diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index 8220228..1d3fdc3 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -3,10 +3,23 @@ from quart import Blueprint, request, current_app as app, jsonify from ..auth import token_check from ..snowflake import get_snowflake from ..enums import ChannelType +from ..errors import Forbidden, GuildNotFound, BadRequest bp = Blueprint('guilds', __name__) +async def guild_check(user_id: int, guild_id: int): + """Check if a user is in a guild.""" + joined_at = await app.db.execute(""" + SELECT joined_at + FROM members + WHERE user_id = $1 AND guild_id = $2 + """, user_id, guild_id) + + if not joined_at: + raise GuildNotFound() + + @bp.route('', methods=['POST']) async def create_guild(): user_id = await token_check() @@ -58,3 +71,152 @@ async def create_guild(): guild_extra = await app.storage.get_guild_extra(guild_id, user_id, 250) return jsonify({**guild_json, **guild_extra}) + + +@bp.route('/', methods=['GET']) +async def get_guild(guild_id): + user_id = await token_check() + gj = await app.storage.get_guild(guild_id, user_id) + gj_extra = await app.storage.get_guild_extra(guild_id, user_id, 250) + + return jsonify({**gj, **gj_extra}) + + +@bp.route('/', methods=['DELETE']) +async def delete_guild(guild_id): + user_id = await token_check() + + owner_id = await app.db.fetchval(""" + SELECT owner_id + FROM guilds + WHERE guild_id = $1 + """, guild_id) + + if not owner_id: + raise GuildNotFound() + + if user_id != owner_id: + raise Forbidden('You are not the owner of the guild') + + # TODO: delete guild, fire GUILD_DELETE to guild + + return '', 204 + + +@bp.route('//channels', methods=['GET']) +async def get_guild_channels(guild_id): + user_id = await token_check() + await guild_check(user_id, guild_id) + + channels = await app.storage.get_channel_data(guild_id) + return jsonify(channels) + + +@bp.route('//channels', methods=['POST']) +async def create_channel(guild_id): + user_id = await token_check() + j = await request.get_json() + + # TODO: check permissions for MANAGE_CHANNELS + await guild_check(user_id, guild_id) + + new_channel_id = get_snowflake() + channel_type = j.get('type', ChannelType.GUILD_TEXT) + + if channel_type not in (ChannelType.GUILD_TEXT, + ChannelType.GUILD_VOICE): + raise BadRequest('Invalid channel type') + + await app.db.execute(""" + INSERT INTO channels (id, channel_type) + VALUES ($1, $2) + """, new_channel_id, channel_type) + + max_pos = await app.db.fetch(""" + SELECT MAX(position) + FROM guild_channels + WHERE guild_id = $1 + """, guild_id) + + channel = { + 'id': str(new_channel_id), + 'type': channel_type, + 'guild_id': str(guild_id), + 'position': max_pos + 1, + 'permission_overwrites': [], + 'nsfw': False, + 'name': j['name'], + } + + if channel_type == ChannelType.GUILD_TEXT: + await app.db.execute(""" + INSERT INTO guild_channels (id, guild_id, name, position) + VALUES ($1, $2, $3, $4) + """, new_channel_id, guild_id, j['name'], max_pos + 1) + + await app.db.execute(""" + INSERT INTO guild_text_channels (id) + VALUES ($1) + """, new_channel_id) + + channel['topic'] = None + elif channel_type == ChannelType.GUILD_VOICE: + channel['user_limit'] = 0 + channel['bitrate'] = 64 + + raise NotImplementedError() + + # TODO: fire Channel Create event + + return jsonify(channel) + + +@bp.route('//members/', methods=['GET']) +async def get_guild_member(guild_id, member_id): + user_id = await token_check() + await guild_check(user_id, guild_id) + + member = await app.storage.get_single_member(guild_id, member_id) + return jsonify(member) + + +@bp.route('//members', methods=['GET']) +async def get_members(guild_id): + user_id = await token_check() + await guild_check(user_id, guild_id) + + j = await request.get_json() + + limit, after = int(j.get('limit', 1)), j.get('after', 0) + + if limit < 1 or limit > 1000: + raise BadRequest('limit not in 1-1000 range') + + user_ids = await app.db.fetch(f""" + SELECT user_id + WHERE guild_id = $1, user_id > $2 + LIMIT {limit} + ORDER BY user_id ASC + """, guild_id, after) + + user_ids = [r[0] for r in user_ids] + members = await app.storage.get_member_multi(guild_id, user_ids) + return jsonify(members) + + +@bp.route('//members/@me/nick', methods=['PATCH']) +async def update_nickname(guild_id): + user_id = await token_check() + await guild_check(user_id, guild_id) + + j = await request.get_json() + + await app.db.execute(""" + UPDATE members + SET nickname = $1 + WHERE user_id = $2 AND guild_id = $3 + """, j['nick'], user_id, guild_id) + + # TODO: fire guild member update event + + return j['nick'] diff --git a/litecord/errors.py b/litecord/errors.py index 8bcb93a..5e72376 100644 --- a/litecord/errors.py +++ b/litecord/errors.py @@ -18,6 +18,14 @@ class Forbidden(LitecordError): status_code = 403 +class NotFound(LitecordError): + status_code = 404 + + +class GuildNotFound(LitecordError): + status_code = 404 + + class WebsocketClose(Exception): @property def code(self): diff --git a/litecord/storage.py b/litecord/storage.py index 44cb4ba..aac208f 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -72,7 +72,66 @@ class Storage: return guild_ids - async def get_member_data(self, guild_id) -> List[Dict[str, Any]]: + async def get_member_data_one(self, guild_id, member_id) -> Dict[str, any]: + basic = await self.db.fetchrow(""" + SELECT user_id, nickname, joined_at, deafened, muted + FROM members + WHERE guild_id = $1 and user_id = $2 + """, guild_id, member_id) + + if not basic: + return + + members_roles = await self.db.fetch(""" + SELECT role_id::text + FROM member_roles + WHERE guild_id = $1 AND user_id = $2 + """, guild_id, member_id) + + return { + 'user': await self.get_user(member_id), + 'nick': basic['nickname'], + 'roles': [row[0] for row in members_roles], + 'joined_at': basic['joined_at'].isoformat(), + 'deaf': basic['deafened'], + 'mute': basic['muted'], + } + + async def _member_dict(self, row, guild_id, member_id) -> Dict[str, Any]: + members_roles = await self.db.fetch(""" + SELECT role_id::text + FROM member_roles + WHERE guild_id = $1 AND user_id = $2 + """, guild_id, member_id) + + return { + 'user': await self.get_user(member_id), + 'nick': row['nickname'], + 'roles': [row[0] for row in members_roles], + 'joined_at': row['joined_at'].isoformat(), + 'deaf': row['deafened'], + 'mute': row['muted'], + } + + async def get_member_multi(self, guild_id: int, + user_ids: List[int]) -> List[Dict[str, Any]]: + """Get member information about multiple users in a guild.""" + members = [] + + # bad idea bad idea bad idea + for user_id in user_ids: + row = await self.db.fetchrow(""" + SELECT user_id, nickname, joined_at, defened, muted + FROM members + WHERE guild_id = $1 AND user_id = $2 + """, guild_id, user_id) + + member = await self._member_dict(row, guild_id, user_id) + members.append(member) + + return members + + async def get_member_data(self, guild_id: int) -> List[Dict[str, Any]]: """Get member information on a guild.""" members_basic = await self.db.fetch(""" SELECT user_id, nickname, joined_at, deafened, muted @@ -83,22 +142,8 @@ class Storage: members = [] for row in members_basic: - member_id = row['user_id'] - - members_roles = await self.db.fetch(""" - SELECT role_id::text - FROM member_roles - WHERE guild_id = $1 AND user_id = $2 - """, guild_id, member_id) - - members.append({ - 'user': await self.get_user(member_id), - 'nick': row['nickname'], - 'roles': [row[0] for row in members_roles], - 'joined_at': row['joined_at'].isoformat(), - 'deaf': row['deafened'], - 'mute': row['muted'], - }) + member = await self._member_dict(row, guild_id, row['user_id']) + members.append(member) return members @@ -139,7 +184,7 @@ class Storage: WHERE id = $1 """, row['id']) - res = await self._channels_extra(row, ctype) + res = await self._channels_extra(dict(row), ctype) # type is a SQL keyword, so we can't do # 'overwrite_type AS type'