diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index 2690440..771fbfd 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -35,6 +35,7 @@ from litecord.system_messages import send_sys_message from litecord.blueprints.dm_channels import ( gdm_remove_recipient, gdm_destroy ) +from litecord.utils import search_result_from_list log = Logger(__name__) bp = Blueprint('channels', __name__) @@ -587,35 +588,26 @@ async def _search_channel(channel_id): await channel_check(user_id, channel_id) await channel_perm_check(user_id, channel_id, 'read_messages') - j = validate(request.args, SEARCH_CHANNEL) + j = validate(dict(request.args), SEARCH_CHANNEL) - # main message ids + # main search query + # the context (before/after) columns are copied from the guilds blueprint. rows = await app.db.fetch(f""" - SELECT messages.id, - COUNT(*) OVER() as total_results - FROM messages - WHERE channel_id = $1 AND content LIKE '%'||$3||'%' - ORDER BY messages.id DESC + SELECT orig.id AS current_id, + COUNT(*) OVER() AS total_results, + array((SELECT messages.id AS before_id + FROM messages WHERE messages.id < orig.id + ORDER BY messages.id DESC LIMIT 2)) AS before, + array((SELECT messages.id AS after_id + FROM messages WHERE messages.id > orig.id + ORDER BY messages.id ASC LIMIT 2)) AS after + + FROM messages AS orig + WHERE channel_id = $1 + AND content LIKE '%'||$3||'%' + ORDER BY orig.id DESC LIMIT 50 OFFSET $2 """, channel_id, j['offset'], j['content']) - results = 0 if not rows else rows[0]['total_results'] - main_messages = [r['message_id'] for r in rows] - - # fetch contexts for each message - # (2 messages before, 2 messages after). - - # TODO: actual contexts - res = [] - - for message_id in main_messages: - msg = await app.storage.get_message(message_id) - msg['hit'] = True - res.append([msg]) - - return jsonify({ - 'total_results': results, - 'messages': res, - 'analytics_id': '', - }) + return jsonify(await search_result_from_list(rows)) diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index 9535367..3c3c7dc 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -35,7 +35,7 @@ from ..schemas import ( ) from .channels import channel_ack from .checks import guild_check, guild_owner_check, guild_perm_check -from litecord.utils import to_update +from litecord.utils import to_update, search_result_from_list from litecord.errors import BadRequest from litecord.permissions import get_permissions @@ -391,7 +391,7 @@ async def fetch_readable_channels(guild_id: int, user_id: int) -> List[int]: for channel_id in channel_ids: perms = await get_permissions(user_id, channel_id) - if perms.read_messages: + if perms.bits.read_messages: res.append(channel_id) return res @@ -416,36 +416,25 @@ async def search_messages(guild_id): can_read = await fetch_readable_channels(guild_id, user_id) rows = await app.db.fetch(f""" - SELECT messages.id, - COUNT(*) OVER() as total_results - FROM messages + SELECT orig.id AS current_id, + COUNT(*) OVER() as total_results, + array((SELECT messages.id AS before_id + FROM messages WHERE messages.id < orig.id + ORDER BY messages.id DESC LIMIT 2)) AS before, + array((SELECT messages.id AS after_id + FROM messages WHERE messages.id > orig.id + ORDER BY messages.id ASC LIMIT 2)) AS after + + FROM messages AS orig WHERE guild_id = $1 - AND messages.content LIKE '%'||$2||'%' - AND ARRAY[messages.channel_id] <@ $4::bigint[] - ORDER BY messages.id DESC + AND orig.content LIKE '%'||$2||'%' + AND ARRAY[orig.channel_id] <@ $4::bigint[] + ORDER BY orig.id DESC LIMIT 50 OFFSET $3 """, guild_id, j['content'], j['offset'], can_read) - results = 0 if not rows else rows[0]['total_results'] - main_messages = [r['id'] for r in rows] - - # fetch contexts for each message - # (2 messages before, 2 messages after). - - # TODO: actual contexts - res = [] - - for message_id in main_messages: - msg = await app.storage.get_message(message_id) - msg['hit'] = True - res.append([msg]) - - return jsonify({ - 'total_results': results, - 'messages': res, - 'analytics_id': '', - }) + return jsonify(await search_result_from_list(rows)) @bp.route('//ack', methods=['POST']) diff --git a/litecord/utils.py b/litecord/utils.py index db34b28..6018ba9 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -19,10 +19,11 @@ along with this program. If not, see . import asyncio import json -from typing import Any, Iterable, Optional, Sequence +from typing import Any, Iterable, Optional, Sequence, List, Dict from logbook import Logger from quart.json import JSONEncoder +from quart import current_app as app log = Logger(__name__) @@ -182,3 +183,34 @@ def to_update(j: dict, orig: dict, field: str) -> bool: """Compare values to check if j[field] is actually updating the value in orig[field]. Useful for icon checks.""" return field in j and j[field] and j[field] != orig[field] + + +async def search_result_from_list(rows: List) -> Dict[str, Any]: + """Generate the end result of the search query, given a list of rows. + + Each row must contain: + - A bigint on `current_id` + - An int (?) on `total_results` + - Two bigint[], each on `before` and `after` respectively. + """ + results = 0 if not rows else rows[0]['total_results'] + res = [] + + for row in rows: + before, after = [], [] + + for before_id in reversed(row['before']): + before.append(await app.storage.get_message(before_id)) + + for after_id in row['after']: + after.append(await app.storage.get_message(after_id)) + + msg = await app.storage.get_message(row['current_id']) + msg['hit'] = True + res.append(before + [msg] + after) + + return { + 'total_results': results, + 'messages': res, + 'analytics_id': '', + }