初始话项目框架
This commit is contained in:
@@ -0,0 +1,15 @@
|
||||
from .context import RequestContext
|
||||
from .exceptions import register_exception_handlers
|
||||
from .lifespan import lifespan
|
||||
from .logger import log
|
||||
from .middleware import register_middleware
|
||||
from .schemas.responses import StandardResponse
|
||||
|
||||
__all__ = [
|
||||
"RequestContext",
|
||||
"StandardResponse",
|
||||
"lifespan",
|
||||
"log",
|
||||
"register_exception_handlers",
|
||||
"register_middleware",
|
||||
]
|
||||
@@ -0,0 +1,11 @@
|
||||
from contextvars import ContextVar
|
||||
|
||||
|
||||
class RequestContext:
|
||||
"""请求上下文变量"""
|
||||
|
||||
# 当前请求 ID
|
||||
request_id: ContextVar[str] = ContextVar("request_id", default="system")
|
||||
|
||||
# 当前用户id
|
||||
user_id: ContextVar[int] = ContextVar("user_id", default=None)
|
||||
@@ -0,0 +1,59 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.config import settings
|
||||
from app.core.logger import log
|
||||
|
||||
# 全局引擎和会话工厂
|
||||
engine: Optional[AsyncEngine] = None
|
||||
async_session_factory: Optional[async_sessionmaker[AsyncSession]] = None
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""ORM 声明基类"""
|
||||
pass
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""初始化数据库引擎和会话工厂"""
|
||||
global engine, async_session_factory
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.database_url,
|
||||
pool_size=settings.db_pool_size,
|
||||
max_overflow=settings.db_max_overflow,
|
||||
pool_recycle=settings.db_pool_recycle,
|
||||
echo=settings.env == "dev",
|
||||
)
|
||||
async_session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
log.info("数据库连接池已初始化")
|
||||
|
||||
|
||||
async def close_db() -> None:
|
||||
"""关闭数据库引擎,释放连接池"""
|
||||
global engine
|
||||
if engine:
|
||||
await engine.dispose()
|
||||
log.info("数据库连接池已关闭")
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""依赖注入:提供异步数据库会话,自动 commit/rollback/close"""
|
||||
if async_session_factory is None:
|
||||
raise RuntimeError("数据库未初始化,请先调用 init_db()")
|
||||
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
@@ -0,0 +1,73 @@
|
||||
import traceback
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
from app.config import settings
|
||||
from app.core.logger import log
|
||||
from app.core.schemas.responses import StandardResponse
|
||||
|
||||
# 友好的 HTTP 状态码消息映射
|
||||
_FRIENDLY_MESSAGES = {
|
||||
400: "请求参数错误",
|
||||
401: "未经授权,请登录",
|
||||
403: "权限不足,禁止访问",
|
||||
404: "请求的资源不存在",
|
||||
405: "请求方法不允许",
|
||||
422: "数据验证失败",
|
||||
429: "请求过于频繁",
|
||||
500: "服务器内部错误",
|
||||
}
|
||||
|
||||
|
||||
def _get_uuid(request: Request) -> str | None:
|
||||
return getattr(request.state, "uuid", None)
|
||||
|
||||
|
||||
async def http_exception_handler(request: Request, exc: StarletteHTTPException) -> JSONResponse:
|
||||
uuid = _get_uuid(request)
|
||||
log.error(f"HTTPException -- uuid: {uuid} | status: {exc.status_code} | msg: {exc.detail}")
|
||||
message = _FRIENDLY_MESSAGES.get(exc.status_code, str(exc.detail))
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=StandardResponse.fail(msg=message, code=exc.status_code, uuid=uuid).model_dump(),
|
||||
)
|
||||
|
||||
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
|
||||
uuid = _get_uuid(request)
|
||||
errors = exc.errors()
|
||||
log.error(f"ValidationError -- uuid: {uuid} | errors: {errors}")
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content=StandardResponse.fail(msg="数据验证失败", code=422, data=errors, uuid=uuid).model_dump(),
|
||||
)
|
||||
|
||||
|
||||
async def assertion_error_handler(request: Request, exc: AssertionError) -> JSONResponse:
|
||||
uuid = _get_uuid(request)
|
||||
log.error(f"AssertionError -- uuid: {uuid} | msg: {exc}\n{traceback.format_exc()}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=StandardResponse.fail(msg=f"断言错误: {exc}", uuid=uuid).model_dump(),
|
||||
)
|
||||
|
||||
|
||||
async def global_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
uuid = _get_uuid(request)
|
||||
log.error(f"Unhandled Exception -- uuid: {uuid} | msg: {exc}\n{traceback.format_exc()}")
|
||||
msg = str(exc) if settings.env == "dev" else "服务器内部错误"
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=StandardResponse.fail(msg=msg, uuid=uuid).model_dump(),
|
||||
)
|
||||
|
||||
|
||||
def register_exception_handlers(app) -> None:
|
||||
"""将异常处理器挂载到 FastAPI 应用"""
|
||||
app.add_exception_handler(StarletteHTTPException, http_exception_handler)
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
app.add_exception_handler(AssertionError, assertion_error_handler)
|
||||
app.add_exception_handler(Exception, global_exception_handler)
|
||||
@@ -0,0 +1,43 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.core.logger import log
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理:初始化和释放所有连接资源"""
|
||||
from app.core.database import init_db, close_db
|
||||
from app.core.redis import init_redis, close_redis
|
||||
from pathlib import Path
|
||||
|
||||
# 启动:初始化资源
|
||||
try:
|
||||
log.info("初始化数据库连接")
|
||||
await init_db()
|
||||
except Exception as e:
|
||||
log.error(f"数据库连接初始化失败: {e}")
|
||||
raise e
|
||||
|
||||
try:
|
||||
log.info("初始化redis连接")
|
||||
await init_redis()
|
||||
except Exception as e:
|
||||
log.warning(f"Redis 连接初始化失败: {e}")
|
||||
raise e
|
||||
log.info("所有资源已初始化,应用启动完成")
|
||||
|
||||
# 打印 banner
|
||||
banner_path = Path(__file__).parent.parent / "banner.txt"
|
||||
if banner_path.exists():
|
||||
print(banner_path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
|
||||
yield
|
||||
|
||||
# 关闭:释放资源
|
||||
await close_db()
|
||||
await close_redis()
|
||||
log.info("所有连接资源已释放")
|
||||
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from app.config import settings
|
||||
|
||||
# 日志目录
|
||||
_root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
_log_dir = os.path.join(_root_dir, "logs")
|
||||
|
||||
# 控制台日志格式
|
||||
_CONSOLE_FORMAT = (
|
||||
"<green>{time:YYYYMMDD HH:mm:ss}</green> | "
|
||||
"{process.name}:{process.id} | "
|
||||
"{extra[request_id]} | "
|
||||
"<cyan>{module}</cyan>.<cyan>{function}</cyan>"
|
||||
":<cyan>{line}</cyan> | "
|
||||
"<level>{level}</level>: "
|
||||
"<level>{message}</level>"
|
||||
)
|
||||
|
||||
# 文件日志格式
|
||||
_FILE_FORMAT = (
|
||||
"{time:YYYYMMDD HH:mm:ss} - "
|
||||
"{process.name}:{process.id} | "
|
||||
"{extra[request_id]} | "
|
||||
"{module}.{function}:{line} - {level} - {message}"
|
||||
)
|
||||
|
||||
|
||||
def _setup_logger() -> logger.__class__:
|
||||
"""配置并返回 Loguru 日志实例"""
|
||||
logger.remove()
|
||||
logger.configure(extra={"request_id": "system"})
|
||||
|
||||
# 控制台输出
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
level=settings.logging_level.upper(),
|
||||
format=_CONSOLE_FORMAT,
|
||||
)
|
||||
|
||||
# 非开发环境写入文件
|
||||
if settings.env != "dev":
|
||||
os.makedirs(_log_dir, exist_ok=True)
|
||||
log_file_path = os.path.join(_log_dir, settings.log_file_name)
|
||||
logger.add(
|
||||
log_file_path,
|
||||
level=settings.logging_level.upper(),
|
||||
encoding="UTF-8",
|
||||
format=_FILE_FORMAT,
|
||||
rotation="10 MB",
|
||||
retention=20,
|
||||
enqueue=True,
|
||||
)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
log = _setup_logger()
|
||||
@@ -0,0 +1,200 @@
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict
|
||||
|
||||
import jwt
|
||||
import shortuuid
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.config import settings
|
||||
from app.core.context import RequestContext
|
||||
from app.core.logger import log
|
||||
|
||||
# 跳过自定义处理的路径(docs 相关)
|
||||
_SKIP_PATHS = {"/openapi.json", "/docs", "/redoc"}
|
||||
|
||||
# 白名单匹配用的正则字符
|
||||
_CLEAN_PATTERN = re.compile(r"[*/ ]")
|
||||
|
||||
|
||||
def _is_whitelisted(path: str) -> bool:
|
||||
"""判断路径是否在白名单中,逻辑与 Java 端一致"""
|
||||
cleaned_path = _CLEAN_PATTERN.sub("", path)
|
||||
for item in settings.auth_whitelist:
|
||||
cleaned_item = _CLEAN_PATTERN.sub("", item)
|
||||
if not cleaned_item:
|
||||
continue
|
||||
if cleaned_path.startswith(cleaned_item):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class RequestIDMiddleware(BaseHTTPMiddleware):
|
||||
"""为每个请求生成唯一的 ShortUUID 并写入响应头"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
request_id = shortuuid.ShortUUID().random(length=18)
|
||||
request.state.uuid = request_id
|
||||
RequestContext.request_id.set(request_id)
|
||||
with log.contextualize(request_id=request_id):
|
||||
response = await call_next(request)
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
return response
|
||||
|
||||
|
||||
class JwtAuthMiddleware(BaseHTTPMiddleware):
|
||||
"""从请求中解析 JWT Token,校验登录状态,存入 RequestContext"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
if _is_whitelisted(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# 优先从 Cookie 获取,不存在则从请求头获取
|
||||
token = None
|
||||
if cookie_token := request.cookies.get("Token"):
|
||||
token = cookie_token
|
||||
if not token:
|
||||
token = request.headers.get("Token")
|
||||
|
||||
if not token:
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
await self._verify_and_set_user(token)
|
||||
except Exception as e:
|
||||
log.warning(f"Token 校验失败: {e}")
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
@staticmethod
|
||||
async def _verify_and_set_user(token: str) -> None:
|
||||
"""解析 JWT,校验 Redis 登录信息,续期"""
|
||||
payload = jwt.decode(token, settings.jwt_secret, algorithms=["HS256"])
|
||||
user_id = payload.get("userId")
|
||||
uu_id = payload.get("uuId")
|
||||
if not user_id or not uu_id:
|
||||
return
|
||||
|
||||
from app.core.redis import redis_client
|
||||
if redis_client is None:
|
||||
return
|
||||
|
||||
redis_key = f"login:token:{user_id}"
|
||||
raw = await redis_client.get(redis_key)
|
||||
if not raw:
|
||||
return
|
||||
|
||||
info = json.loads(raw) if isinstance(raw, str) else raw
|
||||
devices = info.get("loginDevices", [])
|
||||
|
||||
# 过滤过期设备
|
||||
now = datetime.now(timezone.utc)
|
||||
valid_devices = []
|
||||
for d in devices:
|
||||
last_login = datetime.fromisoformat(d["lastLoginTime"])
|
||||
if (now - last_login).total_seconds() < settings.token_expire_seconds:
|
||||
valid_devices.append(d)
|
||||
|
||||
# 校验当前设备
|
||||
device_map = {d["uuId"]: d for d in valid_devices}
|
||||
if uu_id not in device_map:
|
||||
return
|
||||
|
||||
# 续期
|
||||
device_map[uu_id]["lastLoginTime"] = now.isoformat()
|
||||
info["loginDevices"] = valid_devices
|
||||
await redis_client.set(redis_key, json.dumps(info), ex=settings.token_expire_seconds)
|
||||
|
||||
RequestContext.user_id.set(int(user_id))
|
||||
|
||||
|
||||
class AuthRequiredMiddleware(BaseHTTPMiddleware):
|
||||
"""鉴权拦截:非白名单路径必须有 user_id"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
if _is_whitelisted(request.url.path) or request.url.path in _SKIP_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
user_id = RequestContext.user_id.get(None)
|
||||
if user_id is None:
|
||||
from app.core.schemas.responses import StandardResponse
|
||||
content = StandardResponse.fail(msg="未经授权,请登录", code=401).model_dump_json()
|
||||
return Response(content=content, status_code=401, media_type="application/json")
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class RequestLogMiddleware(BaseHTTPMiddleware):
|
||||
"""记录请求和响应信息"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
if request.url.path in _SKIP_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
log_data: Dict[str, Any] = {
|
||||
"method": request.method,
|
||||
"url": str(request.url),
|
||||
"path": request.url.path,
|
||||
"client_ip": request.client.host if request.client else None,
|
||||
}
|
||||
|
||||
if request.path_params:
|
||||
log_data["path_params"] = dict(request.path_params)
|
||||
|
||||
if request.query_params:
|
||||
log_data["query_params"] = dict(request.query_params)
|
||||
|
||||
log.info(f"Request: {log_data}")
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
process_time = round(time.time() - start_time, 4)
|
||||
log.info(f"Response: status={response.status_code} time={process_time}s")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class ResponseWrapMiddleware(BaseHTTPMiddleware):
|
||||
"""将业务路由的 JSON 响应统一包装为 StandardResponse 格式"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
if request.url.path in _SKIP_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if response.status_code != 200 or "application/json" not in content_type:
|
||||
return response
|
||||
|
||||
# 读取原始响应体
|
||||
body = b""
|
||||
async for chunk in response.body_iterator:
|
||||
body += chunk if isinstance(chunk, bytes) else chunk.encode()
|
||||
|
||||
data = json.loads(body)
|
||||
|
||||
# 已经是 StandardResponse 格式的不重复包装
|
||||
if isinstance(data, dict) and "code" in data and "msg" in data and "timestamp" in data:
|
||||
return Response(content=body, status_code=200, media_type="application/json")
|
||||
|
||||
from app.core.schemas.responses import StandardResponse
|
||||
uuid = getattr(request.state, "uuid", None)
|
||||
content = StandardResponse.success(data=data, uuid=uuid).model_dump_json()
|
||||
|
||||
return Response(content=content, status_code=200, media_type="application/json")
|
||||
|
||||
|
||||
def register_middleware(app) -> None:
|
||||
"""注册所有中间件(注意:注册顺序与执行顺序相反)"""
|
||||
# 最后注册的最先执行
|
||||
app.add_middleware(ResponseWrapMiddleware)
|
||||
app.add_middleware(RequestLogMiddleware)
|
||||
app.add_middleware(AuthRequiredMiddleware)
|
||||
app.add_middleware(JwtAuthMiddleware)
|
||||
app.add_middleware(RequestIDMiddleware)
|
||||
@@ -0,0 +1,39 @@
|
||||
from typing import Optional
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from app.config import settings
|
||||
from app.core.logger import log
|
||||
|
||||
redis_client: Optional[aioredis.Redis] = None
|
||||
|
||||
|
||||
async def init_redis() -> None:
|
||||
"""初始化 Redis 连接池"""
|
||||
global redis_client
|
||||
try:
|
||||
redis_client = aioredis.from_url(
|
||||
settings.redis_url,
|
||||
max_connections=settings.redis_pool_size,
|
||||
decode_responses=True,
|
||||
)
|
||||
await redis_client.ping()
|
||||
log.info("Redis 连接池已初始化")
|
||||
except Exception as e:
|
||||
log.error(f"Redis 连接初始化失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def close_redis() -> None:
|
||||
"""关闭 Redis 连接池"""
|
||||
global redis_client
|
||||
if redis_client:
|
||||
await redis_client.close()
|
||||
log.info("Redis 连接池已关闭")
|
||||
|
||||
|
||||
async def get_redis() -> aioredis.Redis:
|
||||
"""依赖注入:提供 Redis 客户端实例"""
|
||||
if redis_client is None:
|
||||
raise RuntimeError("Redis 未初始化,请先调用 init_redis()")
|
||||
return redis_client
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class StandardResponse(BaseModel):
|
||||
"""统一响应格式"""
|
||||
code: int = 200
|
||||
msg: str = "正常响应"
|
||||
data: Any = None
|
||||
timestamp: str = Field(default_factory=lambda: str(int(time.time())))
|
||||
uuid: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def success(
|
||||
cls,
|
||||
data: Any = None,
|
||||
msg: str = "正常响应",
|
||||
code: int = 200,
|
||||
uuid: str | None = None,
|
||||
) -> "StandardResponse":
|
||||
return cls(code=code, msg=msg, data=data, uuid=uuid)
|
||||
|
||||
@classmethod
|
||||
def fail(
|
||||
cls,
|
||||
msg: str = "操作失败",
|
||||
code: int = 500,
|
||||
data: Any = None,
|
||||
uuid: str | None = None,
|
||||
) -> "StandardResponse":
|
||||
return cls(code=code, msg=msg, data=data, uuid=uuid)
|
||||
Reference in New Issue
Block a user