diff --git a/litecord/common/users.py b/litecord/common/users.py
index 7edbc33..a3ea268 100644
--- a/litecord/common/users.py
+++ b/litecord/common/users.py
@@ -24,6 +24,7 @@ from quart import current_app as app
from asyncpg import UniqueViolationError
from logbook import Logger
+from ..presence import BasePresence
from ..snowflake import get_snowflake
from ..errors import BadRequest
from ..auth import hash_data
@@ -268,6 +269,4 @@ async def user_disconnect(user_id: int):
await state.ws.ws.close(4000)
# force everyone to see the user as offline
- await app.presence.dispatch_pres(
- user_id, {"afk": False, "status": "offline", "game": None, "since": 0}
- )
+ await app.presence.dispatch_pres(user_id, BasePresence(status="offline"))
diff --git a/litecord/gateway/state.py b/litecord/gateway/state.py
index 9df23e3..3bac060 100644
--- a/litecord/gateway/state.py
+++ b/litecord/gateway/state.py
@@ -20,6 +20,9 @@ along with this program. If not, see .
import hashlib
import os
+from typing import Optional
+from litecord.presence import BasePresence
+
def gen_session_id() -> str:
"""Generate a random session ID."""
@@ -61,34 +64,35 @@ class GatewayState:
"""
def __init__(self, **kwargs):
- self.session_id = kwargs.get("session_id", gen_session_id())
+ self.session_id: str = kwargs.get("session_id", gen_session_id())
- #: event sequence number
- self.seq = kwargs.get("seq", 0)
+ #: last seq received by the client
+ self.seq: int = int(kwargs.get("seq") or 0)
- #: last seq sent by us, the backend
- self.last_seq = 0
+ #: last seq sent by gateway
+ self.last_seq: int = 0
- #: shard information about the state,
- # its id and shard count
- self.shard = kwargs.get("shard", [0, 1])
+ #: shard information (id and total count)
+ shard = kwargs.get("shard") or [0, 1]
+ self.current_shard: int = int(shard[0])
+ self.shard_count: int = int(shard[1])
- self.user_id = kwargs.get("user_id")
- self.bot = kwargs.get("bot", False)
+ self.user_id: int = int(kwargs["user_id"])
+ self.bot: bool = bool(kwargs.get("bot") or False)
#: set by the gateway connection
# on OP STATUS_UPDATE
- self.presence = {}
+ self.presence: Optional[BasePresence] = None
#: set by the backend once identify happens
self.ws = None
- #: store (kind of) all payloads sent by us
+ #: store of all payloads sent by the gateway (for recovery purposes)
self.store = PayloadStore()
- for key in kwargs:
- value = kwargs[key]
- self.__dict__[key] = value
+ self.compress: bool = kwargs.get("compress") or False
+
+ self.large: int = kwargs.get("large") or 50
def __repr__(self):
- return f"GatewayState"
+ return f"GatewayState"
diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py
index b295290..a4f10e8 100644
--- a/litecord/gateway/websocket.py
+++ b/litecord/gateway/websocket.py
@@ -31,8 +31,15 @@ from logbook import Logger
from litecord.auth import raw_token_check
from litecord.enums import RelationshipType, ChannelType
from litecord.schemas import validate, GW_STATUS_UPDATE
-from litecord.utils import task_wrapper, yield_chunks, maybe_int
+from litecord.utils import (
+ task_wrapper,
+ yield_chunks,
+ maybe_int,
+ want_bytes,
+ want_string,
+)
from litecord.permissions import get_permissions
+from litecord.presence import BasePresence
from litecord.gateway.opcodes import OP
from litecord.gateway.state import GatewayState
@@ -173,13 +180,15 @@ class GatewayWebsocket:
payload.get("t"),
)
+ # TODO encode to bytes only when absolutely needed e.g
+ # when compressing, because encoding == json means bytes won't work
if isinstance(encoded, str):
encoded = encoded.encode()
if self.wsp.compress == "zlib-stream":
- await self._zlib_stream_send(encoded)
+ await self._zlib_stream_send(want_bytes(encoded))
elif self.wsp.compress == "zstd-stream":
- await self._zstd_stream_send(encoded)
+ await self._zstd_stream_send(want_bytes(encoded))
elif (
self.state
and self.state.compress
@@ -188,9 +197,13 @@ class GatewayWebsocket:
):
# TODO determine better conditions to trigger a compress set
# by identify
- await self.ws.send(zlib.compress(encoded))
+ await self.ws.send(zlib.compress(want_bytes(encoded)))
else:
- await self.ws.send(encoded)
+ await self.ws.send(
+ want_bytes(encoded)
+ if self.wsp.encoding == "etf"
+ else want_string(encoded)
+ )
async def send_op(self, op_code: int, data: Any):
"""Send a packet but just the OP code information is filled in."""
@@ -343,7 +356,7 @@ class GatewayWebsocket:
"guilds": guilds,
"session_id": self.state.session_id,
"_trace": ["transbian"],
- "shard": self.state.shard,
+ "shard": [self.state.current_shard, self.state.shard_count],
}
await self.dispatch("READY", {**base_ready, **user_ready})
@@ -442,7 +455,7 @@ class GatewayWebsocket:
log.info("subscribing to {} friends", len(friend_ids))
await self.app.dispatcher.sub_many("friend", user_id, friend_ids)
- async def update_status(self, status: dict):
+ async def update_status(self, incoming_status: dict):
"""Update the status of the current websocket connection."""
if not self.state:
return
@@ -452,7 +465,7 @@ class GatewayWebsocket:
# are just silently dropped.
return
- default_status = {
+ status = {
"afk": False,
# TODO: fetch status from settings
"status": "online",
@@ -460,8 +473,7 @@ class GatewayWebsocket:
# TODO: this
"since": 0,
}
-
- status = {**(status or {}), **default_status}
+ status.update(incoming_status or {})
try:
status = validate(status, GW_STATUS_UPDATE)
@@ -479,15 +491,10 @@ class GatewayWebsocket:
else:
game = status["game"]
- # construct final status
- status = {
- "afk": status.get("afk", False),
- "status": status.get("status", "online"),
- "game": game,
- "since": status.get("since", 0),
- }
+ pres_status = status.get("status") or "online"
+ pres_status = "offline" if pres_status == "invisible" else pres_status
+ self.state.presence = BasePresence(status=pres_status, game=game)
- self.state.presence = status
log.info(
f'Updating presence status={status["status"]} for '
f"uid={self.state.user_id}"
@@ -523,6 +530,8 @@ class GatewayWebsocket:
except KeyError:
raise DecodeError("Invalid identify parameters")
+ # TODO proper validation of this payload
+
compress = data.get("compress", False)
large = data.get("large_threshold", 50)
@@ -552,12 +561,12 @@ class GatewayWebsocket:
bot=bot,
compress=compress,
large=large,
- shard=shard,
current_shard=shard[0],
shard_count=shard[1],
- ws=self,
)
+ self.state.ws = self
+
# link the state to the user
self.app.state_manager.insert(self.state)
@@ -1004,24 +1013,23 @@ class GatewayWebsocket:
if not user_id:
return
- # TODO: account for sharding
- # this only updates status to offline once
- # ALL shards have come offline
+ # TODO: account for sharding. this only checks to dispatch an offline
+ # when all the shards have come fully offline, which is inefficient.
+
+ # TODO why is this inneficient?
states = self.app.state_manager.user_states(user_id)
- with_ws = [s for s in states if s.ws]
-
- # there arent any other states with websocket
- if not with_ws:
- offline = {"afk": False, "status": "offline", "game": None, "since": 0}
-
- await self.app.presence.dispatch_pres(user_id, offline)
+ if not any(s.ws for s in states):
+ await self.app.presence.dispatch_pres(
+ user_id, BasePresence(status="offline")
+ )
async def run(self):
"""Wrap :meth:`listen_messages` inside
a try/except block for WebsocketClose handling."""
try:
- await self._send_hello()
- await self._listen_messages()
+ async with self.app.app_context():
+ await self._send_hello()
+ await self._listen_messages()
except websockets.exceptions.ConnectionClosed as err:
log.warning("conn close, state={}, err={}", self.state, err)
except WebsocketClose as err:
diff --git a/litecord/presence.py b/litecord/presence.py
index 55d7832..9911daf 100644
--- a/litecord/presence.py
+++ b/litecord/presence.py
@@ -17,13 +17,33 @@ along with this program. If not, see .
"""
-from typing import List, Dict, Any, Iterable
+from typing import List, Dict, Any, Iterable, Optional
from random import choice
+from dataclasses import dataclass
+from quart import current_app as app
from logbook import Logger
log = Logger(__name__)
+
+@dataclass
+class BasePresence:
+ status: str
+ game: Optional[dict] = None
+
+ @property
+ def partial_dict(self) -> dict:
+ return {
+ "status": self.status,
+ "game": self.game,
+ "since": 0,
+ "client_status": {},
+ "mobile": False,
+ "activities": [self.game] if self.game else [],
+ }
+
+
Presence = Dict[str, Any]
@@ -37,69 +57,29 @@ def status_cmp(status: str, other_status: str) -> bool:
return hierarchy[status] > hierarchy[other_status]
-def _best_presence(shards):
- """Find the 'best' presence given a list of GatewayState."""
- best = {"status": None, "game": None}
+def _merge_state_presences(shards: list) -> BasePresence:
+ """create a 'best' presence given a list of states."""
+ best = BasePresence(status="offline")
for state in shards:
- presence = state.presence
-
- status = presence["status"]
-
- if not presence:
+ if state.presence is None:
continue
# shards with a better status
# in the hierarchy are treated as best
- if status_cmp(status, best["status"]):
- best["status"] = status
+ if status_cmp(state.presence.status, best.status):
+ best.status = state.presence.status
# if we have any game, use it
- if presence["game"] is not None:
- best["game"] = presence["game"]
+ if state.presence.game:
+ best.game = state.presence.game
- # best['status'] is None when no
- # status was good enough.
- return None if not best["status"] else best
+ return best
-def fill_presence(presence: dict, *, game=None) -> dict:
- """Fill a given presence object with some specific fields."""
- presence["client_status"] = {}
- presence["mobile"] = False
-
- if "since" not in presence:
- presence["since"] = 0
-
- # fill game and activities array depending if game
- # is there or not
- game = game or presence.get("game")
-
- # casting to bool since a game of {} is still invalid
- if game:
- presence["game"] = game
- presence["activities"] = [game]
- else:
- presence["game"] = None
- presence["activities"] = []
-
- return presence
-
-
-async def _pres(storage, user_id: int, status_obj: dict) -> dict:
- """Convert a given status into a presence, given the User ID and the
- :class:`Storage` instance."""
- ext = {
- "user": await storage.get_user(user_id),
- "activities": [],
- # NOTE: we are purposefully overwriting the fields, as there
- # isn't any push for us to actually implement mobile detection, or
- # web detection, etc.
- "client_status": {},
- "mobile": False,
- }
-
- return fill_presence({**status_obj, **ext})
+async def _pres(user_id: int, presence: BasePresence) -> dict:
+ """Take a given base presence and convert it to a full friend presence."""
+ return {**presence.partial_dict, **{"user": await app.storage.get_user(user_id)}}
class PresenceManager:
@@ -123,39 +103,30 @@ class PresenceManager:
# then fetching its respective member and merging that info with
# the state's set presence.
states = self.state_manager.guild_states(member_ids, guild_id)
-
presences = []
for state in states:
member = await self.storage.get_member_data_one(guild_id, state.user_id)
-
- game = state.presence.get("game", None)
-
- # only use the data we need.
presences.append(
- fill_presence(
- {
+ {
+ **(state.presence or BasePresence(status="offline")).partial_dict,
+ **{
"user": member["user"],
"roles": member["roles"],
"guild_id": str(guild_id),
- # if a state is connected to the guild
- # we assume its online.
- "status": state.presence.get("status", "online"),
},
- game=game,
- )
+ }
)
return presences
- async def dispatch_guild_pres(self, guild_id: int, user_id: int, new_state: dict):
+ async def dispatch_guild_pres(
+ self, guild_id: int, user_id: int, presence: BasePresence
+ ):
"""Dispatch a Presence update to an entire guild."""
- state = dict(new_state)
member = await self.storage.get_member_data_one(guild_id, user_id)
- game = state["game"]
-
lazy_guild_store = self.dispatcher.backends["lazy_guild"]
lists = lazy_guild_store.get_gml_guild(guild_id)
@@ -166,7 +137,11 @@ class PresenceManager:
for member_list in lists:
session_ids = await member_list.pres_update(
int(member["user"]["id"]),
- {"roles": member["roles"], "status": state["status"], "game": game},
+ {
+ "roles": member["roles"],
+ "status": presence.status,
+ "game": presence.game,
+ },
)
log.debug("Lazy Dispatch to {}", len(session_ids))
@@ -176,15 +151,14 @@ class PresenceManager:
if member_list.channel_id == member_list.guild_id:
in_lazy.extend(session_ids)
- pres_update_payload = fill_presence(
- {
+ event_payload = {
+ **presence.partial_dict,
+ **{
"guild_id": str(guild_id),
"user": member["user"],
"roles": member["roles"],
- "status": state["status"],
},
- game=game,
- )
+ }
# given a session id, return if the session id actually connects to
# a given user, and if the state has not been dispatched via lazy guild.
@@ -202,79 +176,68 @@ class PresenceManager:
# everyone not in lazy guild mode
# gets a PRESENCE_UPDATE
await self.dispatcher.dispatch_filter(
- "guild", guild_id, _session_check, "PRESENCE_UPDATE", pres_update_payload
+ "guild", guild_id, _session_check, "PRESENCE_UPDATE", event_payload
)
return in_lazy
- async def dispatch_pres(self, user_id: int, state: dict):
+ async def dispatch_pres(self, user_id: int, presence: BasePresence) -> None:
"""Dispatch a new presence to all guilds the user is in.
Also dispatches the presence to all the users' friends
"""
- if state["status"] == "invisible":
- state["status"] = "offline"
-
- # TODO: shard-aware
+ # TODO: shard-aware (needs to only dispatch guilds of the shard)
guild_ids = await self.user_storage.get_user_guilds(user_id)
-
for guild_id in guild_ids:
- await self.dispatch_guild_pres(guild_id, user_id, state)
+ await self.dispatch_guild_pres(guild_id, user_id, presence)
# dispatch to all friends that are subscribed to them
user = await self.storage.get_user(user_id)
- game = state["game"]
-
await self.dispatcher.dispatch(
"friend",
user_id,
"PRESENCE_UPDATE",
- fill_presence({"user": user, "status": state["status"]}, game=game),
+ {**presence.partial_dict, **{"user": user}},
)
+ def fetch_friend_presence(self, friend_id: int) -> BasePresence:
+ """Fetch a presence for a friend.
+
+ This is a different algorithm than guild presence.
+ """
+ friend_states = self.state_manager.user_states(friend_id)
+
+ if not friend_states:
+ return BasePresence(status="offline")
+
+ # filter the best shards:
+ # - all with id 0 (are the first shards in the collection) or
+ # - all shards with count = 1 (single shards)
+ good_shards = list(
+ filter(
+ lambda state: state.shard[0] == 0 or state.shard[1] == 1, friend_states
+ )
+ )
+
+ if good_shards:
+ return _merge_state_presences(good_shards)
+
+ # if there aren't any shards with id 0
+ # AND none that are single, just go with a random one.
+ shard = choice([s for s in friend_states if s.presence])
+ assert shard.presence is not None
+ return shard.presence
+
async def friend_presences(self, friend_ids: Iterable[int]) -> List[Presence]:
"""Fetch presences for a group of users.
This assumes the users are friends and so
only gets states that are single or have ID 0.
"""
- storage = self.storage
res = []
for friend_id in friend_ids:
- friend_states = self.state_manager.user_states(friend_id)
-
- if not friend_states:
- # append offline
- res.append(
- await _pres(
- storage,
- friend_id,
- {"afk": False, "status": "offline", "game": None, "since": 0},
- )
- )
-
- continue
-
- # filter the best shards:
- # - all with id 0 (are the first shards in the collection) or
- # - all shards with count = 1 (single shards)
- good_shards = list(
- filter(
- lambda state: state.shard[0] == 0 or state.shard[1] == 1,
- friend_states,
- )
- )
-
- if good_shards:
- best_pres = _best_presence(good_shards)
- best_pres = await _pres(storage, friend_id, best_pres)
- res.append(best_pres)
- continue
-
- # if there aren't any shards with id 0
- # AND none that are single, just go with a random
- shard = choice(friend_states)
- res.append(await _pres(storage, friend_id, shard.presence))
+ presence = self.fetch_friend_presence(friend_id)
+ res.append(await _pres(friend_id, presence))
return res
diff --git a/litecord/utils.py b/litecord/utils.py
index 91e7d8a..fe3c2d4 100644
--- a/litecord/utils.py
+++ b/litecord/utils.py
@@ -290,3 +290,11 @@ def query_tuple_from_args(args: dict, limit: int) -> tuple:
def rand_hex(length: int = 8) -> str:
"""Generate random hex characters."""
return secrets.token_hex(length)[:length]
+
+
+def want_bytes(data: Union[str, bytes]) -> bytes:
+ return data if isinstance(data, bytes) else data.encode()
+
+
+def want_string(data: Union[str, bytes]) -> str:
+ return data.decode() if isinstance(data, bytes) else data