diff --git a/config.ci.py b/config.ci.py
index 24d6046..911085a 100644
--- a/config.ci.py
+++ b/config.ci.py
@@ -17,13 +17,14 @@ along with this program. If not, see .
"""
-MODE = 'CI'
+MODE = "CI"
class Config:
"""Default configuration values for litecord."""
- MAIN_URL = 'localhost:1'
- NAME = 'gitlab ci'
+
+ MAIN_URL = "localhost:1"
+ NAME = "gitlab ci"
# Enable debug logging?
DEBUG = False
@@ -37,11 +38,11 @@ class Config:
# Set this url to somewhere *your users*
# will hit the websocket.
# e.g 'gateway.example.com' for reverse proxies.
- WEBSOCKET_URL = 'localhost:5001'
+ WEBSOCKET_URL = "localhost:5001"
# Where to host the websocket?
# (a local address the server will bind to)
- WS_HOST = 'localhost'
+ WS_HOST = "localhost"
WS_PORT = 5001
# Postgres credentials
@@ -51,10 +52,10 @@ class Config:
class Development(Config):
DEBUG = True
POSTGRES = {
- 'host': 'localhost',
- 'user': 'litecord',
- 'password': '123',
- 'database': 'litecord',
+ "host": "localhost",
+ "user": "litecord",
+ "password": "123",
+ "database": "litecord",
}
@@ -66,8 +67,4 @@ class Production(Config):
class CI(Config):
DEBUG = True
- POSTGRES = {
- 'host': 'postgres',
- 'user': 'postgres',
- 'password': ''
- }
+ POSTGRES = {"host": "postgres", "user": "postgres", "password": ""}
diff --git a/config.example.py b/config.example.py
index 2e225e4..999add7 100644
--- a/config.example.py
+++ b/config.example.py
@@ -17,16 +17,17 @@ along with this program. If not, see .
"""
-MODE = 'Development'
+MODE = "Development"
class Config:
"""Default configuration values for litecord."""
+
#: Main URL of the instance.
- MAIN_URL = 'discordapp.io'
+ MAIN_URL = "discordapp.io"
#: Name of the instance
- NAME = 'Litecord/Nya'
+ NAME = "Litecord/Nya"
#: Enable debug logging?
DEBUG = False
@@ -45,17 +46,17 @@ class Config:
# Set this url to somewhere *your users*
# will hit the websocket.
# e.g 'gateway.example.com' for reverse proxies.
- WEBSOCKET_URL = 'localhost:5001'
+ WEBSOCKET_URL = "localhost:5001"
#: Where to host the websocket?
# (a local address the server will bind to)
- WS_HOST = '0.0.0.0'
+ WS_HOST = "0.0.0.0"
WS_PORT = 5001
#: Mediaproxy URL on the internet
# mediaproxy is made to prevent client IPs being leaked.
# None is a valid value if you don't want to deploy mediaproxy.
- MEDIA_PROXY = 'localhost:5002'
+ MEDIA_PROXY = "localhost:5002"
#: Postgres credentials
POSTGRES = {}
@@ -65,10 +66,10 @@ class Development(Config):
DEBUG = True
POSTGRES = {
- 'host': 'localhost',
- 'user': 'litecord',
- 'password': '123',
- 'database': 'litecord',
+ "host": "localhost",
+ "user": "litecord",
+ "password": "123",
+ "database": "litecord",
}
@@ -77,8 +78,8 @@ class Production(Config):
IS_SSL = True
POSTGRES = {
- 'host': 'some_production_postgres',
- 'user': 'some_production_user',
- 'password': 'some_production_password',
- 'database': 'litecord_or_anything_else_really',
+ "host": "some_production_postgres",
+ "user": "some_production_user",
+ "password": "some_production_password",
+ "database": "litecord_or_anything_else_really",
}
diff --git a/litecord/__init__.py b/litecord/__init__.py
index ce49370..d21f555 100644
--- a/litecord/__init__.py
+++ b/litecord/__init__.py
@@ -16,4 +16,3 @@ You should have received a copy of the GNU General Public License
along with this program. If not, see .
"""
-
diff --git a/litecord/admin_schemas.py b/litecord/admin_schemas.py
index 22d6bdb..900f734 100644
--- a/litecord/admin_schemas.py
+++ b/litecord/admin_schemas.py
@@ -19,42 +19,33 @@ along with this program. If not, see .
from litecord.enums import Feature, UserFlags
-VOICE_SERVER = {
- 'hostname': {'type': 'string', 'maxlength': 255, 'required': True}
-}
+VOICE_SERVER = {"hostname": {"type": "string", "maxlength": 255, "required": True}}
VOICE_REGION = {
- 'id': {'type': 'string', 'maxlength': 255, 'required': True},
- 'name': {'type': 'string', 'maxlength': 255, 'required': True},
-
- 'vip': {'type': 'boolean', 'default': False},
- 'deprecated': {'type': 'boolean', 'default': False},
- 'custom': {'type': 'boolean', 'default': False},
+ "id": {"type": "string", "maxlength": 255, "required": True},
+ "name": {"type": "string", "maxlength": 255, "required": True},
+ "vip": {"type": "boolean", "default": False},
+ "deprecated": {"type": "boolean", "default": False},
+ "custom": {"type": "boolean", "default": False},
}
FEATURES = {
- 'features': {
- 'type': 'list', 'required': True,
-
+ "features": {
+ "type": "list",
+ "required": True,
# using Feature doesn't seem to work with a "not callable" error.
- 'schema': {'coerce': lambda x: Feature(x)}
+ "schema": {"coerce": lambda x: Feature(x)},
}
}
USER_CREATE = {
- 'username': {'type': 'username', 'required': True},
- 'email': {'type': 'email', 'required': True},
- 'password': {'type': 'string', 'minlength': 5, 'required': True},
+ "username": {"type": "username", "required": True},
+ "email": {"type": "email", "required": True},
+ "password": {"type": "string", "minlength": 5, "required": True},
}
-INSTANCE_INVITE = {
- 'max_uses': {'type': 'integer', 'required': True}
-}
+INSTANCE_INVITE = {"max_uses": {"type": "integer", "required": True}}
-GUILD_UPDATE = {
- 'unavailable': {'type': 'boolean', 'required': False}
-}
+GUILD_UPDATE = {"unavailable": {"type": "boolean", "required": False}}
-USER_UPDATE = {
- 'flags': {'required': False, 'coerce': UserFlags.from_int}
-}
+USER_UPDATE = {"flags": {"required": False, "coerce": UserFlags.from_int}}
diff --git a/litecord/auth.py b/litecord/auth.py
index 3c52ddf..e9ac19f 100644
--- a/litecord/auth.py
+++ b/litecord/auth.py
@@ -55,44 +55,50 @@ async def raw_token_check(token: str, db=None) -> int:
# just try by fragments instead of
# unpacking
- fragments = token.split('.')
+ fragments = token.split(".")
user_id = fragments[0]
try:
user_id = base64.b64decode(user_id.encode())
user_id = int(user_id)
except (ValueError, binascii.Error):
- raise Unauthorized('Invalid user ID type')
+ raise Unauthorized("Invalid user ID type")
- pwd_hash = await db.fetchval("""
+ pwd_hash = await db.fetchval(
+ """
SELECT password_hash
FROM users
WHERE id = $1
- """, user_id)
+ """,
+ user_id,
+ )
if not pwd_hash:
- raise Unauthorized('User ID not found')
+ raise Unauthorized("User ID not found")
signer = TimestampSigner(pwd_hash)
try:
signer.unsign(token)
- log.debug('login for uid {} successful', user_id)
+ log.debug("login for uid {} successful", user_id)
# update the user's last_session field
# so that we can keep an exact track of activity,
# even on long-lived single sessions (that can happen
# with people leaving their clients open forever)
- await db.execute("""
+ await db.execute(
+ """
UPDATE users
SET last_session = (now() at time zone 'utc')
WHERE id = $1
- """, user_id)
+ """,
+ user_id,
+ )
return user_id
except BadSignature:
- log.warning('token failed for uid {}', user_id)
- raise Forbidden('Invalid token')
+ log.warning("token failed for uid {}", user_id)
+ raise Forbidden("Invalid token")
async def token_check() -> int:
@@ -104,12 +110,12 @@ async def token_check() -> int:
pass
try:
- token = request.headers['Authorization']
+ token = request.headers["Authorization"]
except KeyError:
- raise Unauthorized('No token provided')
+ raise Unauthorized("No token provided")
- if token.startswith('Bot '):
- token = token.replace('Bot ', '')
+ if token.startswith("Bot "):
+ token = token.replace("Bot ", "")
user_id = await raw_token_check(token)
request.user_id = user_id
@@ -120,15 +126,18 @@ async def admin_check() -> int:
"""Check if the user is an admin."""
user_id = await token_check()
- flags = await app.db.fetchval("""
+ flags = await app.db.fetchval(
+ """
SELECT flags
FROM users
WHERE id = $1
- """, user_id)
+ """,
+ user_id,
+ )
flags = UserFlags.from_int(flags)
if not flags.is_staff:
- raise Unauthorized('you are not staff')
+ raise Unauthorized("you are not staff")
return user_id
@@ -138,9 +147,7 @@ async def hash_data(data: str, loop=None) -> str:
loop = loop or app.loop
buf = data.encode()
- hashed = await loop.run_in_executor(
- None, bcrypt.hashpw, buf, bcrypt.gensalt(14)
- )
+ hashed = await loop.run_in_executor(None, bcrypt.hashpw, buf, bcrypt.gensalt(14))
return hashed.decode()
@@ -148,22 +155,28 @@ async def hash_data(data: str, loop=None) -> str:
async def check_username_usage(username: str, db=None):
"""Raise an error if too many people are with the same username."""
db = db or app.db
- same_username = await db.fetchval("""
+ same_username = await db.fetchval(
+ """
SELECT COUNT(*)
FROM users
WHERE username = $1
- """, username)
+ """,
+ username,
+ )
if same_username > 9000:
- raise BadRequest('Too many people.', {
- 'username': 'Too many people used the same username. '
- 'Please choose another'
- })
+ raise BadRequest(
+ "Too many people.",
+ {
+ "username": "Too many people used the same username. "
+ "Please choose another"
+ },
+ )
def _raw_discrim() -> str:
new_discrim = randint(1, 9999)
- new_discrim = '%04d' % new_discrim
+ new_discrim = "%04d" % new_discrim
return new_discrim
@@ -186,11 +199,15 @@ async def roll_discrim(username: str, *, db=None) -> Optional[str]:
discrim = _raw_discrim()
# check if anyone is with it
- res = await db.fetchval("""
+ res = await db.fetchval(
+ """
SELECT id
FROM users
WHERE username = $1 AND discriminator = $2
- """, username, discrim)
+ """,
+ username,
+ discrim,
+ )
# if no user is found with the (username, discrim)
# pair, then this is unique! return it.
@@ -200,8 +217,9 @@ async def roll_discrim(username: str, *, db=None) -> Optional[str]:
return None
-async def create_user(username: str, email: str, password: str,
- db=None, loop=None) -> Tuple[int, str]:
+async def create_user(
+ username: str, email: str, password: str, db=None, loop=None
+) -> Tuple[int, str]:
"""Create a single user.
Generates a distriminator and other information. You can fetch the user
@@ -214,20 +232,28 @@ async def create_user(username: str, email: str, password: str,
new_discrim = await roll_discrim(username, db=db)
if new_discrim is None:
- raise BadRequest('Unable to register.', {
- 'username': 'Too many people are with this username.'
- })
+ raise BadRequest(
+ "Unable to register.",
+ {"username": "Too many people are with this username."},
+ )
pwd_hash = await hash_data(password, loop)
try:
- await db.execute("""
+ await db.execute(
+ """
INSERT INTO users
(id, email, username, discriminator, password_hash)
VALUES
($1, $2, $3, $4, $5)
- """, new_id, email, username, new_discrim, pwd_hash)
+ """,
+ new_id,
+ email,
+ username,
+ new_discrim,
+ pwd_hash,
+ )
except UniqueViolationError:
- raise BadRequest('Email already used.')
+ raise BadRequest("Email already used.")
return new_id, pwd_hash
diff --git a/litecord/blueprints/__init__.py b/litecord/blueprints/__init__.py
index 9b021b4..cc15930 100644
--- a/litecord/blueprints/__init__.py
+++ b/litecord/blueprints/__init__.py
@@ -34,7 +34,21 @@ from .static import bp as static
from .attachments import bp as attachments
from .dm_channels import bp as dm_channels
-__all__ = ['gateway', 'auth', 'users', 'guilds', 'channels',
- 'webhooks', 'science', 'voice', 'invites', 'relationships',
- 'dms', 'icons', 'nodeinfo', 'static', 'attachments',
- 'dm_channels']
+__all__ = [
+ "gateway",
+ "auth",
+ "users",
+ "guilds",
+ "channels",
+ "webhooks",
+ "science",
+ "voice",
+ "invites",
+ "relationships",
+ "dms",
+ "icons",
+ "nodeinfo",
+ "static",
+ "attachments",
+ "dm_channels",
+]
diff --git a/litecord/blueprints/admin_api/__init__.py b/litecord/blueprints/admin_api/__init__.py
index d27cc52..ab0e1e8 100644
--- a/litecord/blueprints/admin_api/__init__.py
+++ b/litecord/blueprints/admin_api/__init__.py
@@ -23,4 +23,4 @@ from .guilds import bp as guilds
from .users import bp as users
from .instance_invites import bp as instance_invites
-__all__ = ['voice', 'features', 'guilds', 'users', 'instance_invites']
+__all__ = ["voice", "features", "guilds", "users", "instance_invites"]
diff --git a/litecord/blueprints/admin_api/features.py b/litecord/blueprints/admin_api/features.py
index 9de58d6..314bb44 100644
--- a/litecord/blueprints/admin_api/features.py
+++ b/litecord/blueprints/admin_api/features.py
@@ -25,45 +25,53 @@ from litecord.errors import BadRequest
from litecord.schemas import validate
from litecord.admin_schemas import FEATURES
-bp = Blueprint('features_admin', __name__)
+bp = Blueprint("features_admin", __name__)
async def _features_from_req() -> List[str]:
j = validate(await request.get_json(), FEATURES)
- return [feature.value for feature in j['features']]
+ return [feature.value for feature in j["features"]]
async def _features(guild_id: int):
- return jsonify({
- 'features': await app.storage.guild_features(guild_id)
- })
+ return jsonify({"features": await app.storage.guild_features(guild_id)})
async def _update_features(guild_id: int, features: list):
- if 'VANITY_URL' not in features:
+ if "VANITY_URL" not in features:
existing_inv = await app.storage.vanity_invite(guild_id)
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM vanity_invites
WHERE guild_id = $1
- """, guild_id)
+ """,
+ guild_id,
+ )
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM invites
WHERE code = $1
- """, existing_inv)
+ """,
+ existing_inv,
+ )
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE guilds
SET features = $1
WHERE id = $2
- """, features, guild_id)
+ """,
+ features,
+ guild_id,
+ )
guild = await app.storage.get_guild_full(guild_id)
- await app.dispatcher.dispatch('guild', guild_id, 'GUILD_UPDATE', guild)
+ await app.dispatcher.dispatch("guild", guild_id, "GUILD_UPDATE", guild)
-@bp.route('//features', methods=['PATCH'])
+@bp.route("//features", methods=["PATCH"])
async def replace_features(guild_id: int):
"""Replace the feature list in a guild"""
await admin_check()
@@ -76,7 +84,7 @@ async def replace_features(guild_id: int):
return await _features(guild_id)
-@bp.route('//features', methods=['PUT'])
+@bp.route("//features", methods=["PUT"])
async def insert_features(guild_id: int):
"""Insert a feature on a guild."""
await admin_check()
@@ -93,7 +101,7 @@ async def insert_features(guild_id: int):
return await _features(guild_id)
-@bp.route('//features', methods=['DELETE'])
+@bp.route("//features", methods=["DELETE"])
async def remove_features(guild_id: int):
"""Remove a feature from a guild"""
await admin_check()
@@ -104,7 +112,7 @@ async def remove_features(guild_id: int):
try:
features.remove(feature)
except ValueError:
- raise BadRequest('Trying to remove already removed feature.')
+ raise BadRequest("Trying to remove already removed feature.")
await _update_features(guild_id, features)
return await _features(guild_id)
diff --git a/litecord/blueprints/admin_api/guilds.py b/litecord/blueprints/admin_api/guilds.py
index e27f544..1cf792b 100644
--- a/litecord/blueprints/admin_api/guilds.py
+++ b/litecord/blueprints/admin_api/guilds.py
@@ -25,9 +25,10 @@ from litecord.admin_schemas import GUILD_UPDATE
from litecord.blueprints.guilds import delete_guild
from litecord.errors import GuildNotFound
-bp = Blueprint('guilds_admin', __name__)
+bp = Blueprint("guilds_admin", __name__)
-@bp.route('/', methods=['GET'])
+
+@bp.route("/", methods=["GET"])
async def get_guild(guild_id: int):
"""Get a basic guild payload."""
await admin_check()
@@ -40,7 +41,7 @@ async def get_guild(guild_id: int):
return jsonify(guild)
-@bp.route('/', methods=['PATCH'])
+@bp.route("/", methods=["PATCH"])
async def update_guild(guild_id: int):
await admin_check()
@@ -48,13 +49,13 @@ async def update_guild(guild_id: int):
# TODO: what happens to the other guild attributes when its
# unavailable? do they vanish?
- old_unavailable = app.guild_store.get(guild_id, 'unavailable')
- new_unavailable = j.get('unavailable', old_unavailable)
+ old_unavailable = app.guild_store.get(guild_id, "unavailable")
+ new_unavailable = j.get("unavailable", old_unavailable)
# always set unavailable status since new_unavailable will be
# old_unavailable when not provided, so we don't need to check if
# j.unavailable is there
- app.guild_store.set(guild_id, 'unavailable', j['unavailable'])
+ app.guild_store.set(guild_id, "unavailable", j["unavailable"])
guild = await app.storage.get_guild(guild_id)
@@ -62,17 +63,17 @@ async def update_guild(guild_id: int):
if old_unavailable and not new_unavailable:
# guild became available
- await app.dispatcher.dispatch_guild(guild_id, 'GUILD_CREATE', guild)
+ await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild)
else:
# guild became unavailable
- await app.dispatcher.dispatch_guild(guild_id, 'GUILD_DELETE', guild)
+ await app.dispatcher.dispatch_guild(guild_id, "GUILD_DELETE", guild)
return jsonify(guild)
-@bp.route('/', methods=['DELETE'])
+@bp.route("/", methods=["DELETE"])
async def delete_guild_as_admin(guild_id):
"""Delete a single guild via the admin API without ownership checks."""
await admin_check()
await delete_guild(guild_id)
- return '', 204
+ return "", 204
diff --git a/litecord/blueprints/admin_api/instance_invites.py b/litecord/blueprints/admin_api/instance_invites.py
index dbd6c78..c410c47 100644
--- a/litecord/blueprints/admin_api/instance_invites.py
+++ b/litecord/blueprints/admin_api/instance_invites.py
@@ -27,13 +27,13 @@ from litecord.types import timestamp_
from litecord.schemas import validate
from litecord.admin_schemas import INSTANCE_INVITE
-bp = Blueprint('instance_invites', __name__)
+bp = Blueprint("instance_invites", __name__)
ALPHABET = string.ascii_lowercase + string.ascii_uppercase + string.digits
async def _gen_inv() -> str:
"""Generate an invite code"""
- return ''.join(choice(ALPHABET) for _ in range(6))
+ return "".join(choice(ALPHABET) for _ in range(6))
async def gen_inv(ctx) -> str:
@@ -41,11 +41,14 @@ async def gen_inv(ctx) -> str:
for _ in range(10):
possible_inv = await _gen_inv()
- created_at = await ctx.db.fetchval("""
+ created_at = await ctx.db.fetchval(
+ """
SELECT created_at
FROM instance_invites
WHERE code = $1
- """, possible_inv)
+ """,
+ possible_inv,
+ )
if created_at is None:
return possible_inv
@@ -53,57 +56,71 @@ async def gen_inv(ctx) -> str:
return None
-@bp.route('', methods=['GET'])
+@bp.route("", methods=["GET"])
async def _all_instance_invites():
await admin_check()
- rows = await app.db.fetch("""
+ rows = await app.db.fetch(
+ """
SELECT code, created_at, uses, max_uses
FROM instance_invites
- """)
+ """
+ )
rows = [dict(row) for row in rows]
for row in rows:
- row['created_at'] = timestamp_(row['created_at'])
+ row["created_at"] = timestamp_(row["created_at"])
return jsonify(rows)
-@bp.route('', methods=['PUT'])
+@bp.route("", methods=["PUT"])
async def _create_invite():
await admin_check()
code = await gen_inv(app)
if code is None:
- return 'failed to make invite', 500
+ return "failed to make invite", 500
j = validate(await request.get_json(), INSTANCE_INVITE)
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO instance_invites (code, max_uses)
VALUES ($1, $2)
- """, code, j['max_uses'])
+ """,
+ code,
+ j["max_uses"],
+ )
- inv = dict(await app.db.fetchrow("""
+ inv = dict(
+ await app.db.fetchrow(
+ """
SELECT code, created_at, uses, max_uses
FROM instance_invites
WHERE code = $1
- """, code))
+ """,
+ code,
+ )
+ )
return jsonify(dict(inv))
-@bp.route('/', methods=['DELETE'])
+@bp.route("/", methods=["DELETE"])
async def _del_invite(invite: str):
await admin_check()
- res = await app.db.execute("""
+ res = await app.db.execute(
+ """
DELETE FROM instance_invites
WHERE code = $1
- """, invite)
+ """,
+ invite,
+ )
- if res.lower() == 'delete 0':
- return 'invite not found', 404
+ if res.lower() == "delete 0":
+ return "invite not found", 404
- return '', 204
+ return "", 204
diff --git a/litecord/blueprints/admin_api/users.py b/litecord/blueprints/admin_api/users.py
index ab3ba37..c1df4e8 100644
--- a/litecord/blueprints/admin_api/users.py
+++ b/litecord/blueprints/admin_api/users.py
@@ -25,24 +25,21 @@ from litecord.schemas import validate
from litecord.admin_schemas import USER_CREATE, USER_UPDATE
from litecord.errors import BadRequest, Forbidden
from litecord.utils import async_map
-from litecord.blueprints.users import (
- delete_user, user_disconnect, mass_user_update
-)
+from litecord.blueprints.users import delete_user, user_disconnect, mass_user_update
from litecord.enums import UserFlags
-bp = Blueprint('users_admin', __name__)
+bp = Blueprint("users_admin", __name__)
-@bp.route('', methods=['POST'])
+
+@bp.route("", methods=["POST"])
async def _create_user():
await admin_check()
j = validate(await request.get_json(), USER_CREATE)
- user_id, _ = await create_user(j['username'], j['email'], j['password'])
+ user_id, _ = await create_user(j["username"], j["email"], j["password"])
- return jsonify(
- await app.storage.get_user(user_id)
- )
+ return jsonify(await app.storage.get_user(user_id))
def args_try(args: dict, typ, field: str, default):
@@ -51,29 +48,29 @@ def args_try(args: dict, typ, field: str, default):
try:
return typ(args.get(field, default))
except (TypeError, ValueError):
- raise BadRequest(f'invalid {field} value')
+ raise BadRequest(f"invalid {field} value")
-@bp.route('', methods=['GET'])
+@bp.route("", methods=["GET"])
async def _search_users():
await admin_check()
args = request.args
- username, discrim = args.get('username'), args.get('discriminator')
+ username, discrim = args.get("username"), args.get("discriminator")
- per_page = args_try(args, int, 'per_page', 20)
- page = args_try(args, int, 'page', 0)
+ per_page = args_try(args, int, "per_page", 20)
+ page = args_try(args, int, "page", 0)
if page < 0:
- raise BadRequest('invalid page number')
+ raise BadRequest("invalid page number")
if per_page > 50:
- raise BadRequest('invalid per page number')
+ raise BadRequest("invalid per page number")
# any of those must be available.
if not any((username, discrim)):
- raise BadRequest('must insert username or discrim')
+ raise BadRequest("must insert username or discrim")
wheres, args = [], []
@@ -82,29 +79,31 @@ async def _search_users():
args.append(username)
if discrim:
- wheres.append(f'discriminator = ${len(args) + 2}')
+ wheres.append(f"discriminator = ${len(args) + 2}")
args.append(discrim)
- where_tot = 'WHERE ' if args else ''
- where_tot += ' AND '.join(wheres)
+ where_tot = "WHERE " if args else ""
+ where_tot += " AND ".join(wheres)
- rows = await app.db.fetch(f"""
+ rows = await app.db.fetch(
+ f"""
SELECT id
FROM users
{where_tot}
ORDER BY id ASC
LIMIT {per_page}
OFFSET ($1 * {per_page})
- """, page, *args)
-
- rows = [r['id'] for r in rows]
-
- return jsonify(
- await async_map(app.storage.get_user, rows)
+ """,
+ page,
+ *args,
)
+ rows = [r["id"] for r in rows]
-@bp.route('/', methods=['DELETE'])
+ return jsonify(await async_map(app.storage.get_user, rows))
+
+
+@bp.route("/", methods=["DELETE"])
async def _delete_single_user(user_id: int):
await admin_check()
@@ -115,13 +114,10 @@ async def _delete_single_user(user_id: int):
new_user = await app.storage.get_user(user_id)
- return jsonify({
- 'old': old_user,
- 'new': new_user
- })
+ return jsonify({"old": old_user, "new": new_user})
-@bp.route('/', methods=['PATCH'])
+@bp.route("/", methods=["PATCH"])
async def patch_user(user_id: int):
await admin_check()
@@ -129,21 +125,25 @@ async def patch_user(user_id: int):
# get the original user for flags checking
user = await app.storage.get_user(user_id)
- old_flags = UserFlags.from_int(user['flags'])
+ old_flags = UserFlags.from_int(user["flags"])
# j.flags is already a UserFlags since we coerce it.
- if 'flags' in j:
- new_flags = j['flags']
+ if "flags" in j:
+ new_flags = j["flags"]
# disallow any changes to the staff badge
if new_flags.is_staff != old_flags.is_staff:
- raise Forbidden('you can not change a users staff badge')
+ raise Forbidden("you can not change a users staff badge")
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE users
SET flags = $1
WHERE id = $2
- """, new_flags.value, user_id)
+ """,
+ new_flags.value,
+ user_id,
+ )
public_user, _ = await mass_user_update(user_id, app)
return jsonify(public_user)
diff --git a/litecord/blueprints/admin_api/voice.py b/litecord/blueprints/admin_api/voice.py
index e700b27..334bcfe 100644
--- a/litecord/blueprints/admin_api/voice.py
+++ b/litecord/blueprints/admin_api/voice.py
@@ -27,10 +27,10 @@ from litecord.admin_schemas import VOICE_SERVER, VOICE_REGION
from litecord.errors import BadRequest
log = Logger(__name__)
-bp = Blueprint('voice_admin', __name__)
+bp = Blueprint("voice_admin", __name__)
-@bp.route('/regions/', methods=['GET'])
+@bp.route("/regions/", methods=["GET"])
async def get_region_servers(region):
"""Return a list of all servers for a region."""
await admin_check()
@@ -38,18 +38,25 @@ async def get_region_servers(region):
return jsonify(servers)
-@bp.route('/regions', methods=['PUT'])
+@bp.route("/regions", methods=["PUT"])
async def insert_new_region():
"""Create a voice region."""
await admin_check()
j = validate(await request.get_json(), VOICE_REGION)
- j['id'] = j['id'].lower()
+ j["id"] = j["id"].lower()
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO voice_regions (id, name, vip, deprecated, custom)
VALUES ($1, $2, $3, $4, $5)
- """, j['id'], j['name'], j['vip'], j['deprecated'], j['custom'])
+ """,
+ j["id"],
+ j["name"],
+ j["vip"],
+ j["deprecated"],
+ j["custom"],
+ )
regions = await app.storage.all_voice_regions()
region_count = len(regions)
@@ -57,34 +64,41 @@ async def insert_new_region():
# if region count is 1, this is the first region to be created,
# so we should update all guilds to that region
if region_count == 1:
- res = await app.db.execute("""
+ res = await app.db.execute(
+ """
UPDATE guilds
SET region = $1
- """, j['id'])
+ """,
+ j["id"],
+ )
- log.info('updating guilds to first voice region: {}', res)
+ log.info("updating guilds to first voice region: {}", res)
return jsonify(regions)
-@bp.route('/regions//servers', methods=['PUT'])
+@bp.route("/regions//servers", methods=["PUT"])
async def put_region_server(region):
"""Insert a voice server to a region"""
await admin_check()
j = validate(await request.get_json(), VOICE_SERVER)
try:
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO voice_servers (hostname, region_id)
VALUES ($1, $2)
- """, j['hostname'], region)
+ """,
+ j["hostname"],
+ region,
+ )
except asyncpg.UniqueViolationError:
- raise BadRequest('voice server already exists with given hostname')
+ raise BadRequest("voice server already exists with given hostname")
- return '', 204
+ return "", 204
-@bp.route('/regions//deprecate', methods=['PUT'])
+@bp.route("/regions//deprecate", methods=["PUT"])
async def deprecate_region(region):
"""Deprecate a voice region."""
await admin_check()
@@ -92,13 +106,16 @@ async def deprecate_region(region):
# TODO: write this
await app.voice.disable_region(region)
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE voice_regions
SET deprecated = true
WHERE id = $1
- """, region)
+ """,
+ region,
+ )
- return '', 204
+ return "", 204
async def guild_region_check(app_):
@@ -112,10 +129,11 @@ async def guild_region_check(app_):
regions = await app_.storage.all_voice_regions()
if not regions:
- log.info('region check: no regions to move guilds to')
+ log.info("region check: no regions to move guilds to")
return
- res = await app_.db.execute("""
+ res = await app_.db.execute(
+ """
UPDATE guilds
SET region = (
SELECT id
@@ -124,6 +142,8 @@ async def guild_region_check(app_):
LIMIT 1
)
WHERE region = NULL
- """, len(regions))
+ """,
+ len(regions),
+ )
- log.info('region check: updating guild.region=null: {!r}', res)
+ log.info("region check: updating guild.region=null: {!r}", res)
diff --git a/litecord/blueprints/attachments.py b/litecord/blueprints/attachments.py
index 9d41755..49da21a 100644
--- a/litecord/blueprints/attachments.py
+++ b/litecord/blueprints/attachments.py
@@ -24,16 +24,17 @@ from PIL import Image
from litecord.images import resize_gif
-bp = Blueprint('attachments', __name__)
-ATTACHMENTS = Path.cwd() / 'attachments'
+bp = Blueprint("attachments", __name__)
+ATTACHMENTS = Path.cwd() / "attachments"
-async def _resize_gif(attach_id: int, resized_path: Path,
- width: int, height: int) -> str:
+async def _resize_gif(
+ attach_id: int, resized_path: Path, width: int, height: int
+) -> str:
"""Resize a GIF attachment."""
# get original gif bytes
- orig_path = ATTACHMENTS / f'{attach_id}.gif'
+ orig_path = ATTACHMENTS / f"{attach_id}.gif"
orig_bytes = orig_path.read_bytes()
# give them and the target size to the
@@ -47,10 +48,7 @@ async def _resize_gif(attach_id: int, resized_path: Path,
return str(resized_path)
-FORMAT_HARDCODE = {
- 'jpg': 'jpeg',
- 'jpe': 'jpeg'
-}
+FORMAT_HARDCODE = {"jpg": "jpeg", "jpe": "jpeg"}
def to_format(ext: str) -> str:
@@ -63,11 +61,10 @@ def to_format(ext: str) -> str:
return ext
-async def _resize(image, attach_id: int, ext: str,
- width: int, height: int) -> str:
+async def _resize(image, attach_id: int, ext: str, width: int, height: int) -> str:
"""Resize an image."""
# check if we have it on the folder
- resized_path = ATTACHMENTS / f'{attach_id}_{width}_{height}.{ext}'
+ resized_path = ATTACHMENTS / f"{attach_id}_{width}_{height}.{ext}"
# keep a str-fied instance since that is what
# we'll return.
@@ -81,7 +78,7 @@ async def _resize(image, attach_id: int, ext: str,
# the process is different for gif files because we need
# gifsicle. doing it manually is too troublesome.
- if ext == 'gif':
+ if ext == "gif":
return await _resize_gif(attach_id, resized_path, width, height)
# NOTE: this is the same resize mode for icons.
@@ -91,38 +88,42 @@ async def _resize(image, attach_id: int, ext: str,
return resized_path_s
-@bp.route('/attachments'
- '///',
- methods=['GET'])
-async def _get_attachment(channel_id: int, message_id: int,
- filename: str):
+@bp.route(
+ "/attachments" "///", methods=["GET"]
+)
+async def _get_attachment(channel_id: int, message_id: int, filename: str):
- attach_id = await app.db.fetchval("""
+ attach_id = await app.db.fetchval(
+ """
SELECT id
FROM attachments
WHERE channel_id = $1
AND message_id = $2
AND filename = $3
- """, channel_id, message_id, filename)
+ """,
+ channel_id,
+ message_id,
+ filename,
+ )
if attach_id is None:
- return '', 404
+ return "", 404
- ext = filename.split('.')[-1]
- filepath = f'./attachments/{attach_id}.{ext}'
+ ext = filename.split(".")[-1]
+ filepath = f"./attachments/{attach_id}.{ext}"
image = Image.open(filepath)
im_width, im_height = image.size
try:
- width = int(request.args.get('width', 0)) or im_width
+ width = int(request.args.get("width", 0)) or im_width
except ValueError:
- return '', 400
+ return "", 400
try:
- height = int(request.args.get('height', 0)) or im_height
+ height = int(request.args.get("height", 0)) or im_height
except ValueError:
- return '', 400
+ return "", 400
# if width and height are the same (happens if they weren't provided)
if width == im_width and height == im_height:
diff --git a/litecord/blueprints/auth.py b/litecord/blueprints/auth.py
index e6d4124..38ce17c 100644
--- a/litecord/blueprints/auth.py
+++ b/litecord/blueprints/auth.py
@@ -32,7 +32,7 @@ from litecord.snowflake import get_snowflake
from .invites import use_invite
log = Logger(__name__)
-bp = Blueprint('auth', __name__)
+bp = Blueprint("auth", __name__)
async def check_password(pwd_hash: str, given_password: str) -> bool:
@@ -53,141 +53,139 @@ def make_token(user_id, user_pwd_hash) -> str:
return signer.sign(user_id).decode()
-@bp.route('/register', methods=['POST'])
+@bp.route("/register", methods=["POST"])
async def register():
"""Register a single user."""
- enabled = app.config.get('REGISTRATIONS')
+ enabled = app.config.get("REGISTRATIONS")
if not enabled:
- raise BadRequest('Registrations disabled', {
- 'email': 'Registrations are disabled.'
- })
+ raise BadRequest(
+ "Registrations disabled", {"email": "Registrations are disabled."}
+ )
j = await request.get_json()
- if not 'password' in j:
+ if not "password" in j:
# we need a password to generate a token.
# passwords are optional, so
- j['password'] = 'default_password'
+ j["password"] = "default_password"
j = validate(j, REGISTER)
# they're optional
- email = j.get('email')
- invite = j.get('invite')
+ email = j.get("email")
+ invite = j.get("invite")
- username, password = j['username'], j['password']
+ username, password = j["username"], j["password"]
- new_id, pwd_hash = await create_user(
- username, email, password, app.db
- )
+ new_id, pwd_hash = await create_user(username, email, password, app.db)
if invite:
try:
await use_invite(new_id, invite)
except Exception:
- log.exception('failed to use invite for register {} {!r}',
- new_id, invite)
+ log.exception("failed to use invite for register {} {!r}", new_id, invite)
- return jsonify({
- 'token': make_token(new_id, pwd_hash)
- })
+ return jsonify({"token": make_token(new_id, pwd_hash)})
-@bp.route('/register_inv', methods=['POST'])
+@bp.route("/register_inv", methods=["POST"])
async def _register_with_invite():
data = await request.form
data = validate(await request.form, REGISTER_WITH_INVITE)
- invcode = data['invcode']
+ invcode = data["invcode"]
- row = await app.db.fetchrow("""
+ row = await app.db.fetchrow(
+ """
SELECT uses, max_uses
FROM instance_invites
WHERE code = $1
- """, invcode)
+ """,
+ invcode,
+ )
if row is None:
- raise BadRequest('unknown instance invite')
+ raise BadRequest("unknown instance invite")
- if row['max_uses'] != -1 and row['uses'] >= row['max_uses']:
- raise BadRequest('invite expired')
+ if row["max_uses"] != -1 and row["uses"] >= row["max_uses"]:
+ raise BadRequest("invite expired")
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE instance_invites
SET uses = uses + 1
WHERE code = $1
- """, invcode)
+ """,
+ invcode,
+ )
user_id, pwd_hash = await create_user(
- data['username'], data['email'], data['password'], app.db)
+ data["username"], data["email"], data["password"], app.db
+ )
- return jsonify({
- 'token': make_token(user_id, pwd_hash),
- 'user_id': str(user_id),
- })
+ return jsonify({"token": make_token(user_id, pwd_hash), "user_id": str(user_id)})
-@bp.route('/login', methods=['POST'])
+@bp.route("/login", methods=["POST"])
async def login():
j = await request.get_json()
- email, password = j['email'], j['password']
+ email, password = j["email"], j["password"]
- row = await app.db.fetchrow("""
+ row = await app.db.fetchrow(
+ """
SELECT id, password_hash
FROM users
WHERE email = $1
- """, email)
+ """,
+ email,
+ )
if not row:
- return jsonify({'email': ['User not found.']}), 401
+ return jsonify({"email": ["User not found."]}), 401
user_id, pwd_hash = row
if not await check_password(pwd_hash, password):
- return jsonify({'password': ['Password does not match.']}), 401
+ return jsonify({"password": ["Password does not match."]}), 401
- return jsonify({
- 'token': make_token(user_id, pwd_hash)
- })
+ return jsonify({"token": make_token(user_id, pwd_hash)})
-@bp.route('/consent-required', methods=['GET'])
+@bp.route("/consent-required", methods=["GET"])
async def consent_required():
- return jsonify({
- 'required': True,
- })
+ return jsonify({"required": True})
-@bp.route('/verify/resend', methods=['POST'])
+@bp.route("/verify/resend", methods=["POST"])
async def verify_user():
user_id = await token_check()
# TODO: actually verify a user by sending an email
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE users
SET verified = true
WHERE id = $1
- """, user_id)
+ """,
+ user_id,
+ )
new_user = await app.storage.get_user(user_id, True)
- await app.dispatcher.dispatch_user(
- user_id, 'USER_UPDATE', new_user)
+ await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", new_user)
- return '', 204
+ return "", 204
-@bp.route('/logout', methods=['POST'])
+@bp.route("/logout", methods=["POST"])
async def _logout():
"""Called by the client to logout."""
- return '', 204
+ return "", 204
-@bp.route('/fingerprint', methods=['POST'])
+@bp.route("/fingerprint", methods=["POST"])
async def _fingerprint():
"""No idea what this route is about."""
fingerprint_id = get_snowflake()
- fingerprint = f'{fingerprint_id}.{secrets.token_urlsafe(32)}'
+ fingerprint = f"{fingerprint_id}.{secrets.token_urlsafe(32)}"
- return jsonify({
- 'fingerprint': fingerprint
- })
+ return jsonify({"fingerprint": fingerprint})
diff --git a/litecord/blueprints/channel/__init__.py b/litecord/blueprints/channel/__init__.py
index fac85e9..fc08dd2 100644
--- a/litecord/blueprints/channel/__init__.py
+++ b/litecord/blueprints/channel/__init__.py
@@ -21,4 +21,4 @@ from .messages import bp as channel_messages
from .reactions import bp as channel_reactions
from .pins import bp as channel_pins
-__all__ = ['channel_messages', 'channel_reactions', 'channel_pins']
+__all__ = ["channel_messages", "channel_reactions", "channel_pins"]
diff --git a/litecord/blueprints/channel/dm_checks.py b/litecord/blueprints/channel/dm_checks.py
index 1d0851a..e2cb195 100644
--- a/litecord/blueprints/channel/dm_checks.py
+++ b/litecord/blueprints/channel/dm_checks.py
@@ -30,13 +30,18 @@ class ForbiddenDM(Forbidden):
async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
"""Check if the user can DM the peer."""
# first step is checking if there is a block in any direction
- blockrow = await app.db.fetchrow("""
+ blockrow = await app.db.fetchrow(
+ """
SELECT rel_type
FROM relationships
WHERE rel_type = $3
AND user_id IN ($1, $2)
AND peer_id IN ($1, $2)
- """, user_id, peer_id, RelationshipType.BLOCK.value)
+ """,
+ user_id,
+ peer_id,
+ RelationshipType.BLOCK.value,
+ )
if blockrow is not None:
raise ForbiddenDM()
@@ -58,8 +63,8 @@ async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
user_settings = await app.user_storage.get_user_settings(user_id)
peer_settings = await app.user_storage.get_user_settings(peer_id)
- restricted_user_ = [int(v) for v in user_settings['restricted_guilds']]
- restricted_peer_ = [int(v) for v in peer_settings['restricted_guilds']]
+ restricted_user_ = [int(v) for v in user_settings["restricted_guilds"]]
+ restricted_peer_ = [int(v) for v in peer_settings["restricted_guilds"]]
restricted_user = set(restricted_user_)
restricted_peer = set(restricted_peer_)
diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py
index f051f1f..551508b 100644
--- a/litecord/blueprints/channel/messages.py
+++ b/litecord/blueprints/channel/messages.py
@@ -41,18 +41,18 @@ from litecord.images import try_unlink
log = Logger(__name__)
-bp = Blueprint('channel_messages', __name__)
+bp = Blueprint("channel_messages", __name__)
def extract_limit(request_, default: int = 50, max_val: int = 100):
"""Extract a limit kwarg."""
try:
- limit = int(request_.args.get('limit', default))
+ limit = int(request_.args.get("limit", default))
if limit not in range(0, max_val + 1):
raise ValueError()
except (TypeError, ValueError):
- raise BadRequest('limit not int')
+ raise BadRequest("limit not int")
return limit
@@ -61,27 +61,27 @@ def query_tuple_from_args(args: dict, limit: int) -> tuple:
"""Extract a 2-tuple out of request arguments."""
before, after = None, None
- if 'around' in request.args:
+ if "around" in request.args:
average = int(limit / 2)
- around = int(args['around'])
+ around = int(args["around"])
after = around - average
before = around + average
- elif 'before' in args:
- before = int(args['before'])
- elif 'after' in args:
- before = int(args['after'])
+ elif "before" in args:
+ before = int(args["before"])
+ elif "after" in args:
+ before = int(args["after"])
return before, after
-@bp.route('//messages', methods=['GET'])
+@bp.route("//messages", methods=["GET"])
async def get_messages(channel_id):
user_id = await token_check()
ctype, peer_id = await channel_check(user_id, channel_id)
- await channel_perm_check(user_id, channel_id, 'read_history')
+ await channel_perm_check(user_id, channel_id, "read_history")
if ctype == ChannelType.DM:
# make sure both parties will be subbed
@@ -91,42 +91,45 @@ async def get_messages(channel_id):
limit = extract_limit(request, 50)
- where_clause = ''
+ where_clause = ""
before, after = query_tuple_from_args(request.args, limit)
if before:
- where_clause += f'AND id < {before}'
+ where_clause += f"AND id < {before}"
if after:
- where_clause += f'AND id > {after}'
+ where_clause += f"AND id > {after}"
- message_ids = await app.db.fetch(f"""
+ message_ids = await app.db.fetch(
+ f"""
SELECT id
FROM messages
WHERE channel_id = $1 {where_clause}
ORDER BY id DESC
LIMIT {limit}
- """, channel_id)
+ """,
+ channel_id,
+ )
result = []
for message_id in message_ids:
- msg = await app.storage.get_message(message_id['id'], user_id)
+ msg = await app.storage.get_message(message_id["id"], user_id)
if msg is None:
continue
result.append(msg)
- log.info('Fetched {} messages', len(result))
+ log.info("Fetched {} messages", len(result))
return jsonify(result)
-@bp.route('//messages/', methods=['GET'])
+@bp.route("//messages/", methods=["GET"])
async def get_single_message(channel_id, message_id):
user_id = await token_check()
await channel_check(user_id, channel_id)
- await channel_perm_check(user_id, channel_id, 'read_history')
+ await channel_perm_check(user_id, channel_id, "read_history")
message = await app.storage.get_message(message_id, user_id)
@@ -142,11 +145,15 @@ async def _dm_pre_dispatch(channel_id, peer_id):
# check the other party's dm_channel_state
- dm_state = await app.db.fetchval("""
+ dm_state = await app.db.fetchval(
+ """
SELECT dm_id
FROM dm_channel_state
WHERE user_id = $1 AND dm_id = $2
- """, peer_id, channel_id)
+ """,
+ peer_id,
+ channel_id,
+ )
if dm_state:
# the peer already has the channel
@@ -157,18 +164,19 @@ async def _dm_pre_dispatch(channel_id, peer_id):
# dispatch CHANNEL_CREATE so the client knows which
# channel the future event is about
- await app.dispatcher.dispatch_user(peer_id, 'CHANNEL_CREATE', dm_chan)
+ await app.dispatcher.dispatch_user(peer_id, "CHANNEL_CREATE", dm_chan)
# subscribe the peer to the channel
- await app.dispatcher.sub('channel', channel_id, peer_id)
+ await app.dispatcher.sub("channel", channel_id, peer_id)
# insert it on dm_channel_state so the client
# is subscribed on the future
await try_dm_state(peer_id, channel_id)
-async def create_message(channel_id: int, actual_guild_id: int,
- author_id: int, data: dict) -> int:
+async def create_message(
+ channel_id: int, actual_guild_id: int, author_id: int, data: dict
+) -> int:
message_id = get_snowflake()
async with app.db.acquire() as conn:
@@ -185,32 +193,32 @@ async def create_message(channel_id: int, actual_guild_id: int,
channel_id,
actual_guild_id,
author_id,
- data['content'],
-
- data['tts'],
- data['everyone_mention'],
-
- data['nonce'],
+ data["content"],
+ data["tts"],
+ data["everyone_mention"],
+ data["nonce"],
MessageType.DEFAULT.value,
- data.get('embeds') or []
+ data.get("embeds") or [],
)
return message_id
-async def msg_guild_text_mentions(payload: dict, guild_id: int,
- mentions_everyone: bool, mentions_here: bool):
+
+async def msg_guild_text_mentions(
+ payload: dict, guild_id: int, mentions_everyone: bool, mentions_here: bool
+):
"""Calculates mention data side-effects."""
- channel_id = int(payload['channel_id'])
+ channel_id = int(payload["channel_id"])
# calculate the user ids we'll bump the mention count for
uids = set()
# first is extracting user mentions
- for mention in payload['mentions']:
- uids.add(int(mention['id']))
+ for mention in payload["mentions"]:
+ uids.add(int(mention["id"]))
# then role mentions
- for role_mention in payload['mention_roles']:
+ for role_mention in payload["mention_roles"]:
role_id = int(role_mention)
member_ids = await app.storage.get_role_members(role_id)
@@ -223,11 +231,14 @@ async def msg_guild_text_mentions(payload: dict, guild_id: int,
if mentions_here:
uids = set()
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE user_read_state
SET mention_count = mention_count + 1
WHERE channel_id = $1
- """, channel_id)
+ """,
+ channel_id,
+ )
# at-here updates the read state
# for all users, including the ones
@@ -238,19 +249,26 @@ async def msg_guild_text_mentions(payload: dict, guild_id: int,
member_ids = await app.storage.get_member_ids(guild_id)
- await app.db.executemany("""
+ await app.db.executemany(
+ """
UPDATE user_read_state
SET mention_count = mention_count + 1
WHERE channel_id = $1 AND user_id = $2
- """, [(channel_id, uid) for uid in member_ids])
+ """,
+ [(channel_id, uid) for uid in member_ids],
+ )
for user_id in uids:
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE user_read_state
SET mention_count = mention_count + 1
WHERE user_id = $1
AND channel_id = $2
- """, user_id, channel_id)
+ """,
+ user_id,
+ channel_id,
+ )
async def msg_create_request() -> tuple:
@@ -264,12 +282,12 @@ async def msg_create_request() -> tuple:
# NOTE: embed isn't set on form data
json_from_form = {
- 'content': form.get('content', ''),
- 'nonce': form.get('nonce', '0'),
- 'tts': json.loads(form.get('tts', 'false')),
+ "content": form.get("content", ""),
+ "nonce": form.get("nonce", "0"),
+ "tts": json.loads(form.get("tts", "false")),
}
- payload_json = json.loads(form.get('payload_json', '{}'))
+ payload_json = json.loads(form.get("payload_json", "{}"))
json_from_form.update(request_json)
json_from_form.update(payload_json)
@@ -283,20 +301,19 @@ async def msg_create_request() -> tuple:
def msg_create_check_content(payload: dict, files: list, *, use_embeds=False):
"""Check if there is actually any content being sent to us."""
- has_content = bool(payload.get('content', ''))
+ has_content = bool(payload.get("content", ""))
has_files = len(files) > 0
- embed_field = 'embeds' if use_embeds else 'embed'
+ embed_field = "embeds" if use_embeds else "embed"
has_embed = embed_field in payload and payload.get(embed_field) is not None
has_total_content = has_content or has_embed or has_files
if not has_total_content:
- raise BadRequest('No content has been provided.')
+ raise BadRequest("No content has been provided.")
-async def msg_add_attachment(message_id: int, channel_id: int,
- attachment_file) -> int:
+async def msg_add_attachment(message_id: int, channel_id: int, attachment_file) -> int:
"""Add an attachment to a message.
Parameters
@@ -318,7 +335,7 @@ async def msg_add_attachment(message_id: int, channel_id: int,
# understand file info
mime = attachment_file.mimetype
- is_image = mime.startswith('image/')
+ is_image = mime.startswith("image/")
img_width, img_height = None, None
@@ -346,17 +363,22 @@ async def msg_add_attachment(message_id: int, channel_id: int,
VALUES
($1, $2, $3, $4, $5, $6, $7, $8)
""",
- attachment_id, channel_id, message_id,
- filename, file_size,
- is_image, img_width, img_height)
+ attachment_id,
+ channel_id,
+ message_id,
+ filename,
+ file_size,
+ is_image,
+ img_width,
+ img_height,
+ )
- ext = filename.split('.')[-1]
+ ext = filename.split(".")[-1]
- with open(f'attachments/{attachment_id}.{ext}', 'wb') as attach_file:
+ with open(f"attachments/{attachment_id}.{ext}", "wb") as attach_file:
attach_file.write(attachment_file.stream.read())
- log.debug('written {} bytes for attachment id {}',
- file_size, attachment_id)
+ log.debug("written {} bytes for attachment id {}", file_size, attachment_id)
return attachment_id
@@ -364,12 +386,12 @@ async def msg_add_attachment(message_id: int, channel_id: int,
async def _spawn_embed(app_, payload, **kwargs):
app_.sched.spawn(
process_url_embed(
- app_.config, app_.storage, app_.dispatcher, app_.session,
- payload, **kwargs)
+ app_.config, app_.storage, app_.dispatcher, app_.session, payload, **kwargs
+ )
)
-@bp.route('//messages', methods=['POST'])
+@bp.route("//messages", methods=["POST"])
async def _create_message(channel_id):
"""Create a message."""
@@ -379,7 +401,7 @@ async def _create_message(channel_id):
actual_guild_id = None
if ctype in GUILD_CHANS:
- await channel_perm_check(user_id, channel_id, 'send_messages')
+ await channel_perm_check(user_id, channel_id, "send_messages")
actual_guild_id = guild_id
payload_json, files = await msg_create_request()
@@ -394,29 +416,31 @@ async def _create_message(channel_id):
await dm_pre_check(user_id, channel_id, guild_id)
can_everyone = await channel_perm_check(
- user_id, channel_id, 'mention_everyone', False
+ user_id, channel_id, "mention_everyone", False
)
- mentions_everyone = ('@everyone' in j['content']) and can_everyone
- mentions_here = ('@here' in j['content']) and can_everyone
+ mentions_everyone = ("@everyone" in j["content"]) and can_everyone
+ mentions_here = ("@here" in j["content"]) and can_everyone
- is_tts = (j.get('tts', False) and
- await channel_perm_check(
- user_id, channel_id, 'send_tts_messages', False
- ))
+ is_tts = j.get("tts", False) and await channel_perm_check(
+ user_id, channel_id, "send_tts_messages", False
+ )
message_id = await create_message(
- channel_id, actual_guild_id, user_id, {
- 'content': j['content'],
- 'tts': is_tts,
- 'nonce': int(j.get('nonce', 0)),
- 'everyone_mention': mentions_everyone or mentions_here,
-
+ channel_id,
+ actual_guild_id,
+ user_id,
+ {
+ "content": j["content"],
+ "tts": is_tts,
+ "nonce": int(j.get("nonce", 0)),
+ "everyone_mention": mentions_everyone or mentions_here,
# fill_embed takes care of filling proxy and width/height
- 'embeds': ([await fill_embed(j['embed'])]
- if j.get('embed') is not None
- else []),
- })
+ "embeds": (
+ [await fill_embed(j["embed"])] if j.get("embed") is not None else []
+ ),
+ },
+ )
# for each file given, we add it as an attachment
for pre_attachment in files:
@@ -429,8 +453,7 @@ async def _create_message(channel_id):
await _dm_pre_dispatch(channel_id, user_id)
await _dm_pre_dispatch(channel_id, guild_id)
- await app.dispatcher.dispatch('channel', channel_id,
- 'MESSAGE_CREATE', payload)
+ await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload)
# spawn url processor for embedding of images
perms = await get_permissions(user_id, channel_id)
@@ -438,54 +461,71 @@ async def _create_message(channel_id):
await _spawn_embed(app, payload)
# update read state for the author
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE user_read_state
SET last_message_id = $1
WHERE channel_id = $2 AND user_id = $3
- """, message_id, channel_id, user_id)
+ """,
+ message_id,
+ channel_id,
+ user_id,
+ )
if ctype == ChannelType.GUILD_TEXT:
await msg_guild_text_mentions(
- payload, guild_id, mentions_everyone, mentions_here)
+ payload, guild_id, mentions_everyone, mentions_here
+ )
return jsonify(payload)
-@bp.route('//messages/', methods=['PATCH'])
+@bp.route("//messages/", methods=["PATCH"])
async def edit_message(channel_id, message_id):
user_id = await token_check()
_ctype, _guild_id = await channel_check(user_id, channel_id)
- author_id = await app.db.fetchval("""
+ author_id = await app.db.fetchval(
+ """
SELECT author_id FROM messages
WHERE messages.id = $1
- """, message_id)
+ """,
+ message_id,
+ )
if not author_id == user_id:
- raise Forbidden('You can not edit this message')
+ raise Forbidden("You can not edit this message")
j = await request.get_json()
- updated = 'content' in j or 'embed' in j
+ updated = "content" in j or "embed" in j
old_message = await app.storage.get_message(message_id)
- if 'content' in j:
- await app.db.execute("""
+ if "content" in j:
+ await app.db.execute(
+ """
UPDATE messages
SET content=$1
WHERE messages.id = $2
- """, j['content'], message_id)
+ """,
+ j["content"],
+ message_id,
+ )
- if 'embed' in j:
- embeds = [await fill_embed(j['embed'])]
+ if "embed" in j:
+ embeds = [await fill_embed(j["embed"])]
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE messages
SET embeds=$1
WHERE messages.id = $2
- """, embeds, message_id)
+ """,
+ embeds,
+ message_id,
+ )
# do not spawn process_url_embed since we already have embeds.
- elif 'content' in j:
+ elif "content" in j:
# if there weren't any embed changes BUT
# we had a content change, we dispatch process_url_embed but with
# an artificial delay.
@@ -495,46 +535,55 @@ async def edit_message(channel_id, message_id):
# BEFORE the MESSAGE_UPDATE with the new embeds (based on content)
perms = await get_permissions(user_id, channel_id)
if perms.bits.embed_links:
- await _spawn_embed(app, {
- 'id': message_id,
- 'channel_id': channel_id,
- 'content': j['content'],
- 'embeds': old_message['embeds']
- }, delay=0.2)
+ await _spawn_embed(
+ app,
+ {
+ "id": message_id,
+ "channel_id": channel_id,
+ "content": j["content"],
+ "embeds": old_message["embeds"],
+ },
+ delay=0.2,
+ )
# only set new timestamp upon actual update
if updated:
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE messages
SET edited_at = (now() at time zone 'utc')
WHERE id = $1
- """, message_id)
+ """,
+ message_id,
+ )
message = await app.storage.get_message(message_id, user_id)
# only dispatch MESSAGE_UPDATE if any update
# actually happened
if updated:
- await app.dispatcher.dispatch('channel', channel_id,
- 'MESSAGE_UPDATE', message)
+ await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_UPDATE", message)
return jsonify(message)
async def _del_msg_fkeys(message_id: int):
- attachs = await app.db.fetch("""
+ attachs = await app.db.fetch(
+ """
SELECT id FROM attachments
WHERE message_id = $1
- """, message_id)
+ """,
+ message_id,
+ )
- attachs = [r['id'] for r in attachs]
+ attachs = [r["id"] for r in attachs]
- attachments = Path('./attachments')
+ attachments = Path("./attachments")
for attach_id in attachs:
# anything starting with the given attachment shall be
# deleted, because there may be resizes of the original
# attachment laying around.
- for filepath in attachments.glob(f'{attach_id}*'):
+ for filepath in attachments.glob(f"{attach_id}*"):
try_unlink(filepath)
# after trying to delete all available attachments, delete
@@ -542,51 +591,64 @@ async def _del_msg_fkeys(message_id: int):
# take the chance and delete all the data from the other tables too!
- tables = ['attachments', 'message_webhook_info',
- 'message_reactions', 'channel_pins']
+ tables = [
+ "attachments",
+ "message_webhook_info",
+ "message_reactions",
+ "channel_pins",
+ ]
for table in tables:
- await app.db.execute(f"""
+ await app.db.execute(
+ f"""
DELETE FROM {table}
WHERE message_id = $1
- """, message_id)
+ """,
+ message_id,
+ )
-@bp.route('//messages/', methods=['DELETE'])
+@bp.route("//messages/", methods=["DELETE"])
async def delete_message(channel_id, message_id):
user_id = await token_check()
_ctype, guild_id = await channel_check(user_id, channel_id)
- author_id = await app.db.fetchval("""
+ author_id = await app.db.fetchval(
+ """
SELECT author_id FROM messages
WHERE messages.id = $1
- """, message_id)
-
- by_perm = await channel_perm_check(
- user_id, channel_id, 'manage_messages', False
+ """,
+ message_id,
)
+ by_perm = await channel_perm_check(user_id, channel_id, "manage_messages", False)
+
by_ownership = author_id == user_id
can_delete = by_perm or by_ownership
if not can_delete:
- raise Forbidden('You can not delete this message')
+ raise Forbidden("You can not delete this message")
await _del_msg_fkeys(message_id)
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM messages
WHERE messages.id = $1
- """, message_id)
+ """,
+ message_id,
+ )
await app.dispatcher.dispatch(
- 'channel', channel_id,
- 'MESSAGE_DELETE', {
- 'id': str(message_id),
- 'channel_id': str(channel_id),
-
+ "channel",
+ channel_id,
+ "MESSAGE_DELETE",
+ {
+ "id": str(message_id),
+ "channel_id": str(channel_id),
# for lazy guilds
- 'guild_id': str(guild_id),
- })
+ "guild_id": str(guild_id),
+ },
+ )
- return '', 204
+ return "", 204
diff --git a/litecord/blueprints/channel/pins.py b/litecord/blueprints/channel/pins.py
index f5245d7..5cdef0a 100644
--- a/litecord/blueprints/channel/pins.py
+++ b/litecord/blueprints/channel/pins.py
@@ -28,28 +28,32 @@ from litecord.system_messages import send_sys_message
from litecord.enums import MessageType, SYS_MESSAGES
from litecord.errors import BadRequest
-bp = Blueprint('channel_pins', __name__)
+bp = Blueprint("channel_pins", __name__)
class SysMsgInvalidAction(BadRequest):
"""Invalid action on a system message."""
+
error_code = 50021
-@bp.route('//pins', methods=['GET'])
+@bp.route("//pins", methods=["GET"])
async def get_pins(channel_id):
"""Get the pins for a channel"""
user_id = await token_check()
await channel_check(user_id, channel_id)
- ids = await app.db.fetch("""
+ ids = await app.db.fetch(
+ """
SELECT message_id
FROM channel_pins
WHERE channel_id = $1
ORDER BY message_id DESC
- """, channel_id)
+ """,
+ channel_id,
+ )
- ids = [r['message_id'] for r in ids]
+ ids = [r["message_id"] for r in ids]
res = []
for message_id in ids:
@@ -60,80 +64,96 @@ async def get_pins(channel_id):
return jsonify(res)
-@bp.route('//pins/', methods=['PUT'])
+@bp.route("//pins/", methods=["PUT"])
async def add_pin(channel_id, message_id):
"""Add a pin to a channel"""
user_id = await token_check()
_ctype, guild_id = await channel_check(user_id, channel_id)
- await channel_perm_check(user_id, channel_id, 'manage_messages')
+ await channel_perm_check(user_id, channel_id, "manage_messages")
- mtype = await app.db.fetchval("""
+ mtype = await app.db.fetchval(
+ """
SELECT message_type
FROM messages
WHERE id = $1
- """, message_id)
+ """,
+ message_id,
+ )
if mtype in SYS_MESSAGES:
- raise SysMsgInvalidAction(
- 'Cannot execute action on a system message')
+ raise SysMsgInvalidAction("Cannot execute action on a system message")
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO channel_pins (channel_id, message_id)
VALUES ($1, $2)
- """, channel_id, message_id)
+ """,
+ channel_id,
+ message_id,
+ )
- row = await app.db.fetchrow("""
+ row = await app.db.fetchrow(
+ """
SELECT message_id
FROM channel_pins
WHERE channel_id = $1
ORDER BY message_id ASC
LIMIT 1
- """, channel_id)
-
- timestamp = snowflake_datetime(row['message_id'])
-
- await app.dispatcher.dispatch(
- 'channel', channel_id, 'CHANNEL_PINS_UPDATE',
- {
- 'channel_id': str(channel_id),
- 'last_pin_timestamp': timestamp_(timestamp)
- }
+ """,
+ channel_id,
)
- await send_sys_message(app, channel_id,
- MessageType.CHANNEL_PINNED_MESSAGE,
- message_id, user_id)
+ timestamp = snowflake_datetime(row["message_id"])
- return '', 204
+ await app.dispatcher.dispatch(
+ "channel",
+ channel_id,
+ "CHANNEL_PINS_UPDATE",
+ {"channel_id": str(channel_id), "last_pin_timestamp": timestamp_(timestamp)},
+ )
+
+ await send_sys_message(
+ app, channel_id, MessageType.CHANNEL_PINNED_MESSAGE, message_id, user_id
+ )
+
+ return "", 204
-@bp.route('//pins/', methods=['DELETE'])
+@bp.route("//pins/", methods=["DELETE"])
async def delete_pin(channel_id, message_id):
user_id = await token_check()
_ctype, guild_id = await channel_check(user_id, channel_id)
- await channel_perm_check(user_id, channel_id, 'manage_messages')
+ await channel_perm_check(user_id, channel_id, "manage_messages")
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM channel_pins
WHERE channel_id = $1 AND message_id = $2
- """, channel_id, message_id)
+ """,
+ channel_id,
+ message_id,
+ )
- row = await app.db.fetchrow("""
+ row = await app.db.fetchrow(
+ """
SELECT message_id
FROM channel_pins
WHERE channel_id = $1
ORDER BY message_id ASC
LIMIT 1
- """, channel_id)
+ """,
+ channel_id,
+ )
- timestamp = snowflake_datetime(row['message_id'])
+ timestamp = snowflake_datetime(row["message_id"])
await app.dispatcher.dispatch(
- 'channel', channel_id, 'CHANNEL_PINS_UPDATE', {
- 'channel_id': str(channel_id),
- 'last_pin_timestamp': timestamp.isoformat()
- })
+ "channel",
+ channel_id,
+ "CHANNEL_PINS_UPDATE",
+ {"channel_id": str(channel_id), "last_pin_timestamp": timestamp.isoformat()},
+ )
- return '', 204
+ return "", 204
diff --git a/litecord/blueprints/channel/reactions.py b/litecord/blueprints/channel/reactions.py
index 7ede137..4c8dabd 100644
--- a/litecord/blueprints/channel/reactions.py
+++ b/litecord/blueprints/channel/reactions.py
@@ -26,17 +26,15 @@ from logbook import Logger
from litecord.utils import async_map
from litecord.blueprints.auth import token_check
from litecord.blueprints.checks import channel_check, channel_perm_check
-from litecord.blueprints.channel.messages import (
- query_tuple_from_args, extract_limit
-)
+from litecord.blueprints.channel.messages import query_tuple_from_args, extract_limit
from litecord.enums import GUILD_CHANS
log = Logger(__name__)
-bp = Blueprint('channel_reactions', __name__)
+bp = Blueprint("channel_reactions", __name__)
-BASEPATH = '//messages//reactions'
+BASEPATH = "//messages//reactions"
class EmojiType(IntEnum):
@@ -51,16 +49,14 @@ def emoji_info_from_str(emoji: str) -> tuple:
# unicode emoji just have the raw unicode.
# try checking if the emoji is custom or unicode
- emoji_type = 0 if ':' in emoji else 1
+ emoji_type = 0 if ":" in emoji else 1
emoji_type = EmojiType(emoji_type)
# extract the emoji id OR the unicode value of the emoji
# depending if it is custom or not
- emoji_id = (int(emoji.split(':')[1])
- if emoji_type == EmojiType.CUSTOM
- else emoji)
+ emoji_id = int(emoji.split(":")[1]) if emoji_type == EmojiType.CUSTOM else emoji
- emoji_name = emoji.split(':')[0]
+ emoji_name = emoji.split(":")[0]
return emoji_type, emoji_id, emoji_name
@@ -68,27 +64,27 @@ def emoji_info_from_str(emoji: str) -> tuple:
def partial_emoji(emoji_type, emoji_id, emoji_name) -> dict:
print(emoji_type, emoji_id, emoji_name)
return {
- 'id': None if emoji_type == EmojiType.UNICODE else emoji_id,
- 'name': emoji_name if emoji_type == EmojiType.UNICODE else emoji_id
+ "id": None if emoji_type == EmojiType.UNICODE else emoji_id,
+ "name": emoji_name if emoji_type == EmojiType.UNICODE else emoji_id,
}
def _make_payload(user_id, channel_id, message_id, partial):
return {
- 'user_id': str(user_id),
- 'channel_id': str(channel_id),
- 'message_id': str(message_id),
- 'emoji': partial
+ "user_id": str(user_id),
+ "channel_id": str(channel_id),
+ "message_id": str(message_id),
+ "emoji": partial,
}
-@bp.route(f'{BASEPATH}//@me', methods=['PUT'])
+@bp.route(f"{BASEPATH}//@me", methods=["PUT"])
async def add_reaction(channel_id: int, message_id: int, emoji: str):
"""Put a reaction."""
user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id)
- await channel_perm_check(user_id, channel_id, 'read_history')
+ await channel_perm_check(user_id, channel_id, "read_history")
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
@@ -97,52 +93,64 @@ async def add_reaction(channel_id: int, message_id: int, emoji: str):
# ADD_REACTIONS is only checked when this is the first
# reaction in a message.
- reaction_count = await app.db.fetchval("""
+ reaction_count = await app.db.fetchval(
+ """
SELECT COUNT(*)
FROM message_reactions
WHERE message_id = $1
AND emoji_type = $2
AND emoji_id = $3
AND emoji_text = $4
- """, message_id, emoji_type, emoji_id, emoji_text)
+ """,
+ message_id,
+ emoji_type,
+ emoji_id,
+ emoji_text,
+ )
if reaction_count == 0:
- await channel_perm_check(user_id, channel_id, 'add_reactions')
+ await channel_perm_check(user_id, channel_id, "add_reactions")
await app.db.execute(
"""
INSERT INTO message_reactions (message_id, user_id,
emoji_type, emoji_id, emoji_text)
VALUES ($1, $2, $3, $4, $5)
- """, message_id, user_id, emoji_type,
-
+ """,
+ message_id,
+ user_id,
+ emoji_type,
# if it is custom, we put the emoji_id on emoji_id
# column, if it isn't, we put it on emoji_text
# column.
- emoji_id, emoji_text
+ emoji_id,
+ emoji_text,
)
partial = partial_emoji(emoji_type, emoji_id, emoji_name)
payload = _make_payload(user_id, channel_id, message_id, partial)
if ctype in GUILD_CHANS:
- payload['guild_id'] = str(guild_id)
+ payload["guild_id"] = str(guild_id)
await app.dispatcher.dispatch(
- 'channel', channel_id, 'MESSAGE_REACTION_ADD', payload)
+ "channel", channel_id, "MESSAGE_REACTION_ADD", payload
+ )
- return '', 204
+ return "", 204
def emoji_sql(emoji_type, emoji_id, emoji_name, param=4):
"""Extract SQL clauses to search for specific emoji
in the message_reactions table."""
- param = f'${param}'
+ param = f"${param}"
# know which column to filter with
- where_ext = (f'AND emoji_id = {param}'
- if emoji_type == EmojiType.CUSTOM else
- f'AND emoji_text = {param}')
+ where_ext = (
+ f"AND emoji_id = {param}"
+ if emoji_type == EmojiType.CUSTOM
+ else f"AND emoji_text = {param}"
+ )
# which emoji to remove (custom or unicode)
main_emoji = emoji_id if emoji_type == EmojiType.CUSTOM else emoji_name
@@ -157,8 +165,7 @@ def _emoji_sql_simple(emoji: str, param=4):
return emoji_sql(emoji_type, emoji_id, emoji_name, param)
-async def remove_reaction(channel_id: int, message_id: int,
- user_id: int, emoji: str):
+async def remove_reaction(channel_id: int, message_id: int, user_id: int, emoji: str):
ctype, guild_id = await channel_check(user_id, channel_id)
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
@@ -171,39 +178,45 @@ async def remove_reaction(channel_id: int, message_id: int,
AND user_id = $2
AND emoji_type = $3
{where_ext}
- """, message_id, user_id, emoji_type, main_emoji)
+ """,
+ message_id,
+ user_id,
+ emoji_type,
+ main_emoji,
+ )
partial = partial_emoji(emoji_type, emoji_id, emoji_name)
payload = _make_payload(user_id, channel_id, message_id, partial)
if ctype in GUILD_CHANS:
- payload['guild_id'] = str(guild_id)
+ payload["guild_id"] = str(guild_id)
await app.dispatcher.dispatch(
- 'channel', channel_id, 'MESSAGE_REACTION_REMOVE', payload)
+ "channel", channel_id, "MESSAGE_REACTION_REMOVE", payload
+ )
-@bp.route(f'{BASEPATH}//@me', methods=['DELETE'])
+@bp.route(f"{BASEPATH}//@me", methods=["DELETE"])
async def remove_own_reaction(channel_id, message_id, emoji):
"""Remove a reaction."""
user_id = await token_check()
await remove_reaction(channel_id, message_id, user_id, emoji)
- return '', 204
+ return "", 204
-@bp.route(f'{BASEPATH}//', methods=['DELETE'])
+@bp.route(f"{BASEPATH}//", methods=["DELETE"])
async def remove_user_reaction(channel_id, message_id, emoji, other_id):
"""Remove a reaction made by another user."""
user_id = await token_check()
- await channel_perm_check(user_id, channel_id, 'manage_messages')
+ await channel_perm_check(user_id, channel_id, "manage_messages")
await remove_reaction(channel_id, message_id, other_id, emoji)
- return '', 204
+ return "", 204
-@bp.route(f'{BASEPATH}/', methods=['GET'])
+@bp.route(f"{BASEPATH}/", methods=["GET"])
async def list_users_reaction(channel_id, message_id, emoji):
"""Get the list of all users who reacted with a certain emoji."""
user_id = await token_check()
@@ -215,42 +228,49 @@ async def list_users_reaction(channel_id, message_id, emoji):
limit = extract_limit(request, 25)
before, after = query_tuple_from_args(request.args, limit)
- before_clause = 'AND user_id < $2' if before else ''
- after_clause = 'AND user_id > $3' if after else ''
+ before_clause = "AND user_id < $2" if before else ""
+ after_clause = "AND user_id > $3" if after else ""
where_ext, main_emoji = _emoji_sql_simple(emoji, 4)
- rows = await app.db.fetch(f"""
+ rows = await app.db.fetch(
+ f"""
SELECT user_id
FROM message_reactions
WHERE message_id = $1 {before_clause} {after_clause} {where_ext}
- """, message_id, before, after, main_emoji)
+ """,
+ message_id,
+ before,
+ after,
+ main_emoji,
+ )
- user_ids = [r['user_id'] for r in rows]
+ user_ids = [r["user_id"] for r in rows]
users = await async_map(app.storage.get_user, user_ids)
return jsonify(users)
-@bp.route(f'{BASEPATH}', methods=['DELETE'])
+@bp.route(f"{BASEPATH}", methods=["DELETE"])
async def remove_all_reactions(channel_id, message_id):
"""Remove all reactions in a message."""
user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id)
- await channel_perm_check(user_id, channel_id, 'manage_messages')
+ await channel_perm_check(user_id, channel_id, "manage_messages")
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM message_reactions
WHERE message_id = $1
- """, message_id)
+ """,
+ message_id,
+ )
- payload = {
- 'channel_id': str(channel_id),
- 'message_id': str(message_id),
- }
+ payload = {"channel_id": str(channel_id), "message_id": str(message_id)}
if ctype in GUILD_CHANS:
- payload['guild_id'] = str(guild_id)
+ payload["guild_id"] = str(guild_id)
await app.dispatcher.dispatch(
- 'channel', channel_id, 'MESSAGE_REACTION_REMOVE_ALL', payload)
+ "channel", channel_id, "MESSAGE_REACTION_REMOVE_ALL", payload
+ )
diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py
index 2c9824b..a4b6c45 100644
--- a/litecord/blueprints/channels.py
+++ b/litecord/blueprints/channels.py
@@ -28,24 +28,26 @@ from litecord.auth import token_check
from litecord.enums import ChannelType, GUILD_CHANS, MessageType, MessageFlags
from litecord.errors import ChannelNotFound, Forbidden, BadRequest
from litecord.schemas import (
- validate, CHAN_UPDATE, CHAN_OVERWRITE, SEARCH_CHANNEL, GROUP_DM_UPDATE,
+ validate,
+ CHAN_UPDATE,
+ CHAN_OVERWRITE,
+ SEARCH_CHANNEL,
+ GROUP_DM_UPDATE,
BULK_DELETE,
)
from litecord.blueprints.checks import channel_check, channel_perm_check
from litecord.system_messages import send_sys_message
-from litecord.blueprints.dm_channels import (
- gdm_remove_recipient, gdm_destroy
-)
+from litecord.blueprints.dm_channels import gdm_remove_recipient, gdm_destroy
from litecord.utils import search_result_from_list
from litecord.embed.messages import process_url_embed, msg_update_embeds
from litecord.snowflake import snowflake_datetime
log = Logger(__name__)
-bp = Blueprint('channels', __name__)
+bp = Blueprint("channels", __name__)
-@bp.route('/', methods=['GET'])
+@bp.route("/", methods=["GET"])
async def get_channel(channel_id):
"""Get a single channel's information"""
user_id = await token_check()
@@ -56,7 +58,7 @@ async def get_channel(channel_id):
chan = await app.storage.get_channel(channel_id)
if not chan:
- raise ChannelNotFound('single channel not found')
+ raise ChannelNotFound("single channel not found")
return jsonify(chan)
@@ -64,106 +66,129 @@ async def get_channel(channel_id):
async def __guild_chan_sql(guild_id, channel_id, field: str) -> str:
"""Update a guild's channel id field to NULL,
if it was set to the given channel id before."""
- return await app.db.execute(f"""
+ return await app.db.execute(
+ f"""
UPDATE guilds
SET {field} = NULL
WHERE guilds.id = $1 AND {field} = $2
- """, guild_id, channel_id)
+ """,
+ guild_id,
+ channel_id,
+ )
async def _update_guild_chan_text(guild_id: int, channel_id: int):
- res_embed = await __guild_chan_sql(
- guild_id, channel_id, 'embed_channel_id')
+ res_embed = await __guild_chan_sql(guild_id, channel_id, "embed_channel_id")
- res_widget = await __guild_chan_sql(
- guild_id, channel_id, 'widget_channel_id')
+ res_widget = await __guild_chan_sql(guild_id, channel_id, "widget_channel_id")
- res_system = await __guild_chan_sql(
- guild_id, channel_id, 'system_channel_id')
+ res_system = await __guild_chan_sql(guild_id, channel_id, "system_channel_id")
# if none of them were actually updated,
# ignore and dont dispatch anything
- if 'UPDATE 1' not in (res_embed, res_widget, res_system):
+ if "UPDATE 1" not in (res_embed, res_widget, res_system):
return
# at least one of the fields were updated,
# dispatch GUILD_UPDATE
guild = await app.storage.get_guild(guild_id)
- await app.dispatcher.dispatch_guild(
- guild_id, 'GUILD_UPDATE', guild)
+ await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
async def _update_guild_chan_voice(guild_id: int, channel_id: int):
- res = await __guild_chan_sql(guild_id, channel_id, 'afk_channel_id')
+ res = await __guild_chan_sql(guild_id, channel_id, "afk_channel_id")
# guild didnt update
- if res == 'UPDATE 0':
+ if res == "UPDATE 0":
return
guild = await app.storage.get_guild(guild_id)
- await app.dispatcher.dispatch_guild(
- guild_id, 'GUILD_UPDATE', guild)
+ await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
async def _update_guild_chan_cat(guild_id: int, channel_id: int):
# get all channels that were childs of the category
- childs = await app.db.fetch("""
+ childs = await app.db.fetch(
+ """
SELECT id
FROM guild_channels
WHERE guild_id = $1 AND parent_id = $2
- """, guild_id, channel_id)
- childs = [c['id'] for c in childs]
+ """,
+ guild_id,
+ channel_id,
+ )
+ childs = [c["id"] for c in childs]
# update every child channel to parent_id = NULL
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE guild_channels
SET parent_id = NULL
WHERE guild_id = $1 AND parent_id = $2
- """, guild_id, channel_id)
+ """,
+ guild_id,
+ channel_id,
+ )
# tell all people in the guild of the category removal
for child_id in childs:
child = await app.storage.get_channel(child_id)
- await app.dispatcher.dispatch_guild(
- guild_id, 'CHANNEL_UPDATE', child
- )
+ await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", child)
async def delete_messages(channel_id):
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM channel_pins
WHERE channel_id = $1
- """, channel_id)
+ """,
+ channel_id,
+ )
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM user_read_state
WHERE channel_id = $1
- """, channel_id)
+ """,
+ channel_id,
+ )
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM messages
WHERE channel_id = $1
- """, channel_id)
+ """,
+ channel_id,
+ )
async def guild_cleanup(channel_id):
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM channel_overwrites
WHERE channel_id = $1
- """, channel_id)
+ """,
+ channel_id,
+ )
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM invites
WHERE channel_id = $1
- """, channel_id)
+ """,
+ channel_id,
+ )
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM webhooks
WHERE channel_id = $1
- """, channel_id)
+ """,
+ channel_id,
+ )
-@bp.route('/', methods=['DELETE'])
+@bp.route("/", methods=["DELETE"])
async def close_channel(channel_id):
"""Close or delete a channel."""
user_id = await token_check()
@@ -184,9 +209,8 @@ async def close_channel(channel_id):
}[ctype]
main_tbl = {
- ChannelType.GUILD_TEXT: 'guild_text_channels',
- ChannelType.GUILD_VOICE: 'guild_voice_channels',
-
+ ChannelType.GUILD_TEXT: "guild_text_channels",
+ ChannelType.GUILD_VOICE: "guild_voice_channels",
# TODO: categories?
}[ctype]
@@ -199,29 +223,37 @@ async def close_channel(channel_id):
await delete_messages(channel_id)
await guild_cleanup(channel_id)
- await app.db.execute(f"""
+ await app.db.execute(
+ f"""
DELETE FROM {main_tbl}
WHERE id = $1
- """, channel_id)
+ """,
+ channel_id,
+ )
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM guild_channels
WHERE id = $1
- """, channel_id)
+ """,
+ channel_id,
+ )
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM channels
WHERE id = $1
- """, channel_id)
+ """,
+ channel_id,
+ )
# clean its member list representation
- lazy_guilds = app.dispatcher.backends['lazy_guild']
+ lazy_guilds = app.dispatcher.backends["lazy_guild"]
lazy_guilds.remove_channel(channel_id)
- await app.dispatcher.dispatch_guild(
- guild_id, 'CHANNEL_DELETE', chan)
+ await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_DELETE", chan)
- await app.dispatcher.remove('channel', channel_id)
+ await app.dispatcher.remove("channel", channel_id)
return jsonify(chan)
if ctype == ChannelType.DM:
@@ -231,27 +263,34 @@ async def close_channel(channel_id):
# instead, we close the channel for the user that is making
# the request via removing the link between them and
# the channel on dm_channel_state
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM dm_channel_state
WHERE user_id = $1 AND dm_id = $2
- """, user_id, channel_id)
+ """,
+ user_id,
+ channel_id,
+ )
# unsubscribe
- await app.dispatcher.unsub('channel', channel_id, user_id)
+ await app.dispatcher.unsub("channel", channel_id, user_id)
# nothing happens to the other party of the dm channel
- await app.dispatcher.dispatch_user(user_id, 'CHANNEL_DELETE', chan)
+ await app.dispatcher.dispatch_user(user_id, "CHANNEL_DELETE", chan)
return jsonify(chan)
if ctype == ChannelType.GROUP_DM:
await gdm_remove_recipient(channel_id, user_id)
- gdm_count = await app.db.fetchval("""
+ gdm_count = await app.db.fetchval(
+ """
SELECT COUNT(*)
FROM group_dm_members
WHERE id = $1
- """, channel_id)
+ """,
+ channel_id,
+ )
if gdm_count == 0:
# destroy dm
@@ -261,11 +300,15 @@ async def close_channel(channel_id):
async def _update_pos(channel_id, pos: int):
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE guild_channels
SET position = $1
WHERE id = $2
- """, pos, channel_id)
+ """,
+ pos,
+ channel_id,
+ )
async def _mass_chan_update(guild_id, channel_ids: List[Optional[int]]):
@@ -274,20 +317,19 @@ async def _mass_chan_update(guild_id, channel_ids: List[Optional[int]]):
continue
chan = await app.storage.get_channel(channel_id)
- await app.dispatcher.dispatch(
- 'guild', guild_id, 'CHANNEL_UPDATE', chan)
+ await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan)
async def _process_overwrites(channel_id: int, overwrites: list):
for overwrite in overwrites:
# 0 for member overwrite, 1 for role overwrite
- target_type = 0 if overwrite['type'] == 'member' else 1
- target_role = None if target_type == 0 else overwrite['id']
- target_user = overwrite['id'] if target_type == 0 else None
+ target_type = 0 if overwrite["type"] == "member" else 1
+ target_role = None if target_type == 0 else overwrite["id"]
+ target_user = overwrite["id"] if target_type == 0 else None
- col_name = 'target_user' if target_type == 0 else 'target_role'
- constraint_name = f'channel_overwrites_{col_name}_uniq'
+ col_name = "target_user" if target_type == 0 else "target_role"
+ constraint_name = f"channel_overwrites_{col_name}_uniq"
await app.db.execute(
f"""
@@ -301,53 +343,66 @@ async def _process_overwrites(channel_id: int, overwrites: list):
UPDATE
SET allow = $5, deny = $6
""",
- channel_id, target_type,
- target_role, target_user,
- overwrite['allow'], overwrite['deny'])
+ channel_id,
+ target_type,
+ target_role,
+ target_user,
+ overwrite["allow"],
+ overwrite["deny"],
+ )
-@bp.route('//permissions/', methods=['PUT'])
+@bp.route("//permissions/", methods=["PUT"])
async def put_channel_overwrite(channel_id: int, overwrite_id: int):
"""Insert or modify a channel overwrite."""
user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id)
if ctype not in GUILD_CHANS:
- raise ChannelNotFound('Only usable for guild channels.')
+ raise ChannelNotFound("Only usable for guild channels.")
- await channel_perm_check(user_id, guild_id, 'manage_roles')
+ await channel_perm_check(user_id, guild_id, "manage_roles")
j = validate(
# inserting a fake id on the payload so validation passes through
- {**await request.get_json(), **{'id': -1}},
- CHAN_OVERWRITE
+ {**await request.get_json(), **{"id": -1}},
+ CHAN_OVERWRITE,
)
- await _process_overwrites(channel_id, [{
- 'allow': j['allow'],
- 'deny': j['deny'],
- 'type': j['type'],
- 'id': overwrite_id
- }])
+ await _process_overwrites(
+ channel_id,
+ [
+ {
+ "allow": j["allow"],
+ "deny": j["deny"],
+ "type": j["type"],
+ "id": overwrite_id,
+ }
+ ],
+ )
await _mass_chan_update(guild_id, [channel_id])
- return '', 204
+ return "", 204
async def _update_channel_common(channel_id, guild_id: int, j: dict):
- if 'name' in j:
- await app.db.execute("""
+ if "name" in j:
+ await app.db.execute(
+ """
UPDATE guild_channels
SET name = $1
WHERE id = $2
- """, j['name'], channel_id)
+ """,
+ j["name"],
+ channel_id,
+ )
- if 'position' in j:
+ if "position" in j:
channel_data = await app.storage.get_channel_data(guild_id)
chans = [None] * len(channel_data)
for chandata in channel_data:
- chans.insert(chandata['position'], int(chandata['id']))
+ chans.insert(chandata["position"], int(chandata["id"]))
# are we changing to the left or to the right?
@@ -358,7 +413,7 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict):
# channelN-1 going to the position channel2
# was occupying.
current_pos = chans.index(channel_id)
- new_pos = j['position']
+ new_pos = j["position"]
# if the new position is bigger than the current one,
# we're making a left shift of all the channels that are
@@ -366,113 +421,136 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict):
left_shift = new_pos > current_pos
# find all channels that we'll have to shift
- shift_block = (chans[current_pos:new_pos]
- if left_shift else
- chans[new_pos:current_pos]
- )
+ shift_block = (
+ chans[current_pos:new_pos] if left_shift else chans[new_pos:current_pos]
+ )
shift = -1 if left_shift else 1
# do the shift (to the left or to the right)
- await app.db.executemany("""
+ await app.db.executemany(
+ """
UPDATE guild_channels
SET position = position + $1
WHERE id = $2
- """, [(shift, chan_id) for chan_id in shift_block])
+ """,
+ [(shift, chan_id) for chan_id in shift_block],
+ )
await _mass_chan_update(guild_id, shift_block)
# since theres now an empty slot, move current channel to it
await _update_pos(channel_id, new_pos)
- if 'channel_overwrites' in j:
- overwrites = j['channel_overwrites']
+ if "channel_overwrites" in j:
+ overwrites = j["channel_overwrites"]
await _process_overwrites(channel_id, overwrites)
async def _common_guild_chan(channel_id, j: dict):
# common updates to the guild_channels table
- for field in [field for field in j.keys()
- if field in ('nsfw', 'parent_id')]:
- await app.db.execute(f"""
+ for field in [field for field in j.keys() if field in ("nsfw", "parent_id")]:
+ await app.db.execute(
+ f"""
UPDATE guild_channels
SET {field} = $1
WHERE id = $2
- """, j[field], channel_id)
+ """,
+ j[field],
+ channel_id,
+ )
async def _update_text_channel(channel_id: int, j: dict, _user_id: int):
# first do the specific ones related to guild_text_channels
- for field in [field for field in j.keys()
- if field in ('topic', 'rate_limit_per_user')]:
- await app.db.execute(f"""
+ for field in [
+ field for field in j.keys() if field in ("topic", "rate_limit_per_user")
+ ]:
+ await app.db.execute(
+ f"""
UPDATE guild_text_channels
SET {field} = $1
WHERE id = $2
- """, j[field], channel_id)
+ """,
+ j[field],
+ channel_id,
+ )
await _common_guild_chan(channel_id, j)
async def _update_voice_channel(channel_id: int, j: dict, _user_id: int):
# first do the specific ones in guild_voice_channels
- for field in [field for field in j.keys()
- if field in ('bitrate', 'user_limit')]:
- await app.db.execute(f"""
+ for field in [field for field in j.keys() if field in ("bitrate", "user_limit")]:
+ await app.db.execute(
+ f"""
UPDATE guild_voice_channels
SET {field} = $1
WHERE id = $2
- """, j[field], channel_id)
+ """,
+ j[field],
+ channel_id,
+ )
# yes, i'm letting voice channels have nsfw, you cant stop me
await _common_guild_chan(channel_id, j)
async def _update_group_dm(channel_id: int, j: dict, author_id: int):
- if 'name' in j:
- await app.db.execute("""
+ if "name" in j:
+ await app.db.execute(
+ """
UPDATE group_dm_channels
SET name = $1
WHERE id = $2
- """, j['name'], channel_id)
+ """,
+ j["name"],
+ channel_id,
+ )
await send_sys_message(
app, channel_id, MessageType.CHANNEL_NAME_CHANGE, author_id
)
- if 'icon' in j:
+ if "icon" in j:
new_icon = await app.icons.update(
- 'channel-icons', channel_id, j['icon'], always_icon=True
+ "channel-icons", channel_id, j["icon"], always_icon=True
)
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE group_dm_channels
SET icon = $1
WHERE id = $2
- """, new_icon.icon_hash, channel_id)
+ """,
+ new_icon.icon_hash,
+ channel_id,
+ )
await send_sys_message(
app, channel_id, MessageType.CHANNEL_ICON_CHANGE, author_id
)
-@bp.route('/', methods=['PUT', 'PATCH'])
+@bp.route("/", methods=["PUT", "PATCH"])
async def update_channel(channel_id):
"""Update a channel's information"""
user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id)
- if ctype not in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE,
- ChannelType.GROUP_DM):
- raise ChannelNotFound('unable to edit unsupported chan type')
+ if ctype not in (
+ ChannelType.GUILD_TEXT,
+ ChannelType.GUILD_VOICE,
+ ChannelType.GROUP_DM,
+ ):
+ raise ChannelNotFound("unable to edit unsupported chan type")
is_guild = ctype in GUILD_CHANS
if is_guild:
- await channel_perm_check(user_id, channel_id, 'manage_channels')
+ await channel_perm_check(user_id, channel_id, "manage_channels")
- j = validate(await request.get_json(),
- CHAN_UPDATE if is_guild else GROUP_DM_UPDATE)
+ j = validate(await request.get_json(), CHAN_UPDATE if is_guild else GROUP_DM_UPDATE)
# TODO: categories
update_handler = {
@@ -489,30 +567,32 @@ async def update_channel(channel_id):
chan = await app.storage.get_channel(channel_id)
if is_guild:
- await app.dispatcher.dispatch(
- 'guild', guild_id, 'CHANNEL_UPDATE', chan)
+ await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan)
else:
- await app.dispatcher.dispatch(
- 'channel', channel_id, 'CHANNEL_UPDATE', chan)
+ await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan)
return jsonify(chan)
-@bp.route('//typing', methods=['POST'])
+@bp.route("//typing", methods=["POST"])
async def trigger_typing(channel_id):
user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id)
- await app.dispatcher.dispatch('channel', channel_id, 'TYPING_START', {
- 'channel_id': str(channel_id),
- 'user_id': str(user_id),
- 'timestamp': int(time.time()),
+ await app.dispatcher.dispatch(
+ "channel",
+ channel_id,
+ "TYPING_START",
+ {
+ "channel_id": str(channel_id),
+ "user_id": str(user_id),
+ "timestamp": int(time.time()),
+ # guild_id for lazy guilds
+ "guild_id": str(guild_id) if ctype == ChannelType.GUILD_TEXT else None,
+ },
+ )
- # guild_id for lazy guilds
- 'guild_id': str(guild_id) if ctype == ChannelType.GUILD_TEXT else None,
- })
-
- return '', 204
+ return "", 204
async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
@@ -521,7 +601,8 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
if not message_id:
message_id = await app.storage.chan_last_message(channel_id)
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO user_read_state
(user_id, channel_id, last_message_id, mention_count)
VALUES
@@ -532,26 +613,31 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
SET last_message_id = $3, mention_count = 0
WHERE user_read_state.user_id = $1
AND user_read_state.channel_id = $2
- """, user_id, channel_id, message_id)
+ """,
+ user_id,
+ channel_id,
+ message_id,
+ )
if guild_id:
await app.dispatcher.dispatch_user_guild(
- user_id, guild_id, 'MESSAGE_ACK', {
- 'message_id': str(message_id),
- 'channel_id': str(channel_id)
- })
+ user_id,
+ guild_id,
+ "MESSAGE_ACK",
+ {"message_id": str(message_id), "channel_id": str(channel_id)},
+ )
else:
# we don't use ChannelDispatcher here because since
# guild_id is None, all user devices are already subscribed
# to the given channel (a dm or a group dm)
await app.dispatcher.dispatch_user(
- user_id, 'MESSAGE_ACK', {
- 'message_id': str(message_id),
- 'channel_id': str(channel_id)
- })
+ user_id,
+ "MESSAGE_ACK",
+ {"message_id": str(message_id), "channel_id": str(channel_id)},
+ )
-@bp.route('//messages//ack', methods=['POST'])
+@bp.route("//messages//ack", methods=["POST"])
async def ack_channel(channel_id, message_id):
"""Acknowledge a channel."""
user_id = await token_check()
@@ -562,40 +648,47 @@ async def ack_channel(channel_id, message_id):
await channel_ack(user_id, guild_id, channel_id, message_id)
- return jsonify({
- # token seems to be used for
- # data collection activities,
- # so we never use it.
- 'token': None
- })
+ return jsonify(
+ {
+ # token seems to be used for
+ # data collection activities,
+ # so we never use it.
+ "token": None
+ }
+ )
-@bp.route('//messages/ack', methods=['DELETE'])
+@bp.route("//messages/ack", methods=["DELETE"])
async def delete_read_state(channel_id):
"""Delete the read state of a channel."""
user_id = await token_check()
await channel_check(user_id, channel_id)
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM user_read_state
WHERE user_id = $1 AND channel_id = $2
- """, user_id, channel_id)
+ """,
+ user_id,
+ channel_id,
+ )
- return '', 204
+ return "", 204
-@bp.route('//messages/search', methods=['GET'])
+@bp.route("//messages/search", methods=["GET"])
async def _search_channel(channel_id):
"""Search in DMs or group DMs"""
user_id = await token_check()
await channel_check(user_id, channel_id)
- await channel_perm_check(user_id, channel_id, 'read_messages')
+ await channel_perm_check(user_id, channel_id, "read_messages")
j = validate(dict(request.args), SEARCH_CHANNEL)
# main search query
# the context (before/after) columns are copied from the guilds blueprint.
- rows = await app.db.fetch(f"""
+ rows = await app.db.fetch(
+ f"""
SELECT orig.id AS current_id,
COUNT(*) OVER() AS total_results,
array((SELECT messages.id AS before_id
@@ -611,28 +704,40 @@ async def _search_channel(channel_id):
ORDER BY orig.id DESC
LIMIT 50
OFFSET $2
- """, channel_id, j['offset'], j['content'])
+ """,
+ channel_id,
+ j["offset"],
+ j["content"],
+ )
return jsonify(await search_result_from_list(rows))
+
# NOTE that those functions stay here until some other
# route or code wants it.
async def _msg_update_flags(message_id: int, flags: int):
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE messages
SET flags = $1
WHERE id = $2
- """, flags, message_id)
+ """,
+ flags,
+ message_id,
+ )
async def _msg_get_flags(message_id: int):
- return await app.db.fetchval("""
+ return await app.db.fetchval(
+ """
SELECT flags
FROM messages
WHERE id = $1
- """, message_id)
+ """,
+ message_id,
+ )
async def _msg_set_flags(message_id: int, new_flags: int):
@@ -647,8 +752,9 @@ async def _msg_unset_flags(message_id: int, unset_flags: int):
await _msg_update_flags(message_id, flags)
-@bp.route('//messages//suppress-embeds',
- methods=['POST'])
+@bp.route(
+ "//messages//suppress-embeds", methods=["POST"]
+)
async def suppress_embeds(channel_id: int, message_id: int):
"""Toggle the embeds in a message.
@@ -661,29 +767,27 @@ async def suppress_embeds(channel_id: int, message_id: int):
# the checks here have been copied from the delete_message()
# handler on blueprints.channel.messages. maybe we can combine
# them someday?
- author_id = await app.db.fetchval("""
+ author_id = await app.db.fetchval(
+ """
SELECT author_id FROM messages
WHERE messages.id = $1
- """, message_id)
+ """,
+ message_id,
+ )
- by_perms = await channel_perm_check(
- user_id, channel_id, 'manage_messages', False)
+ by_perms = await channel_perm_check(user_id, channel_id, "manage_messages", False)
by_author = author_id == user_id
can_suppress = by_perms or by_author
if not can_suppress:
- raise Forbidden('Not enough permissions.')
+ raise Forbidden("Not enough permissions.")
- j = validate(
- await request.get_json(),
- {'suppress': {'type': 'boolean'}},
- )
+ j = validate(await request.get_json(), {"suppress": {"type": "boolean"}})
- suppress = j['suppress']
+ suppress = j["suppress"]
message = await app.storage.get_message(message_id)
- url_embeds = sum(
- 1 for embed in message['embeds'] if embed['type'] == 'url')
+ url_embeds = sum(1 for embed in message["embeds"] if embed["type"] == "url")
# NOTE for any future self. discord doing flags an optional thing instead
# of just giving 0 is a pretty bad idea because now i have to deal with
@@ -693,8 +797,7 @@ async def suppress_embeds(channel_id: int, message_id: int):
# delete all embeds then dispatch an update
await _msg_set_flags(message_id, MessageFlags.suppress_embeds)
- message['flags'] = \
- message.get('flags', 0) | MessageFlags.suppress_embeds
+ message["flags"] = message.get("flags", 0) | MessageFlags.suppress_embeds
await msg_update_embeds(message, [], app.storage, app.dispatcher)
elif not suppress and not url_embeds:
@@ -702,30 +805,29 @@ async def suppress_embeds(channel_id: int, message_id: int):
await _msg_unset_flags(message_id, MessageFlags.suppress_embeds)
try:
- message.pop('flags')
+ message.pop("flags")
except KeyError:
pass
app.sched.spawn(
process_url_embed(
- app.config, app.storage, app.dispatcher, app.session,
- message
+ app.config, app.storage, app.dispatcher, app.session, message
)
)
- return '', 204
+ return "", 204
-@bp.route('//messages/bulk-delete', methods=['POST'])
+@bp.route("//messages/bulk-delete", methods=["POST"])
async def bulk_delete(channel_id: int):
user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id)
guild_id = guild_id if ctype in GUILD_CHANS else None
- await channel_perm_check(user_id, channel_id, 'manage_messages')
+ await channel_perm_check(user_id, channel_id, "manage_messages")
j = validate(await request.get_json(), BULK_DELETE)
- message_ids = set(j['messages'])
+ message_ids = set(j["messages"])
# as per discord behavior, if any id here is older than two weeks,
# we must error. a cuter behavior would be returning the message ids
@@ -738,25 +840,28 @@ async def bulk_delete(channel_id: int):
raise BadRequest(50034)
payload = {
- 'guild_id': str(guild_id),
- 'channel_id': str(channel_id),
- 'ids': list(map(str, message_ids)),
+ "guild_id": str(guild_id),
+ "channel_id": str(channel_id),
+ "ids": list(map(str, message_ids)),
}
# payload.guild_id is optional in the event, not nullable.
if guild_id is None:
- payload.pop('guild_id')
+ payload.pop("guild_id")
- res = await app.db.execute("""
+ res = await app.db.execute(
+ """
DELETE FROM messages
WHERE
channel_id = $1
AND ARRAY[id] <@ $2::bigint[]
- """, channel_id, list(message_ids))
+ """,
+ channel_id,
+ list(message_ids),
+ )
- if res == 'DELETE 0':
- raise BadRequest('No messages were removed')
+ if res == "DELETE 0":
+ raise BadRequest("No messages were removed")
- await app.dispatcher.dispatch(
- 'channel', channel_id, 'MESSAGE_DELETE_BULK', payload)
- return '', 204
+ await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_DELETE_BULK", payload)
+ return "", 204
diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py
index e897c21..5141f5b 100644
--- a/litecord/blueprints/checks.py
+++ b/litecord/blueprints/checks.py
@@ -23,46 +23,57 @@ from quart import current_app as app
from litecord.enums import ChannelType, GUILD_CHANS
from litecord.errors import (
- GuildNotFound, ChannelNotFound, Forbidden, MissingPermissions
+ GuildNotFound,
+ ChannelNotFound,
+ Forbidden,
+ MissingPermissions,
)
from litecord.permissions import base_permissions, get_permissions
async def guild_check(user_id: int, guild_id: int):
"""Check if a user is in a guild."""
- joined_at = await app.db.fetchval("""
+ joined_at = await app.db.fetchval(
+ """
SELECT joined_at
FROM members
WHERE user_id = $1 AND guild_id = $2
- """, user_id, guild_id)
+ """,
+ user_id,
+ guild_id,
+ )
if not joined_at:
- raise GuildNotFound('guild not found')
+ raise GuildNotFound("guild not found")
async def guild_owner_check(user_id: int, guild_id: int):
"""Check if a user is the owner of the guild."""
- owner_id = await app.db.fetchval("""
+ owner_id = await app.db.fetchval(
+ """
SELECT owner_id
FROM guilds
WHERE guilds.id = $1
- """, guild_id)
+ """,
+ guild_id,
+ )
if not owner_id:
raise GuildNotFound()
if user_id != owner_id:
- raise Forbidden('You are not the owner of the guild')
+ raise Forbidden("You are not the owner of the guild")
-async def channel_check(user_id, channel_id, *,
- only: Union[ChannelType, List[ChannelType]] = None):
+async def channel_check(
+ user_id, channel_id, *, only: Union[ChannelType, List[ChannelType]] = None
+):
"""Check if the current user is authorized
to read the channel's information."""
chan_type = await app.storage.get_chan_type(channel_id)
if chan_type is None:
- raise ChannelNotFound('channel type not found')
+ raise ChannelNotFound("channel type not found")
ctype = ChannelType(chan_type)
@@ -70,14 +81,17 @@ async def channel_check(user_id, channel_id, *,
only = [only]
if only and ctype not in only:
- raise ChannelNotFound('invalid channel type')
+ raise ChannelNotFound("invalid channel type")
if ctype in GUILD_CHANS:
- guild_id = await app.db.fetchval("""
+ guild_id = await app.db.fetchval(
+ """
SELECT guild_id
FROM guild_channels
WHERE guild_channels.id = $1
- """, channel_id)
+ """,
+ channel_id,
+ )
await guild_check(user_id, guild_id)
return ctype, guild_id
@@ -87,11 +101,14 @@ async def channel_check(user_id, channel_id, *,
return ctype, peer_id
if ctype == ChannelType.GROUP_DM:
- owner_id = await app.db.fetchval("""
+ owner_id = await app.db.fetchval(
+ """
SELECT owner_id
FROM group_dm_channels
WHERE id = $1
- """, channel_id)
+ """,
+ channel_id,
+ )
return ctype, owner_id
@@ -102,18 +119,17 @@ async def guild_perm_check(user_id, guild_id, permission: str):
hasperm = getattr(base_perms.bits, permission)
if not hasperm:
- raise MissingPermissions('Missing permissions.')
+ raise MissingPermissions("Missing permissions.")
return bool(hasperm)
-async def channel_perm_check(user_id, channel_id,
- permission: str, raise_err=True):
+async def channel_perm_check(user_id, channel_id, permission: str, raise_err=True):
"""Check channel permissions for a user."""
base_perms = await get_permissions(user_id, channel_id)
hasperm = getattr(base_perms.bits, permission)
if not hasperm and raise_err:
- raise MissingPermissions('Missing permissions.')
+ raise MissingPermissions("Missing permissions.")
return bool(hasperm)
diff --git a/litecord/blueprints/dm_channels.py b/litecord/blueprints/dm_channels.py
index ac92855..76bd24a 100644
--- a/litecord/blueprints/dm_channels.py
+++ b/litecord/blueprints/dm_channels.py
@@ -29,21 +29,29 @@ from litecord.system_messages import send_sys_message
from litecord.pubsub.channel import gdm_recipient_view
log = Logger(__name__)
-bp = Blueprint('dm_channels', __name__)
+bp = Blueprint("dm_channels", __name__)
async def _raw_gdm_add(channel_id, user_id):
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO group_dm_members (id, member_id)
VALUES ($1, $2)
- """, channel_id, user_id)
+ """,
+ channel_id,
+ user_id,
+ )
async def _raw_gdm_remove(channel_id, user_id):
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM group_dm_members
WHERE id = $1 AND member_id = $2
- """, channel_id, user_id)
+ """,
+ channel_id,
+ user_id,
+ )
async def gdm_create(user_id, peer_id) -> int:
@@ -53,24 +61,32 @@ async def gdm_create(user_id, peer_id) -> int:
"""
channel_id = get_snowflake()
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO channels (id, channel_type)
VALUES ($1, $2)
- """, channel_id, ChannelType.GROUP_DM.value)
+ """,
+ channel_id,
+ ChannelType.GROUP_DM.value,
+ )
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO group_dm_channels (id, owner_id, name, icon)
VALUES ($1, $2, NULL, NULL)
- """, channel_id, user_id)
+ """,
+ channel_id,
+ user_id,
+ )
await _raw_gdm_add(channel_id, user_id)
await _raw_gdm_add(channel_id, peer_id)
- await app.dispatcher.sub('channel', channel_id, user_id)
- await app.dispatcher.sub('channel', channel_id, peer_id)
+ await app.dispatcher.sub("channel", channel_id, user_id)
+ await app.dispatcher.sub("channel", channel_id, peer_id)
chan = await app.storage.get_channel(channel_id)
- await app.dispatcher.dispatch('channel', channel_id, 'CHANNEL_CREATE', chan)
+ await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_CREATE", chan)
return channel_id
@@ -89,17 +105,16 @@ async def gdm_add_recipient(channel_id: int, peer_id: int, *, user_id=None):
# the reasoning behind gdm_recipient_view is in its docstring.
await app.dispatcher.dispatch(
- 'user', peer_id, 'CHANNEL_CREATE', gdm_recipient_view(chan, peer_id))
+ "user", peer_id, "CHANNEL_CREATE", gdm_recipient_view(chan, peer_id)
+ )
- await app.dispatcher.dispatch(
- 'channel', channel_id, 'CHANNEL_UPDATE', chan)
+ await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan)
- await app.dispatcher.sub('channel', peer_id)
+ await app.dispatcher.sub("channel", peer_id)
if user_id:
await send_sys_message(
- app, channel_id, MessageType.RECIPIENT_ADD,
- user_id, peer_id
+ app, channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id
)
@@ -116,22 +131,22 @@ async def gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None):
chan = await app.storage.get_channel(channel_id)
await app.dispatcher.dispatch(
- 'user', peer_id, 'CHANNEL_DELETE', gdm_recipient_view(chan, user_id))
+ "user", peer_id, "CHANNEL_DELETE", gdm_recipient_view(chan, user_id)
+ )
- await app.dispatcher.unsub('channel', peer_id)
+ await app.dispatcher.unsub("channel", peer_id)
await app.dispatcher.dispatch(
- 'channel', channel_id, 'CHANNEL_RECIPIENT_REMOVE', {
- 'channel_id': str(channel_id),
- 'user': await app.storage.get_user(peer_id)
- }
+ "channel",
+ channel_id,
+ "CHANNEL_RECIPIENT_REMOVE",
+ {"channel_id": str(channel_id), "user": await app.storage.get_user(peer_id)},
)
author_id = peer_id if user_id is None else user_id
await send_sys_message(
- app, channel_id, MessageType.RECIPIENT_REMOVE,
- author_id, peer_id
+ app, channel_id, MessageType.RECIPIENT_REMOVE, author_id, peer_id
)
@@ -139,40 +154,51 @@ async def gdm_destroy(channel_id):
"""Destroy a Group DM."""
chan = await app.storage.get_channel(channel_id)
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM group_dm_members
WHERE id = $1
- """, channel_id)
-
- await app.db.execute("""
- DELETE FROM group_dm_channels
- WHERE id = $1
- """, channel_id)
-
- await app.db.execute("""
- DELETE FROM channels
- WHERE id = $1
- """, channel_id)
-
- await app.dispatcher.dispatch(
- 'channel', channel_id, 'CHANNEL_DELETE', chan
+ """,
+ channel_id,
)
- await app.dispatcher.remove('channel', channel_id)
+ await app.db.execute(
+ """
+ DELETE FROM group_dm_channels
+ WHERE id = $1
+ """,
+ channel_id,
+ )
+
+ await app.db.execute(
+ """
+ DELETE FROM channels
+ WHERE id = $1
+ """,
+ channel_id,
+ )
+
+ await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_DELETE", chan)
+
+ await app.dispatcher.remove("channel", channel_id)
async def gdm_is_member(channel_id: int, user_id: int) -> bool:
"""Return if the given user is a member of the Group DM."""
- row = await app.db.fetchval("""
+ row = await app.db.fetchval(
+ """
SELECT id
FROM group_dm_members
WHERE id = $1 AND member_id = $2
- """, channel_id, user_id)
+ """,
+ channel_id,
+ user_id,
+ )
return row is not None
-@bp.route('//recipients/', methods=['PUT'])
+@bp.route("//recipients/", methods=["PUT"])
async def add_to_group_dm(dm_chan, peer_id):
"""Adds a member to a group dm OR creates a group dm."""
user_id = await token_check()
@@ -182,8 +208,7 @@ async def add_to_group_dm(dm_chan, peer_id):
# other_id is the peer of the dm if the given channel is a dm
ctype, other_id = await channel_check(
- user_id, dm_chan,
- only=[ChannelType.DM, ChannelType.GROUP_DM]
+ user_id, dm_chan, only=[ChannelType.DM, ChannelType.GROUP_DM]
)
# check relationship with the given user id
@@ -191,30 +216,24 @@ async def add_to_group_dm(dm_chan, peer_id):
friends = await app.user_storage.are_friends_with(user_id, peer_id)
if not friends:
- raise BadRequest('Cant insert peer into dm')
+ raise BadRequest("Cant insert peer into dm")
if ctype == ChannelType.DM:
- dm_chan = await gdm_create(
- user_id, other_id
- )
+ dm_chan = await gdm_create(user_id, other_id)
await gdm_add_recipient(dm_chan, peer_id, user_id=user_id)
- return jsonify(
- await app.storage.get_channel(dm_chan)
- )
+ return jsonify(await app.storage.get_channel(dm_chan))
-@bp.route('//recipients/', methods=['DELETE'])
+@bp.route("//recipients/", methods=["DELETE"])
async def remove_from_group_dm(dm_chan, peer_id):
"""Remove users from group dm."""
user_id = await token_check()
- _ctype, owner_id = await channel_check(
- user_id, dm_chan, only=ChannelType.GROUP_DM
- )
+ _ctype, owner_id = await channel_check(user_id, dm_chan, only=ChannelType.GROUP_DM)
if owner_id != user_id:
- raise Forbidden('You are now the owner of the group DM')
+ raise Forbidden("You are now the owner of the group DM")
await gdm_remove_recipient(dm_chan, peer_id)
- return '', 204
+ return "", 204
diff --git a/litecord/blueprints/dms.py b/litecord/blueprints/dms.py
index 4289161..975d7ea 100644
--- a/litecord/blueprints/dms.py
+++ b/litecord/blueprints/dms.py
@@ -30,15 +30,13 @@ from ..snowflake import get_snowflake
from .auth import token_check
-from litecord.blueprints.dm_channels import (
- gdm_create, gdm_add_recipient
-)
+from litecord.blueprints.dm_channels import gdm_create, gdm_add_recipient
log = Logger(__name__)
-bp = Blueprint('dms', __name__)
+bp = Blueprint("dms", __name__)
-@bp.route('/@me/channels', methods=['GET'])
+@bp.route("/@me/channels", methods=["GET"])
async def get_dms():
"""Get the open DMs for the user."""
user_id = await token_check()
@@ -53,11 +51,15 @@ async def try_dm_state(user_id: int, dm_id: int):
Does not do anything if the user is already
in the dm state.
"""
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO dm_channel_state (user_id, dm_id)
VALUES ($1, $2)
ON CONFLICT DO NOTHING
- """, user_id, dm_id)
+ """,
+ user_id,
+ dm_id,
+ )
async def jsonify_dm(dm_id: int, user_id: int):
@@ -69,12 +71,16 @@ async def create_dm(user_id, recipient_id):
"""Create a new dm with a user,
or get the existing DM id if it already exists."""
- dm_id = await app.db.fetchval("""
+ dm_id = await app.db.fetchval(
+ """
SELECT id
FROM dm_channels
WHERE (party1_id = $1 OR party2_id = $1) AND
(party1_id = $2 OR party2_id = $2)
- """, user_id, recipient_id)
+ """,
+ user_id,
+ recipient_id,
+ )
if dm_id:
return await jsonify_dm(dm_id, user_id)
@@ -82,15 +88,24 @@ async def create_dm(user_id, recipient_id):
# if no dm was found, create a new one
dm_id = get_snowflake()
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO channels (id, channel_type)
VALUES ($1, $2)
- """, dm_id, ChannelType.DM.value)
+ """,
+ dm_id,
+ ChannelType.DM.value,
+ )
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO dm_channels (id, party1_id, party2_id)
VALUES ($1, $2, $3)
- """, dm_id, user_id, recipient_id)
+ """,
+ dm_id,
+ user_id,
+ recipient_id,
+ )
# the dm state is something we use
# to give the currently "open dms"
@@ -103,24 +118,24 @@ async def create_dm(user_id, recipient_id):
return await jsonify_dm(dm_id, user_id)
-@bp.route('/@me/channels', methods=['POST'])
+@bp.route("/@me/channels", methods=["POST"])
async def start_dm():
"""Create a DM with a user."""
user_id = await token_check()
j = validate(await request.get_json(), CREATE_DM)
- recipient_id = j['recipient_id']
+ recipient_id = j["recipient_id"]
return await create_dm(user_id, recipient_id)
-@bp.route('//channels', methods=['POST'])
+@bp.route("//channels", methods=["POST"])
async def create_group_dm(p_user_id: int):
"""Create a DM or a Group DM with user(s)."""
user_id = await token_check()
assert user_id == p_user_id
j = validate(await request.get_json(), CREATE_GROUP_DM)
- recipients = j['recipients']
+ recipients = j["recipients"]
if len(recipients) == 1:
# its a group dm with 1 user... a dm!
diff --git a/litecord/blueprints/gateway.py b/litecord/blueprints/gateway.py
index 05303f0..a50a60e 100644
--- a/litecord/blueprints/gateway.py
+++ b/litecord/blueprints/gateway.py
@@ -23,37 +23,38 @@ from quart import Blueprint, jsonify, current_app as app
from ..auth import token_check
-bp = Blueprint('gateway', __name__)
+bp = Blueprint("gateway", __name__)
def get_gw():
"""Get the gateway's web"""
- proto = 'wss://' if app.config['IS_SSL'] else 'ws://'
+ proto = "wss://" if app.config["IS_SSL"] else "ws://"
return f'{proto}{app.config["WEBSOCKET_URL"]}'
-@bp.route('/gateway')
+@bp.route("/gateway")
def api_gateway():
"""Get the raw URL."""
- return jsonify({
- 'url': get_gw()
- })
+ return jsonify({"url": get_gw()})
-@bp.route('/gateway/bot')
+@bp.route("/gateway/bot")
async def api_gateway_bot():
user_id = await token_check()
- guild_count = await app.db.fetchval("""
+ guild_count = await app.db.fetchval(
+ """
SELECT COUNT(*)
FROM members
WHERE user_id = $1
- """, user_id)
+ """,
+ user_id,
+ )
shards = max(int(guild_count / 1000), 1)
# get _ws.session ratelimit
- ratelimit = app.ratelimiter.get_ratelimit('_ws.session')
+ ratelimit = app.ratelimiter.get_ratelimit("_ws.session")
bucket = ratelimit.get_bucket(user_id)
# timestamp of bucket reset
@@ -62,13 +63,14 @@ async def api_gateway_bot():
# how many seconds until bucket reset
reset_after_ts = reset_ts - time.time()
- return jsonify({
- 'url': get_gw(),
- 'shards': shards,
-
- 'session_start_limit': {
- 'total': bucket.requests,
- 'remaining': bucket._tokens,
- 'reset_after': int(reset_after_ts * 1000),
+ return jsonify(
+ {
+ "url": get_gw(),
+ "shards": shards,
+ "session_start_limit": {
+ "total": bucket.requests,
+ "remaining": bucket._tokens,
+ "reset_after": int(reset_after_ts * 1000),
+ },
}
- })
+ )
diff --git a/litecord/blueprints/guild/__init__.py b/litecord/blueprints/guild/__init__.py
index 7e37ef6..31c2640 100644
--- a/litecord/blueprints/guild/__init__.py
+++ b/litecord/blueprints/guild/__init__.py
@@ -23,5 +23,4 @@ from .channels import bp as guild_channels
from .mod import bp as guild_mod
from .emoji import bp as guild_emoji
-__all__ = ['guild_roles', 'guild_members', 'guild_channels', 'guild_mod',
- 'guild_emoji']
+__all__ = ["guild_roles", "guild_members", "guild_channels", "guild_mod", "guild_emoji"]
diff --git a/litecord/blueprints/guild/channels.py b/litecord/blueprints/guild/channels.py
index 2b0467b..c8a2d2e 100644
--- a/litecord/blueprints/guild/channels.py
+++ b/litecord/blueprints/guild/channels.py
@@ -25,23 +25,23 @@ from litecord.errors import BadRequest
from litecord.enums import ChannelType
from litecord.blueprints.guild.roles import gen_pairs
-from litecord.schemas import (
- validate, ROLE_UPDATE_POSITION, CHAN_CREATE
-)
-from litecord.blueprints.checks import (
- guild_check, guild_owner_check, guild_perm_check
-)
+from litecord.schemas import validate, ROLE_UPDATE_POSITION, CHAN_CREATE
+from litecord.blueprints.checks import guild_check, guild_owner_check, guild_perm_check
-bp = Blueprint('guild_channels', __name__)
+bp = Blueprint("guild_channels", __name__)
async def _specific_chan_create(channel_id, ctype, **kwargs):
if ctype == ChannelType.GUILD_TEXT:
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO guild_text_channels (id, topic)
VALUES ($1, $2)
- """, channel_id, kwargs.get('topic', ''))
+ """,
+ channel_id,
+ kwargs.get("topic", ""),
+ )
elif ctype == ChannelType.GUILD_VOICE:
await app.db.execute(
"""
@@ -49,34 +49,48 @@ async def _specific_chan_create(channel_id, ctype, **kwargs):
VALUES ($1, $2, $3)
""",
channel_id,
- kwargs.get('bitrate', 64),
- kwargs.get('user_limit', 0)
+ kwargs.get("bitrate", 64),
+ kwargs.get("user_limit", 0),
)
-async def create_guild_channel(guild_id: int, channel_id: int,
- ctype: ChannelType, **kwargs):
+async def create_guild_channel(
+ guild_id: int, channel_id: int, ctype: ChannelType, **kwargs
+):
"""Create a channel in a guild."""
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO channels (id, channel_type)
VALUES ($1, $2)
- """, channel_id, ctype.value)
+ """,
+ channel_id,
+ ctype.value,
+ )
# calc new pos
- max_pos = await app.db.fetchval("""
+ max_pos = await app.db.fetchval(
+ """
SELECT MAX(position)
FROM guild_channels
WHERE guild_id = $1
- """, guild_id)
+ """,
+ guild_id,
+ )
# account for the first channel in a guild too
max_pos = max_pos or 0
# all channels go to guild_channels
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO guild_channels (id, guild_id, name, position)
VALUES ($1, $2, $3, $4)
- """, channel_id, guild_id, kwargs['name'], max_pos + 1)
+ """,
+ channel_id,
+ guild_id,
+ kwargs["name"],
+ max_pos + 1,
+ )
# the rest of sql magic is dependant on the channel
# we're creating (a text or voice or category),
@@ -84,35 +98,32 @@ async def create_guild_channel(guild_id: int, channel_id: int,
await _specific_chan_create(channel_id, ctype, **kwargs)
-@bp.route('//channels', methods=['GET'])
+@bp.route("//channels", methods=["GET"])
async def get_guild_channels(guild_id):
"""Get the list of channels in a guild."""
user_id = await token_check()
await guild_check(user_id, guild_id)
- return jsonify(
- await app.storage.get_channel_data(guild_id))
+ return jsonify(await app.storage.get_channel_data(guild_id))
-@bp.route('//channels', methods=['POST'])
+@bp.route("//channels", methods=["POST"])
async def create_channel(guild_id):
"""Create a channel in a guild."""
user_id = await token_check()
j = validate(await request.get_json(), CHAN_CREATE)
await guild_check(user_id, guild_id)
- await guild_perm_check(user_id, guild_id, 'manage_channels')
+ await guild_perm_check(user_id, guild_id, "manage_channels")
- channel_type = j.get('type', ChannelType.GUILD_TEXT)
+ channel_type = j.get("type", ChannelType.GUILD_TEXT)
channel_type = ChannelType(channel_type)
- if channel_type not in (ChannelType.GUILD_TEXT,
- ChannelType.GUILD_VOICE):
- raise BadRequest('Invalid channel type')
+ if channel_type not in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE):
+ raise BadRequest("Invalid channel type")
new_channel_id = get_snowflake()
- await create_guild_channel(
- guild_id, new_channel_id, channel_type, **j)
+ await create_guild_channel(guild_id, new_channel_id, channel_type, **j)
# TODO: do a better method
# subscribe the currently subscribed users to the new channel
@@ -120,14 +131,13 @@ async def create_channel(guild_id):
# since GuildDispatcher calls Storage.get_channel_ids,
# it will subscribe all users to the newly created channel.
- guild_pubsub = app.dispatcher.backends['guild']
+ guild_pubsub = app.dispatcher.backends["guild"]
user_ids = guild_pubsub.state[guild_id]
for uid in user_ids:
- await app.dispatcher.sub('guild', guild_id, uid)
+ await app.dispatcher.sub("guild", guild_id, uid)
chan = await app.storage.get_channel(new_channel_id)
- await app.dispatcher.dispatch_guild(
- guild_id, 'CHANNEL_CREATE', chan)
+ await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_CREATE", chan)
return jsonify(chan)
@@ -135,7 +145,7 @@ async def _chan_update_dispatch(guild_id: int, channel_id: int):
"""Fetch new information about the channel and dispatch
a single CHANNEL_UPDATE event to the guild."""
chan = await app.storage.get_channel(channel_id)
- await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_UPDATE', chan)
+ await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", chan)
async def _do_single_swap(guild_id: int, pair: tuple):
@@ -149,13 +159,14 @@ async def _do_single_swap(guild_id: int, pair: tuple):
conn = await app.db.acquire()
async with conn.transaction():
- await conn.executemany("""
+ await conn.executemany(
+ """
UPDATE guild_channels
SET position = $1
WHERE id = $2 AND guild_id = $3
- """, [
- (new_pos_1, channel_1, guild_id),
- (new_pos_2, channel_2, guild_id)])
+ """,
+ [(new_pos_1, channel_1, guild_id), (new_pos_2, channel_2, guild_id)],
+ )
await app.db.release(conn)
@@ -173,30 +184,26 @@ async def _do_channel_swaps(guild_id: int, swap_pairs: list):
await _do_single_swap(guild_id, pair)
-@bp.route('//channels', methods=['PATCH'])
+@bp.route("//channels", methods=["PATCH"])
async def modify_channel_pos(guild_id):
"""Change positions of channels in a guild."""
user_id = await token_check()
await guild_owner_check(user_id, guild_id)
- await guild_perm_check(user_id, guild_id, 'manage_channels')
+ await guild_perm_check(user_id, guild_id, "manage_channels")
# same thing as guild.roles, so we use
# the same schema and all.
raw_j = await request.get_json()
- j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION)
- j = j['roles']
+ j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION)
+ j = j["roles"]
channels = await app.storage.get_channel_data(guild_id)
- channel_positions = {chan['position']: int(chan['id'])
- for chan in channels}
+ channel_positions = {chan["position"]: int(chan["id"]) for chan in channels}
- swap_pairs = gen_pairs(
- j,
- channel_positions
- )
+ swap_pairs = gen_pairs(j, channel_positions)
await _do_channel_swaps(guild_id, swap_pairs)
- return '', 204
+ return "", 204
diff --git a/litecord/blueprints/guild/emoji.py b/litecord/blueprints/guild/emoji.py
index cc48dfe..9a5a7d5 100644
--- a/litecord/blueprints/guild/emoji.py
+++ b/litecord/blueprints/guild/emoji.py
@@ -27,78 +27,85 @@ from litecord.types import KILOBYTES
from litecord.images import parse_data_uri
from litecord.errors import BadRequest
-bp = Blueprint('guild.emoji', __name__)
+bp = Blueprint("guild.emoji", __name__)
async def _dispatch_emojis(guild_id):
"""Dispatch a Guild Emojis Update payload to a guild."""
- await app.dispatcher.dispatch('guild', guild_id, 'GUILD_EMOJIS_UPDATE', {
- 'guild_id': str(guild_id),
- 'emojis': await app.storage.get_guild_emojis(guild_id)
- })
+ await app.dispatcher.dispatch(
+ "guild",
+ guild_id,
+ "GUILD_EMOJIS_UPDATE",
+ {
+ "guild_id": str(guild_id),
+ "emojis": await app.storage.get_guild_emojis(guild_id),
+ },
+ )
-@bp.route('//emojis', methods=['GET'])
+@bp.route("//emojis", methods=["GET"])
async def _get_guild_emoji(guild_id):
user_id = await token_check()
await guild_check(user_id, guild_id)
- return jsonify(
- await app.storage.get_guild_emojis(guild_id)
- )
+ return jsonify(await app.storage.get_guild_emojis(guild_id))
-@bp.route('//emojis/', methods=['GET'])
+@bp.route("//emojis/", methods=["GET"])
async def _get_guild_emoji_one(guild_id, emoji_id):
user_id = await token_check()
await guild_check(user_id, guild_id)
- return jsonify(
- await app.storage.get_emoji(emoji_id)
- )
+ return jsonify(await app.storage.get_emoji(emoji_id))
async def _guild_emoji_size_check(guild_id: int, mime: str):
limit = 50
- if await app.storage.has_feature(guild_id, 'MORE_EMOJI'):
+ if await app.storage.has_feature(guild_id, "MORE_EMOJI"):
limit = 200
# NOTE: I'm assuming you can have 200 animated emojis.
- select_animated = mime == 'image/gif'
+ select_animated = mime == "image/gif"
- total_emoji = await app.db.fetchval("""
+ total_emoji = await app.db.fetchval(
+ """
SELECT COUNT(*) FROM guild_emoji
WHERE guild_id = $1 AND animated = $2
- """, guild_id, select_animated)
+ """,
+ guild_id,
+ select_animated,
+ )
if total_emoji >= limit:
# TODO: really return a BadRequest? needs more looking.
- raise BadRequest(f'too many emoji ({limit})')
+ raise BadRequest(f"too many emoji ({limit})")
-@bp.route('//emojis', methods=['POST'])
+@bp.route("//emojis", methods=["POST"])
async def _put_emoji(guild_id):
user_id = await token_check()
await guild_check(user_id, guild_id)
- await guild_perm_check(user_id, guild_id, 'manage_emojis')
+ await guild_perm_check(user_id, guild_id, "manage_emojis")
j = validate(await request.get_json(), NEW_EMOJI)
# we have to parse it before passing on so that we know which
# size to check.
- mime, _ = parse_data_uri(j['image'])
+ mime, _ = parse_data_uri(j["image"])
await _guild_emoji_size_check(guild_id, mime)
emoji_id = get_snowflake()
icon = await app.icons.put(
- 'emoji', emoji_id, j['image'],
-
+ "emoji",
+ emoji_id,
+ j["image"],
# limits to emojis
- bsize=128 * KILOBYTES, size=(128, 128)
+ bsize=128 * KILOBYTES,
+ size=(128, 128),
)
if not icon:
- return '', 400
+ return "", 400
# TODO: better way to detect animated emoji rather than just gifs,
# maybe a list perhaps?
@@ -109,25 +116,25 @@ async def _put_emoji(guild_id):
VALUES
($1, $2, $3, $4, $5, $6)
""",
- emoji_id, guild_id, user_id,
- j['name'],
+ emoji_id,
+ guild_id,
+ user_id,
+ j["name"],
icon.icon_hash,
- icon.mime == 'image/gif'
+ icon.mime == "image/gif",
)
await _dispatch_emojis(guild_id)
- return jsonify(
- await app.storage.get_emoji(emoji_id)
- )
+ return jsonify(await app.storage.get_emoji(emoji_id))
-@bp.route('//emojis/', methods=['PATCH'])
+@bp.route("//emojis/", methods=["PATCH"])
async def _patch_emoji(guild_id, emoji_id):
user_id = await token_check()
await guild_check(user_id, guild_id)
- await guild_perm_check(user_id, guild_id, 'manage_emojis')
+ await guild_perm_check(user_id, guild_id, "manage_emojis")
j = validate(await request.get_json(), PATCH_EMOJI)
emoji = await app.storage.get_emoji(emoji_id)
@@ -135,34 +142,39 @@ async def _patch_emoji(guild_id, emoji_id):
# if emoji.name is still the same, we don't update anything
# or send ane events, just return the same emoji we'd send
# as if we updated it.
- if j['name'] == emoji['name']:
+ if j["name"] == emoji["name"]:
return jsonify(emoji)
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE guild_emoji
SET name = $1
WHERE id = $2
- """, j['name'], emoji_id)
+ """,
+ j["name"],
+ emoji_id,
+ )
await _dispatch_emojis(guild_id)
- return jsonify(
- await app.storage.get_emoji(emoji_id)
- )
+ return jsonify(await app.storage.get_emoji(emoji_id))
-@bp.route('//emojis/', methods=['DELETE'])
+@bp.route("//emojis/", methods=["DELETE"])
async def _del_emoji(guild_id, emoji_id):
user_id = await token_check()
await guild_check(user_id, guild_id)
- await guild_perm_check(user_id, guild_id, 'manage_emojis')
+ await guild_perm_check(user_id, guild_id, "manage_emojis")
# TODO: check if actually deleted
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM guild_emoji
WHERE id = $2
- """, emoji_id)
+ """,
+ emoji_id,
+ )
await _dispatch_emojis(guild_id)
- return '', 204
+ return "", 204
diff --git a/litecord/blueprints/guild/members.py b/litecord/blueprints/guild/members.py
index 3558168..4ea9abb 100644
--- a/litecord/blueprints/guild/members.py
+++ b/litecord/blueprints/guild/members.py
@@ -22,18 +22,14 @@ from quart import Blueprint, request, current_app as app, jsonify
from litecord.blueprints.auth import token_check
from litecord.errors import BadRequest
-from litecord.schemas import (
- validate, MEMBER_UPDATE
-)
+from litecord.schemas import validate, MEMBER_UPDATE
-from litecord.blueprints.checks import (
- guild_check, guild_owner_check, guild_perm_check
-)
+from litecord.blueprints.checks import guild_check, guild_owner_check, guild_perm_check
-bp = Blueprint('guild_members', __name__)
+bp = Blueprint("guild_members", __name__)
-@bp.route('//members/', methods=['GET'])
+@bp.route("//members/", methods=["GET"])
async def get_guild_member(guild_id, member_id):
"""Get a member's information in a guild."""
user_id = await token_check()
@@ -42,7 +38,7 @@ async def get_guild_member(guild_id, member_id):
return jsonify(member)
-@bp.route('//members', methods=['GET'])
+@bp.route("//members", methods=["GET"])
async def get_members(guild_id):
"""Get members inside a guild."""
user_id = await token_check()
@@ -50,34 +46,41 @@ async def get_members(guild_id):
j = await request.get_json()
- limit, after = int(j.get('limit', 1)), j.get('after', 0)
+ limit, after = int(j.get("limit", 1)), j.get("after", 0)
if limit < 1 or limit > 1000:
- raise BadRequest('limit not in 1-1000 range')
+ raise BadRequest("limit not in 1-1000 range")
- user_ids = await app.db.fetch(f"""
+ user_ids = await app.db.fetch(
+ f"""
SELECT user_id
WHERE guild_id = $1, user_id > $2
LIMIT {limit}
ORDER BY user_id ASC
- """, guild_id, after)
+ """,
+ guild_id,
+ after,
+ )
user_ids = [r[0] for r in user_ids]
members = await app.storage.get_member_multi(guild_id, user_ids)
return jsonify(members)
-async def _update_member_roles(guild_id: int, member_id: int,
- wanted_roles: set):
+async def _update_member_roles(guild_id: int, member_id: int, wanted_roles: set):
"""Update the roles a member has."""
# first, fetch all current roles
- roles = await app.db.fetch("""
+ roles = await app.db.fetch(
+ """
SELECT role_id from member_roles
WHERE guild_id = $1 AND user_id = $2
- """, guild_id, member_id)
+ """,
+ guild_id,
+ member_id,
+ )
- roles = [r['role_id'] for r in roles]
+ roles = [r["role_id"] for r in roles]
roles = set(roles)
wanted_roles = set(wanted_roles)
@@ -96,26 +99,30 @@ async def _update_member_roles(guild_id: int, member_id: int,
async with conn.transaction():
# add roles
- await app.db.executemany("""
+ await app.db.executemany(
+ """
INSERT INTO member_roles (user_id, guild_id, role_id)
VALUES ($1, $2, $3)
- """, [(member_id, guild_id, role_id)
- for role_id in added_roles])
+ """,
+ [(member_id, guild_id, role_id) for role_id in added_roles],
+ )
# remove roles
- await app.db.executemany("""
+ await app.db.executemany(
+ """
DELETE FROM member_roles
WHERE
user_id = $1
AND guild_id = $2
AND role_id = $3
- """, [(member_id, guild_id, role_id)
- for role_id in removed_roles])
+ """,
+ [(member_id, guild_id, role_id) for role_id in removed_roles],
+ )
await app.db.release(conn)
-@bp.route('//members/', methods=['PATCH'])
+@bp.route("//members/", methods=["PATCH"])
async def modify_guild_member(guild_id, member_id):
"""Modify a members' information in a guild."""
user_id = await token_check()
@@ -124,96 +131,112 @@ async def modify_guild_member(guild_id, member_id):
j = validate(await request.get_json(), MEMBER_UPDATE)
nick_flag = False
- if 'nick' in j:
- await guild_perm_check(user_id, guild_id, 'manage_nicknames')
+ if "nick" in j:
+ await guild_perm_check(user_id, guild_id, "manage_nicknames")
- nick = j['nick'] or None
+ nick = j["nick"] or None
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE members
SET nickname = $1
WHERE user_id = $2 AND guild_id = $3
- """, nick, member_id, guild_id)
+ """,
+ nick,
+ member_id,
+ guild_id,
+ )
nick_flag = True
- if 'mute' in j:
- await guild_perm_check(user_id, guild_id, 'mute_members')
+ if "mute" in j:
+ await guild_perm_check(user_id, guild_id, "mute_members")
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE members
SET muted = $1
WHERE user_id = $2 AND guild_id = $3
- """, j['mute'], member_id, guild_id)
+ """,
+ j["mute"],
+ member_id,
+ guild_id,
+ )
- if 'deaf' in j:
- await guild_perm_check(user_id, guild_id, 'deafen_members')
+ if "deaf" in j:
+ await guild_perm_check(user_id, guild_id, "deafen_members")
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE members
SET deafened = $1
WHERE user_id = $2 AND guild_id = $3
- """, j['deaf'], member_id, guild_id)
+ """,
+ j["deaf"],
+ member_id,
+ guild_id,
+ )
- if 'channel_id' in j:
+ if "channel_id" in j:
# TODO: check MOVE_MEMBERS and CONNECT to the channel
# TODO: change the member's voice channel
pass
- if 'roles' in j:
- await guild_perm_check(user_id, guild_id, 'manage_roles')
- await _update_member_roles(guild_id, member_id, j['roles'])
+ if "roles" in j:
+ await guild_perm_check(user_id, guild_id, "manage_roles")
+ await _update_member_roles(guild_id, member_id, j["roles"])
member = await app.storage.get_member_data_one(guild_id, member_id)
- member.pop('joined_at')
+ member.pop("joined_at")
# call pres_update for role and nick changes.
- partial = {
- 'roles': member['roles']
- }
+ partial = {"roles": member["roles"]}
if nick_flag:
- partial['nick'] = j['nick']
+ partial["nick"] = j["nick"]
await app.dispatcher.dispatch(
- 'lazy_guild', guild_id, 'pres_update', user_id, partial)
+ "lazy_guild", guild_id, "pres_update", user_id, partial
+ )
- await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{
- 'guild_id': str(guild_id)
- }, **member})
+ await app.dispatcher.dispatch_guild(
+ guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}
+ )
- return '', 204
+ return "", 204
-@bp.route('//members/@me/nick', methods=['PATCH'])
+@bp.route("//members/@me/nick", methods=["PATCH"])
async def update_nickname(guild_id):
"""Update a member's nickname in a guild."""
user_id = await token_check()
await guild_check(user_id, guild_id)
- j = validate(await request.get_json(), {
- 'nick': {'type': 'nickname'}
- })
+ j = validate(await request.get_json(), {"nick": {"type": "nickname"}})
- nick = j['nick'] or None
+ nick = j["nick"] or None
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE members
SET nickname = $1
WHERE user_id = $2 AND guild_id = $3
- """, nick, user_id, guild_id)
+ """,
+ nick,
+ user_id,
+ guild_id,
+ )
member = await app.storage.get_member_data_one(guild_id, user_id)
- member.pop('joined_at')
+ member.pop("joined_at")
# call pres_update for nick changes, etc.
await app.dispatcher.dispatch(
- 'lazy_guild', guild_id, 'pres_update', user_id, {
- 'nick': j['nick']
- })
+ "lazy_guild", guild_id, "pres_update", user_id, {"nick": j["nick"]}
+ )
- await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{
- 'guild_id': str(guild_id)
- }, **member})
+ await app.dispatcher.dispatch_guild(
+ guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}
+ )
- return j['nick']
+ return j["nick"]
diff --git a/litecord/blueprints/guild/mod.py b/litecord/blueprints/guild/mod.py
index 7878d3b..5949fd0 100644
--- a/litecord/blueprints/guild/mod.py
+++ b/litecord/blueprints/guild/mod.py
@@ -24,33 +24,38 @@ from litecord.blueprints.checks import guild_perm_check
from litecord.schemas import validate, GUILD_PRUNE
-bp = Blueprint('guild_moderation', __name__)
+bp = Blueprint("guild_moderation", __name__)
async def remove_member(guild_id: int, member_id: int):
"""Do common tasks related to deleting a member from the guild,
such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE."""
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM members
WHERE guild_id = $1 AND user_id = $2
- """, guild_id, member_id)
+ """,
+ guild_id,
+ member_id,
+ )
await app.dispatcher.dispatch_user_guild(
- member_id, guild_id, 'GUILD_DELETE', {
- 'guild_id': str(guild_id),
- 'unavailable': False,
- })
+ member_id,
+ guild_id,
+ "GUILD_DELETE",
+ {"guild_id": str(guild_id), "unavailable": False},
+ )
- await app.dispatcher.unsub('guild', guild_id, member_id)
+ await app.dispatcher.unsub("guild", guild_id, member_id)
- await app.dispatcher.dispatch(
- 'lazy_guild', guild_id, 'remove_member', member_id)
+ await app.dispatcher.dispatch("lazy_guild", guild_id, "remove_member", member_id)
- await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', {
- 'guild_id': str(guild_id),
- 'user': await app.storage.get_user(member_id),
- })
+ await app.dispatcher.dispatch_guild(
+ guild_id,
+ "GUILD_MEMBER_REMOVE",
+ {"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
+ )
async def remove_member_multi(guild_id: int, members: list):
@@ -59,84 +64,100 @@ async def remove_member_multi(guild_id: int, members: list):
await remove_member(guild_id, member_id)
-@bp.route('//members/', methods=['DELETE'])
+@bp.route("//members/", methods=["DELETE"])
async def kick_guild_member(guild_id, member_id):
"""Remove a member from a guild."""
user_id = await token_check()
- await guild_perm_check(user_id, guild_id, 'kick_members')
+ await guild_perm_check(user_id, guild_id, "kick_members")
await remove_member(guild_id, member_id)
- return '', 204
+ return "", 204
-@bp.route('//bans', methods=['GET'])
+@bp.route("//bans", methods=["GET"])
async def get_bans(guild_id):
user_id = await token_check()
- await guild_perm_check(user_id, guild_id, 'ban_members')
+ await guild_perm_check(user_id, guild_id, "ban_members")
- bans = await app.db.fetch("""
+ bans = await app.db.fetch(
+ """
SELECT user_id, reason
FROM bans
WHERE bans.guild_id = $1
- """, guild_id)
+ """,
+ guild_id,
+ )
res = []
for ban in bans:
- res.append({
- 'reason': ban['reason'],
- 'user': await app.storage.get_user(ban['user_id'])
- })
+ res.append(
+ {
+ "reason": ban["reason"],
+ "user": await app.storage.get_user(ban["user_id"]),
+ }
+ )
return jsonify(res)
-@bp.route('//bans/', methods=['PUT'])
+@bp.route("//bans/", methods=["PUT"])
async def create_ban(guild_id, member_id):
user_id = await token_check()
- await guild_perm_check(user_id, guild_id, 'ban_members')
+ await guild_perm_check(user_id, guild_id, "ban_members")
j = await request.get_json()
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO bans (guild_id, user_id, reason)
VALUES ($1, $2, $3)
- """, guild_id, member_id, j.get('reason', ''))
+ """,
+ guild_id,
+ member_id,
+ j.get("reason", ""),
+ )
await remove_member(guild_id, member_id)
- await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_ADD', {
- 'guild_id': str(guild_id),
- 'user': await app.storage.get_user(member_id)
- })
+ await app.dispatcher.dispatch_guild(
+ guild_id,
+ "GUILD_BAN_ADD",
+ {"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
+ )
- return '', 204
+ return "", 204
-@bp.route('//bans/', methods=['DELETE'])
+@bp.route("//bans/", methods=["DELETE"])
async def remove_ban(guild_id, banned_id):
user_id = await token_check()
- await guild_perm_check(user_id, guild_id, 'ban_members')
+ await guild_perm_check(user_id, guild_id, "ban_members")
- res = await app.db.execute("""
+ res = await app.db.execute(
+ """
DELETE FROM bans
WHERE guild_id = $1 AND user_id = $@
- """, guild_id, banned_id)
+ """,
+ guild_id,
+ banned_id,
+ )
# we don't really need to dispatch GUILD_BAN_REMOVE
# when no bans were actually removed.
- if res == 'DELETE 0':
- return '', 204
+ if res == "DELETE 0":
+ return "", 204
- await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_REMOVE', {
- 'guild_id': str(guild_id),
- 'user': await app.storage.get_user(banned_id)
- })
+ await app.dispatcher.dispatch_guild(
+ guild_id,
+ "GUILD_BAN_REMOVE",
+ {"guild_id": str(guild_id), "user": await app.storage.get_user(banned_id)},
+ )
- return '', 204
+ return "", 204
async def get_prune(guild_id: int, days: int) -> list:
@@ -146,23 +167,30 @@ async def get_prune(guild_id: int, days: int) -> list:
- don't have any roles.
"""
# a good solution would be in pure sql.
- member_ids = await app.db.fetch(f"""
+ member_ids = await app.db.fetch(
+ f"""
SELECT id
FROM users
JOIN members
ON members.guild_id = $1 AND members.user_id = users.id
WHERE users.last_session < (now() - (interval '{days} days'))
- """, guild_id)
+ """,
+ guild_id,
+ )
- member_ids = [r['id'] for r in member_ids]
+ member_ids = [r["id"] for r in member_ids]
members = []
for member_id in member_ids:
- role_count = await app.db.fetchval("""
+ role_count = await app.db.fetchval(
+ """
SELECT COUNT(*)
FROM member_roles
WHERE guild_id = $1 AND user_id = $2
- """, guild_id, member_id)
+ """,
+ guild_id,
+ member_id,
+ )
if role_count == 0:
members.append(member_id)
@@ -170,33 +198,29 @@ async def get_prune(guild_id: int, days: int) -> list:
return members
-@bp.route('//prune', methods=['GET'])
+@bp.route("//prune", methods=["GET"])
async def get_guild_prune_count(guild_id):
user_id = await token_check()
- await guild_perm_check(user_id, guild_id, 'kick_members')
+ await guild_perm_check(user_id, guild_id, "kick_members")
j = validate(request.args, GUILD_PRUNE)
- days = j['days']
+ days = j["days"]
member_ids = await get_prune(guild_id, days)
- return jsonify({
- 'pruned': len(member_ids),
- })
+ return jsonify({"pruned": len(member_ids)})
-@bp.route('//prune', methods=['POST'])
+@bp.route("//prune", methods=["POST"])
async def begin_guild_prune(guild_id):
user_id = await token_check()
- await guild_perm_check(user_id, guild_id, 'kick_members')
+ await guild_perm_check(user_id, guild_id, "kick_members")
j = validate(request.args, GUILD_PRUNE)
- days = j['days']
+ days = j["days"]
member_ids = await get_prune(guild_id, days)
app.loop.create_task(remove_member_multi(guild_id, member_ids))
- return jsonify({
- 'pruned': len(member_ids)
- })
+ return jsonify({"pruned": len(member_ids)})
diff --git a/litecord/blueprints/guild/roles.py b/litecord/blueprints/guild/roles.py
index 5dfcbe7..9516aa4 100644
--- a/litecord/blueprints/guild/roles.py
+++ b/litecord/blueprints/guild/roles.py
@@ -24,12 +24,8 @@ from logbook import Logger
from litecord.auth import token_check
-from litecord.blueprints.checks import (
- guild_check, guild_perm_check
-)
-from litecord.schemas import (
- validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION
-)
+from litecord.blueprints.checks import guild_check, guild_perm_check
+from litecord.schemas import validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION
from litecord.snowflake import get_snowflake
from litecord.utils import dict_get
@@ -37,22 +33,19 @@ from litecord.permissions import get_role_perms
DEFAULT_EVERYONE_PERMS = 104324161
log = Logger(__name__)
-bp = Blueprint('guild_roles', __name__)
+bp = Blueprint("guild_roles", __name__)
-@bp.route('//roles', methods=['GET'])
+@bp.route("//roles", methods=["GET"])
async def get_guild_roles(guild_id):
"""Get all roles in a guild."""
user_id = await token_check()
await guild_check(user_id, guild_id)
- return jsonify(
- await app.storage.get_role_data(guild_id)
- )
+ return jsonify(await app.storage.get_role_data(guild_id))
-async def _maybe_lg(guild_id: int, event: str,
- role, force: bool = False):
+async def _maybe_lg(guild_id: int, event: str, role, force: bool = False):
# sometimes we want to dispatch an event
# even if the role isn't hoisted
@@ -61,11 +54,10 @@ async def _maybe_lg(guild_id: int, event: str,
# check if is a dict first because role_delete
# only receives the role id.
- if isinstance(role, dict) and not role['hoist'] and not force:
+ if isinstance(role, dict) and not role["hoist"] and not force:
return
- await app.dispatcher.dispatch(
- 'lazy_guild', guild_id, event, role)
+ await app.dispatcher.dispatch("lazy_guild", guild_id, event, role)
async def create_role(guild_id, name: str, **kwargs):
@@ -73,18 +65,20 @@ async def create_role(guild_id, name: str, **kwargs):
new_role_id = get_snowflake()
everyone_perms = await get_role_perms(guild_id, guild_id)
- default_perms = dict_get(kwargs, 'default_perms',
- everyone_perms.binary)
+ default_perms = dict_get(kwargs, "default_perms", everyone_perms.binary)
# update all roles so that we have space for pos 1, but without
# sending GUILD_ROLE_UPDATE for everyone
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE roles
SET
position = position + 1
WHERE guild_id = $1
AND NOT (position = 0)
- """, guild_id)
+ """,
+ guild_id,
+ )
await app.db.execute(
"""
@@ -95,42 +89,39 @@ async def create_role(guild_id, name: str, **kwargs):
new_role_id,
guild_id,
name,
- dict_get(kwargs, 'color', 0),
- dict_get(kwargs, 'hoist', False),
-
+ dict_get(kwargs, "color", 0),
+ dict_get(kwargs, "hoist", False),
# always set ourselves on position 1
1,
- int(dict_get(kwargs, 'permissions', default_perms)),
+ int(dict_get(kwargs, "permissions", default_perms)),
False,
- dict_get(kwargs, 'mentionable', False)
+ dict_get(kwargs, "mentionable", False),
)
role = await app.storage.get_role(new_role_id, guild_id)
# we need to update the lazy guild handlers for the newly created group
- await _maybe_lg(guild_id, 'new_role', role)
+ await _maybe_lg(guild_id, "new_role", role)
await app.dispatcher.dispatch_guild(
- guild_id, 'GUILD_ROLE_CREATE', {
- 'guild_id': str(guild_id),
- 'role': role,
- })
+ guild_id, "GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role}
+ )
return role
-@bp.route('//roles', methods=['POST'])
+@bp.route("//roles", methods=["POST"])
async def create_guild_role(guild_id: int):
"""Add a role to a guild"""
user_id = await token_check()
- await guild_perm_check(user_id, guild_id, 'manage_roles')
+ await guild_perm_check(user_id, guild_id, "manage_roles")
# client can just send null
j = validate(await request.get_json() or {}, ROLE_CREATE)
- role_name = j['name']
- j.pop('name')
+ role_name = j["name"]
+ j.pop("name")
role = await create_role(guild_id, role_name, **j)
@@ -141,12 +132,11 @@ async def _role_update_dispatch(role_id: int, guild_id: int):
"""Dispatch a GUILD_ROLE_UPDATE with updated information on a role."""
role = await app.storage.get_role(role_id, guild_id)
- await _maybe_lg(guild_id, 'role_pos_upd', role)
+ await _maybe_lg(guild_id, "role_pos_upd", role)
- await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_UPDATE', {
- 'guild_id': str(guild_id),
- 'role': role,
- })
+ await app.dispatcher.dispatch_guild(
+ guild_id, "GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role}
+ )
return role
@@ -166,17 +156,25 @@ async def _role_pairs_update(guild_id: int, pairs: list):
async with conn.transaction():
# update happens in a transaction
# so we don't fuck it up
- await conn.execute("""
+ await conn.execute(
+ """
UPDATE roles
SET position = $1
WHERE roles.id = $2
- """, new_pos_1, role_1)
+ """,
+ new_pos_1,
+ role_1,
+ )
- await conn.execute("""
+ await conn.execute(
+ """
UPDATE roles
SET position = $1
WHERE roles.id = $2
- """, new_pos_2, role_2)
+ """,
+ new_pos_2,
+ role_2,
+ )
await app.db.release(conn)
@@ -184,11 +182,15 @@ async def _role_pairs_update(guild_id: int, pairs: list):
await _role_update_dispatch(role_1, guild_id)
await _role_update_dispatch(role_2, guild_id)
+
PairList = List[Tuple[Tuple[int, int], Tuple[int, int]]]
-def gen_pairs(list_of_changes: List[Dict[str, int]],
- current_state: Dict[int, int],
- blacklist: List[int] = None) -> PairList:
+
+def gen_pairs(
+ list_of_changes: List[Dict[str, int]],
+ current_state: Dict[int, int],
+ blacklist: List[int] = None,
+) -> PairList:
"""Generate a list of pairs that, when applied to the database,
will generate the desired state given in list_of_changes.
@@ -230,8 +232,9 @@ def gen_pairs(list_of_changes: List[Dict[str, int]],
pairs = []
blacklist = blacklist or []
- preferred_state = {element['id']: element['position']
- for element in list_of_changes}
+ preferred_state = {
+ element["id"]: element["position"] for element in list_of_changes
+ }
for blacklisted_id in blacklist:
preferred_state.pop(blacklisted_id)
@@ -239,7 +242,7 @@ def gen_pairs(list_of_changes: List[Dict[str, int]],
# for each change, we must find a matching change
# in the same list, so we can make a swap pair
for change in list_of_changes:
- element_1, new_pos_1 = change['id'], change['position']
+ element_1, new_pos_1 = change["id"], change["position"]
# check current pairs
# so we don't repeat an element
@@ -267,36 +270,34 @@ def gen_pairs(list_of_changes: List[Dict[str, int]],
# if its being swapped to leave space, add it
# to the pairs list
if new_pos_2 is not None:
- pairs.append(
- ((element_1, new_pos_1), (element_2, new_pos_2))
- )
+ pairs.append(((element_1, new_pos_1), (element_2, new_pos_2)))
return pairs
-@bp.route('//roles', methods=['PATCH'])
+@bp.route("//roles", methods=["PATCH"])
async def update_guild_role_positions(guild_id):
"""Update the positions for a bunch of roles."""
user_id = await token_check()
- await guild_perm_check(user_id, guild_id, 'manage_roles')
+ await guild_perm_check(user_id, guild_id, "manage_roles")
raw_j = await request.get_json()
# we need to do this hackiness because thats
# cerberus for ya.
- j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION)
+ j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION)
# extract the list out
- j = j['roles']
+ j = j["roles"]
- log.debug('role stuff: {!r}', j)
+ log.debug("role stuff: {!r}", j)
all_roles = await app.storage.get_role_data(guild_id)
# we'll have to calculate pairs of changing roles,
# then do the changes, etc.
- roles_pos = {role['position']: int(role['id']) for role in all_roles}
+ roles_pos = {role["position"]: int(role["id"]) for role in all_roles}
# TODO: check if the user can even change the roles in the first place,
# preferrably when we have a proper perms system.
@@ -306,10 +307,9 @@ async def update_guild_role_positions(guild_id):
pairs = gen_pairs(
j,
roles_pos,
-
# always ignore people trying to change
# the @everyone's role position
- [guild_id]
+ [guild_id],
)
await _role_pairs_update(guild_id, pairs)
@@ -318,31 +318,36 @@ async def update_guild_role_positions(guild_id):
return jsonify(await app.storage.get_role_data(guild_id))
-@bp.route('//roles/', methods=['PATCH'])
+@bp.route("//roles/", methods=["PATCH"])
async def update_guild_role(guild_id, role_id):
"""Update a single role's information."""
user_id = await token_check()
- await guild_perm_check(user_id, guild_id, 'manage_roles')
+ await guild_perm_check(user_id, guild_id, "manage_roles")
j = validate(await request.get_json(), ROLE_UPDATE)
# we only update ints on the db, not Permissions
- j['permissions'] = int(j['permissions'])
+ j["permissions"] = int(j["permissions"])
for field in j:
- await app.db.execute(f"""
+ await app.db.execute(
+ f"""
UPDATE roles
SET {field} = $1
WHERE roles.id = $2 AND roles.guild_id = $3
- """, j[field], role_id, guild_id)
+ """,
+ j[field],
+ role_id,
+ guild_id,
+ )
role = await _role_update_dispatch(role_id, guild_id)
- await _maybe_lg(guild_id, 'role_update', role, True)
+ await _maybe_lg(guild_id, "role_update", role, True)
return jsonify(role)
-@bp.route('//roles/', methods=['DELETE'])
+@bp.route("//roles/", methods=["DELETE"])
async def delete_guild_role(guild_id, role_id):
"""Delete a role.
@@ -350,21 +355,26 @@ async def delete_guild_role(guild_id, role_id):
"""
user_id = await token_check()
- await guild_perm_check(user_id, guild_id, 'manage_roles')
+ await guild_perm_check(user_id, guild_id, "manage_roles")
- res = await app.db.execute("""
+ res = await app.db.execute(
+ """
DELETE FROM roles
WHERE guild_id = $1 AND id = $2
- """, guild_id, role_id)
+ """,
+ guild_id,
+ role_id,
+ )
- if res == 'DELETE 0':
- return '', 204
+ if res == "DELETE 0":
+ return "", 204
- await _maybe_lg(guild_id, 'role_delete', role_id, True)
+ await _maybe_lg(guild_id, "role_delete", role_id, True)
- await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_DELETE', {
- 'guild_id': str(guild_id),
- 'role_id': str(role_id),
- })
+ await app.dispatcher.dispatch_guild(
+ guild_id,
+ "GUILD_ROLE_DELETE",
+ {"guild_id": str(guild_id), "role_id": str(role_id)},
+ )
- return '', 204
+ return "", 204
diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py
index 3c3c7dc..7db4224 100644
--- a/litecord/blueprints/guilds.py
+++ b/litecord/blueprints/guilds.py
@@ -22,16 +22,17 @@ from typing import Optional, List
from quart import Blueprint, request, current_app as app, jsonify
from litecord.blueprints.guild.channels import create_guild_channel
-from litecord.blueprints.guild.roles import (
- create_role, DEFAULT_EVERYONE_PERMS
-)
+from litecord.blueprints.guild.roles import create_role, DEFAULT_EVERYONE_PERMS
from ..auth import token_check
from ..snowflake import get_snowflake
from ..enums import ChannelType
from ..schemas import (
- validate, GUILD_CREATE, GUILD_UPDATE, SEARCH_CHANNEL,
- VANITY_URL_PATCH
+ validate,
+ GUILD_CREATE,
+ GUILD_UPDATE,
+ SEARCH_CHANNEL,
+ VANITY_URL_PATCH,
)
from .channels import channel_ack
from .checks import guild_check, guild_owner_check, guild_perm_check
@@ -40,7 +41,7 @@ from litecord.errors import BadRequest
from litecord.permissions import get_permissions
-bp = Blueprint('guilds', __name__)
+bp = Blueprint("guilds", __name__)
async def create_guild_settings(guild_id: int, user_id: int):
@@ -49,26 +50,38 @@ async def create_guild_settings(guild_id: int, user_id: int):
# new guild_settings are based off the currently
# set guild settings (for the guild)
- m_notifs = await app.db.fetchval("""
+ m_notifs = await app.db.fetchval(
+ """
SELECT default_message_notifications
FROM guilds
WHERE id = $1
- """, guild_id)
+ """,
+ guild_id,
+ )
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO guild_settings
(user_id, guild_id, message_notifications)
VALUES
($1, $2, $3)
- """, user_id, guild_id, m_notifs)
+ """,
+ user_id,
+ guild_id,
+ m_notifs,
+ )
async def add_member(guild_id: int, user_id: int):
"""Add a user to a guild."""
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO members (user_id, guild_id)
VALUES ($1, $2)
- """, user_id, guild_id)
+ """,
+ user_id,
+ guild_id,
+ )
await create_guild_settings(guild_id, user_id)
@@ -83,28 +96,28 @@ async def guild_create_roles_prep(guild_id: int, roles: list):
# are patches to the @everyone role
everyone_patches = roles[0]
for field in everyone_patches:
- await app.db.execute(f"""
+ await app.db.execute(
+ f"""
UPDATE roles
SET {field}={everyone_patches[field]}
WHERE roles.id = $1
- """, guild_id)
+ """,
+ guild_id,
+ )
- default_perms = (everyone_patches.get('permissions')
- or DEFAULT_EVERYONE_PERMS)
+ default_perms = everyone_patches.get("permissions") or DEFAULT_EVERYONE_PERMS
# from the 2nd and forward,
# should be treated as new roles
for role in roles[1:]:
- await create_role(
- guild_id, role['name'], default_perms=default_perms, **role
- )
+ await create_role(guild_id, role["name"], default_perms=default_perms, **role)
async def guild_create_channels_prep(guild_id: int, channels: list):
"""Create channels pre-guild create"""
for channel_raw in channels:
channel_id = get_snowflake()
- ctype = ChannelType(channel_raw['type'])
+ ctype = ChannelType(channel_raw["type"])
await create_guild_channel(guild_id, channel_id, ctype)
@@ -114,37 +127,29 @@ def sanitize_icon(icon: Optional[str]) -> Optional[str]:
Defaults to a jpeg icon when the header isn't given.
"""
- if icon and icon.startswith('data'):
+ if icon and icon.startswith("data"):
return icon
- return (f'data:image/jpeg;base64,{icon}'
- if icon
- else None)
+ return f"data:image/jpeg;base64,{icon}" if icon else None
-async def _general_guild_icon(scope: str, guild_id: int,
- icon: str, **kwargs):
+async def _general_guild_icon(scope: str, guild_id: int, icon: str, **kwargs):
encoded = sanitize_icon(icon)
- icon_kwargs = {
- 'always_icon': True
- }
+ icon_kwargs = {"always_icon": True}
- if 'size' in kwargs:
- icon_kwargs['size'] = kwargs['size']
+ if "size" in kwargs:
+ icon_kwargs["size"] = kwargs["size"]
- return await app.icons.put(
- scope, guild_id, encoded,
- **icon_kwargs
- )
+ return await app.icons.put(scope, guild_id, encoded, **icon_kwargs)
async def put_guild_icon(guild_id: int, icon: Optional[str]):
"""Insert a guild icon on the icon database."""
- return await _general_guild_icon('guild', guild_id, icon, size=(128, 128))
+ return await _general_guild_icon("guild", guild_id, icon, size=(128, 128))
-@bp.route('', methods=['POST'])
+@bp.route("", methods=["POST"])
async def create_guild():
"""Create a new guild, assigning
the user creating it as the owner and
@@ -154,8 +159,8 @@ async def create_guild():
guild_id = get_snowflake()
- if 'icon' in j:
- image = await put_guild_icon(guild_id, j['icon'])
+ if "icon" in j:
+ image = await put_guild_icon(guild_id, j["icon"])
image = image.icon_hash
else:
image = None
@@ -166,10 +171,16 @@ async def create_guild():
verification_level, default_message_notifications,
explicit_content_filter)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
- """, guild_id, j['name'], j['region'], image, user_id,
- j.get('verification_level', 0),
- j.get('default_message_notifications', 0),
- j.get('explicit_content_filter', 0))
+ """,
+ guild_id,
+ j["name"],
+ j["region"],
+ image,
+ user_id,
+ j.get("verification_level", 0),
+ j.get("default_message_notifications", 0),
+ j.get("explicit_content_filter", 0),
+ )
await add_member(guild_id, user_id)
@@ -179,107 +190,127 @@ async def create_guild():
# we also don't use create_role because the id of the role
# is the same as the id of the guild, and create_role
# generates a new snowflake.
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO roles (id, guild_id, name, position, permissions)
VALUES ($1, $2, $3, $4, $5)
- """, guild_id, guild_id, '@everyone', 0, DEFAULT_EVERYONE_PERMS)
+ """,
+ guild_id,
+ guild_id,
+ "@everyone",
+ 0,
+ DEFAULT_EVERYONE_PERMS,
+ )
# add the @everyone role to the guild creator
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO member_roles (user_id, guild_id, role_id)
VALUES ($1, $2, $3)
- """, user_id, guild_id, guild_id)
+ """,
+ user_id,
+ guild_id,
+ guild_id,
+ )
# create a single #general channel.
general_id = get_snowflake()
await create_guild_channel(
- guild_id, general_id, ChannelType.GUILD_TEXT,
- name='general')
+ guild_id, general_id, ChannelType.GUILD_TEXT, name="general"
+ )
- if j.get('roles'):
- await guild_create_roles_prep(guild_id, j['roles'])
+ if j.get("roles"):
+ await guild_create_roles_prep(guild_id, j["roles"])
- if j.get('channels'):
- await guild_create_channels_prep(guild_id, j['channels'])
+ if j.get("channels"):
+ await guild_create_channels_prep(guild_id, j["channels"])
guild_total = await app.storage.get_guild_full(guild_id, user_id, 250)
- await app.dispatcher.sub('guild', guild_id, user_id)
- await app.dispatcher.dispatch_guild(guild_id, 'GUILD_CREATE', guild_total)
+ await app.dispatcher.sub("guild", guild_id, user_id)
+ await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild_total)
return jsonify(guild_total)
-@bp.route('/', methods=['GET'])
+@bp.route("/", methods=["GET"])
async def get_guild(guild_id):
"""Get a single guilds' information."""
user_id = await token_check()
await guild_check(user_id, guild_id)
- return jsonify(
- await app.storage.get_guild_full(guild_id, user_id, 250)
- )
+ return jsonify(await app.storage.get_guild_full(guild_id, user_id, 250))
-async def _guild_update_icon(scope: str, guild_id: int,
- icon: Optional[str], **kwargs):
+async def _guild_update_icon(scope: str, guild_id: int, icon: Optional[str], **kwargs):
"""Update icon."""
- new_icon = await app.icons.update(
- scope, guild_id, icon, always_icon=True, **kwargs
- )
+ new_icon = await app.icons.update(scope, guild_id, icon, always_icon=True, **kwargs)
- table = {
- 'guild': 'icon',
- }.get(scope, scope)
+ table = {"guild": "icon"}.get(scope, scope)
- await app.db.execute(f"""
+ await app.db.execute(
+ f"""
UPDATE guilds
SET {table} = $1
WHERE id = $2
- """, new_icon.icon_hash, guild_id)
+ """,
+ new_icon.icon_hash,
+ guild_id,
+ )
async def _guild_update_region(guild_id, region):
is_vip = region.vip
- can_vip = await app.storage.has_feature(guild_id, 'VIP_REGIONS')
+ can_vip = await app.storage.has_feature(guild_id, "VIP_REGIONS")
if is_vip and not can_vip:
- raise BadRequest('can not assign guild to vip-only region')
+ raise BadRequest("can not assign guild to vip-only region")
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE guilds
SET region = $1
WHERE id = $2
- """, region.id, guild_id)
+ """,
+ region.id,
+ guild_id,
+ )
-
-@bp.route('/', methods=['PATCH'])
+@bp.route("/", methods=["PATCH"])
async def _update_guild(guild_id):
user_id = await token_check()
await guild_check(user_id, guild_id)
- await guild_perm_check(user_id, guild_id, 'manage_guild')
+ await guild_perm_check(user_id, guild_id, "manage_guild")
j = validate(await request.get_json(), GUILD_UPDATE)
- if 'owner_id' in j:
+ if "owner_id" in j:
await guild_owner_check(user_id, guild_id)
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE guilds
SET owner_id = $1
WHERE id = $2
- """, int(j['owner_id']), guild_id)
+ """,
+ int(j["owner_id"]),
+ guild_id,
+ )
- if 'name' in j:
- await app.db.execute("""
+ if "name" in j:
+ await app.db.execute(
+ """
UPDATE guilds
SET name = $1
WHERE id = $2
- """, j['name'], guild_id)
+ """,
+ j["name"],
+ guild_id,
+ )
- if 'region' in j:
- region = app.voice.lvsp.region(j['region'])
+ if "region" in j:
+ region = app.voice.lvsp.region(j["region"])
if region is not None:
await _guild_update_region(guild_id, region)
@@ -287,65 +318,77 @@ async def _update_guild(guild_id):
# small guild to work with to_update()
guild = await app.storage.get_guild(guild_id)
- if to_update(j, guild, 'icon'):
- await _guild_update_icon(
- 'guild', guild_id, j['icon'], size=(128, 128))
+ if to_update(j, guild, "icon"):
+ await _guild_update_icon("guild", guild_id, j["icon"], size=(128, 128))
- if to_update(j, guild, 'splash'):
- if not await app.storage.has_feature(guild_id, 'INVITE_SPLASH'):
- raise BadRequest('guild does not have INVITE_SPLASH feature')
+ if to_update(j, guild, "splash"):
+ if not await app.storage.has_feature(guild_id, "INVITE_SPLASH"):
+ raise BadRequest("guild does not have INVITE_SPLASH feature")
- await _guild_update_icon('splash', guild_id, j['splash'])
+ await _guild_update_icon("splash", guild_id, j["splash"])
- if to_update(j, guild, 'banner'):
- if not await app.storage.has_feature(guild_id, 'VERIFIED'):
- raise BadRequest('guild is not verified')
+ if to_update(j, guild, "banner"):
+ if not await app.storage.has_feature(guild_id, "VERIFIED"):
+ raise BadRequest("guild is not verified")
- await _guild_update_icon('banner', guild_id, j['banner'])
+ await _guild_update_icon("banner", guild_id, j["banner"])
- fields = ['verification_level', 'default_message_notifications',
- 'explicit_content_filter', 'afk_timeout', 'description']
+ fields = [
+ "verification_level",
+ "default_message_notifications",
+ "explicit_content_filter",
+ "afk_timeout",
+ "description",
+ ]
for field in [f for f in fields if f in j]:
- await app.db.execute(f"""
+ await app.db.execute(
+ f"""
UPDATE guilds
SET {field} = $1
WHERE id = $2
- """, j[field], guild_id)
+ """,
+ j[field],
+ guild_id,
+ )
- channel_fields = ['afk_channel_id', 'system_channel_id']
+ channel_fields = ["afk_channel_id", "system_channel_id"]
for field in [f for f in channel_fields if f in j]:
# setting to null should remove the link between the afk/sys channel
# to the guild.
if j[field] is None:
- await app.db.execute(f"""
+ await app.db.execute(
+ f"""
UPDATE guilds
SET {field} = NULL
WHERE id = $1
- """, guild_id)
+ """,
+ guild_id,
+ )
continue
chan = await app.storage.get_channel(int(j[field]))
if chan is None:
- raise BadRequest('invalid channel id')
+ raise BadRequest("invalid channel id")
- if chan['guild_id'] != str(guild_id):
- raise BadRequest('channel id not linked to guild')
+ if chan["guild_id"] != str(guild_id):
+ raise BadRequest("channel id not linked to guild")
- await app.db.execute(f"""
+ await app.db.execute(
+ f"""
UPDATE guilds
SET {field} = $1
WHERE id = $2
- """, j[field], guild_id)
+ """,
+ j[field],
+ guild_id,
+ )
- guild = await app.storage.get_guild_full(
- guild_id, user_id
- )
+ guild = await app.storage.get_guild_full(guild_id, user_id)
- await app.dispatcher.dispatch_guild(
- guild_id, 'GUILD_UPDATE', guild)
+ await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
return jsonify(guild)
@@ -354,33 +397,41 @@ async def delete_guild(guild_id: int, *, app_=None):
"""Delete a single guild."""
app_ = app_ or app
- await app_.db.execute("""
+ await app_.db.execute(
+ """
DELETE FROM guilds
WHERE guilds.id = $1
- """, guild_id)
+ """,
+ guild_id,
+ )
# Discord's client expects IDs being string
- await app_.dispatcher.dispatch('guild', guild_id, 'GUILD_DELETE', {
- 'guild_id': str(guild_id),
- 'id': str(guild_id),
- # 'unavailable': False,
- })
+ await app_.dispatcher.dispatch(
+ "guild",
+ guild_id,
+ "GUILD_DELETE",
+ {
+ "guild_id": str(guild_id),
+ "id": str(guild_id),
+ # 'unavailable': False,
+ },
+ )
# remove from the dispatcher so nobody
# becomes the little memer that tries to fuck up with
# everybody's gateway
- await app_.dispatcher.remove('guild', guild_id)
+ await app_.dispatcher.remove("guild", guild_id)
-@bp.route('/', methods=['DELETE'])
+@bp.route("/", methods=["DELETE"])
# this endpoint is not documented, but used by the official client.
-@bp.route('//delete', methods=['POST'])
+@bp.route("//delete", methods=["POST"])
async def delete_guild_handler(guild_id):
"""Delete a guild."""
user_id = await token_check()
await guild_owner_check(user_id, guild_id)
await delete_guild(guild_id)
- return '', 204
+ return "", 204
async def fetch_readable_channels(guild_id: int, user_id: int) -> List[int]:
@@ -397,7 +448,7 @@ async def fetch_readable_channels(guild_id: int, user_id: int) -> List[int]:
return res
-@bp.route('//messages/search', methods=['GET'])
+@bp.route("//messages/search", methods=["GET"])
async def search_messages(guild_id):
"""Search messages in a guild.
@@ -415,7 +466,8 @@ async def search_messages(guild_id):
# use that list on the main search query.
can_read = await fetch_readable_channels(guild_id, user_id)
- rows = await app.db.fetch(f"""
+ rows = await app.db.fetch(
+ f"""
SELECT orig.id AS current_id,
COUNT(*) OVER() as total_results,
array((SELECT messages.id AS before_id
@@ -432,12 +484,17 @@ async def search_messages(guild_id):
ORDER BY orig.id DESC
LIMIT 50
OFFSET $3
- """, guild_id, j['content'], j['offset'], can_read)
+ """,
+ guild_id,
+ j["content"],
+ j["offset"],
+ can_read,
+ )
return jsonify(await search_result_from_list(rows))
-@bp.route('//ack', methods=['POST'])
+@bp.route("//ack", methods=["POST"])
async def ack_guild(guild_id):
"""ACKnowledge all messages in the guild."""
user_id = await token_check()
@@ -448,45 +505,43 @@ async def ack_guild(guild_id):
for chan_id in chan_ids:
await channel_ack(user_id, guild_id, chan_id)
- return '', 204
+ return "", 204
-@bp.route('//vanity-url', methods=['GET'])
+@bp.route("//vanity-url", methods=["GET"])
async def get_vanity_url(guild_id: int):
"""Get the vanity url of a guild."""
user_id = await token_check()
- await guild_perm_check(user_id, guild_id, 'manage_guild')
+ await guild_perm_check(user_id, guild_id, "manage_guild")
inv_code = await app.storage.vanity_invite(guild_id)
if inv_code is None:
- return jsonify({'code': None})
+ return jsonify({"code": None})
- return jsonify(
- await app.storage.get_invite(inv_code)
- )
+ return jsonify(await app.storage.get_invite(inv_code))
-@bp.route('//vanity-url', methods=['PATCH'])
+@bp.route("//vanity-url", methods=["PATCH"])
async def change_vanity_url(guild_id: int):
"""Get the vanity url of a guild."""
user_id = await token_check()
- if not await app.storage.has_feature(guild_id, 'VANITY_URL'):
+ if not await app.storage.has_feature(guild_id, "VANITY_URL"):
# TODO: is this the right error
- raise BadRequest('guild has no vanity url support')
+ raise BadRequest("guild has no vanity url support")
- await guild_perm_check(user_id, guild_id, 'manage_guild')
+ await guild_perm_check(user_id, guild_id, "manage_guild")
j = validate(await request.get_json(), VANITY_URL_PATCH)
- inv_code = j['code']
+ inv_code = j["code"]
# store old vanity in a variable to delete it from
# invites table
old_vanity = await app.storage.vanity_invite(guild_id)
if old_vanity == inv_code:
- raise BadRequest('can not change to same invite')
+ raise BadRequest("can not change to same invite")
# this is sad because we don't really use the things
# sql gives us, but i havent really found a way to put
@@ -494,19 +549,22 @@ async def change_vanity_url(guild_id: int):
# guild_id_fkey fails but INSERT when code_fkey fails..
inv = await app.storage.get_invite(inv_code)
if inv:
- raise BadRequest('invite already exists')
+ raise BadRequest("invite already exists")
# TODO: this is bad, what if a guild has no channels?
# we should probably choose the first channel that has
# @everyone read messages
channels = await app.storage.get_channel_data(guild_id)
- channel_id = int(channels[0]['id'])
+ channel_id = int(channels[0]["id"])
# delete the old invite, insert new one
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM invites
WHERE code = $1
- """, old_vanity)
+ """,
+ old_vanity,
+ )
await app.db.execute(
"""
@@ -515,21 +573,27 @@ async def change_vanity_url(guild_id: int):
max_age, temporary)
VALUES ($1, $2, $3, $4, $5, $6, $7)
""",
- inv_code, guild_id, channel_id, user_id,
-
+ inv_code,
+ guild_id,
+ channel_id,
+ user_id,
# sane defaults for vanity urls.
- 0, 0, False,
+ 0,
+ 0,
+ False,
)
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO vanity_invites (guild_id, code)
VALUES ($1, $2)
ON CONFLICT ON CONSTRAINT vanity_invites_pkey DO
UPDATE
SET code = $2
WHERE vanity_invites.guild_id = $1
- """, guild_id, inv_code)
-
- return jsonify(
- await app.storage.get_invite(inv_code)
+ """,
+ guild_id,
+ inv_code,
)
+
+ return jsonify(await app.storage.get_invite(inv_code))
diff --git a/litecord/blueprints/icons.py b/litecord/blueprints/icons.py
index b01bac7..9b0f378 100644
--- a/litecord/blueprints/icons.py
+++ b/litecord/blueprints/icons.py
@@ -24,41 +24,39 @@ from quart import Blueprint, current_app as app, send_file, redirect
from litecord.embed.sanitizer import make_md_req_url
from litecord.embed.schemas import EmbedURL
-bp = Blueprint('images', __name__)
+bp = Blueprint("images", __name__)
async def send_icon(scope, key, icon_hash, **kwargs):
"""Send an icon."""
- icon = await app.icons.generic_get(
- scope, key, icon_hash, **kwargs)
+ icon = await app.icons.generic_get(scope, key, icon_hash, **kwargs)
if not icon:
- return '', 404
+ return "", 404
return await send_file(icon.as_path)
def splitext_(filepath):
name, ext = splitext(filepath)
- return name, ext.strip('.')
+ return name, ext.strip(".")
-@bp.route('/emojis/', methods=['GET'])
+@bp.route("/emojis/", methods=["GET"])
async def _get_raw_emoji(emoji_file):
# emoji = app.icons.get_emoji(emoji_id, ext=ext)
# just a test file for now
emoji_id, ext = splitext_(emoji_file)
- return await send_icon(
- 'emoji', emoji_id, None, ext=ext)
+ return await send_icon("emoji", emoji_id, None, ext=ext)
-@bp.route('/icons//', methods=['GET'])
+@bp.route("/icons//", methods=["GET"])
async def _get_guild_icon(guild_id: int, icon_file: str):
icon_hash, ext = splitext_(icon_file)
- return await send_icon('guild', guild_id, icon_hash, ext=ext)
+ return await send_icon("guild", guild_id, icon_hash, ext=ext)
-@bp.route('/embed/avatars/.png')
+@bp.route("/embed/avatars/.png")
async def _get_default_user_avatar(default_id: int):
# TODO: how do we determine which assets to use for this?
# I don't think we can use discord assets.
@@ -66,25 +64,29 @@ async def _get_default_user_avatar(default_id: int):
async def _handle_webhook_avatar(md_url_redir: str):
- md_url = make_md_req_url(app.config, 'img', EmbedURL(md_url_redir))
+ md_url = make_md_req_url(app.config, "img", EmbedURL(md_url_redir))
return redirect(md_url)
-@bp.route('/avatars//')
+@bp.route("/avatars//")
async def _get_user_avatar(user_id, avatar_file):
avatar_hash, ext = splitext_(avatar_file)
# first, check if this is a webhook avatar to redir to
- md_url_redir = await app.db.fetchval("""
+ md_url_redir = await app.db.fetchval(
+ """
SELECT md_url_redir
FROM webhook_avatars
WHERE webhook_id = $1 AND hash = $2
- """, user_id, avatar_hash)
+ """,
+ user_id,
+ avatar_hash,
+ )
if md_url_redir:
return await _handle_webhook_avatar(md_url_redir)
- return await send_icon('user', user_id, avatar_hash, ext=ext)
+ return await send_icon("user", user_id, avatar_hash, ext=ext)
# @bp.route('/app-icons//.')
@@ -92,19 +94,19 @@ async def get_app_icon(application_id, icon_hash, ext):
pass
-@bp.route('/channel-icons//', methods=['GET'])
+@bp.route("/channel-icons//", methods=["GET"])
async def _get_gdm_icon(channel_id: int, icon_file: str):
icon_hash, ext = splitext_(icon_file)
- return await send_icon('channel-icons', channel_id, icon_hash, ext=ext)
+ return await send_icon("channel-icons", channel_id, icon_hash, ext=ext)
-@bp.route('/splashes//', methods=['GET'])
+@bp.route("/splashes//", methods=["GET"])
async def _get_guild_splash(guild_id: int, icon_file: str):
icon_hash, ext = splitext_(icon_file)
- return await send_icon('splash', guild_id, icon_hash, ext=ext)
+ return await send_icon("splash", guild_id, icon_hash, ext=ext)
-@bp.route('/banners//', methods=['GET'])
+@bp.route("/banners//", methods=["GET"])
async def _get_guild_banner(guild_id: int, icon_file: str):
icon_hash, ext = splitext_(icon_file)
- return await send_icon('banner', guild_id, icon_hash, ext=ext)
+ return await send_icon("banner", guild_id, icon_hash, ext=ext)
diff --git a/litecord/blueprints/invites.py b/litecord/blueprints/invites.py
index 1b3506a..02b7610 100644
--- a/litecord/blueprints/invites.py
+++ b/litecord/blueprints/invites.py
@@ -32,13 +32,16 @@ from .guilds import create_guild_settings
from ..utils import async_map
from litecord.blueprints.checks import (
- channel_check, channel_perm_check, guild_check, guild_perm_check
+ channel_check,
+ channel_perm_check,
+ guild_check,
+ guild_perm_check,
)
from litecord.blueprints.dm_channels import gdm_is_member, gdm_add_recipient
log = Logger(__name__)
-bp = Blueprint('invites', __name__)
+bp = Blueprint("invites", __name__)
class UnknownInvite(BadRequest):
@@ -48,16 +51,18 @@ class UnknownInvite(BadRequest):
class InvalidInvite(Forbidden):
error_code = 50020
+
class AlreadyInvited(BaseException):
pass
+
def gen_inv_code() -> str:
"""Generate an invite code.
This is a primitive and does not guarantee uniqueness.
"""
raw = secrets.token_urlsafe(10)
- raw = re.sub(r'\/|\+|\-|\_', '', raw)
+ raw = re.sub(r"\/|\+|\-|\_", "", raw)
return raw[:7]
@@ -65,23 +70,31 @@ def gen_inv_code() -> str:
async def invite_precheck(user_id: int, guild_id: int):
"""pre-check invite use in the context of a guild."""
- joined = await app.db.fetchval("""
+ joined = await app.db.fetchval(
+ """
SELECT joined_at
FROM members
WHERE user_id = $1 AND guild_id = $2
- """, user_id, guild_id)
+ """,
+ user_id,
+ guild_id,
+ )
if joined is not None:
- raise AlreadyInvited('You are already in the guild')
+ raise AlreadyInvited("You are already in the guild")
- banned = await app.db.fetchval("""
+ banned = await app.db.fetchval(
+ """
SELECT reason
FROM bans
WHERE user_id = $1 AND guild_id = $2
- """, user_id, guild_id)
+ """,
+ user_id,
+ guild_id,
+ )
if banned is not None:
- raise InvalidInvite('You are banned.')
+ raise InvalidInvite("You are banned.")
async def invite_precheck_gdm(user_id: int, channel_id: int):
@@ -89,23 +102,23 @@ async def invite_precheck_gdm(user_id: int, channel_id: int):
is_member = await gdm_is_member(channel_id, user_id)
if is_member:
- raise AlreadyInvited('You are already in the Group DM')
+ raise AlreadyInvited("You are already in the Group DM")
async def _inv_check_age(inv: dict):
- if inv['max_age'] == 0:
+ if inv["max_age"] == 0:
return
now = datetime.datetime.utcnow()
- delta_sec = (now - inv['created_at']).total_seconds()
+ delta_sec = (now - inv["created_at"]).total_seconds()
- if delta_sec > inv['max_age']:
- await delete_invite(inv['code'])
- raise InvalidInvite('Invite is expired')
+ if delta_sec > inv["max_age"]:
+ await delete_invite(inv["code"])
+ raise InvalidInvite("Invite is expired")
- if inv['max_uses'] is not -1 and inv['uses'] > inv['max_uses']:
- await delete_invite(inv['code'])
- raise InvalidInvite('Too many uses')
+ if inv["max_uses"] is not -1 and inv["uses"] > inv["max_uses"]:
+ await delete_invite(inv["code"])
+ raise InvalidInvite("Too many uses")
async def _guild_add_member(guild_id: int, user_id: int):
@@ -119,78 +132,89 @@ async def _guild_add_member(guild_id: int, user_id: int):
"""
# TODO: system message for member join
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO members (user_id, guild_id)
VALUES ($1, $2)
- """, user_id, guild_id)
+ """,
+ user_id,
+ guild_id,
+ )
await create_guild_settings(guild_id, user_id)
# add the @everyone role to the invited member
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO member_roles (user_id, guild_id, role_id)
VALUES ($1, $2, $3)
- """, user_id, guild_id, guild_id)
+ """,
+ user_id,
+ guild_id,
+ guild_id,
+ )
# tell current members a new member came up
member = await app.storage.get_member_data_one(guild_id, user_id)
- await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_ADD', {
- **member,
- **{
- 'guild_id': str(guild_id),
- },
- })
+ await app.dispatcher.dispatch_guild(
+ guild_id, "GUILD_MEMBER_ADD", {**member, **{"guild_id": str(guild_id)}}
+ )
# update member lists for the new member
- await app.dispatcher.dispatch(
- 'lazy_guild', guild_id, 'new_member', user_id)
+ await app.dispatcher.dispatch("lazy_guild", guild_id, "new_member", user_id)
# subscribe new member to guild, so they get events n stuff
- await app.dispatcher.sub('guild', guild_id, user_id)
+ await app.dispatcher.sub("guild", guild_id, user_id)
# tell the new member that theres the guild it just joined.
# we use dispatch_user_guild so that we send the GUILD_CREATE
# just to the shards that are actually tied to it.
guild = await app.storage.get_guild_full(guild_id, user_id, 250)
- await app.dispatcher.dispatch_user_guild(
- user_id, guild_id, 'GUILD_CREATE', guild)
+ await app.dispatcher.dispatch_user_guild(user_id, guild_id, "GUILD_CREATE", guild)
async def use_invite(user_id, invite_code):
"""Try using an invite"""
- inv = await app.db.fetchrow("""
+ inv = await app.db.fetchrow(
+ """
SELECT code, channel_id, guild_id, created_at,
max_age, uses, max_uses
FROM invites
WHERE code = $1
- """, invite_code)
+ """,
+ invite_code,
+ )
if inv is None:
- raise UnknownInvite('Unknown invite')
+ raise UnknownInvite("Unknown invite")
await _inv_check_age(inv)
# NOTE: if group dm invite, guild_id is null.
- guild_id = inv['guild_id']
+ guild_id = inv["guild_id"]
- try:
+ try:
if guild_id is None:
- channel_id = inv['channel_id']
- await invite_precheck_gdm(user_id, inv['channel_id'])
+ channel_id = inv["channel_id"]
+ await invite_precheck_gdm(user_id, inv["channel_id"])
await gdm_add_recipient(channel_id, user_id)
else:
await invite_precheck(user_id, guild_id)
await _guild_add_member(guild_id, user_id)
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE invites
SET uses = uses + 1
WHERE code = $1
- """, invite_code)
+ """,
+ invite_code,
+ )
except AlreadyInvited:
pass
-@bp.route('/channels//invites', methods=['POST'])
+
+@bp.route("/channels//invites", methods=["POST"])
async def create_invite(channel_id):
"""Create an invite to a channel."""
user_id = await token_check()
@@ -201,12 +225,14 @@ async def create_invite(channel_id):
# NOTE: this works on group dms, since it returns ALL_PERMISSIONS on
# non-guild channels.
- await channel_perm_check(user_id, channel_id, 'create_invites')
+ await channel_perm_check(user_id, channel_id, "create_invites")
- if chantype not in (ChannelType.GUILD_TEXT,
- ChannelType.GUILD_VOICE,
- ChannelType.GROUP_DM):
- raise BadRequest('Invalid channel type')
+ if chantype not in (
+ ChannelType.GUILD_TEXT,
+ ChannelType.GUILD_VOICE,
+ ChannelType.GROUP_DM,
+ ):
+ raise BadRequest("Invalid channel type")
invite_code = gen_inv_code()
@@ -222,101 +248,122 @@ async def create_invite(channel_id):
max_age, temporary)
VALUES ($1, $2, $3, $4, $5, $6, $7)
""",
- invite_code, guild_id, channel_id, user_id,
- j['max_uses'], j['max_age'], j['temporary']
+ invite_code,
+ guild_id,
+ channel_id,
+ user_id,
+ j["max_uses"],
+ j["max_age"],
+ j["temporary"],
)
invite = await app.storage.get_invite(invite_code)
return jsonify(invite)
-@bp.route('/invite/', methods=['GET'])
-@bp.route('/invites/', methods=['GET'])
+@bp.route("/invite/", methods=["GET"])
+@bp.route("/invites/", methods=["GET"])
async def get_invite(invite_code: str):
inv = await app.storage.get_invite(invite_code)
if not inv:
- return '', 404
+ return "", 404
- if request.args.get('with_counts'):
+ if request.args.get("with_counts"):
extra = await app.storage.get_invite_extra(invite_code)
inv.update(extra)
return jsonify(inv)
+
async def delete_invite(invite_code: str):
"""Delete an invite."""
- await app.db.fetchval("""
+ await app.db.fetchval(
+ """
DELETE FROM invites
WHERE code = $1
- """, invite_code)
+ """,
+ invite_code,
+ )
-@bp.route('/invite/', methods=['DELETE'])
-@bp.route('/invites/', methods=['DELETE'])
+
+@bp.route("/invite/", methods=["DELETE"])
+@bp.route("/invites/", methods=["DELETE"])
async def _delete_invite(invite_code: str):
user_id = await token_check()
- guild_id = await app.db.fetchval("""
+ guild_id = await app.db.fetchval(
+ """
SELECT guild_id
FROM invites
WHERE code = $1
- """, invite_code)
+ """,
+ invite_code,
+ )
if guild_id is None:
- raise BadRequest('Unknown invite')
+ raise BadRequest("Unknown invite")
- await guild_perm_check(user_id, guild_id, 'manage_channels')
+ await guild_perm_check(user_id, guild_id, "manage_channels")
inv = await app.storage.get_invite(invite_code)
await delete_invite(invite_code)
return jsonify(inv)
+
async def _get_inv(code):
inv = await app.storage.get_invite(code)
meta = await app.storage.get_invite_metadata(code)
return {**inv, **meta}
-@bp.route('/guilds//invites', methods=['GET'])
+@bp.route("/guilds//invites", methods=["GET"])
async def get_guild_invites(guild_id: int):
"""Get all invites for a guild."""
user_id = await token_check()
await guild_check(user_id, guild_id)
- await guild_perm_check(user_id, guild_id, 'manage_guild')
+ await guild_perm_check(user_id, guild_id, "manage_guild")
- inv_codes = await app.db.fetch("""
+ inv_codes = await app.db.fetch(
+ """
SELECT code
FROM invites
WHERE guild_id = $1
- """, guild_id)
+ """,
+ guild_id,
+ )
- inv_codes = [r['code'] for r in inv_codes]
+ inv_codes = [r["code"] for r in inv_codes]
invs = await async_map(_get_inv, inv_codes)
return jsonify(invs)
-@bp.route('/channels//invites', methods=['GET'])
+@bp.route("/channels//invites", methods=["GET"])
async def get_channel_invites(channel_id: int):
"""Get all invites for a channel."""
user_id = await token_check()
_ctype, guild_id = await channel_check(user_id, channel_id)
- await guild_perm_check(user_id, guild_id, 'manage_channels')
+ await guild_perm_check(user_id, guild_id, "manage_channels")
- inv_codes = await app.db.fetch("""
+ inv_codes = await app.db.fetch(
+ """
SELECT code
FROM invites
WHERE guild_id = $1 AND channel_id = $2
- """, guild_id, channel_id)
+ """,
+ guild_id,
+ channel_id,
+ )
- inv_codes = [r['code'] for r in inv_codes]
+ inv_codes = [r["code"] for r in inv_codes]
invs = await async_map(_get_inv, inv_codes)
return jsonify(invs)
-@bp.route('/invite/', methods=['POST'])
-@bp.route('/invites/', methods=['POST'])
+@bp.route("/invite/", methods=["POST"])
+@bp.route("/invites/", methods=["POST"])
async def _use_invite(invite_code):
"""Use an invite."""
user_id = await token_check()
@@ -327,9 +374,4 @@ async def _use_invite(invite_code):
inv = await app.storage.get_invite(invite_code)
inv_meta = await app.storage.get_invite_metadata(invite_code)
- return jsonify({
- **inv,
- **{
- 'inviter': inv_meta['inviter']
- }
- })
+ return jsonify({**inv, **{"inviter": inv_meta["inviter"]}})
diff --git a/litecord/blueprints/nodeinfo.py b/litecord/blueprints/nodeinfo.py
index 75ecf6f..7afefb2 100644
--- a/litecord/blueprints/nodeinfo.py
+++ b/litecord/blueprints/nodeinfo.py
@@ -19,83 +19,75 @@ along with this program. If not, see .
from quart import Blueprint, current_app as app, jsonify, request
-bp = Blueprint('nodeinfo', __name__)
+bp = Blueprint("nodeinfo", __name__)
-@bp.route('/.well-known/nodeinfo')
+@bp.route("/.well-known/nodeinfo")
async def _dummy_nodeinfo_index():
- proto = 'http' if not app.config['IS_SSL'] else 'https'
- main_url = app.config.get('MAIN_URL', request.host)
+ proto = "http" if not app.config["IS_SSL"] else "https"
+ main_url = app.config.get("MAIN_URL", request.host)
- return jsonify({
- 'links': [{
- 'href': f'{proto}://{main_url}/nodeinfo/2.0.json',
- 'rel': 'http://nodeinfo.diaspora.software/ns/schema/2.0'
- }, {
- 'href': f'{proto}://{main_url}/nodeinfo/2.1.json',
- 'rel': 'http://nodeinfo.diaspora.software/ns/schema/2.1'
- }]
- })
+ return jsonify(
+ {
+ "links": [
+ {
+ "href": f"{proto}://{main_url}/nodeinfo/2.0.json",
+ "rel": "http://nodeinfo.diaspora.software/ns/schema/2.0",
+ },
+ {
+ "href": f"{proto}://{main_url}/nodeinfo/2.1.json",
+ "rel": "http://nodeinfo.diaspora.software/ns/schema/2.1",
+ },
+ ]
+ }
+ )
async def fetch_nodeinfo_20():
- usercount = await app.db.fetchval("""
+ usercount = await app.db.fetchval(
+ """
SELECT COUNT(*)
FROM users
- """)
+ """
+ )
- message_count = await app.db.fetchval("""
+ message_count = await app.db.fetchval(
+ """
SELECT COUNT(*)
FROM messages
- """)
+ """
+ )
return {
- 'metadata': {
- 'features': [
- 'discord_api'
- ],
-
- 'nodeDescription': 'A Litecord instance',
- 'nodeName': 'Litecord/Nya',
- 'private': False,
-
- 'federation': {}
+ "metadata": {
+ "features": ["discord_api"],
+ "nodeDescription": "A Litecord instance",
+ "nodeName": "Litecord/Nya",
+ "private": False,
+ "federation": {},
},
- 'openRegistrations': app.config['REGISTRATIONS'],
- 'protocols': [],
- 'software': {
- 'name': 'litecord',
- 'version': 'litecord v0',
- },
-
- 'services': {
- 'inbound': [],
- 'outbound': [],
- },
-
- 'usage': {
- 'localPosts': message_count,
- 'users': {
- 'total': usercount
- }
- },
- 'version': '2.0',
+ "openRegistrations": app.config["REGISTRATIONS"],
+ "protocols": [],
+ "software": {"name": "litecord", "version": "litecord v0"},
+ "services": {"inbound": [], "outbound": []},
+ "usage": {"localPosts": message_count, "users": {"total": usercount}},
+ "version": "2.0",
}
-@bp.route('/nodeinfo/2.0.json')
+@bp.route("/nodeinfo/2.0.json")
async def _nodeinfo_20():
"""Handler for nodeinfo 2.0."""
raw_nodeinfo = await fetch_nodeinfo_20()
return jsonify(raw_nodeinfo)
-@bp.route('/nodeinfo/2.1.json')
+@bp.route("/nodeinfo/2.1.json")
async def _nodeinfo_21():
"""Handler for nodeinfo 2.1."""
raw_nodeinfo = await fetch_nodeinfo_20()
- raw_nodeinfo['software']['repository'] = 'https://gitlab.com/litecord/litecord'
- raw_nodeinfo['version'] = '2.1'
+ raw_nodeinfo["software"]["repository"] = "https://gitlab.com/litecord/litecord"
+ raw_nodeinfo["version"] = "2.1"
return jsonify(raw_nodeinfo)
diff --git a/litecord/blueprints/relationships.py b/litecord/blueprints/relationships.py
index d93e90b..1fa9d01 100644
--- a/litecord/blueprints/relationships.py
+++ b/litecord/blueprints/relationships.py
@@ -26,76 +26,89 @@ from ..enums import RelationshipType
from litecord.errors import BadRequest
-bp = Blueprint('relationship', __name__)
+bp = Blueprint("relationship", __name__)
-@bp.route('/@me/relationships', methods=['GET'])
+@bp.route("/@me/relationships", methods=["GET"])
async def get_me_relationships():
user_id = await token_check()
- return jsonify(
- await app.user_storage.get_relationships(user_id))
+ return jsonify(await app.user_storage.get_relationships(user_id))
async def _dispatch_single_pres(user_id, presence: dict):
- await app.dispatcher.dispatch(
- 'user', user_id, 'PRESENCE_UPDATE', presence
- )
+ await app.dispatcher.dispatch("user", user_id, "PRESENCE_UPDATE", presence)
async def _unsub_friend(user_id, peer_id):
- await app.dispatcher.unsub('friend', user_id, peer_id)
- await app.dispatcher.unsub('friend', peer_id, user_id)
+ await app.dispatcher.unsub("friend", user_id, peer_id)
+ await app.dispatcher.unsub("friend", peer_id, user_id)
+
async def _sub_friend(user_id, peer_id):
- await app.dispatcher.sub('friend', user_id, peer_id)
- await app.dispatcher.sub('friend', peer_id, user_id)
+ await app.dispatcher.sub("friend", user_id, peer_id)
+ await app.dispatcher.sub("friend", peer_id, user_id)
# dispatch presence update to the user and peer about
# eachother's presence.
- user_pres, peer_pres = await app.presence.friend_presences(
- [user_id, peer_id]
- )
+ user_pres, peer_pres = await app.presence.friend_presences([user_id, peer_id])
await _dispatch_single_pres(user_id, peer_pres)
await _dispatch_single_pres(peer_id, user_pres)
-async def make_friend(user_id: int, peer_id: int,
- rel_type=RelationshipType.FRIEND.value):
+async def make_friend(
+ user_id: int, peer_id: int, rel_type=RelationshipType.FRIEND.value
+):
_friend = RelationshipType.FRIEND.value
_block = RelationshipType.BLOCK.value
try:
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO relationships (user_id, peer_id, rel_type)
VALUES ($1, $2, $3)
- """, user_id, peer_id, rel_type)
+ """,
+ user_id,
+ peer_id,
+ rel_type,
+ )
except UniqueViolationError:
# try to update rel_type
- old_rel_type = await app.db.fetchval("""
+ old_rel_type = await app.db.fetchval(
+ """
SELECT rel_type
FROM relationships
WHERE user_id = $1 AND peer_id = $2
- """, user_id, peer_id)
+ """,
+ user_id,
+ peer_id,
+ )
if old_rel_type == _friend and rel_type == _block:
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE relationships
SET rel_type = $1
WHERE user_id = $2 AND peer_id = $3
- """, rel_type, user_id, peer_id)
+ """,
+ rel_type,
+ user_id,
+ peer_id,
+ )
# remove any existing friendship before the block
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM relationships
WHERE peer_id = $1 AND user_id = $2 AND rel_type = $3
- """, peer_id, user_id, _friend)
+ """,
+ peer_id,
+ user_id,
+ _friend,
+ )
await app.dispatcher.dispatch_user(
- peer_id, 'RELATIONSHIP_REMOVE', {
- 'type': _friend,
- 'id': str(user_id)
- }
+ peer_id, "RELATIONSHIP_REMOVE", {"type": _friend, "id": str(user_id)}
)
await _unsub_friend(user_id, peer_id)
@@ -106,95 +119,118 @@ async def make_friend(user_id: int, peer_id: int,
# check if this is an acceptance
# of a friend request
- existing = await app.db.fetchrow("""
+ existing = await app.db.fetchrow(
+ """
SELECT user_id, peer_id
FROM relationships
WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3
- """, peer_id, user_id, _friend)
+ """,
+ peer_id,
+ user_id,
+ _friend,
+ )
_dispatch = app.dispatcher.dispatch_user
if existing:
# accepted a friend request, dispatch respective
# relationship events
- await _dispatch(user_id, 'RELATIONSHIP_REMOVE', {
- 'type': RelationshipType.INCOMING.value,
- 'id': str(peer_id)
- })
+ await _dispatch(
+ user_id,
+ "RELATIONSHIP_REMOVE",
+ {"type": RelationshipType.INCOMING.value, "id": str(peer_id)},
+ )
- await _dispatch(user_id, 'RELATIONSHIP_ADD', {
- 'type': _friend,
- 'id': str(peer_id),
- 'user': await app.storage.get_user(peer_id)
- })
+ await _dispatch(
+ user_id,
+ "RELATIONSHIP_ADD",
+ {
+ "type": _friend,
+ "id": str(peer_id),
+ "user": await app.storage.get_user(peer_id),
+ },
+ )
- await _dispatch(peer_id, 'RELATIONSHIP_ADD', {
- 'type': _friend,
- 'id': str(user_id),
- 'user': await app.storage.get_user(user_id)
- })
+ await _dispatch(
+ peer_id,
+ "RELATIONSHIP_ADD",
+ {
+ "type": _friend,
+ "id": str(user_id),
+ "user": await app.storage.get_user(user_id),
+ },
+ )
await _sub_friend(user_id, peer_id)
- return '', 204
+ return "", 204
# check if friend AND not acceptance of fr
if rel_type == _friend:
- await _dispatch(user_id, 'RELATIONSHIP_ADD', {
- 'id': str(peer_id),
- 'type': RelationshipType.OUTGOING.value,
- 'user': await app.storage.get_user(peer_id),
- })
+ await _dispatch(
+ user_id,
+ "RELATIONSHIP_ADD",
+ {
+ "id": str(peer_id),
+ "type": RelationshipType.OUTGOING.value,
+ "user": await app.storage.get_user(peer_id),
+ },
+ )
- await _dispatch(peer_id, 'RELATIONSHIP_ADD', {
- 'id': str(user_id),
- 'type': RelationshipType.INCOMING.value,
- 'user': await app.storage.get_user(user_id)
- })
+ await _dispatch(
+ peer_id,
+ "RELATIONSHIP_ADD",
+ {
+ "id": str(user_id),
+ "type": RelationshipType.INCOMING.value,
+ "user": await app.storage.get_user(user_id),
+ },
+ )
# we don't make the pubsub link
# until the peer accepts the friend request
- return '', 204
+ return "", 204
return
class RelationshipFailed(BadRequest):
"""Exception for general relationship errors."""
+
error_code = 80004
class RelationshipBlocked(BadRequest):
"""Exception for when the peer has blocked the user."""
+
error_code = 80001
-@bp.route('/@me/relationships', methods=['POST'])
+@bp.route("/@me/relationships", methods=["POST"])
async def post_relationship():
user_id = await token_check()
j = validate(await request.get_json(), SPECIFIC_FRIEND)
- uid = await app.storage.search_user(j['username'],
- str(j['discriminator']))
+ uid = await app.storage.search_user(j["username"], str(j["discriminator"]))
if not uid:
- raise RelationshipFailed('No users with DiscordTag exist')
+ raise RelationshipFailed("No users with DiscordTag exist")
res = await make_friend(user_id, uid)
if res is None:
- raise RelationshipBlocked('Can not friend user due to block')
+ raise RelationshipBlocked("Can not friend user due to block")
- return '', 204
+ return "", 204
-@bp.route('/@me/relationships/', methods=['PUT'])
+@bp.route("/@me/relationships/", methods=["PUT"])
async def add_relationship(peer_id: int):
"""Add a relationship to the peer."""
user_id = await token_check()
payload = validate(await request.get_json(), RELATIONSHIP)
- rel_type = payload['type']
+ rel_type = payload["type"]
res = await make_friend(user_id, peer_id, rel_type)
@@ -204,18 +240,22 @@ async def add_relationship(peer_id: int):
# make_friend did not succeed, so we
# assume it is a block and dispatch
# the respective RELATIONSHIP_ADD.
- await app.dispatcher.dispatch_user(user_id, 'RELATIONSHIP_ADD', {
- 'id': str(peer_id),
- 'type': RelationshipType.BLOCK.value,
- 'user': await app.storage.get_user(peer_id)
- })
+ await app.dispatcher.dispatch_user(
+ user_id,
+ "RELATIONSHIP_ADD",
+ {
+ "id": str(peer_id),
+ "type": RelationshipType.BLOCK.value,
+ "user": await app.storage.get_user(peer_id),
+ },
+ )
await _unsub_friend(user_id, peer_id)
- return '', 204
+ return "", 204
-@bp.route('/@me/relationships/', methods=['DELETE'])
+@bp.route("/@me/relationships/", methods=["DELETE"])
async def remove_relationship(peer_id: int):
"""Remove an existing relationship"""
user_id = await token_check()
@@ -223,69 +263,86 @@ async def remove_relationship(peer_id: int):
_block = RelationshipType.BLOCK.value
_dispatch = app.dispatcher.dispatch_user
- rel_type = await app.db.fetchval("""
+ rel_type = await app.db.fetchval(
+ """
SELECT rel_type
FROM relationships
WHERE user_id = $1 AND peer_id = $2
- """, user_id, peer_id)
+ """,
+ user_id,
+ peer_id,
+ )
- incoming_rel_type = await app.db.fetchval("""
+ incoming_rel_type = await app.db.fetchval(
+ """
SELECT rel_type
FROM relationships
WHERE user_id = $1 AND peer_id = $2
- """, peer_id, user_id)
+ """,
+ peer_id,
+ user_id,
+ )
# if any of those are friend
if _friend in (rel_type, incoming_rel_type):
# closing the friendship, have to delete both rows
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM relationships
WHERE (
(user_id = $1 AND peer_id = $2) OR
(user_id = $2 AND peer_id = $1)
) AND rel_type = $3
- """, user_id, peer_id, _friend)
+ """,
+ user_id,
+ peer_id,
+ _friend,
+ )
# if there wasnt any mutual friendship before,
# assume they were requests of INCOMING
# and OUTGOING.
- user_del_type = RelationshipType.OUTGOING.value if \
- incoming_rel_type != _friend else _friend
+ user_del_type = (
+ RelationshipType.OUTGOING.value if incoming_rel_type != _friend else _friend
+ )
- await _dispatch(user_id, 'RELATIONSHIP_REMOVE', {
- 'id': str(peer_id),
- 'type': user_del_type,
- })
+ await _dispatch(
+ user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": user_del_type}
+ )
- peer_del_type = RelationshipType.INCOMING.value if \
- incoming_rel_type != _friend else _friend
+ peer_del_type = (
+ RelationshipType.INCOMING.value if incoming_rel_type != _friend else _friend
+ )
- await _dispatch(peer_id, 'RELATIONSHIP_REMOVE', {
- 'id': str(user_id),
- 'type': peer_del_type,
- })
+ await _dispatch(
+ peer_id, "RELATIONSHIP_REMOVE", {"id": str(user_id), "type": peer_del_type}
+ )
await _unsub_friend(user_id, peer_id)
- return '', 204
+ return "", 204
# was a block!
- await app.db.execute("""
+ await app.db.execute(
+ """
DELETE FROM relationships
WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3
- """, user_id, peer_id, _block)
+ """,
+ user_id,
+ peer_id,
+ _block,
+ )
- await _dispatch(user_id, 'RELATIONSHIP_REMOVE', {
- 'id': str(peer_id),
- 'type': _block,
- })
+ await _dispatch(
+ user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": _block}
+ )
await _unsub_friend(user_id, peer_id)
- return '', 204
+ return "", 204
-@bp.route('//relationships', methods=['GET'])
+@bp.route("//relationships", methods=["GET"])
async def get_mutual_friends(peer_id: int):
"""Fetch a users' mutual friends with the current user."""
user_id = await token_check()
@@ -294,17 +351,15 @@ async def get_mutual_friends(peer_id: int):
peer = await app.storage.get_user(peer_id)
if not peer:
- return '', 204
+ return "", 204
# NOTE: maybe this could be better with pure SQL calculations
# but it would be beyond my current SQL knowledge, so...
user_rels = await app.user_storage.get_relationships(user_id)
peer_rels = await app.user_storage.get_relationships(peer_id)
- user_friends = {rel['user']['id']
- for rel in user_rels if rel['type'] == _friend}
- peer_friends = {rel['user']['id']
- for rel in peer_rels if rel['type'] == _friend}
+ user_friends = {rel["user"]["id"] for rel in user_rels if rel["type"] == _friend}
+ peer_friends = {rel["user"]["id"] for rel in peer_rels if rel["type"] == _friend}
# get the intersection, then map them to Storage.get_user() calls
mutual_ids = user_friends & peer_friends
@@ -312,8 +367,6 @@ async def get_mutual_friends(peer_id: int):
mutual_friends = []
for friend_id in mutual_ids:
- mutual_friends.append(
- await app.storage.get_user(int(friend_id))
- )
+ mutual_friends.append(await app.storage.get_user(int(friend_id)))
return jsonify(mutual_friends)
diff --git a/litecord/blueprints/science.py b/litecord/blueprints/science.py
index 80ffe15..7a5273c 100644
--- a/litecord/blueprints/science.py
+++ b/litecord/blueprints/science.py
@@ -19,21 +19,19 @@ along with this program. If not, see .
from quart import Blueprint, jsonify
-bp = Blueprint('science', __name__)
+bp = Blueprint("science", __name__)
-@bp.route('/science', methods=['POST'])
+@bp.route("/science", methods=["POST"])
async def science():
- return '', 204
+ return "", 204
-@bp.route('/applications', methods=['GET'])
+@bp.route("/applications", methods=["GET"])
async def applications():
return jsonify([])
-@bp.route('/experiments', methods=['GET'])
+@bp.route("/experiments", methods=["GET"])
async def experiments():
- return jsonify({
- 'assignments': []
- })
+ return jsonify({"assignments": []})
diff --git a/litecord/blueprints/static.py b/litecord/blueprints/static.py
index 5ce08df..61f2148 100644
--- a/litecord/blueprints/static.py
+++ b/litecord/blueprints/static.py
@@ -20,23 +20,24 @@ along with this program. If not, see .
from quart import Blueprint, current_app as app, render_template_string
from pathlib import Path
-bp = Blueprint('static', __name__)
+bp = Blueprint("static", __name__)
-@bp.route('/')
+@bp.route("/")
async def static_pages(path):
"""Map requests from / to /static."""
- if '..' in path:
- return 'no', 404
+ if ".." in path:
+ return "no", 404
- static_path = Path.cwd() / Path('static') / path
+ static_path = Path.cwd() / Path("static") / path
return await app.send_static_file(str(static_path))
-@bp.route('/')
-@bp.route('/api')
+@bp.route("/")
+@bp.route("/api")
async def index_handler():
"""Handler for the index page."""
- index_path = Path.cwd() / Path('static') / 'index.html'
+ index_path = Path.cwd() / Path("static") / "index.html"
return await render_template_string(
- index_path.read_text(), inst_name=app.config['NAME'])
+ index_path.read_text(), inst_name=app.config["NAME"]
+ )
diff --git a/litecord/blueprints/user/__init__.py b/litecord/blueprints/user/__init__.py
index 9e7b884..bd27fa4 100644
--- a/litecord/blueprints/user/__init__.py
+++ b/litecord/blueprints/user/__init__.py
@@ -21,4 +21,4 @@ from .billing import bp as user_billing
from .settings import bp as user_settings
from .fake_store import bp as fake_store
-__all__ = ['user_billing', 'user_settings', 'fake_store']
+__all__ = ["user_billing", "user_settings", "fake_store"]
diff --git a/litecord/blueprints/user/billing.py b/litecord/blueprints/user/billing.py
index f37744e..a1a4148 100644
--- a/litecord/blueprints/user/billing.py
+++ b/litecord/blueprints/user/billing.py
@@ -33,7 +33,7 @@ from litecord.enums import UserFlags, PremiumType
from litecord.blueprints.users import mass_user_update
log = Logger(__name__)
-bp = Blueprint('users_billing', __name__)
+bp = Blueprint("users_billing", __name__)
class PaymentSource(Enum):
@@ -68,78 +68,87 @@ class PaymentStatus:
PLAN_ID_TO_TYPE = {
- 'premium_month_tier_1': PremiumType.TIER_1,
- 'premium_month_tier_2': PremiumType.TIER_2,
- 'premium_year_tier_1': PremiumType.TIER_1,
- 'premium_year_tier_2': PremiumType.TIER_2,
+ "premium_month_tier_1": PremiumType.TIER_1,
+ "premium_month_tier_2": PremiumType.TIER_2,
+ "premium_year_tier_1": PremiumType.TIER_1,
+ "premium_year_tier_2": PremiumType.TIER_2,
}
# how much should a payment be, depending
# of the subscription
AMOUNTS = {
- 'premium_month_tier_1': 499,
- 'premium_month_tier_2': 999,
- 'premium_year_tier_1': 4999,
- 'premium_year_tier_2': 9999,
+ "premium_month_tier_1": 499,
+ "premium_month_tier_2": 999,
+ "premium_year_tier_1": 4999,
+ "premium_year_tier_2": 9999,
}
CREATE_SUBSCRIPTION = {
- 'payment_gateway_plan_id': {'type': 'string'},
- 'payment_source_id': {'coerce': int}
+ "payment_gateway_plan_id": {"type": "string"},
+ "payment_source_id": {"coerce": int},
}
PAYMENT_SOURCE = {
- 'billing_address': {
- 'type': 'dict',
- 'schema': {
- 'country': {'type': 'string', 'required': True},
- 'city': {'type': 'string', 'required': True},
- 'name': {'type': 'string', 'required': True},
- 'line_1': {'type': 'string', 'required': False},
- 'line_2': {'type': 'string', 'required': False},
- 'postal_code': {'type': 'string', 'required': True},
- 'state': {'type': 'string', 'required': True},
- }
+ "billing_address": {
+ "type": "dict",
+ "schema": {
+ "country": {"type": "string", "required": True},
+ "city": {"type": "string", "required": True},
+ "name": {"type": "string", "required": True},
+ "line_1": {"type": "string", "required": False},
+ "line_2": {"type": "string", "required": False},
+ "postal_code": {"type": "string", "required": True},
+ "state": {"type": "string", "required": True},
+ },
},
- 'payment_gateway': {'type': 'number', 'required': True},
- 'token': {'type': 'string', 'required': True},
+ "payment_gateway": {"type": "number", "required": True},
+ "token": {"type": "string", "required": True},
}
async def get_payment_source_ids(user_id: int) -> list:
- rows = await app.db.fetch("""
+ rows = await app.db.fetch(
+ """
SELECT id
FROM user_payment_sources
WHERE user_id = $1
- """, user_id)
+ """,
+ user_id,
+ )
- return [r['id'] for r in rows]
+ return [r["id"] for r in rows]
async def get_payment_ids(user_id: int, db=None) -> list:
if not db:
db = app.db
- rows = await db.fetch("""
+ rows = await db.fetch(
+ """
SELECT id
FROM user_payments
WHERE user_id = $1
- """, user_id)
+ """,
+ user_id,
+ )
- return [r['id'] for r in rows]
+ return [r["id"] for r in rows]
async def get_subscription_ids(user_id: int) -> list:
- rows = await app.db.fetch("""
+ rows = await app.db.fetch(
+ """
SELECT id
FROM user_subscriptions
WHERE user_id = $1
- """, user_id)
+ """,
+ user_id,
+ )
- return [r['id'] for r in rows]
+ return [r["id"] for r in rows]
async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
@@ -148,41 +157,44 @@ async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
if not db:
db = app.db
- source_type = await db.fetchval("""
+ source_type = await db.fetchval(
+ """
SELECT source_type
FROM user_payment_sources
WHERE id = $1 AND user_id = $2
- """, source_id, user_id)
+ """,
+ source_id,
+ user_id,
+ )
source_type = PaymentSource(source_type)
specific_fields = {
- PaymentSource.PAYPAL: ['paypal_email'],
- PaymentSource.CREDIT: ['expires_month', 'expires_year',
- 'brand', 'cc_full']
+ PaymentSource.PAYPAL: ["paypal_email"],
+ PaymentSource.CREDIT: ["expires_month", "expires_year", "brand", "cc_full"],
}[source_type]
- fields = ','.join(specific_fields)
+ fields = ",".join(specific_fields)
- extras_row = await db.fetchrow(f"""
+ extras_row = await db.fetchrow(
+ f"""
SELECT {fields}, billing_address, default_, id::text
FROM user_payment_sources
WHERE id = $1
- """, source_id)
+ """,
+ source_id,
+ )
derow = dict(extras_row)
if source_type == PaymentSource.CREDIT:
- derow['last_4'] = derow['cc_full'][-4:]
- derow.pop('cc_full')
+ derow["last_4"] = derow["cc_full"][-4:]
+ derow.pop("cc_full")
- derow['default'] = derow['default_']
- derow.pop('default_')
+ derow["default"] = derow["default_"]
+ derow.pop("default_")
- source = {
- 'id': str(source_id),
- 'type': source_type.value,
- }
+ source = {"id": str(source_id), "type": source_type.value}
return {**source, **derow}
@@ -192,7 +204,8 @@ async def get_subscription(subscription_id: int, db=None):
if not db:
db = app.db
- row = await db.fetchrow("""
+ row = await db.fetchrow(
+ """
SELECT id::text, source_id::text AS payment_source_id,
user_id,
payment_gateway, payment_gateway_plan_id,
@@ -201,14 +214,16 @@ async def get_subscription(subscription_id: int, db=None):
canceled_at, s_type, status
FROM user_subscriptions
WHERE id = $1
- """, subscription_id)
+ """,
+ subscription_id,
+ )
drow = dict(row)
- drow['type'] = drow['s_type']
- drow.pop('s_type')
+ drow["type"] = drow["s_type"]
+ drow.pop("s_type")
- to_tstamp = ['current_period_start', 'current_period_end', 'canceled_at']
+ to_tstamp = ["current_period_start", "current_period_end", "canceled_at"]
for field in to_tstamp:
drow[field] = timestamp_(drow[field])
@@ -221,27 +236,30 @@ async def get_payment(payment_id: int, db=None):
if not db:
db = app.db
- row = await db.fetchrow("""
+ row = await db.fetchrow(
+ """
SELECT id::text, source_id, subscription_id, user_id,
amount, amount_refunded, currency,
description, status, tax, tax_inclusive
FROM user_payments
WHERE id = $1
- """, payment_id)
+ """,
+ payment_id,
+ )
drow = dict(row)
- drow.pop('source_id')
- drow.pop('subscription_id')
- drow.pop('user_id')
+ drow.pop("source_id")
+ drow.pop("subscription_id")
+ drow.pop("user_id")
- drow['created_at'] = snowflake_datetime(int(drow['id']))
+ drow["created_at"] = snowflake_datetime(int(drow["id"]))
- drow['payment_source'] = await get_payment_source(
- row['user_id'], row['source_id'], db)
+ drow["payment_source"] = await get_payment_source(
+ row["user_id"], row["source_id"], db
+ )
- drow['subscription'] = await get_subscription(
- row['subscription_id'], db)
+ drow["subscription"] = await get_subscription(row["subscription_id"], db)
return drow
@@ -255,7 +273,7 @@ async def create_payment(subscription_id, db=None):
new_id = get_snowflake()
- amount = AMOUNTS[sub['payment_gateway_plan_id']]
+ amount = AMOUNTS[sub["payment_gateway_plan_id"]]
await db.execute(
"""
@@ -266,10 +284,16 @@ async def create_payment(subscription_id, db=None):
)
VALUES
($1, $2, $3, $4, $5, 0, $6, $7, $8, 0, false)
- """, new_id, int(sub['payment_source_id']),
- subscription_id, int(sub['user_id']),
- amount, 'usd', 'FUCK NITRO',
- PaymentStatus.SUCCESS)
+ """,
+ new_id,
+ int(sub["payment_source_id"]),
+ subscription_id,
+ int(sub["user_id"]),
+ amount,
+ "usd",
+ "FUCK NITRO",
+ PaymentStatus.SUCCESS,
+ )
return new_id
@@ -278,29 +302,34 @@ async def process_subscription(app, subscription_id: int):
"""Process a single subscription."""
sub = await get_subscription(subscription_id, app.db)
- user_id = int(sub['user_id'])
+ user_id = int(sub["user_id"])
- if sub['status'] != SubscriptionStatus.ACTIVE:
- log.debug('ignoring sub {}, not active',
- subscription_id)
+ if sub["status"] != SubscriptionStatus.ACTIVE:
+ log.debug("ignoring sub {}, not active", subscription_id)
return
# if the subscription is still active
# (should get cancelled status on failed
# payments), then we should update premium status
- first_payment_id = await app.db.fetchval("""
+ first_payment_id = await app.db.fetchval(
+ """
SELECT MIN(id)
FROM user_payments
WHERE subscription_id = $1
- """, subscription_id)
+ """,
+ subscription_id,
+ )
first_payment_ts = snowflake_datetime(first_payment_id)
- premium_since = await app.db.fetchval("""
+ premium_since = await app.db.fetchval(
+ """
SELECT premium_since
FROM users
WHERE id = $1
- """, user_id)
+ """,
+ user_id,
+ )
premium_since = premium_since or datetime.datetime.fromtimestamp(0)
@@ -312,27 +341,34 @@ async def process_subscription(app, subscription_id: int):
if delta.total_seconds() < 24 * HOURS:
return
- old_flags = await app.db.fetchval("""
+ old_flags = await app.db.fetchval(
+ """
SELECT flags
FROM users
WHERE id = $1
- """, user_id)
+ """,
+ user_id,
+ )
new_flags = old_flags | UserFlags.premium_early
- log.debug('updating flags {}, {} => {}',
- user_id, old_flags, new_flags)
+ log.debug("updating flags {}, {} => {}", user_id, old_flags, new_flags)
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE users
SET premium_since = $1, flags = $2
WHERE id = $3
- """, first_payment_ts, new_flags, user_id)
+ """,
+ first_payment_ts,
+ new_flags,
+ user_id,
+ )
# dispatch updated user to all possible clients
await mass_user_update(user_id, app)
-@bp.route('/@me/billing/payment-sources', methods=['GET'])
+@bp.route("/@me/billing/payment-sources", methods=["GET"])
async def _get_billing_sources():
user_id = await token_check()
source_ids = await get_payment_source_ids(user_id)
@@ -346,7 +382,7 @@ async def _get_billing_sources():
return jsonify(res)
-@bp.route('/@me/billing/subscriptions', methods=['GET'])
+@bp.route("/@me/billing/subscriptions", methods=["GET"])
async def _get_billing_subscriptions():
user_id = await token_check()
sub_ids = await get_subscription_ids(user_id)
@@ -358,7 +394,7 @@ async def _get_billing_subscriptions():
return jsonify(res)
-@bp.route('/@me/billing/payments', methods=['GET'])
+@bp.route("/@me/billing/payments", methods=["GET"])
async def _get_billing_payments():
user_id = await token_check()
payment_ids = await get_payment_ids(user_id)
@@ -370,7 +406,7 @@ async def _get_billing_payments():
return jsonify(res)
-@bp.route('/@me/billing/payment-sources', methods=['POST'])
+@bp.route("/@me/billing/payment-sources", methods=["POST"])
async def _create_payment_source():
user_id = await token_check()
j = validate(await request.get_json(), PAYMENT_SOURCE)
@@ -383,34 +419,40 @@ async def _create_payment_source():
default_, expires_month, expires_year, brand, cc_full,
billing_address)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
- """, new_source_id, user_id, PaymentSource.CREDIT.value,
- True, 12, 6969, 'Visa', '4242424242424242',
- json.dumps(j['billing_address']))
-
- return jsonify(
- await get_payment_source(user_id, new_source_id)
+ """,
+ new_source_id,
+ user_id,
+ PaymentSource.CREDIT.value,
+ True,
+ 12,
+ 6969,
+ "Visa",
+ "4242424242424242",
+ json.dumps(j["billing_address"]),
)
+ return jsonify(await get_payment_source(user_id, new_source_id))
-@bp.route('/@me/billing/subscriptions', methods=['POST'])
+
+@bp.route("/@me/billing/subscriptions", methods=["POST"])
async def _create_subscription():
user_id = await token_check()
j = validate(await request.get_json(), CREATE_SUBSCRIPTION)
- source = await get_payment_source(user_id, j['payment_source_id'])
+ source = await get_payment_source(user_id, j["payment_source_id"])
if not source:
- raise BadRequest('invalid source id')
+ raise BadRequest("invalid source id")
- plan_id = j['payment_gateway_plan_id']
+ plan_id = j["payment_gateway_plan_id"]
# tier 1 is lightro / classic
# tier 2 is nitro
period_end = {
- 'premium_month_tier_1': '1 month',
- 'premium_month_tier_2': '1 month',
- 'premium_year_tier_1': '1 year',
- 'premium_year_tier_2': '1 year',
+ "premium_month_tier_1": "1 month",
+ "premium_month_tier_2": "1 month",
+ "premium_year_tier_1": "1 year",
+ "premium_year_tier_2": "1 year",
}[plan_id]
new_id = get_snowflake()
@@ -422,9 +464,15 @@ async def _create_subscription():
status, period_end)
VALUES ($1, $2, $3, $4, $5, $6, $7,
now()::timestamp + interval '{period_end}')
- """, new_id, j['payment_source_id'], user_id,
- SubscriptionType.PURCHASE, PaymentGateway.STRIPE,
- plan_id, 1)
+ """,
+ new_id,
+ j["payment_source_id"],
+ user_id,
+ SubscriptionType.PURCHASE,
+ PaymentGateway.STRIPE,
+ plan_id,
+ 1,
+ )
await create_payment(new_id, app.db)
@@ -432,21 +480,17 @@ async def _create_subscription():
# and dispatch respective user updates to other people.
await process_subscription(app, new_id)
- return jsonify(
- await get_subscription(new_id)
- )
+ return jsonify(await get_subscription(new_id))
-@bp.route('/@me/billing/subscriptions/',
- methods=['DELETE'])
+@bp.route("/@me/billing/subscriptions/", methods=["DELETE"])
async def _delete_subscription(subscription_id):
# user_id = await token_check()
# return '', 204
pass
-@bp.route('/@me/billing/subscriptions/',
- methods=['PATCH'])
+@bp.route("/@me/billing/subscriptions/", methods=["PATCH"])
async def _patch_subscription(subscription_id):
"""change a subscription's payment source"""
# user_id = await token_check()
diff --git a/litecord/blueprints/user/billing_job.py b/litecord/blueprints/user/billing_job.py
index b5e9f7f..4148415 100644
--- a/litecord/blueprints/user/billing_job.py
+++ b/litecord/blueprints/user/billing_job.py
@@ -25,8 +25,11 @@ from asyncio import sleep, CancelledError
from logbook import Logger
from litecord.blueprints.user.billing import (
- get_subscription, get_payment_ids, get_payment, create_payment,
- process_subscription
+ get_subscription,
+ get_payment_ids,
+ get_payment,
+ create_payment,
+ process_subscription,
)
from litecord.snowflake import snowflake_datetime
@@ -37,15 +40,15 @@ log = Logger(__name__)
# how many days until a payment needs
# to be issued
THRESHOLDS = {
- 'premium_month_tier_1': 30,
- 'premium_month_tier_2': 30,
- 'premium_year_tier_1': 365,
- 'premium_year_tier_2': 365,
+ "premium_month_tier_1": 30,
+ "premium_month_tier_2": 30,
+ "premium_year_tier_1": 365,
+ "premium_year_tier_2": 365,
}
async def _resched(app):
- log.debug('waiting 30 minutes for job.')
+ log.debug("waiting 30 minutes for job.")
await sleep(30 * MINUTES)
app.sched.spawn(payment_job(app))
@@ -54,10 +57,10 @@ async def _process_user_payments(app, user_id: int):
payments = await get_payment_ids(user_id, app.db)
if not payments:
- log.debug('no payments for uid {}, skipping', user_id)
+ log.debug("no payments for uid {}, skipping", user_id)
return
- log.debug('{} payments for uid {}', len(payments), user_id)
+ log.debug("{} payments for uid {}", len(payments), user_id)
latest_payment = max(payments)
@@ -66,33 +69,29 @@ async def _process_user_payments(app, user_id: int):
# calculate the difference between this payment
# and now.
now = datetime.datetime.now()
- payment_tstamp = snowflake_datetime(int(payment_data['id']))
+ payment_tstamp = snowflake_datetime(int(payment_data["id"]))
delta = now - payment_tstamp
- sub_id = int(payment_data['subscription']['id'])
- subscription = await get_subscription(
- sub_id, app.db)
+ sub_id = int(payment_data["subscription"]["id"])
+ subscription = await get_subscription(sub_id, app.db)
# if the max payment is X days old, we create another.
# X is 30 for monthly subscriptions of nitro,
# X is 365 for yearly subscriptions of nitro
- threshold = THRESHOLDS[subscription['payment_gateway_plan_id']]
+ threshold = THRESHOLDS[subscription["payment_gateway_plan_id"]]
- log.debug('delta {} delta days {} threshold {}',
- delta, delta.days, threshold)
+ log.debug("delta {} delta days {} threshold {}", delta, delta.days, threshold)
if delta.days > threshold:
- log.info('creating payment for sid={}',
- sub_id)
+ log.info("creating payment for sid={}", sub_id)
# create_payment does not call any Stripe
# or BrainTree APIs at all, since we'll just
# give it as free.
await create_payment(sub_id, app.db)
else:
- log.debug('sid={}, missing {} days',
- sub_id, threshold - delta.days)
+ log.debug("sid={}, missing {} days", sub_id, threshold - delta.days)
async def payment_job(app):
@@ -101,35 +100,39 @@ async def payment_job(app):
This function will check through users' payments
and add a new one once a month / year.
"""
- log.debug('payment job start!')
+ log.debug("payment job start!")
- user_ids = await app.db.fetch("""
+ user_ids = await app.db.fetch(
+ """
SELECT DISTINCT user_id
FROM user_payments
- """)
+ """
+ )
- log.debug('working {} users', len(user_ids))
+ log.debug("working {} users", len(user_ids))
# go through each user's payments
for row in user_ids:
- user_id = row['user_id']
+ user_id = row["user_id"]
try:
await _process_user_payments(app, user_id)
except Exception:
- log.exception('error while processing user payments')
+ log.exception("error while processing user payments")
- subscribers = await app.db.fetch("""
+ subscribers = await app.db.fetch(
+ """
SELECT id
FROM user_subscriptions
- """)
+ """
+ )
for row in subscribers:
try:
- await process_subscription(app, row['id'])
+ await process_subscription(app, row["id"])
except Exception:
- log.exception('error while processing subscription')
- log.debug('rescheduling..')
+ log.exception("error while processing subscription")
+ log.debug("rescheduling..")
try:
await _resched(app)
except CancelledError:
- log.info('cancelled while waiting for resched')
+ log.info("cancelled while waiting for resched")
diff --git a/litecord/blueprints/user/fake_store.py b/litecord/blueprints/user/fake_store.py
index 68c3dff..7f4fc3c 100644
--- a/litecord/blueprints/user/fake_store.py
+++ b/litecord/blueprints/user/fake_store.py
@@ -22,24 +22,26 @@ fake routes for discord store
"""
from quart import Blueprint, jsonify
-bp = Blueprint('fake_store', __name__)
+bp = Blueprint("fake_store", __name__)
-@bp.route('/promotions')
+@bp.route("/promotions")
async def _get_promotions():
return jsonify([])
-@bp.route('/users/@me/library')
+@bp.route("/users/@me/library")
async def _get_library():
return jsonify([])
-@bp.route('/users/@me/feed/settings')
+@bp.route("/users/@me/feed/settings")
async def _get_feed_settings():
- return jsonify({
- 'subscribed_games': [],
- 'subscribed_users': [],
- 'unsubscribed_users': [],
- 'unsubscribed_games': [],
- })
+ return jsonify(
+ {
+ "subscribed_games": [],
+ "subscribed_users": [],
+ "unsubscribed_users": [],
+ "unsubscribed_games": [],
+ }
+ )
diff --git a/litecord/blueprints/user/settings.py b/litecord/blueprints/user/settings.py
index 2cc73dc..e64e27e 100644
--- a/litecord/blueprints/user/settings.py
+++ b/litecord/blueprints/user/settings.py
@@ -23,10 +23,10 @@ from litecord.auth import token_check
from litecord.schemas import validate, USER_SETTINGS, GUILD_SETTINGS
from litecord.blueprints.checks import guild_check
-bp = Blueprint('users_settings', __name__)
+bp = Blueprint("users_settings", __name__)
-@bp.route('/@me/settings', methods=['GET'])
+@bp.route("/@me/settings", methods=["GET"])
async def get_user_settings():
"""Get the current user's settings."""
user_id = await token_check()
@@ -34,7 +34,7 @@ async def get_user_settings():
return jsonify(settings)
-@bp.route('/@me/settings', methods=['PATCH'])
+@bp.route("/@me/settings", methods=["PATCH"])
async def patch_current_settings():
"""Patch the users' current settings.
@@ -47,19 +47,22 @@ async def patch_current_settings():
for key in j:
val = j[key]
- await app.storage.execute_with_json(f"""
+ await app.storage.execute_with_json(
+ f"""
UPDATE user_settings
SET {key}=$1
WHERE id = $2
- """, val, user_id)
+ """,
+ val,
+ user_id,
+ )
settings = await app.user_storage.get_user_settings(user_id)
- await app.dispatcher.dispatch_user(
- user_id, 'USER_SETTINGS_UPDATE', settings)
+ await app.dispatcher.dispatch_user(user_id, "USER_SETTINGS_UPDATE", settings)
return jsonify(settings)
-@bp.route('/@me/guilds//settings', methods=['PATCH'])
+@bp.route("/@me/guilds//settings", methods=["PATCH"])
async def patch_guild_settings(guild_id: int):
"""Update the users' guild settings for a given guild.
@@ -74,16 +77,21 @@ async def patch_guild_settings(guild_id: int):
# will make sure they exist in the table.
await app.user_storage.get_guild_settings_one(user_id, guild_id)
- for field in (k for k in j.keys() if k != 'channel_overrides'):
- await app.db.execute(f"""
+ for field in (k for k in j.keys() if k != "channel_overrides"):
+ await app.db.execute(
+ f"""
UPDATE guild_settings
SET {field} = $1
WHERE user_id = $2 AND guild_id = $3
- """, j[field], user_id, guild_id)
+ """,
+ j[field],
+ user_id,
+ guild_id,
+ )
chan_ids = await app.storage.get_channel_ids(guild_id)
- for chandata in j.get('channel_overrides', {}).items():
+ for chandata in j.get("channel_overrides", {}).items():
chan_id, chan_overrides = chandata
chan_id = int(chan_id)
@@ -92,7 +100,8 @@ async def patch_guild_settings(guild_id: int):
continue
for field in chan_overrides:
- await app.db.execute(f"""
+ await app.db.execute(
+ f"""
INSERT INTO guild_settings_channel_overrides
(user_id, guild_id, channel_id, {field})
VALUES
@@ -105,18 +114,21 @@ async def patch_guild_settings(guild_id: int):
WHERE guild_settings_channel_overrides.user_id = $1
AND guild_settings_channel_overrides.guild_id = $2
AND guild_settings_channel_overrides.channel_id = $3
- """, user_id, guild_id, chan_id, chan_overrides[field])
+ """,
+ user_id,
+ guild_id,
+ chan_id,
+ chan_overrides[field],
+ )
- settings = await app.user_storage.get_guild_settings_one(
- user_id, guild_id)
+ settings = await app.user_storage.get_guild_settings_one(user_id, guild_id)
- await app.dispatcher.dispatch_user(
- user_id, 'USER_GUILD_SETTINGS_UPDATE', settings)
+ await app.dispatcher.dispatch_user(user_id, "USER_GUILD_SETTINGS_UPDATE", settings)
return jsonify(settings)
-@bp.route('/@me/notes/', methods=['PUT'])
+@bp.route("/@me/notes/", methods=["PUT"])
async def put_note(target_id: int):
"""Put a note to a user.
@@ -126,10 +138,11 @@ async def put_note(target_id: int):
user_id = await token_check()
j = await request.get_json()
- note = str(j['note'])
+ note = str(j["note"])
# UPSERTs are beautiful
- await app.db.execute("""
+ await app.db.execute(
+ """
INSERT INTO notes (user_id, target_id, note)
VALUES ($1, $2, $3)
@@ -138,12 +151,14 @@ async def put_note(target_id: int):
note = $3
WHERE notes.user_id = $1
AND notes.target_id = $2
- """, user_id, target_id, note)
+ """,
+ user_id,
+ target_id,
+ note,
+ )
- await app.dispatcher.dispatch_user(user_id, 'USER_NOTE_UPDATE', {
- 'id': str(target_id),
- 'note': note,
- })
-
- return '', 204
+ await app.dispatcher.dispatch_user(
+ user_id, "USER_NOTE_UPDATE", {"id": str(target_id), "note": note}
+ )
+ return "", 204
diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py
index 64cc73e..d5bda02 100644
--- a/litecord/blueprints/users.py
+++ b/litecord/blueprints/users.py
@@ -27,9 +27,7 @@ from ..errors import Forbidden, BadRequest, Unauthorized
from ..schemas import validate, USER_UPDATE, GET_MENTIONS
from .guilds import guild_check
-from litecord.auth import (
- token_check, hash_data, check_username_usage, roll_discrim
-)
+from litecord.auth import token_check, hash_data, check_username_usage, roll_discrim
from litecord.blueprints.guild.mod import remove_member
from litecord.enums import PremiumType
@@ -39,7 +37,7 @@ from litecord.permissions import base_permissions
from litecord.blueprints.auth import check_password
from litecord.utils import to_update
-bp = Blueprint('user', __name__)
+bp = Blueprint("user", __name__)
log = Logger(__name__)
@@ -58,8 +56,7 @@ async def mass_user_update(user_id, app_=None):
private_user = await app_.storage.get_user(user_id, secure=True)
session_ids.extend(
- await app_.dispatcher.dispatch_user(
- user_id, 'USER_UPDATE', private_user)
+ await app_.dispatcher.dispatch_user(user_id, "USER_UPDATE", private_user)
)
guild_ids = await app_.user_storage.get_user_guilds(user_id)
@@ -67,26 +64,22 @@ async def mass_user_update(user_id, app_=None):
session_ids.extend(
await app_.dispatcher.dispatch_many_filter_list(
- 'guild', guild_ids, session_ids,
- 'USER_UPDATE', public_user
+ "guild", guild_ids, session_ids, "USER_UPDATE", public_user
)
)
session_ids.extend(
await app_.dispatcher.dispatch_many_filter_list(
- 'friend', friend_ids, session_ids,
- 'USER_UPDATE', public_user
+ "friend", friend_ids, session_ids, "USER_UPDATE", public_user
)
)
- await app_.dispatcher.dispatch_many(
- 'lazy_guild', guild_ids, 'update_user', user_id
- )
+ await app_.dispatcher.dispatch_many("lazy_guild", guild_ids, "update_user", user_id)
return public_user, private_user
-@bp.route('/@me', methods=['GET'])
+@bp.route("/@me", methods=["GET"])
async def get_me():
"""Get the current user's information."""
user_id = await token_check()
@@ -94,18 +87,21 @@ async def get_me():
return jsonify(user)
-@bp.route('/', methods=['GET'])
+@bp.route("/", methods=["GET"])
async def get_other(target_id):
"""Get any user, given the user ID."""
user_id = await token_check()
- bot = await app.db.fetchval("""
+ bot = await app.db.fetchval(
+ """
SELECT bot FROM users
WHERE users.id = $1
- """, user_id)
+ """,
+ user_id,
+ )
if not bot:
- raise Forbidden('Only bots can use this endpoint')
+ raise Forbidden("Only bots can use this endpoint")
other = await app.storage.get_user(target_id)
return jsonify(other)
@@ -116,66 +112,80 @@ async def _try_username_patch(user_id, new_username: str) -> str:
discrim = None
try:
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE users
SET username = $1
WHERE users.id = $2
- """, new_username, user_id)
+ """,
+ new_username,
+ user_id,
+ )
- return await app.db.fetchval("""
+ return await app.db.fetchval(
+ """
SELECT discriminator
FROM users
WHERE users.id = $1
- """, user_id)
+ """,
+ user_id,
+ )
except UniqueViolationError:
discrim = await roll_discrim(new_username)
if not discrim:
- raise BadRequest('Unable to change username', {
- 'username': 'Too many people are with this username.'
- })
+ raise BadRequest(
+ "Unable to change username",
+ {"username": "Too many people are with this username."},
+ )
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE users
SET username = $1, discriminator = $2
WHERE users.id = $3
- """, new_username, discrim, user_id)
+ """,
+ new_username,
+ discrim,
+ user_id,
+ )
return discrim
async def _try_discrim_patch(user_id, new_discrim: str):
try:
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE users
SET discriminator = $1
WHERE id = $2
- """, new_discrim, user_id)
+ """,
+ new_discrim,
+ user_id,
+ )
except UniqueViolationError:
- raise BadRequest('Invalid discriminator', {
- 'discriminator': 'Someone already used this discriminator.'
- })
+ raise BadRequest(
+ "Invalid discriminator",
+ {"discriminator": "Someone already used this discriminator."},
+ )
async def _check_pass(j, user):
# Do not do password checks on unclaimed accounts
- if user['email'] is None:
+ if user["email"] is None:
return
- if not j['password']:
- raise BadRequest('password required', {
- 'password': 'password required'
- })
+ if not j["password"]:
+ raise BadRequest("password required", {"password": "password required"})
- phash = user['password_hash']
+ phash = user["password_hash"]
- if not await check_password(phash, j['password']):
- raise BadRequest('password incorrect', {
- 'password': 'password does not match.'
- })
+ if not await check_password(phash, j["password"]):
+ raise BadRequest("password incorrect", {"password": "password does not match."})
-@bp.route('/@me', methods=['PATCH'])
+@bp.route("/@me", methods=["PATCH"])
async def patch_me():
"""Patch the current user's information."""
user_id = await token_check()
@@ -183,36 +193,43 @@ async def patch_me():
j = validate(await request.get_json(), USER_UPDATE)
user = await app.storage.get_user(user_id, True)
- user['password_hash'] = await app.db.fetchval("""
+ user["password_hash"] = await app.db.fetchval(
+ """
SELECT password_hash
FROM users
WHERE id = $1
- """, user_id)
+ """,
+ user_id,
+ )
- if to_update(j, user, 'username'):
+ if to_update(j, user, "username"):
# this will take care of regenning a new discriminator
- discrim = await _try_username_patch(user_id, j['username'])
- user['username'] = j['username']
- user['discriminator'] = discrim
+ discrim = await _try_username_patch(user_id, j["username"])
+ user["username"] = j["username"]
+ user["discriminator"] = discrim
- if to_update(j, user, 'discriminator'):
+ if to_update(j, user, "discriminator"):
# the API treats discriminators as integers,
# but I work with strings on the database.
- new_discrim = str(j['discriminator'])
+ new_discrim = str(j["discriminator"])
await _try_discrim_patch(user_id, new_discrim)
- user['discriminator'] = new_discrim
+ user["discriminator"] = new_discrim
- if to_update(j, user, 'email'):
+ if to_update(j, user, "email"):
await _check_pass(j, user)
# TODO: reverify the new email?
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE users
SET email = $1
WHERE id = $2
- """, j['email'], user_id)
- user['email'] = j['email']
+ """,
+ j["email"],
+ user_id,
+ )
+ user["email"] = j["email"]
# only update if values are different
# from what the user gave.
@@ -224,44 +241,49 @@ async def patch_me():
# IconManager.update will take care of validating
# the value once put()-ing
- if to_update(j, user, 'avatar'):
- mime, _ = parse_data_uri(j['avatar'])
+ if to_update(j, user, "avatar"):
+ mime, _ = parse_data_uri(j["avatar"])
- if mime == 'image/gif' and user['premium_type'] == PremiumType.NONE:
- raise BadRequest('no gif without nitro')
+ if mime == "image/gif" and user["premium_type"] == PremiumType.NONE:
+ raise BadRequest("no gif without nitro")
- new_icon = await app.icons.update(
- 'user', user_id, j['avatar'], size=(128, 128))
+ new_icon = await app.icons.update("user", user_id, j["avatar"], size=(128, 128))
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE users
SET avatar = $1
WHERE id = $2
- """, new_icon.icon_hash, user_id)
+ """,
+ new_icon.icon_hash,
+ user_id,
+ )
- if user['email'] is None and not 'new_password' in j:
- raise BadRequest('missing password', {
- 'password': 'Please set a password.'
- })
+ if user["email"] is None and not "new_password" in j:
+ raise BadRequest("missing password", {"password": "Please set a password."})
- if 'new_password' in j and j['new_password']:
+ if "new_password" in j and j["new_password"]:
await _check_pass(j, user)
- new_hash = await hash_data(j['new_password'])
+ new_hash = await hash_data(j["new_password"])
- await app.db.execute("""
+ await app.db.execute(
+ """
UPDATE users
SET password_hash = $1
WHERE id = $2
- """, new_hash, user_id)
+ """,
+ new_hash,
+ user_id,
+ )
- user.pop('password_hash')
+ user.pop("password_hash")
_, private_user = await mass_user_update(user_id, app)
return jsonify(private_user)
-@bp.route('/@me/guilds', methods=['GET'])
+@bp.route("/@me/guilds", methods=["GET"])
async def get_me_guilds():
"""Get partial user guilds."""
user_id = await token_check()
@@ -270,27 +292,30 @@ async def get_me_guilds():
partials = []
for guild_id in guild_ids:
- partial = await app.db.fetchrow("""
+ partial = await app.db.fetchrow(
+ """
SELECT id::text, name, icon, owner_id
FROM guilds
WHERE guilds.id = $1
- """, guild_id)
+ """,
+ guild_id,
+ )
partial = dict(partial)
user_perms = await base_permissions(user_id, guild_id)
- partial['permissions'] = user_perms.binary
+ partial["permissions"] = user_perms.binary
- partial['owner'] = partial['owner_id'] == user_id
+ partial["owner"] = partial["owner_id"] == user_id
- partial.pop('owner_id')
+ partial.pop("owner_id")
partials.append(partial)
return jsonify(partials)
-@bp.route('/@me/guilds/', methods=['DELETE'])
+@bp.route("/@me/guilds/", methods=["DELETE"])
async def leave_guild(guild_id: int):
"""Leave a guild."""
user_id = await token_check()
@@ -298,7 +323,7 @@ async def leave_guild(guild_id: int):
await remove_member(guild_id, user_id)
- return '', 204
+ return "", 204
# @bp.route('/@me/connections', methods=['GET'])
@@ -306,7 +331,7 @@ async def get_connections():
pass
-@bp.route('/@me/consent', methods=['GET', 'POST'])
+@bp.route("/@me/consent", methods=["GET", "POST"])
async def get_consent():
"""Always disable data collection.
@@ -314,57 +339,58 @@ async def get_consent():
by the client and ignores them, as they
will always be false.
"""
- return jsonify({
- 'usage_statistics': {
- 'consented': False,
- },
- 'personalization': {
- 'consented': False,
+ return jsonify(
+ {
+ "usage_statistics": {"consented": False},
+ "personalization": {"consented": False},
}
- })
+ )
-@bp.route('/@me/harvest', methods=['GET'])
+@bp.route("/@me/harvest", methods=["GET"])
async def get_harvest():
"""Dummy route"""
- return '', 204
+ return "", 204
-@bp.route('/@me/activities/statistics/applications', methods=['GET'])
+@bp.route("/@me/activities/statistics/applications", methods=["GET"])
async def get_stats_applications():
"""Dummy route for info on gameplay time and such"""
return jsonify([])
-@bp.route('/@me/library', methods=['GET'])
+@bp.route("/@me/library", methods=["GET"])
async def get_library():
"""Probably related to Discord Store?"""
return jsonify([])
-@bp.route('//profile', methods=['GET'])
+@bp.route("/