litecord/litecord/auth.py

138 lines
3.5 KiB
Python

import base64
import binascii
from random import randint
import bcrypt
from asyncpg import UniqueViolationError
from itsdangerous import Signer, BadSignature
from logbook import Logger
from quart import request, current_app as app
from litecord.errors import Forbidden, Unauthorized, BadRequest
from litecord.snowflake import get_snowflake
log = Logger(__name__)
async def raw_token_check(token, db=None):
db = db or app.db
# just try by fragments instead of
# unpacking
fragments = token.split('.')
user_id = fragments[0]
try:
user_id = base64.b64decode(user_id.encode())
user_id = int(user_id)
except (ValueError, binascii.Error):
raise Unauthorized('Invalid user ID type')
pwd_hash = await db.fetchval("""
SELECT password_hash
FROM users
WHERE id = $1
""", user_id)
if not pwd_hash:
raise Unauthorized('User ID not found')
signer = Signer(pwd_hash)
try:
signer.unsign(token)
log.debug('login for uid {} successful', user_id)
# update the user's last_session field
# so that we can keep an exact track of activity,
# even on long-lived single sessions (that can happen
# with people leaving their clients open forever)
await db.execute("""
UPDATE users
SET last_session = (now() at time zone 'utc')
WHERE id = $1
""", user_id)
return user_id
except BadSignature:
log.warning('token failed for uid {}', user_id)
raise Forbidden('Invalid token')
async def token_check():
"""Check token information."""
# first, check if the request info already has a uid
try:
return request.user_id
except AttributeError:
pass
try:
token = request.headers['Authorization']
except KeyError:
raise Unauthorized('No token provided')
if token.startswith('Bot '):
token = token.replace('Bot ', '')
user_id = await raw_token_check(token)
request.user_id = user_id
return user_id
async def hash_data(data: str, loop=None) -> str:
"""Hash information with bcrypt."""
loop = loop or app.loop
buf = data.encode()
hashed = await loop.run_in_executor(
None, bcrypt.hashpw, buf, bcrypt.gensalt(14)
)
return hashed.decode()
async def check_username_usage(username: str, db=None):
"""Raise an error if too many people are with the same username."""
db = db or app.db
same_username = await db.fetchval("""
SELECT COUNT(*)
FROM users
WHERE username = $1
""", username)
if same_username > 8000:
raise BadRequest('Too many people.', {
'username': 'Too many people used the same username. '
'Please choose another'
})
async def create_user(username: str, email: str, password: str,
db=None, loop=None):
"""Create a single user."""
db = db or app.db
loop = loop or app.loop
new_id = get_snowflake()
new_discrim = randint(1, 9999)
new_discrim = '%04d' % new_discrim
pwd_hash = await hash_data(password, loop)
await check_username_usage(username, db)
try:
await db.execute("""
INSERT INTO users (id, email, username,
discriminator, password_hash)
VALUES ($1, $2, $3, $4, $5)
""", new_id, email, username, new_discrim, pwd_hash)
except UniqueViolationError:
raise BadRequest('Email already used.')
return new_id, pwd_hash