mirror of https://gitlab.com/litecord/litecord.git
197 lines
5.0 KiB
Python
197 lines
5.0 KiB
Python
"""
|
|
|
|
Litecord
|
|
Copyright (C) 2018-2019 Luna Mendes
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, version 3 of the License.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
"""
|
|
|
|
import base64
|
|
import binascii
|
|
from random import randint
|
|
from typing import Tuple
|
|
|
|
import bcrypt
|
|
from asyncpg import UniqueViolationError
|
|
from itsdangerous import TimestampSigner, BadSignature
|
|
from logbook import Logger
|
|
from quart import request, current_app as app
|
|
|
|
from litecord.errors import Forbidden, Unauthorized, BadRequest
|
|
from litecord.snowflake import get_snowflake
|
|
from litecord.enums import UserFlags
|
|
|
|
|
|
log = Logger(__name__)
|
|
|
|
|
|
async def raw_token_check(token: str, db=None) -> int:
|
|
"""Check if a given token is valid.
|
|
|
|
Returns
|
|
-------
|
|
int
|
|
The User ID of the given token.
|
|
|
|
Raises
|
|
------
|
|
Unauthorized
|
|
If token is not properly formatted, or if the user does not exist.
|
|
Forbidden
|
|
If token validation fails.
|
|
"""
|
|
db = db or app.db
|
|
|
|
# just try by fragments instead of
|
|
# unpacking
|
|
fragments = token.split('.')
|
|
user_id = fragments[0]
|
|
|
|
try:
|
|
user_id = base64.b64decode(user_id.encode())
|
|
user_id = int(user_id)
|
|
except (ValueError, binascii.Error):
|
|
raise Unauthorized('Invalid user ID type')
|
|
|
|
pwd_hash = await db.fetchval("""
|
|
SELECT password_hash
|
|
FROM users
|
|
WHERE id = $1
|
|
""", user_id)
|
|
|
|
if not pwd_hash:
|
|
raise Unauthorized('User ID not found')
|
|
|
|
signer = TimestampSigner(pwd_hash)
|
|
|
|
try:
|
|
signer.unsign(token)
|
|
log.debug('login for uid {} successful', user_id)
|
|
|
|
# update the user's last_session field
|
|
# so that we can keep an exact track of activity,
|
|
# even on long-lived single sessions (that can happen
|
|
# with people leaving their clients open forever)
|
|
await db.execute("""
|
|
UPDATE users
|
|
SET last_session = (now() at time zone 'utc')
|
|
WHERE id = $1
|
|
""", user_id)
|
|
|
|
return user_id
|
|
except BadSignature:
|
|
log.warning('token failed for uid {}', user_id)
|
|
raise Forbidden('Invalid token')
|
|
|
|
|
|
async def token_check() -> int:
|
|
"""Check token information."""
|
|
# first, check if the request info already has a uid
|
|
try:
|
|
return request.user_id
|
|
except AttributeError:
|
|
pass
|
|
|
|
try:
|
|
token = request.headers['Authorization']
|
|
except KeyError:
|
|
raise Unauthorized('No token provided')
|
|
|
|
if token.startswith('Bot '):
|
|
token = token.replace('Bot ', '')
|
|
|
|
user_id = await raw_token_check(token)
|
|
request.user_id = user_id
|
|
return user_id
|
|
|
|
|
|
async def admin_check() -> int:
|
|
"""Check if the user is an admin."""
|
|
user_id = await token_check()
|
|
|
|
flags = await app.db.fetchval("""
|
|
SELECT flags
|
|
FROM users
|
|
WHERE id = $1
|
|
""", user_id)
|
|
|
|
flags = UserFlags.from_int(flags)
|
|
if not flags.is_staff:
|
|
raise Unauthorized('you are not staff')
|
|
|
|
return user_id
|
|
|
|
|
|
async def hash_data(data: str, loop=None) -> str:
|
|
"""Hash information with bcrypt."""
|
|
loop = loop or app.loop
|
|
buf = data.encode()
|
|
|
|
hashed = await loop.run_in_executor(
|
|
None, bcrypt.hashpw, buf, bcrypt.gensalt(14)
|
|
)
|
|
|
|
return hashed.decode()
|
|
|
|
|
|
|
|
async def check_username_usage(username: str, db=None):
|
|
"""Raise an error if too many people are with the same username."""
|
|
db = db or app.db
|
|
same_username = await db.fetchval("""
|
|
SELECT COUNT(*)
|
|
FROM users
|
|
WHERE username = $1
|
|
""", username)
|
|
|
|
if same_username > 8000:
|
|
raise BadRequest('Too many people.', {
|
|
'username': 'Too many people used the same username. '
|
|
'Please choose another'
|
|
})
|
|
|
|
|
|
async def create_user(username: str, email: str, password: str,
|
|
db=None, loop=None) -> Tuple[int, str]:
|
|
"""Create a single user.
|
|
|
|
Generates a distriminator and other information. You can fetch the user
|
|
data back with :meth:`Storage.get_user`.
|
|
"""
|
|
db = db or app.db
|
|
loop = loop or app.loop
|
|
|
|
new_id = get_snowflake()
|
|
|
|
# TODO: unified discrim generation based off username, that also includes
|
|
# the check_username_usage()
|
|
new_discrim = randint(1, 9999)
|
|
new_discrim = '%04d' % new_discrim
|
|
|
|
pwd_hash = await hash_data(password, loop)
|
|
|
|
await check_username_usage(username, db)
|
|
|
|
try:
|
|
await db.execute("""
|
|
INSERT INTO users
|
|
(id, email, username, discriminator, password_hash)
|
|
VALUES
|
|
($1, $2, $3, $4, $5)
|
|
""", new_id, email, username, new_discrim, pwd_hash)
|
|
except UniqueViolationError:
|
|
raise BadRequest('Email already used.')
|
|
|
|
return new_id, pwd_hash
|