mirror of https://gitlab.com/litecord/litecord.git
Refactor GatewayState with proper types
Make websocket loop be inside app context
This commit is contained in:
parent
70443ca379
commit
0f600543e8
|
|
@ -24,6 +24,7 @@ from quart import current_app as app
|
|||
from asyncpg import UniqueViolationError
|
||||
from logbook import Logger
|
||||
|
||||
from ..presence import BasePresence
|
||||
from ..snowflake import get_snowflake
|
||||
from ..errors import BadRequest
|
||||
from ..auth import hash_data
|
||||
|
|
@ -268,6 +269,4 @@ async def user_disconnect(user_id: int):
|
|||
await state.ws.ws.close(4000)
|
||||
|
||||
# force everyone to see the user as offline
|
||||
await app.presence.dispatch_pres(
|
||||
user_id, {"afk": False, "status": "offline", "game": None, "since": 0}
|
||||
)
|
||||
await app.presence.dispatch_pres(user_id, BasePresence(status="offline"))
|
||||
|
|
|
|||
|
|
@ -20,6 +20,9 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
import hashlib
|
||||
import os
|
||||
|
||||
from typing import Optional
|
||||
from litecord.presence import BasePresence
|
||||
|
||||
|
||||
def gen_session_id() -> str:
|
||||
"""Generate a random session ID."""
|
||||
|
|
@ -61,34 +64,35 @@ class GatewayState:
|
|||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.session_id = kwargs.get("session_id", gen_session_id())
|
||||
self.session_id: str = kwargs.get("session_id", gen_session_id())
|
||||
|
||||
#: event sequence number
|
||||
self.seq = kwargs.get("seq", 0)
|
||||
#: last seq received by the client
|
||||
self.seq: int = int(kwargs.get("seq") or 0)
|
||||
|
||||
#: last seq sent by us, the backend
|
||||
self.last_seq = 0
|
||||
#: last seq sent by gateway
|
||||
self.last_seq: int = 0
|
||||
|
||||
#: shard information about the state,
|
||||
# its id and shard count
|
||||
self.shard = kwargs.get("shard", [0, 1])
|
||||
#: shard information (id and total count)
|
||||
shard = kwargs.get("shard") or [0, 1]
|
||||
self.current_shard: int = int(shard[0])
|
||||
self.shard_count: int = int(shard[1])
|
||||
|
||||
self.user_id = kwargs.get("user_id")
|
||||
self.bot = kwargs.get("bot", False)
|
||||
self.user_id: int = int(kwargs["user_id"])
|
||||
self.bot: bool = bool(kwargs.get("bot") or False)
|
||||
|
||||
#: set by the gateway connection
|
||||
# on OP STATUS_UPDATE
|
||||
self.presence = {}
|
||||
self.presence: Optional[BasePresence] = None
|
||||
|
||||
#: set by the backend once identify happens
|
||||
self.ws = None
|
||||
|
||||
#: store (kind of) all payloads sent by us
|
||||
#: store of all payloads sent by the gateway (for recovery purposes)
|
||||
self.store = PayloadStore()
|
||||
|
||||
for key in kwargs:
|
||||
value = kwargs[key]
|
||||
self.__dict__[key] = value
|
||||
self.compress: bool = kwargs.get("compress") or False
|
||||
|
||||
self.large: int = kwargs.get("large") or 50
|
||||
|
||||
def __repr__(self):
|
||||
return f"GatewayState<seq={self.seq} " f"shard={self.shard} uid={self.user_id}>"
|
||||
return f"GatewayState<seq={self.seq} shard={self.current_shard},{self.shard_count} uid={self.user_id}>"
|
||||
|
|
|
|||
|
|
@ -31,8 +31,15 @@ from logbook import Logger
|
|||
from litecord.auth import raw_token_check
|
||||
from litecord.enums import RelationshipType, ChannelType
|
||||
from litecord.schemas import validate, GW_STATUS_UPDATE
|
||||
from litecord.utils import task_wrapper, yield_chunks, maybe_int
|
||||
from litecord.utils import (
|
||||
task_wrapper,
|
||||
yield_chunks,
|
||||
maybe_int,
|
||||
want_bytes,
|
||||
want_string,
|
||||
)
|
||||
from litecord.permissions import get_permissions
|
||||
from litecord.presence import BasePresence
|
||||
|
||||
from litecord.gateway.opcodes import OP
|
||||
from litecord.gateway.state import GatewayState
|
||||
|
|
@ -173,13 +180,15 @@ class GatewayWebsocket:
|
|||
payload.get("t"),
|
||||
)
|
||||
|
||||
# TODO encode to bytes only when absolutely needed e.g
|
||||
# when compressing, because encoding == json means bytes won't work
|
||||
if isinstance(encoded, str):
|
||||
encoded = encoded.encode()
|
||||
|
||||
if self.wsp.compress == "zlib-stream":
|
||||
await self._zlib_stream_send(encoded)
|
||||
await self._zlib_stream_send(want_bytes(encoded))
|
||||
elif self.wsp.compress == "zstd-stream":
|
||||
await self._zstd_stream_send(encoded)
|
||||
await self._zstd_stream_send(want_bytes(encoded))
|
||||
elif (
|
||||
self.state
|
||||
and self.state.compress
|
||||
|
|
@ -188,9 +197,13 @@ class GatewayWebsocket:
|
|||
):
|
||||
# TODO determine better conditions to trigger a compress set
|
||||
# by identify
|
||||
await self.ws.send(zlib.compress(encoded))
|
||||
await self.ws.send(zlib.compress(want_bytes(encoded)))
|
||||
else:
|
||||
await self.ws.send(encoded)
|
||||
await self.ws.send(
|
||||
want_bytes(encoded)
|
||||
if self.wsp.encoding == "etf"
|
||||
else want_string(encoded)
|
||||
)
|
||||
|
||||
async def send_op(self, op_code: int, data: Any):
|
||||
"""Send a packet but just the OP code information is filled in."""
|
||||
|
|
@ -343,7 +356,7 @@ class GatewayWebsocket:
|
|||
"guilds": guilds,
|
||||
"session_id": self.state.session_id,
|
||||
"_trace": ["transbian"],
|
||||
"shard": self.state.shard,
|
||||
"shard": [self.state.current_shard, self.state.shard_count],
|
||||
}
|
||||
|
||||
await self.dispatch("READY", {**base_ready, **user_ready})
|
||||
|
|
@ -442,7 +455,7 @@ class GatewayWebsocket:
|
|||
log.info("subscribing to {} friends", len(friend_ids))
|
||||
await self.app.dispatcher.sub_many("friend", user_id, friend_ids)
|
||||
|
||||
async def update_status(self, status: dict):
|
||||
async def update_status(self, incoming_status: dict):
|
||||
"""Update the status of the current websocket connection."""
|
||||
if not self.state:
|
||||
return
|
||||
|
|
@ -452,7 +465,7 @@ class GatewayWebsocket:
|
|||
# are just silently dropped.
|
||||
return
|
||||
|
||||
default_status = {
|
||||
status = {
|
||||
"afk": False,
|
||||
# TODO: fetch status from settings
|
||||
"status": "online",
|
||||
|
|
@ -460,8 +473,7 @@ class GatewayWebsocket:
|
|||
# TODO: this
|
||||
"since": 0,
|
||||
}
|
||||
|
||||
status = {**(status or {}), **default_status}
|
||||
status.update(incoming_status or {})
|
||||
|
||||
try:
|
||||
status = validate(status, GW_STATUS_UPDATE)
|
||||
|
|
@ -479,15 +491,10 @@ class GatewayWebsocket:
|
|||
else:
|
||||
game = status["game"]
|
||||
|
||||
# construct final status
|
||||
status = {
|
||||
"afk": status.get("afk", False),
|
||||
"status": status.get("status", "online"),
|
||||
"game": game,
|
||||
"since": status.get("since", 0),
|
||||
}
|
||||
pres_status = status.get("status") or "online"
|
||||
pres_status = "offline" if pres_status == "invisible" else pres_status
|
||||
self.state.presence = BasePresence(status=pres_status, game=game)
|
||||
|
||||
self.state.presence = status
|
||||
log.info(
|
||||
f'Updating presence status={status["status"]} for '
|
||||
f"uid={self.state.user_id}"
|
||||
|
|
@ -523,6 +530,8 @@ class GatewayWebsocket:
|
|||
except KeyError:
|
||||
raise DecodeError("Invalid identify parameters")
|
||||
|
||||
# TODO proper validation of this payload
|
||||
|
||||
compress = data.get("compress", False)
|
||||
large = data.get("large_threshold", 50)
|
||||
|
||||
|
|
@ -552,12 +561,12 @@ class GatewayWebsocket:
|
|||
bot=bot,
|
||||
compress=compress,
|
||||
large=large,
|
||||
shard=shard,
|
||||
current_shard=shard[0],
|
||||
shard_count=shard[1],
|
||||
ws=self,
|
||||
)
|
||||
|
||||
self.state.ws = self
|
||||
|
||||
# link the state to the user
|
||||
self.app.state_manager.insert(self.state)
|
||||
|
||||
|
|
@ -1004,22 +1013,21 @@ class GatewayWebsocket:
|
|||
if not user_id:
|
||||
return
|
||||
|
||||
# TODO: account for sharding
|
||||
# this only updates status to offline once
|
||||
# ALL shards have come offline
|
||||
# TODO: account for sharding. this only checks to dispatch an offline
|
||||
# when all the shards have come fully offline, which is inefficient.
|
||||
|
||||
# TODO why is this inneficient?
|
||||
states = self.app.state_manager.user_states(user_id)
|
||||
with_ws = [s for s in states if s.ws]
|
||||
|
||||
# there arent any other states with websocket
|
||||
if not with_ws:
|
||||
offline = {"afk": False, "status": "offline", "game": None, "since": 0}
|
||||
|
||||
await self.app.presence.dispatch_pres(user_id, offline)
|
||||
if not any(s.ws for s in states):
|
||||
await self.app.presence.dispatch_pres(
|
||||
user_id, BasePresence(status="offline")
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
"""Wrap :meth:`listen_messages` inside
|
||||
a try/except block for WebsocketClose handling."""
|
||||
try:
|
||||
async with self.app.app_context():
|
||||
await self._send_hello()
|
||||
await self._listen_messages()
|
||||
except websockets.exceptions.ConnectionClosed as err:
|
||||
|
|
|
|||
|
|
@ -17,13 +17,33 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Iterable
|
||||
from typing import List, Dict, Any, Iterable, Optional
|
||||
from random import choice
|
||||
from dataclasses import dataclass
|
||||
from quart import current_app as app
|
||||
|
||||
from logbook import Logger
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasePresence:
|
||||
status: str
|
||||
game: Optional[dict] = None
|
||||
|
||||
@property
|
||||
def partial_dict(self) -> dict:
|
||||
return {
|
||||
"status": self.status,
|
||||
"game": self.game,
|
||||
"since": 0,
|
||||
"client_status": {},
|
||||
"mobile": False,
|
||||
"activities": [self.game] if self.game else [],
|
||||
}
|
||||
|
||||
|
||||
Presence = Dict[str, Any]
|
||||
|
||||
|
||||
|
|
@ -37,69 +57,29 @@ def status_cmp(status: str, other_status: str) -> bool:
|
|||
return hierarchy[status] > hierarchy[other_status]
|
||||
|
||||
|
||||
def _best_presence(shards):
|
||||
"""Find the 'best' presence given a list of GatewayState."""
|
||||
best = {"status": None, "game": None}
|
||||
def _merge_state_presences(shards: list) -> BasePresence:
|
||||
"""create a 'best' presence given a list of states."""
|
||||
best = BasePresence(status="offline")
|
||||
|
||||
for state in shards:
|
||||
presence = state.presence
|
||||
|
||||
status = presence["status"]
|
||||
|
||||
if not presence:
|
||||
if state.presence is None:
|
||||
continue
|
||||
|
||||
# shards with a better status
|
||||
# in the hierarchy are treated as best
|
||||
if status_cmp(status, best["status"]):
|
||||
best["status"] = status
|
||||
if status_cmp(state.presence.status, best.status):
|
||||
best.status = state.presence.status
|
||||
|
||||
# if we have any game, use it
|
||||
if presence["game"] is not None:
|
||||
best["game"] = presence["game"]
|
||||
if state.presence.game:
|
||||
best.game = state.presence.game
|
||||
|
||||
# best['status'] is None when no
|
||||
# status was good enough.
|
||||
return None if not best["status"] else best
|
||||
return best
|
||||
|
||||
|
||||
def fill_presence(presence: dict, *, game=None) -> dict:
|
||||
"""Fill a given presence object with some specific fields."""
|
||||
presence["client_status"] = {}
|
||||
presence["mobile"] = False
|
||||
|
||||
if "since" not in presence:
|
||||
presence["since"] = 0
|
||||
|
||||
# fill game and activities array depending if game
|
||||
# is there or not
|
||||
game = game or presence.get("game")
|
||||
|
||||
# casting to bool since a game of {} is still invalid
|
||||
if game:
|
||||
presence["game"] = game
|
||||
presence["activities"] = [game]
|
||||
else:
|
||||
presence["game"] = None
|
||||
presence["activities"] = []
|
||||
|
||||
return presence
|
||||
|
||||
|
||||
async def _pres(storage, user_id: int, status_obj: dict) -> dict:
|
||||
"""Convert a given status into a presence, given the User ID and the
|
||||
:class:`Storage` instance."""
|
||||
ext = {
|
||||
"user": await storage.get_user(user_id),
|
||||
"activities": [],
|
||||
# NOTE: we are purposefully overwriting the fields, as there
|
||||
# isn't any push for us to actually implement mobile detection, or
|
||||
# web detection, etc.
|
||||
"client_status": {},
|
||||
"mobile": False,
|
||||
}
|
||||
|
||||
return fill_presence({**status_obj, **ext})
|
||||
async def _pres(user_id: int, presence: BasePresence) -> dict:
|
||||
"""Take a given base presence and convert it to a full friend presence."""
|
||||
return {**presence.partial_dict, **{"user": await app.storage.get_user(user_id)}}
|
||||
|
||||
|
||||
class PresenceManager:
|
||||
|
|
@ -123,39 +103,30 @@ class PresenceManager:
|
|||
# then fetching its respective member and merging that info with
|
||||
# the state's set presence.
|
||||
states = self.state_manager.guild_states(member_ids, guild_id)
|
||||
|
||||
presences = []
|
||||
|
||||
for state in states:
|
||||
member = await self.storage.get_member_data_one(guild_id, state.user_id)
|
||||
|
||||
game = state.presence.get("game", None)
|
||||
|
||||
# only use the data we need.
|
||||
presences.append(
|
||||
fill_presence(
|
||||
{
|
||||
**(state.presence or BasePresence(status="offline")).partial_dict,
|
||||
**{
|
||||
"user": member["user"],
|
||||
"roles": member["roles"],
|
||||
"guild_id": str(guild_id),
|
||||
# if a state is connected to the guild
|
||||
# we assume its online.
|
||||
"status": state.presence.get("status", "online"),
|
||||
},
|
||||
game=game,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
return presences
|
||||
|
||||
async def dispatch_guild_pres(self, guild_id: int, user_id: int, new_state: dict):
|
||||
async def dispatch_guild_pres(
|
||||
self, guild_id: int, user_id: int, presence: BasePresence
|
||||
):
|
||||
"""Dispatch a Presence update to an entire guild."""
|
||||
state = dict(new_state)
|
||||
|
||||
member = await self.storage.get_member_data_one(guild_id, user_id)
|
||||
|
||||
game = state["game"]
|
||||
|
||||
lazy_guild_store = self.dispatcher.backends["lazy_guild"]
|
||||
lists = lazy_guild_store.get_gml_guild(guild_id)
|
||||
|
||||
|
|
@ -166,7 +137,11 @@ class PresenceManager:
|
|||
for member_list in lists:
|
||||
session_ids = await member_list.pres_update(
|
||||
int(member["user"]["id"]),
|
||||
{"roles": member["roles"], "status": state["status"], "game": game},
|
||||
{
|
||||
"roles": member["roles"],
|
||||
"status": presence.status,
|
||||
"game": presence.game,
|
||||
},
|
||||
)
|
||||
|
||||
log.debug("Lazy Dispatch to {}", len(session_ids))
|
||||
|
|
@ -176,15 +151,14 @@ class PresenceManager:
|
|||
if member_list.channel_id == member_list.guild_id:
|
||||
in_lazy.extend(session_ids)
|
||||
|
||||
pres_update_payload = fill_presence(
|
||||
{
|
||||
event_payload = {
|
||||
**presence.partial_dict,
|
||||
**{
|
||||
"guild_id": str(guild_id),
|
||||
"user": member["user"],
|
||||
"roles": member["roles"],
|
||||
"status": state["status"],
|
||||
},
|
||||
game=game,
|
||||
)
|
||||
}
|
||||
|
||||
# given a session id, return if the session id actually connects to
|
||||
# a given user, and if the state has not been dispatched via lazy guild.
|
||||
|
|
@ -202,79 +176,68 @@ class PresenceManager:
|
|||
# everyone not in lazy guild mode
|
||||
# gets a PRESENCE_UPDATE
|
||||
await self.dispatcher.dispatch_filter(
|
||||
"guild", guild_id, _session_check, "PRESENCE_UPDATE", pres_update_payload
|
||||
"guild", guild_id, _session_check, "PRESENCE_UPDATE", event_payload
|
||||
)
|
||||
|
||||
return in_lazy
|
||||
|
||||
async def dispatch_pres(self, user_id: int, state: dict):
|
||||
async def dispatch_pres(self, user_id: int, presence: BasePresence) -> None:
|
||||
"""Dispatch a new presence to all guilds the user is in.
|
||||
|
||||
Also dispatches the presence to all the users' friends
|
||||
"""
|
||||
if state["status"] == "invisible":
|
||||
state["status"] = "offline"
|
||||
|
||||
# TODO: shard-aware
|
||||
# TODO: shard-aware (needs to only dispatch guilds of the shard)
|
||||
guild_ids = await self.user_storage.get_user_guilds(user_id)
|
||||
|
||||
for guild_id in guild_ids:
|
||||
await self.dispatch_guild_pres(guild_id, user_id, state)
|
||||
await self.dispatch_guild_pres(guild_id, user_id, presence)
|
||||
|
||||
# dispatch to all friends that are subscribed to them
|
||||
user = await self.storage.get_user(user_id)
|
||||
game = state["game"]
|
||||
|
||||
await self.dispatcher.dispatch(
|
||||
"friend",
|
||||
user_id,
|
||||
"PRESENCE_UPDATE",
|
||||
fill_presence({"user": user, "status": state["status"]}, game=game),
|
||||
{**presence.partial_dict, **{"user": user}},
|
||||
)
|
||||
|
||||
def fetch_friend_presence(self, friend_id: int) -> BasePresence:
|
||||
"""Fetch a presence for a friend.
|
||||
|
||||
This is a different algorithm than guild presence.
|
||||
"""
|
||||
friend_states = self.state_manager.user_states(friend_id)
|
||||
|
||||
if not friend_states:
|
||||
return BasePresence(status="offline")
|
||||
|
||||
# filter the best shards:
|
||||
# - all with id 0 (are the first shards in the collection) or
|
||||
# - all shards with count = 1 (single shards)
|
||||
good_shards = list(
|
||||
filter(
|
||||
lambda state: state.shard[0] == 0 or state.shard[1] == 1, friend_states
|
||||
)
|
||||
)
|
||||
|
||||
if good_shards:
|
||||
return _merge_state_presences(good_shards)
|
||||
|
||||
# if there aren't any shards with id 0
|
||||
# AND none that are single, just go with a random one.
|
||||
shard = choice([s for s in friend_states if s.presence])
|
||||
assert shard.presence is not None
|
||||
return shard.presence
|
||||
|
||||
async def friend_presences(self, friend_ids: Iterable[int]) -> List[Presence]:
|
||||
"""Fetch presences for a group of users.
|
||||
|
||||
This assumes the users are friends and so
|
||||
only gets states that are single or have ID 0.
|
||||
"""
|
||||
storage = self.storage
|
||||
res = []
|
||||
|
||||
for friend_id in friend_ids:
|
||||
friend_states = self.state_manager.user_states(friend_id)
|
||||
|
||||
if not friend_states:
|
||||
# append offline
|
||||
res.append(
|
||||
await _pres(
|
||||
storage,
|
||||
friend_id,
|
||||
{"afk": False, "status": "offline", "game": None, "since": 0},
|
||||
)
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
# filter the best shards:
|
||||
# - all with id 0 (are the first shards in the collection) or
|
||||
# - all shards with count = 1 (single shards)
|
||||
good_shards = list(
|
||||
filter(
|
||||
lambda state: state.shard[0] == 0 or state.shard[1] == 1,
|
||||
friend_states,
|
||||
)
|
||||
)
|
||||
|
||||
if good_shards:
|
||||
best_pres = _best_presence(good_shards)
|
||||
best_pres = await _pres(storage, friend_id, best_pres)
|
||||
res.append(best_pres)
|
||||
continue
|
||||
|
||||
# if there aren't any shards with id 0
|
||||
# AND none that are single, just go with a random
|
||||
shard = choice(friend_states)
|
||||
res.append(await _pres(storage, friend_id, shard.presence))
|
||||
presence = self.fetch_friend_presence(friend_id)
|
||||
res.append(await _pres(friend_id, presence))
|
||||
|
||||
return res
|
||||
|
|
|
|||
|
|
@ -290,3 +290,11 @@ def query_tuple_from_args(args: dict, limit: int) -> tuple:
|
|||
def rand_hex(length: int = 8) -> str:
|
||||
"""Generate random hex characters."""
|
||||
return secrets.token_hex(length)[:length]
|
||||
|
||||
|
||||
def want_bytes(data: Union[str, bytes]) -> bytes:
|
||||
return data if isinstance(data, bytes) else data.encode()
|
||||
|
||||
|
||||
def want_string(data: Union[str, bytes]) -> str:
|
||||
return data.decode() if isinstance(data, bytes) else data
|
||||
|
|
|
|||
Loading…
Reference in New Issue