black fmt pass

This commit is contained in:
Luna 2019-10-25 07:27:50 -03:00
parent 0bc4b1ba3f
commit 83a1c1ae29
109 changed files with 5575 additions and 4775 deletions

View File

@ -17,13 +17,14 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
MODE = 'CI' MODE = "CI"
class Config: class Config:
"""Default configuration values for litecord.""" """Default configuration values for litecord."""
MAIN_URL = 'localhost:1'
NAME = 'gitlab ci' MAIN_URL = "localhost:1"
NAME = "gitlab ci"
# Enable debug logging? # Enable debug logging?
DEBUG = False DEBUG = False
@ -37,11 +38,11 @@ class Config:
# Set this url to somewhere *your users* # Set this url to somewhere *your users*
# will hit the websocket. # will hit the websocket.
# e.g 'gateway.example.com' for reverse proxies. # e.g 'gateway.example.com' for reverse proxies.
WEBSOCKET_URL = 'localhost:5001' WEBSOCKET_URL = "localhost:5001"
# Where to host the websocket? # Where to host the websocket?
# (a local address the server will bind to) # (a local address the server will bind to)
WS_HOST = 'localhost' WS_HOST = "localhost"
WS_PORT = 5001 WS_PORT = 5001
# Postgres credentials # Postgres credentials
@ -51,10 +52,10 @@ class Config:
class Development(Config): class Development(Config):
DEBUG = True DEBUG = True
POSTGRES = { POSTGRES = {
'host': 'localhost', "host": "localhost",
'user': 'litecord', "user": "litecord",
'password': '123', "password": "123",
'database': 'litecord', "database": "litecord",
} }
@ -66,8 +67,4 @@ class Production(Config):
class CI(Config): class CI(Config):
DEBUG = True DEBUG = True
POSTGRES = { POSTGRES = {"host": "postgres", "user": "postgres", "password": ""}
'host': 'postgres',
'user': 'postgres',
'password': ''
}

View File

@ -17,16 +17,17 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
MODE = 'Development' MODE = "Development"
class Config: class Config:
"""Default configuration values for litecord.""" """Default configuration values for litecord."""
#: Main URL of the instance. #: Main URL of the instance.
MAIN_URL = 'discordapp.io' MAIN_URL = "discordapp.io"
#: Name of the instance #: Name of the instance
NAME = 'Litecord/Nya' NAME = "Litecord/Nya"
#: Enable debug logging? #: Enable debug logging?
DEBUG = False DEBUG = False
@ -45,17 +46,17 @@ class Config:
# Set this url to somewhere *your users* # Set this url to somewhere *your users*
# will hit the websocket. # will hit the websocket.
# e.g 'gateway.example.com' for reverse proxies. # e.g 'gateway.example.com' for reverse proxies.
WEBSOCKET_URL = 'localhost:5001' WEBSOCKET_URL = "localhost:5001"
#: Where to host the websocket? #: Where to host the websocket?
# (a local address the server will bind to) # (a local address the server will bind to)
WS_HOST = '0.0.0.0' WS_HOST = "0.0.0.0"
WS_PORT = 5001 WS_PORT = 5001
#: Mediaproxy URL on the internet #: Mediaproxy URL on the internet
# mediaproxy is made to prevent client IPs being leaked. # mediaproxy is made to prevent client IPs being leaked.
# None is a valid value if you don't want to deploy mediaproxy. # None is a valid value if you don't want to deploy mediaproxy.
MEDIA_PROXY = 'localhost:5002' MEDIA_PROXY = "localhost:5002"
#: Postgres credentials #: Postgres credentials
POSTGRES = {} POSTGRES = {}
@ -65,10 +66,10 @@ class Development(Config):
DEBUG = True DEBUG = True
POSTGRES = { POSTGRES = {
'host': 'localhost', "host": "localhost",
'user': 'litecord', "user": "litecord",
'password': '123', "password": "123",
'database': 'litecord', "database": "litecord",
} }
@ -77,8 +78,8 @@ class Production(Config):
IS_SSL = True IS_SSL = True
POSTGRES = { POSTGRES = {
'host': 'some_production_postgres', "host": "some_production_postgres",
'user': 'some_production_user', "user": "some_production_user",
'password': 'some_production_password', "password": "some_production_password",
'database': 'litecord_or_anything_else_really', "database": "litecord_or_anything_else_really",
} }

View File

@ -16,4 +16,3 @@ You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """

View File

@ -19,42 +19,33 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
from litecord.enums import Feature, UserFlags from litecord.enums import Feature, UserFlags
VOICE_SERVER = { VOICE_SERVER = {"hostname": {"type": "string", "maxlength": 255, "required": True}}
'hostname': {'type': 'string', 'maxlength': 255, 'required': True}
}
VOICE_REGION = { VOICE_REGION = {
'id': {'type': 'string', 'maxlength': 255, 'required': True}, "id": {"type": "string", "maxlength": 255, "required": True},
'name': {'type': 'string', 'maxlength': 255, 'required': True}, "name": {"type": "string", "maxlength": 255, "required": True},
"vip": {"type": "boolean", "default": False},
'vip': {'type': 'boolean', 'default': False}, "deprecated": {"type": "boolean", "default": False},
'deprecated': {'type': 'boolean', 'default': False}, "custom": {"type": "boolean", "default": False},
'custom': {'type': 'boolean', 'default': False},
} }
FEATURES = { FEATURES = {
'features': { "features": {
'type': 'list', 'required': True, "type": "list",
"required": True,
# using Feature doesn't seem to work with a "not callable" error. # using Feature doesn't seem to work with a "not callable" error.
'schema': {'coerce': lambda x: Feature(x)} "schema": {"coerce": lambda x: Feature(x)},
} }
} }
USER_CREATE = { USER_CREATE = {
'username': {'type': 'username', 'required': True}, "username": {"type": "username", "required": True},
'email': {'type': 'email', 'required': True}, "email": {"type": "email", "required": True},
'password': {'type': 'string', 'minlength': 5, 'required': True}, "password": {"type": "string", "minlength": 5, "required": True},
} }
INSTANCE_INVITE = { INSTANCE_INVITE = {"max_uses": {"type": "integer", "required": True}}
'max_uses': {'type': 'integer', 'required': True}
}
GUILD_UPDATE = { GUILD_UPDATE = {"unavailable": {"type": "boolean", "required": False}}
'unavailable': {'type': 'boolean', 'required': False}
}
USER_UPDATE = { USER_UPDATE = {"flags": {"required": False, "coerce": UserFlags.from_int}}
'flags': {'required': False, 'coerce': UserFlags.from_int}
}

View File

@ -55,44 +55,50 @@ async def raw_token_check(token: str, db=None) -> int:
# just try by fragments instead of # just try by fragments instead of
# unpacking # unpacking
fragments = token.split('.') fragments = token.split(".")
user_id = fragments[0] user_id = fragments[0]
try: try:
user_id = base64.b64decode(user_id.encode()) user_id = base64.b64decode(user_id.encode())
user_id = int(user_id) user_id = int(user_id)
except (ValueError, binascii.Error): except (ValueError, binascii.Error):
raise Unauthorized('Invalid user ID type') raise Unauthorized("Invalid user ID type")
pwd_hash = await db.fetchval(""" pwd_hash = await db.fetchval(
"""
SELECT password_hash SELECT password_hash
FROM users FROM users
WHERE id = $1 WHERE id = $1
""", user_id) """,
user_id,
)
if not pwd_hash: if not pwd_hash:
raise Unauthorized('User ID not found') raise Unauthorized("User ID not found")
signer = TimestampSigner(pwd_hash) signer = TimestampSigner(pwd_hash)
try: try:
signer.unsign(token) signer.unsign(token)
log.debug('login for uid {} successful', user_id) log.debug("login for uid {} successful", user_id)
# update the user's last_session field # update the user's last_session field
# so that we can keep an exact track of activity, # so that we can keep an exact track of activity,
# even on long-lived single sessions (that can happen # even on long-lived single sessions (that can happen
# with people leaving their clients open forever) # with people leaving their clients open forever)
await db.execute(""" await db.execute(
"""
UPDATE users UPDATE users
SET last_session = (now() at time zone 'utc') SET last_session = (now() at time zone 'utc')
WHERE id = $1 WHERE id = $1
""", user_id) """,
user_id,
)
return user_id return user_id
except BadSignature: except BadSignature:
log.warning('token failed for uid {}', user_id) log.warning("token failed for uid {}", user_id)
raise Forbidden('Invalid token') raise Forbidden("Invalid token")
async def token_check() -> int: async def token_check() -> int:
@ -104,12 +110,12 @@ async def token_check() -> int:
pass pass
try: try:
token = request.headers['Authorization'] token = request.headers["Authorization"]
except KeyError: except KeyError:
raise Unauthorized('No token provided') raise Unauthorized("No token provided")
if token.startswith('Bot '): if token.startswith("Bot "):
token = token.replace('Bot ', '') token = token.replace("Bot ", "")
user_id = await raw_token_check(token) user_id = await raw_token_check(token)
request.user_id = user_id request.user_id = user_id
@ -120,15 +126,18 @@ async def admin_check() -> int:
"""Check if the user is an admin.""" """Check if the user is an admin."""
user_id = await token_check() user_id = await token_check()
flags = await app.db.fetchval(""" flags = await app.db.fetchval(
"""
SELECT flags SELECT flags
FROM users FROM users
WHERE id = $1 WHERE id = $1
""", user_id) """,
user_id,
)
flags = UserFlags.from_int(flags) flags = UserFlags.from_int(flags)
if not flags.is_staff: if not flags.is_staff:
raise Unauthorized('you are not staff') raise Unauthorized("you are not staff")
return user_id return user_id
@ -138,9 +147,7 @@ async def hash_data(data: str, loop=None) -> str:
loop = loop or app.loop loop = loop or app.loop
buf = data.encode() buf = data.encode()
hashed = await loop.run_in_executor( hashed = await loop.run_in_executor(None, bcrypt.hashpw, buf, bcrypt.gensalt(14))
None, bcrypt.hashpw, buf, bcrypt.gensalt(14)
)
return hashed.decode() return hashed.decode()
@ -148,22 +155,28 @@ async def hash_data(data: str, loop=None) -> str:
async def check_username_usage(username: str, db=None): async def check_username_usage(username: str, db=None):
"""Raise an error if too many people are with the same username.""" """Raise an error if too many people are with the same username."""
db = db or app.db db = db or app.db
same_username = await db.fetchval(""" same_username = await db.fetchval(
"""
SELECT COUNT(*) SELECT COUNT(*)
FROM users FROM users
WHERE username = $1 WHERE username = $1
""", username) """,
username,
)
if same_username > 9000: if same_username > 9000:
raise BadRequest('Too many people.', { raise BadRequest(
'username': 'Too many people used the same username. ' "Too many people.",
'Please choose another' {
}) "username": "Too many people used the same username. "
"Please choose another"
},
)
def _raw_discrim() -> str: def _raw_discrim() -> str:
new_discrim = randint(1, 9999) new_discrim = randint(1, 9999)
new_discrim = '%04d' % new_discrim new_discrim = "%04d" % new_discrim
return new_discrim return new_discrim
@ -186,11 +199,15 @@ async def roll_discrim(username: str, *, db=None) -> Optional[str]:
discrim = _raw_discrim() discrim = _raw_discrim()
# check if anyone is with it # check if anyone is with it
res = await db.fetchval(""" res = await db.fetchval(
"""
SELECT id SELECT id
FROM users FROM users
WHERE username = $1 AND discriminator = $2 WHERE username = $1 AND discriminator = $2
""", username, discrim) """,
username,
discrim,
)
# if no user is found with the (username, discrim) # if no user is found with the (username, discrim)
# pair, then this is unique! return it. # pair, then this is unique! return it.
@ -200,8 +217,9 @@ async def roll_discrim(username: str, *, db=None) -> Optional[str]:
return None return None
async def create_user(username: str, email: str, password: str, async def create_user(
db=None, loop=None) -> Tuple[int, str]: username: str, email: str, password: str, db=None, loop=None
) -> Tuple[int, str]:
"""Create a single user. """Create a single user.
Generates a distriminator and other information. You can fetch the user Generates a distriminator and other information. You can fetch the user
@ -214,20 +232,28 @@ async def create_user(username: str, email: str, password: str,
new_discrim = await roll_discrim(username, db=db) new_discrim = await roll_discrim(username, db=db)
if new_discrim is None: if new_discrim is None:
raise BadRequest('Unable to register.', { raise BadRequest(
'username': 'Too many people are with this username.' "Unable to register.",
}) {"username": "Too many people are with this username."},
)
pwd_hash = await hash_data(password, loop) pwd_hash = await hash_data(password, loop)
try: try:
await db.execute(""" await db.execute(
"""
INSERT INTO users INSERT INTO users
(id, email, username, discriminator, password_hash) (id, email, username, discriminator, password_hash)
VALUES VALUES
($1, $2, $3, $4, $5) ($1, $2, $3, $4, $5)
""", new_id, email, username, new_discrim, pwd_hash) """,
new_id,
email,
username,
new_discrim,
pwd_hash,
)
except UniqueViolationError: except UniqueViolationError:
raise BadRequest('Email already used.') raise BadRequest("Email already used.")
return new_id, pwd_hash return new_id, pwd_hash

View File

@ -34,7 +34,21 @@ from .static import bp as static
from .attachments import bp as attachments from .attachments import bp as attachments
from .dm_channels import bp as dm_channels from .dm_channels import bp as dm_channels
__all__ = ['gateway', 'auth', 'users', 'guilds', 'channels', __all__ = [
'webhooks', 'science', 'voice', 'invites', 'relationships', "gateway",
'dms', 'icons', 'nodeinfo', 'static', 'attachments', "auth",
'dm_channels'] "users",
"guilds",
"channels",
"webhooks",
"science",
"voice",
"invites",
"relationships",
"dms",
"icons",
"nodeinfo",
"static",
"attachments",
"dm_channels",
]

View File

@ -23,4 +23,4 @@ from .guilds import bp as guilds
from .users import bp as users from .users import bp as users
from .instance_invites import bp as instance_invites from .instance_invites import bp as instance_invites
__all__ = ['voice', 'features', 'guilds', 'users', 'instance_invites'] __all__ = ["voice", "features", "guilds", "users", "instance_invites"]

View File

@ -25,45 +25,53 @@ from litecord.errors import BadRequest
from litecord.schemas import validate from litecord.schemas import validate
from litecord.admin_schemas import FEATURES from litecord.admin_schemas import FEATURES
bp = Blueprint('features_admin', __name__) bp = Blueprint("features_admin", __name__)
async def _features_from_req() -> List[str]: async def _features_from_req() -> List[str]:
j = validate(await request.get_json(), FEATURES) j = validate(await request.get_json(), FEATURES)
return [feature.value for feature in j['features']] return [feature.value for feature in j["features"]]
async def _features(guild_id: int): async def _features(guild_id: int):
return jsonify({ return jsonify({"features": await app.storage.guild_features(guild_id)})
'features': await app.storage.guild_features(guild_id)
})
async def _update_features(guild_id: int, features: list): async def _update_features(guild_id: int, features: list):
if 'VANITY_URL' not in features: if "VANITY_URL" not in features:
existing_inv = await app.storage.vanity_invite(guild_id) existing_inv = await app.storage.vanity_invite(guild_id)
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM vanity_invites DELETE FROM vanity_invites
WHERE guild_id = $1 WHERE guild_id = $1
""", guild_id) """,
guild_id,
)
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM invites DELETE FROM invites
WHERE code = $1 WHERE code = $1
""", existing_inv) """,
existing_inv,
)
await app.db.execute(""" await app.db.execute(
"""
UPDATE guilds UPDATE guilds
SET features = $1 SET features = $1
WHERE id = $2 WHERE id = $2
""", features, guild_id) """,
features,
guild_id,
)
guild = await app.storage.get_guild_full(guild_id) guild = await app.storage.get_guild_full(guild_id)
await app.dispatcher.dispatch('guild', guild_id, 'GUILD_UPDATE', guild) await app.dispatcher.dispatch("guild", guild_id, "GUILD_UPDATE", guild)
@bp.route('/<int:guild_id>/features', methods=['PATCH']) @bp.route("/<int:guild_id>/features", methods=["PATCH"])
async def replace_features(guild_id: int): async def replace_features(guild_id: int):
"""Replace the feature list in a guild""" """Replace the feature list in a guild"""
await admin_check() await admin_check()
@ -76,7 +84,7 @@ async def replace_features(guild_id: int):
return await _features(guild_id) return await _features(guild_id)
@bp.route('/<int:guild_id>/features', methods=['PUT']) @bp.route("/<int:guild_id>/features", methods=["PUT"])
async def insert_features(guild_id: int): async def insert_features(guild_id: int):
"""Insert a feature on a guild.""" """Insert a feature on a guild."""
await admin_check() await admin_check()
@ -93,7 +101,7 @@ async def insert_features(guild_id: int):
return await _features(guild_id) return await _features(guild_id)
@bp.route('/<int:guild_id>/features', methods=['DELETE']) @bp.route("/<int:guild_id>/features", methods=["DELETE"])
async def remove_features(guild_id: int): async def remove_features(guild_id: int):
"""Remove a feature from a guild""" """Remove a feature from a guild"""
await admin_check() await admin_check()
@ -104,7 +112,7 @@ async def remove_features(guild_id: int):
try: try:
features.remove(feature) features.remove(feature)
except ValueError: except ValueError:
raise BadRequest('Trying to remove already removed feature.') raise BadRequest("Trying to remove already removed feature.")
await _update_features(guild_id, features) await _update_features(guild_id, features)
return await _features(guild_id) return await _features(guild_id)

View File

@ -25,9 +25,10 @@ from litecord.admin_schemas import GUILD_UPDATE
from litecord.blueprints.guilds import delete_guild from litecord.blueprints.guilds import delete_guild
from litecord.errors import GuildNotFound from litecord.errors import GuildNotFound
bp = Blueprint('guilds_admin', __name__) bp = Blueprint("guilds_admin", __name__)
@bp.route('/<int:guild_id>', methods=['GET'])
@bp.route("/<int:guild_id>", methods=["GET"])
async def get_guild(guild_id: int): async def get_guild(guild_id: int):
"""Get a basic guild payload.""" """Get a basic guild payload."""
await admin_check() await admin_check()
@ -40,7 +41,7 @@ async def get_guild(guild_id: int):
return jsonify(guild) return jsonify(guild)
@bp.route('/<int:guild_id>', methods=['PATCH']) @bp.route("/<int:guild_id>", methods=["PATCH"])
async def update_guild(guild_id: int): async def update_guild(guild_id: int):
await admin_check() await admin_check()
@ -48,13 +49,13 @@ async def update_guild(guild_id: int):
# TODO: what happens to the other guild attributes when its # TODO: what happens to the other guild attributes when its
# unavailable? do they vanish? # unavailable? do they vanish?
old_unavailable = app.guild_store.get(guild_id, 'unavailable') old_unavailable = app.guild_store.get(guild_id, "unavailable")
new_unavailable = j.get('unavailable', old_unavailable) new_unavailable = j.get("unavailable", old_unavailable)
# always set unavailable status since new_unavailable will be # always set unavailable status since new_unavailable will be
# old_unavailable when not provided, so we don't need to check if # old_unavailable when not provided, so we don't need to check if
# j.unavailable is there # j.unavailable is there
app.guild_store.set(guild_id, 'unavailable', j['unavailable']) app.guild_store.set(guild_id, "unavailable", j["unavailable"])
guild = await app.storage.get_guild(guild_id) guild = await app.storage.get_guild(guild_id)
@ -62,17 +63,17 @@ async def update_guild(guild_id: int):
if old_unavailable and not new_unavailable: if old_unavailable and not new_unavailable:
# guild became available # guild became available
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_CREATE', guild) await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild)
else: else:
# guild became unavailable # guild became unavailable
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_DELETE', guild) await app.dispatcher.dispatch_guild(guild_id, "GUILD_DELETE", guild)
return jsonify(guild) return jsonify(guild)
@bp.route('/<int:guild_id>', methods=['DELETE']) @bp.route("/<int:guild_id>", methods=["DELETE"])
async def delete_guild_as_admin(guild_id): async def delete_guild_as_admin(guild_id):
"""Delete a single guild via the admin API without ownership checks.""" """Delete a single guild via the admin API without ownership checks."""
await admin_check() await admin_check()
await delete_guild(guild_id) await delete_guild(guild_id)
return '', 204 return "", 204

View File

@ -27,13 +27,13 @@ from litecord.types import timestamp_
from litecord.schemas import validate from litecord.schemas import validate
from litecord.admin_schemas import INSTANCE_INVITE from litecord.admin_schemas import INSTANCE_INVITE
bp = Blueprint('instance_invites', __name__) bp = Blueprint("instance_invites", __name__)
ALPHABET = string.ascii_lowercase + string.ascii_uppercase + string.digits ALPHABET = string.ascii_lowercase + string.ascii_uppercase + string.digits
async def _gen_inv() -> str: async def _gen_inv() -> str:
"""Generate an invite code""" """Generate an invite code"""
return ''.join(choice(ALPHABET) for _ in range(6)) return "".join(choice(ALPHABET) for _ in range(6))
async def gen_inv(ctx) -> str: async def gen_inv(ctx) -> str:
@ -41,11 +41,14 @@ async def gen_inv(ctx) -> str:
for _ in range(10): for _ in range(10):
possible_inv = await _gen_inv() possible_inv = await _gen_inv()
created_at = await ctx.db.fetchval(""" created_at = await ctx.db.fetchval(
"""
SELECT created_at SELECT created_at
FROM instance_invites FROM instance_invites
WHERE code = $1 WHERE code = $1
""", possible_inv) """,
possible_inv,
)
if created_at is None: if created_at is None:
return possible_inv return possible_inv
@ -53,57 +56,71 @@ async def gen_inv(ctx) -> str:
return None return None
@bp.route('', methods=['GET']) @bp.route("", methods=["GET"])
async def _all_instance_invites(): async def _all_instance_invites():
await admin_check() await admin_check()
rows = await app.db.fetch(""" rows = await app.db.fetch(
"""
SELECT code, created_at, uses, max_uses SELECT code, created_at, uses, max_uses
FROM instance_invites FROM instance_invites
""") """
)
rows = [dict(row) for row in rows] rows = [dict(row) for row in rows]
for row in rows: for row in rows:
row['created_at'] = timestamp_(row['created_at']) row["created_at"] = timestamp_(row["created_at"])
return jsonify(rows) return jsonify(rows)
@bp.route('', methods=['PUT']) @bp.route("", methods=["PUT"])
async def _create_invite(): async def _create_invite():
await admin_check() await admin_check()
code = await gen_inv(app) code = await gen_inv(app)
if code is None: if code is None:
return 'failed to make invite', 500 return "failed to make invite", 500
j = validate(await request.get_json(), INSTANCE_INVITE) j = validate(await request.get_json(), INSTANCE_INVITE)
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO instance_invites (code, max_uses) INSERT INTO instance_invites (code, max_uses)
VALUES ($1, $2) VALUES ($1, $2)
""", code, j['max_uses']) """,
code,
j["max_uses"],
)
inv = dict(await app.db.fetchrow(""" inv = dict(
await app.db.fetchrow(
"""
SELECT code, created_at, uses, max_uses SELECT code, created_at, uses, max_uses
FROM instance_invites FROM instance_invites
WHERE code = $1 WHERE code = $1
""", code)) """,
code,
)
)
return jsonify(dict(inv)) return jsonify(dict(inv))
@bp.route('/<invite>', methods=['DELETE']) @bp.route("/<invite>", methods=["DELETE"])
async def _del_invite(invite: str): async def _del_invite(invite: str):
await admin_check() await admin_check()
res = await app.db.execute(""" res = await app.db.execute(
"""
DELETE FROM instance_invites DELETE FROM instance_invites
WHERE code = $1 WHERE code = $1
""", invite) """,
invite,
)
if res.lower() == 'delete 0': if res.lower() == "delete 0":
return 'invite not found', 404 return "invite not found", 404
return '', 204 return "", 204

View File

@ -25,24 +25,21 @@ from litecord.schemas import validate
from litecord.admin_schemas import USER_CREATE, USER_UPDATE from litecord.admin_schemas import USER_CREATE, USER_UPDATE
from litecord.errors import BadRequest, Forbidden from litecord.errors import BadRequest, Forbidden
from litecord.utils import async_map from litecord.utils import async_map
from litecord.blueprints.users import ( from litecord.blueprints.users import delete_user, user_disconnect, mass_user_update
delete_user, user_disconnect, mass_user_update
)
from litecord.enums import UserFlags from litecord.enums import UserFlags
bp = Blueprint('users_admin', __name__) bp = Blueprint("users_admin", __name__)
@bp.route('', methods=['POST'])
@bp.route("", methods=["POST"])
async def _create_user(): async def _create_user():
await admin_check() await admin_check()
j = validate(await request.get_json(), USER_CREATE) j = validate(await request.get_json(), USER_CREATE)
user_id, _ = await create_user(j['username'], j['email'], j['password']) user_id, _ = await create_user(j["username"], j["email"], j["password"])
return jsonify( return jsonify(await app.storage.get_user(user_id))
await app.storage.get_user(user_id)
)
def args_try(args: dict, typ, field: str, default): def args_try(args: dict, typ, field: str, default):
@ -51,29 +48,29 @@ def args_try(args: dict, typ, field: str, default):
try: try:
return typ(args.get(field, default)) return typ(args.get(field, default))
except (TypeError, ValueError): except (TypeError, ValueError):
raise BadRequest(f'invalid {field} value') raise BadRequest(f"invalid {field} value")
@bp.route('', methods=['GET']) @bp.route("", methods=["GET"])
async def _search_users(): async def _search_users():
await admin_check() await admin_check()
args = request.args args = request.args
username, discrim = args.get('username'), args.get('discriminator') username, discrim = args.get("username"), args.get("discriminator")
per_page = args_try(args, int, 'per_page', 20) per_page = args_try(args, int, "per_page", 20)
page = args_try(args, int, 'page', 0) page = args_try(args, int, "page", 0)
if page < 0: if page < 0:
raise BadRequest('invalid page number') raise BadRequest("invalid page number")
if per_page > 50: if per_page > 50:
raise BadRequest('invalid per page number') raise BadRequest("invalid per page number")
# any of those must be available. # any of those must be available.
if not any((username, discrim)): if not any((username, discrim)):
raise BadRequest('must insert username or discrim') raise BadRequest("must insert username or discrim")
wheres, args = [], [] wheres, args = [], []
@ -82,29 +79,31 @@ async def _search_users():
args.append(username) args.append(username)
if discrim: if discrim:
wheres.append(f'discriminator = ${len(args) + 2}') wheres.append(f"discriminator = ${len(args) + 2}")
args.append(discrim) args.append(discrim)
where_tot = 'WHERE ' if args else '' where_tot = "WHERE " if args else ""
where_tot += ' AND '.join(wheres) where_tot += " AND ".join(wheres)
rows = await app.db.fetch(f""" rows = await app.db.fetch(
f"""
SELECT id SELECT id
FROM users FROM users
{where_tot} {where_tot}
ORDER BY id ASC ORDER BY id ASC
LIMIT {per_page} LIMIT {per_page}
OFFSET ($1 * {per_page}) OFFSET ($1 * {per_page})
""", page, *args) """,
page,
rows = [r['id'] for r in rows] *args,
return jsonify(
await async_map(app.storage.get_user, rows)
) )
rows = [r["id"] for r in rows]
@bp.route('/<int:user_id>', methods=['DELETE']) return jsonify(await async_map(app.storage.get_user, rows))
@bp.route("/<int:user_id>", methods=["DELETE"])
async def _delete_single_user(user_id: int): async def _delete_single_user(user_id: int):
await admin_check() await admin_check()
@ -115,13 +114,10 @@ async def _delete_single_user(user_id: int):
new_user = await app.storage.get_user(user_id) new_user = await app.storage.get_user(user_id)
return jsonify({ return jsonify({"old": old_user, "new": new_user})
'old': old_user,
'new': new_user
})
@bp.route('/<int:user_id>', methods=['PATCH']) @bp.route("/<int:user_id>", methods=["PATCH"])
async def patch_user(user_id: int): async def patch_user(user_id: int):
await admin_check() await admin_check()
@ -129,21 +125,25 @@ async def patch_user(user_id: int):
# get the original user for flags checking # get the original user for flags checking
user = await app.storage.get_user(user_id) user = await app.storage.get_user(user_id)
old_flags = UserFlags.from_int(user['flags']) old_flags = UserFlags.from_int(user["flags"])
# j.flags is already a UserFlags since we coerce it. # j.flags is already a UserFlags since we coerce it.
if 'flags' in j: if "flags" in j:
new_flags = j['flags'] new_flags = j["flags"]
# disallow any changes to the staff badge # disallow any changes to the staff badge
if new_flags.is_staff != old_flags.is_staff: if new_flags.is_staff != old_flags.is_staff:
raise Forbidden('you can not change a users staff badge') raise Forbidden("you can not change a users staff badge")
await app.db.execute(""" await app.db.execute(
"""
UPDATE users UPDATE users
SET flags = $1 SET flags = $1
WHERE id = $2 WHERE id = $2
""", new_flags.value, user_id) """,
new_flags.value,
user_id,
)
public_user, _ = await mass_user_update(user_id, app) public_user, _ = await mass_user_update(user_id, app)
return jsonify(public_user) return jsonify(public_user)

View File

@ -27,10 +27,10 @@ from litecord.admin_schemas import VOICE_SERVER, VOICE_REGION
from litecord.errors import BadRequest from litecord.errors import BadRequest
log = Logger(__name__) log = Logger(__name__)
bp = Blueprint('voice_admin', __name__) bp = Blueprint("voice_admin", __name__)
@bp.route('/regions/<region>', methods=['GET']) @bp.route("/regions/<region>", methods=["GET"])
async def get_region_servers(region): async def get_region_servers(region):
"""Return a list of all servers for a region.""" """Return a list of all servers for a region."""
await admin_check() await admin_check()
@ -38,18 +38,25 @@ async def get_region_servers(region):
return jsonify(servers) return jsonify(servers)
@bp.route('/regions', methods=['PUT']) @bp.route("/regions", methods=["PUT"])
async def insert_new_region(): async def insert_new_region():
"""Create a voice region.""" """Create a voice region."""
await admin_check() await admin_check()
j = validate(await request.get_json(), VOICE_REGION) j = validate(await request.get_json(), VOICE_REGION)
j['id'] = j['id'].lower() j["id"] = j["id"].lower()
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO voice_regions (id, name, vip, deprecated, custom) INSERT INTO voice_regions (id, name, vip, deprecated, custom)
VALUES ($1, $2, $3, $4, $5) VALUES ($1, $2, $3, $4, $5)
""", j['id'], j['name'], j['vip'], j['deprecated'], j['custom']) """,
j["id"],
j["name"],
j["vip"],
j["deprecated"],
j["custom"],
)
regions = await app.storage.all_voice_regions() regions = await app.storage.all_voice_regions()
region_count = len(regions) region_count = len(regions)
@ -57,34 +64,41 @@ async def insert_new_region():
# if region count is 1, this is the first region to be created, # if region count is 1, this is the first region to be created,
# so we should update all guilds to that region # so we should update all guilds to that region
if region_count == 1: if region_count == 1:
res = await app.db.execute(""" res = await app.db.execute(
"""
UPDATE guilds UPDATE guilds
SET region = $1 SET region = $1
""", j['id']) """,
j["id"],
)
log.info('updating guilds to first voice region: {}', res) log.info("updating guilds to first voice region: {}", res)
return jsonify(regions) return jsonify(regions)
@bp.route('/regions/<region>/servers', methods=['PUT']) @bp.route("/regions/<region>/servers", methods=["PUT"])
async def put_region_server(region): async def put_region_server(region):
"""Insert a voice server to a region""" """Insert a voice server to a region"""
await admin_check() await admin_check()
j = validate(await request.get_json(), VOICE_SERVER) j = validate(await request.get_json(), VOICE_SERVER)
try: try:
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO voice_servers (hostname, region_id) INSERT INTO voice_servers (hostname, region_id)
VALUES ($1, $2) VALUES ($1, $2)
""", j['hostname'], region) """,
j["hostname"],
region,
)
except asyncpg.UniqueViolationError: except asyncpg.UniqueViolationError:
raise BadRequest('voice server already exists with given hostname') raise BadRequest("voice server already exists with given hostname")
return '', 204 return "", 204
@bp.route('/regions/<region>/deprecate', methods=['PUT']) @bp.route("/regions/<region>/deprecate", methods=["PUT"])
async def deprecate_region(region): async def deprecate_region(region):
"""Deprecate a voice region.""" """Deprecate a voice region."""
await admin_check() await admin_check()
@ -92,13 +106,16 @@ async def deprecate_region(region):
# TODO: write this # TODO: write this
await app.voice.disable_region(region) await app.voice.disable_region(region)
await app.db.execute(""" await app.db.execute(
"""
UPDATE voice_regions UPDATE voice_regions
SET deprecated = true SET deprecated = true
WHERE id = $1 WHERE id = $1
""", region) """,
region,
)
return '', 204 return "", 204
async def guild_region_check(app_): async def guild_region_check(app_):
@ -112,10 +129,11 @@ async def guild_region_check(app_):
regions = await app_.storage.all_voice_regions() regions = await app_.storage.all_voice_regions()
if not regions: if not regions:
log.info('region check: no regions to move guilds to') log.info("region check: no regions to move guilds to")
return return
res = await app_.db.execute(""" res = await app_.db.execute(
"""
UPDATE guilds UPDATE guilds
SET region = ( SET region = (
SELECT id SELECT id
@ -124,6 +142,8 @@ async def guild_region_check(app_):
LIMIT 1 LIMIT 1
) )
WHERE region = NULL WHERE region = NULL
""", len(regions)) """,
len(regions),
)
log.info('region check: updating guild.region=null: {!r}', res) log.info("region check: updating guild.region=null: {!r}", res)

View File

@ -24,16 +24,17 @@ from PIL import Image
from litecord.images import resize_gif from litecord.images import resize_gif
bp = Blueprint('attachments', __name__) bp = Blueprint("attachments", __name__)
ATTACHMENTS = Path.cwd() / 'attachments' ATTACHMENTS = Path.cwd() / "attachments"
async def _resize_gif(attach_id: int, resized_path: Path, async def _resize_gif(
width: int, height: int) -> str: attach_id: int, resized_path: Path, width: int, height: int
) -> str:
"""Resize a GIF attachment.""" """Resize a GIF attachment."""
# get original gif bytes # get original gif bytes
orig_path = ATTACHMENTS / f'{attach_id}.gif' orig_path = ATTACHMENTS / f"{attach_id}.gif"
orig_bytes = orig_path.read_bytes() orig_bytes = orig_path.read_bytes()
# give them and the target size to the # give them and the target size to the
@ -47,10 +48,7 @@ async def _resize_gif(attach_id: int, resized_path: Path,
return str(resized_path) return str(resized_path)
FORMAT_HARDCODE = { FORMAT_HARDCODE = {"jpg": "jpeg", "jpe": "jpeg"}
'jpg': 'jpeg',
'jpe': 'jpeg'
}
def to_format(ext: str) -> str: def to_format(ext: str) -> str:
@ -63,11 +61,10 @@ def to_format(ext: str) -> str:
return ext return ext
async def _resize(image, attach_id: int, ext: str, async def _resize(image, attach_id: int, ext: str, width: int, height: int) -> str:
width: int, height: int) -> str:
"""Resize an image.""" """Resize an image."""
# check if we have it on the folder # check if we have it on the folder
resized_path = ATTACHMENTS / f'{attach_id}_{width}_{height}.{ext}' resized_path = ATTACHMENTS / f"{attach_id}_{width}_{height}.{ext}"
# keep a str-fied instance since that is what # keep a str-fied instance since that is what
# we'll return. # we'll return.
@ -81,7 +78,7 @@ async def _resize(image, attach_id: int, ext: str,
# the process is different for gif files because we need # the process is different for gif files because we need
# gifsicle. doing it manually is too troublesome. # gifsicle. doing it manually is too troublesome.
if ext == 'gif': if ext == "gif":
return await _resize_gif(attach_id, resized_path, width, height) return await _resize_gif(attach_id, resized_path, width, height)
# NOTE: this is the same resize mode for icons. # NOTE: this is the same resize mode for icons.
@ -91,38 +88,42 @@ async def _resize(image, attach_id: int, ext: str,
return resized_path_s return resized_path_s
@bp.route('/attachments' @bp.route(
'/<int:channel_id>/<int:message_id>/<filename>', "/attachments" "/<int:channel_id>/<int:message_id>/<filename>", methods=["GET"]
methods=['GET']) )
async def _get_attachment(channel_id: int, message_id: int, async def _get_attachment(channel_id: int, message_id: int, filename: str):
filename: str):
attach_id = await app.db.fetchval(""" attach_id = await app.db.fetchval(
"""
SELECT id SELECT id
FROM attachments FROM attachments
WHERE channel_id = $1 WHERE channel_id = $1
AND message_id = $2 AND message_id = $2
AND filename = $3 AND filename = $3
""", channel_id, message_id, filename) """,
channel_id,
message_id,
filename,
)
if attach_id is None: if attach_id is None:
return '', 404 return "", 404
ext = filename.split('.')[-1] ext = filename.split(".")[-1]
filepath = f'./attachments/{attach_id}.{ext}' filepath = f"./attachments/{attach_id}.{ext}"
image = Image.open(filepath) image = Image.open(filepath)
im_width, im_height = image.size im_width, im_height = image.size
try: try:
width = int(request.args.get('width', 0)) or im_width width = int(request.args.get("width", 0)) or im_width
except ValueError: except ValueError:
return '', 400 return "", 400
try: try:
height = int(request.args.get('height', 0)) or im_height height = int(request.args.get("height", 0)) or im_height
except ValueError: except ValueError:
return '', 400 return "", 400
# if width and height are the same (happens if they weren't provided) # if width and height are the same (happens if they weren't provided)
if width == im_width and height == im_height: if width == im_width and height == im_height:

View File

@ -32,7 +32,7 @@ from litecord.snowflake import get_snowflake
from .invites import use_invite from .invites import use_invite
log = Logger(__name__) log = Logger(__name__)
bp = Blueprint('auth', __name__) bp = Blueprint("auth", __name__)
async def check_password(pwd_hash: str, given_password: str) -> bool: async def check_password(pwd_hash: str, given_password: str) -> bool:
@ -53,141 +53,139 @@ def make_token(user_id, user_pwd_hash) -> str:
return signer.sign(user_id).decode() return signer.sign(user_id).decode()
@bp.route('/register', methods=['POST']) @bp.route("/register", methods=["POST"])
async def register(): async def register():
"""Register a single user.""" """Register a single user."""
enabled = app.config.get('REGISTRATIONS') enabled = app.config.get("REGISTRATIONS")
if not enabled: if not enabled:
raise BadRequest('Registrations disabled', { raise BadRequest(
'email': 'Registrations are disabled.' "Registrations disabled", {"email": "Registrations are disabled."}
}) )
j = await request.get_json() j = await request.get_json()
if not 'password' in j: if not "password" in j:
# we need a password to generate a token. # we need a password to generate a token.
# passwords are optional, so # passwords are optional, so
j['password'] = 'default_password' j["password"] = "default_password"
j = validate(j, REGISTER) j = validate(j, REGISTER)
# they're optional # they're optional
email = j.get('email') email = j.get("email")
invite = j.get('invite') invite = j.get("invite")
username, password = j['username'], j['password'] username, password = j["username"], j["password"]
new_id, pwd_hash = await create_user( new_id, pwd_hash = await create_user(username, email, password, app.db)
username, email, password, app.db
)
if invite: if invite:
try: try:
await use_invite(new_id, invite) await use_invite(new_id, invite)
except Exception: except Exception:
log.exception('failed to use invite for register {} {!r}', log.exception("failed to use invite for register {} {!r}", new_id, invite)
new_id, invite)
return jsonify({ return jsonify({"token": make_token(new_id, pwd_hash)})
'token': make_token(new_id, pwd_hash)
})
@bp.route('/register_inv', methods=['POST']) @bp.route("/register_inv", methods=["POST"])
async def _register_with_invite(): async def _register_with_invite():
data = await request.form data = await request.form
data = validate(await request.form, REGISTER_WITH_INVITE) data = validate(await request.form, REGISTER_WITH_INVITE)
invcode = data['invcode'] invcode = data["invcode"]
row = await app.db.fetchrow(""" row = await app.db.fetchrow(
"""
SELECT uses, max_uses SELECT uses, max_uses
FROM instance_invites FROM instance_invites
WHERE code = $1 WHERE code = $1
""", invcode) """,
invcode,
)
if row is None: if row is None:
raise BadRequest('unknown instance invite') raise BadRequest("unknown instance invite")
if row['max_uses'] != -1 and row['uses'] >= row['max_uses']: if row["max_uses"] != -1 and row["uses"] >= row["max_uses"]:
raise BadRequest('invite expired') raise BadRequest("invite expired")
await app.db.execute(""" await app.db.execute(
"""
UPDATE instance_invites UPDATE instance_invites
SET uses = uses + 1 SET uses = uses + 1
WHERE code = $1 WHERE code = $1
""", invcode) """,
invcode,
)
user_id, pwd_hash = await create_user( user_id, pwd_hash = await create_user(
data['username'], data['email'], data['password'], app.db) data["username"], data["email"], data["password"], app.db
)
return jsonify({ return jsonify({"token": make_token(user_id, pwd_hash), "user_id": str(user_id)})
'token': make_token(user_id, pwd_hash),
'user_id': str(user_id),
})
@bp.route('/login', methods=['POST']) @bp.route("/login", methods=["POST"])
async def login(): async def login():
j = await request.get_json() j = await request.get_json()
email, password = j['email'], j['password'] email, password = j["email"], j["password"]
row = await app.db.fetchrow(""" row = await app.db.fetchrow(
"""
SELECT id, password_hash SELECT id, password_hash
FROM users FROM users
WHERE email = $1 WHERE email = $1
""", email) """,
email,
)
if not row: if not row:
return jsonify({'email': ['User not found.']}), 401 return jsonify({"email": ["User not found."]}), 401
user_id, pwd_hash = row user_id, pwd_hash = row
if not await check_password(pwd_hash, password): if not await check_password(pwd_hash, password):
return jsonify({'password': ['Password does not match.']}), 401 return jsonify({"password": ["Password does not match."]}), 401
return jsonify({ return jsonify({"token": make_token(user_id, pwd_hash)})
'token': make_token(user_id, pwd_hash)
})
@bp.route('/consent-required', methods=['GET']) @bp.route("/consent-required", methods=["GET"])
async def consent_required(): async def consent_required():
return jsonify({ return jsonify({"required": True})
'required': True,
})
@bp.route('/verify/resend', methods=['POST']) @bp.route("/verify/resend", methods=["POST"])
async def verify_user(): async def verify_user():
user_id = await token_check() user_id = await token_check()
# TODO: actually verify a user by sending an email # TODO: actually verify a user by sending an email
await app.db.execute(""" await app.db.execute(
"""
UPDATE users UPDATE users
SET verified = true SET verified = true
WHERE id = $1 WHERE id = $1
""", user_id) """,
user_id,
)
new_user = await app.storage.get_user(user_id, True) new_user = await app.storage.get_user(user_id, True)
await app.dispatcher.dispatch_user( await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", new_user)
user_id, 'USER_UPDATE', new_user)
return '', 204 return "", 204
@bp.route('/logout', methods=['POST']) @bp.route("/logout", methods=["POST"])
async def _logout(): async def _logout():
"""Called by the client to logout.""" """Called by the client to logout."""
return '', 204 return "", 204
@bp.route('/fingerprint', methods=['POST']) @bp.route("/fingerprint", methods=["POST"])
async def _fingerprint(): async def _fingerprint():
"""No idea what this route is about.""" """No idea what this route is about."""
fingerprint_id = get_snowflake() fingerprint_id = get_snowflake()
fingerprint = f'{fingerprint_id}.{secrets.token_urlsafe(32)}' fingerprint = f"{fingerprint_id}.{secrets.token_urlsafe(32)}"
return jsonify({ return jsonify({"fingerprint": fingerprint})
'fingerprint': fingerprint
})

View File

@ -21,4 +21,4 @@ from .messages import bp as channel_messages
from .reactions import bp as channel_reactions from .reactions import bp as channel_reactions
from .pins import bp as channel_pins from .pins import bp as channel_pins
__all__ = ['channel_messages', 'channel_reactions', 'channel_pins'] __all__ = ["channel_messages", "channel_reactions", "channel_pins"]

View File

@ -30,13 +30,18 @@ class ForbiddenDM(Forbidden):
async def dm_pre_check(user_id: int, channel_id: int, peer_id: int): async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
"""Check if the user can DM the peer.""" """Check if the user can DM the peer."""
# first step is checking if there is a block in any direction # first step is checking if there is a block in any direction
blockrow = await app.db.fetchrow(""" blockrow = await app.db.fetchrow(
"""
SELECT rel_type SELECT rel_type
FROM relationships FROM relationships
WHERE rel_type = $3 WHERE rel_type = $3
AND user_id IN ($1, $2) AND user_id IN ($1, $2)
AND peer_id IN ($1, $2) AND peer_id IN ($1, $2)
""", user_id, peer_id, RelationshipType.BLOCK.value) """,
user_id,
peer_id,
RelationshipType.BLOCK.value,
)
if blockrow is not None: if blockrow is not None:
raise ForbiddenDM() raise ForbiddenDM()
@ -58,8 +63,8 @@ async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
user_settings = await app.user_storage.get_user_settings(user_id) user_settings = await app.user_storage.get_user_settings(user_id)
peer_settings = await app.user_storage.get_user_settings(peer_id) peer_settings = await app.user_storage.get_user_settings(peer_id)
restricted_user_ = [int(v) for v in user_settings['restricted_guilds']] restricted_user_ = [int(v) for v in user_settings["restricted_guilds"]]
restricted_peer_ = [int(v) for v in peer_settings['restricted_guilds']] restricted_peer_ = [int(v) for v in peer_settings["restricted_guilds"]]
restricted_user = set(restricted_user_) restricted_user = set(restricted_user_)
restricted_peer = set(restricted_peer_) restricted_peer = set(restricted_peer_)

View File

@ -41,18 +41,18 @@ from litecord.images import try_unlink
log = Logger(__name__) log = Logger(__name__)
bp = Blueprint('channel_messages', __name__) bp = Blueprint("channel_messages", __name__)
def extract_limit(request_, default: int = 50, max_val: int = 100): def extract_limit(request_, default: int = 50, max_val: int = 100):
"""Extract a limit kwarg.""" """Extract a limit kwarg."""
try: try:
limit = int(request_.args.get('limit', default)) limit = int(request_.args.get("limit", default))
if limit not in range(0, max_val + 1): if limit not in range(0, max_val + 1):
raise ValueError() raise ValueError()
except (TypeError, ValueError): except (TypeError, ValueError):
raise BadRequest('limit not int') raise BadRequest("limit not int")
return limit return limit
@ -61,27 +61,27 @@ def query_tuple_from_args(args: dict, limit: int) -> tuple:
"""Extract a 2-tuple out of request arguments.""" """Extract a 2-tuple out of request arguments."""
before, after = None, None before, after = None, None
if 'around' in request.args: if "around" in request.args:
average = int(limit / 2) average = int(limit / 2)
around = int(args['around']) around = int(args["around"])
after = around - average after = around - average
before = around + average before = around + average
elif 'before' in args: elif "before" in args:
before = int(args['before']) before = int(args["before"])
elif 'after' in args: elif "after" in args:
before = int(args['after']) before = int(args["after"])
return before, after return before, after
@bp.route('/<int:channel_id>/messages', methods=['GET']) @bp.route("/<int:channel_id>/messages", methods=["GET"])
async def get_messages(channel_id): async def get_messages(channel_id):
user_id = await token_check() user_id = await token_check()
ctype, peer_id = await channel_check(user_id, channel_id) ctype, peer_id = await channel_check(user_id, channel_id)
await channel_perm_check(user_id, channel_id, 'read_history') await channel_perm_check(user_id, channel_id, "read_history")
if ctype == ChannelType.DM: if ctype == ChannelType.DM:
# make sure both parties will be subbed # make sure both parties will be subbed
@ -91,42 +91,45 @@ async def get_messages(channel_id):
limit = extract_limit(request, 50) limit = extract_limit(request, 50)
where_clause = '' where_clause = ""
before, after = query_tuple_from_args(request.args, limit) before, after = query_tuple_from_args(request.args, limit)
if before: if before:
where_clause += f'AND id < {before}' where_clause += f"AND id < {before}"
if after: if after:
where_clause += f'AND id > {after}' where_clause += f"AND id > {after}"
message_ids = await app.db.fetch(f""" message_ids = await app.db.fetch(
f"""
SELECT id SELECT id
FROM messages FROM messages
WHERE channel_id = $1 {where_clause} WHERE channel_id = $1 {where_clause}
ORDER BY id DESC ORDER BY id DESC
LIMIT {limit} LIMIT {limit}
""", channel_id) """,
channel_id,
)
result = [] result = []
for message_id in message_ids: for message_id in message_ids:
msg = await app.storage.get_message(message_id['id'], user_id) msg = await app.storage.get_message(message_id["id"], user_id)
if msg is None: if msg is None:
continue continue
result.append(msg) result.append(msg)
log.info('Fetched {} messages', len(result)) log.info("Fetched {} messages", len(result))
return jsonify(result) return jsonify(result)
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['GET']) @bp.route("/<int:channel_id>/messages/<int:message_id>", methods=["GET"])
async def get_single_message(channel_id, message_id): async def get_single_message(channel_id, message_id):
user_id = await token_check() user_id = await token_check()
await channel_check(user_id, channel_id) await channel_check(user_id, channel_id)
await channel_perm_check(user_id, channel_id, 'read_history') await channel_perm_check(user_id, channel_id, "read_history")
message = await app.storage.get_message(message_id, user_id) message = await app.storage.get_message(message_id, user_id)
@ -142,11 +145,15 @@ async def _dm_pre_dispatch(channel_id, peer_id):
# check the other party's dm_channel_state # check the other party's dm_channel_state
dm_state = await app.db.fetchval(""" dm_state = await app.db.fetchval(
"""
SELECT dm_id SELECT dm_id
FROM dm_channel_state FROM dm_channel_state
WHERE user_id = $1 AND dm_id = $2 WHERE user_id = $1 AND dm_id = $2
""", peer_id, channel_id) """,
peer_id,
channel_id,
)
if dm_state: if dm_state:
# the peer already has the channel # the peer already has the channel
@ -157,18 +164,19 @@ async def _dm_pre_dispatch(channel_id, peer_id):
# dispatch CHANNEL_CREATE so the client knows which # dispatch CHANNEL_CREATE so the client knows which
# channel the future event is about # channel the future event is about
await app.dispatcher.dispatch_user(peer_id, 'CHANNEL_CREATE', dm_chan) await app.dispatcher.dispatch_user(peer_id, "CHANNEL_CREATE", dm_chan)
# subscribe the peer to the channel # subscribe the peer to the channel
await app.dispatcher.sub('channel', channel_id, peer_id) await app.dispatcher.sub("channel", channel_id, peer_id)
# insert it on dm_channel_state so the client # insert it on dm_channel_state so the client
# is subscribed on the future # is subscribed on the future
await try_dm_state(peer_id, channel_id) await try_dm_state(peer_id, channel_id)
async def create_message(channel_id: int, actual_guild_id: int, async def create_message(
author_id: int, data: dict) -> int: channel_id: int, actual_guild_id: int, author_id: int, data: dict
) -> int:
message_id = get_snowflake() message_id = get_snowflake()
async with app.db.acquire() as conn: async with app.db.acquire() as conn:
@ -185,32 +193,32 @@ async def create_message(channel_id: int, actual_guild_id: int,
channel_id, channel_id,
actual_guild_id, actual_guild_id,
author_id, author_id,
data['content'], data["content"],
data["tts"],
data['tts'], data["everyone_mention"],
data['everyone_mention'], data["nonce"],
data['nonce'],
MessageType.DEFAULT.value, MessageType.DEFAULT.value,
data.get('embeds') or [] data.get("embeds") or [],
) )
return message_id return message_id
async def msg_guild_text_mentions(payload: dict, guild_id: int,
mentions_everyone: bool, mentions_here: bool): async def msg_guild_text_mentions(
payload: dict, guild_id: int, mentions_everyone: bool, mentions_here: bool
):
"""Calculates mention data side-effects.""" """Calculates mention data side-effects."""
channel_id = int(payload['channel_id']) channel_id = int(payload["channel_id"])
# calculate the user ids we'll bump the mention count for # calculate the user ids we'll bump the mention count for
uids = set() uids = set()
# first is extracting user mentions # first is extracting user mentions
for mention in payload['mentions']: for mention in payload["mentions"]:
uids.add(int(mention['id'])) uids.add(int(mention["id"]))
# then role mentions # then role mentions
for role_mention in payload['mention_roles']: for role_mention in payload["mention_roles"]:
role_id = int(role_mention) role_id = int(role_mention)
member_ids = await app.storage.get_role_members(role_id) member_ids = await app.storage.get_role_members(role_id)
@ -223,11 +231,14 @@ async def msg_guild_text_mentions(payload: dict, guild_id: int,
if mentions_here: if mentions_here:
uids = set() uids = set()
await app.db.execute(""" await app.db.execute(
"""
UPDATE user_read_state UPDATE user_read_state
SET mention_count = mention_count + 1 SET mention_count = mention_count + 1
WHERE channel_id = $1 WHERE channel_id = $1
""", channel_id) """,
channel_id,
)
# at-here updates the read state # at-here updates the read state
# for all users, including the ones # for all users, including the ones
@ -238,19 +249,26 @@ async def msg_guild_text_mentions(payload: dict, guild_id: int,
member_ids = await app.storage.get_member_ids(guild_id) member_ids = await app.storage.get_member_ids(guild_id)
await app.db.executemany(""" await app.db.executemany(
"""
UPDATE user_read_state UPDATE user_read_state
SET mention_count = mention_count + 1 SET mention_count = mention_count + 1
WHERE channel_id = $1 AND user_id = $2 WHERE channel_id = $1 AND user_id = $2
""", [(channel_id, uid) for uid in member_ids]) """,
[(channel_id, uid) for uid in member_ids],
)
for user_id in uids: for user_id in uids:
await app.db.execute(""" await app.db.execute(
"""
UPDATE user_read_state UPDATE user_read_state
SET mention_count = mention_count + 1 SET mention_count = mention_count + 1
WHERE user_id = $1 WHERE user_id = $1
AND channel_id = $2 AND channel_id = $2
""", user_id, channel_id) """,
user_id,
channel_id,
)
async def msg_create_request() -> tuple: async def msg_create_request() -> tuple:
@ -264,12 +282,12 @@ async def msg_create_request() -> tuple:
# NOTE: embed isn't set on form data # NOTE: embed isn't set on form data
json_from_form = { json_from_form = {
'content': form.get('content', ''), "content": form.get("content", ""),
'nonce': form.get('nonce', '0'), "nonce": form.get("nonce", "0"),
'tts': json.loads(form.get('tts', 'false')), "tts": json.loads(form.get("tts", "false")),
} }
payload_json = json.loads(form.get('payload_json', '{}')) payload_json = json.loads(form.get("payload_json", "{}"))
json_from_form.update(request_json) json_from_form.update(request_json)
json_from_form.update(payload_json) json_from_form.update(payload_json)
@ -283,20 +301,19 @@ async def msg_create_request() -> tuple:
def msg_create_check_content(payload: dict, files: list, *, use_embeds=False): def msg_create_check_content(payload: dict, files: list, *, use_embeds=False):
"""Check if there is actually any content being sent to us.""" """Check if there is actually any content being sent to us."""
has_content = bool(payload.get('content', '')) has_content = bool(payload.get("content", ""))
has_files = len(files) > 0 has_files = len(files) > 0
embed_field = 'embeds' if use_embeds else 'embed' embed_field = "embeds" if use_embeds else "embed"
has_embed = embed_field in payload and payload.get(embed_field) is not None has_embed = embed_field in payload and payload.get(embed_field) is not None
has_total_content = has_content or has_embed or has_files has_total_content = has_content or has_embed or has_files
if not has_total_content: if not has_total_content:
raise BadRequest('No content has been provided.') raise BadRequest("No content has been provided.")
async def msg_add_attachment(message_id: int, channel_id: int, async def msg_add_attachment(message_id: int, channel_id: int, attachment_file) -> int:
attachment_file) -> int:
"""Add an attachment to a message. """Add an attachment to a message.
Parameters Parameters
@ -318,7 +335,7 @@ async def msg_add_attachment(message_id: int, channel_id: int,
# understand file info # understand file info
mime = attachment_file.mimetype mime = attachment_file.mimetype
is_image = mime.startswith('image/') is_image = mime.startswith("image/")
img_width, img_height = None, None img_width, img_height = None, None
@ -346,17 +363,22 @@ async def msg_add_attachment(message_id: int, channel_id: int,
VALUES VALUES
($1, $2, $3, $4, $5, $6, $7, $8) ($1, $2, $3, $4, $5, $6, $7, $8)
""", """,
attachment_id, channel_id, message_id, attachment_id,
filename, file_size, channel_id,
is_image, img_width, img_height) message_id,
filename,
file_size,
is_image,
img_width,
img_height,
)
ext = filename.split('.')[-1] ext = filename.split(".")[-1]
with open(f'attachments/{attachment_id}.{ext}', 'wb') as attach_file: with open(f"attachments/{attachment_id}.{ext}", "wb") as attach_file:
attach_file.write(attachment_file.stream.read()) attach_file.write(attachment_file.stream.read())
log.debug('written {} bytes for attachment id {}', log.debug("written {} bytes for attachment id {}", file_size, attachment_id)
file_size, attachment_id)
return attachment_id return attachment_id
@ -364,12 +386,12 @@ async def msg_add_attachment(message_id: int, channel_id: int,
async def _spawn_embed(app_, payload, **kwargs): async def _spawn_embed(app_, payload, **kwargs):
app_.sched.spawn( app_.sched.spawn(
process_url_embed( process_url_embed(
app_.config, app_.storage, app_.dispatcher, app_.session, app_.config, app_.storage, app_.dispatcher, app_.session, payload, **kwargs
payload, **kwargs) )
) )
@bp.route('/<int:channel_id>/messages', methods=['POST']) @bp.route("/<int:channel_id>/messages", methods=["POST"])
async def _create_message(channel_id): async def _create_message(channel_id):
"""Create a message.""" """Create a message."""
@ -379,7 +401,7 @@ async def _create_message(channel_id):
actual_guild_id = None actual_guild_id = None
if ctype in GUILD_CHANS: if ctype in GUILD_CHANS:
await channel_perm_check(user_id, channel_id, 'send_messages') await channel_perm_check(user_id, channel_id, "send_messages")
actual_guild_id = guild_id actual_guild_id = guild_id
payload_json, files = await msg_create_request() payload_json, files = await msg_create_request()
@ -394,29 +416,31 @@ async def _create_message(channel_id):
await dm_pre_check(user_id, channel_id, guild_id) await dm_pre_check(user_id, channel_id, guild_id)
can_everyone = await channel_perm_check( can_everyone = await channel_perm_check(
user_id, channel_id, 'mention_everyone', False user_id, channel_id, "mention_everyone", False
) )
mentions_everyone = ('@everyone' in j['content']) and can_everyone mentions_everyone = ("@everyone" in j["content"]) and can_everyone
mentions_here = ('@here' in j['content']) and can_everyone mentions_here = ("@here" in j["content"]) and can_everyone
is_tts = (j.get('tts', False) and is_tts = j.get("tts", False) and await channel_perm_check(
await channel_perm_check( user_id, channel_id, "send_tts_messages", False
user_id, channel_id, 'send_tts_messages', False )
))
message_id = await create_message( message_id = await create_message(
channel_id, actual_guild_id, user_id, { channel_id,
'content': j['content'], actual_guild_id,
'tts': is_tts, user_id,
'nonce': int(j.get('nonce', 0)), {
'everyone_mention': mentions_everyone or mentions_here, "content": j["content"],
"tts": is_tts,
"nonce": int(j.get("nonce", 0)),
"everyone_mention": mentions_everyone or mentions_here,
# fill_embed takes care of filling proxy and width/height # fill_embed takes care of filling proxy and width/height
'embeds': ([await fill_embed(j['embed'])] "embeds": (
if j.get('embed') is not None [await fill_embed(j["embed"])] if j.get("embed") is not None else []
else []), ),
}) },
)
# for each file given, we add it as an attachment # for each file given, we add it as an attachment
for pre_attachment in files: for pre_attachment in files:
@ -429,8 +453,7 @@ async def _create_message(channel_id):
await _dm_pre_dispatch(channel_id, user_id) await _dm_pre_dispatch(channel_id, user_id)
await _dm_pre_dispatch(channel_id, guild_id) await _dm_pre_dispatch(channel_id, guild_id)
await app.dispatcher.dispatch('channel', channel_id, await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload)
'MESSAGE_CREATE', payload)
# spawn url processor for embedding of images # spawn url processor for embedding of images
perms = await get_permissions(user_id, channel_id) perms = await get_permissions(user_id, channel_id)
@ -438,54 +461,71 @@ async def _create_message(channel_id):
await _spawn_embed(app, payload) await _spawn_embed(app, payload)
# update read state for the author # update read state for the author
await app.db.execute(""" await app.db.execute(
"""
UPDATE user_read_state UPDATE user_read_state
SET last_message_id = $1 SET last_message_id = $1
WHERE channel_id = $2 AND user_id = $3 WHERE channel_id = $2 AND user_id = $3
""", message_id, channel_id, user_id) """,
message_id,
channel_id,
user_id,
)
if ctype == ChannelType.GUILD_TEXT: if ctype == ChannelType.GUILD_TEXT:
await msg_guild_text_mentions( await msg_guild_text_mentions(
payload, guild_id, mentions_everyone, mentions_here) payload, guild_id, mentions_everyone, mentions_here
)
return jsonify(payload) return jsonify(payload)
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['PATCH']) @bp.route("/<int:channel_id>/messages/<int:message_id>", methods=["PATCH"])
async def edit_message(channel_id, message_id): async def edit_message(channel_id, message_id):
user_id = await token_check() user_id = await token_check()
_ctype, _guild_id = await channel_check(user_id, channel_id) _ctype, _guild_id = await channel_check(user_id, channel_id)
author_id = await app.db.fetchval(""" author_id = await app.db.fetchval(
"""
SELECT author_id FROM messages SELECT author_id FROM messages
WHERE messages.id = $1 WHERE messages.id = $1
""", message_id) """,
message_id,
)
if not author_id == user_id: if not author_id == user_id:
raise Forbidden('You can not edit this message') raise Forbidden("You can not edit this message")
j = await request.get_json() j = await request.get_json()
updated = 'content' in j or 'embed' in j updated = "content" in j or "embed" in j
old_message = await app.storage.get_message(message_id) old_message = await app.storage.get_message(message_id)
if 'content' in j: if "content" in j:
await app.db.execute(""" await app.db.execute(
"""
UPDATE messages UPDATE messages
SET content=$1 SET content=$1
WHERE messages.id = $2 WHERE messages.id = $2
""", j['content'], message_id) """,
j["content"],
message_id,
)
if 'embed' in j: if "embed" in j:
embeds = [await fill_embed(j['embed'])] embeds = [await fill_embed(j["embed"])]
await app.db.execute(""" await app.db.execute(
"""
UPDATE messages UPDATE messages
SET embeds=$1 SET embeds=$1
WHERE messages.id = $2 WHERE messages.id = $2
""", embeds, message_id) """,
embeds,
message_id,
)
# do not spawn process_url_embed since we already have embeds. # do not spawn process_url_embed since we already have embeds.
elif 'content' in j: elif "content" in j:
# if there weren't any embed changes BUT # if there weren't any embed changes BUT
# we had a content change, we dispatch process_url_embed but with # we had a content change, we dispatch process_url_embed but with
# an artificial delay. # an artificial delay.
@ -495,46 +535,55 @@ async def edit_message(channel_id, message_id):
# BEFORE the MESSAGE_UPDATE with the new embeds (based on content) # BEFORE the MESSAGE_UPDATE with the new embeds (based on content)
perms = await get_permissions(user_id, channel_id) perms = await get_permissions(user_id, channel_id)
if perms.bits.embed_links: if perms.bits.embed_links:
await _spawn_embed(app, { await _spawn_embed(
'id': message_id, app,
'channel_id': channel_id, {
'content': j['content'], "id": message_id,
'embeds': old_message['embeds'] "channel_id": channel_id,
}, delay=0.2) "content": j["content"],
"embeds": old_message["embeds"],
},
delay=0.2,
)
# only set new timestamp upon actual update # only set new timestamp upon actual update
if updated: if updated:
await app.db.execute(""" await app.db.execute(
"""
UPDATE messages UPDATE messages
SET edited_at = (now() at time zone 'utc') SET edited_at = (now() at time zone 'utc')
WHERE id = $1 WHERE id = $1
""", message_id) """,
message_id,
)
message = await app.storage.get_message(message_id, user_id) message = await app.storage.get_message(message_id, user_id)
# only dispatch MESSAGE_UPDATE if any update # only dispatch MESSAGE_UPDATE if any update
# actually happened # actually happened
if updated: if updated:
await app.dispatcher.dispatch('channel', channel_id, await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_UPDATE", message)
'MESSAGE_UPDATE', message)
return jsonify(message) return jsonify(message)
async def _del_msg_fkeys(message_id: int): async def _del_msg_fkeys(message_id: int):
attachs = await app.db.fetch(""" attachs = await app.db.fetch(
"""
SELECT id FROM attachments SELECT id FROM attachments
WHERE message_id = $1 WHERE message_id = $1
""", message_id) """,
message_id,
)
attachs = [r['id'] for r in attachs] attachs = [r["id"] for r in attachs]
attachments = Path('./attachments') attachments = Path("./attachments")
for attach_id in attachs: for attach_id in attachs:
# anything starting with the given attachment shall be # anything starting with the given attachment shall be
# deleted, because there may be resizes of the original # deleted, because there may be resizes of the original
# attachment laying around. # attachment laying around.
for filepath in attachments.glob(f'{attach_id}*'): for filepath in attachments.glob(f"{attach_id}*"):
try_unlink(filepath) try_unlink(filepath)
# after trying to delete all available attachments, delete # after trying to delete all available attachments, delete
@ -542,51 +591,64 @@ async def _del_msg_fkeys(message_id: int):
# take the chance and delete all the data from the other tables too! # take the chance and delete all the data from the other tables too!
tables = ['attachments', 'message_webhook_info', tables = [
'message_reactions', 'channel_pins'] "attachments",
"message_webhook_info",
"message_reactions",
"channel_pins",
]
for table in tables: for table in tables:
await app.db.execute(f""" await app.db.execute(
f"""
DELETE FROM {table} DELETE FROM {table}
WHERE message_id = $1 WHERE message_id = $1
""", message_id) """,
message_id,
)
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['DELETE']) @bp.route("/<int:channel_id>/messages/<int:message_id>", methods=["DELETE"])
async def delete_message(channel_id, message_id): async def delete_message(channel_id, message_id):
user_id = await token_check() user_id = await token_check()
_ctype, guild_id = await channel_check(user_id, channel_id) _ctype, guild_id = await channel_check(user_id, channel_id)
author_id = await app.db.fetchval(""" author_id = await app.db.fetchval(
"""
SELECT author_id FROM messages SELECT author_id FROM messages
WHERE messages.id = $1 WHERE messages.id = $1
""", message_id) """,
message_id,
by_perm = await channel_perm_check(
user_id, channel_id, 'manage_messages', False
) )
by_perm = await channel_perm_check(user_id, channel_id, "manage_messages", False)
by_ownership = author_id == user_id by_ownership = author_id == user_id
can_delete = by_perm or by_ownership can_delete = by_perm or by_ownership
if not can_delete: if not can_delete:
raise Forbidden('You can not delete this message') raise Forbidden("You can not delete this message")
await _del_msg_fkeys(message_id) await _del_msg_fkeys(message_id)
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM messages DELETE FROM messages
WHERE messages.id = $1 WHERE messages.id = $1
""", message_id) """,
message_id,
)
await app.dispatcher.dispatch( await app.dispatcher.dispatch(
'channel', channel_id, "channel",
'MESSAGE_DELETE', { channel_id,
'id': str(message_id), "MESSAGE_DELETE",
'channel_id': str(channel_id), {
"id": str(message_id),
"channel_id": str(channel_id),
# for lazy guilds # for lazy guilds
'guild_id': str(guild_id), "guild_id": str(guild_id),
}) },
)
return '', 204 return "", 204

View File

@ -28,28 +28,32 @@ from litecord.system_messages import send_sys_message
from litecord.enums import MessageType, SYS_MESSAGES from litecord.enums import MessageType, SYS_MESSAGES
from litecord.errors import BadRequest from litecord.errors import BadRequest
bp = Blueprint('channel_pins', __name__) bp = Blueprint("channel_pins", __name__)
class SysMsgInvalidAction(BadRequest): class SysMsgInvalidAction(BadRequest):
"""Invalid action on a system message.""" """Invalid action on a system message."""
error_code = 50021 error_code = 50021
@bp.route('/<int:channel_id>/pins', methods=['GET']) @bp.route("/<int:channel_id>/pins", methods=["GET"])
async def get_pins(channel_id): async def get_pins(channel_id):
"""Get the pins for a channel""" """Get the pins for a channel"""
user_id = await token_check() user_id = await token_check()
await channel_check(user_id, channel_id) await channel_check(user_id, channel_id)
ids = await app.db.fetch(""" ids = await app.db.fetch(
"""
SELECT message_id SELECT message_id
FROM channel_pins FROM channel_pins
WHERE channel_id = $1 WHERE channel_id = $1
ORDER BY message_id DESC ORDER BY message_id DESC
""", channel_id) """,
channel_id,
)
ids = [r['message_id'] for r in ids] ids = [r["message_id"] for r in ids]
res = [] res = []
for message_id in ids: for message_id in ids:
@ -60,80 +64,96 @@ async def get_pins(channel_id):
return jsonify(res) return jsonify(res)
@bp.route('/<int:channel_id>/pins/<int:message_id>', methods=['PUT']) @bp.route("/<int:channel_id>/pins/<int:message_id>", methods=["PUT"])
async def add_pin(channel_id, message_id): async def add_pin(channel_id, message_id):
"""Add a pin to a channel""" """Add a pin to a channel"""
user_id = await token_check() user_id = await token_check()
_ctype, guild_id = await channel_check(user_id, channel_id) _ctype, guild_id = await channel_check(user_id, channel_id)
await channel_perm_check(user_id, channel_id, 'manage_messages') await channel_perm_check(user_id, channel_id, "manage_messages")
mtype = await app.db.fetchval(""" mtype = await app.db.fetchval(
"""
SELECT message_type SELECT message_type
FROM messages FROM messages
WHERE id = $1 WHERE id = $1
""", message_id) """,
message_id,
)
if mtype in SYS_MESSAGES: if mtype in SYS_MESSAGES:
raise SysMsgInvalidAction( raise SysMsgInvalidAction("Cannot execute action on a system message")
'Cannot execute action on a system message')
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO channel_pins (channel_id, message_id) INSERT INTO channel_pins (channel_id, message_id)
VALUES ($1, $2) VALUES ($1, $2)
""", channel_id, message_id) """,
channel_id,
message_id,
)
row = await app.db.fetchrow(""" row = await app.db.fetchrow(
"""
SELECT message_id SELECT message_id
FROM channel_pins FROM channel_pins
WHERE channel_id = $1 WHERE channel_id = $1
ORDER BY message_id ASC ORDER BY message_id ASC
LIMIT 1 LIMIT 1
""", channel_id) """,
channel_id,
timestamp = snowflake_datetime(row['message_id'])
await app.dispatcher.dispatch(
'channel', channel_id, 'CHANNEL_PINS_UPDATE',
{
'channel_id': str(channel_id),
'last_pin_timestamp': timestamp_(timestamp)
}
) )
await send_sys_message(app, channel_id, timestamp = snowflake_datetime(row["message_id"])
MessageType.CHANNEL_PINNED_MESSAGE,
message_id, user_id)
return '', 204 await app.dispatcher.dispatch(
"channel",
channel_id,
"CHANNEL_PINS_UPDATE",
{"channel_id": str(channel_id), "last_pin_timestamp": timestamp_(timestamp)},
)
await send_sys_message(
app, channel_id, MessageType.CHANNEL_PINNED_MESSAGE, message_id, user_id
)
return "", 204
@bp.route('/<int:channel_id>/pins/<int:message_id>', methods=['DELETE']) @bp.route("/<int:channel_id>/pins/<int:message_id>", methods=["DELETE"])
async def delete_pin(channel_id, message_id): async def delete_pin(channel_id, message_id):
user_id = await token_check() user_id = await token_check()
_ctype, guild_id = await channel_check(user_id, channel_id) _ctype, guild_id = await channel_check(user_id, channel_id)
await channel_perm_check(user_id, channel_id, 'manage_messages') await channel_perm_check(user_id, channel_id, "manage_messages")
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM channel_pins DELETE FROM channel_pins
WHERE channel_id = $1 AND message_id = $2 WHERE channel_id = $1 AND message_id = $2
""", channel_id, message_id) """,
channel_id,
message_id,
)
row = await app.db.fetchrow(""" row = await app.db.fetchrow(
"""
SELECT message_id SELECT message_id
FROM channel_pins FROM channel_pins
WHERE channel_id = $1 WHERE channel_id = $1
ORDER BY message_id ASC ORDER BY message_id ASC
LIMIT 1 LIMIT 1
""", channel_id) """,
channel_id,
)
timestamp = snowflake_datetime(row['message_id']) timestamp = snowflake_datetime(row["message_id"])
await app.dispatcher.dispatch( await app.dispatcher.dispatch(
'channel', channel_id, 'CHANNEL_PINS_UPDATE', { "channel",
'channel_id': str(channel_id), channel_id,
'last_pin_timestamp': timestamp.isoformat() "CHANNEL_PINS_UPDATE",
}) {"channel_id": str(channel_id), "last_pin_timestamp": timestamp.isoformat()},
)
return '', 204 return "", 204

View File

@ -26,17 +26,15 @@ from logbook import Logger
from litecord.utils import async_map from litecord.utils import async_map
from litecord.blueprints.auth import token_check from litecord.blueprints.auth import token_check
from litecord.blueprints.checks import channel_check, channel_perm_check from litecord.blueprints.checks import channel_check, channel_perm_check
from litecord.blueprints.channel.messages import ( from litecord.blueprints.channel.messages import query_tuple_from_args, extract_limit
query_tuple_from_args, extract_limit
)
from litecord.enums import GUILD_CHANS from litecord.enums import GUILD_CHANS
log = Logger(__name__) log = Logger(__name__)
bp = Blueprint('channel_reactions', __name__) bp = Blueprint("channel_reactions", __name__)
BASEPATH = '/<int:channel_id>/messages/<int:message_id>/reactions' BASEPATH = "/<int:channel_id>/messages/<int:message_id>/reactions"
class EmojiType(IntEnum): class EmojiType(IntEnum):
@ -51,16 +49,14 @@ def emoji_info_from_str(emoji: str) -> tuple:
# unicode emoji just have the raw unicode. # unicode emoji just have the raw unicode.
# try checking if the emoji is custom or unicode # try checking if the emoji is custom or unicode
emoji_type = 0 if ':' in emoji else 1 emoji_type = 0 if ":" in emoji else 1
emoji_type = EmojiType(emoji_type) emoji_type = EmojiType(emoji_type)
# extract the emoji id OR the unicode value of the emoji # extract the emoji id OR the unicode value of the emoji
# depending if it is custom or not # depending if it is custom or not
emoji_id = (int(emoji.split(':')[1]) emoji_id = int(emoji.split(":")[1]) if emoji_type == EmojiType.CUSTOM else emoji
if emoji_type == EmojiType.CUSTOM
else emoji)
emoji_name = emoji.split(':')[0] emoji_name = emoji.split(":")[0]
return emoji_type, emoji_id, emoji_name return emoji_type, emoji_id, emoji_name
@ -68,27 +64,27 @@ def emoji_info_from_str(emoji: str) -> tuple:
def partial_emoji(emoji_type, emoji_id, emoji_name) -> dict: def partial_emoji(emoji_type, emoji_id, emoji_name) -> dict:
print(emoji_type, emoji_id, emoji_name) print(emoji_type, emoji_id, emoji_name)
return { return {
'id': None if emoji_type == EmojiType.UNICODE else emoji_id, "id": None if emoji_type == EmojiType.UNICODE else emoji_id,
'name': emoji_name if emoji_type == EmojiType.UNICODE else emoji_id "name": emoji_name if emoji_type == EmojiType.UNICODE else emoji_id,
} }
def _make_payload(user_id, channel_id, message_id, partial): def _make_payload(user_id, channel_id, message_id, partial):
return { return {
'user_id': str(user_id), "user_id": str(user_id),
'channel_id': str(channel_id), "channel_id": str(channel_id),
'message_id': str(message_id), "message_id": str(message_id),
'emoji': partial "emoji": partial,
} }
@bp.route(f'{BASEPATH}/<emoji>/@me', methods=['PUT']) @bp.route(f"{BASEPATH}/<emoji>/@me", methods=["PUT"])
async def add_reaction(channel_id: int, message_id: int, emoji: str): async def add_reaction(channel_id: int, message_id: int, emoji: str):
"""Put a reaction.""" """Put a reaction."""
user_id = await token_check() user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id) ctype, guild_id = await channel_check(user_id, channel_id)
await channel_perm_check(user_id, channel_id, 'read_history') await channel_perm_check(user_id, channel_id, "read_history")
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji) emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
@ -97,52 +93,64 @@ async def add_reaction(channel_id: int, message_id: int, emoji: str):
# ADD_REACTIONS is only checked when this is the first # ADD_REACTIONS is only checked when this is the first
# reaction in a message. # reaction in a message.
reaction_count = await app.db.fetchval(""" reaction_count = await app.db.fetchval(
"""
SELECT COUNT(*) SELECT COUNT(*)
FROM message_reactions FROM message_reactions
WHERE message_id = $1 WHERE message_id = $1
AND emoji_type = $2 AND emoji_type = $2
AND emoji_id = $3 AND emoji_id = $3
AND emoji_text = $4 AND emoji_text = $4
""", message_id, emoji_type, emoji_id, emoji_text) """,
message_id,
emoji_type,
emoji_id,
emoji_text,
)
if reaction_count == 0: if reaction_count == 0:
await channel_perm_check(user_id, channel_id, 'add_reactions') await channel_perm_check(user_id, channel_id, "add_reactions")
await app.db.execute( await app.db.execute(
""" """
INSERT INTO message_reactions (message_id, user_id, INSERT INTO message_reactions (message_id, user_id,
emoji_type, emoji_id, emoji_text) emoji_type, emoji_id, emoji_text)
VALUES ($1, $2, $3, $4, $5) VALUES ($1, $2, $3, $4, $5)
""", message_id, user_id, emoji_type, """,
message_id,
user_id,
emoji_type,
# if it is custom, we put the emoji_id on emoji_id # if it is custom, we put the emoji_id on emoji_id
# column, if it isn't, we put it on emoji_text # column, if it isn't, we put it on emoji_text
# column. # column.
emoji_id, emoji_text emoji_id,
emoji_text,
) )
partial = partial_emoji(emoji_type, emoji_id, emoji_name) partial = partial_emoji(emoji_type, emoji_id, emoji_name)
payload = _make_payload(user_id, channel_id, message_id, partial) payload = _make_payload(user_id, channel_id, message_id, partial)
if ctype in GUILD_CHANS: if ctype in GUILD_CHANS:
payload['guild_id'] = str(guild_id) payload["guild_id"] = str(guild_id)
await app.dispatcher.dispatch( await app.dispatcher.dispatch(
'channel', channel_id, 'MESSAGE_REACTION_ADD', payload) "channel", channel_id, "MESSAGE_REACTION_ADD", payload
)
return '', 204 return "", 204
def emoji_sql(emoji_type, emoji_id, emoji_name, param=4): def emoji_sql(emoji_type, emoji_id, emoji_name, param=4):
"""Extract SQL clauses to search for specific emoji """Extract SQL clauses to search for specific emoji
in the message_reactions table.""" in the message_reactions table."""
param = f'${param}' param = f"${param}"
# know which column to filter with # know which column to filter with
where_ext = (f'AND emoji_id = {param}' where_ext = (
if emoji_type == EmojiType.CUSTOM else f"AND emoji_id = {param}"
f'AND emoji_text = {param}') if emoji_type == EmojiType.CUSTOM
else f"AND emoji_text = {param}"
)
# which emoji to remove (custom or unicode) # which emoji to remove (custom or unicode)
main_emoji = emoji_id if emoji_type == EmojiType.CUSTOM else emoji_name main_emoji = emoji_id if emoji_type == EmojiType.CUSTOM else emoji_name
@ -157,8 +165,7 @@ def _emoji_sql_simple(emoji: str, param=4):
return emoji_sql(emoji_type, emoji_id, emoji_name, param) return emoji_sql(emoji_type, emoji_id, emoji_name, param)
async def remove_reaction(channel_id: int, message_id: int, async def remove_reaction(channel_id: int, message_id: int, user_id: int, emoji: str):
user_id: int, emoji: str):
ctype, guild_id = await channel_check(user_id, channel_id) ctype, guild_id = await channel_check(user_id, channel_id)
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji) emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
@ -171,39 +178,45 @@ async def remove_reaction(channel_id: int, message_id: int,
AND user_id = $2 AND user_id = $2
AND emoji_type = $3 AND emoji_type = $3
{where_ext} {where_ext}
""", message_id, user_id, emoji_type, main_emoji) """,
message_id,
user_id,
emoji_type,
main_emoji,
)
partial = partial_emoji(emoji_type, emoji_id, emoji_name) partial = partial_emoji(emoji_type, emoji_id, emoji_name)
payload = _make_payload(user_id, channel_id, message_id, partial) payload = _make_payload(user_id, channel_id, message_id, partial)
if ctype in GUILD_CHANS: if ctype in GUILD_CHANS:
payload['guild_id'] = str(guild_id) payload["guild_id"] = str(guild_id)
await app.dispatcher.dispatch( await app.dispatcher.dispatch(
'channel', channel_id, 'MESSAGE_REACTION_REMOVE', payload) "channel", channel_id, "MESSAGE_REACTION_REMOVE", payload
)
@bp.route(f'{BASEPATH}/<emoji>/@me', methods=['DELETE']) @bp.route(f"{BASEPATH}/<emoji>/@me", methods=["DELETE"])
async def remove_own_reaction(channel_id, message_id, emoji): async def remove_own_reaction(channel_id, message_id, emoji):
"""Remove a reaction.""" """Remove a reaction."""
user_id = await token_check() user_id = await token_check()
await remove_reaction(channel_id, message_id, user_id, emoji) await remove_reaction(channel_id, message_id, user_id, emoji)
return '', 204 return "", 204
@bp.route(f'{BASEPATH}/<emoji>/<int:other_id>', methods=['DELETE']) @bp.route(f"{BASEPATH}/<emoji>/<int:other_id>", methods=["DELETE"])
async def remove_user_reaction(channel_id, message_id, emoji, other_id): async def remove_user_reaction(channel_id, message_id, emoji, other_id):
"""Remove a reaction made by another user.""" """Remove a reaction made by another user."""
user_id = await token_check() user_id = await token_check()
await channel_perm_check(user_id, channel_id, 'manage_messages') await channel_perm_check(user_id, channel_id, "manage_messages")
await remove_reaction(channel_id, message_id, other_id, emoji) await remove_reaction(channel_id, message_id, other_id, emoji)
return '', 204 return "", 204
@bp.route(f'{BASEPATH}/<emoji>', methods=['GET']) @bp.route(f"{BASEPATH}/<emoji>", methods=["GET"])
async def list_users_reaction(channel_id, message_id, emoji): async def list_users_reaction(channel_id, message_id, emoji):
"""Get the list of all users who reacted with a certain emoji.""" """Get the list of all users who reacted with a certain emoji."""
user_id = await token_check() user_id = await token_check()
@ -215,42 +228,49 @@ async def list_users_reaction(channel_id, message_id, emoji):
limit = extract_limit(request, 25) limit = extract_limit(request, 25)
before, after = query_tuple_from_args(request.args, limit) before, after = query_tuple_from_args(request.args, limit)
before_clause = 'AND user_id < $2' if before else '' before_clause = "AND user_id < $2" if before else ""
after_clause = 'AND user_id > $3' if after else '' after_clause = "AND user_id > $3" if after else ""
where_ext, main_emoji = _emoji_sql_simple(emoji, 4) where_ext, main_emoji = _emoji_sql_simple(emoji, 4)
rows = await app.db.fetch(f""" rows = await app.db.fetch(
f"""
SELECT user_id SELECT user_id
FROM message_reactions FROM message_reactions
WHERE message_id = $1 {before_clause} {after_clause} {where_ext} WHERE message_id = $1 {before_clause} {after_clause} {where_ext}
""", message_id, before, after, main_emoji) """,
message_id,
before,
after,
main_emoji,
)
user_ids = [r['user_id'] for r in rows] user_ids = [r["user_id"] for r in rows]
users = await async_map(app.storage.get_user, user_ids) users = await async_map(app.storage.get_user, user_ids)
return jsonify(users) return jsonify(users)
@bp.route(f'{BASEPATH}', methods=['DELETE']) @bp.route(f"{BASEPATH}", methods=["DELETE"])
async def remove_all_reactions(channel_id, message_id): async def remove_all_reactions(channel_id, message_id):
"""Remove all reactions in a message.""" """Remove all reactions in a message."""
user_id = await token_check() user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id) ctype, guild_id = await channel_check(user_id, channel_id)
await channel_perm_check(user_id, channel_id, 'manage_messages') await channel_perm_check(user_id, channel_id, "manage_messages")
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM message_reactions DELETE FROM message_reactions
WHERE message_id = $1 WHERE message_id = $1
""", message_id) """,
message_id,
)
payload = { payload = {"channel_id": str(channel_id), "message_id": str(message_id)}
'channel_id': str(channel_id),
'message_id': str(message_id),
}
if ctype in GUILD_CHANS: if ctype in GUILD_CHANS:
payload['guild_id'] = str(guild_id) payload["guild_id"] = str(guild_id)
await app.dispatcher.dispatch( await app.dispatcher.dispatch(
'channel', channel_id, 'MESSAGE_REACTION_REMOVE_ALL', payload) "channel", channel_id, "MESSAGE_REACTION_REMOVE_ALL", payload
)

View File

@ -28,24 +28,26 @@ from litecord.auth import token_check
from litecord.enums import ChannelType, GUILD_CHANS, MessageType, MessageFlags from litecord.enums import ChannelType, GUILD_CHANS, MessageType, MessageFlags
from litecord.errors import ChannelNotFound, Forbidden, BadRequest from litecord.errors import ChannelNotFound, Forbidden, BadRequest
from litecord.schemas import ( from litecord.schemas import (
validate, CHAN_UPDATE, CHAN_OVERWRITE, SEARCH_CHANNEL, GROUP_DM_UPDATE, validate,
CHAN_UPDATE,
CHAN_OVERWRITE,
SEARCH_CHANNEL,
GROUP_DM_UPDATE,
BULK_DELETE, BULK_DELETE,
) )
from litecord.blueprints.checks import channel_check, channel_perm_check from litecord.blueprints.checks import channel_check, channel_perm_check
from litecord.system_messages import send_sys_message from litecord.system_messages import send_sys_message
from litecord.blueprints.dm_channels import ( from litecord.blueprints.dm_channels import gdm_remove_recipient, gdm_destroy
gdm_remove_recipient, gdm_destroy
)
from litecord.utils import search_result_from_list from litecord.utils import search_result_from_list
from litecord.embed.messages import process_url_embed, msg_update_embeds from litecord.embed.messages import process_url_embed, msg_update_embeds
from litecord.snowflake import snowflake_datetime from litecord.snowflake import snowflake_datetime
log = Logger(__name__) log = Logger(__name__)
bp = Blueprint('channels', __name__) bp = Blueprint("channels", __name__)
@bp.route('/<int:channel_id>', methods=['GET']) @bp.route("/<int:channel_id>", methods=["GET"])
async def get_channel(channel_id): async def get_channel(channel_id):
"""Get a single channel's information""" """Get a single channel's information"""
user_id = await token_check() user_id = await token_check()
@ -56,7 +58,7 @@ async def get_channel(channel_id):
chan = await app.storage.get_channel(channel_id) chan = await app.storage.get_channel(channel_id)
if not chan: if not chan:
raise ChannelNotFound('single channel not found') raise ChannelNotFound("single channel not found")
return jsonify(chan) return jsonify(chan)
@ -64,106 +66,129 @@ async def get_channel(channel_id):
async def __guild_chan_sql(guild_id, channel_id, field: str) -> str: async def __guild_chan_sql(guild_id, channel_id, field: str) -> str:
"""Update a guild's channel id field to NULL, """Update a guild's channel id field to NULL,
if it was set to the given channel id before.""" if it was set to the given channel id before."""
return await app.db.execute(f""" return await app.db.execute(
f"""
UPDATE guilds UPDATE guilds
SET {field} = NULL SET {field} = NULL
WHERE guilds.id = $1 AND {field} = $2 WHERE guilds.id = $1 AND {field} = $2
""", guild_id, channel_id) """,
guild_id,
channel_id,
)
async def _update_guild_chan_text(guild_id: int, channel_id: int): async def _update_guild_chan_text(guild_id: int, channel_id: int):
res_embed = await __guild_chan_sql( res_embed = await __guild_chan_sql(guild_id, channel_id, "embed_channel_id")
guild_id, channel_id, 'embed_channel_id')
res_widget = await __guild_chan_sql( res_widget = await __guild_chan_sql(guild_id, channel_id, "widget_channel_id")
guild_id, channel_id, 'widget_channel_id')
res_system = await __guild_chan_sql( res_system = await __guild_chan_sql(guild_id, channel_id, "system_channel_id")
guild_id, channel_id, 'system_channel_id')
# if none of them were actually updated, # if none of them were actually updated,
# ignore and dont dispatch anything # ignore and dont dispatch anything
if 'UPDATE 1' not in (res_embed, res_widget, res_system): if "UPDATE 1" not in (res_embed, res_widget, res_system):
return return
# at least one of the fields were updated, # at least one of the fields were updated,
# dispatch GUILD_UPDATE # dispatch GUILD_UPDATE
guild = await app.storage.get_guild(guild_id) guild = await app.storage.get_guild(guild_id)
await app.dispatcher.dispatch_guild( await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
guild_id, 'GUILD_UPDATE', guild)
async def _update_guild_chan_voice(guild_id: int, channel_id: int): async def _update_guild_chan_voice(guild_id: int, channel_id: int):
res = await __guild_chan_sql(guild_id, channel_id, 'afk_channel_id') res = await __guild_chan_sql(guild_id, channel_id, "afk_channel_id")
# guild didnt update # guild didnt update
if res == 'UPDATE 0': if res == "UPDATE 0":
return return
guild = await app.storage.get_guild(guild_id) guild = await app.storage.get_guild(guild_id)
await app.dispatcher.dispatch_guild( await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
guild_id, 'GUILD_UPDATE', guild)
async def _update_guild_chan_cat(guild_id: int, channel_id: int): async def _update_guild_chan_cat(guild_id: int, channel_id: int):
# get all channels that were childs of the category # get all channels that were childs of the category
childs = await app.db.fetch(""" childs = await app.db.fetch(
"""
SELECT id SELECT id
FROM guild_channels FROM guild_channels
WHERE guild_id = $1 AND parent_id = $2 WHERE guild_id = $1 AND parent_id = $2
""", guild_id, channel_id) """,
childs = [c['id'] for c in childs] guild_id,
channel_id,
)
childs = [c["id"] for c in childs]
# update every child channel to parent_id = NULL # update every child channel to parent_id = NULL
await app.db.execute(""" await app.db.execute(
"""
UPDATE guild_channels UPDATE guild_channels
SET parent_id = NULL SET parent_id = NULL
WHERE guild_id = $1 AND parent_id = $2 WHERE guild_id = $1 AND parent_id = $2
""", guild_id, channel_id) """,
guild_id,
channel_id,
)
# tell all people in the guild of the category removal # tell all people in the guild of the category removal
for child_id in childs: for child_id in childs:
child = await app.storage.get_channel(child_id) child = await app.storage.get_channel(child_id)
await app.dispatcher.dispatch_guild( await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", child)
guild_id, 'CHANNEL_UPDATE', child
)
async def delete_messages(channel_id): async def delete_messages(channel_id):
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM channel_pins DELETE FROM channel_pins
WHERE channel_id = $1 WHERE channel_id = $1
""", channel_id) """,
channel_id,
)
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM user_read_state DELETE FROM user_read_state
WHERE channel_id = $1 WHERE channel_id = $1
""", channel_id) """,
channel_id,
)
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM messages DELETE FROM messages
WHERE channel_id = $1 WHERE channel_id = $1
""", channel_id) """,
channel_id,
)
async def guild_cleanup(channel_id): async def guild_cleanup(channel_id):
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM channel_overwrites DELETE FROM channel_overwrites
WHERE channel_id = $1 WHERE channel_id = $1
""", channel_id) """,
channel_id,
)
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM invites DELETE FROM invites
WHERE channel_id = $1 WHERE channel_id = $1
""", channel_id) """,
channel_id,
)
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM webhooks DELETE FROM webhooks
WHERE channel_id = $1 WHERE channel_id = $1
""", channel_id) """,
channel_id,
)
@bp.route('/<int:channel_id>', methods=['DELETE']) @bp.route("/<int:channel_id>", methods=["DELETE"])
async def close_channel(channel_id): async def close_channel(channel_id):
"""Close or delete a channel.""" """Close or delete a channel."""
user_id = await token_check() user_id = await token_check()
@ -184,9 +209,8 @@ async def close_channel(channel_id):
}[ctype] }[ctype]
main_tbl = { main_tbl = {
ChannelType.GUILD_TEXT: 'guild_text_channels', ChannelType.GUILD_TEXT: "guild_text_channels",
ChannelType.GUILD_VOICE: 'guild_voice_channels', ChannelType.GUILD_VOICE: "guild_voice_channels",
# TODO: categories? # TODO: categories?
}[ctype] }[ctype]
@ -199,29 +223,37 @@ async def close_channel(channel_id):
await delete_messages(channel_id) await delete_messages(channel_id)
await guild_cleanup(channel_id) await guild_cleanup(channel_id)
await app.db.execute(f""" await app.db.execute(
f"""
DELETE FROM {main_tbl} DELETE FROM {main_tbl}
WHERE id = $1 WHERE id = $1
""", channel_id) """,
channel_id,
)
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM guild_channels DELETE FROM guild_channels
WHERE id = $1 WHERE id = $1
""", channel_id) """,
channel_id,
)
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM channels DELETE FROM channels
WHERE id = $1 WHERE id = $1
""", channel_id) """,
channel_id,
)
# clean its member list representation # clean its member list representation
lazy_guilds = app.dispatcher.backends['lazy_guild'] lazy_guilds = app.dispatcher.backends["lazy_guild"]
lazy_guilds.remove_channel(channel_id) lazy_guilds.remove_channel(channel_id)
await app.dispatcher.dispatch_guild( await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_DELETE", chan)
guild_id, 'CHANNEL_DELETE', chan)
await app.dispatcher.remove('channel', channel_id) await app.dispatcher.remove("channel", channel_id)
return jsonify(chan) return jsonify(chan)
if ctype == ChannelType.DM: if ctype == ChannelType.DM:
@ -231,27 +263,34 @@ async def close_channel(channel_id):
# instead, we close the channel for the user that is making # instead, we close the channel for the user that is making
# the request via removing the link between them and # the request via removing the link between them and
# the channel on dm_channel_state # the channel on dm_channel_state
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM dm_channel_state DELETE FROM dm_channel_state
WHERE user_id = $1 AND dm_id = $2 WHERE user_id = $1 AND dm_id = $2
""", user_id, channel_id) """,
user_id,
channel_id,
)
# unsubscribe # unsubscribe
await app.dispatcher.unsub('channel', channel_id, user_id) await app.dispatcher.unsub("channel", channel_id, user_id)
# nothing happens to the other party of the dm channel # nothing happens to the other party of the dm channel
await app.dispatcher.dispatch_user(user_id, 'CHANNEL_DELETE', chan) await app.dispatcher.dispatch_user(user_id, "CHANNEL_DELETE", chan)
return jsonify(chan) return jsonify(chan)
if ctype == ChannelType.GROUP_DM: if ctype == ChannelType.GROUP_DM:
await gdm_remove_recipient(channel_id, user_id) await gdm_remove_recipient(channel_id, user_id)
gdm_count = await app.db.fetchval(""" gdm_count = await app.db.fetchval(
"""
SELECT COUNT(*) SELECT COUNT(*)
FROM group_dm_members FROM group_dm_members
WHERE id = $1 WHERE id = $1
""", channel_id) """,
channel_id,
)
if gdm_count == 0: if gdm_count == 0:
# destroy dm # destroy dm
@ -261,11 +300,15 @@ async def close_channel(channel_id):
async def _update_pos(channel_id, pos: int): async def _update_pos(channel_id, pos: int):
await app.db.execute(""" await app.db.execute(
"""
UPDATE guild_channels UPDATE guild_channels
SET position = $1 SET position = $1
WHERE id = $2 WHERE id = $2
""", pos, channel_id) """,
pos,
channel_id,
)
async def _mass_chan_update(guild_id, channel_ids: List[Optional[int]]): async def _mass_chan_update(guild_id, channel_ids: List[Optional[int]]):
@ -274,20 +317,19 @@ async def _mass_chan_update(guild_id, channel_ids: List[Optional[int]]):
continue continue
chan = await app.storage.get_channel(channel_id) chan = await app.storage.get_channel(channel_id)
await app.dispatcher.dispatch( await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan)
'guild', guild_id, 'CHANNEL_UPDATE', chan)
async def _process_overwrites(channel_id: int, overwrites: list): async def _process_overwrites(channel_id: int, overwrites: list):
for overwrite in overwrites: for overwrite in overwrites:
# 0 for member overwrite, 1 for role overwrite # 0 for member overwrite, 1 for role overwrite
target_type = 0 if overwrite['type'] == 'member' else 1 target_type = 0 if overwrite["type"] == "member" else 1
target_role = None if target_type == 0 else overwrite['id'] target_role = None if target_type == 0 else overwrite["id"]
target_user = overwrite['id'] if target_type == 0 else None target_user = overwrite["id"] if target_type == 0 else None
col_name = 'target_user' if target_type == 0 else 'target_role' col_name = "target_user" if target_type == 0 else "target_role"
constraint_name = f'channel_overwrites_{col_name}_uniq' constraint_name = f"channel_overwrites_{col_name}_uniq"
await app.db.execute( await app.db.execute(
f""" f"""
@ -301,53 +343,66 @@ async def _process_overwrites(channel_id: int, overwrites: list):
UPDATE UPDATE
SET allow = $5, deny = $6 SET allow = $5, deny = $6
""", """,
channel_id, target_type, channel_id,
target_role, target_user, target_type,
overwrite['allow'], overwrite['deny']) target_role,
target_user,
overwrite["allow"],
overwrite["deny"],
)
@bp.route('/<int:channel_id>/permissions/<int:overwrite_id>', methods=['PUT']) @bp.route("/<int:channel_id>/permissions/<int:overwrite_id>", methods=["PUT"])
async def put_channel_overwrite(channel_id: int, overwrite_id: int): async def put_channel_overwrite(channel_id: int, overwrite_id: int):
"""Insert or modify a channel overwrite.""" """Insert or modify a channel overwrite."""
user_id = await token_check() user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id) ctype, guild_id = await channel_check(user_id, channel_id)
if ctype not in GUILD_CHANS: if ctype not in GUILD_CHANS:
raise ChannelNotFound('Only usable for guild channels.') raise ChannelNotFound("Only usable for guild channels.")
await channel_perm_check(user_id, guild_id, 'manage_roles') await channel_perm_check(user_id, guild_id, "manage_roles")
j = validate( j = validate(
# inserting a fake id on the payload so validation passes through # inserting a fake id on the payload so validation passes through
{**await request.get_json(), **{'id': -1}}, {**await request.get_json(), **{"id": -1}},
CHAN_OVERWRITE CHAN_OVERWRITE,
) )
await _process_overwrites(channel_id, [{ await _process_overwrites(
'allow': j['allow'], channel_id,
'deny': j['deny'], [
'type': j['type'], {
'id': overwrite_id "allow": j["allow"],
}]) "deny": j["deny"],
"type": j["type"],
"id": overwrite_id,
}
],
)
await _mass_chan_update(guild_id, [channel_id]) await _mass_chan_update(guild_id, [channel_id])
return '', 204 return "", 204
async def _update_channel_common(channel_id, guild_id: int, j: dict): async def _update_channel_common(channel_id, guild_id: int, j: dict):
if 'name' in j: if "name" in j:
await app.db.execute(""" await app.db.execute(
"""
UPDATE guild_channels UPDATE guild_channels
SET name = $1 SET name = $1
WHERE id = $2 WHERE id = $2
""", j['name'], channel_id) """,
j["name"],
channel_id,
)
if 'position' in j: if "position" in j:
channel_data = await app.storage.get_channel_data(guild_id) channel_data = await app.storage.get_channel_data(guild_id)
chans = [None] * len(channel_data) chans = [None] * len(channel_data)
for chandata in channel_data: for chandata in channel_data:
chans.insert(chandata['position'], int(chandata['id'])) chans.insert(chandata["position"], int(chandata["id"]))
# are we changing to the left or to the right? # are we changing to the left or to the right?
@ -358,7 +413,7 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict):
# channelN-1 going to the position channel2 # channelN-1 going to the position channel2
# was occupying. # was occupying.
current_pos = chans.index(channel_id) current_pos = chans.index(channel_id)
new_pos = j['position'] new_pos = j["position"]
# if the new position is bigger than the current one, # if the new position is bigger than the current one,
# we're making a left shift of all the channels that are # we're making a left shift of all the channels that are
@ -366,113 +421,136 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict):
left_shift = new_pos > current_pos left_shift = new_pos > current_pos
# find all channels that we'll have to shift # find all channels that we'll have to shift
shift_block = (chans[current_pos:new_pos] shift_block = (
if left_shift else chans[current_pos:new_pos] if left_shift else chans[new_pos:current_pos]
chans[new_pos:current_pos] )
)
shift = -1 if left_shift else 1 shift = -1 if left_shift else 1
# do the shift (to the left or to the right) # do the shift (to the left or to the right)
await app.db.executemany(""" await app.db.executemany(
"""
UPDATE guild_channels UPDATE guild_channels
SET position = position + $1 SET position = position + $1
WHERE id = $2 WHERE id = $2
""", [(shift, chan_id) for chan_id in shift_block]) """,
[(shift, chan_id) for chan_id in shift_block],
)
await _mass_chan_update(guild_id, shift_block) await _mass_chan_update(guild_id, shift_block)
# since theres now an empty slot, move current channel to it # since theres now an empty slot, move current channel to it
await _update_pos(channel_id, new_pos) await _update_pos(channel_id, new_pos)
if 'channel_overwrites' in j: if "channel_overwrites" in j:
overwrites = j['channel_overwrites'] overwrites = j["channel_overwrites"]
await _process_overwrites(channel_id, overwrites) await _process_overwrites(channel_id, overwrites)
async def _common_guild_chan(channel_id, j: dict): async def _common_guild_chan(channel_id, j: dict):
# common updates to the guild_channels table # common updates to the guild_channels table
for field in [field for field in j.keys() for field in [field for field in j.keys() if field in ("nsfw", "parent_id")]:
if field in ('nsfw', 'parent_id')]: await app.db.execute(
await app.db.execute(f""" f"""
UPDATE guild_channels UPDATE guild_channels
SET {field} = $1 SET {field} = $1
WHERE id = $2 WHERE id = $2
""", j[field], channel_id) """,
j[field],
channel_id,
)
async def _update_text_channel(channel_id: int, j: dict, _user_id: int): async def _update_text_channel(channel_id: int, j: dict, _user_id: int):
# first do the specific ones related to guild_text_channels # first do the specific ones related to guild_text_channels
for field in [field for field in j.keys() for field in [
if field in ('topic', 'rate_limit_per_user')]: field for field in j.keys() if field in ("topic", "rate_limit_per_user")
await app.db.execute(f""" ]:
await app.db.execute(
f"""
UPDATE guild_text_channels UPDATE guild_text_channels
SET {field} = $1 SET {field} = $1
WHERE id = $2 WHERE id = $2
""", j[field], channel_id) """,
j[field],
channel_id,
)
await _common_guild_chan(channel_id, j) await _common_guild_chan(channel_id, j)
async def _update_voice_channel(channel_id: int, j: dict, _user_id: int): async def _update_voice_channel(channel_id: int, j: dict, _user_id: int):
# first do the specific ones in guild_voice_channels # first do the specific ones in guild_voice_channels
for field in [field for field in j.keys() for field in [field for field in j.keys() if field in ("bitrate", "user_limit")]:
if field in ('bitrate', 'user_limit')]: await app.db.execute(
await app.db.execute(f""" f"""
UPDATE guild_voice_channels UPDATE guild_voice_channels
SET {field} = $1 SET {field} = $1
WHERE id = $2 WHERE id = $2
""", j[field], channel_id) """,
j[field],
channel_id,
)
# yes, i'm letting voice channels have nsfw, you cant stop me # yes, i'm letting voice channels have nsfw, you cant stop me
await _common_guild_chan(channel_id, j) await _common_guild_chan(channel_id, j)
async def _update_group_dm(channel_id: int, j: dict, author_id: int): async def _update_group_dm(channel_id: int, j: dict, author_id: int):
if 'name' in j: if "name" in j:
await app.db.execute(""" await app.db.execute(
"""
UPDATE group_dm_channels UPDATE group_dm_channels
SET name = $1 SET name = $1
WHERE id = $2 WHERE id = $2
""", j['name'], channel_id) """,
j["name"],
channel_id,
)
await send_sys_message( await send_sys_message(
app, channel_id, MessageType.CHANNEL_NAME_CHANGE, author_id app, channel_id, MessageType.CHANNEL_NAME_CHANGE, author_id
) )
if 'icon' in j: if "icon" in j:
new_icon = await app.icons.update( new_icon = await app.icons.update(
'channel-icons', channel_id, j['icon'], always_icon=True "channel-icons", channel_id, j["icon"], always_icon=True
) )
await app.db.execute(""" await app.db.execute(
"""
UPDATE group_dm_channels UPDATE group_dm_channels
SET icon = $1 SET icon = $1
WHERE id = $2 WHERE id = $2
""", new_icon.icon_hash, channel_id) """,
new_icon.icon_hash,
channel_id,
)
await send_sys_message( await send_sys_message(
app, channel_id, MessageType.CHANNEL_ICON_CHANGE, author_id app, channel_id, MessageType.CHANNEL_ICON_CHANGE, author_id
) )
@bp.route('/<int:channel_id>', methods=['PUT', 'PATCH']) @bp.route("/<int:channel_id>", methods=["PUT", "PATCH"])
async def update_channel(channel_id): async def update_channel(channel_id):
"""Update a channel's information""" """Update a channel's information"""
user_id = await token_check() user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id) ctype, guild_id = await channel_check(user_id, channel_id)
if ctype not in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE, if ctype not in (
ChannelType.GROUP_DM): ChannelType.GUILD_TEXT,
raise ChannelNotFound('unable to edit unsupported chan type') ChannelType.GUILD_VOICE,
ChannelType.GROUP_DM,
):
raise ChannelNotFound("unable to edit unsupported chan type")
is_guild = ctype in GUILD_CHANS is_guild = ctype in GUILD_CHANS
if is_guild: if is_guild:
await channel_perm_check(user_id, channel_id, 'manage_channels') await channel_perm_check(user_id, channel_id, "manage_channels")
j = validate(await request.get_json(), j = validate(await request.get_json(), CHAN_UPDATE if is_guild else GROUP_DM_UPDATE)
CHAN_UPDATE if is_guild else GROUP_DM_UPDATE)
# TODO: categories # TODO: categories
update_handler = { update_handler = {
@ -489,30 +567,32 @@ async def update_channel(channel_id):
chan = await app.storage.get_channel(channel_id) chan = await app.storage.get_channel(channel_id)
if is_guild: if is_guild:
await app.dispatcher.dispatch( await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan)
'guild', guild_id, 'CHANNEL_UPDATE', chan)
else: else:
await app.dispatcher.dispatch( await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan)
'channel', channel_id, 'CHANNEL_UPDATE', chan)
return jsonify(chan) return jsonify(chan)
@bp.route('/<int:channel_id>/typing', methods=['POST']) @bp.route("/<int:channel_id>/typing", methods=["POST"])
async def trigger_typing(channel_id): async def trigger_typing(channel_id):
user_id = await token_check() user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id) ctype, guild_id = await channel_check(user_id, channel_id)
await app.dispatcher.dispatch('channel', channel_id, 'TYPING_START', { await app.dispatcher.dispatch(
'channel_id': str(channel_id), "channel",
'user_id': str(user_id), channel_id,
'timestamp': int(time.time()), "TYPING_START",
{
"channel_id": str(channel_id),
"user_id": str(user_id),
"timestamp": int(time.time()),
# guild_id for lazy guilds
"guild_id": str(guild_id) if ctype == ChannelType.GUILD_TEXT else None,
},
)
# guild_id for lazy guilds return "", 204
'guild_id': str(guild_id) if ctype == ChannelType.GUILD_TEXT else None,
})
return '', 204
async def channel_ack(user_id, guild_id, channel_id, message_id: int = None): async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
@ -521,7 +601,8 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
if not message_id: if not message_id:
message_id = await app.storage.chan_last_message(channel_id) message_id = await app.storage.chan_last_message(channel_id)
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO user_read_state INSERT INTO user_read_state
(user_id, channel_id, last_message_id, mention_count) (user_id, channel_id, last_message_id, mention_count)
VALUES VALUES
@ -532,26 +613,31 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
SET last_message_id = $3, mention_count = 0 SET last_message_id = $3, mention_count = 0
WHERE user_read_state.user_id = $1 WHERE user_read_state.user_id = $1
AND user_read_state.channel_id = $2 AND user_read_state.channel_id = $2
""", user_id, channel_id, message_id) """,
user_id,
channel_id,
message_id,
)
if guild_id: if guild_id:
await app.dispatcher.dispatch_user_guild( await app.dispatcher.dispatch_user_guild(
user_id, guild_id, 'MESSAGE_ACK', { user_id,
'message_id': str(message_id), guild_id,
'channel_id': str(channel_id) "MESSAGE_ACK",
}) {"message_id": str(message_id), "channel_id": str(channel_id)},
)
else: else:
# we don't use ChannelDispatcher here because since # we don't use ChannelDispatcher here because since
# guild_id is None, all user devices are already subscribed # guild_id is None, all user devices are already subscribed
# to the given channel (a dm or a group dm) # to the given channel (a dm or a group dm)
await app.dispatcher.dispatch_user( await app.dispatcher.dispatch_user(
user_id, 'MESSAGE_ACK', { user_id,
'message_id': str(message_id), "MESSAGE_ACK",
'channel_id': str(channel_id) {"message_id": str(message_id), "channel_id": str(channel_id)},
}) )
@bp.route('/<int:channel_id>/messages/<int:message_id>/ack', methods=['POST']) @bp.route("/<int:channel_id>/messages/<int:message_id>/ack", methods=["POST"])
async def ack_channel(channel_id, message_id): async def ack_channel(channel_id, message_id):
"""Acknowledge a channel.""" """Acknowledge a channel."""
user_id = await token_check() user_id = await token_check()
@ -562,40 +648,47 @@ async def ack_channel(channel_id, message_id):
await channel_ack(user_id, guild_id, channel_id, message_id) await channel_ack(user_id, guild_id, channel_id, message_id)
return jsonify({ return jsonify(
# token seems to be used for {
# data collection activities, # token seems to be used for
# so we never use it. # data collection activities,
'token': None # so we never use it.
}) "token": None
}
)
@bp.route('/<int:channel_id>/messages/ack', methods=['DELETE']) @bp.route("/<int:channel_id>/messages/ack", methods=["DELETE"])
async def delete_read_state(channel_id): async def delete_read_state(channel_id):
"""Delete the read state of a channel.""" """Delete the read state of a channel."""
user_id = await token_check() user_id = await token_check()
await channel_check(user_id, channel_id) await channel_check(user_id, channel_id)
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM user_read_state DELETE FROM user_read_state
WHERE user_id = $1 AND channel_id = $2 WHERE user_id = $1 AND channel_id = $2
""", user_id, channel_id) """,
user_id,
channel_id,
)
return '', 204 return "", 204
@bp.route('/<int:channel_id>/messages/search', methods=['GET']) @bp.route("/<int:channel_id>/messages/search", methods=["GET"])
async def _search_channel(channel_id): async def _search_channel(channel_id):
"""Search in DMs or group DMs""" """Search in DMs or group DMs"""
user_id = await token_check() user_id = await token_check()
await channel_check(user_id, channel_id) await channel_check(user_id, channel_id)
await channel_perm_check(user_id, channel_id, 'read_messages') await channel_perm_check(user_id, channel_id, "read_messages")
j = validate(dict(request.args), SEARCH_CHANNEL) j = validate(dict(request.args), SEARCH_CHANNEL)
# main search query # main search query
# the context (before/after) columns are copied from the guilds blueprint. # the context (before/after) columns are copied from the guilds blueprint.
rows = await app.db.fetch(f""" rows = await app.db.fetch(
f"""
SELECT orig.id AS current_id, SELECT orig.id AS current_id,
COUNT(*) OVER() AS total_results, COUNT(*) OVER() AS total_results,
array((SELECT messages.id AS before_id array((SELECT messages.id AS before_id
@ -611,28 +704,40 @@ async def _search_channel(channel_id):
ORDER BY orig.id DESC ORDER BY orig.id DESC
LIMIT 50 LIMIT 50
OFFSET $2 OFFSET $2
""", channel_id, j['offset'], j['content']) """,
channel_id,
j["offset"],
j["content"],
)
return jsonify(await search_result_from_list(rows)) return jsonify(await search_result_from_list(rows))
# NOTE that those functions stay here until some other # NOTE that those functions stay here until some other
# route or code wants it. # route or code wants it.
async def _msg_update_flags(message_id: int, flags: int): async def _msg_update_flags(message_id: int, flags: int):
await app.db.execute(""" await app.db.execute(
"""
UPDATE messages UPDATE messages
SET flags = $1 SET flags = $1
WHERE id = $2 WHERE id = $2
""", flags, message_id) """,
flags,
message_id,
)
async def _msg_get_flags(message_id: int): async def _msg_get_flags(message_id: int):
return await app.db.fetchval(""" return await app.db.fetchval(
"""
SELECT flags SELECT flags
FROM messages FROM messages
WHERE id = $1 WHERE id = $1
""", message_id) """,
message_id,
)
async def _msg_set_flags(message_id: int, new_flags: int): async def _msg_set_flags(message_id: int, new_flags: int):
@ -647,8 +752,9 @@ async def _msg_unset_flags(message_id: int, unset_flags: int):
await _msg_update_flags(message_id, flags) await _msg_update_flags(message_id, flags)
@bp.route('/<int:channel_id>/messages/<int:message_id>/suppress-embeds', @bp.route(
methods=['POST']) "/<int:channel_id>/messages/<int:message_id>/suppress-embeds", methods=["POST"]
)
async def suppress_embeds(channel_id: int, message_id: int): async def suppress_embeds(channel_id: int, message_id: int):
"""Toggle the embeds in a message. """Toggle the embeds in a message.
@ -661,29 +767,27 @@ async def suppress_embeds(channel_id: int, message_id: int):
# the checks here have been copied from the delete_message() # the checks here have been copied from the delete_message()
# handler on blueprints.channel.messages. maybe we can combine # handler on blueprints.channel.messages. maybe we can combine
# them someday? # them someday?
author_id = await app.db.fetchval(""" author_id = await app.db.fetchval(
"""
SELECT author_id FROM messages SELECT author_id FROM messages
WHERE messages.id = $1 WHERE messages.id = $1
""", message_id) """,
message_id,
)
by_perms = await channel_perm_check( by_perms = await channel_perm_check(user_id, channel_id, "manage_messages", False)
user_id, channel_id, 'manage_messages', False)
by_author = author_id == user_id by_author = author_id == user_id
can_suppress = by_perms or by_author can_suppress = by_perms or by_author
if not can_suppress: if not can_suppress:
raise Forbidden('Not enough permissions.') raise Forbidden("Not enough permissions.")
j = validate( j = validate(await request.get_json(), {"suppress": {"type": "boolean"}})
await request.get_json(),
{'suppress': {'type': 'boolean'}},
)
suppress = j['suppress'] suppress = j["suppress"]
message = await app.storage.get_message(message_id) message = await app.storage.get_message(message_id)
url_embeds = sum( url_embeds = sum(1 for embed in message["embeds"] if embed["type"] == "url")
1 for embed in message['embeds'] if embed['type'] == 'url')
# NOTE for any future self. discord doing flags an optional thing instead # NOTE for any future self. discord doing flags an optional thing instead
# of just giving 0 is a pretty bad idea because now i have to deal with # of just giving 0 is a pretty bad idea because now i have to deal with
@ -693,8 +797,7 @@ async def suppress_embeds(channel_id: int, message_id: int):
# delete all embeds then dispatch an update # delete all embeds then dispatch an update
await _msg_set_flags(message_id, MessageFlags.suppress_embeds) await _msg_set_flags(message_id, MessageFlags.suppress_embeds)
message['flags'] = \ message["flags"] = message.get("flags", 0) | MessageFlags.suppress_embeds
message.get('flags', 0) | MessageFlags.suppress_embeds
await msg_update_embeds(message, [], app.storage, app.dispatcher) await msg_update_embeds(message, [], app.storage, app.dispatcher)
elif not suppress and not url_embeds: elif not suppress and not url_embeds:
@ -702,30 +805,29 @@ async def suppress_embeds(channel_id: int, message_id: int):
await _msg_unset_flags(message_id, MessageFlags.suppress_embeds) await _msg_unset_flags(message_id, MessageFlags.suppress_embeds)
try: try:
message.pop('flags') message.pop("flags")
except KeyError: except KeyError:
pass pass
app.sched.spawn( app.sched.spawn(
process_url_embed( process_url_embed(
app.config, app.storage, app.dispatcher, app.session, app.config, app.storage, app.dispatcher, app.session, message
message
) )
) )
return '', 204 return "", 204
@bp.route('/<int:channel_id>/messages/bulk-delete', methods=['POST']) @bp.route("/<int:channel_id>/messages/bulk-delete", methods=["POST"])
async def bulk_delete(channel_id: int): async def bulk_delete(channel_id: int):
user_id = await token_check() user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id) ctype, guild_id = await channel_check(user_id, channel_id)
guild_id = guild_id if ctype in GUILD_CHANS else None guild_id = guild_id if ctype in GUILD_CHANS else None
await channel_perm_check(user_id, channel_id, 'manage_messages') await channel_perm_check(user_id, channel_id, "manage_messages")
j = validate(await request.get_json(), BULK_DELETE) j = validate(await request.get_json(), BULK_DELETE)
message_ids = set(j['messages']) message_ids = set(j["messages"])
# as per discord behavior, if any id here is older than two weeks, # as per discord behavior, if any id here is older than two weeks,
# we must error. a cuter behavior would be returning the message ids # we must error. a cuter behavior would be returning the message ids
@ -738,25 +840,28 @@ async def bulk_delete(channel_id: int):
raise BadRequest(50034) raise BadRequest(50034)
payload = { payload = {
'guild_id': str(guild_id), "guild_id": str(guild_id),
'channel_id': str(channel_id), "channel_id": str(channel_id),
'ids': list(map(str, message_ids)), "ids": list(map(str, message_ids)),
} }
# payload.guild_id is optional in the event, not nullable. # payload.guild_id is optional in the event, not nullable.
if guild_id is None: if guild_id is None:
payload.pop('guild_id') payload.pop("guild_id")
res = await app.db.execute(""" res = await app.db.execute(
"""
DELETE FROM messages DELETE FROM messages
WHERE WHERE
channel_id = $1 channel_id = $1
AND ARRAY[id] <@ $2::bigint[] AND ARRAY[id] <@ $2::bigint[]
""", channel_id, list(message_ids)) """,
channel_id,
list(message_ids),
)
if res == 'DELETE 0': if res == "DELETE 0":
raise BadRequest('No messages were removed') raise BadRequest("No messages were removed")
await app.dispatcher.dispatch( await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_DELETE_BULK", payload)
'channel', channel_id, 'MESSAGE_DELETE_BULK', payload) return "", 204
return '', 204

View File

@ -23,46 +23,57 @@ from quart import current_app as app
from litecord.enums import ChannelType, GUILD_CHANS from litecord.enums import ChannelType, GUILD_CHANS
from litecord.errors import ( from litecord.errors import (
GuildNotFound, ChannelNotFound, Forbidden, MissingPermissions GuildNotFound,
ChannelNotFound,
Forbidden,
MissingPermissions,
) )
from litecord.permissions import base_permissions, get_permissions from litecord.permissions import base_permissions, get_permissions
async def guild_check(user_id: int, guild_id: int): async def guild_check(user_id: int, guild_id: int):
"""Check if a user is in a guild.""" """Check if a user is in a guild."""
joined_at = await app.db.fetchval(""" joined_at = await app.db.fetchval(
"""
SELECT joined_at SELECT joined_at
FROM members FROM members
WHERE user_id = $1 AND guild_id = $2 WHERE user_id = $1 AND guild_id = $2
""", user_id, guild_id) """,
user_id,
guild_id,
)
if not joined_at: if not joined_at:
raise GuildNotFound('guild not found') raise GuildNotFound("guild not found")
async def guild_owner_check(user_id: int, guild_id: int): async def guild_owner_check(user_id: int, guild_id: int):
"""Check if a user is the owner of the guild.""" """Check if a user is the owner of the guild."""
owner_id = await app.db.fetchval(""" owner_id = await app.db.fetchval(
"""
SELECT owner_id SELECT owner_id
FROM guilds FROM guilds
WHERE guilds.id = $1 WHERE guilds.id = $1
""", guild_id) """,
guild_id,
)
if not owner_id: if not owner_id:
raise GuildNotFound() raise GuildNotFound()
if user_id != owner_id: if user_id != owner_id:
raise Forbidden('You are not the owner of the guild') raise Forbidden("You are not the owner of the guild")
async def channel_check(user_id, channel_id, *, async def channel_check(
only: Union[ChannelType, List[ChannelType]] = None): user_id, channel_id, *, only: Union[ChannelType, List[ChannelType]] = None
):
"""Check if the current user is authorized """Check if the current user is authorized
to read the channel's information.""" to read the channel's information."""
chan_type = await app.storage.get_chan_type(channel_id) chan_type = await app.storage.get_chan_type(channel_id)
if chan_type is None: if chan_type is None:
raise ChannelNotFound('channel type not found') raise ChannelNotFound("channel type not found")
ctype = ChannelType(chan_type) ctype = ChannelType(chan_type)
@ -70,14 +81,17 @@ async def channel_check(user_id, channel_id, *,
only = [only] only = [only]
if only and ctype not in only: if only and ctype not in only:
raise ChannelNotFound('invalid channel type') raise ChannelNotFound("invalid channel type")
if ctype in GUILD_CHANS: if ctype in GUILD_CHANS:
guild_id = await app.db.fetchval(""" guild_id = await app.db.fetchval(
"""
SELECT guild_id SELECT guild_id
FROM guild_channels FROM guild_channels
WHERE guild_channels.id = $1 WHERE guild_channels.id = $1
""", channel_id) """,
channel_id,
)
await guild_check(user_id, guild_id) await guild_check(user_id, guild_id)
return ctype, guild_id return ctype, guild_id
@ -87,11 +101,14 @@ async def channel_check(user_id, channel_id, *,
return ctype, peer_id return ctype, peer_id
if ctype == ChannelType.GROUP_DM: if ctype == ChannelType.GROUP_DM:
owner_id = await app.db.fetchval(""" owner_id = await app.db.fetchval(
"""
SELECT owner_id SELECT owner_id
FROM group_dm_channels FROM group_dm_channels
WHERE id = $1 WHERE id = $1
""", channel_id) """,
channel_id,
)
return ctype, owner_id return ctype, owner_id
@ -102,18 +119,17 @@ async def guild_perm_check(user_id, guild_id, permission: str):
hasperm = getattr(base_perms.bits, permission) hasperm = getattr(base_perms.bits, permission)
if not hasperm: if not hasperm:
raise MissingPermissions('Missing permissions.') raise MissingPermissions("Missing permissions.")
return bool(hasperm) return bool(hasperm)
async def channel_perm_check(user_id, channel_id, async def channel_perm_check(user_id, channel_id, permission: str, raise_err=True):
permission: str, raise_err=True):
"""Check channel permissions for a user.""" """Check channel permissions for a user."""
base_perms = await get_permissions(user_id, channel_id) base_perms = await get_permissions(user_id, channel_id)
hasperm = getattr(base_perms.bits, permission) hasperm = getattr(base_perms.bits, permission)
if not hasperm and raise_err: if not hasperm and raise_err:
raise MissingPermissions('Missing permissions.') raise MissingPermissions("Missing permissions.")
return bool(hasperm) return bool(hasperm)

View File

@ -29,21 +29,29 @@ from litecord.system_messages import send_sys_message
from litecord.pubsub.channel import gdm_recipient_view from litecord.pubsub.channel import gdm_recipient_view
log = Logger(__name__) log = Logger(__name__)
bp = Blueprint('dm_channels', __name__) bp = Blueprint("dm_channels", __name__)
async def _raw_gdm_add(channel_id, user_id): async def _raw_gdm_add(channel_id, user_id):
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO group_dm_members (id, member_id) INSERT INTO group_dm_members (id, member_id)
VALUES ($1, $2) VALUES ($1, $2)
""", channel_id, user_id) """,
channel_id,
user_id,
)
async def _raw_gdm_remove(channel_id, user_id): async def _raw_gdm_remove(channel_id, user_id):
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM group_dm_members DELETE FROM group_dm_members
WHERE id = $1 AND member_id = $2 WHERE id = $1 AND member_id = $2
""", channel_id, user_id) """,
channel_id,
user_id,
)
async def gdm_create(user_id, peer_id) -> int: async def gdm_create(user_id, peer_id) -> int:
@ -53,24 +61,32 @@ async def gdm_create(user_id, peer_id) -> int:
""" """
channel_id = get_snowflake() channel_id = get_snowflake()
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO channels (id, channel_type) INSERT INTO channels (id, channel_type)
VALUES ($1, $2) VALUES ($1, $2)
""", channel_id, ChannelType.GROUP_DM.value) """,
channel_id,
ChannelType.GROUP_DM.value,
)
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO group_dm_channels (id, owner_id, name, icon) INSERT INTO group_dm_channels (id, owner_id, name, icon)
VALUES ($1, $2, NULL, NULL) VALUES ($1, $2, NULL, NULL)
""", channel_id, user_id) """,
channel_id,
user_id,
)
await _raw_gdm_add(channel_id, user_id) await _raw_gdm_add(channel_id, user_id)
await _raw_gdm_add(channel_id, peer_id) await _raw_gdm_add(channel_id, peer_id)
await app.dispatcher.sub('channel', channel_id, user_id) await app.dispatcher.sub("channel", channel_id, user_id)
await app.dispatcher.sub('channel', channel_id, peer_id) await app.dispatcher.sub("channel", channel_id, peer_id)
chan = await app.storage.get_channel(channel_id) chan = await app.storage.get_channel(channel_id)
await app.dispatcher.dispatch('channel', channel_id, 'CHANNEL_CREATE', chan) await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_CREATE", chan)
return channel_id return channel_id
@ -89,17 +105,16 @@ async def gdm_add_recipient(channel_id: int, peer_id: int, *, user_id=None):
# the reasoning behind gdm_recipient_view is in its docstring. # the reasoning behind gdm_recipient_view is in its docstring.
await app.dispatcher.dispatch( await app.dispatcher.dispatch(
'user', peer_id, 'CHANNEL_CREATE', gdm_recipient_view(chan, peer_id)) "user", peer_id, "CHANNEL_CREATE", gdm_recipient_view(chan, peer_id)
)
await app.dispatcher.dispatch( await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan)
'channel', channel_id, 'CHANNEL_UPDATE', chan)
await app.dispatcher.sub('channel', peer_id) await app.dispatcher.sub("channel", peer_id)
if user_id: if user_id:
await send_sys_message( await send_sys_message(
app, channel_id, MessageType.RECIPIENT_ADD, app, channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id
user_id, peer_id
) )
@ -116,22 +131,22 @@ async def gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None):
chan = await app.storage.get_channel(channel_id) chan = await app.storage.get_channel(channel_id)
await app.dispatcher.dispatch( await app.dispatcher.dispatch(
'user', peer_id, 'CHANNEL_DELETE', gdm_recipient_view(chan, user_id)) "user", peer_id, "CHANNEL_DELETE", gdm_recipient_view(chan, user_id)
)
await app.dispatcher.unsub('channel', peer_id) await app.dispatcher.unsub("channel", peer_id)
await app.dispatcher.dispatch( await app.dispatcher.dispatch(
'channel', channel_id, 'CHANNEL_RECIPIENT_REMOVE', { "channel",
'channel_id': str(channel_id), channel_id,
'user': await app.storage.get_user(peer_id) "CHANNEL_RECIPIENT_REMOVE",
} {"channel_id": str(channel_id), "user": await app.storage.get_user(peer_id)},
) )
author_id = peer_id if user_id is None else user_id author_id = peer_id if user_id is None else user_id
await send_sys_message( await send_sys_message(
app, channel_id, MessageType.RECIPIENT_REMOVE, app, channel_id, MessageType.RECIPIENT_REMOVE, author_id, peer_id
author_id, peer_id
) )
@ -139,40 +154,51 @@ async def gdm_destroy(channel_id):
"""Destroy a Group DM.""" """Destroy a Group DM."""
chan = await app.storage.get_channel(channel_id) chan = await app.storage.get_channel(channel_id)
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM group_dm_members DELETE FROM group_dm_members
WHERE id = $1 WHERE id = $1
""", channel_id) """,
channel_id,
await app.db.execute("""
DELETE FROM group_dm_channels
WHERE id = $1
""", channel_id)
await app.db.execute("""
DELETE FROM channels
WHERE id = $1
""", channel_id)
await app.dispatcher.dispatch(
'channel', channel_id, 'CHANNEL_DELETE', chan
) )
await app.dispatcher.remove('channel', channel_id) await app.db.execute(
"""
DELETE FROM group_dm_channels
WHERE id = $1
""",
channel_id,
)
await app.db.execute(
"""
DELETE FROM channels
WHERE id = $1
""",
channel_id,
)
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_DELETE", chan)
await app.dispatcher.remove("channel", channel_id)
async def gdm_is_member(channel_id: int, user_id: int) -> bool: async def gdm_is_member(channel_id: int, user_id: int) -> bool:
"""Return if the given user is a member of the Group DM.""" """Return if the given user is a member of the Group DM."""
row = await app.db.fetchval(""" row = await app.db.fetchval(
"""
SELECT id SELECT id
FROM group_dm_members FROM group_dm_members
WHERE id = $1 AND member_id = $2 WHERE id = $1 AND member_id = $2
""", channel_id, user_id) """,
channel_id,
user_id,
)
return row is not None return row is not None
@bp.route('/<int:dm_chan>/recipients/<int:peer_id>', methods=['PUT']) @bp.route("/<int:dm_chan>/recipients/<int:peer_id>", methods=["PUT"])
async def add_to_group_dm(dm_chan, peer_id): async def add_to_group_dm(dm_chan, peer_id):
"""Adds a member to a group dm OR creates a group dm.""" """Adds a member to a group dm OR creates a group dm."""
user_id = await token_check() user_id = await token_check()
@ -182,8 +208,7 @@ async def add_to_group_dm(dm_chan, peer_id):
# other_id is the peer of the dm if the given channel is a dm # other_id is the peer of the dm if the given channel is a dm
ctype, other_id = await channel_check( ctype, other_id = await channel_check(
user_id, dm_chan, user_id, dm_chan, only=[ChannelType.DM, ChannelType.GROUP_DM]
only=[ChannelType.DM, ChannelType.GROUP_DM]
) )
# check relationship with the given user id # check relationship with the given user id
@ -191,30 +216,24 @@ async def add_to_group_dm(dm_chan, peer_id):
friends = await app.user_storage.are_friends_with(user_id, peer_id) friends = await app.user_storage.are_friends_with(user_id, peer_id)
if not friends: if not friends:
raise BadRequest('Cant insert peer into dm') raise BadRequest("Cant insert peer into dm")
if ctype == ChannelType.DM: if ctype == ChannelType.DM:
dm_chan = await gdm_create( dm_chan = await gdm_create(user_id, other_id)
user_id, other_id
)
await gdm_add_recipient(dm_chan, peer_id, user_id=user_id) await gdm_add_recipient(dm_chan, peer_id, user_id=user_id)
return jsonify( return jsonify(await app.storage.get_channel(dm_chan))
await app.storage.get_channel(dm_chan)
)
@bp.route('/<int:dm_chan>/recipients/<int:peer_id>', methods=['DELETE']) @bp.route("/<int:dm_chan>/recipients/<int:peer_id>", methods=["DELETE"])
async def remove_from_group_dm(dm_chan, peer_id): async def remove_from_group_dm(dm_chan, peer_id):
"""Remove users from group dm.""" """Remove users from group dm."""
user_id = await token_check() user_id = await token_check()
_ctype, owner_id = await channel_check( _ctype, owner_id = await channel_check(user_id, dm_chan, only=ChannelType.GROUP_DM)
user_id, dm_chan, only=ChannelType.GROUP_DM
)
if owner_id != user_id: if owner_id != user_id:
raise Forbidden('You are now the owner of the group DM') raise Forbidden("You are now the owner of the group DM")
await gdm_remove_recipient(dm_chan, peer_id) await gdm_remove_recipient(dm_chan, peer_id)
return '', 204 return "", 204

View File

@ -30,15 +30,13 @@ from ..snowflake import get_snowflake
from .auth import token_check from .auth import token_check
from litecord.blueprints.dm_channels import ( from litecord.blueprints.dm_channels import gdm_create, gdm_add_recipient
gdm_create, gdm_add_recipient
)
log = Logger(__name__) log = Logger(__name__)
bp = Blueprint('dms', __name__) bp = Blueprint("dms", __name__)
@bp.route('/@me/channels', methods=['GET']) @bp.route("/@me/channels", methods=["GET"])
async def get_dms(): async def get_dms():
"""Get the open DMs for the user.""" """Get the open DMs for the user."""
user_id = await token_check() user_id = await token_check()
@ -53,11 +51,15 @@ async def try_dm_state(user_id: int, dm_id: int):
Does not do anything if the user is already Does not do anything if the user is already
in the dm state. in the dm state.
""" """
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO dm_channel_state (user_id, dm_id) INSERT INTO dm_channel_state (user_id, dm_id)
VALUES ($1, $2) VALUES ($1, $2)
ON CONFLICT DO NOTHING ON CONFLICT DO NOTHING
""", user_id, dm_id) """,
user_id,
dm_id,
)
async def jsonify_dm(dm_id: int, user_id: int): async def jsonify_dm(dm_id: int, user_id: int):
@ -69,12 +71,16 @@ async def create_dm(user_id, recipient_id):
"""Create a new dm with a user, """Create a new dm with a user,
or get the existing DM id if it already exists.""" or get the existing DM id if it already exists."""
dm_id = await app.db.fetchval(""" dm_id = await app.db.fetchval(
"""
SELECT id SELECT id
FROM dm_channels FROM dm_channels
WHERE (party1_id = $1 OR party2_id = $1) AND WHERE (party1_id = $1 OR party2_id = $1) AND
(party1_id = $2 OR party2_id = $2) (party1_id = $2 OR party2_id = $2)
""", user_id, recipient_id) """,
user_id,
recipient_id,
)
if dm_id: if dm_id:
return await jsonify_dm(dm_id, user_id) return await jsonify_dm(dm_id, user_id)
@ -82,15 +88,24 @@ async def create_dm(user_id, recipient_id):
# if no dm was found, create a new one # if no dm was found, create a new one
dm_id = get_snowflake() dm_id = get_snowflake()
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO channels (id, channel_type) INSERT INTO channels (id, channel_type)
VALUES ($1, $2) VALUES ($1, $2)
""", dm_id, ChannelType.DM.value) """,
dm_id,
ChannelType.DM.value,
)
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO dm_channels (id, party1_id, party2_id) INSERT INTO dm_channels (id, party1_id, party2_id)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
""", dm_id, user_id, recipient_id) """,
dm_id,
user_id,
recipient_id,
)
# the dm state is something we use # the dm state is something we use
# to give the currently "open dms" # to give the currently "open dms"
@ -103,24 +118,24 @@ async def create_dm(user_id, recipient_id):
return await jsonify_dm(dm_id, user_id) return await jsonify_dm(dm_id, user_id)
@bp.route('/@me/channels', methods=['POST']) @bp.route("/@me/channels", methods=["POST"])
async def start_dm(): async def start_dm():
"""Create a DM with a user.""" """Create a DM with a user."""
user_id = await token_check() user_id = await token_check()
j = validate(await request.get_json(), CREATE_DM) j = validate(await request.get_json(), CREATE_DM)
recipient_id = j['recipient_id'] recipient_id = j["recipient_id"]
return await create_dm(user_id, recipient_id) return await create_dm(user_id, recipient_id)
@bp.route('/<int:p_user_id>/channels', methods=['POST']) @bp.route("/<int:p_user_id>/channels", methods=["POST"])
async def create_group_dm(p_user_id: int): async def create_group_dm(p_user_id: int):
"""Create a DM or a Group DM with user(s).""" """Create a DM or a Group DM with user(s)."""
user_id = await token_check() user_id = await token_check()
assert user_id == p_user_id assert user_id == p_user_id
j = validate(await request.get_json(), CREATE_GROUP_DM) j = validate(await request.get_json(), CREATE_GROUP_DM)
recipients = j['recipients'] recipients = j["recipients"]
if len(recipients) == 1: if len(recipients) == 1:
# its a group dm with 1 user... a dm! # its a group dm with 1 user... a dm!

View File

@ -23,37 +23,38 @@ from quart import Blueprint, jsonify, current_app as app
from ..auth import token_check from ..auth import token_check
bp = Blueprint('gateway', __name__) bp = Blueprint("gateway", __name__)
def get_gw(): def get_gw():
"""Get the gateway's web""" """Get the gateway's web"""
proto = 'wss://' if app.config['IS_SSL'] else 'ws://' proto = "wss://" if app.config["IS_SSL"] else "ws://"
return f'{proto}{app.config["WEBSOCKET_URL"]}' return f'{proto}{app.config["WEBSOCKET_URL"]}'
@bp.route('/gateway') @bp.route("/gateway")
def api_gateway(): def api_gateway():
"""Get the raw URL.""" """Get the raw URL."""
return jsonify({ return jsonify({"url": get_gw()})
'url': get_gw()
})
@bp.route('/gateway/bot') @bp.route("/gateway/bot")
async def api_gateway_bot(): async def api_gateway_bot():
user_id = await token_check() user_id = await token_check()
guild_count = await app.db.fetchval(""" guild_count = await app.db.fetchval(
"""
SELECT COUNT(*) SELECT COUNT(*)
FROM members FROM members
WHERE user_id = $1 WHERE user_id = $1
""", user_id) """,
user_id,
)
shards = max(int(guild_count / 1000), 1) shards = max(int(guild_count / 1000), 1)
# get _ws.session ratelimit # get _ws.session ratelimit
ratelimit = app.ratelimiter.get_ratelimit('_ws.session') ratelimit = app.ratelimiter.get_ratelimit("_ws.session")
bucket = ratelimit.get_bucket(user_id) bucket = ratelimit.get_bucket(user_id)
# timestamp of bucket reset # timestamp of bucket reset
@ -62,13 +63,14 @@ async def api_gateway_bot():
# how many seconds until bucket reset # how many seconds until bucket reset
reset_after_ts = reset_ts - time.time() reset_after_ts = reset_ts - time.time()
return jsonify({ return jsonify(
'url': get_gw(), {
'shards': shards, "url": get_gw(),
"shards": shards,
'session_start_limit': { "session_start_limit": {
'total': bucket.requests, "total": bucket.requests,
'remaining': bucket._tokens, "remaining": bucket._tokens,
'reset_after': int(reset_after_ts * 1000), "reset_after": int(reset_after_ts * 1000),
},
} }
}) )

View File

@ -23,5 +23,4 @@ from .channels import bp as guild_channels
from .mod import bp as guild_mod from .mod import bp as guild_mod
from .emoji import bp as guild_emoji from .emoji import bp as guild_emoji
__all__ = ['guild_roles', 'guild_members', 'guild_channels', 'guild_mod', __all__ = ["guild_roles", "guild_members", "guild_channels", "guild_mod", "guild_emoji"]
'guild_emoji']

View File

@ -25,23 +25,23 @@ from litecord.errors import BadRequest
from litecord.enums import ChannelType from litecord.enums import ChannelType
from litecord.blueprints.guild.roles import gen_pairs from litecord.blueprints.guild.roles import gen_pairs
from litecord.schemas import ( from litecord.schemas import validate, ROLE_UPDATE_POSITION, CHAN_CREATE
validate, ROLE_UPDATE_POSITION, CHAN_CREATE from litecord.blueprints.checks import guild_check, guild_owner_check, guild_perm_check
)
from litecord.blueprints.checks import (
guild_check, guild_owner_check, guild_perm_check
)
bp = Blueprint('guild_channels', __name__) bp = Blueprint("guild_channels", __name__)
async def _specific_chan_create(channel_id, ctype, **kwargs): async def _specific_chan_create(channel_id, ctype, **kwargs):
if ctype == ChannelType.GUILD_TEXT: if ctype == ChannelType.GUILD_TEXT:
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO guild_text_channels (id, topic) INSERT INTO guild_text_channels (id, topic)
VALUES ($1, $2) VALUES ($1, $2)
""", channel_id, kwargs.get('topic', '')) """,
channel_id,
kwargs.get("topic", ""),
)
elif ctype == ChannelType.GUILD_VOICE: elif ctype == ChannelType.GUILD_VOICE:
await app.db.execute( await app.db.execute(
""" """
@ -49,34 +49,48 @@ async def _specific_chan_create(channel_id, ctype, **kwargs):
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
""", """,
channel_id, channel_id,
kwargs.get('bitrate', 64), kwargs.get("bitrate", 64),
kwargs.get('user_limit', 0) kwargs.get("user_limit", 0),
) )
async def create_guild_channel(guild_id: int, channel_id: int, async def create_guild_channel(
ctype: ChannelType, **kwargs): guild_id: int, channel_id: int, ctype: ChannelType, **kwargs
):
"""Create a channel in a guild.""" """Create a channel in a guild."""
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO channels (id, channel_type) INSERT INTO channels (id, channel_type)
VALUES ($1, $2) VALUES ($1, $2)
""", channel_id, ctype.value) """,
channel_id,
ctype.value,
)
# calc new pos # calc new pos
max_pos = await app.db.fetchval(""" max_pos = await app.db.fetchval(
"""
SELECT MAX(position) SELECT MAX(position)
FROM guild_channels FROM guild_channels
WHERE guild_id = $1 WHERE guild_id = $1
""", guild_id) """,
guild_id,
)
# account for the first channel in a guild too # account for the first channel in a guild too
max_pos = max_pos or 0 max_pos = max_pos or 0
# all channels go to guild_channels # all channels go to guild_channels
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO guild_channels (id, guild_id, name, position) INSERT INTO guild_channels (id, guild_id, name, position)
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4)
""", channel_id, guild_id, kwargs['name'], max_pos + 1) """,
channel_id,
guild_id,
kwargs["name"],
max_pos + 1,
)
# the rest of sql magic is dependant on the channel # the rest of sql magic is dependant on the channel
# we're creating (a text or voice or category), # we're creating (a text or voice or category),
@ -84,35 +98,32 @@ async def create_guild_channel(guild_id: int, channel_id: int,
await _specific_chan_create(channel_id, ctype, **kwargs) await _specific_chan_create(channel_id, ctype, **kwargs)
@bp.route('/<int:guild_id>/channels', methods=['GET']) @bp.route("/<int:guild_id>/channels", methods=["GET"])
async def get_guild_channels(guild_id): async def get_guild_channels(guild_id):
"""Get the list of channels in a guild.""" """Get the list of channels in a guild."""
user_id = await token_check() user_id = await token_check()
await guild_check(user_id, guild_id) await guild_check(user_id, guild_id)
return jsonify( return jsonify(await app.storage.get_channel_data(guild_id))
await app.storage.get_channel_data(guild_id))
@bp.route('/<int:guild_id>/channels', methods=['POST']) @bp.route("/<int:guild_id>/channels", methods=["POST"])
async def create_channel(guild_id): async def create_channel(guild_id):
"""Create a channel in a guild.""" """Create a channel in a guild."""
user_id = await token_check() user_id = await token_check()
j = validate(await request.get_json(), CHAN_CREATE) j = validate(await request.get_json(), CHAN_CREATE)
await guild_check(user_id, guild_id) await guild_check(user_id, guild_id)
await guild_perm_check(user_id, guild_id, 'manage_channels') await guild_perm_check(user_id, guild_id, "manage_channels")
channel_type = j.get('type', ChannelType.GUILD_TEXT) channel_type = j.get("type", ChannelType.GUILD_TEXT)
channel_type = ChannelType(channel_type) channel_type = ChannelType(channel_type)
if channel_type not in (ChannelType.GUILD_TEXT, if channel_type not in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE):
ChannelType.GUILD_VOICE): raise BadRequest("Invalid channel type")
raise BadRequest('Invalid channel type')
new_channel_id = get_snowflake() new_channel_id = get_snowflake()
await create_guild_channel( await create_guild_channel(guild_id, new_channel_id, channel_type, **j)
guild_id, new_channel_id, channel_type, **j)
# TODO: do a better method # TODO: do a better method
# subscribe the currently subscribed users to the new channel # subscribe the currently subscribed users to the new channel
@ -120,14 +131,13 @@ async def create_channel(guild_id):
# since GuildDispatcher calls Storage.get_channel_ids, # since GuildDispatcher calls Storage.get_channel_ids,
# it will subscribe all users to the newly created channel. # it will subscribe all users to the newly created channel.
guild_pubsub = app.dispatcher.backends['guild'] guild_pubsub = app.dispatcher.backends["guild"]
user_ids = guild_pubsub.state[guild_id] user_ids = guild_pubsub.state[guild_id]
for uid in user_ids: for uid in user_ids:
await app.dispatcher.sub('guild', guild_id, uid) await app.dispatcher.sub("guild", guild_id, uid)
chan = await app.storage.get_channel(new_channel_id) chan = await app.storage.get_channel(new_channel_id)
await app.dispatcher.dispatch_guild( await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_CREATE", chan)
guild_id, 'CHANNEL_CREATE', chan)
return jsonify(chan) return jsonify(chan)
@ -135,7 +145,7 @@ async def _chan_update_dispatch(guild_id: int, channel_id: int):
"""Fetch new information about the channel and dispatch """Fetch new information about the channel and dispatch
a single CHANNEL_UPDATE event to the guild.""" a single CHANNEL_UPDATE event to the guild."""
chan = await app.storage.get_channel(channel_id) chan = await app.storage.get_channel(channel_id)
await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_UPDATE', chan) await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", chan)
async def _do_single_swap(guild_id: int, pair: tuple): async def _do_single_swap(guild_id: int, pair: tuple):
@ -149,13 +159,14 @@ async def _do_single_swap(guild_id: int, pair: tuple):
conn = await app.db.acquire() conn = await app.db.acquire()
async with conn.transaction(): async with conn.transaction():
await conn.executemany(""" await conn.executemany(
"""
UPDATE guild_channels UPDATE guild_channels
SET position = $1 SET position = $1
WHERE id = $2 AND guild_id = $3 WHERE id = $2 AND guild_id = $3
""", [ """,
(new_pos_1, channel_1, guild_id), [(new_pos_1, channel_1, guild_id), (new_pos_2, channel_2, guild_id)],
(new_pos_2, channel_2, guild_id)]) )
await app.db.release(conn) await app.db.release(conn)
@ -173,30 +184,26 @@ async def _do_channel_swaps(guild_id: int, swap_pairs: list):
await _do_single_swap(guild_id, pair) await _do_single_swap(guild_id, pair)
@bp.route('/<int:guild_id>/channels', methods=['PATCH']) @bp.route("/<int:guild_id>/channels", methods=["PATCH"])
async def modify_channel_pos(guild_id): async def modify_channel_pos(guild_id):
"""Change positions of channels in a guild.""" """Change positions of channels in a guild."""
user_id = await token_check() user_id = await token_check()
await guild_owner_check(user_id, guild_id) await guild_owner_check(user_id, guild_id)
await guild_perm_check(user_id, guild_id, 'manage_channels') await guild_perm_check(user_id, guild_id, "manage_channels")
# same thing as guild.roles, so we use # same thing as guild.roles, so we use
# the same schema and all. # the same schema and all.
raw_j = await request.get_json() raw_j = await request.get_json()
j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION) j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION)
j = j['roles'] j = j["roles"]
channels = await app.storage.get_channel_data(guild_id) channels = await app.storage.get_channel_data(guild_id)
channel_positions = {chan['position']: int(chan['id']) channel_positions = {chan["position"]: int(chan["id"]) for chan in channels}
for chan in channels}
swap_pairs = gen_pairs( swap_pairs = gen_pairs(j, channel_positions)
j,
channel_positions
)
await _do_channel_swaps(guild_id, swap_pairs) await _do_channel_swaps(guild_id, swap_pairs)
return '', 204 return "", 204

View File

@ -27,78 +27,85 @@ from litecord.types import KILOBYTES
from litecord.images import parse_data_uri from litecord.images import parse_data_uri
from litecord.errors import BadRequest from litecord.errors import BadRequest
bp = Blueprint('guild.emoji', __name__) bp = Blueprint("guild.emoji", __name__)
async def _dispatch_emojis(guild_id): async def _dispatch_emojis(guild_id):
"""Dispatch a Guild Emojis Update payload to a guild.""" """Dispatch a Guild Emojis Update payload to a guild."""
await app.dispatcher.dispatch('guild', guild_id, 'GUILD_EMOJIS_UPDATE', { await app.dispatcher.dispatch(
'guild_id': str(guild_id), "guild",
'emojis': await app.storage.get_guild_emojis(guild_id) guild_id,
}) "GUILD_EMOJIS_UPDATE",
{
"guild_id": str(guild_id),
"emojis": await app.storage.get_guild_emojis(guild_id),
},
)
@bp.route('/<int:guild_id>/emojis', methods=['GET']) @bp.route("/<int:guild_id>/emojis", methods=["GET"])
async def _get_guild_emoji(guild_id): async def _get_guild_emoji(guild_id):
user_id = await token_check() user_id = await token_check()
await guild_check(user_id, guild_id) await guild_check(user_id, guild_id)
return jsonify( return jsonify(await app.storage.get_guild_emojis(guild_id))
await app.storage.get_guild_emojis(guild_id)
)
@bp.route('/<int:guild_id>/emojis/<int:emoji_id>', methods=['GET']) @bp.route("/<int:guild_id>/emojis/<int:emoji_id>", methods=["GET"])
async def _get_guild_emoji_one(guild_id, emoji_id): async def _get_guild_emoji_one(guild_id, emoji_id):
user_id = await token_check() user_id = await token_check()
await guild_check(user_id, guild_id) await guild_check(user_id, guild_id)
return jsonify( return jsonify(await app.storage.get_emoji(emoji_id))
await app.storage.get_emoji(emoji_id)
)
async def _guild_emoji_size_check(guild_id: int, mime: str): async def _guild_emoji_size_check(guild_id: int, mime: str):
limit = 50 limit = 50
if await app.storage.has_feature(guild_id, 'MORE_EMOJI'): if await app.storage.has_feature(guild_id, "MORE_EMOJI"):
limit = 200 limit = 200
# NOTE: I'm assuming you can have 200 animated emojis. # NOTE: I'm assuming you can have 200 animated emojis.
select_animated = mime == 'image/gif' select_animated = mime == "image/gif"
total_emoji = await app.db.fetchval(""" total_emoji = await app.db.fetchval(
"""
SELECT COUNT(*) FROM guild_emoji SELECT COUNT(*) FROM guild_emoji
WHERE guild_id = $1 AND animated = $2 WHERE guild_id = $1 AND animated = $2
""", guild_id, select_animated) """,
guild_id,
select_animated,
)
if total_emoji >= limit: if total_emoji >= limit:
# TODO: really return a BadRequest? needs more looking. # TODO: really return a BadRequest? needs more looking.
raise BadRequest(f'too many emoji ({limit})') raise BadRequest(f"too many emoji ({limit})")
@bp.route('/<int:guild_id>/emojis', methods=['POST']) @bp.route("/<int:guild_id>/emojis", methods=["POST"])
async def _put_emoji(guild_id): async def _put_emoji(guild_id):
user_id = await token_check() user_id = await token_check()
await guild_check(user_id, guild_id) await guild_check(user_id, guild_id)
await guild_perm_check(user_id, guild_id, 'manage_emojis') await guild_perm_check(user_id, guild_id, "manage_emojis")
j = validate(await request.get_json(), NEW_EMOJI) j = validate(await request.get_json(), NEW_EMOJI)
# we have to parse it before passing on so that we know which # we have to parse it before passing on so that we know which
# size to check. # size to check.
mime, _ = parse_data_uri(j['image']) mime, _ = parse_data_uri(j["image"])
await _guild_emoji_size_check(guild_id, mime) await _guild_emoji_size_check(guild_id, mime)
emoji_id = get_snowflake() emoji_id = get_snowflake()
icon = await app.icons.put( icon = await app.icons.put(
'emoji', emoji_id, j['image'], "emoji",
emoji_id,
j["image"],
# limits to emojis # limits to emojis
bsize=128 * KILOBYTES, size=(128, 128) bsize=128 * KILOBYTES,
size=(128, 128),
) )
if not icon: if not icon:
return '', 400 return "", 400
# TODO: better way to detect animated emoji rather than just gifs, # TODO: better way to detect animated emoji rather than just gifs,
# maybe a list perhaps? # maybe a list perhaps?
@ -109,25 +116,25 @@ async def _put_emoji(guild_id):
VALUES VALUES
($1, $2, $3, $4, $5, $6) ($1, $2, $3, $4, $5, $6)
""", """,
emoji_id, guild_id, user_id, emoji_id,
j['name'], guild_id,
user_id,
j["name"],
icon.icon_hash, icon.icon_hash,
icon.mime == 'image/gif' icon.mime == "image/gif",
) )
await _dispatch_emojis(guild_id) await _dispatch_emojis(guild_id)
return jsonify( return jsonify(await app.storage.get_emoji(emoji_id))
await app.storage.get_emoji(emoji_id)
)
@bp.route('/<int:guild_id>/emojis/<int:emoji_id>', methods=['PATCH']) @bp.route("/<int:guild_id>/emojis/<int:emoji_id>", methods=["PATCH"])
async def _patch_emoji(guild_id, emoji_id): async def _patch_emoji(guild_id, emoji_id):
user_id = await token_check() user_id = await token_check()
await guild_check(user_id, guild_id) await guild_check(user_id, guild_id)
await guild_perm_check(user_id, guild_id, 'manage_emojis') await guild_perm_check(user_id, guild_id, "manage_emojis")
j = validate(await request.get_json(), PATCH_EMOJI) j = validate(await request.get_json(), PATCH_EMOJI)
emoji = await app.storage.get_emoji(emoji_id) emoji = await app.storage.get_emoji(emoji_id)
@ -135,34 +142,39 @@ async def _patch_emoji(guild_id, emoji_id):
# if emoji.name is still the same, we don't update anything # if emoji.name is still the same, we don't update anything
# or send ane events, just return the same emoji we'd send # or send ane events, just return the same emoji we'd send
# as if we updated it. # as if we updated it.
if j['name'] == emoji['name']: if j["name"] == emoji["name"]:
return jsonify(emoji) return jsonify(emoji)
await app.db.execute(""" await app.db.execute(
"""
UPDATE guild_emoji UPDATE guild_emoji
SET name = $1 SET name = $1
WHERE id = $2 WHERE id = $2
""", j['name'], emoji_id) """,
j["name"],
emoji_id,
)
await _dispatch_emojis(guild_id) await _dispatch_emojis(guild_id)
return jsonify( return jsonify(await app.storage.get_emoji(emoji_id))
await app.storage.get_emoji(emoji_id)
)
@bp.route('/<int:guild_id>/emojis/<int:emoji_id>', methods=['DELETE']) @bp.route("/<int:guild_id>/emojis/<int:emoji_id>", methods=["DELETE"])
async def _del_emoji(guild_id, emoji_id): async def _del_emoji(guild_id, emoji_id):
user_id = await token_check() user_id = await token_check()
await guild_check(user_id, guild_id) await guild_check(user_id, guild_id)
await guild_perm_check(user_id, guild_id, 'manage_emojis') await guild_perm_check(user_id, guild_id, "manage_emojis")
# TODO: check if actually deleted # TODO: check if actually deleted
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM guild_emoji DELETE FROM guild_emoji
WHERE id = $2 WHERE id = $2
""", emoji_id) """,
emoji_id,
)
await _dispatch_emojis(guild_id) await _dispatch_emojis(guild_id)
return '', 204 return "", 204

View File

@ -22,18 +22,14 @@ from quart import Blueprint, request, current_app as app, jsonify
from litecord.blueprints.auth import token_check from litecord.blueprints.auth import token_check
from litecord.errors import BadRequest from litecord.errors import BadRequest
from litecord.schemas import ( from litecord.schemas import validate, MEMBER_UPDATE
validate, MEMBER_UPDATE
)
from litecord.blueprints.checks import ( from litecord.blueprints.checks import guild_check, guild_owner_check, guild_perm_check
guild_check, guild_owner_check, guild_perm_check
)
bp = Blueprint('guild_members', __name__) bp = Blueprint("guild_members", __name__)
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['GET']) @bp.route("/<int:guild_id>/members/<int:member_id>", methods=["GET"])
async def get_guild_member(guild_id, member_id): async def get_guild_member(guild_id, member_id):
"""Get a member's information in a guild.""" """Get a member's information in a guild."""
user_id = await token_check() user_id = await token_check()
@ -42,7 +38,7 @@ async def get_guild_member(guild_id, member_id):
return jsonify(member) return jsonify(member)
@bp.route('/<int:guild_id>/members', methods=['GET']) @bp.route("/<int:guild_id>/members", methods=["GET"])
async def get_members(guild_id): async def get_members(guild_id):
"""Get members inside a guild.""" """Get members inside a guild."""
user_id = await token_check() user_id = await token_check()
@ -50,34 +46,41 @@ async def get_members(guild_id):
j = await request.get_json() j = await request.get_json()
limit, after = int(j.get('limit', 1)), j.get('after', 0) limit, after = int(j.get("limit", 1)), j.get("after", 0)
if limit < 1 or limit > 1000: if limit < 1 or limit > 1000:
raise BadRequest('limit not in 1-1000 range') raise BadRequest("limit not in 1-1000 range")
user_ids = await app.db.fetch(f""" user_ids = await app.db.fetch(
f"""
SELECT user_id SELECT user_id
WHERE guild_id = $1, user_id > $2 WHERE guild_id = $1, user_id > $2
LIMIT {limit} LIMIT {limit}
ORDER BY user_id ASC ORDER BY user_id ASC
""", guild_id, after) """,
guild_id,
after,
)
user_ids = [r[0] for r in user_ids] user_ids = [r[0] for r in user_ids]
members = await app.storage.get_member_multi(guild_id, user_ids) members = await app.storage.get_member_multi(guild_id, user_ids)
return jsonify(members) return jsonify(members)
async def _update_member_roles(guild_id: int, member_id: int, async def _update_member_roles(guild_id: int, member_id: int, wanted_roles: set):
wanted_roles: set):
"""Update the roles a member has.""" """Update the roles a member has."""
# first, fetch all current roles # first, fetch all current roles
roles = await app.db.fetch(""" roles = await app.db.fetch(
"""
SELECT role_id from member_roles SELECT role_id from member_roles
WHERE guild_id = $1 AND user_id = $2 WHERE guild_id = $1 AND user_id = $2
""", guild_id, member_id) """,
guild_id,
member_id,
)
roles = [r['role_id'] for r in roles] roles = [r["role_id"] for r in roles]
roles = set(roles) roles = set(roles)
wanted_roles = set(wanted_roles) wanted_roles = set(wanted_roles)
@ -96,26 +99,30 @@ async def _update_member_roles(guild_id: int, member_id: int,
async with conn.transaction(): async with conn.transaction():
# add roles # add roles
await app.db.executemany(""" await app.db.executemany(
"""
INSERT INTO member_roles (user_id, guild_id, role_id) INSERT INTO member_roles (user_id, guild_id, role_id)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
""", [(member_id, guild_id, role_id) """,
for role_id in added_roles]) [(member_id, guild_id, role_id) for role_id in added_roles],
)
# remove roles # remove roles
await app.db.executemany(""" await app.db.executemany(
"""
DELETE FROM member_roles DELETE FROM member_roles
WHERE WHERE
user_id = $1 user_id = $1
AND guild_id = $2 AND guild_id = $2
AND role_id = $3 AND role_id = $3
""", [(member_id, guild_id, role_id) """,
for role_id in removed_roles]) [(member_id, guild_id, role_id) for role_id in removed_roles],
)
await app.db.release(conn) await app.db.release(conn)
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['PATCH']) @bp.route("/<int:guild_id>/members/<int:member_id>", methods=["PATCH"])
async def modify_guild_member(guild_id, member_id): async def modify_guild_member(guild_id, member_id):
"""Modify a members' information in a guild.""" """Modify a members' information in a guild."""
user_id = await token_check() user_id = await token_check()
@ -124,96 +131,112 @@ async def modify_guild_member(guild_id, member_id):
j = validate(await request.get_json(), MEMBER_UPDATE) j = validate(await request.get_json(), MEMBER_UPDATE)
nick_flag = False nick_flag = False
if 'nick' in j: if "nick" in j:
await guild_perm_check(user_id, guild_id, 'manage_nicknames') await guild_perm_check(user_id, guild_id, "manage_nicknames")
nick = j['nick'] or None nick = j["nick"] or None
await app.db.execute(""" await app.db.execute(
"""
UPDATE members UPDATE members
SET nickname = $1 SET nickname = $1
WHERE user_id = $2 AND guild_id = $3 WHERE user_id = $2 AND guild_id = $3
""", nick, member_id, guild_id) """,
nick,
member_id,
guild_id,
)
nick_flag = True nick_flag = True
if 'mute' in j: if "mute" in j:
await guild_perm_check(user_id, guild_id, 'mute_members') await guild_perm_check(user_id, guild_id, "mute_members")
await app.db.execute(""" await app.db.execute(
"""
UPDATE members UPDATE members
SET muted = $1 SET muted = $1
WHERE user_id = $2 AND guild_id = $3 WHERE user_id = $2 AND guild_id = $3
""", j['mute'], member_id, guild_id) """,
j["mute"],
member_id,
guild_id,
)
if 'deaf' in j: if "deaf" in j:
await guild_perm_check(user_id, guild_id, 'deafen_members') await guild_perm_check(user_id, guild_id, "deafen_members")
await app.db.execute(""" await app.db.execute(
"""
UPDATE members UPDATE members
SET deafened = $1 SET deafened = $1
WHERE user_id = $2 AND guild_id = $3 WHERE user_id = $2 AND guild_id = $3
""", j['deaf'], member_id, guild_id) """,
j["deaf"],
member_id,
guild_id,
)
if 'channel_id' in j: if "channel_id" in j:
# TODO: check MOVE_MEMBERS and CONNECT to the channel # TODO: check MOVE_MEMBERS and CONNECT to the channel
# TODO: change the member's voice channel # TODO: change the member's voice channel
pass pass
if 'roles' in j: if "roles" in j:
await guild_perm_check(user_id, guild_id, 'manage_roles') await guild_perm_check(user_id, guild_id, "manage_roles")
await _update_member_roles(guild_id, member_id, j['roles']) await _update_member_roles(guild_id, member_id, j["roles"])
member = await app.storage.get_member_data_one(guild_id, member_id) member = await app.storage.get_member_data_one(guild_id, member_id)
member.pop('joined_at') member.pop("joined_at")
# call pres_update for role and nick changes. # call pres_update for role and nick changes.
partial = { partial = {"roles": member["roles"]}
'roles': member['roles']
}
if nick_flag: if nick_flag:
partial['nick'] = j['nick'] partial["nick"] = j["nick"]
await app.dispatcher.dispatch( await app.dispatcher.dispatch(
'lazy_guild', guild_id, 'pres_update', user_id, partial) "lazy_guild", guild_id, "pres_update", user_id, partial
)
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{ await app.dispatcher.dispatch_guild(
'guild_id': str(guild_id) guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}
}, **member}) )
return '', 204 return "", 204
@bp.route('/<int:guild_id>/members/@me/nick', methods=['PATCH']) @bp.route("/<int:guild_id>/members/@me/nick", methods=["PATCH"])
async def update_nickname(guild_id): async def update_nickname(guild_id):
"""Update a member's nickname in a guild.""" """Update a member's nickname in a guild."""
user_id = await token_check() user_id = await token_check()
await guild_check(user_id, guild_id) await guild_check(user_id, guild_id)
j = validate(await request.get_json(), { j = validate(await request.get_json(), {"nick": {"type": "nickname"}})
'nick': {'type': 'nickname'}
})
nick = j['nick'] or None nick = j["nick"] or None
await app.db.execute(""" await app.db.execute(
"""
UPDATE members UPDATE members
SET nickname = $1 SET nickname = $1
WHERE user_id = $2 AND guild_id = $3 WHERE user_id = $2 AND guild_id = $3
""", nick, user_id, guild_id) """,
nick,
user_id,
guild_id,
)
member = await app.storage.get_member_data_one(guild_id, user_id) member = await app.storage.get_member_data_one(guild_id, user_id)
member.pop('joined_at') member.pop("joined_at")
# call pres_update for nick changes, etc. # call pres_update for nick changes, etc.
await app.dispatcher.dispatch( await app.dispatcher.dispatch(
'lazy_guild', guild_id, 'pres_update', user_id, { "lazy_guild", guild_id, "pres_update", user_id, {"nick": j["nick"]}
'nick': j['nick'] )
})
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{ await app.dispatcher.dispatch_guild(
'guild_id': str(guild_id) guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}
}, **member}) )
return j['nick'] return j["nick"]

View File

@ -24,33 +24,38 @@ from litecord.blueprints.checks import guild_perm_check
from litecord.schemas import validate, GUILD_PRUNE from litecord.schemas import validate, GUILD_PRUNE
bp = Blueprint('guild_moderation', __name__) bp = Blueprint("guild_moderation", __name__)
async def remove_member(guild_id: int, member_id: int): async def remove_member(guild_id: int, member_id: int):
"""Do common tasks related to deleting a member from the guild, """Do common tasks related to deleting a member from the guild,
such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE.""" such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE."""
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM members DELETE FROM members
WHERE guild_id = $1 AND user_id = $2 WHERE guild_id = $1 AND user_id = $2
""", guild_id, member_id) """,
guild_id,
member_id,
)
await app.dispatcher.dispatch_user_guild( await app.dispatcher.dispatch_user_guild(
member_id, guild_id, 'GUILD_DELETE', { member_id,
'guild_id': str(guild_id), guild_id,
'unavailable': False, "GUILD_DELETE",
}) {"guild_id": str(guild_id), "unavailable": False},
)
await app.dispatcher.unsub('guild', guild_id, member_id) await app.dispatcher.unsub("guild", guild_id, member_id)
await app.dispatcher.dispatch( await app.dispatcher.dispatch("lazy_guild", guild_id, "remove_member", member_id)
'lazy_guild', guild_id, 'remove_member', member_id)
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', { await app.dispatcher.dispatch_guild(
'guild_id': str(guild_id), guild_id,
'user': await app.storage.get_user(member_id), "GUILD_MEMBER_REMOVE",
}) {"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
)
async def remove_member_multi(guild_id: int, members: list): async def remove_member_multi(guild_id: int, members: list):
@ -59,84 +64,100 @@ async def remove_member_multi(guild_id: int, members: list):
await remove_member(guild_id, member_id) await remove_member(guild_id, member_id)
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['DELETE']) @bp.route("/<int:guild_id>/members/<int:member_id>", methods=["DELETE"])
async def kick_guild_member(guild_id, member_id): async def kick_guild_member(guild_id, member_id):
"""Remove a member from a guild.""" """Remove a member from a guild."""
user_id = await token_check() user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'kick_members') await guild_perm_check(user_id, guild_id, "kick_members")
await remove_member(guild_id, member_id) await remove_member(guild_id, member_id)
return '', 204 return "", 204
@bp.route('/<int:guild_id>/bans', methods=['GET']) @bp.route("/<int:guild_id>/bans", methods=["GET"])
async def get_bans(guild_id): async def get_bans(guild_id):
user_id = await token_check() user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'ban_members') await guild_perm_check(user_id, guild_id, "ban_members")
bans = await app.db.fetch(""" bans = await app.db.fetch(
"""
SELECT user_id, reason SELECT user_id, reason
FROM bans FROM bans
WHERE bans.guild_id = $1 WHERE bans.guild_id = $1
""", guild_id) """,
guild_id,
)
res = [] res = []
for ban in bans: for ban in bans:
res.append({ res.append(
'reason': ban['reason'], {
'user': await app.storage.get_user(ban['user_id']) "reason": ban["reason"],
}) "user": await app.storage.get_user(ban["user_id"]),
}
)
return jsonify(res) return jsonify(res)
@bp.route('/<int:guild_id>/bans/<int:member_id>', methods=['PUT']) @bp.route("/<int:guild_id>/bans/<int:member_id>", methods=["PUT"])
async def create_ban(guild_id, member_id): async def create_ban(guild_id, member_id):
user_id = await token_check() user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'ban_members') await guild_perm_check(user_id, guild_id, "ban_members")
j = await request.get_json() j = await request.get_json()
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO bans (guild_id, user_id, reason) INSERT INTO bans (guild_id, user_id, reason)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
""", guild_id, member_id, j.get('reason', '')) """,
guild_id,
member_id,
j.get("reason", ""),
)
await remove_member(guild_id, member_id) await remove_member(guild_id, member_id)
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_ADD', { await app.dispatcher.dispatch_guild(
'guild_id': str(guild_id), guild_id,
'user': await app.storage.get_user(member_id) "GUILD_BAN_ADD",
}) {"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
)
return '', 204 return "", 204
@bp.route('/<int:guild_id>/bans/<int:banned_id>', methods=['DELETE']) @bp.route("/<int:guild_id>/bans/<int:banned_id>", methods=["DELETE"])
async def remove_ban(guild_id, banned_id): async def remove_ban(guild_id, banned_id):
user_id = await token_check() user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'ban_members') await guild_perm_check(user_id, guild_id, "ban_members")
res = await app.db.execute(""" res = await app.db.execute(
"""
DELETE FROM bans DELETE FROM bans
WHERE guild_id = $1 AND user_id = $@ WHERE guild_id = $1 AND user_id = $@
""", guild_id, banned_id) """,
guild_id,
banned_id,
)
# we don't really need to dispatch GUILD_BAN_REMOVE # we don't really need to dispatch GUILD_BAN_REMOVE
# when no bans were actually removed. # when no bans were actually removed.
if res == 'DELETE 0': if res == "DELETE 0":
return '', 204 return "", 204
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_REMOVE', { await app.dispatcher.dispatch_guild(
'guild_id': str(guild_id), guild_id,
'user': await app.storage.get_user(banned_id) "GUILD_BAN_REMOVE",
}) {"guild_id": str(guild_id), "user": await app.storage.get_user(banned_id)},
)
return '', 204 return "", 204
async def get_prune(guild_id: int, days: int) -> list: async def get_prune(guild_id: int, days: int) -> list:
@ -146,23 +167,30 @@ async def get_prune(guild_id: int, days: int) -> list:
- don't have any roles. - don't have any roles.
""" """
# a good solution would be in pure sql. # a good solution would be in pure sql.
member_ids = await app.db.fetch(f""" member_ids = await app.db.fetch(
f"""
SELECT id SELECT id
FROM users FROM users
JOIN members JOIN members
ON members.guild_id = $1 AND members.user_id = users.id ON members.guild_id = $1 AND members.user_id = users.id
WHERE users.last_session < (now() - (interval '{days} days')) WHERE users.last_session < (now() - (interval '{days} days'))
""", guild_id) """,
guild_id,
)
member_ids = [r['id'] for r in member_ids] member_ids = [r["id"] for r in member_ids]
members = [] members = []
for member_id in member_ids: for member_id in member_ids:
role_count = await app.db.fetchval(""" role_count = await app.db.fetchval(
"""
SELECT COUNT(*) SELECT COUNT(*)
FROM member_roles FROM member_roles
WHERE guild_id = $1 AND user_id = $2 WHERE guild_id = $1 AND user_id = $2
""", guild_id, member_id) """,
guild_id,
member_id,
)
if role_count == 0: if role_count == 0:
members.append(member_id) members.append(member_id)
@ -170,33 +198,29 @@ async def get_prune(guild_id: int, days: int) -> list:
return members return members
@bp.route('/<int:guild_id>/prune', methods=['GET']) @bp.route("/<int:guild_id>/prune", methods=["GET"])
async def get_guild_prune_count(guild_id): async def get_guild_prune_count(guild_id):
user_id = await token_check() user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'kick_members') await guild_perm_check(user_id, guild_id, "kick_members")
j = validate(request.args, GUILD_PRUNE) j = validate(request.args, GUILD_PRUNE)
days = j['days'] days = j["days"]
member_ids = await get_prune(guild_id, days) member_ids = await get_prune(guild_id, days)
return jsonify({ return jsonify({"pruned": len(member_ids)})
'pruned': len(member_ids),
})
@bp.route('/<int:guild_id>/prune', methods=['POST']) @bp.route("/<int:guild_id>/prune", methods=["POST"])
async def begin_guild_prune(guild_id): async def begin_guild_prune(guild_id):
user_id = await token_check() user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'kick_members') await guild_perm_check(user_id, guild_id, "kick_members")
j = validate(request.args, GUILD_PRUNE) j = validate(request.args, GUILD_PRUNE)
days = j['days'] days = j["days"]
member_ids = await get_prune(guild_id, days) member_ids = await get_prune(guild_id, days)
app.loop.create_task(remove_member_multi(guild_id, member_ids)) app.loop.create_task(remove_member_multi(guild_id, member_ids))
return jsonify({ return jsonify({"pruned": len(member_ids)})
'pruned': len(member_ids)
})

View File

@ -24,12 +24,8 @@ from logbook import Logger
from litecord.auth import token_check from litecord.auth import token_check
from litecord.blueprints.checks import ( from litecord.blueprints.checks import guild_check, guild_perm_check
guild_check, guild_perm_check from litecord.schemas import validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION
)
from litecord.schemas import (
validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION
)
from litecord.snowflake import get_snowflake from litecord.snowflake import get_snowflake
from litecord.utils import dict_get from litecord.utils import dict_get
@ -37,22 +33,19 @@ from litecord.permissions import get_role_perms
DEFAULT_EVERYONE_PERMS = 104324161 DEFAULT_EVERYONE_PERMS = 104324161
log = Logger(__name__) log = Logger(__name__)
bp = Blueprint('guild_roles', __name__) bp = Blueprint("guild_roles", __name__)
@bp.route('/<int:guild_id>/roles', methods=['GET']) @bp.route("/<int:guild_id>/roles", methods=["GET"])
async def get_guild_roles(guild_id): async def get_guild_roles(guild_id):
"""Get all roles in a guild.""" """Get all roles in a guild."""
user_id = await token_check() user_id = await token_check()
await guild_check(user_id, guild_id) await guild_check(user_id, guild_id)
return jsonify( return jsonify(await app.storage.get_role_data(guild_id))
await app.storage.get_role_data(guild_id)
)
async def _maybe_lg(guild_id: int, event: str, async def _maybe_lg(guild_id: int, event: str, role, force: bool = False):
role, force: bool = False):
# sometimes we want to dispatch an event # sometimes we want to dispatch an event
# even if the role isn't hoisted # even if the role isn't hoisted
@ -61,11 +54,10 @@ async def _maybe_lg(guild_id: int, event: str,
# check if is a dict first because role_delete # check if is a dict first because role_delete
# only receives the role id. # only receives the role id.
if isinstance(role, dict) and not role['hoist'] and not force: if isinstance(role, dict) and not role["hoist"] and not force:
return return
await app.dispatcher.dispatch( await app.dispatcher.dispatch("lazy_guild", guild_id, event, role)
'lazy_guild', guild_id, event, role)
async def create_role(guild_id, name: str, **kwargs): async def create_role(guild_id, name: str, **kwargs):
@ -73,18 +65,20 @@ async def create_role(guild_id, name: str, **kwargs):
new_role_id = get_snowflake() new_role_id = get_snowflake()
everyone_perms = await get_role_perms(guild_id, guild_id) everyone_perms = await get_role_perms(guild_id, guild_id)
default_perms = dict_get(kwargs, 'default_perms', default_perms = dict_get(kwargs, "default_perms", everyone_perms.binary)
everyone_perms.binary)
# update all roles so that we have space for pos 1, but without # update all roles so that we have space for pos 1, but without
# sending GUILD_ROLE_UPDATE for everyone # sending GUILD_ROLE_UPDATE for everyone
await app.db.execute(""" await app.db.execute(
"""
UPDATE roles UPDATE roles
SET SET
position = position + 1 position = position + 1
WHERE guild_id = $1 WHERE guild_id = $1
AND NOT (position = 0) AND NOT (position = 0)
""", guild_id) """,
guild_id,
)
await app.db.execute( await app.db.execute(
""" """
@ -95,42 +89,39 @@ async def create_role(guild_id, name: str, **kwargs):
new_role_id, new_role_id,
guild_id, guild_id,
name, name,
dict_get(kwargs, 'color', 0), dict_get(kwargs, "color", 0),
dict_get(kwargs, 'hoist', False), dict_get(kwargs, "hoist", False),
# always set ourselves on position 1 # always set ourselves on position 1
1, 1,
int(dict_get(kwargs, 'permissions', default_perms)), int(dict_get(kwargs, "permissions", default_perms)),
False, False,
dict_get(kwargs, 'mentionable', False) dict_get(kwargs, "mentionable", False),
) )
role = await app.storage.get_role(new_role_id, guild_id) role = await app.storage.get_role(new_role_id, guild_id)
# we need to update the lazy guild handlers for the newly created group # we need to update the lazy guild handlers for the newly created group
await _maybe_lg(guild_id, 'new_role', role) await _maybe_lg(guild_id, "new_role", role)
await app.dispatcher.dispatch_guild( await app.dispatcher.dispatch_guild(
guild_id, 'GUILD_ROLE_CREATE', { guild_id, "GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role}
'guild_id': str(guild_id), )
'role': role,
})
return role return role
@bp.route('/<int:guild_id>/roles', methods=['POST']) @bp.route("/<int:guild_id>/roles", methods=["POST"])
async def create_guild_role(guild_id: int): async def create_guild_role(guild_id: int):
"""Add a role to a guild""" """Add a role to a guild"""
user_id = await token_check() user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'manage_roles') await guild_perm_check(user_id, guild_id, "manage_roles")
# client can just send null # client can just send null
j = validate(await request.get_json() or {}, ROLE_CREATE) j = validate(await request.get_json() or {}, ROLE_CREATE)
role_name = j['name'] role_name = j["name"]
j.pop('name') j.pop("name")
role = await create_role(guild_id, role_name, **j) role = await create_role(guild_id, role_name, **j)
@ -141,12 +132,11 @@ async def _role_update_dispatch(role_id: int, guild_id: int):
"""Dispatch a GUILD_ROLE_UPDATE with updated information on a role.""" """Dispatch a GUILD_ROLE_UPDATE with updated information on a role."""
role = await app.storage.get_role(role_id, guild_id) role = await app.storage.get_role(role_id, guild_id)
await _maybe_lg(guild_id, 'role_pos_upd', role) await _maybe_lg(guild_id, "role_pos_upd", role)
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_UPDATE', { await app.dispatcher.dispatch_guild(
'guild_id': str(guild_id), guild_id, "GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role}
'role': role, )
})
return role return role
@ -166,17 +156,25 @@ async def _role_pairs_update(guild_id: int, pairs: list):
async with conn.transaction(): async with conn.transaction():
# update happens in a transaction # update happens in a transaction
# so we don't fuck it up # so we don't fuck it up
await conn.execute(""" await conn.execute(
"""
UPDATE roles UPDATE roles
SET position = $1 SET position = $1
WHERE roles.id = $2 WHERE roles.id = $2
""", new_pos_1, role_1) """,
new_pos_1,
role_1,
)
await conn.execute(""" await conn.execute(
"""
UPDATE roles UPDATE roles
SET position = $1 SET position = $1
WHERE roles.id = $2 WHERE roles.id = $2
""", new_pos_2, role_2) """,
new_pos_2,
role_2,
)
await app.db.release(conn) await app.db.release(conn)
@ -184,11 +182,15 @@ async def _role_pairs_update(guild_id: int, pairs: list):
await _role_update_dispatch(role_1, guild_id) await _role_update_dispatch(role_1, guild_id)
await _role_update_dispatch(role_2, guild_id) await _role_update_dispatch(role_2, guild_id)
PairList = List[Tuple[Tuple[int, int], Tuple[int, int]]] PairList = List[Tuple[Tuple[int, int], Tuple[int, int]]]
def gen_pairs(list_of_changes: List[Dict[str, int]],
current_state: Dict[int, int], def gen_pairs(
blacklist: List[int] = None) -> PairList: list_of_changes: List[Dict[str, int]],
current_state: Dict[int, int],
blacklist: List[int] = None,
) -> PairList:
"""Generate a list of pairs that, when applied to the database, """Generate a list of pairs that, when applied to the database,
will generate the desired state given in list_of_changes. will generate the desired state given in list_of_changes.
@ -230,8 +232,9 @@ def gen_pairs(list_of_changes: List[Dict[str, int]],
pairs = [] pairs = []
blacklist = blacklist or [] blacklist = blacklist or []
preferred_state = {element['id']: element['position'] preferred_state = {
for element in list_of_changes} element["id"]: element["position"] for element in list_of_changes
}
for blacklisted_id in blacklist: for blacklisted_id in blacklist:
preferred_state.pop(blacklisted_id) preferred_state.pop(blacklisted_id)
@ -239,7 +242,7 @@ def gen_pairs(list_of_changes: List[Dict[str, int]],
# for each change, we must find a matching change # for each change, we must find a matching change
# in the same list, so we can make a swap pair # in the same list, so we can make a swap pair
for change in list_of_changes: for change in list_of_changes:
element_1, new_pos_1 = change['id'], change['position'] element_1, new_pos_1 = change["id"], change["position"]
# check current pairs # check current pairs
# so we don't repeat an element # so we don't repeat an element
@ -267,36 +270,34 @@ def gen_pairs(list_of_changes: List[Dict[str, int]],
# if its being swapped to leave space, add it # if its being swapped to leave space, add it
# to the pairs list # to the pairs list
if new_pos_2 is not None: if new_pos_2 is not None:
pairs.append( pairs.append(((element_1, new_pos_1), (element_2, new_pos_2)))
((element_1, new_pos_1), (element_2, new_pos_2))
)
return pairs return pairs
@bp.route('/<int:guild_id>/roles', methods=['PATCH']) @bp.route("/<int:guild_id>/roles", methods=["PATCH"])
async def update_guild_role_positions(guild_id): async def update_guild_role_positions(guild_id):
"""Update the positions for a bunch of roles.""" """Update the positions for a bunch of roles."""
user_id = await token_check() user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'manage_roles') await guild_perm_check(user_id, guild_id, "manage_roles")
raw_j = await request.get_json() raw_j = await request.get_json()
# we need to do this hackiness because thats # we need to do this hackiness because thats
# cerberus for ya. # cerberus for ya.
j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION) j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION)
# extract the list out # extract the list out
j = j['roles'] j = j["roles"]
log.debug('role stuff: {!r}', j) log.debug("role stuff: {!r}", j)
all_roles = await app.storage.get_role_data(guild_id) all_roles = await app.storage.get_role_data(guild_id)
# we'll have to calculate pairs of changing roles, # we'll have to calculate pairs of changing roles,
# then do the changes, etc. # then do the changes, etc.
roles_pos = {role['position']: int(role['id']) for role in all_roles} roles_pos = {role["position"]: int(role["id"]) for role in all_roles}
# TODO: check if the user can even change the roles in the first place, # TODO: check if the user can even change the roles in the first place,
# preferrably when we have a proper perms system. # preferrably when we have a proper perms system.
@ -306,10 +307,9 @@ async def update_guild_role_positions(guild_id):
pairs = gen_pairs( pairs = gen_pairs(
j, j,
roles_pos, roles_pos,
# always ignore people trying to change # always ignore people trying to change
# the @everyone's role position # the @everyone's role position
[guild_id] [guild_id],
) )
await _role_pairs_update(guild_id, pairs) await _role_pairs_update(guild_id, pairs)
@ -318,31 +318,36 @@ async def update_guild_role_positions(guild_id):
return jsonify(await app.storage.get_role_data(guild_id)) return jsonify(await app.storage.get_role_data(guild_id))
@bp.route('/<int:guild_id>/roles/<int:role_id>', methods=['PATCH']) @bp.route("/<int:guild_id>/roles/<int:role_id>", methods=["PATCH"])
async def update_guild_role(guild_id, role_id): async def update_guild_role(guild_id, role_id):
"""Update a single role's information.""" """Update a single role's information."""
user_id = await token_check() user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'manage_roles') await guild_perm_check(user_id, guild_id, "manage_roles")
j = validate(await request.get_json(), ROLE_UPDATE) j = validate(await request.get_json(), ROLE_UPDATE)
# we only update ints on the db, not Permissions # we only update ints on the db, not Permissions
j['permissions'] = int(j['permissions']) j["permissions"] = int(j["permissions"])
for field in j: for field in j:
await app.db.execute(f""" await app.db.execute(
f"""
UPDATE roles UPDATE roles
SET {field} = $1 SET {field} = $1
WHERE roles.id = $2 AND roles.guild_id = $3 WHERE roles.id = $2 AND roles.guild_id = $3
""", j[field], role_id, guild_id) """,
j[field],
role_id,
guild_id,
)
role = await _role_update_dispatch(role_id, guild_id) role = await _role_update_dispatch(role_id, guild_id)
await _maybe_lg(guild_id, 'role_update', role, True) await _maybe_lg(guild_id, "role_update", role, True)
return jsonify(role) return jsonify(role)
@bp.route('/<int:guild_id>/roles/<int:role_id>', methods=['DELETE']) @bp.route("/<int:guild_id>/roles/<int:role_id>", methods=["DELETE"])
async def delete_guild_role(guild_id, role_id): async def delete_guild_role(guild_id, role_id):
"""Delete a role. """Delete a role.
@ -350,21 +355,26 @@ async def delete_guild_role(guild_id, role_id):
""" """
user_id = await token_check() user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'manage_roles') await guild_perm_check(user_id, guild_id, "manage_roles")
res = await app.db.execute(""" res = await app.db.execute(
"""
DELETE FROM roles DELETE FROM roles
WHERE guild_id = $1 AND id = $2 WHERE guild_id = $1 AND id = $2
""", guild_id, role_id) """,
guild_id,
role_id,
)
if res == 'DELETE 0': if res == "DELETE 0":
return '', 204 return "", 204
await _maybe_lg(guild_id, 'role_delete', role_id, True) await _maybe_lg(guild_id, "role_delete", role_id, True)
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_DELETE', { await app.dispatcher.dispatch_guild(
'guild_id': str(guild_id), guild_id,
'role_id': str(role_id), "GUILD_ROLE_DELETE",
}) {"guild_id": str(guild_id), "role_id": str(role_id)},
)
return '', 204 return "", 204

View File

@ -22,16 +22,17 @@ from typing import Optional, List
from quart import Blueprint, request, current_app as app, jsonify from quart import Blueprint, request, current_app as app, jsonify
from litecord.blueprints.guild.channels import create_guild_channel from litecord.blueprints.guild.channels import create_guild_channel
from litecord.blueprints.guild.roles import ( from litecord.blueprints.guild.roles import create_role, DEFAULT_EVERYONE_PERMS
create_role, DEFAULT_EVERYONE_PERMS
)
from ..auth import token_check from ..auth import token_check
from ..snowflake import get_snowflake from ..snowflake import get_snowflake
from ..enums import ChannelType from ..enums import ChannelType
from ..schemas import ( from ..schemas import (
validate, GUILD_CREATE, GUILD_UPDATE, SEARCH_CHANNEL, validate,
VANITY_URL_PATCH GUILD_CREATE,
GUILD_UPDATE,
SEARCH_CHANNEL,
VANITY_URL_PATCH,
) )
from .channels import channel_ack from .channels import channel_ack
from .checks import guild_check, guild_owner_check, guild_perm_check from .checks import guild_check, guild_owner_check, guild_perm_check
@ -40,7 +41,7 @@ from litecord.errors import BadRequest
from litecord.permissions import get_permissions from litecord.permissions import get_permissions
bp = Blueprint('guilds', __name__) bp = Blueprint("guilds", __name__)
async def create_guild_settings(guild_id: int, user_id: int): async def create_guild_settings(guild_id: int, user_id: int):
@ -49,26 +50,38 @@ async def create_guild_settings(guild_id: int, user_id: int):
# new guild_settings are based off the currently # new guild_settings are based off the currently
# set guild settings (for the guild) # set guild settings (for the guild)
m_notifs = await app.db.fetchval(""" m_notifs = await app.db.fetchval(
"""
SELECT default_message_notifications SELECT default_message_notifications
FROM guilds FROM guilds
WHERE id = $1 WHERE id = $1
""", guild_id) """,
guild_id,
)
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO guild_settings INSERT INTO guild_settings
(user_id, guild_id, message_notifications) (user_id, guild_id, message_notifications)
VALUES VALUES
($1, $2, $3) ($1, $2, $3)
""", user_id, guild_id, m_notifs) """,
user_id,
guild_id,
m_notifs,
)
async def add_member(guild_id: int, user_id: int): async def add_member(guild_id: int, user_id: int):
"""Add a user to a guild.""" """Add a user to a guild."""
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO members (user_id, guild_id) INSERT INTO members (user_id, guild_id)
VALUES ($1, $2) VALUES ($1, $2)
""", user_id, guild_id) """,
user_id,
guild_id,
)
await create_guild_settings(guild_id, user_id) await create_guild_settings(guild_id, user_id)
@ -83,28 +96,28 @@ async def guild_create_roles_prep(guild_id: int, roles: list):
# are patches to the @everyone role # are patches to the @everyone role
everyone_patches = roles[0] everyone_patches = roles[0]
for field in everyone_patches: for field in everyone_patches:
await app.db.execute(f""" await app.db.execute(
f"""
UPDATE roles UPDATE roles
SET {field}={everyone_patches[field]} SET {field}={everyone_patches[field]}
WHERE roles.id = $1 WHERE roles.id = $1
""", guild_id) """,
guild_id,
)
default_perms = (everyone_patches.get('permissions') default_perms = everyone_patches.get("permissions") or DEFAULT_EVERYONE_PERMS
or DEFAULT_EVERYONE_PERMS)
# from the 2nd and forward, # from the 2nd and forward,
# should be treated as new roles # should be treated as new roles
for role in roles[1:]: for role in roles[1:]:
await create_role( await create_role(guild_id, role["name"], default_perms=default_perms, **role)
guild_id, role['name'], default_perms=default_perms, **role
)
async def guild_create_channels_prep(guild_id: int, channels: list): async def guild_create_channels_prep(guild_id: int, channels: list):
"""Create channels pre-guild create""" """Create channels pre-guild create"""
for channel_raw in channels: for channel_raw in channels:
channel_id = get_snowflake() channel_id = get_snowflake()
ctype = ChannelType(channel_raw['type']) ctype = ChannelType(channel_raw["type"])
await create_guild_channel(guild_id, channel_id, ctype) await create_guild_channel(guild_id, channel_id, ctype)
@ -114,37 +127,29 @@ def sanitize_icon(icon: Optional[str]) -> Optional[str]:
Defaults to a jpeg icon when the header isn't given. Defaults to a jpeg icon when the header isn't given.
""" """
if icon and icon.startswith('data'): if icon and icon.startswith("data"):
return icon return icon
return (f'data:image/jpeg;base64,{icon}' return f"data:image/jpeg;base64,{icon}" if icon else None
if icon
else None)
async def _general_guild_icon(scope: str, guild_id: int, async def _general_guild_icon(scope: str, guild_id: int, icon: str, **kwargs):
icon: str, **kwargs):
encoded = sanitize_icon(icon) encoded = sanitize_icon(icon)
icon_kwargs = { icon_kwargs = {"always_icon": True}
'always_icon': True
}
if 'size' in kwargs: if "size" in kwargs:
icon_kwargs['size'] = kwargs['size'] icon_kwargs["size"] = kwargs["size"]
return await app.icons.put( return await app.icons.put(scope, guild_id, encoded, **icon_kwargs)
scope, guild_id, encoded,
**icon_kwargs
)
async def put_guild_icon(guild_id: int, icon: Optional[str]): async def put_guild_icon(guild_id: int, icon: Optional[str]):
"""Insert a guild icon on the icon database.""" """Insert a guild icon on the icon database."""
return await _general_guild_icon('guild', guild_id, icon, size=(128, 128)) return await _general_guild_icon("guild", guild_id, icon, size=(128, 128))
@bp.route('', methods=['POST']) @bp.route("", methods=["POST"])
async def create_guild(): async def create_guild():
"""Create a new guild, assigning """Create a new guild, assigning
the user creating it as the owner and the user creating it as the owner and
@ -154,8 +159,8 @@ async def create_guild():
guild_id = get_snowflake() guild_id = get_snowflake()
if 'icon' in j: if "icon" in j:
image = await put_guild_icon(guild_id, j['icon']) image = await put_guild_icon(guild_id, j["icon"])
image = image.icon_hash image = image.icon_hash
else: else:
image = None image = None
@ -166,10 +171,16 @@ async def create_guild():
verification_level, default_message_notifications, verification_level, default_message_notifications,
explicit_content_filter) explicit_content_filter)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
""", guild_id, j['name'], j['region'], image, user_id, """,
j.get('verification_level', 0), guild_id,
j.get('default_message_notifications', 0), j["name"],
j.get('explicit_content_filter', 0)) j["region"],
image,
user_id,
j.get("verification_level", 0),
j.get("default_message_notifications", 0),
j.get("explicit_content_filter", 0),
)
await add_member(guild_id, user_id) await add_member(guild_id, user_id)
@ -179,107 +190,127 @@ async def create_guild():
# we also don't use create_role because the id of the role # we also don't use create_role because the id of the role
# is the same as the id of the guild, and create_role # is the same as the id of the guild, and create_role
# generates a new snowflake. # generates a new snowflake.
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO roles (id, guild_id, name, position, permissions) INSERT INTO roles (id, guild_id, name, position, permissions)
VALUES ($1, $2, $3, $4, $5) VALUES ($1, $2, $3, $4, $5)
""", guild_id, guild_id, '@everyone', 0, DEFAULT_EVERYONE_PERMS) """,
guild_id,
guild_id,
"@everyone",
0,
DEFAULT_EVERYONE_PERMS,
)
# add the @everyone role to the guild creator # add the @everyone role to the guild creator
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO member_roles (user_id, guild_id, role_id) INSERT INTO member_roles (user_id, guild_id, role_id)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
""", user_id, guild_id, guild_id) """,
user_id,
guild_id,
guild_id,
)
# create a single #general channel. # create a single #general channel.
general_id = get_snowflake() general_id = get_snowflake()
await create_guild_channel( await create_guild_channel(
guild_id, general_id, ChannelType.GUILD_TEXT, guild_id, general_id, ChannelType.GUILD_TEXT, name="general"
name='general') )
if j.get('roles'): if j.get("roles"):
await guild_create_roles_prep(guild_id, j['roles']) await guild_create_roles_prep(guild_id, j["roles"])
if j.get('channels'): if j.get("channels"):
await guild_create_channels_prep(guild_id, j['channels']) await guild_create_channels_prep(guild_id, j["channels"])
guild_total = await app.storage.get_guild_full(guild_id, user_id, 250) guild_total = await app.storage.get_guild_full(guild_id, user_id, 250)
await app.dispatcher.sub('guild', guild_id, user_id) await app.dispatcher.sub("guild", guild_id, user_id)
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_CREATE', guild_total) await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild_total)
return jsonify(guild_total) return jsonify(guild_total)
@bp.route('/<int:guild_id>', methods=['GET']) @bp.route("/<int:guild_id>", methods=["GET"])
async def get_guild(guild_id): async def get_guild(guild_id):
"""Get a single guilds' information.""" """Get a single guilds' information."""
user_id = await token_check() user_id = await token_check()
await guild_check(user_id, guild_id) await guild_check(user_id, guild_id)
return jsonify( return jsonify(await app.storage.get_guild_full(guild_id, user_id, 250))
await app.storage.get_guild_full(guild_id, user_id, 250)
)
async def _guild_update_icon(scope: str, guild_id: int, async def _guild_update_icon(scope: str, guild_id: int, icon: Optional[str], **kwargs):
icon: Optional[str], **kwargs):
"""Update icon.""" """Update icon."""
new_icon = await app.icons.update( new_icon = await app.icons.update(scope, guild_id, icon, always_icon=True, **kwargs)
scope, guild_id, icon, always_icon=True, **kwargs
)
table = { table = {"guild": "icon"}.get(scope, scope)
'guild': 'icon',
}.get(scope, scope)
await app.db.execute(f""" await app.db.execute(
f"""
UPDATE guilds UPDATE guilds
SET {table} = $1 SET {table} = $1
WHERE id = $2 WHERE id = $2
""", new_icon.icon_hash, guild_id) """,
new_icon.icon_hash,
guild_id,
)
async def _guild_update_region(guild_id, region): async def _guild_update_region(guild_id, region):
is_vip = region.vip is_vip = region.vip
can_vip = await app.storage.has_feature(guild_id, 'VIP_REGIONS') can_vip = await app.storage.has_feature(guild_id, "VIP_REGIONS")
if is_vip and not can_vip: if is_vip and not can_vip:
raise BadRequest('can not assign guild to vip-only region') raise BadRequest("can not assign guild to vip-only region")
await app.db.execute(""" await app.db.execute(
"""
UPDATE guilds UPDATE guilds
SET region = $1 SET region = $1
WHERE id = $2 WHERE id = $2
""", region.id, guild_id) """,
region.id,
guild_id,
)
@bp.route("/<int:guild_id>", methods=["PATCH"])
@bp.route('/<int:guild_id>', methods=['PATCH'])
async def _update_guild(guild_id): async def _update_guild(guild_id):
user_id = await token_check() user_id = await token_check()
await guild_check(user_id, guild_id) await guild_check(user_id, guild_id)
await guild_perm_check(user_id, guild_id, 'manage_guild') await guild_perm_check(user_id, guild_id, "manage_guild")
j = validate(await request.get_json(), GUILD_UPDATE) j = validate(await request.get_json(), GUILD_UPDATE)
if 'owner_id' in j: if "owner_id" in j:
await guild_owner_check(user_id, guild_id) await guild_owner_check(user_id, guild_id)
await app.db.execute(""" await app.db.execute(
"""
UPDATE guilds UPDATE guilds
SET owner_id = $1 SET owner_id = $1
WHERE id = $2 WHERE id = $2
""", int(j['owner_id']), guild_id) """,
int(j["owner_id"]),
guild_id,
)
if 'name' in j: if "name" in j:
await app.db.execute(""" await app.db.execute(
"""
UPDATE guilds UPDATE guilds
SET name = $1 SET name = $1
WHERE id = $2 WHERE id = $2
""", j['name'], guild_id) """,
j["name"],
guild_id,
)
if 'region' in j: if "region" in j:
region = app.voice.lvsp.region(j['region']) region = app.voice.lvsp.region(j["region"])
if region is not None: if region is not None:
await _guild_update_region(guild_id, region) await _guild_update_region(guild_id, region)
@ -287,65 +318,77 @@ async def _update_guild(guild_id):
# small guild to work with to_update() # small guild to work with to_update()
guild = await app.storage.get_guild(guild_id) guild = await app.storage.get_guild(guild_id)
if to_update(j, guild, 'icon'): if to_update(j, guild, "icon"):
await _guild_update_icon( await _guild_update_icon("guild", guild_id, j["icon"], size=(128, 128))
'guild', guild_id, j['icon'], size=(128, 128))
if to_update(j, guild, 'splash'): if to_update(j, guild, "splash"):
if not await app.storage.has_feature(guild_id, 'INVITE_SPLASH'): if not await app.storage.has_feature(guild_id, "INVITE_SPLASH"):
raise BadRequest('guild does not have INVITE_SPLASH feature') raise BadRequest("guild does not have INVITE_SPLASH feature")
await _guild_update_icon('splash', guild_id, j['splash']) await _guild_update_icon("splash", guild_id, j["splash"])
if to_update(j, guild, 'banner'): if to_update(j, guild, "banner"):
if not await app.storage.has_feature(guild_id, 'VERIFIED'): if not await app.storage.has_feature(guild_id, "VERIFIED"):
raise BadRequest('guild is not verified') raise BadRequest("guild is not verified")
await _guild_update_icon('banner', guild_id, j['banner']) await _guild_update_icon("banner", guild_id, j["banner"])
fields = ['verification_level', 'default_message_notifications', fields = [
'explicit_content_filter', 'afk_timeout', 'description'] "verification_level",
"default_message_notifications",
"explicit_content_filter",
"afk_timeout",
"description",
]
for field in [f for f in fields if f in j]: for field in [f for f in fields if f in j]:
await app.db.execute(f""" await app.db.execute(
f"""
UPDATE guilds UPDATE guilds
SET {field} = $1 SET {field} = $1
WHERE id = $2 WHERE id = $2
""", j[field], guild_id) """,
j[field],
guild_id,
)
channel_fields = ['afk_channel_id', 'system_channel_id'] channel_fields = ["afk_channel_id", "system_channel_id"]
for field in [f for f in channel_fields if f in j]: for field in [f for f in channel_fields if f in j]:
# setting to null should remove the link between the afk/sys channel # setting to null should remove the link between the afk/sys channel
# to the guild. # to the guild.
if j[field] is None: if j[field] is None:
await app.db.execute(f""" await app.db.execute(
f"""
UPDATE guilds UPDATE guilds
SET {field} = NULL SET {field} = NULL
WHERE id = $1 WHERE id = $1
""", guild_id) """,
guild_id,
)
continue continue
chan = await app.storage.get_channel(int(j[field])) chan = await app.storage.get_channel(int(j[field]))
if chan is None: if chan is None:
raise BadRequest('invalid channel id') raise BadRequest("invalid channel id")
if chan['guild_id'] != str(guild_id): if chan["guild_id"] != str(guild_id):
raise BadRequest('channel id not linked to guild') raise BadRequest("channel id not linked to guild")
await app.db.execute(f""" await app.db.execute(
f"""
UPDATE guilds UPDATE guilds
SET {field} = $1 SET {field} = $1
WHERE id = $2 WHERE id = $2
""", j[field], guild_id) """,
j[field],
guild_id,
)
guild = await app.storage.get_guild_full( guild = await app.storage.get_guild_full(guild_id, user_id)
guild_id, user_id
)
await app.dispatcher.dispatch_guild( await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
guild_id, 'GUILD_UPDATE', guild)
return jsonify(guild) return jsonify(guild)
@ -354,33 +397,41 @@ async def delete_guild(guild_id: int, *, app_=None):
"""Delete a single guild.""" """Delete a single guild."""
app_ = app_ or app app_ = app_ or app
await app_.db.execute(""" await app_.db.execute(
"""
DELETE FROM guilds DELETE FROM guilds
WHERE guilds.id = $1 WHERE guilds.id = $1
""", guild_id) """,
guild_id,
)
# Discord's client expects IDs being string # Discord's client expects IDs being string
await app_.dispatcher.dispatch('guild', guild_id, 'GUILD_DELETE', { await app_.dispatcher.dispatch(
'guild_id': str(guild_id), "guild",
'id': str(guild_id), guild_id,
# 'unavailable': False, "GUILD_DELETE",
}) {
"guild_id": str(guild_id),
"id": str(guild_id),
# 'unavailable': False,
},
)
# remove from the dispatcher so nobody # remove from the dispatcher so nobody
# becomes the little memer that tries to fuck up with # becomes the little memer that tries to fuck up with
# everybody's gateway # everybody's gateway
await app_.dispatcher.remove('guild', guild_id) await app_.dispatcher.remove("guild", guild_id)
@bp.route('/<int:guild_id>', methods=['DELETE']) @bp.route("/<int:guild_id>", methods=["DELETE"])
# this endpoint is not documented, but used by the official client. # this endpoint is not documented, but used by the official client.
@bp.route('/<int:guild_id>/delete', methods=['POST']) @bp.route("/<int:guild_id>/delete", methods=["POST"])
async def delete_guild_handler(guild_id): async def delete_guild_handler(guild_id):
"""Delete a guild.""" """Delete a guild."""
user_id = await token_check() user_id = await token_check()
await guild_owner_check(user_id, guild_id) await guild_owner_check(user_id, guild_id)
await delete_guild(guild_id) await delete_guild(guild_id)
return '', 204 return "", 204
async def fetch_readable_channels(guild_id: int, user_id: int) -> List[int]: async def fetch_readable_channels(guild_id: int, user_id: int) -> List[int]:
@ -397,7 +448,7 @@ async def fetch_readable_channels(guild_id: int, user_id: int) -> List[int]:
return res return res
@bp.route('/<int:guild_id>/messages/search', methods=['GET']) @bp.route("/<int:guild_id>/messages/search", methods=["GET"])
async def search_messages(guild_id): async def search_messages(guild_id):
"""Search messages in a guild. """Search messages in a guild.
@ -415,7 +466,8 @@ async def search_messages(guild_id):
# use that list on the main search query. # use that list on the main search query.
can_read = await fetch_readable_channels(guild_id, user_id) can_read = await fetch_readable_channels(guild_id, user_id)
rows = await app.db.fetch(f""" rows = await app.db.fetch(
f"""
SELECT orig.id AS current_id, SELECT orig.id AS current_id,
COUNT(*) OVER() as total_results, COUNT(*) OVER() as total_results,
array((SELECT messages.id AS before_id array((SELECT messages.id AS before_id
@ -432,12 +484,17 @@ async def search_messages(guild_id):
ORDER BY orig.id DESC ORDER BY orig.id DESC
LIMIT 50 LIMIT 50
OFFSET $3 OFFSET $3
""", guild_id, j['content'], j['offset'], can_read) """,
guild_id,
j["content"],
j["offset"],
can_read,
)
return jsonify(await search_result_from_list(rows)) return jsonify(await search_result_from_list(rows))
@bp.route('/<int:guild_id>/ack', methods=['POST']) @bp.route("/<int:guild_id>/ack", methods=["POST"])
async def ack_guild(guild_id): async def ack_guild(guild_id):
"""ACKnowledge all messages in the guild.""" """ACKnowledge all messages in the guild."""
user_id = await token_check() user_id = await token_check()
@ -448,45 +505,43 @@ async def ack_guild(guild_id):
for chan_id in chan_ids: for chan_id in chan_ids:
await channel_ack(user_id, guild_id, chan_id) await channel_ack(user_id, guild_id, chan_id)
return '', 204 return "", 204
@bp.route('/<int:guild_id>/vanity-url', methods=['GET']) @bp.route("/<int:guild_id>/vanity-url", methods=["GET"])
async def get_vanity_url(guild_id: int): async def get_vanity_url(guild_id: int):
"""Get the vanity url of a guild.""" """Get the vanity url of a guild."""
user_id = await token_check() user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'manage_guild') await guild_perm_check(user_id, guild_id, "manage_guild")
inv_code = await app.storage.vanity_invite(guild_id) inv_code = await app.storage.vanity_invite(guild_id)
if inv_code is None: if inv_code is None:
return jsonify({'code': None}) return jsonify({"code": None})
return jsonify( return jsonify(await app.storage.get_invite(inv_code))
await app.storage.get_invite(inv_code)
)
@bp.route('/<int:guild_id>/vanity-url', methods=['PATCH']) @bp.route("/<int:guild_id>/vanity-url", methods=["PATCH"])
async def change_vanity_url(guild_id: int): async def change_vanity_url(guild_id: int):
"""Get the vanity url of a guild.""" """Get the vanity url of a guild."""
user_id = await token_check() user_id = await token_check()
if not await app.storage.has_feature(guild_id, 'VANITY_URL'): if not await app.storage.has_feature(guild_id, "VANITY_URL"):
# TODO: is this the right error # TODO: is this the right error
raise BadRequest('guild has no vanity url support') raise BadRequest("guild has no vanity url support")
await guild_perm_check(user_id, guild_id, 'manage_guild') await guild_perm_check(user_id, guild_id, "manage_guild")
j = validate(await request.get_json(), VANITY_URL_PATCH) j = validate(await request.get_json(), VANITY_URL_PATCH)
inv_code = j['code'] inv_code = j["code"]
# store old vanity in a variable to delete it from # store old vanity in a variable to delete it from
# invites table # invites table
old_vanity = await app.storage.vanity_invite(guild_id) old_vanity = await app.storage.vanity_invite(guild_id)
if old_vanity == inv_code: if old_vanity == inv_code:
raise BadRequest('can not change to same invite') raise BadRequest("can not change to same invite")
# this is sad because we don't really use the things # this is sad because we don't really use the things
# sql gives us, but i havent really found a way to put # sql gives us, but i havent really found a way to put
@ -494,19 +549,22 @@ async def change_vanity_url(guild_id: int):
# guild_id_fkey fails but INSERT when code_fkey fails.. # guild_id_fkey fails but INSERT when code_fkey fails..
inv = await app.storage.get_invite(inv_code) inv = await app.storage.get_invite(inv_code)
if inv: if inv:
raise BadRequest('invite already exists') raise BadRequest("invite already exists")
# TODO: this is bad, what if a guild has no channels? # TODO: this is bad, what if a guild has no channels?
# we should probably choose the first channel that has # we should probably choose the first channel that has
# @everyone read messages # @everyone read messages
channels = await app.storage.get_channel_data(guild_id) channels = await app.storage.get_channel_data(guild_id)
channel_id = int(channels[0]['id']) channel_id = int(channels[0]["id"])
# delete the old invite, insert new one # delete the old invite, insert new one
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM invites DELETE FROM invites
WHERE code = $1 WHERE code = $1
""", old_vanity) """,
old_vanity,
)
await app.db.execute( await app.db.execute(
""" """
@ -515,21 +573,27 @@ async def change_vanity_url(guild_id: int):
max_age, temporary) max_age, temporary)
VALUES ($1, $2, $3, $4, $5, $6, $7) VALUES ($1, $2, $3, $4, $5, $6, $7)
""", """,
inv_code, guild_id, channel_id, user_id, inv_code,
guild_id,
channel_id,
user_id,
# sane defaults for vanity urls. # sane defaults for vanity urls.
0, 0, False, 0,
0,
False,
) )
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO vanity_invites (guild_id, code) INSERT INTO vanity_invites (guild_id, code)
VALUES ($1, $2) VALUES ($1, $2)
ON CONFLICT ON CONSTRAINT vanity_invites_pkey DO ON CONFLICT ON CONSTRAINT vanity_invites_pkey DO
UPDATE UPDATE
SET code = $2 SET code = $2
WHERE vanity_invites.guild_id = $1 WHERE vanity_invites.guild_id = $1
""", guild_id, inv_code) """,
guild_id,
return jsonify( inv_code,
await app.storage.get_invite(inv_code)
) )
return jsonify(await app.storage.get_invite(inv_code))

View File

@ -24,41 +24,39 @@ from quart import Blueprint, current_app as app, send_file, redirect
from litecord.embed.sanitizer import make_md_req_url from litecord.embed.sanitizer import make_md_req_url
from litecord.embed.schemas import EmbedURL from litecord.embed.schemas import EmbedURL
bp = Blueprint('images', __name__) bp = Blueprint("images", __name__)
async def send_icon(scope, key, icon_hash, **kwargs): async def send_icon(scope, key, icon_hash, **kwargs):
"""Send an icon.""" """Send an icon."""
icon = await app.icons.generic_get( icon = await app.icons.generic_get(scope, key, icon_hash, **kwargs)
scope, key, icon_hash, **kwargs)
if not icon: if not icon:
return '', 404 return "", 404
return await send_file(icon.as_path) return await send_file(icon.as_path)
def splitext_(filepath): def splitext_(filepath):
name, ext = splitext(filepath) name, ext = splitext(filepath)
return name, ext.strip('.') return name, ext.strip(".")
@bp.route('/emojis/<emoji_file>', methods=['GET']) @bp.route("/emojis/<emoji_file>", methods=["GET"])
async def _get_raw_emoji(emoji_file): async def _get_raw_emoji(emoji_file):
# emoji = app.icons.get_emoji(emoji_id, ext=ext) # emoji = app.icons.get_emoji(emoji_id, ext=ext)
# just a test file for now # just a test file for now
emoji_id, ext = splitext_(emoji_file) emoji_id, ext = splitext_(emoji_file)
return await send_icon( return await send_icon("emoji", emoji_id, None, ext=ext)
'emoji', emoji_id, None, ext=ext)
@bp.route('/icons/<int:guild_id>/<icon_file>', methods=['GET']) @bp.route("/icons/<int:guild_id>/<icon_file>", methods=["GET"])
async def _get_guild_icon(guild_id: int, icon_file: str): async def _get_guild_icon(guild_id: int, icon_file: str):
icon_hash, ext = splitext_(icon_file) icon_hash, ext = splitext_(icon_file)
return await send_icon('guild', guild_id, icon_hash, ext=ext) return await send_icon("guild", guild_id, icon_hash, ext=ext)
@bp.route('/embed/avatars/<int:default_id>.png') @bp.route("/embed/avatars/<int:default_id>.png")
async def _get_default_user_avatar(default_id: int): async def _get_default_user_avatar(default_id: int):
# TODO: how do we determine which assets to use for this? # TODO: how do we determine which assets to use for this?
# I don't think we can use discord assets. # I don't think we can use discord assets.
@ -66,25 +64,29 @@ async def _get_default_user_avatar(default_id: int):
async def _handle_webhook_avatar(md_url_redir: str): async def _handle_webhook_avatar(md_url_redir: str):
md_url = make_md_req_url(app.config, 'img', EmbedURL(md_url_redir)) md_url = make_md_req_url(app.config, "img", EmbedURL(md_url_redir))
return redirect(md_url) return redirect(md_url)
@bp.route('/avatars/<int:user_id>/<avatar_file>') @bp.route("/avatars/<int:user_id>/<avatar_file>")
async def _get_user_avatar(user_id, avatar_file): async def _get_user_avatar(user_id, avatar_file):
avatar_hash, ext = splitext_(avatar_file) avatar_hash, ext = splitext_(avatar_file)
# first, check if this is a webhook avatar to redir to # first, check if this is a webhook avatar to redir to
md_url_redir = await app.db.fetchval(""" md_url_redir = await app.db.fetchval(
"""
SELECT md_url_redir SELECT md_url_redir
FROM webhook_avatars FROM webhook_avatars
WHERE webhook_id = $1 AND hash = $2 WHERE webhook_id = $1 AND hash = $2
""", user_id, avatar_hash) """,
user_id,
avatar_hash,
)
if md_url_redir: if md_url_redir:
return await _handle_webhook_avatar(md_url_redir) return await _handle_webhook_avatar(md_url_redir)
return await send_icon('user', user_id, avatar_hash, ext=ext) return await send_icon("user", user_id, avatar_hash, ext=ext)
# @bp.route('/app-icons/<int:application_id>/<icon_hash>.<ext>') # @bp.route('/app-icons/<int:application_id>/<icon_hash>.<ext>')
@ -92,19 +94,19 @@ async def get_app_icon(application_id, icon_hash, ext):
pass pass
@bp.route('/channel-icons/<int:channel_id>/<icon_file>', methods=['GET']) @bp.route("/channel-icons/<int:channel_id>/<icon_file>", methods=["GET"])
async def _get_gdm_icon(channel_id: int, icon_file: str): async def _get_gdm_icon(channel_id: int, icon_file: str):
icon_hash, ext = splitext_(icon_file) icon_hash, ext = splitext_(icon_file)
return await send_icon('channel-icons', channel_id, icon_hash, ext=ext) return await send_icon("channel-icons", channel_id, icon_hash, ext=ext)
@bp.route('/splashes/<int:guild_id>/<icon_file>', methods=['GET']) @bp.route("/splashes/<int:guild_id>/<icon_file>", methods=["GET"])
async def _get_guild_splash(guild_id: int, icon_file: str): async def _get_guild_splash(guild_id: int, icon_file: str):
icon_hash, ext = splitext_(icon_file) icon_hash, ext = splitext_(icon_file)
return await send_icon('splash', guild_id, icon_hash, ext=ext) return await send_icon("splash", guild_id, icon_hash, ext=ext)
@bp.route('/banners/<int:guild_id>/<icon_file>', methods=['GET']) @bp.route("/banners/<int:guild_id>/<icon_file>", methods=["GET"])
async def _get_guild_banner(guild_id: int, icon_file: str): async def _get_guild_banner(guild_id: int, icon_file: str):
icon_hash, ext = splitext_(icon_file) icon_hash, ext = splitext_(icon_file)
return await send_icon('banner', guild_id, icon_hash, ext=ext) return await send_icon("banner", guild_id, icon_hash, ext=ext)

View File

@ -32,13 +32,16 @@ from .guilds import create_guild_settings
from ..utils import async_map from ..utils import async_map
from litecord.blueprints.checks import ( from litecord.blueprints.checks import (
channel_check, channel_perm_check, guild_check, guild_perm_check channel_check,
channel_perm_check,
guild_check,
guild_perm_check,
) )
from litecord.blueprints.dm_channels import gdm_is_member, gdm_add_recipient from litecord.blueprints.dm_channels import gdm_is_member, gdm_add_recipient
log = Logger(__name__) log = Logger(__name__)
bp = Blueprint('invites', __name__) bp = Blueprint("invites", __name__)
class UnknownInvite(BadRequest): class UnknownInvite(BadRequest):
@ -48,16 +51,18 @@ class UnknownInvite(BadRequest):
class InvalidInvite(Forbidden): class InvalidInvite(Forbidden):
error_code = 50020 error_code = 50020
class AlreadyInvited(BaseException): class AlreadyInvited(BaseException):
pass pass
def gen_inv_code() -> str: def gen_inv_code() -> str:
"""Generate an invite code. """Generate an invite code.
This is a primitive and does not guarantee uniqueness. This is a primitive and does not guarantee uniqueness.
""" """
raw = secrets.token_urlsafe(10) raw = secrets.token_urlsafe(10)
raw = re.sub(r'\/|\+|\-|\_', '', raw) raw = re.sub(r"\/|\+|\-|\_", "", raw)
return raw[:7] return raw[:7]
@ -65,23 +70,31 @@ def gen_inv_code() -> str:
async def invite_precheck(user_id: int, guild_id: int): async def invite_precheck(user_id: int, guild_id: int):
"""pre-check invite use in the context of a guild.""" """pre-check invite use in the context of a guild."""
joined = await app.db.fetchval(""" joined = await app.db.fetchval(
"""
SELECT joined_at SELECT joined_at
FROM members FROM members
WHERE user_id = $1 AND guild_id = $2 WHERE user_id = $1 AND guild_id = $2
""", user_id, guild_id) """,
user_id,
guild_id,
)
if joined is not None: if joined is not None:
raise AlreadyInvited('You are already in the guild') raise AlreadyInvited("You are already in the guild")
banned = await app.db.fetchval(""" banned = await app.db.fetchval(
"""
SELECT reason SELECT reason
FROM bans FROM bans
WHERE user_id = $1 AND guild_id = $2 WHERE user_id = $1 AND guild_id = $2
""", user_id, guild_id) """,
user_id,
guild_id,
)
if banned is not None: if banned is not None:
raise InvalidInvite('You are banned.') raise InvalidInvite("You are banned.")
async def invite_precheck_gdm(user_id: int, channel_id: int): async def invite_precheck_gdm(user_id: int, channel_id: int):
@ -89,23 +102,23 @@ async def invite_precheck_gdm(user_id: int, channel_id: int):
is_member = await gdm_is_member(channel_id, user_id) is_member = await gdm_is_member(channel_id, user_id)
if is_member: if is_member:
raise AlreadyInvited('You are already in the Group DM') raise AlreadyInvited("You are already in the Group DM")
async def _inv_check_age(inv: dict): async def _inv_check_age(inv: dict):
if inv['max_age'] == 0: if inv["max_age"] == 0:
return return
now = datetime.datetime.utcnow() now = datetime.datetime.utcnow()
delta_sec = (now - inv['created_at']).total_seconds() delta_sec = (now - inv["created_at"]).total_seconds()
if delta_sec > inv['max_age']: if delta_sec > inv["max_age"]:
await delete_invite(inv['code']) await delete_invite(inv["code"])
raise InvalidInvite('Invite is expired') raise InvalidInvite("Invite is expired")
if inv['max_uses'] is not -1 and inv['uses'] > inv['max_uses']: if inv["max_uses"] is not -1 and inv["uses"] > inv["max_uses"]:
await delete_invite(inv['code']) await delete_invite(inv["code"])
raise InvalidInvite('Too many uses') raise InvalidInvite("Too many uses")
async def _guild_add_member(guild_id: int, user_id: int): async def _guild_add_member(guild_id: int, user_id: int):
@ -119,78 +132,89 @@ async def _guild_add_member(guild_id: int, user_id: int):
""" """
# TODO: system message for member join # TODO: system message for member join
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO members (user_id, guild_id) INSERT INTO members (user_id, guild_id)
VALUES ($1, $2) VALUES ($1, $2)
""", user_id, guild_id) """,
user_id,
guild_id,
)
await create_guild_settings(guild_id, user_id) await create_guild_settings(guild_id, user_id)
# add the @everyone role to the invited member # add the @everyone role to the invited member
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO member_roles (user_id, guild_id, role_id) INSERT INTO member_roles (user_id, guild_id, role_id)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
""", user_id, guild_id, guild_id) """,
user_id,
guild_id,
guild_id,
)
# tell current members a new member came up # tell current members a new member came up
member = await app.storage.get_member_data_one(guild_id, user_id) member = await app.storage.get_member_data_one(guild_id, user_id)
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_ADD', { await app.dispatcher.dispatch_guild(
**member, guild_id, "GUILD_MEMBER_ADD", {**member, **{"guild_id": str(guild_id)}}
**{ )
'guild_id': str(guild_id),
},
})
# update member lists for the new member # update member lists for the new member
await app.dispatcher.dispatch( await app.dispatcher.dispatch("lazy_guild", guild_id, "new_member", user_id)
'lazy_guild', guild_id, 'new_member', user_id)
# subscribe new member to guild, so they get events n stuff # subscribe new member to guild, so they get events n stuff
await app.dispatcher.sub('guild', guild_id, user_id) await app.dispatcher.sub("guild", guild_id, user_id)
# tell the new member that theres the guild it just joined. # tell the new member that theres the guild it just joined.
# we use dispatch_user_guild so that we send the GUILD_CREATE # we use dispatch_user_guild so that we send the GUILD_CREATE
# just to the shards that are actually tied to it. # just to the shards that are actually tied to it.
guild = await app.storage.get_guild_full(guild_id, user_id, 250) guild = await app.storage.get_guild_full(guild_id, user_id, 250)
await app.dispatcher.dispatch_user_guild( await app.dispatcher.dispatch_user_guild(user_id, guild_id, "GUILD_CREATE", guild)
user_id, guild_id, 'GUILD_CREATE', guild)
async def use_invite(user_id, invite_code): async def use_invite(user_id, invite_code):
"""Try using an invite""" """Try using an invite"""
inv = await app.db.fetchrow(""" inv = await app.db.fetchrow(
"""
SELECT code, channel_id, guild_id, created_at, SELECT code, channel_id, guild_id, created_at,
max_age, uses, max_uses max_age, uses, max_uses
FROM invites FROM invites
WHERE code = $1 WHERE code = $1
""", invite_code) """,
invite_code,
)
if inv is None: if inv is None:
raise UnknownInvite('Unknown invite') raise UnknownInvite("Unknown invite")
await _inv_check_age(inv) await _inv_check_age(inv)
# NOTE: if group dm invite, guild_id is null. # NOTE: if group dm invite, guild_id is null.
guild_id = inv['guild_id'] guild_id = inv["guild_id"]
try: try:
if guild_id is None: if guild_id is None:
channel_id = inv['channel_id'] channel_id = inv["channel_id"]
await invite_precheck_gdm(user_id, inv['channel_id']) await invite_precheck_gdm(user_id, inv["channel_id"])
await gdm_add_recipient(channel_id, user_id) await gdm_add_recipient(channel_id, user_id)
else: else:
await invite_precheck(user_id, guild_id) await invite_precheck(user_id, guild_id)
await _guild_add_member(guild_id, user_id) await _guild_add_member(guild_id, user_id)
await app.db.execute(""" await app.db.execute(
"""
UPDATE invites UPDATE invites
SET uses = uses + 1 SET uses = uses + 1
WHERE code = $1 WHERE code = $1
""", invite_code) """,
invite_code,
)
except AlreadyInvited: except AlreadyInvited:
pass pass
@bp.route('/channels/<int:channel_id>/invites', methods=['POST'])
@bp.route("/channels/<int:channel_id>/invites", methods=["POST"])
async def create_invite(channel_id): async def create_invite(channel_id):
"""Create an invite to a channel.""" """Create an invite to a channel."""
user_id = await token_check() user_id = await token_check()
@ -201,12 +225,14 @@ async def create_invite(channel_id):
# NOTE: this works on group dms, since it returns ALL_PERMISSIONS on # NOTE: this works on group dms, since it returns ALL_PERMISSIONS on
# non-guild channels. # non-guild channels.
await channel_perm_check(user_id, channel_id, 'create_invites') await channel_perm_check(user_id, channel_id, "create_invites")
if chantype not in (ChannelType.GUILD_TEXT, if chantype not in (
ChannelType.GUILD_VOICE, ChannelType.GUILD_TEXT,
ChannelType.GROUP_DM): ChannelType.GUILD_VOICE,
raise BadRequest('Invalid channel type') ChannelType.GROUP_DM,
):
raise BadRequest("Invalid channel type")
invite_code = gen_inv_code() invite_code = gen_inv_code()
@ -222,101 +248,122 @@ async def create_invite(channel_id):
max_age, temporary) max_age, temporary)
VALUES ($1, $2, $3, $4, $5, $6, $7) VALUES ($1, $2, $3, $4, $5, $6, $7)
""", """,
invite_code, guild_id, channel_id, user_id, invite_code,
j['max_uses'], j['max_age'], j['temporary'] guild_id,
channel_id,
user_id,
j["max_uses"],
j["max_age"],
j["temporary"],
) )
invite = await app.storage.get_invite(invite_code) invite = await app.storage.get_invite(invite_code)
return jsonify(invite) return jsonify(invite)
@bp.route('/invite/<invite_code>', methods=['GET']) @bp.route("/invite/<invite_code>", methods=["GET"])
@bp.route('/invites/<invite_code>', methods=['GET']) @bp.route("/invites/<invite_code>", methods=["GET"])
async def get_invite(invite_code: str): async def get_invite(invite_code: str):
inv = await app.storage.get_invite(invite_code) inv = await app.storage.get_invite(invite_code)
if not inv: if not inv:
return '', 404 return "", 404
if request.args.get('with_counts'): if request.args.get("with_counts"):
extra = await app.storage.get_invite_extra(invite_code) extra = await app.storage.get_invite_extra(invite_code)
inv.update(extra) inv.update(extra)
return jsonify(inv) return jsonify(inv)
async def delete_invite(invite_code: str): async def delete_invite(invite_code: str):
"""Delete an invite.""" """Delete an invite."""
await app.db.fetchval(""" await app.db.fetchval(
"""
DELETE FROM invites DELETE FROM invites
WHERE code = $1 WHERE code = $1
""", invite_code) """,
invite_code,
)
@bp.route('/invite/<invite_code>', methods=['DELETE'])
@bp.route('/invites/<invite_code>', methods=['DELETE']) @bp.route("/invite/<invite_code>", methods=["DELETE"])
@bp.route("/invites/<invite_code>", methods=["DELETE"])
async def _delete_invite(invite_code: str): async def _delete_invite(invite_code: str):
user_id = await token_check() user_id = await token_check()
guild_id = await app.db.fetchval(""" guild_id = await app.db.fetchval(
"""
SELECT guild_id SELECT guild_id
FROM invites FROM invites
WHERE code = $1 WHERE code = $1
""", invite_code) """,
invite_code,
)
if guild_id is None: if guild_id is None:
raise BadRequest('Unknown invite') raise BadRequest("Unknown invite")
await guild_perm_check(user_id, guild_id, 'manage_channels') await guild_perm_check(user_id, guild_id, "manage_channels")
inv = await app.storage.get_invite(invite_code) inv = await app.storage.get_invite(invite_code)
await delete_invite(invite_code) await delete_invite(invite_code)
return jsonify(inv) return jsonify(inv)
async def _get_inv(code): async def _get_inv(code):
inv = await app.storage.get_invite(code) inv = await app.storage.get_invite(code)
meta = await app.storage.get_invite_metadata(code) meta = await app.storage.get_invite_metadata(code)
return {**inv, **meta} return {**inv, **meta}
@bp.route('/guilds/<int:guild_id>/invites', methods=['GET']) @bp.route("/guilds/<int:guild_id>/invites", methods=["GET"])
async def get_guild_invites(guild_id: int): async def get_guild_invites(guild_id: int):
"""Get all invites for a guild.""" """Get all invites for a guild."""
user_id = await token_check() user_id = await token_check()
await guild_check(user_id, guild_id) await guild_check(user_id, guild_id)
await guild_perm_check(user_id, guild_id, 'manage_guild') await guild_perm_check(user_id, guild_id, "manage_guild")
inv_codes = await app.db.fetch(""" inv_codes = await app.db.fetch(
"""
SELECT code SELECT code
FROM invites FROM invites
WHERE guild_id = $1 WHERE guild_id = $1
""", guild_id) """,
guild_id,
)
inv_codes = [r['code'] for r in inv_codes] inv_codes = [r["code"] for r in inv_codes]
invs = await async_map(_get_inv, inv_codes) invs = await async_map(_get_inv, inv_codes)
return jsonify(invs) return jsonify(invs)
@bp.route('/channels/<int:channel_id>/invites', methods=['GET']) @bp.route("/channels/<int:channel_id>/invites", methods=["GET"])
async def get_channel_invites(channel_id: int): async def get_channel_invites(channel_id: int):
"""Get all invites for a channel.""" """Get all invites for a channel."""
user_id = await token_check() user_id = await token_check()
_ctype, guild_id = await channel_check(user_id, channel_id) _ctype, guild_id = await channel_check(user_id, channel_id)
await guild_perm_check(user_id, guild_id, 'manage_channels') await guild_perm_check(user_id, guild_id, "manage_channels")
inv_codes = await app.db.fetch(""" inv_codes = await app.db.fetch(
"""
SELECT code SELECT code
FROM invites FROM invites
WHERE guild_id = $1 AND channel_id = $2 WHERE guild_id = $1 AND channel_id = $2
""", guild_id, channel_id) """,
guild_id,
channel_id,
)
inv_codes = [r['code'] for r in inv_codes] inv_codes = [r["code"] for r in inv_codes]
invs = await async_map(_get_inv, inv_codes) invs = await async_map(_get_inv, inv_codes)
return jsonify(invs) return jsonify(invs)
@bp.route('/invite/<invite_code>', methods=['POST']) @bp.route("/invite/<invite_code>", methods=["POST"])
@bp.route('/invites/<invite_code>', methods=['POST']) @bp.route("/invites/<invite_code>", methods=["POST"])
async def _use_invite(invite_code): async def _use_invite(invite_code):
"""Use an invite.""" """Use an invite."""
user_id = await token_check() user_id = await token_check()
@ -327,9 +374,4 @@ async def _use_invite(invite_code):
inv = await app.storage.get_invite(invite_code) inv = await app.storage.get_invite(invite_code)
inv_meta = await app.storage.get_invite_metadata(invite_code) inv_meta = await app.storage.get_invite_metadata(invite_code)
return jsonify({ return jsonify({**inv, **{"inviter": inv_meta["inviter"]}})
**inv,
**{
'inviter': inv_meta['inviter']
}
})

View File

@ -19,83 +19,75 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
from quart import Blueprint, current_app as app, jsonify, request from quart import Blueprint, current_app as app, jsonify, request
bp = Blueprint('nodeinfo', __name__) bp = Blueprint("nodeinfo", __name__)
@bp.route('/.well-known/nodeinfo') @bp.route("/.well-known/nodeinfo")
async def _dummy_nodeinfo_index(): async def _dummy_nodeinfo_index():
proto = 'http' if not app.config['IS_SSL'] else 'https' proto = "http" if not app.config["IS_SSL"] else "https"
main_url = app.config.get('MAIN_URL', request.host) main_url = app.config.get("MAIN_URL", request.host)
return jsonify({ return jsonify(
'links': [{ {
'href': f'{proto}://{main_url}/nodeinfo/2.0.json', "links": [
'rel': 'http://nodeinfo.diaspora.software/ns/schema/2.0' {
}, { "href": f"{proto}://{main_url}/nodeinfo/2.0.json",
'href': f'{proto}://{main_url}/nodeinfo/2.1.json', "rel": "http://nodeinfo.diaspora.software/ns/schema/2.0",
'rel': 'http://nodeinfo.diaspora.software/ns/schema/2.1' },
}] {
}) "href": f"{proto}://{main_url}/nodeinfo/2.1.json",
"rel": "http://nodeinfo.diaspora.software/ns/schema/2.1",
},
]
}
)
async def fetch_nodeinfo_20(): async def fetch_nodeinfo_20():
usercount = await app.db.fetchval(""" usercount = await app.db.fetchval(
"""
SELECT COUNT(*) SELECT COUNT(*)
FROM users FROM users
""") """
)
message_count = await app.db.fetchval(""" message_count = await app.db.fetchval(
"""
SELECT COUNT(*) SELECT COUNT(*)
FROM messages FROM messages
""") """
)
return { return {
'metadata': { "metadata": {
'features': [ "features": ["discord_api"],
'discord_api' "nodeDescription": "A Litecord instance",
], "nodeName": "Litecord/Nya",
"private": False,
'nodeDescription': 'A Litecord instance', "federation": {},
'nodeName': 'Litecord/Nya',
'private': False,
'federation': {}
}, },
'openRegistrations': app.config['REGISTRATIONS'], "openRegistrations": app.config["REGISTRATIONS"],
'protocols': [], "protocols": [],
'software': { "software": {"name": "litecord", "version": "litecord v0"},
'name': 'litecord', "services": {"inbound": [], "outbound": []},
'version': 'litecord v0', "usage": {"localPosts": message_count, "users": {"total": usercount}},
}, "version": "2.0",
'services': {
'inbound': [],
'outbound': [],
},
'usage': {
'localPosts': message_count,
'users': {
'total': usercount
}
},
'version': '2.0',
} }
@bp.route('/nodeinfo/2.0.json') @bp.route("/nodeinfo/2.0.json")
async def _nodeinfo_20(): async def _nodeinfo_20():
"""Handler for nodeinfo 2.0.""" """Handler for nodeinfo 2.0."""
raw_nodeinfo = await fetch_nodeinfo_20() raw_nodeinfo = await fetch_nodeinfo_20()
return jsonify(raw_nodeinfo) return jsonify(raw_nodeinfo)
@bp.route('/nodeinfo/2.1.json') @bp.route("/nodeinfo/2.1.json")
async def _nodeinfo_21(): async def _nodeinfo_21():
"""Handler for nodeinfo 2.1.""" """Handler for nodeinfo 2.1."""
raw_nodeinfo = await fetch_nodeinfo_20() raw_nodeinfo = await fetch_nodeinfo_20()
raw_nodeinfo['software']['repository'] = 'https://gitlab.com/litecord/litecord' raw_nodeinfo["software"]["repository"] = "https://gitlab.com/litecord/litecord"
raw_nodeinfo['version'] = '2.1' raw_nodeinfo["version"] = "2.1"
return jsonify(raw_nodeinfo) return jsonify(raw_nodeinfo)

View File

@ -26,76 +26,89 @@ from ..enums import RelationshipType
from litecord.errors import BadRequest from litecord.errors import BadRequest
bp = Blueprint('relationship', __name__) bp = Blueprint("relationship", __name__)
@bp.route('/@me/relationships', methods=['GET']) @bp.route("/@me/relationships", methods=["GET"])
async def get_me_relationships(): async def get_me_relationships():
user_id = await token_check() user_id = await token_check()
return jsonify( return jsonify(await app.user_storage.get_relationships(user_id))
await app.user_storage.get_relationships(user_id))
async def _dispatch_single_pres(user_id, presence: dict): async def _dispatch_single_pres(user_id, presence: dict):
await app.dispatcher.dispatch( await app.dispatcher.dispatch("user", user_id, "PRESENCE_UPDATE", presence)
'user', user_id, 'PRESENCE_UPDATE', presence
)
async def _unsub_friend(user_id, peer_id): async def _unsub_friend(user_id, peer_id):
await app.dispatcher.unsub('friend', user_id, peer_id) await app.dispatcher.unsub("friend", user_id, peer_id)
await app.dispatcher.unsub('friend', peer_id, user_id) await app.dispatcher.unsub("friend", peer_id, user_id)
async def _sub_friend(user_id, peer_id): async def _sub_friend(user_id, peer_id):
await app.dispatcher.sub('friend', user_id, peer_id) await app.dispatcher.sub("friend", user_id, peer_id)
await app.dispatcher.sub('friend', peer_id, user_id) await app.dispatcher.sub("friend", peer_id, user_id)
# dispatch presence update to the user and peer about # dispatch presence update to the user and peer about
# eachother's presence. # eachother's presence.
user_pres, peer_pres = await app.presence.friend_presences( user_pres, peer_pres = await app.presence.friend_presences([user_id, peer_id])
[user_id, peer_id]
)
await _dispatch_single_pres(user_id, peer_pres) await _dispatch_single_pres(user_id, peer_pres)
await _dispatch_single_pres(peer_id, user_pres) await _dispatch_single_pres(peer_id, user_pres)
async def make_friend(user_id: int, peer_id: int, async def make_friend(
rel_type=RelationshipType.FRIEND.value): user_id: int, peer_id: int, rel_type=RelationshipType.FRIEND.value
):
_friend = RelationshipType.FRIEND.value _friend = RelationshipType.FRIEND.value
_block = RelationshipType.BLOCK.value _block = RelationshipType.BLOCK.value
try: try:
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO relationships (user_id, peer_id, rel_type) INSERT INTO relationships (user_id, peer_id, rel_type)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
""", user_id, peer_id, rel_type) """,
user_id,
peer_id,
rel_type,
)
except UniqueViolationError: except UniqueViolationError:
# try to update rel_type # try to update rel_type
old_rel_type = await app.db.fetchval(""" old_rel_type = await app.db.fetchval(
"""
SELECT rel_type SELECT rel_type
FROM relationships FROM relationships
WHERE user_id = $1 AND peer_id = $2 WHERE user_id = $1 AND peer_id = $2
""", user_id, peer_id) """,
user_id,
peer_id,
)
if old_rel_type == _friend and rel_type == _block: if old_rel_type == _friend and rel_type == _block:
await app.db.execute(""" await app.db.execute(
"""
UPDATE relationships UPDATE relationships
SET rel_type = $1 SET rel_type = $1
WHERE user_id = $2 AND peer_id = $3 WHERE user_id = $2 AND peer_id = $3
""", rel_type, user_id, peer_id) """,
rel_type,
user_id,
peer_id,
)
# remove any existing friendship before the block # remove any existing friendship before the block
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM relationships DELETE FROM relationships
WHERE peer_id = $1 AND user_id = $2 AND rel_type = $3 WHERE peer_id = $1 AND user_id = $2 AND rel_type = $3
""", peer_id, user_id, _friend) """,
peer_id,
user_id,
_friend,
)
await app.dispatcher.dispatch_user( await app.dispatcher.dispatch_user(
peer_id, 'RELATIONSHIP_REMOVE', { peer_id, "RELATIONSHIP_REMOVE", {"type": _friend, "id": str(user_id)}
'type': _friend,
'id': str(user_id)
}
) )
await _unsub_friend(user_id, peer_id) await _unsub_friend(user_id, peer_id)
@ -106,95 +119,118 @@ async def make_friend(user_id: int, peer_id: int,
# check if this is an acceptance # check if this is an acceptance
# of a friend request # of a friend request
existing = await app.db.fetchrow(""" existing = await app.db.fetchrow(
"""
SELECT user_id, peer_id SELECT user_id, peer_id
FROM relationships FROM relationships
WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3 WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3
""", peer_id, user_id, _friend) """,
peer_id,
user_id,
_friend,
)
_dispatch = app.dispatcher.dispatch_user _dispatch = app.dispatcher.dispatch_user
if existing: if existing:
# accepted a friend request, dispatch respective # accepted a friend request, dispatch respective
# relationship events # relationship events
await _dispatch(user_id, 'RELATIONSHIP_REMOVE', { await _dispatch(
'type': RelationshipType.INCOMING.value, user_id,
'id': str(peer_id) "RELATIONSHIP_REMOVE",
}) {"type": RelationshipType.INCOMING.value, "id": str(peer_id)},
)
await _dispatch(user_id, 'RELATIONSHIP_ADD', { await _dispatch(
'type': _friend, user_id,
'id': str(peer_id), "RELATIONSHIP_ADD",
'user': await app.storage.get_user(peer_id) {
}) "type": _friend,
"id": str(peer_id),
"user": await app.storage.get_user(peer_id),
},
)
await _dispatch(peer_id, 'RELATIONSHIP_ADD', { await _dispatch(
'type': _friend, peer_id,
'id': str(user_id), "RELATIONSHIP_ADD",
'user': await app.storage.get_user(user_id) {
}) "type": _friend,
"id": str(user_id),
"user": await app.storage.get_user(user_id),
},
)
await _sub_friend(user_id, peer_id) await _sub_friend(user_id, peer_id)
return '', 204 return "", 204
# check if friend AND not acceptance of fr # check if friend AND not acceptance of fr
if rel_type == _friend: if rel_type == _friend:
await _dispatch(user_id, 'RELATIONSHIP_ADD', { await _dispatch(
'id': str(peer_id), user_id,
'type': RelationshipType.OUTGOING.value, "RELATIONSHIP_ADD",
'user': await app.storage.get_user(peer_id), {
}) "id": str(peer_id),
"type": RelationshipType.OUTGOING.value,
"user": await app.storage.get_user(peer_id),
},
)
await _dispatch(peer_id, 'RELATIONSHIP_ADD', { await _dispatch(
'id': str(user_id), peer_id,
'type': RelationshipType.INCOMING.value, "RELATIONSHIP_ADD",
'user': await app.storage.get_user(user_id) {
}) "id": str(user_id),
"type": RelationshipType.INCOMING.value,
"user": await app.storage.get_user(user_id),
},
)
# we don't make the pubsub link # we don't make the pubsub link
# until the peer accepts the friend request # until the peer accepts the friend request
return '', 204 return "", 204
return return
class RelationshipFailed(BadRequest): class RelationshipFailed(BadRequest):
"""Exception for general relationship errors.""" """Exception for general relationship errors."""
error_code = 80004 error_code = 80004
class RelationshipBlocked(BadRequest): class RelationshipBlocked(BadRequest):
"""Exception for when the peer has blocked the user.""" """Exception for when the peer has blocked the user."""
error_code = 80001 error_code = 80001
@bp.route('/@me/relationships', methods=['POST']) @bp.route("/@me/relationships", methods=["POST"])
async def post_relationship(): async def post_relationship():
user_id = await token_check() user_id = await token_check()
j = validate(await request.get_json(), SPECIFIC_FRIEND) j = validate(await request.get_json(), SPECIFIC_FRIEND)
uid = await app.storage.search_user(j['username'], uid = await app.storage.search_user(j["username"], str(j["discriminator"]))
str(j['discriminator']))
if not uid: if not uid:
raise RelationshipFailed('No users with DiscordTag exist') raise RelationshipFailed("No users with DiscordTag exist")
res = await make_friend(user_id, uid) res = await make_friend(user_id, uid)
if res is None: if res is None:
raise RelationshipBlocked('Can not friend user due to block') raise RelationshipBlocked("Can not friend user due to block")
return '', 204 return "", 204
@bp.route('/@me/relationships/<int:peer_id>', methods=['PUT']) @bp.route("/@me/relationships/<int:peer_id>", methods=["PUT"])
async def add_relationship(peer_id: int): async def add_relationship(peer_id: int):
"""Add a relationship to the peer.""" """Add a relationship to the peer."""
user_id = await token_check() user_id = await token_check()
payload = validate(await request.get_json(), RELATIONSHIP) payload = validate(await request.get_json(), RELATIONSHIP)
rel_type = payload['type'] rel_type = payload["type"]
res = await make_friend(user_id, peer_id, rel_type) res = await make_friend(user_id, peer_id, rel_type)
@ -204,18 +240,22 @@ async def add_relationship(peer_id: int):
# make_friend did not succeed, so we # make_friend did not succeed, so we
# assume it is a block and dispatch # assume it is a block and dispatch
# the respective RELATIONSHIP_ADD. # the respective RELATIONSHIP_ADD.
await app.dispatcher.dispatch_user(user_id, 'RELATIONSHIP_ADD', { await app.dispatcher.dispatch_user(
'id': str(peer_id), user_id,
'type': RelationshipType.BLOCK.value, "RELATIONSHIP_ADD",
'user': await app.storage.get_user(peer_id) {
}) "id": str(peer_id),
"type": RelationshipType.BLOCK.value,
"user": await app.storage.get_user(peer_id),
},
)
await _unsub_friend(user_id, peer_id) await _unsub_friend(user_id, peer_id)
return '', 204 return "", 204
@bp.route('/@me/relationships/<int:peer_id>', methods=['DELETE']) @bp.route("/@me/relationships/<int:peer_id>", methods=["DELETE"])
async def remove_relationship(peer_id: int): async def remove_relationship(peer_id: int):
"""Remove an existing relationship""" """Remove an existing relationship"""
user_id = await token_check() user_id = await token_check()
@ -223,69 +263,86 @@ async def remove_relationship(peer_id: int):
_block = RelationshipType.BLOCK.value _block = RelationshipType.BLOCK.value
_dispatch = app.dispatcher.dispatch_user _dispatch = app.dispatcher.dispatch_user
rel_type = await app.db.fetchval(""" rel_type = await app.db.fetchval(
"""
SELECT rel_type SELECT rel_type
FROM relationships FROM relationships
WHERE user_id = $1 AND peer_id = $2 WHERE user_id = $1 AND peer_id = $2
""", user_id, peer_id) """,
user_id,
peer_id,
)
incoming_rel_type = await app.db.fetchval(""" incoming_rel_type = await app.db.fetchval(
"""
SELECT rel_type SELECT rel_type
FROM relationships FROM relationships
WHERE user_id = $1 AND peer_id = $2 WHERE user_id = $1 AND peer_id = $2
""", peer_id, user_id) """,
peer_id,
user_id,
)
# if any of those are friend # if any of those are friend
if _friend in (rel_type, incoming_rel_type): if _friend in (rel_type, incoming_rel_type):
# closing the friendship, have to delete both rows # closing the friendship, have to delete both rows
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM relationships DELETE FROM relationships
WHERE ( WHERE (
(user_id = $1 AND peer_id = $2) OR (user_id = $1 AND peer_id = $2) OR
(user_id = $2 AND peer_id = $1) (user_id = $2 AND peer_id = $1)
) AND rel_type = $3 ) AND rel_type = $3
""", user_id, peer_id, _friend) """,
user_id,
peer_id,
_friend,
)
# if there wasnt any mutual friendship before, # if there wasnt any mutual friendship before,
# assume they were requests of INCOMING # assume they were requests of INCOMING
# and OUTGOING. # and OUTGOING.
user_del_type = RelationshipType.OUTGOING.value if \ user_del_type = (
incoming_rel_type != _friend else _friend RelationshipType.OUTGOING.value if incoming_rel_type != _friend else _friend
)
await _dispatch(user_id, 'RELATIONSHIP_REMOVE', { await _dispatch(
'id': str(peer_id), user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": user_del_type}
'type': user_del_type, )
})
peer_del_type = RelationshipType.INCOMING.value if \ peer_del_type = (
incoming_rel_type != _friend else _friend RelationshipType.INCOMING.value if incoming_rel_type != _friend else _friend
)
await _dispatch(peer_id, 'RELATIONSHIP_REMOVE', { await _dispatch(
'id': str(user_id), peer_id, "RELATIONSHIP_REMOVE", {"id": str(user_id), "type": peer_del_type}
'type': peer_del_type, )
})
await _unsub_friend(user_id, peer_id) await _unsub_friend(user_id, peer_id)
return '', 204 return "", 204
# was a block! # was a block!
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM relationships DELETE FROM relationships
WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3 WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3
""", user_id, peer_id, _block) """,
user_id,
peer_id,
_block,
)
await _dispatch(user_id, 'RELATIONSHIP_REMOVE', { await _dispatch(
'id': str(peer_id), user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": _block}
'type': _block, )
})
await _unsub_friend(user_id, peer_id) await _unsub_friend(user_id, peer_id)
return '', 204 return "", 204
@bp.route('/<int:peer_id>/relationships', methods=['GET']) @bp.route("/<int:peer_id>/relationships", methods=["GET"])
async def get_mutual_friends(peer_id: int): async def get_mutual_friends(peer_id: int):
"""Fetch a users' mutual friends with the current user.""" """Fetch a users' mutual friends with the current user."""
user_id = await token_check() user_id = await token_check()
@ -294,17 +351,15 @@ async def get_mutual_friends(peer_id: int):
peer = await app.storage.get_user(peer_id) peer = await app.storage.get_user(peer_id)
if not peer: if not peer:
return '', 204 return "", 204
# NOTE: maybe this could be better with pure SQL calculations # NOTE: maybe this could be better with pure SQL calculations
# but it would be beyond my current SQL knowledge, so... # but it would be beyond my current SQL knowledge, so...
user_rels = await app.user_storage.get_relationships(user_id) user_rels = await app.user_storage.get_relationships(user_id)
peer_rels = await app.user_storage.get_relationships(peer_id) peer_rels = await app.user_storage.get_relationships(peer_id)
user_friends = {rel['user']['id'] user_friends = {rel["user"]["id"] for rel in user_rels if rel["type"] == _friend}
for rel in user_rels if rel['type'] == _friend} peer_friends = {rel["user"]["id"] for rel in peer_rels if rel["type"] == _friend}
peer_friends = {rel['user']['id']
for rel in peer_rels if rel['type'] == _friend}
# get the intersection, then map them to Storage.get_user() calls # get the intersection, then map them to Storage.get_user() calls
mutual_ids = user_friends & peer_friends mutual_ids = user_friends & peer_friends
@ -312,8 +367,6 @@ async def get_mutual_friends(peer_id: int):
mutual_friends = [] mutual_friends = []
for friend_id in mutual_ids: for friend_id in mutual_ids:
mutual_friends.append( mutual_friends.append(await app.storage.get_user(int(friend_id)))
await app.storage.get_user(int(friend_id))
)
return jsonify(mutual_friends) return jsonify(mutual_friends)

View File

@ -19,21 +19,19 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
from quart import Blueprint, jsonify from quart import Blueprint, jsonify
bp = Blueprint('science', __name__) bp = Blueprint("science", __name__)
@bp.route('/science', methods=['POST']) @bp.route("/science", methods=["POST"])
async def science(): async def science():
return '', 204 return "", 204
@bp.route('/applications', methods=['GET']) @bp.route("/applications", methods=["GET"])
async def applications(): async def applications():
return jsonify([]) return jsonify([])
@bp.route('/experiments', methods=['GET']) @bp.route("/experiments", methods=["GET"])
async def experiments(): async def experiments():
return jsonify({ return jsonify({"assignments": []})
'assignments': []
})

View File

@ -20,23 +20,24 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
from quart import Blueprint, current_app as app, render_template_string from quart import Blueprint, current_app as app, render_template_string
from pathlib import Path from pathlib import Path
bp = Blueprint('static', __name__) bp = Blueprint("static", __name__)
@bp.route('/<path:path>') @bp.route("/<path:path>")
async def static_pages(path): async def static_pages(path):
"""Map requests from / to /static.""" """Map requests from / to /static."""
if '..' in path: if ".." in path:
return 'no', 404 return "no", 404
static_path = Path.cwd() / Path('static') / path static_path = Path.cwd() / Path("static") / path
return await app.send_static_file(str(static_path)) return await app.send_static_file(str(static_path))
@bp.route('/') @bp.route("/")
@bp.route('/api') @bp.route("/api")
async def index_handler(): async def index_handler():
"""Handler for the index page.""" """Handler for the index page."""
index_path = Path.cwd() / Path('static') / 'index.html' index_path = Path.cwd() / Path("static") / "index.html"
return await render_template_string( return await render_template_string(
index_path.read_text(), inst_name=app.config['NAME']) index_path.read_text(), inst_name=app.config["NAME"]
)

View File

@ -21,4 +21,4 @@ from .billing import bp as user_billing
from .settings import bp as user_settings from .settings import bp as user_settings
from .fake_store import bp as fake_store from .fake_store import bp as fake_store
__all__ = ['user_billing', 'user_settings', 'fake_store'] __all__ = ["user_billing", "user_settings", "fake_store"]

View File

@ -33,7 +33,7 @@ from litecord.enums import UserFlags, PremiumType
from litecord.blueprints.users import mass_user_update from litecord.blueprints.users import mass_user_update
log = Logger(__name__) log = Logger(__name__)
bp = Blueprint('users_billing', __name__) bp = Blueprint("users_billing", __name__)
class PaymentSource(Enum): class PaymentSource(Enum):
@ -68,78 +68,87 @@ class PaymentStatus:
PLAN_ID_TO_TYPE = { PLAN_ID_TO_TYPE = {
'premium_month_tier_1': PremiumType.TIER_1, "premium_month_tier_1": PremiumType.TIER_1,
'premium_month_tier_2': PremiumType.TIER_2, "premium_month_tier_2": PremiumType.TIER_2,
'premium_year_tier_1': PremiumType.TIER_1, "premium_year_tier_1": PremiumType.TIER_1,
'premium_year_tier_2': PremiumType.TIER_2, "premium_year_tier_2": PremiumType.TIER_2,
} }
# how much should a payment be, depending # how much should a payment be, depending
# of the subscription # of the subscription
AMOUNTS = { AMOUNTS = {
'premium_month_tier_1': 499, "premium_month_tier_1": 499,
'premium_month_tier_2': 999, "premium_month_tier_2": 999,
'premium_year_tier_1': 4999, "premium_year_tier_1": 4999,
'premium_year_tier_2': 9999, "premium_year_tier_2": 9999,
} }
CREATE_SUBSCRIPTION = { CREATE_SUBSCRIPTION = {
'payment_gateway_plan_id': {'type': 'string'}, "payment_gateway_plan_id": {"type": "string"},
'payment_source_id': {'coerce': int} "payment_source_id": {"coerce": int},
} }
PAYMENT_SOURCE = { PAYMENT_SOURCE = {
'billing_address': { "billing_address": {
'type': 'dict', "type": "dict",
'schema': { "schema": {
'country': {'type': 'string', 'required': True}, "country": {"type": "string", "required": True},
'city': {'type': 'string', 'required': True}, "city": {"type": "string", "required": True},
'name': {'type': 'string', 'required': True}, "name": {"type": "string", "required": True},
'line_1': {'type': 'string', 'required': False}, "line_1": {"type": "string", "required": False},
'line_2': {'type': 'string', 'required': False}, "line_2": {"type": "string", "required": False},
'postal_code': {'type': 'string', 'required': True}, "postal_code": {"type": "string", "required": True},
'state': {'type': 'string', 'required': True}, "state": {"type": "string", "required": True},
} },
}, },
'payment_gateway': {'type': 'number', 'required': True}, "payment_gateway": {"type": "number", "required": True},
'token': {'type': 'string', 'required': True}, "token": {"type": "string", "required": True},
} }
async def get_payment_source_ids(user_id: int) -> list: async def get_payment_source_ids(user_id: int) -> list:
rows = await app.db.fetch(""" rows = await app.db.fetch(
"""
SELECT id SELECT id
FROM user_payment_sources FROM user_payment_sources
WHERE user_id = $1 WHERE user_id = $1
""", user_id) """,
user_id,
)
return [r['id'] for r in rows] return [r["id"] for r in rows]
async def get_payment_ids(user_id: int, db=None) -> list: async def get_payment_ids(user_id: int, db=None) -> list:
if not db: if not db:
db = app.db db = app.db
rows = await db.fetch(""" rows = await db.fetch(
"""
SELECT id SELECT id
FROM user_payments FROM user_payments
WHERE user_id = $1 WHERE user_id = $1
""", user_id) """,
user_id,
)
return [r['id'] for r in rows] return [r["id"] for r in rows]
async def get_subscription_ids(user_id: int) -> list: async def get_subscription_ids(user_id: int) -> list:
rows = await app.db.fetch(""" rows = await app.db.fetch(
"""
SELECT id SELECT id
FROM user_subscriptions FROM user_subscriptions
WHERE user_id = $1 WHERE user_id = $1
""", user_id) """,
user_id,
)
return [r['id'] for r in rows] return [r["id"] for r in rows]
async def get_payment_source(user_id: int, source_id: int, db=None) -> dict: async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
@ -148,41 +157,44 @@ async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
if not db: if not db:
db = app.db db = app.db
source_type = await db.fetchval(""" source_type = await db.fetchval(
"""
SELECT source_type SELECT source_type
FROM user_payment_sources FROM user_payment_sources
WHERE id = $1 AND user_id = $2 WHERE id = $1 AND user_id = $2
""", source_id, user_id) """,
source_id,
user_id,
)
source_type = PaymentSource(source_type) source_type = PaymentSource(source_type)
specific_fields = { specific_fields = {
PaymentSource.PAYPAL: ['paypal_email'], PaymentSource.PAYPAL: ["paypal_email"],
PaymentSource.CREDIT: ['expires_month', 'expires_year', PaymentSource.CREDIT: ["expires_month", "expires_year", "brand", "cc_full"],
'brand', 'cc_full']
}[source_type] }[source_type]
fields = ','.join(specific_fields) fields = ",".join(specific_fields)
extras_row = await db.fetchrow(f""" extras_row = await db.fetchrow(
f"""
SELECT {fields}, billing_address, default_, id::text SELECT {fields}, billing_address, default_, id::text
FROM user_payment_sources FROM user_payment_sources
WHERE id = $1 WHERE id = $1
""", source_id) """,
source_id,
)
derow = dict(extras_row) derow = dict(extras_row)
if source_type == PaymentSource.CREDIT: if source_type == PaymentSource.CREDIT:
derow['last_4'] = derow['cc_full'][-4:] derow["last_4"] = derow["cc_full"][-4:]
derow.pop('cc_full') derow.pop("cc_full")
derow['default'] = derow['default_'] derow["default"] = derow["default_"]
derow.pop('default_') derow.pop("default_")
source = { source = {"id": str(source_id), "type": source_type.value}
'id': str(source_id),
'type': source_type.value,
}
return {**source, **derow} return {**source, **derow}
@ -192,7 +204,8 @@ async def get_subscription(subscription_id: int, db=None):
if not db: if not db:
db = app.db db = app.db
row = await db.fetchrow(""" row = await db.fetchrow(
"""
SELECT id::text, source_id::text AS payment_source_id, SELECT id::text, source_id::text AS payment_source_id,
user_id, user_id,
payment_gateway, payment_gateway_plan_id, payment_gateway, payment_gateway_plan_id,
@ -201,14 +214,16 @@ async def get_subscription(subscription_id: int, db=None):
canceled_at, s_type, status canceled_at, s_type, status
FROM user_subscriptions FROM user_subscriptions
WHERE id = $1 WHERE id = $1
""", subscription_id) """,
subscription_id,
)
drow = dict(row) drow = dict(row)
drow['type'] = drow['s_type'] drow["type"] = drow["s_type"]
drow.pop('s_type') drow.pop("s_type")
to_tstamp = ['current_period_start', 'current_period_end', 'canceled_at'] to_tstamp = ["current_period_start", "current_period_end", "canceled_at"]
for field in to_tstamp: for field in to_tstamp:
drow[field] = timestamp_(drow[field]) drow[field] = timestamp_(drow[field])
@ -221,27 +236,30 @@ async def get_payment(payment_id: int, db=None):
if not db: if not db:
db = app.db db = app.db
row = await db.fetchrow(""" row = await db.fetchrow(
"""
SELECT id::text, source_id, subscription_id, user_id, SELECT id::text, source_id, subscription_id, user_id,
amount, amount_refunded, currency, amount, amount_refunded, currency,
description, status, tax, tax_inclusive description, status, tax, tax_inclusive
FROM user_payments FROM user_payments
WHERE id = $1 WHERE id = $1
""", payment_id) """,
payment_id,
)
drow = dict(row) drow = dict(row)
drow.pop('source_id') drow.pop("source_id")
drow.pop('subscription_id') drow.pop("subscription_id")
drow.pop('user_id') drow.pop("user_id")
drow['created_at'] = snowflake_datetime(int(drow['id'])) drow["created_at"] = snowflake_datetime(int(drow["id"]))
drow['payment_source'] = await get_payment_source( drow["payment_source"] = await get_payment_source(
row['user_id'], row['source_id'], db) row["user_id"], row["source_id"], db
)
drow['subscription'] = await get_subscription( drow["subscription"] = await get_subscription(row["subscription_id"], db)
row['subscription_id'], db)
return drow return drow
@ -255,7 +273,7 @@ async def create_payment(subscription_id, db=None):
new_id = get_snowflake() new_id = get_snowflake()
amount = AMOUNTS[sub['payment_gateway_plan_id']] amount = AMOUNTS[sub["payment_gateway_plan_id"]]
await db.execute( await db.execute(
""" """
@ -266,10 +284,16 @@ async def create_payment(subscription_id, db=None):
) )
VALUES VALUES
($1, $2, $3, $4, $5, 0, $6, $7, $8, 0, false) ($1, $2, $3, $4, $5, 0, $6, $7, $8, 0, false)
""", new_id, int(sub['payment_source_id']), """,
subscription_id, int(sub['user_id']), new_id,
amount, 'usd', 'FUCK NITRO', int(sub["payment_source_id"]),
PaymentStatus.SUCCESS) subscription_id,
int(sub["user_id"]),
amount,
"usd",
"FUCK NITRO",
PaymentStatus.SUCCESS,
)
return new_id return new_id
@ -278,29 +302,34 @@ async def process_subscription(app, subscription_id: int):
"""Process a single subscription.""" """Process a single subscription."""
sub = await get_subscription(subscription_id, app.db) sub = await get_subscription(subscription_id, app.db)
user_id = int(sub['user_id']) user_id = int(sub["user_id"])
if sub['status'] != SubscriptionStatus.ACTIVE: if sub["status"] != SubscriptionStatus.ACTIVE:
log.debug('ignoring sub {}, not active', log.debug("ignoring sub {}, not active", subscription_id)
subscription_id)
return return
# if the subscription is still active # if the subscription is still active
# (should get cancelled status on failed # (should get cancelled status on failed
# payments), then we should update premium status # payments), then we should update premium status
first_payment_id = await app.db.fetchval(""" first_payment_id = await app.db.fetchval(
"""
SELECT MIN(id) SELECT MIN(id)
FROM user_payments FROM user_payments
WHERE subscription_id = $1 WHERE subscription_id = $1
""", subscription_id) """,
subscription_id,
)
first_payment_ts = snowflake_datetime(first_payment_id) first_payment_ts = snowflake_datetime(first_payment_id)
premium_since = await app.db.fetchval(""" premium_since = await app.db.fetchval(
"""
SELECT premium_since SELECT premium_since
FROM users FROM users
WHERE id = $1 WHERE id = $1
""", user_id) """,
user_id,
)
premium_since = premium_since or datetime.datetime.fromtimestamp(0) premium_since = premium_since or datetime.datetime.fromtimestamp(0)
@ -312,27 +341,34 @@ async def process_subscription(app, subscription_id: int):
if delta.total_seconds() < 24 * HOURS: if delta.total_seconds() < 24 * HOURS:
return return
old_flags = await app.db.fetchval(""" old_flags = await app.db.fetchval(
"""
SELECT flags SELECT flags
FROM users FROM users
WHERE id = $1 WHERE id = $1
""", user_id) """,
user_id,
)
new_flags = old_flags | UserFlags.premium_early new_flags = old_flags | UserFlags.premium_early
log.debug('updating flags {}, {} => {}', log.debug("updating flags {}, {} => {}", user_id, old_flags, new_flags)
user_id, old_flags, new_flags)
await app.db.execute(""" await app.db.execute(
"""
UPDATE users UPDATE users
SET premium_since = $1, flags = $2 SET premium_since = $1, flags = $2
WHERE id = $3 WHERE id = $3
""", first_payment_ts, new_flags, user_id) """,
first_payment_ts,
new_flags,
user_id,
)
# dispatch updated user to all possible clients # dispatch updated user to all possible clients
await mass_user_update(user_id, app) await mass_user_update(user_id, app)
@bp.route('/@me/billing/payment-sources', methods=['GET']) @bp.route("/@me/billing/payment-sources", methods=["GET"])
async def _get_billing_sources(): async def _get_billing_sources():
user_id = await token_check() user_id = await token_check()
source_ids = await get_payment_source_ids(user_id) source_ids = await get_payment_source_ids(user_id)
@ -346,7 +382,7 @@ async def _get_billing_sources():
return jsonify(res) return jsonify(res)
@bp.route('/@me/billing/subscriptions', methods=['GET']) @bp.route("/@me/billing/subscriptions", methods=["GET"])
async def _get_billing_subscriptions(): async def _get_billing_subscriptions():
user_id = await token_check() user_id = await token_check()
sub_ids = await get_subscription_ids(user_id) sub_ids = await get_subscription_ids(user_id)
@ -358,7 +394,7 @@ async def _get_billing_subscriptions():
return jsonify(res) return jsonify(res)
@bp.route('/@me/billing/payments', methods=['GET']) @bp.route("/@me/billing/payments", methods=["GET"])
async def _get_billing_payments(): async def _get_billing_payments():
user_id = await token_check() user_id = await token_check()
payment_ids = await get_payment_ids(user_id) payment_ids = await get_payment_ids(user_id)
@ -370,7 +406,7 @@ async def _get_billing_payments():
return jsonify(res) return jsonify(res)
@bp.route('/@me/billing/payment-sources', methods=['POST']) @bp.route("/@me/billing/payment-sources", methods=["POST"])
async def _create_payment_source(): async def _create_payment_source():
user_id = await token_check() user_id = await token_check()
j = validate(await request.get_json(), PAYMENT_SOURCE) j = validate(await request.get_json(), PAYMENT_SOURCE)
@ -383,34 +419,40 @@ async def _create_payment_source():
default_, expires_month, expires_year, brand, cc_full, default_, expires_month, expires_year, brand, cc_full,
billing_address) billing_address)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
""", new_source_id, user_id, PaymentSource.CREDIT.value, """,
True, 12, 6969, 'Visa', '4242424242424242', new_source_id,
json.dumps(j['billing_address'])) user_id,
PaymentSource.CREDIT.value,
return jsonify( True,
await get_payment_source(user_id, new_source_id) 12,
6969,
"Visa",
"4242424242424242",
json.dumps(j["billing_address"]),
) )
return jsonify(await get_payment_source(user_id, new_source_id))
@bp.route('/@me/billing/subscriptions', methods=['POST'])
@bp.route("/@me/billing/subscriptions", methods=["POST"])
async def _create_subscription(): async def _create_subscription():
user_id = await token_check() user_id = await token_check()
j = validate(await request.get_json(), CREATE_SUBSCRIPTION) j = validate(await request.get_json(), CREATE_SUBSCRIPTION)
source = await get_payment_source(user_id, j['payment_source_id']) source = await get_payment_source(user_id, j["payment_source_id"])
if not source: if not source:
raise BadRequest('invalid source id') raise BadRequest("invalid source id")
plan_id = j['payment_gateway_plan_id'] plan_id = j["payment_gateway_plan_id"]
# tier 1 is lightro / classic # tier 1 is lightro / classic
# tier 2 is nitro # tier 2 is nitro
period_end = { period_end = {
'premium_month_tier_1': '1 month', "premium_month_tier_1": "1 month",
'premium_month_tier_2': '1 month', "premium_month_tier_2": "1 month",
'premium_year_tier_1': '1 year', "premium_year_tier_1": "1 year",
'premium_year_tier_2': '1 year', "premium_year_tier_2": "1 year",
}[plan_id] }[plan_id]
new_id = get_snowflake() new_id = get_snowflake()
@ -422,9 +464,15 @@ async def _create_subscription():
status, period_end) status, period_end)
VALUES ($1, $2, $3, $4, $5, $6, $7, VALUES ($1, $2, $3, $4, $5, $6, $7,
now()::timestamp + interval '{period_end}') now()::timestamp + interval '{period_end}')
""", new_id, j['payment_source_id'], user_id, """,
SubscriptionType.PURCHASE, PaymentGateway.STRIPE, new_id,
plan_id, 1) j["payment_source_id"],
user_id,
SubscriptionType.PURCHASE,
PaymentGateway.STRIPE,
plan_id,
1,
)
await create_payment(new_id, app.db) await create_payment(new_id, app.db)
@ -432,21 +480,17 @@ async def _create_subscription():
# and dispatch respective user updates to other people. # and dispatch respective user updates to other people.
await process_subscription(app, new_id) await process_subscription(app, new_id)
return jsonify( return jsonify(await get_subscription(new_id))
await get_subscription(new_id)
)
@bp.route('/@me/billing/subscriptions/<int:subscription_id>', @bp.route("/@me/billing/subscriptions/<int:subscription_id>", methods=["DELETE"])
methods=['DELETE'])
async def _delete_subscription(subscription_id): async def _delete_subscription(subscription_id):
# user_id = await token_check() # user_id = await token_check()
# return '', 204 # return '', 204
pass pass
@bp.route('/@me/billing/subscriptions/<int:subscription_id>', @bp.route("/@me/billing/subscriptions/<int:subscription_id>", methods=["PATCH"])
methods=['PATCH'])
async def _patch_subscription(subscription_id): async def _patch_subscription(subscription_id):
"""change a subscription's payment source""" """change a subscription's payment source"""
# user_id = await token_check() # user_id = await token_check()

View File

@ -25,8 +25,11 @@ from asyncio import sleep, CancelledError
from logbook import Logger from logbook import Logger
from litecord.blueprints.user.billing import ( from litecord.blueprints.user.billing import (
get_subscription, get_payment_ids, get_payment, create_payment, get_subscription,
process_subscription get_payment_ids,
get_payment,
create_payment,
process_subscription,
) )
from litecord.snowflake import snowflake_datetime from litecord.snowflake import snowflake_datetime
@ -37,15 +40,15 @@ log = Logger(__name__)
# how many days until a payment needs # how many days until a payment needs
# to be issued # to be issued
THRESHOLDS = { THRESHOLDS = {
'premium_month_tier_1': 30, "premium_month_tier_1": 30,
'premium_month_tier_2': 30, "premium_month_tier_2": 30,
'premium_year_tier_1': 365, "premium_year_tier_1": 365,
'premium_year_tier_2': 365, "premium_year_tier_2": 365,
} }
async def _resched(app): async def _resched(app):
log.debug('waiting 30 minutes for job.') log.debug("waiting 30 minutes for job.")
await sleep(30 * MINUTES) await sleep(30 * MINUTES)
app.sched.spawn(payment_job(app)) app.sched.spawn(payment_job(app))
@ -54,10 +57,10 @@ async def _process_user_payments(app, user_id: int):
payments = await get_payment_ids(user_id, app.db) payments = await get_payment_ids(user_id, app.db)
if not payments: if not payments:
log.debug('no payments for uid {}, skipping', user_id) log.debug("no payments for uid {}, skipping", user_id)
return return
log.debug('{} payments for uid {}', len(payments), user_id) log.debug("{} payments for uid {}", len(payments), user_id)
latest_payment = max(payments) latest_payment = max(payments)
@ -66,33 +69,29 @@ async def _process_user_payments(app, user_id: int):
# calculate the difference between this payment # calculate the difference between this payment
# and now. # and now.
now = datetime.datetime.now() now = datetime.datetime.now()
payment_tstamp = snowflake_datetime(int(payment_data['id'])) payment_tstamp = snowflake_datetime(int(payment_data["id"]))
delta = now - payment_tstamp delta = now - payment_tstamp
sub_id = int(payment_data['subscription']['id']) sub_id = int(payment_data["subscription"]["id"])
subscription = await get_subscription( subscription = await get_subscription(sub_id, app.db)
sub_id, app.db)
# if the max payment is X days old, we create another. # if the max payment is X days old, we create another.
# X is 30 for monthly subscriptions of nitro, # X is 30 for monthly subscriptions of nitro,
# X is 365 for yearly subscriptions of nitro # X is 365 for yearly subscriptions of nitro
threshold = THRESHOLDS[subscription['payment_gateway_plan_id']] threshold = THRESHOLDS[subscription["payment_gateway_plan_id"]]
log.debug('delta {} delta days {} threshold {}', log.debug("delta {} delta days {} threshold {}", delta, delta.days, threshold)
delta, delta.days, threshold)
if delta.days > threshold: if delta.days > threshold:
log.info('creating payment for sid={}', log.info("creating payment for sid={}", sub_id)
sub_id)
# create_payment does not call any Stripe # create_payment does not call any Stripe
# or BrainTree APIs at all, since we'll just # or BrainTree APIs at all, since we'll just
# give it as free. # give it as free.
await create_payment(sub_id, app.db) await create_payment(sub_id, app.db)
else: else:
log.debug('sid={}, missing {} days', log.debug("sid={}, missing {} days", sub_id, threshold - delta.days)
sub_id, threshold - delta.days)
async def payment_job(app): async def payment_job(app):
@ -101,35 +100,39 @@ async def payment_job(app):
This function will check through users' payments This function will check through users' payments
and add a new one once a month / year. and add a new one once a month / year.
""" """
log.debug('payment job start!') log.debug("payment job start!")
user_ids = await app.db.fetch(""" user_ids = await app.db.fetch(
"""
SELECT DISTINCT user_id SELECT DISTINCT user_id
FROM user_payments FROM user_payments
""") """
)
log.debug('working {} users', len(user_ids)) log.debug("working {} users", len(user_ids))
# go through each user's payments # go through each user's payments
for row in user_ids: for row in user_ids:
user_id = row['user_id'] user_id = row["user_id"]
try: try:
await _process_user_payments(app, user_id) await _process_user_payments(app, user_id)
except Exception: except Exception:
log.exception('error while processing user payments') log.exception("error while processing user payments")
subscribers = await app.db.fetch(""" subscribers = await app.db.fetch(
"""
SELECT id SELECT id
FROM user_subscriptions FROM user_subscriptions
""") """
)
for row in subscribers: for row in subscribers:
try: try:
await process_subscription(app, row['id']) await process_subscription(app, row["id"])
except Exception: except Exception:
log.exception('error while processing subscription') log.exception("error while processing subscription")
log.debug('rescheduling..') log.debug("rescheduling..")
try: try:
await _resched(app) await _resched(app)
except CancelledError: except CancelledError:
log.info('cancelled while waiting for resched') log.info("cancelled while waiting for resched")

View File

@ -22,24 +22,26 @@ fake routes for discord store
""" """
from quart import Blueprint, jsonify from quart import Blueprint, jsonify
bp = Blueprint('fake_store', __name__) bp = Blueprint("fake_store", __name__)
@bp.route('/promotions') @bp.route("/promotions")
async def _get_promotions(): async def _get_promotions():
return jsonify([]) return jsonify([])
@bp.route('/users/@me/library') @bp.route("/users/@me/library")
async def _get_library(): async def _get_library():
return jsonify([]) return jsonify([])
@bp.route('/users/@me/feed/settings') @bp.route("/users/@me/feed/settings")
async def _get_feed_settings(): async def _get_feed_settings():
return jsonify({ return jsonify(
'subscribed_games': [], {
'subscribed_users': [], "subscribed_games": [],
'unsubscribed_users': [], "subscribed_users": [],
'unsubscribed_games': [], "unsubscribed_users": [],
}) "unsubscribed_games": [],
}
)

View File

@ -23,10 +23,10 @@ from litecord.auth import token_check
from litecord.schemas import validate, USER_SETTINGS, GUILD_SETTINGS from litecord.schemas import validate, USER_SETTINGS, GUILD_SETTINGS
from litecord.blueprints.checks import guild_check from litecord.blueprints.checks import guild_check
bp = Blueprint('users_settings', __name__) bp = Blueprint("users_settings", __name__)
@bp.route('/@me/settings', methods=['GET']) @bp.route("/@me/settings", methods=["GET"])
async def get_user_settings(): async def get_user_settings():
"""Get the current user's settings.""" """Get the current user's settings."""
user_id = await token_check() user_id = await token_check()
@ -34,7 +34,7 @@ async def get_user_settings():
return jsonify(settings) return jsonify(settings)
@bp.route('/@me/settings', methods=['PATCH']) @bp.route("/@me/settings", methods=["PATCH"])
async def patch_current_settings(): async def patch_current_settings():
"""Patch the users' current settings. """Patch the users' current settings.
@ -47,19 +47,22 @@ async def patch_current_settings():
for key in j: for key in j:
val = j[key] val = j[key]
await app.storage.execute_with_json(f""" await app.storage.execute_with_json(
f"""
UPDATE user_settings UPDATE user_settings
SET {key}=$1 SET {key}=$1
WHERE id = $2 WHERE id = $2
""", val, user_id) """,
val,
user_id,
)
settings = await app.user_storage.get_user_settings(user_id) settings = await app.user_storage.get_user_settings(user_id)
await app.dispatcher.dispatch_user( await app.dispatcher.dispatch_user(user_id, "USER_SETTINGS_UPDATE", settings)
user_id, 'USER_SETTINGS_UPDATE', settings)
return jsonify(settings) return jsonify(settings)
@bp.route('/@me/guilds/<int:guild_id>/settings', methods=['PATCH']) @bp.route("/@me/guilds/<int:guild_id>/settings", methods=["PATCH"])
async def patch_guild_settings(guild_id: int): async def patch_guild_settings(guild_id: int):
"""Update the users' guild settings for a given guild. """Update the users' guild settings for a given guild.
@ -74,16 +77,21 @@ async def patch_guild_settings(guild_id: int):
# will make sure they exist in the table. # will make sure they exist in the table.
await app.user_storage.get_guild_settings_one(user_id, guild_id) await app.user_storage.get_guild_settings_one(user_id, guild_id)
for field in (k for k in j.keys() if k != 'channel_overrides'): for field in (k for k in j.keys() if k != "channel_overrides"):
await app.db.execute(f""" await app.db.execute(
f"""
UPDATE guild_settings UPDATE guild_settings
SET {field} = $1 SET {field} = $1
WHERE user_id = $2 AND guild_id = $3 WHERE user_id = $2 AND guild_id = $3
""", j[field], user_id, guild_id) """,
j[field],
user_id,
guild_id,
)
chan_ids = await app.storage.get_channel_ids(guild_id) chan_ids = await app.storage.get_channel_ids(guild_id)
for chandata in j.get('channel_overrides', {}).items(): for chandata in j.get("channel_overrides", {}).items():
chan_id, chan_overrides = chandata chan_id, chan_overrides = chandata
chan_id = int(chan_id) chan_id = int(chan_id)
@ -92,7 +100,8 @@ async def patch_guild_settings(guild_id: int):
continue continue
for field in chan_overrides: for field in chan_overrides:
await app.db.execute(f""" await app.db.execute(
f"""
INSERT INTO guild_settings_channel_overrides INSERT INTO guild_settings_channel_overrides
(user_id, guild_id, channel_id, {field}) (user_id, guild_id, channel_id, {field})
VALUES VALUES
@ -105,18 +114,21 @@ async def patch_guild_settings(guild_id: int):
WHERE guild_settings_channel_overrides.user_id = $1 WHERE guild_settings_channel_overrides.user_id = $1
AND guild_settings_channel_overrides.guild_id = $2 AND guild_settings_channel_overrides.guild_id = $2
AND guild_settings_channel_overrides.channel_id = $3 AND guild_settings_channel_overrides.channel_id = $3
""", user_id, guild_id, chan_id, chan_overrides[field]) """,
user_id,
guild_id,
chan_id,
chan_overrides[field],
)
settings = await app.user_storage.get_guild_settings_one( settings = await app.user_storage.get_guild_settings_one(user_id, guild_id)
user_id, guild_id)
await app.dispatcher.dispatch_user( await app.dispatcher.dispatch_user(user_id, "USER_GUILD_SETTINGS_UPDATE", settings)
user_id, 'USER_GUILD_SETTINGS_UPDATE', settings)
return jsonify(settings) return jsonify(settings)
@bp.route('/@me/notes/<int:target_id>', methods=['PUT']) @bp.route("/@me/notes/<int:target_id>", methods=["PUT"])
async def put_note(target_id: int): async def put_note(target_id: int):
"""Put a note to a user. """Put a note to a user.
@ -126,10 +138,11 @@ async def put_note(target_id: int):
user_id = await token_check() user_id = await token_check()
j = await request.get_json() j = await request.get_json()
note = str(j['note']) note = str(j["note"])
# UPSERTs are beautiful # UPSERTs are beautiful
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO notes (user_id, target_id, note) INSERT INTO notes (user_id, target_id, note)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
@ -138,12 +151,14 @@ async def put_note(target_id: int):
note = $3 note = $3
WHERE notes.user_id = $1 WHERE notes.user_id = $1
AND notes.target_id = $2 AND notes.target_id = $2
""", user_id, target_id, note) """,
user_id,
target_id,
note,
)
await app.dispatcher.dispatch_user(user_id, 'USER_NOTE_UPDATE', { await app.dispatcher.dispatch_user(
'id': str(target_id), user_id, "USER_NOTE_UPDATE", {"id": str(target_id), "note": note}
'note': note, )
})
return '', 204
return "", 204

View File

@ -27,9 +27,7 @@ from ..errors import Forbidden, BadRequest, Unauthorized
from ..schemas import validate, USER_UPDATE, GET_MENTIONS from ..schemas import validate, USER_UPDATE, GET_MENTIONS
from .guilds import guild_check from .guilds import guild_check
from litecord.auth import ( from litecord.auth import token_check, hash_data, check_username_usage, roll_discrim
token_check, hash_data, check_username_usage, roll_discrim
)
from litecord.blueprints.guild.mod import remove_member from litecord.blueprints.guild.mod import remove_member
from litecord.enums import PremiumType from litecord.enums import PremiumType
@ -39,7 +37,7 @@ from litecord.permissions import base_permissions
from litecord.blueprints.auth import check_password from litecord.blueprints.auth import check_password
from litecord.utils import to_update from litecord.utils import to_update
bp = Blueprint('user', __name__) bp = Blueprint("user", __name__)
log = Logger(__name__) log = Logger(__name__)
@ -58,8 +56,7 @@ async def mass_user_update(user_id, app_=None):
private_user = await app_.storage.get_user(user_id, secure=True) private_user = await app_.storage.get_user(user_id, secure=True)
session_ids.extend( session_ids.extend(
await app_.dispatcher.dispatch_user( await app_.dispatcher.dispatch_user(user_id, "USER_UPDATE", private_user)
user_id, 'USER_UPDATE', private_user)
) )
guild_ids = await app_.user_storage.get_user_guilds(user_id) guild_ids = await app_.user_storage.get_user_guilds(user_id)
@ -67,26 +64,22 @@ async def mass_user_update(user_id, app_=None):
session_ids.extend( session_ids.extend(
await app_.dispatcher.dispatch_many_filter_list( await app_.dispatcher.dispatch_many_filter_list(
'guild', guild_ids, session_ids, "guild", guild_ids, session_ids, "USER_UPDATE", public_user
'USER_UPDATE', public_user
) )
) )
session_ids.extend( session_ids.extend(
await app_.dispatcher.dispatch_many_filter_list( await app_.dispatcher.dispatch_many_filter_list(
'friend', friend_ids, session_ids, "friend", friend_ids, session_ids, "USER_UPDATE", public_user
'USER_UPDATE', public_user
) )
) )
await app_.dispatcher.dispatch_many( await app_.dispatcher.dispatch_many("lazy_guild", guild_ids, "update_user", user_id)
'lazy_guild', guild_ids, 'update_user', user_id
)
return public_user, private_user return public_user, private_user
@bp.route('/@me', methods=['GET']) @bp.route("/@me", methods=["GET"])
async def get_me(): async def get_me():
"""Get the current user's information.""" """Get the current user's information."""
user_id = await token_check() user_id = await token_check()
@ -94,18 +87,21 @@ async def get_me():
return jsonify(user) return jsonify(user)
@bp.route('/<int:target_id>', methods=['GET']) @bp.route("/<int:target_id>", methods=["GET"])
async def get_other(target_id): async def get_other(target_id):
"""Get any user, given the user ID.""" """Get any user, given the user ID."""
user_id = await token_check() user_id = await token_check()
bot = await app.db.fetchval(""" bot = await app.db.fetchval(
"""
SELECT bot FROM users SELECT bot FROM users
WHERE users.id = $1 WHERE users.id = $1
""", user_id) """,
user_id,
)
if not bot: if not bot:
raise Forbidden('Only bots can use this endpoint') raise Forbidden("Only bots can use this endpoint")
other = await app.storage.get_user(target_id) other = await app.storage.get_user(target_id)
return jsonify(other) return jsonify(other)
@ -116,66 +112,80 @@ async def _try_username_patch(user_id, new_username: str) -> str:
discrim = None discrim = None
try: try:
await app.db.execute(""" await app.db.execute(
"""
UPDATE users UPDATE users
SET username = $1 SET username = $1
WHERE users.id = $2 WHERE users.id = $2
""", new_username, user_id) """,
new_username,
user_id,
)
return await app.db.fetchval(""" return await app.db.fetchval(
"""
SELECT discriminator SELECT discriminator
FROM users FROM users
WHERE users.id = $1 WHERE users.id = $1
""", user_id) """,
user_id,
)
except UniqueViolationError: except UniqueViolationError:
discrim = await roll_discrim(new_username) discrim = await roll_discrim(new_username)
if not discrim: if not discrim:
raise BadRequest('Unable to change username', { raise BadRequest(
'username': 'Too many people are with this username.' "Unable to change username",
}) {"username": "Too many people are with this username."},
)
await app.db.execute(""" await app.db.execute(
"""
UPDATE users UPDATE users
SET username = $1, discriminator = $2 SET username = $1, discriminator = $2
WHERE users.id = $3 WHERE users.id = $3
""", new_username, discrim, user_id) """,
new_username,
discrim,
user_id,
)
return discrim return discrim
async def _try_discrim_patch(user_id, new_discrim: str): async def _try_discrim_patch(user_id, new_discrim: str):
try: try:
await app.db.execute(""" await app.db.execute(
"""
UPDATE users UPDATE users
SET discriminator = $1 SET discriminator = $1
WHERE id = $2 WHERE id = $2
""", new_discrim, user_id) """,
new_discrim,
user_id,
)
except UniqueViolationError: except UniqueViolationError:
raise BadRequest('Invalid discriminator', { raise BadRequest(
'discriminator': 'Someone already used this discriminator.' "Invalid discriminator",
}) {"discriminator": "Someone already used this discriminator."},
)
async def _check_pass(j, user): async def _check_pass(j, user):
# Do not do password checks on unclaimed accounts # Do not do password checks on unclaimed accounts
if user['email'] is None: if user["email"] is None:
return return
if not j['password']: if not j["password"]:
raise BadRequest('password required', { raise BadRequest("password required", {"password": "password required"})
'password': 'password required'
})
phash = user['password_hash'] phash = user["password_hash"]
if not await check_password(phash, j['password']): if not await check_password(phash, j["password"]):
raise BadRequest('password incorrect', { raise BadRequest("password incorrect", {"password": "password does not match."})
'password': 'password does not match.'
})
@bp.route('/@me', methods=['PATCH']) @bp.route("/@me", methods=["PATCH"])
async def patch_me(): async def patch_me():
"""Patch the current user's information.""" """Patch the current user's information."""
user_id = await token_check() user_id = await token_check()
@ -183,36 +193,43 @@ async def patch_me():
j = validate(await request.get_json(), USER_UPDATE) j = validate(await request.get_json(), USER_UPDATE)
user = await app.storage.get_user(user_id, True) user = await app.storage.get_user(user_id, True)
user['password_hash'] = await app.db.fetchval(""" user["password_hash"] = await app.db.fetchval(
"""
SELECT password_hash SELECT password_hash
FROM users FROM users
WHERE id = $1 WHERE id = $1
""", user_id) """,
user_id,
)
if to_update(j, user, 'username'): if to_update(j, user, "username"):
# this will take care of regenning a new discriminator # this will take care of regenning a new discriminator
discrim = await _try_username_patch(user_id, j['username']) discrim = await _try_username_patch(user_id, j["username"])
user['username'] = j['username'] user["username"] = j["username"]
user['discriminator'] = discrim user["discriminator"] = discrim
if to_update(j, user, 'discriminator'): if to_update(j, user, "discriminator"):
# the API treats discriminators as integers, # the API treats discriminators as integers,
# but I work with strings on the database. # but I work with strings on the database.
new_discrim = str(j['discriminator']) new_discrim = str(j["discriminator"])
await _try_discrim_patch(user_id, new_discrim) await _try_discrim_patch(user_id, new_discrim)
user['discriminator'] = new_discrim user["discriminator"] = new_discrim
if to_update(j, user, 'email'): if to_update(j, user, "email"):
await _check_pass(j, user) await _check_pass(j, user)
# TODO: reverify the new email? # TODO: reverify the new email?
await app.db.execute(""" await app.db.execute(
"""
UPDATE users UPDATE users
SET email = $1 SET email = $1
WHERE id = $2 WHERE id = $2
""", j['email'], user_id) """,
user['email'] = j['email'] j["email"],
user_id,
)
user["email"] = j["email"]
# only update if values are different # only update if values are different
# from what the user gave. # from what the user gave.
@ -224,44 +241,49 @@ async def patch_me():
# IconManager.update will take care of validating # IconManager.update will take care of validating
# the value once put()-ing # the value once put()-ing
if to_update(j, user, 'avatar'): if to_update(j, user, "avatar"):
mime, _ = parse_data_uri(j['avatar']) mime, _ = parse_data_uri(j["avatar"])
if mime == 'image/gif' and user['premium_type'] == PremiumType.NONE: if mime == "image/gif" and user["premium_type"] == PremiumType.NONE:
raise BadRequest('no gif without nitro') raise BadRequest("no gif without nitro")
new_icon = await app.icons.update( new_icon = await app.icons.update("user", user_id, j["avatar"], size=(128, 128))
'user', user_id, j['avatar'], size=(128, 128))
await app.db.execute(""" await app.db.execute(
"""
UPDATE users UPDATE users
SET avatar = $1 SET avatar = $1
WHERE id = $2 WHERE id = $2
""", new_icon.icon_hash, user_id) """,
new_icon.icon_hash,
user_id,
)
if user['email'] is None and not 'new_password' in j: if user["email"] is None and not "new_password" in j:
raise BadRequest('missing password', { raise BadRequest("missing password", {"password": "Please set a password."})
'password': 'Please set a password.'
})
if 'new_password' in j and j['new_password']: if "new_password" in j and j["new_password"]:
await _check_pass(j, user) await _check_pass(j, user)
new_hash = await hash_data(j['new_password']) new_hash = await hash_data(j["new_password"])
await app.db.execute(""" await app.db.execute(
"""
UPDATE users UPDATE users
SET password_hash = $1 SET password_hash = $1
WHERE id = $2 WHERE id = $2
""", new_hash, user_id) """,
new_hash,
user_id,
)
user.pop('password_hash') user.pop("password_hash")
_, private_user = await mass_user_update(user_id, app) _, private_user = await mass_user_update(user_id, app)
return jsonify(private_user) return jsonify(private_user)
@bp.route('/@me/guilds', methods=['GET']) @bp.route("/@me/guilds", methods=["GET"])
async def get_me_guilds(): async def get_me_guilds():
"""Get partial user guilds.""" """Get partial user guilds."""
user_id = await token_check() user_id = await token_check()
@ -270,27 +292,30 @@ async def get_me_guilds():
partials = [] partials = []
for guild_id in guild_ids: for guild_id in guild_ids:
partial = await app.db.fetchrow(""" partial = await app.db.fetchrow(
"""
SELECT id::text, name, icon, owner_id SELECT id::text, name, icon, owner_id
FROM guilds FROM guilds
WHERE guilds.id = $1 WHERE guilds.id = $1
""", guild_id) """,
guild_id,
)
partial = dict(partial) partial = dict(partial)
user_perms = await base_permissions(user_id, guild_id) user_perms = await base_permissions(user_id, guild_id)
partial['permissions'] = user_perms.binary partial["permissions"] = user_perms.binary
partial['owner'] = partial['owner_id'] == user_id partial["owner"] = partial["owner_id"] == user_id
partial.pop('owner_id') partial.pop("owner_id")
partials.append(partial) partials.append(partial)
return jsonify(partials) return jsonify(partials)
@bp.route('/@me/guilds/<int:guild_id>', methods=['DELETE']) @bp.route("/@me/guilds/<int:guild_id>", methods=["DELETE"])
async def leave_guild(guild_id: int): async def leave_guild(guild_id: int):
"""Leave a guild.""" """Leave a guild."""
user_id = await token_check() user_id = await token_check()
@ -298,7 +323,7 @@ async def leave_guild(guild_id: int):
await remove_member(guild_id, user_id) await remove_member(guild_id, user_id)
return '', 204 return "", 204
# @bp.route('/@me/connections', methods=['GET']) # @bp.route('/@me/connections', methods=['GET'])
@ -306,7 +331,7 @@ async def get_connections():
pass pass
@bp.route('/@me/consent', methods=['GET', 'POST']) @bp.route("/@me/consent", methods=["GET", "POST"])
async def get_consent(): async def get_consent():
"""Always disable data collection. """Always disable data collection.
@ -314,57 +339,58 @@ async def get_consent():
by the client and ignores them, as they by the client and ignores them, as they
will always be false. will always be false.
""" """
return jsonify({ return jsonify(
'usage_statistics': { {
'consented': False, "usage_statistics": {"consented": False},
}, "personalization": {"consented": False},
'personalization': {
'consented': False,
} }
}) )
@bp.route('/@me/harvest', methods=['GET']) @bp.route("/@me/harvest", methods=["GET"])
async def get_harvest(): async def get_harvest():
"""Dummy route""" """Dummy route"""
return '', 204 return "", 204
@bp.route('/@me/activities/statistics/applications', methods=['GET']) @bp.route("/@me/activities/statistics/applications", methods=["GET"])
async def get_stats_applications(): async def get_stats_applications():
"""Dummy route for info on gameplay time and such""" """Dummy route for info on gameplay time and such"""
return jsonify([]) return jsonify([])
@bp.route('/@me/library', methods=['GET']) @bp.route("/@me/library", methods=["GET"])
async def get_library(): async def get_library():
"""Probably related to Discord Store?""" """Probably related to Discord Store?"""
return jsonify([]) return jsonify([])
@bp.route('/<int:peer_id>/profile', methods=['GET']) @bp.route("/<int:peer_id>/profile", methods=["GET"])
async def get_profile(peer_id: int): async def get_profile(peer_id: int):
"""Get a user's profile.""" """Get a user's profile."""
user_id = await token_check() user_id = await token_check()
peer = await app.storage.get_user(peer_id) peer = await app.storage.get_user(peer_id)
if not peer: if not peer:
return '', 404 return "", 404
mutuals = await app.user_storage.get_mutual_guilds(user_id, peer_id) mutuals = await app.user_storage.get_mutual_guilds(user_id, peer_id)
friends = await app.user_storage.are_friends_with(user_id, peer_id) friends = await app.user_storage.are_friends_with(user_id, peer_id)
# don't return a proper card if no guilds are being shared. # don't return a proper card if no guilds are being shared.
if not mutuals and not friends: if not mutuals and not friends:
return '', 404 return "", 404
# actual premium status is determined by that # actual premium status is determined by that
# column being NULL or not # column being NULL or not
peer_premium = await app.db.fetchval(""" peer_premium = await app.db.fetchval(
"""
SELECT premium_since SELECT premium_since
FROM users FROM users
WHERE id = $1 WHERE id = $1
""", peer_id) """,
peer_id,
)
mutual_guilds = await app.user_storage.get_mutual_guilds(user_id, peer_id) mutual_guilds = await app.user_storage.get_mutual_guilds(user_id, peer_id)
mutual_res = [] mutual_res = []
@ -372,45 +398,49 @@ async def get_profile(peer_id: int):
# ascending sorting # ascending sorting
for guild_id in sorted(mutual_guilds): for guild_id in sorted(mutual_guilds):
nick = await app.db.fetchval(""" nick = await app.db.fetchval(
"""
SELECT nickname SELECT nickname
FROM members FROM members
WHERE guild_id = $1 AND user_id = $2 WHERE guild_id = $1 AND user_id = $2
""", guild_id, peer_id) """,
guild_id,
peer_id,
)
mutual_res.append({ mutual_res.append({"id": str(guild_id), "nick": nick})
'id': str(guild_id),
'nick': nick,
})
return jsonify({ return jsonify(
'user': peer, {
'connected_accounts': [], "user": peer,
'premium_since': peer_premium, "connected_accounts": [],
'mutual_guilds': mutual_res, "premium_since": peer_premium,
}) "mutual_guilds": mutual_res,
}
)
@bp.route('/@me/mentions', methods=['GET']) @bp.route("/@me/mentions", methods=["GET"])
async def _get_mentions(): async def _get_mentions():
user_id = await token_check() user_id = await token_check()
j = validate(dict(request.args), GET_MENTIONS) j = validate(dict(request.args), GET_MENTIONS)
guild_query = 'AND messages.guild_id = $2' if 'guild_id' in j else '' guild_query = "AND messages.guild_id = $2" if "guild_id" in j else ""
role_query = "OR content LIKE '%<@&%'" if j['roles'] else '' role_query = "OR content LIKE '%<@&%'" if j["roles"] else ""
everyone_query = "OR content LIKE '%@everyone%'" if j['everyone'] else '' everyone_query = "OR content LIKE '%@everyone%'" if j["everyone"] else ""
mention_user = f'<@{user_id}>' mention_user = f"<@{user_id}>"
args = [mention_user] args = [mention_user]
if guild_query: if guild_query:
args.append(j['guild_id']) args.append(j["guild_id"])
guild_ids = await app.user_storage.get_user_guilds(user_id) guild_ids = await app.user_storage.get_user_guilds(user_id)
gids = ','.join(str(guild_id) for guild_id in guild_ids) gids = ",".join(str(guild_id) for guild_id in guild_ids)
rows = await app.db.fetch(f""" rows = await app.db.fetch(
f"""
SELECT messages.id SELECT messages.id
FROM messages FROM messages
JOIN channels ON messages.channel_id = channels.id JOIN channels ON messages.channel_id = channels.id
@ -423,20 +453,20 @@ async def _get_mentions():
{guild_query} {guild_query}
) )
LIMIT {j["limit"]} LIMIT {j["limit"]}
""", *args) """,
*args,
)
res = [] res = []
for row in rows: for row in rows:
message = await app.storage.get_message(row['id']) message = await app.storage.get_message(row["id"])
gid = int(message['guild_id']) gid = int(message["guild_id"])
# ignore messages pre-messages.guild_id # ignore messages pre-messages.guild_id
if gid not in guild_ids: if gid not in guild_ids:
continue continue
res.append( res.append(message)
message
)
return jsonify(res) return jsonify(res)
@ -449,18 +479,20 @@ def rand_hex(length: int = 8) -> str:
async def _del_from_table(db, table: str, user_id: int): async def _del_from_table(db, table: str, user_id: int):
"""Delete a row from a table.""" """Delete a row from a table."""
column = { column = {
'channel_overwrites': 'target_user', "channel_overwrites": "target_user",
'user_settings': 'id', "user_settings": "id",
'group_dm_members': 'member_id' "group_dm_members": "member_id",
}.get(table, 'user_id') }.get(table, "user_id")
res = await db.execute(f""" res = await db.execute(
f"""
DELETE FROM {table} DELETE FROM {table}
WHERE {column} = $1 WHERE {column} = $1
""", user_id) """,
user_id,
)
log.info('Deleting uid {} from {}, res: {!r}', log.info("Deleting uid {} from {}, res: {!r}", user_id, table, res)
user_id, table, res)
async def delete_user(user_id, *, app_=None): async def delete_user(user_id, *, app_=None):
@ -470,13 +502,14 @@ async def delete_user(user_id, *, app_=None):
db = app_.db db = app_.db
new_username = f'Deleted User {rand_hex()}' new_username = f"Deleted User {rand_hex()}"
# by using a random hex in password_hash # by using a random hex in password_hash
# we break attempts at using the default '123' password hash # we break attempts at using the default '123' password hash
# to issue valid tokens for deleted users. # to issue valid tokens for deleted users.
await db.execute(""" await db.execute(
"""
UPDATE users UPDATE users
SET SET
username = $1, username = $1,
@ -490,32 +523,39 @@ async def delete_user(user_id, *, app_=None):
password_hash = $2 password_hash = $2
WHERE WHERE
id = $3 id = $3
""", new_username, rand_hex(32), user_id) """,
new_username,
rand_hex(32),
user_id,
)
# remove the user from various tables # remove the user from various tables
await _del_from_table(db, 'user_settings', user_id) await _del_from_table(db, "user_settings", user_id)
await _del_from_table(db, 'user_payment_sources', user_id) await _del_from_table(db, "user_payment_sources", user_id)
await _del_from_table(db, 'user_subscriptions', user_id) await _del_from_table(db, "user_subscriptions", user_id)
await _del_from_table(db, 'user_payments', user_id) await _del_from_table(db, "user_payments", user_id)
await _del_from_table(db, 'user_read_state', user_id) await _del_from_table(db, "user_read_state", user_id)
await _del_from_table(db, 'guild_settings', user_id) await _del_from_table(db, "guild_settings", user_id)
await _del_from_table(db, 'guild_settings_channel_overrides', user_id) await _del_from_table(db, "guild_settings_channel_overrides", user_id)
await db.execute(""" await db.execute(
"""
DELETE FROM relationships DELETE FROM relationships
WHERE user_id = $1 OR peer_id = $1 WHERE user_id = $1 OR peer_id = $1
""", user_id) """,
user_id,
)
# DMs are still maintained, but not the state. # DMs are still maintained, but not the state.
await _del_from_table(db, 'dm_channel_state', user_id) await _del_from_table(db, "dm_channel_state", user_id)
# NOTE: we don't delete the group dms the user is an owner of... # NOTE: we don't delete the group dms the user is an owner of...
# TODO: group dm owner reassign when the owner leaves a gdm # TODO: group dm owner reassign when the owner leaves a gdm
await _del_from_table(db, 'group_dm_members', user_id) await _del_from_table(db, "group_dm_members", user_id)
await _del_from_table(db, 'members', user_id) await _del_from_table(db, "members", user_id)
await _del_from_table(db, 'member_roles', user_id) await _del_from_table(db, "member_roles", user_id)
await _del_from_table(db, 'channel_overwrites', user_id) await _del_from_table(db, "channel_overwrites", user_id)
# after updating the user, we send USER_UPDATE so that all the other # after updating the user, we send USER_UPDATE so that all the other
# clients can refresh their caches on the now-deleted user # clients can refresh their caches on the now-deleted user
@ -540,15 +580,12 @@ async def user_disconnect(user_id: int):
await state.ws.ws.close(4000) await state.ws.ws.close(4000)
# force everyone to see the user as offline # force everyone to see the user as offline
await app.presence.dispatch_pres(user_id, { await app.presence.dispatch_pres(
'afk': False, user_id, {"afk": False, "status": "offline", "game": None, "since": 0}
'status': 'offline', )
'game': None,
'since': 0,
})
@bp.route('/@me/delete', methods=['POST']) @bp.route("/@me/delete", methods=["POST"])
async def delete_account(): async def delete_account():
"""Delete own account. """Delete own account.
@ -560,29 +597,35 @@ async def delete_account():
j = await request.get_json() j = await request.get_json()
try: try:
password = j['password'] password = j["password"]
except KeyError: except KeyError:
raise BadRequest('password required') raise BadRequest("password required")
owned_guilds = await app.db.fetchval(""" owned_guilds = await app.db.fetchval(
"""
SELECT COUNT(*) SELECT COUNT(*)
FROM guilds FROM guilds
WHERE owner_id = $1 WHERE owner_id = $1
""", user_id) """,
user_id,
)
if owned_guilds > 0: if owned_guilds > 0:
raise BadRequest('You still own guilds.') raise BadRequest("You still own guilds.")
pwd_hash = await app.db.fetchval(""" pwd_hash = await app.db.fetchval(
"""
SELECT password_hash SELECT password_hash
FROM users FROM users
WHERE id = $1 WHERE id = $1
""", user_id) """,
user_id,
)
if not await check_password(pwd_hash, password): if not await check_password(pwd_hash, password):
raise Unauthorized('password does not match') raise Unauthorized("password does not match")
await delete_user(user_id) await delete_user(user_id)
await user_disconnect(user_id) await user_disconnect(user_id)
return '', 204 return "", 204

View File

@ -25,7 +25,7 @@ from quart import Blueprint, jsonify, current_app as app
from litecord.blueprints.auth import token_check from litecord.blueprints.auth import token_check
bp = Blueprint('voice', __name__) bp = Blueprint("voice", __name__)
def _majority_region_count(regions: list) -> str: def _majority_region_count(regions: list) -> str:
@ -39,12 +39,14 @@ def _majority_region_count(regions: list) -> str:
async def _choose_random_region() -> Optional[str]: async def _choose_random_region() -> Optional[str]:
"""Give a random voice region.""" """Give a random voice region."""
regions = await app.db.fetch(""" regions = await app.db.fetch(
"""
SELECT id SELECT id
FROM voice_regions FROM voice_regions
""") """
)
regions = [r['id'] for r in regions] regions = [r["id"] for r in regions]
if not regions: if not regions:
return None return None
@ -64,11 +66,14 @@ async def _majority_region_any(user_id) -> Optional[str]:
res = [] res = []
for guild_id in guilds: for guild_id in guilds:
region = await app.db.fetchval(""" region = await app.db.fetchval(
"""
SELECT region SELECT region
FROM guilds FROM guilds
WHERE id = $1 WHERE id = $1
""", guild_id) """,
guild_id,
)
res.append(region) res.append(region)
@ -83,20 +88,23 @@ async def _majority_region_any(user_id) -> Optional[str]:
async def majority_region(user_id: int) -> Optional[str]: async def majority_region(user_id: int) -> Optional[str]:
"""Given a user ID, give the most likely region for the user to be """Given a user ID, give the most likely region for the user to be
happy with.""" happy with."""
regions = await app.db.fetch(""" regions = await app.db.fetch(
"""
SELECT region SELECT region
FROM guilds FROM guilds
WHERE owner_id = $1 WHERE owner_id = $1
""", user_id) """,
user_id,
)
if not regions: if not regions:
return await _majority_region_any(user_id) return await _majority_region_any(user_id)
regions = [r['region'] for r in regions] regions = [r["region"] for r in regions]
return _majority_region_count(regions) return _majority_region_count(regions)
@bp.route('/regions', methods=['GET']) @bp.route("/regions", methods=["GET"])
async def voice_regions(): async def voice_regions():
"""Return voice regions.""" """Return voice regions."""
user_id = await token_check() user_id = await token_check()
@ -105,6 +113,6 @@ async def voice_regions():
regions = await app.storage.all_voice_regions() regions = await app.storage.all_voice_regions()
for region in regions: for region in regions:
region['optimal'] = region['id'] == best_region region["optimal"] = region["id"] == best_region
return jsonify(regions) return jsonify(regions)

View File

@ -26,22 +26,28 @@ from quart import Blueprint, jsonify, current_app as app, request
from litecord.auth import token_check from litecord.auth import token_check
from litecord.blueprints.checks import ( from litecord.blueprints.checks import (
channel_check, channel_perm_check, guild_check, guild_perm_check channel_check,
channel_perm_check,
guild_check,
guild_perm_check,
) )
from litecord.schemas import ( from litecord.schemas import (
validate, WEBHOOK_CREATE, WEBHOOK_UPDATE, WEBHOOK_MESSAGE_CREATE validate,
WEBHOOK_CREATE,
WEBHOOK_UPDATE,
WEBHOOK_MESSAGE_CREATE,
) )
from litecord.enums import ChannelType from litecord.enums import ChannelType
from litecord.snowflake import get_snowflake from litecord.snowflake import get_snowflake
from litecord.utils import async_map from litecord.utils import async_map
from litecord.errors import ( from litecord.errors import WebhookNotFound, Unauthorized, ChannelNotFound, BadRequest
WebhookNotFound, Unauthorized, ChannelNotFound, BadRequest
)
from litecord.blueprints.channel.messages import ( from litecord.blueprints.channel.messages import (
msg_create_request, msg_create_check_content, msg_add_attachment, msg_create_request,
msg_guild_text_mentions msg_create_check_content,
msg_add_attachment,
msg_guild_text_mentions,
) )
from litecord.embed.sanitizer import fill_embed, fetch_raw_img from litecord.embed.sanitizer import fill_embed, fetch_raw_img
from litecord.embed.messages import process_url_embed, is_media_url from litecord.embed.messages import process_url_embed, is_media_url
@ -50,30 +56,34 @@ from litecord.utils import pg_set_json
from litecord.enums import MessageType from litecord.enums import MessageType
from litecord.images import STATIC_IMAGE_MIMES from litecord.images import STATIC_IMAGE_MIMES
bp = Blueprint('webhooks', __name__) bp = Blueprint("webhooks", __name__)
async def get_webhook(webhook_id: int, *, async def get_webhook(
secure: bool=True) -> Optional[Dict[str, Any]]: webhook_id: int, *, secure: bool = True
) -> Optional[Dict[str, Any]]:
"""Get a webhook data""" """Get a webhook data"""
row = await app.db.fetchrow(""" row = await app.db.fetchrow(
"""
SELECT id::text, guild_id::text, channel_id::text, creator_id, SELECT id::text, guild_id::text, channel_id::text, creator_id,
name, avatar, token name, avatar, token
FROM webhooks FROM webhooks
WHERE id = $1 WHERE id = $1
""", webhook_id) """,
webhook_id,
)
if not row: if not row:
return None return None
drow = dict(row) drow = dict(row)
drow['user'] = await app.storage.get_user(row['creator_id']) drow["user"] = await app.storage.get_user(row["creator_id"])
drow.pop('creator_id') drow.pop("creator_id")
if not secure: if not secure:
drow.pop('user') drow.pop("user")
drow.pop('guild_id') drow.pop("guild_id")
return drow return drow
@ -82,7 +92,7 @@ async def _webhook_check(channel_id):
user_id = await token_check() user_id = await token_check()
await channel_check(user_id, channel_id, only=ChannelType.GUILD_TEXT) await channel_check(user_id, channel_id, only=ChannelType.GUILD_TEXT)
await channel_perm_check(user_id, channel_id, 'manage_webhooks') await channel_perm_check(user_id, channel_id, "manage_webhooks")
return user_id return user_id
@ -91,17 +101,20 @@ async def _webhook_check_guild(guild_id):
user_id = await token_check() user_id = await token_check()
await guild_check(user_id, guild_id) await guild_check(user_id, guild_id)
await guild_perm_check(user_id, guild_id, 'manage_webhooks') await guild_perm_check(user_id, guild_id, "manage_webhooks")
return user_id return user_id
async def _webhook_check_fw(webhook_id): async def _webhook_check_fw(webhook_id):
"""Make a check from an incoming webhook id (fw = from webhook).""" """Make a check from an incoming webhook id (fw = from webhook)."""
guild_id = await app.db.fetchval(""" guild_id = await app.db.fetchval(
"""
SELECT guild_id FROM webhooks SELECT guild_id FROM webhooks
WHERE id = $1 WHERE id = $1
""", webhook_id) """,
webhook_id,
)
if guild_id is None: if guild_id is None:
raise WebhookNotFound() raise WebhookNotFound()
@ -110,42 +123,48 @@ async def _webhook_check_fw(webhook_id):
async def _webhook_many(where_clause, arg: int): async def _webhook_many(where_clause, arg: int):
webhook_ids = await app.db.fetch(f""" webhook_ids = await app.db.fetch(
f"""
SELECT id SELECT id
FROM webhooks FROM webhooks
{where_clause} {where_clause}
""", arg) """,
arg,
webhook_ids = [r['id'] for r in webhook_ids]
return jsonify(
await async_map(get_webhook, webhook_ids)
) )
webhook_ids = [r["id"] for r in webhook_ids]
return jsonify(await async_map(get_webhook, webhook_ids))
async def webhook_token_check(webhook_id: int, webhook_token: str): async def webhook_token_check(webhook_id: int, webhook_token: str):
"""token_check() equivalent for webhooks.""" """token_check() equivalent for webhooks."""
row = await app.db.fetchrow(""" row = await app.db.fetchrow(
"""
SELECT guild_id, channel_id SELECT guild_id, channel_id
FROM webhooks FROM webhooks
WHERE id = $1 AND token = $2 WHERE id = $1 AND token = $2
""", webhook_id, webhook_token) """,
webhook_id,
webhook_token,
)
if row is None: if row is None:
raise Unauthorized('webhook not found or unauthorized') raise Unauthorized("webhook not found or unauthorized")
return row['guild_id'], row['channel_id'] return row["guild_id"], row["channel_id"]
async def _dispatch_webhook_update(guild_id: int, channel_id): async def _dispatch_webhook_update(guild_id: int, channel_id):
await app.dispatcher.dispatch('guild', guild_id, 'WEBHOOKS_UPDATE', { await app.dispatcher.dispatch(
'guild_id': str(guild_id), "guild",
'channel_id': str(channel_id) guild_id,
}) "WEBHOOKS_UPDATE",
{"guild_id": str(guild_id), "channel_id": str(channel_id)},
)
@bp.route("/channels/<int:channel_id>/webhooks", methods=["POST"])
@bp.route('/channels/<int:channel_id>/webhooks', methods=['POST'])
async def create_webhook(channel_id: int): async def create_webhook(channel_id: int):
"""Create a webhook given a channel.""" """Create a webhook given a channel."""
user_id = await _webhook_check(channel_id) user_id = await _webhook_check(channel_id)
@ -162,8 +181,7 @@ async def create_webhook(channel_id: int):
token = secrets.token_urlsafe(40) token = secrets.token_urlsafe(40)
webhook_icon = await app.icons.put( webhook_icon = await app.icons.put(
'user', webhook_id, j.get('avatar'), "user", webhook_id, j.get("avatar"), always_icon=True, size=(128, 128)
always_icon=True, size=(128, 128)
) )
await app.db.execute( await app.db.execute(
@ -173,36 +191,41 @@ async def create_webhook(channel_id: int):
VALUES VALUES
($1, $2, $3, $4, $5, $6, $7) ($1, $2, $3, $4, $5, $6, $7)
""", """,
webhook_id, guild_id, channel_id, user_id, webhook_id,
j['name'], webhook_icon.icon_hash, token guild_id,
channel_id,
user_id,
j["name"],
webhook_icon.icon_hash,
token,
) )
await _dispatch_webhook_update(guild_id, channel_id) await _dispatch_webhook_update(guild_id, channel_id)
return jsonify(await get_webhook(webhook_id)) return jsonify(await get_webhook(webhook_id))
@bp.route('/channels/<int:channel_id>/webhooks', methods=['GET']) @bp.route("/channels/<int:channel_id>/webhooks", methods=["GET"])
async def get_channel_webhook(channel_id: int): async def get_channel_webhook(channel_id: int):
"""Get a list of webhooks in a channel""" """Get a list of webhooks in a channel"""
await _webhook_check(channel_id) await _webhook_check(channel_id)
return await _webhook_many('WHERE channel_id = $1', channel_id) return await _webhook_many("WHERE channel_id = $1", channel_id)
@bp.route('/guilds/<int:guild_id>/webhooks', methods=['GET']) @bp.route("/guilds/<int:guild_id>/webhooks", methods=["GET"])
async def get_guild_webhook(guild_id): async def get_guild_webhook(guild_id):
"""Get all webhooks in a guild""" """Get all webhooks in a guild"""
await _webhook_check_guild(guild_id) await _webhook_check_guild(guild_id)
return await _webhook_many('WHERE guild_id = $1', guild_id) return await _webhook_many("WHERE guild_id = $1", guild_id)
@bp.route('/webhooks/<int:webhook_id>', methods=['GET']) @bp.route("/webhooks/<int:webhook_id>", methods=["GET"])
async def get_single_webhook(webhook_id): async def get_single_webhook(webhook_id):
"""Get a single webhook's information.""" """Get a single webhook's information."""
await _webhook_check_fw(webhook_id) await _webhook_check_fw(webhook_id)
return await jsonify(await get_webhook(webhook_id)) return await jsonify(await get_webhook(webhook_id))
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>', methods=['GET']) @bp.route("/webhooks/<int:webhook_id>/<webhook_token>", methods=["GET"])
async def get_tokened_webhook(webhook_id, webhook_token): async def get_tokened_webhook(webhook_id, webhook_token):
"""Get a webhook using its token.""" """Get a webhook using its token."""
await webhook_token_check(webhook_id, webhook_token) await webhook_token_check(webhook_id, webhook_token)
@ -210,46 +233,58 @@ async def get_tokened_webhook(webhook_id, webhook_token):
async def _update_webhook(webhook_id: int, j: dict): async def _update_webhook(webhook_id: int, j: dict):
if 'name' in j: if "name" in j:
await app.db.execute(""" await app.db.execute(
"""
UPDATE webhooks UPDATE webhooks
SET name = $1 SET name = $1
WHERE id = $2 WHERE id = $2
""", j['name'], webhook_id) """,
j["name"],
webhook_id,
)
if 'channel_id' in j: if "channel_id" in j:
await app.db.execute(""" await app.db.execute(
"""
UPDATE webhooks UPDATE webhooks
SET channel_id = $1 SET channel_id = $1
WHERE id = $2 WHERE id = $2
""", j['channel_id'], webhook_id) """,
j["channel_id"],
if 'avatar' in j: webhook_id,
new_icon = await app.icons.update(
'user', webhook_id, j['avatar'], always_icon=True, size=(128, 128)
) )
await app.db.execute(""" if "avatar" in j:
new_icon = await app.icons.update(
"user", webhook_id, j["avatar"], always_icon=True, size=(128, 128)
)
await app.db.execute(
"""
UPDATE webhooks UPDATE webhooks
SET icon = $1 SET icon = $1
WHERE id = $2 WHERE id = $2
""", new_icon.icon_hash, webhook_id) """,
new_icon.icon_hash,
webhook_id,
)
@bp.route('/webhooks/<int:webhook_id>', methods=['PATCH']) @bp.route("/webhooks/<int:webhook_id>", methods=["PATCH"])
async def modify_webhook(webhook_id: int): async def modify_webhook(webhook_id: int):
"""Patch a webhook.""" """Patch a webhook."""
_user_id, guild_id = await _webhook_check_fw(webhook_id) _user_id, guild_id = await _webhook_check_fw(webhook_id)
j = validate(await request.get_json(), WEBHOOK_UPDATE) j = validate(await request.get_json(), WEBHOOK_UPDATE)
if 'channel_id' in j: if "channel_id" in j:
# pre checks # pre checks
chan = await app.storage.get_channel(j['channel_id']) chan = await app.storage.get_channel(j["channel_id"])
# short-circuiting should ensure chan isn't none # short-circuiting should ensure chan isn't none
# by the time we do chan['guild_id'] # by the time we do chan['guild_id']
if chan and chan['guild_id'] != str(guild_id): if chan and chan["guild_id"] != str(guild_id):
raise ChannelNotFound('cant assign webhook to channel') raise ChannelNotFound("cant assign webhook to channel")
await _update_webhook(webhook_id, j) await _update_webhook(webhook_id, j)
@ -257,20 +292,18 @@ async def modify_webhook(webhook_id: int):
# we don't need to cast channel_id to int since that isn't # we don't need to cast channel_id to int since that isn't
# used in the dispatcher call # used in the dispatcher call
await _dispatch_webhook_update(guild_id, webhook['channel_id']) await _dispatch_webhook_update(guild_id, webhook["channel_id"])
return jsonify(webhook) return jsonify(webhook)
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>', methods=['PATCH']) @bp.route("/webhooks/<int:webhook_id>/<webhook_token>", methods=["PATCH"])
async def modify_webhook_tokened(webhook_id, webhook_token): async def modify_webhook_tokened(webhook_id, webhook_token):
"""Modify a webhook, using its token.""" """Modify a webhook, using its token."""
guild_id, channel_id = await webhook_token_check( guild_id, channel_id = await webhook_token_check(webhook_id, webhook_token)
webhook_id, webhook_token)
# forcefully pop() the channel id out of the schema # forcefully pop() the channel id out of the schema
# instead of making another, for simplicity's sake # instead of making another, for simplicity's sake
j = validate(await request.get_json(), j = validate(await request.get_json(), WEBHOOK_UPDATE.pop("channel_id"))
WEBHOOK_UPDATE.pop('channel_id'))
await _update_webhook(webhook_id, j) await _update_webhook(webhook_id, j)
await _dispatch_webhook_update(guild_id, channel_id) await _dispatch_webhook_update(guild_id, channel_id)
@ -281,35 +314,36 @@ async def delete_webhook(webhook_id: int):
"""Delete a webhook.""" """Delete a webhook."""
webhook = await get_webhook(webhook_id) webhook = await get_webhook(webhook_id)
res = await app.db.execute(""" res = await app.db.execute(
"""
DELETE FROM webhooks DELETE FROM webhooks
WHERE id = $1 WHERE id = $1
""", webhook_id) """,
webhook_id,
)
if res.lower() == 'delete 0': if res.lower() == "delete 0":
raise WebhookNotFound() raise WebhookNotFound()
# only casting the guild id since that's whats used # only casting the guild id since that's whats used
# on the dispatcher call. # on the dispatcher call.
await _dispatch_webhook_update( await _dispatch_webhook_update(int(webhook["guild_id"]), webhook["channel_id"])
int(webhook['guild_id']), webhook['channel_id']
)
@bp.route('/webhooks/<int:webhook_id>', methods=['DELETE']) @bp.route("/webhooks/<int:webhook_id>", methods=["DELETE"])
async def del_webhook(webhook_id): async def del_webhook(webhook_id):
"""Delete a webhook.""" """Delete a webhook."""
await _webhook_check_fw(webhook_id) await _webhook_check_fw(webhook_id)
await delete_webhook(webhook_id) await delete_webhook(webhook_id)
return '', 204 return "", 204
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>', methods=['DELETE']) @bp.route("/webhooks/<int:webhook_id>/<webhook_token>", methods=["DELETE"])
async def del_webhook_tokened(webhook_id, webhook_token): async def del_webhook_tokened(webhook_id, webhook_token):
"""Delete a webhook, with its token.""" """Delete a webhook, with its token."""
await webhook_token_check(webhook_id, webhook_token) await webhook_token_check(webhook_id, webhook_token)
await delete_webhook(webhook_id) await delete_webhook(webhook_id)
return '', 204 return "", 204
async def create_message_webhook(guild_id, channel_id, webhook_id, data): async def create_message_webhook(guild_id, channel_id, webhook_id, data):
@ -328,23 +362,27 @@ async def create_message_webhook(guild_id, channel_id, webhook_id, data):
message_id, message_id,
channel_id, channel_id,
guild_id, guild_id,
data['content'], data["content"],
data["tts"],
data['tts'], data["everyone_mention"],
data['everyone_mention'],
MessageType.DEFAULT.value, MessageType.DEFAULT.value,
data.get('embeds', []) data.get("embeds", []),
) )
info = data['info'] info = data["info"]
await conn.execute(""" await conn.execute(
"""
INSERT INTO message_webhook_info INSERT INTO message_webhook_info
(message_id, webhook_id, name, avatar) (message_id, webhook_id, name, avatar)
VALUES VALUES
($1, $2, $3, $4) ($1, $2, $3, $4)
""", message_id, webhook_id, info['name'], info['avatar']) """,
message_id,
webhook_id,
info["name"],
info["avatar"],
)
return message_id return message_id
@ -354,10 +392,15 @@ async def _webhook_avy_redir(webhook_id: int, avatar_url: EmbedURL):
url_hash = hashlib.sha256(avatar_url.to_md_path.encode()).hexdigest() url_hash = hashlib.sha256(avatar_url.to_md_path.encode()).hexdigest()
try: try:
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO webhook_avatars (webhook_id, hash, md_url_redir) INSERT INTO webhook_avatars (webhook_id, hash, md_url_redir)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
""", webhook_id, url_hash, avatar_url.url) """,
webhook_id,
url_hash,
avatar_url.url,
)
except asyncpg.UniqueViolationError: except asyncpg.UniqueViolationError:
pass pass
@ -371,36 +414,36 @@ async def _create_avatar(webhook_id: int, avatar_url: EmbedURL) -> str:
Litecord will write an URL that redirects to the given avatar_url, Litecord will write an URL that redirects to the given avatar_url,
using mediaproxy. using mediaproxy.
""" """
if avatar_url.scheme not in ('http', 'https'): if avatar_url.scheme not in ("http", "https"):
raise BadRequest('invalid avatar url scheme') raise BadRequest("invalid avatar url scheme")
if not is_media_url(avatar_url): if not is_media_url(avatar_url):
raise BadRequest('url is not media url') raise BadRequest("url is not media url")
# we still fetch the URL to check its validity, mimetypes, etc # we still fetch the URL to check its validity, mimetypes, etc
# but in the end, we will store it under the webhook_avatars table, # but in the end, we will store it under the webhook_avatars table,
# not IconManager. # not IconManager.
resp, raw = await fetch_raw_img(avatar_url) resp, raw = await fetch_raw_img(avatar_url)
#raw_b64 = base64.b64encode(raw).decode() # raw_b64 = base64.b64encode(raw).decode()
mime = resp.headers['content-type'] mime = resp.headers["content-type"]
# TODO: apng checks are missing (for this and everywhere else) # TODO: apng checks are missing (for this and everywhere else)
if mime not in STATIC_IMAGE_MIMES: if mime not in STATIC_IMAGE_MIMES:
raise BadRequest('invalid mime type for given url') raise BadRequest("invalid mime type for given url")
#b64_data = f'data:{mime};base64,{raw_b64}' # b64_data = f'data:{mime};base64,{raw_b64}'
# TODO: replace this by webhook_avatars # TODO: replace this by webhook_avatars
#icon = await app.icons.put( # icon = await app.icons.put(
# 'user', webhook_id, b64_data, # 'user', webhook_id, b64_data,
# always_icon=True, size=(128, 128) # always_icon=True, size=(128, 128)
#) # )
return await _webhook_avy_redir(webhook_id, avatar_url) return await _webhook_avy_redir(webhook_id, avatar_url)
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>', methods=['POST']) @bp.route("/webhooks/<int:webhook_id>/<webhook_token>", methods=["POST"])
async def execute_webhook(webhook_id: int, webhook_token): async def execute_webhook(webhook_id: int, webhook_token):
"""Execute a webhook. Sends a message to the channel the webhook """Execute a webhook. Sends a message to the channel the webhook
is tied to.""" is tied to."""
@ -413,41 +456,39 @@ async def execute_webhook(webhook_id: int, webhook_token):
# NOTE: we really pop here instead of adding a kwarg # NOTE: we really pop here instead of adding a kwarg
# to msg_create_request just because of webhooks. # to msg_create_request just because of webhooks.
# nonce isn't allowed on WEBHOOK_MESSAGE_CREATE # nonce isn't allowed on WEBHOOK_MESSAGE_CREATE
payload_json.pop('nonce') payload_json.pop("nonce")
j = validate(payload_json, WEBHOOK_MESSAGE_CREATE) j = validate(payload_json, WEBHOOK_MESSAGE_CREATE)
msg_create_check_content(j, files) msg_create_check_content(j, files)
# webhooks don't need permissions. # webhooks don't need permissions.
mentions_everyone = '@everyone' in j['content'] mentions_everyone = "@everyone" in j["content"]
mentions_here = '@here' in j['content'] mentions_here = "@here" in j["content"]
given_embeds = j.get('embeds', []) given_embeds = j.get("embeds", [])
webhook = await get_webhook(webhook_id) webhook = await get_webhook(webhook_id)
# webhooks have TWO avatars. one is from settings, the other is from # webhooks have TWO avatars. one is from settings, the other is from
# the json's icon_url. one can be handled gracefully by IconManager, # the json's icon_url. one can be handled gracefully by IconManager,
# but the other can't, at all. # but the other can't, at all.
avatar = webhook['avatar'] avatar = webhook["avatar"]
if 'avatar_url' in j and j['avatar_url'] is not None: if "avatar_url" in j and j["avatar_url"] is not None:
avatar = await _create_avatar(webhook_id, j['avatar_url']) avatar = await _create_avatar(webhook_id, j["avatar_url"])
message_id = await create_message_webhook( message_id = await create_message_webhook(
guild_id, channel_id, webhook_id, { guild_id,
'content': j.get('content', ''), channel_id,
'tts': j.get('tts', False), webhook_id,
{
'everyone_mention': mentions_everyone or mentions_here, "content": j.get("content", ""),
'embeds': await async_map(fill_embed, given_embeds), "tts": j.get("tts", False),
"everyone_mention": mentions_everyone or mentions_here,
'info': { "embeds": await async_map(fill_embed, given_embeds),
'name': j.get('username', webhook['name']), "info": {"name": j.get("username", webhook["name"]), "avatar": avatar},
'avatar': avatar },
}
}
) )
for pre_attachment in files: for pre_attachment in files:
@ -455,33 +496,28 @@ async def execute_webhook(webhook_id: int, webhook_token):
payload = await app.storage.get_message(message_id) payload = await app.storage.get_message(message_id)
await app.dispatcher.dispatch('channel', channel_id, await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload)
'MESSAGE_CREATE', payload)
# spawn embedder in the background, even when we're on a webhook. # spawn embedder in the background, even when we're on a webhook.
app.sched.spawn( app.sched.spawn(
process_url_embed( process_url_embed(app.config, app.storage, app.dispatcher, app.session, payload)
app.config, app.storage, app.dispatcher, app.session,
payload
)
) )
# we can assume its a guild text channel, so just call it # we can assume its a guild text channel, so just call it
await msg_guild_text_mentions( await msg_guild_text_mentions(payload, guild_id, mentions_everyone, mentions_here)
payload, guild_id, mentions_everyone, mentions_here)
# TODO: is it really 204? # TODO: is it really 204?
return '', 204 return "", 204
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>/slack',
methods=['POST']) @bp.route("/webhooks/<int:webhook_id>/<webhook_token>/slack", methods=["POST"])
async def execute_slack_webhook(webhook_id, webhook_token): async def execute_slack_webhook(webhook_id, webhook_token):
"""Execute a webhook but expecting Slack data.""" """Execute a webhook but expecting Slack data."""
# TODO: know slack webhooks # TODO: know slack webhooks
await webhook_token_check(webhook_id, webhook_token) await webhook_token_check(webhook_id, webhook_token)
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>/github', methods=['POST']) @bp.route("/webhooks/<int:webhook_id>/<webhook_token>/github", methods=["POST"])
async def execute_github_webhook(webhook_id, webhook_token): async def execute_github_webhook(webhook_id, webhook_token):
"""Execute a webhook but expecting GitHub data.""" """Execute a webhook but expecting GitHub data."""
# TODO: know github webhooks # TODO: know github webhooks

View File

@ -21,9 +21,14 @@ from typing import List, Any, Dict
from logbook import Logger from logbook import Logger
from .pubsub import GuildDispatcher, MemberDispatcher, \ from .pubsub import (
UserDispatcher, ChannelDispatcher, FriendDispatcher, \ GuildDispatcher,
LazyGuildDispatcher MemberDispatcher,
UserDispatcher,
ChannelDispatcher,
FriendDispatcher,
LazyGuildDispatcher,
)
log = Logger(__name__) log = Logger(__name__)
@ -44,17 +49,18 @@ class EventDispatcher:
when dispatching, the backend can do its own logic, given when dispatching, the backend can do its own logic, given
its subscriber ids. its subscriber ids.
""" """
def __init__(self, app): def __init__(self, app):
self.state_manager = app.state_manager self.state_manager = app.state_manager
self.app = app self.app = app
self.backends = { self.backends = {
'guild': GuildDispatcher(self), "guild": GuildDispatcher(self),
'member': MemberDispatcher(self), "member": MemberDispatcher(self),
'channel': ChannelDispatcher(self), "channel": ChannelDispatcher(self),
'user': UserDispatcher(self), "user": UserDispatcher(self),
'friend': FriendDispatcher(self), "friend": FriendDispatcher(self),
'lazy_guild': LazyGuildDispatcher(self), "lazy_guild": LazyGuildDispatcher(self),
} }
async def action(self, backend_str: str, action: str, key, identifier, *args): async def action(self, backend_str: str, action: str, key, identifier, *args):
@ -71,13 +77,13 @@ class EventDispatcher:
return await method(key, identifier, *args) return await method(key, identifier, *args)
async def subscribe(self, backend: str, key: Any, identifier: Any, async def subscribe(
flags: Dict[str, Any] = None): self, backend: str, key: Any, identifier: Any, flags: Dict[str, Any] = None
):
"""Subscribe a single element to the given backend.""" """Subscribe a single element to the given backend."""
flags = flags or {} flags = flags or {}
log.debug('SUB backend={} key={} <= id={}', log.debug("SUB backend={} key={} <= id={}", backend, key, identifier, backend)
backend, key, identifier, backend)
# this is a hacky solution for backwards compatibility between backends # this is a hacky solution for backwards compatibility between backends
# that implement flags and backends that don't. # that implement flags and backends that don't.
@ -85,16 +91,15 @@ class EventDispatcher:
# passing flags to backends that don't implement flags will # passing flags to backends that don't implement flags will
# cause errors as expected. # cause errors as expected.
if flags: if flags:
return await self.action(backend, 'sub', key, identifier, flags) return await self.action(backend, "sub", key, identifier, flags)
return await self.action(backend, 'sub', key, identifier) return await self.action(backend, "sub", key, identifier)
async def unsubscribe(self, backend: str, key: Any, identifier: Any): async def unsubscribe(self, backend: str, key: Any, identifier: Any):
"""Unsubscribe an element from the given backend.""" """Unsubscribe an element from the given backend."""
log.debug('UNSUB backend={} key={} => id={}', log.debug("UNSUB backend={} key={} => id={}", backend, key, identifier, backend)
backend, key, identifier, backend)
return await self.action(backend, 'unsub', key, identifier) return await self.action(backend, "unsub", key, identifier)
async def sub(self, backend, key, identifier): async def sub(self, backend, key, identifier):
"""Alias to subscribe().""" """Alias to subscribe()."""
@ -104,8 +109,13 @@ class EventDispatcher:
"""Alias to unsubscribe().""" """Alias to unsubscribe()."""
return await self.unsubscribe(backend, key, identifier) return await self.unsubscribe(backend, key, identifier)
async def sub_many(self, backend_str: str, identifier: Any, async def sub_many(
keys: list, flags: Dict[str, Any] = None): self,
backend_str: str,
identifier: Any,
keys: list,
flags: Dict[str, Any] = None,
):
"""Subscribe to multiple channels (all in a single backend) """Subscribe to multiple channels (all in a single backend)
at a time. at a time.
@ -116,8 +126,7 @@ class EventDispatcher:
for key in keys: for key in keys:
await self.subscribe(backend_str, key, identifier, flags) await self.subscribe(backend_str, key, identifier, flags)
async def mass_sub(self, identifier: Any, async def mass_sub(self, identifier: Any, backends: List[tuple]):
backends: List[tuple]):
"""Mass subscribe to many backends at once.""" """Mass subscribe to many backends at once."""
for bcall in backends: for bcall in backends:
backend_str, keys = bcall[0], bcall[1] backend_str, keys = bcall[0], bcall[1]
@ -128,8 +137,13 @@ class EventDispatcher:
# we have flags # we have flags
flags = bcall[2] flags = bcall[2]
log.debug('subscribing {} to {} keys in backend {}, flags: {}', log.debug(
identifier, len(keys), backend_str, flags) "subscribing {} to {} keys in backend {}, flags: {}",
identifier,
len(keys),
backend_str,
flags,
)
await self.sub_many(backend_str, identifier, keys, flags) await self.sub_many(backend_str, identifier, keys, flags)
@ -145,17 +159,14 @@ class EventDispatcher:
key = backend.KEY_TYPE(key) key = backend.KEY_TYPE(key)
return await backend.dispatch(key, *args, **kwargs) return await backend.dispatch(key, *args, **kwargs)
async def dispatch_many(self, backend_str: str, async def dispatch_many(self, backend_str: str, keys: List[Any], *args, **kwargs):
keys: List[Any], *args, **kwargs):
"""Dispatch to multiple keys in a single backend.""" """Dispatch to multiple keys in a single backend."""
log.info('MULTI DISPATCH: {!r}, {} keys', log.info("MULTI DISPATCH: {!r}, {} keys", backend_str, len(keys))
backend_str, len(keys))
for key in keys: for key in keys:
await self.dispatch(backend_str, key, *args, **kwargs) await self.dispatch(backend_str, key, *args, **kwargs)
async def dispatch_filter(self, backend_str: str, async def dispatch_filter(self, backend_str: str, key: Any, func, *args):
key: Any, func, *args):
"""Dispatch to a backend that only accepts """Dispatch to a backend that only accepts
(event, data) arguments with an optional filter (event, data) arguments with an optional filter
function.""" function."""
@ -163,9 +174,9 @@ class EventDispatcher:
key = backend.KEY_TYPE(key) key = backend.KEY_TYPE(key)
return await backend.dispatch_filter(key, func, *args) return await backend.dispatch_filter(key, func, *args)
async def dispatch_many_filter_list(self, backend_str: str, async def dispatch_many_filter_list(
keys: List[Any], sess_list: List[str], self, backend_str: str, keys: List[Any], sess_list: List[str], *args
*args): ):
"""Make a "unique" dispatch given a list of session ids. """Make a "unique" dispatch given a list of session ids.
This only works for backends that have a dispatch_filter This only works for backends that have a dispatch_filter
@ -175,9 +186,8 @@ class EventDispatcher:
for key in keys: for key in keys:
sess_list.extend( sess_list.extend(
await self.dispatch_filter( await self.dispatch_filter(
backend_str, key, backend_str, key, lambda sess_id: sess_id not in sess_list, *args
lambda sess_id: sess_id not in sess_list, )
*args)
) )
return sess_list return sess_list
@ -197,12 +207,12 @@ class EventDispatcher:
async def dispatch_guild(self, guild_id, event, data): async def dispatch_guild(self, guild_id, event, data):
"""Backwards compatibility with old EventDispatcher.""" """Backwards compatibility with old EventDispatcher."""
return await self.dispatch('guild', guild_id, event, data) return await self.dispatch("guild", guild_id, event, data)
async def dispatch_user_guild(self, user_id, guild_id, event, data): async def dispatch_user_guild(self, user_id, guild_id, event, data):
"""Backwards compatibility with old EventDispatcher.""" """Backwards compatibility with old EventDispatcher."""
return await self.dispatch('member', (guild_id, user_id), event, data) return await self.dispatch("member", (guild_id, user_id), event, data)
async def dispatch_user(self, user_id, event, data): async def dispatch_user(self, user_id, event, data):
"""Backwards compatibility with old EventDispatcher.""" """Backwards compatibility with old EventDispatcher."""
return await self.dispatch('user', user_id, event, data) return await self.dispatch("user", user_id, event, data)

View File

@ -19,4 +19,4 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
from .sanitizer import sanitize_embed from .sanitizer import sanitize_embed
__all__ = ['sanitize_embed'] __all__ = ["sanitize_embed"]

View File

@ -30,11 +30,7 @@ from litecord.embed.schemas import EmbedURL
log = Logger(__name__) log = Logger(__name__)
MEDIA_EXTENSIONS = ( MEDIA_EXTENSIONS = ("png", "jpg", "jpeg", "gif", "webm")
'png',
'jpg', 'jpeg',
'gif', 'webm'
)
async def insert_media_meta(url, config, session): async def insert_media_meta(url, config, session):
@ -45,18 +41,18 @@ async def insert_media_meta(url, config, session):
if meta is None: if meta is None:
return return
if not meta['image']: if not meta["image"]:
return return
return { return {
'type': 'image', "type": "image",
'url': url, "url": url,
'thumbnail': { "thumbnail": {
'width': meta['width'], "width": meta["width"],
'height': meta['height'], "height": meta["height"],
'url': url, "url": url,
'proxy_url': img_proxy_url "proxy_url": img_proxy_url,
} },
} }
@ -64,29 +60,32 @@ async def msg_update_embeds(payload, new_embeds, storage, dispatcher):
"""Update the message with the given embeds and dispatch a MESSAGE_UPDATE """Update the message with the given embeds and dispatch a MESSAGE_UPDATE
to users.""" to users."""
message_id = int(payload['id']) message_id = int(payload["id"])
channel_id = int(payload['channel_id']) channel_id = int(payload["channel_id"])
await storage.execute_with_json(""" await storage.execute_with_json(
"""
UPDATE messages UPDATE messages
SET embeds = $1 SET embeds = $1
WHERE messages.id = $2 WHERE messages.id = $2
""", new_embeds, message_id) """,
new_embeds,
message_id,
)
update_payload = { update_payload = {
'id': str(message_id), "id": str(message_id),
'channel_id': str(channel_id), "channel_id": str(channel_id),
'embeds': new_embeds, "embeds": new_embeds,
} }
if 'guild_id' in payload: if "guild_id" in payload:
update_payload['guild_id'] = payload['guild_id'] update_payload["guild_id"] = payload["guild_id"]
if 'flags' in payload: if "flags" in payload:
update_payload['flags'] = payload['flags'] update_payload["flags"] = payload["flags"]
await dispatcher.dispatch( await dispatcher.dispatch("channel", channel_id, "MESSAGE_UPDATE", update_payload)
'channel', channel_id, 'MESSAGE_UPDATE', update_payload)
def is_media_url(url) -> bool: def is_media_url(url) -> bool:
@ -98,7 +97,7 @@ def is_media_url(url) -> bool:
parsed = urllib.parse.urlparse(url) parsed = urllib.parse.urlparse(url)
path = Path(parsed.path) path = Path(parsed.path)
extension = path.suffix.lstrip('.') extension = path.suffix.lstrip(".")
return extension in MEDIA_EXTENSIONS return extension in MEDIA_EXTENSIONS
@ -109,20 +108,20 @@ async def insert_mp_embed(parsed, config, session):
return embed return embed
async def process_url_embed(config, storage, dispatcher, async def process_url_embed(
session, payload: dict, *, delay=0): config, storage, dispatcher, session, payload: dict, *, delay=0
):
"""Process URLs in a message and generate embeds based on that.""" """Process URLs in a message and generate embeds based on that."""
await asyncio.sleep(delay) await asyncio.sleep(delay)
message_id = int(payload['id']) message_id = int(payload["id"])
# if we already have embeds # if we already have embeds
# we shouldn't add our own. # we shouldn't add our own.
embeds = payload['embeds'] embeds = payload["embeds"]
if embeds: if embeds:
log.debug('url processor: ignoring existing embeds @ mid {}', log.debug("url processor: ignoring existing embeds @ mid {}", message_id)
message_id)
return return
# now, we have two types of embeds: # now, we have two types of embeds:
@ -130,7 +129,7 @@ async def process_url_embed(config, storage, dispatcher,
# - url embeds # - url embeds
# use regex to get URLs # use regex to get URLs
urls = re.findall(r'(https?://\S+)', payload['content']) urls = re.findall(r"(https?://\S+)", payload["content"])
urls = urls[:5] urls = urls[:5]
# from there, we need to parse each found url and check its path. # from there, we need to parse each found url and check its path.
@ -159,7 +158,6 @@ async def process_url_embed(config, storage, dispatcher,
if not new_embeds: if not new_embeds:
return return
log.debug('made {} embeds for mid {}', log.debug("made {} embeds for mid {}", len(new_embeds), message_id)
len(new_embeds), message_id)
await msg_update_embeds(payload, new_embeds, storage, dispatcher) await msg_update_embeds(payload, new_embeds, storage, dispatcher)

View File

@ -39,9 +39,7 @@ def sanitize_embed(embed: Embed) -> Embed:
This is non-complex sanitization as it doesn't This is non-complex sanitization as it doesn't
need the app object. need the app object.
""" """
return {**embed, **{ return {**embed, **{"type": "rich"}}
'type': 'rich'
}}
def path_exists(embed: Embed, components_in: Union[List[str], str]): def path_exists(embed: Embed, components_in: Union[List[str], str]):
@ -55,7 +53,7 @@ def path_exists(embed: Embed, components_in: Union[List[str], str]):
# get the list of components given # get the list of components given
if isinstance(components_in, str): if isinstance(components_in, str):
components = components_in.split('.') components = components_in.split(".")
else: else:
components = list(components_in) components = list(components_in)
@ -77,7 +75,6 @@ def path_exists(embed: Embed, components_in: Union[List[str], str]):
return False return False
def _mk_cfg_sess(config, session) -> tuple: def _mk_cfg_sess(config, session) -> tuple:
"""Return a tuple of (config, session).""" """Return a tuple of (config, session)."""
if config is None: if config is None:
@ -91,11 +88,11 @@ def _mk_cfg_sess(config, session) -> tuple:
def _md_base(config) -> Optional[tuple]: def _md_base(config) -> Optional[tuple]:
"""Return the protocol and base url for the mediaproxy.""" """Return the protocol and base url for the mediaproxy."""
md_base_url = config['MEDIA_PROXY'] md_base_url = config["MEDIA_PROXY"]
if md_base_url is None: if md_base_url is None:
return None return None
proto = 'https' if config['IS_SSL'] else 'http' proto = "https" if config["IS_SSL"] else "http"
return proto, md_base_url return proto, md_base_url
@ -111,7 +108,7 @@ def make_md_req_url(config, scope: str, url):
return url.url if isinstance(url, EmbedURL) else url return url.url if isinstance(url, EmbedURL) else url
proto, base_url = base proto, base_url = base
return f'{proto}://{base_url}/{scope}/{url.to_md_path}' return f"{proto}://{base_url}/{scope}/{url.to_md_path}"
def proxify(url, *, config=None) -> str: def proxify(url, *, config=None) -> str:
@ -122,11 +119,12 @@ def proxify(url, *, config=None) -> str:
if isinstance(url, str): if isinstance(url, str):
url = EmbedURL(url) url = EmbedURL(url)
return make_md_req_url(config, 'img', url) return make_md_req_url(config, "img", url)
async def _md_client_req(config, session, scope: str, async def _md_client_req(
url, *, ret_resp=False) -> Optional[Union[Tuple, Dict]]: config, session, scope: str, url, *, ret_resp=False
) -> Optional[Union[Tuple, Dict]]:
"""Makes a request to the mediaproxy. """Makes a request to the mediaproxy.
This has common code between all the main mediaproxy request functions This has common code between all the main mediaproxy request functions
@ -172,17 +170,13 @@ async def _md_client_req(config, session, scope: str,
return await resp.json() return await resp.json()
body = await resp.text() body = await resp.text()
log.warning('failed to call {!r}, {} {!r}', log.warning("failed to call {!r}, {} {!r}", request_url, resp.status, body)
request_url, resp.status, body)
return None return None
async def fetch_metadata(url, *, config=None, async def fetch_metadata(url, *, config=None, session=None) -> Optional[Dict]:
session=None) -> Optional[Dict]:
"""Fetch metadata for a url (image width, mime, etc).""" """Fetch metadata for a url (image width, mime, etc)."""
return await _md_client_req( return await _md_client_req(config, session, "meta", url)
config, session, 'meta', url
)
async def fetch_raw_img(url, *, config=None, session=None) -> Optional[tuple]: async def fetch_raw_img(url, *, config=None, session=None) -> Optional[tuple]:
@ -191,9 +185,7 @@ async def fetch_raw_img(url, *, config=None, session=None) -> Optional[tuple]:
Returns a tuple containing the response object and the raw bytes given by Returns a tuple containing the response object and the raw bytes given by
the website. the website.
""" """
tup = await _md_client_req( tup = await _md_client_req(config, session, "img", url, ret_resp=True)
config, session, 'img', url, ret_resp=True
)
if not tup: if not tup:
return None return None
@ -207,9 +199,7 @@ async def fetch_embed(url, *, config=None, session=None) -> Dict[str, Any]:
Returns a discord embed object. Returns a discord embed object.
""" """
return await _md_client_req( return await _md_client_req(config, session, "embed", url)
config, session, 'embed', url
)
async def fill_embed(embed: Optional[Embed]) -> Optional[Embed]: async def fill_embed(embed: Optional[Embed]) -> Optional[Embed]:
@ -229,22 +219,20 @@ async def fill_embed(embed: Optional[Embed]) -> Optional[Embed]:
embed = sanitize_embed(embed) embed = sanitize_embed(embed)
if path_exists(embed, 'footer.icon_url'): if path_exists(embed, "footer.icon_url"):
embed['footer']['proxy_icon_url'] = \ embed["footer"]["proxy_icon_url"] = proxify(embed["footer"]["icon_url"])
proxify(embed['footer']['icon_url'])
if path_exists(embed, 'author.icon_url'): if path_exists(embed, "author.icon_url"):
embed['author']['proxy_icon_url'] = \ embed["author"]["proxy_icon_url"] = proxify(embed["author"]["icon_url"])
proxify(embed['author']['icon_url'])
if path_exists(embed, 'image.url'): if path_exists(embed, "image.url"):
image_url = embed['image']['url'] image_url = embed["image"]["url"]
meta = await fetch_metadata(image_url) meta = await fetch_metadata(image_url)
embed['image']['proxy_url'] = proxify(image_url) embed["image"]["proxy_url"] = proxify(image_url)
if meta and meta['image']: if meta and meta["image"]:
embed['image']['width'] = meta['width'] embed["image"]["width"] = meta["width"]
embed['image']['height'] = meta['height'] embed["image"]["height"] = meta["height"]
return embed return embed

View File

@ -28,8 +28,8 @@ class EmbedURL:
def __init__(self, url: str): def __init__(self, url: str):
parsed = urllib.parse.urlparse(url) parsed = urllib.parse.urlparse(url)
if parsed.scheme not in ('http', 'https', 'attachment'): if parsed.scheme not in ("http", "https", "attachment"):
raise ValueError('Invalid URL scheme') raise ValueError("Invalid URL scheme")
self.scheme = parsed.scheme self.scheme = parsed.scheme
self.raw_url = url self.raw_url = url
@ -54,105 +54,61 @@ class EmbedURL:
def to_md_path(self) -> str: def to_md_path(self) -> str:
"""Convert the EmbedURL to a mediaproxy path (post img/meta).""" """Convert the EmbedURL to a mediaproxy path (post img/meta)."""
parsed = self.parsed parsed = self.parsed
return ( return f"{parsed.scheme}/{parsed.netloc}" f"{parsed.path}?{parsed.query}"
f'{parsed.scheme}/{parsed.netloc}'
f'{parsed.path}?{parsed.query}'
)
EMBED_FOOTER = { EMBED_FOOTER = {
'text': { "text": {"type": "string", "minlength": 1, "maxlength": 1024, "required": True},
'type': 'string', 'minlength': 1, 'maxlength': 1024, 'required': True}, "icon_url": {"coerce": EmbedURL, "required": False},
'icon_url': {
'coerce': EmbedURL, 'required': False,
},
# NOTE: proxy_icon_url set by us # NOTE: proxy_icon_url set by us
} }
EMBED_IMAGE = { EMBED_IMAGE = {
'url': {'coerce': EmbedURL, 'required': True}, "url": {"coerce": EmbedURL, "required": True},
# NOTE: proxy_url, width, height set by us # NOTE: proxy_url, width, height set by us
} }
EMBED_THUMBNAIL = EMBED_IMAGE EMBED_THUMBNAIL = EMBED_IMAGE
EMBED_AUTHOR = { EMBED_AUTHOR = {
'name': { "name": {"type": "string", "minlength": 1, "maxlength": 256, "required": False},
'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': False "url": {"coerce": EmbedURL, "required": False},
}, "icon_url": {"coerce": EmbedURL, "required": False}
'url': {
'coerce': EmbedURL, 'required': False,
},
'icon_url': {
'coerce': EmbedURL, 'required': False,
}
# NOTE: proxy_icon_url set by us # NOTE: proxy_icon_url set by us
} }
EMBED_FIELD = { EMBED_FIELD = {
'name': { "name": {"type": "string", "minlength": 1, "maxlength": 256, "required": True},
'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': True "value": {"type": "string", "minlength": 1, "maxlength": 1024, "required": True},
}, "inline": {"type": "boolean", "required": False, "default": True},
'value': {
'type': 'string', 'minlength': 1, 'maxlength': 1024, 'required': True
},
'inline': {
'type': 'boolean', 'required': False, 'default': True,
},
} }
EMBED_OBJECT = { EMBED_OBJECT = {
'title': { "title": {"type": "string", "minlength": 1, "maxlength": 256, "required": False},
'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': False},
# NOTE: type set by us # NOTE: type set by us
'description': { "description": {
'type': 'string', 'minlength': 1, 'maxlength': 2048, 'required': False, "type": "string",
"minlength": 1,
"maxlength": 2048,
"required": False,
}, },
'url': { "url": {"coerce": EmbedURL, "required": False},
'coerce': EmbedURL, 'required': False, "timestamp": {
},
'timestamp': {
# TODO: an ISO 8601 type # TODO: an ISO 8601 type
# TODO: maybe replace the default in here with now().isoformat? # TODO: maybe replace the default in here with now().isoformat?
'type': 'string', 'required': False "type": "string",
"required": False,
}, },
"color": {"coerce": Color, "required": False},
'color': { "footer": {"type": "dict", "schema": EMBED_FOOTER, "required": False},
'coerce': Color, 'required': False "image": {"type": "dict", "schema": EMBED_IMAGE, "required": False},
}, "thumbnail": {"type": "dict", "schema": EMBED_THUMBNAIL, "required": False},
'footer': {
'type': 'dict',
'schema': EMBED_FOOTER,
'required': False,
},
'image': {
'type': 'dict',
'schema': EMBED_IMAGE,
'required': False,
},
'thumbnail': {
'type': 'dict',
'schema': EMBED_THUMBNAIL,
'required': False,
},
# NOTE: 'video' set by us # NOTE: 'video' set by us
# NOTE: 'provider' set by us # NOTE: 'provider' set by us
"author": {"type": "dict", "schema": EMBED_AUTHOR, "required": False},
'author': { "fields": {
'type': 'dict', "type": "list",
'schema': EMBED_AUTHOR, "schema": {"type": "dict", "schema": EMBED_FIELD},
'required': False, "required": False,
},
'fields': {
'type': 'list',
'schema': {'type': 'dict', 'schema': EMBED_FIELD},
'required': False,
}, },
} }

View File

@ -52,13 +52,14 @@ class Flags:
>>> i2.is_field_3 >>> i2.is_field_3
False False
""" """
def __init_subclass__(cls, **_kwargs): def __init_subclass__(cls, **_kwargs):
attrs = inspect.getmembers(cls, lambda x: not inspect.isroutine(x)) attrs = inspect.getmembers(cls, lambda x: not inspect.isroutine(x))
def _make_int(value): def _make_int(value):
res = Flags() res = Flags()
setattr(res, 'value', value) setattr(res, "value", value)
for attr, val in attrs: for attr, val in attrs:
# get only the ones that represent a field in the # get only the ones that represent a field in the
@ -69,7 +70,7 @@ class Flags:
has_attr = (value & val) == val has_attr = (value & val) == val
# set each attribute # set each attribute
setattr(res, f'is_{attr}', has_attr) setattr(res, f"is_{attr}", has_attr)
return res return res
@ -84,17 +85,16 @@ class ChannelType(EasyEnum):
GUILD_CATEGORY = 4 GUILD_CATEGORY = 4
GUILD_CHANS = (ChannelType.GUILD_TEXT, GUILD_CHANS = (
ChannelType.GUILD_VOICE, ChannelType.GUILD_TEXT,
ChannelType.GUILD_CATEGORY) ChannelType.GUILD_VOICE,
ChannelType.GUILD_CATEGORY,
VOICE_CHANNELS = (
ChannelType.DM, ChannelType.GUILD_VOICE,
ChannelType.GUILD_CATEGORY
) )
VOICE_CHANNELS = (ChannelType.DM, ChannelType.GUILD_VOICE, ChannelType.GUILD_CATEGORY)
class ActivityType(EasyEnum): class ActivityType(EasyEnum):
PLAYING = 0 PLAYING = 0
STREAMING = 1 STREAMING = 1
@ -120,7 +120,7 @@ SYS_MESSAGES = (
MessageType.CHANNEL_NAME_CHANGE, MessageType.CHANNEL_NAME_CHANGE,
MessageType.CHANNEL_ICON_CHANGE, MessageType.CHANNEL_ICON_CHANGE,
MessageType.CHANNEL_PINNED_MESSAGE, MessageType.CHANNEL_PINNED_MESSAGE,
MessageType.GUILD_MEMBER_JOIN MessageType.GUILD_MEMBER_JOIN,
) )
@ -137,6 +137,7 @@ class ActivityFlags(Flags):
Only related to rich presence. Only related to rich presence.
""" """
instance = 1 instance = 1
join = 2 join = 2
spectate = 4 spectate = 4
@ -150,6 +151,7 @@ class UserFlags(Flags):
Used by the client to show badges. Used by the client to show badges.
""" """
staff = 1 staff = 1
partner = 2 partner = 2
hypesquad = 4 hypesquad = 4
@ -166,6 +168,7 @@ class UserFlags(Flags):
class MessageFlags(Flags): class MessageFlags(Flags):
"""Message flags.""" """Message flags."""
none = 0 none = 0
crossposted = 1 << 0 crossposted = 1 << 0
@ -175,11 +178,12 @@ class MessageFlags(Flags):
class StatusType(EasyEnum): class StatusType(EasyEnum):
"""All statuses there can be in a presence.""" """All statuses there can be in a presence."""
ONLINE = 'online'
DND = 'dnd' ONLINE = "online"
IDLE = 'idle' DND = "dnd"
INVISIBLE = 'invisible' IDLE = "idle"
OFFLINE = 'offline' INVISIBLE = "invisible"
OFFLINE = "offline"
class ExplicitFilter(EasyEnum): class ExplicitFilter(EasyEnum):
@ -187,6 +191,7 @@ class ExplicitFilter(EasyEnum):
Also applies to guilds. Also applies to guilds.
""" """
EDGE = 0 EDGE = 0
FRIENDS = 1 FRIENDS = 1
SAFE = 2 SAFE = 2
@ -194,6 +199,7 @@ class ExplicitFilter(EasyEnum):
class VerificationLevel(IntEnum): class VerificationLevel(IntEnum):
"""Verification level for guilds.""" """Verification level for guilds."""
NONE = 0 NONE = 0
LOW = 1 LOW = 1
MEDIUM = 2 MEDIUM = 2
@ -205,6 +211,7 @@ class VerificationLevel(IntEnum):
class RelationshipType(EasyEnum): class RelationshipType(EasyEnum):
"""Relationship types between users.""" """Relationship types between users."""
FRIEND = 1 FRIEND = 1
BLOCK = 2 BLOCK = 2
INCOMING = 3 INCOMING = 3
@ -213,6 +220,7 @@ class RelationshipType(EasyEnum):
class MessageNotifications(EasyEnum): class MessageNotifications(EasyEnum):
"""Message notifications""" """Message notifications"""
ALL = 0 ALL = 0
MENTIONS = 1 MENTIONS = 1
NOTHING = 2 NOTHING = 2
@ -220,6 +228,7 @@ class MessageNotifications(EasyEnum):
class PremiumType: class PremiumType:
"""Premium (Nitro) type.""" """Premium (Nitro) type."""
TIER_1 = 1 TIER_1 = 1
TIER_2 = 2 TIER_2 = 2
NONE = None NONE = None
@ -227,12 +236,13 @@ class PremiumType:
class Feature(EasyEnum): class Feature(EasyEnum):
"""Guild features.""" """Guild features."""
invite_splash = 'INVITE_SPLASH'
vip = 'VIP_REGIONS' invite_splash = "INVITE_SPLASH"
vanity = 'VANITY_URL' vip = "VIP_REGIONS"
emoji = 'MORE_EMOJI' vanity = "VANITY_URL"
verified = 'VERIFIED' emoji = "MORE_EMOJI"
verified = "VERIFIED"
# unknown # unknown
commerce = 'COMMERCE' commerce = "COMMERCE"
news = 'NEWS' news = "NEWS"

View File

@ -18,60 +18,64 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
ERR_MSG_MAP = { ERR_MSG_MAP = {
10001: 'Unknown account', 10001: "Unknown account",
10002: 'Unknown application', 10002: "Unknown application",
10003: 'Unknown channel', 10003: "Unknown channel",
10004: 'Unknown guild', 10004: "Unknown guild",
10005: 'Unknown integration', 10005: "Unknown integration",
10006: 'Unknown invite', 10006: "Unknown invite",
10007: 'Unknown member', 10007: "Unknown member",
10008: 'Unknown message', 10008: "Unknown message",
10009: 'Unknown overwrite', 10009: "Unknown overwrite",
10010: 'Unknown provider', 10010: "Unknown provider",
10011: 'Unknown role', 10011: "Unknown role",
10012: 'Unknown token', 10012: "Unknown token",
10013: 'Unknown user', 10013: "Unknown user",
10014: 'Unknown Emoji', 10014: "Unknown Emoji",
10015: 'Unknown Webhook', 10015: "Unknown Webhook",
20001: 'Bots cannot use this endpoint', 20001: "Bots cannot use this endpoint",
20002: 'Only bots can use this endpoint', 20002: "Only bots can use this endpoint",
30001: 'Maximum number of guilds reached (100)', 30001: "Maximum number of guilds reached (100)",
30002: 'Maximum number of friends reached (1000)', 30002: "Maximum number of friends reached (1000)",
30003: 'Maximum number of pins reached (50)', 30003: "Maximum number of pins reached (50)",
30005: 'Maximum number of guild roles reached (250)', 30005: "Maximum number of guild roles reached (250)",
30010: 'Maximum number of reactions reached (20)', 30010: "Maximum number of reactions reached (20)",
30013: 'Maximum number of guild channels reached (500)', 30013: "Maximum number of guild channels reached (500)",
40001: 'Unauthorized', 40001: "Unauthorized",
50001: 'Missing access', 50001: "Missing access",
50002: 'Invalid account type', 50002: "Invalid account type",
50003: 'Cannot execute action on a DM channel', 50003: "Cannot execute action on a DM channel",
50004: 'Widget Disabled', 50004: "Widget Disabled",
50005: 'Cannot edit a message authored by another user', 50005: "Cannot edit a message authored by another user",
50006: 'Cannot send an empty message', 50006: "Cannot send an empty message",
50007: 'Cannot send messages to this user', 50007: "Cannot send messages to this user",
50008: 'Cannot send messages in a voice channel', 50008: "Cannot send messages in a voice channel",
50009: 'Channel verification level is too high', 50009: "Channel verification level is too high",
50010: 'OAuth2 application does not have a bot', 50010: "OAuth2 application does not have a bot",
50011: 'OAuth2 application limit reached', 50011: "OAuth2 application limit reached",
50012: 'Invalid OAuth state', 50012: "Invalid OAuth state",
50013: 'Missing permissions', 50013: "Missing permissions",
50014: 'Invalid authentication token', 50014: "Invalid authentication token",
50015: 'Note is too long', 50015: "Note is too long",
50016: ('Provided too few or too many messages to delete. Must provide at ' 50016: (
'least 2 and fewer than 100 messages to delete.'), "Provided too few or too many messages to delete. Must provide at "
50019: 'A message can only be pinned to the channel it was sent in', "least 2 and fewer than 100 messages to delete."
50020: 'Invite code is either invalid or taken.', ),
50021: 'Cannot execute action on a system message', 50019: "A message can only be pinned to the channel it was sent in",
50025: 'Invalid OAuth2 access token', 50020: "Invite code is either invalid or taken.",
50034: 'A message provided was too old to bulk delete', 50021: "Cannot execute action on a system message",
50035: 'Invalid Form Body', 50025: "Invalid OAuth2 access token",
50036: 'An invite was accepted to a guild the application\'s bot is not in', 50034: "A message provided was too old to bulk delete",
50041: 'Invalid API version', 50035: "Invalid Form Body",
90001: 'Reaction blocked', 50036: "An invite was accepted to a guild the application's bot is not in",
50041: "Invalid API version",
90001: "Reaction blocked",
} }
class LitecordError(Exception): class LitecordError(Exception):
"""Base class for litecord errors""" """Base class for litecord errors"""
status_code = 500 status_code = 500
def _get_err_msg(self, err_code: int) -> str: def _get_err_msg(self, err_code: int) -> str:
@ -91,7 +95,7 @@ class LitecordError(Exception):
return message return message
except IndexError: except IndexError:
return self._get_err_msg(getattr(self, 'error_code', None)) return self._get_err_msg(getattr(self, "error_code", None))
@property @property
def json(self): def json(self):
@ -143,7 +147,7 @@ class MissingPermissions(Forbidden):
class WebsocketClose(Exception): class WebsocketClose(Exception):
@property @property
def code(self): def code(self):
from_class = getattr(self, 'close_code', None) from_class = getattr(self, "close_code", None)
if from_class: if from_class:
return from_class return from_class
@ -152,7 +156,7 @@ class WebsocketClose(Exception):
@property @property
def reason(self): def reason(self):
from_class = getattr(self, 'close_code', None) from_class = getattr(self, "close_code", None)
if from_class: if from_class:
return self.args[0] return self.args[0]

View File

@ -25,8 +25,7 @@ from litecord.utils import LitecordJSONEncoder
def encode_json(payload) -> str: def encode_json(payload) -> str:
"""Encode a given payload to JSON.""" """Encode a given payload to JSON."""
return json.dumps(payload, separators=(',', ':'), return json.dumps(payload, separators=(",", ":"), cls=LitecordJSONEncoder)
cls=LitecordJSONEncoder)
def decode_json(data: str): def decode_json(data: str):
@ -71,6 +70,7 @@ def _etf_decode_dict(data):
return result return result
def decode_etf(data: bytes): def decode_etf(data: bytes):
"""Decode data in ETF to any.""" """Decode data in ETF to any."""
res = earl.unpack(data) res = earl.unpack(data)

View File

@ -24,37 +24,36 @@ from litecord.gateway.websocket import GatewayWebsocket
async def websocket_handler(app, ws, url): async def websocket_handler(app, ws, url):
"""Main websocket handler, checks query arguments when connecting to """Main websocket handler, checks query arguments when connecting to
the gateway and spawns a GatewayWebsocket instance for the connection.""" the gateway and spawns a GatewayWebsocket instance for the connection."""
args = urllib.parse.parse_qs( args = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
urllib.parse.urlparse(url).query
)
# pull a dict.get but in a really bad way. # pull a dict.get but in a really bad way.
try: try:
gw_version = args['v'][0] gw_version = args["v"][0]
except (KeyError, IndexError): except (KeyError, IndexError):
gw_version = '6' gw_version = "6"
try: try:
gw_encoding = args['encoding'][0] gw_encoding = args["encoding"][0]
except (KeyError, IndexError): except (KeyError, IndexError):
gw_encoding = 'json' gw_encoding = "json"
if gw_version not in ('6', '7'): if gw_version not in ("6", "7"):
return await ws.close(1000, 'Invalid gateway version') return await ws.close(1000, "Invalid gateway version")
if gw_encoding not in ('json', 'etf'): if gw_encoding not in ("json", "etf"):
return await ws.close(1000, 'Invalid gateway encoding') return await ws.close(1000, "Invalid gateway encoding")
try: try:
gw_compress = args['compress'][0] gw_compress = args["compress"][0]
except (KeyError, IndexError): except (KeyError, IndexError):
gw_compress = None gw_compress = None
if gw_compress and gw_compress not in ('zlib-stream', 'zstd-stream'): if gw_compress and gw_compress not in ("zlib-stream", "zstd-stream"):
return await ws.close(1000, 'Invalid gateway compress') return await ws.close(1000, "Invalid gateway compress")
gws = GatewayWebsocket( gws = GatewayWebsocket(
ws, app, v=gw_version, encoding=gw_encoding, compress=gw_compress) ws, app, v=gw_version, encoding=gw_encoding, compress=gw_compress
)
# this can be run with a single await since this whole coroutine # this can be run with a single await since this whole coroutine
# is already running in the background. # is already running in the background.

View File

@ -17,8 +17,10 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
class OP: class OP:
"""Gateway OP codes.""" """Gateway OP codes."""
DISPATCH = 0 DISPATCH = 0
HEARTBEAT = 1 HEARTBEAT = 1
IDENTIFY = 2 IDENTIFY = 2

View File

@ -32,6 +32,7 @@ class PayloadStore:
This will only store a maximum of MAX_STORE_SIZE, This will only store a maximum of MAX_STORE_SIZE,
dropping the older payloads when adding new ones. dropping the older payloads when adding new ones.
""" """
MAX_STORE_SIZE = 250 MAX_STORE_SIZE = 250
def __init__(self): def __init__(self):
@ -60,20 +61,20 @@ class GatewayState:
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.session_id = kwargs.get('session_id', gen_session_id()) self.session_id = kwargs.get("session_id", gen_session_id())
#: event sequence number #: event sequence number
self.seq = kwargs.get('seq', 0) self.seq = kwargs.get("seq", 0)
#: last seq sent by us, the backend #: last seq sent by us, the backend
self.last_seq = 0 self.last_seq = 0
#: shard information about the state, #: shard information about the state,
# its id and shard count # its id and shard count
self.shard = kwargs.get('shard', [0, 1]) self.shard = kwargs.get("shard", [0, 1])
self.user_id = kwargs.get('user_id') self.user_id = kwargs.get("user_id")
self.bot = kwargs.get('bot', False) self.bot = kwargs.get("bot", False)
#: set by the gateway connection #: set by the gateway connection
# on OP STATUS_UPDATE # on OP STATUS_UPDATE
@ -90,5 +91,4 @@ class GatewayState:
self.__dict__[key] = value self.__dict__[key] = value
def __repr__(self): def __repr__(self):
return (f'GatewayState<seq={self.seq} ' return f"GatewayState<seq={self.seq} " f"shard={self.shard} uid={self.user_id}>"
f'shard={self.shard} uid={self.user_id}>')

View File

@ -39,6 +39,7 @@ class ManagerClose(Exception):
class StateDictWrapper: class StateDictWrapper:
"""Wrap a mapping so that any kind of access to the mapping while the """Wrap a mapping so that any kind of access to the mapping while the
state manager is closed raises a ManagerClose error""" state manager is closed raises a ManagerClose error"""
def __init__(self, state_manager, mapping): def __init__(self, state_manager, mapping):
self.state_manager = state_manager self.state_manager = state_manager
self._map = mapping self._map = mapping
@ -98,7 +99,7 @@ class StateManager:
"""Insert a new state object.""" """Insert a new state object."""
user_states = self.states[state.user_id] user_states = self.states[state.user_id]
log.debug('inserting state: {!r}', state) log.debug("inserting state: {!r}", state)
user_states[state.session_id] = state user_states[state.session_id] = state
self.states_raw[state.session_id] = state self.states_raw[state.session_id] = state
@ -128,7 +129,7 @@ class StateManager:
pass pass
try: try:
log.debug('removing state: {!r}', state) log.debug("removing state: {!r}", state)
self.states[state.user_id].pop(state.session_id) self.states[state.user_id].pop(state.session_id)
except KeyError: except KeyError:
pass pass
@ -152,8 +153,7 @@ class StateManager:
"""Fetch all states tied to a single user.""" """Fetch all states tied to a single user."""
return list(self.states[user_id].values()) return list(self.states[user_id].values())
def guild_states(self, member_ids: List[int], def guild_states(self, member_ids: List[int], guild_id: int) -> List[GatewayState]:
guild_id: int) -> List[GatewayState]:
"""Fetch all possible states about members in a guild.""" """Fetch all possible states about members in a guild."""
states = [] states = []
@ -164,14 +164,14 @@ class StateManager:
# since server start, so we need to add a dummy state # since server start, so we need to add a dummy state
if not member_states: if not member_states:
dummy_state = GatewayState( dummy_state = GatewayState(
session_id='', session_id="",
user_id=member_id, user_id=member_id,
presence={ presence={
'afk': False, "afk": False,
'status': 'offline', "status": "offline",
'game': None, "game": None,
'since': 0 "since": 0,
} },
) )
states.append(dummy_state) states.append(dummy_state)
@ -187,9 +187,7 @@ class StateManager:
"""Send OP Reconnect to a single connection.""" """Send OP Reconnect to a single connection."""
websocket = state.ws websocket = state.ws
await websocket.send({ await websocket.send({"op": OP.RECONNECT})
'op': OP.RECONNECT
})
# wait 200ms # wait 200ms
# so that the client has time to process # so that the client has time to process
@ -198,12 +196,9 @@ class StateManager:
try: try:
# try to close the connection ourselves # try to close the connection ourselves
await websocket.ws.close( await websocket.ws.close(code=4000, reason="litecord shutting down")
code=4000,
reason='litecord shutting down'
)
except ConnectionClosed: except ConnectionClosed:
log.info('client {} already closed', state) log.info("client {} already closed", state)
def gen_close_tasks(self): def gen_close_tasks(self):
"""Generate the tasks that will order the clients """Generate the tasks that will order the clients
@ -222,11 +217,9 @@ class StateManager:
if not state.ws: if not state.ws:
continue continue
tasks.append( tasks.append(self.shutdown_single(state))
self.shutdown_single(state)
)
log.info('made {} shutdown tasks', len(tasks)) log.info("made {} shutdown tasks", len(tasks))
return tasks return tasks
def close(self): def close(self):

View File

@ -19,9 +19,11 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio import asyncio
class WebsocketFileHandler: class WebsocketFileHandler:
"""A handler around a websocket that wraps normal I/O calls into """A handler around a websocket that wraps normal I/O calls into
the websocket's respective asyncio calls via asyncio.ensure_future.""" the websocket's respective asyncio calls via asyncio.ensure_future."""
def __init__(self, ws): def __init__(self, ws):
self.ws = ws self.ws = ws

View File

@ -31,23 +31,20 @@ from logbook import Logger
from litecord.auth import raw_token_check from litecord.auth import raw_token_check
from litecord.enums import RelationshipType, ChannelType from litecord.enums import RelationshipType, ChannelType
from litecord.schemas import validate, GW_STATUS_UPDATE from litecord.schemas import validate, GW_STATUS_UPDATE
from litecord.utils import ( from litecord.utils import task_wrapper, yield_chunks, maybe_int
task_wrapper, yield_chunks, maybe_int
)
from litecord.permissions import get_permissions from litecord.permissions import get_permissions
from litecord.gateway.opcodes import OP from litecord.gateway.opcodes import OP
from litecord.gateway.state import GatewayState from litecord.gateway.state import GatewayState
from litecord.errors import ( from litecord.errors import WebsocketClose, Unauthorized, Forbidden, BadRequest
WebsocketClose, Unauthorized, Forbidden, BadRequest
)
from litecord.gateway.errors import ( from litecord.gateway.errors import (
DecodeError, UnknownOPCode, InvalidShard, ShardingRequired DecodeError,
) UnknownOPCode,
from litecord.gateway.encoding import ( InvalidShard,
encode_json, decode_json, encode_etf, decode_etf ShardingRequired,
) )
from litecord.gateway.encoding import encode_json, decode_json, encode_etf, decode_etf
from litecord.gateway.utils import WebsocketFileHandler from litecord.gateway.utils import WebsocketFileHandler
@ -56,15 +53,22 @@ from litecord.storage import int_
log = Logger(__name__) log = Logger(__name__)
WebsocketProperties = collections.namedtuple( WebsocketProperties = collections.namedtuple(
'WebsocketProperties', 'v encoding compress zctx zsctx tasks' "WebsocketProperties", "v encoding compress zctx zsctx tasks"
) )
WebsocketObjects = collections.namedtuple( WebsocketObjects = collections.namedtuple(
'WebsocketObjects', ( "WebsocketObjects",
'db', 'state_manager', 'storage', (
'loop', 'dispatcher', 'presence', 'ratelimiter', "db",
'user_storage', 'voice' "state_manager",
) "storage",
"loop",
"dispatcher",
"presence",
"ratelimiter",
"user_storage",
"voice",
),
) )
@ -73,9 +77,15 @@ class GatewayWebsocket:
def __init__(self, ws, app, **kwargs): def __init__(self, ws, app, **kwargs):
self.ext = WebsocketObjects( self.ext = WebsocketObjects(
app.db, app.state_manager, app.storage, app.loop, app.db,
app.dispatcher, app.presence, app.ratelimiter, app.state_manager,
app.user_storage, app.voice app.storage,
app.loop,
app.dispatcher,
app.presence,
app.ratelimiter,
app.user_storage,
app.voice,
) )
self.storage = self.ext.storage self.storage = self.ext.storage
@ -84,15 +94,15 @@ class GatewayWebsocket:
self.ws = ws self.ws = ws
self.wsp = WebsocketProperties( self.wsp = WebsocketProperties(
kwargs.get('v'), kwargs.get("v"),
kwargs.get('encoding', 'json'), kwargs.get("encoding", "json"),
kwargs.get('compress', None), kwargs.get("compress", None),
zlib.compressobj(), zlib.compressobj(),
zstd.ZstdCompressor(), zstd.ZstdCompressor(),
{} {},
) )
log.debug('websocket properties: {!r}', self.wsp) log.debug("websocket properties: {!r}", self.wsp)
self.state = None self.state = None
@ -102,8 +112,8 @@ class GatewayWebsocket:
encoding = self.wsp.encoding encoding = self.wsp.encoding
encodings = { encodings = {
'json': (encode_json, decode_json), "json": (encode_json, decode_json),
'etf': (encode_etf, decode_etf), "etf": (encode_etf, decode_etf),
} }
self.encoder, self.decoder = encodings[encoding] self.encoder, self.decoder = encodings[encoding]
@ -111,16 +121,17 @@ class GatewayWebsocket:
async def _chunked_send(self, data: bytes, chunk_size: int): async def _chunked_send(self, data: bytes, chunk_size: int):
"""Split data in chunk_size-big chunks and send them """Split data in chunk_size-big chunks and send them
over the websocket.""" over the websocket."""
log.debug('zlib-stream: chunking {} bytes into {}-byte chunks', log.debug(
len(data), chunk_size) "zlib-stream: chunking {} bytes into {}-byte chunks", len(data), chunk_size
)
total_chunks = 0 total_chunks = 0
for chunk in yield_chunks(data, chunk_size): for chunk in yield_chunks(data, chunk_size):
total_chunks += 1 total_chunks += 1
log.debug('zlib-stream: chunk {}', total_chunks) log.debug("zlib-stream: chunk {}", total_chunks)
await self.ws.send(chunk) await self.ws.send(chunk)
log.debug('zlib-stream: sent {} chunks', total_chunks) log.debug("zlib-stream: sent {} chunks", total_chunks)
async def _zlib_stream_send(self, encoded): async def _zlib_stream_send(self, encoded):
"""Sending a single payload across multiple compressed """Sending a single payload across multiple compressed
@ -130,8 +141,12 @@ class GatewayWebsocket:
data1 = self.wsp.zctx.compress(encoded) data1 = self.wsp.zctx.compress(encoded)
data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH) data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH)
log.debug('zlib-stream: length {} -> compressed ({} + {})', log.debug(
len(encoded), len(data1), len(data2)) "zlib-stream: length {} -> compressed ({} + {})",
len(encoded),
len(data1),
len(data2),
)
if not data1: if not data1:
# if data1 is nothing, that might cause problems # if data1 is nothing, that might cause problems
@ -139,8 +154,11 @@ class GatewayWebsocket:
data1 = bytes([data2[0]]) data1 = bytes([data2[0]])
data2 = data2[1:] data2 = data2[1:]
log.debug('zlib-stream: len(data1) == 0, remaking as ({} + {})', log.debug(
len(data1), len(data2)) "zlib-stream: len(data1) == 0, remaking as ({} + {})",
len(data1),
len(data2),
)
# NOTE: the old approach was ws.send(data1 + data2). # NOTE: the old approach was ws.send(data1 + data2).
# I changed this to a chunked send of data1 and data2 # I changed this to a chunked send of data1 and data2
@ -157,8 +175,7 @@ class GatewayWebsocket:
await self._chunked_send(data2, 1024) await self._chunked_send(data2, 1024)
async def _zstd_stream_send(self, encoded): async def _zstd_stream_send(self, encoded):
compressor = self.wsp.zsctx.stream_writer( compressor = self.wsp.zsctx.stream_writer(WebsocketFileHandler(self.ws))
WebsocketFileHandler(self.ws))
compressor.write(encoded) compressor.write(encoded)
compressor.flush(zstd.FLUSH_FRAME) compressor.flush(zstd.FLUSH_FRAME)
@ -172,21 +189,23 @@ class GatewayWebsocket:
encoded = self.encoder(payload) encoded = self.encoder(payload)
if len(encoded) < 2048: if len(encoded) < 2048:
log.debug('sending\n{}', pprint.pformat(payload)) log.debug("sending\n{}", pprint.pformat(payload))
else: else:
log.debug('sending {}', pprint.pformat(payload)) log.debug("sending {}", pprint.pformat(payload))
log.debug('sending op={} s={} t={} (too big)', log.debug(
payload.get('op'), "sending op={} s={} t={} (too big)",
payload.get('s'), payload.get("op"),
payload.get('t')) payload.get("s"),
payload.get("t"),
)
# treat encoded as bytes # treat encoded as bytes
if not isinstance(encoded, bytes): if not isinstance(encoded, bytes):
encoded = encoded.encode() encoded = encoded.encode()
if self.wsp.compress == 'zlib-stream': if self.wsp.compress == "zlib-stream":
await self._zlib_stream_send(encoded) await self._zlib_stream_send(encoded)
elif self.wsp.compress == 'zstd-stream': elif self.wsp.compress == "zstd-stream":
await self._zstd_stream_send(encoded) await self._zstd_stream_send(encoded)
elif self.state and self.state.compress and len(encoded) > 1024: elif self.state and self.state.compress and len(encoded) > 1024:
# TODO: should we only compress on >1KB packets? or maybe we # TODO: should we only compress on >1KB packets? or maybe we
@ -203,16 +222,10 @@ class GatewayWebsocket:
async def send_op(self, op_code: int, data: Any): async def send_op(self, op_code: int, data: Any):
"""Send a packet but just the OP code information is filled in.""" """Send a packet but just the OP code information is filled in."""
await self.send({ await self.send({"op": op_code, "d": data, "t": None, "s": None})
'op': op_code,
'd': data,
't': None,
's': None
})
def _check_ratelimit(self, key: str, ratelimit_key): def _check_ratelimit(self, key: str, ratelimit_key):
ratelimit = self.ext.ratelimiter.get_ratelimit(f'_ws.{key}') ratelimit = self.ext.ratelimiter.get_ratelimit(f"_ws.{key}")
bucket = ratelimit.get_bucket(ratelimit_key) bucket = ratelimit.get_bucket(ratelimit_key)
return bucket.update_rate_limit() return bucket.update_rate_limit()
@ -221,19 +234,19 @@ class GatewayWebsocket:
# if the client heartbeats in time, # if the client heartbeats in time,
# this task will be cancelled. # this task will be cancelled.
await asyncio.sleep(interval / 1000) await asyncio.sleep(interval / 1000)
await self.ws.close(4000, 'Heartbeat expired') await self.ws.close(4000, "Heartbeat expired")
self._cleanup() self._cleanup()
def _hb_start(self, interval: int): def _hb_start(self, interval: int):
# always refresh the heartbeat task # always refresh the heartbeat task
# when possible # when possible
task = self.wsp.tasks.get('heartbeat') task = self.wsp.tasks.get("heartbeat")
if task: if task:
task.cancel() task.cancel()
self.wsp.tasks['heartbeat'] = self.ext.loop.create_task( self.wsp.tasks["heartbeat"] = self.ext.loop.create_task(
task_wrapper('hb wait', self._hb_wait(interval)) task_wrapper("hb wait", self._hb_wait(interval))
) )
async def _send_hello(self): async def _send_hello(self):
@ -241,12 +254,9 @@ class GatewayWebsocket:
# random heartbeat intervals # random heartbeat intervals
interval = randint(40, 46) * 1000 interval = randint(40, 46) * 1000
await self.send_op(OP.HELLO, { await self.send_op(
'heartbeat_interval': interval, OP.HELLO, {"heartbeat_interval": interval, "_trace": ["lesbian-server"]}
'_trace': [ )
'lesbian-server'
],
})
self._hb_start(interval) self._hb_start(interval)
@ -255,16 +265,15 @@ class GatewayWebsocket:
self.state.seq += 1 self.state.seq += 1
payload = { payload = {
'op': OP.DISPATCH, "op": OP.DISPATCH,
't': event.upper(), "t": event.upper(),
's': self.state.seq, "s": self.state.seq,
'd': data, "d": data,
} }
self.state.store[self.state.seq] = payload self.state.store[self.state.seq] = payload
log.debug('sending payload {!r} sid {}', log.debug("sending payload {!r} sid {}", event.upper(), self.state.session_id)
event.upper(), self.state.session_id)
await self.send(payload) await self.send(payload)
@ -274,16 +283,14 @@ class GatewayWebsocket:
guild_ids = await self._guild_ids() guild_ids = await self._guild_ids()
if self.state.bot: if self.state.bot:
return [{ return [{"id": row, "unavailable": True} for row in guild_ids]
'id': row,
'unavailable': True,
} for row in guild_ids]
return [ return [
{ {
**await self.storage.get_guild(guild_id, user_id), **await self.storage.get_guild(guild_id, user_id),
**await self.storage.get_guild_extra(guild_id, user_id, **await self.storage.get_guild_extra(
self.state.large) guild_id, user_id, self.state.large
),
} }
for guild_id in guild_ids for guild_id in guild_ids
] ]
@ -298,13 +305,13 @@ class GatewayWebsocket:
for guild_obj in unavailable_guilds: for guild_obj in unavailable_guilds:
# fetch full guild object including the 'large' field # fetch full guild object including the 'large' field
guild = await self.storage.get_guild_full( guild = await self.storage.get_guild_full(
int(guild_obj['id']), self.state.user_id, self.state.large int(guild_obj["id"]), self.state.user_id, self.state.large
) )
if guild is None: if guild is None:
continue continue
await self.dispatch('GUILD_CREATE', guild) await self.dispatch("GUILD_CREATE", guild)
async def _user_ready(self) -> dict: async def _user_ready(self) -> dict:
"""Fetch information about users in the READY packet. """Fetch information about users in the READY packet.
@ -317,28 +324,28 @@ class GatewayWebsocket:
relationships = await self.user_storage.get_relationships(user_id) relationships = await self.user_storage.get_relationships(user_id)
friend_ids = [int(r['user']['id']) for r in relationships friend_ids = [
if r['type'] == RelationshipType.FRIEND.value] int(r["user"]["id"])
for r in relationships
if r["type"] == RelationshipType.FRIEND.value
]
friend_presences = await self.ext.presence.friend_presences(friend_ids) friend_presences = await self.ext.presence.friend_presences(friend_ids)
settings = await self.user_storage.get_user_settings(user_id) settings = await self.user_storage.get_user_settings(user_id)
return { return {
'user_settings': settings, "user_settings": settings,
'notes': await self.user_storage.fetch_notes(user_id), "notes": await self.user_storage.fetch_notes(user_id),
'relationships': relationships, "relationships": relationships,
'presences': friend_presences, "presences": friend_presences,
'read_state': await self.user_storage.get_read_state(user_id), "read_state": await self.user_storage.get_read_state(user_id),
'user_guild_settings': await self.user_storage.get_guild_settings( "user_guild_settings": await self.user_storage.get_guild_settings(user_id),
user_id), "friend_suggestion_count": 0,
'friend_suggestion_count': 0,
# those are unused default values. # those are unused default values.
'connected_accounts': [], "connected_accounts": [],
'experiments': [], "experiments": [],
'guild_experiments': [], "guild_experiments": [],
'analytics_token': 'transbian', "analytics_token": "transbian",
} }
async def dispatch_ready(self): async def dispatch_ready(self):
@ -353,24 +360,21 @@ class GatewayWebsocket:
# user, fetch info # user, fetch info
user_ready = await self._user_ready() user_ready = await self._user_ready()
private_channels = ( private_channels = await self.user_storage.get_dms(
await self.user_storage.get_dms(user_id) + user_id
await self.user_storage.get_gdms(user_id) ) + await self.user_storage.get_gdms(user_id)
)
base_ready = { base_ready = {
'v': 6, "v": 6,
'user': user, "user": user,
"private_channels": private_channels,
'private_channels': private_channels, "guilds": guilds,
"session_id": self.state.session_id,
'guilds': guilds, "_trace": ["transbian"],
'session_id': self.state.session_id, "shard": self.state.shard,
'_trace': ['transbian'],
'shard': self.state.shard,
} }
await self.dispatch('READY', {**base_ready, **user_ready}) await self.dispatch("READY", {**base_ready, **user_ready})
# async dispatch of guilds # async dispatch of guilds
self.ext.loop.create_task(self._guild_dispatch(guilds)) self.ext.loop.create_task(self._guild_dispatch(guilds))
@ -380,33 +384,32 @@ class GatewayWebsocket:
""" """
current_shard, shard_count = shard current_shard, shard_count = shard
guilds = await self.ext.db.fetchval(""" guilds = await self.ext.db.fetchval(
"""
SELECT COUNT(*) SELECT COUNT(*)
FROM members FROM members
WHERE user_id = $1 WHERE user_id = $1
""", user_id) """,
user_id,
)
recommended = max(int(guilds / 1200), 1) recommended = max(int(guilds / 1200), 1)
if shard_count < recommended: if shard_count < recommended:
raise ShardingRequired('Too many guilds for shard ' raise ShardingRequired("Too many guilds for shard " f"{current_shard}")
f'{current_shard}')
if guilds > 2500 and guilds / shard_count > 0.8: if guilds > 2500 and guilds / shard_count > 0.8:
raise ShardingRequired('Too many shards. ' raise ShardingRequired("Too many shards. " f"(g={guilds} sc={shard_count})")
f'(g={guilds} sc={shard_count})')
if current_shard > shard_count: if current_shard > shard_count:
raise InvalidShard('Shard count > Total shards') raise InvalidShard("Shard count > Total shards")
async def _guild_ids(self) -> list: async def _guild_ids(self) -> list:
"""Get a list of Guild IDs that are tied to this connection. """Get a list of Guild IDs that are tied to this connection.
The implementation is shard-aware. The implementation is shard-aware.
""" """
guild_ids = await self.user_storage.get_user_guilds( guild_ids = await self.user_storage.get_user_guilds(self.state.user_id)
self.state.user_id
)
shard_id = self.state.current_shard shard_id = self.state.current_shard
shard_count = self.state.shard_count shard_count = self.state.shard_count
@ -414,10 +417,7 @@ class GatewayWebsocket:
def _get_shard(guild_id): def _get_shard(guild_id):
return (guild_id >> 22) % shard_count return (guild_id >> 22) % shard_count
filtered = filter( filtered = filter(lambda guild_id: _get_shard(guild_id) == shard_id, guild_ids)
lambda guild_id: _get_shard(guild_id) == shard_id,
guild_ids
)
return list(filtered) return list(filtered)
@ -432,13 +432,17 @@ class GatewayWebsocket:
# subscribe the user to all dms they have OPENED. # subscribe the user to all dms they have OPENED.
dms = await self.user_storage.get_dms(user_id) dms = await self.user_storage.get_dms(user_id)
dm_ids = [int(dm['id']) for dm in dms] dm_ids = [int(dm["id"]) for dm in dms]
# fetch all group dms the user is a member of. # fetch all group dms the user is a member of.
gdm_ids = await self.user_storage.get_gdms_internal(user_id) gdm_ids = await self.user_storage.get_gdms_internal(user_id)
log.info('subscribing to {} guilds {} dms {} gdms', log.info(
len(guild_ids), len(dm_ids), len(gdm_ids)) "subscribing to {} guilds {} dms {} gdms",
len(guild_ids),
len(dm_ids),
len(gdm_ids),
)
# guild_subscriptions: # guild_subscriptions:
# enables dispatching of guild subscription events # enables dispatching of guild subscription events
@ -447,10 +451,13 @@ class GatewayWebsocket:
# we enable processing of guild_subscriptions by adding flags # we enable processing of guild_subscriptions by adding flags
# when subscribing to the given backend. those are optional. # when subscribing to the given backend. those are optional.
channels_to_sub = [ channels_to_sub = [
('guild', guild_ids, (
{'presence': guild_subscriptions, 'typing': guild_subscriptions}), "guild",
('channel', dm_ids), guild_ids,
('channel', gdm_ids), {"presence": guild_subscriptions, "typing": guild_subscriptions},
),
("channel", dm_ids),
("channel", gdm_ids),
] ]
await self.ext.dispatcher.mass_sub(user_id, channels_to_sub) await self.ext.dispatcher.mass_sub(user_id, channels_to_sub)
@ -460,28 +467,26 @@ class GatewayWebsocket:
# (their friends will also subscribe back # (their friends will also subscribe back
# when they come online) # when they come online)
friend_ids = await self.user_storage.get_friend_ids(user_id) friend_ids = await self.user_storage.get_friend_ids(user_id)
log.info('subscribing to {} friends', len(friend_ids)) log.info("subscribing to {} friends", len(friend_ids))
await self.ext.dispatcher.sub_many('friend', user_id, friend_ids) await self.ext.dispatcher.sub_many("friend", user_id, friend_ids)
async def update_status(self, status: dict): async def update_status(self, status: dict):
"""Update the status of the current websocket connection.""" """Update the status of the current websocket connection."""
if not self.state: if not self.state:
return return
if self._check_ratelimit('presence', self.state.session_id): if self._check_ratelimit("presence", self.state.session_id):
# Presence Updates beyond the ratelimit # Presence Updates beyond the ratelimit
# are just silently dropped. # are just silently dropped.
return return
default_status = { default_status = {
'afk': False, "afk": False,
# TODO: fetch status from settings # TODO: fetch status from settings
'status': 'online', "status": "online",
'game': None, "game": None,
# TODO: this # TODO: this
'since': 0, "since": 0,
} }
status = {**(status or {}), **default_status} status = {**(status or {}), **default_status}
@ -489,39 +494,40 @@ class GatewayWebsocket:
try: try:
status = validate(status, GW_STATUS_UPDATE) status = validate(status, GW_STATUS_UPDATE)
except BadRequest as err: except BadRequest as err:
log.warning(f'Invalid status update: {err}') log.warning(f"Invalid status update: {err}")
return return
# try to extract game from activities # try to extract game from activities
# when game not provided # when game not provided
if not status.get('game'): if not status.get("game"):
try: try:
game = status['activities'][0] game = status["activities"][0]
except (KeyError, IndexError): except (KeyError, IndexError):
game = None game = None
else: else:
game = status['game'] game = status["game"]
# construct final status # construct final status
status = { status = {
'afk': status.get('afk', False), "afk": status.get("afk", False),
'status': status.get('status', 'online'), "status": status.get("status", "online"),
'game': game, "game": game,
'since': status.get('since', 0), "since": status.get("since", 0),
} }
self.state.presence = status self.state.presence = status
log.info(f'Updating presence status={status["status"]} for ' log.info(
f'uid={self.state.user_id}') f'Updating presence status={status["status"]} for '
await self.ext.presence.dispatch_pres(self.state.user_id, f"uid={self.state.user_id}"
self.state.presence) )
await self.ext.presence.dispatch_pres(self.state.user_id, self.state.presence)
async def handle_1(self, payload: Dict[str, Any]): async def handle_1(self, payload: Dict[str, Any]):
"""Handle OP 1 Heartbeat packets.""" """Handle OP 1 Heartbeat packets."""
# give the client 3 more seconds before we # give the client 3 more seconds before we
# close the websocket # close the websocket
self._hb_start((46 + 3) * 1000) self._hb_start((46 + 3) * 1000)
cliseq = payload.get('d') cliseq = payload.get("d")
if self.state: if self.state:
self.state.last_seq = cliseq self.state.last_seq = cliseq
@ -529,39 +535,42 @@ class GatewayWebsocket:
await self.send_op(OP.HEARTBEAT_ACK, None) await self.send_op(OP.HEARTBEAT_ACK, None)
async def _connect_ratelimit(self, user_id: int): async def _connect_ratelimit(self, user_id: int):
if self._check_ratelimit('connect', user_id): if self._check_ratelimit("connect", user_id):
await self.invalidate_session(False) await self.invalidate_session(False)
raise WebsocketClose(4009, 'You are being ratelimited.') raise WebsocketClose(4009, "You are being ratelimited.")
if self._check_ratelimit('session', user_id): if self._check_ratelimit("session", user_id):
await self.invalidate_session(False) await self.invalidate_session(False)
raise WebsocketClose(4004, 'Websocket Session Ratelimit reached.') raise WebsocketClose(4004, "Websocket Session Ratelimit reached.")
async def handle_2(self, payload: Dict[str, Any]): async def handle_2(self, payload: Dict[str, Any]):
"""Handle the OP 2 Identify packet.""" """Handle the OP 2 Identify packet."""
try: try:
data = payload['d'] data = payload["d"]
token = data['token'] token = data["token"]
except KeyError: except KeyError:
raise DecodeError('Invalid identify parameters') raise DecodeError("Invalid identify parameters")
compress = data.get('compress', False) compress = data.get("compress", False)
large = data.get('large_threshold', 50) large = data.get("large_threshold", 50)
shard = data.get('shard', [0, 1]) shard = data.get("shard", [0, 1])
presence = data.get('presence') presence = data.get("presence")
try: try:
user_id = await raw_token_check(token, self.ext.db) user_id = await raw_token_check(token, self.ext.db)
except (Unauthorized, Forbidden): except (Unauthorized, Forbidden):
raise WebsocketClose(4004, 'Authentication failed') raise WebsocketClose(4004, "Authentication failed")
await self._connect_ratelimit(user_id) await self._connect_ratelimit(user_id)
bot = await self.ext.db.fetchval(""" bot = await self.ext.db.fetchval(
"""
SELECT bot FROM users SELECT bot FROM users
WHERE id = $1 WHERE id = $1
""", user_id) """,
user_id,
)
await self._check_shards(shard, user_id) await self._check_shards(shard, user_id)
@ -574,19 +583,19 @@ class GatewayWebsocket:
shard=shard, shard=shard,
current_shard=shard[0], current_shard=shard[0],
shard_count=shard[1], shard_count=shard[1],
ws=self ws=self,
) )
# link the state to the user # link the state to the user
self.ext.state_manager.insert(self.state) self.ext.state_manager.insert(self.state)
await self.update_status(presence) await self.update_status(presence)
await self.subscribe_all(data.get('guild_subscriptions', True)) await self.subscribe_all(data.get("guild_subscriptions", True))
await self.dispatch_ready() await self.dispatch_ready()
async def handle_3(self, payload: Dict[str, Any]): async def handle_3(self, payload: Dict[str, Any]):
"""Handle OP 3 Status Update.""" """Handle OP 3 Status Update."""
presence = payload['d'] presence = payload["d"]
# update_status will take care of validation and # update_status will take care of validation and
# setting new presence to state # setting new presence to state
@ -597,27 +606,27 @@ class GatewayWebsocket:
user settings.""" user settings."""
try: try:
# TODO: fetch from settings if not provided # TODO: fetch from settings if not provided
self_deaf = bool(data['self_deaf']) self_deaf = bool(data["self_deaf"])
self_mute = bool(data['self_mute']) self_mute = bool(data["self_mute"])
except (KeyError, ValueError): except (KeyError, ValueError):
pass pass
return { return {
'deaf': state.deaf, "deaf": state.deaf,
'mute': state.mute, "mute": state.mute,
'self_deaf': self_deaf, "self_deaf": self_deaf,
'self_mute': self_mute, "self_mute": self_mute,
} }
async def handle_4(self, payload: Dict[str, Any]): async def handle_4(self, payload: Dict[str, Any]):
"""Handle OP 4 Voice Status Update.""" """Handle OP 4 Voice Status Update."""
data = payload['d'] data = payload["d"]
if not self.state: if not self.state:
return return
channel_id = int_(data.get('channel_id')) channel_id = int_(data.get("channel_id"))
guild_id = int_(data.get('guild_id')) guild_id = int_(data.get("guild_id"))
# if its null and null, disconnect the user from any voice # if its null and null, disconnect the user from any voice
# TODO: maybe just leave from DMs? idk... # TODO: maybe just leave from DMs? idk...
@ -630,9 +639,7 @@ class GatewayWebsocket:
return await self.ext.voice.leave(guild_id, self.state.user_id) return await self.ext.voice.leave(guild_id, self.state.user_id)
# fetch an existing state given user and guild OR user and channel # fetch an existing state given user and guild OR user and channel
chan_type = ChannelType( chan_type = ChannelType(await self.storage.get_chan_type(channel_id))
await self.storage.get_chan_type(channel_id)
)
state_id2 = channel_id state_id2 = channel_id
@ -704,39 +711,38 @@ class GatewayWebsocket:
# ignore unknown seqs # ignore unknown seqs
continue continue
payload_t = payload.get('t') payload_t = payload.get("t")
# presence resumption happens # presence resumption happens
# on a separate event, PRESENCE_REPLACE. # on a separate event, PRESENCE_REPLACE.
if payload_t == 'PRESENCE_UPDATE': if payload_t == "PRESENCE_UPDATE":
presences.append(payload.get('d')) presences.append(payload.get("d"))
continue continue
await self.send(payload) await self.send(payload)
except Exception: except Exception:
log.exception('error while resuming') log.exception("error while resuming")
await self.invalidate_session(False) await self.invalidate_session(False)
return return
if presences: if presences:
await self.dispatch('PRESENCE_REPLACE', presences) await self.dispatch("PRESENCE_REPLACE", presences)
await self.dispatch('RESUMED', {}) await self.dispatch("RESUMED", {})
async def handle_6(self, payload: Dict[str, Any]): async def handle_6(self, payload: Dict[str, Any]):
"""Handle OP 6 Resume.""" """Handle OP 6 Resume."""
data = payload['d'] data = payload["d"]
try: try:
token, sess_id, seq = data['token'], \ token, sess_id, seq = data["token"], data["session_id"], data["seq"]
data['session_id'], data['seq']
except KeyError: except KeyError:
raise DecodeError('Invalid resume payload') raise DecodeError("Invalid resume payload")
try: try:
user_id = await raw_token_check(token, self.ext.db) user_id = await raw_token_check(token, self.ext.db)
except (Unauthorized, Forbidden): except (Unauthorized, Forbidden):
raise WebsocketClose(4004, 'Invalid token') raise WebsocketClose(4004, "Invalid token")
try: try:
state = self.ext.state_manager.fetch(user_id, sess_id) state = self.ext.state_manager.fetch(user_id, sess_id)
@ -744,11 +750,11 @@ class GatewayWebsocket:
return await self.invalidate_session(False) return await self.invalidate_session(False)
if seq > state.seq: if seq > state.seq:
raise WebsocketClose(4007, 'Invalid seq') raise WebsocketClose(4007, "Invalid seq")
# check if a websocket isnt on that state already # check if a websocket isnt on that state already
if state.ws is not None: if state.ws is not None:
log.info('Resuming failed, websocket already connected') log.info("Resuming failed, websocket already connected")
return await self.invalidate_session(False) return await self.invalidate_session(False)
# relink this connection # relink this connection
@ -757,8 +763,9 @@ class GatewayWebsocket:
await self._resume(range(seq, state.seq)) await self._resume(range(seq, state.seq))
async def _req_guild_members(self, guild_id, user_ids: List[int], async def _req_guild_members(
query: str, limit: int): self, guild_id, user_ids: List[int], query: str, limit: int
):
try: try:
guild_id = int(guild_id) guild_id = int(guild_id)
except (TypeError, ValueError): except (TypeError, ValueError):
@ -778,32 +785,32 @@ class GatewayWebsocket:
# ASSUMPTION: requesting user_ids means we don't do query. # ASSUMPTION: requesting user_ids means we don't do query.
if user_ids: if user_ids:
members = await self.storage.get_member_multi(guild_id, user_ids) members = await self.storage.get_member_multi(guild_id, user_ids)
mids = [m['user']['id'] for m in members] mids = [m["user"]["id"] for m in members]
not_found = [uid for uid in user_ids if uid not in mids] not_found = [uid for uid in user_ids if uid not in mids]
await self.dispatch('GUILD_MEMBERS_CHUNK', { await self.dispatch(
'guild_id': str(guild_id), "GUILD_MEMBERS_CHUNK",
'members': members, {"guild_id": str(guild_id), "members": members, "not_found": not_found},
'not_found': not_found, )
})
return return
# do the search # do the search
result = await self.storage.query_members(guild_id, query, limit) result = await self.storage.query_members(guild_id, query, limit)
await self.dispatch('GUILD_MEMBERS_CHUNK', { await self.dispatch(
'guild_id': str(guild_id), "GUILD_MEMBERS_CHUNK", {"guild_id": str(guild_id), "members": result}
'members': result )
})
async def handle_8(self, payload: Dict): async def handle_8(self, payload: Dict):
"""Handle OP 8 Request Guild Members.""" """Handle OP 8 Request Guild Members."""
data = payload['d'] data = payload["d"]
gids = data['guild_id'] gids = data["guild_id"]
uids, query, limit = data.get('user_ids', []), \ uids, query, limit = (
data.get('query', ''), \ data.get("user_ids", []),
data.get('limit', 0) data.get("query", ""),
data.get("limit", 0),
)
if isinstance(gids, str): if isinstance(gids, str):
await self._req_guild_members(gids, uids, query, limit) await self._req_guild_members(gids, uids, query, limit)
@ -820,23 +827,21 @@ class GatewayWebsocket:
GUILD_SYNC event with that info. GUILD_SYNC event with that info.
""" """
members = await self.storage.get_member_data(guild_id) members = await self.storage.get_member_data(guild_id)
member_ids = [int(m['user']['id']) for m in members] member_ids = [int(m["user"]["id"]) for m in members]
log.debug(f'Syncing guild {guild_id} with {len(member_ids)} members') log.debug(f"Syncing guild {guild_id} with {len(member_ids)} members")
presences = await self.presence.guild_presences(member_ids, guild_id) presences = await self.presence.guild_presences(member_ids, guild_id)
await self.dispatch('GUILD_SYNC', { await self.dispatch(
'id': str(guild_id), "GUILD_SYNC",
'presences': presences, {"id": str(guild_id), "presences": presences, "members": members},
'members': members, )
})
async def handle_12(self, payload: Dict[str, Any]): async def handle_12(self, payload: Dict[str, Any]):
"""Handle OP 12 Guild Sync.""" """Handle OP 12 Guild Sync."""
data = payload['d'] data = payload["d"]
gids = await self.user_storage.get_user_guilds( gids = await self.user_storage.get_user_guilds(self.state.user_id)
self.state.user_id)
for guild_id in data: for guild_id in data:
try: try:
@ -931,35 +936,33 @@ class GatewayWebsocket:
] ]
} }
""" """
data = payload['d'] data = payload["d"]
gids = await self.user_storage.get_user_guilds(self.state.user_id) gids = await self.user_storage.get_user_guilds(self.state.user_id)
guild_id = int(data['guild_id']) guild_id = int(data["guild_id"])
# make sure to not extract info you shouldn't get # make sure to not extract info you shouldn't get
if guild_id not in gids: if guild_id not in gids:
return return
log.debug('lazy request: members: {}', log.debug("lazy request: members: {}", data.get("members", []))
data.get('members', []))
# make shard query # make shard query
lazy_guilds = self.ext.dispatcher.backends['lazy_guild'] lazy_guilds = self.ext.dispatcher.backends["lazy_guild"]
for chan_id, ranges in data.get('channels', {}).items(): for chan_id, ranges in data.get("channels", {}).items():
chan_id = int(chan_id) chan_id = int(chan_id)
member_list = await lazy_guilds.get_gml(chan_id) member_list = await lazy_guilds.get_gml(chan_id)
perms = await get_permissions( perms = await get_permissions(
self.state.user_id, chan_id, storage=self.storage) self.state.user_id, chan_id, storage=self.storage
)
if not perms.bits.read_messages: if not perms.bits.read_messages:
# ignore requests to unknown channels # ignore requests to unknown channels
return return
await member_list.shard_query( await member_list.shard_query(self.state.session_id, ranges)
self.state.session_id, ranges
)
async def _handle_23(self, payload): async def _handle_23(self, payload):
# TODO reverse-engineer opcode 23, sent by client # TODO reverse-engineer opcode 23, sent by client
@ -968,21 +971,21 @@ class GatewayWebsocket:
async def _process_message(self, payload): async def _process_message(self, payload):
"""Process a single message coming in from the client.""" """Process a single message coming in from the client."""
try: try:
op_code = payload['op'] op_code = payload["op"]
except KeyError: except KeyError:
raise UnknownOPCode('No OP code') raise UnknownOPCode("No OP code")
try: try:
handler = getattr(self, f'handle_{op_code}') handler = getattr(self, f"handle_{op_code}")
except AttributeError: except AttributeError:
log.warning('Payload with bad op: {}', pprint.pformat(payload)) log.warning("Payload with bad op: {}", pprint.pformat(payload))
raise UnknownOPCode(f'Bad OP code: {op_code}') raise UnknownOPCode(f"Bad OP code: {op_code}")
await handler(payload) await handler(payload)
async def _msg_ratelimit(self): async def _msg_ratelimit(self):
if self._check_ratelimit('messages', self.state.session_id): if self._check_ratelimit("messages", self.state.session_id):
raise WebsocketClose(4008, 'You are being ratelimited.') raise WebsocketClose(4008, "You are being ratelimited.")
async def _listen_messages(self): async def _listen_messages(self):
"""Listen for messages coming in from the websocket.""" """Listen for messages coming in from the websocket."""
@ -990,15 +993,15 @@ class GatewayWebsocket:
# close anyone trying to login while the # close anyone trying to login while the
# server is shutting down # server is shutting down
if self.ext.state_manager.closed: if self.ext.state_manager.closed:
raise WebsocketClose(4000, 'state manager closed') raise WebsocketClose(4000, "state manager closed")
if not self.ext.state_manager.accept_new: if not self.ext.state_manager.accept_new:
raise WebsocketClose(4000, 'state manager closed for new') raise WebsocketClose(4000, "state manager closed for new")
while True: while True:
message = await self.ws.recv() message = await self.ws.recv()
if len(message) > 4096: if len(message) > 4096:
raise DecodeError('Payload length exceeded') raise DecodeError("Payload length exceeded")
if self.state: if self.state:
await self._msg_ratelimit() await self._msg_ratelimit()
@ -1033,17 +1036,9 @@ class GatewayWebsocket:
# there arent any other states with websocket # there arent any other states with websocket
if not with_ws: if not with_ws:
offline = { offline = {"afk": False, "status": "offline", "game": None, "since": 0}
'afk': False,
'status': 'offline',
'game': None,
'since': 0,
}
await self.ext.presence.dispatch_pres( await self.ext.presence.dispatch_pres(user_id, offline)
user_id,
offline
)
async def run(self): async def run(self):
"""Wrap :meth:`listen_messages` inside """Wrap :meth:`listen_messages` inside
@ -1052,12 +1047,12 @@ class GatewayWebsocket:
await self._send_hello() await self._send_hello()
await self._listen_messages() await self._listen_messages()
except websockets.exceptions.ConnectionClosed as err: except websockets.exceptions.ConnectionClosed as err:
log.warning('conn close, state={}, err={}', self.state, err) log.warning("conn close, state={}, err={}", self.state, err)
except WebsocketClose as err: except WebsocketClose as err:
log.warning('ws close, state={} err={}', self.state, err) log.warning("ws close, state={} err={}", self.state, err)
await self.ws.close(code=err.code, reason=err.reason) await self.ws.close(code=err.code, reason=err.reason)
except Exception as err: except Exception as err:
log.exception('An exception has occoured. state={}', self.state) log.exception("An exception has occoured. state={}", self.state)
await self.ws.close(code=4000, reason=repr(err)) await self.ws.close(code=4000, reason=repr(err))
finally: finally:
user_id = self.state.user_id if self.state else None user_id = self.state.user_id if self.state else None

View File

@ -17,19 +17,21 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
class GuildMemoryStore: class GuildMemoryStore:
"""Store in-memory properties about guilds. """Store in-memory properties about guilds.
I could have just used Redis... probably too overkill to add I could have just used Redis... probably too overkill to add
aioredis to the already long depedency list, plus, I don't need aioredis to the already long depedency list, plus, I don't need
""" """
def __init__(self): def __init__(self):
self._store = {} self._store = {}
def get(self, guild_id: int, attribute: str, default=None): def get(self, guild_id: int, attribute: str, default=None):
"""get a key""" """get a key"""
return self._store.get(f'{guild_id}:{attribute}', default) return self._store.get(f"{guild_id}:{attribute}", default)
def set(self, guild_id: int, attribute: str, value): def set(self, guild_id: int, attribute: str, value):
"""set a key""" """set a key"""
self._store[f'{guild_id}:{attribute}'] = value self._store[f"{guild_id}:{attribute}"] = value

View File

@ -33,47 +33,42 @@ from logbook import Logger
from PIL import Image from PIL import Image
IMAGE_FOLDER = Path('./images') IMAGE_FOLDER = Path("./images")
log = Logger(__name__) log = Logger(__name__)
EXTENSIONS = { EXTENSIONS = {"image/jpeg": "jpeg", "image/webp": "webp"}
'image/jpeg': 'jpeg',
'image/webp': 'webp'
}
MIMES = { MIMES = {
'jpg': 'image/jpeg', "jpg": "image/jpeg",
'jpe': 'image/jpeg', "jpe": "image/jpeg",
'jpeg': 'image/jpeg', "jpeg": "image/jpeg",
'webp': 'image/webp', "webp": "image/webp",
} }
STATIC_IMAGE_MIMES = [ STATIC_IMAGE_MIMES = ["image/png", "image/jpeg", "image/webp"]
'image/png',
'image/jpeg',
'image/webp'
]
def get_ext(mime: str) -> str: def get_ext(mime: str) -> str:
if mime in EXTENSIONS: if mime in EXTENSIONS:
return EXTENSIONS[mime] return EXTENSIONS[mime]
extensions = mimetypes.guess_all_extensions(mime) extensions = mimetypes.guess_all_extensions(mime)
return extensions[0].strip('.') return extensions[0].strip(".")
def get_mime(ext: str): def get_mime(ext: str):
if ext in MIMES: if ext in MIMES:
return MIMES[ext] return MIMES[ext]
return mimetypes.types_map[f'.{ext}'] return mimetypes.types_map[f".{ext}"]
@dataclass @dataclass
class Icon: class Icon:
"""Main icon class""" """Main icon class"""
key: Optional[str] key: Optional[str]
icon_hash: Optional[str] icon_hash: Optional[str]
mime: Optional[str] mime: Optional[str]
@ -85,7 +80,7 @@ class Icon:
return None return None
ext = get_ext(self.mime) ext = get_ext(self.mime)
return str(IMAGE_FOLDER / f'{self.key}_{self.icon_hash}.{ext}') return str(IMAGE_FOLDER / f"{self.key}_{self.icon_hash}.{ext}")
@property @property
def as_pathlib(self) -> Optional[Path]: def as_pathlib(self) -> Optional[Path]:
@ -106,13 +101,14 @@ class Icon:
class ImageError(Exception): class ImageError(Exception):
"""Image error class.""" """Image error class."""
pass pass
def to_raw(data_type: str, data: str) -> Optional[bytes]: def to_raw(data_type: str, data: str) -> Optional[bytes]:
"""Given a data type in the data URI and data, """Given a data type in the data URI and data,
give the raw bytes being encoded.""" give the raw bytes being encoded."""
if data_type == 'base64': if data_type == "base64":
return base64.b64decode(data) return base64.b64decode(data)
return None return None
@ -136,7 +132,7 @@ def _calculate_hash(fhandler) -> str:
""" """
hash_obj = sha256() hash_obj = sha256()
for chunk in iter(lambda: fhandler.read(4096), b''): for chunk in iter(lambda: fhandler.read(4096), b""):
hash_obj.update(chunk) hash_obj.update(chunk)
# so that we can reuse the same handler # so that we can reuse the same handler
@ -162,39 +158,36 @@ async def calculate_hash(fhandle, loop=None) -> str:
def parse_data_uri(string) -> tuple: def parse_data_uri(string) -> tuple:
"""Extract image data.""" """Extract image data."""
try: try:
header, headered_data = string.split(';') header, headered_data = string.split(";")
_, given_mime = header.split(':') _, given_mime = header.split(":")
data_type, data = headered_data.split(',') data_type, data = headered_data.split(",")
raw_data = to_raw(data_type, data) raw_data = to_raw(data_type, data)
if raw_data is None: if raw_data is None:
raise ImageError('Unknown data header') raise ImageError("Unknown data header")
return given_mime, raw_data return given_mime, raw_data
except ValueError: except ValueError:
raise ImageError('data URI invalid syntax') raise ImageError("data URI invalid syntax")
def _gen_update_sql(scope: str) -> str: def _gen_update_sql(scope: str) -> str:
# match a scope to (table, field) # match a scope to (table, field)
field = { field = {
'user': 'avatar', "user": "avatar",
'guild': 'icon', "guild": "icon",
'splash': 'splash', "splash": "splash",
'banner': 'banner', "banner": "banner",
"channel-icons": "icon",
'channel-icons': 'icon',
}[scope] }[scope]
table = { table = {
'user': 'users', "user": "users",
"guild": "guilds",
'guild': 'guilds', "splash": "guilds",
'splash': 'guilds', "banner": "guilds",
'banner': 'guilds', "channel-icons": "group_dm_channels",
'channel-icons': 'group_dm_channels'
}[scope] }[scope]
return f""" return f"""
@ -204,10 +197,10 @@ def _gen_update_sql(scope: str) -> str:
def _invalid(kwargs: dict) -> Optional[Icon]: def _invalid(kwargs: dict) -> Optional[Icon]:
"""Send an invalid value.""" """Send an invalid value."""
if not kwargs.get('always_icon', False): if not kwargs.get("always_icon", False):
return None return None
return Icon(None, None, '') return Icon(None, None, "")
def try_unlink(path: Union[Path, str]): def try_unlink(path: Union[Path, str]):
@ -225,18 +218,17 @@ def try_unlink(path: Union[Path, str]):
async def resize_gif(raw_data: bytes, target: tuple) -> tuple: async def resize_gif(raw_data: bytes, target: tuple) -> tuple:
"""Resize a GIF image.""" """Resize a GIF image."""
# generate a temporary file to call gifsticle to and from. # generate a temporary file to call gifsticle to and from.
input_fd, input_path = tempfile.mkstemp(suffix='.gif') input_fd, input_path = tempfile.mkstemp(suffix=".gif")
_, output_path = tempfile.mkstemp(suffix='.gif') _, output_path = tempfile.mkstemp(suffix=".gif")
input_handler = os.fdopen(input_fd, 'wb') input_handler = os.fdopen(input_fd, "wb")
# make sure its valid image data # make sure its valid image data
data_fd = BytesIO(raw_data) data_fd = BytesIO(raw_data)
image = Image.open(data_fd) image = Image.open(data_fd)
image.close() image.close()
log.info('resizing a GIF from {} to {}', log.info("resizing a GIF from {} to {}", image.size, target)
image.size, target)
# insert image info on input_handler # insert image info on input_handler
# close it to make it ready for consumption by gifsicle # close it to make it ready for consumption by gifsicle
@ -244,12 +236,11 @@ async def resize_gif(raw_data: bytes, target: tuple) -> tuple:
input_handler.close() input_handler.close()
# call gifsicle under subprocess # call gifsicle under subprocess
log.debug('input: {}', input_path) log.debug("input: {}", input_path)
log.debug('output: {}', output_path) log.debug("output: {}", output_path)
process = await asyncio.create_subprocess_shell( process = await asyncio.create_subprocess_shell(
f'gifsicle --resize {target[0]}x{target[1]} ' f"gifsicle --resize {target[0]}x{target[1]} " f"{input_path} > {output_path}",
f'{input_path} > {output_path}',
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
) )
@ -257,11 +248,11 @@ async def resize_gif(raw_data: bytes, target: tuple) -> tuple:
# run it, etc. # run it, etc.
out, err = await process.communicate() out, err = await process.communicate()
log.debug('out + err from gifsicle: {}', out + err) log.debug("out + err from gifsicle: {}", out + err)
# write over an empty data_fd # write over an empty data_fd
data_fd = BytesIO() data_fd = BytesIO()
output_handler = open(output_path, 'rb') output_handler = open(output_path, "rb")
data_fd.write(output_handler.read()) data_fd.write(output_handler.read())
# close unused handlers # close unused handlers
@ -283,40 +274,40 @@ async def resize_gif(raw_data: bytes, target: tuple) -> tuple:
class IconManager: class IconManager:
"""Main icon manager.""" """Main icon manager."""
def __init__(self, app): def __init__(self, app):
self.app = app self.app = app
self.storage = app.storage self.storage = app.storage
async def _convert_ext(self, icon: Icon, target: str): async def _convert_ext(self, icon: Icon, target: str):
target = 'jpeg' if target == 'jpg' else target target = "jpeg" if target == "jpg" else target
target_mime = get_mime(target) target_mime = get_mime(target)
log.info('converting from {} to {}', icon.mime, target_mime) log.info("converting from {} to {}", icon.mime, target_mime)
target_path = IMAGE_FOLDER / f'{icon.key}_{icon.icon_hash}.{target}' target_path = IMAGE_FOLDER / f"{icon.key}_{icon.icon_hash}.{target}"
if target_path.exists(): if target_path.exists():
return Icon(icon.key, icon.icon_hash, target_mime) return Icon(icon.key, icon.icon_hash, target_mime)
image = Image.open(icon.as_path) image = Image.open(icon.as_path)
target_fd = target_path.open('wb') target_fd = target_path.open("wb")
if target == 'jpeg': if target == "jpeg":
image = image.convert('RGB') image = image.convert("RGB")
image.save(target_fd, format=target) image.save(target_fd, format=target)
target_fd.close() target_fd.close()
return Icon(icon.key, icon.icon_hash, target_mime) return Icon(icon.key, icon.icon_hash, target_mime)
async def generic_get(self, scope, key, icon_hash, async def generic_get(self, scope, key, icon_hash, **kwargs) -> Optional[Icon]:
**kwargs) -> Optional[Icon]:
"""Get any icon.""" """Get any icon."""
log.debug('GET {} {} {}', scope, key, icon_hash) log.debug("GET {} {} {}", scope, key, icon_hash)
key = str(key) key = str(key)
hash_query = 'AND hash = $3' if icon_hash else '' hash_query = "AND hash = $3" if icon_hash else ""
# hacky solution to only add icon_hash # hacky solution to only add icon_hash
# when needed. # when needed.
@ -325,18 +316,21 @@ class IconManager:
if icon_hash: if icon_hash:
args.append(icon_hash) args.append(icon_hash)
icon_row = await self.storage.db.fetchrow(f""" icon_row = await self.storage.db.fetchrow(
f"""
SELECT key, hash, mime SELECT key, hash, mime
FROM icons FROM icons
WHERE scope = $1 WHERE scope = $1
AND key = $2 AND key = $2
{hash_query} {hash_query}
""", *args) """,
*args,
)
if not icon_row: if not icon_row:
return None return None
icon = Icon(icon_row['key'], icon_row['hash'], icon_row['mime']) icon = Icon(icon_row["key"], icon_row["hash"], icon_row["mime"])
# ensure we aren't messing with NULLs everywhere. # ensure we aren't messing with NULLs everywhere.
if icon.as_pathlib is None: if icon.as_pathlib is None:
@ -349,18 +343,16 @@ class IconManager:
if icon.extension is None: if icon.extension is None:
return None return None
if 'ext' in kwargs and kwargs['ext'] != icon.extension: if "ext" in kwargs and kwargs["ext"] != icon.extension:
return await self._convert_ext(icon, kwargs['ext']) return await self._convert_ext(icon, kwargs["ext"])
return icon return icon
async def get_guild_icon(self, guild_id: int, icon_hash: str, **kwargs): async def get_guild_icon(self, guild_id: int, icon_hash: str, **kwargs):
"""Get an icon for a guild.""" """Get an icon for a guild."""
return await self.generic_get( return await self.generic_get("guild", guild_id, icon_hash, **kwargs)
'guild', guild_id, icon_hash, **kwargs)
async def put(self, scope: str, key: str, async def put(self, scope: str, key: str, b64_data: str, **kwargs) -> Icon:
b64_data: str, **kwargs) -> Icon:
"""Insert an icon.""" """Insert an icon."""
if b64_data is None: if b64_data is None:
return _invalid(kwargs) return _invalid(kwargs)
@ -373,23 +365,22 @@ class IconManager:
# get an extension for the given data uri # get an extension for the given data uri
extension = get_ext(mime) extension = get_ext(mime)
if 'bsize' in kwargs and len(raw_data) > kwargs['bsize']: if "bsize" in kwargs and len(raw_data) > kwargs["bsize"]:
return _invalid(kwargs) return _invalid(kwargs)
# size management is different for gif files # size management is different for gif files
# as they're composed of multiple frames. # as they're composed of multiple frames.
if 'size' in kwargs and mime == 'image/gif': if "size" in kwargs and mime == "image/gif":
data_fd, raw_data = await resize_gif(raw_data, kwargs['size']) data_fd, raw_data = await resize_gif(raw_data, kwargs["size"])
elif 'size' in kwargs: elif "size" in kwargs:
image = Image.open(data_fd) image = Image.open(data_fd)
if mime == 'image/jpeg': if mime == "image/jpeg":
image = image.convert("RGB") image = image.convert("RGB")
want = kwargs['size'] want = kwargs["size"]
log.info('resizing from {} to {}', log.info("resizing from {} to {}", image.size, want)
image.size, want)
resized = image.resize(want, resample=Image.LANCZOS) resized = image.resize(want, resample=Image.LANCZOS)
@ -404,23 +395,26 @@ class IconManager:
# calculate sha256 # calculate sha256
# ignore icon hashes if we're talking about emoji # ignore icon hashes if we're talking about emoji
icon_hash = (await calculate_hash(data_fd) icon_hash = await calculate_hash(data_fd) if scope != "emoji" else None
if scope != 'emoji'
else None)
if scope == 'user' and mime == 'image/gif': if scope == "user" and mime == "image/gif":
icon_hash = f'a_{icon_hash}' icon_hash = f"a_{icon_hash}"
log.debug('PUT icon {!r} {!r} {!r} {!r}', log.debug("PUT icon {!r} {!r} {!r} {!r}", scope, key, icon_hash, mime)
scope, key, icon_hash, mime)
await self.storage.db.execute(""" await self.storage.db.execute(
"""
INSERT INTO icons (scope, key, hash, mime) INSERT INTO icons (scope, key, hash, mime)
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4)
""", scope, str(key), icon_hash, mime) """,
scope,
str(key),
icon_hash,
mime,
)
# write it off to fs # write it off to fs
icon_path = IMAGE_FOLDER / f'{key}_{icon_hash}.{extension}' icon_path = IMAGE_FOLDER / f"{key}_{icon_hash}.{extension}"
icon_path.write_bytes(raw_data) icon_path.write_bytes(raw_data)
# copy from data_fd to icon_fd # copy from data_fd to icon_fd
@ -434,57 +428,80 @@ class IconManager:
if not icon: if not icon:
return return
log.debug('DEL {}', log.debug("DEL {}", icon)
icon)
# dereference # dereference
await self.storage.db.execute(""" await self.storage.db.execute(
"""
UPDATE users UPDATE users
SET avatar = NULL SET avatar = NULL
WHERE avatar = $1 WHERE avatar = $1
""", icon.icon_hash) """,
icon.icon_hash,
)
await self.storage.db.execute(""" await self.storage.db.execute(
"""
UPDATE group_dm_channels UPDATE group_dm_channels
SET icon = NULL SET icon = NULL
WHERE icon = $1 WHERE icon = $1
""", icon.icon_hash) """,
icon.icon_hash,
)
await self.storage.db.execute(""" await self.storage.db.execute(
"""
DELETE FROM guild_emoji DELETE FROM guild_emoji
WHERE image = $1 WHERE image = $1
""", icon.icon_hash) """,
icon.icon_hash,
)
await self.storage.db.execute(""" await self.storage.db.execute(
"""
UPDATE guilds UPDATE guilds
SET icon = NULL SET icon = NULL
WHERE icon = $1 WHERE icon = $1
""", icon.icon_hash) """,
icon.icon_hash,
)
await self.storage.db.execute(""" await self.storage.db.execute(
"""
UPDATE guilds UPDATE guilds
SET splash = NULL SET splash = NULL
WHERE splash = $1 WHERE splash = $1
""", icon.icon_hash) """,
icon.icon_hash,
)
await self.storage.db.execute(""" await self.storage.db.execute(
"""
UPDATE guilds UPDATE guilds
SET banner = NULL SET banner = NULL
WHERE banner = $1 WHERE banner = $1
""", icon.icon_hash) """,
icon.icon_hash,
)
await self.storage.db.execute(""" await self.storage.db.execute(
"""
UPDATE group_dm_channels UPDATE group_dm_channels
SET icon = NULL SET icon = NULL
WHERE icon = $1 WHERE icon = $1
""", icon.icon_hash) """,
icon.icon_hash,
)
await self.storage.db.execute(""" await self.storage.db.execute(
"""
DELETE FROM icons DELETE FROM icons
WHERE hash = $1 WHERE hash = $1
""", icon.icon_hash) """,
icon.icon_hash,
)
paths = IMAGE_FOLDER.glob(f'{icon.key}_{icon.icon_hash}.*') paths = IMAGE_FOLDER.glob(f"{icon.key}_{icon.icon_hash}.*")
for path in paths: for path in paths:
try: try:
@ -492,11 +509,9 @@ class IconManager:
except FileNotFoundError: except FileNotFoundError:
pass pass
async def update(self, scope: str, key: str, async def update(self, scope: str, key: str, new_icon_data: str, **kwargs) -> Icon:
new_icon_data: str, **kwargs) -> Icon:
"""Update an icon on a key.""" """Update an icon on a key."""
old_icon_hash = await self.storage.db.fetchval( old_icon_hash = await self.storage.db.fetchval(_gen_update_sql(scope), key)
_gen_update_sql(scope), key)
# converting key to str only here since from here onwards # converting key to str only here since from here onwards
# its operations on the icons table (or a dereference with # its operations on the icons table (or a dereference with

View File

@ -20,6 +20,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio import asyncio
from logbook import Logger from logbook import Logger
log = Logger(__name__) log = Logger(__name__)
@ -30,6 +31,7 @@ class JobManager:
use helpers such as asyncio.gather and asyncio.Task.all_tasks. It only uses use helpers such as asyncio.gather and asyncio.Task.all_tasks. It only uses
its own internal list of jobs. its own internal list of jobs.
""" """
def __init__(self, loop=None): def __init__(self, loop=None):
self.loop = loop or asyncio.get_event_loop() self.loop = loop or asyncio.get_event_loop()
self.jobs = [] self.jobs = []
@ -41,13 +43,11 @@ class JobManager:
try: try:
await coro await coro
except Exception: except Exception:
log.exception('Error while running job') log.exception("Error while running job")
def spawn(self, coro): def spawn(self, coro):
"""Spawn a given future or coroutine in the background.""" """Spawn a given future or coroutine in the background."""
task = self.loop.create_task( task = self.loop.create_task(self._wrapper(coro))
self._wrapper(coro)
)
self.jobs.append(task) self.jobs.append(task)

View File

@ -26,40 +26,42 @@ from quart import current_app as app
# type for all the fields # type for all the fields
_i = ctypes.c_uint8 _i = ctypes.c_uint8
class _RawPermsBits(ctypes.LittleEndianStructure): class _RawPermsBits(ctypes.LittleEndianStructure):
"""raw bitfield for discord's permission number.""" """raw bitfield for discord's permission number."""
_fields_ = [ _fields_ = [
('create_invites', _i, 1), ("create_invites", _i, 1),
('kick_members', _i, 1), ("kick_members", _i, 1),
('ban_members', _i, 1), ("ban_members", _i, 1),
('administrator', _i, 1), ("administrator", _i, 1),
('manage_channels', _i, 1), ("manage_channels", _i, 1),
('manage_guild', _i, 1), ("manage_guild", _i, 1),
('add_reactions', _i, 1), ("add_reactions", _i, 1),
('view_audit_log', _i, 1), ("view_audit_log", _i, 1),
('priority_speaker', _i, 1), ("priority_speaker", _i, 1),
('stream', _i, 1), ("stream", _i, 1),
('read_messages', _i, 1), ("read_messages", _i, 1),
('send_messages', _i, 1), ("send_messages", _i, 1),
('send_tts', _i, 1), ("send_tts", _i, 1),
('manage_messages', _i, 1), ("manage_messages", _i, 1),
('embed_links', _i, 1), ("embed_links", _i, 1),
('attach_files', _i, 1), ("attach_files", _i, 1),
('read_history', _i, 1), ("read_history", _i, 1),
('mention_everyone', _i, 1), ("mention_everyone", _i, 1),
('external_emojis', _i, 1), ("external_emojis", _i, 1),
('_unused2', _i, 1), ("_unused2", _i, 1),
('connect', _i, 1), ("connect", _i, 1),
('speak', _i, 1), ("speak", _i, 1),
('mute_members', _i, 1), ("mute_members", _i, 1),
('deafen_members', _i, 1), ("deafen_members", _i, 1),
('move_members', _i, 1), ("move_members", _i, 1),
('use_voice_activation', _i, 1), ("use_voice_activation", _i, 1),
('change_nickname', _i, 1), ("change_nickname", _i, 1),
('manage_nicknames', _i, 1), ("manage_nicknames", _i, 1),
('manage_roles', _i, 1), ("manage_roles", _i, 1),
('manage_webhooks', _i, 1), ("manage_webhooks", _i, 1),
('manage_emojis', _i, 1), ("manage_emojis", _i, 1),
] ]
@ -72,16 +74,14 @@ class Permissions(ctypes.Union):
val val
The permissions value as an integer. The permissions value as an integer.
""" """
_fields_ = [
('bits', _RawPermsBits), _fields_ = [("bits", _RawPermsBits), ("binary", ctypes.c_uint64)]
('binary', ctypes.c_uint64),
]
def __init__(self, val: int): def __init__(self, val: int):
self.binary = val self.binary = val
def __repr__(self): def __repr__(self):
return f'<Permissions binary={self.binary}>' return f"<Permissions binary={self.binary}>"
def __int__(self): def __int__(self):
return self.binary return self.binary
@ -95,11 +95,15 @@ async def get_role_perms(guild_id, role_id, storage=None) -> Permissions:
if not storage: if not storage:
storage = app.storage storage = app.storage
perms = await storage.db.fetchval(""" perms = await storage.db.fetchval(
"""
SELECT permissions SELECT permissions
FROM roles FROM roles
WHERE guild_id = $1 AND id = $2 WHERE guild_id = $1 AND id = $2
""", guild_id, role_id) """,
guild_id,
role_id,
)
return Permissions(perms) return Permissions(perms)
@ -118,11 +122,14 @@ async def base_permissions(member_id, guild_id, storage=None) -> Permissions:
if not storage: if not storage:
storage = app.storage storage = app.storage
owner_id = await storage.db.fetchval(""" owner_id = await storage.db.fetchval(
"""
SELECT owner_id SELECT owner_id
FROM guilds FROM guilds
WHERE id = $1 WHERE id = $1
""", guild_id) """,
guild_id,
)
if owner_id == member_id: if owner_id == member_id:
return ALL_PERMISSIONS return ALL_PERMISSIONS
@ -130,20 +137,27 @@ async def base_permissions(member_id, guild_id, storage=None) -> Permissions:
# get permissions for @everyone # get permissions for @everyone
permissions = await get_role_perms(guild_id, guild_id, storage) permissions = await get_role_perms(guild_id, guild_id, storage)
role_ids = await storage.db.fetch(""" role_ids = await storage.db.fetch(
"""
SELECT role_id SELECT role_id
FROM member_roles FROM member_roles
WHERE guild_id = $1 AND user_id = $2 WHERE guild_id = $1 AND user_id = $2
""", guild_id, member_id) """,
guild_id,
member_id,
)
role_perms = [] role_perms = []
for row in role_ids: for row in role_ids:
rperm = await storage.db.fetchval(""" rperm = await storage.db.fetchval(
"""
SELECT permissions SELECT permissions
FROM roles FROM roles
WHERE id = $1 WHERE id = $1
""", row['role_id']) """,
row["role_id"],
)
role_perms.append(rperm) role_perms.append(rperm)
@ -164,16 +178,17 @@ def overwrite_mix(perms: Permissions, overwrite: dict) -> Permissions:
result = perms.binary result = perms.binary
# negate the permissions that are denied # negate the permissions that are denied
result &= ~overwrite['deny'] result &= ~overwrite["deny"]
# combine the permissions that are allowed # combine the permissions that are allowed
result |= overwrite['allow'] result |= overwrite["allow"]
return Permissions(result) return Permissions(result)
def overwrite_find_mix(perms: Permissions, overwrites: dict, def overwrite_find_mix(
target_id: int) -> Permissions: perms: Permissions, overwrites: dict, target_id: int
) -> Permissions:
"""Mix a given permission with a given overwrite. """Mix a given permission with a given overwrite.
Returns the given permission if an overwrite is not found. Returns the given permission if an overwrite is not found.
@ -201,19 +216,25 @@ def overwrite_find_mix(perms: Permissions, overwrites: dict,
return perms return perms
async def role_permissions(guild_id: int, role_id: int, async def role_permissions(
channel_id: int, storage=None) -> Permissions: guild_id: int, role_id: int, channel_id: int, storage=None
) -> Permissions:
"""Get the permissions for a role, in relation to a channel""" """Get the permissions for a role, in relation to a channel"""
if not storage: if not storage:
storage = app.storage storage = app.storage
perms = await get_role_perms(guild_id, role_id, storage) perms = await get_role_perms(guild_id, role_id, storage)
overwrite = await storage.db.fetchrow(""" overwrite = await storage.db.fetchrow(
"""
SELECT allow, deny SELECT allow, deny
FROM channel_overwrites FROM channel_overwrites
WHERE channel_id = $1 AND target_type = $2 AND target_role = $3 WHERE channel_id = $1 AND target_type = $2 AND target_role = $3
""", channel_id, 1, role_id) """,
channel_id,
1,
role_id,
)
if overwrite: if overwrite:
perms = overwrite_mix(perms, overwrite) perms = overwrite_mix(perms, overwrite)
@ -221,10 +242,13 @@ async def role_permissions(guild_id: int, role_id: int,
return perms return perms
async def compute_overwrites(base_perms: Permissions, async def compute_overwrites(
user_id, channel_id: int, base_perms: Permissions,
guild_id: Optional[int] = None, user_id,
storage=None): channel_id: int,
guild_id: Optional[int] = None,
storage=None,
):
"""Compute the permissions in the context of a channel.""" """Compute the permissions in the context of a channel."""
if not storage: if not storage:
storage = app.storage storage = app.storage
@ -245,7 +269,7 @@ async def compute_overwrites(base_perms: Permissions,
return ALL_PERMISSIONS return ALL_PERMISSIONS
# make it a map for better usage # make it a map for better usage
overwrites = {int(o['id']): o for o in overwrites} overwrites = {int(o["id"]): o for o in overwrites}
perms = overwrite_find_mix(perms, overwrites, guild_id) perms = overwrite_find_mix(perms, overwrites, guild_id)
@ -260,14 +284,11 @@ async def compute_overwrites(base_perms: Permissions,
for role_id in role_ids: for role_id in role_ids:
overwrite = overwrites.get(role_id) overwrite = overwrites.get(role_id)
if overwrite: if overwrite:
allow |= overwrite['allow'] allow |= overwrite["allow"]
deny |= overwrite['deny'] deny |= overwrite["deny"]
# final step for roles: mix # final step for roles: mix
perms = overwrite_mix(perms, { perms = overwrite_mix(perms, {"allow": allow, "deny": deny})
'allow': allow,
'deny': deny
})
# apply member specific overwrites # apply member specific overwrites
perms = overwrite_find_mix(perms, overwrites, user_id) perms = overwrite_find_mix(perms, overwrites, user_id)
@ -275,8 +296,7 @@ async def compute_overwrites(base_perms: Permissions,
return perms return perms
async def get_permissions(member_id: int, channel_id, async def get_permissions(member_id: int, channel_id, *, storage=None) -> Permissions:
*, storage=None) -> Permissions:
"""Get the permissions for a user in a channel.""" """Get the permissions for a user in a channel."""
if not storage: if not storage:
storage = app.storage storage = app.storage
@ -290,4 +310,5 @@ async def get_permissions(member_id: int, channel_id,
base_perms = await base_permissions(member_id, guild_id, storage) base_perms = await base_permissions(member_id, guild_id, storage)
return await compute_overwrites( return await compute_overwrites(
base_perms, member_id, channel_id, guild_id, storage) base_perms, member_id, channel_id, guild_id, storage
)

View File

@ -32,62 +32,56 @@ def status_cmp(status: str, other_status: str) -> bool:
in the status hierarchy. in the status hierarchy.
""" """
hierarchy = { hierarchy = {"online": 3, "idle": 2, "dnd": 1, "offline": 0, None: -1}
'online': 3,
'idle': 2,
'dnd': 1,
'offline': 0,
None: -1,
}
return hierarchy[status] > hierarchy[other_status] return hierarchy[status] > hierarchy[other_status]
def _best_presence(shards): def _best_presence(shards):
"""Find the 'best' presence given a list of GatewayState.""" """Find the 'best' presence given a list of GatewayState."""
best = {'status': None, 'game': None} best = {"status": None, "game": None}
for state in shards: for state in shards:
presence = state.presence presence = state.presence
status = presence['status'] status = presence["status"]
if not presence: if not presence:
continue continue
# shards with a better status # shards with a better status
# in the hierarchy are treated as best # in the hierarchy are treated as best
if status_cmp(status, best['status']): if status_cmp(status, best["status"]):
best['status'] = status best["status"] = status
# if we have any game, use it # if we have any game, use it
if presence['game'] is not None: if presence["game"] is not None:
best['game'] = presence['game'] best["game"] = presence["game"]
# best['status'] is None when no # best['status'] is None when no
# status was good enough. # status was good enough.
return None if not best['status'] else best return None if not best["status"] else best
def fill_presence(presence: dict, *, game=None) -> dict: def fill_presence(presence: dict, *, game=None) -> dict:
"""Fill a given presence object with some specific fields.""" """Fill a given presence object with some specific fields."""
presence['client_status'] = {} presence["client_status"] = {}
presence['mobile'] = False presence["mobile"] = False
if 'since' not in presence: if "since" not in presence:
presence['since'] = 0 presence["since"] = 0
# fill game and activities array depending if game # fill game and activities array depending if game
# is there or not # is there or not
game = game or presence.get('game') game = game or presence.get("game")
# casting to bool since a game of {} is still invalid # casting to bool since a game of {} is still invalid
if game: if game:
presence['game'] = game presence["game"] = game
presence['activities'] = [game] presence["activities"] = [game]
else: else:
presence['game'] = None presence["game"] = None
presence['activities'] = [] presence["activities"] = []
return presence return presence
@ -96,14 +90,13 @@ async def _pres(storage, user_id: int, status_obj: dict) -> dict:
"""Convert a given status into a presence, given the User ID and the """Convert a given status into a presence, given the User ID and the
:class:`Storage` instance.""" :class:`Storage` instance."""
ext = { ext = {
'user': await storage.get_user(user_id), "user": await storage.get_user(user_id),
'activities': [], "activities": [],
# NOTE: we are purposefully overwriting the fields, as there # NOTE: we are purposefully overwriting the fields, as there
# isn't any push for us to actually implement mobile detection, or # isn't any push for us to actually implement mobile detection, or
# web detection, etc. # web detection, etc.
'client_status': {}, "client_status": {},
'mobile': False, "mobile": False,
} }
return fill_presence({**status_obj, **ext}) return fill_presence({**status_obj, **ext})
@ -115,14 +108,16 @@ class PresenceManager:
Has common functions to deal with fetching or updating presences, including Has common functions to deal with fetching or updating presences, including
side-effects (events). side-effects (events).
""" """
def __init__(self, app): def __init__(self, app):
self.storage = app.storage self.storage = app.storage
self.user_storage = app.user_storage self.user_storage = app.user_storage
self.state_manager = app.state_manager self.state_manager = app.state_manager
self.dispatcher = app.dispatcher self.dispatcher = app.dispatcher
async def guild_presences(self, member_ids: List[int], async def guild_presences(
guild_id: int) -> List[Dict[Any, str]]: self, member_ids: List[int], guild_id: int
) -> List[Dict[Any, str]]:
"""Fetch all presences in a guild.""" """Fetch all presences in a guild."""
# this works via fetching all connected GatewayState on a guild # this works via fetching all connected GatewayState on a guild
# then fetching its respective member and merging that info with # then fetching its respective member and merging that info with
@ -132,34 +127,36 @@ class PresenceManager:
presences = [] presences = []
for state in states: for state in states:
member = await self.storage.get_member_data_one( member = await self.storage.get_member_data_one(guild_id, state.user_id)
guild_id, state.user_id)
game = state.presence.get('game', None) game = state.presence.get("game", None)
# only use the data we need. # only use the data we need.
presences.append(fill_presence({ presences.append(
'user': member['user'], fill_presence(
'roles': member['roles'], {
'guild_id': str(guild_id), "user": member["user"],
"roles": member["roles"],
# if a state is connected to the guild "guild_id": str(guild_id),
# we assume its online. # if a state is connected to the guild
'status': state.presence.get('status', 'online'), # we assume its online.
}, game=game)) "status": state.presence.get("status", "online"),
},
game=game,
)
)
return presences return presences
async def dispatch_guild_pres(self, guild_id: int, async def dispatch_guild_pres(self, guild_id: int, user_id: int, new_state: dict):
user_id: int, new_state: dict):
"""Dispatch a Presence update to an entire guild.""" """Dispatch a Presence update to an entire guild."""
state = dict(new_state) state = dict(new_state)
member = await self.storage.get_member_data_one(guild_id, user_id) member = await self.storage.get_member_data_one(guild_id, user_id)
game = state['game'] game = state["game"]
lazy_guild_store = self.dispatcher.backends['lazy_guild'] lazy_guild_store = self.dispatcher.backends["lazy_guild"]
lists = lazy_guild_store.get_gml_guild(guild_id) lists = lazy_guild_store.get_gml_guild(guild_id)
# shards that are in lazy guilds with 'everyone' # shards that are in lazy guilds with 'everyone'
@ -168,49 +165,44 @@ class PresenceManager:
for member_list in lists: for member_list in lists:
session_ids = await member_list.pres_update( session_ids = await member_list.pres_update(
int(member['user']['id']), int(member["user"]["id"]),
{ {"roles": member["roles"], "status": state["status"], "game": game},
'roles': member['roles'],
'status': state['status'],
'game': game
}
) )
log.debug('Lazy Dispatch to {}', log.debug("Lazy Dispatch to {}", len(session_ids))
len(session_ids))
# if we are on the 'everyone' member list, we don't # if we are on the 'everyone' member list, we don't
# dispatch a PRESENCE_UPDATE for those shards. # dispatch a PRESENCE_UPDATE for those shards.
if member_list.channel_id == member_list.guild_id: if member_list.channel_id == member_list.guild_id:
in_lazy.extend(session_ids) in_lazy.extend(session_ids)
pres_update_payload = fill_presence({ pres_update_payload = fill_presence(
'guild_id': str(guild_id), {
'user': member['user'], "guild_id": str(guild_id),
'roles': member['roles'], "user": member["user"],
'status': state['status'], "roles": member["roles"],
}, game=game) "status": state["status"],
},
game=game,
)
# given a session id, return if the session id actually connects to # given a session id, return if the session id actually connects to
# a given user, and if the state has not been dispatched via lazy guild. # a given user, and if the state has not been dispatched via lazy guild.
def _session_check(session_id): def _session_check(session_id):
state = self.state_manager.fetch_raw(session_id) state = self.state_manager.fetch_raw(session_id)
uid = int(member['user']['id']) uid = int(member["user"]["id"])
if not state: if not state:
return False return False
# we don't want to send a presence update # we don't want to send a presence update
# to the same user # to the same user
return (state.user_id != uid and return state.user_id != uid and session_id not in in_lazy
session_id not in in_lazy)
# everyone not in lazy guild mode # everyone not in lazy guild mode
# gets a PRESENCE_UPDATE # gets a PRESENCE_UPDATE
await self.dispatcher.dispatch_filter( await self.dispatcher.dispatch_filter(
'guild', guild_id, "guild", guild_id, _session_check, "PRESENCE_UPDATE", pres_update_payload
_session_check,
'PRESENCE_UPDATE', pres_update_payload
) )
return in_lazy return in_lazy
@ -220,25 +212,25 @@ class PresenceManager:
Also dispatches the presence to all the users' friends Also dispatches the presence to all the users' friends
""" """
if state['status'] == 'invisible': if state["status"] == "invisible":
state['status'] = 'offline' state["status"] = "offline"
# TODO: shard-aware # TODO: shard-aware
guild_ids = await self.user_storage.get_user_guilds(user_id) guild_ids = await self.user_storage.get_user_guilds(user_id)
for guild_id in guild_ids: for guild_id in guild_ids:
await self.dispatch_guild_pres( await self.dispatch_guild_pres(guild_id, user_id, state)
guild_id, user_id, state)
# dispatch to all friends that are subscribed to them # dispatch to all friends that are subscribed to them
user = await self.storage.get_user(user_id) user = await self.storage.get_user(user_id)
game = state['game'] game = state["game"]
await self.dispatcher.dispatch( await self.dispatcher.dispatch(
'friend', user_id, 'PRESENCE_UPDATE', fill_presence({ "friend",
'user': user, user_id,
'status': state['status'], "PRESENCE_UPDATE",
}, game=game)) fill_presence({"user": user, "status": state["status"]}, game=game),
)
async def friend_presences(self, friend_ids: Iterable[int]) -> List[Presence]: async def friend_presences(self, friend_ids: Iterable[int]) -> List[Presence]:
"""Fetch presences for a group of users. """Fetch presences for a group of users.
@ -254,22 +246,25 @@ class PresenceManager:
if not friend_states: if not friend_states:
# append offline # append offline
res.append(await _pres(storage, friend_id, { res.append(
'afk': False, await _pres(
'status': 'offline', storage,
'game': None, friend_id,
'since': 0 {"afk": False, "status": "offline", "game": None, "since": 0},
})) )
)
continue continue
# filter the best shards: # filter the best shards:
# - all with id 0 (are the first shards in the collection) or # - all with id 0 (are the first shards in the collection) or
# - all shards with count = 1 (single shards) # - all shards with count = 1 (single shards)
good_shards = list(filter( good_shards = list(
lambda state: state.shard[0] == 0 or state.shard[1] == 1, filter(
friend_states lambda state: state.shard[0] == 0 or state.shard[1] == 1,
)) friend_states,
)
)
if good_shards: if good_shards:
best_pres = _best_presence(good_shards) best_pres = _best_presence(good_shards)

View File

@ -24,6 +24,11 @@ from .channel import ChannelDispatcher
from .friend import FriendDispatcher from .friend import FriendDispatcher
from .lazy_guild import LazyGuildDispatcher from .lazy_guild import LazyGuildDispatcher
__all__ = ['GuildDispatcher', 'MemberDispatcher', __all__ = [
'UserDispatcher', 'ChannelDispatcher', "GuildDispatcher",
'FriendDispatcher', 'LazyGuildDispatcher'] "MemberDispatcher",
"UserDispatcher",
"ChannelDispatcher",
"FriendDispatcher",
"LazyGuildDispatcher",
]

View File

@ -38,23 +38,20 @@ def gdm_recipient_view(orig: dict, user_id: int) -> dict:
# make a copy or the original channel object # make a copy or the original channel object
data = dict(orig) data = dict(orig)
idx = index_by_func( idx = index_by_func(lambda user: user["id"] == str(user_id), data["recipients"])
lambda user: user['id'] == str(user_id),
data['recipients']
)
data['recipients'].pop(idx) data["recipients"].pop(idx)
return data return data
class ChannelDispatcher(DispatcherWithFlags): class ChannelDispatcher(DispatcherWithFlags):
"""Main channel Pub/Sub logic.""" """Main channel Pub/Sub logic."""
KEY_TYPE = int KEY_TYPE = int
VAL_TYPE = int VAL_TYPE = int
async def dispatch(self, channel_id, async def dispatch(self, channel_id, event: str, data: Any) -> List[str]:
event: str, data: Any) -> List[str]:
"""Dispatch an event to a channel.""" """Dispatch an event to a channel."""
# get everyone who is subscribed # get everyone who is subscribed
# and store the number of states we dispatched the event to # and store the number of states we dispatched the event to
@ -75,9 +72,11 @@ class ChannelDispatcher(DispatcherWithFlags):
# TODO: make a fetch_states that fetches shards # TODO: make a fetch_states that fetches shards
# - with id 0 (count any) OR # - with id 0 (count any) OR
# - single shards (id=0, count=1) # - single shards (id=0, count=1)
states = (self.sm.fetch_states(user_id, guild_id) states = (
if guild_id else self.sm.fetch_states(user_id, guild_id)
self.sm.user_states(user_id)) if guild_id
else self.sm.user_states(user_id)
)
# unsub people who don't have any states tied to the channel. # unsub people who don't have any states tied to the channel.
if not states: if not states:
@ -85,28 +84,28 @@ class ChannelDispatcher(DispatcherWithFlags):
continue continue
# skip typing events for users that don't want it # skip typing events for users that don't want it
if event.startswith('TYPING_') and \ if event.startswith("TYPING_") and not self.flags_get(
not self.flags_get(channel_id, user_id, 'typing', True): channel_id, user_id, "typing", True
):
continue continue
cur_sess = [] cur_sess = []
if event in ('CHANNEL_CREATE', 'CHANNEL_UPDATE') \ if (
and data.get('type') == ChannelType.GROUP_DM.value: event in ("CHANNEL_CREATE", "CHANNEL_UPDATE")
and data.get("type") == ChannelType.GROUP_DM.value
):
# we edit the channel payload so it doesn't show # we edit the channel payload so it doesn't show
# the user as a recipient # the user as a recipient
new_data = gdm_recipient_view(data, user_id) new_data = gdm_recipient_view(data, user_id)
cur_sess = await self._dispatch_states( cur_sess = await self._dispatch_states(states, event, new_data)
states, event, new_data)
else: else:
cur_sess = await self._dispatch_states( cur_sess = await self._dispatch_states(states, event, data)
states, event, data)
sessions.extend(cur_sess) sessions.extend(cur_sess)
dispatched += len(cur_sess) dispatched += len(cur_sess)
log.info('Dispatched chan={} {!r} to {} states', log.info("Dispatched chan={} {!r} to {} states", channel_id, event, dispatched)
channel_id, event, dispatched)
return sessions return sessions

View File

@ -80,8 +80,7 @@ class Dispatcher:
""" """
raise NotImplementedError raise NotImplementedError
async def _dispatch_states(self, states: list, event: str, async def _dispatch_states(self, states: list, event: str, data) -> List[str]:
data) -> List[str]:
"""Dispatch an event to a list of states.""" """Dispatch an event to a list of states."""
res = [] res = []
@ -90,7 +89,7 @@ class Dispatcher:
await state.ws.dispatch(event, data) await state.ws.dispatch(event, data)
res.append(state.session_id) res.append(state.session_id)
except Exception: except Exception:
log.exception('error while dispatching') log.exception("error while dispatching")
return res return res
@ -102,6 +101,7 @@ class DispatcherWithState(Dispatcher):
of boilerplate code on Pub/Sub backends of boilerplate code on Pub/Sub backends
that have that dictionary. that have that dictionary.
""" """
def __init__(self, main): def __init__(self, main):
super().__init__(main) super().__init__(main)

View File

@ -31,6 +31,7 @@ class FriendDispatcher(DispatcherWithState):
channels. If that friend updates their presence, it will be channels. If that friend updates their presence, it will be
broadcasted through that channel to basically all their friends. broadcasted through that channel to basically all their friends.
""" """
KEY_TYPE = int KEY_TYPE = int
VAL_TYPE = int VAL_TYPE = int
@ -44,17 +45,13 @@ class FriendDispatcher(DispatcherWithState):
# since relationships broadcast to all shards. # since relationships broadcast to all shards.
sessions.extend( sessions.extend(
await self.main_dispatcher.dispatch_filter( await self.main_dispatcher.dispatch_filter(
'user', peer_id, func, event, data) "user", peer_id, func, event, data
)
) )
log.info('dispatched uid={} {!r} to {} states', log.info("dispatched uid={} {!r} to {} states", user_id, event, len(sessions))
user_id, event, len(sessions))
return sessions return sessions
async def dispatch(self, user_id, event, data): async def dispatch(self, user_id, event, data):
return await self.dispatch_filter( return await self.dispatch_filter(user_id, lambda sess_id: True, event, data)
user_id,
lambda sess_id: True,
event, data,
)

View File

@ -29,11 +29,11 @@ log = Logger(__name__)
class GuildDispatcher(DispatcherWithFlags): class GuildDispatcher(DispatcherWithFlags):
"""Guild backend for Pub/Sub""" """Guild backend for Pub/Sub"""
KEY_TYPE = int KEY_TYPE = int
VAL_TYPE = int VAL_TYPE = int
async def _chan_action(self, action: str, async def _chan_action(self, action: str, guild_id: int, user_id: int, flags=None):
guild_id: int, user_id: int, flags=None):
"""Send an action to all channels of the guild.""" """Send an action to all channels of the guild."""
flags = flags or {} flags = flags or {}
chan_ids = await self.app.storage.get_channel_ids(guild_id) chan_ids = await self.app.storage.get_channel_ids(guild_id)
@ -43,33 +43,31 @@ class GuildDispatcher(DispatcherWithFlags):
# only do an action for users that can # only do an action for users that can
# actually read the channel to start with. # actually read the channel to start with.
chan_perms = await get_permissions( chan_perms = await get_permissions(
user_id, chan_id, user_id, chan_id, storage=self.main_dispatcher.app.storage
storage=self.main_dispatcher.app.storage) )
if not chan_perms.bits.read_messages: if not chan_perms.bits.read_messages:
log.debug('skipping cid={}, no read messages', log.debug("skipping cid={}, no read messages", chan_id)
chan_id)
continue continue
log.debug('sending raw action {!r} to chan={}', log.debug("sending raw action {!r} to chan={}", action, chan_id)
action, chan_id)
# for now, only sub() has support for flags. # for now, only sub() has support for flags.
# it is an idea to have flags support for other actions # it is an idea to have flags support for other actions
args = [] args = []
if action == 'sub': if action == "sub":
chanflags = dict(flags) chanflags = dict(flags)
# channels don't need presence flags # channels don't need presence flags
try: try:
chanflags.pop('presence') chanflags.pop("presence")
except KeyError: except KeyError:
pass pass
args.append(chanflags) args.append(chanflags)
await self.main_dispatcher.action( await self.main_dispatcher.action(
'channel', action, chan_id, user_id, *args "channel", action, chan_id, user_id, *args
) )
async def _chan_call(self, meth: str, guild_id: int, *args): async def _chan_call(self, meth: str, guild_id: int, *args):
@ -77,26 +75,24 @@ class GuildDispatcher(DispatcherWithFlags):
in the guild.""" in the guild."""
chan_ids = await self.app.storage.get_channel_ids(guild_id) chan_ids = await self.app.storage.get_channel_ids(guild_id)
chan_dispatcher = self.main_dispatcher.backends['channel'] chan_dispatcher = self.main_dispatcher.backends["channel"]
method = getattr(chan_dispatcher, meth) method = getattr(chan_dispatcher, meth)
for chan_id in chan_ids: for chan_id in chan_ids:
log.debug('calling {} to chan={}', log.debug("calling {} to chan={}", meth, chan_id)
meth, chan_id)
await method(chan_id, *args) await method(chan_id, *args)
async def sub(self, guild_id: int, user_id: int, flags=None): async def sub(self, guild_id: int, user_id: int, flags=None):
"""Subscribe a user to the guild.""" """Subscribe a user to the guild."""
await super().sub(guild_id, user_id, flags) await super().sub(guild_id, user_id, flags)
await self._chan_action('sub', guild_id, user_id, flags) await self._chan_action("sub", guild_id, user_id, flags)
async def unsub(self, guild_id: int, user_id: int): async def unsub(self, guild_id: int, user_id: int):
"""Unsubscribe a user from the guild.""" """Unsubscribe a user from the guild."""
await super().unsub(guild_id, user_id) await super().unsub(guild_id, user_id)
await self._chan_action('unsub', guild_id, user_id) await self._chan_action("unsub", guild_id, user_id)
async def dispatch_filter(self, guild_id: int, func, async def dispatch_filter(self, guild_id: int, func, event: str, data: Any):
event: str, data: Any):
"""Selectively dispatch to session ids that have """Selectively dispatch to session ids that have
func(session_id) true.""" func(session_id) true."""
user_ids = self.state[guild_id] user_ids = self.state[guild_id]
@ -121,31 +117,23 @@ class GuildDispatcher(DispatcherWithFlags):
# note that this does not equate to any unsubscription # note that this does not equate to any unsubscription
# of the channel. # of the channel.
if event.startswith('PRESENCE_') and \ if event.startswith("PRESENCE_") and not self.flags_get(
not self.flags_get(guild_id, user_id, 'presence', True): guild_id, user_id, "presence", True
):
continue continue
# filter the ones that matter # filter the ones that matter
states = list(filter( states = list(filter(lambda state: func(state.session_id), states))
lambda state: func(state.session_id), states
))
cur_sess = await self._dispatch_states( cur_sess = await self._dispatch_states(states, event, data)
states, event, data)
sessions.extend(cur_sess) sessions.extend(cur_sess)
dispatched += len(cur_sess) dispatched += len(cur_sess)
log.info('Dispatched {} {!r} to {} states', log.info("Dispatched {} {!r} to {} states", guild_id, event, dispatched)
guild_id, event, dispatched)
return sessions return sessions
async def dispatch(self, guild_id: int, async def dispatch(self, guild_id: int, event: str, data: Any):
event: str, data: Any):
"""Dispatch an event to all subscribers of the guild.""" """Dispatch an event to all subscribers of the guild."""
return await self.dispatch_filter( return await self.dispatch_filter(guild_id, lambda sess_id: True, event, data)
guild_id,
lambda sess_id: True,
event, data,
)

File diff suppressed because it is too large Load Diff

View File

@ -22,6 +22,7 @@ from .dispatcher import Dispatcher
class MemberDispatcher(Dispatcher): class MemberDispatcher(Dispatcher):
"""Member backend for Pub/Sub.""" """Member backend for Pub/Sub."""
KEY_TYPE = tuple KEY_TYPE = tuple
async def dispatch(self, key, event, data): async def dispatch(self, key, event, data):
@ -39,7 +40,7 @@ class MemberDispatcher(Dispatcher):
# if no states were found, we should # if no states were found, we should
# unsub the user from the GUILD channel # unsub the user from the GUILD channel
if not states: if not states:
await self.main_dispatcher.unsub('guild', guild_id, user_id) await self.main_dispatcher.unsub("guild", guild_id, user_id)
return return
return await self._dispatch_states(states, event, data) return await self._dispatch_states(states, event, data)

View File

@ -22,22 +22,18 @@ from .dispatcher import Dispatcher
class UserDispatcher(Dispatcher): class UserDispatcher(Dispatcher):
"""User backend for Pub/Sub.""" """User backend for Pub/Sub."""
KEY_TYPE = int KEY_TYPE = int
async def dispatch_filter(self, user_id: int, func, event, data): async def dispatch_filter(self, user_id: int, func, event, data):
"""Dispatch an event to all shards of a user.""" """Dispatch an event to all shards of a user."""
# filter only states where func() gives true # filter only states where func() gives true
states = list(filter( states = list(
lambda state: func(state.session_id), filter(lambda state: func(state.session_id), self.sm.user_states(user_id))
self.sm.user_states(user_id) )
))
return await self._dispatch_states(states, event, data) return await self._dispatch_states(states, event, data)
async def dispatch(self, user_id: int, event, data): async def dispatch(self, user_id: int, event, data):
return await self.dispatch_filter( return await self.dispatch_filter(user_id, lambda sess_id: True, event, data)
user_id,
lambda sess_id: True,
event, data,
)

View File

@ -28,6 +28,7 @@ import time
class RatelimitBucket: class RatelimitBucket:
"""Main ratelimit bucket class.""" """Main ratelimit bucket class."""
def __init__(self, tokens, second): def __init__(self, tokens, second):
self.requests = tokens self.requests = tokens
self.second = second self.second = second
@ -88,17 +89,19 @@ class RatelimitBucket:
Used to manage multiple ratelimits to users. Used to manage multiple ratelimits to users.
""" """
return RatelimitBucket(self.requests, return RatelimitBucket(self.requests, self.second)
self.second)
def __repr__(self): def __repr__(self):
return (f'<RatelimitBucket requests={self.requests} ' return (
f'second={self.second} window: {self._window} ' f"<RatelimitBucket requests={self.requests} "
f'tokens={self._tokens}>') f"second={self.second} window: {self._window} "
f"tokens={self._tokens}>"
)
class Ratelimit: class Ratelimit:
"""Manages buckets.""" """Manages buckets."""
def __init__(self, tokens, second, keys=None): def __init__(self, tokens, second, keys=None):
self._cache = {} self._cache = {}
if keys is None: if keys is None:
@ -107,12 +110,11 @@ class Ratelimit:
self._cooldown = RatelimitBucket(tokens, second) self._cooldown = RatelimitBucket(tokens, second)
def __repr__(self): def __repr__(self):
return (f'<Ratelimit cooldown={self._cooldown}>') return f"<Ratelimit cooldown={self._cooldown}>"
def _verify_cache(self): def _verify_cache(self):
current = time.time() current = time.time()
dead_keys = [k for k, v in self._cache.items() dead_keys = [k for k, v in self._cache.items() if current > v._last + v.second]
if current > v._last + v.second]
for k in dead_keys: for k in dead_keys:
del self._cache[k] del self._cache[k]

View File

@ -31,10 +31,10 @@ async def _check_bucket(bucket):
if retry_after: if retry_after:
request.retry_after = retry_after request.retry_after = retry_after
raise Ratelimited('You are being rate limited.', { raise Ratelimited(
'retry_after': int(retry_after * 1000), "You are being rate limited.",
'global': request.bucket_global, {"retry_after": int(retry_after * 1000), "global": request.bucket_global},
}) )
async def _handle_global(ratelimit): async def _handle_global(ratelimit):
@ -59,13 +59,13 @@ async def _handle_specific(ratelimit):
keys = ratelimit.keys keys = ratelimit.keys
# base key is the user id # base key is the user id
key_components = [f'user_id:{user_id}'] key_components = [f"user_id:{user_id}"]
for key in keys: for key in keys:
val = request.view_args[key] val = request.view_args[key]
key_components.append(f'{key}:{val}') key_components.append(f"{key}:{val}")
bucket_key = ':'.join(key_components) bucket_key = ":".join(key_components)
bucket = ratelimit.get_bucket(bucket_key) bucket = ratelimit.get_bucket(bucket_key)
await _check_bucket(bucket) await _check_bucket(bucket)
@ -78,9 +78,7 @@ async def ratelimit_handler():
rule = request.url_rule rule = request.url_rule
if rule is None: if rule is None:
return await _handle_global( return await _handle_global(app.ratelimiter.global_bucket)
app.ratelimiter.global_bucket
)
# rule.endpoint is composed of '<blueprint>.<function>' # rule.endpoint is composed of '<blueprint>.<function>'
# and so we can use that to make routes with different # and so we can use that to make routes with different
@ -97,6 +95,4 @@ async def ratelimit_handler():
ratelimit = app.ratelimiter.get_ratelimit(rule_path) ratelimit = app.ratelimiter.get_ratelimit(rule_path)
await _handle_specific(ratelimit) await _handle_specific(ratelimit)
except KeyError: except KeyError:
await _handle_global( await _handle_global(app.ratelimiter.global_bucket)
app.ratelimiter.global_bucket
)

View File

@ -34,33 +34,30 @@ WS:
|All Sent Messages| | 120/60s | per-session |All Sent Messages| | 120/60s | per-session
""" """
REACTION_BUCKET = Ratelimit(1, 0.25, ('channel_id')) REACTION_BUCKET = Ratelimit(1, 0.25, ("channel_id"))
RATELIMITS = { RATELIMITS = {
'channel_messages.create_message': Ratelimit(5, 5, ('channel_id')), "channel_messages.create_message": Ratelimit(5, 5, ("channel_id")),
'channel_messages.delete_message': Ratelimit(5, 1, ('channel_id')), "channel_messages.delete_message": Ratelimit(5, 1, ("channel_id")),
# all of those share the same bucket. # all of those share the same bucket.
'channel_reactions.add_reaction': REACTION_BUCKET, "channel_reactions.add_reaction": REACTION_BUCKET,
'channel_reactions.remove_own_reaction': REACTION_BUCKET, "channel_reactions.remove_own_reaction": REACTION_BUCKET,
'channel_reactions.remove_user_reaction': REACTION_BUCKET, "channel_reactions.remove_user_reaction": REACTION_BUCKET,
"guild_members.modify_guild_member": Ratelimit(10, 10, ("guild_id")),
'guild_members.modify_guild_member': Ratelimit(10, 10, ('guild_id')), "guild_members.update_nickname": Ratelimit(1, 1, ("guild_id")),
'guild_members.update_nickname': Ratelimit(1, 1, ('guild_id')),
# this only applies to username. # this only applies to username.
# 'users.patch_me': Ratelimit(2, 3600), # 'users.patch_me': Ratelimit(2, 3600),
"_ws.connect": Ratelimit(1, 5),
'_ws.connect': Ratelimit(1, 5), "_ws.presence": Ratelimit(5, 60),
'_ws.presence': Ratelimit(5, 60), "_ws.messages": Ratelimit(120, 60),
'_ws.messages': Ratelimit(120, 60),
# 1000 / 4h for new session issuing # 1000 / 4h for new session issuing
'_ws.session': Ratelimit(1000, 14400) "_ws.session": Ratelimit(1000, 14400),
} }
class RatelimitManager: class RatelimitManager:
"""Manager for the bucket managers""" """Manager for the bucket managers"""
def __init__(self, testing_flag=False): def __init__(self, testing_flag=False):
self._ratelimiters = {} self._ratelimiters = {}
self._test = testing_flag self._test = testing_flag
@ -74,9 +71,7 @@ class RatelimitManager:
# NOTE: this is a bad way to do it, but # NOTE: this is a bad way to do it, but
# we only need to change that one for now. # we only need to change that one for now.
rtl = (Ratelimit(10, 1) rtl = Ratelimit(10, 1) if self._test and path == "_ws.connect" else rtl
if self._test and path == '_ws.connect'
else rtl)
self._ratelimiters[path] = rtl self._ratelimiters[path] = rtl

View File

@ -28,26 +28,30 @@ from .errors import BadRequest
from .permissions import Permissions from .permissions import Permissions
from .types import Color from .types import Color
from .enums import ( from .enums import (
ActivityType, StatusType, ExplicitFilter, RelationshipType, ActivityType,
MessageNotifications, ChannelType, VerificationLevel StatusType,
ExplicitFilter,
RelationshipType,
MessageNotifications,
ChannelType,
VerificationLevel,
) )
from litecord.embed.schemas import EMBED_OBJECT, EmbedURL from litecord.embed.schemas import EMBED_OBJECT, EmbedURL
log = Logger(__name__) log = Logger(__name__)
USERNAME_REGEX = re.compile(r'^[a-zA-Z0-9_ ]{2,30}$', re.A) USERNAME_REGEX = re.compile(r"^[a-zA-Z0-9_ ]{2,30}$", re.A)
EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$', EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$", re.A)
re.A) DATA_REGEX = re.compile(r"data\:image/(png|jpeg|gif);base64,(.+)", re.A)
DATA_REGEX = re.compile(r'data\:image/(png|jpeg|gif);base64,(.+)', re.A)
# collection of regexes # collection of regexes
USER_MENTION = re.compile(r'<@!?(\d+)>', re.A | re.M) USER_MENTION = re.compile(r"<@!?(\d+)>", re.A | re.M)
CHAN_MENTION = re.compile(r'<#(\d+)>', re.A | re.M) CHAN_MENTION = re.compile(r"<#(\d+)>", re.A | re.M)
ROLE_MENTION = re.compile(r'<@&(\d+)>', re.A | re.M) ROLE_MENTION = re.compile(r"<@&(\d+)>", re.A | re.M)
EMOJO_MENTION = re.compile(r'<:(\.+):(\d+)>', re.A | re.M) EMOJO_MENTION = re.compile(r"<:(\.+):(\d+)>", re.A | re.M)
ANIMOJI_MENTION = re.compile(r'<a:(\.+):(\d+)>', re.A | re.M) ANIMOJI_MENTION = re.compile(r"<a:(\.+):(\d+)>", re.A | re.M)
def _in_enum(enum, value) -> bool: def _in_enum(enum, value) -> bool:
@ -61,6 +65,7 @@ def _in_enum(enum, value) -> bool:
class LitecordValidator(Validator): class LitecordValidator(Validator):
"""Main validator class for Litecord, containing custom types.""" """Main validator class for Litecord, containing custom types."""
def _validate_type_username(self, value: str) -> bool: def _validate_type_username(self, value: str) -> bool:
"""Validate against the username regex.""" """Validate against the username regex."""
return bool(USERNAME_REGEX.match(value)) return bool(USERNAME_REGEX.match(value))
@ -130,8 +135,7 @@ class LitecordValidator(Validator):
return False return False
# nobody is allowed to use the INCOMING and OUTGOING rel types # nobody is allowed to use the INCOMING and OUTGOING rel types
return val in (RelationshipType.FRIEND.value, return val in (RelationshipType.FRIEND.value, RelationshipType.BLOCK.value)
RelationshipType.BLOCK.value)
def _validate_type_msg_notifications(self, value: str): def _validate_type_msg_notifications(self, value: str):
try: try:
@ -152,14 +156,15 @@ class LitecordValidator(Validator):
return self._validate_type_guild_name(value) return self._validate_type_guild_name(value)
def _validate_type_theme(self, value: str) -> bool: def _validate_type_theme(self, value: str) -> bool:
return value in ['light', 'dark'] return value in ["light", "dark"]
def _validate_type_nickname(self, value: str) -> bool: def _validate_type_nickname(self, value: str) -> bool:
return isinstance(value, str) and (len(value) < 32) return isinstance(value, str) and (len(value) < 32)
def validate(reqjson: Optional[Union[Dict, List]], schema: Dict, def validate(
raise_err: bool = True) -> Dict: reqjson: Optional[Union[Dict, List]], schema: Dict, raise_err: bool = True
) -> Dict:
"""Validate the given user-given data against a schema, giving the """Validate the given user-given data against a schema, giving the
"correct" version of the document, with all defaults applied. "correct" version of the document, with all defaults applied.
@ -176,20 +181,20 @@ def validate(reqjson: Optional[Union[Dict, List]], schema: Dict,
validator = LitecordValidator(schema) validator = LitecordValidator(schema)
if reqjson is None: if reqjson is None:
raise BadRequest('No JSON provided') raise BadRequest("No JSON provided")
try: try:
valid = validator.validate(reqjson) valid = validator.validate(reqjson)
except Exception: except Exception:
log.exception('Error while validating') log.exception("Error while validating")
raise Exception(f'Error while validating: {reqjson}') raise Exception(f"Error while validating: {reqjson}")
if not valid: if not valid:
errs = validator.errors errs = validator.errors
log.warning('Error validating doc {!r}: {!r}', reqjson, errs) log.warning("Error validating doc {!r}: {!r}", reqjson, errs)
if raise_err: if raise_err:
raise BadRequest('bad payload', errs) raise BadRequest("bad payload", errs)
return None return None
@ -197,554 +202,441 @@ def validate(reqjson: Optional[Union[Dict, List]], schema: Dict,
REGISTER = { REGISTER = {
'username': {'type': 'username', 'required': True}, "username": {"type": "username", "required": True},
'email': {'type': 'email', 'required': False}, "email": {"type": "email", "required": False},
'password': {'type': 'password', 'required': False}, "password": {"type": "password", "required": False},
# invite stands for a guild invite, not an instance invite (that's on # invite stands for a guild invite, not an instance invite (that's on
# the register_with_invite handler). # the register_with_invite handler).
'invite': {'type': 'string', 'required': False, 'nullable': True}, "invite": {"type": "string", "required": False, "nullable": True},
# following fields only sent by official client, unused by us # following fields only sent by official client, unused by us
'fingerprint': {'type': 'string', 'required': False, 'nullable': True}, "fingerprint": {"type": "string", "required": False, "nullable": True},
'captcha_key': {'type': 'string', 'required': False, 'nullable': True}, "captcha_key": {"type": "string", "required": False, "nullable": True},
'gift_code_sku_id': {'type': 'string', 'required': False, 'nullable': True}, "gift_code_sku_id": {"type": "string", "required": False, "nullable": True},
'consent': {'type': 'boolean', 'required': False}, "consent": {"type": "boolean", "required": False},
} }
# only used by us, not discord, hence 'invcode' (to separate from discord) # only used by us, not discord, hence 'invcode' (to separate from discord)
REGISTER_WITH_INVITE = {**REGISTER, **{ REGISTER_WITH_INVITE = {**REGISTER, **{"invcode": {"type": "string", "required": True}}}
'invcode': {'type': 'string', 'required': True}
}}
USER_UPDATE = { USER_UPDATE = {
'username': { "username": {
'type': 'username', 'minlength': 2, "type": "username",
'maxlength': 30, 'required': False}, "minlength": 2,
"maxlength": 30,
'discriminator': { "required": False,
'type': 'discriminator',
'required': False,
'nullable': True,
}, },
"discriminator": {"type": "discriminator", "required": False, "nullable": True},
'password': { "password": {"type": "password", "required": False},
'type': 'password', 'required': False, "new_password": {
"type": "password",
"required": False,
"dependencies": "password",
"nullable": True,
}, },
"email": {"type": "email", "required": False, "dependencies": "password"},
'new_password': { "avatar": {
'type': 'password', 'required': False,
'dependencies': 'password', 'nullable': True
},
'email': {
'type': 'email', 'required': False, 'dependencies': 'password',
},
'avatar': {
# can be both b64_icon or string (just the hash) # can be both b64_icon or string (just the hash)
'type': 'string', 'required': False, "type": "string",
'nullable': True "required": False,
"nullable": True,
}, },
} }
PARTIAL_ROLE_GUILD_CREATE = { PARTIAL_ROLE_GUILD_CREATE = {
'type': 'dict', "type": "dict",
'schema': { "schema": {
'name': {'type': 'role_name'}, "name": {"type": "role_name"},
'color': {'type': 'number', 'default': 0}, "color": {"type": "number", "default": 0},
'hoist': {'type': 'boolean', 'default': False}, "hoist": {"type": "boolean", "default": False},
# NOTE: no position on partial role (on guild create) # NOTE: no position on partial role (on guild create)
"permissions": {"coerce": Permissions, "required": False},
'permissions': {'coerce': Permissions, 'required': False}, "mentionable": {"type": "boolean", "default": False},
'mentionable': {'type': 'boolean', 'default': False}, },
}
} }
PARTIAL_CHANNEL_GUILD_CREATE = { PARTIAL_CHANNEL_GUILD_CREATE = {
'type': 'dict', "type": "dict",
'schema': { "schema": {"name": {"type": "channel_name"}, "type": {"type": "channel_type"}},
'name': {'type': 'channel_name'},
'type': {'type': 'channel_type'},
}
} }
GUILD_CREATE = { GUILD_CREATE = {
'name': {'type': 'guild_name'}, "name": {"type": "guild_name"},
'region': {'type': 'voice_region', 'nullable': True}, "region": {"type": "voice_region", "nullable": True},
'icon': {'type': 'b64_icon', 'required': False, 'nullable': True}, "icon": {"type": "b64_icon", "required": False, "nullable": True},
"verification_level": {"type": "verification_level", "default": 0},
'verification_level': { "default_message_notifications": {"type": "msg_notifications", "default": 0},
'type': 'verification_level', 'default': 0}, "explicit_content_filter": {"type": "explicit", "default": 0},
'default_message_notifications': { "roles": {"type": "list", "required": False, "schema": PARTIAL_ROLE_GUILD_CREATE},
'type': 'msg_notifications', 'default': 0}, "channels": {"type": "list", "default": [], "schema": PARTIAL_CHANNEL_GUILD_CREATE},
'explicit_content_filter': {
'type': 'explicit', 'default': 0},
'roles': {
'type': 'list', 'required': False,
'schema': PARTIAL_ROLE_GUILD_CREATE},
'channels': {
'type': 'list', 'default': [], 'schema': PARTIAL_CHANNEL_GUILD_CREATE},
} }
GUILD_UPDATE = { GUILD_UPDATE = {
'name': { "name": {"type": "guild_name", "required": False},
'type': 'guild_name', "region": {"type": "voice_region", "required": False, "nullable": True},
'required': False
},
'region': {'type': 'voice_region', 'required': False, 'nullable': True},
# all three can have hashes # all three can have hashes
'icon': {'type': 'string', 'required': False, 'nullable': True}, "icon": {"type": "string", "required": False, "nullable": True},
'banner': {'type': 'string', 'required': False, 'nullable': True}, "banner": {"type": "string", "required": False, "nullable": True},
'splash': {'type': 'string', 'required': False, 'nullable': True}, "splash": {"type": "string", "required": False, "nullable": True},
"description": {
'description': { "type": "string",
'type': 'string', 'required': False, "required": False,
'minlength': 1, 'maxlength': 120, "minlength": 1,
'nullable': True "maxlength": 120,
"nullable": True,
}, },
"verification_level": {"type": "verification_level", "required": False},
'verification_level': { "default_message_notifications": {"type": "msg_notifications", "required": False},
'type': 'verification_level', 'required': False}, "explicit_content_filter": {"type": "explicit", "required": False},
'default_message_notifications': { "afk_channel_id": {"type": "snowflake", "required": False, "nullable": True},
'type': 'msg_notifications', 'required': False}, "afk_timeout": {"type": "number", "required": False},
'explicit_content_filter': {'type': 'explicit', 'required': False}, "owner_id": {"type": "snowflake", "required": False},
"system_channel_id": {"type": "snowflake", "required": False, "nullable": True},
'afk_channel_id': {
'type': 'snowflake', 'required': False, 'nullable': True},
'afk_timeout': {'type': 'number', 'required': False},
'owner_id': {'type': 'snowflake', 'required': False},
'system_channel_id': {
'type': 'snowflake', 'required': False, 'nullable': True},
} }
CHAN_OVERWRITE = { CHAN_OVERWRITE = {
'id': {'coerce': int}, "id": {"coerce": int},
'type': {'type': 'string', 'allowed': ['role', 'member']}, "type": {"type": "string", "allowed": ["role", "member"]},
'allow': {'coerce': Permissions}, "allow": {"coerce": Permissions},
'deny': {'coerce': Permissions} "deny": {"coerce": Permissions},
} }
CHAN_CREATE = { CHAN_CREATE = {
'name': { "name": {"type": "string", "minlength": 2, "maxlength": 100, "required": True},
'type': 'string', 'minlength': 2, "type": {"type": "channel_type", "default": ChannelType.GUILD_TEXT.value},
'maxlength': 100, 'required': True "position": {"coerce": int, "required": False},
}, "topic": {"type": "string", "minlength": 0, "maxlength": 1024, "required": False},
"nsfw": {"type": "boolean", "required": False},
'type': {'type': 'channel_type', "rate_limit_per_user": {"coerce": int, "min": 0, "max": 120, "required": False},
'default': ChannelType.GUILD_TEXT.value}, "bitrate": {
"coerce": int,
'position': {'coerce': int, 'required': False}, "min": 8000,
'topic': {
'type': 'string', 'minlength': 0,
'maxlength': 1024, 'required': False},
'nsfw': {'type': 'boolean', 'required': False},
'rate_limit_per_user': {
'coerce': int, 'min': 0,
'max': 120, 'required': False},
'bitrate': {
'coerce': int, 'min': 8000,
# NOTE: 'max' is 96000 for non-vip guilds # NOTE: 'max' is 96000 for non-vip guilds
'max': 128000, 'required': False}, "max": 128000,
"required": False,
'user_limit': { },
"user_limit": {
# user_limit being 0 means infinite. # user_limit being 0 means infinite.
'coerce': int, 'min': 0, "coerce": int,
'max': 99, 'required': False "min": 0,
"max": 99,
"required": False,
}, },
"permission_overwrites": {
'permission_overwrites': { "type": "list",
'type': 'list', "schema": {"type": "dict", "schema": CHAN_OVERWRITE},
'schema': {'type': 'dict', 'schema': CHAN_OVERWRITE}, "required": False,
'required': False
}, },
"parent_id": {"coerce": int, "required": False, "nullable": True},
'parent_id': {'coerce': int, 'required': False, 'nullable': True}
} }
CHAN_UPDATE = {**CHAN_CREATE, **{ CHAN_UPDATE = {
'name': { **CHAN_CREATE,
'type': 'string', 'minlength': 2, **{"name": {"type": "string", "minlength": 2, "maxlength": 100, "required": False}},
'maxlength': 100, 'required': False}, }
}}
ROLE_CREATE = { ROLE_CREATE = {
'name': {'type': 'string', 'default': 'new role'}, "name": {"type": "string", "default": "new role"},
'permissions': {'coerce': Permissions, 'nullable': True}, "permissions": {"coerce": Permissions, "nullable": True},
'color': {'coerce': Color, 'default': 0}, "color": {"coerce": Color, "default": 0},
'hoist': {'type': 'boolean', 'default': False}, "hoist": {"type": "boolean", "default": False},
'mentionable': {'type': 'boolean', 'default': False}, "mentionable": {"type": "boolean", "default": False},
} }
ROLE_UPDATE = { ROLE_UPDATE = {
'name': {'type': 'string', 'required': False}, "name": {"type": "string", "required": False},
'permissions': {'coerce': Permissions, 'required': False}, "permissions": {"coerce": Permissions, "required": False},
'color': {'coerce': Color, 'required': False}, "color": {"coerce": Color, "required": False},
'hoist': {'type': 'boolean', 'required': False}, "hoist": {"type": "boolean", "required": False},
'mentionable': {'type': 'boolean', 'required': False}, "mentionable": {"type": "boolean", "required": False},
} }
ROLE_UPDATE_POSITION = { ROLE_UPDATE_POSITION = {
'roles': { "roles": {
'type': 'list', "type": "list",
'schema': { "schema": {
'type': 'dict', "type": "dict",
'schema': { "schema": {"id": {"coerce": int}, "position": {"coerce": int}},
'id': {'coerce': int}, },
'position': {'coerce': int},
},
}
} }
} }
MEMBER_UPDATE = { MEMBER_UPDATE = {
'nick': { "nick": {"type": "nickname", "required": False},
'type': 'nickname', 'required': False}, "roles": {"type": "list", "required": False, "schema": {"coerce": int}},
'roles': {'type': 'list', 'required': False, "mute": {"type": "boolean", "required": False},
'schema': {'coerce': int}}, "deaf": {"type": "boolean", "required": False},
'mute': {'type': 'boolean', 'required': False}, "channel_id": {"type": "snowflake", "required": False},
'deaf': {'type': 'boolean', 'required': False},
'channel_id': {'type': 'snowflake', 'required': False},
} }
# NOTE: things such as payload_json are parsed at the handler # NOTE: things such as payload_json are parsed at the handler
# for creating a message. # for creating a message.
MESSAGE_CREATE = { MESSAGE_CREATE = {
'content': {'type': 'string', 'minlength': 0, 'maxlength': 2000}, "content": {"type": "string", "minlength": 0, "maxlength": 2000},
'nonce': {'type': 'snowflake', 'required': False}, "nonce": {"type": "snowflake", "required": False},
'tts': {'type': 'boolean', 'required': False}, "tts": {"type": "boolean", "required": False},
"embed": {
'embed': { "type": "dict",
'type': 'dict', "schema": EMBED_OBJECT,
'schema': EMBED_OBJECT, "required": False,
'required': False, "nullable": True,
'nullable': True },
}
} }
GW_ACTIVITY = { GW_ACTIVITY = {
'name': {'type': 'string', 'required': True}, "name": {"type": "string", "required": True},
'type': {'type': 'activity_type', 'required': True}, "type": {"type": "activity_type", "required": True},
"url": {"type": "string", "required": False, "nullable": True},
'url': {'type': 'string', 'required': False, 'nullable': True}, "timestamps": {
"type": "dict",
'timestamps': { "required": False,
'type': 'dict', "schema": {
'required': False, "start": {"type": "number", "required": False},
'schema': { "end": {"type": "number", "required": False},
'start': {'type': 'number', 'required': False},
'end': {'type': 'number', 'required': False},
}, },
}, },
"application_id": {"type": "snowflake", "required": False, "nullable": False},
'application_id': {'type': 'snowflake', 'required': False, "details": {"type": "string", "required": False, "nullable": True},
'nullable': False}, "state": {"type": "string", "required": False, "nullable": True},
'details': {'type': 'string', 'required': False, 'nullable': True}, "party": {
'state': {'type': 'string', 'required': False, 'nullable': True}, "type": "dict",
"required": False,
'party': { "schema": {
'type': 'dict', "id": {"type": "snowflake", "required": False},
'required': False, "size": {"type": "list", "required": False},
'schema': { },
'id': {'type': 'snowflake', 'required': False},
'size': {'type': 'list', 'required': False},
}
}, },
"assets": {
'assets': { "type": "dict",
'type': 'dict', "required": False,
'required': False, "schema": {
'schema': { "large_image": {"type": "snowflake", "required": False},
'large_image': {'type': 'snowflake', 'required': False}, "large_text": {"type": "string", "required": False},
'large_text': {'type': 'string', 'required': False}, "small_image": {"type": "snowflake", "required": False},
'small_image': {'type': 'snowflake', 'required': False}, "small_text": {"type": "string", "required": False},
'small_text': {'type': 'string', 'required': False}, },
}
}, },
"secrets": {
'secrets': { "type": "dict",
'type': 'dict', "required": False,
'required': False, "schema": {
'schema': { "join": {"type": "string", "required": False},
'join': {'type': 'string', 'required': False}, "spectate": {"type": "string", "required": False},
'spectate': {'type': 'string', 'required': False}, "match": {"type": "string", "required": False},
'match': {'type': 'string', 'required': False}, },
}
}, },
"instance": {"type": "boolean", "required": False},
'instance': {'type': 'boolean', 'required': False}, "flags": {"type": "number", "required": False},
'flags': {'type': 'number', 'required': False},
} }
GW_STATUS_UPDATE = { GW_STATUS_UPDATE = {
'status': {'type': 'status_external', 'required': False, "status": {"type": "status_external", "required": False, "default": "online"},
'default': 'online'}, "activities": {
'activities': { "type": "list",
'type': 'list', 'required': False, "required": False,
'schema': {'type': 'dict', 'schema': GW_ACTIVITY} "schema": {"type": "dict", "schema": GW_ACTIVITY},
}, },
'afk': {'type': 'boolean', 'required': False}, "afk": {"type": "boolean", "required": False},
"since": {"type": "number", "required": False, "nullable": True},
'since': {'type': 'number', 'required': False, 'nullable': True}, "game": {
'game': { "type": "dict",
'type': 'dict', "required": False,
'required': False, "nullable": True,
'nullable': True, "schema": GW_ACTIVITY,
'schema': GW_ACTIVITY,
}, },
} }
INVITE = { INVITE = {
# max_age in seconds # max_age in seconds
# 0 for infinite # 0 for infinite
'max_age': { "max_age": {
'type': 'number', "type": "number",
'min': 0, "min": 0,
'max': 86400, "max": 86400,
# a day # a day
'default': 86400 "default": 86400,
}, },
# max invite uses # max invite uses
'max_uses': { "max_uses": {
'type': 'number', "type": "number",
'min': 0, "min": 0,
# idk # idk
'max': 1000, "max": 1000,
# default infinite # default infinite
'default': 0 "default": 0,
}, },
"temporary": {"type": "boolean", "required": False, "default": False},
'temporary': {'type': 'boolean', 'required': False, 'default': False}, "unique": {"type": "boolean", "required": False, "default": True},
'unique': {'type': 'boolean', 'required': False, 'default': True}, "validate": {
'validate': {'type': 'string', 'required': False, 'nullable': True} # discord client sends invite code there "type": "string",
"required": False,
"nullable": True,
}, # discord client sends invite code there
} }
USER_SETTINGS = { USER_SETTINGS = {
'afk_timeout': { "afk_timeout": {"type": "number", "required": False, "min": 0, "max": 3000},
'type': 'number', 'required': False, 'min': 0, 'max': 3000}, "animate_emoji": {"type": "boolean", "required": False},
"convert_emoticons": {"type": "boolean", "required": False},
'animate_emoji': {'type': 'boolean', 'required': False}, "default_guilds_restricted": {"type": "boolean", "required": False},
'convert_emoticons': {'type': 'boolean', 'required': False}, "detect_platform_accounts": {"type": "boolean", "required": False},
'default_guilds_restricted': {'type': 'boolean', 'required': False}, "developer_mode": {"type": "boolean", "required": False},
'detect_platform_accounts': {'type': 'boolean', 'required': False}, "disable_games_tab": {"type": "boolean", "required": False},
'developer_mode': {'type': 'boolean', 'required': False}, "enable_tts_command": {"type": "boolean", "required": False},
'disable_games_tab': {'type': 'boolean', 'required': False}, "explicit_content_filter": {"type": "explicit", "required": False},
'enable_tts_command': {'type': 'boolean', 'required': False}, "friend_source": {
"type": "dict",
'explicit_content_filter': {'type': 'explicit', 'required': False}, "required": False,
"schema": {
'friend_source': { "all": {"type": "boolean", "required": False},
'type': 'dict', "mutual_guilds": {"type": "boolean", "required": False},
'required': False, "mutual_friends": {"type": "boolean", "required": False},
'schema': { },
'all': {'type': 'boolean', 'required': False},
'mutual_guilds': {'type': 'boolean', 'required': False},
'mutual_friends': {'type': 'boolean', 'required': False},
}
}, },
'guild_positions': { "guild_positions": {
'type': 'list', "type": "list",
'required': False, "required": False,
'schema': {'type': 'snowflake'} "schema": {"type": "snowflake"},
}, },
'restricted_guilds': { "restricted_guilds": {
'type': 'list', "type": "list",
'required': False, "required": False,
'schema': {'type': 'snowflake'} "schema": {"type": "snowflake"},
}, },
"gif_auto_play": {"type": "boolean", "required": False},
'gif_auto_play': {'type': 'boolean', 'required': False}, "inline_attachment_media": {"type": "boolean", "required": False},
'inline_attachment_media': {'type': 'boolean', 'required': False}, "inline_embed_media": {"type": "boolean", "required": False},
'inline_embed_media': {'type': 'boolean', 'required': False}, "message_display_compact": {"type": "boolean", "required": False},
'message_display_compact': {'type': 'boolean', 'required': False}, "render_embeds": {"type": "boolean", "required": False},
'render_embeds': {'type': 'boolean', 'required': False}, "render_reactions": {"type": "boolean", "required": False},
'render_reactions': {'type': 'boolean', 'required': False}, "show_current_game": {"type": "boolean", "required": False},
'show_current_game': {'type': 'boolean', 'required': False}, "timezone_offset": {"type": "number", "required": False},
"status": {"type": "status_external", "required": False},
'timezone_offset': {'type': 'number', 'required': False}, "theme": {"type": "theme", "required": False},
'status': {'type': 'status_external', 'required': False},
'theme': {'type': 'theme', 'required': False}
} }
RELATIONSHIP = { RELATIONSHIP = {
'type': { "type": {
'type': 'rel_type', "type": "rel_type",
'required': False, "required": False,
'default': RelationshipType.FRIEND.value "default": RelationshipType.FRIEND.value,
} }
} }
CREATE_DM = { CREATE_DM = {"recipient_id": {"type": "snowflake", "required": True}}
'recipient_id': {
'type': 'snowflake',
'required': True
}
}
CREATE_GROUP_DM = { CREATE_GROUP_DM = {
'recipients': { "recipients": {"type": "list", "required": True, "schema": {"type": "snowflake"}}
'type': 'list',
'required': True,
'schema': {'type': 'snowflake'}
},
} }
GROUP_DM_UPDATE = { GROUP_DM_UPDATE = {
'name': { "name": {"type": "guild_name", "required": False},
'type': 'guild_name', "icon": {"type": "b64_icon", "required": False, "nullable": True},
'required': False
},
'icon': {'type': 'b64_icon', 'required': False, 'nullable': True},
} }
SPECIFIC_FRIEND = { SPECIFIC_FRIEND = {
'username': {'type': 'username'}, "username": {"type": "username"},
'discriminator': {'type': 'discriminator'} "discriminator": {"type": "discriminator"},
} }
GUILD_SETTINGS_CHAN_OVERRIDE = { GUILD_SETTINGS_CHAN_OVERRIDE = {
'type': 'dict', "type": "dict",
'schema': { "schema": {
'muted': { "muted": {"type": "boolean", "required": False},
'type': 'boolean', 'required': False}, "message_notifications": {"type": "msg_notifications", "required": False},
'message_notifications': { },
'type': 'msg_notifications',
'required': False,
}
}
} }
GUILD_SETTINGS = { GUILD_SETTINGS = {
'channel_overrides': { "channel_overrides": {
'type': 'dict', "type": "dict",
'valueschema': GUILD_SETTINGS_CHAN_OVERRIDE, "valueschema": GUILD_SETTINGS_CHAN_OVERRIDE,
'keyschema': {'type': 'snowflake'}, "keyschema": {"type": "snowflake"},
'required': False, "required": False,
}, },
'suppress_everyone': { "suppress_everyone": {"type": "boolean", "required": False},
'type': 'boolean', 'required': False}, "muted": {"type": "boolean", "required": False},
'muted': { "mobile_push": {"type": "boolean", "required": False},
'type': 'boolean', 'required': False}, "message_notifications": {"type": "msg_notifications", "required": False},
'mobile_push': {
'type': 'boolean', 'required': False},
'message_notifications': {
'type': 'msg_notifications',
'required': False,
}
} }
GUILD_PRUNE = { GUILD_PRUNE = {
'days': {'type': 'number', 'coerce': int, 'min': 1, 'max': 30, 'default': 7}, "days": {"type": "number", "coerce": int, "min": 1, "max": 30, "default": 7},
'compute_prune_count': {'type': 'string', 'default': 'true'} "compute_prune_count": {"type": "string", "default": "true"},
} }
NEW_EMOJI = { NEW_EMOJI = {
'name': { "name": {"type": "string", "minlength": 1, "maxlength": 256, "required": True},
'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': True}, "image": {"type": "b64_icon", "required": True},
'image': {'type': 'b64_icon', 'required': True}, "roles": {"type": "list", "schema": {"coerce": int}},
'roles': {'type': 'list', 'schema': {'coerce': int}}
} }
PATCH_EMOJI = { PATCH_EMOJI = {
'name': { "name": {"type": "string", "minlength": 1, "maxlength": 256, "required": True},
'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': True}, "roles": {"type": "list", "schema": {"coerce": int}},
'roles': {'type': 'list', 'schema': {'coerce': int}}
} }
SEARCH_CHANNEL = { SEARCH_CHANNEL = {
'content': {'type': 'string', 'minlength': 1, 'required': True}, "content": {"type": "string", "minlength": 1, "required": True},
'include_nsfw': {'coerce': bool, 'default': False}, "include_nsfw": {"coerce": bool, "default": False},
'offset': {'coerce': int, 'default': 0} "offset": {"coerce": int, "default": 0},
} }
GET_MENTIONS = { GET_MENTIONS = {
'limit': {'coerce': int, 'default': 25}, "limit": {"coerce": int, "default": 25},
'roles': {'coerce': bool, 'default': True}, "roles": {"coerce": bool, "default": True},
'everyone': {'coerce': bool, 'default': True}, "everyone": {"coerce": bool, "default": True},
'guild_id': {'coerce': int, 'required': False} "guild_id": {"coerce": int, "required": False},
} }
VANITY_URL_PATCH = { VANITY_URL_PATCH = {
# TODO: put proper values in maybe an invite data type # TODO: put proper values in maybe an invite data type
'code': {'type': 'string', 'minlength': 5, 'maxlength': 30} "code": {"type": "string", "minlength": 5, "maxlength": 30}
} }
WEBHOOK_CREATE = { WEBHOOK_CREATE = {
'name': { "name": {"type": "string", "minlength": 2, "maxlength": 32, "required": True},
'type': 'string', 'minlength': 2, 'maxlength': 32, "avatar": {"type": "b64_icon", "required": False, "nullable": False},
'required': True
},
'avatar': {'type': 'b64_icon', 'required': False, 'nullable': False}
} }
WEBHOOK_UPDATE = { WEBHOOK_UPDATE = {
'name': { "name": {"type": "string", "minlength": 2, "maxlength": 32, "required": False},
'type': 'string', 'minlength': 2, 'maxlength': 32,
'required': False
},
# TODO: check if its b64_icon or string since the client # TODO: check if its b64_icon or string since the client
# could pass an icon hash instead. # could pass an icon hash instead.
'avatar': {'type': 'b64_icon', 'required': False, 'nullable': False}, "avatar": {"type": "b64_icon", "required": False, "nullable": False},
'channel_id': {'coerce': int, 'required': False, 'nullable': False} "channel_id": {"coerce": int, "required": False, "nullable": False},
} }
WEBHOOK_MESSAGE_CREATE = { WEBHOOK_MESSAGE_CREATE = {
'content': { "content": {"type": "string", "minlength": 0, "maxlength": 2000, "required": False},
'type': 'string', "tts": {"type": "boolean", "required": False},
'minlength': 0, 'maxlength': 2000, 'required': False "username": {"type": "string", "minlength": 2, "maxlength": 32, "required": False},
"avatar_url": {"coerce": EmbedURL, "required": False},
"embeds": {
"type": "list",
"required": False,
"schema": {"type": "dict", "schema": EMBED_OBJECT},
}, },
'tts': {'type': 'boolean', 'required': False},
'username': {
'type': 'string',
'minlength': 2, 'maxlength': 32, 'required': False
},
'avatar_url': {
'coerce': EmbedURL, 'required': False
},
'embeds': {
'type': 'list',
'required': False,
'schema': {'type': 'dict', 'schema': EMBED_OBJECT}
}
} }
BULK_DELETE = { BULK_DELETE = {
'messages': { "messages": {
'type': 'list', 'required': True, "type": "list",
'minlength': 2, 'maxlength': 100, "required": True,
'schema': {'coerce': int} "minlength": 2,
"maxlength": 100,
"schema": {"coerce": int},
} }
} }

View File

@ -61,19 +61,19 @@ def _snowflake(timestamp: int) -> Snowflake:
# bits 0-12 encode _generated_ids (size 12) # bits 0-12 encode _generated_ids (size 12)
# modulo'd to prevent overflows # modulo'd to prevent overflows
genid_b = '{0:012b}'.format(_generated_ids % 4096) genid_b = "{0:012b}".format(_generated_ids % 4096)
# bits 12-17 encode PROCESS_ID (size 5) # bits 12-17 encode PROCESS_ID (size 5)
procid_b = '{0:05b}'.format(PROCESS_ID) procid_b = "{0:05b}".format(PROCESS_ID)
# bits 17-22 encode WORKER_ID (size 5) # bits 17-22 encode WORKER_ID (size 5)
workid_b = '{0:05b}'.format(WORKER_ID) workid_b = "{0:05b}".format(WORKER_ID)
# bits 22-64 encode (timestamp - EPOCH) (size 42) # bits 22-64 encode (timestamp - EPOCH) (size 42)
epochized = timestamp - EPOCH epochized = timestamp - EPOCH
epoch_b = '{0:042b}'.format(epochized) epoch_b = "{0:042b}".format(epochized)
snowflake_b = f'{epoch_b}{workid_b}{procid_b}{genid_b}' snowflake_b = f"{epoch_b}{workid_b}{procid_b}{genid_b}"
_generated_ids += 1 _generated_ids += 1
return int(snowflake_b, 2) return int(snowflake_b, 2)
@ -87,7 +87,7 @@ def snowflake_time(snowflake: Snowflake) -> float:
# the total size for a snowflake is 64 bits, # the total size for a snowflake is 64 bits,
# considering it is a string, position 0 to 42 will give us # considering it is a string, position 0 to 42 will give us
# the `epochized` variable # the `epochized` variable
snowflake_b = '{0:064b}'.format(snowflake) snowflake_b = "{0:064b}".format(snowflake)
epochized_b = snowflake_b[:42] epochized_b = snowflake_b[:42]
epochized = int(epochized_b, 2) epochized = int(epochized_b, 2)

File diff suppressed because it is too large Load Diff

View File

@ -24,6 +24,7 @@ from litecord.enums import MessageType
log = Logger(__name__) log = Logger(__name__)
async def _handle_pin_msg(app, channel_id, _pinned_id, author_id): async def _handle_pin_msg(app, channel_id, _pinned_id, author_id):
"""Handle a message pin.""" """Handle a message pin."""
new_id = get_snowflake() new_id = get_snowflake()
@ -37,8 +38,10 @@ async def _handle_pin_msg(app, channel_id, _pinned_id, author_id):
($1, $2, NULL, $3, NULL, '', ($1, $2, NULL, $3, NULL, '',
$4) $4)
""", """,
new_id, channel_id, author_id, new_id,
MessageType.CHANNEL_PINNED_MESSAGE.value channel_id,
author_id,
MessageType.CHANNEL_PINNED_MESSAGE.value,
) )
return new_id return new_id
@ -56,15 +59,16 @@ async def _handle_recp_add(app, channel_id, author_id, peer_id):
VALUES VALUES
($1, $2, $3, NULL, $4, $5) ($1, $2, $3, NULL, $4, $5)
""", """,
new_id, channel_id, author_id, new_id,
f'<@{peer_id}>', channel_id,
MessageType.RECIPIENT_ADD.value author_id,
f"<@{peer_id}>",
MessageType.RECIPIENT_ADD.value,
) )
return new_id return new_id
async def _handle_recp_rmv(app, channel_id, author_id, peer_id): async def _handle_recp_rmv(app, channel_id, author_id, peer_id):
new_id = get_snowflake() new_id = get_snowflake()
@ -76,9 +80,11 @@ async def _handle_recp_rmv(app, channel_id, author_id, peer_id):
VALUES VALUES
($1, $2, $3, NULL, $4, $5) ($1, $2, $3, NULL, $4, $5)
""", """,
new_id, channel_id, author_id, new_id,
f'<@{peer_id}>', channel_id,
MessageType.RECIPIENT_REMOVE.value author_id,
f"<@{peer_id}>",
MessageType.RECIPIENT_REMOVE.value,
) )
return new_id return new_id
@ -87,13 +93,16 @@ async def _handle_recp_rmv(app, channel_id, author_id, peer_id):
async def _handle_gdm_name_edit(app, channel_id, author_id): async def _handle_gdm_name_edit(app, channel_id, author_id):
new_id = get_snowflake() new_id = get_snowflake()
gdm_name = await app.db.fetchval(""" gdm_name = await app.db.fetchval(
"""
SELECT name FROM group_dm_channels SELECT name FROM group_dm_channels
WHERE id = $1 WHERE id = $1
""", channel_id) """,
channel_id,
)
if not gdm_name: if not gdm_name:
log.warning('no gdm name found for sys message') log.warning("no gdm name found for sys message")
return return
await app.db.execute( await app.db.execute(
@ -104,9 +113,11 @@ async def _handle_gdm_name_edit(app, channel_id, author_id):
VALUES VALUES
($1, $2, $3, NULL, $4, $5) ($1, $2, $3, NULL, $4, $5)
""", """,
new_id, channel_id, author_id, new_id,
channel_id,
author_id,
gdm_name, gdm_name,
MessageType.CHANNEL_NAME_CHANGE.value MessageType.CHANNEL_NAME_CHANGE.value,
) )
return new_id return new_id
@ -123,16 +134,19 @@ async def _handle_gdm_icon_edit(app, channel_id, author_id):
VALUES VALUES
($1, $2, $3, NULL, $4, $5) ($1, $2, $3, NULL, $4, $5)
""", """,
new_id, channel_id, author_id, new_id,
'', channel_id,
MessageType.CHANNEL_ICON_CHANGE.value author_id,
"",
MessageType.CHANNEL_ICON_CHANGE.value,
) )
return new_id return new_id
async def send_sys_message(app, channel_id: int, m_type: MessageType, async def send_sys_message(
*args, **kwargs) -> int: app, channel_id: int, m_type: MessageType, *args, **kwargs
) -> int:
"""Send a system message. """Send a system message.
The handler for a given message type MUST return an integer, that integer The handler for a given message type MUST return an integer, that integer
@ -156,22 +170,19 @@ async def send_sys_message(app, channel_id: int, m_type: MessageType,
try: try:
handler = { handler = {
MessageType.CHANNEL_PINNED_MESSAGE: _handle_pin_msg, MessageType.CHANNEL_PINNED_MESSAGE: _handle_pin_msg,
# gdm specific # gdm specific
MessageType.RECIPIENT_ADD: _handle_recp_add, MessageType.RECIPIENT_ADD: _handle_recp_add,
MessageType.RECIPIENT_REMOVE: _handle_recp_rmv, MessageType.RECIPIENT_REMOVE: _handle_recp_rmv,
MessageType.CHANNEL_NAME_CHANGE: _handle_gdm_name_edit, MessageType.CHANNEL_NAME_CHANGE: _handle_gdm_name_edit,
MessageType.CHANNEL_ICON_CHANGE: _handle_gdm_icon_edit MessageType.CHANNEL_ICON_CHANGE: _handle_gdm_icon_edit,
}[m_type] }[m_type]
except KeyError: except KeyError:
raise ValueError('Invalid system message type') raise ValueError("Invalid system message type")
message_id = await handler(app, channel_id, *args, **kwargs) message_id = await handler(app, channel_id, *args, **kwargs)
message = await app.storage.get_message(message_id) message = await app.storage.get_message(message_id)
await app.dispatcher.dispatch( await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", message)
'channel', channel_id, 'MESSAGE_CREATE', message
)
return message_id return message_id

View File

@ -29,6 +29,7 @@ HOURS = 60 * MINUTES
class Color: class Color:
"""Custom color class""" """Custom color class"""
def __init__(self, val: int): def __init__(self, val: int):
self.blue = val & 255 self.blue = val & 255
self.green = (val >> 8) & 255 self.green = (val >> 8) & 255
@ -37,7 +38,7 @@ class Color:
@property @property
def value(self): def value(self):
"""Give the actual RGB integer encoding this color.""" """Give the actual RGB integer encoding this color."""
return int('%02x%02x%02x' % (self.red, self.green, self.blue), 16) return int("%02x%02x%02x" % (self.red, self.green, self.blue), 16)
@property @property
def to_json(self): def to_json(self):
@ -49,4 +50,4 @@ class Color:
def timestamp_(dt) -> Optional[str]: def timestamp_(dt) -> Optional[str]:
"""safer version for dt.isoformat()""" """safer version for dt.isoformat()"""
return f'{dt.isoformat()}+00:00' if dt else None return f"{dt.isoformat()}+00:00" if dt else None

View File

@ -27,43 +27,52 @@ log = Logger(__name__)
class UserStorage: class UserStorage:
"""Storage functions related to a single user.""" """Storage functions related to a single user."""
def __init__(self, storage): def __init__(self, storage):
self.storage = storage self.storage = storage
self.db = storage.db self.db = storage.db
async def fetch_notes(self, user_id: int) -> dict: async def fetch_notes(self, user_id: int) -> dict:
"""Fetch a users' notes""" """Fetch a users' notes"""
note_rows = await self.db.fetch(""" note_rows = await self.db.fetch(
"""
SELECT target_id, note SELECT target_id, note
FROM notes FROM notes
WHERE user_id = $1 WHERE user_id = $1
""", user_id) """,
user_id,
)
return {str(row['target_id']): row['note'] return {str(row["target_id"]): row["note"] for row in note_rows}
for row in note_rows}
async def get_user_settings(self, user_id: int) -> Dict[str, Any]: async def get_user_settings(self, user_id: int) -> Dict[str, Any]:
"""Get current user settings.""" """Get current user settings."""
row = await self.storage.fetchrow_with_json(""" row = await self.storage.fetchrow_with_json(
"""
SELECT * SELECT *
FROM user_settings FROM user_settings
WHERE id = $1 WHERE id = $1
""", user_id) """,
user_id,
)
if not row: if not row:
log.info('Generating user settings for {}', user_id) log.info("Generating user settings for {}", user_id)
await self.db.execute(""" await self.db.execute(
"""
INSERT INTO user_settings (id) INSERT INTO user_settings (id)
VALUES ($1) VALUES ($1)
""", user_id) """,
user_id,
)
# recalling get_user_settings # recalling get_user_settings
# should work after adding # should work after adding
return await self.get_user_settings(user_id) return await self.get_user_settings(user_id)
drow = dict(row) drow = dict(row)
drow.pop('id') drow.pop("id")
return drow return drow
async def get_relationships(self, user_id: int) -> List[Dict[str, Any]]: async def get_relationships(self, user_id: int) -> List[Dict[str, Any]]:
@ -76,11 +85,15 @@ class UserStorage:
_outgoing = RelationshipType.OUTGOING.value _outgoing = RelationshipType.OUTGOING.value
# check all outgoing friends # check all outgoing friends
friends = await self.db.fetch(""" friends = await self.db.fetch(
"""
SELECT user_id, peer_id, rel_type SELECT user_id, peer_id, rel_type
FROM relationships FROM relationships
WHERE user_id = $1 AND rel_type = $2 WHERE user_id = $1 AND rel_type = $2
""", user_id, _friend) """,
user_id,
_friend,
)
friends = list(map(dict, friends)) friends = list(map(dict, friends))
# mutuals is a list of ints # mutuals is a list of ints
@ -95,66 +108,80 @@ class UserStorage:
SELECT user_id, peer_id SELECT user_id, peer_id
FROM relationships FROM relationships
WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3 WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3
""", row['peer_id'], row['user_id'], """,
_friend) row["peer_id"],
row["user_id"],
_friend,
)
if is_friend is not None: if is_friend is not None:
mutuals.append(row['peer_id']) mutuals.append(row["peer_id"])
# fetch friend requests directed at us # fetch friend requests directed at us
incoming_friends = await self.db.fetch(""" incoming_friends = await self.db.fetch(
"""
SELECT user_id, peer_id SELECT user_id, peer_id
FROM relationships FROM relationships
WHERE peer_id = $1 AND rel_type = $2 WHERE peer_id = $1 AND rel_type = $2
""", user_id, _friend) """,
user_id,
_friend,
)
# only need their ids # only need their ids
incoming_friends = [r['user_id'] for r in incoming_friends incoming_friends = [
if r['user_id'] not in mutuals] r["user_id"] for r in incoming_friends if r["user_id"] not in mutuals
]
# only fetch blocks we did, # only fetch blocks we did,
# not fetching the ones people did to us # not fetching the ones people did to us
blocks = await self.db.fetch(""" blocks = await self.db.fetch(
"""
SELECT user_id, peer_id, rel_type SELECT user_id, peer_id, rel_type
FROM relationships FROM relationships
WHERE user_id = $1 AND rel_type = $2 WHERE user_id = $1 AND rel_type = $2
""", user_id, _block) """,
user_id,
_block,
)
blocks = list(map(dict, blocks)) blocks = list(map(dict, blocks))
res = [] res = []
for drow in friends: for drow in friends:
drow['type'] = drow['rel_type'] drow["type"] = drow["rel_type"]
drow['id'] = str(drow['peer_id']) drow["id"] = str(drow["peer_id"])
drow.pop('rel_type') drow.pop("rel_type")
# check if the receiver is a mutual # check if the receiver is a mutual
# if it isnt, its still on a friend request stage # if it isnt, its still on a friend request stage
if drow['peer_id'] not in mutuals: if drow["peer_id"] not in mutuals:
drow['type'] = _outgoing drow["type"] = _outgoing
drow['user'] = await self.storage.get_user(drow['peer_id']) drow["user"] = await self.storage.get_user(drow["peer_id"])
drow.pop('user_id') drow.pop("user_id")
drow.pop('peer_id') drow.pop("peer_id")
res.append(drow) res.append(drow)
for peer_id in incoming_friends: for peer_id in incoming_friends:
res.append({ res.append(
'id': str(peer_id), {
'user': await self.storage.get_user(peer_id), "id": str(peer_id),
'type': _incoming, "user": await self.storage.get_user(peer_id),
}) "type": _incoming,
}
)
for drow in blocks: for drow in blocks:
drow['type'] = drow['rel_type'] drow["type"] = drow["rel_type"]
drow.pop('rel_type') drow.pop("rel_type")
drow['id'] = str(drow['peer_id']) drow["id"] = str(drow["peer_id"])
drow['user'] = await self.storage.get_user(drow['peer_id']) drow["user"] = await self.storage.get_user(drow["peer_id"])
drow.pop('user_id') drow.pop("user_id")
drow.pop('peer_id') drow.pop("peer_id")
res.append(drow) res.append(drow)
return res return res
@ -163,9 +190,11 @@ class UserStorage:
"""Get all friend IDs for a user.""" """Get all friend IDs for a user."""
rels = await self.get_relationships(user_id) rels = await self.get_relationships(user_id)
return [int(r['user']['id']) return [
for r in rels int(r["user"]["id"])
if r['type'] == RelationshipType.FRIEND.value] for r in rels
if r["type"] == RelationshipType.FRIEND.value
]
async def get_dms(self, user_id: int) -> List[Dict[str, Any]]: async def get_dms(self, user_id: int) -> List[Dict[str, Any]]:
"""Get all DM channels for a user, including group DMs. """Get all DM channels for a user, including group DMs.
@ -173,13 +202,16 @@ class UserStorage:
This will only fetch channels the user has in their state, This will only fetch channels the user has in their state,
which is different than the whole list of DM channels. which is different than the whole list of DM channels.
""" """
dm_ids = await self.db.fetch(""" dm_ids = await self.db.fetch(
"""
SELECT dm_id SELECT dm_id
FROM dm_channel_state FROM dm_channel_state
WHERE user_id = $1 WHERE user_id = $1
""", user_id) """,
user_id,
)
dm_ids = [r['dm_id'] for r in dm_ids] dm_ids = [r["dm_id"] for r in dm_ids]
res = [] res = []
@ -191,21 +223,24 @@ class UserStorage:
async def get_read_state(self, user_id: int) -> List[Dict[str, Any]]: async def get_read_state(self, user_id: int) -> List[Dict[str, Any]]:
"""Get the read state for a user.""" """Get the read state for a user."""
rows = await self.db.fetch(""" rows = await self.db.fetch(
"""
SELECT channel_id, last_message_id, mention_count SELECT channel_id, last_message_id, mention_count
FROM user_read_state FROM user_read_state
WHERE user_id = $1 WHERE user_id = $1
""", user_id) """,
user_id,
)
res = [] res = []
for row in rows: for row in rows:
drow = dict(row) drow = dict(row)
drow['id'] = str(drow['channel_id']) drow["id"] = str(drow["channel_id"])
drow.pop('channel_id') drow.pop("channel_id")
drow['last_message_id'] = str(drow['last_message_id']) drow["last_message_id"] = str(drow["last_message_id"])
res.append(drow) res.append(drow)
@ -214,13 +249,17 @@ class UserStorage:
async def _get_chan_overrides(self, user_id: int, guild_id: int) -> List: async def _get_chan_overrides(self, user_id: int, guild_id: int) -> List:
chan_overrides = [] chan_overrides = []
overrides = await self.db.fetch(""" overrides = await self.db.fetch(
"""
SELECT channel_id::text, muted, message_notifications SELECT channel_id::text, muted, message_notifications
FROM guild_settings_channel_overrides FROM guild_settings_channel_overrides
WHERE WHERE
user_id = $1 user_id = $1
AND guild_id = $2 AND guild_id = $2
""", user_id, guild_id) """,
user_id,
guild_id,
)
for chan_row in overrides: for chan_row in overrides:
dcrow = dict(chan_row) dcrow = dict(chan_row)
@ -228,30 +267,35 @@ class UserStorage:
return chan_overrides return chan_overrides
async def get_guild_settings_one(self, user_id: int, async def get_guild_settings_one(self, user_id: int, guild_id: int) -> dict:
guild_id: int) -> dict:
"""Get guild settings information for a single guild.""" """Get guild settings information for a single guild."""
row = await self.db.fetchrow(""" row = await self.db.fetchrow(
"""
SELECT guild_id::text, suppress_everyone, muted, SELECT guild_id::text, suppress_everyone, muted,
message_notifications, mobile_push message_notifications, mobile_push
FROM guild_settings FROM guild_settings
WHERE user_id = $1 AND guild_id = $2 WHERE user_id = $1 AND guild_id = $2
""", user_id, guild_id) """,
user_id,
guild_id,
)
if not row: if not row:
await self.db.execute(""" await self.db.execute(
"""
INSERT INTO guild_settings (user_id, guild_id) INSERT INTO guild_settings (user_id, guild_id)
VALUES ($1, $2) VALUES ($1, $2)
""", user_id, guild_id) """,
user_id,
guild_id,
)
return await self.get_guild_settings_one(user_id, guild_id) return await self.get_guild_settings_one(user_id, guild_id)
gid = int(row['guild_id']) gid = int(row["guild_id"])
drow = dict(row) drow = dict(row)
chan_overrides = await self._get_chan_overrides(user_id, gid) chan_overrides = await self._get_chan_overrides(user_id, gid)
return {**drow, **{ return {**drow, **{"channel_overrides": chan_overrides}}
'channel_overrides': chan_overrides
}}
async def get_guild_settings(self, user_id: int): async def get_guild_settings(self, user_id: int):
"""Get the specific User Guild Settings, """Get the specific User Guild Settings,
@ -259,34 +303,38 @@ class UserStorage:
res = [] res = []
settings = await self.db.fetch(""" settings = await self.db.fetch(
"""
SELECT guild_id::text, suppress_everyone, muted, SELECT guild_id::text, suppress_everyone, muted,
message_notifications, mobile_push message_notifications, mobile_push
FROM guild_settings FROM guild_settings
WHERE user_id = $1 WHERE user_id = $1
""", user_id) """,
user_id,
)
for row in settings: for row in settings:
gid = int(row['guild_id']) gid = int(row["guild_id"])
drow = dict(row) drow = dict(row)
chan_overrides = await self._get_chan_overrides(user_id, gid) chan_overrides = await self._get_chan_overrides(user_id, gid)
res.append({**drow, **{ res.append({**drow, **{"channel_overrides": chan_overrides}})
'channel_overrides': chan_overrides
}})
return res return res
async def get_user_guilds(self, user_id: int) -> List[int]: async def get_user_guilds(self, user_id: int) -> List[int]:
"""Get all guild IDs a user is on.""" """Get all guild IDs a user is on."""
guild_ids = await self.db.fetch(""" guild_ids = await self.db.fetch(
"""
SELECT guild_id SELECT guild_id
FROM members FROM members
WHERE user_id = $1 WHERE user_id = $1
""", user_id) """,
user_id,
)
return [row['guild_id'] for row in guild_ids] return [row["guild_id"] for row in guild_ids]
async def get_mutual_guilds(self, user_id: int, peer_id: int) -> List[int]: async def get_mutual_guilds(self, user_id: int, peer_id: int) -> List[int]:
"""Get a list of guilds two separate users """Get a list of guilds two separate users
@ -301,13 +349,17 @@ class UserStorage:
return await self.get_user_guilds(user_id) or [0] return await self.get_user_guilds(user_id) or [0]
mutual_guilds = await self.db.fetch(""" mutual_guilds = await self.db.fetch(
"""
SELECT guild_id FROM members WHERE user_id = $1 SELECT guild_id FROM members WHERE user_id = $1
INTERSECT INTERSECT
SELECT guild_id FROM members WHERE user_id = $2 SELECT guild_id FROM members WHERE user_id = $2
""", user_id, peer_id) """,
user_id,
peer_id,
)
mutual_guilds = [r['guild_id'] for r in mutual_guilds] mutual_guilds = [r["guild_id"] for r in mutual_guilds]
return mutual_guilds return mutual_guilds
@ -316,7 +368,8 @@ class UserStorage:
This returns false even if there is a friend request. This returns false even if there is a friend request.
""" """
return await self.db.fetchval(""" return await self.db.fetchval(
"""
SELECT SELECT
( (
SELECT EXISTS( SELECT EXISTS(
@ -337,17 +390,23 @@ class UserStorage:
AND rel_type = 1 AND rel_type = 1
) )
) )
""", user_id, peer_id) """,
user_id,
peer_id,
)
async def get_gdms_internal(self, user_id) -> List[int]: async def get_gdms_internal(self, user_id) -> List[int]:
"""Return a list of Group DM IDs the user is a member of.""" """Return a list of Group DM IDs the user is a member of."""
rows = await self.db.fetch(""" rows = await self.db.fetch(
"""
SELECT id SELECT id
FROM group_dm_members FROM group_dm_members
WHERE member_id = $1 WHERE member_id = $1
""", user_id) """,
user_id,
)
return [r['id'] for r in rows] return [r["id"] for r in rows]
async def get_gdms(self, user_id) -> List[Dict[str, Any]]: async def get_gdms(self, user_id) -> List[Dict[str, Any]]:
"""Get list of group DMs a user is in.""" """Get list of group DMs a user is in."""
@ -356,8 +415,6 @@ class UserStorage:
res = [] res = []
for gdm_id in gdm_ids: for gdm_id in gdm_ids:
res.append( res.append(await self.storage.get_channel(gdm_id, user_id=user_id))
await self.storage.get_channel(gdm_id, user_id=user_id)
)
return res return res

View File

@ -46,7 +46,7 @@ async def task_wrapper(name: str, coro):
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
except: except:
log.exception('{} task error', name) log.exception("{} task error", name)
def dict_get(mapping, key, default): def dict_get(mapping, key, default):
@ -84,54 +84,66 @@ def mmh3(inp_str: str, seed: int = 0):
h1 = seed h1 = seed
# mm3 constants # mm3 constants
c1 = 0xcc9e2d51 c1 = 0xCC9E2D51
c2 = 0x1b873593 c2 = 0x1B873593
i = 0 i = 0
while i < bytecount: while i < bytecount:
k1 = ( k1 = (
(key[i] & 0xff) | (key[i] & 0xFF)
((key[i + 1] & 0xff) << 8) | | ((key[i + 1] & 0xFF) << 8)
((key[i + 2] & 0xff) << 16) | | ((key[i + 2] & 0xFF) << 16)
((key[i + 3] & 0xff) << 24) | ((key[i + 3] & 0xFF) << 24)
) )
i += 4 i += 4
k1 = ((((k1 & 0xffff) * c1) + ((((_u(k1) >> 16) * c1) & 0xffff) << 16))) & 0xffffffff k1 = (
(((k1 & 0xFFFF) * c1) + ((((_u(k1) >> 16) * c1) & 0xFFFF) << 16))
) & 0xFFFFFFFF
k1 = (k1 << 15) | (_u(k1) >> 17) k1 = (k1 << 15) | (_u(k1) >> 17)
k1 = ((((k1 & 0xffff) * c2) + ((((_u(k1) >> 16) * c2) & 0xffff) << 16))) & 0xffffffff; k1 = (
(((k1 & 0xFFFF) * c2) + ((((_u(k1) >> 16) * c2) & 0xFFFF) << 16))
) & 0xFFFFFFFF
h1 ^= k1 h1 ^= k1
h1 = (h1 << 13) | (_u(h1) >> 19); h1 = (h1 << 13) | (_u(h1) >> 19)
h1b = ((((h1 & 0xffff) * 5) + ((((_u(h1) >> 16) * 5) & 0xffff) << 16))) & 0xffffffff; h1b = (
h1 = (((h1b & 0xffff) + 0x6b64) + ((((_u(h1b) >> 16) + 0xe654) & 0xffff) << 16)) (((h1 & 0xFFFF) * 5) + ((((_u(h1) >> 16) * 5) & 0xFFFF) << 16))
) & 0xFFFFFFFF
h1 = ((h1b & 0xFFFF) + 0x6B64) + ((((_u(h1b) >> 16) + 0xE654) & 0xFFFF) << 16)
k1 = 0 k1 = 0
v = None v = None
if remainder == 3: if remainder == 3:
v = (key[i + 2] & 0xff) << 16 v = (key[i + 2] & 0xFF) << 16
elif remainder == 2: elif remainder == 2:
v = (key[i + 1] & 0xff) << 8 v = (key[i + 1] & 0xFF) << 8
elif remainder == 1: elif remainder == 1:
v = (key[i] & 0xff) v = key[i] & 0xFF
if v is not None: if v is not None:
k1 ^= v k1 ^= v
k1 = (((k1 & 0xffff) * c1) + ((((_u(k1) >> 16) * c1) & 0xffff) << 16)) & 0xffffffff k1 = (((k1 & 0xFFFF) * c1) + ((((_u(k1) >> 16) * c1) & 0xFFFF) << 16)) & 0xFFFFFFFF
k1 = (k1 << 15) | (_u(k1) >> 17) k1 = (k1 << 15) | (_u(k1) >> 17)
k1 = (((k1 & 0xffff) * c2) + ((((_u(k1) >> 16) * c2) & 0xffff) << 16)) & 0xffffffff k1 = (((k1 & 0xFFFF) * c2) + ((((_u(k1) >> 16) * c2) & 0xFFFF) << 16)) & 0xFFFFFFFF
h1 ^= k1 h1 ^= k1
h1 ^= len(key) h1 ^= len(key)
h1 ^= _u(h1) >> 16 h1 ^= _u(h1) >> 16
h1 = (((h1 & 0xffff) * 0x85ebca6b) + ((((_u(h1) >> 16) * 0x85ebca6b) & 0xffff) << 16)) & 0xffffffff h1 = (
((h1 & 0xFFFF) * 0x85EBCA6B) + ((((_u(h1) >> 16) * 0x85EBCA6B) & 0xFFFF) << 16)
) & 0xFFFFFFFF
h1 ^= _u(h1) >> 13 h1 ^= _u(h1) >> 13
h1 = ((((h1 & 0xffff) * 0xc2b2ae35) + ((((_u(h1) >> 16) * 0xc2b2ae35) & 0xffff) << 16))) & 0xffffffff h1 = (
(
((h1 & 0xFFFF) * 0xC2B2AE35)
+ ((((_u(h1) >> 16) * 0xC2B2AE35) & 0xFFFF) << 16)
)
) & 0xFFFFFFFF
h1 ^= _u(h1) >> 16 h1 ^= _u(h1) >> 16
return _u(h1) >> 0 return _u(h1) >> 0
@ -139,6 +151,7 @@ def mmh3(inp_str: str, seed: int = 0):
class LitecordJSONEncoder(JSONEncoder): class LitecordJSONEncoder(JSONEncoder):
"""Custom JSON encoder for Litecord.""" """Custom JSON encoder for Litecord."""
def default(self, value: Any): def default(self, value: Any):
"""By default, this will try to get the to_json attribute of a given """By default, this will try to get the to_json attribute of a given
value being JSON encoded.""" value being JSON encoded."""
@ -151,17 +164,17 @@ class LitecordJSONEncoder(JSONEncoder):
async def pg_set_json(con): async def pg_set_json(con):
"""Set JSON and JSONB codecs for an asyncpg connection.""" """Set JSON and JSONB codecs for an asyncpg connection."""
await con.set_type_codec( await con.set_type_codec(
'json', "json",
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder), encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
decoder=json.loads, decoder=json.loads,
schema='pg_catalog' schema="pg_catalog",
) )
await con.set_type_codec( await con.set_type_codec(
'jsonb', "jsonb",
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder), encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
decoder=json.loads, decoder=json.loads,
schema='pg_catalog' schema="pg_catalog",
) )
@ -177,7 +190,8 @@ def yield_chunks(input_list: Sequence[Any], chunk_size: int):
# range accepts step param, so we use that to # range accepts step param, so we use that to
# make the chunks # make the chunks
for idx in range(0, len(input_list), chunk_size): for idx in range(0, len(input_list), chunk_size):
yield input_list[idx:idx + chunk_size] yield input_list[idx : idx + chunk_size]
def to_update(j: dict, orig: dict, field: str) -> bool: def to_update(j: dict, orig: dict, field: str) -> bool:
"""Compare values to check if j[field] is actually updating """Compare values to check if j[field] is actually updating
@ -193,27 +207,23 @@ async def search_result_from_list(rows: List) -> Dict[str, Any]:
- An int (?) on `total_results` - An int (?) on `total_results`
- Two bigint[], each on `before` and `after` respectively. - Two bigint[], each on `before` and `after` respectively.
""" """
results = 0 if not rows else rows[0]['total_results'] results = 0 if not rows else rows[0]["total_results"]
res = [] res = []
for row in rows: for row in rows:
before, after = [], [] before, after = [], []
for before_id in reversed(row['before']): for before_id in reversed(row["before"]):
before.append(await app.storage.get_message(before_id)) before.append(await app.storage.get_message(before_id))
for after_id in row['after']: for after_id in row["after"]:
after.append(await app.storage.get_message(after_id)) after.append(await app.storage.get_message(after_id))
msg = await app.storage.get_message(row['current_id']) msg = await app.storage.get_message(row["current_id"])
msg['hit'] = True msg["hit"] = True
res.append(before + [msg] + after) res.append(before + [msg] + after)
return { return {"total_results": results, "messages": res, "analytics_id": ""}
'total_results': results,
'messages': res,
'analytics_id': '',
}
def maybe_int(val: Any) -> Union[int, Any]: def maybe_int(val: Any) -> Union[int, Any]:

View File

@ -31,6 +31,7 @@ log = Logger(__name__)
class LVSPConnection: class LVSPConnection:
"""Represents a single LVSP connection.""" """Represents a single LVSP connection."""
def __init__(self, lvsp, region: str, hostname: str): def __init__(self, lvsp, region: str, hostname: str):
self.lvsp = lvsp self.lvsp = lvsp
self.app = lvsp.app self.app = lvsp.app
@ -46,7 +47,7 @@ class LVSPConnection:
@property @property
def _log_id(self): def _log_id(self):
return f'region={self.region} hostname={self.hostname}' return f"region={self.region} hostname={self.hostname}"
async def send(self, payload): async def send(self, payload):
"""Send a payload down the websocket.""" """Send a payload down the websocket."""
@ -61,50 +62,42 @@ class LVSPConnection:
async def send_op(self, opcode: int, data: dict): async def send_op(self, opcode: int, data: dict):
"""Send a message with an OP code included""" """Send a message with an OP code included"""
await self.send({ await self.send({"op": opcode, "d": data})
'op': opcode,
'd': data
})
async def send_info(self, info_type: str, info_data: Dict): async def send_info(self, info_type: str, info_data: Dict):
"""Send an INFO message down the websocket.""" """Send an INFO message down the websocket."""
await self.send({ await self.send(
'op': OP.info, {
'd': { "op": OP.info,
'type': InfoTable[info_type.upper()], "d": {"type": InfoTable[info_type.upper()], "data": info_data},
'data': info_data
} }
}) )
async def _heartbeater(self, hb_interval: int): async def _heartbeater(self, hb_interval: int):
try: try:
await asyncio.sleep(hb_interval) await asyncio.sleep(hb_interval)
# TODO: add self._seq # TODO: add self._seq
await self.send_op(OP.heartbeat, { await self.send_op(OP.heartbeat, {"s": 0})
's': 0
})
# give the server 300 milliseconds to reply. # give the server 300 milliseconds to reply.
await asyncio.sleep(300) await asyncio.sleep(300)
await self.conn.close(4000, 'heartbeat timeout') await self.conn.close(4000, "heartbeat timeout")
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
def _start_hb(self): def _start_hb(self):
self._hb_task = self.app.loop.create_task( self._hb_task = self.app.loop.create_task(self._heartbeater(self._hb_interval))
self._heartbeater(self._hb_interval)
)
def _stop_hb(self): def _stop_hb(self):
self._hb_task.cancel() self._hb_task.cancel()
async def _handle_0(self, msg): async def _handle_0(self, msg):
"""Handle HELLO message.""" """Handle HELLO message."""
data = msg['d'] data = msg["d"]
# nonce = data['nonce'] # nonce = data['nonce']
self._hb_interval = data['heartbeat_interval'] self._hb_interval = data["heartbeat_interval"]
# TODO: send identify # TODO: send identify
@ -112,48 +105,52 @@ class LVSPConnection:
"""Update the health value of a given voice server.""" """Update the health value of a given voice server."""
self.health = new_health self.health = new_health
await self.app.db.execute(""" await self.app.db.execute(
"""
UPDATE voice_servers UPDATE voice_servers
SET health = $1 SET health = $1
WHERE hostname = $2 WHERE hostname = $2
""", new_health, self.hostname) """,
new_health,
self.hostname,
)
async def _handle_3(self, msg): async def _handle_3(self, msg):
"""Handle READY message. """Handle READY message.
We only start heartbeating after READY. We only start heartbeating after READY.
""" """
await self._update_health(msg['health']) await self._update_health(msg["health"])
self._start_hb() self._start_hb()
async def _handle_5(self, msg): async def _handle_5(self, msg):
"""Handle HEARTBEAT_ACK.""" """Handle HEARTBEAT_ACK."""
self._stop_hb() self._stop_hb()
await self._update_health(msg['health']) await self._update_health(msg["health"])
self._start_hb() self._start_hb()
async def _handle_6(self, msg): async def _handle_6(self, msg):
"""Handle INFO messages.""" """Handle INFO messages."""
info = msg['d'] info = msg["d"]
info_type_str = InfoReverse[info['type']].lower() info_type_str = InfoReverse[info["type"]].lower()
try: try:
info_handler = getattr(self, f'_handle_info_{info_type_str}') info_handler = getattr(self, f"_handle_info_{info_type_str}")
except AttributeError: except AttributeError:
return return
await info_handler(info['data']) await info_handler(info["data"])
async def _handle_info_channel_assign(self, data: dict): async def _handle_info_channel_assign(self, data: dict):
"""called by the server once we got a channel assign.""" """called by the server once we got a channel assign."""
try: try:
channel_id = data['channel_id'] channel_id = data["channel_id"]
channel_id = int(channel_id) channel_id = int(channel_id)
except (TypeError, ValueError): except (TypeError, ValueError):
return return
try: try:
guild_id = data['guild_id'] guild_id = data["guild_id"]
guild_id = int(guild_id) guild_id = int(guild_id)
except (TypeError, ValueError): except (TypeError, ValueError):
guild_id = None guild_id = None
@ -166,19 +163,19 @@ class LVSPConnection:
msg = await self.recv() msg = await self.recv()
try: try:
opcode = msg['op'] opcode = msg["op"]
handler = getattr(self, f'_handle_{opcode}') handler = getattr(self, f"_handle_{opcode}")
await handler(msg) await handler(msg)
except (KeyError, AttributeError): except (KeyError, AttributeError):
# TODO: error codes in LVSP # TODO: error codes in LVSP
raise Exception('invalid op code') raise Exception("invalid op code")
async def start(self): async def start(self):
"""Try to start a websocket connection.""" """Try to start a websocket connection."""
try: try:
self.conn = await websockets.connect(f'wss://{self.hostname}') self.conn = await websockets.connect(f"wss://{self.hostname}")
except Exception: except Exception:
log.exception('failed to start lvsp conn to {}', self.hostname) log.exception("failed to start lvsp conn to {}", self.hostname)
async def run(self): async def run(self):
"""Start the websocket.""" """Start the websocket."""
@ -186,15 +183,15 @@ class LVSPConnection:
try: try:
if not self.conn: if not self.conn:
log.error('failed to start lvsp connection, stopping') log.error("failed to start lvsp connection, stopping")
return return
await self._loop() await self._loop()
except websockets.exceptions.ConnectionClosed as err: except websockets.exceptions.ConnectionClosed as err:
log.warning('conn close, {}, err={}', self._log_id, err) log.warning("conn close, {}, err={}", self._log_id, err)
# except WebsocketClose as err: # except WebsocketClose as err:
# log.warning('ws close, state={} err={}', self.state, err) # log.warning('ws close, state={} err={}', self.state, err)
# await self.conn.close(code=err.code, reason=err.reason) # await self.conn.close(code=err.code, reason=err.reason)
except Exception as err: except Exception as err:
log.exception('An exception has occoured. {}', self._log_id) log.exception("An exception has occoured. {}", self._log_id)
await self.conn.close(code=4000, reason=repr(err)) await self.conn.close(code=4000, reason=repr(err))

View File

@ -31,6 +31,7 @@ log = Logger(__name__)
@dataclass @dataclass
class Region: class Region:
"""Voice region data.""" """Voice region data."""
id: str id: str
vip: bool vip: bool
@ -40,6 +41,7 @@ class LVSPManager:
Spawns :class:`LVSPConnection` as needed, etc. Spawns :class:`LVSPConnection` as needed, etc.
""" """
def __init__(self, app, voice): def __init__(self, app, voice):
self.app = app self.app = app
self.voice = voice self.voice = voice
@ -61,49 +63,50 @@ class LVSPManager:
async def _spawn(self): async def _spawn(self):
"""Spawn LVSPConnection for each region.""" """Spawn LVSPConnection for each region."""
regions = await self.app.db.fetch(""" regions = await self.app.db.fetch(
"""
SELECT id, vip SELECT id, vip
FROM voice_regions FROM voice_regions
WHERE deprecated = false WHERE deprecated = false
""") """
)
regions = [Region(r['id'], r['vip']) for r in regions] regions = [Region(r["id"], r["vip"]) for r in regions]
if not regions: if not regions:
log.warning('no regions are setup') log.warning("no regions are setup")
return return
for region in regions: for region in regions:
# store it locally for region() function # store it locally for region() function
self.regions[region.id] = region self.regions[region.id] = region
self.app.loop.create_task( self.app.loop.create_task(self._spawn_region(region))
self._spawn_region(region)
)
async def _spawn_region(self, region: Region): async def _spawn_region(self, region: Region):
"""Spawn a region. Involves fetching all the hostnames """Spawn a region. Involves fetching all the hostnames
for the regions and spawning a LVSPConnection for each.""" for the regions and spawning a LVSPConnection for each."""
servers = await self.app.db.fetch(""" servers = await self.app.db.fetch(
"""
SELECT hostname SELECT hostname
FROM voice_servers FROM voice_servers
WHERE region_id = $1 WHERE region_id = $1
""", region.id) """,
region.id,
)
if not servers: if not servers:
log.warning('region {} does not have servers', region) log.warning("region {} does not have servers", region)
return return
servers = [r['hostname'] for r in servers] servers = [r["hostname"] for r in servers]
self.servers[region.id] = servers self.servers[region.id] = servers
for hostname in servers: for hostname in servers:
conn = LVSPConnection(self, region.id, hostname) conn = LVSPConnection(self, region.id, hostname)
self.conns[hostname] = conn self.conns[hostname] = conn
self.app.loop.create_task( self.app.loop.create_task(conn.run())
conn.run()
)
async def del_conn(self, conn): async def del_conn(self, conn):
"""Delete a connection from the connection pool.""" """Delete a connection from the connection pool."""
@ -119,11 +122,14 @@ class LVSPManager:
async def guild_region(self, guild_id: int) -> Optional[str]: async def guild_region(self, guild_id: int) -> Optional[str]:
"""Return the voice region of a guild.""" """Return the voice region of a guild."""
return await self.app.db.fetchval(""" return await self.app.db.fetchval(
"""
SELECT region SELECT region
FROM guilds FROM guilds
WHERE id = $1 WHERE id = $1
""", guild_id) """,
guild_id,
)
def get_health(self, hostname: str) -> float: def get_health(self, hostname: str) -> float:
"""Get voice server health, given hostname.""" """Get voice server health, given hostname."""
@ -144,10 +150,7 @@ class LVSPManager:
region = await self.guild_region(guild_id) region = await self.guild_region(guild_id)
# sort connected servers by health # sort connected servers by health
sorted_servers = sorted( sorted_servers = sorted(self.servers[region], key=self.get_health)
self.servers[region],
key=self.get_health
)
try: try:
hostname = sorted_servers[0] hostname = sorted_servers[0]

View File

@ -17,8 +17,10 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
class OPCodes: class OPCodes:
"""LVSP OP codes.""" """LVSP OP codes."""
hello = 0 hello = 0
identify = 1 identify = 1
resume = 2 resume = 2
@ -29,13 +31,13 @@ class OPCodes:
InfoTable = { InfoTable = {
'CHANNEL_REQ': 0, "CHANNEL_REQ": 0,
'CHANNEL_ASSIGN': 1, "CHANNEL_ASSIGN": 1,
'CHANNEL_UPDATE': 2, "CHANNEL_UPDATE": 2,
'CHANNEL_DESTROY': 3, "CHANNEL_DESTROY": 3,
'VST_CREATE': 4, "VST_CREATE": 4,
'VST_UPDATE': 5, "VST_UPDATE": 5,
'VST_LEAVE': 6, "VST_LEAVE": 6,
} }
InfoReverse = {v: k for k, v in InfoTable.items()} InfoReverse = {v: k for k, v in InfoTable.items()}

View File

@ -43,6 +43,7 @@ def _construct_state(state_dict: dict) -> VoiceState:
class VoiceManager: class VoiceManager:
"""Main voice manager class.""" """Main voice manager class."""
def __init__(self, app): def __init__(self, app):
self.app = app self.app = app
@ -56,7 +57,7 @@ class VoiceManager:
"""Return if a user can join a channel.""" """Return if a user can join a channel."""
channel = await self.app.storage.get_channel(channel_id) channel = await self.app.storage.get_channel(channel_id)
ctype = ChannelType(channel['type']) ctype = ChannelType(channel["type"])
if ctype not in VOICE_CHANNELS: if ctype not in VOICE_CHANNELS:
return return
@ -65,14 +66,12 @@ class VoiceManager:
# get_permissions returns ALL_PERMISSIONS when # get_permissions returns ALL_PERMISSIONS when
# the channel isn't from a guild # the channel isn't from a guild
perms = await get_permissions( perms = await get_permissions(user_id, channel_id, storage=self.app.storage)
user_id, channel_id, storage=self.app.storage
)
# hacky user_limit but should work, as channels not # hacky user_limit but should work, as channels not
# in guilds won't have that field. # in guilds won't have that field.
is_full = states >= channel.get('user_limit', 100) is_full = states >= channel.get("user_limit", 100)
is_bot = (await self.app.storage.get_user(user_id))['bot'] is_bot = (await self.app.storage.get_user(user_id))["bot"]
is_manager = perms.bits.manage_channels is_manager = perms.bits.manage_channels
# if the channel is full AND: # if the channel is full AND:
@ -140,8 +139,8 @@ class VoiceManager:
for field in prop: for field in prop:
# NOTE: this should not happen, ever. # NOTE: this should not happen, ever.
if field in ('channel_id', 'user_id'): if field in ("channel_id", "user_id"):
raise ValueError('properties are updating channel or user') raise ValueError("properties are updating channel or user")
new_state_dict[field] = prop[field] new_state_dict[field] = prop[field]
@ -153,27 +152,28 @@ class VoiceManager:
async def move_channels(self, old_voice_key: VoiceKey, channel_id: int): async def move_channels(self, old_voice_key: VoiceKey, channel_id: int):
"""Move a user between channels.""" """Move a user between channels."""
await self.del_state(old_voice_key) await self.del_state(old_voice_key)
await self.create_state(old_voice_key, {'channel_id': channel_id}) await self.create_state(old_voice_key, {"channel_id": channel_id})
async def _lvsp_info_guild(self, guild_id, info_type, info_data): async def _lvsp_info_guild(self, guild_id, info_type, info_data):
hostname = await self.lvsp.get_guild_server(guild_id) hostname = await self.lvsp.get_guild_server(guild_id)
if hostname is None: if hostname is None:
log.error('no voice server for guild id {}', guild_id) log.error("no voice server for guild id {}", guild_id)
return return
conn = self.lvsp.get_conn(hostname) conn = self.lvsp.get_conn(hostname)
await conn.send_info(info_type, info_data) await conn.send_info(info_type, info_data)
async def _create_ctx_guild(self, guild_id, channel_id): async def _create_ctx_guild(self, guild_id, channel_id):
await self._lvsp_info_guild(guild_id, 'CHANNEL_REQ', { await self._lvsp_info_guild(
'guild_id': str(guild_id), guild_id,
'channel_id': str(channel_id), "CHANNEL_REQ",
}) {"guild_id": str(guild_id), "channel_id": str(channel_id)},
)
async def _start_voice_guild(self, voice_key: VoiceKey, data: dict): async def _start_voice_guild(self, voice_key: VoiceKey, data: dict):
"""Start a voice context in a guild.""" """Start a voice context in a guild."""
user_id, guild_id = voice_key user_id, guild_id = voice_key
channel_id = int(data['channel_id']) channel_id = int(data["channel_id"])
existing_states = self.states[voice_key] existing_states = self.states[voice_key]
channel_exists = any( channel_exists = any(
@ -183,11 +183,15 @@ class VoiceManager:
if not channel_exists: if not channel_exists:
await self._create_ctx_guild(guild_id, channel_id) await self._create_ctx_guild(guild_id, channel_id)
await self._lvsp_info_guild(guild_id, 'VST_CREATE', { await self._lvsp_info_guild(
'user_id': str(user_id), guild_id,
'guild_id': str(guild_id), "VST_CREATE",
'channel_id': str(channel_id), {
}) "user_id": str(user_id),
"guild_id": str(guild_id),
"channel_id": str(channel_id),
},
)
async def create_state(self, voice_key: VoiceKey, data: dict): async def create_state(self, voice_key: VoiceKey, data: dict):
"""Creates (or tries to create) a voice state. """Creates (or tries to create) a voice state.
@ -249,10 +253,13 @@ class VoiceManager:
async def voice_server_list(self, region: str) -> List[dict]: async def voice_server_list(self, region: str) -> List[dict]:
"""Get a list of voice server objects""" """Get a list of voice server objects"""
rows = await self.app.db.fetch(""" rows = await self.app.db.fetch(
"""
SELECT hostname, last_health SELECT hostname, last_health
FROM voice_servers FROM voice_servers
WHERE region_id = $1 WHERE region_id = $1
""", region) """,
region,
)
return list(map(dict, rows)) return list(map(dict, rows))

View File

@ -23,6 +23,7 @@ from dataclasses import dataclass, asdict
@dataclass @dataclass
class VoiceState: class VoiceState:
"""Represents a voice state.""" """Represents a voice state."""
guild_id: int guild_id: int
channel_id: int channel_id: int
user_id: int user_id: int
@ -55,7 +56,7 @@ class VoiceState:
# a better approach would be actually using # a better approach would be actually using
# the suppressed_by field for backend efficiency. # the suppressed_by field for backend efficiency.
self_dict['suppress'] = user_id == self.suppressed_by self_dict["suppress"] = user_id == self.suppressed_by
self_dict.pop('suppressed_by') self_dict.pop("suppressed_by")
return self_dict return self_dict

View File

@ -27,5 +27,5 @@ import config
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
if __name__ == '__main__': if __name__ == "__main__":
sys.exit(main(config)) sys.exit(main(config))

View File

@ -16,4 +16,3 @@ You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """

View File

@ -26,7 +26,7 @@ ALPHABET = string.ascii_lowercase + string.ascii_uppercase + string.digits
async def _gen_inv() -> str: async def _gen_inv() -> str:
"""Generate an invite code""" """Generate an invite code"""
return ''.join(choice(ALPHABET) for _ in range(6)) return "".join(choice(ALPHABET) for _ in range(6))
async def gen_inv(ctx) -> str: async def gen_inv(ctx) -> str:
@ -34,11 +34,14 @@ async def gen_inv(ctx) -> str:
for _ in range(10): for _ in range(10):
possible_inv = await _gen_inv() possible_inv = await _gen_inv()
created_at = await ctx.db.fetchval(""" created_at = await ctx.db.fetchval(
"""
SELECT created_at SELECT created_at
FROM instance_invites FROM instance_invites
WHERE code = $1 WHERE code = $1
""", possible_inv) """,
possible_inv,
)
if created_at is None: if created_at is None:
return possible_inv return possible_inv
@ -51,27 +54,32 @@ async def make_inv(ctx, args):
max_uses = args.max_uses max_uses = args.max_uses
await ctx.db.execute(""" await ctx.db.execute(
"""
INSERT INTO instance_invites (code, max_uses) INSERT INTO instance_invites (code, max_uses)
VALUES ($1, $2) VALUES ($1, $2)
""", code, max_uses) """,
code,
max_uses,
)
print(f'invite created with {max_uses} max uses', code) print(f"invite created with {max_uses} max uses", code)
async def list_invs(ctx, args): async def list_invs(ctx, args):
rows = await ctx.db.fetch(""" rows = await ctx.db.fetch(
"""
SELECT code, created_at, uses, max_uses SELECT code, created_at, uses, max_uses
FROM instance_invites FROM instance_invites
""") """
)
print(len(rows), 'invites') print(len(rows), "invites")
for row in rows: for row in rows:
max_uses = row['max_uses'] max_uses = row["max_uses"]
delta = datetime.datetime.utcnow() - row['created_at'] delta = datetime.datetime.utcnow() - row["created_at"]
usage = ('infinite uses' if max_uses == -1 usage = "infinite uses" if max_uses == -1 else f'{row["uses"]} / {max_uses}'
else f'{row["uses"]} / {max_uses}')
print(f'\t{row["code"]}, {usage}, made {delta} ago') print(f'\t{row["code"]}, {usage}, made {delta} ago')
@ -79,40 +87,37 @@ async def list_invs(ctx, args):
async def delete_inv(ctx, args): async def delete_inv(ctx, args):
inv = args.invite_code inv = args.invite_code
res = await ctx.db.execute(""" res = await ctx.db.execute(
"""
DELETE FROM instance_invites DELETE FROM instance_invites
WHERE code = $1 WHERE code = $1
""", inv) """,
inv,
)
if res == 'DELETE 0': if res == "DELETE 0":
print('NOT FOUND') print("NOT FOUND")
return return
print('OK') print("OK")
def setup(subparser): def setup(subparser):
makeinv_parser = subparser.add_parser( makeinv_parser = subparser.add_parser("makeinv", help="create an invite")
'makeinv',
help='create an invite',
)
makeinv_parser.add_argument( makeinv_parser.add_argument(
'max_uses', nargs='?', type=int, default=-1, "max_uses",
help='Maximum amount of uses before the invite is unavailable', nargs="?",
type=int,
default=-1,
help="Maximum amount of uses before the invite is unavailable",
) )
makeinv_parser.set_defaults(func=make_inv) makeinv_parser.set_defaults(func=make_inv)
listinv_parser = subparser.add_parser( listinv_parser = subparser.add_parser("listinv", help="list all invites")
'listinv',
help='list all invites',
)
listinv_parser.set_defaults(func=list_invs) listinv_parser.set_defaults(func=list_invs)
delinv_parser = subparser.add_parser( delinv_parser = subparser.add_parser("delinv", help="delete an invite")
'delinv', delinv_parser.add_argument("invite_code")
help='delete an invite',
)
delinv_parser.add_argument('invite_code')
delinv_parser.set_defaults(func=delete_inv) delinv_parser.set_defaults(func=delete_inv)

View File

@ -19,4 +19,4 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
from .command import setup as migration from .command import setup as migration
__all__ = ['migration'] __all__ = ["migration"]

View File

@ -32,18 +32,19 @@ from logbook import Logger
log = Logger(__name__) log = Logger(__name__)
Migration = namedtuple('Migration', 'id name path') Migration = namedtuple("Migration", "id name path")
# line of change, 4 april 2019, at 1am (gmt+0) # line of change, 4 april 2019, at 1am (gmt+0)
BREAK = datetime.datetime(2019, 4, 4, 1) BREAK = datetime.datetime(2019, 4, 4, 1)
# if a database has those tables, it ran 0_base.sql. # if a database has those tables, it ran 0_base.sql.
HAS_BASE = ['users', 'guilds', 'e'] HAS_BASE = ["users", "guilds", "e"]
@dataclass @dataclass
class MigrationContext: class MigrationContext:
"""Hold information about migration.""" """Hold information about migration."""
migration_folder: Path migration_folder: Path
scripts: Dict[int, Migration] scripts: Dict[int, Migration]
@ -60,22 +61,21 @@ def make_migration_ctx() -> MigrationContext:
script_folder = os.sep.join(script_path.split(os.sep)[:-1]) script_folder = os.sep.join(script_path.split(os.sep)[:-1])
script_folder = Path(script_folder) script_folder = Path(script_folder)
migration_folder = script_folder / 'scripts' migration_folder = script_folder / "scripts"
mctx = MigrationContext(migration_folder, {}) mctx = MigrationContext(migration_folder, {})
for mig_path in migration_folder.glob('*.sql'): for mig_path in migration_folder.glob("*.sql"):
mig_path_str = str(mig_path) mig_path_str = str(mig_path)
# extract migration script id and name # extract migration script id and name
mig_filename = mig_path_str.split(os.sep)[-1].split('.')[0] mig_filename = mig_path_str.split(os.sep)[-1].split(".")[0]
name_fragments = mig_filename.split('_') name_fragments = mig_filename.split("_")
mig_id = int(name_fragments[0]) mig_id = int(name_fragments[0])
mig_name = '_'.join(name_fragments[1:]) mig_name = "_".join(name_fragments[1:])
mctx.scripts[mig_id] = Migration( mctx.scripts[mig_id] = Migration(mig_id, mig_name, mig_path)
mig_id, mig_name, mig_path)
return mctx return mctx
@ -83,7 +83,8 @@ def make_migration_ctx() -> MigrationContext:
async def _ensure_changelog(app, ctx): async def _ensure_changelog(app, ctx):
# make sure we have the migration table up # make sure we have the migration table up
try: try:
await app.db.execute(""" await app.db.execute(
"""
CREATE TABLE migration_log ( CREATE TABLE migration_log (
change_num bigint NOT NULL, change_num bigint NOT NULL,
@ -94,43 +95,56 @@ async def _ensure_changelog(app, ctx):
PRIMARY KEY (change_num) PRIMARY KEY (change_num)
); );
""") """
)
except asyncpg.DuplicateTableError: except asyncpg.DuplicateTableError:
log.debug('existing migration table') log.debug("existing migration table")
# NOTE: this is a migration breakage, # NOTE: this is a migration breakage,
# only applying to databases that had their first migration # only applying to databases that had their first migration
# before 4 april 2019 (more on BREAK) # before 4 april 2019 (more on BREAK)
# if migration_log is empty, just assume this is new # if migration_log is empty, just assume this is new
first = await app.db.fetchval(""" first = (
await app.db.fetchval(
"""
SELECT apply_ts FROM migration_log SELECT apply_ts FROM migration_log
ORDER BY apply_ts ASC ORDER BY apply_ts ASC
LIMIT 1 LIMIT 1
""") or BREAK """
)
or BREAK
)
if first < BREAK: if first < BREAK:
log.info('deleting migration_log due to migration structure change') log.info("deleting migration_log due to migration structure change")
await app.db.execute("DROP TABLE migration_log") await app.db.execute("DROP TABLE migration_log")
await _ensure_changelog(app, ctx) await _ensure_changelog(app, ctx)
async def _insert_log(app, migration_id: int, description) -> bool: async def _insert_log(app, migration_id: int, description) -> bool:
try: try:
await app.db.execute(""" await app.db.execute(
"""
INSERT INTO migration_log (change_num, description) INSERT INTO migration_log (change_num, description)
VALUES ($1, $2) VALUES ($1, $2)
""", migration_id, description) """,
migration_id,
description,
)
return True return True
except asyncpg.UniqueViolationError: except asyncpg.UniqueViolationError:
log.warning('already inserted {}', migration_id) log.warning("already inserted {}", migration_id)
return False return False
async def _delete_log(app, migration_id: int): async def _delete_log(app, migration_id: int):
await app.db.execute(""" await app.db.execute(
"""
DELETE FROM migration_log WHERE change_num = $1 DELETE FROM migration_log WHERE change_num = $1
""", migration_id) """,
migration_id,
)
async def apply_migration(app, migration: Migration) -> bool: async def apply_migration(app, migration: Migration) -> bool:
@ -144,21 +158,20 @@ async def apply_migration(app, migration: Migration) -> bool:
Returns a boolean signaling if this failed or not. Returns a boolean signaling if this failed or not.
""" """
migration_sql = migration.path.read_text(encoding='utf-8') migration_sql = migration.path.read_text(encoding="utf-8")
res = await _insert_log( res = await _insert_log(app, migration.id, f"migration: {migration.name}")
app, migration.id, f'migration: {migration.name}')
if not res: if not res:
return False return False
try: try:
await app.db.execute(migration_sql) await app.db.execute(migration_sql)
log.info('applied {} {}', migration.id, migration.name) log.info("applied {} {}", migration.id, migration.name)
return True return True
except: except:
log.exception('failed to run migration, rollbacking log') log.exception("failed to run migration, rollbacking log")
await _delete_log(app, migration.id) await _delete_log(app, migration.id)
return False return False
@ -169,9 +182,11 @@ async def _check_base(app) -> bool:
file.""" file."""
try: try:
for table in HAS_BASE: for table in HAS_BASE:
await app.db.execute(f""" await app.db.execute(
f"""
SELECT * FROM {table} LIMIT 0 SELECT * FROM {table} LIMIT 0
""") """
)
except asyncpg.UndefinedTableError: except asyncpg.UndefinedTableError:
return False return False
@ -197,14 +212,16 @@ async def migrate_cmd(app, _args):
has_base = await _check_base(app) has_base = await _check_base(app)
# fetch latest local migration that has been run on this database # fetch latest local migration that has been run on this database
local_change = await app.db.fetchval(""" local_change = await app.db.fetchval(
"""
SELECT max(change_num) SELECT max(change_num)
FROM migration_log FROM migration_log
""") """
)
# if base exists, add it to logs, if not, apply (and add to logs) # if base exists, add it to logs, if not, apply (and add to logs)
if has_base: if has_base:
await _insert_log(app, 0, 'migration setup (from existing)') await _insert_log(app, 0, "migration setup (from existing)")
else: else:
await apply_migration(app, ctx.scripts[0]) await apply_migration(app, ctx.scripts[0])
@ -215,10 +232,10 @@ async def migrate_cmd(app, _args):
local_change = local_change or 0 local_change = local_change or 0
latest_change = ctx.latest latest_change = ctx.latest
log.debug('local: {}, latest: {}', local_change, latest_change) log.debug("local: {}, latest: {}", local_change, latest_change)
if local_change == latest_change: if local_change == latest_change:
print('no changes to do, exiting') print("no changes to do, exiting")
return return
# we do local_change + 1 so we start from the # we do local_change + 1 so we start from the
@ -227,15 +244,13 @@ async def migrate_cmd(app, _args):
for idx in range(local_change + 1, latest_change + 1): for idx in range(local_change + 1, latest_change + 1):
migration = ctx.scripts.get(idx) migration = ctx.scripts.get(idx)
print('applying', migration.id, migration.name) print("applying", migration.id, migration.name)
await apply_migration(app, migration) await apply_migration(app, migration)
def setup(subparser): def setup(subparser):
migrate_parser = subparser.add_parser( migrate_parser = subparser.add_parser(
'migrate', "migrate", help="Run migration tasks", description=migrate_cmd.__doc__
help='Run migration tasks',
description=migrate_cmd.__doc__
) )
migrate_parser.set_defaults(func=migrate_cmd) migrate_parser.set_defaults(func=migrate_cmd)

View File

@ -24,39 +24,51 @@ from litecord.enums import UserFlags
async def find_user(username, discrim, ctx) -> int: async def find_user(username, discrim, ctx) -> int:
"""Get a user ID via the username/discrim pair.""" """Get a user ID via the username/discrim pair."""
return await ctx.db.fetchval(""" return await ctx.db.fetchval(
"""
SELECT id SELECT id
FROM users FROM users
WHERE username = $1 AND discriminator = $2 WHERE username = $1 AND discriminator = $2
""", username, discrim) """,
username,
discrim,
)
async def set_user_staff(user_id, ctx): async def set_user_staff(user_id, ctx):
"""Give a single user staff status.""" """Give a single user staff status."""
old_flags = await ctx.db.fetchval(""" old_flags = await ctx.db.fetchval(
"""
SELECT flags SELECT flags
FROM users FROM users
WHERE id = $1 WHERE id = $1
""", user_id) """,
user_id,
)
new_flags = old_flags | UserFlags.staff new_flags = old_flags | UserFlags.staff
await ctx.db.execute(""" await ctx.db.execute(
"""
UPDATE users UPDATE users
SET flags=$1 SET flags=$1
WHERE id = $2 WHERE id = $2
""", new_flags, user_id) """,
new_flags,
user_id,
)
async def adduser(ctx, args): async def adduser(ctx, args):
"""Create a single user.""" """Create a single user."""
uid, _ = await create_user(args.username, args.email, uid, _ = await create_user(
args.password, ctx.db, ctx.loop) args.username, args.email, args.password, ctx.db, ctx.loop
)
user = await ctx.storage.get_user(uid) user = await ctx.storage.get_user(uid)
print('created!') print("created!")
print(f'\tuid: {uid}') print(f"\tuid: {uid}")
print(f'\tusername: {user["username"]}') print(f'\tusername: {user["username"]}')
print(f'\tdiscrim: {user["discriminator"]}') print(f'\tdiscrim: {user["discriminator"]}')
@ -72,22 +84,26 @@ async def make_staff(ctx, args):
uid = await find_user(args.username, args.discrim, ctx) uid = await find_user(args.username, args.discrim, ctx)
if not uid: if not uid:
return print('user not found') return print("user not found")
await set_user_staff(uid, ctx) await set_user_staff(uid, ctx)
print('OK: set staff') print("OK: set staff")
async def generate_bot_token(ctx, args): async def generate_bot_token(ctx, args):
"""Generate a token for specified bot.""" """Generate a token for specified bot."""
password_hash = await ctx.db.fetchval(""" password_hash = await ctx.db.fetchval(
"""
SELECT password_hash SELECT password_hash
FROM users FROM users
WHERE id = $1 AND bot = 'true' WHERE id = $1 AND bot = 'true'
""", int(args.user_id)) """,
int(args.user_id),
)
if not password_hash: if not password_hash:
return print('cannot find a bot with specified id') return print("cannot find a bot with specified id")
print(make_token(args.user_id, password_hash)) print(make_token(args.user_id, password_hash))
@ -97,7 +113,7 @@ async def del_user(ctx, args):
uid = await find_user(args.username, args.discrim, ctx) uid = await find_user(args.username, args.discrim, ctx)
if uid is None: if uid is None:
print('user not found') print("user not found")
return return
user = await ctx.storage.get_user(uid) user = await ctx.storage.get_user(uid)
@ -106,57 +122,48 @@ async def del_user(ctx, args):
print(f'\tuname: {user["username"]}') print(f'\tuname: {user["username"]}')
print(f'\tdiscrim: {user["discriminator"]}') print(f'\tdiscrim: {user["discriminator"]}')
print('\n you sure you want to delete user? press Y (uppercase)') print("\n you sure you want to delete user? press Y (uppercase)")
confirm = input() confirm = input()
if confirm != 'Y': if confirm != "Y":
print('not confirmed') print("not confirmed")
return return
await delete_user(uid, app_=ctx) await delete_user(uid, app_=ctx)
print('ok') print("ok")
def setup(subparser): def setup(subparser):
setup_test_parser = subparser.add_parser( setup_test_parser = subparser.add_parser("adduser", help="create a user")
'adduser',
help='create a user',
)
setup_test_parser.add_argument( setup_test_parser.add_argument("username", help="username of the user")
'username', help='username of the user') setup_test_parser.add_argument("email", help="email of the user")
setup_test_parser.add_argument( setup_test_parser.add_argument("password", help="password of the user")
'email', help='email of the user')
setup_test_parser.add_argument(
'password', help='password of the user')
setup_test_parser.set_defaults(func=adduser) setup_test_parser.set_defaults(func=adduser)
staff_parser = subparser.add_parser( staff_parser = subparser.add_parser(
'make_staff', "make_staff", help="make a user staff", description=make_staff.__doc__
help='make a user staff',
description=make_staff.__doc__
) )
staff_parser.add_argument('username') staff_parser.add_argument("username")
staff_parser.add_argument( staff_parser.add_argument("discrim", help="the discriminator of the user")
'discrim', help='the discriminator of the user')
staff_parser.set_defaults(func=make_staff) staff_parser.set_defaults(func=make_staff)
del_user_parser = subparser.add_parser( del_user_parser = subparser.add_parser("deluser", help="delete a single user")
'deluser', help='delete a single user')
del_user_parser.add_argument('username') del_user_parser.add_argument("username")
del_user_parser.add_argument('discrim') del_user_parser.add_argument("discrim")
del_user_parser.set_defaults(func=del_user) del_user_parser.set_defaults(func=del_user)
token_parser = subparser.add_parser( token_parser = subparser.add_parser(
'generate_token', "generate_token",
help='generate a token for specified bot', help="generate a token for specified bot",
description=generate_bot_token.__doc__) description=generate_bot_token.__doc__,
)
token_parser.add_argument('user_id') token_parser.add_argument("user_id")
token_parser.set_defaults(func=generate_bot_token) token_parser.set_defaults(func=generate_bot_token)

View File

@ -34,6 +34,7 @@ log = Logger(__name__)
@dataclass @dataclass
class FakeApp: class FakeApp:
"""Fake app instance.""" """Fake app instance."""
config: dict config: dict
db = None db = None
loop: asyncio.BaseEventLoop = None loop: asyncio.BaseEventLoop = None
@ -50,7 +51,7 @@ class FakeApp:
def init_parser(): def init_parser():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
subparser = parser.add_subparsers(help='operations') subparser = parser.add_subparsers(help="operations")
migration(subparser) migration(subparser)
users.setup(subparser) users.setup(subparser)
@ -78,12 +79,12 @@ def main(config):
# only init app managers when we aren't migrating # only init app managers when we aren't migrating
# as the managers require it # as the managers require it
# and the migrate command also sets the db up # and the migrate command also sets the db up
if argv[1] != 'migrate': if argv[1] != "migrate":
init_app_managers(app, voice=False) init_app_managers(app, voice=False)
args = parser.parse_args() args = parser.parse_args()
loop.run_until_complete(args.func(app, args)) loop.run_until_complete(args.func(app, args))
except Exception: except Exception:
log.exception('error while running command') log.exception("error while running command")
finally: finally:
loop.run_until_complete(app.db.close()) loop.run_until_complete(app.db.close())

236
run.py
View File

@ -33,32 +33,51 @@ from aiohttp import ClientSession
import config import config
from litecord.blueprints import ( from litecord.blueprints import (
gateway, auth, users, guilds, channels, webhooks, science, gateway,
voice, invites, relationships, dms, icons, nodeinfo, static, auth,
attachments, dm_channels users,
guilds,
channels,
webhooks,
science,
voice,
invites,
relationships,
dms,
icons,
nodeinfo,
static,
attachments,
dm_channels,
) )
# those blueprints are separated from the "main" ones # those blueprints are separated from the "main" ones
# for code readability if people want to dig through # for code readability if people want to dig through
# the codebase. # the codebase.
from litecord.blueprints.guild import ( from litecord.blueprints.guild import (
guild_roles, guild_members, guild_channels, guild_mod, guild_roles,
guild_emoji guild_members,
guild_channels,
guild_mod,
guild_emoji,
) )
from litecord.blueprints.channel import ( from litecord.blueprints.channel import (
channel_messages, channel_reactions, channel_pins channel_messages,
channel_reactions,
channel_pins,
) )
from litecord.blueprints.user import ( from litecord.blueprints.user import user_settings, user_billing, fake_store
user_settings, user_billing, fake_store
)
from litecord.blueprints.user.billing_job import payment_job from litecord.blueprints.user.billing_job import payment_job
from litecord.blueprints.admin_api import ( from litecord.blueprints.admin_api import (
voice as voice_admin, features as features_admin, voice as voice_admin,
guilds as guilds_admin, users as users_admin, instance_invites features as features_admin,
guilds as guilds_admin,
users as users_admin,
instance_invites,
) )
from litecord.blueprints.admin_api.voice import guild_region_check from litecord.blueprints.admin_api.voice import guild_region_check
@ -84,23 +103,23 @@ from litecord.utils import LitecordJSONEncoder
# setup logbook # setup logbook
handler = StreamHandler(sys.stdout, level=logbook.INFO) handler = StreamHandler(sys.stdout, level=logbook.INFO)
handler.push_application() handler.push_application()
log = Logger('litecord.boot') log = Logger("litecord.boot")
redirect_logging() redirect_logging()
def make_app(): def make_app():
app = Quart(__name__) app = Quart(__name__)
app.config.from_object(f'config.{config.MODE}') app.config.from_object(f"config.{config.MODE}")
is_debug = app.config.get('DEBUG', False) is_debug = app.config.get("DEBUG", False)
app.debug = is_debug app.debug = is_debug
if is_debug: if is_debug:
log.info('on debug') log.info("on debug")
handler.level = logbook.DEBUG handler.level = logbook.DEBUG
app.logger.level = logbook.DEBUG app.logger.level = logbook.DEBUG
# always keep websockets on INFO # always keep websockets on INFO
logging.getLogger('websockets').setLevel(logbook.INFO) logging.getLogger("websockets").setLevel(logbook.INFO)
# use our custom json encoder for custom data types # use our custom json encoder for custom data types
app.json_encoder = LitecordJSONEncoder app.json_encoder = LitecordJSONEncoder
@ -112,51 +131,44 @@ def set_blueprints(app_):
"""Set the blueprints for a given app instance""" """Set the blueprints for a given app instance"""
bps = { bps = {
gateway: None, gateway: None,
auth: '/auth', auth: "/auth",
users: "/users",
users: '/users', user_settings: "/users",
user_settings: '/users', user_billing: "/users",
user_billing: '/users', relationships: "/users",
relationships: '/users', guilds: "/guilds",
guild_roles: "/guilds",
guilds: '/guilds', guild_members: "/guilds",
guild_roles: '/guilds', guild_channels: "/guilds",
guild_members: '/guilds', guild_mod: "/guilds",
guild_channels: '/guilds', guild_emoji: "/guilds",
guild_mod: '/guilds', channels: "/channels",
guild_emoji: '/guilds', channel_messages: "/channels",
channel_reactions: "/channels",
channels: '/channels', channel_pins: "/channels",
channel_messages: '/channels',
channel_reactions: '/channels',
channel_pins: '/channels',
webhooks: None, webhooks: None,
science: None, science: None,
voice: '/voice', voice: "/voice",
invites: None, invites: None,
dms: '/users', dms: "/users",
dm_channels: '/channels', dm_channels: "/channels",
fake_store: None, fake_store: None,
icons: -1, icons: -1,
attachments: -1, attachments: -1,
nodeinfo: -1, nodeinfo: -1,
static: -1, static: -1,
voice_admin: "/admin/voice",
voice_admin: '/admin/voice', features_admin: "/admin/guilds",
features_admin: '/admin/guilds', guilds_admin: "/admin/guilds",
guilds_admin: '/admin/guilds', users_admin: "/admin/users",
users_admin: '/admin/users', instance_invites: "/admin/instance/invites",
instance_invites: '/admin/instance/invites'
} }
for bp, suffix in bps.items(): for bp, suffix in bps.items():
url_prefix = f'/api/v6{suffix or ""}' url_prefix = f'/api/v6{suffix or ""}'
if suffix == -1: if suffix == -1:
url_prefix = '' url_prefix = ""
app_.register_blueprint(bp, url_prefix=url_prefix) app_.register_blueprint(bp, url_prefix=url_prefix)
@ -175,37 +187,35 @@ async def app_before_request():
@app.after_request @app.after_request
async def app_after_request(resp): async def app_after_request(resp):
"""Handle CORS headers.""" """Handle CORS headers."""
origin = request.headers.get('Origin', '*') origin = request.headers.get("Origin", "*")
resp.headers['Access-Control-Allow-Origin'] = origin resp.headers["Access-Control-Allow-Origin"] = origin
resp.headers['Access-Control-Allow-Headers'] = ( resp.headers["Access-Control-Allow-Headers"] = (
'*, X-Super-Properties, ' "*, X-Super-Properties, "
'X-Fingerprint, ' "X-Fingerprint, "
'X-Context-Properties, ' "X-Context-Properties, "
'X-Failed-Requests, ' "X-Failed-Requests, "
'X-Debug-Options, ' "X-Debug-Options, "
'Content-Type, ' "Content-Type, "
'Authorization, ' "Authorization, "
'Origin, ' "Origin, "
'If-None-Match' "If-None-Match"
) )
resp.headers['Access-Control-Allow-Methods'] = \ resp.headers["Access-Control-Allow-Methods"] = resp.headers.get("allow", "*")
resp.headers.get('allow', '*')
return resp return resp
def _set_rtl_reset(bucket, resp): def _set_rtl_reset(bucket, resp):
reset = bucket._window + bucket.second reset = bucket._window + bucket.second
precision = request.headers.get('x-ratelimit-precision', 'second') precision = request.headers.get("x-ratelimit-precision", "second")
if precision == 'second': if precision == "second":
resp.headers['X-RateLimit-Reset'] = str(round(reset)) resp.headers["X-RateLimit-Reset"] = str(round(reset))
elif precision == 'millisecond': elif precision == "millisecond":
resp.headers['X-RateLimit-Reset'] = str(reset) resp.headers["X-RateLimit-Reset"] = str(reset)
else: else:
resp.headers['X-RateLimit-Reset'] = ( resp.headers["X-RateLimit-Reset"] = (
'Invalid X-RateLimit-Precision, ' "Invalid X-RateLimit-Precision, " "valid options are (second, millisecond)"
'valid options are (second, millisecond)'
) )
@ -218,15 +228,15 @@ async def app_set_ratelimit_headers(resp):
if bucket is None: if bucket is None:
raise AttributeError() raise AttributeError()
resp.headers['X-RateLimit-Limit'] = str(bucket.requests) resp.headers["X-RateLimit-Limit"] = str(bucket.requests)
resp.headers['X-RateLimit-Remaining'] = str(bucket._tokens) resp.headers["X-RateLimit-Remaining"] = str(bucket._tokens)
resp.headers['X-RateLimit-Global'] = str(request.bucket_global).lower() resp.headers["X-RateLimit-Global"] = str(request.bucket_global).lower()
_set_rtl_reset(bucket, resp) _set_rtl_reset(bucket, resp)
# only add Retry-After if we actually hit a ratelimit # only add Retry-After if we actually hit a ratelimit
retry_after = request.retry_after retry_after = request.retry_after
if request.retry_after: if request.retry_after:
resp.headers['Retry-After'] = str(retry_after) resp.headers["Retry-After"] = str(retry_after)
except AttributeError: except AttributeError:
pass pass
@ -238,8 +248,8 @@ async def init_app_db(app_):
Also spawns the job scheduler. Also spawns the job scheduler.
""" """
log.info('db connect') log.info("db connect")
app_.db = await asyncpg.create_pool(**app.config['POSTGRES']) app_.db = await asyncpg.create_pool(**app.config["POSTGRES"])
app_.sched = JobManager() app_.sched = JobManager()
@ -247,7 +257,7 @@ async def init_app_db(app_):
def init_app_managers(app_, *, voice=True): def init_app_managers(app_, *, voice=True):
"""Initialize singleton classes.""" """Initialize singleton classes."""
app_.loop = asyncio.get_event_loop() app_.loop = asyncio.get_event_loop()
app_.ratelimiter = RatelimitManager(app_.config.get('_testing')) app_.ratelimiter = RatelimitManager(app_.config.get("_testing"))
app_.state_manager = StateManager() app_.state_manager = StateManager()
app_.storage = Storage(app_) app_.storage = Storage(app_)
@ -274,15 +284,12 @@ async def api_index(app_):
to_find = {} to_find = {}
found = [] found = []
with open('discord_endpoints.txt') as fd: with open("discord_endpoints.txt") as fd:
for line in fd.readlines(): for line in fd.readlines():
components = line.split(' ') components = line.split(" ")
components = list(filter( components = list(filter(bool, components))
bool,
components
))
name, method, path = components name, method, path = components
path = f'/api/v6{path.strip()}' path = f"/api/v6{path.strip()}"
method = method.strip() method = method.strip()
to_find[(path, method)] = name to_find[(path, method)] = name
@ -290,17 +297,17 @@ async def api_index(app_):
path = rule.rule path = rule.rule
# convert the path to the discord_endpoints file's style # convert the path to the discord_endpoints file's style
path = path.replace('_', '.') path = path.replace("_", ".")
path = path.replace('<', '{') path = path.replace("<", "{")
path = path.replace('>', '}') path = path.replace(">", "}")
path = path.replace('int:', '') path = path.replace("int:", "")
# change our parameters into user.id # change our parameters into user.id
path = path.replace('member.id', 'user.id') path = path.replace("member.id", "user.id")
path = path.replace('banned.id', 'user.id') path = path.replace("banned.id", "user.id")
path = path.replace('target.id', 'user.id') path = path.replace("target.id", "user.id")
path = path.replace('other.id', 'user.id') path = path.replace("other.id", "user.id")
path = path.replace('peer.id', 'user.id') path = path.replace("peer.id", "user.id")
methods = rule.methods methods = rule.methods
@ -317,10 +324,15 @@ async def api_index(app_):
percentage = (len(found) / len(api)) * 100 percentage = (len(found) / len(api)) * 100
percentage = round(percentage, 2) percentage = round(percentage, 2)
log.debug('API compliance: {} out of {} ({} missing), {}% compliant', log.debug(
len(found), len(api), len(missing), percentage) "API compliance: {} out of {} ({} missing), {}% compliant",
len(found),
len(api),
len(missing),
percentage,
)
log.debug('missing: {}', missing) log.debug("missing: {}", missing)
async def post_app_start(app_): async def post_app_start(app_):
@ -332,7 +344,7 @@ async def post_app_start(app_):
def start_websocket(host, port, ws_handler) -> asyncio.Future: def start_websocket(host, port, ws_handler) -> asyncio.Future:
"""Start a websocket. Returns the websocket future""" """Start a websocket. Returns the websocket future"""
log.info(f'starting websocket at {host} {port}') log.info(f"starting websocket at {host} {port}")
async def _wrapper(ws, url): async def _wrapper(ws, url):
# We wrap the main websocket_handler # We wrap the main websocket_handler
@ -348,7 +360,7 @@ async def app_before_serving():
Also sets up the websocket handlers. Also sets up the websocket handlers.
""" """
log.info('opening db') log.info("opening db")
await init_app_db(app) await init_app_db(app)
app.session = ClientSession() app.session = ClientSession()
@ -359,8 +371,7 @@ async def app_before_serving():
# start gateway websocket # start gateway websocket
# voice websocket is handled by the voice server # voice websocket is handled by the voice server
ws_fut = start_websocket( ws_fut = start_websocket(
app.config['WS_HOST'], app.config['WS_PORT'], app.config["WS_HOST"], app.config["WS_PORT"], websocket_handler
websocket_handler
) )
await ws_fut await ws_fut
@ -379,7 +390,7 @@ async def app_after_serving():
app.sched.close() app.sched.close()
log.info('closing db') log.info("closing db")
await app.db.close() await app.db.close()
@ -391,24 +402,23 @@ async def handle_litecord_err(err):
ejson = {} ejson = {}
try: try:
ejson['code'] = err.error_code ejson["code"] = err.error_code
except AttributeError: except AttributeError:
pass pass
log.warning('error: {} {!r}', err.status_code, err.message) log.warning("error: {} {!r}", err.status_code, err.message)
return jsonify({ return (
'error': True, jsonify(
'status': err.status_code, {"error": True, "status": err.status_code, "message": err.message, **ejson}
'message': err.message, ),
**ejson err.status_code,
}), err.status_code )
@app.errorhandler(500) @app.errorhandler(500)
async def handle_500(err): async def handle_500(err):
return jsonify({ return (
'error': True, jsonify({"error": True, "message": repr(err), "internal_server_error": True}),
'message': repr(err), 500,
'internal_server_error': True, )
}), 500

View File

@ -20,10 +20,10 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
from setuptools import setup from setuptools import setup
setup( setup(
name='litecord', name="litecord",
version='0.0.1', version="0.0.1",
description='Implementation of the Discord API', description="Implementation of the Discord API",
url='https://litecord.top', url="https://litecord.top",
author='Luna Mendes', author="Luna Mendes",
python_requires='>=3.7' python_requires=">=3.7",
) )

View File

@ -19,13 +19,15 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import secrets import secrets
def email() -> str: def email() -> str:
return f'{secrets.token_hex(5)}@{secrets.token_hex(5)}.com' return f"{secrets.token_hex(5)}@{secrets.token_hex(5)}.com"
class TestClient: class TestClient:
"""Test client that wraps pytest-sanic's TestClient and a test """Test client that wraps pytest-sanic's TestClient and a test
user and adds authorization headers to test requests.""" user and adds authorization headers to test requests."""
def __init__(self, test_cli, test_user): def __init__(self, test_cli, test_user):
self.cli = test_cli self.cli = test_cli
self.app = test_cli.app self.app = test_cli.app
@ -37,31 +39,31 @@ class TestClient:
def _inject_auth(self, kwargs: dict) -> list: def _inject_auth(self, kwargs: dict) -> list:
"""Inject the test user's API key into the test request before """Inject the test user's API key into the test request before
passing the request on to the underlying TestClient.""" passing the request on to the underlying TestClient."""
headers = kwargs.get('headers', {}) headers = kwargs.get("headers", {})
headers['authorization'] = self.user['token'] headers["authorization"] = self.user["token"]
return headers return headers
async def get(self, *args, **kwargs): async def get(self, *args, **kwargs):
"""Send a GET request.""" """Send a GET request."""
kwargs['headers'] = self._inject_auth(kwargs) kwargs["headers"] = self._inject_auth(kwargs)
return await self.cli.get(*args, **kwargs) return await self.cli.get(*args, **kwargs)
async def post(self, *args, **kwargs): async def post(self, *args, **kwargs):
"""Send a POST request.""" """Send a POST request."""
kwargs['headers'] = self._inject_auth(kwargs) kwargs["headers"] = self._inject_auth(kwargs)
return await self.cli.post(*args, **kwargs) return await self.cli.post(*args, **kwargs)
async def put(self, *args, **kwargs): async def put(self, *args, **kwargs):
"""Send a POST request.""" """Send a POST request."""
kwargs['headers'] = self._inject_auth(kwargs) kwargs["headers"] = self._inject_auth(kwargs)
return await self.cli.put(*args, **kwargs) return await self.cli.put(*args, **kwargs)
async def patch(self, *args, **kwargs): async def patch(self, *args, **kwargs):
"""Send a PATCH request.""" """Send a PATCH request."""
kwargs['headers'] = self._inject_auth(kwargs) kwargs["headers"] = self._inject_auth(kwargs)
return await self.cli.patch(*args, **kwargs) return await self.cli.patch(*args, **kwargs)
async def delete(self, *args, **kwargs): async def delete(self, *args, **kwargs):
"""Send a DELETE request.""" """Send a DELETE request."""
kwargs['headers'] = self._inject_auth(kwargs) kwargs["headers"] = self._inject_auth(kwargs)
return await self.cli.delete(*args, **kwargs) return await self.cli.delete(*args, **kwargs)

View File

@ -36,22 +36,22 @@ from litecord.blueprints.auth import make_token
from litecord.blueprints.users import delete_user from litecord.blueprints.users import delete_user
@pytest.fixture(name='app') @pytest.fixture(name="app")
def _test_app(unused_tcp_port, event_loop): def _test_app(unused_tcp_port, event_loop):
set_blueprints(main_app) set_blueprints(main_app)
main_app.config['_testing'] = True main_app.config["_testing"] = True
# reassign an unused tcp port for websockets # reassign an unused tcp port for websockets
# since the config might give a used one. # since the config might give a used one.
ws_port = unused_tcp_port ws_port = unused_tcp_port
main_app.config['IS_SSL'] = False main_app.config["IS_SSL"] = False
main_app.config['WS_PORT'] = ws_port main_app.config["WS_PORT"] = ws_port
main_app.config['WEBSOCKET_URL'] = f'localhost:{ws_port}' main_app.config["WEBSOCKET_URL"] = f"localhost:{ws_port}"
# testing user creations requires hardcoding this to true # testing user creations requires hardcoding this to true
# on testing # on testing
main_app.config['REGISTRATIONS'] = True main_app.config["REGISTRATIONS"] = True
# make sure we're calling the before_serving hooks # make sure we're calling the before_serving hooks
event_loop.run_until_complete(main_app.startup()) event_loop.run_until_complete(main_app.startup())
@ -63,11 +63,12 @@ def _test_app(unused_tcp_port, event_loop):
event_loop.run_until_complete(main_app.shutdown()) event_loop.run_until_complete(main_app.shutdown())
@pytest.fixture(name='test_cli') @pytest.fixture(name="test_cli")
def _test_cli(app): def _test_cli(app):
"""Give a test client.""" """Give a test client."""
return app.test_client() return app.test_client()
# code shamelessly stolen from my elixire mr # code shamelessly stolen from my elixire mr
# https://gitlab.com/elixire/elixire/merge_requests/52 # https://gitlab.com/elixire/elixire/merge_requests/52
async def _user_fixture_setup(app): async def _user_fixture_setup(app):
@ -76,21 +77,26 @@ async def _user_fixture_setup(app):
user_email = email() user_email = email()
user_id, pwd_hash = await create_user( user_id, pwd_hash = await create_user(
username, user_email, password, app.db, app.loop) username, user_email, password, app.db, app.loop
)
# generate a token for api access # generate a token for api access
user_token = make_token(user_id, pwd_hash) user_token = make_token(user_id, pwd_hash)
return {'id': user_id, 'token': user_token, return {
'email': user_email, 'username': username, "id": user_id,
'password': password} "token": user_token,
"email": user_email,
"username": username,
"password": password,
}
async def _user_fixture_teardown(app, udata: dict): async def _user_fixture_teardown(app, udata: dict):
await delete_user(udata['id'], app_=app) await delete_user(udata["id"], app_=app)
@pytest.fixture(name='test_user') @pytest.fixture(name="test_user")
async def test_user_fixture(app): async def test_user_fixture(app):
"""Yield a randomly generated test user.""" """Yield a randomly generated test user."""
udata = await _user_fixture_setup(app) udata = await _user_fixture_setup(app)
@ -113,18 +119,25 @@ async def test_cli_staff(test_cli):
# same test_cli_user, which isn't acceptable. # same test_cli_user, which isn't acceptable.
app = test_cli.app app = test_cli.app
test_user = await _user_fixture_setup(app) test_user = await _user_fixture_setup(app)
user_id = test_user['id'] user_id = test_user["id"]
# copied from manage.cmd.users.set_user_staff. # copied from manage.cmd.users.set_user_staff.
old_flags = await app.db.fetchval(""" old_flags = await app.db.fetchval(
"""
SELECT flags FROM users WHERE id = $1 SELECT flags FROM users WHERE id = $1
""", user_id) """,
user_id,
)
new_flags = old_flags | UserFlags.staff new_flags = old_flags | UserFlags.staff
await app.db.execute(""" await app.db.execute(
"""
UPDATE users SET flags = $1 WHERE id = $2 UPDATE users SET flags = $1 WHERE id = $2
""", new_flags, user_id) """,
new_flags,
user_id,
)
yield TestClient(test_cli, test_user) yield TestClient(test_cli, test_user)
await _user_fixture_teardown(test_cli.app, test_user) await _user_fixture_teardown(test_cli.app, test_user)

View File

@ -24,24 +24,24 @@ import pytest
from litecord.blueprints.guilds import delete_guild from litecord.blueprints.guilds import delete_guild
from litecord.errors import GuildNotFound from litecord.errors import GuildNotFound
async def _create_guild(test_cli_staff): async def _create_guild(test_cli_staff):
genned_name = secrets.token_hex(6) genned_name = secrets.token_hex(6)
resp = await test_cli_staff.post('/api/v6/guilds', json={ resp = await test_cli_staff.post(
'name': genned_name, "/api/v6/guilds", json={"name": genned_name, "region": None}
'region': None )
})
assert resp.status_code == 200 assert resp.status_code == 200
rjson = await resp.json rjson = await resp.json
assert isinstance(rjson, dict) assert isinstance(rjson, dict)
assert rjson['name'] == genned_name assert rjson["name"] == genned_name
return rjson return rjson
async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False): async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False):
resp = await test_cli_staff.get(f'/api/v6/admin/guilds/{guild_id}') resp = await test_cli_staff.get(f"/api/v6/admin/guilds/{guild_id}")
if ret_early: if ret_early:
return resp return resp
@ -49,7 +49,7 @@ async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False):
assert resp.status_code == 200 assert resp.status_code == 200
rjson = await resp.json rjson = await resp.json
assert isinstance(rjson, dict) assert isinstance(rjson, dict)
assert rjson['id'] == guild_id assert rjson["id"] == guild_id
return rjson return rjson
@ -58,7 +58,7 @@ async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False):
async def test_guild_fetch(test_cli_staff): async def test_guild_fetch(test_cli_staff):
"""Test the creation and fetching of a guild via the Admin API.""" """Test the creation and fetching of a guild via the Admin API."""
rjson = await _create_guild(test_cli_staff) rjson = await _create_guild(test_cli_staff)
guild_id = rjson['id'] guild_id = rjson["id"]
try: try:
await _fetch_guild(test_cli_staff, guild_id) await _fetch_guild(test_cli_staff, guild_id)
@ -70,8 +70,8 @@ async def test_guild_fetch(test_cli_staff):
async def test_guild_update(test_cli_staff): async def test_guild_update(test_cli_staff):
"""Test the update of a guild via the Admin API.""" """Test the update of a guild via the Admin API."""
rjson = await _create_guild(test_cli_staff) rjson = await _create_guild(test_cli_staff)
guild_id = rjson['id'] guild_id = rjson["id"]
assert not rjson['unavailable'] assert not rjson["unavailable"]
try: try:
# I believe setting up an entire gateway client registered to the guild # I believe setting up an entire gateway client registered to the guild
@ -79,19 +79,17 @@ async def test_guild_update(test_cli_staff):
# testing them. Yes, I know its a bad idea, but if someone has an easier # testing them. Yes, I know its a bad idea, but if someone has an easier
# way to write that, do send an MR. # way to write that, do send an MR.
resp = await test_cli_staff.patch( resp = await test_cli_staff.patch(
f'/api/v6/admin/guilds/{guild_id}', f"/api/v6/admin/guilds/{guild_id}", json={"unavailable": True}
json={ )
'unavailable': True
})
assert resp.status_code == 200 assert resp.status_code == 200
rjson = await resp.json rjson = await resp.json
assert isinstance(rjson, dict) assert isinstance(rjson, dict)
assert rjson['id'] == guild_id assert rjson["id"] == guild_id
assert rjson['unavailable'] assert rjson["unavailable"]
rjson = await _fetch_guild(test_cli_staff, guild_id) rjson = await _fetch_guild(test_cli_staff, guild_id)
assert rjson['unavailable'] assert rjson["unavailable"]
finally: finally:
await delete_guild(int(guild_id), app_=test_cli_staff.app) await delete_guild(int(guild_id), app_=test_cli_staff.app)
@ -100,20 +98,19 @@ async def test_guild_update(test_cli_staff):
async def test_guild_delete(test_cli_staff): async def test_guild_delete(test_cli_staff):
"""Test the update of a guild via the Admin API.""" """Test the update of a guild via the Admin API."""
rjson = await _create_guild(test_cli_staff) rjson = await _create_guild(test_cli_staff)
guild_id = rjson['id'] guild_id = rjson["id"]
try: try:
resp = await test_cli_staff.delete(f'/api/v6/admin/guilds/{guild_id}') resp = await test_cli_staff.delete(f"/api/v6/admin/guilds/{guild_id}")
assert resp.status_code == 204 assert resp.status_code == 204
resp = await _fetch_guild( resp = await _fetch_guild(test_cli_staff, guild_id, ret_early=True)
test_cli_staff, guild_id, ret_early=True)
assert resp.status_code == 404 assert resp.status_code == 404
rjson = await resp.json rjson = await resp.json
assert isinstance(rjson, dict) assert isinstance(rjson, dict)
assert rjson['error'] assert rjson["error"]
assert rjson['code'] == GuildNotFound.error_code assert rjson["code"] == GuildNotFound.error_code
finally: finally:
await delete_guild(int(guild_id), app_=test_cli_staff.app) await delete_guild(int(guild_id), app_=test_cli_staff.app)

View File

@ -21,7 +21,7 @@ import pytest
async def _get_invs(test_cli): async def _get_invs(test_cli):
resp = await test_cli.get('/api/v6/admin/instance/invites') resp = await test_cli.get("/api/v6/admin/instance/invites")
assert resp.status_code == 200 assert resp.status_code == 200
rjson = await resp.json rjson = await resp.json
@ -39,7 +39,7 @@ async def test_get_invites(test_cli_staff):
async def test_inv_delete_invalid(test_cli_staff): async def test_inv_delete_invalid(test_cli_staff):
"""Test errors happen when trying to delete a """Test errors happen when trying to delete a
non-existing instance invite.""" non-existing instance invite."""
resp = await test_cli_staff.delete('/api/v6/admin/instance/invites/aaaaaa') resp = await test_cli_staff.delete("/api/v6/admin/instance/invites/aaaaaa")
assert resp.status_code == 404 assert resp.status_code == 404
@ -48,21 +48,20 @@ async def test_inv_delete_invalid(test_cli_staff):
async def test_create_invite(test_cli_staff): async def test_create_invite(test_cli_staff):
"""Test the creation of an instance invite, then listing it, """Test the creation of an instance invite, then listing it,
then deleting it.""" then deleting it."""
resp = await test_cli_staff.put('/api/v6/admin/instance/invites', json={ resp = await test_cli_staff.put(
'max_uses': 1 "/api/v6/admin/instance/invites", json={"max_uses": 1}
}) )
assert resp.status_code == 200 assert resp.status_code == 200
rjson = await resp.json rjson = await resp.json
assert isinstance(rjson, dict) assert isinstance(rjson, dict)
code = rjson['code'] code = rjson["code"]
# assert that the invite is in the list # assert that the invite is in the list
invites = await _get_invs(test_cli_staff) invites = await _get_invs(test_cli_staff)
assert any(inv['code'] == code for inv in invites) assert any(inv["code"] == code for inv in invites)
# delete it, and assert it worked # delete it, and assert it worked
resp = await test_cli_staff.delete( resp = await test_cli_staff.delete(f"/api/v6/admin/instance/invites/{code}")
f'/api/v6/admin/instance/invites/{code}')
assert resp.status_code == 204 assert resp.status_code == 204

Some files were not shown because too many files have changed in this diff Show More