diff --git a/litecord/storage.py b/litecord/storage.py index d1efc3c..307acd3 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -14,8 +14,20 @@ async def _dummy(any_id): return str(any_id) +def maybe(typ, val): + return typ(val) if val is not None else None + + def dict_(val): - return dict(val) if val else None + return maybe(dict, val) + + +def str_(val): + return maybe(str, val) + + +def timestamp_(dt): + return dt.isoformat() if dt else None async def _set_json(con): @@ -108,7 +120,13 @@ class Storage: async def get_guild(self, guild_id: int, user_id=None) -> Dict: """Get gulid payload.""" row = await self.db.fetchrow(""" - SELECT * + SELECT id::text, owner_id::text, name, icon, splash, + region, afk_channel_id::text, afk_timeout, + verification_level, default_message_notifications, + explicit_content_filter, mfa_level, + embed_enabled, embed_channel_id::text, + widget_enabled, widget_channel_id::text, + system_channel_id::text FROM guilds WHERE guilds.id = $1 """, guild_id) @@ -119,22 +137,7 @@ class Storage: drow = dict(row) if user_id: - drow['owner'] = drow['owner_id'] == user_id - - # TODO: Probably a really bad idea to repeat str() calls - # Any ideas to make this simpler? - # (No, changing the types on the db wouldn't be nice) - drow['id'] = str(drow['id']) - drow['owner_id'] = str(drow['owner_id']) - drow['afk_channel_id'] = str(drow['afk_channel_id']) \ - if drow['afk_channel_id'] else None - drow['embed_channel_id'] = str(drow['embed_channel_id']) \ - if drow['embed_channel_id'] else None - - drow['widget_channel_id'] = str(drow['widget_channel_id']) \ - if drow['widget_channel_id'] else None - drow['system_channel_id'] = str(drow['system_channel_id']) \ - if drow['system_channel_id'] else None + drow['owner'] = drow['owner_id'] == str(user_id) # TODO: emojis drow['emojis'] = [] @@ -150,31 +153,13 @@ class Storage: return [row['guild_id'] for row in guild_ids] - async def get_member_data_one(self, guild_id, member_id) -> Dict[str, any]: - basic = await self.db.fetchrow(""" + async def _member_basic(self, guild_id: int, member_id: int): + return await self.db.fetchrow(""" SELECT user_id, nickname, joined_at, deafened, muted FROM members WHERE guild_id = $1 and user_id = $2 """, guild_id, member_id) - if not basic: - return - - members_roles = await self.db.fetch(""" - SELECT role_id::text - FROM member_roles - WHERE guild_id = $1 AND user_id = $2 - """, guild_id, member_id) - - return { - 'user': await self.get_user(member_id), - 'nick': basic['nickname'], - 'roles': [row[0] for row in members_roles], - 'joined_at': basic['joined_at'].isoformat(), - 'deaf': basic['deafened'], - 'mute': basic['muted'], - } - async def _member_dict(self, row, guild_id, member_id) -> Dict[str, Any]: members_roles = await self.db.fetch(""" SELECT role_id::text @@ -191,18 +176,26 @@ class Storage: 'mute': row['muted'], } + async def get_member_data_one(self, guild_id: int, + member_id: int) -> Dict[str, Any]: + """Get data about one member in a guild.""" + basic = await self._member_basic(guild_id, member_id) + + if not basic: + return + + return await self._member_dict(basic, guild_id, member_id) + async def get_member_multi(self, guild_id: int, user_ids: List[int]) -> List[Dict[str, Any]]: """Get member information about multiple users in a guild.""" members = [] - # bad idea bad idea bad idea for user_id in user_ids: - row = await self.db.fetchrow(""" - SELECT user_id, nickname, joined_at, defened, muted - FROM members - WHERE guild_id = $1 AND user_id = $2 - """, guild_id, user_id) + row = await self._member_basic(guild_id, user_id) + + if not row: + continue member = await self._member_dict(row, guild_id, user_id) members.append(member) @@ -247,14 +240,14 @@ class Storage: WHERE channel_id = $1 """, channel_id) - async def chan_last_message_str(self, channel_id: int) -> int: + async def chan_last_message_str(self, channel_id: int) -> str: """Get the last message ID but in a string. Converts to None (not the string "None") when no last message ID is found. """ last_msg = await self.chan_last_message(channel_id) - return str(last_msg) if last_msg is not None else None + return str_(last_msg) async def _channels_extra(self, row) -> Dict: """Fill in more information about a channel.""" @@ -293,7 +286,7 @@ class Storage: WHERE channels.id = $1 """, channel_id) - async def _chan_overwrites(self, channel_id): + async def _chan_overwrites(self, channel_id: int) -> List[Dict[str, Any]]: overwrite_rows = await self.db.fetch(""" SELECT target_type, target_role, target_user, allow, deny FROM channel_overwrites @@ -369,11 +362,13 @@ class Storage: drow['id'] = str(drow['id']) return drow elif ctype == ChannelType.GROUP_DM: + # TODO: group dms pass return None async def get_channel_ids(self, guild_id: int) -> List[int]: + """Get all channel IDs in a guild.""" rows = await self.db.fetch(""" SELECT id FROM guild_channels @@ -383,9 +378,9 @@ class Storage: return [r['id'] for r in rows] async def get_channel_data(self, guild_id) -> List[Dict]: - """Get channel information on a guild""" + """Get channel list information on a guild""" channel_basics = await self.db.fetch(""" - SELECT id, guild_id::text, parent_id, name, position, nsfw + SELECT id, guild_id::text, parent_id::text, name, position, nsfw FROM guild_channels WHERE guild_id = $1 """, guild_id) @@ -412,7 +407,31 @@ class Storage: return channels + async def get_role(self, role_id: int, + guild_id: int = None) -> Dict[str, Any]: + """get a single role's information.""" + + guild_field = 'AND guild_id = $2' if guild_id else '' + + args = [role_id] + if guild_id: + args.append(guild_id) + + row = await self.db.fetchrow(f""" + SELECT id::text, name, color, hoist, position, + permissions, managed, mentionable + FROM roles + WHERE id = $1 {guild_field} + LIMIT 1 + """, *args) + + if not row: + return + + return dict(row) + async def get_role_data(self, guild_id: int) -> List[Dict[str, Any]]: + """Get role list information on a guild.""" roledata = await self.db.fetch(""" SELECT id::text, name, color, hoist, position, permissions, managed, mentionable @@ -420,12 +439,7 @@ class Storage: WHERE guild_id = $1 """, guild_id) - roles = [] - - for row in roledata: - roles.append(dict(row)) - - return roles + return list(map(dict, roledata)) async def get_guild_extra(self, guild_id: int, user_id=None, large=None) -> Dict: @@ -438,14 +452,16 @@ class Storage: WHERE guild_id = $1 """, guild_id) - if user_id and large: + if large: + res['large'] = member_count > large + + if user_id: joined_at = await self.db.fetchval(""" SELECT joined_at FROM members WHERE guild_id = $1 AND user_id = $2 """, guild_id, user_id) - res['large'] = member_count > large res['joined_at'] = joined_at.isoformat() members = await self.get_member_data(guild_id) @@ -457,23 +473,29 @@ class Storage: return {**res, **{ 'member_count': member_count, 'members': members, - 'voice_states': [], 'channels': channels, 'roles': roles, 'presences': await self.presence.guild_presences( mids, guild_id ), + + # TODO: voice state management + 'voice_states': [], }} async def get_guild_full(self, guild_id: int, user_id: int, large_count: int = 250) -> Dict: + """Get full information on a guild. + + This is a very expensive operation. + """ guild = await self.get_guild(guild_id, user_id) extra = await self.get_guild_extra(guild_id, user_id, large_count) return {**guild, **extra} - async def guild_exists(self, guild_id: int): + async def guild_exists(self, guild_id: int) -> bool: """Return if a given guild ID exists.""" owner_id = await self.db.fetch(""" SELECT owner_id @@ -484,6 +506,7 @@ class Storage: return owner_id is not None async def get_member_ids(self, guild_id: int) -> List[int]: + """Get member IDs inside a guild""" rows = await self.db.fetch(""" SELECT user_id FROM members @@ -492,7 +515,7 @@ class Storage: return [r[0] for r in rows] - async def _msg_regex(self, regex, method, content) -> List[Dict]: + async def _msg_regex(self, regex, func, content) -> List[Dict]: res = [] for match in regex.finditer(content): @@ -503,8 +526,8 @@ class Storage: except ValueError: continue - obj = await method(found_id) - if obj: + obj = await func(found_id) + if obj is not None: res.append(obj) return res @@ -525,17 +548,50 @@ class Storage: res = dict(row) res['nonce'] = str(res['nonce']) res['timestamp'] = res['timestamp'].isoformat() + res['edited_timestamp'] = timestamp_(res['edited_timestamp']) + res['type'] = res['message_type'] res.pop('message_type') + channel_id = int(row['channel_id']) + content = row['content'] + guild_id = await self.guild_from_channel(channel_id) + # calculate user mentions and role mentions by regex - res['mentions'] = await self._msg_regex(USER_MENTION, self.get_user, + async def _get_member(user_id): + user = await self.get_user(user_id) + member = None + + if guild_id: + # TODO: maybe make this partial? + member = await self.get_member_data_one(guild_id, user_id) + + return {**user, **{'member': member}} if member else user + + res['mentions'] = await self._msg_regex(USER_MENTION, _get_member, row['content']) # _dummy just returns the string of the id, since we don't # actually use the role objects in mention_roles, just their ids. - res['mention_roles'] = await self._msg_regex(ROLE_MENTION, _dummy, - row['content']) + async def _get_role_mention(role_id: int): + if not guild_id: + return str(role_id) + + if role_id == guild_id: + # TODO: check MENTION_EVERYONE permission + return str(role_id) + + role = await self.get_role(role_id, guild_id) + if not role: + return + + if not role['mentionable']: + return + + return str(role_id) + + res['mention_roles'] = await self._msg_regex( + ROLE_MENTION, _get_role_mention, content) # TODO: handle webhook authors res['author'] = await self.get_user(res['author_id']) @@ -553,13 +609,7 @@ class Storage: # TODO: res['pinned'] res['pinned'] = False - # this is specifically for lazy guilds. - guild_id = await self.db.fetchval(""" - SELECT guild_id - FROM guild_channels - WHERE guild_channels.id = $1 - """, int(res['channel_id'])) - + # this is specifically for lazy guilds: # only insert when the channel # is actually from a guild. if guild_id: @@ -811,22 +861,6 @@ class Storage: return res - async def get_all_dms(self, user_id: int) -> List[Dict[str, Any]]: - """Get all DMs for a user, regardless of the DM state.""" - dm_ids = await self.db.fetch(""" - SELECT id - FROM dm_channels - WHERE party1_id = $1 OR party2_id = $2 - """, user_id) - - res = [] - - for dm_id in dm_ids: - dm_chan = await self.get_dm(dm_id, user_id) - res.append(dm_chan) - - return res - async def get_read_state(self, user_id: int) -> List[Dict[str, Any]]: """Get the read state for a user.""" rows = await self.db.fetch(""" @@ -850,6 +884,7 @@ class Storage: return res async def guild_from_channel(self, channel_id: int): + """Get the guild id coming from a channel id.""" return await self.db.fetchval(""" SELECT guild_id FROM guild_channels