Merge branch 'enhance/refactor-and-move-internals' into 'master'

Refactor and move internals

See merge request litecord/litecord!47
This commit is contained in:
Luna 2019-10-25 19:52:31 +00:00
commit 57220179bc
35 changed files with 1097 additions and 1089 deletions

View File

@ -19,17 +19,13 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import base64
import binascii
from random import randint
from typing import Tuple, Optional
import bcrypt
from asyncpg import UniqueViolationError
from itsdangerous import TimestampSigner, BadSignature
from logbook import Logger
from quart import request, current_app as app
from litecord.errors import Forbidden, Unauthorized, BadRequest
from litecord.snowflake import get_snowflake
from litecord.errors import Forbidden, Unauthorized
from litecord.enums import UserFlags
@ -56,11 +52,11 @@ async def raw_token_check(token: str, db=None) -> int:
# just try by fragments instead of
# unpacking
fragments = token.split(".")
user_id = fragments[0]
user_id_str = fragments[0]
try:
user_id = base64.b64decode(user_id.encode())
user_id = int(user_id)
user_id_decoded = base64.b64decode(user_id_str.encode())
user_id = int(user_id_decoded)
except (ValueError, binascii.Error):
raise Unauthorized("Invalid user ID type")
@ -150,105 +146,3 @@ async def hash_data(data: str, loop=None) -> str:
hashed = await loop.run_in_executor(None, bcrypt.hashpw, buf, bcrypt.gensalt(14))
return hashed.decode()
async def check_username_usage(username: str):
"""Raise an error if too many people are with the same username."""
same_username = await app.db.fetchval(
"""
SELECT COUNT(*)
FROM users
WHERE username = $1
""",
username,
)
if same_username > 9000:
raise BadRequest(
"Too many people.",
{
"username": "Too many people used the same username. "
"Please choose another"
},
)
def _raw_discrim() -> str:
discrim_number = randint(1, 9999)
return "%04d" % discrim_number
async def roll_discrim(username: str) -> Optional[str]:
"""Roll a discriminator for a DiscordTag.
Tries to generate one 10 times.
Calls check_username_usage.
"""
# we shouldn't roll discrims for usernames
# that have been used too much.
await check_username_usage(username)
# max 10 times for a reroll
for _ in range(10):
# generate random discrim
discrim = _raw_discrim()
# check if anyone is with it
res = await app.db.fetchval(
"""
SELECT id
FROM users
WHERE username = $1 AND discriminator = $2
""",
username,
discrim,
)
# if no user is found with the (username, discrim)
# pair, then this is unique! return it.
if res is None:
return discrim
return None
async def create_user(username: str, email: str, password: str) -> Tuple[int, str]:
"""Create a single user.
Generates a distriminator and other information. You can fetch the user
data back with :meth:`Storage.get_user`.
"""
db = app.db
loop = app.loop
new_id = get_snowflake()
new_discrim = await roll_discrim(username)
if new_discrim is None:
raise BadRequest(
"Unable to register.",
{"username": "Too many people are with this username."},
)
pwd_hash = await hash_data(password, loop)
try:
await db.execute(
"""
INSERT INTO users
(id, email, username, discriminator, password_hash)
VALUES
($1, $2, $3, $4, $5)
""",
new_id,
email,
username,
new_discrim,
pwd_hash,
)
except UniqueViolationError:
raise BadRequest("Email already used.")
return new_id, pwd_hash

View File

@ -22,7 +22,7 @@ from quart import Blueprint, jsonify, current_app as app, request
from litecord.auth import admin_check
from litecord.schemas import validate
from litecord.admin_schemas import GUILD_UPDATE
from litecord.blueprints.guilds import delete_guild
from litecord.common.guilds import delete_guild
from litecord.errors import GuildNotFound
bp = Blueprint("guilds_admin", __name__)

View File

@ -20,13 +20,17 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
from quart import Blueprint, jsonify, current_app as app, request
from litecord.auth import admin_check
from litecord.blueprints.auth import create_user
from litecord.schemas import validate
from litecord.admin_schemas import USER_CREATE, USER_UPDATE
from litecord.errors import BadRequest, Forbidden
from litecord.utils import async_map
from litecord.blueprints.users import delete_user, user_disconnect, mass_user_update
from litecord.enums import UserFlags
from litecord.common.users import (
create_user,
delete_user,
user_disconnect,
mass_user_update,
)
bp = Blueprint("users_admin", __name__)

View File

