litecord/manage/cmd/migration/command.py

144 lines
3.7 KiB
Python

import inspect
from pathlib import Path
from dataclasses import dataclass
from collections import namedtuple
from typing import Dict
import asyncpg
from logbook import Logger
log = Logger(__name__)
Migration = namedtuple('Migration', 'id name path')
@dataclass
class MigrationContext:
"""Hold information about migration."""
migration_folder: Path
scripts: Dict[int, Migration]
@property
def latest(self):
"""Return the latest migration ID."""
return max(self.scripts.keys())
def make_migration_ctx() -> MigrationContext:
"""Create the MigrationContext instance."""
# taken from https://stackoverflow.com/a/6628348
script_path = inspect.stack()[0][1]
script_folder = '/'.join(script_path.split('/')[:-1])
script_folder = Path(script_folder)
migration_folder = script_folder / 'scripts'
mctx = MigrationContext(migration_folder, {})
for mig_path in migration_folder.glob('*.sql'):
mig_path_str = str(mig_path)
# extract migration script id and name
mig_filename = mig_path_str.split('/')[-1].split('.')[0]
name_fragments = mig_filename.split('_')
mig_id = int(name_fragments[0])
mig_name = '_'.join(name_fragments[1:])
mctx.scripts[mig_id] = Migration(
mig_id, mig_name, mig_path)
return mctx
async def _ensure_changelog(app, ctx):
# make sure we have the migration table up
try:
await app.db.execute("""
CREATE TABLE migration_log (
change_num bigint NOT NULL,
apply_ts timestamp without time zone default
(now() at time zone 'utc'),
description text,
PRIMARY KEY (change_num)
);
""")
# if we were able to create the
# migration_log table, insert that we are
# on the latest version.
await app.db.execute("""
INSERT INTO migration_log (change_num, description)
VALUES ($1, $2)
""", ctx.latest, 'migration setup')
except asyncpg.DuplicateTableError:
log.debug('existing migration table')
async def apply_migration(app, migration: Migration):
"""Apply a single migration."""
migration_sql = migration.path.read_text(encoding='utf-8')
try:
await app.db.execute("""
INSERT INTO migration_log (change_num, description)
VALUES ($1, $2)
""", migration.id, f'migration: {migration.name}')
except asyncpg.UniqueViolationError:
log.warning('already applied {}', migration.id)
return
await app.db.execute(migration_sql)
log.info('applied {}', migration.id)
async def migrate_cmd(app, _args):
"""Main migration command.
This makes sure the database
is updated.
"""
ctx = make_migration_ctx()
await _ensure_changelog(app, ctx)
# local point in the changelog
local_change = await app.db.fetchval("""
SELECT max(change_num)
FROM migration_log
""")
local_change = local_change or 0
latest_change = ctx.latest
log.debug('local: {}, latest: {}', local_change, latest_change)
if local_change == latest_change:
print('no changes to do, exiting')
return
# we do local_change + 1 so we start from the
# next migration to do, end in latest_change + 1
# because of how range() works.
for idx in range(local_change + 1, latest_change + 1):
migration = ctx.scripts.get(idx)
print('applying', migration.id, migration.name)
await apply_migration(app, migration)
def setup(subparser):
migrate_parser = subparser.add_parser(
'migrate',
help='Run migration tasks',
description=migrate_cmd.__doc__
)
migrate_parser.set_defaults(func=migrate_cmd)