blueprints: add users.py blueprint

- errors: change AuthError to Unauthorized and Forbidden
 - auth: fix bug on token_check
 - storage: add Storage.get_user_guilds
This commit is contained in:
Luna Mendes 2018-06-20 23:29:30 -03:00
parent cb8ab6d836
commit f5ea44c8d7
8 changed files with 149 additions and 17 deletions

View File

@ -5,7 +5,7 @@ from itsdangerous import Signer, BadSignature
from logbook import Logger from logbook import Logger
from quart import request, current_app as app from quart import request, current_app as app
from .errors import AuthError from .errors import Forbidden, Unauthorized
log = Logger(__name__) log = Logger(__name__)
@ -19,7 +19,7 @@ async def raw_token_check(token, db=None):
user_id = base64.b64decode(user_id.encode()) user_id = base64.b64decode(user_id.encode())
user_id = int(user_id) user_id = int(user_id)
except (ValueError, binascii.Error): except (ValueError, binascii.Error):
raise AuthError('Invalid user ID type') raise Unauthorized('Invalid user ID type')
pwd_hash = await db.fetchval(""" pwd_hash = await db.fetchval("""
SELECT password_hash SELECT password_hash
@ -28,7 +28,7 @@ async def raw_token_check(token, db=None):
""", user_id) """, user_id)
if not pwd_hash: if not pwd_hash:
raise AuthError('User ID not found') raise Unauthorized('User ID not found')
signer = Signer(pwd_hash) signer = Signer(pwd_hash)
@ -38,7 +38,7 @@ async def raw_token_check(token, db=None):
return user_id return user_id
except BadSignature: except BadSignature:
log.warning('token failed for uid {}', user_id) log.warning('token failed for uid {}', user_id)
raise AuthError('Invalid token') raise Forbidden('Invalid token')
async def token_check(): async def token_check():
@ -46,6 +46,6 @@ async def token_check():
try: try:
token = request.headers['Authorization'] token = request.headers['Authorization']
except KeyError: except KeyError:
raise AuthError('No token provided') raise Unauthorized('No token provided')
await raw_token_check(token) return await raw_token_check(token)

View File

@ -1,2 +1,3 @@
from .gateway import bp as gateway from .gateway import bp as gateway
from .auth import bp as auth from .auth import bp as auth
from .users import bp as users

View File

