607 lines
23 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
问题生成 API
"""
import asyncio
import json
import re
import sys
from pathlib import Path
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import Optional
from urllib.parse import quote
# Add parent directory to sys.path for relative imports
sys.path.insert(0, str(Path(__file__).parent.parent))
from models.db import get_db, _now, _id
router = APIRouter(prefix="/api/qa-gen", tags=["问题生成"])
# ── 任务 CRUD ─────────────────────────────────────────────────────────────────
@router.post("/task")
async def create_task(
file: UploadFile = File(...),
name: str = Form(""),
judge_config_id: str = Form(...),
questions_per_section: int = Form(5),
quality_threshold: float = Form(0.6),
):
content = await file.read()
md_text = content.decode("utf-8")
task_id = _id()
async with get_db() as db:
await db.execute(
"""INSERT INTO qa_gen_task
(id,name,judge_config_id,questions_per_section,quality_threshold,status,created_at)
VALUES (?,?,?,?,?,?,?)""",
(task_id, name or file.filename, judge_config_id,
questions_per_section, quality_threshold, "pending", _now()),
)
await db.commit()
asyncio.create_task(_run_task(task_id, md_text))
return {"status": 0, "data": {"id": task_id}}
@router.get("/task/list")
async def list_tasks():
async with get_db() as db:
rows = await db.execute_fetchall(
"SELECT * FROM qa_gen_task ORDER BY created_at DESC"
)
return {"status": 0, "data": [dict(r) for r in rows]}
class CreateDatasetReq(BaseModel):
name: str
knowledge_hub_id: str = ""
description: str = ""
@router.post("/task/{task_id}/create-dataset")
async def create_dataset_from_qa_gen(task_id: str, req: CreateDatasetReq):
"""根据 QA 生成任务创建评测数据集"""
async with get_db() as db:
# 检查任务是否存在
rows = await db.execute_fetchall(
"SELECT * FROM qa_gen_task WHERE id=?", (task_id,)
)
if not rows:
raise HTTPException(status_code=404, detail="QA 生成任务不存在")
# 获取已通过的问题
question_rows = await db.execute_fetchall(
"SELECT * FROM qa_gen_question WHERE task_id=? AND status='approved'",
(task_id,)
)
if not question_rows:
raise HTTPException(status_code=400, detail="没有已通过的问题")
# 创建数据集
dataset_id = _id()
await db.execute(
"INSERT INTO eval_dataset (id,name,description,sample_count,created_at) VALUES (?,?,?,?,?)",
(dataset_id, req.name, req.description, len(question_rows), _now()),
)
# 添加样本
for q in question_rows:
q_dict = dict(q)
sample_id = _id()
await db.execute(
"""INSERT INTO eval_sample
(id,dataset_id,question,reference_answer,relevant_chunk_ids,knowledge_hub_id,source_file_id,metadata)
VALUES (?,?,?,?,?,?,?,?)""",
(sample_id, dataset_id, q_dict["question"], q_dict["reference_answer"],
json.dumps([], ensure_ascii=False), req.knowledge_hub_id,
None, json.dumps({"source": "qa_gen", "qa_gen_task_id": task_id}, ensure_ascii=False)),
)
await db.commit()
return {"status": 0, "data": {"dataset_id": dataset_id, "sample_count": len(question_rows)}}
@router.get("/task/{task_id}")
async def get_task(task_id: str):
async with get_db() as db:
rows = await db.execute_fetchall(
"SELECT * FROM qa_gen_task WHERE id=?", (task_id,)
)
if not rows:
raise HTTPException(status_code=404, detail="Task not found")
return {"status": 0, "data": dict(rows[0])}
@router.delete("/task/{task_id}")
async def delete_task(task_id: str):
async with get_db() as db:
await db.execute("DELETE FROM qa_gen_question WHERE task_id=?", (task_id,))
await db.execute("DELETE FROM qa_gen_task WHERE id=?", (task_id,))
await db.commit()
return {"status": 0, "data": True}
# ── 问题列表 ──────────────────────────────────────────────────────────────────
@router.get("/task/{task_id}/questions")
async def list_questions(
task_id: str,
status: Optional[str] = None,
section: Optional[str] = None,
page: int = 1,
page_size: int = 50,
):
conditions = ["task_id=?"]
params: list = [task_id]
if status:
conditions.append("status=?")
params.append(status)
if section:
conditions.append("section_path=?")
params.append(section)
where = " AND ".join(conditions)
offset = (page - 1) * page_size
async with get_db() as db:
count_rows = await db.execute_fetchall(
f"SELECT COUNT(*) as cnt FROM qa_gen_question WHERE {where}", params
)
total = dict(count_rows[0])["cnt"]
rows = await db.execute_fetchall(
f"""SELECT id,task_id,section_path,question,reference_answer,source_chunk,
quality_score,quality_detail,dup_of,dup_similarity,status,created_at,updated_at,
chunk_headers,chunk_id,file_id,file_name
FROM qa_gen_question WHERE {where}
ORDER BY section_path, created_at
LIMIT ? OFFSET ?""",
params + [page_size, offset],
)
items = []
for r in rows:
d = dict(r)
if d.get("quality_detail"):
try:
d["quality_detail"] = json.loads(d["quality_detail"])
except Exception:
pass
items.append(d)
return {"status": 0, "data": {"total": total, "items": items}}
@router.get("/task/{task_id}/sections")
async def list_sections(task_id: str):
"""返回任务下各章节的问题统计"""
async with get_db() as db:
rows = await db.execute_fetchall(
"""SELECT section_path,
COUNT(*) as total,
SUM(CASE WHEN status='approved' THEN 1 ELSE 0 END) as approved,
SUM(CASE WHEN status='rejected' THEN 1 ELSE 0 END) as rejected,
SUM(CASE WHEN status='pending' THEN 1 ELSE 0 END) as pending,
SUM(CASE WHEN dup_of IS NOT NULL THEN 1 ELSE 0 END) as duplicates,
AVG(quality_score) as avg_quality
FROM qa_gen_question WHERE task_id=?
GROUP BY section_path ORDER BY section_path""",
(task_id,),
)
return {"status": 0, "data": [dict(r) for r in rows]}
# ── 审核操作 ──────────────────────────────────────────────────────────────────
@router.post("/question/{question_id}/approve")
async def approve_question(question_id: str):
async with get_db() as db:
rows = await db.execute_fetchall(
"SELECT task_id FROM qa_gen_question WHERE id=?", (question_id,)
)
if not rows:
raise HTTPException(status_code=404, detail="Question not found")
task_id = dict(rows[0])["task_id"]
await db.execute(
"UPDATE qa_gen_question SET status='approved', updated_at=? WHERE id=?",
(_now(), question_id),
)
await _sync_approved_count(db, task_id)
await db.commit()
return {"status": 0, "data": True}
@router.post("/question/{question_id}/reject")
async def reject_question(question_id: str):
async with get_db() as db:
rows = await db.execute_fetchall(
"SELECT task_id FROM qa_gen_question WHERE id=?", (question_id,)
)
if not rows:
raise HTTPException(status_code=404, detail="Question not found")
task_id = dict(rows[0])["task_id"]
await db.execute(
"UPDATE qa_gen_question SET status='rejected', updated_at=? WHERE id=?",
(_now(), question_id),
)
await _sync_approved_count(db, task_id)
await db.commit()
return {"status": 0, "data": True}
class QuestionEditReq(BaseModel):
question: Optional[str] = None
reference_answer: Optional[str] = None
@router.put("/question/{question_id}")
async def edit_question(question_id: str, req: QuestionEditReq):
async with get_db() as db:
rows = await db.execute_fetchall(
"SELECT task_id FROM qa_gen_question WHERE id=?", (question_id,)
)
if not rows:
raise HTTPException(status_code=404, detail="Question not found")
task_id = dict(rows[0])["task_id"]
updates = []
params = []
if req.question is not None:
updates.append("question=?")
params.append(req.question)
if req.reference_answer is not None:
updates.append("reference_answer=?")
params.append(req.reference_answer)
if not updates:
raise HTTPException(status_code=400, detail="No fields to update")
updates.append("status='approved'")
updates.append("updated_at=?")
params.append(_now())
params.append(question_id)
await db.execute(
f"UPDATE qa_gen_question SET {', '.join(updates)} WHERE id=?", params
)
await _sync_approved_count(db, task_id)
await db.commit()
return {"status": 0, "data": True}
@router.post("/task/{task_id}/batch-approve")
async def batch_approve(task_id: str, min_quality: float = 0.0):
"""批量通过:通过 quality_score >= min_quality 且非重复的 pending 问题"""
async with get_db() as db:
await db.execute(
"""UPDATE qa_gen_question SET status='approved', updated_at=?
WHERE task_id=? AND status='pending' AND dup_of IS NULL
AND (quality_score IS NULL OR quality_score >= ?)""",
(_now(), task_id, min_quality),
)
await _sync_approved_count(db, task_id)
await db.commit()
return {"status": 0, "data": True}
# ── 导出 MD ───────────────────────────────────────────────────────────────────
@router.get("/task/{task_id}/export-md")
async def export_md(task_id: str):
"""导出已通过的问题为标准 MD 格式(与单跳测试输入格式一致)"""
async with get_db() as db:
task_rows = await db.execute_fetchall(
"SELECT name FROM qa_gen_task WHERE id=?", (task_id,)
)
if not task_rows:
raise HTTPException(status_code=404, detail="Task not found")
task_name = dict(task_rows[0]).get("name", task_id)
rows = await db.execute_fetchall(
"""SELECT section_path, qid, question, reference_answer, file_name
FROM (
SELECT section_path, question, reference_answer, file_name,
ROW_NUMBER() OVER (PARTITION BY section_path ORDER BY created_at) as rn,
'Q' || ROW_NUMBER() OVER (PARTITION BY section_path ORDER BY created_at) as qid
FROM qa_gen_question
WHERE task_id=? AND status='approved'
)
ORDER BY section_path, rn""",
(task_id,),
)
# Convert rows to dicts while connection is still open
row_dicts = [dict(r) for r in rows]
if not row_dicts:
raise HTTPException(status_code=404, detail="没有已通过的问题")
from collections import defaultdict
sections: dict[str, list] = defaultdict(list)
section_file_names: dict[str, str] = {}
for d in row_dicts:
sections[d["section_path"]].append(d)
if d.get("file_name") and d["section_path"] not in section_file_names:
section_file_names[d["section_path"]] = d["file_name"]
lines = []
import re
def clean_for_parser(text: str) -> str:
"""清理文本以匹配解析器正则表达式,保留中文字符"""
if not text:
return "default"
# 保留中文字符、数字、字母、下划线、斜杠、空格、点、连字符
cleaned = re.sub(r'[^一-龥a-zA-Z0-9_/ .\-]', '_', text)
cleaned = cleaned.strip()
if cleaned.startswith('.'):
cleaned = '_' + cleaned[1:]
return cleaned if cleaned else "default_section"
section_index = 0
for section_path, items in sections.items():
section_index += 1
file_name = section_file_names.get(section_path)
if file_name:
# 使用 Dagent 的 file_name 作为 section 标识
doc_name = file_name.rsplit(".", 1)[0] if "." in file_name else file_name
chapter_title = f"{section_index}{doc_name.split('/')[-1]}"
lines.append(f"# {chapter_title}")
lines.append(f"## {file_name} / {doc_name}")
lines.append(f"# {section_index}. {doc_name.split('/')[-1]}_Document")
else:
# 回退:没有 file_name 时用清理后的 section_path
clean_section_path = clean_for_parser(section_path)
raw_doc_name = section_path.split("/")[-1] if "/" in section_path else section_path
clean_doc_name = clean_for_parser(raw_doc_name)
chapter_title = f"{section_index}{clean_doc_name}"
lines.append(f"# {chapter_title}")
lines.append(f"## {clean_section_path} / {clean_doc_name}")
lines.append(f"# {section_index}. {clean_doc_name}_Document")
lines.append("> Generated from QA generation task")
lines.append("---")
lines.append("")
for item in items:
qid = item["qid"]
aid = qid.replace("Q", "A")
lines.append(f"## {qid}: {item['question']}")
lines.append(f"**{aid}:** {item['reference_answer']}")
lines.append("")
lines.append("---")
lines.append("")
md_content = "\n".join(lines)
filename_encoded = quote(f"qa_{task_name}.md".replace(" ", "_"))
return StreamingResponse(
iter([md_content.encode("utf-8")]),
media_type="text/markdown",
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{filename_encoded}"},
)
# ── 内部:运行生成任务 ─────────────────────────────────────────────────────────
async def _run_task(task_id: str, md_text: str):
try:
from rag_eval.single_jump.parser import parse_qa_file_text as _parse
# 复用 single_jump parser 解析章节结构,但这里 md_text 是知识库原文
# 需要用自定义解析器按 ## 切分章节
sections = _parse_knowledge_md(md_text)
total = len(sections)
async with get_db() as db:
await db.execute(
"UPDATE qa_gen_task SET status='running', total=? WHERE id=?",
(total, task_id),
)
await db.commit()
# 获取 judge_config
async with get_db() as db:
cfg_rows = await db.execute_fetchall(
"SELECT * FROM qa_gen_task t JOIN judge_config j ON t.judge_config_id=j.id WHERE t.id=?",
(task_id,),
)
if not cfg_rows:
raise ValueError("judge_config not found")
cfg = dict(cfg_rows[0])
questions_per_section = cfg["questions_per_section"]
quality_threshold = cfg["quality_threshold"]
# 逐章节生成
sem = asyncio.Semaphore(3)
done = 0
async def gen_section(section_path: str, content: str):
nonlocal done
async with sem:
questions = await _generate_questions(
cfg=cfg,
section_path=section_path,
content=content,
n=questions_per_section,
)
async with get_db() as db2:
for q in questions:
qid = _id()
# 简单质量评分:暂时用 LLM 返回的置信度,后续可扩展
quality_score = q.get("quality_score", 0.8)
status = "approved" if quality_score >= quality_threshold else "pending"
await db2.execute(
"""INSERT INTO qa_gen_question
(id,task_id,section_path,question,reference_answer,source_chunk,
quality_score,status,created_at)
VALUES (?,?,?,?,?,?,?,?,?)""",
(qid, task_id, section_path,
q["question"], q["answer"], q.get("source_chunk", ""),
quality_score, status, _now()),
)
done += 1
await db2.execute(
"UPDATE qa_gen_task SET progress=? WHERE id=?", (done, task_id)
)
await _sync_approved_count(db2, task_id)
await db2.commit()
await asyncio.gather(*[gen_section(sp, ct) for sp, ct in sections])
async with get_db() as db:
await db.execute(
"UPDATE qa_gen_task SET status='done', finished_at=? WHERE id=?",
(_now(), task_id),
)
await db.commit()
except Exception as exc:
async with get_db() as db:
await db.execute(
"UPDATE qa_gen_task SET status='failed', error_message=? WHERE id=?",
(str(exc), task_id),
)
await db.commit()
def _parse_knowledge_md(md_text: str) -> list[tuple[str, str]]:
"""
将知识库 MD 文件按 ## 标题切分为 (section_path, content) 列表。
支持多级标题,用 / 拼接路径。
"""
lines = md_text.splitlines()
sections: list[tuple[str, str]] = []
current_path: list[str] = []
current_lines: list[str] = []
current_level = 0
for line in lines:
m = re.match(r'^(#{1,4})\s+(.+)', line)
if m:
# 保存上一个 section
if current_path and current_lines:
content = "\n".join(current_lines).strip()
if content:
sections.append(("/".join(current_path), content))
level = len(m.group(1))
title = m.group(2).strip()
# 调整路径深度
if level > current_level:
current_path.append(title)
elif level == current_level:
if current_path:
current_path[-1] = title
else:
current_path = [title]
else:
# 回退到对应层级
current_path = current_path[:level - 1] + [title]
current_level = level
current_lines = []
else:
current_lines.append(line)
# 最后一个 section
if current_path and current_lines:
content = "\n".join(current_lines).strip()
if content:
sections.append(("/".join(current_path), content))
return sections
async def _generate_questions(
cfg: dict,
section_path: str,
content: str,
n: int,
) -> list[dict]:
"""调用 LLM 生成问题,返回 [{question, answer, source_chunk, quality_score}]"""
import aiohttp
base_url = cfg.get("base_url", "").rstrip("/")
api_key = cfg.get("api_key", "")
model = cfg.get("model", "gpt-4o-mini")
# 截断过长内容
content_truncated = content[:3000] if len(content) > 3000 else content
prompt = f"""你是一个专业的技术文档测试问题生成专家。
根据以下技术文档章节内容,生成 {n} 个测试问题。
章节路径:{section_path}
章节内容:
{content_truncated}
要求:
1. 问题必须能从该章节内容直接回答,不要生成需要跨文档才能回答的问题
2. 问题应覆盖章节的关键知识点,避免过于简单的是非题
3. 问题表述清晰,无歧义
4. 答案准确与原文一致长度适中1-3句话
5. source_chunk 为答案来源的原文片段50-150字
6. quality_score 为你对该问题质量的评估0-11为最高质量
只输出 JSON 数组,不要有其他内容:
[
{{
"question": "问题文本",
"answer": "参考答案",
"source_chunk": "答案来源原文片段",
"quality_score": 0.9
}}
]"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
payload = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.3,
}
try:
async with aiohttp.ClientSession(headers=headers) as session:
async with session.post(
f"{base_url}/chat/completions",
json=payload,
timeout=aiohttp.ClientTimeout(total=60),
) as resp:
resp.raise_for_status()
data = await resp.json()
text = data["choices"][0]["message"]["content"].strip()
# 提取 JSON 数组
m = re.search(r'\[.*\]', text, re.DOTALL)
if not m:
return []
questions = json.loads(m.group())
# 校验字段
result = []
for q in questions:
if isinstance(q, dict) and q.get("question") and q.get("answer"):
result.append({
"question": str(q["question"]).strip(),
"answer": str(q["answer"]).strip(),
"source_chunk": str(q.get("source_chunk", "")).strip(),
"quality_score": float(q.get("quality_score", 0.8)),
})
return result
except Exception as e:
# 生成失败不中断整个任务,返回空列表
return []
async def _sync_approved_count(db, task_id: str):
"""同步更新 qa_gen_task.approved 计数"""
rows = await db.execute_fetchall(
"SELECT COUNT(*) as cnt FROM qa_gen_question WHERE task_id=? AND status='approved'",
(task_id,),
)
approved = dict(rows[0])["cnt"] if rows else 0
await db.execute(
"UPDATE qa_gen_task SET approved=? WHERE id=?", (approved, task_id)
)