mirror of https://gitlab.com/litecord/litecord.git
blueprints: add users.py blueprint
- errors: change AuthError to Unauthorized and Forbidden - auth: fix bug on token_check - storage: add Storage.get_user_guilds
This commit is contained in:
parent
cb8ab6d836
commit
f5ea44c8d7
|
|
@ -5,7 +5,7 @@ from itsdangerous import Signer, BadSignature
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
from quart import request, current_app as app
|
from quart import request, current_app as app
|
||||||
|
|
||||||
from .errors import AuthError
|
from .errors import Forbidden, Unauthorized
|
||||||
|
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
@ -19,7 +19,7 @@ async def raw_token_check(token, db=None):
|
||||||
user_id = base64.b64decode(user_id.encode())
|
user_id = base64.b64decode(user_id.encode())
|
||||||
user_id = int(user_id)
|
user_id = int(user_id)
|
||||||
except (ValueError, binascii.Error):
|
except (ValueError, binascii.Error):
|
||||||
raise AuthError('Invalid user ID type')
|
raise Unauthorized('Invalid user ID type')
|
||||||
|
|
||||||
pwd_hash = await db.fetchval("""
|
pwd_hash = await db.fetchval("""
|
||||||
SELECT password_hash
|
SELECT password_hash
|
||||||
|
|
@ -28,7 +28,7 @@ async def raw_token_check(token, db=None):
|
||||||
""", user_id)
|
""", user_id)
|
||||||
|
|
||||||
if not pwd_hash:
|
if not pwd_hash:
|
||||||
raise AuthError('User ID not found')
|
raise Unauthorized('User ID not found')
|
||||||
|
|
||||||
signer = Signer(pwd_hash)
|
signer = Signer(pwd_hash)
|
||||||
|
|
||||||
|
|
@ -38,7 +38,7 @@ async def raw_token_check(token, db=None):
|
||||||
return user_id
|
return user_id
|
||||||
except BadSignature:
|
except BadSignature:
|
||||||
log.warning('token failed for uid {}', user_id)
|
log.warning('token failed for uid {}', user_id)
|
||||||
raise AuthError('Invalid token')
|
raise Forbidden('Invalid token')
|
||||||
|
|
||||||
|
|
||||||
async def token_check():
|
async def token_check():
|
||||||
|
|
@ -46,6 +46,6 @@ async def token_check():
|
||||||
try:
|
try:
|
||||||
token = request.headers['Authorization']
|
token = request.headers['Authorization']
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise AuthError('No token provided')
|
raise Unauthorized('No token provided')
|
||||||
|
|
||||||
await raw_token_check(token)
|
return await raw_token_check(token)
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,3 @@
|
||||||
from .gateway import bp as gateway
|
from .gateway import bp as gateway
|
||||||
from .auth import bp as auth
|
from .auth import bp as auth
|
||||||
|
from .users import bp as users
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import bcrypt
|
||||||
from quart import Blueprint, jsonify, request, current_app as app
|
from quart import Blueprint, jsonify, request, current_app as app
|
||||||
|
|
||||||
from litecord.snowflake import get_snowflake
|
from litecord.snowflake import get_snowflake
|
||||||
from litecord.errors import AuthError
|
from litecord.errors import BadRequest
|
||||||
|
|
||||||
|
|
||||||
bp = Blueprint('auth', __name__)
|
bp = Blueprint('auth', __name__)
|
||||||
|
|
@ -59,7 +59,7 @@ async def register():
|
||||||
VALUES ($1, $2, $3, $4, $5)
|
VALUES ($1, $2, $3, $4, $5)
|
||||||
""", new_id, email, username, new_discrim, pwd_hash)
|
""", new_id, email, username, new_discrim, pwd_hash)
|
||||||
except asyncpg.UniqueViolationError:
|
except asyncpg.UniqueViolationError:
|
||||||
raise AuthError('Email already used.')
|
raise BadRequest('Email already used.')
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({
|
||||||
'token': make_token(new_id, pwd_hash)
|
'token': make_token(new_id, pwd_hash)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,111 @@
|
||||||
|
from quart import Blueprint, jsonify, request, current_app as app
|
||||||
|
from asyncpg import UniqueViolationError
|
||||||
|
|
||||||
|
from ..auth import token_check
|
||||||
|
from ..errors import Forbidden, BadRequest
|
||||||
|
|
||||||
|
bp = Blueprint('user', __name__)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/@me', methods=['GET'])
|
||||||
|
async def get_me():
|
||||||
|
"""Get the current user's information."""
|
||||||
|
user_id = await token_check()
|
||||||
|
user = await app.storage.get_user(user_id, True)
|
||||||
|
return jsonify(user)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:user_id>', methods=['GET'])
|
||||||
|
async def get_other():
|
||||||
|
"""Get any user, given the user ID."""
|
||||||
|
user_id = await token_check()
|
||||||
|
|
||||||
|
bot = await app.db.fetchval("""
|
||||||
|
SELECT bot FROM users
|
||||||
|
WHERE users.id = $1
|
||||||
|
""", user_id)
|
||||||
|
|
||||||
|
if not bot:
|
||||||
|
raise Forbidden('Only bots can use this endpoint')
|
||||||
|
|
||||||
|
other = await app.storage.get_user(user_id)
|
||||||
|
return jsonify(other)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/@me', methods=['PATCH'])
|
||||||
|
async def patch_me():
|
||||||
|
"""Patch the current user's information."""
|
||||||
|
user_id = await token_check()
|
||||||
|
j = await request.get_json()
|
||||||
|
|
||||||
|
if not isinstance(j, dict):
|
||||||
|
raise BadRequest('Invalid payload')
|
||||||
|
|
||||||
|
user = await app.storage.get_user(user_id, True)
|
||||||
|
|
||||||
|
if 'username' in j:
|
||||||
|
try:
|
||||||
|
await app.db.execute("""
|
||||||
|
UPDATE users
|
||||||
|
SET username = $1
|
||||||
|
WHERE users.id = $2
|
||||||
|
""", j['username'], user_id)
|
||||||
|
except UniqueViolationError:
|
||||||
|
raise BadRequest('Username already used.')
|
||||||
|
|
||||||
|
user['username'] = j['username']
|
||||||
|
|
||||||
|
return jsonify(user)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/@me/guilds', methods=['GET'])
|
||||||
|
async def get_me_guilds():
|
||||||
|
"""Get partial user guilds."""
|
||||||
|
user_id = await token_check()
|
||||||
|
guild_ids = await app.storage.get_user_guilds(user_id)
|
||||||
|
|
||||||
|
partials = []
|
||||||
|
|
||||||
|
for guild_id in guild_ids:
|
||||||
|
partial = await app.db.fetchrow("""
|
||||||
|
SELECT id::text, name, icon, owner_id
|
||||||
|
FROM guilds
|
||||||
|
WHERE guild_id = $1
|
||||||
|
""", guild_id)
|
||||||
|
|
||||||
|
# TODO: partial['permissions']
|
||||||
|
partial['owner'] = partial['owner_id'] == user_id
|
||||||
|
partial.pop('owner_id')
|
||||||
|
|
||||||
|
partials.append(partial)
|
||||||
|
|
||||||
|
return jsonify(partials)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/@me/guilds/<int:guild_id>', methods=['DELETE'])
|
||||||
|
async def leave_guild(guild_id):
|
||||||
|
user_id = await token_check()
|
||||||
|
|
||||||
|
await app.db.execute("""
|
||||||
|
DELETE FROM members
|
||||||
|
WHERE user_id = $1 AND guild_id = $2
|
||||||
|
""", user_id, guild_id)
|
||||||
|
|
||||||
|
# TODO: something to dispatch events to the users
|
||||||
|
|
||||||
|
return '', 204
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/@me/connections', methods=['GET'])
|
||||||
|
async def get_connections():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# @bp.route('/@me/channels', methods=['GET'])
|
||||||
|
async def get_dms():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# @bp.route('/@me/channels', methods=['POST'])
|
||||||
|
async def start_dm():
|
||||||
|
pass
|
||||||
|
|
@ -6,7 +6,15 @@ class LitecordError(Exception):
|
||||||
return self.args[0]
|
return self.args[0]
|
||||||
|
|
||||||
|
|
||||||
class AuthError(LitecordError):
|
class BadRequest(LitecordError):
|
||||||
|
status_code = 400
|
||||||
|
|
||||||
|
|
||||||
|
class Unauthorized(LitecordError):
|
||||||
|
status_code = 401
|
||||||
|
|
||||||
|
|
||||||
|
class Forbidden(LitecordError):
|
||||||
status_code = 403
|
status_code = 403
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from typing import List
|
||||||
import earl
|
import earl
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from litecord.errors import WebsocketClose, AuthError
|
from litecord.errors import WebsocketClose, Unauthorized, Forbidden
|
||||||
from litecord.auth import raw_token_check
|
from litecord.auth import raw_token_check
|
||||||
from .errors import DecodeError, UnknownOPCode, \
|
from .errors import DecodeError, UnknownOPCode, \
|
||||||
InvalidShard, ShardingRequired
|
InvalidShard, ShardingRequired
|
||||||
|
|
@ -100,11 +100,7 @@ class GatewayWebsocket:
|
||||||
# TODO: This function does not account for sharding.
|
# TODO: This function does not account for sharding.
|
||||||
user_id = self.state.user_id
|
user_id = self.state.user_id
|
||||||
|
|
||||||
guild_ids = await self.ext.db.fetch("""
|
guild_ids = await self.storage.get_user_guilds(user_id)
|
||||||
SELECT guild_id
|
|
||||||
FROM members
|
|
||||||
WHERE user_id = $1
|
|
||||||
""", user_id)
|
|
||||||
|
|
||||||
if self.state.bot:
|
if self.state.bot:
|
||||||
return [{
|
return [{
|
||||||
|
|
@ -188,7 +184,7 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_id = await raw_token_check(token, self.ext.db)
|
user_id = await raw_token_check(token, self.ext.db)
|
||||||
except AuthError:
|
except (Unauthorized, Forbidden):
|
||||||
raise WebsocketClose(4004, 'Authentication failed')
|
raise WebsocketClose(4004, 'Authentication failed')
|
||||||
|
|
||||||
bot = await self.ext.db.fetchval("""
|
bot = await self.ext.db.fetchval("""
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,8 @@ class Storage:
|
||||||
|
|
||||||
async def get_user(self, user_id, secure=False) -> Dict[str, Any]:
|
async def get_user(self, user_id, secure=False) -> Dict[str, Any]:
|
||||||
"""Get a single user payload."""
|
"""Get a single user payload."""
|
||||||
|
user_id = int(user_id)
|
||||||
|
|
||||||
user_row = await self.db.fetchrow("""
|
user_row = await self.db.fetchrow("""
|
||||||
SELECT id::text, username, discriminator, avatar, email,
|
SELECT id::text, username, discriminator, avatar, email,
|
||||||
flags, bot, mfa_enabled, verified, premium
|
flags, bot, mfa_enabled, verified, premium
|
||||||
|
|
@ -17,6 +19,9 @@ class Storage:
|
||||||
WHERE users.id = $1
|
WHERE users.id = $1
|
||||||
""", user_id)
|
""", user_id)
|
||||||
|
|
||||||
|
if not user_row:
|
||||||
|
return
|
||||||
|
|
||||||
duser = dict(user_row)
|
duser = dict(user_row)
|
||||||
|
|
||||||
if not secure:
|
if not secure:
|
||||||
|
|
@ -58,6 +63,16 @@ class Storage:
|
||||||
'emojis': [],
|
'emojis': [],
|
||||||
}}
|
}}
|
||||||
|
|
||||||
|
async def get_user_guilds(self, user_id: int) -> List[int]:
|
||||||
|
"""Get all guild IDs a user is on."""
|
||||||
|
guild_ids = await self.db.fetch("""
|
||||||
|
SELECT guild_id
|
||||||
|
FROM members
|
||||||
|
WHERE user_id = $1
|
||||||
|
""", user_id)
|
||||||
|
|
||||||
|
return guild_ids
|
||||||
|
|
||||||
async def get_member_data(self, guild_id) -> List[Dict[str, Any]]:
|
async def get_member_data(self, guild_id) -> List[Dict[str, Any]]:
|
||||||
"""Get member information on a guild."""
|
"""Get member information on a guild."""
|
||||||
members_basic = await self.db.fetch("""
|
members_basic = await self.db.fetch("""
|
||||||
|
|
|
||||||
3
run.py
3
run.py
|
|
@ -8,7 +8,7 @@ from quart import Quart, g, jsonify
|
||||||
from logbook import StreamHandler, Logger
|
from logbook import StreamHandler, Logger
|
||||||
|
|
||||||
import config
|
import config
|
||||||
from litecord.blueprints import gateway, auth
|
from litecord.blueprints import gateway, auth, users
|
||||||
from litecord.gateway import websocket_handler
|
from litecord.gateway import websocket_handler
|
||||||
from litecord.errors import LitecordError
|
from litecord.errors import LitecordError
|
||||||
from litecord.gateway.state_manager import StateManager
|
from litecord.gateway.state_manager import StateManager
|
||||||
|
|
@ -35,6 +35,7 @@ def make_app():
|
||||||
app = make_app()
|
app = make_app()
|
||||||
app.register_blueprint(gateway, url_prefix='/api/v6')
|
app.register_blueprint(gateway, url_prefix='/api/v6')
|
||||||
app.register_blueprint(auth, url_prefix='/api/v6')
|
app.register_blueprint(auth, url_prefix='/api/v6')
|
||||||
|
app.register_blueprint(users, url_prefix='/api/v6/users')
|
||||||
|
|
||||||
|
|
||||||
@app.before_serving
|
@app.before_serving
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue