From 61d553efb84068140f8c9a92c16718c93379f268 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 14 Nov 2018 23:24:38 -0300 Subject: [PATCH] gateway.websocket: account for sharding in _guild_ids Closes #1 --- litecord/gateway/websocket.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index c2d7080..6810b24 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -196,10 +196,9 @@ class GatewayWebsocket: await self.send(payload) async def _make_guild_list(self) -> List[int]: - # TODO: This function does not account for sharding. user_id = self.state.user_id - guild_ids = await self.storage.get_user_guilds(user_id) + guild_ids = await self._guild_ids() if self.state.bot: return [{ @@ -315,11 +314,23 @@ class GatewayWebsocket: raise InvalidShard('Shard count > Total shards') async def _guild_ids(self): - # TODO: account for sharding - return await self.storage.get_user_guilds( + guild_ids = await self.storage.get_user_guilds( self.state.user_id ) + shard_id = self.state.current_shard + shard_count = self.state.shard_count + + def _get_shard(guild_id): + return (guild_id >> 22) % shard_count + + filtered = filter( + lambda guild_id: _get_shard(guild_id) == shard_id, + guild_ids + ) + + return list(filtered) + async def subscribe_all(self): """Subscribe to all guilds, DM channels, and friends.