From 83a1c1ae294ef2ef19ab8eab391d3ce4f29d7ebb Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 07:27:50 -0300 Subject: [PATCH] black fmt pass --- config.ci.py | 25 +- config.example.py | 29 +- litecord/__init__.py | 1 - litecord/admin_schemas.py | 41 +- litecord/auth.py | 100 ++- litecord/blueprints/__init__.py | 22 +- litecord/blueprints/admin_api/__init__.py | 2 +- litecord/blueprints/admin_api/features.py | 42 +- litecord/blueprints/admin_api/guilds.py | 21 +- .../blueprints/admin_api/instance_invites.py | 57 +- litecord/blueprints/admin_api/users.py | 78 +- litecord/blueprints/admin_api/voice.py | 64 +- litecord/blueprints/attachments.py | 55 +- litecord/blueprints/auth.py | 116 ++- litecord/blueprints/channel/__init__.py | 2 +- litecord/blueprints/channel/dm_checks.py | 13 +- litecord/blueprints/channel/messages.py | 336 ++++---- litecord/blueprints/channel/pins.py | 100 ++- litecord/blueprints/channel/reactions.py | 134 ++-- litecord/blueprints/channels.py | 511 +++++++----- litecord/blueprints/checks.py | 54 +- litecord/blueprints/dm_channels.py | 139 ++-- litecord/blueprints/dms.py | 49 +- litecord/blueprints/gateway.py | 40 +- litecord/blueprints/guild/__init__.py | 3 +- litecord/blueprints/guild/channels.py | 107 +-- litecord/blueprints/guild/emoji.py | 100 ++- litecord/blueprints/guild/members.py | 157 ++-- litecord/blueprints/guild/mod.py | 148 ++-- litecord/blueprints/guild/roles.py | 166 ++-- litecord/blueprints/guilds.py | 368 +++++---- litecord/blueprints/icons.py | 46 +- litecord/blueprints/invites.py | 202 +++-- litecord/blueprints/nodeinfo.py | 92 +-- litecord/blueprints/relationships.py | 265 ++++--- litecord/blueprints/science.py | 14 +- litecord/blueprints/static.py | 19 +- litecord/blueprints/user/__init__.py | 2 +- litecord/blueprints/user/billing.py | 266 ++++--- litecord/blueprints/user/billing_job.py | 67 +- litecord/blueprints/user/fake_store.py | 22 +- litecord/blueprints/user/settings.py | 71 +- litecord/blueprints/users.py | 373 +++++---- litecord/blueprints/voice.py | 30 +- litecord/blueprints/webhooks.py | 300 ++++--- litecord/dispatcher.py | 88 +- litecord/embed/__init__.py | 2 +- litecord/embed/messages.py | 72 +- litecord/embed/sanitizer.py | 60 +- litecord/embed/schemas.py | 106 +-- litecord/enums.py | 56 +- litecord/errors.py | 110 +-- litecord/gateway/encoding.py | 4 +- litecord/gateway/gateway.py | 29 +- litecord/gateway/opcodes.py | 2 + litecord/gateway/state.py | 14 +- litecord/gateway/state_manager.py | 37 +- litecord/gateway/utils.py | 2 + litecord/gateway/websocket.py | 497 ++++++------ litecord/guild_memory_store.py | 6 +- litecord/images.py | 235 +++--- litecord/jobs.py | 8 +- litecord/permissions.py | 153 ++-- litecord/presence.py | 165 ++-- litecord/pubsub/__init__.py | 11 +- litecord/pubsub/channel.py | 39 +- litecord/pubsub/dispatcher.py | 6 +- litecord/pubsub/friend.py | 13 +- litecord/pubsub/guild.py | 56 +- litecord/pubsub/lazy_guild.py | 590 +++++++------- litecord/pubsub/member.py | 3 +- litecord/pubsub/user.py | 14 +- litecord/ratelimits/bucket.py | 18 +- litecord/ratelimits/handler.py | 22 +- litecord/ratelimits/main.py | 35 +- litecord/schemas.py | 724 +++++++---------- litecord/snowflake.py | 12 +- litecord/storage.py | 749 +++++++++++------- litecord/system_messages.py | 61 +- litecord/types.py | 5 +- litecord/user_storage.py | 223 ++++-- litecord/utils.py | 80 +- litecord/voice/lvsp_conn.py | 73 +- litecord/voice/lvsp_manager.py | 43 +- litecord/voice/lvsp_opcodes.py | 16 +- litecord/voice/manager.py | 51 +- litecord/voice/state.py | 5 +- manage.py | 2 +- manage/__init__.py | 1 - manage/cmd/invites.py | 71 +- manage/cmd/migration/__init__.py | 2 +- manage/cmd/migration/command.py | 87 +- manage/cmd/users.py | 95 ++- manage/main.py | 7 +- run.py | 236 +++--- setup.py | 12 +- tests/common.py | 18 +- tests/conftest.py | 49 +- tests/test_admin_api/test_guilds.py | 43 +- tests/test_admin_api/test_instance_invites.py | 17 +- tests/test_admin_api/test_users.py | 58 +- tests/test_embeds.py | 77 +- tests/test_gateway.py | 23 +- tests/test_guild.py | 40 +- tests/test_main.py | 2 +- tests/test_no_tracking.py | 10 +- tests/test_ratelimits.py | 11 +- tests/test_user.py | 67 +- tests/test_websocket.py | 78 +- 109 files changed, 5575 insertions(+), 4775 deletions(-) diff --git a/config.ci.py b/config.ci.py index 24d6046..911085a 100644 --- a/config.ci.py +++ b/config.ci.py @@ -17,13 +17,14 @@ along with this program. If not, see . """ -MODE = 'CI' +MODE = "CI" class Config: """Default configuration values for litecord.""" - MAIN_URL = 'localhost:1' - NAME = 'gitlab ci' + + MAIN_URL = "localhost:1" + NAME = "gitlab ci" # Enable debug logging? DEBUG = False @@ -37,11 +38,11 @@ class Config: # Set this url to somewhere *your users* # will hit the websocket. # e.g 'gateway.example.com' for reverse proxies. - WEBSOCKET_URL = 'localhost:5001' + WEBSOCKET_URL = "localhost:5001" # Where to host the websocket? # (a local address the server will bind to) - WS_HOST = 'localhost' + WS_HOST = "localhost" WS_PORT = 5001 # Postgres credentials @@ -51,10 +52,10 @@ class Config: class Development(Config): DEBUG = True POSTGRES = { - 'host': 'localhost', - 'user': 'litecord', - 'password': '123', - 'database': 'litecord', + "host": "localhost", + "user": "litecord", + "password": "123", + "database": "litecord", } @@ -66,8 +67,4 @@ class Production(Config): class CI(Config): DEBUG = True - POSTGRES = { - 'host': 'postgres', - 'user': 'postgres', - 'password': '' - } + POSTGRES = {"host": "postgres", "user": "postgres", "password": ""} diff --git a/config.example.py b/config.example.py index 2e225e4..999add7 100644 --- a/config.example.py +++ b/config.example.py @@ -17,16 +17,17 @@ along with this program. If not, see . """ -MODE = 'Development' +MODE = "Development" class Config: """Default configuration values for litecord.""" + #: Main URL of the instance. - MAIN_URL = 'discordapp.io' + MAIN_URL = "discordapp.io" #: Name of the instance - NAME = 'Litecord/Nya' + NAME = "Litecord/Nya" #: Enable debug logging? DEBUG = False @@ -45,17 +46,17 @@ class Config: # Set this url to somewhere *your users* # will hit the websocket. # e.g 'gateway.example.com' for reverse proxies. - WEBSOCKET_URL = 'localhost:5001' + WEBSOCKET_URL = "localhost:5001" #: Where to host the websocket? # (a local address the server will bind to) - WS_HOST = '0.0.0.0' + WS_HOST = "0.0.0.0" WS_PORT = 5001 #: Mediaproxy URL on the internet # mediaproxy is made to prevent client IPs being leaked. # None is a valid value if you don't want to deploy mediaproxy. - MEDIA_PROXY = 'localhost:5002' + MEDIA_PROXY = "localhost:5002" #: Postgres credentials POSTGRES = {} @@ -65,10 +66,10 @@ class Development(Config): DEBUG = True POSTGRES = { - 'host': 'localhost', - 'user': 'litecord', - 'password': '123', - 'database': 'litecord', + "host": "localhost", + "user": "litecord", + "password": "123", + "database": "litecord", } @@ -77,8 +78,8 @@ class Production(Config): IS_SSL = True POSTGRES = { - 'host': 'some_production_postgres', - 'user': 'some_production_user', - 'password': 'some_production_password', - 'database': 'litecord_or_anything_else_really', + "host": "some_production_postgres", + "user": "some_production_user", + "password": "some_production_password", + "database": "litecord_or_anything_else_really", } diff --git a/litecord/__init__.py b/litecord/__init__.py index ce49370..d21f555 100644 --- a/litecord/__init__.py +++ b/litecord/__init__.py @@ -16,4 +16,3 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ - diff --git a/litecord/admin_schemas.py b/litecord/admin_schemas.py index 22d6bdb..900f734 100644 --- a/litecord/admin_schemas.py +++ b/litecord/admin_schemas.py @@ -19,42 +19,33 @@ along with this program. If not, see . from litecord.enums import Feature, UserFlags -VOICE_SERVER = { - 'hostname': {'type': 'string', 'maxlength': 255, 'required': True} -} +VOICE_SERVER = {"hostname": {"type": "string", "maxlength": 255, "required": True}} VOICE_REGION = { - 'id': {'type': 'string', 'maxlength': 255, 'required': True}, - 'name': {'type': 'string', 'maxlength': 255, 'required': True}, - - 'vip': {'type': 'boolean', 'default': False}, - 'deprecated': {'type': 'boolean', 'default': False}, - 'custom': {'type': 'boolean', 'default': False}, + "id": {"type": "string", "maxlength": 255, "required": True}, + "name": {"type": "string", "maxlength": 255, "required": True}, + "vip": {"type": "boolean", "default": False}, + "deprecated": {"type": "boolean", "default": False}, + "custom": {"type": "boolean", "default": False}, } FEATURES = { - 'features': { - 'type': 'list', 'required': True, - + "features": { + "type": "list", + "required": True, # using Feature doesn't seem to work with a "not callable" error. - 'schema': {'coerce': lambda x: Feature(x)} + "schema": {"coerce": lambda x: Feature(x)}, } } USER_CREATE = { - 'username': {'type': 'username', 'required': True}, - 'email': {'type': 'email', 'required': True}, - 'password': {'type': 'string', 'minlength': 5, 'required': True}, + "username": {"type": "username", "required": True}, + "email": {"type": "email", "required": True}, + "password": {"type": "string", "minlength": 5, "required": True}, } -INSTANCE_INVITE = { - 'max_uses': {'type': 'integer', 'required': True} -} +INSTANCE_INVITE = {"max_uses": {"type": "integer", "required": True}} -GUILD_UPDATE = { - 'unavailable': {'type': 'boolean', 'required': False} -} +GUILD_UPDATE = {"unavailable": {"type": "boolean", "required": False}} -USER_UPDATE = { - 'flags': {'required': False, 'coerce': UserFlags.from_int} -} +USER_UPDATE = {"flags": {"required": False, "coerce": UserFlags.from_int}} diff --git a/litecord/auth.py b/litecord/auth.py index 3c52ddf..e9ac19f 100644 --- a/litecord/auth.py +++ b/litecord/auth.py @@ -55,44 +55,50 @@ async def raw_token_check(token: str, db=None) -> int: # just try by fragments instead of # unpacking - fragments = token.split('.') + fragments = token.split(".") user_id = fragments[0] try: user_id = base64.b64decode(user_id.encode()) user_id = int(user_id) except (ValueError, binascii.Error): - raise Unauthorized('Invalid user ID type') + raise Unauthorized("Invalid user ID type") - pwd_hash = await db.fetchval(""" + pwd_hash = await db.fetchval( + """ SELECT password_hash FROM users WHERE id = $1 - """, user_id) + """, + user_id, + ) if not pwd_hash: - raise Unauthorized('User ID not found') + raise Unauthorized("User ID not found") signer = TimestampSigner(pwd_hash) try: signer.unsign(token) - log.debug('login for uid {} successful', user_id) + log.debug("login for uid {} successful", user_id) # update the user's last_session field # so that we can keep an exact track of activity, # even on long-lived single sessions (that can happen # with people leaving their clients open forever) - await db.execute(""" + await db.execute( + """ UPDATE users SET last_session = (now() at time zone 'utc') WHERE id = $1 - """, user_id) + """, + user_id, + ) return user_id except BadSignature: - log.warning('token failed for uid {}', user_id) - raise Forbidden('Invalid token') + log.warning("token failed for uid {}", user_id) + raise Forbidden("Invalid token") async def token_check() -> int: @@ -104,12 +110,12 @@ async def token_check() -> int: pass try: - token = request.headers['Authorization'] + token = request.headers["Authorization"] except KeyError: - raise Unauthorized('No token provided') + raise Unauthorized("No token provided") - if token.startswith('Bot '): - token = token.replace('Bot ', '') + if token.startswith("Bot "): + token = token.replace("Bot ", "") user_id = await raw_token_check(token) request.user_id = user_id @@ -120,15 +126,18 @@ async def admin_check() -> int: """Check if the user is an admin.""" user_id = await token_check() - flags = await app.db.fetchval(""" + flags = await app.db.fetchval( + """ SELECT flags FROM users WHERE id = $1 - """, user_id) + """, + user_id, + ) flags = UserFlags.from_int(flags) if not flags.is_staff: - raise Unauthorized('you are not staff') + raise Unauthorized("you are not staff") return user_id @@ -138,9 +147,7 @@ async def hash_data(data: str, loop=None) -> str: loop = loop or app.loop buf = data.encode() - hashed = await loop.run_in_executor( - None, bcrypt.hashpw, buf, bcrypt.gensalt(14) - ) + hashed = await loop.run_in_executor(None, bcrypt.hashpw, buf, bcrypt.gensalt(14)) return hashed.decode() @@ -148,22 +155,28 @@ async def hash_data(data: str, loop=None) -> str: async def check_username_usage(username: str, db=None): """Raise an error if too many people are with the same username.""" db = db or app.db - same_username = await db.fetchval(""" + same_username = await db.fetchval( + """ SELECT COUNT(*) FROM users WHERE username = $1 - """, username) + """, + username, + ) if same_username > 9000: - raise BadRequest('Too many people.', { - 'username': 'Too many people used the same username. ' - 'Please choose another' - }) + raise BadRequest( + "Too many people.", + { + "username": "Too many people used the same username. " + "Please choose another" + }, + ) def _raw_discrim() -> str: new_discrim = randint(1, 9999) - new_discrim = '%04d' % new_discrim + new_discrim = "%04d" % new_discrim return new_discrim @@ -186,11 +199,15 @@ async def roll_discrim(username: str, *, db=None) -> Optional[str]: discrim = _raw_discrim() # check if anyone is with it - res = await db.fetchval(""" + res = await db.fetchval( + """ SELECT id FROM users WHERE username = $1 AND discriminator = $2 - """, username, discrim) + """, + username, + discrim, + ) # if no user is found with the (username, discrim) # pair, then this is unique! return it. @@ -200,8 +217,9 @@ async def roll_discrim(username: str, *, db=None) -> Optional[str]: return None -async def create_user(username: str, email: str, password: str, - db=None, loop=None) -> Tuple[int, str]: +async def create_user( + username: str, email: str, password: str, db=None, loop=None +) -> Tuple[int, str]: """Create a single user. Generates a distriminator and other information. You can fetch the user @@ -214,20 +232,28 @@ async def create_user(username: str, email: str, password: str, new_discrim = await roll_discrim(username, db=db) if new_discrim is None: - raise BadRequest('Unable to register.', { - 'username': 'Too many people are with this username.' - }) + raise BadRequest( + "Unable to register.", + {"username": "Too many people are with this username."}, + ) pwd_hash = await hash_data(password, loop) try: - await db.execute(""" + await db.execute( + """ INSERT INTO users (id, email, username, discriminator, password_hash) VALUES ($1, $2, $3, $4, $5) - """, new_id, email, username, new_discrim, pwd_hash) + """, + new_id, + email, + username, + new_discrim, + pwd_hash, + ) except UniqueViolationError: - raise BadRequest('Email already used.') + raise BadRequest("Email already used.") return new_id, pwd_hash diff --git a/litecord/blueprints/__init__.py b/litecord/blueprints/__init__.py index 9b021b4..cc15930 100644 --- a/litecord/blueprints/__init__.py +++ b/litecord/blueprints/__init__.py @@ -34,7 +34,21 @@ from .static import bp as static from .attachments import bp as attachments from .dm_channels import bp as dm_channels -__all__ = ['gateway', 'auth', 'users', 'guilds', 'channels', - 'webhooks', 'science', 'voice', 'invites', 'relationships', - 'dms', 'icons', 'nodeinfo', 'static', 'attachments', - 'dm_channels'] +__all__ = [ + "gateway", + "auth", + "users", + "guilds", + "channels", + "webhooks", + "science", + "voice", + "invites", + "relationships", + "dms", + "icons", + "nodeinfo", + "static", + "attachments", + "dm_channels", +] diff --git a/litecord/blueprints/admin_api/__init__.py b/litecord/blueprints/admin_api/__init__.py index d27cc52..ab0e1e8 100644 --- a/litecord/blueprints/admin_api/__init__.py +++ b/litecord/blueprints/admin_api/__init__.py @@ -23,4 +23,4 @@ from .guilds import bp as guilds from .users import bp as users from .instance_invites import bp as instance_invites -__all__ = ['voice', 'features', 'guilds', 'users', 'instance_invites'] +__all__ = ["voice", "features", "guilds", "users", "instance_invites"] diff --git a/litecord/blueprints/admin_api/features.py b/litecord/blueprints/admin_api/features.py index 9de58d6..314bb44 100644 --- a/litecord/blueprints/admin_api/features.py +++ b/litecord/blueprints/admin_api/features.py @@ -25,45 +25,53 @@ from litecord.errors import BadRequest from litecord.schemas import validate from litecord.admin_schemas import FEATURES -bp = Blueprint('features_admin', __name__) +bp = Blueprint("features_admin", __name__) async def _features_from_req() -> List[str]: j = validate(await request.get_json(), FEATURES) - return [feature.value for feature in j['features']] + return [feature.value for feature in j["features"]] async def _features(guild_id: int): - return jsonify({ - 'features': await app.storage.guild_features(guild_id) - }) + return jsonify({"features": await app.storage.guild_features(guild_id)}) async def _update_features(guild_id: int, features: list): - if 'VANITY_URL' not in features: + if "VANITY_URL" not in features: existing_inv = await app.storage.vanity_invite(guild_id) - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM vanity_invites WHERE guild_id = $1 - """, guild_id) + """, + guild_id, + ) - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM invites WHERE code = $1 - """, existing_inv) + """, + existing_inv, + ) - await app.db.execute(""" + await app.db.execute( + """ UPDATE guilds SET features = $1 WHERE id = $2 - """, features, guild_id) + """, + features, + guild_id, + ) guild = await app.storage.get_guild_full(guild_id) - await app.dispatcher.dispatch('guild', guild_id, 'GUILD_UPDATE', guild) + await app.dispatcher.dispatch("guild", guild_id, "GUILD_UPDATE", guild) -@bp.route('//features', methods=['PATCH']) +@bp.route("//features", methods=["PATCH"]) async def replace_features(guild_id: int): """Replace the feature list in a guild""" await admin_check() @@ -76,7 +84,7 @@ async def replace_features(guild_id: int): return await _features(guild_id) -@bp.route('//features', methods=['PUT']) +@bp.route("//features", methods=["PUT"]) async def insert_features(guild_id: int): """Insert a feature on a guild.""" await admin_check() @@ -93,7 +101,7 @@ async def insert_features(guild_id: int): return await _features(guild_id) -@bp.route('//features', methods=['DELETE']) +@bp.route("//features", methods=["DELETE"]) async def remove_features(guild_id: int): """Remove a feature from a guild""" await admin_check() @@ -104,7 +112,7 @@ async def remove_features(guild_id: int): try: features.remove(feature) except ValueError: - raise BadRequest('Trying to remove already removed feature.') + raise BadRequest("Trying to remove already removed feature.") await _update_features(guild_id, features) return await _features(guild_id) diff --git a/litecord/blueprints/admin_api/guilds.py b/litecord/blueprints/admin_api/guilds.py index e27f544..1cf792b 100644 --- a/litecord/blueprints/admin_api/guilds.py +++ b/litecord/blueprints/admin_api/guilds.py @@ -25,9 +25,10 @@ from litecord.admin_schemas import GUILD_UPDATE from litecord.blueprints.guilds import delete_guild from litecord.errors import GuildNotFound -bp = Blueprint('guilds_admin', __name__) +bp = Blueprint("guilds_admin", __name__) -@bp.route('/', methods=['GET']) + +@bp.route("/", methods=["GET"]) async def get_guild(guild_id: int): """Get a basic guild payload.""" await admin_check() @@ -40,7 +41,7 @@ async def get_guild(guild_id: int): return jsonify(guild) -@bp.route('/', methods=['PATCH']) +@bp.route("/", methods=["PATCH"]) async def update_guild(guild_id: int): await admin_check() @@ -48,13 +49,13 @@ async def update_guild(guild_id: int): # TODO: what happens to the other guild attributes when its # unavailable? do they vanish? - old_unavailable = app.guild_store.get(guild_id, 'unavailable') - new_unavailable = j.get('unavailable', old_unavailable) + old_unavailable = app.guild_store.get(guild_id, "unavailable") + new_unavailable = j.get("unavailable", old_unavailable) # always set unavailable status since new_unavailable will be # old_unavailable when not provided, so we don't need to check if # j.unavailable is there - app.guild_store.set(guild_id, 'unavailable', j['unavailable']) + app.guild_store.set(guild_id, "unavailable", j["unavailable"]) guild = await app.storage.get_guild(guild_id) @@ -62,17 +63,17 @@ async def update_guild(guild_id: int): if old_unavailable and not new_unavailable: # guild became available - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_CREATE', guild) + await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild) else: # guild became unavailable - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_DELETE', guild) + await app.dispatcher.dispatch_guild(guild_id, "GUILD_DELETE", guild) return jsonify(guild) -@bp.route('/', methods=['DELETE']) +@bp.route("/", methods=["DELETE"]) async def delete_guild_as_admin(guild_id): """Delete a single guild via the admin API without ownership checks.""" await admin_check() await delete_guild(guild_id) - return '', 204 + return "", 204 diff --git a/litecord/blueprints/admin_api/instance_invites.py b/litecord/blueprints/admin_api/instance_invites.py index dbd6c78..c410c47 100644 --- a/litecord/blueprints/admin_api/instance_invites.py +++ b/litecord/blueprints/admin_api/instance_invites.py @@ -27,13 +27,13 @@ from litecord.types import timestamp_ from litecord.schemas import validate from litecord.admin_schemas import INSTANCE_INVITE -bp = Blueprint('instance_invites', __name__) +bp = Blueprint("instance_invites", __name__) ALPHABET = string.ascii_lowercase + string.ascii_uppercase + string.digits async def _gen_inv() -> str: """Generate an invite code""" - return ''.join(choice(ALPHABET) for _ in range(6)) + return "".join(choice(ALPHABET) for _ in range(6)) async def gen_inv(ctx) -> str: @@ -41,11 +41,14 @@ async def gen_inv(ctx) -> str: for _ in range(10): possible_inv = await _gen_inv() - created_at = await ctx.db.fetchval(""" + created_at = await ctx.db.fetchval( + """ SELECT created_at FROM instance_invites WHERE code = $1 - """, possible_inv) + """, + possible_inv, + ) if created_at is None: return possible_inv @@ -53,57 +56,71 @@ async def gen_inv(ctx) -> str: return None -@bp.route('', methods=['GET']) +@bp.route("", methods=["GET"]) async def _all_instance_invites(): await admin_check() - rows = await app.db.fetch(""" + rows = await app.db.fetch( + """ SELECT code, created_at, uses, max_uses FROM instance_invites - """) + """ + ) rows = [dict(row) for row in rows] for row in rows: - row['created_at'] = timestamp_(row['created_at']) + row["created_at"] = timestamp_(row["created_at"]) return jsonify(rows) -@bp.route('', methods=['PUT']) +@bp.route("", methods=["PUT"]) async def _create_invite(): await admin_check() code = await gen_inv(app) if code is None: - return 'failed to make invite', 500 + return "failed to make invite", 500 j = validate(await request.get_json(), INSTANCE_INVITE) - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO instance_invites (code, max_uses) VALUES ($1, $2) - """, code, j['max_uses']) + """, + code, + j["max_uses"], + ) - inv = dict(await app.db.fetchrow(""" + inv = dict( + await app.db.fetchrow( + """ SELECT code, created_at, uses, max_uses FROM instance_invites WHERE code = $1 - """, code)) + """, + code, + ) + ) return jsonify(dict(inv)) -@bp.route('/', methods=['DELETE']) +@bp.route("/", methods=["DELETE"]) async def _del_invite(invite: str): await admin_check() - res = await app.db.execute(""" + res = await app.db.execute( + """ DELETE FROM instance_invites WHERE code = $1 - """, invite) + """, + invite, + ) - if res.lower() == 'delete 0': - return 'invite not found', 404 + if res.lower() == "delete 0": + return "invite not found", 404 - return '', 204 + return "", 204 diff --git a/litecord/blueprints/admin_api/users.py b/litecord/blueprints/admin_api/users.py index ab3ba37..c1df4e8 100644 --- a/litecord/blueprints/admin_api/users.py +++ b/litecord/blueprints/admin_api/users.py @@ -25,24 +25,21 @@ from litecord.schemas import validate from litecord.admin_schemas import USER_CREATE, USER_UPDATE from litecord.errors import BadRequest, Forbidden from litecord.utils import async_map -from litecord.blueprints.users import ( - delete_user, user_disconnect, mass_user_update -) +from litecord.blueprints.users import delete_user, user_disconnect, mass_user_update from litecord.enums import UserFlags -bp = Blueprint('users_admin', __name__) +bp = Blueprint("users_admin", __name__) -@bp.route('', methods=['POST']) + +@bp.route("", methods=["POST"]) async def _create_user(): await admin_check() j = validate(await request.get_json(), USER_CREATE) - user_id, _ = await create_user(j['username'], j['email'], j['password']) + user_id, _ = await create_user(j["username"], j["email"], j["password"]) - return jsonify( - await app.storage.get_user(user_id) - ) + return jsonify(await app.storage.get_user(user_id)) def args_try(args: dict, typ, field: str, default): @@ -51,29 +48,29 @@ def args_try(args: dict, typ, field: str, default): try: return typ(args.get(field, default)) except (TypeError, ValueError): - raise BadRequest(f'invalid {field} value') + raise BadRequest(f"invalid {field} value") -@bp.route('', methods=['GET']) +@bp.route("", methods=["GET"]) async def _search_users(): await admin_check() args = request.args - username, discrim = args.get('username'), args.get('discriminator') + username, discrim = args.get("username"), args.get("discriminator") - per_page = args_try(args, int, 'per_page', 20) - page = args_try(args, int, 'page', 0) + per_page = args_try(args, int, "per_page", 20) + page = args_try(args, int, "page", 0) if page < 0: - raise BadRequest('invalid page number') + raise BadRequest("invalid page number") if per_page > 50: - raise BadRequest('invalid per page number') + raise BadRequest("invalid per page number") # any of those must be available. if not any((username, discrim)): - raise BadRequest('must insert username or discrim') + raise BadRequest("must insert username or discrim") wheres, args = [], [] @@ -82,29 +79,31 @@ async def _search_users(): args.append(username) if discrim: - wheres.append(f'discriminator = ${len(args) + 2}') + wheres.append(f"discriminator = ${len(args) + 2}") args.append(discrim) - where_tot = 'WHERE ' if args else '' - where_tot += ' AND '.join(wheres) + where_tot = "WHERE " if args else "" + where_tot += " AND ".join(wheres) - rows = await app.db.fetch(f""" + rows = await app.db.fetch( + f""" SELECT id FROM users {where_tot} ORDER BY id ASC LIMIT {per_page} OFFSET ($1 * {per_page}) - """, page, *args) - - rows = [r['id'] for r in rows] - - return jsonify( - await async_map(app.storage.get_user, rows) + """, + page, + *args, ) + rows = [r["id"] for r in rows] -@bp.route('/', methods=['DELETE']) + return jsonify(await async_map(app.storage.get_user, rows)) + + +@bp.route("/", methods=["DELETE"]) async def _delete_single_user(user_id: int): await admin_check() @@ -115,13 +114,10 @@ async def _delete_single_user(user_id: int): new_user = await app.storage.get_user(user_id) - return jsonify({ - 'old': old_user, - 'new': new_user - }) + return jsonify({"old": old_user, "new": new_user}) -@bp.route('/', methods=['PATCH']) +@bp.route("/", methods=["PATCH"]) async def patch_user(user_id: int): await admin_check() @@ -129,21 +125,25 @@ async def patch_user(user_id: int): # get the original user for flags checking user = await app.storage.get_user(user_id) - old_flags = UserFlags.from_int(user['flags']) + old_flags = UserFlags.from_int(user["flags"]) # j.flags is already a UserFlags since we coerce it. - if 'flags' in j: - new_flags = j['flags'] + if "flags" in j: + new_flags = j["flags"] # disallow any changes to the staff badge if new_flags.is_staff != old_flags.is_staff: - raise Forbidden('you can not change a users staff badge') + raise Forbidden("you can not change a users staff badge") - await app.db.execute(""" + await app.db.execute( + """ UPDATE users SET flags = $1 WHERE id = $2 - """, new_flags.value, user_id) + """, + new_flags.value, + user_id, + ) public_user, _ = await mass_user_update(user_id, app) return jsonify(public_user) diff --git a/litecord/blueprints/admin_api/voice.py b/litecord/blueprints/admin_api/voice.py index e700b27..334bcfe 100644 --- a/litecord/blueprints/admin_api/voice.py +++ b/litecord/blueprints/admin_api/voice.py @@ -27,10 +27,10 @@ from litecord.admin_schemas import VOICE_SERVER, VOICE_REGION from litecord.errors import BadRequest log = Logger(__name__) -bp = Blueprint('voice_admin', __name__) +bp = Blueprint("voice_admin", __name__) -@bp.route('/regions/', methods=['GET']) +@bp.route("/regions/", methods=["GET"]) async def get_region_servers(region): """Return a list of all servers for a region.""" await admin_check() @@ -38,18 +38,25 @@ async def get_region_servers(region): return jsonify(servers) -@bp.route('/regions', methods=['PUT']) +@bp.route("/regions", methods=["PUT"]) async def insert_new_region(): """Create a voice region.""" await admin_check() j = validate(await request.get_json(), VOICE_REGION) - j['id'] = j['id'].lower() + j["id"] = j["id"].lower() - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO voice_regions (id, name, vip, deprecated, custom) VALUES ($1, $2, $3, $4, $5) - """, j['id'], j['name'], j['vip'], j['deprecated'], j['custom']) + """, + j["id"], + j["name"], + j["vip"], + j["deprecated"], + j["custom"], + ) regions = await app.storage.all_voice_regions() region_count = len(regions) @@ -57,34 +64,41 @@ async def insert_new_region(): # if region count is 1, this is the first region to be created, # so we should update all guilds to that region if region_count == 1: - res = await app.db.execute(""" + res = await app.db.execute( + """ UPDATE guilds SET region = $1 - """, j['id']) + """, + j["id"], + ) - log.info('updating guilds to first voice region: {}', res) + log.info("updating guilds to first voice region: {}", res) return jsonify(regions) -@bp.route('/regions//servers', methods=['PUT']) +@bp.route("/regions//servers", methods=["PUT"]) async def put_region_server(region): """Insert a voice server to a region""" await admin_check() j = validate(await request.get_json(), VOICE_SERVER) try: - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO voice_servers (hostname, region_id) VALUES ($1, $2) - """, j['hostname'], region) + """, + j["hostname"], + region, + ) except asyncpg.UniqueViolationError: - raise BadRequest('voice server already exists with given hostname') + raise BadRequest("voice server already exists with given hostname") - return '', 204 + return "", 204 -@bp.route('/regions//deprecate', methods=['PUT']) +@bp.route("/regions//deprecate", methods=["PUT"]) async def deprecate_region(region): """Deprecate a voice region.""" await admin_check() @@ -92,13 +106,16 @@ async def deprecate_region(region): # TODO: write this await app.voice.disable_region(region) - await app.db.execute(""" + await app.db.execute( + """ UPDATE voice_regions SET deprecated = true WHERE id = $1 - """, region) + """, + region, + ) - return '', 204 + return "", 204 async def guild_region_check(app_): @@ -112,10 +129,11 @@ async def guild_region_check(app_): regions = await app_.storage.all_voice_regions() if not regions: - log.info('region check: no regions to move guilds to') + log.info("region check: no regions to move guilds to") return - res = await app_.db.execute(""" + res = await app_.db.execute( + """ UPDATE guilds SET region = ( SELECT id @@ -124,6 +142,8 @@ async def guild_region_check(app_): LIMIT 1 ) WHERE region = NULL - """, len(regions)) + """, + len(regions), + ) - log.info('region check: updating guild.region=null: {!r}', res) + log.info("region check: updating guild.region=null: {!r}", res) diff --git a/litecord/blueprints/attachments.py b/litecord/blueprints/attachments.py index 9d41755..49da21a 100644 --- a/litecord/blueprints/attachments.py +++ b/litecord/blueprints/attachments.py @@ -24,16 +24,17 @@ from PIL import Image from litecord.images import resize_gif -bp = Blueprint('attachments', __name__) -ATTACHMENTS = Path.cwd() / 'attachments' +bp = Blueprint("attachments", __name__) +ATTACHMENTS = Path.cwd() / "attachments" -async def _resize_gif(attach_id: int, resized_path: Path, - width: int, height: int) -> str: +async def _resize_gif( + attach_id: int, resized_path: Path, width: int, height: int +) -> str: """Resize a GIF attachment.""" # get original gif bytes - orig_path = ATTACHMENTS / f'{attach_id}.gif' + orig_path = ATTACHMENTS / f"{attach_id}.gif" orig_bytes = orig_path.read_bytes() # give them and the target size to the @@ -47,10 +48,7 @@ async def _resize_gif(attach_id: int, resized_path: Path, return str(resized_path) -FORMAT_HARDCODE = { - 'jpg': 'jpeg', - 'jpe': 'jpeg' -} +FORMAT_HARDCODE = {"jpg": "jpeg", "jpe": "jpeg"} def to_format(ext: str) -> str: @@ -63,11 +61,10 @@ def to_format(ext: str) -> str: return ext -async def _resize(image, attach_id: int, ext: str, - width: int, height: int) -> str: +async def _resize(image, attach_id: int, ext: str, width: int, height: int) -> str: """Resize an image.""" # check if we have it on the folder - resized_path = ATTACHMENTS / f'{attach_id}_{width}_{height}.{ext}' + resized_path = ATTACHMENTS / f"{attach_id}_{width}_{height}.{ext}" # keep a str-fied instance since that is what # we'll return. @@ -81,7 +78,7 @@ async def _resize(image, attach_id: int, ext: str, # the process is different for gif files because we need # gifsicle. doing it manually is too troublesome. - if ext == 'gif': + if ext == "gif": return await _resize_gif(attach_id, resized_path, width, height) # NOTE: this is the same resize mode for icons. @@ -91,38 +88,42 @@ async def _resize(image, attach_id: int, ext: str, return resized_path_s -@bp.route('/attachments' - '///', - methods=['GET']) -async def _get_attachment(channel_id: int, message_id: int, - filename: str): +@bp.route( + "/attachments" "///", methods=["GET"] +) +async def _get_attachment(channel_id: int, message_id: int, filename: str): - attach_id = await app.db.fetchval(""" + attach_id = await app.db.fetchval( + """ SELECT id FROM attachments WHERE channel_id = $1 AND message_id = $2 AND filename = $3 - """, channel_id, message_id, filename) + """, + channel_id, + message_id, + filename, + ) if attach_id is None: - return '', 404 + return "", 404 - ext = filename.split('.')[-1] - filepath = f'./attachments/{attach_id}.{ext}' + ext = filename.split(".")[-1] + filepath = f"./attachments/{attach_id}.{ext}" image = Image.open(filepath) im_width, im_height = image.size try: - width = int(request.args.get('width', 0)) or im_width + width = int(request.args.get("width", 0)) or im_width except ValueError: - return '', 400 + return "", 400 try: - height = int(request.args.get('height', 0)) or im_height + height = int(request.args.get("height", 0)) or im_height except ValueError: - return '', 400 + return "", 400 # if width and height are the same (happens if they weren't provided) if width == im_width and height == im_height: diff --git a/litecord/blueprints/auth.py b/litecord/blueprints/auth.py index e6d4124..38ce17c 100644 --- a/litecord/blueprints/auth.py +++ b/litecord/blueprints/auth.py @@ -32,7 +32,7 @@ from litecord.snowflake import get_snowflake from .invites import use_invite log = Logger(__name__) -bp = Blueprint('auth', __name__) +bp = Blueprint("auth", __name__) async def check_password(pwd_hash: str, given_password: str) -> bool: @@ -53,141 +53,139 @@ def make_token(user_id, user_pwd_hash) -> str: return signer.sign(user_id).decode() -@bp.route('/register', methods=['POST']) +@bp.route("/register", methods=["POST"]) async def register(): """Register a single user.""" - enabled = app.config.get('REGISTRATIONS') + enabled = app.config.get("REGISTRATIONS") if not enabled: - raise BadRequest('Registrations disabled', { - 'email': 'Registrations are disabled.' - }) + raise BadRequest( + "Registrations disabled", {"email": "Registrations are disabled."} + ) j = await request.get_json() - if not 'password' in j: + if not "password" in j: # we need a password to generate a token. # passwords are optional, so - j['password'] = 'default_password' + j["password"] = "default_password" j = validate(j, REGISTER) # they're optional - email = j.get('email') - invite = j.get('invite') + email = j.get("email") + invite = j.get("invite") - username, password = j['username'], j['password'] + username, password = j["username"], j["password"] - new_id, pwd_hash = await create_user( - username, email, password, app.db - ) + new_id, pwd_hash = await create_user(username, email, password, app.db) if invite: try: await use_invite(new_id, invite) except Exception: - log.exception('failed to use invite for register {} {!r}', - new_id, invite) + log.exception("failed to use invite for register {} {!r}", new_id, invite) - return jsonify({ - 'token': make_token(new_id, pwd_hash) - }) + return jsonify({"token": make_token(new_id, pwd_hash)}) -@bp.route('/register_inv', methods=['POST']) +@bp.route("/register_inv", methods=["POST"]) async def _register_with_invite(): data = await request.form data = validate(await request.form, REGISTER_WITH_INVITE) - invcode = data['invcode'] + invcode = data["invcode"] - row = await app.db.fetchrow(""" + row = await app.db.fetchrow( + """ SELECT uses, max_uses FROM instance_invites WHERE code = $1 - """, invcode) + """, + invcode, + ) if row is None: - raise BadRequest('unknown instance invite') + raise BadRequest("unknown instance invite") - if row['max_uses'] != -1 and row['uses'] >= row['max_uses']: - raise BadRequest('invite expired') + if row["max_uses"] != -1 and row["uses"] >= row["max_uses"]: + raise BadRequest("invite expired") - await app.db.execute(""" + await app.db.execute( + """ UPDATE instance_invites SET uses = uses + 1 WHERE code = $1 - """, invcode) + """, + invcode, + ) user_id, pwd_hash = await create_user( - data['username'], data['email'], data['password'], app.db) + data["username"], data["email"], data["password"], app.db + ) - return jsonify({ - 'token': make_token(user_id, pwd_hash), - 'user_id': str(user_id), - }) + return jsonify({"token": make_token(user_id, pwd_hash), "user_id": str(user_id)}) -@bp.route('/login', methods=['POST']) +@bp.route("/login", methods=["POST"]) async def login(): j = await request.get_json() - email, password = j['email'], j['password'] + email, password = j["email"], j["password"] - row = await app.db.fetchrow(""" + row = await app.db.fetchrow( + """ SELECT id, password_hash FROM users WHERE email = $1 - """, email) + """, + email, + ) if not row: - return jsonify({'email': ['User not found.']}), 401 + return jsonify({"email": ["User not found."]}), 401 user_id, pwd_hash = row if not await check_password(pwd_hash, password): - return jsonify({'password': ['Password does not match.']}), 401 + return jsonify({"password": ["Password does not match."]}), 401 - return jsonify({ - 'token': make_token(user_id, pwd_hash) - }) + return jsonify({"token": make_token(user_id, pwd_hash)}) -@bp.route('/consent-required', methods=['GET']) +@bp.route("/consent-required", methods=["GET"]) async def consent_required(): - return jsonify({ - 'required': True, - }) + return jsonify({"required": True}) -@bp.route('/verify/resend', methods=['POST']) +@bp.route("/verify/resend", methods=["POST"]) async def verify_user(): user_id = await token_check() # TODO: actually verify a user by sending an email - await app.db.execute(""" + await app.db.execute( + """ UPDATE users SET verified = true WHERE id = $1 - """, user_id) + """, + user_id, + ) new_user = await app.storage.get_user(user_id, True) - await app.dispatcher.dispatch_user( - user_id, 'USER_UPDATE', new_user) + await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", new_user) - return '', 204 + return "", 204 -@bp.route('/logout', methods=['POST']) +@bp.route("/logout", methods=["POST"]) async def _logout(): """Called by the client to logout.""" - return '', 204 + return "", 204 -@bp.route('/fingerprint', methods=['POST']) +@bp.route("/fingerprint", methods=["POST"]) async def _fingerprint(): """No idea what this route is about.""" fingerprint_id = get_snowflake() - fingerprint = f'{fingerprint_id}.{secrets.token_urlsafe(32)}' + fingerprint = f"{fingerprint_id}.{secrets.token_urlsafe(32)}" - return jsonify({ - 'fingerprint': fingerprint - }) + return jsonify({"fingerprint": fingerprint}) diff --git a/litecord/blueprints/channel/__init__.py b/litecord/blueprints/channel/__init__.py index fac85e9..fc08dd2 100644 --- a/litecord/blueprints/channel/__init__.py +++ b/litecord/blueprints/channel/__init__.py @@ -21,4 +21,4 @@ from .messages import bp as channel_messages from .reactions import bp as channel_reactions from .pins import bp as channel_pins -__all__ = ['channel_messages', 'channel_reactions', 'channel_pins'] +__all__ = ["channel_messages", "channel_reactions", "channel_pins"] diff --git a/litecord/blueprints/channel/dm_checks.py b/litecord/blueprints/channel/dm_checks.py index 1d0851a..e2cb195 100644 --- a/litecord/blueprints/channel/dm_checks.py +++ b/litecord/blueprints/channel/dm_checks.py @@ -30,13 +30,18 @@ class ForbiddenDM(Forbidden): async def dm_pre_check(user_id: int, channel_id: int, peer_id: int): """Check if the user can DM the peer.""" # first step is checking if there is a block in any direction - blockrow = await app.db.fetchrow(""" + blockrow = await app.db.fetchrow( + """ SELECT rel_type FROM relationships WHERE rel_type = $3 AND user_id IN ($1, $2) AND peer_id IN ($1, $2) - """, user_id, peer_id, RelationshipType.BLOCK.value) + """, + user_id, + peer_id, + RelationshipType.BLOCK.value, + ) if blockrow is not None: raise ForbiddenDM() @@ -58,8 +63,8 @@ async def dm_pre_check(user_id: int, channel_id: int, peer_id: int): user_settings = await app.user_storage.get_user_settings(user_id) peer_settings = await app.user_storage.get_user_settings(peer_id) - restricted_user_ = [int(v) for v in user_settings['restricted_guilds']] - restricted_peer_ = [int(v) for v in peer_settings['restricted_guilds']] + restricted_user_ = [int(v) for v in user_settings["restricted_guilds"]] + restricted_peer_ = [int(v) for v in peer_settings["restricted_guilds"]] restricted_user = set(restricted_user_) restricted_peer = set(restricted_peer_) diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index f051f1f..551508b 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -41,18 +41,18 @@ from litecord.images import try_unlink log = Logger(__name__) -bp = Blueprint('channel_messages', __name__) +bp = Blueprint("channel_messages", __name__) def extract_limit(request_, default: int = 50, max_val: int = 100): """Extract a limit kwarg.""" try: - limit = int(request_.args.get('limit', default)) + limit = int(request_.args.get("limit", default)) if limit not in range(0, max_val + 1): raise ValueError() except (TypeError, ValueError): - raise BadRequest('limit not int') + raise BadRequest("limit not int") return limit @@ -61,27 +61,27 @@ def query_tuple_from_args(args: dict, limit: int) -> tuple: """Extract a 2-tuple out of request arguments.""" before, after = None, None - if 'around' in request.args: + if "around" in request.args: average = int(limit / 2) - around = int(args['around']) + around = int(args["around"]) after = around - average before = around + average - elif 'before' in args: - before = int(args['before']) - elif 'after' in args: - before = int(args['after']) + elif "before" in args: + before = int(args["before"]) + elif "after" in args: + before = int(args["after"]) return before, after -@bp.route('//messages', methods=['GET']) +@bp.route("//messages", methods=["GET"]) async def get_messages(channel_id): user_id = await token_check() ctype, peer_id = await channel_check(user_id, channel_id) - await channel_perm_check(user_id, channel_id, 'read_history') + await channel_perm_check(user_id, channel_id, "read_history") if ctype == ChannelType.DM: # make sure both parties will be subbed @@ -91,42 +91,45 @@ async def get_messages(channel_id): limit = extract_limit(request, 50) - where_clause = '' + where_clause = "" before, after = query_tuple_from_args(request.args, limit) if before: - where_clause += f'AND id < {before}' + where_clause += f"AND id < {before}" if after: - where_clause += f'AND id > {after}' + where_clause += f"AND id > {after}" - message_ids = await app.db.fetch(f""" + message_ids = await app.db.fetch( + f""" SELECT id FROM messages WHERE channel_id = $1 {where_clause} ORDER BY id DESC LIMIT {limit} - """, channel_id) + """, + channel_id, + ) result = [] for message_id in message_ids: - msg = await app.storage.get_message(message_id['id'], user_id) + msg = await app.storage.get_message(message_id["id"], user_id) if msg is None: continue result.append(msg) - log.info('Fetched {} messages', len(result)) + log.info("Fetched {} messages", len(result)) return jsonify(result) -@bp.route('//messages/', methods=['GET']) +@bp.route("//messages/", methods=["GET"]) async def get_single_message(channel_id, message_id): user_id = await token_check() await channel_check(user_id, channel_id) - await channel_perm_check(user_id, channel_id, 'read_history') + await channel_perm_check(user_id, channel_id, "read_history") message = await app.storage.get_message(message_id, user_id) @@ -142,11 +145,15 @@ async def _dm_pre_dispatch(channel_id, peer_id): # check the other party's dm_channel_state - dm_state = await app.db.fetchval(""" + dm_state = await app.db.fetchval( + """ SELECT dm_id FROM dm_channel_state WHERE user_id = $1 AND dm_id = $2 - """, peer_id, channel_id) + """, + peer_id, + channel_id, + ) if dm_state: # the peer already has the channel @@ -157,18 +164,19 @@ async def _dm_pre_dispatch(channel_id, peer_id): # dispatch CHANNEL_CREATE so the client knows which # channel the future event is about - await app.dispatcher.dispatch_user(peer_id, 'CHANNEL_CREATE', dm_chan) + await app.dispatcher.dispatch_user(peer_id, "CHANNEL_CREATE", dm_chan) # subscribe the peer to the channel - await app.dispatcher.sub('channel', channel_id, peer_id) + await app.dispatcher.sub("channel", channel_id, peer_id) # insert it on dm_channel_state so the client # is subscribed on the future await try_dm_state(peer_id, channel_id) -async def create_message(channel_id: int, actual_guild_id: int, - author_id: int, data: dict) -> int: +async def create_message( + channel_id: int, actual_guild_id: int, author_id: int, data: dict +) -> int: message_id = get_snowflake() async with app.db.acquire() as conn: @@ -185,32 +193,32 @@ async def create_message(channel_id: int, actual_guild_id: int, channel_id, actual_guild_id, author_id, - data['content'], - - data['tts'], - data['everyone_mention'], - - data['nonce'], + data["content"], + data["tts"], + data["everyone_mention"], + data["nonce"], MessageType.DEFAULT.value, - data.get('embeds') or [] + data.get("embeds") or [], ) return message_id -async def msg_guild_text_mentions(payload: dict, guild_id: int, - mentions_everyone: bool, mentions_here: bool): + +async def msg_guild_text_mentions( + payload: dict, guild_id: int, mentions_everyone: bool, mentions_here: bool +): """Calculates mention data side-effects.""" - channel_id = int(payload['channel_id']) + channel_id = int(payload["channel_id"]) # calculate the user ids we'll bump the mention count for uids = set() # first is extracting user mentions - for mention in payload['mentions']: - uids.add(int(mention['id'])) + for mention in payload["mentions"]: + uids.add(int(mention["id"])) # then role mentions - for role_mention in payload['mention_roles']: + for role_mention in payload["mention_roles"]: role_id = int(role_mention) member_ids = await app.storage.get_role_members(role_id) @@ -223,11 +231,14 @@ async def msg_guild_text_mentions(payload: dict, guild_id: int, if mentions_here: uids = set() - await app.db.execute(""" + await app.db.execute( + """ UPDATE user_read_state SET mention_count = mention_count + 1 WHERE channel_id = $1 - """, channel_id) + """, + channel_id, + ) # at-here updates the read state # for all users, including the ones @@ -238,19 +249,26 @@ async def msg_guild_text_mentions(payload: dict, guild_id: int, member_ids = await app.storage.get_member_ids(guild_id) - await app.db.executemany(""" + await app.db.executemany( + """ UPDATE user_read_state SET mention_count = mention_count + 1 WHERE channel_id = $1 AND user_id = $2 - """, [(channel_id, uid) for uid in member_ids]) + """, + [(channel_id, uid) for uid in member_ids], + ) for user_id in uids: - await app.db.execute(""" + await app.db.execute( + """ UPDATE user_read_state SET mention_count = mention_count + 1 WHERE user_id = $1 AND channel_id = $2 - """, user_id, channel_id) + """, + user_id, + channel_id, + ) async def msg_create_request() -> tuple: @@ -264,12 +282,12 @@ async def msg_create_request() -> tuple: # NOTE: embed isn't set on form data json_from_form = { - 'content': form.get('content', ''), - 'nonce': form.get('nonce', '0'), - 'tts': json.loads(form.get('tts', 'false')), + "content": form.get("content", ""), + "nonce": form.get("nonce", "0"), + "tts": json.loads(form.get("tts", "false")), } - payload_json = json.loads(form.get('payload_json', '{}')) + payload_json = json.loads(form.get("payload_json", "{}")) json_from_form.update(request_json) json_from_form.update(payload_json) @@ -283,20 +301,19 @@ async def msg_create_request() -> tuple: def msg_create_check_content(payload: dict, files: list, *, use_embeds=False): """Check if there is actually any content being sent to us.""" - has_content = bool(payload.get('content', '')) + has_content = bool(payload.get("content", "")) has_files = len(files) > 0 - embed_field = 'embeds' if use_embeds else 'embed' + embed_field = "embeds" if use_embeds else "embed" has_embed = embed_field in payload and payload.get(embed_field) is not None has_total_content = has_content or has_embed or has_files if not has_total_content: - raise BadRequest('No content has been provided.') + raise BadRequest("No content has been provided.") -async def msg_add_attachment(message_id: int, channel_id: int, - attachment_file) -> int: +async def msg_add_attachment(message_id: int, channel_id: int, attachment_file) -> int: """Add an attachment to a message. Parameters @@ -318,7 +335,7 @@ async def msg_add_attachment(message_id: int, channel_id: int, # understand file info mime = attachment_file.mimetype - is_image = mime.startswith('image/') + is_image = mime.startswith("image/") img_width, img_height = None, None @@ -346,17 +363,22 @@ async def msg_add_attachment(message_id: int, channel_id: int, VALUES ($1, $2, $3, $4, $5, $6, $7, $8) """, - attachment_id, channel_id, message_id, - filename, file_size, - is_image, img_width, img_height) + attachment_id, + channel_id, + message_id, + filename, + file_size, + is_image, + img_width, + img_height, + ) - ext = filename.split('.')[-1] + ext = filename.split(".")[-1] - with open(f'attachments/{attachment_id}.{ext}', 'wb') as attach_file: + with open(f"attachments/{attachment_id}.{ext}", "wb") as attach_file: attach_file.write(attachment_file.stream.read()) - log.debug('written {} bytes for attachment id {}', - file_size, attachment_id) + log.debug("written {} bytes for attachment id {}", file_size, attachment_id) return attachment_id @@ -364,12 +386,12 @@ async def msg_add_attachment(message_id: int, channel_id: int, async def _spawn_embed(app_, payload, **kwargs): app_.sched.spawn( process_url_embed( - app_.config, app_.storage, app_.dispatcher, app_.session, - payload, **kwargs) + app_.config, app_.storage, app_.dispatcher, app_.session, payload, **kwargs + ) ) -@bp.route('//messages', methods=['POST']) +@bp.route("//messages", methods=["POST"]) async def _create_message(channel_id): """Create a message.""" @@ -379,7 +401,7 @@ async def _create_message(channel_id): actual_guild_id = None if ctype in GUILD_CHANS: - await channel_perm_check(user_id, channel_id, 'send_messages') + await channel_perm_check(user_id, channel_id, "send_messages") actual_guild_id = guild_id payload_json, files = await msg_create_request() @@ -394,29 +416,31 @@ async def _create_message(channel_id): await dm_pre_check(user_id, channel_id, guild_id) can_everyone = await channel_perm_check( - user_id, channel_id, 'mention_everyone', False + user_id, channel_id, "mention_everyone", False ) - mentions_everyone = ('@everyone' in j['content']) and can_everyone - mentions_here = ('@here' in j['content']) and can_everyone + mentions_everyone = ("@everyone" in j["content"]) and can_everyone + mentions_here = ("@here" in j["content"]) and can_everyone - is_tts = (j.get('tts', False) and - await channel_perm_check( - user_id, channel_id, 'send_tts_messages', False - )) + is_tts = j.get("tts", False) and await channel_perm_check( + user_id, channel_id, "send_tts_messages", False + ) message_id = await create_message( - channel_id, actual_guild_id, user_id, { - 'content': j['content'], - 'tts': is_tts, - 'nonce': int(j.get('nonce', 0)), - 'everyone_mention': mentions_everyone or mentions_here, - + channel_id, + actual_guild_id, + user_id, + { + "content": j["content"], + "tts": is_tts, + "nonce": int(j.get("nonce", 0)), + "everyone_mention": mentions_everyone or mentions_here, # fill_embed takes care of filling proxy and width/height - 'embeds': ([await fill_embed(j['embed'])] - if j.get('embed') is not None - else []), - }) + "embeds": ( + [await fill_embed(j["embed"])] if j.get("embed") is not None else [] + ), + }, + ) # for each file given, we add it as an attachment for pre_attachment in files: @@ -429,8 +453,7 @@ async def _create_message(channel_id): await _dm_pre_dispatch(channel_id, user_id) await _dm_pre_dispatch(channel_id, guild_id) - await app.dispatcher.dispatch('channel', channel_id, - 'MESSAGE_CREATE', payload) + await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload) # spawn url processor for embedding of images perms = await get_permissions(user_id, channel_id) @@ -438,54 +461,71 @@ async def _create_message(channel_id): await _spawn_embed(app, payload) # update read state for the author - await app.db.execute(""" + await app.db.execute( + """ UPDATE user_read_state SET last_message_id = $1 WHERE channel_id = $2 AND user_id = $3 - """, message_id, channel_id, user_id) + """, + message_id, + channel_id, + user_id, + ) if ctype == ChannelType.GUILD_TEXT: await msg_guild_text_mentions( - payload, guild_id, mentions_everyone, mentions_here) + payload, guild_id, mentions_everyone, mentions_here + ) return jsonify(payload) -@bp.route('//messages/', methods=['PATCH']) +@bp.route("//messages/", methods=["PATCH"]) async def edit_message(channel_id, message_id): user_id = await token_check() _ctype, _guild_id = await channel_check(user_id, channel_id) - author_id = await app.db.fetchval(""" + author_id = await app.db.fetchval( + """ SELECT author_id FROM messages WHERE messages.id = $1 - """, message_id) + """, + message_id, + ) if not author_id == user_id: - raise Forbidden('You can not edit this message') + raise Forbidden("You can not edit this message") j = await request.get_json() - updated = 'content' in j or 'embed' in j + updated = "content" in j or "embed" in j old_message = await app.storage.get_message(message_id) - if 'content' in j: - await app.db.execute(""" + if "content" in j: + await app.db.execute( + """ UPDATE messages SET content=$1 WHERE messages.id = $2 - """, j['content'], message_id) + """, + j["content"], + message_id, + ) - if 'embed' in j: - embeds = [await fill_embed(j['embed'])] + if "embed" in j: + embeds = [await fill_embed(j["embed"])] - await app.db.execute(""" + await app.db.execute( + """ UPDATE messages SET embeds=$1 WHERE messages.id = $2 - """, embeds, message_id) + """, + embeds, + message_id, + ) # do not spawn process_url_embed since we already have embeds. - elif 'content' in j: + elif "content" in j: # if there weren't any embed changes BUT # we had a content change, we dispatch process_url_embed but with # an artificial delay. @@ -495,46 +535,55 @@ async def edit_message(channel_id, message_id): # BEFORE the MESSAGE_UPDATE with the new embeds (based on content) perms = await get_permissions(user_id, channel_id) if perms.bits.embed_links: - await _spawn_embed(app, { - 'id': message_id, - 'channel_id': channel_id, - 'content': j['content'], - 'embeds': old_message['embeds'] - }, delay=0.2) + await _spawn_embed( + app, + { + "id": message_id, + "channel_id": channel_id, + "content": j["content"], + "embeds": old_message["embeds"], + }, + delay=0.2, + ) # only set new timestamp upon actual update if updated: - await app.db.execute(""" + await app.db.execute( + """ UPDATE messages SET edited_at = (now() at time zone 'utc') WHERE id = $1 - """, message_id) + """, + message_id, + ) message = await app.storage.get_message(message_id, user_id) # only dispatch MESSAGE_UPDATE if any update # actually happened if updated: - await app.dispatcher.dispatch('channel', channel_id, - 'MESSAGE_UPDATE', message) + await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_UPDATE", message) return jsonify(message) async def _del_msg_fkeys(message_id: int): - attachs = await app.db.fetch(""" + attachs = await app.db.fetch( + """ SELECT id FROM attachments WHERE message_id = $1 - """, message_id) + """, + message_id, + ) - attachs = [r['id'] for r in attachs] + attachs = [r["id"] for r in attachs] - attachments = Path('./attachments') + attachments = Path("./attachments") for attach_id in attachs: # anything starting with the given attachment shall be # deleted, because there may be resizes of the original # attachment laying around. - for filepath in attachments.glob(f'{attach_id}*'): + for filepath in attachments.glob(f"{attach_id}*"): try_unlink(filepath) # after trying to delete all available attachments, delete @@ -542,51 +591,64 @@ async def _del_msg_fkeys(message_id: int): # take the chance and delete all the data from the other tables too! - tables = ['attachments', 'message_webhook_info', - 'message_reactions', 'channel_pins'] + tables = [ + "attachments", + "message_webhook_info", + "message_reactions", + "channel_pins", + ] for table in tables: - await app.db.execute(f""" + await app.db.execute( + f""" DELETE FROM {table} WHERE message_id = $1 - """, message_id) + """, + message_id, + ) -@bp.route('//messages/', methods=['DELETE']) +@bp.route("//messages/", methods=["DELETE"]) async def delete_message(channel_id, message_id): user_id = await token_check() _ctype, guild_id = await channel_check(user_id, channel_id) - author_id = await app.db.fetchval(""" + author_id = await app.db.fetchval( + """ SELECT author_id FROM messages WHERE messages.id = $1 - """, message_id) - - by_perm = await channel_perm_check( - user_id, channel_id, 'manage_messages', False + """, + message_id, ) + by_perm = await channel_perm_check(user_id, channel_id, "manage_messages", False) + by_ownership = author_id == user_id can_delete = by_perm or by_ownership if not can_delete: - raise Forbidden('You can not delete this message') + raise Forbidden("You can not delete this message") await _del_msg_fkeys(message_id) - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM messages WHERE messages.id = $1 - """, message_id) + """, + message_id, + ) await app.dispatcher.dispatch( - 'channel', channel_id, - 'MESSAGE_DELETE', { - 'id': str(message_id), - 'channel_id': str(channel_id), - + "channel", + channel_id, + "MESSAGE_DELETE", + { + "id": str(message_id), + "channel_id": str(channel_id), # for lazy guilds - 'guild_id': str(guild_id), - }) + "guild_id": str(guild_id), + }, + ) - return '', 204 + return "", 204 diff --git a/litecord/blueprints/channel/pins.py b/litecord/blueprints/channel/pins.py index f5245d7..5cdef0a 100644 --- a/litecord/blueprints/channel/pins.py +++ b/litecord/blueprints/channel/pins.py @@ -28,28 +28,32 @@ from litecord.system_messages import send_sys_message from litecord.enums import MessageType, SYS_MESSAGES from litecord.errors import BadRequest -bp = Blueprint('channel_pins', __name__) +bp = Blueprint("channel_pins", __name__) class SysMsgInvalidAction(BadRequest): """Invalid action on a system message.""" + error_code = 50021 -@bp.route('//pins', methods=['GET']) +@bp.route("//pins", methods=["GET"]) async def get_pins(channel_id): """Get the pins for a channel""" user_id = await token_check() await channel_check(user_id, channel_id) - ids = await app.db.fetch(""" + ids = await app.db.fetch( + """ SELECT message_id FROM channel_pins WHERE channel_id = $1 ORDER BY message_id DESC - """, channel_id) + """, + channel_id, + ) - ids = [r['message_id'] for r in ids] + ids = [r["message_id"] for r in ids] res = [] for message_id in ids: @@ -60,80 +64,96 @@ async def get_pins(channel_id): return jsonify(res) -@bp.route('//pins/', methods=['PUT']) +@bp.route("//pins/", methods=["PUT"]) async def add_pin(channel_id, message_id): """Add a pin to a channel""" user_id = await token_check() _ctype, guild_id = await channel_check(user_id, channel_id) - await channel_perm_check(user_id, channel_id, 'manage_messages') + await channel_perm_check(user_id, channel_id, "manage_messages") - mtype = await app.db.fetchval(""" + mtype = await app.db.fetchval( + """ SELECT message_type FROM messages WHERE id = $1 - """, message_id) + """, + message_id, + ) if mtype in SYS_MESSAGES: - raise SysMsgInvalidAction( - 'Cannot execute action on a system message') + raise SysMsgInvalidAction("Cannot execute action on a system message") - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO channel_pins (channel_id, message_id) VALUES ($1, $2) - """, channel_id, message_id) + """, + channel_id, + message_id, + ) - row = await app.db.fetchrow(""" + row = await app.db.fetchrow( + """ SELECT message_id FROM channel_pins WHERE channel_id = $1 ORDER BY message_id ASC LIMIT 1 - """, channel_id) - - timestamp = snowflake_datetime(row['message_id']) - - await app.dispatcher.dispatch( - 'channel', channel_id, 'CHANNEL_PINS_UPDATE', - { - 'channel_id': str(channel_id), - 'last_pin_timestamp': timestamp_(timestamp) - } + """, + channel_id, ) - await send_sys_message(app, channel_id, - MessageType.CHANNEL_PINNED_MESSAGE, - message_id, user_id) + timestamp = snowflake_datetime(row["message_id"]) - return '', 204 + await app.dispatcher.dispatch( + "channel", + channel_id, + "CHANNEL_PINS_UPDATE", + {"channel_id": str(channel_id), "last_pin_timestamp": timestamp_(timestamp)}, + ) + + await send_sys_message( + app, channel_id, MessageType.CHANNEL_PINNED_MESSAGE, message_id, user_id + ) + + return "", 204 -@bp.route('//pins/', methods=['DELETE']) +@bp.route("//pins/", methods=["DELETE"]) async def delete_pin(channel_id, message_id): user_id = await token_check() _ctype, guild_id = await channel_check(user_id, channel_id) - await channel_perm_check(user_id, channel_id, 'manage_messages') + await channel_perm_check(user_id, channel_id, "manage_messages") - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM channel_pins WHERE channel_id = $1 AND message_id = $2 - """, channel_id, message_id) + """, + channel_id, + message_id, + ) - row = await app.db.fetchrow(""" + row = await app.db.fetchrow( + """ SELECT message_id FROM channel_pins WHERE channel_id = $1 ORDER BY message_id ASC LIMIT 1 - """, channel_id) + """, + channel_id, + ) - timestamp = snowflake_datetime(row['message_id']) + timestamp = snowflake_datetime(row["message_id"]) await app.dispatcher.dispatch( - 'channel', channel_id, 'CHANNEL_PINS_UPDATE', { - 'channel_id': str(channel_id), - 'last_pin_timestamp': timestamp.isoformat() - }) + "channel", + channel_id, + "CHANNEL_PINS_UPDATE", + {"channel_id": str(channel_id), "last_pin_timestamp": timestamp.isoformat()}, + ) - return '', 204 + return "", 204 diff --git a/litecord/blueprints/channel/reactions.py b/litecord/blueprints/channel/reactions.py index 7ede137..4c8dabd 100644 --- a/litecord/blueprints/channel/reactions.py +++ b/litecord/blueprints/channel/reactions.py @@ -26,17 +26,15 @@ from logbook import Logger from litecord.utils import async_map from litecord.blueprints.auth import token_check from litecord.blueprints.checks import channel_check, channel_perm_check -from litecord.blueprints.channel.messages import ( - query_tuple_from_args, extract_limit -) +from litecord.blueprints.channel.messages import query_tuple_from_args, extract_limit from litecord.enums import GUILD_CHANS log = Logger(__name__) -bp = Blueprint('channel_reactions', __name__) +bp = Blueprint("channel_reactions", __name__) -BASEPATH = '//messages//reactions' +BASEPATH = "//messages//reactions" class EmojiType(IntEnum): @@ -51,16 +49,14 @@ def emoji_info_from_str(emoji: str) -> tuple: # unicode emoji just have the raw unicode. # try checking if the emoji is custom or unicode - emoji_type = 0 if ':' in emoji else 1 + emoji_type = 0 if ":" in emoji else 1 emoji_type = EmojiType(emoji_type) # extract the emoji id OR the unicode value of the emoji # depending if it is custom or not - emoji_id = (int(emoji.split(':')[1]) - if emoji_type == EmojiType.CUSTOM - else emoji) + emoji_id = int(emoji.split(":")[1]) if emoji_type == EmojiType.CUSTOM else emoji - emoji_name = emoji.split(':')[0] + emoji_name = emoji.split(":")[0] return emoji_type, emoji_id, emoji_name @@ -68,27 +64,27 @@ def emoji_info_from_str(emoji: str) -> tuple: def partial_emoji(emoji_type, emoji_id, emoji_name) -> dict: print(emoji_type, emoji_id, emoji_name) return { - 'id': None if emoji_type == EmojiType.UNICODE else emoji_id, - 'name': emoji_name if emoji_type == EmojiType.UNICODE else emoji_id + "id": None if emoji_type == EmojiType.UNICODE else emoji_id, + "name": emoji_name if emoji_type == EmojiType.UNICODE else emoji_id, } def _make_payload(user_id, channel_id, message_id, partial): return { - 'user_id': str(user_id), - 'channel_id': str(channel_id), - 'message_id': str(message_id), - 'emoji': partial + "user_id": str(user_id), + "channel_id": str(channel_id), + "message_id": str(message_id), + "emoji": partial, } -@bp.route(f'{BASEPATH}//@me', methods=['PUT']) +@bp.route(f"{BASEPATH}//@me", methods=["PUT"]) async def add_reaction(channel_id: int, message_id: int, emoji: str): """Put a reaction.""" user_id = await token_check() ctype, guild_id = await channel_check(user_id, channel_id) - await channel_perm_check(user_id, channel_id, 'read_history') + await channel_perm_check(user_id, channel_id, "read_history") emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji) @@ -97,52 +93,64 @@ async def add_reaction(channel_id: int, message_id: int, emoji: str): # ADD_REACTIONS is only checked when this is the first # reaction in a message. - reaction_count = await app.db.fetchval(""" + reaction_count = await app.db.fetchval( + """ SELECT COUNT(*) FROM message_reactions WHERE message_id = $1 AND emoji_type = $2 AND emoji_id = $3 AND emoji_text = $4 - """, message_id, emoji_type, emoji_id, emoji_text) + """, + message_id, + emoji_type, + emoji_id, + emoji_text, + ) if reaction_count == 0: - await channel_perm_check(user_id, channel_id, 'add_reactions') + await channel_perm_check(user_id, channel_id, "add_reactions") await app.db.execute( """ INSERT INTO message_reactions (message_id, user_id, emoji_type, emoji_id, emoji_text) VALUES ($1, $2, $3, $4, $5) - """, message_id, user_id, emoji_type, - + """, + message_id, + user_id, + emoji_type, # if it is custom, we put the emoji_id on emoji_id # column, if it isn't, we put it on emoji_text # column. - emoji_id, emoji_text + emoji_id, + emoji_text, ) partial = partial_emoji(emoji_type, emoji_id, emoji_name) payload = _make_payload(user_id, channel_id, message_id, partial) if ctype in GUILD_CHANS: - payload['guild_id'] = str(guild_id) + payload["guild_id"] = str(guild_id) await app.dispatcher.dispatch( - 'channel', channel_id, 'MESSAGE_REACTION_ADD', payload) + "channel", channel_id, "MESSAGE_REACTION_ADD", payload + ) - return '', 204 + return "", 204 def emoji_sql(emoji_type, emoji_id, emoji_name, param=4): """Extract SQL clauses to search for specific emoji in the message_reactions table.""" - param = f'${param}' + param = f"${param}" # know which column to filter with - where_ext = (f'AND emoji_id = {param}' - if emoji_type == EmojiType.CUSTOM else - f'AND emoji_text = {param}') + where_ext = ( + f"AND emoji_id = {param}" + if emoji_type == EmojiType.CUSTOM + else f"AND emoji_text = {param}" + ) # which emoji to remove (custom or unicode) main_emoji = emoji_id if emoji_type == EmojiType.CUSTOM else emoji_name @@ -157,8 +165,7 @@ def _emoji_sql_simple(emoji: str, param=4): return emoji_sql(emoji_type, emoji_id, emoji_name, param) -async def remove_reaction(channel_id: int, message_id: int, - user_id: int, emoji: str): +async def remove_reaction(channel_id: int, message_id: int, user_id: int, emoji: str): ctype, guild_id = await channel_check(user_id, channel_id) emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji) @@ -171,39 +178,45 @@ async def remove_reaction(channel_id: int, message_id: int, AND user_id = $2 AND emoji_type = $3 {where_ext} - """, message_id, user_id, emoji_type, main_emoji) + """, + message_id, + user_id, + emoji_type, + main_emoji, + ) partial = partial_emoji(emoji_type, emoji_id, emoji_name) payload = _make_payload(user_id, channel_id, message_id, partial) if ctype in GUILD_CHANS: - payload['guild_id'] = str(guild_id) + payload["guild_id"] = str(guild_id) await app.dispatcher.dispatch( - 'channel', channel_id, 'MESSAGE_REACTION_REMOVE', payload) + "channel", channel_id, "MESSAGE_REACTION_REMOVE", payload + ) -@bp.route(f'{BASEPATH}//@me', methods=['DELETE']) +@bp.route(f"{BASEPATH}//@me", methods=["DELETE"]) async def remove_own_reaction(channel_id, message_id, emoji): """Remove a reaction.""" user_id = await token_check() await remove_reaction(channel_id, message_id, user_id, emoji) - return '', 204 + return "", 204 -@bp.route(f'{BASEPATH}//', methods=['DELETE']) +@bp.route(f"{BASEPATH}//", methods=["DELETE"]) async def remove_user_reaction(channel_id, message_id, emoji, other_id): """Remove a reaction made by another user.""" user_id = await token_check() - await channel_perm_check(user_id, channel_id, 'manage_messages') + await channel_perm_check(user_id, channel_id, "manage_messages") await remove_reaction(channel_id, message_id, other_id, emoji) - return '', 204 + return "", 204 -@bp.route(f'{BASEPATH}/', methods=['GET']) +@bp.route(f"{BASEPATH}/", methods=["GET"]) async def list_users_reaction(channel_id, message_id, emoji): """Get the list of all users who reacted with a certain emoji.""" user_id = await token_check() @@ -215,42 +228,49 @@ async def list_users_reaction(channel_id, message_id, emoji): limit = extract_limit(request, 25) before, after = query_tuple_from_args(request.args, limit) - before_clause = 'AND user_id < $2' if before else '' - after_clause = 'AND user_id > $3' if after else '' + before_clause = "AND user_id < $2" if before else "" + after_clause = "AND user_id > $3" if after else "" where_ext, main_emoji = _emoji_sql_simple(emoji, 4) - rows = await app.db.fetch(f""" + rows = await app.db.fetch( + f""" SELECT user_id FROM message_reactions WHERE message_id = $1 {before_clause} {after_clause} {where_ext} - """, message_id, before, after, main_emoji) + """, + message_id, + before, + after, + main_emoji, + ) - user_ids = [r['user_id'] for r in rows] + user_ids = [r["user_id"] for r in rows] users = await async_map(app.storage.get_user, user_ids) return jsonify(users) -@bp.route(f'{BASEPATH}', methods=['DELETE']) +@bp.route(f"{BASEPATH}", methods=["DELETE"]) async def remove_all_reactions(channel_id, message_id): """Remove all reactions in a message.""" user_id = await token_check() ctype, guild_id = await channel_check(user_id, channel_id) - await channel_perm_check(user_id, channel_id, 'manage_messages') + await channel_perm_check(user_id, channel_id, "manage_messages") - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM message_reactions WHERE message_id = $1 - """, message_id) + """, + message_id, + ) - payload = { - 'channel_id': str(channel_id), - 'message_id': str(message_id), - } + payload = {"channel_id": str(channel_id), "message_id": str(message_id)} if ctype in GUILD_CHANS: - payload['guild_id'] = str(guild_id) + payload["guild_id"] = str(guild_id) await app.dispatcher.dispatch( - 'channel', channel_id, 'MESSAGE_REACTION_REMOVE_ALL', payload) + "channel", channel_id, "MESSAGE_REACTION_REMOVE_ALL", payload + ) diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index 2c9824b..a4b6c45 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -28,24 +28,26 @@ from litecord.auth import token_check from litecord.enums import ChannelType, GUILD_CHANS, MessageType, MessageFlags from litecord.errors import ChannelNotFound, Forbidden, BadRequest from litecord.schemas import ( - validate, CHAN_UPDATE, CHAN_OVERWRITE, SEARCH_CHANNEL, GROUP_DM_UPDATE, + validate, + CHAN_UPDATE, + CHAN_OVERWRITE, + SEARCH_CHANNEL, + GROUP_DM_UPDATE, BULK_DELETE, ) from litecord.blueprints.checks import channel_check, channel_perm_check from litecord.system_messages import send_sys_message -from litecord.blueprints.dm_channels import ( - gdm_remove_recipient, gdm_destroy -) +from litecord.blueprints.dm_channels import gdm_remove_recipient, gdm_destroy from litecord.utils import search_result_from_list from litecord.embed.messages import process_url_embed, msg_update_embeds from litecord.snowflake import snowflake_datetime log = Logger(__name__) -bp = Blueprint('channels', __name__) +bp = Blueprint("channels", __name__) -@bp.route('/', methods=['GET']) +@bp.route("/", methods=["GET"]) async def get_channel(channel_id): """Get a single channel's information""" user_id = await token_check() @@ -56,7 +58,7 @@ async def get_channel(channel_id): chan = await app.storage.get_channel(channel_id) if not chan: - raise ChannelNotFound('single channel not found') + raise ChannelNotFound("single channel not found") return jsonify(chan) @@ -64,106 +66,129 @@ async def get_channel(channel_id): async def __guild_chan_sql(guild_id, channel_id, field: str) -> str: """Update a guild's channel id field to NULL, if it was set to the given channel id before.""" - return await app.db.execute(f""" + return await app.db.execute( + f""" UPDATE guilds SET {field} = NULL WHERE guilds.id = $1 AND {field} = $2 - """, guild_id, channel_id) + """, + guild_id, + channel_id, + ) async def _update_guild_chan_text(guild_id: int, channel_id: int): - res_embed = await __guild_chan_sql( - guild_id, channel_id, 'embed_channel_id') + res_embed = await __guild_chan_sql(guild_id, channel_id, "embed_channel_id") - res_widget = await __guild_chan_sql( - guild_id, channel_id, 'widget_channel_id') + res_widget = await __guild_chan_sql(guild_id, channel_id, "widget_channel_id") - res_system = await __guild_chan_sql( - guild_id, channel_id, 'system_channel_id') + res_system = await __guild_chan_sql(guild_id, channel_id, "system_channel_id") # if none of them were actually updated, # ignore and dont dispatch anything - if 'UPDATE 1' not in (res_embed, res_widget, res_system): + if "UPDATE 1" not in (res_embed, res_widget, res_system): return # at least one of the fields were updated, # dispatch GUILD_UPDATE guild = await app.storage.get_guild(guild_id) - await app.dispatcher.dispatch_guild( - guild_id, 'GUILD_UPDATE', guild) + await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild) async def _update_guild_chan_voice(guild_id: int, channel_id: int): - res = await __guild_chan_sql(guild_id, channel_id, 'afk_channel_id') + res = await __guild_chan_sql(guild_id, channel_id, "afk_channel_id") # guild didnt update - if res == 'UPDATE 0': + if res == "UPDATE 0": return guild = await app.storage.get_guild(guild_id) - await app.dispatcher.dispatch_guild( - guild_id, 'GUILD_UPDATE', guild) + await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild) async def _update_guild_chan_cat(guild_id: int, channel_id: int): # get all channels that were childs of the category - childs = await app.db.fetch(""" + childs = await app.db.fetch( + """ SELECT id FROM guild_channels WHERE guild_id = $1 AND parent_id = $2 - """, guild_id, channel_id) - childs = [c['id'] for c in childs] + """, + guild_id, + channel_id, + ) + childs = [c["id"] for c in childs] # update every child channel to parent_id = NULL - await app.db.execute(""" + await app.db.execute( + """ UPDATE guild_channels SET parent_id = NULL WHERE guild_id = $1 AND parent_id = $2 - """, guild_id, channel_id) + """, + guild_id, + channel_id, + ) # tell all people in the guild of the category removal for child_id in childs: child = await app.storage.get_channel(child_id) - await app.dispatcher.dispatch_guild( - guild_id, 'CHANNEL_UPDATE', child - ) + await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", child) async def delete_messages(channel_id): - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM channel_pins WHERE channel_id = $1 - """, channel_id) + """, + channel_id, + ) - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM user_read_state WHERE channel_id = $1 - """, channel_id) + """, + channel_id, + ) - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM messages WHERE channel_id = $1 - """, channel_id) + """, + channel_id, + ) async def guild_cleanup(channel_id): - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM channel_overwrites WHERE channel_id = $1 - """, channel_id) + """, + channel_id, + ) - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM invites WHERE channel_id = $1 - """, channel_id) + """, + channel_id, + ) - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM webhooks WHERE channel_id = $1 - """, channel_id) + """, + channel_id, + ) -@bp.route('/', methods=['DELETE']) +@bp.route("/", methods=["DELETE"]) async def close_channel(channel_id): """Close or delete a channel.""" user_id = await token_check() @@ -184,9 +209,8 @@ async def close_channel(channel_id): }[ctype] main_tbl = { - ChannelType.GUILD_TEXT: 'guild_text_channels', - ChannelType.GUILD_VOICE: 'guild_voice_channels', - + ChannelType.GUILD_TEXT: "guild_text_channels", + ChannelType.GUILD_VOICE: "guild_voice_channels", # TODO: categories? }[ctype] @@ -199,29 +223,37 @@ async def close_channel(channel_id): await delete_messages(channel_id) await guild_cleanup(channel_id) - await app.db.execute(f""" + await app.db.execute( + f""" DELETE FROM {main_tbl} WHERE id = $1 - """, channel_id) + """, + channel_id, + ) - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM guild_channels WHERE id = $1 - """, channel_id) + """, + channel_id, + ) - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM channels WHERE id = $1 - """, channel_id) + """, + channel_id, + ) # clean its member list representation - lazy_guilds = app.dispatcher.backends['lazy_guild'] + lazy_guilds = app.dispatcher.backends["lazy_guild"] lazy_guilds.remove_channel(channel_id) - await app.dispatcher.dispatch_guild( - guild_id, 'CHANNEL_DELETE', chan) + await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_DELETE", chan) - await app.dispatcher.remove('channel', channel_id) + await app.dispatcher.remove("channel", channel_id) return jsonify(chan) if ctype == ChannelType.DM: @@ -231,27 +263,34 @@ async def close_channel(channel_id): # instead, we close the channel for the user that is making # the request via removing the link between them and # the channel on dm_channel_state - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM dm_channel_state WHERE user_id = $1 AND dm_id = $2 - """, user_id, channel_id) + """, + user_id, + channel_id, + ) # unsubscribe - await app.dispatcher.unsub('channel', channel_id, user_id) + await app.dispatcher.unsub("channel", channel_id, user_id) # nothing happens to the other party of the dm channel - await app.dispatcher.dispatch_user(user_id, 'CHANNEL_DELETE', chan) + await app.dispatcher.dispatch_user(user_id, "CHANNEL_DELETE", chan) return jsonify(chan) if ctype == ChannelType.GROUP_DM: await gdm_remove_recipient(channel_id, user_id) - gdm_count = await app.db.fetchval(""" + gdm_count = await app.db.fetchval( + """ SELECT COUNT(*) FROM group_dm_members WHERE id = $1 - """, channel_id) + """, + channel_id, + ) if gdm_count == 0: # destroy dm @@ -261,11 +300,15 @@ async def close_channel(channel_id): async def _update_pos(channel_id, pos: int): - await app.db.execute(""" + await app.db.execute( + """ UPDATE guild_channels SET position = $1 WHERE id = $2 - """, pos, channel_id) + """, + pos, + channel_id, + ) async def _mass_chan_update(guild_id, channel_ids: List[Optional[int]]): @@ -274,20 +317,19 @@ async def _mass_chan_update(guild_id, channel_ids: List[Optional[int]]): continue chan = await app.storage.get_channel(channel_id) - await app.dispatcher.dispatch( - 'guild', guild_id, 'CHANNEL_UPDATE', chan) + await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan) async def _process_overwrites(channel_id: int, overwrites: list): for overwrite in overwrites: # 0 for member overwrite, 1 for role overwrite - target_type = 0 if overwrite['type'] == 'member' else 1 - target_role = None if target_type == 0 else overwrite['id'] - target_user = overwrite['id'] if target_type == 0 else None + target_type = 0 if overwrite["type"] == "member" else 1 + target_role = None if target_type == 0 else overwrite["id"] + target_user = overwrite["id"] if target_type == 0 else None - col_name = 'target_user' if target_type == 0 else 'target_role' - constraint_name = f'channel_overwrites_{col_name}_uniq' + col_name = "target_user" if target_type == 0 else "target_role" + constraint_name = f"channel_overwrites_{col_name}_uniq" await app.db.execute( f""" @@ -301,53 +343,66 @@ async def _process_overwrites(channel_id: int, overwrites: list): UPDATE SET allow = $5, deny = $6 """, - channel_id, target_type, - target_role, target_user, - overwrite['allow'], overwrite['deny']) + channel_id, + target_type, + target_role, + target_user, + overwrite["allow"], + overwrite["deny"], + ) -@bp.route('//permissions/', methods=['PUT']) +@bp.route("//permissions/", methods=["PUT"]) async def put_channel_overwrite(channel_id: int, overwrite_id: int): """Insert or modify a channel overwrite.""" user_id = await token_check() ctype, guild_id = await channel_check(user_id, channel_id) if ctype not in GUILD_CHANS: - raise ChannelNotFound('Only usable for guild channels.') + raise ChannelNotFound("Only usable for guild channels.") - await channel_perm_check(user_id, guild_id, 'manage_roles') + await channel_perm_check(user_id, guild_id, "manage_roles") j = validate( # inserting a fake id on the payload so validation passes through - {**await request.get_json(), **{'id': -1}}, - CHAN_OVERWRITE + {**await request.get_json(), **{"id": -1}}, + CHAN_OVERWRITE, ) - await _process_overwrites(channel_id, [{ - 'allow': j['allow'], - 'deny': j['deny'], - 'type': j['type'], - 'id': overwrite_id - }]) + await _process_overwrites( + channel_id, + [ + { + "allow": j["allow"], + "deny": j["deny"], + "type": j["type"], + "id": overwrite_id, + } + ], + ) await _mass_chan_update(guild_id, [channel_id]) - return '', 204 + return "", 204 async def _update_channel_common(channel_id, guild_id: int, j: dict): - if 'name' in j: - await app.db.execute(""" + if "name" in j: + await app.db.execute( + """ UPDATE guild_channels SET name = $1 WHERE id = $2 - """, j['name'], channel_id) + """, + j["name"], + channel_id, + ) - if 'position' in j: + if "position" in j: channel_data = await app.storage.get_channel_data(guild_id) chans = [None] * len(channel_data) for chandata in channel_data: - chans.insert(chandata['position'], int(chandata['id'])) + chans.insert(chandata["position"], int(chandata["id"])) # are we changing to the left or to the right? @@ -358,7 +413,7 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict): # channelN-1 going to the position channel2 # was occupying. current_pos = chans.index(channel_id) - new_pos = j['position'] + new_pos = j["position"] # if the new position is bigger than the current one, # we're making a left shift of all the channels that are @@ -366,113 +421,136 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict): left_shift = new_pos > current_pos # find all channels that we'll have to shift - shift_block = (chans[current_pos:new_pos] - if left_shift else - chans[new_pos:current_pos] - ) + shift_block = ( + chans[current_pos:new_pos] if left_shift else chans[new_pos:current_pos] + ) shift = -1 if left_shift else 1 # do the shift (to the left or to the right) - await app.db.executemany(""" + await app.db.executemany( + """ UPDATE guild_channels SET position = position + $1 WHERE id = $2 - """, [(shift, chan_id) for chan_id in shift_block]) + """, + [(shift, chan_id) for chan_id in shift_block], + ) await _mass_chan_update(guild_id, shift_block) # since theres now an empty slot, move current channel to it await _update_pos(channel_id, new_pos) - if 'channel_overwrites' in j: - overwrites = j['channel_overwrites'] + if "channel_overwrites" in j: + overwrites = j["channel_overwrites"] await _process_overwrites(channel_id, overwrites) async def _common_guild_chan(channel_id, j: dict): # common updates to the guild_channels table - for field in [field for field in j.keys() - if field in ('nsfw', 'parent_id')]: - await app.db.execute(f""" + for field in [field for field in j.keys() if field in ("nsfw", "parent_id")]: + await app.db.execute( + f""" UPDATE guild_channels SET {field} = $1 WHERE id = $2 - """, j[field], channel_id) + """, + j[field], + channel_id, + ) async def _update_text_channel(channel_id: int, j: dict, _user_id: int): # first do the specific ones related to guild_text_channels - for field in [field for field in j.keys() - if field in ('topic', 'rate_limit_per_user')]: - await app.db.execute(f""" + for field in [ + field for field in j.keys() if field in ("topic", "rate_limit_per_user") + ]: + await app.db.execute( + f""" UPDATE guild_text_channels SET {field} = $1 WHERE id = $2 - """, j[field], channel_id) + """, + j[field], + channel_id, + ) await _common_guild_chan(channel_id, j) async def _update_voice_channel(channel_id: int, j: dict, _user_id: int): # first do the specific ones in guild_voice_channels - for field in [field for field in j.keys() - if field in ('bitrate', 'user_limit')]: - await app.db.execute(f""" + for field in [field for field in j.keys() if field in ("bitrate", "user_limit")]: + await app.db.execute( + f""" UPDATE guild_voice_channels SET {field} = $1 WHERE id = $2 - """, j[field], channel_id) + """, + j[field], + channel_id, + ) # yes, i'm letting voice channels have nsfw, you cant stop me await _common_guild_chan(channel_id, j) async def _update_group_dm(channel_id: int, j: dict, author_id: int): - if 'name' in j: - await app.db.execute(""" + if "name" in j: + await app.db.execute( + """ UPDATE group_dm_channels SET name = $1 WHERE id = $2 - """, j['name'], channel_id) + """, + j["name"], + channel_id, + ) await send_sys_message( app, channel_id, MessageType.CHANNEL_NAME_CHANGE, author_id ) - if 'icon' in j: + if "icon" in j: new_icon = await app.icons.update( - 'channel-icons', channel_id, j['icon'], always_icon=True + "channel-icons", channel_id, j["icon"], always_icon=True ) - await app.db.execute(""" + await app.db.execute( + """ UPDATE group_dm_channels SET icon = $1 WHERE id = $2 - """, new_icon.icon_hash, channel_id) + """, + new_icon.icon_hash, + channel_id, + ) await send_sys_message( app, channel_id, MessageType.CHANNEL_ICON_CHANGE, author_id ) -@bp.route('/', methods=['PUT', 'PATCH']) +@bp.route("/", methods=["PUT", "PATCH"]) async def update_channel(channel_id): """Update a channel's information""" user_id = await token_check() ctype, guild_id = await channel_check(user_id, channel_id) - if ctype not in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE, - ChannelType.GROUP_DM): - raise ChannelNotFound('unable to edit unsupported chan type') + if ctype not in ( + ChannelType.GUILD_TEXT, + ChannelType.GUILD_VOICE, + ChannelType.GROUP_DM, + ): + raise ChannelNotFound("unable to edit unsupported chan type") is_guild = ctype in GUILD_CHANS if is_guild: - await channel_perm_check(user_id, channel_id, 'manage_channels') + await channel_perm_check(user_id, channel_id, "manage_channels") - j = validate(await request.get_json(), - CHAN_UPDATE if is_guild else GROUP_DM_UPDATE) + j = validate(await request.get_json(), CHAN_UPDATE if is_guild else GROUP_DM_UPDATE) # TODO: categories update_handler = { @@ -489,30 +567,32 @@ async def update_channel(channel_id): chan = await app.storage.get_channel(channel_id) if is_guild: - await app.dispatcher.dispatch( - 'guild', guild_id, 'CHANNEL_UPDATE', chan) + await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan) else: - await app.dispatcher.dispatch( - 'channel', channel_id, 'CHANNEL_UPDATE', chan) + await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan) return jsonify(chan) -@bp.route('//typing', methods=['POST']) +@bp.route("//typing", methods=["POST"]) async def trigger_typing(channel_id): user_id = await token_check() ctype, guild_id = await channel_check(user_id, channel_id) - await app.dispatcher.dispatch('channel', channel_id, 'TYPING_START', { - 'channel_id': str(channel_id), - 'user_id': str(user_id), - 'timestamp': int(time.time()), + await app.dispatcher.dispatch( + "channel", + channel_id, + "TYPING_START", + { + "channel_id": str(channel_id), + "user_id": str(user_id), + "timestamp": int(time.time()), + # guild_id for lazy guilds + "guild_id": str(guild_id) if ctype == ChannelType.GUILD_TEXT else None, + }, + ) - # guild_id for lazy guilds - 'guild_id': str(guild_id) if ctype == ChannelType.GUILD_TEXT else None, - }) - - return '', 204 + return "", 204 async def channel_ack(user_id, guild_id, channel_id, message_id: int = None): @@ -521,7 +601,8 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None): if not message_id: message_id = await app.storage.chan_last_message(channel_id) - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO user_read_state (user_id, channel_id, last_message_id, mention_count) VALUES @@ -532,26 +613,31 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None): SET last_message_id = $3, mention_count = 0 WHERE user_read_state.user_id = $1 AND user_read_state.channel_id = $2 - """, user_id, channel_id, message_id) + """, + user_id, + channel_id, + message_id, + ) if guild_id: await app.dispatcher.dispatch_user_guild( - user_id, guild_id, 'MESSAGE_ACK', { - 'message_id': str(message_id), - 'channel_id': str(channel_id) - }) + user_id, + guild_id, + "MESSAGE_ACK", + {"message_id": str(message_id), "channel_id": str(channel_id)}, + ) else: # we don't use ChannelDispatcher here because since # guild_id is None, all user devices are already subscribed # to the given channel (a dm or a group dm) await app.dispatcher.dispatch_user( - user_id, 'MESSAGE_ACK', { - 'message_id': str(message_id), - 'channel_id': str(channel_id) - }) + user_id, + "MESSAGE_ACK", + {"message_id": str(message_id), "channel_id": str(channel_id)}, + ) -@bp.route('//messages//ack', methods=['POST']) +@bp.route("//messages//ack", methods=["POST"]) async def ack_channel(channel_id, message_id): """Acknowledge a channel.""" user_id = await token_check() @@ -562,40 +648,47 @@ async def ack_channel(channel_id, message_id): await channel_ack(user_id, guild_id, channel_id, message_id) - return jsonify({ - # token seems to be used for - # data collection activities, - # so we never use it. - 'token': None - }) + return jsonify( + { + # token seems to be used for + # data collection activities, + # so we never use it. + "token": None + } + ) -@bp.route('//messages/ack', methods=['DELETE']) +@bp.route("//messages/ack", methods=["DELETE"]) async def delete_read_state(channel_id): """Delete the read state of a channel.""" user_id = await token_check() await channel_check(user_id, channel_id) - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM user_read_state WHERE user_id = $1 AND channel_id = $2 - """, user_id, channel_id) + """, + user_id, + channel_id, + ) - return '', 204 + return "", 204 -@bp.route('//messages/search', methods=['GET']) +@bp.route("//messages/search", methods=["GET"]) async def _search_channel(channel_id): """Search in DMs or group DMs""" user_id = await token_check() await channel_check(user_id, channel_id) - await channel_perm_check(user_id, channel_id, 'read_messages') + await channel_perm_check(user_id, channel_id, "read_messages") j = validate(dict(request.args), SEARCH_CHANNEL) # main search query # the context (before/after) columns are copied from the guilds blueprint. - rows = await app.db.fetch(f""" + rows = await app.db.fetch( + f""" SELECT orig.id AS current_id, COUNT(*) OVER() AS total_results, array((SELECT messages.id AS before_id @@ -611,28 +704,40 @@ async def _search_channel(channel_id): ORDER BY orig.id DESC LIMIT 50 OFFSET $2 - """, channel_id, j['offset'], j['content']) + """, + channel_id, + j["offset"], + j["content"], + ) return jsonify(await search_result_from_list(rows)) + # NOTE that those functions stay here until some other # route or code wants it. async def _msg_update_flags(message_id: int, flags: int): - await app.db.execute(""" + await app.db.execute( + """ UPDATE messages SET flags = $1 WHERE id = $2 - """, flags, message_id) + """, + flags, + message_id, + ) async def _msg_get_flags(message_id: int): - return await app.db.fetchval(""" + return await app.db.fetchval( + """ SELECT flags FROM messages WHERE id = $1 - """, message_id) + """, + message_id, + ) async def _msg_set_flags(message_id: int, new_flags: int): @@ -647,8 +752,9 @@ async def _msg_unset_flags(message_id: int, unset_flags: int): await _msg_update_flags(message_id, flags) -@bp.route('//messages//suppress-embeds', - methods=['POST']) +@bp.route( + "//messages//suppress-embeds", methods=["POST"] +) async def suppress_embeds(channel_id: int, message_id: int): """Toggle the embeds in a message. @@ -661,29 +767,27 @@ async def suppress_embeds(channel_id: int, message_id: int): # the checks here have been copied from the delete_message() # handler on blueprints.channel.messages. maybe we can combine # them someday? - author_id = await app.db.fetchval(""" + author_id = await app.db.fetchval( + """ SELECT author_id FROM messages WHERE messages.id = $1 - """, message_id) + """, + message_id, + ) - by_perms = await channel_perm_check( - user_id, channel_id, 'manage_messages', False) + by_perms = await channel_perm_check(user_id, channel_id, "manage_messages", False) by_author = author_id == user_id can_suppress = by_perms or by_author if not can_suppress: - raise Forbidden('Not enough permissions.') + raise Forbidden("Not enough permissions.") - j = validate( - await request.get_json(), - {'suppress': {'type': 'boolean'}}, - ) + j = validate(await request.get_json(), {"suppress": {"type": "boolean"}}) - suppress = j['suppress'] + suppress = j["suppress"] message = await app.storage.get_message(message_id) - url_embeds = sum( - 1 for embed in message['embeds'] if embed['type'] == 'url') + url_embeds = sum(1 for embed in message["embeds"] if embed["type"] == "url") # NOTE for any future self. discord doing flags an optional thing instead # of just giving 0 is a pretty bad idea because now i have to deal with @@ -693,8 +797,7 @@ async def suppress_embeds(channel_id: int, message_id: int): # delete all embeds then dispatch an update await _msg_set_flags(message_id, MessageFlags.suppress_embeds) - message['flags'] = \ - message.get('flags', 0) | MessageFlags.suppress_embeds + message["flags"] = message.get("flags", 0) | MessageFlags.suppress_embeds await msg_update_embeds(message, [], app.storage, app.dispatcher) elif not suppress and not url_embeds: @@ -702,30 +805,29 @@ async def suppress_embeds(channel_id: int, message_id: int): await _msg_unset_flags(message_id, MessageFlags.suppress_embeds) try: - message.pop('flags') + message.pop("flags") except KeyError: pass app.sched.spawn( process_url_embed( - app.config, app.storage, app.dispatcher, app.session, - message + app.config, app.storage, app.dispatcher, app.session, message ) ) - return '', 204 + return "", 204 -@bp.route('//messages/bulk-delete', methods=['POST']) +@bp.route("//messages/bulk-delete", methods=["POST"]) async def bulk_delete(channel_id: int): user_id = await token_check() ctype, guild_id = await channel_check(user_id, channel_id) guild_id = guild_id if ctype in GUILD_CHANS else None - await channel_perm_check(user_id, channel_id, 'manage_messages') + await channel_perm_check(user_id, channel_id, "manage_messages") j = validate(await request.get_json(), BULK_DELETE) - message_ids = set(j['messages']) + message_ids = set(j["messages"]) # as per discord behavior, if any id here is older than two weeks, # we must error. a cuter behavior would be returning the message ids @@ -738,25 +840,28 @@ async def bulk_delete(channel_id: int): raise BadRequest(50034) payload = { - 'guild_id': str(guild_id), - 'channel_id': str(channel_id), - 'ids': list(map(str, message_ids)), + "guild_id": str(guild_id), + "channel_id": str(channel_id), + "ids": list(map(str, message_ids)), } # payload.guild_id is optional in the event, not nullable. if guild_id is None: - payload.pop('guild_id') + payload.pop("guild_id") - res = await app.db.execute(""" + res = await app.db.execute( + """ DELETE FROM messages WHERE channel_id = $1 AND ARRAY[id] <@ $2::bigint[] - """, channel_id, list(message_ids)) + """, + channel_id, + list(message_ids), + ) - if res == 'DELETE 0': - raise BadRequest('No messages were removed') + if res == "DELETE 0": + raise BadRequest("No messages were removed") - await app.dispatcher.dispatch( - 'channel', channel_id, 'MESSAGE_DELETE_BULK', payload) - return '', 204 + await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_DELETE_BULK", payload) + return "", 204 diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py index e897c21..5141f5b 100644 --- a/litecord/blueprints/checks.py +++ b/litecord/blueprints/checks.py @@ -23,46 +23,57 @@ from quart import current_app as app from litecord.enums import ChannelType, GUILD_CHANS from litecord.errors import ( - GuildNotFound, ChannelNotFound, Forbidden, MissingPermissions + GuildNotFound, + ChannelNotFound, + Forbidden, + MissingPermissions, ) from litecord.permissions import base_permissions, get_permissions async def guild_check(user_id: int, guild_id: int): """Check if a user is in a guild.""" - joined_at = await app.db.fetchval(""" + joined_at = await app.db.fetchval( + """ SELECT joined_at FROM members WHERE user_id = $1 AND guild_id = $2 - """, user_id, guild_id) + """, + user_id, + guild_id, + ) if not joined_at: - raise GuildNotFound('guild not found') + raise GuildNotFound("guild not found") async def guild_owner_check(user_id: int, guild_id: int): """Check if a user is the owner of the guild.""" - owner_id = await app.db.fetchval(""" + owner_id = await app.db.fetchval( + """ SELECT owner_id FROM guilds WHERE guilds.id = $1 - """, guild_id) + """, + guild_id, + ) if not owner_id: raise GuildNotFound() if user_id != owner_id: - raise Forbidden('You are not the owner of the guild') + raise Forbidden("You are not the owner of the guild") -async def channel_check(user_id, channel_id, *, - only: Union[ChannelType, List[ChannelType]] = None): +async def channel_check( + user_id, channel_id, *, only: Union[ChannelType, List[ChannelType]] = None +): """Check if the current user is authorized to read the channel's information.""" chan_type = await app.storage.get_chan_type(channel_id) if chan_type is None: - raise ChannelNotFound('channel type not found') + raise ChannelNotFound("channel type not found") ctype = ChannelType(chan_type) @@ -70,14 +81,17 @@ async def channel_check(user_id, channel_id, *, only = [only] if only and ctype not in only: - raise ChannelNotFound('invalid channel type') + raise ChannelNotFound("invalid channel type") if ctype in GUILD_CHANS: - guild_id = await app.db.fetchval(""" + guild_id = await app.db.fetchval( + """ SELECT guild_id FROM guild_channels WHERE guild_channels.id = $1 - """, channel_id) + """, + channel_id, + ) await guild_check(user_id, guild_id) return ctype, guild_id @@ -87,11 +101,14 @@ async def channel_check(user_id, channel_id, *, return ctype, peer_id if ctype == ChannelType.GROUP_DM: - owner_id = await app.db.fetchval(""" + owner_id = await app.db.fetchval( + """ SELECT owner_id FROM group_dm_channels WHERE id = $1 - """, channel_id) + """, + channel_id, + ) return ctype, owner_id @@ -102,18 +119,17 @@ async def guild_perm_check(user_id, guild_id, permission: str): hasperm = getattr(base_perms.bits, permission) if not hasperm: - raise MissingPermissions('Missing permissions.') + raise MissingPermissions("Missing permissions.") return bool(hasperm) -async def channel_perm_check(user_id, channel_id, - permission: str, raise_err=True): +async def channel_perm_check(user_id, channel_id, permission: str, raise_err=True): """Check channel permissions for a user.""" base_perms = await get_permissions(user_id, channel_id) hasperm = getattr(base_perms.bits, permission) if not hasperm and raise_err: - raise MissingPermissions('Missing permissions.') + raise MissingPermissions("Missing permissions.") return bool(hasperm) diff --git a/litecord/blueprints/dm_channels.py b/litecord/blueprints/dm_channels.py index ac92855..76bd24a 100644 --- a/litecord/blueprints/dm_channels.py +++ b/litecord/blueprints/dm_channels.py @@ -29,21 +29,29 @@ from litecord.system_messages import send_sys_message from litecord.pubsub.channel import gdm_recipient_view log = Logger(__name__) -bp = Blueprint('dm_channels', __name__) +bp = Blueprint("dm_channels", __name__) async def _raw_gdm_add(channel_id, user_id): - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO group_dm_members (id, member_id) VALUES ($1, $2) - """, channel_id, user_id) + """, + channel_id, + user_id, + ) async def _raw_gdm_remove(channel_id, user_id): - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM group_dm_members WHERE id = $1 AND member_id = $2 - """, channel_id, user_id) + """, + channel_id, + user_id, + ) async def gdm_create(user_id, peer_id) -> int: @@ -53,24 +61,32 @@ async def gdm_create(user_id, peer_id) -> int: """ channel_id = get_snowflake() - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO channels (id, channel_type) VALUES ($1, $2) - """, channel_id, ChannelType.GROUP_DM.value) + """, + channel_id, + ChannelType.GROUP_DM.value, + ) - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO group_dm_channels (id, owner_id, name, icon) VALUES ($1, $2, NULL, NULL) - """, channel_id, user_id) + """, + channel_id, + user_id, + ) await _raw_gdm_add(channel_id, user_id) await _raw_gdm_add(channel_id, peer_id) - await app.dispatcher.sub('channel', channel_id, user_id) - await app.dispatcher.sub('channel', channel_id, peer_id) + await app.dispatcher.sub("channel", channel_id, user_id) + await app.dispatcher.sub("channel", channel_id, peer_id) chan = await app.storage.get_channel(channel_id) - await app.dispatcher.dispatch('channel', channel_id, 'CHANNEL_CREATE', chan) + await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_CREATE", chan) return channel_id @@ -89,17 +105,16 @@ async def gdm_add_recipient(channel_id: int, peer_id: int, *, user_id=None): # the reasoning behind gdm_recipient_view is in its docstring. await app.dispatcher.dispatch( - 'user', peer_id, 'CHANNEL_CREATE', gdm_recipient_view(chan, peer_id)) + "user", peer_id, "CHANNEL_CREATE", gdm_recipient_view(chan, peer_id) + ) - await app.dispatcher.dispatch( - 'channel', channel_id, 'CHANNEL_UPDATE', chan) + await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan) - await app.dispatcher.sub('channel', peer_id) + await app.dispatcher.sub("channel", peer_id) if user_id: await send_sys_message( - app, channel_id, MessageType.RECIPIENT_ADD, - user_id, peer_id + app, channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id ) @@ -116,22 +131,22 @@ async def gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None): chan = await app.storage.get_channel(channel_id) await app.dispatcher.dispatch( - 'user', peer_id, 'CHANNEL_DELETE', gdm_recipient_view(chan, user_id)) + "user", peer_id, "CHANNEL_DELETE", gdm_recipient_view(chan, user_id) + ) - await app.dispatcher.unsub('channel', peer_id) + await app.dispatcher.unsub("channel", peer_id) await app.dispatcher.dispatch( - 'channel', channel_id, 'CHANNEL_RECIPIENT_REMOVE', { - 'channel_id': str(channel_id), - 'user': await app.storage.get_user(peer_id) - } + "channel", + channel_id, + "CHANNEL_RECIPIENT_REMOVE", + {"channel_id": str(channel_id), "user": await app.storage.get_user(peer_id)}, ) author_id = peer_id if user_id is None else user_id await send_sys_message( - app, channel_id, MessageType.RECIPIENT_REMOVE, - author_id, peer_id + app, channel_id, MessageType.RECIPIENT_REMOVE, author_id, peer_id ) @@ -139,40 +154,51 @@ async def gdm_destroy(channel_id): """Destroy a Group DM.""" chan = await app.storage.get_channel(channel_id) - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM group_dm_members WHERE id = $1 - """, channel_id) - - await app.db.execute(""" - DELETE FROM group_dm_channels - WHERE id = $1 - """, channel_id) - - await app.db.execute(""" - DELETE FROM channels - WHERE id = $1 - """, channel_id) - - await app.dispatcher.dispatch( - 'channel', channel_id, 'CHANNEL_DELETE', chan + """, + channel_id, ) - await app.dispatcher.remove('channel', channel_id) + await app.db.execute( + """ + DELETE FROM group_dm_channels + WHERE id = $1 + """, + channel_id, + ) + + await app.db.execute( + """ + DELETE FROM channels + WHERE id = $1 + """, + channel_id, + ) + + await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_DELETE", chan) + + await app.dispatcher.remove("channel", channel_id) async def gdm_is_member(channel_id: int, user_id: int) -> bool: """Return if the given user is a member of the Group DM.""" - row = await app.db.fetchval(""" + row = await app.db.fetchval( + """ SELECT id FROM group_dm_members WHERE id = $1 AND member_id = $2 - """, channel_id, user_id) + """, + channel_id, + user_id, + ) return row is not None -@bp.route('//recipients/', methods=['PUT']) +@bp.route("//recipients/", methods=["PUT"]) async def add_to_group_dm(dm_chan, peer_id): """Adds a member to a group dm OR creates a group dm.""" user_id = await token_check() @@ -182,8 +208,7 @@ async def add_to_group_dm(dm_chan, peer_id): # other_id is the peer of the dm if the given channel is a dm ctype, other_id = await channel_check( - user_id, dm_chan, - only=[ChannelType.DM, ChannelType.GROUP_DM] + user_id, dm_chan, only=[ChannelType.DM, ChannelType.GROUP_DM] ) # check relationship with the given user id @@ -191,30 +216,24 @@ async def add_to_group_dm(dm_chan, peer_id): friends = await app.user_storage.are_friends_with(user_id, peer_id) if not friends: - raise BadRequest('Cant insert peer into dm') + raise BadRequest("Cant insert peer into dm") if ctype == ChannelType.DM: - dm_chan = await gdm_create( - user_id, other_id - ) + dm_chan = await gdm_create(user_id, other_id) await gdm_add_recipient(dm_chan, peer_id, user_id=user_id) - return jsonify( - await app.storage.get_channel(dm_chan) - ) + return jsonify(await app.storage.get_channel(dm_chan)) -@bp.route('//recipients/', methods=['DELETE']) +@bp.route("//recipients/", methods=["DELETE"]) async def remove_from_group_dm(dm_chan, peer_id): """Remove users from group dm.""" user_id = await token_check() - _ctype, owner_id = await channel_check( - user_id, dm_chan, only=ChannelType.GROUP_DM - ) + _ctype, owner_id = await channel_check(user_id, dm_chan, only=ChannelType.GROUP_DM) if owner_id != user_id: - raise Forbidden('You are now the owner of the group DM') + raise Forbidden("You are now the owner of the group DM") await gdm_remove_recipient(dm_chan, peer_id) - return '', 204 + return "", 204 diff --git a/litecord/blueprints/dms.py b/litecord/blueprints/dms.py index 4289161..975d7ea 100644 --- a/litecord/blueprints/dms.py +++ b/litecord/blueprints/dms.py @@ -30,15 +30,13 @@ from ..snowflake import get_snowflake from .auth import token_check -from litecord.blueprints.dm_channels import ( - gdm_create, gdm_add_recipient -) +from litecord.blueprints.dm_channels import gdm_create, gdm_add_recipient log = Logger(__name__) -bp = Blueprint('dms', __name__) +bp = Blueprint("dms", __name__) -@bp.route('/@me/channels', methods=['GET']) +@bp.route("/@me/channels", methods=["GET"]) async def get_dms(): """Get the open DMs for the user.""" user_id = await token_check() @@ -53,11 +51,15 @@ async def try_dm_state(user_id: int, dm_id: int): Does not do anything if the user is already in the dm state. """ - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO dm_channel_state (user_id, dm_id) VALUES ($1, $2) ON CONFLICT DO NOTHING - """, user_id, dm_id) + """, + user_id, + dm_id, + ) async def jsonify_dm(dm_id: int, user_id: int): @@ -69,12 +71,16 @@ async def create_dm(user_id, recipient_id): """Create a new dm with a user, or get the existing DM id if it already exists.""" - dm_id = await app.db.fetchval(""" + dm_id = await app.db.fetchval( + """ SELECT id FROM dm_channels WHERE (party1_id = $1 OR party2_id = $1) AND (party1_id = $2 OR party2_id = $2) - """, user_id, recipient_id) + """, + user_id, + recipient_id, + ) if dm_id: return await jsonify_dm(dm_id, user_id) @@ -82,15 +88,24 @@ async def create_dm(user_id, recipient_id): # if no dm was found, create a new one dm_id = get_snowflake() - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO channels (id, channel_type) VALUES ($1, $2) - """, dm_id, ChannelType.DM.value) + """, + dm_id, + ChannelType.DM.value, + ) - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO dm_channels (id, party1_id, party2_id) VALUES ($1, $2, $3) - """, dm_id, user_id, recipient_id) + """, + dm_id, + user_id, + recipient_id, + ) # the dm state is something we use # to give the currently "open dms" @@ -103,24 +118,24 @@ async def create_dm(user_id, recipient_id): return await jsonify_dm(dm_id, user_id) -@bp.route('/@me/channels', methods=['POST']) +@bp.route("/@me/channels", methods=["POST"]) async def start_dm(): """Create a DM with a user.""" user_id = await token_check() j = validate(await request.get_json(), CREATE_DM) - recipient_id = j['recipient_id'] + recipient_id = j["recipient_id"] return await create_dm(user_id, recipient_id) -@bp.route('//channels', methods=['POST']) +@bp.route("//channels", methods=["POST"]) async def create_group_dm(p_user_id: int): """Create a DM or a Group DM with user(s).""" user_id = await token_check() assert user_id == p_user_id j = validate(await request.get_json(), CREATE_GROUP_DM) - recipients = j['recipients'] + recipients = j["recipients"] if len(recipients) == 1: # its a group dm with 1 user... a dm! diff --git a/litecord/blueprints/gateway.py b/litecord/blueprints/gateway.py index 05303f0..a50a60e 100644 --- a/litecord/blueprints/gateway.py +++ b/litecord/blueprints/gateway.py @@ -23,37 +23,38 @@ from quart import Blueprint, jsonify, current_app as app from ..auth import token_check -bp = Blueprint('gateway', __name__) +bp = Blueprint("gateway", __name__) def get_gw(): """Get the gateway's web""" - proto = 'wss://' if app.config['IS_SSL'] else 'ws://' + proto = "wss://" if app.config["IS_SSL"] else "ws://" return f'{proto}{app.config["WEBSOCKET_URL"]}' -@bp.route('/gateway') +@bp.route("/gateway") def api_gateway(): """Get the raw URL.""" - return jsonify({ - 'url': get_gw() - }) + return jsonify({"url": get_gw()}) -@bp.route('/gateway/bot') +@bp.route("/gateway/bot") async def api_gateway_bot(): user_id = await token_check() - guild_count = await app.db.fetchval(""" + guild_count = await app.db.fetchval( + """ SELECT COUNT(*) FROM members WHERE user_id = $1 - """, user_id) + """, + user_id, + ) shards = max(int(guild_count / 1000), 1) # get _ws.session ratelimit - ratelimit = app.ratelimiter.get_ratelimit('_ws.session') + ratelimit = app.ratelimiter.get_ratelimit("_ws.session") bucket = ratelimit.get_bucket(user_id) # timestamp of bucket reset @@ -62,13 +63,14 @@ async def api_gateway_bot(): # how many seconds until bucket reset reset_after_ts = reset_ts - time.time() - return jsonify({ - 'url': get_gw(), - 'shards': shards, - - 'session_start_limit': { - 'total': bucket.requests, - 'remaining': bucket._tokens, - 'reset_after': int(reset_after_ts * 1000), + return jsonify( + { + "url": get_gw(), + "shards": shards, + "session_start_limit": { + "total": bucket.requests, + "remaining": bucket._tokens, + "reset_after": int(reset_after_ts * 1000), + }, } - }) + ) diff --git a/litecord/blueprints/guild/__init__.py b/litecord/blueprints/guild/__init__.py index 7e37ef6..31c2640 100644 --- a/litecord/blueprints/guild/__init__.py +++ b/litecord/blueprints/guild/__init__.py @@ -23,5 +23,4 @@ from .channels import bp as guild_channels from .mod import bp as guild_mod from .emoji import bp as guild_emoji -__all__ = ['guild_roles', 'guild_members', 'guild_channels', 'guild_mod', - 'guild_emoji'] +__all__ = ["guild_roles", "guild_members", "guild_channels", "guild_mod", "guild_emoji"] diff --git a/litecord/blueprints/guild/channels.py b/litecord/blueprints/guild/channels.py index 2b0467b..c8a2d2e 100644 --- a/litecord/blueprints/guild/channels.py +++ b/litecord/blueprints/guild/channels.py @@ -25,23 +25,23 @@ from litecord.errors import BadRequest from litecord.enums import ChannelType from litecord.blueprints.guild.roles import gen_pairs -from litecord.schemas import ( - validate, ROLE_UPDATE_POSITION, CHAN_CREATE -) -from litecord.blueprints.checks import ( - guild_check, guild_owner_check, guild_perm_check -) +from litecord.schemas import validate, ROLE_UPDATE_POSITION, CHAN_CREATE +from litecord.blueprints.checks import guild_check, guild_owner_check, guild_perm_check -bp = Blueprint('guild_channels', __name__) +bp = Blueprint("guild_channels", __name__) async def _specific_chan_create(channel_id, ctype, **kwargs): if ctype == ChannelType.GUILD_TEXT: - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO guild_text_channels (id, topic) VALUES ($1, $2) - """, channel_id, kwargs.get('topic', '')) + """, + channel_id, + kwargs.get("topic", ""), + ) elif ctype == ChannelType.GUILD_VOICE: await app.db.execute( """ @@ -49,34 +49,48 @@ async def _specific_chan_create(channel_id, ctype, **kwargs): VALUES ($1, $2, $3) """, channel_id, - kwargs.get('bitrate', 64), - kwargs.get('user_limit', 0) + kwargs.get("bitrate", 64), + kwargs.get("user_limit", 0), ) -async def create_guild_channel(guild_id: int, channel_id: int, - ctype: ChannelType, **kwargs): +async def create_guild_channel( + guild_id: int, channel_id: int, ctype: ChannelType, **kwargs +): """Create a channel in a guild.""" - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO channels (id, channel_type) VALUES ($1, $2) - """, channel_id, ctype.value) + """, + channel_id, + ctype.value, + ) # calc new pos - max_pos = await app.db.fetchval(""" + max_pos = await app.db.fetchval( + """ SELECT MAX(position) FROM guild_channels WHERE guild_id = $1 - """, guild_id) + """, + guild_id, + ) # account for the first channel in a guild too max_pos = max_pos or 0 # all channels go to guild_channels - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO guild_channels (id, guild_id, name, position) VALUES ($1, $2, $3, $4) - """, channel_id, guild_id, kwargs['name'], max_pos + 1) + """, + channel_id, + guild_id, + kwargs["name"], + max_pos + 1, + ) # the rest of sql magic is dependant on the channel # we're creating (a text or voice or category), @@ -84,35 +98,32 @@ async def create_guild_channel(guild_id: int, channel_id: int, await _specific_chan_create(channel_id, ctype, **kwargs) -@bp.route('//channels', methods=['GET']) +@bp.route("//channels", methods=["GET"]) async def get_guild_channels(guild_id): """Get the list of channels in a guild.""" user_id = await token_check() await guild_check(user_id, guild_id) - return jsonify( - await app.storage.get_channel_data(guild_id)) + return jsonify(await app.storage.get_channel_data(guild_id)) -@bp.route('//channels', methods=['POST']) +@bp.route("//channels", methods=["POST"]) async def create_channel(guild_id): """Create a channel in a guild.""" user_id = await token_check() j = validate(await request.get_json(), CHAN_CREATE) await guild_check(user_id, guild_id) - await guild_perm_check(user_id, guild_id, 'manage_channels') + await guild_perm_check(user_id, guild_id, "manage_channels") - channel_type = j.get('type', ChannelType.GUILD_TEXT) + channel_type = j.get("type", ChannelType.GUILD_TEXT) channel_type = ChannelType(channel_type) - if channel_type not in (ChannelType.GUILD_TEXT, - ChannelType.GUILD_VOICE): - raise BadRequest('Invalid channel type') + if channel_type not in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE): + raise BadRequest("Invalid channel type") new_channel_id = get_snowflake() - await create_guild_channel( - guild_id, new_channel_id, channel_type, **j) + await create_guild_channel(guild_id, new_channel_id, channel_type, **j) # TODO: do a better method # subscribe the currently subscribed users to the new channel @@ -120,14 +131,13 @@ async def create_channel(guild_id): # since GuildDispatcher calls Storage.get_channel_ids, # it will subscribe all users to the newly created channel. - guild_pubsub = app.dispatcher.backends['guild'] + guild_pubsub = app.dispatcher.backends["guild"] user_ids = guild_pubsub.state[guild_id] for uid in user_ids: - await app.dispatcher.sub('guild', guild_id, uid) + await app.dispatcher.sub("guild", guild_id, uid) chan = await app.storage.get_channel(new_channel_id) - await app.dispatcher.dispatch_guild( - guild_id, 'CHANNEL_CREATE', chan) + await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_CREATE", chan) return jsonify(chan) @@ -135,7 +145,7 @@ async def _chan_update_dispatch(guild_id: int, channel_id: int): """Fetch new information about the channel and dispatch a single CHANNEL_UPDATE event to the guild.""" chan = await app.storage.get_channel(channel_id) - await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_UPDATE', chan) + await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", chan) async def _do_single_swap(guild_id: int, pair: tuple): @@ -149,13 +159,14 @@ async def _do_single_swap(guild_id: int, pair: tuple): conn = await app.db.acquire() async with conn.transaction(): - await conn.executemany(""" + await conn.executemany( + """ UPDATE guild_channels SET position = $1 WHERE id = $2 AND guild_id = $3 - """, [ - (new_pos_1, channel_1, guild_id), - (new_pos_2, channel_2, guild_id)]) + """, + [(new_pos_1, channel_1, guild_id), (new_pos_2, channel_2, guild_id)], + ) await app.db.release(conn) @@ -173,30 +184,26 @@ async def _do_channel_swaps(guild_id: int, swap_pairs: list): await _do_single_swap(guild_id, pair) -@bp.route('//channels', methods=['PATCH']) +@bp.route("//channels", methods=["PATCH"]) async def modify_channel_pos(guild_id): """Change positions of channels in a guild.""" user_id = await token_check() await guild_owner_check(user_id, guild_id) - await guild_perm_check(user_id, guild_id, 'manage_channels') + await guild_perm_check(user_id, guild_id, "manage_channels") # same thing as guild.roles, so we use # the same schema and all. raw_j = await request.get_json() - j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION) - j = j['roles'] + j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION) + j = j["roles"] channels = await app.storage.get_channel_data(guild_id) - channel_positions = {chan['position']: int(chan['id']) - for chan in channels} + channel_positions = {chan["position"]: int(chan["id"]) for chan in channels} - swap_pairs = gen_pairs( - j, - channel_positions - ) + swap_pairs = gen_pairs(j, channel_positions) await _do_channel_swaps(guild_id, swap_pairs) - return '', 204 + return "", 204 diff --git a/litecord/blueprints/guild/emoji.py b/litecord/blueprints/guild/emoji.py index cc48dfe..9a5a7d5 100644 --- a/litecord/blueprints/guild/emoji.py +++ b/litecord/blueprints/guild/emoji.py @@ -27,78 +27,85 @@ from litecord.types import KILOBYTES from litecord.images import parse_data_uri from litecord.errors import BadRequest -bp = Blueprint('guild.emoji', __name__) +bp = Blueprint("guild.emoji", __name__) async def _dispatch_emojis(guild_id): """Dispatch a Guild Emojis Update payload to a guild.""" - await app.dispatcher.dispatch('guild', guild_id, 'GUILD_EMOJIS_UPDATE', { - 'guild_id': str(guild_id), - 'emojis': await app.storage.get_guild_emojis(guild_id) - }) + await app.dispatcher.dispatch( + "guild", + guild_id, + "GUILD_EMOJIS_UPDATE", + { + "guild_id": str(guild_id), + "emojis": await app.storage.get_guild_emojis(guild_id), + }, + ) -@bp.route('//emojis', methods=['GET']) +@bp.route("//emojis", methods=["GET"]) async def _get_guild_emoji(guild_id): user_id = await token_check() await guild_check(user_id, guild_id) - return jsonify( - await app.storage.get_guild_emojis(guild_id) - ) + return jsonify(await app.storage.get_guild_emojis(guild_id)) -@bp.route('//emojis/', methods=['GET']) +@bp.route("//emojis/", methods=["GET"]) async def _get_guild_emoji_one(guild_id, emoji_id): user_id = await token_check() await guild_check(user_id, guild_id) - return jsonify( - await app.storage.get_emoji(emoji_id) - ) + return jsonify(await app.storage.get_emoji(emoji_id)) async def _guild_emoji_size_check(guild_id: int, mime: str): limit = 50 - if await app.storage.has_feature(guild_id, 'MORE_EMOJI'): + if await app.storage.has_feature(guild_id, "MORE_EMOJI"): limit = 200 # NOTE: I'm assuming you can have 200 animated emojis. - select_animated = mime == 'image/gif' + select_animated = mime == "image/gif" - total_emoji = await app.db.fetchval(""" + total_emoji = await app.db.fetchval( + """ SELECT COUNT(*) FROM guild_emoji WHERE guild_id = $1 AND animated = $2 - """, guild_id, select_animated) + """, + guild_id, + select_animated, + ) if total_emoji >= limit: # TODO: really return a BadRequest? needs more looking. - raise BadRequest(f'too many emoji ({limit})') + raise BadRequest(f"too many emoji ({limit})") -@bp.route('//emojis', methods=['POST']) +@bp.route("//emojis", methods=["POST"]) async def _put_emoji(guild_id): user_id = await token_check() await guild_check(user_id, guild_id) - await guild_perm_check(user_id, guild_id, 'manage_emojis') + await guild_perm_check(user_id, guild_id, "manage_emojis") j = validate(await request.get_json(), NEW_EMOJI) # we have to parse it before passing on so that we know which # size to check. - mime, _ = parse_data_uri(j['image']) + mime, _ = parse_data_uri(j["image"]) await _guild_emoji_size_check(guild_id, mime) emoji_id = get_snowflake() icon = await app.icons.put( - 'emoji', emoji_id, j['image'], - + "emoji", + emoji_id, + j["image"], # limits to emojis - bsize=128 * KILOBYTES, size=(128, 128) + bsize=128 * KILOBYTES, + size=(128, 128), ) if not icon: - return '', 400 + return "", 400 # TODO: better way to detect animated emoji rather than just gifs, # maybe a list perhaps? @@ -109,25 +116,25 @@ async def _put_emoji(guild_id): VALUES ($1, $2, $3, $4, $5, $6) """, - emoji_id, guild_id, user_id, - j['name'], + emoji_id, + guild_id, + user_id, + j["name"], icon.icon_hash, - icon.mime == 'image/gif' + icon.mime == "image/gif", ) await _dispatch_emojis(guild_id) - return jsonify( - await app.storage.get_emoji(emoji_id) - ) + return jsonify(await app.storage.get_emoji(emoji_id)) -@bp.route('//emojis/', methods=['PATCH']) +@bp.route("//emojis/", methods=["PATCH"]) async def _patch_emoji(guild_id, emoji_id): user_id = await token_check() await guild_check(user_id, guild_id) - await guild_perm_check(user_id, guild_id, 'manage_emojis') + await guild_perm_check(user_id, guild_id, "manage_emojis") j = validate(await request.get_json(), PATCH_EMOJI) emoji = await app.storage.get_emoji(emoji_id) @@ -135,34 +142,39 @@ async def _patch_emoji(guild_id, emoji_id): # if emoji.name is still the same, we don't update anything # or send ane events, just return the same emoji we'd send # as if we updated it. - if j['name'] == emoji['name']: + if j["name"] == emoji["name"]: return jsonify(emoji) - await app.db.execute(""" + await app.db.execute( + """ UPDATE guild_emoji SET name = $1 WHERE id = $2 - """, j['name'], emoji_id) + """, + j["name"], + emoji_id, + ) await _dispatch_emojis(guild_id) - return jsonify( - await app.storage.get_emoji(emoji_id) - ) + return jsonify(await app.storage.get_emoji(emoji_id)) -@bp.route('//emojis/', methods=['DELETE']) +@bp.route("//emojis/", methods=["DELETE"]) async def _del_emoji(guild_id, emoji_id): user_id = await token_check() await guild_check(user_id, guild_id) - await guild_perm_check(user_id, guild_id, 'manage_emojis') + await guild_perm_check(user_id, guild_id, "manage_emojis") # TODO: check if actually deleted - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM guild_emoji WHERE id = $2 - """, emoji_id) + """, + emoji_id, + ) await _dispatch_emojis(guild_id) - return '', 204 + return "", 204 diff --git a/litecord/blueprints/guild/members.py b/litecord/blueprints/guild/members.py index 3558168..4ea9abb 100644 --- a/litecord/blueprints/guild/members.py +++ b/litecord/blueprints/guild/members.py @@ -22,18 +22,14 @@ from quart import Blueprint, request, current_app as app, jsonify from litecord.blueprints.auth import token_check from litecord.errors import BadRequest -from litecord.schemas import ( - validate, MEMBER_UPDATE -) +from litecord.schemas import validate, MEMBER_UPDATE -from litecord.blueprints.checks import ( - guild_check, guild_owner_check, guild_perm_check -) +from litecord.blueprints.checks import guild_check, guild_owner_check, guild_perm_check -bp = Blueprint('guild_members', __name__) +bp = Blueprint("guild_members", __name__) -@bp.route('//members/', methods=['GET']) +@bp.route("//members/", methods=["GET"]) async def get_guild_member(guild_id, member_id): """Get a member's information in a guild.""" user_id = await token_check() @@ -42,7 +38,7 @@ async def get_guild_member(guild_id, member_id): return jsonify(member) -@bp.route('//members', methods=['GET']) +@bp.route("//members", methods=["GET"]) async def get_members(guild_id): """Get members inside a guild.""" user_id = await token_check() @@ -50,34 +46,41 @@ async def get_members(guild_id): j = await request.get_json() - limit, after = int(j.get('limit', 1)), j.get('after', 0) + limit, after = int(j.get("limit", 1)), j.get("after", 0) if limit < 1 or limit > 1000: - raise BadRequest('limit not in 1-1000 range') + raise BadRequest("limit not in 1-1000 range") - user_ids = await app.db.fetch(f""" + user_ids = await app.db.fetch( + f""" SELECT user_id WHERE guild_id = $1, user_id > $2 LIMIT {limit} ORDER BY user_id ASC - """, guild_id, after) + """, + guild_id, + after, + ) user_ids = [r[0] for r in user_ids] members = await app.storage.get_member_multi(guild_id, user_ids) return jsonify(members) -async def _update_member_roles(guild_id: int, member_id: int, - wanted_roles: set): +async def _update_member_roles(guild_id: int, member_id: int, wanted_roles: set): """Update the roles a member has.""" # first, fetch all current roles - roles = await app.db.fetch(""" + roles = await app.db.fetch( + """ SELECT role_id from member_roles WHERE guild_id = $1 AND user_id = $2 - """, guild_id, member_id) + """, + guild_id, + member_id, + ) - roles = [r['role_id'] for r in roles] + roles = [r["role_id"] for r in roles] roles = set(roles) wanted_roles = set(wanted_roles) @@ -96,26 +99,30 @@ async def _update_member_roles(guild_id: int, member_id: int, async with conn.transaction(): # add roles - await app.db.executemany(""" + await app.db.executemany( + """ INSERT INTO member_roles (user_id, guild_id, role_id) VALUES ($1, $2, $3) - """, [(member_id, guild_id, role_id) - for role_id in added_roles]) + """, + [(member_id, guild_id, role_id) for role_id in added_roles], + ) # remove roles - await app.db.executemany(""" + await app.db.executemany( + """ DELETE FROM member_roles WHERE user_id = $1 AND guild_id = $2 AND role_id = $3 - """, [(member_id, guild_id, role_id) - for role_id in removed_roles]) + """, + [(member_id, guild_id, role_id) for role_id in removed_roles], + ) await app.db.release(conn) -@bp.route('//members/', methods=['PATCH']) +@bp.route("//members/", methods=["PATCH"]) async def modify_guild_member(guild_id, member_id): """Modify a members' information in a guild.""" user_id = await token_check() @@ -124,96 +131,112 @@ async def modify_guild_member(guild_id, member_id): j = validate(await request.get_json(), MEMBER_UPDATE) nick_flag = False - if 'nick' in j: - await guild_perm_check(user_id, guild_id, 'manage_nicknames') + if "nick" in j: + await guild_perm_check(user_id, guild_id, "manage_nicknames") - nick = j['nick'] or None + nick = j["nick"] or None - await app.db.execute(""" + await app.db.execute( + """ UPDATE members SET nickname = $1 WHERE user_id = $2 AND guild_id = $3 - """, nick, member_id, guild_id) + """, + nick, + member_id, + guild_id, + ) nick_flag = True - if 'mute' in j: - await guild_perm_check(user_id, guild_id, 'mute_members') + if "mute" in j: + await guild_perm_check(user_id, guild_id, "mute_members") - await app.db.execute(""" + await app.db.execute( + """ UPDATE members SET muted = $1 WHERE user_id = $2 AND guild_id = $3 - """, j['mute'], member_id, guild_id) + """, + j["mute"], + member_id, + guild_id, + ) - if 'deaf' in j: - await guild_perm_check(user_id, guild_id, 'deafen_members') + if "deaf" in j: + await guild_perm_check(user_id, guild_id, "deafen_members") - await app.db.execute(""" + await app.db.execute( + """ UPDATE members SET deafened = $1 WHERE user_id = $2 AND guild_id = $3 - """, j['deaf'], member_id, guild_id) + """, + j["deaf"], + member_id, + guild_id, + ) - if 'channel_id' in j: + if "channel_id" in j: # TODO: check MOVE_MEMBERS and CONNECT to the channel # TODO: change the member's voice channel pass - if 'roles' in j: - await guild_perm_check(user_id, guild_id, 'manage_roles') - await _update_member_roles(guild_id, member_id, j['roles']) + if "roles" in j: + await guild_perm_check(user_id, guild_id, "manage_roles") + await _update_member_roles(guild_id, member_id, j["roles"]) member = await app.storage.get_member_data_one(guild_id, member_id) - member.pop('joined_at') + member.pop("joined_at") # call pres_update for role and nick changes. - partial = { - 'roles': member['roles'] - } + partial = {"roles": member["roles"]} if nick_flag: - partial['nick'] = j['nick'] + partial["nick"] = j["nick"] await app.dispatcher.dispatch( - 'lazy_guild', guild_id, 'pres_update', user_id, partial) + "lazy_guild", guild_id, "pres_update", user_id, partial + ) - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{ - 'guild_id': str(guild_id) - }, **member}) + await app.dispatcher.dispatch_guild( + guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member} + ) - return '', 204 + return "", 204 -@bp.route('//members/@me/nick', methods=['PATCH']) +@bp.route("//members/@me/nick", methods=["PATCH"]) async def update_nickname(guild_id): """Update a member's nickname in a guild.""" user_id = await token_check() await guild_check(user_id, guild_id) - j = validate(await request.get_json(), { - 'nick': {'type': 'nickname'} - }) + j = validate(await request.get_json(), {"nick": {"type": "nickname"}}) - nick = j['nick'] or None + nick = j["nick"] or None - await app.db.execute(""" + await app.db.execute( + """ UPDATE members SET nickname = $1 WHERE user_id = $2 AND guild_id = $3 - """, nick, user_id, guild_id) + """, + nick, + user_id, + guild_id, + ) member = await app.storage.get_member_data_one(guild_id, user_id) - member.pop('joined_at') + member.pop("joined_at") # call pres_update for nick changes, etc. await app.dispatcher.dispatch( - 'lazy_guild', guild_id, 'pres_update', user_id, { - 'nick': j['nick'] - }) + "lazy_guild", guild_id, "pres_update", user_id, {"nick": j["nick"]} + ) - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{ - 'guild_id': str(guild_id) - }, **member}) + await app.dispatcher.dispatch_guild( + guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member} + ) - return j['nick'] + return j["nick"] diff --git a/litecord/blueprints/guild/mod.py b/litecord/blueprints/guild/mod.py index 7878d3b..5949fd0 100644 --- a/litecord/blueprints/guild/mod.py +++ b/litecord/blueprints/guild/mod.py @@ -24,33 +24,38 @@ from litecord.blueprints.checks import guild_perm_check from litecord.schemas import validate, GUILD_PRUNE -bp = Blueprint('guild_moderation', __name__) +bp = Blueprint("guild_moderation", __name__) async def remove_member(guild_id: int, member_id: int): """Do common tasks related to deleting a member from the guild, such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE.""" - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM members WHERE guild_id = $1 AND user_id = $2 - """, guild_id, member_id) + """, + guild_id, + member_id, + ) await app.dispatcher.dispatch_user_guild( - member_id, guild_id, 'GUILD_DELETE', { - 'guild_id': str(guild_id), - 'unavailable': False, - }) + member_id, + guild_id, + "GUILD_DELETE", + {"guild_id": str(guild_id), "unavailable": False}, + ) - await app.dispatcher.unsub('guild', guild_id, member_id) + await app.dispatcher.unsub("guild", guild_id, member_id) - await app.dispatcher.dispatch( - 'lazy_guild', guild_id, 'remove_member', member_id) + await app.dispatcher.dispatch("lazy_guild", guild_id, "remove_member", member_id) - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', { - 'guild_id': str(guild_id), - 'user': await app.storage.get_user(member_id), - }) + await app.dispatcher.dispatch_guild( + guild_id, + "GUILD_MEMBER_REMOVE", + {"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)}, + ) async def remove_member_multi(guild_id: int, members: list): @@ -59,84 +64,100 @@ async def remove_member_multi(guild_id: int, members: list): await remove_member(guild_id, member_id) -@bp.route('//members/', methods=['DELETE']) +@bp.route("//members/", methods=["DELETE"]) async def kick_guild_member(guild_id, member_id): """Remove a member from a guild.""" user_id = await token_check() - await guild_perm_check(user_id, guild_id, 'kick_members') + await guild_perm_check(user_id, guild_id, "kick_members") await remove_member(guild_id, member_id) - return '', 204 + return "", 204 -@bp.route('//bans', methods=['GET']) +@bp.route("//bans", methods=["GET"]) async def get_bans(guild_id): user_id = await token_check() - await guild_perm_check(user_id, guild_id, 'ban_members') + await guild_perm_check(user_id, guild_id, "ban_members") - bans = await app.db.fetch(""" + bans = await app.db.fetch( + """ SELECT user_id, reason FROM bans WHERE bans.guild_id = $1 - """, guild_id) + """, + guild_id, + ) res = [] for ban in bans: - res.append({ - 'reason': ban['reason'], - 'user': await app.storage.get_user(ban['user_id']) - }) + res.append( + { + "reason": ban["reason"], + "user": await app.storage.get_user(ban["user_id"]), + } + ) return jsonify(res) -@bp.route('//bans/', methods=['PUT']) +@bp.route("//bans/", methods=["PUT"]) async def create_ban(guild_id, member_id): user_id = await token_check() - await guild_perm_check(user_id, guild_id, 'ban_members') + await guild_perm_check(user_id, guild_id, "ban_members") j = await request.get_json() - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO bans (guild_id, user_id, reason) VALUES ($1, $2, $3) - """, guild_id, member_id, j.get('reason', '')) + """, + guild_id, + member_id, + j.get("reason", ""), + ) await remove_member(guild_id, member_id) - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_ADD', { - 'guild_id': str(guild_id), - 'user': await app.storage.get_user(member_id) - }) + await app.dispatcher.dispatch_guild( + guild_id, + "GUILD_BAN_ADD", + {"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)}, + ) - return '', 204 + return "", 204 -@bp.route('//bans/', methods=['DELETE']) +@bp.route("//bans/", methods=["DELETE"]) async def remove_ban(guild_id, banned_id): user_id = await token_check() - await guild_perm_check(user_id, guild_id, 'ban_members') + await guild_perm_check(user_id, guild_id, "ban_members") - res = await app.db.execute(""" + res = await app.db.execute( + """ DELETE FROM bans WHERE guild_id = $1 AND user_id = $@ - """, guild_id, banned_id) + """, + guild_id, + banned_id, + ) # we don't really need to dispatch GUILD_BAN_REMOVE # when no bans were actually removed. - if res == 'DELETE 0': - return '', 204 + if res == "DELETE 0": + return "", 204 - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_REMOVE', { - 'guild_id': str(guild_id), - 'user': await app.storage.get_user(banned_id) - }) + await app.dispatcher.dispatch_guild( + guild_id, + "GUILD_BAN_REMOVE", + {"guild_id": str(guild_id), "user": await app.storage.get_user(banned_id)}, + ) - return '', 204 + return "", 204 async def get_prune(guild_id: int, days: int) -> list: @@ -146,23 +167,30 @@ async def get_prune(guild_id: int, days: int) -> list: - don't have any roles. """ # a good solution would be in pure sql. - member_ids = await app.db.fetch(f""" + member_ids = await app.db.fetch( + f""" SELECT id FROM users JOIN members ON members.guild_id = $1 AND members.user_id = users.id WHERE users.last_session < (now() - (interval '{days} days')) - """, guild_id) + """, + guild_id, + ) - member_ids = [r['id'] for r in member_ids] + member_ids = [r["id"] for r in member_ids] members = [] for member_id in member_ids: - role_count = await app.db.fetchval(""" + role_count = await app.db.fetchval( + """ SELECT COUNT(*) FROM member_roles WHERE guild_id = $1 AND user_id = $2 - """, guild_id, member_id) + """, + guild_id, + member_id, + ) if role_count == 0: members.append(member_id) @@ -170,33 +198,29 @@ async def get_prune(guild_id: int, days: int) -> list: return members -@bp.route('//prune', methods=['GET']) +@bp.route("//prune", methods=["GET"]) async def get_guild_prune_count(guild_id): user_id = await token_check() - await guild_perm_check(user_id, guild_id, 'kick_members') + await guild_perm_check(user_id, guild_id, "kick_members") j = validate(request.args, GUILD_PRUNE) - days = j['days'] + days = j["days"] member_ids = await get_prune(guild_id, days) - return jsonify({ - 'pruned': len(member_ids), - }) + return jsonify({"pruned": len(member_ids)}) -@bp.route('//prune', methods=['POST']) +@bp.route("//prune", methods=["POST"]) async def begin_guild_prune(guild_id): user_id = await token_check() - await guild_perm_check(user_id, guild_id, 'kick_members') + await guild_perm_check(user_id, guild_id, "kick_members") j = validate(request.args, GUILD_PRUNE) - days = j['days'] + days = j["days"] member_ids = await get_prune(guild_id, days) app.loop.create_task(remove_member_multi(guild_id, member_ids)) - return jsonify({ - 'pruned': len(member_ids) - }) + return jsonify({"pruned": len(member_ids)}) diff --git a/litecord/blueprints/guild/roles.py b/litecord/blueprints/guild/roles.py index 5dfcbe7..9516aa4 100644 --- a/litecord/blueprints/guild/roles.py +++ b/litecord/blueprints/guild/roles.py @@ -24,12 +24,8 @@ from logbook import Logger from litecord.auth import token_check -from litecord.blueprints.checks import ( - guild_check, guild_perm_check -) -from litecord.schemas import ( - validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION -) +from litecord.blueprints.checks import guild_check, guild_perm_check +from litecord.schemas import validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION from litecord.snowflake import get_snowflake from litecord.utils import dict_get @@ -37,22 +33,19 @@ from litecord.permissions import get_role_perms DEFAULT_EVERYONE_PERMS = 104324161 log = Logger(__name__) -bp = Blueprint('guild_roles', __name__) +bp = Blueprint("guild_roles", __name__) -@bp.route('//roles', methods=['GET']) +@bp.route("//roles", methods=["GET"]) async def get_guild_roles(guild_id): """Get all roles in a guild.""" user_id = await token_check() await guild_check(user_id, guild_id) - return jsonify( - await app.storage.get_role_data(guild_id) - ) + return jsonify(await app.storage.get_role_data(guild_id)) -async def _maybe_lg(guild_id: int, event: str, - role, force: bool = False): +async def _maybe_lg(guild_id: int, event: str, role, force: bool = False): # sometimes we want to dispatch an event # even if the role isn't hoisted @@ -61,11 +54,10 @@ async def _maybe_lg(guild_id: int, event: str, # check if is a dict first because role_delete # only receives the role id. - if isinstance(role, dict) and not role['hoist'] and not force: + if isinstance(role, dict) and not role["hoist"] and not force: return - await app.dispatcher.dispatch( - 'lazy_guild', guild_id, event, role) + await app.dispatcher.dispatch("lazy_guild", guild_id, event, role) async def create_role(guild_id, name: str, **kwargs): @@ -73,18 +65,20 @@ async def create_role(guild_id, name: str, **kwargs): new_role_id = get_snowflake() everyone_perms = await get_role_perms(guild_id, guild_id) - default_perms = dict_get(kwargs, 'default_perms', - everyone_perms.binary) + default_perms = dict_get(kwargs, "default_perms", everyone_perms.binary) # update all roles so that we have space for pos 1, but without # sending GUILD_ROLE_UPDATE for everyone - await app.db.execute(""" + await app.db.execute( + """ UPDATE roles SET position = position + 1 WHERE guild_id = $1 AND NOT (position = 0) - """, guild_id) + """, + guild_id, + ) await app.db.execute( """ @@ -95,42 +89,39 @@ async def create_role(guild_id, name: str, **kwargs): new_role_id, guild_id, name, - dict_get(kwargs, 'color', 0), - dict_get(kwargs, 'hoist', False), - + dict_get(kwargs, "color", 0), + dict_get(kwargs, "hoist", False), # always set ourselves on position 1 1, - int(dict_get(kwargs, 'permissions', default_perms)), + int(dict_get(kwargs, "permissions", default_perms)), False, - dict_get(kwargs, 'mentionable', False) + dict_get(kwargs, "mentionable", False), ) role = await app.storage.get_role(new_role_id, guild_id) # we need to update the lazy guild handlers for the newly created group - await _maybe_lg(guild_id, 'new_role', role) + await _maybe_lg(guild_id, "new_role", role) await app.dispatcher.dispatch_guild( - guild_id, 'GUILD_ROLE_CREATE', { - 'guild_id': str(guild_id), - 'role': role, - }) + guild_id, "GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role} + ) return role -@bp.route('//roles', methods=['POST']) +@bp.route("//roles", methods=["POST"]) async def create_guild_role(guild_id: int): """Add a role to a guild""" user_id = await token_check() - await guild_perm_check(user_id, guild_id, 'manage_roles') + await guild_perm_check(user_id, guild_id, "manage_roles") # client can just send null j = validate(await request.get_json() or {}, ROLE_CREATE) - role_name = j['name'] - j.pop('name') + role_name = j["name"] + j.pop("name") role = await create_role(guild_id, role_name, **j) @@ -141,12 +132,11 @@ async def _role_update_dispatch(role_id: int, guild_id: int): """Dispatch a GUILD_ROLE_UPDATE with updated information on a role.""" role = await app.storage.get_role(role_id, guild_id) - await _maybe_lg(guild_id, 'role_pos_upd', role) + await _maybe_lg(guild_id, "role_pos_upd", role) - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_UPDATE', { - 'guild_id': str(guild_id), - 'role': role, - }) + await app.dispatcher.dispatch_guild( + guild_id, "GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role} + ) return role @@ -166,17 +156,25 @@ async def _role_pairs_update(guild_id: int, pairs: list): async with conn.transaction(): # update happens in a transaction # so we don't fuck it up - await conn.execute(""" + await conn.execute( + """ UPDATE roles SET position = $1 WHERE roles.id = $2 - """, new_pos_1, role_1) + """, + new_pos_1, + role_1, + ) - await conn.execute(""" + await conn.execute( + """ UPDATE roles SET position = $1 WHERE roles.id = $2 - """, new_pos_2, role_2) + """, + new_pos_2, + role_2, + ) await app.db.release(conn) @@ -184,11 +182,15 @@ async def _role_pairs_update(guild_id: int, pairs: list): await _role_update_dispatch(role_1, guild_id) await _role_update_dispatch(role_2, guild_id) + PairList = List[Tuple[Tuple[int, int], Tuple[int, int]]] -def gen_pairs(list_of_changes: List[Dict[str, int]], - current_state: Dict[int, int], - blacklist: List[int] = None) -> PairList: + +def gen_pairs( + list_of_changes: List[Dict[str, int]], + current_state: Dict[int, int], + blacklist: List[int] = None, +) -> PairList: """Generate a list of pairs that, when applied to the database, will generate the desired state given in list_of_changes. @@ -230,8 +232,9 @@ def gen_pairs(list_of_changes: List[Dict[str, int]], pairs = [] blacklist = blacklist or [] - preferred_state = {element['id']: element['position'] - for element in list_of_changes} + preferred_state = { + element["id"]: element["position"] for element in list_of_changes + } for blacklisted_id in blacklist: preferred_state.pop(blacklisted_id) @@ -239,7 +242,7 @@ def gen_pairs(list_of_changes: List[Dict[str, int]], # for each change, we must find a matching change # in the same list, so we can make a swap pair for change in list_of_changes: - element_1, new_pos_1 = change['id'], change['position'] + element_1, new_pos_1 = change["id"], change["position"] # check current pairs # so we don't repeat an element @@ -267,36 +270,34 @@ def gen_pairs(list_of_changes: List[Dict[str, int]], # if its being swapped to leave space, add it # to the pairs list if new_pos_2 is not None: - pairs.append( - ((element_1, new_pos_1), (element_2, new_pos_2)) - ) + pairs.append(((element_1, new_pos_1), (element_2, new_pos_2))) return pairs -@bp.route('//roles', methods=['PATCH']) +@bp.route("//roles", methods=["PATCH"]) async def update_guild_role_positions(guild_id): """Update the positions for a bunch of roles.""" user_id = await token_check() - await guild_perm_check(user_id, guild_id, 'manage_roles') + await guild_perm_check(user_id, guild_id, "manage_roles") raw_j = await request.get_json() # we need to do this hackiness because thats # cerberus for ya. - j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION) + j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION) # extract the list out - j = j['roles'] + j = j["roles"] - log.debug('role stuff: {!r}', j) + log.debug("role stuff: {!r}", j) all_roles = await app.storage.get_role_data(guild_id) # we'll have to calculate pairs of changing roles, # then do the changes, etc. - roles_pos = {role['position']: int(role['id']) for role in all_roles} + roles_pos = {role["position"]: int(role["id"]) for role in all_roles} # TODO: check if the user can even change the roles in the first place, # preferrably when we have a proper perms system. @@ -306,10 +307,9 @@ async def update_guild_role_positions(guild_id): pairs = gen_pairs( j, roles_pos, - # always ignore people trying to change # the @everyone's role position - [guild_id] + [guild_id], ) await _role_pairs_update(guild_id, pairs) @@ -318,31 +318,36 @@ async def update_guild_role_positions(guild_id): return jsonify(await app.storage.get_role_data(guild_id)) -@bp.route('//roles/', methods=['PATCH']) +@bp.route("//roles/", methods=["PATCH"]) async def update_guild_role(guild_id, role_id): """Update a single role's information.""" user_id = await token_check() - await guild_perm_check(user_id, guild_id, 'manage_roles') + await guild_perm_check(user_id, guild_id, "manage_roles") j = validate(await request.get_json(), ROLE_UPDATE) # we only update ints on the db, not Permissions - j['permissions'] = int(j['permissions']) + j["permissions"] = int(j["permissions"]) for field in j: - await app.db.execute(f""" + await app.db.execute( + f""" UPDATE roles SET {field} = $1 WHERE roles.id = $2 AND roles.guild_id = $3 - """, j[field], role_id, guild_id) + """, + j[field], + role_id, + guild_id, + ) role = await _role_update_dispatch(role_id, guild_id) - await _maybe_lg(guild_id, 'role_update', role, True) + await _maybe_lg(guild_id, "role_update", role, True) return jsonify(role) -@bp.route('//roles/', methods=['DELETE']) +@bp.route("//roles/", methods=["DELETE"]) async def delete_guild_role(guild_id, role_id): """Delete a role. @@ -350,21 +355,26 @@ async def delete_guild_role(guild_id, role_id): """ user_id = await token_check() - await guild_perm_check(user_id, guild_id, 'manage_roles') + await guild_perm_check(user_id, guild_id, "manage_roles") - res = await app.db.execute(""" + res = await app.db.execute( + """ DELETE FROM roles WHERE guild_id = $1 AND id = $2 - """, guild_id, role_id) + """, + guild_id, + role_id, + ) - if res == 'DELETE 0': - return '', 204 + if res == "DELETE 0": + return "", 204 - await _maybe_lg(guild_id, 'role_delete', role_id, True) + await _maybe_lg(guild_id, "role_delete", role_id, True) - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_DELETE', { - 'guild_id': str(guild_id), - 'role_id': str(role_id), - }) + await app.dispatcher.dispatch_guild( + guild_id, + "GUILD_ROLE_DELETE", + {"guild_id": str(guild_id), "role_id": str(role_id)}, + ) - return '', 204 + return "", 204 diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index 3c3c7dc..7db4224 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -22,16 +22,17 @@ from typing import Optional, List from quart import Blueprint, request, current_app as app, jsonify from litecord.blueprints.guild.channels import create_guild_channel -from litecord.blueprints.guild.roles import ( - create_role, DEFAULT_EVERYONE_PERMS -) +from litecord.blueprints.guild.roles import create_role, DEFAULT_EVERYONE_PERMS from ..auth import token_check from ..snowflake import get_snowflake from ..enums import ChannelType from ..schemas import ( - validate, GUILD_CREATE, GUILD_UPDATE, SEARCH_CHANNEL, - VANITY_URL_PATCH + validate, + GUILD_CREATE, + GUILD_UPDATE, + SEARCH_CHANNEL, + VANITY_URL_PATCH, ) from .channels import channel_ack from .checks import guild_check, guild_owner_check, guild_perm_check @@ -40,7 +41,7 @@ from litecord.errors import BadRequest from litecord.permissions import get_permissions -bp = Blueprint('guilds', __name__) +bp = Blueprint("guilds", __name__) async def create_guild_settings(guild_id: int, user_id: int): @@ -49,26 +50,38 @@ async def create_guild_settings(guild_id: int, user_id: int): # new guild_settings are based off the currently # set guild settings (for the guild) - m_notifs = await app.db.fetchval(""" + m_notifs = await app.db.fetchval( + """ SELECT default_message_notifications FROM guilds WHERE id = $1 - """, guild_id) + """, + guild_id, + ) - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO guild_settings (user_id, guild_id, message_notifications) VALUES ($1, $2, $3) - """, user_id, guild_id, m_notifs) + """, + user_id, + guild_id, + m_notifs, + ) async def add_member(guild_id: int, user_id: int): """Add a user to a guild.""" - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO members (user_id, guild_id) VALUES ($1, $2) - """, user_id, guild_id) + """, + user_id, + guild_id, + ) await create_guild_settings(guild_id, user_id) @@ -83,28 +96,28 @@ async def guild_create_roles_prep(guild_id: int, roles: list): # are patches to the @everyone role everyone_patches = roles[0] for field in everyone_patches: - await app.db.execute(f""" + await app.db.execute( + f""" UPDATE roles SET {field}={everyone_patches[field]} WHERE roles.id = $1 - """, guild_id) + """, + guild_id, + ) - default_perms = (everyone_patches.get('permissions') - or DEFAULT_EVERYONE_PERMS) + default_perms = everyone_patches.get("permissions") or DEFAULT_EVERYONE_PERMS # from the 2nd and forward, # should be treated as new roles for role in roles[1:]: - await create_role( - guild_id, role['name'], default_perms=default_perms, **role - ) + await create_role(guild_id, role["name"], default_perms=default_perms, **role) async def guild_create_channels_prep(guild_id: int, channels: list): """Create channels pre-guild create""" for channel_raw in channels: channel_id = get_snowflake() - ctype = ChannelType(channel_raw['type']) + ctype = ChannelType(channel_raw["type"]) await create_guild_channel(guild_id, channel_id, ctype) @@ -114,37 +127,29 @@ def sanitize_icon(icon: Optional[str]) -> Optional[str]: Defaults to a jpeg icon when the header isn't given. """ - if icon and icon.startswith('data'): + if icon and icon.startswith("data"): return icon - return (f'data:image/jpeg;base64,{icon}' - if icon - else None) + return f"data:image/jpeg;base64,{icon}" if icon else None -async def _general_guild_icon(scope: str, guild_id: int, - icon: str, **kwargs): +async def _general_guild_icon(scope: str, guild_id: int, icon: str, **kwargs): encoded = sanitize_icon(icon) - icon_kwargs = { - 'always_icon': True - } + icon_kwargs = {"always_icon": True} - if 'size' in kwargs: - icon_kwargs['size'] = kwargs['size'] + if "size" in kwargs: + icon_kwargs["size"] = kwargs["size"] - return await app.icons.put( - scope, guild_id, encoded, - **icon_kwargs - ) + return await app.icons.put(scope, guild_id, encoded, **icon_kwargs) async def put_guild_icon(guild_id: int, icon: Optional[str]): """Insert a guild icon on the icon database.""" - return await _general_guild_icon('guild', guild_id, icon, size=(128, 128)) + return await _general_guild_icon("guild", guild_id, icon, size=(128, 128)) -@bp.route('', methods=['POST']) +@bp.route("", methods=["POST"]) async def create_guild(): """Create a new guild, assigning the user creating it as the owner and @@ -154,8 +159,8 @@ async def create_guild(): guild_id = get_snowflake() - if 'icon' in j: - image = await put_guild_icon(guild_id, j['icon']) + if "icon" in j: + image = await put_guild_icon(guild_id, j["icon"]) image = image.icon_hash else: image = None @@ -166,10 +171,16 @@ async def create_guild(): verification_level, default_message_notifications, explicit_content_filter) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - """, guild_id, j['name'], j['region'], image, user_id, - j.get('verification_level', 0), - j.get('default_message_notifications', 0), - j.get('explicit_content_filter', 0)) + """, + guild_id, + j["name"], + j["region"], + image, + user_id, + j.get("verification_level", 0), + j.get("default_message_notifications", 0), + j.get("explicit_content_filter", 0), + ) await add_member(guild_id, user_id) @@ -179,107 +190,127 @@ async def create_guild(): # we also don't use create_role because the id of the role # is the same as the id of the guild, and create_role # generates a new snowflake. - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO roles (id, guild_id, name, position, permissions) VALUES ($1, $2, $3, $4, $5) - """, guild_id, guild_id, '@everyone', 0, DEFAULT_EVERYONE_PERMS) + """, + guild_id, + guild_id, + "@everyone", + 0, + DEFAULT_EVERYONE_PERMS, + ) # add the @everyone role to the guild creator - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO member_roles (user_id, guild_id, role_id) VALUES ($1, $2, $3) - """, user_id, guild_id, guild_id) + """, + user_id, + guild_id, + guild_id, + ) # create a single #general channel. general_id = get_snowflake() await create_guild_channel( - guild_id, general_id, ChannelType.GUILD_TEXT, - name='general') + guild_id, general_id, ChannelType.GUILD_TEXT, name="general" + ) - if j.get('roles'): - await guild_create_roles_prep(guild_id, j['roles']) + if j.get("roles"): + await guild_create_roles_prep(guild_id, j["roles"]) - if j.get('channels'): - await guild_create_channels_prep(guild_id, j['channels']) + if j.get("channels"): + await guild_create_channels_prep(guild_id, j["channels"]) guild_total = await app.storage.get_guild_full(guild_id, user_id, 250) - await app.dispatcher.sub('guild', guild_id, user_id) - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_CREATE', guild_total) + await app.dispatcher.sub("guild", guild_id, user_id) + await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild_total) return jsonify(guild_total) -@bp.route('/', methods=['GET']) +@bp.route("/", methods=["GET"]) async def get_guild(guild_id): """Get a single guilds' information.""" user_id = await token_check() await guild_check(user_id, guild_id) - return jsonify( - await app.storage.get_guild_full(guild_id, user_id, 250) - ) + return jsonify(await app.storage.get_guild_full(guild_id, user_id, 250)) -async def _guild_update_icon(scope: str, guild_id: int, - icon: Optional[str], **kwargs): +async def _guild_update_icon(scope: str, guild_id: int, icon: Optional[str], **kwargs): """Update icon.""" - new_icon = await app.icons.update( - scope, guild_id, icon, always_icon=True, **kwargs - ) + new_icon = await app.icons.update(scope, guild_id, icon, always_icon=True, **kwargs) - table = { - 'guild': 'icon', - }.get(scope, scope) + table = {"guild": "icon"}.get(scope, scope) - await app.db.execute(f""" + await app.db.execute( + f""" UPDATE guilds SET {table} = $1 WHERE id = $2 - """, new_icon.icon_hash, guild_id) + """, + new_icon.icon_hash, + guild_id, + ) async def _guild_update_region(guild_id, region): is_vip = region.vip - can_vip = await app.storage.has_feature(guild_id, 'VIP_REGIONS') + can_vip = await app.storage.has_feature(guild_id, "VIP_REGIONS") if is_vip and not can_vip: - raise BadRequest('can not assign guild to vip-only region') + raise BadRequest("can not assign guild to vip-only region") - await app.db.execute(""" + await app.db.execute( + """ UPDATE guilds SET region = $1 WHERE id = $2 - """, region.id, guild_id) + """, + region.id, + guild_id, + ) - -@bp.route('/', methods=['PATCH']) +@bp.route("/", methods=["PATCH"]) async def _update_guild(guild_id): user_id = await token_check() await guild_check(user_id, guild_id) - await guild_perm_check(user_id, guild_id, 'manage_guild') + await guild_perm_check(user_id, guild_id, "manage_guild") j = validate(await request.get_json(), GUILD_UPDATE) - if 'owner_id' in j: + if "owner_id" in j: await guild_owner_check(user_id, guild_id) - await app.db.execute(""" + await app.db.execute( + """ UPDATE guilds SET owner_id = $1 WHERE id = $2 - """, int(j['owner_id']), guild_id) + """, + int(j["owner_id"]), + guild_id, + ) - if 'name' in j: - await app.db.execute(""" + if "name" in j: + await app.db.execute( + """ UPDATE guilds SET name = $1 WHERE id = $2 - """, j['name'], guild_id) + """, + j["name"], + guild_id, + ) - if 'region' in j: - region = app.voice.lvsp.region(j['region']) + if "region" in j: + region = app.voice.lvsp.region(j["region"]) if region is not None: await _guild_update_region(guild_id, region) @@ -287,65 +318,77 @@ async def _update_guild(guild_id): # small guild to work with to_update() guild = await app.storage.get_guild(guild_id) - if to_update(j, guild, 'icon'): - await _guild_update_icon( - 'guild', guild_id, j['icon'], size=(128, 128)) + if to_update(j, guild, "icon"): + await _guild_update_icon("guild", guild_id, j["icon"], size=(128, 128)) - if to_update(j, guild, 'splash'): - if not await app.storage.has_feature(guild_id, 'INVITE_SPLASH'): - raise BadRequest('guild does not have INVITE_SPLASH feature') + if to_update(j, guild, "splash"): + if not await app.storage.has_feature(guild_id, "INVITE_SPLASH"): + raise BadRequest("guild does not have INVITE_SPLASH feature") - await _guild_update_icon('splash', guild_id, j['splash']) + await _guild_update_icon("splash", guild_id, j["splash"]) - if to_update(j, guild, 'banner'): - if not await app.storage.has_feature(guild_id, 'VERIFIED'): - raise BadRequest('guild is not verified') + if to_update(j, guild, "banner"): + if not await app.storage.has_feature(guild_id, "VERIFIED"): + raise BadRequest("guild is not verified") - await _guild_update_icon('banner', guild_id, j['banner']) + await _guild_update_icon("banner", guild_id, j["banner"]) - fields = ['verification_level', 'default_message_notifications', - 'explicit_content_filter', 'afk_timeout', 'description'] + fields = [ + "verification_level", + "default_message_notifications", + "explicit_content_filter", + "afk_timeout", + "description", + ] for field in [f for f in fields if f in j]: - await app.db.execute(f""" + await app.db.execute( + f""" UPDATE guilds SET {field} = $1 WHERE id = $2 - """, j[field], guild_id) + """, + j[field], + guild_id, + ) - channel_fields = ['afk_channel_id', 'system_channel_id'] + channel_fields = ["afk_channel_id", "system_channel_id"] for field in [f for f in channel_fields if f in j]: # setting to null should remove the link between the afk/sys channel # to the guild. if j[field] is None: - await app.db.execute(f""" + await app.db.execute( + f""" UPDATE guilds SET {field} = NULL WHERE id = $1 - """, guild_id) + """, + guild_id, + ) continue chan = await app.storage.get_channel(int(j[field])) if chan is None: - raise BadRequest('invalid channel id') + raise BadRequest("invalid channel id") - if chan['guild_id'] != str(guild_id): - raise BadRequest('channel id not linked to guild') + if chan["guild_id"] != str(guild_id): + raise BadRequest("channel id not linked to guild") - await app.db.execute(f""" + await app.db.execute( + f""" UPDATE guilds SET {field} = $1 WHERE id = $2 - """, j[field], guild_id) + """, + j[field], + guild_id, + ) - guild = await app.storage.get_guild_full( - guild_id, user_id - ) + guild = await app.storage.get_guild_full(guild_id, user_id) - await app.dispatcher.dispatch_guild( - guild_id, 'GUILD_UPDATE', guild) + await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild) return jsonify(guild) @@ -354,33 +397,41 @@ async def delete_guild(guild_id: int, *, app_=None): """Delete a single guild.""" app_ = app_ or app - await app_.db.execute(""" + await app_.db.execute( + """ DELETE FROM guilds WHERE guilds.id = $1 - """, guild_id) + """, + guild_id, + ) # Discord's client expects IDs being string - await app_.dispatcher.dispatch('guild', guild_id, 'GUILD_DELETE', { - 'guild_id': str(guild_id), - 'id': str(guild_id), - # 'unavailable': False, - }) + await app_.dispatcher.dispatch( + "guild", + guild_id, + "GUILD_DELETE", + { + "guild_id": str(guild_id), + "id": str(guild_id), + # 'unavailable': False, + }, + ) # remove from the dispatcher so nobody # becomes the little memer that tries to fuck up with # everybody's gateway - await app_.dispatcher.remove('guild', guild_id) + await app_.dispatcher.remove("guild", guild_id) -@bp.route('/', methods=['DELETE']) +@bp.route("/", methods=["DELETE"]) # this endpoint is not documented, but used by the official client. -@bp.route('//delete', methods=['POST']) +@bp.route("//delete", methods=["POST"]) async def delete_guild_handler(guild_id): """Delete a guild.""" user_id = await token_check() await guild_owner_check(user_id, guild_id) await delete_guild(guild_id) - return '', 204 + return "", 204 async def fetch_readable_channels(guild_id: int, user_id: int) -> List[int]: @@ -397,7 +448,7 @@ async def fetch_readable_channels(guild_id: int, user_id: int) -> List[int]: return res -@bp.route('//messages/search', methods=['GET']) +@bp.route("//messages/search", methods=["GET"]) async def search_messages(guild_id): """Search messages in a guild. @@ -415,7 +466,8 @@ async def search_messages(guild_id): # use that list on the main search query. can_read = await fetch_readable_channels(guild_id, user_id) - rows = await app.db.fetch(f""" + rows = await app.db.fetch( + f""" SELECT orig.id AS current_id, COUNT(*) OVER() as total_results, array((SELECT messages.id AS before_id @@ -432,12 +484,17 @@ async def search_messages(guild_id): ORDER BY orig.id DESC LIMIT 50 OFFSET $3 - """, guild_id, j['content'], j['offset'], can_read) + """, + guild_id, + j["content"], + j["offset"], + can_read, + ) return jsonify(await search_result_from_list(rows)) -@bp.route('//ack', methods=['POST']) +@bp.route("//ack", methods=["POST"]) async def ack_guild(guild_id): """ACKnowledge all messages in the guild.""" user_id = await token_check() @@ -448,45 +505,43 @@ async def ack_guild(guild_id): for chan_id in chan_ids: await channel_ack(user_id, guild_id, chan_id) - return '', 204 + return "", 204 -@bp.route('//vanity-url', methods=['GET']) +@bp.route("//vanity-url", methods=["GET"]) async def get_vanity_url(guild_id: int): """Get the vanity url of a guild.""" user_id = await token_check() - await guild_perm_check(user_id, guild_id, 'manage_guild') + await guild_perm_check(user_id, guild_id, "manage_guild") inv_code = await app.storage.vanity_invite(guild_id) if inv_code is None: - return jsonify({'code': None}) + return jsonify({"code": None}) - return jsonify( - await app.storage.get_invite(inv_code) - ) + return jsonify(await app.storage.get_invite(inv_code)) -@bp.route('//vanity-url', methods=['PATCH']) +@bp.route("//vanity-url", methods=["PATCH"]) async def change_vanity_url(guild_id: int): """Get the vanity url of a guild.""" user_id = await token_check() - if not await app.storage.has_feature(guild_id, 'VANITY_URL'): + if not await app.storage.has_feature(guild_id, "VANITY_URL"): # TODO: is this the right error - raise BadRequest('guild has no vanity url support') + raise BadRequest("guild has no vanity url support") - await guild_perm_check(user_id, guild_id, 'manage_guild') + await guild_perm_check(user_id, guild_id, "manage_guild") j = validate(await request.get_json(), VANITY_URL_PATCH) - inv_code = j['code'] + inv_code = j["code"] # store old vanity in a variable to delete it from # invites table old_vanity = await app.storage.vanity_invite(guild_id) if old_vanity == inv_code: - raise BadRequest('can not change to same invite') + raise BadRequest("can not change to same invite") # this is sad because we don't really use the things # sql gives us, but i havent really found a way to put @@ -494,19 +549,22 @@ async def change_vanity_url(guild_id: int): # guild_id_fkey fails but INSERT when code_fkey fails.. inv = await app.storage.get_invite(inv_code) if inv: - raise BadRequest('invite already exists') + raise BadRequest("invite already exists") # TODO: this is bad, what if a guild has no channels? # we should probably choose the first channel that has # @everyone read messages channels = await app.storage.get_channel_data(guild_id) - channel_id = int(channels[0]['id']) + channel_id = int(channels[0]["id"]) # delete the old invite, insert new one - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM invites WHERE code = $1 - """, old_vanity) + """, + old_vanity, + ) await app.db.execute( """ @@ -515,21 +573,27 @@ async def change_vanity_url(guild_id: int): max_age, temporary) VALUES ($1, $2, $3, $4, $5, $6, $7) """, - inv_code, guild_id, channel_id, user_id, - + inv_code, + guild_id, + channel_id, + user_id, # sane defaults for vanity urls. - 0, 0, False, + 0, + 0, + False, ) - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO vanity_invites (guild_id, code) VALUES ($1, $2) ON CONFLICT ON CONSTRAINT vanity_invites_pkey DO UPDATE SET code = $2 WHERE vanity_invites.guild_id = $1 - """, guild_id, inv_code) - - return jsonify( - await app.storage.get_invite(inv_code) + """, + guild_id, + inv_code, ) + + return jsonify(await app.storage.get_invite(inv_code)) diff --git a/litecord/blueprints/icons.py b/litecord/blueprints/icons.py index b01bac7..9b0f378 100644 --- a/litecord/blueprints/icons.py +++ b/litecord/blueprints/icons.py @@ -24,41 +24,39 @@ from quart import Blueprint, current_app as app, send_file, redirect from litecord.embed.sanitizer import make_md_req_url from litecord.embed.schemas import EmbedURL -bp = Blueprint('images', __name__) +bp = Blueprint("images", __name__) async def send_icon(scope, key, icon_hash, **kwargs): """Send an icon.""" - icon = await app.icons.generic_get( - scope, key, icon_hash, **kwargs) + icon = await app.icons.generic_get(scope, key, icon_hash, **kwargs) if not icon: - return '', 404 + return "", 404 return await send_file(icon.as_path) def splitext_(filepath): name, ext = splitext(filepath) - return name, ext.strip('.') + return name, ext.strip(".") -@bp.route('/emojis/', methods=['GET']) +@bp.route("/emojis/", methods=["GET"]) async def _get_raw_emoji(emoji_file): # emoji = app.icons.get_emoji(emoji_id, ext=ext) # just a test file for now emoji_id, ext = splitext_(emoji_file) - return await send_icon( - 'emoji', emoji_id, None, ext=ext) + return await send_icon("emoji", emoji_id, None, ext=ext) -@bp.route('/icons//', methods=['GET']) +@bp.route("/icons//", methods=["GET"]) async def _get_guild_icon(guild_id: int, icon_file: str): icon_hash, ext = splitext_(icon_file) - return await send_icon('guild', guild_id, icon_hash, ext=ext) + return await send_icon("guild", guild_id, icon_hash, ext=ext) -@bp.route('/embed/avatars/.png') +@bp.route("/embed/avatars/.png") async def _get_default_user_avatar(default_id: int): # TODO: how do we determine which assets to use for this? # I don't think we can use discord assets. @@ -66,25 +64,29 @@ async def _get_default_user_avatar(default_id: int): async def _handle_webhook_avatar(md_url_redir: str): - md_url = make_md_req_url(app.config, 'img', EmbedURL(md_url_redir)) + md_url = make_md_req_url(app.config, "img", EmbedURL(md_url_redir)) return redirect(md_url) -@bp.route('/avatars//') +@bp.route("/avatars//") async def _get_user_avatar(user_id, avatar_file): avatar_hash, ext = splitext_(avatar_file) # first, check if this is a webhook avatar to redir to - md_url_redir = await app.db.fetchval(""" + md_url_redir = await app.db.fetchval( + """ SELECT md_url_redir FROM webhook_avatars WHERE webhook_id = $1 AND hash = $2 - """, user_id, avatar_hash) + """, + user_id, + avatar_hash, + ) if md_url_redir: return await _handle_webhook_avatar(md_url_redir) - return await send_icon('user', user_id, avatar_hash, ext=ext) + return await send_icon("user", user_id, avatar_hash, ext=ext) # @bp.route('/app-icons//.') @@ -92,19 +94,19 @@ async def get_app_icon(application_id, icon_hash, ext): pass -@bp.route('/channel-icons//', methods=['GET']) +@bp.route("/channel-icons//", methods=["GET"]) async def _get_gdm_icon(channel_id: int, icon_file: str): icon_hash, ext = splitext_(icon_file) - return await send_icon('channel-icons', channel_id, icon_hash, ext=ext) + return await send_icon("channel-icons", channel_id, icon_hash, ext=ext) -@bp.route('/splashes//', methods=['GET']) +@bp.route("/splashes//", methods=["GET"]) async def _get_guild_splash(guild_id: int, icon_file: str): icon_hash, ext = splitext_(icon_file) - return await send_icon('splash', guild_id, icon_hash, ext=ext) + return await send_icon("splash", guild_id, icon_hash, ext=ext) -@bp.route('/banners//', methods=['GET']) +@bp.route("/banners//", methods=["GET"]) async def _get_guild_banner(guild_id: int, icon_file: str): icon_hash, ext = splitext_(icon_file) - return await send_icon('banner', guild_id, icon_hash, ext=ext) + return await send_icon("banner", guild_id, icon_hash, ext=ext) diff --git a/litecord/blueprints/invites.py b/litecord/blueprints/invites.py index 1b3506a..02b7610 100644 --- a/litecord/blueprints/invites.py +++ b/litecord/blueprints/invites.py @@ -32,13 +32,16 @@ from .guilds import create_guild_settings from ..utils import async_map from litecord.blueprints.checks import ( - channel_check, channel_perm_check, guild_check, guild_perm_check + channel_check, + channel_perm_check, + guild_check, + guild_perm_check, ) from litecord.blueprints.dm_channels import gdm_is_member, gdm_add_recipient log = Logger(__name__) -bp = Blueprint('invites', __name__) +bp = Blueprint("invites", __name__) class UnknownInvite(BadRequest): @@ -48,16 +51,18 @@ class UnknownInvite(BadRequest): class InvalidInvite(Forbidden): error_code = 50020 + class AlreadyInvited(BaseException): pass + def gen_inv_code() -> str: """Generate an invite code. This is a primitive and does not guarantee uniqueness. """ raw = secrets.token_urlsafe(10) - raw = re.sub(r'\/|\+|\-|\_', '', raw) + raw = re.sub(r"\/|\+|\-|\_", "", raw) return raw[:7] @@ -65,23 +70,31 @@ def gen_inv_code() -> str: async def invite_precheck(user_id: int, guild_id: int): """pre-check invite use in the context of a guild.""" - joined = await app.db.fetchval(""" + joined = await app.db.fetchval( + """ SELECT joined_at FROM members WHERE user_id = $1 AND guild_id = $2 - """, user_id, guild_id) + """, + user_id, + guild_id, + ) if joined is not None: - raise AlreadyInvited('You are already in the guild') + raise AlreadyInvited("You are already in the guild") - banned = await app.db.fetchval(""" + banned = await app.db.fetchval( + """ SELECT reason FROM bans WHERE user_id = $1 AND guild_id = $2 - """, user_id, guild_id) + """, + user_id, + guild_id, + ) if banned is not None: - raise InvalidInvite('You are banned.') + raise InvalidInvite("You are banned.") async def invite_precheck_gdm(user_id: int, channel_id: int): @@ -89,23 +102,23 @@ async def invite_precheck_gdm(user_id: int, channel_id: int): is_member = await gdm_is_member(channel_id, user_id) if is_member: - raise AlreadyInvited('You are already in the Group DM') + raise AlreadyInvited("You are already in the Group DM") async def _inv_check_age(inv: dict): - if inv['max_age'] == 0: + if inv["max_age"] == 0: return now = datetime.datetime.utcnow() - delta_sec = (now - inv['created_at']).total_seconds() + delta_sec = (now - inv["created_at"]).total_seconds() - if delta_sec > inv['max_age']: - await delete_invite(inv['code']) - raise InvalidInvite('Invite is expired') + if delta_sec > inv["max_age"]: + await delete_invite(inv["code"]) + raise InvalidInvite("Invite is expired") - if inv['max_uses'] is not -1 and inv['uses'] > inv['max_uses']: - await delete_invite(inv['code']) - raise InvalidInvite('Too many uses') + if inv["max_uses"] is not -1 and inv["uses"] > inv["max_uses"]: + await delete_invite(inv["code"]) + raise InvalidInvite("Too many uses") async def _guild_add_member(guild_id: int, user_id: int): @@ -119,78 +132,89 @@ async def _guild_add_member(guild_id: int, user_id: int): """ # TODO: system message for member join - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO members (user_id, guild_id) VALUES ($1, $2) - """, user_id, guild_id) + """, + user_id, + guild_id, + ) await create_guild_settings(guild_id, user_id) # add the @everyone role to the invited member - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO member_roles (user_id, guild_id, role_id) VALUES ($1, $2, $3) - """, user_id, guild_id, guild_id) + """, + user_id, + guild_id, + guild_id, + ) # tell current members a new member came up member = await app.storage.get_member_data_one(guild_id, user_id) - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_ADD', { - **member, - **{ - 'guild_id': str(guild_id), - }, - }) + await app.dispatcher.dispatch_guild( + guild_id, "GUILD_MEMBER_ADD", {**member, **{"guild_id": str(guild_id)}} + ) # update member lists for the new member - await app.dispatcher.dispatch( - 'lazy_guild', guild_id, 'new_member', user_id) + await app.dispatcher.dispatch("lazy_guild", guild_id, "new_member", user_id) # subscribe new member to guild, so they get events n stuff - await app.dispatcher.sub('guild', guild_id, user_id) + await app.dispatcher.sub("guild", guild_id, user_id) # tell the new member that theres the guild it just joined. # we use dispatch_user_guild so that we send the GUILD_CREATE # just to the shards that are actually tied to it. guild = await app.storage.get_guild_full(guild_id, user_id, 250) - await app.dispatcher.dispatch_user_guild( - user_id, guild_id, 'GUILD_CREATE', guild) + await app.dispatcher.dispatch_user_guild(user_id, guild_id, "GUILD_CREATE", guild) async def use_invite(user_id, invite_code): """Try using an invite""" - inv = await app.db.fetchrow(""" + inv = await app.db.fetchrow( + """ SELECT code, channel_id, guild_id, created_at, max_age, uses, max_uses FROM invites WHERE code = $1 - """, invite_code) + """, + invite_code, + ) if inv is None: - raise UnknownInvite('Unknown invite') + raise UnknownInvite("Unknown invite") await _inv_check_age(inv) # NOTE: if group dm invite, guild_id is null. - guild_id = inv['guild_id'] + guild_id = inv["guild_id"] - try: + try: if guild_id is None: - channel_id = inv['channel_id'] - await invite_precheck_gdm(user_id, inv['channel_id']) + channel_id = inv["channel_id"] + await invite_precheck_gdm(user_id, inv["channel_id"]) await gdm_add_recipient(channel_id, user_id) else: await invite_precheck(user_id, guild_id) await _guild_add_member(guild_id, user_id) - await app.db.execute(""" + await app.db.execute( + """ UPDATE invites SET uses = uses + 1 WHERE code = $1 - """, invite_code) + """, + invite_code, + ) except AlreadyInvited: pass -@bp.route('/channels//invites', methods=['POST']) + +@bp.route("/channels//invites", methods=["POST"]) async def create_invite(channel_id): """Create an invite to a channel.""" user_id = await token_check() @@ -201,12 +225,14 @@ async def create_invite(channel_id): # NOTE: this works on group dms, since it returns ALL_PERMISSIONS on # non-guild channels. - await channel_perm_check(user_id, channel_id, 'create_invites') + await channel_perm_check(user_id, channel_id, "create_invites") - if chantype not in (ChannelType.GUILD_TEXT, - ChannelType.GUILD_VOICE, - ChannelType.GROUP_DM): - raise BadRequest('Invalid channel type') + if chantype not in ( + ChannelType.GUILD_TEXT, + ChannelType.GUILD_VOICE, + ChannelType.GROUP_DM, + ): + raise BadRequest("Invalid channel type") invite_code = gen_inv_code() @@ -222,101 +248,122 @@ async def create_invite(channel_id): max_age, temporary) VALUES ($1, $2, $3, $4, $5, $6, $7) """, - invite_code, guild_id, channel_id, user_id, - j['max_uses'], j['max_age'], j['temporary'] + invite_code, + guild_id, + channel_id, + user_id, + j["max_uses"], + j["max_age"], + j["temporary"], ) invite = await app.storage.get_invite(invite_code) return jsonify(invite) -@bp.route('/invite/', methods=['GET']) -@bp.route('/invites/', methods=['GET']) +@bp.route("/invite/", methods=["GET"]) +@bp.route("/invites/", methods=["GET"]) async def get_invite(invite_code: str): inv = await app.storage.get_invite(invite_code) if not inv: - return '', 404 + return "", 404 - if request.args.get('with_counts'): + if request.args.get("with_counts"): extra = await app.storage.get_invite_extra(invite_code) inv.update(extra) return jsonify(inv) + async def delete_invite(invite_code: str): """Delete an invite.""" - await app.db.fetchval(""" + await app.db.fetchval( + """ DELETE FROM invites WHERE code = $1 - """, invite_code) + """, + invite_code, + ) -@bp.route('/invite/', methods=['DELETE']) -@bp.route('/invites/', methods=['DELETE']) + +@bp.route("/invite/", methods=["DELETE"]) +@bp.route("/invites/", methods=["DELETE"]) async def _delete_invite(invite_code: str): user_id = await token_check() - guild_id = await app.db.fetchval(""" + guild_id = await app.db.fetchval( + """ SELECT guild_id FROM invites WHERE code = $1 - """, invite_code) + """, + invite_code, + ) if guild_id is None: - raise BadRequest('Unknown invite') + raise BadRequest("Unknown invite") - await guild_perm_check(user_id, guild_id, 'manage_channels') + await guild_perm_check(user_id, guild_id, "manage_channels") inv = await app.storage.get_invite(invite_code) await delete_invite(invite_code) return jsonify(inv) + async def _get_inv(code): inv = await app.storage.get_invite(code) meta = await app.storage.get_invite_metadata(code) return {**inv, **meta} -@bp.route('/guilds//invites', methods=['GET']) +@bp.route("/guilds//invites", methods=["GET"]) async def get_guild_invites(guild_id: int): """Get all invites for a guild.""" user_id = await token_check() await guild_check(user_id, guild_id) - await guild_perm_check(user_id, guild_id, 'manage_guild') + await guild_perm_check(user_id, guild_id, "manage_guild") - inv_codes = await app.db.fetch(""" + inv_codes = await app.db.fetch( + """ SELECT code FROM invites WHERE guild_id = $1 - """, guild_id) + """, + guild_id, + ) - inv_codes = [r['code'] for r in inv_codes] + inv_codes = [r["code"] for r in inv_codes] invs = await async_map(_get_inv, inv_codes) return jsonify(invs) -@bp.route('/channels//invites', methods=['GET']) +@bp.route("/channels//invites", methods=["GET"]) async def get_channel_invites(channel_id: int): """Get all invites for a channel.""" user_id = await token_check() _ctype, guild_id = await channel_check(user_id, channel_id) - await guild_perm_check(user_id, guild_id, 'manage_channels') + await guild_perm_check(user_id, guild_id, "manage_channels") - inv_codes = await app.db.fetch(""" + inv_codes = await app.db.fetch( + """ SELECT code FROM invites WHERE guild_id = $1 AND channel_id = $2 - """, guild_id, channel_id) + """, + guild_id, + channel_id, + ) - inv_codes = [r['code'] for r in inv_codes] + inv_codes = [r["code"] for r in inv_codes] invs = await async_map(_get_inv, inv_codes) return jsonify(invs) -@bp.route('/invite/', methods=['POST']) -@bp.route('/invites/', methods=['POST']) +@bp.route("/invite/", methods=["POST"]) +@bp.route("/invites/", methods=["POST"]) async def _use_invite(invite_code): """Use an invite.""" user_id = await token_check() @@ -327,9 +374,4 @@ async def _use_invite(invite_code): inv = await app.storage.get_invite(invite_code) inv_meta = await app.storage.get_invite_metadata(invite_code) - return jsonify({ - **inv, - **{ - 'inviter': inv_meta['inviter'] - } - }) + return jsonify({**inv, **{"inviter": inv_meta["inviter"]}}) diff --git a/litecord/blueprints/nodeinfo.py b/litecord/blueprints/nodeinfo.py index 75ecf6f..7afefb2 100644 --- a/litecord/blueprints/nodeinfo.py +++ b/litecord/blueprints/nodeinfo.py @@ -19,83 +19,75 @@ along with this program. If not, see . from quart import Blueprint, current_app as app, jsonify, request -bp = Blueprint('nodeinfo', __name__) +bp = Blueprint("nodeinfo", __name__) -@bp.route('/.well-known/nodeinfo') +@bp.route("/.well-known/nodeinfo") async def _dummy_nodeinfo_index(): - proto = 'http' if not app.config['IS_SSL'] else 'https' - main_url = app.config.get('MAIN_URL', request.host) + proto = "http" if not app.config["IS_SSL"] else "https" + main_url = app.config.get("MAIN_URL", request.host) - return jsonify({ - 'links': [{ - 'href': f'{proto}://{main_url}/nodeinfo/2.0.json', - 'rel': 'http://nodeinfo.diaspora.software/ns/schema/2.0' - }, { - 'href': f'{proto}://{main_url}/nodeinfo/2.1.json', - 'rel': 'http://nodeinfo.diaspora.software/ns/schema/2.1' - }] - }) + return jsonify( + { + "links": [ + { + "href": f"{proto}://{main_url}/nodeinfo/2.0.json", + "rel": "http://nodeinfo.diaspora.software/ns/schema/2.0", + }, + { + "href": f"{proto}://{main_url}/nodeinfo/2.1.json", + "rel": "http://nodeinfo.diaspora.software/ns/schema/2.1", + }, + ] + } + ) async def fetch_nodeinfo_20(): - usercount = await app.db.fetchval(""" + usercount = await app.db.fetchval( + """ SELECT COUNT(*) FROM users - """) + """ + ) - message_count = await app.db.fetchval(""" + message_count = await app.db.fetchval( + """ SELECT COUNT(*) FROM messages - """) + """ + ) return { - 'metadata': { - 'features': [ - 'discord_api' - ], - - 'nodeDescription': 'A Litecord instance', - 'nodeName': 'Litecord/Nya', - 'private': False, - - 'federation': {} + "metadata": { + "features": ["discord_api"], + "nodeDescription": "A Litecord instance", + "nodeName": "Litecord/Nya", + "private": False, + "federation": {}, }, - 'openRegistrations': app.config['REGISTRATIONS'], - 'protocols': [], - 'software': { - 'name': 'litecord', - 'version': 'litecord v0', - }, - - 'services': { - 'inbound': [], - 'outbound': [], - }, - - 'usage': { - 'localPosts': message_count, - 'users': { - 'total': usercount - } - }, - 'version': '2.0', + "openRegistrations": app.config["REGISTRATIONS"], + "protocols": [], + "software": {"name": "litecord", "version": "litecord v0"}, + "services": {"inbound": [], "outbound": []}, + "usage": {"localPosts": message_count, "users": {"total": usercount}}, + "version": "2.0", } -@bp.route('/nodeinfo/2.0.json') +@bp.route("/nodeinfo/2.0.json") async def _nodeinfo_20(): """Handler for nodeinfo 2.0.""" raw_nodeinfo = await fetch_nodeinfo_20() return jsonify(raw_nodeinfo) -@bp.route('/nodeinfo/2.1.json') +@bp.route("/nodeinfo/2.1.json") async def _nodeinfo_21(): """Handler for nodeinfo 2.1.""" raw_nodeinfo = await fetch_nodeinfo_20() - raw_nodeinfo['software']['repository'] = 'https://gitlab.com/litecord/litecord' - raw_nodeinfo['version'] = '2.1' + raw_nodeinfo["software"]["repository"] = "https://gitlab.com/litecord/litecord" + raw_nodeinfo["version"] = "2.1" return jsonify(raw_nodeinfo) diff --git a/litecord/blueprints/relationships.py b/litecord/blueprints/relationships.py index d93e90b..1fa9d01 100644 --- a/litecord/blueprints/relationships.py +++ b/litecord/blueprints/relationships.py @@ -26,76 +26,89 @@ from ..enums import RelationshipType from litecord.errors import BadRequest -bp = Blueprint('relationship', __name__) +bp = Blueprint("relationship", __name__) -@bp.route('/@me/relationships', methods=['GET']) +@bp.route("/@me/relationships", methods=["GET"]) async def get_me_relationships(): user_id = await token_check() - return jsonify( - await app.user_storage.get_relationships(user_id)) + return jsonify(await app.user_storage.get_relationships(user_id)) async def _dispatch_single_pres(user_id, presence: dict): - await app.dispatcher.dispatch( - 'user', user_id, 'PRESENCE_UPDATE', presence - ) + await app.dispatcher.dispatch("user", user_id, "PRESENCE_UPDATE", presence) async def _unsub_friend(user_id, peer_id): - await app.dispatcher.unsub('friend', user_id, peer_id) - await app.dispatcher.unsub('friend', peer_id, user_id) + await app.dispatcher.unsub("friend", user_id, peer_id) + await app.dispatcher.unsub("friend", peer_id, user_id) + async def _sub_friend(user_id, peer_id): - await app.dispatcher.sub('friend', user_id, peer_id) - await app.dispatcher.sub('friend', peer_id, user_id) + await app.dispatcher.sub("friend", user_id, peer_id) + await app.dispatcher.sub("friend", peer_id, user_id) # dispatch presence update to the user and peer about # eachother's presence. - user_pres, peer_pres = await app.presence.friend_presences( - [user_id, peer_id] - ) + user_pres, peer_pres = await app.presence.friend_presences([user_id, peer_id]) await _dispatch_single_pres(user_id, peer_pres) await _dispatch_single_pres(peer_id, user_pres) -async def make_friend(user_id: int, peer_id: int, - rel_type=RelationshipType.FRIEND.value): +async def make_friend( + user_id: int, peer_id: int, rel_type=RelationshipType.FRIEND.value +): _friend = RelationshipType.FRIEND.value _block = RelationshipType.BLOCK.value try: - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO relationships (user_id, peer_id, rel_type) VALUES ($1, $2, $3) - """, user_id, peer_id, rel_type) + """, + user_id, + peer_id, + rel_type, + ) except UniqueViolationError: # try to update rel_type - old_rel_type = await app.db.fetchval(""" + old_rel_type = await app.db.fetchval( + """ SELECT rel_type FROM relationships WHERE user_id = $1 AND peer_id = $2 - """, user_id, peer_id) + """, + user_id, + peer_id, + ) if old_rel_type == _friend and rel_type == _block: - await app.db.execute(""" + await app.db.execute( + """ UPDATE relationships SET rel_type = $1 WHERE user_id = $2 AND peer_id = $3 - """, rel_type, user_id, peer_id) + """, + rel_type, + user_id, + peer_id, + ) # remove any existing friendship before the block - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM relationships WHERE peer_id = $1 AND user_id = $2 AND rel_type = $3 - """, peer_id, user_id, _friend) + """, + peer_id, + user_id, + _friend, + ) await app.dispatcher.dispatch_user( - peer_id, 'RELATIONSHIP_REMOVE', { - 'type': _friend, - 'id': str(user_id) - } + peer_id, "RELATIONSHIP_REMOVE", {"type": _friend, "id": str(user_id)} ) await _unsub_friend(user_id, peer_id) @@ -106,95 +119,118 @@ async def make_friend(user_id: int, peer_id: int, # check if this is an acceptance # of a friend request - existing = await app.db.fetchrow(""" + existing = await app.db.fetchrow( + """ SELECT user_id, peer_id FROM relationships WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3 - """, peer_id, user_id, _friend) + """, + peer_id, + user_id, + _friend, + ) _dispatch = app.dispatcher.dispatch_user if existing: # accepted a friend request, dispatch respective # relationship events - await _dispatch(user_id, 'RELATIONSHIP_REMOVE', { - 'type': RelationshipType.INCOMING.value, - 'id': str(peer_id) - }) + await _dispatch( + user_id, + "RELATIONSHIP_REMOVE", + {"type": RelationshipType.INCOMING.value, "id": str(peer_id)}, + ) - await _dispatch(user_id, 'RELATIONSHIP_ADD', { - 'type': _friend, - 'id': str(peer_id), - 'user': await app.storage.get_user(peer_id) - }) + await _dispatch( + user_id, + "RELATIONSHIP_ADD", + { + "type": _friend, + "id": str(peer_id), + "user": await app.storage.get_user(peer_id), + }, + ) - await _dispatch(peer_id, 'RELATIONSHIP_ADD', { - 'type': _friend, - 'id': str(user_id), - 'user': await app.storage.get_user(user_id) - }) + await _dispatch( + peer_id, + "RELATIONSHIP_ADD", + { + "type": _friend, + "id": str(user_id), + "user": await app.storage.get_user(user_id), + }, + ) await _sub_friend(user_id, peer_id) - return '', 204 + return "", 204 # check if friend AND not acceptance of fr if rel_type == _friend: - await _dispatch(user_id, 'RELATIONSHIP_ADD', { - 'id': str(peer_id), - 'type': RelationshipType.OUTGOING.value, - 'user': await app.storage.get_user(peer_id), - }) + await _dispatch( + user_id, + "RELATIONSHIP_ADD", + { + "id": str(peer_id), + "type": RelationshipType.OUTGOING.value, + "user": await app.storage.get_user(peer_id), + }, + ) - await _dispatch(peer_id, 'RELATIONSHIP_ADD', { - 'id': str(user_id), - 'type': RelationshipType.INCOMING.value, - 'user': await app.storage.get_user(user_id) - }) + await _dispatch( + peer_id, + "RELATIONSHIP_ADD", + { + "id": str(user_id), + "type": RelationshipType.INCOMING.value, + "user": await app.storage.get_user(user_id), + }, + ) # we don't make the pubsub link # until the peer accepts the friend request - return '', 204 + return "", 204 return class RelationshipFailed(BadRequest): """Exception for general relationship errors.""" + error_code = 80004 class RelationshipBlocked(BadRequest): """Exception for when the peer has blocked the user.""" + error_code = 80001 -@bp.route('/@me/relationships', methods=['POST']) +@bp.route("/@me/relationships", methods=["POST"]) async def post_relationship(): user_id = await token_check() j = validate(await request.get_json(), SPECIFIC_FRIEND) - uid = await app.storage.search_user(j['username'], - str(j['discriminator'])) + uid = await app.storage.search_user(j["username"], str(j["discriminator"])) if not uid: - raise RelationshipFailed('No users with DiscordTag exist') + raise RelationshipFailed("No users with DiscordTag exist") res = await make_friend(user_id, uid) if res is None: - raise RelationshipBlocked('Can not friend user due to block') + raise RelationshipBlocked("Can not friend user due to block") - return '', 204 + return "", 204 -@bp.route('/@me/relationships/', methods=['PUT']) +@bp.route("/@me/relationships/", methods=["PUT"]) async def add_relationship(peer_id: int): """Add a relationship to the peer.""" user_id = await token_check() payload = validate(await request.get_json(), RELATIONSHIP) - rel_type = payload['type'] + rel_type = payload["type"] res = await make_friend(user_id, peer_id, rel_type) @@ -204,18 +240,22 @@ async def add_relationship(peer_id: int): # make_friend did not succeed, so we # assume it is a block and dispatch # the respective RELATIONSHIP_ADD. - await app.dispatcher.dispatch_user(user_id, 'RELATIONSHIP_ADD', { - 'id': str(peer_id), - 'type': RelationshipType.BLOCK.value, - 'user': await app.storage.get_user(peer_id) - }) + await app.dispatcher.dispatch_user( + user_id, + "RELATIONSHIP_ADD", + { + "id": str(peer_id), + "type": RelationshipType.BLOCK.value, + "user": await app.storage.get_user(peer_id), + }, + ) await _unsub_friend(user_id, peer_id) - return '', 204 + return "", 204 -@bp.route('/@me/relationships/', methods=['DELETE']) +@bp.route("/@me/relationships/", methods=["DELETE"]) async def remove_relationship(peer_id: int): """Remove an existing relationship""" user_id = await token_check() @@ -223,69 +263,86 @@ async def remove_relationship(peer_id: int): _block = RelationshipType.BLOCK.value _dispatch = app.dispatcher.dispatch_user - rel_type = await app.db.fetchval(""" + rel_type = await app.db.fetchval( + """ SELECT rel_type FROM relationships WHERE user_id = $1 AND peer_id = $2 - """, user_id, peer_id) + """, + user_id, + peer_id, + ) - incoming_rel_type = await app.db.fetchval(""" + incoming_rel_type = await app.db.fetchval( + """ SELECT rel_type FROM relationships WHERE user_id = $1 AND peer_id = $2 - """, peer_id, user_id) + """, + peer_id, + user_id, + ) # if any of those are friend if _friend in (rel_type, incoming_rel_type): # closing the friendship, have to delete both rows - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM relationships WHERE ( (user_id = $1 AND peer_id = $2) OR (user_id = $2 AND peer_id = $1) ) AND rel_type = $3 - """, user_id, peer_id, _friend) + """, + user_id, + peer_id, + _friend, + ) # if there wasnt any mutual friendship before, # assume they were requests of INCOMING # and OUTGOING. - user_del_type = RelationshipType.OUTGOING.value if \ - incoming_rel_type != _friend else _friend + user_del_type = ( + RelationshipType.OUTGOING.value if incoming_rel_type != _friend else _friend + ) - await _dispatch(user_id, 'RELATIONSHIP_REMOVE', { - 'id': str(peer_id), - 'type': user_del_type, - }) + await _dispatch( + user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": user_del_type} + ) - peer_del_type = RelationshipType.INCOMING.value if \ - incoming_rel_type != _friend else _friend + peer_del_type = ( + RelationshipType.INCOMING.value if incoming_rel_type != _friend else _friend + ) - await _dispatch(peer_id, 'RELATIONSHIP_REMOVE', { - 'id': str(user_id), - 'type': peer_del_type, - }) + await _dispatch( + peer_id, "RELATIONSHIP_REMOVE", {"id": str(user_id), "type": peer_del_type} + ) await _unsub_friend(user_id, peer_id) - return '', 204 + return "", 204 # was a block! - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM relationships WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3 - """, user_id, peer_id, _block) + """, + user_id, + peer_id, + _block, + ) - await _dispatch(user_id, 'RELATIONSHIP_REMOVE', { - 'id': str(peer_id), - 'type': _block, - }) + await _dispatch( + user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": _block} + ) await _unsub_friend(user_id, peer_id) - return '', 204 + return "", 204 -@bp.route('//relationships', methods=['GET']) +@bp.route("//relationships", methods=["GET"]) async def get_mutual_friends(peer_id: int): """Fetch a users' mutual friends with the current user.""" user_id = await token_check() @@ -294,17 +351,15 @@ async def get_mutual_friends(peer_id: int): peer = await app.storage.get_user(peer_id) if not peer: - return '', 204 + return "", 204 # NOTE: maybe this could be better with pure SQL calculations # but it would be beyond my current SQL knowledge, so... user_rels = await app.user_storage.get_relationships(user_id) peer_rels = await app.user_storage.get_relationships(peer_id) - user_friends = {rel['user']['id'] - for rel in user_rels if rel['type'] == _friend} - peer_friends = {rel['user']['id'] - for rel in peer_rels if rel['type'] == _friend} + user_friends = {rel["user"]["id"] for rel in user_rels if rel["type"] == _friend} + peer_friends = {rel["user"]["id"] for rel in peer_rels if rel["type"] == _friend} # get the intersection, then map them to Storage.get_user() calls mutual_ids = user_friends & peer_friends @@ -312,8 +367,6 @@ async def get_mutual_friends(peer_id: int): mutual_friends = [] for friend_id in mutual_ids: - mutual_friends.append( - await app.storage.get_user(int(friend_id)) - ) + mutual_friends.append(await app.storage.get_user(int(friend_id))) return jsonify(mutual_friends) diff --git a/litecord/blueprints/science.py b/litecord/blueprints/science.py index 80ffe15..7a5273c 100644 --- a/litecord/blueprints/science.py +++ b/litecord/blueprints/science.py @@ -19,21 +19,19 @@ along with this program. If not, see . from quart import Blueprint, jsonify -bp = Blueprint('science', __name__) +bp = Blueprint("science", __name__) -@bp.route('/science', methods=['POST']) +@bp.route("/science", methods=["POST"]) async def science(): - return '', 204 + return "", 204 -@bp.route('/applications', methods=['GET']) +@bp.route("/applications", methods=["GET"]) async def applications(): return jsonify([]) -@bp.route('/experiments', methods=['GET']) +@bp.route("/experiments", methods=["GET"]) async def experiments(): - return jsonify({ - 'assignments': [] - }) + return jsonify({"assignments": []}) diff --git a/litecord/blueprints/static.py b/litecord/blueprints/static.py index 5ce08df..61f2148 100644 --- a/litecord/blueprints/static.py +++ b/litecord/blueprints/static.py @@ -20,23 +20,24 @@ along with this program. If not, see . from quart import Blueprint, current_app as app, render_template_string from pathlib import Path -bp = Blueprint('static', __name__) +bp = Blueprint("static", __name__) -@bp.route('/') +@bp.route("/") async def static_pages(path): """Map requests from / to /static.""" - if '..' in path: - return 'no', 404 + if ".." in path: + return "no", 404 - static_path = Path.cwd() / Path('static') / path + static_path = Path.cwd() / Path("static") / path return await app.send_static_file(str(static_path)) -@bp.route('/') -@bp.route('/api') +@bp.route("/") +@bp.route("/api") async def index_handler(): """Handler for the index page.""" - index_path = Path.cwd() / Path('static') / 'index.html' + index_path = Path.cwd() / Path("static") / "index.html" return await render_template_string( - index_path.read_text(), inst_name=app.config['NAME']) + index_path.read_text(), inst_name=app.config["NAME"] + ) diff --git a/litecord/blueprints/user/__init__.py b/litecord/blueprints/user/__init__.py index 9e7b884..bd27fa4 100644 --- a/litecord/blueprints/user/__init__.py +++ b/litecord/blueprints/user/__init__.py @@ -21,4 +21,4 @@ from .billing import bp as user_billing from .settings import bp as user_settings from .fake_store import bp as fake_store -__all__ = ['user_billing', 'user_settings', 'fake_store'] +__all__ = ["user_billing", "user_settings", "fake_store"] diff --git a/litecord/blueprints/user/billing.py b/litecord/blueprints/user/billing.py index f37744e..a1a4148 100644 --- a/litecord/blueprints/user/billing.py +++ b/litecord/blueprints/user/billing.py @@ -33,7 +33,7 @@ from litecord.enums import UserFlags, PremiumType from litecord.blueprints.users import mass_user_update log = Logger(__name__) -bp = Blueprint('users_billing', __name__) +bp = Blueprint("users_billing", __name__) class PaymentSource(Enum): @@ -68,78 +68,87 @@ class PaymentStatus: PLAN_ID_TO_TYPE = { - 'premium_month_tier_1': PremiumType.TIER_1, - 'premium_month_tier_2': PremiumType.TIER_2, - 'premium_year_tier_1': PremiumType.TIER_1, - 'premium_year_tier_2': PremiumType.TIER_2, + "premium_month_tier_1": PremiumType.TIER_1, + "premium_month_tier_2": PremiumType.TIER_2, + "premium_year_tier_1": PremiumType.TIER_1, + "premium_year_tier_2": PremiumType.TIER_2, } # how much should a payment be, depending # of the subscription AMOUNTS = { - 'premium_month_tier_1': 499, - 'premium_month_tier_2': 999, - 'premium_year_tier_1': 4999, - 'premium_year_tier_2': 9999, + "premium_month_tier_1": 499, + "premium_month_tier_2": 999, + "premium_year_tier_1": 4999, + "premium_year_tier_2": 9999, } CREATE_SUBSCRIPTION = { - 'payment_gateway_plan_id': {'type': 'string'}, - 'payment_source_id': {'coerce': int} + "payment_gateway_plan_id": {"type": "string"}, + "payment_source_id": {"coerce": int}, } PAYMENT_SOURCE = { - 'billing_address': { - 'type': 'dict', - 'schema': { - 'country': {'type': 'string', 'required': True}, - 'city': {'type': 'string', 'required': True}, - 'name': {'type': 'string', 'required': True}, - 'line_1': {'type': 'string', 'required': False}, - 'line_2': {'type': 'string', 'required': False}, - 'postal_code': {'type': 'string', 'required': True}, - 'state': {'type': 'string', 'required': True}, - } + "billing_address": { + "type": "dict", + "schema": { + "country": {"type": "string", "required": True}, + "city": {"type": "string", "required": True}, + "name": {"type": "string", "required": True}, + "line_1": {"type": "string", "required": False}, + "line_2": {"type": "string", "required": False}, + "postal_code": {"type": "string", "required": True}, + "state": {"type": "string", "required": True}, + }, }, - 'payment_gateway': {'type': 'number', 'required': True}, - 'token': {'type': 'string', 'required': True}, + "payment_gateway": {"type": "number", "required": True}, + "token": {"type": "string", "required": True}, } async def get_payment_source_ids(user_id: int) -> list: - rows = await app.db.fetch(""" + rows = await app.db.fetch( + """ SELECT id FROM user_payment_sources WHERE user_id = $1 - """, user_id) + """, + user_id, + ) - return [r['id'] for r in rows] + return [r["id"] for r in rows] async def get_payment_ids(user_id: int, db=None) -> list: if not db: db = app.db - rows = await db.fetch(""" + rows = await db.fetch( + """ SELECT id FROM user_payments WHERE user_id = $1 - """, user_id) + """, + user_id, + ) - return [r['id'] for r in rows] + return [r["id"] for r in rows] async def get_subscription_ids(user_id: int) -> list: - rows = await app.db.fetch(""" + rows = await app.db.fetch( + """ SELECT id FROM user_subscriptions WHERE user_id = $1 - """, user_id) + """, + user_id, + ) - return [r['id'] for r in rows] + return [r["id"] for r in rows] async def get_payment_source(user_id: int, source_id: int, db=None) -> dict: @@ -148,41 +157,44 @@ async def get_payment_source(user_id: int, source_id: int, db=None) -> dict: if not db: db = app.db - source_type = await db.fetchval(""" + source_type = await db.fetchval( + """ SELECT source_type FROM user_payment_sources WHERE id = $1 AND user_id = $2 - """, source_id, user_id) + """, + source_id, + user_id, + ) source_type = PaymentSource(source_type) specific_fields = { - PaymentSource.PAYPAL: ['paypal_email'], - PaymentSource.CREDIT: ['expires_month', 'expires_year', - 'brand', 'cc_full'] + PaymentSource.PAYPAL: ["paypal_email"], + PaymentSource.CREDIT: ["expires_month", "expires_year", "brand", "cc_full"], }[source_type] - fields = ','.join(specific_fields) + fields = ",".join(specific_fields) - extras_row = await db.fetchrow(f""" + extras_row = await db.fetchrow( + f""" SELECT {fields}, billing_address, default_, id::text FROM user_payment_sources WHERE id = $1 - """, source_id) + """, + source_id, + ) derow = dict(extras_row) if source_type == PaymentSource.CREDIT: - derow['last_4'] = derow['cc_full'][-4:] - derow.pop('cc_full') + derow["last_4"] = derow["cc_full"][-4:] + derow.pop("cc_full") - derow['default'] = derow['default_'] - derow.pop('default_') + derow["default"] = derow["default_"] + derow.pop("default_") - source = { - 'id': str(source_id), - 'type': source_type.value, - } + source = {"id": str(source_id), "type": source_type.value} return {**source, **derow} @@ -192,7 +204,8 @@ async def get_subscription(subscription_id: int, db=None): if not db: db = app.db - row = await db.fetchrow(""" + row = await db.fetchrow( + """ SELECT id::text, source_id::text AS payment_source_id, user_id, payment_gateway, payment_gateway_plan_id, @@ -201,14 +214,16 @@ async def get_subscription(subscription_id: int, db=None): canceled_at, s_type, status FROM user_subscriptions WHERE id = $1 - """, subscription_id) + """, + subscription_id, + ) drow = dict(row) - drow['type'] = drow['s_type'] - drow.pop('s_type') + drow["type"] = drow["s_type"] + drow.pop("s_type") - to_tstamp = ['current_period_start', 'current_period_end', 'canceled_at'] + to_tstamp = ["current_period_start", "current_period_end", "canceled_at"] for field in to_tstamp: drow[field] = timestamp_(drow[field]) @@ -221,27 +236,30 @@ async def get_payment(payment_id: int, db=None): if not db: db = app.db - row = await db.fetchrow(""" + row = await db.fetchrow( + """ SELECT id::text, source_id, subscription_id, user_id, amount, amount_refunded, currency, description, status, tax, tax_inclusive FROM user_payments WHERE id = $1 - """, payment_id) + """, + payment_id, + ) drow = dict(row) - drow.pop('source_id') - drow.pop('subscription_id') - drow.pop('user_id') + drow.pop("source_id") + drow.pop("subscription_id") + drow.pop("user_id") - drow['created_at'] = snowflake_datetime(int(drow['id'])) + drow["created_at"] = snowflake_datetime(int(drow["id"])) - drow['payment_source'] = await get_payment_source( - row['user_id'], row['source_id'], db) + drow["payment_source"] = await get_payment_source( + row["user_id"], row["source_id"], db + ) - drow['subscription'] = await get_subscription( - row['subscription_id'], db) + drow["subscription"] = await get_subscription(row["subscription_id"], db) return drow @@ -255,7 +273,7 @@ async def create_payment(subscription_id, db=None): new_id = get_snowflake() - amount = AMOUNTS[sub['payment_gateway_plan_id']] + amount = AMOUNTS[sub["payment_gateway_plan_id"]] await db.execute( """ @@ -266,10 +284,16 @@ async def create_payment(subscription_id, db=None): ) VALUES ($1, $2, $3, $4, $5, 0, $6, $7, $8, 0, false) - """, new_id, int(sub['payment_source_id']), - subscription_id, int(sub['user_id']), - amount, 'usd', 'FUCK NITRO', - PaymentStatus.SUCCESS) + """, + new_id, + int(sub["payment_source_id"]), + subscription_id, + int(sub["user_id"]), + amount, + "usd", + "FUCK NITRO", + PaymentStatus.SUCCESS, + ) return new_id @@ -278,29 +302,34 @@ async def process_subscription(app, subscription_id: int): """Process a single subscription.""" sub = await get_subscription(subscription_id, app.db) - user_id = int(sub['user_id']) + user_id = int(sub["user_id"]) - if sub['status'] != SubscriptionStatus.ACTIVE: - log.debug('ignoring sub {}, not active', - subscription_id) + if sub["status"] != SubscriptionStatus.ACTIVE: + log.debug("ignoring sub {}, not active", subscription_id) return # if the subscription is still active # (should get cancelled status on failed # payments), then we should update premium status - first_payment_id = await app.db.fetchval(""" + first_payment_id = await app.db.fetchval( + """ SELECT MIN(id) FROM user_payments WHERE subscription_id = $1 - """, subscription_id) + """, + subscription_id, + ) first_payment_ts = snowflake_datetime(first_payment_id) - premium_since = await app.db.fetchval(""" + premium_since = await app.db.fetchval( + """ SELECT premium_since FROM users WHERE id = $1 - """, user_id) + """, + user_id, + ) premium_since = premium_since or datetime.datetime.fromtimestamp(0) @@ -312,27 +341,34 @@ async def process_subscription(app, subscription_id: int): if delta.total_seconds() < 24 * HOURS: return - old_flags = await app.db.fetchval(""" + old_flags = await app.db.fetchval( + """ SELECT flags FROM users WHERE id = $1 - """, user_id) + """, + user_id, + ) new_flags = old_flags | UserFlags.premium_early - log.debug('updating flags {}, {} => {}', - user_id, old_flags, new_flags) + log.debug("updating flags {}, {} => {}", user_id, old_flags, new_flags) - await app.db.execute(""" + await app.db.execute( + """ UPDATE users SET premium_since = $1, flags = $2 WHERE id = $3 - """, first_payment_ts, new_flags, user_id) + """, + first_payment_ts, + new_flags, + user_id, + ) # dispatch updated user to all possible clients await mass_user_update(user_id, app) -@bp.route('/@me/billing/payment-sources', methods=['GET']) +@bp.route("/@me/billing/payment-sources", methods=["GET"]) async def _get_billing_sources(): user_id = await token_check() source_ids = await get_payment_source_ids(user_id) @@ -346,7 +382,7 @@ async def _get_billing_sources(): return jsonify(res) -@bp.route('/@me/billing/subscriptions', methods=['GET']) +@bp.route("/@me/billing/subscriptions", methods=["GET"]) async def _get_billing_subscriptions(): user_id = await token_check() sub_ids = await get_subscription_ids(user_id) @@ -358,7 +394,7 @@ async def _get_billing_subscriptions(): return jsonify(res) -@bp.route('/@me/billing/payments', methods=['GET']) +@bp.route("/@me/billing/payments", methods=["GET"]) async def _get_billing_payments(): user_id = await token_check() payment_ids = await get_payment_ids(user_id) @@ -370,7 +406,7 @@ async def _get_billing_payments(): return jsonify(res) -@bp.route('/@me/billing/payment-sources', methods=['POST']) +@bp.route("/@me/billing/payment-sources", methods=["POST"]) async def _create_payment_source(): user_id = await token_check() j = validate(await request.get_json(), PAYMENT_SOURCE) @@ -383,34 +419,40 @@ async def _create_payment_source(): default_, expires_month, expires_year, brand, cc_full, billing_address) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) - """, new_source_id, user_id, PaymentSource.CREDIT.value, - True, 12, 6969, 'Visa', '4242424242424242', - json.dumps(j['billing_address'])) - - return jsonify( - await get_payment_source(user_id, new_source_id) + """, + new_source_id, + user_id, + PaymentSource.CREDIT.value, + True, + 12, + 6969, + "Visa", + "4242424242424242", + json.dumps(j["billing_address"]), ) + return jsonify(await get_payment_source(user_id, new_source_id)) -@bp.route('/@me/billing/subscriptions', methods=['POST']) + +@bp.route("/@me/billing/subscriptions", methods=["POST"]) async def _create_subscription(): user_id = await token_check() j = validate(await request.get_json(), CREATE_SUBSCRIPTION) - source = await get_payment_source(user_id, j['payment_source_id']) + source = await get_payment_source(user_id, j["payment_source_id"]) if not source: - raise BadRequest('invalid source id') + raise BadRequest("invalid source id") - plan_id = j['payment_gateway_plan_id'] + plan_id = j["payment_gateway_plan_id"] # tier 1 is lightro / classic # tier 2 is nitro period_end = { - 'premium_month_tier_1': '1 month', - 'premium_month_tier_2': '1 month', - 'premium_year_tier_1': '1 year', - 'premium_year_tier_2': '1 year', + "premium_month_tier_1": "1 month", + "premium_month_tier_2": "1 month", + "premium_year_tier_1": "1 year", + "premium_year_tier_2": "1 year", }[plan_id] new_id = get_snowflake() @@ -422,9 +464,15 @@ async def _create_subscription(): status, period_end) VALUES ($1, $2, $3, $4, $5, $6, $7, now()::timestamp + interval '{period_end}') - """, new_id, j['payment_source_id'], user_id, - SubscriptionType.PURCHASE, PaymentGateway.STRIPE, - plan_id, 1) + """, + new_id, + j["payment_source_id"], + user_id, + SubscriptionType.PURCHASE, + PaymentGateway.STRIPE, + plan_id, + 1, + ) await create_payment(new_id, app.db) @@ -432,21 +480,17 @@ async def _create_subscription(): # and dispatch respective user updates to other people. await process_subscription(app, new_id) - return jsonify( - await get_subscription(new_id) - ) + return jsonify(await get_subscription(new_id)) -@bp.route('/@me/billing/subscriptions/', - methods=['DELETE']) +@bp.route("/@me/billing/subscriptions/", methods=["DELETE"]) async def _delete_subscription(subscription_id): # user_id = await token_check() # return '', 204 pass -@bp.route('/@me/billing/subscriptions/', - methods=['PATCH']) +@bp.route("/@me/billing/subscriptions/", methods=["PATCH"]) async def _patch_subscription(subscription_id): """change a subscription's payment source""" # user_id = await token_check() diff --git a/litecord/blueprints/user/billing_job.py b/litecord/blueprints/user/billing_job.py index b5e9f7f..4148415 100644 --- a/litecord/blueprints/user/billing_job.py +++ b/litecord/blueprints/user/billing_job.py @@ -25,8 +25,11 @@ from asyncio import sleep, CancelledError from logbook import Logger from litecord.blueprints.user.billing import ( - get_subscription, get_payment_ids, get_payment, create_payment, - process_subscription + get_subscription, + get_payment_ids, + get_payment, + create_payment, + process_subscription, ) from litecord.snowflake import snowflake_datetime @@ -37,15 +40,15 @@ log = Logger(__name__) # how many days until a payment needs # to be issued THRESHOLDS = { - 'premium_month_tier_1': 30, - 'premium_month_tier_2': 30, - 'premium_year_tier_1': 365, - 'premium_year_tier_2': 365, + "premium_month_tier_1": 30, + "premium_month_tier_2": 30, + "premium_year_tier_1": 365, + "premium_year_tier_2": 365, } async def _resched(app): - log.debug('waiting 30 minutes for job.') + log.debug("waiting 30 minutes for job.") await sleep(30 * MINUTES) app.sched.spawn(payment_job(app)) @@ -54,10 +57,10 @@ async def _process_user_payments(app, user_id: int): payments = await get_payment_ids(user_id, app.db) if not payments: - log.debug('no payments for uid {}, skipping', user_id) + log.debug("no payments for uid {}, skipping", user_id) return - log.debug('{} payments for uid {}', len(payments), user_id) + log.debug("{} payments for uid {}", len(payments), user_id) latest_payment = max(payments) @@ -66,33 +69,29 @@ async def _process_user_payments(app, user_id: int): # calculate the difference between this payment # and now. now = datetime.datetime.now() - payment_tstamp = snowflake_datetime(int(payment_data['id'])) + payment_tstamp = snowflake_datetime(int(payment_data["id"])) delta = now - payment_tstamp - sub_id = int(payment_data['subscription']['id']) - subscription = await get_subscription( - sub_id, app.db) + sub_id = int(payment_data["subscription"]["id"]) + subscription = await get_subscription(sub_id, app.db) # if the max payment is X days old, we create another. # X is 30 for monthly subscriptions of nitro, # X is 365 for yearly subscriptions of nitro - threshold = THRESHOLDS[subscription['payment_gateway_plan_id']] + threshold = THRESHOLDS[subscription["payment_gateway_plan_id"]] - log.debug('delta {} delta days {} threshold {}', - delta, delta.days, threshold) + log.debug("delta {} delta days {} threshold {}", delta, delta.days, threshold) if delta.days > threshold: - log.info('creating payment for sid={}', - sub_id) + log.info("creating payment for sid={}", sub_id) # create_payment does not call any Stripe # or BrainTree APIs at all, since we'll just # give it as free. await create_payment(sub_id, app.db) else: - log.debug('sid={}, missing {} days', - sub_id, threshold - delta.days) + log.debug("sid={}, missing {} days", sub_id, threshold - delta.days) async def payment_job(app): @@ -101,35 +100,39 @@ async def payment_job(app): This function will check through users' payments and add a new one once a month / year. """ - log.debug('payment job start!') + log.debug("payment job start!") - user_ids = await app.db.fetch(""" + user_ids = await app.db.fetch( + """ SELECT DISTINCT user_id FROM user_payments - """) + """ + ) - log.debug('working {} users', len(user_ids)) + log.debug("working {} users", len(user_ids)) # go through each user's payments for row in user_ids: - user_id = row['user_id'] + user_id = row["user_id"] try: await _process_user_payments(app, user_id) except Exception: - log.exception('error while processing user payments') + log.exception("error while processing user payments") - subscribers = await app.db.fetch(""" + subscribers = await app.db.fetch( + """ SELECT id FROM user_subscriptions - """) + """ + ) for row in subscribers: try: - await process_subscription(app, row['id']) + await process_subscription(app, row["id"]) except Exception: - log.exception('error while processing subscription') - log.debug('rescheduling..') + log.exception("error while processing subscription") + log.debug("rescheduling..") try: await _resched(app) except CancelledError: - log.info('cancelled while waiting for resched') + log.info("cancelled while waiting for resched") diff --git a/litecord/blueprints/user/fake_store.py b/litecord/blueprints/user/fake_store.py index 68c3dff..7f4fc3c 100644 --- a/litecord/blueprints/user/fake_store.py +++ b/litecord/blueprints/user/fake_store.py @@ -22,24 +22,26 @@ fake routes for discord store """ from quart import Blueprint, jsonify -bp = Blueprint('fake_store', __name__) +bp = Blueprint("fake_store", __name__) -@bp.route('/promotions') +@bp.route("/promotions") async def _get_promotions(): return jsonify([]) -@bp.route('/users/@me/library') +@bp.route("/users/@me/library") async def _get_library(): return jsonify([]) -@bp.route('/users/@me/feed/settings') +@bp.route("/users/@me/feed/settings") async def _get_feed_settings(): - return jsonify({ - 'subscribed_games': [], - 'subscribed_users': [], - 'unsubscribed_users': [], - 'unsubscribed_games': [], - }) + return jsonify( + { + "subscribed_games": [], + "subscribed_users": [], + "unsubscribed_users": [], + "unsubscribed_games": [], + } + ) diff --git a/litecord/blueprints/user/settings.py b/litecord/blueprints/user/settings.py index 2cc73dc..e64e27e 100644 --- a/litecord/blueprints/user/settings.py +++ b/litecord/blueprints/user/settings.py @@ -23,10 +23,10 @@ from litecord.auth import token_check from litecord.schemas import validate, USER_SETTINGS, GUILD_SETTINGS from litecord.blueprints.checks import guild_check -bp = Blueprint('users_settings', __name__) +bp = Blueprint("users_settings", __name__) -@bp.route('/@me/settings', methods=['GET']) +@bp.route("/@me/settings", methods=["GET"]) async def get_user_settings(): """Get the current user's settings.""" user_id = await token_check() @@ -34,7 +34,7 @@ async def get_user_settings(): return jsonify(settings) -@bp.route('/@me/settings', methods=['PATCH']) +@bp.route("/@me/settings", methods=["PATCH"]) async def patch_current_settings(): """Patch the users' current settings. @@ -47,19 +47,22 @@ async def patch_current_settings(): for key in j: val = j[key] - await app.storage.execute_with_json(f""" + await app.storage.execute_with_json( + f""" UPDATE user_settings SET {key}=$1 WHERE id = $2 - """, val, user_id) + """, + val, + user_id, + ) settings = await app.user_storage.get_user_settings(user_id) - await app.dispatcher.dispatch_user( - user_id, 'USER_SETTINGS_UPDATE', settings) + await app.dispatcher.dispatch_user(user_id, "USER_SETTINGS_UPDATE", settings) return jsonify(settings) -@bp.route('/@me/guilds//settings', methods=['PATCH']) +@bp.route("/@me/guilds//settings", methods=["PATCH"]) async def patch_guild_settings(guild_id: int): """Update the users' guild settings for a given guild. @@ -74,16 +77,21 @@ async def patch_guild_settings(guild_id: int): # will make sure they exist in the table. await app.user_storage.get_guild_settings_one(user_id, guild_id) - for field in (k for k in j.keys() if k != 'channel_overrides'): - await app.db.execute(f""" + for field in (k for k in j.keys() if k != "channel_overrides"): + await app.db.execute( + f""" UPDATE guild_settings SET {field} = $1 WHERE user_id = $2 AND guild_id = $3 - """, j[field], user_id, guild_id) + """, + j[field], + user_id, + guild_id, + ) chan_ids = await app.storage.get_channel_ids(guild_id) - for chandata in j.get('channel_overrides', {}).items(): + for chandata in j.get("channel_overrides", {}).items(): chan_id, chan_overrides = chandata chan_id = int(chan_id) @@ -92,7 +100,8 @@ async def patch_guild_settings(guild_id: int): continue for field in chan_overrides: - await app.db.execute(f""" + await app.db.execute( + f""" INSERT INTO guild_settings_channel_overrides (user_id, guild_id, channel_id, {field}) VALUES @@ -105,18 +114,21 @@ async def patch_guild_settings(guild_id: int): WHERE guild_settings_channel_overrides.user_id = $1 AND guild_settings_channel_overrides.guild_id = $2 AND guild_settings_channel_overrides.channel_id = $3 - """, user_id, guild_id, chan_id, chan_overrides[field]) + """, + user_id, + guild_id, + chan_id, + chan_overrides[field], + ) - settings = await app.user_storage.get_guild_settings_one( - user_id, guild_id) + settings = await app.user_storage.get_guild_settings_one(user_id, guild_id) - await app.dispatcher.dispatch_user( - user_id, 'USER_GUILD_SETTINGS_UPDATE', settings) + await app.dispatcher.dispatch_user(user_id, "USER_GUILD_SETTINGS_UPDATE", settings) return jsonify(settings) -@bp.route('/@me/notes/', methods=['PUT']) +@bp.route("/@me/notes/", methods=["PUT"]) async def put_note(target_id: int): """Put a note to a user. @@ -126,10 +138,11 @@ async def put_note(target_id: int): user_id = await token_check() j = await request.get_json() - note = str(j['note']) + note = str(j["note"]) # UPSERTs are beautiful - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO notes (user_id, target_id, note) VALUES ($1, $2, $3) @@ -138,12 +151,14 @@ async def put_note(target_id: int): note = $3 WHERE notes.user_id = $1 AND notes.target_id = $2 - """, user_id, target_id, note) + """, + user_id, + target_id, + note, + ) - await app.dispatcher.dispatch_user(user_id, 'USER_NOTE_UPDATE', { - 'id': str(target_id), - 'note': note, - }) - - return '', 204 + await app.dispatcher.dispatch_user( + user_id, "USER_NOTE_UPDATE", {"id": str(target_id), "note": note} + ) + return "", 204 diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py index 64cc73e..d5bda02 100644 --- a/litecord/blueprints/users.py +++ b/litecord/blueprints/users.py @@ -27,9 +27,7 @@ from ..errors import Forbidden, BadRequest, Unauthorized from ..schemas import validate, USER_UPDATE, GET_MENTIONS from .guilds import guild_check -from litecord.auth import ( - token_check, hash_data, check_username_usage, roll_discrim -) +from litecord.auth import token_check, hash_data, check_username_usage, roll_discrim from litecord.blueprints.guild.mod import remove_member from litecord.enums import PremiumType @@ -39,7 +37,7 @@ from litecord.permissions import base_permissions from litecord.blueprints.auth import check_password from litecord.utils import to_update -bp = Blueprint('user', __name__) +bp = Blueprint("user", __name__) log = Logger(__name__) @@ -58,8 +56,7 @@ async def mass_user_update(user_id, app_=None): private_user = await app_.storage.get_user(user_id, secure=True) session_ids.extend( - await app_.dispatcher.dispatch_user( - user_id, 'USER_UPDATE', private_user) + await app_.dispatcher.dispatch_user(user_id, "USER_UPDATE", private_user) ) guild_ids = await app_.user_storage.get_user_guilds(user_id) @@ -67,26 +64,22 @@ async def mass_user_update(user_id, app_=None): session_ids.extend( await app_.dispatcher.dispatch_many_filter_list( - 'guild', guild_ids, session_ids, - 'USER_UPDATE', public_user + "guild", guild_ids, session_ids, "USER_UPDATE", public_user ) ) session_ids.extend( await app_.dispatcher.dispatch_many_filter_list( - 'friend', friend_ids, session_ids, - 'USER_UPDATE', public_user + "friend", friend_ids, session_ids, "USER_UPDATE", public_user ) ) - await app_.dispatcher.dispatch_many( - 'lazy_guild', guild_ids, 'update_user', user_id - ) + await app_.dispatcher.dispatch_many("lazy_guild", guild_ids, "update_user", user_id) return public_user, private_user -@bp.route('/@me', methods=['GET']) +@bp.route("/@me", methods=["GET"]) async def get_me(): """Get the current user's information.""" user_id = await token_check() @@ -94,18 +87,21 @@ async def get_me(): return jsonify(user) -@bp.route('/', methods=['GET']) +@bp.route("/", methods=["GET"]) async def get_other(target_id): """Get any user, given the user ID.""" user_id = await token_check() - bot = await app.db.fetchval(""" + bot = await app.db.fetchval( + """ SELECT bot FROM users WHERE users.id = $1 - """, user_id) + """, + user_id, + ) if not bot: - raise Forbidden('Only bots can use this endpoint') + raise Forbidden("Only bots can use this endpoint") other = await app.storage.get_user(target_id) return jsonify(other) @@ -116,66 +112,80 @@ async def _try_username_patch(user_id, new_username: str) -> str: discrim = None try: - await app.db.execute(""" + await app.db.execute( + """ UPDATE users SET username = $1 WHERE users.id = $2 - """, new_username, user_id) + """, + new_username, + user_id, + ) - return await app.db.fetchval(""" + return await app.db.fetchval( + """ SELECT discriminator FROM users WHERE users.id = $1 - """, user_id) + """, + user_id, + ) except UniqueViolationError: discrim = await roll_discrim(new_username) if not discrim: - raise BadRequest('Unable to change username', { - 'username': 'Too many people are with this username.' - }) + raise BadRequest( + "Unable to change username", + {"username": "Too many people are with this username."}, + ) - await app.db.execute(""" + await app.db.execute( + """ UPDATE users SET username = $1, discriminator = $2 WHERE users.id = $3 - """, new_username, discrim, user_id) + """, + new_username, + discrim, + user_id, + ) return discrim async def _try_discrim_patch(user_id, new_discrim: str): try: - await app.db.execute(""" + await app.db.execute( + """ UPDATE users SET discriminator = $1 WHERE id = $2 - """, new_discrim, user_id) + """, + new_discrim, + user_id, + ) except UniqueViolationError: - raise BadRequest('Invalid discriminator', { - 'discriminator': 'Someone already used this discriminator.' - }) + raise BadRequest( + "Invalid discriminator", + {"discriminator": "Someone already used this discriminator."}, + ) async def _check_pass(j, user): # Do not do password checks on unclaimed accounts - if user['email'] is None: + if user["email"] is None: return - if not j['password']: - raise BadRequest('password required', { - 'password': 'password required' - }) + if not j["password"]: + raise BadRequest("password required", {"password": "password required"}) - phash = user['password_hash'] + phash = user["password_hash"] - if not await check_password(phash, j['password']): - raise BadRequest('password incorrect', { - 'password': 'password does not match.' - }) + if not await check_password(phash, j["password"]): + raise BadRequest("password incorrect", {"password": "password does not match."}) -@bp.route('/@me', methods=['PATCH']) +@bp.route("/@me", methods=["PATCH"]) async def patch_me(): """Patch the current user's information.""" user_id = await token_check() @@ -183,36 +193,43 @@ async def patch_me(): j = validate(await request.get_json(), USER_UPDATE) user = await app.storage.get_user(user_id, True) - user['password_hash'] = await app.db.fetchval(""" + user["password_hash"] = await app.db.fetchval( + """ SELECT password_hash FROM users WHERE id = $1 - """, user_id) + """, + user_id, + ) - if to_update(j, user, 'username'): + if to_update(j, user, "username"): # this will take care of regenning a new discriminator - discrim = await _try_username_patch(user_id, j['username']) - user['username'] = j['username'] - user['discriminator'] = discrim + discrim = await _try_username_patch(user_id, j["username"]) + user["username"] = j["username"] + user["discriminator"] = discrim - if to_update(j, user, 'discriminator'): + if to_update(j, user, "discriminator"): # the API treats discriminators as integers, # but I work with strings on the database. - new_discrim = str(j['discriminator']) + new_discrim = str(j["discriminator"]) await _try_discrim_patch(user_id, new_discrim) - user['discriminator'] = new_discrim + user["discriminator"] = new_discrim - if to_update(j, user, 'email'): + if to_update(j, user, "email"): await _check_pass(j, user) # TODO: reverify the new email? - await app.db.execute(""" + await app.db.execute( + """ UPDATE users SET email = $1 WHERE id = $2 - """, j['email'], user_id) - user['email'] = j['email'] + """, + j["email"], + user_id, + ) + user["email"] = j["email"] # only update if values are different # from what the user gave. @@ -224,44 +241,49 @@ async def patch_me(): # IconManager.update will take care of validating # the value once put()-ing - if to_update(j, user, 'avatar'): - mime, _ = parse_data_uri(j['avatar']) + if to_update(j, user, "avatar"): + mime, _ = parse_data_uri(j["avatar"]) - if mime == 'image/gif' and user['premium_type'] == PremiumType.NONE: - raise BadRequest('no gif without nitro') + if mime == "image/gif" and user["premium_type"] == PremiumType.NONE: + raise BadRequest("no gif without nitro") - new_icon = await app.icons.update( - 'user', user_id, j['avatar'], size=(128, 128)) + new_icon = await app.icons.update("user", user_id, j["avatar"], size=(128, 128)) - await app.db.execute(""" + await app.db.execute( + """ UPDATE users SET avatar = $1 WHERE id = $2 - """, new_icon.icon_hash, user_id) + """, + new_icon.icon_hash, + user_id, + ) - if user['email'] is None and not 'new_password' in j: - raise BadRequest('missing password', { - 'password': 'Please set a password.' - }) + if user["email"] is None and not "new_password" in j: + raise BadRequest("missing password", {"password": "Please set a password."}) - if 'new_password' in j and j['new_password']: + if "new_password" in j and j["new_password"]: await _check_pass(j, user) - new_hash = await hash_data(j['new_password']) + new_hash = await hash_data(j["new_password"]) - await app.db.execute(""" + await app.db.execute( + """ UPDATE users SET password_hash = $1 WHERE id = $2 - """, new_hash, user_id) + """, + new_hash, + user_id, + ) - user.pop('password_hash') + user.pop("password_hash") _, private_user = await mass_user_update(user_id, app) return jsonify(private_user) -@bp.route('/@me/guilds', methods=['GET']) +@bp.route("/@me/guilds", methods=["GET"]) async def get_me_guilds(): """Get partial user guilds.""" user_id = await token_check() @@ -270,27 +292,30 @@ async def get_me_guilds(): partials = [] for guild_id in guild_ids: - partial = await app.db.fetchrow(""" + partial = await app.db.fetchrow( + """ SELECT id::text, name, icon, owner_id FROM guilds WHERE guilds.id = $1 - """, guild_id) + """, + guild_id, + ) partial = dict(partial) user_perms = await base_permissions(user_id, guild_id) - partial['permissions'] = user_perms.binary + partial["permissions"] = user_perms.binary - partial['owner'] = partial['owner_id'] == user_id + partial["owner"] = partial["owner_id"] == user_id - partial.pop('owner_id') + partial.pop("owner_id") partials.append(partial) return jsonify(partials) -@bp.route('/@me/guilds/', methods=['DELETE']) +@bp.route("/@me/guilds/", methods=["DELETE"]) async def leave_guild(guild_id: int): """Leave a guild.""" user_id = await token_check() @@ -298,7 +323,7 @@ async def leave_guild(guild_id: int): await remove_member(guild_id, user_id) - return '', 204 + return "", 204 # @bp.route('/@me/connections', methods=['GET']) @@ -306,7 +331,7 @@ async def get_connections(): pass -@bp.route('/@me/consent', methods=['GET', 'POST']) +@bp.route("/@me/consent", methods=["GET", "POST"]) async def get_consent(): """Always disable data collection. @@ -314,57 +339,58 @@ async def get_consent(): by the client and ignores them, as they will always be false. """ - return jsonify({ - 'usage_statistics': { - 'consented': False, - }, - 'personalization': { - 'consented': False, + return jsonify( + { + "usage_statistics": {"consented": False}, + "personalization": {"consented": False}, } - }) + ) -@bp.route('/@me/harvest', methods=['GET']) +@bp.route("/@me/harvest", methods=["GET"]) async def get_harvest(): """Dummy route""" - return '', 204 + return "", 204 -@bp.route('/@me/activities/statistics/applications', methods=['GET']) +@bp.route("/@me/activities/statistics/applications", methods=["GET"]) async def get_stats_applications(): """Dummy route for info on gameplay time and such""" return jsonify([]) -@bp.route('/@me/library', methods=['GET']) +@bp.route("/@me/library", methods=["GET"]) async def get_library(): """Probably related to Discord Store?""" return jsonify([]) -@bp.route('//profile', methods=['GET']) +@bp.route("//profile", methods=["GET"]) async def get_profile(peer_id: int): """Get a user's profile.""" user_id = await token_check() peer = await app.storage.get_user(peer_id) if not peer: - return '', 404 + return "", 404 mutuals = await app.user_storage.get_mutual_guilds(user_id, peer_id) friends = await app.user_storage.are_friends_with(user_id, peer_id) # don't return a proper card if no guilds are being shared. if not mutuals and not friends: - return '', 404 + return "", 404 # actual premium status is determined by that # column being NULL or not - peer_premium = await app.db.fetchval(""" + peer_premium = await app.db.fetchval( + """ SELECT premium_since FROM users WHERE id = $1 - """, peer_id) + """, + peer_id, + ) mutual_guilds = await app.user_storage.get_mutual_guilds(user_id, peer_id) mutual_res = [] @@ -372,45 +398,49 @@ async def get_profile(peer_id: int): # ascending sorting for guild_id in sorted(mutual_guilds): - nick = await app.db.fetchval(""" + nick = await app.db.fetchval( + """ SELECT nickname FROM members WHERE guild_id = $1 AND user_id = $2 - """, guild_id, peer_id) + """, + guild_id, + peer_id, + ) - mutual_res.append({ - 'id': str(guild_id), - 'nick': nick, - }) + mutual_res.append({"id": str(guild_id), "nick": nick}) - return jsonify({ - 'user': peer, - 'connected_accounts': [], - 'premium_since': peer_premium, - 'mutual_guilds': mutual_res, - }) + return jsonify( + { + "user": peer, + "connected_accounts": [], + "premium_since": peer_premium, + "mutual_guilds": mutual_res, + } + ) -@bp.route('/@me/mentions', methods=['GET']) +@bp.route("/@me/mentions", methods=["GET"]) async def _get_mentions(): user_id = await token_check() j = validate(dict(request.args), GET_MENTIONS) - guild_query = 'AND messages.guild_id = $2' if 'guild_id' in j else '' - role_query = "OR content LIKE '%<@&%'" if j['roles'] else '' - everyone_query = "OR content LIKE '%@everyone%'" if j['everyone'] else '' - mention_user = f'<@{user_id}>' + guild_query = "AND messages.guild_id = $2" if "guild_id" in j else "" + role_query = "OR content LIKE '%<@&%'" if j["roles"] else "" + everyone_query = "OR content LIKE '%@everyone%'" if j["everyone"] else "" + mention_user = f"<@{user_id}>" args = [mention_user] if guild_query: - args.append(j['guild_id']) + args.append(j["guild_id"]) guild_ids = await app.user_storage.get_user_guilds(user_id) - gids = ','.join(str(guild_id) for guild_id in guild_ids) + gids = ",".join(str(guild_id) for guild_id in guild_ids) - rows = await app.db.fetch(f""" + rows = await app.db.fetch( + f""" SELECT messages.id FROM messages JOIN channels ON messages.channel_id = channels.id @@ -423,20 +453,20 @@ async def _get_mentions(): {guild_query} ) LIMIT {j["limit"]} - """, *args) + """, + *args, + ) res = [] for row in rows: - message = await app.storage.get_message(row['id']) - gid = int(message['guild_id']) + message = await app.storage.get_message(row["id"]) + gid = int(message["guild_id"]) # ignore messages pre-messages.guild_id if gid not in guild_ids: continue - res.append( - message - ) + res.append(message) return jsonify(res) @@ -449,18 +479,20 @@ def rand_hex(length: int = 8) -> str: async def _del_from_table(db, table: str, user_id: int): """Delete a row from a table.""" column = { - 'channel_overwrites': 'target_user', - 'user_settings': 'id', - 'group_dm_members': 'member_id' - }.get(table, 'user_id') + "channel_overwrites": "target_user", + "user_settings": "id", + "group_dm_members": "member_id", + }.get(table, "user_id") - res = await db.execute(f""" + res = await db.execute( + f""" DELETE FROM {table} WHERE {column} = $1 - """, user_id) + """, + user_id, + ) - log.info('Deleting uid {} from {}, res: {!r}', - user_id, table, res) + log.info("Deleting uid {} from {}, res: {!r}", user_id, table, res) async def delete_user(user_id, *, app_=None): @@ -470,13 +502,14 @@ async def delete_user(user_id, *, app_=None): db = app_.db - new_username = f'Deleted User {rand_hex()}' + new_username = f"Deleted User {rand_hex()}" # by using a random hex in password_hash # we break attempts at using the default '123' password hash # to issue valid tokens for deleted users. - await db.execute(""" + await db.execute( + """ UPDATE users SET username = $1, @@ -490,32 +523,39 @@ async def delete_user(user_id, *, app_=None): password_hash = $2 WHERE id = $3 - """, new_username, rand_hex(32), user_id) + """, + new_username, + rand_hex(32), + user_id, + ) # remove the user from various tables - await _del_from_table(db, 'user_settings', user_id) - await _del_from_table(db, 'user_payment_sources', user_id) - await _del_from_table(db, 'user_subscriptions', user_id) - await _del_from_table(db, 'user_payments', user_id) - await _del_from_table(db, 'user_read_state', user_id) - await _del_from_table(db, 'guild_settings', user_id) - await _del_from_table(db, 'guild_settings_channel_overrides', user_id) + await _del_from_table(db, "user_settings", user_id) + await _del_from_table(db, "user_payment_sources", user_id) + await _del_from_table(db, "user_subscriptions", user_id) + await _del_from_table(db, "user_payments", user_id) + await _del_from_table(db, "user_read_state", user_id) + await _del_from_table(db, "guild_settings", user_id) + await _del_from_table(db, "guild_settings_channel_overrides", user_id) - await db.execute(""" + await db.execute( + """ DELETE FROM relationships WHERE user_id = $1 OR peer_id = $1 - """, user_id) + """, + user_id, + ) # DMs are still maintained, but not the state. - await _del_from_table(db, 'dm_channel_state', user_id) + await _del_from_table(db, "dm_channel_state", user_id) # NOTE: we don't delete the group dms the user is an owner of... # TODO: group dm owner reassign when the owner leaves a gdm - await _del_from_table(db, 'group_dm_members', user_id) + await _del_from_table(db, "group_dm_members", user_id) - await _del_from_table(db, 'members', user_id) - await _del_from_table(db, 'member_roles', user_id) - await _del_from_table(db, 'channel_overwrites', user_id) + await _del_from_table(db, "members", user_id) + await _del_from_table(db, "member_roles", user_id) + await _del_from_table(db, "channel_overwrites", user_id) # after updating the user, we send USER_UPDATE so that all the other # clients can refresh their caches on the now-deleted user @@ -540,15 +580,12 @@ async def user_disconnect(user_id: int): await state.ws.ws.close(4000) # force everyone to see the user as offline - await app.presence.dispatch_pres(user_id, { - 'afk': False, - 'status': 'offline', - 'game': None, - 'since': 0, - }) + await app.presence.dispatch_pres( + user_id, {"afk": False, "status": "offline", "game": None, "since": 0} + ) -@bp.route('/@me/delete', methods=['POST']) +@bp.route("/@me/delete", methods=["POST"]) async def delete_account(): """Delete own account. @@ -560,29 +597,35 @@ async def delete_account(): j = await request.get_json() try: - password = j['password'] + password = j["password"] except KeyError: - raise BadRequest('password required') + raise BadRequest("password required") - owned_guilds = await app.db.fetchval(""" + owned_guilds = await app.db.fetchval( + """ SELECT COUNT(*) FROM guilds WHERE owner_id = $1 - """, user_id) + """, + user_id, + ) if owned_guilds > 0: - raise BadRequest('You still own guilds.') + raise BadRequest("You still own guilds.") - pwd_hash = await app.db.fetchval(""" + pwd_hash = await app.db.fetchval( + """ SELECT password_hash FROM users WHERE id = $1 - """, user_id) + """, + user_id, + ) if not await check_password(pwd_hash, password): - raise Unauthorized('password does not match') + raise Unauthorized("password does not match") await delete_user(user_id) await user_disconnect(user_id) - return '', 204 + return "", 204 diff --git a/litecord/blueprints/voice.py b/litecord/blueprints/voice.py index a06eec1..7011211 100644 --- a/litecord/blueprints/voice.py +++ b/litecord/blueprints/voice.py @@ -25,7 +25,7 @@ from quart import Blueprint, jsonify, current_app as app from litecord.blueprints.auth import token_check -bp = Blueprint('voice', __name__) +bp = Blueprint("voice", __name__) def _majority_region_count(regions: list) -> str: @@ -39,12 +39,14 @@ def _majority_region_count(regions: list) -> str: async def _choose_random_region() -> Optional[str]: """Give a random voice region.""" - regions = await app.db.fetch(""" + regions = await app.db.fetch( + """ SELECT id FROM voice_regions - """) + """ + ) - regions = [r['id'] for r in regions] + regions = [r["id"] for r in regions] if not regions: return None @@ -64,11 +66,14 @@ async def _majority_region_any(user_id) -> Optional[str]: res = [] for guild_id in guilds: - region = await app.db.fetchval(""" + region = await app.db.fetchval( + """ SELECT region FROM guilds WHERE id = $1 - """, guild_id) + """, + guild_id, + ) res.append(region) @@ -83,20 +88,23 @@ async def _majority_region_any(user_id) -> Optional[str]: async def majority_region(user_id: int) -> Optional[str]: """Given a user ID, give the most likely region for the user to be happy with.""" - regions = await app.db.fetch(""" + regions = await app.db.fetch( + """ SELECT region FROM guilds WHERE owner_id = $1 - """, user_id) + """, + user_id, + ) if not regions: return await _majority_region_any(user_id) - regions = [r['region'] for r in regions] + regions = [r["region"] for r in regions] return _majority_region_count(regions) -@bp.route('/regions', methods=['GET']) +@bp.route("/regions", methods=["GET"]) async def voice_regions(): """Return voice regions.""" user_id = await token_check() @@ -105,6 +113,6 @@ async def voice_regions(): regions = await app.storage.all_voice_regions() for region in regions: - region['optimal'] = region['id'] == best_region + region["optimal"] = region["id"] == best_region return jsonify(regions) diff --git a/litecord/blueprints/webhooks.py b/litecord/blueprints/webhooks.py index 08712d3..9a0be5f 100644 --- a/litecord/blueprints/webhooks.py +++ b/litecord/blueprints/webhooks.py @@ -26,22 +26,28 @@ from quart import Blueprint, jsonify, current_app as app, request from litecord.auth import token_check from litecord.blueprints.checks import ( - channel_check, channel_perm_check, guild_check, guild_perm_check + channel_check, + channel_perm_check, + guild_check, + guild_perm_check, ) from litecord.schemas import ( - validate, WEBHOOK_CREATE, WEBHOOK_UPDATE, WEBHOOK_MESSAGE_CREATE + validate, + WEBHOOK_CREATE, + WEBHOOK_UPDATE, + WEBHOOK_MESSAGE_CREATE, ) from litecord.enums import ChannelType from litecord.snowflake import get_snowflake from litecord.utils import async_map -from litecord.errors import ( - WebhookNotFound, Unauthorized, ChannelNotFound, BadRequest -) +from litecord.errors import WebhookNotFound, Unauthorized, ChannelNotFound, BadRequest from litecord.blueprints.channel.messages import ( - msg_create_request, msg_create_check_content, msg_add_attachment, - msg_guild_text_mentions + msg_create_request, + msg_create_check_content, + msg_add_attachment, + msg_guild_text_mentions, ) from litecord.embed.sanitizer import fill_embed, fetch_raw_img from litecord.embed.messages import process_url_embed, is_media_url @@ -50,30 +56,34 @@ from litecord.utils import pg_set_json from litecord.enums import MessageType from litecord.images import STATIC_IMAGE_MIMES -bp = Blueprint('webhooks', __name__) +bp = Blueprint("webhooks", __name__) -async def get_webhook(webhook_id: int, *, - secure: bool=True) -> Optional[Dict[str, Any]]: +async def get_webhook( + webhook_id: int, *, secure: bool = True +) -> Optional[Dict[str, Any]]: """Get a webhook data""" - row = await app.db.fetchrow(""" + row = await app.db.fetchrow( + """ SELECT id::text, guild_id::text, channel_id::text, creator_id, name, avatar, token FROM webhooks WHERE id = $1 - """, webhook_id) + """, + webhook_id, + ) if not row: return None drow = dict(row) - drow['user'] = await app.storage.get_user(row['creator_id']) - drow.pop('creator_id') + drow["user"] = await app.storage.get_user(row["creator_id"]) + drow.pop("creator_id") if not secure: - drow.pop('user') - drow.pop('guild_id') + drow.pop("user") + drow.pop("guild_id") return drow @@ -82,7 +92,7 @@ async def _webhook_check(channel_id): user_id = await token_check() await channel_check(user_id, channel_id, only=ChannelType.GUILD_TEXT) - await channel_perm_check(user_id, channel_id, 'manage_webhooks') + await channel_perm_check(user_id, channel_id, "manage_webhooks") return user_id @@ -91,17 +101,20 @@ async def _webhook_check_guild(guild_id): user_id = await token_check() await guild_check(user_id, guild_id) - await guild_perm_check(user_id, guild_id, 'manage_webhooks') + await guild_perm_check(user_id, guild_id, "manage_webhooks") return user_id async def _webhook_check_fw(webhook_id): """Make a check from an incoming webhook id (fw = from webhook).""" - guild_id = await app.db.fetchval(""" + guild_id = await app.db.fetchval( + """ SELECT guild_id FROM webhooks WHERE id = $1 - """, webhook_id) + """, + webhook_id, + ) if guild_id is None: raise WebhookNotFound() @@ -110,42 +123,48 @@ async def _webhook_check_fw(webhook_id): async def _webhook_many(where_clause, arg: int): - webhook_ids = await app.db.fetch(f""" + webhook_ids = await app.db.fetch( + f""" SELECT id FROM webhooks {where_clause} - """, arg) - - webhook_ids = [r['id'] for r in webhook_ids] - - return jsonify( - await async_map(get_webhook, webhook_ids) + """, + arg, ) + webhook_ids = [r["id"] for r in webhook_ids] + + return jsonify(await async_map(get_webhook, webhook_ids)) + async def webhook_token_check(webhook_id: int, webhook_token: str): """token_check() equivalent for webhooks.""" - row = await app.db.fetchrow(""" + row = await app.db.fetchrow( + """ SELECT guild_id, channel_id FROM webhooks WHERE id = $1 AND token = $2 - """, webhook_id, webhook_token) + """, + webhook_id, + webhook_token, + ) if row is None: - raise Unauthorized('webhook not found or unauthorized') + raise Unauthorized("webhook not found or unauthorized") - return row['guild_id'], row['channel_id'] + return row["guild_id"], row["channel_id"] async def _dispatch_webhook_update(guild_id: int, channel_id): - await app.dispatcher.dispatch('guild', guild_id, 'WEBHOOKS_UPDATE', { - 'guild_id': str(guild_id), - 'channel_id': str(channel_id) - }) + await app.dispatcher.dispatch( + "guild", + guild_id, + "WEBHOOKS_UPDATE", + {"guild_id": str(guild_id), "channel_id": str(channel_id)}, + ) - -@bp.route('/channels//webhooks', methods=['POST']) +@bp.route("/channels//webhooks", methods=["POST"]) async def create_webhook(channel_id: int): """Create a webhook given a channel.""" user_id = await _webhook_check(channel_id) @@ -162,8 +181,7 @@ async def create_webhook(channel_id: int): token = secrets.token_urlsafe(40) webhook_icon = await app.icons.put( - 'user', webhook_id, j.get('avatar'), - always_icon=True, size=(128, 128) + "user", webhook_id, j.get("avatar"), always_icon=True, size=(128, 128) ) await app.db.execute( @@ -173,36 +191,41 @@ async def create_webhook(channel_id: int): VALUES ($1, $2, $3, $4, $5, $6, $7) """, - webhook_id, guild_id, channel_id, user_id, - j['name'], webhook_icon.icon_hash, token + webhook_id, + guild_id, + channel_id, + user_id, + j["name"], + webhook_icon.icon_hash, + token, ) await _dispatch_webhook_update(guild_id, channel_id) return jsonify(await get_webhook(webhook_id)) -@bp.route('/channels//webhooks', methods=['GET']) +@bp.route("/channels//webhooks", methods=["GET"]) async def get_channel_webhook(channel_id: int): """Get a list of webhooks in a channel""" await _webhook_check(channel_id) - return await _webhook_many('WHERE channel_id = $1', channel_id) + return await _webhook_many("WHERE channel_id = $1", channel_id) -@bp.route('/guilds//webhooks', methods=['GET']) +@bp.route("/guilds//webhooks", methods=["GET"]) async def get_guild_webhook(guild_id): """Get all webhooks in a guild""" await _webhook_check_guild(guild_id) - return await _webhook_many('WHERE guild_id = $1', guild_id) + return await _webhook_many("WHERE guild_id = $1", guild_id) -@bp.route('/webhooks/', methods=['GET']) +@bp.route("/webhooks/", methods=["GET"]) async def get_single_webhook(webhook_id): """Get a single webhook's information.""" await _webhook_check_fw(webhook_id) return await jsonify(await get_webhook(webhook_id)) -@bp.route('/webhooks//', methods=['GET']) +@bp.route("/webhooks//", methods=["GET"]) async def get_tokened_webhook(webhook_id, webhook_token): """Get a webhook using its token.""" await webhook_token_check(webhook_id, webhook_token) @@ -210,46 +233,58 @@ async def get_tokened_webhook(webhook_id, webhook_token): async def _update_webhook(webhook_id: int, j: dict): - if 'name' in j: - await app.db.execute(""" + if "name" in j: + await app.db.execute( + """ UPDATE webhooks SET name = $1 WHERE id = $2 - """, j['name'], webhook_id) + """, + j["name"], + webhook_id, + ) - if 'channel_id' in j: - await app.db.execute(""" + if "channel_id" in j: + await app.db.execute( + """ UPDATE webhooks SET channel_id = $1 WHERE id = $2 - """, j['channel_id'], webhook_id) - - if 'avatar' in j: - new_icon = await app.icons.update( - 'user', webhook_id, j['avatar'], always_icon=True, size=(128, 128) + """, + j["channel_id"], + webhook_id, ) - await app.db.execute(""" + if "avatar" in j: + new_icon = await app.icons.update( + "user", webhook_id, j["avatar"], always_icon=True, size=(128, 128) + ) + + await app.db.execute( + """ UPDATE webhooks SET icon = $1 WHERE id = $2 - """, new_icon.icon_hash, webhook_id) + """, + new_icon.icon_hash, + webhook_id, + ) -@bp.route('/webhooks/', methods=['PATCH']) +@bp.route("/webhooks/", methods=["PATCH"]) async def modify_webhook(webhook_id: int): """Patch a webhook.""" _user_id, guild_id = await _webhook_check_fw(webhook_id) j = validate(await request.get_json(), WEBHOOK_UPDATE) - if 'channel_id' in j: + if "channel_id" in j: # pre checks - chan = await app.storage.get_channel(j['channel_id']) + chan = await app.storage.get_channel(j["channel_id"]) # short-circuiting should ensure chan isn't none # by the time we do chan['guild_id'] - if chan and chan['guild_id'] != str(guild_id): - raise ChannelNotFound('cant assign webhook to channel') + if chan and chan["guild_id"] != str(guild_id): + raise ChannelNotFound("cant assign webhook to channel") await _update_webhook(webhook_id, j) @@ -257,20 +292,18 @@ async def modify_webhook(webhook_id: int): # we don't need to cast channel_id to int since that isn't # used in the dispatcher call - await _dispatch_webhook_update(guild_id, webhook['channel_id']) + await _dispatch_webhook_update(guild_id, webhook["channel_id"]) return jsonify(webhook) -@bp.route('/webhooks//', methods=['PATCH']) +@bp.route("/webhooks//", methods=["PATCH"]) async def modify_webhook_tokened(webhook_id, webhook_token): """Modify a webhook, using its token.""" - guild_id, channel_id = await webhook_token_check( - webhook_id, webhook_token) + guild_id, channel_id = await webhook_token_check(webhook_id, webhook_token) # forcefully pop() the channel id out of the schema # instead of making another, for simplicity's sake - j = validate(await request.get_json(), - WEBHOOK_UPDATE.pop('channel_id')) + j = validate(await request.get_json(), WEBHOOK_UPDATE.pop("channel_id")) await _update_webhook(webhook_id, j) await _dispatch_webhook_update(guild_id, channel_id) @@ -281,35 +314,36 @@ async def delete_webhook(webhook_id: int): """Delete a webhook.""" webhook = await get_webhook(webhook_id) - res = await app.db.execute(""" + res = await app.db.execute( + """ DELETE FROM webhooks WHERE id = $1 - """, webhook_id) + """, + webhook_id, + ) - if res.lower() == 'delete 0': + if res.lower() == "delete 0": raise WebhookNotFound() # only casting the guild id since that's whats used # on the dispatcher call. - await _dispatch_webhook_update( - int(webhook['guild_id']), webhook['channel_id'] - ) + await _dispatch_webhook_update(int(webhook["guild_id"]), webhook["channel_id"]) -@bp.route('/webhooks/', methods=['DELETE']) +@bp.route("/webhooks/", methods=["DELETE"]) async def del_webhook(webhook_id): """Delete a webhook.""" await _webhook_check_fw(webhook_id) await delete_webhook(webhook_id) - return '', 204 + return "", 204 -@bp.route('/webhooks//', methods=['DELETE']) +@bp.route("/webhooks//", methods=["DELETE"]) async def del_webhook_tokened(webhook_id, webhook_token): """Delete a webhook, with its token.""" await webhook_token_check(webhook_id, webhook_token) await delete_webhook(webhook_id) - return '', 204 + return "", 204 async def create_message_webhook(guild_id, channel_id, webhook_id, data): @@ -328,23 +362,27 @@ async def create_message_webhook(guild_id, channel_id, webhook_id, data): message_id, channel_id, guild_id, - data['content'], - - data['tts'], - data['everyone_mention'], - + data["content"], + data["tts"], + data["everyone_mention"], MessageType.DEFAULT.value, - data.get('embeds', []) + data.get("embeds", []), ) - info = data['info'] + info = data["info"] - await conn.execute(""" + await conn.execute( + """ INSERT INTO message_webhook_info (message_id, webhook_id, name, avatar) VALUES ($1, $2, $3, $4) - """, message_id, webhook_id, info['name'], info['avatar']) + """, + message_id, + webhook_id, + info["name"], + info["avatar"], + ) return message_id @@ -354,10 +392,15 @@ async def _webhook_avy_redir(webhook_id: int, avatar_url: EmbedURL): url_hash = hashlib.sha256(avatar_url.to_md_path.encode()).hexdigest() try: - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO webhook_avatars (webhook_id, hash, md_url_redir) VALUES ($1, $2, $3) - """, webhook_id, url_hash, avatar_url.url) + """, + webhook_id, + url_hash, + avatar_url.url, + ) except asyncpg.UniqueViolationError: pass @@ -371,36 +414,36 @@ async def _create_avatar(webhook_id: int, avatar_url: EmbedURL) -> str: Litecord will write an URL that redirects to the given avatar_url, using mediaproxy. """ - if avatar_url.scheme not in ('http', 'https'): - raise BadRequest('invalid avatar url scheme') + if avatar_url.scheme not in ("http", "https"): + raise BadRequest("invalid avatar url scheme") if not is_media_url(avatar_url): - raise BadRequest('url is not media url') + raise BadRequest("url is not media url") # we still fetch the URL to check its validity, mimetypes, etc # but in the end, we will store it under the webhook_avatars table, # not IconManager. resp, raw = await fetch_raw_img(avatar_url) - #raw_b64 = base64.b64encode(raw).decode() + # raw_b64 = base64.b64encode(raw).decode() - mime = resp.headers['content-type'] + mime = resp.headers["content-type"] # TODO: apng checks are missing (for this and everywhere else) if mime not in STATIC_IMAGE_MIMES: - raise BadRequest('invalid mime type for given url') + raise BadRequest("invalid mime type for given url") - #b64_data = f'data:{mime};base64,{raw_b64}' + # b64_data = f'data:{mime};base64,{raw_b64}' # TODO: replace this by webhook_avatars - #icon = await app.icons.put( + # icon = await app.icons.put( # 'user', webhook_id, b64_data, # always_icon=True, size=(128, 128) - #) + # ) return await _webhook_avy_redir(webhook_id, avatar_url) -@bp.route('/webhooks//', methods=['POST']) +@bp.route("/webhooks//", methods=["POST"]) async def execute_webhook(webhook_id: int, webhook_token): """Execute a webhook. Sends a message to the channel the webhook is tied to.""" @@ -413,41 +456,39 @@ async def execute_webhook(webhook_id: int, webhook_token): # NOTE: we really pop here instead of adding a kwarg # to msg_create_request just because of webhooks. # nonce isn't allowed on WEBHOOK_MESSAGE_CREATE - payload_json.pop('nonce') + payload_json.pop("nonce") j = validate(payload_json, WEBHOOK_MESSAGE_CREATE) msg_create_check_content(j, files) # webhooks don't need permissions. - mentions_everyone = '@everyone' in j['content'] - mentions_here = '@here' in j['content'] + mentions_everyone = "@everyone" in j["content"] + mentions_here = "@here" in j["content"] - given_embeds = j.get('embeds', []) + given_embeds = j.get("embeds", []) webhook = await get_webhook(webhook_id) # webhooks have TWO avatars. one is from settings, the other is from # the json's icon_url. one can be handled gracefully by IconManager, # but the other can't, at all. - avatar = webhook['avatar'] + avatar = webhook["avatar"] - if 'avatar_url' in j and j['avatar_url'] is not None: - avatar = await _create_avatar(webhook_id, j['avatar_url']) + if "avatar_url" in j and j["avatar_url"] is not None: + avatar = await _create_avatar(webhook_id, j["avatar_url"]) message_id = await create_message_webhook( - guild_id, channel_id, webhook_id, { - 'content': j.get('content', ''), - 'tts': j.get('tts', False), - - 'everyone_mention': mentions_everyone or mentions_here, - 'embeds': await async_map(fill_embed, given_embeds), - - 'info': { - 'name': j.get('username', webhook['name']), - 'avatar': avatar - } - } + guild_id, + channel_id, + webhook_id, + { + "content": j.get("content", ""), + "tts": j.get("tts", False), + "everyone_mention": mentions_everyone or mentions_here, + "embeds": await async_map(fill_embed, given_embeds), + "info": {"name": j.get("username", webhook["name"]), "avatar": avatar}, + }, ) for pre_attachment in files: @@ -455,33 +496,28 @@ async def execute_webhook(webhook_id: int, webhook_token): payload = await app.storage.get_message(message_id) - await app.dispatcher.dispatch('channel', channel_id, - 'MESSAGE_CREATE', payload) + await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload) # spawn embedder in the background, even when we're on a webhook. app.sched.spawn( - process_url_embed( - app.config, app.storage, app.dispatcher, app.session, - payload - ) + process_url_embed(app.config, app.storage, app.dispatcher, app.session, payload) ) # we can assume its a guild text channel, so just call it - await msg_guild_text_mentions( - payload, guild_id, mentions_everyone, mentions_here) + await msg_guild_text_mentions(payload, guild_id, mentions_everyone, mentions_here) # TODO: is it really 204? - return '', 204 + return "", 204 -@bp.route('/webhooks///slack', - methods=['POST']) + +@bp.route("/webhooks///slack", methods=["POST"]) async def execute_slack_webhook(webhook_id, webhook_token): """Execute a webhook but expecting Slack data.""" # TODO: know slack webhooks await webhook_token_check(webhook_id, webhook_token) -@bp.route('/webhooks///github', methods=['POST']) +@bp.route("/webhooks///github", methods=["POST"]) async def execute_github_webhook(webhook_id, webhook_token): """Execute a webhook but expecting GitHub data.""" # TODO: know github webhooks diff --git a/litecord/dispatcher.py b/litecord/dispatcher.py index 96f29d3..4028757 100644 --- a/litecord/dispatcher.py +++ b/litecord/dispatcher.py @@ -21,9 +21,14 @@ from typing import List, Any, Dict from logbook import Logger -from .pubsub import GuildDispatcher, MemberDispatcher, \ - UserDispatcher, ChannelDispatcher, FriendDispatcher, \ - LazyGuildDispatcher +from .pubsub import ( + GuildDispatcher, + MemberDispatcher, + UserDispatcher, + ChannelDispatcher, + FriendDispatcher, + LazyGuildDispatcher, +) log = Logger(__name__) @@ -44,17 +49,18 @@ class EventDispatcher: when dispatching, the backend can do its own logic, given its subscriber ids. """ + def __init__(self, app): self.state_manager = app.state_manager self.app = app self.backends = { - 'guild': GuildDispatcher(self), - 'member': MemberDispatcher(self), - 'channel': ChannelDispatcher(self), - 'user': UserDispatcher(self), - 'friend': FriendDispatcher(self), - 'lazy_guild': LazyGuildDispatcher(self), + "guild": GuildDispatcher(self), + "member": MemberDispatcher(self), + "channel": ChannelDispatcher(self), + "user": UserDispatcher(self), + "friend": FriendDispatcher(self), + "lazy_guild": LazyGuildDispatcher(self), } async def action(self, backend_str: str, action: str, key, identifier, *args): @@ -71,13 +77,13 @@ class EventDispatcher: return await method(key, identifier, *args) - async def subscribe(self, backend: str, key: Any, identifier: Any, - flags: Dict[str, Any] = None): + async def subscribe( + self, backend: str, key: Any, identifier: Any, flags: Dict[str, Any] = None + ): """Subscribe a single element to the given backend.""" flags = flags or {} - log.debug('SUB backend={} key={} <= id={}', - backend, key, identifier, backend) + log.debug("SUB backend={} key={} <= id={}", backend, key, identifier, backend) # this is a hacky solution for backwards compatibility between backends # that implement flags and backends that don't. @@ -85,16 +91,15 @@ class EventDispatcher: # passing flags to backends that don't implement flags will # cause errors as expected. if flags: - return await self.action(backend, 'sub', key, identifier, flags) + return await self.action(backend, "sub", key, identifier, flags) - return await self.action(backend, 'sub', key, identifier) + return await self.action(backend, "sub", key, identifier) async def unsubscribe(self, backend: str, key: Any, identifier: Any): """Unsubscribe an element from the given backend.""" - log.debug('UNSUB backend={} key={} => id={}', - backend, key, identifier, backend) + log.debug("UNSUB backend={} key={} => id={}", backend, key, identifier, backend) - return await self.action(backend, 'unsub', key, identifier) + return await self.action(backend, "unsub", key, identifier) async def sub(self, backend, key, identifier): """Alias to subscribe().""" @@ -104,8 +109,13 @@ class EventDispatcher: """Alias to unsubscribe().""" return await self.unsubscribe(backend, key, identifier) - async def sub_many(self, backend_str: str, identifier: Any, - keys: list, flags: Dict[str, Any] = None): + async def sub_many( + self, + backend_str: str, + identifier: Any, + keys: list, + flags: Dict[str, Any] = None, + ): """Subscribe to multiple channels (all in a single backend) at a time. @@ -116,8 +126,7 @@ class EventDispatcher: for key in keys: await self.subscribe(backend_str, key, identifier, flags) - async def mass_sub(self, identifier: Any, - backends: List[tuple]): + async def mass_sub(self, identifier: Any, backends: List[tuple]): """Mass subscribe to many backends at once.""" for bcall in backends: backend_str, keys = bcall[0], bcall[1] @@ -128,8 +137,13 @@ class EventDispatcher: # we have flags flags = bcall[2] - log.debug('subscribing {} to {} keys in backend {}, flags: {}', - identifier, len(keys), backend_str, flags) + log.debug( + "subscribing {} to {} keys in backend {}, flags: {}", + identifier, + len(keys), + backend_str, + flags, + ) await self.sub_many(backend_str, identifier, keys, flags) @@ -145,17 +159,14 @@ class EventDispatcher: key = backend.KEY_TYPE(key) return await backend.dispatch(key, *args, **kwargs) - async def dispatch_many(self, backend_str: str, - keys: List[Any], *args, **kwargs): + async def dispatch_many(self, backend_str: str, keys: List[Any], *args, **kwargs): """Dispatch to multiple keys in a single backend.""" - log.info('MULTI DISPATCH: {!r}, {} keys', - backend_str, len(keys)) + log.info("MULTI DISPATCH: {!r}, {} keys", backend_str, len(keys)) for key in keys: await self.dispatch(backend_str, key, *args, **kwargs) - async def dispatch_filter(self, backend_str: str, - key: Any, func, *args): + async def dispatch_filter(self, backend_str: str, key: Any, func, *args): """Dispatch to a backend that only accepts (event, data) arguments with an optional filter function.""" @@ -163,9 +174,9 @@ class EventDispatcher: key = backend.KEY_TYPE(key) return await backend.dispatch_filter(key, func, *args) - async def dispatch_many_filter_list(self, backend_str: str, - keys: List[Any], sess_list: List[str], - *args): + async def dispatch_many_filter_list( + self, backend_str: str, keys: List[Any], sess_list: List[str], *args + ): """Make a "unique" dispatch given a list of session ids. This only works for backends that have a dispatch_filter @@ -175,9 +186,8 @@ class EventDispatcher: for key in keys: sess_list.extend( await self.dispatch_filter( - backend_str, key, - lambda sess_id: sess_id not in sess_list, - *args) + backend_str, key, lambda sess_id: sess_id not in sess_list, *args + ) ) return sess_list @@ -197,12 +207,12 @@ class EventDispatcher: async def dispatch_guild(self, guild_id, event, data): """Backwards compatibility with old EventDispatcher.""" - return await self.dispatch('guild', guild_id, event, data) + return await self.dispatch("guild", guild_id, event, data) async def dispatch_user_guild(self, user_id, guild_id, event, data): """Backwards compatibility with old EventDispatcher.""" - return await self.dispatch('member', (guild_id, user_id), event, data) + return await self.dispatch("member", (guild_id, user_id), event, data) async def dispatch_user(self, user_id, event, data): """Backwards compatibility with old EventDispatcher.""" - return await self.dispatch('user', user_id, event, data) + return await self.dispatch("user", user_id, event, data) diff --git a/litecord/embed/__init__.py b/litecord/embed/__init__.py index 41eb159..c813d99 100644 --- a/litecord/embed/__init__.py +++ b/litecord/embed/__init__.py @@ -19,4 +19,4 @@ along with this program. If not, see . from .sanitizer import sanitize_embed -__all__ = ['sanitize_embed'] +__all__ = ["sanitize_embed"] diff --git a/litecord/embed/messages.py b/litecord/embed/messages.py index 4807c0f..ce23ea4 100644 --- a/litecord/embed/messages.py +++ b/litecord/embed/messages.py @@ -30,11 +30,7 @@ from litecord.embed.schemas import EmbedURL log = Logger(__name__) -MEDIA_EXTENSIONS = ( - 'png', - 'jpg', 'jpeg', - 'gif', 'webm' -) +MEDIA_EXTENSIONS = ("png", "jpg", "jpeg", "gif", "webm") async def insert_media_meta(url, config, session): @@ -45,18 +41,18 @@ async def insert_media_meta(url, config, session): if meta is None: return - if not meta['image']: + if not meta["image"]: return return { - 'type': 'image', - 'url': url, - 'thumbnail': { - 'width': meta['width'], - 'height': meta['height'], - 'url': url, - 'proxy_url': img_proxy_url - } + "type": "image", + "url": url, + "thumbnail": { + "width": meta["width"], + "height": meta["height"], + "url": url, + "proxy_url": img_proxy_url, + }, } @@ -64,29 +60,32 @@ async def msg_update_embeds(payload, new_embeds, storage, dispatcher): """Update the message with the given embeds and dispatch a MESSAGE_UPDATE to users.""" - message_id = int(payload['id']) - channel_id = int(payload['channel_id']) + message_id = int(payload["id"]) + channel_id = int(payload["channel_id"]) - await storage.execute_with_json(""" + await storage.execute_with_json( + """ UPDATE messages SET embeds = $1 WHERE messages.id = $2 - """, new_embeds, message_id) + """, + new_embeds, + message_id, + ) update_payload = { - 'id': str(message_id), - 'channel_id': str(channel_id), - 'embeds': new_embeds, + "id": str(message_id), + "channel_id": str(channel_id), + "embeds": new_embeds, } - if 'guild_id' in payload: - update_payload['guild_id'] = payload['guild_id'] + if "guild_id" in payload: + update_payload["guild_id"] = payload["guild_id"] - if 'flags' in payload: - update_payload['flags'] = payload['flags'] + if "flags" in payload: + update_payload["flags"] = payload["flags"] - await dispatcher.dispatch( - 'channel', channel_id, 'MESSAGE_UPDATE', update_payload) + await dispatcher.dispatch("channel", channel_id, "MESSAGE_UPDATE", update_payload) def is_media_url(url) -> bool: @@ -98,7 +97,7 @@ def is_media_url(url) -> bool: parsed = urllib.parse.urlparse(url) path = Path(parsed.path) - extension = path.suffix.lstrip('.') + extension = path.suffix.lstrip(".") return extension in MEDIA_EXTENSIONS @@ -109,20 +108,20 @@ async def insert_mp_embed(parsed, config, session): return embed -async def process_url_embed(config, storage, dispatcher, - session, payload: dict, *, delay=0): +async def process_url_embed( + config, storage, dispatcher, session, payload: dict, *, delay=0 +): """Process URLs in a message and generate embeds based on that.""" await asyncio.sleep(delay) - message_id = int(payload['id']) + message_id = int(payload["id"]) # if we already have embeds # we shouldn't add our own. - embeds = payload['embeds'] + embeds = payload["embeds"] if embeds: - log.debug('url processor: ignoring existing embeds @ mid {}', - message_id) + log.debug("url processor: ignoring existing embeds @ mid {}", message_id) return # now, we have two types of embeds: @@ -130,7 +129,7 @@ async def process_url_embed(config, storage, dispatcher, # - url embeds # use regex to get URLs - urls = re.findall(r'(https?://\S+)', payload['content']) + urls = re.findall(r"(https?://\S+)", payload["content"]) urls = urls[:5] # from there, we need to parse each found url and check its path. @@ -159,7 +158,6 @@ async def process_url_embed(config, storage, dispatcher, if not new_embeds: return - log.debug('made {} embeds for mid {}', - len(new_embeds), message_id) + log.debug("made {} embeds for mid {}", len(new_embeds), message_id) await msg_update_embeds(payload, new_embeds, storage, dispatcher) diff --git a/litecord/embed/sanitizer.py b/litecord/embed/sanitizer.py index 96c932f..b14e436 100644 --- a/litecord/embed/sanitizer.py +++ b/litecord/embed/sanitizer.py @@ -39,9 +39,7 @@ def sanitize_embed(embed: Embed) -> Embed: This is non-complex sanitization as it doesn't need the app object. """ - return {**embed, **{ - 'type': 'rich' - }} + return {**embed, **{"type": "rich"}} def path_exists(embed: Embed, components_in: Union[List[str], str]): @@ -55,7 +53,7 @@ def path_exists(embed: Embed, components_in: Union[List[str], str]): # get the list of components given if isinstance(components_in, str): - components = components_in.split('.') + components = components_in.split(".") else: components = list(components_in) @@ -77,7 +75,6 @@ def path_exists(embed: Embed, components_in: Union[List[str], str]): return False - def _mk_cfg_sess(config, session) -> tuple: """Return a tuple of (config, session).""" if config is None: @@ -91,11 +88,11 @@ def _mk_cfg_sess(config, session) -> tuple: def _md_base(config) -> Optional[tuple]: """Return the protocol and base url for the mediaproxy.""" - md_base_url = config['MEDIA_PROXY'] + md_base_url = config["MEDIA_PROXY"] if md_base_url is None: return None - proto = 'https' if config['IS_SSL'] else 'http' + proto = "https" if config["IS_SSL"] else "http" return proto, md_base_url @@ -111,7 +108,7 @@ def make_md_req_url(config, scope: str, url): return url.url if isinstance(url, EmbedURL) else url proto, base_url = base - return f'{proto}://{base_url}/{scope}/{url.to_md_path}' + return f"{proto}://{base_url}/{scope}/{url.to_md_path}" def proxify(url, *, config=None) -> str: @@ -122,11 +119,12 @@ def proxify(url, *, config=None) -> str: if isinstance(url, str): url = EmbedURL(url) - return make_md_req_url(config, 'img', url) + return make_md_req_url(config, "img", url) -async def _md_client_req(config, session, scope: str, - url, *, ret_resp=False) -> Optional[Union[Tuple, Dict]]: +async def _md_client_req( + config, session, scope: str, url, *, ret_resp=False +) -> Optional[Union[Tuple, Dict]]: """Makes a request to the mediaproxy. This has common code between all the main mediaproxy request functions @@ -172,17 +170,13 @@ async def _md_client_req(config, session, scope: str, return await resp.json() body = await resp.text() - log.warning('failed to call {!r}, {} {!r}', - request_url, resp.status, body) + log.warning("failed to call {!r}, {} {!r}", request_url, resp.status, body) return None -async def fetch_metadata(url, *, config=None, - session=None) -> Optional[Dict]: +async def fetch_metadata(url, *, config=None, session=None) -> Optional[Dict]: """Fetch metadata for a url (image width, mime, etc).""" - return await _md_client_req( - config, session, 'meta', url - ) + return await _md_client_req(config, session, "meta", url) async def fetch_raw_img(url, *, config=None, session=None) -> Optional[tuple]: @@ -191,9 +185,7 @@ async def fetch_raw_img(url, *, config=None, session=None) -> Optional[tuple]: Returns a tuple containing the response object and the raw bytes given by the website. """ - tup = await _md_client_req( - config, session, 'img', url, ret_resp=True - ) + tup = await _md_client_req(config, session, "img", url, ret_resp=True) if not tup: return None @@ -207,9 +199,7 @@ async def fetch_embed(url, *, config=None, session=None) -> Dict[str, Any]: Returns a discord embed object. """ - return await _md_client_req( - config, session, 'embed', url - ) + return await _md_client_req(config, session, "embed", url) async def fill_embed(embed: Optional[Embed]) -> Optional[Embed]: @@ -229,22 +219,20 @@ async def fill_embed(embed: Optional[Embed]) -> Optional[Embed]: embed = sanitize_embed(embed) - if path_exists(embed, 'footer.icon_url'): - embed['footer']['proxy_icon_url'] = \ - proxify(embed['footer']['icon_url']) + if path_exists(embed, "footer.icon_url"): + embed["footer"]["proxy_icon_url"] = proxify(embed["footer"]["icon_url"]) - if path_exists(embed, 'author.icon_url'): - embed['author']['proxy_icon_url'] = \ - proxify(embed['author']['icon_url']) + if path_exists(embed, "author.icon_url"): + embed["author"]["proxy_icon_url"] = proxify(embed["author"]["icon_url"]) - if path_exists(embed, 'image.url'): - image_url = embed['image']['url'] + if path_exists(embed, "image.url"): + image_url = embed["image"]["url"] meta = await fetch_metadata(image_url) - embed['image']['proxy_url'] = proxify(image_url) + embed["image"]["proxy_url"] = proxify(image_url) - if meta and meta['image']: - embed['image']['width'] = meta['width'] - embed['image']['height'] = meta['height'] + if meta and meta["image"]: + embed["image"]["width"] = meta["width"] + embed["image"]["height"] = meta["height"] return embed diff --git a/litecord/embed/schemas.py b/litecord/embed/schemas.py index 6f3b306..c93bb2d 100644 --- a/litecord/embed/schemas.py +++ b/litecord/embed/schemas.py @@ -28,8 +28,8 @@ class EmbedURL: def __init__(self, url: str): parsed = urllib.parse.urlparse(url) - if parsed.scheme not in ('http', 'https', 'attachment'): - raise ValueError('Invalid URL scheme') + if parsed.scheme not in ("http", "https", "attachment"): + raise ValueError("Invalid URL scheme") self.scheme = parsed.scheme self.raw_url = url @@ -54,105 +54,61 @@ class EmbedURL: def to_md_path(self) -> str: """Convert the EmbedURL to a mediaproxy path (post img/meta).""" parsed = self.parsed - return ( - f'{parsed.scheme}/{parsed.netloc}' - f'{parsed.path}?{parsed.query}' - ) + return f"{parsed.scheme}/{parsed.netloc}" f"{parsed.path}?{parsed.query}" EMBED_FOOTER = { - 'text': { - 'type': 'string', 'minlength': 1, 'maxlength': 1024, 'required': True}, - - 'icon_url': { - 'coerce': EmbedURL, 'required': False, - }, - + "text": {"type": "string", "minlength": 1, "maxlength": 1024, "required": True}, + "icon_url": {"coerce": EmbedURL, "required": False}, # NOTE: proxy_icon_url set by us } EMBED_IMAGE = { - 'url': {'coerce': EmbedURL, 'required': True}, - + "url": {"coerce": EmbedURL, "required": True}, # NOTE: proxy_url, width, height set by us } EMBED_THUMBNAIL = EMBED_IMAGE EMBED_AUTHOR = { - 'name': { - 'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': False - }, - 'url': { - 'coerce': EmbedURL, 'required': False, - }, - 'icon_url': { - 'coerce': EmbedURL, 'required': False, - } - + "name": {"type": "string", "minlength": 1, "maxlength": 256, "required": False}, + "url": {"coerce": EmbedURL, "required": False}, + "icon_url": {"coerce": EmbedURL, "required": False} # NOTE: proxy_icon_url set by us } EMBED_FIELD = { - 'name': { - 'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': True - }, - 'value': { - 'type': 'string', 'minlength': 1, 'maxlength': 1024, 'required': True - }, - 'inline': { - 'type': 'boolean', 'required': False, 'default': True, - }, + "name": {"type": "string", "minlength": 1, "maxlength": 256, "required": True}, + "value": {"type": "string", "minlength": 1, "maxlength": 1024, "required": True}, + "inline": {"type": "boolean", "required": False, "default": True}, } EMBED_OBJECT = { - 'title': { - 'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': False}, + "title": {"type": "string", "minlength": 1, "maxlength": 256, "required": False}, # NOTE: type set by us - 'description': { - 'type': 'string', 'minlength': 1, 'maxlength': 2048, 'required': False, + "description": { + "type": "string", + "minlength": 1, + "maxlength": 2048, + "required": False, }, - 'url': { - 'coerce': EmbedURL, 'required': False, - }, - 'timestamp': { + "url": {"coerce": EmbedURL, "required": False}, + "timestamp": { # TODO: an ISO 8601 type # TODO: maybe replace the default in here with now().isoformat? - 'type': 'string', 'required': False + "type": "string", + "required": False, }, - - 'color': { - 'coerce': Color, 'required': False - }, - - 'footer': { - 'type': 'dict', - 'schema': EMBED_FOOTER, - 'required': False, - }, - 'image': { - 'type': 'dict', - 'schema': EMBED_IMAGE, - 'required': False, - }, - 'thumbnail': { - 'type': 'dict', - 'schema': EMBED_THUMBNAIL, - 'required': False, - }, - + "color": {"coerce": Color, "required": False}, + "footer": {"type": "dict", "schema": EMBED_FOOTER, "required": False}, + "image": {"type": "dict", "schema": EMBED_IMAGE, "required": False}, + "thumbnail": {"type": "dict", "schema": EMBED_THUMBNAIL, "required": False}, # NOTE: 'video' set by us # NOTE: 'provider' set by us - - 'author': { - 'type': 'dict', - 'schema': EMBED_AUTHOR, - 'required': False, - }, - - 'fields': { - 'type': 'list', - 'schema': {'type': 'dict', 'schema': EMBED_FIELD}, - 'required': False, + "author": {"type": "dict", "schema": EMBED_AUTHOR, "required": False}, + "fields": { + "type": "list", + "schema": {"type": "dict", "schema": EMBED_FIELD}, + "required": False, }, } diff --git a/litecord/enums.py b/litecord/enums.py index c9ecb9a..aaf71a0 100644 --- a/litecord/enums.py +++ b/litecord/enums.py @@ -52,13 +52,14 @@ class Flags: >>> i2.is_field_3 False """ + def __init_subclass__(cls, **_kwargs): attrs = inspect.getmembers(cls, lambda x: not inspect.isroutine(x)) def _make_int(value): res = Flags() - setattr(res, 'value', value) + setattr(res, "value", value) for attr, val in attrs: # get only the ones that represent a field in the @@ -69,7 +70,7 @@ class Flags: has_attr = (value & val) == val # set each attribute - setattr(res, f'is_{attr}', has_attr) + setattr(res, f"is_{attr}", has_attr) return res @@ -84,17 +85,16 @@ class ChannelType(EasyEnum): GUILD_CATEGORY = 4 -GUILD_CHANS = (ChannelType.GUILD_TEXT, - ChannelType.GUILD_VOICE, - ChannelType.GUILD_CATEGORY) - - -VOICE_CHANNELS = ( - ChannelType.DM, ChannelType.GUILD_VOICE, - ChannelType.GUILD_CATEGORY +GUILD_CHANS = ( + ChannelType.GUILD_TEXT, + ChannelType.GUILD_VOICE, + ChannelType.GUILD_CATEGORY, ) +VOICE_CHANNELS = (ChannelType.DM, ChannelType.GUILD_VOICE, ChannelType.GUILD_CATEGORY) + + class ActivityType(EasyEnum): PLAYING = 0 STREAMING = 1 @@ -120,7 +120,7 @@ SYS_MESSAGES = ( MessageType.CHANNEL_NAME_CHANGE, MessageType.CHANNEL_ICON_CHANGE, MessageType.CHANNEL_PINNED_MESSAGE, - MessageType.GUILD_MEMBER_JOIN + MessageType.GUILD_MEMBER_JOIN, ) @@ -137,6 +137,7 @@ class ActivityFlags(Flags): Only related to rich presence. """ + instance = 1 join = 2 spectate = 4 @@ -150,6 +151,7 @@ class UserFlags(Flags): Used by the client to show badges. """ + staff = 1 partner = 2 hypesquad = 4 @@ -166,6 +168,7 @@ class UserFlags(Flags): class MessageFlags(Flags): """Message flags.""" + none = 0 crossposted = 1 << 0 @@ -175,11 +178,12 @@ class MessageFlags(Flags): class StatusType(EasyEnum): """All statuses there can be in a presence.""" - ONLINE = 'online' - DND = 'dnd' - IDLE = 'idle' - INVISIBLE = 'invisible' - OFFLINE = 'offline' + + ONLINE = "online" + DND = "dnd" + IDLE = "idle" + INVISIBLE = "invisible" + OFFLINE = "offline" class ExplicitFilter(EasyEnum): @@ -187,6 +191,7 @@ class ExplicitFilter(EasyEnum): Also applies to guilds. """ + EDGE = 0 FRIENDS = 1 SAFE = 2 @@ -194,6 +199,7 @@ class ExplicitFilter(EasyEnum): class VerificationLevel(IntEnum): """Verification level for guilds.""" + NONE = 0 LOW = 1 MEDIUM = 2 @@ -205,6 +211,7 @@ class VerificationLevel(IntEnum): class RelationshipType(EasyEnum): """Relationship types between users.""" + FRIEND = 1 BLOCK = 2 INCOMING = 3 @@ -213,6 +220,7 @@ class RelationshipType(EasyEnum): class MessageNotifications(EasyEnum): """Message notifications""" + ALL = 0 MENTIONS = 1 NOTHING = 2 @@ -220,6 +228,7 @@ class MessageNotifications(EasyEnum): class PremiumType: """Premium (Nitro) type.""" + TIER_1 = 1 TIER_2 = 2 NONE = None @@ -227,12 +236,13 @@ class PremiumType: class Feature(EasyEnum): """Guild features.""" - invite_splash = 'INVITE_SPLASH' - vip = 'VIP_REGIONS' - vanity = 'VANITY_URL' - emoji = 'MORE_EMOJI' - verified = 'VERIFIED' + + invite_splash = "INVITE_SPLASH" + vip = "VIP_REGIONS" + vanity = "VANITY_URL" + emoji = "MORE_EMOJI" + verified = "VERIFIED" # unknown - commerce = 'COMMERCE' - news = 'NEWS' + commerce = "COMMERCE" + news = "NEWS" diff --git a/litecord/errors.py b/litecord/errors.py index c171b15..d2a72c2 100644 --- a/litecord/errors.py +++ b/litecord/errors.py @@ -18,60 +18,64 @@ along with this program. If not, see . """ ERR_MSG_MAP = { - 10001: 'Unknown account', - 10002: 'Unknown application', - 10003: 'Unknown channel', - 10004: 'Unknown guild', - 10005: 'Unknown integration', - 10006: 'Unknown invite', - 10007: 'Unknown member', - 10008: 'Unknown message', - 10009: 'Unknown overwrite', - 10010: 'Unknown provider', - 10011: 'Unknown role', - 10012: 'Unknown token', - 10013: 'Unknown user', - 10014: 'Unknown Emoji', - 10015: 'Unknown Webhook', - 20001: 'Bots cannot use this endpoint', - 20002: 'Only bots can use this endpoint', - 30001: 'Maximum number of guilds reached (100)', - 30002: 'Maximum number of friends reached (1000)', - 30003: 'Maximum number of pins reached (50)', - 30005: 'Maximum number of guild roles reached (250)', - 30010: 'Maximum number of reactions reached (20)', - 30013: 'Maximum number of guild channels reached (500)', - 40001: 'Unauthorized', - 50001: 'Missing access', - 50002: 'Invalid account type', - 50003: 'Cannot execute action on a DM channel', - 50004: 'Widget Disabled', - 50005: 'Cannot edit a message authored by another user', - 50006: 'Cannot send an empty message', - 50007: 'Cannot send messages to this user', - 50008: 'Cannot send messages in a voice channel', - 50009: 'Channel verification level is too high', - 50010: 'OAuth2 application does not have a bot', - 50011: 'OAuth2 application limit reached', - 50012: 'Invalid OAuth state', - 50013: 'Missing permissions', - 50014: 'Invalid authentication token', - 50015: 'Note is too long', - 50016: ('Provided too few or too many messages to delete. Must provide at ' - 'least 2 and fewer than 100 messages to delete.'), - 50019: 'A message can only be pinned to the channel it was sent in', - 50020: 'Invite code is either invalid or taken.', - 50021: 'Cannot execute action on a system message', - 50025: 'Invalid OAuth2 access token', - 50034: 'A message provided was too old to bulk delete', - 50035: 'Invalid Form Body', - 50036: 'An invite was accepted to a guild the application\'s bot is not in', - 50041: 'Invalid API version', - 90001: 'Reaction blocked', + 10001: "Unknown account", + 10002: "Unknown application", + 10003: "Unknown channel", + 10004: "Unknown guild", + 10005: "Unknown integration", + 10006: "Unknown invite", + 10007: "Unknown member", + 10008: "Unknown message", + 10009: "Unknown overwrite", + 10010: "Unknown provider", + 10011: "Unknown role", + 10012: "Unknown token", + 10013: "Unknown user", + 10014: "Unknown Emoji", + 10015: "Unknown Webhook", + 20001: "Bots cannot use this endpoint", + 20002: "Only bots can use this endpoint", + 30001: "Maximum number of guilds reached (100)", + 30002: "Maximum number of friends reached (1000)", + 30003: "Maximum number of pins reached (50)", + 30005: "Maximum number of guild roles reached (250)", + 30010: "Maximum number of reactions reached (20)", + 30013: "Maximum number of guild channels reached (500)", + 40001: "Unauthorized", + 50001: "Missing access", + 50002: "Invalid account type", + 50003: "Cannot execute action on a DM channel", + 50004: "Widget Disabled", + 50005: "Cannot edit a message authored by another user", + 50006: "Cannot send an empty message", + 50007: "Cannot send messages to this user", + 50008: "Cannot send messages in a voice channel", + 50009: "Channel verification level is too high", + 50010: "OAuth2 application does not have a bot", + 50011: "OAuth2 application limit reached", + 50012: "Invalid OAuth state", + 50013: "Missing permissions", + 50014: "Invalid authentication token", + 50015: "Note is too long", + 50016: ( + "Provided too few or too many messages to delete. Must provide at " + "least 2 and fewer than 100 messages to delete." + ), + 50019: "A message can only be pinned to the channel it was sent in", + 50020: "Invite code is either invalid or taken.", + 50021: "Cannot execute action on a system message", + 50025: "Invalid OAuth2 access token", + 50034: "A message provided was too old to bulk delete", + 50035: "Invalid Form Body", + 50036: "An invite was accepted to a guild the application's bot is not in", + 50041: "Invalid API version", + 90001: "Reaction blocked", } + class LitecordError(Exception): """Base class for litecord errors""" + status_code = 500 def _get_err_msg(self, err_code: int) -> str: @@ -91,7 +95,7 @@ class LitecordError(Exception): return message except IndexError: - return self._get_err_msg(getattr(self, 'error_code', None)) + return self._get_err_msg(getattr(self, "error_code", None)) @property def json(self): @@ -143,7 +147,7 @@ class MissingPermissions(Forbidden): class WebsocketClose(Exception): @property def code(self): - from_class = getattr(self, 'close_code', None) + from_class = getattr(self, "close_code", None) if from_class: return from_class @@ -152,7 +156,7 @@ class WebsocketClose(Exception): @property def reason(self): - from_class = getattr(self, 'close_code', None) + from_class = getattr(self, "close_code", None) if from_class: return self.args[0] diff --git a/litecord/gateway/encoding.py b/litecord/gateway/encoding.py index 07957f6..788d2b7 100644 --- a/litecord/gateway/encoding.py +++ b/litecord/gateway/encoding.py @@ -25,8 +25,7 @@ from litecord.utils import LitecordJSONEncoder def encode_json(payload) -> str: """Encode a given payload to JSON.""" - return json.dumps(payload, separators=(',', ':'), - cls=LitecordJSONEncoder) + return json.dumps(payload, separators=(",", ":"), cls=LitecordJSONEncoder) def decode_json(data: str): @@ -71,6 +70,7 @@ def _etf_decode_dict(data): return result + def decode_etf(data: bytes): """Decode data in ETF to any.""" res = earl.unpack(data) diff --git a/litecord/gateway/gateway.py b/litecord/gateway/gateway.py index a213528..2540269 100644 --- a/litecord/gateway/gateway.py +++ b/litecord/gateway/gateway.py @@ -24,37 +24,36 @@ from litecord.gateway.websocket import GatewayWebsocket async def websocket_handler(app, ws, url): """Main websocket handler, checks query arguments when connecting to the gateway and spawns a GatewayWebsocket instance for the connection.""" - args = urllib.parse.parse_qs( - urllib.parse.urlparse(url).query - ) + args = urllib.parse.parse_qs(urllib.parse.urlparse(url).query) # pull a dict.get but in a really bad way. try: - gw_version = args['v'][0] + gw_version = args["v"][0] except (KeyError, IndexError): - gw_version = '6' + gw_version = "6" try: - gw_encoding = args['encoding'][0] + gw_encoding = args["encoding"][0] except (KeyError, IndexError): - gw_encoding = 'json' + gw_encoding = "json" - if gw_version not in ('6', '7'): - return await ws.close(1000, 'Invalid gateway version') + if gw_version not in ("6", "7"): + return await ws.close(1000, "Invalid gateway version") - if gw_encoding not in ('json', 'etf'): - return await ws.close(1000, 'Invalid gateway encoding') + if gw_encoding not in ("json", "etf"): + return await ws.close(1000, "Invalid gateway encoding") try: - gw_compress = args['compress'][0] + gw_compress = args["compress"][0] except (KeyError, IndexError): gw_compress = None - if gw_compress and gw_compress not in ('zlib-stream', 'zstd-stream'): - return await ws.close(1000, 'Invalid gateway compress') + if gw_compress and gw_compress not in ("zlib-stream", "zstd-stream"): + return await ws.close(1000, "Invalid gateway compress") gws = GatewayWebsocket( - ws, app, v=gw_version, encoding=gw_encoding, compress=gw_compress) + ws, app, v=gw_version, encoding=gw_encoding, compress=gw_compress + ) # this can be run with a single await since this whole coroutine # is already running in the background. diff --git a/litecord/gateway/opcodes.py b/litecord/gateway/opcodes.py index 2f9596b..a36b099 100644 --- a/litecord/gateway/opcodes.py +++ b/litecord/gateway/opcodes.py @@ -17,8 +17,10 @@ along with this program. If not, see . """ + class OP: """Gateway OP codes.""" + DISPATCH = 0 HEARTBEAT = 1 IDENTIFY = 2 diff --git a/litecord/gateway/state.py b/litecord/gateway/state.py index cadd3af..9df23e3 100644 --- a/litecord/gateway/state.py +++ b/litecord/gateway/state.py @@ -32,6 +32,7 @@ class PayloadStore: This will only store a maximum of MAX_STORE_SIZE, dropping the older payloads when adding new ones. """ + MAX_STORE_SIZE = 250 def __init__(self): @@ -60,20 +61,20 @@ class GatewayState: """ def __init__(self, **kwargs): - self.session_id = kwargs.get('session_id', gen_session_id()) + self.session_id = kwargs.get("session_id", gen_session_id()) #: event sequence number - self.seq = kwargs.get('seq', 0) + self.seq = kwargs.get("seq", 0) #: last seq sent by us, the backend self.last_seq = 0 #: shard information about the state, # its id and shard count - self.shard = kwargs.get('shard', [0, 1]) + self.shard = kwargs.get("shard", [0, 1]) - self.user_id = kwargs.get('user_id') - self.bot = kwargs.get('bot', False) + self.user_id = kwargs.get("user_id") + self.bot = kwargs.get("bot", False) #: set by the gateway connection # on OP STATUS_UPDATE @@ -90,5 +91,4 @@ class GatewayState: self.__dict__[key] = value def __repr__(self): - return (f'GatewayState') + return f"GatewayState" diff --git a/litecord/gateway/state_manager.py b/litecord/gateway/state_manager.py index 3e995ee..756e551 100644 --- a/litecord/gateway/state_manager.py +++ b/litecord/gateway/state_manager.py @@ -39,6 +39,7 @@ class ManagerClose(Exception): class StateDictWrapper: """Wrap a mapping so that any kind of access to the mapping while the state manager is closed raises a ManagerClose error""" + def __init__(self, state_manager, mapping): self.state_manager = state_manager self._map = mapping @@ -98,7 +99,7 @@ class StateManager: """Insert a new state object.""" user_states = self.states[state.user_id] - log.debug('inserting state: {!r}', state) + log.debug("inserting state: {!r}", state) user_states[state.session_id] = state self.states_raw[state.session_id] = state @@ -128,7 +129,7 @@ class StateManager: pass try: - log.debug('removing state: {!r}', state) + log.debug("removing state: {!r}", state) self.states[state.user_id].pop(state.session_id) except KeyError: pass @@ -152,8 +153,7 @@ class StateManager: """Fetch all states tied to a single user.""" return list(self.states[user_id].values()) - def guild_states(self, member_ids: List[int], - guild_id: int) -> List[GatewayState]: + def guild_states(self, member_ids: List[int], guild_id: int) -> List[GatewayState]: """Fetch all possible states about members in a guild.""" states = [] @@ -164,14 +164,14 @@ class StateManager: # since server start, so we need to add a dummy state if not member_states: dummy_state = GatewayState( - session_id='', + session_id="", user_id=member_id, presence={ - 'afk': False, - 'status': 'offline', - 'game': None, - 'since': 0 - } + "afk": False, + "status": "offline", + "game": None, + "since": 0, + }, ) states.append(dummy_state) @@ -187,9 +187,7 @@ class StateManager: """Send OP Reconnect to a single connection.""" websocket = state.ws - await websocket.send({ - 'op': OP.RECONNECT - }) + await websocket.send({"op": OP.RECONNECT}) # wait 200ms # so that the client has time to process @@ -198,12 +196,9 @@ class StateManager: try: # try to close the connection ourselves - await websocket.ws.close( - code=4000, - reason='litecord shutting down' - ) + await websocket.ws.close(code=4000, reason="litecord shutting down") except ConnectionClosed: - log.info('client {} already closed', state) + log.info("client {} already closed", state) def gen_close_tasks(self): """Generate the tasks that will order the clients @@ -222,11 +217,9 @@ class StateManager: if not state.ws: continue - tasks.append( - self.shutdown_single(state) - ) + tasks.append(self.shutdown_single(state)) - log.info('made {} shutdown tasks', len(tasks)) + log.info("made {} shutdown tasks", len(tasks)) return tasks def close(self): diff --git a/litecord/gateway/utils.py b/litecord/gateway/utils.py index d52d7df..8f158b3 100644 --- a/litecord/gateway/utils.py +++ b/litecord/gateway/utils.py @@ -19,9 +19,11 @@ along with this program. If not, see . import asyncio + class WebsocketFileHandler: """A handler around a websocket that wraps normal I/O calls into the websocket's respective asyncio calls via asyncio.ensure_future.""" + def __init__(self, ws): self.ws = ws diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 5b5b33a..6579982 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -31,23 +31,20 @@ from logbook import Logger from litecord.auth import raw_token_check from litecord.enums import RelationshipType, ChannelType from litecord.schemas import validate, GW_STATUS_UPDATE -from litecord.utils import ( - task_wrapper, yield_chunks, maybe_int -) +from litecord.utils import task_wrapper, yield_chunks, maybe_int from litecord.permissions import get_permissions from litecord.gateway.opcodes import OP from litecord.gateway.state import GatewayState -from litecord.errors import ( - WebsocketClose, Unauthorized, Forbidden, BadRequest -) +from litecord.errors import WebsocketClose, Unauthorized, Forbidden, BadRequest from litecord.gateway.errors import ( - DecodeError, UnknownOPCode, InvalidShard, ShardingRequired -) -from litecord.gateway.encoding import ( - encode_json, decode_json, encode_etf, decode_etf + DecodeError, + UnknownOPCode, + InvalidShard, + ShardingRequired, ) +from litecord.gateway.encoding import encode_json, decode_json, encode_etf, decode_etf from litecord.gateway.utils import WebsocketFileHandler @@ -56,15 +53,22 @@ from litecord.storage import int_ log = Logger(__name__) WebsocketProperties = collections.namedtuple( - 'WebsocketProperties', 'v encoding compress zctx zsctx tasks' + "WebsocketProperties", "v encoding compress zctx zsctx tasks" ) WebsocketObjects = collections.namedtuple( - 'WebsocketObjects', ( - 'db', 'state_manager', 'storage', - 'loop', 'dispatcher', 'presence', 'ratelimiter', - 'user_storage', 'voice' - ) + "WebsocketObjects", + ( + "db", + "state_manager", + "storage", + "loop", + "dispatcher", + "presence", + "ratelimiter", + "user_storage", + "voice", + ), ) @@ -73,9 +77,15 @@ class GatewayWebsocket: def __init__(self, ws, app, **kwargs): self.ext = WebsocketObjects( - app.db, app.state_manager, app.storage, app.loop, - app.dispatcher, app.presence, app.ratelimiter, - app.user_storage, app.voice + app.db, + app.state_manager, + app.storage, + app.loop, + app.dispatcher, + app.presence, + app.ratelimiter, + app.user_storage, + app.voice, ) self.storage = self.ext.storage @@ -84,15 +94,15 @@ class GatewayWebsocket: self.ws = ws self.wsp = WebsocketProperties( - kwargs.get('v'), - kwargs.get('encoding', 'json'), - kwargs.get('compress', None), + kwargs.get("v"), + kwargs.get("encoding", "json"), + kwargs.get("compress", None), zlib.compressobj(), zstd.ZstdCompressor(), - {} + {}, ) - log.debug('websocket properties: {!r}', self.wsp) + log.debug("websocket properties: {!r}", self.wsp) self.state = None @@ -102,8 +112,8 @@ class GatewayWebsocket: encoding = self.wsp.encoding encodings = { - 'json': (encode_json, decode_json), - 'etf': (encode_etf, decode_etf), + "json": (encode_json, decode_json), + "etf": (encode_etf, decode_etf), } self.encoder, self.decoder = encodings[encoding] @@ -111,16 +121,17 @@ class GatewayWebsocket: async def _chunked_send(self, data: bytes, chunk_size: int): """Split data in chunk_size-big chunks and send them over the websocket.""" - log.debug('zlib-stream: chunking {} bytes into {}-byte chunks', - len(data), chunk_size) + log.debug( + "zlib-stream: chunking {} bytes into {}-byte chunks", len(data), chunk_size + ) total_chunks = 0 for chunk in yield_chunks(data, chunk_size): total_chunks += 1 - log.debug('zlib-stream: chunk {}', total_chunks) + log.debug("zlib-stream: chunk {}", total_chunks) await self.ws.send(chunk) - log.debug('zlib-stream: sent {} chunks', total_chunks) + log.debug("zlib-stream: sent {} chunks", total_chunks) async def _zlib_stream_send(self, encoded): """Sending a single payload across multiple compressed @@ -130,8 +141,12 @@ class GatewayWebsocket: data1 = self.wsp.zctx.compress(encoded) data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH) - log.debug('zlib-stream: length {} -> compressed ({} + {})', - len(encoded), len(data1), len(data2)) + log.debug( + "zlib-stream: length {} -> compressed ({} + {})", + len(encoded), + len(data1), + len(data2), + ) if not data1: # if data1 is nothing, that might cause problems @@ -139,8 +154,11 @@ class GatewayWebsocket: data1 = bytes([data2[0]]) data2 = data2[1:] - log.debug('zlib-stream: len(data1) == 0, remaking as ({} + {})', - len(data1), len(data2)) + log.debug( + "zlib-stream: len(data1) == 0, remaking as ({} + {})", + len(data1), + len(data2), + ) # NOTE: the old approach was ws.send(data1 + data2). # I changed this to a chunked send of data1 and data2 @@ -157,8 +175,7 @@ class GatewayWebsocket: await self._chunked_send(data2, 1024) async def _zstd_stream_send(self, encoded): - compressor = self.wsp.zsctx.stream_writer( - WebsocketFileHandler(self.ws)) + compressor = self.wsp.zsctx.stream_writer(WebsocketFileHandler(self.ws)) compressor.write(encoded) compressor.flush(zstd.FLUSH_FRAME) @@ -172,21 +189,23 @@ class GatewayWebsocket: encoded = self.encoder(payload) if len(encoded) < 2048: - log.debug('sending\n{}', pprint.pformat(payload)) + log.debug("sending\n{}", pprint.pformat(payload)) else: - log.debug('sending {}', pprint.pformat(payload)) - log.debug('sending op={} s={} t={} (too big)', - payload.get('op'), - payload.get('s'), - payload.get('t')) + log.debug("sending {}", pprint.pformat(payload)) + log.debug( + "sending op={} s={} t={} (too big)", + payload.get("op"), + payload.get("s"), + payload.get("t"), + ) # treat encoded as bytes if not isinstance(encoded, bytes): encoded = encoded.encode() - if self.wsp.compress == 'zlib-stream': + if self.wsp.compress == "zlib-stream": await self._zlib_stream_send(encoded) - elif self.wsp.compress == 'zstd-stream': + elif self.wsp.compress == "zstd-stream": await self._zstd_stream_send(encoded) elif self.state and self.state.compress and len(encoded) > 1024: # TODO: should we only compress on >1KB packets? or maybe we @@ -203,16 +222,10 @@ class GatewayWebsocket: async def send_op(self, op_code: int, data: Any): """Send a packet but just the OP code information is filled in.""" - await self.send({ - 'op': op_code, - 'd': data, - - 't': None, - 's': None - }) + await self.send({"op": op_code, "d": data, "t": None, "s": None}) def _check_ratelimit(self, key: str, ratelimit_key): - ratelimit = self.ext.ratelimiter.get_ratelimit(f'_ws.{key}') + ratelimit = self.ext.ratelimiter.get_ratelimit(f"_ws.{key}") bucket = ratelimit.get_bucket(ratelimit_key) return bucket.update_rate_limit() @@ -221,19 +234,19 @@ class GatewayWebsocket: # if the client heartbeats in time, # this task will be cancelled. await asyncio.sleep(interval / 1000) - await self.ws.close(4000, 'Heartbeat expired') + await self.ws.close(4000, "Heartbeat expired") self._cleanup() def _hb_start(self, interval: int): # always refresh the heartbeat task # when possible - task = self.wsp.tasks.get('heartbeat') + task = self.wsp.tasks.get("heartbeat") if task: task.cancel() - self.wsp.tasks['heartbeat'] = self.ext.loop.create_task( - task_wrapper('hb wait', self._hb_wait(interval)) + self.wsp.tasks["heartbeat"] = self.ext.loop.create_task( + task_wrapper("hb wait", self._hb_wait(interval)) ) async def _send_hello(self): @@ -241,12 +254,9 @@ class GatewayWebsocket: # random heartbeat intervals interval = randint(40, 46) * 1000 - await self.send_op(OP.HELLO, { - 'heartbeat_interval': interval, - '_trace': [ - 'lesbian-server' - ], - }) + await self.send_op( + OP.HELLO, {"heartbeat_interval": interval, "_trace": ["lesbian-server"]} + ) self._hb_start(interval) @@ -255,16 +265,15 @@ class GatewayWebsocket: self.state.seq += 1 payload = { - 'op': OP.DISPATCH, - 't': event.upper(), - 's': self.state.seq, - 'd': data, + "op": OP.DISPATCH, + "t": event.upper(), + "s": self.state.seq, + "d": data, } self.state.store[self.state.seq] = payload - log.debug('sending payload {!r} sid {}', - event.upper(), self.state.session_id) + log.debug("sending payload {!r} sid {}", event.upper(), self.state.session_id) await self.send(payload) @@ -274,16 +283,14 @@ class GatewayWebsocket: guild_ids = await self._guild_ids() if self.state.bot: - return [{ - 'id': row, - 'unavailable': True, - } for row in guild_ids] + return [{"id": row, "unavailable": True} for row in guild_ids] return [ { **await self.storage.get_guild(guild_id, user_id), - **await self.storage.get_guild_extra(guild_id, user_id, - self.state.large) + **await self.storage.get_guild_extra( + guild_id, user_id, self.state.large + ), } for guild_id in guild_ids ] @@ -298,13 +305,13 @@ class GatewayWebsocket: for guild_obj in unavailable_guilds: # fetch full guild object including the 'large' field guild = await self.storage.get_guild_full( - int(guild_obj['id']), self.state.user_id, self.state.large + int(guild_obj["id"]), self.state.user_id, self.state.large ) if guild is None: continue - await self.dispatch('GUILD_CREATE', guild) + await self.dispatch("GUILD_CREATE", guild) async def _user_ready(self) -> dict: """Fetch information about users in the READY packet. @@ -317,28 +324,28 @@ class GatewayWebsocket: relationships = await self.user_storage.get_relationships(user_id) - friend_ids = [int(r['user']['id']) for r in relationships - if r['type'] == RelationshipType.FRIEND.value] + friend_ids = [ + int(r["user"]["id"]) + for r in relationships + if r["type"] == RelationshipType.FRIEND.value + ] friend_presences = await self.ext.presence.friend_presences(friend_ids) settings = await self.user_storage.get_user_settings(user_id) return { - 'user_settings': settings, - 'notes': await self.user_storage.fetch_notes(user_id), - 'relationships': relationships, - 'presences': friend_presences, - 'read_state': await self.user_storage.get_read_state(user_id), - 'user_guild_settings': await self.user_storage.get_guild_settings( - user_id), - - 'friend_suggestion_count': 0, - + "user_settings": settings, + "notes": await self.user_storage.fetch_notes(user_id), + "relationships": relationships, + "presences": friend_presences, + "read_state": await self.user_storage.get_read_state(user_id), + "user_guild_settings": await self.user_storage.get_guild_settings(user_id), + "friend_suggestion_count": 0, # those are unused default values. - 'connected_accounts': [], - 'experiments': [], - 'guild_experiments': [], - 'analytics_token': 'transbian', + "connected_accounts": [], + "experiments": [], + "guild_experiments": [], + "analytics_token": "transbian", } async def dispatch_ready(self): @@ -353,24 +360,21 @@ class GatewayWebsocket: # user, fetch info user_ready = await self._user_ready() - private_channels = ( - await self.user_storage.get_dms(user_id) + - await self.user_storage.get_gdms(user_id) - ) + private_channels = await self.user_storage.get_dms( + user_id + ) + await self.user_storage.get_gdms(user_id) base_ready = { - 'v': 6, - 'user': user, - - 'private_channels': private_channels, - - 'guilds': guilds, - 'session_id': self.state.session_id, - '_trace': ['transbian'], - 'shard': self.state.shard, + "v": 6, + "user": user, + "private_channels": private_channels, + "guilds": guilds, + "session_id": self.state.session_id, + "_trace": ["transbian"], + "shard": self.state.shard, } - await self.dispatch('READY', {**base_ready, **user_ready}) + await self.dispatch("READY", {**base_ready, **user_ready}) # async dispatch of guilds self.ext.loop.create_task(self._guild_dispatch(guilds)) @@ -380,33 +384,32 @@ class GatewayWebsocket: """ current_shard, shard_count = shard - guilds = await self.ext.db.fetchval(""" + guilds = await self.ext.db.fetchval( + """ SELECT COUNT(*) FROM members WHERE user_id = $1 - """, user_id) + """, + user_id, + ) recommended = max(int(guilds / 1200), 1) if shard_count < recommended: - raise ShardingRequired('Too many guilds for shard ' - f'{current_shard}') + raise ShardingRequired("Too many guilds for shard " f"{current_shard}") if guilds > 2500 and guilds / shard_count > 0.8: - raise ShardingRequired('Too many shards. ' - f'(g={guilds} sc={shard_count})') + raise ShardingRequired("Too many shards. " f"(g={guilds} sc={shard_count})") if current_shard > shard_count: - raise InvalidShard('Shard count > Total shards') + raise InvalidShard("Shard count > Total shards") async def _guild_ids(self) -> list: """Get a list of Guild IDs that are tied to this connection. The implementation is shard-aware. """ - guild_ids = await self.user_storage.get_user_guilds( - self.state.user_id - ) + guild_ids = await self.user_storage.get_user_guilds(self.state.user_id) shard_id = self.state.current_shard shard_count = self.state.shard_count @@ -414,10 +417,7 @@ class GatewayWebsocket: def _get_shard(guild_id): return (guild_id >> 22) % shard_count - filtered = filter( - lambda guild_id: _get_shard(guild_id) == shard_id, - guild_ids - ) + filtered = filter(lambda guild_id: _get_shard(guild_id) == shard_id, guild_ids) return list(filtered) @@ -432,13 +432,17 @@ class GatewayWebsocket: # subscribe the user to all dms they have OPENED. dms = await self.user_storage.get_dms(user_id) - dm_ids = [int(dm['id']) for dm in dms] + dm_ids = [int(dm["id"]) for dm in dms] # fetch all group dms the user is a member of. gdm_ids = await self.user_storage.get_gdms_internal(user_id) - log.info('subscribing to {} guilds {} dms {} gdms', - len(guild_ids), len(dm_ids), len(gdm_ids)) + log.info( + "subscribing to {} guilds {} dms {} gdms", + len(guild_ids), + len(dm_ids), + len(gdm_ids), + ) # guild_subscriptions: # enables dispatching of guild subscription events @@ -447,10 +451,13 @@ class GatewayWebsocket: # we enable processing of guild_subscriptions by adding flags # when subscribing to the given backend. those are optional. channels_to_sub = [ - ('guild', guild_ids, - {'presence': guild_subscriptions, 'typing': guild_subscriptions}), - ('channel', dm_ids), - ('channel', gdm_ids), + ( + "guild", + guild_ids, + {"presence": guild_subscriptions, "typing": guild_subscriptions}, + ), + ("channel", dm_ids), + ("channel", gdm_ids), ] await self.ext.dispatcher.mass_sub(user_id, channels_to_sub) @@ -460,28 +467,26 @@ class GatewayWebsocket: # (their friends will also subscribe back # when they come online) friend_ids = await self.user_storage.get_friend_ids(user_id) - log.info('subscribing to {} friends', len(friend_ids)) - await self.ext.dispatcher.sub_many('friend', user_id, friend_ids) + log.info("subscribing to {} friends", len(friend_ids)) + await self.ext.dispatcher.sub_many("friend", user_id, friend_ids) async def update_status(self, status: dict): """Update the status of the current websocket connection.""" if not self.state: return - if self._check_ratelimit('presence', self.state.session_id): + if self._check_ratelimit("presence", self.state.session_id): # Presence Updates beyond the ratelimit # are just silently dropped. return default_status = { - 'afk': False, - + "afk": False, # TODO: fetch status from settings - 'status': 'online', - 'game': None, - + "status": "online", + "game": None, # TODO: this - 'since': 0, + "since": 0, } status = {**(status or {}), **default_status} @@ -489,39 +494,40 @@ class GatewayWebsocket: try: status = validate(status, GW_STATUS_UPDATE) except BadRequest as err: - log.warning(f'Invalid status update: {err}') + log.warning(f"Invalid status update: {err}") return # try to extract game from activities # when game not provided - if not status.get('game'): + if not status.get("game"): try: - game = status['activities'][0] + game = status["activities"][0] except (KeyError, IndexError): game = None else: - game = status['game'] + game = status["game"] # construct final status status = { - 'afk': status.get('afk', False), - 'status': status.get('status', 'online'), - 'game': game, - 'since': status.get('since', 0), + "afk": status.get("afk", False), + "status": status.get("status", "online"), + "game": game, + "since": status.get("since", 0), } self.state.presence = status - log.info(f'Updating presence status={status["status"]} for ' - f'uid={self.state.user_id}') - await self.ext.presence.dispatch_pres(self.state.user_id, - self.state.presence) + log.info( + f'Updating presence status={status["status"]} for ' + f"uid={self.state.user_id}" + ) + await self.ext.presence.dispatch_pres(self.state.user_id, self.state.presence) async def handle_1(self, payload: Dict[str, Any]): """Handle OP 1 Heartbeat packets.""" # give the client 3 more seconds before we # close the websocket self._hb_start((46 + 3) * 1000) - cliseq = payload.get('d') + cliseq = payload.get("d") if self.state: self.state.last_seq = cliseq @@ -529,39 +535,42 @@ class GatewayWebsocket: await self.send_op(OP.HEARTBEAT_ACK, None) async def _connect_ratelimit(self, user_id: int): - if self._check_ratelimit('connect', user_id): + if self._check_ratelimit("connect", user_id): await self.invalidate_session(False) - raise WebsocketClose(4009, 'You are being ratelimited.') + raise WebsocketClose(4009, "You are being ratelimited.") - if self._check_ratelimit('session', user_id): + if self._check_ratelimit("session", user_id): await self.invalidate_session(False) - raise WebsocketClose(4004, 'Websocket Session Ratelimit reached.') + raise WebsocketClose(4004, "Websocket Session Ratelimit reached.") async def handle_2(self, payload: Dict[str, Any]): """Handle the OP 2 Identify packet.""" try: - data = payload['d'] - token = data['token'] + data = payload["d"] + token = data["token"] except KeyError: - raise DecodeError('Invalid identify parameters') + raise DecodeError("Invalid identify parameters") - compress = data.get('compress', False) - large = data.get('large_threshold', 50) + compress = data.get("compress", False) + large = data.get("large_threshold", 50) - shard = data.get('shard', [0, 1]) - presence = data.get('presence') + shard = data.get("shard", [0, 1]) + presence = data.get("presence") try: user_id = await raw_token_check(token, self.ext.db) except (Unauthorized, Forbidden): - raise WebsocketClose(4004, 'Authentication failed') + raise WebsocketClose(4004, "Authentication failed") await self._connect_ratelimit(user_id) - bot = await self.ext.db.fetchval(""" + bot = await self.ext.db.fetchval( + """ SELECT bot FROM users WHERE id = $1 - """, user_id) + """, + user_id, + ) await self._check_shards(shard, user_id) @@ -574,19 +583,19 @@ class GatewayWebsocket: shard=shard, current_shard=shard[0], shard_count=shard[1], - ws=self + ws=self, ) # link the state to the user self.ext.state_manager.insert(self.state) await self.update_status(presence) - await self.subscribe_all(data.get('guild_subscriptions', True)) + await self.subscribe_all(data.get("guild_subscriptions", True)) await self.dispatch_ready() async def handle_3(self, payload: Dict[str, Any]): """Handle OP 3 Status Update.""" - presence = payload['d'] + presence = payload["d"] # update_status will take care of validation and # setting new presence to state @@ -597,27 +606,27 @@ class GatewayWebsocket: user settings.""" try: # TODO: fetch from settings if not provided - self_deaf = bool(data['self_deaf']) - self_mute = bool(data['self_mute']) + self_deaf = bool(data["self_deaf"]) + self_mute = bool(data["self_mute"]) except (KeyError, ValueError): pass return { - 'deaf': state.deaf, - 'mute': state.mute, - 'self_deaf': self_deaf, - 'self_mute': self_mute, + "deaf": state.deaf, + "mute": state.mute, + "self_deaf": self_deaf, + "self_mute": self_mute, } async def handle_4(self, payload: Dict[str, Any]): """Handle OP 4 Voice Status Update.""" - data = payload['d'] + data = payload["d"] if not self.state: return - channel_id = int_(data.get('channel_id')) - guild_id = int_(data.get('guild_id')) + channel_id = int_(data.get("channel_id")) + guild_id = int_(data.get("guild_id")) # if its null and null, disconnect the user from any voice # TODO: maybe just leave from DMs? idk... @@ -630,9 +639,7 @@ class GatewayWebsocket: return await self.ext.voice.leave(guild_id, self.state.user_id) # fetch an existing state given user and guild OR user and channel - chan_type = ChannelType( - await self.storage.get_chan_type(channel_id) - ) + chan_type = ChannelType(await self.storage.get_chan_type(channel_id)) state_id2 = channel_id @@ -704,39 +711,38 @@ class GatewayWebsocket: # ignore unknown seqs continue - payload_t = payload.get('t') + payload_t = payload.get("t") # presence resumption happens # on a separate event, PRESENCE_REPLACE. - if payload_t == 'PRESENCE_UPDATE': - presences.append(payload.get('d')) + if payload_t == "PRESENCE_UPDATE": + presences.append(payload.get("d")) continue await self.send(payload) except Exception: - log.exception('error while resuming') + log.exception("error while resuming") await self.invalidate_session(False) return if presences: - await self.dispatch('PRESENCE_REPLACE', presences) + await self.dispatch("PRESENCE_REPLACE", presences) - await self.dispatch('RESUMED', {}) + await self.dispatch("RESUMED", {}) async def handle_6(self, payload: Dict[str, Any]): """Handle OP 6 Resume.""" - data = payload['d'] + data = payload["d"] try: - token, sess_id, seq = data['token'], \ - data['session_id'], data['seq'] + token, sess_id, seq = data["token"], data["session_id"], data["seq"] except KeyError: - raise DecodeError('Invalid resume payload') + raise DecodeError("Invalid resume payload") try: user_id = await raw_token_check(token, self.ext.db) except (Unauthorized, Forbidden): - raise WebsocketClose(4004, 'Invalid token') + raise WebsocketClose(4004, "Invalid token") try: state = self.ext.state_manager.fetch(user_id, sess_id) @@ -744,11 +750,11 @@ class GatewayWebsocket: return await self.invalidate_session(False) if seq > state.seq: - raise WebsocketClose(4007, 'Invalid seq') + raise WebsocketClose(4007, "Invalid seq") # check if a websocket isnt on that state already if state.ws is not None: - log.info('Resuming failed, websocket already connected') + log.info("Resuming failed, websocket already connected") return await self.invalidate_session(False) # relink this connection @@ -757,8 +763,9 @@ class GatewayWebsocket: await self._resume(range(seq, state.seq)) - async def _req_guild_members(self, guild_id, user_ids: List[int], - query: str, limit: int): + async def _req_guild_members( + self, guild_id, user_ids: List[int], query: str, limit: int + ): try: guild_id = int(guild_id) except (TypeError, ValueError): @@ -778,32 +785,32 @@ class GatewayWebsocket: # ASSUMPTION: requesting user_ids means we don't do query. if user_ids: members = await self.storage.get_member_multi(guild_id, user_ids) - mids = [m['user']['id'] for m in members] + mids = [m["user"]["id"] for m in members] not_found = [uid for uid in user_ids if uid not in mids] - await self.dispatch('GUILD_MEMBERS_CHUNK', { - 'guild_id': str(guild_id), - 'members': members, - 'not_found': not_found, - }) + await self.dispatch( + "GUILD_MEMBERS_CHUNK", + {"guild_id": str(guild_id), "members": members, "not_found": not_found}, + ) return # do the search result = await self.storage.query_members(guild_id, query, limit) - await self.dispatch('GUILD_MEMBERS_CHUNK', { - 'guild_id': str(guild_id), - 'members': result - }) + await self.dispatch( + "GUILD_MEMBERS_CHUNK", {"guild_id": str(guild_id), "members": result} + ) async def handle_8(self, payload: Dict): """Handle OP 8 Request Guild Members.""" - data = payload['d'] - gids = data['guild_id'] + data = payload["d"] + gids = data["guild_id"] - uids, query, limit = data.get('user_ids', []), \ - data.get('query', ''), \ - data.get('limit', 0) + uids, query, limit = ( + data.get("user_ids", []), + data.get("query", ""), + data.get("limit", 0), + ) if isinstance(gids, str): await self._req_guild_members(gids, uids, query, limit) @@ -820,23 +827,21 @@ class GatewayWebsocket: GUILD_SYNC event with that info. """ members = await self.storage.get_member_data(guild_id) - member_ids = [int(m['user']['id']) for m in members] + member_ids = [int(m["user"]["id"]) for m in members] - log.debug(f'Syncing guild {guild_id} with {len(member_ids)} members') + log.debug(f"Syncing guild {guild_id} with {len(member_ids)} members") presences = await self.presence.guild_presences(member_ids, guild_id) - await self.dispatch('GUILD_SYNC', { - 'id': str(guild_id), - 'presences': presences, - 'members': members, - }) + await self.dispatch( + "GUILD_SYNC", + {"id": str(guild_id), "presences": presences, "members": members}, + ) async def handle_12(self, payload: Dict[str, Any]): """Handle OP 12 Guild Sync.""" - data = payload['d'] + data = payload["d"] - gids = await self.user_storage.get_user_guilds( - self.state.user_id) + gids = await self.user_storage.get_user_guilds(self.state.user_id) for guild_id in data: try: @@ -931,35 +936,33 @@ class GatewayWebsocket: ] } """ - data = payload['d'] + data = payload["d"] gids = await self.user_storage.get_user_guilds(self.state.user_id) - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) # make sure to not extract info you shouldn't get if guild_id not in gids: return - log.debug('lazy request: members: {}', - data.get('members', [])) + log.debug("lazy request: members: {}", data.get("members", [])) # make shard query - lazy_guilds = self.ext.dispatcher.backends['lazy_guild'] + lazy_guilds = self.ext.dispatcher.backends["lazy_guild"] - for chan_id, ranges in data.get('channels', {}).items(): + for chan_id, ranges in data.get("channels", {}).items(): chan_id = int(chan_id) member_list = await lazy_guilds.get_gml(chan_id) perms = await get_permissions( - self.state.user_id, chan_id, storage=self.storage) + self.state.user_id, chan_id, storage=self.storage + ) if not perms.bits.read_messages: # ignore requests to unknown channels return - await member_list.shard_query( - self.state.session_id, ranges - ) + await member_list.shard_query(self.state.session_id, ranges) async def _handle_23(self, payload): # TODO reverse-engineer opcode 23, sent by client @@ -968,21 +971,21 @@ class GatewayWebsocket: async def _process_message(self, payload): """Process a single message coming in from the client.""" try: - op_code = payload['op'] + op_code = payload["op"] except KeyError: - raise UnknownOPCode('No OP code') + raise UnknownOPCode("No OP code") try: - handler = getattr(self, f'handle_{op_code}') + handler = getattr(self, f"handle_{op_code}") except AttributeError: - log.warning('Payload with bad op: {}', pprint.pformat(payload)) - raise UnknownOPCode(f'Bad OP code: {op_code}') + log.warning("Payload with bad op: {}", pprint.pformat(payload)) + raise UnknownOPCode(f"Bad OP code: {op_code}") await handler(payload) async def _msg_ratelimit(self): - if self._check_ratelimit('messages', self.state.session_id): - raise WebsocketClose(4008, 'You are being ratelimited.') + if self._check_ratelimit("messages", self.state.session_id): + raise WebsocketClose(4008, "You are being ratelimited.") async def _listen_messages(self): """Listen for messages coming in from the websocket.""" @@ -990,15 +993,15 @@ class GatewayWebsocket: # close anyone trying to login while the # server is shutting down if self.ext.state_manager.closed: - raise WebsocketClose(4000, 'state manager closed') + raise WebsocketClose(4000, "state manager closed") if not self.ext.state_manager.accept_new: - raise WebsocketClose(4000, 'state manager closed for new') + raise WebsocketClose(4000, "state manager closed for new") while True: message = await self.ws.recv() if len(message) > 4096: - raise DecodeError('Payload length exceeded') + raise DecodeError("Payload length exceeded") if self.state: await self._msg_ratelimit() @@ -1033,17 +1036,9 @@ class GatewayWebsocket: # there arent any other states with websocket if not with_ws: - offline = { - 'afk': False, - 'status': 'offline', - 'game': None, - 'since': 0, - } + offline = {"afk": False, "status": "offline", "game": None, "since": 0} - await self.ext.presence.dispatch_pres( - user_id, - offline - ) + await self.ext.presence.dispatch_pres(user_id, offline) async def run(self): """Wrap :meth:`listen_messages` inside @@ -1052,12 +1047,12 @@ class GatewayWebsocket: await self._send_hello() await self._listen_messages() except websockets.exceptions.ConnectionClosed as err: - log.warning('conn close, state={}, err={}', self.state, err) + log.warning("conn close, state={}, err={}", self.state, err) except WebsocketClose as err: - log.warning('ws close, state={} err={}', self.state, err) + log.warning("ws close, state={} err={}", self.state, err) await self.ws.close(code=err.code, reason=err.reason) except Exception as err: - log.exception('An exception has occoured. state={}', self.state) + log.exception("An exception has occoured. state={}", self.state) await self.ws.close(code=4000, reason=repr(err)) finally: user_id = self.state.user_id if self.state else None diff --git a/litecord/guild_memory_store.py b/litecord/guild_memory_store.py index a2e3457..223891c 100644 --- a/litecord/guild_memory_store.py +++ b/litecord/guild_memory_store.py @@ -17,19 +17,21 @@ along with this program. If not, see . """ + class GuildMemoryStore: """Store in-memory properties about guilds. I could have just used Redis... probably too overkill to add aioredis to the already long depedency list, plus, I don't need """ + def __init__(self): self._store = {} def get(self, guild_id: int, attribute: str, default=None): """get a key""" - return self._store.get(f'{guild_id}:{attribute}', default) + return self._store.get(f"{guild_id}:{attribute}", default) def set(self, guild_id: int, attribute: str, value): """set a key""" - self._store[f'{guild_id}:{attribute}'] = value + self._store[f"{guild_id}:{attribute}"] = value diff --git a/litecord/images.py b/litecord/images.py index 0ce0595..345f2ef 100644 --- a/litecord/images.py +++ b/litecord/images.py @@ -33,47 +33,42 @@ from logbook import Logger from PIL import Image -IMAGE_FOLDER = Path('./images') +IMAGE_FOLDER = Path("./images") log = Logger(__name__) -EXTENSIONS = { - 'image/jpeg': 'jpeg', - 'image/webp': 'webp' -} +EXTENSIONS = {"image/jpeg": "jpeg", "image/webp": "webp"} MIMES = { - 'jpg': 'image/jpeg', - 'jpe': 'image/jpeg', - 'jpeg': 'image/jpeg', - 'webp': 'image/webp', + "jpg": "image/jpeg", + "jpe": "image/jpeg", + "jpeg": "image/jpeg", + "webp": "image/webp", } -STATIC_IMAGE_MIMES = [ - 'image/png', - 'image/jpeg', - 'image/webp' -] +STATIC_IMAGE_MIMES = ["image/png", "image/jpeg", "image/webp"] + def get_ext(mime: str) -> str: if mime in EXTENSIONS: return EXTENSIONS[mime] extensions = mimetypes.guess_all_extensions(mime) - return extensions[0].strip('.') + return extensions[0].strip(".") def get_mime(ext: str): if ext in MIMES: return MIMES[ext] - return mimetypes.types_map[f'.{ext}'] + return mimetypes.types_map[f".{ext}"] @dataclass class Icon: """Main icon class""" + key: Optional[str] icon_hash: Optional[str] mime: Optional[str] @@ -85,7 +80,7 @@ class Icon: return None ext = get_ext(self.mime) - return str(IMAGE_FOLDER / f'{self.key}_{self.icon_hash}.{ext}') + return str(IMAGE_FOLDER / f"{self.key}_{self.icon_hash}.{ext}") @property def as_pathlib(self) -> Optional[Path]: @@ -106,13 +101,14 @@ class Icon: class ImageError(Exception): """Image error class.""" + pass def to_raw(data_type: str, data: str) -> Optional[bytes]: """Given a data type in the data URI and data, give the raw bytes being encoded.""" - if data_type == 'base64': + if data_type == "base64": return base64.b64decode(data) return None @@ -136,7 +132,7 @@ def _calculate_hash(fhandler) -> str: """ hash_obj = sha256() - for chunk in iter(lambda: fhandler.read(4096), b''): + for chunk in iter(lambda: fhandler.read(4096), b""): hash_obj.update(chunk) # so that we can reuse the same handler @@ -162,39 +158,36 @@ async def calculate_hash(fhandle, loop=None) -> str: def parse_data_uri(string) -> tuple: """Extract image data.""" try: - header, headered_data = string.split(';') + header, headered_data = string.split(";") - _, given_mime = header.split(':') - data_type, data = headered_data.split(',') + _, given_mime = header.split(":") + data_type, data = headered_data.split(",") raw_data = to_raw(data_type, data) if raw_data is None: - raise ImageError('Unknown data header') + raise ImageError("Unknown data header") return given_mime, raw_data except ValueError: - raise ImageError('data URI invalid syntax') + raise ImageError("data URI invalid syntax") def _gen_update_sql(scope: str) -> str: # match a scope to (table, field) field = { - 'user': 'avatar', - 'guild': 'icon', - 'splash': 'splash', - 'banner': 'banner', - - 'channel-icons': 'icon', + "user": "avatar", + "guild": "icon", + "splash": "splash", + "banner": "banner", + "channel-icons": "icon", }[scope] table = { - 'user': 'users', - - 'guild': 'guilds', - 'splash': 'guilds', - 'banner': 'guilds', - - 'channel-icons': 'group_dm_channels' + "user": "users", + "guild": "guilds", + "splash": "guilds", + "banner": "guilds", + "channel-icons": "group_dm_channels", }[scope] return f""" @@ -204,10 +197,10 @@ def _gen_update_sql(scope: str) -> str: def _invalid(kwargs: dict) -> Optional[Icon]: """Send an invalid value.""" - if not kwargs.get('always_icon', False): + if not kwargs.get("always_icon", False): return None - return Icon(None, None, '') + return Icon(None, None, "") def try_unlink(path: Union[Path, str]): @@ -225,18 +218,17 @@ def try_unlink(path: Union[Path, str]): async def resize_gif(raw_data: bytes, target: tuple) -> tuple: """Resize a GIF image.""" # generate a temporary file to call gifsticle to and from. - input_fd, input_path = tempfile.mkstemp(suffix='.gif') - _, output_path = tempfile.mkstemp(suffix='.gif') + input_fd, input_path = tempfile.mkstemp(suffix=".gif") + _, output_path = tempfile.mkstemp(suffix=".gif") - input_handler = os.fdopen(input_fd, 'wb') + input_handler = os.fdopen(input_fd, "wb") # make sure its valid image data data_fd = BytesIO(raw_data) image = Image.open(data_fd) image.close() - log.info('resizing a GIF from {} to {}', - image.size, target) + log.info("resizing a GIF from {} to {}", image.size, target) # insert image info on input_handler # close it to make it ready for consumption by gifsicle @@ -244,12 +236,11 @@ async def resize_gif(raw_data: bytes, target: tuple) -> tuple: input_handler.close() # call gifsicle under subprocess - log.debug('input: {}', input_path) - log.debug('output: {}', output_path) + log.debug("input: {}", input_path) + log.debug("output: {}", output_path) process = await asyncio.create_subprocess_shell( - f'gifsicle --resize {target[0]}x{target[1]} ' - f'{input_path} > {output_path}', + f"gifsicle --resize {target[0]}x{target[1]} " f"{input_path} > {output_path}", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) @@ -257,11 +248,11 @@ async def resize_gif(raw_data: bytes, target: tuple) -> tuple: # run it, etc. out, err = await process.communicate() - log.debug('out + err from gifsicle: {}', out + err) + log.debug("out + err from gifsicle: {}", out + err) # write over an empty data_fd data_fd = BytesIO() - output_handler = open(output_path, 'rb') + output_handler = open(output_path, "rb") data_fd.write(output_handler.read()) # close unused handlers @@ -283,40 +274,40 @@ async def resize_gif(raw_data: bytes, target: tuple) -> tuple: class IconManager: """Main icon manager.""" + def __init__(self, app): self.app = app self.storage = app.storage async def _convert_ext(self, icon: Icon, target: str): - target = 'jpeg' if target == 'jpg' else target + target = "jpeg" if target == "jpg" else target target_mime = get_mime(target) - log.info('converting from {} to {}', icon.mime, target_mime) + log.info("converting from {} to {}", icon.mime, target_mime) - target_path = IMAGE_FOLDER / f'{icon.key}_{icon.icon_hash}.{target}' + target_path = IMAGE_FOLDER / f"{icon.key}_{icon.icon_hash}.{target}" if target_path.exists(): return Icon(icon.key, icon.icon_hash, target_mime) image = Image.open(icon.as_path) - target_fd = target_path.open('wb') + target_fd = target_path.open("wb") - if target == 'jpeg': - image = image.convert('RGB') + if target == "jpeg": + image = image.convert("RGB") image.save(target_fd, format=target) target_fd.close() return Icon(icon.key, icon.icon_hash, target_mime) - async def generic_get(self, scope, key, icon_hash, - **kwargs) -> Optional[Icon]: + async def generic_get(self, scope, key, icon_hash, **kwargs) -> Optional[Icon]: """Get any icon.""" - log.debug('GET {} {} {}', scope, key, icon_hash) + log.debug("GET {} {} {}", scope, key, icon_hash) key = str(key) - hash_query = 'AND hash = $3' if icon_hash else '' + hash_query = "AND hash = $3" if icon_hash else "" # hacky solution to only add icon_hash # when needed. @@ -325,18 +316,21 @@ class IconManager: if icon_hash: args.append(icon_hash) - icon_row = await self.storage.db.fetchrow(f""" + icon_row = await self.storage.db.fetchrow( + f""" SELECT key, hash, mime FROM icons WHERE scope = $1 AND key = $2 {hash_query} - """, *args) + """, + *args, + ) if not icon_row: return None - icon = Icon(icon_row['key'], icon_row['hash'], icon_row['mime']) + icon = Icon(icon_row["key"], icon_row["hash"], icon_row["mime"]) # ensure we aren't messing with NULLs everywhere. if icon.as_pathlib is None: @@ -349,18 +343,16 @@ class IconManager: if icon.extension is None: return None - if 'ext' in kwargs and kwargs['ext'] != icon.extension: - return await self._convert_ext(icon, kwargs['ext']) + if "ext" in kwargs and kwargs["ext"] != icon.extension: + return await self._convert_ext(icon, kwargs["ext"]) return icon async def get_guild_icon(self, guild_id: int, icon_hash: str, **kwargs): """Get an icon for a guild.""" - return await self.generic_get( - 'guild', guild_id, icon_hash, **kwargs) + return await self.generic_get("guild", guild_id, icon_hash, **kwargs) - async def put(self, scope: str, key: str, - b64_data: str, **kwargs) -> Icon: + async def put(self, scope: str, key: str, b64_data: str, **kwargs) -> Icon: """Insert an icon.""" if b64_data is None: return _invalid(kwargs) @@ -373,23 +365,22 @@ class IconManager: # get an extension for the given data uri extension = get_ext(mime) - if 'bsize' in kwargs and len(raw_data) > kwargs['bsize']: + if "bsize" in kwargs and len(raw_data) > kwargs["bsize"]: return _invalid(kwargs) # size management is different for gif files # as they're composed of multiple frames. - if 'size' in kwargs and mime == 'image/gif': - data_fd, raw_data = await resize_gif(raw_data, kwargs['size']) - elif 'size' in kwargs: + if "size" in kwargs and mime == "image/gif": + data_fd, raw_data = await resize_gif(raw_data, kwargs["size"]) + elif "size" in kwargs: image = Image.open(data_fd) - if mime == 'image/jpeg': + if mime == "image/jpeg": image = image.convert("RGB") - want = kwargs['size'] + want = kwargs["size"] - log.info('resizing from {} to {}', - image.size, want) + log.info("resizing from {} to {}", image.size, want) resized = image.resize(want, resample=Image.LANCZOS) @@ -404,23 +395,26 @@ class IconManager: # calculate sha256 # ignore icon hashes if we're talking about emoji - icon_hash = (await calculate_hash(data_fd) - if scope != 'emoji' - else None) + icon_hash = await calculate_hash(data_fd) if scope != "emoji" else None - if scope == 'user' and mime == 'image/gif': - icon_hash = f'a_{icon_hash}' + if scope == "user" and mime == "image/gif": + icon_hash = f"a_{icon_hash}" - log.debug('PUT icon {!r} {!r} {!r} {!r}', - scope, key, icon_hash, mime) + log.debug("PUT icon {!r} {!r} {!r} {!r}", scope, key, icon_hash, mime) - await self.storage.db.execute(""" + await self.storage.db.execute( + """ INSERT INTO icons (scope, key, hash, mime) VALUES ($1, $2, $3, $4) - """, scope, str(key), icon_hash, mime) + """, + scope, + str(key), + icon_hash, + mime, + ) # write it off to fs - icon_path = IMAGE_FOLDER / f'{key}_{icon_hash}.{extension}' + icon_path = IMAGE_FOLDER / f"{key}_{icon_hash}.{extension}" icon_path.write_bytes(raw_data) # copy from data_fd to icon_fd @@ -434,57 +428,80 @@ class IconManager: if not icon: return - log.debug('DEL {}', - icon) + log.debug("DEL {}", icon) # dereference - await self.storage.db.execute(""" + await self.storage.db.execute( + """ UPDATE users SET avatar = NULL WHERE avatar = $1 - """, icon.icon_hash) + """, + icon.icon_hash, + ) - await self.storage.db.execute(""" + await self.storage.db.execute( + """ UPDATE group_dm_channels SET icon = NULL WHERE icon = $1 - """, icon.icon_hash) + """, + icon.icon_hash, + ) - await self.storage.db.execute(""" + await self.storage.db.execute( + """ DELETE FROM guild_emoji WHERE image = $1 - """, icon.icon_hash) + """, + icon.icon_hash, + ) - await self.storage.db.execute(""" + await self.storage.db.execute( + """ UPDATE guilds SET icon = NULL WHERE icon = $1 - """, icon.icon_hash) + """, + icon.icon_hash, + ) - await self.storage.db.execute(""" + await self.storage.db.execute( + """ UPDATE guilds SET splash = NULL WHERE splash = $1 - """, icon.icon_hash) + """, + icon.icon_hash, + ) - await self.storage.db.execute(""" + await self.storage.db.execute( + """ UPDATE guilds SET banner = NULL WHERE banner = $1 - """, icon.icon_hash) + """, + icon.icon_hash, + ) - await self.storage.db.execute(""" + await self.storage.db.execute( + """ UPDATE group_dm_channels SET icon = NULL WHERE icon = $1 - """, icon.icon_hash) + """, + icon.icon_hash, + ) - await self.storage.db.execute(""" + await self.storage.db.execute( + """ DELETE FROM icons WHERE hash = $1 - """, icon.icon_hash) + """, + icon.icon_hash, + ) - paths = IMAGE_FOLDER.glob(f'{icon.key}_{icon.icon_hash}.*') + paths = IMAGE_FOLDER.glob(f"{icon.key}_{icon.icon_hash}.*") for path in paths: try: @@ -492,11 +509,9 @@ class IconManager: except FileNotFoundError: pass - async def update(self, scope: str, key: str, - new_icon_data: str, **kwargs) -> Icon: + async def update(self, scope: str, key: str, new_icon_data: str, **kwargs) -> Icon: """Update an icon on a key.""" - old_icon_hash = await self.storage.db.fetchval( - _gen_update_sql(scope), key) + old_icon_hash = await self.storage.db.fetchval(_gen_update_sql(scope), key) # converting key to str only here since from here onwards # its operations on the icons table (or a dereference with diff --git a/litecord/jobs.py b/litecord/jobs.py index b27099e..4ad3852 100644 --- a/litecord/jobs.py +++ b/litecord/jobs.py @@ -20,6 +20,7 @@ along with this program. If not, see . import asyncio from logbook import Logger + log = Logger(__name__) @@ -30,6 +31,7 @@ class JobManager: use helpers such as asyncio.gather and asyncio.Task.all_tasks. It only uses its own internal list of jobs. """ + def __init__(self, loop=None): self.loop = loop or asyncio.get_event_loop() self.jobs = [] @@ -41,13 +43,11 @@ class JobManager: try: await coro except Exception: - log.exception('Error while running job') + log.exception("Error while running job") def spawn(self, coro): """Spawn a given future or coroutine in the background.""" - task = self.loop.create_task( - self._wrapper(coro) - ) + task = self.loop.create_task(self._wrapper(coro)) self.jobs.append(task) diff --git a/litecord/permissions.py b/litecord/permissions.py index 5029644..11656a1 100644 --- a/litecord/permissions.py +++ b/litecord/permissions.py @@ -26,40 +26,42 @@ from quart import current_app as app # type for all the fields _i = ctypes.c_uint8 + class _RawPermsBits(ctypes.LittleEndianStructure): """raw bitfield for discord's permission number.""" + _fields_ = [ - ('create_invites', _i, 1), - ('kick_members', _i, 1), - ('ban_members', _i, 1), - ('administrator', _i, 1), - ('manage_channels', _i, 1), - ('manage_guild', _i, 1), - ('add_reactions', _i, 1), - ('view_audit_log', _i, 1), - ('priority_speaker', _i, 1), - ('stream', _i, 1), - ('read_messages', _i, 1), - ('send_messages', _i, 1), - ('send_tts', _i, 1), - ('manage_messages', _i, 1), - ('embed_links', _i, 1), - ('attach_files', _i, 1), - ('read_history', _i, 1), - ('mention_everyone', _i, 1), - ('external_emojis', _i, 1), - ('_unused2', _i, 1), - ('connect', _i, 1), - ('speak', _i, 1), - ('mute_members', _i, 1), - ('deafen_members', _i, 1), - ('move_members', _i, 1), - ('use_voice_activation', _i, 1), - ('change_nickname', _i, 1), - ('manage_nicknames', _i, 1), - ('manage_roles', _i, 1), - ('manage_webhooks', _i, 1), - ('manage_emojis', _i, 1), + ("create_invites", _i, 1), + ("kick_members", _i, 1), + ("ban_members", _i, 1), + ("administrator", _i, 1), + ("manage_channels", _i, 1), + ("manage_guild", _i, 1), + ("add_reactions", _i, 1), + ("view_audit_log", _i, 1), + ("priority_speaker", _i, 1), + ("stream", _i, 1), + ("read_messages", _i, 1), + ("send_messages", _i, 1), + ("send_tts", _i, 1), + ("manage_messages", _i, 1), + ("embed_links", _i, 1), + ("attach_files", _i, 1), + ("read_history", _i, 1), + ("mention_everyone", _i, 1), + ("external_emojis", _i, 1), + ("_unused2", _i, 1), + ("connect", _i, 1), + ("speak", _i, 1), + ("mute_members", _i, 1), + ("deafen_members", _i, 1), + ("move_members", _i, 1), + ("use_voice_activation", _i, 1), + ("change_nickname", _i, 1), + ("manage_nicknames", _i, 1), + ("manage_roles", _i, 1), + ("manage_webhooks", _i, 1), + ("manage_emojis", _i, 1), ] @@ -72,16 +74,14 @@ class Permissions(ctypes.Union): val The permissions value as an integer. """ - _fields_ = [ - ('bits', _RawPermsBits), - ('binary', ctypes.c_uint64), - ] + + _fields_ = [("bits", _RawPermsBits), ("binary", ctypes.c_uint64)] def __init__(self, val: int): self.binary = val def __repr__(self): - return f'' + return f"" def __int__(self): return self.binary @@ -95,11 +95,15 @@ async def get_role_perms(guild_id, role_id, storage=None) -> Permissions: if not storage: storage = app.storage - perms = await storage.db.fetchval(""" + perms = await storage.db.fetchval( + """ SELECT permissions FROM roles WHERE guild_id = $1 AND id = $2 - """, guild_id, role_id) + """, + guild_id, + role_id, + ) return Permissions(perms) @@ -118,11 +122,14 @@ async def base_permissions(member_id, guild_id, storage=None) -> Permissions: if not storage: storage = app.storage - owner_id = await storage.db.fetchval(""" + owner_id = await storage.db.fetchval( + """ SELECT owner_id FROM guilds WHERE id = $1 - """, guild_id) + """, + guild_id, + ) if owner_id == member_id: return ALL_PERMISSIONS @@ -130,20 +137,27 @@ async def base_permissions(member_id, guild_id, storage=None) -> Permissions: # get permissions for @everyone permissions = await get_role_perms(guild_id, guild_id, storage) - role_ids = await storage.db.fetch(""" + role_ids = await storage.db.fetch( + """ SELECT role_id FROM member_roles WHERE guild_id = $1 AND user_id = $2 - """, guild_id, member_id) + """, + guild_id, + member_id, + ) role_perms = [] for row in role_ids: - rperm = await storage.db.fetchval(""" + rperm = await storage.db.fetchval( + """ SELECT permissions FROM roles WHERE id = $1 - """, row['role_id']) + """, + row["role_id"], + ) role_perms.append(rperm) @@ -164,16 +178,17 @@ def overwrite_mix(perms: Permissions, overwrite: dict) -> Permissions: result = perms.binary # negate the permissions that are denied - result &= ~overwrite['deny'] + result &= ~overwrite["deny"] # combine the permissions that are allowed - result |= overwrite['allow'] + result |= overwrite["allow"] return Permissions(result) -def overwrite_find_mix(perms: Permissions, overwrites: dict, - target_id: int) -> Permissions: +def overwrite_find_mix( + perms: Permissions, overwrites: dict, target_id: int +) -> Permissions: """Mix a given permission with a given overwrite. Returns the given permission if an overwrite is not found. @@ -201,19 +216,25 @@ def overwrite_find_mix(perms: Permissions, overwrites: dict, return perms -async def role_permissions(guild_id: int, role_id: int, - channel_id: int, storage=None) -> Permissions: +async def role_permissions( + guild_id: int, role_id: int, channel_id: int, storage=None +) -> Permissions: """Get the permissions for a role, in relation to a channel""" if not storage: storage = app.storage perms = await get_role_perms(guild_id, role_id, storage) - overwrite = await storage.db.fetchrow(""" + overwrite = await storage.db.fetchrow( + """ SELECT allow, deny FROM channel_overwrites WHERE channel_id = $1 AND target_type = $2 AND target_role = $3 - """, channel_id, 1, role_id) + """, + channel_id, + 1, + role_id, + ) if overwrite: perms = overwrite_mix(perms, overwrite) @@ -221,10 +242,13 @@ async def role_permissions(guild_id: int, role_id: int, return perms -async def compute_overwrites(base_perms: Permissions, - user_id, channel_id: int, - guild_id: Optional[int] = None, - storage=None): +async def compute_overwrites( + base_perms: Permissions, + user_id, + channel_id: int, + guild_id: Optional[int] = None, + storage=None, +): """Compute the permissions in the context of a channel.""" if not storage: storage = app.storage @@ -245,7 +269,7 @@ async def compute_overwrites(base_perms: Permissions, return ALL_PERMISSIONS # make it a map for better usage - overwrites = {int(o['id']): o for o in overwrites} + overwrites = {int(o["id"]): o for o in overwrites} perms = overwrite_find_mix(perms, overwrites, guild_id) @@ -260,14 +284,11 @@ async def compute_overwrites(base_perms: Permissions, for role_id in role_ids: overwrite = overwrites.get(role_id) if overwrite: - allow |= overwrite['allow'] - deny |= overwrite['deny'] + allow |= overwrite["allow"] + deny |= overwrite["deny"] # final step for roles: mix - perms = overwrite_mix(perms, { - 'allow': allow, - 'deny': deny - }) + perms = overwrite_mix(perms, {"allow": allow, "deny": deny}) # apply member specific overwrites perms = overwrite_find_mix(perms, overwrites, user_id) @@ -275,8 +296,7 @@ async def compute_overwrites(base_perms: Permissions, return perms -async def get_permissions(member_id: int, channel_id, - *, storage=None) -> Permissions: +async def get_permissions(member_id: int, channel_id, *, storage=None) -> Permissions: """Get the permissions for a user in a channel.""" if not storage: storage = app.storage @@ -290,4 +310,5 @@ async def get_permissions(member_id: int, channel_id, base_perms = await base_permissions(member_id, guild_id, storage) return await compute_overwrites( - base_perms, member_id, channel_id, guild_id, storage) + base_perms, member_id, channel_id, guild_id, storage + ) diff --git a/litecord/presence.py b/litecord/presence.py index d9c19ad..55d7832 100644 --- a/litecord/presence.py +++ b/litecord/presence.py @@ -32,62 +32,56 @@ def status_cmp(status: str, other_status: str) -> bool: in the status hierarchy. """ - hierarchy = { - 'online': 3, - 'idle': 2, - 'dnd': 1, - 'offline': 0, - None: -1, - } + hierarchy = {"online": 3, "idle": 2, "dnd": 1, "offline": 0, None: -1} return hierarchy[status] > hierarchy[other_status] def _best_presence(shards): """Find the 'best' presence given a list of GatewayState.""" - best = {'status': None, 'game': None} + best = {"status": None, "game": None} for state in shards: presence = state.presence - status = presence['status'] + status = presence["status"] if not presence: continue # shards with a better status # in the hierarchy are treated as best - if status_cmp(status, best['status']): - best['status'] = status + if status_cmp(status, best["status"]): + best["status"] = status # if we have any game, use it - if presence['game'] is not None: - best['game'] = presence['game'] + if presence["game"] is not None: + best["game"] = presence["game"] # best['status'] is None when no # status was good enough. - return None if not best['status'] else best + return None if not best["status"] else best def fill_presence(presence: dict, *, game=None) -> dict: """Fill a given presence object with some specific fields.""" - presence['client_status'] = {} - presence['mobile'] = False + presence["client_status"] = {} + presence["mobile"] = False - if 'since' not in presence: - presence['since'] = 0 + if "since" not in presence: + presence["since"] = 0 # fill game and activities array depending if game # is there or not - game = game or presence.get('game') + game = game or presence.get("game") # casting to bool since a game of {} is still invalid if game: - presence['game'] = game - presence['activities'] = [game] + presence["game"] = game + presence["activities"] = [game] else: - presence['game'] = None - presence['activities'] = [] + presence["game"] = None + presence["activities"] = [] return presence @@ -96,14 +90,13 @@ async def _pres(storage, user_id: int, status_obj: dict) -> dict: """Convert a given status into a presence, given the User ID and the :class:`Storage` instance.""" ext = { - 'user': await storage.get_user(user_id), - 'activities': [], - + "user": await storage.get_user(user_id), + "activities": [], # NOTE: we are purposefully overwriting the fields, as there # isn't any push for us to actually implement mobile detection, or # web detection, etc. - 'client_status': {}, - 'mobile': False, + "client_status": {}, + "mobile": False, } return fill_presence({**status_obj, **ext}) @@ -115,14 +108,16 @@ class PresenceManager: Has common functions to deal with fetching or updating presences, including side-effects (events). """ + def __init__(self, app): self.storage = app.storage self.user_storage = app.user_storage self.state_manager = app.state_manager self.dispatcher = app.dispatcher - async def guild_presences(self, member_ids: List[int], - guild_id: int) -> List[Dict[Any, str]]: + async def guild_presences( + self, member_ids: List[int], guild_id: int + ) -> List[Dict[Any, str]]: """Fetch all presences in a guild.""" # this works via fetching all connected GatewayState on a guild # then fetching its respective member and merging that info with @@ -132,34 +127,36 @@ class PresenceManager: presences = [] for state in states: - member = await self.storage.get_member_data_one( - guild_id, state.user_id) + member = await self.storage.get_member_data_one(guild_id, state.user_id) - game = state.presence.get('game', None) + game = state.presence.get("game", None) # only use the data we need. - presences.append(fill_presence({ - 'user': member['user'], - 'roles': member['roles'], - 'guild_id': str(guild_id), - - # if a state is connected to the guild - # we assume its online. - 'status': state.presence.get('status', 'online'), - }, game=game)) + presences.append( + fill_presence( + { + "user": member["user"], + "roles": member["roles"], + "guild_id": str(guild_id), + # if a state is connected to the guild + # we assume its online. + "status": state.presence.get("status", "online"), + }, + game=game, + ) + ) return presences - async def dispatch_guild_pres(self, guild_id: int, - user_id: int, new_state: dict): + async def dispatch_guild_pres(self, guild_id: int, user_id: int, new_state: dict): """Dispatch a Presence update to an entire guild.""" state = dict(new_state) member = await self.storage.get_member_data_one(guild_id, user_id) - game = state['game'] + game = state["game"] - lazy_guild_store = self.dispatcher.backends['lazy_guild'] + lazy_guild_store = self.dispatcher.backends["lazy_guild"] lists = lazy_guild_store.get_gml_guild(guild_id) # shards that are in lazy guilds with 'everyone' @@ -168,49 +165,44 @@ class PresenceManager: for member_list in lists: session_ids = await member_list.pres_update( - int(member['user']['id']), - { - 'roles': member['roles'], - 'status': state['status'], - 'game': game - } + int(member["user"]["id"]), + {"roles": member["roles"], "status": state["status"], "game": game}, ) - log.debug('Lazy Dispatch to {}', - len(session_ids)) + log.debug("Lazy Dispatch to {}", len(session_ids)) # if we are on the 'everyone' member list, we don't # dispatch a PRESENCE_UPDATE for those shards. if member_list.channel_id == member_list.guild_id: in_lazy.extend(session_ids) - pres_update_payload = fill_presence({ - 'guild_id': str(guild_id), - 'user': member['user'], - 'roles': member['roles'], - 'status': state['status'], - }, game=game) + pres_update_payload = fill_presence( + { + "guild_id": str(guild_id), + "user": member["user"], + "roles": member["roles"], + "status": state["status"], + }, + game=game, + ) # given a session id, return if the session id actually connects to # a given user, and if the state has not been dispatched via lazy guild. def _session_check(session_id): state = self.state_manager.fetch_raw(session_id) - uid = int(member['user']['id']) + uid = int(member["user"]["id"]) if not state: return False # we don't want to send a presence update # to the same user - return (state.user_id != uid and - session_id not in in_lazy) + return state.user_id != uid and session_id not in in_lazy # everyone not in lazy guild mode # gets a PRESENCE_UPDATE await self.dispatcher.dispatch_filter( - 'guild', guild_id, - _session_check, - 'PRESENCE_UPDATE', pres_update_payload + "guild", guild_id, _session_check, "PRESENCE_UPDATE", pres_update_payload ) return in_lazy @@ -220,25 +212,25 @@ class PresenceManager: Also dispatches the presence to all the users' friends """ - if state['status'] == 'invisible': - state['status'] = 'offline' + if state["status"] == "invisible": + state["status"] = "offline" # TODO: shard-aware guild_ids = await self.user_storage.get_user_guilds(user_id) for guild_id in guild_ids: - await self.dispatch_guild_pres( - guild_id, user_id, state) + await self.dispatch_guild_pres(guild_id, user_id, state) # dispatch to all friends that are subscribed to them user = await self.storage.get_user(user_id) - game = state['game'] + game = state["game"] await self.dispatcher.dispatch( - 'friend', user_id, 'PRESENCE_UPDATE', fill_presence({ - 'user': user, - 'status': state['status'], - }, game=game)) + "friend", + user_id, + "PRESENCE_UPDATE", + fill_presence({"user": user, "status": state["status"]}, game=game), + ) async def friend_presences(self, friend_ids: Iterable[int]) -> List[Presence]: """Fetch presences for a group of users. @@ -254,22 +246,25 @@ class PresenceManager: if not friend_states: # append offline - res.append(await _pres(storage, friend_id, { - 'afk': False, - 'status': 'offline', - 'game': None, - 'since': 0 - })) + res.append( + await _pres( + storage, + friend_id, + {"afk": False, "status": "offline", "game": None, "since": 0}, + ) + ) continue # filter the best shards: # - all with id 0 (are the first shards in the collection) or # - all shards with count = 1 (single shards) - good_shards = list(filter( - lambda state: state.shard[0] == 0 or state.shard[1] == 1, - friend_states - )) + good_shards = list( + filter( + lambda state: state.shard[0] == 0 or state.shard[1] == 1, + friend_states, + ) + ) if good_shards: best_pres = _best_presence(good_shards) diff --git a/litecord/pubsub/__init__.py b/litecord/pubsub/__init__.py index a349982..3840695 100644 --- a/litecord/pubsub/__init__.py +++ b/litecord/pubsub/__init__.py @@ -24,6 +24,11 @@ from .channel import ChannelDispatcher from .friend import FriendDispatcher from .lazy_guild import LazyGuildDispatcher -__all__ = ['GuildDispatcher', 'MemberDispatcher', - 'UserDispatcher', 'ChannelDispatcher', - 'FriendDispatcher', 'LazyGuildDispatcher'] +__all__ = [ + "GuildDispatcher", + "MemberDispatcher", + "UserDispatcher", + "ChannelDispatcher", + "FriendDispatcher", + "LazyGuildDispatcher", +] diff --git a/litecord/pubsub/channel.py b/litecord/pubsub/channel.py index 443d1e3..fe3c215 100644 --- a/litecord/pubsub/channel.py +++ b/litecord/pubsub/channel.py @@ -38,23 +38,20 @@ def gdm_recipient_view(orig: dict, user_id: int) -> dict: # make a copy or the original channel object data = dict(orig) - idx = index_by_func( - lambda user: user['id'] == str(user_id), - data['recipients'] - ) + idx = index_by_func(lambda user: user["id"] == str(user_id), data["recipients"]) - data['recipients'].pop(idx) + data["recipients"].pop(idx) return data class ChannelDispatcher(DispatcherWithFlags): """Main channel Pub/Sub logic.""" + KEY_TYPE = int VAL_TYPE = int - async def dispatch(self, channel_id, - event: str, data: Any) -> List[str]: + async def dispatch(self, channel_id, event: str, data: Any) -> List[str]: """Dispatch an event to a channel.""" # get everyone who is subscribed # and store the number of states we dispatched the event to @@ -75,9 +72,11 @@ class ChannelDispatcher(DispatcherWithFlags): # TODO: make a fetch_states that fetches shards # - with id 0 (count any) OR # - single shards (id=0, count=1) - states = (self.sm.fetch_states(user_id, guild_id) - if guild_id else - self.sm.user_states(user_id)) + states = ( + self.sm.fetch_states(user_id, guild_id) + if guild_id + else self.sm.user_states(user_id) + ) # unsub people who don't have any states tied to the channel. if not states: @@ -85,28 +84,28 @@ class ChannelDispatcher(DispatcherWithFlags): continue # skip typing events for users that don't want it - if event.startswith('TYPING_') and \ - not self.flags_get(channel_id, user_id, 'typing', True): + if event.startswith("TYPING_") and not self.flags_get( + channel_id, user_id, "typing", True + ): continue cur_sess = [] - if event in ('CHANNEL_CREATE', 'CHANNEL_UPDATE') \ - and data.get('type') == ChannelType.GROUP_DM.value: + if ( + event in ("CHANNEL_CREATE", "CHANNEL_UPDATE") + and data.get("type") == ChannelType.GROUP_DM.value + ): # we edit the channel payload so it doesn't show # the user as a recipient new_data = gdm_recipient_view(data, user_id) - cur_sess = await self._dispatch_states( - states, event, new_data) + cur_sess = await self._dispatch_states(states, event, new_data) else: - cur_sess = await self._dispatch_states( - states, event, data) + cur_sess = await self._dispatch_states(states, event, data) sessions.extend(cur_sess) dispatched += len(cur_sess) - log.info('Dispatched chan={} {!r} to {} states', - channel_id, event, dispatched) + log.info("Dispatched chan={} {!r} to {} states", channel_id, event, dispatched) return sessions diff --git a/litecord/pubsub/dispatcher.py b/litecord/pubsub/dispatcher.py index 493162a..7c4246f 100644 --- a/litecord/pubsub/dispatcher.py +++ b/litecord/pubsub/dispatcher.py @@ -80,8 +80,7 @@ class Dispatcher: """ raise NotImplementedError - async def _dispatch_states(self, states: list, event: str, - data) -> List[str]: + async def _dispatch_states(self, states: list, event: str, data) -> List[str]: """Dispatch an event to a list of states.""" res = [] @@ -90,7 +89,7 @@ class Dispatcher: await state.ws.dispatch(event, data) res.append(state.session_id) except Exception: - log.exception('error while dispatching') + log.exception("error while dispatching") return res @@ -102,6 +101,7 @@ class DispatcherWithState(Dispatcher): of boilerplate code on Pub/Sub backends that have that dictionary. """ + def __init__(self, main): super().__init__(main) diff --git a/litecord/pubsub/friend.py b/litecord/pubsub/friend.py index 2717f6d..0d0ae6b 100644 --- a/litecord/pubsub/friend.py +++ b/litecord/pubsub/friend.py @@ -31,6 +31,7 @@ class FriendDispatcher(DispatcherWithState): channels. If that friend updates their presence, it will be broadcasted through that channel to basically all their friends. """ + KEY_TYPE = int VAL_TYPE = int @@ -44,17 +45,13 @@ class FriendDispatcher(DispatcherWithState): # since relationships broadcast to all shards. sessions.extend( await self.main_dispatcher.dispatch_filter( - 'user', peer_id, func, event, data) + "user", peer_id, func, event, data + ) ) - log.info('dispatched uid={} {!r} to {} states', - user_id, event, len(sessions)) + log.info("dispatched uid={} {!r} to {} states", user_id, event, len(sessions)) return sessions async def dispatch(self, user_id, event, data): - return await self.dispatch_filter( - user_id, - lambda sess_id: True, - event, data, - ) + return await self.dispatch_filter(user_id, lambda sess_id: True, event, data) diff --git a/litecord/pubsub/guild.py b/litecord/pubsub/guild.py index 462fb63..21d5143 100644 --- a/litecord/pubsub/guild.py +++ b/litecord/pubsub/guild.py @@ -29,11 +29,11 @@ log = Logger(__name__) class GuildDispatcher(DispatcherWithFlags): """Guild backend for Pub/Sub""" + KEY_TYPE = int VAL_TYPE = int - async def _chan_action(self, action: str, - guild_id: int, user_id: int, flags=None): + async def _chan_action(self, action: str, guild_id: int, user_id: int, flags=None): """Send an action to all channels of the guild.""" flags = flags or {} chan_ids = await self.app.storage.get_channel_ids(guild_id) @@ -43,33 +43,31 @@ class GuildDispatcher(DispatcherWithFlags): # only do an action for users that can # actually read the channel to start with. chan_perms = await get_permissions( - user_id, chan_id, - storage=self.main_dispatcher.app.storage) + user_id, chan_id, storage=self.main_dispatcher.app.storage + ) if not chan_perms.bits.read_messages: - log.debug('skipping cid={}, no read messages', - chan_id) + log.debug("skipping cid={}, no read messages", chan_id) continue - log.debug('sending raw action {!r} to chan={}', - action, chan_id) + log.debug("sending raw action {!r} to chan={}", action, chan_id) # for now, only sub() has support for flags. # it is an idea to have flags support for other actions args = [] - if action == 'sub': + if action == "sub": chanflags = dict(flags) # channels don't need presence flags try: - chanflags.pop('presence') + chanflags.pop("presence") except KeyError: pass args.append(chanflags) await self.main_dispatcher.action( - 'channel', action, chan_id, user_id, *args + "channel", action, chan_id, user_id, *args ) async def _chan_call(self, meth: str, guild_id: int, *args): @@ -77,26 +75,24 @@ class GuildDispatcher(DispatcherWithFlags): in the guild.""" chan_ids = await self.app.storage.get_channel_ids(guild_id) - chan_dispatcher = self.main_dispatcher.backends['channel'] + chan_dispatcher = self.main_dispatcher.backends["channel"] method = getattr(chan_dispatcher, meth) for chan_id in chan_ids: - log.debug('calling {} to chan={}', - meth, chan_id) + log.debug("calling {} to chan={}", meth, chan_id) await method(chan_id, *args) async def sub(self, guild_id: int, user_id: int, flags=None): """Subscribe a user to the guild.""" await super().sub(guild_id, user_id, flags) - await self._chan_action('sub', guild_id, user_id, flags) + await self._chan_action("sub", guild_id, user_id, flags) async def unsub(self, guild_id: int, user_id: int): """Unsubscribe a user from the guild.""" await super().unsub(guild_id, user_id) - await self._chan_action('unsub', guild_id, user_id) + await self._chan_action("unsub", guild_id, user_id) - async def dispatch_filter(self, guild_id: int, func, - event: str, data: Any): + async def dispatch_filter(self, guild_id: int, func, event: str, data: Any): """Selectively dispatch to session ids that have func(session_id) true.""" user_ids = self.state[guild_id] @@ -121,31 +117,23 @@ class GuildDispatcher(DispatcherWithFlags): # note that this does not equate to any unsubscription # of the channel. - if event.startswith('PRESENCE_') and \ - not self.flags_get(guild_id, user_id, 'presence', True): + if event.startswith("PRESENCE_") and not self.flags_get( + guild_id, user_id, "presence", True + ): continue # filter the ones that matter - states = list(filter( - lambda state: func(state.session_id), states - )) + states = list(filter(lambda state: func(state.session_id), states)) - cur_sess = await self._dispatch_states( - states, event, data) + cur_sess = await self._dispatch_states(states, event, data) sessions.extend(cur_sess) dispatched += len(cur_sess) - log.info('Dispatched {} {!r} to {} states', - guild_id, event, dispatched) + log.info("Dispatched {} {!r} to {} states", guild_id, event, dispatched) return sessions - async def dispatch(self, guild_id: int, - event: str, data: Any): + async def dispatch(self, guild_id: int, event: str, data: Any): """Dispatch an event to all subscribers of the guild.""" - return await self.dispatch_filter( - guild_id, - lambda sess_id: True, - event, data, - ) + return await self.dispatch_filter(guild_id, lambda sess_id: True, event, data) diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index d0e26c2..b391186 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -28,16 +28,17 @@ lazy guilds: import asyncio from collections import defaultdict -from typing import ( - Any, List, Dict, Union, Optional, Iterable, Iterator, Tuple, Set -) +from typing import Any, List, Dict, Union, Optional, Iterable, Iterator, Tuple, Set from dataclasses import dataclass, asdict, field from logbook import Logger from litecord.pubsub.dispatcher import Dispatcher from litecord.permissions import ( - Permissions, overwrite_find_mix, get_permissions, role_permissions + Permissions, + overwrite_find_mix, + get_permissions, + role_permissions, ) from litecord.utils import index_by_func from litecord.utils import mmh3 @@ -55,6 +56,7 @@ MAX_ROLES = 250 @dataclass class GroupInfo: """Store information about a specific group.""" + gid: GroupID name: str position: int @@ -85,6 +87,7 @@ class MemberList: channel, and since only roles with Read Messages can be in the list, we need to store that information) """ + groups: List[GroupInfo] = field(default_factory=list) data: Dict[GroupID, List[int]] = field(default_factory=dict) presences: Dict[int, Presence] = field(default_factory=dict) @@ -98,8 +101,7 @@ class MemberList: # ignore the bool status of overwrites return all( - bool(list_dict[k]) - for k in ('groups', 'data', 'presences', 'members') + bool(list_dict[k]) for k in ("groups", "data", "presences", "members") ) def __iter__(self): @@ -125,7 +127,7 @@ class MemberList: for group, member_ids in self: count = len(member_ids) - if group.gid == 'offline': + if group.gid == "offline": yield group, member_ids continue @@ -163,38 +165,36 @@ class MemberList: @dataclass class Operation: """Represents a member list operation.""" + list_op: str params: Dict[str, Any] @property def to_dict(self) -> dict: """Return a dictionary representation of the operation.""" - if self.list_op not in ('SYNC', 'INVALIDATE', - 'INSERT', 'UPDATE', 'DELETE'): - raise ValueError('Invalid list operator') + if self.list_op not in ("SYNC", "INVALIDATE", "INSERT", "UPDATE", "DELETE"): + raise ValueError("Invalid list operator") - res = { - 'op': self.list_op - } + res = {"op": self.list_op} - if self.list_op == 'SYNC': - res['items'] = self.params['items'] + if self.list_op == "SYNC": + res["items"] = self.params["items"] - if self.list_op in ('SYNC', 'INVALIDATE'): - res['range'] = self.params['range'] + if self.list_op in ("SYNC", "INVALIDATE"): + res["range"] = self.params["range"] - if self.list_op in ('INSERT', 'DELETE', 'UPDATE'): - res['index'] = self.params['index'] + if self.list_op in ("INSERT", "DELETE", "UPDATE"): + res["index"] = self.params["index"] - if self.list_op in ('INSERT', 'UPDATE'): - res['item'] = self.params['item'] + if self.list_op in ("INSERT", "UPDATE"): + res["item"] = self.params["item"] return res def _to_simple_group(presence: dict) -> str: """Return a simple group (not a role), given a presence.""" - return 'offline' if presence['status'] == 'offline' else 'online' + return "offline" if presence["status"] == "offline" else "online" async def everyone_allow(gml) -> bool: @@ -206,10 +206,7 @@ async def everyone_allow(gml) -> bool: If the role can't access the list, then the list keeps its list ID. """ everyone_perms = await role_permissions( - gml.guild_id, - gml.guild_id, - gml.channel_id, - storage=gml.storage + gml.guild_id, gml.guild_id, gml.channel_id, storage=gml.storage ) return bool(everyone_perms.bits.read_messages) @@ -218,16 +215,17 @@ async def everyone_allow(gml) -> bool: def merge(member: dict, presence: Presence) -> dict: """Merge a member dictionary and a presence dictionary into an item.""" - return {**member, **{ - 'presence': { - 'user': { - 'id': str(member['user']['id']), - }, - 'status': presence['status'], - 'game': presence['game'], - 'activities': presence['activities'] - } - }} + return { + **member, + **{ + "presence": { + "user": {"id": str(member["user"]["id"])}, + "status": presence["status"], + "game": presence["game"], + "activities": presence["activities"], + } + }, + } class GuildMemberList: @@ -257,8 +255,8 @@ class GuildMemberList: lazy guilds to all of the userbase, but users that are bots, for example, can still rely on PRESENCE_UPDATEs. """ - def __init__(self, guild_id: int, - channel_id: int, main_lg): + + def __init__(self, guild_id: int, channel_id: int, main_lg): self.guild_id = guild_id self.channel_id = channel_id @@ -294,9 +292,7 @@ class GuildMemberList: @property def list_id(self): """get the id of the member list.""" - return ('everyone' - if self.channel_id == self.guild_id - else self._calculated_id) + return "everyone" if self.channel_id == self.guild_id else self._calculated_id @property def _calculated_id(self): @@ -310,16 +306,16 @@ class GuildMemberList: for actor_id, overwrite in self.list.overwrites.items(): allow, deny = ( - Permissions(overwrite['allow']), - Permissions(overwrite['deny']) + Permissions(overwrite["allow"]), + Permissions(overwrite["deny"]), ) if allow.bits.read_messages: - ovs_i.append(f'allow:{actor_id}') + ovs_i.append(f"allow:{actor_id}") elif deny.bits.read_messages: - ovs_i.append(f'deny:{actor_id}') + ovs_i.append(f"deny:{actor_id}") - hash_in = ','.join(ovs_i) + hash_in = ",".join(ovs_i) return str(mmh3(hash_in)) def _set_empty_list(self): @@ -334,7 +330,7 @@ class GuildMemberList: async def _fetch_overwrites(self): overwrites = await self.storage.chan_overwrites(self.channel_id) - overwrites = {int(ov['id']): ov for ov in overwrites} + overwrites = {int(ov["id"]): ov for ov in overwrites} self.list.overwrites = overwrites def _calc_member_group(self, roles: List[int], status: str): @@ -344,12 +340,11 @@ class GuildMemberList: # the first group in the list # that the member is entitled to is # the selected group for the member. - group_id = next(g.gid for g in self.list.groups - if g.gid in roles) + group_id = next(g.gid for g in self.list.groups if g.gid in roles) except StopIteration: # no group was found, so we fallback # to simple group - group_id = _to_simple_group({'status': status}) + group_id = _to_simple_group({"status": status}) return group_id @@ -361,7 +356,8 @@ class GuildMemberList: # then the final perms for that role if # any overwrite exists in the channel final_perms = overwrite_find_mix( - role_perms, self.list.overwrites, int(group.gid)) + role_perms, self.list.overwrites, int(group.gid) + ) # update the group's permissions # with the mixed ones @@ -385,26 +381,25 @@ class GuildMemberList: The list is sorted by each role's position. """ - roledata = await self.storage.db.fetch(""" + roledata = await self.storage.db.fetch( + """ SELECT id, name, hoist, position, permissions FROM roles WHERE guild_id = $1 - """, self.guild_id) + """, + self.guild_id, + ) hoisted = [ GroupInfo( - row['id'], row['name'], - row['position'], - Permissions(row['permissions']) + row["id"], row["name"], row["position"], Permissions(row["permissions"]) ) - for row in roledata if row['hoist'] + for row in roledata + if row["hoist"] ] # sort role list by position - hoisted = sorted( - hoisted, key=lambda group: group.position, - reverse=True - ) + hoisted = sorted(hoisted, key=lambda group: group.position, reverse=True) # we need to store the overwrites since # we have incoming presences to manage. @@ -419,28 +414,32 @@ class GuildMemberList: # inject default groups 'online' and 'offline' # their position is always going to be the last ones. self.list.groups = role_groups + [ - GroupInfo('online', 'online', MAX_ROLES + 1, 0), - GroupInfo('offline', 'offline', MAX_ROLES + 2, 0) + GroupInfo("online", "online", MAX_ROLES + 1, 0), + GroupInfo("offline", "offline", MAX_ROLES + 2, 0), ] - async def _get_group_for_member(self, member_id: int, - roles: List[Union[str, int]], - status: str) -> Optional[GroupID]: + async def _get_group_for_member( + self, member_id: int, roles: List[Union[str, int]], status: str + ) -> Optional[GroupID]: """Return a fitting group ID for the member.""" member_roles = list(map(int, roles)) # get the member's permissions relative to the channel # (accounting for channel overwrites) member_perms = await get_permissions( - member_id, self.channel_id, storage=self.storage) + member_id, self.channel_id, storage=self.storage + ) if not member_perms.bits.read_messages: return None # if the member is offline, we # default give them the offline group. - group_id = ('offline' if status == 'offline' - else self._calc_member_group(member_roles, status)) + group_id = ( + "offline" + if status == "offline" + else self._calc_member_group(member_roles, status) + ) return group_id @@ -450,7 +449,7 @@ class GuildMemberList: presence = self.list.presences[member_id] group_id = await self._get_group_for_member( - member_id, presence['roles'], presence['status'] + member_id, presence["roles"], presence["status"] ) # skip members that don't have any group assigned. @@ -458,9 +457,7 @@ class GuildMemberList: if group_id is None: continue - member = await self.storage.get_member_data_one( - self.guild_id, member_id - ) + member = await self.storage.get_member_data_one(self.guild_id, member_id) self.list.members[member_id] = member self.list.data[group_id].append(member_id) @@ -476,33 +473,28 @@ class GuildMemberList: except KeyError: return None - username = member['user']['username'] - nickname = member['nick'] + username = member["user"]["username"] + nickname = member["nick"] return nickname or username async def _sort_groups(self): for member_ids in self.list.data.values(): # this should update the list in-place - member_ids.sort( - key=self._display_name) + member_ids.sort(key=self._display_name) async def __init_member_list(self): """Generate the main member list with groups.""" member_ids = await self.storage.get_member_ids(self.guild_id) - presences = await self.presence.guild_presences( - member_ids, self.guild_id) + presences = await self.presence.guild_presences(member_ids, self.guild_id) # set presences in the list - self.list.presences = {int(p['user']['id']): p - for p in presences} + self.list.presences = {int(p["user"]["id"]): p for p in presences} await self._set_groups() - log.debug('init: {} members, {} groups', - len(member_ids), - len(self.list.groups)) + log.debug("init: {} members, {} groups", len(member_ids), len(self.list.groups)) # allocate a list per group self.list.data = {group.gid: [] for group in self.list.groups} @@ -547,17 +539,10 @@ class GuildMemberList: if not member_ids: continue - res.append({ - 'group': { - 'id': str(group.gid), - 'count': len(member_ids), - } - }) + res.append({"group": {"id": str(group.gid), "count": len(member_ids)}}) for member_id in member_ids: - res.append({ - 'member': self._get_member_as_item(member_id) - }) + res.append({"member": self._get_member_as_item(member_id)}) return res @@ -591,27 +576,21 @@ class GuildMemberList: except KeyError: return None - async def _dispatch_sess(self, session_ids: Iterable[str], - operations: List[Operation]): + async def _dispatch_sess( + self, session_ids: Iterable[str], operations: List[Operation] + ): """Dispatch a GUILD_MEMBER_LIST_UPDATE to the given session ids.""" # construct the payload to dispatch payload = { - 'id': self.list_id, - 'guild_id': str(self.guild_id), - - 'groups': [ - { - 'id': str(group.gid), - 'count': count, - } for group, count in self.list.groups_complete + "id": self.list_id, + "guild_id": str(self.guild_id), + "groups": [ + {"id": str(group.gid), "count": count} + for group, count in self.list.groups_complete ], - - 'ops': [ - operation.to_dict - for operation in operations - ] + "ops": [operation.to_dict for operation in operations], } states = map(self._get_state, session_ids) @@ -621,15 +600,13 @@ class GuildMemberList: if state is None: continue - await state.ws.dispatch( - 'GUILD_MEMBER_LIST_UPDATE', payload) + await state.ws.dispatch("GUILD_MEMBER_LIST_UPDATE", payload) dispatched.append(state.session_id) return dispatched - async def _resync(self, session_ids: List[str], - item_index: int) -> List[str]: + async def _resync(self, session_ids: List[str], item_index: int) -> List[str]: """Send a SYNC event to all states that are subscribed to an item. Returns @@ -649,18 +626,23 @@ class GuildMemberList: try: # get the only range where the group is in - role_range = next((r_min, r_max) for r_min, r_max in ranges - if r_min <= item_index <= r_max) + role_range = next( + (r_min, r_max) + for r_min, r_max in ranges + if r_min <= item_index <= r_max + ) except StopIteration: - log.debug('ignoring sess_id={}, no range for item {}, {}', - session_id, item_index, ranges) + log.debug( + "ignoring sess_id={}, no range for item {}, {}", + session_id, + item_index, + ranges, + ) continue # do resync-ing in the background result.append(session_id) - self.loop.create_task( - self.shard_query(session_id, [role_range]) - ) + self.loop.create_task(self.shard_query(session_id, [role_range])) return result @@ -669,10 +651,7 @@ class GuildMemberList: if item_index is None: return [] - return await self._resync( - self._get_subs(item_index), - item_index - ) + return await self._resync(self._get_subs(item_index), item_index) async def shard_query(self, session_id: str, ranges: list): """Send a GUILD_MEMBER_LIST_UPDATE event @@ -699,18 +678,13 @@ class GuildMemberList: # we direct the request to the 'everyone' gml instance # instead of the current one. everyone_perms = await role_permissions( - self.guild_id, - self.guild_id, - self.channel_id, - storage=self.storage + self.guild_id, self.guild_id, self.channel_id, storage=self.storage ) - if everyone_perms.bits.read_messages and list_id != 'everyone': + if everyone_perms.bits.read_messages and list_id != "everyone": everyone_gml = await self.main.get_gml(self.guild_id) - return await everyone_gml.shard_query( - session_id, ranges - ) + return await everyone_gml.shard_query(session_id, ranges) await self._init_check() @@ -725,10 +699,11 @@ class GuildMemberList: self.state[session_id].add((start, end)) - ops.append(Operation('SYNC', { - 'range': [start, end], - 'items': self.items[start:end] - })) + ops.append( + Operation( + "SYNC", {"range": [start, end], "items": self.items[start:end]} + ) + ) # send SYNCs to the state that requested await self._dispatch_sess([session_id], ops) @@ -780,8 +755,7 @@ class GuildMemberList: def _get_subs(self, item_index: int) -> Iterable[str]: """Get the list of subscribed states to a given item.""" return filter( - lambda sess_id: self._is_subbed(item_index, sess_id), - self.state.keys() + lambda sess_id: self._is_subbed(item_index, sess_id), self.state.keys() ) async def _pres_update_simple(self, user_id: int): @@ -794,8 +768,7 @@ class GuildMemberList: item_index = self._get_item_index(user_id) if item_index is None: - log.warning('lazy guild got invalid pres update uid={}', - user_id) + log.warning("lazy guild got invalid pres update uid={}", user_id) return [] item = self.items[item_index] @@ -804,19 +777,12 @@ class GuildMemberList: # simple update means we just give an UPDATE # operation return await self._dispatch_sess( - session_ids, - [ - Operation('UPDATE', { - 'index': item_index, - 'item': item, - }) - ] + session_ids, [Operation("UPDATE", {"index": item_index, "item": item})] ) async def _pres_update_complex( - self, user_id: int, - old_group: GroupID, rel_index: int, - new_group: GroupID): + self, user_id: int, old_group: GroupID, rel_index: int, new_group: GroupID + ): """Move a member between groups. Parameters @@ -831,17 +797,20 @@ class GuildMemberList: The group the user has to move to. """ - log.debug('complex update: uid={} old={} rel_idx={} new={}', - user_id, old_group, rel_index, new_group) + log.debug( + "complex update: uid={} old={} rel_idx={} new={}", + user_id, + old_group, + rel_index, + new_group, + ) ops = [] old_user_index = self._get_item_index(user_id) old_group_index = self._get_group_item_index(old_group) - ops.append(Operation('DELETE', { - 'index': old_user_index - })) + ops.append(Operation("DELETE", {"index": old_user_index})) # do the necessary changes self.list.data[old_group].remove(user_id) @@ -851,30 +820,35 @@ class GuildMemberList: new_user_index = self._get_item_index(user_id) - ops.append(Operation('INSERT', { - 'index': new_user_index, - - # TODO: maybe construct the new item manually - # instead of resorting to items list? - 'item': self.items[new_user_index] - })) + ops.append( + Operation( + "INSERT", + { + "index": new_user_index, + # TODO: maybe construct the new item manually + # instead of resorting to items list? + "item": self.items[new_user_index], + }, + ) + ) # put a INSERT operation if this is # the first member in the group. - if self.list.is_birth(new_group) and new_group != 'offline': - ops.append(Operation('INSERT', { - 'index': self._get_group_item_index(new_group), - 'item': { - 'group': str(new_group), 'count': 1 - } - })) + if self.list.is_birth(new_group) and new_group != "offline": + ops.append( + Operation( + "INSERT", + { + "index": self._get_group_item_index(new_group), + "item": {"group": str(new_group), "count": 1}, + }, + ) + ) # only add DELETE for the old group after # both operations. if self.list.is_empty(old_group): - ops.append(Operation('DELETE', { - 'index': old_group_index, - })) + ops.append(Operation("DELETE", {"index": old_group_index})) session_ids_old = list(self._get_subs(old_user_index)) session_ids_new = list(self._get_subs(new_user_index)) @@ -894,42 +868,39 @@ class GuildMemberList: # ) # merge both results together - return (await self._resync(session_ids_old, old_user_index) + - await self._resync(session_ids_new, new_user_index)) + return await self._resync(session_ids_old, old_user_index) + await self._resync( + session_ids_new, new_user_index + ) async def new_member(self, user_id: int): """Insert a new member.""" if not self.list: - log.info('lazy: ignoring new member from not-init {}', - user_id) + log.info("lazy: ignoring new member from not-init {}", user_id) return # fetch the new member's presence - pres = await self.presence.guild_presences( - [user_id], self.guild_id) + pres = await self.presence.guild_presences([user_id], self.guild_id) try: pres = pres[0] except IndexError: - log.warning('lazy: did not find pres for new uid {}', - user_id) + log.warning("lazy: did not find pres for new uid {}", user_id) return # insert to pres dict self.list.presences[user_id] = pres - member = await self.storage.get_member_data_one( - self.guild_id, user_id) + member = await self.storage.get_member_data_one(self.guild_id, user_id) self.list.members[user_id] = member # find a group for the newcomer group_id = await self._get_group_for_member( - user_id, member['roles'], pres['status']) + user_id, member["roles"], pres["status"] + ) if group_id is None: - log.warning('lazy: not adding uid {}, no group', - user_id) + log.warning("lazy: not adding uid {}, no group", user_id) return self.list.data[group_id].append(user_id) @@ -938,16 +909,14 @@ class GuildMemberList: user_index = self._get_item_index(user_id) if not user_index: - log.warning('lazy: new uid {} was not assigned idx', - user_id) + log.warning("lazy: new uid {} was not assigned idx", user_id) return await self._resync_by_item(user_index) async def remove_member(self, user_id: int): """Remove a member from the list.""" if not self.list: - log.warning('lazy: unitialized, ignoring del uid {}', - user_id) + log.warning("lazy: unitialized, ignoring del uid {}", user_id) return # we need the old index to resync later on @@ -975,34 +944,34 @@ class GuildMemberList: old_len = len(state_keys) removed = old_len - len(self.state) - log.info('lazy: removed {} states due to remove_member {}', - removed, user_id) + log.info("lazy: removed {} states due to remove_member {}", removed, user_id) # then clean anything on the internal member list # about the member being removed. try: pres = self.list.presences.pop(user_id) except KeyError: - log.warning('lazy: unknown pres uid {}', user_id) + log.warning("lazy: unknown pres uid {}", user_id) return try: member = self.list.members.pop(user_id) except KeyError: - log.warning('lazy: unknown member uid {}', user_id) + log.warning("lazy: unknown member uid {}", user_id) return group_id = await self._get_group_for_member( - user_id, member['roles'], pres['status']) + user_id, member["roles"], pres["status"] + ) if not group_id: - log.warning('lazy: unknown group uid {}', user_id) + log.warning("lazy: unknown group uid {}", user_id) return self.list.data[group_id].remove(user_id) if old_idx is None: - log.warning('lazy: unknown old idx uid {}', user_id) + log.warning("lazy: unknown old idx uid {}", user_id) return # tell everyone about the removal. @@ -1014,20 +983,17 @@ class GuildMemberList: return if user_id not in self.list.members: - log.warning('lazy: ignoring unknown uid {}', - user_id) + log.warning("lazy: ignoring unknown uid {}", user_id) return # update user information inside self.list.members - self.list.members[user_id]['user'] = \ - await self.storage.get_user(user_id) + self.list.members[user_id]["user"] = await self.storage.get_user(user_id) # redispatch user_idx = self._get_item_index(user_id) return await self._resync_by_item(user_idx) - async def pres_update(self, user_id: int, - partial_presence: Presence): + async def pres_update(self, user_id: int, partial_presence: Presence): """Update a presence inside the member list. There are 5 types of updates that can happen for a user in a group: @@ -1052,13 +1018,13 @@ class GuildMemberList: old_group = None old_presence = self.list.presences[user_id] - has_nick = 'nick' in partial_presence + has_nick = "nick" in partial_presence # partial presences don't have 'nick'. we only use it # as a flag that we're doing a mixed update (complex # but without any inter-group changes) try: - partial_presence.pop('nick') + partial_presence.pop("nick") except KeyError: pass @@ -1066,11 +1032,10 @@ class GuildMemberList: try: old_index = member_ids.index(user_id) except ValueError: - log.debug('skipping group {}', group) + log.debug("skipping group {}", group) continue - log.debug('found index for uid={}: gid={}', - user_id, group.gid) + log.debug("found index for uid={}: gid={}", user_id, group.gid) old_group = group.gid break @@ -1080,33 +1045,35 @@ class GuildMemberList: # wasn't in the list in the first place if not old_group: - log.warning('pres update with unknown old group uid={}', - user_id) + log.warning("pres update with unknown old group uid={}", user_id) return [] - roles = partial_presence.get('roles', old_presence['roles']) - status = partial_presence.get('status', old_presence['status']) + roles = partial_presence.get("roles", old_presence["roles"]) + status = partial_presence.get("status", old_presence["status"]) # calculate a possible new group # TODO: handle when new_group is None (member loses perms) - new_group = await self._get_group_for_member( - user_id, roles, status) + new_group = await self._get_group_for_member(user_id, roles, status) - log.debug('pres update: gid={} cid={} old_g={} new_g={}', - self.guild_id, self.channel_id, old_group, new_group) + log.debug( + "pres update: gid={} cid={} old_g={} new_g={}", + self.guild_id, + self.channel_id, + old_group, + new_group, + ) # update our presence with the given partial presence # since in both cases we'd update it anyways self.list.presences[user_id].update(partial_presence) - self.list.members[user_id]['roles'] = roles + self.list.members[user_id]["roles"] = roles # if we're going to the same group AND there are no # nickname changes, treat this as a simple update if old_group == new_group and not has_nick: return await self._pres_update_simple(user_id) - return await self._pres_update_complex( - user_id, old_group, old_index, new_group) + return await self._pres_update_complex(user_id, old_group, old_index, new_group) async def new_role(self, role: dict): """Add a new role to the list. @@ -1117,25 +1084,28 @@ class GuildMemberList: if not self.list: return - group_id = int(role['id']) + group_id = int(role["id"]) new_group = GroupInfo( - group_id, role['name'], - role['position'], Permissions(role['permissions']) + group_id, role["name"], role["position"], Permissions(role["permissions"]) ) # check if new role has good perms await self._fetch_overwrites() if not self._can_read_chan(new_group): - log.info('ignoring incoming group {}', new_group) + log.info("ignoring incoming group {}", new_group) return - log.debug('new_role: inserted rid={} (gid={}, cid={})', - group_id, self.guild_id, self.channel_id) + log.debug( + "new_role: inserted rid={} (gid={}, cid={})", + group_id, + self.guild_id, + self.channel_id, + ) # maintain role sorting - self.list.groups.insert(role['position'], new_group) + self.list.groups.insert(role["position"], new_group) # since this is a new group, we can set it # as a new empty list (as nobody is in the @@ -1160,18 +1130,18 @@ class GuildMemberList: - role is not found inside the group list. """ if not self.list: - log.warning('uninitialized list for gid={} cid={} rid={}', - self.guild_id, self.channel_id, role_id) + log.warning( + "uninitialized list for gid={} cid={} rid={}", + self.guild_id, + self.channel_id, + role_id, + ) return None - groups_idx = index_by_func( - lambda g: g.gid == role_id, - self.list.groups - ) + groups_idx = index_by_func(lambda g: g.gid == role_id, self.list.groups) if groups_idx is None: - log.info('ignoring rid={}, unknown group', - role_id) + log.info("ignoring rid={}, unknown group", role_id) return None return groups_idx @@ -1183,43 +1153,53 @@ class GuildMemberList: This resorts the entire group list, which might be an inefficient operation. """ - role_id = int(role['id']) + role_id = int(role["id"]) old_index = self._get_group_item_index(role_id) if not old_index: - log.warning('lazy role_pos_update: unknown group {}', role_id) + log.warning("lazy role_pos_update: unknown group {}", role_id) return old_sessions = list(self._get_subs(old_index)) groups_idx = self._get_role_as_group_idx(role_id) if groups_idx is None: - log.debug('ignoring rid={} because not group (gid={}, cid={})', - role_id, self.guild_id, self.channel_id) + log.debug( + "ignoring rid={} because not group (gid={}, cid={})", + role_id, + self.guild_id, + self.channel_id, + ) return group = self.list.groups[groups_idx] - group.position = role['position'] + group.position = role["position"] # TODO: maybe this can be more efficient? # we could self.list.groups.insert... but I don't know. # I'm taking the safe route right now by using sorted() - new_groups = sorted(self.list.groups, - key=lambda group: group.position, - reverse=True) + new_groups = sorted( + self.list.groups, key=lambda group: group.position, reverse=True + ) - log.debug('resorted groups from role pos upd ' - 'rid={} rpos={} (gid={}, cid={}) ' - 'res={}', - role_id, group.position, self.guild_id, self.channel_id, - [g.gid for g in new_groups]) + log.debug( + "resorted groups from role pos upd " + "rid={} rpos={} (gid={}, cid={}) " + "res={}", + role_id, + group.position, + self.guild_id, + self.channel_id, + [g.gid for g in new_groups], + ) self.list.groups = new_groups new_index = self._get_group_item_index(role_id) - return (await self._resync(old_sessions, old_index) + - await self._resync_by_item(new_index)) + return await self._resync(old_sessions, old_index) + await self._resync_by_item( + new_index + ) async def role_update(self, role: dict): """Update a role. @@ -1232,23 +1212,26 @@ class GuildMemberList: if not self.list: return - role_id = int(role['id']) + role_id = int(role["id"]) group_idx = self._get_role_as_group_idx(role_id) - if not group_idx and role['hoist']: + if not group_idx and role["hoist"]: # this is a new group, so we'll treat it accordingly. - log.debug('role_update promote to new_role call rid={}', - role_id) + log.debug("role_update promote to new_role call rid={}", role_id) return await self.new_role(role) if not group_idx: - log.debug('role is not group {} (gid={}, cid={})', - role_id, self.guild_id, self.channel_id) + log.debug( + "role is not group {} (gid={}, cid={})", + role_id, + self.guild_id, + self.channel_id, + ) return group = self.list.groups[group_idx] - group.permissions = Permissions(role['permissions']) + group.permissions = Permissions(role["permissions"]) await self._fetch_overwrites() @@ -1260,15 +1243,16 @@ class GuildMemberList: # respective GUILD_MEMBER_LIST_UPDATE events # down to the subscribers. if not self._can_read_chan(group): - log.debug('role_update promote to role_delete ' - 'call rid={} (lost perms)', - role_id) + log.debug( + "role_update promote to role_delete " "call rid={} (lost perms)", + role_id, + ) return await self.role_delete(role_id) - if not role['hoist']: - log.debug('role_update promote to role_delete ' - 'call rid={} (no hoist)', - role_id) + if not role["hoist"]: + log.debug( + "role_update promote to role_delete " "call rid={} (no hoist)", role_id + ) return await self.role_delete(role_id) async def role_delete(self, role_id: int): @@ -1293,20 +1277,19 @@ class GuildMemberList: # using a filter object would cause problems # as we only resync AFTER we delete the group - sess_ids_resync = (list(self._get_subs(role_item_index)) - if role_item_index is not None - else []) + sess_ids_resync = ( + list(self._get_subs(role_item_index)) if role_item_index is not None else [] + ) # remove the group info off the list groups_index = index_by_func( - lambda group: group.gid == role_id, - self.list.groups + lambda group: group.gid == role_id, self.list.groups ) if groups_index is not None: del self.list.groups[groups_index] else: - log.warning('list unstable: {} not on group list', role_id) + log.warning("list unstable: {} not on group list", role_id) # now the data info try: @@ -1318,13 +1301,11 @@ class GuildMemberList: # when generating the guild, we can reassign # the presences into new groups and sort # the new presences so we achieve the correct state - log.debug('reassigning {} presences', len(member_ids)) - await self._list_fill_groups( - member_ids - ) + log.debug("reassigning {} presences", len(member_ids)) + await self._list_fill_groups(member_ids) await self._sort_groups() except KeyError: - log.warning('list unstable: {} not in data dict', role_id) + log.warning("list unstable: {} not in data dict", role_id) try: self.list.overwrites.pop(role_id) @@ -1336,11 +1317,18 @@ class GuildMemberList: # after removing, we do a resync with the # shards that had the group. - log.info('role_delete rid={} (gid={}, cid={})', - role_id, self.guild_id, self.channel_id) + log.info( + "role_delete rid={} (gid={}, cid={})", + role_id, + self.guild_id, + self.channel_id, + ) - log.debug('there are {} session ids to resync (for item {})', - len(sess_ids_resync), role_item_index) + log.debug( + "there are {} session ids to resync (for item {})", + len(sess_ids_resync), + role_item_index, + ) if role_item_index is not None: return await self._resync(sess_ids_resync, role_item_index) @@ -1357,7 +1345,7 @@ class GuildMemberList: # await self._list_fill_groups() # await self._sort_groups() - if self.list_id == 'everyone': + if self.list_id == "everyone": return # we are on a non-everyone gml, time to check everyone perms @@ -1370,8 +1358,12 @@ class GuildMemberList: def close(self): """Remove data.""" - log.info('closing GML gid={} cid={}, {} subscribers', - self.guild_id, self.channel_id, len(self.state)) + log.info( + "closing GML gid={} cid={}, {} subscribers", + self.guild_id, + self.channel_id, + len(self.state), + ) self.guild_id = None self.channel_id = None @@ -1382,6 +1374,7 @@ class GuildMemberList: class LazyGuildDispatcher(Dispatcher): """Main class holding the member lists for lazy guilds.""" + # channel ids KEY_TYPE = int @@ -1407,9 +1400,7 @@ class LazyGuildDispatcher(Dispatcher): try: return self.state[channel_id] except KeyError: - guild_id = await self.storage.guild_from_channel( - channel_id - ) + guild_id = await self.storage.guild_from_channel(channel_id) # if we don't find a guild, we just # set it the same as the channel. @@ -1423,10 +1414,7 @@ class LazyGuildDispatcher(Dispatcher): def get_gml_guild(self, guild_id: int) -> List[GuildMemberList]: """Get all member lists for a given guild.""" - return list(map( - self.state.get, - self.guild_map[guild_id] - )) + return list(map(self.state.get, self.guild_map[guild_id])) async def unsub(self, chan_id, session_id): """Unsubscribe a session from the list.""" @@ -1436,9 +1424,9 @@ class LazyGuildDispatcher(Dispatcher): async def dispatch(self, guild_id, event: str, *args, **kwargs): """Call a function specialized in handling the given event""" try: - handler = getattr(self, f'_handle_{event.lower()}') + handler = getattr(self, f"_handle_{event.lower()}") except AttributeError: - log.warning('unknown event: {}', event) + log.warning("unknown event: {}", event) return await handler(guild_id, *args, **kwargs) @@ -1464,8 +1452,7 @@ class LazyGuildDispatcher(Dispatcher): async def _call_all_lists(self, guild_id, method_str: str, *args): lists = self.get_gml_guild(guild_id) - log.debug('calling method={} to all {} lists', - method_str, len(lists)) + log.debug("calling method={} to all {} lists", method_str, len(lists)) for lazy_list in lists: method = getattr(lazy_list, method_str) @@ -1474,31 +1461,26 @@ class LazyGuildDispatcher(Dispatcher): async def _handle_new_role(self, guild_id: int, new_role: dict): """Handle the addition of a new group by dispatching it to the member lists.""" - await self._call_all_lists(guild_id, 'new_role', new_role) + await self._call_all_lists(guild_id, "new_role", new_role) async def _handle_role_pos_upd(self, guild_id, role: dict): - await self._call_all_lists(guild_id, 'role_pos_update', role) + await self._call_all_lists(guild_id, "role_pos_update", role) async def _handle_role_update(self, guild_id, role: dict): # handle name and hoist changes - await self._call_all_lists(guild_id, 'role_update', role) + await self._call_all_lists(guild_id, "role_update", role) async def _handle_role_delete(self, guild_id, role_id: int): - await self._call_all_lists(guild_id, 'role_delete', role_id) + await self._call_all_lists(guild_id, "role_delete", role_id) - async def _handle_pres_update(self, guild_id, user_id: int, - partial: dict): - await self._call_all_lists( - guild_id, 'pres_update', user_id, partial) + async def _handle_pres_update(self, guild_id, user_id: int, partial: dict): + await self._call_all_lists(guild_id, "pres_update", user_id, partial) async def _handle_new_member(self, guild_id, user_id: int): - await self._call_all_lists( - guild_id, 'new_member', user_id) + await self._call_all_lists(guild_id, "new_member", user_id) async def _handle_remove_member(self, guild_id, user_id: int): - await self._call_all_lists( - guild_id, 'remove_member', user_id) + await self._call_all_lists(guild_id, "remove_member", user_id) async def _handle_update_user(self, guild_id, user_id: int): - await self._call_all_lists( - guild_id, 'update_user', user_id) + await self._call_all_lists(guild_id, "update_user", user_id) diff --git a/litecord/pubsub/member.py b/litecord/pubsub/member.py index 5302a17..c5a389e 100644 --- a/litecord/pubsub/member.py +++ b/litecord/pubsub/member.py @@ -22,6 +22,7 @@ from .dispatcher import Dispatcher class MemberDispatcher(Dispatcher): """Member backend for Pub/Sub.""" + KEY_TYPE = tuple async def dispatch(self, key, event, data): @@ -39,7 +40,7 @@ class MemberDispatcher(Dispatcher): # if no states were found, we should # unsub the user from the GUILD channel if not states: - await self.main_dispatcher.unsub('guild', guild_id, user_id) + await self.main_dispatcher.unsub("guild", guild_id, user_id) return return await self._dispatch_states(states, event, data) diff --git a/litecord/pubsub/user.py b/litecord/pubsub/user.py index 94bb489..bc53094 100644 --- a/litecord/pubsub/user.py +++ b/litecord/pubsub/user.py @@ -22,22 +22,18 @@ from .dispatcher import Dispatcher class UserDispatcher(Dispatcher): """User backend for Pub/Sub.""" + KEY_TYPE = int async def dispatch_filter(self, user_id: int, func, event, data): """Dispatch an event to all shards of a user.""" # filter only states where func() gives true - states = list(filter( - lambda state: func(state.session_id), - self.sm.user_states(user_id) - )) + states = list( + filter(lambda state: func(state.session_id), self.sm.user_states(user_id)) + ) return await self._dispatch_states(states, event, data) async def dispatch(self, user_id: int, event, data): - return await self.dispatch_filter( - user_id, - lambda sess_id: True, - event, data, - ) + return await self.dispatch_filter(user_id, lambda sess_id: True, event, data) diff --git a/litecord/ratelimits/bucket.py b/litecord/ratelimits/bucket.py index 929e9f6..c79bca2 100644 --- a/litecord/ratelimits/bucket.py +++ b/litecord/ratelimits/bucket.py @@ -28,6 +28,7 @@ import time class RatelimitBucket: """Main ratelimit bucket class.""" + def __init__(self, tokens, second): self.requests = tokens self.second = second @@ -88,17 +89,19 @@ class RatelimitBucket: Used to manage multiple ratelimits to users. """ - return RatelimitBucket(self.requests, - self.second) + return RatelimitBucket(self.requests, self.second) def __repr__(self): - return (f'') + return ( + f"" + ) class Ratelimit: """Manages buckets.""" + def __init__(self, tokens, second, keys=None): self._cache = {} if keys is None: @@ -107,12 +110,11 @@ class Ratelimit: self._cooldown = RatelimitBucket(tokens, second) def __repr__(self): - return (f'') + return f"" def _verify_cache(self): current = time.time() - dead_keys = [k for k, v in self._cache.items() - if current > v._last + v.second] + dead_keys = [k for k, v in self._cache.items() if current > v._last + v.second] for k in dead_keys: del self._cache[k] diff --git a/litecord/ratelimits/handler.py b/litecord/ratelimits/handler.py index f5bfa52..05228fd 100644 --- a/litecord/ratelimits/handler.py +++ b/litecord/ratelimits/handler.py @@ -31,10 +31,10 @@ async def _check_bucket(bucket): if retry_after: request.retry_after = retry_after - raise Ratelimited('You are being rate limited.', { - 'retry_after': int(retry_after * 1000), - 'global': request.bucket_global, - }) + raise Ratelimited( + "You are being rate limited.", + {"retry_after": int(retry_after * 1000), "global": request.bucket_global}, + ) async def _handle_global(ratelimit): @@ -59,13 +59,13 @@ async def _handle_specific(ratelimit): keys = ratelimit.keys # base key is the user id - key_components = [f'user_id:{user_id}'] + key_components = [f"user_id:{user_id}"] for key in keys: val = request.view_args[key] - key_components.append(f'{key}:{val}') + key_components.append(f"{key}:{val}") - bucket_key = ':'.join(key_components) + bucket_key = ":".join(key_components) bucket = ratelimit.get_bucket(bucket_key) await _check_bucket(bucket) @@ -78,9 +78,7 @@ async def ratelimit_handler(): rule = request.url_rule if rule is None: - return await _handle_global( - app.ratelimiter.global_bucket - ) + return await _handle_global(app.ratelimiter.global_bucket) # rule.endpoint is composed of '.' # and so we can use that to make routes with different @@ -97,6 +95,4 @@ async def ratelimit_handler(): ratelimit = app.ratelimiter.get_ratelimit(rule_path) await _handle_specific(ratelimit) except KeyError: - await _handle_global( - app.ratelimiter.global_bucket - ) + await _handle_global(app.ratelimiter.global_bucket) diff --git a/litecord/ratelimits/main.py b/litecord/ratelimits/main.py index 0662e46..88ec1a2 100644 --- a/litecord/ratelimits/main.py +++ b/litecord/ratelimits/main.py @@ -34,33 +34,30 @@ WS: |All Sent Messages| | 120/60s | per-session """ -REACTION_BUCKET = Ratelimit(1, 0.25, ('channel_id')) +REACTION_BUCKET = Ratelimit(1, 0.25, ("channel_id")) RATELIMITS = { - 'channel_messages.create_message': Ratelimit(5, 5, ('channel_id')), - 'channel_messages.delete_message': Ratelimit(5, 1, ('channel_id')), - + "channel_messages.create_message": Ratelimit(5, 5, ("channel_id")), + "channel_messages.delete_message": Ratelimit(5, 1, ("channel_id")), # all of those share the same bucket. - 'channel_reactions.add_reaction': REACTION_BUCKET, - 'channel_reactions.remove_own_reaction': REACTION_BUCKET, - 'channel_reactions.remove_user_reaction': REACTION_BUCKET, - - 'guild_members.modify_guild_member': Ratelimit(10, 10, ('guild_id')), - 'guild_members.update_nickname': Ratelimit(1, 1, ('guild_id')), - + "channel_reactions.add_reaction": REACTION_BUCKET, + "channel_reactions.remove_own_reaction": REACTION_BUCKET, + "channel_reactions.remove_user_reaction": REACTION_BUCKET, + "guild_members.modify_guild_member": Ratelimit(10, 10, ("guild_id")), + "guild_members.update_nickname": Ratelimit(1, 1, ("guild_id")), # this only applies to username. # 'users.patch_me': Ratelimit(2, 3600), - - '_ws.connect': Ratelimit(1, 5), - '_ws.presence': Ratelimit(5, 60), - '_ws.messages': Ratelimit(120, 60), - + "_ws.connect": Ratelimit(1, 5), + "_ws.presence": Ratelimit(5, 60), + "_ws.messages": Ratelimit(120, 60), # 1000 / 4h for new session issuing - '_ws.session': Ratelimit(1000, 14400) + "_ws.session": Ratelimit(1000, 14400), } + class RatelimitManager: """Manager for the bucket managers""" + def __init__(self, testing_flag=False): self._ratelimiters = {} self._test = testing_flag @@ -74,9 +71,7 @@ class RatelimitManager: # NOTE: this is a bad way to do it, but # we only need to change that one for now. - rtl = (Ratelimit(10, 1) - if self._test and path == '_ws.connect' - else rtl) + rtl = Ratelimit(10, 1) if self._test and path == "_ws.connect" else rtl self._ratelimiters[path] = rtl diff --git a/litecord/schemas.py b/litecord/schemas.py index bc9cce7..3919d43 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -28,26 +28,30 @@ from .errors import BadRequest from .permissions import Permissions from .types import Color from .enums import ( - ActivityType, StatusType, ExplicitFilter, RelationshipType, - MessageNotifications, ChannelType, VerificationLevel + ActivityType, + StatusType, + ExplicitFilter, + RelationshipType, + MessageNotifications, + ChannelType, + VerificationLevel, ) from litecord.embed.schemas import EMBED_OBJECT, EmbedURL log = Logger(__name__) -USERNAME_REGEX = re.compile(r'^[a-zA-Z0-9_ ]{2,30}$', re.A) -EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$', - re.A) -DATA_REGEX = re.compile(r'data\:image/(png|jpeg|gif);base64,(.+)', re.A) +USERNAME_REGEX = re.compile(r"^[a-zA-Z0-9_ ]{2,30}$", re.A) +EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$", re.A) +DATA_REGEX = re.compile(r"data\:image/(png|jpeg|gif);base64,(.+)", re.A) # collection of regexes -USER_MENTION = re.compile(r'<@!?(\d+)>', re.A | re.M) -CHAN_MENTION = re.compile(r'<#(\d+)>', re.A | re.M) -ROLE_MENTION = re.compile(r'<@&(\d+)>', re.A | re.M) -EMOJO_MENTION = re.compile(r'<:(\.+):(\d+)>', re.A | re.M) -ANIMOJI_MENTION = re.compile(r'', re.A | re.M) +USER_MENTION = re.compile(r"<@!?(\d+)>", re.A | re.M) +CHAN_MENTION = re.compile(r"<#(\d+)>", re.A | re.M) +ROLE_MENTION = re.compile(r"<@&(\d+)>", re.A | re.M) +EMOJO_MENTION = re.compile(r"<:(\.+):(\d+)>", re.A | re.M) +ANIMOJI_MENTION = re.compile(r"", re.A | re.M) def _in_enum(enum, value) -> bool: @@ -61,6 +65,7 @@ def _in_enum(enum, value) -> bool: class LitecordValidator(Validator): """Main validator class for Litecord, containing custom types.""" + def _validate_type_username(self, value: str) -> bool: """Validate against the username regex.""" return bool(USERNAME_REGEX.match(value)) @@ -130,8 +135,7 @@ class LitecordValidator(Validator): return False # nobody is allowed to use the INCOMING and OUTGOING rel types - return val in (RelationshipType.FRIEND.value, - RelationshipType.BLOCK.value) + return val in (RelationshipType.FRIEND.value, RelationshipType.BLOCK.value) def _validate_type_msg_notifications(self, value: str): try: @@ -152,14 +156,15 @@ class LitecordValidator(Validator): return self._validate_type_guild_name(value) def _validate_type_theme(self, value: str) -> bool: - return value in ['light', 'dark'] + return value in ["light", "dark"] def _validate_type_nickname(self, value: str) -> bool: return isinstance(value, str) and (len(value) < 32) -def validate(reqjson: Optional[Union[Dict, List]], schema: Dict, - raise_err: bool = True) -> Dict: +def validate( + reqjson: Optional[Union[Dict, List]], schema: Dict, raise_err: bool = True +) -> Dict: """Validate the given user-given data against a schema, giving the "correct" version of the document, with all defaults applied. @@ -176,20 +181,20 @@ def validate(reqjson: Optional[Union[Dict, List]], schema: Dict, validator = LitecordValidator(schema) if reqjson is None: - raise BadRequest('No JSON provided') + raise BadRequest("No JSON provided") try: valid = validator.validate(reqjson) except Exception: - log.exception('Error while validating') - raise Exception(f'Error while validating: {reqjson}') + log.exception("Error while validating") + raise Exception(f"Error while validating: {reqjson}") if not valid: errs = validator.errors - log.warning('Error validating doc {!r}: {!r}', reqjson, errs) + log.warning("Error validating doc {!r}: {!r}", reqjson, errs) if raise_err: - raise BadRequest('bad payload', errs) + raise BadRequest("bad payload", errs) return None @@ -197,554 +202,441 @@ def validate(reqjson: Optional[Union[Dict, List]], schema: Dict, REGISTER = { - 'username': {'type': 'username', 'required': True}, - 'email': {'type': 'email', 'required': False}, - 'password': {'type': 'password', 'required': False}, - + "username": {"type": "username", "required": True}, + "email": {"type": "email", "required": False}, + "password": {"type": "password", "required": False}, # invite stands for a guild invite, not an instance invite (that's on # the register_with_invite handler). - 'invite': {'type': 'string', 'required': False, 'nullable': True}, - + "invite": {"type": "string", "required": False, "nullable": True}, # following fields only sent by official client, unused by us - 'fingerprint': {'type': 'string', 'required': False, 'nullable': True}, - 'captcha_key': {'type': 'string', 'required': False, 'nullable': True}, - 'gift_code_sku_id': {'type': 'string', 'required': False, 'nullable': True}, - 'consent': {'type': 'boolean', 'required': False}, + "fingerprint": {"type": "string", "required": False, "nullable": True}, + "captcha_key": {"type": "string", "required": False, "nullable": True}, + "gift_code_sku_id": {"type": "string", "required": False, "nullable": True}, + "consent": {"type": "boolean", "required": False}, } # only used by us, not discord, hence 'invcode' (to separate from discord) -REGISTER_WITH_INVITE = {**REGISTER, **{ - 'invcode': {'type': 'string', 'required': True} -}} +REGISTER_WITH_INVITE = {**REGISTER, **{"invcode": {"type": "string", "required": True}}} USER_UPDATE = { - 'username': { - 'type': 'username', 'minlength': 2, - 'maxlength': 30, 'required': False}, - - 'discriminator': { - 'type': 'discriminator', - 'required': False, - 'nullable': True, + "username": { + "type": "username", + "minlength": 2, + "maxlength": 30, + "required": False, }, - - 'password': { - 'type': 'password', 'required': False, + "discriminator": {"type": "discriminator", "required": False, "nullable": True}, + "password": {"type": "password", "required": False}, + "new_password": { + "type": "password", + "required": False, + "dependencies": "password", + "nullable": True, }, - - 'new_password': { - 'type': 'password', 'required': False, - 'dependencies': 'password', 'nullable': True - }, - - 'email': { - 'type': 'email', 'required': False, 'dependencies': 'password', - }, - - 'avatar': { + "email": {"type": "email", "required": False, "dependencies": "password"}, + "avatar": { # can be both b64_icon or string (just the hash) - 'type': 'string', 'required': False, - 'nullable': True + "type": "string", + "required": False, + "nullable": True, }, - } PARTIAL_ROLE_GUILD_CREATE = { - 'type': 'dict', - 'schema': { - 'name': {'type': 'role_name'}, - 'color': {'type': 'number', 'default': 0}, - 'hoist': {'type': 'boolean', 'default': False}, - + "type": "dict", + "schema": { + "name": {"type": "role_name"}, + "color": {"type": "number", "default": 0}, + "hoist": {"type": "boolean", "default": False}, # NOTE: no position on partial role (on guild create) - - 'permissions': {'coerce': Permissions, 'required': False}, - 'mentionable': {'type': 'boolean', 'default': False}, - } + "permissions": {"coerce": Permissions, "required": False}, + "mentionable": {"type": "boolean", "default": False}, + }, } PARTIAL_CHANNEL_GUILD_CREATE = { - 'type': 'dict', - 'schema': { - 'name': {'type': 'channel_name'}, - 'type': {'type': 'channel_type'}, - } + "type": "dict", + "schema": {"name": {"type": "channel_name"}, "type": {"type": "channel_type"}}, } GUILD_CREATE = { - 'name': {'type': 'guild_name'}, - 'region': {'type': 'voice_region', 'nullable': True}, - 'icon': {'type': 'b64_icon', 'required': False, 'nullable': True}, - - 'verification_level': { - 'type': 'verification_level', 'default': 0}, - 'default_message_notifications': { - 'type': 'msg_notifications', 'default': 0}, - 'explicit_content_filter': { - 'type': 'explicit', 'default': 0}, - - 'roles': { - 'type': 'list', 'required': False, - 'schema': PARTIAL_ROLE_GUILD_CREATE}, - 'channels': { - 'type': 'list', 'default': [], 'schema': PARTIAL_CHANNEL_GUILD_CREATE}, + "name": {"type": "guild_name"}, + "region": {"type": "voice_region", "nullable": True}, + "icon": {"type": "b64_icon", "required": False, "nullable": True}, + "verification_level": {"type": "verification_level", "default": 0}, + "default_message_notifications": {"type": "msg_notifications", "default": 0}, + "explicit_content_filter": {"type": "explicit", "default": 0}, + "roles": {"type": "list", "required": False, "schema": PARTIAL_ROLE_GUILD_CREATE}, + "channels": {"type": "list", "default": [], "schema": PARTIAL_CHANNEL_GUILD_CREATE}, } GUILD_UPDATE = { - 'name': { - 'type': 'guild_name', - 'required': False - }, - 'region': {'type': 'voice_region', 'required': False, 'nullable': True}, - + "name": {"type": "guild_name", "required": False}, + "region": {"type": "voice_region", "required": False, "nullable": True}, # all three can have hashes - 'icon': {'type': 'string', 'required': False, 'nullable': True}, - 'banner': {'type': 'string', 'required': False, 'nullable': True}, - 'splash': {'type': 'string', 'required': False, 'nullable': True}, - - 'description': { - 'type': 'string', 'required': False, - 'minlength': 1, 'maxlength': 120, - 'nullable': True + "icon": {"type": "string", "required": False, "nullable": True}, + "banner": {"type": "string", "required": False, "nullable": True}, + "splash": {"type": "string", "required": False, "nullable": True}, + "description": { + "type": "string", + "required": False, + "minlength": 1, + "maxlength": 120, + "nullable": True, }, - - 'verification_level': { - 'type': 'verification_level', 'required': False}, - 'default_message_notifications': { - 'type': 'msg_notifications', 'required': False}, - 'explicit_content_filter': {'type': 'explicit', 'required': False}, - - 'afk_channel_id': { - 'type': 'snowflake', 'required': False, 'nullable': True}, - 'afk_timeout': {'type': 'number', 'required': False}, - - 'owner_id': {'type': 'snowflake', 'required': False}, - - 'system_channel_id': { - 'type': 'snowflake', 'required': False, 'nullable': True}, + "verification_level": {"type": "verification_level", "required": False}, + "default_message_notifications": {"type": "msg_notifications", "required": False}, + "explicit_content_filter": {"type": "explicit", "required": False}, + "afk_channel_id": {"type": "snowflake", "required": False, "nullable": True}, + "afk_timeout": {"type": "number", "required": False}, + "owner_id": {"type": "snowflake", "required": False}, + "system_channel_id": {"type": "snowflake", "required": False, "nullable": True}, } CHAN_OVERWRITE = { - 'id': {'coerce': int}, - 'type': {'type': 'string', 'allowed': ['role', 'member']}, - 'allow': {'coerce': Permissions}, - 'deny': {'coerce': Permissions} + "id": {"coerce": int}, + "type": {"type": "string", "allowed": ["role", "member"]}, + "allow": {"coerce": Permissions}, + "deny": {"coerce": Permissions}, } CHAN_CREATE = { - 'name': { - 'type': 'string', 'minlength': 2, - 'maxlength': 100, 'required': True - }, - - 'type': {'type': 'channel_type', - 'default': ChannelType.GUILD_TEXT.value}, - - 'position': {'coerce': int, 'required': False}, - - 'topic': { - 'type': 'string', 'minlength': 0, - 'maxlength': 1024, 'required': False}, - - 'nsfw': {'type': 'boolean', 'required': False}, - 'rate_limit_per_user': { - 'coerce': int, 'min': 0, - 'max': 120, 'required': False}, - - 'bitrate': { - 'coerce': int, 'min': 8000, - + "name": {"type": "string", "minlength": 2, "maxlength": 100, "required": True}, + "type": {"type": "channel_type", "default": ChannelType.GUILD_TEXT.value}, + "position": {"coerce": int, "required": False}, + "topic": {"type": "string", "minlength": 0, "maxlength": 1024, "required": False}, + "nsfw": {"type": "boolean", "required": False}, + "rate_limit_per_user": {"coerce": int, "min": 0, "max": 120, "required": False}, + "bitrate": { + "coerce": int, + "min": 8000, # NOTE: 'max' is 96000 for non-vip guilds - 'max': 128000, 'required': False}, - - 'user_limit': { + "max": 128000, + "required": False, + }, + "user_limit": { # user_limit being 0 means infinite. - 'coerce': int, 'min': 0, - 'max': 99, 'required': False + "coerce": int, + "min": 0, + "max": 99, + "required": False, }, - - 'permission_overwrites': { - 'type': 'list', - 'schema': {'type': 'dict', 'schema': CHAN_OVERWRITE}, - 'required': False + "permission_overwrites": { + "type": "list", + "schema": {"type": "dict", "schema": CHAN_OVERWRITE}, + "required": False, }, - - 'parent_id': {'coerce': int, 'required': False, 'nullable': True} + "parent_id": {"coerce": int, "required": False, "nullable": True}, } -CHAN_UPDATE = {**CHAN_CREATE, **{ - 'name': { - 'type': 'string', 'minlength': 2, - 'maxlength': 100, 'required': False}, - -}} +CHAN_UPDATE = { + **CHAN_CREATE, + **{"name": {"type": "string", "minlength": 2, "maxlength": 100, "required": False}}, +} ROLE_CREATE = { - 'name': {'type': 'string', 'default': 'new role'}, - 'permissions': {'coerce': Permissions, 'nullable': True}, - 'color': {'coerce': Color, 'default': 0}, - 'hoist': {'type': 'boolean', 'default': False}, - 'mentionable': {'type': 'boolean', 'default': False}, + "name": {"type": "string", "default": "new role"}, + "permissions": {"coerce": Permissions, "nullable": True}, + "color": {"coerce": Color, "default": 0}, + "hoist": {"type": "boolean", "default": False}, + "mentionable": {"type": "boolean", "default": False}, } ROLE_UPDATE = { - 'name': {'type': 'string', 'required': False}, - 'permissions': {'coerce': Permissions, 'required': False}, - 'color': {'coerce': Color, 'required': False}, - 'hoist': {'type': 'boolean', 'required': False}, - 'mentionable': {'type': 'boolean', 'required': False}, + "name": {"type": "string", "required": False}, + "permissions": {"coerce": Permissions, "required": False}, + "color": {"coerce": Color, "required": False}, + "hoist": {"type": "boolean", "required": False}, + "mentionable": {"type": "boolean", "required": False}, } ROLE_UPDATE_POSITION = { - 'roles': { - 'type': 'list', - 'schema': { - 'type': 'dict', - 'schema': { - 'id': {'coerce': int}, - 'position': {'coerce': int}, - }, - } + "roles": { + "type": "list", + "schema": { + "type": "dict", + "schema": {"id": {"coerce": int}, "position": {"coerce": int}}, + }, } } MEMBER_UPDATE = { - 'nick': { - 'type': 'nickname', 'required': False}, - 'roles': {'type': 'list', 'required': False, - 'schema': {'coerce': int}}, - 'mute': {'type': 'boolean', 'required': False}, - 'deaf': {'type': 'boolean', 'required': False}, - 'channel_id': {'type': 'snowflake', 'required': False}, + "nick": {"type": "nickname", "required": False}, + "roles": {"type": "list", "required": False, "schema": {"coerce": int}}, + "mute": {"type": "boolean", "required": False}, + "deaf": {"type": "boolean", "required": False}, + "channel_id": {"type": "snowflake", "required": False}, } # NOTE: things such as payload_json are parsed at the handler # for creating a message. MESSAGE_CREATE = { - 'content': {'type': 'string', 'minlength': 0, 'maxlength': 2000}, - 'nonce': {'type': 'snowflake', 'required': False}, - 'tts': {'type': 'boolean', 'required': False}, - - 'embed': { - 'type': 'dict', - 'schema': EMBED_OBJECT, - 'required': False, - 'nullable': True - } + "content": {"type": "string", "minlength": 0, "maxlength": 2000}, + "nonce": {"type": "snowflake", "required": False}, + "tts": {"type": "boolean", "required": False}, + "embed": { + "type": "dict", + "schema": EMBED_OBJECT, + "required": False, + "nullable": True, + }, } GW_ACTIVITY = { - 'name': {'type': 'string', 'required': True}, - 'type': {'type': 'activity_type', 'required': True}, - - 'url': {'type': 'string', 'required': False, 'nullable': True}, - - 'timestamps': { - 'type': 'dict', - 'required': False, - 'schema': { - 'start': {'type': 'number', 'required': False}, - 'end': {'type': 'number', 'required': False}, + "name": {"type": "string", "required": True}, + "type": {"type": "activity_type", "required": True}, + "url": {"type": "string", "required": False, "nullable": True}, + "timestamps": { + "type": "dict", + "required": False, + "schema": { + "start": {"type": "number", "required": False}, + "end": {"type": "number", "required": False}, }, }, - - 'application_id': {'type': 'snowflake', 'required': False, - 'nullable': False}, - 'details': {'type': 'string', 'required': False, 'nullable': True}, - 'state': {'type': 'string', 'required': False, 'nullable': True}, - - 'party': { - 'type': 'dict', - 'required': False, - 'schema': { - 'id': {'type': 'snowflake', 'required': False}, - 'size': {'type': 'list', 'required': False}, - } + "application_id": {"type": "snowflake", "required": False, "nullable": False}, + "details": {"type": "string", "required": False, "nullable": True}, + "state": {"type": "string", "required": False, "nullable": True}, + "party": { + "type": "dict", + "required": False, + "schema": { + "id": {"type": "snowflake", "required": False}, + "size": {"type": "list", "required": False}, + }, }, - - 'assets': { - 'type': 'dict', - 'required': False, - 'schema': { - 'large_image': {'type': 'snowflake', 'required': False}, - 'large_text': {'type': 'string', 'required': False}, - 'small_image': {'type': 'snowflake', 'required': False}, - 'small_text': {'type': 'string', 'required': False}, - } + "assets": { + "type": "dict", + "required": False, + "schema": { + "large_image": {"type": "snowflake", "required": False}, + "large_text": {"type": "string", "required": False}, + "small_image": {"type": "snowflake", "required": False}, + "small_text": {"type": "string", "required": False}, + }, }, - - 'secrets': { - 'type': 'dict', - 'required': False, - 'schema': { - 'join': {'type': 'string', 'required': False}, - 'spectate': {'type': 'string', 'required': False}, - 'match': {'type': 'string', 'required': False}, - } + "secrets": { + "type": "dict", + "required": False, + "schema": { + "join": {"type": "string", "required": False}, + "spectate": {"type": "string", "required": False}, + "match": {"type": "string", "required": False}, + }, }, - - 'instance': {'type': 'boolean', 'required': False}, - 'flags': {'type': 'number', 'required': False}, + "instance": {"type": "boolean", "required": False}, + "flags": {"type": "number", "required": False}, } GW_STATUS_UPDATE = { - 'status': {'type': 'status_external', 'required': False, - 'default': 'online'}, - 'activities': { - 'type': 'list', 'required': False, - 'schema': {'type': 'dict', 'schema': GW_ACTIVITY} + "status": {"type": "status_external", "required": False, "default": "online"}, + "activities": { + "type": "list", + "required": False, + "schema": {"type": "dict", "schema": GW_ACTIVITY}, }, - 'afk': {'type': 'boolean', 'required': False}, - - 'since': {'type': 'number', 'required': False, 'nullable': True}, - 'game': { - 'type': 'dict', - 'required': False, - 'nullable': True, - 'schema': GW_ACTIVITY, + "afk": {"type": "boolean", "required": False}, + "since": {"type": "number", "required": False, "nullable": True}, + "game": { + "type": "dict", + "required": False, + "nullable": True, + "schema": GW_ACTIVITY, }, } INVITE = { # max_age in seconds # 0 for infinite - 'max_age': { - 'type': 'number', - 'min': 0, - 'max': 86400, - + "max_age": { + "type": "number", + "min": 0, + "max": 86400, # a day - 'default': 86400 + "default": 86400, }, - # max invite uses - 'max_uses': { - 'type': 'number', - 'min': 0, - + "max_uses": { + "type": "number", + "min": 0, # idk - 'max': 1000, - + "max": 1000, # default infinite - 'default': 0 + "default": 0, }, - - 'temporary': {'type': 'boolean', 'required': False, 'default': False}, - 'unique': {'type': 'boolean', 'required': False, 'default': True}, - 'validate': {'type': 'string', 'required': False, 'nullable': True} # discord client sends invite code there + "temporary": {"type": "boolean", "required": False, "default": False}, + "unique": {"type": "boolean", "required": False, "default": True}, + "validate": { + "type": "string", + "required": False, + "nullable": True, + }, # discord client sends invite code there } USER_SETTINGS = { - 'afk_timeout': { - 'type': 'number', 'required': False, 'min': 0, 'max': 3000}, - - 'animate_emoji': {'type': 'boolean', 'required': False}, - 'convert_emoticons': {'type': 'boolean', 'required': False}, - 'default_guilds_restricted': {'type': 'boolean', 'required': False}, - 'detect_platform_accounts': {'type': 'boolean', 'required': False}, - 'developer_mode': {'type': 'boolean', 'required': False}, - 'disable_games_tab': {'type': 'boolean', 'required': False}, - 'enable_tts_command': {'type': 'boolean', 'required': False}, - - 'explicit_content_filter': {'type': 'explicit', 'required': False}, - - 'friend_source': { - 'type': 'dict', - 'required': False, - 'schema': { - 'all': {'type': 'boolean', 'required': False}, - 'mutual_guilds': {'type': 'boolean', 'required': False}, - 'mutual_friends': {'type': 'boolean', 'required': False}, - } + "afk_timeout": {"type": "number", "required": False, "min": 0, "max": 3000}, + "animate_emoji": {"type": "boolean", "required": False}, + "convert_emoticons": {"type": "boolean", "required": False}, + "default_guilds_restricted": {"type": "boolean", "required": False}, + "detect_platform_accounts": {"type": "boolean", "required": False}, + "developer_mode": {"type": "boolean", "required": False}, + "disable_games_tab": {"type": "boolean", "required": False}, + "enable_tts_command": {"type": "boolean", "required": False}, + "explicit_content_filter": {"type": "explicit", "required": False}, + "friend_source": { + "type": "dict", + "required": False, + "schema": { + "all": {"type": "boolean", "required": False}, + "mutual_guilds": {"type": "boolean", "required": False}, + "mutual_friends": {"type": "boolean", "required": False}, + }, }, - 'guild_positions': { - 'type': 'list', - 'required': False, - 'schema': {'type': 'snowflake'} + "guild_positions": { + "type": "list", + "required": False, + "schema": {"type": "snowflake"}, }, - 'restricted_guilds': { - 'type': 'list', - 'required': False, - 'schema': {'type': 'snowflake'} + "restricted_guilds": { + "type": "list", + "required": False, + "schema": {"type": "snowflake"}, }, - - 'gif_auto_play': {'type': 'boolean', 'required': False}, - 'inline_attachment_media': {'type': 'boolean', 'required': False}, - 'inline_embed_media': {'type': 'boolean', 'required': False}, - 'message_display_compact': {'type': 'boolean', 'required': False}, - 'render_embeds': {'type': 'boolean', 'required': False}, - 'render_reactions': {'type': 'boolean', 'required': False}, - 'show_current_game': {'type': 'boolean', 'required': False}, - - 'timezone_offset': {'type': 'number', 'required': False}, - - 'status': {'type': 'status_external', 'required': False}, - 'theme': {'type': 'theme', 'required': False} + "gif_auto_play": {"type": "boolean", "required": False}, + "inline_attachment_media": {"type": "boolean", "required": False}, + "inline_embed_media": {"type": "boolean", "required": False}, + "message_display_compact": {"type": "boolean", "required": False}, + "render_embeds": {"type": "boolean", "required": False}, + "render_reactions": {"type": "boolean", "required": False}, + "show_current_game": {"type": "boolean", "required": False}, + "timezone_offset": {"type": "number", "required": False}, + "status": {"type": "status_external", "required": False}, + "theme": {"type": "theme", "required": False}, } RELATIONSHIP = { - 'type': { - 'type': 'rel_type', - 'required': False, - 'default': RelationshipType.FRIEND.value + "type": { + "type": "rel_type", + "required": False, + "default": RelationshipType.FRIEND.value, } } -CREATE_DM = { - 'recipient_id': { - 'type': 'snowflake', - 'required': True - } -} +CREATE_DM = {"recipient_id": {"type": "snowflake", "required": True}} CREATE_GROUP_DM = { - 'recipients': { - 'type': 'list', - 'required': True, - 'schema': {'type': 'snowflake'} - }, + "recipients": {"type": "list", "required": True, "schema": {"type": "snowflake"}} } GROUP_DM_UPDATE = { - 'name': { - 'type': 'guild_name', - 'required': False - }, - 'icon': {'type': 'b64_icon', 'required': False, 'nullable': True}, + "name": {"type": "guild_name", "required": False}, + "icon": {"type": "b64_icon", "required": False, "nullable": True}, } SPECIFIC_FRIEND = { - 'username': {'type': 'username'}, - 'discriminator': {'type': 'discriminator'} + "username": {"type": "username"}, + "discriminator": {"type": "discriminator"}, } GUILD_SETTINGS_CHAN_OVERRIDE = { - 'type': 'dict', - 'schema': { - 'muted': { - 'type': 'boolean', 'required': False}, - 'message_notifications': { - 'type': 'msg_notifications', - 'required': False, - } - } + "type": "dict", + "schema": { + "muted": {"type": "boolean", "required": False}, + "message_notifications": {"type": "msg_notifications", "required": False}, + }, } GUILD_SETTINGS = { - 'channel_overrides': { - 'type': 'dict', - 'valueschema': GUILD_SETTINGS_CHAN_OVERRIDE, - 'keyschema': {'type': 'snowflake'}, - 'required': False, + "channel_overrides": { + "type": "dict", + "valueschema": GUILD_SETTINGS_CHAN_OVERRIDE, + "keyschema": {"type": "snowflake"}, + "required": False, }, - 'suppress_everyone': { - 'type': 'boolean', 'required': False}, - 'muted': { - 'type': 'boolean', 'required': False}, - 'mobile_push': { - 'type': 'boolean', 'required': False}, - 'message_notifications': { - 'type': 'msg_notifications', - 'required': False, - } + "suppress_everyone": {"type": "boolean", "required": False}, + "muted": {"type": "boolean", "required": False}, + "mobile_push": {"type": "boolean", "required": False}, + "message_notifications": {"type": "msg_notifications", "required": False}, } GUILD_PRUNE = { - 'days': {'type': 'number', 'coerce': int, 'min': 1, 'max': 30, 'default': 7}, - 'compute_prune_count': {'type': 'string', 'default': 'true'} + "days": {"type": "number", "coerce": int, "min": 1, "max": 30, "default": 7}, + "compute_prune_count": {"type": "string", "default": "true"}, } NEW_EMOJI = { - 'name': { - 'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': True}, - 'image': {'type': 'b64_icon', 'required': True}, - 'roles': {'type': 'list', 'schema': {'coerce': int}} + "name": {"type": "string", "minlength": 1, "maxlength": 256, "required": True}, + "image": {"type": "b64_icon", "required": True}, + "roles": {"type": "list", "schema": {"coerce": int}}, } PATCH_EMOJI = { - 'name': { - 'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': True}, - 'roles': {'type': 'list', 'schema': {'coerce': int}} + "name": {"type": "string", "minlength": 1, "maxlength": 256, "required": True}, + "roles": {"type": "list", "schema": {"coerce": int}}, } SEARCH_CHANNEL = { - 'content': {'type': 'string', 'minlength': 1, 'required': True}, - 'include_nsfw': {'coerce': bool, 'default': False}, - 'offset': {'coerce': int, 'default': 0} + "content": {"type": "string", "minlength": 1, "required": True}, + "include_nsfw": {"coerce": bool, "default": False}, + "offset": {"coerce": int, "default": 0}, } GET_MENTIONS = { - 'limit': {'coerce': int, 'default': 25}, - 'roles': {'coerce': bool, 'default': True}, - 'everyone': {'coerce': bool, 'default': True}, - 'guild_id': {'coerce': int, 'required': False} + "limit": {"coerce": int, "default": 25}, + "roles": {"coerce": bool, "default": True}, + "everyone": {"coerce": bool, "default": True}, + "guild_id": {"coerce": int, "required": False}, } VANITY_URL_PATCH = { # TODO: put proper values in maybe an invite data type - 'code': {'type': 'string', 'minlength': 5, 'maxlength': 30} + "code": {"type": "string", "minlength": 5, "maxlength": 30} } WEBHOOK_CREATE = { - 'name': { - 'type': 'string', 'minlength': 2, 'maxlength': 32, - 'required': True - }, - 'avatar': {'type': 'b64_icon', 'required': False, 'nullable': False} + "name": {"type": "string", "minlength": 2, "maxlength": 32, "required": True}, + "avatar": {"type": "b64_icon", "required": False, "nullable": False}, } WEBHOOK_UPDATE = { - 'name': { - 'type': 'string', 'minlength': 2, 'maxlength': 32, - 'required': False - }, - + "name": {"type": "string", "minlength": 2, "maxlength": 32, "required": False}, # TODO: check if its b64_icon or string since the client # could pass an icon hash instead. - 'avatar': {'type': 'b64_icon', 'required': False, 'nullable': False}, - 'channel_id': {'coerce': int, 'required': False, 'nullable': False} + "avatar": {"type": "b64_icon", "required": False, "nullable": False}, + "channel_id": {"coerce": int, "required": False, "nullable": False}, } WEBHOOK_MESSAGE_CREATE = { - 'content': { - 'type': 'string', - 'minlength': 0, 'maxlength': 2000, 'required': False + "content": {"type": "string", "minlength": 0, "maxlength": 2000, "required": False}, + "tts": {"type": "boolean", "required": False}, + "username": {"type": "string", "minlength": 2, "maxlength": 32, "required": False}, + "avatar_url": {"coerce": EmbedURL, "required": False}, + "embeds": { + "type": "list", + "required": False, + "schema": {"type": "dict", "schema": EMBED_OBJECT}, }, - 'tts': {'type': 'boolean', 'required': False}, - - 'username': { - 'type': 'string', - 'minlength': 2, 'maxlength': 32, 'required': False - }, - - 'avatar_url': { - 'coerce': EmbedURL, 'required': False - }, - - 'embeds': { - 'type': 'list', - 'required': False, - 'schema': {'type': 'dict', 'schema': EMBED_OBJECT} - } } BULK_DELETE = { - 'messages': { - 'type': 'list', 'required': True, - 'minlength': 2, 'maxlength': 100, - 'schema': {'coerce': int} + "messages": { + "type": "list", + "required": True, + "minlength": 2, + "maxlength": 100, + "schema": {"coerce": int}, } } diff --git a/litecord/snowflake.py b/litecord/snowflake.py index fd85995..24a83c8 100644 --- a/litecord/snowflake.py +++ b/litecord/snowflake.py @@ -61,19 +61,19 @@ def _snowflake(timestamp: int) -> Snowflake: # bits 0-12 encode _generated_ids (size 12) # modulo'd to prevent overflows - genid_b = '{0:012b}'.format(_generated_ids % 4096) + genid_b = "{0:012b}".format(_generated_ids % 4096) # bits 12-17 encode PROCESS_ID (size 5) - procid_b = '{0:05b}'.format(PROCESS_ID) + procid_b = "{0:05b}".format(PROCESS_ID) # bits 17-22 encode WORKER_ID (size 5) - workid_b = '{0:05b}'.format(WORKER_ID) + workid_b = "{0:05b}".format(WORKER_ID) # bits 22-64 encode (timestamp - EPOCH) (size 42) epochized = timestamp - EPOCH - epoch_b = '{0:042b}'.format(epochized) + epoch_b = "{0:042b}".format(epochized) - snowflake_b = f'{epoch_b}{workid_b}{procid_b}{genid_b}' + snowflake_b = f"{epoch_b}{workid_b}{procid_b}{genid_b}" _generated_ids += 1 return int(snowflake_b, 2) @@ -87,7 +87,7 @@ def snowflake_time(snowflake: Snowflake) -> float: # the total size for a snowflake is 64 bits, # considering it is a string, position 0 to 42 will give us # the `epochized` variable - snowflake_b = '{0:064b}'.format(snowflake) + snowflake_b = "{0:064b}".format(snowflake) epochized_b = snowflake_b[:42] epochized = int(epochized_b, 2) diff --git a/litecord/storage.py b/litecord/storage.py index 676f362..dae0961 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -23,9 +23,7 @@ from logbook import Logger from litecord.enums import ChannelType from litecord.schemas import USER_MENTION, ROLE_MENTION -from litecord.blueprints.channel.reactions import ( - EmojiType, emoji_sql, partial_emoji -) +from litecord.blueprints.channel.reactions import EmojiType, emoji_sql, partial_emoji from litecord.blueprints.user.billing import PLAN_ID_TO_TYPE @@ -63,13 +61,12 @@ def bool_(val): def _filter_recipients(recipients: List[Dict[str, Any]], user_id: str): """Filter recipients in a list of recipients, removing the one that is reundant (ourselves).""" - return list(filter( - lambda recipient: recipient['id'] != user_id, - recipients)) + return list(filter(lambda recipient: recipient["id"] != user_id, recipients)) class Storage: """Class for common SQL statements.""" + def __init__(self, app): self.app = app self.db = app.db @@ -100,38 +97,51 @@ class Storage: """Get a single user payload.""" user_id = int(user_id) - fields = ['id::text', 'username', 'discriminator', - 'avatar', 'flags', 'bot', 'premium_since'] + fields = [ + "id::text", + "username", + "discriminator", + "avatar", + "flags", + "bot", + "premium_since", + ] if secure: - fields.extend(['email', 'verified', 'mfa_enabled']) + fields.extend(["email", "verified", "mfa_enabled"]) - user_row = await self.db.fetchrow(f""" + user_row = await self.db.fetchrow( + f""" SELECT {','.join(fields)} FROM users WHERE users.id = $1 - """, user_id) + """, + user_id, + ) if not user_row: return None duser = dict(user_row) - duser['premium'] = duser['premium_since'] is not None - duser.pop('premium_since') + duser["premium"] = duser["premium_since"] is not None + duser.pop("premium_since") if secure: - duser['mobile'] = False - duser['phone'] = None + duser["mobile"] = False + duser["phone"] = None - plan_id = await self.db.fetchval(""" + plan_id = await self.db.fetchval( + """ SELECT payment_gateway_plan_id FROM user_subscriptions WHERE status = 1 AND user_id = $1 - """, user_id) + """, + user_id, + ) - duser['premium_type'] = PLAN_ID_TO_TYPE.get(plan_id) + duser["premium_type"] = PLAN_ID_TO_TYPE.get(plan_id) return duser @@ -139,30 +149,41 @@ class Storage: """Search a user""" if len(discriminator) < 4: # how do we do this in f-strings again..? - discriminator = '%04d' % int(discriminator) + discriminator = "%04d" % int(discriminator) - return await self.db.fetchval(""" + return await self.db.fetchval( + """ SELECT id FROM users WHERE username = $1 AND discriminator = $2 - """, username, discriminator) + """, + username, + discriminator, + ) async def guild_features(self, guild_id: int) -> Optional[List[str]]: """Get a list of guild features for the given guild.""" - return await self.db.fetchval(""" + return await self.db.fetchval( + """ SELECT features FROM guilds WHERE id = $1 - """, guild_id) + """, + guild_id, + ) async def vanity_invite(self, guild_id: int) -> Optional[str]: """Get the vanity invite for a guild.""" - return await self.db.fetchval(""" + return await self.db.fetchval( + """ SELECT code FROM vanity_invites WHERE guild_id = $1 - """, guild_id) + """, + guild_id, + ) async def get_guild(self, guild_id: int, user_id=None) -> Optional[Dict]: """Get gulid payload.""" - row = await self.db.fetchrow(""" + row = await self.db.fetchrow( + """ SELECT id::text, owner_id::text, name, icon, splash, region, afk_channel_id::text, afk_timeout, verification_level, default_message_notifications, @@ -173,7 +194,9 @@ class Storage: banner, description FROM guilds WHERE guilds.id = $1 - """, guild_id) + """, + guild_id, + ) if not row: return None @@ -182,50 +205,49 @@ class Storage: # a guild's unavailable state is kept in memory, and we remove every # other guild related field when its unavailable. - drow['unavailable'] = self.app.guild_store.get( - guild_id, 'unavailable', False) + drow["unavailable"] = self.app.guild_store.get(guild_id, "unavailable", False) - if drow['unavailable']: - drow = { - 'id': drow['id'], - 'unavailable': True - } + if drow["unavailable"]: + drow = {"id": drow["id"], "unavailable": True} # guild.owner is dependant of the user doing the get_guild call. if user_id: - drow['owner'] = drow['owner_id'] == str(user_id) + drow["owner"] = drow["owner_id"] == str(user_id) - drow['vanity_url_code'] = await self.vanity_invite(guild_id) + drow["vanity_url_code"] = await self.vanity_invite(guild_id) # hardcoding these since: # - we aren't discord # - the limit for guilds is unknown and heavily dependant on the # hardware - drow['max_presences'] = 1000 - drow['max_members'] = 1000 + drow["max_presences"] = 1000 + drow["max_members"] = 1000 # used by guilds with DISCOVERABLE feature - drow['preffered_locale'] = 'en-US' + drow["preffered_locale"] = "en-US" return drow async def _member_basic(self, guild_id: int, member_id: int): - row = await self.db.fetchrow(""" + row = await self.db.fetchrow( + """ SELECT user_id, nickname, joined_at, deafened AS deaf, muted AS mute FROM members WHERE guild_id = $1 and user_id = $2 - """, guild_id, member_id) + """, + guild_id, + member_id, + ) if row is None: return None row = dict(row) - row['joined_at'] = timestamp_(row['joined_at']) + row["joined_at"] = timestamp_(row["joined_at"]) return row - async def _member_basic_with_roles(self, guild_id: int, - member_id: int): + async def _member_basic_with_roles(self, guild_id: int, member_id: int): basic = await self._member_basic(guild_id, member_id) if basic is None: @@ -234,20 +256,21 @@ class Storage: basic = dict(basic) roles = await self.get_member_role_ids(guild_id, member_id) - return {**basic, **{ - 'roles': roles - }} + return {**basic, **{"roles": roles}} - async def get_member_role_ids(self, guild_id: int, - member_id: int) -> List[str]: + async def get_member_role_ids(self, guild_id: int, member_id: int) -> List[str]: """Get a list of role IDs that are on a member.""" - roles = await self.db.fetch(""" + roles = await self.db.fetch( + """ SELECT role_id::text FROM member_roles WHERE guild_id = $1 AND user_id = $2 - """, guild_id, member_id) + """, + guild_id, + member_id, + ) - roles = [r['role_id'] for r in roles] + roles = [r["role_id"] for r in roles] try: roles.remove(str(guild_id)) @@ -255,10 +278,15 @@ class Storage: # if the @everyone role isn't in, we add it # to member_roles automatically (it won't # be shown on the API, though). - await self.db.execute(""" + await self.db.execute( + """ INSERT INTO member_roles (user_id, guild_id, role_id) VALUES ($1, $2, $3) - """, member_id, guild_id, guild_id) + """, + member_id, + guild_id, + guild_id, + ) return list(map(str, roles)) @@ -266,20 +294,20 @@ class Storage: roles = await self.get_member_role_ids(guild_id, member_id) return { - 'user': await self.get_user(member_id), - 'nick': row['nickname'], - + "user": await self.get_user(member_id), + "nick": row["nickname"], # we don't send the @everyone role's id to # the user since it is known that everyone has # that role. - 'roles': roles, - 'joined_at': row['joined_at'], - 'deaf': row['deaf'], - 'mute': row['mute'], + "roles": roles, + "joined_at": row["joined_at"], + "deaf": row["deaf"], + "mute": row["mute"], } - async def get_member_data_one(self, guild_id: int, - member_id: int) -> Optional[Dict[str, Any]]: + async def get_member_data_one( + self, guild_id: int, member_id: int + ) -> Optional[Dict[str, Any]]: """Get data about one member in a guild.""" basic = await self._member_basic(guild_id, member_id) @@ -288,8 +316,9 @@ class Storage: return await self._member_dict(basic, guild_id, member_id) - async def get_member_multi(self, guild_id: int, - user_ids: List[int]) -> List[Dict[str, Any]]: + async def get_member_multi( + self, guild_id: int, user_ids: List[int] + ) -> List[Dict[str, Any]]: """Get member information about multiple users in a guild.""" members = [] @@ -305,45 +334,55 @@ class Storage: async def get_member_data(self, guild_id: int) -> List[Dict[str, Any]]: """Get member information on a guild.""" - members_basic = await self.db.fetch(""" + members_basic = await self.db.fetch( + """ SELECT user_id, nickname, joined_at, deafened AS deaf, muted AS mute FROM members WHERE guild_id = $1 - """, guild_id) + """, + guild_id, + ) members = [] for row in members_basic: drow = dict(row) - drow['joined_at'] = timestamp_(drow['joined_at']) - member = await self._member_dict(drow, guild_id, drow['user_id']) + drow["joined_at"] = timestamp_(drow["joined_at"]) + member = await self._member_dict(drow, guild_id, drow["user_id"]) members.append(member) return members async def query_members(self, guild_id: int, query: str, limit: int): """Find members with usernames matching the given query.""" - mids = await self.db.fetch(f""" + mids = await self.db.fetch( + f""" SELECT user_id FROM members JOIN users ON members.user_id = users.id WHERE members.guild_id = $1 AND users.username LIKE '%'||$2 LIMIT {limit} - """, guild_id, query) + """, + guild_id, + query, + ) - mids = [r['user_id'] for r in mids] + mids = [r["user_id"] for r in mids] members = await self.get_member_multi(guild_id, mids) return members async def chan_last_message(self, channel_id: int): """Get the last message ID in a channel.""" - return await self.db.fetchval(""" + return await self.db.fetchval( + """ SELECT MAX(id) FROM messages WHERE channel_id = $1 - """, channel_id) + """, + channel_id, + ) async def chan_last_message_str(self, channel_id: int) -> str: """Get the last message ID but in a string. @@ -356,70 +395,79 @@ class Storage: async def _channels_extra(self, row) -> Dict: """Fill in more information about a channel.""" - channel_type = row['type'] + channel_type = row["type"] chan_type = ChannelType(channel_type) if chan_type == ChannelType.GUILD_TEXT: - ext_row = await self.db.fetchrow(""" + ext_row = await self.db.fetchrow( + """ SELECT topic, rate_limit_per_user FROM guild_text_channels WHERE id = $1 - """, row['id']) + """, + row["id"], + ) drow = dict(ext_row) - last_msg = await self.chan_last_message_str(row['id']) + last_msg = await self.chan_last_message_str(row["id"]) - drow['last_message_id'] = last_msg + drow["last_message_id"] = last_msg return {**row, **drow} if chan_type == ChannelType.GUILD_VOICE: - vrow = await self.db.fetchrow(""" + vrow = await self.db.fetchrow( + """ SELECT bitrate, user_limit FROM guild_voice_channels WHERE id = $1 - """, row['id']) + """, + row["id"], + ) return {**row, **dict(vrow)} - log.warning('unknown channel type: {}', chan_type) + log.warning("unknown channel type: {}", chan_type) return row async def get_chan_type(self, channel_id: int) -> int: """Get the channel type integer, given channel ID.""" - return await self.db.fetchval(""" + return await self.db.fetchval( + """ SELECT channel_type FROM channels WHERE channels.id = $1 - """, channel_id) + """, + channel_id, + ) async def chan_overwrites(self, channel_id: int) -> List[Dict[str, Any]]: - overwrite_rows = await self.db.fetch(""" + overwrite_rows = await self.db.fetch( + """ SELECT target_type, target_role, target_user, allow, deny FROM channel_overwrites WHERE channel_id = $1 - """, channel_id) + """, + channel_id, + ) def _overwrite_convert(row): drow = dict(row) - target_type = drow['target_type'] - drow['type'] = 'member' if target_type == 0 else 'role' + target_type = drow["target_type"] + drow["type"] = "member" if target_type == 0 else "role" # if type is 0, the overwrite is for a member # if type is 1, the overwrite is for a role - drow['id'] = { - 0: drow['target_user'], - 1: drow['target_role'], - }[target_type] + drow["id"] = {0: drow["target_user"], 1: drow["target_role"]}[target_type] - drow['id'] = str(drow['id']) + drow["id"] = str(drow["id"]) - drow.pop('target_type') - drow.pop('target_user') - drow.pop('target_role') + drow.pop("target_type") + drow.pop("target_user") + drow.pop("target_role") return drow @@ -428,19 +476,23 @@ class Storage: async def gdm_recipient_ids(self, channel_id: int) -> List[int]: """Get the list of user IDs that are recipients of the given Group DM.""" - user_ids = await self.db.fetch(""" + user_ids = await self.db.fetch( + """ SELECT member_id FROM group_dm_members JOIN users ON member_id = users.id WHERE group_dm_members.id = $1 ORDER BY username DESC - """, channel_id) + """, + channel_id, + ) - return [r['member_id'] for r in user_ids] + return [r["member_id"] for r in user_ids] - async def _gdm_recipients(self, channel_id: int, - reference_id: int = None) -> List[Dict]: + async def _gdm_recipients( + self, channel_id: int, reference_id: int = None + ) -> List[Dict]: """Get the list of users that are recipients of the given Group DM.""" recipients = await self.gdm_recipient_ids(channel_id) @@ -459,70 +511,75 @@ class Storage: return res - async def get_channel(self, channel_id: int, - **kwargs) -> Optional[Dict[str, Any]]: + async def get_channel(self, channel_id: int, **kwargs) -> Optional[Dict[str, Any]]: """Fetch a single channel's information.""" chan_type = await self.get_chan_type(channel_id) ctype = ChannelType(chan_type) - if ctype in (ChannelType.GUILD_TEXT, - ChannelType.GUILD_VOICE, - ChannelType.GUILD_CATEGORY): - base = await self.db.fetchrow(""" + if ctype in ( + ChannelType.GUILD_TEXT, + ChannelType.GUILD_VOICE, + ChannelType.GUILD_CATEGORY, + ): + base = await self.db.fetchrow( + """ SELECT id, guild_id::text, parent_id, name, position, nsfw FROM guild_channels WHERE guild_channels.id = $1 - """, channel_id) + """, + channel_id, + ) dbase = dict(base) - dbase['type'] = chan_type + dbase["type"] = chan_type res = await self._channels_extra(dbase) - res['permission_overwrites'] = await self.chan_overwrites( - channel_id) + res["permission_overwrites"] = await self.chan_overwrites(channel_id) - res['id'] = str(res['id']) + res["id"] = str(res["id"]) return res elif ctype == ChannelType.DM: - dm_row = await self.db.fetchrow(""" + dm_row = await self.db.fetchrow( + """ SELECT id, party1_id, party2_id FROM dm_channels WHERE id = $1 - """, channel_id) - - drow = dict(dm_row) - drow['type'] = chan_type - - drow['last_message_id'] = await self.chan_last_message_str( - channel_id + """, + channel_id, ) + drow = dict(dm_row) + drow["type"] = chan_type + + drow["last_message_id"] = await self.chan_last_message_str(channel_id) + # dms have just two recipients. - drow['recipients'] = [ - await self.get_user(drow['party1_id']), - await self.get_user(drow['party2_id']) + drow["recipients"] = [ + await self.get_user(drow["party1_id"]), + await self.get_user(drow["party2_id"]), ] - drow.pop('party1_id') - drow.pop('party2_id') + drow.pop("party1_id") + drow.pop("party2_id") - drow['id'] = str(drow['id']) + drow["id"] = str(drow["id"]) return drow elif ctype == ChannelType.GROUP_DM: - gdm_row = await self.db.fetchrow(""" + gdm_row = await self.db.fetchrow( + """ SELECT id::text, owner_id::text, name, icon FROM group_dm_channels WHERE id = $1 - """, channel_id) + """, + channel_id, + ) drow = dict(gdm_row) - drow['type'] = chan_type - drow['recipients'] = await self._gdm_recipients( - channel_id, kwargs.get('user_id') - ) - drow['last_message_id'] = await self.chan_last_message_str( - channel_id + drow["type"] = chan_type + drow["recipients"] = await self._gdm_recipients( + channel_id, kwargs.get("user_id") ) + drow["last_message_id"] = await self.chan_last_message_str(channel_id) return drow @@ -530,61 +587,73 @@ class Storage: async def get_channel_ids(self, guild_id: int) -> List[int]: """Get all channel IDs in a guild.""" - rows = await self.db.fetch(""" + rows = await self.db.fetch( + """ SELECT id FROM guild_channels WHERE guild_id = $1 - """, guild_id) + """, + guild_id, + ) - return [r['id'] for r in rows] + return [r["id"] for r in rows] async def get_channel_data(self, guild_id) -> List[Dict]: """Get channel list information on a guild""" - channel_basics = await self.db.fetch(""" + channel_basics = await self.db.fetch( + """ SELECT id, guild_id::text, parent_id::text, name, position, nsfw FROM guild_channels WHERE guild_id = $1 - """, guild_id) + """, + guild_id, + ) channels = [] for row in channel_basics: - ctype = await self.db.fetchval(""" + ctype = await self.db.fetchval( + """ SELECT channel_type FROM channels WHERE id = $1 - """, row['id']) + """, + row["id"], + ) drow = dict(row) - drow['type'] = ctype + drow["type"] = ctype res = await self._channels_extra(drow) - res['permission_overwrites'] = await self.chan_overwrites( - row['id']) + res["permission_overwrites"] = await self.chan_overwrites(row["id"]) # Making sure. - res['id'] = str(res['id']) + res["id"] = str(res["id"]) channels.append(res) return channels - async def get_role(self, role_id: int, - guild_id: int = None) -> Optional[Dict[str, Any]]: + async def get_role( + self, role_id: int, guild_id: int = None + ) -> Optional[Dict[str, Any]]: """get a single role's information.""" - guild_field = 'AND guild_id = $2' if guild_id else '' + guild_field = "AND guild_id = $2" if guild_id else "" args = [role_id] if guild_id: args.append(guild_id) - row = await self.db.fetchrow(f""" + row = await self.db.fetchrow( + f""" SELECT id::text, name, color, hoist, position, permissions, managed, mentionable FROM roles WHERE id = $1 {guild_field} LIMIT 1 - """, *args) + """, + *args, + ) if not row: return None @@ -593,18 +662,22 @@ class Storage: async def get_role_data(self, guild_id: int) -> List[Dict[str, Any]]: """Get role list information on a guild.""" - roledata = await self.db.fetch(""" + roledata = await self.db.fetch( + """ SELECT id::text, name, color, hoist, position, permissions, managed, mentionable FROM roles WHERE guild_id = $1 ORDER BY position ASC - """, guild_id) + """, + guild_id, + ) return list(map(dict, roledata)) - async def guild_voice_states(self, guild_id: int, - user_id=None) -> List[Dict[str, Any]]: + async def guild_voice_states( + self, guild_id: int, user_id=None + ) -> List[Dict[str, Any]]: """Get a list of voice states for the given guild.""" channel_ids = await self.get_channel_ids(guild_id) @@ -618,57 +691,63 @@ class Storage: # discord does NOT insert guild_id to voice states on the # guild voice state list. for state in jsonified: - state.pop('guild_id') + state.pop("guild_id") res.extend(jsonified) return res - async def get_guild_extra(self, guild_id: int, - user_id=None, large=None) -> Dict: + async def get_guild_extra(self, guild_id: int, user_id=None, large=None) -> Dict: """Get extra information about a guild.""" res = {} - member_count = await self.db.fetchval(""" + member_count = await self.db.fetchval( + """ SELECT COUNT(*) FROM members WHERE guild_id = $1 - """, guild_id) + """, + guild_id, + ) if large: - res['large'] = member_count > large + res["large"] = member_count > large if user_id: - joined_at = await self.db.fetchval(""" + joined_at = await self.db.fetchval( + """ SELECT joined_at FROM members WHERE guild_id = $1 AND user_id = $2 - """, guild_id, user_id) + """, + guild_id, + user_id, + ) - res['joined_at'] = timestamp_(joined_at) + res["joined_at"] = timestamp_(joined_at) members = await self.get_member_data(guild_id) channels = await self.get_channel_data(guild_id) roles = await self.get_role_data(guild_id) - mids = [int(m['user']['id']) for m in members] + mids = [int(m["user"]["id"]) for m in members] - return {**res, **{ - 'member_count': member_count, - 'members': members, - 'channels': channels, - 'roles': roles, + return { + **res, + **{ + "member_count": member_count, + "members": members, + "channels": channels, + "roles": roles, + "presences": await self.presence.guild_presences(mids, guild_id), + "emojis": await self.get_guild_emojis(guild_id), + "voice_states": await self.guild_voice_states(guild_id), + }, + } - 'presences': await self.presence.guild_presences( - mids, guild_id - ), - - 'emojis': await self.get_guild_emojis(guild_id), - 'voice_states': await self.guild_voice_states(guild_id), - }} - - async def get_guild_full(self, guild_id: int, user_id: Optional[int] = None, - large_count: int = 250) -> Optional[Dict]: + async def get_guild_full( + self, guild_id: int, user_id: Optional[int] = None, large_count: int = 250 + ) -> Optional[Dict]: """Get full information on a guild. This is a very expensive operation. @@ -678,7 +757,7 @@ class Storage: if guild is None: return None - if guild['unavailable']: + if guild["unavailable"]: return guild extra = await self.get_guild_extra(guild_id, user_id, large_count) @@ -687,21 +766,27 @@ class Storage: async def guild_exists(self, guild_id: int) -> bool: """Return if a given guild ID exists.""" - owner_id = await self.db.fetch(""" + owner_id = await self.db.fetch( + """ SELECT owner_id FROM guilds WHERE id = $1 - """, guild_id) + """, + guild_id, + ) return owner_id is not None async def get_member_ids(self, guild_id: int) -> List[int]: """Get member IDs inside a guild""" - rows = await self.db.fetch(""" + rows = await self.db.fetch( + """ SELECT user_id FROM members WHERE guild_id = $1 - """, guild_id) + """, + guild_id, + ) return [r[0] for r in rows] @@ -726,12 +811,15 @@ class Storage: async def get_reactions(self, message_id: int, user_id=None) -> List: """Get all reactions in a message.""" - reactions = await self.db.fetch(""" + reactions = await self.db.fetch( + """ SELECT user_id, emoji_type, emoji_id, emoji_text FROM message_reactions WHERE message_id = $1 ORDER BY react_ts - """, message_id) + """, + message_id, + ) # ordered list of emoji emoji = [] @@ -745,8 +833,8 @@ class Storage: # we can't use a set() because that # doesn't guarantee any order. for row in reactions: - etype = EmojiType(row['emoji_type']) - eid, etext = row['emoji_id'], row['emoji_text'] + etype = EmojiType(row["emoji_type"]) + eid, etext = row["emoji_id"], row["emoji_text"] # get the main key to use, given # the emoji information @@ -760,27 +848,27 @@ class Storage: emoji.append(main_emoji) react_stats[main_emoji] = { - 'count': 0, - 'me': False, - 'emoji': partial_emoji(etype, eid, etext) + "count": 0, + "me": False, + "emoji": partial_emoji(etype, eid, etext), } # then the 2nd pass, where we insert # the info for each reaction in the react_stats # dictionary for row in reactions: - etype = EmojiType(row['emoji_type']) - eid, etext = row['emoji_id'], row['emoji_text'] + etype = EmojiType(row["emoji_type"]) + eid, etext = row["emoji_id"], row["emoji_text"] # same thing as the last loop, # extracting main key _, main_emoji = emoji_sql(etype, eid, etext) stats = react_stats[main_emoji] - stats['count'] += 1 + stats["count"] += 1 - if row['user_id'] == user_id: - stats['me'] = True + if row["user_id"] == user_id: + stats["me"] = True # after processing reaction counts, # we get them in the same order @@ -789,43 +877,51 @@ class Storage: async def get_attachments(self, message_id: int) -> List[Dict[str, Any]]: """Get a list of attachment objects tied to the message.""" - attachment_ids = await self.db.fetch(""" + attachment_ids = await self.db.fetch( + """ SELECT id FROM attachments WHERE message_id = $1 - """, message_id) + """, + message_id, + ) - attachment_ids = [r['id'] for r in attachment_ids] + attachment_ids = [r["id"] for r in attachment_ids] res = [] for attachment_id in attachment_ids: - row = await self.db.fetchrow(""" + row = await self.db.fetchrow( + """ SELECT id::text, message_id, channel_id, filename, filesize, image, height, width FROM attachments WHERE id = $1 - """, attachment_id) + """, + attachment_id, + ) drow = dict(row) - drow.pop('message_id') - drow.pop('channel_id') + drow.pop("message_id") + drow.pop("channel_id") - drow['size'] = drow['filesize'] - drow.pop('size') + drow["size"] = drow["filesize"] + drow.pop("size") # construct attachment url - proto = 'https' if self.app.config['IS_SSL'] else 'http' - main_url = self.app.config['MAIN_URL'] + proto = "https" if self.app.config["IS_SSL"] else "http" + main_url = self.app.config["MAIN_URL"] - drow['url'] = (f'{proto}://{main_url}/attachments/' - f'{row["channel_id"]}/{row["message_id"]}/' - f'{row["filename"]}') + drow["url"] = ( + f"{proto}://{main_url}/attachments/" + f'{row["channel_id"]}/{row["message_id"]}/' + f'{row["filename"]}' + ) # NOTE: since the url comes from the instance itself # i think proxy_url=url is valid. - drow['proxy_url'] = drow['url'] + drow["proxy_url"] = drow["url"] res.append(drow) @@ -834,7 +930,7 @@ class Storage: async def _inject_author(self, res: dict): """Inject a pseudo-user object when the message is made by a webhook.""" - author_id = res['author_id'] + author_id = res["author_id"] # if author_id is None, we fetch webhook info # from the message_webhook_info table. @@ -843,63 +939,69 @@ class Storage: # is copied from the webhook table, or inserted by the webhook # itself. this causes a complete disconnect from the messages # table into the webhooks table. - wb_info = await self.db.fetchrow(""" + wb_info = await self.db.fetchrow( + """ SELECT webhook_id, name, avatar FROM message_webhook_info WHERE message_id = $1 - """, int(res['id'])) + """, + int(res["id"]), + ) if not wb_info: - log.warning('webhook info not found for msg {}', - res['id']) + log.warning("webhook info not found for msg {}", res["id"]) wb_info = wb_info or { - 'id': res['id'], - 'bot': True, - 'avatar': None, - 'username': '', - 'discriminator': '0000', + "id": res["id"], + "bot": True, + "avatar": None, + "username": "", + "discriminator": "0000", } - res['author'] = { - 'id': str(wb_info['webhook_id']), - 'bot': True, - 'username': wb_info['name'], - 'avatar': wb_info['avatar'], - 'discriminator': '0000', + res["author"] = { + "id": str(wb_info["webhook_id"]), + "bot": True, + "username": wb_info["name"], + "avatar": wb_info["avatar"], + "discriminator": "0000", } else: - res['author'] = await self.get_user(res['author_id']) + res["author"] = await self.get_user(res["author_id"]) - res.pop('author_id') + res.pop("author_id") - async def get_message(self, message_id: int, - user_id: Optional[int] = None) -> Optional[Dict]: + async def get_message( + self, message_id: int, user_id: Optional[int] = None + ) -> Optional[Dict]: """Get a single message's payload.""" - row = await self.fetchrow_with_json(""" + row = await self.fetchrow_with_json( + """ SELECT id::text, channel_id::text, author_id, content, created_at AS timestamp, edited_at AS edited_timestamp, tts, mention_everyone, nonce, message_type, embeds, flags FROM messages WHERE id = $1 - """, message_id) + """, + message_id, + ) if not row: return None res = dict(row) - res['nonce'] = str(res['nonce']) - res['timestamp'] = timestamp_(res['timestamp']) - res['edited_timestamp'] = timestamp_(res['edited_timestamp']) + res["nonce"] = str(res["nonce"]) + res["timestamp"] = timestamp_(res["timestamp"]) + res["edited_timestamp"] = timestamp_(res["edited_timestamp"]) - res['type'] = res['message_type'] - res.pop('message_type') + res["type"] = res["message_type"] + res.pop("message_type") - if res['content'] is None: - res['content'] = "" + if res["content"] is None: + res["content"] = "" - channel_id = int(row['channel_id']) - content = row['content'] + channel_id = int(row["channel_id"]) + content = row["content"] guild_id = await self.guild_from_channel(channel_id) # calculate user mentions and role mentions by regex @@ -911,10 +1013,11 @@ class Storage: # TODO: maybe make this partial? member = await self.get_member_data_one(guild_id, user_id) - return {**user, **{'member': member}} if member else user + return {**user, **{"member": member}} if member else user - res['mentions'] = await self._msg_regex(USER_MENTION, _get_member, - row['content']) + res["mentions"] = await self._msg_regex( + USER_MENTION, _get_member, row["content"] + ) # _dummy just returns the string of the id, since we don't # actually use the role objects in mention_roles, just their ids. @@ -929,60 +1032,67 @@ class Storage: if not role: return - if not role['mentionable']: + if not role["mentionable"]: return return str(role_id) - res['mention_roles'] = await self._msg_regex( - ROLE_MENTION, _get_role_mention, content) + res["mention_roles"] = await self._msg_regex( + ROLE_MENTION, _get_role_mention, content + ) - res['reactions'] = await self.get_reactions(message_id, user_id) + res["reactions"] = await self.get_reactions(message_id, user_id) await self._inject_author(res) - res['attachments'] = await self.get_attachments(message_id) + res["attachments"] = await self.get_attachments(message_id) # if message is not from a dm, guild_id is None and so, _member_basic # will just return None # user id can be none, though, and we need to watch out for that if user_id is not None: - res['member'] = await self._member_basic_with_roles( - guild_id, user_id) + res["member"] = await self._member_basic_with_roles(guild_id, user_id) - if res.get('member') is None: + if res.get("member") is None: try: - res.pop('member') + res.pop("member") except KeyError: pass - pin_id = await self.db.fetchval(""" + pin_id = await self.db.fetchval( + """ SELECT message_id FROM channel_pins WHERE channel_id = $1 AND message_id = $2 - """, channel_id, message_id) + """, + channel_id, + message_id, + ) - res['pinned'] = pin_id is not None + res["pinned"] = pin_id is not None # this is specifically for lazy guilds: # only insert when the channel # is actually from a guild. if guild_id: - res['guild_id'] = str(guild_id) + res["guild_id"] = str(guild_id) - if res['flags'] == 0: - res.pop('flags') + if res["flags"] == 0: + res.pop("flags") return res async def get_invite(self, invite_code: str) -> Optional[Dict]: """Fetch invite information given its code.""" - invite = await self.db.fetchrow(""" + invite = await self.db.fetchrow( + """ SELECT code, guild_id, channel_id FROM invites WHERE code = $1 - """, invite_code) + """, + invite_code, + ) if invite is None: return None @@ -990,71 +1100,75 @@ class Storage: dinv = dict_(invite) # fetch some guild info - guild = await self.db.fetchrow(""" + guild = await self.db.fetchrow( + """ SELECT id::text, name, icon, splash, banner, features, verification_level, description FROM guilds WHERE id = $1 - """, invite['guild_id']) + """, + invite["guild_id"], + ) if guild: - dinv['guild'] = dict(guild) + dinv["guild"] = dict(guild) else: - dinv['guild'] = {} + dinv["guild"] = {} - chan = await self.get_channel(invite['channel_id']) + chan = await self.get_channel(invite["channel_id"]) if chan is None: return None - dinv['channel'] = { - 'id': chan['id'], - 'name': chan['name'], - 'type': chan['type'], - } + dinv["channel"] = {"id": chan["id"], "name": chan["name"], "type": chan["type"]} - dinv.pop('guild_id') - dinv.pop('channel_id') + dinv.pop("guild_id") + dinv.pop("channel_id") return dinv async def get_invite_extra(self, invite_code: str) -> dict: """Extra information about the invite, such as approximate guild and presence counts.""" - guild_id = await self.db.fetchval(""" + guild_id = await self.db.fetchval( + """ SELECT guild_id FROM invites WHERE code = $1 - """, invite_code) + """, + invite_code, + ) if guild_id is None: return {} mids = await self.get_member_ids(guild_id) pres = await self.presence.guild_presences(mids, guild_id) - online_count = sum(1 for p in pres if p['status'] == 'online') + online_count = sum(1 for p in pres if p["status"] == "online") return { - 'approximate_presence_count': online_count, - 'approximate_member_count': len(mids), + "approximate_presence_count": online_count, + "approximate_member_count": len(mids), } - async def get_invite_metadata(self, - invite_code: str) -> Optional[Dict[str, Any]]: + async def get_invite_metadata(self, invite_code: str) -> Optional[Dict[str, Any]]: """Fetch invite metadata (max_age and friends).""" - invite = await self.db.fetchrow(""" + invite = await self.db.fetchrow( + """ SELECT code, inviter, created_at, uses, max_uses, max_age, temporary, created_at, revoked FROM invites WHERE code = $1 - """, invite_code) + """, + invite_code, + ) if invite is None: return None dinv = dict_(invite) - inviter = await self.get_user(invite['inviter']) - dinv['inviter'] = inviter + inviter = await self.get_user(invite["inviter"]) + dinv["inviter"] = inviter return dinv @@ -1063,29 +1177,36 @@ class Storage: dm_chan = await self.get_channel(dm_id) if user_id and dm_chan: - dm_chan['recipients'] = _filter_recipients( - dm_chan['recipients'], str(user_id) + dm_chan["recipients"] = _filter_recipients( + dm_chan["recipients"], str(user_id) ) return dm_chan async def guild_from_channel(self, channel_id: int) -> int: """Get the guild id coming from a channel id.""" - return await self.db.fetchval(""" + return await self.db.fetchval( + """ SELECT guild_id FROM guild_channels WHERE id = $1 - """, channel_id) + """, + channel_id, + ) async def get_dm_peer(self, channel_id: int, user_id: int) -> int: """Get the peer id on a dm""" - parties = await self.db.fetchrow(""" + parties = await self.db.fetchrow( + """ SELECT party1_id, party2_id FROM dm_channels WHERE id = $1 AND (party1_id = $2 OR party2_id = $2) - """, channel_id, user_id) + """, + channel_id, + user_id, + ) - parties = [parties['party1_id'], parties['party2_id']] + parties = [parties["party1_id"], parties["party2_id"]] # get the id of the other party parties.remove(user_id) @@ -1094,12 +1215,15 @@ class Storage: async def get_emoji(self, emoji_id: int) -> Optional[Dict[str, Any]]: """Get a single emoji.""" - row = await self.db.fetchrow(""" + row = await self.db.fetchrow( + """ SELECT id::text, name, animated, managed, require_colons, uploader_id FROM guild_emoji WHERE id = $1 - """, emoji_id) + """, + emoji_id, + ) if not row: return None @@ -1107,22 +1231,25 @@ class Storage: drow = dict(row) # ???? - drow['roles'] = [] + drow["roles"] = [] - uploader_id = drow.pop('uploader_id') - drow['user'] = await self.get_user(uploader_id) + uploader_id = drow.pop("uploader_id") + drow["user"] = await self.get_user(uploader_id) return drow async def get_guild_emojis(self, guild_id: int): """Get a list of all emoji objects in a guild.""" - rows = await self.db.fetch(""" + rows = await self.db.fetch( + """ SELECT id FROM guild_emoji WHERE guild_id = $1 - """, guild_id) + """, + guild_id, + ) - emoji_ids = [r['id'] for r in rows] + emoji_ids = [r["id"] for r in rows] res = [] @@ -1134,28 +1261,36 @@ class Storage: async def get_role_members(self, role_id: int) -> List[int]: """Get all members with a role.""" - rows = await self.db.fetch(""" + rows = await self.db.fetch( + """ SELECT user_id FROM member_roles WHERE role_id = $1 - """, role_id) + """, + role_id, + ) - return [r['id'] for r in rows] + return [r["id"] for r in rows] async def all_voice_regions(self) -> List[Dict[str, Any]]: """Return a list of all voice regions.""" - rows = await self.db.fetch(""" + rows = await self.db.fetch( + """ SELECT id, name, vip, deprecated, custom FROM voice_regions - """) + """ + ) return list(map(dict, rows)) async def has_feature(self, guild_id: int, feature: str) -> bool: """Return if a certain guild has a certain feature.""" - features = await self.db.fetchval(""" + features = await self.db.fetchval( + """ SELECT features FROM guilds WHERE id = $1 - """, guild_id) + """, + guild_id, + ) return feature.upper() in features diff --git a/litecord/system_messages.py b/litecord/system_messages.py index a9e1cab..bc63717 100644 --- a/litecord/system_messages.py +++ b/litecord/system_messages.py @@ -24,6 +24,7 @@ from litecord.enums import MessageType log = Logger(__name__) + async def _handle_pin_msg(app, channel_id, _pinned_id, author_id): """Handle a message pin.""" new_id = get_snowflake() @@ -37,8 +38,10 @@ async def _handle_pin_msg(app, channel_id, _pinned_id, author_id): ($1, $2, NULL, $3, NULL, '', $4) """, - new_id, channel_id, author_id, - MessageType.CHANNEL_PINNED_MESSAGE.value + new_id, + channel_id, + author_id, + MessageType.CHANNEL_PINNED_MESSAGE.value, ) return new_id @@ -56,15 +59,16 @@ async def _handle_recp_add(app, channel_id, author_id, peer_id): VALUES ($1, $2, $3, NULL, $4, $5) """, - new_id, channel_id, author_id, - f'<@{peer_id}>', - MessageType.RECIPIENT_ADD.value + new_id, + channel_id, + author_id, + f"<@{peer_id}>", + MessageType.RECIPIENT_ADD.value, ) return new_id - async def _handle_recp_rmv(app, channel_id, author_id, peer_id): new_id = get_snowflake() @@ -76,9 +80,11 @@ async def _handle_recp_rmv(app, channel_id, author_id, peer_id): VALUES ($1, $2, $3, NULL, $4, $5) """, - new_id, channel_id, author_id, - f'<@{peer_id}>', - MessageType.RECIPIENT_REMOVE.value + new_id, + channel_id, + author_id, + f"<@{peer_id}>", + MessageType.RECIPIENT_REMOVE.value, ) return new_id @@ -87,13 +93,16 @@ async def _handle_recp_rmv(app, channel_id, author_id, peer_id): async def _handle_gdm_name_edit(app, channel_id, author_id): new_id = get_snowflake() - gdm_name = await app.db.fetchval(""" + gdm_name = await app.db.fetchval( + """ SELECT name FROM group_dm_channels WHERE id = $1 - """, channel_id) + """, + channel_id, + ) if not gdm_name: - log.warning('no gdm name found for sys message') + log.warning("no gdm name found for sys message") return await app.db.execute( @@ -104,9 +113,11 @@ async def _handle_gdm_name_edit(app, channel_id, author_id): VALUES ($1, $2, $3, NULL, $4, $5) """, - new_id, channel_id, author_id, + new_id, + channel_id, + author_id, gdm_name, - MessageType.CHANNEL_NAME_CHANGE.value + MessageType.CHANNEL_NAME_CHANGE.value, ) return new_id @@ -123,16 +134,19 @@ async def _handle_gdm_icon_edit(app, channel_id, author_id): VALUES ($1, $2, $3, NULL, $4, $5) """, - new_id, channel_id, author_id, - '', - MessageType.CHANNEL_ICON_CHANGE.value + new_id, + channel_id, + author_id, + "", + MessageType.CHANNEL_ICON_CHANGE.value, ) return new_id -async def send_sys_message(app, channel_id: int, m_type: MessageType, - *args, **kwargs) -> int: +async def send_sys_message( + app, channel_id: int, m_type: MessageType, *args, **kwargs +) -> int: """Send a system message. The handler for a given message type MUST return an integer, that integer @@ -156,22 +170,19 @@ async def send_sys_message(app, channel_id: int, m_type: MessageType, try: handler = { MessageType.CHANNEL_PINNED_MESSAGE: _handle_pin_msg, - # gdm specific MessageType.RECIPIENT_ADD: _handle_recp_add, MessageType.RECIPIENT_REMOVE: _handle_recp_rmv, MessageType.CHANNEL_NAME_CHANGE: _handle_gdm_name_edit, - MessageType.CHANNEL_ICON_CHANGE: _handle_gdm_icon_edit + MessageType.CHANNEL_ICON_CHANGE: _handle_gdm_icon_edit, }[m_type] except KeyError: - raise ValueError('Invalid system message type') + raise ValueError("Invalid system message type") message_id = await handler(app, channel_id, *args, **kwargs) message = await app.storage.get_message(message_id) - await app.dispatcher.dispatch( - 'channel', channel_id, 'MESSAGE_CREATE', message - ) + await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", message) return message_id diff --git a/litecord/types.py b/litecord/types.py index 2c9c122..da8b157 100644 --- a/litecord/types.py +++ b/litecord/types.py @@ -29,6 +29,7 @@ HOURS = 60 * MINUTES class Color: """Custom color class""" + def __init__(self, val: int): self.blue = val & 255 self.green = (val >> 8) & 255 @@ -37,7 +38,7 @@ class Color: @property def value(self): """Give the actual RGB integer encoding this color.""" - return int('%02x%02x%02x' % (self.red, self.green, self.blue), 16) + return int("%02x%02x%02x" % (self.red, self.green, self.blue), 16) @property def to_json(self): @@ -49,4 +50,4 @@ class Color: def timestamp_(dt) -> Optional[str]: """safer version for dt.isoformat()""" - return f'{dt.isoformat()}+00:00' if dt else None + return f"{dt.isoformat()}+00:00" if dt else None diff --git a/litecord/user_storage.py b/litecord/user_storage.py index baaeb2a..45704cc 100644 --- a/litecord/user_storage.py +++ b/litecord/user_storage.py @@ -27,43 +27,52 @@ log = Logger(__name__) class UserStorage: """Storage functions related to a single user.""" + def __init__(self, storage): self.storage = storage self.db = storage.db async def fetch_notes(self, user_id: int) -> dict: """Fetch a users' notes""" - note_rows = await self.db.fetch(""" + note_rows = await self.db.fetch( + """ SELECT target_id, note FROM notes WHERE user_id = $1 - """, user_id) + """, + user_id, + ) - return {str(row['target_id']): row['note'] - for row in note_rows} + return {str(row["target_id"]): row["note"] for row in note_rows} async def get_user_settings(self, user_id: int) -> Dict[str, Any]: """Get current user settings.""" - row = await self.storage.fetchrow_with_json(""" + row = await self.storage.fetchrow_with_json( + """ SELECT * FROM user_settings WHERE id = $1 - """, user_id) + """, + user_id, + ) if not row: - log.info('Generating user settings for {}', user_id) + log.info("Generating user settings for {}", user_id) - await self.db.execute(""" + await self.db.execute( + """ INSERT INTO user_settings (id) VALUES ($1) - """, user_id) + """, + user_id, + ) # recalling get_user_settings # should work after adding return await self.get_user_settings(user_id) drow = dict(row) - drow.pop('id') + drow.pop("id") return drow async def get_relationships(self, user_id: int) -> List[Dict[str, Any]]: @@ -76,11 +85,15 @@ class UserStorage: _outgoing = RelationshipType.OUTGOING.value # check all outgoing friends - friends = await self.db.fetch(""" + friends = await self.db.fetch( + """ SELECT user_id, peer_id, rel_type FROM relationships WHERE user_id = $1 AND rel_type = $2 - """, user_id, _friend) + """, + user_id, + _friend, + ) friends = list(map(dict, friends)) # mutuals is a list of ints @@ -95,66 +108,80 @@ class UserStorage: SELECT user_id, peer_id FROM relationships WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3 - """, row['peer_id'], row['user_id'], - _friend) + """, + row["peer_id"], + row["user_id"], + _friend, + ) if is_friend is not None: - mutuals.append(row['peer_id']) + mutuals.append(row["peer_id"]) # fetch friend requests directed at us - incoming_friends = await self.db.fetch(""" + incoming_friends = await self.db.fetch( + """ SELECT user_id, peer_id FROM relationships WHERE peer_id = $1 AND rel_type = $2 - """, user_id, _friend) + """, + user_id, + _friend, + ) # only need their ids - incoming_friends = [r['user_id'] for r in incoming_friends - if r['user_id'] not in mutuals] + incoming_friends = [ + r["user_id"] for r in incoming_friends if r["user_id"] not in mutuals + ] # only fetch blocks we did, # not fetching the ones people did to us - blocks = await self.db.fetch(""" + blocks = await self.db.fetch( + """ SELECT user_id, peer_id, rel_type FROM relationships WHERE user_id = $1 AND rel_type = $2 - """, user_id, _block) + """, + user_id, + _block, + ) blocks = list(map(dict, blocks)) res = [] for drow in friends: - drow['type'] = drow['rel_type'] - drow['id'] = str(drow['peer_id']) - drow.pop('rel_type') + drow["type"] = drow["rel_type"] + drow["id"] = str(drow["peer_id"]) + drow.pop("rel_type") # check if the receiver is a mutual # if it isnt, its still on a friend request stage - if drow['peer_id'] not in mutuals: - drow['type'] = _outgoing + if drow["peer_id"] not in mutuals: + drow["type"] = _outgoing - drow['user'] = await self.storage.get_user(drow['peer_id']) + drow["user"] = await self.storage.get_user(drow["peer_id"]) - drow.pop('user_id') - drow.pop('peer_id') + drow.pop("user_id") + drow.pop("peer_id") res.append(drow) for peer_id in incoming_friends: - res.append({ - 'id': str(peer_id), - 'user': await self.storage.get_user(peer_id), - 'type': _incoming, - }) + res.append( + { + "id": str(peer_id), + "user": await self.storage.get_user(peer_id), + "type": _incoming, + } + ) for drow in blocks: - drow['type'] = drow['rel_type'] - drow.pop('rel_type') + drow["type"] = drow["rel_type"] + drow.pop("rel_type") - drow['id'] = str(drow['peer_id']) - drow['user'] = await self.storage.get_user(drow['peer_id']) + drow["id"] = str(drow["peer_id"]) + drow["user"] = await self.storage.get_user(drow["peer_id"]) - drow.pop('user_id') - drow.pop('peer_id') + drow.pop("user_id") + drow.pop("peer_id") res.append(drow) return res @@ -163,9 +190,11 @@ class UserStorage: """Get all friend IDs for a user.""" rels = await self.get_relationships(user_id) - return [int(r['user']['id']) - for r in rels - if r['type'] == RelationshipType.FRIEND.value] + return [ + int(r["user"]["id"]) + for r in rels + if r["type"] == RelationshipType.FRIEND.value + ] async def get_dms(self, user_id: int) -> List[Dict[str, Any]]: """Get all DM channels for a user, including group DMs. @@ -173,13 +202,16 @@ class UserStorage: This will only fetch channels the user has in their state, which is different than the whole list of DM channels. """ - dm_ids = await self.db.fetch(""" + dm_ids = await self.db.fetch( + """ SELECT dm_id FROM dm_channel_state WHERE user_id = $1 - """, user_id) + """, + user_id, + ) - dm_ids = [r['dm_id'] for r in dm_ids] + dm_ids = [r["dm_id"] for r in dm_ids] res = [] @@ -191,21 +223,24 @@ class UserStorage: async def get_read_state(self, user_id: int) -> List[Dict[str, Any]]: """Get the read state for a user.""" - rows = await self.db.fetch(""" + rows = await self.db.fetch( + """ SELECT channel_id, last_message_id, mention_count FROM user_read_state WHERE user_id = $1 - """, user_id) + """, + user_id, + ) res = [] for row in rows: drow = dict(row) - drow['id'] = str(drow['channel_id']) - drow.pop('channel_id') + drow["id"] = str(drow["channel_id"]) + drow.pop("channel_id") - drow['last_message_id'] = str(drow['last_message_id']) + drow["last_message_id"] = str(drow["last_message_id"]) res.append(drow) @@ -214,13 +249,17 @@ class UserStorage: async def _get_chan_overrides(self, user_id: int, guild_id: int) -> List: chan_overrides = [] - overrides = await self.db.fetch(""" + overrides = await self.db.fetch( + """ SELECT channel_id::text, muted, message_notifications FROM guild_settings_channel_overrides WHERE user_id = $1 AND guild_id = $2 - """, user_id, guild_id) + """, + user_id, + guild_id, + ) for chan_row in overrides: dcrow = dict(chan_row) @@ -228,30 +267,35 @@ class UserStorage: return chan_overrides - async def get_guild_settings_one(self, user_id: int, - guild_id: int) -> dict: + async def get_guild_settings_one(self, user_id: int, guild_id: int) -> dict: """Get guild settings information for a single guild.""" - row = await self.db.fetchrow(""" + row = await self.db.fetchrow( + """ SELECT guild_id::text, suppress_everyone, muted, message_notifications, mobile_push FROM guild_settings WHERE user_id = $1 AND guild_id = $2 - """, user_id, guild_id) + """, + user_id, + guild_id, + ) if not row: - await self.db.execute(""" + await self.db.execute( + """ INSERT INTO guild_settings (user_id, guild_id) VALUES ($1, $2) - """, user_id, guild_id) + """, + user_id, + guild_id, + ) return await self.get_guild_settings_one(user_id, guild_id) - gid = int(row['guild_id']) + gid = int(row["guild_id"]) drow = dict(row) chan_overrides = await self._get_chan_overrides(user_id, gid) - return {**drow, **{ - 'channel_overrides': chan_overrides - }} + return {**drow, **{"channel_overrides": chan_overrides}} async def get_guild_settings(self, user_id: int): """Get the specific User Guild Settings, @@ -259,34 +303,38 @@ class UserStorage: res = [] - settings = await self.db.fetch(""" + settings = await self.db.fetch( + """ SELECT guild_id::text, suppress_everyone, muted, message_notifications, mobile_push FROM guild_settings WHERE user_id = $1 - """, user_id) + """, + user_id, + ) for row in settings: - gid = int(row['guild_id']) + gid = int(row["guild_id"]) drow = dict(row) chan_overrides = await self._get_chan_overrides(user_id, gid) - res.append({**drow, **{ - 'channel_overrides': chan_overrides - }}) + res.append({**drow, **{"channel_overrides": chan_overrides}}) return res async def get_user_guilds(self, user_id: int) -> List[int]: """Get all guild IDs a user is on.""" - guild_ids = await self.db.fetch(""" + guild_ids = await self.db.fetch( + """ SELECT guild_id FROM members WHERE user_id = $1 - """, user_id) + """, + user_id, + ) - return [row['guild_id'] for row in guild_ids] + return [row["guild_id"] for row in guild_ids] async def get_mutual_guilds(self, user_id: int, peer_id: int) -> List[int]: """Get a list of guilds two separate users @@ -301,13 +349,17 @@ class UserStorage: return await self.get_user_guilds(user_id) or [0] - mutual_guilds = await self.db.fetch(""" + mutual_guilds = await self.db.fetch( + """ SELECT guild_id FROM members WHERE user_id = $1 INTERSECT SELECT guild_id FROM members WHERE user_id = $2 - """, user_id, peer_id) + """, + user_id, + peer_id, + ) - mutual_guilds = [r['guild_id'] for r in mutual_guilds] + mutual_guilds = [r["guild_id"] for r in mutual_guilds] return mutual_guilds @@ -316,7 +368,8 @@ class UserStorage: This returns false even if there is a friend request. """ - return await self.db.fetchval(""" + return await self.db.fetchval( + """ SELECT ( SELECT EXISTS( @@ -337,17 +390,23 @@ class UserStorage: AND rel_type = 1 ) ) - """, user_id, peer_id) + """, + user_id, + peer_id, + ) async def get_gdms_internal(self, user_id) -> List[int]: """Return a list of Group DM IDs the user is a member of.""" - rows = await self.db.fetch(""" + rows = await self.db.fetch( + """ SELECT id FROM group_dm_members WHERE member_id = $1 - """, user_id) + """, + user_id, + ) - return [r['id'] for r in rows] + return [r["id"] for r in rows] async def get_gdms(self, user_id) -> List[Dict[str, Any]]: """Get list of group DMs a user is in.""" @@ -356,8 +415,6 @@ class UserStorage: res = [] for gdm_id in gdm_ids: - res.append( - await self.storage.get_channel(gdm_id, user_id=user_id) - ) + res.append(await self.storage.get_channel(gdm_id, user_id=user_id)) return res diff --git a/litecord/utils.py b/litecord/utils.py index fcc192d..b2c5478 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -46,7 +46,7 @@ async def task_wrapper(name: str, coro): except asyncio.CancelledError: pass except: - log.exception('{} task error', name) + log.exception("{} task error", name) def dict_get(mapping, key, default): @@ -84,54 +84,66 @@ def mmh3(inp_str: str, seed: int = 0): h1 = seed # mm3 constants - c1 = 0xcc9e2d51 - c2 = 0x1b873593 + c1 = 0xCC9E2D51 + c2 = 0x1B873593 i = 0 while i < bytecount: k1 = ( - (key[i] & 0xff) | - ((key[i + 1] & 0xff) << 8) | - ((key[i + 2] & 0xff) << 16) | - ((key[i + 3] & 0xff) << 24) + (key[i] & 0xFF) + | ((key[i + 1] & 0xFF) << 8) + | ((key[i + 2] & 0xFF) << 16) + | ((key[i + 3] & 0xFF) << 24) ) i += 4 - k1 = ((((k1 & 0xffff) * c1) + ((((_u(k1) >> 16) * c1) & 0xffff) << 16))) & 0xffffffff + k1 = ( + (((k1 & 0xFFFF) * c1) + ((((_u(k1) >> 16) * c1) & 0xFFFF) << 16)) + ) & 0xFFFFFFFF k1 = (k1 << 15) | (_u(k1) >> 17) - k1 = ((((k1 & 0xffff) * c2) + ((((_u(k1) >> 16) * c2) & 0xffff) << 16))) & 0xffffffff; + k1 = ( + (((k1 & 0xFFFF) * c2) + ((((_u(k1) >> 16) * c2) & 0xFFFF) << 16)) + ) & 0xFFFFFFFF h1 ^= k1 - h1 = (h1 << 13) | (_u(h1) >> 19); - h1b = ((((h1 & 0xffff) * 5) + ((((_u(h1) >> 16) * 5) & 0xffff) << 16))) & 0xffffffff; - h1 = (((h1b & 0xffff) + 0x6b64) + ((((_u(h1b) >> 16) + 0xe654) & 0xffff) << 16)) - + h1 = (h1 << 13) | (_u(h1) >> 19) + h1b = ( + (((h1 & 0xFFFF) * 5) + ((((_u(h1) >> 16) * 5) & 0xFFFF) << 16)) + ) & 0xFFFFFFFF + h1 = ((h1b & 0xFFFF) + 0x6B64) + ((((_u(h1b) >> 16) + 0xE654) & 0xFFFF) << 16) k1 = 0 v = None if remainder == 3: - v = (key[i + 2] & 0xff) << 16 + v = (key[i + 2] & 0xFF) << 16 elif remainder == 2: - v = (key[i + 1] & 0xff) << 8 + v = (key[i + 1] & 0xFF) << 8 elif remainder == 1: - v = (key[i] & 0xff) + v = key[i] & 0xFF if v is not None: k1 ^= v - k1 = (((k1 & 0xffff) * c1) + ((((_u(k1) >> 16) * c1) & 0xffff) << 16)) & 0xffffffff + k1 = (((k1 & 0xFFFF) * c1) + ((((_u(k1) >> 16) * c1) & 0xFFFF) << 16)) & 0xFFFFFFFF k1 = (k1 << 15) | (_u(k1) >> 17) - k1 = (((k1 & 0xffff) * c2) + ((((_u(k1) >> 16) * c2) & 0xffff) << 16)) & 0xffffffff + k1 = (((k1 & 0xFFFF) * c2) + ((((_u(k1) >> 16) * c2) & 0xFFFF) << 16)) & 0xFFFFFFFF h1 ^= k1 h1 ^= len(key) h1 ^= _u(h1) >> 16 - h1 = (((h1 & 0xffff) * 0x85ebca6b) + ((((_u(h1) >> 16) * 0x85ebca6b) & 0xffff) << 16)) & 0xffffffff + h1 = ( + ((h1 & 0xFFFF) * 0x85EBCA6B) + ((((_u(h1) >> 16) * 0x85EBCA6B) & 0xFFFF) << 16) + ) & 0xFFFFFFFF h1 ^= _u(h1) >> 13 - h1 = ((((h1 & 0xffff) * 0xc2b2ae35) + ((((_u(h1) >> 16) * 0xc2b2ae35) & 0xffff) << 16))) & 0xffffffff + h1 = ( + ( + ((h1 & 0xFFFF) * 0xC2B2AE35) + + ((((_u(h1) >> 16) * 0xC2B2AE35) & 0xFFFF) << 16) + ) + ) & 0xFFFFFFFF h1 ^= _u(h1) >> 16 return _u(h1) >> 0 @@ -139,6 +151,7 @@ def mmh3(inp_str: str, seed: int = 0): class LitecordJSONEncoder(JSONEncoder): """Custom JSON encoder for Litecord.""" + def default(self, value: Any): """By default, this will try to get the to_json attribute of a given value being JSON encoded.""" @@ -151,17 +164,17 @@ class LitecordJSONEncoder(JSONEncoder): async def pg_set_json(con): """Set JSON and JSONB codecs for an asyncpg connection.""" await con.set_type_codec( - 'json', + "json", encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder), decoder=json.loads, - schema='pg_catalog' + schema="pg_catalog", ) await con.set_type_codec( - 'jsonb', + "jsonb", encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder), decoder=json.loads, - schema='pg_catalog' + schema="pg_catalog", ) @@ -177,7 +190,8 @@ def yield_chunks(input_list: Sequence[Any], chunk_size: int): # range accepts step param, so we use that to # make the chunks for idx in range(0, len(input_list), chunk_size): - yield input_list[idx:idx + chunk_size] + yield input_list[idx : idx + chunk_size] + def to_update(j: dict, orig: dict, field: str) -> bool: """Compare values to check if j[field] is actually updating @@ -193,27 +207,23 @@ async def search_result_from_list(rows: List) -> Dict[str, Any]: - An int (?) on `total_results` - Two bigint[], each on `before` and `after` respectively. """ - results = 0 if not rows else rows[0]['total_results'] + results = 0 if not rows else rows[0]["total_results"] res = [] for row in rows: before, after = [], [] - for before_id in reversed(row['before']): + for before_id in reversed(row["before"]): before.append(await app.storage.get_message(before_id)) - for after_id in row['after']: + for after_id in row["after"]: after.append(await app.storage.get_message(after_id)) - msg = await app.storage.get_message(row['current_id']) - msg['hit'] = True + msg = await app.storage.get_message(row["current_id"]) + msg["hit"] = True res.append(before + [msg] + after) - return { - 'total_results': results, - 'messages': res, - 'analytics_id': '', - } + return {"total_results": results, "messages": res, "analytics_id": ""} def maybe_int(val: Any) -> Union[int, Any]: diff --git a/litecord/voice/lvsp_conn.py b/litecord/voice/lvsp_conn.py index 3902674..d7bad73 100644 --- a/litecord/voice/lvsp_conn.py +++ b/litecord/voice/lvsp_conn.py @@ -31,6 +31,7 @@ log = Logger(__name__) class LVSPConnection: """Represents a single LVSP connection.""" + def __init__(self, lvsp, region: str, hostname: str): self.lvsp = lvsp self.app = lvsp.app @@ -46,7 +47,7 @@ class LVSPConnection: @property def _log_id(self): - return f'region={self.region} hostname={self.hostname}' + return f"region={self.region} hostname={self.hostname}" async def send(self, payload): """Send a payload down the websocket.""" @@ -61,50 +62,42 @@ class LVSPConnection: async def send_op(self, opcode: int, data: dict): """Send a message with an OP code included""" - await self.send({ - 'op': opcode, - 'd': data - }) + await self.send({"op": opcode, "d": data}) async def send_info(self, info_type: str, info_data: Dict): """Send an INFO message down the websocket.""" - await self.send({ - 'op': OP.info, - 'd': { - 'type': InfoTable[info_type.upper()], - 'data': info_data + await self.send( + { + "op": OP.info, + "d": {"type": InfoTable[info_type.upper()], "data": info_data}, } - }) + ) async def _heartbeater(self, hb_interval: int): try: await asyncio.sleep(hb_interval) # TODO: add self._seq - await self.send_op(OP.heartbeat, { - 's': 0 - }) + await self.send_op(OP.heartbeat, {"s": 0}) # give the server 300 milliseconds to reply. await asyncio.sleep(300) - await self.conn.close(4000, 'heartbeat timeout') + await self.conn.close(4000, "heartbeat timeout") except asyncio.CancelledError: pass def _start_hb(self): - self._hb_task = self.app.loop.create_task( - self._heartbeater(self._hb_interval) - ) + self._hb_task = self.app.loop.create_task(self._heartbeater(self._hb_interval)) def _stop_hb(self): self._hb_task.cancel() async def _handle_0(self, msg): """Handle HELLO message.""" - data = msg['d'] + data = msg["d"] # nonce = data['nonce'] - self._hb_interval = data['heartbeat_interval'] + self._hb_interval = data["heartbeat_interval"] # TODO: send identify @@ -112,48 +105,52 @@ class LVSPConnection: """Update the health value of a given voice server.""" self.health = new_health - await self.app.db.execute(""" + await self.app.db.execute( + """ UPDATE voice_servers SET health = $1 WHERE hostname = $2 - """, new_health, self.hostname) + """, + new_health, + self.hostname, + ) async def _handle_3(self, msg): """Handle READY message. We only start heartbeating after READY. """ - await self._update_health(msg['health']) + await self._update_health(msg["health"]) self._start_hb() async def _handle_5(self, msg): """Handle HEARTBEAT_ACK.""" self._stop_hb() - await self._update_health(msg['health']) + await self._update_health(msg["health"]) self._start_hb() async def _handle_6(self, msg): """Handle INFO messages.""" - info = msg['d'] - info_type_str = InfoReverse[info['type']].lower() + info = msg["d"] + info_type_str = InfoReverse[info["type"]].lower() try: - info_handler = getattr(self, f'_handle_info_{info_type_str}') + info_handler = getattr(self, f"_handle_info_{info_type_str}") except AttributeError: return - await info_handler(info['data']) + await info_handler(info["data"]) async def _handle_info_channel_assign(self, data: dict): """called by the server once we got a channel assign.""" try: - channel_id = data['channel_id'] + channel_id = data["channel_id"] channel_id = int(channel_id) except (TypeError, ValueError): return try: - guild_id = data['guild_id'] + guild_id = data["guild_id"] guild_id = int(guild_id) except (TypeError, ValueError): guild_id = None @@ -166,19 +163,19 @@ class LVSPConnection: msg = await self.recv() try: - opcode = msg['op'] - handler = getattr(self, f'_handle_{opcode}') + opcode = msg["op"] + handler = getattr(self, f"_handle_{opcode}") await handler(msg) except (KeyError, AttributeError): # TODO: error codes in LVSP - raise Exception('invalid op code') + raise Exception("invalid op code") async def start(self): """Try to start a websocket connection.""" try: - self.conn = await websockets.connect(f'wss://{self.hostname}') + self.conn = await websockets.connect(f"wss://{self.hostname}") except Exception: - log.exception('failed to start lvsp conn to {}', self.hostname) + log.exception("failed to start lvsp conn to {}", self.hostname) async def run(self): """Start the websocket.""" @@ -186,15 +183,15 @@ class LVSPConnection: try: if not self.conn: - log.error('failed to start lvsp connection, stopping') + log.error("failed to start lvsp connection, stopping") return await self._loop() except websockets.exceptions.ConnectionClosed as err: - log.warning('conn close, {}, err={}', self._log_id, err) + log.warning("conn close, {}, err={}", self._log_id, err) # except WebsocketClose as err: # log.warning('ws close, state={} err={}', self.state, err) # await self.conn.close(code=err.code, reason=err.reason) except Exception as err: - log.exception('An exception has occoured. {}', self._log_id) + log.exception("An exception has occoured. {}", self._log_id) await self.conn.close(code=4000, reason=repr(err)) diff --git a/litecord/voice/lvsp_manager.py b/litecord/voice/lvsp_manager.py index 30407dc..2cc0e76 100644 --- a/litecord/voice/lvsp_manager.py +++ b/litecord/voice/lvsp_manager.py @@ -31,6 +31,7 @@ log = Logger(__name__) @dataclass class Region: """Voice region data.""" + id: str vip: bool @@ -40,6 +41,7 @@ class LVSPManager: Spawns :class:`LVSPConnection` as needed, etc. """ + def __init__(self, app, voice): self.app = app self.voice = voice @@ -61,49 +63,50 @@ class LVSPManager: async def _spawn(self): """Spawn LVSPConnection for each region.""" - regions = await self.app.db.fetch(""" + regions = await self.app.db.fetch( + """ SELECT id, vip FROM voice_regions WHERE deprecated = false - """) + """ + ) - regions = [Region(r['id'], r['vip']) for r in regions] + regions = [Region(r["id"], r["vip"]) for r in regions] if not regions: - log.warning('no regions are setup') + log.warning("no regions are setup") return for region in regions: # store it locally for region() function self.regions[region.id] = region - self.app.loop.create_task( - self._spawn_region(region) - ) + self.app.loop.create_task(self._spawn_region(region)) async def _spawn_region(self, region: Region): """Spawn a region. Involves fetching all the hostnames for the regions and spawning a LVSPConnection for each.""" - servers = await self.app.db.fetch(""" + servers = await self.app.db.fetch( + """ SELECT hostname FROM voice_servers WHERE region_id = $1 - """, region.id) + """, + region.id, + ) if not servers: - log.warning('region {} does not have servers', region) + log.warning("region {} does not have servers", region) return - servers = [r['hostname'] for r in servers] + servers = [r["hostname"] for r in servers] self.servers[region.id] = servers for hostname in servers: conn = LVSPConnection(self, region.id, hostname) self.conns[hostname] = conn - self.app.loop.create_task( - conn.run() - ) + self.app.loop.create_task(conn.run()) async def del_conn(self, conn): """Delete a connection from the connection pool.""" @@ -119,11 +122,14 @@ class LVSPManager: async def guild_region(self, guild_id: int) -> Optional[str]: """Return the voice region of a guild.""" - return await self.app.db.fetchval(""" + return await self.app.db.fetchval( + """ SELECT region FROM guilds WHERE id = $1 - """, guild_id) + """, + guild_id, + ) def get_health(self, hostname: str) -> float: """Get voice server health, given hostname.""" @@ -144,10 +150,7 @@ class LVSPManager: region = await self.guild_region(guild_id) # sort connected servers by health - sorted_servers = sorted( - self.servers[region], - key=self.get_health - ) + sorted_servers = sorted(self.servers[region], key=self.get_health) try: hostname = sorted_servers[0] diff --git a/litecord/voice/lvsp_opcodes.py b/litecord/voice/lvsp_opcodes.py index 4cd0b3c..b617fae 100644 --- a/litecord/voice/lvsp_opcodes.py +++ b/litecord/voice/lvsp_opcodes.py @@ -17,8 +17,10 @@ along with this program. If not, see . """ + class OPCodes: """LVSP OP codes.""" + hello = 0 identify = 1 resume = 2 @@ -29,13 +31,13 @@ class OPCodes: InfoTable = { - 'CHANNEL_REQ': 0, - 'CHANNEL_ASSIGN': 1, - 'CHANNEL_UPDATE': 2, - 'CHANNEL_DESTROY': 3, - 'VST_CREATE': 4, - 'VST_UPDATE': 5, - 'VST_LEAVE': 6, + "CHANNEL_REQ": 0, + "CHANNEL_ASSIGN": 1, + "CHANNEL_UPDATE": 2, + "CHANNEL_DESTROY": 3, + "VST_CREATE": 4, + "VST_UPDATE": 5, + "VST_LEAVE": 6, } InfoReverse = {v: k for k, v in InfoTable.items()} diff --git a/litecord/voice/manager.py b/litecord/voice/manager.py index b927cd3..1a2367a 100644 --- a/litecord/voice/manager.py +++ b/litecord/voice/manager.py @@ -43,6 +43,7 @@ def _construct_state(state_dict: dict) -> VoiceState: class VoiceManager: """Main voice manager class.""" + def __init__(self, app): self.app = app @@ -56,7 +57,7 @@ class VoiceManager: """Return if a user can join a channel.""" channel = await self.app.storage.get_channel(channel_id) - ctype = ChannelType(channel['type']) + ctype = ChannelType(channel["type"]) if ctype not in VOICE_CHANNELS: return @@ -65,14 +66,12 @@ class VoiceManager: # get_permissions returns ALL_PERMISSIONS when # the channel isn't from a guild - perms = await get_permissions( - user_id, channel_id, storage=self.app.storage - ) + perms = await get_permissions(user_id, channel_id, storage=self.app.storage) # hacky user_limit but should work, as channels not # in guilds won't have that field. - is_full = states >= channel.get('user_limit', 100) - is_bot = (await self.app.storage.get_user(user_id))['bot'] + is_full = states >= channel.get("user_limit", 100) + is_bot = (await self.app.storage.get_user(user_id))["bot"] is_manager = perms.bits.manage_channels # if the channel is full AND: @@ -140,8 +139,8 @@ class VoiceManager: for field in prop: # NOTE: this should not happen, ever. - if field in ('channel_id', 'user_id'): - raise ValueError('properties are updating channel or user') + if field in ("channel_id", "user_id"): + raise ValueError("properties are updating channel or user") new_state_dict[field] = prop[field] @@ -153,27 +152,28 @@ class VoiceManager: async def move_channels(self, old_voice_key: VoiceKey, channel_id: int): """Move a user between channels.""" await self.del_state(old_voice_key) - await self.create_state(old_voice_key, {'channel_id': channel_id}) + await self.create_state(old_voice_key, {"channel_id": channel_id}) async def _lvsp_info_guild(self, guild_id, info_type, info_data): hostname = await self.lvsp.get_guild_server(guild_id) if hostname is None: - log.error('no voice server for guild id {}', guild_id) + log.error("no voice server for guild id {}", guild_id) return conn = self.lvsp.get_conn(hostname) await conn.send_info(info_type, info_data) async def _create_ctx_guild(self, guild_id, channel_id): - await self._lvsp_info_guild(guild_id, 'CHANNEL_REQ', { - 'guild_id': str(guild_id), - 'channel_id': str(channel_id), - }) + await self._lvsp_info_guild( + guild_id, + "CHANNEL_REQ", + {"guild_id": str(guild_id), "channel_id": str(channel_id)}, + ) async def _start_voice_guild(self, voice_key: VoiceKey, data: dict): """Start a voice context in a guild.""" user_id, guild_id = voice_key - channel_id = int(data['channel_id']) + channel_id = int(data["channel_id"]) existing_states = self.states[voice_key] channel_exists = any( @@ -183,11 +183,15 @@ class VoiceManager: if not channel_exists: await self._create_ctx_guild(guild_id, channel_id) - await self._lvsp_info_guild(guild_id, 'VST_CREATE', { - 'user_id': str(user_id), - 'guild_id': str(guild_id), - 'channel_id': str(channel_id), - }) + await self._lvsp_info_guild( + guild_id, + "VST_CREATE", + { + "user_id": str(user_id), + "guild_id": str(guild_id), + "channel_id": str(channel_id), + }, + ) async def create_state(self, voice_key: VoiceKey, data: dict): """Creates (or tries to create) a voice state. @@ -249,10 +253,13 @@ class VoiceManager: async def voice_server_list(self, region: str) -> List[dict]: """Get a list of voice server objects""" - rows = await self.app.db.fetch(""" + rows = await self.app.db.fetch( + """ SELECT hostname, last_health FROM voice_servers WHERE region_id = $1 - """, region) + """, + region, + ) return list(map(dict, rows)) diff --git a/litecord/voice/state.py b/litecord/voice/state.py index d5e8732..b3d8d31 100644 --- a/litecord/voice/state.py +++ b/litecord/voice/state.py @@ -23,6 +23,7 @@ from dataclasses import dataclass, asdict @dataclass class VoiceState: """Represents a voice state.""" + guild_id: int channel_id: int user_id: int @@ -55,7 +56,7 @@ class VoiceState: # a better approach would be actually using # the suppressed_by field for backend efficiency. - self_dict['suppress'] = user_id == self.suppressed_by - self_dict.pop('suppressed_by') + self_dict["suppress"] = user_id == self.suppressed_by + self_dict.pop("suppressed_by") return self_dict diff --git a/manage.py b/manage.py index ab61ec2..454f304 100755 --- a/manage.py +++ b/manage.py @@ -27,5 +27,5 @@ import config logging.basicConfig(level=logging.DEBUG) -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main(config)) diff --git a/manage/__init__.py b/manage/__init__.py index ce49370..d21f555 100644 --- a/manage/__init__.py +++ b/manage/__init__.py @@ -16,4 +16,3 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ - diff --git a/manage/cmd/invites.py b/manage/cmd/invites.py index 3fc716a..05b3edc 100644 --- a/manage/cmd/invites.py +++ b/manage/cmd/invites.py @@ -26,7 +26,7 @@ ALPHABET = string.ascii_lowercase + string.ascii_uppercase + string.digits async def _gen_inv() -> str: """Generate an invite code""" - return ''.join(choice(ALPHABET) for _ in range(6)) + return "".join(choice(ALPHABET) for _ in range(6)) async def gen_inv(ctx) -> str: @@ -34,11 +34,14 @@ async def gen_inv(ctx) -> str: for _ in range(10): possible_inv = await _gen_inv() - created_at = await ctx.db.fetchval(""" + created_at = await ctx.db.fetchval( + """ SELECT created_at FROM instance_invites WHERE code = $1 - """, possible_inv) + """, + possible_inv, + ) if created_at is None: return possible_inv @@ -51,27 +54,32 @@ async def make_inv(ctx, args): max_uses = args.max_uses - await ctx.db.execute(""" + await ctx.db.execute( + """ INSERT INTO instance_invites (code, max_uses) VALUES ($1, $2) - """, code, max_uses) + """, + code, + max_uses, + ) - print(f'invite created with {max_uses} max uses', code) + print(f"invite created with {max_uses} max uses", code) async def list_invs(ctx, args): - rows = await ctx.db.fetch(""" + rows = await ctx.db.fetch( + """ SELECT code, created_at, uses, max_uses FROM instance_invites - """) + """ + ) - print(len(rows), 'invites') + print(len(rows), "invites") for row in rows: - max_uses = row['max_uses'] - delta = datetime.datetime.utcnow() - row['created_at'] - usage = ('infinite uses' if max_uses == -1 - else f'{row["uses"]} / {max_uses}') + max_uses = row["max_uses"] + delta = datetime.datetime.utcnow() - row["created_at"] + usage = "infinite uses" if max_uses == -1 else f'{row["uses"]} / {max_uses}' print(f'\t{row["code"]}, {usage}, made {delta} ago') @@ -79,40 +87,37 @@ async def list_invs(ctx, args): async def delete_inv(ctx, args): inv = args.invite_code - res = await ctx.db.execute(""" + res = await ctx.db.execute( + """ DELETE FROM instance_invites WHERE code = $1 - """, inv) + """, + inv, + ) - if res == 'DELETE 0': - print('NOT FOUND') + if res == "DELETE 0": + print("NOT FOUND") return - print('OK') + print("OK") def setup(subparser): - makeinv_parser = subparser.add_parser( - 'makeinv', - help='create an invite', - ) + makeinv_parser = subparser.add_parser("makeinv", help="create an invite") makeinv_parser.add_argument( - 'max_uses', nargs='?', type=int, default=-1, - help='Maximum amount of uses before the invite is unavailable', + "max_uses", + nargs="?", + type=int, + default=-1, + help="Maximum amount of uses before the invite is unavailable", ) makeinv_parser.set_defaults(func=make_inv) - listinv_parser = subparser.add_parser( - 'listinv', - help='list all invites', - ) + listinv_parser = subparser.add_parser("listinv", help="list all invites") listinv_parser.set_defaults(func=list_invs) - delinv_parser = subparser.add_parser( - 'delinv', - help='delete an invite', - ) - delinv_parser.add_argument('invite_code') + delinv_parser = subparser.add_parser("delinv", help="delete an invite") + delinv_parser.add_argument("invite_code") delinv_parser.set_defaults(func=delete_inv) diff --git a/manage/cmd/migration/__init__.py b/manage/cmd/migration/__init__.py index d685676..6ceffc6 100644 --- a/manage/cmd/migration/__init__.py +++ b/manage/cmd/migration/__init__.py @@ -19,4 +19,4 @@ along with this program. If not, see . from .command import setup as migration -__all__ = ['migration'] +__all__ = ["migration"] diff --git a/manage/cmd/migration/command.py b/manage/cmd/migration/command.py index 1f1f91f..003d675 100644 --- a/manage/cmd/migration/command.py +++ b/manage/cmd/migration/command.py @@ -32,18 +32,19 @@ from logbook import Logger log = Logger(__name__) -Migration = namedtuple('Migration', 'id name path') +Migration = namedtuple("Migration", "id name path") # line of change, 4 april 2019, at 1am (gmt+0) BREAK = datetime.datetime(2019, 4, 4, 1) # if a database has those tables, it ran 0_base.sql. -HAS_BASE = ['users', 'guilds', 'e'] +HAS_BASE = ["users", "guilds", "e"] @dataclass class MigrationContext: """Hold information about migration.""" + migration_folder: Path scripts: Dict[int, Migration] @@ -60,22 +61,21 @@ def make_migration_ctx() -> MigrationContext: script_folder = os.sep.join(script_path.split(os.sep)[:-1]) script_folder = Path(script_folder) - migration_folder = script_folder / 'scripts' + migration_folder = script_folder / "scripts" mctx = MigrationContext(migration_folder, {}) - for mig_path in migration_folder.glob('*.sql'): + for mig_path in migration_folder.glob("*.sql"): mig_path_str = str(mig_path) # extract migration script id and name - mig_filename = mig_path_str.split(os.sep)[-1].split('.')[0] - name_fragments = mig_filename.split('_') + mig_filename = mig_path_str.split(os.sep)[-1].split(".")[0] + name_fragments = mig_filename.split("_") mig_id = int(name_fragments[0]) - mig_name = '_'.join(name_fragments[1:]) + mig_name = "_".join(name_fragments[1:]) - mctx.scripts[mig_id] = Migration( - mig_id, mig_name, mig_path) + mctx.scripts[mig_id] = Migration(mig_id, mig_name, mig_path) return mctx @@ -83,7 +83,8 @@ def make_migration_ctx() -> MigrationContext: async def _ensure_changelog(app, ctx): # make sure we have the migration table up try: - await app.db.execute(""" + await app.db.execute( + """ CREATE TABLE migration_log ( change_num bigint NOT NULL, @@ -94,43 +95,56 @@ async def _ensure_changelog(app, ctx): PRIMARY KEY (change_num) ); - """) + """ + ) except asyncpg.DuplicateTableError: - log.debug('existing migration table') + log.debug("existing migration table") # NOTE: this is a migration breakage, # only applying to databases that had their first migration # before 4 april 2019 (more on BREAK) # if migration_log is empty, just assume this is new - first = await app.db.fetchval(""" + first = ( + await app.db.fetchval( + """ SELECT apply_ts FROM migration_log ORDER BY apply_ts ASC LIMIT 1 - """) or BREAK + """ + ) + or BREAK + ) if first < BREAK: - log.info('deleting migration_log due to migration structure change') + log.info("deleting migration_log due to migration structure change") await app.db.execute("DROP TABLE migration_log") await _ensure_changelog(app, ctx) async def _insert_log(app, migration_id: int, description) -> bool: try: - await app.db.execute(""" + await app.db.execute( + """ INSERT INTO migration_log (change_num, description) VALUES ($1, $2) - """, migration_id, description) + """, + migration_id, + description, + ) return True except asyncpg.UniqueViolationError: - log.warning('already inserted {}', migration_id) + log.warning("already inserted {}", migration_id) return False async def _delete_log(app, migration_id: int): - await app.db.execute(""" + await app.db.execute( + """ DELETE FROM migration_log WHERE change_num = $1 - """, migration_id) + """, + migration_id, + ) async def apply_migration(app, migration: Migration) -> bool: @@ -144,21 +158,20 @@ async def apply_migration(app, migration: Migration) -> bool: Returns a boolean signaling if this failed or not. """ - migration_sql = migration.path.read_text(encoding='utf-8') + migration_sql = migration.path.read_text(encoding="utf-8") - res = await _insert_log( - app, migration.id, f'migration: {migration.name}') + res = await _insert_log(app, migration.id, f"migration: {migration.name}") if not res: return False try: await app.db.execute(migration_sql) - log.info('applied {} {}', migration.id, migration.name) + log.info("applied {} {}", migration.id, migration.name) return True except: - log.exception('failed to run migration, rollbacking log') + log.exception("failed to run migration, rollbacking log") await _delete_log(app, migration.id) return False @@ -169,9 +182,11 @@ async def _check_base(app) -> bool: file.""" try: for table in HAS_BASE: - await app.db.execute(f""" + await app.db.execute( + f""" SELECT * FROM {table} LIMIT 0 - """) + """ + ) except asyncpg.UndefinedTableError: return False @@ -197,14 +212,16 @@ async def migrate_cmd(app, _args): has_base = await _check_base(app) # fetch latest local migration that has been run on this database - local_change = await app.db.fetchval(""" + local_change = await app.db.fetchval( + """ SELECT max(change_num) FROM migration_log - """) + """ + ) # if base exists, add it to logs, if not, apply (and add to logs) if has_base: - await _insert_log(app, 0, 'migration setup (from existing)') + await _insert_log(app, 0, "migration setup (from existing)") else: await apply_migration(app, ctx.scripts[0]) @@ -215,10 +232,10 @@ async def migrate_cmd(app, _args): local_change = local_change or 0 latest_change = ctx.latest - log.debug('local: {}, latest: {}', local_change, latest_change) + log.debug("local: {}, latest: {}", local_change, latest_change) if local_change == latest_change: - print('no changes to do, exiting') + print("no changes to do, exiting") return # we do local_change + 1 so we start from the @@ -227,15 +244,13 @@ async def migrate_cmd(app, _args): for idx in range(local_change + 1, latest_change + 1): migration = ctx.scripts.get(idx) - print('applying', migration.id, migration.name) + print("applying", migration.id, migration.name) await apply_migration(app, migration) def setup(subparser): migrate_parser = subparser.add_parser( - 'migrate', - help='Run migration tasks', - description=migrate_cmd.__doc__ + "migrate", help="Run migration tasks", description=migrate_cmd.__doc__ ) migrate_parser.set_defaults(func=migrate_cmd) diff --git a/manage/cmd/users.py b/manage/cmd/users.py index 500a9f6..f89466a 100644 --- a/manage/cmd/users.py +++ b/manage/cmd/users.py @@ -24,39 +24,51 @@ from litecord.enums import UserFlags async def find_user(username, discrim, ctx) -> int: """Get a user ID via the username/discrim pair.""" - return await ctx.db.fetchval(""" + return await ctx.db.fetchval( + """ SELECT id FROM users WHERE username = $1 AND discriminator = $2 - """, username, discrim) + """, + username, + discrim, + ) async def set_user_staff(user_id, ctx): """Give a single user staff status.""" - old_flags = await ctx.db.fetchval(""" + old_flags = await ctx.db.fetchval( + """ SELECT flags FROM users WHERE id = $1 - """, user_id) + """, + user_id, + ) new_flags = old_flags | UserFlags.staff - await ctx.db.execute(""" + await ctx.db.execute( + """ UPDATE users SET flags=$1 WHERE id = $2 - """, new_flags, user_id) + """, + new_flags, + user_id, + ) async def adduser(ctx, args): """Create a single user.""" - uid, _ = await create_user(args.username, args.email, - args.password, ctx.db, ctx.loop) + uid, _ = await create_user( + args.username, args.email, args.password, ctx.db, ctx.loop + ) user = await ctx.storage.get_user(uid) - print('created!') - print(f'\tuid: {uid}') + print("created!") + print(f"\tuid: {uid}") print(f'\tusername: {user["username"]}') print(f'\tdiscrim: {user["discriminator"]}') @@ -72,22 +84,26 @@ async def make_staff(ctx, args): uid = await find_user(args.username, args.discrim, ctx) if not uid: - return print('user not found') + return print("user not found") await set_user_staff(uid, ctx) - print('OK: set staff') + print("OK: set staff") + async def generate_bot_token(ctx, args): """Generate a token for specified bot.""" - password_hash = await ctx.db.fetchval(""" + password_hash = await ctx.db.fetchval( + """ SELECT password_hash FROM users WHERE id = $1 AND bot = 'true' - """, int(args.user_id)) + """, + int(args.user_id), + ) if not password_hash: - return print('cannot find a bot with specified id') + return print("cannot find a bot with specified id") print(make_token(args.user_id, password_hash)) @@ -97,7 +113,7 @@ async def del_user(ctx, args): uid = await find_user(args.username, args.discrim, ctx) if uid is None: - print('user not found') + print("user not found") return user = await ctx.storage.get_user(uid) @@ -106,57 +122,48 @@ async def del_user(ctx, args): print(f'\tuname: {user["username"]}') print(f'\tdiscrim: {user["discriminator"]}') - print('\n you sure you want to delete user? press Y (uppercase)') + print("\n you sure you want to delete user? press Y (uppercase)") confirm = input() - if confirm != 'Y': - print('not confirmed') + if confirm != "Y": + print("not confirmed") return await delete_user(uid, app_=ctx) - print('ok') + print("ok") def setup(subparser): - setup_test_parser = subparser.add_parser( - 'adduser', - help='create a user', - ) + setup_test_parser = subparser.add_parser("adduser", help="create a user") - setup_test_parser.add_argument( - 'username', help='username of the user') - setup_test_parser.add_argument( - 'email', help='email of the user') - setup_test_parser.add_argument( - 'password', help='password of the user') + setup_test_parser.add_argument("username", help="username of the user") + setup_test_parser.add_argument("email", help="email of the user") + setup_test_parser.add_argument("password", help="password of the user") setup_test_parser.set_defaults(func=adduser) staff_parser = subparser.add_parser( - 'make_staff', - help='make a user staff', - description=make_staff.__doc__ + "make_staff", help="make a user staff", description=make_staff.__doc__ ) - staff_parser.add_argument('username') - staff_parser.add_argument( - 'discrim', help='the discriminator of the user') + staff_parser.add_argument("username") + staff_parser.add_argument("discrim", help="the discriminator of the user") staff_parser.set_defaults(func=make_staff) - del_user_parser = subparser.add_parser( - 'deluser', help='delete a single user') + del_user_parser = subparser.add_parser("deluser", help="delete a single user") - del_user_parser.add_argument('username') - del_user_parser.add_argument('discrim') + del_user_parser.add_argument("username") + del_user_parser.add_argument("discrim") del_user_parser.set_defaults(func=del_user) token_parser = subparser.add_parser( - 'generate_token', - help='generate a token for specified bot', - description=generate_bot_token.__doc__) + "generate_token", + help="generate a token for specified bot", + description=generate_bot_token.__doc__, + ) - token_parser.add_argument('user_id') + token_parser.add_argument("user_id") token_parser.set_defaults(func=generate_bot_token) diff --git a/manage/main.py b/manage/main.py index d475f15..562d1f6 100644 --- a/manage/main.py +++ b/manage/main.py @@ -34,6 +34,7 @@ log = Logger(__name__) @dataclass class FakeApp: """Fake app instance.""" + config: dict db = None loop: asyncio.BaseEventLoop = None @@ -50,7 +51,7 @@ class FakeApp: def init_parser(): parser = argparse.ArgumentParser() - subparser = parser.add_subparsers(help='operations') + subparser = parser.add_subparsers(help="operations") migration(subparser) users.setup(subparser) @@ -78,12 +79,12 @@ def main(config): # only init app managers when we aren't migrating # as the managers require it # and the migrate command also sets the db up - if argv[1] != 'migrate': + if argv[1] != "migrate": init_app_managers(app, voice=False) args = parser.parse_args() loop.run_until_complete(args.func(app, args)) except Exception: - log.exception('error while running command') + log.exception("error while running command") finally: loop.run_until_complete(app.db.close()) diff --git a/run.py b/run.py index 4700d1f..d15dbc1 100644 --- a/run.py +++ b/run.py @@ -33,32 +33,51 @@ from aiohttp import ClientSession import config from litecord.blueprints import ( - gateway, auth, users, guilds, channels, webhooks, science, - voice, invites, relationships, dms, icons, nodeinfo, static, - attachments, dm_channels + gateway, + auth, + users, + guilds, + channels, + webhooks, + science, + voice, + invites, + relationships, + dms, + icons, + nodeinfo, + static, + attachments, + dm_channels, ) # those blueprints are separated from the "main" ones # for code readability if people want to dig through # the codebase. from litecord.blueprints.guild import ( - guild_roles, guild_members, guild_channels, guild_mod, - guild_emoji + guild_roles, + guild_members, + guild_channels, + guild_mod, + guild_emoji, ) from litecord.blueprints.channel import ( - channel_messages, channel_reactions, channel_pins + channel_messages, + channel_reactions, + channel_pins, ) -from litecord.blueprints.user import ( - user_settings, user_billing, fake_store -) +from litecord.blueprints.user import user_settings, user_billing, fake_store from litecord.blueprints.user.billing_job import payment_job from litecord.blueprints.admin_api import ( - voice as voice_admin, features as features_admin, - guilds as guilds_admin, users as users_admin, instance_invites + voice as voice_admin, + features as features_admin, + guilds as guilds_admin, + users as users_admin, + instance_invites, ) from litecord.blueprints.admin_api.voice import guild_region_check @@ -84,23 +103,23 @@ from litecord.utils import LitecordJSONEncoder # setup logbook handler = StreamHandler(sys.stdout, level=logbook.INFO) handler.push_application() -log = Logger('litecord.boot') +log = Logger("litecord.boot") redirect_logging() def make_app(): app = Quart(__name__) - app.config.from_object(f'config.{config.MODE}') - is_debug = app.config.get('DEBUG', False) + app.config.from_object(f"config.{config.MODE}") + is_debug = app.config.get("DEBUG", False) app.debug = is_debug if is_debug: - log.info('on debug') + log.info("on debug") handler.level = logbook.DEBUG app.logger.level = logbook.DEBUG # always keep websockets on INFO - logging.getLogger('websockets').setLevel(logbook.INFO) + logging.getLogger("websockets").setLevel(logbook.INFO) # use our custom json encoder for custom data types app.json_encoder = LitecordJSONEncoder @@ -112,51 +131,44 @@ def set_blueprints(app_): """Set the blueprints for a given app instance""" bps = { gateway: None, - auth: '/auth', - - users: '/users', - user_settings: '/users', - user_billing: '/users', - relationships: '/users', - - guilds: '/guilds', - guild_roles: '/guilds', - guild_members: '/guilds', - guild_channels: '/guilds', - guild_mod: '/guilds', - guild_emoji: '/guilds', - - channels: '/channels', - channel_messages: '/channels', - channel_reactions: '/channels', - channel_pins: '/channels', - + auth: "/auth", + users: "/users", + user_settings: "/users", + user_billing: "/users", + relationships: "/users", + guilds: "/guilds", + guild_roles: "/guilds", + guild_members: "/guilds", + guild_channels: "/guilds", + guild_mod: "/guilds", + guild_emoji: "/guilds", + channels: "/channels", + channel_messages: "/channels", + channel_reactions: "/channels", + channel_pins: "/channels", webhooks: None, science: None, - voice: '/voice', + voice: "/voice", invites: None, - dms: '/users', - dm_channels: '/channels', - + dms: "/users", + dm_channels: "/channels", fake_store: None, - icons: -1, attachments: -1, nodeinfo: -1, static: -1, - - voice_admin: '/admin/voice', - features_admin: '/admin/guilds', - guilds_admin: '/admin/guilds', - users_admin: '/admin/users', - instance_invites: '/admin/instance/invites' + voice_admin: "/admin/voice", + features_admin: "/admin/guilds", + guilds_admin: "/admin/guilds", + users_admin: "/admin/users", + instance_invites: "/admin/instance/invites", } for bp, suffix in bps.items(): url_prefix = f'/api/v6{suffix or ""}' if suffix == -1: - url_prefix = '' + url_prefix = "" app_.register_blueprint(bp, url_prefix=url_prefix) @@ -175,37 +187,35 @@ async def app_before_request(): @app.after_request async def app_after_request(resp): """Handle CORS headers.""" - origin = request.headers.get('Origin', '*') - resp.headers['Access-Control-Allow-Origin'] = origin - resp.headers['Access-Control-Allow-Headers'] = ( - '*, X-Super-Properties, ' - 'X-Fingerprint, ' - 'X-Context-Properties, ' - 'X-Failed-Requests, ' - 'X-Debug-Options, ' - 'Content-Type, ' - 'Authorization, ' - 'Origin, ' - 'If-None-Match' + origin = request.headers.get("Origin", "*") + resp.headers["Access-Control-Allow-Origin"] = origin + resp.headers["Access-Control-Allow-Headers"] = ( + "*, X-Super-Properties, " + "X-Fingerprint, " + "X-Context-Properties, " + "X-Failed-Requests, " + "X-Debug-Options, " + "Content-Type, " + "Authorization, " + "Origin, " + "If-None-Match" ) - resp.headers['Access-Control-Allow-Methods'] = \ - resp.headers.get('allow', '*') + resp.headers["Access-Control-Allow-Methods"] = resp.headers.get("allow", "*") return resp def _set_rtl_reset(bucket, resp): reset = bucket._window + bucket.second - precision = request.headers.get('x-ratelimit-precision', 'second') + precision = request.headers.get("x-ratelimit-precision", "second") - if precision == 'second': - resp.headers['X-RateLimit-Reset'] = str(round(reset)) - elif precision == 'millisecond': - resp.headers['X-RateLimit-Reset'] = str(reset) + if precision == "second": + resp.headers["X-RateLimit-Reset"] = str(round(reset)) + elif precision == "millisecond": + resp.headers["X-RateLimit-Reset"] = str(reset) else: - resp.headers['X-RateLimit-Reset'] = ( - 'Invalid X-RateLimit-Precision, ' - 'valid options are (second, millisecond)' + resp.headers["X-RateLimit-Reset"] = ( + "Invalid X-RateLimit-Precision, " "valid options are (second, millisecond)" ) @@ -218,15 +228,15 @@ async def app_set_ratelimit_headers(resp): if bucket is None: raise AttributeError() - resp.headers['X-RateLimit-Limit'] = str(bucket.requests) - resp.headers['X-RateLimit-Remaining'] = str(bucket._tokens) - resp.headers['X-RateLimit-Global'] = str(request.bucket_global).lower() + resp.headers["X-RateLimit-Limit"] = str(bucket.requests) + resp.headers["X-RateLimit-Remaining"] = str(bucket._tokens) + resp.headers["X-RateLimit-Global"] = str(request.bucket_global).lower() _set_rtl_reset(bucket, resp) # only add Retry-After if we actually hit a ratelimit retry_after = request.retry_after if request.retry_after: - resp.headers['Retry-After'] = str(retry_after) + resp.headers["Retry-After"] = str(retry_after) except AttributeError: pass @@ -238,8 +248,8 @@ async def init_app_db(app_): Also spawns the job scheduler. """ - log.info('db connect') - app_.db = await asyncpg.create_pool(**app.config['POSTGRES']) + log.info("db connect") + app_.db = await asyncpg.create_pool(**app.config["POSTGRES"]) app_.sched = JobManager() @@ -247,7 +257,7 @@ async def init_app_db(app_): def init_app_managers(app_, *, voice=True): """Initialize singleton classes.""" app_.loop = asyncio.get_event_loop() - app_.ratelimiter = RatelimitManager(app_.config.get('_testing')) + app_.ratelimiter = RatelimitManager(app_.config.get("_testing")) app_.state_manager = StateManager() app_.storage = Storage(app_) @@ -274,15 +284,12 @@ async def api_index(app_): to_find = {} found = [] - with open('discord_endpoints.txt') as fd: + with open("discord_endpoints.txt") as fd: for line in fd.readlines(): - components = line.split(' ') - components = list(filter( - bool, - components - )) + components = line.split(" ") + components = list(filter(bool, components)) name, method, path = components - path = f'/api/v6{path.strip()}' + path = f"/api/v6{path.strip()}" method = method.strip() to_find[(path, method)] = name @@ -290,17 +297,17 @@ async def api_index(app_): path = rule.rule # convert the path to the discord_endpoints file's style - path = path.replace('_', '.') - path = path.replace('<', '{') - path = path.replace('>', '}') - path = path.replace('int:', '') + path = path.replace("_", ".") + path = path.replace("<", "{") + path = path.replace(">", "}") + path = path.replace("int:", "") # change our parameters into user.id - path = path.replace('member.id', 'user.id') - path = path.replace('banned.id', 'user.id') - path = path.replace('target.id', 'user.id') - path = path.replace('other.id', 'user.id') - path = path.replace('peer.id', 'user.id') + path = path.replace("member.id", "user.id") + path = path.replace("banned.id", "user.id") + path = path.replace("target.id", "user.id") + path = path.replace("other.id", "user.id") + path = path.replace("peer.id", "user.id") methods = rule.methods @@ -317,10 +324,15 @@ async def api_index(app_): percentage = (len(found) / len(api)) * 100 percentage = round(percentage, 2) - log.debug('API compliance: {} out of {} ({} missing), {}% compliant', - len(found), len(api), len(missing), percentage) + log.debug( + "API compliance: {} out of {} ({} missing), {}% compliant", + len(found), + len(api), + len(missing), + percentage, + ) - log.debug('missing: {}', missing) + log.debug("missing: {}", missing) async def post_app_start(app_): @@ -332,7 +344,7 @@ async def post_app_start(app_): def start_websocket(host, port, ws_handler) -> asyncio.Future: """Start a websocket. Returns the websocket future""" - log.info(f'starting websocket at {host} {port}') + log.info(f"starting websocket at {host} {port}") async def _wrapper(ws, url): # We wrap the main websocket_handler @@ -348,7 +360,7 @@ async def app_before_serving(): Also sets up the websocket handlers. """ - log.info('opening db') + log.info("opening db") await init_app_db(app) app.session = ClientSession() @@ -359,8 +371,7 @@ async def app_before_serving(): # start gateway websocket # voice websocket is handled by the voice server ws_fut = start_websocket( - app.config['WS_HOST'], app.config['WS_PORT'], - websocket_handler + app.config["WS_HOST"], app.config["WS_PORT"], websocket_handler ) await ws_fut @@ -379,7 +390,7 @@ async def app_after_serving(): app.sched.close() - log.info('closing db') + log.info("closing db") await app.db.close() @@ -391,24 +402,23 @@ async def handle_litecord_err(err): ejson = {} try: - ejson['code'] = err.error_code + ejson["code"] = err.error_code except AttributeError: pass - log.warning('error: {} {!r}', err.status_code, err.message) + log.warning("error: {} {!r}", err.status_code, err.message) - return jsonify({ - 'error': True, - 'status': err.status_code, - 'message': err.message, - **ejson - }), err.status_code + return ( + jsonify( + {"error": True, "status": err.status_code, "message": err.message, **ejson} + ), + err.status_code, + ) @app.errorhandler(500) async def handle_500(err): - return jsonify({ - 'error': True, - 'message': repr(err), - 'internal_server_error': True, - }), 500 + return ( + jsonify({"error": True, "message": repr(err), "internal_server_error": True}), + 500, + ) diff --git a/setup.py b/setup.py index 741cede..ee848cf 100644 --- a/setup.py +++ b/setup.py @@ -20,10 +20,10 @@ along with this program. If not, see . from setuptools import setup setup( - name='litecord', - version='0.0.1', - description='Implementation of the Discord API', - url='https://litecord.top', - author='Luna Mendes', - python_requires='>=3.7' + name="litecord", + version="0.0.1", + description="Implementation of the Discord API", + url="https://litecord.top", + author="Luna Mendes", + python_requires=">=3.7", ) diff --git a/tests/common.py b/tests/common.py index cc0a9b3..4e573a7 100644 --- a/tests/common.py +++ b/tests/common.py @@ -19,13 +19,15 @@ along with this program. If not, see . import secrets + def email() -> str: - return f'{secrets.token_hex(5)}@{secrets.token_hex(5)}.com' + return f"{secrets.token_hex(5)}@{secrets.token_hex(5)}.com" class TestClient: """Test client that wraps pytest-sanic's TestClient and a test user and adds authorization headers to test requests.""" + def __init__(self, test_cli, test_user): self.cli = test_cli self.app = test_cli.app @@ -37,31 +39,31 @@ class TestClient: def _inject_auth(self, kwargs: dict) -> list: """Inject the test user's API key into the test request before passing the request on to the underlying TestClient.""" - headers = kwargs.get('headers', {}) - headers['authorization'] = self.user['token'] + headers = kwargs.get("headers", {}) + headers["authorization"] = self.user["token"] return headers async def get(self, *args, **kwargs): """Send a GET request.""" - kwargs['headers'] = self._inject_auth(kwargs) + kwargs["headers"] = self._inject_auth(kwargs) return await self.cli.get(*args, **kwargs) async def post(self, *args, **kwargs): """Send a POST request.""" - kwargs['headers'] = self._inject_auth(kwargs) + kwargs["headers"] = self._inject_auth(kwargs) return await self.cli.post(*args, **kwargs) async def put(self, *args, **kwargs): """Send a POST request.""" - kwargs['headers'] = self._inject_auth(kwargs) + kwargs["headers"] = self._inject_auth(kwargs) return await self.cli.put(*args, **kwargs) async def patch(self, *args, **kwargs): """Send a PATCH request.""" - kwargs['headers'] = self._inject_auth(kwargs) + kwargs["headers"] = self._inject_auth(kwargs) return await self.cli.patch(*args, **kwargs) async def delete(self, *args, **kwargs): """Send a DELETE request.""" - kwargs['headers'] = self._inject_auth(kwargs) + kwargs["headers"] = self._inject_auth(kwargs) return await self.cli.delete(*args, **kwargs) diff --git a/tests/conftest.py b/tests/conftest.py index 7d88389..ed9e294 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -36,22 +36,22 @@ from litecord.blueprints.auth import make_token from litecord.blueprints.users import delete_user -@pytest.fixture(name='app') +@pytest.fixture(name="app") def _test_app(unused_tcp_port, event_loop): set_blueprints(main_app) - main_app.config['_testing'] = True + main_app.config["_testing"] = True # reassign an unused tcp port for websockets # since the config might give a used one. ws_port = unused_tcp_port - main_app.config['IS_SSL'] = False - main_app.config['WS_PORT'] = ws_port - main_app.config['WEBSOCKET_URL'] = f'localhost:{ws_port}' + main_app.config["IS_SSL"] = False + main_app.config["WS_PORT"] = ws_port + main_app.config["WEBSOCKET_URL"] = f"localhost:{ws_port}" # testing user creations requires hardcoding this to true # on testing - main_app.config['REGISTRATIONS'] = True + main_app.config["REGISTRATIONS"] = True # make sure we're calling the before_serving hooks event_loop.run_until_complete(main_app.startup()) @@ -63,11 +63,12 @@ def _test_app(unused_tcp_port, event_loop): event_loop.run_until_complete(main_app.shutdown()) -@pytest.fixture(name='test_cli') +@pytest.fixture(name="test_cli") def _test_cli(app): """Give a test client.""" return app.test_client() + # code shamelessly stolen from my elixire mr # https://gitlab.com/elixire/elixire/merge_requests/52 async def _user_fixture_setup(app): @@ -76,21 +77,26 @@ async def _user_fixture_setup(app): user_email = email() user_id, pwd_hash = await create_user( - username, user_email, password, app.db, app.loop) + username, user_email, password, app.db, app.loop + ) # generate a token for api access user_token = make_token(user_id, pwd_hash) - return {'id': user_id, 'token': user_token, - 'email': user_email, 'username': username, - 'password': password} + return { + "id": user_id, + "token": user_token, + "email": user_email, + "username": username, + "password": password, + } async def _user_fixture_teardown(app, udata: dict): - await delete_user(udata['id'], app_=app) + await delete_user(udata["id"], app_=app) -@pytest.fixture(name='test_user') +@pytest.fixture(name="test_user") async def test_user_fixture(app): """Yield a randomly generated test user.""" udata = await _user_fixture_setup(app) @@ -113,18 +119,25 @@ async def test_cli_staff(test_cli): # same test_cli_user, which isn't acceptable. app = test_cli.app test_user = await _user_fixture_setup(app) - user_id = test_user['id'] + user_id = test_user["id"] # copied from manage.cmd.users.set_user_staff. - old_flags = await app.db.fetchval(""" + old_flags = await app.db.fetchval( + """ SELECT flags FROM users WHERE id = $1 - """, user_id) + """, + user_id, + ) new_flags = old_flags | UserFlags.staff - await app.db.execute(""" + await app.db.execute( + """ UPDATE users SET flags = $1 WHERE id = $2 - """, new_flags, user_id) + """, + new_flags, + user_id, + ) yield TestClient(test_cli, test_user) await _user_fixture_teardown(test_cli.app, test_user) diff --git a/tests/test_admin_api/test_guilds.py b/tests/test_admin_api/test_guilds.py index d1156e0..b6619e7 100644 --- a/tests/test_admin_api/test_guilds.py +++ b/tests/test_admin_api/test_guilds.py @@ -24,24 +24,24 @@ import pytest from litecord.blueprints.guilds import delete_guild from litecord.errors import GuildNotFound + async def _create_guild(test_cli_staff): genned_name = secrets.token_hex(6) - resp = await test_cli_staff.post('/api/v6/guilds', json={ - 'name': genned_name, - 'region': None - }) + resp = await test_cli_staff.post( + "/api/v6/guilds", json={"name": genned_name, "region": None} + ) assert resp.status_code == 200 rjson = await resp.json assert isinstance(rjson, dict) - assert rjson['name'] == genned_name + assert rjson["name"] == genned_name return rjson async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False): - resp = await test_cli_staff.get(f'/api/v6/admin/guilds/{guild_id}') + resp = await test_cli_staff.get(f"/api/v6/admin/guilds/{guild_id}") if ret_early: return resp @@ -49,7 +49,7 @@ async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False): assert resp.status_code == 200 rjson = await resp.json assert isinstance(rjson, dict) - assert rjson['id'] == guild_id + assert rjson["id"] == guild_id return rjson @@ -58,7 +58,7 @@ async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False): async def test_guild_fetch(test_cli_staff): """Test the creation and fetching of a guild via the Admin API.""" rjson = await _create_guild(test_cli_staff) - guild_id = rjson['id'] + guild_id = rjson["id"] try: await _fetch_guild(test_cli_staff, guild_id) @@ -70,8 +70,8 @@ async def test_guild_fetch(test_cli_staff): async def test_guild_update(test_cli_staff): """Test the update of a guild via the Admin API.""" rjson = await _create_guild(test_cli_staff) - guild_id = rjson['id'] - assert not rjson['unavailable'] + guild_id = rjson["id"] + assert not rjson["unavailable"] try: # I believe setting up an entire gateway client registered to the guild @@ -79,19 +79,17 @@ async def test_guild_update(test_cli_staff): # testing them. Yes, I know its a bad idea, but if someone has an easier # way to write that, do send an MR. resp = await test_cli_staff.patch( - f'/api/v6/admin/guilds/{guild_id}', - json={ - 'unavailable': True - }) + f"/api/v6/admin/guilds/{guild_id}", json={"unavailable": True} + ) assert resp.status_code == 200 rjson = await resp.json assert isinstance(rjson, dict) - assert rjson['id'] == guild_id - assert rjson['unavailable'] + assert rjson["id"] == guild_id + assert rjson["unavailable"] rjson = await _fetch_guild(test_cli_staff, guild_id) - assert rjson['unavailable'] + assert rjson["unavailable"] finally: await delete_guild(int(guild_id), app_=test_cli_staff.app) @@ -100,20 +98,19 @@ async def test_guild_update(test_cli_staff): async def test_guild_delete(test_cli_staff): """Test the update of a guild via the Admin API.""" rjson = await _create_guild(test_cli_staff) - guild_id = rjson['id'] + guild_id = rjson["id"] try: - resp = await test_cli_staff.delete(f'/api/v6/admin/guilds/{guild_id}') + resp = await test_cli_staff.delete(f"/api/v6/admin/guilds/{guild_id}") assert resp.status_code == 204 - resp = await _fetch_guild( - test_cli_staff, guild_id, ret_early=True) + resp = await _fetch_guild(test_cli_staff, guild_id, ret_early=True) assert resp.status_code == 404 rjson = await resp.json assert isinstance(rjson, dict) - assert rjson['error'] - assert rjson['code'] == GuildNotFound.error_code + assert rjson["error"] + assert rjson["code"] == GuildNotFound.error_code finally: await delete_guild(int(guild_id), app_=test_cli_staff.app) diff --git a/tests/test_admin_api/test_instance_invites.py b/tests/test_admin_api/test_instance_invites.py index 9bdeba5..e9149c6 100644 --- a/tests/test_admin_api/test_instance_invites.py +++ b/tests/test_admin_api/test_instance_invites.py @@ -21,7 +21,7 @@ import pytest async def _get_invs(test_cli): - resp = await test_cli.get('/api/v6/admin/instance/invites') + resp = await test_cli.get("/api/v6/admin/instance/invites") assert resp.status_code == 200 rjson = await resp.json @@ -39,7 +39,7 @@ async def test_get_invites(test_cli_staff): async def test_inv_delete_invalid(test_cli_staff): """Test errors happen when trying to delete a non-existing instance invite.""" - resp = await test_cli_staff.delete('/api/v6/admin/instance/invites/aaaaaa') + resp = await test_cli_staff.delete("/api/v6/admin/instance/invites/aaaaaa") assert resp.status_code == 404 @@ -48,21 +48,20 @@ async def test_inv_delete_invalid(test_cli_staff): async def test_create_invite(test_cli_staff): """Test the creation of an instance invite, then listing it, then deleting it.""" - resp = await test_cli_staff.put('/api/v6/admin/instance/invites', json={ - 'max_uses': 1 - }) + resp = await test_cli_staff.put( + "/api/v6/admin/instance/invites", json={"max_uses": 1} + ) assert resp.status_code == 200 rjson = await resp.json assert isinstance(rjson, dict) - code = rjson['code'] + code = rjson["code"] # assert that the invite is in the list invites = await _get_invs(test_cli_staff) - assert any(inv['code'] == code for inv in invites) + assert any(inv["code"] == code for inv in invites) # delete it, and assert it worked - resp = await test_cli_staff.delete( - f'/api/v6/admin/instance/invites/{code}') + resp = await test_cli_staff.delete(f"/api/v6/admin/instance/invites/{code}") assert resp.status_code == 204 diff --git a/tests/test_admin_api/test_users.py b/tests/test_admin_api/test_users.py index a78acf7..8b2e738 100644 --- a/tests/test_admin_api/test_users.py +++ b/tests/test_admin_api/test_users.py @@ -24,20 +24,16 @@ import pytest from litecord.enums import UserFlags -async def _search(test_cli, *, username='', discrim=''): - query_string = { - 'username': username, - 'discriminator': discrim - } +async def _search(test_cli, *, username="", discrim=""): + query_string = {"username": username, "discriminator": discrim} - return await test_cli.get('/api/v6/admin/users', query_string=query_string) + return await test_cli.get("/api/v6/admin/users", query_string=query_string) @pytest.mark.asyncio async def test_list_users(test_cli_staff): """Try to list as many users as possible.""" - resp = await _search( - test_cli_staff, username=test_cli_staff.user['username']) + resp = await _search(test_cli_staff, username=test_cli_staff.user["username"]) assert resp.status_code == 200 rjson = await resp.json @@ -48,36 +44,42 @@ async def test_list_users(test_cli_staff): async def _setup_user(test_cli) -> dict: genned = secrets.token_hex(7) - resp = await test_cli.post('/api/v6/admin/users', json={ - 'username': genned, - 'email': f'{genned}@{genned}.com', - 'password': genned, - }) + resp = await test_cli.post( + "/api/v6/admin/users", + json={ + "username": genned, + "email": f"{genned}@{genned}.com", + "password": genned, + }, + ) assert resp.status_code == 200 rjson = await resp.json assert isinstance(rjson, dict) - assert rjson['username'] == genned + assert rjson["username"] == genned return rjson async def _del_user(test_cli, user_id): """Delete a user.""" - resp = await test_cli.delete(f'/api/v6/admin/users/{user_id}') + resp = await test_cli.delete(f"/api/v6/admin/users/{user_id}") assert resp.status_code == 200 rjson = await resp.json assert isinstance(rjson, dict) - assert rjson['new']['id'] == user_id - assert rjson['old']['id'] == rjson['new']['id'] + assert rjson["new"]["id"] == user_id + assert rjson["old"]["id"] == rjson["new"]["id"] # delete the original record since the DELETE endpoint will just # replace the user by a "Deleted User ", and we don't want # to have obsolete users filling up our db every time we run tests - await test_cli.app.db.execute(""" + await test_cli.app.db.execute( + """ DELETE FROM users WHERE id = $1 - """, int(user_id)) + """, + int(user_id), + ) @pytest.mark.asyncio @@ -85,8 +87,8 @@ async def test_create_delete(test_cli_staff): """Create a user. Then delete them.""" rjson = await _setup_user(test_cli_staff) - genned = rjson['username'] - genned_uid = rjson['id'] + genned = rjson["username"] + genned_uid = rjson["id"] try: # check if side-effects went through with a search @@ -95,7 +97,7 @@ async def test_create_delete(test_cli_staff): assert resp.status_code == 200 rjson = await resp.json assert isinstance(rjson, list) - assert rjson[0]['id'] == genned_uid + assert rjson[0]["id"] == genned_uid finally: await _del_user(test_cli_staff, genned_uid) @@ -105,22 +107,20 @@ async def test_user_update(test_cli_staff): """Test user update.""" rjson = await _setup_user(test_cli_staff) - user_id = rjson['id'] + user_id = rjson["id"] # test update try: # set them as partner flag resp = await test_cli_staff.patch( - f'/api/v6/admin/users/{user_id}', - json={ - 'flags': UserFlags.partner, - }) + f"/api/v6/admin/users/{user_id}", json={"flags": UserFlags.partner} + ) assert resp.status_code == 200 rjson = await resp.json - assert rjson['id'] == user_id - assert rjson['flags'] == UserFlags.partner + assert rjson["id"] == user_id + assert rjson["flags"] == UserFlags.partner # TODO: maybe we can check for side effects by fetching the # user manually too... diff --git a/tests/test_embeds.py b/tests/test_embeds.py index cbf161c..5720243 100644 --- a/tests/test_embeds.py +++ b/tests/test_embeds.py @@ -21,9 +21,11 @@ from litecord.schemas import validate from litecord.embed.schemas import EMBED_OBJECT from litecord.embed.sanitizer import path_exists + def validate_embed(embed): return validate(embed, EMBED_OBJECT) + def valid(embed: dict): try: validate_embed(embed) @@ -31,6 +33,7 @@ def valid(embed: dict): except: return False + def invalid(embed): try: validate_embed(embed) @@ -44,66 +47,48 @@ def test_empty_embed(): def test_basic_embed(): - assert valid({ - 'title': 'test', - 'description': 'acab', - 'url': 'https://www.w3.org', - 'color': 123 - }) + assert valid( + { + "title": "test", + "description": "acab", + "url": "https://www.w3.org", + "color": 123, + } + ) def test_footer_embed(): - assert invalid({ - 'footer': {} - }) + assert invalid({"footer": {}}) + + assert valid({"title": "test", "footer": {"text": "abcdef"}}) - assert valid({ - 'title': 'test', - 'footer': { - 'text': 'abcdef' - } - }) def test_image(): - assert invalid({ - 'image': {} - }) + assert invalid({"image": {}}) + + assert valid({"image": {"url": "https://www.w3.org"}}) - assert valid({ - 'image': { - 'url': 'https://www.w3.org' - } - }) def test_author(): - assert invalid({ - 'author': { - 'name': '' - } - }) + assert invalid({"author": {"name": ""}}) + + assert valid({"author": {"name": "abcdef"}}) - assert valid({ - 'author': { - 'name': 'abcdef' - } - }) def test_fields(): - assert valid({ - 'fields': [ - {'name': 'a', 'value': 'b'}, - {'name': 'c', 'value': 'd', 'inline': False}, - ] - }) + assert valid( + { + "fields": [ + {"name": "a", "value": "b"}, + {"name": "c", "value": "d", "inline": False}, + ] + } + ) - assert invalid({ - 'fields': [ - {'name': 'a'}, - ] - }) + assert invalid({"fields": [{"name": "a"}]}) def test_path_exists(): """Test the path_exists() function for embed sanitization.""" - assert path_exists({'a': {'b': 2}}, 'a.b') - assert not path_exists({'a': 'b'}, 'a.b') + assert path_exists({"a": {"b": 2}}, "a.b") + assert not path_exists({"a": "b"}, "a.b") diff --git a/tests/test_gateway.py b/tests/test_gateway.py index a160b88..073c1e0 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -19,6 +19,7 @@ along with this program. If not, see . import sys import os + sys.path.append(os.getcwd()) import pytest @@ -27,28 +28,28 @@ import pytest @pytest.mark.asyncio async def test_gw(test_cli): """Test if the gateway route works.""" - resp = await test_cli.get('/api/v6/gateway') + resp = await test_cli.get("/api/v6/gateway") assert resp.status_code == 200 rjson = await resp.json assert isinstance(rjson, dict) - assert 'url' in rjson - assert isinstance(rjson['url'], str) + assert "url" in rjson + assert isinstance(rjson["url"], str) @pytest.mark.asyncio async def test_gw_bot(test_cli_user): """Test the Get Bot Gateway route""" - resp = await test_cli_user.get('/api/v6/gateway/bot') + resp = await test_cli_user.get("/api/v6/gateway/bot") assert resp.status_code == 200 rjson = await resp.json assert isinstance(rjson, dict) - assert isinstance(rjson['url'], str) - assert isinstance(rjson['shards'], int) - assert 'session_start_limit' in rjson + assert isinstance(rjson["url"], str) + assert isinstance(rjson["shards"], int) + assert "session_start_limit" in rjson - ssl = rjson['session_start_limit'] - assert isinstance(ssl['total'], int) - assert isinstance(ssl['remaining'], int) - assert isinstance(ssl['reset_after'], int) + ssl = rjson["session_start_limit"] + assert isinstance(ssl["total"], int) + assert isinstance(ssl["remaining"], int) + assert isinstance(ssl["reset_after"], int) diff --git a/tests/test_guild.py b/tests/test_guild.py index c91e676..b6739f6 100644 --- a/tests/test_guild.py +++ b/tests/test_guild.py @@ -31,25 +31,24 @@ async def test_guild_create(test_cli_user): g_name = secrets.token_hex(5) # stage 1: create - resp = await test_cli_user.post('/api/v6/guilds', json={ - 'name': g_name, - 'region': None, - }) + resp = await test_cli_user.post( + "/api/v6/guilds", json={"name": g_name, "region": None} + ) assert resp.status_code == 200 rjson = await resp.json # we won't assert a full guild object. - assert isinstance(rjson['id'], str) - assert isinstance(rjson['owner_id'], str) - assert isinstance(rjson['name'], str) - assert rjson['name'] == g_name + assert isinstance(rjson["id"], str) + assert isinstance(rjson["owner_id"], str) + assert isinstance(rjson["name"], str) + assert rjson["name"] == g_name created = rjson - guild_id = created['id'] + guild_id = created["id"] # stage 2: test - resp = await test_cli_user.get('/api/v6/users/@me/guilds') + resp = await test_cli_user.get("/api/v6/users/@me/guilds") assert resp.status_code == 200 rjson = await resp.json @@ -62,23 +61,20 @@ async def test_guild_create(test_cli_user): for guild in rjson: assert isinstance(guild, dict) - assert isinstance(guild['id'], str) - assert isinstance(guild['name'], str) - assert isinstance(guild['owner'], bool) - assert guild['icon'] is None or isinstance(guild['icon'], str) + assert isinstance(guild["id"], str) + assert isinstance(guild["name"], str) + assert isinstance(guild["owner"], bool) + assert guild["icon"] is None or isinstance(guild["icon"], str) try: - our_guild = next(filter( - lambda guild: guild['id'] == guild_id, - rjson - )) + our_guild = next(filter(lambda guild: guild["id"] == guild_id, rjson)) except StopIteration: - raise Exception('created guild not found in user guild list') + raise Exception("created guild not found in user guild list") - assert our_guild['id'] == created['id'] - assert our_guild['name'] == created['name'] + assert our_guild["id"] == created["id"] + assert our_guild["name"] == created["name"] # stage 3: deletion - resp = await test_cli_user.delete(f'/api/v6/guilds/{guild_id}') + resp = await test_cli_user.delete(f"/api/v6/guilds/{guild_id}") assert resp.status_code == 204 diff --git a/tests/test_main.py b/tests/test_main.py index 3dc7a6a..729b37d 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -23,5 +23,5 @@ import pytest @pytest.mark.asyncio async def test_index(test_cli): """Test if the main index page works.""" - resp = await test_cli.get('/') + resp = await test_cli.get("/") assert resp.status_code == 200 diff --git a/tests/test_no_tracking.py b/tests/test_no_tracking.py index efdac93..aefc4f5 100644 --- a/tests/test_no_tracking.py +++ b/tests/test_no_tracking.py @@ -23,14 +23,14 @@ import pytest @pytest.mark.asyncio async def test_science_empty(test_cli): """Test that the science route gives nothing.""" - resp = await test_cli.post('/api/v6/science') + resp = await test_cli.post("/api/v6/science") assert resp.status_code == 204 @pytest.mark.asyncio async def test_harvest_empty(test_cli): """test that the harvest route is empty""" - resp = await test_cli.get('/api/v6/users/@me/harvest') + resp = await test_cli.get("/api/v6/users/@me/harvest") assert resp.status_code == 204 @@ -38,12 +38,12 @@ async def test_harvest_empty(test_cli): async def test_consent_non_consenting(test_cli_user): """Test the consent route to see if we're still on a non-consent status regarding data collection.""" - resp = await test_cli_user.get('/api/v6/users/@me/consent') + resp = await test_cli_user.get("/api/v6/users/@me/consent") assert resp.status_code == 200 rjson = await resp.json assert isinstance(rjson, dict) # assert that we did not consent to those - assert not rjson['usage_statistics']['consented'] - assert not rjson['personalization']['consented'] + assert not rjson["usage_statistics"]["consented"] + assert not rjson["personalization"]["consented"] diff --git a/tests/test_ratelimits.py b/tests/test_ratelimits.py index 170ed2f..bd675d0 100644 --- a/tests/test_ratelimits.py +++ b/tests/test_ratelimits.py @@ -19,6 +19,7 @@ along with this program. If not, see . import sys import os + sys.path.append(os.getcwd()) import pytest @@ -38,10 +39,10 @@ def test_ratelimit(): @pytest.mark.asyncio async def test_ratelimit_headers(test_cli): """Test if the basic ratelimit headers are sent.""" - resp = await test_cli.get('/api/v6/gateway') + resp = await test_cli.get("/api/v6/gateway") assert resp.status_code == 200 hdrs = resp.headers - assert 'X-RateLimit-Limit' in hdrs - assert 'X-RateLimit-Remaining' in hdrs - assert 'X-RateLimit-Reset' in hdrs - assert 'X-RateLimit-Global' in hdrs + assert "X-RateLimit-Limit" in hdrs + assert "X-RateLimit-Remaining" in hdrs + assert "X-RateLimit-Reset" in hdrs + assert "X-RateLimit-Global" in hdrs diff --git a/tests/test_user.py b/tests/test_user.py index a83de77..0caf14b 100644 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -23,24 +23,24 @@ import secrets @pytest.mark.asyncio async def test_get_me(test_cli_user): - resp = await test_cli_user.get('/api/v6/users/@me') + resp = await test_cli_user.get("/api/v6/users/@me") assert resp.status_code == 200 rjson = await resp.json assert isinstance(rjson, dict) # incomplete user assertions, but should be enough - assert isinstance(rjson['id'], str) - assert isinstance(rjson['username'], str) - assert isinstance(rjson['discriminator'], str) - assert rjson['avatar'] is None or isinstance(rjson['avatar'], str) - assert isinstance(rjson['flags'], int) - assert isinstance(rjson['bot'], bool) + assert isinstance(rjson["id"], str) + assert isinstance(rjson["username"], str) + assert isinstance(rjson["discriminator"], str) + assert rjson["avatar"] is None or isinstance(rjson["avatar"], str) + assert isinstance(rjson["flags"], int) + assert isinstance(rjson["bot"], bool) @pytest.mark.asyncio async def test_get_me_guilds(test_cli_user): - resp = await test_cli_user.get('/api/v6/users/@me/guilds') + resp = await test_cli_user.get("/api/v6/users/@me/guilds") assert resp.status_code == 200 rjson = await resp.json @@ -49,17 +49,16 @@ async def test_get_me_guilds(test_cli_user): @pytest.mark.asyncio async def test_get_profile_self(test_cli_user): - user_id = test_cli_user.user['id'] - resp = await test_cli_user.get(f'/api/v6/users/{user_id}/profile') + user_id = test_cli_user.user["id"] + resp = await test_cli_user.get(f"/api/v6/users/{user_id}/profile") assert resp.status_code == 200 rjson = await resp.json assert isinstance(rjson, dict) - assert isinstance(rjson['user'], dict) - assert isinstance(rjson['connected_accounts'], list) - assert (rjson['premium_since'] is None - or isinstance(rjson['premium_since'], str)) - assert isinstance(rjson['mutual_guilds'], list) + assert isinstance(rjson["user"], dict) + assert isinstance(rjson["connected_accounts"], list) + assert rjson["premium_since"] is None or isinstance(rjson["premium_since"], str) + assert isinstance(rjson["mutual_guilds"], list) @pytest.mark.asyncio @@ -67,39 +66,39 @@ async def test_create_user(test_cli): """Test the creation and deletion of a user.""" username = secrets.token_hex(4) _email = secrets.token_hex(5) - email = f'{_email}@{_email}.com' + email = f"{_email}@{_email}.com" password = secrets.token_hex(6) - resp = await test_cli.post('/api/v6/auth/register', json={ - 'username': username, - 'email': email, - 'password': password - }) + resp = await test_cli.post( + "/api/v6/auth/register", + json={"username": username, "email": email, "password": password}, + ) assert resp.status_code == 200 rjson = await resp.json assert isinstance(rjson, dict) - token = rjson['token'] + token = rjson["token"] assert isinstance(token, str) - resp = await test_cli.get('/api/v6/users/@me', headers={ - 'Authorization': token, - }) + resp = await test_cli.get("/api/v6/users/@me", headers={"Authorization": token}) assert resp.status_code == 200 rjson = await resp.json - assert rjson['username'] == username - assert rjson['email'] == email + assert rjson["username"] == username + assert rjson["email"] == email - resp = await test_cli.post('/api/v6/users/@me/delete', headers={ - 'Authorization': token, - }, json={ - 'password': password - }) + resp = await test_cli.post( + "/api/v6/users/@me/delete", + headers={"Authorization": token}, + json={"password": password}, + ) assert resp.status_code == 204 - await test_cli.app.db.execute(""" + await test_cli.app.db.execute( + """ DELETE FROM users WHERE id = $1 - """, int(rjson['id'])) + """, + int(rjson["id"]), + ) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index dc31235..6402f21 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -42,21 +42,18 @@ async def _json_send(conn, data): async def _json_send_op(conn, opcode, data=None): - await _json_send(conn, { - 'op': opcode, - 'd': data - }) + await _json_send(conn, {"op": opcode, "d": data}) async def _close(conn): - await conn.close(1000, 'test end') + await conn.close(1000, "test end") async def get_gw(test_cli) -> str: """Get the Gateway URL.""" - gw_resp = await test_cli.get('/api/v6/gateway') + gw_resp = await test_cli.get("/api/v6/gateway") gw_json = await gw_resp.json - return gw_json['url'] + return gw_json["url"] async def gw_start(test_cli, *, etf=False): @@ -64,7 +61,7 @@ async def gw_start(test_cli, *, etf=False): gw_url = await get_gw(test_cli) if etf: - gw_url = f'{gw_url}?encoding=etf' + gw_url = f"{gw_url}?encoding=etf" return await websockets.connect(gw_url) @@ -76,11 +73,11 @@ async def test_gw(test_cli): conn = await gw_start(test_cli) hello = await _json(conn) - assert hello['op'] == OP.HELLO + assert hello["op"] == OP.HELLO - assert isinstance(hello['d'], dict) - assert isinstance(hello['d']['heartbeat_interval'], int) - assert isinstance(hello['d']['_trace'], list) + assert isinstance(hello["d"], dict) + assert isinstance(hello["d"]["heartbeat_interval"], int) + assert isinstance(hello["d"]["_trace"], list) await _close(conn) @@ -92,12 +89,9 @@ async def test_ready(test_cli_user): # get the hello frame but ignore it await _json(conn) - await _json_send(conn, { - 'op': OP.IDENTIFY, - 'd': { - 'token': test_cli_user.user['token'], - } - }) + await _json_send( + conn, {"op": OP.IDENTIFY, "d": {"token": test_cli_user.user["token"]}} + ) # try to get a ready try: @@ -116,35 +110,32 @@ async def test_ready_fields(test_cli_user): # get the hello frame but ignore it await _json(conn) - await _json_send(conn, { - 'op': OP.IDENTIFY, - 'd': { - 'token': test_cli_user.user['token'], - } - }) + await _json_send( + conn, {"op": OP.IDENTIFY, "d": {"token": test_cli_user.user["token"]}} + ) try: ready = await _json(conn) assert isinstance(ready, dict) - assert ready['op'] == OP.DISPATCH - assert ready['t'] == 'READY' + assert ready["op"] == OP.DISPATCH + assert ready["t"] == "READY" - data = ready['d'] + data = ready["d"] assert isinstance(data, dict) # NOTE: change if default gateway changes - assert data['v'] == 6 + assert data["v"] == 6 # make sure other fields exist and are with # proper types. - assert isinstance(data['user'], dict) - assert isinstance(data['private_channels'], list) - assert isinstance(data['guilds'], list) - assert isinstance(data['session_id'], str) - assert isinstance(data['_trace'], list) + assert isinstance(data["user"], dict) + assert isinstance(data["private_channels"], list) + assert isinstance(data["guilds"], list) + assert isinstance(data["session_id"], str) + assert isinstance(data["_trace"], list) - if 'shard' in data: - assert isinstance(data['shard'], list) + if "shard" in data: + assert isinstance(data["shard"], list) finally: await _close(conn) @@ -156,24 +147,21 @@ async def test_heartbeat(test_cli_user): # get the hello frame but ignore it await _json(conn) - await _json_send(conn, { - 'op': OP.IDENTIFY, - 'd': { - 'token': test_cli_user.user['token'], - } - }) + await _json_send( + conn, {"op": OP.IDENTIFY, "d": {"token": test_cli_user.user["token"]}} + ) # ignore ready data ready = await _json(conn) assert isinstance(ready, dict) - assert ready['op'] == OP.DISPATCH - assert ready['t'] == 'READY' + assert ready["op"] == OP.DISPATCH + assert ready["t"] == "READY" # test a heartbeat await _json_send_op(conn, OP.HEARTBEAT) recv = await _json(conn) assert isinstance(recv, dict) - assert recv['op'] == OP.HEARTBEAT_ACK + assert recv["op"] == OP.HEARTBEAT_ACK await _close(conn) @@ -185,6 +173,6 @@ async def test_etf(test_cli): try: hello = await _etf(conn) - assert hello['op'] == OP.HELLO + assert hello["op"] == OP.HELLO finally: await _close(conn)