From c6aea0b7b1cdecd4f09c83620bca38d88491ec7a Mon Sep 17 00:00:00 2001 From: Luna Date: Sun, 3 Mar 2019 00:04:08 -0300 Subject: [PATCH] gateway.websocket: rewrite op 4 handler mostly we moved the permission checks from the websocket to the voice manager. - voice.manager: add leave_all(), leave() - voice.state add VoiceState.guild_id, VoiceState.key --- litecord/gateway/websocket.py | 143 ++++++++++++---------------------- litecord/storage.py | 2 +- litecord/voice/manager.py | 103 +++++++++++++++++++----- litecord/voice/state.py | 6 ++ litecord/voice/utils.py | 0 5 files changed, 142 insertions(+), 112 deletions(-) create mode 100644 litecord/voice/utils.py diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 2fb6000..e999875 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -30,12 +30,12 @@ from logbook import Logger import earl from litecord.auth import raw_token_check -from litecord.enums import RelationshipType, ChannelType, VOICE_CHANNELS +from litecord.enums import RelationshipType, ChannelType from litecord.schemas import validate, GW_STATUS_UPDATE from litecord.utils import ( task_wrapper, LitecordJSONEncoder, yield_chunks ) -from litecord.permissions import get_permissions, ALL_PERMISSIONS +from litecord.permissions import get_permissions from litecord.gateway.opcodes import OP from litecord.gateway.state import GatewayState @@ -603,50 +603,10 @@ class GatewayWebsocket: # setting new presence to state await self.update_status(presence) - @property - def voice_key(self): + def voice_key(self, channel_id: int, guild_id: int): """Voice state key.""" return (self.state.user_id, self.state.session_id) - async def _voice_check(self, guild_id: int, channel_id: int): - """Check if the user can join the given guild/channel pair.""" - guild = None - if guild_id: - guild = await self.storage.get_guild(guild_id) - - channel = await self.storage.get_channel(channel_id) - ctype = ChannelType(channel['type']) - - if ctype not in VOICE_CHANNELS: - return - - if guild and channel.get(['guild_id']) != guild['id']: - return - - is_guild_voice = ctype == ChannelType.GUILD_VOICE - - states = await self.ext.voice.state_count(channel_id) - perms = (ALL_PERMISSIONS - if not is_guild_voice else - await get_permissions(self.state.user_id, - channel_id, storage=self.storage) - ) - - is_full = states >= channel['user_limit'] - is_bot = self.state.bot - - is_manager = perms.bits.manage_channels - - # if the channel is full AND: - # - user is not a bot - # - user is not manage channels - # then it fails - if not is_bot and not is_manager and is_full: - return - - # all checks passed. - return True - async def _vsu_get_prop(self, state, data): """Get voice state properties from data, fallbacking to user settings.""" @@ -664,52 +624,6 @@ class GatewayWebsocket: 'self_mute': self_mute, } - async def _move_voice(self, guild_id, channel_id, state, data): - """Move an existing voice state to the given target.""" - # first case: consider when the user is leaving the - # voice channel. - if channel_id is None: - return await self.ext.voice.del_state(self.voice_key) - - # second case: an update of voice state while being in - # the same channel - if channel_id == state.channel_id: - # we are moving to the same channel, so a simple update - # to the self_deaf / self_mute should suffice. - prop = await self._vsu_get_prop(state, data) - return await self.ext.voice.update_state( - self.voice_key, prop) - - # third case: moving between channels, check if the - # user can join the targeted channel first - if not await self._voice_check(guild_id, channel_id): - return - - # if they can join, move the state to there. - # this will delete the old one and construct a new one. - await self.ext.voice.move_channels(self.voice_key, channel_id) - - async def _create_voice(self, guild_id, channel_id, _state, data): - """Create a voice state.""" - - # if we are trying to create a voice state pointing torwards - # nowhere, we ignore it. - - # NOTE: HOWEVER, shouldn't we update the users' settings for - # self_mute and self_deaf? - if channel_id is None: - return - - # we ignore the given existing state as it'll be basically - # none, lol. - - # check if we can join the channel - if not await self._voice_check(guild_id, channel_id): - return - - # if yes, create the state - await self.ext.voice.create_state(self.voice_key, channel_id, data) - async def handle_4(self, payload: Dict[str, Any]): """Handle OP 4 Voice Status Update.""" data = payload['d'] @@ -720,11 +634,54 @@ class GatewayWebsocket: channel_id = int_(data.get('channel_id')) guild_id = int_(data.get('guild_id')) - # fetch an existing voice state - voice_state = await self.ext.voice.get_state(self.voice_key) + # if its null and null, disconnect the user from any voice + # TODO: maybe just leave from DMs? idk... + if channel_id is None and guild_id is None: + await self.ext.voice.leave_all(self.state.user_id) - func = self._move_voice if voice_state else self._create_voice - await func(guild_id, channel_id, voice_state, data) + # if guild is not none but channel is, we are leaving + # a guild's channel + if channel_id is None: + await self.ext.voice.leave(guild_id, self.state.user_id) + + # fetch an existing state given user and guild OR user and channel + chan_type = ChannelType( + await self.storage.get_chan_type(channel_id) + ) + + state_id2 = channel_id + + if chan_type == ChannelType.GUILD_VOICE: + state_id2 = guild_id + + # a voice state key is a Tuple[int, int] + # - [0] is the user id + # - [1] is the channel id or guild id + + # the old approach was a (user_id, session_id), but + # that does not work. + + # this works since users can be connected to many channels + # using a single gateway websocket connection. HOWEVER, + # they CAN NOT enter two channels in a single guild. + + # this state id format takes care of that. + voice_key = (self.state.user_id, state_id2) + voice_state = await self.ext.voice.get_state(voice_key) + + if voice_state is None: + await self.ext.voice.create_state(voice_key) + + same_guild = guild_id == voice_state.guild_id + same_channel = channel_id == voice_state.channel_id + + prop = await self._vsu_get_prop(voice_state, data) + + if same_guild and same_channel: + await self.ext.voice.update_state(voice_state, prop) + + if same_guild and not same_channel: + await self.ext.voice.move_state(voice_state, channel_id) async def _handle_5(self, payload: Dict[str, Any]): """Handle OP 5 Voice Server Ping. diff --git a/litecord/storage.py b/litecord/storage.py index 0d56c28..0910a65 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -943,7 +943,7 @@ class Storage: return dm_chan - async def guild_from_channel(self, channel_id: int): + async def guild_from_channel(self, channel_id: int) -> int: """Get the guild id coming from a channel id.""" return await self.db.fetchval(""" SELECT guild_id diff --git a/litecord/voice/manager.py b/litecord/voice/manager.py index e3301d1..8b74293 100644 --- a/litecord/voice/manager.py +++ b/litecord/voice/manager.py @@ -23,6 +23,8 @@ from dataclasses import fields from logbook import Logger +from litecord.permissions import get_permissions +from litecord.enums import ChannelType, VOICE_CHANNELS from litecord.voice.state import VoiceState @@ -43,53 +45,95 @@ class VoiceManager: def __init__(self, app): self.app = app + # double dict, first key is guild/channel id, second key is user id self.states = defaultdict(dict) # TODO: hold voice server LVSP connections # TODO: map channel ids to voice servers + async def can_join(self, user_id: int, channel_id: int) -> int: + """Return if a user can join a channel.""" + + channel = await self.app.storage.get_channel(channel_id) + ctype = ChannelType(channel['type']) + + if ctype not in VOICE_CHANNELS: + return + + states = await self.app.voice.state_count(channel_id) + + # get_permissions returns ALL_PERMISSIONS when + # the channel isn't from a guild + perms = await get_permissions( + user_id, channel_id, storage=self.app.storage + ) + + # hacky user_limit but should work, as channels not + # in guilds won't have that field. + is_full = states >= channel.get('user_limit', 100) + is_bot = (await self.app.storage.get_user(user_id))['bot'] + is_manager = perms.bits.manage_channels + + # if the channel is full AND: + # - user is not a bot + # - user is not manage channels + # then it fails + if not is_bot and not is_manager and is_full: + return + + # all good + return True + async def state_count(self, channel_id: int) -> int: """Get the current amount of voice states in a channel.""" return len(self.states[channel_id]) async def fetch_states(self, channel_id: int) -> Dict[int, VoiceState]: """Fetch the states of the given channel.""" - # NOTE: maybe we *could* optimize by just returning a reference to the - # states dict instead of calling dict()... + # since the state key is (user_id, guild_id | channel_id), we need + # to determine which kind of search we want to do. + guild_id = await self.app.storage.guild_from_channel(channel_id) - # however I'm really worried about state inconsistencies caused - # by this, so i'll just use dict(). - return dict(self.states[channel_id]) + # if there isn't a guild for the channel, it is a dm or group dm. + # those are simple to handle. + if not guild_id: + return dict(self.states[channel_id]) + + # guild states hold a dict mapping user ids to guild states, + # same as channels, thats the structure. + guild_states = self.states[guild_id] + res = {} + + # iterate over all users with states and add the channel matches + # into res + for user_id, state in guild_states.items(): + if state.channel_id == channel_id: + res[user_id] = state + + return res async def get_state(self, voice_key: VoiceKey) -> VoiceState: """Get a single VoiceState for a user in a channel. Returns None if no VoiceState is found.""" - channel_id, user_id = voice_key + user_id, sec_key_id = voice_key try: - return self.states[channel_id][user_id] + return self.states[sec_key_id][user_id] except KeyError: return None async def del_state(self, voice_key: VoiceKey): """Delete a given voice state.""" - chan_id, user_id = voice_key + user_id, sec_key_id = voice_key try: # TODO: tell that to the voice server of the channel. - self.states[chan_id].pop(user_id) + self.states[sec_key_id].pop(user_id) except KeyError: pass - async def update_state(self, voice_key: VoiceKey, prop: dict): + async def update_state(self, state: VoiceState, prop: dict): """Update a state in a channel""" - chan_id, user_id = voice_key - - try: - state = self.states[chan_id][user_id] - except KeyError: - return - # construct a new state based on the old one + properties new_state_dict = dict(state.as_json) @@ -103,7 +147,7 @@ class VoiceManager: new_state = _construct_state(new_state_dict) # TODO: dispatch to voice server - self.states[chan_id][user_id] = new_state + self.states[state.key][state.user_id] = new_state async def move_channels(self, old_voice_key: VoiceKey, channel_id: int): """Move a user between channels.""" @@ -113,3 +157,26 @@ class VoiceManager: async def create_state(self, voice_key: VoiceKey, channel_id: int, data: dict): pass + + async def leave_all(self, user_id: int) -> int: + """Leave all voice channels.""" + + # iterate over every state finding matches + + # NOTE: we copy the current states dict since we're modifying + # on iteration. this is SLOW. + + # TODO: better solution instead of copying, maybe we can generate + # a list of tasks to run that actually do the deletion by themselves + # instead of us generating a delete. then only start running them later + # on. + for sec_key_id, states in dict(self.states): + for state in states: + if state.user_id != user_id: + continue + + await self.del_state((user_id, sec_key_id)) + + async def leave(self, guild_id: int, user_id: int): + """Make a user leave a channel IN A GUILD.""" + await self.del_state((guild_id, user_id)) diff --git a/litecord/voice/state.py b/litecord/voice/state.py index adb3c20..d5e8732 100644 --- a/litecord/voice/state.py +++ b/litecord/voice/state.py @@ -23,6 +23,7 @@ from dataclasses import dataclass, asdict @dataclass class VoiceState: """Represents a voice state.""" + guild_id: int channel_id: int user_id: int session_id: str @@ -32,6 +33,11 @@ class VoiceState: self_mute: bool suppressed_by: int + @property + def key(self): + """Get the second part of a key identifying a state.""" + return self.channel_id if self.guild_id is None else self.guild_id + @property def as_json(self): """Return JSON-serializable dict.""" diff --git a/litecord/voice/utils.py b/litecord/voice/utils.py new file mode 100644 index 0000000..e69de29