You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

91 lines
2.7 KiB

#!/usr/bin/env python3
from __future__ import annotations
import argparse
import sqlite3
import sys
from pathlib import Path
REQUIRED_TABLES = ("households", "option_items")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Import one or more SQLite .sql files into a target database.",
)
parser.add_argument(
"database",
help="Path to the target SQLite database file.",
)
parser.add_argument(
"sql_files",
nargs="+",
help="SQL files to import in order.",
)
return parser.parse_args()
def resolve_sql_files(raw_files: list[str]) -> list[Path]:
sql_paths = [Path(value).expanduser() for value in raw_files]
resolved: list[Path] = []
for path in sql_paths:
if not path.exists():
raise FileNotFoundError(f"SQL file not found: {path}")
if not path.is_file():
raise FileNotFoundError(f"SQL path is not a file: {path}")
resolved.append(path.resolve())
return resolved
def import_sql_files(database_path: Path, sql_paths: list[Path]) -> None:
database_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(database_path)
try:
conn.execute("PRAGMA foreign_keys = ON;")
existing_tables = {
row[0]
for row in conn.execute(
"SELECT name FROM sqlite_master WHERE type = 'table'",
).fetchall()
}
missing_tables = [table_name for table_name in REQUIRED_TABLES if table_name not in existing_tables]
if missing_tables:
joined = ", ".join(missing_tables)
raise RuntimeError(
f"Database is missing required tables: {joined}. "
"Run migrations or initialize the app database first.",
)
for sql_path in sql_paths:
script = sql_path.read_text(encoding="utf-8-sig")
print(f"[run] {sql_path}")
conn.executescript(script)
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
def main() -> int:
args = parse_args()
database_path = Path(args.database).expanduser().resolve()
try:
sql_paths = resolve_sql_files(args.sql_files)
if database_path.exists() and not database_path.is_file():
raise FileNotFoundError(f"Database path is not a file: {database_path}")
import_sql_files(database_path, sql_paths)
except Exception as exc:
print(f"[error] {exc}", file=sys.stderr)
return 1
print(f"[done] Imported {len(sql_paths)} SQL file(s) into {database_path}")
return 0
if __name__ == "__main__":
raise SystemExit(main())