dagent_eval/server/api/qa_gen_dagent.py

1175 lines
47 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
从 Dagent 数据库导入知识库数据,生成多模态问答集
"""
import asyncio
import json
import re
import sys
from pathlib import Path
from fastapi import APIRouter, Form, HTTPException
from typing import Optional
import aiohttp
import aiomysql
import logging
import os
from datetime import datetime
# 设置文件日志(必须在 Path 导入后)
LOG_PATH = Path(__file__).parent.parent / "logs"
LOG_PATH.mkdir(exist_ok=True)
_logger = logging.getLogger("qa_gen_dagent")
_logger.setLevel(logging.DEBUG)
if not _logger.handlers:
_fh = logging.FileHandler(LOG_PATH / "qa_gen_debug.log", encoding="utf-8")
_fh.setLevel(logging.DEBUG)
_fh.setFormatter(logging.Formatter("%(asctime)s - %(message)s"))
_logger.addHandler(_fh)
def _log(msg: str):
"""强制写入文件日志"""
_logger.debug(msg)
# Add parent directory to sys.path for relative imports
sys.path.insert(0, str(Path(__file__).parent.parent))
from models.db import get_db, _now, _id
router = APIRouter(prefix="/api/qa-gen", tags=["问题生成-Dagent"])
DAGENT_DB = {
"host": "120.48.66.228",
"port": 23306,
"user": "dagent",
"password": "Fd1.Ej3.fdIie48",
"db": "dagent_platform",
"charset": "utf8mb4",
}
async def get_dagent_conn():
"""创建 Dagent 数据库连接"""
return await aiomysql.connect(**DAGENT_DB)
@router.get("/dagent/stats")
async def get_dagent_stats(org_id: str, env_url: str = ""):
"""获取 Dagent 知识库统计信息(通过 HTTP API"""
import aiohttp
# 使用默认生产环境 URL
base_url = (env_url or "https://dagent.d-robotics.cc").rstrip("/")
headers = {
"Content-Type": "application/json",
"org-id": org_id,
"d-user-id": "test",
}
try:
async with aiohttp.ClientSession(headers=headers) as session:
# 获取文件列表
page = 1
page_size = 100
total_files = 0
total_paragraphs = 0
while True:
async with session.post(
f"{base_url}/dagent/knowledge/file/page",
json={"current": page, "page_size": page_size, "org_id": org_id},
timeout=aiohttp.ClientTimeout(total=15),
) as resp:
if resp.status != 200:
break
data = await resp.json()
files = data.get("data", {}).get("list", [])
if not files:
break
total_files += len(files)
# 获取每个文件的切片数
for f in files:
try:
async with session.post(
f"{base_url}/dagent/knowledge/chunk/page",
json={"file_id": f["id"], "org_id": org_id, "page": 1, "page_size": 1},
timeout=aiohttp.ClientTimeout(total=10),
) as cr:
if cr.status == 200:
cd = await cr.json()
total_paragraphs += cd.get("data", {}).get("total", 0)
except Exception:
pass
if len(files) < page_size:
break
page += 1
return {"status": 0, "data": {
"file_count": total_files,
"paragraph_count": total_paragraphs,
"total_images": 0,
"paragraphs_with_pic_text": 0,
"paragraphs_with_question": 0,
}}
except Exception as e:
print(f"[get_dagent_stats] Error: {e}")
return {"status": 0, "data": {}}
@router.get("/dagent/files")
async def list_dagent_files(org_id: str, env_url: str = ""):
"""列出 Dagent 中某组织下已处理完成的文件(通过 HTTP API"""
import aiohttp
base_url = (env_url or "https://dagent.d-robotics.cc").rstrip("/")
headers = {
"Content-Type": "application/json",
"org-id": org_id,
"d-user-id": "test",
}
all_files = []
try:
async with aiohttp.ClientSession(headers=headers) as session:
page = 1
page_size = 100
while True:
async with session.post(
f"{base_url}/dagent/knowledge/file/page",
json={"current": page, "page_size": page_size, "org_id": org_id},
timeout=aiohttp.ClientTimeout(total=15),
) as resp:
if resp.status != 200:
break
data = await resp.json()
files = data.get("data", {}).get("list", [])
if not files:
break
for f in files:
all_files.append({
"id": f.get("id"),
"file_name": f.get("file_name"),
"file_type": f.get("file_type"),
"file_clean_status": f.get("file_clean_status", "").lower(),
"file_bytes": f.get("file_bytes", 0),
"create_time": f.get("create_time"),
})
if len(files) < page_size:
break
page += 1
return {"status": 0, "data": all_files}
except Exception as e:
print(f"[list_dagent_files] Error: {e}")
return {"status": 0, "data": []}
@router.get("/dagent/tree")
async def get_dagent_tree(org_id: str, env_url: str = ""):
"""
获取知识库的层级树形结构
结构:大章节 -> 小章节 -> 文件
"""
import aiohttp
import asyncio
base_url = (env_url or "https://dagent.d-robotics.cc").rstrip("/")
headers = {
"Content-Type": "application/json",
"org-id": org_id,
"d-user-id": "test",
}
try:
async with aiohttp.ClientSession(headers=headers) as session:
page = 1
all_files = []
while True:
async with session.post(
f"{base_url}/dagent/knowledge/file/page",
json={"current": page, "page_size": 100, "org_id": org_id},
timeout=aiohttp.ClientTimeout(total=15),
) as resp:
if resp.status != 200:
break
data = await resp.json()
files = data.get("data", {}).get("list", [])
if not files:
break
for f in files:
all_files.append(f)
if len(files) < 100:
break
page += 1
# 并发获取每个文件的 chunk 总数page_size=1 只拿 total
sem = asyncio.Semaphore(20)
async def fetch_chunk_count(file_id: str) -> int:
async with sem:
try:
async with session.post(
f"{base_url}/dagent/knowledge/chunk/page",
json={"file_id": file_id, "org_id": org_id, "page": 1, "page_size": 1},
timeout=aiohttp.ClientTimeout(total=10),
) as cr:
if cr.status == 200:
cdata = await cr.json()
return cdata.get("data", {}).get("total", 0)
except Exception:
pass
return 0
chunk_counts = await asyncio.gather(
*[fetch_chunk_count(f.get("id")) for f in all_files]
)
# 解析文件路径并构建列表
parsed_files = []
for i, f in enumerate(all_files):
file_name = f.get("file_name", "")
parts = file_name.split("/")
if len(parts) >= 2:
major_chapter = parts[0]
minor_chapter = "/".join(parts[:-1]) if len(parts) > 2 else parts[0]
file_name_only = parts[-1]
else:
major_chapter = "默认章节"
minor_chapter = "默认章节"
file_name_only = file_name
parsed_files.append({
"id": f.get("id"),
"file_name": file_name_only,
"full_path": file_name,
"file_type": f.get("file_type", ""),
"file_clean_status": f.get("file_clean_status", "").lower(),
"major_chapter": major_chapter,
"minor_chapter": minor_chapter,
"chunk_count": chunk_counts[i],
})
# 构建树形结构
tree = {}
for f in parsed_files:
major = f["major_chapter"]
minor = f["minor_chapter"]
if major not in tree:
tree[major] = {
"key": f"major:{major}",
"title": major,
"type": "major_chapter",
"children": {}
}
if minor not in tree[major]["children"]:
tree[major]["children"][minor] = {
"key": f"minor:{minor}",
"title": minor.split("/")[-1] if "/" in minor else minor,
"full_path": minor,
"type": "minor_chapter",
"children": []
}
tree[major]["children"][minor]["children"].append({
"key": f"file:{f['id']}",
"title": f["file_name"],
"type": "file",
"file_id": f["id"],
"file_type": f["file_type"],
"status": f["file_clean_status"],
"chunk_count": f["chunk_count"],
})
result = []
for major_name, major_node in tree.items():
major_children = []
for minor_name, minor_node in major_node["children"].items():
minor_children = sorted(minor_node["children"], key=lambda x: x["title"])
major_children.append({
**{k: v for k, v in minor_node.items() if k != "children"},
"children": minor_children
})
result.append({
"key": major_node["key"],
"title": major_node["title"],
"type": "major_chapter",
"children": sorted(major_children, key=lambda x: x["title"])
})
return {"status": 0, "data": sorted(result, key=lambda x: x["title"])}
except Exception as e:
import traceback
print(f"[get_dagent_tree] Error: {e}")
print(traceback.format_exc())
return {"status": 1, "message": str(e), "data": []}
@router.post("/task/from-dagent")
async def create_task_from_dagent(
org_id: str = Form(...),
env_url: str = Form(""),
name: str = Form(""),
judge_config_id: str = Form(...),
file_ids: str = Form(""),
questions_per_section: int = Form(5),
quality_threshold: float = Form(0.6),
include_multimodal: bool = Form(True),
):
"""从 Dagent 数据库创建问答生成任务"""
task_id = _id()
file_id_list = [f.strip() for f in file_ids.split(",") if f.strip()]
async with get_db() as db:
await db.execute(
"""INSERT INTO qa_gen_task
(id,name,judge_config_id,questions_per_section,quality_threshold,status,created_at)
VALUES (?,?,?,?,?,?,?)""",
(task_id, name or f"Dagent导入({org_id[:8]}...)",
judge_config_id, questions_per_section, quality_threshold, "pending", _now()),
)
await db.commit()
asyncio.create_task(_run_dagent_task(
task_id, org_id, file_id_list, judge_config_id,
questions_per_section, quality_threshold, include_multimodal,
env_url=env_url,
))
return {"status": 0, "data": {"id": task_id}}
# ── 内部:后台任务 ─────────────────────────────────────────────────────────────
def _dedupe_paragraphs_by_chunk_id(paragraphs: list[dict]) -> list[dict]:
"""按 chunk id 去重,保留首次出现顺序(避免 API 重复页导致重复生成)。"""
seen: set[str] = set()
out: list[dict] = []
dup = 0
for p in paragraphs:
cid = (p.get("id") or "").strip()
if cid:
if cid in seen:
dup += 1
continue
seen.add(cid)
out.append(p)
if dup:
print(f"[_dedupe_paragraphs_by_chunk_id] removed {dup} duplicate chunk rows")
return out
def _merge_paragraphs_by_chunk_id(primary: list[dict], extra: list[dict]) -> list[dict]:
"""把 extra 中尚未出现在 primary 的 chunk 并入(按 id"""
seen = {(p.get("id") or "").strip() for p in primary if (p.get("id") or "").strip()}
merged = list(primary)
for p in extra:
cid = (p.get("id") or "").strip()
if cid and cid in seen:
continue
if cid:
seen.add(cid)
merged.append(p)
return merged
async def _fetch_paragraphs(org_id: str, file_id_list: list[str], env_url: str = "") -> list[dict]:
"""从 Dagent HTTP API 提取段落数据
Args:
file_id_list: 指定要处理的文件ID列表如果为空则处理所有文件
"""
import aiohttp
base_url = (env_url or "https://dagent.d-robotics.cc").rstrip("/")
headers = {
"Content-Type": "application/json",
"org-id": org_id,
"d-user-id": "test",
}
all_paragraphs = []
# 单个文件的切片数上限(防止 API 忽略 file_id 返回全库切片);过小会整文件跳过
MAX_CHUNKS_PER_FILE = 50000
MAX_RETRIES = 5 # 分页触顶 / 网络抖动时多试几次
PAGE_SIZE = 100
try:
async with aiohttp.ClientSession(headers=headers) as session:
# 确定要处理的文件列表
files_to_process = []
if file_id_list:
print(f"[_fetch_paragraphs] Processing {len(file_id_list)} user-selected files")
files_to_process = [{"id": fid, "file_name": ""} for fid in file_id_list]
else:
print(f"[_fetch_paragraphs] Fetching file list...")
page = 1
all_files = []
while True:
async with session.post(
f"{base_url}/dagent/knowledge/file/page",
json={"current": page, "page_size": 100, "org_id": org_id},
timeout=aiohttp.ClientTimeout(total=15),
) as resp:
if resp.status != 200:
break
data = await resp.json()
files = data.get("data", {}).get("list", [])
if not files:
break
all_files.extend(files)
if len(files) < 100:
break
page += 1
print(f"[_fetch_paragraphs] Total files available: {len(all_files)}, will process all")
files_to_process = all_files
# 获取每个文件的切片
total_files = len(files_to_process)
for idx, f in enumerate(files_to_process):
file_id = f.get("id") if isinstance(f, dict) else f
file_name = f.get("file_name", "") if isinstance(f, dict) else ""
if idx % 10 == 0:
print(f"[_fetch_paragraphs] Processing file {idx+1}/{total_files}: {file_id[:20]}...")
# 先用 page_size=1 探测该文件的 total验证 API 是否正确过滤
expected_total = None
for attempt in range(MAX_RETRIES):
try:
async with session.post(
f"{base_url}/dagent/knowledge/chunk/page",
json={"file_id": file_id, "org_id": org_id, "page": 1, "page_size": 1},
timeout=aiohttp.ClientTimeout(total=15),
) as resp:
if resp.status != 200:
print(f"[_fetch_paragraphs] Probe failed for {file_id[:20]}: HTTP {resp.status}")
await asyncio.sleep(2 ** attempt)
continue
probe_data = await resp.json()
expected_total = probe_data.get("data", {}).get("total", 0)
except Exception as e:
print(f"[_fetch_paragraphs] Probe error for {file_id[:20]}: {e}")
await asyncio.sleep(2 ** attempt)
continue
if expected_total is not None and expected_total <= MAX_CHUNKS_PER_FILE:
break
elif expected_total is not None and expected_total > MAX_CHUNKS_PER_FILE:
print(f"[_fetch_paragraphs] WARNING: file {file_id[:20]} returned total={expected_total}, "
f"likely API bug (file_id ignored). Retrying ({attempt+1}/{MAX_RETRIES})...")
expected_total = None
await asyncio.sleep(3 * (attempt + 1))
if expected_total is None or expected_total > MAX_CHUNKS_PER_FILE:
print(f"[_fetch_paragraphs] SKIPPING file {file_id[:20]} ({file_name}): "
f"total={expected_total} exceeds limit after {MAX_RETRIES} retries")
continue
if expected_total == 0:
continue
# 正式分页拉取:不得以「已收集数 >= API total」提前停——total 常低于真实切片数,会少拉约一页~数页。
# max_pages 给足余量;仅当末页 < PAGE_SIZE 或返回空 list 时视为自然结束。
for fetch_attempt in range(MAX_RETRIES):
slack = 80 + fetch_attempt * 60
max_pages = min(
2000,
max(50, (expected_total + PAGE_SIZE - 1) // PAGE_SIZE + slack),
)
page = 1
file_chunks = []
fetch_ok = True
foreign_count = 0
ended_normally = False # 空页或末页不满 PAGE_SIZE
while page <= max_pages:
try:
async with session.post(
f"{base_url}/dagent/knowledge/chunk/page",
json={
"file_id": file_id,
"org_id": org_id,
"page": page,
"page_size": PAGE_SIZE,
},
timeout=aiohttp.ClientTimeout(total=30),
) as resp:
if resp.status != 200:
fetch_ok = False
break
data = await resp.json()
chunks = data.get("data", {}).get("list", [])
if not chunks:
ended_normally = True
break
page_foreign = 0
for c in chunks:
chunk_fid = c.get("file_id", "")
if chunk_fid and chunk_fid != file_id:
foreign_count += 1
page_foreign += 1
continue
# large_paragraph_llm_summary后端大段压缩后的摘要常与 paragraph_context 二选一存在;
# 若不映射,大量切片会落入「无正文」→ 生成阶段恒返回 0 题。
_ctx = (
c.get("active_paragraph_context")
or c.get("paragraph_context")
or c.get("active_context")
or ""
)
_llm_sum = (c.get("large_paragraph_llm_summary") or "").strip()
_para_sum = (c.get("paragraph_summary") or "").strip()
file_chunks.append({
"id": c.get("id"),
"file_id": file_id,
"file_name": file_name or c.get("file_name", ""),
"headers": c.get("headers", ""),
"paragraph_context": _ctx or _llm_sum,
"paragraph_img_num": c.get("paragraph_img_num", 0),
"paragraph_pic_semantics_context": c.get("paragraph_pic_semantics_context", ""),
"paragraph_question": c.get("paragraph_question", ""),
"paragraph_summary": _para_sum or _llm_sum,
"paragraph_keywords": c.get("paragraph_keywords", ""),
})
if len(chunks) < PAGE_SIZE:
ended_normally = True
break
if page_foreign > len(chunks) * 0.5:
print(
f"[_fetch_paragraphs] Page {page}: high foreign ratio "
f"{page_foreign}/{len(chunks)} for file {file_id[:20]}, continuing"
)
page += 1
except Exception as e:
print(f"[_fetch_paragraphs] Error fetching chunks for file {file_id[:20]}: {e}")
fetch_ok = False
break
if foreign_count > 0:
print(
f"[_fetch_paragraphs] File {file_id[:20]}: filtered {foreign_count} foreign, "
f"kept {len(file_chunks)}"
)
if not fetch_ok:
if fetch_attempt < MAX_RETRIES - 1:
print(
f"[_fetch_paragraphs] File {file_id[:20]}: fetch error, "
f"retry ({fetch_attempt + 1}/{MAX_RETRIES})..."
)
file_chunks = []
await asyncio.sleep(3 * (fetch_attempt + 1))
continue
break
if ended_normally:
if expected_total and len(file_chunks) < expected_total:
print(
f"[_fetch_paragraphs] File {file_id[:20]}: EOF kept={len(file_chunks)} "
f"vs API total={expected_total} (often foreign rows in total)"
)
break
# 未自然结束:多半触达 max_pages 且最后一页仍为满页,继续扩页重试
if fetch_attempt < MAX_RETRIES - 1:
print(
f"[_fetch_paragraphs] File {file_id[:20]}: page cap hit "
f"(last_page={page - 1}, max_pages={max_pages}, kept={len(file_chunks)}), "
f"retry ({fetch_attempt + 1}/{MAX_RETRIES})..."
)
file_chunks = []
await asyncio.sleep(3 * (fetch_attempt + 1))
continue
print(
f"[_fetch_paragraphs] WARNING: file {file_id[:20]} still not EOF after "
f"{MAX_RETRIES} attempts; accepting {len(file_chunks)} chunks"
)
break
if file_chunks:
all_paragraphs.extend(file_chunks)
all_paragraphs = _dedupe_paragraphs_by_chunk_id(all_paragraphs)
print(f"[_fetch_paragraphs] Total paragraphs fetched: {len(all_paragraphs)} from {total_files} files")
return all_paragraphs
except Exception as e:
import traceback
print(f"[_fetch_paragraphs] Error: {e}")
print(f"[_fetch_paragraphs] Traceback: {traceback.format_exc()}")
return []
def _extract_json_array(text: str) -> Optional[list]:
"""容错解析 LLM 返回的 JSON 数组。
策略依次尝试:
1) 直接 json.loads 整个响应
2) 抠出 ```...``` 或 ```json ... ``` 代码块再 loads
3) 以第一个 `[` 为起点,按括号配平找到对应 `]`(跳过字符串内的括号)
4) 若因截断未闭合,尝试在最后一个完整对象 `}` 处强制补 `]` 再 loads
任一成功即返回 list全部失败返回 None。
"""
if not text:
return None
stripped = text.strip()
# 1) 整体 loads
try:
data = json.loads(stripped)
if isinstance(data, list):
return data
except Exception:
pass
# 2) 代码块
block = re.search(r"```(?:json)?\s*(.*?)```", stripped, re.DOTALL | re.IGNORECASE)
if block:
try:
data = json.loads(block.group(1).strip())
if isinstance(data, list):
return data
except Exception:
pass
# 3) 括号配平(跳过字符串内的括号)
start = stripped.find("[")
if start == -1:
return None
depth = 0
in_str = False
escape = False
end = -1
for i in range(start, len(stripped)):
ch = stripped[i]
if escape:
escape = False
continue
if ch == "\\":
escape = True
continue
if ch == '"':
in_str = not in_str
continue
if in_str:
continue
if ch == "[":
depth += 1
elif ch == "]":
depth -= 1
if depth == 0:
end = i
break
if end != -1:
candidate = stripped[start:end + 1]
try:
data = json.loads(candidate)
if isinstance(data, list):
return data
except Exception:
pass
# 4) 截断恢复:找最后一个完整对象的 `}`,强制补 `]`
tail_brace = stripped.rfind("}")
if tail_brace > start:
candidate = stripped[start:tail_brace + 1] + "]"
try:
data = json.loads(candidate)
if isinstance(data, list):
return data
except Exception:
pass
return None
def _parse_quality_score(raw) -> float:
"""模型自评分数:缺省/非法时用 0.8,避免 float(None) 整段失败;限制在 [0,1]。"""
if raw is None:
return 0.8
try:
v = float(raw)
except (TypeError, ValueError):
return 0.8
return max(0.0, min(1.0, v))
async def _call_llm_once(
session, base_url: str, payload: dict, timeout_s: int
) -> tuple[Optional[str], Optional[str]]:
"""单次调用 LLM返回 (content, error_str)。"""
try:
async with session.post(
f"{base_url}/chat/completions",
json=payload,
timeout=aiohttp.ClientTimeout(total=timeout_s),
) as resp:
if resp.status != 200:
body = (await resp.text())[:500]
return None, f"HTTP {resp.status}: {body}"
data = await resp.json()
content = data["choices"][0]["message"]["content"].strip()
return content, None
except Exception as e:
return None, str(e)
async def _generate_questions_for_paragraph(
para: dict, cfg: dict, n: int, include_multimodal: bool,
existing_questions: list[str] = None, # 已有的问题列表,用于避免重复
) -> list[dict]:
"""为单个段落生成问答,支持传入已有问题避免重复。
改进:
- 引入容错 JSON 抽取_extract_json_array避免贪婪正则漏解析/截断直接丢题。
- 增加重试与自适应降 n单次失败后指数退避重试若怀疑是 max_tokens 截断,下一次把 n 折半。
- 复用 ClientSession在此函数内同次生成共享跨调用暂未共享以保持接口稳定
- 放宽「已有历史问题」的硬约束,改为软参考,避免第 3/4 轮模型大量拒答。
"""
base_url = cfg.get("base_url", "").rstrip("/")
api_key = cfg.get("api_key", "")
model = cfg.get("model", "gpt-4o-mini")
context_plain = (para.get("paragraph_context") or "").strip()
pic_semantics = (para.get("paragraph_pic_semantics_context") or "").strip()
seed_question = (para.get("paragraph_question") or "").strip()
headers = (para.get("headers") or "").strip()
summary = (para.get("paragraph_summary") or "").strip()
keywords = (para.get("paragraph_keywords") or "").strip()
has_image = bool(pic_semantics and para.get("paragraph_img_num", 0) > 0)
text = context_plain
if not text:
text = summary
if not text and seed_question:
text = seed_question
if not text and has_image and include_multimodal and pic_semantics:
text = pic_semantics[:2500]
if not text and keywords:
text = f"关键词:\n{keywords[:1500]}"
if not text and headers:
text = (
f"(该切片缺少正文/摘要,仅章节路径如下;请基于路径生成 {n} 个简短、可检索的技术问题,"
f"答案可写「需结合全文」类占位但问题须具体)\n{headers}"
)
if not text:
return []
# 构建 prompt正文为空但用图片语义作主内容时不再重复插入图片块
pic_section = ""
if has_image and include_multimodal and pic_semantics and context_plain:
pic_section = f"""
**图片语义描述(图片已由 AI 识别):**
{pic_semantics[:800]}
"""
seed_section = ""
if seed_question:
seed_section = f"\n**已有种子问题(请避免重复,可从不同角度扩展):** {seed_question}"
# 已有问题列表(来自循环任务的历史问题)
# 放宽为「风格参考」——模型在历史 10+ 条强约束下常直接返回空数组,导致循环第 3/4 轮整轮 0 题。
existing_section = ""
if existing_questions:
sample_existing = existing_questions[:5]
existing_section = (
"\n**该段落的历史问题(供参考,尽量换角度/换措辞,但不必完全不同):**\n"
)
for i, eq in enumerate(sample_existing, 1):
existing_section += f"{i}. {eq}\n"
def _build_prompt(ask_n: int) -> str:
return f"""你是一个技术文档问答生成专家。基于以下内容生成 {ask_n} 个测试问题。
**章节路径:** {headers}
**文本内容:**
{text[:2500]}
{pic_section}{seed_section}{existing_section}
**要求:**
1. 问题必须能从该章节内容直接回答
2. 覆盖关键知识点,避免过于简单的是非题
3. 如果有图片语义描述,至少生成 1 个图文结合的问题(问题中提及"如图所示""图中"等)
4. 答案准确长度适中1-3 句话)
5. source_chunk 为答案来源的原文片段50-150 字)
6. has_image 标记该问题是否依赖图像信息
7. quality_score 为质量评估0-1
**输出格式:严格只输出一个合法 JSON 数组,不要额外解释、不要代码块标记。**
[
{{
"question": "问题文本",
"answer": "参考答案",
"source_chunk": "答案来源原文片段",
"has_image": false,
"quality_score": 0.9
}}
]"""
headers_http = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
MAX_ATTEMPTS = 3
TIMEOUT_S = 120
cur_n = n
async with aiohttp.ClientSession(headers=headers_http) as session:
for attempt in range(MAX_ATTEMPTS):
prompt = _build_prompt(cur_n)
payload = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.3,
}
content, err = await _call_llm_once(session, base_url, payload, TIMEOUT_S)
if err:
is_rate_limit = "429" in err or "Rate limit" in err or "rate limit" in err
is_budget = "400" in err and ("Budget" in err or "budget" in err or "budget_exceeded" in err)
_log(
f"[_generate_questions] Attempt {attempt + 1}/{MAX_ATTEMPTS} failed "
f"for headers={headers[:50]}: {err[:200]}"
)
# 限流或预算超限:长退避;其他错误:短退避
if attempt < MAX_ATTEMPTS - 1:
if is_budget:
wait_s = 300 + 300 * attempt # 预算超限5min, 10min, 15min
_log(f"[_generate_questions] Budget exceeded, backing off {wait_s}s (wait for reset)")
elif is_rate_limit:
wait_s = 30 + 15 * attempt # 429: 30s, 45s, 60s
_log(f"[_generate_questions] Rate limit detected, backing off {wait_s}s")
else:
wait_s = 2 + 2 * attempt # 其他: 2s, 4s, 6s
await asyncio.sleep(wait_s)
continue
questions = _extract_json_array(content)
if not questions:
# 看起来像被截断:响应末尾既无 `]` 又无 `}`,下一轮降 n
looks_truncated = not content.rstrip().endswith(("]", "}"))
_log(
f"[_generate_questions] Attempt {attempt + 1}/{MAX_ATTEMPTS}: "
f"JSON parse failed for headers={headers[:50]} "
f"(len={len(content)}, truncated={looks_truncated})"
)
_log(f"[_generate_questions] Raw response preview: {content[:300]}...{content[-300:]}")
if looks_truncated and cur_n > 1:
cur_n = max(1, cur_n // 2)
if attempt < MAX_ATTEMPTS - 1:
await asyncio.sleep(1.0 * (attempt + 1))
continue
result = []
for q in questions:
if isinstance(q, dict) and q.get("question") and q.get("answer"):
result.append({
"question": str(q["question"]).strip(),
"answer": str(q["answer"]).strip(),
"source_chunk": str(q.get("source_chunk", "")).strip(),
"has_image": bool(q.get("has_image", False)),
"quality_score": _parse_quality_score(q.get("quality_score")),
"source_image_desc": pic_semantics[:300] if q.get("has_image") else "",
})
if result:
_log(
f"[_generate_questions] Generated {len(result)} questions for "
f"headers={headers[:50]} (attempt {attempt + 1}, asked={cur_n})"
)
return result
_log(
f"[_generate_questions] Attempt {attempt + 1}/{MAX_ATTEMPTS}: "
f"JSON parsed but 0 valid items for headers={headers[:50]} "
f"(raw count={len(questions)})"
)
if attempt < MAX_ATTEMPTS - 1:
await asyncio.sleep(1.0 * (attempt + 1))
_log(f"[_generate_questions] All {MAX_ATTEMPTS} attempts exhausted for headers={headers[:50]}")
return []
async def _run_dagent_task(
task_id: str,
org_id: str,
file_id_list: list[str],
judge_config_id: str,
questions_per_section: int,
quality_threshold: float,
include_multimodal: bool,
section_existing_questions: dict[str, list[str]] = None, # {section_path: [question1, question2, ...]}
stop_check: callable = None, # Optional stop check function
pause_check: callable = None, # Optional async pause check function
env_url: str = "", # Dagent environment URL
expected_chunk_count: Optional[int] = None, # 批次规划切片总数;与拉取结果对齐校验
):
"""
运行 Dagent QA 生成任务
Args:
section_existing_questions: 各 section 下已有的问题列表,用于避免重复生成
stop_check: 可选的停止检查函数返回True时应停止任务
env_url: Dagent 环境 URL
expected_chunk_count: 与 chunk_batches_plan 中本批 chunk_count 一致时,强制校验去重后的拉取条数
"""
import traceback
section_existing_questions = section_existing_questions or {}
print(f"[_run_dagent_task] Starting task {task_id}, org_id={org_id}, file_id_list={len(file_id_list)} files, env_url={env_url}")
try:
# 1. 先更新状态为 running让用户知道任务已开始
async with get_db() as db:
await db.execute(
"UPDATE qa_gen_task SET status='running', total=0, progress=0 WHERE id=?",
(task_id,),
)
await db.commit()
# 2. 提取段落(可多次拉取合并,直至满足 expected_chunk_count
print(f"[_run_dagent_task] Fetching paragraphs...")
paragraphs = await _fetch_paragraphs(org_id, file_id_list, env_url)
paragraphs = _dedupe_paragraphs_by_chunk_id(paragraphs)
if expected_chunk_count and len(paragraphs) < expected_chunk_count:
for refetch_i in range(3):
short_by = expected_chunk_count - len(paragraphs)
print(
f"[_run_dagent_task] Chunk count {len(paragraphs)} < expected {expected_chunk_count} "
f"(short {short_by}), refetch merge attempt {refetch_i + 1}/3"
)
more = await _fetch_paragraphs(org_id, file_id_list, env_url)
paragraphs = _merge_paragraphs_by_chunk_id(paragraphs, more)
if len(paragraphs) >= expected_chunk_count:
break
await asyncio.sleep(5 * (refetch_i + 1))
if expected_chunk_count and len(paragraphs) < expected_chunk_count:
raise RuntimeError(
f"拉取切片 {len(paragraphs)} 条,少于批次期望 {expected_chunk_count} 条;"
f"请检查 Dagent chunk/page API、file_ids 是否与 chunk_batches_plan 一致。"
)
total = len(paragraphs)
print(
f"[_run_dagent_task] Fetched {total} paragraphs"
+ (f" (expected_chunk_count={expected_chunk_count})" if expected_chunk_count else "")
)
if expected_chunk_count and total > expected_chunk_count + 5:
print(
f"[_run_dagent_task] WARN: fetched {total} > expected {expected_chunk_count} "
"(plan/API 漂移,仍按已拉取切片全部生成)"
)
if total == 0:
print(f"[_run_dagent_task] No paragraphs found, marking as done")
async with get_db() as db:
await db.execute(
"UPDATE qa_gen_task SET status='done', finished_at=?, total=0 WHERE id=?",
(_now(), task_id),
)
await db.commit()
return
# 更新总数
async with get_db() as db:
await db.execute(
"UPDATE qa_gen_task SET total=? WHERE id=?",
(total, task_id),
)
await db.commit()
# 3. 获取 LLM 配置
_log(f"[_run_dagent_task] Getting LLM config for judge_config_id={judge_config_id}")
async with get_db() as db:
cfg_rows = await db.execute_fetchall(
"SELECT * FROM judge_config WHERE id=?", (judge_config_id,)
)
if not cfg_rows:
raise ValueError("judge_config not found")
cfg = dict(cfg_rows[0])
_log(f"[_run_dagent_task] LLM config: {cfg.get('model')}")
# 3. 并发生成(降低并发到 3避免触发限流原 10 在 global_dedup 下压力过大)
sem = asyncio.Semaphore(3)
buf_lock = asyncio.Lock() # 保护 write_buf 的锁
done = 0
FLUSH_SIZE = 50
write_buf = []
stopped = False
_log(f"[_run_dagent_task] Starting generation: {total} paragraphs, concurrency=3, flush_size=50")
total_questions_written = 0
paragraphs_with_zero_questions = 0
async def flush_question_buf(buf: list):
"""将缓冲区问题写入 DB并同步 progress即使 buf 为空,也需要回写进度,
否则整轮 0 题的任务 progress 永远停在 0前端误以为卡死"""
async with get_db() as db2:
for p, q in buf:
qid = _id()
status = "approved" if q["quality_score"] >= quality_threshold else "pending"
await db2.execute(
"""INSERT INTO qa_gen_question
(id,task_id,section_path,question,reference_answer,source_chunk,
quality_score,status,created_at,file_id,file_name,chunk_id,chunk_headers,chunk_content_preview)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?)""",
(qid, task_id, p["headers"],
q["question"], q["answer"], q["source_chunk"],
q["quality_score"], status, _now(),
p.get("file_id", ""), p.get("file_name", ""),
p.get("id", ""), p.get("headers", ""), p.get("paragraph_context", "")[:500]),
)
if buf:
from .qa_gen import _sync_approved_count
await _sync_approved_count(db2, task_id)
await db2.execute(
"UPDATE qa_gen_task SET progress=? WHERE id=?",
(done, task_id),
)
await db2.commit()
async def process_one(para: dict):
nonlocal done, stopped, total_questions_written, paragraphs_with_zero_questions
# Check stop condition before processing
if stop_check and stop_check():
stopped = True
return
# Check pause condition before processing
if pause_check and await pause_check():
stopped = True
return
async with sem:
# Check stop condition again before LLM call
if stop_check and stop_check():
stopped = True
return
# Check pause condition again before LLM call
if pause_check and await pause_check():
stopped = True
return
# 获取该 section 下已有的问题列表
headers = para.get("headers", "")
existing = section_existing_questions.get(headers, [])
questions: list = []
merged_existing = list(existing)
max_fill_rounds = 4
consecutive_empty_rounds = 0 # 连续空轮次计数
max_consecutive_empty = 2 # 最多允许连续2轮为空才终止
for fill_round in range(max_fill_rounds):
need = questions_per_section - len(questions)
if need <= 0:
break
batch = await _generate_questions_for_paragraph(
para, cfg, need, include_multimodal,
existing_questions=merged_existing,
)
if not batch:
consecutive_empty_rounds += 1
if consecutive_empty_rounds >= max_consecutive_empty:
# 连续多轮为空才真正终止
break
# 单轮为空继续尝试下一轮
continue
# 重置连续空轮次计数
consecutive_empty_rounds = 0
questions.extend(batch)
merged_existing.extend(q["question"] for q in batch)
async with buf_lock:
done += 1
total_questions_written += len(questions)
if not questions:
paragraphs_with_zero_questions += 1
write_buf.extend([(para, q) for q in questions])
# 每100个段落打印一次进度
if done % 100 == 0 or done == total:
print(
f"[_run_dagent_task] Progress: {done}/{total} ({done*100//total}%) "
f"questions={total_questions_written} zero_chunks={paragraphs_with_zero_questions}"
)
# 有足够题目时按 FLUSH_SIZE 落盘;整轮 0 题时也要周期性回写进度(每 100 段一次)
need_flush = (
len(write_buf) >= FLUSH_SIZE
or done == total
or (done % 100 == 0)
)
if need_flush:
batch = write_buf.copy()
write_buf.clear()
await flush_question_buf(batch)
await asyncio.gather(*[process_one(p) for p in paragraphs])
# 停止/正常结束前务必刷盘,否则缓冲区里已生成的问题会整批丢失(表现为部分切片无题)
async with buf_lock:
if write_buf:
await flush_question_buf(write_buf)
write_buf.clear()
print(
f"[_run_dagent_task] First pass only (no second pass): paragraphs={total}, "
f"questions_inserted={total_questions_written}, "
f"paragraphs_with_zero_questions={paragraphs_with_zero_questions}"
)
# Check if stopped early
if stopped:
async with get_db() as db:
await db.execute(
"UPDATE qa_gen_task SET status='stopped', finished_at=? WHERE id=?",
(_now(), task_id),
)
await db.commit()
return
async with get_db() as db:
await db.execute(
"UPDATE qa_gen_task SET status='done', finished_at=? WHERE id=?",
(_now(), task_id),
)
await db.commit()
except Exception as exc:
async with get_db() as db:
await db.execute(
"UPDATE qa_gen_task SET status='failed', error_message=? WHERE id=?",
(str(exc), task_id),
)
await db.commit()