mirror of https://gitlab.com/litecord/litecord.git
Refactor GatewayState with proper types
Make websocket loop be inside app context
This commit is contained in:
parent
70443ca379
commit
0f600543e8
|
|
@ -24,6 +24,7 @@ from quart import current_app as app
|
||||||
from asyncpg import UniqueViolationError
|
from asyncpg import UniqueViolationError
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
|
from ..presence import BasePresence
|
||||||
from ..snowflake import get_snowflake
|
from ..snowflake import get_snowflake
|
||||||
from ..errors import BadRequest
|
from ..errors import BadRequest
|
||||||
from ..auth import hash_data
|
from ..auth import hash_data
|
||||||
|
|
@ -268,6 +269,4 @@ async def user_disconnect(user_id: int):
|
||||||
await state.ws.ws.close(4000)
|
await state.ws.ws.close(4000)
|
||||||
|
|
||||||
# force everyone to see the user as offline
|
# force everyone to see the user as offline
|
||||||
await app.presence.dispatch_pres(
|
await app.presence.dispatch_pres(user_id, BasePresence(status="offline"))
|
||||||
user_id, {"afk": False, "status": "offline", "game": None, "since": 0}
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,9 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from litecord.presence import BasePresence
|
||||||
|
|
||||||
|
|
||||||
def gen_session_id() -> str:
|
def gen_session_id() -> str:
|
||||||
"""Generate a random session ID."""
|
"""Generate a random session ID."""
|
||||||
|
|
@ -61,34 +64,35 @@ class GatewayState:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.session_id = kwargs.get("session_id", gen_session_id())
|
self.session_id: str = kwargs.get("session_id", gen_session_id())
|
||||||
|
|
||||||
#: event sequence number
|
#: last seq received by the client
|
||||||
self.seq = kwargs.get("seq", 0)
|
self.seq: int = int(kwargs.get("seq") or 0)
|
||||||
|
|
||||||
#: last seq sent by us, the backend
|
#: last seq sent by gateway
|
||||||
self.last_seq = 0
|
self.last_seq: int = 0
|
||||||
|
|
||||||
#: shard information about the state,
|
#: shard information (id and total count)
|
||||||
# its id and shard count
|
shard = kwargs.get("shard") or [0, 1]
|
||||||
self.shard = kwargs.get("shard", [0, 1])
|
self.current_shard: int = int(shard[0])
|
||||||
|
self.shard_count: int = int(shard[1])
|
||||||
|
|
||||||
self.user_id = kwargs.get("user_id")
|
self.user_id: int = int(kwargs["user_id"])
|
||||||
self.bot = kwargs.get("bot", False)
|
self.bot: bool = bool(kwargs.get("bot") or False)
|
||||||
|
|
||||||
#: set by the gateway connection
|
#: set by the gateway connection
|
||||||
# on OP STATUS_UPDATE
|
# on OP STATUS_UPDATE
|
||||||
self.presence = {}
|
self.presence: Optional[BasePresence] = None
|
||||||
|
|
||||||
#: set by the backend once identify happens
|
#: set by the backend once identify happens
|
||||||
self.ws = None
|
self.ws = None
|
||||||
|
|
||||||
#: store (kind of) all payloads sent by us
|
#: store of all payloads sent by the gateway (for recovery purposes)
|
||||||
self.store = PayloadStore()
|
self.store = PayloadStore()
|
||||||
|
|
||||||
for key in kwargs:
|
self.compress: bool = kwargs.get("compress") or False
|
||||||
value = kwargs[key]
|
|
||||||
self.__dict__[key] = value
|
self.large: int = kwargs.get("large") or 50
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"GatewayState<seq={self.seq} " f"shard={self.shard} uid={self.user_id}>"
|
return f"GatewayState<seq={self.seq} shard={self.current_shard},{self.shard_count} uid={self.user_id}>"
|
||||||
|
|
|
||||||
|
|
@ -31,8 +31,15 @@ from logbook import Logger
|
||||||
from litecord.auth import raw_token_check
|
from litecord.auth import raw_token_check
|
||||||
from litecord.enums import RelationshipType, ChannelType
|
from litecord.enums import RelationshipType, ChannelType
|
||||||
from litecord.schemas import validate, GW_STATUS_UPDATE
|
from litecord.schemas import validate, GW_STATUS_UPDATE
|
||||||
from litecord.utils import task_wrapper, yield_chunks, maybe_int
|
from litecord.utils import (
|
||||||
|
task_wrapper,
|
||||||
|
yield_chunks,
|
||||||
|
maybe_int,
|
||||||
|
want_bytes,
|
||||||
|
want_string,
|
||||||
|
)
|
||||||
from litecord.permissions import get_permissions
|
from litecord.permissions import get_permissions
|
||||||
|
from litecord.presence import BasePresence
|
||||||
|
|
||||||
from litecord.gateway.opcodes import OP
|
from litecord.gateway.opcodes import OP
|
||||||
from litecord.gateway.state import GatewayState
|
from litecord.gateway.state import GatewayState
|
||||||
|
|
@ -173,13 +180,15 @@ class GatewayWebsocket:
|
||||||
payload.get("t"),
|
payload.get("t"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO encode to bytes only when absolutely needed e.g
|
||||||
|
# when compressing, because encoding == json means bytes won't work
|
||||||
if isinstance(encoded, str):
|
if isinstance(encoded, str):
|
||||||
encoded = encoded.encode()
|
encoded = encoded.encode()
|
||||||
|
|
||||||
if self.wsp.compress == "zlib-stream":
|
if self.wsp.compress == "zlib-stream":
|
||||||
await self._zlib_stream_send(encoded)
|
await self._zlib_stream_send(want_bytes(encoded))
|
||||||
elif self.wsp.compress == "zstd-stream":
|
elif self.wsp.compress == "zstd-stream":
|
||||||
await self._zstd_stream_send(encoded)
|
await self._zstd_stream_send(want_bytes(encoded))
|
||||||
elif (
|
elif (
|
||||||
self.state
|
self.state
|
||||||
and self.state.compress
|
and self.state.compress
|
||||||
|
|
@ -188,9 +197,13 @@ class GatewayWebsocket:
|
||||||
):
|
):
|
||||||
# TODO determine better conditions to trigger a compress set
|
# TODO determine better conditions to trigger a compress set
|
||||||
# by identify
|
# by identify
|
||||||
await self.ws.send(zlib.compress(encoded))
|
await self.ws.send(zlib.compress(want_bytes(encoded)))
|
||||||
else:
|
else:
|
||||||
await self.ws.send(encoded)
|
await self.ws.send(
|
||||||
|
want_bytes(encoded)
|
||||||
|
if self.wsp.encoding == "etf"
|
||||||
|
else want_string(encoded)
|
||||||
|
)
|
||||||
|
|
||||||
async def send_op(self, op_code: int, data: Any):
|
async def send_op(self, op_code: int, data: Any):
|
||||||
"""Send a packet but just the OP code information is filled in."""
|
"""Send a packet but just the OP code information is filled in."""
|
||||||
|
|
@ -343,7 +356,7 @@ class GatewayWebsocket:
|
||||||
"guilds": guilds,
|
"guilds": guilds,
|
||||||
"session_id": self.state.session_id,
|
"session_id": self.state.session_id,
|
||||||
"_trace": ["transbian"],
|
"_trace": ["transbian"],
|
||||||
"shard": self.state.shard,
|
"shard": [self.state.current_shard, self.state.shard_count],
|
||||||
}
|
}
|
||||||
|
|
||||||
await self.dispatch("READY", {**base_ready, **user_ready})
|
await self.dispatch("READY", {**base_ready, **user_ready})
|
||||||
|
|
@ -442,7 +455,7 @@ class GatewayWebsocket:
|
||||||
log.info("subscribing to {} friends", len(friend_ids))
|
log.info("subscribing to {} friends", len(friend_ids))
|
||||||
await self.app.dispatcher.sub_many("friend", user_id, friend_ids)
|
await self.app.dispatcher.sub_many("friend", user_id, friend_ids)
|
||||||
|
|
||||||
async def update_status(self, status: dict):
|
async def update_status(self, incoming_status: dict):
|
||||||
"""Update the status of the current websocket connection."""
|
"""Update the status of the current websocket connection."""
|
||||||
if not self.state:
|
if not self.state:
|
||||||
return
|
return
|
||||||
|
|
@ -452,7 +465,7 @@ class GatewayWebsocket:
|
||||||
# are just silently dropped.
|
# are just silently dropped.
|
||||||
return
|
return
|
||||||
|
|
||||||
default_status = {
|
status = {
|
||||||
"afk": False,
|
"afk": False,
|
||||||
# TODO: fetch status from settings
|
# TODO: fetch status from settings
|
||||||
"status": "online",
|
"status": "online",
|
||||||
|
|
@ -460,8 +473,7 @@ class GatewayWebsocket:
|
||||||
# TODO: this
|
# TODO: this
|
||||||
"since": 0,
|
"since": 0,
|
||||||
}
|
}
|
||||||
|
status.update(incoming_status or {})
|
||||||
status = {**(status or {}), **default_status}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
status = validate(status, GW_STATUS_UPDATE)
|
status = validate(status, GW_STATUS_UPDATE)
|
||||||
|
|
@ -479,15 +491,10 @@ class GatewayWebsocket:
|
||||||
else:
|
else:
|
||||||
game = status["game"]
|
game = status["game"]
|
||||||
|
|
||||||
# construct final status
|
pres_status = status.get("status") or "online"
|
||||||
status = {
|
pres_status = "offline" if pres_status == "invisible" else pres_status
|
||||||
"afk": status.get("afk", False),
|
self.state.presence = BasePresence(status=pres_status, game=game)
|
||||||
"status": status.get("status", "online"),
|
|
||||||
"game": game,
|
|
||||||
"since": status.get("since", 0),
|
|
||||||
}
|
|
||||||
|
|
||||||
self.state.presence = status
|
|
||||||
log.info(
|
log.info(
|
||||||
f'Updating presence status={status["status"]} for '
|
f'Updating presence status={status["status"]} for '
|
||||||
f"uid={self.state.user_id}"
|
f"uid={self.state.user_id}"
|
||||||
|
|
@ -523,6 +530,8 @@ class GatewayWebsocket:
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise DecodeError("Invalid identify parameters")
|
raise DecodeError("Invalid identify parameters")
|
||||||
|
|
||||||
|
# TODO proper validation of this payload
|
||||||
|
|
||||||
compress = data.get("compress", False)
|
compress = data.get("compress", False)
|
||||||
large = data.get("large_threshold", 50)
|
large = data.get("large_threshold", 50)
|
||||||
|
|
||||||
|
|
@ -552,12 +561,12 @@ class GatewayWebsocket:
|
||||||
bot=bot,
|
bot=bot,
|
||||||
compress=compress,
|
compress=compress,
|
||||||
large=large,
|
large=large,
|
||||||
shard=shard,
|
|
||||||
current_shard=shard[0],
|
current_shard=shard[0],
|
||||||
shard_count=shard[1],
|
shard_count=shard[1],
|
||||||
ws=self,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.state.ws = self
|
||||||
|
|
||||||
# link the state to the user
|
# link the state to the user
|
||||||
self.app.state_manager.insert(self.state)
|
self.app.state_manager.insert(self.state)
|
||||||
|
|
||||||
|
|
@ -1004,22 +1013,21 @@ class GatewayWebsocket:
|
||||||
if not user_id:
|
if not user_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
# TODO: account for sharding
|
# TODO: account for sharding. this only checks to dispatch an offline
|
||||||
# this only updates status to offline once
|
# when all the shards have come fully offline, which is inefficient.
|
||||||
# ALL shards have come offline
|
|
||||||
|
# TODO why is this inneficient?
|
||||||
states = self.app.state_manager.user_states(user_id)
|
states = self.app.state_manager.user_states(user_id)
|
||||||
with_ws = [s for s in states if s.ws]
|
if not any(s.ws for s in states):
|
||||||
|
await self.app.presence.dispatch_pres(
|
||||||
# there arent any other states with websocket
|
user_id, BasePresence(status="offline")
|
||||||
if not with_ws:
|
)
|
||||||
offline = {"afk": False, "status": "offline", "game": None, "since": 0}
|
|
||||||
|
|
||||||
await self.app.presence.dispatch_pres(user_id, offline)
|
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""Wrap :meth:`listen_messages` inside
|
"""Wrap :meth:`listen_messages` inside
|
||||||
a try/except block for WebsocketClose handling."""
|
a try/except block for WebsocketClose handling."""
|
||||||
try:
|
try:
|
||||||
|
async with self.app.app_context():
|
||||||
await self._send_hello()
|
await self._send_hello()
|
||||||
await self._listen_messages()
|
await self._listen_messages()
|
||||||
except websockets.exceptions.ConnectionClosed as err:
|
except websockets.exceptions.ConnectionClosed as err:
|
||||||
|
|
|
||||||
|
|
@ -17,13 +17,33 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Dict, Any, Iterable
|
from typing import List, Dict, Any, Iterable, Optional
|
||||||
from random import choice
|
from random import choice
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from quart import current_app as app
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BasePresence:
|
||||||
|
status: str
|
||||||
|
game: Optional[dict] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def partial_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"status": self.status,
|
||||||
|
"game": self.game,
|
||||||
|
"since": 0,
|
||||||
|
"client_status": {},
|
||||||
|
"mobile": False,
|
||||||
|
"activities": [self.game] if self.game else [],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
Presence = Dict[str, Any]
|
Presence = Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -37,69 +57,29 @@ def status_cmp(status: str, other_status: str) -> bool:
|
||||||
return hierarchy[status] > hierarchy[other_status]
|
return hierarchy[status] > hierarchy[other_status]
|
||||||
|
|
||||||
|
|
||||||
def _best_presence(shards):
|
def _merge_state_presences(shards: list) -> BasePresence:
|
||||||
"""Find the 'best' presence given a list of GatewayState."""
|
"""create a 'best' presence given a list of states."""
|
||||||
best = {"status": None, "game": None}
|
best = BasePresence(status="offline")
|
||||||
|
|
||||||
for state in shards:
|
for state in shards:
|
||||||
presence = state.presence
|
if state.presence is None:
|
||||||
|
|
||||||
status = presence["status"]
|
|
||||||
|
|
||||||
if not presence:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# shards with a better status
|
# shards with a better status
|
||||||
# in the hierarchy are treated as best
|
# in the hierarchy are treated as best
|
||||||
if status_cmp(status, best["status"]):
|
if status_cmp(state.presence.status, best.status):
|
||||||
best["status"] = status
|
best.status = state.presence.status
|
||||||
|
|
||||||
# if we have any game, use it
|
# if we have any game, use it
|
||||||
if presence["game"] is not None:
|
if state.presence.game:
|
||||||
best["game"] = presence["game"]
|
best.game = state.presence.game
|
||||||
|
|
||||||
# best['status'] is None when no
|
return best
|
||||||
# status was good enough.
|
|
||||||
return None if not best["status"] else best
|
|
||||||
|
|
||||||
|
|
||||||
def fill_presence(presence: dict, *, game=None) -> dict:
|
async def _pres(user_id: int, presence: BasePresence) -> dict:
|
||||||
"""Fill a given presence object with some specific fields."""
|
"""Take a given base presence and convert it to a full friend presence."""
|
||||||
presence["client_status"] = {}
|
return {**presence.partial_dict, **{"user": await app.storage.get_user(user_id)}}
|
||||||
presence["mobile"] = False
|
|
||||||
|
|
||||||
if "since" not in presence:
|
|
||||||
presence["since"] = 0
|
|
||||||
|
|
||||||
# fill game and activities array depending if game
|
|
||||||
# is there or not
|
|
||||||
game = game or presence.get("game")
|
|
||||||
|
|
||||||
# casting to bool since a game of {} is still invalid
|
|
||||||
if game:
|
|
||||||
presence["game"] = game
|
|
||||||
presence["activities"] = [game]
|
|
||||||
else:
|
|
||||||
presence["game"] = None
|
|
||||||
presence["activities"] = []
|
|
||||||
|
|
||||||
return presence
|
|
||||||
|
|
||||||
|
|
||||||
async def _pres(storage, user_id: int, status_obj: dict) -> dict:
|
|
||||||
"""Convert a given status into a presence, given the User ID and the
|
|
||||||
:class:`Storage` instance."""
|
|
||||||
ext = {
|
|
||||||
"user": await storage.get_user(user_id),
|
|
||||||
"activities": [],
|
|
||||||
# NOTE: we are purposefully overwriting the fields, as there
|
|
||||||
# isn't any push for us to actually implement mobile detection, or
|
|
||||||
# web detection, etc.
|
|
||||||
"client_status": {},
|
|
||||||
"mobile": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
return fill_presence({**status_obj, **ext})
|
|
||||||
|
|
||||||
|
|
||||||
class PresenceManager:
|
class PresenceManager:
|
||||||
|
|
@ -123,39 +103,30 @@ class PresenceManager:
|
||||||
# then fetching its respective member and merging that info with
|
# then fetching its respective member and merging that info with
|
||||||
# the state's set presence.
|
# the state's set presence.
|
||||||
states = self.state_manager.guild_states(member_ids, guild_id)
|
states = self.state_manager.guild_states(member_ids, guild_id)
|
||||||
|
|
||||||
presences = []
|
presences = []
|
||||||
|
|
||||||
for state in states:
|
for state in states:
|
||||||
member = await self.storage.get_member_data_one(guild_id, state.user_id)
|
member = await self.storage.get_member_data_one(guild_id, state.user_id)
|
||||||
|
|
||||||
game = state.presence.get("game", None)
|
|
||||||
|
|
||||||
# only use the data we need.
|
|
||||||
presences.append(
|
presences.append(
|
||||||
fill_presence(
|
|
||||||
{
|
{
|
||||||
|
**(state.presence or BasePresence(status="offline")).partial_dict,
|
||||||
|
**{
|
||||||
"user": member["user"],
|
"user": member["user"],
|
||||||
"roles": member["roles"],
|
"roles": member["roles"],
|
||||||
"guild_id": str(guild_id),
|
"guild_id": str(guild_id),
|
||||||
# if a state is connected to the guild
|
|
||||||
# we assume its online.
|
|
||||||
"status": state.presence.get("status", "online"),
|
|
||||||
},
|
},
|
||||||
game=game,
|
}
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return presences
|
return presences
|
||||||
|
|
||||||
async def dispatch_guild_pres(self, guild_id: int, user_id: int, new_state: dict):
|
async def dispatch_guild_pres(
|
||||||
|
self, guild_id: int, user_id: int, presence: BasePresence
|
||||||
|
):
|
||||||
"""Dispatch a Presence update to an entire guild."""
|
"""Dispatch a Presence update to an entire guild."""
|
||||||
state = dict(new_state)
|
|
||||||
|
|
||||||
member = await self.storage.get_member_data_one(guild_id, user_id)
|
member = await self.storage.get_member_data_one(guild_id, user_id)
|
||||||
|
|
||||||
game = state["game"]
|
|
||||||
|
|
||||||
lazy_guild_store = self.dispatcher.backends["lazy_guild"]
|
lazy_guild_store = self.dispatcher.backends["lazy_guild"]
|
||||||
lists = lazy_guild_store.get_gml_guild(guild_id)
|
lists = lazy_guild_store.get_gml_guild(guild_id)
|
||||||
|
|
||||||
|
|
@ -166,7 +137,11 @@ class PresenceManager:
|
||||||
for member_list in lists:
|
for member_list in lists:
|
||||||
session_ids = await member_list.pres_update(
|
session_ids = await member_list.pres_update(
|
||||||
int(member["user"]["id"]),
|
int(member["user"]["id"]),
|
||||||
{"roles": member["roles"], "status": state["status"], "game": game},
|
{
|
||||||
|
"roles": member["roles"],
|
||||||
|
"status": presence.status,
|
||||||
|
"game": presence.game,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
log.debug("Lazy Dispatch to {}", len(session_ids))
|
log.debug("Lazy Dispatch to {}", len(session_ids))
|
||||||
|
|
@ -176,15 +151,14 @@ class PresenceManager:
|
||||||
if member_list.channel_id == member_list.guild_id:
|
if member_list.channel_id == member_list.guild_id:
|
||||||
in_lazy.extend(session_ids)
|
in_lazy.extend(session_ids)
|
||||||
|
|
||||||
pres_update_payload = fill_presence(
|
event_payload = {
|
||||||
{
|
**presence.partial_dict,
|
||||||
|
**{
|
||||||
"guild_id": str(guild_id),
|
"guild_id": str(guild_id),
|
||||||
"user": member["user"],
|
"user": member["user"],
|
||||||
"roles": member["roles"],
|
"roles": member["roles"],
|
||||||
"status": state["status"],
|
|
||||||
},
|
},
|
||||||
game=game,
|
}
|
||||||
)
|
|
||||||
|
|
||||||
# given a session id, return if the session id actually connects to
|
# given a session id, return if the session id actually connects to
|
||||||
# a given user, and if the state has not been dispatched via lazy guild.
|
# a given user, and if the state has not been dispatched via lazy guild.
|
||||||
|
|
@ -202,79 +176,68 @@ class PresenceManager:
|
||||||
# everyone not in lazy guild mode
|
# everyone not in lazy guild mode
|
||||||
# gets a PRESENCE_UPDATE
|
# gets a PRESENCE_UPDATE
|
||||||
await self.dispatcher.dispatch_filter(
|
await self.dispatcher.dispatch_filter(
|
||||||
"guild", guild_id, _session_check, "PRESENCE_UPDATE", pres_update_payload
|
"guild", guild_id, _session_check, "PRESENCE_UPDATE", event_payload
|
||||||
)
|
)
|
||||||
|
|
||||||
return in_lazy
|
return in_lazy
|
||||||
|
|
||||||
async def dispatch_pres(self, user_id: int, state: dict):
|
async def dispatch_pres(self, user_id: int, presence: BasePresence) -> None:
|
||||||
"""Dispatch a new presence to all guilds the user is in.
|
"""Dispatch a new presence to all guilds the user is in.
|
||||||
|
|
||||||
Also dispatches the presence to all the users' friends
|
Also dispatches the presence to all the users' friends
|
||||||
"""
|
"""
|
||||||
if state["status"] == "invisible":
|
# TODO: shard-aware (needs to only dispatch guilds of the shard)
|
||||||
state["status"] = "offline"
|
|
||||||
|
|
||||||
# TODO: shard-aware
|
|
||||||
guild_ids = await self.user_storage.get_user_guilds(user_id)
|
guild_ids = await self.user_storage.get_user_guilds(user_id)
|
||||||
|
|
||||||
for guild_id in guild_ids:
|
for guild_id in guild_ids:
|
||||||
await self.dispatch_guild_pres(guild_id, user_id, state)
|
await self.dispatch_guild_pres(guild_id, user_id, presence)
|
||||||
|
|
||||||
# dispatch to all friends that are subscribed to them
|
# dispatch to all friends that are subscribed to them
|
||||||
user = await self.storage.get_user(user_id)
|
user = await self.storage.get_user(user_id)
|
||||||
game = state["game"]
|
|
||||||
|
|
||||||
await self.dispatcher.dispatch(
|
await self.dispatcher.dispatch(
|
||||||
"friend",
|
"friend",
|
||||||
user_id,
|
user_id,
|
||||||
"PRESENCE_UPDATE",
|
"PRESENCE_UPDATE",
|
||||||
fill_presence({"user": user, "status": state["status"]}, game=game),
|
{**presence.partial_dict, **{"user": user}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def fetch_friend_presence(self, friend_id: int) -> BasePresence:
|
||||||
|
"""Fetch a presence for a friend.
|
||||||
|
|
||||||
|
This is a different algorithm than guild presence.
|
||||||
|
"""
|
||||||
|
friend_states = self.state_manager.user_states(friend_id)
|
||||||
|
|
||||||
|
if not friend_states:
|
||||||
|
return BasePresence(status="offline")
|
||||||
|
|
||||||
|
# filter the best shards:
|
||||||
|
# - all with id 0 (are the first shards in the collection) or
|
||||||
|
# - all shards with count = 1 (single shards)
|
||||||
|
good_shards = list(
|
||||||
|
filter(
|
||||||
|
lambda state: state.shard[0] == 0 or state.shard[1] == 1, friend_states
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if good_shards:
|
||||||
|
return _merge_state_presences(good_shards)
|
||||||
|
|
||||||
|
# if there aren't any shards with id 0
|
||||||
|
# AND none that are single, just go with a random one.
|
||||||
|
shard = choice([s for s in friend_states if s.presence])
|
||||||
|
assert shard.presence is not None
|
||||||
|
return shard.presence
|
||||||
|
|
||||||
async def friend_presences(self, friend_ids: Iterable[int]) -> List[Presence]:
|
async def friend_presences(self, friend_ids: Iterable[int]) -> List[Presence]:
|
||||||
"""Fetch presences for a group of users.
|
"""Fetch presences for a group of users.
|
||||||
|
|
||||||
This assumes the users are friends and so
|
This assumes the users are friends and so
|
||||||
only gets states that are single or have ID 0.
|
only gets states that are single or have ID 0.
|
||||||
"""
|
"""
|
||||||
storage = self.storage
|
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
for friend_id in friend_ids:
|
for friend_id in friend_ids:
|
||||||
friend_states = self.state_manager.user_states(friend_id)
|
presence = self.fetch_friend_presence(friend_id)
|
||||||
|
res.append(await _pres(friend_id, presence))
|
||||||
if not friend_states:
|
|
||||||
# append offline
|
|
||||||
res.append(
|
|
||||||
await _pres(
|
|
||||||
storage,
|
|
||||||
friend_id,
|
|
||||||
{"afk": False, "status": "offline", "game": None, "since": 0},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
continue
|
|
||||||
|
|
||||||
# filter the best shards:
|
|
||||||
# - all with id 0 (are the first shards in the collection) or
|
|
||||||
# - all shards with count = 1 (single shards)
|
|
||||||
good_shards = list(
|
|
||||||
filter(
|
|
||||||
lambda state: state.shard[0] == 0 or state.shard[1] == 1,
|
|
||||||
friend_states,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if good_shards:
|
|
||||||
best_pres = _best_presence(good_shards)
|
|
||||||
best_pres = await _pres(storage, friend_id, best_pres)
|
|
||||||
res.append(best_pres)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# if there aren't any shards with id 0
|
|
||||||
# AND none that are single, just go with a random
|
|
||||||
shard = choice(friend_states)
|
|
||||||
res.append(await _pres(storage, friend_id, shard.presence))
|
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
|
||||||
|
|
@ -290,3 +290,11 @@ def query_tuple_from_args(args: dict, limit: int) -> tuple:
|
||||||
def rand_hex(length: int = 8) -> str:
|
def rand_hex(length: int = 8) -> str:
|
||||||
"""Generate random hex characters."""
|
"""Generate random hex characters."""
|
||||||
return secrets.token_hex(length)[:length]
|
return secrets.token_hex(length)[:length]
|
||||||
|
|
||||||
|
|
||||||
|
def want_bytes(data: Union[str, bytes]) -> bytes:
|
||||||
|
return data if isinstance(data, bytes) else data.encode()
|
||||||
|
|
||||||
|
|
||||||
|
def want_string(data: Union[str, bytes]) -> str:
|
||||||
|
return data.decode() if isinstance(data, bytes) else data
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue