You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

156 lines
4.8 KiB

from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import TypeVar
from urllib.parse import urljoin, urlparse
try:
from typing import ParamSpec
except ImportError: # pragma: no cover - Python 3.10+ uses typing.ParamSpec directly
from typing_extensions import ParamSpec
from flask import flash, g, redirect, request, session, url_for
from flask.typing import ResponseReturnValue
from app.extensions import db
from app.models import Account
from app.services import account_snapshot, write_audit_log
SESSION_ACCOUNT_ID_KEY = "account_id"
ALLOWED_ACCOUNT_ROLES = {"admin", "editor", "entry_only", "quick_editor"}
ACTIVE_ACCOUNT_STATUS = "active"
P = ParamSpec("P")
R = TypeVar("R")
def get_current_account() -> Account | None:
account_id = session.get(SESSION_ACCOUNT_ID_KEY)
if not isinstance(account_id, int):
return None
account = db.session.get(Account, account_id)
if account is None:
session.pop(SESSION_ACCOUNT_ID_KEY, None)
return None
if account.status != ACTIVE_ACCOUNT_STATUS:
session.pop(SESSION_ACCOUNT_ID_KEY, None)
return None
return account
def load_current_account() -> None:
g.current_account = get_current_account()
def login_account(account: Account) -> None:
session.clear()
session[SESSION_ACCOUNT_ID_KEY] = account.id
session.permanent = True
def logout_account() -> None:
session.clear()
def is_authenticated() -> bool:
return isinstance(getattr(g, "current_account", None), Account)
def login_required(view: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]:
@wraps(view)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue:
if not is_authenticated():
flash("请先登录。", "warning")
redirect_target = url_for("auth.login", next=request.full_path.rstrip("?"))
write_audit_log(
action_type="auth_unauthenticated",
target_type="route",
target_display_name=request.path,
after_data={
"reason": "login_required",
"redirect_to": redirect_target,
},
)
return redirect(redirect_target)
return view(*args, **kwargs)
return wrapped
def role_required(*roles: str) -> Callable[[Callable[P, R]], Callable[P, R | ResponseReturnValue]]:
required_roles = set(roles)
def decorator(view: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]:
@wraps(view)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue:
current_account = getattr(g, "current_account", None)
if not isinstance(current_account, Account):
flash("请先登录。", "warning")
redirect_target = url_for("auth.login", next=request.full_path.rstrip("?"))
write_audit_log(
action_type="auth_unauthenticated",
target_type="route",
target_display_name=request.path,
after_data={
"reason": "login_required",
"redirect_to": redirect_target,
},
)
return redirect(redirect_target)
if current_account.role not in required_roles:
flash("您没有权限访问该页面。", "warning")
redirect_target = default_redirect_path(current_account)
write_audit_log(
action_type="auth_forbidden",
target_type="route",
actor=current_account,
target_display_name=request.path,
after_data={
"account": account_snapshot(current_account),
"required_roles": sorted(required_roles),
"redirect_to": redirect_target,
},
)
return redirect(redirect_target)
return view(*args, **kwargs)
return wrapped
return decorator
def default_redirect_path(account: Account) -> str:
if account.role in {"entry_only", "quick_editor"}:
return url_for("quick_entry.index")
return url_for("main.index")
def is_safe_redirect_target(target: str | None) -> bool:
if not target:
return False
host_url = request.host_url
ref_url = urlparse(host_url)
test_url = urlparse(urljoin(host_url, target))
return test_url.scheme in {"http", "https"} and test_url.netloc == ref_url.netloc
def resolve_post_login_redirect(account: Account) -> str:
next_target = request.values.get("next")
if is_safe_redirect_target(next_target):
return str(next_target)
return default_redirect_path(account)