diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index c2d7080..6810b24 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -196,10 +196,9 @@ class GatewayWebsocket: await self.send(payload) async def _make_guild_list(self) -> List[int]: - # TODO: This function does not account for sharding. user_id = self.state.user_id - guild_ids = await self.storage.get_user_guilds(user_id) + guild_ids = await self._guild_ids() if self.state.bot: return [{ @@ -315,11 +314,23 @@ class GatewayWebsocket: raise InvalidShard('Shard count > Total shards') async def _guild_ids(self): - # TODO: account for sharding - return await self.storage.get_user_guilds( + guild_ids = await self.storage.get_user_guilds( self.state.user_id ) + shard_id = self.state.current_shard + shard_count = self.state.shard_count + + def _get_shard(guild_id): + return (guild_id >> 22) % shard_count + + filtered = filter( + lambda guild_id: _get_shard(guild_id) == shard_id, + guild_ids + ) + + return list(filtered) + async def subscribe_all(self): """Subscribe to all guilds, DM channels, and friends.