mirror of https://gitlab.com/litecord/litecord.git
gateway.websocket: rewrite op 4 handler
mostly we moved the permission checks from the websocket to the voice manager. - voice.manager: add leave_all(), leave() - voice.state add VoiceState.guild_id, VoiceState.key
This commit is contained in:
parent
d47d79977b
commit
c6aea0b7b1
|
|
@ -30,12 +30,12 @@ from logbook import Logger
|
||||||
import earl
|
import earl
|
||||||
|
|
||||||
from litecord.auth import raw_token_check
|
from litecord.auth import raw_token_check
|
||||||
from litecord.enums import RelationshipType, ChannelType, VOICE_CHANNELS
|
from litecord.enums import RelationshipType, ChannelType
|
||||||
from litecord.schemas import validate, GW_STATUS_UPDATE
|
from litecord.schemas import validate, GW_STATUS_UPDATE
|
||||||
from litecord.utils import (
|
from litecord.utils import (
|
||||||
task_wrapper, LitecordJSONEncoder, yield_chunks
|
task_wrapper, LitecordJSONEncoder, yield_chunks
|
||||||
)
|
)
|
||||||
from litecord.permissions import get_permissions, ALL_PERMISSIONS
|
from litecord.permissions import get_permissions
|
||||||
|
|
||||||
from litecord.gateway.opcodes import OP
|
from litecord.gateway.opcodes import OP
|
||||||
from litecord.gateway.state import GatewayState
|
from litecord.gateway.state import GatewayState
|
||||||
|
|
@ -603,50 +603,10 @@ class GatewayWebsocket:
|
||||||
# setting new presence to state
|
# setting new presence to state
|
||||||
await self.update_status(presence)
|
await self.update_status(presence)
|
||||||
|
|
||||||
@property
|
def voice_key(self, channel_id: int, guild_id: int):
|
||||||
def voice_key(self):
|
|
||||||
"""Voice state key."""
|
"""Voice state key."""
|
||||||
return (self.state.user_id, self.state.session_id)
|
return (self.state.user_id, self.state.session_id)
|
||||||
|
|
||||||
async def _voice_check(self, guild_id: int, channel_id: int):
|
|
||||||
"""Check if the user can join the given guild/channel pair."""
|
|
||||||
guild = None
|
|
||||||
if guild_id:
|
|
||||||
guild = await self.storage.get_guild(guild_id)
|
|
||||||
|
|
||||||
channel = await self.storage.get_channel(channel_id)
|
|
||||||
ctype = ChannelType(channel['type'])
|
|
||||||
|
|
||||||
if ctype not in VOICE_CHANNELS:
|
|
||||||
return
|
|
||||||
|
|
||||||
if guild and channel.get(['guild_id']) != guild['id']:
|
|
||||||
return
|
|
||||||
|
|
||||||
is_guild_voice = ctype == ChannelType.GUILD_VOICE
|
|
||||||
|
|
||||||
states = await self.ext.voice.state_count(channel_id)
|
|
||||||
perms = (ALL_PERMISSIONS
|
|
||||||
if not is_guild_voice else
|
|
||||||
await get_permissions(self.state.user_id,
|
|
||||||
channel_id, storage=self.storage)
|
|
||||||
)
|
|
||||||
|
|
||||||
is_full = states >= channel['user_limit']
|
|
||||||
is_bot = self.state.bot
|
|
||||||
|
|
||||||
is_manager = perms.bits.manage_channels
|
|
||||||
|
|
||||||
# if the channel is full AND:
|
|
||||||
# - user is not a bot
|
|
||||||
# - user is not manage channels
|
|
||||||
# then it fails
|
|
||||||
if not is_bot and not is_manager and is_full:
|
|
||||||
return
|
|
||||||
|
|
||||||
# all checks passed.
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _vsu_get_prop(self, state, data):
|
async def _vsu_get_prop(self, state, data):
|
||||||
"""Get voice state properties from data, fallbacking to
|
"""Get voice state properties from data, fallbacking to
|
||||||
user settings."""
|
user settings."""
|
||||||
|
|
@ -664,52 +624,6 @@ class GatewayWebsocket:
|
||||||
'self_mute': self_mute,
|
'self_mute': self_mute,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _move_voice(self, guild_id, channel_id, state, data):
|
|
||||||
"""Move an existing voice state to the given target."""
|
|
||||||
# first case: consider when the user is leaving the
|
|
||||||
# voice channel.
|
|
||||||
if channel_id is None:
|
|
||||||
return await self.ext.voice.del_state(self.voice_key)
|
|
||||||
|
|
||||||
# second case: an update of voice state while being in
|
|
||||||
# the same channel
|
|
||||||
if channel_id == state.channel_id:
|
|
||||||
# we are moving to the same channel, so a simple update
|
|
||||||
# to the self_deaf / self_mute should suffice.
|
|
||||||
prop = await self._vsu_get_prop(state, data)
|
|
||||||
return await self.ext.voice.update_state(
|
|
||||||
self.voice_key, prop)
|
|
||||||
|
|
||||||
# third case: moving between channels, check if the
|
|
||||||
# user can join the targeted channel first
|
|
||||||
if not await self._voice_check(guild_id, channel_id):
|
|
||||||
return
|
|
||||||
|
|
||||||
# if they can join, move the state to there.
|
|
||||||
# this will delete the old one and construct a new one.
|
|
||||||
await self.ext.voice.move_channels(self.voice_key, channel_id)
|
|
||||||
|
|
||||||
async def _create_voice(self, guild_id, channel_id, _state, data):
|
|
||||||
"""Create a voice state."""
|
|
||||||
|
|
||||||
# if we are trying to create a voice state pointing torwards
|
|
||||||
# nowhere, we ignore it.
|
|
||||||
|
|
||||||
# NOTE: HOWEVER, shouldn't we update the users' settings for
|
|
||||||
# self_mute and self_deaf?
|
|
||||||
if channel_id is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# we ignore the given existing state as it'll be basically
|
|
||||||
# none, lol.
|
|
||||||
|
|
||||||
# check if we can join the channel
|
|
||||||
if not await self._voice_check(guild_id, channel_id):
|
|
||||||
return
|
|
||||||
|
|
||||||
# if yes, create the state
|
|
||||||
await self.ext.voice.create_state(self.voice_key, channel_id, data)
|
|
||||||
|
|
||||||
async def handle_4(self, payload: Dict[str, Any]):
|
async def handle_4(self, payload: Dict[str, Any]):
|
||||||
"""Handle OP 4 Voice Status Update."""
|
"""Handle OP 4 Voice Status Update."""
|
||||||
data = payload['d']
|
data = payload['d']
|
||||||
|
|
@ -720,11 +634,54 @@ class GatewayWebsocket:
|
||||||
channel_id = int_(data.get('channel_id'))
|
channel_id = int_(data.get('channel_id'))
|
||||||
guild_id = int_(data.get('guild_id'))
|
guild_id = int_(data.get('guild_id'))
|
||||||
|
|
||||||
# fetch an existing voice state
|
# if its null and null, disconnect the user from any voice
|
||||||
voice_state = await self.ext.voice.get_state(self.voice_key)
|
# TODO: maybe just leave from DMs? idk...
|
||||||
|
if channel_id is None and guild_id is None:
|
||||||
|
await self.ext.voice.leave_all(self.state.user_id)
|
||||||
|
|
||||||
func = self._move_voice if voice_state else self._create_voice
|
# if guild is not none but channel is, we are leaving
|
||||||
await func(guild_id, channel_id, voice_state, data)
|
# a guild's channel
|
||||||
|
if channel_id is None:
|
||||||
|
await self.ext.voice.leave(guild_id, self.state.user_id)
|
||||||
|
|
||||||
|
# fetch an existing state given user and guild OR user and channel
|
||||||
|
chan_type = ChannelType(
|
||||||
|
await self.storage.get_chan_type(channel_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
state_id2 = channel_id
|
||||||
|
|
||||||
|
if chan_type == ChannelType.GUILD_VOICE:
|
||||||
|
state_id2 = guild_id
|
||||||
|
|
||||||
|
# a voice state key is a Tuple[int, int]
|
||||||
|
# - [0] is the user id
|
||||||
|
# - [1] is the channel id or guild id
|
||||||
|
|
||||||
|
# the old approach was a (user_id, session_id), but
|
||||||
|
# that does not work.
|
||||||
|
|
||||||
|
# this works since users can be connected to many channels
|
||||||
|
# using a single gateway websocket connection. HOWEVER,
|
||||||
|
# they CAN NOT enter two channels in a single guild.
|
||||||
|
|
||||||
|
# this state id format takes care of that.
|
||||||
|
voice_key = (self.state.user_id, state_id2)
|
||||||
|
voice_state = await self.ext.voice.get_state(voice_key)
|
||||||
|
|
||||||
|
if voice_state is None:
|
||||||
|
await self.ext.voice.create_state(voice_key)
|
||||||
|
|
||||||
|
same_guild = guild_id == voice_state.guild_id
|
||||||
|
same_channel = channel_id == voice_state.channel_id
|
||||||
|
|
||||||
|
prop = await self._vsu_get_prop(voice_state, data)
|
||||||
|
|
||||||
|
if same_guild and same_channel:
|
||||||
|
await self.ext.voice.update_state(voice_state, prop)
|
||||||
|
|
||||||
|
if same_guild and not same_channel:
|
||||||
|
await self.ext.voice.move_state(voice_state, channel_id)
|
||||||
|
|
||||||
async def _handle_5(self, payload: Dict[str, Any]):
|
async def _handle_5(self, payload: Dict[str, Any]):
|
||||||
"""Handle OP 5 Voice Server Ping.
|
"""Handle OP 5 Voice Server Ping.
|
||||||
|
|
|
||||||
|
|
@ -943,7 +943,7 @@ class Storage:
|
||||||
|
|
||||||
return dm_chan
|
return dm_chan
|
||||||
|
|
||||||
async def guild_from_channel(self, channel_id: int):
|
async def guild_from_channel(self, channel_id: int) -> int:
|
||||||
"""Get the guild id coming from a channel id."""
|
"""Get the guild id coming from a channel id."""
|
||||||
return await self.db.fetchval("""
|
return await self.db.fetchval("""
|
||||||
SELECT guild_id
|
SELECT guild_id
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,8 @@ from dataclasses import fields
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
|
from litecord.permissions import get_permissions
|
||||||
|
from litecord.enums import ChannelType, VOICE_CHANNELS
|
||||||
from litecord.voice.state import VoiceState
|
from litecord.voice.state import VoiceState
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -43,53 +45,95 @@ class VoiceManager:
|
||||||
def __init__(self, app):
|
def __init__(self, app):
|
||||||
self.app = app
|
self.app = app
|
||||||
|
|
||||||
|
# double dict, first key is guild/channel id, second key is user id
|
||||||
self.states = defaultdict(dict)
|
self.states = defaultdict(dict)
|
||||||
|
|
||||||
# TODO: hold voice server LVSP connections
|
# TODO: hold voice server LVSP connections
|
||||||
# TODO: map channel ids to voice servers
|
# TODO: map channel ids to voice servers
|
||||||
|
|
||||||
|
async def can_join(self, user_id: int, channel_id: int) -> int:
|
||||||
|
"""Return if a user can join a channel."""
|
||||||
|
|
||||||
|
channel = await self.app.storage.get_channel(channel_id)
|
||||||
|
ctype = ChannelType(channel['type'])
|
||||||
|
|
||||||
|
if ctype not in VOICE_CHANNELS:
|
||||||
|
return
|
||||||
|
|
||||||
|
states = await self.app.voice.state_count(channel_id)
|
||||||
|
|
||||||
|
# get_permissions returns ALL_PERMISSIONS when
|
||||||
|
# the channel isn't from a guild
|
||||||
|
perms = await get_permissions(
|
||||||
|
user_id, channel_id, storage=self.app.storage
|
||||||
|
)
|
||||||
|
|
||||||
|
# hacky user_limit but should work, as channels not
|
||||||
|
# in guilds won't have that field.
|
||||||
|
is_full = states >= channel.get('user_limit', 100)
|
||||||
|
is_bot = (await self.app.storage.get_user(user_id))['bot']
|
||||||
|
is_manager = perms.bits.manage_channels
|
||||||
|
|
||||||
|
# if the channel is full AND:
|
||||||
|
# - user is not a bot
|
||||||
|
# - user is not manage channels
|
||||||
|
# then it fails
|
||||||
|
if not is_bot and not is_manager and is_full:
|
||||||
|
return
|
||||||
|
|
||||||
|
# all good
|
||||||
|
return True
|
||||||
|
|
||||||
async def state_count(self, channel_id: int) -> int:
|
async def state_count(self, channel_id: int) -> int:
|
||||||
"""Get the current amount of voice states in a channel."""
|
"""Get the current amount of voice states in a channel."""
|
||||||
return len(self.states[channel_id])
|
return len(self.states[channel_id])
|
||||||
|
|
||||||
async def fetch_states(self, channel_id: int) -> Dict[int, VoiceState]:
|
async def fetch_states(self, channel_id: int) -> Dict[int, VoiceState]:
|
||||||
"""Fetch the states of the given channel."""
|
"""Fetch the states of the given channel."""
|
||||||
# NOTE: maybe we *could* optimize by just returning a reference to the
|
# since the state key is (user_id, guild_id | channel_id), we need
|
||||||
# states dict instead of calling dict()...
|
# to determine which kind of search we want to do.
|
||||||
|
guild_id = await self.app.storage.guild_from_channel(channel_id)
|
||||||
|
|
||||||
# however I'm really worried about state inconsistencies caused
|
# if there isn't a guild for the channel, it is a dm or group dm.
|
||||||
# by this, so i'll just use dict().
|
# those are simple to handle.
|
||||||
return dict(self.states[channel_id])
|
if not guild_id:
|
||||||
|
return dict(self.states[channel_id])
|
||||||
|
|
||||||
|
# guild states hold a dict mapping user ids to guild states,
|
||||||
|
# same as channels, thats the structure.
|
||||||
|
guild_states = self.states[guild_id]
|
||||||
|
res = {}
|
||||||
|
|
||||||
|
# iterate over all users with states and add the channel matches
|
||||||
|
# into res
|
||||||
|
for user_id, state in guild_states.items():
|
||||||
|
if state.channel_id == channel_id:
|
||||||
|
res[user_id] = state
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
async def get_state(self, voice_key: VoiceKey) -> VoiceState:
|
async def get_state(self, voice_key: VoiceKey) -> VoiceState:
|
||||||
"""Get a single VoiceState for a user in a channel. Returns None
|
"""Get a single VoiceState for a user in a channel. Returns None
|
||||||
if no VoiceState is found."""
|
if no VoiceState is found."""
|
||||||
channel_id, user_id = voice_key
|
user_id, sec_key_id = voice_key
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return self.states[channel_id][user_id]
|
return self.states[sec_key_id][user_id]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def del_state(self, voice_key: VoiceKey):
|
async def del_state(self, voice_key: VoiceKey):
|
||||||
"""Delete a given voice state."""
|
"""Delete a given voice state."""
|
||||||
chan_id, user_id = voice_key
|
user_id, sec_key_id = voice_key
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# TODO: tell that to the voice server of the channel.
|
# TODO: tell that to the voice server of the channel.
|
||||||
self.states[chan_id].pop(user_id)
|
self.states[sec_key_id].pop(user_id)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def update_state(self, voice_key: VoiceKey, prop: dict):
|
async def update_state(self, state: VoiceState, prop: dict):
|
||||||
"""Update a state in a channel"""
|
"""Update a state in a channel"""
|
||||||
chan_id, user_id = voice_key
|
|
||||||
|
|
||||||
try:
|
|
||||||
state = self.states[chan_id][user_id]
|
|
||||||
except KeyError:
|
|
||||||
return
|
|
||||||
|
|
||||||
# construct a new state based on the old one + properties
|
# construct a new state based on the old one + properties
|
||||||
new_state_dict = dict(state.as_json)
|
new_state_dict = dict(state.as_json)
|
||||||
|
|
||||||
|
|
@ -103,7 +147,7 @@ class VoiceManager:
|
||||||
new_state = _construct_state(new_state_dict)
|
new_state = _construct_state(new_state_dict)
|
||||||
|
|
||||||
# TODO: dispatch to voice server
|
# TODO: dispatch to voice server
|
||||||
self.states[chan_id][user_id] = new_state
|
self.states[state.key][state.user_id] = new_state
|
||||||
|
|
||||||
async def move_channels(self, old_voice_key: VoiceKey, channel_id: int):
|
async def move_channels(self, old_voice_key: VoiceKey, channel_id: int):
|
||||||
"""Move a user between channels."""
|
"""Move a user between channels."""
|
||||||
|
|
@ -113,3 +157,26 @@ class VoiceManager:
|
||||||
async def create_state(self, voice_key: VoiceKey, channel_id: int,
|
async def create_state(self, voice_key: VoiceKey, channel_id: int,
|
||||||
data: dict):
|
data: dict):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def leave_all(self, user_id: int) -> int:
|
||||||
|
"""Leave all voice channels."""
|
||||||
|
|
||||||
|
# iterate over every state finding matches
|
||||||
|
|
||||||
|
# NOTE: we copy the current states dict since we're modifying
|
||||||
|
# on iteration. this is SLOW.
|
||||||
|
|
||||||
|
# TODO: better solution instead of copying, maybe we can generate
|
||||||
|
# a list of tasks to run that actually do the deletion by themselves
|
||||||
|
# instead of us generating a delete. then only start running them later
|
||||||
|
# on.
|
||||||
|
for sec_key_id, states in dict(self.states):
|
||||||
|
for state in states:
|
||||||
|
if state.user_id != user_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
await self.del_state((user_id, sec_key_id))
|
||||||
|
|
||||||
|
async def leave(self, guild_id: int, user_id: int):
|
||||||
|
"""Make a user leave a channel IN A GUILD."""
|
||||||
|
await self.del_state((guild_id, user_id))
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ from dataclasses import dataclass, asdict
|
||||||
@dataclass
|
@dataclass
|
||||||
class VoiceState:
|
class VoiceState:
|
||||||
"""Represents a voice state."""
|
"""Represents a voice state."""
|
||||||
|
guild_id: int
|
||||||
channel_id: int
|
channel_id: int
|
||||||
user_id: int
|
user_id: int
|
||||||
session_id: str
|
session_id: str
|
||||||
|
|
@ -32,6 +33,11 @@ class VoiceState:
|
||||||
self_mute: bool
|
self_mute: bool
|
||||||
suppressed_by: int
|
suppressed_by: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def key(self):
|
||||||
|
"""Get the second part of a key identifying a state."""
|
||||||
|
return self.channel_id if self.guild_id is None else self.guild_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def as_json(self):
|
def as_json(self):
|
||||||
"""Return JSON-serializable dict."""
|
"""Return JSON-serializable dict."""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue