66 lines
1.9 KiB
Python
66 lines
1.9 KiB
Python
"""AI 调用工具封装"""
|
|
|
|
import json
|
|
import re
|
|
from typing import Any
|
|
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_core.messages import SystemMessage, HumanMessage
|
|
|
|
from app.core.logger import log
|
|
|
|
# markdown 代码块正则
|
|
_CODE_BLOCK_RE = re.compile(r"```\w*\s*\n?(.*?)\n?\s*```", re.DOTALL)
|
|
# 控制字符正则(保留 \t \n \r)
|
|
_CONTROL_CHAR_RE = re.compile(r"[\x00-\x08\x0b\x0c\x0e-\x1f]")
|
|
|
|
|
|
def clean_ai_response(response: str) -> str:
|
|
"""从 AI 返回的文本中提取干净的 JSON 字符串"""
|
|
if not response or not response.strip():
|
|
return ""
|
|
|
|
result = response.strip()
|
|
|
|
# 尝试从 markdown 代码块提取
|
|
match = _CODE_BLOCK_RE.search(result)
|
|
if match:
|
|
result = match.group(1).strip()
|
|
else:
|
|
# 定位首个 JSON 起始符
|
|
obj_start = result.find("{")
|
|
arr_start = result.find("[")
|
|
if obj_start < 0:
|
|
start = arr_start
|
|
elif arr_start < 0:
|
|
start = obj_start
|
|
else:
|
|
start = min(obj_start, arr_start)
|
|
if start > 0:
|
|
result = result[start:]
|
|
|
|
# 清除控制字符
|
|
result = _CONTROL_CHAR_RE.sub("", result)
|
|
return result
|
|
|
|
|
|
async def ai_chat(llm: ChatOpenAI, system_prompt: str, user_message: str) -> str:
|
|
"""异步调用 LLM,返回原始文本"""
|
|
messages = [
|
|
SystemMessage(content=system_prompt),
|
|
HumanMessage(content=user_message),
|
|
]
|
|
response = await llm.ainvoke(messages)
|
|
return response.content
|
|
|
|
|
|
async def ai_chat_json(llm: ChatOpenAI, system_prompt: str, user_message: str) -> Any:
|
|
"""异步调用 LLM,返回解析后的 JSON 对象"""
|
|
raw = await ai_chat(llm, system_prompt, user_message)
|
|
cleaned = clean_ai_response(raw)
|
|
try:
|
|
return json.loads(cleaned)
|
|
except json.JSONDecodeError as e:
|
|
log.warning("AI JSON 解析失败: {}, raw={}", e, raw[:200])
|
|
return None
|