diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py
index 2df02da..846d3e3 100644
--- a/litecord/blueprints/channel/messages.py
+++ b/litecord/blueprints/channel/messages.py
@@ -18,7 +18,7 @@ along with this program. If not, see .
"""
from pathlib import Path
-from typing import Optional
+from typing import Optional, List
from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger
@@ -49,6 +49,49 @@ log = Logger(__name__)
bp = Blueprint("channel_messages", __name__)
+async def message_search(
+ channel_id: int,
+ before: Optional[int],
+ after: Optional[int],
+ limit: int,
+) -> List[int]:
+ where_clause = ""
+ if before:
+ where_clause += f"AND id < {before}"
+
+ if after:
+ where_clause += f"AND id > {after}"
+
+ return await app.db.fetch(
+ f"""
+ SELECT id
+ FROM messages
+ WHERE channel_id = $1 {where_clause}
+ ORDER BY id DESC
+ LIMIT {limit}
+ """,
+ channel_id,
+ )
+
+
+async def around_message_search(
+ channel_id: int,
+ around_id: int,
+ limit: int,
+) -> List[int]:
+ # search limit/2 messages BEFORE around_id
+ # search limit/2 messages AFTER around_id
+ # merge it all together: before + [around_id] + after
+ halved_limit = limit // 2
+ before_messages = await message_search(
+ channel_id, before=around_id, after=None, limit=halved_limit
+ )
+ after_messages = await message_search(
+ channel_id, before=None, after=around_id, limit=halved_limit
+ )
+ return before_messages + [around_id] + after_messages
+
+
@bp.route("//messages", methods=["GET"])
async def get_messages(channel_id):
user_id = await token_check()
@@ -64,25 +107,15 @@ async def get_messages(channel_id):
limit = extract_limit(request, 50)
- where_clause = ""
- before, after = query_tuple_from_args(request.args, limit)
-
- if before:
- where_clause += f"AND id < {before}"
-
- if after:
- where_clause += f"AND id > {after}"
-
- message_ids = await app.db.fetch(
- f"""
- SELECT id
- FROM messages
- WHERE channel_id = $1 {where_clause}
- ORDER BY id DESC
- LIMIT {limit}
- """,
- channel_id,
- )
+ if "around" in request.args:
+ message_ids = await around_message_search(
+ channel_id, int(request.args["around"]), limit
+ )
+ else:
+ before, after = query_tuple_from_args(request.args, limit)
+ message_ids = await message_search(
+ channel_id, before=before, after=after, limit=limit
+ )
result = []
diff --git a/litecord/utils.py b/litecord/utils.py
index b516029..86be8c2 100644
--- a/litecord/utils.py
+++ b/litecord/utils.py
@@ -26,7 +26,7 @@ from typing import Any, Iterable, Optional, Sequence, List, Dict, Union
from logbook import Logger
from quart.json import JSONEncoder
-from quart import current_app as app, request
+from quart import current_app as app
from .errors import BadRequest
@@ -274,14 +274,7 @@ def query_tuple_from_args(args: dict, limit: int) -> tuple:
"""Extract a 2-tuple out of request arguments."""
before, after = None, None
- if "around" in request.args:
- average = int(limit / 2)
- around = int(args["around"])
-
- after = around - average
- before = around + average
-
- elif "before" in args:
+ if "before" in args:
before = int(args["before"])
elif "after" in args:
before = int(args["after"])