import inspect import time from functools import wraps from typing import Any, Dict, Mapping, Optional import jwt from fastapi import Request from jwt.exceptions import InvalidTokenError from domain.auth import ALGORITHM from domain.config import ConfigService from models.database import UserAccount from .service import AuditService from .types import AuditAction def _extract_request(bound_args: Mapping[str, Any]) -> Request | None: for value in bound_args.values(): if isinstance(value, Request): return value return None async def _resolve_user(request: Request | None, user_obj: Any | None) -> tuple[Optional[int], Optional[str]]: user_id: int | None = None username: str | None = None if request: auth_header = request.headers.get("authorization") or request.headers.get("Authorization") if auth_header and auth_header.lower().startswith("bearer "): token = auth_header.split(" ", 1)[1] try: payload = jwt.decode(token, await ConfigService.get_secret_key("SECRET_KEY"), algorithms=[ALGORITHM]) username = payload.get("sub") or payload.get("username") if username: user = await UserAccount.get_or_none(username=username) user_id = user.id if user else None except (InvalidTokenError, Exception): pass if user_id is None and username is None and user_obj is not None: user_id = getattr(user_obj, "id", None) or getattr(user_obj, "user_id", None) username = getattr(user_obj, "username", None) or getattr(user_obj, "name", None) if isinstance(user_obj, dict): user_id = user_obj.get("id", user_obj.get("user_id", user_id)) username = user_obj.get("username", user_obj.get("name", username)) return user_id, username def _extract_body_fields(bound_args: Mapping[str, Any], body_fields: list[str] | None, redact_fields: list[str] | None): if not body_fields: return None body: Dict[str, Any] = {} redacts = set(redact_fields or []) for value in bound_args.values(): data: Optional[Dict[str, Any]] = None if hasattr(value, "model_dump"): try: data = value.model_dump() except Exception: data = None elif hasattr(value, "dict"): try: data = value.dict() except Exception: data = None elif isinstance(value, dict): data = value elif hasattr(value, "__dict__"): data = dict(value.__dict__) if not isinstance(data, dict): continue for field in body_fields: if field in data and field not in body: body[field] = data[field] if not body: return None for field in redacts: if field in body: body[field] = "" return body def _build_request_params(request: Request | None) -> Dict[str, Any] | None: if not request: return None params: Dict[str, Any] = {} query = dict(request.query_params) if query: params["query"] = query path_params = dict(request.path_params or {}) if path_params: params["path"] = path_params return params or None def _get_client_ip(request: Request | None) -> str | None: if not request: return None cf_connecting_ip = request.headers.get("cf-connecting-ip") or request.headers.get("CF-Connecting-IP") if cf_connecting_ip: ip = cf_connecting_ip.strip() if ip: return ip x_real_ip = request.headers.get("x-real-ip") or request.headers.get("X-Real-IP") if x_real_ip: ip = x_real_ip.strip() if ip: return ip x_forwarded_for = request.headers.get("x-forwarded-for") or request.headers.get("X-Forwarded-For") if x_forwarded_for: for part in x_forwarded_for.split(","): ip = part.strip() if ip and ip.lower() != "unknown": return ip return request.client.host if request.client else None def _status_code_from_response(response: Any) -> int: if hasattr(response, "status_code"): try: return int(getattr(response, "status_code")) except Exception: pass return 200 def audit( *, action: AuditAction, description: str | None = None, body_fields: list[str] | None = None, redact_fields: list[str] | None = None, user_kw: str = "current_user", ): def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): bound = inspect.signature(func).bind_partial(*args, **kwargs) bound.apply_defaults() request = _extract_request(bound.arguments) start = time.perf_counter() user_info = bound.arguments.get(user_kw) user_id, username = await _resolve_user(request, user_info) request_params = _build_request_params(request) request_body = _extract_body_fields(bound.arguments, body_fields, redact_fields) try: result = func(*args, **kwargs) if inspect.isawaitable(result): result = await result status_code = _status_code_from_response(result) success = True error = None except Exception as exc: # noqa: BLE001 status_code = getattr(exc, "status_code", 500) success = False error = str(exc) duration_ms = round((time.perf_counter() - start) * 1000, 2) try: await AuditService.log( action=action, description=description, user_id=user_id, username=username, client_ip=_get_client_ip(request), method=request.method if request else "", path=request.url.path if request else func.__name__, status_code=status_code, duration_ms=duration_ms, success=success, request_params=request_params, request_body=request_body, error=error, ) except Exception: pass raise duration_ms = round((time.perf_counter() - start) * 1000, 2) try: await AuditService.log( action=action, description=description, user_id=user_id, username=username, client_ip=_get_client_ip(request), method=request.method if request else "", path=request.url.path if request else func.__name__, status_code=status_code, duration_ms=duration_ms, success=success, request_params=request_params, request_body=request_body, error=error, ) except Exception: pass return result return wrapper return decorator