From 68c6442375ad622137cff433fae1a7d7ee81e89d Mon Sep 17 00:00:00 2001 From: Luna Date: Mon, 4 Mar 2019 15:12:38 -0300 Subject: [PATCH] storage: fix more type hints, handle None on more places --- litecord/storage.py | 41 ++++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/litecord/storage.py b/litecord/storage.py index 2b63e82..d105aca 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -60,11 +60,9 @@ def bool_(val): return maybe(int, val) -def _filter_recipients(recipients: List[Dict[str, Any]], user_id: int): +def _filter_recipients(recipients: List[Dict[str, Any]], user_id: str): """Filter recipients in a list of recipients, removing the one that is reundant (ourselves).""" - user_id = str(user_id) - return list(filter( lambda recipient: recipient['id'] != user_id, recipients)) @@ -148,7 +146,7 @@ class Storage: WHERE username = $1 AND discriminator = $2 """, username, discriminator) - async def get_guild(self, guild_id: int, user_id=None) -> Dict: + async def get_guild(self, guild_id: int, user_id=None) -> Optional[Dict]: """Get gulid payload.""" row = await self.db.fetchrow(""" SELECT id::text, owner_id::text, name, icon, splash, @@ -163,7 +161,7 @@ class Storage: """, guild_id) if not row: - return + return None drow = dict(row) @@ -180,7 +178,7 @@ class Storage: """, guild_id, member_id) async def get_member_role_ids(self, guild_id: int, - member_id: int) -> List[int]: + member_id: int) -> List[str]: """Get a list of role IDs that are on a member.""" roles = await self.db.fetch(""" SELECT role_id::text @@ -322,6 +320,7 @@ class Storage: return {**row, **dict(vrow)} log.warning('unknown channel type: {}', chan_type) + return row async def get_chan_type(self, channel_id: int) -> int: """Get the channel type integer, given channel ID.""" @@ -386,9 +385,12 @@ class Storage: if user_id == reference_id: continue - res.append( - await self.get_user(user_id) - ) + user = await self.get_user(user_id) + + if user is None: + continue + + res.append(user) return res @@ -600,13 +602,17 @@ class Storage: 'voice_states': await self.guild_voice_states(guild_id), }} - async def get_guild_full(self, guild_id: int, - user_id: int, large_count: int = 250) -> Dict: + async def get_guild_full(self, guild_id: int, user_id: int, + large_count: int = 250) -> Optional[Dict]: """Get full information on a guild. This is a very expensive operation. """ guild = await self.get_guild(guild_id, user_id) + + if guild is None: + return None + extra = await self.get_guild_extra(guild_id, user_id, large_count) return {**guild, **extra} @@ -856,7 +862,7 @@ class Storage: return res - async def get_invite(self, invite_code: str) -> dict: + async def get_invite(self, invite_code: str) -> Optional[Dict]: """Fetch invite information given its code.""" invite = await self.db.fetchrow(""" SELECT code, guild_id, channel_id @@ -885,6 +891,10 @@ class Storage: dinv['guild'] = {} chan = await self.get_channel(invite['channel_id']) + + if chan is None: + return None + dinv['channel'] = { 'id': chan['id'], 'name': chan['name'], @@ -936,12 +946,13 @@ class Storage: return dinv - async def get_dm(self, dm_id: int, user_id: int = None): + async def get_dm(self, dm_id: int, user_id: int = None) -> Optional[Dict]: + """Get a DM channel.""" dm_chan = await self.get_channel(dm_id) - if user_id: + if user_id and dm_chan: dm_chan['recipients'] = _filter_recipients( - dm_chan['recipients'], user_id + dm_chan['recipients'], str(user_id) ) return dm_chan