You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
91 lines
2.7 KiB
91 lines
2.7 KiB
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import sqlite3
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
REQUIRED_TABLES = ("households", "option_items")
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description="Import one or more SQLite .sql files into a target database.",
|
|
)
|
|
parser.add_argument(
|
|
"database",
|
|
help="Path to the target SQLite database file.",
|
|
)
|
|
parser.add_argument(
|
|
"sql_files",
|
|
nargs="+",
|
|
help="SQL files to import in order.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def resolve_sql_files(raw_files: list[str]) -> list[Path]:
|
|
sql_paths = [Path(value).expanduser() for value in raw_files]
|
|
resolved: list[Path] = []
|
|
|
|
for path in sql_paths:
|
|
if not path.exists():
|
|
raise FileNotFoundError(f"SQL file not found: {path}")
|
|
if not path.is_file():
|
|
raise FileNotFoundError(f"SQL path is not a file: {path}")
|
|
resolved.append(path.resolve())
|
|
|
|
return resolved
|
|
|
|
|
|
def import_sql_files(database_path: Path, sql_paths: list[Path]) -> None:
|
|
database_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
conn = sqlite3.connect(database_path)
|
|
try:
|
|
conn.execute("PRAGMA foreign_keys = ON;")
|
|
existing_tables = {
|
|
row[0]
|
|
for row in conn.execute(
|
|
"SELECT name FROM sqlite_master WHERE type = 'table'",
|
|
).fetchall()
|
|
}
|
|
missing_tables = [table_name for table_name in REQUIRED_TABLES if table_name not in existing_tables]
|
|
if missing_tables:
|
|
joined = ", ".join(missing_tables)
|
|
raise RuntimeError(
|
|
f"Database is missing required tables: {joined}. "
|
|
"Run migrations or initialize the app database first.",
|
|
)
|
|
for sql_path in sql_paths:
|
|
script = sql_path.read_text(encoding="utf-8-sig")
|
|
print(f"[run] {sql_path}")
|
|
conn.executescript(script)
|
|
conn.commit()
|
|
except Exception:
|
|
conn.rollback()
|
|
raise
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
database_path = Path(args.database).expanduser().resolve()
|
|
|
|
try:
|
|
sql_paths = resolve_sql_files(args.sql_files)
|
|
if database_path.exists() and not database_path.is_file():
|
|
raise FileNotFoundError(f"Database path is not a file: {database_path}")
|
|
import_sql_files(database_path, sql_paths)
|
|
except Exception as exc:
|
|
print(f"[error] {exc}", file=sys.stderr)
|
|
return 1
|
|
|
|
print(f"[done] Imported {len(sql_paths)} SQL file(s) into {database_path}")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|
|
|