from __future__ import annotations from collections.abc import Callable from functools import wraps from typing import TypeVar from urllib.parse import urljoin, urlparse try: from typing import ParamSpec except ImportError: # pragma: no cover - Python 3.10+ uses typing.ParamSpec directly from typing_extensions import ParamSpec from flask import flash, g, redirect, request, session, url_for from flask.typing import ResponseReturnValue from app.extensions import db from app.models import Account from app.services import account_snapshot, write_audit_log SESSION_ACCOUNT_ID_KEY = "account_id" ALLOWED_ACCOUNT_ROLES = {"admin", "editor", "entry_only", "quick_editor"} ACTIVE_ACCOUNT_STATUS = "active" P = ParamSpec("P") R = TypeVar("R") def get_current_account() -> Account | None: account_id = session.get(SESSION_ACCOUNT_ID_KEY) if not isinstance(account_id, int): return None account = db.session.get(Account, account_id) if account is None: session.pop(SESSION_ACCOUNT_ID_KEY, None) return None if account.status != ACTIVE_ACCOUNT_STATUS: session.pop(SESSION_ACCOUNT_ID_KEY, None) return None return account def load_current_account() -> None: g.current_account = get_current_account() def login_account(account: Account) -> None: session.clear() session[SESSION_ACCOUNT_ID_KEY] = account.id session.permanent = True def logout_account() -> None: session.clear() def is_authenticated() -> bool: return isinstance(getattr(g, "current_account", None), Account) def login_required(view: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]: @wraps(view) def wrapped(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue: if not is_authenticated(): flash("请先登录。", "warning") redirect_target = url_for("auth.login", next=request.full_path.rstrip("?")) write_audit_log( action_type="auth_unauthenticated", target_type="route", target_display_name=request.path, after_data={ "reason": "login_required", "redirect_to": redirect_target, }, ) return redirect(redirect_target) return view(*args, **kwargs) return wrapped def role_required(*roles: str) -> Callable[[Callable[P, R]], Callable[P, R | ResponseReturnValue]]: required_roles = set(roles) def decorator(view: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]: @wraps(view) def wrapped(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue: current_account = getattr(g, "current_account", None) if not isinstance(current_account, Account): flash("请先登录。", "warning") redirect_target = url_for("auth.login", next=request.full_path.rstrip("?")) write_audit_log( action_type="auth_unauthenticated", target_type="route", target_display_name=request.path, after_data={ "reason": "login_required", "redirect_to": redirect_target, }, ) return redirect(redirect_target) if current_account.role not in required_roles: flash("您没有权限访问该页面。", "warning") redirect_target = default_redirect_path(current_account) write_audit_log( action_type="auth_forbidden", target_type="route", actor=current_account, target_display_name=request.path, after_data={ "account": account_snapshot(current_account), "required_roles": sorted(required_roles), "redirect_to": redirect_target, }, ) return redirect(redirect_target) return view(*args, **kwargs) return wrapped return decorator def default_redirect_path(account: Account) -> str: if account.role in {"entry_only", "quick_editor"}: return url_for("quick_entry.index") return url_for("main.index") def is_safe_redirect_target(target: str | None) -> bool: if not target: return False host_url = request.host_url ref_url = urlparse(host_url) test_url = urlparse(urljoin(host_url, target)) return test_url.scheme in {"http", "https"} and test_url.netloc == ref_url.netloc def resolve_post_login_redirect(account: Account) -> str: next_target = request.values.get("next") if is_safe_redirect_target(next_target): return str(next_target) return default_redirect_path(account)