gateway: more comments and docstrings

This commit is contained in:
Luna Mendes 2018-10-13 17:30:02 -03:00
parent 6be85ea305
commit 1d3520876d
3 changed files with 49 additions and 20 deletions

View File

@ -8,7 +8,11 @@ def gen_session_id() -> str:
class PayloadStore:
"""Store manager for payloads."""
"""Store manager for payloads.
This will only store a maximum of MAX_STORE_SIZE,
dropping the older payloads when adding new ones.
"""
MAX_STORE_SIZE = 250
def __init__(self):
@ -30,14 +34,6 @@ class PayloadStore:
self.store[opcode] = payload
class Presence:
def __init__(self, raw: dict):
self.afk = raw.get('afk', False)
self.status = raw.get('status', 'online')
self.game = raw.get('game', None)
self.since = raw.get('since', 0)
class GatewayState:
"""Main websocket state.
@ -46,13 +42,28 @@ class GatewayState:
def __init__(self, **kwargs):
self.session_id = kwargs.get('session_id', gen_session_id())
#: event sequence number
self.seq = kwargs.get('seq', 0)
#: last seq sent by us, the backend
self.last_seq = 0
#: shard information about the state,
# its id and shard count
self.shard = kwargs.get('shard', [0, 1])
self.user_id = kwargs.get('user_id')
self.bot = kwargs.get('bot', False)
#: set by the gateway connection
# on OP STATUS_UPDATE
self.presence = {}
#: set by the backend once identify happens
self.ws = None
#: store (kind of) all payloads sent by us
self.store = PayloadStore()
for key in kwargs:

View File

@ -13,6 +13,13 @@ class StateManager:
"""Manager for gateway state information."""
def __init__(self):
# {
# user_id: {
# session_id: GatewayState,
# session_id_2: GatewayState, ...
# },
# user_id_2: {}, ...
# }
self.states = defaultdict(dict)
def insert(self, state: GatewayState):
@ -23,7 +30,14 @@ class StateManager:
user_states[state.session_id] = state
def fetch(self, user_id: int, session_id: str) -> GatewayState:
"""Fetch a state object from the registry."""
"""Fetch a state object from the manager.
Raises
------
KeyError
When the user_id or session_id
aren't found in the store.
"""
return self.states[user_id][session_id]
def remove(self, state):

View File

@ -256,15 +256,14 @@ class GatewayWebsocket:
# async dispatch of guilds
self.ext.loop.create_task(self.guild_dispatch(guilds))
async def _check_shards(self):
shard = self.state.shard
async def _check_shards(self, shard, user_id):
current_shard, shard_count = shard
guilds = await self.ext.db.fetchval("""
SELECT COUNT(*)
FROM members
WHERE user_id = $1
""", self.state.user_id)
""", user_id)
recommended = max(int(guilds / 1200), 1)
@ -390,6 +389,9 @@ class GatewayWebsocket:
WHERE id = $1
""", user_id)
await self._check_shards(shard, user_id)
# only create a state after checking everything
self.state = GatewayState(
user_id=user_id,
bot=bot,
@ -401,9 +403,9 @@ class GatewayWebsocket:
ws=self
)
await self._check_shards()
# link the state to the user
self.ext.state_manager.insert(self.state)
await self.update_status(presence)
await self.subscribe_all()
await self.dispatch_ready()
@ -419,6 +421,7 @@ class GatewayWebsocket:
async def handle_4(self, payload: Dict[str, Any]):
"""Handle OP 4 Voice Status Update."""
data = payload['d']
# for now, ignore
log.debug('got VSU cid={} gid={} deaf={} mute={} video={}',
data.get('channel_id'),
data.get('guild_id'),
@ -426,9 +429,6 @@ class GatewayWebsocket:
data.get('self_mute'),
data.get('self_video'))
# for now, do nothing
pass
async def _handle_5(self, payload: Dict[str, Any]):
"""Handle OP 5 Voice Server Ping.
@ -452,6 +452,9 @@ class GatewayWebsocket:
})
if not resumable and self.state:
# since the state will be removed from
# the manager, it will become unreachable
# when trying to resume.
self.ext.state_manager.remove(self.state)
async def _resume(self, replay_seqs: iter):
@ -476,12 +479,14 @@ class GatewayWebsocket:
await self.send(payload)
except Exception:
log.exception('error while resuming')
await self.invalidate_session()
await self.invalidate_session(False)
return
if presences:
await self.dispatch('PRESENCE_REPLACE', presences)
await self.dispatch('RESUMED', {})
async def handle_6(self, payload: Dict[str, Any]):
"""Handle OP 6 Resume."""
data = payload['d']
@ -515,7 +520,6 @@ class GatewayWebsocket:
state.ws = self
await self._resume(range(seq, state.seq))
await self.dispatch('RESUMED', {})
async def _req_guild_members(self, guild_id: str, user_ids: List[int],
query: str, limit: int):