create litecord.common

This commit is contained in:
Luna 2019-10-25 13:31:47 -03:00
parent 71a058e542
commit 1efc65511c
4 changed files with 463 additions and 10 deletions

View File

View File

@ -16,15 +16,55 @@ You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from quart import current_app as app from quart import current_app as app
from litecord.errors import Forbidden
from litecord.errors import ForbiddenDM
from litecord.enums import RelationshipType from litecord.enums import RelationshipType
class ForbiddenDM(Forbidden): async def channel_ack(
error_code = 50007 user_id: int, guild_id: int, channel_id: int, message_id: int = None
):
"""ACK a channel."""
if not message_id:
message_id = await app.storage.chan_last_message(channel_id)
await app.db.execute(
"""
INSERT INTO user_read_state
(user_id, channel_id, last_message_id, mention_count)
VALUES
($1, $2, $3, 0)
ON CONFLICT ON CONSTRAINT user_read_state_pkey
DO
UPDATE
SET last_message_id = $3, mention_count = 0
WHERE user_read_state.user_id = $1
AND user_read_state.channel_id = $2
""",
user_id,
channel_id,
message_id,
)
if guild_id:
await app.dispatcher.dispatch_user_guild(
user_id,
guild_id,
"MESSAGE_ACK",
{"message_id": str(message_id), "channel_id": str(channel_id)},
)
else:
# we don't use ChannelDispatcher here because since
# guild_id is None, all user devices are already subscribed
# to the given channel (a dm or a group dm)
await app.dispatcher.dispatch_user(
user_id,
"MESSAGE_ACK",
{"message_id": str(message_id), "channel_id": str(channel_id)},
)
async def dm_pre_check(user_id: int, channel_id: int, peer_id: int): async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
@ -75,3 +115,21 @@ async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
# if after this filtering we don't have any more guilds, error # if after this filtering we don't have any more guilds, error
if not mutual_guilds: if not mutual_guilds:
raise ForbiddenDM() raise ForbiddenDM()
async def try_dm_state(user_id: int, dm_id: int):
"""Try inserting the user into the dm state
for the given DM.
Does not do anything if the user is already
in the dm state.
"""
await app.db.execute(
"""
INSERT INTO dm_channel_state (user_id, dm_id)
VALUES ($1, $2)
ON CONFLICT DO NOTHING
""",
user_id,
dm_id,
)

206
litecord/common/guilds.py Normal file
View File

