Merge branch 'unified-discrim-gen' into 'master'

Unified discrim gen

See merge request litecord/litecord!30
This commit is contained in:
Luna 2019-04-04 00:14:25 +00:00
commit 1679f63217
4 changed files with 106 additions and 42 deletions

View File

@ -20,7 +20,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import base64 import base64
import binascii import binascii
from random import randint from random import randint
from typing import Tuple from typing import Tuple, Optional
import bcrypt import bcrypt
from asyncpg import UniqueViolationError from asyncpg import UniqueViolationError
@ -145,7 +145,6 @@ async def hash_data(data: str, loop=None) -> str:
return hashed.decode() return hashed.decode()
async def check_username_usage(username: str, db=None): async def check_username_usage(username: str, db=None):
"""Raise an error if too many people are with the same username.""" """Raise an error if too many people are with the same username."""
db = db or app.db db = db or app.db
@ -155,13 +154,52 @@ async def check_username_usage(username: str, db=None):
WHERE username = $1 WHERE username = $1
""", username) """, username)
if same_username > 8000: if same_username > 9000:
raise BadRequest('Too many people.', { raise BadRequest('Too many people.', {
'username': 'Too many people used the same username. ' 'username': 'Too many people used the same username. '
'Please choose another' 'Please choose another'
}) })
def _raw_discrim() -> str:
new_discrim = randint(1, 9999)
new_discrim = '%04d' % new_discrim
return new_discrim
async def roll_discrim(username: str, *, db=None) -> Optional[str]:
"""Roll a discriminator for a DiscordTag.
Tries to generate one 10 times.
Calls check_username_usage.
"""
db = db or app.db
# we shouldn't roll discrims for usernames
# that have been used too much.
await check_username_usage(username, db)
# max 10 times for a reroll
for _ in range(10):
# generate random discrim
discrim = _raw_discrim()
# check if anyone is with it
res = await db.fetchval("""
SELECT id
FROM users
WHERE username = $1 AND discriminator = $2
""", username, discrim)
# if no user is found with the (username, discrim)
# pair, then this is unique! return it.
if res is None:
return discrim
return None
async def create_user(username: str, email: str, password: str, async def create_user(username: str, email: str, password: str,
db=None, loop=None) -> Tuple[int, str]: db=None, loop=None) -> Tuple[int, str]:
"""Create a single user. """Create a single user.
@ -173,16 +211,15 @@ async def create_user(username: str, email: str, password: str,
loop = loop or app.loop loop = loop or app.loop
new_id = get_snowflake() new_id = get_snowflake()
new_discrim = await roll_discrim(username, db=db)
# TODO: unified discrim generation based off username, that also includes if new_discrim is None:
# the check_username_usage() raise BadRequest('Unable to register.', {
new_discrim = randint(1, 9999) 'username': 'Too many people are with this username.'
new_discrim = '%04d' % new_discrim })
pwd_hash = await hash_data(password, loop) pwd_hash = await hash_data(password, loop)
await check_username_usage(username, db)
try: try:
await db.execute(""" await db.execute("""
INSERT INTO users INSERT INTO users

View File

@ -17,7 +17,6 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
import random
from os import urandom from os import urandom
from asyncpg import UniqueViolationError from asyncpg import UniqueViolationError
@ -28,7 +27,9 @@ from ..errors import Forbidden, BadRequest, Unauthorized
from ..schemas import validate, USER_UPDATE, GET_MENTIONS from ..schemas import validate, USER_UPDATE, GET_MENTIONS
from .guilds import guild_check from .guilds import guild_check
from litecord.auth import token_check, hash_data, check_username_usage from litecord.auth import (
token_check, hash_data, check_username_usage, roll_discrim
)
from litecord.blueprints.guild.mod import remove_member from litecord.blueprints.guild.mod import remove_member
from litecord.enums import PremiumType from litecord.enums import PremiumType
@ -110,36 +111,6 @@ async def get_other(target_id):
return jsonify(other) return jsonify(other)
async def _try_reroll(user_id, preferred_username: str = None):
for _ in range(10):
reroll = str(random.randint(1, 9999))
if preferred_username:
existing_uid = await app.db.fetchrow("""
SELECT user_id
FROM users
WHERE username = $1 AND discriminator = $2
""", preferred_username, reroll)
if not existing_uid:
return reroll
continue
try:
await app.db.execute("""
UPDATE users
SET discriminator = $1
WHERE users.id = $2
""", reroll, user_id)
return reroll
except UniqueViolationError:
continue
return
async def _try_username_patch(user_id, new_username: str) -> str: async def _try_username_patch(user_id, new_username: str) -> str:
await check_username_usage(new_username) await check_username_usage(new_username)
discrim = None discrim = None
@ -157,7 +128,7 @@ async def _try_username_patch(user_id, new_username: str) -> str:
WHERE users.id = $1 WHERE users.id = $1
""", user_id) """, user_id)
except UniqueViolationError: except UniqueViolationError:
discrim = await _try_reroll(user_id, new_username) discrim = await roll_discrim(new_username)
if not discrim: if not discrim:
raise BadRequest('Unable to change username', { raise BadRequest('Unable to change username', {

View File

@ -43,6 +43,10 @@ def _test_app(unused_tcp_port, event_loop):
main_app.config['WS_PORT'] = ws_port main_app.config['WS_PORT'] = ws_port
main_app.config['WEBSOCKET_URL'] = f'localhost:{ws_port}' main_app.config['WEBSOCKET_URL'] = f'localhost:{ws_port}'
# testing user creations requires hardcoding this to true
# on testing
main_app.config['REGISTRATIONS'] = True
# make sure we're calling the before_serving hooks # make sure we're calling the before_serving hooks
event_loop.run_until_complete(main_app.startup()) event_loop.run_until_complete(main_app.startup())

View File

@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
import pytest import pytest
import secrets
from tests.common import login, get_uid from tests.common import login, get_uid
@ -33,6 +34,14 @@ async def test_get_me(test_cli):
rjson = await resp.json rjson = await resp.json
assert isinstance(rjson, dict) assert isinstance(rjson, dict)
# incomplete user assertions, but should be enough
assert isinstance(rjson['id'], str)
assert isinstance(rjson['username'], str)
assert isinstance(rjson['discriminator'], str)
assert rjson['avatar'] is None or isinstance(rjson['avatar'], str)
assert isinstance(rjson['flags'], int)
assert isinstance(rjson['bot'], bool)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_me_guilds(test_cli): async def test_get_me_guilds(test_cli):
@ -63,3 +72,46 @@ async def test_get_profile_self(test_cli):
assert (rjson['premium_since'] is None assert (rjson['premium_since'] is None
or isinstance(rjson['premium_since'], str)) or isinstance(rjson['premium_since'], str))
assert isinstance(rjson['mutual_guilds'], list) assert isinstance(rjson['mutual_guilds'], list)
@pytest.mark.asyncio
async def test_create_user(test_cli):
"""Test the creation and deletion of a user."""
username = secrets.token_hex(4)
_email = secrets.token_hex(5)
email = f'{_email}@{_email}.com'
password = secrets.token_hex(6)
resp = await test_cli.post('/api/v6/auth/register', json={
'username': username,
'email': email,
'password': password
})
assert resp.status_code == 200
rjson = await resp.json
assert isinstance(rjson, dict)
token = rjson['token']
assert isinstance(token, str)
resp = await test_cli.get('/api/v6/users/@me', headers={
'Authorization': token,
})
assert resp.status_code == 200
rjson = await resp.json
assert rjson['username'] == username
assert rjson['email'] == email
resp = await test_cli.post('/api/v6/users/@me/delete', headers={
'Authorization': token,
}, json={
'password': password
})
assert resp.status_code == 204
await test_cli.app.db.execute("""
DELETE FROM users WHERE id = $1
""", int(rjson['id']))