gateway.websocket: add basics of resuming

- gateway.state: add PayloadStore
 - state: add last_seq
 - gateway.websocket: send string on non zlib-stream
 - gateway.websocket: add cleanup of state on ws close
This commit is contained in:
Luna Mendes 2018-06-24 18:28:28 -03:00
parent 2276308c5d
commit d62db421b0
2 changed files with 94 additions and 11 deletions

View File

@ -7,6 +7,18 @@ def gen_session_id() -> str:
return hashlib.sha1(os.urandom(256)).hexdigest() return hashlib.sha1(os.urandom(256)).hexdigest()
class PayloadStore:
"""Store manager for payloads."""
def __init__(self):
self.store = {}
def __getitem__(self, opcode: int):
return self.store[opcode]
def __setitem__(self, opcode: int, payload: dict):
self.store[opcode] = payload
class GatewayState: class GatewayState:
"""Main websocket state. """Main websocket state.
@ -16,9 +28,11 @@ class GatewayState:
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.session_id = kwargs.get('session_id', gen_session_id()) self.session_id = kwargs.get('session_id', gen_session_id())
self.seq = kwargs.get('seq', 0) self.seq = kwargs.get('seq', 0)
self.last_seq = 0
self.shard = kwargs.get('shard', [0, 1]) self.shard = kwargs.get('shard', [0, 1])
self.user_id = kwargs.get('user_id') self.user_id = kwargs.get('user_id')
self.bot = kwargs.get('bot', False) self.bot = kwargs.get('bot', False)
self.store = PayloadStore()
for key in kwargs: for key in kwargs:
value = kwargs[key] value = kwargs[key]

View File

