split search result generation into own func on utils

This commit is contained in:
Luna 2019-06-18 16:53:36 -03:00
parent 2322c2603c
commit 262273b618
3 changed files with 52 additions and 48 deletions

View File

@ -35,6 +35,7 @@ from litecord.system_messages import send_sys_message
from litecord.blueprints.dm_channels import ( from litecord.blueprints.dm_channels import (
gdm_remove_recipient, gdm_destroy gdm_remove_recipient, gdm_destroy
) )
from litecord.utils import search_result_from_list
log = Logger(__name__) log = Logger(__name__)
bp = Blueprint('channels', __name__) bp = Blueprint('channels', __name__)
@ -589,33 +590,24 @@ async def _search_channel(channel_id):
j = validate(request.args, SEARCH_CHANNEL) j = validate(request.args, SEARCH_CHANNEL)
# main message ids # main search query
# the context (before/after) columns are copied from the guilds blueprint.
rows = await app.db.fetch(f""" rows = await app.db.fetch(f"""
SELECT messages.id, SELECT orig.id AS current_id,
COUNT(*) OVER() as total_results COUNT(*) OVER() AS total_results,
FROM messages array((SELECT messages.id AS before_id
WHERE channel_id = $1 AND content LIKE '%'||$3||'%' FROM messages WHERE messages.id < orig.id
ORDER BY messages.id DESC ORDER BY messages.id DESC LIMIT 2)) AS before,
array((SELECT messages.id AS after_id
FROM messages WHERE messages.id > orig.id
ORDER BY messages.id ASC LIMIT 2)) AS after
FROM messages AS orig
WHERE channel_id = $1
AND content LIKE '%'||$3||'%'
ORDER BY orig.id DESC
LIMIT 50 LIMIT 50
OFFSET $2 OFFSET $2
""", channel_id, j['offset'], j['content']) """, channel_id, j['offset'], j['content'])
results = 0 if not rows else rows[0]['total_results'] return jsonify(await search_result_from_list(rows))
main_messages = [r['message_id'] for r in rows]
# fetch contexts for each message
# (2 messages before, 2 messages after).
# TODO: actual contexts
res = []
for message_id in main_messages:
msg = await app.storage.get_message(message_id)
msg['hit'] = True
res.append([msg])
return jsonify({
'total_results': results,
'messages': res,
'analytics_id': '',
})

View File

@ -35,7 +35,7 @@ from ..schemas import (
) )
from .channels import channel_ack from .channels import channel_ack
from .checks import guild_check, guild_owner_check, guild_perm_check from .checks import guild_check, guild_owner_check, guild_perm_check
from litecord.utils import to_update from litecord.utils import to_update, search_result_from_list
from litecord.errors import BadRequest from litecord.errors import BadRequest
from litecord.permissions import get_permissions from litecord.permissions import get_permissions
@ -434,27 +434,7 @@ async def search_messages(guild_id):
OFFSET $3 OFFSET $3
""", guild_id, j['content'], j['offset'], can_read) """, guild_id, j['content'], j['offset'], can_read)
results = 0 if not rows else rows[0]['total_results'] return jsonify(await search_result_from_list(rows))
res = []
for row in rows:
before, after = [], []
for before_id in reversed(row['before']):
before.append(await app.storage.get_message(before_id))
for after_id in row['after']:
after.append(await app.storage.get_message(after_id))
msg = await app.storage.get_message(row['current_id'])
msg['hit'] = True
res.append(before + [msg] + after)
return jsonify({
'total_results': results,
'messages': res,
'analytics_id': '',
})
@bp.route('/<int:guild_id>/ack', methods=['POST']) @bp.route('/<int:guild_id>/ack', methods=['POST'])

View File

@ -19,10 +19,11 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio import asyncio
import json import json
from typing import Any, Iterable, Optional, Sequence from typing import Any, Iterable, Optional, Sequence, List, Dict
from logbook import Logger from logbook import Logger
from quart.json import JSONEncoder from quart.json import JSONEncoder
from quart import current_app as app
log = Logger(__name__) log = Logger(__name__)
@ -182,3 +183,34 @@ def to_update(j: dict, orig: dict, field: str) -> bool:
"""Compare values to check if j[field] is actually updating """Compare values to check if j[field] is actually updating
the value in orig[field]. Useful for icon checks.""" the value in orig[field]. Useful for icon checks."""
return field in j and j[field] and j[field] != orig[field] return field in j and j[field] and j[field] != orig[field]
async def search_result_from_list(rows: List) -> Dict[str, Any]:
"""Generate the end result of the search query, given a list of rows.
Each row must contain:
- A bigint on `current_id`
- An int (?) on `total_results`
- Two bigint[], each on `before` and `after` respectively.
"""
results = 0 if not rows else rows[0]['total_results']
res = []
for row in rows:
before, after = [], []
for before_id in reversed(row['before']):
before.append(await app.storage.get_message(before_id))
for after_id in row['after']:
after.append(await app.storage.get_message(after_id))
msg = await app.storage.get_message(row['current_id'])
msg['hit'] = True
res.append(before + [msg] + after)
return {
'total_results': results,
'messages': res,
'analytics_id': '',
}