from __future__ import annotations from collections.abc import Sequence from datetime import UTC, datetime from decimal import Decimal, InvalidOperation from sqlalchemy import Select from sqlalchemy.orm import selectinload from app.extensions import db from app.models import GiftRecord, Household GIFT_RECORD_TYPE_OPTIONS = ( "cash_gift", "physical_gift", "pre_wedding_cash", "wedding_day_cash", "post_wedding_cash", ) GIFT_RECORD_TYPE_LABELS = { "cash_gift": "礼金", "physical_gift": "礼品", "pre_wedding_cash": "婚前礼金", "wedding_day_cash": "婚礼当天礼金", "post_wedding_cash": "婚后补记礼金", } CASH_GIFT_RECORD_TYPES = frozenset( { "cash_gift", "pre_wedding_cash", "wedding_day_cash", "post_wedding_cash", }, ) def gift_record_query( household_id: int | None = None, *, include_deleted: bool = False, ) -> Select[tuple[GiftRecord]]: query = db.select(GiftRecord).options( selectinload(GiftRecord.method_option), selectinload(GiftRecord.scene_option), ) if household_id is not None: query = query.where(GiftRecord.household_id == household_id) if not include_deleted: query = query.where(GiftRecord.deleted_at.is_(None)) return query.order_by(GiftRecord.record_time.desc(), GiftRecord.id.desc()) def list_gift_records(household_id: int, *, include_deleted: bool = False) -> Sequence[GiftRecord]: result = db.session.execute( gift_record_query(household_id, include_deleted=include_deleted), ) return result.scalars().all() def get_gift_record_or_none(record_id: int, *, include_deleted: bool = False) -> GiftRecord | None: result = db.session.execute( gift_record_query(include_deleted=include_deleted).where(GiftRecord.id == record_id), ) return result.scalar_one_or_none() def build_new_gift_record_draft(household_id: int) -> GiftRecord: record = GiftRecord() record.household_id = household_id record.record_type = "cash_gift" record.amount = Decimal("0.00") record.gift_name = None record.estimated_value = None record.method_option_id = None record.scene_option_id = None record.record_time = datetime.now(UTC).replace(tzinfo=None) record.note = None return record def household_has_active_gift_records(household_id: int) -> bool: return get_gift_record_count(household_id) > 0 def get_gift_record_count(household_id: int) -> int: result = db.session.execute( db.select(GiftRecord.id) .where( GiftRecord.household_id == household_id, GiftRecord.deleted_at.is_(None), ) .limit(1), ) return 1 if result.scalar_one_or_none() is not None else 0 def parse_gift_record_form( form: dict[str, str], *, household_id: int, valid_method_ids: set[int], valid_scene_ids: set[int], ) -> tuple[dict[str, object], list[str]]: errors: list[str] = [] record_type = form.get("record_type", "").strip() if record_type not in GIFT_RECORD_TYPE_OPTIONS: errors.append("礼金记录类型不合法。") amount = _parse_optional_decimal(form.get("amount", ""), field_label="礼金金额", errors=errors) estimated_value = _parse_optional_decimal(form.get("estimated_value", ""), field_label="礼品估值", errors=errors) gift_name = form.get("gift_name", "").strip() or None if record_type in CASH_GIFT_RECORD_TYPES and amount is None: errors.append("现金类礼金记录必须填写礼金金额。") if record_type == "physical_gift" and not gift_name: errors.append("礼品记录必须填写礼品名称。") method_option_id = _parse_optional_option_id( form.get("method_option_id", ""), valid_ids=valid_method_ids, field_label="礼金方式", errors=errors, ) scene_option_id = _parse_optional_option_id( form.get("scene_option_id", ""), valid_ids=valid_scene_ids, field_label="记录场景", errors=errors, ) record_time = _parse_record_time(form.get("record_time", ""), errors) if errors: return {}, errors return { "household_id": household_id, "record_type": record_type, "amount": amount, "gift_name": gift_name, "estimated_value": estimated_value, "method_option_id": method_option_id, "scene_option_id": scene_option_id, "record_time": record_time, "note": form.get("note", "").strip() or None, }, [] def serialize_gift_record_snapshot(record: GiftRecord) -> dict[str, object]: return { "id": record.id, "household_id": record.household_id, "record_type": record.record_type, "amount": str(record.amount) if record.amount is not None else None, "gift_name": record.gift_name, "estimated_value": str(record.estimated_value) if record.estimated_value is not None else None, "method_option_id": record.method_option_id, "scene_option_id": record.scene_option_id, "record_time": record.record_time.isoformat(sep=" ") if record.record_time is not None else None, "deleted_at": record.deleted_at.isoformat(sep=" ") if record.deleted_at is not None else None, "note": record.note, } def recalculate_household_gift_summary(household: Household) -> None: records = list_gift_records(household.id) cash_records = [record for record in records if record.record_type in CASH_GIFT_RECORD_TYPES] household.total_gift_amount = sum( (record.amount for record in cash_records if record.amount is not None), start=Decimal("0.00"), ).quantize(Decimal("0.01")) household.gift_method_option_id = next( (record.method_option_id for record in cash_records if record.method_option_id is not None), None, ) household.gift_scene_option_id = next( (record.scene_option_id for record in cash_records if record.scene_option_id is not None), None, ) def gift_record_type_label(record_type: str | None) -> str: if not record_type: return "-" normalized_record_type = record_type.strip() if not normalized_record_type: return "-" return GIFT_RECORD_TYPE_LABELS.get(normalized_record_type, normalized_record_type) def _parse_optional_decimal(raw_value: str, *, field_label: str, errors: list[str]) -> Decimal | None: value = raw_value.strip() if not value: return None try: parsed = Decimal(value) except InvalidOperation: errors.append(f"{field_label}必须是合法金额。") return None if parsed < 0: errors.append(f"{field_label}不能为负数。") return None return parsed.quantize(Decimal("0.01")) def _parse_optional_option_id(raw_value: str, *, valid_ids: set[int], field_label: str, errors: list[str]) -> int | None: value = raw_value.strip() if not value: return None try: parsed = int(value) except ValueError: errors.append(f"{field_label}选项不合法。") return None if parsed not in valid_ids: errors.append(f"{field_label}选项不合法。") return None return parsed def _parse_record_time(raw_value: str, errors: list[str]) -> datetime | None: value = raw_value.strip() if not value: errors.append("记录时间不能为空。") return None try: parsed = datetime.fromisoformat(value) except ValueError: errors.append("记录时间格式不合法。") return None if parsed.tzinfo is not None: parsed = parsed.astimezone(UTC).replace(tzinfo=None) return parsed