diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py
index e6fa7d4..d17cf27 100644
--- a/litecord/gateway/websocket.py
+++ b/litecord/gateway/websocket.py
@@ -670,7 +670,7 @@ class GatewayWebsocket:
voice_state = await self.ext.voice.get_state(voice_key)
if voice_state is None:
- await self.ext.voice.create_state(voice_key)
+ return await self.ext.voice.create_state(voice_key)
same_guild = guild_id == voice_state.guild_id
same_channel = channel_id == voice_state.channel_id
@@ -678,10 +678,10 @@ class GatewayWebsocket:
prop = await self._vsu_get_prop(voice_state, data)
if same_guild and same_channel:
- await self.ext.voice.update_state(voice_state, prop)
+ return await self.ext.voice.update_state(voice_state, prop)
if same_guild and not same_channel:
- await self.ext.voice.move_state(voice_state, channel_id)
+ return await self.ext.voice.move_state(voice_state, channel_id)
async def _handle_5(self, payload: Dict[str, Any]):
"""Handle OP 5 Voice Server Ping.
diff --git a/litecord/permissions.py b/litecord/permissions.py
index 4e2a4b6..4b062e8 100644
--- a/litecord/permissions.py
+++ b/litecord/permissions.py
@@ -18,6 +18,7 @@ along with this program. If not, see .
"""
import ctypes
+from typing import Optional
from quart import current_app as app
diff --git a/litecord/utils.py b/litecord/utils.py
index 0145825..513fafa 100644
--- a/litecord/utils.py
+++ b/litecord/utils.py
@@ -19,7 +19,7 @@ along with this program. If not, see .
import asyncio
import json
-from typing import Any, Iterable, Optional, Indexable
+from typing import Any, Iterable, Optional, Sequence
from logbook import Logger
from quart.json import JSONEncoder
@@ -27,7 +27,7 @@ from quart.json import JSONEncoder
log = Logger(__name__)
-async def async_map(function, iterable) -> list:
+async def async_map(function, iterable: Iterable) -> list:
"""Map a coroutine to an iterable."""
res = []
@@ -52,7 +52,7 @@ def dict_get(mapping, key, default):
return mapping.get(key) or default
-def index_by_func(function, indexable: Indexable) -> Optional[int]:
+def index_by_func(function, indexable: Sequence[Any]) -> Optional[int]:
"""Search in an idexable and return the index number
for an iterm that has func(item) = True."""
for index, item in enumerate(indexable):
@@ -161,7 +161,7 @@ async def pg_set_json(con):
)
-def yield_chunks(input_list: Iterable, chunk_size: int):
+def yield_chunks(input_list: Sequence[Any], chunk_size: int):
"""Yield successive n-sized chunks from l.
Taken from https://stackoverflow.com/a/312464.
diff --git a/litecord/voice/lvsp_conn.py b/litecord/voice/lvsp_conn.py
index a350737..143ba97 100644
--- a/litecord/voice/lvsp_conn.py
+++ b/litecord/voice/lvsp_conn.py
@@ -38,6 +38,7 @@ class LVSPConnection:
self.hostname = hostname
self.conn = None
+ self.health = 0.5
self._hb_task = None
self._hb_interval = None
@@ -98,6 +99,8 @@ class LVSPConnection:
async def _update_health(self, new_health: float):
"""Update the health value of a given voice server."""
+ self.health = new_health
+
await self.app.db.execute("""
UPDATE voice_servers
SET health = $1
diff --git a/litecord/voice/lvsp_manager.py b/litecord/voice/lvsp_manager.py
index 2182859..a40ff83 100644
--- a/litecord/voice/lvsp_manager.py
+++ b/litecord/voice/lvsp_manager.py
@@ -17,6 +17,7 @@ along with this program. If not, see .
"""
+from typing import Optional
from collections import defaultdict
from logbook import Logger
@@ -34,7 +35,15 @@ class LVSPManager:
self.app = app
self.voice = voice
- self.servers = defaultdict(dict)
+ # map servers to LVSPConnection
+ self.conns = {}
+
+ # maps regions to server hostnames
+ self.servers = defaultdict(list)
+
+ # maps guilds to server hostnames
+ self.guild_servers = {}
+
self.app.loop.create_task(self._spawn())
async def _spawn(self):
@@ -71,10 +80,11 @@ class LVSPManager:
return
servers = [r['hostname'] for r in servers]
+ self.servers[region] = servers
for hostname in servers:
conn = LVSPConnection(self, region, hostname)
- self.servers[region][hostname] = conn
+ self.conns[hostname] = conn
self.app.loop.create_task(
conn.run()
@@ -83,6 +93,47 @@ class LVSPManager:
async def del_conn(self, conn):
"""Delete a connection from the connection pool."""
try:
- self.servers[conn.region].pop(conn.hostname)
+ self.servers[conn.region].remove(conn.hostname)
except KeyError:
pass
+
+ try:
+ self.conns.pop(conn.hostname)
+ except KeyError:
+ pass
+
+ async def guild_region(self, guild_id: int) -> Optional[str]:
+ """Return the voice region of a guild."""
+ return await self.app.db.fetchval("""
+ SELECT region
+ FROM guilds
+ WHERE id = $1
+ """, guild_id)
+
+ def get_health(self, hostname: str) -> float:
+ """Get voice server health, given hostname."""
+ try:
+ conn = self.conns[hostname]
+ except KeyError:
+ return -1
+
+ return conn.health
+
+ async def get_server(self, guild_id: int) -> str:
+ """Get a voice server for the given guild, assigns
+ one if there isn't any."""
+
+ try:
+ hostname = self.guild_servers[guild_id]
+ except KeyError:
+ region = await self.guild_region(guild_id)
+
+ # sort connected servers by health
+ sorted_servers = sorted(
+ self.servers[region],
+ self.get_health,
+ )
+
+ hostname = sorted_servers[0]
+
+ return hostname
diff --git a/litecord/voice/manager.py b/litecord/voice/manager.py
index 535b6b2..a408909 100644
--- a/litecord/voice/manager.py
+++ b/litecord/voice/manager.py
@@ -153,11 +153,72 @@ class VoiceManager:
async def move_channels(self, old_voice_key: VoiceKey, channel_id: int):
"""Move a user between channels."""
await self.del_state(old_voice_key)
- await self.create_state(old_voice_key, channel_id, {})
+ await self.create_state(old_voice_key, {'channel_id': channel_id})
- async def create_state(self, voice_key: VoiceKey, channel_id: int,
- data: dict):
- pass
+ async def _create_ctx_guild(self, guild_id, channel_id):
+ # get a voice server
+ server = await self.lvsp.get_server(guild_id)
+ conn = self.lvsp.get_conn(server)
+ chan = await self.app.storage.get_channel(channel_id)
+
+ # TODO: this, but properly
+ # TODO: when the server sends a reply to CHAN_REQ, we need to update
+ # LVSPManager.guild_servers.
+ await conn.send_info('CHAN_REQ', {
+ 'guild_id': str(guild_id),
+ 'channel_id': str(channel_id),
+
+ 'channel_properties': {
+ 'bitrate': chan['bitrate']
+ }
+ })
+
+ async def _start_voice_guild(self, voice_key: VoiceKey, data: dict):
+ """Start a voice context in a guild."""
+ user_id, guild_id = voice_key
+ channel_id = int(data['channel_id'])
+
+ existing_states = self.states[voice_key]
+ channel_exists = any(
+ state.channel_id == channel_id for state in existing_states)
+
+ if not channel_exists:
+ await self._create_ctx_guild(guild_id, channel_id)
+
+ async def create_state(self, voice_key: VoiceKey, data: dict):
+ """Creates (or tries to create) a voice state.
+
+ Depending on the VoiceKey given, it will use the guild's voice
+ region or assign one based on the starter of a call, or the owner of
+ a Group DM.
+
+ Once a region is assigned, it'll choose the best voice server
+ and send a request to it.
+ """
+
+ # TODO: handle CALL events.
+
+ # compare if this voice key is for a guild or a channel
+ _uid, id2 = voice_key
+ guild = await self.app.storage.get_guild(id2)
+
+ # if guild not found, then we are dealing with a dm or group dm
+ if not guild:
+ ctype = await self.app.storage.get_chan_type(id2)
+ ctype = ChannelType(ctype)
+
+ if ctype == ChannelType.GROUP_DM:
+ # await self._start_voice_dm()
+ pass
+ elif ctype == ChannelType.DM:
+ # await self._start_voice_gdm()
+ pass
+
+ return
+
+ # if guild found, then data.channel_id exists, and we treat it
+ # as a guild
+ # await self._start_voice_guild()
async def leave_all(self, user_id: int) -> int:
"""Leave all voice channels."""