初始话项目框架
This commit is contained in:
@@ -0,0 +1,200 @@
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict
|
||||
|
||||
import jwt
|
||||
import shortuuid
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.config import settings
|
||||
from app.core.context import RequestContext
|
||||
from app.core.logger import log
|
||||
|
||||
# 跳过自定义处理的路径(docs 相关)
|
||||
_SKIP_PATHS = {"/openapi.json", "/docs", "/redoc"}
|
||||
|
||||
# 白名单匹配用的正则字符
|
||||
_CLEAN_PATTERN = re.compile(r"[*/ ]")
|
||||
|
||||
|
||||
def _is_whitelisted(path: str) -> bool:
|
||||
"""判断路径是否在白名单中,逻辑与 Java 端一致"""
|
||||
cleaned_path = _CLEAN_PATTERN.sub("", path)
|
||||
for item in settings.auth_whitelist:
|
||||
cleaned_item = _CLEAN_PATTERN.sub("", item)
|
||||
if not cleaned_item:
|
||||
continue
|
||||
if cleaned_path.startswith(cleaned_item):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class RequestIDMiddleware(BaseHTTPMiddleware):
|
||||
"""为每个请求生成唯一的 ShortUUID 并写入响应头"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
request_id = shortuuid.ShortUUID().random(length=18)
|
||||
request.state.uuid = request_id
|
||||
RequestContext.request_id.set(request_id)
|
||||
with log.contextualize(request_id=request_id):
|
||||
response = await call_next(request)
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
return response
|
||||
|
||||
|
||||
class JwtAuthMiddleware(BaseHTTPMiddleware):
|
||||
"""从请求中解析 JWT Token,校验登录状态,存入 RequestContext"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
if _is_whitelisted(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# 优先从 Cookie 获取,不存在则从请求头获取
|
||||
token = None
|
||||
if cookie_token := request.cookies.get("Token"):
|
||||
token = cookie_token
|
||||
if not token:
|
||||
token = request.headers.get("Token")
|
||||
|
||||
if not token:
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
await self._verify_and_set_user(token)
|
||||
except Exception as e:
|
||||
log.warning(f"Token 校验失败: {e}")
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
@staticmethod
|
||||
async def _verify_and_set_user(token: str) -> None:
|
||||
"""解析 JWT,校验 Redis 登录信息,续期"""
|
||||
payload = jwt.decode(token, settings.jwt_secret, algorithms=["HS256"])
|
||||
user_id = payload.get("userId")
|
||||
uu_id = payload.get("uuId")
|
||||
if not user_id or not uu_id:
|
||||
return
|
||||
|
||||
from app.core.redis import redis_client
|
||||
if redis_client is None:
|
||||
return
|
||||
|
||||
redis_key = f"login:token:{user_id}"
|
||||
raw = await redis_client.get(redis_key)
|
||||
if not raw:
|
||||
return
|
||||
|
||||
info = json.loads(raw) if isinstance(raw, str) else raw
|
||||
devices = info.get("loginDevices", [])
|
||||
|
||||
# 过滤过期设备
|
||||
now = datetime.now(timezone.utc)
|
||||
valid_devices = []
|
||||
for d in devices:
|
||||
last_login = datetime.fromisoformat(d["lastLoginTime"])
|
||||
if (now - last_login).total_seconds() < settings.token_expire_seconds:
|
||||
valid_devices.append(d)
|
||||
|
||||
# 校验当前设备
|
||||
device_map = {d["uuId"]: d for d in valid_devices}
|
||||
if uu_id not in device_map:
|
||||
return
|
||||
|
||||
# 续期
|
||||
device_map[uu_id]["lastLoginTime"] = now.isoformat()
|
||||
info["loginDevices"] = valid_devices
|
||||
await redis_client.set(redis_key, json.dumps(info), ex=settings.token_expire_seconds)
|
||||
|
||||
RequestContext.user_id.set(int(user_id))
|
||||
|
||||
|
||||
class AuthRequiredMiddleware(BaseHTTPMiddleware):
|
||||
"""鉴权拦截:非白名单路径必须有 user_id"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
if _is_whitelisted(request.url.path) or request.url.path in _SKIP_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
user_id = RequestContext.user_id.get(None)
|
||||
if user_id is None:
|
||||
from app.core.schemas.responses import StandardResponse
|
||||
content = StandardResponse.fail(msg="未经授权,请登录", code=401).model_dump_json()
|
||||
return Response(content=content, status_code=401, media_type="application/json")
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class RequestLogMiddleware(BaseHTTPMiddleware):
|
||||
"""记录请求和响应信息"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
if request.url.path in _SKIP_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
log_data: Dict[str, Any] = {
|
||||
"method": request.method,
|
||||
"url": str(request.url),
|
||||
"path": request.url.path,
|
||||
"client_ip": request.client.host if request.client else None,
|
||||
}
|
||||
|
||||
if request.path_params:
|
||||
log_data["path_params"] = dict(request.path_params)
|
||||
|
||||
if request.query_params:
|
||||
log_data["query_params"] = dict(request.query_params)
|
||||
|
||||
log.info(f"Request: {log_data}")
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
process_time = round(time.time() - start_time, 4)
|
||||
log.info(f"Response: status={response.status_code} time={process_time}s")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class ResponseWrapMiddleware(BaseHTTPMiddleware):
|
||||
"""将业务路由的 JSON 响应统一包装为 StandardResponse 格式"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
if request.url.path in _SKIP_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if response.status_code != 200 or "application/json" not in content_type:
|
||||
return response
|
||||
|
||||
# 读取原始响应体
|
||||
body = b""
|
||||
async for chunk in response.body_iterator:
|
||||
body += chunk if isinstance(chunk, bytes) else chunk.encode()
|
||||
|
||||
data = json.loads(body)
|
||||
|
||||
# 已经是 StandardResponse 格式的不重复包装
|
||||
if isinstance(data, dict) and "code" in data and "msg" in data and "timestamp" in data:
|
||||
return Response(content=body, status_code=200, media_type="application/json")
|
||||
|
||||
from app.core.schemas.responses import StandardResponse
|
||||
uuid = getattr(request.state, "uuid", None)
|
||||
content = StandardResponse.success(data=data, uuid=uuid).model_dump_json()
|
||||
|
||||
return Response(content=content, status_code=200, media_type="application/json")
|
||||
|
||||
|
||||
def register_middleware(app) -> None:
|
||||
"""注册所有中间件(注意:注册顺序与执行顺序相反)"""
|
||||
# 最后注册的最先执行
|
||||
app.add_middleware(ResponseWrapMiddleware)
|
||||
app.add_middleware(RequestLogMiddleware)
|
||||
app.add_middleware(AuthRequiredMiddleware)
|
||||
app.add_middleware(JwtAuthMiddleware)
|
||||
app.add_middleware(RequestIDMiddleware)
|
||||
Reference in New Issue
Block a user