mirror of https://gitlab.com/litecord/litecord.git
gateway.websocket: detach app object from GatewayWebsocket
It doesn't work since quart's objects only work with stuff that is already from quart, e.g the current_app stuff requires you to be inside a special hidden context that only quart functions get. Gateway code is detached from quart since quart's websocket stuff can't handle custom error codes. - auth: optional db detach - gateway.errors: add InvalidShard, ShardingRequired - gateway.gateway: pass asyncpg connection and StateManager - gateway.state: add repr, etc - gateway.state_man: add remove(), fetch_states()
This commit is contained in:
parent
77c5a101c6
commit
6f0528eaec
|
|
@ -11,7 +11,8 @@ from .errors import AuthError
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def raw_token_check(token):
|
async def raw_token_check(token, db=None):
|
||||||
|
db = db or app.db
|
||||||
user_id, _hmac = token.split('.')
|
user_id, _hmac = token.split('.')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -20,7 +21,7 @@ async def raw_token_check(token):
|
||||||
except (ValueError, binascii.Error):
|
except (ValueError, binascii.Error):
|
||||||
raise AuthError('Invalid user ID type')
|
raise AuthError('Invalid user ID type')
|
||||||
|
|
||||||
pwd_hash = await app.db.fetchval("""
|
pwd_hash = await db.fetchval("""
|
||||||
SELECT password_hash
|
SELECT password_hash
|
||||||
FROM users
|
FROM users
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
|
|
|
||||||
|
|
@ -15,3 +15,17 @@ class DecodeError(WebsocketClose):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.args = [4002, self.args[0]]
|
self.args = [4002, self.args[0]]
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidShard(WebsocketClose):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.args = [4010, self.args[0]]
|
||||||
|
|
||||||
|
|
||||||
|
class ShardingRequired(WebsocketClose):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.args = [4011, self.args[0]]
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import urllib.parse
|
||||||
from .websocket import GatewayWebsocket
|
from .websocket import GatewayWebsocket
|
||||||
|
|
||||||
|
|
||||||
async def websocket_handler(app, ws, url):
|
async def websocket_handler(db, sm, ws, url):
|
||||||
qs = urllib.parse.parse_qs(
|
qs = urllib.parse.parse_qs(
|
||||||
urllib.parse.urlparse(url).query
|
urllib.parse.urlparse(url).query
|
||||||
)
|
)
|
||||||
|
|
@ -27,6 +27,6 @@ async def websocket_handler(app, ws, url):
|
||||||
if gw_compress and gw_compress not in ('zlib-stream',):
|
if gw_compress and gw_compress not in ('zlib-stream',):
|
||||||
return await ws.close(1000, 'Invalid gateway compress')
|
return await ws.close(1000, 'Invalid gateway compress')
|
||||||
|
|
||||||
gws = GatewayWebsocket(app, ws, v=gw_version,
|
gws = GatewayWebsocket(sm, db, ws, v=gw_version,
|
||||||
encoding=gw_encoding, compress=gw_compress)
|
encoding=gw_encoding, compress=gw_compress)
|
||||||
await gws.run()
|
await gws.run()
|
||||||
|
|
|
||||||
|
|
@ -13,4 +13,17 @@ class GatewayState:
|
||||||
Used to store all information tied to the websocket's session.
|
Used to store all information tied to the websocket's session.
|
||||||
"""
|
"""
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
pass
|
self.session_id = kwargs.get('session_id', gen_session_id())
|
||||||
|
self.seq = kwargs.get('seq', 0)
|
||||||
|
self.shard = kwargs.get('shard', [0, 1])
|
||||||
|
self.user_id = kwargs.get('user_id')
|
||||||
|
|
||||||
|
self.ws = None
|
||||||
|
|
||||||
|
for key in kwargs:
|
||||||
|
value = kwargs[key]
|
||||||
|
self.__dict__[key] = value
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (f'GatewayState<session={self.session_id} seq={self.seq} '
|
||||||
|
f'shard={self.shard} uid={self.user_id}>')
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,49 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
from .state import GatewayState
|
from .state import GatewayState
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class StateManager:
|
class StateManager:
|
||||||
"""Manager for gateway state information."""
|
"""Manager for gateway state information."""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.states = {}
|
self.states = defaultdict(dict)
|
||||||
|
|
||||||
def insert(self, state: GatewayState):
|
def insert(self, state: GatewayState):
|
||||||
"""Insert a new state object."""
|
"""Insert a new state object."""
|
||||||
user_states = self.states[state.user_id]
|
user_states = self.states[state.user_id]
|
||||||
|
|
||||||
|
log.info(f'Inserting state {state!r}')
|
||||||
user_states[state.session_id] = state
|
user_states[state.session_id] = state
|
||||||
|
|
||||||
def fetch(self, user_id: int, session_id: str) -> GatewayState:
|
def fetch(self, user_id: int, session_id: str) -> GatewayState:
|
||||||
"""Fetch a state object from the registry."""
|
"""Fetch a state object from the registry."""
|
||||||
return self.states[user_id][session_id]
|
return self.states[user_id][session_id]
|
||||||
|
|
||||||
|
def remove(self, state):
|
||||||
|
"""Remove a state from the registry"""
|
||||||
|
if not state:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
log.info(f'Removing state {state!r}')
|
||||||
|
self.states[state.user_id].pop(state.session_id)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def fetch_states(self, user_id, guild_id) -> List[GatewayState]:
|
||||||
|
"""Fetch all states that are tied to a guild."""
|
||||||
|
states = []
|
||||||
|
|
||||||
|
for state in self.states[user_id]:
|
||||||
|
shard_id = (guild_id >> 22) % state.shard_count
|
||||||
|
|
||||||
|
if shard_id == state.current_shard:
|
||||||
|
states.append(state)
|
||||||
|
|
||||||
|
return states
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,9 @@ import earl
|
||||||
|
|
||||||
from ..errors import WebsocketClose, AuthError
|
from ..errors import WebsocketClose, AuthError
|
||||||
from ..auth import raw_token_check
|
from ..auth import raw_token_check
|
||||||
from .errors import DecodeError, UnknownOPCode
|
from .errors import DecodeError, UnknownOPCode, \
|
||||||
|
InvalidShard, ShardingRequired
|
||||||
|
|
||||||
from .opcodes import OP
|
from .opcodes import OP
|
||||||
from .state import GatewayState, gen_session_id
|
from .state import GatewayState, gen_session_id
|
||||||
|
|
||||||
|
|
@ -34,8 +36,9 @@ def decode_etf(data):
|
||||||
|
|
||||||
class GatewayWebsocket:
|
class GatewayWebsocket:
|
||||||
"""Main gateway websocket logic."""
|
"""Main gateway websocket logic."""
|
||||||
def __init__(self, app, ws, **kwargs):
|
def __init__(self, sm, db, ws, **kwargs):
|
||||||
self.app = app
|
self.state_manager = sm
|
||||||
|
self.db = db
|
||||||
self.ws = ws
|
self.ws = ws
|
||||||
|
|
||||||
self.wsp = WebsocketProperties(kwargs.get('v'),
|
self.wsp = WebsocketProperties(kwargs.get('v'),
|
||||||
|
|
@ -78,13 +81,47 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
async def dispatch(self, event, data):
|
async def dispatch(self, event, data):
|
||||||
"""Dispatch an event to the websocket."""
|
"""Dispatch an event to the websocket."""
|
||||||
|
self.state.seq += 1
|
||||||
|
|
||||||
await self.send({
|
await self.send({
|
||||||
'op': OP.DISPATCH,
|
'op': OP.DISPATCH,
|
||||||
't': event.upper(),
|
't': event.upper(),
|
||||||
# 's': self.state.seq,
|
's': self.state.seq,
|
||||||
'd': data,
|
'd': data,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
async def dispatch_ready(self):
|
||||||
|
await self.dispatch('READY', {
|
||||||
|
'v': 6,
|
||||||
|
'user': {'i': 'Boobs !! ! .........'},
|
||||||
|
'private_channels': [],
|
||||||
|
'guilds': [],
|
||||||
|
'session_id': self.state.session_id,
|
||||||
|
'_trace': ['despacito']
|
||||||
|
})
|
||||||
|
|
||||||
|
async def _check_shards(self):
|
||||||
|
shard = self.state.shard
|
||||||
|
current_shard, shard_count = shard
|
||||||
|
|
||||||
|
guilds = await self.db.fetchval("""
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM members
|
||||||
|
WHERE user_id = $1
|
||||||
|
""", self.state.user_id)
|
||||||
|
|
||||||
|
recommended = max(int(guilds / 1200), 1)
|
||||||
|
|
||||||
|
if shard_count < recommended:
|
||||||
|
raise ShardingRequired('Too many guilds for shard '
|
||||||
|
f'{current_shard}')
|
||||||
|
|
||||||
|
if guilds / shard_count > 0.8:
|
||||||
|
raise ShardingRequired('Too many shards.')
|
||||||
|
|
||||||
|
if current_shard > shard_count:
|
||||||
|
raise InvaildShard('Shard count > Total shards')
|
||||||
|
|
||||||
async def handle_0(self, payload: dict):
|
async def handle_0(self, payload: dict):
|
||||||
"""Handle the OP 0 Identify packet."""
|
"""Handle the OP 0 Identify packet."""
|
||||||
data = payload['d']
|
data = payload['d']
|
||||||
|
|
@ -100,33 +137,25 @@ class GatewayWebsocket:
|
||||||
presence = data.get('presence')
|
presence = data.get('presence')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_id = await raw_token_check(token)
|
user_id = await raw_token_check(token, self.db)
|
||||||
except AuthError:
|
except AuthError:
|
||||||
raise WebsocketClose(4004, 'Authentication failed')
|
raise WebsocketClose(4004, 'Authentication failed')
|
||||||
|
|
||||||
session_id = gen_session_id()
|
|
||||||
|
|
||||||
self.state = GatewayState(
|
self.state = GatewayState(
|
||||||
session_id=session_id,
|
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
properties=properties,
|
properties=properties,
|
||||||
compress=compress,
|
compress=compress,
|
||||||
large=large,
|
large=large,
|
||||||
shard=shard,
|
shard=shard,
|
||||||
presence=presence
|
presence=presence,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.app.state_manager.insert(self.state)
|
self.state.ws = self
|
||||||
|
|
||||||
# TODO: dispatch READY
|
await self._check_shards()
|
||||||
await self.dispatch('READY', {
|
|
||||||
'v': 6,
|
self.state_manager.insert(self.state)
|
||||||
'user': {'i': 'Boobs !! ! .........'},
|
await self.dispatch_ready()
|
||||||
'private_channels': [],
|
|
||||||
'guilds': [],
|
|
||||||
'session_id': session_id,
|
|
||||||
'_trace': ['despacito']
|
|
||||||
})
|
|
||||||
|
|
||||||
async def process_message(self, payload):
|
async def process_message(self, payload):
|
||||||
"""Process a single message coming in from the client."""
|
"""Process a single message coming in from the client."""
|
||||||
|
|
@ -159,6 +188,8 @@ class GatewayWebsocket:
|
||||||
await self.send_hello()
|
await self.send_hello()
|
||||||
await self.listen_messages()
|
await self.listen_messages()
|
||||||
except WebsocketClose as err:
|
except WebsocketClose as err:
|
||||||
log.warning(f'Closed a client, {self.state or "<none>"} {err!r}')
|
log.warning(f'Closed a client, state={self.state or "<none>"} '
|
||||||
|
f'{err!r}')
|
||||||
|
|
||||||
await self.ws.close(code=err.code,
|
await self.ws.close(code=err.code,
|
||||||
reason=err.reason)
|
reason=err.reason)
|
||||||
|
|
|
||||||
2
run.py
2
run.py
|
|
@ -49,7 +49,7 @@ async def app_before_serving():
|
||||||
async def _wrapper(ws, url):
|
async def _wrapper(ws, url):
|
||||||
# We wrap the main websocket_handler
|
# We wrap the main websocket_handler
|
||||||
# so we can pass quart's app object.
|
# so we can pass quart's app object.
|
||||||
await websocket_handler(app, ws, url)
|
await websocket_handler(app.db, app.state_manager, ws, url)
|
||||||
|
|
||||||
ws_future = websockets.serve(_wrapper, host, port)
|
ws_future = websockets.serve(_wrapper, host, port)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue