mirror of https://gitlab.com/litecord/litecord.git
1053 lines
33 KiB
Python
1053 lines
33 KiB
Python
"""
|
|
|
|
Litecord
|
|
Copyright (C) 2018-2019 Luna Mendes
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, version 3 of the License.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
"""
|
|
|
|
import collections
|
|
import asyncio
|
|
import pprint
|
|
import zlib
|
|
from typing import List, Dict, Any
|
|
from random import randint
|
|
|
|
import websockets
|
|
import zstandard as zstd
|
|
from logbook import Logger
|
|
|
|
from litecord.auth import raw_token_check
|
|
from litecord.enums import RelationshipType, ChannelType
|
|
from litecord.schemas import validate, GW_STATUS_UPDATE
|
|
from litecord.utils import (
|
|
task_wrapper, yield_chunks
|
|
)
|
|
from litecord.permissions import get_permissions
|
|
|
|
from litecord.gateway.opcodes import OP
|
|
from litecord.gateway.state import GatewayState
|
|
|
|
from litecord.errors import (
|
|
WebsocketClose, Unauthorized, Forbidden, BadRequest
|
|
)
|
|
from litecord.gateway.errors import (
|
|
DecodeError, UnknownOPCode, InvalidShard, ShardingRequired
|
|
)
|
|
from litecord.gateway.encoding import (
|
|
encode_json, decode_json, encode_etf, decode_etf
|
|
)
|
|
|
|
from litecord.gateway.utils import WebsocketFileHandler
|
|
|
|
from litecord.storage import int_
|
|
|
|
log = Logger(__name__)
|
|
|
|
WebsocketProperties = collections.namedtuple(
|
|
'WebsocketProperties', 'v encoding compress zctx zsctx tasks'
|
|
)
|
|
|
|
WebsocketObjects = collections.namedtuple(
|
|
'WebsocketObjects', (
|
|
'db', 'state_manager', 'storage',
|
|
'loop', 'dispatcher', 'presence', 'ratelimiter',
|
|
'user_storage', 'voice'
|
|
)
|
|
)
|
|
|
|
|
|
class GatewayWebsocket:
|
|
"""Main gateway websocket logic."""
|
|
|
|
def __init__(self, ws, app, **kwargs):
|
|
self.ext = WebsocketObjects(
|
|
app.db, app.state_manager, app.storage, app.loop,
|
|
app.dispatcher, app.presence, app.ratelimiter,
|
|
app.user_storage, app.voice
|
|
)
|
|
|
|
self.storage = self.ext.storage
|
|
self.user_storage = self.ext.user_storage
|
|
self.presence = self.ext.presence
|
|
self.ws = ws
|
|
|
|
self.wsp = WebsocketProperties(
|
|
kwargs.get('v'),
|
|
kwargs.get('encoding', 'json'),
|
|
kwargs.get('compress', None),
|
|
zlib.compressobj(),
|
|
zstd.ZstdCompressor(),
|
|
{}
|
|
)
|
|
|
|
log.debug('websocket properties: {!r}', self.wsp)
|
|
|
|
self.state = None
|
|
|
|
self._set_encoders()
|
|
|
|
def _set_encoders(self):
|
|
encoding = self.wsp.encoding
|
|
|
|
encodings = {
|
|
'json': (encode_json, decode_json),
|
|
'etf': (encode_etf, decode_etf),
|
|
}
|
|
|
|
self.encoder, self.decoder = encodings[encoding]
|
|
|
|
async def _chunked_send(self, data: bytes, chunk_size: int):
|
|
"""Split data in chunk_size-big chunks and send them
|
|
over the websocket."""
|
|
log.debug('zlib-stream: chunking {} bytes into {}-byte chunks',
|
|
len(data), chunk_size)
|
|
|
|
total_chunks = 0
|
|
for chunk in yield_chunks(data, chunk_size):
|
|
total_chunks += 1
|
|
log.debug('zlib-stream: chunk {}', total_chunks)
|
|
await self.ws.send(chunk)
|
|
|
|
log.debug('zlib-stream: sent {} chunks', total_chunks)
|
|
|
|
async def _zlib_stream_send(self, encoded):
|
|
"""Sending a single payload across multiple compressed
|
|
websocket messages."""
|
|
|
|
# compress and flush (for the rest of compressed data + ZLIB_SUFFIX)
|
|
data1 = self.wsp.zctx.compress(encoded)
|
|
data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH)
|
|
|
|
log.debug('zlib-stream: length {} -> compressed ({} + {})',
|
|
len(encoded), len(data1), len(data2))
|
|
|
|
if not data1:
|
|
# if data1 is nothing, that might cause problems
|
|
# to clients, since they'll receive an empty message
|
|
data1 = bytes([data2[0]])
|
|
data2 = data2[1:]
|
|
|
|
log.debug('zlib-stream: len(data1) == 0, remaking as ({} + {})',
|
|
len(data1), len(data2))
|
|
|
|
# NOTE: the old approach was ws.send(data1 + data2).
|
|
# I changed this to a chunked send of data1 and data2
|
|
# because that can bring some problems to the network
|
|
# since we can be potentially sending a really big packet
|
|
# as a single message.
|
|
|
|
# clients should handle chunked sends (via detection
|
|
# of the ZLIB_SUFFIX suffix appended to data2), so
|
|
# this shouldn't being problems.
|
|
|
|
# TODO: the chunks are 1024 bytes, 1KB, is this good enough?
|
|
await self._chunked_send(data1, 1024)
|
|
await self._chunked_send(data2, 1024)
|
|
|
|
async def _zstd_stream_send(self, encoded):
|
|
compressor = self.wsp.zsctx.stream_writer(
|
|
WebsocketFileHandler(self.ws))
|
|
|
|
compressor.write(encoded)
|
|
compressor.flush(zstd.FLUSH_FRAME)
|
|
|
|
async def send(self, payload: Dict[str, Any]):
|
|
"""Send a payload to the websocket.
|
|
|
|
This function accounts for the zlib-stream
|
|
transport method used by Discord.
|
|
"""
|
|
encoded = self.encoder(payload)
|
|
|
|
if len(encoded) < 2048:
|
|
log.debug('sending\n{}', pprint.pformat(payload))
|
|
else:
|
|
log.debug('sending {}', pprint.pformat(payload))
|
|
log.debug('sending op={} s={} t={} (too big)',
|
|
payload.get('op'),
|
|
payload.get('s'),
|
|
payload.get('t'))
|
|
|
|
# treat encoded as bytes
|
|
if not isinstance(encoded, bytes):
|
|
encoded = encoded.encode()
|
|
|
|
if self.wsp.compress == 'zlib-stream':
|
|
await self._zlib_stream_send(encoded)
|
|
elif self.wsp.compress == 'zstd-stream':
|
|
await self._zstd_stream_send(encoded)
|
|
elif self.state and self.state.compress and len(encoded) > 1024:
|
|
# TODO: should we only compress on >1KB packets? or maybe we
|
|
# should do all?
|
|
await self.ws.send(zlib.compress(encoded))
|
|
else:
|
|
try:
|
|
# assume encoded is string, json based, decoding it
|
|
# should give reasonable messages down the websocket
|
|
await self.ws.send(encoded.decode())
|
|
except UnicodeDecodeError:
|
|
# in here, encoded is ETF, its bytes(), so we send it raw
|
|
await self.ws.send(encoded)
|
|
|
|
async def send_op(self, op_code: int, data: Any):
|
|
"""Send a packet but just the OP code information is filled in."""
|
|
await self.send({
|
|
'op': op_code,
|
|
'd': data,
|
|
|
|
't': None,
|
|
's': None
|
|
})
|
|
|
|
def _check_ratelimit(self, key: str, ratelimit_key):
|
|
ratelimit = self.ext.ratelimiter.get_ratelimit(f'_ws.{key}')
|
|
bucket = ratelimit.get_bucket(ratelimit_key)
|
|
return bucket.update_rate_limit()
|
|
|
|
async def _hb_wait(self, interval: int):
|
|
"""Wait heartbeat"""
|
|
# if the client heartbeats in time,
|
|
# this task will be cancelled.
|
|
await asyncio.sleep(interval / 1000)
|
|
await self.ws.close(4000, 'Heartbeat expired')
|
|
|
|
self._cleanup()
|
|
|
|
def _hb_start(self, interval: int):
|
|
# always refresh the heartbeat task
|
|
# when possible
|
|
task = self.wsp.tasks.get('heartbeat')
|
|
if task:
|
|
task.cancel()
|
|
|
|
self.wsp.tasks['heartbeat'] = self.ext.loop.create_task(
|
|
task_wrapper('hb wait', self._hb_wait(interval))
|
|
)
|
|
|
|
async def _send_hello(self):
|
|
"""Send the OP 10 Hello packet over the websocket."""
|
|
# random heartbeat intervals
|
|
interval = randint(40, 46) * 1000
|
|
|
|
await self.send_op(OP.HELLO, {
|
|
'heartbeat_interval': interval,
|
|
'_trace': [
|
|
'lesbian-server'
|
|
],
|
|
})
|
|
|
|
self._hb_start(interval)
|
|
|
|
async def dispatch(self, event: str, data: Any):
|
|
"""Dispatch an event to the websocket."""
|
|
self.state.seq += 1
|
|
|
|
payload = {
|
|
'op': OP.DISPATCH,
|
|
't': event.upper(),
|
|
's': self.state.seq,
|
|
'd': data,
|
|
}
|
|
|
|
self.state.store[self.state.seq] = payload
|
|
|
|
log.debug('sending payload {!r} sid {}',
|
|
event.upper(), self.state.session_id)
|
|
|
|
await self.send(payload)
|
|
|
|
async def _make_guild_list(self) -> List[Dict[str, Any]]:
|
|
user_id = self.state.user_id
|
|
|
|
guild_ids = await self._guild_ids()
|
|
|
|
if self.state.bot:
|
|
return [{
|
|
'id': row,
|
|
'unavailable': True,
|
|
} for row in guild_ids]
|
|
|
|
return [
|
|
{
|
|
**await self.storage.get_guild(guild_id, user_id),
|
|
**await self.storage.get_guild_extra(guild_id, user_id,
|
|
self.state.large)
|
|
}
|
|
for guild_id in guild_ids
|
|
]
|
|
|
|
async def _guild_dispatch(self, unavailable_guilds: List[Dict[str, Any]]):
|
|
"""Dispatch GUILD_CREATE information."""
|
|
|
|
# Users don't get asynchronous guild dispatching.
|
|
if not self.state.bot:
|
|
return
|
|
|
|
for guild_obj in unavailable_guilds:
|
|
# fetch full guild object including the 'large' field
|
|
guild = await self.storage.get_guild_full(
|
|
int(guild_obj['id']), self.state.user_id, self.state.large
|
|
)
|
|
|
|
if guild is None:
|
|
continue
|
|
|
|
await self.dispatch('GUILD_CREATE', guild)
|
|
|
|
async def _user_ready(self) -> dict:
|
|
"""Fetch information about users in the READY packet.
|
|
|
|
This part of the API is completly undocumented.
|
|
PLEAS DISCORD DO NOT BAN ME
|
|
"""
|
|
|
|
user_id = self.state.user_id
|
|
|
|
relationships = await self.user_storage.get_relationships(user_id)
|
|
|
|
friend_ids = [int(r['user']['id']) for r in relationships
|
|
if r['type'] == RelationshipType.FRIEND.value]
|
|
|
|
friend_presences = await self.ext.presence.friend_presences(friend_ids)
|
|
settings = await self.user_storage.get_user_settings(user_id)
|
|
|
|
return {
|
|
'user_settings': settings,
|
|
'notes': await self.user_storage.fetch_notes(user_id),
|
|
'relationships': relationships,
|
|
'presences': friend_presences,
|
|
'read_state': await self.user_storage.get_read_state(user_id),
|
|
'user_guild_settings': await self.user_storage.get_guild_settings(
|
|
user_id),
|
|
|
|
'friend_suggestion_count': 0,
|
|
|
|
# those are unused default values.
|
|
'connected_accounts': [],
|
|
'experiments': [],
|
|
'guild_experiments': [],
|
|
'analytics_token': 'transbian',
|
|
}
|
|
|
|
async def dispatch_ready(self):
|
|
"""Dispatch the READY packet for a connecting account."""
|
|
guilds = await self._make_guild_list()
|
|
|
|
user_id = self.state.user_id
|
|
user = await self.storage.get_user(user_id, True)
|
|
|
|
user_ready = {}
|
|
if not self.state.bot:
|
|
# user, fetch info
|
|
user_ready = await self._user_ready()
|
|
|
|
private_channels = (
|
|
await self.user_storage.get_dms(user_id) +
|
|
await self.user_storage.get_gdms(user_id)
|
|
)
|
|
|
|
base_ready = {
|
|
'v': 6,
|
|
'user': user,
|
|
|
|
'private_channels': private_channels,
|
|
|
|
'guilds': guilds,
|
|
'session_id': self.state.session_id,
|
|
'_trace': ['transbian'],
|
|
'shard': self.state.shard,
|
|
}
|
|
|
|
await self.dispatch('READY', {**base_ready, **user_ready})
|
|
|
|
# async dispatch of guilds
|
|
self.ext.loop.create_task(self._guild_dispatch(guilds))
|
|
|
|
async def _check_shards(self, shard, user_id):
|
|
"""Check if the given `shard` value in IDENTIFY has good enough values.
|
|
"""
|
|
current_shard, shard_count = shard
|
|
|
|
guilds = await self.ext.db.fetchval("""
|
|
SELECT COUNT(*)
|
|
FROM members
|
|
WHERE user_id = $1
|
|
""", user_id)
|
|
|
|
recommended = max(int(guilds / 1200), 1)
|
|
|
|
if shard_count < recommended:
|
|
raise ShardingRequired('Too many guilds for shard '
|
|
f'{current_shard}')
|
|
|
|
if guilds > 2500 and guilds / shard_count > 0.8:
|
|
raise ShardingRequired('Too many shards. '
|
|
f'(g={guilds} sc={shard_count})')
|
|
|
|
if current_shard > shard_count:
|
|
raise InvalidShard('Shard count > Total shards')
|
|
|
|
async def _guild_ids(self) -> list:
|
|
"""Get a list of Guild IDs that are tied to this connection.
|
|
|
|
The implementation is shard-aware.
|
|
"""
|
|
guild_ids = await self.user_storage.get_user_guilds(
|
|
self.state.user_id
|
|
)
|
|
|
|
shard_id = self.state.current_shard
|
|
shard_count = self.state.shard_count
|
|
|
|
def _get_shard(guild_id):
|
|
return (guild_id >> 22) % shard_count
|
|
|
|
filtered = filter(
|
|
lambda guild_id: _get_shard(guild_id) == shard_id,
|
|
guild_ids
|
|
)
|
|
|
|
return list(filtered)
|
|
|
|
async def subscribe_all(self):
|
|
"""Subscribe to all guilds, DM channels, and friends.
|
|
|
|
Note: subscribing to channels is already handled
|
|
by GuildDispatcher.sub
|
|
"""
|
|
user_id = self.state.user_id
|
|
guild_ids = await self._guild_ids()
|
|
|
|
# subscribe the user to all dms they have OPENED.
|
|
dms = await self.user_storage.get_dms(user_id)
|
|
dm_ids = [int(dm['id']) for dm in dms]
|
|
|
|
# fetch all group dms the user is a member of.
|
|
gdm_ids = await self.user_storage.get_gdms_internal(user_id)
|
|
|
|
log.info('subscribing to {} guilds', len(guild_ids))
|
|
log.info('subscribing to {} dms', len(dm_ids))
|
|
log.info('subscribing to {} group dms', len(gdm_ids))
|
|
|
|
await self.ext.dispatcher.mass_sub(user_id, [
|
|
('guild', guild_ids),
|
|
('channel', dm_ids),
|
|
('channel', gdm_ids)
|
|
])
|
|
|
|
if not self.state.bot:
|
|
# subscribe to all friends
|
|
# (their friends will also subscribe back
|
|
# when they come online)
|
|
friend_ids = await self.user_storage.get_friend_ids(user_id)
|
|
log.info('subscribing to {} friends', len(friend_ids))
|
|
await self.ext.dispatcher.sub_many('friend', user_id, friend_ids)
|
|
|
|
async def update_status(self, status: dict):
|
|
"""Update the status of the current websocket connection."""
|
|
if not self.state:
|
|
return
|
|
|
|
if self._check_ratelimit('presence', self.state.session_id):
|
|
# Presence Updates beyond the ratelimit
|
|
# are just silently dropped.
|
|
return
|
|
|
|
default_status = {
|
|
'afk': False,
|
|
|
|
# TODO: fetch status from settings
|
|
'status': 'online',
|
|
'game': None,
|
|
|
|
# TODO: this
|
|
'since': 0,
|
|
}
|
|
|
|
status = {**(status or {}), **default_status}
|
|
|
|
try:
|
|
status = validate(status, GW_STATUS_UPDATE)
|
|
except BadRequest as err:
|
|
log.warning(f'Invalid status update: {err}')
|
|
return
|
|
|
|
# try to extract game from activities
|
|
# when game not provided
|
|
if not status.get('game'):
|
|
try:
|
|
game = status['activities'][0]
|
|
except (KeyError, IndexError):
|
|
game = None
|
|
else:
|
|
game = status['game']
|
|
|
|
# construct final status
|
|
status = {
|
|
'afk': status.get('afk', False),
|
|
'status': status.get('status', 'online'),
|
|
'game': game,
|
|
'since': status.get('since', 0),
|
|
}
|
|
|
|
self.state.presence = status
|
|
log.info(f'Updating presence status={status["status"]} for '
|
|
f'uid={self.state.user_id}')
|
|
await self.ext.presence.dispatch_pres(self.state.user_id,
|
|
self.state.presence)
|
|
|
|
async def handle_1(self, payload: Dict[str, Any]):
|
|
"""Handle OP 1 Heartbeat packets."""
|
|
# give the client 3 more seconds before we
|
|
# close the websocket
|
|
self._hb_start((46 + 3) * 1000)
|
|
cliseq = payload.get('d')
|
|
|
|
if self.state:
|
|
self.state.last_seq = cliseq
|
|
|
|
await self.send_op(OP.HEARTBEAT_ACK, None)
|
|
|
|
async def _connect_ratelimit(self, user_id: int):
|
|
if self._check_ratelimit('connect', user_id):
|
|
await self.invalidate_session(False)
|
|
raise WebsocketClose(4009, 'You are being ratelimited.')
|
|
|
|
if self._check_ratelimit('session', user_id):
|
|
await self.invalidate_session(False)
|
|
raise WebsocketClose(4004, 'Websocket Session Ratelimit reached.')
|
|
|
|
async def handle_2(self, payload: Dict[str, Any]):
|
|
"""Handle the OP 2 Identify packet."""
|
|
try:
|
|
data = payload['d']
|
|
token = data['token']
|
|
except KeyError:
|
|
raise DecodeError('Invalid identify parameters')
|
|
|
|
compress = data.get('compress', False)
|
|
large = data.get('large_threshold', 50)
|
|
|
|
shard = data.get('shard', [0, 1])
|
|
presence = data.get('presence')
|
|
|
|
try:
|
|
user_id = await raw_token_check(token, self.ext.db)
|
|
except (Unauthorized, Forbidden):
|
|
raise WebsocketClose(4004, 'Authentication failed')
|
|
|
|
await self._connect_ratelimit(user_id)
|
|
|
|
bot = await self.ext.db.fetchval("""
|
|
SELECT bot FROM users
|
|
WHERE id = $1
|
|
""", user_id)
|
|
|
|
await self._check_shards(shard, user_id)
|
|
|
|
# only create a state after checking everything
|
|
self.state = GatewayState(
|
|
user_id=user_id,
|
|
bot=bot,
|
|
compress=compress,
|
|
large=large,
|
|
shard=shard,
|
|
current_shard=shard[0],
|
|
shard_count=shard[1],
|
|
ws=self
|
|
)
|
|
|
|
# link the state to the user
|
|
self.ext.state_manager.insert(self.state)
|
|
|
|
await self.update_status(presence)
|
|
await self.subscribe_all()
|
|
await self.dispatch_ready()
|
|
|
|
async def handle_3(self, payload: Dict[str, Any]):
|
|
"""Handle OP 3 Status Update."""
|
|
presence = payload['d']
|
|
|
|
# update_status will take care of validation and
|
|
# setting new presence to state
|
|
await self.update_status(presence)
|
|
|
|
async def _vsu_get_prop(self, state, data):
|
|
"""Get voice state properties from data, fallbacking to
|
|
user settings."""
|
|
try:
|
|
# TODO: fetch from settings if not provided
|
|
self_deaf = bool(data['self_deaf'])
|
|
self_mute = bool(data['self_mute'])
|
|
except (KeyError, ValueError):
|
|
pass
|
|
|
|
return {
|
|
'deaf': state.deaf,
|
|
'mute': state.mute,
|
|
'self_deaf': self_deaf,
|
|
'self_mute': self_mute,
|
|
}
|
|
|
|
async def handle_4(self, payload: Dict[str, Any]):
|
|
"""Handle OP 4 Voice Status Update."""
|
|
data = payload['d']
|
|
|
|
if not self.state:
|
|
return
|
|
|
|
channel_id = int_(data.get('channel_id'))
|
|
guild_id = int_(data.get('guild_id'))
|
|
|
|
# if its null and null, disconnect the user from any voice
|
|
# TODO: maybe just leave from DMs? idk...
|
|
if channel_id is None and guild_id is None:
|
|
return await self.ext.voice.leave_all(self.state.user_id)
|
|
|
|
# if guild is not none but channel is, we are leaving
|
|
# a guild's channel
|
|
if channel_id is None:
|
|
return await self.ext.voice.leave(guild_id, self.state.user_id)
|
|
|
|
# fetch an existing state given user and guild OR user and channel
|
|
chan_type = ChannelType(
|
|
await self.storage.get_chan_type(channel_id)
|
|
)
|
|
|
|
state_id2 = channel_id
|
|
|
|
if chan_type == ChannelType.GUILD_VOICE:
|
|
state_id2 = guild_id
|
|
|
|
# a voice state key is a Tuple[int, int]
|
|
# - [0] is the user id
|
|
# - [1] is the channel id or guild id
|
|
|
|
# the old approach was a (user_id, session_id), but
|
|
# that does not work.
|
|
|
|
# this works since users can be connected to many channels
|
|
# using a single gateway websocket connection. HOWEVER,
|
|
# they CAN NOT enter two channels in a single guild.
|
|
|
|
# this state id format takes care of that.
|
|
voice_key = (self.state.user_id, state_id2)
|
|
voice_state = await self.ext.voice.get_state(voice_key)
|
|
|
|
if voice_state is None:
|
|
return await self.ext.voice.create_state(voice_key, data)
|
|
|
|
same_guild = guild_id == voice_state.guild_id
|
|
same_channel = channel_id == voice_state.channel_id
|
|
|
|
prop = await self._vsu_get_prop(voice_state, data)
|
|
|
|
if same_guild and same_channel:
|
|
return await self.ext.voice.update_state(voice_state, prop)
|
|
|
|
if same_guild and not same_channel:
|
|
return await self.ext.voice.move_state(voice_state, channel_id)
|
|
|
|
async def _handle_5(self, payload: Dict[str, Any]):
|
|
"""Handle OP 5 Voice Server Ping.
|
|
|
|
packet's data structure:
|
|
|
|
{
|
|
delay: num,
|
|
speaking: num,
|
|
ssrc: num
|
|
}
|
|
|
|
"""
|
|
pass
|
|
|
|
async def invalidate_session(self, resumable: bool = True):
|
|
"""Invalidate the current session and signal that
|
|
to the client."""
|
|
await self.send_op(OP.INVALID_SESSION, resumable)
|
|
|
|
if not resumable and self.state:
|
|
# since the state will be removed from
|
|
# the manager, it will become unreachable
|
|
# when trying to resume.
|
|
self.ext.state_manager.remove(self.state)
|
|
|
|
async def _resume(self, replay_seqs: iter):
|
|
presences = []
|
|
|
|
try:
|
|
for seq in replay_seqs:
|
|
try:
|
|
payload = self.state.store[seq]
|
|
except KeyError:
|
|
# ignore unknown seqs
|
|
continue
|
|
|
|
payload_t = payload.get('t')
|
|
|
|
# presence resumption happens
|
|
# on a separate event, PRESENCE_REPLACE.
|
|
if payload_t == 'PRESENCE_UPDATE':
|
|
presences.append(payload.get('d'))
|
|
continue
|
|
|
|
await self.send(payload)
|
|
except Exception:
|
|
log.exception('error while resuming')
|
|
await self.invalidate_session(False)
|
|
return
|
|
|
|
if presences:
|
|
await self.dispatch('PRESENCE_REPLACE', presences)
|
|
|
|
await self.dispatch('RESUMED', {})
|
|
|
|
async def handle_6(self, payload: Dict[str, Any]):
|
|
"""Handle OP 6 Resume."""
|
|
data = payload['d']
|
|
|
|
try:
|
|
token, sess_id, seq = data['token'], \
|
|
data['session_id'], data['seq']
|
|
except KeyError:
|
|
raise DecodeError('Invalid resume payload')
|
|
|
|
try:
|
|
user_id = await raw_token_check(token, self.ext.db)
|
|
except (Unauthorized, Forbidden):
|
|
raise WebsocketClose(4004, 'Invalid token')
|
|
|
|
try:
|
|
state = self.ext.state_manager.fetch(user_id, sess_id)
|
|
except KeyError:
|
|
return await self.invalidate_session(False)
|
|
|
|
if seq > state.seq:
|
|
raise WebsocketClose(4007, 'Invalid seq')
|
|
|
|
# check if a websocket isnt on that state already
|
|
if state.ws is not None:
|
|
log.info('Resuming failed, websocket already connected')
|
|
return await self.invalidate_session(False)
|
|
|
|
# relink this connection
|
|
self.state = state
|
|
state.ws = self
|
|
|
|
await self._resume(range(seq, state.seq))
|
|
|
|
async def _req_guild_members(self, guild_id, user_ids: List[int],
|
|
query: str, limit: int):
|
|
try:
|
|
guild_id = int(guild_id)
|
|
except (TypeError, ValueError):
|
|
return
|
|
|
|
limit = limit or 1000
|
|
exists = await self.storage.get_guild(guild_id)
|
|
|
|
if not exists:
|
|
return
|
|
|
|
# limit user_ids to 1000 possible members
|
|
user_ids = user_ids[:1000]
|
|
|
|
# assumption: requesting user_ids means
|
|
# we don't do query.
|
|
if user_ids:
|
|
members = await self.storage.get_member_multi(guild_id, user_ids)
|
|
mids = [m['user']['id'] for m in members]
|
|
not_found = [uid for uid in user_ids if uid not in mids]
|
|
|
|
await self.dispatch('GUILD_MEMBERS_CHUNK', {
|
|
'guild_id': str(guild_id),
|
|
'members': members,
|
|
'not_found': not_found,
|
|
})
|
|
|
|
return
|
|
|
|
# do the search
|
|
result = await self.storage.query_members(guild_id, query, limit)
|
|
await self.dispatch('GUILD_MEMBERS_CHUNK', {
|
|
'guild_id': str(guild_id),
|
|
'members': result
|
|
})
|
|
|
|
async def handle_8(self, payload: Dict):
|
|
"""Handle OP 8 Request Guild Members."""
|
|
data = payload['d']
|
|
gids = data['guild_id']
|
|
|
|
uids, query, limit = data.get('user_ids', []), \
|
|
data.get('query', ''), \
|
|
data.get('limit', 0)
|
|
|
|
if isinstance(gids, str):
|
|
await self._req_guild_members(gids, uids, query, limit)
|
|
return
|
|
|
|
for gid in gids:
|
|
# ignore uids on multiple guilds
|
|
await self._req_guild_members(gid, [], query, limit)
|
|
|
|
async def _guild_sync(self, guild_id: int):
|
|
"""Synchronize a guild.
|
|
|
|
Fetches the members and presences of a guild and dispatches a
|
|
GUILD_SYNC event with that info.
|
|
"""
|
|
members = await self.storage.get_member_data(guild_id)
|
|
member_ids = [int(m['user']['id']) for m in members]
|
|
|
|
log.debug(f'Syncing guild {guild_id} with {len(member_ids)} members')
|
|
presences = await self.presence.guild_presences(member_ids, guild_id)
|
|
|
|
await self.dispatch('GUILD_SYNC', {
|
|
'id': str(guild_id),
|
|
'presences': presences,
|
|
'members': members,
|
|
})
|
|
|
|
async def handle_12(self, payload: Dict[str, Any]):
|
|
"""Handle OP 12 Guild Sync."""
|
|
data = payload['d']
|
|
|
|
gids = await self.user_storage.get_user_guilds(
|
|
self.state.user_id)
|
|
|
|
for guild_id in data:
|
|
try:
|
|
guild_id = int(guild_id)
|
|
except (ValueError, TypeError):
|
|
continue
|
|
|
|
# check if user in guild
|
|
if guild_id not in gids:
|
|
continue
|
|
|
|
await self._guild_sync(guild_id)
|
|
|
|
async def handle_13(self, payload: Dict[str, Any]):
|
|
"""Handle CALL_SYNC request.
|
|
|
|
There isn't any need to actually finish the implementation
|
|
since we don't have voice. Discord doesn't seem to send anything
|
|
on text-only DMs, so I'll keep that behavior and do nothing.
|
|
|
|
CALL_SYNC structure (for now, we don't know if there is anything else):
|
|
{
|
|
channel_id: snowflake
|
|
}
|
|
"""
|
|
pass
|
|
|
|
async def handle_14(self, payload: Dict[str, Any]):
|
|
"""Lazy guilds handler.
|
|
|
|
This is the known structure of an OP 14:
|
|
|
|
lazy_request = {
|
|
'guild_id': guild_id,
|
|
'channels': {
|
|
// the client wants a specific range of members
|
|
// from the channel. so you must assume each query is
|
|
// for people with roles that can Read Messages
|
|
channel_id -> [[min, max], ...],
|
|
...
|
|
},
|
|
|
|
'members': [?], // ???
|
|
'activities': bool, // ???
|
|
'typing': bool, // ???
|
|
}
|
|
|
|
This is the known structure of GUILD_MEMBER_LIST_UPDATE:
|
|
|
|
group_id = 'online' | 'offline' | role_id (string)
|
|
|
|
sync_item = {
|
|
'group': {
|
|
'id': group_id,
|
|
'count': num
|
|
}
|
|
} | {
|
|
'member': member_object
|
|
}
|
|
|
|
list_op = 'SYNC' | 'INVALIDATE' | 'INSERT' | 'UPDATE' | 'DELETE'
|
|
|
|
list_data = {
|
|
'id': channel_id | 'everyone',
|
|
'guild_id': guild_id,
|
|
|
|
'ops': [
|
|
{
|
|
'op': list_op,
|
|
|
|
// exists if op = 'SYNC' or 'INVALIDATE'
|
|
'range': [num, num],
|
|
|
|
// exists if op = 'SYNC'
|
|
'items': sync_item[],
|
|
|
|
// exists if op == 'INSERT' | 'DELETE' | 'UPDATE'
|
|
'index': num,
|
|
|
|
// exists if op == 'INSERT' | 'UPDATE'
|
|
'item': sync_item,
|
|
}
|
|
],
|
|
|
|
// maybe those represent roles that show people
|
|
// separately from the online list?
|
|
'groups': [
|
|
{
|
|
'id': group_id
|
|
'count': num
|
|
}, ...
|
|
]
|
|
}
|
|
"""
|
|
data = payload['d']
|
|
|
|
gids = await self.user_storage.get_user_guilds(self.state.user_id)
|
|
guild_id = int(data['guild_id'])
|
|
|
|
# make sure to not extract info you shouldn't get
|
|
if guild_id not in gids:
|
|
return
|
|
|
|
log.debug('lazy request: members: {}',
|
|
data.get('members', []))
|
|
|
|
# make shard query
|
|
lazy_guilds = self.ext.dispatcher.backends['lazy_guild']
|
|
|
|
for chan_id, ranges in data.get('channels', {}).items():
|
|
chan_id = int(chan_id)
|
|
member_list = await lazy_guilds.get_gml(chan_id)
|
|
|
|
perms = await get_permissions(
|
|
self.state.user_id, chan_id, storage=self.storage)
|
|
|
|
if not perms.bits.read_messages:
|
|
# ignore requests to unknown channels
|
|
return
|
|
|
|
await member_list.shard_query(
|
|
self.state.session_id, ranges
|
|
)
|
|
|
|
async def _process_message(self, payload):
|
|
"""Process a single message coming in from the client."""
|
|
try:
|
|
op_code = payload['op']
|
|
except KeyError:
|
|
raise UnknownOPCode('No OP code')
|
|
|
|
try:
|
|
handler = getattr(self, f'handle_{op_code}')
|
|
except AttributeError:
|
|
log.warning('Payload with bad op: {}', pprint.pformat(payload))
|
|
raise UnknownOPCode(f'Bad OP code: {op_code}')
|
|
|
|
await handler(payload)
|
|
|
|
async def _msg_ratelimit(self):
|
|
if self._check_ratelimit('messages', self.state.session_id):
|
|
raise WebsocketClose(4008, 'You are being ratelimited.')
|
|
|
|
async def _listen_messages(self):
|
|
"""Listen for messages coming in from the websocket."""
|
|
|
|
# close anyone trying to login while the
|
|
# server is shutting down
|
|
if self.ext.state_manager.closed:
|
|
raise WebsocketClose(4000, 'state manager closed')
|
|
|
|
if not self.ext.state_manager.accept_new:
|
|
raise WebsocketClose(4000, 'state manager closed for new')
|
|
|
|
while True:
|
|
message = await self.ws.recv()
|
|
if len(message) > 4096:
|
|
raise DecodeError('Payload length exceeded')
|
|
|
|
if self.state:
|
|
await self._msg_ratelimit()
|
|
|
|
payload = self.decoder(message)
|
|
await self._process_message(payload)
|
|
|
|
def _cleanup(self):
|
|
"""Cleanup any leftover tasks, and remove the connection from the
|
|
state manager."""
|
|
for task in self.wsp.tasks.values():
|
|
task.cancel()
|
|
|
|
if self.state:
|
|
self.ext.state_manager.remove(self.state)
|
|
self.state.ws = None
|
|
self.state = None
|
|
|
|
async def _check_conns(self, user_id):
|
|
"""Check if there are any existing connections.
|
|
|
|
If there aren't, dispatch a presence for offline.
|
|
"""
|
|
if not user_id:
|
|
return
|
|
|
|
# TODO: account for sharding
|
|
# this only updates status to offline once
|
|
# ALL shards have come offline
|
|
states = self.ext.state_manager.user_states(user_id)
|
|
with_ws = [s for s in states if s.ws]
|
|
|
|
# there arent any other states with websocket
|
|
if not with_ws:
|
|
offline = {
|
|
'afk': False,
|
|
'status': 'offline',
|
|
'game': None,
|
|
'since': 0,
|
|
}
|
|
|
|
await self.ext.presence.dispatch_pres(
|
|
user_id,
|
|
offline
|
|
)
|
|
|
|
async def run(self):
|
|
"""Wrap :meth:`listen_messages` inside
|
|
a try/except block for WebsocketClose handling."""
|
|
try:
|
|
await self._send_hello()
|
|
await self._listen_messages()
|
|
except websockets.exceptions.ConnectionClosed as err:
|
|
log.warning('conn close, state={}, err={}', self.state, err)
|
|
except WebsocketClose as err:
|
|
log.warning('ws close, state={} err={}', self.state, err)
|
|
await self.ws.close(code=err.code, reason=err.reason)
|
|
except Exception as err:
|
|
log.exception('An exception has occoured. state={}', self.state)
|
|
await self.ws.close(code=4000, reason=repr(err))
|
|
finally:
|
|
user_id = self.state.user_id if self.state else None
|
|
self._cleanup()
|
|
await self._check_conns(user_id)
|