Merge branch 'master' into db-streamlining

This commit is contained in:
Luna 2019-04-04 13:37:32 -03:00
commit 81bf33c364
7 changed files with 115 additions and 46 deletions

View File

@ -20,7 +20,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import base64 import base64
import binascii import binascii
from random import randint from random import randint
from typing import Tuple from typing import Tuple, Optional
import bcrypt import bcrypt
from asyncpg import UniqueViolationError from asyncpg import UniqueViolationError
@ -145,7 +145,6 @@ async def hash_data(data: str, loop=None) -> str:
return hashed.decode() return hashed.decode()
async def check_username_usage(username: str, db=None): async def check_username_usage(username: str, db=None):
"""Raise an error if too many people are with the same username.""" """Raise an error if too many people are with the same username."""
db = db or app.db db = db or app.db
@ -155,13 +154,52 @@ async def check_username_usage(username: str, db=None):
WHERE username = $1 WHERE username = $1
""", username) """, username)
if same_username > 8000: if same_username > 9000:
raise BadRequest('Too many people.', { raise BadRequest('Too many people.', {
'username': 'Too many people used the same username. ' 'username': 'Too many people used the same username. '
'Please choose another' 'Please choose another'
}) })
def _raw_discrim() -> str:
new_discrim = randint(1, 9999)
new_discrim = '%04d' % new_discrim
return new_discrim
async def roll_discrim(username: str, *, db=None) -> Optional[str]:
"""Roll a discriminator for a DiscordTag.
Tries to generate one 10 times.
Calls check_username_usage.
"""
db = db or app.db
# we shouldn't roll discrims for usernames
# that have been used too much.
await check_username_usage(username, db)
# max 10 times for a reroll
for _ in range(10):
# generate random discrim
discrim = _raw_discrim()
# check if anyone is with it
res = await db.fetchval("""
SELECT id
FROM users
WHERE username = $1 AND discriminator = $2
""", username, discrim)
# if no user is found with the (username, discrim)
# pair, then this is unique! return it.
if res is None:
return discrim
return None
async def create_user(username: str, email: str, password: str, async def create_user(username: str, email: str, password: str,
db=None, loop=None) -> Tuple[int, str]: db=None, loop=None) -> Tuple[int, str]:
"""Create a single user. """Create a single user.
@ -173,16 +211,15 @@ async def create_user(username: str, email: str, password: str,
loop = loop or app.loop loop = loop or app.loop
new_id = get_snowflake() new_id = get_snowflake()
new_discrim = await roll_discrim(username, db=db)
# TODO: unified discrim generation based off username, that also includes if new_discrim is None:
# the check_username_usage() raise BadRequest('Unable to register.', {
new_discrim = randint(1, 9999) 'username': 'Too many people are with this username.'
new_discrim = '%04d' % new_discrim })
pwd_hash = await hash_data(password, loop) pwd_hash = await hash_data(password, loop)
await check_username_usage(username, db)
try: try:
await db.execute(""" await db.execute("""
INSERT INTO users INSERT INTO users

View File

@ -191,7 +191,7 @@ async def create_message(channel_id: int, actual_guild_id: int,
data['nonce'], data['nonce'],
MessageType.DEFAULT.value, MessageType.DEFAULT.value,
data.get('embeds', []) data.get('embeds') or []
) )
return message_id return message_id
@ -286,7 +286,7 @@ def msg_create_check_content(payload: dict, files: list, *, use_embeds=False):
has_files = len(files) > 0 has_files = len(files) > 0
embed_field = 'embeds' if use_embeds else 'embed' embed_field = 'embeds' if use_embeds else 'embed'
has_embed = embed_field in payload has_embed = embed_field in payload and payload.get(embed_field) is not None
has_total_content = has_content or has_embed or has_files has_total_content = has_content or has_embed or has_files
@ -405,7 +405,8 @@ async def _create_message(channel_id):
# fill_embed takes care of filling proxy and width/height # fill_embed takes care of filling proxy and width/height
'embeds': ([await fill_embed(j['embed'])] 'embeds': ([await fill_embed(j['embed'])]
if 'embed' in j else []), if j.get('embed') is not None
else []),
}) })
# for each file given, we add it as an attachment # for each file given, we add it as an attachment

View File

