Files
Foxel/domain/audit/decorator.py

205 lines
7.3 KiB
Python

import inspect
import time
from functools import wraps
from typing import Any, Dict, Mapping, Optional
import jwt
from fastapi import Request
from jwt.exceptions import InvalidTokenError
from domain.audit.service import AuditService
from domain.audit.types import AuditAction
from domain.auth.service import ALGORITHM
from domain.config.service import ConfigService
from models.database import UserAccount
def _extract_request(bound_args: Mapping[str, Any]) -> Request | None:
for value in bound_args.values():
if isinstance(value, Request):
return value
return None
async def _resolve_user(request: Request | None, user_obj: Any | None) -> tuple[Optional[int], Optional[str]]:
user_id: int | None = None
username: str | None = None
if request:
auth_header = request.headers.get("authorization") or request.headers.get("Authorization")
if auth_header and auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1]
try:
payload = jwt.decode(token, await ConfigService.get_secret_key("SECRET_KEY"), algorithms=[ALGORITHM])
username = payload.get("sub") or payload.get("username")
if username:
user = await UserAccount.get_or_none(username=username)
user_id = user.id if user else None
except (InvalidTokenError, Exception):
pass
if user_id is None and username is None and user_obj is not None:
user_id = getattr(user_obj, "id", None) or getattr(user_obj, "user_id", None)
username = getattr(user_obj, "username", None) or getattr(user_obj, "name", None)
if isinstance(user_obj, dict):
user_id = user_obj.get("id", user_obj.get("user_id", user_id))
username = user_obj.get("username", user_obj.get("name", username))
return user_id, username
def _extract_body_fields(bound_args: Mapping[str, Any], body_fields: list[str] | None, redact_fields: list[str] | None):
if not body_fields:
return None
body: Dict[str, Any] = {}
redacts = set(redact_fields or [])
for value in bound_args.values():
data: Optional[Dict[str, Any]] = None
if hasattr(value, "model_dump"):
try:
data = value.model_dump()
except Exception:
data = None
elif hasattr(value, "dict"):
try:
data = value.dict()
except Exception:
data = None
elif isinstance(value, dict):
data = value
elif hasattr(value, "__dict__"):
data = dict(value.__dict__)
if not isinstance(data, dict):
continue
for field in body_fields:
if field in data and field not in body:
body[field] = data[field]
if not body:
return None
for field in redacts:
if field in body:
body[field] = "<redacted>"
return body
def _build_request_params(request: Request | None) -> Dict[str, Any] | None:
if not request:
return None
params: Dict[str, Any] = {}
query = dict(request.query_params)
if query:
params["query"] = query
path_params = dict(request.path_params or {})
if path_params:
params["path"] = path_params
return params or None
def _get_client_ip(request: Request | None) -> str | None:
if not request:
return None
cf_connecting_ip = request.headers.get("cf-connecting-ip") or request.headers.get("CF-Connecting-IP")
if cf_connecting_ip:
ip = cf_connecting_ip.strip()
if ip:
return ip
x_real_ip = request.headers.get("x-real-ip") or request.headers.get("X-Real-IP")
if x_real_ip:
ip = x_real_ip.strip()
if ip:
return ip
x_forwarded_for = request.headers.get("x-forwarded-for") or request.headers.get("X-Forwarded-For")
if x_forwarded_for:
for part in x_forwarded_for.split(","):
ip = part.strip()
if ip and ip.lower() != "unknown":
return ip
return request.client.host if request.client else None
def _status_code_from_response(response: Any) -> int:
if hasattr(response, "status_code"):
try:
return int(getattr(response, "status_code"))
except Exception:
pass
return 200
def audit(
*,
action: AuditAction,
description: str | None = None,
body_fields: list[str] | None = None,
redact_fields: list[str] | None = None,
user_kw: str = "current_user",
):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
bound = inspect.signature(func).bind_partial(*args, **kwargs)
bound.apply_defaults()
request = _extract_request(bound.arguments)
start = time.perf_counter()
user_info = bound.arguments.get(user_kw)
user_id, username = await _resolve_user(request, user_info)
request_params = _build_request_params(request)
request_body = _extract_body_fields(bound.arguments, body_fields, redact_fields)
try:
result = func(*args, **kwargs)
if inspect.isawaitable(result):
result = await result
status_code = _status_code_from_response(result)
success = True
error = None
except Exception as exc: # noqa: BLE001
status_code = getattr(exc, "status_code", 500)
success = False
error = str(exc)
duration_ms = round((time.perf_counter() - start) * 1000, 2)
try:
await AuditService.log(
action=action,
description=description,
user_id=user_id,
username=username,
client_ip=_get_client_ip(request),
method=request.method if request else "",
path=request.url.path if request else func.__name__,
status_code=status_code,
duration_ms=duration_ms,
success=success,
request_params=request_params,
request_body=request_body,
error=error,
)
except Exception:
pass
raise
duration_ms = round((time.perf_counter() - start) * 1000, 2)
try:
await AuditService.log(
action=action,
description=description,
user_id=user_id,
username=username,
client_ip=_get_client_ip(request),
method=request.method if request else "",
path=request.url.path if request else func.__name__,
status_code=status_code,
duration_ms=duration_ms,
success=success,
request_params=request_params,
request_body=request_body,
error=error,
)
except Exception:
pass
return result
return wrapper
return decorator