diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py
index 1a3d6b5..8e375b1 100644
--- a/litecord/blueprints/channel/messages.py
+++ b/litecord/blueprints/channel/messages.py
@@ -17,65 +17,37 @@ along with this program. If not, see .
"""
-import json
from pathlib import Path
-from PIL import Image
from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger
from litecord.blueprints.auth import token_check
from litecord.blueprints.checks import channel_check, channel_perm_check
from litecord.blueprints.dms import try_dm_state
-from litecord.errors import MessageNotFound, Forbidden, BadRequest
+from litecord.errors import MessageNotFound, Forbidden
from litecord.enums import MessageType, ChannelType, GUILD_CHANS
from litecord.snowflake import get_snowflake
from litecord.schemas import validate, MESSAGE_CREATE
-from litecord.utils import pg_set_json
+from litecord.utils import pg_set_json, query_tuple_from_args, extract_limit
from litecord.permissions import get_permissions
from litecord.embed.sanitizer import fill_embed
from litecord.embed.messages import process_url_embed
-from litecord.blueprints.channel.dm_checks import dm_pre_check
+from litecord.common.channels import dm_pre_check, try_dm_state
from litecord.images import try_unlink
+from litecord.common.messages import (
+ msg_create_request,
+ msg_create_check_content,
+ msg_add_attachment,
+ msg_guild_text_mentions,
+)
log = Logger(__name__)
bp = Blueprint("channel_messages", __name__)
-def extract_limit(request_, default: int = 50, max_val: int = 100):
- """Extract a limit kwarg."""
- try:
- limit = int(request_.args.get("limit", default))
-
- if limit not in range(0, max_val + 1):
- raise ValueError()
- except (TypeError, ValueError):
- raise BadRequest("limit not int")
-
- return limit
-
-
-def query_tuple_from_args(args: dict, limit: int) -> tuple:
- """Extract a 2-tuple out of request arguments."""
- before, after = None, None
-
- if "around" in request.args:
- average = int(limit / 2)
- around = int(args["around"])
-
- after = around - average
- before = around + average
-
- elif "before" in args:
- before = int(args["before"])
- elif "after" in args:
- before = int(args["after"])
-
- return before, after
-
-
@bp.route("//messages", methods=["GET"])
async def get_messages(channel_id):
user_id = await token_check()
@@ -204,185 +176,6 @@ async def create_message(
return message_id
-async def msg_guild_text_mentions(
- payload: dict, guild_id: int, mentions_everyone: bool, mentions_here: bool
-):
- """Calculates mention data side-effects."""
- channel_id = int(payload["channel_id"])
-
- # calculate the user ids we'll bump the mention count for
- uids = set()
-
- # first is extracting user mentions
- for mention in payload["mentions"]:
- uids.add(int(mention["id"]))
-
- # then role mentions
- for role_mention in payload["mention_roles"]:
- role_id = int(role_mention)
- member_ids = await app.storage.get_role_members(role_id)
-
- for member_id in member_ids:
- uids.add(member_id)
-
- # at-here only updates the state
- # for the users that have a state
- # in the channel.
- if mentions_here:
- uids = set()
-
- await app.db.execute(
- """
- UPDATE user_read_state
- SET mention_count = mention_count + 1
- WHERE channel_id = $1
- """,
- channel_id,
- )
-
- # at-here updates the read state
- # for all users, including the ones
- # that might not have read permissions
- # to the channel.
- if mentions_everyone:
- uids = set()
-
- member_ids = await app.storage.get_member_ids(guild_id)
-
- await app.db.executemany(
- """
- UPDATE user_read_state
- SET mention_count = mention_count + 1
- WHERE channel_id = $1 AND user_id = $2
- """,
- [(channel_id, uid) for uid in member_ids],
- )
-
- for user_id in uids:
- await app.db.execute(
- """
- UPDATE user_read_state
- SET mention_count = mention_count + 1
- WHERE user_id = $1
- AND channel_id = $2
- """,
- user_id,
- channel_id,
- )
-
-
-async def msg_create_request() -> tuple:
- """Extract the json input and any file information
- the client gave to us in the request.
-
- This only applies to create message route.
- """
- form = await request.form
- request_json = await request.get_json() or {}
-
- # NOTE: embed isn't set on form data
- json_from_form = {
- "content": form.get("content", ""),
- "nonce": form.get("nonce", "0"),
- "tts": json.loads(form.get("tts", "false")),
- }
-
- payload_json = json.loads(form.get("payload_json", "{}"))
-
- json_from_form.update(request_json)
- json_from_form.update(payload_json)
-
- files = await request.files
-
- # we don't really care about the given fields on the files dict, so
- # we only extract the values
- return json_from_form, [v for k, v in files.items()]
-
-
-def msg_create_check_content(payload: dict, files: list, *, use_embeds=False):
- """Check if there is actually any content being sent to us."""
- has_content = bool(payload.get("content", ""))
- has_files = len(files) > 0
-
- embed_field = "embeds" if use_embeds else "embed"
- has_embed = embed_field in payload and payload.get(embed_field) is not None
-
- has_total_content = has_content or has_embed or has_files
-
- if not has_total_content:
- raise BadRequest("No content has been provided.")
-
-
-async def msg_add_attachment(message_id: int, channel_id: int, attachment_file) -> int:
- """Add an attachment to a message.
-
- Parameters
- ----------
- message_id: int
- The ID of the message getting the attachment.
- channel_id: int
- The ID of the channel the message belongs to.
-
- Exists because the attachment URL scheme contains
- a channel id. The purpose is unknown, but we are
- implementing Discord's behavior.
- attachment_file: quart.FileStorage
- quart FileStorage instance of the file.
- """
-
- attachment_id = get_snowflake()
- filename = attachment_file.filename
-
- # understand file info
- mime = attachment_file.mimetype
- is_image = mime.startswith("image/")
-
- img_width, img_height = None, None
-
- # extract file size
- # TODO: this is probably inneficient
- file_size = attachment_file.stream.getbuffer().nbytes
-
- if is_image:
- # open with pillow, extract image size
- image = Image.open(attachment_file.stream)
- img_width, img_height = image.size
-
- # NOTE: DO NOT close the image, as closing the image will
- # also close the stream.
-
- # reset it to 0 for later usage
- attachment_file.stream.seek(0)
-
- await app.db.execute(
- """
- INSERT INTO attachments
- (id, channel_id, message_id,
- filename, filesize,
- image, width, height)
- VALUES
- ($1, $2, $3, $4, $5, $6, $7, $8)
- """,
- attachment_id,
- channel_id,
- message_id,
- filename,
- file_size,
- is_image,
- img_width,
- img_height,
- )
-
- ext = filename.split(".")[-1]
-
- with open(f"attachments/{attachment_id}.{ext}", "wb") as attach_file:
- attach_file.write(attachment_file.stream.read())
-
- log.debug("written {} bytes for attachment id {}", file_size, attachment_id)
-
- return attachment_id
-
-
async def _spawn_embed(payload, **kwargs):
app.sched.spawn(process_url_embed(payload, **kwargs))