mirror of https://gitlab.com/litecord/litecord.git
blueprints: add users.py blueprint
- errors: change AuthError to Unauthorized and Forbidden - auth: fix bug on token_check - storage: add Storage.get_user_guilds
This commit is contained in:
parent
cb8ab6d836
commit
f5ea44c8d7
|
|
@ -5,7 +5,7 @@ from itsdangerous import Signer, BadSignature
|
|||
from logbook import Logger
|
||||
from quart import request, current_app as app
|
||||
|
||||
from .errors import AuthError
|
||||
from .errors import Forbidden, Unauthorized
|
||||
|
||||
|
||||
log = Logger(__name__)
|
||||
|
|
@ -19,7 +19,7 @@ async def raw_token_check(token, db=None):
|
|||
user_id = base64.b64decode(user_id.encode())
|
||||
user_id = int(user_id)
|
||||
except (ValueError, binascii.Error):
|
||||
raise AuthError('Invalid user ID type')
|
||||
raise Unauthorized('Invalid user ID type')
|
||||
|
||||
pwd_hash = await db.fetchval("""
|
||||
SELECT password_hash
|
||||
|
|
@ -28,7 +28,7 @@ async def raw_token_check(token, db=None):
|
|||
""", user_id)
|
||||
|
||||
if not pwd_hash:
|
||||
raise AuthError('User ID not found')
|
||||
raise Unauthorized('User ID not found')
|
||||
|
||||
signer = Signer(pwd_hash)
|
||||
|
||||
|
|
@ -38,7 +38,7 @@ async def raw_token_check(token, db=None):
|
|||
return user_id
|
||||
except BadSignature:
|
||||
log.warning('token failed for uid {}', user_id)
|
||||
raise AuthError('Invalid token')
|
||||
raise Forbidden('Invalid token')
|
||||
|
||||
|
||||
async def token_check():
|
||||
|
|
@ -46,6 +46,6 @@ async def token_check():
|
|||
try:
|
||||
token = request.headers['Authorization']
|
||||
except KeyError:
|
||||
raise AuthError('No token provided')
|
||||
raise Unauthorized('No token provided')
|
||||
|
||||
await raw_token_check(token)
|
||||
return await raw_token_check(token)
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
from .gateway import bp as gateway
|
||||
from .auth import bp as auth
|
||||
from .users import bp as users
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import bcrypt
|
|||
from quart import Blueprint, jsonify, request, current_app as app
|
||||
|
||||
from litecord.snowflake import get_snowflake
|
||||
from litecord.errors import AuthError
|
||||
from litecord.errors import BadRequest
|
||||
|
||||
|
||||
bp = Blueprint('auth', __name__)
|
||||
|
|
@ -59,7 +59,7 @@ async def register():
|
|||
VALUES ($1, $2, $3, $4, $5)
|
||||
""", new_id, email, username, new_discrim, pwd_hash)
|
||||
except asyncpg.UniqueViolationError:
|
||||
raise AuthError('Email already used.')
|
||||
raise BadRequest('Email already used.')
|
||||
|
||||
return jsonify({
|
||||
'token': make_token(new_id, pwd_hash)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,111 @@
|
|||
from quart import Blueprint, jsonify, request, current_app as app
|
||||
from asyncpg import UniqueViolationError
|
||||
|
||||
from ..auth import token_check
|
||||
from ..errors import Forbidden, BadRequest
|
||||
|
||||
bp = Blueprint('user', __name__)
|
||||
|
||||
|
||||
@bp.route('/@me', methods=['GET'])
|
||||
async def get_me():
|
||||
"""Get the current user's information."""
|
||||
user_id = await token_check()
|
||||
user = await app.storage.get_user(user_id, True)
|
||||
return jsonify(user)
|
||||
|
||||
|
||||
@bp.route('/<int:user_id>', methods=['GET'])
|
||||
async def get_other():
|
||||
"""Get any user, given the user ID."""
|
||||
user_id = await token_check()
|
||||
|
||||
bot = await app.db.fetchval("""
|
||||
SELECT bot FROM users
|
||||
WHERE users.id = $1
|
||||
""", user_id)
|
||||
|
||||
if not bot:
|
||||
raise Forbidden('Only bots can use this endpoint')
|
||||
|
||||
other = await app.storage.get_user(user_id)
|
||||
return jsonify(other)
|
||||
|
||||
|
||||
@bp.route('/@me', methods=['PATCH'])
|
||||
async def patch_me():
|
||||
"""Patch the current user's information."""
|
||||
user_id = await token_check()
|
||||
j = await request.get_json()
|
||||
|
||||
if not isinstance(j, dict):
|
||||
raise BadRequest('Invalid payload')
|
||||
|
||||
user = await app.storage.get_user(user_id, True)
|
||||
|
||||
if 'username' in j:
|
||||
try:
|
||||
await app.db.execute("""
|
||||
UPDATE users
|
||||
SET username = $1
|
||||
WHERE users.id = $2
|
||||
""", j['username'], user_id)
|
||||
except UniqueViolationError:
|
||||
raise BadRequest('Username already used.')
|
||||
|
||||
user['username'] = j['username']
|
||||
|
||||
return jsonify(user)
|
||||
|
||||
|
||||
@bp.route('/@me/guilds', methods=['GET'])
|
||||
async def get_me_guilds():
|
||||
"""Get partial user guilds."""
|
||||
user_id = await token_check()
|
||||
guild_ids = await app.storage.get_user_guilds(user_id)
|
||||
|
||||
partials = []
|
||||
|
||||
for guild_id in guild_ids:
|
||||
partial = await app.db.fetchrow("""
|
||||
SELECT id::text, name, icon, owner_id
|
||||
FROM guilds
|
||||
WHERE guild_id = $1
|
||||
""", guild_id)
|
||||
|
||||
# TODO: partial['permissions']
|
||||
partial['owner'] = partial['owner_id'] == user_id
|
||||
partial.pop('owner_id')
|
||||
|
||||
partials.append(partial)
|
||||
|
||||
return jsonify(partials)
|
||||
|
||||
|
||||
@bp.route('/@me/guilds/<int:guild_id>', methods=['DELETE'])
|
||||
async def leave_guild(guild_id):
|
||||
user_id = await token_check()
|
||||
|
||||
await app.db.execute("""
|
||||
DELETE FROM members
|
||||
WHERE user_id = $1 AND guild_id = $2
|
||||
""", user_id, guild_id)
|
||||
|
||||
# TODO: something to dispatch events to the users
|
||||
|
||||
return '', 204
|
||||
|
||||
|
||||
@bp.route('/@me/connections', methods=['GET'])
|
||||
async def get_connections():
|
||||
pass
|
||||
|
||||
|
||||
# @bp.route('/@me/channels', methods=['GET'])
|
||||
async def get_dms():
|
||||
pass
|
||||
|
||||
|
||||
# @bp.route('/@me/channels', methods=['POST'])
|
||||
async def start_dm():
|
||||
pass
|
||||
|
|
@ -6,7 +6,15 @@ class LitecordError(Exception):
|
|||
return self.args[0]
|
||||
|
||||
|
||||
class AuthError(LitecordError):
|
||||
class BadRequest(LitecordError):
|
||||
status_code = 400
|
||||
|
||||
|
||||
class Unauthorized(LitecordError):
|
||||
status_code = 401
|
||||
|
||||
|
||||
class Forbidden(LitecordError):
|
||||
status_code = 403
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import List
|
|||
import earl
|
||||
from logbook import Logger
|
||||
|
||||
from litecord.errors import WebsocketClose, AuthError
|
||||
from litecord.errors import WebsocketClose, Unauthorized, Forbidden
|
||||
from litecord.auth import raw_token_check
|
||||
from .errors import DecodeError, UnknownOPCode, \
|
||||
InvalidShard, ShardingRequired
|
||||
|
|
@ -100,11 +100,7 @@ class GatewayWebsocket:
|
|||
# TODO: This function does not account for sharding.
|
||||
user_id = self.state.user_id
|
||||
|
||||
guild_ids = await self.ext.db.fetch("""
|
||||
SELECT guild_id
|
||||
FROM members
|
||||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
guild_ids = await self.storage.get_user_guilds(user_id)
|
||||
|
||||
if self.state.bot:
|
||||
return [{
|
||||
|
|
@ -188,7 +184,7 @@ class GatewayWebsocket:
|
|||
|
||||
try:
|
||||
user_id = await raw_token_check(token, self.ext.db)
|
||||
except AuthError:
|
||||
except (Unauthorized, Forbidden):
|
||||
raise WebsocketClose(4004, 'Authentication failed')
|
||||
|
||||
bot = await self.ext.db.fetchval("""
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ class Storage:
|
|||
|
||||
async def get_user(self, user_id, secure=False) -> Dict[str, Any]:
|
||||
"""Get a single user payload."""
|
||||
user_id = int(user_id)
|
||||
|
||||
user_row = await self.db.fetchrow("""
|
||||
SELECT id::text, username, discriminator, avatar, email,
|
||||
flags, bot, mfa_enabled, verified, premium
|
||||
|
|
@ -17,6 +19,9 @@ class Storage:
|
|||
WHERE users.id = $1
|
||||
""", user_id)
|
||||
|
||||
if not user_row:
|
||||
return
|
||||
|
||||
duser = dict(user_row)
|
||||
|
||||
if not secure:
|
||||
|
|
@ -58,6 +63,16 @@ class Storage:
|
|||
'emojis': [],
|
||||
}}
|
||||
|
||||
async def get_user_guilds(self, user_id: int) -> List[int]:
|
||||
"""Get all guild IDs a user is on."""
|
||||
guild_ids = await self.db.fetch("""
|
||||
SELECT guild_id
|
||||
FROM members
|
||||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
|
||||
return guild_ids
|
||||
|
||||
async def get_member_data(self, guild_id) -> List[Dict[str, Any]]:
|
||||
"""Get member information on a guild."""
|
||||
members_basic = await self.db.fetch("""
|
||||
|
|
|
|||
3
run.py
3
run.py
|
|
@ -8,7 +8,7 @@ from quart import Quart, g, jsonify
|
|||
from logbook import StreamHandler, Logger
|
||||
|
||||
import config
|
||||
from litecord.blueprints import gateway, auth
|
||||
from litecord.blueprints import gateway, auth, users
|
||||
from litecord.gateway import websocket_handler
|
||||
from litecord.errors import LitecordError
|
||||
from litecord.gateway.state_manager import StateManager
|
||||
|
|
@ -35,6 +35,7 @@ def make_app():
|
|||
app = make_app()
|
||||
app.register_blueprint(gateway, url_prefix='/api/v6')
|
||||
app.register_blueprint(auth, url_prefix='/api/v6')
|
||||
app.register_blueprint(users, url_prefix='/api/v6/users')
|
||||
|
||||
|
||||
@app.before_serving
|
||||
|
|
|
|||
Loading…
Reference in New Issue