import collections import asyncio import pprint import zlib import json from typing import List, Dict, Any from random import randint import earl import websockets from logbook import Logger from litecord.errors import WebsocketClose, Unauthorized, Forbidden from litecord.auth import raw_token_check from .errors import DecodeError, UnknownOPCode, \ InvalidShard, ShardingRequired from .opcodes import OP from .state import GatewayState from ..errors import BadRequest from ..schemas import validate, GW_STATUS_UPDATE from ..utils import task_wrapper log = Logger(__name__) WebsocketProperties = collections.namedtuple( 'WebsocketProperties', 'v encoding compress zctx tasks' ) WebsocketObjects = collections.namedtuple( 'WebsocketObjects', 'db state_manager storage loop dispatcher presence' ) def encode_json(payload) -> str: return json.dumps(payload) def decode_json(data: str): return json.loads(data) def encode_etf(payload) -> str: return earl.pack(payload) def decode_etf(data: bytes): return earl.unpack(data) class GatewayWebsocket: """Main gateway websocket logic.""" def __init__(self, ws, **kwargs): self.ext = WebsocketObjects(*kwargs['prop']) self.storage = self.ext.storage self.presence = self.ext.presence self.ws = ws self.wsp = WebsocketProperties(kwargs.get('v'), kwargs.get('encoding', 'json'), kwargs.get('compress', None), zlib.compressobj(), {}) self.state = None self._set_encoders() def _set_encoders(self): encoding = self.wsp.encoding encodings = { 'json': (encode_json, decode_json), 'etf': (encode_etf, decode_etf), } self.encoder, self.decoder = encodings[encoding] async def send(self, payload: Dict[str, Any]): """Send a payload to the websocket. This function accounts for the zlib-stream transport method used by Discord. """ encoded = self.encoder(payload) if len(encoded) < 1024: log.debug('sending {}', pprint.pformat(payload)) else: log.debug('sending {}', pprint.pformat(payload)) log.debug('sending op={} s={} t={} (too big)', payload.get('op'), payload.get('s'), payload.get('t')) if not isinstance(encoded, bytes): encoded = encoded.encode() if self.wsp.compress == 'zlib-stream': data1 = self.wsp.zctx.compress(encoded) data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH) await self.ws.send(data1 + data2) else: # TODO: pure zlib await self.ws.send(encoded.decode()) async def _hb_wait(self, interval: int): """Wait heartbeat""" # if the client heartbeats in time, # this task will be cancelled. await asyncio.sleep(interval / 1000) await self.ws.close(4000, 'Heartbeat expired') self._cleanup() def _hb_start(self, interval: int): # always refresh the heartbeat task # when possible task = self.wsp.tasks.get('heartbeat') if task: task.cancel() self.wsp.tasks['heartbeat'] = self.ext.loop.create_task( task_wrapper('hb wait', self._hb_wait(interval)) ) async def send_hello(self): """Send the OP 10 Hello packet over the websocket.""" # random heartbeat intervals interval = randint(40, 46) * 1000 await self.send({ 'op': OP.HELLO, 'd': { 'heartbeat_interval': interval, '_trace': [ 'lesbian-server' ], } }) self._hb_start(interval) async def dispatch(self, event: str, data: Any): """Dispatch an event to the websocket.""" self.state.seq += 1 payload = { 'op': OP.DISPATCH, 't': event.upper(), 's': self.state.seq, 'd': data, } self.state.store[self.state.seq] = payload await self.send(payload) async def _make_guild_list(self) -> List[int]: # TODO: This function does not account for sharding. user_id = self.state.user_id guild_ids = await self.storage.get_user_guilds(user_id) if self.state.bot: return [{ 'id': row[0], 'unavailable': True, } for row in guild_ids] return [ { **await self.storage.get_guild(guild_id, user_id), **await self.storage.get_guild_extra(guild_id, user_id, self.state.large) } for guild_id in guild_ids ] async def guild_dispatch(self, unavailable_guilds: List[Dict[str, Any]]): """Dispatch GUILD_CREATE information.""" # Users don't get asynchronous guild dispatching. if not self.state.bot: return for guild_obj in unavailable_guilds: guild = await self.storage.get_guild(guild_obj['id'], self.state.user_id) if not guild: continue await self.dispatch('GUILD_CREATE', dict(guild)) async def user_ready(self): """Fetch information about users in the READY packet. This part of the API is completly undocumented. PLEAS DISCORD DO NOT BAN ME """ return { # TODO 'relationships': [], # TODO 'user_guild_settings': [], 'notes': await self.storage.fetch_notes(self.state.user_id), 'friend_suggestion_count': 0, # TODO 'presences': [], # TODO 'read_state': [], 'experiments': [], 'guild_experiments': [], # TODO 'connected_accounts': [], # TODO: make those changeable 'user_settings': { 'afk_timeout': 300, 'animate_emoji': True, 'convert_emoticons': False, 'default_guilds_restricted': True, 'detect_platform_accounts': False, 'developer_mode': True, 'disable_games_tab': True, 'enable_tts_command': False, 'explicit_content_filter': 2, 'friend_source_flags': { 'mutual_friends': True }, 'gif_auto_play': True, 'guild_positions': [], 'inline_attachment_media': True, 'inline_embed_media': True, 'locale': 'en-US', 'message_display_compact': False, 'render_embeds': True, 'render_reactions': True, 'restricted_guilds': [], 'show_current_game': True, 'status': 'online', 'theme': 'dark', 'timezone_offset': 420, }, 'analytics_token': 'transbian', } async def dispatch_ready(self): """Dispatch the READY packet for a connecting account.""" guilds = await self._make_guild_list() user = await self.storage.get_user(self.state.user_id, True) uready = {} if not self.state.bot: # user, fetch info uready = await self.user_ready() await self.dispatch('READY', {**{ 'v': 6, 'user': user, # TODO: dms 'private_channels': [], 'guilds': guilds, 'session_id': self.state.session_id, '_trace': ['transbian'] }, **uready}) # async dispatch of guilds self.ext.loop.create_task(self.guild_dispatch(guilds)) async def _check_shards(self): shard = self.state.shard current_shard, shard_count = shard guilds = await self.ext.db.fetchval(""" SELECT COUNT(*) FROM members WHERE user_id = $1 """, self.state.user_id) recommended = max(int(guilds / 1200), 1) if shard_count < recommended: raise ShardingRequired('Too many guilds for shard ' f'{current_shard}') if guilds > 2500 and guilds / shard_count > 0.8: raise ShardingRequired('Too many shards. ' f'(g={guilds} sc={shard_count})') if current_shard > shard_count: raise InvalidShard('Shard count > Total shards') async def _guild_ids(self): # TODO: account for sharding guild_ids = await self.ext.db.fetch(""" SELECT guild_id FROM members WHERE user_id = $1 """, self.state.user_id) return [r['guild_id'] for r in guild_ids] async def subscribe_guilds(self): """Subscribe to all available guilds""" guild_ids = await self._guild_ids() log.info('subscribing to {} guilds', len(guild_ids)) self.ext.dispatcher.sub_many(self.state.user_id, guild_ids) async def update_status(self, status: dict): if status is None: status = { 'afk': False, # TODO: fetch status from settings 'status': 'online', 'game': None, # TODO: this 'since': 0, } self.state.presence = status try: status = validate(status, GW_STATUS_UPDATE) except BadRequest as err: log.warning(f'Invalid payload: {err}') return # try to extract game from activities # when game not provided if not status.get('game'): try: game = status['activities'][0] except (KeyError, IndexError): game = None # construct final status status = { 'afk': status.get('afk', False), 'status': status.get('status', 'online'), 'game': game, 'since': status.get('since', 0), } self.state.presence = status log.info(f'Updating presence status={status["status"]} for ' f'uid={self.state.user_id}') await self.ext.presence.dispatch_pres(self.state.user_id, self.state.presence) async def handle_1(self, payload: Dict[str, Any]): """Handle OP 1 Heartbeat packets.""" # give the client 3 more seconds before we # close the websocket self._hb_start((46 + 3) * 1000) cliseq = payload.get('d') self.state.last_seq = cliseq await self.send({ 'op': OP.HEARTBEAT_ACK, }) async def handle_2(self, payload: Dict[str, Any]): """Handle the OP 2 Identify packet.""" try: data = payload['d'] token = data['token'] except KeyError: raise DecodeError('Invalid identify parameters') compress = data.get('compress', False) large = data.get('large_threshold', 50) shard = data.get('shard', [0, 1]) presence = data.get('presence') try: user_id = await raw_token_check(token, self.ext.db) except (Unauthorized, Forbidden): raise WebsocketClose(4004, 'Authentication failed') bot = await self.ext.db.fetchval(""" SELECT bot FROM users WHERE id = $1 """, user_id) self.state = GatewayState( user_id=user_id, bot=bot, compress=compress, large=large, shard=shard, current_shard=shard[0], shard_count=shard[1], ws=self ) await self._check_shards() self.ext.state_manager.insert(self.state) await self.update_status(presence) await self.subscribe_guilds() await self.dispatch_ready() async def handle_3(self, payload: Dict[str, Any]): """Handle OP 3 Status Update.""" presence = payload['d'] # update_status will take care of validation and # setting new presence to state await self.update_status(presence) async def handle_4(self, payload: Dict[str, Any]): """Handle OP 4 Voice Status Update.""" data = payload['d'] log.debug('got VSU cid={} gid={} deaf={} mute={} video={}', data.get('channel_id'), data.get('guild_id'), data.get('self_deaf'), data.get('self_mute'), data.get('self_video')) # for now, do nothing pass async def _handle_5(self, payload: Dict[str, Any]): """Handle OP 5 Voice Server Ping. packet's data structure: { delay: num, speaking: num, ssrc: num } """ pass async def invalidate_session(self, resumable: bool = True): """Invalidate the current session and signal that to the client.""" await self.send({ 'op': OP.INVALID_SESSION, 'd': resumable, }) if not resumable and self.state: self.ext.state_manager.remove(self.state) async def _resume(self, replay_seqs: iter): presences = [] try: for seq in replay_seqs: try: payload = self.state.store[seq] except KeyError: # ignore unknown seqs continue payload_t = payload.get('t') # presence resumption happens # on a separate event, PRESENCE_REPLACE. if payload_t == 'PRESENCE_UPDATE': presences.append(payload.get('d')) continue await self.send(payload) except Exception: log.exception('error while resuming') await self.invalidate_session() return if presences: await self.dispatch('PRESENCE_REPLACE', presences) async def handle_6(self, payload: Dict[str, Any]): """Handle OP 6 Resume.""" data = payload['d'] try: token, sess_id, seq = data['token'], \ data['session_id'], data['seq'] except KeyError: raise DecodeError('Invalid resume payload') try: user_id = await raw_token_check(token, self.ext.db) except (Unauthorized, Forbidden): raise WebsocketClose(4004, 'Invalid token') try: state = self.ext.state_manager.fetch(user_id, sess_id) except KeyError: return await self.invalidate_session(False) if seq > state.seq: raise WebsocketClose(4007, 'Invalid seq') # check if a websocket isnt on that state already if state.ws is not None: log.info('Resuming failed, websocket already connected') return await self.invalidate_session(False) # relink this connection self.state = state state.ws = self await self._resume(range(seq, state.seq)) await self.dispatch('RESUMED', {}) async def _guild_sync(self, guild_id: int): members = await self.storage.get_member_data(guild_id) member_ids = [int(m['user']['id']) for m in members] log.debug(f'Syncing guild {guild_id} with {len(member_ids)} members') presences = await self.presence.guild_presences(member_ids, guild_id) await self.dispatch('GUILD_SYNC', { 'id': str(guild_id), 'presences': presences, 'members': members, }) async def handle_12(self, payload: Dict[str, Any]): """Handle OP 12 Guild Sync.""" data = payload['d'] gids = await self.storage.get_user_guilds(self.state.user_id) for guild_id in data: try: guild_id = int(guild_id) except (ValueError, TypeError): continue # check if user in guild if guild_id not in gids: continue await self._guild_sync(guild_id) async def handle_14(self, payload: Dict[str, Any]): """Lazy guilds handler. This is the known structure of an OP 14: lazy_request = { 'guild_id': guild_id, 'channels': { // the client wants a specific range of members // from the channel. so you must assume each query is // for people with roles that can Read Messages channel_id -> [[min, max], ...], ... }, 'members': [?], // ??? 'activities': bool, // ??? 'typing': bool, // ??? } This is the known structure of GUILD_MEMBER_LIST_UPDATE: sync_item = { 'group': { 'id': string, // 'online' | 'offline' | any role id 'count': num } } | { 'member': member_object } list_op = 'SYNC' | 'INVALIDATE' | 'INSERT' | 'UPDATE' | 'DELETE' list_data = { 'id': "everyone" // ?? 'guild_id': guild_id, 'ops': [ { 'op': list_op, // exists if op = 'SYNC' or 'INVALIDATE' 'range': [num, num], // exists if op = 'SYNC' 'items': sync_item[], // exists if op = 'INSERT' or 'DELETE' 'index': num, // exists if op = 'INSERT' 'item': sync_item, } ], // maybe those represent roles that show people // separately from the online list? 'groups': [ { 'id': string // 'online' | 'offline' | any role id 'count': num }, ... ] } # Implementation defails. Lazy guilds are complicated to deal with in the backend level as there are a lot of computation to be done for each request. The current implementation is rudimentary and does not account for any roles inside the guild. A correct implementation would take account of roles and make the correct groups on list_data: For each channel in lazy_request['channels']: - get all roles that have Read Messages on the channel: - Also fetch their member counts, as it'll be important - with the role list, order them like you normally would (by their role priority) - based on the channel's range's min and max and the ordered role list, you can get the roles wanted for your list_data reply. - make new groups ONLY when the role is hoisted. """ data = payload['d'] gids = await self.storage.get_user_guilds(self.state.user_id) guild_id = int(data['guild_id']) # make sure to not extract info you shouldn't get if guild_id not in gids: return member_ids = await self.storage.get_member_ids(guild_id) log.debug('lazy: loading {} members', len(member_ids)) # the current implementation is rudimentary and only # generates two groups: online and offline, using # PresenceManager.guild_presences to fill list_data. # this also doesn't take account the channels in lazy_request. guild_presences = await self.presence.guild_presences(member_ids, guild_id) online = [{'member': p} for p in guild_presences if p['status'] == 'online'] offline = [{'member': p} for p in guild_presences if p['status'] == 'offline'] log.debug('lazy: {} presences, online={}, offline={}', len(guild_presences), len(online), len(offline)) # construct items in the WORST WAY POSSIBLE. items = [{ 'group': { 'id': 'online', 'count': len(online), } }] + online + [{ 'group': { 'id': 'offline', 'count': len(offline), } }] + offline await self.dispatch('GUILD_MEMBER_LIST_UPDATE', { 'id': 'everyone', 'guild_id': data['guild_id'], 'groups': [ { 'id': 'online', 'count': len(online), }, { 'id': 'offline', 'count': len(offline), } ], 'ops': [ { 'range': [0, 99], 'op': 'SYNC', 'items': items } ] }) async def process_message(self, payload): """Process a single message coming in from the client.""" try: op_code = payload['op'] except KeyError: raise UnknownOPCode('No OP code') try: handler = getattr(self, f'handle_{op_code}') except AttributeError: raise UnknownOPCode(f'Bad OP code: {op_code}') await handler(payload) async def listen_messages(self): """Listen for messages coming in from the websocket.""" while True: message = await self.ws.recv() if len(message) > 4096: raise DecodeError('Payload length exceeded') payload = self.decoder(message) await self.process_message(payload) def _cleanup(self): if self.state: self.ext.state_manager.remove(self.state) self.state.ws = None self.state = None async def _check_conns(self, user_id): """Check if there are any existing connections. If there aren't, dispatch a presence for offline. """ if not user_id: return # TODO: account for sharding # this only updates status to offline once # ALL shards have come offline states = self.ext.state_manager.user_states(user_id) with_ws = [s for s in states if s.ws] # there arent any other states with websocket if not with_ws: offline = { 'afk': False, 'status': 'offline', 'game': None, 'since': 0, } await self.ext.presence.dispatch_pres( user_id, offline ) async def run(self): """Wrap listen_messages inside a try/except block for WebsocketClose handling.""" try: await self.send_hello() await self.listen_messages() except websockets.exceptions.ConnectionClosed as err: log.warning('conn close, state={}, err={}', self.state, err) except WebsocketClose as err: log.warning('ws close, state={} err={}', self.state, err) await self.ws.close(code=err.code, reason=err.reason) except Exception as err: log.exception('An exception has occoured. state={}', self.state) await self.ws.close(code=4000, reason=repr(err)) finally: user_id = self.state.user_id if self.state else None self._cleanup() await self._check_conns(user_id)