@ -2,9 +2,10 @@ import json
import collections import collections
import pprint import pprint
import zlib import zlib
from typing import List from typing import List, Dict, Any
import earl import earl
import websockets
from logbook import Logger from logbook import Logger
from litecord.errors import WebsocketClose, Unauthorized, Forbidden from litecord.errors import WebsocketClose, Unauthorized, Forbidden
@ -68,14 +69,19 @@ class GatewayWebsocket:
self.encoder, self.decoder = encodings[encoding] self.encoder, self.decoder = encodings[encoding]
async def send(self, payload: dict): async def send(self, payload: Dict[str, Any]):
"""Send a payload to the websocket""" """Send a payload to the websocket.
This function accounts for the zlib-stream
transport method used by Discord.
"""
log.debug('Sending {}', pprint.pformat(payload)) log.debug('Sending {}', pprint.pformat(payload))
encoded = self.encoder(payload) encoded = self.encoder(payload)
if not isinstance(encoded, bytes): if not isinstance(encoded, bytes):
encoded = encoded.encode() encoded = encoded.encode()
print(self.wsp.compress)
if self.wsp.compress == 'zlib-stream': if self.wsp.compress == 'zlib-stream':
data1 = self.wsp.zctx.compress(encoded) data1 = self.wsp.zctx.compress(encoded)
data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH) data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH)
@ -83,7 +89,7 @@ class GatewayWebsocket:
await self.ws.send(data1 + data2) await self.ws.send(data1 + data2)
else: else:
# TODO: pure zlib # TODO: pure zlib
await self.ws.send(encoded) await self.ws.send(encoded.decode())
async def send_hello(self): async def send_hello(self):
"""Send the OP 10 Hello packet over the websocket.""" """Send the OP 10 Hello packet over the websocket."""
@ -97,16 +103,19 @@ class GatewayWebsocket:
} }
}) })
async def dispatch(self, event, data): async def dispatch(self, event: str, data: Any):
"""Dispatch an event to the websocket.""" """Dispatch an event to the websocket."""
self.state.seq += 1 self.state.seq += 1
await self.send({ payload = {
'op': OP.DISPATCH, 'op': OP.DISPATCH,
't': event.upper(), 't': event.upper(),
's': self.state.seq, 's': self.state.seq,
'd': data, 'd': data,
}) }
self.state.store[self.state.seq] = payload
await self.send(payload)
async def _make_guild_list(self) -> List[int]: async def _make_guild_list(self) -> List[int]:
# TODO: This function does not account for sharding. # TODO: This function does not account for sharding.
@ -129,7 +138,7 @@ class GatewayWebsocket:
for row in guild_ids for row in guild_ids
] ]
async def guild_dispatch(self, unavailable_guilds: List[dict]): async def guild_dispatch(self, unavailable_guilds: List[Dict[str, Any]]):
"""Dispatch GUILD_CREATE information.""" """Dispatch GUILD_CREATE information."""
# Users don't get asynchronous guild dispatching. # Users don't get asynchronous guild dispatching.
@ -237,10 +246,11 @@ class GatewayWebsocket:
if current_shard > shard_count: if current_shard > shard_count:
raise InvalidShard('Shard count > Total shards') raise InvalidShard('Shard count > Total shards')
async def handle_1(self, payload: dict): async def handle_1(self, payload: Dict[str, Any]):
"""Handle OP 1 Heartbeat packets."""
pass pass
async def handle_2(self, payload: dict): async def handle_2(self, payload: Dict[str, Any]):
"""Handle the OP 2 Identify packet.""" """Handle the OP 2 Identify packet."""
data = payload['d'] data = payload['d']
try: try:
@ -282,6 +292,60 @@ class GatewayWebsocket:
self.ext.state_manager.insert(self.state) self.ext.state_manager.insert(self.state)
await self.dispatch_ready() await self.dispatch_ready()
async def handle_3(self, payload: Dict[str, Any]):
"""Handle OP 3 Status Update."""
pass
async def handle_6(self, payload: Dict[str, Any]):
"""Handle OP 6 Resume."""
data = payload['d']
try:
token, sess_id, seq = data['token'], \
data['session_id'], data['seq']
except KeyError:
raise DecodeError('Invalid resume payload')
try:
user_id = await raw_token_check(token, self.ext.db)
except (Unauthorized, Forbidden):
raise WebsocketClose(4004, 'Invalid token')
try:
state = self.ext.state_manager.fetch(user_id, sess_id)
except KeyError:
return await self.send({
'op': 9,
'd': False,
})
if seq > state.seq:
raise WebsocketClose(4007, 'Invalid seq')
# check if a websocket isnt on that state already
if state.ws is not None:
log.info('Resuming failed, websocket already connected')
return await self.send({
'op': 9,
'd': False,
})
# relink this connection
self.state = state
state.ws = self
# TODO: resend payloads
await self.dispatch('RESUMED', {})
async def handle_12(self, payload: Dict[str, Any]):
"""Handle OP 12 Guild Sync."""
data = payload['d']
for _guild_id in data:
# check if user in guild
pass
async def process_message(self, payload): async def process_message(self, payload):
"""Process a single message coming in from the client.""" """Process a single message coming in from the client."""
try: try:
@ -292,7 +356,7 @@ class GatewayWebsocket:
try: try:
handler = getattr(self, f'handle_{op_code}') handler = getattr(self, f'handle_{op_code}')
except AttributeError: except AttributeError:
raise UnknownOPCode('Bad OP code') raise UnknownOPCode(f'Bad OP code: {op_code}')
await handler(payload) await handler(payload)
@ -316,7 +380,12 @@ class GatewayWebsocket:
try: try:
await self.send_hello() await self.send_hello()
await self.listen_messages() await self.listen_messages()
except websockets.exceptions.ConnectionClosed as err:
log.warning('Client closed, state={}, err={}', self.state, err)
except WebsocketClose as err: except WebsocketClose as err:
log.warning('closed a client, state={} err={}', self.state, err) log.warning('closed a client, state={} err={}', self.state, err)
await self.ws.close(code=err.code, reason=err.reason) await self.ws.close(code=err.code, reason=err.reason)
finally:
if self.state:
self.state.ws = None