初始话项目框架

This commit is contained in:
zk
2026-03-13 13:51:51 +08:00
commit f26585a130
25 changed files with 845 additions and 0 deletions
+2
View File
@@ -0,0 +1,2 @@
ENV=pro
+1
View File
@@ -0,0 +1 @@
ENV=test
+57
View File
@@ -0,0 +1,57 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
.env
venv/
ENV/
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# Project specific
app/token_cache.json
test/recordings/
test/resume/
*.log
# System files
.DS_Store
Thumbs.db
.directory
# Jupyter Notebook
.ipynb_checkpoints
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
View File
+1
View File
@@ -0,0 +1 @@
+48
View File
@@ -0,0 +1,48 @@
"""LLM 模型枚举与实例获取
Usage:
from app.ai.models import LLM
llm = LLM.DOUBAO_PRO_256K.create(temperature=0)
llm = LLM.DEEPSEEK_V3.create()
"""
from enum import Enum
from langchain_openai import ChatOpenAI
from app.config import settings
# 供应商连接配置
_VOLCENGINE = (lambda: settings.volcengine_api_key, lambda: settings.volcengine_base_url)
_CARDIAC = (lambda: settings.cardiacBrder_api_key, lambda: settings.cardiacBrder_base_url)
class LLM(Enum):
"""所有可用模型,每个枚举值 = (模型名, api_key函数, base_url函数)"""
# 火山引擎
DOUBAO_PRO_256K = ("doubao-pro-256k", *_VOLCENGINE)
DOUBAO_PRO_32K = ("doubao-pro-32k", *_VOLCENGINE)
DOUBAO_LITE_128K = ("doubao-lite-128k", *_VOLCENGINE)
DEEPSEEK_V3 = ("deepseek-v3-250324", *_VOLCENGINE)
DEEPSEEK_R1 = ("deepseek-r1-250528", *_VOLCENGINE)
# 心缘
GPT_4O = ("gpt-4o", *_CARDIAC)
GPT_4O_MINI = ("gpt-4o-mini", *_CARDIAC)
CLAUDE_SONNET_4 = ("claude-sonnet-4-20250514", *_CARDIAC)
def __init__(self, model_name: str, api_key_fn, base_url_fn):
self.model_name = model_name
self._api_key_fn = api_key_fn
self._base_url_fn = base_url_fn
def create(self, **kwargs) -> ChatOpenAI:
"""创建 LLM 实例,kwargs 透传给 ChatOpenAItemperature, max_tokens 等)"""
return ChatOpenAI(
model=self.model_name,
api_key=self._api_key_fn(),
base_url=self._base_url_fn(),
**kwargs,
)
+1
View File
@@ -0,0 +1 @@
+9
View File
@@ -0,0 +1,9 @@
"""健康检查路由"""
from fastapi import APIRouter
router = APIRouter(prefix="/health", tags=["健康检查"])
@router.get("/", summary="健康检查")
async def health_check():
return {"status": "ok"}
+10
View File
@@ -0,0 +1,10 @@
____ __ __ ____ _
/ __ \ / _|/ _| | _ \ (_)
| | | | |_| |_ ___ _ __ ___| |_) | _ ___
| | | | _| _/ _ \ '__/ __| __/ | |/ _ \
| |__| | | | || __/ | | |__| | | | __/
\____/|_| |_| \___|_| \___|_| |_|\___|
OfferPie AI
FastAPI Starting...
+10
View File
@@ -0,0 +1,10 @@
import os
from .settings import Settings
_env = os.getenv("ENV", "dev")
_env_files = {"dev": ".env", "test": ".env.test", "pro": ".env.prod"}
settings = Settings(_env_file=_env_files.get(_env, ".env"))
__all__ = ["settings"]
+77
View File
@@ -0,0 +1,77 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
# 环境
env: str = "dev"
# 项目基础信息
project_name: str = "FastAPI App"
version: str = "0.1.0"
server_port: int = 8000
# CORS
cors_origins: list[str] = ["*"]
cors_allow_methods: list[str] = ["*"]
cors_allow_headers: list[str] = ["*"]
# 数据库 (MySQL)
db_host: str = "192.168.31.105"
db_port: int = 3306
db_user: str = "root"
db_password: str = "123456"
db_name: str = "offerpie"
db_pool_size: int = 10
db_max_overflow: int = 20
db_pool_recycle: int = 3600
# Redis
redis_host: str = "192.168.31.105"
redis_port: int = 6379
redis_password: str = "123456"
redis_db: int = 0
redis_pool_size: int = 10
# LLM 供应商连接配置
# 火山引擎
volcengine_api_key: str = "fd065993-bee2-4f31-8bf2-56d5d3012c02"
volcengine_base_url: str = "https://ark.cn-beijing.volces.com/api/v3"
# 心缘
cardiacBrder_api_key: str = "sk-8NxoLe7ZTJveGSmtPENBm4NwN9ai4YLGw8y6fqueZrPTo4Uu"
cardiacBrder_base_url: str = "https://api-i.xykjy.com/v1"
# JWT
jwt_secret: str = "Aa123123"
token_expire_seconds: int = 5184000
# 鉴权白名单路径
auth_whitelist: list[str] = ["/health/**", "/docs/**", "/redoc/**", "/openapi.json"]
# 日志
logging_level: str = "DEBUG"
log_file_name: str = "app.log"
@property
def database_url(self) -> str:
return (
f"mysql+asyncmy://{self.db_user}:{self.db_password}"
f"@{self.db_host}:{self.db_port}/{self.db_name}"
)
@property
def redis_url(self) -> str:
if self.redis_password:
return (
f"redis://:{self.redis_password}"
f"@{self.redis_host}:{self.redis_port}/{self.redis_db}"
)
return f"redis://{self.redis_host}:{self.redis_port}/{self.redis_db}"
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
extra="ignore",
)
+15
View File
@@ -0,0 +1,15 @@
from .context import RequestContext
from .exceptions import register_exception_handlers
from .lifespan import lifespan
from .logger import log
from .middleware import register_middleware
from .schemas.responses import StandardResponse
__all__ = [
"RequestContext",
"StandardResponse",
"lifespan",
"log",
"register_exception_handlers",
"register_middleware",
]
+11
View File
@@ -0,0 +1,11 @@
from contextvars import ContextVar
class RequestContext:
"""请求上下文变量"""
# 当前请求 ID
request_id: ContextVar[str] = ContextVar("request_id", default="system")
# 当前用户id
user_id: ContextVar[int] = ContextVar("user_id", default=None)
+59
View File
@@ -0,0 +1,59 @@
from collections.abc import AsyncGenerator
from typing import Optional
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import DeclarativeBase
from app.config import settings
from app.core.logger import log
# 全局引擎和会话工厂
engine: Optional[AsyncEngine] = None
async_session_factory: Optional[async_sessionmaker[AsyncSession]] = None
class Base(DeclarativeBase):
"""ORM 声明基类"""
pass
async def init_db() -> None:
"""初始化数据库引擎和会话工厂"""
global engine, async_session_factory
engine = create_async_engine(
settings.database_url,
pool_size=settings.db_pool_size,
max_overflow=settings.db_max_overflow,
pool_recycle=settings.db_pool_recycle,
echo=settings.env == "dev",
)
async_session_factory = async_sessionmaker(engine, expire_on_commit=False)
log.info("数据库连接池已初始化")
async def close_db() -> None:
"""关闭数据库引擎,释放连接池"""
global engine
if engine:
await engine.dispose()
log.info("数据库连接池已关闭")
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""依赖注入:提供异步数据库会话,自动 commit/rollback/close"""
if async_session_factory is None:
raise RuntimeError("数据库未初始化,请先调用 init_db()")
async with async_session_factory() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
+73
View File
@@ -0,0 +1,73 @@
import traceback
from fastapi import Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from starlette.exceptions import HTTPException as StarletteHTTPException
from app.config import settings
from app.core.logger import log
from app.core.schemas.responses import StandardResponse
# 友好的 HTTP 状态码消息映射
_FRIENDLY_MESSAGES = {
400: "请求参数错误",
401: "未经授权,请登录",
403: "权限不足,禁止访问",
404: "请求的资源不存在",
405: "请求方法不允许",
422: "数据验证失败",
429: "请求过于频繁",
500: "服务器内部错误",
}
def _get_uuid(request: Request) -> str | None:
return getattr(request.state, "uuid", None)
async def http_exception_handler(request: Request, exc: StarletteHTTPException) -> JSONResponse:
uuid = _get_uuid(request)
log.error(f"HTTPException -- uuid: {uuid} | status: {exc.status_code} | msg: {exc.detail}")
message = _FRIENDLY_MESSAGES.get(exc.status_code, str(exc.detail))
return JSONResponse(
status_code=exc.status_code,
content=StandardResponse.fail(msg=message, code=exc.status_code, uuid=uuid).model_dump(),
)
async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
uuid = _get_uuid(request)
errors = exc.errors()
log.error(f"ValidationError -- uuid: {uuid} | errors: {errors}")
return JSONResponse(
status_code=422,
content=StandardResponse.fail(msg="数据验证失败", code=422, data=errors, uuid=uuid).model_dump(),
)
async def assertion_error_handler(request: Request, exc: AssertionError) -> JSONResponse:
uuid = _get_uuid(request)
log.error(f"AssertionError -- uuid: {uuid} | msg: {exc}\n{traceback.format_exc()}")
return JSONResponse(
status_code=500,
content=StandardResponse.fail(msg=f"断言错误: {exc}", uuid=uuid).model_dump(),
)
async def global_exception_handler(request: Request, exc: Exception) -> JSONResponse:
uuid = _get_uuid(request)
log.error(f"Unhandled Exception -- uuid: {uuid} | msg: {exc}\n{traceback.format_exc()}")
msg = str(exc) if settings.env == "dev" else "服务器内部错误"
return JSONResponse(
status_code=500,
content=StandardResponse.fail(msg=msg, uuid=uuid).model_dump(),
)
def register_exception_handlers(app) -> None:
"""将异常处理器挂载到 FastAPI 应用"""
app.add_exception_handler(StarletteHTTPException, http_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
app.add_exception_handler(AssertionError, assertion_error_handler)
app.add_exception_handler(Exception, global_exception_handler)
+43
View File
@@ -0,0 +1,43 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI
from app.core.logger import log
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理:初始化和释放所有连接资源"""
from app.core.database import init_db, close_db
from app.core.redis import init_redis, close_redis
from pathlib import Path
# 启动:初始化资源
try:
log.info("初始化数据库连接")
await init_db()
except Exception as e:
log.error(f"数据库连接初始化失败: {e}")
raise e
try:
log.info("初始化redis连接")
await init_redis()
except Exception as e:
log.warning(f"Redis 连接初始化失败: {e}")
raise e
log.info("所有资源已初始化,应用启动完成")
# 打印 banner
banner_path = Path(__file__).parent.parent / "banner.txt"
if banner_path.exists():
print(banner_path.read_text(encoding="utf-8"))
yield
# 关闭:释放资源
await close_db()
await close_redis()
log.info("所有连接资源已释放")
+61
View File
@@ -0,0 +1,61 @@
import os
import sys
from loguru import logger
from app.config import settings
# 日志目录
_root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
_log_dir = os.path.join(_root_dir, "logs")
# 控制台日志格式
_CONSOLE_FORMAT = (
"<green>{time:YYYYMMDD HH:mm:ss}</green> | "
"{process.name}:{process.id} | "
"{extra[request_id]} | "
"<cyan>{module}</cyan>.<cyan>{function}</cyan>"
":<cyan>{line}</cyan> | "
"<level>{level}</level>: "
"<level>{message}</level>"
)
# 文件日志格式
_FILE_FORMAT = (
"{time:YYYYMMDD HH:mm:ss} - "
"{process.name}:{process.id} | "
"{extra[request_id]} | "
"{module}.{function}:{line} - {level} - {message}"
)
def _setup_logger() -> logger.__class__:
"""配置并返回 Loguru 日志实例"""
logger.remove()
logger.configure(extra={"request_id": "system"})
# 控制台输出
logger.add(
sys.stdout,
level=settings.logging_level.upper(),
format=_CONSOLE_FORMAT,
)
# 非开发环境写入文件
if settings.env != "dev":
os.makedirs(_log_dir, exist_ok=True)
log_file_path = os.path.join(_log_dir, settings.log_file_name)
logger.add(
log_file_path,
level=settings.logging_level.upper(),
encoding="UTF-8",
format=_FILE_FORMAT,
rotation="10 MB",
retention=20,
enqueue=True,
)
return logger
log = _setup_logger()
+200
View File
@@ -0,0 +1,200 @@
import json
import re
import time
from datetime import datetime, timezone
from typing import Any, Dict
import jwt
import shortuuid
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from app.config import settings
from app.core.context import RequestContext
from app.core.logger import log
# 跳过自定义处理的路径(docs 相关)
_SKIP_PATHS = {"/openapi.json", "/docs", "/redoc"}
# 白名单匹配用的正则字符
_CLEAN_PATTERN = re.compile(r"[*/ ]")
def _is_whitelisted(path: str) -> bool:
"""判断路径是否在白名单中,逻辑与 Java 端一致"""
cleaned_path = _CLEAN_PATTERN.sub("", path)
for item in settings.auth_whitelist:
cleaned_item = _CLEAN_PATTERN.sub("", item)
if not cleaned_item:
continue
if cleaned_path.startswith(cleaned_item):
return True
return False
class RequestIDMiddleware(BaseHTTPMiddleware):
"""为每个请求生成唯一的 ShortUUID 并写入响应头"""
async def dispatch(self, request: Request, call_next) -> Response:
request_id = shortuuid.ShortUUID().random(length=18)
request.state.uuid = request_id
RequestContext.request_id.set(request_id)
with log.contextualize(request_id=request_id):
response = await call_next(request)
response.headers["X-Request-ID"] = request_id
return response
class JwtAuthMiddleware(BaseHTTPMiddleware):
"""从请求中解析 JWT Token,校验登录状态,存入 RequestContext"""
async def dispatch(self, request: Request, call_next) -> Response:
if _is_whitelisted(request.url.path):
return await call_next(request)
# 优先从 Cookie 获取,不存在则从请求头获取
token = None
if cookie_token := request.cookies.get("Token"):
token = cookie_token
if not token:
token = request.headers.get("Token")
if not token:
return await call_next(request)
try:
await self._verify_and_set_user(token)
except Exception as e:
log.warning(f"Token 校验失败: {e}")
return await call_next(request)
@staticmethod
async def _verify_and_set_user(token: str) -> None:
"""解析 JWT,校验 Redis 登录信息,续期"""
payload = jwt.decode(token, settings.jwt_secret, algorithms=["HS256"])
user_id = payload.get("userId")
uu_id = payload.get("uuId")
if not user_id or not uu_id:
return
from app.core.redis import redis_client
if redis_client is None:
return
redis_key = f"login:token:{user_id}"
raw = await redis_client.get(redis_key)
if not raw:
return
info = json.loads(raw) if isinstance(raw, str) else raw
devices = info.get("loginDevices", [])
# 过滤过期设备
now = datetime.now(timezone.utc)
valid_devices = []
for d in devices:
last_login = datetime.fromisoformat(d["lastLoginTime"])
if (now - last_login).total_seconds() < settings.token_expire_seconds:
valid_devices.append(d)
# 校验当前设备
device_map = {d["uuId"]: d for d in valid_devices}
if uu_id not in device_map:
return
# 续期
device_map[uu_id]["lastLoginTime"] = now.isoformat()
info["loginDevices"] = valid_devices
await redis_client.set(redis_key, json.dumps(info), ex=settings.token_expire_seconds)
RequestContext.user_id.set(int(user_id))
class AuthRequiredMiddleware(BaseHTTPMiddleware):
"""鉴权拦截:非白名单路径必须有 user_id"""
async def dispatch(self, request: Request, call_next) -> Response:
if _is_whitelisted(request.url.path) or request.url.path in _SKIP_PATHS:
return await call_next(request)
user_id = RequestContext.user_id.get(None)
if user_id is None:
from app.core.schemas.responses import StandardResponse
content = StandardResponse.fail(msg="未经授权,请登录", code=401).model_dump_json()
return Response(content=content, status_code=401, media_type="application/json")
return await call_next(request)
class RequestLogMiddleware(BaseHTTPMiddleware):
"""记录请求和响应信息"""
async def dispatch(self, request: Request, call_next) -> Response:
if request.url.path in _SKIP_PATHS:
return await call_next(request)
start_time = time.time()
log_data: Dict[str, Any] = {
"method": request.method,
"url": str(request.url),
"path": request.url.path,
"client_ip": request.client.host if request.client else None,
}
if request.path_params:
log_data["path_params"] = dict(request.path_params)
if request.query_params:
log_data["query_params"] = dict(request.query_params)
log.info(f"Request: {log_data}")
response = await call_next(request)
process_time = round(time.time() - start_time, 4)
log.info(f"Response: status={response.status_code} time={process_time}s")
return response
class ResponseWrapMiddleware(BaseHTTPMiddleware):
"""将业务路由的 JSON 响应统一包装为 StandardResponse 格式"""
async def dispatch(self, request: Request, call_next) -> Response:
if request.url.path in _SKIP_PATHS:
return await call_next(request)
response = await call_next(request)
content_type = response.headers.get("content-type", "")
if response.status_code != 200 or "application/json" not in content_type:
return response
# 读取原始响应体
body = b""
async for chunk in response.body_iterator:
body += chunk if isinstance(chunk, bytes) else chunk.encode()
data = json.loads(body)
# 已经是 StandardResponse 格式的不重复包装
if isinstance(data, dict) and "code" in data and "msg" in data and "timestamp" in data:
return Response(content=body, status_code=200, media_type="application/json")
from app.core.schemas.responses import StandardResponse
uuid = getattr(request.state, "uuid", None)
content = StandardResponse.success(data=data, uuid=uuid).model_dump_json()
return Response(content=content, status_code=200, media_type="application/json")
def register_middleware(app) -> None:
"""注册所有中间件(注意:注册顺序与执行顺序相反)"""
# 最后注册的最先执行
app.add_middleware(ResponseWrapMiddleware)
app.add_middleware(RequestLogMiddleware)
app.add_middleware(AuthRequiredMiddleware)
app.add_middleware(JwtAuthMiddleware)
app.add_middleware(RequestIDMiddleware)
+39
View File
@@ -0,0 +1,39 @@
from typing import Optional
import redis.asyncio as aioredis
from app.config import settings
from app.core.logger import log
redis_client: Optional[aioredis.Redis] = None
async def init_redis() -> None:
"""初始化 Redis 连接池"""
global redis_client
try:
redis_client = aioredis.from_url(
settings.redis_url,
max_connections=settings.redis_pool_size,
decode_responses=True,
)
await redis_client.ping()
log.info("Redis 连接池已初始化")
except Exception as e:
log.error(f"Redis 连接初始化失败: {e}")
raise
async def close_redis() -> None:
"""关闭 Redis 连接池"""
global redis_client
if redis_client:
await redis_client.close()
log.info("Redis 连接池已关闭")
async def get_redis() -> aioredis.Redis:
"""依赖注入:提供 Redis 客户端实例"""
if redis_client is None:
raise RuntimeError("Redis 未初始化,请先调用 init_redis()")
return redis_client
+1
View File
@@ -0,0 +1 @@
+33
View File
@@ -0,0 +1,33 @@
import time
from typing import Any, Optional
from pydantic import BaseModel, Field
class StandardResponse(BaseModel):
"""统一响应格式"""
code: int = 200
msg: str = "正常响应"
data: Any = None
timestamp: str = Field(default_factory=lambda: str(int(time.time())))
uuid: Optional[str] = None
@classmethod
def success(
cls,
data: Any = None,
msg: str = "正常响应",
code: int = 200,
uuid: str | None = None,
) -> "StandardResponse":
return cls(code=code, msg=msg, data=data, uuid=uuid)
@classmethod
def fail(
cls,
msg: str = "操作失败",
code: int = 500,
data: Any = None,
uuid: str | None = None,
) -> "StandardResponse":
return cls(code=code, msg=msg, data=data, uuid=uuid)
+47
View File
@@ -0,0 +1,47 @@
import os
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.config import settings
from app.core.exceptions import register_exception_handlers
from app.core.lifespan import lifespan
from app.core.middleware import register_middleware
app = FastAPI(
title=settings.project_name,
version=settings.version,
lifespan=lifespan,
)
# 注册全局异常处理器
register_exception_handlers(app)
# 注册自定义中间件
register_middleware(app)
# 注册 CORS 中间件
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=False,
allow_methods=settings.cors_allow_methods,
allow_headers=settings.cors_allow_headers,
)
# ========== 路由注册 ==========
from app.api.health import router as health_router
app.include_router(health_router)
# ==============================
if __name__ == "__main__":
os.environ["ENV"] = "dev"
import uvicorn
uvicorn.run(
"app.main:app",
host="0.0.0.0",
timeout_graceful_shutdown=5,
port=settings.server_port,
)
+1
View File
@@ -0,0 +1 @@
+1
View File
@@ -0,0 +1 @@
+45
View File
@@ -0,0 +1,45 @@
# Web 框架与 ASGI 服务
fastapi>=0.115.0
uvicorn[standard]>=0.30.0,<0.32.0
gunicorn>=23.0.0
# 数据校验与配置
pydantic>=2.6.0,<3.0.0
pydantic-settings>=2.2.0,<3.0.0
# 数据库 (SQLAlchemy 异步 + MySQL)
sqlalchemy[asyncio]>=2.0.0
asyncmy>=0.2.9
# 缓存 (Redis 异步)
redis>=5.0.0
# AI SDK
openai>=1.30.0
# LangChain 生态
langchain>=0.3.0
langchain-core>=0.3.0
langchain-openai>=0.3.0
langgraph>=0.2.0
# HTTP 客户端
httpx>=0.28.0
# 日志与工具库
loguru>=0.7.2,<1.0.0
shortuuid>=1.0.11,<2.0.0
PyJWT>=2.8.0
# 数据处理
pandas>=2.2.0
numpy>=1.26.0
# FastAPI 上传与环境变量辅助
python-multipart>=0.0.9
python-dotenv>=1.0.0
# 测试
pytest>=8.0.0
pytest-asyncio>=0.24.0
hypothesis>=6.100.0