mirror of https://gitlab.com/litecord/litecord.git
messages: use common/utils functions
This commit is contained in:
parent
a67b6580ba
commit
2ebb94f476
|
|
@ -17,65 +17,37 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from quart import Blueprint, request, current_app as app, jsonify
|
||||
from logbook import Logger
|
||||
|
||||
from litecord.blueprints.auth import token_check
|
||||
from litecord.blueprints.checks import channel_check, channel_perm_check
|
||||
from litecord.blueprints.dms import try_dm_state
|
||||
from litecord.errors import MessageNotFound, Forbidden, BadRequest
|
||||
from litecord.errors import MessageNotFound, Forbidden
|
||||
from litecord.enums import MessageType, ChannelType, GUILD_CHANS
|
||||
from litecord.snowflake import get_snowflake
|
||||
from litecord.schemas import validate, MESSAGE_CREATE
|
||||
from litecord.utils import pg_set_json
|
||||
from litecord.utils import pg_set_json, query_tuple_from_args, extract_limit
|
||||
from litecord.permissions import get_permissions
|
||||
|
||||
from litecord.embed.sanitizer import fill_embed
|
||||
from litecord.embed.messages import process_url_embed
|
||||
from litecord.blueprints.channel.dm_checks import dm_pre_check
|
||||
from litecord.common.channels import dm_pre_check, try_dm_state
|
||||
from litecord.images import try_unlink
|
||||
from litecord.common.messages import (
|
||||
msg_create_request,
|
||||
msg_create_check_content,
|
||||
msg_add_attachment,
|
||||
msg_guild_text_mentions,
|
||||
)
|
||||
|
||||
|
||||
log = Logger(__name__)
|
||||
bp = Blueprint("channel_messages", __name__)
|
||||
|
||||
|
||||
def extract_limit(request_, default: int = 50, max_val: int = 100):
|
||||
"""Extract a limit kwarg."""
|
||||
try:
|
||||
limit = int(request_.args.get("limit", default))
|
||||
|
||||
if limit not in range(0, max_val + 1):
|
||||
raise ValueError()
|
||||
except (TypeError, ValueError):
|
||||
raise BadRequest("limit not int")
|
||||
|
||||
return limit
|
||||
|
||||
|
||||
def query_tuple_from_args(args: dict, limit: int) -> tuple:
|
||||
"""Extract a 2-tuple out of request arguments."""
|
||||
before, after = None, None
|
||||
|
||||
if "around" in request.args:
|
||||
average = int(limit / 2)
|
||||
around = int(args["around"])
|
||||
|
||||
after = around - average
|
||||
before = around + average
|
||||
|
||||
elif "before" in args:
|
||||
before = int(args["before"])
|
||||
elif "after" in args:
|
||||
before = int(args["after"])
|
||||
|
||||
return before, after
|
||||
|
||||
|
||||
@bp.route("/<int:channel_id>/messages", methods=["GET"])
|
||||
async def get_messages(channel_id):
|
||||
user_id = await token_check()
|
||||
|
|
@ -204,185 +176,6 @@ async def create_message(
|
|||
return message_id
|
||||
|
||||
|
||||
async def msg_guild_text_mentions(
|
||||
payload: dict, guild_id: int, mentions_everyone: bool, mentions_here: bool
|
||||
):
|
||||
"""Calculates mention data side-effects."""
|
||||
channel_id = int(payload["channel_id"])
|
||||
|
||||
# calculate the user ids we'll bump the mention count for
|
||||
uids = set()
|
||||
|
||||
# first is extracting user mentions
|
||||
for mention in payload["mentions"]:
|
||||
uids.add(int(mention["id"]))
|
||||
|
||||
# then role mentions
|
||||
for role_mention in payload["mention_roles"]:
|
||||
role_id = int(role_mention)
|
||||
member_ids = await app.storage.get_role_members(role_id)
|
||||
|
||||
for member_id in member_ids:
|
||||
uids.add(member_id)
|
||||
|
||||
# at-here only updates the state
|
||||
# for the users that have a state
|
||||
# in the channel.
|
||||
if mentions_here:
|
||||
uids = set()
|
||||
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE user_read_state
|
||||
SET mention_count = mention_count + 1
|
||||
WHERE channel_id = $1
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
# at-here updates the read state
|
||||
# for all users, including the ones
|
||||
# that might not have read permissions
|
||||
# to the channel.
|
||||
if mentions_everyone:
|
||||
uids = set()
|
||||
|
||||
member_ids = await app.storage.get_member_ids(guild_id)
|
||||
|
||||
await app.db.executemany(
|
||||
"""
|
||||
UPDATE user_read_state
|
||||
SET mention_count = mention_count + 1
|
||||
WHERE channel_id = $1 AND user_id = $2
|
||||
""",
|
||||
[(channel_id, uid) for uid in member_ids],
|
||||
)
|
||||
|
||||
for user_id in uids:
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE user_read_state
|
||||
SET mention_count = mention_count + 1
|
||||
WHERE user_id = $1
|
||||
AND channel_id = $2
|
||||
""",
|
||||
user_id,
|
||||
channel_id,
|
||||
)
|
||||
|
||||
|
||||
async def msg_create_request() -> tuple:
|
||||
"""Extract the json input and any file information
|
||||
the client gave to us in the request.
|
||||
|
||||
This only applies to create message route.
|
||||
"""
|
||||
form = await request.form
|
||||
request_json = await request.get_json() or {}
|
||||
|
||||
# NOTE: embed isn't set on form data
|
||||
json_from_form = {
|
||||
"content": form.get("content", ""),
|
||||
"nonce": form.get("nonce", "0"),
|
||||
"tts": json.loads(form.get("tts", "false")),
|
||||
}
|
||||
|
||||
payload_json = json.loads(form.get("payload_json", "{}"))
|
||||
|
||||
json_from_form.update(request_json)
|
||||
json_from_form.update(payload_json)
|
||||
|
||||
files = await request.files
|
||||
|
||||
# we don't really care about the given fields on the files dict, so
|
||||
# we only extract the values
|
||||
return json_from_form, [v for k, v in files.items()]
|
||||
|
||||
|
||||
def msg_create_check_content(payload: dict, files: list, *, use_embeds=False):
|
||||
"""Check if there is actually any content being sent to us."""
|
||||
has_content = bool(payload.get("content", ""))
|
||||
has_files = len(files) > 0
|
||||
|
||||
embed_field = "embeds" if use_embeds else "embed"
|
||||
has_embed = embed_field in payload and payload.get(embed_field) is not None
|
||||
|
||||
has_total_content = has_content or has_embed or has_files
|
||||
|
||||
if not has_total_content:
|
||||
raise BadRequest("No content has been provided.")
|
||||
|
||||
|
||||
async def msg_add_attachment(message_id: int, channel_id: int, attachment_file) -> int:
|
||||
"""Add an attachment to a message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
message_id: int
|
||||
The ID of the message getting the attachment.
|
||||
channel_id: int
|
||||
The ID of the channel the message belongs to.
|
||||
|
||||
Exists because the attachment URL scheme contains
|
||||
a channel id. The purpose is unknown, but we are
|
||||
implementing Discord's behavior.
|
||||
attachment_file: quart.FileStorage
|
||||
quart FileStorage instance of the file.
|
||||
"""
|
||||
|
||||
attachment_id = get_snowflake()
|
||||
filename = attachment_file.filename
|
||||
|
||||
# understand file info
|
||||
mime = attachment_file.mimetype
|
||||
is_image = mime.startswith("image/")
|
||||
|
||||
img_width, img_height = None, None
|
||||
|
||||
# extract file size
|
||||
# TODO: this is probably inneficient
|
||||
file_size = attachment_file.stream.getbuffer().nbytes
|
||||
|
||||
if is_image:
|
||||
# open with pillow, extract image size
|
||||
image = Image.open(attachment_file.stream)
|
||||
img_width, img_height = image.size
|
||||
|
||||
# NOTE: DO NOT close the image, as closing the image will
|
||||
# also close the stream.
|
||||
|
||||
# reset it to 0 for later usage
|
||||
attachment_file.stream.seek(0)
|
||||
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO attachments
|
||||
(id, channel_id, message_id,
|
||||
filename, filesize,
|
||||
image, width, height)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
""",
|
||||
attachment_id,
|
||||
channel_id,
|
||||
message_id,
|
||||
filename,
|
||||
file_size,
|
||||
is_image,
|
||||
img_width,
|
||||
img_height,
|
||||
)
|
||||
|
||||
ext = filename.split(".")[-1]
|
||||
|
||||
with open(f"attachments/{attachment_id}.{ext}", "wb") as attach_file:
|
||||
attach_file.write(attachment_file.stream.read())
|
||||
|
||||
log.debug("written {} bytes for attachment id {}", file_size, attachment_id)
|
||||
|
||||
return attachment_id
|
||||
|
||||
|
||||
async def _spawn_embed(payload, **kwargs):
|
||||
app.sched.spawn(process_url_embed(payload, **kwargs))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue