diff --git a/litecord/gateway/errors.py b/litecord/gateway/errors.py index e217d27..583e397 100644 --- a/litecord/gateway/errors.py +++ b/litecord/gateway/errors.py @@ -1,6 +1,15 @@ from litecord.errors import WebsocketClose +class GatewayError(WebsocketClose): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # hacky solution to + # decrease code repetition + self.args = [4000, self.args[0]] + + class UnknownOPCode(WebsocketClose): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index aa745c6..d1fe574 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -1,7 +1,8 @@ -import json import collections +import asyncio import pprint import zlib +import json from typing import List, Dict, Any from random import randint @@ -18,11 +19,12 @@ from .state import GatewayState from ..errors import BadRequest from ..schemas import validate, GW_STATUS_UPDATE +from ..utils import task_wrapper log = Logger(__name__) WebsocketProperties = collections.namedtuple( - 'WebsocketProperties', 'v encoding compress zctx' + 'WebsocketProperties', 'v encoding compress zctx tasks' ) WebsocketObjects = collections.namedtuple( @@ -58,7 +60,8 @@ class GatewayWebsocket: self.wsp = WebsocketProperties(kwargs.get('v'), kwargs.get('encoding', 'json'), kwargs.get('compress', None), - zlib.compressobj()) + zlib.compressobj(), + {}) self.state = None @@ -95,19 +98,39 @@ class GatewayWebsocket: # TODO: pure zlib await self.ws.send(encoded.decode()) + async def _hb_wait(self, interval: int): + """Wait heartbeat""" + await asyncio.sleep(interval / 1000) + await self.ws.close(4000, 'Heartbeat expired') + + def _hb_start(self, interval: int): + # always refresh the heartbeat task + # when possible + task = self.wsp.tasks.get('heartbeat') + if task: + task.cancel() + + self.wsp.tasks['heartbeat'] = self.ext.loop.create_task( + task_wrapper('hb wait', self._hb_wait(interval)) + ) + async def send_hello(self): """Send the OP 10 Hello packet over the websocket.""" + # random heartbeat intervals + interval = randint(40, 46) * 1000 + await self.send({ 'op': OP.HELLO, 'd': { - # random heartbeat intervals - 'heartbeat_interval': randint(40, 46) * 1000, + 'heartbeat_interval': interval, '_trace': [ 'lesbian-server' ], } }) + self._hb_start(interval) + async def dispatch(self, event: str, data: Any): """Dispatch an event to the websocket.""" self.state.seq += 1 @@ -328,8 +351,14 @@ class GatewayWebsocket: async def handle_1(self, payload: Dict[str, Any]): """Handle OP 1 Heartbeat packets.""" - # TODO: handling heartbeats - pass + # give the client 3 more seconds before we + # close the websocket + self._hb_start((46 + 3) * 1000) + cliseq = payload.get('d') + self.state.last_seq = cliseq + await self.send({ + 'op': OP.HEARTBEAT_ACK, + }) async def handle_2(self, payload: Dict[str, Any]): """Handle the OP 2 Identify packet.""" diff --git a/litecord/utils.py b/litecord/utils.py index 1a8aefb..a350dad 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -1,3 +1,9 @@ +import asyncio +from logbook import Logger + +log = Logger(__name__) + + async def async_map(function, iterable) -> list: """Map a coroutine to an iterable.""" res = [] @@ -7,3 +13,12 @@ async def async_map(function, iterable) -> list: res.append(result) return res + + +async def task_wrapper(name: str, coro): + try: + await coro + except asyncio.CancelledError: + pass + except: + log.exception('{} task error', name)