gateway.websocket: add handler for heartbeats

this should keep connections more stable since we reply and update
WebsocketState.last_seq
This commit is contained in:
Luna Mendes 2018-09-28 17:50:18 -03:00
parent d9506f450d
commit b06c07c097
3 changed files with 60 additions and 7 deletions

View File

@ -1,6 +1,15 @@
from litecord.errors import WebsocketClose from litecord.errors import WebsocketClose
class GatewayError(WebsocketClose):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# hacky solution to
# decrease code repetition
self.args = [4000, self.args[0]]
class UnknownOPCode(WebsocketClose): class UnknownOPCode(WebsocketClose):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)

View File

@ -1,7 +1,8 @@
import json
import collections import collections
import asyncio
import pprint import pprint
import zlib import zlib
import json
from typing import List, Dict, Any from typing import List, Dict, Any
from random import randint from random import randint
@ -18,11 +19,12 @@ from .state import GatewayState
from ..errors import BadRequest from ..errors import BadRequest
from ..schemas import validate, GW_STATUS_UPDATE from ..schemas import validate, GW_STATUS_UPDATE
from ..utils import task_wrapper
log = Logger(__name__) log = Logger(__name__)
WebsocketProperties = collections.namedtuple( WebsocketProperties = collections.namedtuple(
'WebsocketProperties', 'v encoding compress zctx' 'WebsocketProperties', 'v encoding compress zctx tasks'
) )
WebsocketObjects = collections.namedtuple( WebsocketObjects = collections.namedtuple(
@ -58,7 +60,8 @@ class GatewayWebsocket:
self.wsp = WebsocketProperties(kwargs.get('v'), self.wsp = WebsocketProperties(kwargs.get('v'),
kwargs.get('encoding', 'json'), kwargs.get('encoding', 'json'),
kwargs.get('compress', None), kwargs.get('compress', None),
zlib.compressobj()) zlib.compressobj(),
{})
self.state = None self.state = None
@ -95,19 +98,39 @@ class GatewayWebsocket:
# TODO: pure zlib # TODO: pure zlib
await self.ws.send(encoded.decode()) await self.ws.send(encoded.decode())
async def _hb_wait(self, interval: int):
"""Wait heartbeat"""
await asyncio.sleep(interval / 1000)
await self.ws.close(4000, 'Heartbeat expired')
def _hb_start(self, interval: int):
# always refresh the heartbeat task
# when possible
task = self.wsp.tasks.get('heartbeat')
if task:
task.cancel()
self.wsp.tasks['heartbeat'] = self.ext.loop.create_task(
task_wrapper('hb wait', self._hb_wait(interval))
)
async def send_hello(self): async def send_hello(self):
"""Send the OP 10 Hello packet over the websocket.""" """Send the OP 10 Hello packet over the websocket."""
# random heartbeat intervals
interval = randint(40, 46) * 1000
await self.send({ await self.send({
'op': OP.HELLO, 'op': OP.HELLO,
'd': { 'd': {
# random heartbeat intervals 'heartbeat_interval': interval,
'heartbeat_interval': randint(40, 46) * 1000,
'_trace': [ '_trace': [
'lesbian-server' 'lesbian-server'
], ],
} }
}) })
self._hb_start(interval)
async def dispatch(self, event: str, data: Any): async def dispatch(self, event: str, data: Any):
"""Dispatch an event to the websocket.""" """Dispatch an event to the websocket."""
self.state.seq += 1 self.state.seq += 1
@ -328,8 +351,14 @@ class GatewayWebsocket:
async def handle_1(self, payload: Dict[str, Any]): async def handle_1(self, payload: Dict[str, Any]):
"""Handle OP 1 Heartbeat packets.""" """Handle OP 1 Heartbeat packets."""
# TODO: handling heartbeats # give the client 3 more seconds before we
pass # close the websocket
self._hb_start((46 + 3) * 1000)
cliseq = payload.get('d')
self.state.last_seq = cliseq
await self.send({
'op': OP.HEARTBEAT_ACK,
})
async def handle_2(self, payload: Dict[str, Any]): async def handle_2(self, payload: Dict[str, Any]):
"""Handle the OP 2 Identify packet.""" """Handle the OP 2 Identify packet."""

View File

@ -1,3 +1,9 @@
import asyncio
from logbook import Logger
log = Logger(__name__)
async def async_map(function, iterable) -> list: async def async_map(function, iterable) -> list:
"""Map a coroutine to an iterable.""" """Map a coroutine to an iterable."""
res = [] res = []
@ -7,3 +13,12 @@ async def async_map(function, iterable) -> list:
res.append(result) res.append(result)
return res return res
async def task_wrapper(name: str, coro):
try:
await coro
except asyncio.CancelledError:
pass
except:
log.exception('{} task error', name)