diff --git a/litecord/gateway/state.py b/litecord/gateway/state.py index dc13c57..974e2c0 100644 --- a/litecord/gateway/state.py +++ b/litecord/gateway/state.py @@ -8,7 +8,11 @@ def gen_session_id() -> str: class PayloadStore: - """Store manager for payloads.""" + """Store manager for payloads. + + This will only store a maximum of MAX_STORE_SIZE, + dropping the older payloads when adding new ones. + """ MAX_STORE_SIZE = 250 def __init__(self): @@ -30,14 +34,6 @@ class PayloadStore: self.store[opcode] = payload -class Presence: - def __init__(self, raw: dict): - self.afk = raw.get('afk', False) - self.status = raw.get('status', 'online') - self.game = raw.get('game', None) - self.since = raw.get('since', 0) - - class GatewayState: """Main websocket state. @@ -46,13 +42,28 @@ class GatewayState: def __init__(self, **kwargs): self.session_id = kwargs.get('session_id', gen_session_id()) + + #: event sequence number self.seq = kwargs.get('seq', 0) + + #: last seq sent by us, the backend self.last_seq = 0 + + #: shard information about the state, + # its id and shard count self.shard = kwargs.get('shard', [0, 1]) + self.user_id = kwargs.get('user_id') self.bot = kwargs.get('bot', False) + + #: set by the gateway connection + # on OP STATUS_UPDATE self.presence = {} + + #: set by the backend once identify happens self.ws = None + + #: store (kind of) all payloads sent by us self.store = PayloadStore() for key in kwargs: diff --git a/litecord/gateway/state_manager.py b/litecord/gateway/state_manager.py index bf911a0..56d00bc 100644 --- a/litecord/gateway/state_manager.py +++ b/litecord/gateway/state_manager.py @@ -13,6 +13,13 @@ class StateManager: """Manager for gateway state information.""" def __init__(self): + # { + # user_id: { + # session_id: GatewayState, + # session_id_2: GatewayState, ... + # }, + # user_id_2: {}, ... + # } self.states = defaultdict(dict) def insert(self, state: GatewayState): @@ -23,7 +30,14 @@ class StateManager: user_states[state.session_id] = state def fetch(self, user_id: int, session_id: str) -> GatewayState: - """Fetch a state object from the registry.""" + """Fetch a state object from the manager. + + Raises + ------ + KeyError + When the user_id or session_id + aren't found in the store. + """ return self.states[user_id][session_id] def remove(self, state): diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 39809ae..289c0e6 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -256,15 +256,14 @@ class GatewayWebsocket: # async dispatch of guilds self.ext.loop.create_task(self.guild_dispatch(guilds)) - async def _check_shards(self): - shard = self.state.shard + async def _check_shards(self, shard, user_id): current_shard, shard_count = shard guilds = await self.ext.db.fetchval(""" SELECT COUNT(*) FROM members WHERE user_id = $1 - """, self.state.user_id) + """, user_id) recommended = max(int(guilds / 1200), 1) @@ -390,6 +389,9 @@ class GatewayWebsocket: WHERE id = $1 """, user_id) + await self._check_shards(shard, user_id) + + # only create a state after checking everything self.state = GatewayState( user_id=user_id, bot=bot, @@ -401,9 +403,9 @@ class GatewayWebsocket: ws=self ) - await self._check_shards() - + # link the state to the user self.ext.state_manager.insert(self.state) + await self.update_status(presence) await self.subscribe_all() await self.dispatch_ready() @@ -419,6 +421,7 @@ class GatewayWebsocket: async def handle_4(self, payload: Dict[str, Any]): """Handle OP 4 Voice Status Update.""" data = payload['d'] + # for now, ignore log.debug('got VSU cid={} gid={} deaf={} mute={} video={}', data.get('channel_id'), data.get('guild_id'), @@ -426,9 +429,6 @@ class GatewayWebsocket: data.get('self_mute'), data.get('self_video')) - # for now, do nothing - pass - async def _handle_5(self, payload: Dict[str, Any]): """Handle OP 5 Voice Server Ping. @@ -452,6 +452,9 @@ class GatewayWebsocket: }) if not resumable and self.state: + # since the state will be removed from + # the manager, it will become unreachable + # when trying to resume. self.ext.state_manager.remove(self.state) async def _resume(self, replay_seqs: iter): @@ -476,12 +479,14 @@ class GatewayWebsocket: await self.send(payload) except Exception: log.exception('error while resuming') - await self.invalidate_session() + await self.invalidate_session(False) return if presences: await self.dispatch('PRESENCE_REPLACE', presences) + await self.dispatch('RESUMED', {}) + async def handle_6(self, payload: Dict[str, Any]): """Handle OP 6 Resume.""" data = payload['d'] @@ -515,7 +520,6 @@ class GatewayWebsocket: state.ws = self await self._resume(range(seq, state.seq)) - await self.dispatch('RESUMED', {}) async def _req_guild_members(self, guild_id: str, user_ids: List[int], query: str, limit: int):