diff --git a/litecord/auth.py b/litecord/auth.py index 1c95efb..dd68a47 100644 --- a/litecord/auth.py +++ b/litecord/auth.py @@ -11,7 +11,8 @@ from .errors import AuthError log = logging.getLogger(__name__) -async def raw_token_check(token): +async def raw_token_check(token, db=None): + db = db or app.db user_id, _hmac = token.split('.') try: @@ -20,7 +21,7 @@ async def raw_token_check(token): except (ValueError, binascii.Error): raise AuthError('Invalid user ID type') - pwd_hash = await app.db.fetchval(""" + pwd_hash = await db.fetchval(""" SELECT password_hash FROM users WHERE id = $1 diff --git a/litecord/gateway/errors.py b/litecord/gateway/errors.py index e877dc9..c74fade 100644 --- a/litecord/gateway/errors.py +++ b/litecord/gateway/errors.py @@ -15,3 +15,17 @@ class DecodeError(WebsocketClose): super().__init__(*args, **kwargs) self.args = [4002, self.args[0]] + + +class InvalidShard(WebsocketClose): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.args = [4010, self.args[0]] + + +class ShardingRequired(WebsocketClose): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.args = [4011, self.args[0]] diff --git a/litecord/gateway/gateway.py b/litecord/gateway/gateway.py index 8b6e8cf..d0105e5 100644 --- a/litecord/gateway/gateway.py +++ b/litecord/gateway/gateway.py @@ -2,7 +2,7 @@ import urllib.parse from .websocket import GatewayWebsocket -async def websocket_handler(app, ws, url): +async def websocket_handler(db, sm, ws, url): qs = urllib.parse.parse_qs( urllib.parse.urlparse(url).query ) @@ -27,6 +27,6 @@ async def websocket_handler(app, ws, url): if gw_compress and gw_compress not in ('zlib-stream',): return await ws.close(1000, 'Invalid gateway compress') - gws = GatewayWebsocket(app, ws, v=gw_version, + gws = GatewayWebsocket(sm, db, ws, v=gw_version, encoding=gw_encoding, compress=gw_compress) await gws.run() diff --git a/litecord/gateway/state.py b/litecord/gateway/state.py index 01e7686..b3bc46e 100644 --- a/litecord/gateway/state.py +++ b/litecord/gateway/state.py @@ -13,4 +13,17 @@ class GatewayState: Used to store all information tied to the websocket's session. """ def __init__(self, **kwargs): - pass + self.session_id = kwargs.get('session_id', gen_session_id()) + self.seq = kwargs.get('seq', 0) + self.shard = kwargs.get('shard', [0, 1]) + self.user_id = kwargs.get('user_id') + + self.ws = None + + for key in kwargs: + value = kwargs[key] + self.__dict__[key] = value + + def __repr__(self): + return (f'GatewayState') diff --git a/litecord/gateway/state_man.py b/litecord/gateway/state_man.py index 0c1de68..4bbb6b2 100644 --- a/litecord/gateway/state_man.py +++ b/litecord/gateway/state_man.py @@ -1,16 +1,49 @@ +import logging + +from typing import List +from collections import defaultdict + from .state import GatewayState +log = logging.getLogger(__name__) + + class StateManager: """Manager for gateway state information.""" def __init__(self): - self.states = {} + self.states = defaultdict(dict) def insert(self, state: GatewayState): """Insert a new state object.""" user_states = self.states[state.user_id] + + log.info(f'Inserting state {state!r}') user_states[state.session_id] = state def fetch(self, user_id: int, session_id: str) -> GatewayState: """Fetch a state object from the registry.""" return self.states[user_id][session_id] + + def remove(self, state): + """Remove a state from the registry""" + if not state: + return + + try: + log.info(f'Removing state {state!r}') + self.states[state.user_id].pop(state.session_id) + except KeyError: + pass + + def fetch_states(self, user_id, guild_id) -> List[GatewayState]: + """Fetch all states that are tied to a guild.""" + states = [] + + for state in self.states[user_id]: + shard_id = (guild_id >> 22) % state.shard_count + + if shard_id == state.current_shard: + states.append(state) + + return states diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index c7616a6..e880371 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -6,7 +6,9 @@ import earl from ..errors import WebsocketClose, AuthError from ..auth import raw_token_check -from .errors import DecodeError, UnknownOPCode +from .errors import DecodeError, UnknownOPCode, \ + InvalidShard, ShardingRequired + from .opcodes import OP from .state import GatewayState, gen_session_id @@ -34,8 +36,9 @@ def decode_etf(data): class GatewayWebsocket: """Main gateway websocket logic.""" - def __init__(self, app, ws, **kwargs): - self.app = app + def __init__(self, sm, db, ws, **kwargs): + self.state_manager = sm + self.db = db self.ws = ws self.wsp = WebsocketProperties(kwargs.get('v'), @@ -78,13 +81,47 @@ class GatewayWebsocket: async def dispatch(self, event, data): """Dispatch an event to the websocket.""" + self.state.seq += 1 + await self.send({ 'op': OP.DISPATCH, 't': event.upper(), - # 's': self.state.seq, + 's': self.state.seq, 'd': data, }) + async def dispatch_ready(self): + await self.dispatch('READY', { + 'v': 6, + 'user': {'i': 'Boobs !! ! .........'}, + 'private_channels': [], + 'guilds': [], + 'session_id': self.state.session_id, + '_trace': ['despacito'] + }) + + async def _check_shards(self): + shard = self.state.shard + current_shard, shard_count = shard + + guilds = await self.db.fetchval(""" + SELECT COUNT(*) + FROM members + WHERE user_id = $1 + """, self.state.user_id) + + recommended = max(int(guilds / 1200), 1) + + if shard_count < recommended: + raise ShardingRequired('Too many guilds for shard ' + f'{current_shard}') + + if guilds / shard_count > 0.8: + raise ShardingRequired('Too many shards.') + + if current_shard > shard_count: + raise InvaildShard('Shard count > Total shards') + async def handle_0(self, payload: dict): """Handle the OP 0 Identify packet.""" data = payload['d'] @@ -100,33 +137,25 @@ class GatewayWebsocket: presence = data.get('presence') try: - user_id = await raw_token_check(token) + user_id = await raw_token_check(token, self.db) except AuthError: raise WebsocketClose(4004, 'Authentication failed') - session_id = gen_session_id() - self.state = GatewayState( - session_id=session_id, user_id=user_id, properties=properties, compress=compress, large=large, shard=shard, - presence=presence + presence=presence, ) - self.app.state_manager.insert(self.state) + self.state.ws = self - # TODO: dispatch READY - await self.dispatch('READY', { - 'v': 6, - 'user': {'i': 'Boobs !! ! .........'}, - 'private_channels': [], - 'guilds': [], - 'session_id': session_id, - '_trace': ['despacito'] - }) + await self._check_shards() + + self.state_manager.insert(self.state) + await self.dispatch_ready() async def process_message(self, payload): """Process a single message coming in from the client.""" @@ -159,6 +188,8 @@ class GatewayWebsocket: await self.send_hello() await self.listen_messages() except WebsocketClose as err: - log.warning(f'Closed a client, {self.state or ""} {err!r}') + log.warning(f'Closed a client, state={self.state or ""} ' + f'{err!r}') + await self.ws.close(code=err.code, reason=err.reason) diff --git a/run.py b/run.py index b1b41b1..ff5628e 100644 --- a/run.py +++ b/run.py @@ -49,7 +49,7 @@ async def app_before_serving(): async def _wrapper(ws, url): # We wrap the main websocket_handler # so we can pass quart's app object. - await websocket_handler(app, ws, url) + await websocket_handler(app.db, app.state_manager, ws, url) ws_future = websockets.serve(_wrapper, host, port)