mirror of https://gitlab.com/litecord/litecord.git
voice: more voice goodies
- lvsp manager: change internal structure of lvsp conns - voice.manager: add incomplete impl for creating a channel
This commit is contained in:
parent
9dab5b20ae
commit
e2af6b6370
|
|
@ -670,7 +670,7 @@ class GatewayWebsocket:
|
||||||
voice_state = await self.ext.voice.get_state(voice_key)
|
voice_state = await self.ext.voice.get_state(voice_key)
|
||||||
|
|
||||||
if voice_state is None:
|
if voice_state is None:
|
||||||
await self.ext.voice.create_state(voice_key)
|
return await self.ext.voice.create_state(voice_key)
|
||||||
|
|
||||||
same_guild = guild_id == voice_state.guild_id
|
same_guild = guild_id == voice_state.guild_id
|
||||||
same_channel = channel_id == voice_state.channel_id
|
same_channel = channel_id == voice_state.channel_id
|
||||||
|
|
@ -678,10 +678,10 @@ class GatewayWebsocket:
|
||||||
prop = await self._vsu_get_prop(voice_state, data)
|
prop = await self._vsu_get_prop(voice_state, data)
|
||||||
|
|
||||||
if same_guild and same_channel:
|
if same_guild and same_channel:
|
||||||
await self.ext.voice.update_state(voice_state, prop)
|
return await self.ext.voice.update_state(voice_state, prop)
|
||||||
|
|
||||||
if same_guild and not same_channel:
|
if same_guild and not same_channel:
|
||||||
await self.ext.voice.move_state(voice_state, channel_id)
|
return await self.ext.voice.move_state(voice_state, channel_id)
|
||||||
|
|
||||||
async def _handle_5(self, payload: Dict[str, Any]):
|
async def _handle_5(self, payload: Dict[str, Any]):
|
||||||
"""Handle OP 5 Voice Server Ping.
|
"""Handle OP 5 Voice Server Ping.
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from quart import current_app as app
|
from quart import current_app as app
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from typing import Any, Iterable, Optional, Indexable
|
from typing import Any, Iterable, Optional, Sequence
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
from quart.json import JSONEncoder
|
from quart.json import JSONEncoder
|
||||||
|
|
@ -27,7 +27,7 @@ from quart.json import JSONEncoder
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def async_map(function, iterable) -> list:
|
async def async_map(function, iterable: Iterable) -> list:
|
||||||
"""Map a coroutine to an iterable."""
|
"""Map a coroutine to an iterable."""
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
|
|
@ -52,7 +52,7 @@ def dict_get(mapping, key, default):
|
||||||
return mapping.get(key) or default
|
return mapping.get(key) or default
|
||||||
|
|
||||||
|
|
||||||
def index_by_func(function, indexable: Indexable) -> Optional[int]:
|
def index_by_func(function, indexable: Sequence[Any]) -> Optional[int]:
|
||||||
"""Search in an idexable and return the index number
|
"""Search in an idexable and return the index number
|
||||||
for an iterm that has func(item) = True."""
|
for an iterm that has func(item) = True."""
|
||||||
for index, item in enumerate(indexable):
|
for index, item in enumerate(indexable):
|
||||||
|
|
@ -161,7 +161,7 @@ async def pg_set_json(con):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def yield_chunks(input_list: Iterable, chunk_size: int):
|
def yield_chunks(input_list: Sequence[Any], chunk_size: int):
|
||||||
"""Yield successive n-sized chunks from l.
|
"""Yield successive n-sized chunks from l.
|
||||||
|
|
||||||
Taken from https://stackoverflow.com/a/312464.
|
Taken from https://stackoverflow.com/a/312464.
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ class LVSPConnection:
|
||||||
self.hostname = hostname
|
self.hostname = hostname
|
||||||
|
|
||||||
self.conn = None
|
self.conn = None
|
||||||
|
self.health = 0.5
|
||||||
|
|
||||||
self._hb_task = None
|
self._hb_task = None
|
||||||
self._hb_interval = None
|
self._hb_interval = None
|
||||||
|
|
@ -98,6 +99,8 @@ class LVSPConnection:
|
||||||
|
|
||||||
async def _update_health(self, new_health: float):
|
async def _update_health(self, new_health: float):
|
||||||
"""Update the health value of a given voice server."""
|
"""Update the health value of a given voice server."""
|
||||||
|
self.health = new_health
|
||||||
|
|
||||||
await self.app.db.execute("""
|
await self.app.db.execute("""
|
||||||
UPDATE voice_servers
|
UPDATE voice_servers
|
||||||
SET health = $1
|
SET health = $1
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
@ -34,7 +35,15 @@ class LVSPManager:
|
||||||
self.app = app
|
self.app = app
|
||||||
self.voice = voice
|
self.voice = voice
|
||||||
|
|
||||||
self.servers = defaultdict(dict)
|
# map servers to LVSPConnection
|
||||||
|
self.conns = {}
|
||||||
|
|
||||||
|
# maps regions to server hostnames
|
||||||
|
self.servers = defaultdict(list)
|
||||||
|
|
||||||
|
# maps guilds to server hostnames
|
||||||
|
self.guild_servers = {}
|
||||||
|
|
||||||
self.app.loop.create_task(self._spawn())
|
self.app.loop.create_task(self._spawn())
|
||||||
|
|
||||||
async def _spawn(self):
|
async def _spawn(self):
|
||||||
|
|
@ -71,10 +80,11 @@ class LVSPManager:
|
||||||
return
|
return
|
||||||
|
|
||||||
servers = [r['hostname'] for r in servers]
|
servers = [r['hostname'] for r in servers]
|
||||||
|
self.servers[region] = servers
|
||||||
|
|
||||||
for hostname in servers:
|
for hostname in servers:
|
||||||
conn = LVSPConnection(self, region, hostname)
|
conn = LVSPConnection(self, region, hostname)
|
||||||
self.servers[region][hostname] = conn
|
self.conns[hostname] = conn
|
||||||
|
|
||||||
self.app.loop.create_task(
|
self.app.loop.create_task(
|
||||||
conn.run()
|
conn.run()
|
||||||
|
|
@ -83,6 +93,47 @@ class LVSPManager:
|
||||||
async def del_conn(self, conn):
|
async def del_conn(self, conn):
|
||||||
"""Delete a connection from the connection pool."""
|
"""Delete a connection from the connection pool."""
|
||||||
try:
|
try:
|
||||||
self.servers[conn.region].pop(conn.hostname)
|
self.servers[conn.region].remove(conn.hostname)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.conns.pop(conn.hostname)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def guild_region(self, guild_id: int) -> Optional[str]:
|
||||||
|
"""Return the voice region of a guild."""
|
||||||
|
return await self.app.db.fetchval("""
|
||||||
|
SELECT region
|
||||||
|
FROM guilds
|
||||||
|
WHERE id = $1
|
||||||
|
""", guild_id)
|
||||||
|
|
||||||
|
def get_health(self, hostname: str) -> float:
|
||||||
|
"""Get voice server health, given hostname."""
|
||||||
|
try:
|
||||||
|
conn = self.conns[hostname]
|
||||||
|
except KeyError:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
return conn.health
|
||||||
|
|
||||||
|
async def get_server(self, guild_id: int) -> str:
|
||||||
|
"""Get a voice server for the given guild, assigns
|
||||||
|
one if there isn't any."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
hostname = self.guild_servers[guild_id]
|
||||||
|
except KeyError:
|
||||||
|
region = await self.guild_region(guild_id)
|
||||||
|
|
||||||
|
# sort connected servers by health
|
||||||
|
sorted_servers = sorted(
|
||||||
|
self.servers[region],
|
||||||
|
self.get_health,
|
||||||
|
)
|
||||||
|
|
||||||
|
hostname = sorted_servers[0]
|
||||||
|
|
||||||
|
return hostname
|
||||||
|
|
|
||||||
|
|
@ -153,11 +153,72 @@ class VoiceManager:
|
||||||
async def move_channels(self, old_voice_key: VoiceKey, channel_id: int):
|
async def move_channels(self, old_voice_key: VoiceKey, channel_id: int):
|
||||||
"""Move a user between channels."""
|
"""Move a user between channels."""
|
||||||
await self.del_state(old_voice_key)
|
await self.del_state(old_voice_key)
|
||||||
await self.create_state(old_voice_key, channel_id, {})
|
await self.create_state(old_voice_key, {'channel_id': channel_id})
|
||||||
|
|
||||||
async def create_state(self, voice_key: VoiceKey, channel_id: int,
|
async def _create_ctx_guild(self, guild_id, channel_id):
|
||||||
data: dict):
|
# get a voice server
|
||||||
|
server = await self.lvsp.get_server(guild_id)
|
||||||
|
conn = self.lvsp.get_conn(server)
|
||||||
|
chan = await self.app.storage.get_channel(channel_id)
|
||||||
|
|
||||||
|
# TODO: this, but properly
|
||||||
|
# TODO: when the server sends a reply to CHAN_REQ, we need to update
|
||||||
|
# LVSPManager.guild_servers.
|
||||||
|
await conn.send_info('CHAN_REQ', {
|
||||||
|
'guild_id': str(guild_id),
|
||||||
|
'channel_id': str(channel_id),
|
||||||
|
|
||||||
|
'channel_properties': {
|
||||||
|
'bitrate': chan['bitrate']
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
async def _start_voice_guild(self, voice_key: VoiceKey, data: dict):
|
||||||
|
"""Start a voice context in a guild."""
|
||||||
|
user_id, guild_id = voice_key
|
||||||
|
channel_id = int(data['channel_id'])
|
||||||
|
|
||||||
|
existing_states = self.states[voice_key]
|
||||||
|
channel_exists = any(
|
||||||
|
state.channel_id == channel_id for state in existing_states)
|
||||||
|
|
||||||
|
if not channel_exists:
|
||||||
|
await self._create_ctx_guild(guild_id, channel_id)
|
||||||
|
|
||||||
|
async def create_state(self, voice_key: VoiceKey, data: dict):
|
||||||
|
"""Creates (or tries to create) a voice state.
|
||||||
|
|
||||||
|
Depending on the VoiceKey given, it will use the guild's voice
|
||||||
|
region or assign one based on the starter of a call, or the owner of
|
||||||
|
a Group DM.
|
||||||
|
|
||||||
|
Once a region is assigned, it'll choose the best voice server
|
||||||
|
and send a request to it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: handle CALL events.
|
||||||
|
|
||||||
|
# compare if this voice key is for a guild or a channel
|
||||||
|
_uid, id2 = voice_key
|
||||||
|
guild = await self.app.storage.get_guild(id2)
|
||||||
|
|
||||||
|
# if guild not found, then we are dealing with a dm or group dm
|
||||||
|
if not guild:
|
||||||
|
ctype = await self.app.storage.get_chan_type(id2)
|
||||||
|
ctype = ChannelType(ctype)
|
||||||
|
|
||||||
|
if ctype == ChannelType.GROUP_DM:
|
||||||
|
# await self._start_voice_dm()
|
||||||
pass
|
pass
|
||||||
|
elif ctype == ChannelType.DM:
|
||||||
|
# await self._start_voice_gdm()
|
||||||
|
pass
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
# if guild found, then data.channel_id exists, and we treat it
|
||||||
|
# as a guild
|
||||||
|
# await self._start_voice_guild()
|
||||||
|
|
||||||
async def leave_all(self, user_id: int) -> int:
|
async def leave_all(self, user_id: int) -> int:
|
||||||
"""Leave all voice channels."""
|
"""Leave all voice channels."""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue