diff --git a/litecord/blueprints/auth.py b/litecord/blueprints/auth.py index 493d96c..17e77f3 100644 --- a/litecord/blueprints/auth.py +++ b/litecord/blueprints/auth.py @@ -43,15 +43,35 @@ def make_token(user_id, user_pwd_hash) -> str: return signer.sign(user_id).decode() +async def check_username_usage(username: str): + """Raise an error if too many people are with the same username.""" + same_username = await app.db.fetchval(""" + SELECT COUNT(*) + FROM users + WHERE username = $1 + """, username) + + if same_username > 8000: + raise BadRequest('Too many people.', { + 'username': 'Too many people used the same username. ' + 'Please choose another' + }) + + @bp.route('/register', methods=['POST']) async def register(): j = await request.get_json() email, password, username = j['email'], j['password'], j['username'] new_id = get_snowflake() + new_discrim = str(random.randint(1, 9999)) + new_discrim = '%04d' % new_discrim + pwd_hash = await hash_data(password) + await check_username_usage(username) + try: await app.db.execute(""" INSERT INTO users (id, email, username, diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py index a5670cc..f1b094f 100644 --- a/litecord/blueprints/users.py +++ b/litecord/blueprints/users.py @@ -1,13 +1,17 @@ +import random + from quart import Blueprint, jsonify, request, current_app as app from asyncpg import UniqueViolationError from ..auth import token_check from ..snowflake import get_snowflake -from ..errors import Forbidden, BadRequest -from ..schemas import validate, USER_SETTINGS, CREATE_DM, CREATE_GROUP_DM +from ..errors import Forbidden, BadRequest, Unauthorized +from ..schemas import validate, USER_SETTINGS, \ + CREATE_DM, CREATE_GROUP_DM, USER_UPDATE from ..enums import ChannelType, RelationshipType from .guilds import guild_check +from .auth import hash_data, check_password, check_username_usage bp = Blueprint('user', __name__) @@ -37,28 +41,157 @@ async def get_other(target_id): return jsonify(other) +async def _try_reroll(user_id, preferred_username: str = None): + for _ in range(10): + reroll = str(random.randint(1, 9999)) + + if preferred_username: + existing_uid = await app.db.fetchrow(""" + SELECT user_id + FROM users + WHERE preferred_username = $1 AND discriminator = $2 + """, preferred_username, reroll) + + if not existing_uid: + return reroll + + continue + + try: + await app.db.execute(""" + UPDATE users + SET discriminator = $1 + WHERE users.id = $2 + """, reroll, user_id) + + return reroll + except UniqueViolationError: + continue + + return + + +async def _try_username_patch(user_id, new_username: str) -> str: + await check_username_usage(new_username) + discrim = None + + try: + await app.db.execute(""" + UPDATE users + SET username = $1 + WHERE users.id = $2 + """, new_username, user_id) + + return await app.db.fetchval(""" + SELECT discriminator + FROM users + WHERE users.id = $1 + """, user_id) + except UniqueViolationError: + discrim = await _try_reroll(user_id, new_username) + + if not discrim: + raise BadRequest('Unable to change username', { + 'username': 'Too many people are with this username.' + }) + + await app.db.execute(""" + UPDATE users + SET username = $1, discriminator = $2 + WHERE users.id = $3 + """, new_username, discrim, user_id) + + return discrim + + +async def _try_discrim_patch(user_id, new_discrim: str): + try: + await app.db.execute(""" + UPDATE users + SET discriminator = $1 + WHERE id = $2 + """, new_discrim, user_id) + except UniqueViolationError: + raise BadRequest('Invalid discriminator', { + 'discriminator': 'Someone already used this discriminator.' + }) + + +def to_update(j: dict, user: dict, field: str): + return field in j and j[field] and j[field] != user[field] + + +async def _check_pass(j, user): + if not j['password']: + raise BadRequest('password required', { + 'password': 'password required' + }) + + phash = user['password_hash'] + + if not await check_password(phash, j['password']): + raise BadRequest('password incorrect', { + 'password': 'password does not match.' + }) + + @bp.route('/@me', methods=['PATCH']) async def patch_me(): """Patch the current user's information.""" user_id = await token_check() - j = await request.get_json() - - if not isinstance(j, dict): - raise BadRequest('Invalid payload') + j = validate(await request.get_json(), USER_UPDATE) user = await app.storage.get_user(user_id, True) - if 'username' in j: - try: - await app.db.execute(""" - UPDATE users - SET username = $1 - WHERE users.id = $2 - """, j['username'], user_id) - except UniqueViolationError: - raise BadRequest('Username already used.') + user['password_hash'] = await app.db.fetchval(""" + SELECT password_hash + FROM users + WHERE id = $1 + """, user_id) + if to_update(j, user, 'username'): + # this will take care of regenning a new discriminator + discrim = await _try_username_patch(user_id, j['username']) user['username'] = j['username'] + user['discriminator'] = discrim + + if to_update(j, user, 'discriminator'): + # the API treats discriminators as integers, + # but I work with strings on the database. + new_discrim = str(j['discriminator']) + + await _try_discrim_patch(user_id, new_discrim) + user['discriminator'] = new_discrim + + if to_update(j, user, 'email'): + await _check_pass(j, user) + + # TODO: reverify the new email? + await app.db.execute(""" + UPDATE users + SET email = $1 + WHERE id = $2 + """, j['email'], user_id) + user['email'] = j['email'] + + if 'avatar' in j: + # TODO: update icon + pass + + if 'new_password' in j and j['new_password']: + await _check_pass(j, user) + + new_hash = await hash_data(j['new_password']) + + await app.db.execute(""" + UPDATE users + SET password_hash = $1 + WHERE id = $2 + """, new_hash, user_id) + + # TODO: dispatch USER_UPDATE to guilds and users + await app.dispatcher.dispatch_user( + user_id, 'USER_UPDATE', user) return jsonify(user) diff --git a/litecord/enums.py b/litecord/enums.py index 99cc0e2..edc580b 100644 --- a/litecord/enums.py +++ b/litecord/enums.py @@ -1,5 +1,3 @@ -import ctypes - from enum import Enum @@ -46,28 +44,28 @@ class MessageActivityType(EasyEnum): JOIN_REQUEST = 5 -uint8 = ctypes.c_uint8 +class ActivityFlags: + instance = 1 + join = 2 + spectate = 4 + join_request = 8 + sync = 16 + play = 32 -# use ctypes to interpret the bits in activity flags -class ActivityFlagsBits(ctypes.LittleEndianStructure): - _fields_ = [ - ('instance', uint8, 1), - ('join', uint8, 1), - ('spectate', uint8, 1), - ('join_request', uint8, 1), - ('sync', uint8, 1), - ('play', uint8, 1), - ] +class UserFlags: + staff = 1 + partner = 2 + hypesquad = 4 + bug_hunter = 8 + mfa_sms = 16 + premium_dismissed = 32 + hsquad_house_1 = 64 + hsquad_house_2 = 128 + hsquad_house_3 = 256 -class ActivityFlags(ctypes.Union): - _anonymous_ = ('bit',) - - _fields_ = [ - ('bit', ActivityFlagsBits), - ('as_byte', uint8), - ] + premium_early = 512 class StatusType(EasyEnum): diff --git a/litecord/schemas.py b/litecord/schemas.py index a6f30dd..556b84a 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -12,6 +12,7 @@ log = Logger(__name__) USERNAME_REGEX = re.compile(r'^[a-zA-Z0-9_]{2,19}$', re.A) EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$', re.A) +DATA_REGEX = re.compile(r'data\:image/(png|jpeg|gif);base64,(.+)', re.A) # collection of regexes @@ -27,6 +28,24 @@ class LitecordValidator(Validator): """Validate against the username regex.""" return bool(USERNAME_REGEX.match(value)) + def _validate_type_email(self, value: str) -> bool: + """Validate against the username regex.""" + return bool(EMAIL_REGEX.match(value)) + + def _validate_type_b64_icon(self, value: str) -> bool: + return bool(DATA_REGEX.match(value)) + + def _validate_type_discriminator(self, value: str) -> bool: + """Discriminators are numbers in the API + that can go from 0 to 9999. + """ + try: + discrim = int(value) + except (TypeError, ValueError): + return False + + return 0 < discrim <= 9999 + def _validate_type_snowflake(self, value: str) -> bool: try: int(value) @@ -82,6 +101,43 @@ def validate(reqjson, schema, raise_err: bool = True): return validator.document +USER_UPDATE = { + 'username': { + 'type': 'username', 'minlength': 2, + 'maxlength': 30, 'required': False}, + + 'discriminator': { + 'type': 'discriminator', + 'required': False, + 'nullable': True, + }, + + 'password': { + 'type': 'string', 'minlength': 0, + 'maxlength': 100, 'required': False, + }, + + 'new_password': { + 'type': 'string', 'minlength': 5, + 'maxlength': 100, 'required': False, + 'dependencies': 'password', + 'nullable': True + }, + + 'email': { + 'type': 'string', 'minlength': 2, + 'maxlength': 30, 'required': False, + 'dependencies': 'password', + }, + + 'avatar': { + 'type': 'b64_icon', 'required': False, + 'nullable': True + }, + +} + + GUILD_UPDATE = { 'name': { 'type': 'string', @@ -297,5 +353,5 @@ CREATE_GROUP_DM = { SPECIFIC_FRIEND = { 'username': {'type': 'username'}, - 'discriminator': {'type': 'number'} + 'discriminator': {'type': 'discriminator'} } diff --git a/litecord/storage.py b/litecord/storage.py index 307acd3..b6ae49e 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -96,6 +96,9 @@ class Storage: duser = dict(user_row) + duser['mobile'] = False + duser['phone'] = None + duser['premium'] = duser['premium_since'] is not None duser.pop('premium_since')