blueprints: add users.py blueprint

- errors: change AuthError to Unauthorized and Forbidden
 - auth: fix bug on token_check
 - storage: add Storage.get_user_guilds
This commit is contained in:
Luna Mendes 2018-06-20 23:29:30 -03:00
parent cb8ab6d836
commit f5ea44c8d7
8 changed files with 149 additions and 17 deletions

View File

@ -5,7 +5,7 @@ from itsdangerous import Signer, BadSignature
from logbook import Logger
from quart import request, current_app as app
from .errors import AuthError
from .errors import Forbidden, Unauthorized
log = Logger(__name__)
@ -19,7 +19,7 @@ async def raw_token_check(token, db=None):
user_id = base64.b64decode(user_id.encode())
user_id = int(user_id)
except (ValueError, binascii.Error):
raise AuthError('Invalid user ID type')
raise Unauthorized('Invalid user ID type')
pwd_hash = await db.fetchval("""
SELECT password_hash
@ -28,7 +28,7 @@ async def raw_token_check(token, db=None):
""", user_id)
if not pwd_hash:
raise AuthError('User ID not found')
raise Unauthorized('User ID not found')
signer = Signer(pwd_hash)
@ -38,7 +38,7 @@ async def raw_token_check(token, db=None):
return user_id
except BadSignature:
log.warning('token failed for uid {}', user_id)
raise AuthError('Invalid token')
raise Forbidden('Invalid token')
async def token_check():
@ -46,6 +46,6 @@ async def token_check():
try:
token = request.headers['Authorization']
except KeyError:
raise AuthError('No token provided')
raise Unauthorized('No token provided')
await raw_token_check(token)
return await raw_token_check(token)

View File

@ -1,2 +1,3 @@
from .gateway import bp as gateway
from .auth import bp as auth
from .users import bp as users

View File

@ -7,7 +7,7 @@ import bcrypt
from quart import Blueprint, jsonify, request, current_app as app
from litecord.snowflake import get_snowflake
from litecord.errors import AuthError
from litecord.errors import BadRequest
bp = Blueprint('auth', __name__)
@ -59,7 +59,7 @@ async def register():
VALUES ($1, $2, $3, $4, $5)
""", new_id, email, username, new_discrim, pwd_hash)
except asyncpg.UniqueViolationError:
raise AuthError('Email already used.')
raise BadRequest('Email already used.')
return jsonify({
'token': make_token(new_id, pwd_hash)

View File

@ -0,0 +1,111 @@
from quart import Blueprint, jsonify, request, current_app as app
from asyncpg import UniqueViolationError
from ..auth import token_check
from ..errors import Forbidden, BadRequest
bp = Blueprint('user', __name__)
@bp.route('/@me', methods=['GET'])
async def get_me():
"""Get the current user's information."""
user_id = await token_check()
user = await app.storage.get_user(user_id, True)
return jsonify(user)
@bp.route('/<int:user_id>', methods=['GET'])
async def get_other():
"""Get any user, given the user ID."""
user_id = await token_check()
bot = await app.db.fetchval("""
SELECT bot FROM users
WHERE users.id = $1
""", user_id)
if not bot:
raise Forbidden('Only bots can use this endpoint')
other = await app.storage.get_user(user_id)
return jsonify(other)
@bp.route('/@me', methods=['PATCH'])
async def patch_me():
"""Patch the current user's information."""
user_id = await token_check()
j = await request.get_json()
if not isinstance(j, dict):
raise BadRequest('Invalid payload')
user = await app.storage.get_user(user_id, True)
if 'username' in j:
try:
await app.db.execute("""
UPDATE users
SET username = $1
WHERE users.id = $2
""", j['username'], user_id)
except UniqueViolationError:
raise BadRequest('Username already used.')
user['username'] = j['username']
return jsonify(user)
@bp.route('/@me/guilds', methods=['GET'])
async def get_me_guilds():
"""Get partial user guilds."""
user_id = await token_check()
guild_ids = await app.storage.get_user_guilds(user_id)
partials = []
for guild_id in guild_ids:
partial = await app.db.fetchrow("""
SELECT id::text, name, icon, owner_id
FROM guilds
WHERE guild_id = $1
""", guild_id)
# TODO: partial['permissions']
partial['owner'] = partial['owner_id'] == user_id
partial.pop('owner_id')
partials.append(partial)
return jsonify(partials)
@bp.route('/@me/guilds/<int:guild_id>', methods=['DELETE'])
async def leave_guild(guild_id):
user_id = await token_check()
await app.db.execute("""
DELETE FROM members
WHERE user_id = $1 AND guild_id = $2
""", user_id, guild_id)
# TODO: something to dispatch events to the users
return '', 204
@bp.route('/@me/connections', methods=['GET'])
async def get_connections():
pass
# @bp.route('/@me/channels', methods=['GET'])
async def get_dms():
pass
# @bp.route('/@me/channels', methods=['POST'])
async def start_dm():
pass

