52 lines
1.8 KiB
Python
52 lines
1.8 KiB
Python
"""AI 调用工具封装"""
|
|
|
|
import re
|
|
from typing import Any
|
|
|
|
from json_repair import repair_json
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_core.messages import SystemMessage, HumanMessage
|
|
|
|
from app.core.logger import log
|
|
|
|
# 匹配 <think>任意内容</think>,用于剥离推理模型的思考过程
|
|
_THINK_RE = re.compile(r"<think>.*?</think>", re.DOTALL | re.IGNORECASE)
|
|
|
|
# 匹配 ```json ... ``` 代码块,提取中间的 JSON 内容
|
|
_CODE_BLOCK_RE = re.compile(r"```(?:json\w*)?\s*\n?(.*?)\n?\s*```", re.DOTALL | re.IGNORECASE)
|
|
|
|
|
|
def parse_llm_json(text: str) -> Any:
|
|
"""解析 AI 输出的 JSON,自动去除思考标签、markdown 代码块,容错处理"""
|
|
# 1. 去掉 <think>...</think> 思考内容
|
|
cleaned = _THINK_RE.sub("", text).strip()
|
|
# 2. 如果有 ```json ... ``` 代码块,只取代码块里的内容
|
|
match = _CODE_BLOCK_RE.search(cleaned)
|
|
if match:
|
|
cleaned = match.group(1).strip()
|
|
# 3. repair_json 容错解析:修复不规范的 JSON(多余逗号、缺引号、非法转义等)
|
|
return repair_json(cleaned, return_objects=True)
|
|
|
|
|
|
async def ai_chat(llm: ChatOpenAI, system_prompt: str, user_message: str) -> str:
|
|
"""异步调用 LLM,返回原始文本"""
|
|
messages = [
|
|
SystemMessage(content=system_prompt),
|
|
HumanMessage(content=user_message),
|
|
]
|
|
response = await llm.ainvoke(messages)
|
|
return response.content
|
|
|
|
|
|
async def ai_chat_json(llm: ChatOpenAI, system_prompt: str, user_message: str) -> Any:
|
|
"""异步调用 LLM,返回解析后的 JSON 对象"""
|
|
raw = await ai_chat(llm, system_prompt, user_message)
|
|
if not raw or not raw.strip():
|
|
log.warning("AI 返回为空")
|
|
return None
|
|
try:
|
|
return parse_llm_json(raw)
|
|
except Exception as e:
|
|
log.warning("AI JSON 解析失败: {}, raw={}", e, raw[:200])
|
|
return None
|