mirror of https://gitlab.com/litecord/litecord.git
gateway.websocket: add basics of resuming
- gateway.state: add PayloadStore - state: add last_seq - gateway.websocket: send string on non zlib-stream - gateway.websocket: add cleanup of state on ws close
This commit is contained in:
parent
2276308c5d
commit
d62db421b0
|
|
@ -7,6 +7,18 @@ def gen_session_id() -> str:
|
||||||
return hashlib.sha1(os.urandom(256)).hexdigest()
|
return hashlib.sha1(os.urandom(256)).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
class PayloadStore:
|
||||||
|
"""Store manager for payloads."""
|
||||||
|
def __init__(self):
|
||||||
|
self.store = {}
|
||||||
|
|
||||||
|
def __getitem__(self, opcode: int):
|
||||||
|
return self.store[opcode]
|
||||||
|
|
||||||
|
def __setitem__(self, opcode: int, payload: dict):
|
||||||
|
self.store[opcode] = payload
|
||||||
|
|
||||||
|
|
||||||
class GatewayState:
|
class GatewayState:
|
||||||
"""Main websocket state.
|
"""Main websocket state.
|
||||||
|
|
||||||
|
|
@ -16,9 +28,11 @@ class GatewayState:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.session_id = kwargs.get('session_id', gen_session_id())
|
self.session_id = kwargs.get('session_id', gen_session_id())
|
||||||
self.seq = kwargs.get('seq', 0)
|
self.seq = kwargs.get('seq', 0)
|
||||||
|
self.last_seq = 0
|
||||||
self.shard = kwargs.get('shard', [0, 1])
|
self.shard = kwargs.get('shard', [0, 1])
|
||||||
self.user_id = kwargs.get('user_id')
|
self.user_id = kwargs.get('user_id')
|
||||||
self.bot = kwargs.get('bot', False)
|
self.bot = kwargs.get('bot', False)
|
||||||
|
self.store = PayloadStore()
|
||||||
|
|
||||||
for key in kwargs:
|
for key in kwargs:
|
||||||
value = kwargs[key]
|
value = kwargs[key]
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,10 @@ import json
|
||||||
import collections
|
import collections
|
||||||
import pprint
|
import pprint
|
||||||
import zlib
|
import zlib
|
||||||
from typing import List
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
import earl
|
import earl
|
||||||
|
import websockets
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from litecord.errors import WebsocketClose, Unauthorized, Forbidden
|
from litecord.errors import WebsocketClose, Unauthorized, Forbidden
|
||||||
|
|
@ -68,14 +69,19 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
self.encoder, self.decoder = encodings[encoding]
|
self.encoder, self.decoder = encodings[encoding]
|
||||||
|
|
||||||
async def send(self, payload: dict):
|
async def send(self, payload: Dict[str, Any]):
|
||||||
"""Send a payload to the websocket"""
|
"""Send a payload to the websocket.
|
||||||
|
|
||||||
|
This function accounts for the zlib-stream
|
||||||
|
transport method used by Discord.
|
||||||
|
"""
|
||||||
log.debug('Sending {}', pprint.pformat(payload))
|
log.debug('Sending {}', pprint.pformat(payload))
|
||||||
encoded = self.encoder(payload)
|
encoded = self.encoder(payload)
|
||||||
|
|
||||||
if not isinstance(encoded, bytes):
|
if not isinstance(encoded, bytes):
|
||||||
encoded = encoded.encode()
|
encoded = encoded.encode()
|
||||||
|
|
||||||
|
print(self.wsp.compress)
|
||||||
if self.wsp.compress == 'zlib-stream':
|
if self.wsp.compress == 'zlib-stream':
|
||||||
data1 = self.wsp.zctx.compress(encoded)
|
data1 = self.wsp.zctx.compress(encoded)
|
||||||
data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH)
|
data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH)
|
||||||
|
|
@ -83,7 +89,7 @@ class GatewayWebsocket:
|
||||||
await self.ws.send(data1 + data2)
|
await self.ws.send(data1 + data2)
|
||||||
else:
|
else:
|
||||||
# TODO: pure zlib
|
# TODO: pure zlib
|
||||||
await self.ws.send(encoded)
|
await self.ws.send(encoded.decode())
|
||||||
|
|
||||||
async def send_hello(self):
|
async def send_hello(self):
|
||||||
"""Send the OP 10 Hello packet over the websocket."""
|
"""Send the OP 10 Hello packet over the websocket."""
|
||||||
|
|
@ -97,16 +103,19 @@ class GatewayWebsocket:
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
async def dispatch(self, event, data):
|
async def dispatch(self, event: str, data: Any):
|
||||||
"""Dispatch an event to the websocket."""
|
"""Dispatch an event to the websocket."""
|
||||||
self.state.seq += 1
|
self.state.seq += 1
|
||||||
|
|
||||||
await self.send({
|
payload = {
|
||||||
'op': OP.DISPATCH,
|
'op': OP.DISPATCH,
|
||||||
't': event.upper(),
|
't': event.upper(),
|
||||||
's': self.state.seq,
|
's': self.state.seq,
|
||||||
'd': data,
|
'd': data,
|
||||||
})
|
}
|
||||||
|
|
||||||
|
self.state.store[self.state.seq] = payload
|
||||||
|
await self.send(payload)
|
||||||
|
|
||||||
async def _make_guild_list(self) -> List[int]:
|
async def _make_guild_list(self) -> List[int]:
|
||||||
# TODO: This function does not account for sharding.
|
# TODO: This function does not account for sharding.
|
||||||
|
|
@ -129,7 +138,7 @@ class GatewayWebsocket:
|
||||||
for row in guild_ids
|
for row in guild_ids
|
||||||
]
|
]
|
||||||
|
|
||||||
async def guild_dispatch(self, unavailable_guilds: List[dict]):
|
async def guild_dispatch(self, unavailable_guilds: List[Dict[str, Any]]):
|
||||||
"""Dispatch GUILD_CREATE information."""
|
"""Dispatch GUILD_CREATE information."""
|
||||||
|
|
||||||
# Users don't get asynchronous guild dispatching.
|
# Users don't get asynchronous guild dispatching.
|
||||||
|
|
@ -237,10 +246,11 @@ class GatewayWebsocket:
|
||||||
if current_shard > shard_count:
|
if current_shard > shard_count:
|
||||||
raise InvalidShard('Shard count > Total shards')
|
raise InvalidShard('Shard count > Total shards')
|
||||||
|
|
||||||
async def handle_1(self, payload: dict):
|
async def handle_1(self, payload: Dict[str, Any]):
|
||||||
|
"""Handle OP 1 Heartbeat packets."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def handle_2(self, payload: dict):
|
async def handle_2(self, payload: Dict[str, Any]):
|
||||||
"""Handle the OP 2 Identify packet."""
|
"""Handle the OP 2 Identify packet."""
|
||||||
data = payload['d']
|
data = payload['d']
|
||||||
try:
|
try:
|
||||||
|
|
@ -282,6 +292,60 @@ class GatewayWebsocket:
|
||||||
self.ext.state_manager.insert(self.state)
|
self.ext.state_manager.insert(self.state)
|
||||||
await self.dispatch_ready()
|
await self.dispatch_ready()
|
||||||
|
|
||||||
|
async def handle_3(self, payload: Dict[str, Any]):
|
||||||
|
"""Handle OP 3 Status Update."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def handle_6(self, payload: Dict[str, Any]):
|
||||||
|
"""Handle OP 6 Resume."""
|
||||||
|
data = payload['d']
|
||||||
|
|
||||||
|
try:
|
||||||
|
token, sess_id, seq = data['token'], \
|
||||||
|
data['session_id'], data['seq']
|
||||||
|
except KeyError:
|
||||||
|
raise DecodeError('Invalid resume payload')
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_id = await raw_token_check(token, self.ext.db)
|
||||||
|
except (Unauthorized, Forbidden):
|
||||||
|
raise WebsocketClose(4004, 'Invalid token')
|
||||||
|
|
||||||
|
try:
|
||||||
|
state = self.ext.state_manager.fetch(user_id, sess_id)
|
||||||
|
except KeyError:
|
||||||
|
return await self.send({
|
||||||
|
'op': 9,
|
||||||
|
'd': False,
|
||||||
|
})
|
||||||
|
|
||||||
|
if seq > state.seq:
|
||||||
|
raise WebsocketClose(4007, 'Invalid seq')
|
||||||
|
|
||||||
|
# check if a websocket isnt on that state already
|
||||||
|
if state.ws is not None:
|
||||||
|
log.info('Resuming failed, websocket already connected')
|
||||||
|
return await self.send({
|
||||||
|
'op': 9,
|
||||||
|
'd': False,
|
||||||
|
})
|
||||||
|
|
||||||
|
# relink this connection
|
||||||
|
self.state = state
|
||||||
|
state.ws = self
|
||||||
|
|
||||||
|
# TODO: resend payloads
|
||||||
|
|
||||||
|
await self.dispatch('RESUMED', {})
|
||||||
|
|
||||||
|
async def handle_12(self, payload: Dict[str, Any]):
|
||||||
|
"""Handle OP 12 Guild Sync."""
|
||||||
|
data = payload['d']
|
||||||
|
|
||||||
|
for _guild_id in data:
|
||||||
|
# check if user in guild
|
||||||
|
pass
|
||||||
|
|
||||||
async def process_message(self, payload):
|
async def process_message(self, payload):
|
||||||
"""Process a single message coming in from the client."""
|
"""Process a single message coming in from the client."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -292,7 +356,7 @@ class GatewayWebsocket:
|
||||||
try:
|
try:
|
||||||
handler = getattr(self, f'handle_{op_code}')
|
handler = getattr(self, f'handle_{op_code}')
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise UnknownOPCode('Bad OP code')
|
raise UnknownOPCode(f'Bad OP code: {op_code}')
|
||||||
|
|
||||||
await handler(payload)
|
await handler(payload)
|
||||||
|
|
||||||
|
|
@ -316,7 +380,12 @@ class GatewayWebsocket:
|
||||||
try:
|
try:
|
||||||
await self.send_hello()
|
await self.send_hello()
|
||||||
await self.listen_messages()
|
await self.listen_messages()
|
||||||
|
except websockets.exceptions.ConnectionClosed as err:
|
||||||
|
log.warning('Client closed, state={}, err={}', self.state, err)
|
||||||
except WebsocketClose as err:
|
except WebsocketClose as err:
|
||||||
log.warning('closed a client, state={} err={}', self.state, err)
|
log.warning('closed a client, state={} err={}', self.state, err)
|
||||||
|
|
||||||
await self.ws.close(code=err.code, reason=err.reason)
|
await self.ws.close(code=err.code, reason=err.reason)
|
||||||
|
finally:
|
||||||
|
if self.state:
|
||||||
|
self.state.ws = None
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue