#!/usr/bin/env python3 from __future__ import annotations import argparse import sqlite3 import sys from pathlib import Path REQUIRED_TABLES = ("households", "option_items") def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Import one or more SQLite .sql files into a target database.", ) parser.add_argument( "database", help="Path to the target SQLite database file.", ) parser.add_argument( "sql_files", nargs="+", help="SQL files to import in order.", ) return parser.parse_args() def resolve_sql_files(raw_files: list[str]) -> list[Path]: sql_paths = [Path(value).expanduser() for value in raw_files] resolved: list[Path] = [] for path in sql_paths: if not path.exists(): raise FileNotFoundError(f"SQL file not found: {path}") if not path.is_file(): raise FileNotFoundError(f"SQL path is not a file: {path}") resolved.append(path.resolve()) return resolved def import_sql_files(database_path: Path, sql_paths: list[Path]) -> None: database_path.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(database_path) try: conn.execute("PRAGMA foreign_keys = ON;") existing_tables = { row[0] for row in conn.execute( "SELECT name FROM sqlite_master WHERE type = 'table'", ).fetchall() } missing_tables = [table_name for table_name in REQUIRED_TABLES if table_name not in existing_tables] if missing_tables: joined = ", ".join(missing_tables) raise RuntimeError( f"Database is missing required tables: {joined}. " "Run migrations or initialize the app database first.", ) for sql_path in sql_paths: script = sql_path.read_text(encoding="utf-8-sig") print(f"[run] {sql_path}") conn.executescript(script) conn.commit() except Exception: conn.rollback() raise finally: conn.close() def main() -> int: args = parse_args() database_path = Path(args.database).expanduser().resolve() try: sql_paths = resolve_sql_files(args.sql_files) if database_path.exists() and not database_path.is_file(): raise FileNotFoundError(f"Database path is not a file: {database_path}") import_sql_files(database_path, sql_paths) except Exception as exc: print(f"[error] {exc}", file=sys.stderr) return 1 print(f"[done] Imported {len(sql_paths)} SQL file(s) into {database_path}") return 0 if __name__ == "__main__": raise SystemExit(main())