From c477b2ed50db9cf1e3da6de4ebcc3cf28f91a7d4 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 20 Jun 2018 17:39:21 -0300 Subject: [PATCH] storage: add channel fetching logic - litecord: add enums - storage: add get_user - storage: use column::text instead of str() in some cases --- litecord/enums.py | 7 ++ litecord/gateway/state.py | 3 +- litecord/gateway/websocket.py | 37 ++++++--- litecord/storage.py | 148 +++++++++++++++++++++++++++------- 4 files changed, 153 insertions(+), 42 deletions(-) create mode 100644 litecord/enums.py diff --git a/litecord/enums.py b/litecord/enums.py new file mode 100644 index 0000000..984a057 --- /dev/null +++ b/litecord/enums.py @@ -0,0 +1,7 @@ + +class ChannelType: + GUILD_TEXT = 0 + DM = 1 + GUILD_VOICE = 2 + GROUP_DM = 3 + GUILD_CATEGORY = 4 diff --git a/litecord/gateway/state.py b/litecord/gateway/state.py index a16f18e..e6101ab 100644 --- a/litecord/gateway/state.py +++ b/litecord/gateway/state.py @@ -18,8 +18,7 @@ class GatewayState: self.seq = kwargs.get('seq', 0) self.shard = kwargs.get('shard', [0, 1]) self.user_id = kwargs.get('user_id') - - self.ws = None + self.bot = kwargs.get('bot', False) for key in kwargs: value = kwargs[key] diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 430e98b..e0aba21 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -45,7 +45,6 @@ class GatewayWebsocket: def __init__(self, ws, **kwargs): self.ext = WebsocketObjects(*kwargs['prop']) self.storage = self.ext.storage - self.state_manager = self.ext.state_manager self.ws = ws self.wsp = WebsocketProperties(kwargs.get('v'), @@ -99,6 +98,7 @@ class GatewayWebsocket: async def _make_guild_list(self) -> List[int]: # TODO: This function does not account for sharding. + # TODO: This function does not account for bots. user_id = self.state.user_id guild_ids = await self.ext.db.fetch(""" @@ -107,15 +107,27 @@ class GatewayWebsocket: WHERE user_id = $1 """, user_id) - return [{ - 'id': row[0], - 'unavailable': True, - } for row in guild_ids] + if self.state.bot: + return [{ + 'id': row[0], + 'unavailable': True, + } for row in guild_ids] + + return [ + await self.storage.get_guild(row[0], self.state) + for row in guild_ids + ] async def guild_dispatch(self, unavailable_guilds: List[dict]): + """Dispatch GUILD_CREATE information.""" + + # Users don't get asynchronous guild dispatching. + if not self.state.bot: + return + for guild_obj in unavailable_guilds: guild = await self.storage.get_guild(guild_obj['id'], - self.state.user_id) + self.state) if not guild: continue @@ -123,7 +135,7 @@ class GatewayWebsocket: await self.dispatch('GUILD_CREATE', dict(guild)) async def dispatch_ready(self): - """Dispatch the READY packet for a connecting user.""" + """Dispatch the READY packet for a connecting account.""" guilds = await self._make_guild_list() user = await self.storage.get_user(self.state.user_id, True) @@ -180,20 +192,25 @@ class GatewayWebsocket: except AuthError: raise WebsocketClose(4004, 'Authentication failed') + bot = await self.ext.db.fetchval(""" + SELECT bot FROM users + WHERE id = $1 + """, user_id) + self.state = GatewayState( user_id=user_id, + bot=bot, properties=properties, compress=compress, large=large, shard=shard, presence=presence, + ws=self ) - self.state.ws = self - await self._check_shards() - self.state_manager.insert(self.state) + self.ext.state_manager.insert(self.state) await self.dispatch_ready() async def process_message(self, payload): diff --git a/litecord/storage.py b/litecord/storage.py index 44bf9a3..0a6c576 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -1,4 +1,6 @@ -from typing import Dict +from typing import List, Dict, Any + +from .enums import ChannelType class Storage: @@ -6,10 +8,27 @@ class Storage: def __init__(self, db): self.db = db - async def get_user(self, guild_id, secure=False): - pass + async def get_user(self, user_id, secure=False) -> Dict[str, Any]: + """Get a single user payload.""" + user_row = await self.db.fetchrow(""" + SELECT id::text, username, discriminator, avatar, email, + flags, bot, mfa_enabled, verified, premium + FROM users + WHERE users.id = $1 + """, user_id) - async def get_guild(self, guild_id: int, state) -> Dict: + duser = dict(user_row) + + if not secure: + duser.pop('email') + duser.pop('mfa_enabled') + duser.pop('verified') + duser.pop('mfa_enabled') + + return duser + + async def get_guild(self, guild_id: int, state=None) -> Dict: + """Get gulid payload.""" row = await self.db.fetchrow(""" SELECT * FROM guilds @@ -39,6 +58,98 @@ class Storage: 'emojis': [], }} + async def get_member_data(self, guild_id) -> List[Dict[str, Any]]: + """Get member information on a guild.""" + members_basic = await self.db.fetch(""" + SELECT user_id, nickname, joined_at + FROM members + WHERE guild_id = $1 + """, guild_id) + + members = [] + + for row in members_basic: + member_id = row['user_id'] + + members_roles = await self.db.fetch(""" + SELECT role_id::text + FROM member_roles + WHERE guild_id = $1 AND user_id = $2 + """, guild_id, member_id) + + members.append({ + 'user': await self.get_user(member_id), + 'nick': row['nickname'], + 'roles': [row[0] for row in members_roles], + 'joined_at': row['joined_at'].isoformat(), + 'deaf': row['deafened'], + 'mute': row['muted'], + }) + + return members + + async def _channels_extra(self, row, channel_type: int) -> Dict: + """Fill in more information about a channel.""" + # TODO: This could probably be better with a dictionary. + + # TODO: dm and group dm? + if channel_type == ChannelType.GUILD_TEXT: + topic = await self.db.fetchval(""" + SELECT topic FROM guild_text_channels + WHERE id = $1 + """, row['id']) + + return {**row, **{ + 'topic': topic, + }} + elif channel_type == ChannelType.GUILD_VOICE: + vrow = await self.db.fetchval(""" + SELECT bitrate, user_limit FROM guild_voice_channels + WHERE id = $1 + """, row['id']) + + return {**row, **dict(vrow)} + + async def get_channel_data(self, guild_id) -> List[Dict]: + """Get channel information on a guild""" + channel_basics = await self.db.fetch(""" + SELECT * FROM guild_channels + WHERE guild_id = $1 + """, guild_id) + + channels = [] + + for row in channel_basics: + ctype = await self.db.fetchval(""" + SELECT channel_type FROM channels + WHERE id = $1 + """, row['id']) + + res = await self._channels_extra(row, ctype) + + # type is a SQL keyword, so we can't do + # 'overwrite_type AS type' + overwrite_rows = await self.db.fetch(""" + SELECT user_id::text AS id, overwrite_type, allow, deny + FROM channel_overwrites + WHERE channel_id = $1 + """, row['id']) + + def _overwrite_convert(ov_row): + drow = dict(ov_row) + drow['type'] = drow['overwrite_type'] + drow.pop('overwrite_type') + return drow + + res['permission_overwrites'] = list(map(_overwrite_convert, + overwrite_rows)) + + # Making sure. + res['id'] = str(res['id']) + channels.append(res) + + return channels + async def get_guild_extra(self, guild_id: int, state=None) -> Dict: """Get extra information about a guild.""" res = {} @@ -59,37 +170,14 @@ class Storage: res['large'] = state.large > member_count res['joined_at'] = joined_at.isoformat() - members_basic = await self.db.fetch(""" - SELECT user_id, nickname, joined_at - FROM members - WHERE guild_id = $1 - """, guild_id) - - members = [] - - for row in members_basic: - member_id = row['user_id'] - - members_roles = await self.db.fetch(""" - SELECT role_id - FROM member_roles - WHERE guild_id = $1 AND user_id = $2 - """, guild_id, member_id) - - members.append({ - 'user': await self.get_user(member_id), - 'nick': row['nickname'], - 'roles': [str(row[0]) for row in members_roles], - 'joined_at': row['joined_at'].isoformat(), - 'deaf': row['deafened'], - 'mute': row['muted'], - }) + members = await self.get_member_data(guild_id) + channels = await self.get_channel_data(guild_id) return {**res, **{ 'member_count': member_count, 'members': members, 'voice_states': [], + 'channels': channels, # TODO: finish those - 'channels': [], 'presences': [], }}