litecord/manage/cmd/migration/command.py

144 lines
3.8 KiB
Python

import inspect
import os
from pathlib import Path
from dataclasses import dataclass
from collections import namedtuple
from typing import Dict
import asyncpg
from logbook import Logger
log = Logger(__name__)
Migration = namedtuple('Migration', 'id name path')
@dataclass
class MigrationContext:
"""Hold information about migration."""
migration_folder: Path
scripts: Dict[int, Migration]
@property
def latest(self):
"""Return the latest migration ID."""
return 0 if len(self.scripts) == 0 else max(self.scripts.keys())
def make_migration_ctx() -> MigrationContext:
"""Create the MigrationContext instance."""
# taken from https://stackoverflow.com/a/6628348
script_path = inspect.stack()[0][1]
script_folder = os.sep.join(script_path.split(os.sep)[:-1])
script_folder = Path(script_folder)
migration_folder = script_folder / 'scripts'
mctx = MigrationContext(migration_folder, {})
for mig_path in migration_folder.glob('*.sql'):
mig_path_str = str(mig_path)
# extract migration script id and name
mig_filename = mig_path_str.split(os.sep)[-1].split('.')[0]
name_fragments = mig_filename.split('_')
mig_id = int(name_fragments[0])
mig_name = '_'.join(name_fragments[1:])
mctx.scripts[mig_id] = Migration(
mig_id, mig_name, mig_path)
return mctx
async def _ensure_changelog(app, ctx):
# make sure we have the migration table up
try:
await app.db.execute("""
CREATE TABLE migration_log (
change_num bigint NOT NULL,
apply_ts timestamp without time zone default
(now() at time zone 'utc'),
description text,
PRIMARY KEY (change_num)
);
""")
# if we were able to create the
# migration_log table, insert that we are
# on the latest version.
await app.db.execute("""
INSERT INTO migration_log (change_num, description)
VALUES ($1, $2)
""", ctx.latest, 'migration setup')
except asyncpg.DuplicateTableError:
log.debug('existing migration table')
async def apply_migration(app, migration: Migration):
"""Apply a single migration."""
migration_sql = migration.path.read_text(encoding='utf-8')
try:
await app.db.execute("""
INSERT INTO migration_log (change_num, description)
VALUES ($1, $2)
""", migration.id, f'migration: {migration.name}')
except asyncpg.UniqueViolationError:
log.warning('already applied {}', migration.id)
return
await app.db.execute(migration_sql)
log.info('applied {}', migration.id)
async def migrate_cmd(app, _args):
"""Main migration command.
This makes sure the database
is updated.
"""
ctx = make_migration_ctx()
await _ensure_changelog(app, ctx)
# local point in the changelog
local_change = await app.db.fetchval("""
SELECT max(change_num)
FROM migration_log
""")
local_change = local_change or 0
latest_change = ctx.latest
log.debug('local: {}, latest: {}', local_change, latest_change)
if local_change == latest_change:
print('no changes to do, exiting')
return
# we do local_change + 1 so we start from the
# next migration to do, end in latest_change + 1
# because of how range() works.
for idx in range(local_change + 1, latest_change + 1):
migration = ctx.scripts.get(idx)
print('applying', migration.id, migration.name)
await apply_migration(app, migration)
def setup(subparser):
migrate_parser = subparser.add_parser(
'migrate',
help='Run migration tasks',
description=migrate_cmd.__doc__
)
migrate_parser.set_defaults(func=migrate_cmd)