From d39783e666e2b26468a8a570064ba2a385ee8174 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sat, 1 Sep 2018 23:53:36 -0300 Subject: [PATCH] Add barebones implementation for GUILD_SYNC - state_manager: add StateManager.guild_states - add PresenceManager in the presence module - fix get_user_guilds not returning ints - gateway: add dummy handler for op 4 - gateway: add hazmat implementation for op 14 - run: keep websockets logger on INFO - run: add more headers on app_after_request --- litecord/gateway/opcodes.py | 1 + litecord/gateway/state_manager.py | 16 ++++++++- litecord/gateway/websocket.py | 58 +++++++++++++++++++++++++++---- litecord/presence.py | 28 +++++++++++++++ litecord/storage.py | 2 +- run.py | 14 ++++++-- 6 files changed, 109 insertions(+), 10 deletions(-) create mode 100644 litecord/presence.py diff --git a/litecord/gateway/opcodes.py b/litecord/gateway/opcodes.py index 1cf83de..2b6a256 100644 --- a/litecord/gateway/opcodes.py +++ b/litecord/gateway/opcodes.py @@ -13,3 +13,4 @@ class OP: HELLO = 10 HEARTBEAT_ACK = 11 GUILD_SYNC = 12 + UNKNOWN = 14 diff --git a/litecord/gateway/state_manager.py b/litecord/gateway/state_manager.py index 67c4dd5..291cd2d 100644 --- a/litecord/gateway/state_manager.py +++ b/litecord/gateway/state_manager.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Dict, Any from collections import defaultdict from logbook import Logger @@ -51,3 +51,17 @@ class StateManager: states.append(state) return states + + def guild_states(self, member_ids: List[int], + guild_id: int) -> List[GatewayState]: + states = [] + + for member_id in member_ids: + member_states = self.fetch_states(member_id, guild_id) + + # for now, just get the first state + state = next(iter(member_states)) + + states.append(state) + + return states diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 5cd43c3..6571334 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -22,7 +22,7 @@ WebsocketProperties = collections.namedtuple( ) WebsocketObjects = collections.namedtuple( - 'WebsocketObjects', 'db state_manager storage loop dispatcher' + 'WebsocketObjects', 'db state_manager storage loop dispatcher presence' ) @@ -48,6 +48,7 @@ class GatewayWebsocket: def __init__(self, ws, **kwargs): self.ext = WebsocketObjects(*kwargs['prop']) self.storage = self.ext.storage + self.presence = self.ext.presence self.ws = ws self.wsp = WebsocketProperties(kwargs.get('v'), @@ -130,11 +131,11 @@ class GatewayWebsocket: return [ { - **await self.storage.get_guild(row[0], user_id), - **await self.storage.get_guild_extra(row[0], user_id, + **await self.storage.get_guild(guild_id, user_id), + **await self.storage.get_guild_extra(guild_id, user_id, self.state.large) } - for row in guild_ids + for guild_id in guild_ids ] async def guild_dispatch(self, unavailable_guilds: List[Dict[str, Any]]): @@ -307,6 +308,10 @@ class GatewayWebsocket: """Handle OP 3 Status Update.""" pass + async def handle_4(self, payload: Dict[str, Any]): + """Handle OP 4 Voice Status Update.""" + pass + async def handle_6(self, payload: Dict[str, Any]): """Handle OP 6 Resume.""" data = payload['d'] @@ -349,13 +354,54 @@ class GatewayWebsocket: await self.dispatch('RESUMED', {}) + async def _guild_sync(self, guild_id: int): + members = await self.storage.get_member_data(guild_id) + member_ids = [int(m['user']['id']) for m in members] + + log.debug(f'Syncing guild {guild_id} with {len(member_ids)} members') + presences = await self.presence.guild_presences(member_ids, guild_id) + + await self.dispatch('GUILD_SYNC', { + 'id': str(guild_id), + 'presences': presences, + 'members': members, + }) + async def handle_12(self, payload: Dict[str, Any]): """Handle OP 12 Guild Sync.""" data = payload['d'] - for _guild_id in data: + gids = await self.storage.get_user_guilds(self.state.user_id) + + for guild_id in data: + try: + guild_id = int(guild_id) + except (ValueError, TypeError): + continue + # check if user in guild - pass + if guild_id not in gids: + continue + + await self._guild_sync(guild_id) + + async def handle_14(self, payload: Dict[str, Any]): + # NOTE: put your HAZMAT suit on. + # OP 12 wasn't sent by the client, but OP 14 was, + # it contained a guild id, so i assume this is an + # evolution of OP 12. + # OP 14 is undocumented. + data = payload['d'] + + gids = await self.storage.get_user_guilds(self.state.user_id) + guild_id = int(data['guild_id']) + + # make sure we are dealing with a sync to a guild + # the user is in. + if guild_id not in gids: + return + + await self._guild_sync(guild_id) async def process_message(self, payload): """Process a single message coming in from the client.""" diff --git a/litecord/presence.py b/litecord/presence.py new file mode 100644 index 0000000..2ac1532 --- /dev/null +++ b/litecord/presence.py @@ -0,0 +1,28 @@ +from typing import List, Dict, Any + + +class PresenceManager: + """Presence related functions.""" + def __init__(self, storage, state_manager): + self.storage = storage + self.state_manager = state_manager + + async def guild_presences(self, member_ids: List[int], + guild_id: int) -> List[Dict[Any, str]]: + states = self.state_manager.guild_states(member_ids, guild_id) + + presences = [] + + for state in states: + member = await self.storage.get_member_data_one( + guild_id, state.user_id) + + presences.append({ + 'user': member['user'], + 'roles': member['roles'], + 'game': state.presence['game'], + 'guild_id': guild_id, + 'status': state.presence['status'], + }) + + return presences diff --git a/litecord/storage.py b/litecord/storage.py index c608288..3e3f6c0 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -80,7 +80,7 @@ class Storage: WHERE user_id = $1 """, user_id) - return guild_ids + return [row['guild_id'] for row in guild_ids] async def get_member_data_one(self, guild_id, member_id) -> Dict[str, any]: basic = await self.db.fetchrow(""" diff --git a/run.py b/run.py index a999703..9c744fb 100644 --- a/run.py +++ b/run.py @@ -3,6 +3,7 @@ import sys import asyncpg import logbook +import logging import websockets from quart import Quart, g, jsonify from logbook import StreamHandler, Logger @@ -16,6 +17,7 @@ from litecord.errors import LitecordError from litecord.gateway.state_manager import StateManager from litecord.storage import Storage from litecord.dispatcher import EventDispatcher +from litecord.presence import PresenceManager # setup logbook handler = StreamHandler(sys.stdout, level=logbook.INFO) @@ -35,6 +37,9 @@ def make_app(): handler.level = logbook.DEBUG app.logger.level = logbook.DEBUG + # always keep websockets on INFO + logging.getLogger('websockets').setLevel(logbook.INFO) + return app @@ -63,7 +68,10 @@ async def app_after_request(resp): 'X-Fingerprint, ' 'X-Context-Properties, ' 'X-Failed-Requests, ' - 'Content-Type') + 'Content-Type, ' + 'Authorization, ' + 'Origin, ' + 'If-None-Match') resp.headers['Access-Control-Allow-Methods'] = '*' return resp @@ -80,6 +88,7 @@ async def app_before_serving(): app.state_manager = StateManager() app.dispatcher = EventDispatcher(app.state_manager) app.storage = Storage(app.db) + app.presence = PresenceManager(app.storage, app.state_manager) # start the websocket, etc host, port = app.config['WS_HOST'], app.config['WS_PORT'] @@ -89,7 +98,8 @@ async def app_before_serving(): # We wrap the main websocket_handler # so we can pass quart's app object. await websocket_handler((app.db, app.state_manager, app.storage, - app.loop, app.dispatcher), ws, url) + app.loop, app.dispatcher, app.presence), + ws, url) ws_future = websockets.serve(_wrapper, host, port)