From 221c2f5c6cd014eb072463e5a61a83cbd9386b8c Mon Sep 17 00:00:00 2001 From: Luna Date: Sun, 29 Aug 2021 21:25:15 -0300 Subject: [PATCH] refactor websocket properties name --- litecord/gateway/gateway.py | 5 +++- litecord/gateway/websocket.py | 50 ++++++++++++++++++----------------- 2 files changed, 30 insertions(+), 25 deletions(-) diff --git a/litecord/gateway/gateway.py b/litecord/gateway/gateway.py index fe3786f..c7f0cf1 100644 --- a/litecord/gateway/gateway.py +++ b/litecord/gateway/gateway.py @@ -54,7 +54,10 @@ async def websocket_handler(app, ws, url): async with app.app_context(): gws = GatewayWebsocket( - ws, v=int(gw_version), encoding=gw_encoding, compress=gw_compress + ws, + version=int(gw_version), + encoding=gw_encoding or "json", + compress=gw_compress, ) # this can be run with a single await since this whole coroutine diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 30199f4..e7ef08b 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -70,11 +70,11 @@ from litecord.storage import int_ log = Logger(__name__) WebsocketProperties = collections.namedtuple( - "WebsocketProperties", "v encoding compress zctx zsctx tasks" + "WebsocketProperties", "version encoding compress zctx zsctx tasks" ) -def _complete_users_list(user_id: str, base_ready, user_ready, wsp) -> dict: +def _complete_users_list(user_id: str, base_ready, user_ready, ws_properties) -> dict: """Use the data we were already preparing to send in READY to construct the users array, saving on I/O cost.""" @@ -99,7 +99,7 @@ def _complete_users_list(user_id: str, base_ready, user_ready, wsp) -> dict: ready["users"] = [value for value in users_to_send.values()] # relationship object structure changed in v9 - if wsp.v == 9: + if ws_properties.version == 9: ready["relationships"] = [] for relationship in user_ready["relationships"]: ready["relationships"].append( @@ -211,23 +211,23 @@ def calculate_intents(data) -> Intents: class GatewayWebsocket: """Main gateway websocket logic.""" - def __init__(self, ws, **kwargs): + def __init__(self, ws, *, version, encoding, compress): self.app = app self.storage = app.storage self.user_storage = app.user_storage self.presence = app.presence self.ws = ws - self.wsp = WebsocketProperties( - kwargs.get("v"), - kwargs.get("encoding", "json"), - kwargs.get("compress", None), + self.ws_properties = WebsocketProperties( + version, + encoding, + compress, zlib.compressobj(), zstd.ZstdCompressor(), {}, ) - log.debug("websocket properties: {!r}", self.wsp) + log.debug("websocket properties: {!r}", self.ws_properties) self.state = None self._hb_counter = 0 @@ -235,7 +235,7 @@ class GatewayWebsocket: self._set_encoders() def _set_encoders(self): - encoding = self.wsp.encoding + encoding = self.ws_properties.encoding encodings = { "json": (encode_json, decode_json), @@ -264,8 +264,8 @@ class GatewayWebsocket: websocket messages.""" # compress and flush (for the rest of compressed data + ZLIB_SUFFIX) - data1 = self.wsp.zctx.compress(encoded) - data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH) + data1 = self.ws_properties.zctx.compress(encoded) + data2 = self.ws_properties.zctx.flush(zlib.Z_FULL_FLUSH) log.debug( "zlib-stream: length {} -> compressed ({} + {})", @@ -301,7 +301,9 @@ class GatewayWebsocket: await self._chunked_send(data2, 1024) async def _zstd_stream_send(self, encoded): - compressor = self.wsp.zsctx.stream_writer(WebsocketFileHandler(self.ws)) + compressor = self.ws_properties.zsctx.stream_writer( + WebsocketFileHandler(self.ws) + ) compressor.write(encoded) compressor.flush(zstd.FLUSH_FRAME) @@ -330,15 +332,15 @@ class GatewayWebsocket: if isinstance(encoded, str): encoded = encoded.encode() - if self.wsp.compress == "zlib-stream": + if self.ws_properties.compress == "zlib-stream": await self._zlib_stream_send(want_bytes(encoded)) - elif self.wsp.compress == "zstd-stream": + elif self.ws_properties.compress == "zstd-stream": await self._zstd_stream_send(want_bytes(encoded)) elif ( self.state and self.state.compress and len(encoded) > 8192 - and self.wsp.encoding != "etf" + and self.ws_properties.encoding != "etf" ): # TODO determine better conditions to trigger a compress set # by identify @@ -346,7 +348,7 @@ class GatewayWebsocket: else: await self.ws.send( want_bytes(encoded) - if self.wsp.encoding == "etf" + if self.ws_properties.encoding == "etf" else want_string(encoded) ) @@ -371,11 +373,11 @@ class GatewayWebsocket: def _hb_start(self, interval: int): # always refresh the heartbeat task # when possible - task = self.wsp.tasks.get("heartbeat") + task = self.ws_properties.tasks.get("heartbeat") if task: task.cancel() - self.wsp.tasks["heartbeat"] = app.sched.spawn( + self.ws_properties.tasks["heartbeat"] = app.sched.spawn( task_wrapper("hb wait", self._hb_wait(interval)) ) @@ -471,7 +473,7 @@ class GatewayWebsocket: friend_presences = await self.app.presence.friend_presences(friend_ids) settings = settings or await self.user_storage.get_user_settings(user_id) - if self.wsp.v < 7: # v6 and below + if self.ws_properties.version < 7: # v6 and below user_guild_settings = await self.user_storage.get_guild_settings(user_id) else: user_guild_settings = { @@ -513,7 +515,7 @@ class GatewayWebsocket: ) + await self.user_storage.get_gdms(user_id) base_ready = { - "v": self.wsp.v, + "v": self.ws_properties.version, "user": user, "private_channels": private_channels, "guilds": guilds, @@ -528,7 +530,7 @@ class GatewayWebsocket: # pass users_to_send to ready_supplemental so that its easier to # cross-reference things full_ready_data, users_to_send = _complete_users_list( - user["id"], base_ready, user_ready, self.wsp + user["id"], base_ready, user_ready, self.ws_properties ) ready_supplemental = await _compute_supplemental( self.app, base_ready, user_ready, users_to_send @@ -541,7 +543,7 @@ class GatewayWebsocket: guild["members"] = [] await self.dispatch("READY", full_ready_data) - if self.wsp.v > 6: + if self.ws_properties.version > 6: await self.dispatch("READY_SUPPLEMENTAL", ready_supplemental) app.sched.spawn(self._guild_dispatch(guilds)) @@ -1278,7 +1280,7 @@ class GatewayWebsocket: def _cleanup(self): """Cleanup any leftover tasks, and remove the connection from the state manager.""" - for task in self.wsp.tasks.values(): + for task in self.ws_properties.tasks.values(): task.cancel() if self.state: