diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py
index 6711181..11c18a5 100644
--- a/litecord/blueprints/channels.py
+++ b/litecord/blueprints/channels.py
@@ -18,6 +18,7 @@ along with this program. If not, see .
"""
import time
+import datetime
from typing import List, Optional
from quart import Blueprint, request, current_app as app, jsonify
@@ -25,9 +26,10 @@ from logbook import Logger
from litecord.auth import token_check
from litecord.enums import ChannelType, GUILD_CHANS, MessageType
-from litecord.errors import ChannelNotFound, Forbidden
+from litecord.errors import ChannelNotFound, Forbidden, BadRequest
from litecord.schemas import (
- validate, CHAN_UPDATE, CHAN_OVERWRITE, SEARCH_CHANNEL, GROUP_DM_UPDATE
+ validate, CHAN_UPDATE, CHAN_OVERWRITE, SEARCH_CHANNEL, GROUP_DM_UPDATE,
+ BULK_DELETE,
)
from litecord.blueprints.checks import channel_check, channel_perm_check
@@ -37,6 +39,7 @@ from litecord.blueprints.dm_channels import (
)
from litecord.utils import search_result_from_list
from litecord.embed.messages import process_url_embed, msg_update_embeds
+from litecord.snowflake import snowflake_datetime
log = Logger(__name__)
bp = Blueprint('channels', __name__)
@@ -664,3 +667,49 @@ async def suppress_embeds(channel_id: int, message_id: int):
)
return '', 204
+
+
+@bp.route('//messages/bulk-delete', methods=['POST'])
+async def bulk_delete(channel_id: int):
+ user_id = await token_check()
+ ctype, guild_id = await channel_check(user_id, channel_id)
+ guild_id = guild_id if ctype in GUILD_CHANS else None
+
+ await channel_perm_check(user_id, channel_id, 'manage_messages')
+
+ j = validate(await request.get_json(), BULK_DELETE)
+ message_ids = set(j['messages'])
+
+ # as per discord behavior, if any id here is older than two weeks,
+ # we must error. a cuter behavior would be returning the message ids
+ # that were deleted, ignoring the 2 week+ old ones.
+ for message_id in message_ids:
+ message_dt = snowflake_datetime(message_id)
+ delta = datetime.datetime.utcnow() - message_dt
+
+ if delta.days > 14:
+ raise BadRequest(50034)
+
+ payload = {
+ 'guild_id': str(guild_id),
+ 'channel_id': str(channel_id),
+ 'ids': list(map(str, message_ids)),
+ }
+
+ # payload.guild_id is optional in the event, not nullable.
+ if guild_id is None:
+ payload.pop('guild_id')
+
+ res = await app.db.execute("""
+ DELETE FROM messages
+ WHERE
+ channel_id = $1
+ AND ARRAY[id] <@ $2::bigint[]
+ """, channel_id, list(message_ids))
+
+ if res == 'DELETE 0':
+ raise BadRequest('No messages were removed')
+
+ await app.dispatcher.dispatch(
+ 'channel', channel_id, 'MESSAGE_DELETE_BULK', payload)
+ return '', 204
diff --git a/litecord/errors.py b/litecord/errors.py
index 76f937c..c171b15 100644
--- a/litecord/errors.py
+++ b/litecord/errors.py
@@ -74,18 +74,24 @@ class LitecordError(Exception):
"""Base class for litecord errors"""
status_code = 500
+ def _get_err_msg(self, err_code: int) -> str:
+ if err_code is not None:
+ return ERR_MSG_MAP.get(err_code) or self.args[0]
+
+ return repr(self)
+
@property
def message(self) -> str:
"""Get an error's message string."""
try:
- return self.args[0]
+ message = self.args[0]
+
+ if isinstance(message, int):
+ return self._get_err_msg(message)
+
+ return message
except IndexError:
- err_code = getattr(self, 'error_code', None)
-
- if err_code is not None:
- return ERR_MSG_MAP.get(err_code) or self.args[0]
-
- return repr(self)
+ return self._get_err_msg(getattr(self, 'error_code', None))
@property
def json(self):
diff --git a/litecord/schemas.py b/litecord/schemas.py
index 50e6d09..bc9cce7 100644
--- a/litecord/schemas.py
+++ b/litecord/schemas.py
@@ -18,7 +18,7 @@ along with this program. If not, see .
"""
import re
-from typing import Union, Dict, List
+from typing import Union, Dict, List, Optional
from cerberus import Validator
from logbook import Logger
@@ -158,7 +158,7 @@ class LitecordValidator(Validator):
return isinstance(value, str) and (len(value) < 32)
-def validate(reqjson: Union[Dict, List], schema: Dict,
+def validate(reqjson: Optional[Union[Dict, List]], schema: Dict,
raise_err: bool = True) -> Dict:
"""Validate the given user-given data against a schema, giving the
"correct" version of the document, with all defaults applied.
@@ -175,6 +175,9 @@ def validate(reqjson: Union[Dict, List], schema: Dict,
"""
validator = LitecordValidator(schema)
+ if reqjson is None:
+ raise BadRequest('No JSON provided')
+
try:
valid = validator.validate(reqjson)
except Exception:
@@ -737,3 +740,11 @@ WEBHOOK_MESSAGE_CREATE = {
'schema': {'type': 'dict', 'schema': EMBED_OBJECT}
}
}
+
+BULK_DELETE = {
+ 'messages': {
+ 'type': 'list', 'required': True,
+ 'minlength': 2, 'maxlength': 100,
+ 'schema': {'coerce': int}
+ }
+}