diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index e6fa7d4..d17cf27 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -670,7 +670,7 @@ class GatewayWebsocket: voice_state = await self.ext.voice.get_state(voice_key) if voice_state is None: - await self.ext.voice.create_state(voice_key) + return await self.ext.voice.create_state(voice_key) same_guild = guild_id == voice_state.guild_id same_channel = channel_id == voice_state.channel_id @@ -678,10 +678,10 @@ class GatewayWebsocket: prop = await self._vsu_get_prop(voice_state, data) if same_guild and same_channel: - await self.ext.voice.update_state(voice_state, prop) + return await self.ext.voice.update_state(voice_state, prop) if same_guild and not same_channel: - await self.ext.voice.move_state(voice_state, channel_id) + return await self.ext.voice.move_state(voice_state, channel_id) async def _handle_5(self, payload: Dict[str, Any]): """Handle OP 5 Voice Server Ping. diff --git a/litecord/permissions.py b/litecord/permissions.py index 4e2a4b6..4b062e8 100644 --- a/litecord/permissions.py +++ b/litecord/permissions.py @@ -18,6 +18,7 @@ along with this program. If not, see . """ import ctypes +from typing import Optional from quart import current_app as app diff --git a/litecord/utils.py b/litecord/utils.py index 0145825..513fafa 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -19,7 +19,7 @@ along with this program. If not, see . import asyncio import json -from typing import Any, Iterable, Optional, Indexable +from typing import Any, Iterable, Optional, Sequence from logbook import Logger from quart.json import JSONEncoder @@ -27,7 +27,7 @@ from quart.json import JSONEncoder log = Logger(__name__) -async def async_map(function, iterable) -> list: +async def async_map(function, iterable: Iterable) -> list: """Map a coroutine to an iterable.""" res = [] @@ -52,7 +52,7 @@ def dict_get(mapping, key, default): return mapping.get(key) or default -def index_by_func(function, indexable: Indexable) -> Optional[int]: +def index_by_func(function, indexable: Sequence[Any]) -> Optional[int]: """Search in an idexable and return the index number for an iterm that has func(item) = True.""" for index, item in enumerate(indexable): @@ -161,7 +161,7 @@ async def pg_set_json(con): ) -def yield_chunks(input_list: Iterable, chunk_size: int): +def yield_chunks(input_list: Sequence[Any], chunk_size: int): """Yield successive n-sized chunks from l. Taken from https://stackoverflow.com/a/312464. diff --git a/litecord/voice/lvsp_conn.py b/litecord/voice/lvsp_conn.py index a350737..143ba97 100644 --- a/litecord/voice/lvsp_conn.py +++ b/litecord/voice/lvsp_conn.py @@ -38,6 +38,7 @@ class LVSPConnection: self.hostname = hostname self.conn = None + self.health = 0.5 self._hb_task = None self._hb_interval = None @@ -98,6 +99,8 @@ class LVSPConnection: async def _update_health(self, new_health: float): """Update the health value of a given voice server.""" + self.health = new_health + await self.app.db.execute(""" UPDATE voice_servers SET health = $1 diff --git a/litecord/voice/lvsp_manager.py b/litecord/voice/lvsp_manager.py index 2182859..a40ff83 100644 --- a/litecord/voice/lvsp_manager.py +++ b/litecord/voice/lvsp_manager.py @@ -17,6 +17,7 @@ along with this program. If not, see . """ +from typing import Optional from collections import defaultdict from logbook import Logger @@ -34,7 +35,15 @@ class LVSPManager: self.app = app self.voice = voice - self.servers = defaultdict(dict) + # map servers to LVSPConnection + self.conns = {} + + # maps regions to server hostnames + self.servers = defaultdict(list) + + # maps guilds to server hostnames + self.guild_servers = {} + self.app.loop.create_task(self._spawn()) async def _spawn(self): @@ -71,10 +80,11 @@ class LVSPManager: return servers = [r['hostname'] for r in servers] + self.servers[region] = servers for hostname in servers: conn = LVSPConnection(self, region, hostname) - self.servers[region][hostname] = conn + self.conns[hostname] = conn self.app.loop.create_task( conn.run() @@ -83,6 +93,47 @@ class LVSPManager: async def del_conn(self, conn): """Delete a connection from the connection pool.""" try: - self.servers[conn.region].pop(conn.hostname) + self.servers[conn.region].remove(conn.hostname) except KeyError: pass + + try: + self.conns.pop(conn.hostname) + except KeyError: + pass + + async def guild_region(self, guild_id: int) -> Optional[str]: + """Return the voice region of a guild.""" + return await self.app.db.fetchval(""" + SELECT region + FROM guilds + WHERE id = $1 + """, guild_id) + + def get_health(self, hostname: str) -> float: + """Get voice server health, given hostname.""" + try: + conn = self.conns[hostname] + except KeyError: + return -1 + + return conn.health + + async def get_server(self, guild_id: int) -> str: + """Get a voice server for the given guild, assigns + one if there isn't any.""" + + try: + hostname = self.guild_servers[guild_id] + except KeyError: + region = await self.guild_region(guild_id) + + # sort connected servers by health + sorted_servers = sorted( + self.servers[region], + self.get_health, + ) + + hostname = sorted_servers[0] + + return hostname diff --git a/litecord/voice/manager.py b/litecord/voice/manager.py index 535b6b2..a408909 100644 --- a/litecord/voice/manager.py +++ b/litecord/voice/manager.py @@ -153,11 +153,72 @@ class VoiceManager: async def move_channels(self, old_voice_key: VoiceKey, channel_id: int): """Move a user between channels.""" await self.del_state(old_voice_key) - await self.create_state(old_voice_key, channel_id, {}) + await self.create_state(old_voice_key, {'channel_id': channel_id}) - async def create_state(self, voice_key: VoiceKey, channel_id: int, - data: dict): - pass + async def _create_ctx_guild(self, guild_id, channel_id): + # get a voice server + server = await self.lvsp.get_server(guild_id) + conn = self.lvsp.get_conn(server) + chan = await self.app.storage.get_channel(channel_id) + + # TODO: this, but properly + # TODO: when the server sends a reply to CHAN_REQ, we need to update + # LVSPManager.guild_servers. + await conn.send_info('CHAN_REQ', { + 'guild_id': str(guild_id), + 'channel_id': str(channel_id), + + 'channel_properties': { + 'bitrate': chan['bitrate'] + } + }) + + async def _start_voice_guild(self, voice_key: VoiceKey, data: dict): + """Start a voice context in a guild.""" + user_id, guild_id = voice_key + channel_id = int(data['channel_id']) + + existing_states = self.states[voice_key] + channel_exists = any( + state.channel_id == channel_id for state in existing_states) + + if not channel_exists: + await self._create_ctx_guild(guild_id, channel_id) + + async def create_state(self, voice_key: VoiceKey, data: dict): + """Creates (or tries to create) a voice state. + + Depending on the VoiceKey given, it will use the guild's voice + region or assign one based on the starter of a call, or the owner of + a Group DM. + + Once a region is assigned, it'll choose the best voice server + and send a request to it. + """ + + # TODO: handle CALL events. + + # compare if this voice key is for a guild or a channel + _uid, id2 = voice_key + guild = await self.app.storage.get_guild(id2) + + # if guild not found, then we are dealing with a dm or group dm + if not guild: + ctype = await self.app.storage.get_chan_type(id2) + ctype = ChannelType(ctype) + + if ctype == ChannelType.GROUP_DM: + # await self._start_voice_dm() + pass + elif ctype == ChannelType.DM: + # await self._start_voice_gdm() + pass + + return + + # if guild found, then data.channel_id exists, and we treat it + # as a guild + # await self._start_voice_guild() async def leave_all(self, user_id: int) -> int: """Leave all voice channels."""