diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index 2df02da..846d3e3 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -18,7 +18,7 @@ along with this program. If not, see . """ from pathlib import Path -from typing import Optional +from typing import Optional, List from quart import Blueprint, request, current_app as app, jsonify from logbook import Logger @@ -49,6 +49,49 @@ log = Logger(__name__) bp = Blueprint("channel_messages", __name__) +async def message_search( + channel_id: int, + before: Optional[int], + after: Optional[int], + limit: int, +) -> List[int]: + where_clause = "" + if before: + where_clause += f"AND id < {before}" + + if after: + where_clause += f"AND id > {after}" + + return await app.db.fetch( + f""" + SELECT id + FROM messages + WHERE channel_id = $1 {where_clause} + ORDER BY id DESC + LIMIT {limit} + """, + channel_id, + ) + + +async def around_message_search( + channel_id: int, + around_id: int, + limit: int, +) -> List[int]: + # search limit/2 messages BEFORE around_id + # search limit/2 messages AFTER around_id + # merge it all together: before + [around_id] + after + halved_limit = limit // 2 + before_messages = await message_search( + channel_id, before=around_id, after=None, limit=halved_limit + ) + after_messages = await message_search( + channel_id, before=None, after=around_id, limit=halved_limit + ) + return before_messages + [around_id] + after_messages + + @bp.route("//messages", methods=["GET"]) async def get_messages(channel_id): user_id = await token_check() @@ -64,25 +107,15 @@ async def get_messages(channel_id): limit = extract_limit(request, 50) - where_clause = "" - before, after = query_tuple_from_args(request.args, limit) - - if before: - where_clause += f"AND id < {before}" - - if after: - where_clause += f"AND id > {after}" - - message_ids = await app.db.fetch( - f""" - SELECT id - FROM messages - WHERE channel_id = $1 {where_clause} - ORDER BY id DESC - LIMIT {limit} - """, - channel_id, - ) + if "around" in request.args: + message_ids = await around_message_search( + channel_id, int(request.args["around"]), limit + ) + else: + before, after = query_tuple_from_args(request.args, limit) + message_ids = await message_search( + channel_id, before=before, after=after, limit=limit + ) result = [] diff --git a/litecord/utils.py b/litecord/utils.py index b516029..86be8c2 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -26,7 +26,7 @@ from typing import Any, Iterable, Optional, Sequence, List, Dict, Union from logbook import Logger from quart.json import JSONEncoder -from quart import current_app as app, request +from quart import current_app as app from .errors import BadRequest @@ -274,14 +274,7 @@ def query_tuple_from_args(args: dict, limit: int) -> tuple: """Extract a 2-tuple out of request arguments.""" before, after = None, None - if "around" in request.args: - average = int(limit / 2) - around = int(args["around"]) - - after = around - average - before = around + average - - elif "before" in args: + if "before" in args: before = int(args["before"]) elif "after" in args: before = int(args["after"])