630 lines
24 KiB
Python
630 lines
24 KiB
Python
"""
|
||
Loop task API - Automated QA generation and testing with pause/resume.
|
||
"""
|
||
import asyncio
|
||
import json
|
||
from io import BytesIO
|
||
from typing import Optional
|
||
|
||
from fastapi import APIRouter, Form, HTTPException, Query
|
||
from fastapi.responses import JSONResponse, StreamingResponse
|
||
|
||
import sys
|
||
from pathlib import Path
|
||
sys.path.insert(0, str(Path(__file__).parent.parent / "sdk"))
|
||
|
||
from models.db import get_db, _id, _now
|
||
from service.loop_recall_md import DEFAULT_LLM_NOTE, append_recall_md_section
|
||
from service.loop_engine import (
|
||
run_loop_task, pause_loop, resume_loop, stop_loop,
|
||
_loop_controls, _update_loop_stats
|
||
)
|
||
|
||
router = APIRouter(prefix="/api/loop", tags=["Loop Task"])
|
||
|
||
|
||
@router.post("/task")
|
||
async def create_loop_task(
|
||
name: str = Form(...),
|
||
org_id: str = Form(...),
|
||
judge_config_id: str = Form(...),
|
||
file_ids: str = Form(""), # comma-separated
|
||
questions_per_section: int = Form(5),
|
||
quality_threshold: float = Form(0.6),
|
||
include_multimodal: bool = Form(True),
|
||
env_url: str = Form(...),
|
||
d_user_id: str = Form("test"),
|
||
agent_id: str = Form(""), # 用于召回测试的 agent ID
|
||
top_k: int = Form(64),
|
||
recall_top_k: int = Form(64),
|
||
concurrency: int = Form(20),
|
||
cross_chunk: bool = Form(True),
|
||
max_rounds: int = Form(0),
|
||
max_questions: int = Form(0),
|
||
global_dedup: bool = Form(False), # 是否全局去重(跨任务)
|
||
expected_chunk_count: int = Form(0), # 本批次切片总数,与 chunk_batches_plan.chunk_count 一致;>0 时校验拉取完整性
|
||
):
|
||
"""Create and start a loop task.
|
||
|
||
Args:
|
||
top_k: 用于判断切片/文件是否命中的阈值(默认64)
|
||
recall_top_k: 调用召回API时请求的top_k数量(默认64)
|
||
agent_id: 用于召回测试的 agent ID(可选,为空时直接调用知识库搜索)
|
||
expected_chunk_count: 可选;与批次 chunk_count 一致时,拉取不足会重试并最终失败,避免静默缺切片
|
||
"""
|
||
|
||
task_id = _id()
|
||
file_id_list = [f.strip() for f in file_ids.split(",") if f.strip()]
|
||
ecc = int(expected_chunk_count) if expected_chunk_count and int(expected_chunk_count) > 0 else None
|
||
|
||
async with get_db() as db:
|
||
await db.execute(
|
||
"""INSERT INTO loop_task
|
||
(id,name,org_id,judge_config_id,file_ids,questions_per_section,quality_threshold,
|
||
include_multimodal,env_url,d_user_id,agent_id,top_k,recall_top_k,concurrency,cross_chunk,
|
||
status,current_round,max_rounds,max_questions,total_generated,total_approved,
|
||
total_duplicates,total_tested,total_recalled,total_file_hit,total_file_miss,
|
||
total_recall_failed,global_dedup,expected_chunk_count,created_at)
|
||
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""",
|
||
(task_id, name, org_id, judge_config_id, ",".join(file_id_list),
|
||
questions_per_section, quality_threshold, int(include_multimodal),
|
||
env_url, d_user_id, agent_id, top_k, recall_top_k, concurrency, int(cross_chunk),
|
||
"pending", 0, max_rounds, max_questions,
|
||
0, 0, 0, 0, 0, 0, 0, 0, int(global_dedup), ecc, _now()),
|
||
)
|
||
await db.commit()
|
||
|
||
# Start the loop in background
|
||
asyncio.create_task(run_loop_task(
|
||
loop_task_id=task_id,
|
||
org_id=org_id,
|
||
file_ids=file_id_list,
|
||
judge_config_id=judge_config_id,
|
||
questions_per_section=questions_per_section,
|
||
quality_threshold=quality_threshold,
|
||
include_multimodal=include_multimodal,
|
||
env_url=env_url,
|
||
d_user_id=d_user_id,
|
||
agent_id=agent_id,
|
||
top_k=top_k,
|
||
recall_top_k=recall_top_k,
|
||
concurrency=concurrency,
|
||
cross_chunk=cross_chunk,
|
||
max_rounds=max_rounds,
|
||
max_questions=max_questions,
|
||
global_dedup=global_dedup,
|
||
))
|
||
|
||
return {"status": 0, "data": {"id": task_id}}
|
||
|
||
|
||
@router.get("/task/list")
|
||
async def list_loop_tasks(
|
||
page: int = Query(1, ge=1),
|
||
page_size: int = Query(20, ge=1, le=100),
|
||
):
|
||
"""List all loop tasks with pagination."""
|
||
offset = (page - 1) * page_size
|
||
|
||
async with get_db() as db:
|
||
rows = await db.execute_fetchall(
|
||
"""SELECT * FROM loop_task
|
||
ORDER BY created_at DESC
|
||
LIMIT ? OFFSET ?""",
|
||
(page_size, offset),
|
||
)
|
||
total = await db.execute_fetchall(
|
||
"SELECT COUNT(*) as cnt FROM loop_task"
|
||
)
|
||
|
||
tasks = []
|
||
for row in rows:
|
||
task = dict(row)
|
||
# Calculate derived metrics
|
||
total_tested = task.get("total_tested") or 0
|
||
total_recalled = task.get("total_recalled") or 0
|
||
total_file_hit = task.get("total_file_hit") or 0
|
||
total_file_miss = task.get("total_file_miss") or 0
|
||
|
||
task["recall_rate"] = round(total_recalled / total_tested, 4) if total_tested > 0 else 0
|
||
task["file_hit_rate"] = round(total_file_hit / total_recalled, 4) if total_recalled > 0 else 0
|
||
task["file_miss_rate"] = round(total_file_miss / total_recalled, 4) if total_recalled > 0 else 0
|
||
|
||
tasks.append(task)
|
||
|
||
return {
|
||
"status": 0,
|
||
"data": {
|
||
"total": total[0]["cnt"] if total else 0,
|
||
"items": tasks,
|
||
},
|
||
}
|
||
|
||
|
||
@router.get("/task/{task_id}")
|
||
async def get_loop_task(task_id: str):
|
||
"""Get loop task details with cumulative stats."""
|
||
async with get_db() as db:
|
||
rows = await db.execute_fetchall(
|
||
"SELECT * FROM loop_task WHERE id=?", (task_id,)
|
||
)
|
||
|
||
if not rows:
|
||
raise HTTPException(status_code=404, detail="Task not found")
|
||
|
||
task = dict(rows[0])
|
||
|
||
# Calculate rates
|
||
total_tested = task.get("total_tested") or 0
|
||
total_recalled = task.get("total_recalled") or 0
|
||
total_file_hit = task.get("total_file_hit") or 0
|
||
total_file_miss = task.get("total_file_miss") or 0
|
||
|
||
task["recall_rate"] = round(total_recalled / total_tested, 4) if total_tested > 0 else 0
|
||
task["file_hit_rate"] = round(total_file_hit / total_recalled, 4) if total_recalled > 0 else 0
|
||
task["file_miss_rate"] = round(total_file_miss / total_recalled, 4) if total_recalled > 0 else 0
|
||
|
||
return {"status": 0, "data": task}
|
||
|
||
|
||
@router.post("/task/{task_id}/pause")
|
||
async def pause_task(task_id: str):
|
||
"""Pause a running loop task."""
|
||
result = await pause_loop(task_id)
|
||
if not result:
|
||
raise HTTPException(status_code=400, detail="Task not running")
|
||
|
||
# 返回更新后的任务状态
|
||
async with get_db() as db:
|
||
rows = await db.execute_fetchall(
|
||
"SELECT * FROM loop_task WHERE id=?", (task_id,)
|
||
)
|
||
if not rows:
|
||
raise HTTPException(status_code=404, detail="Task not found")
|
||
|
||
task = dict(rows[0])
|
||
return {"status": 0, "data": task}
|
||
|
||
|
||
@router.post("/task/{task_id}/resume")
|
||
async def resume_task(task_id: str):
|
||
"""Resume a paused loop task."""
|
||
async with get_db() as db:
|
||
rows = await db.execute_fetchall(
|
||
"SELECT status FROM loop_task WHERE id=?", (task_id,)
|
||
)
|
||
if not rows:
|
||
raise HTTPException(status_code=404, detail="Task not found")
|
||
|
||
if dict(rows[0])["status"] != "paused":
|
||
raise HTTPException(status_code=400, detail="Task not paused")
|
||
|
||
# 立即把状态改成 running,让前端马上看到反馈
|
||
async with get_db() as db:
|
||
await db.execute(
|
||
"UPDATE loop_task SET status='running', paused_at=NULL WHERE id=?",
|
||
(task_id,),
|
||
)
|
||
await db.commit()
|
||
|
||
# 尝试唤醒内存中的任务
|
||
result = await resume_loop(task_id)
|
||
if not result:
|
||
# 内存中没有(服务重启过),重新启动任务
|
||
async with get_db() as db:
|
||
task_rows = await db.execute_fetchall(
|
||
"SELECT * FROM loop_task WHERE id=?", (task_id,)
|
||
)
|
||
task = dict(task_rows[0])
|
||
file_ids = [f.strip() for f in (task.get("file_ids") or "").split(",") if f.strip()]
|
||
|
||
asyncio.create_task(run_loop_task(
|
||
loop_task_id=task_id,
|
||
org_id=task["org_id"],
|
||
file_ids=file_ids,
|
||
judge_config_id=task["judge_config_id"],
|
||
questions_per_section=task["questions_per_section"],
|
||
quality_threshold=task["quality_threshold"],
|
||
include_multimodal=bool(task["include_multimodal"]),
|
||
env_url=task["env_url"],
|
||
d_user_id=task["d_user_id"],
|
||
agent_id=task.get("agent_id", ""),
|
||
top_k=task["top_k"],
|
||
recall_top_k=task.get("recall_top_k", 64),
|
||
concurrency=task["concurrency"],
|
||
cross_chunk=bool(task["cross_chunk"]),
|
||
max_rounds=task["max_rounds"],
|
||
max_questions=task["max_questions"],
|
||
global_dedup=bool(task.get("global_dedup", 0)),
|
||
))
|
||
|
||
# 返回更新后的任务状态
|
||
async with get_db() as db:
|
||
rows = await db.execute_fetchall(
|
||
"SELECT * FROM loop_task WHERE id=?", (task_id,)
|
||
)
|
||
task = dict(rows[0])
|
||
return {"status": 0, "data": task}
|
||
|
||
|
||
@router.post("/task/{task_id}/stop")
|
||
async def stop_task(task_id: str):
|
||
"""Stop a loop task permanently."""
|
||
# Check task exists and is running or paused
|
||
async with get_db() as db:
|
||
rows = await db.execute_fetchall(
|
||
"SELECT status FROM loop_task WHERE id=?", (task_id,)
|
||
)
|
||
if not rows:
|
||
raise HTTPException(status_code=404, detail="Task not found")
|
||
|
||
status = rows[0]["status"]
|
||
if status not in ("running", "paused"):
|
||
raise HTTPException(status_code=400, detail="Task not running or paused")
|
||
|
||
# Try to stop via control structure (if running)
|
||
from service.loop_engine import _loop_controls
|
||
ctrl = _loop_controls.get(task_id)
|
||
if ctrl:
|
||
ctrl["stop"] = True
|
||
ctrl["pause_event"].set()
|
||
|
||
# Update database status regardless
|
||
async with get_db() as db:
|
||
await db.execute(
|
||
"UPDATE loop_task SET status='stopped', finished_at=? WHERE id=?",
|
||
(_now(), task_id),
|
||
)
|
||
await db.commit()
|
||
|
||
return {"status": 0, "data": True}
|
||
|
||
|
||
@router.delete("/task/{task_id}")
|
||
async def delete_task(task_id: str):
|
||
"""Delete loop task and all related data."""
|
||
|
||
# First stop any running background task
|
||
from service.loop_engine import _loop_controls
|
||
ctrl = _loop_controls.get(task_id)
|
||
if ctrl:
|
||
ctrl["stop"] = True
|
||
ctrl["pause_event"].set()
|
||
_loop_controls.pop(task_id, None)
|
||
|
||
async with get_db() as db:
|
||
# Get all rounds to delete related tasks
|
||
rounds = await db.execute_fetchall(
|
||
"SELECT qa_gen_task_id, single_jump_task_id FROM loop_round WHERE loop_task_id=?",
|
||
(task_id,),
|
||
)
|
||
|
||
for r in rounds:
|
||
qa_id = r["qa_gen_task_id"]
|
||
sj_id = r["single_jump_task_id"]
|
||
|
||
# Delete QA questions
|
||
if qa_id:
|
||
await db.execute(
|
||
"DELETE FROM qa_gen_question WHERE task_id=?", (qa_id,)
|
||
)
|
||
await db.execute(
|
||
"DELETE FROM qa_gen_task WHERE id=?", (qa_id,)
|
||
)
|
||
|
||
# Delete single-jump results
|
||
if sj_id:
|
||
await db.execute(
|
||
"DELETE FROM single_jump_result WHERE task_id=?", (sj_id,)
|
||
)
|
||
await db.execute(
|
||
"DELETE FROM single_jump_task WHERE id=?", (sj_id,)
|
||
)
|
||
|
||
# Delete rounds
|
||
await db.execute(
|
||
"DELETE FROM loop_round WHERE loop_task_id=?", (task_id,)
|
||
)
|
||
|
||
# Delete task
|
||
await db.execute(
|
||
"DELETE FROM loop_task WHERE id=?", (task_id,)
|
||
)
|
||
|
||
await db.commit()
|
||
|
||
return {"status": 0, "data": True}
|
||
|
||
|
||
@router.get("/task/{task_id}/rounds")
|
||
async def get_rounds(task_id: str):
|
||
"""Get all rounds for a loop task."""
|
||
async with get_db() as db:
|
||
rows = await db.execute_fetchall(
|
||
"""SELECT * FROM loop_round
|
||
WHERE loop_task_id=?
|
||
ORDER BY round_number""",
|
||
(task_id,),
|
||
)
|
||
# Convert rows to dicts while connection is still open
|
||
rounds = [dict(r) for r in rows]
|
||
|
||
return {"status": 0, "data": rounds}
|
||
|
||
|
||
@router.get("/task/{task_id}/questions")
|
||
async def get_questions(
|
||
task_id: str,
|
||
status: Optional[str] = Query(None), # approved, rejected, duplicate
|
||
category: Optional[str] = Query(None), # hit, file_miss, recall_failed
|
||
page: int = Query(1, ge=1),
|
||
page_size: int = Query(20, ge=1, le=100),
|
||
):
|
||
"""
|
||
Get questions across all rounds.
|
||
|
||
- status: filter by qa_gen_question status
|
||
- category: filter by test result category
|
||
"""
|
||
offset = (page - 1) * page_size
|
||
|
||
# Build query
|
||
where_clauses = ["lr.loop_task_id = ?"]
|
||
params = [task_id]
|
||
|
||
if status:
|
||
if status == "duplicate":
|
||
where_clauses.append("q.dup_of IS NOT NULL")
|
||
else:
|
||
where_clauses.append("q.status = ?")
|
||
params.append(status)
|
||
|
||
if category:
|
||
if category == "hit":
|
||
where_clauses.append("r.is_file_hit = 1")
|
||
elif category == "file_miss":
|
||
where_clauses.append("r.is_file_hit = 0 AND COALESCE(json_array_length(r.retrieved), 0) > 0")
|
||
elif category == "recall_failed":
|
||
where_clauses.append("COALESCE(json_array_length(r.retrieved), 0) = 0 AND r.error IS NULL")
|
||
|
||
where_sql = " AND ".join(where_clauses)
|
||
|
||
async with get_db() as db:
|
||
rows = await db.execute_fetchall(
|
||
f"""SELECT
|
||
q.id, q.section_path, q.question, q.reference_answer,
|
||
q.source_chunk, q.quality_score, q.status,
|
||
q.dup_of, q.dup_similarity,
|
||
q.chunk_headers, q.chunk_id, q.file_name,
|
||
lr.round_number,
|
||
r.is_file_hit, r.retrieved, r.best_cosine_sim, r.latency_ms, r.error,
|
||
r.expected_chunk_id, r.is_chunk_hit, r.chunk_hit_rank
|
||
FROM qa_gen_question q
|
||
JOIN loop_round lr ON q.task_id = lr.qa_gen_task_id
|
||
LEFT JOIN single_jump_result r ON r.rowid = (
|
||
SELECT r2.rowid FROM single_jump_result r2
|
||
WHERE r2.task_id = lr.single_jump_task_id AND r2.question = q.question
|
||
ORDER BY r2.rowid DESC LIMIT 1
|
||
)
|
||
WHERE {where_sql}
|
||
ORDER BY lr.round_number DESC, q.created_at DESC
|
||
LIMIT ? OFFSET ?""",
|
||
(*params, page_size, offset),
|
||
)
|
||
|
||
# Convert rows to dicts while connection is still open
|
||
questions = [dict(r) for r in rows]
|
||
|
||
# Get total count
|
||
total_rows = await db.execute_fetchall(
|
||
f"""SELECT COUNT(DISTINCT q.id) as cnt
|
||
FROM qa_gen_question q
|
||
JOIN loop_round lr ON q.task_id = lr.qa_gen_task_id
|
||
LEFT JOIN single_jump_result r ON r.rowid = (
|
||
SELECT r2.rowid FROM single_jump_result r2
|
||
WHERE r2.task_id = lr.single_jump_task_id AND r2.question = q.question
|
||
ORDER BY r2.rowid DESC LIMIT 1
|
||
)
|
||
WHERE {where_sql}""",
|
||
params,
|
||
)
|
||
|
||
return {
|
||
"status": 0,
|
||
"data": {
|
||
"total": total_rows[0]["cnt"] if total_rows else 0,
|
||
"items": questions,
|
||
},
|
||
}
|
||
|
||
|
||
@router.get("/task/{task_id}/export")
|
||
async def export_questions(
|
||
task_id: str,
|
||
category: str = Query("all"), # all, hit, file_miss, recall_failed
|
||
format: str = Query("md"), # md, json
|
||
):
|
||
"""Export questions to MD or JSON format."""
|
||
|
||
async with get_db() as db:
|
||
# Check if we have qa_gen_task_id in loop_round
|
||
has_qa_task = await db.execute_fetchall(
|
||
"""SELECT COUNT(*) as cnt FROM loop_round
|
||
WHERE loop_task_id=? AND qa_gen_task_id IS NOT NULL""",
|
||
(task_id,)
|
||
)
|
||
|
||
use_qa_task = has_qa_task[0]["cnt"] > 0 if has_qa_task else False
|
||
|
||
# Build where clause based on category
|
||
if use_qa_task:
|
||
# New tasks: query from qa_gen_question and join single_jump_result for expected_chunk_id
|
||
if category == "hit":
|
||
where_clause = "r.is_file_hit = 1"
|
||
elif category == "file_miss":
|
||
where_clause = "r.is_file_hit = 0 AND COALESCE(json_array_length(r.retrieved), 0) > 0"
|
||
elif category == "recall_failed":
|
||
where_clause = "COALESCE(json_array_length(r.retrieved), 0) = 0 AND r.error IS NULL"
|
||
else: # all
|
||
where_clause = "1=1"
|
||
|
||
# 注意:不要用 JOIN qa_gen_question ON chunk_id,同一 chunk 下多题会行膨胀导致导出重复。
|
||
# single_jump_result 若同一 task 下同题干有多行,只取最新一条(rowid 最大)。
|
||
db_rows = await db.execute_fetchall(
|
||
f"""SELECT
|
||
q.id as qa_question_id,
|
||
q.section_path, q.file_name, q.question, q.reference_answer,
|
||
q.source_chunk, q.quality_score, q.status,
|
||
q.dup_of, q.dup_similarity,
|
||
q.chunk_headers, q.chunk_id,
|
||
lr.round_number,
|
||
r.is_file_hit, r.retrieved, r.best_cosine_sim,
|
||
r.expected_chunk_id,
|
||
(SELECT q2b.chunk_headers FROM qa_gen_question q2b
|
||
WHERE q2b.chunk_id = r.expected_chunk_id
|
||
AND q2b.chunk_id IS NOT NULL AND trim(COALESCE(q2b.chunk_headers, '')) != ''
|
||
LIMIT 1) AS expected_chunk_name
|
||
FROM qa_gen_question q
|
||
JOIN loop_round lr ON q.task_id = lr.qa_gen_task_id
|
||
LEFT JOIN single_jump_result r ON r.rowid = (
|
||
SELECT r2.rowid FROM single_jump_result r2
|
||
WHERE r2.task_id = lr.single_jump_task_id AND r2.question = q.question
|
||
ORDER BY r2.rowid DESC LIMIT 1
|
||
)
|
||
WHERE lr.loop_task_id = ? AND q.status = 'approved' AND {where_clause}
|
||
ORDER BY lr.round_number, q.chunk_headers, q.created_at""",
|
||
(task_id,),
|
||
)
|
||
else:
|
||
# Old tasks: query from single_jump_result directly
|
||
if category == "hit":
|
||
where_clause = "r.is_file_hit = 1"
|
||
elif category == "file_miss":
|
||
where_clause = "r.is_file_hit = 0 AND COALESCE(json_array_length(r.retrieved), 0) > 0"
|
||
elif category == "recall_failed":
|
||
where_clause = "COALESCE(json_array_length(r.retrieved), 0) = 0 AND r.error IS NULL"
|
||
else: # all
|
||
where_clause = "1=1"
|
||
|
||
db_rows = await db.execute_fetchall(
|
||
f"""SELECT
|
||
r.rowid as result_rowid,
|
||
r.section_path, r.file_name, r.question, r.reference_answer,
|
||
'' as source_chunk, 1.0 as quality_score, 'approved' as status,
|
||
NULL as dup_of, NULL as dup_similarity,
|
||
COALESCE(r.raw_chunk_headers, r.section_path) as chunk_headers,
|
||
r.expected_chunk_id as chunk_id,
|
||
lr.round_number,
|
||
r.is_file_hit, r.retrieved, r.best_cosine_sim,
|
||
r.expected_chunk_id,
|
||
(SELECT qb.chunk_headers FROM qa_gen_question qb
|
||
WHERE qb.chunk_id = r.expected_chunk_id LIMIT 1) AS expected_chunk_name
|
||
FROM single_jump_result r
|
||
JOIN loop_round lr ON r.task_id = lr.single_jump_task_id
|
||
WHERE lr.loop_task_id = ? AND {where_clause}
|
||
ORDER BY lr.round_number, r.section_path""",
|
||
(task_id,),
|
||
)
|
||
|
||
# Convert rows to dicts while connection is still open
|
||
rows = [dict(row) for row in db_rows]
|
||
|
||
if not rows:
|
||
# Return empty response if no data
|
||
from fastapi.responses import PlainTextResponse
|
||
return PlainTextResponse(
|
||
"没有符合条件的数据",
|
||
status_code=404
|
||
)
|
||
|
||
# Group by section
|
||
from collections import defaultdict
|
||
sections: dict[str, list] = defaultdict(list)
|
||
for row in rows:
|
||
# Use chunk_headers as the grouping key if available, otherwise use section_path
|
||
section_key = row.get("chunk_headers") or row.get("section_path") or row.get("file_name") or "default"
|
||
sections[section_key].append(row)
|
||
|
||
if format == "json":
|
||
# JSON export
|
||
data = {
|
||
"task_id": task_id,
|
||
"category": category,
|
||
"exported_at": _now(),
|
||
"questions": [],
|
||
}
|
||
for section_path, items in sections.items():
|
||
for item in items:
|
||
data["questions"].append({
|
||
"section_path": section_path,
|
||
"file_name": item.get("file_name"),
|
||
"round": item["round_number"],
|
||
"question": item["question"],
|
||
"reference_answer": item["reference_answer"],
|
||
"source_chunk": item["source_chunk"],
|
||
"quality_score": item["quality_score"],
|
||
"status": item["status"],
|
||
"is_duplicate": bool(item.get("dup_of")),
|
||
"dup_similarity": item.get("dup_similarity"),
|
||
"is_file_hit": bool(item.get("is_file_hit")),
|
||
"recall_results": json.loads(item["retrieved"]) if item.get("retrieved") else [],
|
||
"best_cosine_sim": item["best_cosine_sim"],
|
||
"expected_chunk_id": item.get("expected_chunk_id"),
|
||
"expected_chunk_name": item.get("expected_chunk_name"),
|
||
"chunk_id": item.get("chunk_id") or item.get("expected_chunk_id"),
|
||
})
|
||
|
||
content = json.dumps(data, ensure_ascii=False, indent=2)
|
||
filename = f"loop_{task_id}_{category}.json"
|
||
media_type = "application/json"
|
||
|
||
else:
|
||
# MD export:与单跳解析器、循环内单跳 MD、离线脚本同一套 loop_recall_md
|
||
lines: list[str] = []
|
||
|
||
def _after_answer(_i: int, item: dict):
|
||
if item.get("expected_chunk_name"):
|
||
yield f"> 预期切片: {item['expected_chunk_name']}"
|
||
sc = item.get("source_chunk")
|
||
if sc:
|
||
yield f"> Source: {str(sc)[:200]}..."
|
||
|
||
section_index = 0
|
||
for section_key, items in sections.items():
|
||
section_index += 1
|
||
file_name = (items[0].get("file_name") or "").strip()
|
||
slice_title = (items[0].get("chunk_headers") or "").strip() or section_key
|
||
meta = [f"> 代表轮次: {items[0]['round_number']}", DEFAULT_LLM_NOTE]
|
||
if category != "all":
|
||
meta.insert(0, f"> 导出分类: {category}")
|
||
qa_items = [
|
||
{
|
||
"question": it["question"],
|
||
"reference_answer": it["reference_answer"],
|
||
"chunk_id": (it.get("chunk_id") or it.get("expected_chunk_id") or ""),
|
||
}
|
||
for it in items
|
||
]
|
||
append_recall_md_section(
|
||
lines,
|
||
section_index,
|
||
file_name=file_name,
|
||
slice_title=slice_title,
|
||
qa_items=qa_items,
|
||
meta_lines=meta,
|
||
after_answer_lines=_after_answer,
|
||
)
|
||
|
||
content = "\n".join(lines)
|
||
filename = f"loop_{task_id}_{category}.md"
|
||
media_type = "text/markdown"
|
||
|
||
from urllib.parse import quote
|
||
filename_encoded = quote(filename)
|
||
|
||
return StreamingResponse(
|
||
BytesIO(content.encode("utf-8")),
|
||
media_type=media_type,
|
||
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{filename_encoded}"},
|
||
)
|