diff --git a/litecord/auth.py b/litecord/auth.py index b52e1a1..3c52ddf 100644 --- a/litecord/auth.py +++ b/litecord/auth.py @@ -20,7 +20,7 @@ along with this program. If not, see . import base64 import binascii from random import randint -from typing import Tuple +from typing import Tuple, Optional import bcrypt from asyncpg import UniqueViolationError @@ -145,7 +145,6 @@ async def hash_data(data: str, loop=None) -> str: return hashed.decode() - async def check_username_usage(username: str, db=None): """Raise an error if too many people are with the same username.""" db = db or app.db @@ -155,13 +154,52 @@ async def check_username_usage(username: str, db=None): WHERE username = $1 """, username) - if same_username > 8000: + if same_username > 9000: raise BadRequest('Too many people.', { 'username': 'Too many people used the same username. ' 'Please choose another' }) +def _raw_discrim() -> str: + new_discrim = randint(1, 9999) + new_discrim = '%04d' % new_discrim + return new_discrim + + +async def roll_discrim(username: str, *, db=None) -> Optional[str]: + """Roll a discriminator for a DiscordTag. + + Tries to generate one 10 times. + + Calls check_username_usage. + """ + db = db or app.db + + # we shouldn't roll discrims for usernames + # that have been used too much. + await check_username_usage(username, db) + + # max 10 times for a reroll + for _ in range(10): + # generate random discrim + discrim = _raw_discrim() + + # check if anyone is with it + res = await db.fetchval(""" + SELECT id + FROM users + WHERE username = $1 AND discriminator = $2 + """, username, discrim) + + # if no user is found with the (username, discrim) + # pair, then this is unique! return it. + if res is None: + return discrim + + return None + + async def create_user(username: str, email: str, password: str, db=None, loop=None) -> Tuple[int, str]: """Create a single user. @@ -173,16 +211,15 @@ async def create_user(username: str, email: str, password: str, loop = loop or app.loop new_id = get_snowflake() + new_discrim = await roll_discrim(username, db=db) - # TODO: unified discrim generation based off username, that also includes - # the check_username_usage() - new_discrim = randint(1, 9999) - new_discrim = '%04d' % new_discrim + if new_discrim is None: + raise BadRequest('Unable to register.', { + 'username': 'Too many people are with this username.' + }) pwd_hash = await hash_data(password, loop) - await check_username_usage(username, db) - try: await db.execute(""" INSERT INTO users diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py index 3789eda..a3d7672 100644 --- a/litecord/blueprints/users.py +++ b/litecord/blueprints/users.py @@ -17,7 +17,6 @@ along with this program. If not, see . """ -import random from os import urandom from asyncpg import UniqueViolationError @@ -28,7 +27,9 @@ from ..errors import Forbidden, BadRequest, Unauthorized from ..schemas import validate, USER_UPDATE, GET_MENTIONS from .guilds import guild_check -from litecord.auth import token_check, hash_data, check_username_usage +from litecord.auth import ( + token_check, hash_data, check_username_usage, roll_discrim +) from litecord.blueprints.guild.mod import remove_member from litecord.enums import PremiumType @@ -110,36 +111,6 @@ async def get_other(target_id): return jsonify(other) -async def _try_reroll(user_id, preferred_username: str = None): - for _ in range(10): - reroll = str(random.randint(1, 9999)) - - if preferred_username: - existing_uid = await app.db.fetchrow(""" - SELECT user_id - FROM users - WHERE username = $1 AND discriminator = $2 - """, preferred_username, reroll) - - if not existing_uid: - return reroll - - continue - - try: - await app.db.execute(""" - UPDATE users - SET discriminator = $1 - WHERE users.id = $2 - """, reroll, user_id) - - return reroll - except UniqueViolationError: - continue - - return - - async def _try_username_patch(user_id, new_username: str) -> str: await check_username_usage(new_username) discrim = None @@ -157,7 +128,7 @@ async def _try_username_patch(user_id, new_username: str) -> str: WHERE users.id = $1 """, user_id) except UniqueViolationError: - discrim = await _try_reroll(user_id, new_username) + discrim = await roll_discrim(new_username) if not discrim: raise BadRequest('Unable to change username', { diff --git a/tests/conftest.py b/tests/conftest.py index bb8b171..48c5726 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -43,6 +43,10 @@ def _test_app(unused_tcp_port, event_loop): main_app.config['WS_PORT'] = ws_port main_app.config['WEBSOCKET_URL'] = f'localhost:{ws_port}' + # testing user creations requires hardcoding this to true + # on testing + main_app.config['REGISTRATIONS'] = True + # make sure we're calling the before_serving hooks event_loop.run_until_complete(main_app.startup()) diff --git a/tests/test_user.py b/tests/test_user.py index 3e00c40..7a851a0 100644 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -18,6 +18,7 @@ along with this program. If not, see . """ import pytest +import secrets from tests.common import login, get_uid @@ -33,6 +34,14 @@ async def test_get_me(test_cli): rjson = await resp.json assert isinstance(rjson, dict) + # incomplete user assertions, but should be enough + assert isinstance(rjson['id'], str) + assert isinstance(rjson['username'], str) + assert isinstance(rjson['discriminator'], str) + assert rjson['avatar'] is None or isinstance(rjson['avatar'], str) + assert isinstance(rjson['flags'], int) + assert isinstance(rjson['bot'], bool) + @pytest.mark.asyncio async def test_get_me_guilds(test_cli): @@ -63,3 +72,46 @@ async def test_get_profile_self(test_cli): assert (rjson['premium_since'] is None or isinstance(rjson['premium_since'], str)) assert isinstance(rjson['mutual_guilds'], list) + + +@pytest.mark.asyncio +async def test_create_user(test_cli): + """Test the creation and deletion of a user.""" + username = secrets.token_hex(4) + _email = secrets.token_hex(5) + email = f'{_email}@{_email}.com' + password = secrets.token_hex(6) + + resp = await test_cli.post('/api/v6/auth/register', json={ + 'username': username, + 'email': email, + 'password': password + }) + + assert resp.status_code == 200 + rjson = await resp.json + + assert isinstance(rjson, dict) + token = rjson['token'] + assert isinstance(token, str) + + resp = await test_cli.get('/api/v6/users/@me', headers={ + 'Authorization': token, + }) + + assert resp.status_code == 200 + rjson = await resp.json + assert rjson['username'] == username + assert rjson['email'] == email + + resp = await test_cli.post('/api/v6/users/@me/delete', headers={ + 'Authorization': token, + }, json={ + 'password': password + }) + + assert resp.status_code == 204 + + await test_cli.app.db.execute(""" + DELETE FROM users WHERE id = $1 + """, int(rjson['id']))