""" 从 Dagent 数据库导入知识库数据,生成多模态问答集 """ import asyncio import json import re import sys from pathlib import Path from fastapi import APIRouter, Form, HTTPException from typing import Optional import aiohttp import aiomysql import logging import os from datetime import datetime # 设置文件日志(必须在 Path 导入后) LOG_PATH = Path(__file__).parent.parent / "logs" LOG_PATH.mkdir(exist_ok=True) _logger = logging.getLogger("qa_gen_dagent") _logger.setLevel(logging.DEBUG) if not _logger.handlers: _fh = logging.FileHandler(LOG_PATH / "qa_gen_debug.log", encoding="utf-8") _fh.setLevel(logging.DEBUG) _fh.setFormatter(logging.Formatter("%(asctime)s - %(message)s")) _logger.addHandler(_fh) def _log(msg: str): """强制写入文件日志""" _logger.debug(msg) # Add parent directory to sys.path for relative imports sys.path.insert(0, str(Path(__file__).parent.parent)) from models.db import get_db, _now, _id router = APIRouter(prefix="/api/qa-gen", tags=["问题生成-Dagent"]) DAGENT_DB = { "host": "120.48.66.228", "port": 23306, "user": "dagent", "password": "Fd1.Ej3.fdIie48", "db": "dagent_platform", "charset": "utf8mb4", } async def get_dagent_conn(): """创建 Dagent 数据库连接""" return await aiomysql.connect(**DAGENT_DB) @router.get("/dagent/stats") async def get_dagent_stats(org_id: str, env_url: str = ""): """获取 Dagent 知识库统计信息(通过 HTTP API)""" import aiohttp # 使用默认生产环境 URL base_url = (env_url or "https://dagent.d-robotics.cc").rstrip("/") headers = { "Content-Type": "application/json", "org-id": org_id, "d-user-id": "test", } try: async with aiohttp.ClientSession(headers=headers) as session: # 获取文件列表 page = 1 page_size = 100 total_files = 0 total_paragraphs = 0 while True: async with session.post( f"{base_url}/dagent/knowledge/file/page", json={"current": page, "page_size": page_size, "org_id": org_id}, timeout=aiohttp.ClientTimeout(total=15), ) as resp: if resp.status != 200: break data = await resp.json() files = data.get("data", {}).get("list", []) if not files: break total_files += len(files) # 获取每个文件的切片数 for f in files: try: async with session.post( f"{base_url}/dagent/knowledge/chunk/page", json={"file_id": f["id"], "org_id": org_id, "page": 1, "page_size": 1}, timeout=aiohttp.ClientTimeout(total=10), ) as cr: if cr.status == 200: cd = await cr.json() total_paragraphs += cd.get("data", {}).get("total", 0) except Exception: pass if len(files) < page_size: break page += 1 return {"status": 0, "data": { "file_count": total_files, "paragraph_count": total_paragraphs, "total_images": 0, "paragraphs_with_pic_text": 0, "paragraphs_with_question": 0, }} except Exception as e: print(f"[get_dagent_stats] Error: {e}") return {"status": 0, "data": {}} @router.get("/dagent/files") async def list_dagent_files(org_id: str, env_url: str = ""): """列出 Dagent 中某组织下已处理完成的文件(通过 HTTP API)""" import aiohttp base_url = (env_url or "https://dagent.d-robotics.cc").rstrip("/") headers = { "Content-Type": "application/json", "org-id": org_id, "d-user-id": "test", } all_files = [] try: async with aiohttp.ClientSession(headers=headers) as session: page = 1 page_size = 100 while True: async with session.post( f"{base_url}/dagent/knowledge/file/page", json={"current": page, "page_size": page_size, "org_id": org_id}, timeout=aiohttp.ClientTimeout(total=15), ) as resp: if resp.status != 200: break data = await resp.json() files = data.get("data", {}).get("list", []) if not files: break for f in files: all_files.append({ "id": f.get("id"), "file_name": f.get("file_name"), "file_type": f.get("file_type"), "file_clean_status": f.get("file_clean_status", "").lower(), "file_bytes": f.get("file_bytes", 0), "create_time": f.get("create_time"), }) if len(files) < page_size: break page += 1 return {"status": 0, "data": all_files} except Exception as e: print(f"[list_dagent_files] Error: {e}") return {"status": 0, "data": []} @router.get("/dagent/tree") async def get_dagent_tree(org_id: str, env_url: str = ""): """ 获取知识库的层级树形结构 结构:大章节 -> 小章节 -> 文件 """ import aiohttp import asyncio base_url = (env_url or "https://dagent.d-robotics.cc").rstrip("/") headers = { "Content-Type": "application/json", "org-id": org_id, "d-user-id": "test", } try: async with aiohttp.ClientSession(headers=headers) as session: page = 1 all_files = [] while True: async with session.post( f"{base_url}/dagent/knowledge/file/page", json={"current": page, "page_size": 100, "org_id": org_id}, timeout=aiohttp.ClientTimeout(total=15), ) as resp: if resp.status != 200: break data = await resp.json() files = data.get("data", {}).get("list", []) if not files: break for f in files: all_files.append(f) if len(files) < 100: break page += 1 # 并发获取每个文件的 chunk 总数(page_size=1 只拿 total) sem = asyncio.Semaphore(20) async def fetch_chunk_count(file_id: str) -> int: async with sem: try: async with session.post( f"{base_url}/dagent/knowledge/chunk/page", json={"file_id": file_id, "org_id": org_id, "page": 1, "page_size": 1}, timeout=aiohttp.ClientTimeout(total=10), ) as cr: if cr.status == 200: cdata = await cr.json() return cdata.get("data", {}).get("total", 0) except Exception: pass return 0 chunk_counts = await asyncio.gather( *[fetch_chunk_count(f.get("id")) for f in all_files] ) # 解析文件路径并构建列表 parsed_files = [] for i, f in enumerate(all_files): file_name = f.get("file_name", "") parts = file_name.split("/") if len(parts) >= 2: major_chapter = parts[0] minor_chapter = "/".join(parts[:-1]) if len(parts) > 2 else parts[0] file_name_only = parts[-1] else: major_chapter = "默认章节" minor_chapter = "默认章节" file_name_only = file_name parsed_files.append({ "id": f.get("id"), "file_name": file_name_only, "full_path": file_name, "file_type": f.get("file_type", ""), "file_clean_status": f.get("file_clean_status", "").lower(), "major_chapter": major_chapter, "minor_chapter": minor_chapter, "chunk_count": chunk_counts[i], }) # 构建树形结构 tree = {} for f in parsed_files: major = f["major_chapter"] minor = f["minor_chapter"] if major not in tree: tree[major] = { "key": f"major:{major}", "title": major, "type": "major_chapter", "children": {} } if minor not in tree[major]["children"]: tree[major]["children"][minor] = { "key": f"minor:{minor}", "title": minor.split("/")[-1] if "/" in minor else minor, "full_path": minor, "type": "minor_chapter", "children": [] } tree[major]["children"][minor]["children"].append({ "key": f"file:{f['id']}", "title": f["file_name"], "type": "file", "file_id": f["id"], "file_type": f["file_type"], "status": f["file_clean_status"], "chunk_count": f["chunk_count"], }) result = [] for major_name, major_node in tree.items(): major_children = [] for minor_name, minor_node in major_node["children"].items(): minor_children = sorted(minor_node["children"], key=lambda x: x["title"]) major_children.append({ **{k: v for k, v in minor_node.items() if k != "children"}, "children": minor_children }) result.append({ "key": major_node["key"], "title": major_node["title"], "type": "major_chapter", "children": sorted(major_children, key=lambda x: x["title"]) }) return {"status": 0, "data": sorted(result, key=lambda x: x["title"])} except Exception as e: import traceback print(f"[get_dagent_tree] Error: {e}") print(traceback.format_exc()) return {"status": 1, "message": str(e), "data": []} @router.post("/task/from-dagent") async def create_task_from_dagent( org_id: str = Form(...), env_url: str = Form(""), name: str = Form(""), judge_config_id: str = Form(...), file_ids: str = Form(""), questions_per_section: int = Form(5), quality_threshold: float = Form(0.6), include_multimodal: bool = Form(True), ): """从 Dagent 数据库创建问答生成任务""" task_id = _id() file_id_list = [f.strip() for f in file_ids.split(",") if f.strip()] async with get_db() as db: await db.execute( """INSERT INTO qa_gen_task (id,name,judge_config_id,questions_per_section,quality_threshold,status,created_at) VALUES (?,?,?,?,?,?,?)""", (task_id, name or f"Dagent导入({org_id[:8]}...)", judge_config_id, questions_per_section, quality_threshold, "pending", _now()), ) await db.commit() asyncio.create_task(_run_dagent_task( task_id, org_id, file_id_list, judge_config_id, questions_per_section, quality_threshold, include_multimodal, env_url=env_url, )) return {"status": 0, "data": {"id": task_id}} # ── 内部:后台任务 ───────────────────────────────────────────────────────────── def _dedupe_paragraphs_by_chunk_id(paragraphs: list[dict]) -> list[dict]: """按 chunk id 去重,保留首次出现顺序(避免 API 重复页导致重复生成)。""" seen: set[str] = set() out: list[dict] = [] dup = 0 for p in paragraphs: cid = (p.get("id") or "").strip() if cid: if cid in seen: dup += 1 continue seen.add(cid) out.append(p) if dup: print(f"[_dedupe_paragraphs_by_chunk_id] removed {dup} duplicate chunk rows") return out def _merge_paragraphs_by_chunk_id(primary: list[dict], extra: list[dict]) -> list[dict]: """把 extra 中尚未出现在 primary 的 chunk 并入(按 id)。""" seen = {(p.get("id") or "").strip() for p in primary if (p.get("id") or "").strip()} merged = list(primary) for p in extra: cid = (p.get("id") or "").strip() if cid and cid in seen: continue if cid: seen.add(cid) merged.append(p) return merged async def _fetch_paragraphs(org_id: str, file_id_list: list[str], env_url: str = "") -> list[dict]: """从 Dagent HTTP API 提取段落数据 Args: file_id_list: 指定要处理的文件ID列表,如果为空则处理所有文件 """ import aiohttp base_url = (env_url or "https://dagent.d-robotics.cc").rstrip("/") headers = { "Content-Type": "application/json", "org-id": org_id, "d-user-id": "test", } all_paragraphs = [] # 单个文件的切片数上限(防止 API 忽略 file_id 返回全库切片);过小会整文件跳过 MAX_CHUNKS_PER_FILE = 50000 MAX_RETRIES = 5 # 分页触顶 / 网络抖动时多试几次 PAGE_SIZE = 100 try: async with aiohttp.ClientSession(headers=headers) as session: # 确定要处理的文件列表 files_to_process = [] if file_id_list: print(f"[_fetch_paragraphs] Processing {len(file_id_list)} user-selected files") files_to_process = [{"id": fid, "file_name": ""} for fid in file_id_list] else: print(f"[_fetch_paragraphs] Fetching file list...") page = 1 all_files = [] while True: async with session.post( f"{base_url}/dagent/knowledge/file/page", json={"current": page, "page_size": 100, "org_id": org_id}, timeout=aiohttp.ClientTimeout(total=15), ) as resp: if resp.status != 200: break data = await resp.json() files = data.get("data", {}).get("list", []) if not files: break all_files.extend(files) if len(files) < 100: break page += 1 print(f"[_fetch_paragraphs] Total files available: {len(all_files)}, will process all") files_to_process = all_files # 获取每个文件的切片 total_files = len(files_to_process) for idx, f in enumerate(files_to_process): file_id = f.get("id") if isinstance(f, dict) else f file_name = f.get("file_name", "") if isinstance(f, dict) else "" if idx % 10 == 0: print(f"[_fetch_paragraphs] Processing file {idx+1}/{total_files}: {file_id[:20]}...") # 先用 page_size=1 探测该文件的 total,验证 API 是否正确过滤 expected_total = None for attempt in range(MAX_RETRIES): try: async with session.post( f"{base_url}/dagent/knowledge/chunk/page", json={"file_id": file_id, "org_id": org_id, "page": 1, "page_size": 1}, timeout=aiohttp.ClientTimeout(total=15), ) as resp: if resp.status != 200: print(f"[_fetch_paragraphs] Probe failed for {file_id[:20]}: HTTP {resp.status}") await asyncio.sleep(2 ** attempt) continue probe_data = await resp.json() expected_total = probe_data.get("data", {}).get("total", 0) except Exception as e: print(f"[_fetch_paragraphs] Probe error for {file_id[:20]}: {e}") await asyncio.sleep(2 ** attempt) continue if expected_total is not None and expected_total <= MAX_CHUNKS_PER_FILE: break elif expected_total is not None and expected_total > MAX_CHUNKS_PER_FILE: print(f"[_fetch_paragraphs] WARNING: file {file_id[:20]} returned total={expected_total}, " f"likely API bug (file_id ignored). Retrying ({attempt+1}/{MAX_RETRIES})...") expected_total = None await asyncio.sleep(3 * (attempt + 1)) if expected_total is None or expected_total > MAX_CHUNKS_PER_FILE: print(f"[_fetch_paragraphs] SKIPPING file {file_id[:20]} ({file_name}): " f"total={expected_total} exceeds limit after {MAX_RETRIES} retries") continue if expected_total == 0: continue # 正式分页拉取:不得以「已收集数 >= API total」提前停——total 常低于真实切片数,会少拉约一页~数页。 # max_pages 给足余量;仅当末页 < PAGE_SIZE 或返回空 list 时视为自然结束。 for fetch_attempt in range(MAX_RETRIES): slack = 80 + fetch_attempt * 60 max_pages = min( 2000, max(50, (expected_total + PAGE_SIZE - 1) // PAGE_SIZE + slack), ) page = 1 file_chunks = [] fetch_ok = True foreign_count = 0 ended_normally = False # 空页或末页不满 PAGE_SIZE while page <= max_pages: try: async with session.post( f"{base_url}/dagent/knowledge/chunk/page", json={ "file_id": file_id, "org_id": org_id, "page": page, "page_size": PAGE_SIZE, }, timeout=aiohttp.ClientTimeout(total=30), ) as resp: if resp.status != 200: fetch_ok = False break data = await resp.json() chunks = data.get("data", {}).get("list", []) if not chunks: ended_normally = True break page_foreign = 0 for c in chunks: chunk_fid = c.get("file_id", "") if chunk_fid and chunk_fid != file_id: foreign_count += 1 page_foreign += 1 continue # large_paragraph_llm_summary:后端大段压缩后的摘要,常与 paragraph_context 二选一存在; # 若不映射,大量切片会落入「无正文」→ 生成阶段恒返回 0 题。 _ctx = ( c.get("active_paragraph_context") or c.get("paragraph_context") or c.get("active_context") or "" ) _llm_sum = (c.get("large_paragraph_llm_summary") or "").strip() _para_sum = (c.get("paragraph_summary") or "").strip() file_chunks.append({ "id": c.get("id"), "file_id": file_id, "file_name": file_name or c.get("file_name", ""), "headers": c.get("headers", ""), "paragraph_context": _ctx or _llm_sum, "paragraph_img_num": c.get("paragraph_img_num", 0), "paragraph_pic_semantics_context": c.get("paragraph_pic_semantics_context", ""), "paragraph_question": c.get("paragraph_question", ""), "paragraph_summary": _para_sum or _llm_sum, "paragraph_keywords": c.get("paragraph_keywords", ""), }) if len(chunks) < PAGE_SIZE: ended_normally = True break if page_foreign > len(chunks) * 0.5: print( f"[_fetch_paragraphs] Page {page}: high foreign ratio " f"{page_foreign}/{len(chunks)} for file {file_id[:20]}, continuing" ) page += 1 except Exception as e: print(f"[_fetch_paragraphs] Error fetching chunks for file {file_id[:20]}: {e}") fetch_ok = False break if foreign_count > 0: print( f"[_fetch_paragraphs] File {file_id[:20]}: filtered {foreign_count} foreign, " f"kept {len(file_chunks)}" ) if not fetch_ok: if fetch_attempt < MAX_RETRIES - 1: print( f"[_fetch_paragraphs] File {file_id[:20]}: fetch error, " f"retry ({fetch_attempt + 1}/{MAX_RETRIES})..." ) file_chunks = [] await asyncio.sleep(3 * (fetch_attempt + 1)) continue break if ended_normally: if expected_total and len(file_chunks) < expected_total: print( f"[_fetch_paragraphs] File {file_id[:20]}: EOF kept={len(file_chunks)} " f"vs API total={expected_total} (often foreign rows in total)" ) break # 未自然结束:多半触达 max_pages 且最后一页仍为满页,继续扩页重试 if fetch_attempt < MAX_RETRIES - 1: print( f"[_fetch_paragraphs] File {file_id[:20]}: page cap hit " f"(last_page={page - 1}, max_pages={max_pages}, kept={len(file_chunks)}), " f"retry ({fetch_attempt + 1}/{MAX_RETRIES})..." ) file_chunks = [] await asyncio.sleep(3 * (fetch_attempt + 1)) continue print( f"[_fetch_paragraphs] WARNING: file {file_id[:20]} still not EOF after " f"{MAX_RETRIES} attempts; accepting {len(file_chunks)} chunks" ) break if file_chunks: all_paragraphs.extend(file_chunks) all_paragraphs = _dedupe_paragraphs_by_chunk_id(all_paragraphs) print(f"[_fetch_paragraphs] Total paragraphs fetched: {len(all_paragraphs)} from {total_files} files") return all_paragraphs except Exception as e: import traceback print(f"[_fetch_paragraphs] Error: {e}") print(f"[_fetch_paragraphs] Traceback: {traceback.format_exc()}") return [] def _extract_json_array(text: str) -> Optional[list]: """容错解析 LLM 返回的 JSON 数组。 策略依次尝试: 1) 直接 json.loads 整个响应 2) 抠出 ```...``` 或 ```json ... ``` 代码块再 loads 3) 以第一个 `[` 为起点,按括号配平找到对应 `]`(跳过字符串内的括号) 4) 若因截断未闭合,尝试在最后一个完整对象 `}` 处强制补 `]` 再 loads 任一成功即返回 list;全部失败返回 None。 """ if not text: return None stripped = text.strip() # 1) 整体 loads try: data = json.loads(stripped) if isinstance(data, list): return data except Exception: pass # 2) 代码块 block = re.search(r"```(?:json)?\s*(.*?)```", stripped, re.DOTALL | re.IGNORECASE) if block: try: data = json.loads(block.group(1).strip()) if isinstance(data, list): return data except Exception: pass # 3) 括号配平(跳过字符串内的括号) start = stripped.find("[") if start == -1: return None depth = 0 in_str = False escape = False end = -1 for i in range(start, len(stripped)): ch = stripped[i] if escape: escape = False continue if ch == "\\": escape = True continue if ch == '"': in_str = not in_str continue if in_str: continue if ch == "[": depth += 1 elif ch == "]": depth -= 1 if depth == 0: end = i break if end != -1: candidate = stripped[start:end + 1] try: data = json.loads(candidate) if isinstance(data, list): return data except Exception: pass # 4) 截断恢复:找最后一个完整对象的 `}`,强制补 `]` tail_brace = stripped.rfind("}") if tail_brace > start: candidate = stripped[start:tail_brace + 1] + "]" try: data = json.loads(candidate) if isinstance(data, list): return data except Exception: pass return None def _parse_quality_score(raw) -> float: """模型自评分数:缺省/非法时用 0.8,避免 float(None) 整段失败;限制在 [0,1]。""" if raw is None: return 0.8 try: v = float(raw) except (TypeError, ValueError): return 0.8 return max(0.0, min(1.0, v)) async def _call_llm_once( session, base_url: str, payload: dict, timeout_s: int ) -> tuple[Optional[str], Optional[str]]: """单次调用 LLM,返回 (content, error_str)。""" try: async with session.post( f"{base_url}/chat/completions", json=payload, timeout=aiohttp.ClientTimeout(total=timeout_s), ) as resp: if resp.status != 200: body = (await resp.text())[:500] return None, f"HTTP {resp.status}: {body}" data = await resp.json() content = data["choices"][0]["message"]["content"].strip() return content, None except Exception as e: return None, str(e) async def _generate_questions_for_paragraph( para: dict, cfg: dict, n: int, include_multimodal: bool, existing_questions: list[str] = None, # 已有的问题列表,用于避免重复 ) -> list[dict]: """为单个段落生成问答,支持传入已有问题避免重复。 改进: - 引入容错 JSON 抽取(_extract_json_array),避免贪婪正则漏解析/截断直接丢题。 - 增加重试与自适应降 n:单次失败后指数退避重试;若怀疑是 max_tokens 截断,下一次把 n 折半。 - 复用 ClientSession(在此函数内同次生成共享;跨调用暂未共享以保持接口稳定)。 - 放宽「已有历史问题」的硬约束,改为软参考,避免第 3/4 轮模型大量拒答。 """ base_url = cfg.get("base_url", "").rstrip("/") api_key = cfg.get("api_key", "") model = cfg.get("model", "gpt-4o-mini") context_plain = (para.get("paragraph_context") or "").strip() pic_semantics = (para.get("paragraph_pic_semantics_context") or "").strip() seed_question = (para.get("paragraph_question") or "").strip() headers = (para.get("headers") or "").strip() summary = (para.get("paragraph_summary") or "").strip() keywords = (para.get("paragraph_keywords") or "").strip() has_image = bool(pic_semantics and para.get("paragraph_img_num", 0) > 0) text = context_plain if not text: text = summary if not text and seed_question: text = seed_question if not text and has_image and include_multimodal and pic_semantics: text = pic_semantics[:2500] if not text and keywords: text = f"关键词:\n{keywords[:1500]}" if not text and headers: text = ( f"(该切片缺少正文/摘要,仅章节路径如下;请基于路径生成 {n} 个简短、可检索的技术问题," f"答案可写「需结合全文」类占位但问题须具体)\n{headers}" ) if not text: return [] # 构建 prompt(正文为空但用图片语义作主内容时不再重复插入图片块) pic_section = "" if has_image and include_multimodal and pic_semantics and context_plain: pic_section = f""" **图片语义描述(图片已由 AI 识别):** {pic_semantics[:800]} """ seed_section = "" if seed_question: seed_section = f"\n**已有种子问题(请避免重复,可从不同角度扩展):** {seed_question}" # 已有问题列表(来自循环任务的历史问题) # 放宽为「风格参考」——模型在历史 10+ 条强约束下常直接返回空数组,导致循环第 3/4 轮整轮 0 题。 existing_section = "" if existing_questions: sample_existing = existing_questions[:5] existing_section = ( "\n**该段落的历史问题(供参考,尽量换角度/换措辞,但不必完全不同):**\n" ) for i, eq in enumerate(sample_existing, 1): existing_section += f"{i}. {eq}\n" def _build_prompt(ask_n: int) -> str: return f"""你是一个技术文档问答生成专家。基于以下内容生成 {ask_n} 个测试问题。 **章节路径:** {headers} **文本内容:** {text[:2500]} {pic_section}{seed_section}{existing_section} **要求:** 1. 问题必须能从该章节内容直接回答 2. 覆盖关键知识点,避免过于简单的是非题 3. 如果有图片语义描述,至少生成 1 个图文结合的问题(问题中提及"如图所示"、"图中"等) 4. 答案准确,长度适中(1-3 句话) 5. source_chunk 为答案来源的原文片段(50-150 字) 6. has_image 标记该问题是否依赖图像信息 7. quality_score 为质量评估(0-1) **输出格式:严格只输出一个合法 JSON 数组,不要额外解释、不要代码块标记。** [ {{ "question": "问题文本", "answer": "参考答案", "source_chunk": "答案来源原文片段", "has_image": false, "quality_score": 0.9 }} ]""" headers_http = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", } MAX_ATTEMPTS = 3 TIMEOUT_S = 120 cur_n = n async with aiohttp.ClientSession(headers=headers_http) as session: for attempt in range(MAX_ATTEMPTS): prompt = _build_prompt(cur_n) payload = { "model": model, "messages": [{"role": "user", "content": prompt}], "temperature": 0.3, } content, err = await _call_llm_once(session, base_url, payload, TIMEOUT_S) if err: is_rate_limit = "429" in err or "Rate limit" in err or "rate limit" in err is_budget = "400" in err and ("Budget" in err or "budget" in err or "budget_exceeded" in err) _log( f"[_generate_questions] Attempt {attempt + 1}/{MAX_ATTEMPTS} failed " f"for headers={headers[:50]}: {err[:200]}" ) # 限流或预算超限:长退避;其他错误:短退避 if attempt < MAX_ATTEMPTS - 1: if is_budget: wait_s = 300 + 300 * attempt # 预算超限:5min, 10min, 15min _log(f"[_generate_questions] Budget exceeded, backing off {wait_s}s (wait for reset)") elif is_rate_limit: wait_s = 30 + 15 * attempt # 429: 30s, 45s, 60s _log(f"[_generate_questions] Rate limit detected, backing off {wait_s}s") else: wait_s = 2 + 2 * attempt # 其他: 2s, 4s, 6s await asyncio.sleep(wait_s) continue questions = _extract_json_array(content) if not questions: # 看起来像被截断:响应末尾既无 `]` 又无 `}`,下一轮降 n looks_truncated = not content.rstrip().endswith(("]", "}")) _log( f"[_generate_questions] Attempt {attempt + 1}/{MAX_ATTEMPTS}: " f"JSON parse failed for headers={headers[:50]} " f"(len={len(content)}, truncated={looks_truncated})" ) _log(f"[_generate_questions] Raw response preview: {content[:300]}...{content[-300:]}") if looks_truncated and cur_n > 1: cur_n = max(1, cur_n // 2) if attempt < MAX_ATTEMPTS - 1: await asyncio.sleep(1.0 * (attempt + 1)) continue result = [] for q in questions: if isinstance(q, dict) and q.get("question") and q.get("answer"): result.append({ "question": str(q["question"]).strip(), "answer": str(q["answer"]).strip(), "source_chunk": str(q.get("source_chunk", "")).strip(), "has_image": bool(q.get("has_image", False)), "quality_score": _parse_quality_score(q.get("quality_score")), "source_image_desc": pic_semantics[:300] if q.get("has_image") else "", }) if result: _log( f"[_generate_questions] Generated {len(result)} questions for " f"headers={headers[:50]} (attempt {attempt + 1}, asked={cur_n})" ) return result _log( f"[_generate_questions] Attempt {attempt + 1}/{MAX_ATTEMPTS}: " f"JSON parsed but 0 valid items for headers={headers[:50]} " f"(raw count={len(questions)})" ) if attempt < MAX_ATTEMPTS - 1: await asyncio.sleep(1.0 * (attempt + 1)) _log(f"[_generate_questions] All {MAX_ATTEMPTS} attempts exhausted for headers={headers[:50]}") return [] async def _run_dagent_task( task_id: str, org_id: str, file_id_list: list[str], judge_config_id: str, questions_per_section: int, quality_threshold: float, include_multimodal: bool, section_existing_questions: dict[str, list[str]] = None, # {section_path: [question1, question2, ...]} stop_check: callable = None, # Optional stop check function pause_check: callable = None, # Optional async pause check function env_url: str = "", # Dagent environment URL expected_chunk_count: Optional[int] = None, # 批次规划切片总数;与拉取结果对齐校验 ): """ 运行 Dagent QA 生成任务 Args: section_existing_questions: 各 section 下已有的问题列表,用于避免重复生成 stop_check: 可选的停止检查函数,返回True时应停止任务 env_url: Dagent 环境 URL expected_chunk_count: 与 chunk_batches_plan 中本批 chunk_count 一致时,强制校验去重后的拉取条数 """ import traceback section_existing_questions = section_existing_questions or {} print(f"[_run_dagent_task] Starting task {task_id}, org_id={org_id}, file_id_list={len(file_id_list)} files, env_url={env_url}") try: # 1. 先更新状态为 running,让用户知道任务已开始 async with get_db() as db: await db.execute( "UPDATE qa_gen_task SET status='running', total=0, progress=0 WHERE id=?", (task_id,), ) await db.commit() # 2. 提取段落(可多次拉取合并,直至满足 expected_chunk_count) print(f"[_run_dagent_task] Fetching paragraphs...") paragraphs = await _fetch_paragraphs(org_id, file_id_list, env_url) paragraphs = _dedupe_paragraphs_by_chunk_id(paragraphs) if expected_chunk_count and len(paragraphs) < expected_chunk_count: for refetch_i in range(3): short_by = expected_chunk_count - len(paragraphs) print( f"[_run_dagent_task] Chunk count {len(paragraphs)} < expected {expected_chunk_count} " f"(short {short_by}), refetch merge attempt {refetch_i + 1}/3" ) more = await _fetch_paragraphs(org_id, file_id_list, env_url) paragraphs = _merge_paragraphs_by_chunk_id(paragraphs, more) if len(paragraphs) >= expected_chunk_count: break await asyncio.sleep(5 * (refetch_i + 1)) if expected_chunk_count and len(paragraphs) < expected_chunk_count: raise RuntimeError( f"拉取切片 {len(paragraphs)} 条,少于批次期望 {expected_chunk_count} 条;" f"请检查 Dagent chunk/page API、file_ids 是否与 chunk_batches_plan 一致。" ) total = len(paragraphs) print( f"[_run_dagent_task] Fetched {total} paragraphs" + (f" (expected_chunk_count={expected_chunk_count})" if expected_chunk_count else "") ) if expected_chunk_count and total > expected_chunk_count + 5: print( f"[_run_dagent_task] WARN: fetched {total} > expected {expected_chunk_count} " "(plan/API 漂移,仍按已拉取切片全部生成)" ) if total == 0: print(f"[_run_dagent_task] No paragraphs found, marking as done") async with get_db() as db: await db.execute( "UPDATE qa_gen_task SET status='done', finished_at=?, total=0 WHERE id=?", (_now(), task_id), ) await db.commit() return # 更新总数 async with get_db() as db: await db.execute( "UPDATE qa_gen_task SET total=? WHERE id=?", (total, task_id), ) await db.commit() # 3. 获取 LLM 配置 _log(f"[_run_dagent_task] Getting LLM config for judge_config_id={judge_config_id}") async with get_db() as db: cfg_rows = await db.execute_fetchall( "SELECT * FROM judge_config WHERE id=?", (judge_config_id,) ) if not cfg_rows: raise ValueError("judge_config not found") cfg = dict(cfg_rows[0]) _log(f"[_run_dagent_task] LLM config: {cfg.get('model')}") # 3. 并发生成(降低并发到 3,避免触发限流;原 10 在 global_dedup 下压力过大) sem = asyncio.Semaphore(3) buf_lock = asyncio.Lock() # 保护 write_buf 的锁 done = 0 FLUSH_SIZE = 50 write_buf = [] stopped = False _log(f"[_run_dagent_task] Starting generation: {total} paragraphs, concurrency=3, flush_size=50") total_questions_written = 0 paragraphs_with_zero_questions = 0 async def flush_question_buf(buf: list): """将缓冲区问题写入 DB,并同步 progress(即使 buf 为空,也需要回写进度, 否则整轮 0 题的任务 progress 永远停在 0,前端误以为卡死)。""" async with get_db() as db2: for p, q in buf: qid = _id() status = "approved" if q["quality_score"] >= quality_threshold else "pending" await db2.execute( """INSERT INTO qa_gen_question (id,task_id,section_path,question,reference_answer,source_chunk, quality_score,status,created_at,file_id,file_name,chunk_id,chunk_headers,chunk_content_preview) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", (qid, task_id, p["headers"], q["question"], q["answer"], q["source_chunk"], q["quality_score"], status, _now(), p.get("file_id", ""), p.get("file_name", ""), p.get("id", ""), p.get("headers", ""), p.get("paragraph_context", "")[:500]), ) if buf: from .qa_gen import _sync_approved_count await _sync_approved_count(db2, task_id) await db2.execute( "UPDATE qa_gen_task SET progress=? WHERE id=?", (done, task_id), ) await db2.commit() async def process_one(para: dict): nonlocal done, stopped, total_questions_written, paragraphs_with_zero_questions # Check stop condition before processing if stop_check and stop_check(): stopped = True return # Check pause condition before processing if pause_check and await pause_check(): stopped = True return async with sem: # Check stop condition again before LLM call if stop_check and stop_check(): stopped = True return # Check pause condition again before LLM call if pause_check and await pause_check(): stopped = True return # 获取该 section 下已有的问题列表 headers = para.get("headers", "") existing = section_existing_questions.get(headers, []) questions: list = [] merged_existing = list(existing) max_fill_rounds = 4 consecutive_empty_rounds = 0 # 连续空轮次计数 max_consecutive_empty = 2 # 最多允许连续2轮为空才终止 for fill_round in range(max_fill_rounds): need = questions_per_section - len(questions) if need <= 0: break batch = await _generate_questions_for_paragraph( para, cfg, need, include_multimodal, existing_questions=merged_existing, ) if not batch: consecutive_empty_rounds += 1 if consecutive_empty_rounds >= max_consecutive_empty: # 连续多轮为空才真正终止 break # 单轮为空继续尝试下一轮 continue # 重置连续空轮次计数 consecutive_empty_rounds = 0 questions.extend(batch) merged_existing.extend(q["question"] for q in batch) async with buf_lock: done += 1 total_questions_written += len(questions) if not questions: paragraphs_with_zero_questions += 1 write_buf.extend([(para, q) for q in questions]) # 每100个段落打印一次进度 if done % 100 == 0 or done == total: print( f"[_run_dagent_task] Progress: {done}/{total} ({done*100//total}%) " f"questions={total_questions_written} zero_chunks={paragraphs_with_zero_questions}" ) # 有足够题目时按 FLUSH_SIZE 落盘;整轮 0 题时也要周期性回写进度(每 100 段一次) need_flush = ( len(write_buf) >= FLUSH_SIZE or done == total or (done % 100 == 0) ) if need_flush: batch = write_buf.copy() write_buf.clear() await flush_question_buf(batch) await asyncio.gather(*[process_one(p) for p in paragraphs]) # 停止/正常结束前务必刷盘,否则缓冲区里已生成的问题会整批丢失(表现为部分切片无题) async with buf_lock: if write_buf: await flush_question_buf(write_buf) write_buf.clear() print( f"[_run_dagent_task] First pass only (no second pass): paragraphs={total}, " f"questions_inserted={total_questions_written}, " f"paragraphs_with_zero_questions={paragraphs_with_zero_questions}" ) # Check if stopped early if stopped: async with get_db() as db: await db.execute( "UPDATE qa_gen_task SET status='stopped', finished_at=? WHERE id=?", (_now(), task_id), ) await db.commit() return async with get_db() as db: await db.execute( "UPDATE qa_gen_task SET status='done', finished_at=? WHERE id=?", (_now(), task_id), ) await db.commit() except Exception as exc: async with get_db() as db: await db.execute( "UPDATE qa_gen_task SET status='failed', error_message=? WHERE id=?", (str(exc), task_id), ) await db.commit()