black fmt pass

This commit is contained in:
Luna 2019-10-25 07:27:50 -03:00
parent 0bc4b1ba3f
commit 83a1c1ae29
109 changed files with 5575 additions and 4775 deletions

View File

@ -17,13 +17,14 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
MODE = 'CI'
MODE = "CI"
class Config:
"""Default configuration values for litecord."""
MAIN_URL = 'localhost:1'
NAME = 'gitlab ci'
MAIN_URL = "localhost:1"
NAME = "gitlab ci"
# Enable debug logging?
DEBUG = False
@ -37,11 +38,11 @@ class Config:
# Set this url to somewhere *your users*
# will hit the websocket.
# e.g 'gateway.example.com' for reverse proxies.
WEBSOCKET_URL = 'localhost:5001'
WEBSOCKET_URL = "localhost:5001"
# Where to host the websocket?
# (a local address the server will bind to)
WS_HOST = 'localhost'
WS_HOST = "localhost"
WS_PORT = 5001
# Postgres credentials
@ -51,10 +52,10 @@ class Config:
class Development(Config):
DEBUG = True
POSTGRES = {
'host': 'localhost',
'user': 'litecord',
'password': '123',
'database': 'litecord',
"host": "localhost",
"user": "litecord",
"password": "123",
"database": "litecord",
}
@ -66,8 +67,4 @@ class Production(Config):
class CI(Config):
DEBUG = True
POSTGRES = {
'host': 'postgres',
'user': 'postgres',
'password': ''
}
POSTGRES = {"host": "postgres", "user": "postgres", "password": ""}

View File

@ -17,16 +17,17 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
MODE = 'Development'
MODE = "Development"
class Config:
"""Default configuration values for litecord."""
#: Main URL of the instance.
MAIN_URL = 'discordapp.io'
MAIN_URL = "discordapp.io"
#: Name of the instance
NAME = 'Litecord/Nya'
NAME = "Litecord/Nya"
#: Enable debug logging?
DEBUG = False
@ -45,17 +46,17 @@ class Config:
# Set this url to somewhere *your users*
# will hit the websocket.
# e.g 'gateway.example.com' for reverse proxies.
WEBSOCKET_URL = 'localhost:5001'
WEBSOCKET_URL = "localhost:5001"
#: Where to host the websocket?
# (a local address the server will bind to)
WS_HOST = '0.0.0.0'
WS_HOST = "0.0.0.0"
WS_PORT = 5001
#: Mediaproxy URL on the internet
# mediaproxy is made to prevent client IPs being leaked.
# None is a valid value if you don't want to deploy mediaproxy.
MEDIA_PROXY = 'localhost:5002'
MEDIA_PROXY = "localhost:5002"
#: Postgres credentials
POSTGRES = {}
@ -65,10 +66,10 @@ class Development(Config):
DEBUG = True
POSTGRES = {
'host': 'localhost',
'user': 'litecord',
'password': '123',
'database': 'litecord',
"host": "localhost",
"user": "litecord",
"password": "123",
"database": "litecord",
}
@ -77,8 +78,8 @@ class Production(Config):
IS_SSL = True
POSTGRES = {
'host': 'some_production_postgres',
'user': 'some_production_user',
'password': 'some_production_password',
'database': 'litecord_or_anything_else_really',
"host": "some_production_postgres",
"user": "some_production_user",
"password": "some_production_password",
"database": "litecord_or_anything_else_really",
}

View File

@ -16,4 +16,3 @@ You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""

View File

@ -19,42 +19,33 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
from litecord.enums import Feature, UserFlags
VOICE_SERVER = {
'hostname': {'type': 'string', 'maxlength': 255, 'required': True}
}
VOICE_SERVER = {"hostname": {"type": "string", "maxlength": 255, "required": True}}
VOICE_REGION = {
'id': {'type': 'string', 'maxlength': 255, 'required': True},
'name': {'type': 'string', 'maxlength': 255, 'required': True},
'vip': {'type': 'boolean', 'default': False},
'deprecated': {'type': 'boolean', 'default': False},
'custom': {'type': 'boolean', 'default': False},
"id": {"type": "string", "maxlength": 255, "required": True},
"name": {"type": "string", "maxlength": 255, "required": True},
"vip": {"type": "boolean", "default": False},
"deprecated": {"type": "boolean", "default": False},
"custom": {"type": "boolean", "default": False},
}
FEATURES = {
'features': {
'type': 'list', 'required': True,
"features": {
"type": "list",
"required": True,
# using Feature doesn't seem to work with a "not callable" error.
'schema': {'coerce': lambda x: Feature(x)}
"schema": {"coerce": lambda x: Feature(x)},
}
}
USER_CREATE = {
'username': {'type': 'username', 'required': True},
'email': {'type': 'email', 'required': True},
'password': {'type': 'string', 'minlength': 5, 'required': True},
"username": {"type": "username", "required": True},
"email": {"type": "email", "required": True},
"password": {"type": "string", "minlength": 5, "required": True},
}
INSTANCE_INVITE = {
'max_uses': {'type': 'integer', 'required': True}
}
INSTANCE_INVITE = {"max_uses": {"type": "integer", "required": True}}
GUILD_UPDATE = {
'unavailable': {'type': 'boolean', 'required': False}
}
GUILD_UPDATE = {"unavailable": {"type": "boolean", "required": False}}
USER_UPDATE = {
'flags': {'required': False, 'coerce': UserFlags.from_int}
}
USER_UPDATE = {"flags": {"required": False, "coerce": UserFlags.from_int}}

View File

@ -55,44 +55,50 @@ async def raw_token_check(token: str, db=None) -> int:
# just try by fragments instead of
# unpacking
fragments = token.split('.')
fragments = token.split(".")
user_id = fragments[0]
try:
user_id = base64.b64decode(user_id.encode())
user_id = int(user_id)
except (ValueError, binascii.Error):
raise Unauthorized('Invalid user ID type')
raise Unauthorized("Invalid user ID type")
pwd_hash = await db.fetchval("""
pwd_hash = await db.fetchval(
"""
SELECT password_hash
FROM users
WHERE id = $1
""", user_id)
""",
user_id,
)
if not pwd_hash:
raise Unauthorized('User ID not found')
raise Unauthorized("User ID not found")
signer = TimestampSigner(pwd_hash)
try:
signer.unsign(token)
log.debug('login for uid {} successful', user_id)
log.debug("login for uid {} successful", user_id)
# update the user's last_session field
# so that we can keep an exact track of activity,
# even on long-lived single sessions (that can happen
# with people leaving their clients open forever)
await db.execute("""
await db.execute(
"""
UPDATE users
SET last_session = (now() at time zone 'utc')
WHERE id = $1
""", user_id)
""",
user_id,
)
return user_id
except BadSignature:
log.warning('token failed for uid {}', user_id)
raise Forbidden('Invalid token')
log.warning("token failed for uid {}", user_id)
raise Forbidden("Invalid token")
async def token_check() -> int:
@ -104,12 +110,12 @@ async def token_check() -> int:
pass
try:
token = request.headers['Authorization']
token = request.headers["Authorization"]
except KeyError:
raise Unauthorized('No token provided')
raise Unauthorized("No token provided")
if token.startswith('Bot '):
token = token.replace('Bot ', '')
if token.startswith("Bot "):
token = token.replace("Bot ", "")
user_id = await raw_token_check(token)
request.user_id = user_id
@ -120,15 +126,18 @@ async def admin_check() -> int:
"""Check if the user is an admin."""
user_id = await token_check()
flags = await app.db.fetchval("""
flags = await app.db.fetchval(
"""
SELECT flags
FROM users
WHERE id = $1
""", user_id)
""",
user_id,
)
flags = UserFlags.from_int(flags)
if not flags.is_staff:
raise Unauthorized('you are not staff')
raise Unauthorized("you are not staff")
return user_id
@ -138,9 +147,7 @@ async def hash_data(data: str, loop=None) -> str:
loop = loop or app.loop
buf = data.encode()
hashed = await loop.run_in_executor(
None, bcrypt.hashpw, buf, bcrypt.gensalt(14)
)
hashed = await loop.run_in_executor(None, bcrypt.hashpw, buf, bcrypt.gensalt(14))
return hashed.decode()
@ -148,22 +155,28 @@ async def hash_data(data: str, loop=None) -> str:
async def check_username_usage(username: str, db=None):
"""Raise an error if too many people are with the same username."""
db = db or app.db
same_username = await db.fetchval("""
same_username = await db.fetchval(
"""
SELECT COUNT(*)
FROM users
WHERE username = $1
""", username)
""",
username,
)
if same_username > 9000:
raise BadRequest('Too many people.', {
'username': 'Too many people used the same username. '
'Please choose another'
})
raise BadRequest(
"Too many people.",
{
"username": "Too many people used the same username. "
"Please choose another"
},
)
def _raw_discrim() -> str:
new_discrim = randint(1, 9999)
new_discrim = '%04d' % new_discrim
new_discrim = "%04d" % new_discrim
return new_discrim
@ -186,11 +199,15 @@ async def roll_discrim(username: str, *, db=None) -> Optional[str]:
discrim = _raw_discrim()
# check if anyone is with it
res = await db.fetchval("""
res = await db.fetchval(
"""
SELECT id
FROM users
WHERE username = $1 AND discriminator = $2
""", username, discrim)
""",
username,
discrim,
)
# if no user is found with the (username, discrim)
# pair, then this is unique! return it.
@ -200,8 +217,9 @@ async def roll_discrim(username: str, *, db=None) -> Optional[str]:
return None
async def create_user(username: str, email: str, password: str,
db=None, loop=None) -> Tuple[int, str]:
async def create_user(
username: str, email: str, password: str, db=None, loop=None
) -> Tuple[int, str]:
"""Create a single user.
Generates a distriminator and other information. You can fetch the user
@ -214,20 +232,28 @@ async def create_user(username: str, email: str, password: str,
new_discrim = await roll_discrim(username, db=db)
if new_discrim is None:
raise BadRequest('Unable to register.', {
'username': 'Too many people are with this username.'
})
raise BadRequest(
"Unable to register.",
{"username": "Too many people are with this username."},
)
pwd_hash = await hash_data(password, loop)
try:
await db.execute("""
await db.execute(
"""
INSERT INTO users
(id, email, username, discriminator, password_hash)
VALUES
($1, $2, $3, $4, $5)
""", new_id, email, username, new_discrim, pwd_hash)
""",
new_id,
email,
username,
new_discrim,
pwd_hash,
)
except UniqueViolationError:
raise BadRequest('Email already used.')
raise BadRequest("Email already used.")
return new_id, pwd_hash

View File

@ -34,7 +34,21 @@ from .static import bp as static
from .attachments import bp as attachments
from .dm_channels import bp as dm_channels
__all__ = ['gateway', 'auth', 'users', 'guilds', 'channels',
'webhooks', 'science', 'voice', 'invites', 'relationships',
'dms', 'icons', 'nodeinfo', 'static', 'attachments',
'dm_channels']
__all__ = [
"gateway",
"auth",
"users",
"guilds",
"channels",
"webhooks",
"science",
"voice",
"invites",
"relationships",
"dms",
"icons",
"nodeinfo",
"static",
"attachments",
"dm_channels",
]

View File

@ -23,4 +23,4 @@ from .guilds import bp as guilds
from .users import bp as users
from .instance_invites import bp as instance_invites
__all__ = ['voice', 'features', 'guilds', 'users', 'instance_invites']
__all__ = ["voice", "features", "guilds", "users", "instance_invites"]

View File

@ -25,45 +25,53 @@ from litecord.errors import BadRequest
from litecord.schemas import validate
from litecord.admin_schemas import FEATURES
bp = Blueprint('features_admin', __name__)
bp = Blueprint("features_admin", __name__)
async def _features_from_req() -> List[str]:
j = validate(await request.get_json(), FEATURES)
return [feature.value for feature in j['features']]
return [feature.value for feature in j["features"]]
async def _features(guild_id: int):
return jsonify({
'features': await app.storage.guild_features(guild_id)
})
return jsonify({"features": await app.storage.guild_features(guild_id)})
async def _update_features(guild_id: int, features: list):
if 'VANITY_URL' not in features:
if "VANITY_URL" not in features:
existing_inv = await app.storage.vanity_invite(guild_id)
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM vanity_invites
WHERE guild_id = $1
""", guild_id)
""",
guild_id,
)
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM invites
WHERE code = $1
""", existing_inv)
""",
existing_inv,
)
await app.db.execute("""
await app.db.execute(
"""
UPDATE guilds
SET features = $1
WHERE id = $2
""", features, guild_id)
""",
features,
guild_id,
)
guild = await app.storage.get_guild_full(guild_id)
await app.dispatcher.dispatch('guild', guild_id, 'GUILD_UPDATE', guild)
await app.dispatcher.dispatch("guild", guild_id, "GUILD_UPDATE", guild)
@bp.route('/<int:guild_id>/features', methods=['PATCH'])
@bp.route("/<int:guild_id>/features", methods=["PATCH"])
async def replace_features(guild_id: int):
"""Replace the feature list in a guild"""
await admin_check()
@ -76,7 +84,7 @@ async def replace_features(guild_id: int):
return await _features(guild_id)
@bp.route('/<int:guild_id>/features', methods=['PUT'])
@bp.route("/<int:guild_id>/features", methods=["PUT"])
async def insert_features(guild_id: int):
"""Insert a feature on a guild."""
await admin_check()
@ -93,7 +101,7 @@ async def insert_features(guild_id: int):
return await _features(guild_id)
@bp.route('/<int:guild_id>/features', methods=['DELETE'])
@bp.route("/<int:guild_id>/features", methods=["DELETE"])
async def remove_features(guild_id: int):
"""Remove a feature from a guild"""
await admin_check()
@ -104,7 +112,7 @@ async def remove_features(guild_id: int):
try:
features.remove(feature)
except ValueError:
raise BadRequest('Trying to remove already removed feature.')
raise BadRequest("Trying to remove already removed feature.")
await _update_features(guild_id, features)
return await _features(guild_id)

View File

@ -25,9 +25,10 @@ from litecord.admin_schemas import GUILD_UPDATE
from litecord.blueprints.guilds import delete_guild
from litecord.errors import GuildNotFound
bp = Blueprint('guilds_admin', __name__)
bp = Blueprint("guilds_admin", __name__)
@bp.route('/<int:guild_id>', methods=['GET'])
@bp.route("/<int:guild_id>", methods=["GET"])
async def get_guild(guild_id: int):
"""Get a basic guild payload."""
await admin_check()
@ -40,7 +41,7 @@ async def get_guild(guild_id: int):
return jsonify(guild)
@bp.route('/<int:guild_id>', methods=['PATCH'])
@bp.route("/<int:guild_id>", methods=["PATCH"])
async def update_guild(guild_id: int):
await admin_check()
@ -48,13 +49,13 @@ async def update_guild(guild_id: int):
# TODO: what happens to the other guild attributes when its
# unavailable? do they vanish?
old_unavailable = app.guild_store.get(guild_id, 'unavailable')
new_unavailable = j.get('unavailable', old_unavailable)
old_unavailable = app.guild_store.get(guild_id, "unavailable")
new_unavailable = j.get("unavailable", old_unavailable)
# always set unavailable status since new_unavailable will be
# old_unavailable when not provided, so we don't need to check if
# j.unavailable is there
app.guild_store.set(guild_id, 'unavailable', j['unavailable'])
app.guild_store.set(guild_id, "unavailable", j["unavailable"])
guild = await app.storage.get_guild(guild_id)
@ -62,17 +63,17 @@ async def update_guild(guild_id: int):
if old_unavailable and not new_unavailable:
# guild became available
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_CREATE', guild)
await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild)
else:
# guild became unavailable
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_DELETE', guild)
await app.dispatcher.dispatch_guild(guild_id, "GUILD_DELETE", guild)
return jsonify(guild)
@bp.route('/<int:guild_id>', methods=['DELETE'])
@bp.route("/<int:guild_id>", methods=["DELETE"])
async def delete_guild_as_admin(guild_id):
"""Delete a single guild via the admin API without ownership checks."""
await admin_check()
await delete_guild(guild_id)
return '', 204
return "", 204

View File

@ -27,13 +27,13 @@ from litecord.types import timestamp_
from litecord.schemas import validate
from litecord.admin_schemas import INSTANCE_INVITE
bp = Blueprint('instance_invites', __name__)
bp = Blueprint("instance_invites", __name__)
ALPHABET = string.ascii_lowercase + string.ascii_uppercase + string.digits
async def _gen_inv() -> str:
"""Generate an invite code"""
return ''.join(choice(ALPHABET) for _ in range(6))
return "".join(choice(ALPHABET) for _ in range(6))
async def gen_inv(ctx) -> str:
@ -41,11 +41,14 @@ async def gen_inv(ctx) -> str:
for _ in range(10):
possible_inv = await _gen_inv()
created_at = await ctx.db.fetchval("""
created_at = await ctx.db.fetchval(
"""
SELECT created_at
FROM instance_invites
WHERE code = $1
""", possible_inv)
""",
possible_inv,
)
if created_at is None:
return possible_inv
@ -53,57 +56,71 @@ async def gen_inv(ctx) -> str:
return None
@bp.route('', methods=['GET'])
@bp.route("", methods=["GET"])
async def _all_instance_invites():
await admin_check()
rows = await app.db.fetch("""
rows = await app.db.fetch(
"""
SELECT code, created_at, uses, max_uses
FROM instance_invites
""")
"""
)
rows = [dict(row) for row in rows]
for row in rows:
row['created_at'] = timestamp_(row['created_at'])
row["created_at"] = timestamp_(row["created_at"])
return jsonify(rows)
@bp.route('', methods=['PUT'])
@bp.route("", methods=["PUT"])
async def _create_invite():
await admin_check()
code = await gen_inv(app)
if code is None:
return 'failed to make invite', 500
return "failed to make invite", 500
j = validate(await request.get_json(), INSTANCE_INVITE)
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO instance_invites (code, max_uses)
VALUES ($1, $2)
""", code, j['max_uses'])
""",
code,
j["max_uses"],
)
inv = dict(await app.db.fetchrow("""
inv = dict(
await app.db.fetchrow(
"""
SELECT code, created_at, uses, max_uses
FROM instance_invites
WHERE code = $1
""", code))
""",
code,
)
)
return jsonify(dict(inv))
@bp.route('/<invite>', methods=['DELETE'])
@bp.route("/<invite>", methods=["DELETE"])
async def _del_invite(invite: str):
await admin_check()
res = await app.db.execute("""
res = await app.db.execute(
"""
DELETE FROM instance_invites
WHERE code = $1
""", invite)
""",
invite,
)
if res.lower() == 'delete 0':
return 'invite not found', 404
if res.lower() == "delete 0":
return "invite not found", 404
return '', 204
return "", 204

View File

@ -25,24 +25,21 @@ from litecord.schemas import validate
from litecord.admin_schemas import USER_CREATE, USER_UPDATE
from litecord.errors import BadRequest, Forbidden
from litecord.utils import async_map
from litecord.blueprints.users import (
delete_user, user_disconnect, mass_user_update
)
from litecord.blueprints.users import delete_user, user_disconnect, mass_user_update
from litecord.enums import UserFlags
bp = Blueprint('users_admin', __name__)
bp = Blueprint("users_admin", __name__)
@bp.route('', methods=['POST'])
@bp.route("", methods=["POST"])
async def _create_user():
await admin_check()
j = validate(await request.get_json(), USER_CREATE)
user_id, _ = await create_user(j['username'], j['email'], j['password'])
user_id, _ = await create_user(j["username"], j["email"], j["password"])
return jsonify(
await app.storage.get_user(user_id)
)
return jsonify(await app.storage.get_user(user_id))
def args_try(args: dict, typ, field: str, default):
@ -51,29 +48,29 @@ def args_try(args: dict, typ, field: str, default):
try:
return typ(args.get(field, default))
except (TypeError, ValueError):
raise BadRequest(f'invalid {field} value')
raise BadRequest(f"invalid {field} value")
@bp.route('', methods=['GET'])
@bp.route("", methods=["GET"])
async def _search_users():
await admin_check()
args = request.args
username, discrim = args.get('username'), args.get('discriminator')
username, discrim = args.get("username"), args.get("discriminator")
per_page = args_try(args, int, 'per_page', 20)
page = args_try(args, int, 'page', 0)
per_page = args_try(args, int, "per_page", 20)
page = args_try(args, int, "page", 0)
if page < 0:
raise BadRequest('invalid page number')
raise BadRequest("invalid page number")
if per_page > 50:
raise BadRequest('invalid per page number')
raise BadRequest("invalid per page number")
# any of those must be available.
if not any((username, discrim)):
raise BadRequest('must insert username or discrim')
raise BadRequest("must insert username or discrim")
wheres, args = [], []
@ -82,29 +79,31 @@ async def _search_users():
args.append(username)
if discrim:
wheres.append(f'discriminator = ${len(args) + 2}')
wheres.append(f"discriminator = ${len(args) + 2}")
args.append(discrim)
where_tot = 'WHERE ' if args else ''
where_tot += ' AND '.join(wheres)
where_tot = "WHERE " if args else ""
where_tot += " AND ".join(wheres)
rows = await app.db.fetch(f"""
rows = await app.db.fetch(
f"""
SELECT id
FROM users
{where_tot}
ORDER BY id ASC
LIMIT {per_page}
OFFSET ($1 * {per_page})
""", page, *args)
rows = [r['id'] for r in rows]
return jsonify(
await async_map(app.storage.get_user, rows)
""",
page,
*args,
)
rows = [r["id"] for r in rows]
@bp.route('/<int:user_id>', methods=['DELETE'])
return jsonify(await async_map(app.storage.get_user, rows))
@bp.route("/<int:user_id>", methods=["DELETE"])
async def _delete_single_user(user_id: int):
await admin_check()
@ -115,13 +114,10 @@ async def _delete_single_user(user_id: int):
new_user = await app.storage.get_user(user_id)
return jsonify({
'old': old_user,
'new': new_user
})
return jsonify({"old": old_user, "new": new_user})
@bp.route('/<int:user_id>', methods=['PATCH'])
@bp.route("/<int:user_id>", methods=["PATCH"])
async def patch_user(user_id: int):
await admin_check()
@ -129,21 +125,25 @@ async def patch_user(user_id: int):
# get the original user for flags checking
user = await app.storage.get_user(user_id)
old_flags = UserFlags.from_int(user['flags'])
old_flags = UserFlags.from_int(user["flags"])
# j.flags is already a UserFlags since we coerce it.
if 'flags' in j:
new_flags = j['flags']
if "flags" in j:
new_flags = j["flags"]
# disallow any changes to the staff badge
if new_flags.is_staff != old_flags.is_staff:
raise Forbidden('you can not change a users staff badge')
raise Forbidden("you can not change a users staff badge")
await app.db.execute("""
await app.db.execute(
"""
UPDATE users
SET flags = $1
WHERE id = $2
""", new_flags.value, user_id)
""",
new_flags.value,
user_id,
)
public_user, _ = await mass_user_update(user_id, app)
return jsonify(public_user)

View File

@ -27,10 +27,10 @@ from litecord.admin_schemas import VOICE_SERVER, VOICE_REGION
from litecord.errors import BadRequest
log = Logger(__name__)
bp = Blueprint('voice_admin', __name__)
bp = Blueprint("voice_admin", __name__)
@bp.route('/regions/<region>', methods=['GET'])
@bp.route("/regions/<region>", methods=["GET"])
async def get_region_servers(region):
"""Return a list of all servers for a region."""
await admin_check()
@ -38,18 +38,25 @@ async def get_region_servers(region):
return jsonify(servers)
@bp.route('/regions', methods=['PUT'])
@bp.route("/regions", methods=["PUT"])
async def insert_new_region():
"""Create a voice region."""
await admin_check()
j = validate(await request.get_json(), VOICE_REGION)
j['id'] = j['id'].lower()
j["id"] = j["id"].lower()
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO voice_regions (id, name, vip, deprecated, custom)
VALUES ($1, $2, $3, $4, $5)
""", j['id'], j['name'], j['vip'], j['deprecated'], j['custom'])
""",
j["id"],
j["name"],
j["vip"],
j["deprecated"],
j["custom"],
)
regions = await app.storage.all_voice_regions()
region_count = len(regions)
@ -57,34 +64,41 @@ async def insert_new_region():
# if region count is 1, this is the first region to be created,
# so we should update all guilds to that region
if region_count == 1:
res = await app.db.execute("""
res = await app.db.execute(
"""
UPDATE guilds
SET region = $1
""", j['id'])
""",
j["id"],
)
log.info('updating guilds to first voice region: {}', res)
log.info("updating guilds to first voice region: {}", res)
return jsonify(regions)
@bp.route('/regions/<region>/servers', methods=['PUT'])
@bp.route("/regions/<region>/servers", methods=["PUT"])
async def put_region_server(region):
"""Insert a voice server to a region"""
await admin_check()
j = validate(await request.get_json(), VOICE_SERVER)
try:
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO voice_servers (hostname, region_id)
VALUES ($1, $2)
""", j['hostname'], region)
""",
j["hostname"],
region,
)
except asyncpg.UniqueViolationError:
raise BadRequest('voice server already exists with given hostname')
raise BadRequest("voice server already exists with given hostname")
return '', 204
return "", 204
@bp.route('/regions/<region>/deprecate', methods=['PUT'])
@bp.route("/regions/<region>/deprecate", methods=["PUT"])
async def deprecate_region(region):
"""Deprecate a voice region."""
await admin_check()
@ -92,13 +106,16 @@ async def deprecate_region(region):
# TODO: write this
await app.voice.disable_region(region)
await app.db.execute("""
await app.db.execute(
"""
UPDATE voice_regions
SET deprecated = true
WHERE id = $1
""", region)
""",
region,
)
return '', 204
return "", 204
async def guild_region_check(app_):
@ -112,10 +129,11 @@ async def guild_region_check(app_):
regions = await app_.storage.all_voice_regions()
if not regions:
log.info('region check: no regions to move guilds to')
log.info("region check: no regions to move guilds to")
return
res = await app_.db.execute("""
res = await app_.db.execute(
"""
UPDATE guilds
SET region = (
SELECT id
@ -124,6 +142,8 @@ async def guild_region_check(app_):
LIMIT 1
)
WHERE region = NULL
""", len(regions))
""",
len(regions),
)
log.info('region check: updating guild.region=null: {!r}', res)
log.info("region check: updating guild.region=null: {!r}", res)

View File

@ -24,16 +24,17 @@ from PIL import Image
from litecord.images import resize_gif
bp = Blueprint('attachments', __name__)
ATTACHMENTS = Path.cwd() / 'attachments'
bp = Blueprint("attachments", __name__)
ATTACHMENTS = Path.cwd() / "attachments"
async def _resize_gif(attach_id: int, resized_path: Path,
width: int, height: int) -> str:
async def _resize_gif(
attach_id: int, resized_path: Path, width: int, height: int
) -> str:
"""Resize a GIF attachment."""
# get original gif bytes
orig_path = ATTACHMENTS / f'{attach_id}.gif'
orig_path = ATTACHMENTS / f"{attach_id}.gif"
orig_bytes = orig_path.read_bytes()
# give them and the target size to the
@ -47,10 +48,7 @@ async def _resize_gif(attach_id: int, resized_path: Path,
return str(resized_path)
FORMAT_HARDCODE = {
'jpg': 'jpeg',
'jpe': 'jpeg'
}
FORMAT_HARDCODE = {"jpg": "jpeg", "jpe": "jpeg"}
def to_format(ext: str) -> str:
@ -63,11 +61,10 @@ def to_format(ext: str) -> str:
return ext
async def _resize(image, attach_id: int, ext: str,
width: int, height: int) -> str:
async def _resize(image, attach_id: int, ext: str, width: int, height: int) -> str:
"""Resize an image."""
# check if we have it on the folder
resized_path = ATTACHMENTS / f'{attach_id}_{width}_{height}.{ext}'
resized_path = ATTACHMENTS / f"{attach_id}_{width}_{height}.{ext}"
# keep a str-fied instance since that is what
# we'll return.
@ -81,7 +78,7 @@ async def _resize(image, attach_id: int, ext: str,
# the process is different for gif files because we need
# gifsicle. doing it manually is too troublesome.
if ext == 'gif':
if ext == "gif":
return await _resize_gif(attach_id, resized_path, width, height)
# NOTE: this is the same resize mode for icons.
@ -91,38 +88,42 @@ async def _resize(image, attach_id: int, ext: str,
return resized_path_s
@bp.route('/attachments'
'/<int:channel_id>/<int:message_id>/<filename>',
methods=['GET'])
async def _get_attachment(channel_id: int, message_id: int,
filename: str):
@bp.route(
"/attachments" "/<int:channel_id>/<int:message_id>/<filename>", methods=["GET"]
)
async def _get_attachment(channel_id: int, message_id: int, filename: str):
attach_id = await app.db.fetchval("""
attach_id = await app.db.fetchval(
"""
SELECT id
FROM attachments
WHERE channel_id = $1
AND message_id = $2
AND filename = $3
""", channel_id, message_id, filename)
""",
channel_id,
message_id,
filename,
)
if attach_id is None:
return '', 404
return "", 404
ext = filename.split('.')[-1]
filepath = f'./attachments/{attach_id}.{ext}'
ext = filename.split(".")[-1]
filepath = f"./attachments/{attach_id}.{ext}"
image = Image.open(filepath)
im_width, im_height = image.size
try:
width = int(request.args.get('width', 0)) or im_width
width = int(request.args.get("width", 0)) or im_width
except ValueError:
return '', 400
return "", 400
try:
height = int(request.args.get('height', 0)) or im_height
height = int(request.args.get("height", 0)) or im_height
except ValueError:
return '', 400
return "", 400
# if width and height are the same (happens if they weren't provided)
if width == im_width and height == im_height:

View File

@ -32,7 +32,7 @@ from litecord.snowflake import get_snowflake
from .invites import use_invite
log = Logger(__name__)
bp = Blueprint('auth', __name__)
bp = Blueprint("auth", __name__)
async def check_password(pwd_hash: str, given_password: str) -> bool:
@ -53,141 +53,139 @@ def make_token(user_id, user_pwd_hash) -> str:
return signer.sign(user_id).decode()
@bp.route('/register', methods=['POST'])
@bp.route("/register", methods=["POST"])
async def register():
"""Register a single user."""
enabled = app.config.get('REGISTRATIONS')
enabled = app.config.get("REGISTRATIONS")
if not enabled:
raise BadRequest('Registrations disabled', {
'email': 'Registrations are disabled.'
})
raise BadRequest(
"Registrations disabled", {"email": "Registrations are disabled."}
)
j = await request.get_json()
if not 'password' in j:
if not "password" in j:
# we need a password to generate a token.
# passwords are optional, so
j['password'] = 'default_password'
j["password"] = "default_password"
j = validate(j, REGISTER)
# they're optional
email = j.get('email')
invite = j.get('invite')
email = j.get("email")
invite = j.get("invite")
username, password = j['username'], j['password']
username, password = j["username"], j["password"]
new_id, pwd_hash = await create_user(
username, email, password, app.db
)
new_id, pwd_hash = await create_user(username, email, password, app.db)
if invite:
try:
await use_invite(new_id, invite)
except Exception:
log.exception('failed to use invite for register {} {!r}',
new_id, invite)
log.exception("failed to use invite for register {} {!r}", new_id, invite)
return jsonify({
'token': make_token(new_id, pwd_hash)
})
return jsonify({"token": make_token(new_id, pwd_hash)})
@bp.route('/register_inv', methods=['POST'])
@bp.route("/register_inv", methods=["POST"])
async def _register_with_invite():
data = await request.form
data = validate(await request.form, REGISTER_WITH_INVITE)
invcode = data['invcode']
invcode = data["invcode"]
row = await app.db.fetchrow("""
row = await app.db.fetchrow(
"""
SELECT uses, max_uses
FROM instance_invites
WHERE code = $1
""", invcode)
""",
invcode,
)
if row is None:
raise BadRequest('unknown instance invite')
raise BadRequest("unknown instance invite")
if row['max_uses'] != -1 and row['uses'] >= row['max_uses']:
raise BadRequest('invite expired')
if row["max_uses"] != -1 and row["uses"] >= row["max_uses"]:
raise BadRequest("invite expired")
await app.db.execute("""
await app.db.execute(
"""
UPDATE instance_invites
SET uses = uses + 1
WHERE code = $1
""", invcode)
""",
invcode,
)
user_id, pwd_hash = await create_user(
data['username'], data['email'], data['password'], app.db)
data["username"], data["email"], data["password"], app.db
)
return jsonify({
'token': make_token(user_id, pwd_hash),
'user_id': str(user_id),
})
return jsonify({"token": make_token(user_id, pwd_hash), "user_id": str(user_id)})
@bp.route('/login', methods=['POST'])
@bp.route("/login", methods=["POST"])
async def login():
j = await request.get_json()
email, password = j['email'], j['password']
email, password = j["email"], j["password"]
row = await app.db.fetchrow("""
row = await app.db.fetchrow(
"""
SELECT id, password_hash
FROM users
WHERE email = $1
""", email)
""",
email,
)
if not row:
return jsonify({'email': ['User not found.']}), 401
return jsonify({"email": ["User not found."]}), 401
user_id, pwd_hash = row
if not await check_password(pwd_hash, password):
return jsonify({'password': ['Password does not match.']}), 401
return jsonify({"password": ["Password does not match."]}), 401
return jsonify({
'token': make_token(user_id, pwd_hash)
})
return jsonify({"token": make_token(user_id, pwd_hash)})
@bp.route('/consent-required', methods=['GET'])
@bp.route("/consent-required", methods=["GET"])
async def consent_required():
return jsonify({
'required': True,
})
return jsonify({"required": True})
@bp.route('/verify/resend', methods=['POST'])
@bp.route("/verify/resend", methods=["POST"])
async def verify_user():
user_id = await token_check()
# TODO: actually verify a user by sending an email
await app.db.execute("""
await app.db.execute(
"""
UPDATE users
SET verified = true
WHERE id = $1
""", user_id)
""",
user_id,
)
new_user = await app.storage.get_user(user_id, True)
await app.dispatcher.dispatch_user(
user_id, 'USER_UPDATE', new_user)
await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", new_user)
return '', 204
return "", 204
@bp.route('/logout', methods=['POST'])
@bp.route("/logout", methods=["POST"])
async def _logout():
"""Called by the client to logout."""
return '', 204
return "", 204
@bp.route('/fingerprint', methods=['POST'])
@bp.route("/fingerprint", methods=["POST"])
async def _fingerprint():
"""No idea what this route is about."""
fingerprint_id = get_snowflake()
fingerprint = f'{fingerprint_id}.{secrets.token_urlsafe(32)}'
fingerprint = f"{fingerprint_id}.{secrets.token_urlsafe(32)}"
return jsonify({
'fingerprint': fingerprint
})
return jsonify({"fingerprint": fingerprint})

View File

@ -21,4 +21,4 @@ from .messages import bp as channel_messages
from .reactions import bp as channel_reactions
from .pins import bp as channel_pins
__all__ = ['channel_messages', 'channel_reactions', 'channel_pins']
__all__ = ["channel_messages", "channel_reactions", "channel_pins"]

View File

@ -30,13 +30,18 @@ class ForbiddenDM(Forbidden):
async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
"""Check if the user can DM the peer."""
# first step is checking if there is a block in any direction
blockrow = await app.db.fetchrow("""
blockrow = await app.db.fetchrow(
"""
SELECT rel_type
FROM relationships
WHERE rel_type = $3
AND user_id IN ($1, $2)
AND peer_id IN ($1, $2)
""", user_id, peer_id, RelationshipType.BLOCK.value)
""",
user_id,
peer_id,
RelationshipType.BLOCK.value,
)
if blockrow is not None:
raise ForbiddenDM()
@ -58,8 +63,8 @@ async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
user_settings = await app.user_storage.get_user_settings(user_id)
peer_settings = await app.user_storage.get_user_settings(peer_id)
restricted_user_ = [int(v) for v in user_settings['restricted_guilds']]
restricted_peer_ = [int(v) for v in peer_settings['restricted_guilds']]
restricted_user_ = [int(v) for v in user_settings["restricted_guilds"]]
restricted_peer_ = [int(v) for v in peer_settings["restricted_guilds"]]
restricted_user = set(restricted_user_)
restricted_peer = set(restricted_peer_)

View File

@ -41,18 +41,18 @@ from litecord.images import try_unlink
log = Logger(__name__)
bp = Blueprint('channel_messages', __name__)
bp = Blueprint("channel_messages", __name__)
def extract_limit(request_, default: int = 50, max_val: int = 100):
"""Extract a limit kwarg."""
try:
limit = int(request_.args.get('limit', default))
limit = int(request_.args.get("limit", default))
if limit not in range(0, max_val + 1):
raise ValueError()
except (TypeError, ValueError):
raise BadRequest('limit not int')
raise BadRequest("limit not int")
return limit
@ -61,27 +61,27 @@ def query_tuple_from_args(args: dict, limit: int) -> tuple:
"""Extract a 2-tuple out of request arguments."""
before, after = None, None
if 'around' in request.args:
if "around" in request.args:
average = int(limit / 2)
around = int(args['around'])
around = int(args["around"])
after = around - average
before = around + average
elif 'before' in args:
before = int(args['before'])
elif 'after' in args:
before = int(args['after'])
elif "before" in args:
before = int(args["before"])
elif "after" in args:
before = int(args["after"])
return before, after
@bp.route('/<int:channel_id>/messages', methods=['GET'])
@bp.route("/<int:channel_id>/messages", methods=["GET"])
async def get_messages(channel_id):
user_id = await token_check()
ctype, peer_id = await channel_check(user_id, channel_id)
await channel_perm_check(user_id, channel_id, 'read_history')
await channel_perm_check(user_id, channel_id, "read_history")
if ctype == ChannelType.DM:
# make sure both parties will be subbed
@ -91,42 +91,45 @@ async def get_messages(channel_id):
limit = extract_limit(request, 50)
where_clause = ''
where_clause = ""
before, after = query_tuple_from_args(request.args, limit)
if before:
where_clause += f'AND id < {before}'
where_clause += f"AND id < {before}"
if after:
where_clause += f'AND id > {after}'
where_clause += f"AND id > {after}"
message_ids = await app.db.fetch(f"""
message_ids = await app.db.fetch(
f"""
SELECT id
FROM messages
WHERE channel_id = $1 {where_clause}
ORDER BY id DESC
LIMIT {limit}
""", channel_id)
""",
channel_id,
)
result = []
for message_id in message_ids:
msg = await app.storage.get_message(message_id['id'], user_id)
msg = await app.storage.get_message(message_id["id"], user_id)
if msg is None:
continue
result.append(msg)
log.info('Fetched {} messages', len(result))
log.info("Fetched {} messages", len(result))
return jsonify(result)
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['GET'])
@bp.route("/<int:channel_id>/messages/<int:message_id>", methods=["GET"])
async def get_single_message(channel_id, message_id):
user_id = await token_check()
await channel_check(user_id, channel_id)
await channel_perm_check(user_id, channel_id, 'read_history')
await channel_perm_check(user_id, channel_id, "read_history")
message = await app.storage.get_message(message_id, user_id)
@ -142,11 +145,15 @@ async def _dm_pre_dispatch(channel_id, peer_id):
# check the other party's dm_channel_state
dm_state = await app.db.fetchval("""
dm_state = await app.db.fetchval(
"""
SELECT dm_id
FROM dm_channel_state
WHERE user_id = $1 AND dm_id = $2
""", peer_id, channel_id)
""",
peer_id,
channel_id,
)
if dm_state:
# the peer already has the channel
@ -157,18 +164,19 @@ async def _dm_pre_dispatch(channel_id, peer_id):
# dispatch CHANNEL_CREATE so the client knows which
# channel the future event is about
await app.dispatcher.dispatch_user(peer_id, 'CHANNEL_CREATE', dm_chan)
await app.dispatcher.dispatch_user(peer_id, "CHANNEL_CREATE", dm_chan)
# subscribe the peer to the channel
await app.dispatcher.sub('channel', channel_id, peer_id)
await app.dispatcher.sub("channel", channel_id, peer_id)
# insert it on dm_channel_state so the client
# is subscribed on the future
await try_dm_state(peer_id, channel_id)
async def create_message(channel_id: int, actual_guild_id: int,
author_id: int, data: dict) -> int:
async def create_message(
channel_id: int, actual_guild_id: int, author_id: int, data: dict
) -> int:
message_id = get_snowflake()
async with app.db.acquire() as conn:
@ -185,32 +193,32 @@ async def create_message(channel_id: int, actual_guild_id: int,
channel_id,
actual_guild_id,
author_id,
data['content'],
data['tts'],
data['everyone_mention'],
data['nonce'],
data["content"],
data["tts"],
data["everyone_mention"],
data["nonce"],
MessageType.DEFAULT.value,
data.get('embeds') or []
data.get("embeds") or [],
)
return message_id
async def msg_guild_text_mentions(payload: dict, guild_id: int,
mentions_everyone: bool, mentions_here: bool):
async def msg_guild_text_mentions(
payload: dict, guild_id: int, mentions_everyone: bool, mentions_here: bool
):
"""Calculates mention data side-effects."""
channel_id = int(payload['channel_id'])
channel_id = int(payload["channel_id"])
# calculate the user ids we'll bump the mention count for
uids = set()
# first is extracting user mentions
for mention in payload['mentions']:
uids.add(int(mention['id']))
for mention in payload["mentions"]:
uids.add(int(mention["id"]))
# then role mentions
for role_mention in payload['mention_roles']:
for role_mention in payload["mention_roles"]:
role_id = int(role_mention)
member_ids = await app.storage.get_role_members(role_id)
@ -223,11 +231,14 @@ async def msg_guild_text_mentions(payload: dict, guild_id: int,
if mentions_here:
uids = set()
await app.db.execute("""
await app.db.execute(
"""
UPDATE user_read_state
SET mention_count = mention_count + 1
WHERE channel_id = $1
""", channel_id)
""",
channel_id,
)
# at-here updates the read state
# for all users, including the ones
@ -238,19 +249,26 @@ async def msg_guild_text_mentions(payload: dict, guild_id: int,
member_ids = await app.storage.get_member_ids(guild_id)
await app.db.executemany("""
await app.db.executemany(
"""
UPDATE user_read_state
SET mention_count = mention_count + 1
WHERE channel_id = $1 AND user_id = $2
""", [(channel_id, uid) for uid in member_ids])
""",
[(channel_id, uid) for uid in member_ids],
)
for user_id in uids:
await app.db.execute("""
await app.db.execute(
"""
UPDATE user_read_state
SET mention_count = mention_count + 1
WHERE user_id = $1
AND channel_id = $2
""", user_id, channel_id)
""",
user_id,
channel_id,
)
async def msg_create_request() -> tuple:
@ -264,12 +282,12 @@ async def msg_create_request() -> tuple:
# NOTE: embed isn't set on form data
json_from_form = {
'content': form.get('content', ''),
'nonce': form.get('nonce', '0'),
'tts': json.loads(form.get('tts', 'false')),
"content": form.get("content", ""),
"nonce": form.get("nonce", "0"),
"tts": json.loads(form.get("tts", "false")),
}
payload_json = json.loads(form.get('payload_json', '{}'))
payload_json = json.loads(form.get("payload_json", "{}"))
json_from_form.update(request_json)
json_from_form.update(payload_json)
@ -283,20 +301,19 @@ async def msg_create_request() -> tuple:
def msg_create_check_content(payload: dict, files: list, *, use_embeds=False):
"""Check if there is actually any content being sent to us."""
has_content = bool(payload.get('content', ''))
has_content = bool(payload.get("content", ""))
has_files = len(files) > 0
embed_field = 'embeds' if use_embeds else 'embed'
embed_field = "embeds" if use_embeds else "embed"
has_embed = embed_field in payload and payload.get(embed_field) is not None
has_total_content = has_content or has_embed or has_files
if not has_total_content:
raise BadRequest('No content has been provided.')
raise BadRequest("No content has been provided.")
async def msg_add_attachment(message_id: int, channel_id: int,
attachment_file) -> int:
async def msg_add_attachment(message_id: int, channel_id: int, attachment_file) -> int:
"""Add an attachment to a message.
Parameters
@ -318,7 +335,7 @@ async def msg_add_attachment(message_id: int, channel_id: int,
# understand file info
mime = attachment_file.mimetype
is_image = mime.startswith('image/')
is_image = mime.startswith("image/")
img_width, img_height = None, None
@ -346,17 +363,22 @@ async def msg_add_attachment(message_id: int, channel_id: int,
VALUES
($1, $2, $3, $4, $5, $6, $7, $8)
""",
attachment_id, channel_id, message_id,
filename, file_size,
is_image, img_width, img_height)
attachment_id,
channel_id,
message_id,
filename,
file_size,
is_image,
img_width,
img_height,
)
ext = filename.split('.')[-1]
ext = filename.split(".")[-1]
with open(f'attachments/{attachment_id}.{ext}', 'wb') as attach_file:
with open(f"attachments/{attachment_id}.{ext}", "wb") as attach_file:
attach_file.write(attachment_file.stream.read())
log.debug('written {} bytes for attachment id {}',
file_size, attachment_id)
log.debug("written {} bytes for attachment id {}", file_size, attachment_id)
return attachment_id
@ -364,12 +386,12 @@ async def msg_add_attachment(message_id: int, channel_id: int,
async def _spawn_embed(app_, payload, **kwargs):
app_.sched.spawn(
process_url_embed(
app_.config, app_.storage, app_.dispatcher, app_.session,
payload, **kwargs)
app_.config, app_.storage, app_.dispatcher, app_.session, payload, **kwargs
)
)
@bp.route('/<int:channel_id>/messages', methods=['POST'])
@bp.route("/<int:channel_id>/messages", methods=["POST"])
async def _create_message(channel_id):
"""Create a message."""
@ -379,7 +401,7 @@ async def _create_message(channel_id):
actual_guild_id = None
if ctype in GUILD_CHANS:
await channel_perm_check(user_id, channel_id, 'send_messages')
await channel_perm_check(user_id, channel_id, "send_messages")
actual_guild_id = guild_id
payload_json, files = await msg_create_request()
@ -394,29 +416,31 @@ async def _create_message(channel_id):
await dm_pre_check(user_id, channel_id, guild_id)
can_everyone = await channel_perm_check(
user_id, channel_id, 'mention_everyone', False
user_id, channel_id, "mention_everyone", False
)
mentions_everyone = ('@everyone' in j['content']) and can_everyone
mentions_here = ('@here' in j['content']) and can_everyone
mentions_everyone = ("@everyone" in j["content"]) and can_everyone
mentions_here = ("@here" in j["content"]) and can_everyone
is_tts = (j.get('tts', False) and
await channel_perm_check(
user_id, channel_id, 'send_tts_messages', False
))
is_tts = j.get("tts", False) and await channel_perm_check(
user_id, channel_id, "send_tts_messages", False
)
message_id = await create_message(
channel_id, actual_guild_id, user_id, {
'content': j['content'],
'tts': is_tts,
'nonce': int(j.get('nonce', 0)),
'everyone_mention': mentions_everyone or mentions_here,
channel_id,
actual_guild_id,
user_id,
{
"content": j["content"],
"tts": is_tts,
"nonce": int(j.get("nonce", 0)),
"everyone_mention": mentions_everyone or mentions_here,
# fill_embed takes care of filling proxy and width/height
'embeds': ([await fill_embed(j['embed'])]
if j.get('embed') is not None
else []),
})
"embeds": (
[await fill_embed(j["embed"])] if j.get("embed") is not None else []
),
},
)
# for each file given, we add it as an attachment
for pre_attachment in files:
@ -429,8 +453,7 @@ async def _create_message(channel_id):
await _dm_pre_dispatch(channel_id, user_id)
await _dm_pre_dispatch(channel_id, guild_id)
await app.dispatcher.dispatch('channel', channel_id,
'MESSAGE_CREATE', payload)
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload)
# spawn url processor for embedding of images
perms = await get_permissions(user_id, channel_id)
@ -438,54 +461,71 @@ async def _create_message(channel_id):
await _spawn_embed(app, payload)
# update read state for the author
await app.db.execute("""
await app.db.execute(
"""
UPDATE user_read_state
SET last_message_id = $1
WHERE channel_id = $2 AND user_id = $3
""", message_id, channel_id, user_id)
""",
message_id,
channel_id,
user_id,
)
if ctype == ChannelType.GUILD_TEXT:
await msg_guild_text_mentions(
payload, guild_id, mentions_everyone, mentions_here)
payload, guild_id, mentions_everyone, mentions_here
)
return jsonify(payload)
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['PATCH'])
@bp.route("/<int:channel_id>/messages/<int:message_id>", methods=["PATCH"])
async def edit_message(channel_id, message_id):
user_id = await token_check()
_ctype, _guild_id = await channel_check(user_id, channel_id)
author_id = await app.db.fetchval("""
author_id = await app.db.fetchval(
"""
SELECT author_id FROM messages
WHERE messages.id = $1
""", message_id)
""",
message_id,
)
if not author_id == user_id:
raise Forbidden('You can not edit this message')
raise Forbidden("You can not edit this message")
j = await request.get_json()
updated = 'content' in j or 'embed' in j
updated = "content" in j or "embed" in j
old_message = await app.storage.get_message(message_id)
if 'content' in j:
await app.db.execute("""
if "content" in j:
await app.db.execute(
"""
UPDATE messages
SET content=$1
WHERE messages.id = $2
""", j['content'], message_id)
""",
j["content"],
message_id,
)
if 'embed' in j:
embeds = [await fill_embed(j['embed'])]
if "embed" in j:
embeds = [await fill_embed(j["embed"])]
await app.db.execute("""
await app.db.execute(
"""
UPDATE messages
SET embeds=$1
WHERE messages.id = $2
""", embeds, message_id)
""",
embeds,
message_id,
)
# do not spawn process_url_embed since we already have embeds.
elif 'content' in j:
elif "content" in j:
# if there weren't any embed changes BUT
# we had a content change, we dispatch process_url_embed but with
# an artificial delay.
@ -495,46 +535,55 @@ async def edit_message(channel_id, message_id):
# BEFORE the MESSAGE_UPDATE with the new embeds (based on content)
perms = await get_permissions(user_id, channel_id)
if perms.bits.embed_links:
await _spawn_embed(app, {
'id': message_id,
'channel_id': channel_id,
'content': j['content'],
'embeds': old_message['embeds']
}, delay=0.2)
await _spawn_embed(
app,
{
"id": message_id,
"channel_id": channel_id,
"content": j["content"],
"embeds": old_message["embeds"],
},
delay=0.2,
)
# only set new timestamp upon actual update
if updated:
await app.db.execute("""
await app.db.execute(
"""
UPDATE messages
SET edited_at = (now() at time zone 'utc')
WHERE id = $1
""", message_id)
""",
message_id,
)
message = await app.storage.get_message(message_id, user_id)
# only dispatch MESSAGE_UPDATE if any update
# actually happened
if updated:
await app.dispatcher.dispatch('channel', channel_id,
'MESSAGE_UPDATE', message)
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_UPDATE", message)
return jsonify(message)
async def _del_msg_fkeys(message_id: int):
attachs = await app.db.fetch("""
attachs = await app.db.fetch(
"""
SELECT id FROM attachments
WHERE message_id = $1
""", message_id)
""",
message_id,
)
attachs = [r['id'] for r in attachs]
attachs = [r["id"] for r in attachs]
attachments = Path('./attachments')
attachments = Path("./attachments")
for attach_id in attachs:
# anything starting with the given attachment shall be
# deleted, because there may be resizes of the original
# attachment laying around.
for filepath in attachments.glob(f'{attach_id}*'):
for filepath in attachments.glob(f"{attach_id}*"):
try_unlink(filepath)
# after trying to delete all available attachments, delete
@ -542,51 +591,64 @@ async def _del_msg_fkeys(message_id: int):
# take the chance and delete all the data from the other tables too!
tables = ['attachments', 'message_webhook_info',
'message_reactions', 'channel_pins']
tables = [
"attachments",
"message_webhook_info",
"message_reactions",
"channel_pins",
]
for table in tables:
await app.db.execute(f"""
await app.db.execute(
f"""
DELETE FROM {table}
WHERE message_id = $1
""", message_id)
""",
message_id,
)
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['DELETE'])
@bp.route("/<int:channel_id>/messages/<int:message_id>", methods=["DELETE"])
async def delete_message(channel_id, message_id):
user_id = await token_check()
_ctype, guild_id = await channel_check(user_id, channel_id)
author_id = await app.db.fetchval("""
author_id = await app.db.fetchval(
"""
SELECT author_id FROM messages
WHERE messages.id = $1
""", message_id)
by_perm = await channel_perm_check(
user_id, channel_id, 'manage_messages', False
""",
message_id,
)
by_perm = await channel_perm_check(user_id, channel_id, "manage_messages", False)
by_ownership = author_id == user_id
can_delete = by_perm or by_ownership
if not can_delete:
raise Forbidden('You can not delete this message')
raise Forbidden("You can not delete this message")
await _del_msg_fkeys(message_id)
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM messages
WHERE messages.id = $1
""", message_id)
""",
message_id,
)
await app.dispatcher.dispatch(
'channel', channel_id,
'MESSAGE_DELETE', {
'id': str(message_id),
'channel_id': str(channel_id),
"channel",
channel_id,
"MESSAGE_DELETE",
{
"id": str(message_id),
"channel_id": str(channel_id),
# for lazy guilds
'guild_id': str(guild_id),
})
"guild_id": str(guild_id),
},
)
return '', 204
return "", 204

View File

@ -28,28 +28,32 @@ from litecord.system_messages import send_sys_message
from litecord.enums import MessageType, SYS_MESSAGES
from litecord.errors import BadRequest
bp = Blueprint('channel_pins', __name__)
bp = Blueprint("channel_pins", __name__)
class SysMsgInvalidAction(BadRequest):
"""Invalid action on a system message."""
error_code = 50021
@bp.route('/<int:channel_id>/pins', methods=['GET'])
@bp.route("/<int:channel_id>/pins", methods=["GET"])
async def get_pins(channel_id):
"""Get the pins for a channel"""
user_id = await token_check()
await channel_check(user_id, channel_id)
ids = await app.db.fetch("""
ids = await app.db.fetch(
"""
SELECT message_id
FROM channel_pins
WHERE channel_id = $1
ORDER BY message_id DESC
""", channel_id)
""",
channel_id,
)
ids = [r['message_id'] for r in ids]
ids = [r["message_id"] for r in ids]
res = []
for message_id in ids:
@ -60,80 +64,96 @@ async def get_pins(channel_id):
return jsonify(res)
@bp.route('/<int:channel_id>/pins/<int:message_id>', methods=['PUT'])
@bp.route("/<int:channel_id>/pins/<int:message_id>", methods=["PUT"])
async def add_pin(channel_id, message_id):
"""Add a pin to a channel"""
user_id = await token_check()
_ctype, guild_id = await channel_check(user_id, channel_id)
await channel_perm_check(user_id, channel_id, 'manage_messages')
await channel_perm_check(user_id, channel_id, "manage_messages")
mtype = await app.db.fetchval("""
mtype = await app.db.fetchval(
"""
SELECT message_type
FROM messages
WHERE id = $1
""", message_id)
""",
message_id,
)
if mtype in SYS_MESSAGES:
raise SysMsgInvalidAction(
'Cannot execute action on a system message')
raise SysMsgInvalidAction("Cannot execute action on a system message")
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO channel_pins (channel_id, message_id)
VALUES ($1, $2)
""", channel_id, message_id)
""",
channel_id,
message_id,
)
row = await app.db.fetchrow("""
row = await app.db.fetchrow(
"""
SELECT message_id
FROM channel_pins
WHERE channel_id = $1
ORDER BY message_id ASC
LIMIT 1
""", channel_id)
timestamp = snowflake_datetime(row['message_id'])
await app.dispatcher.dispatch(
'channel', channel_id, 'CHANNEL_PINS_UPDATE',
{
'channel_id': str(channel_id),
'last_pin_timestamp': timestamp_(timestamp)
}
""",
channel_id,
)
await send_sys_message(app, channel_id,
MessageType.CHANNEL_PINNED_MESSAGE,
message_id, user_id)
timestamp = snowflake_datetime(row["message_id"])
return '', 204
await app.dispatcher.dispatch(
"channel",
channel_id,
"CHANNEL_PINS_UPDATE",
{"channel_id": str(channel_id), "last_pin_timestamp": timestamp_(timestamp)},
)
await send_sys_message(
app, channel_id, MessageType.CHANNEL_PINNED_MESSAGE, message_id, user_id
)
return "", 204
@bp.route('/<int:channel_id>/pins/<int:message_id>', methods=['DELETE'])
@bp.route("/<int:channel_id>/pins/<int:message_id>", methods=["DELETE"])
async def delete_pin(channel_id, message_id):
user_id = await token_check()
_ctype, guild_id = await channel_check(user_id, channel_id)
await channel_perm_check(user_id, channel_id, 'manage_messages')
await channel_perm_check(user_id, channel_id, "manage_messages")
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM channel_pins
WHERE channel_id = $1 AND message_id = $2
""", channel_id, message_id)
""",
channel_id,
message_id,
)
row = await app.db.fetchrow("""
row = await app.db.fetchrow(
"""
SELECT message_id
FROM channel_pins
WHERE channel_id = $1
ORDER BY message_id ASC
LIMIT 1
""", channel_id)
""",
channel_id,
)
timestamp = snowflake_datetime(row['message_id'])
timestamp = snowflake_datetime(row["message_id"])
await app.dispatcher.dispatch(
'channel', channel_id, 'CHANNEL_PINS_UPDATE', {
'channel_id': str(channel_id),
'last_pin_timestamp': timestamp.isoformat()
})
"channel",
channel_id,
"CHANNEL_PINS_UPDATE",
{"channel_id": str(channel_id), "last_pin_timestamp": timestamp.isoformat()},
)
return '', 204
return "", 204

View File

@ -26,17 +26,15 @@ from logbook import Logger
from litecord.utils import async_map
from litecord.blueprints.auth import token_check
from litecord.blueprints.checks import channel_check, channel_perm_check
from litecord.blueprints.channel.messages import (
query_tuple_from_args, extract_limit
)
from litecord.blueprints.channel.messages import query_tuple_from_args, extract_limit
from litecord.enums import GUILD_CHANS
log = Logger(__name__)
bp = Blueprint('channel_reactions', __name__)
bp = Blueprint("channel_reactions", __name__)
BASEPATH = '/<int:channel_id>/messages/<int:message_id>/reactions'
BASEPATH = "/<int:channel_id>/messages/<int:message_id>/reactions"
class EmojiType(IntEnum):
@ -51,16 +49,14 @@ def emoji_info_from_str(emoji: str) -> tuple:
# unicode emoji just have the raw unicode.
# try checking if the emoji is custom or unicode
emoji_type = 0 if ':' in emoji else 1
emoji_type = 0 if ":" in emoji else 1
emoji_type = EmojiType(emoji_type)
# extract the emoji id OR the unicode value of the emoji
# depending if it is custom or not
emoji_id = (int(emoji.split(':')[1])
if emoji_type == EmojiType.CUSTOM
else emoji)
emoji_id = int(emoji.split(":")[1]) if emoji_type == EmojiType.CUSTOM else emoji
emoji_name = emoji.split(':')[0]
emoji_name = emoji.split(":")[0]
return emoji_type, emoji_id, emoji_name
@ -68,27 +64,27 @@ def emoji_info_from_str(emoji: str) -> tuple:
def partial_emoji(emoji_type, emoji_id, emoji_name) -> dict:
print(emoji_type, emoji_id, emoji_name)
return {
'id': None if emoji_type == EmojiType.UNICODE else emoji_id,
'name': emoji_name if emoji_type == EmojiType.UNICODE else emoji_id
"id": None if emoji_type == EmojiType.UNICODE else emoji_id,
"name": emoji_name if emoji_type == EmojiType.UNICODE else emoji_id,
}
def _make_payload(user_id, channel_id, message_id, partial):
return {
'user_id': str(user_id),
'channel_id': str(channel_id),
'message_id': str(message_id),
'emoji': partial
"user_id": str(user_id),
"channel_id": str(channel_id),
"message_id": str(message_id),
"emoji": partial,
}
@bp.route(f'{BASEPATH}/<emoji>/@me', methods=['PUT'])
@bp.route(f"{BASEPATH}/<emoji>/@me", methods=["PUT"])
async def add_reaction(channel_id: int, message_id: int, emoji: str):
"""Put a reaction."""
user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id)
await channel_perm_check(user_id, channel_id, 'read_history')
await channel_perm_check(user_id, channel_id, "read_history")
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
@ -97,52 +93,64 @@ async def add_reaction(channel_id: int, message_id: int, emoji: str):
# ADD_REACTIONS is only checked when this is the first
# reaction in a message.
reaction_count = await app.db.fetchval("""
reaction_count = await app.db.fetchval(
"""
SELECT COUNT(*)
FROM message_reactions
WHERE message_id = $1
AND emoji_type = $2
AND emoji_id = $3
AND emoji_text = $4
""", message_id, emoji_type, emoji_id, emoji_text)
""",
message_id,
emoji_type,
emoji_id,
emoji_text,
)
if reaction_count == 0:
await channel_perm_check(user_id, channel_id, 'add_reactions')
await channel_perm_check(user_id, channel_id, "add_reactions")
await app.db.execute(
"""
INSERT INTO message_reactions (message_id, user_id,
emoji_type, emoji_id, emoji_text)
VALUES ($1, $2, $3, $4, $5)
""", message_id, user_id, emoji_type,
""",
message_id,
user_id,
emoji_type,
# if it is custom, we put the emoji_id on emoji_id
# column, if it isn't, we put it on emoji_text
# column.
emoji_id, emoji_text
emoji_id,
emoji_text,
)
partial = partial_emoji(emoji_type, emoji_id, emoji_name)
payload = _make_payload(user_id, channel_id, message_id, partial)
if ctype in GUILD_CHANS:
payload['guild_id'] = str(guild_id)
payload["guild_id"] = str(guild_id)
await app.dispatcher.dispatch(
'channel', channel_id, 'MESSAGE_REACTION_ADD', payload)
"channel", channel_id, "MESSAGE_REACTION_ADD", payload
)
return '', 204
return "", 204
def emoji_sql(emoji_type, emoji_id, emoji_name, param=4):
"""Extract SQL clauses to search for specific emoji
in the message_reactions table."""
param = f'${param}'
param = f"${param}"
# know which column to filter with
where_ext = (f'AND emoji_id = {param}'
if emoji_type == EmojiType.CUSTOM else
f'AND emoji_text = {param}')
where_ext = (
f"AND emoji_id = {param}"
if emoji_type == EmojiType.CUSTOM
else f"AND emoji_text = {param}"
)
# which emoji to remove (custom or unicode)
main_emoji = emoji_id if emoji_type == EmojiType.CUSTOM else emoji_name
@ -157,8 +165,7 @@ def _emoji_sql_simple(emoji: str, param=4):
return emoji_sql(emoji_type, emoji_id, emoji_name, param)
async def remove_reaction(channel_id: int, message_id: int,
user_id: int, emoji: str):
async def remove_reaction(channel_id: int, message_id: int, user_id: int, emoji: str):
ctype, guild_id = await channel_check(user_id, channel_id)
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
@ -171,39 +178,45 @@ async def remove_reaction(channel_id: int, message_id: int,
AND user_id = $2
AND emoji_type = $3
{where_ext}
""", message_id, user_id, emoji_type, main_emoji)
""",
message_id,
user_id,
emoji_type,
main_emoji,
)
partial = partial_emoji(emoji_type, emoji_id, emoji_name)
payload = _make_payload(user_id, channel_id, message_id, partial)
if ctype in GUILD_CHANS:
payload['guild_id'] = str(guild_id)
payload["guild_id"] = str(guild_id)
await app.dispatcher.dispatch(
'channel', channel_id, 'MESSAGE_REACTION_REMOVE', payload)
"channel", channel_id, "MESSAGE_REACTION_REMOVE", payload
)
@bp.route(f'{BASEPATH}/<emoji>/@me', methods=['DELETE'])
@bp.route(f"{BASEPATH}/<emoji>/@me", methods=["DELETE"])
async def remove_own_reaction(channel_id, message_id, emoji):
"""Remove a reaction."""
user_id = await token_check()
await remove_reaction(channel_id, message_id, user_id, emoji)
return '', 204
return "", 204
@bp.route(f'{BASEPATH}/<emoji>/<int:other_id>', methods=['DELETE'])
@bp.route(f"{BASEPATH}/<emoji>/<int:other_id>", methods=["DELETE"])
async def remove_user_reaction(channel_id, message_id, emoji, other_id):
"""Remove a reaction made by another user."""
user_id = await token_check()
await channel_perm_check(user_id, channel_id, 'manage_messages')
await channel_perm_check(user_id, channel_id, "manage_messages")
await remove_reaction(channel_id, message_id, other_id, emoji)
return '', 204
return "", 204
@bp.route(f'{BASEPATH}/<emoji>', methods=['GET'])
@bp.route(f"{BASEPATH}/<emoji>", methods=["GET"])
async def list_users_reaction(channel_id, message_id, emoji):
"""Get the list of all users who reacted with a certain emoji."""
user_id = await token_check()
@ -215,42 +228,49 @@ async def list_users_reaction(channel_id, message_id, emoji):
limit = extract_limit(request, 25)
before, after = query_tuple_from_args(request.args, limit)
before_clause = 'AND user_id < $2' if before else ''
after_clause = 'AND user_id > $3' if after else ''
before_clause = "AND user_id < $2" if before else ""
after_clause = "AND user_id > $3" if after else ""
where_ext, main_emoji = _emoji_sql_simple(emoji, 4)
rows = await app.db.fetch(f"""
rows = await app.db.fetch(
f"""
SELECT user_id
FROM message_reactions
WHERE message_id = $1 {before_clause} {after_clause} {where_ext}
""", message_id, before, after, main_emoji)
""",
message_id,
before,
after,
main_emoji,
)
user_ids = [r['user_id'] for r in rows]
user_ids = [r["user_id"] for r in rows]
users = await async_map(app.storage.get_user, user_ids)
return jsonify(users)
@bp.route(f'{BASEPATH}', methods=['DELETE'])
@bp.route(f"{BASEPATH}", methods=["DELETE"])
async def remove_all_reactions(channel_id, message_id):
"""Remove all reactions in a message."""
user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id)
await channel_perm_check(user_id, channel_id, 'manage_messages')
await channel_perm_check(user_id, channel_id, "manage_messages")
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM message_reactions
WHERE message_id = $1
""", message_id)
""",
message_id,
)
payload = {
'channel_id': str(channel_id),
'message_id': str(message_id),
}
payload = {"channel_id": str(channel_id), "message_id": str(message_id)}
if ctype in GUILD_CHANS:
payload['guild_id'] = str(guild_id)
payload["guild_id"] = str(guild_id)
await app.dispatcher.dispatch(
'channel', channel_id, 'MESSAGE_REACTION_REMOVE_ALL', payload)
"channel", channel_id, "MESSAGE_REACTION_REMOVE_ALL", payload
)

View File

@ -28,24 +28,26 @@ from litecord.auth import token_check
from litecord.enums import ChannelType, GUILD_CHANS, MessageType, MessageFlags
from litecord.errors import ChannelNotFound, Forbidden, BadRequest
from litecord.schemas import (
validate, CHAN_UPDATE, CHAN_OVERWRITE, SEARCH_CHANNEL, GROUP_DM_UPDATE,
validate,
CHAN_UPDATE,
CHAN_OVERWRITE,
SEARCH_CHANNEL,
GROUP_DM_UPDATE,
BULK_DELETE,
)
from litecord.blueprints.checks import channel_check, channel_perm_check
from litecord.system_messages import send_sys_message
from litecord.blueprints.dm_channels import (
gdm_remove_recipient, gdm_destroy
)
from litecord.blueprints.dm_channels import gdm_remove_recipient, gdm_destroy
from litecord.utils import search_result_from_list
from litecord.embed.messages import process_url_embed, msg_update_embeds
from litecord.snowflake import snowflake_datetime
log = Logger(__name__)
bp = Blueprint('channels', __name__)
bp = Blueprint("channels", __name__)
@bp.route('/<int:channel_id>', methods=['GET'])
@bp.route("/<int:channel_id>", methods=["GET"])
async def get_channel(channel_id):
"""Get a single channel's information"""
user_id = await token_check()
@ -56,7 +58,7 @@ async def get_channel(channel_id):
chan = await app.storage.get_channel(channel_id)
if not chan:
raise ChannelNotFound('single channel not found')
raise ChannelNotFound("single channel not found")
return jsonify(chan)
@ -64,106 +66,129 @@ async def get_channel(channel_id):
async def __guild_chan_sql(guild_id, channel_id, field: str) -> str:
"""Update a guild's channel id field to NULL,
if it was set to the given channel id before."""
return await app.db.execute(f"""
return await app.db.execute(
f"""
UPDATE guilds
SET {field} = NULL
WHERE guilds.id = $1 AND {field} = $2
""", guild_id, channel_id)
""",
guild_id,
channel_id,
)
async def _update_guild_chan_text(guild_id: int, channel_id: int):
res_embed = await __guild_chan_sql(
guild_id, channel_id, 'embed_channel_id')
res_embed = await __guild_chan_sql(guild_id, channel_id, "embed_channel_id")
res_widget = await __guild_chan_sql(
guild_id, channel_id, 'widget_channel_id')
res_widget = await __guild_chan_sql(guild_id, channel_id, "widget_channel_id")
res_system = await __guild_chan_sql(
guild_id, channel_id, 'system_channel_id')
res_system = await __guild_chan_sql(guild_id, channel_id, "system_channel_id")
# if none of them were actually updated,
# ignore and dont dispatch anything
if 'UPDATE 1' not in (res_embed, res_widget, res_system):
if "UPDATE 1" not in (res_embed, res_widget, res_system):
return
# at least one of the fields were updated,
# dispatch GUILD_UPDATE
guild = await app.storage.get_guild(guild_id)
await app.dispatcher.dispatch_guild(
guild_id, 'GUILD_UPDATE', guild)
await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
async def _update_guild_chan_voice(guild_id: int, channel_id: int):
res = await __guild_chan_sql(guild_id, channel_id, 'afk_channel_id')
res = await __guild_chan_sql(guild_id, channel_id, "afk_channel_id")
# guild didnt update
if res == 'UPDATE 0':
if res == "UPDATE 0":
return
guild = await app.storage.get_guild(guild_id)
await app.dispatcher.dispatch_guild(
guild_id, 'GUILD_UPDATE', guild)
await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
async def _update_guild_chan_cat(guild_id: int, channel_id: int):
# get all channels that were childs of the category
childs = await app.db.fetch("""
childs = await app.db.fetch(
"""
SELECT id
FROM guild_channels
WHERE guild_id = $1 AND parent_id = $2
""", guild_id, channel_id)
childs = [c['id'] for c in childs]
""",
guild_id,
channel_id,
)
childs = [c["id"] for c in childs]
# update every child channel to parent_id = NULL
await app.db.execute("""
await app.db.execute(
"""
UPDATE guild_channels
SET parent_id = NULL
WHERE guild_id = $1 AND parent_id = $2
""", guild_id, channel_id)
""",
guild_id,
channel_id,
)
# tell all people in the guild of the category removal
for child_id in childs:
child = await app.storage.get_channel(child_id)
await app.dispatcher.dispatch_guild(
guild_id, 'CHANNEL_UPDATE', child
)
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", child)
async def delete_messages(channel_id):
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM channel_pins
WHERE channel_id = $1
""", channel_id)
""",
channel_id,
)
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM user_read_state
WHERE channel_id = $1
""", channel_id)
""",
channel_id,
)
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM messages
WHERE channel_id = $1
""", channel_id)
""",
channel_id,
)
async def guild_cleanup(channel_id):
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM channel_overwrites
WHERE channel_id = $1
""", channel_id)
""",
channel_id,
)
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM invites
WHERE channel_id = $1
""", channel_id)
""",
channel_id,
)
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM webhooks
WHERE channel_id = $1
""", channel_id)
""",
channel_id,
)
@bp.route('/<int:channel_id>', methods=['DELETE'])
@bp.route("/<int:channel_id>", methods=["DELETE"])
async def close_channel(channel_id):
"""Close or delete a channel."""
user_id = await token_check()
@ -184,9 +209,8 @@ async def close_channel(channel_id):
}[ctype]
main_tbl = {
ChannelType.GUILD_TEXT: 'guild_text_channels',
ChannelType.GUILD_VOICE: 'guild_voice_channels',
ChannelType.GUILD_TEXT: "guild_text_channels",
ChannelType.GUILD_VOICE: "guild_voice_channels",
# TODO: categories?
}[ctype]
@ -199,29 +223,37 @@ async def close_channel(channel_id):
await delete_messages(channel_id)
await guild_cleanup(channel_id)
await app.db.execute(f"""
await app.db.execute(
f"""
DELETE FROM {main_tbl}
WHERE id = $1
""", channel_id)
""",
channel_id,
)
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM guild_channels
WHERE id = $1
""", channel_id)
""",
channel_id,
)
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM channels
WHERE id = $1
""", channel_id)
""",
channel_id,
)
# clean its member list representation
lazy_guilds = app.dispatcher.backends['lazy_guild']
lazy_guilds = app.dispatcher.backends["lazy_guild"]
lazy_guilds.remove_channel(channel_id)
await app.dispatcher.dispatch_guild(
guild_id, 'CHANNEL_DELETE', chan)
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_DELETE", chan)
await app.dispatcher.remove('channel', channel_id)
await app.dispatcher.remove("channel", channel_id)
return jsonify(chan)
if ctype == ChannelType.DM:
@ -231,27 +263,34 @@ async def close_channel(channel_id):
# instead, we close the channel for the user that is making
# the request via removing the link between them and
# the channel on dm_channel_state
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM dm_channel_state
WHERE user_id = $1 AND dm_id = $2
""", user_id, channel_id)
""",
user_id,
channel_id,
)
# unsubscribe
await app.dispatcher.unsub('channel', channel_id, user_id)
await app.dispatcher.unsub("channel", channel_id, user_id)
# nothing happens to the other party of the dm channel
await app.dispatcher.dispatch_user(user_id, 'CHANNEL_DELETE', chan)
await app.dispatcher.dispatch_user(user_id, "CHANNEL_DELETE", chan)
return jsonify(chan)
if ctype == ChannelType.GROUP_DM:
await gdm_remove_recipient(channel_id, user_id)
gdm_count = await app.db.fetchval("""
gdm_count = await app.db.fetchval(
"""
SELECT COUNT(*)
FROM group_dm_members
WHERE id = $1
""", channel_id)
""",
channel_id,
)
if gdm_count == 0:
# destroy dm
@ -261,11 +300,15 @@ async def close_channel(channel_id):
async def _update_pos(channel_id, pos: int):
await app.db.execute("""
await app.db.execute(
"""
UPDATE guild_channels
SET position = $1
WHERE id = $2
""", pos, channel_id)
""",
pos,
channel_id,
)
async def _mass_chan_update(guild_id, channel_ids: List[Optional[int]]):
@ -274,20 +317,19 @@ async def _mass_chan_update(guild_id, channel_ids: List[Optional[int]]):
continue
chan = await app.storage.get_channel(channel_id)
await app.dispatcher.dispatch(
'guild', guild_id, 'CHANNEL_UPDATE', chan)
await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan)
async def _process_overwrites(channel_id: int, overwrites: list):
for overwrite in overwrites:
# 0 for member overwrite, 1 for role overwrite
target_type = 0 if overwrite['type'] == 'member' else 1
target_role = None if target_type == 0 else overwrite['id']
target_user = overwrite['id'] if target_type == 0 else None
target_type = 0 if overwrite["type"] == "member" else 1
target_role = None if target_type == 0 else overwrite["id"]
target_user = overwrite["id"] if target_type == 0 else None
col_name = 'target_user' if target_type == 0 else 'target_role'
constraint_name = f'channel_overwrites_{col_name}_uniq'
col_name = "target_user" if target_type == 0 else "target_role"
constraint_name = f"channel_overwrites_{col_name}_uniq"
await app.db.execute(
f"""
@ -301,53 +343,66 @@ async def _process_overwrites(channel_id: int, overwrites: list):
UPDATE
SET allow = $5, deny = $6
""",
channel_id, target_type,
target_role, target_user,
overwrite['allow'], overwrite['deny'])
channel_id,
target_type,
target_role,
target_user,
overwrite["allow"],
overwrite["deny"],
)
@bp.route('/<int:channel_id>/permissions/<int:overwrite_id>', methods=['PUT'])
@bp.route("/<int:channel_id>/permissions/<int:overwrite_id>", methods=["PUT"])
async def put_channel_overwrite(channel_id: int, overwrite_id: int):
"""Insert or modify a channel overwrite."""
user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id)
if ctype not in GUILD_CHANS:
raise ChannelNotFound('Only usable for guild channels.')
raise ChannelNotFound("Only usable for guild channels.")
await channel_perm_check(user_id, guild_id, 'manage_roles')
await channel_perm_check(user_id, guild_id, "manage_roles")
j = validate(
# inserting a fake id on the payload so validation passes through
{**await request.get_json(), **{'id': -1}},
CHAN_OVERWRITE
{**await request.get_json(), **{"id": -1}},
CHAN_OVERWRITE,
)
await _process_overwrites(channel_id, [{
'allow': j['allow'],
'deny': j['deny'],
'type': j['type'],
'id': overwrite_id
}])
await _process_overwrites(
channel_id,
[
{
"allow": j["allow"],
"deny": j["deny"],
"type": j["type"],
"id": overwrite_id,
}
],
)
await _mass_chan_update(guild_id, [channel_id])
return '', 204
return "", 204
async def _update_channel_common(channel_id, guild_id: int, j: dict):
if 'name' in j:
await app.db.execute("""
if "name" in j:
await app.db.execute(
"""
UPDATE guild_channels
SET name = $1
WHERE id = $2
""", j['name'], channel_id)
""",
j["name"],
channel_id,
)
if 'position' in j:
if "position" in j:
channel_data = await app.storage.get_channel_data(guild_id)
chans = [None] * len(channel_data)
for chandata in channel_data:
chans.insert(chandata['position'], int(chandata['id']))
chans.insert(chandata["position"], int(chandata["id"]))
# are we changing to the left or to the right?
@ -358,7 +413,7 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict):
# channelN-1 going to the position channel2
# was occupying.
current_pos = chans.index(channel_id)
new_pos = j['position']
new_pos = j["position"]
# if the new position is bigger than the current one,
# we're making a left shift of all the channels that are
@ -366,113 +421,136 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict):
left_shift = new_pos > current_pos
# find all channels that we'll have to shift
shift_block = (chans[current_pos:new_pos]
if left_shift else
chans[new_pos:current_pos]
shift_block = (
chans[current_pos:new_pos] if left_shift else chans[new_pos:current_pos]
)
shift = -1 if left_shift else 1
# do the shift (to the left or to the right)
await app.db.executemany("""
await app.db.executemany(
"""
UPDATE guild_channels
SET position = position + $1
WHERE id = $2
""", [(shift, chan_id) for chan_id in shift_block])
""",
[(shift, chan_id) for chan_id in shift_block],
)
await _mass_chan_update(guild_id, shift_block)
# since theres now an empty slot, move current channel to it
await _update_pos(channel_id, new_pos)
if 'channel_overwrites' in j:
overwrites = j['channel_overwrites']
if "channel_overwrites" in j:
overwrites = j["channel_overwrites"]
await _process_overwrites(channel_id, overwrites)
async def _common_guild_chan(channel_id, j: dict):
# common updates to the guild_channels table
for field in [field for field in j.keys()
if field in ('nsfw', 'parent_id')]:
await app.db.execute(f"""
for field in [field for field in j.keys() if field in ("nsfw", "parent_id")]:
await app.db.execute(
f"""
UPDATE guild_channels
SET {field} = $1
WHERE id = $2
""", j[field], channel_id)
""",
j[field],
channel_id,
)
async def _update_text_channel(channel_id: int, j: dict, _user_id: int):
# first do the specific ones related to guild_text_channels
for field in [field for field in j.keys()
if field in ('topic', 'rate_limit_per_user')]:
await app.db.execute(f"""
for field in [
field for field in j.keys() if field in ("topic", "rate_limit_per_user")
]:
await app.db.execute(
f"""
UPDATE guild_text_channels
SET {field} = $1
WHERE id = $2
""", j[field], channel_id)
""",
j[field],
channel_id,
)
await _common_guild_chan(channel_id, j)
async def _update_voice_channel(channel_id: int, j: dict, _user_id: int):
# first do the specific ones in guild_voice_channels
for field in [field for field in j.keys()
if field in ('bitrate', 'user_limit')]:
await app.db.execute(f"""
for field in [field for field in j.keys() if field in ("bitrate", "user_limit")]:
await app.db.execute(
f"""
UPDATE guild_voice_channels
SET {field} = $1
WHERE id = $2
""", j[field], channel_id)
""",
j[field],
channel_id,
)
# yes, i'm letting voice channels have nsfw, you cant stop me
await _common_guild_chan(channel_id, j)
async def _update_group_dm(channel_id: int, j: dict, author_id: int):
if 'name' in j:
await app.db.execute("""
if "name" in j:
await app.db.execute(
"""
UPDATE group_dm_channels
SET name = $1
WHERE id = $2
""", j['name'], channel_id)
""",
j["name"],
channel_id,
)
await send_sys_message(
app, channel_id, MessageType.CHANNEL_NAME_CHANGE, author_id
)
if 'icon' in j:
if "icon" in j:
new_icon = await app.icons.update(
'channel-icons', channel_id, j['icon'], always_icon=True
"channel-icons", channel_id, j["icon"], always_icon=True
)
await app.db.execute("""
await app.db.execute(
"""
UPDATE group_dm_channels
SET icon = $1
WHERE id = $2
""", new_icon.icon_hash, channel_id)
""",
new_icon.icon_hash,
channel_id,
)
await send_sys_message(
app, channel_id, MessageType.CHANNEL_ICON_CHANGE, author_id
)
@bp.route('/<int:channel_id>', methods=['PUT', 'PATCH'])
@bp.route("/<int:channel_id>", methods=["PUT", "PATCH"])
async def update_channel(channel_id):
"""Update a channel's information"""
user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id)
if ctype not in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE,
ChannelType.GROUP_DM):
raise ChannelNotFound('unable to edit unsupported chan type')
if ctype not in (
ChannelType.GUILD_TEXT,
ChannelType.GUILD_VOICE,
ChannelType.GROUP_DM,
):
raise ChannelNotFound("unable to edit unsupported chan type")
is_guild = ctype in GUILD_CHANS
if is_guild:
await channel_perm_check(user_id, channel_id, 'manage_channels')
await channel_perm_check(user_id, channel_id, "manage_channels")
j = validate(await request.get_json(),
CHAN_UPDATE if is_guild else GROUP_DM_UPDATE)
j = validate(await request.get_json(), CHAN_UPDATE if is_guild else GROUP_DM_UPDATE)
# TODO: categories
update_handler = {
@ -489,30 +567,32 @@ async def update_channel(channel_id):
chan = await app.storage.get_channel(channel_id)
if is_guild:
await app.dispatcher.dispatch(
'guild', guild_id, 'CHANNEL_UPDATE', chan)
await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan)
else:
await app.dispatcher.dispatch(
'channel', channel_id, 'CHANNEL_UPDATE', chan)
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan)
return jsonify(chan)
@bp.route('/<int:channel_id>/typing', methods=['POST'])
@bp.route("/<int:channel_id>/typing", methods=["POST"])
async def trigger_typing(channel_id):
user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id)
await app.dispatcher.dispatch('channel', channel_id, 'TYPING_START', {
'channel_id': str(channel_id),
'user_id': str(user_id),
'timestamp': int(time.time()),
await app.dispatcher.dispatch(
"channel",
channel_id,
"TYPING_START",
{
"channel_id": str(channel_id),
"user_id": str(user_id),
"timestamp": int(time.time()),
# guild_id for lazy guilds
'guild_id': str(guild_id) if ctype == ChannelType.GUILD_TEXT else None,
})
"guild_id": str(guild_id) if ctype == ChannelType.GUILD_TEXT else None,
},
)
return '', 204
return "", 204
async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
@ -521,7 +601,8 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
if not message_id:
message_id = await app.storage.chan_last_message(channel_id)
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO user_read_state
(user_id, channel_id, last_message_id, mention_count)
VALUES
@ -532,26 +613,31 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
SET last_message_id = $3, mention_count = 0
WHERE user_read_state.user_id = $1
AND user_read_state.channel_id = $2
""", user_id, channel_id, message_id)
""",
user_id,
channel_id,
message_id,
)
if guild_id:
await app.dispatcher.dispatch_user_guild(
user_id, guild_id, 'MESSAGE_ACK', {
'message_id': str(message_id),
'channel_id': str(channel_id)
})
user_id,
guild_id,
"MESSAGE_ACK",
{"message_id": str(message_id), "channel_id": str(channel_id)},
)
else:
# we don't use ChannelDispatcher here because since
# guild_id is None, all user devices are already subscribed
# to the given channel (a dm or a group dm)
await app.dispatcher.dispatch_user(
user_id, 'MESSAGE_ACK', {
'message_id': str(message_id),
'channel_id': str(channel_id)
})
user_id,
"MESSAGE_ACK",
{"message_id": str(message_id), "channel_id": str(channel_id)},
)
@bp.route('/<int:channel_id>/messages/<int:message_id>/ack', methods=['POST'])
@bp.route("/<int:channel_id>/messages/<int:message_id>/ack", methods=["POST"])
async def ack_channel(channel_id, message_id):
"""Acknowledge a channel."""
user_id = await token_check()
@ -562,40 +648,47 @@ async def ack_channel(channel_id, message_id):
await channel_ack(user_id, guild_id, channel_id, message_id)
return jsonify({
return jsonify(
{
# token seems to be used for
# data collection activities,
# so we never use it.
'token': None
})
"token": None
}
)
@bp.route('/<int:channel_id>/messages/ack', methods=['DELETE'])
@bp.route("/<int:channel_id>/messages/ack", methods=["DELETE"])
async def delete_read_state(channel_id):
"""Delete the read state of a channel."""
user_id = await token_check()
await channel_check(user_id, channel_id)
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM user_read_state
WHERE user_id = $1 AND channel_id = $2
""", user_id, channel_id)
""",
user_id,
channel_id,
)
return '', 204
return "", 204
@bp.route('/<int:channel_id>/messages/search', methods=['GET'])
@bp.route("/<int:channel_id>/messages/search", methods=["GET"])
async def _search_channel(channel_id):
"""Search in DMs or group DMs"""
user_id = await token_check()
await channel_check(user_id, channel_id)
await channel_perm_check(user_id, channel_id, 'read_messages')
await channel_perm_check(user_id, channel_id, "read_messages")
j = validate(dict(request.args), SEARCH_CHANNEL)
# main search query
# the context (before/after) columns are copied from the guilds blueprint.
rows = await app.db.fetch(f"""
rows = await app.db.fetch(
f"""
SELECT orig.id AS current_id,
COUNT(*) OVER() AS total_results,
array((SELECT messages.id AS before_id
@ -611,28 +704,40 @@ async def _search_channel(channel_id):
ORDER BY orig.id DESC
LIMIT 50
OFFSET $2
""", channel_id, j['offset'], j['content'])
""",
channel_id,
j["offset"],
j["content"],
)
return jsonify(await search_result_from_list(rows))
# NOTE that those functions stay here until some other
# route or code wants it.
async def _msg_update_flags(message_id: int, flags: int):
await app.db.execute("""
await app.db.execute(
"""
UPDATE messages
SET flags = $1
WHERE id = $2
""", flags, message_id)
""",
flags,
message_id,
)
async def _msg_get_flags(message_id: int):
return await app.db.fetchval("""
return await app.db.fetchval(
"""
SELECT flags
FROM messages
WHERE id = $1
""", message_id)
""",
message_id,
)
async def _msg_set_flags(message_id: int, new_flags: int):
@ -647,8 +752,9 @@ async def _msg_unset_flags(message_id: int, unset_flags: int):
await _msg_update_flags(message_id, flags)
@bp.route('/<int:channel_id>/messages/<int:message_id>/suppress-embeds',
methods=['POST'])
@bp.route(
"/<int:channel_id>/messages/<int:message_id>/suppress-embeds", methods=["POST"]
)
async def suppress_embeds(channel_id: int, message_id: int):
"""Toggle the embeds in a message.
@ -661,29 +767,27 @@ async def suppress_embeds(channel_id: int, message_id: int):
# the checks here have been copied from the delete_message()
# handler on blueprints.channel.messages. maybe we can combine
# them someday?
author_id = await app.db.fetchval("""
author_id = await app.db.fetchval(
"""
SELECT author_id FROM messages
WHERE messages.id = $1
""", message_id)
""",
message_id,
)
by_perms = await channel_perm_check(
user_id, channel_id, 'manage_messages', False)
by_perms = await channel_perm_check(user_id, channel_id, "manage_messages", False)
by_author = author_id == user_id
can_suppress = by_perms or by_author
if not can_suppress:
raise Forbidden('Not enough permissions.')
raise Forbidden("Not enough permissions.")
j = validate(
await request.get_json(),
{'suppress': {'type': 'boolean'}},
)
j = validate(await request.get_json(), {"suppress": {"type": "boolean"}})
suppress = j['suppress']
suppress = j["suppress"]
message = await app.storage.get_message(message_id)
url_embeds = sum(
1 for embed in message['embeds'] if embed['type'] == 'url')
url_embeds = sum(1 for embed in message["embeds"] if embed["type"] == "url")
# NOTE for any future self. discord doing flags an optional thing instead
# of just giving 0 is a pretty bad idea because now i have to deal with
@ -693,8 +797,7 @@ async def suppress_embeds(channel_id: int, message_id: int):
# delete all embeds then dispatch an update
await _msg_set_flags(message_id, MessageFlags.suppress_embeds)
message['flags'] = \
message.get('flags', 0) | MessageFlags.suppress_embeds
message["flags"] = message.get("flags", 0) | MessageFlags.suppress_embeds
await msg_update_embeds(message, [], app.storage, app.dispatcher)
elif not suppress and not url_embeds:
@ -702,30 +805,29 @@ async def suppress_embeds(channel_id: int, message_id: int):
await _msg_unset_flags(message_id, MessageFlags.suppress_embeds)
try:
message.pop('flags')
message.pop("flags")
except KeyError:
pass
app.sched.spawn(
process_url_embed(
app.config, app.storage, app.dispatcher, app.session,
message
app.config, app.storage, app.dispatcher, app.session, message
)
)
return '', 204
return "", 204
@bp.route('/<int:channel_id>/messages/bulk-delete', methods=['POST'])
@bp.route("/<int:channel_id>/messages/bulk-delete", methods=["POST"])
async def bulk_delete(channel_id: int):
user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id)
guild_id = guild_id if ctype in GUILD_CHANS else None
await channel_perm_check(user_id, channel_id, 'manage_messages')
await channel_perm_check(user_id, channel_id, "manage_messages")
j = validate(await request.get_json(), BULK_DELETE)
message_ids = set(j['messages'])
message_ids = set(j["messages"])
# as per discord behavior, if any id here is older than two weeks,
# we must error. a cuter behavior would be returning the message ids
@ -738,25 +840,28 @@ async def bulk_delete(channel_id: int):
raise BadRequest(50034)
payload = {
'guild_id': str(guild_id),
'channel_id': str(channel_id),
'ids': list(map(str, message_ids)),
"guild_id": str(guild_id),
"channel_id": str(channel_id),
"ids": list(map(str, message_ids)),
}
# payload.guild_id is optional in the event, not nullable.
if guild_id is None:
payload.pop('guild_id')
payload.pop("guild_id")
res = await app.db.execute("""
res = await app.db.execute(
"""
DELETE FROM messages
WHERE
channel_id = $1
AND ARRAY[id] <@ $2::bigint[]
""", channel_id, list(message_ids))
""",
channel_id,
list(message_ids),
)
if res == 'DELETE 0':
raise BadRequest('No messages were removed')
if res == "DELETE 0":
raise BadRequest("No messages were removed")
await app.dispatcher.dispatch(
'channel', channel_id, 'MESSAGE_DELETE_BULK', payload)
return '', 204
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_DELETE_BULK", payload)
return "", 204

View File

@ -23,46 +23,57 @@ from quart import current_app as app
from litecord.enums import ChannelType, GUILD_CHANS
from litecord.errors import (
GuildNotFound, ChannelNotFound, Forbidden, MissingPermissions
GuildNotFound,
ChannelNotFound,
Forbidden,
MissingPermissions,
)
from litecord.permissions import base_permissions, get_permissions
async def guild_check(user_id: int, guild_id: int):
"""Check if a user is in a guild."""
joined_at = await app.db.fetchval("""
joined_at = await app.db.fetchval(
"""
SELECT joined_at
FROM members
WHERE user_id = $1 AND guild_id = $2
""", user_id, guild_id)
""",
user_id,
guild_id,
)
if not joined_at:
raise GuildNotFound('guild not found')
raise GuildNotFound("guild not found")
async def guild_owner_check(user_id: int, guild_id: int):
"""Check if a user is the owner of the guild."""
owner_id = await app.db.fetchval("""
owner_id = await app.db.fetchval(
"""
SELECT owner_id
FROM guilds
WHERE guilds.id = $1
""", guild_id)
""",
guild_id,
)
if not owner_id:
raise GuildNotFound()
if user_id != owner_id:
raise Forbidden('You are not the owner of the guild')
raise Forbidden("You are not the owner of the guild")
async def channel_check(user_id, channel_id, *,
only: Union[ChannelType, List[ChannelType]] = None):
async def channel_check(
user_id, channel_id, *, only: Union[ChannelType, List[ChannelType]] = None
):
"""Check if the current user is authorized
to read the channel's information."""
chan_type = await app.storage.get_chan_type(channel_id)
if chan_type is None:
raise ChannelNotFound('channel type not found')
raise ChannelNotFound("channel type not found")
ctype = ChannelType(chan_type)
@ -70,14 +81,17 @@ async def channel_check(user_id, channel_id, *,
only = [only]
if only and ctype not in only:
raise ChannelNotFound('invalid channel type')
raise ChannelNotFound("invalid channel type")
if ctype in GUILD_CHANS:
guild_id = await app.db.fetchval("""
guild_id = await app.db.fetchval(
"""
SELECT guild_id
FROM guild_channels
WHERE guild_channels.id = $1
""", channel_id)
""",
channel_id,
)
await guild_check(user_id, guild_id)
return ctype, guild_id
@ -87,11 +101,14 @@ async def channel_check(user_id, channel_id, *,
return ctype, peer_id
if ctype == ChannelType.GROUP_DM:
owner_id = await app.db.fetchval("""
owner_id = await app.db.fetchval(
"""
SELECT owner_id
FROM group_dm_channels
WHERE id = $1
""", channel_id)
""",
channel_id,
)
return ctype, owner_id
@ -102,18 +119,17 @@ async def guild_perm_check(user_id, guild_id, permission: str):
hasperm = getattr(base_perms.bits, permission)
if not hasperm:
raise MissingPermissions('Missing permissions.')
raise MissingPermissions("Missing permissions.")
return bool(hasperm)
async def channel_perm_check(user_id, channel_id,
permission: str, raise_err=True):
async def channel_perm_check(user_id, channel_id, permission: str, raise_err=True):
"""Check channel permissions for a user."""
base_perms = await get_permissions(user_id, channel_id)
hasperm = getattr(base_perms.bits, permission)
if not hasperm and raise_err:
raise MissingPermissions('Missing permissions.')
raise MissingPermissions("Missing permissions.")
return bool(hasperm)

View File

@ -29,21 +29,29 @@ from litecord.system_messages import send_sys_message
from litecord.pubsub.channel import gdm_recipient_view
log = Logger(__name__)
bp = Blueprint('dm_channels', __name__)
bp = Blueprint("dm_channels", __name__)
async def _raw_gdm_add(channel_id, user_id):
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO group_dm_members (id, member_id)
VALUES ($1, $2)
""", channel_id, user_id)
""",
channel_id,
user_id,
)
async def _raw_gdm_remove(channel_id, user_id):
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM group_dm_members
WHERE id = $1 AND member_id = $2
""", channel_id, user_id)
""",
channel_id,
user_id,
)
async def gdm_create(user_id, peer_id) -> int:
@ -53,24 +61,32 @@ async def gdm_create(user_id, peer_id) -> int:
"""
channel_id = get_snowflake()
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO channels (id, channel_type)
VALUES ($1, $2)
""", channel_id, ChannelType.GROUP_DM.value)
""",
channel_id,
ChannelType.GROUP_DM.value,
)
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO group_dm_channels (id, owner_id, name, icon)
VALUES ($1, $2, NULL, NULL)
""", channel_id, user_id)
""",
channel_id,
user_id,
)
await _raw_gdm_add(channel_id, user_id)
await _raw_gdm_add(channel_id, peer_id)
await app.dispatcher.sub('channel', channel_id, user_id)
await app.dispatcher.sub('channel', channel_id, peer_id)
await app.dispatcher.sub("channel", channel_id, user_id)
await app.dispatcher.sub("channel", channel_id, peer_id)
chan = await app.storage.get_channel(channel_id)
await app.dispatcher.dispatch('channel', channel_id, 'CHANNEL_CREATE', chan)
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_CREATE", chan)
return channel_id
@ -89,17 +105,16 @@ async def gdm_add_recipient(channel_id: int, peer_id: int, *, user_id=None):
# the reasoning behind gdm_recipient_view is in its docstring.
await app.dispatcher.dispatch(
'user', peer_id, 'CHANNEL_CREATE', gdm_recipient_view(chan, peer_id))
"user", peer_id, "CHANNEL_CREATE", gdm_recipient_view(chan, peer_id)
)
await app.dispatcher.dispatch(
'channel', channel_id, 'CHANNEL_UPDATE', chan)
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan)
await app.dispatcher.sub('channel', peer_id)
await app.dispatcher.sub("channel", peer_id)
if user_id:
await send_sys_message(
app, channel_id, MessageType.RECIPIENT_ADD,
user_id, peer_id
app, channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id
)
@ -116,22 +131,22 @@ async def gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None):
chan = await app.storage.get_channel(channel_id)
await app.dispatcher.dispatch(
'user', peer_id, 'CHANNEL_DELETE', gdm_recipient_view(chan, user_id))
"user", peer_id, "CHANNEL_DELETE", gdm_recipient_view(chan, user_id)
)
await app.dispatcher.unsub('channel', peer_id)
await app.dispatcher.unsub("channel", peer_id)
await app.dispatcher.dispatch(
'channel', channel_id, 'CHANNEL_RECIPIENT_REMOVE', {
'channel_id': str(channel_id),
'user': await app.storage.get_user(peer_id)
}
"channel",
channel_id,
"CHANNEL_RECIPIENT_REMOVE",
{"channel_id": str(channel_id), "user": await app.storage.get_user(peer_id)},
)
author_id = peer_id if user_id is None else user_id
await send_sys_message(
app, channel_id, MessageType.RECIPIENT_REMOVE,
author_id, peer_id
app, channel_id, MessageType.RECIPIENT_REMOVE, author_id, peer_id
)
@ -139,40 +154,51 @@ async def gdm_destroy(channel_id):
"""Destroy a Group DM."""
chan = await app.storage.get_channel(channel_id)
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM group_dm_members
WHERE id = $1
""", channel_id)
await app.db.execute("""
DELETE FROM group_dm_channels
WHERE id = $1
""", channel_id)
await app.db.execute("""
DELETE FROM channels
WHERE id = $1
""", channel_id)
await app.dispatcher.dispatch(
'channel', channel_id, 'CHANNEL_DELETE', chan
""",
channel_id,
)
await app.dispatcher.remove('channel', channel_id)
await app.db.execute(
"""
DELETE FROM group_dm_channels
WHERE id = $1
""",
channel_id,
)
await app.db.execute(
"""
DELETE FROM channels
WHERE id = $1
""",
channel_id,
)
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_DELETE", chan)
await app.dispatcher.remove("channel", channel_id)
async def gdm_is_member(channel_id: int, user_id: int) -> bool:
"""Return if the given user is a member of the Group DM."""
row = await app.db.fetchval("""
row = await app.db.fetchval(
"""
SELECT id
FROM group_dm_members
WHERE id = $1 AND member_id = $2
""", channel_id, user_id)
""",
channel_id,
user_id,
)
return row is not None
@bp.route('/<int:dm_chan>/recipients/<int:peer_id>', methods=['PUT'])
@bp.route("/<int:dm_chan>/recipients/<int:peer_id>", methods=["PUT"])
async def add_to_group_dm(dm_chan, peer_id):
"""Adds a member to a group dm OR creates a group dm."""
user_id = await token_check()
@ -182,8 +208,7 @@ async def add_to_group_dm(dm_chan, peer_id):
# other_id is the peer of the dm if the given channel is a dm
ctype, other_id = await channel_check(
user_id, dm_chan,
only=[ChannelType.DM, ChannelType.GROUP_DM]
user_id, dm_chan, only=[ChannelType.DM, ChannelType.GROUP_DM]
)
# check relationship with the given user id
@ -191,30 +216,24 @@ async def add_to_group_dm(dm_chan, peer_id):
friends = await app.user_storage.are_friends_with(user_id, peer_id)
if not friends:
raise BadRequest('Cant insert peer into dm')
raise BadRequest("Cant insert peer into dm")
if ctype == ChannelType.DM:
dm_chan = await gdm_create(
user_id, other_id
)
dm_chan = await gdm_create(user_id, other_id)
await gdm_add_recipient(dm_chan, peer_id, user_id=user_id)
return jsonify(
await app.storage.get_channel(dm_chan)
)
return jsonify(await app.storage.get_channel(dm_chan))
@bp.route('/<int:dm_chan>/recipients/<int:peer_id>', methods=['DELETE'])
@bp.route("/<int:dm_chan>/recipients/<int:peer_id>", methods=["DELETE"])
async def remove_from_group_dm(dm_chan, peer_id):
"""Remove users from group dm."""
user_id = await token_check()
_ctype, owner_id = await channel_check(
user_id, dm_chan, only=ChannelType.GROUP_DM
)
_ctype, owner_id = await channel_check(user_id, dm_chan, only=ChannelType.GROUP_DM)
if owner_id != user_id:
raise Forbidden('You are now the owner of the group DM')
raise Forbidden("You are now the owner of the group DM")
await gdm_remove_recipient(dm_chan, peer_id)
return '', 204
return "", 204

View File

@ -30,15 +30,13 @@ from ..snowflake import get_snowflake
from .auth import token_check
from litecord.blueprints.dm_channels import (
gdm_create, gdm_add_recipient
)
from litecord.blueprints.dm_channels import gdm_create, gdm_add_recipient
log = Logger(__name__)
bp = Blueprint('dms', __name__)
bp = Blueprint("dms", __name__)
@bp.route('/@me/channels', methods=['GET'])
@bp.route("/@me/channels", methods=["GET"])
async def get_dms():
"""Get the open DMs for the user."""
user_id = await token_check()
@ -53,11 +51,15 @@ async def try_dm_state(user_id: int, dm_id: int):
Does not do anything if the user is already
in the dm state.
"""
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO dm_channel_state (user_id, dm_id)
VALUES ($1, $2)
ON CONFLICT DO NOTHING
""", user_id, dm_id)
""",
user_id,
dm_id,
)
async def jsonify_dm(dm_id: int, user_id: int):
@ -69,12 +71,16 @@ async def create_dm(user_id, recipient_id):
"""Create a new dm with a user,
or get the existing DM id if it already exists."""
dm_id = await app.db.fetchval("""
dm_id = await app.db.fetchval(
"""
SELECT id
FROM dm_channels
WHERE (party1_id = $1 OR party2_id = $1) AND
(party1_id = $2 OR party2_id = $2)
""", user_id, recipient_id)
""",
user_id,
recipient_id,
)
if dm_id:
return await jsonify_dm(dm_id, user_id)
@ -82,15 +88,24 @@ async def create_dm(user_id, recipient_id):
# if no dm was found, create a new one
dm_id = get_snowflake()
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO channels (id, channel_type)
VALUES ($1, $2)
""", dm_id, ChannelType.DM.value)
""",
dm_id,
ChannelType.DM.value,
)
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO dm_channels (id, party1_id, party2_id)
VALUES ($1, $2, $3)
""", dm_id, user_id, recipient_id)
""",
dm_id,
user_id,
recipient_id,
)
# the dm state is something we use
# to give the currently "open dms"
@ -103,24 +118,24 @@ async def create_dm(user_id, recipient_id):
return await jsonify_dm(dm_id, user_id)
@bp.route('/@me/channels', methods=['POST'])
@bp.route("/@me/channels", methods=["POST"])
async def start_dm():
"""Create a DM with a user."""
user_id = await token_check()
j = validate(await request.get_json(), CREATE_DM)
recipient_id = j['recipient_id']
recipient_id = j["recipient_id"]
return await create_dm(user_id, recipient_id)
@bp.route('/<int:p_user_id>/channels', methods=['POST'])
@bp.route("/<int:p_user_id>/channels", methods=["POST"])
async def create_group_dm(p_user_id: int):
"""Create a DM or a Group DM with user(s)."""
user_id = await token_check()
assert user_id == p_user_id
j = validate(await request.get_json(), CREATE_GROUP_DM)
recipients = j['recipients']
recipients = j["recipients"]
if len(recipients) == 1:
# its a group dm with 1 user... a dm!

View File

@ -23,37 +23,38 @@ from quart import Blueprint, jsonify, current_app as app
from ..auth import token_check
bp = Blueprint('gateway', __name__)
bp = Blueprint("gateway", __name__)
def get_gw():
"""Get the gateway's web"""
proto = 'wss://' if app.config['IS_SSL'] else 'ws://'
proto = "wss://" if app.config["IS_SSL"] else "ws://"
return f'{proto}{app.config["WEBSOCKET_URL"]}'
@bp.route('/gateway')
@bp.route("/gateway")
def api_gateway():
"""Get the raw URL."""
return jsonify({
'url': get_gw()
})
return jsonify({"url": get_gw()})
@bp.route('/gateway/bot')
@bp.route("/gateway/bot")
async def api_gateway_bot():
user_id = await token_check()
guild_count = await app.db.fetchval("""
guild_count = await app.db.fetchval(
"""
SELECT COUNT(*)
FROM members
WHERE user_id = $1
""", user_id)
""",
user_id,
)
shards = max(int(guild_count / 1000), 1)
# get _ws.session ratelimit
ratelimit = app.ratelimiter.get_ratelimit('_ws.session')
ratelimit = app.ratelimiter.get_ratelimit("_ws.session")
bucket = ratelimit.get_bucket(user_id)
# timestamp of bucket reset
@ -62,13 +63,14 @@ async def api_gateway_bot():
# how many seconds until bucket reset
reset_after_ts = reset_ts - time.time()
return jsonify({
'url': get_gw(),
'shards': shards,
'session_start_limit': {
'total': bucket.requests,
'remaining': bucket._tokens,
'reset_after': int(reset_after_ts * 1000),
return jsonify(
{
"url": get_gw(),
"shards": shards,
"session_start_limit": {
"total": bucket.requests,
"remaining": bucket._tokens,
"reset_after": int(reset_after_ts * 1000),
},
}
})
)

View File

@ -23,5 +23,4 @@ from .channels import bp as guild_channels
from .mod import bp as guild_mod
from .emoji import bp as guild_emoji
__all__ = ['guild_roles', 'guild_members', 'guild_channels', 'guild_mod',
'guild_emoji']
__all__ = ["guild_roles", "guild_members", "guild_channels", "guild_mod", "guild_emoji"]

View File

@ -25,23 +25,23 @@ from litecord.errors import BadRequest
from litecord.enums import ChannelType
from litecord.blueprints.guild.roles import gen_pairs
from litecord.schemas import (
validate, ROLE_UPDATE_POSITION, CHAN_CREATE
)
from litecord.blueprints.checks import (
guild_check, guild_owner_check, guild_perm_check
)
from litecord.schemas import validate, ROLE_UPDATE_POSITION, CHAN_CREATE
from litecord.blueprints.checks import guild_check, guild_owner_check, guild_perm_check
bp = Blueprint('guild_channels', __name__)
bp = Blueprint("guild_channels", __name__)
async def _specific_chan_create(channel_id, ctype, **kwargs):
if ctype == ChannelType.GUILD_TEXT:
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO guild_text_channels (id, topic)
VALUES ($1, $2)
""", channel_id, kwargs.get('topic', ''))
""",
channel_id,
kwargs.get("topic", ""),
)
elif ctype == ChannelType.GUILD_VOICE:
await app.db.execute(
"""
@ -49,34 +49,48 @@ async def _specific_chan_create(channel_id, ctype, **kwargs):
VALUES ($1, $2, $3)
""",
channel_id,
kwargs.get('bitrate', 64),
kwargs.get('user_limit', 0)
kwargs.get("bitrate", 64),
kwargs.get("user_limit", 0),
)
async def create_guild_channel(guild_id: int, channel_id: int,
ctype: ChannelType, **kwargs):
async def create_guild_channel(
guild_id: int, channel_id: int, ctype: ChannelType, **kwargs
):
"""Create a channel in a guild."""
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO channels (id, channel_type)
VALUES ($1, $2)
""", channel_id, ctype.value)
""",
channel_id,
ctype.value,
)
# calc new pos
max_pos = await app.db.fetchval("""
max_pos = await app.db.fetchval(
"""
SELECT MAX(position)
FROM guild_channels
WHERE guild_id = $1
""", guild_id)
""",
guild_id,
)
# account for the first channel in a guild too
max_pos = max_pos or 0
# all channels go to guild_channels
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO guild_channels (id, guild_id, name, position)
VALUES ($1, $2, $3, $4)
""", channel_id, guild_id, kwargs['name'], max_pos + 1)
""",
channel_id,
guild_id,
kwargs["name"],
max_pos + 1,
)
# the rest of sql magic is dependant on the channel
# we're creating (a text or voice or category),
@ -84,35 +98,32 @@ async def create_guild_channel(guild_id: int, channel_id: int,
await _specific_chan_create(channel_id, ctype, **kwargs)
@bp.route('/<int:guild_id>/channels', methods=['GET'])
@bp.route("/<int:guild_id>/channels", methods=["GET"])
async def get_guild_channels(guild_id):
"""Get the list of channels in a guild."""
user_id = await token_check()
await guild_check(user_id, guild_id)
return jsonify(
await app.storage.get_channel_data(guild_id))
return jsonify(await app.storage.get_channel_data(guild_id))
@bp.route('/<int:guild_id>/channels', methods=['POST'])
@bp.route("/<int:guild_id>/channels", methods=["POST"])
async def create_channel(guild_id):
"""Create a channel in a guild."""
user_id = await token_check()
j = validate(await request.get_json(), CHAN_CREATE)
await guild_check(user_id, guild_id)
await guild_perm_check(user_id, guild_id, 'manage_channels')
await guild_perm_check(user_id, guild_id, "manage_channels")
channel_type = j.get('type', ChannelType.GUILD_TEXT)
channel_type = j.get("type", ChannelType.GUILD_TEXT)
channel_type = ChannelType(channel_type)
if channel_type not in (ChannelType.GUILD_TEXT,
ChannelType.GUILD_VOICE):
raise BadRequest('Invalid channel type')
if channel_type not in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE):
raise BadRequest("Invalid channel type")
new_channel_id = get_snowflake()
await create_guild_channel(
guild_id, new_channel_id, channel_type, **j)
await create_guild_channel(guild_id, new_channel_id, channel_type, **j)
# TODO: do a better method
# subscribe the currently subscribed users to the new channel
@ -120,14 +131,13 @@ async def create_channel(guild_id):
# since GuildDispatcher calls Storage.get_channel_ids,
# it will subscribe all users to the newly created channel.
guild_pubsub = app.dispatcher.backends['guild']
guild_pubsub = app.dispatcher.backends["guild"]
user_ids = guild_pubsub.state[guild_id]
for uid in user_ids:
await app.dispatcher.sub('guild', guild_id, uid)
await app.dispatcher.sub("guild", guild_id, uid)
chan = await app.storage.get_channel(new_channel_id)
await app.dispatcher.dispatch_guild(
guild_id, 'CHANNEL_CREATE', chan)
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_CREATE", chan)
return jsonify(chan)
@ -135,7 +145,7 @@ async def _chan_update_dispatch(guild_id: int, channel_id: int):
"""Fetch new information about the channel and dispatch
a single CHANNEL_UPDATE event to the guild."""
chan = await app.storage.get_channel(channel_id)
await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_UPDATE', chan)
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", chan)
async def _do_single_swap(guild_id: int, pair: tuple):
@ -149,13 +159,14 @@ async def _do_single_swap(guild_id: int, pair: tuple):
conn = await app.db.acquire()
async with conn.transaction():
await conn.executemany("""
await conn.executemany(
"""
UPDATE guild_channels
SET position = $1
WHERE id = $2 AND guild_id = $3
""", [
(new_pos_1, channel_1, guild_id),
(new_pos_2, channel_2, guild_id)])
""",
[(new_pos_1, channel_1, guild_id), (new_pos_2, channel_2, guild_id)],
)
await app.db.release(conn)
@ -173,30 +184,26 @@ async def _do_channel_swaps(guild_id: int, swap_pairs: list):
await _do_single_swap(guild_id, pair)
@bp.route('/<int:guild_id>/channels', methods=['PATCH'])
@bp.route("/<int:guild_id>/channels", methods=["PATCH"])
async def modify_channel_pos(guild_id):
"""Change positions of channels in a guild."""
user_id = await token_check()
await guild_owner_check(user_id, guild_id)
await guild_perm_check(user_id, guild_id, 'manage_channels')
await guild_perm_check(user_id, guild_id, "manage_channels")
# same thing as guild.roles, so we use
# the same schema and all.
raw_j = await request.get_json()
j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION)
j = j['roles']
j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION)
j = j["roles"]
channels = await app.storage.get_channel_data(guild_id)
channel_positions = {chan['position']: int(chan['id'])
for chan in channels}
channel_positions = {chan["position"]: int(chan["id"]) for chan in channels}
swap_pairs = gen_pairs(
j,
channel_positions
)
swap_pairs = gen_pairs(j, channel_positions)
await _do_channel_swaps(guild_id, swap_pairs)
return '', 204
return "", 204

View File

@ -27,78 +27,85 @@ from litecord.types import KILOBYTES
from litecord.images import parse_data_uri
from litecord.errors import BadRequest
bp = Blueprint('guild.emoji', __name__)
bp = Blueprint("guild.emoji", __name__)
async def _dispatch_emojis(guild_id):
"""Dispatch a Guild Emojis Update payload to a guild."""
await app.dispatcher.dispatch('guild', guild_id, 'GUILD_EMOJIS_UPDATE', {
'guild_id': str(guild_id),
'emojis': await app.storage.get_guild_emojis(guild_id)
})
await app.dispatcher.dispatch(
"guild",
guild_id,
"GUILD_EMOJIS_UPDATE",
{
"guild_id": str(guild_id),
"emojis": await app.storage.get_guild_emojis(guild_id),
},
)
@bp.route('/<int:guild_id>/emojis', methods=['GET'])
@bp.route("/<int:guild_id>/emojis", methods=["GET"])
async def _get_guild_emoji(guild_id):
user_id = await token_check()
await guild_check(user_id, guild_id)
return jsonify(
await app.storage.get_guild_emojis(guild_id)
)
return jsonify(await app.storage.get_guild_emojis(guild_id))
@bp.route('/<int:guild_id>/emojis/<int:emoji_id>', methods=['GET'])
@bp.route("/<int:guild_id>/emojis/<int:emoji_id>", methods=["GET"])
async def _get_guild_emoji_one(guild_id, emoji_id):
user_id = await token_check()
await guild_check(user_id, guild_id)
return jsonify(
await app.storage.get_emoji(emoji_id)
)
return jsonify(await app.storage.get_emoji(emoji_id))
async def _guild_emoji_size_check(guild_id: int, mime: str):
limit = 50
if await app.storage.has_feature(guild_id, 'MORE_EMOJI'):
if await app.storage.has_feature(guild_id, "MORE_EMOJI"):
limit = 200
# NOTE: I'm assuming you can have 200 animated emojis.
select_animated = mime == 'image/gif'
select_animated = mime == "image/gif"
total_emoji = await app.db.fetchval("""
total_emoji = await app.db.fetchval(
"""
SELECT COUNT(*) FROM guild_emoji
WHERE guild_id = $1 AND animated = $2
""", guild_id, select_animated)
""",
guild_id,
select_animated,
)
if total_emoji >= limit:
# TODO: really return a BadRequest? needs more looking.
raise BadRequest(f'too many emoji ({limit})')
raise BadRequest(f"too many emoji ({limit})")
@bp.route('/<int:guild_id>/emojis', methods=['POST'])
@bp.route("/<int:guild_id>/emojis", methods=["POST"])
async def _put_emoji(guild_id):
user_id = await token_check()
await guild_check(user_id, guild_id)
await guild_perm_check(user_id, guild_id, 'manage_emojis')
await guild_perm_check(user_id, guild_id, "manage_emojis")
j = validate(await request.get_json(), NEW_EMOJI)
# we have to parse it before passing on so that we know which
# size to check.
mime, _ = parse_data_uri(j['image'])
mime, _ = parse_data_uri(j["image"])
await _guild_emoji_size_check(guild_id, mime)
emoji_id = get_snowflake()
icon = await app.icons.put(
'emoji', emoji_id, j['image'],
"emoji",
emoji_id,
j["image"],
# limits to emojis
bsize=128 * KILOBYTES, size=(128, 128)
bsize=128 * KILOBYTES,
size=(128, 128),
)
if not icon:
return '', 400
return "", 400
# TODO: better way to detect animated emoji rather than just gifs,
# maybe a list perhaps?
@ -109,25 +116,25 @@ async def _put_emoji(guild_id):
VALUES
($1, $2, $3, $4, $5, $6)
""",
emoji_id, guild_id, user_id,
j['name'],
emoji_id,
guild_id,
user_id,
j["name"],
icon.icon_hash,
icon.mime == 'image/gif'
icon.mime == "image/gif",
)
await _dispatch_emojis(guild_id)
return jsonify(
await app.storage.get_emoji(emoji_id)
)
return jsonify(await app.storage.get_emoji(emoji_id))
@bp.route('/<int:guild_id>/emojis/<int:emoji_id>', methods=['PATCH'])
@bp.route("/<int:guild_id>/emojis/<int:emoji_id>", methods=["PATCH"])
async def _patch_emoji(guild_id, emoji_id):
user_id = await token_check()
await guild_check(user_id, guild_id)
await guild_perm_check(user_id, guild_id, 'manage_emojis')
await guild_perm_check(user_id, guild_id, "manage_emojis")
j = validate(await request.get_json(), PATCH_EMOJI)
emoji = await app.storage.get_emoji(emoji_id)
@ -135,34 +142,39 @@ async def _patch_emoji(guild_id, emoji_id):
# if emoji.name is still the same, we don't update anything
# or send ane events, just return the same emoji we'd send
# as if we updated it.
if j['name'] == emoji['name']:
if j["name"] == emoji["name"]:
return jsonify(emoji)
await app.db.execute("""
await app.db.execute(
"""
UPDATE guild_emoji
SET name = $1
WHERE id = $2
""", j['name'], emoji_id)
""",
j["name"],
emoji_id,
)
await _dispatch_emojis(guild_id)
return jsonify(
await app.storage.get_emoji(emoji_id)
)
return jsonify(await app.storage.get_emoji(emoji_id))
@bp.route('/<int:guild_id>/emojis/<int:emoji_id>', methods=['DELETE'])
@bp.route("/<int:guild_id>/emojis/<int:emoji_id>", methods=["DELETE"])
async def _del_emoji(guild_id, emoji_id):
user_id = await token_check()
await guild_check(user_id, guild_id)
await guild_perm_check(user_id, guild_id, 'manage_emojis')
await guild_perm_check(user_id, guild_id, "manage_emojis")
# TODO: check if actually deleted
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM guild_emoji
WHERE id = $2
""", emoji_id)
""",
emoji_id,
)
await _dispatch_emojis(guild_id)
return '', 204
return "", 204

View File

@ -22,18 +22,14 @@ from quart import Blueprint, request, current_app as app, jsonify
from litecord.blueprints.auth import token_check
from litecord.errors import BadRequest
from litecord.schemas import (
validate, MEMBER_UPDATE
)
from litecord.schemas import validate, MEMBER_UPDATE
from litecord.blueprints.checks import (
guild_check, guild_owner_check, guild_perm_check
)
from litecord.blueprints.checks import guild_check, guild_owner_check, guild_perm_check
bp = Blueprint('guild_members', __name__)
bp = Blueprint("guild_members", __name__)
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['GET'])
@bp.route("/<int:guild_id>/members/<int:member_id>", methods=["GET"])
async def get_guild_member(guild_id, member_id):
"""Get a member's information in a guild."""
user_id = await token_check()
@ -42,7 +38,7 @@ async def get_guild_member(guild_id, member_id):
return jsonify(member)
@bp.route('/<int:guild_id>/members', methods=['GET'])
@bp.route("/<int:guild_id>/members", methods=["GET"])
async def get_members(guild_id):
"""Get members inside a guild."""
user_id = await token_check()
@ -50,34 +46,41 @@ async def get_members(guild_id):
j = await request.get_json()
limit, after = int(j.get('limit', 1)), j.get('after', 0)
limit, after = int(j.get("limit", 1)), j.get("after", 0)
if limit < 1 or limit > 1000:
raise BadRequest('limit not in 1-1000 range')
raise BadRequest("limit not in 1-1000 range")
user_ids = await app.db.fetch(f"""
user_ids = await app.db.fetch(
f"""
SELECT user_id
WHERE guild_id = $1, user_id > $2
LIMIT {limit}
ORDER BY user_id ASC
""", guild_id, after)
""",
guild_id,
after,
)
user_ids = [r[0] for r in user_ids]
members = await app.storage.get_member_multi(guild_id, user_ids)
return jsonify(members)
async def _update_member_roles(guild_id: int, member_id: int,
wanted_roles: set):
async def _update_member_roles(guild_id: int, member_id: int, wanted_roles: set):
"""Update the roles a member has."""
# first, fetch all current roles
roles = await app.db.fetch("""
roles = await app.db.fetch(
"""
SELECT role_id from member_roles
WHERE guild_id = $1 AND user_id = $2
""", guild_id, member_id)
""",
guild_id,
member_id,
)
roles = [r['role_id'] for r in roles]
roles = [r["role_id"] for r in roles]
roles = set(roles)
wanted_roles = set(wanted_roles)
@ -96,26 +99,30 @@ async def _update_member_roles(guild_id: int, member_id: int,
async with conn.transaction():
# add roles
await app.db.executemany("""
await app.db.executemany(
"""
INSERT INTO member_roles (user_id, guild_id, role_id)
VALUES ($1, $2, $3)
""", [(member_id, guild_id, role_id)
for role_id in added_roles])
""",
[(member_id, guild_id, role_id) for role_id in added_roles],
)
# remove roles
await app.db.executemany("""
await app.db.executemany(
"""
DELETE FROM member_roles
WHERE
user_id = $1
AND guild_id = $2
AND role_id = $3
""", [(member_id, guild_id, role_id)
for role_id in removed_roles])
""",
[(member_id, guild_id, role_id) for role_id in removed_roles],
)
await app.db.release(conn)
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['PATCH'])
@bp.route("/<int:guild_id>/members/<int:member_id>", methods=["PATCH"])
async def modify_guild_member(guild_id, member_id):
"""Modify a members' information in a guild."""
user_id = await token_check()
@ -124,96 +131,112 @@ async def modify_guild_member(guild_id, member_id):
j = validate(await request.get_json(), MEMBER_UPDATE)
nick_flag = False
if 'nick' in j:
await guild_perm_check(user_id, guild_id, 'manage_nicknames')
if "nick" in j:
await guild_perm_check(user_id, guild_id, "manage_nicknames")
nick = j['nick'] or None
nick = j["nick"] or None
await app.db.execute("""
await app.db.execute(
"""
UPDATE members
SET nickname = $1
WHERE user_id = $2 AND guild_id = $3
""", nick, member_id, guild_id)
""",
nick,
member_id,
guild_id,
)
nick_flag = True
if 'mute' in j:
await guild_perm_check(user_id, guild_id, 'mute_members')
if "mute" in j:
await guild_perm_check(user_id, guild_id, "mute_members")
await app.db.execute("""
await app.db.execute(
"""
UPDATE members
SET muted = $1
WHERE user_id = $2 AND guild_id = $3
""", j['mute'], member_id, guild_id)
""",
j["mute"],
member_id,
guild_id,
)
if 'deaf' in j:
await guild_perm_check(user_id, guild_id, 'deafen_members')
if "deaf" in j:
await guild_perm_check(user_id, guild_id, "deafen_members")
await app.db.execute("""
await app.db.execute(
"""
UPDATE members
SET deafened = $1
WHERE user_id = $2 AND guild_id = $3
""", j['deaf'], member_id, guild_id)
""",
j["deaf"],
member_id,
guild_id,
)
if 'channel_id' in j:
if "channel_id" in j:
# TODO: check MOVE_MEMBERS and CONNECT to the channel
# TODO: change the member's voice channel
pass
if 'roles' in j:
await guild_perm_check(user_id, guild_id, 'manage_roles')
await _update_member_roles(guild_id, member_id, j['roles'])
if "roles" in j:
await guild_perm_check(user_id, guild_id, "manage_roles")
await _update_member_roles(guild_id, member_id, j["roles"])
member = await app.storage.get_member_data_one(guild_id, member_id)
member.pop('joined_at')
member.pop("joined_at")
# call pres_update for role and nick changes.
partial = {
'roles': member['roles']
}
partial = {"roles": member["roles"]}
if nick_flag:
partial['nick'] = j['nick']
partial["nick"] = j["nick"]
await app.dispatcher.dispatch(
'lazy_guild', guild_id, 'pres_update', user_id, partial)
"lazy_guild", guild_id, "pres_update", user_id, partial
)
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{
'guild_id': str(guild_id)
}, **member})
await app.dispatcher.dispatch_guild(
guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}
)
return '', 204
return "", 204
@bp.route('/<int:guild_id>/members/@me/nick', methods=['PATCH'])
@bp.route("/<int:guild_id>/members/@me/nick", methods=["PATCH"])
async def update_nickname(guild_id):
"""Update a member's nickname in a guild."""
user_id = await token_check()
await guild_check(user_id, guild_id)
j = validate(await request.get_json(), {
'nick': {'type': 'nickname'}
})
j = validate(await request.get_json(), {"nick": {"type": "nickname"}})
nick = j['nick'] or None
nick = j["nick"] or None
await app.db.execute("""
await app.db.execute(
"""
UPDATE members
SET nickname = $1
WHERE user_id = $2 AND guild_id = $3
""", nick, user_id, guild_id)
""",
nick,
user_id,
guild_id,
)
member = await app.storage.get_member_data_one(guild_id, user_id)
member.pop('joined_at')
member.pop("joined_at")
# call pres_update for nick changes, etc.
await app.dispatcher.dispatch(
'lazy_guild', guild_id, 'pres_update', user_id, {
'nick': j['nick']
})
"lazy_guild", guild_id, "pres_update", user_id, {"nick": j["nick"]}
)
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{
'guild_id': str(guild_id)
}, **member})
await app.dispatcher.dispatch_guild(
guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}
)
return j['nick']
return j["nick"]

View File

@ -24,33 +24,38 @@ from litecord.blueprints.checks import guild_perm_check
from litecord.schemas import validate, GUILD_PRUNE
bp = Blueprint('guild_moderation', __name__)
bp = Blueprint("guild_moderation", __name__)
async def remove_member(guild_id: int, member_id: int):
"""Do common tasks related to deleting a member from the guild,
such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE."""
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM members
WHERE guild_id = $1 AND user_id = $2
""", guild_id, member_id)
""",
guild_id,
member_id,
)
await app.dispatcher.dispatch_user_guild(
member_id, guild_id, 'GUILD_DELETE', {
'guild_id': str(guild_id),
'unavailable': False,
})
member_id,
guild_id,
"GUILD_DELETE",
{"guild_id": str(guild_id), "unavailable": False},
)
await app.dispatcher.unsub('guild', guild_id, member_id)
await app.dispatcher.unsub("guild", guild_id, member_id)
await app.dispatcher.dispatch(
'lazy_guild', guild_id, 'remove_member', member_id)
await app.dispatcher.dispatch("lazy_guild", guild_id, "remove_member", member_id)
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', {
'guild_id': str(guild_id),
'user': await app.storage.get_user(member_id),
})
await app.dispatcher.dispatch_guild(
guild_id,
"GUILD_MEMBER_REMOVE",
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
)
async def remove_member_multi(guild_id: int, members: list):
@ -59,84 +64,100 @@ async def remove_member_multi(guild_id: int, members: list):
await remove_member(guild_id, member_id)
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['DELETE'])
@bp.route("/<int:guild_id>/members/<int:member_id>", methods=["DELETE"])
async def kick_guild_member(guild_id, member_id):
"""Remove a member from a guild."""
user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'kick_members')
await guild_perm_check(user_id, guild_id, "kick_members")
await remove_member(guild_id, member_id)
return '', 204
return "", 204
@bp.route('/<int:guild_id>/bans', methods=['GET'])
@bp.route("/<int:guild_id>/bans", methods=["GET"])
async def get_bans(guild_id):
user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'ban_members')
await guild_perm_check(user_id, guild_id, "ban_members")
bans = await app.db.fetch("""
bans = await app.db.fetch(
"""
SELECT user_id, reason
FROM bans
WHERE bans.guild_id = $1
""", guild_id)
""",
guild_id,
)
res = []
for ban in bans:
res.append({
'reason': ban['reason'],
'user': await app.storage.get_user(ban['user_id'])
})
res.append(
{
"reason": ban["reason"],
"user": await app.storage.get_user(ban["user_id"]),
}
)
return jsonify(res)
@bp.route('/<int:guild_id>/bans/<int:member_id>', methods=['PUT'])
@bp.route("/<int:guild_id>/bans/<int:member_id>", methods=["PUT"])
async def create_ban(guild_id, member_id):
user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'ban_members')
await guild_perm_check(user_id, guild_id, "ban_members")
j = await request.get_json()
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO bans (guild_id, user_id, reason)
VALUES ($1, $2, $3)
""", guild_id, member_id, j.get('reason', ''))
""",
guild_id,
member_id,
j.get("reason", ""),
)
await remove_member(guild_id, member_id)
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_ADD', {
'guild_id': str(guild_id),
'user': await app.storage.get_user(member_id)
})
await app.dispatcher.dispatch_guild(
guild_id,
"GUILD_BAN_ADD",
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
)
return '', 204
return "", 204
@bp.route('/<int:guild_id>/bans/<int:banned_id>', methods=['DELETE'])
@bp.route("/<int:guild_id>/bans/<int:banned_id>", methods=["DELETE"])
async def remove_ban(guild_id, banned_id):
user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'ban_members')
await guild_perm_check(user_id, guild_id, "ban_members")
res = await app.db.execute("""
res = await app.db.execute(
"""
DELETE FROM bans
WHERE guild_id = $1 AND user_id = $@
""", guild_id, banned_id)
""",
guild_id,
banned_id,
)
# we don't really need to dispatch GUILD_BAN_REMOVE
# when no bans were actually removed.
if res == 'DELETE 0':
return '', 204
if res == "DELETE 0":
return "", 204
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_REMOVE', {
'guild_id': str(guild_id),
'user': await app.storage.get_user(banned_id)
})
await app.dispatcher.dispatch_guild(
guild_id,
"GUILD_BAN_REMOVE",
{"guild_id": str(guild_id), "user": await app.storage.get_user(banned_id)},
)
return '', 204
return "", 204
async def get_prune(guild_id: int, days: int) -> list:
@ -146,23 +167,30 @@ async def get_prune(guild_id: int, days: int) -> list:
- don't have any roles.
"""
# a good solution would be in pure sql.
member_ids = await app.db.fetch(f"""
member_ids = await app.db.fetch(
f"""
SELECT id
FROM users
JOIN members
ON members.guild_id = $1 AND members.user_id = users.id
WHERE users.last_session < (now() - (interval '{days} days'))
""", guild_id)
""",
guild_id,
)
member_ids = [r['id'] for r in member_ids]
member_ids = [r["id"] for r in member_ids]
members = []
for member_id in member_ids:
role_count = await app.db.fetchval("""
role_count = await app.db.fetchval(
"""
SELECT COUNT(*)
FROM member_roles
WHERE guild_id = $1 AND user_id = $2
""", guild_id, member_id)
""",
guild_id,
member_id,
)
if role_count == 0:
members.append(member_id)
@ -170,33 +198,29 @@ async def get_prune(guild_id: int, days: int) -> list:
return members
@bp.route('/<int:guild_id>/prune', methods=['GET'])
@bp.route("/<int:guild_id>/prune", methods=["GET"])
async def get_guild_prune_count(guild_id):
user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'kick_members')
await guild_perm_check(user_id, guild_id, "kick_members")
j = validate(request.args, GUILD_PRUNE)
days = j['days']
days = j["days"]
member_ids = await get_prune(guild_id, days)
return jsonify({
'pruned': len(member_ids),
})
return jsonify({"pruned": len(member_ids)})
@bp.route('/<int:guild_id>/prune', methods=['POST'])
@bp.route("/<int:guild_id>/prune", methods=["POST"])
async def begin_guild_prune(guild_id):
user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'kick_members')
await guild_perm_check(user_id, guild_id, "kick_members")
j = validate(request.args, GUILD_PRUNE)
days = j['days']
days = j["days"]
member_ids = await get_prune(guild_id, days)
app.loop.create_task(remove_member_multi(guild_id, member_ids))
return jsonify({
'pruned': len(member_ids)
})
return jsonify({"pruned": len(member_ids)})

View File

@ -24,12 +24,8 @@ from logbook import Logger
from litecord.auth import token_check
from litecord.blueprints.checks import (
guild_check, guild_perm_check
)
from litecord.schemas import (
validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION
)
from litecord.blueprints.checks import guild_check, guild_perm_check
from litecord.schemas import validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION
from litecord.snowflake import get_snowflake
from litecord.utils import dict_get
@ -37,22 +33,19 @@ from litecord.permissions import get_role_perms
DEFAULT_EVERYONE_PERMS = 104324161
log = Logger(__name__)
bp = Blueprint('guild_roles', __name__)
bp = Blueprint("guild_roles", __name__)
@bp.route('/<int:guild_id>/roles', methods=['GET'])
@bp.route("/<int:guild_id>/roles", methods=["GET"])
async def get_guild_roles(guild_id):
"""Get all roles in a guild."""
user_id = await token_check()
await guild_check(user_id, guild_id)
return jsonify(
await app.storage.get_role_data(guild_id)
)
return jsonify(await app.storage.get_role_data(guild_id))
async def _maybe_lg(guild_id: int, event: str,
role, force: bool = False):
async def _maybe_lg(guild_id: int, event: str, role, force: bool = False):
# sometimes we want to dispatch an event
# even if the role isn't hoisted
@ -61,11 +54,10 @@ async def _maybe_lg(guild_id: int, event: str,
# check if is a dict first because role_delete
# only receives the role id.
if isinstance(role, dict) and not role['hoist'] and not force:
if isinstance(role, dict) and not role["hoist"] and not force:
return
await app.dispatcher.dispatch(
'lazy_guild', guild_id, event, role)
await app.dispatcher.dispatch("lazy_guild", guild_id, event, role)
async def create_role(guild_id, name: str, **kwargs):
@ -73,18 +65,20 @@ async def create_role(guild_id, name: str, **kwargs):
new_role_id = get_snowflake()
everyone_perms = await get_role_perms(guild_id, guild_id)
default_perms = dict_get(kwargs, 'default_perms',
everyone_perms.binary)
default_perms = dict_get(kwargs, "default_perms", everyone_perms.binary)
# update all roles so that we have space for pos 1, but without
# sending GUILD_ROLE_UPDATE for everyone
await app.db.execute("""
await app.db.execute(
"""
UPDATE roles
SET
position = position + 1
WHERE guild_id = $1
AND NOT (position = 0)
""", guild_id)
""",
guild_id,
)
await app.db.execute(
"""
@ -95,42 +89,39 @@ async def create_role(guild_id, name: str, **kwargs):
new_role_id,
guild_id,
name,
dict_get(kwargs, 'color', 0),
dict_get(kwargs, 'hoist', False),
dict_get(kwargs, "color", 0),
dict_get(kwargs, "hoist", False),
# always set ourselves on position 1
1,
int(dict_get(kwargs, 'permissions', default_perms)),
int(dict_get(kwargs, "permissions", default_perms)),
False,
dict_get(kwargs, 'mentionable', False)
dict_get(kwargs, "mentionable", False),
)
role = await app.storage.get_role(new_role_id, guild_id)
# we need to update the lazy guild handlers for the newly created group
await _maybe_lg(guild_id, 'new_role', role)
await _maybe_lg(guild_id, "new_role", role)
await app.dispatcher.dispatch_guild(
guild_id, 'GUILD_ROLE_CREATE', {
'guild_id': str(guild_id),
'role': role,
})
guild_id, "GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role}
)
return role
@bp.route('/<int:guild_id>/roles', methods=['POST'])
@bp.route("/<int:guild_id>/roles", methods=["POST"])
async def create_guild_role(guild_id: int):
"""Add a role to a guild"""
user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'manage_roles')
await guild_perm_check(user_id, guild_id, "manage_roles")
# client can just send null
j = validate(await request.get_json() or {}, ROLE_CREATE)
role_name = j['name']
j.pop('name')
role_name = j["name"]
j.pop("name")
role = await create_role(guild_id, role_name, **j)
@ -141,12 +132,11 @@ async def _role_update_dispatch(role_id: int, guild_id: int):
"""Dispatch a GUILD_ROLE_UPDATE with updated information on a role."""
role = await app.storage.get_role(role_id, guild_id)
await _maybe_lg(guild_id, 'role_pos_upd', role)
await _maybe_lg(guild_id, "role_pos_upd", role)
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_UPDATE', {
'guild_id': str(guild_id),
'role': role,
})
await app.dispatcher.dispatch_guild(
guild_id, "GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role}
)
return role
@ -166,17 +156,25 @@ async def _role_pairs_update(guild_id: int, pairs: list):
async with conn.transaction():
# update happens in a transaction
# so we don't fuck it up
await conn.execute("""
await conn.execute(
"""
UPDATE roles
SET position = $1
WHERE roles.id = $2
""", new_pos_1, role_1)
""",
new_pos_1,
role_1,
)
await conn.execute("""
await conn.execute(
"""
UPDATE roles
SET position = $1
WHERE roles.id = $2
""", new_pos_2, role_2)
""",
new_pos_2,
role_2,
)
await app.db.release(conn)
@ -184,11 +182,15 @@ async def _role_pairs_update(guild_id: int, pairs: list):
await _role_update_dispatch(role_1, guild_id)
await _role_update_dispatch(role_2, guild_id)
PairList = List[Tuple[Tuple[int, int], Tuple[int, int]]]
def gen_pairs(list_of_changes: List[Dict[str, int]],
def gen_pairs(
list_of_changes: List[Dict[str, int]],
current_state: Dict[int, int],
blacklist: List[int] = None) -> PairList:
blacklist: List[int] = None,
) -> PairList:
"""Generate a list of pairs that, when applied to the database,
will generate the desired state given in list_of_changes.
@ -230,8 +232,9 @@ def gen_pairs(list_of_changes: List[Dict[str, int]],
pairs = []
blacklist = blacklist or []
preferred_state = {element['id']: element['position']
for element in list_of_changes}
preferred_state = {
element["id"]: element["position"] for element in list_of_changes
}
for blacklisted_id in blacklist:
preferred_state.pop(blacklisted_id)
@ -239,7 +242,7 @@ def gen_pairs(list_of_changes: List[Dict[str, int]],
# for each change, we must find a matching change
# in the same list, so we can make a swap pair
for change in list_of_changes:
element_1, new_pos_1 = change['id'], change['position']
element_1, new_pos_1 = change["id"], change["position"]
# check current pairs
# so we don't repeat an element
@ -267,36 +270,34 @@ def gen_pairs(list_of_changes: List[Dict[str, int]],
# if its being swapped to leave space, add it
# to the pairs list
if new_pos_2 is not None:
pairs.append(
((element_1, new_pos_1), (element_2, new_pos_2))
)
pairs.append(((element_1, new_pos_1), (element_2, new_pos_2)))
return pairs
@bp.route('/<int:guild_id>/roles', methods=['PATCH'])
@bp.route("/<int:guild_id>/roles", methods=["PATCH"])
async def update_guild_role_positions(guild_id):
"""Update the positions for a bunch of roles."""
user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'manage_roles')
await guild_perm_check(user_id, guild_id, "manage_roles")
raw_j = await request.get_json()
# we need to do this hackiness because thats
# cerberus for ya.
j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION)
j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION)
# extract the list out
j = j['roles']
j = j["roles"]
log.debug('role stuff: {!r}', j)
log.debug("role stuff: {!r}", j)
all_roles = await app.storage.get_role_data(guild_id)
# we'll have to calculate pairs of changing roles,
# then do the changes, etc.
roles_pos = {role['position']: int(role['id']) for role in all_roles}
roles_pos = {role["position"]: int(role["id"]) for role in all_roles}
# TODO: check if the user can even change the roles in the first place,
# preferrably when we have a proper perms system.
@ -306,10 +307,9 @@ async def update_guild_role_positions(guild_id):
pairs = gen_pairs(
j,
roles_pos,
# always ignore people trying to change
# the @everyone's role position
[guild_id]
[guild_id],
)
await _role_pairs_update(guild_id, pairs)
@ -318,31 +318,36 @@ async def update_guild_role_positions(guild_id):
return jsonify(await app.storage.get_role_data(guild_id))
@bp.route('/<int:guild_id>/roles/<int:role_id>', methods=['PATCH'])
@bp.route("/<int:guild_id>/roles/<int:role_id>", methods=["PATCH"])
async def update_guild_role(guild_id, role_id):
"""Update a single role's information."""
user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'manage_roles')
await guild_perm_check(user_id, guild_id, "manage_roles")
j = validate(await request.get_json(), ROLE_UPDATE)
# we only update ints on the db, not Permissions
j['permissions'] = int(j['permissions'])
j["permissions"] = int(j["permissions"])
for field in j:
await app.db.execute(f"""
await app.db.execute(
f"""
UPDATE roles
SET {field} = $1
WHERE roles.id = $2 AND roles.guild_id = $3
""", j[field], role_id, guild_id)
""",
j[field],
role_id,
guild_id,
)
role = await _role_update_dispatch(role_id, guild_id)
await _maybe_lg(guild_id, 'role_update', role, True)
await _maybe_lg(guild_id, "role_update", role, True)
return jsonify(role)
@bp.route('/<int:guild_id>/roles/<int:role_id>', methods=['DELETE'])
@bp.route("/<int:guild_id>/roles/<int:role_id>", methods=["DELETE"])
async def delete_guild_role(guild_id, role_id):
"""Delete a role.
@ -350,21 +355,26 @@ async def delete_guild_role(guild_id, role_id):
"""
user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'manage_roles')
await guild_perm_check(user_id, guild_id, "manage_roles")
res = await app.db.execute("""
res = await app.db.execute(
"""
DELETE FROM roles
WHERE guild_id = $1 AND id = $2
""", guild_id, role_id)
""",
guild_id,
role_id,
)
if res == 'DELETE 0':
return '', 204
if res == "DELETE 0":
return "", 204
await _maybe_lg(guild_id, 'role_delete', role_id, True)
await _maybe_lg(guild_id, "role_delete", role_id, True)
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_DELETE', {
'guild_id': str(guild_id),
'role_id': str(role_id),
})
await app.dispatcher.dispatch_guild(
guild_id,
"GUILD_ROLE_DELETE",
{"guild_id": str(guild_id), "role_id": str(role_id)},
)
return '', 204
return "", 204

View File

@ -22,16 +22,17 @@ from typing import Optional, List
from quart import Blueprint, request, current_app as app, jsonify
from litecord.blueprints.guild.channels import create_guild_channel
from litecord.blueprints.guild.roles import (
create_role, DEFAULT_EVERYONE_PERMS
)
from litecord.blueprints.guild.roles import create_role, DEFAULT_EVERYONE_PERMS
from ..auth import token_check
from ..snowflake import get_snowflake
from ..enums import ChannelType
from ..schemas import (
validate, GUILD_CREATE, GUILD_UPDATE, SEARCH_CHANNEL,
VANITY_URL_PATCH
validate,
GUILD_CREATE,
GUILD_UPDATE,
SEARCH_CHANNEL,
VANITY_URL_PATCH,
)
from .channels import channel_ack
from .checks import guild_check, guild_owner_check, guild_perm_check
@ -40,7 +41,7 @@ from litecord.errors import BadRequest
from litecord.permissions import get_permissions
bp = Blueprint('guilds', __name__)
bp = Blueprint("guilds", __name__)
async def create_guild_settings(guild_id: int, user_id: int):
@ -49,26 +50,38 @@ async def create_guild_settings(guild_id: int, user_id: int):
# new guild_settings are based off the currently
# set guild settings (for the guild)
m_notifs = await app.db.fetchval("""
m_notifs = await app.db.fetchval(
"""
SELECT default_message_notifications
FROM guilds
WHERE id = $1
""", guild_id)
""",
guild_id,
)
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO guild_settings
(user_id, guild_id, message_notifications)
VALUES
($1, $2, $3)
""", user_id, guild_id, m_notifs)
""",
user_id,
guild_id,
m_notifs,
)
async def add_member(guild_id: int, user_id: int):
"""Add a user to a guild."""
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO members (user_id, guild_id)
VALUES ($1, $2)
""", user_id, guild_id)
""",
user_id,
guild_id,
)
await create_guild_settings(guild_id, user_id)
@ -83,28 +96,28 @@ async def guild_create_roles_prep(guild_id: int, roles: list):
# are patches to the @everyone role
everyone_patches = roles[0]
for field in everyone_patches:
await app.db.execute(f"""
await app.db.execute(
f"""
UPDATE roles
SET {field}={everyone_patches[field]}
WHERE roles.id = $1
""", guild_id)
""",
guild_id,
)
default_perms = (everyone_patches.get('permissions')
or DEFAULT_EVERYONE_PERMS)
default_perms = everyone_patches.get("permissions") or DEFAULT_EVERYONE_PERMS
# from the 2nd and forward,
# should be treated as new roles
for role in roles[1:]:
await create_role(
guild_id, role['name'], default_perms=default_perms, **role
)
await create_role(guild_id, role["name"], default_perms=default_perms, **role)
async def guild_create_channels_prep(guild_id: int, channels: list):
"""Create channels pre-guild create"""
for channel_raw in channels:
channel_id = get_snowflake()
ctype = ChannelType(channel_raw['type'])
ctype = ChannelType(channel_raw["type"])
await create_guild_channel(guild_id, channel_id, ctype)
@ -114,37 +127,29 @@ def sanitize_icon(icon: Optional[str]) -> Optional[str]:
Defaults to a jpeg icon when the header isn't given.
"""
if icon and icon.startswith('data'):
if icon and icon.startswith("data"):
return icon
return (f'data:image/jpeg;base64,{icon}'
if icon
else None)
return f"data:image/jpeg;base64,{icon}" if icon else None
async def _general_guild_icon(scope: str, guild_id: int,
icon: str, **kwargs):
async def _general_guild_icon(scope: str, guild_id: int, icon: str, **kwargs):
encoded = sanitize_icon(icon)
icon_kwargs = {
'always_icon': True
}
icon_kwargs = {"always_icon": True}
if 'size' in kwargs:
icon_kwargs['size'] = kwargs['size']
if "size" in kwargs:
icon_kwargs["size"] = kwargs["size"]
return await app.icons.put(
scope, guild_id, encoded,
**icon_kwargs
)
return await app.icons.put(scope, guild_id, encoded, **icon_kwargs)
async def put_guild_icon(guild_id: int, icon: Optional[str]):
"""Insert a guild icon on the icon database."""
return await _general_guild_icon('guild', guild_id, icon, size=(128, 128))
return await _general_guild_icon("guild", guild_id, icon, size=(128, 128))
@bp.route('', methods=['POST'])
@bp.route("", methods=["POST"])
async def create_guild():
"""Create a new guild, assigning
the user creating it as the owner and
@ -154,8 +159,8 @@ async def create_guild():
guild_id = get_snowflake()
if 'icon' in j:
image = await put_guild_icon(guild_id, j['icon'])
if "icon" in j:
image = await put_guild_icon(guild_id, j["icon"])
image = image.icon_hash
else:
image = None
@ -166,10 +171,16 @@ async def create_guild():
verification_level, default_message_notifications,
explicit_content_filter)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
""", guild_id, j['name'], j['region'], image, user_id,
j.get('verification_level', 0),
j.get('default_message_notifications', 0),
j.get('explicit_content_filter', 0))
""",
guild_id,
j["name"],
j["region"],
image,
user_id,
j.get("verification_level", 0),
j.get("default_message_notifications", 0),
j.get("explicit_content_filter", 0),
)
await add_member(guild_id, user_id)
@ -179,107 +190,127 @@ async def create_guild():
# we also don't use create_role because the id of the role
# is the same as the id of the guild, and create_role
# generates a new snowflake.
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO roles (id, guild_id, name, position, permissions)
VALUES ($1, $2, $3, $4, $5)
""", guild_id, guild_id, '@everyone', 0, DEFAULT_EVERYONE_PERMS)
""",
guild_id,
guild_id,
"@everyone",
0,
DEFAULT_EVERYONE_PERMS,
)
# add the @everyone role to the guild creator
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO member_roles (user_id, guild_id, role_id)
VALUES ($1, $2, $3)
""", user_id, guild_id, guild_id)
""",
user_id,
guild_id,
guild_id,
)
# create a single #general channel.
general_id = get_snowflake()
await create_guild_channel(
guild_id, general_id, ChannelType.GUILD_TEXT,
name='general')
guild_id, general_id, ChannelType.GUILD_TEXT, name="general"
)
if j.get('roles'):
await guild_create_roles_prep(guild_id, j['roles'])
if j.get("roles"):
await guild_create_roles_prep(guild_id, j["roles"])
if j.get('channels'):
await guild_create_channels_prep(guild_id, j['channels'])
if j.get("channels"):
await guild_create_channels_prep(guild_id, j["channels"])
guild_total = await app.storage.get_guild_full(guild_id, user_id, 250)
await app.dispatcher.sub('guild', guild_id, user_id)
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_CREATE', guild_total)
await app.dispatcher.sub("guild", guild_id, user_id)
await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild_total)
return jsonify(guild_total)
@bp.route('/<int:guild_id>', methods=['GET'])
@bp.route("/<int:guild_id>", methods=["GET"])
async def get_guild(guild_id):
"""Get a single guilds' information."""
user_id = await token_check()
await guild_check(user_id, guild_id)
return jsonify(
await app.storage.get_guild_full(guild_id, user_id, 250)
)
return jsonify(await app.storage.get_guild_full(guild_id, user_id, 250))
async def _guild_update_icon(scope: str, guild_id: int,
icon: Optional[str], **kwargs):
async def _guild_update_icon(scope: str, guild_id: int, icon: Optional[str], **kwargs):
"""Update icon."""
new_icon = await app.icons.update(
scope, guild_id, icon, always_icon=True, **kwargs
)
new_icon = await app.icons.update(scope, guild_id, icon, always_icon=True, **kwargs)
table = {
'guild': 'icon',
}.get(scope, scope)
table = {"guild": "icon"}.get(scope, scope)
await app.db.execute(f"""
await app.db.execute(
f"""
UPDATE guilds
SET {table} = $1
WHERE id = $2
""", new_icon.icon_hash, guild_id)
""",
new_icon.icon_hash,
guild_id,
)
async def _guild_update_region(guild_id, region):
is_vip = region.vip
can_vip = await app.storage.has_feature(guild_id, 'VIP_REGIONS')
can_vip = await app.storage.has_feature(guild_id, "VIP_REGIONS")
if is_vip and not can_vip:
raise BadRequest('can not assign guild to vip-only region')
raise BadRequest("can not assign guild to vip-only region")
await app.db.execute("""
await app.db.execute(
"""
UPDATE guilds
SET region = $1
WHERE id = $2
""", region.id, guild_id)
""",
region.id,
guild_id,
)
@bp.route('/<int:guild_id>', methods=['PATCH'])
@bp.route("/<int:guild_id>", methods=["PATCH"])
async def _update_guild(guild_id):
user_id = await token_check()
await guild_check(user_id, guild_id)
await guild_perm_check(user_id, guild_id, 'manage_guild')
await guild_perm_check(user_id, guild_id, "manage_guild")
j = validate(await request.get_json(), GUILD_UPDATE)
if 'owner_id' in j:
if "owner_id" in j:
await guild_owner_check(user_id, guild_id)
await app.db.execute("""
await app.db.execute(
"""
UPDATE guilds
SET owner_id = $1
WHERE id = $2
""", int(j['owner_id']), guild_id)
""",
int(j["owner_id"]),
guild_id,
)
if 'name' in j:
await app.db.execute("""
if "name" in j:
await app.db.execute(
"""
UPDATE guilds
SET name = $1
WHERE id = $2
""", j['name'], guild_id)
""",
j["name"],
guild_id,
)
if 'region' in j:
region = app.voice.lvsp.region(j['region'])
if "region" in j:
region = app.voice.lvsp.region(j["region"])
if region is not None:
await _guild_update_region(guild_id, region)
@ -287,65 +318,77 @@ async def _update_guild(guild_id):
# small guild to work with to_update()
guild = await app.storage.get_guild(guild_id)
if to_update(j, guild, 'icon'):
await _guild_update_icon(
'guild', guild_id, j['icon'], size=(128, 128))
if to_update(j, guild, "icon"):
await _guild_update_icon("guild", guild_id, j["icon"], size=(128, 128))
if to_update(j, guild, 'splash'):
if not await app.storage.has_feature(guild_id, 'INVITE_SPLASH'):
raise BadRequest('guild does not have INVITE_SPLASH feature')
if to_update(j, guild, "splash"):
if not await app.storage.has_feature(guild_id, "INVITE_SPLASH"):
raise BadRequest("guild does not have INVITE_SPLASH feature")
await _guild_update_icon('splash', guild_id, j['splash'])
await _guild_update_icon("splash", guild_id, j["splash"])
if to_update(j, guild, 'banner'):
if not await app.storage.has_feature(guild_id, 'VERIFIED'):
raise BadRequest('guild is not verified')
if to_update(j, guild, "banner"):
if not await app.storage.has_feature(guild_id, "VERIFIED"):
raise BadRequest("guild is not verified")
await _guild_update_icon('banner', guild_id, j['banner'])
await _guild_update_icon("banner", guild_id, j["banner"])
fields = ['verification_level', 'default_message_notifications',
'explicit_content_filter', 'afk_timeout', 'description']
fields = [
"verification_level",
"default_message_notifications",
"explicit_content_filter",
"afk_timeout",
"description",
]
for field in [f for f in fields if f in j]:
await app.db.execute(f"""
await app.db.execute(
f"""
UPDATE guilds
SET {field} = $1
WHERE id = $2
""", j[field], guild_id)
""",
j[field],
guild_id,
)
channel_fields = ['afk_channel_id', 'system_channel_id']
channel_fields = ["afk_channel_id", "system_channel_id"]
for field in [f for f in channel_fields if f in j]:
# setting to null should remove the link between the afk/sys channel
# to the guild.
if j[field] is None:
await app.db.execute(f"""
await app.db.execute(
f"""
UPDATE guilds
SET {field} = NULL
WHERE id = $1
""", guild_id)
""",
guild_id,
)
continue
chan = await app.storage.get_channel(int(j[field]))
if chan is None:
raise BadRequest('invalid channel id')
raise BadRequest("invalid channel id")
if chan['guild_id'] != str(guild_id):
raise BadRequest('channel id not linked to guild')
if chan["guild_id"] != str(guild_id):
raise BadRequest("channel id not linked to guild")
await app.db.execute(f"""
await app.db.execute(
f"""
UPDATE guilds
SET {field} = $1
WHERE id = $2
""", j[field], guild_id)
guild = await app.storage.get_guild_full(
guild_id, user_id
""",
j[field],
guild_id,
)
await app.dispatcher.dispatch_guild(
guild_id, 'GUILD_UPDATE', guild)
guild = await app.storage.get_guild_full(guild_id, user_id)
await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
return jsonify(guild)
@ -354,33 +397,41 @@ async def delete_guild(guild_id: int, *, app_=None):
"""Delete a single guild."""
app_ = app_ or app
await app_.db.execute("""
await app_.db.execute(
"""
DELETE FROM guilds
WHERE guilds.id = $1
""", guild_id)
""",
guild_id,
)
# Discord's client expects IDs being string
await app_.dispatcher.dispatch('guild', guild_id, 'GUILD_DELETE', {
'guild_id': str(guild_id),
'id': str(guild_id),
await app_.dispatcher.dispatch(
"guild",
guild_id,
"GUILD_DELETE",
{
"guild_id": str(guild_id),
"id": str(guild_id),
# 'unavailable': False,
})
},
)
# remove from the dispatcher so nobody
# becomes the little memer that tries to fuck up with
# everybody's gateway
await app_.dispatcher.remove('guild', guild_id)
await app_.dispatcher.remove("guild", guild_id)
@bp.route('/<int:guild_id>', methods=['DELETE'])
@bp.route("/<int:guild_id>", methods=["DELETE"])
# this endpoint is not documented, but used by the official client.
@bp.route('/<int:guild_id>/delete', methods=['POST'])
@bp.route("/<int:guild_id>/delete", methods=["POST"])
async def delete_guild_handler(guild_id):
"""Delete a guild."""
user_id = await token_check()
await guild_owner_check(user_id, guild_id)
await delete_guild(guild_id)
return '', 204
return "", 204
async def fetch_readable_channels(guild_id: int, user_id: int) -> List[int]:
@ -397,7 +448,7 @@ async def fetch_readable_channels(guild_id: int, user_id: int) -> List[int]:
return res
@bp.route('/<int:guild_id>/messages/search', methods=['GET'])
@bp.route("/<int:guild_id>/messages/search", methods=["GET"])
async def search_messages(guild_id):
"""Search messages in a guild.
@ -415,7 +466,8 @@ async def search_messages(guild_id):
# use that list on the main search query.
can_read = await fetch_readable_channels(guild_id, user_id)
rows = await app.db.fetch(f"""
rows = await app.db.fetch(
f"""
SELECT orig.id AS current_id,
COUNT(*) OVER() as total_results,
array((SELECT messages.id AS before_id
@ -432,12 +484,17 @@ async def search_messages(guild_id):
ORDER BY orig.id DESC
LIMIT 50
OFFSET $3
""", guild_id, j['content'], j['offset'], can_read)
""",
guild_id,
j["content"],
j["offset"],
can_read,
)
return jsonify(await search_result_from_list(rows))
@bp.route('/<int:guild_id>/ack', methods=['POST'])
@bp.route("/<int:guild_id>/ack", methods=["POST"])
async def ack_guild(guild_id):
"""ACKnowledge all messages in the guild."""
user_id = await token_check()
@ -448,45 +505,43 @@ async def ack_guild(guild_id):
for chan_id in chan_ids:
await channel_ack(user_id, guild_id, chan_id)
return '', 204
return "", 204
@bp.route('/<int:guild_id>/vanity-url', methods=['GET'])
@bp.route("/<int:guild_id>/vanity-url", methods=["GET"])
async def get_vanity_url(guild_id: int):
"""Get the vanity url of a guild."""
user_id = await token_check()
await guild_perm_check(user_id, guild_id, 'manage_guild')
await guild_perm_check(user_id, guild_id, "manage_guild")
inv_code = await app.storage.vanity_invite(guild_id)
if inv_code is None:
return jsonify({'code': None})
return jsonify({"code": None})
return jsonify(
await app.storage.get_invite(inv_code)
)
return jsonify(await app.storage.get_invite(inv_code))
@bp.route('/<int:guild_id>/vanity-url', methods=['PATCH'])
@bp.route("/<int:guild_id>/vanity-url", methods=["PATCH"])
async def change_vanity_url(guild_id: int):
"""Get the vanity url of a guild."""
user_id = await token_check()
if not await app.storage.has_feature(guild_id, 'VANITY_URL'):
if not await app.storage.has_feature(guild_id, "VANITY_URL"):
# TODO: is this the right error
raise BadRequest('guild has no vanity url support')
raise BadRequest("guild has no vanity url support")
await guild_perm_check(user_id, guild_id, 'manage_guild')
await guild_perm_check(user_id, guild_id, "manage_guild")
j = validate(await request.get_json(), VANITY_URL_PATCH)
inv_code = j['code']
inv_code = j["code"]
# store old vanity in a variable to delete it from
# invites table
old_vanity = await app.storage.vanity_invite(guild_id)
if old_vanity == inv_code:
raise BadRequest('can not change to same invite')
raise BadRequest("can not change to same invite")
# this is sad because we don't really use the things
# sql gives us, but i havent really found a way to put
@ -494,19 +549,22 @@ async def change_vanity_url(guild_id: int):
# guild_id_fkey fails but INSERT when code_fkey fails..
inv = await app.storage.get_invite(inv_code)
if inv:
raise BadRequest('invite already exists')
raise BadRequest("invite already exists")
# TODO: this is bad, what if a guild has no channels?
# we should probably choose the first channel that has
# @everyone read messages
channels = await app.storage.get_channel_data(guild_id)
channel_id = int(channels[0]['id'])
channel_id = int(channels[0]["id"])
# delete the old invite, insert new one
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM invites
WHERE code = $1
""", old_vanity)
""",
old_vanity,
)
await app.db.execute(
"""
@ -515,21 +573,27 @@ async def change_vanity_url(guild_id: int):
max_age, temporary)
VALUES ($1, $2, $3, $4, $5, $6, $7)
""",
inv_code, guild_id, channel_id, user_id,
inv_code,
guild_id,
channel_id,
user_id,
# sane defaults for vanity urls.
0, 0, False,
0,
0,
False,
)
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO vanity_invites (guild_id, code)
VALUES ($1, $2)
ON CONFLICT ON CONSTRAINT vanity_invites_pkey DO
UPDATE
SET code = $2
WHERE vanity_invites.guild_id = $1
""", guild_id, inv_code)
return jsonify(
await app.storage.get_invite(inv_code)
""",
guild_id,
inv_code,
)
return jsonify(await app.storage.get_invite(inv_code))

View File

@ -24,41 +24,39 @@ from quart import Blueprint, current_app as app, send_file, redirect
from litecord.embed.sanitizer import make_md_req_url
from litecord.embed.schemas import EmbedURL
bp = Blueprint('images', __name__)
bp = Blueprint("images", __name__)
async def send_icon(scope, key, icon_hash, **kwargs):
"""Send an icon."""
icon = await app.icons.generic_get(
scope, key, icon_hash, **kwargs)
icon = await app.icons.generic_get(scope, key, icon_hash, **kwargs)
if not icon:
return '', 404
return "", 404
return await send_file(icon.as_path)
def splitext_(filepath):
name, ext = splitext(filepath)
return name, ext.strip('.')
return name, ext.strip(".")
@bp.route('/emojis/<emoji_file>', methods=['GET'])
@bp.route("/emojis/<emoji_file>", methods=["GET"])
async def _get_raw_emoji(emoji_file):
# emoji = app.icons.get_emoji(emoji_id, ext=ext)
# just a test file for now
emoji_id, ext = splitext_(emoji_file)
return await send_icon(
'emoji', emoji_id, None, ext=ext)
return await send_icon("emoji", emoji_id, None, ext=ext)
@bp.route('/icons/<int:guild_id>/<icon_file>', methods=['GET'])
@bp.route("/icons/<int:guild_id>/<icon_file>", methods=["GET"])
async def _get_guild_icon(guild_id: int, icon_file: str):
icon_hash, ext = splitext_(icon_file)
return await send_icon('guild', guild_id, icon_hash, ext=ext)
return await send_icon("guild", guild_id, icon_hash, ext=ext)
@bp.route('/embed/avatars/<int:default_id>.png')
@bp.route("/embed/avatars/<int:default_id>.png")
async def _get_default_user_avatar(default_id: int):
# TODO: how do we determine which assets to use for this?
# I don't think we can use discord assets.
@ -66,25 +64,29 @@ async def _get_default_user_avatar(default_id: int):
async def _handle_webhook_avatar(md_url_redir: str):
md_url = make_md_req_url(app.config, 'img', EmbedURL(md_url_redir))
md_url = make_md_req_url(app.config, "img", EmbedURL(md_url_redir))
return redirect(md_url)
@bp.route('/avatars/<int:user_id>/<avatar_file>')
@bp.route("/avatars/<int:user_id>/<avatar_file>")
async def _get_user_avatar(user_id, avatar_file):
avatar_hash, ext = splitext_(avatar_file)
# first, check if this is a webhook avatar to redir to
md_url_redir = await app.db.fetchval("""
md_url_redir = await app.db.fetchval(
"""
SELECT md_url_redir
FROM webhook_avatars
WHERE webhook_id = $1 AND hash = $2
""", user_id, avatar_hash)
""",
user_id,
avatar_hash,
)
if md_url_redir:
return await _handle_webhook_avatar(md_url_redir)
return await send_icon('user', user_id, avatar_hash, ext=ext)
return await send_icon("user", user_id, avatar_hash, ext=ext)
# @bp.route('/app-icons/<int:application_id>/<icon_hash>.<ext>')
@ -92,19 +94,19 @@ async def get_app_icon(application_id, icon_hash, ext):
pass
@bp.route('/channel-icons/<int:channel_id>/<icon_file>', methods=['GET'])
@bp.route("/channel-icons/<int:channel_id>/<icon_file>", methods=["GET"])
async def _get_gdm_icon(channel_id: int, icon_file: str):
icon_hash, ext = splitext_(icon_file)
return await send_icon('channel-icons', channel_id, icon_hash, ext=ext)
return await send_icon("channel-icons", channel_id, icon_hash, ext=ext)
@bp.route('/splashes/<int:guild_id>/<icon_file>', methods=['GET'])
@bp.route("/splashes/<int:guild_id>/<icon_file>", methods=["GET"])
async def _get_guild_splash(guild_id: int, icon_file: str):
icon_hash, ext = splitext_(icon_file)
return await send_icon('splash', guild_id, icon_hash, ext=ext)
return await send_icon("splash", guild_id, icon_hash, ext=ext)
@bp.route('/banners/<int:guild_id>/<icon_file>', methods=['GET'])
@bp.route("/banners/<int:guild_id>/<icon_file>", methods=["GET"])
async def _get_guild_banner(guild_id: int, icon_file: str):
icon_hash, ext = splitext_(icon_file)
return await send_icon('banner', guild_id, icon_hash, ext=ext)
return await send_icon("banner", guild_id, icon_hash, ext=ext)

View File

@ -32,13 +32,16 @@ from .guilds import create_guild_settings
from ..utils import async_map
from litecord.blueprints.checks import (
channel_check, channel_perm_check, guild_check, guild_perm_check
channel_check,
channel_perm_check,
guild_check,
guild_perm_check,
)
from litecord.blueprints.dm_channels import gdm_is_member, gdm_add_recipient
log = Logger(__name__)
bp = Blueprint('invites', __name__)
bp = Blueprint("invites", __name__)
class UnknownInvite(BadRequest):
@ -48,16 +51,18 @@ class UnknownInvite(BadRequest):
class InvalidInvite(Forbidden):
error_code = 50020
class AlreadyInvited(BaseException):
pass
def gen_inv_code() -> str:
"""Generate an invite code.
This is a primitive and does not guarantee uniqueness.
"""
raw = secrets.token_urlsafe(10)
raw = re.sub(r'\/|\+|\-|\_', '', raw)
raw = re.sub(r"\/|\+|\-|\_", "", raw)
return raw[:7]
@ -65,23 +70,31 @@ def gen_inv_code() -> str:
async def invite_precheck(user_id: int, guild_id: int):
"""pre-check invite use in the context of a guild."""
joined = await app.db.fetchval("""
joined = await app.db.fetchval(
"""
SELECT joined_at
FROM members
WHERE user_id = $1 AND guild_id = $2
""", user_id, guild_id)
""",
user_id,
guild_id,
)
if joined is not None:
raise AlreadyInvited('You are already in the guild')
raise AlreadyInvited("You are already in the guild")
banned = await app.db.fetchval("""
banned = await app.db.fetchval(
"""
SELECT reason
FROM bans
WHERE user_id = $1 AND guild_id = $2
""", user_id, guild_id)
""",
user_id,
guild_id,
)
if banned is not None:
raise InvalidInvite('You are banned.')
raise InvalidInvite("You are banned.")
async def invite_precheck_gdm(user_id: int, channel_id: int):
@ -89,23 +102,23 @@ async def invite_precheck_gdm(user_id: int, channel_id: int):
is_member = await gdm_is_member(channel_id, user_id)
if is_member:
raise AlreadyInvited('You are already in the Group DM')
raise AlreadyInvited("You are already in the Group DM")
async def _inv_check_age(inv: dict):
if inv['max_age'] == 0:
if inv["max_age"] == 0:
return
now = datetime.datetime.utcnow()
delta_sec = (now - inv['created_at']).total_seconds()
delta_sec = (now - inv["created_at"]).total_seconds()
if delta_sec > inv['max_age']:
await delete_invite(inv['code'])
raise InvalidInvite('Invite is expired')
if delta_sec > inv["max_age"]:
await delete_invite(inv["code"])
raise InvalidInvite("Invite is expired")
if inv['max_uses'] is not -1 and inv['uses'] > inv['max_uses']:
await delete_invite(inv['code'])
raise InvalidInvite('Too many uses')
if inv["max_uses"] is not -1 and inv["uses"] > inv["max_uses"]:
await delete_invite(inv["code"])
raise InvalidInvite("Too many uses")
async def _guild_add_member(guild_id: int, user_id: int):
@ -119,78 +132,89 @@ async def _guild_add_member(guild_id: int, user_id: int):
"""
# TODO: system message for member join
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO members (user_id, guild_id)
VALUES ($1, $2)
""", user_id, guild_id)
""",
user_id,
guild_id,
)
await create_guild_settings(guild_id, user_id)
# add the @everyone role to the invited member
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO member_roles (user_id, guild_id, role_id)
VALUES ($1, $2, $3)
""", user_id, guild_id, guild_id)
""",
user_id,
guild_id,
guild_id,
)
# tell current members a new member came up
member = await app.storage.get_member_data_one(guild_id, user_id)
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_ADD', {
**member,
**{
'guild_id': str(guild_id),
},
})
await app.dispatcher.dispatch_guild(
guild_id, "GUILD_MEMBER_ADD", {**member, **{"guild_id": str(guild_id)}}
)
# update member lists for the new member
await app.dispatcher.dispatch(
'lazy_guild', guild_id, 'new_member', user_id)
await app.dispatcher.dispatch("lazy_guild", guild_id, "new_member", user_id)
# subscribe new member to guild, so they get events n stuff
await app.dispatcher.sub('guild', guild_id, user_id)
await app.dispatcher.sub("guild", guild_id, user_id)
# tell the new member that theres the guild it just joined.
# we use dispatch_user_guild so that we send the GUILD_CREATE
# just to the shards that are actually tied to it.
guild = await app.storage.get_guild_full(guild_id, user_id, 250)
await app.dispatcher.dispatch_user_guild(
user_id, guild_id, 'GUILD_CREATE', guild)
await app.dispatcher.dispatch_user_guild(user_id, guild_id, "GUILD_CREATE", guild)
async def use_invite(user_id, invite_code):
"""Try using an invite"""
inv = await app.db.fetchrow("""
inv = await app.db.fetchrow(
"""
SELECT code, channel_id, guild_id, created_at,
max_age, uses, max_uses
FROM invites
WHERE code = $1
""", invite_code)
""",
invite_code,
)
if inv is None:
raise UnknownInvite('Unknown invite')
raise UnknownInvite("Unknown invite")
await _inv_check_age(inv)
# NOTE: if group dm invite, guild_id is null.
guild_id = inv['guild_id']
guild_id = inv["guild_id"]
try:
if guild_id is None:
channel_id = inv['channel_id']
await invite_precheck_gdm(user_id, inv['channel_id'])
channel_id = inv["channel_id"]
await invite_precheck_gdm(user_id, inv["channel_id"])
await gdm_add_recipient(channel_id, user_id)
else:
await invite_precheck(user_id, guild_id)
await _guild_add_member(guild_id, user_id)
await app.db.execute("""
await app.db.execute(
"""
UPDATE invites
SET uses = uses + 1
WHERE code = $1
""", invite_code)
""",
invite_code,
)
except AlreadyInvited:
pass
@bp.route('/channels/<int:channel_id>/invites', methods=['POST'])
@bp.route("/channels/<int:channel_id>/invites", methods=["POST"])
async def create_invite(channel_id):
"""Create an invite to a channel."""
user_id = await token_check()
@ -201,12 +225,14 @@ async def create_invite(channel_id):
# NOTE: this works on group dms, since it returns ALL_PERMISSIONS on
# non-guild channels.
await channel_perm_check(user_id, channel_id, 'create_invites')
await channel_perm_check(user_id, channel_id, "create_invites")
if chantype not in (ChannelType.GUILD_TEXT,
if chantype not in (
ChannelType.GUILD_TEXT,
ChannelType.GUILD_VOICE,
ChannelType.GROUP_DM):
raise BadRequest('Invalid channel type')
ChannelType.GROUP_DM,
):
raise BadRequest("Invalid channel type")
invite_code = gen_inv_code()
@ -222,101 +248,122 @@ async def create_invite(channel_id):
max_age, temporary)
VALUES ($1, $2, $3, $4, $5, $6, $7)
""",
invite_code, guild_id, channel_id, user_id,
j['max_uses'], j['max_age'], j['temporary']
invite_code,
guild_id,
channel_id,
user_id,
j["max_uses"],
j["max_age"],
j["temporary"],
)
invite = await app.storage.get_invite(invite_code)
return jsonify(invite)
@bp.route('/invite/<invite_code>', methods=['GET'])
@bp.route('/invites/<invite_code>', methods=['GET'])
@bp.route("/invite/<invite_code>", methods=["GET"])
@bp.route("/invites/<invite_code>", methods=["GET"])
async def get_invite(invite_code: str):
inv = await app.storage.get_invite(invite_code)
if not inv:
return '', 404
return "", 404
if request.args.get('with_counts'):
if request.args.get("with_counts"):
extra = await app.storage.get_invite_extra(invite_code)
inv.update(extra)
return jsonify(inv)
async def delete_invite(invite_code: str):
"""Delete an invite."""
await app.db.fetchval("""
await app.db.fetchval(
"""
DELETE FROM invites
WHERE code = $1
""", invite_code)
""",
invite_code,
)
@bp.route('/invite/<invite_code>', methods=['DELETE'])
@bp.route('/invites/<invite_code>', methods=['DELETE'])
@bp.route("/invite/<invite_code>", methods=["DELETE"])
@bp.route("/invites/<invite_code>", methods=["DELETE"])
async def _delete_invite(invite_code: str):
user_id = await token_check()
guild_id = await app.db.fetchval("""
guild_id = await app.db.fetchval(
"""
SELECT guild_id
FROM invites
WHERE code = $1
""", invite_code)
""",
invite_code,
)
if guild_id is None:
raise BadRequest('Unknown invite')
raise BadRequest("Unknown invite")
await guild_perm_check(user_id, guild_id, 'manage_channels')
await guild_perm_check(user_id, guild_id, "manage_channels")
inv = await app.storage.get_invite(invite_code)
await delete_invite(invite_code)
return jsonify(inv)
async def _get_inv(code):
inv = await app.storage.get_invite(code)
meta = await app.storage.get_invite_metadata(code)
return {**inv, **meta}
@bp.route('/guilds/<int:guild_id>/invites', methods=['GET'])
@bp.route("/guilds/<int:guild_id>/invites", methods=["GET"])
async def get_guild_invites(guild_id: int):
"""Get all invites for a guild."""
user_id = await token_check()
await guild_check(user_id, guild_id)
await guild_perm_check(user_id, guild_id, 'manage_guild')
await guild_perm_check(user_id, guild_id, "manage_guild")
inv_codes = await app.db.fetch("""
inv_codes = await app.db.fetch(
"""
SELECT code
FROM invites
WHERE guild_id = $1
""", guild_id)
""",
guild_id,
)
inv_codes = [r['code'] for r in inv_codes]
inv_codes = [r["code"] for r in inv_codes]
invs = await async_map(_get_inv, inv_codes)
return jsonify(invs)
@bp.route('/channels/<int:channel_id>/invites', methods=['GET'])
@bp.route("/channels/<int:channel_id>/invites", methods=["GET"])
async def get_channel_invites(channel_id: int):
"""Get all invites for a channel."""
user_id = await token_check()
_ctype, guild_id = await channel_check(user_id, channel_id)
await guild_perm_check(user_id, guild_id, 'manage_channels')
await guild_perm_check(user_id, guild_id, "manage_channels")
inv_codes = await app.db.fetch("""
inv_codes = await app.db.fetch(
"""
SELECT code
FROM invites
WHERE guild_id = $1 AND channel_id = $2
""", guild_id, channel_id)
""",
guild_id,
channel_id,
)
inv_codes = [r['code'] for r in inv_codes]
inv_codes = [r["code"] for r in inv_codes]
invs = await async_map(_get_inv, inv_codes)
return jsonify(invs)
@bp.route('/invite/<invite_code>', methods=['POST'])
@bp.route('/invites/<invite_code>', methods=['POST'])
@bp.route("/invite/<invite_code>", methods=["POST"])
@bp.route("/invites/<invite_code>", methods=["POST"])
async def _use_invite(invite_code):
"""Use an invite."""
user_id = await token_check()
@ -327,9 +374,4 @@ async def _use_invite(invite_code):
inv = await app.storage.get_invite(invite_code)
inv_meta = await app.storage.get_invite_metadata(invite_code)
return jsonify({
**inv,
**{
'inviter': inv_meta['inviter']
}
})
return jsonify({**inv, **{"inviter": inv_meta["inviter"]}})

View File

@ -19,83 +19,75 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
from quart import Blueprint, current_app as app, jsonify, request
bp = Blueprint('nodeinfo', __name__)
bp = Blueprint("nodeinfo", __name__)
@bp.route('/.well-known/nodeinfo')
@bp.route("/.well-known/nodeinfo")
async def _dummy_nodeinfo_index():
proto = 'http' if not app.config['IS_SSL'] else 'https'
main_url = app.config.get('MAIN_URL', request.host)
proto = "http" if not app.config["IS_SSL"] else "https"
main_url = app.config.get("MAIN_URL", request.host)
return jsonify({
'links': [{
'href': f'{proto}://{main_url}/nodeinfo/2.0.json',
'rel': 'http://nodeinfo.diaspora.software/ns/schema/2.0'
}, {
'href': f'{proto}://{main_url}/nodeinfo/2.1.json',
'rel': 'http://nodeinfo.diaspora.software/ns/schema/2.1'
}]
})
return jsonify(
{
"links": [
{
"href": f"{proto}://{main_url}/nodeinfo/2.0.json",
"rel": "http://nodeinfo.diaspora.software/ns/schema/2.0",
},
{
"href": f"{proto}://{main_url}/nodeinfo/2.1.json",
"rel": "http://nodeinfo.diaspora.software/ns/schema/2.1",
},
]
}
)
async def fetch_nodeinfo_20():
usercount = await app.db.fetchval("""
usercount = await app.db.fetchval(
"""
SELECT COUNT(*)
FROM users
""")
"""
)
message_count = await app.db.fetchval("""
message_count = await app.db.fetchval(
"""
SELECT COUNT(*)
FROM messages
""")
"""
)
return {
'metadata': {
'features': [
'discord_api'
],
'nodeDescription': 'A Litecord instance',
'nodeName': 'Litecord/Nya',
'private': False,
'federation': {}
"metadata": {
"features": ["discord_api"],
"nodeDescription": "A Litecord instance",
"nodeName": "Litecord/Nya",
"private": False,
"federation": {},
},
'openRegistrations': app.config['REGISTRATIONS'],
'protocols': [],
'software': {
'name': 'litecord',
'version': 'litecord v0',
},
'services': {
'inbound': [],
'outbound': [],
},
'usage': {
'localPosts': message_count,
'users': {
'total': usercount
}
},
'version': '2.0',
"openRegistrations": app.config["REGISTRATIONS"],
"protocols": [],
"software": {"name": "litecord", "version": "litecord v0"},
"services": {"inbound": [], "outbound": []},
"usage": {"localPosts": message_count, "users": {"total": usercount}},
"version": "2.0",
}
@bp.route('/nodeinfo/2.0.json')
@bp.route("/nodeinfo/2.0.json")
async def _nodeinfo_20():
"""Handler for nodeinfo 2.0."""
raw_nodeinfo = await fetch_nodeinfo_20()
return jsonify(raw_nodeinfo)
@bp.route('/nodeinfo/2.1.json')
@bp.route("/nodeinfo/2.1.json")
async def _nodeinfo_21():
"""Handler for nodeinfo 2.1."""
raw_nodeinfo = await fetch_nodeinfo_20()
raw_nodeinfo['software']['repository'] = 'https://gitlab.com/litecord/litecord'
raw_nodeinfo['version'] = '2.1'
raw_nodeinfo["software"]["repository"] = "https://gitlab.com/litecord/litecord"
raw_nodeinfo["version"] = "2.1"
return jsonify(raw_nodeinfo)

View File

@ -26,76 +26,89 @@ from ..enums import RelationshipType
from litecord.errors import BadRequest
bp = Blueprint('relationship', __name__)
bp = Blueprint("relationship", __name__)
@bp.route('/@me/relationships', methods=['GET'])
@bp.route("/@me/relationships", methods=["GET"])
async def get_me_relationships():
user_id = await token_check()
return jsonify(
await app.user_storage.get_relationships(user_id))
return jsonify(await app.user_storage.get_relationships(user_id))
async def _dispatch_single_pres(user_id, presence: dict):
await app.dispatcher.dispatch(
'user', user_id, 'PRESENCE_UPDATE', presence
)
await app.dispatcher.dispatch("user", user_id, "PRESENCE_UPDATE", presence)
async def _unsub_friend(user_id, peer_id):
await app.dispatcher.unsub('friend', user_id, peer_id)
await app.dispatcher.unsub('friend', peer_id, user_id)
await app.dispatcher.unsub("friend", user_id, peer_id)
await app.dispatcher.unsub("friend", peer_id, user_id)
async def _sub_friend(user_id, peer_id):
await app.dispatcher.sub('friend', user_id, peer_id)
await app.dispatcher.sub('friend', peer_id, user_id)
await app.dispatcher.sub("friend", user_id, peer_id)
await app.dispatcher.sub("friend", peer_id, user_id)
# dispatch presence update to the user and peer about
# eachother's presence.
user_pres, peer_pres = await app.presence.friend_presences(
[user_id, peer_id]
)
user_pres, peer_pres = await app.presence.friend_presences([user_id, peer_id])
await _dispatch_single_pres(user_id, peer_pres)
await _dispatch_single_pres(peer_id, user_pres)
async def make_friend(user_id: int, peer_id: int,
rel_type=RelationshipType.FRIEND.value):
async def make_friend(
user_id: int, peer_id: int, rel_type=RelationshipType.FRIEND.value
):
_friend = RelationshipType.FRIEND.value
_block = RelationshipType.BLOCK.value
try:
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO relationships (user_id, peer_id, rel_type)
VALUES ($1, $2, $3)
""", user_id, peer_id, rel_type)
""",
user_id,
peer_id,
rel_type,
)
except UniqueViolationError:
# try to update rel_type
old_rel_type = await app.db.fetchval("""
old_rel_type = await app.db.fetchval(
"""
SELECT rel_type
FROM relationships
WHERE user_id = $1 AND peer_id = $2
""", user_id, peer_id)
""",
user_id,
peer_id,
)
if old_rel_type == _friend and rel_type == _block:
await app.db.execute("""
await app.db.execute(
"""
UPDATE relationships
SET rel_type = $1
WHERE user_id = $2 AND peer_id = $3
""", rel_type, user_id, peer_id)
""",
rel_type,
user_id,
peer_id,
)
# remove any existing friendship before the block
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM relationships
WHERE peer_id = $1 AND user_id = $2 AND rel_type = $3
""", peer_id, user_id, _friend)
""",
peer_id,
user_id,
_friend,
)
await app.dispatcher.dispatch_user(
peer_id, 'RELATIONSHIP_REMOVE', {
'type': _friend,
'id': str(user_id)
}
peer_id, "RELATIONSHIP_REMOVE", {"type": _friend, "id": str(user_id)}
)
await _unsub_friend(user_id, peer_id)
@ -106,95 +119,118 @@ async def make_friend(user_id: int, peer_id: int,
# check if this is an acceptance
# of a friend request
existing = await app.db.fetchrow("""
existing = await app.db.fetchrow(
"""
SELECT user_id, peer_id
FROM relationships
WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3
""", peer_id, user_id, _friend)
""",
peer_id,
user_id,
_friend,
)
_dispatch = app.dispatcher.dispatch_user
if existing:
# accepted a friend request, dispatch respective
# relationship events
await _dispatch(user_id, 'RELATIONSHIP_REMOVE', {
'type': RelationshipType.INCOMING.value,
'id': str(peer_id)
})
await _dispatch(
user_id,
"RELATIONSHIP_REMOVE",
{"type": RelationshipType.INCOMING.value, "id": str(peer_id)},
)
await _dispatch(user_id, 'RELATIONSHIP_ADD', {
'type': _friend,
'id': str(peer_id),
'user': await app.storage.get_user(peer_id)
})
await _dispatch(
user_id,
"RELATIONSHIP_ADD",
{
"type": _friend,
"id": str(peer_id),
"user": await app.storage.get_user(peer_id),
},
)
await _dispatch(peer_id, 'RELATIONSHIP_ADD', {
'type': _friend,
'id': str(user_id),
'user': await app.storage.get_user(user_id)
})
await _dispatch(
peer_id,
"RELATIONSHIP_ADD",
{
"type": _friend,
"id": str(user_id),
"user": await app.storage.get_user(user_id),
},
)
await _sub_friend(user_id, peer_id)
return '', 204
return "", 204
# check if friend AND not acceptance of fr
if rel_type == _friend:
await _dispatch(user_id, 'RELATIONSHIP_ADD', {
'id': str(peer_id),
'type': RelationshipType.OUTGOING.value,
'user': await app.storage.get_user(peer_id),
})
await _dispatch(
user_id,
"RELATIONSHIP_ADD",
{
"id": str(peer_id),
"type": RelationshipType.OUTGOING.value,
"user": await app.storage.get_user(peer_id),
},
)
await _dispatch(peer_id, 'RELATIONSHIP_ADD', {
'id': str(user_id),
'type': RelationshipType.INCOMING.value,
'user': await app.storage.get_user(user_id)
})
await _dispatch(
peer_id,
"RELATIONSHIP_ADD",
{
"id": str(user_id),
"type": RelationshipType.INCOMING.value,
"user": await app.storage.get_user(user_id),
},
)
# we don't make the pubsub link
# until the peer accepts the friend request
return '', 204
return "", 204
return
class RelationshipFailed(BadRequest):
"""Exception for general relationship errors."""
error_code = 80004
class RelationshipBlocked(BadRequest):
"""Exception for when the peer has blocked the user."""
error_code = 80001
@bp.route('/@me/relationships', methods=['POST'])
@bp.route("/@me/relationships", methods=["POST"])
async def post_relationship():
user_id = await token_check()
j = validate(await request.get_json(), SPECIFIC_FRIEND)
uid = await app.storage.search_user(j['username'],
str(j['discriminator']))
uid = await app.storage.search_user(j["username"], str(j["discriminator"]))
if not uid:
raise RelationshipFailed('No users with DiscordTag exist')
raise RelationshipFailed("No users with DiscordTag exist")
res = await make_friend(user_id, uid)
if res is None:
raise RelationshipBlocked('Can not friend user due to block')
raise RelationshipBlocked("Can not friend user due to block")
return '', 204
return "", 204
@bp.route('/@me/relationships/<int:peer_id>', methods=['PUT'])
@bp.route("/@me/relationships/<int:peer_id>", methods=["PUT"])
async def add_relationship(peer_id: int):
"""Add a relationship to the peer."""
user_id = await token_check()
payload = validate(await request.get_json(), RELATIONSHIP)
rel_type = payload['type']
rel_type = payload["type"]
res = await make_friend(user_id, peer_id, rel_type)
@ -204,18 +240,22 @@ async def add_relationship(peer_id: int):
# make_friend did not succeed, so we
# assume it is a block and dispatch
# the respective RELATIONSHIP_ADD.
await app.dispatcher.dispatch_user(user_id, 'RELATIONSHIP_ADD', {
'id': str(peer_id),
'type': RelationshipType.BLOCK.value,
'user': await app.storage.get_user(peer_id)
})
await app.dispatcher.dispatch_user(
user_id,
"RELATIONSHIP_ADD",
{
"id": str(peer_id),
"type": RelationshipType.BLOCK.value,
"user": await app.storage.get_user(peer_id),
},
)
await _unsub_friend(user_id, peer_id)
return '', 204
return "", 204
@bp.route('/@me/relationships/<int:peer_id>', methods=['DELETE'])
@bp.route("/@me/relationships/<int:peer_id>", methods=["DELETE"])
async def remove_relationship(peer_id: int):
"""Remove an existing relationship"""
user_id = await token_check()
@ -223,69 +263,86 @@ async def remove_relationship(peer_id: int):
_block = RelationshipType.BLOCK.value
_dispatch = app.dispatcher.dispatch_user
rel_type = await app.db.fetchval("""
rel_type = await app.db.fetchval(
"""
SELECT rel_type
FROM relationships
WHERE user_id = $1 AND peer_id = $2
""", user_id, peer_id)
""",
user_id,
peer_id,
)
incoming_rel_type = await app.db.fetchval("""
incoming_rel_type = await app.db.fetchval(
"""
SELECT rel_type
FROM relationships
WHERE user_id = $1 AND peer_id = $2
""", peer_id, user_id)
""",
peer_id,
user_id,
)
# if any of those are friend
if _friend in (rel_type, incoming_rel_type):
# closing the friendship, have to delete both rows
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM relationships
WHERE (
(user_id = $1 AND peer_id = $2) OR
(user_id = $2 AND peer_id = $1)
) AND rel_type = $3
""", user_id, peer_id, _friend)
""",
user_id,
peer_id,
_friend,
)
# if there wasnt any mutual friendship before,
# assume they were requests of INCOMING
# and OUTGOING.
user_del_type = RelationshipType.OUTGOING.value if \
incoming_rel_type != _friend else _friend
user_del_type = (
RelationshipType.OUTGOING.value if incoming_rel_type != _friend else _friend
)
await _dispatch(user_id, 'RELATIONSHIP_REMOVE', {
'id': str(peer_id),
'type': user_del_type,
})
await _dispatch(
user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": user_del_type}
)
peer_del_type = RelationshipType.INCOMING.value if \
incoming_rel_type != _friend else _friend
peer_del_type = (
RelationshipType.INCOMING.value if incoming_rel_type != _friend else _friend
)
await _dispatch(peer_id, 'RELATIONSHIP_REMOVE', {
'id': str(user_id),
'type': peer_del_type,
})
await _dispatch(
peer_id, "RELATIONSHIP_REMOVE", {"id": str(user_id), "type": peer_del_type}
)
await _unsub_friend(user_id, peer_id)
return '', 204
return "", 204
# was a block!
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM relationships
WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3
""", user_id, peer_id, _block)
""",
user_id,
peer_id,
_block,
)
await _dispatch(user_id, 'RELATIONSHIP_REMOVE', {
'id': str(peer_id),
'type': _block,
})
await _dispatch(
user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": _block}
)
await _unsub_friend(user_id, peer_id)
return '', 204
return "", 204
@bp.route('/<int:peer_id>/relationships', methods=['GET'])
@bp.route("/<int:peer_id>/relationships", methods=["GET"])
async def get_mutual_friends(peer_id: int):
"""Fetch a users' mutual friends with the current user."""
user_id = await token_check()
@ -294,17 +351,15 @@ async def get_mutual_friends(peer_id: int):
peer = await app.storage.get_user(peer_id)
if not peer:
return '', 204
return "", 204
# NOTE: maybe this could be better with pure SQL calculations
# but it would be beyond my current SQL knowledge, so...
user_rels = await app.user_storage.get_relationships(user_id)
peer_rels = await app.user_storage.get_relationships(peer_id)
user_friends = {rel['user']['id']
for rel in user_rels if rel['type'] == _friend}
peer_friends = {rel['user']['id']
for rel in peer_rels if rel['type'] == _friend}
user_friends = {rel["user"]["id"] for rel in user_rels if rel["type"] == _friend}
peer_friends = {rel["user"]["id"] for rel in peer_rels if rel["type"] == _friend}
# get the intersection, then map them to Storage.get_user() calls
mutual_ids = user_friends & peer_friends
@ -312,8 +367,6 @@ async def get_mutual_friends(peer_id: int):
mutual_friends = []
for friend_id in mutual_ids:
mutual_friends.append(
await app.storage.get_user(int(friend_id))
)
mutual_friends.append(await app.storage.get_user(int(friend_id)))
return jsonify(mutual_friends)

View File

@ -19,21 +19,19 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
from quart import Blueprint, jsonify
bp = Blueprint('science', __name__)
bp = Blueprint("science", __name__)
@bp.route('/science', methods=['POST'])
@bp.route("/science", methods=["POST"])
async def science():
return '', 204
return "", 204
@bp.route('/applications', methods=['GET'])
@bp.route("/applications", methods=["GET"])
async def applications():
return jsonify([])
@bp.route('/experiments', methods=['GET'])
@bp.route("/experiments", methods=["GET"])
async def experiments():
return jsonify({
'assignments': []
})
return jsonify({"assignments": []})

View File

@ -20,23 +20,24 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
from quart import Blueprint, current_app as app, render_template_string
from pathlib import Path
bp = Blueprint('static', __name__)
bp = Blueprint("static", __name__)
@bp.route('/<path:path>')
@bp.route("/<path:path>")
async def static_pages(path):
"""Map requests from / to /static."""
if '..' in path:
return 'no', 404
if ".." in path:
return "no", 404
static_path = Path.cwd() / Path('static') / path
static_path = Path.cwd() / Path("static") / path
return await app.send_static_file(str(static_path))
@bp.route('/')
@bp.route('/api')
@bp.route("/")
@bp.route("/api")
async def index_handler():
"""Handler for the index page."""
index_path = Path.cwd() / Path('static') / 'index.html'
index_path = Path.cwd() / Path("static") / "index.html"
return await render_template_string(
index_path.read_text(), inst_name=app.config['NAME'])
index_path.read_text(), inst_name=app.config["NAME"]
)

View File

@ -21,4 +21,4 @@ from .billing import bp as user_billing
from .settings import bp as user_settings
from .fake_store import bp as fake_store
__all__ = ['user_billing', 'user_settings', 'fake_store']
__all__ = ["user_billing", "user_settings", "fake_store"]

View File

@ -33,7 +33,7 @@ from litecord.enums import UserFlags, PremiumType
from litecord.blueprints.users import mass_user_update
log = Logger(__name__)
bp = Blueprint('users_billing', __name__)
bp = Blueprint("users_billing", __name__)
class PaymentSource(Enum):
@ -68,78 +68,87 @@ class PaymentStatus:
PLAN_ID_TO_TYPE = {
'premium_month_tier_1': PremiumType.TIER_1,
'premium_month_tier_2': PremiumType.TIER_2,
'premium_year_tier_1': PremiumType.TIER_1,
'premium_year_tier_2': PremiumType.TIER_2,
"premium_month_tier_1": PremiumType.TIER_1,
"premium_month_tier_2": PremiumType.TIER_2,
"premium_year_tier_1": PremiumType.TIER_1,
"premium_year_tier_2": PremiumType.TIER_2,
}
# how much should a payment be, depending
# of the subscription
AMOUNTS = {
'premium_month_tier_1': 499,
'premium_month_tier_2': 999,
'premium_year_tier_1': 4999,
'premium_year_tier_2': 9999,
"premium_month_tier_1": 499,
"premium_month_tier_2": 999,
"premium_year_tier_1": 4999,
"premium_year_tier_2": 9999,
}
CREATE_SUBSCRIPTION = {
'payment_gateway_plan_id': {'type': 'string'},
'payment_source_id': {'coerce': int}
"payment_gateway_plan_id": {"type": "string"},
"payment_source_id": {"coerce": int},
}
PAYMENT_SOURCE = {
'billing_address': {
'type': 'dict',
'schema': {
'country': {'type': 'string', 'required': True},
'city': {'type': 'string', 'required': True},
'name': {'type': 'string', 'required': True},
'line_1': {'type': 'string', 'required': False},
'line_2': {'type': 'string', 'required': False},
'postal_code': {'type': 'string', 'required': True},
'state': {'type': 'string', 'required': True},
}
"billing_address": {
"type": "dict",
"schema": {
"country": {"type": "string", "required": True},
"city": {"type": "string", "required": True},
"name": {"type": "string", "required": True},
"line_1": {"type": "string", "required": False},
"line_2": {"type": "string", "required": False},
"postal_code": {"type": "string", "required": True},
"state": {"type": "string", "required": True},
},
'payment_gateway': {'type': 'number', 'required': True},
'token': {'type': 'string', 'required': True},
},
"payment_gateway": {"type": "number", "required": True},
"token": {"type": "string", "required": True},
}
async def get_payment_source_ids(user_id: int) -> list:
rows = await app.db.fetch("""
rows = await app.db.fetch(
"""
SELECT id
FROM user_payment_sources
WHERE user_id = $1
""", user_id)
""",
user_id,
)
return [r['id'] for r in rows]
return [r["id"] for r in rows]
async def get_payment_ids(user_id: int, db=None) -> list:
if not db:
db = app.db
rows = await db.fetch("""
rows = await db.fetch(
"""
SELECT id
FROM user_payments
WHERE user_id = $1
""", user_id)
""",
user_id,
)
return [r['id'] for r in rows]
return [r["id"] for r in rows]
async def get_subscription_ids(user_id: int) -> list:
rows = await app.db.fetch("""
rows = await app.db.fetch(
"""
SELECT id
FROM user_subscriptions
WHERE user_id = $1
""", user_id)
""",
user_id,
)
return [r['id'] for r in rows]
return [r["id"] for r in rows]
async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
@ -148,41 +157,44 @@ async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
if not db:
db = app.db
source_type = await db.fetchval("""
source_type = await db.fetchval(
"""
SELECT source_type
FROM user_payment_sources
WHERE id = $1 AND user_id = $2
""", source_id, user_id)
""",
source_id,
user_id,
)
source_type = PaymentSource(source_type)
specific_fields = {
PaymentSource.PAYPAL: ['paypal_email'],
PaymentSource.CREDIT: ['expires_month', 'expires_year',
'brand', 'cc_full']
PaymentSource.PAYPAL: ["paypal_email"],
PaymentSource.CREDIT: ["expires_month", "expires_year", "brand", "cc_full"],
}[source_type]
fields = ','.join(specific_fields)
fields = ",".join(specific_fields)
extras_row = await db.fetchrow(f"""
extras_row = await db.fetchrow(
f"""
SELECT {fields}, billing_address, default_, id::text
FROM user_payment_sources
WHERE id = $1
""", source_id)
""",
source_id,
)
derow = dict(extras_row)
if source_type == PaymentSource.CREDIT:
derow['last_4'] = derow['cc_full'][-4:]
derow.pop('cc_full')
derow["last_4"] = derow["cc_full"][-4:]
derow.pop("cc_full")
derow['default'] = derow['default_']
derow.pop('default_')
derow["default"] = derow["default_"]
derow.pop("default_")
source = {
'id': str(source_id),
'type': source_type.value,
}
source = {"id": str(source_id), "type": source_type.value}
return {**source, **derow}
@ -192,7 +204,8 @@ async def get_subscription(subscription_id: int, db=None):
if not db:
db = app.db
row = await db.fetchrow("""
row = await db.fetchrow(
"""
SELECT id::text, source_id::text AS payment_source_id,
user_id,
payment_gateway, payment_gateway_plan_id,
@ -201,14 +214,16 @@ async def get_subscription(subscription_id: int, db=None):
canceled_at, s_type, status
FROM user_subscriptions
WHERE id = $1
""", subscription_id)
""",
subscription_id,
)
drow = dict(row)
drow['type'] = drow['s_type']
drow.pop('s_type')
drow["type"] = drow["s_type"]
drow.pop("s_type")
to_tstamp = ['current_period_start', 'current_period_end', 'canceled_at']
to_tstamp = ["current_period_start", "current_period_end", "canceled_at"]
for field in to_tstamp:
drow[field] = timestamp_(drow[field])
@ -221,27 +236,30 @@ async def get_payment(payment_id: int, db=None):
if not db:
db = app.db
row = await db.fetchrow("""
row = await db.fetchrow(
"""
SELECT id::text, source_id, subscription_id, user_id,
amount, amount_refunded, currency,
description, status, tax, tax_inclusive
FROM user_payments
WHERE id = $1
""", payment_id)
""",
payment_id,
)
drow = dict(row)
drow.pop('source_id')
drow.pop('subscription_id')
drow.pop('user_id')
drow.pop("source_id")
drow.pop("subscription_id")
drow.pop("user_id")
drow['created_at'] = snowflake_datetime(int(drow['id']))
drow["created_at"] = snowflake_datetime(int(drow["id"]))
drow['payment_source'] = await get_payment_source(
row['user_id'], row['source_id'], db)
drow["payment_source"] = await get_payment_source(
row["user_id"], row["source_id"], db
)
drow['subscription'] = await get_subscription(
row['subscription_id'], db)
drow["subscription"] = await get_subscription(row["subscription_id"], db)
return drow
@ -255,7 +273,7 @@ async def create_payment(subscription_id, db=None):
new_id = get_snowflake()
amount = AMOUNTS[sub['payment_gateway_plan_id']]
amount = AMOUNTS[sub["payment_gateway_plan_id"]]
await db.execute(
"""
@ -266,10 +284,16 @@ async def create_payment(subscription_id, db=None):
)
VALUES
($1, $2, $3, $4, $5, 0, $6, $7, $8, 0, false)
""", new_id, int(sub['payment_source_id']),
subscription_id, int(sub['user_id']),
amount, 'usd', 'FUCK NITRO',
PaymentStatus.SUCCESS)
""",
new_id,
int(sub["payment_source_id"]),
subscription_id,
int(sub["user_id"]),
amount,
"usd",
"FUCK NITRO",
PaymentStatus.SUCCESS,
)
return new_id
@ -278,29 +302,34 @@ async def process_subscription(app, subscription_id: int):
"""Process a single subscription."""
sub = await get_subscription(subscription_id, app.db)
user_id = int(sub['user_id'])
user_id = int(sub["user_id"])
if sub['status'] != SubscriptionStatus.ACTIVE:
log.debug('ignoring sub {}, not active',
subscription_id)
if sub["status"] != SubscriptionStatus.ACTIVE:
log.debug("ignoring sub {}, not active", subscription_id)
return
# if the subscription is still active
# (should get cancelled status on failed
# payments), then we should update premium status
first_payment_id = await app.db.fetchval("""
first_payment_id = await app.db.fetchval(
"""
SELECT MIN(id)
FROM user_payments
WHERE subscription_id = $1
""", subscription_id)
""",
subscription_id,
)
first_payment_ts = snowflake_datetime(first_payment_id)
premium_since = await app.db.fetchval("""
premium_since = await app.db.fetchval(
"""
SELECT premium_since
FROM users
WHERE id = $1
""", user_id)
""",
user_id,
)
premium_since = premium_since or datetime.datetime.fromtimestamp(0)
@ -312,27 +341,34 @@ async def process_subscription(app, subscription_id: int):
if delta.total_seconds() < 24 * HOURS:
return
old_flags = await app.db.fetchval("""
old_flags = await app.db.fetchval(
"""
SELECT flags
FROM users
WHERE id = $1
""", user_id)
""",
user_id,
)
new_flags = old_flags | UserFlags.premium_early
log.debug('updating flags {}, {} => {}',
user_id, old_flags, new_flags)
log.debug("updating flags {}, {} => {}", user_id, old_flags, new_flags)
await app.db.execute("""
await app.db.execute(
"""
UPDATE users
SET premium_since = $1, flags = $2
WHERE id = $3
""", first_payment_ts, new_flags, user_id)
""",
first_payment_ts,
new_flags,
user_id,
)
# dispatch updated user to all possible clients
await mass_user_update(user_id, app)
@bp.route('/@me/billing/payment-sources', methods=['GET'])
@bp.route("/@me/billing/payment-sources", methods=["GET"])
async def _get_billing_sources():
user_id = await token_check()
source_ids = await get_payment_source_ids(user_id)
@ -346,7 +382,7 @@ async def _get_billing_sources():
return jsonify(res)
@bp.route('/@me/billing/subscriptions', methods=['GET'])
@bp.route("/@me/billing/subscriptions", methods=["GET"])
async def _get_billing_subscriptions():
user_id = await token_check()
sub_ids = await get_subscription_ids(user_id)
@ -358,7 +394,7 @@ async def _get_billing_subscriptions():
return jsonify(res)
@bp.route('/@me/billing/payments', methods=['GET'])
@bp.route("/@me/billing/payments", methods=["GET"])
async def _get_billing_payments():
user_id = await token_check()
payment_ids = await get_payment_ids(user_id)
@ -370,7 +406,7 @@ async def _get_billing_payments():
return jsonify(res)
@bp.route('/@me/billing/payment-sources', methods=['POST'])
@bp.route("/@me/billing/payment-sources", methods=["POST"])
async def _create_payment_source():
user_id = await token_check()
j = validate(await request.get_json(), PAYMENT_SOURCE)
@ -383,34 +419,40 @@ async def _create_payment_source():
default_, expires_month, expires_year, brand, cc_full,
billing_address)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
""", new_source_id, user_id, PaymentSource.CREDIT.value,
True, 12, 6969, 'Visa', '4242424242424242',
json.dumps(j['billing_address']))
return jsonify(
await get_payment_source(user_id, new_source_id)
""",
new_source_id,
user_id,
PaymentSource.CREDIT.value,
True,
12,
6969,
"Visa",
"4242424242424242",
json.dumps(j["billing_address"]),
)
return jsonify(await get_payment_source(user_id, new_source_id))
@bp.route('/@me/billing/subscriptions', methods=['POST'])
@bp.route("/@me/billing/subscriptions", methods=["POST"])
async def _create_subscription():
user_id = await token_check()
j = validate(await request.get_json(), CREATE_SUBSCRIPTION)
source = await get_payment_source(user_id, j['payment_source_id'])
source = await get_payment_source(user_id, j["payment_source_id"])
if not source:
raise BadRequest('invalid source id')
raise BadRequest("invalid source id")
plan_id = j['payment_gateway_plan_id']
plan_id = j["payment_gateway_plan_id"]
# tier 1 is lightro / classic
# tier 2 is nitro
period_end = {
'premium_month_tier_1': '1 month',
'premium_month_tier_2': '1 month',
'premium_year_tier_1': '1 year',
'premium_year_tier_2': '1 year',
"premium_month_tier_1": "1 month",
"premium_month_tier_2": "1 month",
"premium_year_tier_1": "1 year",
"premium_year_tier_2": "1 year",
}[plan_id]
new_id = get_snowflake()
@ -422,9 +464,15 @@ async def _create_subscription():
status, period_end)
VALUES ($1, $2, $3, $4, $5, $6, $7,
now()::timestamp + interval '{period_end}')
""", new_id, j['payment_source_id'], user_id,
SubscriptionType.PURCHASE, PaymentGateway.STRIPE,
plan_id, 1)
""",
new_id,
j["payment_source_id"],
user_id,
SubscriptionType.PURCHASE,
PaymentGateway.STRIPE,
plan_id,
1,
)
await create_payment(new_id, app.db)
@ -432,21 +480,17 @@ async def _create_subscription():
# and dispatch respective user updates to other people.
await process_subscription(app, new_id)
return jsonify(
await get_subscription(new_id)
)
return jsonify(await get_subscription(new_id))
@bp.route('/@me/billing/subscriptions/<int:subscription_id>',
methods=['DELETE'])
@bp.route("/@me/billing/subscriptions/<int:subscription_id>", methods=["DELETE"])
async def _delete_subscription(subscription_id):
# user_id = await token_check()
# return '', 204
pass
@bp.route('/@me/billing/subscriptions/<int:subscription_id>',
methods=['PATCH'])
@bp.route("/@me/billing/subscriptions/<int:subscription_id>", methods=["PATCH"])
async def _patch_subscription(subscription_id):
"""change a subscription's payment source"""
# user_id = await token_check()

View File

@ -25,8 +25,11 @@ from asyncio import sleep, CancelledError
from logbook import Logger
from litecord.blueprints.user.billing import (
get_subscription, get_payment_ids, get_payment, create_payment,
process_subscription
get_subscription,
get_payment_ids,
get_payment,
create_payment,
process_subscription,
)
from litecord.snowflake import snowflake_datetime
@ -37,15 +40,15 @@ log = Logger(__name__)
# how many days until a payment needs
# to be issued
THRESHOLDS = {
'premium_month_tier_1': 30,
'premium_month_tier_2': 30,
'premium_year_tier_1': 365,
'premium_year_tier_2': 365,
"premium_month_tier_1": 30,
"premium_month_tier_2": 30,
"premium_year_tier_1": 365,
"premium_year_tier_2": 365,
}
async def _resched(app):
log.debug('waiting 30 minutes for job.')
log.debug("waiting 30 minutes for job.")
await sleep(30 * MINUTES)
app.sched.spawn(payment_job(app))
@ -54,10 +57,10 @@ async def _process_user_payments(app, user_id: int):
payments = await get_payment_ids(user_id, app.db)
if not payments:
log.debug('no payments for uid {}, skipping', user_id)
log.debug("no payments for uid {}, skipping", user_id)
return
log.debug('{} payments for uid {}', len(payments), user_id)
log.debug("{} payments for uid {}", len(payments), user_id)
latest_payment = max(payments)
@ -66,33 +69,29 @@ async def _process_user_payments(app, user_id: int):
# calculate the difference between this payment
# and now.
now = datetime.datetime.now()
payment_tstamp = snowflake_datetime(int(payment_data['id']))
payment_tstamp = snowflake_datetime(int(payment_data["id"]))
delta = now - payment_tstamp
sub_id = int(payment_data['subscription']['id'])
subscription = await get_subscription(
sub_id, app.db)
sub_id = int(payment_data["subscription"]["id"])
subscription = await get_subscription(sub_id, app.db)
# if the max payment is X days old, we create another.
# X is 30 for monthly subscriptions of nitro,
# X is 365 for yearly subscriptions of nitro
threshold = THRESHOLDS[subscription['payment_gateway_plan_id']]
threshold = THRESHOLDS[subscription["payment_gateway_plan_id"]]
log.debug('delta {} delta days {} threshold {}',
delta, delta.days, threshold)
log.debug("delta {} delta days {} threshold {}", delta, delta.days, threshold)
if delta.days > threshold:
log.info('creating payment for sid={}',
sub_id)
log.info("creating payment for sid={}", sub_id)
# create_payment does not call any Stripe
# or BrainTree APIs at all, since we'll just
# give it as free.
await create_payment(sub_id, app.db)
else:
log.debug('sid={}, missing {} days',
sub_id, threshold - delta.days)
log.debug("sid={}, missing {} days", sub_id, threshold - delta.days)
async def payment_job(app):
@ -101,35 +100,39 @@ async def payment_job(app):
This function will check through users' payments
and add a new one once a month / year.
"""
log.debug('payment job start!')
log.debug("payment job start!")
user_ids = await app.db.fetch("""
user_ids = await app.db.fetch(
"""
SELECT DISTINCT user_id
FROM user_payments
""")
"""
)
log.debug('working {} users', len(user_ids))
log.debug("working {} users", len(user_ids))
# go through each user's payments
for row in user_ids:
user_id = row['user_id']
user_id = row["user_id"]
try:
await _process_user_payments(app, user_id)
except Exception:
log.exception('error while processing user payments')
log.exception("error while processing user payments")
subscribers = await app.db.fetch("""
subscribers = await app.db.fetch(
"""
SELECT id
FROM user_subscriptions
""")
"""
)
for row in subscribers:
try:
await process_subscription(app, row['id'])
await process_subscription(app, row["id"])
except Exception:
log.exception('error while processing subscription')
log.debug('rescheduling..')
log.exception("error while processing subscription")
log.debug("rescheduling..")
try:
await _resched(app)
except CancelledError:
log.info('cancelled while waiting for resched')
log.info("cancelled while waiting for resched")

View File

@ -22,24 +22,26 @@ fake routes for discord store
"""
from quart import Blueprint, jsonify
bp = Blueprint('fake_store', __name__)
bp = Blueprint("fake_store", __name__)
@bp.route('/promotions')
@bp.route("/promotions")
async def _get_promotions():
return jsonify([])
@bp.route('/users/@me/library')
@bp.route("/users/@me/library")
async def _get_library():
return jsonify([])
@bp.route('/users/@me/feed/settings')
@bp.route("/users/@me/feed/settings")
async def _get_feed_settings():
return jsonify({
'subscribed_games': [],
'subscribed_users': [],
'unsubscribed_users': [],
'unsubscribed_games': [],
})
return jsonify(
{
"subscribed_games": [],
"subscribed_users": [],
"unsubscribed_users": [],
"unsubscribed_games": [],
}
)

View File

@ -23,10 +23,10 @@ from litecord.auth import token_check
from litecord.schemas import validate, USER_SETTINGS, GUILD_SETTINGS
from litecord.blueprints.checks import guild_check
bp = Blueprint('users_settings', __name__)
bp = Blueprint("users_settings", __name__)
@bp.route('/@me/settings', methods=['GET'])
@bp.route("/@me/settings", methods=["GET"])
async def get_user_settings():
"""Get the current user's settings."""
user_id = await token_check()
@ -34,7 +34,7 @@ async def get_user_settings():
return jsonify(settings)
@bp.route('/@me/settings', methods=['PATCH'])
@bp.route("/@me/settings", methods=["PATCH"])
async def patch_current_settings():
"""Patch the users' current settings.
@ -47,19 +47,22 @@ async def patch_current_settings():
for key in j:
val = j[key]
await app.storage.execute_with_json(f"""
await app.storage.execute_with_json(
f"""
UPDATE user_settings
SET {key}=$1
WHERE id = $2
""", val, user_id)
""",
val,
user_id,
)
settings = await app.user_storage.get_user_settings(user_id)
await app.dispatcher.dispatch_user(
user_id, 'USER_SETTINGS_UPDATE', settings)
await app.dispatcher.dispatch_user(user_id, "USER_SETTINGS_UPDATE", settings)
return jsonify(settings)
@bp.route('/@me/guilds/<int:guild_id>/settings', methods=['PATCH'])
@bp.route("/@me/guilds/<int:guild_id>/settings", methods=["PATCH"])
async def patch_guild_settings(guild_id: int):
"""Update the users' guild settings for a given guild.
@ -74,16 +77,21 @@ async def patch_guild_settings(guild_id: int):
# will make sure they exist in the table.
await app.user_storage.get_guild_settings_one(user_id, guild_id)
for field in (k for k in j.keys() if k != 'channel_overrides'):
await app.db.execute(f"""
for field in (k for k in j.keys() if k != "channel_overrides"):
await app.db.execute(
f"""
UPDATE guild_settings
SET {field} = $1
WHERE user_id = $2 AND guild_id = $3
""", j[field], user_id, guild_id)
""",
j[field],
user_id,
guild_id,
)
chan_ids = await app.storage.get_channel_ids(guild_id)
for chandata in j.get('channel_overrides', {}).items():
for chandata in j.get("channel_overrides", {}).items():
chan_id, chan_overrides = chandata
chan_id = int(chan_id)
@ -92,7 +100,8 @@ async def patch_guild_settings(guild_id: int):
continue
for field in chan_overrides:
await app.db.execute(f"""
await app.db.execute(
f"""
INSERT INTO guild_settings_channel_overrides
(user_id, guild_id, channel_id, {field})
VALUES
@ -105,18 +114,21 @@ async def patch_guild_settings(guild_id: int):
WHERE guild_settings_channel_overrides.user_id = $1
AND guild_settings_channel_overrides.guild_id = $2
AND guild_settings_channel_overrides.channel_id = $3
""", user_id, guild_id, chan_id, chan_overrides[field])
""",
user_id,
guild_id,
chan_id,
chan_overrides[field],
)
settings = await app.user_storage.get_guild_settings_one(
user_id, guild_id)
settings = await app.user_storage.get_guild_settings_one(user_id, guild_id)
await app.dispatcher.dispatch_user(
user_id, 'USER_GUILD_SETTINGS_UPDATE', settings)
await app.dispatcher.dispatch_user(user_id, "USER_GUILD_SETTINGS_UPDATE", settings)
return jsonify(settings)
@bp.route('/@me/notes/<int:target_id>', methods=['PUT'])
@bp.route("/@me/notes/<int:target_id>", methods=["PUT"])
async def put_note(target_id: int):
"""Put a note to a user.
@ -126,10 +138,11 @@ async def put_note(target_id: int):
user_id = await token_check()
j = await request.get_json()
note = str(j['note'])
note = str(j["note"])
# UPSERTs are beautiful
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO notes (user_id, target_id, note)
VALUES ($1, $2, $3)
@ -138,12 +151,14 @@ async def put_note(target_id: int):
note = $3
WHERE notes.user_id = $1
AND notes.target_id = $2
""", user_id, target_id, note)
""",
user_id,
target_id,
note,
)
await app.dispatcher.dispatch_user(user_id, 'USER_NOTE_UPDATE', {
'id': str(target_id),
'note': note,
})
return '', 204
await app.dispatcher.dispatch_user(
user_id, "USER_NOTE_UPDATE", {"id": str(target_id), "note": note}
)
return "", 204

View File

@ -27,9 +27,7 @@ from ..errors import Forbidden, BadRequest, Unauthorized
from ..schemas import validate, USER_UPDATE, GET_MENTIONS
from .guilds import guild_check
from litecord.auth import (
token_check, hash_data, check_username_usage, roll_discrim
)
from litecord.auth import token_check, hash_data, check_username_usage, roll_discrim
from litecord.blueprints.guild.mod import remove_member
from litecord.enums import PremiumType
@ -39,7 +37,7 @@ from litecord.permissions import base_permissions
from litecord.blueprints.auth import check_password
from litecord.utils import to_update
bp = Blueprint('user', __name__)
bp = Blueprint("user", __name__)
log = Logger(__name__)
@ -58,8 +56,7 @@ async def mass_user_update(user_id, app_=None):
private_user = await app_.storage.get_user(user_id, secure=True)
session_ids.extend(
await app_.dispatcher.dispatch_user(
user_id, 'USER_UPDATE', private_user)
await app_.dispatcher.dispatch_user(user_id, "USER_UPDATE", private_user)
)
guild_ids = await app_.user_storage.get_user_guilds(user_id)
@ -67,26 +64,22 @@ async def mass_user_update(user_id, app_=None):
session_ids.extend(
await app_.dispatcher.dispatch_many_filter_list(
'guild', guild_ids, session_ids,
'USER_UPDATE', public_user
"guild", guild_ids, session_ids, "USER_UPDATE", public_user
)
)
session_ids.extend(
await app_.dispatcher.dispatch_many_filter_list(
'friend', friend_ids, session_ids,
'USER_UPDATE', public_user
"friend", friend_ids, session_ids, "USER_UPDATE", public_user
)
)
await app_.dispatcher.dispatch_many(
'lazy_guild', guild_ids, 'update_user', user_id
)
await app_.dispatcher.dispatch_many("lazy_guild", guild_ids, "update_user", user_id)
return public_user, private_user
@bp.route('/@me', methods=['GET'])
@bp.route("/@me", methods=["GET"])
async def get_me():
"""Get the current user's information."""
user_id = await token_check()
@ -94,18 +87,21 @@ async def get_me():
return jsonify(user)
@bp.route('/<int:target_id>', methods=['GET'])
@bp.route("/<int:target_id>", methods=["GET"])
async def get_other(target_id):
"""Get any user, given the user ID."""
user_id = await token_check()
bot = await app.db.fetchval("""
bot = await app.db.fetchval(
"""
SELECT bot FROM users
WHERE users.id = $1
""", user_id)
""",
user_id,
)
if not bot:
raise Forbidden('Only bots can use this endpoint')
raise Forbidden("Only bots can use this endpoint")
other = await app.storage.get_user(target_id)
return jsonify(other)
@ -116,66 +112,80 @@ async def _try_username_patch(user_id, new_username: str) -> str:
discrim = None
try:
await app.db.execute("""
await app.db.execute(
"""
UPDATE users
SET username = $1
WHERE users.id = $2
""", new_username, user_id)
""",
new_username,
user_id,
)
return await app.db.fetchval("""
return await app.db.fetchval(
"""
SELECT discriminator
FROM users
WHERE users.id = $1
""", user_id)
""",
user_id,
)
except UniqueViolationError:
discrim = await roll_discrim(new_username)
if not discrim:
raise BadRequest('Unable to change username', {
'username': 'Too many people are with this username.'
})
raise BadRequest(
"Unable to change username",
{"username": "Too many people are with this username."},
)
await app.db.execute("""
await app.db.execute(
"""
UPDATE users
SET username = $1, discriminator = $2
WHERE users.id = $3
""", new_username, discrim, user_id)
""",
new_username,
discrim,
user_id,
)
return discrim
async def _try_discrim_patch(user_id, new_discrim: str):
try:
await app.db.execute("""
await app.db.execute(
"""
UPDATE users
SET discriminator = $1
WHERE id = $2
""", new_discrim, user_id)
""",
new_discrim,
user_id,
)
except UniqueViolationError:
raise BadRequest('Invalid discriminator', {
'discriminator': 'Someone already used this discriminator.'
})
raise BadRequest(
"Invalid discriminator",
{"discriminator": "Someone already used this discriminator."},
)
async def _check_pass(j, user):
# Do not do password checks on unclaimed accounts
if user['email'] is None:
if user["email"] is None:
return
if not j['password']:
raise BadRequest('password required', {
'password': 'password required'
})
if not j["password"]:
raise BadRequest("password required", {"password": "password required"})
phash = user['password_hash']
phash = user["password_hash"]
if not await check_password(phash, j['password']):
raise BadRequest('password incorrect', {
'password': 'password does not match.'
})
if not await check_password(phash, j["password"]):
raise BadRequest("password incorrect", {"password": "password does not match."})
@bp.route('/@me', methods=['PATCH'])
@bp.route("/@me", methods=["PATCH"])
async def patch_me():
"""Patch the current user's information."""
user_id = await token_check()
@ -183,36 +193,43 @@ async def patch_me():
j = validate(await request.get_json(), USER_UPDATE)
user = await app.storage.get_user(user_id, True)
user['password_hash'] = await app.db.fetchval("""
user["password_hash"] = await app.db.fetchval(
"""
SELECT password_hash
FROM users
WHERE id = $1
""", user_id)
""",
user_id,
)
if to_update(j, user, 'username'):
if to_update(j, user, "username"):
# this will take care of regenning a new discriminator
discrim = await _try_username_patch(user_id, j['username'])
user['username'] = j['username']
user['discriminator'] = discrim
discrim = await _try_username_patch(user_id, j["username"])
user["username"] = j["username"]
user["discriminator"] = discrim
if to_update(j, user, 'discriminator'):
if to_update(j, user, "discriminator"):
# the API treats discriminators as integers,
# but I work with strings on the database.
new_discrim = str(j['discriminator'])
new_discrim = str(j["discriminator"])
await _try_discrim_patch(user_id, new_discrim)
user['discriminator'] = new_discrim
user["discriminator"] = new_discrim
if to_update(j, user, 'email'):
if to_update(j, user, "email"):
await _check_pass(j, user)
# TODO: reverify the new email?
await app.db.execute("""
await app.db.execute(
"""
UPDATE users
SET email = $1
WHERE id = $2
""", j['email'], user_id)
user['email'] = j['email']
""",
j["email"],
user_id,
)
user["email"] = j["email"]
# only update if values are different
# from what the user gave.
@ -224,44 +241,49 @@ async def patch_me():
# IconManager.update will take care of validating
# the value once put()-ing
if to_update(j, user, 'avatar'):
mime, _ = parse_data_uri(j['avatar'])
if to_update(j, user, "avatar"):
mime, _ = parse_data_uri(j["avatar"])
if mime == 'image/gif' and user['premium_type'] == PremiumType.NONE:
raise BadRequest('no gif without nitro')
if mime == "image/gif" and user["premium_type"] == PremiumType.NONE:
raise BadRequest("no gif without nitro")
new_icon = await app.icons.update(
'user', user_id, j['avatar'], size=(128, 128))
new_icon = await app.icons.update("user", user_id, j["avatar"], size=(128, 128))
await app.db.execute("""
await app.db.execute(
"""
UPDATE users
SET avatar = $1
WHERE id = $2
""", new_icon.icon_hash, user_id)
""",
new_icon.icon_hash,
user_id,
)
if user['email'] is None and not 'new_password' in j:
raise BadRequest('missing password', {
'password': 'Please set a password.'
})
if user["email"] is None and not "new_password" in j:
raise BadRequest("missing password", {"password": "Please set a password."})
if 'new_password' in j and j['new_password']:
if "new_password" in j and j["new_password"]:
await _check_pass(j, user)
new_hash = await hash_data(j['new_password'])
new_hash = await hash_data(j["new_password"])
await app.db.execute("""
await app.db.execute(
"""
UPDATE users
SET password_hash = $1
WHERE id = $2
""", new_hash, user_id)
""",
new_hash,
user_id,
)
user.pop('password_hash')
user.pop("password_hash")
_, private_user = await mass_user_update(user_id, app)
return jsonify(private_user)
@bp.route('/@me/guilds', methods=['GET'])
@bp.route("/@me/guilds", methods=["GET"])
async def get_me_guilds():
"""Get partial user guilds."""
user_id = await token_check()
@ -270,27 +292,30 @@ async def get_me_guilds():
partials = []
for guild_id in guild_ids:
partial = await app.db.fetchrow("""
partial = await app.db.fetchrow(
"""
SELECT id::text, name, icon, owner_id
FROM guilds
WHERE guilds.id = $1
""", guild_id)
""",
guild_id,
)
partial = dict(partial)
user_perms = await base_permissions(user_id, guild_id)
partial['permissions'] = user_perms.binary
partial["permissions"] = user_perms.binary
partial['owner'] = partial['owner_id'] == user_id
partial["owner"] = partial["owner_id"] == user_id
partial.pop('owner_id')
partial.pop("owner_id")
partials.append(partial)
return jsonify(partials)
@bp.route('/@me/guilds/<int:guild_id>', methods=['DELETE'])
@bp.route("/@me/guilds/<int:guild_id>", methods=["DELETE"])
async def leave_guild(guild_id: int):
"""Leave a guild."""
user_id = await token_check()
@ -298,7 +323,7 @@ async def leave_guild(guild_id: int):
await remove_member(guild_id, user_id)
return '', 204
return "", 204
# @bp.route('/@me/connections', methods=['GET'])
@ -306,7 +331,7 @@ async def get_connections():
pass
@bp.route('/@me/consent', methods=['GET', 'POST'])
@bp.route("/@me/consent", methods=["GET", "POST"])
async def get_consent():
"""Always disable data collection.
@ -314,57 +339,58 @@ async def get_consent():
by the client and ignores them, as they
will always be false.
"""
return jsonify({
'usage_statistics': {
'consented': False,
},
'personalization': {
'consented': False,
return jsonify(
{
"usage_statistics": {"consented": False},
"personalization": {"consented": False},
}
})
)
@bp.route('/@me/harvest', methods=['GET'])
@bp.route("/@me/harvest", methods=["GET"])
async def get_harvest():
"""Dummy route"""
return '', 204
return "", 204
@bp.route('/@me/activities/statistics/applications', methods=['GET'])
@bp.route("/@me/activities/statistics/applications", methods=["GET"])
async def get_stats_applications():
"""Dummy route for info on gameplay time and such"""
return jsonify([])
@bp.route('/@me/library', methods=['GET'])
@bp.route("/@me/library", methods=["GET"])
async def get_library():
"""Probably related to Discord Store?"""
return jsonify([])
@bp.route('/<int:peer_id>/profile', methods=['GET'])
@bp.route("/<int:peer_id>/profile", methods=["GET"])
async def get_profile(peer_id: int):
"""Get a user's profile."""
user_id = await token_check()
peer = await app.storage.get_user(peer_id)
if not peer:
return '', 404
return "", 404
mutuals = await app.user_storage.get_mutual_guilds(user_id, peer_id)
friends = await app.user_storage.are_friends_with(user_id, peer_id)
# don't return a proper card if no guilds are being shared.
if not mutuals and not friends:
return '', 404
return "", 404
# actual premium status is determined by that
# column being NULL or not
peer_premium = await app.db.fetchval("""
peer_premium = await app.db.fetchval(
"""
SELECT premium_since
FROM users
WHERE id = $1
""", peer_id)
""",
peer_id,
)
mutual_guilds = await app.user_storage.get_mutual_guilds(user_id, peer_id)
mutual_res = []
@ -372,45 +398,49 @@ async def get_profile(peer_id: int):
# ascending sorting
for guild_id in sorted(mutual_guilds):
nick = await app.db.fetchval("""
nick = await app.db.fetchval(
"""
SELECT nickname
FROM members
WHERE guild_id = $1 AND user_id = $2
""", guild_id, peer_id)
""",
guild_id,
peer_id,
)
mutual_res.append({
'id': str(guild_id),
'nick': nick,
})
mutual_res.append({"id": str(guild_id), "nick": nick})
return jsonify({
'user': peer,
'connected_accounts': [],
'premium_since': peer_premium,
'mutual_guilds': mutual_res,
})
return jsonify(
{
"user": peer,
"connected_accounts": [],
"premium_since": peer_premium,
"mutual_guilds": mutual_res,
}
)
@bp.route('/@me/mentions', methods=['GET'])
@bp.route("/@me/mentions", methods=["GET"])
async def _get_mentions():
user_id = await token_check()
j = validate(dict(request.args), GET_MENTIONS)
guild_query = 'AND messages.guild_id = $2' if 'guild_id' in j else ''
role_query = "OR content LIKE '%<@&%'" if j['roles'] else ''
everyone_query = "OR content LIKE '%@everyone%'" if j['everyone'] else ''
mention_user = f'<@{user_id}>'
guild_query = "AND messages.guild_id = $2" if "guild_id" in j else ""
role_query = "OR content LIKE '%<@&%'" if j["roles"] else ""
everyone_query = "OR content LIKE '%@everyone%'" if j["everyone"] else ""
mention_user = f"<@{user_id}>"
args = [mention_user]
if guild_query:
args.append(j['guild_id'])
args.append(j["guild_id"])
guild_ids = await app.user_storage.get_user_guilds(user_id)
gids = ','.join(str(guild_id) for guild_id in guild_ids)
gids = ",".join(str(guild_id) for guild_id in guild_ids)
rows = await app.db.fetch(f"""
rows = await app.db.fetch(
f"""
SELECT messages.id
FROM messages
JOIN channels ON messages.channel_id = channels.id
@ -423,20 +453,20 @@ async def _get_mentions():
{guild_query}
)
LIMIT {j["limit"]}
""", *args)
""",
*args,
)
res = []
for row in rows:
message = await app.storage.get_message(row['id'])
gid = int(message['guild_id'])
message = await app.storage.get_message(row["id"])
gid = int(message["guild_id"])
# ignore messages pre-messages.guild_id
if gid not in guild_ids:
continue
res.append(
message
)
res.append(message)
return jsonify(res)
@ -449,18 +479,20 @@ def rand_hex(length: int = 8) -> str:
async def _del_from_table(db, table: str, user_id: int):
"""Delete a row from a table."""
column = {
'channel_overwrites': 'target_user',
'user_settings': 'id',
'group_dm_members': 'member_id'
}.get(table, 'user_id')
"channel_overwrites": "target_user",
"user_settings": "id",
"group_dm_members": "member_id",
}.get(table, "user_id")
res = await db.execute(f"""
res = await db.execute(
f"""
DELETE FROM {table}
WHERE {column} = $1
""", user_id)
""",
user_id,
)
log.info('Deleting uid {} from {}, res: {!r}',
user_id, table, res)
log.info("Deleting uid {} from {}, res: {!r}", user_id, table, res)
async def delete_user(user_id, *, app_=None):
@ -470,13 +502,14 @@ async def delete_user(user_id, *, app_=None):
db = app_.db
new_username = f'Deleted User {rand_hex()}'
new_username = f"Deleted User {rand_hex()}"
# by using a random hex in password_hash
# we break attempts at using the default '123' password hash
# to issue valid tokens for deleted users.
await db.execute("""
await db.execute(
"""
UPDATE users
SET
username = $1,
@ -490,32 +523,39 @@ async def delete_user(user_id, *, app_=None):
password_hash = $2
WHERE
id = $3
""", new_username, rand_hex(32), user_id)
""",
new_username,
rand_hex(32),
user_id,
)
# remove the user from various tables
await _del_from_table(db, 'user_settings', user_id)
await _del_from_table(db, 'user_payment_sources', user_id)
await _del_from_table(db, 'user_subscriptions', user_id)
await _del_from_table(db, 'user_payments', user_id)
await _del_from_table(db, 'user_read_state', user_id)
await _del_from_table(db, 'guild_settings', user_id)
await _del_from_table(db, 'guild_settings_channel_overrides', user_id)
await _del_from_table(db, "user_settings", user_id)
await _del_from_table(db, "user_payment_sources", user_id)
await _del_from_table(db, "user_subscriptions", user_id)
await _del_from_table(db, "user_payments", user_id)
await _del_from_table(db, "user_read_state", user_id)
await _del_from_table(db, "guild_settings", user_id)
await _del_from_table(db, "guild_settings_channel_overrides", user_id)
await db.execute("""
await db.execute(
"""
DELETE FROM relationships
WHERE user_id = $1 OR peer_id = $1
""", user_id)
""",
user_id,
)
# DMs are still maintained, but not the state.
await _del_from_table(db, 'dm_channel_state', user_id)
await _del_from_table(db, "dm_channel_state", user_id)
# NOTE: we don't delete the group dms the user is an owner of...
# TODO: group dm owner reassign when the owner leaves a gdm
await _del_from_table(db, 'group_dm_members', user_id)
await _del_from_table(db, "group_dm_members", user_id)
await _del_from_table(db, 'members', user_id)
await _del_from_table(db, 'member_roles', user_id)
await _del_from_table(db, 'channel_overwrites', user_id)
await _del_from_table(db, "members", user_id)
await _del_from_table(db, "member_roles", user_id)
await _del_from_table(db, "channel_overwrites", user_id)
# after updating the user, we send USER_UPDATE so that all the other
# clients can refresh their caches on the now-deleted user
@ -540,15 +580,12 @@ async def user_disconnect(user_id: int):
await state.ws.ws.close(4000)
# force everyone to see the user as offline
await app.presence.dispatch_pres(user_id, {
'afk': False,
'status': 'offline',
'game': None,
'since': 0,
})
await app.presence.dispatch_pres(
user_id, {"afk": False, "status": "offline", "game": None, "since": 0}
)
@bp.route('/@me/delete', methods=['POST'])
@bp.route("/@me/delete", methods=["POST"])
async def delete_account():
"""Delete own account.
@ -560,29 +597,35 @@ async def delete_account():
j = await request.get_json()
try:
password = j['password']
password = j["password"]
except KeyError:
raise BadRequest('password required')
raise BadRequest("password required")
owned_guilds = await app.db.fetchval("""
owned_guilds = await app.db.fetchval(
"""
SELECT COUNT(*)
FROM guilds
WHERE owner_id = $1
""", user_id)
""",
user_id,
)
if owned_guilds > 0:
raise BadRequest('You still own guilds.')
raise BadRequest("You still own guilds.")
pwd_hash = await app.db.fetchval("""
pwd_hash = await app.db.fetchval(
"""
SELECT password_hash
FROM users
WHERE id = $1
""", user_id)
""",
user_id,
)
if not await check_password(pwd_hash, password):
raise Unauthorized('password does not match')
raise Unauthorized("password does not match")
await delete_user(user_id)
await user_disconnect(user_id)
return '', 204
return "", 204

View File

@ -25,7 +25,7 @@ from quart import Blueprint, jsonify, current_app as app
from litecord.blueprints.auth import token_check
bp = Blueprint('voice', __name__)
bp = Blueprint("voice", __name__)
def _majority_region_count(regions: list) -> str:
@ -39,12 +39,14 @@ def _majority_region_count(regions: list) -> str:
async def _choose_random_region() -> Optional[str]:
"""Give a random voice region."""
regions = await app.db.fetch("""
regions = await app.db.fetch(
"""
SELECT id
FROM voice_regions
""")
"""
)
regions = [r['id'] for r in regions]
regions = [r["id"] for r in regions]
if not regions:
return None
@ -64,11 +66,14 @@ async def _majority_region_any(user_id) -> Optional[str]:
res = []
for guild_id in guilds:
region = await app.db.fetchval("""
region = await app.db.fetchval(
"""
SELECT region
FROM guilds
WHERE id = $1
""", guild_id)
""",
guild_id,
)
res.append(region)
@ -83,20 +88,23 @@ async def _majority_region_any(user_id) -> Optional[str]:
async def majority_region(user_id: int) -> Optional[str]:
"""Given a user ID, give the most likely region for the user to be
happy with."""
regions = await app.db.fetch("""
regions = await app.db.fetch(
"""
SELECT region
FROM guilds
WHERE owner_id = $1
""", user_id)
""",
user_id,
)
if not regions:
return await _majority_region_any(user_id)
regions = [r['region'] for r in regions]
regions = [r["region"] for r in regions]
return _majority_region_count(regions)
@bp.route('/regions', methods=['GET'])
@bp.route("/regions", methods=["GET"])
async def voice_regions():
"""Return voice regions."""
user_id = await token_check()
@ -105,6 +113,6 @@ async def voice_regions():
regions = await app.storage.all_voice_regions()
for region in regions:
region['optimal'] = region['id'] == best_region
region["optimal"] = region["id"] == best_region
return jsonify(regions)

View File

@ -26,22 +26,28 @@ from quart import Blueprint, jsonify, current_app as app, request
from litecord.auth import token_check
from litecord.blueprints.checks import (
channel_check, channel_perm_check, guild_check, guild_perm_check
channel_check,
channel_perm_check,
guild_check,
guild_perm_check,
)
from litecord.schemas import (
validate, WEBHOOK_CREATE, WEBHOOK_UPDATE, WEBHOOK_MESSAGE_CREATE
validate,
WEBHOOK_CREATE,
WEBHOOK_UPDATE,
WEBHOOK_MESSAGE_CREATE,
)
from litecord.enums import ChannelType
from litecord.snowflake import get_snowflake
from litecord.utils import async_map
from litecord.errors import (
WebhookNotFound, Unauthorized, ChannelNotFound, BadRequest
)
from litecord.errors import WebhookNotFound, Unauthorized, ChannelNotFound, BadRequest
from litecord.blueprints.channel.messages import (
msg_create_request, msg_create_check_content, msg_add_attachment,
msg_guild_text_mentions
msg_create_request,
msg_create_check_content,
msg_add_attachment,
msg_guild_text_mentions,
)
from litecord.embed.sanitizer import fill_embed, fetch_raw_img
from litecord.embed.messages import process_url_embed, is_media_url
@ -50,30 +56,34 @@ from litecord.utils import pg_set_json
from litecord.enums import MessageType
from litecord.images import STATIC_IMAGE_MIMES
bp = Blueprint('webhooks', __name__)
bp = Blueprint("webhooks", __name__)
async def get_webhook(webhook_id: int, *,
secure: bool=True) -> Optional[Dict[str, Any]]:
async def get_webhook(
webhook_id: int, *, secure: bool = True
) -> Optional[Dict[str, Any]]:
"""Get a webhook data"""
row = await app.db.fetchrow("""
row = await app.db.fetchrow(
"""
SELECT id::text, guild_id::text, channel_id::text, creator_id,
name, avatar, token
FROM webhooks
WHERE id = $1
""", webhook_id)
""",
webhook_id,
)
if not row:
return None
drow = dict(row)
drow['user'] = await app.storage.get_user(row['creator_id'])
drow.pop('creator_id')
drow["user"] = await app.storage.get_user(row["creator_id"])
drow.pop("creator_id")
if not secure:
drow.pop('user')
drow.pop('guild_id')
drow.pop("user")
drow.pop("guild_id")
return drow
@ -82,7 +92,7 @@ async def _webhook_check(channel_id):
user_id = await token_check()
await channel_check(user_id, channel_id, only=ChannelType.GUILD_TEXT)
await channel_perm_check(user_id, channel_id, 'manage_webhooks')
await channel_perm_check(user_id, channel_id, "manage_webhooks")
return user_id
@ -91,17 +101,20 @@ async def _webhook_check_guild(guild_id):
user_id = await token_check()
await guild_check(user_id, guild_id)
await guild_perm_check(user_id, guild_id, 'manage_webhooks')
await guild_perm_check(user_id, guild_id, "manage_webhooks")
return user_id
async def _webhook_check_fw(webhook_id):
"""Make a check from an incoming webhook id (fw = from webhook)."""
guild_id = await app.db.fetchval("""
guild_id = await app.db.fetchval(
"""
SELECT guild_id FROM webhooks
WHERE id = $1
""", webhook_id)
""",
webhook_id,
)
if guild_id is None:
raise WebhookNotFound()
@ -110,42 +123,48 @@ async def _webhook_check_fw(webhook_id):
async def _webhook_many(where_clause, arg: int):
webhook_ids = await app.db.fetch(f"""
webhook_ids = await app.db.fetch(
f"""
SELECT id
FROM webhooks
{where_clause}
""", arg)
webhook_ids = [r['id'] for r in webhook_ids]
return jsonify(
await async_map(get_webhook, webhook_ids)
""",
arg,
)
webhook_ids = [r["id"] for r in webhook_ids]
return jsonify(await async_map(get_webhook, webhook_ids))
async def webhook_token_check(webhook_id: int, webhook_token: str):
"""token_check() equivalent for webhooks."""
row = await app.db.fetchrow("""
row = await app.db.fetchrow(
"""
SELECT guild_id, channel_id
FROM webhooks
WHERE id = $1 AND token = $2
""", webhook_id, webhook_token)
""",
webhook_id,
webhook_token,
)
if row is None:
raise Unauthorized('webhook not found or unauthorized')
raise Unauthorized("webhook not found or unauthorized")
return row['guild_id'], row['channel_id']
return row["guild_id"], row["channel_id"]
async def _dispatch_webhook_update(guild_id: int, channel_id):
await app.dispatcher.dispatch('guild', guild_id, 'WEBHOOKS_UPDATE', {
'guild_id': str(guild_id),
'channel_id': str(channel_id)
})
await app.dispatcher.dispatch(
"guild",
guild_id,
"WEBHOOKS_UPDATE",
{"guild_id": str(guild_id), "channel_id": str(channel_id)},
)
@bp.route('/channels/<int:channel_id>/webhooks', methods=['POST'])
@bp.route("/channels/<int:channel_id>/webhooks", methods=["POST"])
async def create_webhook(channel_id: int):
"""Create a webhook given a channel."""
user_id = await _webhook_check(channel_id)
@ -162,8 +181,7 @@ async def create_webhook(channel_id: int):
token = secrets.token_urlsafe(40)
webhook_icon = await app.icons.put(
'user', webhook_id, j.get('avatar'),
always_icon=True, size=(128, 128)
"user", webhook_id, j.get("avatar"), always_icon=True, size=(128, 128)
)
await app.db.execute(
@ -173,36 +191,41 @@ async def create_webhook(channel_id: int):
VALUES
($1, $2, $3, $4, $5, $6, $7)
""",
webhook_id, guild_id, channel_id, user_id,
j['name'], webhook_icon.icon_hash, token
webhook_id,
guild_id,
channel_id,
user_id,
j["name"],
webhook_icon.icon_hash,
token,
)
await _dispatch_webhook_update(guild_id, channel_id)
return jsonify(await get_webhook(webhook_id))
@bp.route('/channels/<int:channel_id>/webhooks', methods=['GET'])
@bp.route("/channels/<int:channel_id>/webhooks", methods=["GET"])
async def get_channel_webhook(channel_id: int):
"""Get a list of webhooks in a channel"""
await _webhook_check(channel_id)
return await _webhook_many('WHERE channel_id = $1', channel_id)
return await _webhook_many("WHERE channel_id = $1", channel_id)
@bp.route('/guilds/<int:guild_id>/webhooks', methods=['GET'])
@bp.route("/guilds/<int:guild_id>/webhooks", methods=["GET"])
async def get_guild_webhook(guild_id):
"""Get all webhooks in a guild"""
await _webhook_check_guild(guild_id)
return await _webhook_many('WHERE guild_id = $1', guild_id)
return await _webhook_many("WHERE guild_id = $1", guild_id)
@bp.route('/webhooks/<int:webhook_id>', methods=['GET'])
@bp.route("/webhooks/<int:webhook_id>", methods=["GET"])
async def get_single_webhook(webhook_id):
"""Get a single webhook's information."""
await _webhook_check_fw(webhook_id)
return await jsonify(await get_webhook(webhook_id))
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>', methods=['GET'])
@bp.route("/webhooks/<int:webhook_id>/<webhook_token>", methods=["GET"])
async def get_tokened_webhook(webhook_id, webhook_token):
"""Get a webhook using its token."""
await webhook_token_check(webhook_id, webhook_token)
@ -210,46 +233,58 @@ async def get_tokened_webhook(webhook_id, webhook_token):
async def _update_webhook(webhook_id: int, j: dict):
if 'name' in j:
await app.db.execute("""
if "name" in j:
await app.db.execute(
"""
UPDATE webhooks
SET name = $1
WHERE id = $2
""", j['name'], webhook_id)
""",
j["name"],
webhook_id,
)
if 'channel_id' in j:
await app.db.execute("""
if "channel_id" in j:
await app.db.execute(
"""
UPDATE webhooks
SET channel_id = $1
WHERE id = $2
""", j['channel_id'], webhook_id)
if 'avatar' in j:
new_icon = await app.icons.update(
'user', webhook_id, j['avatar'], always_icon=True, size=(128, 128)
""",
j["channel_id"],
webhook_id,
)
await app.db.execute("""
if "avatar" in j:
new_icon = await app.icons.update(
"user", webhook_id, j["avatar"], always_icon=True, size=(128, 128)
)
await app.db.execute(
"""
UPDATE webhooks
SET icon = $1
WHERE id = $2
""", new_icon.icon_hash, webhook_id)
""",
new_icon.icon_hash,
webhook_id,
)
@bp.route('/webhooks/<int:webhook_id>', methods=['PATCH'])
@bp.route("/webhooks/<int:webhook_id>", methods=["PATCH"])
async def modify_webhook(webhook_id: int):
"""Patch a webhook."""
_user_id, guild_id = await _webhook_check_fw(webhook_id)
j = validate(await request.get_json(), WEBHOOK_UPDATE)
if 'channel_id' in j:
if "channel_id" in j:
# pre checks
chan = await app.storage.get_channel(j['channel_id'])
chan = await app.storage.get_channel(j["channel_id"])
# short-circuiting should ensure chan isn't none
# by the time we do chan['guild_id']
if chan and chan['guild_id'] != str(guild_id):
raise ChannelNotFound('cant assign webhook to channel')
if chan and chan["guild_id"] != str(guild_id):
raise ChannelNotFound("cant assign webhook to channel")
await _update_webhook(webhook_id, j)
@ -257,20 +292,18 @@ async def modify_webhook(webhook_id: int):
# we don't need to cast channel_id to int since that isn't
# used in the dispatcher call
await _dispatch_webhook_update(guild_id, webhook['channel_id'])
await _dispatch_webhook_update(guild_id, webhook["channel_id"])
return jsonify(webhook)
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>', methods=['PATCH'])
@bp.route("/webhooks/<int:webhook_id>/<webhook_token>", methods=["PATCH"])
async def modify_webhook_tokened(webhook_id, webhook_token):
"""Modify a webhook, using its token."""
guild_id, channel_id = await webhook_token_check(
webhook_id, webhook_token)
guild_id, channel_id = await webhook_token_check(webhook_id, webhook_token)
# forcefully pop() the channel id out of the schema
# instead of making another, for simplicity's sake
j = validate(await request.get_json(),
WEBHOOK_UPDATE.pop('channel_id'))
j = validate(await request.get_json(), WEBHOOK_UPDATE.pop("channel_id"))
await _update_webhook(webhook_id, j)
await _dispatch_webhook_update(guild_id, channel_id)
@ -281,35 +314,36 @@ async def delete_webhook(webhook_id: int):
"""Delete a webhook."""
webhook = await get_webhook(webhook_id)
res = await app.db.execute("""
res = await app.db.execute(
"""
DELETE FROM webhooks
WHERE id = $1
""", webhook_id)
""",
webhook_id,
)
if res.lower() == 'delete 0':
if res.lower() == "delete 0":
raise WebhookNotFound()
# only casting the guild id since that's whats used
# on the dispatcher call.
await _dispatch_webhook_update(
int(webhook['guild_id']), webhook['channel_id']
)
await _dispatch_webhook_update(int(webhook["guild_id"]), webhook["channel_id"])
@bp.route('/webhooks/<int:webhook_id>', methods=['DELETE'])
@bp.route("/webhooks/<int:webhook_id>", methods=["DELETE"])
async def del_webhook(webhook_id):
"""Delete a webhook."""
await _webhook_check_fw(webhook_id)
await delete_webhook(webhook_id)
return '', 204
return "", 204
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>', methods=['DELETE'])
@bp.route("/webhooks/<int:webhook_id>/<webhook_token>", methods=["DELETE"])
async def del_webhook_tokened(webhook_id, webhook_token):
"""Delete a webhook, with its token."""
await webhook_token_check(webhook_id, webhook_token)
await delete_webhook(webhook_id)
return '', 204
return "", 204
async def create_message_webhook(guild_id, channel_id, webhook_id, data):
@ -328,23 +362,27 @@ async def create_message_webhook(guild_id, channel_id, webhook_id, data):
message_id,
channel_id,
guild_id,
data['content'],
data['tts'],
data['everyone_mention'],
data["content"],
data["tts"],
data["everyone_mention"],
MessageType.DEFAULT.value,
data.get('embeds', [])
data.get("embeds", []),
)
info = data['info']
info = data["info"]
await conn.execute("""
await conn.execute(
"""
INSERT INTO message_webhook_info
(message_id, webhook_id, name, avatar)
VALUES
($1, $2, $3, $4)
""", message_id, webhook_id, info['name'], info['avatar'])
""",
message_id,
webhook_id,
info["name"],
info["avatar"],
)
return message_id
@ -354,10 +392,15 @@ async def _webhook_avy_redir(webhook_id: int, avatar_url: EmbedURL):
url_hash = hashlib.sha256(avatar_url.to_md_path.encode()).hexdigest()
try:
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO webhook_avatars (webhook_id, hash, md_url_redir)
VALUES ($1, $2, $3)
""", webhook_id, url_hash, avatar_url.url)
""",
webhook_id,
url_hash,
avatar_url.url,
)
except asyncpg.UniqueViolationError:
pass
@ -371,36 +414,36 @@ async def _create_avatar(webhook_id: int, avatar_url: EmbedURL) -> str:
Litecord will write an URL that redirects to the given avatar_url,
using mediaproxy.
"""
if avatar_url.scheme not in ('http', 'https'):
raise BadRequest('invalid avatar url scheme')
if avatar_url.scheme not in ("http", "https"):
raise BadRequest("invalid avatar url scheme")
if not is_media_url(avatar_url):
raise BadRequest('url is not media url')
raise BadRequest("url is not media url")
# we still fetch the URL to check its validity, mimetypes, etc
# but in the end, we will store it under the webhook_avatars table,
# not IconManager.
resp, raw = await fetch_raw_img(avatar_url)
#raw_b64 = base64.b64encode(raw).decode()
# raw_b64 = base64.b64encode(raw).decode()
mime = resp.headers['content-type']
mime = resp.headers["content-type"]
# TODO: apng checks are missing (for this and everywhere else)
if mime not in STATIC_IMAGE_MIMES:
raise BadRequest('invalid mime type for given url')
raise BadRequest("invalid mime type for given url")
#b64_data = f'data:{mime};base64,{raw_b64}'
# b64_data = f'data:{mime};base64,{raw_b64}'
# TODO: replace this by webhook_avatars
#icon = await app.icons.put(
# icon = await app.icons.put(
# 'user', webhook_id, b64_data,
# always_icon=True, size=(128, 128)
#)
# )
return await _webhook_avy_redir(webhook_id, avatar_url)
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>', methods=['POST'])
@bp.route("/webhooks/<int:webhook_id>/<webhook_token>", methods=["POST"])
async def execute_webhook(webhook_id: int, webhook_token):
"""Execute a webhook. Sends a message to the channel the webhook
is tied to."""
@ -413,41 +456,39 @@ async def execute_webhook(webhook_id: int, webhook_token):
# NOTE: we really pop here instead of adding a kwarg
# to msg_create_request just because of webhooks.
# nonce isn't allowed on WEBHOOK_MESSAGE_CREATE
payload_json.pop('nonce')
payload_json.pop("nonce")
j = validate(payload_json, WEBHOOK_MESSAGE_CREATE)
msg_create_check_content(j, files)
# webhooks don't need permissions.
mentions_everyone = '@everyone' in j['content']
mentions_here = '@here' in j['content']
mentions_everyone = "@everyone" in j["content"]
mentions_here = "@here" in j["content"]
given_embeds = j.get('embeds', [])
given_embeds = j.get("embeds", [])
webhook = await get_webhook(webhook_id)
# webhooks have TWO avatars. one is from settings, the other is from
# the json's icon_url. one can be handled gracefully by IconManager,
# but the other can't, at all.
avatar = webhook['avatar']
avatar = webhook["avatar"]
if 'avatar_url' in j and j['avatar_url'] is not None:
avatar = await _create_avatar(webhook_id, j['avatar_url'])
if "avatar_url" in j and j["avatar_url"] is not None:
avatar = await _create_avatar(webhook_id, j["avatar_url"])
message_id = await create_message_webhook(
guild_id, channel_id, webhook_id, {
'content': j.get('content', ''),
'tts': j.get('tts', False),
'everyone_mention': mentions_everyone or mentions_here,
'embeds': await async_map(fill_embed, given_embeds),
'info': {
'name': j.get('username', webhook['name']),
'avatar': avatar
}
}
guild_id,
channel_id,
webhook_id,
{
"content": j.get("content", ""),
"tts": j.get("tts", False),
"everyone_mention": mentions_everyone or mentions_here,
"embeds": await async_map(fill_embed, given_embeds),
"info": {"name": j.get("username", webhook["name"]), "avatar": avatar},
},
)
for pre_attachment in files:
@ -455,33 +496,28 @@ async def execute_webhook(webhook_id: int, webhook_token):
payload = await app.storage.get_message(message_id)
await app.dispatcher.dispatch('channel', channel_id,
'MESSAGE_CREATE', payload)
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload)
# spawn embedder in the background, even when we're on a webhook.
app.sched.spawn(
process_url_embed(
app.config, app.storage, app.dispatcher, app.session,
payload
)
process_url_embed(app.config, app.storage, app.dispatcher, app.session, payload)
)
# we can assume its a guild text channel, so just call it
await msg_guild_text_mentions(
payload, guild_id, mentions_everyone, mentions_here)
await msg_guild_text_mentions(payload, guild_id, mentions_everyone, mentions_here)
# TODO: is it really 204?
return '', 204
return "", 204
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>/slack',
methods=['POST'])
@bp.route("/webhooks/<int:webhook_id>/<webhook_token>/slack", methods=["POST"])
async def execute_slack_webhook(webhook_id, webhook_token):
"""Execute a webhook but expecting Slack data."""
# TODO: know slack webhooks
await webhook_token_check(webhook_id, webhook_token)
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>/github', methods=['POST'])
@bp.route("/webhooks/<int:webhook_id>/<webhook_token>/github", methods=["POST"])
async def execute_github_webhook(webhook_id, webhook_token):
"""Execute a webhook but expecting GitHub data."""
# TODO: know github webhooks

View File

@ -21,9 +21,14 @@ from typing import List, Any, Dict
from logbook import Logger
from .pubsub import GuildDispatcher, MemberDispatcher, \
UserDispatcher, ChannelDispatcher, FriendDispatcher, \
LazyGuildDispatcher
from .pubsub import (
GuildDispatcher,
MemberDispatcher,
UserDispatcher,
ChannelDispatcher,
FriendDispatcher,
LazyGuildDispatcher,
)
log = Logger(__name__)
@ -44,17 +49,18 @@ class EventDispatcher:
when dispatching, the backend can do its own logic, given
its subscriber ids.
"""
def __init__(self, app):
self.state_manager = app.state_manager
self.app = app
self.backends = {
'guild': GuildDispatcher(self),
'member': MemberDispatcher(self),
'channel': ChannelDispatcher(self),
'user': UserDispatcher(self),
'friend': FriendDispatcher(self),
'lazy_guild': LazyGuildDispatcher(self),
"guild": GuildDispatcher(self),
"member": MemberDispatcher(self),
"channel": ChannelDispatcher(self),
"user": UserDispatcher(self),
"friend": FriendDispatcher(self),
"lazy_guild": LazyGuildDispatcher(self),
}
async def action(self, backend_str: str, action: str, key, identifier, *args):
@ -71,13 +77,13 @@ class EventDispatcher:
return await method(key, identifier, *args)
async def subscribe(self, backend: str, key: Any, identifier: Any,
flags: Dict[str, Any] = None):
async def subscribe(
self, backend: str, key: Any, identifier: Any, flags: Dict[str, Any] = None
):
"""Subscribe a single element to the given backend."""
flags = flags or {}
log.debug('SUB backend={} key={} <= id={}',
backend, key, identifier, backend)
log.debug("SUB backend={} key={} <= id={}", backend, key, identifier, backend)
# this is a hacky solution for backwards compatibility between backends
# that implement flags and backends that don't.
@ -85,16 +91,15 @@ class EventDispatcher:
# passing flags to backends that don't implement flags will
# cause errors as expected.
if flags:
return await self.action(backend, 'sub', key, identifier, flags)
return await self.action(backend, "sub", key, identifier, flags)
return await self.action(backend, 'sub', key, identifier)
return await self.action(backend, "sub", key, identifier)
async def unsubscribe(self, backend: str, key: Any, identifier: Any):
"""Unsubscribe an element from the given backend."""
log.debug('UNSUB backend={} key={} => id={}',
backend, key, identifier, backend)
log.debug("UNSUB backend={} key={} => id={}", backend, key, identifier, backend)
return await self.action(backend, 'unsub', key, identifier)
return await self.action(backend, "unsub", key, identifier)
async def sub(self, backend, key, identifier):
"""Alias to subscribe()."""
@ -104,8 +109,13 @@ class EventDispatcher:
"""Alias to unsubscribe()."""
return await self.unsubscribe(backend, key, identifier)
async def sub_many(self, backend_str: str, identifier: Any,
keys: list, flags: Dict[str, Any] = None):
async def sub_many(
self,
backend_str: str,
identifier: Any,
keys: list,
flags: Dict[str, Any] = None,
):
"""Subscribe to multiple channels (all in a single backend)
at a time.
@ -116,8 +126,7 @@ class EventDispatcher:
for key in keys:
await self.subscribe(backend_str, key, identifier, flags)
async def mass_sub(self, identifier: Any,
backends: List[tuple]):
async def mass_sub(self, identifier: Any, backends: List[tuple]):
"""Mass subscribe to many backends at once."""
for bcall in backends:
backend_str, keys = bcall[0], bcall[1]
@ -128,8 +137,13 @@ class EventDispatcher:
# we have flags
flags = bcall[2]
log.debug('subscribing {} to {} keys in backend {}, flags: {}',
identifier, len(keys), backend_str, flags)
log.debug(
"subscribing {} to {} keys in backend {}, flags: {}",
identifier,
len(keys),
backend_str,
flags,
)
await self.sub_many(backend_str, identifier, keys, flags)
@ -145,17 +159,14 @@ class EventDispatcher:
key = backend.KEY_TYPE(key)
return await backend.dispatch(key, *args, **kwargs)
async def dispatch_many(self, backend_str: str,
keys: List[Any], *args, **kwargs):
async def dispatch_many(self, backend_str: str, keys: List[Any], *args, **kwargs):
"""Dispatch to multiple keys in a single backend."""
log.info('MULTI DISPATCH: {!r}, {} keys',
backend_str, len(keys))
log.info("MULTI DISPATCH: {!r}, {} keys", backend_str, len(keys))
for key in keys:
await self.dispatch(backend_str, key, *args, **kwargs)
async def dispatch_filter(self, backend_str: str,
key: Any, func, *args):
async def dispatch_filter(self, backend_str: str, key: Any, func, *args):
"""Dispatch to a backend that only accepts
(event, data) arguments with an optional filter
function."""
@ -163,9 +174,9 @@ class EventDispatcher:
key = backend.KEY_TYPE(key)
return await backend.dispatch_filter(key, func, *args)
async def dispatch_many_filter_list(self, backend_str: str,
keys: List[Any], sess_list: List[str],
*args):
async def dispatch_many_filter_list(
self, backend_str: str, keys: List[Any], sess_list: List[str], *args
):
"""Make a "unique" dispatch given a list of session ids.
This only works for backends that have a dispatch_filter
@ -175,9 +186,8 @@ class EventDispatcher:
for key in keys:
sess_list.extend(
await self.dispatch_filter(
backend_str, key,
lambda sess_id: sess_id not in sess_list,
*args)
backend_str, key, lambda sess_id: sess_id not in sess_list, *args
)
)
return sess_list
@ -197,12 +207,12 @@ class EventDispatcher:
async def dispatch_guild(self, guild_id, event, data):
"""Backwards compatibility with old EventDispatcher."""
return await self.dispatch('guild', guild_id, event, data)
return await self.dispatch("guild", guild_id, event, data)
async def dispatch_user_guild(self, user_id, guild_id, event, data):
"""Backwards compatibility with old EventDispatcher."""
return await self.dispatch('member', (guild_id, user_id), event, data)
return await self.dispatch("member", (guild_id, user_id), event, data)
async def dispatch_user(self, user_id, event, data):
"""Backwards compatibility with old EventDispatcher."""
return await self.dispatch('user', user_id, event, data)
return await self.dispatch("user", user_id, event, data)

View File

@ -19,4 +19,4 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
from .sanitizer import sanitize_embed
__all__ = ['sanitize_embed']
__all__ = ["sanitize_embed"]

View File

@ -30,11 +30,7 @@ from litecord.embed.schemas import EmbedURL
log = Logger(__name__)
MEDIA_EXTENSIONS = (
'png',
'jpg', 'jpeg',
'gif', 'webm'
)
MEDIA_EXTENSIONS = ("png", "jpg", "jpeg", "gif", "webm")
async def insert_media_meta(url, config, session):
@ -45,18 +41,18 @@ async def insert_media_meta(url, config, session):
if meta is None:
return
if not meta['image']:
if not meta["image"]:
return
return {
'type': 'image',
'url': url,
'thumbnail': {
'width': meta['width'],
'height': meta['height'],
'url': url,
'proxy_url': img_proxy_url
}
"type": "image",
"url": url,
"thumbnail": {
"width": meta["width"],
"height": meta["height"],
"url": url,
"proxy_url": img_proxy_url,
},
}
@ -64,29 +60,32 @@ async def msg_update_embeds(payload, new_embeds, storage, dispatcher):
"""Update the message with the given embeds and dispatch a MESSAGE_UPDATE
to users."""
message_id = int(payload['id'])
channel_id = int(payload['channel_id'])
message_id = int(payload["id"])
channel_id = int(payload["channel_id"])
await storage.execute_with_json("""
await storage.execute_with_json(
"""
UPDATE messages
SET embeds = $1
WHERE messages.id = $2
""", new_embeds, message_id)
""",
new_embeds,
message_id,
)
update_payload = {
'id': str(message_id),
'channel_id': str(channel_id),
'embeds': new_embeds,
"id": str(message_id),
"channel_id": str(channel_id),
"embeds": new_embeds,
}
if 'guild_id' in payload:
update_payload['guild_id'] = payload['guild_id']
if "guild_id" in payload:
update_payload["guild_id"] = payload["guild_id"]
if 'flags' in payload:
update_payload['flags'] = payload['flags']
if "flags" in payload:
update_payload["flags"] = payload["flags"]
await dispatcher.dispatch(
'channel', channel_id, 'MESSAGE_UPDATE', update_payload)
await dispatcher.dispatch("channel", channel_id, "MESSAGE_UPDATE", update_payload)
def is_media_url(url) -> bool:
@ -98,7 +97,7 @@ def is_media_url(url) -> bool:
parsed = urllib.parse.urlparse(url)
path = Path(parsed.path)
extension = path.suffix.lstrip('.')
extension = path.suffix.lstrip(".")
return extension in MEDIA_EXTENSIONS
@ -109,20 +108,20 @@ async def insert_mp_embed(parsed, config, session):
return embed
async def process_url_embed(config, storage, dispatcher,
session, payload: dict, *, delay=0):
async def process_url_embed(
config, storage, dispatcher, session, payload: dict, *, delay=0
):
"""Process URLs in a message and generate embeds based on that."""
await asyncio.sleep(delay)
message_id = int(payload['id'])
message_id = int(payload["id"])
# if we already have embeds
# we shouldn't add our own.
embeds = payload['embeds']
embeds = payload["embeds"]
if embeds:
log.debug('url processor: ignoring existing embeds @ mid {}',
message_id)
log.debug("url processor: ignoring existing embeds @ mid {}", message_id)
return
# now, we have two types of embeds:
@ -130,7 +129,7 @@ async def process_url_embed(config, storage, dispatcher,
# - url embeds
# use regex to get URLs
urls = re.findall(r'(https?://\S+)', payload['content'])
urls = re.findall(r"(https?://\S+)", payload["content"])
urls = urls[:5]
# from there, we need to parse each found url and check its path.
@ -159,7 +158,6 @@ async def process_url_embed(config, storage, dispatcher,
if not new_embeds:
return
log.debug('made {} embeds for mid {}',
len(new_embeds), message_id)
log.debug("made {} embeds for mid {}", len(new_embeds), message_id)
await msg_update_embeds(payload, new_embeds, storage, dispatcher)

View File

@ -39,9 +39,7 @@ def sanitize_embed(embed: Embed) -> Embed:
This is non-complex sanitization as it doesn't
need the app object.
"""
return {**embed, **{
'type': 'rich'
}}
return {**embed, **{"type": "rich"}}
def path_exists(embed: Embed, components_in: Union[List[str], str]):
@ -55,7 +53,7 @@ def path_exists(embed: Embed, components_in: Union[List[str], str]):
# get the list of components given
if isinstance(components_in, str):
components = components_in.split('.')
components = components_in.split(".")
else:
components = list(components_in)
@ -77,7 +75,6 @@ def path_exists(embed: Embed, components_in: Union[List[str], str]):
return False
def _mk_cfg_sess(config, session) -> tuple:
"""Return a tuple of (config, session)."""
if config is None:
@ -91,11 +88,11 @@ def _mk_cfg_sess(config, session) -> tuple:
def _md_base(config) -> Optional[tuple]:
"""Return the protocol and base url for the mediaproxy."""
md_base_url = config['MEDIA_PROXY']
md_base_url = config["MEDIA_PROXY"]
if md_base_url is None:
return None
proto = 'https' if config['IS_SSL'] else 'http'
proto = "https" if config["IS_SSL"] else "http"
return proto, md_base_url
@ -111,7 +108,7 @@ def make_md_req_url(config, scope: str, url):
return url.url if isinstance(url, EmbedURL) else url
proto, base_url = base
return f'{proto}://{base_url}/{scope}/{url.to_md_path}'
return f"{proto}://{base_url}/{scope}/{url.to_md_path}"
def proxify(url, *, config=None) -> str:
@ -122,11 +119,12 @@ def proxify(url, *, config=None) -> str:
if isinstance(url, str):
url = EmbedURL(url)
return make_md_req_url(config, 'img', url)
return make_md_req_url(config, "img", url)
async def _md_client_req(config, session, scope: str,
url, *, ret_resp=False) -> Optional[Union[Tuple, Dict]]:
async def _md_client_req(
config, session, scope: str, url, *, ret_resp=False
) -> Optional[Union[Tuple, Dict]]:
"""Makes a request to the mediaproxy.
This has common code between all the main mediaproxy request functions
@ -172,17 +170,13 @@ async def _md_client_req(config, session, scope: str,
return await resp.json()
body = await resp.text()
log.warning('failed to call {!r}, {} {!r}',
request_url, resp.status, body)
log.warning("failed to call {!r}, {} {!r}", request_url, resp.status, body)
return None
async def fetch_metadata(url, *, config=None,
session=None) -> Optional[Dict]:
async def fetch_metadata(url, *, config=None, session=None) -> Optional[Dict]:
"""Fetch metadata for a url (image width, mime, etc)."""
return await _md_client_req(
config, session, 'meta', url
)
return await _md_client_req(config, session, "meta", url)
async def fetch_raw_img(url, *, config=None, session=None) -> Optional[tuple]:
@ -191,9 +185,7 @@ async def fetch_raw_img(url, *, config=None, session=None) -> Optional[tuple]:
Returns a tuple containing the response object and the raw bytes given by
the website.
"""
tup = await _md_client_req(
config, session, 'img', url, ret_resp=True
)
tup = await _md_client_req(config, session, "img", url, ret_resp=True)
if not tup:
return None
@ -207,9 +199,7 @@ async def fetch_embed(url, *, config=None, session=None) -> Dict[str, Any]:
Returns a discord embed object.
"""
return await _md_client_req(
config, session, 'embed', url
)
return await _md_client_req(config, session, "embed", url)
async def fill_embed(embed: Optional[Embed]) -> Optional[Embed]:
@ -229,22 +219,20 @@ async def fill_embed(embed: Optional[Embed]) -> Optional[Embed]:
embed = sanitize_embed(embed)
if path_exists(embed, 'footer.icon_url'):
embed['footer']['proxy_icon_url'] = \
proxify(embed['footer']['icon_url'])
if path_exists(embed, "footer.icon_url"):
embed["footer"]["proxy_icon_url"] = proxify(embed["footer"]["icon_url"])
if path_exists(embed, 'author.icon_url'):
embed['author']['proxy_icon_url'] = \
proxify(embed['author']['icon_url'])
if path_exists(embed, "author.icon_url"):
embed["author"]["proxy_icon_url"] = proxify(embed["author"]["icon_url"])
if path_exists(embed, 'image.url'):
image_url = embed['image']['url']
if path_exists(embed, "image.url"):
image_url = embed["image"]["url"]
meta = await fetch_metadata(image_url)
embed['image']['proxy_url'] = proxify(image_url)
embed["image"]["proxy_url"] = proxify(image_url)
if meta and meta['image']:
embed['image']['width'] = meta['width']
embed['image']['height'] = meta['height']
if meta and meta["image"]:
embed["image"]["width"] = meta["width"]
embed["image"]["height"] = meta["height"]
return embed

View File

@ -28,8 +28,8 @@ class EmbedURL:
def __init__(self, url: str):
parsed = urllib.parse.urlparse(url)
if parsed.scheme not in ('http', 'https', 'attachment'):
raise ValueError('Invalid URL scheme')
if parsed.scheme not in ("http", "https", "attachment"):
raise ValueError("Invalid URL scheme")
self.scheme = parsed.scheme
self.raw_url = url
@ -54,105 +54,61 @@ class EmbedURL:
def to_md_path(self) -> str:
"""Convert the EmbedURL to a mediaproxy path (post img/meta)."""
parsed = self.parsed
return (
f'{parsed.scheme}/{parsed.netloc}'
f'{parsed.path}?{parsed.query}'
)
return f"{parsed.scheme}/{parsed.netloc}" f"{parsed.path}?{parsed.query}"
EMBED_FOOTER = {
'text': {
'type': 'string', 'minlength': 1, 'maxlength': 1024, 'required': True},
'icon_url': {
'coerce': EmbedURL, 'required': False,
},
"text": {"type": "string", "minlength": 1, "maxlength": 1024, "required": True},
"icon_url": {"coerce": EmbedURL, "required": False},
# NOTE: proxy_icon_url set by us
}
EMBED_IMAGE = {
'url': {'coerce': EmbedURL, 'required': True},
"url": {"coerce": EmbedURL, "required": True},
# NOTE: proxy_url, width, height set by us
}
EMBED_THUMBNAIL = EMBED_IMAGE
EMBED_AUTHOR = {
'name': {
'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': False
},
'url': {
'coerce': EmbedURL, 'required': False,
},
'icon_url': {
'coerce': EmbedURL, 'required': False,
}
"name": {"type": "string", "minlength": 1, "maxlength": 256, "required": False},
"url": {"coerce": EmbedURL, "required": False},
"icon_url": {"coerce": EmbedURL, "required": False}
# NOTE: proxy_icon_url set by us
}
EMBED_FIELD = {
'name': {
'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': True
},
'value': {
'type': 'string', 'minlength': 1, 'maxlength': 1024, 'required': True
},
'inline': {
'type': 'boolean', 'required': False, 'default': True,
},
"name": {"type": "string", "minlength": 1, "maxlength": 256, "required": True},
"value": {"type": "string", "minlength": 1, "maxlength": 1024, "required": True},
"inline": {"type": "boolean", "required": False, "default": True},
}
EMBED_OBJECT = {
'title': {
'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': False},
"title": {"type": "string", "minlength": 1, "maxlength": 256, "required": False},
# NOTE: type set by us
'description': {
'type': 'string', 'minlength': 1, 'maxlength': 2048, 'required': False,
"description": {
"type": "string",
"minlength": 1,
"maxlength": 2048,
"required": False,
},
'url': {
'coerce': EmbedURL, 'required': False,
},
'timestamp': {
"url": {"coerce": EmbedURL, "required": False},
"timestamp": {
# TODO: an ISO 8601 type
# TODO: maybe replace the default in here with now().isoformat?
'type': 'string', 'required': False
"type": "string",
"required": False,
},
'color': {
'coerce': Color, 'required': False
},
'footer': {
'type': 'dict',
'schema': EMBED_FOOTER,
'required': False,
},
'image': {
'type': 'dict',
'schema': EMBED_IMAGE,
'required': False,
},
'thumbnail': {
'type': 'dict',
'schema': EMBED_THUMBNAIL,
'required': False,
},
"color": {"coerce": Color, "required": False},
"footer": {"type": "dict", "schema": EMBED_FOOTER, "required": False},
"image": {"type": "dict", "schema": EMBED_IMAGE, "required": False},
"thumbnail": {"type": "dict", "schema": EMBED_THUMBNAIL, "required": False},
# NOTE: 'video' set by us
# NOTE: 'provider' set by us
'author': {
'type': 'dict',
'schema': EMBED_AUTHOR,
'required': False,
},
'fields': {
'type': 'list',
'schema': {'type': 'dict', 'schema': EMBED_FIELD},
'required': False,
"author": {"type": "dict", "schema": EMBED_AUTHOR, "required": False},
"fields": {
"type": "list",
"schema": {"type": "dict", "schema": EMBED_FIELD},
"required": False,
},
}

View File

@ -52,13 +52,14 @@ class Flags:
>>> i2.is_field_3
False
"""
def __init_subclass__(cls, **_kwargs):
attrs = inspect.getmembers(cls, lambda x: not inspect.isroutine(x))
def _make_int(value):
res = Flags()
setattr(res, 'value', value)
setattr(res, "value", value)
for attr, val in attrs:
# get only the ones that represent a field in the
@ -69,7 +70,7 @@ class Flags:
has_attr = (value & val) == val
# set each attribute
setattr(res, f'is_{attr}', has_attr)
setattr(res, f"is_{attr}", has_attr)
return res
@ -84,17 +85,16 @@ class ChannelType(EasyEnum):
GUILD_CATEGORY = 4
GUILD_CHANS = (ChannelType.GUILD_TEXT,
GUILD_CHANS = (
ChannelType.GUILD_TEXT,
ChannelType.GUILD_VOICE,
ChannelType.GUILD_CATEGORY)
VOICE_CHANNELS = (
ChannelType.DM, ChannelType.GUILD_VOICE,
ChannelType.GUILD_CATEGORY
ChannelType.GUILD_CATEGORY,
)
VOICE_CHANNELS = (ChannelType.DM, ChannelType.GUILD_VOICE, ChannelType.GUILD_CATEGORY)
class ActivityType(EasyEnum):
PLAYING = 0
STREAMING = 1
@ -120,7 +120,7 @@ SYS_MESSAGES = (
MessageType.CHANNEL_NAME_CHANGE,
MessageType.CHANNEL_ICON_CHANGE,
MessageType.CHANNEL_PINNED_MESSAGE,
MessageType.GUILD_MEMBER_JOIN
MessageType.GUILD_MEMBER_JOIN,
)
@ -137,6 +137,7 @@ class ActivityFlags(Flags):
Only related to rich presence.
"""
instance = 1
join = 2
spectate = 4
@ -150,6 +151,7 @@ class UserFlags(Flags):
Used by the client to show badges.
"""
staff = 1
partner = 2
hypesquad = 4
@ -166,6 +168,7 @@ class UserFlags(Flags):
class MessageFlags(Flags):
"""Message flags."""
none = 0
crossposted = 1 << 0
@ -175,11 +178,12 @@ class MessageFlags(Flags):
class StatusType(EasyEnum):
"""All statuses there can be in a presence."""
ONLINE = 'online'
DND = 'dnd'
IDLE = 'idle'
INVISIBLE = 'invisible'
OFFLINE = 'offline'
ONLINE = "online"
DND = "dnd"
IDLE = "idle"
INVISIBLE = "invisible"
OFFLINE = "offline"
class ExplicitFilter(EasyEnum):
@ -187,6 +191,7 @@ class ExplicitFilter(EasyEnum):
Also applies to guilds.
"""
EDGE = 0
FRIENDS = 1
SAFE = 2
@ -194,6 +199,7 @@ class ExplicitFilter(EasyEnum):
class VerificationLevel(IntEnum):
"""Verification level for guilds."""
NONE = 0
LOW = 1
MEDIUM = 2
@ -205,6 +211,7 @@ class VerificationLevel(IntEnum):
class RelationshipType(EasyEnum):
"""Relationship types between users."""
FRIEND = 1
BLOCK = 2
INCOMING = 3
@ -213,6 +220,7 @@ class RelationshipType(EasyEnum):
class MessageNotifications(EasyEnum):
"""Message notifications"""
ALL = 0
MENTIONS = 1
NOTHING = 2
@ -220,6 +228,7 @@ class MessageNotifications(EasyEnum):
class PremiumType:
"""Premium (Nitro) type."""
TIER_1 = 1
TIER_2 = 2
NONE = None
@ -227,12 +236,13 @@ class PremiumType:
class Feature(EasyEnum):
"""Guild features."""
invite_splash = 'INVITE_SPLASH'
vip = 'VIP_REGIONS'
vanity = 'VANITY_URL'
emoji = 'MORE_EMOJI'
verified = 'VERIFIED'
invite_splash = "INVITE_SPLASH"
vip = "VIP_REGIONS"
vanity = "VANITY_URL"
emoji = "MORE_EMOJI"
verified = "VERIFIED"
# unknown
commerce = 'COMMERCE'
news = 'NEWS'
commerce = "COMMERCE"
news = "NEWS"

View File

@ -18,60 +18,64 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
ERR_MSG_MAP = {
10001: 'Unknown account',
10002: 'Unknown application',
10003: 'Unknown channel',
10004: 'Unknown guild',
10005: 'Unknown integration',
10006: 'Unknown invite',
10007: 'Unknown member',
10008: 'Unknown message',
10009: 'Unknown overwrite',
10010: 'Unknown provider',
10011: 'Unknown role',
10012: 'Unknown token',
10013: 'Unknown user',
10014: 'Unknown Emoji',
10015: 'Unknown Webhook',
20001: 'Bots cannot use this endpoint',
20002: 'Only bots can use this endpoint',
30001: 'Maximum number of guilds reached (100)',
30002: 'Maximum number of friends reached (1000)',
30003: 'Maximum number of pins reached (50)',
30005: 'Maximum number of guild roles reached (250)',
30010: 'Maximum number of reactions reached (20)',
30013: 'Maximum number of guild channels reached (500)',
40001: 'Unauthorized',
50001: 'Missing access',
50002: 'Invalid account type',
50003: 'Cannot execute action on a DM channel',
50004: 'Widget Disabled',
50005: 'Cannot edit a message authored by another user',
50006: 'Cannot send an empty message',
50007: 'Cannot send messages to this user',
50008: 'Cannot send messages in a voice channel',
50009: 'Channel verification level is too high',
50010: 'OAuth2 application does not have a bot',
50011: 'OAuth2 application limit reached',
50012: 'Invalid OAuth state',
50013: 'Missing permissions',
50014: 'Invalid authentication token',
50015: 'Note is too long',
50016: ('Provided too few or too many messages to delete. Must provide at '
'least 2 and fewer than 100 messages to delete.'),
50019: 'A message can only be pinned to the channel it was sent in',
50020: 'Invite code is either invalid or taken.',
50021: 'Cannot execute action on a system message',
50025: 'Invalid OAuth2 access token',
50034: 'A message provided was too old to bulk delete',
50035: 'Invalid Form Body',
50036: 'An invite was accepted to a guild the application\'s bot is not in',
50041: 'Invalid API version',
90001: 'Reaction blocked',
10001: "Unknown account",
10002: "Unknown application",
10003: "Unknown channel",
10004: "Unknown guild",
10005: "Unknown integration",
10006: "Unknown invite",
10007: "Unknown member",
10008: "Unknown message",
10009: "Unknown overwrite",
10010: "Unknown provider",
10011: "Unknown role",
10012: "Unknown token",
10013: "Unknown user",
10014: "Unknown Emoji",
10015: "Unknown Webhook",
20001: "Bots cannot use this endpoint",
20002: "Only bots can use this endpoint",
30001: "Maximum number of guilds reached (100)",
30002: "Maximum number of friends reached (1000)",
30003: "Maximum number of pins reached (50)",
30005: "Maximum number of guild roles reached (250)",
30010: "Maximum number of reactions reached (20)",
30013: "Maximum number of guild channels reached (500)",
40001: "Unauthorized",
50001: "Missing access",
50002: "Invalid account type",
50003: "Cannot execute action on a DM channel",
50004: "Widget Disabled",
50005: "Cannot edit a message authored by another user",
50006: "Cannot send an empty message",
50007: "Cannot send messages to this user",
50008: "Cannot send messages in a voice channel",
50009: "Channel verification level is too high",
50010: "OAuth2 application does not have a bot",
50011: "OAuth2 application limit reached",
50012: "Invalid OAuth state",
50013: "Missing permissions",
50014: "Invalid authentication token",
50015: "Note is too long",
50016: (
"Provided too few or too many messages to delete. Must provide at "
"least 2 and fewer than 100 messages to delete."
),
50019: "A message can only be pinned to the channel it was sent in",
50020: "Invite code is either invalid or taken.",
50021: "Cannot execute action on a system message",
50025: "Invalid OAuth2 access token",
50034: "A message provided was too old to bulk delete",
50035: "Invalid Form Body",
50036: "An invite was accepted to a guild the application's bot is not in",
50041: "Invalid API version",
90001: "Reaction blocked",
}
class LitecordError(Exception):
"""Base class for litecord errors"""
status_code = 500
def _get_err_msg(self, err_code: int) -> str:
@ -91,7 +95,7 @@ class LitecordError(Exception):
return message
except IndexError:
return self._get_err_msg(getattr(self, 'error_code', None))
return self._get_err_msg(getattr(self, "error_code", None))
@property
def json(self):
@ -143,7 +147,7 @@ class MissingPermissions(Forbidden):
class WebsocketClose(Exception):
@property
def code(self):
from_class = getattr(self, 'close_code', None)
from_class = getattr(self, "close_code", None)
if from_class:
return from_class
@ -152,7 +156,7 @@ class WebsocketClose(Exception):
@property
def reason(self):
from_class = getattr(self, 'close_code', None)
from_class = getattr(self, "close_code", None)
if from_class:
return self.args[0]

View File

@ -25,8 +25,7 @@ from litecord.utils import LitecordJSONEncoder
def encode_json(payload) -> str:
"""Encode a given payload to JSON."""
return json.dumps(payload, separators=(',', ':'),
cls=LitecordJSONEncoder)
return json.dumps(payload, separators=(",", ":"), cls=LitecordJSONEncoder)
def decode_json(data: str):
@ -71,6 +70,7 @@ def _etf_decode_dict(data):
return result
def decode_etf(data: bytes):
"""Decode data in ETF to any."""
res = earl.unpack(data)

View File

@ -24,37 +24,36 @@ from litecord.gateway.websocket import GatewayWebsocket
async def websocket_handler(app, ws, url):
"""Main websocket handler, checks query arguments when connecting to
the gateway and spawns a GatewayWebsocket instance for the connection."""
args = urllib.parse.parse_qs(
urllib.parse.urlparse(url).query
)
args = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
# pull a dict.get but in a really bad way.
try:
gw_version = args['v'][0]
gw_version = args["v"][0]
except (KeyError, IndexError):
gw_version = '6'
gw_version = "6"
try:
gw_encoding = args['encoding'][0]
gw_encoding = args["encoding"][0]
except (KeyError, IndexError):
gw_encoding = 'json'
gw_encoding = "json"
if gw_version not in ('6', '7'):
return await ws.close(1000, 'Invalid gateway version')
if gw_version not in ("6", "7"):
return await ws.close(1000, "Invalid gateway version")
if gw_encoding not in ('json', 'etf'):
return await ws.close(1000, 'Invalid gateway encoding')
if gw_encoding not in ("json", "etf"):
return await ws.close(1000, "Invalid gateway encoding")
try:
gw_compress = args['compress'][0]
gw_compress = args["compress"][0]
except (KeyError, IndexError):
gw_compress = None
if gw_compress and gw_compress not in ('zlib-stream', 'zstd-stream'):
return await ws.close(1000, 'Invalid gateway compress')
if gw_compress and gw_compress not in ("zlib-stream", "zstd-stream"):
return await ws.close(1000, "Invalid gateway compress")
gws = GatewayWebsocket(
ws, app, v=gw_version, encoding=gw_encoding, compress=gw_compress)
ws, app, v=gw_version, encoding=gw_encoding, compress=gw_compress
)
# this can be run with a single await since this whole coroutine
# is already running in the background.

View File

@ -17,8 +17,10 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
class OP:
"""Gateway OP codes."""
DISPATCH = 0
HEARTBEAT = 1
IDENTIFY = 2

View File

@ -32,6 +32,7 @@ class PayloadStore:
This will only store a maximum of MAX_STORE_SIZE,
dropping the older payloads when adding new ones.
"""
MAX_STORE_SIZE = 250
def __init__(self):
@ -60,20 +61,20 @@ class GatewayState:
"""
def __init__(self, **kwargs):
self.session_id = kwargs.get('session_id', gen_session_id())
self.session_id = kwargs.get("session_id", gen_session_id())
#: event sequence number
self.seq = kwargs.get('seq', 0)
self.seq = kwargs.get("seq", 0)
#: last seq sent by us, the backend
self.last_seq = 0
#: shard information about the state,
# its id and shard count
self.shard = kwargs.get('shard', [0, 1])
self.shard = kwargs.get("shard", [0, 1])
self.user_id = kwargs.get('user_id')
self.bot = kwargs.get('bot', False)
self.user_id = kwargs.get("user_id")
self.bot = kwargs.get("bot", False)
#: set by the gateway connection
# on OP STATUS_UPDATE
@ -90,5 +91,4 @@ class GatewayState:
self.__dict__[key] = value
def __repr__(self):
return (f'GatewayState<seq={self.seq} '
f'shard={self.shard} uid={self.user_id}>')
return f"GatewayState<seq={self.seq} " f"shard={self.shard} uid={self.user_id}>"

View File

@ -39,6 +39,7 @@ class ManagerClose(Exception):
class StateDictWrapper:
"""Wrap a mapping so that any kind of access to the mapping while the
state manager is closed raises a ManagerClose error"""
def __init__(self, state_manager, mapping):
self.state_manager = state_manager
self._map = mapping
@ -98,7 +99,7 @@ class StateManager:
"""Insert a new state object."""
user_states = self.states[state.user_id]
log.debug('inserting state: {!r}', state)
log.debug("inserting state: {!r}", state)
user_states[state.session_id] = state
self.states_raw[state.session_id] = state
@ -128,7 +129,7 @@ class StateManager:
pass
try:
log.debug('removing state: {!r}', state)
log.debug("removing state: {!r}", state)
self.states[state.user_id].pop(state.session_id)
except KeyError:
pass
@ -152,8 +153,7 @@ class StateManager:
"""Fetch all states tied to a single user."""
return list(self.states[user_id].values())
def guild_states(self, member_ids: List[int],
guild_id: int) -> List[GatewayState]:
def guild_states(self, member_ids: List[int], guild_id: int) -> List[GatewayState]:
"""Fetch all possible states about members in a guild."""
states = []
@ -164,14 +164,14 @@ class StateManager:
# since server start, so we need to add a dummy state
if not member_states:
dummy_state = GatewayState(
session_id='',
session_id="",
user_id=member_id,
presence={
'afk': False,
'status': 'offline',
'game': None,
'since': 0
}
"afk": False,
"status": "offline",
"game": None,
"since": 0,
},
)
states.append(dummy_state)
@ -187,9 +187,7 @@ class StateManager:
"""Send OP Reconnect to a single connection."""
websocket = state.ws
await websocket.send({
'op': OP.RECONNECT
})
await websocket.send({"op": OP.RECONNECT})
# wait 200ms
# so that the client has time to process
@ -198,12 +196,9 @@ class StateManager:
try:
# try to close the connection ourselves
await websocket.ws.close(
code=4000,
reason='litecord shutting down'
)
await websocket.ws.close(code=4000, reason="litecord shutting down")
except ConnectionClosed:
log.info('client {} already closed', state)
log.info("client {} already closed", state)
def gen_close_tasks(self):
"""Generate the tasks that will order the clients
@ -222,11 +217,9 @@ class StateManager:
if not state.ws:
continue
tasks.append(
self.shutdown_single(state)
)
tasks.append(self.shutdown_single(state))
log.info('made {} shutdown tasks', len(tasks))
log.info("made {} shutdown tasks", len(tasks))
return tasks
def close(self):

View File

@ -19,9 +19,11 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
class WebsocketFileHandler:
"""A handler around a websocket that wraps normal I/O calls into
the websocket's respective asyncio calls via asyncio.ensure_future."""
def __init__(self, ws):
self.ws = ws

View File

@ -31,23 +31,20 @@ from logbook import Logger
from litecord.auth import raw_token_check
from litecord.enums import RelationshipType, ChannelType
from litecord.schemas import validate, GW_STATUS_UPDATE
from litecord.utils import (
task_wrapper, yield_chunks, maybe_int
)
from litecord.utils import task_wrapper, yield_chunks, maybe_int
from litecord.permissions import get_permissions
from litecord.gateway.opcodes import OP
from litecord.gateway.state import GatewayState
from litecord.errors import (
WebsocketClose, Unauthorized, Forbidden, BadRequest
)
from litecord.errors import WebsocketClose, Unauthorized, Forbidden, BadRequest
from litecord.gateway.errors import (
DecodeError, UnknownOPCode, InvalidShard, ShardingRequired
)
from litecord.gateway.encoding import (
encode_json, decode_json, encode_etf, decode_etf
DecodeError,
UnknownOPCode,
InvalidShard,
ShardingRequired,
)
from litecord.gateway.encoding import encode_json, decode_json, encode_etf, decode_etf
from litecord.gateway.utils import WebsocketFileHandler
@ -56,15 +53,22 @@ from litecord.storage import int_
log = Logger(__name__)
WebsocketProperties = collections.namedtuple(
'WebsocketProperties', 'v encoding compress zctx zsctx tasks'
"WebsocketProperties", "v encoding compress zctx zsctx tasks"
)
WebsocketObjects = collections.namedtuple(
'WebsocketObjects', (
'db', 'state_manager', 'storage',
'loop', 'dispatcher', 'presence', 'ratelimiter',
'user_storage', 'voice'
)
"WebsocketObjects",
(
"db",
"state_manager",
"storage",
"loop",
"dispatcher",
"presence",
"ratelimiter",
"user_storage",
"voice",
),
)
@ -73,9 +77,15 @@ class GatewayWebsocket:
def __init__(self, ws, app, **kwargs):
self.ext = WebsocketObjects(
app.db, app.state_manager, app.storage, app.loop,
app.dispatcher, app.presence, app.ratelimiter,
app.user_storage, app.voice
app.db,
app.state_manager,
app.storage,
app.loop,
app.dispatcher,
app.presence,
app.ratelimiter,
app.user_storage,
app.voice,
)
self.storage = self.ext.storage
@ -84,15 +94,15 @@ class GatewayWebsocket:
self.ws = ws
self.wsp = WebsocketProperties(
kwargs.get('v'),
kwargs.get('encoding', 'json'),
kwargs.get('compress', None),
kwargs.get("v"),
kwargs.get("encoding", "json"),
kwargs.get("compress", None),
zlib.compressobj(),
zstd.ZstdCompressor(),
{}
{},
)
log.debug('websocket properties: {!r}', self.wsp)
log.debug("websocket properties: {!r}", self.wsp)
self.state = None
@ -102,8 +112,8 @@ class GatewayWebsocket:
encoding = self.wsp.encoding
encodings = {
'json': (encode_json, decode_json),
'etf': (encode_etf, decode_etf),
"json": (encode_json, decode_json),
"etf": (encode_etf, decode_etf),
}
self.encoder, self.decoder = encodings[encoding]
@ -111,16 +121,17 @@ class GatewayWebsocket:
async def _chunked_send(self, data: bytes, chunk_size: int):
"""Split data in chunk_size-big chunks and send them
over the websocket."""
log.debug('zlib-stream: chunking {} bytes into {}-byte chunks',
len(data), chunk_size)
log.debug(
"zlib-stream: chunking {} bytes into {}-byte chunks", len(data), chunk_size
)
total_chunks = 0
for chunk in yield_chunks(data, chunk_size):
total_chunks += 1
log.debug('zlib-stream: chunk {}', total_chunks)
log.debug("zlib-stream: chunk {}", total_chunks)
await self.ws.send(chunk)
log.debug('zlib-stream: sent {} chunks', total_chunks)
log.debug("zlib-stream: sent {} chunks", total_chunks)
async def _zlib_stream_send(self, encoded):
"""Sending a single payload across multiple compressed
@ -130,8 +141,12 @@ class GatewayWebsocket:
data1 = self.wsp.zctx.compress(encoded)
data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH)
log.debug('zlib-stream: length {} -> compressed ({} + {})',
len(encoded), len(data1), len(data2))
log.debug(
"zlib-stream: length {} -> compressed ({} + {})",
len(encoded),
len(data1),
len(data2),
)
if not data1:
# if data1 is nothing, that might cause problems
@ -139,8 +154,11 @@ class GatewayWebsocket:
data1 = bytes([data2[0]])
data2 = data2[1:]
log.debug('zlib-stream: len(data1) == 0, remaking as ({} + {})',
len(data1), len(data2))
log.debug(
"zlib-stream: len(data1) == 0, remaking as ({} + {})",
len(data1),
len(data2),
)
# NOTE: the old approach was ws.send(data1 + data2).
# I changed this to a chunked send of data1 and data2
@ -157,8 +175,7 @@ class GatewayWebsocket:
await self._chunked_send(data2, 1024)
async def _zstd_stream_send(self, encoded):
compressor = self.wsp.zsctx.stream_writer(
WebsocketFileHandler(self.ws))
compressor = self.wsp.zsctx.stream_writer(WebsocketFileHandler(self.ws))
compressor.write(encoded)
compressor.flush(zstd.FLUSH_FRAME)
@ -172,21 +189,23 @@ class GatewayWebsocket:
encoded = self.encoder(payload)
if len(encoded) < 2048:
log.debug('sending\n{}', pprint.pformat(payload))
log.debug("sending\n{}", pprint.pformat(payload))
else:
log.debug('sending {}', pprint.pformat(payload))
log.debug('sending op={} s={} t={} (too big)',
payload.get('op'),
payload.get('s'),
payload.get('t'))
log.debug("sending {}", pprint.pformat(payload))
log.debug(
"sending op={} s={} t={} (too big)",
payload.get("op"),
payload.get("s"),
payload.get("t"),
)
# treat encoded as bytes
if not isinstance(encoded, bytes):
encoded = encoded.encode()
if self.wsp.compress == 'zlib-stream':
if self.wsp.compress == "zlib-stream":
await self._zlib_stream_send(encoded)
elif self.wsp.compress == 'zstd-stream':
elif self.wsp.compress == "zstd-stream":
await self._zstd_stream_send(encoded)
elif self.state and self.state.compress and len(encoded) > 1024:
# TODO: should we only compress on >1KB packets? or maybe we
@ -203,16 +222,10 @@ class GatewayWebsocket:
async def send_op(self, op_code: int, data: Any):
"""Send a packet but just the OP code information is filled in."""
await self.send({
'op': op_code,
'd': data,
't': None,
's': None
})
await self.send({"op": op_code, "d": data, "t": None, "s": None})
def _check_ratelimit(self, key: str, ratelimit_key):
ratelimit = self.ext.ratelimiter.get_ratelimit(f'_ws.{key}')
ratelimit = self.ext.ratelimiter.get_ratelimit(f"_ws.{key}")
bucket = ratelimit.get_bucket(ratelimit_key)
return bucket.update_rate_limit()
@ -221,19 +234,19 @@ class GatewayWebsocket:
# if the client heartbeats in time,
# this task will be cancelled.
await asyncio.sleep(interval / 1000)
await self.ws.close(4000, 'Heartbeat expired')
await self.ws.close(4000, "Heartbeat expired")
self._cleanup()
def _hb_start(self, interval: int):
# always refresh the heartbeat task
# when possible
task = self.wsp.tasks.get('heartbeat')
task = self.wsp.tasks.get("heartbeat")
if task:
task.cancel()
self.wsp.tasks['heartbeat'] = self.ext.loop.create_task(
task_wrapper('hb wait', self._hb_wait(interval))
self.wsp.tasks["heartbeat"] = self.ext.loop.create_task(
task_wrapper("hb wait", self._hb_wait(interval))
)
async def _send_hello(self):
@ -241,12 +254,9 @@ class GatewayWebsocket:
# random heartbeat intervals
interval = randint(40, 46) * 1000
await self.send_op(OP.HELLO, {
'heartbeat_interval': interval,
'_trace': [
'lesbian-server'
],
})
await self.send_op(
OP.HELLO, {"heartbeat_interval": interval, "_trace": ["lesbian-server"]}
)
self._hb_start(interval)
@ -255,16 +265,15 @@ class GatewayWebsocket:
self.state.seq += 1
payload = {
'op': OP.DISPATCH,
't': event.upper(),
's': self.state.seq,
'd': data,
"op": OP.DISPATCH,
"t": event.upper(),
"s": self.state.seq,
"d": data,
}
self.state.store[self.state.seq] = payload
log.debug('sending payload {!r} sid {}',
event.upper(), self.state.session_id)
log.debug("sending payload {!r} sid {}", event.upper(), self.state.session_id)
await self.send(payload)
@ -274,16 +283,14 @@ class GatewayWebsocket:
guild_ids = await self._guild_ids()
if self.state.bot:
return [{
'id': row,
'unavailable': True,
} for row in guild_ids]
return [{"id": row, "unavailable": True} for row in guild_ids]
return [
{
**await self.storage.get_guild(guild_id, user_id),
**await self.storage.get_guild_extra(guild_id, user_id,
self.state.large)
**await self.storage.get_guild_extra(
guild_id, user_id, self.state.large
),
}
for guild_id in guild_ids
]
@ -298,13 +305,13 @@ class GatewayWebsocket:
for guild_obj in unavailable_guilds:
# fetch full guild object including the 'large' field
guild = await self.storage.get_guild_full(
int(guild_obj['id']), self.state.user_id, self.state.large
int(guild_obj["id"]), self.state.user_id, self.state.large
)
if guild is None:
continue
await self.dispatch('GUILD_CREATE', guild)
await self.dispatch("GUILD_CREATE", guild)
async def _user_ready(self) -> dict:
"""Fetch information about users in the READY packet.
@ -317,28 +324,28 @@ class GatewayWebsocket:
relationships = await self.user_storage.get_relationships(user_id)
friend_ids = [int(r['user']['id']) for r in relationships
if r['type'] == RelationshipType.FRIEND.value]
friend_ids = [
int(r["user"]["id"])
for r in relationships
if r["type"] == RelationshipType.FRIEND.value
]
friend_presences = await self.ext.presence.friend_presences(friend_ids)
settings = await self.user_storage.get_user_settings(user_id)
return {
'user_settings': settings,
'notes': await self.user_storage.fetch_notes(user_id),
'relationships': relationships,
'presences': friend_presences,
'read_state': await self.user_storage.get_read_state(user_id),
'user_guild_settings': await self.user_storage.get_guild_settings(
user_id),
'friend_suggestion_count': 0,
"user_settings": settings,
"notes": await self.user_storage.fetch_notes(user_id),
"relationships": relationships,
"presences": friend_presences,
"read_state": await self.user_storage.get_read_state(user_id),
"user_guild_settings": await self.user_storage.get_guild_settings(user_id),
"friend_suggestion_count": 0,
# those are unused default values.
'connected_accounts': [],
'experiments': [],
'guild_experiments': [],
'analytics_token': 'transbian',
"connected_accounts": [],
"experiments": [],
"guild_experiments": [],
"analytics_token": "transbian",
}
async def dispatch_ready(self):
@ -353,24 +360,21 @@ class GatewayWebsocket:
# user, fetch info
user_ready = await self._user_ready()
private_channels = (
await self.user_storage.get_dms(user_id) +
await self.user_storage.get_gdms(user_id)
)
private_channels = await self.user_storage.get_dms(
user_id
) + await self.user_storage.get_gdms(user_id)
base_ready = {
'v': 6,
'user': user,
'private_channels': private_channels,
'guilds': guilds,
'session_id': self.state.session_id,
'_trace': ['transbian'],
'shard': self.state.shard,
"v": 6,
"user": user,
"private_channels": private_channels,
"guilds": guilds,
"session_id": self.state.session_id,
"_trace": ["transbian"],
"shard": self.state.shard,
}
await self.dispatch('READY', {**base_ready, **user_ready})
await self.dispatch("READY", {**base_ready, **user_ready})
# async dispatch of guilds
self.ext.loop.create_task(self._guild_dispatch(guilds))
@ -380,33 +384,32 @@ class GatewayWebsocket:
"""
current_shard, shard_count = shard
guilds = await self.ext.db.fetchval("""
guilds = await self.ext.db.fetchval(
"""
SELECT COUNT(*)
FROM members
WHERE user_id = $1
""", user_id)
""",
user_id,
)
recommended = max(int(guilds / 1200), 1)
if shard_count < recommended:
raise ShardingRequired('Too many guilds for shard '
f'{current_shard}')
raise ShardingRequired("Too many guilds for shard " f"{current_shard}")
if guilds > 2500 and guilds / shard_count > 0.8:
raise ShardingRequired('Too many shards. '
f'(g={guilds} sc={shard_count})')
raise ShardingRequired("Too many shards. " f"(g={guilds} sc={shard_count})")
if current_shard > shard_count:
raise InvalidShard('Shard count > Total shards')
raise InvalidShard("Shard count > Total shards")
async def _guild_ids(self) -> list:
"""Get a list of Guild IDs that are tied to this connection.
The implementation is shard-aware.
"""
guild_ids = await self.user_storage.get_user_guilds(
self.state.user_id
)
guild_ids = await self.user_storage.get_user_guilds(self.state.user_id)
shard_id = self.state.current_shard
shard_count = self.state.shard_count
@ -414,10 +417,7 @@ class GatewayWebsocket:
def _get_shard(guild_id):
return (guild_id >> 22) % shard_count
filtered = filter(
lambda guild_id: _get_shard(guild_id) == shard_id,
guild_ids
)
filtered = filter(lambda guild_id: _get_shard(guild_id) == shard_id, guild_ids)
return list(filtered)
@ -432,13 +432,17 @@ class GatewayWebsocket:
# subscribe the user to all dms they have OPENED.
dms = await self.user_storage.get_dms(user_id)
dm_ids = [int(dm['id']) for dm in dms]
dm_ids = [int(dm["id"]) for dm in dms]
# fetch all group dms the user is a member of.
gdm_ids = await self.user_storage.get_gdms_internal(user_id)
log.info('subscribing to {} guilds {} dms {} gdms',
len(guild_ids), len(dm_ids), len(gdm_ids))
log.info(
"subscribing to {} guilds {} dms {} gdms",
len(guild_ids),
len(dm_ids),
len(gdm_ids),
)
# guild_subscriptions:
# enables dispatching of guild subscription events
@ -447,10 +451,13 @@ class GatewayWebsocket:
# we enable processing of guild_subscriptions by adding flags
# when subscribing to the given backend. those are optional.
channels_to_sub = [
('guild', guild_ids,
{'presence': guild_subscriptions, 'typing': guild_subscriptions}),
('channel', dm_ids),
('channel', gdm_ids),
(
"guild",
guild_ids,
{"presence": guild_subscriptions, "typing": guild_subscriptions},
),
("channel", dm_ids),
("channel", gdm_ids),
]
await self.ext.dispatcher.mass_sub(user_id, channels_to_sub)
@ -460,28 +467,26 @@ class GatewayWebsocket:
# (their friends will also subscribe back
# when they come online)
friend_ids = await self.user_storage.get_friend_ids(user_id)
log.info('subscribing to {} friends', len(friend_ids))
await self.ext.dispatcher.sub_many('friend', user_id, friend_ids)
log.info("subscribing to {} friends", len(friend_ids))
await self.ext.dispatcher.sub_many("friend", user_id, friend_ids)
async def update_status(self, status: dict):
"""Update the status of the current websocket connection."""
if not self.state:
return
if self._check_ratelimit('presence', self.state.session_id):
if self._check_ratelimit("presence", self.state.session_id):
# Presence Updates beyond the ratelimit
# are just silently dropped.
return
default_status = {
'afk': False,
"afk": False,
# TODO: fetch status from settings
'status': 'online',
'game': None,
"status": "online",
"game": None,
# TODO: this
'since': 0,
"since": 0,
}
status = {**(status or {}), **default_status}
@ -489,39 +494,40 @@ class GatewayWebsocket:
try:
status = validate(status, GW_STATUS_UPDATE)
except BadRequest as err:
log.warning(f'Invalid status update: {err}')
log.warning(f"Invalid status update: {err}")
return
# try to extract game from activities
# when game not provided
if not status.get('game'):
if not status.get("game"):
try:
game = status['activities'][0]
game = status["activities"][0]
except (KeyError, IndexError):
game = None
else:
game = status['game']
game = status["game"]
# construct final status
status = {
'afk': status.get('afk', False),
'status': status.get('status', 'online'),
'game': game,
'since': status.get('since', 0),
"afk": status.get("afk", False),
"status": status.get("status", "online"),
"game": game,
"since": status.get("since", 0),
}
self.state.presence = status
log.info(f'Updating presence status={status["status"]} for '
f'uid={self.state.user_id}')
await self.ext.presence.dispatch_pres(self.state.user_id,
self.state.presence)
log.info(
f'Updating presence status={status["status"]} for '
f"uid={self.state.user_id}"
)
await self.ext.presence.dispatch_pres(self.state.user_id, self.state.presence)
async def handle_1(self, payload: Dict[str, Any]):
"""Handle OP 1 Heartbeat packets."""
# give the client 3 more seconds before we
# close the websocket
self._hb_start((46 + 3) * 1000)
cliseq = payload.get('d')
cliseq = payload.get("d")
if self.state:
self.state.last_seq = cliseq
@ -529,39 +535,42 @@ class GatewayWebsocket:
await self.send_op(OP.HEARTBEAT_ACK, None)
async def _connect_ratelimit(self, user_id: int):
if self._check_ratelimit('connect', user_id):
if self._check_ratelimit("connect", user_id):
await self.invalidate_session(False)
raise WebsocketClose(4009, 'You are being ratelimited.')
raise WebsocketClose(4009, "You are being ratelimited.")
if self._check_ratelimit('session', user_id):
if self._check_ratelimit("session", user_id):
await self.invalidate_session(False)
raise WebsocketClose(4004, 'Websocket Session Ratelimit reached.')
raise WebsocketClose(4004, "Websocket Session Ratelimit reached.")
async def handle_2(self, payload: Dict[str, Any]):
"""Handle the OP 2 Identify packet."""
try:
data = payload['d']
token = data['token']
data = payload["d"]
token = data["token"]
except KeyError:
raise DecodeError('Invalid identify parameters')
raise DecodeError("Invalid identify parameters")
compress = data.get('compress', False)
large = data.get('large_threshold', 50)
compress = data.get("compress", False)
large = data.get("large_threshold", 50)
shard = data.get('shard', [0, 1])
presence = data.get('presence')
shard = data.get("shard", [0, 1])
presence = data.get("presence")
try:
user_id = await raw_token_check(token, self.ext.db)
except (Unauthorized, Forbidden):
raise WebsocketClose(4004, 'Authentication failed')
raise WebsocketClose(4004, "Authentication failed")
await self._connect_ratelimit(user_id)
bot = await self.ext.db.fetchval("""
bot = await self.ext.db.fetchval(
"""
SELECT bot FROM users
WHERE id = $1
""", user_id)
""",
user_id,
)
await self._check_shards(shard, user_id)
@ -574,19 +583,19 @@ class GatewayWebsocket:
shard=shard,
current_shard=shard[0],
shard_count=shard[1],
ws=self
ws=self,
)
# link the state to the user
self.ext.state_manager.insert(self.state)
await self.update_status(presence)
await self.subscribe_all(data.get('guild_subscriptions', True))
await self.subscribe_all(data.get("guild_subscriptions", True))
await self.dispatch_ready()
async def handle_3(self, payload: Dict[str, Any]):
"""Handle OP 3 Status Update."""
presence = payload['d']
presence = payload["d"]
# update_status will take care of validation and
# setting new presence to state
@ -597,27 +606,27 @@ class GatewayWebsocket:
user settings."""
try:
# TODO: fetch from settings if not provided
self_deaf = bool(data['self_deaf'])
self_mute = bool(data['self_mute'])
self_deaf = bool(data["self_deaf"])
self_mute = bool(data["self_mute"])
except (KeyError, ValueError):
pass
return {
'deaf': state.deaf,
'mute': state.mute,
'self_deaf': self_deaf,
'self_mute': self_mute,
"deaf": state.deaf,
"mute": state.mute,
"self_deaf": self_deaf,
"self_mute": self_mute,
}
async def handle_4(self, payload: Dict[str, Any]):
"""Handle OP 4 Voice Status Update."""
data = payload['d']
data = payload["d"]
if not self.state:
return
channel_id = int_(data.get('channel_id'))
guild_id = int_(data.get('guild_id'))
channel_id = int_(data.get("channel_id"))
guild_id = int_(data.get("guild_id"))
# if its null and null, disconnect the user from any voice
# TODO: maybe just leave from DMs? idk...
@ -630,9 +639,7 @@ class GatewayWebsocket:
return await self.ext.voice.leave(guild_id, self.state.user_id)
# fetch an existing state given user and guild OR user and channel
chan_type = ChannelType(
await self.storage.get_chan_type(channel_id)
)
chan_type = ChannelType(await self.storage.get_chan_type(channel_id))
state_id2 = channel_id
@ -704,39 +711,38 @@ class GatewayWebsocket:
# ignore unknown seqs
continue
payload_t = payload.get('t')
payload_t = payload.get("t")
# presence resumption happens
# on a separate event, PRESENCE_REPLACE.
if payload_t == 'PRESENCE_UPDATE':
presences.append(payload.get('d'))
if payload_t == "PRESENCE_UPDATE":
presences.append(payload.get("d"))
continue
await self.send(payload)
except Exception:
log.exception('error while resuming')
log.exception("error while resuming")
await self.invalidate_session(False)
return
if presences:
await self.dispatch('PRESENCE_REPLACE', presences)
await self.dispatch("PRESENCE_REPLACE", presences)
await self.dispatch('RESUMED', {})
await self.dispatch("RESUMED", {})
async def handle_6(self, payload: Dict[str, Any]):
"""Handle OP 6 Resume."""
data = payload['d']
data = payload["d"]
try:
token, sess_id, seq = data['token'], \
data['session_id'], data['seq']
token, sess_id, seq = data["token"], data["session_id"], data["seq"]
except KeyError:
raise DecodeError('Invalid resume payload')
raise DecodeError("Invalid resume payload")
try:
user_id = await raw_token_check(token, self.ext.db)
except (Unauthorized, Forbidden):
raise WebsocketClose(4004, 'Invalid token')
raise WebsocketClose(4004, "Invalid token")
try:
state = self.ext.state_manager.fetch(user_id, sess_id)
@ -744,11 +750,11 @@ class GatewayWebsocket:
return await self.invalidate_session(False)
if seq > state.seq:
raise WebsocketClose(4007, 'Invalid seq')
raise WebsocketClose(4007, "Invalid seq")
# check if a websocket isnt on that state already
if state.ws is not None:
log.info('Resuming failed, websocket already connected')
log.info("Resuming failed, websocket already connected")
return await self.invalidate_session(False)
# relink this connection
@ -757,8 +763,9 @@ class GatewayWebsocket:
await self._resume(range(seq, state.seq))
async def _req_guild_members(self, guild_id, user_ids: List[int],
query: str, limit: int):
async def _req_guild_members(
self, guild_id, user_ids: List[int], query: str, limit: int
):
try:
guild_id = int(guild_id)
except (TypeError, ValueError):
@ -778,32 +785,32 @@ class GatewayWebsocket:
# ASSUMPTION: requesting user_ids means we don't do query.
if user_ids:
members = await self.storage.get_member_multi(guild_id, user_ids)
mids = [m['user']['id'] for m in members]
mids = [m["user"]["id"] for m in members]
not_found = [uid for uid in user_ids if uid not in mids]
await self.dispatch('GUILD_MEMBERS_CHUNK', {
'guild_id': str(guild_id),
'members': members,
'not_found': not_found,
})
await self.dispatch(
"GUILD_MEMBERS_CHUNK",
{"guild_id": str(guild_id), "members": members, "not_found": not_found},
)
return
# do the search
result = await self.storage.query_members(guild_id, query, limit)
await self.dispatch('GUILD_MEMBERS_CHUNK', {
'guild_id': str(guild_id),
'members': result
})
await self.dispatch(
"GUILD_MEMBERS_CHUNK", {"guild_id": str(guild_id), "members": result}
)
async def handle_8(self, payload: Dict):
"""Handle OP 8 Request Guild Members."""
data = payload['d']
gids = data['guild_id']
data = payload["d"]
gids = data["guild_id"]
uids, query, limit = data.get('user_ids', []), \
data.get('query', ''), \
data.get('limit', 0)
uids, query, limit = (
data.get("user_ids", []),
data.get("query", ""),
data.get("limit", 0),
)
if isinstance(gids, str):
await self._req_guild_members(gids, uids, query, limit)
@ -820,23 +827,21 @@ class GatewayWebsocket:
GUILD_SYNC event with that info.
"""
members = await self.storage.get_member_data(guild_id)
member_ids = [int(m['user']['id']) for m in members]
member_ids = [int(m["user"]["id"]) for m in members]
log.debug(f'Syncing guild {guild_id} with {len(member_ids)} members')
log.debug(f"Syncing guild {guild_id} with {len(member_ids)} members")
presences = await self.presence.guild_presences(member_ids, guild_id)
await self.dispatch('GUILD_SYNC', {
'id': str(guild_id),
'presences': presences,
'members': members,
})
await self.dispatch(
"GUILD_SYNC",
{"id": str(guild_id), "presences": presences, "members": members},
)
async def handle_12(self, payload: Dict[str, Any]):
"""Handle OP 12 Guild Sync."""
data = payload['d']
data = payload["d"]
gids = await self.user_storage.get_user_guilds(
self.state.user_id)
gids = await self.user_storage.get_user_guilds(self.state.user_id)
for guild_id in data:
try:
@ -931,35 +936,33 @@ class GatewayWebsocket:
]
}
"""
data = payload['d']
data = payload["d"]
gids = await self.user_storage.get_user_guilds(self.state.user_id)
guild_id = int(data['guild_id'])
guild_id = int(data["guild_id"])
# make sure to not extract info you shouldn't get
if guild_id not in gids:
return
log.debug('lazy request: members: {}',
data.get('members', []))
log.debug("lazy request: members: {}", data.get("members", []))
# make shard query
lazy_guilds = self.ext.dispatcher.backends['lazy_guild']
lazy_guilds = self.ext.dispatcher.backends["lazy_guild"]
for chan_id, ranges in data.get('channels', {}).items():
for chan_id, ranges in data.get("channels", {}).items():
chan_id = int(chan_id)
member_list = await lazy_guilds.get_gml(chan_id)
perms = await get_permissions(
self.state.user_id, chan_id, storage=self.storage)
self.state.user_id, chan_id, storage=self.storage
)
if not perms.bits.read_messages:
# ignore requests to unknown channels
return
await member_list.shard_query(
self.state.session_id, ranges
)
await member_list.shard_query(self.state.session_id, ranges)
async def _handle_23(self, payload):
# TODO reverse-engineer opcode 23, sent by client
@ -968,21 +971,21 @@ class GatewayWebsocket:
async def _process_message(self, payload):
"""Process a single message coming in from the client."""
try:
op_code = payload['op']
op_code = payload["op"]
except KeyError:
raise UnknownOPCode('No OP code')
raise UnknownOPCode("No OP code")
try:
handler = getattr(self, f'handle_{op_code}')
handler = getattr(self, f"handle_{op_code}")
except AttributeError:
log.warning('Payload with bad op: {}', pprint.pformat(payload))
raise UnknownOPCode(f'Bad OP code: {op_code}')
log.warning("Payload with bad op: {}", pprint.pformat(payload))
raise UnknownOPCode(f"Bad OP code: {op_code}")
await handler(payload)
async def _msg_ratelimit(self):
if self._check_ratelimit('messages', self.state.session_id):
raise WebsocketClose(4008, 'You are being ratelimited.')
if self._check_ratelimit("messages", self.state.session_id):
raise WebsocketClose(4008, "You are being ratelimited.")
async def _listen_messages(self):
"""Listen for messages coming in from the websocket."""
@ -990,15 +993,15 @@ class GatewayWebsocket:
# close anyone trying to login while the
# server is shutting down
if self.ext.state_manager.closed:
raise WebsocketClose(4000, 'state manager closed')
raise WebsocketClose(4000, "state manager closed")
if not self.ext.state_manager.accept_new:
raise WebsocketClose(4000, 'state manager closed for new')
raise WebsocketClose(4000, "state manager closed for new")
while True:
message = await self.ws.recv()
if len(message) > 4096:
raise DecodeError('Payload length exceeded')
raise DecodeError("Payload length exceeded")
if self.state:
await self._msg_ratelimit()
@ -1033,17 +1036,9 @@ class GatewayWebsocket:
# there arent any other states with websocket
if not with_ws:
offline = {
'afk': False,
'status': 'offline',
'game': None,
'since': 0,
}
offline = {"afk": False, "status": "offline", "game": None, "since": 0}
await self.ext.presence.dispatch_pres(
user_id,
offline
)
await self.ext.presence.dispatch_pres(user_id, offline)
async def run(self):
"""Wrap :meth:`listen_messages` inside
@ -1052,12 +1047,12 @@ class GatewayWebsocket:
await self._send_hello()
await self._listen_messages()
except websockets.exceptions.ConnectionClosed as err:
log.warning('conn close, state={}, err={}', self.state, err)
log.warning("conn close, state={}, err={}", self.state, err)
except WebsocketClose as err:
log.warning('ws close, state={} err={}', self.state, err)
log.warning("ws close, state={} err={}", self.state, err)
await self.ws.close(code=err.code, reason=err.reason)
except Exception as err:
log.exception('An exception has occoured. state={}', self.state)
log.exception("An exception has occoured. state={}", self.state)
await self.ws.close(code=4000, reason=repr(err))
finally:
user_id = self.state.user_id if self.state else None

View File

@ -17,19 +17,21 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
class GuildMemoryStore:
"""Store in-memory properties about guilds.
I could have just used Redis... probably too overkill to add
aioredis to the already long depedency list, plus, I don't need
"""
def __init__(self):
self._store = {}
def get(self, guild_id: int, attribute: str, default=None):
"""get a key"""
return self._store.get(f'{guild_id}:{attribute}', default)
return self._store.get(f"{guild_id}:{attribute}", default)
def set(self, guild_id: int, attribute: str, value):
"""set a key"""
self._store[f'{guild_id}:{attribute}'] = value
self._store[f"{guild_id}:{attribute}"] = value

View File

@ -33,47 +33,42 @@ from logbook import Logger
from PIL import Image
IMAGE_FOLDER = Path('./images')
IMAGE_FOLDER = Path("./images")
log = Logger(__name__)
EXTENSIONS = {
'image/jpeg': 'jpeg',
'image/webp': 'webp'
}
EXTENSIONS = {"image/jpeg": "jpeg", "image/webp": "webp"}
MIMES = {
'jpg': 'image/jpeg',
'jpe': 'image/jpeg',
'jpeg': 'image/jpeg',
'webp': 'image/webp',
"jpg": "image/jpeg",
"jpe": "image/jpeg",
"jpeg": "image/jpeg",
"webp": "image/webp",
}
STATIC_IMAGE_MIMES = [
'image/png',
'image/jpeg',
'image/webp'
]
STATIC_IMAGE_MIMES = ["image/png", "image/jpeg", "image/webp"]
def get_ext(mime: str) -> str:
if mime in EXTENSIONS:
return EXTENSIONS[mime]
extensions = mimetypes.guess_all_extensions(mime)
return extensions[0].strip('.')
return extensions[0].strip(".")
def get_mime(ext: str):
if ext in MIMES:
return MIMES[ext]
return mimetypes.types_map[f'.{ext}']
return mimetypes.types_map[f".{ext}"]
@dataclass
class Icon:
"""Main icon class"""
key: Optional[str]
icon_hash: Optional[str]
mime: Optional[str]
@ -85,7 +80,7 @@ class Icon:
return None
ext = get_ext(self.mime)
return str(IMAGE_FOLDER / f'{self.key}_{self.icon_hash}.{ext}')
return str(IMAGE_FOLDER / f"{self.key}_{self.icon_hash}.{ext}")
@property
def as_pathlib(self) -> Optional[Path]:
@ -106,13 +101,14 @@ class Icon:
class ImageError(Exception):
"""Image error class."""
pass
def to_raw(data_type: str, data: str) -> Optional[bytes]:
"""Given a data type in the data URI and data,
give the raw bytes being encoded."""
if data_type == 'base64':
if data_type == "base64":
return base64.b64decode(data)
return None
@ -136,7 +132,7 @@ def _calculate_hash(fhandler) -> str:
"""
hash_obj = sha256()
for chunk in iter(lambda: fhandler.read(4096), b''):
for chunk in iter(lambda: fhandler.read(4096), b""):
hash_obj.update(chunk)
# so that we can reuse the same handler
@ -162,39 +158,36 @@ async def calculate_hash(fhandle, loop=None) -> str:
def parse_data_uri(string) -> tuple:
"""Extract image data."""
try:
header, headered_data = string.split(';')
header, headered_data = string.split(";")
_, given_mime = header.split(':')
data_type, data = headered_data.split(',')
_, given_mime = header.split(":")
data_type, data = headered_data.split(",")
raw_data = to_raw(data_type, data)
if raw_data is None:
raise ImageError('Unknown data header')
raise ImageError("Unknown data header")
return given_mime, raw_data
except ValueError:
raise ImageError('data URI invalid syntax')
raise ImageError("data URI invalid syntax")
def _gen_update_sql(scope: str) -> str:
# match a scope to (table, field)
field = {
'user': 'avatar',
'guild': 'icon',
'splash': 'splash',
'banner': 'banner',
'channel-icons': 'icon',
"user": "avatar",
"guild": "icon",
"splash": "splash",
"banner": "banner",
"channel-icons": "icon",
}[scope]
table = {
'user': 'users',
'guild': 'guilds',
'splash': 'guilds',
'banner': 'guilds',
'channel-icons': 'group_dm_channels'
"user": "users",
"guild": "guilds",
"splash": "guilds",
"banner": "guilds",
"channel-icons": "group_dm_channels",
}[scope]
return f"""
@ -204,10 +197,10 @@ def _gen_update_sql(scope: str) -> str:
def _invalid(kwargs: dict) -> Optional[Icon]:
"""Send an invalid value."""
if not kwargs.get('always_icon', False):
if not kwargs.get("always_icon", False):
return None
return Icon(None, None, '')
return Icon(None, None, "")
def try_unlink(path: Union[Path, str]):
@ -225,18 +218,17 @@ def try_unlink(path: Union[Path, str]):
async def resize_gif(raw_data: bytes, target: tuple) -> tuple:
"""Resize a GIF image."""
# generate a temporary file to call gifsticle to and from.
input_fd, input_path = tempfile.mkstemp(suffix='.gif')
_, output_path = tempfile.mkstemp(suffix='.gif')
input_fd, input_path = tempfile.mkstemp(suffix=".gif")
_, output_path = tempfile.mkstemp(suffix=".gif")
input_handler = os.fdopen(input_fd, 'wb')
input_handler = os.fdopen(input_fd, "wb")
# make sure its valid image data
data_fd = BytesIO(raw_data)
image = Image.open(data_fd)
image.close()
log.info('resizing a GIF from {} to {}',
image.size, target)
log.info("resizing a GIF from {} to {}", image.size, target)
# insert image info on input_handler
# close it to make it ready for consumption by gifsicle
@ -244,12 +236,11 @@ async def resize_gif(raw_data: bytes, target: tuple) -> tuple:
input_handler.close()
# call gifsicle under subprocess
log.debug('input: {}', input_path)
log.debug('output: {}', output_path)
log.debug("input: {}", input_path)
log.debug("output: {}", output_path)
process = await asyncio.create_subprocess_shell(
f'gifsicle --resize {target[0]}x{target[1]} '
f'{input_path} > {output_path}',
f"gifsicle --resize {target[0]}x{target[1]} " f"{input_path} > {output_path}",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
@ -257,11 +248,11 @@ async def resize_gif(raw_data: bytes, target: tuple) -> tuple:
# run it, etc.
out, err = await process.communicate()
log.debug('out + err from gifsicle: {}', out + err)
log.debug("out + err from gifsicle: {}", out + err)
# write over an empty data_fd
data_fd = BytesIO()
output_handler = open(output_path, 'rb')
output_handler = open(output_path, "rb")
data_fd.write(output_handler.read())
# close unused handlers
@ -283,40 +274,40 @@ async def resize_gif(raw_data: bytes, target: tuple) -> tuple:
class IconManager:
"""Main icon manager."""
def __init__(self, app):
self.app = app
self.storage = app.storage
async def _convert_ext(self, icon: Icon, target: str):
target = 'jpeg' if target == 'jpg' else target
target = "jpeg" if target == "jpg" else target
target_mime = get_mime(target)
log.info('converting from {} to {}', icon.mime, target_mime)
log.info("converting from {} to {}", icon.mime, target_mime)
target_path = IMAGE_FOLDER / f'{icon.key}_{icon.icon_hash}.{target}'
target_path = IMAGE_FOLDER / f"{icon.key}_{icon.icon_hash}.{target}"
if target_path.exists():
return Icon(icon.key, icon.icon_hash, target_mime)
image = Image.open(icon.as_path)
target_fd = target_path.open('wb')
target_fd = target_path.open("wb")
if target == 'jpeg':
image = image.convert('RGB')
if target == "jpeg":
image = image.convert("RGB")
image.save(target_fd, format=target)
target_fd.close()
return Icon(icon.key, icon.icon_hash, target_mime)
async def generic_get(self, scope, key, icon_hash,
**kwargs) -> Optional[Icon]:
async def generic_get(self, scope, key, icon_hash, **kwargs) -> Optional[Icon]:
"""Get any icon."""
log.debug('GET {} {} {}', scope, key, icon_hash)
log.debug("GET {} {} {}", scope, key, icon_hash)
key = str(key)
hash_query = 'AND hash = $3' if icon_hash else ''
hash_query = "AND hash = $3" if icon_hash else ""
# hacky solution to only add icon_hash
# when needed.
@ -325,18 +316,21 @@ class IconManager:
if icon_hash:
args.append(icon_hash)
icon_row = await self.storage.db.fetchrow(f"""
icon_row = await self.storage.db.fetchrow(
f"""
SELECT key, hash, mime
FROM icons
WHERE scope = $1
AND key = $2
{hash_query}
""", *args)
""",
*args,
)
if not icon_row:
return None
icon = Icon(icon_row['key'], icon_row['hash'], icon_row['mime'])
icon = Icon(icon_row["key"], icon_row["hash"], icon_row["mime"])
# ensure we aren't messing with NULLs everywhere.
if icon.as_pathlib is None:
@ -349,18 +343,16 @@ class IconManager:
if icon.extension is None:
return None
if 'ext' in kwargs and kwargs['ext'] != icon.extension:
return await self._convert_ext(icon, kwargs['ext'])
if "ext" in kwargs and kwargs["ext"] != icon.extension:
return await self._convert_ext(icon, kwargs["ext"])
return icon
async def get_guild_icon(self, guild_id: int, icon_hash: str, **kwargs):
"""Get an icon for a guild."""
return await self.generic_get(
'guild', guild_id, icon_hash, **kwargs)
return await self.generic_get("guild", guild_id, icon_hash, **kwargs)
async def put(self, scope: str, key: str,
b64_data: str, **kwargs) -> Icon:
async def put(self, scope: str, key: str, b64_data: str, **kwargs) -> Icon:
"""Insert an icon."""
if b64_data is None:
return _invalid(kwargs)
@ -373,23 +365,22 @@ class IconManager:
# get an extension for the given data uri
extension = get_ext(mime)
if 'bsize' in kwargs and len(raw_data) > kwargs['bsize']:
if "bsize" in kwargs and len(raw_data) > kwargs["bsize"]:
return _invalid(kwargs)
# size management is different for gif files
# as they're composed of multiple frames.
if 'size' in kwargs and mime == 'image/gif':
data_fd, raw_data = await resize_gif(raw_data, kwargs['size'])
elif 'size' in kwargs:
if "size" in kwargs and mime == "image/gif":
data_fd, raw_data = await resize_gif(raw_data, kwargs["size"])
elif "size" in kwargs:
image = Image.open(data_fd)
if mime == 'image/jpeg':
if mime == "image/jpeg":
image = image.convert("RGB")
want = kwargs['size']
want = kwargs["size"]
log.info('resizing from {} to {}',
image.size, want)
log.info("resizing from {} to {}", image.size, want)
resized = image.resize(want, resample=Image.LANCZOS)
@ -404,23 +395,26 @@ class IconManager:
# calculate sha256
# ignore icon hashes if we're talking about emoji
icon_hash = (await calculate_hash(data_fd)
if scope != 'emoji'
else None)
icon_hash = await calculate_hash(data_fd) if scope != "emoji" else None
if scope == 'user' and mime == 'image/gif':
icon_hash = f'a_{icon_hash}'
if scope == "user" and mime == "image/gif":
icon_hash = f"a_{icon_hash}"
log.debug('PUT icon {!r} {!r} {!r} {!r}',
scope, key, icon_hash, mime)
log.debug("PUT icon {!r} {!r} {!r} {!r}", scope, key, icon_hash, mime)
await self.storage.db.execute("""
await self.storage.db.execute(
"""
INSERT INTO icons (scope, key, hash, mime)
VALUES ($1, $2, $3, $4)
""", scope, str(key), icon_hash, mime)
""",
scope,
str(key),
icon_hash,
mime,
)
# write it off to fs
icon_path = IMAGE_FOLDER / f'{key}_{icon_hash}.{extension}'
icon_path = IMAGE_FOLDER / f"{key}_{icon_hash}.{extension}"
icon_path.write_bytes(raw_data)
# copy from data_fd to icon_fd
@ -434,57 +428,80 @@ class IconManager:
if not icon:
return
log.debug('DEL {}',
icon)
log.debug("DEL {}", icon)
# dereference
await self.storage.db.execute("""
await self.storage.db.execute(
"""
UPDATE users
SET avatar = NULL
WHERE avatar = $1
""", icon.icon_hash)
""",
icon.icon_hash,
)
await self.storage.db.execute("""
await self.storage.db.execute(
"""
UPDATE group_dm_channels
SET icon = NULL
WHERE icon = $1
""", icon.icon_hash)
""",
icon.icon_hash,
)
await self.storage.db.execute("""
await self.storage.db.execute(
"""
DELETE FROM guild_emoji
WHERE image = $1
""", icon.icon_hash)
""",
icon.icon_hash,
)
await self.storage.db.execute("""
await self.storage.db.execute(
"""
UPDATE guilds
SET icon = NULL
WHERE icon = $1
""", icon.icon_hash)
""",
icon.icon_hash,
)
await self.storage.db.execute("""
await self.storage.db.execute(
"""
UPDATE guilds
SET splash = NULL
WHERE splash = $1
""", icon.icon_hash)
""",
icon.icon_hash,
)
await self.storage.db.execute("""
await self.storage.db.execute(
"""
UPDATE guilds
SET banner = NULL
WHERE banner = $1
""", icon.icon_hash)
""",
icon.icon_hash,
)
await self.storage.db.execute("""
await self.storage.db.execute(
"""
UPDATE group_dm_channels
SET icon = NULL
WHERE icon = $1
""", icon.icon_hash)
""",
icon.icon_hash,
)
await self.storage.db.execute("""
await self.storage.db.execute(
"""
DELETE FROM icons
WHERE hash = $1
""", icon.icon_hash)
""",
icon.icon_hash,
)
paths = IMAGE_FOLDER.glob(f'{icon.key}_{icon.icon_hash}.*')
paths = IMAGE_FOLDER.glob(f"{icon.key}_{icon.icon_hash}.*")
for path in paths:
try:
@ -492,11 +509,9 @@ class IconManager:
except FileNotFoundError:
pass
async def update(self, scope: str, key: str,
new_icon_data: str, **kwargs) -> Icon:
async def update(self, scope: str, key: str, new_icon_data: str, **kwargs) -> Icon:
"""Update an icon on a key."""
old_icon_hash = await self.storage.db.fetchval(
_gen_update_sql(scope), key)
old_icon_hash = await self.storage.db.fetchval(_gen_update_sql(scope), key)
# converting key to str only here since from here onwards
# its operations on the icons table (or a dereference with

View File

@ -20,6 +20,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
from logbook import Logger
log = Logger(__name__)
@ -30,6 +31,7 @@ class JobManager:
use helpers such as asyncio.gather and asyncio.Task.all_tasks. It only uses
its own internal list of jobs.
"""
def __init__(self, loop=None):
self.loop = loop or asyncio.get_event_loop()
self.jobs = []
@ -41,13 +43,11 @@ class JobManager:
try:
await coro
except Exception:
log.exception('Error while running job')
log.exception("Error while running job")
def spawn(self, coro):
"""Spawn a given future or coroutine in the background."""
task = self.loop.create_task(
self._wrapper(coro)
)
task = self.loop.create_task(self._wrapper(coro))
self.jobs.append(task)

View File

@ -26,40 +26,42 @@ from quart import current_app as app
# type for all the fields
_i = ctypes.c_uint8
class _RawPermsBits(ctypes.LittleEndianStructure):
"""raw bitfield for discord's permission number."""
_fields_ = [
('create_invites', _i, 1),
('kick_members', _i, 1),
('ban_members', _i, 1),
('administrator', _i, 1),
('manage_channels', _i, 1),
('manage_guild', _i, 1),
('add_reactions', _i, 1),
('view_audit_log', _i, 1),
('priority_speaker', _i, 1),
('stream', _i, 1),
('read_messages', _i, 1),
('send_messages', _i, 1),
('send_tts', _i, 1),
('manage_messages', _i, 1),
('embed_links', _i, 1),
('attach_files', _i, 1),
('read_history', _i, 1),
('mention_everyone', _i, 1),
('external_emojis', _i, 1),
('_unused2', _i, 1),
('connect', _i, 1),
('speak', _i, 1),
('mute_members', _i, 1),
('deafen_members', _i, 1),
('move_members', _i, 1),
('use_voice_activation', _i, 1),
('change_nickname', _i, 1),
('manage_nicknames', _i, 1),
('manage_roles', _i, 1),
('manage_webhooks', _i, 1),
('manage_emojis', _i, 1),
("create_invites", _i, 1),
("kick_members", _i, 1),
("ban_members", _i, 1),
("administrator", _i, 1),
("manage_channels", _i, 1),
("manage_guild", _i, 1),
("add_reactions", _i, 1),
("view_audit_log", _i, 1),
("priority_speaker", _i, 1),
("stream", _i, 1),
("read_messages", _i, 1),
("send_messages", _i, 1),
("send_tts", _i, 1),
("manage_messages", _i, 1),
("embed_links", _i, 1),
("attach_files", _i, 1),
("read_history", _i, 1),
("mention_everyone", _i, 1),
("external_emojis", _i, 1),
("_unused2", _i, 1),
("connect", _i, 1),
("speak", _i, 1),
("mute_members", _i, 1),
("deafen_members", _i, 1),
("move_members", _i, 1),
("use_voice_activation", _i, 1),
("change_nickname", _i, 1),
("manage_nicknames", _i, 1),
("manage_roles", _i, 1),
("manage_webhooks", _i, 1),
("manage_emojis", _i, 1),
]
@ -72,16 +74,14 @@ class Permissions(ctypes.Union):
val
The permissions value as an integer.
"""
_fields_ = [
('bits', _RawPermsBits),
('binary', ctypes.c_uint64),
]
_fields_ = [("bits", _RawPermsBits), ("binary", ctypes.c_uint64)]
def __init__(self, val: int):
self.binary = val
def __repr__(self):
return f'<Permissions binary={self.binary}>'
return f"<Permissions binary={self.binary}>"
def __int__(self):
return self.binary
@ -95,11 +95,15 @@ async def get_role_perms(guild_id, role_id, storage=None) -> Permissions:
if not storage:
storage = app.storage
perms = await storage.db.fetchval("""
perms = await storage.db.fetchval(
"""
SELECT permissions
FROM roles
WHERE guild_id = $1 AND id = $2
""", guild_id, role_id)
""",
guild_id,
role_id,
)
return Permissions(perms)
@ -118,11 +122,14 @@ async def base_permissions(member_id, guild_id, storage=None) -> Permissions:
if not storage:
storage = app.storage
owner_id = await storage.db.fetchval("""
owner_id = await storage.db.fetchval(
"""
SELECT owner_id
FROM guilds
WHERE id = $1
""", guild_id)
""",
guild_id,
)
if owner_id == member_id:
return ALL_PERMISSIONS
@ -130,20 +137,27 @@ async def base_permissions(member_id, guild_id, storage=None) -> Permissions:
# get permissions for @everyone
permissions = await get_role_perms(guild_id, guild_id, storage)
role_ids = await storage.db.fetch("""
role_ids = await storage.db.fetch(
"""
SELECT role_id
FROM member_roles
WHERE guild_id = $1 AND user_id = $2
""", guild_id, member_id)
""",
guild_id,
member_id,
)
role_perms = []
for row in role_ids:
rperm = await storage.db.fetchval("""
rperm = await storage.db.fetchval(
"""
SELECT permissions
FROM roles
WHERE id = $1
""", row['role_id'])
""",
row["role_id"],
)
role_perms.append(rperm)
@ -164,16 +178,17 @@ def overwrite_mix(perms: Permissions, overwrite: dict) -> Permissions:
result = perms.binary
# negate the permissions that are denied
result &= ~overwrite['deny']
result &= ~overwrite["deny"]
# combine the permissions that are allowed
result |= overwrite['allow']
result |= overwrite["allow"]
return Permissions(result)
def overwrite_find_mix(perms: Permissions, overwrites: dict,
target_id: int) -> Permissions:
def overwrite_find_mix(
perms: Permissions, overwrites: dict, target_id: int
) -> Permissions:
"""Mix a given permission with a given overwrite.
Returns the given permission if an overwrite is not found.
@ -201,19 +216,25 @@ def overwrite_find_mix(perms: Permissions, overwrites: dict,
return perms
async def role_permissions(guild_id: int, role_id: int,
channel_id: int, storage=None) -> Permissions:
async def role_permissions(
guild_id: int, role_id: int, channel_id: int, storage=None
) -> Permissions:
"""Get the permissions for a role, in relation to a channel"""
if not storage:
storage = app.storage
perms = await get_role_perms(guild_id, role_id, storage)
overwrite = await storage.db.fetchrow("""
overwrite = await storage.db.fetchrow(
"""
SELECT allow, deny
FROM channel_overwrites
WHERE channel_id = $1 AND target_type = $2 AND target_role = $3
""", channel_id, 1, role_id)
""",
channel_id,
1,
role_id,
)
if overwrite:
perms = overwrite_mix(perms, overwrite)
@ -221,10 +242,13 @@ async def role_permissions(guild_id: int, role_id: int,
return perms
async def compute_overwrites(base_perms: Permissions,
user_id, channel_id: int,
async def compute_overwrites(
base_perms: Permissions,
user_id,
channel_id: int,
guild_id: Optional[int] = None,
storage=None):
storage=None,
):
"""Compute the permissions in the context of a channel."""
if not storage:
storage = app.storage
@ -245,7 +269,7 @@ async def compute_overwrites(base_perms: Permissions,
return ALL_PERMISSIONS
# make it a map for better usage
overwrites = {int(o['id']): o for o in overwrites}
overwrites = {int(o["id"]): o for o in overwrites}
perms = overwrite_find_mix(perms, overwrites, guild_id)
@ -260,14 +284,11 @@ async def compute_overwrites(base_perms: Permissions,
for role_id in role_ids:
overwrite = overwrites.get(role_id)
if overwrite:
allow |= overwrite['allow']
deny |= overwrite['deny']
allow |= overwrite["allow"]
deny |= overwrite["deny"]
# final step for roles: mix
perms = overwrite_mix(perms, {
'allow': allow,
'deny': deny
})
perms = overwrite_mix(perms, {"allow": allow, "deny": deny})
# apply member specific overwrites
perms = overwrite_find_mix(perms, overwrites, user_id)
@ -275,8 +296,7 @@ async def compute_overwrites(base_perms: Permissions,
return perms
async def get_permissions(member_id: int, channel_id,
*, storage=None) -> Permissions:
async def get_permissions(member_id: int, channel_id, *, storage=None) -> Permissions:
"""Get the permissions for a user in a channel."""
if not storage:
storage = app.storage
@ -290,4 +310,5 @@ async def get_permissions(member_id: int, channel_id,
base_perms = await base_permissions(member_id, guild_id, storage)
return await compute_overwrites(
base_perms, member_id, channel_id, guild_id, storage)
base_perms, member_id, channel_id, guild_id, storage
)

View File

@ -32,62 +32,56 @@ def status_cmp(status: str, other_status: str) -> bool:
in the status hierarchy.
"""
hierarchy = {
'online': 3,
'idle': 2,
'dnd': 1,
'offline': 0,
None: -1,
}
hierarchy = {"online": 3, "idle": 2, "dnd": 1, "offline": 0, None: -1}
return hierarchy[status] > hierarchy[other_status]
def _best_presence(shards):
"""Find the 'best' presence given a list of GatewayState."""
best = {'status': None, 'game': None}
best = {"status": None, "game": None}
for state in shards:
presence = state.presence
status = presence['status']
status = presence["status"]
if not presence:
continue
# shards with a better status
# in the hierarchy are treated as best
if status_cmp(status, best['status']):
best['status'] = status
if status_cmp(status, best["status"]):
best["status"] = status
# if we have any game, use it
if presence['game'] is not None:
best['game'] = presence['game']
if presence["game"] is not None:
best["game"] = presence["game"]
# best['status'] is None when no
# status was good enough.
return None if not best['status'] else best
return None if not best["status"] else best
def fill_presence(presence: dict, *, game=None) -> dict:
"""Fill a given presence object with some specific fields."""
presence['client_status'] = {}
presence['mobile'] = False
presence["client_status"] = {}
presence["mobile"] = False
if 'since' not in presence:
presence['since'] = 0
if "since" not in presence:
presence["since"] = 0
# fill game and activities array depending if game
# is there or not
game = game or presence.get('game')
game = game or presence.get("game")
# casting to bool since a game of {} is still invalid
if game:
presence['game'] = game
presence['activities'] = [game]
presence["game"] = game
presence["activities"] = [game]
else:
presence['game'] = None
presence['activities'] = []
presence["game"] = None
presence["activities"] = []
return presence
@ -96,14 +90,13 @@ async def _pres(storage, user_id: int, status_obj: dict) -> dict:
"""Convert a given status into a presence, given the User ID and the
:class:`Storage` instance."""
ext = {
'user': await storage.get_user(user_id),
'activities': [],
"user": await storage.get_user(user_id),
"activities": [],
# NOTE: we are purposefully overwriting the fields, as there
# isn't any push for us to actually implement mobile detection, or
# web detection, etc.
'client_status': {},
'mobile': False,
"client_status": {},
"mobile": False,
}
return fill_presence({**status_obj, **ext})
@ -115,14 +108,16 @@ class PresenceManager:
Has common functions to deal with fetching or updating presences, including
side-effects (events).
"""
def __init__(self, app):
self.storage = app.storage
self.user_storage = app.user_storage
self.state_manager = app.state_manager
self.dispatcher = app.dispatcher
async def guild_presences(self, member_ids: List[int],
guild_id: int) -> List[Dict[Any, str]]:
async def guild_presences(
self, member_ids: List[int], guild_id: int
) -> List[Dict[Any, str]]:
"""Fetch all presences in a guild."""
# this works via fetching all connected GatewayState on a guild
# then fetching its respective member and merging that info with
@ -132,34 +127,36 @@ class PresenceManager:
presences = []
for state in states:
member = await self.storage.get_member_data_one(
guild_id, state.user_id)
member = await self.storage.get_member_data_one(guild_id, state.user_id)
game = state.presence.get('game', None)
game = state.presence.get("game", None)
# only use the data we need.
presences.append(fill_presence({
'user': member['user'],
'roles': member['roles'],
'guild_id': str(guild_id),
presences.append(
fill_presence(
{
"user": member["user"],
"roles": member["roles"],
"guild_id": str(guild_id),
# if a state is connected to the guild
# we assume its online.
'status': state.presence.get('status', 'online'),
}, game=game))
"status": state.presence.get("status", "online"),
},
game=game,
)
)
return presences
async def dispatch_guild_pres(self, guild_id: int,
user_id: int, new_state: dict):
async def dispatch_guild_pres(self, guild_id: int, user_id: int, new_state: dict):
"""Dispatch a Presence update to an entire guild."""
state = dict(new_state)
member = await self.storage.get_member_data_one(guild_id, user_id)
game = state['game']
game = state["game"]
lazy_guild_store = self.dispatcher.backends['lazy_guild']
lazy_guild_store = self.dispatcher.backends["lazy_guild"]
lists = lazy_guild_store.get_gml_guild(guild_id)
# shards that are in lazy guilds with 'everyone'
@ -168,49 +165,44 @@ class PresenceManager:
for member_list in lists:
session_ids = await member_list.pres_update(
int(member['user']['id']),
{
'roles': member['roles'],
'status': state['status'],
'game': game
}
int(member["user"]["id"]),
{"roles": member["roles"], "status": state["status"], "game": game},
)
log.debug('Lazy Dispatch to {}',
len(session_ids))
log.debug("Lazy Dispatch to {}", len(session_ids))
# if we are on the 'everyone' member list, we don't
# dispatch a PRESENCE_UPDATE for those shards.
if member_list.channel_id == member_list.guild_id:
in_lazy.extend(session_ids)
pres_update_payload = fill_presence({
'guild_id': str(guild_id),
'user': member['user'],
'roles': member['roles'],
'status': state['status'],
}, game=game)
pres_update_payload = fill_presence(
{
"guild_id": str(guild_id),
"user": member["user"],
"roles": member["roles"],
"status": state["status"],
},
game=game,
)
# given a session id, return if the session id actually connects to
# a given user, and if the state has not been dispatched via lazy guild.
def _session_check(session_id):
state = self.state_manager.fetch_raw(session_id)
uid = int(member['user']['id'])
uid = int(member["user"]["id"])
if not state:
return False
# we don't want to send a presence update
# to the same user
return (state.user_id != uid and
session_id not in in_lazy)
return state.user_id != uid and session_id not in in_lazy
# everyone not in lazy guild mode
# gets a PRESENCE_UPDATE
await self.dispatcher.dispatch_filter(
'guild', guild_id,
_session_check,
'PRESENCE_UPDATE', pres_update_payload
"guild", guild_id, _session_check, "PRESENCE_UPDATE", pres_update_payload
)
return in_lazy
@ -220,25 +212,25 @@ class PresenceManager:
Also dispatches the presence to all the users' friends
"""
if state['status'] == 'invisible':
state['status'] = 'offline'
if state["status"] == "invisible":
state["status"] = "offline"
# TODO: shard-aware
guild_ids = await self.user_storage.get_user_guilds(user_id)
for guild_id in guild_ids:
await self.dispatch_guild_pres(
guild_id, user_id, state)
await self.dispatch_guild_pres(guild_id, user_id, state)
# dispatch to all friends that are subscribed to them
user = await self.storage.get_user(user_id)
game = state['game']
game = state["game"]
await self.dispatcher.dispatch(
'friend', user_id, 'PRESENCE_UPDATE', fill_presence({
'user': user,
'status': state['status'],
}, game=game))
"friend",
user_id,
"PRESENCE_UPDATE",
fill_presence({"user": user, "status": state["status"]}, game=game),
)
async def friend_presences(self, friend_ids: Iterable[int]) -> List[Presence]:
"""Fetch presences for a group of users.
@ -254,22 +246,25 @@ class PresenceManager:
if not friend_states:
# append offline
res.append(await _pres(storage, friend_id, {
'afk': False,
'status': 'offline',
'game': None,
'since': 0
}))
res.append(
await _pres(
storage,
friend_id,
{"afk": False, "status": "offline", "game": None, "since": 0},
)
)
continue
# filter the best shards:
# - all with id 0 (are the first shards in the collection) or
# - all shards with count = 1 (single shards)
good_shards = list(filter(
good_shards = list(
filter(
lambda state: state.shard[0] == 0 or state.shard[1] == 1,
friend_states
))
friend_states,
)
)
if good_shards:
best_pres = _best_presence(good_shards)

View File

@ -24,6 +24,11 @@ from .channel import ChannelDispatcher
from .friend import FriendDispatcher
from .lazy_guild import LazyGuildDispatcher
__all__ = ['GuildDispatcher', 'MemberDispatcher',
'UserDispatcher', 'ChannelDispatcher',
'FriendDispatcher', 'LazyGuildDispatcher']
__all__ = [
"GuildDispatcher",
"MemberDispatcher",
"UserDispatcher",
"ChannelDispatcher",
"FriendDispatcher",
"LazyGuildDispatcher",
]

View File

@ -38,23 +38,20 @@ def gdm_recipient_view(orig: dict, user_id: int) -> dict:
# make a copy or the original channel object
data = dict(orig)
idx = index_by_func(
lambda user: user['id'] == str(user_id),
data['recipients']
)
idx = index_by_func(lambda user: user["id"] == str(user_id), data["recipients"])
data['recipients'].pop(idx)
data["recipients"].pop(idx)
return data
class ChannelDispatcher(DispatcherWithFlags):
"""Main channel Pub/Sub logic."""
KEY_TYPE = int
VAL_TYPE = int
async def dispatch(self, channel_id,
event: str, data: Any) -> List[str]:
async def dispatch(self, channel_id, event: str, data: Any) -> List[str]:
"""Dispatch an event to a channel."""
# get everyone who is subscribed
# and store the number of states we dispatched the event to
@ -75,9 +72,11 @@ class ChannelDispatcher(DispatcherWithFlags):
# TODO: make a fetch_states that fetches shards
# - with id 0 (count any) OR
# - single shards (id=0, count=1)
states = (self.sm.fetch_states(user_id, guild_id)
if guild_id else
self.sm.user_states(user_id))
states = (
self.sm.fetch_states(user_id, guild_id)
if guild_id
else self.sm.user_states(user_id)
)
# unsub people who don't have any states tied to the channel.
if not states:
@ -85,28 +84,28 @@ class ChannelDispatcher(DispatcherWithFlags):
continue
# skip typing events for users that don't want it
if event.startswith('TYPING_') and \
not self.flags_get(channel_id, user_id, 'typing', True):
if event.startswith("TYPING_") and not self.flags_get(
channel_id, user_id, "typing", True
):
continue
cur_sess = []
if event in ('CHANNEL_CREATE', 'CHANNEL_UPDATE') \
and data.get('type') == ChannelType.GROUP_DM.value:
if (
event in ("CHANNEL_CREATE", "CHANNEL_UPDATE")
and data.get("type") == ChannelType.GROUP_DM.value
):
# we edit the channel payload so it doesn't show
# the user as a recipient
new_data = gdm_recipient_view(data, user_id)
cur_sess = await self._dispatch_states(
states, event, new_data)
cur_sess = await self._dispatch_states(states, event, new_data)
else:
cur_sess = await self._dispatch_states(
states, event, data)
cur_sess = await self._dispatch_states(states, event, data)
sessions.extend(cur_sess)
dispatched += len(cur_sess)
log.info('Dispatched chan={} {!r} to {} states',
channel_id, event, dispatched)
log.info("Dispatched chan={} {!r} to {} states", channel_id, event, dispatched)
return sessions

View File

@ -80,8 +80,7 @@ class Dispatcher:
"""
raise NotImplementedError
async def _dispatch_states(self, states: list, event: str,
data) -> List[str]:
async def _dispatch_states(self, states: list, event: str, data) -> List[str]:
"""Dispatch an event to a list of states."""
res = []
@ -90,7 +89,7 @@ class Dispatcher:
await state.ws.dispatch(event, data)
res.append(state.session_id)
except Exception:
log.exception('error while dispatching')
log.exception("error while dispatching")
return res
@ -102,6 +101,7 @@ class DispatcherWithState(Dispatcher):
of boilerplate code on Pub/Sub backends
that have that dictionary.
"""
def __init__(self, main):
super().__init__(main)

View File

@ -31,6 +31,7 @@ class FriendDispatcher(DispatcherWithState):
channels. If that friend updates their presence, it will be
broadcasted through that channel to basically all their friends.
"""
KEY_TYPE = int
VAL_TYPE = int
@ -44,17 +45,13 @@ class FriendDispatcher(DispatcherWithState):
# since relationships broadcast to all shards.
sessions.extend(
await self.main_dispatcher.dispatch_filter(
'user', peer_id, func, event, data)
"user", peer_id, func, event, data
)
)
log.info('dispatched uid={} {!r} to {} states',
user_id, event, len(sessions))
log.info("dispatched uid={} {!r} to {} states", user_id, event, len(sessions))
return sessions
async def dispatch(self, user_id, event, data):
return await self.dispatch_filter(
user_id,
lambda sess_id: True,
event, data,
)
return await self.dispatch_filter(user_id, lambda sess_id: True, event, data)

View File

@ -29,11 +29,11 @@ log = Logger(__name__)
class GuildDispatcher(DispatcherWithFlags):
"""Guild backend for Pub/Sub"""
KEY_TYPE = int
VAL_TYPE = int
async def _chan_action(self, action: str,
guild_id: int, user_id: int, flags=None):
async def _chan_action(self, action: str, guild_id: int, user_id: int, flags=None):
"""Send an action to all channels of the guild."""
flags = flags or {}
chan_ids = await self.app.storage.get_channel_ids(guild_id)
@ -43,33 +43,31 @@ class GuildDispatcher(DispatcherWithFlags):
# only do an action for users that can
# actually read the channel to start with.
chan_perms = await get_permissions(
user_id, chan_id,
storage=self.main_dispatcher.app.storage)
user_id, chan_id, storage=self.main_dispatcher.app.storage
)
if not chan_perms.bits.read_messages:
log.debug('skipping cid={}, no read messages',
chan_id)
log.debug("skipping cid={}, no read messages", chan_id)
continue
log.debug('sending raw action {!r} to chan={}',
action, chan_id)
log.debug("sending raw action {!r} to chan={}", action, chan_id)
# for now, only sub() has support for flags.
# it is an idea to have flags support for other actions
args = []
if action == 'sub':
if action == "sub":
chanflags = dict(flags)
# channels don't need presence flags
try:
chanflags.pop('presence')
chanflags.pop("presence")
except KeyError:
pass
args.append(chanflags)
await self.main_dispatcher.action(
'channel', action, chan_id, user_id, *args
"channel", action, chan_id, user_id, *args
)
async def _chan_call(self, meth: str, guild_id: int, *args):
@ -77,26 +75,24 @@ class GuildDispatcher(DispatcherWithFlags):
in the guild."""
chan_ids = await self.app.storage.get_channel_ids(guild_id)
chan_dispatcher = self.main_dispatcher.backends['channel']
chan_dispatcher = self.main_dispatcher.backends["channel"]
method = getattr(chan_dispatcher, meth)
for chan_id in chan_ids:
log.debug('calling {} to chan={}',
meth, chan_id)
log.debug("calling {} to chan={}", meth, chan_id)
await method(chan_id, *args)
async def sub(self, guild_id: int, user_id: int, flags=None):
"""Subscribe a user to the guild."""
await super().sub(guild_id, user_id, flags)
await self._chan_action('sub', guild_id, user_id, flags)
await self._chan_action("sub", guild_id, user_id, flags)
async def unsub(self, guild_id: int, user_id: int):
"""Unsubscribe a user from the guild."""
await super().unsub(guild_id, user_id)
await self._chan_action('unsub', guild_id, user_id)
await self._chan_action("unsub", guild_id, user_id)
async def dispatch_filter(self, guild_id: int, func,
event: str, data: Any):
async def dispatch_filter(self, guild_id: int, func, event: str, data: Any):
"""Selectively dispatch to session ids that have
func(session_id) true."""
user_ids = self.state[guild_id]
@ -121,31 +117,23 @@ class GuildDispatcher(DispatcherWithFlags):
# note that this does not equate to any unsubscription
# of the channel.
if event.startswith('PRESENCE_') and \
not self.flags_get(guild_id, user_id, 'presence', True):
if event.startswith("PRESENCE_") and not self.flags_get(
guild_id, user_id, "presence", True
):
continue
# filter the ones that matter
states = list(filter(
lambda state: func(state.session_id), states
))
states = list(filter(lambda state: func(state.session_id), states))
cur_sess = await self._dispatch_states(
states, event, data)
cur_sess = await self._dispatch_states(states, event, data)
sessions.extend(cur_sess)
dispatched += len(cur_sess)
log.info('Dispatched {} {!r} to {} states',
guild_id, event, dispatched)
log.info("Dispatched {} {!r} to {} states", guild_id, event, dispatched)
return sessions
async def dispatch(self, guild_id: int,
event: str, data: Any):
async def dispatch(self, guild_id: int, event: str, data: Any):
"""Dispatch an event to all subscribers of the guild."""
return await self.dispatch_filter(
guild_id,
lambda sess_id: True,
event, data,
)
return await self.dispatch_filter(guild_id, lambda sess_id: True, event, data)

File diff suppressed because it is too large Load Diff

View File

@ -22,6 +22,7 @@ from .dispatcher import Dispatcher
class MemberDispatcher(Dispatcher):
"""Member backend for Pub/Sub."""
KEY_TYPE = tuple
async def dispatch(self, key, event, data):
@ -39,7 +40,7 @@ class MemberDispatcher(Dispatcher):
# if no states were found, we should
# unsub the user from the GUILD channel
if not states:
await self.main_dispatcher.unsub('guild', guild_id, user_id)
await self.main_dispatcher.unsub("guild", guild_id, user_id)
return
return await self._dispatch_states(states, event, data)

View File

@ -22,22 +22,18 @@ from .dispatcher import Dispatcher
class UserDispatcher(Dispatcher):
"""User backend for Pub/Sub."""
KEY_TYPE = int
async def dispatch_filter(self, user_id: int, func, event, data):
"""Dispatch an event to all shards of a user."""
# filter only states where func() gives true
states = list(filter(
lambda state: func(state.session_id),
self.sm.user_states(user_id)
))
states = list(
filter(lambda state: func(state.session_id), self.sm.user_states(user_id))
)
return await self._dispatch_states(states, event, data)
async def dispatch(self, user_id: int, event, data):
return await self.dispatch_filter(
user_id,
lambda sess_id: True,
event, data,
)
return await self.dispatch_filter(user_id, lambda sess_id: True, event, data)

View File

@ -28,6 +28,7 @@ import time
class RatelimitBucket:
"""Main ratelimit bucket class."""
def __init__(self, tokens, second):
self.requests = tokens
self.second = second
@ -88,17 +89,19 @@ class RatelimitBucket:
Used to manage multiple ratelimits to users.
"""
return RatelimitBucket(self.requests,
self.second)
return RatelimitBucket(self.requests, self.second)
def __repr__(self):
return (f'<RatelimitBucket requests={self.requests} '
f'second={self.second} window: {self._window} '
f'tokens={self._tokens}>')
return (
f"<RatelimitBucket requests={self.requests} "
f"second={self.second} window: {self._window} "
f"tokens={self._tokens}>"
)
class Ratelimit:
"""Manages buckets."""
def __init__(self, tokens, second, keys=None):
self._cache = {}
if keys is None:
@ -107,12 +110,11 @@ class Ratelimit:
self._cooldown = RatelimitBucket(tokens, second)
def __repr__(self):
return (f'<Ratelimit cooldown={self._cooldown}>')
return f"<Ratelimit cooldown={self._cooldown}>"
def _verify_cache(self):
current = time.time()
dead_keys = [k for k, v in self._cache.items()
if current > v._last + v.second]
dead_keys = [k for k, v in self._cache.items() if current > v._last + v.second]
for k in dead_keys:
del self._cache[k]

View File

@ -31,10 +31,10 @@ async def _check_bucket(bucket):
if retry_after:
request.retry_after = retry_after
raise Ratelimited('You are being rate limited.', {
'retry_after': int(retry_after * 1000),
'global': request.bucket_global,
})
raise Ratelimited(
"You are being rate limited.",
{"retry_after": int(retry_after * 1000), "global": request.bucket_global},
)
async def _handle_global(ratelimit):
@ -59,13 +59,13 @@ async def _handle_specific(ratelimit):
keys = ratelimit.keys
# base key is the user id
key_components = [f'user_id:{user_id}']
key_components = [f"user_id:{user_id}"]
for key in keys:
val = request.view_args[key]
key_components.append(f'{key}:{val}')
key_components.append(f"{key}:{val}")
bucket_key = ':'.join(key_components)
bucket_key = ":".join(key_components)
bucket = ratelimit.get_bucket(bucket_key)
await _check_bucket(bucket)
@ -78,9 +78,7 @@ async def ratelimit_handler():
rule = request.url_rule
if rule is None:
return await _handle_global(
app.ratelimiter.global_bucket
)
return await _handle_global(app.ratelimiter.global_bucket)
# rule.endpoint is composed of '<blueprint>.<function>'
# and so we can use that to make routes with different
@ -97,6 +95,4 @@ async def ratelimit_handler():
ratelimit = app.ratelimiter.get_ratelimit(rule_path)
await _handle_specific(ratelimit)
except KeyError:
await _handle_global(
app.ratelimiter.global_bucket
)
await _handle_global(app.ratelimiter.global_bucket)

View File

@ -34,33 +34,30 @@ WS:
|All Sent Messages| | 120/60s | per-session
"""
REACTION_BUCKET = Ratelimit(1, 0.25, ('channel_id'))
REACTION_BUCKET = Ratelimit(1, 0.25, ("channel_id"))
RATELIMITS = {
'channel_messages.create_message': Ratelimit(5, 5, ('channel_id')),
'channel_messages.delete_message': Ratelimit(5, 1, ('channel_id')),
"channel_messages.create_message": Ratelimit(5, 5, ("channel_id")),
"channel_messages.delete_message": Ratelimit(5, 1, ("channel_id")),
# all of those share the same bucket.
'channel_reactions.add_reaction': REACTION_BUCKET,
'channel_reactions.remove_own_reaction': REACTION_BUCKET,
'channel_reactions.remove_user_reaction': REACTION_BUCKET,
'guild_members.modify_guild_member': Ratelimit(10, 10, ('guild_id')),
'guild_members.update_nickname': Ratelimit(1, 1, ('guild_id')),
"channel_reactions.add_reaction": REACTION_BUCKET,
"channel_reactions.remove_own_reaction": REACTION_BUCKET,
"channel_reactions.remove_user_reaction": REACTION_BUCKET,
"guild_members.modify_guild_member": Ratelimit(10, 10, ("guild_id")),
"guild_members.update_nickname": Ratelimit(1, 1, ("guild_id")),
# this only applies to username.
# 'users.patch_me': Ratelimit(2, 3600),
'_ws.connect': Ratelimit(1, 5),
'_ws.presence': Ratelimit(5, 60),
'_ws.messages': Ratelimit(120, 60),
"_ws.connect": Ratelimit(1, 5),
"_ws.presence": Ratelimit(5, 60),
"_ws.messages": Ratelimit(120, 60),
# 1000 / 4h for new session issuing
'_ws.session': Ratelimit(1000, 14400)
"_ws.session": Ratelimit(1000, 14400),
}
class RatelimitManager:
"""Manager for the bucket managers"""
def __init__(self, testing_flag=False):
self._ratelimiters = {}
self._test = testing_flag
@ -74,9 +71,7 @@ class RatelimitManager:
# NOTE: this is a bad way to do it, but
# we only need to change that one for now.
rtl = (Ratelimit(10, 1)
if self._test and path == '_ws.connect'
else rtl)
rtl = Ratelimit(10, 1) if self._test and path == "_ws.connect" else rtl
self._ratelimiters[path] = rtl

View File

@ -28,26 +28,30 @@ from .errors import BadRequest
from .permissions import Permissions
from .types import Color
from .enums import (
ActivityType, StatusType, ExplicitFilter, RelationshipType,
MessageNotifications, ChannelType, VerificationLevel
ActivityType,
StatusType,
ExplicitFilter,
RelationshipType,
MessageNotifications,
ChannelType,
VerificationLevel,
)
from litecord.embed.schemas import EMBED_OBJECT, EmbedURL
log = Logger(__name__)
USERNAME_REGEX = re.compile(r'^[a-zA-Z0-9_ ]{2,30}$', re.A)
EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$',
re.A)
DATA_REGEX = re.compile(r'data\:image/(png|jpeg|gif);base64,(.+)', re.A)
USERNAME_REGEX = re.compile(r"^[a-zA-Z0-9_ ]{2,30}$", re.A)
EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$", re.A)
DATA_REGEX = re.compile(r"data\:image/(png|jpeg|gif);base64,(.+)", re.A)
# collection of regexes
USER_MENTION = re.compile(r'<@!?(\d+)>', re.A | re.M)
CHAN_MENTION = re.compile(r'<#(\d+)>', re.A | re.M)
ROLE_MENTION = re.compile(r'<@&(\d+)>', re.A | re.M)
EMOJO_MENTION = re.compile(r'<:(\.+):(\d+)>', re.A | re.M)
ANIMOJI_MENTION = re.compile(r'<a:(\.+):(\d+)>', re.A | re.M)
USER_MENTION = re.compile(r"<@!?(\d+)>", re.A | re.M)
CHAN_MENTION = re.compile(r"<#(\d+)>", re.A | re.M)
ROLE_MENTION = re.compile(r"<@&(\d+)>", re.A | re.M)
EMOJO_MENTION = re.compile(r"<:(\.+):(\d+)>", re.A | re.M)
ANIMOJI_MENTION = re.compile(r"<a:(\.+):(\d+)>", re.A | re.M)
def _in_enum(enum, value) -> bool:
@ -61,6 +65,7 @@ def _in_enum(enum, value) -> bool:
class LitecordValidator(Validator):
"""Main validator class for Litecord, containing custom types."""
def _validate_type_username(self, value: str) -> bool:
"""Validate against the username regex."""
return bool(USERNAME_REGEX.match(value))
@ -130,8 +135,7 @@ class LitecordValidator(Validator):
return False
# nobody is allowed to use the INCOMING and OUTGOING rel types
return val in (RelationshipType.FRIEND.value,
RelationshipType.BLOCK.value)
return val in (RelationshipType.FRIEND.value, RelationshipType.BLOCK.value)
def _validate_type_msg_notifications(self, value: str):
try:
@ -152,14 +156,15 @@ class LitecordValidator(Validator):
return self._validate_type_guild_name(value)
def _validate_type_theme(self, value: str) -> bool:
return value in ['light', 'dark']
return value in ["light", "dark"]
def _validate_type_nickname(self, value: str) -> bool:
return isinstance(value, str) and (len(value) < 32)
def validate(reqjson: Optional[Union[Dict, List]], schema: Dict,
raise_err: bool = True) -> Dict:
def validate(
reqjson: Optional[Union[Dict, List]], schema: Dict, raise_err: bool = True
) -> Dict:
"""Validate the given user-given data against a schema, giving the
"correct" version of the document, with all defaults applied.
@ -176,20 +181,20 @@ def validate(reqjson: Optional[Union[Dict, List]], schema: Dict,
validator = LitecordValidator(schema)
if reqjson is None:
raise BadRequest('No JSON provided')
raise BadRequest("No JSON provided")
try:
valid = validator.validate(reqjson)
except Exception:
log.exception('Error while validating')
raise Exception(f'Error while validating: {reqjson}')
log.exception("Error while validating")
raise Exception(f"Error while validating: {reqjson}")
if not valid:
errs = validator.errors
log.warning('Error validating doc {!r}: {!r}', reqjson, errs)
log.warning("Error validating doc {!r}: {!r}", reqjson, errs)
if raise_err:
raise BadRequest('bad payload', errs)
raise BadRequest("bad payload", errs)
return None
@ -197,554 +202,441 @@ def validate(reqjson: Optional[Union[Dict, List]], schema: Dict,
REGISTER = {
'username': {'type': 'username', 'required': True},
'email': {'type': 'email', 'required': False},
'password': {'type': 'password', 'required': False},
"username": {"type": "username", "required": True},
"email": {"type": "email", "required": False},
"password": {"type": "password", "required": False},
# invite stands for a guild invite, not an instance invite (that's on
# the register_with_invite handler).
'invite': {'type': 'string', 'required': False, 'nullable': True},
"invite": {"type": "string", "required": False, "nullable": True},
# following fields only sent by official client, unused by us
'fingerprint': {'type': 'string', 'required': False, 'nullable': True},
'captcha_key': {'type': 'string', 'required': False, 'nullable': True},
'gift_code_sku_id': {'type': 'string', 'required': False, 'nullable': True},
'consent': {'type': 'boolean', 'required': False},
"fingerprint": {"type": "string", "required": False, "nullable": True},
"captcha_key": {"type": "string", "required": False, "nullable": True},
"gift_code_sku_id": {"type": "string", "required": False, "nullable": True},
"consent": {"type": "boolean", "required": False},
}
# only used by us, not discord, hence 'invcode' (to separate from discord)
REGISTER_WITH_INVITE = {**REGISTER, **{
'invcode': {'type': 'string', 'required': True}
}}
REGISTER_WITH_INVITE = {**REGISTER, **{"invcode": {"type": "string", "required": True}}}
USER_UPDATE = {
'username': {
'type': 'username', 'minlength': 2,
'maxlength': 30, 'required': False},
'discriminator': {
'type': 'discriminator',
'required': False,
'nullable': True,
"username": {
"type": "username",
"minlength": 2,
"maxlength": 30,
"required": False,
},
'password': {
'type': 'password', 'required': False,
"discriminator": {"type": "discriminator", "required": False, "nullable": True},
"password": {"type": "password", "required": False},
"new_password": {
"type": "password",
"required": False,
"dependencies": "password",
"nullable": True,
},
'new_password': {
'type': 'password', 'required': False,
'dependencies': 'password', 'nullable': True
},
'email': {
'type': 'email', 'required': False, 'dependencies': 'password',
},
'avatar': {
"email": {"type": "email", "required": False, "dependencies": "password"},
"avatar": {
# can be both b64_icon or string (just the hash)
'type': 'string', 'required': False,
'nullable': True
"type": "string",
"required": False,
"nullable": True,
},
}
PARTIAL_ROLE_GUILD_CREATE = {
'type': 'dict',
'schema': {
'name': {'type': 'role_name'},
'color': {'type': 'number', 'default': 0},
'hoist': {'type': 'boolean', 'default': False},
"type": "dict",
"schema": {
"name": {"type": "role_name"},
"color": {"type": "number", "default": 0},
"hoist": {"type": "boolean", "default": False},
# NOTE: no position on partial role (on guild create)
'permissions': {'coerce': Permissions, 'required': False},
'mentionable': {'type': 'boolean', 'default': False},
}
"permissions": {"coerce": Permissions, "required": False},
"mentionable": {"type": "boolean", "default": False},
},
}
PARTIAL_CHANNEL_GUILD_CREATE = {
'type': 'dict',
'schema': {
'name': {'type': 'channel_name'},
'type': {'type': 'channel_type'},
}
"type": "dict",
"schema": {"name": {"type": "channel_name"}, "type": {"type": "channel_type"}},
}
GUILD_CREATE = {
'name': {'type': 'guild_name'},
'region': {'type': 'voice_region', 'nullable': True},
'icon': {'type': 'b64_icon', 'required': False, 'nullable': True},
'verification_level': {
'type': 'verification_level', 'default': 0},
'default_message_notifications': {
'type': 'msg_notifications', 'default': 0},
'explicit_content_filter': {
'type': 'explicit', 'default': 0},
'roles': {
'type': 'list', 'required': False,
'schema': PARTIAL_ROLE_GUILD_CREATE},
'channels': {
'type': 'list', 'default': [], 'schema': PARTIAL_CHANNEL_GUILD_CREATE},
"name": {"type": "guild_name"},
"region": {"type": "voice_region", "nullable": True},
"icon": {"type": "b64_icon", "required": False, "nullable": True},
"verification_level": {"type": "verification_level", "default": 0},
"default_message_notifications": {"type": "msg_notifications", "default": 0},
"explicit_content_filter": {"type": "explicit", "default": 0},
"roles": {"type": "list", "required": False, "schema": PARTIAL_ROLE_GUILD_CREATE},
"channels": {"type": "list", "default": [], "schema": PARTIAL_CHANNEL_GUILD_CREATE},
}
GUILD_UPDATE = {
'name': {
'type': 'guild_name',
'required': False
},
'region': {'type': 'voice_region', 'required': False, 'nullable': True},
"name": {"type": "guild_name", "required": False},
"region": {"type": "voice_region", "required": False, "nullable": True},
# all three can have hashes
'icon': {'type': 'string', 'required': False, 'nullable': True},
'banner': {'type': 'string', 'required': False, 'nullable': True},
'splash': {'type': 'string', 'required': False, 'nullable': True},
'description': {
'type': 'string', 'required': False,
'minlength': 1, 'maxlength': 120,
'nullable': True
"icon": {"type": "string", "required": False, "nullable": True},
"banner": {"type": "string", "required": False, "nullable": True},
"splash": {"type": "string", "required": False, "nullable": True},
"description": {
"type": "string",
"required": False,
"minlength": 1,
"maxlength": 120,
"nullable": True,
},
'verification_level': {
'type': 'verification_level', 'required': False},
'default_message_notifications': {
'type': 'msg_notifications', 'required': False},
'explicit_content_filter': {'type': 'explicit', 'required': False},
'afk_channel_id': {
'type': 'snowflake', 'required': False, 'nullable': True},
'afk_timeout': {'type': 'number', 'required': False},
'owner_id': {'type': 'snowflake', 'required': False},
'system_channel_id': {
'type': 'snowflake', 'required': False, 'nullable': True},
"verification_level": {"type": "verification_level", "required": False},
"default_message_notifications": {"type": "msg_notifications", "required": False},
"explicit_content_filter": {"type": "explicit", "required": False},
"afk_channel_id": {"type": "snowflake", "required": False, "nullable": True},
"afk_timeout": {"type": "number", "required": False},
"owner_id": {"type": "snowflake", "required": False},
"system_channel_id": {"type": "snowflake", "required": False, "nullable": True},
}
CHAN_OVERWRITE = {
'id': {'coerce': int},
'type': {'type': 'string', 'allowed': ['role', 'member']},
'allow': {'coerce': Permissions},
'deny': {'coerce': Permissions}
"id": {"coerce": int},
"type": {"type": "string", "allowed": ["role", "member"]},
"allow": {"coerce": Permissions},
"deny": {"coerce": Permissions},
}
CHAN_CREATE = {
'name': {
'type': 'string', 'minlength': 2,
'maxlength': 100, 'required': True
},
'type': {'type': 'channel_type',
'default': ChannelType.GUILD_TEXT.value},
'position': {'coerce': int, 'required': False},
'topic': {
'type': 'string', 'minlength': 0,
'maxlength': 1024, 'required': False},
'nsfw': {'type': 'boolean', 'required': False},
'rate_limit_per_user': {
'coerce': int, 'min': 0,
'max': 120, 'required': False},
'bitrate': {
'coerce': int, 'min': 8000,
"name": {"type": "string", "minlength": 2, "maxlength": 100, "required": True},
"type": {"type": "channel_type", "default": ChannelType.GUILD_TEXT.value},
"position": {"coerce": int, "required": False},
"topic": {"type": "string", "minlength": 0, "maxlength": 1024, "required": False},
"nsfw": {"type": "boolean", "required": False},
"rate_limit_per_user": {"coerce": int, "min": 0, "max": 120, "required": False},
"bitrate": {
"coerce": int,
"min": 8000,
# NOTE: 'max' is 96000 for non-vip guilds
'max': 128000, 'required': False},
'user_limit': {
"max": 128000,
"required": False,
},
"user_limit": {
# user_limit being 0 means infinite.
'coerce': int, 'min': 0,
'max': 99, 'required': False
"coerce": int,
"min": 0,
"max": 99,
"required": False,
},
'permission_overwrites': {
'type': 'list',
'schema': {'type': 'dict', 'schema': CHAN_OVERWRITE},
'required': False
"permission_overwrites": {
"type": "list",
"schema": {"type": "dict", "schema": CHAN_OVERWRITE},
"required": False,
},
'parent_id': {'coerce': int, 'required': False, 'nullable': True}
"parent_id": {"coerce": int, "required": False, "nullable": True},
}
CHAN_UPDATE = {**CHAN_CREATE, **{
'name': {
'type': 'string', 'minlength': 2,
'maxlength': 100, 'required': False},
}}
CHAN_UPDATE = {
**CHAN_CREATE,
**{"name": {"type": "string", "minlength": 2, "maxlength": 100, "required": False}},
}
ROLE_CREATE = {
'name': {'type': 'string', 'default': 'new role'},
'permissions': {'coerce': Permissions, 'nullable': True},
'color': {'coerce': Color, 'default': 0},
'hoist': {'type': 'boolean', 'default': False},
'mentionable': {'type': 'boolean', 'default': False},
"name": {"type": "string", "default": "new role"},
"permissions": {"coerce": Permissions, "nullable": True},
"color": {"coerce": Color, "default": 0},
"hoist": {"type": "boolean", "default": False},
"mentionable": {"type": "boolean", "default": False},
}
ROLE_UPDATE = {
'name': {'type': 'string', 'required': False},
'permissions': {'coerce': Permissions, 'required': False},
'color': {'coerce': Color, 'required': False},
'hoist': {'type': 'boolean', 'required': False},
'mentionable': {'type': 'boolean', 'required': False},
"name": {"type": "string", "required": False},
"permissions": {"coerce": Permissions, "required": False},
"color": {"coerce": Color, "required": False},
"hoist": {"type": "boolean", "required": False},
"mentionable": {"type": "boolean", "required": False},
}
ROLE_UPDATE_POSITION = {
'roles': {
'type': 'list',
'schema': {
'type': 'dict',
'schema': {
'id': {'coerce': int},
'position': {'coerce': int},
"roles": {
"type": "list",
"schema": {
"type": "dict",
"schema": {"id": {"coerce": int}, "position": {"coerce": int}},
},
}
}
}
MEMBER_UPDATE = {
'nick': {
'type': 'nickname', 'required': False},
'roles': {'type': 'list', 'required': False,
'schema': {'coerce': int}},
'mute': {'type': 'boolean', 'required': False},
'deaf': {'type': 'boolean', 'required': False},
'channel_id': {'type': 'snowflake', 'required': False},
"nick": {"type": "nickname", "required": False},
"roles": {"type": "list", "required": False, "schema": {"coerce": int}},
"mute": {"type": "boolean", "required": False},
"deaf": {"type": "boolean", "required": False},
"channel_id": {"type": "snowflake", "required": False},
}
# NOTE: things such as payload_json are parsed at the handler
# for creating a message.
MESSAGE_CREATE = {
'content': {'type': 'string', 'minlength': 0, 'maxlength': 2000},
'nonce': {'type': 'snowflake', 'required': False},
'tts': {'type': 'boolean', 'required': False},
'embed': {
'type': 'dict',
'schema': EMBED_OBJECT,
'required': False,
'nullable': True
}
"content": {"type": "string", "minlength": 0, "maxlength": 2000},
"nonce": {"type": "snowflake", "required": False},
"tts": {"type": "boolean", "required": False},
"embed": {
"type": "dict",
"schema": EMBED_OBJECT,
"required": False,
"nullable": True,
},
}
GW_ACTIVITY = {
'name': {'type': 'string', 'required': True},
'type': {'type': 'activity_type', 'required': True},
'url': {'type': 'string', 'required': False, 'nullable': True},
'timestamps': {
'type': 'dict',
'required': False,
'schema': {
'start': {'type': 'number', 'required': False},
'end': {'type': 'number', 'required': False},
"name": {"type": "string", "required": True},
"type": {"type": "activity_type", "required": True},
"url": {"type": "string", "required": False, "nullable": True},
"timestamps": {
"type": "dict",
"required": False,
"schema": {
"start": {"type": "number", "required": False},
"end": {"type": "number", "required": False},
},
},
'application_id': {'type': 'snowflake', 'required': False,
'nullable': False},
'details': {'type': 'string', 'required': False, 'nullable': True},
'state': {'type': 'string', 'required': False, 'nullable': True},
'party': {
'type': 'dict',
'required': False,
'schema': {
'id': {'type': 'snowflake', 'required': False},
'size': {'type': 'list', 'required': False},
}
"application_id": {"type": "snowflake", "required": False, "nullable": False},
"details": {"type": "string", "required": False, "nullable": True},
"state": {"type": "string", "required": False, "nullable": True},
"party": {
"type": "dict",
"required": False,
"schema": {
"id": {"type": "snowflake", "required": False},
"size": {"type": "list", "required": False},
},
'assets': {
'type': 'dict',
'required': False,
'schema': {
'large_image': {'type': 'snowflake', 'required': False},
'large_text': {'type': 'string', 'required': False},
'small_image': {'type': 'snowflake', 'required': False},
'small_text': {'type': 'string', 'required': False},
}
},
'secrets': {
'type': 'dict',
'required': False,
'schema': {
'join': {'type': 'string', 'required': False},
'spectate': {'type': 'string', 'required': False},
'match': {'type': 'string', 'required': False},
}
"assets": {
"type": "dict",
"required": False,
"schema": {
"large_image": {"type": "snowflake", "required": False},
"large_text": {"type": "string", "required": False},
"small_image": {"type": "snowflake", "required": False},
"small_text": {"type": "string", "required": False},
},
'instance': {'type': 'boolean', 'required': False},
'flags': {'type': 'number', 'required': False},
},
"secrets": {
"type": "dict",
"required": False,
"schema": {
"join": {"type": "string", "required": False},
"spectate": {"type": "string", "required": False},
"match": {"type": "string", "required": False},
},
},
"instance": {"type": "boolean", "required": False},
"flags": {"type": "number", "required": False},
}
GW_STATUS_UPDATE = {
'status': {'type': 'status_external', 'required': False,
'default': 'online'},
'activities': {
'type': 'list', 'required': False,
'schema': {'type': 'dict', 'schema': GW_ACTIVITY}
"status": {"type": "status_external", "required": False, "default": "online"},
"activities": {
"type": "list",
"required": False,
"schema": {"type": "dict", "schema": GW_ACTIVITY},
},
'afk': {'type': 'boolean', 'required': False},
'since': {'type': 'number', 'required': False, 'nullable': True},
'game': {
'type': 'dict',
'required': False,
'nullable': True,
'schema': GW_ACTIVITY,
"afk": {"type": "boolean", "required": False},
"since": {"type": "number", "required": False, "nullable": True},
"game": {
"type": "dict",
"required": False,
"nullable": True,
"schema": GW_ACTIVITY,
},
}
INVITE = {
# max_age in seconds
# 0 for infinite
'max_age': {
'type': 'number',
'min': 0,
'max': 86400,
"max_age": {
"type": "number",
"min": 0,
"max": 86400,
# a day
'default': 86400
"default": 86400,
},
# max invite uses
'max_uses': {
'type': 'number',
'min': 0,
"max_uses": {
"type": "number",
"min": 0,
# idk
'max': 1000,
"max": 1000,
# default infinite
'default': 0
"default": 0,
},
'temporary': {'type': 'boolean', 'required': False, 'default': False},
'unique': {'type': 'boolean', 'required': False, 'default': True},
'validate': {'type': 'string', 'required': False, 'nullable': True} # discord client sends invite code there
"temporary": {"type": "boolean", "required": False, "default": False},
"unique": {"type": "boolean", "required": False, "default": True},
"validate": {
"type": "string",
"required": False,
"nullable": True,
}, # discord client sends invite code there
}
USER_SETTINGS = {
'afk_timeout': {
'type': 'number', 'required': False, 'min': 0, 'max': 3000},
'animate_emoji': {'type': 'boolean', 'required': False},
'convert_emoticons': {'type': 'boolean', 'required': False},
'default_guilds_restricted': {'type': 'boolean', 'required': False},
'detect_platform_accounts': {'type': 'boolean', 'required': False},
'developer_mode': {'type': 'boolean', 'required': False},
'disable_games_tab': {'type': 'boolean', 'required': False},
'enable_tts_command': {'type': 'boolean', 'required': False},
'explicit_content_filter': {'type': 'explicit', 'required': False},
'friend_source': {
'type': 'dict',
'required': False,
'schema': {
'all': {'type': 'boolean', 'required': False},
'mutual_guilds': {'type': 'boolean', 'required': False},
'mutual_friends': {'type': 'boolean', 'required': False},
}
"afk_timeout": {"type": "number", "required": False, "min": 0, "max": 3000},
"animate_emoji": {"type": "boolean", "required": False},
"convert_emoticons": {"type": "boolean", "required": False},
"default_guilds_restricted": {"type": "boolean", "required": False},
"detect_platform_accounts": {"type": "boolean", "required": False},
"developer_mode": {"type": "boolean", "required": False},
"disable_games_tab": {"type": "boolean", "required": False},
"enable_tts_command": {"type": "boolean", "required": False},
"explicit_content_filter": {"type": "explicit", "required": False},
"friend_source": {
"type": "dict",
"required": False,
"schema": {
"all": {"type": "boolean", "required": False},
"mutual_guilds": {"type": "boolean", "required": False},
"mutual_friends": {"type": "boolean", "required": False},
},
'guild_positions': {
'type': 'list',
'required': False,
'schema': {'type': 'snowflake'}
},
'restricted_guilds': {
'type': 'list',
'required': False,
'schema': {'type': 'snowflake'}
"guild_positions": {
"type": "list",
"required": False,
"schema": {"type": "snowflake"},
},
'gif_auto_play': {'type': 'boolean', 'required': False},
'inline_attachment_media': {'type': 'boolean', 'required': False},
'inline_embed_media': {'type': 'boolean', 'required': False},
'message_display_compact': {'type': 'boolean', 'required': False},
'render_embeds': {'type': 'boolean', 'required': False},
'render_reactions': {'type': 'boolean', 'required': False},
'show_current_game': {'type': 'boolean', 'required': False},
'timezone_offset': {'type': 'number', 'required': False},
'status': {'type': 'status_external', 'required': False},
'theme': {'type': 'theme', 'required': False}
"restricted_guilds": {
"type": "list",
"required": False,
"schema": {"type": "snowflake"},
},
"gif_auto_play": {"type": "boolean", "required": False},
"inline_attachment_media": {"type": "boolean", "required": False},
"inline_embed_media": {"type": "boolean", "required": False},
"message_display_compact": {"type": "boolean", "required": False},
"render_embeds": {"type": "boolean", "required": False},
"render_reactions": {"type": "boolean", "required": False},
"show_current_game": {"type": "boolean", "required": False},
"timezone_offset": {"type": "number", "required": False},
"status": {"type": "status_external", "required": False},
"theme": {"type": "theme", "required": False},
}
RELATIONSHIP = {
'type': {
'type': 'rel_type',
'required': False,
'default': RelationshipType.FRIEND.value
"type": {
"type": "rel_type",
"required": False,
"default": RelationshipType.FRIEND.value,
}
}
CREATE_DM = {
'recipient_id': {
'type': 'snowflake',
'required': True
}
}
CREATE_DM = {"recipient_id": {"type": "snowflake", "required": True}}
CREATE_GROUP_DM = {
'recipients': {
'type': 'list',
'required': True,
'schema': {'type': 'snowflake'}
},
"recipients": {"type": "list", "required": True, "schema": {"type": "snowflake"}}
}
GROUP_DM_UPDATE = {
'name': {
'type': 'guild_name',
'required': False
},
'icon': {'type': 'b64_icon', 'required': False, 'nullable': True},
"name": {"type": "guild_name", "required": False},
"icon": {"type": "b64_icon", "required": False, "nullable": True},
}
SPECIFIC_FRIEND = {
'username': {'type': 'username'},
'discriminator': {'type': 'discriminator'}
"username": {"type": "username"},
"discriminator": {"type": "discriminator"},
}
GUILD_SETTINGS_CHAN_OVERRIDE = {
'type': 'dict',
'schema': {
'muted': {
'type': 'boolean', 'required': False},
'message_notifications': {
'type': 'msg_notifications',
'required': False,
}
}
"type": "dict",
"schema": {
"muted": {"type": "boolean", "required": False},
"message_notifications": {"type": "msg_notifications", "required": False},
},
}
GUILD_SETTINGS = {
'channel_overrides': {
'type': 'dict',
'valueschema': GUILD_SETTINGS_CHAN_OVERRIDE,
'keyschema': {'type': 'snowflake'},
'required': False,
"channel_overrides": {
"type": "dict",
"valueschema": GUILD_SETTINGS_CHAN_OVERRIDE,
"keyschema": {"type": "snowflake"},
"required": False,
},
'suppress_everyone': {
'type': 'boolean', 'required': False},
'muted': {
'type': 'boolean', 'required': False},
'mobile_push': {
'type': 'boolean', 'required': False},
'message_notifications': {
'type': 'msg_notifications',
'required': False,
}
"suppress_everyone": {"type": "boolean", "required": False},
"muted": {"type": "boolean", "required": False},
"mobile_push": {"type": "boolean", "required": False},
"message_notifications": {"type": "msg_notifications", "required": False},
}
GUILD_PRUNE = {
'days': {'type': 'number', 'coerce': int, 'min': 1, 'max': 30, 'default': 7},
'compute_prune_count': {'type': 'string', 'default': 'true'}
"days": {"type": "number", "coerce": int, "min": 1, "max": 30, "default": 7},
"compute_prune_count": {"type": "string", "default": "true"},
}
NEW_EMOJI = {
'name': {
'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': True},
'image': {'type': 'b64_icon', 'required': True},
'roles': {'type': 'list', 'schema': {'coerce': int}}
"name": {"type": "string", "minlength": 1, "maxlength": 256, "required": True},
"image": {"type": "b64_icon", "required": True},
"roles": {"type": "list", "schema": {"coerce": int}},
}
PATCH_EMOJI = {
'name': {
'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': True},
'roles': {'type': 'list', 'schema': {'coerce': int}}
"name": {"type": "string", "minlength": 1, "maxlength": 256, "required": True},
"roles": {"type": "list", "schema": {"coerce": int}},
}
SEARCH_CHANNEL = {
'content': {'type': 'string', 'minlength': 1, 'required': True},
'include_nsfw': {'coerce': bool, 'default': False},
'offset': {'coerce': int, 'default': 0}
"content": {"type": "string", "minlength": 1, "required": True},
"include_nsfw": {"coerce": bool, "default": False},
"offset": {"coerce": int, "default": 0},
}
GET_MENTIONS = {
'limit': {'coerce': int, 'default': 25},
'roles': {'coerce': bool, 'default': True},
'everyone': {'coerce': bool, 'default': True},
'guild_id': {'coerce': int, 'required': False}
"limit": {"coerce": int, "default": 25},
"roles": {"coerce": bool, "default": True},
"everyone": {"coerce": bool, "default": True},
"guild_id": {"coerce": int, "required": False},
}
VANITY_URL_PATCH = {
# TODO: put proper values in maybe an invite data type
'code': {'type': 'string', 'minlength': 5, 'maxlength': 30}
"code": {"type": "string", "minlength": 5, "maxlength": 30}
}
WEBHOOK_CREATE = {
'name': {
'type': 'string', 'minlength': 2, 'maxlength': 32,
'required': True
},
'avatar': {'type': 'b64_icon', 'required': False, 'nullable': False}
"name": {"type": "string", "minlength": 2, "maxlength": 32, "required": True},
"avatar": {"type": "b64_icon", "required": False, "nullable": False},
}
WEBHOOK_UPDATE = {
'name': {
'type': 'string', 'minlength': 2, 'maxlength': 32,
'required': False
},
"name": {"type": "string", "minlength": 2, "maxlength": 32, "required": False},
# TODO: check if its b64_icon or string since the client
# could pass an icon hash instead.
'avatar': {'type': 'b64_icon', 'required': False, 'nullable': False},
'channel_id': {'coerce': int, 'required': False, 'nullable': False}
"avatar": {"type": "b64_icon", "required": False, "nullable": False},
"channel_id": {"coerce": int, "required": False, "nullable": False},
}
WEBHOOK_MESSAGE_CREATE = {
'content': {
'type': 'string',
'minlength': 0, 'maxlength': 2000, 'required': False
"content": {"type": "string", "minlength": 0, "maxlength": 2000, "required": False},
"tts": {"type": "boolean", "required": False},
"username": {"type": "string", "minlength": 2, "maxlength": 32, "required": False},
"avatar_url": {"coerce": EmbedURL, "required": False},
"embeds": {
"type": "list",
"required": False,
"schema": {"type": "dict", "schema": EMBED_OBJECT},
},
'tts': {'type': 'boolean', 'required': False},
'username': {
'type': 'string',
'minlength': 2, 'maxlength': 32, 'required': False
},
'avatar_url': {
'coerce': EmbedURL, 'required': False
},
'embeds': {
'type': 'list',
'required': False,
'schema': {'type': 'dict', 'schema': EMBED_OBJECT}
}
}
BULK_DELETE = {
'messages': {
'type': 'list', 'required': True,
'minlength': 2, 'maxlength': 100,
'schema': {'coerce': int}
"messages": {
"type": "list",
"required": True,
"minlength": 2,
"maxlength": 100,
"schema": {"coerce": int},
}
}

View File

@ -61,19 +61,19 @@ def _snowflake(timestamp: int) -> Snowflake:
# bits 0-12 encode _generated_ids (size 12)
# modulo'd to prevent overflows
genid_b = '{0:012b}'.format(_generated_ids % 4096)
genid_b = "{0:012b}".format(_generated_ids % 4096)
# bits 12-17 encode PROCESS_ID (size 5)
procid_b = '{0:05b}'.format(PROCESS_ID)
procid_b = "{0:05b}".format(PROCESS_ID)
# bits 17-22 encode WORKER_ID (size 5)
workid_b = '{0:05b}'.format(WORKER_ID)
workid_b = "{0:05b}".format(WORKER_ID)
# bits 22-64 encode (timestamp - EPOCH) (size 42)
epochized = timestamp - EPOCH
epoch_b = '{0:042b}'.format(epochized)
epoch_b = "{0:042b}".format(epochized)
snowflake_b = f'{epoch_b}{workid_b}{procid_b}{genid_b}'
snowflake_b = f"{epoch_b}{workid_b}{procid_b}{genid_b}"
_generated_ids += 1
return int(snowflake_b, 2)
@ -87,7 +87,7 @@ def snowflake_time(snowflake: Snowflake) -> float:
# the total size for a snowflake is 64 bits,
# considering it is a string, position 0 to 42 will give us
# the `epochized` variable
snowflake_b = '{0:064b}'.format(snowflake)
snowflake_b = "{0:064b}".format(snowflake)
epochized_b = snowflake_b[:42]
epochized = int(epochized_b, 2)

File diff suppressed because it is too large Load Diff

View File

@ -24,6 +24,7 @@ from litecord.enums import MessageType
log = Logger(__name__)
async def _handle_pin_msg(app, channel_id, _pinned_id, author_id):
"""Handle a message pin."""
new_id = get_snowflake()
@ -37,8 +38,10 @@ async def _handle_pin_msg(app, channel_id, _pinned_id, author_id):
($1, $2, NULL, $3, NULL, '',
$4)
""",
new_id, channel_id, author_id,
MessageType.CHANNEL_PINNED_MESSAGE.value
new_id,
channel_id,
author_id,
MessageType.CHANNEL_PINNED_MESSAGE.value,
)
return new_id
@ -56,15 +59,16 @@ async def _handle_recp_add(app, channel_id, author_id, peer_id):
VALUES
($1, $2, $3, NULL, $4, $5)
""",
new_id, channel_id, author_id,
f'<@{peer_id}>',
MessageType.RECIPIENT_ADD.value
new_id,
channel_id,
author_id,
f"<@{peer_id}>",
MessageType.RECIPIENT_ADD.value,
)
return new_id
async def _handle_recp_rmv(app, channel_id, author_id, peer_id):
new_id = get_snowflake()
@ -76,9 +80,11 @@ async def _handle_recp_rmv(app, channel_id, author_id, peer_id):
VALUES
($1, $2, $3, NULL, $4, $5)
""",
new_id, channel_id, author_id,
f'<@{peer_id}>',
MessageType.RECIPIENT_REMOVE.value
new_id,
channel_id,
author_id,
f"<@{peer_id}>",
MessageType.RECIPIENT_REMOVE.value,
)
return new_id
@ -87,13 +93,16 @@ async def _handle_recp_rmv(app, channel_id, author_id, peer_id):
async def _handle_gdm_name_edit(app, channel_id, author_id):
new_id = get_snowflake()
gdm_name = await app.db.fetchval("""
gdm_name = await app.db.fetchval(
"""
SELECT name FROM group_dm_channels
WHERE id = $1
""", channel_id)
""",
channel_id,
)
if not gdm_name:
log.warning('no gdm name found for sys message')
log.warning("no gdm name found for sys message")
return
await app.db.execute(
@ -104,9 +113,11 @@ async def _handle_gdm_name_edit(app, channel_id, author_id):
VALUES
($1, $2, $3, NULL, $4, $5)
""",
new_id, channel_id, author_id,
new_id,
channel_id,
author_id,
gdm_name,
MessageType.CHANNEL_NAME_CHANGE.value
MessageType.CHANNEL_NAME_CHANGE.value,
)
return new_id
@ -123,16 +134,19 @@ async def _handle_gdm_icon_edit(app, channel_id, author_id):
VALUES
($1, $2, $3, NULL, $4, $5)
""",
new_id, channel_id, author_id,
'',
MessageType.CHANNEL_ICON_CHANGE.value
new_id,
channel_id,
author_id,
"",
MessageType.CHANNEL_ICON_CHANGE.value,
)
return new_id
async def send_sys_message(app, channel_id: int, m_type: MessageType,
*args, **kwargs) -> int:
async def send_sys_message(
app, channel_id: int, m_type: MessageType, *args, **kwargs
) -> int:
"""Send a system message.
The handler for a given message type MUST return an integer, that integer
@ -156,22 +170,19 @@ async def send_sys_message(app, channel_id: int, m_type: MessageType,
try:
handler = {
MessageType.CHANNEL_PINNED_MESSAGE: _handle_pin_msg,
# gdm specific
MessageType.RECIPIENT_ADD: _handle_recp_add,
MessageType.RECIPIENT_REMOVE: _handle_recp_rmv,
MessageType.CHANNEL_NAME_CHANGE: _handle_gdm_name_edit,
MessageType.CHANNEL_ICON_CHANGE: _handle_gdm_icon_edit
MessageType.CHANNEL_ICON_CHANGE: _handle_gdm_icon_edit,
}[m_type]
except KeyError:
raise ValueError('Invalid system message type')
raise ValueError("Invalid system message type")
message_id = await handler(app, channel_id, *args, **kwargs)
message = await app.storage.get_message(message_id)
await app.dispatcher.dispatch(
'channel', channel_id, 'MESSAGE_CREATE', message
)
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", message)
return message_id

View File

@ -29,6 +29,7 @@ HOURS = 60 * MINUTES
class Color:
"""Custom color class"""
def __init__(self, val: int):
self.blue = val & 255
self.green = (val >> 8) & 255
@ -37,7 +38,7 @@ class Color:
@property
def value(self):
"""Give the actual RGB integer encoding this color."""
return int('%02x%02x%02x' % (self.red, self.green, self.blue), 16)
return int("%02x%02x%02x" % (self.red, self.green, self.blue), 16)
@property
def to_json(self):
@ -49,4 +50,4 @@ class Color:
def timestamp_(dt) -> Optional[str]:
"""safer version for dt.isoformat()"""
return f'{dt.isoformat()}+00:00' if dt else None
return f"{dt.isoformat()}+00:00" if dt else None

View File

@ -27,43 +27,52 @@ log = Logger(__name__)
class UserStorage:
"""Storage functions related to a single user."""
def __init__(self, storage):
self.storage = storage
self.db = storage.db
async def fetch_notes(self, user_id: int) -> dict:
"""Fetch a users' notes"""
note_rows = await self.db.fetch("""
note_rows = await self.db.fetch(
"""
SELECT target_id, note
FROM notes
WHERE user_id = $1
""", user_id)
""",
user_id,
)
return {str(row['target_id']): row['note']
for row in note_rows}
return {str(row["target_id"]): row["note"] for row in note_rows}
async def get_user_settings(self, user_id: int) -> Dict[str, Any]:
"""Get current user settings."""
row = await self.storage.fetchrow_with_json("""
row = await self.storage.fetchrow_with_json(
"""
SELECT *
FROM user_settings
WHERE id = $1
""", user_id)
""",
user_id,
)
if not row:
log.info('Generating user settings for {}', user_id)
log.info("Generating user settings for {}", user_id)
await self.db.execute("""
await self.db.execute(
"""
INSERT INTO user_settings (id)
VALUES ($1)
""", user_id)
""",
user_id,
)
# recalling get_user_settings
# should work after adding
return await self.get_user_settings(user_id)
drow = dict(row)
drow.pop('id')
drow.pop("id")
return drow
async def get_relationships(self, user_id: int) -> List[Dict[str, Any]]:
@ -76,11 +85,15 @@ class UserStorage:
_outgoing = RelationshipType.OUTGOING.value
# check all outgoing friends
friends = await self.db.fetch("""
friends = await self.db.fetch(
"""
SELECT user_id, peer_id, rel_type
FROM relationships
WHERE user_id = $1 AND rel_type = $2
""", user_id, _friend)
""",
user_id,
_friend,
)
friends = list(map(dict, friends))
# mutuals is a list of ints
@ -95,66 +108,80 @@ class UserStorage:
SELECT user_id, peer_id
FROM relationships
WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3
""", row['peer_id'], row['user_id'],
_friend)
""",
row["peer_id"],
row["user_id"],
_friend,
)
if is_friend is not None:
mutuals.append(row['peer_id'])
mutuals.append(row["peer_id"])
# fetch friend requests directed at us
incoming_friends = await self.db.fetch("""
incoming_friends = await self.db.fetch(
"""
SELECT user_id, peer_id
FROM relationships
WHERE peer_id = $1 AND rel_type = $2
""", user_id, _friend)
""",
user_id,
_friend,
)
# only need their ids
incoming_friends = [r['user_id'] for r in incoming_friends
if r['user_id'] not in mutuals]
incoming_friends = [
r["user_id"] for r in incoming_friends if r["user_id"] not in mutuals
]
# only fetch blocks we did,
# not fetching the ones people did to us
blocks = await self.db.fetch("""
blocks = await self.db.fetch(
"""
SELECT user_id, peer_id, rel_type
FROM relationships
WHERE user_id = $1 AND rel_type = $2
""", user_id, _block)
""",
user_id,
_block,
)
blocks = list(map(dict, blocks))
res = []
for drow in friends:
drow['type'] = drow['rel_type']
drow['id'] = str(drow['peer_id'])
drow.pop('rel_type')
drow["type"] = drow["rel_type"]
drow["id"] = str(drow["peer_id"])
drow.pop("rel_type")
# check if the receiver is a mutual
# if it isnt, its still on a friend request stage
if drow['peer_id'] not in mutuals:
drow['type'] = _outgoing
if drow["peer_id"] not in mutuals:
drow["type"] = _outgoing
drow['user'] = await self.storage.get_user(drow['peer_id'])
drow["user"] = await self.storage.get_user(drow["peer_id"])
drow.pop('user_id')
drow.pop('peer_id')
drow.pop("user_id")
drow.pop("peer_id")
res.append(drow)
for peer_id in incoming_friends:
res.append({
'id': str(peer_id),
'user': await self.storage.get_user(peer_id),
'type': _incoming,
})
res.append(
{
"id": str(peer_id),
"user": await self.storage.get_user(peer_id),
"type": _incoming,
}
)
for drow in blocks:
drow['type'] = drow['rel_type']
drow.pop('rel_type')
drow["type"] = drow["rel_type"]
drow.pop("rel_type")
drow['id'] = str(drow['peer_id'])
drow['user'] = await self.storage.get_user(drow['peer_id'])
drow["id"] = str(drow["peer_id"])
drow["user"] = await self.storage.get_user(drow["peer_id"])
drow.pop('user_id')
drow.pop('peer_id')
drow.pop("user_id")
drow.pop("peer_id")
res.append(drow)
return res
@ -163,9 +190,11 @@ class UserStorage:
"""Get all friend IDs for a user."""
rels = await self.get_relationships(user_id)
return [int(r['user']['id'])
return [
int(r["user"]["id"])
for r in rels
if r['type'] == RelationshipType.FRIEND.value]
if r["type"] == RelationshipType.FRIEND.value
]
async def get_dms(self, user_id: int) -> List[Dict[str, Any]]:
"""Get all DM channels for a user, including group DMs.
@ -173,13 +202,16 @@ class UserStorage:
This will only fetch channels the user has in their state,
which is different than the whole list of DM channels.
"""
dm_ids = await self.db.fetch("""
dm_ids = await self.db.fetch(
"""
SELECT dm_id
FROM dm_channel_state
WHERE user_id = $1
""", user_id)
""",
user_id,
)
dm_ids = [r['dm_id'] for r in dm_ids]
dm_ids = [r["dm_id"] for r in dm_ids]
res = []
@ -191,21 +223,24 @@ class UserStorage:
async def get_read_state(self, user_id: int) -> List[Dict[str, Any]]:
"""Get the read state for a user."""
rows = await self.db.fetch("""
rows = await self.db.fetch(
"""
SELECT channel_id, last_message_id, mention_count
FROM user_read_state
WHERE user_id = $1
""", user_id)
""",
user_id,
)
res = []
for row in rows:
drow = dict(row)
drow['id'] = str(drow['channel_id'])
drow.pop('channel_id')
drow["id"] = str(drow["channel_id"])
drow.pop("channel_id")
drow['last_message_id'] = str(drow['last_message_id'])
drow["last_message_id"] = str(drow["last_message_id"])
res.append(drow)
@ -214,13 +249,17 @@ class UserStorage:
async def _get_chan_overrides(self, user_id: int, guild_id: int) -> List:
chan_overrides = []
overrides = await self.db.fetch("""
overrides = await self.db.fetch(
"""
SELECT channel_id::text, muted, message_notifications
FROM guild_settings_channel_overrides
WHERE
user_id = $1
AND guild_id = $2
""", user_id, guild_id)
""",
user_id,
guild_id,
)
for chan_row in overrides:
dcrow = dict(chan_row)
@ -228,30 +267,35 @@ class UserStorage:
return chan_overrides
async def get_guild_settings_one(self, user_id: int,
guild_id: int) -> dict:
async def get_guild_settings_one(self, user_id: int, guild_id: int) -> dict:
"""Get guild settings information for a single guild."""
row = await self.db.fetchrow("""
row = await self.db.fetchrow(
"""
SELECT guild_id::text, suppress_everyone, muted,
message_notifications, mobile_push
FROM guild_settings
WHERE user_id = $1 AND guild_id = $2
""", user_id, guild_id)
""",
user_id,
guild_id,
)
if not row:
await self.db.execute("""
await self.db.execute(
"""
INSERT INTO guild_settings (user_id, guild_id)
VALUES ($1, $2)
""", user_id, guild_id)
""",
user_id,
guild_id,
)
return await self.get_guild_settings_one(user_id, guild_id)
gid = int(row['guild_id'])
gid = int(row["guild_id"])
drow = dict(row)
chan_overrides = await self._get_chan_overrides(user_id, gid)
return {**drow, **{
'channel_overrides': chan_overrides
}}
return {**drow, **{"channel_overrides": chan_overrides}}
async def get_guild_settings(self, user_id: int):
"""Get the specific User Guild Settings,
@ -259,34 +303,38 @@ class UserStorage:
res = []
settings = await self.db.fetch("""
settings = await self.db.fetch(
"""
SELECT guild_id::text, suppress_everyone, muted,
message_notifications, mobile_push
FROM guild_settings
WHERE user_id = $1
""", user_id)
""",
user_id,
)
for row in settings:
gid = int(row['guild_id'])
gid = int(row["guild_id"])
drow = dict(row)
chan_overrides = await self._get_chan_overrides(user_id, gid)
res.append({**drow, **{
'channel_overrides': chan_overrides
}})
res.append({**drow, **{"channel_overrides": chan_overrides}})
return res
async def get_user_guilds(self, user_id: int) -> List[int]:
"""Get all guild IDs a user is on."""
guild_ids = await self.db.fetch("""
guild_ids = await self.db.fetch(
"""
SELECT guild_id
FROM members
WHERE user_id = $1
""", user_id)
""",
user_id,
)
return [row['guild_id'] for row in guild_ids]
return [row["guild_id"] for row in guild_ids]
async def get_mutual_guilds(self, user_id: int, peer_id: int) -> List[int]:
"""Get a list of guilds two separate users
@ -301,13 +349,17 @@ class UserStorage:
return await self.get_user_guilds(user_id) or [0]
mutual_guilds = await self.db.fetch("""
mutual_guilds = await self.db.fetch(
"""
SELECT guild_id FROM members WHERE user_id = $1
INTERSECT
SELECT guild_id FROM members WHERE user_id = $2
""", user_id, peer_id)
""",
user_id,
peer_id,
)
mutual_guilds = [r['guild_id'] for r in mutual_guilds]
mutual_guilds = [r["guild_id"] for r in mutual_guilds]
return mutual_guilds
@ -316,7 +368,8 @@ class UserStorage:
This returns false even if there is a friend request.
"""
return await self.db.fetchval("""
return await self.db.fetchval(
"""
SELECT
(
SELECT EXISTS(
@ -337,17 +390,23 @@ class UserStorage:
AND rel_type = 1
)
)
""", user_id, peer_id)
""",
user_id,
peer_id,
)
async def get_gdms_internal(self, user_id) -> List[int]:
"""Return a list of Group DM IDs the user is a member of."""
rows = await self.db.fetch("""
rows = await self.db.fetch(
"""
SELECT id
FROM group_dm_members
WHERE member_id = $1
""", user_id)
""",
user_id,
)
return [r['id'] for r in rows]
return [r["id"] for r in rows]
async def get_gdms(self, user_id) -> List[Dict[str, Any]]:
"""Get list of group DMs a user is in."""
@ -356,8 +415,6 @@ class UserStorage:
res = []
for gdm_id in gdm_ids:
res.append(
await self.storage.get_channel(gdm_id, user_id=user_id)
)
res.append(await self.storage.get_channel(gdm_id, user_id=user_id))
return res

View File

@ -46,7 +46,7 @@ async def task_wrapper(name: str, coro):
except asyncio.CancelledError:
pass
except:
log.exception('{} task error', name)
log.exception("{} task error", name)
def dict_get(mapping, key, default):
@ -84,54 +84,66 @@ def mmh3(inp_str: str, seed: int = 0):
h1 = seed
# mm3 constants
c1 = 0xcc9e2d51
c2 = 0x1b873593
c1 = 0xCC9E2D51
c2 = 0x1B873593
i = 0
while i < bytecount:
k1 = (
(key[i] & 0xff) |
((key[i + 1] & 0xff) << 8) |
((key[i + 2] & 0xff) << 16) |
((key[i + 3] & 0xff) << 24)
(key[i] & 0xFF)
| ((key[i + 1] & 0xFF) << 8)
| ((key[i + 2] & 0xFF) << 16)
| ((key[i + 3] & 0xFF) << 24)
)
i += 4
k1 = ((((k1 & 0xffff) * c1) + ((((_u(k1) >> 16) * c1) & 0xffff) << 16))) & 0xffffffff
k1 = (
(((k1 & 0xFFFF) * c1) + ((((_u(k1) >> 16) * c1) & 0xFFFF) << 16))
) & 0xFFFFFFFF
k1 = (k1 << 15) | (_u(k1) >> 17)
k1 = ((((k1 & 0xffff) * c2) + ((((_u(k1) >> 16) * c2) & 0xffff) << 16))) & 0xffffffff;
k1 = (
(((k1 & 0xFFFF) * c2) + ((((_u(k1) >> 16) * c2) & 0xFFFF) << 16))
) & 0xFFFFFFFF
h1 ^= k1
h1 = (h1 << 13) | (_u(h1) >> 19);
h1b = ((((h1 & 0xffff) * 5) + ((((_u(h1) >> 16) * 5) & 0xffff) << 16))) & 0xffffffff;
h1 = (((h1b & 0xffff) + 0x6b64) + ((((_u(h1b) >> 16) + 0xe654) & 0xffff) << 16))
h1 = (h1 << 13) | (_u(h1) >> 19)
h1b = (
(((h1 & 0xFFFF) * 5) + ((((_u(h1) >> 16) * 5) & 0xFFFF) << 16))
) & 0xFFFFFFFF
h1 = ((h1b & 0xFFFF) + 0x6B64) + ((((_u(h1b) >> 16) + 0xE654) & 0xFFFF) << 16)
k1 = 0
v = None
if remainder == 3:
v = (key[i + 2] & 0xff) << 16
v = (key[i + 2] & 0xFF) << 16
elif remainder == 2:
v = (key[i + 1] & 0xff) << 8
v = (key[i + 1] & 0xFF) << 8
elif remainder == 1:
v = (key[i] & 0xff)
v = key[i] & 0xFF
if v is not None:
k1 ^= v
k1 = (((k1 & 0xffff) * c1) + ((((_u(k1) >> 16) * c1) & 0xffff) << 16)) & 0xffffffff
k1 = (((k1 & 0xFFFF) * c1) + ((((_u(k1) >> 16) * c1) & 0xFFFF) << 16)) & 0xFFFFFFFF
k1 = (k1 << 15) | (_u(k1) >> 17)
k1 = (((k1 & 0xffff) * c2) + ((((_u(k1) >> 16) * c2) & 0xffff) << 16)) & 0xffffffff
k1 = (((k1 & 0xFFFF) * c2) + ((((_u(k1) >> 16) * c2) & 0xFFFF) << 16)) & 0xFFFFFFFF
h1 ^= k1
h1 ^= len(key)
h1 ^= _u(h1) >> 16
h1 = (((h1 & 0xffff) * 0x85ebca6b) + ((((_u(h1) >> 16) * 0x85ebca6b) & 0xffff) << 16)) & 0xffffffff
h1 = (
((h1 & 0xFFFF) * 0x85EBCA6B) + ((((_u(h1) >> 16) * 0x85EBCA6B) & 0xFFFF) << 16)
) & 0xFFFFFFFF
h1 ^= _u(h1) >> 13
h1 = ((((h1 & 0xffff) * 0xc2b2ae35) + ((((_u(h1) >> 16) * 0xc2b2ae35) & 0xffff) << 16))) & 0xffffffff
h1 = (
(
((h1 & 0xFFFF) * 0xC2B2AE35)
+ ((((_u(h1) >> 16) * 0xC2B2AE35) & 0xFFFF) << 16)
)
) & 0xFFFFFFFF
h1 ^= _u(h1) >> 16
return _u(h1) >> 0
@ -139,6 +151,7 @@ def mmh3(inp_str: str, seed: int = 0):
class LitecordJSONEncoder(JSONEncoder):
"""Custom JSON encoder for Litecord."""
def default(self, value: Any):
"""By default, this will try to get the to_json attribute of a given
value being JSON encoded."""
@ -151,17 +164,17 @@ class LitecordJSONEncoder(JSONEncoder):
async def pg_set_json(con):
"""Set JSON and JSONB codecs for an asyncpg connection."""
await con.set_type_codec(
'json',
"json",
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
decoder=json.loads,
schema='pg_catalog'
schema="pg_catalog",
)
await con.set_type_codec(
'jsonb',
"jsonb",
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
decoder=json.loads,
schema='pg_catalog'
schema="pg_catalog",
)
@ -177,7 +190,8 @@ def yield_chunks(input_list: Sequence[Any], chunk_size: int):
# range accepts step param, so we use that to
# make the chunks
for idx in range(0, len(input_list), chunk_size):
yield input_list[idx:idx + chunk_size]
yield input_list[idx : idx + chunk_size]
def to_update(j: dict, orig: dict, field: str) -> bool:
"""Compare values to check if j[field] is actually updating
@ -193,27 +207,23 @@ async def search_result_from_list(rows: List) -> Dict[str, Any]:
- An int (?) on `total_results`
- Two bigint[], each on `before` and `after` respectively.
"""
results = 0 if not rows else rows[0]['total_results']
results = 0 if not rows else rows[0]["total_results"]
res = []
for row in rows:
before, after = [], []
for before_id in reversed(row['before']):
for before_id in reversed(row["before"]):
before.append(await app.storage.get_message(before_id))
for after_id in row['after']:
for after_id in row["after"]:
after.append(await app.storage.get_message(after_id))
msg = await app.storage.get_message(row['current_id'])
msg['hit'] = True
msg = await app.storage.get_message(row["current_id"])
msg["hit"] = True
res.append(before + [msg] + after)
return {
'total_results': results,
'messages': res,
'analytics_id': '',
}
return {"total_results": results, "messages": res, "analytics_id": ""}
def maybe_int(val: Any) -> Union[int, Any]:

View File

@ -31,6 +31,7 @@ log = Logger(__name__)
class LVSPConnection:
"""Represents a single LVSP connection."""
def __init__(self, lvsp, region: str, hostname: str):
self.lvsp = lvsp
self.app = lvsp.app
@ -46,7 +47,7 @@ class LVSPConnection:
@property
def _log_id(self):
return f'region={self.region} hostname={self.hostname}'
return f"region={self.region} hostname={self.hostname}"
async def send(self, payload):
"""Send a payload down the websocket."""
@ -61,50 +62,42 @@ class LVSPConnection:
async def send_op(self, opcode: int, data: dict):
"""Send a message with an OP code included"""
await self.send({
'op': opcode,
'd': data
})
await self.send({"op": opcode, "d": data})
async def send_info(self, info_type: str, info_data: Dict):
"""Send an INFO message down the websocket."""
await self.send({
'op': OP.info,
'd': {
'type': InfoTable[info_type.upper()],
'data': info_data
await self.send(
{
"op": OP.info,
"d": {"type": InfoTable[info_type.upper()], "data": info_data},
}
})
)
async def _heartbeater(self, hb_interval: int):
try:
await asyncio.sleep(hb_interval)
# TODO: add self._seq
await self.send_op(OP.heartbeat, {
's': 0
})
await self.send_op(OP.heartbeat, {"s": 0})
# give the server 300 milliseconds to reply.
await asyncio.sleep(300)
await self.conn.close(4000, 'heartbeat timeout')
await self.conn.close(4000, "heartbeat timeout")
except asyncio.CancelledError:
pass
def _start_hb(self):
self._hb_task = self.app.loop.create_task(
self._heartbeater(self._hb_interval)
)
self._hb_task = self.app.loop.create_task(self._heartbeater(self._hb_interval))
def _stop_hb(self):
self._hb_task.cancel()
async def _handle_0(self, msg):
"""Handle HELLO message."""
data = msg['d']
data = msg["d"]
# nonce = data['nonce']
self._hb_interval = data['heartbeat_interval']
self._hb_interval = data["heartbeat_interval"]
# TODO: send identify
@ -112,48 +105,52 @@ class LVSPConnection:
"""Update the health value of a given voice server."""
self.health = new_health
await self.app.db.execute("""
await self.app.db.execute(
"""
UPDATE voice_servers
SET health = $1
WHERE hostname = $2
""", new_health, self.hostname)
""",
new_health,
self.hostname,
)
async def _handle_3(self, msg):
"""Handle READY message.
We only start heartbeating after READY.
"""
await self._update_health(msg['health'])
await self._update_health(msg["health"])
self._start_hb()
async def _handle_5(self, msg):
"""Handle HEARTBEAT_ACK."""
self._stop_hb()
await self._update_health(msg['health'])
await self._update_health(msg["health"])
self._start_hb()
async def _handle_6(self, msg):
"""Handle INFO messages."""
info = msg['d']
info_type_str = InfoReverse[info['type']].lower()
info = msg["d"]
info_type_str = InfoReverse[info["type"]].lower()
try:
info_handler = getattr(self, f'_handle_info_{info_type_str}')
info_handler = getattr(self, f"_handle_info_{info_type_str}")
except AttributeError:
return
await info_handler(info['data'])
await info_handler(info["data"])
async def _handle_info_channel_assign(self, data: dict):
"""called by the server once we got a channel assign."""
try:
channel_id = data['channel_id']
channel_id = data["channel_id"]
channel_id = int(channel_id)
except (TypeError, ValueError):
return
try:
guild_id = data['guild_id']
guild_id = data["guild_id"]
guild_id = int(guild_id)
except (TypeError, ValueError):
guild_id = None
@ -166,19 +163,19 @@ class LVSPConnection:
msg = await self.recv()
try:
opcode = msg['op']
handler = getattr(self, f'_handle_{opcode}')
opcode = msg["op"]
handler = getattr(self, f"_handle_{opcode}")
await handler(msg)
except (KeyError, AttributeError):
# TODO: error codes in LVSP
raise Exception('invalid op code')
raise Exception("invalid op code")
async def start(self):
"""Try to start a websocket connection."""
try:
self.conn = await websockets.connect(f'wss://{self.hostname}')
self.conn = await websockets.connect(f"wss://{self.hostname}")
except Exception:
log.exception('failed to start lvsp conn to {}', self.hostname)
log.exception("failed to start lvsp conn to {}", self.hostname)
async def run(self):
"""Start the websocket."""
@ -186,15 +183,15 @@ class LVSPConnection:
try:
if not self.conn:
log.error('failed to start lvsp connection, stopping')
log.error("failed to start lvsp connection, stopping")
return
await self._loop()
except websockets.exceptions.ConnectionClosed as err:
log.warning('conn close, {}, err={}', self._log_id, err)
log.warning("conn close, {}, err={}", self._log_id, err)
# except WebsocketClose as err:
# log.warning('ws close, state={} err={}', self.state, err)
# await self.conn.close(code=err.code, reason=err.reason)
except Exception as err:
log.exception('An exception has occoured. {}', self._log_id)
log.exception("An exception has occoured. {}", self._log_id)
await self.conn.close(code=4000, reason=repr(err))

View File

@ -31,6 +31,7 @@ log = Logger(__name__)
@dataclass
class Region:
"""Voice region data."""
id: str
vip: bool
@ -40,6 +41,7 @@ class LVSPManager:
Spawns :class:`LVSPConnection` as needed, etc.
"""
def __init__(self, app, voice):
self.app = app
self.voice = voice
@ -61,49 +63,50 @@ class LVSPManager:
async def _spawn(self):
"""Spawn LVSPConnection for each region."""
regions = await self.app.db.fetch("""
regions = await self.app.db.fetch(
"""
SELECT id, vip
FROM voice_regions
WHERE deprecated = false
""")
"""
)
regions = [Region(r['id'], r['vip']) for r in regions]
regions = [Region(r["id"], r["vip"]) for r in regions]
if not regions:
log.warning('no regions are setup')
log.warning("no regions are setup")
return
for region in regions:
# store it locally for region() function
self.regions[region.id] = region
self.app.loop.create_task(
self._spawn_region(region)
)
self.app.loop.create_task(self._spawn_region(region))
async def _spawn_region(self, region: Region):
"""Spawn a region. Involves fetching all the hostnames
for the regions and spawning a LVSPConnection for each."""
servers = await self.app.db.fetch("""
servers = await self.app.db.fetch(
"""
SELECT hostname
FROM voice_servers
WHERE region_id = $1
""", region.id)
""",
region.id,
)
if not servers:
log.warning('region {} does not have servers', region)
log.warning("region {} does not have servers", region)
return
servers = [r['hostname'] for r in servers]
servers = [r["hostname"] for r in servers]
self.servers[region.id] = servers
for hostname in servers:
conn = LVSPConnection(self, region.id, hostname)
self.conns[hostname] = conn
self.app.loop.create_task(
conn.run()
)
self.app.loop.create_task(conn.run())
async def del_conn(self, conn):
"""Delete a connection from the connection pool."""
@ -119,11 +122,14 @@ class LVSPManager:
async def guild_region(self, guild_id: int) -> Optional[str]:
"""Return the voice region of a guild."""
return await self.app.db.fetchval("""
return await self.app.db.fetchval(
"""
SELECT region
FROM guilds
WHERE id = $1
""", guild_id)
""",
guild_id,
)
def get_health(self, hostname: str) -> float:
"""Get voice server health, given hostname."""
@ -144,10 +150,7 @@ class LVSPManager:
region = await self.guild_region(guild_id)
# sort connected servers by health
sorted_servers = sorted(
self.servers[region],
key=self.get_health
)
sorted_servers = sorted(self.servers[region], key=self.get_health)
try:
hostname = sorted_servers[0]

View File

@ -17,8 +17,10 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
class OPCodes:
"""LVSP OP codes."""
hello = 0
identify = 1
resume = 2
@ -29,13 +31,13 @@ class OPCodes:
InfoTable = {
'CHANNEL_REQ': 0,
'CHANNEL_ASSIGN': 1,
'CHANNEL_UPDATE': 2,
'CHANNEL_DESTROY': 3,
'VST_CREATE': 4,
'VST_UPDATE': 5,
'VST_LEAVE': 6,
"CHANNEL_REQ": 0,
"CHANNEL_ASSIGN": 1,
"CHANNEL_UPDATE": 2,
"CHANNEL_DESTROY": 3,
"VST_CREATE": 4,
"VST_UPDATE": 5,
"VST_LEAVE": 6,
}
InfoReverse = {v: k for k, v in InfoTable.items()}

View File

@ -43,6 +43,7 @@ def _construct_state(state_dict: dict) -> VoiceState:
class VoiceManager:
"""Main voice manager class."""
def __init__(self, app):
self.app = app
@ -56,7 +57,7 @@ class VoiceManager:
"""Return if a user can join a channel."""
channel = await self.app.storage.get_channel(channel_id)
ctype = ChannelType(channel['type'])
ctype = ChannelType(channel["type"])
if ctype not in VOICE_CHANNELS:
return
@ -65,14 +66,12 @@ class VoiceManager:
# get_permissions returns ALL_PERMISSIONS when
# the channel isn't from a guild
perms = await get_permissions(
user_id, channel_id, storage=self.app.storage
)
perms = await get_permissions(user_id, channel_id, storage=self.app.storage)
# hacky user_limit but should work, as channels not
# in guilds won't have that field.
is_full = states >= channel.get('user_limit', 100)
is_bot = (await self.app.storage.get_user(user_id))['bot']
is_full = states >= channel.get("user_limit", 100)
is_bot = (await self.app.storage.get_user(user_id))["bot"]
is_manager = perms.bits.manage_channels
# if the channel is full AND:
@ -140,8 +139,8 @@ class VoiceManager:
for field in prop:
# NOTE: this should not happen, ever.
if field in ('channel_id', 'user_id'):
raise ValueError('properties are updating channel or user')
if field in ("channel_id", "user_id"):
raise ValueError("properties are updating channel or user")
new_state_dict[field] = prop[field]
@ -153,27 +152,28 @@ class VoiceManager:
async def move_channels(self, old_voice_key: VoiceKey, channel_id: int):
"""Move a user between channels."""
await self.del_state(old_voice_key)
await self.create_state(old_voice_key, {'channel_id': channel_id})
await self.create_state(old_voice_key, {"channel_id": channel_id})
async def _lvsp_info_guild(self, guild_id, info_type, info_data):
hostname = await self.lvsp.get_guild_server(guild_id)
if hostname is None:
log.error('no voice server for guild id {}', guild_id)
log.error("no voice server for guild id {}", guild_id)
return
conn = self.lvsp.get_conn(hostname)
await conn.send_info(info_type, info_data)
async def _create_ctx_guild(self, guild_id, channel_id):
await self._lvsp_info_guild(guild_id, 'CHANNEL_REQ', {
'guild_id': str(guild_id),
'channel_id': str(channel_id),
})
await self._lvsp_info_guild(
guild_id,
"CHANNEL_REQ",
{"guild_id": str(guild_id), "channel_id": str(channel_id)},
)
async def _start_voice_guild(self, voice_key: VoiceKey, data: dict):
"""Start a voice context in a guild."""
user_id, guild_id = voice_key
channel_id = int(data['channel_id'])
channel_id = int(data["channel_id"])
existing_states = self.states[voice_key]
channel_exists = any(
@ -183,11 +183,15 @@ class VoiceManager:
if not channel_exists:
await self._create_ctx_guild(guild_id, channel_id)
await self._lvsp_info_guild(guild_id, 'VST_CREATE', {
'user_id': str(user_id),
'guild_id': str(guild_id),
'channel_id': str(channel_id),
})
await self._lvsp_info_guild(
guild_id,
"VST_CREATE",
{
"user_id": str(user_id),
"guild_id": str(guild_id),
"channel_id": str(channel_id),
},
)
async def create_state(self, voice_key: VoiceKey, data: dict):
"""Creates (or tries to create) a voice state.
@ -249,10 +253,13 @@ class VoiceManager:
async def voice_server_list(self, region: str) -> List[dict]:
"""Get a list of voice server objects"""
rows = await self.app.db.fetch("""
rows = await self.app.db.fetch(
"""
SELECT hostname, last_health
FROM voice_servers
WHERE region_id = $1
""", region)
""",
region,
)
return list(map(dict, rows))

View File

@ -23,6 +23,7 @@ from dataclasses import dataclass, asdict
@dataclass
class VoiceState:
"""Represents a voice state."""
guild_id: int
channel_id: int
user_id: int
@ -55,7 +56,7 @@ class VoiceState:
# a better approach would be actually using
# the suppressed_by field for backend efficiency.
self_dict['suppress'] = user_id == self.suppressed_by
self_dict.pop('suppressed_by')
self_dict["suppress"] = user_id == self.suppressed_by
self_dict.pop("suppressed_by")
return self_dict

View File

@ -27,5 +27,5 @@ import config
logging.basicConfig(level=logging.DEBUG)
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(main(config))

View File

@ -16,4 +16,3 @@ You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""

View File

@ -26,7 +26,7 @@ ALPHABET = string.ascii_lowercase + string.ascii_uppercase + string.digits
async def _gen_inv() -> str:
"""Generate an invite code"""
return ''.join(choice(ALPHABET) for _ in range(6))
return "".join(choice(ALPHABET) for _ in range(6))
async def gen_inv(ctx) -> str:
@ -34,11 +34,14 @@ async def gen_inv(ctx) -> str:
for _ in range(10):
possible_inv = await _gen_inv()
created_at = await ctx.db.fetchval("""
created_at = await ctx.db.fetchval(
"""
SELECT created_at
FROM instance_invites
WHERE code = $1
""", possible_inv)
""",
possible_inv,
)
if created_at is None:
return possible_inv
@ -51,27 +54,32 @@ async def make_inv(ctx, args):
max_uses = args.max_uses
await ctx.db.execute("""
await ctx.db.execute(
"""
INSERT INTO instance_invites (code, max_uses)
VALUES ($1, $2)
""", code, max_uses)
""",
code,
max_uses,
)
print(f'invite created with {max_uses} max uses', code)
print(f"invite created with {max_uses} max uses", code)
async def list_invs(ctx, args):
rows = await ctx.db.fetch("""
rows = await ctx.db.fetch(
"""
SELECT code, created_at, uses, max_uses
FROM instance_invites
""")
"""
)
print(len(rows), 'invites')
print(len(rows), "invites")
for row in rows:
max_uses = row['max_uses']
delta = datetime.datetime.utcnow() - row['created_at']
usage = ('infinite uses' if max_uses == -1
else f'{row["uses"]} / {max_uses}')
max_uses = row["max_uses"]
delta = datetime.datetime.utcnow() - row["created_at"]
usage = "infinite uses" if max_uses == -1 else f'{row["uses"]} / {max_uses}'
print(f'\t{row["code"]}, {usage}, made {delta} ago')
@ -79,40 +87,37 @@ async def list_invs(ctx, args):
async def delete_inv(ctx, args):
inv = args.invite_code
res = await ctx.db.execute("""
res = await ctx.db.execute(
"""
DELETE FROM instance_invites
WHERE code = $1
""", inv)
""",
inv,
)
if res == 'DELETE 0':
print('NOT FOUND')
if res == "DELETE 0":
print("NOT FOUND")
return
print('OK')
print("OK")
def setup(subparser):
makeinv_parser = subparser.add_parser(
'makeinv',
help='create an invite',
)
makeinv_parser = subparser.add_parser("makeinv", help="create an invite")
makeinv_parser.add_argument(
'max_uses', nargs='?', type=int, default=-1,
help='Maximum amount of uses before the invite is unavailable',
"max_uses",
nargs="?",
type=int,
default=-1,
help="Maximum amount of uses before the invite is unavailable",
)
makeinv_parser.set_defaults(func=make_inv)
listinv_parser = subparser.add_parser(
'listinv',
help='list all invites',
)
listinv_parser = subparser.add_parser("listinv", help="list all invites")
listinv_parser.set_defaults(func=list_invs)
delinv_parser = subparser.add_parser(
'delinv',
help='delete an invite',
)
delinv_parser.add_argument('invite_code')
delinv_parser = subparser.add_parser("delinv", help="delete an invite")
delinv_parser.add_argument("invite_code")
delinv_parser.set_defaults(func=delete_inv)

View File

@ -19,4 +19,4 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
from .command import setup as migration
__all__ = ['migration']
__all__ = ["migration"]

View File

@ -32,18 +32,19 @@ from logbook import Logger
log = Logger(__name__)
Migration = namedtuple('Migration', 'id name path')
Migration = namedtuple("Migration", "id name path")
# line of change, 4 april 2019, at 1am (gmt+0)
BREAK = datetime.datetime(2019, 4, 4, 1)
# if a database has those tables, it ran 0_base.sql.
HAS_BASE = ['users', 'guilds', 'e']
HAS_BASE = ["users", "guilds", "e"]
@dataclass
class MigrationContext:
"""Hold information about migration."""
migration_folder: Path
scripts: Dict[int, Migration]
@ -60,22 +61,21 @@ def make_migration_ctx() -> MigrationContext:
script_folder = os.sep.join(script_path.split(os.sep)[:-1])
script_folder = Path(script_folder)
migration_folder = script_folder / 'scripts'
migration_folder = script_folder / "scripts"
mctx = MigrationContext(migration_folder, {})
for mig_path in migration_folder.glob('*.sql'):
for mig_path in migration_folder.glob("*.sql"):
mig_path_str = str(mig_path)
# extract migration script id and name
mig_filename = mig_path_str.split(os.sep)[-1].split('.')[0]
name_fragments = mig_filename.split('_')
mig_filename = mig_path_str.split(os.sep)[-1].split(".")[0]
name_fragments = mig_filename.split("_")
mig_id = int(name_fragments[0])
mig_name = '_'.join(name_fragments[1:])
mig_name = "_".join(name_fragments[1:])
mctx.scripts[mig_id] = Migration(
mig_id, mig_name, mig_path)
mctx.scripts[mig_id] = Migration(mig_id, mig_name, mig_path)
return mctx
@ -83,7 +83,8 @@ def make_migration_ctx() -> MigrationContext:
async def _ensure_changelog(app, ctx):
# make sure we have the migration table up
try:
await app.db.execute("""
await app.db.execute(
"""
CREATE TABLE migration_log (
change_num bigint NOT NULL,
@ -94,43 +95,56 @@ async def _ensure_changelog(app, ctx):
PRIMARY KEY (change_num)
);
""")
"""
)
except asyncpg.DuplicateTableError:
log.debug('existing migration table')
log.debug("existing migration table")
# NOTE: this is a migration breakage,
# only applying to databases that had their first migration
# before 4 april 2019 (more on BREAK)
# if migration_log is empty, just assume this is new
first = await app.db.fetchval("""
first = (
await app.db.fetchval(
"""
SELECT apply_ts FROM migration_log
ORDER BY apply_ts ASC
LIMIT 1
""") or BREAK
"""
)
or BREAK
)
if first < BREAK:
log.info('deleting migration_log due to migration structure change')
log.info("deleting migration_log due to migration structure change")
await app.db.execute("DROP TABLE migration_log")
await _ensure_changelog(app, ctx)
async def _insert_log(app, migration_id: int, description) -> bool:
try:
await app.db.execute("""
await app.db.execute(
"""
INSERT INTO migration_log (change_num, description)
VALUES ($1, $2)
""", migration_id, description)
""",
migration_id,
description,
)
return True
except asyncpg.UniqueViolationError:
log.warning('already inserted {}', migration_id)
log.warning("already inserted {}", migration_id)
return False
async def _delete_log(app, migration_id: int):
await app.db.execute("""
await app.db.execute(
"""
DELETE FROM migration_log WHERE change_num = $1
""", migration_id)
""",
migration_id,
)
async def apply_migration(app, migration: Migration) -> bool:
@ -144,21 +158,20 @@ async def apply_migration(app, migration: Migration) -> bool:
Returns a boolean signaling if this failed or not.
"""
migration_sql = migration.path.read_text(encoding='utf-8')
migration_sql = migration.path.read_text(encoding="utf-8")
res = await _insert_log(
app, migration.id, f'migration: {migration.name}')
res = await _insert_log(app, migration.id, f"migration: {migration.name}")
if not res:
return False
try:
await app.db.execute(migration_sql)
log.info('applied {} {}', migration.id, migration.name)
log.info("applied {} {}", migration.id, migration.name)
return True
except:
log.exception('failed to run migration, rollbacking log')
log.exception("failed to run migration, rollbacking log")
await _delete_log(app, migration.id)
return False
@ -169,9 +182,11 @@ async def _check_base(app) -> bool:
file."""
try:
for table in HAS_BASE:
await app.db.execute(f"""
await app.db.execute(
f"""
SELECT * FROM {table} LIMIT 0
""")
"""
)
except asyncpg.UndefinedTableError:
return False
@ -197,14 +212,16 @@ async def migrate_cmd(app, _args):
has_base = await _check_base(app)
# fetch latest local migration that has been run on this database
local_change = await app.db.fetchval("""
local_change = await app.db.fetchval(
"""
SELECT max(change_num)
FROM migration_log
""")
"""
)
# if base exists, add it to logs, if not, apply (and add to logs)
if has_base:
await _insert_log(app, 0, 'migration setup (from existing)')
await _insert_log(app, 0, "migration setup (from existing)")
else:
await apply_migration(app, ctx.scripts[0])
@ -215,10 +232,10 @@ async def migrate_cmd(app, _args):
local_change = local_change or 0
latest_change = ctx.latest
log.debug('local: {}, latest: {}', local_change, latest_change)
log.debug("local: {}, latest: {}", local_change, latest_change)
if local_change == latest_change:
print('no changes to do, exiting')
print("no changes to do, exiting")
return
# we do local_change + 1 so we start from the
@ -227,15 +244,13 @@ async def migrate_cmd(app, _args):
for idx in range(local_change + 1, latest_change + 1):
migration = ctx.scripts.get(idx)
print('applying', migration.id, migration.name)
print("applying", migration.id, migration.name)
await apply_migration(app, migration)
def setup(subparser):
migrate_parser = subparser.add_parser(
'migrate',
help='Run migration tasks',
description=migrate_cmd.__doc__
"migrate", help="Run migration tasks", description=migrate_cmd.__doc__
)
migrate_parser.set_defaults(func=migrate_cmd)

View File

@ -24,39 +24,51 @@ from litecord.enums import UserFlags
async def find_user(username, discrim, ctx) -> int:
"""Get a user ID via the username/discrim pair."""
return await ctx.db.fetchval("""
return await ctx.db.fetchval(
"""
SELECT id
FROM users
WHERE username = $1 AND discriminator = $2
""", username, discrim)
""",
username,
discrim,
)
async def set_user_staff(user_id, ctx):
"""Give a single user staff status."""
old_flags = await ctx.db.fetchval("""
old_flags = await ctx.db.fetchval(
"""
SELECT flags
FROM users
WHERE id = $1
""", user_id)
""",
user_id,
)
new_flags = old_flags | UserFlags.staff
await ctx.db.execute("""
await ctx.db.execute(
"""
UPDATE users
SET flags=$1
WHERE id = $2
""", new_flags, user_id)
""",
new_flags,
user_id,
)
async def adduser(ctx, args):
"""Create a single user."""
uid, _ = await create_user(args.username, args.email,
args.password, ctx.db, ctx.loop)
uid, _ = await create_user(
args.username, args.email, args.password, ctx.db, ctx.loop
)
user = await ctx.storage.get_user(uid)
print('created!')
print(f'\tuid: {uid}')
print("created!")
print(f"\tuid: {uid}")
print(f'\tusername: {user["username"]}')
print(f'\tdiscrim: {user["discriminator"]}')
@ -72,22 +84,26 @@ async def make_staff(ctx, args):
uid = await find_user(args.username, args.discrim, ctx)
if not uid:
return print('user not found')
return print("user not found")
await set_user_staff(uid, ctx)
print('OK: set staff')
print("OK: set staff")
async def generate_bot_token(ctx, args):
"""Generate a token for specified bot."""
password_hash = await ctx.db.fetchval("""
password_hash = await ctx.db.fetchval(
"""
SELECT password_hash
FROM users
WHERE id = $1 AND bot = 'true'
""", int(args.user_id))
""",
int(args.user_id),
)
if not password_hash:
return print('cannot find a bot with specified id')
return print("cannot find a bot with specified id")
print(make_token(args.user_id, password_hash))
@ -97,7 +113,7 @@ async def del_user(ctx, args):
uid = await find_user(args.username, args.discrim, ctx)
if uid is None:
print('user not found')
print("user not found")
return
user = await ctx.storage.get_user(uid)
@ -106,57 +122,48 @@ async def del_user(ctx, args):
print(f'\tuname: {user["username"]}')
print(f'\tdiscrim: {user["discriminator"]}')
print('\n you sure you want to delete user? press Y (uppercase)')
print("\n you sure you want to delete user? press Y (uppercase)")
confirm = input()
if confirm != 'Y':
print('not confirmed')
if confirm != "Y":
print("not confirmed")
return
await delete_user(uid, app_=ctx)
print('ok')
print("ok")
def setup(subparser):
setup_test_parser = subparser.add_parser(
'adduser',
help='create a user',
)
setup_test_parser = subparser.add_parser("adduser", help="create a user")
setup_test_parser.add_argument(
'username', help='username of the user')
setup_test_parser.add_argument(
'email', help='email of the user')
setup_test_parser.add_argument(
'password', help='password of the user')
setup_test_parser.add_argument("username", help="username of the user")
setup_test_parser.add_argument("email", help="email of the user")
setup_test_parser.add_argument("password", help="password of the user")
setup_test_parser.set_defaults(func=adduser)
staff_parser = subparser.add_parser(
'make_staff',
help='make a user staff',
description=make_staff.__doc__
"make_staff", help="make a user staff", description=make_staff.__doc__
)
staff_parser.add_argument('username')
staff_parser.add_argument(
'discrim', help='the discriminator of the user')
staff_parser.add_argument("username")
staff_parser.add_argument("discrim", help="the discriminator of the user")
staff_parser.set_defaults(func=make_staff)
del_user_parser = subparser.add_parser(
'deluser', help='delete a single user')
del_user_parser = subparser.add_parser("deluser", help="delete a single user")
del_user_parser.add_argument('username')
del_user_parser.add_argument('discrim')
del_user_parser.add_argument("username")
del_user_parser.add_argument("discrim")
del_user_parser.set_defaults(func=del_user)
token_parser = subparser.add_parser(
'generate_token',
help='generate a token for specified bot',
description=generate_bot_token.__doc__)
"generate_token",
help="generate a token for specified bot",
description=generate_bot_token.__doc__,
)
token_parser.add_argument('user_id')
token_parser.add_argument("user_id")
token_parser.set_defaults(func=generate_bot_token)

View File

@ -34,6 +34,7 @@ log = Logger(__name__)
@dataclass
class FakeApp:
"""Fake app instance."""
config: dict
db = None
loop: asyncio.BaseEventLoop = None
@ -50,7 +51,7 @@ class FakeApp:
def init_parser():
parser = argparse.ArgumentParser()
subparser = parser.add_subparsers(help='operations')
subparser = parser.add_subparsers(help="operations")
migration(subparser)
users.setup(subparser)
@ -78,12 +79,12 @@ def main(config):
# only init app managers when we aren't migrating
# as the managers require it
# and the migrate command also sets the db up
if argv[1] != 'migrate':
if argv[1] != "migrate":
init_app_managers(app, voice=False)
args = parser.parse_args()
loop.run_until_complete(args.func(app, args))
except Exception:
log.exception('error while running command')
log.exception("error while running command")
finally:
loop.run_until_complete(app.db.close())

236
run.py
View File

@ -33,32 +33,51 @@ from aiohttp import ClientSession
import config
from litecord.blueprints import (
gateway, auth, users, guilds, channels, webhooks, science,
voice, invites, relationships, dms, icons, nodeinfo, static,
attachments, dm_channels
gateway,
auth,
users,
guilds,
channels,
webhooks,
science,
voice,
invites,
relationships,
dms,
icons,
nodeinfo,
static,
attachments,
dm_channels,
)
# those blueprints are separated from the "main" ones
# for code readability if people want to dig through
# the codebase.
from litecord.blueprints.guild import (
guild_roles, guild_members, guild_channels, guild_mod,
guild_emoji
guild_roles,
guild_members,
guild_channels,
guild_mod,
guild_emoji,
)
from litecord.blueprints.channel import (
channel_messages, channel_reactions, channel_pins
channel_messages,
channel_reactions,
channel_pins,
)
from litecord.blueprints.user import (
user_settings, user_billing, fake_store
)
from litecord.blueprints.user import user_settings, user_billing, fake_store
from litecord.blueprints.user.billing_job import payment_job
from litecord.blueprints.admin_api import (
voice as voice_admin, features as features_admin,
guilds as guilds_admin, users as users_admin, instance_invites
voice as voice_admin,
features as features_admin,
guilds as guilds_admin,
users as users_admin,
instance_invites,
)
from litecord.blueprints.admin_api.voice import guild_region_check
@ -84,23 +103,23 @@ from litecord.utils import LitecordJSONEncoder
# setup logbook
handler = StreamHandler(sys.stdout, level=logbook.INFO)
handler.push_application()
log = Logger('litecord.boot')
log = Logger("litecord.boot")
redirect_logging()
def make_app():
app = Quart(__name__)
app.config.from_object(f'config.{config.MODE}')
is_debug = app.config.get('DEBUG', False)
app.config.from_object(f"config.{config.MODE}")
is_debug = app.config.get("DEBUG", False)
app.debug = is_debug
if is_debug:
log.info('on debug')
log.info("on debug")
handler.level = logbook.DEBUG
app.logger.level = logbook.DEBUG
# always keep websockets on INFO
logging.getLogger('websockets').setLevel(logbook.INFO)
logging.getLogger("websockets").setLevel(logbook.INFO)
# use our custom json encoder for custom data types
app.json_encoder = LitecordJSONEncoder
@ -112,51 +131,44 @@ def set_blueprints(app_):
"""Set the blueprints for a given app instance"""
bps = {
gateway: None,
auth: '/auth',
users: '/users',
user_settings: '/users',
user_billing: '/users',
relationships: '/users',
guilds: '/guilds',
guild_roles: '/guilds',
guild_members: '/guilds',
guild_channels: '/guilds',
guild_mod: '/guilds',
guild_emoji: '/guilds',
channels: '/channels',
channel_messages: '/channels',
channel_reactions: '/channels',
channel_pins: '/channels',
auth: "/auth",
users: "/users",
user_settings: "/users",
user_billing: "/users",
relationships: "/users",
guilds: "/guilds",
guild_roles: "/guilds",
guild_members: "/guilds",
guild_channels: "/guilds",
guild_mod: "/guilds",
guild_emoji: "/guilds",
channels: "/channels",
channel_messages: "/channels",
channel_reactions: "/channels",
channel_pins: "/channels",
webhooks: None,
science: None,
voice: '/voice',
voice: "/voice",
invites: None,
dms: '/users',
dm_channels: '/channels',
dms: "/users",
dm_channels: "/channels",
fake_store: None,
icons: -1,
attachments: -1,
nodeinfo: -1,
static: -1,
voice_admin: '/admin/voice',
features_admin: '/admin/guilds',
guilds_admin: '/admin/guilds',
users_admin: '/admin/users',
instance_invites: '/admin/instance/invites'
voice_admin: "/admin/voice",
features_admin: "/admin/guilds",
guilds_admin: "/admin/guilds",
users_admin: "/admin/users",
instance_invites: "/admin/instance/invites",
}
for bp, suffix in bps.items():
url_prefix = f'/api/v6{suffix or ""}'
if suffix == -1:
url_prefix = ''
url_prefix = ""
app_.register_blueprint(bp, url_prefix=url_prefix)
@ -175,37 +187,35 @@ async def app_before_request():
@app.after_request
async def app_after_request(resp):
"""Handle CORS headers."""
origin = request.headers.get('Origin', '*')
resp.headers['Access-Control-Allow-Origin'] = origin
resp.headers['Access-Control-Allow-Headers'] = (
'*, X-Super-Properties, '
'X-Fingerprint, '
'X-Context-Properties, '
'X-Failed-Requests, '
'X-Debug-Options, '
'Content-Type, '
'Authorization, '
'Origin, '
'If-None-Match'
origin = request.headers.get("Origin", "*")
resp.headers["Access-Control-Allow-Origin"] = origin
resp.headers["Access-Control-Allow-Headers"] = (
"*, X-Super-Properties, "
"X-Fingerprint, "
"X-Context-Properties, "
"X-Failed-Requests, "
"X-Debug-Options, "
"Content-Type, "
"Authorization, "
"Origin, "
"If-None-Match"
)
resp.headers['Access-Control-Allow-Methods'] = \
resp.headers.get('allow', '*')
resp.headers["Access-Control-Allow-Methods"] = resp.headers.get("allow", "*")
return resp
def _set_rtl_reset(bucket, resp):
reset = bucket._window + bucket.second
precision = request.headers.get('x-ratelimit-precision', 'second')
precision = request.headers.get("x-ratelimit-precision", "second")
if precision == 'second':
resp.headers['X-RateLimit-Reset'] = str(round(reset))
elif precision == 'millisecond':
resp.headers['X-RateLimit-Reset'] = str(reset)
if precision == "second":
resp.headers["X-RateLimit-Reset"] = str(round(reset))
elif precision == "millisecond":
resp.headers["X-RateLimit-Reset"] = str(reset)
else:
resp.headers['X-RateLimit-Reset'] = (
'Invalid X-RateLimit-Precision, '
'valid options are (second, millisecond)'
resp.headers["X-RateLimit-Reset"] = (
"Invalid X-RateLimit-Precision, " "valid options are (second, millisecond)"
)
@ -218,15 +228,15 @@ async def app_set_ratelimit_headers(resp):
if bucket is None:
raise AttributeError()
resp.headers['X-RateLimit-Limit'] = str(bucket.requests)
resp.headers['X-RateLimit-Remaining'] = str(bucket._tokens)
resp.headers['X-RateLimit-Global'] = str(request.bucket_global).lower()
resp.headers["X-RateLimit-Limit"] = str(bucket.requests)
resp.headers["X-RateLimit-Remaining"] = str(bucket._tokens)
resp.headers["X-RateLimit-Global"] = str(request.bucket_global).lower()
_set_rtl_reset(bucket, resp)
# only add Retry-After if we actually hit a ratelimit
retry_after = request.retry_after
if request.retry_after:
resp.headers['Retry-After'] = str(retry_after)
resp.headers["Retry-After"] = str(retry_after)
except AttributeError:
pass
@ -238,8 +248,8 @@ async def init_app_db(app_):
Also spawns the job scheduler.
"""
log.info('db connect')
app_.db = await asyncpg.create_pool(**app.config['POSTGRES'])
log.info("db connect")
app_.db = await asyncpg.create_pool(**app.config["POSTGRES"])
app_.sched = JobManager()
@ -247,7 +257,7 @@ async def init_app_db(app_):
def init_app_managers(app_, *, voice=True):
"""Initialize singleton classes."""
app_.loop = asyncio.get_event_loop()
app_.ratelimiter = RatelimitManager(app_.config.get('_testing'))
app_.ratelimiter = RatelimitManager(app_.config.get("_testing"))
app_.state_manager = StateManager()
app_.storage = Storage(app_)
@ -274,15 +284,12 @@ async def api_index(app_):
to_find = {}
found = []
with open('discord_endpoints.txt') as fd:
with open("discord_endpoints.txt") as fd:
for line in fd.readlines():
components = line.split(' ')
components = list(filter(
bool,
components
))
components = line.split(" ")
components = list(filter(bool, components))
name, method, path = components
path = f'/api/v6{path.strip()}'
path = f"/api/v6{path.strip()}"
method = method.strip()
to_find[(path, method)] = name
@ -290,17 +297,17 @@ async def api_index(app_):
path = rule.rule
# convert the path to the discord_endpoints file's style
path = path.replace('_', '.')
path = path.replace('<', '{')
path = path.replace('>', '}')
path = path.replace('int:', '')
path = path.replace("_", ".")
path = path.replace("<", "{")
path = path.replace(">", "}")
path = path.replace("int:", "")
# change our parameters into user.id
path = path.replace('member.id', 'user.id')
path = path.replace('banned.id', 'user.id')
path = path.replace('target.id', 'user.id')
path = path.replace('other.id', 'user.id')
path = path.replace('peer.id', 'user.id')
path = path.replace("member.id", "user.id")
path = path.replace("banned.id", "user.id")
path = path.replace("target.id", "user.id")
path = path.replace("other.id", "user.id")
path = path.replace("peer.id", "user.id")
methods = rule.methods
@ -317,10 +324,15 @@ async def api_index(app_):
percentage = (len(found) / len(api)) * 100
percentage = round(percentage, 2)
log.debug('API compliance: {} out of {} ({} missing), {}% compliant',
len(found), len(api), len(missing), percentage)
log.debug(
"API compliance: {} out of {} ({} missing), {}% compliant",
len(found),
len(api),
len(missing),
percentage,
)
log.debug('missing: {}', missing)
log.debug("missing: {}", missing)
async def post_app_start(app_):
@ -332,7 +344,7 @@ async def post_app_start(app_):
def start_websocket(host, port, ws_handler) -> asyncio.Future:
"""Start a websocket. Returns the websocket future"""
log.info(f'starting websocket at {host} {port}')
log.info(f"starting websocket at {host} {port}")
async def _wrapper(ws, url):
# We wrap the main websocket_handler
@ -348,7 +360,7 @@ async def app_before_serving():
Also sets up the websocket handlers.
"""
log.info('opening db')
log.info("opening db")
await init_app_db(app)
app.session = ClientSession()
@ -359,8 +371,7 @@ async def app_before_serving():
# start gateway websocket
# voice websocket is handled by the voice server
ws_fut = start_websocket(
app.config['WS_HOST'], app.config['WS_PORT'],
websocket_handler
app.config["WS_HOST"], app.config["WS_PORT"], websocket_handler
)
await ws_fut
@ -379,7 +390,7 @@ async def app_after_serving():
app.sched.close()
log.info('closing db')
log.info("closing db")
await app.db.close()
@ -391,24 +402,23 @@ async def handle_litecord_err(err):
ejson = {}
try:
ejson['code'] = err.error_code
ejson["code"] = err.error_code
except AttributeError:
pass
log.warning('error: {} {!r}', err.status_code, err.message)
log.warning("error: {} {!r}", err.status_code, err.message)
return jsonify({
'error': True,
'status': err.status_code,
'message': err.message,
**ejson
}), err.status_code
return (
jsonify(
{"error": True, "status": err.status_code, "message": err.message, **ejson}
),
err.status_code,
)
@app.errorhandler(500)
async def handle_500(err):
return jsonify({
'error': True,
'message': repr(err),
'internal_server_error': True,
}), 500
return (
jsonify({"error": True, "message": repr(err), "internal_server_error": True}),
500,
)

View File

@ -20,10 +20,10 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
from setuptools import setup
setup(
name='litecord',
version='0.0.1',
description='Implementation of the Discord API',
url='https://litecord.top',
author='Luna Mendes',
python_requires='>=3.7'
name="litecord",
version="0.0.1",
description="Implementation of the Discord API",
url="https://litecord.top",
author="Luna Mendes",
python_requires=">=3.7",
)

View File

@ -19,13 +19,15 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import secrets
def email() -> str:
return f'{secrets.token_hex(5)}@{secrets.token_hex(5)}.com'
return f"{secrets.token_hex(5)}@{secrets.token_hex(5)}.com"
class TestClient:
"""Test client that wraps pytest-sanic's TestClient and a test
user and adds authorization headers to test requests."""
def __init__(self, test_cli, test_user):
self.cli = test_cli
self.app = test_cli.app
@ -37,31 +39,31 @@ class TestClient:
def _inject_auth(self, kwargs: dict) -> list:
"""Inject the test user's API key into the test request before
passing the request on to the underlying TestClient."""
headers = kwargs.get('headers', {})
headers['authorization'] = self.user['token']
headers = kwargs.get("headers", {})
headers["authorization"] = self.user["token"]
return headers
async def get(self, *args, **kwargs):
"""Send a GET request."""
kwargs['headers'] = self._inject_auth(kwargs)
kwargs["headers"] = self._inject_auth(kwargs)
return await self.cli.get(*args, **kwargs)
async def post(self, *args, **kwargs):
"""Send a POST request."""
kwargs['headers'] = self._inject_auth(kwargs)
kwargs["headers"] = self._inject_auth(kwargs)
return await self.cli.post(*args, **kwargs)
async def put(self, *args, **kwargs):
"""Send a POST request."""
kwargs['headers'] = self._inject_auth(kwargs)
kwargs["headers"] = self._inject_auth(kwargs)
return await self.cli.put(*args, **kwargs)
async def patch(self, *args, **kwargs):
"""Send a PATCH request."""
kwargs['headers'] = self._inject_auth(kwargs)
kwargs["headers"] = self._inject_auth(kwargs)
return await self.cli.patch(*args, **kwargs)
async def delete(self, *args, **kwargs):
"""Send a DELETE request."""
kwargs['headers'] = self._inject_auth(kwargs)
kwargs["headers"] = self._inject_auth(kwargs)
return await self.cli.delete(*args, **kwargs)

View File

@ -36,22 +36,22 @@ from litecord.blueprints.auth import make_token
from litecord.blueprints.users import delete_user
@pytest.fixture(name='app')
@pytest.fixture(name="app")
def _test_app(unused_tcp_port, event_loop):
set_blueprints(main_app)
main_app.config['_testing'] = True
main_app.config["_testing"] = True
# reassign an unused tcp port for websockets
# since the config might give a used one.
ws_port = unused_tcp_port
main_app.config['IS_SSL'] = False
main_app.config['WS_PORT'] = ws_port
main_app.config['WEBSOCKET_URL'] = f'localhost:{ws_port}'
main_app.config["IS_SSL"] = False
main_app.config["WS_PORT"] = ws_port
main_app.config["WEBSOCKET_URL"] = f"localhost:{ws_port}"
# testing user creations requires hardcoding this to true
# on testing
main_app.config['REGISTRATIONS'] = True
main_app.config["REGISTRATIONS"] = True
# make sure we're calling the before_serving hooks
event_loop.run_until_complete(main_app.startup())
@ -63,11 +63,12 @@ def _test_app(unused_tcp_port, event_loop):
event_loop.run_until_complete(main_app.shutdown())
@pytest.fixture(name='test_cli')
@pytest.fixture(name="test_cli")
def _test_cli(app):
"""Give a test client."""
return app.test_client()
# code shamelessly stolen from my elixire mr
# https://gitlab.com/elixire/elixire/merge_requests/52
async def _user_fixture_setup(app):
@ -76,21 +77,26 @@ async def _user_fixture_setup(app):
user_email = email()
user_id, pwd_hash = await create_user(
username, user_email, password, app.db, app.loop)
username, user_email, password, app.db, app.loop
)
# generate a token for api access
user_token = make_token(user_id, pwd_hash)
return {'id': user_id, 'token': user_token,
'email': user_email, 'username': username,
'password': password}
return {
"id": user_id,
"token": user_token,
"email": user_email,
"username": username,
"password": password,
}
async def _user_fixture_teardown(app, udata: dict):
await delete_user(udata['id'], app_=app)
await delete_user(udata["id"], app_=app)
@pytest.fixture(name='test_user')
@pytest.fixture(name="test_user")
async def test_user_fixture(app):
"""Yield a randomly generated test user."""
udata = await _user_fixture_setup(app)
@ -113,18 +119,25 @@ async def test_cli_staff(test_cli):
# same test_cli_user, which isn't acceptable.
app = test_cli.app
test_user = await _user_fixture_setup(app)
user_id = test_user['id']
user_id = test_user["id"]
# copied from manage.cmd.users.set_user_staff.
old_flags = await app.db.fetchval("""
old_flags = await app.db.fetchval(
"""
SELECT flags FROM users WHERE id = $1
""", user_id)
""",
user_id,
)
new_flags = old_flags | UserFlags.staff
await app.db.execute("""
await app.db.execute(
"""
UPDATE users SET flags = $1 WHERE id = $2
""", new_flags, user_id)
""",
new_flags,
user_id,
)
yield TestClient(test_cli, test_user)
await _user_fixture_teardown(test_cli.app, test_user)

View File

@ -24,24 +24,24 @@ import pytest
from litecord.blueprints.guilds import delete_guild
from litecord.errors import GuildNotFound
async def _create_guild(test_cli_staff):
genned_name = secrets.token_hex(6)
resp = await test_cli_staff.post('/api/v6/guilds', json={
'name': genned_name,
'region': None
})
resp = await test_cli_staff.post(
"/api/v6/guilds", json={"name": genned_name, "region": None}
)
assert resp.status_code == 200
rjson = await resp.json
assert isinstance(rjson, dict)
assert rjson['name'] == genned_name
assert rjson["name"] == genned_name
return rjson
async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False):
resp = await test_cli_staff.get(f'/api/v6/admin/guilds/{guild_id}')
resp = await test_cli_staff.get(f"/api/v6/admin/guilds/{guild_id}")
if ret_early:
return resp
@ -49,7 +49,7 @@ async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False):
assert resp.status_code == 200
rjson = await resp.json
assert isinstance(rjson, dict)
assert rjson['id'] == guild_id
assert rjson["id"] == guild_id
return rjson
@ -58,7 +58,7 @@ async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False):
async def test_guild_fetch(test_cli_staff):
"""Test the creation and fetching of a guild via the Admin API."""
rjson = await _create_guild(test_cli_staff)
guild_id = rjson['id']
guild_id = rjson["id"]
try:
await _fetch_guild(test_cli_staff, guild_id)
@ -70,8 +70,8 @@ async def test_guild_fetch(test_cli_staff):
async def test_guild_update(test_cli_staff):
"""Test the update of a guild via the Admin API."""
rjson = await _create_guild(test_cli_staff)
guild_id = rjson['id']
assert not rjson['unavailable']
guild_id = rjson["id"]
assert not rjson["unavailable"]
try:
# I believe setting up an entire gateway client registered to the guild
@ -79,19 +79,17 @@ async def test_guild_update(test_cli_staff):
# testing them. Yes, I know its a bad idea, but if someone has an easier
# way to write that, do send an MR.
resp = await test_cli_staff.patch(
f'/api/v6/admin/guilds/{guild_id}',
json={
'unavailable': True
})
f"/api/v6/admin/guilds/{guild_id}", json={"unavailable": True}
)
assert resp.status_code == 200
rjson = await resp.json
assert isinstance(rjson, dict)
assert rjson['id'] == guild_id
assert rjson['unavailable']
assert rjson["id"] == guild_id
assert rjson["unavailable"]
rjson = await _fetch_guild(test_cli_staff, guild_id)
assert rjson['unavailable']
assert rjson["unavailable"]
finally:
await delete_guild(int(guild_id), app_=test_cli_staff.app)
@ -100,20 +98,19 @@ async def test_guild_update(test_cli_staff):
async def test_guild_delete(test_cli_staff):
"""Test the update of a guild via the Admin API."""
rjson = await _create_guild(test_cli_staff)
guild_id = rjson['id']
guild_id = rjson["id"]
try:
resp = await test_cli_staff.delete(f'/api/v6/admin/guilds/{guild_id}')
resp = await test_cli_staff.delete(f"/api/v6/admin/guilds/{guild_id}")
assert resp.status_code == 204
resp = await _fetch_guild(
test_cli_staff, guild_id, ret_early=True)
resp = await _fetch_guild(test_cli_staff, guild_id, ret_early=True)
assert resp.status_code == 404
rjson = await resp.json
assert isinstance(rjson, dict)
assert rjson['error']
assert rjson['code'] == GuildNotFound.error_code
assert rjson["error"]
assert rjson["code"] == GuildNotFound.error_code
finally:
await delete_guild(int(guild_id), app_=test_cli_staff.app)

View File

@ -21,7 +21,7 @@ import pytest
async def _get_invs(test_cli):
resp = await test_cli.get('/api/v6/admin/instance/invites')
resp = await test_cli.get("/api/v6/admin/instance/invites")
assert resp.status_code == 200
rjson = await resp.json
@ -39,7 +39,7 @@ async def test_get_invites(test_cli_staff):
async def test_inv_delete_invalid(test_cli_staff):
"""Test errors happen when trying to delete a
non-existing instance invite."""
resp = await test_cli_staff.delete('/api/v6/admin/instance/invites/aaaaaa')
resp = await test_cli_staff.delete("/api/v6/admin/instance/invites/aaaaaa")
assert resp.status_code == 404
@ -48,21 +48,20 @@ async def test_inv_delete_invalid(test_cli_staff):
async def test_create_invite(test_cli_staff):
"""Test the creation of an instance invite, then listing it,
then deleting it."""
resp = await test_cli_staff.put('/api/v6/admin/instance/invites', json={
'max_uses': 1
})
resp = await test_cli_staff.put(
"/api/v6/admin/instance/invites", json={"max_uses": 1}
)
assert resp.status_code == 200
rjson = await resp.json
assert isinstance(rjson, dict)
code = rjson['code']
code = rjson["code"]
# assert that the invite is in the list
invites = await _get_invs(test_cli_staff)
assert any(inv['code'] == code for inv in invites)
assert any(inv["code"] == code for inv in invites)
# delete it, and assert it worked
resp = await test_cli_staff.delete(
f'/api/v6/admin/instance/invites/{code}')
resp = await test_cli_staff.delete(f"/api/v6/admin/instance/invites/{code}")
assert resp.status_code == 204

Some files were not shown because too many files have changed in this diff Show More