mirror of https://gitlab.com/litecord/litecord.git
guilds: add search filter for readable channels
This commit is contained in:
parent
23c7ac2c34
commit
bcf12c8d3e
|
|
@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
|
||||
from quart import Blueprint, request, current_app as app, jsonify
|
||||
|
||||
|
|
@ -36,8 +36,8 @@ from ..schemas import (
|
|||
from .channels import channel_ack
|
||||
from .checks import guild_check, guild_owner_check, guild_perm_check
|
||||
from litecord.utils import to_update
|
||||
|
||||
from litecord.errors import BadRequest
|
||||
from litecord.permissions import get_permissions
|
||||
|
||||
|
||||
bp = Blueprint('guilds', __name__)
|
||||
|
|
@ -383,6 +383,20 @@ async def delete_guild_handler(guild_id):
|
|||
return '', 204
|
||||
|
||||
|
||||
async def fetch_readable_channels(guild_id: int, user_id: int) -> List[int]:
|
||||
"""Fetch readable channel IDs."""
|
||||
channel_ids = await app.storage.get_channel_ids(guild_id)
|
||||
res = []
|
||||
|
||||
for channel_id in channel_ids:
|
||||
perms = await get_permissions(user_id, channel_id)
|
||||
|
||||
if perms.read_messages:
|
||||
res.append(channel_id)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/messages/search', methods=['GET'])
|
||||
async def search_messages(guild_id):
|
||||
"""Search messages in a guild.
|
||||
|
|
@ -394,9 +408,12 @@ async def search_messages(guild_id):
|
|||
|
||||
j = validate(dict(request.args), SEARCH_CHANNEL)
|
||||
|
||||
# main message ids
|
||||
# TODO: filter only channels where user can
|
||||
# read messages to prevent leaking
|
||||
# instead of writing a function in pure sql (which would be
|
||||
# better/faster for this usecase), consdering that it would be
|
||||
# hard to write the function in the first place, we generate
|
||||
# a list of channels the user can read AHEAD of time, then
|
||||
# use that list on the main search query.
|
||||
can_read = await fetch_readable_channels(guild_id, user_id)
|
||||
|
||||
rows = await app.db.fetch(f"""
|
||||
SELECT messages.id,
|
||||
|
|
@ -404,10 +421,11 @@ async def search_messages(guild_id):
|
|||
FROM messages
|
||||
WHERE guild_id = $1
|
||||
AND messages.content LIKE '%'||$2||'%'
|
||||
AND ARRAY[messages.channel_id] <@ $4::bigint[]
|
||||
ORDER BY messages.id DESC
|
||||
LIMIT 50
|
||||
OFFSET $3
|
||||
""", guild_id, j['content'], j['offset'])
|
||||
""", guild_id, j['content'], j['offset'], can_read)
|
||||
|
||||
results = 0 if not rows else rows[0]['total_results']
|
||||
main_messages = [r['id'] for r in rows]
|
||||
|
|
|
|||
Loading…
Reference in New Issue