From f5ea44c8d711deb48dc6f466846597c3f9ad7f16 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 20 Jun 2018 23:29:30 -0300 Subject: [PATCH] blueprints: add users.py blueprint - errors: change AuthError to Unauthorized and Forbidden - auth: fix bug on token_check - storage: add Storage.get_user_guilds --- litecord/auth.py | 12 ++-- litecord/blueprints/__init__.py | 1 + litecord/blueprints/auth.py | 4 +- litecord/blueprints/users.py | 111 ++++++++++++++++++++++++++++++++ litecord/errors.py | 10 ++- litecord/gateway/websocket.py | 10 +-- litecord/storage.py | 15 +++++ run.py | 3 +- 8 files changed, 149 insertions(+), 17 deletions(-) create mode 100644 litecord/blueprints/users.py diff --git a/litecord/auth.py b/litecord/auth.py index e6f483f..d766d90 100644 --- a/litecord/auth.py +++ b/litecord/auth.py @@ -5,7 +5,7 @@ from itsdangerous import Signer, BadSignature from logbook import Logger from quart import request, current_app as app -from .errors import AuthError +from .errors import Forbidden, Unauthorized log = Logger(__name__) @@ -19,7 +19,7 @@ async def raw_token_check(token, db=None): user_id = base64.b64decode(user_id.encode()) user_id = int(user_id) except (ValueError, binascii.Error): - raise AuthError('Invalid user ID type') + raise Unauthorized('Invalid user ID type') pwd_hash = await db.fetchval(""" SELECT password_hash @@ -28,7 +28,7 @@ async def raw_token_check(token, db=None): """, user_id) if not pwd_hash: - raise AuthError('User ID not found') + raise Unauthorized('User ID not found') signer = Signer(pwd_hash) @@ -38,7 +38,7 @@ async def raw_token_check(token, db=None): return user_id except BadSignature: log.warning('token failed for uid {}', user_id) - raise AuthError('Invalid token') + raise Forbidden('Invalid token') async def token_check(): @@ -46,6 +46,6 @@ async def token_check(): try: token = request.headers['Authorization'] except KeyError: - raise AuthError('No token provided') + raise Unauthorized('No token provided') - await raw_token_check(token) + return await raw_token_check(token) diff --git a/litecord/blueprints/__init__.py b/litecord/blueprints/__init__.py index 9099edc..d36112f 100644 --- a/litecord/blueprints/__init__.py +++ b/litecord/blueprints/__init__.py @@ -1,2 +1,3 @@ from .gateway import bp as gateway from .auth import bp as auth +from .users import bp as users diff --git a/litecord/blueprints/auth.py b/litecord/blueprints/auth.py index 1c69fcd..3bcc073 100644 --- a/litecord/blueprints/auth.py +++ b/litecord/blueprints/auth.py @@ -7,7 +7,7 @@ import bcrypt from quart import Blueprint, jsonify, request, current_app as app from litecord.snowflake import get_snowflake -from litecord.errors import AuthError +from litecord.errors import BadRequest bp = Blueprint('auth', __name__) @@ -59,7 +59,7 @@ async def register(): VALUES ($1, $2, $3, $4, $5) """, new_id, email, username, new_discrim, pwd_hash) except asyncpg.UniqueViolationError: - raise AuthError('Email already used.') + raise BadRequest('Email already used.') return jsonify({ 'token': make_token(new_id, pwd_hash) diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py new file mode 100644 index 0000000..05984f6 --- /dev/null +++ b/litecord/blueprints/users.py @@ -0,0 +1,111 @@ +from quart import Blueprint, jsonify, request, current_app as app +from asyncpg import UniqueViolationError + +from ..auth import token_check +from ..errors import Forbidden, BadRequest + +bp = Blueprint('user', __name__) + + +@bp.route('/@me', methods=['GET']) +async def get_me(): + """Get the current user's information.""" + user_id = await token_check() + user = await app.storage.get_user(user_id, True) + return jsonify(user) + + +@bp.route('/', methods=['GET']) +async def get_other(): + """Get any user, given the user ID.""" + user_id = await token_check() + + bot = await app.db.fetchval(""" + SELECT bot FROM users + WHERE users.id = $1 + """, user_id) + + if not bot: + raise Forbidden('Only bots can use this endpoint') + + other = await app.storage.get_user(user_id) + return jsonify(other) + + +@bp.route('/@me', methods=['PATCH']) +async def patch_me(): + """Patch the current user's information.""" + user_id = await token_check() + j = await request.get_json() + + if not isinstance(j, dict): + raise BadRequest('Invalid payload') + + user = await app.storage.get_user(user_id, True) + + if 'username' in j: + try: + await app.db.execute(""" + UPDATE users + SET username = $1 + WHERE users.id = $2 + """, j['username'], user_id) + except UniqueViolationError: + raise BadRequest('Username already used.') + + user['username'] = j['username'] + + return jsonify(user) + + +@bp.route('/@me/guilds', methods=['GET']) +async def get_me_guilds(): + """Get partial user guilds.""" + user_id = await token_check() + guild_ids = await app.storage.get_user_guilds(user_id) + + partials = [] + + for guild_id in guild_ids: + partial = await app.db.fetchrow(""" + SELECT id::text, name, icon, owner_id + FROM guilds + WHERE guild_id = $1 + """, guild_id) + + # TODO: partial['permissions'] + partial['owner'] = partial['owner_id'] == user_id + partial.pop('owner_id') + + partials.append(partial) + + return jsonify(partials) + + +@bp.route('/@me/guilds/', methods=['DELETE']) +async def leave_guild(guild_id): + user_id = await token_check() + + await app.db.execute(""" + DELETE FROM members + WHERE user_id = $1 AND guild_id = $2 + """, user_id, guild_id) + + # TODO: something to dispatch events to the users + + return '', 204 + + +@bp.route('/@me/connections', methods=['GET']) +async def get_connections(): + pass + + +# @bp.route('/@me/channels', methods=['GET']) +async def get_dms(): + pass + + +# @bp.route('/@me/channels', methods=['POST']) +async def start_dm(): + pass diff --git a/litecord/errors.py b/litecord/errors.py index 12f99f6..8bcb93a 100644 --- a/litecord/errors.py +++ b/litecord/errors.py @@ -6,7 +6,15 @@ class LitecordError(Exception): return self.args[0] -class AuthError(LitecordError): +class BadRequest(LitecordError): + status_code = 400 + + +class Unauthorized(LitecordError): + status_code = 401 + + +class Forbidden(LitecordError): status_code = 403 diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 796f3ab..c84f2af 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -5,7 +5,7 @@ from typing import List import earl from logbook import Logger -from litecord.errors import WebsocketClose, AuthError +from litecord.errors import WebsocketClose, Unauthorized, Forbidden from litecord.auth import raw_token_check from .errors import DecodeError, UnknownOPCode, \ InvalidShard, ShardingRequired @@ -100,11 +100,7 @@ class GatewayWebsocket: # TODO: This function does not account for sharding. user_id = self.state.user_id - guild_ids = await self.ext.db.fetch(""" - SELECT guild_id - FROM members - WHERE user_id = $1 - """, user_id) + guild_ids = await self.storage.get_user_guilds(user_id) if self.state.bot: return [{ @@ -188,7 +184,7 @@ class GatewayWebsocket: try: user_id = await raw_token_check(token, self.ext.db) - except AuthError: + except (Unauthorized, Forbidden): raise WebsocketClose(4004, 'Authentication failed') bot = await self.ext.db.fetchval(""" diff --git a/litecord/storage.py b/litecord/storage.py index 0a6c576..10cb1a1 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -10,6 +10,8 @@ class Storage: async def get_user(self, user_id, secure=False) -> Dict[str, Any]: """Get a single user payload.""" + user_id = int(user_id) + user_row = await self.db.fetchrow(""" SELECT id::text, username, discriminator, avatar, email, flags, bot, mfa_enabled, verified, premium @@ -17,6 +19,9 @@ class Storage: WHERE users.id = $1 """, user_id) + if not user_row: + return + duser = dict(user_row) if not secure: @@ -58,6 +63,16 @@ class Storage: 'emojis': [], }} + async def get_user_guilds(self, user_id: int) -> List[int]: + """Get all guild IDs a user is on.""" + guild_ids = await self.db.fetch(""" + SELECT guild_id + FROM members + WHERE user_id = $1 + """, user_id) + + return guild_ids + async def get_member_data(self, guild_id) -> List[Dict[str, Any]]: """Get member information on a guild.""" members_basic = await self.db.fetch(""" diff --git a/run.py b/run.py index 59b9c42..96adeb8 100644 --- a/run.py +++ b/run.py @@ -8,7 +8,7 @@ from quart import Quart, g, jsonify from logbook import StreamHandler, Logger import config -from litecord.blueprints import gateway, auth +from litecord.blueprints import gateway, auth, users from litecord.gateway import websocket_handler from litecord.errors import LitecordError from litecord.gateway.state_manager import StateManager @@ -35,6 +35,7 @@ def make_app(): app = make_app() app.register_blueprint(gateway, url_prefix='/api/v6') app.register_blueprint(auth, url_prefix='/api/v6') +app.register_blueprint(users, url_prefix='/api/v6/users') @app.before_serving