gateway.websocket: add handler for heartbeats

this should keep connections more stable since we reply and update
WebsocketState.last_seq
This commit is contained in:
Luna Mendes 2018-09-28 17:50:18 -03:00
parent d9506f450d
commit b06c07c097
3 changed files with 60 additions and 7 deletions

View File

@ -1,6 +1,15 @@
from litecord.errors import WebsocketClose
class GatewayError(WebsocketClose):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# hacky solution to
# decrease code repetition
self.args = [4000, self.args[0]]
class UnknownOPCode(WebsocketClose):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

View File

@ -1,7 +1,8 @@
import json
import collections
import asyncio
import pprint
import zlib
import json
from typing import List, Dict, Any
from random import randint
@ -18,11 +19,12 @@ from .state import GatewayState
from ..errors import BadRequest
from ..schemas import validate, GW_STATUS_UPDATE
from ..utils import task_wrapper
log = Logger(__name__)
WebsocketProperties = collections.namedtuple(
'WebsocketProperties', 'v encoding compress zctx'
'WebsocketProperties', 'v encoding compress zctx tasks'
)
WebsocketObjects = collections.namedtuple(
@ -58,7 +60,8 @@ class GatewayWebsocket:
self.wsp = WebsocketProperties(kwargs.get('v'),
kwargs.get('encoding', 'json'),
kwargs.get('compress', None),
zlib.compressobj())
zlib.compressobj(),
{})
self.state = None
@ -95,19 +98,39 @@ class GatewayWebsocket:
# TODO: pure zlib
await self.ws.send(encoded.decode())
async def _hb_wait(self, interval: int):
"""Wait heartbeat"""
await asyncio.sleep(interval / 1000)
await self.ws.close(4000, 'Heartbeat expired')
def _hb_start(self, interval: int):
# always refresh the heartbeat task
# when possible
task = self.wsp.tasks.get('heartbeat')
if task:
task.cancel()
self.wsp.tasks['heartbeat'] = self.ext.loop.create_task(
task_wrapper('hb wait', self._hb_wait(interval))
)
async def send_hello(self):
"""Send the OP 10 Hello packet over the websocket."""
# random heartbeat intervals
interval = randint(40, 46) * 1000
await self.send({
'op': OP.HELLO,
'd': {
# random heartbeat intervals
'heartbeat_interval': randint(40, 46) * 1000,
'heartbeat_interval': interval,
'_trace': [
'lesbian-server'
],
}
})
self._hb_start(interval)
async def dispatch(self, event: str, data: Any):
"""Dispatch an event to the websocket."""
self.state.seq += 1
@ -328,8 +351,14 @@ class GatewayWebsocket:
async def handle_1(self, payload: Dict[str, Any]):
"""Handle OP 1 Heartbeat packets."""
# TODO: handling heartbeats
pass
# give the client 3 more seconds before we
# close the websocket
self._hb_start((46 + 3) * 1000)
cliseq = payload.get('d')
self.state.last_seq = cliseq
await self.send({
'op': OP.HEARTBEAT_ACK,
})
async def handle_2(self, payload: Dict[str, Any]):
"""Handle the OP 2 Identify packet."""

View File

@ -1,3 +1,9 @@
import asyncio
from logbook import Logger
log = Logger(__name__)
async def async_map(function, iterable) -> list:
"""Map a coroutine to an iterable."""
res = []
@ -7,3 +13,12 @@ async def async_map(function, iterable) -> list:
res.append(result)
return res
async def task_wrapper(name: str, coro):
try:
await coro
except asyncio.CancelledError:
pass
except:
log.exception('{} task error', name)