You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

244 lines
7.6 KiB

from __future__ import annotations
from collections.abc import Sequence
from datetime import UTC, datetime
from decimal import Decimal, InvalidOperation
from sqlalchemy import Select
from sqlalchemy.orm import selectinload
from app.extensions import db
from app.models import GiftRecord, Household
GIFT_RECORD_TYPE_OPTIONS = (
"cash_gift",
"physical_gift",
"pre_wedding_cash",
"wedding_day_cash",
"post_wedding_cash",
)
GIFT_RECORD_TYPE_LABELS = {
"cash_gift": "礼金",
"physical_gift": "礼品",
"pre_wedding_cash": "婚前礼金",
"wedding_day_cash": "婚礼当天礼金",
"post_wedding_cash": "婚后补记礼金",
}
CASH_GIFT_RECORD_TYPES = frozenset(
{
"cash_gift",
"pre_wedding_cash",
"wedding_day_cash",
"post_wedding_cash",
},
)
def gift_record_query(
household_id: int | None = None,
*,
include_deleted: bool = False,
) -> Select[tuple[GiftRecord]]:
query = db.select(GiftRecord).options(
selectinload(GiftRecord.method_option),
selectinload(GiftRecord.scene_option),
)
if household_id is not None:
query = query.where(GiftRecord.household_id == household_id)
if not include_deleted:
query = query.where(GiftRecord.deleted_at.is_(None))
return query.order_by(GiftRecord.record_time.desc(), GiftRecord.id.desc())
def list_gift_records(household_id: int, *, include_deleted: bool = False) -> Sequence[GiftRecord]:
result = db.session.execute(
gift_record_query(household_id, include_deleted=include_deleted),
)
return result.scalars().all()
def get_gift_record_or_none(record_id: int, *, include_deleted: bool = False) -> GiftRecord | None:
result = db.session.execute(
gift_record_query(include_deleted=include_deleted).where(GiftRecord.id == record_id),
)
return result.scalar_one_or_none()
def build_new_gift_record_draft(household_id: int) -> GiftRecord:
record = GiftRecord()
record.household_id = household_id
record.record_type = "cash_gift"
record.amount = Decimal("0.00")
record.gift_name = None
record.estimated_value = None
record.method_option_id = None
record.scene_option_id = None
record.record_time = datetime.now(UTC).replace(tzinfo=None)
record.note = None
return record
def household_has_active_gift_records(household_id: int) -> bool:
return get_gift_record_count(household_id) > 0
def get_gift_record_count(household_id: int) -> int:
result = db.session.execute(
db.select(GiftRecord.id)
.where(
GiftRecord.household_id == household_id,
GiftRecord.deleted_at.is_(None),
)
.limit(1),
)
return 1 if result.scalar_one_or_none() is not None else 0
def parse_gift_record_form(
form: dict[str, str],
*,
household_id: int,
valid_method_ids: set[int],
valid_scene_ids: set[int],
) -> tuple[dict[str, object], list[str]]:
errors: list[str] = []
record_type = form.get("record_type", "").strip()
if record_type not in GIFT_RECORD_TYPE_OPTIONS:
errors.append("礼金记录类型不合法。")
amount = _parse_optional_decimal(form.get("amount", ""), field_label="礼金金额", errors=errors)
estimated_value = _parse_optional_decimal(form.get("estimated_value", ""), field_label="礼品估值", errors=errors)
gift_name = form.get("gift_name", "").strip() or None
if record_type in CASH_GIFT_RECORD_TYPES and amount is None:
errors.append("现金类礼金记录必须填写礼金金额。")
if record_type == "physical_gift" and not gift_name:
errors.append("礼品记录必须填写礼品名称。")
method_option_id = _parse_optional_option_id(
form.get("method_option_id", ""),
valid_ids=valid_method_ids,
field_label="礼金方式",
errors=errors,
)
scene_option_id = _parse_optional_option_id(
form.get("scene_option_id", ""),
valid_ids=valid_scene_ids,
field_label="记录场景",
errors=errors,
)
record_time = _parse_record_time(form.get("record_time", ""), errors)
if errors:
return {}, errors
return {
"household_id": household_id,
"record_type": record_type,
"amount": amount,
"gift_name": gift_name,
"estimated_value": estimated_value,
"method_option_id": method_option_id,
"scene_option_id": scene_option_id,
"record_time": record_time,
"note": form.get("note", "").strip() or None,
}, []
def serialize_gift_record_snapshot(record: GiftRecord) -> dict[str, object]:
return {
"id": record.id,
"household_id": record.household_id,
"record_type": record.record_type,
"amount": str(record.amount) if record.amount is not None else None,
"gift_name": record.gift_name,
"estimated_value": str(record.estimated_value) if record.estimated_value is not None else None,
"method_option_id": record.method_option_id,
"scene_option_id": record.scene_option_id,
"record_time": record.record_time.isoformat(sep=" ") if record.record_time is not None else None,
"deleted_at": record.deleted_at.isoformat(sep=" ") if record.deleted_at is not None else None,
"note": record.note,
}
def recalculate_household_gift_summary(household: Household) -> None:
records = list_gift_records(household.id)
cash_records = [record for record in records if record.record_type in CASH_GIFT_RECORD_TYPES]
household.total_gift_amount = sum(
(record.amount for record in cash_records if record.amount is not None),
start=Decimal("0.00"),
).quantize(Decimal("0.01"))
household.gift_method_option_id = next(
(record.method_option_id for record in cash_records if record.method_option_id is not None),
None,
)
household.gift_scene_option_id = next(
(record.scene_option_id for record in cash_records if record.scene_option_id is not None),
None,
)
def gift_record_type_label(record_type: str | None) -> str:
if not record_type:
return "-"
normalized_record_type = record_type.strip()
if not normalized_record_type:
return "-"
return GIFT_RECORD_TYPE_LABELS.get(normalized_record_type, normalized_record_type)
def _parse_optional_decimal(raw_value: str, *, field_label: str, errors: list[str]) -> Decimal | None:
value = raw_value.strip()
if not value:
return None
try:
parsed = Decimal(value)
except InvalidOperation:
errors.append(f"{field_label}必须是合法金额。")
return None
if parsed < 0:
errors.append(f"{field_label}不能为负数。")
return None
return parsed.quantize(Decimal("0.01"))
def _parse_optional_option_id(raw_value: str, *, valid_ids: set[int], field_label: str, errors: list[str]) -> int | None:
value = raw_value.strip()
if not value:
return None
try:
parsed = int(value)
except ValueError:
errors.append(f"{field_label}选项不合法。")
return None
if parsed not in valid_ids:
errors.append(f"{field_label}选项不合法。")
return None
return parsed
def _parse_record_time(raw_value: str, errors: list[str]) -> datetime | None:
value = raw_value.strip()
if not value:
errors.append("记录时间不能为空。")
return None
try:
parsed = datetime.fromisoformat(value)
except ValueError:
errors.append("记录时间格式不合法。")
return None
if parsed.tzinfo is not None:
parsed = parsed.astimezone(UTC).replace(tzinfo=None)
return parsed