mirror of https://gitlab.com/litecord/litecord.git
gateway.websocket: add handler for heartbeats
this should keep connections more stable since we reply and update WebsocketState.last_seq
This commit is contained in:
parent
d9506f450d
commit
b06c07c097
|
|
@ -1,6 +1,15 @@
|
||||||
from litecord.errors import WebsocketClose
|
from litecord.errors import WebsocketClose
|
||||||
|
|
||||||
|
|
||||||
|
class GatewayError(WebsocketClose):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# hacky solution to
|
||||||
|
# decrease code repetition
|
||||||
|
self.args = [4000, self.args[0]]
|
||||||
|
|
||||||
|
|
||||||
class UnknownOPCode(WebsocketClose):
|
class UnknownOPCode(WebsocketClose):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
import json
|
|
||||||
import collections
|
import collections
|
||||||
|
import asyncio
|
||||||
import pprint
|
import pprint
|
||||||
import zlib
|
import zlib
|
||||||
|
import json
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
from random import randint
|
from random import randint
|
||||||
|
|
||||||
|
|
@ -18,11 +19,12 @@ from .state import GatewayState
|
||||||
from ..errors import BadRequest
|
from ..errors import BadRequest
|
||||||
|
|
||||||
from ..schemas import validate, GW_STATUS_UPDATE
|
from ..schemas import validate, GW_STATUS_UPDATE
|
||||||
|
from ..utils import task_wrapper
|
||||||
|
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
WebsocketProperties = collections.namedtuple(
|
WebsocketProperties = collections.namedtuple(
|
||||||
'WebsocketProperties', 'v encoding compress zctx'
|
'WebsocketProperties', 'v encoding compress zctx tasks'
|
||||||
)
|
)
|
||||||
|
|
||||||
WebsocketObjects = collections.namedtuple(
|
WebsocketObjects = collections.namedtuple(
|
||||||
|
|
@ -58,7 +60,8 @@ class GatewayWebsocket:
|
||||||
self.wsp = WebsocketProperties(kwargs.get('v'),
|
self.wsp = WebsocketProperties(kwargs.get('v'),
|
||||||
kwargs.get('encoding', 'json'),
|
kwargs.get('encoding', 'json'),
|
||||||
kwargs.get('compress', None),
|
kwargs.get('compress', None),
|
||||||
zlib.compressobj())
|
zlib.compressobj(),
|
||||||
|
{})
|
||||||
|
|
||||||
self.state = None
|
self.state = None
|
||||||
|
|
||||||
|
|
@ -95,19 +98,39 @@ class GatewayWebsocket:
|
||||||
# TODO: pure zlib
|
# TODO: pure zlib
|
||||||
await self.ws.send(encoded.decode())
|
await self.ws.send(encoded.decode())
|
||||||
|
|
||||||
|
async def _hb_wait(self, interval: int):
|
||||||
|
"""Wait heartbeat"""
|
||||||
|
await asyncio.sleep(interval / 1000)
|
||||||
|
await self.ws.close(4000, 'Heartbeat expired')
|
||||||
|
|
||||||
|
def _hb_start(self, interval: int):
|
||||||
|
# always refresh the heartbeat task
|
||||||
|
# when possible
|
||||||
|
task = self.wsp.tasks.get('heartbeat')
|
||||||
|
if task:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
self.wsp.tasks['heartbeat'] = self.ext.loop.create_task(
|
||||||
|
task_wrapper('hb wait', self._hb_wait(interval))
|
||||||
|
)
|
||||||
|
|
||||||
async def send_hello(self):
|
async def send_hello(self):
|
||||||
"""Send the OP 10 Hello packet over the websocket."""
|
"""Send the OP 10 Hello packet over the websocket."""
|
||||||
|
# random heartbeat intervals
|
||||||
|
interval = randint(40, 46) * 1000
|
||||||
|
|
||||||
await self.send({
|
await self.send({
|
||||||
'op': OP.HELLO,
|
'op': OP.HELLO,
|
||||||
'd': {
|
'd': {
|
||||||
# random heartbeat intervals
|
'heartbeat_interval': interval,
|
||||||
'heartbeat_interval': randint(40, 46) * 1000,
|
|
||||||
'_trace': [
|
'_trace': [
|
||||||
'lesbian-server'
|
'lesbian-server'
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
self._hb_start(interval)
|
||||||
|
|
||||||
async def dispatch(self, event: str, data: Any):
|
async def dispatch(self, event: str, data: Any):
|
||||||
"""Dispatch an event to the websocket."""
|
"""Dispatch an event to the websocket."""
|
||||||
self.state.seq += 1
|
self.state.seq += 1
|
||||||
|
|
@ -328,8 +351,14 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
async def handle_1(self, payload: Dict[str, Any]):
|
async def handle_1(self, payload: Dict[str, Any]):
|
||||||
"""Handle OP 1 Heartbeat packets."""
|
"""Handle OP 1 Heartbeat packets."""
|
||||||
# TODO: handling heartbeats
|
# give the client 3 more seconds before we
|
||||||
pass
|
# close the websocket
|
||||||
|
self._hb_start((46 + 3) * 1000)
|
||||||
|
cliseq = payload.get('d')
|
||||||
|
self.state.last_seq = cliseq
|
||||||
|
await self.send({
|
||||||
|
'op': OP.HEARTBEAT_ACK,
|
||||||
|
})
|
||||||
|
|
||||||
async def handle_2(self, payload: Dict[str, Any]):
|
async def handle_2(self, payload: Dict[str, Any]):
|
||||||
"""Handle the OP 2 Identify packet."""
|
"""Handle the OP 2 Identify packet."""
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,9 @@
|
||||||
|
import asyncio
|
||||||
|
from logbook import Logger
|
||||||
|
|
||||||
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def async_map(function, iterable) -> list:
|
async def async_map(function, iterable) -> list:
|
||||||
"""Map a coroutine to an iterable."""
|
"""Map a coroutine to an iterable."""
|
||||||
res = []
|
res = []
|
||||||
|
|
@ -7,3 +13,12 @@ async def async_map(function, iterable) -> list:
|
||||||
res.append(result)
|
res.append(result)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
async def task_wrapper(name: str, coro):
|
||||||
|
try:
|
||||||
|
await coro
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except:
|
||||||
|
log.exception('{} task error', name)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue