diff --git a/litecord/gateway/gateway.py b/litecord/gateway/gateway.py index d0105e5..fe86226 100644 --- a/litecord/gateway/gateway.py +++ b/litecord/gateway/gateway.py @@ -2,7 +2,7 @@ import urllib.parse from .websocket import GatewayWebsocket -async def websocket_handler(db, sm, ws, url): +async def websocket_handler(prop, ws, url): qs = urllib.parse.parse_qs( urllib.parse.urlparse(url).query ) @@ -27,6 +27,6 @@ async def websocket_handler(db, sm, ws, url): if gw_compress and gw_compress not in ('zlib-stream',): return await ws.close(1000, 'Invalid gateway compress') - gws = GatewayWebsocket(sm, db, ws, v=gw_version, + gws = GatewayWebsocket(ws, prop=prop, v=gw_version, encoding=gw_encoding, compress=gw_compress) await gws.run() diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 4da6940..430e98b 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -1,5 +1,6 @@ import json import collections +from typing import List import earl from logbook import Logger @@ -17,6 +18,10 @@ WebsocketProperties = collections.namedtuple( 'WebsocketProperties', 'v encoding compress' ) +WebsocketObjects = collections.namedtuple( + 'WebsocketObjects', 'db state_manager storage loop' +) + def encode_json(payload) -> str: return json.dumps(payload) @@ -30,16 +35,17 @@ def encode_etf(payload) -> str: return earl.pack(payload) -def decode_etf(data): +def decode_etf(data: bytes): return earl.unpack(data) class GatewayWebsocket: """Main gateway websocket logic.""" - def __init__(self, sm, db, ws, **kwargs): - self.state_manager = sm - self.db = db + def __init__(self, ws, **kwargs): + self.ext = WebsocketObjects(*kwargs['prop']) + self.storage = self.ext.storage + self.state_manager = self.ext.state_manager self.ws = ws self.wsp = WebsocketProperties(kwargs.get('v'), @@ -91,22 +97,53 @@ class GatewayWebsocket: 'd': data, }) + async def _make_guild_list(self) -> List[int]: + # TODO: This function does not account for sharding. + user_id = self.state.user_id + + guild_ids = await self.ext.db.fetch(""" + SELECT guild_id + FROM members + WHERE user_id = $1 + """, user_id) + + return [{ + 'id': row[0], + 'unavailable': True, + } for row in guild_ids] + + async def guild_dispatch(self, unavailable_guilds: List[dict]): + for guild_obj in unavailable_guilds: + guild = await self.storage.get_guild(guild_obj['id'], + self.state.user_id) + + if not guild: + continue + + await self.dispatch('GUILD_CREATE', dict(guild)) + async def dispatch_ready(self): """Dispatch the READY packet for a connecting user.""" + guilds = await self._make_guild_list() + user = await self.storage.get_user(self.state.user_id, True) + await self.dispatch('READY', { 'v': 6, - 'user': {}, + 'user': user, 'private_channels': [], - 'guilds': [], + 'guilds': guilds, 'session_id': self.state.session_id, '_trace': ['transbian'] }) + # async dispatch of guilds + self.ext.loop.create_task(self.guild_dispatch(guilds)) + async def _check_shards(self): shard = self.state.shard current_shard, shard_count = shard - guilds = await self.db.fetchval(""" + guilds = await self.ext.db.fetchval(""" SELECT COUNT(*) FROM members WHERE user_id = $1 @@ -139,7 +176,7 @@ class GatewayWebsocket: presence = data.get('presence') try: - user_id = await raw_token_check(token, self.db) + user_id = await raw_token_check(token, self.ext.db) except AuthError: raise WebsocketClose(4004, 'Authentication failed') diff --git a/litecord/storage.py b/litecord/storage.py new file mode 100644 index 0000000..44bf9a3 --- /dev/null +++ b/litecord/storage.py @@ -0,0 +1,95 @@ +from typing import Dict + + +class Storage: + """Class for common SQL statements.""" + def __init__(self, db): + self.db = db + + async def get_user(self, guild_id, secure=False): + pass + + async def get_guild(self, guild_id: int, state) -> Dict: + row = await self.db.fetchrow(""" + SELECT * + FROM guilds + WHERE guilds.id = $1 + """, guild_id) + + if not row: + return + + drow = dict(row) + + if state: + drow['owner'] = drow['owner_id'] == state.user_id + + # TODO: Probably a really bad idea to repeat str() calls + # Any ideas to make this simpler? + # (No, changing the types on the db wouldn't be nice) + drow['id'] = str(drow['id']) + drow['owner_id'] = str(drow['owner_id']) + drow['afk_channel_id'] = str(drow['afk_channel_id']) + drow['embed_channel_id'] = str(drow['embed_channel_id']) + drow['widget_channel_id'] = str(drow['widget_channel_id']) + drow['system_channel_id'] = str(drow['system_channel_id']) + + return {**drow, **{ + 'roles': [], + 'emojis': [], + }} + + async def get_guild_extra(self, guild_id: int, state=None) -> Dict: + """Get extra information about a guild.""" + res = {} + + member_count = await self.db.fetchval(""" + SELECT COUNT(*) + FROM members + WHERE guild_id = $1 + """, guild_id) + + if state: + joined_at = await self.db.fetchval(""" + SELECT joined_at + FROM members + WHERE guild_id = $1 AND user_id = $2 + """, guild_id, state.user_id) + + res['large'] = state.large > member_count + res['joined_at'] = joined_at.isoformat() + + members_basic = await self.db.fetch(""" + SELECT user_id, nickname, joined_at + FROM members + WHERE guild_id = $1 + """, guild_id) + + members = [] + + for row in members_basic: + member_id = row['user_id'] + + members_roles = await self.db.fetch(""" + SELECT role_id + FROM member_roles + WHERE guild_id = $1 AND user_id = $2 + """, guild_id, member_id) + + members.append({ + 'user': await self.get_user(member_id), + 'nick': row['nickname'], + 'roles': [str(row[0]) for row in members_roles], + 'joined_at': row['joined_at'].isoformat(), + 'deaf': row['deafened'], + 'mute': row['muted'], + }) + + return {**res, **{ + 'member_count': member_count, + 'members': members, + 'voice_states': [], + # TODO: finish those + 'channels': [], + 'presences': [], + }} diff --git a/run.py b/run.py index b46906b..59b9c42 100644 --- a/run.py +++ b/run.py @@ -12,6 +12,7 @@ from litecord.blueprints import gateway, auth from litecord.gateway import websocket_handler from litecord.errors import LitecordError from litecord.gateway.state_manager import StateManager +from litecord.storage import Storage # setup logbook handler = StreamHandler(sys.stdout, level=logbook.INFO) @@ -46,6 +47,7 @@ async def app_before_serving(): g.loop = asyncio.get_event_loop() app.state_manager = StateManager() + app.storage = Storage(app.db) # start the websocket, etc host, port = app.config['WS_HOST'], app.config['WS_PORT'] @@ -54,7 +56,8 @@ async def app_before_serving(): async def _wrapper(ws, url): # We wrap the main websocket_handler # so we can pass quart's app object. - await websocket_handler(app.db, app.state_manager, ws, url) + await websocket_handler((app.db, app.state_manager, + app.storage, app.loop), ws, url) ws_future = websockets.serve(_wrapper, host, port) diff --git a/schema.sql b/schema.sql index dfa9a9a..581d42c 100644 --- a/schema.sql +++ b/schema.sql @@ -267,6 +267,8 @@ CREATE TABLE IF NOT EXISTS members ( guild_id bigint REFERENCES guilds (id) ON DELETE CASCADE, nickname varchar(100) DEFAULT NULL, joined_at timestamp without time zone default now(), + deafened boolean DEFAULT false, + muted boolean DEFAULT false, PRIMARY KEY (user_id, guild_id) );