139 lines
5.0 KiB
Python
139 lines
5.0 KiB
Python
"""公司数据补充服务(协程版)"""
|
|
|
|
import asyncio
|
|
from datetime import datetime
|
|
|
|
from sqlalchemy import text
|
|
|
|
from app.config import settings
|
|
from app.core.database import MysqlSession
|
|
from app.core.logger import log
|
|
from app.ai.model_config import CompanyCleanModel
|
|
from app.ai.prompts import COMPANY_ENRICH_SYSTEM
|
|
from app.services.ai_tool import ai_chat_json
|
|
from app.services.dict_cache_service import dict_cache
|
|
|
|
|
|
async def run_company_clean() -> None:
|
|
"""一次批量公司补充任务"""
|
|
# 锁定一批待完善公司
|
|
async with MysqlSession() as mysql:
|
|
result = await mysql.execute(
|
|
text("""
|
|
SELECT * FROM bg_company
|
|
WHERE status = 0
|
|
LIMIT :limit
|
|
FOR UPDATE SKIP LOCKED
|
|
"""),
|
|
{"limit": settings.company_batch_size},
|
|
)
|
|
rows = result.mappings().all()
|
|
if not rows:
|
|
return
|
|
|
|
ids = [r["id"] for r in rows]
|
|
# MySQL 批量 IN 用 format 拼接(id 是 bigint,安全)
|
|
ids_str = ",".join(str(i) for i in ids)
|
|
await mysql.execute(
|
|
text(f"UPDATE bg_company SET status = 3, update_time = NOW() WHERE id IN ({ids_str})"),
|
|
)
|
|
await mysql.commit()
|
|
|
|
log.info("公司补充:锁定{}条数据", len(rows))
|
|
|
|
# 协程并发,信号量限流
|
|
sem = asyncio.Semaphore(settings.company_concurrency)
|
|
tasks = [_clean_one(sem, dict(r)) for r in rows]
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
|
|
async def _clean_one(sem: asyncio.Semaphore, company: dict) -> None:
|
|
"""单条公司补充"""
|
|
async with sem:
|
|
try:
|
|
await _do_clean(company)
|
|
except Exception as e:
|
|
log.error("公司补充异常, id={}, shortName={}: {}", company["id"], company.get("short_name"), e)
|
|
|
|
|
|
async def _do_clean(company: dict) -> None:
|
|
"""公司补充逻辑"""
|
|
company_id = company["id"]
|
|
short_name = company.get("short_name", "")
|
|
|
|
user_msg = f"【公司简称】\n{short_name}\n\n【行业列表】\n{dict_cache.industry_text}"
|
|
result = await ai_chat_json(CompanyCleanModel.ENRICH, COMPANY_ENRICH_SYSTEM, user_msg)
|
|
|
|
if result is None or not result.get("valid", False):
|
|
await _update_status(company_id, 4)
|
|
return
|
|
|
|
# 地区匹配
|
|
city = result.get("city")
|
|
region_code = dict_cache.match_region_code(city) if city else None
|
|
|
|
# 回填数据
|
|
now = datetime.now()
|
|
async with MysqlSession() as mysql:
|
|
await mysql.execute(
|
|
text("""
|
|
UPDATE bg_company SET
|
|
name = COALESCE(:name, name),
|
|
region_code = COALESCE(:region_code, region_code),
|
|
company_type = COALESCE(:company_type, company_type),
|
|
industry_id = :industry_id,
|
|
tags = :tags,
|
|
summary = COALESCE(:summary, summary),
|
|
description = COALESCE(:description, description),
|
|
founded_year = COALESCE(:founded_year, founded_year),
|
|
address = COALESCE(:address, address),
|
|
scale = COALESCE(:scale, scale),
|
|
website = COALESCE(:website, website),
|
|
financing_stage = COALESCE(:financing_stage, financing_stage),
|
|
latest_valuation = COALESCE(:latest_valuation, latest_valuation),
|
|
news = :news,
|
|
status = 1,
|
|
update_time = :now
|
|
WHERE id = :id
|
|
"""),
|
|
{
|
|
"name": result.get("name"),
|
|
"region_code": region_code,
|
|
"company_type": result.get("companyType"),
|
|
"industry_id": result.get("industryId"),
|
|
"tags": _to_json(result.get("tags")),
|
|
"summary": result.get("summary"),
|
|
"description": result.get("description"),
|
|
"founded_year": result.get("foundedYear"),
|
|
"address": result.get("address"),
|
|
"scale": result.get("scale"),
|
|
"website": result.get("website"),
|
|
"financing_stage": result.get("financingStage"),
|
|
"latest_valuation": result.get("latestValuation"),
|
|
"news": _to_json(result.get("news")),
|
|
"now": now,
|
|
"id": company_id,
|
|
},
|
|
)
|
|
await mysql.commit()
|
|
|
|
log.info("公司补充完成, id={}, shortName={}", company_id, short_name)
|
|
|
|
|
|
async def _update_status(company_id: int, status: int) -> None:
|
|
"""更新公司状态"""
|
|
async with MysqlSession() as mysql:
|
|
await mysql.execute(
|
|
text("UPDATE bg_company SET status = :s, update_time = NOW() WHERE id = :id"),
|
|
{"s": status, "id": company_id},
|
|
)
|
|
await mysql.commit()
|
|
|
|
|
|
def _to_json(value) -> str | None:
|
|
"""列表转 JSON 字符串"""
|
|
import json
|
|
if value and isinstance(value, list):
|
|
return json.dumps(value, ensure_ascii=False)
|
|
return None
|