add proper support for "around" parameter

This commit is contained in:
Luna 2021-08-29 01:17:46 -03:00
parent 0b20ba1283
commit 92dba16237
2 changed files with 55 additions and 29 deletions

View File

@ -18,7 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, List
from quart import Blueprint, request, current_app as app, jsonify from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger from logbook import Logger
@ -49,6 +49,49 @@ log = Logger(__name__)
bp = Blueprint("channel_messages", __name__) bp = Blueprint("channel_messages", __name__)
async def message_search(
channel_id: int,
before: Optional[int],
after: Optional[int],
limit: int,
) -> List[int]:
where_clause = ""
if before:
where_clause += f"AND id < {before}"
if after:
where_clause += f"AND id > {after}"
return await app.db.fetch(
f"""
SELECT id
FROM messages
WHERE channel_id = $1 {where_clause}
ORDER BY id DESC
LIMIT {limit}
""",
channel_id,
)
async def around_message_search(
channel_id: int,
around_id: int,
limit: int,
) -> List[int]:
# search limit/2 messages BEFORE around_id
# search limit/2 messages AFTER around_id
# merge it all together: before + [around_id] + after
halved_limit = limit // 2
before_messages = await message_search(
channel_id, before=around_id, after=None, limit=halved_limit
)
after_messages = await message_search(
channel_id, before=None, after=around_id, limit=halved_limit
)
return before_messages + [around_id] + after_messages
@bp.route("/<int:channel_id>/messages", methods=["GET"]) @bp.route("/<int:channel_id>/messages", methods=["GET"])
async def get_messages(channel_id): async def get_messages(channel_id):
user_id = await token_check() user_id = await token_check()
@ -64,25 +107,15 @@ async def get_messages(channel_id):
limit = extract_limit(request, 50) limit = extract_limit(request, 50)
where_clause = "" if "around" in request.args:
before, after = query_tuple_from_args(request.args, limit) message_ids = await around_message_search(
channel_id, int(request.args["around"]), limit
if before: )
where_clause += f"AND id < {before}" else:
before, after = query_tuple_from_args(request.args, limit)
if after: message_ids = await message_search(
where_clause += f"AND id > {after}" channel_id, before=before, after=after, limit=limit
)
message_ids = await app.db.fetch(
f"""
SELECT id
FROM messages
WHERE channel_id = $1 {where_clause}
ORDER BY id DESC
LIMIT {limit}
""",
channel_id,
)
result = [] result = []

View File

@ -26,7 +26,7 @@ from typing import Any, Iterable, Optional, Sequence, List, Dict, Union
from logbook import Logger from logbook import Logger
from quart.json import JSONEncoder from quart.json import JSONEncoder
from quart import current_app as app, request from quart import current_app as app
from .errors import BadRequest from .errors import BadRequest
@ -274,14 +274,7 @@ def query_tuple_from_args(args: dict, limit: int) -> tuple:
"""Extract a 2-tuple out of request arguments.""" """Extract a 2-tuple out of request arguments."""
before, after = None, None before, after = None, None
if "around" in request.args: if "before" in args:
average = int(limit / 2)
around = int(args["around"])
after = around - average
before = around + average
elif "before" in args:
before = int(args["before"]) before = int(args["before"])
elif "after" in args: elif "after" in args:
before = int(args["after"]) before = int(args["after"])