From 262273b6180c82c51e11b4433e78aa8b67cc137b Mon Sep 17 00:00:00 2001 From: Luna Date: Tue, 18 Jun 2019 16:53:36 -0300 Subject: [PATCH] split search result generation into own func on utils --- litecord/blueprints/channels.py | 42 +++++++++++++-------------------- litecord/blueprints/guilds.py | 24 ++----------------- litecord/utils.py | 34 +++++++++++++++++++++++++- 3 files changed, 52 insertions(+), 48 deletions(-) diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index 2690440..c67e788 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -35,6 +35,7 @@ from litecord.system_messages import send_sys_message from litecord.blueprints.dm_channels import ( gdm_remove_recipient, gdm_destroy ) +from litecord.utils import search_result_from_list log = Logger(__name__) bp = Blueprint('channels', __name__) @@ -589,33 +590,24 @@ async def _search_channel(channel_id): j = validate(request.args, SEARCH_CHANNEL) - # main message ids + # main search query + # the context (before/after) columns are copied from the guilds blueprint. rows = await app.db.fetch(f""" - SELECT messages.id, - COUNT(*) OVER() as total_results - FROM messages - WHERE channel_id = $1 AND content LIKE '%'||$3||'%' - ORDER BY messages.id DESC + SELECT orig.id AS current_id, + COUNT(*) OVER() AS total_results, + array((SELECT messages.id AS before_id + FROM messages WHERE messages.id < orig.id + ORDER BY messages.id DESC LIMIT 2)) AS before, + array((SELECT messages.id AS after_id + FROM messages WHERE messages.id > orig.id + ORDER BY messages.id ASC LIMIT 2)) AS after + + FROM messages AS orig + WHERE channel_id = $1 + AND content LIKE '%'||$3||'%' + ORDER BY orig.id DESC LIMIT 50 OFFSET $2 """, channel_id, j['offset'], j['content']) - results = 0 if not rows else rows[0]['total_results'] - main_messages = [r['message_id'] for r in rows] - - # fetch contexts for each message - # (2 messages before, 2 messages after). - - # TODO: actual contexts - res = [] - - for message_id in main_messages: - msg = await app.storage.get_message(message_id) - msg['hit'] = True - res.append([msg]) - - return jsonify({ - 'total_results': results, - 'messages': res, - 'analytics_id': '', - }) + return jsonify(await search_result_from_list(rows)) diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index e3938f0..3c3c7dc 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -35,7 +35,7 @@ from ..schemas import ( ) from .channels import channel_ack from .checks import guild_check, guild_owner_check, guild_perm_check -from litecord.utils import to_update +from litecord.utils import to_update, search_result_from_list from litecord.errors import BadRequest from litecord.permissions import get_permissions @@ -434,27 +434,7 @@ async def search_messages(guild_id): OFFSET $3 """, guild_id, j['content'], j['offset'], can_read) - results = 0 if not rows else rows[0]['total_results'] - res = [] - - for row in rows: - before, after = [], [] - - for before_id in reversed(row['before']): - before.append(await app.storage.get_message(before_id)) - - for after_id in row['after']: - after.append(await app.storage.get_message(after_id)) - - msg = await app.storage.get_message(row['current_id']) - msg['hit'] = True - res.append(before + [msg] + after) - - return jsonify({ - 'total_results': results, - 'messages': res, - 'analytics_id': '', - }) + return jsonify(await search_result_from_list(rows)) @bp.route('//ack', methods=['POST']) diff --git a/litecord/utils.py b/litecord/utils.py index db34b28..6018ba9 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -19,10 +19,11 @@ along with this program. If not, see . import asyncio import json -from typing import Any, Iterable, Optional, Sequence +from typing import Any, Iterable, Optional, Sequence, List, Dict from logbook import Logger from quart.json import JSONEncoder +from quart import current_app as app log = Logger(__name__) @@ -182,3 +183,34 @@ def to_update(j: dict, orig: dict, field: str) -> bool: """Compare values to check if j[field] is actually updating the value in orig[field]. Useful for icon checks.""" return field in j and j[field] and j[field] != orig[field] + + +async def search_result_from_list(rows: List) -> Dict[str, Any]: + """Generate the end result of the search query, given a list of rows. + + Each row must contain: + - A bigint on `current_id` + - An int (?) on `total_results` + - Two bigint[], each on `before` and `after` respectively. + """ + results = 0 if not rows else rows[0]['total_results'] + res = [] + + for row in rows: + before, after = [], [] + + for before_id in reversed(row['before']): + before.append(await app.storage.get_message(before_id)) + + for after_id in row['after']: + after.append(await app.storage.get_message(after_id)) + + msg = await app.storage.get_message(row['current_id']) + msg['hit'] = True + res.append(before + [msg] + after) + + return { + 'total_results': results, + 'messages': res, + 'analytics_id': '', + }