@ -0,0 +1,206 @@
"""
Litecord
Copyright (C) 2018-2019 Luna Mendes
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3 of the License.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from quart import current_app as app
from ..snowflake import get_snowflake
from ..permissions import get_role_perms
from ..utils import dict_get, maybe_lazy_guild_dispatch
from ..enums import ChannelType
async def remove_member(guild_id: int, member_id: int):
"""Do common tasks related to deleting a member from the guild,
such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE."""
await app.db.execute(
"""
DELETE FROM members
WHERE guild_id = $1 AND user_id = $2
""",
guild_id,
member_id,
)
await app.dispatcher.dispatch_user_guild(
member_id,
guild_id,
"GUILD_DELETE",
{"guild_id": str(guild_id), "unavailable": False},
)
await app.dispatcher.unsub("guild", guild_id, member_id)
await app.dispatcher.dispatch("lazy_guild", guild_id, "remove_member", member_id)
await app.dispatcher.dispatch_guild(
guild_id,
"GUILD_MEMBER_REMOVE",
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
)
async def remove_member_multi(guild_id: int, members: list):
"""Remove multiple members."""
for member_id in members:
await remove_member(guild_id, member_id)
async def create_role(guild_id, name: str, **kwargs):
"""Create a role in a guild."""
new_role_id = get_snowflake()
everyone_perms = await get_role_perms(guild_id, guild_id)
default_perms = dict_get(kwargs, "default_perms", everyone_perms.binary)
# update all roles so that we have space for pos 1, but without
# sending GUILD_ROLE_UPDATE for everyone
await app.db.execute(
"""
UPDATE roles
SET
position = position + 1
WHERE guild_id = $1
AND NOT (position = 0)
""",
guild_id,
)
await app.db.execute(
"""
INSERT INTO roles (id, guild_id, name, color,
hoist, position, permissions, managed, mentionable)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
""",
new_role_id,
guild_id,
name,
dict_get(kwargs, "color", 0),
dict_get(kwargs, "hoist", False),
# always set ourselves on position 1
1,
int(dict_get(kwargs, "permissions", default_perms)),
False,
dict_get(kwargs, "mentionable", False),
)
role = await app.storage.get_role(new_role_id, guild_id)
# we need to update the lazy guild handlers for the newly created group
await maybe_lazy_guild_dispatch(guild_id, "new_role", role)
await app.dispatcher.dispatch_guild(
guild_id, "GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role}
)
return role
async def _specific_chan_create(channel_id, ctype, **kwargs):
if ctype == ChannelType.GUILD_TEXT:
await app.db.execute(
"""
INSERT INTO guild_text_channels (id, topic)
VALUES ($1, $2)
""",
channel_id,
kwargs.get("topic", ""),
)
elif ctype == ChannelType.GUILD_VOICE:
await app.db.execute(
"""
INSERT INTO guild_voice_channels (id, bitrate, user_limit)
VALUES ($1, $2, $3)
""",
channel_id,
kwargs.get("bitrate", 64),
kwargs.get("user_limit", 0),
)
async def create_guild_channel(
guild_id: int, channel_id: int, ctype: ChannelType, **kwargs
):
"""Create a channel in a guild."""
await app.db.execute(
"""
INSERT INTO channels (id, channel_type)
VALUES ($1, $2)
""",
channel_id,
ctype.value,
)
# calc new pos
max_pos = await app.db.fetchval(
"""
SELECT MAX(position)
FROM guild_channels
WHERE guild_id = $1
""",
guild_id,
)
# account for the first channel in a guild too
max_pos = max_pos or 0
# all channels go to guild_channels
await app.db.execute(
"""
INSERT INTO guild_channels (id, guild_id, name, position)
VALUES ($1, $2, $3, $4)
""",
channel_id,
guild_id,
kwargs["name"],
max_pos + 1,
)
# the rest of sql magic is dependant on the channel
# we're creating (a text or voice or category),
# so we use this function.
await _specific_chan_create(channel_id, ctype, **kwargs)
async def delete_guild(guild_id: int):
"""Delete a single guild."""
await app.db.execute(
"""
DELETE FROM guilds
WHERE guilds.id = $1
""",
guild_id,
)
# Discord's client expects IDs being string
await app.dispatcher.dispatch(
"guild",
guild_id,
"GUILD_DELETE",
{
"guild_id": str(guild_id),
"id": str(guild_id),
# 'unavailable': False,
},
)
# remove from the dispatcher so nobody
# becomes the little memer that tries to fuck up with
# everybody's gateway
await app.dispatcher.remove("guild", guild_id)

189
litecord/common/messages.py Normal file
View File

@ -0,0 +1,189 @@
import json
import logging
from PIL import Image
from quart import request, current_app as app
from litecord.errors import BadRequest
from ..snowflake import get_snowflake
log = logging.getLogger(__name__)
async def msg_create_request() -> tuple:
"""Extract the json input and any file information
the client gave to us in the request.
This only applies to create message route.
"""
form = await request.form
request_json = await request.get_json() or {}
# NOTE: embed isn't set on form data
json_from_form = {
"content": form.get("content", ""),
"nonce": form.get("nonce", "0"),
"tts": json.loads(form.get("tts", "false")),
}
payload_json = json.loads(form.get("payload_json", "{}"))
json_from_form.update(request_json)
json_from_form.update(payload_json)
files = await request.files
# we don't really care about the given fields on the files dict, so
# we only extract the values
return json_from_form, [v for k, v in files.items()]
def msg_create_check_content(payload: dict, files: list, *, use_embeds=False):
"""Check if there is actually any content being sent to us."""
has_content = bool(payload.get("content", ""))
has_files = len(files) > 0
embed_field = "embeds" if use_embeds else "embed"
has_embed = embed_field in payload and payload.get(embed_field) is not None
has_total_content = has_content or has_embed or has_files
if not has_total_content:
raise BadRequest("No content has been provided.")
async def msg_add_attachment(message_id: int, channel_id: int, attachment_file) -> int:
"""Add an attachment to a message.
Parameters
----------
message_id: int
The ID of the message getting the attachment.
channel_id: int
The ID of the channel the message belongs to.
Exists because the attachment URL scheme contains
a channel id. The purpose is unknown, but we are
implementing Discord's behavior.
attachment_file: quart.FileStorage
quart FileStorage instance of the file.
"""
attachment_id = get_snowflake()
filename = attachment_file.filename
# understand file info
mime = attachment_file.mimetype
is_image = mime.startswith("image/")
img_width, img_height = None, None
# extract file size
# TODO: this is probably inneficient
file_size = attachment_file.stream.getbuffer().nbytes
if is_image:
# open with pillow, extract image size
image = Image.open(attachment_file.stream)
img_width, img_height = image.size
# NOTE: DO NOT close the image, as closing the image will
# also close the stream.
# reset it to 0 for later usage
attachment_file.stream.seek(0)
await app.db.execute(
"""
INSERT INTO attachments
(id, channel_id, message_id,
filename, filesize,
image, width, height)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8)
""",
attachment_id,
channel_id,
message_id,
filename,
file_size,
is_image,
img_width,
img_height,
)
ext = filename.split(".")[-1]
with open(f"attachments/{attachment_id}.{ext}", "wb") as attach_file:
attach_file.write(attachment_file.stream.read())
log.debug("written {} bytes for attachment id {}", file_size, attachment_id)
return attachment_id
async def msg_guild_text_mentions(
payload: dict, guild_id: int, mentions_everyone: bool, mentions_here: bool
):
"""Calculates mention data side-effects."""
channel_id = int(payload["channel_id"])
# calculate the user ids we'll bump the mention count for
uids = set()
# first is extracting user mentions
for mention in payload["mentions"]:
uids.add(int(mention["id"]))
# then role mentions
for role_mention in payload["mention_roles"]:
role_id = int(role_mention)
member_ids = await app.storage.get_role_members(role_id)
for member_id in member_ids:
uids.add(member_id)
# at-here only updates the state
# for the users that have a state
# in the channel.
if mentions_here:
uids = set()
await app.db.execute(
"""
UPDATE user_read_state
SET mention_count = mention_count + 1
WHERE channel_id = $1
""",
channel_id,
)
# at-here updates the read state
# for all users, including the ones
# that might not have read permissions
# to the channel.
if mentions_everyone:
uids = set()
member_ids = await app.storage.get_member_ids(guild_id)
await app.db.executemany(
"""
UPDATE user_read_state
SET mention_count = mention_count + 1
WHERE channel_id = $1 AND user_id = $2
""",
[(channel_id, uid) for uid in member_ids],
)
for user_id in uids:
await app.db.execute(
"""
UPDATE user_read_state
SET mention_count = mention_count + 1
WHERE user_id = $1
AND channel_id = $2
""",
user_id,
channel_id,
)