mirror of https://gitlab.com/litecord/litecord.git
gateway.websocket: add basics of resuming
- gateway.state: add PayloadStore - state: add last_seq - gateway.websocket: send string on non zlib-stream - gateway.websocket: add cleanup of state on ws close
This commit is contained in:
parent
2276308c5d
commit
d62db421b0
|
|
@ -7,6 +7,18 @@ def gen_session_id() -> str:
|
|||
return hashlib.sha1(os.urandom(256)).hexdigest()
|
||||
|
||||
|
||||
class PayloadStore:
|
||||
"""Store manager for payloads."""
|
||||
def __init__(self):
|
||||
self.store = {}
|
||||
|
||||
def __getitem__(self, opcode: int):
|
||||
return self.store[opcode]
|
||||
|
||||
def __setitem__(self, opcode: int, payload: dict):
|
||||
self.store[opcode] = payload
|
||||
|
||||
|
||||
class GatewayState:
|
||||
"""Main websocket state.
|
||||
|
||||
|
|
@ -16,9 +28,11 @@ class GatewayState:
|
|||
def __init__(self, **kwargs):
|
||||
self.session_id = kwargs.get('session_id', gen_session_id())
|
||||
self.seq = kwargs.get('seq', 0)
|
||||
self.last_seq = 0
|
||||
self.shard = kwargs.get('shard', [0, 1])
|
||||
self.user_id = kwargs.get('user_id')
|
||||
self.bot = kwargs.get('bot', False)
|
||||
self.store = PayloadStore()
|
||||
|
||||
for key in kwargs:
|
||||
value = kwargs[key]
|
||||
|
|
|
|||
|
|
@ -2,9 +2,10 @@ import json
|
|||
import collections
|
||||
import pprint
|
||||
import zlib
|
||||
from typing import List
|
||||
from typing import List, Dict, Any
|
||||
|
||||
import earl
|
||||
import websockets
|
||||
from logbook import Logger
|
||||
|
||||
from litecord.errors import WebsocketClose, Unauthorized, Forbidden
|
||||
|
|
@ -68,14 +69,19 @@ class GatewayWebsocket:
|
|||
|
||||
self.encoder, self.decoder = encodings[encoding]
|
||||
|
||||
async def send(self, payload: dict):
|
||||
"""Send a payload to the websocket"""
|
||||
async def send(self, payload: Dict[str, Any]):
|
||||
"""Send a payload to the websocket.
|
||||
|
||||
This function accounts for the zlib-stream
|
||||
transport method used by Discord.
|
||||
"""
|
||||
log.debug('Sending {}', pprint.pformat(payload))
|
||||
encoded = self.encoder(payload)
|
||||
|
||||
if not isinstance(encoded, bytes):
|
||||
encoded = encoded.encode()
|
||||
|
||||
print(self.wsp.compress)
|
||||
if self.wsp.compress == 'zlib-stream':
|
||||
data1 = self.wsp.zctx.compress(encoded)
|
||||
data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH)
|
||||
|
|
@ -83,7 +89,7 @@ class GatewayWebsocket:
|
|||
await self.ws.send(data1 + data2)
|
||||
else:
|
||||
# TODO: pure zlib
|
||||
await self.ws.send(encoded)
|
||||
await self.ws.send(encoded.decode())
|
||||
|
||||
async def send_hello(self):
|
||||
"""Send the OP 10 Hello packet over the websocket."""
|
||||
|
|
@ -97,16 +103,19 @@ class GatewayWebsocket:
|
|||
}
|
||||
})
|
||||
|
||||
async def dispatch(self, event, data):
|
||||
async def dispatch(self, event: str, data: Any):
|
||||
"""Dispatch an event to the websocket."""
|
||||
self.state.seq += 1
|
||||
|
||||
await self.send({
|
||||
payload = {
|
||||
'op': OP.DISPATCH,
|
||||
't': event.upper(),
|
||||
's': self.state.seq,
|
||||
'd': data,
|
||||
})
|
||||
}
|
||||
|
||||
self.state.store[self.state.seq] = payload
|
||||
await self.send(payload)
|
||||
|
||||
async def _make_guild_list(self) -> List[int]:
|
||||
# TODO: This function does not account for sharding.
|
||||
|
|
@ -129,7 +138,7 @@ class GatewayWebsocket:
|
|||
for row in guild_ids
|
||||
]
|
||||
|
||||
async def guild_dispatch(self, unavailable_guilds: List[dict]):
|
||||
async def guild_dispatch(self, unavailable_guilds: List[Dict[str, Any]]):
|
||||
"""Dispatch GUILD_CREATE information."""
|
||||
|
||||
# Users don't get asynchronous guild dispatching.
|
||||
|
|
@ -237,10 +246,11 @@ class GatewayWebsocket:
|
|||
if current_shard > shard_count:
|
||||
raise InvalidShard('Shard count > Total shards')
|
||||
|
||||
async def handle_1(self, payload: dict):
|
||||
async def handle_1(self, payload: Dict[str, Any]):
|
||||
"""Handle OP 1 Heartbeat packets."""
|
||||
pass
|
||||
|
||||
async def handle_2(self, payload: dict):
|
||||
async def handle_2(self, payload: Dict[str, Any]):
|
||||
"""Handle the OP 2 Identify packet."""
|
||||
data = payload['d']
|
||||
try:
|
||||
|
|
@ -282,6 +292,60 @@ class GatewayWebsocket:
|
|||
self.ext.state_manager.insert(self.state)
|
||||
await self.dispatch_ready()
|
||||
|
||||
async def handle_3(self, payload: Dict[str, Any]):
|
||||
"""Handle OP 3 Status Update."""
|
||||
pass
|
||||
|
||||
async def handle_6(self, payload: Dict[str, Any]):
|
||||
"""Handle OP 6 Resume."""
|
||||
data = payload['d']
|
||||
|
||||
try:
|
||||
token, sess_id, seq = data['token'], \
|
||||
data['session_id'], data['seq']
|
||||
except KeyError:
|
||||
raise DecodeError('Invalid resume payload')
|
||||
|
||||
try:
|
||||
user_id = await raw_token_check(token, self.ext.db)
|
||||
except (Unauthorized, Forbidden):
|
||||
raise WebsocketClose(4004, 'Invalid token')
|
||||
|
||||
try:
|
||||
state = self.ext.state_manager.fetch(user_id, sess_id)
|
||||
except KeyError:
|
||||
return await self.send({
|
||||
'op': 9,
|
||||
'd': False,
|
||||
})
|
||||
|
||||
if seq > state.seq:
|
||||
raise WebsocketClose(4007, 'Invalid seq')
|
||||
|
||||
# check if a websocket isnt on that state already
|
||||
if state.ws is not None:
|
||||
log.info('Resuming failed, websocket already connected')
|
||||
return await self.send({
|
||||
'op': 9,
|
||||
'd': False,
|
||||
})
|
||||
|
||||
# relink this connection
|
||||
self.state = state
|
||||
state.ws = self
|
||||
|
||||
# TODO: resend payloads
|
||||
|
||||
await self.dispatch('RESUMED', {})
|
||||
|
||||
async def handle_12(self, payload: Dict[str, Any]):
|
||||
"""Handle OP 12 Guild Sync."""
|
||||
data = payload['d']
|
||||
|
||||
for _guild_id in data:
|
||||
# check if user in guild
|
||||
pass
|
||||
|
||||
async def process_message(self, payload):
|
||||
"""Process a single message coming in from the client."""
|
||||
try:
|
||||
|
|
@ -292,7 +356,7 @@ class GatewayWebsocket:
|
|||
try:
|
||||
handler = getattr(self, f'handle_{op_code}')
|
||||
except AttributeError:
|
||||
raise UnknownOPCode('Bad OP code')
|
||||
raise UnknownOPCode(f'Bad OP code: {op_code}')
|
||||
|
||||
await handler(payload)
|
||||
|
||||
|
|
@ -316,7 +380,12 @@ class GatewayWebsocket:
|
|||
try:
|
||||
await self.send_hello()
|
||||
await self.listen_messages()
|
||||
except websockets.exceptions.ConnectionClosed as err:
|
||||
log.warning('Client closed, state={}, err={}', self.state, err)
|
||||
except WebsocketClose as err:
|
||||
log.warning('closed a client, state={} err={}', self.state, err)
|
||||
|
||||
await self.ws.close(code=err.code, reason=err.reason)
|
||||
finally:
|
||||
if self.state:
|
||||
self.state.ws = None
|
||||
|
|
|
|||
Loading…
Reference in New Issue