From 3eb6d5e60f6d467d411180f9b16241191c445dc2 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 20 Jun 2018 16:53:22 -0300 Subject: [PATCH] litecord: add Storage Storage serves as a way to reduce code repeatbility. So that we don't need to keep repeating the same SQL statements over and over, and to detach some SQL calls into their own code (like guild fetching) - gateway.websocket: add WebsocketObjects to hold db, state_manager, storage and loop - gateway.websocket: add _make_guild_list - schema: add members.deafened, members.muted --- litecord/gateway/gateway.py | 4 +- litecord/gateway/websocket.py | 53 ++++++++++++++++--- litecord/storage.py | 95 +++++++++++++++++++++++++++++++++++ run.py | 5 +- schema.sql | 2 + 5 files changed, 148 insertions(+), 11 deletions(-) create mode 100644 litecord/storage.py diff --git a/litecord/gateway/gateway.py b/litecord/gateway/gateway.py index d0105e5..fe86226 100644 --- a/litecord/gateway/gateway.py +++ b/litecord/gateway/gateway.py @@ -2,7 +2,7 @@ import urllib.parse from .websocket import GatewayWebsocket -async def websocket_handler(db, sm, ws, url): +async def websocket_handler(prop, ws, url): qs = urllib.parse.parse_qs( urllib.parse.urlparse(url).query ) @@ -27,6 +27,6 @@ async def websocket_handler(db, sm, ws, url): if gw_compress and gw_compress not in ('zlib-stream',): return await ws.close(1000, 'Invalid gateway compress') - gws = GatewayWebsocket(sm, db, ws, v=gw_version, + gws = GatewayWebsocket(ws, prop=prop, v=gw_version, encoding=gw_encoding, compress=gw_compress) await gws.run() diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 4da6940..430e98b 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -1,5 +1,6 @@ import json import collections +from typing import List import earl from logbook import Logger @@ -17,6 +18,10 @@ WebsocketProperties = collections.namedtuple( 'WebsocketProperties', 'v encoding compress' ) +WebsocketObjects = collections.namedtuple( + 'WebsocketObjects', 'db state_manager storage loop' +) + def encode_json(payload) -> str: return json.dumps(payload) @@ -30,16 +35,17 @@ def encode_etf(payload) -> str: return earl.pack(payload) -def decode_etf(data): +def decode_etf(data: bytes): return earl.unpack(data) class GatewayWebsocket: """Main gateway websocket logic.""" - def __init__(self, sm, db, ws, **kwargs): - self.state_manager = sm - self.db = db + def __init__(self, ws, **kwargs): + self.ext = WebsocketObjects(*kwargs['prop']) + self.storage = self.ext.storage + self.state_manager = self.ext.state_manager self.ws = ws self.wsp = WebsocketProperties(kwargs.get('v'), @@ -91,22 +97,53 @@ class GatewayWebsocket: 'd': data, }) + async def _make_guild_list(self) -> List[int]: + # TODO: This function does not account for sharding. + user_id = self.state.user_id + + guild_ids = await self.ext.db.fetch(""" + SELECT guild_id + FROM members + WHERE user_id = $1 + """, user_id) + + return [{ + 'id': row[0], + 'unavailable': True, + } for row in guild_ids] + + async def guild_dispatch(self, unavailable_guilds: List[dict]): + for guild_obj in unavailable_guilds: + guild = await self.storage.get_guild(guild_obj['id'], + self.state.user_id) + + if not guild: + continue + + await self.dispatch('GUILD_CREATE', dict(guild)) + async def dispatch_ready(self): """Dispatch the READY packet for a connecting user.""" + guilds = await self._make_guild_list() + user = await self.storage.get_user(self.state.user_id, True) + await self.dispatch('READY', { 'v': 6, - 'user': {}, + 'user': user, 'private_channels': [], - 'guilds': [], + 'guilds': guilds, 'session_id': self.state.session_id, '_trace': ['transbian'] }) + # async dispatch of guilds + self.ext.loop.create_task(self.guild_dispatch(guilds)) + async def _check_shards(self): shard = self.state.shard current_shard, shard_count = shard - guilds = await self.db.fetchval(""" + guilds = await self.ext.db.fetchval(""" SELECT COUNT(*) FROM members WHERE user_id = $1 @@ -139,7 +176,7 @@ class GatewayWebsocket: presence = data.get('presence') try: - user_id = await raw_token_check(token, self.db) + user_id = await raw_token_check(token, self.ext.db) except AuthError: raise WebsocketClose(4004, 'Authentication failed') diff --git a/litecord/storage.py b/litecord/storage.py new file mode 100644 index 0000000..44bf9a3 --- /dev/null +++ b/litecord/storage.py @@ -0,0 +1,95 @@ +from typing import Dict + + +class Storage: + """Class for common SQL statements.""" + def __init__(self, db): + self.db = db + + async def get_user(self, guild_id, secure=False): + pass + + async def get_guild(self, guild_id: int, state) -> Dict: + row = await self.db.fetchrow(""" + SELECT * + FROM guilds + WHERE guilds.id = $1 + """, guild_id) + + if not row: + return + + drow = dict(row) + + if state: + drow['owner'] = drow['owner_id'] == state.user_id + + # TODO: Probably a really bad idea to repeat str() calls + # Any ideas to make this simpler? + # (No, changing the types on the db wouldn't be nice) + drow['id'] = str(drow['id']) + drow['owner_id'] = str(drow['owner_id']) + drow['afk_channel_id'] = str(drow['afk_channel_id']) + drow['embed_channel_id'] = str(drow['embed_channel_id']) + drow['widget_channel_id'] = str(drow['widget_channel_id']) + drow['system_channel_id'] = str(drow['system_channel_id']) + + return {**drow, **{ + 'roles': [], + 'emojis': [], + }} + + async def get_guild_extra(self, guild_id: int, state=None) -> Dict: + """Get extra information about a guild.""" + res = {} + + member_count = await self.db.fetchval(""" + SELECT COUNT(*) + FROM members + WHERE guild_id = $1 + """, guild_id) + + if state: + joined_at = await self.db.fetchval(""" + SELECT joined_at + FROM members + WHERE guild_id = $1 AND user_id = $2 + """, guild_id, state.user_id) + + res['large'] = state.large > member_count + res['joined_at'] = joined_at.isoformat() + + members_basic = await self.db.fetch(""" + SELECT user_id, nickname, joined_at + FROM members + WHERE guild_id = $1 + """, guild_id) + + members = [] + + for row in members_basic: + member_id = row['user_id'] + + members_roles = await self.db.fetch(""" + SELECT role_id + FROM member_roles + WHERE guild_id = $1 AND user_id = $2 + """, guild_id, member_id) + + members.append({ + 'user': await self.get_user(member_id), + 'nick': row['nickname'], + 'roles': [str(row[0]) for row in members_roles], + 'joined_at': row['joined_at'].isoformat(), + 'deaf': row['deafened'], + 'mute': row['muted'], + }) + + return {**res, **{ + 'member_count': member_count, + 'members': members, + 'voice_states': [], + # TODO: finish those + 'channels': [], + 'presences': [], + }} diff --git a/run.py b/run.py index b46906b..59b9c42 100644 --- a/run.py +++ b/run.py @@ -12,6 +12,7 @@ from litecord.blueprints import gateway, auth from litecord.gateway import websocket_handler from litecord.errors import LitecordError from litecord.gateway.state_manager import StateManager +from litecord.storage import Storage # setup logbook handler = StreamHandler(sys.stdout, level=logbook.INFO) @@ -46,6 +47,7 @@ async def app_before_serving(): g.loop = asyncio.get_event_loop() app.state_manager = StateManager() + app.storage = Storage(app.db) # start the websocket, etc host, port = app.config['WS_HOST'], app.config['WS_PORT'] @@ -54,7 +56,8 @@ async def app_before_serving(): async def _wrapper(ws, url): # We wrap the main websocket_handler # so we can pass quart's app object. - await websocket_handler(app.db, app.state_manager, ws, url) + await websocket_handler((app.db, app.state_manager, + app.storage, app.loop), ws, url) ws_future = websockets.serve(_wrapper, host, port) diff --git a/schema.sql b/schema.sql index dfa9a9a..581d42c 100644 --- a/schema.sql +++ b/schema.sql @@ -267,6 +267,8 @@ CREATE TABLE IF NOT EXISTS members ( guild_id bigint REFERENCES guilds (id) ON DELETE CASCADE, nickname varchar(100) DEFAULT NULL, joined_at timestamp without time zone default now(), + deafened boolean DEFAULT false, + muted boolean DEFAULT false, PRIMARY KEY (user_id, guild_id) );