mirror of https://gitlab.com/litecord/litecord.git
storage: multiple enhancements
storage: - don't use SELECT * when get_guild'ing - use field::text instead of the str() call madness - simplify get_member_data_one - fix parent_id being an int on get_channel_data - add Storage.get_role - better mention resolving on get_message - remove Storage.get_all_dms
This commit is contained in:
parent
d28c0f1bc6
commit
da4ce66a0c
|
|
@ -14,8 +14,20 @@ async def _dummy(any_id):
|
|||
return str(any_id)
|
||||
|
||||
|
||||
def maybe(typ, val):
|
||||
return typ(val) if val is not None else None
|
||||
|
||||
|
||||
def dict_(val):
|
||||
return dict(val) if val else None
|
||||
return maybe(dict, val)
|
||||
|
||||
|
||||
def str_(val):
|
||||
return maybe(str, val)
|
||||
|
||||
|
||||
def timestamp_(dt):
|
||||
return dt.isoformat() if dt else None
|
||||
|
||||
|
||||
async def _set_json(con):
|
||||
|
|
@ -108,7 +120,13 @@ class Storage:
|
|||
async def get_guild(self, guild_id: int, user_id=None) -> Dict:
|
||||
"""Get gulid payload."""
|
||||
row = await self.db.fetchrow("""
|
||||
SELECT *
|
||||
SELECT id::text, owner_id::text, name, icon, splash,
|
||||
region, afk_channel_id::text, afk_timeout,
|
||||
verification_level, default_message_notifications,
|
||||
explicit_content_filter, mfa_level,
|
||||
embed_enabled, embed_channel_id::text,
|
||||
widget_enabled, widget_channel_id::text,
|
||||
system_channel_id::text
|
||||
FROM guilds
|
||||
WHERE guilds.id = $1
|
||||
""", guild_id)
|
||||
|
|
@ -119,22 +137,7 @@ class Storage:
|
|||
drow = dict(row)
|
||||
|
||||
if user_id:
|
||||
drow['owner'] = drow['owner_id'] == user_id
|
||||
|
||||
# TODO: Probably a really bad idea to repeat str() calls
|
||||
# Any ideas to make this simpler?
|
||||
# (No, changing the types on the db wouldn't be nice)
|
||||
drow['id'] = str(drow['id'])
|
||||
drow['owner_id'] = str(drow['owner_id'])
|
||||
drow['afk_channel_id'] = str(drow['afk_channel_id']) \
|
||||
if drow['afk_channel_id'] else None
|
||||
drow['embed_channel_id'] = str(drow['embed_channel_id']) \
|
||||
if drow['embed_channel_id'] else None
|
||||
|
||||
drow['widget_channel_id'] = str(drow['widget_channel_id']) \
|
||||
if drow['widget_channel_id'] else None
|
||||
drow['system_channel_id'] = str(drow['system_channel_id']) \
|
||||
if drow['system_channel_id'] else None
|
||||
drow['owner'] = drow['owner_id'] == str(user_id)
|
||||
|
||||
# TODO: emojis
|
||||
drow['emojis'] = []
|
||||
|
|
@ -150,31 +153,13 @@ class Storage:
|
|||
|
||||
return [row['guild_id'] for row in guild_ids]
|
||||
|
||||
async def get_member_data_one(self, guild_id, member_id) -> Dict[str, any]:
|
||||
basic = await self.db.fetchrow("""
|
||||
async def _member_basic(self, guild_id: int, member_id: int):
|
||||
return await self.db.fetchrow("""
|
||||
SELECT user_id, nickname, joined_at, deafened, muted
|
||||
FROM members
|
||||
WHERE guild_id = $1 and user_id = $2
|
||||
""", guild_id, member_id)
|
||||
|
||||
if not basic:
|
||||
return
|
||||
|
||||
members_roles = await self.db.fetch("""
|
||||
SELECT role_id::text
|
||||
FROM member_roles
|
||||
WHERE guild_id = $1 AND user_id = $2
|
||||
""", guild_id, member_id)
|
||||
|
||||
return {
|
||||
'user': await self.get_user(member_id),
|
||||
'nick': basic['nickname'],
|
||||
'roles': [row[0] for row in members_roles],
|
||||
'joined_at': basic['joined_at'].isoformat(),
|
||||
'deaf': basic['deafened'],
|
||||
'mute': basic['muted'],
|
||||
}
|
||||
|
||||
async def _member_dict(self, row, guild_id, member_id) -> Dict[str, Any]:
|
||||
members_roles = await self.db.fetch("""
|
||||
SELECT role_id::text
|
||||
|
|
@ -191,18 +176,26 @@ class Storage:
|
|||
'mute': row['muted'],
|
||||
}
|
||||
|
||||
async def get_member_data_one(self, guild_id: int,
|
||||
member_id: int) -> Dict[str, Any]:
|
||||
"""Get data about one member in a guild."""
|
||||
basic = await self._member_basic(guild_id, member_id)
|
||||
|
||||
if not basic:
|
||||
return
|
||||
|
||||
return await self._member_dict(basic, guild_id, member_id)
|
||||
|
||||
async def get_member_multi(self, guild_id: int,
|
||||
user_ids: List[int]) -> List[Dict[str, Any]]:
|
||||
"""Get member information about multiple users in a guild."""
|
||||
members = []
|
||||
|
||||
# bad idea bad idea bad idea
|
||||
for user_id in user_ids:
|
||||
row = await self.db.fetchrow("""
|
||||
SELECT user_id, nickname, joined_at, defened, muted
|
||||
FROM members
|
||||
WHERE guild_id = $1 AND user_id = $2
|
||||
""", guild_id, user_id)
|
||||
row = await self._member_basic(guild_id, user_id)
|
||||
|
||||
if not row:
|
||||
continue
|
||||
|
||||
member = await self._member_dict(row, guild_id, user_id)
|
||||
members.append(member)
|
||||
|
|
@ -247,14 +240,14 @@ class Storage:
|
|||
WHERE channel_id = $1
|
||||
""", channel_id)
|
||||
|
||||
async def chan_last_message_str(self, channel_id: int) -> int:
|
||||
async def chan_last_message_str(self, channel_id: int) -> str:
|
||||
"""Get the last message ID but in a string.
|
||||
|
||||
Converts to None (not the string "None") when
|
||||
no last message ID is found.
|
||||
"""
|
||||
last_msg = await self.chan_last_message(channel_id)
|
||||
return str(last_msg) if last_msg is not None else None
|
||||
return str_(last_msg)
|
||||
|
||||
async def _channels_extra(self, row) -> Dict:
|
||||
"""Fill in more information about a channel."""
|
||||
|
|
@ -293,7 +286,7 @@ class Storage:
|
|||
WHERE channels.id = $1
|
||||
""", channel_id)
|
||||
|
||||
async def _chan_overwrites(self, channel_id):
|
||||
async def _chan_overwrites(self, channel_id: int) -> List[Dict[str, Any]]:
|
||||
overwrite_rows = await self.db.fetch("""
|
||||
SELECT target_type, target_role, target_user, allow, deny
|
||||
FROM channel_overwrites
|
||||
|
|
@ -369,11 +362,13 @@ class Storage:
|
|||
drow['id'] = str(drow['id'])
|
||||
return drow
|
||||
elif ctype == ChannelType.GROUP_DM:
|
||||
# TODO: group dms
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
async def get_channel_ids(self, guild_id: int) -> List[int]:
|
||||
"""Get all channel IDs in a guild."""
|
||||
rows = await self.db.fetch("""
|
||||
SELECT id
|
||||
FROM guild_channels
|
||||
|
|
@ -383,9 +378,9 @@ class Storage:
|
|||
return [r['id'] for r in rows]
|
||||
|
||||
async def get_channel_data(self, guild_id) -> List[Dict]:
|
||||
"""Get channel information on a guild"""
|
||||
"""Get channel list information on a guild"""
|
||||
channel_basics = await self.db.fetch("""
|
||||
SELECT id, guild_id::text, parent_id, name, position, nsfw
|
||||
SELECT id, guild_id::text, parent_id::text, name, position, nsfw
|
||||
FROM guild_channels
|
||||
WHERE guild_id = $1
|
||||
""", guild_id)
|
||||
|
|
@ -412,7 +407,31 @@ class Storage:
|
|||
|
||||
return channels
|
||||
|
||||
async def get_role(self, role_id: int,
|
||||
guild_id: int = None) -> Dict[str, Any]:
|
||||
"""get a single role's information."""
|
||||
|
||||
guild_field = 'AND guild_id = $2' if guild_id else ''
|
||||
|
||||
args = [role_id]
|
||||
if guild_id:
|
||||
args.append(guild_id)
|
||||
|
||||
row = await self.db.fetchrow(f"""
|
||||
SELECT id::text, name, color, hoist, position,
|
||||
permissions, managed, mentionable
|
||||
FROM roles
|
||||
WHERE id = $1 {guild_field}
|
||||
LIMIT 1
|
||||
""", *args)
|
||||
|
||||
if not row:
|
||||
return
|
||||
|
||||
return dict(row)
|
||||
|
||||
async def get_role_data(self, guild_id: int) -> List[Dict[str, Any]]:
|
||||
"""Get role list information on a guild."""
|
||||
roledata = await self.db.fetch("""
|
||||
SELECT id::text, name, color, hoist, position,
|
||||
permissions, managed, mentionable
|
||||
|
|
@ -420,12 +439,7 @@ class Storage:
|
|||
WHERE guild_id = $1
|
||||
""", guild_id)
|
||||
|
||||
roles = []
|
||||
|
||||
for row in roledata:
|
||||
roles.append(dict(row))
|
||||
|
||||
return roles
|
||||
return list(map(dict, roledata))
|
||||
|
||||
async def get_guild_extra(self, guild_id: int,
|
||||
user_id=None, large=None) -> Dict:
|
||||
|
|
@ -438,14 +452,16 @@ class Storage:
|
|||
WHERE guild_id = $1
|
||||
""", guild_id)
|
||||
|
||||
if user_id and large:
|
||||
if large:
|
||||
res['large'] = member_count > large
|
||||
|
||||
if user_id:
|
||||
joined_at = await self.db.fetchval("""
|
||||
SELECT joined_at
|
||||
FROM members
|
||||
WHERE guild_id = $1 AND user_id = $2
|
||||
""", guild_id, user_id)
|
||||
|
||||
res['large'] = member_count > large
|
||||
res['joined_at'] = joined_at.isoformat()
|
||||
|
||||
members = await self.get_member_data(guild_id)
|
||||
|
|
@ -457,23 +473,29 @@ class Storage:
|
|||
return {**res, **{
|
||||
'member_count': member_count,
|
||||
'members': members,
|
||||
'voice_states': [],
|
||||
'channels': channels,
|
||||
'roles': roles,
|
||||
|
||||
'presences': await self.presence.guild_presences(
|
||||
mids, guild_id
|
||||
),
|
||||
|
||||
# TODO: voice state management
|
||||
'voice_states': [],
|
||||
}}
|
||||
|
||||
async def get_guild_full(self, guild_id: int,
|
||||
user_id: int, large_count: int = 250) -> Dict:
|
||||
"""Get full information on a guild.
|
||||
|
||||
This is a very expensive operation.
|
||||
"""
|
||||
guild = await self.get_guild(guild_id, user_id)
|
||||
extra = await self.get_guild_extra(guild_id, user_id, large_count)
|
||||
|
||||
return {**guild, **extra}
|
||||
|
||||
async def guild_exists(self, guild_id: int):
|
||||
async def guild_exists(self, guild_id: int) -> bool:
|
||||
"""Return if a given guild ID exists."""
|
||||
owner_id = await self.db.fetch("""
|
||||
SELECT owner_id
|
||||
|
|
@ -484,6 +506,7 @@ class Storage:
|
|||
return owner_id is not None
|
||||
|
||||
async def get_member_ids(self, guild_id: int) -> List[int]:
|
||||
"""Get member IDs inside a guild"""
|
||||
rows = await self.db.fetch("""
|
||||
SELECT user_id
|
||||
FROM members
|
||||
|
|
@ -492,7 +515,7 @@ class Storage:
|
|||
|
||||
return [r[0] for r in rows]
|
||||
|
||||
async def _msg_regex(self, regex, method, content) -> List[Dict]:
|
||||
async def _msg_regex(self, regex, func, content) -> List[Dict]:
|
||||
res = []
|
||||
|
||||
for match in regex.finditer(content):
|
||||
|
|
@ -503,8 +526,8 @@ class Storage:
|
|||
except ValueError:
|
||||
continue
|
||||
|
||||
obj = await method(found_id)
|
||||
if obj:
|
||||
obj = await func(found_id)
|
||||
if obj is not None:
|
||||
res.append(obj)
|
||||
|
||||
return res
|
||||
|
|
@ -525,17 +548,50 @@ class Storage:
|
|||
res = dict(row)
|
||||
res['nonce'] = str(res['nonce'])
|
||||
res['timestamp'] = res['timestamp'].isoformat()
|
||||
res['edited_timestamp'] = timestamp_(res['edited_timestamp'])
|
||||
|
||||
res['type'] = res['message_type']
|
||||
res.pop('message_type')
|
||||
|
||||
channel_id = int(row['channel_id'])
|
||||
content = row['content']
|
||||
guild_id = await self.guild_from_channel(channel_id)
|
||||
|
||||
# calculate user mentions and role mentions by regex
|
||||
res['mentions'] = await self._msg_regex(USER_MENTION, self.get_user,
|
||||
async def _get_member(user_id):
|
||||
user = await self.get_user(user_id)
|
||||
member = None
|
||||
|
||||
if guild_id:
|
||||
# TODO: maybe make this partial?
|
||||
member = await self.get_member_data_one(guild_id, user_id)
|
||||
|
||||
return {**user, **{'member': member}} if member else user
|
||||
|
||||
res['mentions'] = await self._msg_regex(USER_MENTION, _get_member,
|
||||
row['content'])
|
||||
|
||||
# _dummy just returns the string of the id, since we don't
|
||||
# actually use the role objects in mention_roles, just their ids.
|
||||
res['mention_roles'] = await self._msg_regex(ROLE_MENTION, _dummy,
|
||||
row['content'])
|
||||
async def _get_role_mention(role_id: int):
|
||||
if not guild_id:
|
||||
return str(role_id)
|
||||
|
||||
if role_id == guild_id:
|
||||
# TODO: check MENTION_EVERYONE permission
|
||||
return str(role_id)
|
||||
|
||||
role = await self.get_role(role_id, guild_id)
|
||||
if not role:
|
||||
return
|
||||
|
||||
if not role['mentionable']:
|
||||
return
|
||||
|
||||
return str(role_id)
|
||||
|
||||
res['mention_roles'] = await self._msg_regex(
|
||||
ROLE_MENTION, _get_role_mention, content)
|
||||
|
||||
# TODO: handle webhook authors
|
||||
res['author'] = await self.get_user(res['author_id'])
|
||||
|
|
@ -553,13 +609,7 @@ class Storage:
|
|||
# TODO: res['pinned']
|
||||
res['pinned'] = False
|
||||
|
||||
# this is specifically for lazy guilds.
|
||||
guild_id = await self.db.fetchval("""
|
||||
SELECT guild_id
|
||||
FROM guild_channels
|
||||
WHERE guild_channels.id = $1
|
||||
""", int(res['channel_id']))
|
||||
|
||||
# this is specifically for lazy guilds:
|
||||
# only insert when the channel
|
||||
# is actually from a guild.
|
||||
if guild_id:
|
||||
|
|
@ -811,22 +861,6 @@ class Storage:
|
|||
|
||||
return res
|
||||
|
||||
async def get_all_dms(self, user_id: int) -> List[Dict[str, Any]]:
|
||||
"""Get all DMs for a user, regardless of the DM state."""
|
||||
dm_ids = await self.db.fetch("""
|
||||
SELECT id
|
||||
FROM dm_channels
|
||||
WHERE party1_id = $1 OR party2_id = $2
|
||||
""", user_id)
|
||||
|
||||
res = []
|
||||
|
||||
for dm_id in dm_ids:
|
||||
dm_chan = await self.get_dm(dm_id, user_id)
|
||||
res.append(dm_chan)
|
||||
|
||||
return res
|
||||
|
||||
async def get_read_state(self, user_id: int) -> List[Dict[str, Any]]:
|
||||
"""Get the read state for a user."""
|
||||
rows = await self.db.fetch("""
|
||||
|
|
@ -850,6 +884,7 @@ class Storage:
|
|||
return res
|
||||
|
||||
async def guild_from_channel(self, channel_id: int):
|
||||
"""Get the guild id coming from a channel id."""
|
||||
return await self.db.fetchval("""
|
||||
SELECT guild_id
|
||||
FROM guild_channels
|
||||
|
|
|
|||
Loading…
Reference in New Issue