from __future__ import annotations import csv import io import json import secrets from collections import Counter from collections.abc import Iterator, Mapping, Sequence from datetime import UTC, datetime, timedelta from decimal import Decimal from pathlib import Path import re from typing import TypedDict from flask import current_app from app.extensions import db from app.models import Household, OptionItem from app.services.households import ( build_new_household_draft, parse_admin_form, serialize_admin_edit_snapshot, ) CSV_IMPORT_FIELDS = ( "household_code", "head_name", "phone", "side", "invite_status", "attendance_status", "expected_attendee_count", "actual_attendee_count", "child_count", "red_packet_child_count", "total_gift_amount", "gift_method_option_code", "gift_scene_option_code", "favor_status", "candy_status", "child_red_packet_status", "note", ) PREVIEW_DIRECTORY_NAME = "csv_previews" PREVIEW_FILE_PATTERN = re.compile(r"^[A-Za-z0-9_-]{10,}$") IMPORT_CONFLICT_MODES = ("skip_conflicts", "update_by_code") OPTION_GROUPS = ( "relation_category", "relation_detail", "tag", "gift_method", "gift_scene", ) class OptionIndex(TypedDict): relation_category_ids: set[int] relation_detail_parent_map: dict[int, int | None] tag_ids: set[int] gift_method_ids: set[int] gift_scene_ids: set[int] gift_method_by_code: dict[str, OptionItem] gift_scene_by_code: dict[str, OptionItem] class PreviewSummary(TypedDict): create_count: int update_count: int conflict_count: int invalid_count: int class HouseholdBrief(TypedDict): id: int household_code: str head_name: str phone: str | None side: str class PreviewRow(TypedDict): row_number: int raw: dict[str, str] payload: dict[str, object] | None status: str errors: list[str] matched_household: HouseholdBrief | None conflict_household: HouseholdBrief | None class PreviewData(TypedDict): token: str file_name: str created_at: str total_rows: int rows: list[PreviewRow] summary: PreviewSummary class ImportSummary(TypedDict): file_name: str total_rows_parsed: int rows_created: int rows_updated: int rows_skipped: int rows_invalid: int conflict_mode: str class CsvImportError(ValueError): """CSV 导入预览阶段的用户可见错误。""" def build_household_csv_template() -> str: buffer = io.StringIO(newline="") writer = csv.writer(buffer, lineterminator="\r\n") writer.writerow(CSV_IMPORT_FIELDS) return "\ufeff" + buffer.getvalue() def build_household_export_rows(households: Sequence[Household]) -> Iterator[list[str]]: yield list(CSV_IMPORT_FIELDS) for household in households: yield [ household.household_code, household.head_name, household.phone or "", household.side, household.invite_status, household.attendance_status, str(household.expected_attendee_count), str(household.actual_attendee_count), str(household.child_count), str(household.red_packet_child_count), f"{household.total_gift_amount:.2f}", household.gift_method_option.option_code if household.gift_method_option else "", household.gift_scene_option.option_code if household.gift_scene_option else "", household.favor_status, household.candy_status, household.child_red_packet_status, household.note or "", ] def stream_household_csv(households: Sequence[Household]) -> Iterator[str]: yield "\ufeff" buffer = io.StringIO(newline="") writer = csv.writer(buffer, lineterminator="\r\n") for row in build_household_export_rows(households): writer.writerow(row) yield buffer.getvalue() buffer.seek(0) buffer.truncate(0) def build_household_export_filename(*, scope: str) -> str: timestamp = datetime.now(UTC).strftime("%Y%m%d-%H%M%S") return f"households-{scope}-{timestamp}.csv" def preview_household_csv(*, file_name: str, file_bytes: bytes) -> PreviewData: if not file_name.lower().endswith(".csv"): raise CsvImportError("请上传 .csv 格式文件。") if not file_bytes: raise CsvImportError("上传文件不能为空。") max_bytes = int(current_app.config["HOUSEHOLD_CSV_MAX_UPLOAD_BYTES"]) if len(file_bytes) > max_bytes: raise CsvImportError("CSV 文件不能超过 5 MB。") try: decoded_text = file_bytes.decode("utf-8-sig") except UnicodeDecodeError as exc: raise CsvImportError("CSV 文件必须使用 UTF-8 编码。") from exc reader = csv.DictReader(io.StringIO(decoded_text, newline="")) if reader.fieldnames is None: raise CsvImportError("CSV 文件缺少表头,无法导入。") normalized_headers = [header.strip() for header in reader.fieldnames] missing_headers = [field for field in CSV_IMPORT_FIELDS if field not in normalized_headers] if missing_headers: missing_header_text = "、".join(missing_headers) raise CsvImportError(f"CSV 缺少必填表头:{missing_header_text}。") normalized_rows = [ { field: ((row.get(field) or "").strip()) for field in CSV_IMPORT_FIELDS } for row in reader ] if not normalized_rows: raise CsvImportError("CSV 文件中没有可导入的数据行。") option_index = _build_option_index() all_households = db.session.execute( db.select(Household).order_by(Household.id.asc()), ).scalars().all() households = [household for household in all_households if household.deleted_at is None] households_by_code = {household.household_code: household for household in all_households} households_by_identity = { _household_identity_key(household.head_name, household.side, household.phone): household for household in households if _household_identity_key(household.head_name, household.side, household.phone) is not None } duplicate_codes = Counter( row["household_code"] for row in normalized_rows if row["household_code"] ) preview_rows: list[PreviewRow] = [] for index, row in enumerate(normalized_rows, start=2): code_match = households_by_code.get(row["household_code"]) deleted_code_match = code_match if code_match is not None and code_match.deleted_at is not None else None active_code_match = code_match if code_match is not None and code_match.deleted_at is None else None identity_key = _household_identity_key(row["head_name"], row["side"], row["phone"]) conflict_match = households_by_identity.get(identity_key) if identity_key is not None else None if active_code_match is not None and conflict_match is active_code_match: conflict_match = None row_errors: list[str] = [] if row["household_code"] and duplicate_codes[row["household_code"]] > 1: row_errors.append("CSV 文件中存在重复的户编码,请先去重后再导入。") if deleted_code_match is not None: row_errors.append("户编码已被历史删除记录占用,请改用新的户编码。") payload, validation_errors = _validate_preview_row( row=row, option_index=option_index, existing_household=active_code_match, ) row_errors.extend(validation_errors) if row_errors: row_status = "invalid" elif active_code_match is not None: row_status = "update" elif conflict_match is not None: row_status = "conflict" else: row_status = "create" preview_row: PreviewRow = { "row_number": index, "raw": row, "payload": _serialize_payload(payload), "status": row_status, "errors": row_errors, "matched_household": _household_brief(active_code_match), "conflict_household": _household_brief(conflict_match), } preview_rows.append(preview_row) preview: PreviewData = { "token": _generate_preview_token(), "file_name": file_name, "created_at": datetime.now(UTC).isoformat(), "total_rows": len(preview_rows), "rows": preview_rows, "summary": _build_preview_summary(preview_rows), } save_household_import_preview(preview) return preview def save_household_import_preview(preview: PreviewData) -> None: _purge_expired_preview_files() token = str(preview["token"]) preview_path = _preview_file_path(token) preview_path.parent.mkdir(parents=True, exist_ok=True) preview_path.write_text( json.dumps(preview, ensure_ascii=False, separators=(",", ":")), encoding="utf-8", ) def load_household_import_preview(token: str) -> PreviewData | None: _purge_expired_preview_files() if not PREVIEW_FILE_PATTERN.fullmatch(token): return None preview_path = _preview_file_path(token) if not preview_path.exists(): return None try: preview = json.loads(preview_path.read_text(encoding="utf-8")) except (OSError, json.JSONDecodeError): return None if not isinstance(preview, dict): return None token_value = preview.get("token") file_name_value = preview.get("file_name") created_at_value = preview.get("created_at") total_rows_value = preview.get("total_rows") rows_value = preview.get("rows") summary_value = preview.get("summary") if not isinstance(token_value, str) or not isinstance(file_name_value, str) or not isinstance(created_at_value, str): return None if not isinstance(total_rows_value, int) or not isinstance(rows_value, list) or not isinstance(summary_value, dict): return None preview_rows: list[PreviewRow] = [] for row in rows_value: if not isinstance(row, dict): return None row_number = row.get("row_number") raw = row.get("raw") payload = row.get("payload") status = row.get("status") errors = row.get("errors") matched_household = row.get("matched_household") conflict_household = row.get("conflict_household") if not isinstance(row_number, int) or not isinstance(raw, dict) or not isinstance(status, str) or not isinstance(errors, list): return None preview_row: PreviewRow = { "row_number": row_number, "raw": {str(key): str(value) for key, value in raw.items()}, "payload": payload if isinstance(payload, dict) else None, "status": status, "errors": [str(error) for error in errors], "matched_household": _coerce_household_brief(matched_household), "conflict_household": _coerce_household_brief(conflict_household), } preview_rows.append(preview_row) preview_data: PreviewData = { "token": token_value, "file_name": file_name_value, "created_at": created_at_value, "total_rows": total_rows_value, "rows": preview_rows, "summary": { "create_count": _int_value(summary_value.get("create_count")), "update_count": _int_value(summary_value.get("update_count")), "conflict_count": _int_value(summary_value.get("conflict_count")), "invalid_count": _int_value(summary_value.get("invalid_count")), }, } return preview_data def delete_household_import_preview(token: str) -> None: if not PREVIEW_FILE_PATTERN.fullmatch(token): return preview_path = _preview_file_path(token) if preview_path.exists(): preview_path.unlink() def apply_household_import_preview(*, preview: PreviewData, actor_id: int | None, conflict_mode: str) -> ImportSummary: if conflict_mode not in IMPORT_CONFLICT_MODES: raise CsvImportError("导入处理方式不合法。") rows = preview["rows"] file_name = preview["file_name"] total_rows = preview["total_rows"] invalid_rows = 0 created_rows = 0 updated_rows = 0 skipped_rows = 0 current_households = db.session.execute( db.select(Household), ).scalars().all() active_households = [household for household in current_households if household.deleted_at is None] households_by_code = {household.household_code: household for household in current_households} households_by_identity = { _household_identity_key(household.head_name, household.side, household.phone): household for household in active_households if _household_identity_key(household.head_name, household.side, household.phone) is not None } for row in rows: status = row["status"] if status == "invalid": invalid_rows += 1 continue payload = _deserialize_payload(row["payload"]) if payload is None: invalid_rows += 1 continue household_code = str(payload["household_code"]) existing_household = households_by_code.get(household_code) if existing_household is not None and existing_household.deleted_at is not None: skipped_rows += 1 continue identity_key = _household_identity_key( str(payload["head_name"]), str(payload["side"]), _string_or_none(payload.get("phone")), ) conflict_household = households_by_identity.get(identity_key) if identity_key is not None else None if existing_household is not None and conflict_household is existing_household: conflict_household = None if existing_household is not None: if conflict_mode != "update_by_code": skipped_rows += 1 continue before_snapshot = serialize_admin_edit_snapshot(existing_household) _apply_payload(existing_household, payload, actor_id=actor_id, increment_version=True) after_snapshot = serialize_admin_edit_snapshot(existing_household) if after_snapshot == before_snapshot: skipped_rows += 1 continue updated_rows += 1 updated_identity_key = _household_identity_key(existing_household.head_name, existing_household.side, existing_household.phone) if updated_identity_key is not None: households_by_identity[updated_identity_key] = existing_household continue if conflict_household is not None: skipped_rows += 1 continue new_household = build_new_household_draft(household_code=household_code) _apply_payload(new_household, payload, actor_id=actor_id, increment_version=False) db.session.add(new_household) db.session.flush() households_by_code[new_household.household_code] = new_household new_identity_key = _household_identity_key(new_household.head_name, new_household.side, new_household.phone) if new_identity_key is not None: households_by_identity[new_identity_key] = new_household created_rows += 1 summary: ImportSummary = { "file_name": file_name, "total_rows_parsed": total_rows, "rows_created": created_rows, "rows_updated": updated_rows, "rows_skipped": skipped_rows, "rows_invalid": invalid_rows, "conflict_mode": conflict_mode, } return summary def get_household_import_conflict_modes() -> tuple[tuple[str, str, str], ...]: return ( ("skip_conflicts", "跳过冲突与更新候选", "只新增明确可创建的行,编码更新候选和疑似重复行都会跳过"), ("update_by_code", "按户编码更新", "编码命中的行会更新,疑似重复但无编码命中的冲突行仍会跳过"), ) def option_code_label(option_group: str, option_code: str | None) -> str: if not option_code: return "" option = db.session.execute( db.select(OptionItem) .where(OptionItem.option_group == option_group, OptionItem.option_code == option_code) .limit(1), ).scalar_one_or_none() if option is None: return option_code return option.option_label def _validate_preview_row( *, row: Mapping[str, str], option_index: OptionIndex, existing_household: Household | None, ) -> tuple[dict[str, object] | None, list[str]]: errors: list[str] = [] gift_method_code = row["gift_method_option_code"] gift_scene_code = row["gift_scene_option_code"] method_option_id, method_error = _resolve_option_code( gift_method_code, option_group="gift_method", code_map=option_index["gift_method_by_code"], field_label="礼金方式", ) if method_error: errors.append(method_error) scene_option_id, scene_error = _resolve_option_code( gift_scene_code, option_group="gift_scene", code_map=option_index["gift_scene_by_code"], field_label="礼金记录场景", ) if scene_error: errors.append(scene_error) validation_form = _build_validation_form( row=row, existing_household=existing_household, gift_method_option_id=method_option_id, gift_scene_option_id=scene_option_id, ) payload, validation_errors = parse_admin_form( validation_form, raw_tag_option_ids=_existing_tag_values(existing_household), valid_relation_category_ids=option_index["relation_category_ids"], relation_detail_parent_map=option_index["relation_detail_parent_map"], valid_tag_ids=option_index["tag_ids"], valid_method_ids=option_index["gift_method_ids"], valid_scene_ids=option_index["gift_scene_ids"], ) errors.extend(validation_errors) if errors: return None, errors return payload, [] def _build_validation_form( *, row: Mapping[str, str], existing_household: Household | None, gift_method_option_id: int | None, gift_scene_option_id: int | None, ) -> dict[str, str]: return { "household_code": row["household_code"], "head_name": row["head_name"], "phone": row["phone"], "side": row["side"], "relation_category_option_id": str(existing_household.relation_category_option_id or "") if existing_household else "", "relation_detail_option_id": str(existing_household.relation_detail_option_id or "") if existing_household else "", "invite_status": row["invite_status"], "attendance_status": row["attendance_status"], "expected_attendee_count": row["expected_attendee_count"], "actual_attendee_count": row["actual_attendee_count"], "child_count": row["child_count"], "red_packet_child_count": row["red_packet_child_count"], "total_gift_amount": row["total_gift_amount"], "gift_method_option_id": str(gift_method_option_id or ""), "gift_scene_option_id": str(gift_scene_option_id or ""), "favor_status": row["favor_status"], "candy_status": row["candy_status"], "child_red_packet_status": row["child_red_packet_status"], "note": row["note"], } def _existing_tag_values(existing_household: Household | None) -> list[str]: if existing_household is None or not existing_household.tag_option_ids_json: return [] return [str(value) for value in existing_household.tag_option_ids_json] def _resolve_option_code( option_code: str, *, option_group: str, code_map: Mapping[str, OptionItem], field_label: str, ) -> tuple[int | None, str | None]: normalized_code = option_code.strip() if not normalized_code: return None, None option = code_map.get(normalized_code) if option is None: return None, f"{field_label}编码“{normalized_code}”不存在。" if option.option_group != option_group: return None, f"{field_label}编码“{normalized_code}”不属于正确分组。" return option.id, None def _build_option_index() -> OptionIndex: options = db.session.execute( db.select(OptionItem).where(OptionItem.option_group.in_(OPTION_GROUPS)), ).scalars().all() relation_detail_parent_map = { option.id: option.parent_id for option in options if option.option_group == "relation_detail" } return { "relation_category_ids": {option.id for option in options if option.option_group == "relation_category"}, "relation_detail_parent_map": relation_detail_parent_map, "tag_ids": {option.id for option in options if option.option_group == "tag"}, "gift_method_ids": {option.id for option in options if option.option_group == "gift_method"}, "gift_scene_ids": {option.id for option in options if option.option_group == "gift_scene"}, "gift_method_by_code": { option.option_code: option for option in options if option.option_group == "gift_method" }, "gift_scene_by_code": { option.option_code: option for option in options if option.option_group == "gift_scene" }, } def _build_preview_summary(rows: Sequence[PreviewRow]) -> PreviewSummary: return { "create_count": sum(1 for row in rows if row.get("status") == "create"), "update_count": sum(1 for row in rows if row.get("status") == "update"), "conflict_count": sum(1 for row in rows if row.get("status") == "conflict"), "invalid_count": sum(1 for row in rows if row.get("status") == "invalid"), } def _serialize_payload(payload: Mapping[str, object] | None) -> dict[str, object] | None: if payload is None: return None serialized: dict[str, object] = {} for key, value in payload.items(): if isinstance(value, Decimal): serialized[key] = f"{value:.2f}" elif isinstance(value, list): serialized[key] = list(value) else: serialized[key] = value return serialized def _deserialize_payload(payload: object) -> dict[str, object] | None: if not isinstance(payload, dict): return None deserialized: dict[str, object] = {} for key, value in payload.items(): if key == "total_gift_amount": deserialized[key] = Decimal(str(value)) elif key == "tag_option_ids_json" and isinstance(value, list): deserialized[key] = [int(str(item)) for item in value] else: deserialized[key] = value return deserialized def _apply_payload( household: Household, payload: Mapping[str, object], *, actor_id: int | None, increment_version: bool, ) -> None: for field_name, value in payload.items(): setattr(household, field_name, value) if actor_id is not None: if household.created_by is None: household.created_by = actor_id household.updated_by = actor_id if increment_version: household.version += 1 def _generate_preview_token() -> str: return secrets.token_urlsafe(18) def _preview_directory() -> Path: return Path(current_app.instance_path) / PREVIEW_DIRECTORY_NAME def _preview_file_path(token: str) -> Path: return _preview_directory() / f"{token}.json" def _purge_expired_preview_files() -> None: preview_directory = _preview_directory() if not preview_directory.exists(): return now = datetime.now(UTC) ttl_seconds = int(current_app.config["HOUSEHOLD_CSV_PREVIEW_TTL_SECONDS"]) expires_before = now - timedelta(seconds=ttl_seconds) for file_path in preview_directory.glob("*.json"): if datetime.fromtimestamp(file_path.stat().st_mtime, UTC) < expires_before: file_path.unlink(missing_ok=True) def _household_identity_key(head_name: str | None, side: str | None, phone: str | None) -> str | None: normalized_head_name = _normalize_text(head_name) normalized_side = _normalize_text(side) normalized_phone = _normalize_phone(phone) if not normalized_head_name or not normalized_side or not normalized_phone: return None return f"{normalized_head_name}|{normalized_side}|{normalized_phone}" def _normalize_text(value: str | None) -> str: if not value: return "" return "".join(value.strip().lower().split()) def _normalize_phone(value: str | None) -> str: if not value: return "" digits_only = "".join(character for character in value if character.isdigit()) if digits_only: return digits_only return _normalize_text(value) def _household_brief(household: Household | None) -> HouseholdBrief | None: if household is None: return None return { "id": household.id, "household_code": household.household_code, "head_name": household.head_name, "phone": household.phone, "side": household.side, } def _coerce_household_brief(value: object) -> HouseholdBrief | None: if not isinstance(value, dict): return None household_id = value.get("id") household_code = value.get("household_code") head_name = value.get("head_name") side = value.get("side") phone = value.get("phone") if not isinstance(household_id, int) or not isinstance(household_code, str) or not isinstance(head_name, str) or not isinstance(side, str): return None return { "id": household_id, "household_code": household_code, "head_name": head_name, "phone": _string_or_none(phone), "side": side, } def _int_value(value: object) -> int: if isinstance(value, int): return value if isinstance(value, str) and value.strip(): return int(value) return 0 def _string_or_none(value: object) -> str | None: if value is None: return None if isinstance(value, str): return value return str(value)