@ -17,7 +17,6 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
import random
from os import urandom from os import urandom
from asyncpg import UniqueViolationError from asyncpg import UniqueViolationError
@ -28,7 +27,9 @@ from ..errors import Forbidden, BadRequest, Unauthorized
from ..schemas import validate, USER_UPDATE, GET_MENTIONS from ..schemas import validate, USER_UPDATE, GET_MENTIONS
from .guilds import guild_check from .guilds import guild_check
from litecord.auth import token_check, hash_data, check_username_usage from litecord.auth import (
token_check, hash_data, check_username_usage, roll_discrim
)
from litecord.blueprints.guild.mod import remove_member from litecord.blueprints.guild.mod import remove_member
from litecord.enums import PremiumType from litecord.enums import PremiumType
@ -110,36 +111,6 @@ async def get_other(target_id):
return jsonify(other) return jsonify(other)
async def _try_reroll(user_id, preferred_username: str = None):
for _ in range(10):
reroll = str(random.randint(1, 9999))
if preferred_username:
existing_uid = await app.db.fetchrow("""
SELECT user_id
FROM users
WHERE username = $1 AND discriminator = $2
""", preferred_username, reroll)
if not existing_uid:
return reroll
continue
try:
await app.db.execute("""
UPDATE users
SET discriminator = $1
WHERE users.id = $2
""", reroll, user_id)
return reroll
except UniqueViolationError:
continue
return
async def _try_username_patch(user_id, new_username: str) -> str: async def _try_username_patch(user_id, new_username: str) -> str:
await check_username_usage(new_username) await check_username_usage(new_username)
discrim = None discrim = None
@ -157,7 +128,7 @@ async def _try_username_patch(user_id, new_username: str) -> str:
WHERE users.id = $1 WHERE users.id = $1
""", user_id) """, user_id)
except UniqueViolationError: except UniqueViolationError:
discrim = await _try_reroll(user_id, new_username) discrim = await roll_discrim(new_username)
if not discrim: if not discrim:
raise BadRequest('Unable to change username', { raise BadRequest('Unable to change username', {

View File

@ -185,6 +185,9 @@ async def fetch_embed(url, *, config=None, session=None) -> dict:
async def fill_embed(embed: Embed) -> Embed: async def fill_embed(embed: Embed) -> Embed:
"""Fill an embed with more information, such as proxy URLs.""" """Fill an embed with more information, such as proxy URLs."""
if embed is None:
return
embed = sanitize_embed(embed) embed = sanitize_embed(embed)
if path_exists(embed, 'footer.icon_url'): if path_exists(embed, 'footer.icon_url'):

View File

@ -434,7 +434,8 @@ MESSAGE_CREATE = {
'embed': { 'embed': {
'type': 'dict', 'type': 'dict',
'schema': EMBED_OBJECT, 'schema': EMBED_OBJECT,
'required': False 'required': False,
'nullable': True
} }
} }

View File

@ -43,6 +43,10 @@ def _test_app(unused_tcp_port, event_loop):
main_app.config['WS_PORT'] = ws_port main_app.config['WS_PORT'] = ws_port
main_app.config['WEBSOCKET_URL'] = f'localhost:{ws_port}' main_app.config['WEBSOCKET_URL'] = f'localhost:{ws_port}'
# testing user creations requires hardcoding this to true
# on testing
main_app.config['REGISTRATIONS'] = True
# make sure we're calling the before_serving hooks # make sure we're calling the before_serving hooks
event_loop.run_until_complete(main_app.startup()) event_loop.run_until_complete(main_app.startup())

View File

@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
import pytest import pytest
import secrets
from tests.common import login, get_uid from tests.common import login, get_uid
@ -33,6 +34,14 @@ async def test_get_me(test_cli):
rjson = await resp.json rjson = await resp.json
assert isinstance(rjson, dict) assert isinstance(rjson, dict)
# incomplete user assertions, but should be enough
assert isinstance(rjson['id'], str)
assert isinstance(rjson['username'], str)
assert isinstance(rjson['discriminator'], str)
assert rjson['avatar'] is None or isinstance(rjson['avatar'], str)
assert isinstance(rjson['flags'], int)
assert isinstance(rjson['bot'], bool)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_me_guilds(test_cli): async def test_get_me_guilds(test_cli):
@ -63,3 +72,46 @@ async def test_get_profile_self(test_cli):
assert (rjson['premium_since'] is None assert (rjson['premium_since'] is None
or isinstance(rjson['premium_since'], str)) or isinstance(rjson['premium_since'], str))
assert isinstance(rjson['mutual_guilds'], list) assert isinstance(rjson['mutual_guilds'], list)
@pytest.mark.asyncio
async def test_create_user(test_cli):
"""Test the creation and deletion of a user."""
username = secrets.token_hex(4)
_email = secrets.token_hex(5)
email = f'{_email}@{_email}.com'
password = secrets.token_hex(6)
resp = await test_cli.post('/api/v6/auth/register', json={
'username': username,
'email': email,
'password': password
})
assert resp.status_code == 200
rjson = await resp.json
assert isinstance(rjson, dict)
token = rjson['token']
assert isinstance(token, str)
resp = await test_cli.get('/api/v6/users/@me', headers={
'Authorization': token,
})
assert resp.status_code == 200
rjson = await resp.json
assert rjson['username'] == username
assert rjson['email'] == email
resp = await test_cli.post('/api/v6/users/@me/delete', headers={
'Authorization': token,
}, json={
'password': password
})
assert resp.status_code == 204
await test_cli.app.db.execute("""
DELETE FROM users WHERE id = $1
""", int(rjson['id']))