mirror of https://gitlab.com/litecord/litecord.git
blueprints.users: finish user patch impl
- blueprints.auth: check availability of username on register - enums: add UserFlags - schemas: add DATA_REGEX, USER_UPDATE - storage: add dummy mobile and phone values on get_user
This commit is contained in:
parent
9aec27203b
commit
051cdd8ff2
|
|
@ -43,15 +43,35 @@ def make_token(user_id, user_pwd_hash) -> str:
|
|||
return signer.sign(user_id).decode()
|
||||
|
||||
|
||||
async def check_username_usage(username: str):
|
||||
"""Raise an error if too many people are with the same username."""
|
||||
same_username = await app.db.fetchval("""
|
||||
SELECT COUNT(*)
|
||||
FROM users
|
||||
WHERE username = $1
|
||||
""", username)
|
||||
|
||||
if same_username > 8000:
|
||||
raise BadRequest('Too many people.', {
|
||||
'username': 'Too many people used the same username. '
|
||||
'Please choose another'
|
||||
})
|
||||
|
||||
|
||||
@bp.route('/register', methods=['POST'])
|
||||
async def register():
|
||||
j = await request.get_json()
|
||||
email, password, username = j['email'], j['password'], j['username']
|
||||
|
||||
new_id = get_snowflake()
|
||||
|
||||
new_discrim = str(random.randint(1, 9999))
|
||||
new_discrim = '%04d' % new_discrim
|
||||
|
||||
pwd_hash = await hash_data(password)
|
||||
|
||||
await check_username_usage(username)
|
||||
|
||||
try:
|
||||
await app.db.execute("""
|
||||
INSERT INTO users (id, email, username,
|
||||
|
|
|
|||
|
|
@ -1,13 +1,17 @@
|
|||
import random
|
||||
|
||||
from quart import Blueprint, jsonify, request, current_app as app
|
||||
from asyncpg import UniqueViolationError
|
||||
|
||||
from ..auth import token_check
|
||||
from ..snowflake import get_snowflake
|
||||
from ..errors import Forbidden, BadRequest
|
||||
from ..schemas import validate, USER_SETTINGS, CREATE_DM, CREATE_GROUP_DM
|
||||
from ..errors import Forbidden, BadRequest, Unauthorized
|
||||
from ..schemas import validate, USER_SETTINGS, \
|
||||
CREATE_DM, CREATE_GROUP_DM, USER_UPDATE
|
||||
from ..enums import ChannelType, RelationshipType
|
||||
|
||||
from .guilds import guild_check
|
||||
from .auth import hash_data, check_password, check_username_usage
|
||||
|
||||
bp = Blueprint('user', __name__)
|
||||
|
||||
|
|
@ -37,28 +41,157 @@ async def get_other(target_id):
|
|||
return jsonify(other)
|
||||
|
||||
|
||||
@bp.route('/@me', methods=['PATCH'])
|
||||
async def patch_me():
|
||||
"""Patch the current user's information."""
|
||||
user_id = await token_check()
|
||||
j = await request.get_json()
|
||||
async def _try_reroll(user_id, preferred_username: str = None):
|
||||
for _ in range(10):
|
||||
reroll = str(random.randint(1, 9999))
|
||||
|
||||
if not isinstance(j, dict):
|
||||
raise BadRequest('Invalid payload')
|
||||
if preferred_username:
|
||||
existing_uid = await app.db.fetchrow("""
|
||||
SELECT user_id
|
||||
FROM users
|
||||
WHERE preferred_username = $1 AND discriminator = $2
|
||||
""", preferred_username, reroll)
|
||||
|
||||
user = await app.storage.get_user(user_id, True)
|
||||
if not existing_uid:
|
||||
return reroll
|
||||
|
||||
continue
|
||||
|
||||
try:
|
||||
await app.db.execute("""
|
||||
UPDATE users
|
||||
SET discriminator = $1
|
||||
WHERE users.id = $2
|
||||
""", reroll, user_id)
|
||||
|
||||
return reroll
|
||||
except UniqueViolationError:
|
||||
continue
|
||||
|
||||
return
|
||||
|
||||
|
||||
async def _try_username_patch(user_id, new_username: str) -> str:
|
||||
await check_username_usage(new_username)
|
||||
discrim = None
|
||||
|
||||
if 'username' in j:
|
||||
try:
|
||||
await app.db.execute("""
|
||||
UPDATE users
|
||||
SET username = $1
|
||||
WHERE users.id = $2
|
||||
""", j['username'], user_id)
|
||||
except UniqueViolationError:
|
||||
raise BadRequest('Username already used.')
|
||||
""", new_username, user_id)
|
||||
|
||||
return await app.db.fetchval("""
|
||||
SELECT discriminator
|
||||
FROM users
|
||||
WHERE users.id = $1
|
||||
""", user_id)
|
||||
except UniqueViolationError:
|
||||
discrim = await _try_reroll(user_id, new_username)
|
||||
|
||||
if not discrim:
|
||||
raise BadRequest('Unable to change username', {
|
||||
'username': 'Too many people are with this username.'
|
||||
})
|
||||
|
||||
await app.db.execute("""
|
||||
UPDATE users
|
||||
SET username = $1, discriminator = $2
|
||||
WHERE users.id = $3
|
||||
""", new_username, discrim, user_id)
|
||||
|
||||
return discrim
|
||||
|
||||
|
||||
async def _try_discrim_patch(user_id, new_discrim: str):
|
||||
try:
|
||||
await app.db.execute("""
|
||||
UPDATE users
|
||||
SET discriminator = $1
|
||||
WHERE id = $2
|
||||
""", new_discrim, user_id)
|
||||
except UniqueViolationError:
|
||||
raise BadRequest('Invalid discriminator', {
|
||||
'discriminator': 'Someone already used this discriminator.'
|
||||
})
|
||||
|
||||
|
||||
def to_update(j: dict, user: dict, field: str):
|
||||
return field in j and j[field] and j[field] != user[field]
|
||||
|
||||
|
||||
async def _check_pass(j, user):
|
||||
if not j['password']:
|
||||
raise BadRequest('password required', {
|
||||
'password': 'password required'
|
||||
})
|
||||
|
||||
phash = user['password_hash']
|
||||
|
||||
if not await check_password(phash, j['password']):
|
||||
raise BadRequest('password incorrect', {
|
||||
'password': 'password does not match.'
|
||||
})
|
||||
|
||||
|
||||
@bp.route('/@me', methods=['PATCH'])
|
||||
async def patch_me():
|
||||
"""Patch the current user's information."""
|
||||
user_id = await token_check()
|
||||
|
||||
j = validate(await request.get_json(), USER_UPDATE)
|
||||
user = await app.storage.get_user(user_id, True)
|
||||
|
||||
user['password_hash'] = await app.db.fetchval("""
|
||||
SELECT password_hash
|
||||
FROM users
|
||||
WHERE id = $1
|
||||
""", user_id)
|
||||
|
||||
if to_update(j, user, 'username'):
|
||||
# this will take care of regenning a new discriminator
|
||||
discrim = await _try_username_patch(user_id, j['username'])
|
||||
user['username'] = j['username']
|
||||
user['discriminator'] = discrim
|
||||
|
||||
if to_update(j, user, 'discriminator'):
|
||||
# the API treats discriminators as integers,
|
||||
# but I work with strings on the database.
|
||||
new_discrim = str(j['discriminator'])
|
||||
|
||||
await _try_discrim_patch(user_id, new_discrim)
|
||||
user['discriminator'] = new_discrim
|
||||
|
||||
if to_update(j, user, 'email'):
|
||||
await _check_pass(j, user)
|
||||
|
||||
# TODO: reverify the new email?
|
||||
await app.db.execute("""
|
||||
UPDATE users
|
||||
SET email = $1
|
||||
WHERE id = $2
|
||||
""", j['email'], user_id)
|
||||
user['email'] = j['email']
|
||||
|
||||
if 'avatar' in j:
|
||||
# TODO: update icon
|
||||
pass
|
||||
|
||||
if 'new_password' in j and j['new_password']:
|
||||
await _check_pass(j, user)
|
||||
|
||||
new_hash = await hash_data(j['new_password'])
|
||||
|
||||
await app.db.execute("""
|
||||
UPDATE users
|
||||
SET password_hash = $1
|
||||
WHERE id = $2
|
||||
""", new_hash, user_id)
|
||||
|
||||
# TODO: dispatch USER_UPDATE to guilds and users
|
||||
await app.dispatcher.dispatch_user(
|
||||
user_id, 'USER_UPDATE', user)
|
||||
|
||||
return jsonify(user)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import ctypes
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
|
|
@ -46,28 +44,28 @@ class MessageActivityType(EasyEnum):
|
|||
JOIN_REQUEST = 5
|
||||
|
||||
|
||||
uint8 = ctypes.c_uint8
|
||||
class ActivityFlags:
|
||||
instance = 1
|
||||
join = 2
|
||||
spectate = 4
|
||||
join_request = 8
|
||||
sync = 16
|
||||
play = 32
|
||||
|
||||
|
||||
# use ctypes to interpret the bits in activity flags
|
||||
class ActivityFlagsBits(ctypes.LittleEndianStructure):
|
||||
_fields_ = [
|
||||
('instance', uint8, 1),
|
||||
('join', uint8, 1),
|
||||
('spectate', uint8, 1),
|
||||
('join_request', uint8, 1),
|
||||
('sync', uint8, 1),
|
||||
('play', uint8, 1),
|
||||
]
|
||||
class UserFlags:
|
||||
staff = 1
|
||||
partner = 2
|
||||
hypesquad = 4
|
||||
bug_hunter = 8
|
||||
mfa_sms = 16
|
||||
premium_dismissed = 32
|
||||
|
||||
hsquad_house_1 = 64
|
||||
hsquad_house_2 = 128
|
||||
hsquad_house_3 = 256
|
||||
|
||||
class ActivityFlags(ctypes.Union):
|
||||
_anonymous_ = ('bit',)
|
||||
|
||||
_fields_ = [
|
||||
('bit', ActivityFlagsBits),
|
||||
('as_byte', uint8),
|
||||
]
|
||||
premium_early = 512
|
||||
|
||||
|
||||
class StatusType(EasyEnum):
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ log = Logger(__name__)
|
|||
USERNAME_REGEX = re.compile(r'^[a-zA-Z0-9_]{2,19}$', re.A)
|
||||
EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$',
|
||||
re.A)
|
||||
DATA_REGEX = re.compile(r'data\:image/(png|jpeg|gif);base64,(.+)', re.A)
|
||||
|
||||
|
||||
# collection of regexes
|
||||
|
|
@ -27,6 +28,24 @@ class LitecordValidator(Validator):
|
|||
"""Validate against the username regex."""
|
||||
return bool(USERNAME_REGEX.match(value))
|
||||
|
||||
def _validate_type_email(self, value: str) -> bool:
|
||||
"""Validate against the username regex."""
|
||||
return bool(EMAIL_REGEX.match(value))
|
||||
|
||||
def _validate_type_b64_icon(self, value: str) -> bool:
|
||||
return bool(DATA_REGEX.match(value))
|
||||
|
||||
def _validate_type_discriminator(self, value: str) -> bool:
|
||||
"""Discriminators are numbers in the API
|
||||
that can go from 0 to 9999.
|
||||
"""
|
||||
try:
|
||||
discrim = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
return 0 < discrim <= 9999
|
||||
|
||||
def _validate_type_snowflake(self, value: str) -> bool:
|
||||
try:
|
||||
int(value)
|
||||
|
|
@ -82,6 +101,43 @@ def validate(reqjson, schema, raise_err: bool = True):
|
|||
return validator.document
|
||||
|
||||
|
||||
USER_UPDATE = {
|
||||
'username': {
|
||||
'type': 'username', 'minlength': 2,
|
||||
'maxlength': 30, 'required': False},
|
||||
|
||||
'discriminator': {
|
||||
'type': 'discriminator',
|
||||
'required': False,
|
||||
'nullable': True,
|
||||
},
|
||||
|
||||
'password': {
|
||||
'type': 'string', 'minlength': 0,
|
||||
'maxlength': 100, 'required': False,
|
||||
},
|
||||
|
||||
'new_password': {
|
||||
'type': 'string', 'minlength': 5,
|
||||
'maxlength': 100, 'required': False,
|
||||
'dependencies': 'password',
|
||||
'nullable': True
|
||||
},
|
||||
|
||||
'email': {
|
||||
'type': 'string', 'minlength': 2,
|
||||
'maxlength': 30, 'required': False,
|
||||
'dependencies': 'password',
|
||||
},
|
||||
|
||||
'avatar': {
|
||||
'type': 'b64_icon', 'required': False,
|
||||
'nullable': True
|
||||
},
|
||||
|
||||
}
|
||||
|
||||
|
||||
GUILD_UPDATE = {
|
||||
'name': {
|
||||
'type': 'string',
|
||||
|
|
@ -297,5 +353,5 @@ CREATE_GROUP_DM = {
|
|||
|
||||
SPECIFIC_FRIEND = {
|
||||
'username': {'type': 'username'},
|
||||
'discriminator': {'type': 'number'}
|
||||
'discriminator': {'type': 'discriminator'}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -96,6 +96,9 @@ class Storage:
|
|||
|
||||
duser = dict(user_row)
|
||||
|
||||
duser['mobile'] = False
|
||||
duser['phone'] = None
|
||||
|
||||
duser['premium'] = duser['premium_since'] is not None
|
||||
duser.pop('premium_since')
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue