From ff1469f05f70fcb1155aba3840ed548a0c4ecb3b Mon Sep 17 00:00:00 2001 From: Luna Date: Sun, 31 Mar 2019 18:59:38 -0300 Subject: [PATCH] litecord.auth: add roll_discrim() this should fix any issues arising when genning discriminators at register-time. - litecord.blueprints.users: use roll_discrim in favour of _try_reroll --- litecord/auth.py | 55 ++++++++++++++++++++++++++++++------ litecord/blueprints/users.py | 36 +++-------------------- 2 files changed, 50 insertions(+), 41 deletions(-) diff --git a/litecord/auth.py b/litecord/auth.py index b52e1a1..3c52ddf 100644 --- a/litecord/auth.py +++ b/litecord/auth.py @@ -20,7 +20,7 @@ along with this program. If not, see . import base64 import binascii from random import randint -from typing import Tuple +from typing import Tuple, Optional import bcrypt from asyncpg import UniqueViolationError @@ -145,7 +145,6 @@ async def hash_data(data: str, loop=None) -> str: return hashed.decode() - async def check_username_usage(username: str, db=None): """Raise an error if too many people are with the same username.""" db = db or app.db @@ -155,13 +154,52 @@ async def check_username_usage(username: str, db=None): WHERE username = $1 """, username) - if same_username > 8000: + if same_username > 9000: raise BadRequest('Too many people.', { 'username': 'Too many people used the same username. ' 'Please choose another' }) +def _raw_discrim() -> str: + new_discrim = randint(1, 9999) + new_discrim = '%04d' % new_discrim + return new_discrim + + +async def roll_discrim(username: str, *, db=None) -> Optional[str]: + """Roll a discriminator for a DiscordTag. + + Tries to generate one 10 times. + + Calls check_username_usage. + """ + db = db or app.db + + # we shouldn't roll discrims for usernames + # that have been used too much. + await check_username_usage(username, db) + + # max 10 times for a reroll + for _ in range(10): + # generate random discrim + discrim = _raw_discrim() + + # check if anyone is with it + res = await db.fetchval(""" + SELECT id + FROM users + WHERE username = $1 AND discriminator = $2 + """, username, discrim) + + # if no user is found with the (username, discrim) + # pair, then this is unique! return it. + if res is None: + return discrim + + return None + + async def create_user(username: str, email: str, password: str, db=None, loop=None) -> Tuple[int, str]: """Create a single user. @@ -173,16 +211,15 @@ async def create_user(username: str, email: str, password: str, loop = loop or app.loop new_id = get_snowflake() + new_discrim = await roll_discrim(username, db=db) - # TODO: unified discrim generation based off username, that also includes - # the check_username_usage() - new_discrim = randint(1, 9999) - new_discrim = '%04d' % new_discrim + if new_discrim is None: + raise BadRequest('Unable to register.', { + 'username': 'Too many people are with this username.' + }) pwd_hash = await hash_data(password, loop) - await check_username_usage(username, db) - try: await db.execute(""" INSERT INTO users diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py index 3789eda..8727734 100644 --- a/litecord/blueprints/users.py +++ b/litecord/blueprints/users.py @@ -28,7 +28,9 @@ from ..errors import Forbidden, BadRequest, Unauthorized from ..schemas import validate, USER_UPDATE, GET_MENTIONS from .guilds import guild_check -from litecord.auth import token_check, hash_data, check_username_usage +from litecord.auth import ( + token_check, hash_data, check_username_usage, roll_discrim +) from litecord.blueprints.guild.mod import remove_member from litecord.enums import PremiumType @@ -110,36 +112,6 @@ async def get_other(target_id): return jsonify(other) -async def _try_reroll(user_id, preferred_username: str = None): - for _ in range(10): - reroll = str(random.randint(1, 9999)) - - if preferred_username: - existing_uid = await app.db.fetchrow(""" - SELECT user_id - FROM users - WHERE username = $1 AND discriminator = $2 - """, preferred_username, reroll) - - if not existing_uid: - return reroll - - continue - - try: - await app.db.execute(""" - UPDATE users - SET discriminator = $1 - WHERE users.id = $2 - """, reroll, user_id) - - return reroll - except UniqueViolationError: - continue - - return - - async def _try_username_patch(user_id, new_username: str) -> str: await check_username_usage(new_username) discrim = None @@ -157,7 +129,7 @@ async def _try_username_patch(user_id, new_username: str) -> str: WHERE users.id = $1 """, user_id) except UniqueViolationError: - discrim = await _try_reroll(user_id, new_username) + discrim = await roll_discrim(new_username) if not discrim: raise BadRequest('Unable to change username', {