gateway.websocket: account for sharding in _guild_ids

Closes #1
This commit is contained in:
Luna Mendes 2018-11-14 23:24:38 -03:00
parent a50cf8a17c
commit 61d553efb8
1 changed files with 15 additions and 4 deletions

View File

@ -196,10 +196,9 @@ class GatewayWebsocket:
await self.send(payload) await self.send(payload)
async def _make_guild_list(self) -> List[int]: async def _make_guild_list(self) -> List[int]:
# TODO: This function does not account for sharding.
user_id = self.state.user_id user_id = self.state.user_id
guild_ids = await self.storage.get_user_guilds(user_id) guild_ids = await self._guild_ids()
if self.state.bot: if self.state.bot:
return [{ return [{
@ -315,11 +314,23 @@ class GatewayWebsocket:
raise InvalidShard('Shard count > Total shards') raise InvalidShard('Shard count > Total shards')
async def _guild_ids(self): async def _guild_ids(self):
# TODO: account for sharding guild_ids = await self.storage.get_user_guilds(
return await self.storage.get_user_guilds(
self.state.user_id self.state.user_id
) )
shard_id = self.state.current_shard
shard_count = self.state.shard_count
def _get_shard(guild_id):
return (guild_id >> 22) % shard_count
filtered = filter(
lambda guild_id: _get_shard(guild_id) == shard_id,
guild_ids
)
return list(filtered)
async def subscribe_all(self): async def subscribe_all(self):
"""Subscribe to all guilds, DM channels, and friends. """Subscribe to all guilds, DM channels, and friends.