mirror of https://gitlab.com/litecord/litecord.git
gateway.websocket: add handler for heartbeats
this should keep connections more stable since we reply and update WebsocketState.last_seq
This commit is contained in:
parent
d9506f450d
commit
b06c07c097
|
|
@ -1,6 +1,15 @@
|
|||
from litecord.errors import WebsocketClose
|
||||
|
||||
|
||||
class GatewayError(WebsocketClose):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# hacky solution to
|
||||
# decrease code repetition
|
||||
self.args = [4000, self.args[0]]
|
||||
|
||||
|
||||
class UnknownOPCode(WebsocketClose):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import json
|
||||
import collections
|
||||
import asyncio
|
||||
import pprint
|
||||
import zlib
|
||||
import json
|
||||
from typing import List, Dict, Any
|
||||
from random import randint
|
||||
|
||||
|
|
@ -18,11 +19,12 @@ from .state import GatewayState
|
|||
from ..errors import BadRequest
|
||||
|
||||
from ..schemas import validate, GW_STATUS_UPDATE
|
||||
from ..utils import task_wrapper
|
||||
|
||||
|
||||
log = Logger(__name__)
|
||||
WebsocketProperties = collections.namedtuple(
|
||||
'WebsocketProperties', 'v encoding compress zctx'
|
||||
'WebsocketProperties', 'v encoding compress zctx tasks'
|
||||
)
|
||||
|
||||
WebsocketObjects = collections.namedtuple(
|
||||
|
|
@ -58,7 +60,8 @@ class GatewayWebsocket:
|
|||
self.wsp = WebsocketProperties(kwargs.get('v'),
|
||||
kwargs.get('encoding', 'json'),
|
||||
kwargs.get('compress', None),
|
||||
zlib.compressobj())
|
||||
zlib.compressobj(),
|
||||
{})
|
||||
|
||||
self.state = None
|
||||
|
||||
|
|
@ -95,19 +98,39 @@ class GatewayWebsocket:
|
|||
# TODO: pure zlib
|
||||
await self.ws.send(encoded.decode())
|
||||
|
||||
async def _hb_wait(self, interval: int):
|
||||
"""Wait heartbeat"""
|
||||
await asyncio.sleep(interval / 1000)
|
||||
await self.ws.close(4000, 'Heartbeat expired')
|
||||
|
||||
def _hb_start(self, interval: int):
|
||||
# always refresh the heartbeat task
|
||||
# when possible
|
||||
task = self.wsp.tasks.get('heartbeat')
|
||||
if task:
|
||||
task.cancel()
|
||||
|
||||
self.wsp.tasks['heartbeat'] = self.ext.loop.create_task(
|
||||
task_wrapper('hb wait', self._hb_wait(interval))
|
||||
)
|
||||
|
||||
async def send_hello(self):
|
||||
"""Send the OP 10 Hello packet over the websocket."""
|
||||
# random heartbeat intervals
|
||||
interval = randint(40, 46) * 1000
|
||||
|
||||
await self.send({
|
||||
'op': OP.HELLO,
|
||||
'd': {
|
||||
# random heartbeat intervals
|
||||
'heartbeat_interval': randint(40, 46) * 1000,
|
||||
'heartbeat_interval': interval,
|
||||
'_trace': [
|
||||
'lesbian-server'
|
||||
],
|
||||
}
|
||||
})
|
||||
|
||||
self._hb_start(interval)
|
||||
|
||||
async def dispatch(self, event: str, data: Any):
|
||||
"""Dispatch an event to the websocket."""
|
||||
self.state.seq += 1
|
||||
|
|
@ -328,8 +351,14 @@ class GatewayWebsocket:
|
|||
|
||||
async def handle_1(self, payload: Dict[str, Any]):
|
||||
"""Handle OP 1 Heartbeat packets."""
|
||||
# TODO: handling heartbeats
|
||||
pass
|
||||
# give the client 3 more seconds before we
|
||||
# close the websocket
|
||||
self._hb_start((46 + 3) * 1000)
|
||||
cliseq = payload.get('d')
|
||||
self.state.last_seq = cliseq
|
||||
await self.send({
|
||||
'op': OP.HEARTBEAT_ACK,
|
||||
})
|
||||
|
||||
async def handle_2(self, payload: Dict[str, Any]):
|
||||
"""Handle the OP 2 Identify packet."""
|
||||
|
|
|
|||
|
|
@ -1,3 +1,9 @@
|
|||
import asyncio
|
||||
from logbook import Logger
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
||||
async def async_map(function, iterable) -> list:
|
||||
"""Map a coroutine to an iterable."""
|
||||
res = []
|
||||
|
|
@ -7,3 +13,12 @@ async def async_map(function, iterable) -> list:
|
|||
res.append(result)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
async def task_wrapper(name: str, coro):
|
||||
try:
|
||||
await coro
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except:
|
||||
log.exception('{} task error', name)
|
||||
|
|
|
|||
Loading…
Reference in New Issue