storage: multiple enhancements

storage:
 - don't use SELECT * when get_guild'ing
 - use field::text instead of the str() call madness
 - simplify get_member_data_one
 - fix parent_id being an int on get_channel_data
 - add Storage.get_role
 - better mention resolving on get_message
 - remove Storage.get_all_dms
This commit is contained in:
Luna Mendes 2018-10-12 19:06:13 -03:00
parent d28c0f1bc6
commit da4ce66a0c
1 changed files with 123 additions and 88 deletions

View File

@ -14,8 +14,20 @@ async def _dummy(any_id):
return str(any_id) return str(any_id)
def maybe(typ, val):
return typ(val) if val is not None else None
def dict_(val): def dict_(val):
return dict(val) if val else None return maybe(dict, val)
def str_(val):
return maybe(str, val)
def timestamp_(dt):
return dt.isoformat() if dt else None
async def _set_json(con): async def _set_json(con):
@ -108,7 +120,13 @@ class Storage:
async def get_guild(self, guild_id: int, user_id=None) -> Dict: async def get_guild(self, guild_id: int, user_id=None) -> Dict:
"""Get gulid payload.""" """Get gulid payload."""
row = await self.db.fetchrow(""" row = await self.db.fetchrow("""
SELECT * SELECT id::text, owner_id::text, name, icon, splash,
region, afk_channel_id::text, afk_timeout,
verification_level, default_message_notifications,
explicit_content_filter, mfa_level,
embed_enabled, embed_channel_id::text,
widget_enabled, widget_channel_id::text,
system_channel_id::text
FROM guilds FROM guilds
WHERE guilds.id = $1 WHERE guilds.id = $1
""", guild_id) """, guild_id)
@ -119,22 +137,7 @@ class Storage:
drow = dict(row) drow = dict(row)
if user_id: if user_id:
drow['owner'] = drow['owner_id'] == user_id drow['owner'] = drow['owner_id'] == str(user_id)
# TODO: Probably a really bad idea to repeat str() calls
# Any ideas to make this simpler?
# (No, changing the types on the db wouldn't be nice)
drow['id'] = str(drow['id'])
drow['owner_id'] = str(drow['owner_id'])
drow['afk_channel_id'] = str(drow['afk_channel_id']) \
if drow['afk_channel_id'] else None
drow['embed_channel_id'] = str(drow['embed_channel_id']) \
if drow['embed_channel_id'] else None
drow['widget_channel_id'] = str(drow['widget_channel_id']) \
if drow['widget_channel_id'] else None
drow['system_channel_id'] = str(drow['system_channel_id']) \
if drow['system_channel_id'] else None
# TODO: emojis # TODO: emojis
drow['emojis'] = [] drow['emojis'] = []
@ -150,31 +153,13 @@ class Storage:
return [row['guild_id'] for row in guild_ids] return [row['guild_id'] for row in guild_ids]
async def get_member_data_one(self, guild_id, member_id) -> Dict[str, any]: async def _member_basic(self, guild_id: int, member_id: int):
basic = await self.db.fetchrow(""" return await self.db.fetchrow("""
SELECT user_id, nickname, joined_at, deafened, muted SELECT user_id, nickname, joined_at, deafened, muted
FROM members FROM members
WHERE guild_id = $1 and user_id = $2 WHERE guild_id = $1 and user_id = $2
""", guild_id, member_id) """, guild_id, member_id)
if not basic:
return
members_roles = await self.db.fetch("""
SELECT role_id::text
FROM member_roles
WHERE guild_id = $1 AND user_id = $2
""", guild_id, member_id)
return {
'user': await self.get_user(member_id),
'nick': basic['nickname'],
'roles': [row[0] for row in members_roles],
'joined_at': basic['joined_at'].isoformat(),
'deaf': basic['deafened'],
'mute': basic['muted'],
}
async def _member_dict(self, row, guild_id, member_id) -> Dict[str, Any]: async def _member_dict(self, row, guild_id, member_id) -> Dict[str, Any]:
members_roles = await self.db.fetch(""" members_roles = await self.db.fetch("""
SELECT role_id::text SELECT role_id::text
@ -191,18 +176,26 @@ class Storage:
'mute': row['muted'], 'mute': row['muted'],
} }
async def get_member_data_one(self, guild_id: int,
member_id: int) -> Dict[str, Any]:
"""Get data about one member in a guild."""
basic = await self._member_basic(guild_id, member_id)
if not basic:
return
return await self._member_dict(basic, guild_id, member_id)
async def get_member_multi(self, guild_id: int, async def get_member_multi(self, guild_id: int,
user_ids: List[int]) -> List[Dict[str, Any]]: user_ids: List[int]) -> List[Dict[str, Any]]:
"""Get member information about multiple users in a guild.""" """Get member information about multiple users in a guild."""
members = [] members = []
# bad idea bad idea bad idea
for user_id in user_ids: for user_id in user_ids:
row = await self.db.fetchrow(""" row = await self._member_basic(guild_id, user_id)
SELECT user_id, nickname, joined_at, defened, muted
FROM members if not row:
WHERE guild_id = $1 AND user_id = $2 continue
""", guild_id, user_id)
member = await self._member_dict(row, guild_id, user_id) member = await self._member_dict(row, guild_id, user_id)
members.append(member) members.append(member)
@ -247,14 +240,14 @@ class Storage:
WHERE channel_id = $1 WHERE channel_id = $1
""", channel_id) """, channel_id)
async def chan_last_message_str(self, channel_id: int) -> int: async def chan_last_message_str(self, channel_id: int) -> str:
"""Get the last message ID but in a string. """Get the last message ID but in a string.
Converts to None (not the string "None") when Converts to None (not the string "None") when
no last message ID is found. no last message ID is found.
""" """
last_msg = await self.chan_last_message(channel_id) last_msg = await self.chan_last_message(channel_id)
return str(last_msg) if last_msg is not None else None return str_(last_msg)
async def _channels_extra(self, row) -> Dict: async def _channels_extra(self, row) -> Dict:
"""Fill in more information about a channel.""" """Fill in more information about a channel."""
@ -293,7 +286,7 @@ class Storage:
WHERE channels.id = $1 WHERE channels.id = $1
""", channel_id) """, channel_id)
async def _chan_overwrites(self, channel_id): async def _chan_overwrites(self, channel_id: int) -> List[Dict[str, Any]]:
overwrite_rows = await self.db.fetch(""" overwrite_rows = await self.db.fetch("""
SELECT target_type, target_role, target_user, allow, deny SELECT target_type, target_role, target_user, allow, deny
FROM channel_overwrites FROM channel_overwrites
@ -369,11 +362,13 @@ class Storage:
drow['id'] = str(drow['id']) drow['id'] = str(drow['id'])
return drow return drow
elif ctype == ChannelType.GROUP_DM: elif ctype == ChannelType.GROUP_DM:
# TODO: group dms
pass pass
return None return None
async def get_channel_ids(self, guild_id: int) -> List[int]: async def get_channel_ids(self, guild_id: int) -> List[int]:
"""Get all channel IDs in a guild."""
rows = await self.db.fetch(""" rows = await self.db.fetch("""
SELECT id SELECT id
FROM guild_channels FROM guild_channels
@ -383,9 +378,9 @@ class Storage:
return [r['id'] for r in rows] return [r['id'] for r in rows]
async def get_channel_data(self, guild_id) -> List[Dict]: async def get_channel_data(self, guild_id) -> List[Dict]:
"""Get channel information on a guild""" """Get channel list information on a guild"""
channel_basics = await self.db.fetch(""" channel_basics = await self.db.fetch("""
SELECT id, guild_id::text, parent_id, name, position, nsfw SELECT id, guild_id::text, parent_id::text, name, position, nsfw
FROM guild_channels FROM guild_channels
WHERE guild_id = $1 WHERE guild_id = $1
""", guild_id) """, guild_id)
@ -412,7 +407,31 @@ class Storage:
return channels return channels
async def get_role(self, role_id: int,
guild_id: int = None) -> Dict[str, Any]:
"""get a single role's information."""
guild_field = 'AND guild_id = $2' if guild_id else ''
args = [role_id]
if guild_id:
args.append(guild_id)
row = await self.db.fetchrow(f"""
SELECT id::text, name, color, hoist, position,
permissions, managed, mentionable
FROM roles
WHERE id = $1 {guild_field}
LIMIT 1
""", *args)
if not row:
return
return dict(row)
async def get_role_data(self, guild_id: int) -> List[Dict[str, Any]]: async def get_role_data(self, guild_id: int) -> List[Dict[str, Any]]:
"""Get role list information on a guild."""
roledata = await self.db.fetch(""" roledata = await self.db.fetch("""
SELECT id::text, name, color, hoist, position, SELECT id::text, name, color, hoist, position,
permissions, managed, mentionable permissions, managed, mentionable
@ -420,12 +439,7 @@ class Storage:
WHERE guild_id = $1 WHERE guild_id = $1
""", guild_id) """, guild_id)
roles = [] return list(map(dict, roledata))
for row in roledata:
roles.append(dict(row))
return roles
async def get_guild_extra(self, guild_id: int, async def get_guild_extra(self, guild_id: int,
user_id=None, large=None) -> Dict: user_id=None, large=None) -> Dict:
@ -438,14 +452,16 @@ class Storage:
WHERE guild_id = $1 WHERE guild_id = $1
""", guild_id) """, guild_id)
if user_id and large: if large:
res['large'] = member_count > large
if user_id:
joined_at = await self.db.fetchval(""" joined_at = await self.db.fetchval("""
SELECT joined_at SELECT joined_at
FROM members FROM members
WHERE guild_id = $1 AND user_id = $2 WHERE guild_id = $1 AND user_id = $2
""", guild_id, user_id) """, guild_id, user_id)
res['large'] = member_count > large
res['joined_at'] = joined_at.isoformat() res['joined_at'] = joined_at.isoformat()
members = await self.get_member_data(guild_id) members = await self.get_member_data(guild_id)
@ -457,23 +473,29 @@ class Storage:
return {**res, **{ return {**res, **{
'member_count': member_count, 'member_count': member_count,
'members': members, 'members': members,
'voice_states': [],
'channels': channels, 'channels': channels,
'roles': roles, 'roles': roles,
'presences': await self.presence.guild_presences( 'presences': await self.presence.guild_presences(
mids, guild_id mids, guild_id
), ),
# TODO: voice state management
'voice_states': [],
}} }}
async def get_guild_full(self, guild_id: int, async def get_guild_full(self, guild_id: int,
user_id: int, large_count: int = 250) -> Dict: user_id: int, large_count: int = 250) -> Dict:
"""Get full information on a guild.
This is a very expensive operation.
"""
guild = await self.get_guild(guild_id, user_id) guild = await self.get_guild(guild_id, user_id)
extra = await self.get_guild_extra(guild_id, user_id, large_count) extra = await self.get_guild_extra(guild_id, user_id, large_count)
return {**guild, **extra} return {**guild, **extra}
async def guild_exists(self, guild_id: int): async def guild_exists(self, guild_id: int) -> bool:
"""Return if a given guild ID exists.""" """Return if a given guild ID exists."""
owner_id = await self.db.fetch(""" owner_id = await self.db.fetch("""
SELECT owner_id SELECT owner_id
@ -484,6 +506,7 @@ class Storage:
return owner_id is not None return owner_id is not None
async def get_member_ids(self, guild_id: int) -> List[int]: async def get_member_ids(self, guild_id: int) -> List[int]:
"""Get member IDs inside a guild"""
rows = await self.db.fetch(""" rows = await self.db.fetch("""
SELECT user_id SELECT user_id
FROM members FROM members
@ -492,7 +515,7 @@ class Storage:
return [r[0] for r in rows] return [r[0] for r in rows]
async def _msg_regex(self, regex, method, content) -> List[Dict]: async def _msg_regex(self, regex, func, content) -> List[Dict]:
res = [] res = []
for match in regex.finditer(content): for match in regex.finditer(content):
@ -503,8 +526,8 @@ class Storage:
except ValueError: except ValueError:
continue continue
obj = await method(found_id) obj = await func(found_id)
if obj: if obj is not None:
res.append(obj) res.append(obj)
return res return res
@ -525,17 +548,50 @@ class Storage:
res = dict(row) res = dict(row)
res['nonce'] = str(res['nonce']) res['nonce'] = str(res['nonce'])
res['timestamp'] = res['timestamp'].isoformat() res['timestamp'] = res['timestamp'].isoformat()
res['edited_timestamp'] = timestamp_(res['edited_timestamp'])
res['type'] = res['message_type'] res['type'] = res['message_type']
res.pop('message_type') res.pop('message_type')
channel_id = int(row['channel_id'])
content = row['content']
guild_id = await self.guild_from_channel(channel_id)
# calculate user mentions and role mentions by regex # calculate user mentions and role mentions by regex
res['mentions'] = await self._msg_regex(USER_MENTION, self.get_user, async def _get_member(user_id):
user = await self.get_user(user_id)
member = None
if guild_id:
# TODO: maybe make this partial?
member = await self.get_member_data_one(guild_id, user_id)
return {**user, **{'member': member}} if member else user
res['mentions'] = await self._msg_regex(USER_MENTION, _get_member,
row['content']) row['content'])
# _dummy just returns the string of the id, since we don't # _dummy just returns the string of the id, since we don't
# actually use the role objects in mention_roles, just their ids. # actually use the role objects in mention_roles, just their ids.
res['mention_roles'] = await self._msg_regex(ROLE_MENTION, _dummy, async def _get_role_mention(role_id: int):
row['content']) if not guild_id:
return str(role_id)
if role_id == guild_id:
# TODO: check MENTION_EVERYONE permission
return str(role_id)
role = await self.get_role(role_id, guild_id)
if not role:
return
if not role['mentionable']:
return
return str(role_id)
res['mention_roles'] = await self._msg_regex(
ROLE_MENTION, _get_role_mention, content)
# TODO: handle webhook authors # TODO: handle webhook authors
res['author'] = await self.get_user(res['author_id']) res['author'] = await self.get_user(res['author_id'])
@ -553,13 +609,7 @@ class Storage:
# TODO: res['pinned'] # TODO: res['pinned']
res['pinned'] = False res['pinned'] = False
# this is specifically for lazy guilds. # this is specifically for lazy guilds:
guild_id = await self.db.fetchval("""
SELECT guild_id
FROM guild_channels
WHERE guild_channels.id = $1
""", int(res['channel_id']))
# only insert when the channel # only insert when the channel
# is actually from a guild. # is actually from a guild.
if guild_id: if guild_id:
@ -811,22 +861,6 @@ class Storage:
return res return res
async def get_all_dms(self, user_id: int) -> List[Dict[str, Any]]:
"""Get all DMs for a user, regardless of the DM state."""
dm_ids = await self.db.fetch("""
SELECT id
FROM dm_channels
WHERE party1_id = $1 OR party2_id = $2
""", user_id)
res = []
for dm_id in dm_ids:
dm_chan = await self.get_dm(dm_id, user_id)
res.append(dm_chan)
return res
async def get_read_state(self, user_id: int) -> List[Dict[str, Any]]: async def get_read_state(self, user_id: int) -> List[Dict[str, Any]]:
"""Get the read state for a user.""" """Get the read state for a user."""
rows = await self.db.fetch(""" rows = await self.db.fetch("""
@ -850,6 +884,7 @@ class Storage:
return res return res
async def guild_from_channel(self, channel_id: int): async def guild_from_channel(self, channel_id: int):
"""Get the guild id coming from a channel id."""
return await self.db.fetchval(""" return await self.db.fetchval("""
SELECT guild_id SELECT guild_id
FROM guild_channels FROM guild_channels