添加授权依赖
This commit is contained in:
@@ -0,0 +1,58 @@
|
|||||||
|
"""权限校验依赖
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from app.core.auth import require_login, func_permission
|
||||||
|
|
||||||
|
# 仅需登录
|
||||||
|
@router.get("/profile")
|
||||||
|
async def get_profile(user_id: int = Depends(require_login)):
|
||||||
|
...
|
||||||
|
|
||||||
|
# 需要功能权限(校验 + 扣库存,异常自动回退)
|
||||||
|
@router.post("/ai/generate")
|
||||||
|
async def generate(_: None = Depends(func_permission("ai:generate"))):
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from app.core.context import RequestContext
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.core.logger import log
|
||||||
|
|
||||||
|
|
||||||
|
async def require_login() -> int:
|
||||||
|
"""要求登录,返回 user_id"""
|
||||||
|
user_id = RequestContext.user_id.get(None)
|
||||||
|
if user_id is None:
|
||||||
|
raise HTTPException(status_code=401, detail="未经授权,请登录")
|
||||||
|
return user_id
|
||||||
|
|
||||||
|
|
||||||
|
def func_permission(func_code: str):
|
||||||
|
"""功能权限校验:校验权限 + 扣库存 → 执行业务 → 异常回退"""
|
||||||
|
|
||||||
|
async def dependency():
|
||||||
|
user_id = await require_login()
|
||||||
|
|
||||||
|
from app.services.func_permission_service import FuncPermissionService
|
||||||
|
|
||||||
|
# 事务1:校验 + 扣库存
|
||||||
|
async for session in get_db():
|
||||||
|
service = FuncPermissionService(session)
|
||||||
|
log.info(f"功能权限校验 userId:{user_id} funcCode:{func_code}")
|
||||||
|
log_id = await service.check_and_deduct(user_id, func_code)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
except Exception as e:
|
||||||
|
# 事务2:回滚
|
||||||
|
log.warning(
|
||||||
|
f"业务异常,回退使用记录 logId:{log_id} userId:{user_id} funcCode:{func_code}"
|
||||||
|
)
|
||||||
|
async for session in get_db():
|
||||||
|
service = FuncPermissionService(session)
|
||||||
|
await service.rollback_usage(log_id, user_id, func_code)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return dependency
|
||||||
@@ -94,7 +94,13 @@ class JwtAuthMiddleware(BaseHTTPMiddleware):
|
|||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
valid_devices = []
|
valid_devices = []
|
||||||
for d in devices:
|
for d in devices:
|
||||||
last_login = datetime.fromisoformat(d["lastLoginTime"])
|
last_login_str = d["lastLoginTime"]
|
||||||
|
# 兼容 Java Instant 格式(尾部 Z)
|
||||||
|
if last_login_str.endswith("Z"):
|
||||||
|
last_login_str = last_login_str[:-1] + "+00:00"
|
||||||
|
last_login = datetime.fromisoformat(last_login_str)
|
||||||
|
if last_login.tzinfo is None:
|
||||||
|
last_login = last_login.replace(tzinfo=timezone.utc)
|
||||||
if (now - last_login).total_seconds() < settings.token_expire_seconds:
|
if (now - last_login).total_seconds() < settings.token_expire_seconds:
|
||||||
valid_devices.append(d)
|
valid_devices.append(d)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
"""功能权限表"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import BigInteger, Integer, String, DateTime
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from app.core.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
class FuncPermission(Base):
|
||||||
|
"""功能权限定义表 bg_func_permission"""
|
||||||
|
__tablename__ = "bg_func_permission"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||||
|
func_code: Mapped[str] = mapped_column(String(12), nullable=False, comment="权限编码")
|
||||||
|
func_name: Mapped[str] = mapped_column(String(64), nullable=False, comment="功能名称")
|
||||||
|
daily_free_count: Mapped[int] = mapped_column(Integer, default=0, comment="每日免费次数,0表示无免费额度")
|
||||||
|
status: Mapped[int] = mapped_column(Integer, default=1, comment="状态 1=启用 0=禁用")
|
||||||
|
create_time: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment="创建时间")
|
||||||
|
update_time: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, onupdate=datetime.now, comment="修改时间")
|
||||||
|
is_delete: Mapped[int] = mapped_column(BigInteger, default=0, comment="删除标识 0正常 非0删除")
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
"""用户功能权限库存表"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import BigInteger, Integer, String, DateTime
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from app.core.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
class UserFuncPermissionStock(Base):
|
||||||
|
"""用户功能权限库存表 bg_user_func_permission_stock"""
|
||||||
|
__tablename__ = "bg_user_func_permission_stock"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False, comment="用户ID")
|
||||||
|
func_code: Mapped[str] = mapped_column(String(12), nullable=False, comment="权限编码")
|
||||||
|
time_limit: Mapped[int] = mapped_column(Integer, default=0, comment="0=不限时 1=限时")
|
||||||
|
count_limit: Mapped[int] = mapped_column(Integer, default=0, comment="0=不限次 1=限次")
|
||||||
|
expire_time: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True, comment="过期时间")
|
||||||
|
remain_count: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, comment="剩余次数")
|
||||||
|
create_time: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment="创建时间")
|
||||||
|
update_time: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, onupdate=datetime.now, comment="修改时间")
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
"""用户功能使用记录表"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import BigInteger, String, DateTime
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from app.core.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
class UserFuncUsageLog(Base):
|
||||||
|
"""用户功能使用记录表 bg_user_func_usage_log"""
|
||||||
|
__tablename__ = "bg_user_func_usage_log"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False, comment="用户ID")
|
||||||
|
func_code: Mapped[str] = mapped_column(String(12), nullable=False, comment="功能编码")
|
||||||
|
create_time: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment="使用时间")
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
"""功能权限 Service
|
||||||
|
|
||||||
|
校验用户功能权限并扣减库存,业务异常时回退。
|
||||||
|
逻辑与 Java 端 FuncPermissionService 完全一致。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from sqlalchemy import select, func, update, delete
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.func_permission import FuncPermission
|
||||||
|
from app.models.user_func_permission_stock import UserFuncPermissionStock
|
||||||
|
from app.models.user_func_usage_log import UserFuncUsageLog
|
||||||
|
|
||||||
|
|
||||||
|
class FuncPermissionService:
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession):
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def check_and_deduct(self, user_id: int, func_code: str) -> int:
|
||||||
|
"""校验权限 + 扣减库存,返回使用记录ID"""
|
||||||
|
|
||||||
|
# 1. 查功能权限定义
|
||||||
|
result = await self.session.execute(select(FuncPermission).where(FuncPermission.func_code == func_code, FuncPermission.status == 1, FuncPermission.is_delete == 0))
|
||||||
|
perm = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if perm is None:
|
||||||
|
raise HTTPException(status_code=403, detail="功能不存在或未启用")
|
||||||
|
|
||||||
|
# 2. 判断每日免费额度
|
||||||
|
daily_free = perm.daily_free_count or 0
|
||||||
|
if daily_free > 0:
|
||||||
|
today_start = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
result = await self.session.execute(select(func.count()).select_from(UserFuncUsageLog).where(UserFuncUsageLog.user_id == user_id, UserFuncUsageLog.func_code == func_code, UserFuncUsageLog.create_time >= today_start,))
|
||||||
|
today_used = result.scalar() or 0
|
||||||
|
|
||||||
|
if today_used < daily_free:
|
||||||
|
return await self._insert_usage_log(user_id, func_code)
|
||||||
|
|
||||||
|
# 3. 查付费库存
|
||||||
|
result = await self.session.execute(select(UserFuncPermissionStock).where(UserFuncPermissionStock.user_id == user_id, UserFuncPermissionStock.func_code == func_code))
|
||||||
|
stock = result.scalar_one_or_none()
|
||||||
|
if stock is None:
|
||||||
|
raise HTTPException(status_code=403, detail="无该功能权限")
|
||||||
|
|
||||||
|
# 4. 时间维度校验
|
||||||
|
if stock.time_limit == 1 and stock.expire_time is not None:
|
||||||
|
if stock.expire_time < datetime.now():
|
||||||
|
raise HTTPException(status_code=403, detail="功能权限已过期")
|
||||||
|
|
||||||
|
# 5. 次数维度校验
|
||||||
|
if stock.count_limit == 0:
|
||||||
|
return await self._insert_usage_log(user_id, func_code)
|
||||||
|
|
||||||
|
# 限次,SQL 原子扣减
|
||||||
|
result = await self.session.execute(update(UserFuncPermissionStock).where(UserFuncPermissionStock.user_id == user_id, UserFuncPermissionStock.func_code == func_code, UserFuncPermissionStock.remain_count > 0).values(remain_count=UserFuncPermissionStock.remain_count - 1))
|
||||||
|
|
||||||
|
if result.rowcount == 0:
|
||||||
|
raise HTTPException(status_code=403, detail="功能使用次数已用完")
|
||||||
|
|
||||||
|
return await self._insert_usage_log(user_id, func_code)
|
||||||
|
|
||||||
|
async def rollback_usage(self, log_id: int, user_id: int, func_code: str) -> None:
|
||||||
|
"""回退使用记录 + 库存"""
|
||||||
|
# 删除使用记录
|
||||||
|
await self.session.execute(delete(UserFuncUsageLog).where(UserFuncUsageLog.id == log_id))
|
||||||
|
|
||||||
|
# 尝试回退库存(count_limit=1 才会匹配)
|
||||||
|
await self.session.execute(update(UserFuncPermissionStock).where(UserFuncPermissionStock.user_id == user_id, UserFuncPermissionStock.func_code == func_code, UserFuncPermissionStock.count_limit == 1).values(remain_count=UserFuncPermissionStock.remain_count + 1))
|
||||||
|
|
||||||
|
async def _insert_usage_log(self, user_id: int, func_code: str) -> int:
|
||||||
|
"""插入使用记录,返回记录ID"""
|
||||||
|
usage_log = UserFuncUsageLog(user_id=user_id, func_code=func_code)
|
||||||
|
self.session.add(usage_log)
|
||||||
|
await self.session.flush()
|
||||||
|
return usage_log.id
|
||||||
Reference in New Issue
Block a user