Refactor GatewayState with proper types

Make websocket loop be inside app context
This commit is contained in:
Luna 2020-01-30 01:50:50 +00:00
parent 70443ca379
commit 0f600543e8
5 changed files with 154 additions and 172 deletions

View File

@ -24,6 +24,7 @@ from quart import current_app as app
from asyncpg import UniqueViolationError from asyncpg import UniqueViolationError
from logbook import Logger from logbook import Logger
from ..presence import BasePresence
from ..snowflake import get_snowflake from ..snowflake import get_snowflake
from ..errors import BadRequest from ..errors import BadRequest
from ..auth import hash_data from ..auth import hash_data
@ -268,6 +269,4 @@ async def user_disconnect(user_id: int):
await state.ws.ws.close(4000) await state.ws.ws.close(4000)
# force everyone to see the user as offline # force everyone to see the user as offline
await app.presence.dispatch_pres( await app.presence.dispatch_pres(user_id, BasePresence(status="offline"))
user_id, {"afk": False, "status": "offline", "game": None, "since": 0}
)

View File

@ -20,6 +20,9 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import hashlib import hashlib
import os import os
from typing import Optional
from litecord.presence import BasePresence
def gen_session_id() -> str: def gen_session_id() -> str:
"""Generate a random session ID.""" """Generate a random session ID."""
@ -61,34 +64,35 @@ class GatewayState:
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.session_id = kwargs.get("session_id", gen_session_id()) self.session_id: str = kwargs.get("session_id", gen_session_id())
#: event sequence number #: last seq received by the client
self.seq = kwargs.get("seq", 0) self.seq: int = int(kwargs.get("seq") or 0)
#: last seq sent by us, the backend #: last seq sent by gateway
self.last_seq = 0 self.last_seq: int = 0
#: shard information about the state, #: shard information (id and total count)
# its id and shard count shard = kwargs.get("shard") or [0, 1]
self.shard = kwargs.get("shard", [0, 1]) self.current_shard: int = int(shard[0])
self.shard_count: int = int(shard[1])
self.user_id = kwargs.get("user_id") self.user_id: int = int(kwargs["user_id"])
self.bot = kwargs.get("bot", False) self.bot: bool = bool(kwargs.get("bot") or False)
#: set by the gateway connection #: set by the gateway connection
# on OP STATUS_UPDATE # on OP STATUS_UPDATE
self.presence = {} self.presence: Optional[BasePresence] = None
#: set by the backend once identify happens #: set by the backend once identify happens
self.ws = None self.ws = None
#: store (kind of) all payloads sent by us #: store of all payloads sent by the gateway (for recovery purposes)
self.store = PayloadStore() self.store = PayloadStore()
for key in kwargs: self.compress: bool = kwargs.get("compress") or False
value = kwargs[key]
self.__dict__[key] = value self.large: int = kwargs.get("large") or 50
def __repr__(self): def __repr__(self):
return f"GatewayState<seq={self.seq} " f"shard={self.shard} uid={self.user_id}>" return f"GatewayState<seq={self.seq} shard={self.current_shard},{self.shard_count} uid={self.user_id}>"

View File