View File

@ -6,7 +6,15 @@ class LitecordError(Exception):
return self.args[0]
class AuthError(LitecordError):
class BadRequest(LitecordError):
status_code = 400
class Unauthorized(LitecordError):
status_code = 401
class Forbidden(LitecordError):
status_code = 403

View File

@ -5,7 +5,7 @@ from typing import List
import earl
from logbook import Logger
from litecord.errors import WebsocketClose, AuthError
from litecord.errors import WebsocketClose, Unauthorized, Forbidden
from litecord.auth import raw_token_check
from .errors import DecodeError, UnknownOPCode, \
InvalidShard, ShardingRequired
@ -100,11 +100,7 @@ class GatewayWebsocket:
# TODO: This function does not account for sharding.
user_id = self.state.user_id
guild_ids = await self.ext.db.fetch("""
SELECT guild_id
FROM members
WHERE user_id = $1
""", user_id)
guild_ids = await self.storage.get_user_guilds(user_id)
if self.state.bot:
return [{
@ -188,7 +184,7 @@ class GatewayWebsocket:
try:
user_id = await raw_token_check(token, self.ext.db)
except AuthError:
except (Unauthorized, Forbidden):
raise WebsocketClose(4004, 'Authentication failed')
bot = await self.ext.db.fetchval("""

View File

@ -10,6 +10,8 @@ class Storage:
async def get_user(self, user_id, secure=False) -> Dict[str, Any]:
"""Get a single user payload."""
user_id = int(user_id)
user_row = await self.db.fetchrow("""
SELECT id::text, username, discriminator, avatar, email,
flags, bot, mfa_enabled, verified, premium
@ -17,6 +19,9 @@ class Storage:
WHERE users.id = $1
""", user_id)
if not user_row:
return
duser = dict(user_row)
if not secure:
@ -58,6 +63,16 @@ class Storage:
'emojis': [],
}}
async def get_user_guilds(self, user_id: int) -> List[int]:
"""Get all guild IDs a user is on."""
guild_ids = await self.db.fetch("""
SELECT guild_id
FROM members
WHERE user_id = $1
""", user_id)
return guild_ids
async def get_member_data(self, guild_id) -> List[Dict[str, Any]]:
"""Get member information on a guild."""
members_basic = await self.db.fetch("""

3
run.py
View File

@ -8,7 +8,7 @@ from quart import Quart, g, jsonify
from logbook import StreamHandler, Logger
import config
from litecord.blueprints import gateway, auth
from litecord.blueprints import gateway, auth, users
from litecord.gateway import websocket_handler
from litecord.errors import LitecordError
from litecord.gateway.state_manager import StateManager
@ -35,6 +35,7 @@ def make_app():
app = make_app()
app.register_blueprint(gateway, url_prefix='/api/v6')
app.register_blueprint(auth, url_prefix='/api/v6')
app.register_blueprint(users, url_prefix='/api/v6/users')
@app.before_serving