mirror of https://gitlab.com/litecord/litecord.git
storage: add channel fetching logic
- litecord: add enums - storage: add get_user - storage: use column::text instead of str() in some cases
This commit is contained in:
parent
3eb6d5e60f
commit
c477b2ed50
|
|
@ -0,0 +1,7 @@
|
||||||
|
|
||||||
|
class ChannelType:
|
||||||
|
GUILD_TEXT = 0
|
||||||
|
DM = 1
|
||||||
|
GUILD_VOICE = 2
|
||||||
|
GROUP_DM = 3
|
||||||
|
GUILD_CATEGORY = 4
|
||||||
|
|
@ -18,8 +18,7 @@ class GatewayState:
|
||||||
self.seq = kwargs.get('seq', 0)
|
self.seq = kwargs.get('seq', 0)
|
||||||
self.shard = kwargs.get('shard', [0, 1])
|
self.shard = kwargs.get('shard', [0, 1])
|
||||||
self.user_id = kwargs.get('user_id')
|
self.user_id = kwargs.get('user_id')
|
||||||
|
self.bot = kwargs.get('bot', False)
|
||||||
self.ws = None
|
|
||||||
|
|
||||||
for key in kwargs:
|
for key in kwargs:
|
||||||
value = kwargs[key]
|
value = kwargs[key]
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,6 @@ class GatewayWebsocket:
|
||||||
def __init__(self, ws, **kwargs):
|
def __init__(self, ws, **kwargs):
|
||||||
self.ext = WebsocketObjects(*kwargs['prop'])
|
self.ext = WebsocketObjects(*kwargs['prop'])
|
||||||
self.storage = self.ext.storage
|
self.storage = self.ext.storage
|
||||||
self.state_manager = self.ext.state_manager
|
|
||||||
self.ws = ws
|
self.ws = ws
|
||||||
|
|
||||||
self.wsp = WebsocketProperties(kwargs.get('v'),
|
self.wsp = WebsocketProperties(kwargs.get('v'),
|
||||||
|
|
@ -99,6 +98,7 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
async def _make_guild_list(self) -> List[int]:
|
async def _make_guild_list(self) -> List[int]:
|
||||||
# TODO: This function does not account for sharding.
|
# TODO: This function does not account for sharding.
|
||||||
|
# TODO: This function does not account for bots.
|
||||||
user_id = self.state.user_id
|
user_id = self.state.user_id
|
||||||
|
|
||||||
guild_ids = await self.ext.db.fetch("""
|
guild_ids = await self.ext.db.fetch("""
|
||||||
|
|
@ -107,15 +107,27 @@ class GatewayWebsocket:
|
||||||
WHERE user_id = $1
|
WHERE user_id = $1
|
||||||
""", user_id)
|
""", user_id)
|
||||||
|
|
||||||
return [{
|
if self.state.bot:
|
||||||
'id': row[0],
|
return [{
|
||||||
'unavailable': True,
|
'id': row[0],
|
||||||
} for row in guild_ids]
|
'unavailable': True,
|
||||||
|
} for row in guild_ids]
|
||||||
|
|
||||||
|
return [
|
||||||
|
await self.storage.get_guild(row[0], self.state)
|
||||||
|
for row in guild_ids
|
||||||
|
]
|
||||||
|
|
||||||
async def guild_dispatch(self, unavailable_guilds: List[dict]):
|
async def guild_dispatch(self, unavailable_guilds: List[dict]):
|
||||||
|
"""Dispatch GUILD_CREATE information."""
|
||||||
|
|
||||||
|
# Users don't get asynchronous guild dispatching.
|
||||||
|
if not self.state.bot:
|
||||||
|
return
|
||||||
|
|
||||||
for guild_obj in unavailable_guilds:
|
for guild_obj in unavailable_guilds:
|
||||||
guild = await self.storage.get_guild(guild_obj['id'],
|
guild = await self.storage.get_guild(guild_obj['id'],
|
||||||
self.state.user_id)
|
self.state)
|
||||||
|
|
||||||
if not guild:
|
if not guild:
|
||||||
continue
|
continue
|
||||||
|
|
@ -123,7 +135,7 @@ class GatewayWebsocket:
|
||||||
await self.dispatch('GUILD_CREATE', dict(guild))
|
await self.dispatch('GUILD_CREATE', dict(guild))
|
||||||
|
|
||||||
async def dispatch_ready(self):
|
async def dispatch_ready(self):
|
||||||
"""Dispatch the READY packet for a connecting user."""
|
"""Dispatch the READY packet for a connecting account."""
|
||||||
guilds = await self._make_guild_list()
|
guilds = await self._make_guild_list()
|
||||||
user = await self.storage.get_user(self.state.user_id, True)
|
user = await self.storage.get_user(self.state.user_id, True)
|
||||||
|
|
||||||
|
|
@ -180,20 +192,25 @@ class GatewayWebsocket:
|
||||||
except AuthError:
|
except AuthError:
|
||||||
raise WebsocketClose(4004, 'Authentication failed')
|
raise WebsocketClose(4004, 'Authentication failed')
|
||||||
|
|
||||||
|
bot = await self.ext.db.fetchval("""
|
||||||
|
SELECT bot FROM users
|
||||||
|
WHERE id = $1
|
||||||
|
""", user_id)
|
||||||
|
|
||||||
self.state = GatewayState(
|
self.state = GatewayState(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
bot=bot,
|
||||||
properties=properties,
|
properties=properties,
|
||||||
compress=compress,
|
compress=compress,
|
||||||
large=large,
|
large=large,
|
||||||
shard=shard,
|
shard=shard,
|
||||||
presence=presence,
|
presence=presence,
|
||||||
|
ws=self
|
||||||
)
|
)
|
||||||
|
|
||||||
self.state.ws = self
|
|
||||||
|
|
||||||
await self._check_shards()
|
await self._check_shards()
|
||||||
|
|
||||||
self.state_manager.insert(self.state)
|
self.ext.state_manager.insert(self.state)
|
||||||
await self.dispatch_ready()
|
await self.dispatch_ready()
|
||||||
|
|
||||||
async def process_message(self, payload):
|
async def process_message(self, payload):
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
from typing import Dict
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
from .enums import ChannelType
|
||||||
|
|
||||||
|
|
||||||
class Storage:
|
class Storage:
|
||||||
|
|
@ -6,10 +8,27 @@ class Storage:
|
||||||
def __init__(self, db):
|
def __init__(self, db):
|
||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
async def get_user(self, guild_id, secure=False):
|
async def get_user(self, user_id, secure=False) -> Dict[str, Any]:
|
||||||
pass
|
"""Get a single user payload."""
|
||||||
|
user_row = await self.db.fetchrow("""
|
||||||
|
SELECT id::text, username, discriminator, avatar, email,
|
||||||
|
flags, bot, mfa_enabled, verified, premium
|
||||||
|
FROM users
|
||||||
|
WHERE users.id = $1
|
||||||
|
""", user_id)
|
||||||
|
|
||||||
async def get_guild(self, guild_id: int, state) -> Dict:
|
duser = dict(user_row)
|
||||||
|
|
||||||
|
if not secure:
|
||||||
|
duser.pop('email')
|
||||||
|
duser.pop('mfa_enabled')
|
||||||
|
duser.pop('verified')
|
||||||
|
duser.pop('mfa_enabled')
|
||||||
|
|
||||||
|
return duser
|
||||||
|
|
||||||
|
async def get_guild(self, guild_id: int, state=None) -> Dict:
|
||||||
|
"""Get gulid payload."""
|
||||||
row = await self.db.fetchrow("""
|
row = await self.db.fetchrow("""
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM guilds
|
FROM guilds
|
||||||
|
|
@ -39,6 +58,98 @@ class Storage:
|
||||||
'emojis': [],
|
'emojis': [],
|
||||||
}}
|
}}
|
||||||
|
|
||||||
|
async def get_member_data(self, guild_id) -> List[Dict[str, Any]]:
|
||||||
|
"""Get member information on a guild."""
|
||||||
|
members_basic = await self.db.fetch("""
|
||||||
|
SELECT user_id, nickname, joined_at
|
||||||
|
FROM members
|
||||||
|
WHERE guild_id = $1
|
||||||
|
""", guild_id)
|
||||||
|
|
||||||
|
members = []
|
||||||
|
|
||||||
|
for row in members_basic:
|
||||||
|
member_id = row['user_id']
|
||||||
|
|
||||||
|
members_roles = await self.db.fetch("""
|
||||||
|
SELECT role_id::text
|
||||||
|
FROM member_roles
|
||||||
|
WHERE guild_id = $1 AND user_id = $2
|
||||||
|
""", guild_id, member_id)
|
||||||
|
|
||||||
|
members.append({
|
||||||
|
'user': await self.get_user(member_id),
|
||||||
|
'nick': row['nickname'],
|
||||||
|
'roles': [row[0] for row in members_roles],
|
||||||
|
'joined_at': row['joined_at'].isoformat(),
|
||||||
|
'deaf': row['deafened'],
|
||||||
|
'mute': row['muted'],
|
||||||
|
})
|
||||||
|
|
||||||
|
return members
|
||||||
|
|
||||||
|
async def _channels_extra(self, row, channel_type: int) -> Dict:
|
||||||
|
"""Fill in more information about a channel."""
|
||||||
|
# TODO: This could probably be better with a dictionary.
|
||||||
|
|
||||||
|
# TODO: dm and group dm?
|
||||||
|
if channel_type == ChannelType.GUILD_TEXT:
|
||||||
|
topic = await self.db.fetchval("""
|
||||||
|
SELECT topic FROM guild_text_channels
|
||||||
|
WHERE id = $1
|
||||||
|
""", row['id'])
|
||||||
|
|
||||||
|
return {**row, **{
|
||||||
|
'topic': topic,
|
||||||
|
}}
|
||||||
|
elif channel_type == ChannelType.GUILD_VOICE:
|
||||||
|
vrow = await self.db.fetchval("""
|
||||||
|
SELECT bitrate, user_limit FROM guild_voice_channels
|
||||||
|
WHERE id = $1
|
||||||
|
""", row['id'])
|
||||||
|
|
||||||
|
return {**row, **dict(vrow)}
|
||||||
|
|
||||||
|
async def get_channel_data(self, guild_id) -> List[Dict]:
|
||||||
|
"""Get channel information on a guild"""
|
||||||
|
channel_basics = await self.db.fetch("""
|
||||||
|
SELECT * FROM guild_channels
|
||||||
|
WHERE guild_id = $1
|
||||||
|
""", guild_id)
|
||||||
|
|
||||||
|
channels = []
|
||||||
|
|
||||||
|
for row in channel_basics:
|
||||||
|
ctype = await self.db.fetchval("""
|
||||||
|
SELECT channel_type FROM channels
|
||||||
|
WHERE id = $1
|
||||||
|
""", row['id'])
|
||||||
|
|
||||||
|
res = await self._channels_extra(row, ctype)
|
||||||
|
|
||||||
|
# type is a SQL keyword, so we can't do
|
||||||
|
# 'overwrite_type AS type'
|
||||||
|
overwrite_rows = await self.db.fetch("""
|
||||||
|
SELECT user_id::text AS id, overwrite_type, allow, deny
|
||||||
|
FROM channel_overwrites
|
||||||
|
WHERE channel_id = $1
|
||||||
|
""", row['id'])
|
||||||
|
|
||||||
|
def _overwrite_convert(ov_row):
|
||||||
|
drow = dict(ov_row)
|
||||||
|
drow['type'] = drow['overwrite_type']
|
||||||
|
drow.pop('overwrite_type')
|
||||||
|
return drow
|
||||||
|
|
||||||
|
res['permission_overwrites'] = list(map(_overwrite_convert,
|
||||||
|
overwrite_rows))
|
||||||
|
|
||||||
|
# Making sure.
|
||||||
|
res['id'] = str(res['id'])
|
||||||
|
channels.append(res)
|
||||||
|
|
||||||
|
return channels
|
||||||
|
|
||||||
async def get_guild_extra(self, guild_id: int, state=None) -> Dict:
|
async def get_guild_extra(self, guild_id: int, state=None) -> Dict:
|
||||||
"""Get extra information about a guild."""
|
"""Get extra information about a guild."""
|
||||||
res = {}
|
res = {}
|
||||||
|
|
@ -59,37 +170,14 @@ class Storage:
|
||||||
res['large'] = state.large > member_count
|
res['large'] = state.large > member_count
|
||||||
res['joined_at'] = joined_at.isoformat()
|
res['joined_at'] = joined_at.isoformat()
|
||||||
|
|
||||||
members_basic = await self.db.fetch("""
|
members = await self.get_member_data(guild_id)
|
||||||
SELECT user_id, nickname, joined_at
|
channels = await self.get_channel_data(guild_id)
|
||||||
FROM members
|
|
||||||
WHERE guild_id = $1
|
|
||||||
""", guild_id)
|
|
||||||
|
|
||||||
members = []
|
|
||||||
|
|
||||||
for row in members_basic:
|
|
||||||
member_id = row['user_id']
|
|
||||||
|
|
||||||
members_roles = await self.db.fetch("""
|
|
||||||
SELECT role_id
|
|
||||||
FROM member_roles
|
|
||||||
WHERE guild_id = $1 AND user_id = $2
|
|
||||||
""", guild_id, member_id)
|
|
||||||
|
|
||||||
members.append({
|
|
||||||
'user': await self.get_user(member_id),
|
|
||||||
'nick': row['nickname'],
|
|
||||||
'roles': [str(row[0]) for row in members_roles],
|
|
||||||
'joined_at': row['joined_at'].isoformat(),
|
|
||||||
'deaf': row['deafened'],
|
|
||||||
'mute': row['muted'],
|
|
||||||
})
|
|
||||||
|
|
||||||
return {**res, **{
|
return {**res, **{
|
||||||
'member_count': member_count,
|
'member_count': member_count,
|
||||||
'members': members,
|
'members': members,
|
||||||
'voice_states': [],
|
'voice_states': [],
|
||||||
|
'channels': channels,
|
||||||
# TODO: finish those
|
# TODO: finish those
|
||||||
'channels': [],
|
|
||||||
'presences': [],
|
'presences': [],
|
||||||
}}
|
}}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue