From 0f600543e8efa90929d09301daab3a173ad72684 Mon Sep 17 00:00:00 2001 From: Luna <508270-luna@users.noreply.gitlab.com> Date: Thu, 30 Jan 2020 01:50:50 +0000 Subject: [PATCH] Refactor GatewayState with proper types Make websocket loop be inside app context --- litecord/common/users.py | 5 +- litecord/gateway/state.py | 36 +++--- litecord/gateway/websocket.py | 72 ++++++------ litecord/presence.py | 205 ++++++++++++++-------------------- litecord/utils.py | 8 ++ 5 files changed, 154 insertions(+), 172 deletions(-) diff --git a/litecord/common/users.py b/litecord/common/users.py index 7edbc33..a3ea268 100644 --- a/litecord/common/users.py +++ b/litecord/common/users.py @@ -24,6 +24,7 @@ from quart import current_app as app from asyncpg import UniqueViolationError from logbook import Logger +from ..presence import BasePresence from ..snowflake import get_snowflake from ..errors import BadRequest from ..auth import hash_data @@ -268,6 +269,4 @@ async def user_disconnect(user_id: int): await state.ws.ws.close(4000) # force everyone to see the user as offline - await app.presence.dispatch_pres( - user_id, {"afk": False, "status": "offline", "game": None, "since": 0} - ) + await app.presence.dispatch_pres(user_id, BasePresence(status="offline")) diff --git a/litecord/gateway/state.py b/litecord/gateway/state.py index 9df23e3..3bac060 100644 --- a/litecord/gateway/state.py +++ b/litecord/gateway/state.py @@ -20,6 +20,9 @@ along with this program. If not, see . import hashlib import os +from typing import Optional +from litecord.presence import BasePresence + def gen_session_id() -> str: """Generate a random session ID.""" @@ -61,34 +64,35 @@ class GatewayState: """ def __init__(self, **kwargs): - self.session_id = kwargs.get("session_id", gen_session_id()) + self.session_id: str = kwargs.get("session_id", gen_session_id()) - #: event sequence number - self.seq = kwargs.get("seq", 0) + #: last seq received by the client + self.seq: int = int(kwargs.get("seq") or 0) - #: last seq sent by us, the backend - self.last_seq = 0 + #: last seq sent by gateway + self.last_seq: int = 0 - #: shard information about the state, - # its id and shard count - self.shard = kwargs.get("shard", [0, 1]) + #: shard information (id and total count) + shard = kwargs.get("shard") or [0, 1] + self.current_shard: int = int(shard[0]) + self.shard_count: int = int(shard[1]) - self.user_id = kwargs.get("user_id") - self.bot = kwargs.get("bot", False) + self.user_id: int = int(kwargs["user_id"]) + self.bot: bool = bool(kwargs.get("bot") or False) #: set by the gateway connection # on OP STATUS_UPDATE - self.presence = {} + self.presence: Optional[BasePresence] = None #: set by the backend once identify happens self.ws = None - #: store (kind of) all payloads sent by us + #: store of all payloads sent by the gateway (for recovery purposes) self.store = PayloadStore() - for key in kwargs: - value = kwargs[key] - self.__dict__[key] = value + self.compress: bool = kwargs.get("compress") or False + + self.large: int = kwargs.get("large") or 50 def __repr__(self): - return f"GatewayState" + return f"GatewayState" diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index b295290..a4f10e8 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -31,8 +31,15 @@ from logbook import Logger from litecord.auth import raw_token_check from litecord.enums import RelationshipType, ChannelType from litecord.schemas import validate, GW_STATUS_UPDATE -from litecord.utils import task_wrapper, yield_chunks, maybe_int +from litecord.utils import ( + task_wrapper, + yield_chunks, + maybe_int, + want_bytes, + want_string, +) from litecord.permissions import get_permissions +from litecord.presence import BasePresence from litecord.gateway.opcodes import OP from litecord.gateway.state import GatewayState @@ -173,13 +180,15 @@ class GatewayWebsocket: payload.get("t"), ) + # TODO encode to bytes only when absolutely needed e.g + # when compressing, because encoding == json means bytes won't work if isinstance(encoded, str): encoded = encoded.encode() if self.wsp.compress == "zlib-stream": - await self._zlib_stream_send(encoded) + await self._zlib_stream_send(want_bytes(encoded)) elif self.wsp.compress == "zstd-stream": - await self._zstd_stream_send(encoded) + await self._zstd_stream_send(want_bytes(encoded)) elif ( self.state and self.state.compress @@ -188,9 +197,13 @@ class GatewayWebsocket: ): # TODO determine better conditions to trigger a compress set # by identify - await self.ws.send(zlib.compress(encoded)) + await self.ws.send(zlib.compress(want_bytes(encoded))) else: - await self.ws.send(encoded) + await self.ws.send( + want_bytes(encoded) + if self.wsp.encoding == "etf" + else want_string(encoded) + ) async def send_op(self, op_code: int, data: Any): """Send a packet but just the OP code information is filled in.""" @@ -343,7 +356,7 @@ class GatewayWebsocket: "guilds": guilds, "session_id": self.state.session_id, "_trace": ["transbian"], - "shard": self.state.shard, + "shard": [self.state.current_shard, self.state.shard_count], } await self.dispatch("READY", {**base_ready, **user_ready}) @@ -442,7 +455,7 @@ class GatewayWebsocket: log.info("subscribing to {} friends", len(friend_ids)) await self.app.dispatcher.sub_many("friend", user_id, friend_ids) - async def update_status(self, status: dict): + async def update_status(self, incoming_status: dict): """Update the status of the current websocket connection.""" if not self.state: return @@ -452,7 +465,7 @@ class GatewayWebsocket: # are just silently dropped. return - default_status = { + status = { "afk": False, # TODO: fetch status from settings "status": "online", @@ -460,8 +473,7 @@ class GatewayWebsocket: # TODO: this "since": 0, } - - status = {**(status or {}), **default_status} + status.update(incoming_status or {}) try: status = validate(status, GW_STATUS_UPDATE) @@ -479,15 +491,10 @@ class GatewayWebsocket: else: game = status["game"] - # construct final status - status = { - "afk": status.get("afk", False), - "status": status.get("status", "online"), - "game": game, - "since": status.get("since", 0), - } + pres_status = status.get("status") or "online" + pres_status = "offline" if pres_status == "invisible" else pres_status + self.state.presence = BasePresence(status=pres_status, game=game) - self.state.presence = status log.info( f'Updating presence status={status["status"]} for ' f"uid={self.state.user_id}" @@ -523,6 +530,8 @@ class GatewayWebsocket: except KeyError: raise DecodeError("Invalid identify parameters") + # TODO proper validation of this payload + compress = data.get("compress", False) large = data.get("large_threshold", 50) @@ -552,12 +561,12 @@ class GatewayWebsocket: bot=bot, compress=compress, large=large, - shard=shard, current_shard=shard[0], shard_count=shard[1], - ws=self, ) + self.state.ws = self + # link the state to the user self.app.state_manager.insert(self.state) @@ -1004,24 +1013,23 @@ class GatewayWebsocket: if not user_id: return - # TODO: account for sharding - # this only updates status to offline once - # ALL shards have come offline + # TODO: account for sharding. this only checks to dispatch an offline + # when all the shards have come fully offline, which is inefficient. + + # TODO why is this inneficient? states = self.app.state_manager.user_states(user_id) - with_ws = [s for s in states if s.ws] - - # there arent any other states with websocket - if not with_ws: - offline = {"afk": False, "status": "offline", "game": None, "since": 0} - - await self.app.presence.dispatch_pres(user_id, offline) + if not any(s.ws for s in states): + await self.app.presence.dispatch_pres( + user_id, BasePresence(status="offline") + ) async def run(self): """Wrap :meth:`listen_messages` inside a try/except block for WebsocketClose handling.""" try: - await self._send_hello() - await self._listen_messages() + async with self.app.app_context(): + await self._send_hello() + await self._listen_messages() except websockets.exceptions.ConnectionClosed as err: log.warning("conn close, state={}, err={}", self.state, err) except WebsocketClose as err: diff --git a/litecord/presence.py b/litecord/presence.py index 55d7832..9911daf 100644 --- a/litecord/presence.py +++ b/litecord/presence.py @@ -17,13 +17,33 @@ along with this program. If not, see . """ -from typing import List, Dict, Any, Iterable +from typing import List, Dict, Any, Iterable, Optional from random import choice +from dataclasses import dataclass +from quart import current_app as app from logbook import Logger log = Logger(__name__) + +@dataclass +class BasePresence: + status: str + game: Optional[dict] = None + + @property + def partial_dict(self) -> dict: + return { + "status": self.status, + "game": self.game, + "since": 0, + "client_status": {}, + "mobile": False, + "activities": [self.game] if self.game else [], + } + + Presence = Dict[str, Any] @@ -37,69 +57,29 @@ def status_cmp(status: str, other_status: str) -> bool: return hierarchy[status] > hierarchy[other_status] -def _best_presence(shards): - """Find the 'best' presence given a list of GatewayState.""" - best = {"status": None, "game": None} +def _merge_state_presences(shards: list) -> BasePresence: + """create a 'best' presence given a list of states.""" + best = BasePresence(status="offline") for state in shards: - presence = state.presence - - status = presence["status"] - - if not presence: + if state.presence is None: continue # shards with a better status # in the hierarchy are treated as best - if status_cmp(status, best["status"]): - best["status"] = status + if status_cmp(state.presence.status, best.status): + best.status = state.presence.status # if we have any game, use it - if presence["game"] is not None: - best["game"] = presence["game"] + if state.presence.game: + best.game = state.presence.game - # best['status'] is None when no - # status was good enough. - return None if not best["status"] else best + return best -def fill_presence(presence: dict, *, game=None) -> dict: - """Fill a given presence object with some specific fields.""" - presence["client_status"] = {} - presence["mobile"] = False - - if "since" not in presence: - presence["since"] = 0 - - # fill game and activities array depending if game - # is there or not - game = game or presence.get("game") - - # casting to bool since a game of {} is still invalid - if game: - presence["game"] = game - presence["activities"] = [game] - else: - presence["game"] = None - presence["activities"] = [] - - return presence - - -async def _pres(storage, user_id: int, status_obj: dict) -> dict: - """Convert a given status into a presence, given the User ID and the - :class:`Storage` instance.""" - ext = { - "user": await storage.get_user(user_id), - "activities": [], - # NOTE: we are purposefully overwriting the fields, as there - # isn't any push for us to actually implement mobile detection, or - # web detection, etc. - "client_status": {}, - "mobile": False, - } - - return fill_presence({**status_obj, **ext}) +async def _pres(user_id: int, presence: BasePresence) -> dict: + """Take a given base presence and convert it to a full friend presence.""" + return {**presence.partial_dict, **{"user": await app.storage.get_user(user_id)}} class PresenceManager: @@ -123,39 +103,30 @@ class PresenceManager: # then fetching its respective member and merging that info with # the state's set presence. states = self.state_manager.guild_states(member_ids, guild_id) - presences = [] for state in states: member = await self.storage.get_member_data_one(guild_id, state.user_id) - - game = state.presence.get("game", None) - - # only use the data we need. presences.append( - fill_presence( - { + { + **(state.presence or BasePresence(status="offline")).partial_dict, + **{ "user": member["user"], "roles": member["roles"], "guild_id": str(guild_id), - # if a state is connected to the guild - # we assume its online. - "status": state.presence.get("status", "online"), }, - game=game, - ) + } ) return presences - async def dispatch_guild_pres(self, guild_id: int, user_id: int, new_state: dict): + async def dispatch_guild_pres( + self, guild_id: int, user_id: int, presence: BasePresence + ): """Dispatch a Presence update to an entire guild.""" - state = dict(new_state) member = await self.storage.get_member_data_one(guild_id, user_id) - game = state["game"] - lazy_guild_store = self.dispatcher.backends["lazy_guild"] lists = lazy_guild_store.get_gml_guild(guild_id) @@ -166,7 +137,11 @@ class PresenceManager: for member_list in lists: session_ids = await member_list.pres_update( int(member["user"]["id"]), - {"roles": member["roles"], "status": state["status"], "game": game}, + { + "roles": member["roles"], + "status": presence.status, + "game": presence.game, + }, ) log.debug("Lazy Dispatch to {}", len(session_ids)) @@ -176,15 +151,14 @@ class PresenceManager: if member_list.channel_id == member_list.guild_id: in_lazy.extend(session_ids) - pres_update_payload = fill_presence( - { + event_payload = { + **presence.partial_dict, + **{ "guild_id": str(guild_id), "user": member["user"], "roles": member["roles"], - "status": state["status"], }, - game=game, - ) + } # given a session id, return if the session id actually connects to # a given user, and if the state has not been dispatched via lazy guild. @@ -202,79 +176,68 @@ class PresenceManager: # everyone not in lazy guild mode # gets a PRESENCE_UPDATE await self.dispatcher.dispatch_filter( - "guild", guild_id, _session_check, "PRESENCE_UPDATE", pres_update_payload + "guild", guild_id, _session_check, "PRESENCE_UPDATE", event_payload ) return in_lazy - async def dispatch_pres(self, user_id: int, state: dict): + async def dispatch_pres(self, user_id: int, presence: BasePresence) -> None: """Dispatch a new presence to all guilds the user is in. Also dispatches the presence to all the users' friends """ - if state["status"] == "invisible": - state["status"] = "offline" - - # TODO: shard-aware + # TODO: shard-aware (needs to only dispatch guilds of the shard) guild_ids = await self.user_storage.get_user_guilds(user_id) - for guild_id in guild_ids: - await self.dispatch_guild_pres(guild_id, user_id, state) + await self.dispatch_guild_pres(guild_id, user_id, presence) # dispatch to all friends that are subscribed to them user = await self.storage.get_user(user_id) - game = state["game"] - await self.dispatcher.dispatch( "friend", user_id, "PRESENCE_UPDATE", - fill_presence({"user": user, "status": state["status"]}, game=game), + {**presence.partial_dict, **{"user": user}}, ) + def fetch_friend_presence(self, friend_id: int) -> BasePresence: + """Fetch a presence for a friend. + + This is a different algorithm than guild presence. + """ + friend_states = self.state_manager.user_states(friend_id) + + if not friend_states: + return BasePresence(status="offline") + + # filter the best shards: + # - all with id 0 (are the first shards in the collection) or + # - all shards with count = 1 (single shards) + good_shards = list( + filter( + lambda state: state.shard[0] == 0 or state.shard[1] == 1, friend_states + ) + ) + + if good_shards: + return _merge_state_presences(good_shards) + + # if there aren't any shards with id 0 + # AND none that are single, just go with a random one. + shard = choice([s for s in friend_states if s.presence]) + assert shard.presence is not None + return shard.presence + async def friend_presences(self, friend_ids: Iterable[int]) -> List[Presence]: """Fetch presences for a group of users. This assumes the users are friends and so only gets states that are single or have ID 0. """ - storage = self.storage res = [] for friend_id in friend_ids: - friend_states = self.state_manager.user_states(friend_id) - - if not friend_states: - # append offline - res.append( - await _pres( - storage, - friend_id, - {"afk": False, "status": "offline", "game": None, "since": 0}, - ) - ) - - continue - - # filter the best shards: - # - all with id 0 (are the first shards in the collection) or - # - all shards with count = 1 (single shards) - good_shards = list( - filter( - lambda state: state.shard[0] == 0 or state.shard[1] == 1, - friend_states, - ) - ) - - if good_shards: - best_pres = _best_presence(good_shards) - best_pres = await _pres(storage, friend_id, best_pres) - res.append(best_pres) - continue - - # if there aren't any shards with id 0 - # AND none that are single, just go with a random - shard = choice(friend_states) - res.append(await _pres(storage, friend_id, shard.presence)) + presence = self.fetch_friend_presence(friend_id) + res.append(await _pres(friend_id, presence)) return res diff --git a/litecord/utils.py b/litecord/utils.py index 91e7d8a..fe3c2d4 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -290,3 +290,11 @@ def query_tuple_from_args(args: dict, limit: int) -> tuple: def rand_hex(length: int = 8) -> str: """Generate random hex characters.""" return secrets.token_hex(length)[:length] + + +def want_bytes(data: Union[str, bytes]) -> bytes: + return data if isinstance(data, bytes) else data.encode() + + +def want_string(data: Union[str, bytes]) -> str: + return data.decode() if isinstance(data, bytes) else data