diff --git a/litecord/auth.py b/litecord/auth.py
index b52e1a1..3c52ddf 100644
--- a/litecord/auth.py
+++ b/litecord/auth.py
@@ -20,7 +20,7 @@ along with this program. If not, see .
import base64
import binascii
from random import randint
-from typing import Tuple
+from typing import Tuple, Optional
import bcrypt
from asyncpg import UniqueViolationError
@@ -145,7 +145,6 @@ async def hash_data(data: str, loop=None) -> str:
return hashed.decode()
-
async def check_username_usage(username: str, db=None):
"""Raise an error if too many people are with the same username."""
db = db or app.db
@@ -155,13 +154,52 @@ async def check_username_usage(username: str, db=None):
WHERE username = $1
""", username)
- if same_username > 8000:
+ if same_username > 9000:
raise BadRequest('Too many people.', {
'username': 'Too many people used the same username. '
'Please choose another'
})
+def _raw_discrim() -> str:
+ new_discrim = randint(1, 9999)
+ new_discrim = '%04d' % new_discrim
+ return new_discrim
+
+
+async def roll_discrim(username: str, *, db=None) -> Optional[str]:
+ """Roll a discriminator for a DiscordTag.
+
+ Tries to generate one 10 times.
+
+ Calls check_username_usage.
+ """
+ db = db or app.db
+
+ # we shouldn't roll discrims for usernames
+ # that have been used too much.
+ await check_username_usage(username, db)
+
+ # max 10 times for a reroll
+ for _ in range(10):
+ # generate random discrim
+ discrim = _raw_discrim()
+
+ # check if anyone is with it
+ res = await db.fetchval("""
+ SELECT id
+ FROM users
+ WHERE username = $1 AND discriminator = $2
+ """, username, discrim)
+
+ # if no user is found with the (username, discrim)
+ # pair, then this is unique! return it.
+ if res is None:
+ return discrim
+
+ return None
+
+
async def create_user(username: str, email: str, password: str,
db=None, loop=None) -> Tuple[int, str]:
"""Create a single user.
@@ -173,16 +211,15 @@ async def create_user(username: str, email: str, password: str,
loop = loop or app.loop
new_id = get_snowflake()
+ new_discrim = await roll_discrim(username, db=db)
- # TODO: unified discrim generation based off username, that also includes
- # the check_username_usage()
- new_discrim = randint(1, 9999)
- new_discrim = '%04d' % new_discrim
+ if new_discrim is None:
+ raise BadRequest('Unable to register.', {
+ 'username': 'Too many people are with this username.'
+ })
pwd_hash = await hash_data(password, loop)
- await check_username_usage(username, db)
-
try:
await db.execute("""
INSERT INTO users
diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py
index 3789eda..a3d7672 100644
--- a/litecord/blueprints/users.py
+++ b/litecord/blueprints/users.py
@@ -17,7 +17,6 @@ along with this program. If not, see .
"""
-import random
from os import urandom
from asyncpg import UniqueViolationError
@@ -28,7 +27,9 @@ from ..errors import Forbidden, BadRequest, Unauthorized
from ..schemas import validate, USER_UPDATE, GET_MENTIONS
from .guilds import guild_check
-from litecord.auth import token_check, hash_data, check_username_usage
+from litecord.auth import (
+ token_check, hash_data, check_username_usage, roll_discrim
+)
from litecord.blueprints.guild.mod import remove_member
from litecord.enums import PremiumType
@@ -110,36 +111,6 @@ async def get_other(target_id):
return jsonify(other)
-async def _try_reroll(user_id, preferred_username: str = None):
- for _ in range(10):
- reroll = str(random.randint(1, 9999))
-
- if preferred_username:
- existing_uid = await app.db.fetchrow("""
- SELECT user_id
- FROM users
- WHERE username = $1 AND discriminator = $2
- """, preferred_username, reroll)
-
- if not existing_uid:
- return reroll
-
- continue
-
- try:
- await app.db.execute("""
- UPDATE users
- SET discriminator = $1
- WHERE users.id = $2
- """, reroll, user_id)
-
- return reroll
- except UniqueViolationError:
- continue
-
- return
-
-
async def _try_username_patch(user_id, new_username: str) -> str:
await check_username_usage(new_username)
discrim = None
@@ -157,7 +128,7 @@ async def _try_username_patch(user_id, new_username: str) -> str:
WHERE users.id = $1
""", user_id)
except UniqueViolationError:
- discrim = await _try_reroll(user_id, new_username)
+ discrim = await roll_discrim(new_username)
if not discrim:
raise BadRequest('Unable to change username', {
diff --git a/tests/conftest.py b/tests/conftest.py
index bb8b171..48c5726 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -43,6 +43,10 @@ def _test_app(unused_tcp_port, event_loop):
main_app.config['WS_PORT'] = ws_port
main_app.config['WEBSOCKET_URL'] = f'localhost:{ws_port}'
+ # testing user creations requires hardcoding this to true
+ # on testing
+ main_app.config['REGISTRATIONS'] = True
+
# make sure we're calling the before_serving hooks
event_loop.run_until_complete(main_app.startup())
diff --git a/tests/test_user.py b/tests/test_user.py
index 3e00c40..7a851a0 100644
--- a/tests/test_user.py
+++ b/tests/test_user.py
@@ -18,6 +18,7 @@ along with this program. If not, see .
"""
import pytest
+import secrets
from tests.common import login, get_uid
@@ -33,6 +34,14 @@ async def test_get_me(test_cli):
rjson = await resp.json
assert isinstance(rjson, dict)
+ # incomplete user assertions, but should be enough
+ assert isinstance(rjson['id'], str)
+ assert isinstance(rjson['username'], str)
+ assert isinstance(rjson['discriminator'], str)
+ assert rjson['avatar'] is None or isinstance(rjson['avatar'], str)
+ assert isinstance(rjson['flags'], int)
+ assert isinstance(rjson['bot'], bool)
+
@pytest.mark.asyncio
async def test_get_me_guilds(test_cli):
@@ -63,3 +72,46 @@ async def test_get_profile_self(test_cli):
assert (rjson['premium_since'] is None
or isinstance(rjson['premium_since'], str))
assert isinstance(rjson['mutual_guilds'], list)
+
+
+@pytest.mark.asyncio
+async def test_create_user(test_cli):
+ """Test the creation and deletion of a user."""
+ username = secrets.token_hex(4)
+ _email = secrets.token_hex(5)
+ email = f'{_email}@{_email}.com'
+ password = secrets.token_hex(6)
+
+ resp = await test_cli.post('/api/v6/auth/register', json={
+ 'username': username,
+ 'email': email,
+ 'password': password
+ })
+
+ assert resp.status_code == 200
+ rjson = await resp.json
+
+ assert isinstance(rjson, dict)
+ token = rjson['token']
+ assert isinstance(token, str)
+
+ resp = await test_cli.get('/api/v6/users/@me', headers={
+ 'Authorization': token,
+ })
+
+ assert resp.status_code == 200
+ rjson = await resp.json
+ assert rjson['username'] == username
+ assert rjson['email'] == email
+
+ resp = await test_cli.post('/api/v6/users/@me/delete', headers={
+ 'Authorization': token,
+ }, json={
+ 'password': password
+ })
+
+ assert resp.status_code == 204
+
+ await test_cli.app.db.execute("""
+ DELETE FROM users WHERE id = $1
+ """, int(rjson['id']))