diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index 6711181..11c18a5 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -18,6 +18,7 @@ along with this program. If not, see . """ import time +import datetime from typing import List, Optional from quart import Blueprint, request, current_app as app, jsonify @@ -25,9 +26,10 @@ from logbook import Logger from litecord.auth import token_check from litecord.enums import ChannelType, GUILD_CHANS, MessageType -from litecord.errors import ChannelNotFound, Forbidden +from litecord.errors import ChannelNotFound, Forbidden, BadRequest from litecord.schemas import ( - validate, CHAN_UPDATE, CHAN_OVERWRITE, SEARCH_CHANNEL, GROUP_DM_UPDATE + validate, CHAN_UPDATE, CHAN_OVERWRITE, SEARCH_CHANNEL, GROUP_DM_UPDATE, + BULK_DELETE, ) from litecord.blueprints.checks import channel_check, channel_perm_check @@ -37,6 +39,7 @@ from litecord.blueprints.dm_channels import ( ) from litecord.utils import search_result_from_list from litecord.embed.messages import process_url_embed, msg_update_embeds +from litecord.snowflake import snowflake_datetime log = Logger(__name__) bp = Blueprint('channels', __name__) @@ -664,3 +667,49 @@ async def suppress_embeds(channel_id: int, message_id: int): ) return '', 204 + + +@bp.route('//messages/bulk-delete', methods=['POST']) +async def bulk_delete(channel_id: int): + user_id = await token_check() + ctype, guild_id = await channel_check(user_id, channel_id) + guild_id = guild_id if ctype in GUILD_CHANS else None + + await channel_perm_check(user_id, channel_id, 'manage_messages') + + j = validate(await request.get_json(), BULK_DELETE) + message_ids = set(j['messages']) + + # as per discord behavior, if any id here is older than two weeks, + # we must error. a cuter behavior would be returning the message ids + # that were deleted, ignoring the 2 week+ old ones. + for message_id in message_ids: + message_dt = snowflake_datetime(message_id) + delta = datetime.datetime.utcnow() - message_dt + + if delta.days > 14: + raise BadRequest(50034) + + payload = { + 'guild_id': str(guild_id), + 'channel_id': str(channel_id), + 'ids': list(map(str, message_ids)), + } + + # payload.guild_id is optional in the event, not nullable. + if guild_id is None: + payload.pop('guild_id') + + res = await app.db.execute(""" + DELETE FROM messages + WHERE + channel_id = $1 + AND ARRAY[id] <@ $2::bigint[] + """, channel_id, list(message_ids)) + + if res == 'DELETE 0': + raise BadRequest('No messages were removed') + + await app.dispatcher.dispatch( + 'channel', channel_id, 'MESSAGE_DELETE_BULK', payload) + return '', 204 diff --git a/litecord/errors.py b/litecord/errors.py index 76f937c..c171b15 100644 --- a/litecord/errors.py +++ b/litecord/errors.py @@ -74,18 +74,24 @@ class LitecordError(Exception): """Base class for litecord errors""" status_code = 500 + def _get_err_msg(self, err_code: int) -> str: + if err_code is not None: + return ERR_MSG_MAP.get(err_code) or self.args[0] + + return repr(self) + @property def message(self) -> str: """Get an error's message string.""" try: - return self.args[0] + message = self.args[0] + + if isinstance(message, int): + return self._get_err_msg(message) + + return message except IndexError: - err_code = getattr(self, 'error_code', None) - - if err_code is not None: - return ERR_MSG_MAP.get(err_code) or self.args[0] - - return repr(self) + return self._get_err_msg(getattr(self, 'error_code', None)) @property def json(self): diff --git a/litecord/schemas.py b/litecord/schemas.py index 50e6d09..bc9cce7 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -18,7 +18,7 @@ along with this program. If not, see . """ import re -from typing import Union, Dict, List +from typing import Union, Dict, List, Optional from cerberus import Validator from logbook import Logger @@ -158,7 +158,7 @@ class LitecordValidator(Validator): return isinstance(value, str) and (len(value) < 32) -def validate(reqjson: Union[Dict, List], schema: Dict, +def validate(reqjson: Optional[Union[Dict, List]], schema: Dict, raise_err: bool = True) -> Dict: """Validate the given user-given data against a schema, giving the "correct" version of the document, with all defaults applied. @@ -175,6 +175,9 @@ def validate(reqjson: Union[Dict, List], schema: Dict, """ validator = LitecordValidator(schema) + if reqjson is None: + raise BadRequest('No JSON provided') + try: valid = validator.validate(reqjson) except Exception: @@ -737,3 +740,11 @@ WEBHOOK_MESSAGE_CREATE = { 'schema': {'type': 'dict', 'schema': EMBED_OBJECT} } } + +BULK_DELETE = { + 'messages': { + 'type': 'list', 'required': True, + 'minlength': 2, 'maxlength': 100, + 'schema': {'coerce': int} + } +}