storage: add channel fetching logic

- litecord: add enums
 - storage: add get_user
 - storage: use column::text instead of str() in some cases
This commit is contained in:
Luna Mendes 2018-06-20 17:39:21 -03:00
parent 3eb6d5e60f
commit c477b2ed50
4 changed files with 153 additions and 42 deletions

7
litecord/enums.py Normal file
View File

@ -0,0 +1,7 @@
class ChannelType:
GUILD_TEXT = 0
DM = 1
GUILD_VOICE = 2
GROUP_DM = 3
GUILD_CATEGORY = 4

View File

@ -18,8 +18,7 @@ class GatewayState:
self.seq = kwargs.get('seq', 0) self.seq = kwargs.get('seq', 0)
self.shard = kwargs.get('shard', [0, 1]) self.shard = kwargs.get('shard', [0, 1])
self.user_id = kwargs.get('user_id') self.user_id = kwargs.get('user_id')
self.bot = kwargs.get('bot', False)
self.ws = None
for key in kwargs: for key in kwargs:
value = kwargs[key] value = kwargs[key]

View File

@ -45,7 +45,6 @@ class GatewayWebsocket:
def __init__(self, ws, **kwargs): def __init__(self, ws, **kwargs):
self.ext = WebsocketObjects(*kwargs['prop']) self.ext = WebsocketObjects(*kwargs['prop'])
self.storage = self.ext.storage self.storage = self.ext.storage
self.state_manager = self.ext.state_manager
self.ws = ws self.ws = ws
self.wsp = WebsocketProperties(kwargs.get('v'), self.wsp = WebsocketProperties(kwargs.get('v'),
@ -99,6 +98,7 @@ class GatewayWebsocket:
async def _make_guild_list(self) -> List[int]: async def _make_guild_list(self) -> List[int]:
# TODO: This function does not account for sharding. # TODO: This function does not account for sharding.
# TODO: This function does not account for bots.
user_id = self.state.user_id user_id = self.state.user_id
guild_ids = await self.ext.db.fetch(""" guild_ids = await self.ext.db.fetch("""
@ -107,15 +107,27 @@ class GatewayWebsocket:
WHERE user_id = $1 WHERE user_id = $1
""", user_id) """, user_id)
return [{ if self.state.bot:
'id': row[0], return [{
'unavailable': True, 'id': row[0],
} for row in guild_ids] 'unavailable': True,
} for row in guild_ids]
return [
await self.storage.get_guild(row[0], self.state)
for row in guild_ids
]
async def guild_dispatch(self, unavailable_guilds: List[dict]): async def guild_dispatch(self, unavailable_guilds: List[dict]):
"""Dispatch GUILD_CREATE information."""
# Users don't get asynchronous guild dispatching.
if not self.state.bot:
return
for guild_obj in unavailable_guilds: for guild_obj in unavailable_guilds:
guild = await self.storage.get_guild(guild_obj['id'], guild = await self.storage.get_guild(guild_obj['id'],
self.state.user_id) self.state)
if not guild: if not guild:
continue continue
@ -123,7 +135,7 @@ class GatewayWebsocket:
await self.dispatch('GUILD_CREATE', dict(guild)) await self.dispatch('GUILD_CREATE', dict(guild))
async def dispatch_ready(self): async def dispatch_ready(self):
"""Dispatch the READY packet for a connecting user.""" """Dispatch the READY packet for a connecting account."""
guilds = await self._make_guild_list() guilds = await self._make_guild_list()
user = await self.storage.get_user(self.state.user_id, True) user = await self.storage.get_user(self.state.user_id, True)
@ -180,20 +192,25 @@ class GatewayWebsocket:
except AuthError: except AuthError:
raise WebsocketClose(4004, 'Authentication failed') raise WebsocketClose(4004, 'Authentication failed')
bot = await self.ext.db.fetchval("""
SELECT bot FROM users
WHERE id = $1
""", user_id)
self.state = GatewayState( self.state = GatewayState(
user_id=user_id, user_id=user_id,
bot=bot,
properties=properties, properties=properties,
compress=compress, compress=compress,
large=large, large=large,
shard=shard, shard=shard,
presence=presence, presence=presence,
ws=self
) )
self.state.ws = self
await self._check_shards() await self._check_shards()
self.state_manager.insert(self.state) self.ext.state_manager.insert(self.state)
await self.dispatch_ready() await self.dispatch_ready()
async def process_message(self, payload): async def process_message(self, payload):

View File

@ -1,4 +1,6 @@
from typing import Dict from typing import List, Dict, Any
from .enums import ChannelType
class Storage: class Storage:
@ -6,10 +8,27 @@ class Storage:
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
async def get_user(self, guild_id, secure=False): async def get_user(self, user_id, secure=False) -> Dict[str, Any]:
pass """Get a single user payload."""
user_row = await self.db.fetchrow("""
SELECT id::text, username, discriminator, avatar, email,
flags, bot, mfa_enabled, verified, premium
FROM users
WHERE users.id = $1
""", user_id)
async def get_guild(self, guild_id: int, state) -> Dict: duser = dict(user_row)
if not secure:
duser.pop('email')
duser.pop('mfa_enabled')
duser.pop('verified')
duser.pop('mfa_enabled')
return duser
async def get_guild(self, guild_id: int, state=None) -> Dict:
"""Get gulid payload."""
row = await self.db.fetchrow(""" row = await self.db.fetchrow("""
SELECT * SELECT *
FROM guilds FROM guilds
@ -39,6 +58,98 @@ class Storage:
'emojis': [], 'emojis': [],
}} }}
async def get_member_data(self, guild_id) -> List[Dict[str, Any]]:
"""Get member information on a guild."""
members_basic = await self.db.fetch("""
SELECT user_id, nickname, joined_at
FROM members
WHERE guild_id = $1
""", guild_id)
members = []
for row in members_basic:
member_id = row['user_id']
members_roles = await self.db.fetch("""
SELECT role_id::text
FROM member_roles
WHERE guild_id = $1 AND user_id = $2
""", guild_id, member_id)
members.append({
'user': await self.get_user(member_id),
'nick': row['nickname'],
'roles': [row[0] for row in members_roles],
'joined_at': row['joined_at'].isoformat(),
'deaf': row['deafened'],
'mute': row['muted'],
})
return members
async def _channels_extra(self, row, channel_type: int) -> Dict:
"""Fill in more information about a channel."""
# TODO: This could probably be better with a dictionary.
# TODO: dm and group dm?
if channel_type == ChannelType.GUILD_TEXT:
topic = await self.db.fetchval("""
SELECT topic FROM guild_text_channels
WHERE id = $1
""", row['id'])
return {**row, **{
'topic': topic,
}}
elif channel_type == ChannelType.GUILD_VOICE:
vrow = await self.db.fetchval("""
SELECT bitrate, user_limit FROM guild_voice_channels
WHERE id = $1
""", row['id'])
return {**row, **dict(vrow)}
async def get_channel_data(self, guild_id) -> List[Dict]:
"""Get channel information on a guild"""
channel_basics = await self.db.fetch("""
SELECT * FROM guild_channels
WHERE guild_id = $1
""", guild_id)
channels = []
for row in channel_basics:
ctype = await self.db.fetchval("""
SELECT channel_type FROM channels
WHERE id = $1
""", row['id'])
res = await self._channels_extra(row, ctype)
# type is a SQL keyword, so we can't do
# 'overwrite_type AS type'
overwrite_rows = await self.db.fetch("""
SELECT user_id::text AS id, overwrite_type, allow, deny
FROM channel_overwrites
WHERE channel_id = $1
""", row['id'])
def _overwrite_convert(ov_row):
drow = dict(ov_row)
drow['type'] = drow['overwrite_type']
drow.pop('overwrite_type')
return drow
res['permission_overwrites'] = list(map(_overwrite_convert,
overwrite_rows))
# Making sure.
res['id'] = str(res['id'])
channels.append(res)
return channels
async def get_guild_extra(self, guild_id: int, state=None) -> Dict: async def get_guild_extra(self, guild_id: int, state=None) -> Dict:
"""Get extra information about a guild.""" """Get extra information about a guild."""
res = {} res = {}
@ -59,37 +170,14 @@ class Storage:
res['large'] = state.large > member_count res['large'] = state.large > member_count
res['joined_at'] = joined_at.isoformat() res['joined_at'] = joined_at.isoformat()
members_basic = await self.db.fetch(""" members = await self.get_member_data(guild_id)
SELECT user_id, nickname, joined_at channels = await self.get_channel_data(guild_id)
FROM members
WHERE guild_id = $1
""", guild_id)
members = []
for row in members_basic:
member_id = row['user_id']
members_roles = await self.db.fetch("""
SELECT role_id
FROM member_roles
WHERE guild_id = $1 AND user_id = $2
""", guild_id, member_id)
members.append({
'user': await self.get_user(member_id),
'nick': row['nickname'],
'roles': [str(row[0]) for row in members_roles],
'joined_at': row['joined_at'].isoformat(),
'deaf': row['deafened'],
'mute': row['muted'],
})
return {**res, **{ return {**res, **{
'member_count': member_count, 'member_count': member_count,
'members': members, 'members': members,
'voice_states': [], 'voice_states': [],
'channels': channels,
# TODO: finish those # TODO: finish those
'channels': [],
'presences': [], 'presences': [],
}} }}