mirror of https://gitlab.com/litecord/litecord.git
Merge branch 'master' into db-streamlining
This commit is contained in:
commit
81bf33c364
|
|
@ -20,7 +20,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
import base64
|
import base64
|
||||||
import binascii
|
import binascii
|
||||||
from random import randint
|
from random import randint
|
||||||
from typing import Tuple
|
from typing import Tuple, Optional
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
from asyncpg import UniqueViolationError
|
from asyncpg import UniqueViolationError
|
||||||
|
|
@ -145,7 +145,6 @@ async def hash_data(data: str, loop=None) -> str:
|
||||||
return hashed.decode()
|
return hashed.decode()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def check_username_usage(username: str, db=None):
|
async def check_username_usage(username: str, db=None):
|
||||||
"""Raise an error if too many people are with the same username."""
|
"""Raise an error if too many people are with the same username."""
|
||||||
db = db or app.db
|
db = db or app.db
|
||||||
|
|
@ -155,13 +154,52 @@ async def check_username_usage(username: str, db=None):
|
||||||
WHERE username = $1
|
WHERE username = $1
|
||||||
""", username)
|
""", username)
|
||||||
|
|
||||||
if same_username > 8000:
|
if same_username > 9000:
|
||||||
raise BadRequest('Too many people.', {
|
raise BadRequest('Too many people.', {
|
||||||
'username': 'Too many people used the same username. '
|
'username': 'Too many people used the same username. '
|
||||||
'Please choose another'
|
'Please choose another'
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def _raw_discrim() -> str:
|
||||||
|
new_discrim = randint(1, 9999)
|
||||||
|
new_discrim = '%04d' % new_discrim
|
||||||
|
return new_discrim
|
||||||
|
|
||||||
|
|
||||||
|
async def roll_discrim(username: str, *, db=None) -> Optional[str]:
|
||||||
|
"""Roll a discriminator for a DiscordTag.
|
||||||
|
|
||||||
|
Tries to generate one 10 times.
|
||||||
|
|
||||||
|
Calls check_username_usage.
|
||||||
|
"""
|
||||||
|
db = db or app.db
|
||||||
|
|
||||||
|
# we shouldn't roll discrims for usernames
|
||||||
|
# that have been used too much.
|
||||||
|
await check_username_usage(username, db)
|
||||||
|
|
||||||
|
# max 10 times for a reroll
|
||||||
|
for _ in range(10):
|
||||||
|
# generate random discrim
|
||||||
|
discrim = _raw_discrim()
|
||||||
|
|
||||||
|
# check if anyone is with it
|
||||||
|
res = await db.fetchval("""
|
||||||
|
SELECT id
|
||||||
|
FROM users
|
||||||
|
WHERE username = $1 AND discriminator = $2
|
||||||
|
""", username, discrim)
|
||||||
|
|
||||||
|
# if no user is found with the (username, discrim)
|
||||||
|
# pair, then this is unique! return it.
|
||||||
|
if res is None:
|
||||||
|
return discrim
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def create_user(username: str, email: str, password: str,
|
async def create_user(username: str, email: str, password: str,
|
||||||
db=None, loop=None) -> Tuple[int, str]:
|
db=None, loop=None) -> Tuple[int, str]:
|
||||||
"""Create a single user.
|
"""Create a single user.
|
||||||
|
|
@ -173,16 +211,15 @@ async def create_user(username: str, email: str, password: str,
|
||||||
loop = loop or app.loop
|
loop = loop or app.loop
|
||||||
|
|
||||||
new_id = get_snowflake()
|
new_id = get_snowflake()
|
||||||
|
new_discrim = await roll_discrim(username, db=db)
|
||||||
|
|
||||||
# TODO: unified discrim generation based off username, that also includes
|
if new_discrim is None:
|
||||||
# the check_username_usage()
|
raise BadRequest('Unable to register.', {
|
||||||
new_discrim = randint(1, 9999)
|
'username': 'Too many people are with this username.'
|
||||||
new_discrim = '%04d' % new_discrim
|
})
|
||||||
|
|
||||||
pwd_hash = await hash_data(password, loop)
|
pwd_hash = await hash_data(password, loop)
|
||||||
|
|
||||||
await check_username_usage(username, db)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await db.execute("""
|
await db.execute("""
|
||||||
INSERT INTO users
|
INSERT INTO users
|
||||||
|
|
|
||||||
|
|
@ -191,7 +191,7 @@ async def create_message(channel_id: int, actual_guild_id: int,
|
||||||
|
|
||||||
data['nonce'],
|
data['nonce'],
|
||||||
MessageType.DEFAULT.value,
|
MessageType.DEFAULT.value,
|
||||||
data.get('embeds', [])
|
data.get('embeds') or []
|
||||||
)
|
)
|
||||||
|
|
||||||
return message_id
|
return message_id
|
||||||
|
|
@ -286,7 +286,7 @@ def msg_create_check_content(payload: dict, files: list, *, use_embeds=False):
|
||||||
has_files = len(files) > 0
|
has_files = len(files) > 0
|
||||||
|
|
||||||
embed_field = 'embeds' if use_embeds else 'embed'
|
embed_field = 'embeds' if use_embeds else 'embed'
|
||||||
has_embed = embed_field in payload
|
has_embed = embed_field in payload and payload.get(embed_field) is not None
|
||||||
|
|
||||||
has_total_content = has_content or has_embed or has_files
|
has_total_content = has_content or has_embed or has_files
|
||||||
|
|
||||||
|
|
@ -405,7 +405,8 @@ async def _create_message(channel_id):
|
||||||
|
|
||||||
# fill_embed takes care of filling proxy and width/height
|
# fill_embed takes care of filling proxy and width/height
|
||||||
'embeds': ([await fill_embed(j['embed'])]
|
'embeds': ([await fill_embed(j['embed'])]
|
||||||
if 'embed' in j else []),
|
if j.get('embed') is not None
|
||||||
|
else []),
|
||||||
})
|
})
|
||||||
|
|
||||||
# for each file given, we add it as an attachment
|
# for each file given, we add it as an attachment
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import random
|
|
||||||
from os import urandom
|
from os import urandom
|
||||||
|
|
||||||
from asyncpg import UniqueViolationError
|
from asyncpg import UniqueViolationError
|
||||||
|
|
@ -28,7 +27,9 @@ from ..errors import Forbidden, BadRequest, Unauthorized
|
||||||
from ..schemas import validate, USER_UPDATE, GET_MENTIONS
|
from ..schemas import validate, USER_UPDATE, GET_MENTIONS
|
||||||
|
|
||||||
from .guilds import guild_check
|
from .guilds import guild_check
|
||||||
from litecord.auth import token_check, hash_data, check_username_usage
|
from litecord.auth import (
|
||||||
|
token_check, hash_data, check_username_usage, roll_discrim
|
||||||
|
)
|
||||||
from litecord.blueprints.guild.mod import remove_member
|
from litecord.blueprints.guild.mod import remove_member
|
||||||
|
|
||||||
from litecord.enums import PremiumType
|
from litecord.enums import PremiumType
|
||||||
|
|
@ -110,36 +111,6 @@ async def get_other(target_id):
|
||||||
return jsonify(other)
|
return jsonify(other)
|
||||||
|
|
||||||
|
|
||||||
async def _try_reroll(user_id, preferred_username: str = None):
|
|
||||||
for _ in range(10):
|
|
||||||
reroll = str(random.randint(1, 9999))
|
|
||||||
|
|
||||||
if preferred_username:
|
|
||||||
existing_uid = await app.db.fetchrow("""
|
|
||||||
SELECT user_id
|
|
||||||
FROM users
|
|
||||||
WHERE username = $1 AND discriminator = $2
|
|
||||||
""", preferred_username, reroll)
|
|
||||||
|
|
||||||
if not existing_uid:
|
|
||||||
return reroll
|
|
||||||
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
await app.db.execute("""
|
|
||||||
UPDATE users
|
|
||||||
SET discriminator = $1
|
|
||||||
WHERE users.id = $2
|
|
||||||
""", reroll, user_id)
|
|
||||||
|
|
||||||
return reroll
|
|
||||||
except UniqueViolationError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
async def _try_username_patch(user_id, new_username: str) -> str:
|
async def _try_username_patch(user_id, new_username: str) -> str:
|
||||||
await check_username_usage(new_username)
|
await check_username_usage(new_username)
|
||||||
discrim = None
|
discrim = None
|
||||||
|
|
@ -157,7 +128,7 @@ async def _try_username_patch(user_id, new_username: str) -> str:
|
||||||
WHERE users.id = $1
|
WHERE users.id = $1
|
||||||
""", user_id)
|
""", user_id)
|
||||||
except UniqueViolationError:
|
except UniqueViolationError:
|
||||||
discrim = await _try_reroll(user_id, new_username)
|
discrim = await roll_discrim(new_username)
|
||||||
|
|
||||||
if not discrim:
|
if not discrim:
|
||||||
raise BadRequest('Unable to change username', {
|
raise BadRequest('Unable to change username', {
|
||||||
|
|
|
||||||
|
|
@ -185,6 +185,9 @@ async def fetch_embed(url, *, config=None, session=None) -> dict:
|
||||||
|
|
||||||
async def fill_embed(embed: Embed) -> Embed:
|
async def fill_embed(embed: Embed) -> Embed:
|
||||||
"""Fill an embed with more information, such as proxy URLs."""
|
"""Fill an embed with more information, such as proxy URLs."""
|
||||||
|
if embed is None:
|
||||||
|
return
|
||||||
|
|
||||||
embed = sanitize_embed(embed)
|
embed = sanitize_embed(embed)
|
||||||
|
|
||||||
if path_exists(embed, 'footer.icon_url'):
|
if path_exists(embed, 'footer.icon_url'):
|
||||||
|
|
|
||||||
|
|
@ -434,7 +434,8 @@ MESSAGE_CREATE = {
|
||||||
'embed': {
|
'embed': {
|
||||||
'type': 'dict',
|
'type': 'dict',
|
||||||
'schema': EMBED_OBJECT,
|
'schema': EMBED_OBJECT,
|
||||||
'required': False
|
'required': False,
|
||||||
|
'nullable': True
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,10 @@ def _test_app(unused_tcp_port, event_loop):
|
||||||
main_app.config['WS_PORT'] = ws_port
|
main_app.config['WS_PORT'] = ws_port
|
||||||
main_app.config['WEBSOCKET_URL'] = f'localhost:{ws_port}'
|
main_app.config['WEBSOCKET_URL'] = f'localhost:{ws_port}'
|
||||||
|
|
||||||
|
# testing user creations requires hardcoding this to true
|
||||||
|
# on testing
|
||||||
|
main_app.config['REGISTRATIONS'] = True
|
||||||
|
|
||||||
# make sure we're calling the before_serving hooks
|
# make sure we're calling the before_serving hooks
|
||||||
event_loop.run_until_complete(main_app.startup())
|
event_loop.run_until_complete(main_app.startup())
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import secrets
|
||||||
|
|
||||||
from tests.common import login, get_uid
|
from tests.common import login, get_uid
|
||||||
|
|
||||||
|
|
@ -33,6 +34,14 @@ async def test_get_me(test_cli):
|
||||||
rjson = await resp.json
|
rjson = await resp.json
|
||||||
assert isinstance(rjson, dict)
|
assert isinstance(rjson, dict)
|
||||||
|
|
||||||
|
# incomplete user assertions, but should be enough
|
||||||
|
assert isinstance(rjson['id'], str)
|
||||||
|
assert isinstance(rjson['username'], str)
|
||||||
|
assert isinstance(rjson['discriminator'], str)
|
||||||
|
assert rjson['avatar'] is None or isinstance(rjson['avatar'], str)
|
||||||
|
assert isinstance(rjson['flags'], int)
|
||||||
|
assert isinstance(rjson['bot'], bool)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_me_guilds(test_cli):
|
async def test_get_me_guilds(test_cli):
|
||||||
|
|
@ -63,3 +72,46 @@ async def test_get_profile_self(test_cli):
|
||||||
assert (rjson['premium_since'] is None
|
assert (rjson['premium_since'] is None
|
||||||
or isinstance(rjson['premium_since'], str))
|
or isinstance(rjson['premium_since'], str))
|
||||||
assert isinstance(rjson['mutual_guilds'], list)
|
assert isinstance(rjson['mutual_guilds'], list)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_user(test_cli):
|
||||||
|
"""Test the creation and deletion of a user."""
|
||||||
|
username = secrets.token_hex(4)
|
||||||
|
_email = secrets.token_hex(5)
|
||||||
|
email = f'{_email}@{_email}.com'
|
||||||
|
password = secrets.token_hex(6)
|
||||||
|
|
||||||
|
resp = await test_cli.post('/api/v6/auth/register', json={
|
||||||
|
'username': username,
|
||||||
|
'email': email,
|
||||||
|
'password': password
|
||||||
|
})
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
rjson = await resp.json
|
||||||
|
|
||||||
|
assert isinstance(rjson, dict)
|
||||||
|
token = rjson['token']
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
resp = await test_cli.get('/api/v6/users/@me', headers={
|
||||||
|
'Authorization': token,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
rjson = await resp.json
|
||||||
|
assert rjson['username'] == username
|
||||||
|
assert rjson['email'] == email
|
||||||
|
|
||||||
|
resp = await test_cli.post('/api/v6/users/@me/delete', headers={
|
||||||
|
'Authorization': token,
|
||||||
|
}, json={
|
||||||
|
'password': password
|
||||||
|
})
|
||||||
|
|
||||||
|
assert resp.status_code == 204
|
||||||
|
|
||||||
|
await test_cli.app.db.execute("""
|
||||||
|
DELETE FROM users WHERE id = $1
|
||||||
|
""", int(rjson['id']))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue