207 lines
7.1 KiB
Python
207 lines
7.1 KiB
Python
import json
|
||
import re
|
||
import time
|
||
from datetime import datetime, timezone
|
||
from typing import Any, Dict
|
||
|
||
import jwt
|
||
import shortuuid
|
||
from fastapi import Request, Response
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
|
||
from app.config import settings
|
||
from app.core.context import RequestContext
|
||
from app.core.logger import log
|
||
|
||
# 跳过自定义处理的路径(docs 相关)
|
||
_SKIP_PATHS = {"/openapi.json", "/docs", "/redoc"}
|
||
|
||
# 白名单匹配用的正则字符
|
||
_CLEAN_PATTERN = re.compile(r"[*/ ]")
|
||
|
||
|
||
def _is_whitelisted(path: str) -> bool:
|
||
"""判断路径是否在白名单中,逻辑与 Java 端一致"""
|
||
cleaned_path = _CLEAN_PATTERN.sub("", path)
|
||
for item in settings.auth_whitelist:
|
||
cleaned_item = _CLEAN_PATTERN.sub("", item)
|
||
if not cleaned_item:
|
||
continue
|
||
if cleaned_path.startswith(cleaned_item):
|
||
return True
|
||
return False
|
||
|
||
|
||
class RequestIDMiddleware(BaseHTTPMiddleware):
|
||
"""为每个请求生成唯一的 ShortUUID 并写入响应头"""
|
||
|
||
async def dispatch(self, request: Request, call_next) -> Response:
|
||
request_id = shortuuid.ShortUUID().random(length=18)
|
||
request.state.uuid = request_id
|
||
RequestContext.request_id.set(request_id)
|
||
with log.contextualize(request_id=request_id):
|
||
response = await call_next(request)
|
||
response.headers["X-Request-ID"] = request_id
|
||
return response
|
||
|
||
|
||
class JwtAuthMiddleware(BaseHTTPMiddleware):
|
||
"""从请求中解析 JWT Token,校验登录状态,存入 RequestContext"""
|
||
|
||
async def dispatch(self, request: Request, call_next) -> Response:
|
||
if _is_whitelisted(request.url.path):
|
||
return await call_next(request)
|
||
|
||
# 优先从 Cookie 获取,不存在则从请求头获取
|
||
token = None
|
||
if cookie_token := request.cookies.get("Token"):
|
||
token = cookie_token
|
||
if not token:
|
||
token = request.headers.get("Token")
|
||
|
||
if not token:
|
||
return await call_next(request)
|
||
|
||
try:
|
||
await self._verify_and_set_user(token)
|
||
except Exception as e:
|
||
log.warning(f"Token 校验失败: {e}")
|
||
|
||
return await call_next(request)
|
||
|
||
@staticmethod
|
||
async def _verify_and_set_user(token: str) -> None:
|
||
"""解析 JWT,校验 Redis 登录信息,续期"""
|
||
payload = jwt.decode(token, settings.jwt_secret, algorithms=["HS256"])
|
||
user_id = payload.get("userId")
|
||
uu_id = payload.get("uuId")
|
||
if not user_id or not uu_id:
|
||
return
|
||
|
||
from app.core.redis import redis_client
|
||
if redis_client is None:
|
||
return
|
||
|
||
redis_key = f"login:token:{user_id}"
|
||
raw = await redis_client.get(redis_key)
|
||
if not raw:
|
||
return
|
||
|
||
info = json.loads(raw) if isinstance(raw, str) else raw
|
||
devices = info.get("loginDevices", [])
|
||
|
||
# 过滤过期设备
|
||
now = datetime.now(timezone.utc)
|
||
valid_devices = []
|
||
for d in devices:
|
||
last_login_str = d["lastLoginTime"]
|
||
# 兼容 Java Instant 格式(尾部 Z)
|
||
if last_login_str.endswith("Z"):
|
||
last_login_str = last_login_str[:-1] + "+00:00"
|
||
last_login = datetime.fromisoformat(last_login_str)
|
||
if last_login.tzinfo is None:
|
||
last_login = last_login.replace(tzinfo=timezone.utc)
|
||
if (now - last_login).total_seconds() < settings.token_expire_seconds:
|
||
valid_devices.append(d)
|
||
|
||
# 校验当前设备
|
||
device_map = {d["uuId"]: d for d in valid_devices}
|
||
if uu_id not in device_map:
|
||
return
|
||
|
||
# 续期
|
||
device_map[uu_id]["lastLoginTime"] = now.isoformat()
|
||
info["loginDevices"] = valid_devices
|
||
await redis_client.set(redis_key, json.dumps(info), ex=settings.token_expire_seconds)
|
||
|
||
RequestContext.user_id.set(int(user_id))
|
||
|
||
|
||
class AuthRequiredMiddleware(BaseHTTPMiddleware):
|
||
"""鉴权拦截:非白名单路径必须有 user_id"""
|
||
|
||
async def dispatch(self, request: Request, call_next) -> Response:
|
||
if _is_whitelisted(request.url.path) or request.url.path in _SKIP_PATHS:
|
||
return await call_next(request)
|
||
|
||
user_id = RequestContext.user_id.get(None)
|
||
if user_id is None:
|
||
from app.core.schemas.responses import StandardResponse
|
||
content = StandardResponse.fail(msg="未经授权,请登录", code=401).model_dump_json()
|
||
return Response(content=content, status_code=401, media_type="application/json")
|
||
|
||
return await call_next(request)
|
||
|
||
|
||
class RequestLogMiddleware(BaseHTTPMiddleware):
|
||
"""记录请求和响应信息"""
|
||
|
||
async def dispatch(self, request: Request, call_next) -> Response:
|
||
if request.url.path in _SKIP_PATHS:
|
||
return await call_next(request)
|
||
|
||
start_time = time.time()
|
||
|
||
log_data: Dict[str, Any] = {
|
||
"method": request.method,
|
||
"url": str(request.url),
|
||
"path": request.url.path,
|
||
"client_ip": request.client.host if request.client else None,
|
||
}
|
||
|
||
if request.path_params:
|
||
log_data["path_params"] = dict(request.path_params)
|
||
|
||
if request.query_params:
|
||
log_data["query_params"] = dict(request.query_params)
|
||
|
||
log.info(f"Request: {log_data}")
|
||
|
||
response = await call_next(request)
|
||
|
||
process_time = round(time.time() - start_time, 4)
|
||
log.info(f"Response: status={response.status_code} time={process_time}s")
|
||
|
||
return response
|
||
|
||
|
||
class ResponseWrapMiddleware(BaseHTTPMiddleware):
|
||
"""将业务路由的 JSON 响应统一包装为 StandardResponse 格式"""
|
||
|
||
async def dispatch(self, request: Request, call_next) -> Response:
|
||
if request.url.path in _SKIP_PATHS:
|
||
return await call_next(request)
|
||
|
||
response = await call_next(request)
|
||
|
||
content_type = response.headers.get("content-type", "")
|
||
if response.status_code != 200 or "application/json" not in content_type:
|
||
return response
|
||
|
||
# 读取原始响应体
|
||
body = b""
|
||
async for chunk in response.body_iterator:
|
||
body += chunk if isinstance(chunk, bytes) else chunk.encode()
|
||
|
||
data = json.loads(body)
|
||
|
||
# 已经是 StandardResponse 格式的不重复包装
|
||
if isinstance(data, dict) and "code" in data and "msg" in data and "timestamp" in data:
|
||
return Response(content=body, status_code=200, media_type="application/json")
|
||
|
||
from app.core.schemas.responses import StandardResponse
|
||
uuid = getattr(request.state, "uuid", None)
|
||
content = StandardResponse.success(data=data, uuid=uuid).model_dump_json()
|
||
|
||
return Response(content=content, status_code=200, media_type="application/json")
|
||
|
||
|
||
def register_middleware(app) -> None:
|
||
"""注册所有中间件(注意:注册顺序与执行顺序相反)"""
|
||
# 最后注册的最先执行
|
||
app.add_middleware(ResponseWrapMiddleware)
|
||
app.add_middleware(RequestLogMiddleware)
|
||
app.add_middleware(AuthRequiredMiddleware)
|
||
app.add_middleware(JwtAuthMiddleware)
|
||
app.add_middleware(RequestIDMiddleware)
|