From bcf12c8d3e934b479dc9beb69c008fc137d12bcf Mon Sep 17 00:00:00 2001 From: Luna Date: Thu, 30 May 2019 21:34:36 -0300 Subject: [PATCH] guilds: add search filter for readable channels --- litecord/blueprints/guilds.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index 0bc7585..9535367 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -17,7 +17,7 @@ along with this program. If not, see . """ -from typing import Optional +from typing import Optional, List from quart import Blueprint, request, current_app as app, jsonify @@ -36,8 +36,8 @@ from ..schemas import ( from .channels import channel_ack from .checks import guild_check, guild_owner_check, guild_perm_check from litecord.utils import to_update - from litecord.errors import BadRequest +from litecord.permissions import get_permissions bp = Blueprint('guilds', __name__) @@ -383,6 +383,20 @@ async def delete_guild_handler(guild_id): return '', 204 +async def fetch_readable_channels(guild_id: int, user_id: int) -> List[int]: + """Fetch readable channel IDs.""" + channel_ids = await app.storage.get_channel_ids(guild_id) + res = [] + + for channel_id in channel_ids: + perms = await get_permissions(user_id, channel_id) + + if perms.read_messages: + res.append(channel_id) + + return res + + @bp.route('//messages/search', methods=['GET']) async def search_messages(guild_id): """Search messages in a guild. @@ -394,9 +408,12 @@ async def search_messages(guild_id): j = validate(dict(request.args), SEARCH_CHANNEL) - # main message ids - # TODO: filter only channels where user can - # read messages to prevent leaking + # instead of writing a function in pure sql (which would be + # better/faster for this usecase), consdering that it would be + # hard to write the function in the first place, we generate + # a list of channels the user can read AHEAD of time, then + # use that list on the main search query. + can_read = await fetch_readable_channels(guild_id, user_id) rows = await app.db.fetch(f""" SELECT messages.id, @@ -404,10 +421,11 @@ async def search_messages(guild_id): FROM messages WHERE guild_id = $1 AND messages.content LIKE '%'||$2||'%' + AND ARRAY[messages.channel_id] <@ $4::bigint[] ORDER BY messages.id DESC LIMIT 50 OFFSET $3 - """, guild_id, j['content'], j['offset']) + """, guild_id, j['content'], j['offset'], can_read) results = 0 if not rows else rows[0]['total_results'] main_messages = [r['id'] for r in rows]