from collections.abc import AsyncGenerator from typing import Optional from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine, ) from sqlalchemy.orm import DeclarativeBase from app.config import settings from app.core.logger import log # 全局引擎和会话工厂 engine: Optional[AsyncEngine] = None async_session_factory: Optional[async_sessionmaker[AsyncSession]] = None class Base(DeclarativeBase): """ORM 声明基类""" pass async def init_db() -> None: """初始化数据库引擎和会话工厂""" global engine, async_session_factory engine = create_async_engine( settings.database_url, pool_size=settings.db_pool_size, max_overflow=settings.db_max_overflow, pool_recycle=settings.db_pool_recycle, echo=settings.env == "dev", ) async_session_factory = async_sessionmaker(engine, expire_on_commit=False) log.info("数据库连接池已初始化") async def close_db() -> None: """关闭数据库引擎,释放连接池""" global engine if engine: await engine.dispose() log.info("数据库连接池已关闭") async def get_db() -> AsyncGenerator[AsyncSession, None]: """依赖注入:提供异步数据库会话,自动 commit/rollback/close""" if async_session_factory is None: raise RuntimeError("数据库未初始化,请先调用 init_db()") async with async_session_factory() as session: try: yield session await session.commit() except Exception: await session.rollback() raise