mirror of https://gitlab.com/litecord/litecord.git
gateway: more comments and docstrings
This commit is contained in:
parent
6be85ea305
commit
1d3520876d
|
|
@ -8,7 +8,11 @@ def gen_session_id() -> str:
|
|||
|
||||
|
||||
class PayloadStore:
|
||||
"""Store manager for payloads."""
|
||||
"""Store manager for payloads.
|
||||
|
||||
This will only store a maximum of MAX_STORE_SIZE,
|
||||
dropping the older payloads when adding new ones.
|
||||
"""
|
||||
MAX_STORE_SIZE = 250
|
||||
|
||||
def __init__(self):
|
||||
|
|
@ -30,14 +34,6 @@ class PayloadStore:
|
|||
self.store[opcode] = payload
|
||||
|
||||
|
||||
class Presence:
|
||||
def __init__(self, raw: dict):
|
||||
self.afk = raw.get('afk', False)
|
||||
self.status = raw.get('status', 'online')
|
||||
self.game = raw.get('game', None)
|
||||
self.since = raw.get('since', 0)
|
||||
|
||||
|
||||
class GatewayState:
|
||||
"""Main websocket state.
|
||||
|
||||
|
|
@ -46,13 +42,28 @@ class GatewayState:
|
|||
|
||||
def __init__(self, **kwargs):
|
||||
self.session_id = kwargs.get('session_id', gen_session_id())
|
||||
|
||||
#: event sequence number
|
||||
self.seq = kwargs.get('seq', 0)
|
||||
|
||||
#: last seq sent by us, the backend
|
||||
self.last_seq = 0
|
||||
|
||||
#: shard information about the state,
|
||||
# its id and shard count
|
||||
self.shard = kwargs.get('shard', [0, 1])
|
||||
|
||||
self.user_id = kwargs.get('user_id')
|
||||
self.bot = kwargs.get('bot', False)
|
||||
|
||||
#: set by the gateway connection
|
||||
# on OP STATUS_UPDATE
|
||||
self.presence = {}
|
||||
|
||||
#: set by the backend once identify happens
|
||||
self.ws = None
|
||||
|
||||
#: store (kind of) all payloads sent by us
|
||||
self.store = PayloadStore()
|
||||
|
||||
for key in kwargs:
|
||||
|
|
|
|||
|
|
@ -13,6 +13,13 @@ class StateManager:
|
|||
"""Manager for gateway state information."""
|
||||
|
||||
def __init__(self):
|
||||
# {
|
||||
# user_id: {
|
||||
# session_id: GatewayState,
|
||||
# session_id_2: GatewayState, ...
|
||||
# },
|
||||
# user_id_2: {}, ...
|
||||
# }
|
||||
self.states = defaultdict(dict)
|
||||
|
||||
def insert(self, state: GatewayState):
|
||||
|
|
@ -23,7 +30,14 @@ class StateManager:
|
|||
user_states[state.session_id] = state
|
||||
|
||||
def fetch(self, user_id: int, session_id: str) -> GatewayState:
|
||||
"""Fetch a state object from the registry."""
|
||||
"""Fetch a state object from the manager.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
When the user_id or session_id
|
||||
aren't found in the store.
|
||||
"""
|
||||
return self.states[user_id][session_id]
|
||||
|
||||
def remove(self, state):
|
||||
|
|
|
|||
|
|
@ -256,15 +256,14 @@ class GatewayWebsocket:
|
|||
# async dispatch of guilds
|
||||
self.ext.loop.create_task(self.guild_dispatch(guilds))
|
||||
|
||||
async def _check_shards(self):
|
||||
shard = self.state.shard
|
||||
async def _check_shards(self, shard, user_id):
|
||||
current_shard, shard_count = shard
|
||||
|
||||
guilds = await self.ext.db.fetchval("""
|
||||
SELECT COUNT(*)
|
||||
FROM members
|
||||
WHERE user_id = $1
|
||||
""", self.state.user_id)
|
||||
""", user_id)
|
||||
|
||||
recommended = max(int(guilds / 1200), 1)
|
||||
|
||||
|
|
@ -390,6 +389,9 @@ class GatewayWebsocket:
|
|||
WHERE id = $1
|
||||
""", user_id)
|
||||
|
||||
await self._check_shards(shard, user_id)
|
||||
|
||||
# only create a state after checking everything
|
||||
self.state = GatewayState(
|
||||
user_id=user_id,
|
||||
bot=bot,
|
||||
|
|
@ -401,9 +403,9 @@ class GatewayWebsocket:
|
|||
ws=self
|
||||
)
|
||||
|
||||
await self._check_shards()
|
||||
|
||||
# link the state to the user
|
||||
self.ext.state_manager.insert(self.state)
|
||||
|
||||
await self.update_status(presence)
|
||||
await self.subscribe_all()
|
||||
await self.dispatch_ready()
|
||||
|
|
@ -419,6 +421,7 @@ class GatewayWebsocket:
|
|||
async def handle_4(self, payload: Dict[str, Any]):
|
||||
"""Handle OP 4 Voice Status Update."""
|
||||
data = payload['d']
|
||||
# for now, ignore
|
||||
log.debug('got VSU cid={} gid={} deaf={} mute={} video={}',
|
||||
data.get('channel_id'),
|
||||
data.get('guild_id'),
|
||||
|
|
@ -426,9 +429,6 @@ class GatewayWebsocket:
|
|||
data.get('self_mute'),
|
||||
data.get('self_video'))
|
||||
|
||||
# for now, do nothing
|
||||
pass
|
||||
|
||||
async def _handle_5(self, payload: Dict[str, Any]):
|
||||
"""Handle OP 5 Voice Server Ping.
|
||||
|
||||
|
|
@ -452,6 +452,9 @@ class GatewayWebsocket:
|
|||
})
|
||||
|
||||
if not resumable and self.state:
|
||||
# since the state will be removed from
|
||||
# the manager, it will become unreachable
|
||||
# when trying to resume.
|
||||
self.ext.state_manager.remove(self.state)
|
||||
|
||||
async def _resume(self, replay_seqs: iter):
|
||||
|
|
@ -476,12 +479,14 @@ class GatewayWebsocket:
|
|||
await self.send(payload)
|
||||
except Exception:
|
||||
log.exception('error while resuming')
|
||||
await self.invalidate_session()
|
||||
await self.invalidate_session(False)
|
||||
return
|
||||
|
||||
if presences:
|
||||
await self.dispatch('PRESENCE_REPLACE', presences)
|
||||
|
||||
await self.dispatch('RESUMED', {})
|
||||
|
||||
async def handle_6(self, payload: Dict[str, Any]):
|
||||
"""Handle OP 6 Resume."""
|
||||
data = payload['d']
|
||||
|
|
@ -515,7 +520,6 @@ class GatewayWebsocket:
|
|||
state.ws = self
|
||||
|
||||
await self._resume(range(seq, state.seq))
|
||||
await self.dispatch('RESUMED', {})
|
||||
|
||||
async def _req_guild_members(self, guild_id: str, user_ids: List[int],
|
||||
query: str, limit: int):
|
||||
|
|
|
|||
Loading…
Reference in New Issue