You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

763 lines
26 KiB

from __future__ import annotations
import csv
import io
import json
import secrets
from collections import Counter
from collections.abc import Iterator, Mapping, Sequence
from datetime import UTC, datetime, timedelta
from decimal import Decimal
from pathlib import Path
import re
from typing import TypedDict
from flask import current_app
from app.extensions import db
from app.models import Household, OptionItem
from app.services.households import (
build_new_household_draft,
parse_admin_form,
serialize_admin_edit_snapshot,
)
CSV_IMPORT_FIELDS = (
"household_code",
"head_name",
"phone",
"side",
"invite_status",
"attendance_status",
"expected_attendee_count",
"actual_attendee_count",
"child_count",
"red_packet_child_count",
"total_gift_amount",
"gift_method_option_code",
"gift_scene_option_code",
"favor_status",
"candy_status",
"child_red_packet_status",
"note",
)
PREVIEW_DIRECTORY_NAME = "csv_previews"
PREVIEW_FILE_PATTERN = re.compile(r"^[A-Za-z0-9_-]{10,}$")
IMPORT_CONFLICT_MODES = ("skip_conflicts", "update_by_code")
OPTION_GROUPS = (
"relation_category",
"relation_detail",
"tag",
"gift_method",
"gift_scene",
)
class OptionIndex(TypedDict):
relation_category_ids: set[int]
relation_detail_parent_map: dict[int, int | None]
tag_ids: set[int]
gift_method_ids: set[int]
gift_scene_ids: set[int]
gift_method_by_code: dict[str, OptionItem]
gift_scene_by_code: dict[str, OptionItem]
class PreviewSummary(TypedDict):
create_count: int
update_count: int
conflict_count: int
invalid_count: int
class HouseholdBrief(TypedDict):
id: int
household_code: str
head_name: str
phone: str | None
side: str
class PreviewRow(TypedDict):
row_number: int
raw: dict[str, str]
payload: dict[str, object] | None
status: str
errors: list[str]
matched_household: HouseholdBrief | None
conflict_household: HouseholdBrief | None
class PreviewData(TypedDict):
token: str
file_name: str
created_at: str
total_rows: int
rows: list[PreviewRow]
summary: PreviewSummary
class ImportSummary(TypedDict):
file_name: str
total_rows_parsed: int
rows_created: int
rows_updated: int
rows_skipped: int
rows_invalid: int
conflict_mode: str
class CsvImportError(ValueError):
"""CSV 导入预览阶段的用户可见错误。"""
def build_household_csv_template() -> str:
buffer = io.StringIO(newline="")
writer = csv.writer(buffer, lineterminator="\r\n")
writer.writerow(CSV_IMPORT_FIELDS)
return "\ufeff" + buffer.getvalue()
def build_household_export_rows(households: Sequence[Household]) -> Iterator[list[str]]:
yield list(CSV_IMPORT_FIELDS)
for household in households:
yield [
household.household_code,
household.head_name,
household.phone or "",
household.side,
household.invite_status,
household.attendance_status,
str(household.expected_attendee_count),
str(household.actual_attendee_count),
str(household.child_count),
str(household.red_packet_child_count),
f"{household.total_gift_amount:.2f}",
household.gift_method_option.option_code if household.gift_method_option else "",
household.gift_scene_option.option_code if household.gift_scene_option else "",
household.favor_status,
household.candy_status,
household.child_red_packet_status,
household.note or "",
]
def stream_household_csv(households: Sequence[Household]) -> Iterator[str]:
yield "\ufeff"
buffer = io.StringIO(newline="")
writer = csv.writer(buffer, lineterminator="\r\n")
for row in build_household_export_rows(households):
writer.writerow(row)
yield buffer.getvalue()
buffer.seek(0)
buffer.truncate(0)
def build_household_export_filename(*, scope: str) -> str:
timestamp = datetime.now(UTC).strftime("%Y%m%d-%H%M%S")
return f"households-{scope}-{timestamp}.csv"
def preview_household_csv(*, file_name: str, file_bytes: bytes) -> PreviewData:
if not file_name.lower().endswith(".csv"):
raise CsvImportError("请上传 .csv 格式文件。")
if not file_bytes:
raise CsvImportError("上传文件不能为空。")
max_bytes = int(current_app.config["HOUSEHOLD_CSV_MAX_UPLOAD_BYTES"])
if len(file_bytes) > max_bytes:
raise CsvImportError("CSV 文件不能超过 5 MB。")
try:
decoded_text = file_bytes.decode("utf-8-sig")
except UnicodeDecodeError as exc:
raise CsvImportError("CSV 文件必须使用 UTF-8 编码。") from exc
reader = csv.DictReader(io.StringIO(decoded_text, newline=""))
if reader.fieldnames is None:
raise CsvImportError("CSV 文件缺少表头,无法导入。")
normalized_headers = [header.strip() for header in reader.fieldnames]
missing_headers = [field for field in CSV_IMPORT_FIELDS if field not in normalized_headers]
if missing_headers:
missing_header_text = "".join(missing_headers)
raise CsvImportError(f"CSV 缺少必填表头:{missing_header_text}")
normalized_rows = [
{
field: ((row.get(field) or "").strip())
for field in CSV_IMPORT_FIELDS
}
for row in reader
]
if not normalized_rows:
raise CsvImportError("CSV 文件中没有可导入的数据行。")
option_index = _build_option_index()
all_households = db.session.execute(
db.select(Household).order_by(Household.id.asc()),
).scalars().all()
households = [household for household in all_households if household.deleted_at is None]
households_by_code = {household.household_code: household for household in all_households}
households_by_identity = {
_household_identity_key(household.head_name, household.side, household.phone): household
for household in households
if _household_identity_key(household.head_name, household.side, household.phone) is not None
}
duplicate_codes = Counter(
row["household_code"]
for row in normalized_rows
if row["household_code"]
)
preview_rows: list[PreviewRow] = []
for index, row in enumerate(normalized_rows, start=2):
code_match = households_by_code.get(row["household_code"])
deleted_code_match = code_match if code_match is not None and code_match.deleted_at is not None else None
active_code_match = code_match if code_match is not None and code_match.deleted_at is None else None
identity_key = _household_identity_key(row["head_name"], row["side"], row["phone"])
conflict_match = households_by_identity.get(identity_key) if identity_key is not None else None
if active_code_match is not None and conflict_match is active_code_match:
conflict_match = None
row_errors: list[str] = []
if row["household_code"] and duplicate_codes[row["household_code"]] > 1:
row_errors.append("CSV 文件中存在重复的户编码,请先去重后再导入。")
if deleted_code_match is not None:
row_errors.append("户编码已被历史删除记录占用,请改用新的户编码。")
payload, validation_errors = _validate_preview_row(
row=row,
option_index=option_index,
existing_household=active_code_match,
)
row_errors.extend(validation_errors)
if row_errors:
row_status = "invalid"
elif active_code_match is not None:
row_status = "update"
elif conflict_match is not None:
row_status = "conflict"
else:
row_status = "create"
preview_row: PreviewRow = {
"row_number": index,
"raw": row,
"payload": _serialize_payload(payload),
"status": row_status,
"errors": row_errors,
"matched_household": _household_brief(active_code_match),
"conflict_household": _household_brief(conflict_match),
}
preview_rows.append(preview_row)
preview: PreviewData = {
"token": _generate_preview_token(),
"file_name": file_name,
"created_at": datetime.now(UTC).isoformat(),
"total_rows": len(preview_rows),
"rows": preview_rows,
"summary": _build_preview_summary(preview_rows),
}
save_household_import_preview(preview)
return preview
def save_household_import_preview(preview: PreviewData) -> None:
_purge_expired_preview_files()
token = str(preview["token"])
preview_path = _preview_file_path(token)
preview_path.parent.mkdir(parents=True, exist_ok=True)
preview_path.write_text(
json.dumps(preview, ensure_ascii=False, separators=(",", ":")),
encoding="utf-8",
)
def load_household_import_preview(token: str) -> PreviewData | None:
_purge_expired_preview_files()
if not PREVIEW_FILE_PATTERN.fullmatch(token):
return None
preview_path = _preview_file_path(token)
if not preview_path.exists():
return None
try:
preview = json.loads(preview_path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
return None
if not isinstance(preview, dict):
return None
token_value = preview.get("token")
file_name_value = preview.get("file_name")
created_at_value = preview.get("created_at")
total_rows_value = preview.get("total_rows")
rows_value = preview.get("rows")
summary_value = preview.get("summary")
if not isinstance(token_value, str) or not isinstance(file_name_value, str) or not isinstance(created_at_value, str):
return None
if not isinstance(total_rows_value, int) or not isinstance(rows_value, list) or not isinstance(summary_value, dict):
return None
preview_rows: list[PreviewRow] = []
for row in rows_value:
if not isinstance(row, dict):
return None
row_number = row.get("row_number")
raw = row.get("raw")
payload = row.get("payload")
status = row.get("status")
errors = row.get("errors")
matched_household = row.get("matched_household")
conflict_household = row.get("conflict_household")
if not isinstance(row_number, int) or not isinstance(raw, dict) or not isinstance(status, str) or not isinstance(errors, list):
return None
preview_row: PreviewRow = {
"row_number": row_number,
"raw": {str(key): str(value) for key, value in raw.items()},
"payload": payload if isinstance(payload, dict) else None,
"status": status,
"errors": [str(error) for error in errors],
"matched_household": _coerce_household_brief(matched_household),
"conflict_household": _coerce_household_brief(conflict_household),
}
preview_rows.append(preview_row)
preview_data: PreviewData = {
"token": token_value,
"file_name": file_name_value,
"created_at": created_at_value,
"total_rows": total_rows_value,
"rows": preview_rows,
"summary": {
"create_count": _int_value(summary_value.get("create_count")),
"update_count": _int_value(summary_value.get("update_count")),
"conflict_count": _int_value(summary_value.get("conflict_count")),
"invalid_count": _int_value(summary_value.get("invalid_count")),
},
}
return preview_data
def delete_household_import_preview(token: str) -> None:
if not PREVIEW_FILE_PATTERN.fullmatch(token):
return
preview_path = _preview_file_path(token)
if preview_path.exists():
preview_path.unlink()
def apply_household_import_preview(*, preview: PreviewData, actor_id: int | None, conflict_mode: str) -> ImportSummary:
if conflict_mode not in IMPORT_CONFLICT_MODES:
raise CsvImportError("导入处理方式不合法。")
rows = preview["rows"]
file_name = preview["file_name"]
total_rows = preview["total_rows"]
invalid_rows = 0
created_rows = 0
updated_rows = 0
skipped_rows = 0
current_households = db.session.execute(
db.select(Household),
).scalars().all()
active_households = [household for household in current_households if household.deleted_at is None]
households_by_code = {household.household_code: household for household in current_households}
households_by_identity = {
_household_identity_key(household.head_name, household.side, household.phone): household
for household in active_households
if _household_identity_key(household.head_name, household.side, household.phone) is not None
}
for row in rows:
status = row["status"]
if status == "invalid":
invalid_rows += 1
continue
payload = _deserialize_payload(row["payload"])
if payload is None:
invalid_rows += 1
continue
household_code = str(payload["household_code"])
existing_household = households_by_code.get(household_code)
if existing_household is not None and existing_household.deleted_at is not None:
skipped_rows += 1
continue
identity_key = _household_identity_key(
str(payload["head_name"]),
str(payload["side"]),
_string_or_none(payload.get("phone")),
)
conflict_household = households_by_identity.get(identity_key) if identity_key is not None else None
if existing_household is not None and conflict_household is existing_household:
conflict_household = None
if existing_household is not None:
if conflict_mode != "update_by_code":
skipped_rows += 1
continue
before_snapshot = serialize_admin_edit_snapshot(existing_household)
_apply_payload(existing_household, payload, actor_id=actor_id, increment_version=True)
after_snapshot = serialize_admin_edit_snapshot(existing_household)
if after_snapshot == before_snapshot:
skipped_rows += 1
continue
updated_rows += 1
updated_identity_key = _household_identity_key(existing_household.head_name, existing_household.side, existing_household.phone)
if updated_identity_key is not None:
households_by_identity[updated_identity_key] = existing_household
continue
if conflict_household is not None:
skipped_rows += 1
continue
new_household = build_new_household_draft(household_code=household_code)
_apply_payload(new_household, payload, actor_id=actor_id, increment_version=False)
db.session.add(new_household)
db.session.flush()
households_by_code[new_household.household_code] = new_household
new_identity_key = _household_identity_key(new_household.head_name, new_household.side, new_household.phone)
if new_identity_key is not None:
households_by_identity[new_identity_key] = new_household
created_rows += 1
summary: ImportSummary = {
"file_name": file_name,
"total_rows_parsed": total_rows,
"rows_created": created_rows,
"rows_updated": updated_rows,
"rows_skipped": skipped_rows,
"rows_invalid": invalid_rows,
"conflict_mode": conflict_mode,
}
return summary
def get_household_import_conflict_modes() -> tuple[tuple[str, str, str], ...]:
return (
("skip_conflicts", "跳过冲突与更新候选", "只新增明确可创建的行,编码更新候选和疑似重复行都会跳过"),
("update_by_code", "按户编码更新", "编码命中的行会更新,疑似重复但无编码命中的冲突行仍会跳过"),
)
def option_code_label(option_group: str, option_code: str | None) -> str:
if not option_code:
return ""
option = db.session.execute(
db.select(OptionItem)
.where(OptionItem.option_group == option_group, OptionItem.option_code == option_code)
.limit(1),
).scalar_one_or_none()
if option is None:
return option_code
return option.option_label
def _validate_preview_row(
*,
row: Mapping[str, str],
option_index: OptionIndex,
existing_household: Household | None,
) -> tuple[dict[str, object] | None, list[str]]:
errors: list[str] = []
gift_method_code = row["gift_method_option_code"]
gift_scene_code = row["gift_scene_option_code"]
method_option_id, method_error = _resolve_option_code(
gift_method_code,
option_group="gift_method",
code_map=option_index["gift_method_by_code"],
field_label="礼金方式",
)
if method_error:
errors.append(method_error)
scene_option_id, scene_error = _resolve_option_code(
gift_scene_code,
option_group="gift_scene",
code_map=option_index["gift_scene_by_code"],
field_label="礼金记录场景",
)
if scene_error:
errors.append(scene_error)
validation_form = _build_validation_form(
row=row,
existing_household=existing_household,
gift_method_option_id=method_option_id,
gift_scene_option_id=scene_option_id,
)
payload, validation_errors = parse_admin_form(
validation_form,
raw_tag_option_ids=_existing_tag_values(existing_household),
valid_relation_category_ids=option_index["relation_category_ids"],
relation_detail_parent_map=option_index["relation_detail_parent_map"],
valid_tag_ids=option_index["tag_ids"],
valid_method_ids=option_index["gift_method_ids"],
valid_scene_ids=option_index["gift_scene_ids"],
)
errors.extend(validation_errors)
if errors:
return None, errors
return payload, []
def _build_validation_form(
*,
row: Mapping[str, str],
existing_household: Household | None,
gift_method_option_id: int | None,
gift_scene_option_id: int | None,
) -> dict[str, str]:
return {
"household_code": row["household_code"],
"head_name": row["head_name"],
"phone": row["phone"],
"side": row["side"],
"relation_category_option_id": str(existing_household.relation_category_option_id or "") if existing_household else "",
"relation_detail_option_id": str(existing_household.relation_detail_option_id or "") if existing_household else "",
"invite_status": row["invite_status"],
"attendance_status": row["attendance_status"],
"expected_attendee_count": row["expected_attendee_count"],
"actual_attendee_count": row["actual_attendee_count"],
"child_count": row["child_count"],
"red_packet_child_count": row["red_packet_child_count"],
"total_gift_amount": row["total_gift_amount"],
"gift_method_option_id": str(gift_method_option_id or ""),
"gift_scene_option_id": str(gift_scene_option_id or ""),
"favor_status": row["favor_status"],
"candy_status": row["candy_status"],
"child_red_packet_status": row["child_red_packet_status"],
"note": row["note"],
}
def _existing_tag_values(existing_household: Household | None) -> list[str]:
if existing_household is None or not existing_household.tag_option_ids_json:
return []
return [str(value) for value in existing_household.tag_option_ids_json]
def _resolve_option_code(
option_code: str,
*,
option_group: str,
code_map: Mapping[str, OptionItem],
field_label: str,
) -> tuple[int | None, str | None]:
normalized_code = option_code.strip()
if not normalized_code:
return None, None
option = code_map.get(normalized_code)
if option is None:
return None, f"{field_label}编码“{normalized_code}”不存在。"
if option.option_group != option_group:
return None, f"{field_label}编码“{normalized_code}”不属于正确分组。"
return option.id, None
def _build_option_index() -> OptionIndex:
options = db.session.execute(
db.select(OptionItem).where(OptionItem.option_group.in_(OPTION_GROUPS)),
).scalars().all()
relation_detail_parent_map = {
option.id: option.parent_id
for option in options
if option.option_group == "relation_detail"
}
return {
"relation_category_ids": {option.id for option in options if option.option_group == "relation_category"},
"relation_detail_parent_map": relation_detail_parent_map,
"tag_ids": {option.id for option in options if option.option_group == "tag"},
"gift_method_ids": {option.id for option in options if option.option_group == "gift_method"},
"gift_scene_ids": {option.id for option in options if option.option_group == "gift_scene"},
"gift_method_by_code": {
option.option_code: option
for option in options
if option.option_group == "gift_method"
},
"gift_scene_by_code": {
option.option_code: option
for option in options
if option.option_group == "gift_scene"
},
}
def _build_preview_summary(rows: Sequence[PreviewRow]) -> PreviewSummary:
return {
"create_count": sum(1 for row in rows if row.get("status") == "create"),
"update_count": sum(1 for row in rows if row.get("status") == "update"),
"conflict_count": sum(1 for row in rows if row.get("status") == "conflict"),
"invalid_count": sum(1 for row in rows if row.get("status") == "invalid"),
}
def _serialize_payload(payload: Mapping[str, object] | None) -> dict[str, object] | None:
if payload is None:
return None
serialized: dict[str, object] = {}
for key, value in payload.items():
if isinstance(value, Decimal):
serialized[key] = f"{value:.2f}"
elif isinstance(value, list):
serialized[key] = list(value)
else:
serialized[key] = value
return serialized
def _deserialize_payload(payload: object) -> dict[str, object] | None:
if not isinstance(payload, dict):
return None
deserialized: dict[str, object] = {}
for key, value in payload.items():
if key == "total_gift_amount":
deserialized[key] = Decimal(str(value))
elif key == "tag_option_ids_json" and isinstance(value, list):
deserialized[key] = [int(str(item)) for item in value]
else:
deserialized[key] = value
return deserialized
def _apply_payload(
household: Household,
payload: Mapping[str, object],
*,
actor_id: int | None,
increment_version: bool,
) -> None:
for field_name, value in payload.items():
setattr(household, field_name, value)
if actor_id is not None:
if household.created_by is None:
household.created_by = actor_id
household.updated_by = actor_id
if increment_version:
household.version += 1
def _generate_preview_token() -> str:
return secrets.token_urlsafe(18)
def _preview_directory() -> Path:
return Path(current_app.instance_path) / PREVIEW_DIRECTORY_NAME
def _preview_file_path(token: str) -> Path:
return _preview_directory() / f"{token}.json"
def _purge_expired_preview_files() -> None:
preview_directory = _preview_directory()
if not preview_directory.exists():
return
now = datetime.now(UTC)
ttl_seconds = int(current_app.config["HOUSEHOLD_CSV_PREVIEW_TTL_SECONDS"])
expires_before = now - timedelta(seconds=ttl_seconds)
for file_path in preview_directory.glob("*.json"):
if datetime.fromtimestamp(file_path.stat().st_mtime, UTC) < expires_before:
file_path.unlink(missing_ok=True)
def _household_identity_key(head_name: str | None, side: str | None, phone: str | None) -> str | None:
normalized_head_name = _normalize_text(head_name)
normalized_side = _normalize_text(side)
normalized_phone = _normalize_phone(phone)
if not normalized_head_name or not normalized_side or not normalized_phone:
return None
return f"{normalized_head_name}|{normalized_side}|{normalized_phone}"
def _normalize_text(value: str | None) -> str:
if not value:
return ""
return "".join(value.strip().lower().split())
def _normalize_phone(value: str | None) -> str:
if not value:
return ""
digits_only = "".join(character for character in value if character.isdigit())
if digits_only:
return digits_only
return _normalize_text(value)
def _household_brief(household: Household | None) -> HouseholdBrief | None:
if household is None:
return None
return {
"id": household.id,
"household_code": household.household_code,
"head_name": household.head_name,
"phone": household.phone,
"side": household.side,
}
def _coerce_household_brief(value: object) -> HouseholdBrief | None:
if not isinstance(value, dict):
return None
household_id = value.get("id")
household_code = value.get("household_code")
head_name = value.get("head_name")
side = value.get("side")
phone = value.get("phone")
if not isinstance(household_id, int) or not isinstance(household_code, str) or not isinstance(head_name, str) or not isinstance(side, str):
return None
return {
"id": household_id,
"household_code": household_code,
"head_name": head_name,
"phone": _string_or_none(phone),
"side": side,
}
def _int_value(value: object) -> int:
if isinstance(value, int):
return value
if isinstance(value, str) and value.strip():
return int(value)
return 0
def _string_or_none(value: object) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(value)