@ -118,7 +118,7 @@ async def deprecate_region(region):
return "", 204
async def guild_region_check(app_):
async def guild_region_check():
"""Check all guilds for voice region inconsistencies.
Since the voice migration caused all guilds.region columns
@ -126,13 +126,13 @@ async def guild_region_check(app_):
than one region setup.
"""
regions = await app_.storage.all_voice_regions()
regions = await app.storage.all_voice_regions()
if not regions:
log.info("region check: no regions to move guilds to")
return
res = await app_.db.execute(
res = await app.db.execute(
"""
UPDATE guilds
SET region = (

View File

@ -25,7 +25,8 @@ import bcrypt
from quart import Blueprint, jsonify, request, current_app as app
from logbook import Logger
from litecord.auth import token_check, create_user
from litecord.auth import token_check
from litecord.common.users import create_user
from litecord.schemas import validate, REGISTER, REGISTER_WITH_INVITE
from litecord.errors import BadRequest
from litecord.snowflake import get_snowflake
@ -120,7 +121,7 @@ async def _register_with_invite():
)
user_id, pwd_hash = await create_user(
data["username"], data["email"], data["password"], app.db
data["username"], data["email"], data["password"]
)
return jsonify({"token": make_token(user_id, pwd_hash), "user_id": str(user_id)})

View File

@ -17,65 +17,36 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
import json
from pathlib import Path
from PIL import Image
from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger
from litecord.blueprints.auth import token_check
from litecord.blueprints.checks import channel_check, channel_perm_check
from litecord.blueprints.dms import try_dm_state
from litecord.errors import MessageNotFound, Forbidden, BadRequest
from litecord.errors import MessageNotFound, Forbidden
from litecord.enums import MessageType, ChannelType, GUILD_CHANS
from litecord.snowflake import get_snowflake
from litecord.schemas import validate, MESSAGE_CREATE
from litecord.utils import pg_set_json
from litecord.utils import pg_set_json, query_tuple_from_args, extract_limit
from litecord.permissions import get_permissions
from litecord.embed.sanitizer import fill_embed
from litecord.embed.messages import process_url_embed
from litecord.blueprints.channel.dm_checks import dm_pre_check
from litecord.common.channels import dm_pre_check, try_dm_state
from litecord.images import try_unlink
from litecord.common.messages import (
msg_create_request,
msg_create_check_content,
msg_add_attachment,
msg_guild_text_mentions,
)
log = Logger(__name__)
bp = Blueprint("channel_messages", __name__)
def extract_limit(request_, default: int = 50, max_val: int = 100):
"""Extract a limit kwarg."""
try:
limit = int(request_.args.get("limit", default))
if limit not in range(0, max_val + 1):
raise ValueError()
except (TypeError, ValueError):
raise BadRequest("limit not int")
return limit
def query_tuple_from_args(args: dict, limit: int) -> tuple:
"""Extract a 2-tuple out of request arguments."""
before, after = None, None
if "around" in request.args:
average = int(limit / 2)
around = int(args["around"])
after = around - average
before = around + average
elif "before" in args:
before = int(args["before"])
elif "after" in args:
before = int(args["after"])
return before, after
@bp.route("/<int:channel_id>/messages", methods=["GET"])
async def get_messages(channel_id):
user_id = await token_check()
@ -204,191 +175,8 @@ async def create_message(
return message_id
async def msg_guild_text_mentions(
payload: dict, guild_id: int, mentions_everyone: bool, mentions_here: bool
):
"""Calculates mention data side-effects."""
channel_id = int(payload["channel_id"])
# calculate the user ids we'll bump the mention count for
uids = set()
# first is extracting user mentions
for mention in payload["mentions"]:
uids.add(int(mention["id"]))
# then role mentions
for role_mention in payload["mention_roles"]:
role_id = int(role_mention)
member_ids = await app.storage.get_role_members(role_id)
for member_id in member_ids:
uids.add(member_id)
# at-here only updates the state
# for the users that have a state
# in the channel.
if mentions_here:
uids = set()
await app.db.execute(
"""
UPDATE user_read_state
SET mention_count = mention_count + 1
WHERE channel_id = $1
""",
channel_id,
)
# at-here updates the read state
# for all users, including the ones
# that might not have read permissions
# to the channel.
if mentions_everyone:
uids = set()
member_ids = await app.storage.get_member_ids(guild_id)
await app.db.executemany(
"""
UPDATE user_read_state
SET mention_count = mention_count + 1
WHERE channel_id = $1 AND user_id = $2
""",
[(channel_id, uid) for uid in member_ids],
)
for user_id in uids:
await app.db.execute(
"""
UPDATE user_read_state
SET mention_count = mention_count + 1
WHERE user_id = $1
AND channel_id = $2
""",
user_id,
channel_id,
)
async def msg_create_request() -> tuple:
"""Extract the json input and any file information
the client gave to us in the request.
This only applies to create message route.
"""
form = await request.form
request_json = await request.get_json() or {}
# NOTE: embed isn't set on form data
json_from_form = {
"content": form.get("content", ""),
"nonce": form.get("nonce", "0"),
"tts": json.loads(form.get("tts", "false")),
}
payload_json = json.loads(form.get("payload_json", "{}"))
json_from_form.update(request_json)
json_from_form.update(payload_json)
files = await request.files
# we don't really care about the given fields on the files dict, so
# we only extract the values
return json_from_form, [v for k, v in files.items()]
def msg_create_check_content(payload: dict, files: list, *, use_embeds=False):
"""Check if there is actually any content being sent to us."""
has_content = bool(payload.get("content", ""))
has_files = len(files) > 0
embed_field = "embeds" if use_embeds else "embed"
has_embed = embed_field in payload and payload.get(embed_field) is not None
has_total_content = has_content or has_embed or has_files
if not has_total_content:
raise BadRequest("No content has been provided.")
async def msg_add_attachment(message_id: int, channel_id: int, attachment_file) -> int:
"""Add an attachment to a message.
Parameters
----------
message_id: int
The ID of the message getting the attachment.
channel_id: int
The ID of the channel the message belongs to.
Exists because the attachment URL scheme contains
a channel id. The purpose is unknown, but we are
implementing Discord's behavior.
attachment_file: quart.FileStorage
quart FileStorage instance of the file.
"""
attachment_id = get_snowflake()
filename = attachment_file.filename
# understand file info
mime = attachment_file.mimetype
is_image = mime.startswith("image/")
img_width, img_height = None, None
# extract file size
# TODO: this is probably inneficient
file_size = attachment_file.stream.getbuffer().nbytes
if is_image:
# open with pillow, extract image size
image = Image.open(attachment_file.stream)
img_width, img_height = image.size
# NOTE: DO NOT close the image, as closing the image will
# also close the stream.
# reset it to 0 for later usage
attachment_file.stream.seek(0)
await app.db.execute(
"""
INSERT INTO attachments
(id, channel_id, message_id,
filename, filesize,
image, width, height)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8)
""",
attachment_id,
channel_id,
message_id,
filename,
file_size,
is_image,
img_width,
img_height,
)
ext = filename.split(".")[-1]
with open(f"attachments/{attachment_id}.{ext}", "wb") as attach_file:
attach_file.write(attachment_file.stream.read())
log.debug("written {} bytes for attachment id {}", file_size, attachment_id)
return attachment_id
async def _spawn_embed(app_, payload, **kwargs):
app_.sched.spawn(
process_url_embed(
app_.config, app_.storage, app_.dispatcher, app_.session, payload, **kwargs
)
)
async def _spawn_embed(payload, **kwargs):
app.sched.spawn(process_url_embed(payload, **kwargs))
@bp.route("/<int:channel_id>/messages", methods=["POST"])
@ -458,7 +246,7 @@ async def _create_message(channel_id):
# spawn url processor for embedding of images
perms = await get_permissions(user_id, channel_id)
if perms.bits.embed_links:
await _spawn_embed(app, payload)
await _spawn_embed(payload)
# update read state for the author
await app.db.execute(
@ -536,7 +324,6 @@ async def edit_message(channel_id, message_id):
perms = await get_permissions(user_id, channel_id)
if perms.bits.embed_links:
await _spawn_embed(
app,
{
"id": message_id,
"channel_id": channel_id,

View File

@ -23,10 +23,9 @@ from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger
from litecord.utils import async_map
from litecord.utils import async_map, query_tuple_from_args, extract_limit
from litecord.blueprints.auth import token_check
from litecord.blueprints.checks import channel_check, channel_perm_check
from litecord.blueprints.channel.messages import query_tuple_from_args, extract_limit
from litecord.enums import GUILD_CHANS
@ -165,7 +164,8 @@ def _emoji_sql_simple(emoji: str, param=4):
return emoji_sql(emoji_type, emoji_id, emoji_name, param)
async def remove_reaction(channel_id: int, message_id: int, user_id: int, emoji: str):
async def _remove_reaction(channel_id: int, message_id: int, user_id: int, emoji: str):
"""Remove given reaction from a message."""
ctype, guild_id = await channel_check(user_id, channel_id)
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
@ -201,8 +201,7 @@ async def remove_own_reaction(channel_id, message_id, emoji):
"""Remove a reaction."""
user_id = await token_check()
await remove_reaction(channel_id, message_id, user_id, emoji)
await _remove_reaction(channel_id, message_id, user_id, emoji)
return "", 204
@ -212,7 +211,7 @@ async def remove_user_reaction(channel_id, message_id, emoji, other_id):
user_id = await token_check()
await channel_perm_check(user_id, channel_id, "manage_messages")
await remove_reaction(channel_id, message_id, other_id, emoji)
await _remove_reaction(channel_id, message_id, other_id, emoji)
return "", 204

View File

@ -42,6 +42,7 @@ from litecord.blueprints.dm_channels import gdm_remove_recipient, gdm_destroy
from litecord.utils import search_result_from_list
from litecord.embed.messages import process_url_embed, msg_update_embeds
from litecord.snowflake import snowflake_datetime
from litecord.common.channels import channel_ack
log = Logger(__name__)
bp = Blueprint("channels", __name__)
@ -136,7 +137,7 @@ async def _update_guild_chan_cat(guild_id: int, channel_id: int):
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", child)
async def delete_messages(channel_id):
async def _delete_messages(channel_id):
await app.db.execute(
"""
DELETE FROM channel_pins
@ -162,7 +163,7 @@ async def delete_messages(channel_id):
)
async def guild_cleanup(channel_id):
async def _guild_cleanup(channel_id):
await app.db.execute(
"""
DELETE FROM channel_overwrites
@ -220,8 +221,8 @@ async def close_channel(channel_id):
# didn't work on my setup, so I delete
# everything before moving to the main
# channel table deletes
await delete_messages(channel_id)
await guild_cleanup(channel_id)
await _delete_messages(channel_id)
await _guild_cleanup(channel_id)
await app.db.execute(
f"""
@ -595,48 +596,6 @@ async def trigger_typing(channel_id):
return "", 204
async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
"""ACK a channel."""
if not message_id:
message_id = await app.storage.chan_last_message(channel_id)
await app.db.execute(
"""
INSERT INTO user_read_state
(user_id, channel_id, last_message_id, mention_count)
VALUES
($1, $2, $3, 0)
ON CONFLICT ON CONSTRAINT user_read_state_pkey
DO
UPDATE
SET last_message_id = $3, mention_count = 0
WHERE user_read_state.user_id = $1
AND user_read_state.channel_id = $2
""",
user_id,
channel_id,
message_id,
)
if guild_id:
await app.dispatcher.dispatch_user_guild(
user_id,
guild_id,
"MESSAGE_ACK",
{"message_id": str(message_id), "channel_id": str(channel_id)},
)
else:
# we don't use ChannelDispatcher here because since
# guild_id is None, all user devices are already subscribed
# to the given channel (a dm or a group dm)
await app.dispatcher.dispatch_user(
user_id,
"MESSAGE_ACK",
{"message_id": str(message_id), "channel_id": str(channel_id)},
)
@bp.route("/<int:channel_id>/messages/<int:message_id>/ack", methods=["POST"])
async def ack_channel(channel_id, message_id):
"""Acknowledge a channel."""
@ -799,7 +758,7 @@ async def suppress_embeds(channel_id: int, message_id: int):
message["flags"] = message.get("flags", 0) | MessageFlags.suppress_embeds
await msg_update_embeds(message, [], app.storage, app.dispatcher)
await msg_update_embeds(message, [])
elif not suppress and not url_embeds:
# spawn process_url_embed to restore the embeds, if any
await _msg_unset_flags(message_id, MessageFlags.suppress_embeds)
@ -809,11 +768,7 @@ async def suppress_embeds(channel_id: int, message_id: int):
except KeyError:
pass
app.sched.spawn(
process_url_embed(
app.config, app.storage, app.dispatcher, app.session, message
)
)
app.sched.spawn(process_url_embed(message))
return "", 204

View File

@ -31,6 +31,7 @@ from ..snowflake import get_snowflake
from .auth import token_check
from litecord.blueprints.dm_channels import gdm_create, gdm_add_recipient
from litecord.common.channels import try_dm_state
log = Logger(__name__)
bp = Blueprint("dms", __name__)
@ -44,24 +45,6 @@ async def get_dms():
return jsonify(dms)
async def try_dm_state(user_id: int, dm_id: int):
"""Try inserting the user into the dm state
for the given DM.
Does not do anything if the user is already
in the dm state.
"""
await app.db.execute(
"""
INSERT INTO dm_channel_state (user_id, dm_id)
VALUES ($1, $2)
ON CONFLICT DO NOTHING
""",
user_id,
dm_id,
)
async def jsonify_dm(dm_id: int, user_id: int):
dm_chan = await app.storage.get_dm(dm_id, user_id)
return jsonify(dm_chan)

View File

@ -27,77 +27,11 @@ from litecord.blueprints.guild.roles import gen_pairs
from litecord.schemas import validate, ROLE_UPDATE_POSITION, CHAN_CREATE
from litecord.blueprints.checks import guild_check, guild_owner_check, guild_perm_check
from litecord.common.guilds import create_guild_channel
bp = Blueprint("guild_channels", __name__)
async def _specific_chan_create(channel_id, ctype, **kwargs):
if ctype == ChannelType.GUILD_TEXT:
await app.db.execute(
"""
INSERT INTO guild_text_channels (id, topic)
VALUES ($1, $2)
""",
channel_id,
kwargs.get("topic", ""),
)
elif ctype == ChannelType.GUILD_VOICE:
await app.db.execute(
"""
INSERT INTO guild_voice_channels (id, bitrate, user_limit)
VALUES ($1, $2, $3)
""",
channel_id,
kwargs.get("bitrate", 64),
kwargs.get("user_limit", 0),
)
async def create_guild_channel(
guild_id: int, channel_id: int, ctype: ChannelType, **kwargs
):
"""Create a channel in a guild."""
await app.db.execute(
"""
INSERT INTO channels (id, channel_type)
VALUES ($1, $2)
""",
channel_id,
ctype.value,
)
# calc new pos
max_pos = await app.db.fetchval(
"""
SELECT MAX(position)
FROM guild_channels
WHERE guild_id = $1
""",
guild_id,
)
# account for the first channel in a guild too
max_pos = max_pos or 0
# all channels go to guild_channels
await app.db.execute(
"""
INSERT INTO guild_channels (id, guild_id, name, position)
VALUES ($1, $2, $3, $4)
""",
channel_id,
guild_id,
kwargs["name"],
max_pos + 1,
)
# the rest of sql magic is dependant on the channel
# we're creating (a text or voice or category),
# so we use this function.
await _specific_chan_create(channel_id, ctype, **kwargs)
@bp.route("/<int:guild_id>/channels", methods=["GET"])
async def get_guild_channels(guild_id):
"""Get the list of channels in a guild."""

View File

@ -23,47 +23,11 @@ from litecord.blueprints.auth import token_check
from litecord.blueprints.checks import guild_perm_check
from litecord.schemas import validate, GUILD_PRUNE
from litecord.common.guilds import remove_member, remove_member_multi
bp = Blueprint("guild_moderation", __name__)
async def remove_member(guild_id: int, member_id: int):
"""Do common tasks related to deleting a member from the guild,
such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE."""
await app.db.execute(
"""
DELETE FROM members
WHERE guild_id = $1 AND user_id = $2
""",
guild_id,
member_id,
)
await app.dispatcher.dispatch_user_guild(
member_id,
guild_id,
"GUILD_DELETE",
{"guild_id": str(guild_id), "unavailable": False},
)
await app.dispatcher.unsub("guild", guild_id, member_id)
await app.dispatcher.dispatch("lazy_guild", guild_id, "remove_member", member_id)
await app.dispatcher.dispatch_guild(
guild_id,
"GUILD_MEMBER_REMOVE",
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
)
async def remove_member_multi(guild_id: int, members: list):
"""Remove multiple members."""
for member_id in members:
await remove_member(guild_id, member_id)
@bp.route("/<int:guild_id>/members/<int:member_id>", methods=["DELETE"])
async def kick_guild_member(guild_id, member_id):
"""Remove a member from a guild."""
@ -221,6 +185,5 @@ async def begin_guild_prune(guild_id):
days = j["days"]
member_ids = await get_prune(guild_id, days)
app.loop.create_task(remove_member_multi(guild_id, member_ids))
app.sched.spawn(remove_member_multi(guild_id, member_ids))
return jsonify({"pruned": len(member_ids)})

View File

@ -27,11 +27,9 @@ from litecord.auth import token_check
from litecord.blueprints.checks import guild_check, guild_perm_check
from litecord.schemas import validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION
from litecord.snowflake import get_snowflake
from litecord.utils import dict_get
from litecord.permissions import get_role_perms
from litecord.utils import maybe_lazy_guild_dispatch
from litecord.common.guilds import create_role
DEFAULT_EVERYONE_PERMS = 104324161
log = Logger(__name__)
bp = Blueprint("guild_roles", __name__)
@ -45,71 +43,6 @@ async def get_guild_roles(guild_id):
return jsonify(await app.storage.get_role_data(guild_id))
async def _maybe_lg(guild_id: int, event: str, role, force: bool = False):
# sometimes we want to dispatch an event
# even if the role isn't hoisted
# an example of such a case is when a role loses
# its hoist status.
# check if is a dict first because role_delete
# only receives the role id.
if isinstance(role, dict) and not role["hoist"] and not force:
return
await app.dispatcher.dispatch("lazy_guild", guild_id, event, role)
async def create_role(guild_id, name: str, **kwargs):
"""Create a role in a guild."""
new_role_id = get_snowflake()
everyone_perms = await get_role_perms(guild_id, guild_id)
default_perms = dict_get(kwargs, "default_perms", everyone_perms.binary)
# update all roles so that we have space for pos 1, but without
# sending GUILD_ROLE_UPDATE for everyone
await app.db.execute(
"""
UPDATE roles
SET
position = position + 1
WHERE guild_id = $1
AND NOT (position = 0)
""",
guild_id,
)
await app.db.execute(
"""
INSERT INTO roles (id, guild_id, name, color,
hoist, position, permissions, managed, mentionable)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
""",
new_role_id,
guild_id,
name,
dict_get(kwargs, "color", 0),
dict_get(kwargs, "hoist", False),
# always set ourselves on position 1
1,
int(dict_get(kwargs, "permissions", default_perms)),
False,
dict_get(kwargs, "mentionable", False),
)
role = await app.storage.get_role(new_role_id, guild_id)
# we need to update the lazy guild handlers for the newly created group
await _maybe_lg(guild_id, "new_role", role)
await app.dispatcher.dispatch_guild(
guild_id, "GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role}
)
return role
@bp.route("/<int:guild_id>/roles", methods=["POST"])
async def create_guild_role(guild_id: int):
"""Add a role to a guild"""
@ -132,7 +65,7 @@ async def _role_update_dispatch(role_id: int, guild_id: int):
"""Dispatch a GUILD_ROLE_UPDATE with updated information on a role."""
role = await app.storage.get_role(role_id, guild_id)
await _maybe_lg(guild_id, "role_pos_upd", role)
await maybe_lazy_guild_dispatch(guild_id, "role_pos_upd", role)
await app.dispatcher.dispatch_guild(
guild_id, "GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role}
@ -343,7 +276,7 @@ async def update_guild_role(guild_id, role_id):
)
role = await _role_update_dispatch(role_id, guild_id)
await _maybe_lg(guild_id, "role_update", role, True)
await maybe_lazy_guild_dispatch(guild_id, "role_update", role, True)
return jsonify(role)
@ -369,7 +302,7 @@ async def delete_guild_role(guild_id, role_id):
if res == "DELETE 0":
return "", 204
await _maybe_lg(guild_id, "role_delete", role_id, True)
await maybe_lazy_guild_dispatch(guild_id, "role_delete", role_id, True)
await app.dispatcher.dispatch_guild(
guild_id,

View File

@ -21,8 +21,12 @@ from typing import Optional, List
from quart import Blueprint, request, current_app as app, jsonify
from litecord.blueprints.guild.channels import create_guild_channel
from litecord.blueprints.guild.roles import create_role, DEFAULT_EVERYONE_PERMS
from litecord.common.guilds import (
create_role,
create_guild_channel,
delete_guild,
create_guild_settings,
)
from ..auth import token_check
from ..snowflake import get_snowflake
@ -34,44 +38,17 @@ from ..schemas import (
SEARCH_CHANNEL,
VANITY_URL_PATCH,
)
from .channels import channel_ack
from .checks import guild_check, guild_owner_check, guild_perm_check
from ..common.channels import channel_ack
from litecord.utils import to_update, search_result_from_list
from litecord.errors import BadRequest
from litecord.permissions import get_permissions
DEFAULT_EVERYONE_PERMS = 104324161
bp = Blueprint("guilds", __name__)
async def create_guild_settings(guild_id: int, user_id: int):
"""Create guild settings for the user
joining the guild."""
# new guild_settings are based off the currently
# set guild settings (for the guild)
m_notifs = await app.db.fetchval(
"""
SELECT default_message_notifications
FROM guilds
WHERE id = $1
""",
guild_id,
)
await app.db.execute(
"""
INSERT INTO guild_settings
(user_id, guild_id, message_notifications)
VALUES
($1, $2, $3)
""",
user_id,
guild_id,
m_notifs,
)
async def add_member(guild_id: int, user_id: int):
"""Add a user to a guild."""
await app.db.execute(
@ -393,36 +370,6 @@ async def _update_guild(guild_id):
return jsonify(guild)
async def delete_guild(guild_id: int, *, app_=None):
"""Delete a single guild."""
app_ = app_ or app
await app_.db.execute(
"""
DELETE FROM guilds
WHERE guilds.id = $1
""",
guild_id,
)
# Discord's client expects IDs being string
await app_.dispatcher.dispatch(
"guild",
guild_id,
"GUILD_DELETE",
{
"guild_id": str(guild_id),
"id": str(guild_id),
# 'unavailable': False,
},
)
# remove from the dispatcher so nobody
# becomes the little memer that tries to fuck up with
# everybody's gateway
await app_.dispatcher.remove("guild", guild_id)
@bp.route("/<int:guild_id>", methods=["DELETE"])
# this endpoint is not documented, but used by the official client.
@bp.route("/<int:guild_id>/delete", methods=["POST"])

View File

@ -64,7 +64,7 @@ async def _get_default_user_avatar(default_id: int):
async def _handle_webhook_avatar(md_url_redir: str):
md_url = make_md_req_url(app.config, "img", EmbedURL(md_url_redir))
md_url = make_md_req_url("img", EmbedURL(md_url_redir))
return redirect(md_url)

View File

@ -28,7 +28,6 @@ from ..auth import token_check
from ..schemas import validate, INVITE
from ..enums import ChannelType
from ..errors import BadRequest, Forbidden
from .guilds import create_guild_settings
from ..utils import async_map
from litecord.blueprints.checks import (
@ -39,6 +38,7 @@ from litecord.blueprints.checks import (
)
from litecord.blueprints.dm_channels import gdm_is_member, gdm_add_recipient
from litecord.common.guilds import create_guild_settings
log = Logger(__name__)
bp = Blueprint("invites", __name__)

View File

@ -30,7 +30,7 @@ from litecord.snowflake import snowflake_datetime, get_snowflake
from litecord.errors import BadRequest
from litecord.types import timestamp_, HOURS
from litecord.enums import UserFlags, PremiumType
from litecord.blueprints.users import mass_user_update
from litecord.common.users import mass_user_update
log = Logger(__name__)
bp = Blueprint("users_billing", __name__)
@ -122,11 +122,8 @@ async def get_payment_source_ids(user_id: int) -> list:
return [r["id"] for r in rows]
async def get_payment_ids(user_id: int, db=None) -> list:
if not db:
db = app.db
rows = await db.fetch(
async def get_payment_ids(user_id: int) -> list:
rows = await app.db.fetch(
"""
SELECT id
FROM user_payments
@ -151,13 +148,9 @@ async def get_subscription_ids(user_id: int) -> list:
return [r["id"] for r in rows]
async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
async def get_payment_source(user_id: int, source_id: int) -> dict:
"""Get a payment source's information."""
if not db:
db = app.db
source_type = await db.fetchval(
source_type = await app.db.fetchval(
"""
SELECT source_type
FROM user_payment_sources
@ -176,7 +169,7 @@ async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
fields = ",".join(specific_fields)
extras_row = await db.fetchrow(
extras_row = await app.db.fetchrow(
f"""
SELECT {fields}, billing_address, default_, id::text
FROM user_payment_sources
@ -199,12 +192,9 @@ async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
return {**source, **derow}
async def get_subscription(subscription_id: int, db=None):
async def get_subscription(subscription_id: int):
"""Get a subscription's information."""
if not db:
db = app.db
row = await db.fetchrow(
row = await app.db.fetchrow(
"""
SELECT id::text, source_id::text AS payment_source_id,
user_id,
@ -231,12 +221,9 @@ async def get_subscription(subscription_id: int, db=None):
return drow
async def get_payment(payment_id: int, db=None):
async def get_payment(payment_id: int):
"""Get a single payment's information."""
if not db:
db = app.db
row = await db.fetchrow(
row = await app.db.fetchrow(
"""
SELECT id::text, source_id, subscription_id, user_id,
amount, amount_refunded, currency,
@ -255,27 +242,22 @@ async def get_payment(payment_id: int, db=None):
drow["created_at"] = snowflake_datetime(int(drow["id"]))
drow["payment_source"] = await get_payment_source(
row["user_id"], row["source_id"], db
)
drow["payment_source"] = await get_payment_source(row["user_id"], row["source_id"])
drow["subscription"] = await get_subscription(row["subscription_id"], db)
drow["subscription"] = await get_subscription(row["subscription_id"])
return drow
async def create_payment(subscription_id, db=None):
async def create_payment(subscription_id):
"""Create a payment."""
if not db:
db = app.db
sub = await get_subscription(subscription_id, db)
sub = await get_subscription(subscription_id)
new_id = get_snowflake()
amount = AMOUNTS[sub["payment_gateway_plan_id"]]
await db.execute(
await app.db.execute(
"""
INSERT INTO user_payments (
id, source_id, subscription_id, user_id,
@ -298,9 +280,9 @@ async def create_payment(subscription_id, db=None):
return new_id
async def process_subscription(app, subscription_id: int):
async def process_subscription(subscription_id: int):
"""Process a single subscription."""
sub = await get_subscription(subscription_id, app.db)
sub = await get_subscription(subscription_id)
user_id = int(sub["user_id"])
@ -365,7 +347,7 @@ async def process_subscription(app, subscription_id: int):
)
# dispatch updated user to all possible clients
await mass_user_update(user_id, app)
await mass_user_update(user_id)
@bp.route("/@me/billing/payment-sources", methods=["GET"])
@ -474,11 +456,11 @@ async def _create_subscription():
1,
)
await create_payment(new_id, app.db)
await create_payment(new_id)
# make sure we update the user's premium status
# and dispatch respective user updates to other people.
await process_subscription(app, new_id)
await process_subscription(new_id)
return jsonify(await get_subscription(new_id))

View File

@ -21,6 +21,8 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
this file only serves the periodic payment job code.
"""
import datetime
from quart import current_app as app
from asyncio import sleep, CancelledError
from logbook import Logger
@ -47,14 +49,14 @@ THRESHOLDS = {
}
async def _resched(app):
async def _resched():
log.debug("waiting 30 minutes for job.")
await sleep(30 * MINUTES)
app.sched.spawn(payment_job(app))
app.sched.spawn(payment_job())
async def _process_user_payments(app, user_id: int):
payments = await get_payment_ids(user_id, app.db)
async def _process_user_payments(user_id: int):
payments = await get_payment_ids(user_id)
if not payments:
log.debug("no payments for uid {}, skipping", user_id)
@ -64,7 +66,7 @@ async def _process_user_payments(app, user_id: int):
latest_payment = max(payments)
payment_data = await get_payment(latest_payment, app.db)
payment_data = await get_payment(latest_payment)
# calculate the difference between this payment
# and now.
@ -74,7 +76,7 @@ async def _process_user_payments(app, user_id: int):
delta = now - payment_tstamp
sub_id = int(payment_data["subscription"]["id"])
subscription = await get_subscription(sub_id, app.db)
subscription = await get_subscription(sub_id)
# if the max payment is X days old, we create another.
# X is 30 for monthly subscriptions of nitro,
@ -89,12 +91,12 @@ async def _process_user_payments(app, user_id: int):
# create_payment does not call any Stripe
# or BrainTree APIs at all, since we'll just
# give it as free.
await create_payment(sub_id, app.db)
await create_payment(sub_id)
else:
log.debug("sid={}, missing {} days", sub_id, threshold - delta.days)
async def payment_job(app):
async def payment_job():
"""Main payment job function.
This function will check through users' payments
@ -115,7 +117,7 @@ async def payment_job(app):
for row in user_ids:
user_id = row["user_id"]
try:
await _process_user_payments(app, user_id)
await _process_user_payments(user_id)
except Exception:
log.exception("error while processing user payments")
@ -128,11 +130,11 @@ async def payment_job(app):
for row in subscribers:
try:
await process_subscription(app, row["id"])
await process_subscription(row["id"])
except Exception:
log.exception("error while processing subscription")
log.debug("rescheduling..")
try:
await _resched(app)
await _resched()
except CancelledError:
log.info("cancelled while waiting for resched")

View File

@ -17,7 +17,6 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from os import urandom
from asyncpg import UniqueViolationError
from quart import Blueprint, jsonify, request, current_app as app
@ -27,8 +26,8 @@ from ..errors import Forbidden, BadRequest, Unauthorized
from ..schemas import validate, USER_UPDATE, GET_MENTIONS
from .guilds import guild_check
from litecord.auth import token_check, hash_data, check_username_usage, roll_discrim
from litecord.blueprints.guild.mod import remove_member
from litecord.auth import token_check, hash_data
from litecord.common.guilds import remove_member
from litecord.enums import PremiumType
from litecord.images import parse_data_uri
@ -36,46 +35,18 @@ from litecord.permissions import base_permissions
from litecord.blueprints.auth import check_password
from litecord.utils import to_update
from litecord.common.users import (
mass_user_update,
delete_user,
check_username_usage,
roll_discrim,
user_disconnect,
)
bp = Blueprint("user", __name__)
log = Logger(__name__)
async def mass_user_update(user_id):
"""Dispatch USER_UPDATE in a mass way."""
# by using dispatch_with_filter
# we're guaranteeing all shards will get
# a USER_UPDATE once and not any others.
session_ids = []
public_user = await app.storage.get_user(user_id)
private_user = await app.storage.get_user(user_id, secure=True)
session_ids.extend(
await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", private_user)
)
guild_ids = await app.user_storage.get_user_guilds(user_id)
friend_ids = await app.user_storage.get_friend_ids(user_id)
session_ids.extend(
await app.dispatcher.dispatch_many_filter_list(
"guild", guild_ids, session_ids, "USER_UPDATE", public_user
)
)
session_ids.extend(
await app.dispatcher.dispatch_many_filter_list(
"friend", friend_ids, session_ids, "USER_UPDATE", public_user
)
)
await app.dispatcher.dispatch_many("lazy_guild", guild_ids, "update_user", user_id)
return public_user, private_user
@bp.route("/@me", methods=["GET"])
async def get_me():
"""Get the current user's information."""
@ -276,7 +247,7 @@ async def patch_me():
user.pop("password_hash")
_, private_user = await mass_user_update(user_id, app)
_, private_user = await mass_user_update(user_id)
return jsonify(private_user)
@ -319,7 +290,6 @@ async def leave_guild(guild_id: int):
await guild_check(user_id, guild_id)
await remove_member(guild_id, user_id)
return "", 204
@ -468,118 +438,6 @@ async def _get_mentions():
return jsonify(res)
def rand_hex(length: int = 8) -> str:
"""Generate random hex characters."""
return urandom(length).hex()[:length]
async def _del_from_table(db, table: str, user_id: int):
"""Delete a row from a table."""
column = {
"channel_overwrites": "target_user",
"user_settings": "id",
"group_dm_members": "member_id",
}.get(table, "user_id")
res = await db.execute(
f"""
DELETE FROM {table}
WHERE {column} = $1
""",
user_id,
)
log.info("Deleting uid {} from {}, res: {!r}", user_id, table, res)
async def delete_user(user_id, *, mass_update: bool = True):
"""Delete a user. Does not disconnect the user."""
db = app.db
new_username = f"Deleted User {rand_hex()}"
# by using a random hex in password_hash
# we break attempts at using the default '123' password hash
# to issue valid tokens for deleted users.
await db.execute(
"""
UPDATE users
SET
username = $1,
email = NULL,
mfa_enabled = false,
verified = false,
avatar = NULL,
flags = 0,
premium_since = NULL,
phone = '',
password_hash = $2
WHERE
id = $3
""",
new_username,
rand_hex(32),
user_id,
)
# remove the user from various tables
await _del_from_table(db, "user_settings", user_id)
await _del_from_table(db, "user_payment_sources", user_id)
await _del_from_table(db, "user_subscriptions", user_id)
await _del_from_table(db, "user_payments", user_id)
await _del_from_table(db, "user_read_state", user_id)
await _del_from_table(db, "guild_settings", user_id)
await _del_from_table(db, "guild_settings_channel_overrides", user_id)
await db.execute(
"""
DELETE FROM relationships
WHERE user_id = $1 OR peer_id = $1
""",
user_id,
)
# DMs are still maintained, but not the state.
await _del_from_table(db, "dm_channel_state", user_id)
# NOTE: we don't delete the group dms the user is an owner of...
# TODO: group dm owner reassign when the owner leaves a gdm
await _del_from_table(db, "group_dm_members", user_id)
await _del_from_table(db, "members", user_id)
await _del_from_table(db, "member_roles", user_id)
await _del_from_table(db, "channel_overwrites", user_id)
# after updating the user, we send USER_UPDATE so that all the other
# clients can refresh their caches on the now-deleted user
if mass_update:
await mass_user_update(user_id)
async def user_disconnect(user_id: int):
"""Disconnects the given user's devices."""
# after removing the user from all tables, we need to force
# all known user states to reconnect, causing the user to not
# be online anymore.
user_states = app.state_manager.user_states(user_id)
for state in user_states:
# make it unable to resume
app.state_manager.remove(state)
if not state.ws:
continue
# force a close, 4000 should make the client reconnect.
await state.ws.ws.close(4000)
# force everyone to see the user as offline
await app.presence.dispatch_pres(
user_id, {"afk": False, "status": "offline", "game": None, "since": 0}
)
@bp.route("/@me/delete", methods=["POST"])
async def delete_account():
"""Delete own account.

View File

@ -43,7 +43,7 @@ from litecord.snowflake import get_snowflake
from litecord.utils import async_map
from litecord.errors import WebhookNotFound, Unauthorized, ChannelNotFound, BadRequest
from litecord.blueprints.channel.messages import (
from litecord.common.messages import (
msg_create_request,
msg_create_check_content,
msg_add_attachment,
@ -499,9 +499,7 @@ async def execute_webhook(webhook_id: int, webhook_token):
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload)
# spawn embedder in the background, even when we're on a webhook.
app.sched.spawn(
process_url_embed(app.config, app.storage, app.dispatcher, app.session, payload)
)
app.sched.spawn(process_url_embed(payload))
# we can assume its a guild text channel, so just call it
await msg_guild_text_mentions(payload, guild_id, mentions_everyone, mentions_here)

View File

View File

@ -16,15 +16,55 @@ You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from quart import current_app as app
from litecord.errors import Forbidden
from litecord.errors import ForbiddenDM
from litecord.enums import RelationshipType
class ForbiddenDM(Forbidden):
error_code = 50007
async def channel_ack(
user_id: int, guild_id: int, channel_id: int, message_id: int = None
):
"""ACK a channel."""
if not message_id:
message_id = await app.storage.chan_last_message(channel_id)
await app.db.execute(
"""
INSERT INTO user_read_state
(user_id, channel_id, last_message_id, mention_count)
VALUES
($1, $2, $3, 0)
ON CONFLICT ON CONSTRAINT user_read_state_pkey
DO
UPDATE
SET last_message_id = $3, mention_count = 0
WHERE user_read_state.user_id = $1
AND user_read_state.channel_id = $2
""",
user_id,
channel_id,
message_id,
)
if guild_id:
await app.dispatcher.dispatch_user_guild(
user_id,
guild_id,
"MESSAGE_ACK",
{"message_id": str(message_id), "channel_id": str(channel_id)},
)
else:
# we don't use ChannelDispatcher here because since
# guild_id is None, all user devices are already subscribed
# to the given channel (a dm or a group dm)
await app.dispatcher.dispatch_user(
user_id,
"MESSAGE_ACK",
{"message_id": str(message_id), "channel_id": str(channel_id)},
)
async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
@ -75,3 +115,21 @@ async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
# if after this filtering we don't have any more guilds, error
if not mutual_guilds:
raise ForbiddenDM()
async def try_dm_state(user_id: int, dm_id: int):
"""Try inserting the user into the dm state
for the given DM.
Does not do anything if the user is already
in the dm state.
"""
await app.db.execute(
"""
INSERT INTO dm_channel_state (user_id, dm_id)
VALUES ($1, $2)
ON CONFLICT DO NOTHING
""",
user_id,
dm_id,
)

234
litecord/common/guilds.py Normal file
View File

@ -0,0 +1,234 @@
"""
Litecord
Copyright (C) 2018-2019 Luna Mendes
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3 of the License.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from quart import current_app as app
from ..snowflake import get_snowflake
from ..permissions import get_role_perms
from ..utils import dict_get, maybe_lazy_guild_dispatch
from ..enums import ChannelType
async def remove_member(guild_id: int, member_id: int):
"""Do common tasks related to deleting a member from the guild,
such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE."""
await app.db.execute(
"""
DELETE FROM members
WHERE guild_id = $1 AND user_id = $2
""",
guild_id,
member_id,
)
await app.dispatcher.dispatch_user_guild(
member_id,
guild_id,
"GUILD_DELETE",
{"guild_id": str(guild_id), "unavailable": False},
)
await app.dispatcher.unsub("guild", guild_id, member_id)
await app.dispatcher.dispatch("lazy_guild", guild_id, "remove_member", member_id)
await app.dispatcher.dispatch_guild(
guild_id,
"GUILD_MEMBER_REMOVE",
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
)
async def remove_member_multi(guild_id: int, members: list):
"""Remove multiple members."""
for member_id in members:
await remove_member(guild_id, member_id)
async def create_role(guild_id, name: str, **kwargs):
"""Create a role in a guild."""
new_role_id = get_snowflake()
everyone_perms = await get_role_perms(guild_id, guild_id)
default_perms = dict_get(kwargs, "default_perms", everyone_perms.binary)
# update all roles so that we have space for pos 1, but without
# sending GUILD_ROLE_UPDATE for everyone
await app.db.execute(
"""
UPDATE roles
SET
position = position + 1
WHERE guild_id = $1
AND NOT (position = 0)
""",
guild_id,
)
await app.db.execute(
"""
INSERT INTO roles (id, guild_id, name, color,
hoist, position, permissions, managed, mentionable)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
""",
new_role_id,
guild_id,
name,
dict_get(kwargs, "color", 0),
dict_get(kwargs, "hoist", False),
# always set ourselves on position 1
1,
int(dict_get(kwargs, "permissions", default_perms)),
False,
dict_get(kwargs, "mentionable", False),
)
role = await app.storage.get_role(new_role_id, guild_id)
# we need to update the lazy guild handlers for the newly created group
await maybe_lazy_guild_dispatch(guild_id, "new_role", role)
await app.dispatcher.dispatch_guild(
guild_id, "GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role}
)
return role
async def _specific_chan_create(channel_id, ctype, **kwargs):
if ctype == ChannelType.GUILD_TEXT:
await app.db.execute(
"""
INSERT INTO guild_text_channels (id, topic)
VALUES ($1, $2)
""",
channel_id,
kwargs.get("topic", ""),
)
elif ctype == ChannelType.GUILD_VOICE:
await app.db.execute(
"""
INSERT INTO guild_voice_channels (id, bitrate, user_limit)
VALUES ($1, $2, $3)
""",
channel_id,
kwargs.get("bitrate", 64),
kwargs.get("user_limit", 0),
)
async def create_guild_channel(
guild_id: int, channel_id: int, ctype: ChannelType, **kwargs
):
"""Create a channel in a guild."""
await app.db.execute(
"""
INSERT INTO channels (id, channel_type)
VALUES ($1, $2)
""",
channel_id,
ctype.value,
)
# calc new pos
max_pos = await app.db.fetchval(
"""
SELECT MAX(position)
FROM guild_channels
WHERE guild_id = $1
""",
guild_id,
)
# account for the first channel in a guild too
max_pos = max_pos or 0
# all channels go to guild_channels
await app.db.execute(
"""
INSERT INTO guild_channels (id, guild_id, name, position)
VALUES ($1, $2, $3, $4)
""",
channel_id,
guild_id,
kwargs["name"],
max_pos + 1,
)
# the rest of sql magic is dependant on the channel
# we're creating (a text or voice or category),
# so we use this function.
await _specific_chan_create(channel_id, ctype, **kwargs)
async def delete_guild(guild_id: int):
"""Delete a single guild."""
await app.db.execute(
"""
DELETE FROM guilds
WHERE guilds.id = $1
""",
guild_id,
)
# Discord's client expects IDs being string
await app.dispatcher.dispatch(
"guild",
guild_id,
"GUILD_DELETE",
{
"guild_id": str(guild_id),
"id": str(guild_id),
# 'unavailable': False,
},
)
# remove from the dispatcher so nobody
# becomes the little memer that tries to fuck up with
# everybody's gateway
await app.dispatcher.remove("guild", guild_id)
async def create_guild_settings(guild_id: int, user_id: int):
"""Create guild settings for the user
joining the guild."""
# new guild_settings are based off the currently
# set guild settings (for the guild)
m_notifs = await app.db.fetchval(
"""
SELECT default_message_notifications
FROM guilds
WHERE id = $1
""",
guild_id,
)
await app.db.execute(
"""
INSERT INTO guild_settings
(user_id, guild_id, message_notifications)
VALUES
($1, $2, $3)
""",
user_id,
guild_id,
m_notifs,
)

189
litecord/common/messages.py Normal file
View File

@ -0,0 +1,189 @@
import json
import logging
from PIL import Image
from quart import request, current_app as app
from litecord.errors import BadRequest
from ..snowflake import get_snowflake
log = logging.getLogger(__name__)
async def msg_create_request() -> tuple:
"""Extract the json input and any file information
the client gave to us in the request.
This only applies to create message route.
"""
form = await request.form
request_json = await request.get_json() or {}
# NOTE: embed isn't set on form data
json_from_form = {
"content": form.get("content", ""),
"nonce": form.get("nonce", "0"),
"tts": json.loads(form.get("tts", "false")),
}
payload_json = json.loads(form.get("payload_json", "{}"))
json_from_form.update(request_json)
json_from_form.update(payload_json)
files = await request.files
# we don't really care about the given fields on the files dict, so
# we only extract the values
return json_from_form, [v for k, v in files.items()]
def msg_create_check_content(payload: dict, files: list, *, use_embeds=False):
"""Check if there is actually any content being sent to us."""
has_content = bool(payload.get("content", ""))
has_files = len(files) > 0
embed_field = "embeds" if use_embeds else "embed"
has_embed = embed_field in payload and payload.get(embed_field) is not None
has_total_content = has_content or has_embed or has_files
if not has_total_content:
raise BadRequest("No content has been provided.")
async def msg_add_attachment(message_id: int, channel_id: int, attachment_file) -> int:
"""Add an attachment to a message.
Parameters
----------
message_id: int
The ID of the message getting the attachment.
channel_id: int
The ID of the channel the message belongs to.
Exists because the attachment URL scheme contains
a channel id. The purpose is unknown, but we are
implementing Discord's behavior.
attachment_file: quart.FileStorage
quart FileStorage instance of the file.
"""
attachment_id = get_snowflake()
filename = attachment_file.filename
# understand file info
mime = attachment_file.mimetype
is_image = mime.startswith("image/")
img_width, img_height = None, None
# extract file size
# TODO: this is probably inneficient
file_size = attachment_file.stream.getbuffer().nbytes
if is_image:
# open with pillow, extract image size
image = Image.open(attachment_file.stream)
img_width, img_height = image.size
# NOTE: DO NOT close the image, as closing the image will
# also close the stream.
# reset it to 0 for later usage
attachment_file.stream.seek(0)
await app.db.execute(
"""
INSERT INTO attachments
(id, channel_id, message_id,
filename, filesize,
image, width, height)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8)
""",
attachment_id,
channel_id,
message_id,
filename,
file_size,
is_image,
img_width,
img_height,
)
ext = filename.split(".")[-1]
with open(f"attachments/{attachment_id}.{ext}", "wb") as attach_file:
attach_file.write(attachment_file.stream.read())
log.debug("written {} bytes for attachment id {}", file_size, attachment_id)
return attachment_id
async def msg_guild_text_mentions(
payload: dict, guild_id: int, mentions_everyone: bool, mentions_here: bool
):
"""Calculates mention data side-effects."""
channel_id = int(payload["channel_id"])
# calculate the user ids we'll bump the mention count for
uids = set()
# first is extracting user mentions
for mention in payload["mentions"]:
uids.add(int(mention["id"]))
# then role mentions
for role_mention in payload["mention_roles"]:
role_id = int(role_mention)
member_ids = await app.storage.get_role_members(role_id)
for member_id in member_ids:
uids.add(member_id)
# at-here only updates the state
# for the users that have a state
# in the channel.
if mentions_here:
uids = set()
await app.db.execute(
"""
UPDATE user_read_state
SET mention_count = mention_count + 1
WHERE channel_id = $1
""",
channel_id,
)
# at-here updates the read state
# for all users, including the ones
# that might not have read permissions
# to the channel.
if mentions_everyone:
uids = set()
member_ids = await app.storage.get_member_ids(guild_id)
await app.db.executemany(
"""
UPDATE user_read_state
SET mention_count = mention_count + 1
WHERE channel_id = $1 AND user_id = $2
""",
[(channel_id, uid) for uid in member_ids],
)
for user_id in uids:
await app.db.execute(
"""
UPDATE user_read_state
SET mention_count = mention_count + 1
WHERE user_id = $1
AND channel_id = $2
""",
user_id,
channel_id,
)

273
litecord/common/users.py Normal file
View File

@ -0,0 +1,273 @@
"""
Litecord
Copyright (C) 2018-2019 Luna Mendes
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3 of the License.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
import logging
from random import randint
from typing import Tuple, Optional
from quart import current_app as app
from asyncpg import UniqueViolationError
from ..snowflake import get_snowflake
from ..errors import BadRequest
from ..auth import hash_data
from ..utils import rand_hex
log = logging.getLogger(__name__)
async def mass_user_update(user_id):
"""Dispatch USER_UPDATE in a mass way."""
# by using dispatch_with_filter
# we're guaranteeing all shards will get
# a USER_UPDATE once and not any others.
session_ids = []
public_user = await app.storage.get_user(user_id)
private_user = await app.storage.get_user(user_id, secure=True)
session_ids.extend(
await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", private_user)
)
guild_ids = await app.user_storage.get_user_guilds(user_id)
friend_ids = await app.user_storage.get_friend_ids(user_id)
session_ids.extend(
await app.dispatcher.dispatch_many_filter_list(
"guild", guild_ids, session_ids, "USER_UPDATE", public_user
)
)
session_ids.extend(
await app.dispatcher.dispatch_many_filter_list(
"friend", friend_ids, session_ids, "USER_UPDATE", public_user
)
)
await app.dispatcher.dispatch_many("lazy_guild", guild_ids, "update_user", user_id)
return public_user, private_user
async def check_username_usage(username: str):
"""Raise an error if too many people are with the same username."""
same_username = await app.db.fetchval(
"""
SELECT COUNT(*)
FROM users
WHERE username = $1
""",
username,
)
if same_username > 9000:
raise BadRequest(
"Too many people.",
{
"username": "Too many people used the same username. "
"Please choose another"
},
)
def _raw_discrim() -> str:
discrim_number = randint(1, 9999)
return "%04d" % discrim_number
async def roll_discrim(username: str) -> Optional[str]:
"""Roll a discriminator for a DiscordTag.
Tries to generate one 10 times.
Calls check_username_usage.
"""
# we shouldn't roll discrims for usernames
# that have been used too much.
await check_username_usage(username)
# max 10 times for a reroll
for _ in range(10):
# generate random discrim
discrim = _raw_discrim()
# check if anyone is with it
res = await app.db.fetchval(
"""
SELECT id
FROM users
WHERE username = $1 AND discriminator = $2
""",
username,
discrim,
)
# if no user is found with the (username, discrim)
# pair, then this is unique! return it.
if res is None:
return discrim
return None
async def create_user(username: str, email: str, password: str) -> Tuple[int, str]:
"""Create a single user.
Generates a distriminator and other information. You can fetch the user
data back with :meth:`Storage.get_user`.
"""
new_id = get_snowflake()
new_discrim = await roll_discrim(username)
if new_discrim is None:
raise BadRequest(
"Unable to register.",
{"username": "Too many people are with this username."},
)
pwd_hash = await hash_data(password)
try:
await app.db.execute(
"""
INSERT INTO users
(id, email, username, discriminator, password_hash)
VALUES
($1, $2, $3, $4, $5)
""",
new_id,
email,
username,
new_discrim,
pwd_hash,
)
except UniqueViolationError:
raise BadRequest("Email already used.")
return new_id, pwd_hash
async def _del_from_table(db, table: str, user_id: int):
"""Delete a row from a table."""
column = {
"channel_overwrites": "target_user",
"user_settings": "id",
"group_dm_members": "member_id",
}.get(table, "user_id")
res = await db.execute(
f"""
DELETE FROM {table}
WHERE {column} = $1
""",
user_id,
)
log.info("Deleting uid {} from {}, res: {!r}", user_id, table, res)
async def delete_user(user_id, *, mass_update: bool = True):
"""Delete a user. Does not disconnect the user."""
db = app.db
new_username = f"Deleted User {rand_hex()}"
# by using a random hex in password_hash
# we break attempts at using the default '123' password hash
# to issue valid tokens for deleted users.
await db.execute(
"""
UPDATE users
SET
username = $1,
email = NULL,
mfa_enabled = false,
verified = false,
avatar = NULL,
flags = 0,
premium_since = NULL,
phone = '',
password_hash = $2
WHERE
id = $3
""",
new_username,
rand_hex(32),
user_id,
)
# remove the user from various tables
await _del_from_table(db, "user_settings", user_id)
await _del_from_table(db, "user_payment_sources", user_id)
await _del_from_table(db, "user_subscriptions", user_id)
await _del_from_table(db, "user_payments", user_id)
await _del_from_table(db, "user_read_state", user_id)
await _del_from_table(db, "guild_settings", user_id)
await _del_from_table(db, "guild_settings_channel_overrides", user_id)
await db.execute(
"""
DELETE FROM relationships
WHERE user_id = $1 OR peer_id = $1
""",
user_id,
)
# DMs are still maintained, but not the state.
await _del_from_table(db, "dm_channel_state", user_id)
# NOTE: we don't delete the group dms the user is an owner of...
# TODO: group dm owner reassign when the owner leaves a gdm
await _del_from_table(db, "group_dm_members", user_id)
await _del_from_table(db, "members", user_id)
await _del_from_table(db, "member_roles", user_id)
await _del_from_table(db, "channel_overwrites", user_id)
# after updating the user, we send USER_UPDATE so that all the other
# clients can refresh their caches on the now-deleted user
if mass_update:
await mass_user_update(user_id)
async def user_disconnect(user_id: int):
"""Disconnects the given user's devices."""
# after removing the user from all tables, we need to force
# all known user states to reconnect, causing the user to not
# be online anymore.
user_states = app.state_manager.user_states(user_id)
for state in user_states:
# make it unable to resume
app.state_manager.remove(state)
if not state.ws:
continue
# force a close, 4000 should make the client reconnect.
await state.ws.ws.close(4000)
# force everyone to see the user as offline
await app.presence.dispatch_pres(
user_id, {"afk": False, "status": "offline", "game": None, "since": 0}
)

View File

@ -22,6 +22,7 @@ import asyncio
import urllib.parse
from pathlib import Path
from quart import current_app as app
from logbook import Logger
from litecord.embed.sanitizer import proxify, fetch_metadata, fetch_embed
@ -33,10 +34,10 @@ log = Logger(__name__)
MEDIA_EXTENSIONS = ("png", "jpg", "jpeg", "gif", "webm")
async def insert_media_meta(url, config, session):
async def insert_media_meta(url):
"""Insert media metadata as an embed."""
img_proxy_url = proxify(url, config=config)
meta = await fetch_metadata(url, config=config, session=session)
img_proxy_url = proxify(url)
meta = await fetch_metadata(url)
if meta is None:
return
@ -56,14 +57,14 @@ async def insert_media_meta(url, config, session):
}
async def msg_update_embeds(payload, new_embeds, storage, dispatcher):
async def msg_update_embeds(payload, new_embeds):
"""Update the message with the given embeds and dispatch a MESSAGE_UPDATE
to users."""
message_id = int(payload["id"])
channel_id = int(payload["channel_id"])
await storage.execute_with_json(
await app.storage.execute_with_json(
"""
UPDATE messages
SET embeds = $1
@ -85,7 +86,9 @@ async def msg_update_embeds(payload, new_embeds, storage, dispatcher):
if "flags" in payload:
update_payload["flags"] = payload["flags"]
await dispatcher.dispatch("channel", channel_id, "MESSAGE_UPDATE", update_payload)
await app.dispatcher.dispatch(
"channel", channel_id, "MESSAGE_UPDATE", update_payload
)
def is_media_url(url) -> bool:
@ -102,15 +105,13 @@ def is_media_url(url) -> bool:
return extension in MEDIA_EXTENSIONS
async def insert_mp_embed(parsed, config, session):
async def insert_mp_embed(parsed):
"""Insert mediaproxy embed."""
embed = await fetch_embed(parsed, config=config, session=session)
embed = await fetch_embed(parsed)
return embed
async def process_url_embed(
config, storage, dispatcher, session, payload: dict, *, delay=0
):
async def process_url_embed(payload: dict, *, delay=0):
"""Process URLs in a message and generate embeds based on that."""
await asyncio.sleep(delay)
@ -145,9 +146,9 @@ async def process_url_embed(
url = EmbedURL(url)
if is_media_url(url):
embed = await insert_media_meta(url, config, session)
embed = await insert_media_meta(url)
else:
embed = await insert_mp_embed(url, config, session)
embed = await insert_mp_embed(url)
if not embed:
continue
@ -160,4 +161,4 @@ async def process_url_embed(
log.debug("made {} embeds for mid {}", len(new_embeds), message_id)
await msg_update_embeds(payload, new_embeds, storage, dispatcher)
await msg_update_embeds(payload, new_embeds)

View File

@ -75,35 +75,24 @@ def path_exists(embed: Embed, components_in: Union[List[str], str]):
return False
def _mk_cfg_sess(config, session) -> tuple:
"""Return a tuple of (config, session)."""
if config is None:
config = app.config
if session is None:
session = app.session
return config, session
def _md_base(config) -> Optional[tuple]:
def _md_base() -> Optional[tuple]:
"""Return the protocol and base url for the mediaproxy."""
md_base_url = config["MEDIA_PROXY"]
md_base_url = app.config["MEDIA_PROXY"]
if md_base_url is None:
return None
proto = "https" if config["IS_SSL"] else "http"
proto = "https" if app.config["IS_SSL"] else "http"
return proto, md_base_url
def make_md_req_url(config, scope: str, url):
"""Make a mediaproxy request URL given the config, scope, and the url
def make_md_req_url(scope: str, url):
"""Make a mediaproxy request URL given the scope and the url
to be proxied.
When MEDIA_PROXY is None, however, returns the original URL.
"""
base = _md_base(config)
base = _md_base()
if base is None:
return url.url if isinstance(url, EmbedURL) else url
@ -111,38 +100,25 @@ def make_md_req_url(config, scope: str, url):
return f"{proto}://{base_url}/{scope}/{url.to_md_path}"
def proxify(url, *, config=None) -> str:
def proxify(url) -> str:
"""Return a mediaproxy url for the given EmbedURL. Returns an
/img/ scope."""
config, _sess = _mk_cfg_sess(config, False)
if isinstance(url, str):
url = EmbedURL(url)
return make_md_req_url(config, "img", url)
return make_md_req_url("img", url)
async def _md_client_req(
config, session, scope: str, url, *, ret_resp=False
scope: str, url, *, ret_resp=False
) -> Optional[Union[Tuple, Dict]]:
"""Makes a request to the mediaproxy.
This has common code between all the main mediaproxy request functions
to decrease code repetition.
Note that config and session exist because there are cases where the app
isn't retrievable (as those functions usually run in background tasks,
not in the app itself).
Parameters
----------
config: dict-like
the app configuration, if None, this will get the global one from the
app instance.
session: aiohttp client session
the aiohttp ClientSession instance to use, if None, this will get
the global one from the app.
scope: str
the scope of your request. one of 'meta', 'img', or 'embed' are
available for the mediaproxy's API.
@ -155,14 +131,12 @@ async def _md_client_req(
the raw bytes of the response, but by the time this function is
returned, the response object is invalid and the socket is closed
"""
config, session = _mk_cfg_sess(config, session)
if not isinstance(url, EmbedURL):
url = EmbedURL(url)
request_url = make_md_req_url(config, scope, url)
request_url = make_md_req_url(scope, url)
async with session.get(request_url) as resp:
async with app.session.get(request_url) as resp:
if resp.status == 200:
if ret_resp:
return resp, await resp.read()
@ -174,18 +148,18 @@ async def _md_client_req(
return None
async def fetch_metadata(url, *, config=None, session=None) -> Optional[Dict]:
async def fetch_metadata(url) -> Optional[Dict]:
"""Fetch metadata for a url (image width, mime, etc)."""
return await _md_client_req(config, session, "meta", url)
return await _md_client_req("meta", url)
async def fetch_raw_img(url, *, config=None, session=None) -> Optional[tuple]:
async def fetch_raw_img(url) -> Optional[tuple]:
"""Fetch raw data for a url (the bytes given off, used to proxy images).
Returns a tuple containing the response object and the raw bytes given by
the website.
"""
tup = await _md_client_req(config, session, "img", url, ret_resp=True)
tup = await _md_client_req("img", url, ret_resp=True)
if not tup:
return None
@ -193,13 +167,13 @@ async def fetch_raw_img(url, *, config=None, session=None) -> Optional[tuple]:
return tup
async def fetch_embed(url, *, config=None, session=None) -> Dict[str, Any]:
async def fetch_embed(url) -> Dict[str, Any]:
"""Fetch an embed for a given webpage (an automatically generated embed
by the mediaproxy, look over the project on how it generates embeds).
Returns a discord embed object.
"""
return await _md_client_req(config, session, "embed", url)
return await _md_client_req("embed", url)
async def fill_embed(embed: Optional[Embed]) -> Optional[Embed]:

View File

@ -54,28 +54,22 @@ class Flags:
"""
def __init_subclass__(cls, **_kwargs):
attrs = inspect.getmembers(cls, lambda x: not inspect.isroutine(x))
# get only the members that represent a field
cls._attrs = inspect.getmembers(cls, lambda x: isinstance(x, int))
def _make_int(value):
@classmethod
def from_int(cls, value: int):
"""Create a Flags from a given int value."""
res = Flags()
setattr(res, "value", value)
for attr, val in attrs:
# get only the ones that represent a field in the
# number's bits
if not isinstance(val, int):
continue
for attr, val in cls._attrs:
has_attr = (value & val) == val
# set each attribute
# set attributes dynamically
setattr(res, f"is_{attr}", has_attr)
return res
cls.from_int = _make_int
class ChannelType(EasyEnum):
GUILD_TEXT = 0

View File

@ -116,6 +116,10 @@ class Forbidden(LitecordError):
status_code = 403
class ForbiddenDM(Forbidden):
error_code = 50007
class NotFound(LitecordError):
status_code = 404

View File

@ -21,7 +21,7 @@ import collections
import asyncio
import pprint
import zlib
from typing import List, Dict, Any
from typing import List, Dict, Any, Iterable
from random import randint
import websockets
@ -56,41 +56,15 @@ WebsocketProperties = collections.namedtuple(
"WebsocketProperties", "v encoding compress zctx zsctx tasks"
)
WebsocketObjects = collections.namedtuple(
"WebsocketObjects",
(
"db",
"state_manager",
"storage",
"loop",
"dispatcher",
"presence",
"ratelimiter",
"user_storage",
"voice",
),
)
class GatewayWebsocket:
"""Main gateway websocket logic."""
def __init__(self, ws, app, **kwargs):
self.ext = WebsocketObjects(
app.db,
app.state_manager,
app.storage,
app.loop,
app.dispatcher,
app.presence,
app.ratelimiter,
app.user_storage,
app.voice,
)
self.storage = self.ext.storage
self.user_storage = self.ext.user_storage
self.presence = self.ext.presence
self.app = app
self.storage = app.storage
self.user_storage = app.user_storage
self.presence = app.presence
self.ws = ws
self.wsp = WebsocketProperties(
@ -225,7 +199,7 @@ class GatewayWebsocket:
await self.send({"op": op_code, "d": data, "t": None, "s": None})
def _check_ratelimit(self, key: str, ratelimit_key):
ratelimit = self.ext.ratelimiter.get_ratelimit(f"_ws.{key}")
ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}")
bucket = ratelimit.get_bucket(ratelimit_key)
return bucket.update_rate_limit()
@ -245,7 +219,7 @@ class GatewayWebsocket:
if task:
task.cancel()
self.wsp.tasks["heartbeat"] = self.ext.loop.create_task(
self.wsp.tasks["heartbeat"] = self.app.loop.create_task(
task_wrapper("hb wait", self._hb_wait(interval))
)
@ -330,7 +304,7 @@ class GatewayWebsocket:
if r["type"] == RelationshipType.FRIEND.value
]
friend_presences = await self.ext.presence.friend_presences(friend_ids)
friend_presences = await self.app.presence.friend_presences(friend_ids)
settings = await self.user_storage.get_user_settings(user_id)
return {
@ -377,14 +351,14 @@ class GatewayWebsocket:
await self.dispatch("READY", {**base_ready, **user_ready})
# async dispatch of guilds
self.ext.loop.create_task(self._guild_dispatch(guilds))
self.app.loop.create_task(self._guild_dispatch(guilds))
async def _check_shards(self, shard, user_id):
"""Check if the given `shard` value in IDENTIFY has good enough values.
"""
current_shard, shard_count = shard
guilds = await self.ext.db.fetchval(
guilds = await self.app.db.fetchval(
"""
SELECT COUNT(*)
FROM members
@ -460,7 +434,7 @@ class GatewayWebsocket:
("channel", gdm_ids),
]
await self.ext.dispatcher.mass_sub(user_id, channels_to_sub)
await self.app.dispatcher.mass_sub(user_id, channels_to_sub)
if not self.state.bot:
# subscribe to all friends
@ -468,7 +442,7 @@ class GatewayWebsocket:
# when they come online)
friend_ids = await self.user_storage.get_friend_ids(user_id)
log.info("subscribing to {} friends", len(friend_ids))
await self.ext.dispatcher.sub_many("friend", user_id, friend_ids)
await self.app.dispatcher.sub_many("friend", user_id, friend_ids)
async def update_status(self, status: dict):
"""Update the status of the current websocket connection."""
@ -520,7 +494,7 @@ class GatewayWebsocket:
f'Updating presence status={status["status"]} for '
f"uid={self.state.user_id}"
)
await self.ext.presence.dispatch_pres(self.state.user_id, self.state.presence)
await self.app.presence.dispatch_pres(self.state.user_id, self.state.presence)
async def handle_1(self, payload: Dict[str, Any]):
"""Handle OP 1 Heartbeat packets."""
@ -558,13 +532,13 @@ class GatewayWebsocket:
presence = data.get("presence")
try:
user_id = await raw_token_check(token, self.ext.db)
user_id = await raw_token_check(token, self.app.db)
except (Unauthorized, Forbidden):
raise WebsocketClose(4004, "Authentication failed")
await self._connect_ratelimit(user_id)
bot = await self.ext.db.fetchval(
bot = await self.app.db.fetchval(
"""
SELECT bot FROM users
WHERE id = $1
@ -587,7 +561,7 @@ class GatewayWebsocket:
)
# link the state to the user
self.ext.state_manager.insert(self.state)
self.app.state_manager.insert(self.state)
await self.update_status(presence)
await self.subscribe_all(data.get("guild_subscriptions", True))
@ -631,12 +605,12 @@ class GatewayWebsocket:
# if its null and null, disconnect the user from any voice
# TODO: maybe just leave from DMs? idk...
if channel_id is None and guild_id is None:
return await self.ext.voice.leave_all(self.state.user_id)
return await self.app.voice.leave_all(self.state.user_id)
# if guild is not none but channel is, we are leaving
# a guild's channel
if channel_id is None:
return await self.ext.voice.leave(guild_id, self.state.user_id)
return await self.app.voice.leave(guild_id, self.state.user_id)
# fetch an existing state given user and guild OR user and channel
chan_type = ChannelType(await self.storage.get_chan_type(channel_id))
@ -659,10 +633,10 @@ class GatewayWebsocket:
# this state id format takes care of that.
voice_key = (self.state.user_id, state_id2)
voice_state = await self.ext.voice.get_state(voice_key)
voice_state = await self.app.voice.get_state(voice_key)
if voice_state is None:
return await self.ext.voice.create_state(voice_key, data)
return await self.app.voice.create_state(voice_key, data)
same_guild = guild_id == voice_state.guild_id
same_channel = channel_id == voice_state.channel_id
@ -670,10 +644,10 @@ class GatewayWebsocket:
prop = await self._vsu_get_prop(voice_state, data)
if same_guild and same_channel:
return await self.ext.voice.update_state(voice_state, prop)
return await self.app.voice.update_state(voice_state, prop)
if same_guild and not same_channel:
return await self.ext.voice.move_state(voice_state, channel_id)
return await self.app.voice.move_state(voice_state, channel_id)
async def _handle_5(self, payload: Dict[str, Any]):
"""Handle OP 5 Voice Server Ping.
@ -698,9 +672,9 @@ class GatewayWebsocket:
# since the state will be removed from
# the manager, it will become unreachable
# when trying to resume.
self.ext.state_manager.remove(self.state)
self.app.state_manager.remove(self.state)
async def _resume(self, replay_seqs: iter):
async def _resume(self, replay_seqs: Iterable):
presences = []
try:
@ -740,12 +714,12 @@ class GatewayWebsocket:
raise DecodeError("Invalid resume payload")
try:
user_id = await raw_token_check(token, self.ext.db)
user_id = await raw_token_check(token, self.app.db)
except (Unauthorized, Forbidden):
raise WebsocketClose(4004, "Invalid token")
try:
state = self.ext.state_manager.fetch(user_id, sess_id)
state = self.app.state_manager.fetch(user_id, sess_id)
except KeyError:
return await self.invalidate_session(False)
@ -948,7 +922,7 @@ class GatewayWebsocket:
log.debug("lazy request: members: {}", data.get("members", []))
# make shard query
lazy_guilds = self.ext.dispatcher.backends["lazy_guild"]
lazy_guilds = self.app.dispatcher.backends["lazy_guild"]
for chan_id, ranges in data.get("channels", {}).items():
chan_id = int(chan_id)
@ -992,10 +966,10 @@ class GatewayWebsocket:
# close anyone trying to login while the
# server is shutting down
if self.ext.state_manager.closed:
if self.app.state_manager.closed:
raise WebsocketClose(4000, "state manager closed")
if not self.ext.state_manager.accept_new:
if not self.app.state_manager.accept_new:
raise WebsocketClose(4000, "state manager closed for new")
while True:
@ -1016,7 +990,7 @@ class GatewayWebsocket:
task.cancel()
if self.state:
self.ext.state_manager.remove(self.state)
self.app.state_manager.remove(self.state)
self.state.ws = None
self.state = None
@ -1031,14 +1005,14 @@ class GatewayWebsocket:
# TODO: account for sharding
# this only updates status to offline once
# ALL shards have come offline
states = self.ext.state_manager.user_states(user_id)
states = self.app.state_manager.user_states(user_id)
with_ws = [s for s in states if s.ws]
# there arent any other states with websocket
if not with_ws:
offline = {"afk": False, "status": "offline", "game": None, "since": 0}
await self.ext.presence.dispatch_pres(user_id, offline)
await self.app.presence.dispatch_pres(user_id, offline)
async def run(self):
"""Wrap :meth:`listen_messages` inside

View File

@ -18,7 +18,9 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
import asyncio
from typing import Any
from quart.ctx import copy_current_app_context
from logbook import Logger
log = Logger(__name__)
@ -47,9 +49,14 @@ class JobManager:
def spawn(self, coro):
"""Spawn a given future or coroutine in the background."""
task = self.loop.create_task(self._wrapper(coro))
@copy_current_app_context
async def _ctx_wrapper_bg() -> Any:
return await coro
task = self.loop.create_task(self._wrapper(_ctx_wrapper_bg()))
self.jobs.append(task)
return task
def close(self):
"""Close the job manager, cancelling all existing jobs.

View File

@ -19,11 +19,14 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import json
import secrets
from typing import Any, Iterable, Optional, Sequence, List, Dict, Union
from logbook import Logger
from quart.json import JSONEncoder
from quart import current_app as app
from quart import current_app as app, request
from .errors import BadRequest
log = Logger(__name__)
@ -233,3 +236,57 @@ def maybe_int(val: Any) -> Union[int, Any]:
return int(val)
except (ValueError, TypeError):
return val
async def maybe_lazy_guild_dispatch(
guild_id: int, event: str, role, force: bool = False
):
# sometimes we want to dispatch an event
# even if the role isn't hoisted
# an example of such a case is when a role loses
# its hoist status.
# check if is a dict first because role_delete
# only receives the role id.
if isinstance(role, dict) and not role["hoist"] and not force:
return
await app.dispatcher.dispatch("lazy_guild", guild_id, event, role)
def extract_limit(request_, default: int = 50, max_val: int = 100):
"""Extract a limit kwarg."""
try:
limit = int(request_.args.get("limit", default))
if limit not in range(0, max_val + 1):
raise ValueError()
except (TypeError, ValueError):
raise BadRequest("limit not int")
return limit
def query_tuple_from_args(args: dict, limit: int) -> tuple:
"""Extract a 2-tuple out of request arguments."""
before, after = None, None
if "around" in request.args:
average = int(limit / 2)
around = int(args["around"])
after = around - average
before = around + average
elif "before" in args:
before = int(args["before"])
elif "after" in args:
before = int(args["after"])
return before, after
def rand_hex(length: int = 8) -> str:
"""Generate random hex characters."""
return secrets.token_hex(length)[:length]

View File

@ -17,9 +17,8 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from litecord.auth import create_user
from litecord.common.users import create_user, delete_user
from litecord.blueprints.auth import make_token
from litecord.blueprints.users import delete_user
from litecord.enums import UserFlags

4
run.py
View File

@ -337,9 +337,9 @@ async def api_index(app_):
async def post_app_start(app_):
# we'll need to start a billing job
app_.sched.spawn(payment_job(app_))
app_.sched.spawn(payment_job())
app_.sched.spawn(api_index(app_))
app_.sched.spawn(guild_region_check(app_))
app_.sched.spawn(guild_region_check())
def start_websocket(host, port, ws_handler) -> asyncio.Future:

View File

@ -30,10 +30,9 @@ from tests.common import email, TestClient
from run import app as main_app, set_blueprints
from litecord.auth import create_user
from litecord.common.users import create_user, delete_user
from litecord.enums import UserFlags
from litecord.blueprints.auth import make_token
from litecord.blueprints.users import delete_user
@pytest.fixture(name="app")

View File

@ -54,6 +54,11 @@ async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False):
return rjson
async def _delete_guild(test_cli, guild_id: int):
async with test_cli.app.app_context():
await delete_guild(int(guild_id))
@pytest.mark.asyncio
async def test_guild_fetch(test_cli_staff):
"""Test the creation and fetching of a guild via the Admin API."""
@ -63,7 +68,7 @@ async def test_guild_fetch(test_cli_staff):
try:
await _fetch_guild(test_cli_staff, guild_id)
finally:
await delete_guild(int(guild_id), app_=test_cli_staff.app)
await _delete_guild(test_cli_staff, int(guild_id))
@pytest.mark.asyncio
@ -91,7 +96,7 @@ async def test_guild_update(test_cli_staff):
rjson = await _fetch_guild(test_cli_staff, guild_id)
assert rjson["unavailable"]
finally:
await delete_guild(int(guild_id), app_=test_cli_staff.app)
await _delete_guild(test_cli_staff, int(guild_id))
@pytest.mark.asyncio
@ -113,4 +118,4 @@ async def test_guild_delete(test_cli_staff):
assert rjson["error"]
assert rjson["code"] == GuildNotFound.error_code
finally:
await delete_guild(int(guild_id), app_=test_cli_staff.app)
await _delete_guild(test_cli_staff, int(guild_id))