diff --git a/litecord/auth.py b/litecord/auth.py
index 078a02f..849a45a 100644
--- a/litecord/auth.py
+++ b/litecord/auth.py
@@ -19,17 +19,13 @@ along with this program. If not, see .
import base64
import binascii
-from random import randint
-from typing import Tuple, Optional
import bcrypt
-from asyncpg import UniqueViolationError
from itsdangerous import TimestampSigner, BadSignature
from logbook import Logger
from quart import request, current_app as app
-from litecord.errors import Forbidden, Unauthorized, BadRequest
-from litecord.snowflake import get_snowflake
+from litecord.errors import Forbidden, Unauthorized
from litecord.enums import UserFlags
@@ -56,20 +52,20 @@ async def raw_token_check(token: str, db=None) -> int:
# just try by fragments instead of
# unpacking
fragments = token.split(".")
- user_id = fragments[0]
+ user_id_str = fragments[0]
try:
- user_id = base64.b64decode(user_id.encode())
- user_id = int(user_id)
+ user_id_decoded = base64.b64decode(user_id_str.encode())
+ user_id = int(user_id_decoded)
except (ValueError, binascii.Error):
raise Unauthorized("Invalid user ID type")
pwd_hash = await db.fetchval(
"""
- SELECT password_hash
- FROM users
- WHERE id = $1
- """,
+ SELECT password_hash
+ FROM users
+ WHERE id = $1
+ """,
user_id,
)
@@ -88,10 +84,10 @@ async def raw_token_check(token: str, db=None) -> int:
# with people leaving their clients open forever)
await db.execute(
"""
- UPDATE users
- SET last_session = (now() at time zone 'utc')
- WHERE id = $1
- """,
+ UPDATE users
+ SET last_session = (now() at time zone 'utc')
+ WHERE id = $1
+ """,
user_id,
)
@@ -128,10 +124,10 @@ async def admin_check() -> int:
flags = await app.db.fetchval(
"""
- SELECT flags
- FROM users
- WHERE id = $1
- """,
+ SELECT flags
+ FROM users
+ WHERE id = $1
+ """,
user_id,
)
@@ -150,105 +146,3 @@ async def hash_data(data: str, loop=None) -> str:
hashed = await loop.run_in_executor(None, bcrypt.hashpw, buf, bcrypt.gensalt(14))
return hashed.decode()
-
-
-async def check_username_usage(username: str):
- """Raise an error if too many people are with the same username."""
- same_username = await app.db.fetchval(
- """
- SELECT COUNT(*)
- FROM users
- WHERE username = $1
- """,
- username,
- )
-
- if same_username > 9000:
- raise BadRequest(
- "Too many people.",
- {
- "username": "Too many people used the same username. "
- "Please choose another"
- },
- )
-
-
-def _raw_discrim() -> str:
- discrim_number = randint(1, 9999)
- return "%04d" % discrim_number
-
-
-async def roll_discrim(username: str) -> Optional[str]:
- """Roll a discriminator for a DiscordTag.
-
- Tries to generate one 10 times.
-
- Calls check_username_usage.
- """
-
- # we shouldn't roll discrims for usernames
- # that have been used too much.
- await check_username_usage(username)
-
- # max 10 times for a reroll
- for _ in range(10):
- # generate random discrim
- discrim = _raw_discrim()
-
- # check if anyone is with it
- res = await app.db.fetchval(
- """
- SELECT id
- FROM users
- WHERE username = $1 AND discriminator = $2
- """,
- username,
- discrim,
- )
-
- # if no user is found with the (username, discrim)
- # pair, then this is unique! return it.
- if res is None:
- return discrim
-
- return None
-
-
-async def create_user(username: str, email: str, password: str) -> Tuple[int, str]:
- """Create a single user.
-
- Generates a distriminator and other information. You can fetch the user
- data back with :meth:`Storage.get_user`.
- """
- db = app.db
- loop = app.loop
-
- new_id = get_snowflake()
- new_discrim = await roll_discrim(username)
-
- if new_discrim is None:
- raise BadRequest(
- "Unable to register.",
- {"username": "Too many people are with this username."},
- )
-
- pwd_hash = await hash_data(password, loop)
-
- try:
- await db.execute(
- """
- INSERT INTO users
- (id, email, username, discriminator, password_hash)
- VALUES
- ($1, $2, $3, $4, $5)
- """,
- new_id,
- email,
- username,
- new_discrim,
- pwd_hash,
- )
- except UniqueViolationError:
- raise BadRequest("Email already used.")
-
- return new_id, pwd_hash
diff --git a/litecord/blueprints/admin_api/guilds.py b/litecord/blueprints/admin_api/guilds.py
index 1cf792b..15f2647 100644
--- a/litecord/blueprints/admin_api/guilds.py
+++ b/litecord/blueprints/admin_api/guilds.py
@@ -22,7 +22,7 @@ from quart import Blueprint, jsonify, current_app as app, request
from litecord.auth import admin_check
from litecord.schemas import validate
from litecord.admin_schemas import GUILD_UPDATE
-from litecord.blueprints.guilds import delete_guild
+from litecord.common.guilds import delete_guild
from litecord.errors import GuildNotFound
bp = Blueprint("guilds_admin", __name__)
diff --git a/litecord/blueprints/admin_api/users.py b/litecord/blueprints/admin_api/users.py
index b8b0acd..b1bbe98 100644
--- a/litecord/blueprints/admin_api/users.py
+++ b/litecord/blueprints/admin_api/users.py
@@ -20,13 +20,17 @@ along with this program. If not, see .
from quart import Blueprint, jsonify, current_app as app, request
from litecord.auth import admin_check
-from litecord.blueprints.auth import create_user
from litecord.schemas import validate
from litecord.admin_schemas import USER_CREATE, USER_UPDATE
from litecord.errors import BadRequest, Forbidden
from litecord.utils import async_map
-from litecord.blueprints.users import delete_user, user_disconnect, mass_user_update
from litecord.enums import UserFlags
+from litecord.common.users import (
+ create_user,
+ delete_user,
+ user_disconnect,
+ mass_user_update,
+)
bp = Blueprint("users_admin", __name__)
diff --git a/litecord/blueprints/admin_api/voice.py b/litecord/blueprints/admin_api/voice.py
index 334bcfe..e2b87d5 100644
--- a/litecord/blueprints/admin_api/voice.py
+++ b/litecord/blueprints/admin_api/voice.py
@@ -118,7 +118,7 @@ async def deprecate_region(region):
return "", 204
-async def guild_region_check(app_):
+async def guild_region_check():
"""Check all guilds for voice region inconsistencies.
Since the voice migration caused all guilds.region columns
@@ -126,23 +126,23 @@ async def guild_region_check(app_):
than one region setup.
"""
- regions = await app_.storage.all_voice_regions()
+ regions = await app.storage.all_voice_regions()
if not regions:
log.info("region check: no regions to move guilds to")
return
- res = await app_.db.execute(
+ res = await app.db.execute(
"""
- UPDATE guilds
- SET region = (
- SELECT id
- FROM voice_regions
- OFFSET floor(random()*$1)
- LIMIT 1
- )
- WHERE region = NULL
- """,
+ UPDATE guilds
+ SET region = (
+ SELECT id
+ FROM voice_regions
+ OFFSET floor(random()*$1)
+ LIMIT 1
+ )
+ WHERE region = NULL
+ """,
len(regions),
)
diff --git a/litecord/blueprints/auth.py b/litecord/blueprints/auth.py
index 81a25d1..4b60a69 100644
--- a/litecord/blueprints/auth.py
+++ b/litecord/blueprints/auth.py
@@ -25,7 +25,8 @@ import bcrypt
from quart import Blueprint, jsonify, request, current_app as app
from logbook import Logger
-from litecord.auth import token_check, create_user
+from litecord.auth import token_check
+from litecord.common.users import create_user
from litecord.schemas import validate, REGISTER, REGISTER_WITH_INVITE
from litecord.errors import BadRequest
from litecord.snowflake import get_snowflake
@@ -120,7 +121,7 @@ async def _register_with_invite():
)
user_id, pwd_hash = await create_user(
- data["username"], data["email"], data["password"], app.db
+ data["username"], data["email"], data["password"]
)
return jsonify({"token": make_token(user_id, pwd_hash), "user_id": str(user_id)})
diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py
index 551508b..f46bae8 100644
--- a/litecord/blueprints/channel/messages.py
+++ b/litecord/blueprints/channel/messages.py
@@ -17,65 +17,36 @@ along with this program. If not, see .
"""
-import json
from pathlib import Path
-from PIL import Image
from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger
from litecord.blueprints.auth import token_check
from litecord.blueprints.checks import channel_check, channel_perm_check
-from litecord.blueprints.dms import try_dm_state
-from litecord.errors import MessageNotFound, Forbidden, BadRequest
+from litecord.errors import MessageNotFound, Forbidden
from litecord.enums import MessageType, ChannelType, GUILD_CHANS
from litecord.snowflake import get_snowflake
from litecord.schemas import validate, MESSAGE_CREATE
-from litecord.utils import pg_set_json
+from litecord.utils import pg_set_json, query_tuple_from_args, extract_limit
from litecord.permissions import get_permissions
from litecord.embed.sanitizer import fill_embed
from litecord.embed.messages import process_url_embed
-from litecord.blueprints.channel.dm_checks import dm_pre_check
+from litecord.common.channels import dm_pre_check, try_dm_state
from litecord.images import try_unlink
+from litecord.common.messages import (
+ msg_create_request,
+ msg_create_check_content,
+ msg_add_attachment,
+ msg_guild_text_mentions,
+)
log = Logger(__name__)
bp = Blueprint("channel_messages", __name__)
-def extract_limit(request_, default: int = 50, max_val: int = 100):
- """Extract a limit kwarg."""
- try:
- limit = int(request_.args.get("limit", default))
-
- if limit not in range(0, max_val + 1):
- raise ValueError()
- except (TypeError, ValueError):
- raise BadRequest("limit not int")
-
- return limit
-
-
-def query_tuple_from_args(args: dict, limit: int) -> tuple:
- """Extract a 2-tuple out of request arguments."""
- before, after = None, None
-
- if "around" in request.args:
- average = int(limit / 2)
- around = int(args["around"])
-
- after = around - average
- before = around + average
-
- elif "before" in args:
- before = int(args["before"])
- elif "after" in args:
- before = int(args["after"])
-
- return before, after
-
-
@bp.route("//messages", methods=["GET"])
async def get_messages(channel_id):
user_id = await token_check()
@@ -204,191 +175,8 @@ async def create_message(
return message_id
-async def msg_guild_text_mentions(
- payload: dict, guild_id: int, mentions_everyone: bool, mentions_here: bool
-):
- """Calculates mention data side-effects."""
- channel_id = int(payload["channel_id"])
-
- # calculate the user ids we'll bump the mention count for
- uids = set()
-
- # first is extracting user mentions
- for mention in payload["mentions"]:
- uids.add(int(mention["id"]))
-
- # then role mentions
- for role_mention in payload["mention_roles"]:
- role_id = int(role_mention)
- member_ids = await app.storage.get_role_members(role_id)
-
- for member_id in member_ids:
- uids.add(member_id)
-
- # at-here only updates the state
- # for the users that have a state
- # in the channel.
- if mentions_here:
- uids = set()
-
- await app.db.execute(
- """
- UPDATE user_read_state
- SET mention_count = mention_count + 1
- WHERE channel_id = $1
- """,
- channel_id,
- )
-
- # at-here updates the read state
- # for all users, including the ones
- # that might not have read permissions
- # to the channel.
- if mentions_everyone:
- uids = set()
-
- member_ids = await app.storage.get_member_ids(guild_id)
-
- await app.db.executemany(
- """
- UPDATE user_read_state
- SET mention_count = mention_count + 1
- WHERE channel_id = $1 AND user_id = $2
- """,
- [(channel_id, uid) for uid in member_ids],
- )
-
- for user_id in uids:
- await app.db.execute(
- """
- UPDATE user_read_state
- SET mention_count = mention_count + 1
- WHERE user_id = $1
- AND channel_id = $2
- """,
- user_id,
- channel_id,
- )
-
-
-async def msg_create_request() -> tuple:
- """Extract the json input and any file information
- the client gave to us in the request.
-
- This only applies to create message route.
- """
- form = await request.form
- request_json = await request.get_json() or {}
-
- # NOTE: embed isn't set on form data
- json_from_form = {
- "content": form.get("content", ""),
- "nonce": form.get("nonce", "0"),
- "tts": json.loads(form.get("tts", "false")),
- }
-
- payload_json = json.loads(form.get("payload_json", "{}"))
-
- json_from_form.update(request_json)
- json_from_form.update(payload_json)
-
- files = await request.files
-
- # we don't really care about the given fields on the files dict, so
- # we only extract the values
- return json_from_form, [v for k, v in files.items()]
-
-
-def msg_create_check_content(payload: dict, files: list, *, use_embeds=False):
- """Check if there is actually any content being sent to us."""
- has_content = bool(payload.get("content", ""))
- has_files = len(files) > 0
-
- embed_field = "embeds" if use_embeds else "embed"
- has_embed = embed_field in payload and payload.get(embed_field) is not None
-
- has_total_content = has_content or has_embed or has_files
-
- if not has_total_content:
- raise BadRequest("No content has been provided.")
-
-
-async def msg_add_attachment(message_id: int, channel_id: int, attachment_file) -> int:
- """Add an attachment to a message.
-
- Parameters
- ----------
- message_id: int
- The ID of the message getting the attachment.
- channel_id: int
- The ID of the channel the message belongs to.
-
- Exists because the attachment URL scheme contains
- a channel id. The purpose is unknown, but we are
- implementing Discord's behavior.
- attachment_file: quart.FileStorage
- quart FileStorage instance of the file.
- """
-
- attachment_id = get_snowflake()
- filename = attachment_file.filename
-
- # understand file info
- mime = attachment_file.mimetype
- is_image = mime.startswith("image/")
-
- img_width, img_height = None, None
-
- # extract file size
- # TODO: this is probably inneficient
- file_size = attachment_file.stream.getbuffer().nbytes
-
- if is_image:
- # open with pillow, extract image size
- image = Image.open(attachment_file.stream)
- img_width, img_height = image.size
-
- # NOTE: DO NOT close the image, as closing the image will
- # also close the stream.
-
- # reset it to 0 for later usage
- attachment_file.stream.seek(0)
-
- await app.db.execute(
- """
- INSERT INTO attachments
- (id, channel_id, message_id,
- filename, filesize,
- image, width, height)
- VALUES
- ($1, $2, $3, $4, $5, $6, $7, $8)
- """,
- attachment_id,
- channel_id,
- message_id,
- filename,
- file_size,
- is_image,
- img_width,
- img_height,
- )
-
- ext = filename.split(".")[-1]
-
- with open(f"attachments/{attachment_id}.{ext}", "wb") as attach_file:
- attach_file.write(attachment_file.stream.read())
-
- log.debug("written {} bytes for attachment id {}", file_size, attachment_id)
-
- return attachment_id
-
-
-async def _spawn_embed(app_, payload, **kwargs):
- app_.sched.spawn(
- process_url_embed(
- app_.config, app_.storage, app_.dispatcher, app_.session, payload, **kwargs
- )
- )
+async def _spawn_embed(payload, **kwargs):
+ app.sched.spawn(process_url_embed(payload, **kwargs))
@bp.route("//messages", methods=["POST"])
@@ -458,7 +246,7 @@ async def _create_message(channel_id):
# spawn url processor for embedding of images
perms = await get_permissions(user_id, channel_id)
if perms.bits.embed_links:
- await _spawn_embed(app, payload)
+ await _spawn_embed(payload)
# update read state for the author
await app.db.execute(
@@ -536,7 +324,6 @@ async def edit_message(channel_id, message_id):
perms = await get_permissions(user_id, channel_id)
if perms.bits.embed_links:
await _spawn_embed(
- app,
{
"id": message_id,
"channel_id": channel_id,
diff --git a/litecord/blueprints/channel/reactions.py b/litecord/blueprints/channel/reactions.py
index 4c8dabd..ac1440f 100644
--- a/litecord/blueprints/channel/reactions.py
+++ b/litecord/blueprints/channel/reactions.py
@@ -23,10 +23,9 @@ from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger
-from litecord.utils import async_map
+from litecord.utils import async_map, query_tuple_from_args, extract_limit
from litecord.blueprints.auth import token_check
from litecord.blueprints.checks import channel_check, channel_perm_check
-from litecord.blueprints.channel.messages import query_tuple_from_args, extract_limit
from litecord.enums import GUILD_CHANS
@@ -165,7 +164,8 @@ def _emoji_sql_simple(emoji: str, param=4):
return emoji_sql(emoji_type, emoji_id, emoji_name, param)
-async def remove_reaction(channel_id: int, message_id: int, user_id: int, emoji: str):
+async def _remove_reaction(channel_id: int, message_id: int, user_id: int, emoji: str):
+ """Remove given reaction from a message."""
ctype, guild_id = await channel_check(user_id, channel_id)
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
@@ -201,8 +201,7 @@ async def remove_own_reaction(channel_id, message_id, emoji):
"""Remove a reaction."""
user_id = await token_check()
- await remove_reaction(channel_id, message_id, user_id, emoji)
-
+ await _remove_reaction(channel_id, message_id, user_id, emoji)
return "", 204
@@ -212,7 +211,7 @@ async def remove_user_reaction(channel_id, message_id, emoji, other_id):
user_id = await token_check()
await channel_perm_check(user_id, channel_id, "manage_messages")
- await remove_reaction(channel_id, message_id, other_id, emoji)
+ await _remove_reaction(channel_id, message_id, other_id, emoji)
return "", 204
diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py
index a4b6c45..701c141 100644
--- a/litecord/blueprints/channels.py
+++ b/litecord/blueprints/channels.py
@@ -42,6 +42,7 @@ from litecord.blueprints.dm_channels import gdm_remove_recipient, gdm_destroy
from litecord.utils import search_result_from_list
from litecord.embed.messages import process_url_embed, msg_update_embeds
from litecord.snowflake import snowflake_datetime
+from litecord.common.channels import channel_ack
log = Logger(__name__)
bp = Blueprint("channels", __name__)
@@ -136,7 +137,7 @@ async def _update_guild_chan_cat(guild_id: int, channel_id: int):
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", child)
-async def delete_messages(channel_id):
+async def _delete_messages(channel_id):
await app.db.execute(
"""
DELETE FROM channel_pins
@@ -162,7 +163,7 @@ async def delete_messages(channel_id):
)
-async def guild_cleanup(channel_id):
+async def _guild_cleanup(channel_id):
await app.db.execute(
"""
DELETE FROM channel_overwrites
@@ -220,8 +221,8 @@ async def close_channel(channel_id):
# didn't work on my setup, so I delete
# everything before moving to the main
# channel table deletes
- await delete_messages(channel_id)
- await guild_cleanup(channel_id)
+ await _delete_messages(channel_id)
+ await _guild_cleanup(channel_id)
await app.db.execute(
f"""
@@ -595,48 +596,6 @@ async def trigger_typing(channel_id):
return "", 204
-async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
- """ACK a channel."""
-
- if not message_id:
- message_id = await app.storage.chan_last_message(channel_id)
-
- await app.db.execute(
- """
- INSERT INTO user_read_state
- (user_id, channel_id, last_message_id, mention_count)
- VALUES
- ($1, $2, $3, 0)
- ON CONFLICT ON CONSTRAINT user_read_state_pkey
- DO
- UPDATE
- SET last_message_id = $3, mention_count = 0
- WHERE user_read_state.user_id = $1
- AND user_read_state.channel_id = $2
- """,
- user_id,
- channel_id,
- message_id,
- )
-
- if guild_id:
- await app.dispatcher.dispatch_user_guild(
- user_id,
- guild_id,
- "MESSAGE_ACK",
- {"message_id": str(message_id), "channel_id": str(channel_id)},
- )
- else:
- # we don't use ChannelDispatcher here because since
- # guild_id is None, all user devices are already subscribed
- # to the given channel (a dm or a group dm)
- await app.dispatcher.dispatch_user(
- user_id,
- "MESSAGE_ACK",
- {"message_id": str(message_id), "channel_id": str(channel_id)},
- )
-
-
@bp.route("//messages//ack", methods=["POST"])
async def ack_channel(channel_id, message_id):
"""Acknowledge a channel."""
@@ -799,7 +758,7 @@ async def suppress_embeds(channel_id: int, message_id: int):
message["flags"] = message.get("flags", 0) | MessageFlags.suppress_embeds
- await msg_update_embeds(message, [], app.storage, app.dispatcher)
+ await msg_update_embeds(message, [])
elif not suppress and not url_embeds:
# spawn process_url_embed to restore the embeds, if any
await _msg_unset_flags(message_id, MessageFlags.suppress_embeds)
@@ -809,11 +768,7 @@ async def suppress_embeds(channel_id: int, message_id: int):
except KeyError:
pass
- app.sched.spawn(
- process_url_embed(
- app.config, app.storage, app.dispatcher, app.session, message
- )
- )
+ app.sched.spawn(process_url_embed(message))
return "", 204
diff --git a/litecord/blueprints/dms.py b/litecord/blueprints/dms.py
index 975d7ea..2a50837 100644
--- a/litecord/blueprints/dms.py
+++ b/litecord/blueprints/dms.py
@@ -31,6 +31,7 @@ from ..snowflake import get_snowflake
from .auth import token_check
from litecord.blueprints.dm_channels import gdm_create, gdm_add_recipient
+from litecord.common.channels import try_dm_state
log = Logger(__name__)
bp = Blueprint("dms", __name__)
@@ -44,24 +45,6 @@ async def get_dms():
return jsonify(dms)
-async def try_dm_state(user_id: int, dm_id: int):
- """Try inserting the user into the dm state
- for the given DM.
-
- Does not do anything if the user is already
- in the dm state.
- """
- await app.db.execute(
- """
- INSERT INTO dm_channel_state (user_id, dm_id)
- VALUES ($1, $2)
- ON CONFLICT DO NOTHING
- """,
- user_id,
- dm_id,
- )
-
-
async def jsonify_dm(dm_id: int, user_id: int):
dm_chan = await app.storage.get_dm(dm_id, user_id)
return jsonify(dm_chan)
diff --git a/litecord/blueprints/guild/channels.py b/litecord/blueprints/guild/channels.py
index c8a2d2e..de20b95 100644
--- a/litecord/blueprints/guild/channels.py
+++ b/litecord/blueprints/guild/channels.py
@@ -27,77 +27,11 @@ from litecord.blueprints.guild.roles import gen_pairs
from litecord.schemas import validate, ROLE_UPDATE_POSITION, CHAN_CREATE
from litecord.blueprints.checks import guild_check, guild_owner_check, guild_perm_check
-
+from litecord.common.guilds import create_guild_channel
bp = Blueprint("guild_channels", __name__)
-async def _specific_chan_create(channel_id, ctype, **kwargs):
- if ctype == ChannelType.GUILD_TEXT:
- await app.db.execute(
- """
- INSERT INTO guild_text_channels (id, topic)
- VALUES ($1, $2)
- """,
- channel_id,
- kwargs.get("topic", ""),
- )
- elif ctype == ChannelType.GUILD_VOICE:
- await app.db.execute(
- """
- INSERT INTO guild_voice_channels (id, bitrate, user_limit)
- VALUES ($1, $2, $3)
- """,
- channel_id,
- kwargs.get("bitrate", 64),
- kwargs.get("user_limit", 0),
- )
-
-
-async def create_guild_channel(
- guild_id: int, channel_id: int, ctype: ChannelType, **kwargs
-):
- """Create a channel in a guild."""
- await app.db.execute(
- """
- INSERT INTO channels (id, channel_type)
- VALUES ($1, $2)
- """,
- channel_id,
- ctype.value,
- )
-
- # calc new pos
- max_pos = await app.db.fetchval(
- """
- SELECT MAX(position)
- FROM guild_channels
- WHERE guild_id = $1
- """,
- guild_id,
- )
-
- # account for the first channel in a guild too
- max_pos = max_pos or 0
-
- # all channels go to guild_channels
- await app.db.execute(
- """
- INSERT INTO guild_channels (id, guild_id, name, position)
- VALUES ($1, $2, $3, $4)
- """,
- channel_id,
- guild_id,
- kwargs["name"],
- max_pos + 1,
- )
-
- # the rest of sql magic is dependant on the channel
- # we're creating (a text or voice or category),
- # so we use this function.
- await _specific_chan_create(channel_id, ctype, **kwargs)
-
-
@bp.route("//channels", methods=["GET"])
async def get_guild_channels(guild_id):
"""Get the list of channels in a guild."""
diff --git a/litecord/blueprints/guild/mod.py b/litecord/blueprints/guild/mod.py
index 5949fd0..4032bb0 100644
--- a/litecord/blueprints/guild/mod.py
+++ b/litecord/blueprints/guild/mod.py
@@ -23,47 +23,11 @@ from litecord.blueprints.auth import token_check
from litecord.blueprints.checks import guild_perm_check
from litecord.schemas import validate, GUILD_PRUNE
+from litecord.common.guilds import remove_member, remove_member_multi
bp = Blueprint("guild_moderation", __name__)
-async def remove_member(guild_id: int, member_id: int):
- """Do common tasks related to deleting a member from the guild,
- such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE."""
-
- await app.db.execute(
- """
- DELETE FROM members
- WHERE guild_id = $1 AND user_id = $2
- """,
- guild_id,
- member_id,
- )
-
- await app.dispatcher.dispatch_user_guild(
- member_id,
- guild_id,
- "GUILD_DELETE",
- {"guild_id": str(guild_id), "unavailable": False},
- )
-
- await app.dispatcher.unsub("guild", guild_id, member_id)
-
- await app.dispatcher.dispatch("lazy_guild", guild_id, "remove_member", member_id)
-
- await app.dispatcher.dispatch_guild(
- guild_id,
- "GUILD_MEMBER_REMOVE",
- {"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
- )
-
-
-async def remove_member_multi(guild_id: int, members: list):
- """Remove multiple members."""
- for member_id in members:
- await remove_member(guild_id, member_id)
-
-
@bp.route("//members/", methods=["DELETE"])
async def kick_guild_member(guild_id, member_id):
"""Remove a member from a guild."""
@@ -221,6 +185,5 @@ async def begin_guild_prune(guild_id):
days = j["days"]
member_ids = await get_prune(guild_id, days)
- app.loop.create_task(remove_member_multi(guild_id, member_ids))
-
+ app.sched.spawn(remove_member_multi(guild_id, member_ids))
return jsonify({"pruned": len(member_ids)})
diff --git a/litecord/blueprints/guild/roles.py b/litecord/blueprints/guild/roles.py
index 9516aa4..98440fa 100644
--- a/litecord/blueprints/guild/roles.py
+++ b/litecord/blueprints/guild/roles.py
@@ -27,11 +27,9 @@ from litecord.auth import token_check
from litecord.blueprints.checks import guild_check, guild_perm_check
from litecord.schemas import validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION
-from litecord.snowflake import get_snowflake
-from litecord.utils import dict_get
-from litecord.permissions import get_role_perms
+from litecord.utils import maybe_lazy_guild_dispatch
+from litecord.common.guilds import create_role
-DEFAULT_EVERYONE_PERMS = 104324161
log = Logger(__name__)
bp = Blueprint("guild_roles", __name__)
@@ -45,71 +43,6 @@ async def get_guild_roles(guild_id):
return jsonify(await app.storage.get_role_data(guild_id))
-async def _maybe_lg(guild_id: int, event: str, role, force: bool = False):
- # sometimes we want to dispatch an event
- # even if the role isn't hoisted
-
- # an example of such a case is when a role loses
- # its hoist status.
-
- # check if is a dict first because role_delete
- # only receives the role id.
- if isinstance(role, dict) and not role["hoist"] and not force:
- return
-
- await app.dispatcher.dispatch("lazy_guild", guild_id, event, role)
-
-
-async def create_role(guild_id, name: str, **kwargs):
- """Create a role in a guild."""
- new_role_id = get_snowflake()
-
- everyone_perms = await get_role_perms(guild_id, guild_id)
- default_perms = dict_get(kwargs, "default_perms", everyone_perms.binary)
-
- # update all roles so that we have space for pos 1, but without
- # sending GUILD_ROLE_UPDATE for everyone
- await app.db.execute(
- """
- UPDATE roles
- SET
- position = position + 1
- WHERE guild_id = $1
- AND NOT (position = 0)
- """,
- guild_id,
- )
-
- await app.db.execute(
- """
- INSERT INTO roles (id, guild_id, name, color,
- hoist, position, permissions, managed, mentionable)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
- """,
- new_role_id,
- guild_id,
- name,
- dict_get(kwargs, "color", 0),
- dict_get(kwargs, "hoist", False),
- # always set ourselves on position 1
- 1,
- int(dict_get(kwargs, "permissions", default_perms)),
- False,
- dict_get(kwargs, "mentionable", False),
- )
-
- role = await app.storage.get_role(new_role_id, guild_id)
-
- # we need to update the lazy guild handlers for the newly created group
- await _maybe_lg(guild_id, "new_role", role)
-
- await app.dispatcher.dispatch_guild(
- guild_id, "GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role}
- )
-
- return role
-
-
@bp.route("//roles", methods=["POST"])
async def create_guild_role(guild_id: int):
"""Add a role to a guild"""
@@ -132,7 +65,7 @@ async def _role_update_dispatch(role_id: int, guild_id: int):
"""Dispatch a GUILD_ROLE_UPDATE with updated information on a role."""
role = await app.storage.get_role(role_id, guild_id)
- await _maybe_lg(guild_id, "role_pos_upd", role)
+ await maybe_lazy_guild_dispatch(guild_id, "role_pos_upd", role)
await app.dispatcher.dispatch_guild(
guild_id, "GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role}
@@ -343,7 +276,7 @@ async def update_guild_role(guild_id, role_id):
)
role = await _role_update_dispatch(role_id, guild_id)
- await _maybe_lg(guild_id, "role_update", role, True)
+ await maybe_lazy_guild_dispatch(guild_id, "role_update", role, True)
return jsonify(role)
@@ -369,7 +302,7 @@ async def delete_guild_role(guild_id, role_id):
if res == "DELETE 0":
return "", 204
- await _maybe_lg(guild_id, "role_delete", role_id, True)
+ await maybe_lazy_guild_dispatch(guild_id, "role_delete", role_id, True)
await app.dispatcher.dispatch_guild(
guild_id,
diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py
index 7db4224..9317429 100644
--- a/litecord/blueprints/guilds.py
+++ b/litecord/blueprints/guilds.py
@@ -21,8 +21,12 @@ from typing import Optional, List
from quart import Blueprint, request, current_app as app, jsonify
-from litecord.blueprints.guild.channels import create_guild_channel
-from litecord.blueprints.guild.roles import create_role, DEFAULT_EVERYONE_PERMS
+from litecord.common.guilds import (
+ create_role,
+ create_guild_channel,
+ delete_guild,
+ create_guild_settings,
+)
from ..auth import token_check
from ..snowflake import get_snowflake
@@ -34,44 +38,17 @@ from ..schemas import (
SEARCH_CHANNEL,
VANITY_URL_PATCH,
)
-from .channels import channel_ack
from .checks import guild_check, guild_owner_check, guild_perm_check
+from ..common.channels import channel_ack
from litecord.utils import to_update, search_result_from_list
from litecord.errors import BadRequest
from litecord.permissions import get_permissions
+DEFAULT_EVERYONE_PERMS = 104324161
bp = Blueprint("guilds", __name__)
-async def create_guild_settings(guild_id: int, user_id: int):
- """Create guild settings for the user
- joining the guild."""
-
- # new guild_settings are based off the currently
- # set guild settings (for the guild)
- m_notifs = await app.db.fetchval(
- """
- SELECT default_message_notifications
- FROM guilds
- WHERE id = $1
- """,
- guild_id,
- )
-
- await app.db.execute(
- """
- INSERT INTO guild_settings
- (user_id, guild_id, message_notifications)
- VALUES
- ($1, $2, $3)
- """,
- user_id,
- guild_id,
- m_notifs,
- )
-
-
async def add_member(guild_id: int, user_id: int):
"""Add a user to a guild."""
await app.db.execute(
@@ -393,36 +370,6 @@ async def _update_guild(guild_id):
return jsonify(guild)
-async def delete_guild(guild_id: int, *, app_=None):
- """Delete a single guild."""
- app_ = app_ or app
-
- await app_.db.execute(
- """
- DELETE FROM guilds
- WHERE guilds.id = $1
- """,
- guild_id,
- )
-
- # Discord's client expects IDs being string
- await app_.dispatcher.dispatch(
- "guild",
- guild_id,
- "GUILD_DELETE",
- {
- "guild_id": str(guild_id),
- "id": str(guild_id),
- # 'unavailable': False,
- },
- )
-
- # remove from the dispatcher so nobody
- # becomes the little memer that tries to fuck up with
- # everybody's gateway
- await app_.dispatcher.remove("guild", guild_id)
-
-
@bp.route("/", methods=["DELETE"])
# this endpoint is not documented, but used by the official client.
@bp.route("//delete", methods=["POST"])
diff --git a/litecord/blueprints/icons.py b/litecord/blueprints/icons.py
index 9b0f378..a01509e 100644
--- a/litecord/blueprints/icons.py
+++ b/litecord/blueprints/icons.py
@@ -64,7 +64,7 @@ async def _get_default_user_avatar(default_id: int):
async def _handle_webhook_avatar(md_url_redir: str):
- md_url = make_md_req_url(app.config, "img", EmbedURL(md_url_redir))
+ md_url = make_md_req_url("img", EmbedURL(md_url_redir))
return redirect(md_url)
diff --git a/litecord/blueprints/invites.py b/litecord/blueprints/invites.py
index 02b7610..0a9f7d7 100644
--- a/litecord/blueprints/invites.py
+++ b/litecord/blueprints/invites.py
@@ -28,7 +28,6 @@ from ..auth import token_check
from ..schemas import validate, INVITE
from ..enums import ChannelType
from ..errors import BadRequest, Forbidden
-from .guilds import create_guild_settings
from ..utils import async_map
from litecord.blueprints.checks import (
@@ -39,6 +38,7 @@ from litecord.blueprints.checks import (
)
from litecord.blueprints.dm_channels import gdm_is_member, gdm_add_recipient
+from litecord.common.guilds import create_guild_settings
log = Logger(__name__)
bp = Blueprint("invites", __name__)
diff --git a/litecord/blueprints/user/billing.py b/litecord/blueprints/user/billing.py
index a1a4148..3bf770d 100644
--- a/litecord/blueprints/user/billing.py
+++ b/litecord/blueprints/user/billing.py
@@ -30,7 +30,7 @@ from litecord.snowflake import snowflake_datetime, get_snowflake
from litecord.errors import BadRequest
from litecord.types import timestamp_, HOURS
from litecord.enums import UserFlags, PremiumType
-from litecord.blueprints.users import mass_user_update
+from litecord.common.users import mass_user_update
log = Logger(__name__)
bp = Blueprint("users_billing", __name__)
@@ -122,16 +122,13 @@ async def get_payment_source_ids(user_id: int) -> list:
return [r["id"] for r in rows]
-async def get_payment_ids(user_id: int, db=None) -> list:
- if not db:
- db = app.db
-
- rows = await db.fetch(
+async def get_payment_ids(user_id: int) -> list:
+ rows = await app.db.fetch(
"""
- SELECT id
- FROM user_payments
- WHERE user_id = $1
- """,
+ SELECT id
+ FROM user_payments
+ WHERE user_id = $1
+ """,
user_id,
)
@@ -151,18 +148,14 @@ async def get_subscription_ids(user_id: int) -> list:
return [r["id"] for r in rows]
-async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
+async def get_payment_source(user_id: int, source_id: int) -> dict:
"""Get a payment source's information."""
-
- if not db:
- db = app.db
-
- source_type = await db.fetchval(
+ source_type = await app.db.fetchval(
"""
- SELECT source_type
- FROM user_payment_sources
- WHERE id = $1 AND user_id = $2
- """,
+ SELECT source_type
+ FROM user_payment_sources
+ WHERE id = $1 AND user_id = $2
+ """,
source_id,
user_id,
)
@@ -176,7 +169,7 @@ async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
fields = ",".join(specific_fields)
- extras_row = await db.fetchrow(
+ extras_row = await app.db.fetchrow(
f"""
SELECT {fields}, billing_address, default_, id::text
FROM user_payment_sources
@@ -199,22 +192,19 @@ async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
return {**source, **derow}
-async def get_subscription(subscription_id: int, db=None):
+async def get_subscription(subscription_id: int):
"""Get a subscription's information."""
- if not db:
- db = app.db
-
- row = await db.fetchrow(
+ row = await app.db.fetchrow(
"""
- SELECT id::text, source_id::text AS payment_source_id,
- user_id,
- payment_gateway, payment_gateway_plan_id,
- period_start AS current_period_start,
- period_end AS current_period_end,
- canceled_at, s_type, status
- FROM user_subscriptions
- WHERE id = $1
- """,
+ SELECT id::text, source_id::text AS payment_source_id,
+ user_id,
+ payment_gateway, payment_gateway_plan_id,
+ period_start AS current_period_start,
+ period_end AS current_period_end,
+ canceled_at, s_type, status
+ FROM user_subscriptions
+ WHERE id = $1
+ """,
subscription_id,
)
@@ -231,19 +221,16 @@ async def get_subscription(subscription_id: int, db=None):
return drow
-async def get_payment(payment_id: int, db=None):
+async def get_payment(payment_id: int):
"""Get a single payment's information."""
- if not db:
- db = app.db
-
- row = await db.fetchrow(
+ row = await app.db.fetchrow(
"""
- SELECT id::text, source_id, subscription_id, user_id,
- amount, amount_refunded, currency,
- description, status, tax, tax_inclusive
- FROM user_payments
- WHERE id = $1
- """,
+ SELECT id::text, source_id, subscription_id, user_id,
+ amount, amount_refunded, currency,
+ description, status, tax, tax_inclusive
+ FROM user_payments
+ WHERE id = $1
+ """,
payment_id,
)
@@ -255,27 +242,22 @@ async def get_payment(payment_id: int, db=None):
drow["created_at"] = snowflake_datetime(int(drow["id"]))
- drow["payment_source"] = await get_payment_source(
- row["user_id"], row["source_id"], db
- )
+ drow["payment_source"] = await get_payment_source(row["user_id"], row["source_id"])
- drow["subscription"] = await get_subscription(row["subscription_id"], db)
+ drow["subscription"] = await get_subscription(row["subscription_id"])
return drow
-async def create_payment(subscription_id, db=None):
+async def create_payment(subscription_id):
"""Create a payment."""
- if not db:
- db = app.db
-
- sub = await get_subscription(subscription_id, db)
+ sub = await get_subscription(subscription_id)
new_id = get_snowflake()
amount = AMOUNTS[sub["payment_gateway_plan_id"]]
- await db.execute(
+ await app.db.execute(
"""
INSERT INTO user_payments (
id, source_id, subscription_id, user_id,
@@ -298,9 +280,9 @@ async def create_payment(subscription_id, db=None):
return new_id
-async def process_subscription(app, subscription_id: int):
+async def process_subscription(subscription_id: int):
"""Process a single subscription."""
- sub = await get_subscription(subscription_id, app.db)
+ sub = await get_subscription(subscription_id)
user_id = int(sub["user_id"])
@@ -313,10 +295,10 @@ async def process_subscription(app, subscription_id: int):
# payments), then we should update premium status
first_payment_id = await app.db.fetchval(
"""
- SELECT MIN(id)
- FROM user_payments
- WHERE subscription_id = $1
- """,
+ SELECT MIN(id)
+ FROM user_payments
+ WHERE subscription_id = $1
+ """,
subscription_id,
)
@@ -324,10 +306,10 @@ async def process_subscription(app, subscription_id: int):
premium_since = await app.db.fetchval(
"""
- SELECT premium_since
- FROM users
- WHERE id = $1
- """,
+ SELECT premium_since
+ FROM users
+ WHERE id = $1
+ """,
user_id,
)
@@ -343,10 +325,10 @@ async def process_subscription(app, subscription_id: int):
old_flags = await app.db.fetchval(
"""
- SELECT flags
- FROM users
- WHERE id = $1
- """,
+ SELECT flags
+ FROM users
+ WHERE id = $1
+ """,
user_id,
)
@@ -355,17 +337,17 @@ async def process_subscription(app, subscription_id: int):
await app.db.execute(
"""
- UPDATE users
- SET premium_since = $1, flags = $2
- WHERE id = $3
- """,
+ UPDATE users
+ SET premium_since = $1, flags = $2
+ WHERE id = $3
+ """,
first_payment_ts,
new_flags,
user_id,
)
# dispatch updated user to all possible clients
- await mass_user_update(user_id, app)
+ await mass_user_update(user_id)
@bp.route("/@me/billing/payment-sources", methods=["GET"])
@@ -474,11 +456,11 @@ async def _create_subscription():
1,
)
- await create_payment(new_id, app.db)
+ await create_payment(new_id)
# make sure we update the user's premium status
# and dispatch respective user updates to other people.
- await process_subscription(app, new_id)
+ await process_subscription(new_id)
return jsonify(await get_subscription(new_id))
diff --git a/litecord/blueprints/user/billing_job.py b/litecord/blueprints/user/billing_job.py
index 4148415..ee50c33 100644
--- a/litecord/blueprints/user/billing_job.py
+++ b/litecord/blueprints/user/billing_job.py
@@ -21,6 +21,8 @@ along with this program. If not, see .
this file only serves the periodic payment job code.
"""
import datetime
+
+from quart import current_app as app
from asyncio import sleep, CancelledError
from logbook import Logger
@@ -47,14 +49,14 @@ THRESHOLDS = {
}
-async def _resched(app):
+async def _resched():
log.debug("waiting 30 minutes for job.")
await sleep(30 * MINUTES)
- app.sched.spawn(payment_job(app))
+ app.sched.spawn(payment_job())
-async def _process_user_payments(app, user_id: int):
- payments = await get_payment_ids(user_id, app.db)
+async def _process_user_payments(user_id: int):
+ payments = await get_payment_ids(user_id)
if not payments:
log.debug("no payments for uid {}, skipping", user_id)
@@ -64,7 +66,7 @@ async def _process_user_payments(app, user_id: int):
latest_payment = max(payments)
- payment_data = await get_payment(latest_payment, app.db)
+ payment_data = await get_payment(latest_payment)
# calculate the difference between this payment
# and now.
@@ -74,7 +76,7 @@ async def _process_user_payments(app, user_id: int):
delta = now - payment_tstamp
sub_id = int(payment_data["subscription"]["id"])
- subscription = await get_subscription(sub_id, app.db)
+ subscription = await get_subscription(sub_id)
# if the max payment is X days old, we create another.
# X is 30 for monthly subscriptions of nitro,
@@ -89,12 +91,12 @@ async def _process_user_payments(app, user_id: int):
# create_payment does not call any Stripe
# or BrainTree APIs at all, since we'll just
# give it as free.
- await create_payment(sub_id, app.db)
+ await create_payment(sub_id)
else:
log.debug("sid={}, missing {} days", sub_id, threshold - delta.days)
-async def payment_job(app):
+async def payment_job():
"""Main payment job function.
This function will check through users' payments
@@ -104,9 +106,9 @@ async def payment_job(app):
user_ids = await app.db.fetch(
"""
- SELECT DISTINCT user_id
- FROM user_payments
- """
+ SELECT DISTINCT user_id
+ FROM user_payments
+ """
)
log.debug("working {} users", len(user_ids))
@@ -115,24 +117,24 @@ async def payment_job(app):
for row in user_ids:
user_id = row["user_id"]
try:
- await _process_user_payments(app, user_id)
+ await _process_user_payments(user_id)
except Exception:
log.exception("error while processing user payments")
subscribers = await app.db.fetch(
"""
- SELECT id
- FROM user_subscriptions
- """
+ SELECT id
+ FROM user_subscriptions
+ """
)
for row in subscribers:
try:
- await process_subscription(app, row["id"])
+ await process_subscription(row["id"])
except Exception:
log.exception("error while processing subscription")
log.debug("rescheduling..")
try:
- await _resched(app)
+ await _resched()
except CancelledError:
log.info("cancelled while waiting for resched")
diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py
index e45a1a3..9223c61 100644
--- a/litecord/blueprints/users.py
+++ b/litecord/blueprints/users.py
@@ -17,7 +17,6 @@ along with this program. If not, see .
"""
-from os import urandom
from asyncpg import UniqueViolationError
from quart import Blueprint, jsonify, request, current_app as app
@@ -27,8 +26,8 @@ from ..errors import Forbidden, BadRequest, Unauthorized
from ..schemas import validate, USER_UPDATE, GET_MENTIONS
from .guilds import guild_check
-from litecord.auth import token_check, hash_data, check_username_usage, roll_discrim
-from litecord.blueprints.guild.mod import remove_member
+from litecord.auth import token_check, hash_data
+from litecord.common.guilds import remove_member
from litecord.enums import PremiumType
from litecord.images import parse_data_uri
@@ -36,46 +35,18 @@ from litecord.permissions import base_permissions
from litecord.blueprints.auth import check_password
from litecord.utils import to_update
+from litecord.common.users import (
+ mass_user_update,
+ delete_user,
+ check_username_usage,
+ roll_discrim,
+ user_disconnect,
+)
bp = Blueprint("user", __name__)
log = Logger(__name__)
-async def mass_user_update(user_id):
- """Dispatch USER_UPDATE in a mass way."""
- # by using dispatch_with_filter
- # we're guaranteeing all shards will get
- # a USER_UPDATE once and not any others.
-
- session_ids = []
-
- public_user = await app.storage.get_user(user_id)
- private_user = await app.storage.get_user(user_id, secure=True)
-
- session_ids.extend(
- await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", private_user)
- )
-
- guild_ids = await app.user_storage.get_user_guilds(user_id)
- friend_ids = await app.user_storage.get_friend_ids(user_id)
-
- session_ids.extend(
- await app.dispatcher.dispatch_many_filter_list(
- "guild", guild_ids, session_ids, "USER_UPDATE", public_user
- )
- )
-
- session_ids.extend(
- await app.dispatcher.dispatch_many_filter_list(
- "friend", friend_ids, session_ids, "USER_UPDATE", public_user
- )
- )
-
- await app.dispatcher.dispatch_many("lazy_guild", guild_ids, "update_user", user_id)
-
- return public_user, private_user
-
-
@bp.route("/@me", methods=["GET"])
async def get_me():
"""Get the current user's information."""
@@ -276,7 +247,7 @@ async def patch_me():
user.pop("password_hash")
- _, private_user = await mass_user_update(user_id, app)
+ _, private_user = await mass_user_update(user_id)
return jsonify(private_user)
@@ -319,7 +290,6 @@ async def leave_guild(guild_id: int):
await guild_check(user_id, guild_id)
await remove_member(guild_id, user_id)
-
return "", 204
@@ -468,118 +438,6 @@ async def _get_mentions():
return jsonify(res)
-def rand_hex(length: int = 8) -> str:
- """Generate random hex characters."""
- return urandom(length).hex()[:length]
-
-
-async def _del_from_table(db, table: str, user_id: int):
- """Delete a row from a table."""
- column = {
- "channel_overwrites": "target_user",
- "user_settings": "id",
- "group_dm_members": "member_id",
- }.get(table, "user_id")
-
- res = await db.execute(
- f"""
- DELETE FROM {table}
- WHERE {column} = $1
- """,
- user_id,
- )
-
- log.info("Deleting uid {} from {}, res: {!r}", user_id, table, res)
-
-
-async def delete_user(user_id, *, mass_update: bool = True):
- """Delete a user. Does not disconnect the user."""
- db = app.db
-
- new_username = f"Deleted User {rand_hex()}"
-
- # by using a random hex in password_hash
- # we break attempts at using the default '123' password hash
- # to issue valid tokens for deleted users.
-
- await db.execute(
- """
- UPDATE users
- SET
- username = $1,
- email = NULL,
- mfa_enabled = false,
- verified = false,
- avatar = NULL,
- flags = 0,
- premium_since = NULL,
- phone = '',
- password_hash = $2
- WHERE
- id = $3
- """,
- new_username,
- rand_hex(32),
- user_id,
- )
-
- # remove the user from various tables
- await _del_from_table(db, "user_settings", user_id)
- await _del_from_table(db, "user_payment_sources", user_id)
- await _del_from_table(db, "user_subscriptions", user_id)
- await _del_from_table(db, "user_payments", user_id)
- await _del_from_table(db, "user_read_state", user_id)
- await _del_from_table(db, "guild_settings", user_id)
- await _del_from_table(db, "guild_settings_channel_overrides", user_id)
-
- await db.execute(
- """
- DELETE FROM relationships
- WHERE user_id = $1 OR peer_id = $1
- """,
- user_id,
- )
-
- # DMs are still maintained, but not the state.
- await _del_from_table(db, "dm_channel_state", user_id)
-
- # NOTE: we don't delete the group dms the user is an owner of...
- # TODO: group dm owner reassign when the owner leaves a gdm
- await _del_from_table(db, "group_dm_members", user_id)
-
- await _del_from_table(db, "members", user_id)
- await _del_from_table(db, "member_roles", user_id)
- await _del_from_table(db, "channel_overwrites", user_id)
-
- # after updating the user, we send USER_UPDATE so that all the other
- # clients can refresh their caches on the now-deleted user
- if mass_update:
- await mass_user_update(user_id)
-
-
-async def user_disconnect(user_id: int):
- """Disconnects the given user's devices."""
- # after removing the user from all tables, we need to force
- # all known user states to reconnect, causing the user to not
- # be online anymore.
- user_states = app.state_manager.user_states(user_id)
-
- for state in user_states:
- # make it unable to resume
- app.state_manager.remove(state)
-
- if not state.ws:
- continue
-
- # force a close, 4000 should make the client reconnect.
- await state.ws.ws.close(4000)
-
- # force everyone to see the user as offline
- await app.presence.dispatch_pres(
- user_id, {"afk": False, "status": "offline", "game": None, "since": 0}
- )
-
-
@bp.route("/@me/delete", methods=["POST"])
async def delete_account():
"""Delete own account.
diff --git a/litecord/blueprints/webhooks.py b/litecord/blueprints/webhooks.py
index 9a0be5f..bb3668b 100644
--- a/litecord/blueprints/webhooks.py
+++ b/litecord/blueprints/webhooks.py
@@ -43,7 +43,7 @@ from litecord.snowflake import get_snowflake
from litecord.utils import async_map
from litecord.errors import WebhookNotFound, Unauthorized, ChannelNotFound, BadRequest
-from litecord.blueprints.channel.messages import (
+from litecord.common.messages import (
msg_create_request,
msg_create_check_content,
msg_add_attachment,
@@ -499,9 +499,7 @@ async def execute_webhook(webhook_id: int, webhook_token):
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload)
# spawn embedder in the background, even when we're on a webhook.
- app.sched.spawn(
- process_url_embed(app.config, app.storage, app.dispatcher, app.session, payload)
- )
+ app.sched.spawn(process_url_embed(payload))
# we can assume its a guild text channel, so just call it
await msg_guild_text_mentions(payload, guild_id, mentions_everyone, mentions_here)
diff --git a/litecord/common/__init__.py b/litecord/common/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/litecord/blueprints/channel/dm_checks.py b/litecord/common/channels.py
similarity index 53%
rename from litecord/blueprints/channel/dm_checks.py
rename to litecord/common/channels.py
index e2cb195..dc85ef4 100644
--- a/litecord/blueprints/channel/dm_checks.py
+++ b/litecord/common/channels.py
@@ -16,15 +16,55 @@ You should have received a copy of the GNU General Public License
along with this program. If not, see .
"""
-
from quart import current_app as app
-from litecord.errors import Forbidden
+
+from litecord.errors import ForbiddenDM
from litecord.enums import RelationshipType
-class ForbiddenDM(Forbidden):
- error_code = 50007
+async def channel_ack(
+ user_id: int, guild_id: int, channel_id: int, message_id: int = None
+):
+ """ACK a channel."""
+
+ if not message_id:
+ message_id = await app.storage.chan_last_message(channel_id)
+
+ await app.db.execute(
+ """
+ INSERT INTO user_read_state
+ (user_id, channel_id, last_message_id, mention_count)
+ VALUES
+ ($1, $2, $3, 0)
+ ON CONFLICT ON CONSTRAINT user_read_state_pkey
+ DO
+ UPDATE
+ SET last_message_id = $3, mention_count = 0
+ WHERE user_read_state.user_id = $1
+ AND user_read_state.channel_id = $2
+ """,
+ user_id,
+ channel_id,
+ message_id,
+ )
+
+ if guild_id:
+ await app.dispatcher.dispatch_user_guild(
+ user_id,
+ guild_id,
+ "MESSAGE_ACK",
+ {"message_id": str(message_id), "channel_id": str(channel_id)},
+ )
+ else:
+ # we don't use ChannelDispatcher here because since
+ # guild_id is None, all user devices are already subscribed
+ # to the given channel (a dm or a group dm)
+ await app.dispatcher.dispatch_user(
+ user_id,
+ "MESSAGE_ACK",
+ {"message_id": str(message_id), "channel_id": str(channel_id)},
+ )
async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
@@ -32,12 +72,12 @@ async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
# first step is checking if there is a block in any direction
blockrow = await app.db.fetchrow(
"""
- SELECT rel_type
- FROM relationships
- WHERE rel_type = $3
- AND user_id IN ($1, $2)
- AND peer_id IN ($1, $2)
- """,
+ SELECT rel_type
+ FROM relationships
+ WHERE rel_type = $3
+ AND user_id IN ($1, $2)
+ AND peer_id IN ($1, $2)
+ """,
user_id,
peer_id,
RelationshipType.BLOCK.value,
@@ -75,3 +115,21 @@ async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
# if after this filtering we don't have any more guilds, error
if not mutual_guilds:
raise ForbiddenDM()
+
+
+async def try_dm_state(user_id: int, dm_id: int):
+ """Try inserting the user into the dm state
+ for the given DM.
+
+ Does not do anything if the user is already
+ in the dm state.
+ """
+ await app.db.execute(
+ """
+ INSERT INTO dm_channel_state (user_id, dm_id)
+ VALUES ($1, $2)
+ ON CONFLICT DO NOTHING
+ """,
+ user_id,
+ dm_id,
+ )
diff --git a/litecord/common/guilds.py b/litecord/common/guilds.py
new file mode 100644
index 0000000..d7d1e25
--- /dev/null
+++ b/litecord/common/guilds.py
@@ -0,0 +1,234 @@
+"""
+
+Litecord
+Copyright (C) 2018-2019 Luna Mendes
+
+This program is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, version 3 of the License.
+
+This program is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with this program. If not, see .
+
+"""
+
+from quart import current_app as app
+
+from ..snowflake import get_snowflake
+from ..permissions import get_role_perms
+from ..utils import dict_get, maybe_lazy_guild_dispatch
+from ..enums import ChannelType
+
+
+async def remove_member(guild_id: int, member_id: int):
+ """Do common tasks related to deleting a member from the guild,
+ such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE."""
+
+ await app.db.execute(
+ """
+ DELETE FROM members
+ WHERE guild_id = $1 AND user_id = $2
+ """,
+ guild_id,
+ member_id,
+ )
+
+ await app.dispatcher.dispatch_user_guild(
+ member_id,
+ guild_id,
+ "GUILD_DELETE",
+ {"guild_id": str(guild_id), "unavailable": False},
+ )
+
+ await app.dispatcher.unsub("guild", guild_id, member_id)
+
+ await app.dispatcher.dispatch("lazy_guild", guild_id, "remove_member", member_id)
+
+ await app.dispatcher.dispatch_guild(
+ guild_id,
+ "GUILD_MEMBER_REMOVE",
+ {"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
+ )
+
+
+async def remove_member_multi(guild_id: int, members: list):
+ """Remove multiple members."""
+ for member_id in members:
+ await remove_member(guild_id, member_id)
+
+
+async def create_role(guild_id, name: str, **kwargs):
+ """Create a role in a guild."""
+ new_role_id = get_snowflake()
+
+ everyone_perms = await get_role_perms(guild_id, guild_id)
+ default_perms = dict_get(kwargs, "default_perms", everyone_perms.binary)
+
+ # update all roles so that we have space for pos 1, but without
+ # sending GUILD_ROLE_UPDATE for everyone
+ await app.db.execute(
+ """
+ UPDATE roles
+ SET
+ position = position + 1
+ WHERE guild_id = $1
+ AND NOT (position = 0)
+ """,
+ guild_id,
+ )
+
+ await app.db.execute(
+ """
+ INSERT INTO roles (id, guild_id, name, color,
+ hoist, position, permissions, managed, mentionable)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
+ """,
+ new_role_id,
+ guild_id,
+ name,
+ dict_get(kwargs, "color", 0),
+ dict_get(kwargs, "hoist", False),
+ # always set ourselves on position 1
+ 1,
+ int(dict_get(kwargs, "permissions", default_perms)),
+ False,
+ dict_get(kwargs, "mentionable", False),
+ )
+
+ role = await app.storage.get_role(new_role_id, guild_id)
+
+ # we need to update the lazy guild handlers for the newly created group
+ await maybe_lazy_guild_dispatch(guild_id, "new_role", role)
+
+ await app.dispatcher.dispatch_guild(
+ guild_id, "GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role}
+ )
+
+ return role
+
+
+async def _specific_chan_create(channel_id, ctype, **kwargs):
+ if ctype == ChannelType.GUILD_TEXT:
+ await app.db.execute(
+ """
+ INSERT INTO guild_text_channels (id, topic)
+ VALUES ($1, $2)
+ """,
+ channel_id,
+ kwargs.get("topic", ""),
+ )
+ elif ctype == ChannelType.GUILD_VOICE:
+ await app.db.execute(
+ """
+ INSERT INTO guild_voice_channels (id, bitrate, user_limit)
+ VALUES ($1, $2, $3)
+ """,
+ channel_id,
+ kwargs.get("bitrate", 64),
+ kwargs.get("user_limit", 0),
+ )
+
+
+async def create_guild_channel(
+ guild_id: int, channel_id: int, ctype: ChannelType, **kwargs
+):
+ """Create a channel in a guild."""
+ await app.db.execute(
+ """
+ INSERT INTO channels (id, channel_type)
+ VALUES ($1, $2)
+ """,
+ channel_id,
+ ctype.value,
+ )
+
+ # calc new pos
+ max_pos = await app.db.fetchval(
+ """
+ SELECT MAX(position)
+ FROM guild_channels
+ WHERE guild_id = $1
+ """,
+ guild_id,
+ )
+
+ # account for the first channel in a guild too
+ max_pos = max_pos or 0
+
+ # all channels go to guild_channels
+ await app.db.execute(
+ """
+ INSERT INTO guild_channels (id, guild_id, name, position)
+ VALUES ($1, $2, $3, $4)
+ """,
+ channel_id,
+ guild_id,
+ kwargs["name"],
+ max_pos + 1,
+ )
+
+ # the rest of sql magic is dependant on the channel
+ # we're creating (a text or voice or category),
+ # so we use this function.
+ await _specific_chan_create(channel_id, ctype, **kwargs)
+
+
+async def delete_guild(guild_id: int):
+ """Delete a single guild."""
+ await app.db.execute(
+ """
+ DELETE FROM guilds
+ WHERE guilds.id = $1
+ """,
+ guild_id,
+ )
+
+ # Discord's client expects IDs being string
+ await app.dispatcher.dispatch(
+ "guild",
+ guild_id,
+ "GUILD_DELETE",
+ {
+ "guild_id": str(guild_id),
+ "id": str(guild_id),
+ # 'unavailable': False,
+ },
+ )
+
+ # remove from the dispatcher so nobody
+ # becomes the little memer that tries to fuck up with
+ # everybody's gateway
+ await app.dispatcher.remove("guild", guild_id)
+
+
+async def create_guild_settings(guild_id: int, user_id: int):
+ """Create guild settings for the user
+ joining the guild."""
+
+ # new guild_settings are based off the currently
+ # set guild settings (for the guild)
+ m_notifs = await app.db.fetchval(
+ """
+ SELECT default_message_notifications
+ FROM guilds
+ WHERE id = $1
+ """,
+ guild_id,
+ )
+
+ await app.db.execute(
+ """
+ INSERT INTO guild_settings
+ (user_id, guild_id, message_notifications)
+ VALUES
+ ($1, $2, $3)
+ """,
+ user_id,
+ guild_id,
+ m_notifs,
+ )
diff --git a/litecord/common/messages.py b/litecord/common/messages.py
new file mode 100644
index 0000000..ab43767
--- /dev/null
+++ b/litecord/common/messages.py
@@ -0,0 +1,189 @@
+import json
+import logging
+
+from PIL import Image
+from quart import request, current_app as app
+
+from litecord.errors import BadRequest
+from ..snowflake import get_snowflake
+
+log = logging.getLogger(__name__)
+
+
+async def msg_create_request() -> tuple:
+ """Extract the json input and any file information
+ the client gave to us in the request.
+
+ This only applies to create message route.
+ """
+ form = await request.form
+ request_json = await request.get_json() or {}
+
+ # NOTE: embed isn't set on form data
+ json_from_form = {
+ "content": form.get("content", ""),
+ "nonce": form.get("nonce", "0"),
+ "tts": json.loads(form.get("tts", "false")),
+ }
+
+ payload_json = json.loads(form.get("payload_json", "{}"))
+
+ json_from_form.update(request_json)
+ json_from_form.update(payload_json)
+
+ files = await request.files
+
+ # we don't really care about the given fields on the files dict, so
+ # we only extract the values
+ return json_from_form, [v for k, v in files.items()]
+
+
+def msg_create_check_content(payload: dict, files: list, *, use_embeds=False):
+ """Check if there is actually any content being sent to us."""
+ has_content = bool(payload.get("content", ""))
+ has_files = len(files) > 0
+
+ embed_field = "embeds" if use_embeds else "embed"
+ has_embed = embed_field in payload and payload.get(embed_field) is not None
+
+ has_total_content = has_content or has_embed or has_files
+
+ if not has_total_content:
+ raise BadRequest("No content has been provided.")
+
+
+async def msg_add_attachment(message_id: int, channel_id: int, attachment_file) -> int:
+ """Add an attachment to a message.
+
+ Parameters
+ ----------
+ message_id: int
+ The ID of the message getting the attachment.
+ channel_id: int
+ The ID of the channel the message belongs to.
+
+ Exists because the attachment URL scheme contains
+ a channel id. The purpose is unknown, but we are
+ implementing Discord's behavior.
+ attachment_file: quart.FileStorage
+ quart FileStorage instance of the file.
+ """
+
+ attachment_id = get_snowflake()
+ filename = attachment_file.filename
+
+ # understand file info
+ mime = attachment_file.mimetype
+ is_image = mime.startswith("image/")
+
+ img_width, img_height = None, None
+
+ # extract file size
+ # TODO: this is probably inneficient
+ file_size = attachment_file.stream.getbuffer().nbytes
+
+ if is_image:
+ # open with pillow, extract image size
+ image = Image.open(attachment_file.stream)
+ img_width, img_height = image.size
+
+ # NOTE: DO NOT close the image, as closing the image will
+ # also close the stream.
+
+ # reset it to 0 for later usage
+ attachment_file.stream.seek(0)
+
+ await app.db.execute(
+ """
+ INSERT INTO attachments
+ (id, channel_id, message_id,
+ filename, filesize,
+ image, width, height)
+ VALUES
+ ($1, $2, $3, $4, $5, $6, $7, $8)
+ """,
+ attachment_id,
+ channel_id,
+ message_id,
+ filename,
+ file_size,
+ is_image,
+ img_width,
+ img_height,
+ )
+
+ ext = filename.split(".")[-1]
+
+ with open(f"attachments/{attachment_id}.{ext}", "wb") as attach_file:
+ attach_file.write(attachment_file.stream.read())
+
+ log.debug("written {} bytes for attachment id {}", file_size, attachment_id)
+
+ return attachment_id
+
+
+async def msg_guild_text_mentions(
+ payload: dict, guild_id: int, mentions_everyone: bool, mentions_here: bool
+):
+ """Calculates mention data side-effects."""
+ channel_id = int(payload["channel_id"])
+
+ # calculate the user ids we'll bump the mention count for
+ uids = set()
+
+ # first is extracting user mentions
+ for mention in payload["mentions"]:
+ uids.add(int(mention["id"]))
+
+ # then role mentions
+ for role_mention in payload["mention_roles"]:
+ role_id = int(role_mention)
+ member_ids = await app.storage.get_role_members(role_id)
+
+ for member_id in member_ids:
+ uids.add(member_id)
+
+ # at-here only updates the state
+ # for the users that have a state
+ # in the channel.
+ if mentions_here:
+ uids = set()
+
+ await app.db.execute(
+ """
+ UPDATE user_read_state
+ SET mention_count = mention_count + 1
+ WHERE channel_id = $1
+ """,
+ channel_id,
+ )
+
+ # at-here updates the read state
+ # for all users, including the ones
+ # that might not have read permissions
+ # to the channel.
+ if mentions_everyone:
+ uids = set()
+
+ member_ids = await app.storage.get_member_ids(guild_id)
+
+ await app.db.executemany(
+ """
+ UPDATE user_read_state
+ SET mention_count = mention_count + 1
+ WHERE channel_id = $1 AND user_id = $2
+ """,
+ [(channel_id, uid) for uid in member_ids],
+ )
+
+ for user_id in uids:
+ await app.db.execute(
+ """
+ UPDATE user_read_state
+ SET mention_count = mention_count + 1
+ WHERE user_id = $1
+ AND channel_id = $2
+ """,
+ user_id,
+ channel_id,
+ )
diff --git a/litecord/common/users.py b/litecord/common/users.py
new file mode 100644
index 0000000..2b36082
--- /dev/null
+++ b/litecord/common/users.py
@@ -0,0 +1,273 @@
+"""
+
+Litecord
+Copyright (C) 2018-2019 Luna Mendes
+
+This program is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, version 3 of the License.
+
+This program is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with this program. If not, see .
+
+"""
+
+import logging
+from random import randint
+from typing import Tuple, Optional
+
+from quart import current_app as app
+from asyncpg import UniqueViolationError
+
+from ..snowflake import get_snowflake
+from ..errors import BadRequest
+from ..auth import hash_data
+from ..utils import rand_hex
+
+log = logging.getLogger(__name__)
+
+
+async def mass_user_update(user_id):
+ """Dispatch USER_UPDATE in a mass way."""
+ # by using dispatch_with_filter
+ # we're guaranteeing all shards will get
+ # a USER_UPDATE once and not any others.
+
+ session_ids = []
+
+ public_user = await app.storage.get_user(user_id)
+ private_user = await app.storage.get_user(user_id, secure=True)
+
+ session_ids.extend(
+ await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", private_user)
+ )
+
+ guild_ids = await app.user_storage.get_user_guilds(user_id)
+ friend_ids = await app.user_storage.get_friend_ids(user_id)
+
+ session_ids.extend(
+ await app.dispatcher.dispatch_many_filter_list(
+ "guild", guild_ids, session_ids, "USER_UPDATE", public_user
+ )
+ )
+
+ session_ids.extend(
+ await app.dispatcher.dispatch_many_filter_list(
+ "friend", friend_ids, session_ids, "USER_UPDATE", public_user
+ )
+ )
+
+ await app.dispatcher.dispatch_many("lazy_guild", guild_ids, "update_user", user_id)
+
+ return public_user, private_user
+
+
+async def check_username_usage(username: str):
+ """Raise an error if too many people are with the same username."""
+ same_username = await app.db.fetchval(
+ """
+ SELECT COUNT(*)
+ FROM users
+ WHERE username = $1
+ """,
+ username,
+ )
+
+ if same_username > 9000:
+ raise BadRequest(
+ "Too many people.",
+ {
+ "username": "Too many people used the same username. "
+ "Please choose another"
+ },
+ )
+
+
+def _raw_discrim() -> str:
+ discrim_number = randint(1, 9999)
+ return "%04d" % discrim_number
+
+
+async def roll_discrim(username: str) -> Optional[str]:
+ """Roll a discriminator for a DiscordTag.
+
+ Tries to generate one 10 times.
+
+ Calls check_username_usage.
+ """
+
+ # we shouldn't roll discrims for usernames
+ # that have been used too much.
+ await check_username_usage(username)
+
+ # max 10 times for a reroll
+ for _ in range(10):
+ # generate random discrim
+ discrim = _raw_discrim()
+
+ # check if anyone is with it
+ res = await app.db.fetchval(
+ """
+ SELECT id
+ FROM users
+ WHERE username = $1 AND discriminator = $2
+ """,
+ username,
+ discrim,
+ )
+
+ # if no user is found with the (username, discrim)
+ # pair, then this is unique! return it.
+ if res is None:
+ return discrim
+
+ return None
+
+
+async def create_user(username: str, email: str, password: str) -> Tuple[int, str]:
+ """Create a single user.
+
+ Generates a distriminator and other information. You can fetch the user
+ data back with :meth:`Storage.get_user`.
+ """
+ new_id = get_snowflake()
+ new_discrim = await roll_discrim(username)
+
+ if new_discrim is None:
+ raise BadRequest(
+ "Unable to register.",
+ {"username": "Too many people are with this username."},
+ )
+
+ pwd_hash = await hash_data(password)
+
+ try:
+ await app.db.execute(
+ """
+ INSERT INTO users
+ (id, email, username, discriminator, password_hash)
+ VALUES
+ ($1, $2, $3, $4, $5)
+ """,
+ new_id,
+ email,
+ username,
+ new_discrim,
+ pwd_hash,
+ )
+ except UniqueViolationError:
+ raise BadRequest("Email already used.")
+
+ return new_id, pwd_hash
+
+
+async def _del_from_table(db, table: str, user_id: int):
+ """Delete a row from a table."""
+ column = {
+ "channel_overwrites": "target_user",
+ "user_settings": "id",
+ "group_dm_members": "member_id",
+ }.get(table, "user_id")
+
+ res = await db.execute(
+ f"""
+ DELETE FROM {table}
+ WHERE {column} = $1
+ """,
+ user_id,
+ )
+
+ log.info("Deleting uid {} from {}, res: {!r}", user_id, table, res)
+
+
+async def delete_user(user_id, *, mass_update: bool = True):
+ """Delete a user. Does not disconnect the user."""
+ db = app.db
+
+ new_username = f"Deleted User {rand_hex()}"
+
+ # by using a random hex in password_hash
+ # we break attempts at using the default '123' password hash
+ # to issue valid tokens for deleted users.
+
+ await db.execute(
+ """
+ UPDATE users
+ SET
+ username = $1,
+ email = NULL,
+ mfa_enabled = false,
+ verified = false,
+ avatar = NULL,
+ flags = 0,
+ premium_since = NULL,
+ phone = '',
+ password_hash = $2
+ WHERE
+ id = $3
+ """,
+ new_username,
+ rand_hex(32),
+ user_id,
+ )
+
+ # remove the user from various tables
+ await _del_from_table(db, "user_settings", user_id)
+ await _del_from_table(db, "user_payment_sources", user_id)
+ await _del_from_table(db, "user_subscriptions", user_id)
+ await _del_from_table(db, "user_payments", user_id)
+ await _del_from_table(db, "user_read_state", user_id)
+ await _del_from_table(db, "guild_settings", user_id)
+ await _del_from_table(db, "guild_settings_channel_overrides", user_id)
+
+ await db.execute(
+ """
+ DELETE FROM relationships
+ WHERE user_id = $1 OR peer_id = $1
+ """,
+ user_id,
+ )
+
+ # DMs are still maintained, but not the state.
+ await _del_from_table(db, "dm_channel_state", user_id)
+
+ # NOTE: we don't delete the group dms the user is an owner of...
+ # TODO: group dm owner reassign when the owner leaves a gdm
+ await _del_from_table(db, "group_dm_members", user_id)
+
+ await _del_from_table(db, "members", user_id)
+ await _del_from_table(db, "member_roles", user_id)
+ await _del_from_table(db, "channel_overwrites", user_id)
+
+ # after updating the user, we send USER_UPDATE so that all the other
+ # clients can refresh their caches on the now-deleted user
+ if mass_update:
+ await mass_user_update(user_id)
+
+
+async def user_disconnect(user_id: int):
+ """Disconnects the given user's devices."""
+ # after removing the user from all tables, we need to force
+ # all known user states to reconnect, causing the user to not
+ # be online anymore.
+ user_states = app.state_manager.user_states(user_id)
+
+ for state in user_states:
+ # make it unable to resume
+ app.state_manager.remove(state)
+
+ if not state.ws:
+ continue
+
+ # force a close, 4000 should make the client reconnect.
+ await state.ws.ws.close(4000)
+
+ # force everyone to see the user as offline
+ await app.presence.dispatch_pres(
+ user_id, {"afk": False, "status": "offline", "game": None, "since": 0}
+ )
diff --git a/litecord/embed/messages.py b/litecord/embed/messages.py
index ce23ea4..bc240ea 100644
--- a/litecord/embed/messages.py
+++ b/litecord/embed/messages.py
@@ -22,6 +22,7 @@ import asyncio
import urllib.parse
from pathlib import Path
+from quart import current_app as app
from logbook import Logger
from litecord.embed.sanitizer import proxify, fetch_metadata, fetch_embed
@@ -33,10 +34,10 @@ log = Logger(__name__)
MEDIA_EXTENSIONS = ("png", "jpg", "jpeg", "gif", "webm")
-async def insert_media_meta(url, config, session):
+async def insert_media_meta(url):
"""Insert media metadata as an embed."""
- img_proxy_url = proxify(url, config=config)
- meta = await fetch_metadata(url, config=config, session=session)
+ img_proxy_url = proxify(url)
+ meta = await fetch_metadata(url)
if meta is None:
return
@@ -56,19 +57,19 @@ async def insert_media_meta(url, config, session):
}
-async def msg_update_embeds(payload, new_embeds, storage, dispatcher):
+async def msg_update_embeds(payload, new_embeds):
"""Update the message with the given embeds and dispatch a MESSAGE_UPDATE
to users."""
message_id = int(payload["id"])
channel_id = int(payload["channel_id"])
- await storage.execute_with_json(
+ await app.storage.execute_with_json(
"""
- UPDATE messages
- SET embeds = $1
- WHERE messages.id = $2
- """,
+ UPDATE messages
+ SET embeds = $1
+ WHERE messages.id = $2
+ """,
new_embeds,
message_id,
)
@@ -85,7 +86,9 @@ async def msg_update_embeds(payload, new_embeds, storage, dispatcher):
if "flags" in payload:
update_payload["flags"] = payload["flags"]
- await dispatcher.dispatch("channel", channel_id, "MESSAGE_UPDATE", update_payload)
+ await app.dispatcher.dispatch(
+ "channel", channel_id, "MESSAGE_UPDATE", update_payload
+ )
def is_media_url(url) -> bool:
@@ -102,15 +105,13 @@ def is_media_url(url) -> bool:
return extension in MEDIA_EXTENSIONS
-async def insert_mp_embed(parsed, config, session):
+async def insert_mp_embed(parsed):
"""Insert mediaproxy embed."""
- embed = await fetch_embed(parsed, config=config, session=session)
+ embed = await fetch_embed(parsed)
return embed
-async def process_url_embed(
- config, storage, dispatcher, session, payload: dict, *, delay=0
-):
+async def process_url_embed(payload: dict, *, delay=0):
"""Process URLs in a message and generate embeds based on that."""
await asyncio.sleep(delay)
@@ -145,9 +146,9 @@ async def process_url_embed(
url = EmbedURL(url)
if is_media_url(url):
- embed = await insert_media_meta(url, config, session)
+ embed = await insert_media_meta(url)
else:
- embed = await insert_mp_embed(url, config, session)
+ embed = await insert_mp_embed(url)
if not embed:
continue
@@ -160,4 +161,4 @@ async def process_url_embed(
log.debug("made {} embeds for mid {}", len(new_embeds), message_id)
- await msg_update_embeds(payload, new_embeds, storage, dispatcher)
+ await msg_update_embeds(payload, new_embeds)
diff --git a/litecord/embed/sanitizer.py b/litecord/embed/sanitizer.py
index b14e436..14e8977 100644
--- a/litecord/embed/sanitizer.py
+++ b/litecord/embed/sanitizer.py
@@ -75,35 +75,24 @@ def path_exists(embed: Embed, components_in: Union[List[str], str]):
return False
-def _mk_cfg_sess(config, session) -> tuple:
- """Return a tuple of (config, session)."""
- if config is None:
- config = app.config
-
- if session is None:
- session = app.session
-
- return config, session
-
-
-def _md_base(config) -> Optional[tuple]:
+def _md_base() -> Optional[tuple]:
"""Return the protocol and base url for the mediaproxy."""
- md_base_url = config["MEDIA_PROXY"]
+ md_base_url = app.config["MEDIA_PROXY"]
if md_base_url is None:
return None
- proto = "https" if config["IS_SSL"] else "http"
+ proto = "https" if app.config["IS_SSL"] else "http"
return proto, md_base_url
-def make_md_req_url(config, scope: str, url):
- """Make a mediaproxy request URL given the config, scope, and the url
+def make_md_req_url(scope: str, url):
+ """Make a mediaproxy request URL given the scope and the url
to be proxied.
When MEDIA_PROXY is None, however, returns the original URL.
"""
- base = _md_base(config)
+ base = _md_base()
if base is None:
return url.url if isinstance(url, EmbedURL) else url
@@ -111,38 +100,25 @@ def make_md_req_url(config, scope: str, url):
return f"{proto}://{base_url}/{scope}/{url.to_md_path}"
-def proxify(url, *, config=None) -> str:
+def proxify(url) -> str:
"""Return a mediaproxy url for the given EmbedURL. Returns an
/img/ scope."""
- config, _sess = _mk_cfg_sess(config, False)
-
if isinstance(url, str):
url = EmbedURL(url)
- return make_md_req_url(config, "img", url)
+ return make_md_req_url("img", url)
async def _md_client_req(
- config, session, scope: str, url, *, ret_resp=False
+ scope: str, url, *, ret_resp=False
) -> Optional[Union[Tuple, Dict]]:
"""Makes a request to the mediaproxy.
This has common code between all the main mediaproxy request functions
to decrease code repetition.
- Note that config and session exist because there are cases where the app
- isn't retrievable (as those functions usually run in background tasks,
- not in the app itself).
-
Parameters
----------
- config: dict-like
- the app configuration, if None, this will get the global one from the
- app instance.
- session: aiohttp client session
- the aiohttp ClientSession instance to use, if None, this will get
- the global one from the app.
-
scope: str
the scope of your request. one of 'meta', 'img', or 'embed' are
available for the mediaproxy's API.
@@ -155,14 +131,12 @@ async def _md_client_req(
the raw bytes of the response, but by the time this function is
returned, the response object is invalid and the socket is closed
"""
- config, session = _mk_cfg_sess(config, session)
-
if not isinstance(url, EmbedURL):
url = EmbedURL(url)
- request_url = make_md_req_url(config, scope, url)
+ request_url = make_md_req_url(scope, url)
- async with session.get(request_url) as resp:
+ async with app.session.get(request_url) as resp:
if resp.status == 200:
if ret_resp:
return resp, await resp.read()
@@ -174,18 +148,18 @@ async def _md_client_req(
return None
-async def fetch_metadata(url, *, config=None, session=None) -> Optional[Dict]:
+async def fetch_metadata(url) -> Optional[Dict]:
"""Fetch metadata for a url (image width, mime, etc)."""
- return await _md_client_req(config, session, "meta", url)
+ return await _md_client_req("meta", url)
-async def fetch_raw_img(url, *, config=None, session=None) -> Optional[tuple]:
+async def fetch_raw_img(url) -> Optional[tuple]:
"""Fetch raw data for a url (the bytes given off, used to proxy images).
Returns a tuple containing the response object and the raw bytes given by
the website.
"""
- tup = await _md_client_req(config, session, "img", url, ret_resp=True)
+ tup = await _md_client_req("img", url, ret_resp=True)
if not tup:
return None
@@ -193,13 +167,13 @@ async def fetch_raw_img(url, *, config=None, session=None) -> Optional[tuple]:
return tup
-async def fetch_embed(url, *, config=None, session=None) -> Dict[str, Any]:
+async def fetch_embed(url) -> Dict[str, Any]:
"""Fetch an embed for a given webpage (an automatically generated embed
by the mediaproxy, look over the project on how it generates embeds).
Returns a discord embed object.
"""
- return await _md_client_req(config, session, "embed", url)
+ return await _md_client_req("embed", url)
async def fill_embed(embed: Optional[Embed]) -> Optional[Embed]:
diff --git a/litecord/enums.py b/litecord/enums.py
index ce68e18..7c5b143 100644
--- a/litecord/enums.py
+++ b/litecord/enums.py
@@ -54,27 +54,21 @@ class Flags:
"""
def __init_subclass__(cls, **_kwargs):
- attrs = inspect.getmembers(cls, lambda x: not inspect.isroutine(x))
+ # get only the members that represent a field
+ cls._attrs = inspect.getmembers(cls, lambda x: isinstance(x, int))
- def _make_int(value):
- res = Flags()
+ @classmethod
+ def from_int(cls, value: int):
+ """Create a Flags from a given int value."""
+ res = Flags()
+ setattr(res, "value", value)
- setattr(res, "value", value)
+ for attr, val in cls._attrs:
+ has_attr = (value & val) == val
+ # set attributes dynamically
+ setattr(res, f"is_{attr}", has_attr)
- for attr, val in attrs:
- # get only the ones that represent a field in the
- # number's bits
- if not isinstance(val, int):
- continue
-
- has_attr = (value & val) == val
-
- # set each attribute
- setattr(res, f"is_{attr}", has_attr)
-
- return res
-
- cls.from_int = _make_int
+ return res
class ChannelType(EasyEnum):
diff --git a/litecord/errors.py b/litecord/errors.py
index d2a72c2..1789a5e 100644
--- a/litecord/errors.py
+++ b/litecord/errors.py
@@ -116,6 +116,10 @@ class Forbidden(LitecordError):
status_code = 403
+class ForbiddenDM(Forbidden):
+ error_code = 50007
+
+
class NotFound(LitecordError):
status_code = 404
diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py
index 323b2e4..14d01f2 100644
--- a/litecord/gateway/websocket.py
+++ b/litecord/gateway/websocket.py
@@ -21,7 +21,7 @@ import collections
import asyncio
import pprint
import zlib
-from typing import List, Dict, Any
+from typing import List, Dict, Any, Iterable
from random import randint
import websockets
@@ -56,41 +56,15 @@ WebsocketProperties = collections.namedtuple(
"WebsocketProperties", "v encoding compress zctx zsctx tasks"
)
-WebsocketObjects = collections.namedtuple(
- "WebsocketObjects",
- (
- "db",
- "state_manager",
- "storage",
- "loop",
- "dispatcher",
- "presence",
- "ratelimiter",
- "user_storage",
- "voice",
- ),
-)
-
class GatewayWebsocket:
"""Main gateway websocket logic."""
def __init__(self, ws, app, **kwargs):
- self.ext = WebsocketObjects(
- app.db,
- app.state_manager,
- app.storage,
- app.loop,
- app.dispatcher,
- app.presence,
- app.ratelimiter,
- app.user_storage,
- app.voice,
- )
-
- self.storage = self.ext.storage
- self.user_storage = self.ext.user_storage
- self.presence = self.ext.presence
+ self.app = app
+ self.storage = app.storage
+ self.user_storage = app.user_storage
+ self.presence = app.presence
self.ws = ws
self.wsp = WebsocketProperties(
@@ -225,7 +199,7 @@ class GatewayWebsocket:
await self.send({"op": op_code, "d": data, "t": None, "s": None})
def _check_ratelimit(self, key: str, ratelimit_key):
- ratelimit = self.ext.ratelimiter.get_ratelimit(f"_ws.{key}")
+ ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}")
bucket = ratelimit.get_bucket(ratelimit_key)
return bucket.update_rate_limit()
@@ -245,7 +219,7 @@ class GatewayWebsocket:
if task:
task.cancel()
- self.wsp.tasks["heartbeat"] = self.ext.loop.create_task(
+ self.wsp.tasks["heartbeat"] = self.app.loop.create_task(
task_wrapper("hb wait", self._hb_wait(interval))
)
@@ -330,7 +304,7 @@ class GatewayWebsocket:
if r["type"] == RelationshipType.FRIEND.value
]
- friend_presences = await self.ext.presence.friend_presences(friend_ids)
+ friend_presences = await self.app.presence.friend_presences(friend_ids)
settings = await self.user_storage.get_user_settings(user_id)
return {
@@ -377,14 +351,14 @@ class GatewayWebsocket:
await self.dispatch("READY", {**base_ready, **user_ready})
# async dispatch of guilds
- self.ext.loop.create_task(self._guild_dispatch(guilds))
+ self.app.loop.create_task(self._guild_dispatch(guilds))
async def _check_shards(self, shard, user_id):
"""Check if the given `shard` value in IDENTIFY has good enough values.
"""
current_shard, shard_count = shard
- guilds = await self.ext.db.fetchval(
+ guilds = await self.app.db.fetchval(
"""
SELECT COUNT(*)
FROM members
@@ -460,7 +434,7 @@ class GatewayWebsocket:
("channel", gdm_ids),
]
- await self.ext.dispatcher.mass_sub(user_id, channels_to_sub)
+ await self.app.dispatcher.mass_sub(user_id, channels_to_sub)
if not self.state.bot:
# subscribe to all friends
@@ -468,7 +442,7 @@ class GatewayWebsocket:
# when they come online)
friend_ids = await self.user_storage.get_friend_ids(user_id)
log.info("subscribing to {} friends", len(friend_ids))
- await self.ext.dispatcher.sub_many("friend", user_id, friend_ids)
+ await self.app.dispatcher.sub_many("friend", user_id, friend_ids)
async def update_status(self, status: dict):
"""Update the status of the current websocket connection."""
@@ -520,7 +494,7 @@ class GatewayWebsocket:
f'Updating presence status={status["status"]} for '
f"uid={self.state.user_id}"
)
- await self.ext.presence.dispatch_pres(self.state.user_id, self.state.presence)
+ await self.app.presence.dispatch_pres(self.state.user_id, self.state.presence)
async def handle_1(self, payload: Dict[str, Any]):
"""Handle OP 1 Heartbeat packets."""
@@ -558,13 +532,13 @@ class GatewayWebsocket:
presence = data.get("presence")
try:
- user_id = await raw_token_check(token, self.ext.db)
+ user_id = await raw_token_check(token, self.app.db)
except (Unauthorized, Forbidden):
raise WebsocketClose(4004, "Authentication failed")
await self._connect_ratelimit(user_id)
- bot = await self.ext.db.fetchval(
+ bot = await self.app.db.fetchval(
"""
SELECT bot FROM users
WHERE id = $1
@@ -587,7 +561,7 @@ class GatewayWebsocket:
)
# link the state to the user
- self.ext.state_manager.insert(self.state)
+ self.app.state_manager.insert(self.state)
await self.update_status(presence)
await self.subscribe_all(data.get("guild_subscriptions", True))
@@ -631,12 +605,12 @@ class GatewayWebsocket:
# if its null and null, disconnect the user from any voice
# TODO: maybe just leave from DMs? idk...
if channel_id is None and guild_id is None:
- return await self.ext.voice.leave_all(self.state.user_id)
+ return await self.app.voice.leave_all(self.state.user_id)
# if guild is not none but channel is, we are leaving
# a guild's channel
if channel_id is None:
- return await self.ext.voice.leave(guild_id, self.state.user_id)
+ return await self.app.voice.leave(guild_id, self.state.user_id)
# fetch an existing state given user and guild OR user and channel
chan_type = ChannelType(await self.storage.get_chan_type(channel_id))
@@ -659,10 +633,10 @@ class GatewayWebsocket:
# this state id format takes care of that.
voice_key = (self.state.user_id, state_id2)
- voice_state = await self.ext.voice.get_state(voice_key)
+ voice_state = await self.app.voice.get_state(voice_key)
if voice_state is None:
- return await self.ext.voice.create_state(voice_key, data)
+ return await self.app.voice.create_state(voice_key, data)
same_guild = guild_id == voice_state.guild_id
same_channel = channel_id == voice_state.channel_id
@@ -670,10 +644,10 @@ class GatewayWebsocket:
prop = await self._vsu_get_prop(voice_state, data)
if same_guild and same_channel:
- return await self.ext.voice.update_state(voice_state, prop)
+ return await self.app.voice.update_state(voice_state, prop)
if same_guild and not same_channel:
- return await self.ext.voice.move_state(voice_state, channel_id)
+ return await self.app.voice.move_state(voice_state, channel_id)
async def _handle_5(self, payload: Dict[str, Any]):
"""Handle OP 5 Voice Server Ping.
@@ -698,9 +672,9 @@ class GatewayWebsocket:
# since the state will be removed from
# the manager, it will become unreachable
# when trying to resume.
- self.ext.state_manager.remove(self.state)
+ self.app.state_manager.remove(self.state)
- async def _resume(self, replay_seqs: iter):
+ async def _resume(self, replay_seqs: Iterable):
presences = []
try:
@@ -740,12 +714,12 @@ class GatewayWebsocket:
raise DecodeError("Invalid resume payload")
try:
- user_id = await raw_token_check(token, self.ext.db)
+ user_id = await raw_token_check(token, self.app.db)
except (Unauthorized, Forbidden):
raise WebsocketClose(4004, "Invalid token")
try:
- state = self.ext.state_manager.fetch(user_id, sess_id)
+ state = self.app.state_manager.fetch(user_id, sess_id)
except KeyError:
return await self.invalidate_session(False)
@@ -948,7 +922,7 @@ class GatewayWebsocket:
log.debug("lazy request: members: {}", data.get("members", []))
# make shard query
- lazy_guilds = self.ext.dispatcher.backends["lazy_guild"]
+ lazy_guilds = self.app.dispatcher.backends["lazy_guild"]
for chan_id, ranges in data.get("channels", {}).items():
chan_id = int(chan_id)
@@ -992,10 +966,10 @@ class GatewayWebsocket:
# close anyone trying to login while the
# server is shutting down
- if self.ext.state_manager.closed:
+ if self.app.state_manager.closed:
raise WebsocketClose(4000, "state manager closed")
- if not self.ext.state_manager.accept_new:
+ if not self.app.state_manager.accept_new:
raise WebsocketClose(4000, "state manager closed for new")
while True:
@@ -1016,7 +990,7 @@ class GatewayWebsocket:
task.cancel()
if self.state:
- self.ext.state_manager.remove(self.state)
+ self.app.state_manager.remove(self.state)
self.state.ws = None
self.state = None
@@ -1031,14 +1005,14 @@ class GatewayWebsocket:
# TODO: account for sharding
# this only updates status to offline once
# ALL shards have come offline
- states = self.ext.state_manager.user_states(user_id)
+ states = self.app.state_manager.user_states(user_id)
with_ws = [s for s in states if s.ws]
# there arent any other states with websocket
if not with_ws:
offline = {"afk": False, "status": "offline", "game": None, "since": 0}
- await self.ext.presence.dispatch_pres(user_id, offline)
+ await self.app.presence.dispatch_pres(user_id, offline)
async def run(self):
"""Wrap :meth:`listen_messages` inside
diff --git a/litecord/jobs.py b/litecord/jobs.py
index 4ad3852..4f08271 100644
--- a/litecord/jobs.py
+++ b/litecord/jobs.py
@@ -18,7 +18,9 @@ along with this program. If not, see .
"""
import asyncio
+from typing import Any
+from quart.ctx import copy_current_app_context
from logbook import Logger
log = Logger(__name__)
@@ -47,9 +49,14 @@ class JobManager:
def spawn(self, coro):
"""Spawn a given future or coroutine in the background."""
- task = self.loop.create_task(self._wrapper(coro))
+ @copy_current_app_context
+ async def _ctx_wrapper_bg() -> Any:
+ return await coro
+
+ task = self.loop.create_task(self._wrapper(_ctx_wrapper_bg()))
self.jobs.append(task)
+ return task
def close(self):
"""Close the job manager, cancelling all existing jobs.
diff --git a/litecord/utils.py b/litecord/utils.py
index a0f5587..91e7d8a 100644
--- a/litecord/utils.py
+++ b/litecord/utils.py
@@ -19,11 +19,14 @@ along with this program. If not, see .
import asyncio
import json
+import secrets
from typing import Any, Iterable, Optional, Sequence, List, Dict, Union
from logbook import Logger
from quart.json import JSONEncoder
-from quart import current_app as app
+from quart import current_app as app, request
+
+from .errors import BadRequest
log = Logger(__name__)
@@ -233,3 +236,57 @@ def maybe_int(val: Any) -> Union[int, Any]:
return int(val)
except (ValueError, TypeError):
return val
+
+
+async def maybe_lazy_guild_dispatch(
+ guild_id: int, event: str, role, force: bool = False
+):
+ # sometimes we want to dispatch an event
+ # even if the role isn't hoisted
+
+ # an example of such a case is when a role loses
+ # its hoist status.
+
+ # check if is a dict first because role_delete
+ # only receives the role id.
+ if isinstance(role, dict) and not role["hoist"] and not force:
+ return
+
+ await app.dispatcher.dispatch("lazy_guild", guild_id, event, role)
+
+
+def extract_limit(request_, default: int = 50, max_val: int = 100):
+ """Extract a limit kwarg."""
+ try:
+ limit = int(request_.args.get("limit", default))
+
+ if limit not in range(0, max_val + 1):
+ raise ValueError()
+ except (TypeError, ValueError):
+ raise BadRequest("limit not int")
+
+ return limit
+
+
+def query_tuple_from_args(args: dict, limit: int) -> tuple:
+ """Extract a 2-tuple out of request arguments."""
+ before, after = None, None
+
+ if "around" in request.args:
+ average = int(limit / 2)
+ around = int(args["around"])
+
+ after = around - average
+ before = around + average
+
+ elif "before" in args:
+ before = int(args["before"])
+ elif "after" in args:
+ before = int(args["after"])
+
+ return before, after
+
+
+def rand_hex(length: int = 8) -> str:
+ """Generate random hex characters."""
+ return secrets.token_hex(length)[:length]
diff --git a/manage/cmd/users.py b/manage/cmd/users.py
index c20e181..50818e8 100644
--- a/manage/cmd/users.py
+++ b/manage/cmd/users.py
@@ -17,9 +17,8 @@ along with this program. If not, see .
"""
-from litecord.auth import create_user
+from litecord.common.users import create_user, delete_user
from litecord.blueprints.auth import make_token
-from litecord.blueprints.users import delete_user
from litecord.enums import UserFlags
diff --git a/run.py b/run.py
index d15dbc1..55aeb37 100644
--- a/run.py
+++ b/run.py
@@ -337,9 +337,9 @@ async def api_index(app_):
async def post_app_start(app_):
# we'll need to start a billing job
- app_.sched.spawn(payment_job(app_))
+ app_.sched.spawn(payment_job())
app_.sched.spawn(api_index(app_))
- app_.sched.spawn(guild_region_check(app_))
+ app_.sched.spawn(guild_region_check())
def start_websocket(host, port, ws_handler) -> asyncio.Future:
diff --git a/tests/conftest.py b/tests/conftest.py
index 2a48d5c..f0444c9 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -30,10 +30,9 @@ from tests.common import email, TestClient
from run import app as main_app, set_blueprints
-from litecord.auth import create_user
+from litecord.common.users import create_user, delete_user
from litecord.enums import UserFlags
from litecord.blueprints.auth import make_token
-from litecord.blueprints.users import delete_user
@pytest.fixture(name="app")
diff --git a/tests/test_admin_api/test_guilds.py b/tests/test_admin_api/test_guilds.py
index b6619e7..6ca61cb 100644
--- a/tests/test_admin_api/test_guilds.py
+++ b/tests/test_admin_api/test_guilds.py
@@ -54,6 +54,11 @@ async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False):
return rjson
+async def _delete_guild(test_cli, guild_id: int):
+ async with test_cli.app.app_context():
+ await delete_guild(int(guild_id))
+
+
@pytest.mark.asyncio
async def test_guild_fetch(test_cli_staff):
"""Test the creation and fetching of a guild via the Admin API."""
@@ -63,7 +68,7 @@ async def test_guild_fetch(test_cli_staff):
try:
await _fetch_guild(test_cli_staff, guild_id)
finally:
- await delete_guild(int(guild_id), app_=test_cli_staff.app)
+ await _delete_guild(test_cli_staff, int(guild_id))
@pytest.mark.asyncio
@@ -91,7 +96,7 @@ async def test_guild_update(test_cli_staff):
rjson = await _fetch_guild(test_cli_staff, guild_id)
assert rjson["unavailable"]
finally:
- await delete_guild(int(guild_id), app_=test_cli_staff.app)
+ await _delete_guild(test_cli_staff, int(guild_id))
@pytest.mark.asyncio
@@ -113,4 +118,4 @@ async def test_guild_delete(test_cli_staff):
assert rjson["error"]
assert rjson["code"] == GuildNotFound.error_code
finally:
- await delete_guild(int(guild_id), app_=test_cli_staff.app)
+ await _delete_guild(test_cli_staff, int(guild_id))