diff --git a/litecord/blueprints/admin_api/voice.py b/litecord/blueprints/admin_api/voice.py
index b6e123a..e700b27 100644
--- a/litecord/blueprints/admin_api/voice.py
+++ b/litecord/blueprints/admin_api/voice.py
@@ -19,12 +19,14 @@ along with this program. If not, see .
import asyncpg
from quart import Blueprint, jsonify, current_app as app, request
+from logbook import Logger
from litecord.auth import admin_check
from litecord.schemas import validate
from litecord.admin_schemas import VOICE_SERVER, VOICE_REGION
from litecord.errors import BadRequest
+log = Logger(__name__)
bp = Blueprint('voice_admin', __name__)
@@ -49,9 +51,20 @@ async def insert_new_region():
VALUES ($1, $2, $3, $4, $5)
""", j['id'], j['name'], j['vip'], j['deprecated'], j['custom'])
- return jsonify(
- await app.storage.all_voice_regions()
- )
+ regions = await app.storage.all_voice_regions()
+ region_count = len(regions)
+
+ # if region count is 1, this is the first region to be created,
+ # so we should update all guilds to that region
+ if region_count == 1:
+ res = await app.db.execute("""
+ UPDATE guilds
+ SET region = $1
+ """, j['id'])
+
+ log.info('updating guilds to first voice region: {}', res)
+
+ return jsonify(regions)
@bp.route('/regions//servers', methods=['PUT'])
@@ -86,3 +99,31 @@ async def deprecate_region(region):
""", region)
return '', 204
+
+
+async def guild_region_check(app_):
+ """Check all guilds for voice region inconsistencies.
+
+ Since the voice migration caused all guilds.region columns
+ to become NULL, we need to remove such NULLs if we have more
+ than one region setup.
+ """
+
+ regions = await app_.storage.all_voice_regions()
+
+ if not regions:
+ log.info('region check: no regions to move guilds to')
+ return
+
+ res = await app_.db.execute("""
+ UPDATE guilds
+ SET region = (
+ SELECT id
+ FROM voice_regions
+ OFFSET floor(random()*$1)
+ LIMIT 1
+ )
+ WHERE region = NULL
+ """, len(regions))
+
+ log.info('region check: updating guild.region=null: {!r}', res)
diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py
index 2922bfb..f900ef6 100644
--- a/litecord/blueprints/channels.py
+++ b/litecord/blueprints/channels.py
@@ -393,7 +393,7 @@ async def _common_guild_chan(channel_id, j: dict):
""", j[field], channel_id)
-async def _update_text_channel(channel_id: int, j: dict):
+async def _update_text_channel(channel_id: int, j: dict, _user_id: int):
# first do the specific ones related to guild_text_channels
for field in [field for field in j.keys()
if field in ('topic', 'rate_limit_per_user')]:
@@ -406,7 +406,7 @@ async def _update_text_channel(channel_id: int, j: dict):
await _common_guild_chan(channel_id, j)
-async def _update_voice_channel(channel_id: int, j: dict):
+async def _update_voice_channel(channel_id: int, j: dict, _user_id: int):
# first do the specific ones in guild_voice_channels
for field in [field for field in j.keys()
if field in ('bitrate', 'user_limit')]:
diff --git a/run.py b/run.py
index ed3cdb9..8dcfc51 100644
--- a/run.py
+++ b/run.py
@@ -60,6 +60,8 @@ from litecord.blueprints.admin_api import (
voice as voice_admin
)
+from litecord.blueprints.admin_api.voice import guild_region_check
+
from litecord.ratelimits.handler import ratelimit_handler
from litecord.ratelimits.main import RatelimitManager
@@ -298,6 +300,7 @@ async def post_app_start(app_):
# we'll need to start a billing job
app_.sched.spawn(payment_job(app_))
app_.sched.spawn(api_index(app_))
+ app_.sched.spawn(guild_region_check(app_))
def start_websocket(host, port, ws_handler) -> asyncio.Future: