diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 6579982..f4c23c7 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -21,7 +21,7 @@ import collections import asyncio import pprint import zlib -from typing import List, Dict, Any +from typing import List, Dict, Any, Iterable from random import randint import websockets @@ -56,41 +56,15 @@ WebsocketProperties = collections.namedtuple( "WebsocketProperties", "v encoding compress zctx zsctx tasks" ) -WebsocketObjects = collections.namedtuple( - "WebsocketObjects", - ( - "db", - "state_manager", - "storage", - "loop", - "dispatcher", - "presence", - "ratelimiter", - "user_storage", - "voice", - ), -) - class GatewayWebsocket: """Main gateway websocket logic.""" def __init__(self, ws, app, **kwargs): - self.ext = WebsocketObjects( - app.db, - app.state_manager, - app.storage, - app.loop, - app.dispatcher, - app.presence, - app.ratelimiter, - app.user_storage, - app.voice, - ) - - self.storage = self.ext.storage - self.user_storage = self.ext.user_storage - self.presence = self.ext.presence + self.app = app + self.storage = app.storage + self.user_storage = app.user_storage + self.presence = app.presence self.ws = ws self.wsp = WebsocketProperties( @@ -225,7 +199,7 @@ class GatewayWebsocket: await self.send({"op": op_code, "d": data, "t": None, "s": None}) def _check_ratelimit(self, key: str, ratelimit_key): - ratelimit = self.ext.ratelimiter.get_ratelimit(f"_ws.{key}") + ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}") bucket = ratelimit.get_bucket(ratelimit_key) return bucket.update_rate_limit() @@ -245,7 +219,7 @@ class GatewayWebsocket: if task: task.cancel() - self.wsp.tasks["heartbeat"] = self.ext.loop.create_task( + self.wsp.tasks["heartbeat"] = self.app.loop.create_task( task_wrapper("hb wait", self._hb_wait(interval)) ) @@ -330,7 +304,7 @@ class GatewayWebsocket: if r["type"] == RelationshipType.FRIEND.value ] - friend_presences = await self.ext.presence.friend_presences(friend_ids) + friend_presences = await self.app.presence.friend_presences(friend_ids) settings = await self.user_storage.get_user_settings(user_id) return { @@ -377,14 +351,14 @@ class GatewayWebsocket: await self.dispatch("READY", {**base_ready, **user_ready}) # async dispatch of guilds - self.ext.loop.create_task(self._guild_dispatch(guilds)) + self.app.loop.create_task(self._guild_dispatch(guilds)) async def _check_shards(self, shard, user_id): """Check if the given `shard` value in IDENTIFY has good enough values. """ current_shard, shard_count = shard - guilds = await self.ext.db.fetchval( + guilds = await self.app.db.fetchval( """ SELECT COUNT(*) FROM members @@ -460,7 +434,7 @@ class GatewayWebsocket: ("channel", gdm_ids), ] - await self.ext.dispatcher.mass_sub(user_id, channels_to_sub) + await self.app.dispatcher.mass_sub(user_id, channels_to_sub) if not self.state.bot: # subscribe to all friends @@ -468,7 +442,7 @@ class GatewayWebsocket: # when they come online) friend_ids = await self.user_storage.get_friend_ids(user_id) log.info("subscribing to {} friends", len(friend_ids)) - await self.ext.dispatcher.sub_many("friend", user_id, friend_ids) + await self.app.dispatcher.sub_many("friend", user_id, friend_ids) async def update_status(self, status: dict): """Update the status of the current websocket connection.""" @@ -520,7 +494,7 @@ class GatewayWebsocket: f'Updating presence status={status["status"]} for ' f"uid={self.state.user_id}" ) - await self.ext.presence.dispatch_pres(self.state.user_id, self.state.presence) + await self.app.presence.dispatch_pres(self.state.user_id, self.state.presence) async def handle_1(self, payload: Dict[str, Any]): """Handle OP 1 Heartbeat packets.""" @@ -558,13 +532,13 @@ class GatewayWebsocket: presence = data.get("presence") try: - user_id = await raw_token_check(token, self.ext.db) + user_id = await raw_token_check(token, self.app.db) except (Unauthorized, Forbidden): raise WebsocketClose(4004, "Authentication failed") await self._connect_ratelimit(user_id) - bot = await self.ext.db.fetchval( + bot = await self.app.db.fetchval( """ SELECT bot FROM users WHERE id = $1 @@ -587,7 +561,7 @@ class GatewayWebsocket: ) # link the state to the user - self.ext.state_manager.insert(self.state) + self.app.state_manager.insert(self.state) await self.update_status(presence) await self.subscribe_all(data.get("guild_subscriptions", True)) @@ -631,12 +605,12 @@ class GatewayWebsocket: # if its null and null, disconnect the user from any voice # TODO: maybe just leave from DMs? idk... if channel_id is None and guild_id is None: - return await self.ext.voice.leave_all(self.state.user_id) + return await self.app.voice.leave_all(self.state.user_id) # if guild is not none but channel is, we are leaving # a guild's channel if channel_id is None: - return await self.ext.voice.leave(guild_id, self.state.user_id) + return await self.app.voice.leave(guild_id, self.state.user_id) # fetch an existing state given user and guild OR user and channel chan_type = ChannelType(await self.storage.get_chan_type(channel_id)) @@ -659,10 +633,10 @@ class GatewayWebsocket: # this state id format takes care of that. voice_key = (self.state.user_id, state_id2) - voice_state = await self.ext.voice.get_state(voice_key) + voice_state = await self.app.voice.get_state(voice_key) if voice_state is None: - return await self.ext.voice.create_state(voice_key, data) + return await self.app.voice.create_state(voice_key, data) same_guild = guild_id == voice_state.guild_id same_channel = channel_id == voice_state.channel_id @@ -670,10 +644,10 @@ class GatewayWebsocket: prop = await self._vsu_get_prop(voice_state, data) if same_guild and same_channel: - return await self.ext.voice.update_state(voice_state, prop) + return await self.app.voice.update_state(voice_state, prop) if same_guild and not same_channel: - return await self.ext.voice.move_state(voice_state, channel_id) + return await self.app.voice.move_state(voice_state, channel_id) async def _handle_5(self, payload: Dict[str, Any]): """Handle OP 5 Voice Server Ping. @@ -698,9 +672,9 @@ class GatewayWebsocket: # since the state will be removed from # the manager, it will become unreachable # when trying to resume. - self.ext.state_manager.remove(self.state) + self.app.state_manager.remove(self.state) - async def _resume(self, replay_seqs: iter): + async def _resume(self, replay_seqs: Iterable): presences = [] try: @@ -740,12 +714,12 @@ class GatewayWebsocket: raise DecodeError("Invalid resume payload") try: - user_id = await raw_token_check(token, self.ext.db) + user_id = await raw_token_check(token, self.app.db) except (Unauthorized, Forbidden): raise WebsocketClose(4004, "Invalid token") try: - state = self.ext.state_manager.fetch(user_id, sess_id) + state = self.app.state_manager.fetch(user_id, sess_id) except KeyError: return await self.invalidate_session(False) @@ -948,7 +922,7 @@ class GatewayWebsocket: log.debug("lazy request: members: {}", data.get("members", [])) # make shard query - lazy_guilds = self.ext.dispatcher.backends["lazy_guild"] + lazy_guilds = self.app.dispatcher.backends["lazy_guild"] for chan_id, ranges in data.get("channels", {}).items(): chan_id = int(chan_id) @@ -992,10 +966,10 @@ class GatewayWebsocket: # close anyone trying to login while the # server is shutting down - if self.ext.state_manager.closed: + if self.app.state_manager.closed: raise WebsocketClose(4000, "state manager closed") - if not self.ext.state_manager.accept_new: + if not self.app.state_manager.accept_new: raise WebsocketClose(4000, "state manager closed for new") while True: @@ -1016,7 +990,7 @@ class GatewayWebsocket: task.cancel() if self.state: - self.ext.state_manager.remove(self.state) + self.app.state_manager.remove(self.state) self.state.ws = None self.state = None @@ -1031,14 +1005,14 @@ class GatewayWebsocket: # TODO: account for sharding # this only updates status to offline once # ALL shards have come offline - states = self.ext.state_manager.user_states(user_id) + states = self.app.state_manager.user_states(user_id) with_ws = [s for s in states if s.ws] # there arent any other states with websocket if not with_ws: offline = {"afk": False, "status": "offline", "game": None, "since": 0} - await self.ext.presence.dispatch_pres(user_id, offline) + await self.app.presence.dispatch_pres(user_id, offline) async def run(self): """Wrap :meth:`listen_messages` inside