188 lines
6.3 KiB
Python
188 lines
6.3 KiB
Python
import json
|
|
import re
|
|
import time
|
|
from datetime import datetime
|
|
from typing import Any, Dict
|
|
|
|
import jwt
|
|
import shortuuid
|
|
from fastapi import Request, Response
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
from app.config import settings
|
|
from app.core.context import RequestContext
|
|
from app.core.logger import log
|
|
|
|
# 跳过自定义处理的路径(docs 相关)
|
|
_SKIP_PATHS = {"/openapi.json", "/docs", "/redoc"}
|
|
|
|
# 白名单匹配用的正则字符
|
|
_CLEAN_PATTERN = re.compile(r"[*/ ]")
|
|
|
|
|
|
def _is_whitelisted(path: str) -> bool:
|
|
"""判断路径是否在白名单中,逻辑与 Java 端一致"""
|
|
cleaned_path = _CLEAN_PATTERN.sub("", path)
|
|
for item in settings.auth_whitelist:
|
|
cleaned_item = _CLEAN_PATTERN.sub("", item)
|
|
if not cleaned_item:
|
|
continue
|
|
if cleaned_path.startswith(cleaned_item):
|
|
return True
|
|
return False
|
|
|
|
|
|
class RequestIDMiddleware(BaseHTTPMiddleware):
|
|
"""为每个请求生成唯一的 ShortUUID 并写入响应头"""
|
|
|
|
async def dispatch(self, request: Request, call_next) -> Response:
|
|
request_id = shortuuid.ShortUUID().random(length=18)
|
|
request.state.uuid = request_id
|
|
RequestContext.request_id.set(request_id)
|
|
with log.contextualize(request_id=request_id):
|
|
response = await call_next(request)
|
|
response.headers["X-Request-ID"] = request_id
|
|
return response
|
|
|
|
|
|
class JwtAuthMiddleware(BaseHTTPMiddleware):
|
|
"""从请求中解析 JWT Token,校验登录状态,存入 RequestContext"""
|
|
|
|
async def dispatch(self, request: Request, call_next) -> Response:
|
|
if _is_whitelisted(request.url.path):
|
|
return await call_next(request)
|
|
|
|
# 优先从 Cookie 获取,不存在则从请求头获取
|
|
token = None
|
|
if cookie_token := request.cookies.get("Token"):
|
|
token = cookie_token
|
|
if not token:
|
|
token = request.headers.get("Token")
|
|
|
|
if not token:
|
|
return await call_next(request)
|
|
|
|
try:
|
|
await self._verify_and_set_user(token)
|
|
except Exception as e:
|
|
log.warning(f"Token 校验失败: {e}")
|
|
|
|
return await call_next(request)
|
|
|
|
@staticmethod
|
|
async def _verify_and_set_user(token: str) -> None:
|
|
"""解析 JWT,校验 Redis 登录信息,续期"""
|
|
payload = jwt.decode(token, settings.jwt_secret, algorithms=["HS256"])
|
|
user_id = payload.get("userId")
|
|
uu_id = payload.get("uuId")
|
|
if not user_id or not uu_id:
|
|
return
|
|
|
|
from app.core.redis import redis_client
|
|
if redis_client is None:
|
|
return
|
|
|
|
redis_key = f"client:login:token:{user_id}"
|
|
raw = await redis_client.get(redis_key)
|
|
if not raw:
|
|
return
|
|
|
|
info = json.loads(raw) if isinstance(raw, str) else raw
|
|
devices = info.get("loginDevices", [])
|
|
|
|
# 校验当前设备是否在列表中
|
|
device_uuids = {d["uuId"] for d in devices}
|
|
if uu_id not in device_uuids:
|
|
return
|
|
|
|
RequestContext.user_id.set(int(user_id))
|
|
|
|
|
|
class AuthRequiredMiddleware(BaseHTTPMiddleware):
|
|
"""鉴权拦截:非白名单路径必须有 user_id"""
|
|
|
|
async def dispatch(self, request: Request, call_next) -> Response:
|
|
if _is_whitelisted(request.url.path) or request.url.path in _SKIP_PATHS:
|
|
return await call_next(request)
|
|
|
|
user_id = RequestContext.user_id.get(None)
|
|
if user_id is None:
|
|
from app.core.schemas.responses import StandardResponse
|
|
content = StandardResponse.fail(msg="未经授权,请登录", code=401).model_dump_json()
|
|
return Response(content=content, status_code=401, media_type="application/json")
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
class RequestLogMiddleware(BaseHTTPMiddleware):
|
|
"""记录请求和响应信息"""
|
|
|
|
async def dispatch(self, request: Request, call_next) -> Response:
|
|
if request.url.path in _SKIP_PATHS:
|
|
return await call_next(request)
|
|
|
|
start_time = time.time()
|
|
|
|
log_data: Dict[str, Any] = {
|
|
"method": request.method,
|
|
"url": str(request.url),
|
|
"path": request.url.path,
|
|
"client_ip": request.client.host if request.client else None,
|
|
}
|
|
|
|
if request.path_params:
|
|
log_data["path_params"] = dict(request.path_params)
|
|
|
|
if request.query_params:
|
|
log_data["query_params"] = dict(request.query_params)
|
|
|
|
log.info(f"Request: {log_data}")
|
|
|
|
response = await call_next(request)
|
|
|
|
process_time = round(time.time() - start_time, 4)
|
|
log.info(f"Response: status={response.status_code} time={process_time}s")
|
|
|
|
return response
|
|
|
|
|
|
class ResponseWrapMiddleware(BaseHTTPMiddleware):
|
|
"""将业务路由的 JSON 响应统一包装为 StandardResponse 格式"""
|
|
|
|
async def dispatch(self, request: Request, call_next) -> Response:
|
|
if request.url.path in _SKIP_PATHS:
|
|
return await call_next(request)
|
|
|
|
response = await call_next(request)
|
|
|
|
content_type = response.headers.get("content-type", "")
|
|
if response.status_code != 200 or "application/json" not in content_type:
|
|
return response
|
|
|
|
# 读取原始响应体
|
|
body = b""
|
|
async for chunk in response.body_iterator:
|
|
body += chunk if isinstance(chunk, bytes) else chunk.encode()
|
|
|
|
data = json.loads(body)
|
|
|
|
# 已经是 StandardResponse 格式的不重复包装
|
|
if isinstance(data, dict) and "code" in data and "msg" in data and "timestamp" in data:
|
|
return Response(content=body, status_code=200, media_type="application/json")
|
|
|
|
from app.core.schemas.responses import StandardResponse
|
|
uuid = getattr(request.state, "uuid", None)
|
|
content = StandardResponse.success(data=data, uuid=uuid).model_dump_json()
|
|
|
|
return Response(content=content, status_code=200, media_type="application/json")
|
|
|
|
|
|
def register_middleware(app) -> None:
|
|
"""注册所有中间件(注意:注册顺序与执行顺序相反)"""
|
|
# 最后注册的最先执行
|
|
app.add_middleware(ResponseWrapMiddleware)
|
|
app.add_middleware(RequestLogMiddleware)
|
|
app.add_middleware(AuthRequiredMiddleware)
|
|
app.add_middleware(JwtAuthMiddleware)
|
|
app.add_middleware(RequestIDMiddleware)
|