mirror of https://gitlab.com/litecord/litecord.git
create litecord.common
This commit is contained in:
parent
71a058e542
commit
1efc65511c
|
|
@ -16,15 +16,55 @@ You should have received a copy of the GNU General Public License
|
|||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
"""
|
||||
|
||||
from quart import current_app as app
|
||||
|
||||
from litecord.errors import Forbidden
|
||||
|
||||
from litecord.errors import ForbiddenDM
|
||||
from litecord.enums import RelationshipType
|
||||
|
||||
|
||||
class ForbiddenDM(Forbidden):
|
||||
error_code = 50007
|
||||
async def channel_ack(
|
||||
user_id: int, guild_id: int, channel_id: int, message_id: int = None
|
||||
):
|
||||
"""ACK a channel."""
|
||||
|
||||
if not message_id:
|
||||
message_id = await app.storage.chan_last_message(channel_id)
|
||||
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO user_read_state
|
||||
(user_id, channel_id, last_message_id, mention_count)
|
||||
VALUES
|
||||
($1, $2, $3, 0)
|
||||
ON CONFLICT ON CONSTRAINT user_read_state_pkey
|
||||
DO
|
||||
UPDATE
|
||||
SET last_message_id = $3, mention_count = 0
|
||||
WHERE user_read_state.user_id = $1
|
||||
AND user_read_state.channel_id = $2
|
||||
""",
|
||||
user_id,
|
||||
channel_id,
|
||||
message_id,
|
||||
)
|
||||
|
||||
if guild_id:
|
||||
await app.dispatcher.dispatch_user_guild(
|
||||
user_id,
|
||||
guild_id,
|
||||
"MESSAGE_ACK",
|
||||
{"message_id": str(message_id), "channel_id": str(channel_id)},
|
||||
)
|
||||
else:
|
||||
# we don't use ChannelDispatcher here because since
|
||||
# guild_id is None, all user devices are already subscribed
|
||||
# to the given channel (a dm or a group dm)
|
||||
await app.dispatcher.dispatch_user(
|
||||
user_id,
|
||||
"MESSAGE_ACK",
|
||||
{"message_id": str(message_id), "channel_id": str(channel_id)},
|
||||
)
|
||||
|
||||
|
||||
async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
|
||||
|
|
@ -32,12 +72,12 @@ async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
|
|||
# first step is checking if there is a block in any direction
|
||||
blockrow = await app.db.fetchrow(
|
||||
"""
|
||||
SELECT rel_type
|
||||
FROM relationships
|
||||
WHERE rel_type = $3
|
||||
AND user_id IN ($1, $2)
|
||||
AND peer_id IN ($1, $2)
|
||||
""",
|
||||
SELECT rel_type
|
||||
FROM relationships
|
||||
WHERE rel_type = $3
|
||||
AND user_id IN ($1, $2)
|
||||
AND peer_id IN ($1, $2)
|
||||
""",
|
||||
user_id,
|
||||
peer_id,
|
||||
RelationshipType.BLOCK.value,
|
||||
|
|
@ -75,3 +115,21 @@ async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
|
|||
# if after this filtering we don't have any more guilds, error
|
||||
if not mutual_guilds:
|
||||
raise ForbiddenDM()
|
||||
|
||||
|
||||
async def try_dm_state(user_id: int, dm_id: int):
|
||||
"""Try inserting the user into the dm state
|
||||
for the given DM.
|
||||
|
||||
Does not do anything if the user is already
|
||||
in the dm state.
|
||||
"""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO dm_channel_state (user_id, dm_id)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT DO NOTHING
|
||||
""",
|
||||
user_id,
|
||||
dm_id,
|
||||
)
|
||||
|
|
@ -0,0 +1,206 @@
|
|||
"""
|
||||
|
||||
Litecord
|
||||
Copyright (C) 2018-2019 Luna Mendes
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, version 3 of the License.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
"""
|
||||
|
||||
from quart import current_app as app
|
||||
|
||||
from ..snowflake import get_snowflake
|
||||
from ..permissions import get_role_perms
|
||||
from ..utils import dict_get, maybe_lazy_guild_dispatch
|
||||
from ..enums import ChannelType
|
||||
|
||||
|
||||
async def remove_member(guild_id: int, member_id: int):
|
||||
"""Do common tasks related to deleting a member from the guild,
|
||||
such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE."""
|
||||
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM members
|
||||
WHERE guild_id = $1 AND user_id = $2
|
||||
""",
|
||||
guild_id,
|
||||
member_id,
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch_user_guild(
|
||||
member_id,
|
||||
guild_id,
|
||||
"GUILD_DELETE",
|
||||
{"guild_id": str(guild_id), "unavailable": False},
|
||||
)
|
||||
|
||||
await app.dispatcher.unsub("guild", guild_id, member_id)
|
||||
|
||||
await app.dispatcher.dispatch("lazy_guild", guild_id, "remove_member", member_id)
|
||||
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id,
|
||||
"GUILD_MEMBER_REMOVE",
|
||||
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
|
||||
)
|
||||
|
||||
|
||||
async def remove_member_multi(guild_id: int, members: list):
|
||||
"""Remove multiple members."""
|
||||
for member_id in members:
|
||||
await remove_member(guild_id, member_id)
|
||||
|
||||
|
||||
async def create_role(guild_id, name: str, **kwargs):
|
||||
"""Create a role in a guild."""
|
||||
new_role_id = get_snowflake()
|
||||
|
||||
everyone_perms = await get_role_perms(guild_id, guild_id)
|
||||
default_perms = dict_get(kwargs, "default_perms", everyone_perms.binary)
|
||||
|
||||
# update all roles so that we have space for pos 1, but without
|
||||
# sending GUILD_ROLE_UPDATE for everyone
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE roles
|
||||
SET
|
||||
position = position + 1
|
||||
WHERE guild_id = $1
|
||||
AND NOT (position = 0)
|
||||
""",
|
||||
guild_id,
|
||||
)
|
||||
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO roles (id, guild_id, name, color,
|
||||
hoist, position, permissions, managed, mentionable)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
""",
|
||||
new_role_id,
|
||||
guild_id,
|
||||
name,
|
||||
dict_get(kwargs, "color", 0),
|
||||
dict_get(kwargs, "hoist", False),
|
||||
# always set ourselves on position 1
|
||||
1,
|
||||
int(dict_get(kwargs, "permissions", default_perms)),
|
||||
False,
|
||||
dict_get(kwargs, "mentionable", False),
|
||||
)
|
||||
|
||||
role = await app.storage.get_role(new_role_id, guild_id)
|
||||
|
||||
# we need to update the lazy guild handlers for the newly created group
|
||||
await maybe_lazy_guild_dispatch(guild_id, "new_role", role)
|
||||
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, "GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role}
|
||||
)
|
||||
|
||||
return role
|
||||
|
||||
|
||||
async def _specific_chan_create(channel_id, ctype, **kwargs):
|
||||
if ctype == ChannelType.GUILD_TEXT:
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO guild_text_channels (id, topic)
|
||||
VALUES ($1, $2)
|
||||
""",
|
||||
channel_id,
|
||||
kwargs.get("topic", ""),
|
||||
)
|
||||
elif ctype == ChannelType.GUILD_VOICE:
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO guild_voice_channels (id, bitrate, user_limit)
|
||||
VALUES ($1, $2, $3)
|
||||
""",
|
||||
channel_id,
|
||||
kwargs.get("bitrate", 64),
|
||||
kwargs.get("user_limit", 0),
|
||||
)
|
||||
|
||||
|
||||
async def create_guild_channel(
|
||||
guild_id: int, channel_id: int, ctype: ChannelType, **kwargs
|
||||
):
|
||||
"""Create a channel in a guild."""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO channels (id, channel_type)
|
||||
VALUES ($1, $2)
|
||||
""",
|
||||
channel_id,
|
||||
ctype.value,
|
||||
)
|
||||
|
||||
# calc new pos
|
||||
max_pos = await app.db.fetchval(
|
||||
"""
|
||||
SELECT MAX(position)
|
||||
FROM guild_channels
|
||||
WHERE guild_id = $1
|
||||
""",
|
||||
guild_id,
|
||||
)
|
||||
|
||||
# account for the first channel in a guild too
|
||||
max_pos = max_pos or 0
|
||||
|
||||
# all channels go to guild_channels
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO guild_channels (id, guild_id, name, position)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
""",
|
||||
channel_id,
|
||||
guild_id,
|
||||
kwargs["name"],
|
||||
max_pos + 1,
|
||||
)
|
||||
|
||||
# the rest of sql magic is dependant on the channel
|
||||
# we're creating (a text or voice or category),
|
||||
# so we use this function.
|
||||
await _specific_chan_create(channel_id, ctype, **kwargs)
|
||||
|
||||
|
||||
async def delete_guild(guild_id: int):
|
||||
"""Delete a single guild."""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM guilds
|
||||
WHERE guilds.id = $1
|
||||
""",
|
||||
guild_id,
|
||||
)
|
||||
|
||||
# Discord's client expects IDs being string
|
||||
await app.dispatcher.dispatch(
|
||||
"guild",
|
||||
guild_id,
|
||||
"GUILD_DELETE",
|
||||
{
|
||||
"guild_id": str(guild_id),
|
||||
"id": str(guild_id),
|
||||
# 'unavailable': False,
|
||||
},
|
||||
)
|
||||
|
||||
# remove from the dispatcher so nobody
|
||||
# becomes the little memer that tries to fuck up with
|
||||
# everybody's gateway
|
||||
await app.dispatcher.remove("guild", guild_id)
|
||||
|
|
@ -0,0 +1,189 @@
|
|||
import json
|
||||
import logging
|
||||
|
||||
from PIL import Image
|
||||
from quart import request, current_app as app
|
||||
|
||||
from litecord.errors import BadRequest
|
||||
from ..snowflake import get_snowflake
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def msg_create_request() -> tuple:
|
||||
"""Extract the json input and any file information
|
||||
the client gave to us in the request.
|
||||
|
||||
This only applies to create message route.
|
||||
"""
|
||||
form = await request.form
|
||||
request_json = await request.get_json() or {}
|
||||
|
||||
# NOTE: embed isn't set on form data
|
||||
json_from_form = {
|
||||
"content": form.get("content", ""),
|
||||
"nonce": form.get("nonce", "0"),
|
||||
"tts": json.loads(form.get("tts", "false")),
|
||||
}
|
||||
|
||||
payload_json = json.loads(form.get("payload_json", "{}"))
|
||||
|
||||
json_from_form.update(request_json)
|
||||
json_from_form.update(payload_json)
|
||||
|
||||
files = await request.files
|
||||
|
||||
# we don't really care about the given fields on the files dict, so
|
||||
# we only extract the values
|
||||
return json_from_form, [v for k, v in files.items()]
|
||||
|
||||
|
||||
def msg_create_check_content(payload: dict, files: list, *, use_embeds=False):
|
||||
"""Check if there is actually any content being sent to us."""
|
||||
has_content = bool(payload.get("content", ""))
|
||||
has_files = len(files) > 0
|
||||
|
||||
embed_field = "embeds" if use_embeds else "embed"
|
||||
has_embed = embed_field in payload and payload.get(embed_field) is not None
|
||||
|
||||
has_total_content = has_content or has_embed or has_files
|
||||
|
||||
if not has_total_content:
|
||||
raise BadRequest("No content has been provided.")
|
||||
|
||||
|
||||
async def msg_add_attachment(message_id: int, channel_id: int, attachment_file) -> int:
|
||||
"""Add an attachment to a message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
message_id: int
|
||||
The ID of the message getting the attachment.
|
||||
channel_id: int
|
||||
The ID of the channel the message belongs to.
|
||||
|
||||
Exists because the attachment URL scheme contains
|
||||
a channel id. The purpose is unknown, but we are
|
||||
implementing Discord's behavior.
|
||||
attachment_file: quart.FileStorage
|
||||
quart FileStorage instance of the file.
|
||||
"""
|
||||
|
||||
attachment_id = get_snowflake()
|
||||
filename = attachment_file.filename
|
||||
|
||||
# understand file info
|
||||
mime = attachment_file.mimetype
|
||||
is_image = mime.startswith("image/")
|
||||
|
||||
img_width, img_height = None, None
|
||||
|
||||
# extract file size
|
||||
# TODO: this is probably inneficient
|
||||
file_size = attachment_file.stream.getbuffer().nbytes
|
||||
|
||||
if is_image:
|
||||
# open with pillow, extract image size
|
||||
image = Image.open(attachment_file.stream)
|
||||
img_width, img_height = image.size
|
||||
|
||||
# NOTE: DO NOT close the image, as closing the image will
|
||||
# also close the stream.
|
||||
|
||||
# reset it to 0 for later usage
|
||||
attachment_file.stream.seek(0)
|
||||
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO attachments
|
||||
(id, channel_id, message_id,
|
||||
filename, filesize,
|
||||
image, width, height)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
""",
|
||||
attachment_id,
|
||||
channel_id,
|
||||
message_id,
|
||||
filename,
|
||||
file_size,
|
||||
is_image,
|
||||
img_width,
|
||||
img_height,
|
||||
)
|
||||
|
||||
ext = filename.split(".")[-1]
|
||||
|
||||
with open(f"attachments/{attachment_id}.{ext}", "wb") as attach_file:
|
||||
attach_file.write(attachment_file.stream.read())
|
||||
|
||||
log.debug("written {} bytes for attachment id {}", file_size, attachment_id)
|
||||
|
||||
return attachment_id
|
||||
|
||||
|
||||
async def msg_guild_text_mentions(
|
||||
payload: dict, guild_id: int, mentions_everyone: bool, mentions_here: bool
|
||||
):
|
||||
"""Calculates mention data side-effects."""
|
||||
channel_id = int(payload["channel_id"])
|
||||
|
||||
# calculate the user ids we'll bump the mention count for
|
||||
uids = set()
|
||||
|
||||
# first is extracting user mentions
|
||||
for mention in payload["mentions"]:
|
||||
uids.add(int(mention["id"]))
|
||||
|
||||
# then role mentions
|
||||
for role_mention in payload["mention_roles"]:
|
||||
role_id = int(role_mention)
|
||||
member_ids = await app.storage.get_role_members(role_id)
|
||||
|
||||
for member_id in member_ids:
|
||||
uids.add(member_id)
|
||||
|
||||
# at-here only updates the state
|
||||
# for the users that have a state
|
||||
# in the channel.
|
||||
if mentions_here:
|
||||
uids = set()
|
||||
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE user_read_state
|
||||
SET mention_count = mention_count + 1
|
||||
WHERE channel_id = $1
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
# at-here updates the read state
|
||||
# for all users, including the ones
|
||||
# that might not have read permissions
|
||||
# to the channel.
|
||||
if mentions_everyone:
|
||||
uids = set()
|
||||
|
||||
member_ids = await app.storage.get_member_ids(guild_id)
|
||||
|
||||
await app.db.executemany(
|
||||
"""
|
||||
UPDATE user_read_state
|
||||
SET mention_count = mention_count + 1
|
||||
WHERE channel_id = $1 AND user_id = $2
|
||||
""",
|
||||
[(channel_id, uid) for uid in member_ids],
|
||||
)
|
||||
|
||||
for user_id in uids:
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE user_read_state
|
||||
SET mention_count = mention_count + 1
|
||||
WHERE user_id = $1
|
||||
AND channel_id = $2
|
||||
""",
|
||||
user_id,
|
||||
channel_id,
|
||||
)
|
||||
Loading…
Reference in New Issue