@ -31,8 +31,15 @@ from logbook import Logger
from litecord.auth import raw_token_check from litecord.auth import raw_token_check
from litecord.enums import RelationshipType, ChannelType from litecord.enums import RelationshipType, ChannelType
from litecord.schemas import validate, GW_STATUS_UPDATE from litecord.schemas import validate, GW_STATUS_UPDATE
from litecord.utils import task_wrapper, yield_chunks, maybe_int from litecord.utils import (
task_wrapper,
yield_chunks,
maybe_int,
want_bytes,
want_string,
)
from litecord.permissions import get_permissions from litecord.permissions import get_permissions
from litecord.presence import BasePresence
from litecord.gateway.opcodes import OP from litecord.gateway.opcodes import OP
from litecord.gateway.state import GatewayState from litecord.gateway.state import GatewayState
@ -173,13 +180,15 @@ class GatewayWebsocket:
payload.get("t"), payload.get("t"),
) )
# TODO encode to bytes only when absolutely needed e.g
# when compressing, because encoding == json means bytes won't work
if isinstance(encoded, str): if isinstance(encoded, str):
encoded = encoded.encode() encoded = encoded.encode()
if self.wsp.compress == "zlib-stream": if self.wsp.compress == "zlib-stream":
await self._zlib_stream_send(encoded) await self._zlib_stream_send(want_bytes(encoded))
elif self.wsp.compress == "zstd-stream": elif self.wsp.compress == "zstd-stream":
await self._zstd_stream_send(encoded) await self._zstd_stream_send(want_bytes(encoded))
elif ( elif (
self.state self.state
and self.state.compress and self.state.compress
@ -188,9 +197,13 @@ class GatewayWebsocket:
): ):
# TODO determine better conditions to trigger a compress set # TODO determine better conditions to trigger a compress set
# by identify # by identify
await self.ws.send(zlib.compress(encoded)) await self.ws.send(zlib.compress(want_bytes(encoded)))
else: else:
await self.ws.send(encoded) await self.ws.send(
want_bytes(encoded)
if self.wsp.encoding == "etf"
else want_string(encoded)
)
async def send_op(self, op_code: int, data: Any): async def send_op(self, op_code: int, data: Any):
"""Send a packet but just the OP code information is filled in.""" """Send a packet but just the OP code information is filled in."""
@ -343,7 +356,7 @@ class GatewayWebsocket:
"guilds": guilds, "guilds": guilds,
"session_id": self.state.session_id, "session_id": self.state.session_id,
"_trace": ["transbian"], "_trace": ["transbian"],
"shard": self.state.shard, "shard": [self.state.current_shard, self.state.shard_count],
} }
await self.dispatch("READY", {**base_ready, **user_ready}) await self.dispatch("READY", {**base_ready, **user_ready})
@ -442,7 +455,7 @@ class GatewayWebsocket:
log.info("subscribing to {} friends", len(friend_ids)) log.info("subscribing to {} friends", len(friend_ids))
await self.app.dispatcher.sub_many("friend", user_id, friend_ids) await self.app.dispatcher.sub_many("friend", user_id, friend_ids)
async def update_status(self, status: dict): async def update_status(self, incoming_status: dict):
"""Update the status of the current websocket connection.""" """Update the status of the current websocket connection."""
if not self.state: if not self.state:
return return
@ -452,7 +465,7 @@ class GatewayWebsocket:
# are just silently dropped. # are just silently dropped.
return return
default_status = { status = {
"afk": False, "afk": False,
# TODO: fetch status from settings # TODO: fetch status from settings
"status": "online", "status": "online",
@ -460,8 +473,7 @@ class GatewayWebsocket:
# TODO: this # TODO: this
"since": 0, "since": 0,
} }
status.update(incoming_status or {})
status = {**(status or {}), **default_status}
try: try:
status = validate(status, GW_STATUS_UPDATE) status = validate(status, GW_STATUS_UPDATE)
@ -479,15 +491,10 @@ class GatewayWebsocket:
else: else:
game = status["game"] game = status["game"]
# construct final status pres_status = status.get("status") or "online"
status = { pres_status = "offline" if pres_status == "invisible" else pres_status
"afk": status.get("afk", False), self.state.presence = BasePresence(status=pres_status, game=game)
"status": status.get("status", "online"),
"game": game,
"since": status.get("since", 0),
}
self.state.presence = status
log.info( log.info(
f'Updating presence status={status["status"]} for ' f'Updating presence status={status["status"]} for '
f"uid={self.state.user_id}" f"uid={self.state.user_id}"
@ -523,6 +530,8 @@ class GatewayWebsocket:
except KeyError: except KeyError:
raise DecodeError("Invalid identify parameters") raise DecodeError("Invalid identify parameters")
# TODO proper validation of this payload
compress = data.get("compress", False) compress = data.get("compress", False)
large = data.get("large_threshold", 50) large = data.get("large_threshold", 50)
@ -552,12 +561,12 @@ class GatewayWebsocket:
bot=bot, bot=bot,
compress=compress, compress=compress,
large=large, large=large,
shard=shard,
current_shard=shard[0], current_shard=shard[0],
shard_count=shard[1], shard_count=shard[1],
ws=self,
) )
self.state.ws = self
# link the state to the user # link the state to the user
self.app.state_manager.insert(self.state) self.app.state_manager.insert(self.state)
@ -1004,24 +1013,23 @@ class GatewayWebsocket:
if not user_id: if not user_id:
return return
# TODO: account for sharding # TODO: account for sharding. this only checks to dispatch an offline
# this only updates status to offline once # when all the shards have come fully offline, which is inefficient.
# ALL shards have come offline
# TODO why is this inneficient?
states = self.app.state_manager.user_states(user_id) states = self.app.state_manager.user_states(user_id)
with_ws = [s for s in states if s.ws] if not any(s.ws for s in states):
await self.app.presence.dispatch_pres(
# there arent any other states with websocket user_id, BasePresence(status="offline")
if not with_ws: )
offline = {"afk": False, "status": "offline", "game": None, "since": 0}
await self.app.presence.dispatch_pres(user_id, offline)
async def run(self): async def run(self):
"""Wrap :meth:`listen_messages` inside """Wrap :meth:`listen_messages` inside
a try/except block for WebsocketClose handling.""" a try/except block for WebsocketClose handling."""
try: try:
await self._send_hello() async with self.app.app_context():
await self._listen_messages() await self._send_hello()
await self._listen_messages()
except websockets.exceptions.ConnectionClosed as err: except websockets.exceptions.ConnectionClosed as err:
log.warning("conn close, state={}, err={}", self.state, err) log.warning("conn close, state={}, err={}", self.state, err)
except WebsocketClose as err: except WebsocketClose as err:

View File

@ -17,13 +17,33 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import List, Dict, Any, Iterable from typing import List, Dict, Any, Iterable, Optional
from random import choice from random import choice
from dataclasses import dataclass
from quart import current_app as app
from logbook import Logger from logbook import Logger
log = Logger(__name__) log = Logger(__name__)
@dataclass
class BasePresence:
status: str
game: Optional[dict] = None
@property
def partial_dict(self) -> dict:
return {
"status": self.status,
"game": self.game,
"since": 0,
"client_status": {},
"mobile": False,
"activities": [self.game] if self.game else [],
}
Presence = Dict[str, Any] Presence = Dict[str, Any]
@ -37,69 +57,29 @@ def status_cmp(status: str, other_status: str) -> bool:
return hierarchy[status] > hierarchy[other_status] return hierarchy[status] > hierarchy[other_status]
def _best_presence(shards): def _merge_state_presences(shards: list) -> BasePresence:
"""Find the 'best' presence given a list of GatewayState.""" """create a 'best' presence given a list of states."""
best = {"status": None, "game": None} best = BasePresence(status="offline")
for state in shards: for state in shards:
presence = state.presence if state.presence is None:
status = presence["status"]
if not presence:
continue continue
# shards with a better status # shards with a better status
# in the hierarchy are treated as best # in the hierarchy are treated as best
if status_cmp(status, best["status"]): if status_cmp(state.presence.status, best.status):
best["status"] = status best.status = state.presence.status
# if we have any game, use it # if we have any game, use it
if presence["game"] is not None: if state.presence.game:
best["game"] = presence["game"] best.game = state.presence.game
# best['status'] is None when no return best
# status was good enough.
return None if not best["status"] else best
def fill_presence(presence: dict, *, game=None) -> dict: async def _pres(user_id: int, presence: BasePresence) -> dict:
"""Fill a given presence object with some specific fields.""" """Take a given base presence and convert it to a full friend presence."""
presence["client_status"] = {} return {**presence.partial_dict, **{"user": await app.storage.get_user(user_id)}}
presence["mobile"] = False
if "since" not in presence:
presence["since"] = 0
# fill game and activities array depending if game
# is there or not
game = game or presence.get("game")
# casting to bool since a game of {} is still invalid
if game:
presence["game"] = game
presence["activities"] = [game]
else:
presence["game"] = None
presence["activities"] = []
return presence
async def _pres(storage, user_id: int, status_obj: dict) -> dict:
"""Convert a given status into a presence, given the User ID and the
:class:`Storage` instance."""
ext = {
"user": await storage.get_user(user_id),
"activities": [],
# NOTE: we are purposefully overwriting the fields, as there
# isn't any push for us to actually implement mobile detection, or
# web detection, etc.
"client_status": {},
"mobile": False,
}
return fill_presence({**status_obj, **ext})
class PresenceManager: class PresenceManager:
@ -123,39 +103,30 @@ class PresenceManager:
# then fetching its respective member and merging that info with # then fetching its respective member and merging that info with
# the state's set presence. # the state's set presence.
states = self.state_manager.guild_states(member_ids, guild_id) states = self.state_manager.guild_states(member_ids, guild_id)
presences = [] presences = []
for state in states: for state in states:
member = await self.storage.get_member_data_one(guild_id, state.user_id) member = await self.storage.get_member_data_one(guild_id, state.user_id)
game = state.presence.get("game", None)
# only use the data we need.
presences.append( presences.append(
fill_presence( {
{ **(state.presence or BasePresence(status="offline")).partial_dict,
**{
"user": member["user"], "user": member["user"],
"roles": member["roles"], "roles": member["roles"],
"guild_id": str(guild_id), "guild_id": str(guild_id),
# if a state is connected to the guild
# we assume its online.
"status": state.presence.get("status", "online"),
}, },
game=game, }
)
) )
return presences return presences
async def dispatch_guild_pres(self, guild_id: int, user_id: int, new_state: dict): async def dispatch_guild_pres(
self, guild_id: int, user_id: int, presence: BasePresence
):
"""Dispatch a Presence update to an entire guild.""" """Dispatch a Presence update to an entire guild."""
state = dict(new_state)
member = await self.storage.get_member_data_one(guild_id, user_id) member = await self.storage.get_member_data_one(guild_id, user_id)
game = state["game"]
lazy_guild_store = self.dispatcher.backends["lazy_guild"] lazy_guild_store = self.dispatcher.backends["lazy_guild"]
lists = lazy_guild_store.get_gml_guild(guild_id) lists = lazy_guild_store.get_gml_guild(guild_id)
@ -166,7 +137,11 @@ class PresenceManager:
for member_list in lists: for member_list in lists:
session_ids = await member_list.pres_update( session_ids = await member_list.pres_update(
int(member["user"]["id"]), int(member["user"]["id"]),
{"roles": member["roles"], "status": state["status"], "game": game}, {
"roles": member["roles"],
"status": presence.status,
"game": presence.game,
},
) )
log.debug("Lazy Dispatch to {}", len(session_ids)) log.debug("Lazy Dispatch to {}", len(session_ids))
@ -176,15 +151,14 @@ class PresenceManager:
if member_list.channel_id == member_list.guild_id: if member_list.channel_id == member_list.guild_id:
in_lazy.extend(session_ids) in_lazy.extend(session_ids)
pres_update_payload = fill_presence( event_payload = {
{ **presence.partial_dict,
**{
"guild_id": str(guild_id), "guild_id": str(guild_id),
"user": member["user"], "user": member["user"],
"roles": member["roles"], "roles": member["roles"],
"status": state["status"],
}, },
game=game, }
)
# given a session id, return if the session id actually connects to # given a session id, return if the session id actually connects to
# a given user, and if the state has not been dispatched via lazy guild. # a given user, and if the state has not been dispatched via lazy guild.
@ -202,79 +176,68 @@ class PresenceManager:
# everyone not in lazy guild mode # everyone not in lazy guild mode
# gets a PRESENCE_UPDATE # gets a PRESENCE_UPDATE
await self.dispatcher.dispatch_filter( await self.dispatcher.dispatch_filter(
"guild", guild_id, _session_check, "PRESENCE_UPDATE", pres_update_payload "guild", guild_id, _session_check, "PRESENCE_UPDATE", event_payload
) )
return in_lazy return in_lazy
async def dispatch_pres(self, user_id: int, state: dict): async def dispatch_pres(self, user_id: int, presence: BasePresence) -> None:
"""Dispatch a new presence to all guilds the user is in. """Dispatch a new presence to all guilds the user is in.
Also dispatches the presence to all the users' friends Also dispatches the presence to all the users' friends
""" """
if state["status"] == "invisible": # TODO: shard-aware (needs to only dispatch guilds of the shard)
state["status"] = "offline"
# TODO: shard-aware
guild_ids = await self.user_storage.get_user_guilds(user_id) guild_ids = await self.user_storage.get_user_guilds(user_id)
for guild_id in guild_ids: for guild_id in guild_ids:
await self.dispatch_guild_pres(guild_id, user_id, state) await self.dispatch_guild_pres(guild_id, user_id, presence)
# dispatch to all friends that are subscribed to them # dispatch to all friends that are subscribed to them
user = await self.storage.get_user(user_id) user = await self.storage.get_user(user_id)
game = state["game"]
await self.dispatcher.dispatch( await self.dispatcher.dispatch(
"friend", "friend",
user_id, user_id,
"PRESENCE_UPDATE", "PRESENCE_UPDATE",
fill_presence({"user": user, "status": state["status"]}, game=game), {**presence.partial_dict, **{"user": user}},
) )
def fetch_friend_presence(self, friend_id: int) -> BasePresence:
"""Fetch a presence for a friend.
This is a different algorithm than guild presence.
"""
friend_states = self.state_manager.user_states(friend_id)
if not friend_states:
return BasePresence(status="offline")
# filter the best shards:
# - all with id 0 (are the first shards in the collection) or
# - all shards with count = 1 (single shards)
good_shards = list(
filter(
lambda state: state.shard[0] == 0 or state.shard[1] == 1, friend_states
)
)
if good_shards:
return _merge_state_presences(good_shards)
# if there aren't any shards with id 0
# AND none that are single, just go with a random one.
shard = choice([s for s in friend_states if s.presence])
assert shard.presence is not None
return shard.presence
async def friend_presences(self, friend_ids: Iterable[int]) -> List[Presence]: async def friend_presences(self, friend_ids: Iterable[int]) -> List[Presence]:
"""Fetch presences for a group of users. """Fetch presences for a group of users.
This assumes the users are friends and so This assumes the users are friends and so
only gets states that are single or have ID 0. only gets states that are single or have ID 0.
""" """
storage = self.storage
res = [] res = []
for friend_id in friend_ids: for friend_id in friend_ids:
friend_states = self.state_manager.user_states(friend_id) presence = self.fetch_friend_presence(friend_id)
res.append(await _pres(friend_id, presence))
if not friend_states:
# append offline
res.append(
await _pres(
storage,
friend_id,
{"afk": False, "status": "offline", "game": None, "since": 0},
)
)
continue
# filter the best shards:
# - all with id 0 (are the first shards in the collection) or
# - all shards with count = 1 (single shards)
good_shards = list(
filter(
lambda state: state.shard[0] == 0 or state.shard[1] == 1,
friend_states,
)
)
if good_shards:
best_pres = _best_presence(good_shards)
best_pres = await _pres(storage, friend_id, best_pres)
res.append(best_pres)
continue
# if there aren't any shards with id 0
# AND none that are single, just go with a random
shard = choice(friend_states)
res.append(await _pres(storage, friend_id, shard.presence))
return res return res

View File

@ -290,3 +290,11 @@ def query_tuple_from_args(args: dict, limit: int) -> tuple:
def rand_hex(length: int = 8) -> str: def rand_hex(length: int = 8) -> str:
"""Generate random hex characters.""" """Generate random hex characters."""
return secrets.token_hex(length)[:length] return secrets.token_hex(length)[:length]
def want_bytes(data: Union[str, bytes]) -> bytes:
return data if isinstance(data, bytes) else data.encode()
def want_string(data: Union[str, bytes]) -> str:
return data.decode() if isinstance(data, bytes) else data