You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
156 lines
4.8 KiB
156 lines
4.8 KiB
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
from functools import wraps
|
|
from typing import TypeVar
|
|
from urllib.parse import urljoin, urlparse
|
|
|
|
try:
|
|
from typing import ParamSpec
|
|
except ImportError: # pragma: no cover - Python 3.10+ uses typing.ParamSpec directly
|
|
from typing_extensions import ParamSpec
|
|
|
|
from flask import flash, g, redirect, request, session, url_for
|
|
from flask.typing import ResponseReturnValue
|
|
|
|
from app.extensions import db
|
|
from app.models import Account
|
|
from app.services import account_snapshot, write_audit_log
|
|
|
|
SESSION_ACCOUNT_ID_KEY = "account_id"
|
|
ALLOWED_ACCOUNT_ROLES = {"admin", "editor", "entry_only", "quick_editor"}
|
|
ACTIVE_ACCOUNT_STATUS = "active"
|
|
|
|
P = ParamSpec("P")
|
|
R = TypeVar("R")
|
|
|
|
|
|
def get_current_account() -> Account | None:
|
|
account_id = session.get(SESSION_ACCOUNT_ID_KEY)
|
|
if not isinstance(account_id, int):
|
|
return None
|
|
|
|
account = db.session.get(Account, account_id)
|
|
if account is None:
|
|
session.pop(SESSION_ACCOUNT_ID_KEY, None)
|
|
return None
|
|
|
|
if account.status != ACTIVE_ACCOUNT_STATUS:
|
|
session.pop(SESSION_ACCOUNT_ID_KEY, None)
|
|
return None
|
|
|
|
return account
|
|
|
|
|
|
def load_current_account() -> None:
|
|
g.current_account = get_current_account()
|
|
|
|
|
|
def login_account(account: Account) -> None:
|
|
session.clear()
|
|
session[SESSION_ACCOUNT_ID_KEY] = account.id
|
|
session.permanent = True
|
|
|
|
|
|
|
|
def logout_account() -> None:
|
|
session.clear()
|
|
|
|
|
|
|
|
def is_authenticated() -> bool:
|
|
return isinstance(getattr(g, "current_account", None), Account)
|
|
|
|
|
|
|
|
def login_required(view: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]:
|
|
@wraps(view)
|
|
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue:
|
|
if not is_authenticated():
|
|
flash("请先登录。", "warning")
|
|
redirect_target = url_for("auth.login", next=request.full_path.rstrip("?"))
|
|
write_audit_log(
|
|
action_type="auth_unauthenticated",
|
|
target_type="route",
|
|
target_display_name=request.path,
|
|
after_data={
|
|
"reason": "login_required",
|
|
"redirect_to": redirect_target,
|
|
},
|
|
)
|
|
return redirect(redirect_target)
|
|
|
|
return view(*args, **kwargs)
|
|
|
|
return wrapped
|
|
|
|
|
|
|
|
def role_required(*roles: str) -> Callable[[Callable[P, R]], Callable[P, R | ResponseReturnValue]]:
|
|
required_roles = set(roles)
|
|
|
|
def decorator(view: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]:
|
|
@wraps(view)
|
|
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue:
|
|
current_account = getattr(g, "current_account", None)
|
|
if not isinstance(current_account, Account):
|
|
flash("请先登录。", "warning")
|
|
redirect_target = url_for("auth.login", next=request.full_path.rstrip("?"))
|
|
write_audit_log(
|
|
action_type="auth_unauthenticated",
|
|
target_type="route",
|
|
target_display_name=request.path,
|
|
after_data={
|
|
"reason": "login_required",
|
|
"redirect_to": redirect_target,
|
|
},
|
|
)
|
|
return redirect(redirect_target)
|
|
|
|
if current_account.role not in required_roles:
|
|
flash("您没有权限访问该页面。", "warning")
|
|
redirect_target = default_redirect_path(current_account)
|
|
write_audit_log(
|
|
action_type="auth_forbidden",
|
|
target_type="route",
|
|
actor=current_account,
|
|
target_display_name=request.path,
|
|
after_data={
|
|
"account": account_snapshot(current_account),
|
|
"required_roles": sorted(required_roles),
|
|
"redirect_to": redirect_target,
|
|
},
|
|
)
|
|
return redirect(redirect_target)
|
|
|
|
return view(*args, **kwargs)
|
|
|
|
return wrapped
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
def default_redirect_path(account: Account) -> str:
|
|
if account.role in {"entry_only", "quick_editor"}:
|
|
return url_for("quick_entry.index")
|
|
|
|
return url_for("main.index")
|
|
|
|
|
|
def is_safe_redirect_target(target: str | None) -> bool:
|
|
if not target:
|
|
return False
|
|
|
|
host_url = request.host_url
|
|
ref_url = urlparse(host_url)
|
|
test_url = urlparse(urljoin(host_url, target))
|
|
return test_url.scheme in {"http", "https"} and test_url.netloc == ref_url.netloc
|
|
|
|
|
|
def resolve_post_login_redirect(account: Account) -> str:
|
|
next_target = request.values.get("next")
|
|
if is_safe_redirect_target(next_target):
|
|
return str(next_target)
|
|
|
|
return default_redirect_path(account)
|
|
|