@ -7,7 +7,7 @@ import bcrypt
from quart import Blueprint, jsonify, request, current_app as app from quart import Blueprint, jsonify, request, current_app as app
from litecord.snowflake import get_snowflake from litecord.snowflake import get_snowflake
from litecord.errors import AuthError from litecord.errors import BadRequest
bp = Blueprint('auth', __name__) bp = Blueprint('auth', __name__)
@ -59,7 +59,7 @@ async def register():
VALUES ($1, $2, $3, $4, $5) VALUES ($1, $2, $3, $4, $5)
""", new_id, email, username, new_discrim, pwd_hash) """, new_id, email, username, new_discrim, pwd_hash)
except asyncpg.UniqueViolationError: except asyncpg.UniqueViolationError:
raise AuthError('Email already used.') raise BadRequest('Email already used.')
return jsonify({ return jsonify({
'token': make_token(new_id, pwd_hash) 'token': make_token(new_id, pwd_hash)

View File

@ -0,0 +1,111 @@
from quart import Blueprint, jsonify, request, current_app as app
from asyncpg import UniqueViolationError
from ..auth import token_check
from ..errors import Forbidden, BadRequest
bp = Blueprint('user', __name__)
@bp.route('/@me', methods=['GET'])
async def get_me():
"""Get the current user's information."""
user_id = await token_check()
user = await app.storage.get_user(user_id, True)
return jsonify(user)
@bp.route('/<int:user_id>', methods=['GET'])
async def get_other():
"""Get any user, given the user ID."""
user_id = await token_check()
bot = await app.db.fetchval("""
SELECT bot FROM users
WHERE users.id = $1
""", user_id)
if not bot:
raise Forbidden('Only bots can use this endpoint')
other = await app.storage.get_user(user_id)
return jsonify(other)
@bp.route('/@me', methods=['PATCH'])
async def patch_me():
"""Patch the current user's information."""
user_id = await token_check()
j = await request.get_json()
if not isinstance(j, dict):
raise BadRequest('Invalid payload')
user = await app.storage.get_user(user_id, True)
if 'username' in j:
try:
await app.db.execute("""
UPDATE users
SET username = $1
WHERE users.id = $2
""", j['username'], user_id)
except UniqueViolationError:
raise BadRequest('Username already used.')
user['username'] = j['username']
return jsonify(user)
@bp.route('/@me/guilds', methods=['GET'])
async def get_me_guilds():
"""Get partial user guilds."""
user_id = await token_check()
guild_ids = await app.storage.get_user_guilds(user_id)
partials = []
for guild_id in guild_ids:
partial = await app.db.fetchrow("""
SELECT id::text, name, icon, owner_id
FROM guilds
WHERE guild_id = $1
""", guild_id)
# TODO: partial['permissions']
partial['owner'] = partial['owner_id'] == user_id
partial.pop('owner_id')
partials.append(partial)
return jsonify(partials)
@bp.route('/@me/guilds/<int:guild_id>', methods=['DELETE'])
async def leave_guild(guild_id):
user_id = await token_check()
await app.db.execute("""
DELETE FROM members
WHERE user_id = $1 AND guild_id = $2
""", user_id, guild_id)
# TODO: something to dispatch events to the users
return '', 204
@bp.route('/@me/connections', methods=['GET'])
async def get_connections():
pass
# @bp.route('/@me/channels', methods=['GET'])
async def get_dms():
pass
# @bp.route('/@me/channels', methods=['POST'])
async def start_dm():
pass

View File

@ -6,7 +6,15 @@ class LitecordError(Exception):
return self.args[0] return self.args[0]
class AuthError(LitecordError): class BadRequest(LitecordError):
status_code = 400
class Unauthorized(LitecordError):
status_code = 401
class Forbidden(LitecordError):
status_code = 403 status_code = 403

View File

@ -5,7 +5,7 @@ from typing import List
import earl import earl
from logbook import Logger from logbook import Logger
from litecord.errors import WebsocketClose, AuthError from litecord.errors import WebsocketClose, Unauthorized, Forbidden
from litecord.auth import raw_token_check from litecord.auth import raw_token_check
from .errors import DecodeError, UnknownOPCode, \ from .errors import DecodeError, UnknownOPCode, \
InvalidShard, ShardingRequired InvalidShard, ShardingRequired
@ -100,11 +100,7 @@ class GatewayWebsocket:
# TODO: This function does not account for sharding. # TODO: This function does not account for sharding.
user_id = self.state.user_id user_id = self.state.user_id
guild_ids = await self.ext.db.fetch(""" guild_ids = await self.storage.get_user_guilds(user_id)
SELECT guild_id
FROM members
WHERE user_id = $1
""", user_id)
if self.state.bot: if self.state.bot:
return [{ return [{
@ -188,7 +184,7 @@ class GatewayWebsocket:
try: try:
user_id = await raw_token_check(token, self.ext.db) user_id = await raw_token_check(token, self.ext.db)
except AuthError: except (Unauthorized, Forbidden):
raise WebsocketClose(4004, 'Authentication failed') raise WebsocketClose(4004, 'Authentication failed')
bot = await self.ext.db.fetchval(""" bot = await self.ext.db.fetchval("""

View File

@ -10,6 +10,8 @@ class Storage:
async def get_user(self, user_id, secure=False) -> Dict[str, Any]: async def get_user(self, user_id, secure=False) -> Dict[str, Any]:
"""Get a single user payload.""" """Get a single user payload."""
user_id = int(user_id)
user_row = await self.db.fetchrow(""" user_row = await self.db.fetchrow("""
SELECT id::text, username, discriminator, avatar, email, SELECT id::text, username, discriminator, avatar, email,
flags, bot, mfa_enabled, verified, premium flags, bot, mfa_enabled, verified, premium
@ -17,6 +19,9 @@ class Storage:
WHERE users.id = $1 WHERE users.id = $1
""", user_id) """, user_id)
if not user_row:
return
duser = dict(user_row) duser = dict(user_row)
if not secure: if not secure:
@ -58,6 +63,16 @@ class Storage:
'emojis': [], 'emojis': [],
}} }}
async def get_user_guilds(self, user_id: int) -> List[int]:
"""Get all guild IDs a user is on."""
guild_ids = await self.db.fetch("""
SELECT guild_id
FROM members
WHERE user_id = $1
""", user_id)
return guild_ids
async def get_member_data(self, guild_id) -> List[Dict[str, Any]]: async def get_member_data(self, guild_id) -> List[Dict[str, Any]]:
"""Get member information on a guild.""" """Get member information on a guild."""
members_basic = await self.db.fetch(""" members_basic = await self.db.fetch("""

3
run.py
View File

@ -8,7 +8,7 @@ from quart import Quart, g, jsonify
from logbook import StreamHandler, Logger from logbook import StreamHandler, Logger
import config import config
from litecord.blueprints import gateway, auth from litecord.blueprints import gateway, auth, users
from litecord.gateway import websocket_handler from litecord.gateway import websocket_handler
from litecord.errors import LitecordError from litecord.errors import LitecordError
from litecord.gateway.state_manager import StateManager from litecord.gateway.state_manager import StateManager
@ -35,6 +35,7 @@ def make_app():
app = make_app() app = make_app()
app.register_blueprint(gateway, url_prefix='/api/v6') app.register_blueprint(gateway, url_prefix='/api/v6')
app.register_blueprint(auth, url_prefix='/api/v6') app.register_blueprint(auth, url_prefix='/api/v6')
app.register_blueprint(users, url_prefix='/api/v6/users')
@app.before_serving @app.before_serving