630 lines
24 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Loop task API - Automated QA generation and testing with pause/resume.
"""
import asyncio
import json
from io import BytesIO
from typing import Optional
from fastapi import APIRouter, Form, HTTPException, Query
from fastapi.responses import JSONResponse, StreamingResponse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent / "sdk"))
from models.db import get_db, _id, _now
from service.loop_recall_md import DEFAULT_LLM_NOTE, append_recall_md_section
from service.loop_engine import (
run_loop_task, pause_loop, resume_loop, stop_loop,
_loop_controls, _update_loop_stats
)
router = APIRouter(prefix="/api/loop", tags=["Loop Task"])
@router.post("/task")
async def create_loop_task(
name: str = Form(...),
org_id: str = Form(...),
judge_config_id: str = Form(...),
file_ids: str = Form(""), # comma-separated
questions_per_section: int = Form(5),
quality_threshold: float = Form(0.6),
include_multimodal: bool = Form(True),
env_url: str = Form(...),
d_user_id: str = Form("test"),
agent_id: str = Form(""), # 用于召回测试的 agent ID
top_k: int = Form(64),
recall_top_k: int = Form(64),
concurrency: int = Form(20),
cross_chunk: bool = Form(True),
max_rounds: int = Form(0),
max_questions: int = Form(0),
global_dedup: bool = Form(False), # 是否全局去重(跨任务)
expected_chunk_count: int = Form(0), # 本批次切片总数,与 chunk_batches_plan.chunk_count 一致;>0 时校验拉取完整性
):
"""Create and start a loop task.
Args:
top_k: 用于判断切片/文件是否命中的阈值默认64
recall_top_k: 调用召回API时请求的top_k数量默认64
agent_id: 用于召回测试的 agent ID可选为空时直接调用知识库搜索
expected_chunk_count: 可选;与批次 chunk_count 一致时,拉取不足会重试并最终失败,避免静默缺切片
"""
task_id = _id()
file_id_list = [f.strip() for f in file_ids.split(",") if f.strip()]
ecc = int(expected_chunk_count) if expected_chunk_count and int(expected_chunk_count) > 0 else None
async with get_db() as db:
await db.execute(
"""INSERT INTO loop_task
(id,name,org_id,judge_config_id,file_ids,questions_per_section,quality_threshold,
include_multimodal,env_url,d_user_id,agent_id,top_k,recall_top_k,concurrency,cross_chunk,
status,current_round,max_rounds,max_questions,total_generated,total_approved,
total_duplicates,total_tested,total_recalled,total_file_hit,total_file_miss,
total_recall_failed,global_dedup,expected_chunk_count,created_at)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""",
(task_id, name, org_id, judge_config_id, ",".join(file_id_list),
questions_per_section, quality_threshold, int(include_multimodal),
env_url, d_user_id, agent_id, top_k, recall_top_k, concurrency, int(cross_chunk),
"pending", 0, max_rounds, max_questions,
0, 0, 0, 0, 0, 0, 0, 0, int(global_dedup), ecc, _now()),
)
await db.commit()
# Start the loop in background
asyncio.create_task(run_loop_task(
loop_task_id=task_id,
org_id=org_id,
file_ids=file_id_list,
judge_config_id=judge_config_id,
questions_per_section=questions_per_section,
quality_threshold=quality_threshold,
include_multimodal=include_multimodal,
env_url=env_url,
d_user_id=d_user_id,
agent_id=agent_id,
top_k=top_k,
recall_top_k=recall_top_k,
concurrency=concurrency,
cross_chunk=cross_chunk,
max_rounds=max_rounds,
max_questions=max_questions,
global_dedup=global_dedup,
))
return {"status": 0, "data": {"id": task_id}}
@router.get("/task/list")
async def list_loop_tasks(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
):
"""List all loop tasks with pagination."""
offset = (page - 1) * page_size
async with get_db() as db:
rows = await db.execute_fetchall(
"""SELECT * FROM loop_task
ORDER BY created_at DESC
LIMIT ? OFFSET ?""",
(page_size, offset),
)
total = await db.execute_fetchall(
"SELECT COUNT(*) as cnt FROM loop_task"
)
tasks = []
for row in rows:
task = dict(row)
# Calculate derived metrics
total_tested = task.get("total_tested") or 0
total_recalled = task.get("total_recalled") or 0
total_file_hit = task.get("total_file_hit") or 0
total_file_miss = task.get("total_file_miss") or 0
task["recall_rate"] = round(total_recalled / total_tested, 4) if total_tested > 0 else 0
task["file_hit_rate"] = round(total_file_hit / total_recalled, 4) if total_recalled > 0 else 0
task["file_miss_rate"] = round(total_file_miss / total_recalled, 4) if total_recalled > 0 else 0
tasks.append(task)
return {
"status": 0,
"data": {
"total": total[0]["cnt"] if total else 0,
"items": tasks,
},
}
@router.get("/task/{task_id}")
async def get_loop_task(task_id: str):
"""Get loop task details with cumulative stats."""
async with get_db() as db:
rows = await db.execute_fetchall(
"SELECT * FROM loop_task WHERE id=?", (task_id,)
)
if not rows:
raise HTTPException(status_code=404, detail="Task not found")
task = dict(rows[0])
# Calculate rates
total_tested = task.get("total_tested") or 0
total_recalled = task.get("total_recalled") or 0
total_file_hit = task.get("total_file_hit") or 0
total_file_miss = task.get("total_file_miss") or 0
task["recall_rate"] = round(total_recalled / total_tested, 4) if total_tested > 0 else 0
task["file_hit_rate"] = round(total_file_hit / total_recalled, 4) if total_recalled > 0 else 0
task["file_miss_rate"] = round(total_file_miss / total_recalled, 4) if total_recalled > 0 else 0
return {"status": 0, "data": task}
@router.post("/task/{task_id}/pause")
async def pause_task(task_id: str):
"""Pause a running loop task."""
result = await pause_loop(task_id)
if not result:
raise HTTPException(status_code=400, detail="Task not running")
# 返回更新后的任务状态
async with get_db() as db:
rows = await db.execute_fetchall(
"SELECT * FROM loop_task WHERE id=?", (task_id,)
)
if not rows:
raise HTTPException(status_code=404, detail="Task not found")
task = dict(rows[0])
return {"status": 0, "data": task}
@router.post("/task/{task_id}/resume")
async def resume_task(task_id: str):
"""Resume a paused loop task."""
async with get_db() as db:
rows = await db.execute_fetchall(
"SELECT status FROM loop_task WHERE id=?", (task_id,)
)
if not rows:
raise HTTPException(status_code=404, detail="Task not found")
if dict(rows[0])["status"] != "paused":
raise HTTPException(status_code=400, detail="Task not paused")
# 立即把状态改成 running让前端马上看到反馈
async with get_db() as db:
await db.execute(
"UPDATE loop_task SET status='running', paused_at=NULL WHERE id=?",
(task_id,),
)
await db.commit()
# 尝试唤醒内存中的任务
result = await resume_loop(task_id)
if not result:
# 内存中没有(服务重启过),重新启动任务
async with get_db() as db:
task_rows = await db.execute_fetchall(
"SELECT * FROM loop_task WHERE id=?", (task_id,)
)
task = dict(task_rows[0])
file_ids = [f.strip() for f in (task.get("file_ids") or "").split(",") if f.strip()]
asyncio.create_task(run_loop_task(
loop_task_id=task_id,
org_id=task["org_id"],
file_ids=file_ids,
judge_config_id=task["judge_config_id"],
questions_per_section=task["questions_per_section"],
quality_threshold=task["quality_threshold"],
include_multimodal=bool(task["include_multimodal"]),
env_url=task["env_url"],
d_user_id=task["d_user_id"],
agent_id=task.get("agent_id", ""),
top_k=task["top_k"],
recall_top_k=task.get("recall_top_k", 64),
concurrency=task["concurrency"],
cross_chunk=bool(task["cross_chunk"]),
max_rounds=task["max_rounds"],
max_questions=task["max_questions"],
global_dedup=bool(task.get("global_dedup", 0)),
))
# 返回更新后的任务状态
async with get_db() as db:
rows = await db.execute_fetchall(
"SELECT * FROM loop_task WHERE id=?", (task_id,)
)
task = dict(rows[0])
return {"status": 0, "data": task}
@router.post("/task/{task_id}/stop")
async def stop_task(task_id: str):
"""Stop a loop task permanently."""
# Check task exists and is running or paused
async with get_db() as db:
rows = await db.execute_fetchall(
"SELECT status FROM loop_task WHERE id=?", (task_id,)
)
if not rows:
raise HTTPException(status_code=404, detail="Task not found")
status = rows[0]["status"]
if status not in ("running", "paused"):
raise HTTPException(status_code=400, detail="Task not running or paused")
# Try to stop via control structure (if running)
from service.loop_engine import _loop_controls
ctrl = _loop_controls.get(task_id)
if ctrl:
ctrl["stop"] = True
ctrl["pause_event"].set()
# Update database status regardless
async with get_db() as db:
await db.execute(
"UPDATE loop_task SET status='stopped', finished_at=? WHERE id=?",
(_now(), task_id),
)
await db.commit()
return {"status": 0, "data": True}
@router.delete("/task/{task_id}")
async def delete_task(task_id: str):
"""Delete loop task and all related data."""
# First stop any running background task
from service.loop_engine import _loop_controls
ctrl = _loop_controls.get(task_id)
if ctrl:
ctrl["stop"] = True
ctrl["pause_event"].set()
_loop_controls.pop(task_id, None)
async with get_db() as db:
# Get all rounds to delete related tasks
rounds = await db.execute_fetchall(
"SELECT qa_gen_task_id, single_jump_task_id FROM loop_round WHERE loop_task_id=?",
(task_id,),
)
for r in rounds:
qa_id = r["qa_gen_task_id"]
sj_id = r["single_jump_task_id"]
# Delete QA questions
if qa_id:
await db.execute(
"DELETE FROM qa_gen_question WHERE task_id=?", (qa_id,)
)
await db.execute(
"DELETE FROM qa_gen_task WHERE id=?", (qa_id,)
)
# Delete single-jump results
if sj_id:
await db.execute(
"DELETE FROM single_jump_result WHERE task_id=?", (sj_id,)
)
await db.execute(
"DELETE FROM single_jump_task WHERE id=?", (sj_id,)
)
# Delete rounds
await db.execute(
"DELETE FROM loop_round WHERE loop_task_id=?", (task_id,)
)
# Delete task
await db.execute(
"DELETE FROM loop_task WHERE id=?", (task_id,)
)
await db.commit()
return {"status": 0, "data": True}
@router.get("/task/{task_id}/rounds")
async def get_rounds(task_id: str):
"""Get all rounds for a loop task."""
async with get_db() as db:
rows = await db.execute_fetchall(
"""SELECT * FROM loop_round
WHERE loop_task_id=?
ORDER BY round_number""",
(task_id,),
)
# Convert rows to dicts while connection is still open
rounds = [dict(r) for r in rows]
return {"status": 0, "data": rounds}
@router.get("/task/{task_id}/questions")
async def get_questions(
task_id: str,
status: Optional[str] = Query(None), # approved, rejected, duplicate
category: Optional[str] = Query(None), # hit, file_miss, recall_failed
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
):
"""
Get questions across all rounds.
- status: filter by qa_gen_question status
- category: filter by test result category
"""
offset = (page - 1) * page_size
# Build query
where_clauses = ["lr.loop_task_id = ?"]
params = [task_id]
if status:
if status == "duplicate":
where_clauses.append("q.dup_of IS NOT NULL")
else:
where_clauses.append("q.status = ?")
params.append(status)
if category:
if category == "hit":
where_clauses.append("r.is_file_hit = 1")
elif category == "file_miss":
where_clauses.append("r.is_file_hit = 0 AND COALESCE(json_array_length(r.retrieved), 0) > 0")
elif category == "recall_failed":
where_clauses.append("COALESCE(json_array_length(r.retrieved), 0) = 0 AND r.error IS NULL")
where_sql = " AND ".join(where_clauses)
async with get_db() as db:
rows = await db.execute_fetchall(
f"""SELECT
q.id, q.section_path, q.question, q.reference_answer,
q.source_chunk, q.quality_score, q.status,
q.dup_of, q.dup_similarity,
q.chunk_headers, q.chunk_id, q.file_name,
lr.round_number,
r.is_file_hit, r.retrieved, r.best_cosine_sim, r.latency_ms, r.error,
r.expected_chunk_id, r.is_chunk_hit, r.chunk_hit_rank
FROM qa_gen_question q
JOIN loop_round lr ON q.task_id = lr.qa_gen_task_id
LEFT JOIN single_jump_result r ON r.rowid = (
SELECT r2.rowid FROM single_jump_result r2
WHERE r2.task_id = lr.single_jump_task_id AND r2.question = q.question
ORDER BY r2.rowid DESC LIMIT 1
)
WHERE {where_sql}
ORDER BY lr.round_number DESC, q.created_at DESC
LIMIT ? OFFSET ?""",
(*params, page_size, offset),
)
# Convert rows to dicts while connection is still open
questions = [dict(r) for r in rows]
# Get total count
total_rows = await db.execute_fetchall(
f"""SELECT COUNT(DISTINCT q.id) as cnt
FROM qa_gen_question q
JOIN loop_round lr ON q.task_id = lr.qa_gen_task_id
LEFT JOIN single_jump_result r ON r.rowid = (
SELECT r2.rowid FROM single_jump_result r2
WHERE r2.task_id = lr.single_jump_task_id AND r2.question = q.question
ORDER BY r2.rowid DESC LIMIT 1
)
WHERE {where_sql}""",
params,
)
return {
"status": 0,
"data": {
"total": total_rows[0]["cnt"] if total_rows else 0,
"items": questions,
},
}
@router.get("/task/{task_id}/export")
async def export_questions(
task_id: str,
category: str = Query("all"), # all, hit, file_miss, recall_failed
format: str = Query("md"), # md, json
):
"""Export questions to MD or JSON format."""
async with get_db() as db:
# Check if we have qa_gen_task_id in loop_round
has_qa_task = await db.execute_fetchall(
"""SELECT COUNT(*) as cnt FROM loop_round
WHERE loop_task_id=? AND qa_gen_task_id IS NOT NULL""",
(task_id,)
)
use_qa_task = has_qa_task[0]["cnt"] > 0 if has_qa_task else False
# Build where clause based on category
if use_qa_task:
# New tasks: query from qa_gen_question and join single_jump_result for expected_chunk_id
if category == "hit":
where_clause = "r.is_file_hit = 1"
elif category == "file_miss":
where_clause = "r.is_file_hit = 0 AND COALESCE(json_array_length(r.retrieved), 0) > 0"
elif category == "recall_failed":
where_clause = "COALESCE(json_array_length(r.retrieved), 0) = 0 AND r.error IS NULL"
else: # all
where_clause = "1=1"
# 注意:不要用 JOIN qa_gen_question ON chunk_id同一 chunk 下多题会行膨胀导致导出重复。
# single_jump_result 若同一 task 下同题干有多行只取最新一条rowid 最大)。
db_rows = await db.execute_fetchall(
f"""SELECT
q.id as qa_question_id,
q.section_path, q.file_name, q.question, q.reference_answer,
q.source_chunk, q.quality_score, q.status,
q.dup_of, q.dup_similarity,
q.chunk_headers, q.chunk_id,
lr.round_number,
r.is_file_hit, r.retrieved, r.best_cosine_sim,
r.expected_chunk_id,
(SELECT q2b.chunk_headers FROM qa_gen_question q2b
WHERE q2b.chunk_id = r.expected_chunk_id
AND q2b.chunk_id IS NOT NULL AND trim(COALESCE(q2b.chunk_headers, '')) != ''
LIMIT 1) AS expected_chunk_name
FROM qa_gen_question q
JOIN loop_round lr ON q.task_id = lr.qa_gen_task_id
LEFT JOIN single_jump_result r ON r.rowid = (
SELECT r2.rowid FROM single_jump_result r2
WHERE r2.task_id = lr.single_jump_task_id AND r2.question = q.question
ORDER BY r2.rowid DESC LIMIT 1
)
WHERE lr.loop_task_id = ? AND q.status = 'approved' AND {where_clause}
ORDER BY lr.round_number, q.chunk_headers, q.created_at""",
(task_id,),
)
else:
# Old tasks: query from single_jump_result directly
if category == "hit":
where_clause = "r.is_file_hit = 1"
elif category == "file_miss":
where_clause = "r.is_file_hit = 0 AND COALESCE(json_array_length(r.retrieved), 0) > 0"
elif category == "recall_failed":
where_clause = "COALESCE(json_array_length(r.retrieved), 0) = 0 AND r.error IS NULL"
else: # all
where_clause = "1=1"
db_rows = await db.execute_fetchall(
f"""SELECT
r.rowid as result_rowid,
r.section_path, r.file_name, r.question, r.reference_answer,
'' as source_chunk, 1.0 as quality_score, 'approved' as status,
NULL as dup_of, NULL as dup_similarity,
COALESCE(r.raw_chunk_headers, r.section_path) as chunk_headers,
r.expected_chunk_id as chunk_id,
lr.round_number,
r.is_file_hit, r.retrieved, r.best_cosine_sim,
r.expected_chunk_id,
(SELECT qb.chunk_headers FROM qa_gen_question qb
WHERE qb.chunk_id = r.expected_chunk_id LIMIT 1) AS expected_chunk_name
FROM single_jump_result r
JOIN loop_round lr ON r.task_id = lr.single_jump_task_id
WHERE lr.loop_task_id = ? AND {where_clause}
ORDER BY lr.round_number, r.section_path""",
(task_id,),
)
# Convert rows to dicts while connection is still open
rows = [dict(row) for row in db_rows]
if not rows:
# Return empty response if no data
from fastapi.responses import PlainTextResponse
return PlainTextResponse(
"没有符合条件的数据",
status_code=404
)
# Group by section
from collections import defaultdict
sections: dict[str, list] = defaultdict(list)
for row in rows:
# Use chunk_headers as the grouping key if available, otherwise use section_path
section_key = row.get("chunk_headers") or row.get("section_path") or row.get("file_name") or "default"
sections[section_key].append(row)
if format == "json":
# JSON export
data = {
"task_id": task_id,
"category": category,
"exported_at": _now(),
"questions": [],
}
for section_path, items in sections.items():
for item in items:
data["questions"].append({
"section_path": section_path,
"file_name": item.get("file_name"),
"round": item["round_number"],
"question": item["question"],
"reference_answer": item["reference_answer"],
"source_chunk": item["source_chunk"],
"quality_score": item["quality_score"],
"status": item["status"],
"is_duplicate": bool(item.get("dup_of")),
"dup_similarity": item.get("dup_similarity"),
"is_file_hit": bool(item.get("is_file_hit")),
"recall_results": json.loads(item["retrieved"]) if item.get("retrieved") else [],
"best_cosine_sim": item["best_cosine_sim"],
"expected_chunk_id": item.get("expected_chunk_id"),
"expected_chunk_name": item.get("expected_chunk_name"),
"chunk_id": item.get("chunk_id") or item.get("expected_chunk_id"),
})
content = json.dumps(data, ensure_ascii=False, indent=2)
filename = f"loop_{task_id}_{category}.json"
media_type = "application/json"
else:
# MD export与单跳解析器、循环内单跳 MD、离线脚本同一套 loop_recall_md
lines: list[str] = []
def _after_answer(_i: int, item: dict):
if item.get("expected_chunk_name"):
yield f"> 预期切片: {item['expected_chunk_name']}"
sc = item.get("source_chunk")
if sc:
yield f"> Source: {str(sc)[:200]}..."
section_index = 0
for section_key, items in sections.items():
section_index += 1
file_name = (items[0].get("file_name") or "").strip()
slice_title = (items[0].get("chunk_headers") or "").strip() or section_key
meta = [f"> 代表轮次: {items[0]['round_number']}", DEFAULT_LLM_NOTE]
if category != "all":
meta.insert(0, f"> 导出分类: {category}")
qa_items = [
{
"question": it["question"],
"reference_answer": it["reference_answer"],
"chunk_id": (it.get("chunk_id") or it.get("expected_chunk_id") or ""),
}
for it in items
]
append_recall_md_section(
lines,
section_index,
file_name=file_name,
slice_title=slice_title,
qa_items=qa_items,
meta_lines=meta,
after_answer_lines=_after_answer,
)
content = "\n".join(lines)
filename = f"loop_{task_id}_{category}.md"
media_type = "text/markdown"
from urllib.parse import quote
filename_encoded = quote(filename)
return StreamingResponse(
BytesIO(content.encode("utf-8")),
media_type=media_type,
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{filename_encoded}"},
)