diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index 1a3d6b5..8e375b1 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -17,65 +17,37 @@ along with this program. If not, see . """ -import json from pathlib import Path -from PIL import Image from quart import Blueprint, request, current_app as app, jsonify from logbook import Logger from litecord.blueprints.auth import token_check from litecord.blueprints.checks import channel_check, channel_perm_check from litecord.blueprints.dms import try_dm_state -from litecord.errors import MessageNotFound, Forbidden, BadRequest +from litecord.errors import MessageNotFound, Forbidden from litecord.enums import MessageType, ChannelType, GUILD_CHANS from litecord.snowflake import get_snowflake from litecord.schemas import validate, MESSAGE_CREATE -from litecord.utils import pg_set_json +from litecord.utils import pg_set_json, query_tuple_from_args, extract_limit from litecord.permissions import get_permissions from litecord.embed.sanitizer import fill_embed from litecord.embed.messages import process_url_embed -from litecord.blueprints.channel.dm_checks import dm_pre_check +from litecord.common.channels import dm_pre_check, try_dm_state from litecord.images import try_unlink +from litecord.common.messages import ( + msg_create_request, + msg_create_check_content, + msg_add_attachment, + msg_guild_text_mentions, +) log = Logger(__name__) bp = Blueprint("channel_messages", __name__) -def extract_limit(request_, default: int = 50, max_val: int = 100): - """Extract a limit kwarg.""" - try: - limit = int(request_.args.get("limit", default)) - - if limit not in range(0, max_val + 1): - raise ValueError() - except (TypeError, ValueError): - raise BadRequest("limit not int") - - return limit - - -def query_tuple_from_args(args: dict, limit: int) -> tuple: - """Extract a 2-tuple out of request arguments.""" - before, after = None, None - - if "around" in request.args: - average = int(limit / 2) - around = int(args["around"]) - - after = around - average - before = around + average - - elif "before" in args: - before = int(args["before"]) - elif "after" in args: - before = int(args["after"]) - - return before, after - - @bp.route("//messages", methods=["GET"]) async def get_messages(channel_id): user_id = await token_check() @@ -204,185 +176,6 @@ async def create_message( return message_id -async def msg_guild_text_mentions( - payload: dict, guild_id: int, mentions_everyone: bool, mentions_here: bool -): - """Calculates mention data side-effects.""" - channel_id = int(payload["channel_id"]) - - # calculate the user ids we'll bump the mention count for - uids = set() - - # first is extracting user mentions - for mention in payload["mentions"]: - uids.add(int(mention["id"])) - - # then role mentions - for role_mention in payload["mention_roles"]: - role_id = int(role_mention) - member_ids = await app.storage.get_role_members(role_id) - - for member_id in member_ids: - uids.add(member_id) - - # at-here only updates the state - # for the users that have a state - # in the channel. - if mentions_here: - uids = set() - - await app.db.execute( - """ - UPDATE user_read_state - SET mention_count = mention_count + 1 - WHERE channel_id = $1 - """, - channel_id, - ) - - # at-here updates the read state - # for all users, including the ones - # that might not have read permissions - # to the channel. - if mentions_everyone: - uids = set() - - member_ids = await app.storage.get_member_ids(guild_id) - - await app.db.executemany( - """ - UPDATE user_read_state - SET mention_count = mention_count + 1 - WHERE channel_id = $1 AND user_id = $2 - """, - [(channel_id, uid) for uid in member_ids], - ) - - for user_id in uids: - await app.db.execute( - """ - UPDATE user_read_state - SET mention_count = mention_count + 1 - WHERE user_id = $1 - AND channel_id = $2 - """, - user_id, - channel_id, - ) - - -async def msg_create_request() -> tuple: - """Extract the json input and any file information - the client gave to us in the request. - - This only applies to create message route. - """ - form = await request.form - request_json = await request.get_json() or {} - - # NOTE: embed isn't set on form data - json_from_form = { - "content": form.get("content", ""), - "nonce": form.get("nonce", "0"), - "tts": json.loads(form.get("tts", "false")), - } - - payload_json = json.loads(form.get("payload_json", "{}")) - - json_from_form.update(request_json) - json_from_form.update(payload_json) - - files = await request.files - - # we don't really care about the given fields on the files dict, so - # we only extract the values - return json_from_form, [v for k, v in files.items()] - - -def msg_create_check_content(payload: dict, files: list, *, use_embeds=False): - """Check if there is actually any content being sent to us.""" - has_content = bool(payload.get("content", "")) - has_files = len(files) > 0 - - embed_field = "embeds" if use_embeds else "embed" - has_embed = embed_field in payload and payload.get(embed_field) is not None - - has_total_content = has_content or has_embed or has_files - - if not has_total_content: - raise BadRequest("No content has been provided.") - - -async def msg_add_attachment(message_id: int, channel_id: int, attachment_file) -> int: - """Add an attachment to a message. - - Parameters - ---------- - message_id: int - The ID of the message getting the attachment. - channel_id: int - The ID of the channel the message belongs to. - - Exists because the attachment URL scheme contains - a channel id. The purpose is unknown, but we are - implementing Discord's behavior. - attachment_file: quart.FileStorage - quart FileStorage instance of the file. - """ - - attachment_id = get_snowflake() - filename = attachment_file.filename - - # understand file info - mime = attachment_file.mimetype - is_image = mime.startswith("image/") - - img_width, img_height = None, None - - # extract file size - # TODO: this is probably inneficient - file_size = attachment_file.stream.getbuffer().nbytes - - if is_image: - # open with pillow, extract image size - image = Image.open(attachment_file.stream) - img_width, img_height = image.size - - # NOTE: DO NOT close the image, as closing the image will - # also close the stream. - - # reset it to 0 for later usage - attachment_file.stream.seek(0) - - await app.db.execute( - """ - INSERT INTO attachments - (id, channel_id, message_id, - filename, filesize, - image, width, height) - VALUES - ($1, $2, $3, $4, $5, $6, $7, $8) - """, - attachment_id, - channel_id, - message_id, - filename, - file_size, - is_image, - img_width, - img_height, - ) - - ext = filename.split(".")[-1] - - with open(f"attachments/{attachment_id}.{ext}", "wb") as attach_file: - attach_file.write(attachment_file.stream.read()) - - log.debug("written {} bytes for attachment id {}", file_size, attachment_id) - - return attachment_id - - async def _spawn_embed(payload, **kwargs): app.sched.spawn(process_url_embed(payload, **kwargs))