Merge branch 'bulk-delete' into 'master'

add support for message bulk delete

See merge request litecord/litecord!45
This commit is contained in:
Luna 2019-09-01 18:38:46 +00:00
commit a8f226e2c2
3 changed files with 77 additions and 11 deletions

View File

@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
import time import time
import datetime
from typing import List, Optional from typing import List, Optional
from quart import Blueprint, request, current_app as app, jsonify from quart import Blueprint, request, current_app as app, jsonify
@ -25,9 +26,10 @@ from logbook import Logger
from litecord.auth import token_check from litecord.auth import token_check
from litecord.enums import ChannelType, GUILD_CHANS, MessageType from litecord.enums import ChannelType, GUILD_CHANS, MessageType
from litecord.errors import ChannelNotFound, Forbidden from litecord.errors import ChannelNotFound, Forbidden, BadRequest
from litecord.schemas import ( from litecord.schemas import (
validate, CHAN_UPDATE, CHAN_OVERWRITE, SEARCH_CHANNEL, GROUP_DM_UPDATE validate, CHAN_UPDATE, CHAN_OVERWRITE, SEARCH_CHANNEL, GROUP_DM_UPDATE,
BULK_DELETE,
) )
from litecord.blueprints.checks import channel_check, channel_perm_check from litecord.blueprints.checks import channel_check, channel_perm_check
@ -37,6 +39,7 @@ from litecord.blueprints.dm_channels import (
) )
from litecord.utils import search_result_from_list from litecord.utils import search_result_from_list
from litecord.embed.messages import process_url_embed, msg_update_embeds from litecord.embed.messages import process_url_embed, msg_update_embeds
from litecord.snowflake import snowflake_datetime
log = Logger(__name__) log = Logger(__name__)
bp = Blueprint('channels', __name__) bp = Blueprint('channels', __name__)
@ -664,3 +667,49 @@ async def suppress_embeds(channel_id: int, message_id: int):
) )
return '', 204 return '', 204
@bp.route('/<int:channel_id>/messages/bulk-delete', methods=['POST'])
async def bulk_delete(channel_id: int):
user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id)
guild_id = guild_id if ctype in GUILD_CHANS else None
await channel_perm_check(user_id, channel_id, 'manage_messages')
j = validate(await request.get_json(), BULK_DELETE)
message_ids = set(j['messages'])
# as per discord behavior, if any id here is older than two weeks,
# we must error. a cuter behavior would be returning the message ids
# that were deleted, ignoring the 2 week+ old ones.
for message_id in message_ids:
message_dt = snowflake_datetime(message_id)
delta = datetime.datetime.utcnow() - message_dt
if delta.days > 14:
raise BadRequest(50034)
payload = {
'guild_id': str(guild_id),
'channel_id': str(channel_id),
'ids': list(map(str, message_ids)),
}
# payload.guild_id is optional in the event, not nullable.
if guild_id is None:
payload.pop('guild_id')
res = await app.db.execute("""
DELETE FROM messages
WHERE
channel_id = $1
AND ARRAY[id] <@ $2::bigint[]
""", channel_id, list(message_ids))
if res == 'DELETE 0':
raise BadRequest('No messages were removed')
await app.dispatcher.dispatch(
'channel', channel_id, 'MESSAGE_DELETE_BULK', payload)
return '', 204

View File

@ -74,18 +74,24 @@ class LitecordError(Exception):
"""Base class for litecord errors""" """Base class for litecord errors"""
status_code = 500 status_code = 500
def _get_err_msg(self, err_code: int) -> str:
if err_code is not None:
return ERR_MSG_MAP.get(err_code) or self.args[0]
return repr(self)
@property @property
def message(self) -> str: def message(self) -> str:
"""Get an error's message string.""" """Get an error's message string."""
try: try:
return self.args[0] message = self.args[0]
if isinstance(message, int):
return self._get_err_msg(message)
return message
except IndexError: except IndexError:
err_code = getattr(self, 'error_code', None) return self._get_err_msg(getattr(self, 'error_code', None))
if err_code is not None:
return ERR_MSG_MAP.get(err_code) or self.args[0]
return repr(self)
@property @property
def json(self): def json(self):

View File

@ -18,7 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
import re import re
from typing import Union, Dict, List from typing import Union, Dict, List, Optional
from cerberus import Validator from cerberus import Validator
from logbook import Logger from logbook import Logger
@ -158,7 +158,7 @@ class LitecordValidator(Validator):
return isinstance(value, str) and (len(value) < 32) return isinstance(value, str) and (len(value) < 32)
def validate(reqjson: Union[Dict, List], schema: Dict, def validate(reqjson: Optional[Union[Dict, List]], schema: Dict,
raise_err: bool = True) -> Dict: raise_err: bool = True) -> Dict:
"""Validate the given user-given data against a schema, giving the """Validate the given user-given data against a schema, giving the
"correct" version of the document, with all defaults applied. "correct" version of the document, with all defaults applied.
@ -175,6 +175,9 @@ def validate(reqjson: Union[Dict, List], schema: Dict,
""" """
validator = LitecordValidator(schema) validator = LitecordValidator(schema)
if reqjson is None:
raise BadRequest('No JSON provided')
try: try:
valid = validator.validate(reqjson) valid = validator.validate(reqjson)
except Exception: except Exception:
@ -737,3 +740,11 @@ WEBHOOK_MESSAGE_CREATE = {
'schema': {'type': 'dict', 'schema': EMBED_OBJECT} 'schema': {'type': 'dict', 'schema': EMBED_OBJECT}
} }
} }
BULK_DELETE = {
'messages': {
'type': 'list', 'required': True,
'minlength': 2, 'maxlength': 100,
'schema': {'coerce': int}
}
}