From 95028cdef9dbcf0937b91bf0eb89298dfa3300cf Mon Sep 17 00:00:00 2001 From: zk Date: Fri, 13 Mar 2026 15:37:01 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=8E=88=E6=9D=83=E4=BE=9D?= =?UTF-8?q?=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/auth.py | 58 +++++++++++++++++ app/core/middleware.py | 8 ++- app/models/func_permission.py | 22 +++++++ app/models/user_func_permission_stock.py | 24 +++++++ app/models/user_func_usage_log.py | 18 ++++++ app/services/func_permission_service.py | 81 ++++++++++++++++++++++++ 6 files changed, 210 insertions(+), 1 deletion(-) create mode 100644 app/core/auth.py create mode 100644 app/models/func_permission.py create mode 100644 app/models/user_func_permission_stock.py create mode 100644 app/models/user_func_usage_log.py create mode 100644 app/services/func_permission_service.py diff --git a/app/core/auth.py b/app/core/auth.py new file mode 100644 index 0000000..67349fb --- /dev/null +++ b/app/core/auth.py @@ -0,0 +1,58 @@ +"""权限校验依赖 + +Usage: + from app.core.auth import require_login, func_permission + + # 仅需登录 + @router.get("/profile") + async def get_profile(user_id: int = Depends(require_login)): + ... + + # 需要功能权限(校验 + 扣库存,异常自动回退) + @router.post("/ai/generate") + async def generate(_: None = Depends(func_permission("ai:generate"))): + ... +""" + +from fastapi import HTTPException + +from app.core.context import RequestContext +from app.core.database import get_db +from app.core.logger import log + + +async def require_login() -> int: + """要求登录,返回 user_id""" + user_id = RequestContext.user_id.get(None) + if user_id is None: + raise HTTPException(status_code=401, detail="未经授权,请登录") + return user_id + + +def func_permission(func_code: str): + """功能权限校验:校验权限 + 扣库存 → 执行业务 → 异常回退""" + + async def dependency(): + user_id = await require_login() + + from app.services.func_permission_service import FuncPermissionService + + # 事务1:校验 + 扣库存 + async for session in get_db(): + service = FuncPermissionService(session) + log.info(f"功能权限校验 userId:{user_id} funcCode:{func_code}") + log_id = await service.check_and_deduct(user_id, func_code) + + try: + yield + except Exception as e: + # 事务2:回滚 + log.warning( + f"业务异常,回退使用记录 logId:{log_id} userId:{user_id} funcCode:{func_code}" + ) + async for session in get_db(): + service = FuncPermissionService(session) + await service.rollback_usage(log_id, user_id, func_code) + raise + + return dependency diff --git a/app/core/middleware.py b/app/core/middleware.py index 2e1b18a..b4dce67 100644 --- a/app/core/middleware.py +++ b/app/core/middleware.py @@ -94,7 +94,13 @@ class JwtAuthMiddleware(BaseHTTPMiddleware): now = datetime.now(timezone.utc) valid_devices = [] for d in devices: - last_login = datetime.fromisoformat(d["lastLoginTime"]) + last_login_str = d["lastLoginTime"] + # 兼容 Java Instant 格式(尾部 Z) + if last_login_str.endswith("Z"): + last_login_str = last_login_str[:-1] + "+00:00" + last_login = datetime.fromisoformat(last_login_str) + if last_login.tzinfo is None: + last_login = last_login.replace(tzinfo=timezone.utc) if (now - last_login).total_seconds() < settings.token_expire_seconds: valid_devices.append(d) diff --git a/app/models/func_permission.py b/app/models/func_permission.py new file mode 100644 index 0000000..f328e23 --- /dev/null +++ b/app/models/func_permission.py @@ -0,0 +1,22 @@ +"""功能权限表""" + +from datetime import datetime + +from sqlalchemy import BigInteger, Integer, String, DateTime +from sqlalchemy.orm import Mapped, mapped_column + +from app.core.database import Base + + +class FuncPermission(Base): + """功能权限定义表 bg_func_permission""" + __tablename__ = "bg_func_permission" + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + func_code: Mapped[str] = mapped_column(String(12), nullable=False, comment="权限编码") + func_name: Mapped[str] = mapped_column(String(64), nullable=False, comment="功能名称") + daily_free_count: Mapped[int] = mapped_column(Integer, default=0, comment="每日免费次数,0表示无免费额度") + status: Mapped[int] = mapped_column(Integer, default=1, comment="状态 1=启用 0=禁用") + create_time: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment="创建时间") + update_time: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, onupdate=datetime.now, comment="修改时间") + is_delete: Mapped[int] = mapped_column(BigInteger, default=0, comment="删除标识 0正常 非0删除") diff --git a/app/models/user_func_permission_stock.py b/app/models/user_func_permission_stock.py new file mode 100644 index 0000000..2978519 --- /dev/null +++ b/app/models/user_func_permission_stock.py @@ -0,0 +1,24 @@ +"""用户功能权限库存表""" + +from datetime import datetime +from typing import Optional + +from sqlalchemy import BigInteger, Integer, String, DateTime +from sqlalchemy.orm import Mapped, mapped_column + +from app.core.database import Base + + +class UserFuncPermissionStock(Base): + """用户功能权限库存表 bg_user_func_permission_stock""" + __tablename__ = "bg_user_func_permission_stock" + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + user_id: Mapped[int] = mapped_column(BigInteger, nullable=False, comment="用户ID") + func_code: Mapped[str] = mapped_column(String(12), nullable=False, comment="权限编码") + time_limit: Mapped[int] = mapped_column(Integer, default=0, comment="0=不限时 1=限时") + count_limit: Mapped[int] = mapped_column(Integer, default=0, comment="0=不限次 1=限次") + expire_time: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True, comment="过期时间") + remain_count: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, comment="剩余次数") + create_time: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment="创建时间") + update_time: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, onupdate=datetime.now, comment="修改时间") diff --git a/app/models/user_func_usage_log.py b/app/models/user_func_usage_log.py new file mode 100644 index 0000000..cd4372b --- /dev/null +++ b/app/models/user_func_usage_log.py @@ -0,0 +1,18 @@ +"""用户功能使用记录表""" + +from datetime import datetime + +from sqlalchemy import BigInteger, String, DateTime +from sqlalchemy.orm import Mapped, mapped_column + +from app.core.database import Base + + +class UserFuncUsageLog(Base): + """用户功能使用记录表 bg_user_func_usage_log""" + __tablename__ = "bg_user_func_usage_log" + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True) + user_id: Mapped[int] = mapped_column(BigInteger, nullable=False, comment="用户ID") + func_code: Mapped[str] = mapped_column(String(12), nullable=False, comment="功能编码") + create_time: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment="使用时间") diff --git a/app/services/func_permission_service.py b/app/services/func_permission_service.py new file mode 100644 index 0000000..716e568 --- /dev/null +++ b/app/services/func_permission_service.py @@ -0,0 +1,81 @@ +"""功能权限 Service + +校验用户功能权限并扣减库存,业务异常时回退。 +逻辑与 Java 端 FuncPermissionService 完全一致。 +""" + +from datetime import datetime + +from fastapi import HTTPException +from sqlalchemy import select, func, update, delete +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.func_permission import FuncPermission +from app.models.user_func_permission_stock import UserFuncPermissionStock +from app.models.user_func_usage_log import UserFuncUsageLog + + +class FuncPermissionService: + + def __init__(self, session: AsyncSession): + self.session = session + + async def check_and_deduct(self, user_id: int, func_code: str) -> int: + """校验权限 + 扣减库存,返回使用记录ID""" + + # 1. 查功能权限定义 + result = await self.session.execute(select(FuncPermission).where(FuncPermission.func_code == func_code, FuncPermission.status == 1, FuncPermission.is_delete == 0)) + perm = result.scalar_one_or_none() + + if perm is None: + raise HTTPException(status_code=403, detail="功能不存在或未启用") + + # 2. 判断每日免费额度 + daily_free = perm.daily_free_count or 0 + if daily_free > 0: + today_start = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) + + # noinspection PyTypeChecker + result = await self.session.execute(select(func.count()).select_from(UserFuncUsageLog).where(UserFuncUsageLog.user_id == user_id, UserFuncUsageLog.func_code == func_code, UserFuncUsageLog.create_time >= today_start,)) + today_used = result.scalar() or 0 + + if today_used < daily_free: + return await self._insert_usage_log(user_id, func_code) + + # 3. 查付费库存 + result = await self.session.execute(select(UserFuncPermissionStock).where(UserFuncPermissionStock.user_id == user_id, UserFuncPermissionStock.func_code == func_code)) + stock = result.scalar_one_or_none() + if stock is None: + raise HTTPException(status_code=403, detail="无该功能权限") + + # 4. 时间维度校验 + if stock.time_limit == 1 and stock.expire_time is not None: + if stock.expire_time < datetime.now(): + raise HTTPException(status_code=403, detail="功能权限已过期") + + # 5. 次数维度校验 + if stock.count_limit == 0: + return await self._insert_usage_log(user_id, func_code) + + # 限次,SQL 原子扣减 + result = await self.session.execute(update(UserFuncPermissionStock).where(UserFuncPermissionStock.user_id == user_id, UserFuncPermissionStock.func_code == func_code, UserFuncPermissionStock.remain_count > 0).values(remain_count=UserFuncPermissionStock.remain_count - 1)) + + if result.rowcount == 0: + raise HTTPException(status_code=403, detail="功能使用次数已用完") + + return await self._insert_usage_log(user_id, func_code) + + async def rollback_usage(self, log_id: int, user_id: int, func_code: str) -> None: + """回退使用记录 + 库存""" + # 删除使用记录 + await self.session.execute(delete(UserFuncUsageLog).where(UserFuncUsageLog.id == log_id)) + + # 尝试回退库存(count_limit=1 才会匹配) + await self.session.execute(update(UserFuncPermissionStock).where(UserFuncPermissionStock.user_id == user_id, UserFuncPermissionStock.func_code == func_code, UserFuncPermissionStock.count_limit == 1).values(remain_count=UserFuncPermissionStock.remain_count + 1)) + + async def _insert_usage_log(self, user_id: int, func_code: str) -> int: + """插入使用记录,返回记录ID""" + usage_log = UserFuncUsageLog(user_id=user_id, func_code=func_code) + self.session.add(usage_log) + await self.session.flush() + return usage_log.id