import json import re import time from datetime import datetime from typing import Any, Dict import jwt import shortuuid from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from app.config import settings from app.core.context import RequestContext from app.core.logger import log # 跳过自定义处理的路径(docs 相关) _SKIP_PATHS = {"/openapi.json", "/docs", "/redoc"} # 白名单匹配用的正则字符 _CLEAN_PATTERN = re.compile(r"[*/ ]") def _is_whitelisted(path: str) -> bool: """判断路径是否在白名单中,逻辑与 Java 端一致""" cleaned_path = _CLEAN_PATTERN.sub("", path) for item in settings.auth_whitelist: cleaned_item = _CLEAN_PATTERN.sub("", item) if not cleaned_item: continue if cleaned_path.startswith(cleaned_item): return True return False class RequestIDMiddleware(BaseHTTPMiddleware): """为每个请求生成唯一的 ShortUUID 并写入响应头""" async def dispatch(self, request: Request, call_next) -> Response: request_id = shortuuid.ShortUUID().random(length=18) request.state.uuid = request_id RequestContext.request_id.set(request_id) with log.contextualize(request_id=request_id): response = await call_next(request) response.headers["X-Request-ID"] = request_id return response class JwtAuthMiddleware(BaseHTTPMiddleware): """从请求中解析 JWT Token,校验登录状态,存入 RequestContext""" async def dispatch(self, request: Request, call_next) -> Response: if _is_whitelisted(request.url.path): return await call_next(request) # 优先从 Cookie 获取,不存在则从请求头获取 token = None if cookie_token := request.cookies.get("Token"): token = cookie_token if not token: token = request.headers.get("Token") if not token: return await call_next(request) try: await self._verify_and_set_user(token) except Exception as e: log.warning(f"Token 校验失败: {e}") return await call_next(request) @staticmethod async def _verify_and_set_user(token: str) -> None: """解析 JWT,校验 Redis 登录信息,续期""" payload = jwt.decode(token, settings.jwt_secret, algorithms=["HS256"]) user_id = payload.get("userId") uu_id = payload.get("uuId") if not user_id or not uu_id: return from app.core.redis import redis_client if redis_client is None: return redis_key = f"client:login:token:{user_id}" raw = await redis_client.get(redis_key) if not raw: return info = json.loads(raw) if isinstance(raw, str) else raw devices = info.get("loginDevices", []) # 校验当前设备是否在列表中 device_uuids = {d["uuId"] for d in devices} if uu_id not in device_uuids: return RequestContext.user_id.set(int(user_id)) class AuthRequiredMiddleware(BaseHTTPMiddleware): """鉴权拦截:非白名单路径必须有 user_id""" async def dispatch(self, request: Request, call_next) -> Response: if _is_whitelisted(request.url.path) or request.url.path in _SKIP_PATHS: return await call_next(request) user_id = RequestContext.user_id.get(None) if user_id is None: from app.core.schemas.responses import StandardResponse content = StandardResponse.fail(msg="未经授权,请登录", code=401).model_dump_json() return Response(content=content, status_code=401, media_type="application/json") return await call_next(request) class RequestLogMiddleware(BaseHTTPMiddleware): """记录请求和响应信息""" async def dispatch(self, request: Request, call_next) -> Response: if request.url.path in _SKIP_PATHS: return await call_next(request) start_time = time.time() log_data: Dict[str, Any] = { "method": request.method, "url": str(request.url), "path": request.url.path, "client_ip": request.client.host if request.client else None, } if request.path_params: log_data["path_params"] = dict(request.path_params) if request.query_params: log_data["query_params"] = dict(request.query_params) log.info(f"Request: {log_data}") response = await call_next(request) process_time = round(time.time() - start_time, 4) log.info(f"Response: status={response.status_code} time={process_time}s") return response class ResponseWrapMiddleware(BaseHTTPMiddleware): """将业务路由的 JSON 响应统一包装为 StandardResponse 格式""" async def dispatch(self, request: Request, call_next) -> Response: if request.url.path in _SKIP_PATHS: return await call_next(request) response = await call_next(request) content_type = response.headers.get("content-type", "") if response.status_code != 200 or "application/json" not in content_type: return response # 读取原始响应体 body = b"" async for chunk in response.body_iterator: body += chunk if isinstance(chunk, bytes) else chunk.encode() data = json.loads(body) # 已经是 StandardResponse 格式的不重复包装 if isinstance(data, dict) and "code" in data and "msg" in data and "timestamp" in data: return Response(content=body, status_code=200, media_type="application/json") from app.core.schemas.responses import StandardResponse uuid = getattr(request.state, "uuid", None) content = StandardResponse.success(data=data, uuid=uuid).model_dump_json() return Response(content=content, status_code=200, media_type="application/json") def register_middleware(app) -> None: """注册所有中间件(注意:注册顺序与执行顺序相反)""" # 最后注册的最先执行 app.add_middleware(ResponseWrapMiddleware) app.add_middleware(RequestLogMiddleware) app.add_middleware(AuthRequiredMiddleware) app.add_middleware(JwtAuthMiddleware) app.add_middleware(RequestIDMiddleware)