283 lines
11 KiB
Python
283 lines
11 KiB
Python
"""
|
|
多跳召回测试 API
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "sdk"))
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
from models.db import get_db, _now, _id
|
|
|
|
router = APIRouter(prefix="/api/multi-hop", tags=["多跳召回测试"])
|
|
|
|
|
|
@router.get("/dagent/agents")
|
|
async def list_dagent_agents(env_url: str, org_id: str, d_user_id: str = "test"):
|
|
"""从 dagent 平台拉取可用的 Agent 列表"""
|
|
import aiohttp
|
|
url = f"{env_url.rstrip('/')}/dagent/agent/page"
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"d-user-id": d_user_id,
|
|
"org-id": org_id,
|
|
}
|
|
payload = {"current": 1, "page_size": 100, "org_id": org_id}
|
|
try:
|
|
async with aiohttp.ClientSession(headers=headers) as session:
|
|
async with session.post(url, json=payload, timeout=aiohttp.ClientTimeout(total=10)) as resp:
|
|
resp.raise_for_status()
|
|
data = await resp.json()
|
|
agents = data.get("data", {}).get("list", [])
|
|
return {"status": 0, "data": [
|
|
{"id": a.get("id"), "name": a.get("agent_name"), "type": a.get("agent_type"), "description": a.get("agent_description")}
|
|
for a in agents
|
|
]}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=502, detail=f"无法连接 dagent: {e}")
|
|
|
|
|
|
@router.post("/task")
|
|
async def create_task(
|
|
file: UploadFile = File(...),
|
|
name: str = Form(""),
|
|
env_url: str = Form(...),
|
|
org_id: str = Form(...),
|
|
d_user_id: str = Form("test"),
|
|
agent_id: str = Form(...),
|
|
llm_type: str = Form("deepseek_v3"),
|
|
top_k: int = Form(10),
|
|
concurrency: int = Form(5),
|
|
):
|
|
content = await file.read()
|
|
qa_text = content.decode("utf-8")
|
|
|
|
task_id = _id()
|
|
async with get_db() as db:
|
|
await db.execute(
|
|
"""INSERT INTO multi_hop_task
|
|
(id,name,env_url,org_id,d_user_id,agent_id,llm_type,top_k,concurrency,status,created_at)
|
|
VALUES (?,?,?,?,?,?,?,?,?,?,?)""",
|
|
(task_id, name or file.filename, env_url, org_id,
|
|
d_user_id, agent_id, llm_type, top_k, concurrency, "pending", _now()),
|
|
)
|
|
await db.commit()
|
|
|
|
asyncio.create_task(_run_task(
|
|
task_id, qa_text, env_url, org_id, d_user_id, agent_id, llm_type, top_k, concurrency
|
|
))
|
|
return {"status": 0, "data": {"id": task_id}}
|
|
|
|
|
|
@router.get("/task/list")
|
|
async def list_tasks():
|
|
async with get_db() as db:
|
|
rows = await db.execute_fetchall(
|
|
"SELECT * FROM multi_hop_task ORDER BY created_at DESC"
|
|
)
|
|
return {"status": 0, "data": [dict(r) for r in rows]}
|
|
|
|
|
|
@router.get("/task/{task_id}")
|
|
async def get_task(task_id: str):
|
|
async with get_db() as db:
|
|
rows = await db.execute_fetchall(
|
|
"SELECT * FROM multi_hop_task WHERE id=?", (task_id,)
|
|
)
|
|
if not rows:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
return {"status": 0, "data": dict(rows[0])}
|
|
|
|
|
|
@router.delete("/task/{task_id}")
|
|
async def delete_task(task_id: str):
|
|
async with get_db() as db:
|
|
await db.execute("DELETE FROM multi_hop_result WHERE task_id=?", (task_id,))
|
|
await db.execute("DELETE FROM multi_hop_task WHERE id=?", (task_id,))
|
|
await db.commit()
|
|
return {"status": 0}
|
|
|
|
|
|
@router.get("/task/{task_id}/results")
|
|
async def get_results(task_id: str):
|
|
async with get_db() as db:
|
|
rows = await db.execute_fetchall(
|
|
"SELECT * FROM multi_hop_result WHERE task_id=? ORDER BY qid",
|
|
(task_id,),
|
|
)
|
|
results = []
|
|
for r in rows:
|
|
d = dict(r)
|
|
d["hops"] = json.loads(d.get("hops") or "[]")
|
|
d["actual_hops"] = json.loads(d.get("actual_hops") or "[]")
|
|
d["retrieved"] = json.loads(d.get("retrieved") or "[]")
|
|
results.append(d)
|
|
return {"status": 0, "data": results}
|
|
|
|
|
|
@router.get("/task/{task_id}/summary")
|
|
async def get_summary(task_id: str):
|
|
async with get_db() as db:
|
|
task_rows = await db.execute_fetchall(
|
|
"SELECT * FROM multi_hop_task WHERE id=?", (task_id,)
|
|
)
|
|
if not task_rows:
|
|
raise HTTPException(status_code=404)
|
|
task = dict(task_rows[0])
|
|
|
|
rows = await db.execute_fetchall(
|
|
"""SELECT
|
|
COUNT(*) as total,
|
|
SUM(CASE WHEN error IS NOT NULL THEN 1 ELSE 0 END) as errors,
|
|
SUM(full_hit) as full_hit_count,
|
|
SUM(partial_hit) as partial_hit_count,
|
|
SUM(full_chunk_hit) as full_chunk_hit_count,
|
|
SUM(partial_chunk_hit) as partial_chunk_hit_count,
|
|
AVG(CASE WHEN hop_count > 0 THEN CAST(hop_hit_count AS REAL) / hop_count ELSE 0 END) as avg_hop_hit_rate,
|
|
AVG(CASE WHEN hop_count > 0 THEN CAST(chunk_hit_count AS REAL) / hop_count ELSE 0 END) as avg_chunk_hit_rate,
|
|
AVG(latency_ms) as avg_latency_ms
|
|
FROM multi_hop_result WHERE task_id=?""",
|
|
(task_id,),
|
|
)
|
|
stats = dict(rows[0]) if rows else {}
|
|
|
|
total = stats.get("total") or 0
|
|
full_hit = stats.get("full_hit_count") or 0
|
|
partial_hit = stats.get("partial_hit_count") or 0
|
|
full_chunk_hit = stats.get("full_chunk_hit_count") or 0
|
|
partial_chunk_hit = stats.get("partial_chunk_hit_count") or 0
|
|
|
|
return {
|
|
"status": 0,
|
|
"data": {
|
|
"task": task,
|
|
"total": total,
|
|
"full_hit_count": full_hit,
|
|
"full_hit_rate": round(full_hit / total, 4) if total else 0.0,
|
|
"partial_hit_count": partial_hit,
|
|
"partial_hit_rate": round(partial_hit / total, 4) if total else 0.0,
|
|
"full_chunk_hit_count": full_chunk_hit,
|
|
"full_chunk_hit_rate": round(full_chunk_hit / total, 4) if total else 0.0,
|
|
"partial_chunk_hit_count": partial_chunk_hit,
|
|
"partial_chunk_hit_rate": round(partial_chunk_hit / total, 4) if total else 0.0,
|
|
"error_count": stats.get("errors") or 0,
|
|
"avg_hop_hit_rate": round(stats.get("avg_hop_hit_rate") or 0.0, 4),
|
|
"avg_chunk_hit_rate": round(stats.get("avg_chunk_hit_rate") or 0.0, 4),
|
|
"avg_latency_ms": round(stats.get("avg_latency_ms") or 0.0, 1),
|
|
}
|
|
}
|
|
|
|
|
|
# ── 后台执行 ───────────────────────────────────────────────────────────────────
|
|
|
|
async def _run_task(task_id: str, qa_text: str, env_url: str, org_id: str,
|
|
d_user_id: str, agent_id: str, llm_type: str,
|
|
top_k: int, concurrency: int):
|
|
try:
|
|
from rag_eval.multi_hop.parser import parse_multi_hop_text
|
|
from rag_eval.multi_hop.tester import MultiHopTester
|
|
from rag_eval.single_jump.mapper import FileMapper
|
|
|
|
case = parse_multi_hop_text(qa_text)
|
|
qa_pairs = case.qa_pairs
|
|
if not qa_pairs:
|
|
raise ValueError("未解析到任何多跳问答对")
|
|
|
|
total = len(qa_pairs)
|
|
async with get_db() as db:
|
|
await db.execute(
|
|
"UPDATE multi_hop_task SET status='running', total=? WHERE id=?",
|
|
(total, task_id),
|
|
)
|
|
await db.commit()
|
|
|
|
mapper = FileMapper(env_url, org_id, d_user_id)
|
|
await mapper.load_files()
|
|
all_paths = {hop.section_path for qa in qa_pairs for hop in qa.hops}
|
|
file_map = {path: mapper.map_section_to_file(path) for path in all_paths}
|
|
|
|
tester = MultiHopTester(
|
|
env_url, org_id, d_user_id,
|
|
agent_id=agent_id, llm_type=llm_type or "deepseek_v3",
|
|
)
|
|
|
|
write_buf = []
|
|
FLUSH_SIZE = 20
|
|
|
|
async def flush_buf(buf: list, progress: int):
|
|
async with get_db() as db2:
|
|
for r in buf:
|
|
await db2.execute(
|
|
"""INSERT INTO multi_hop_result
|
|
(id,task_id,qid,question,answer,type,top_k,
|
|
hops,actual_hops,retrieved,agent_answer,
|
|
latency_ms,error,best_cosine_sim,
|
|
full_hit,partial_hit,hop_count,hop_hit_count,
|
|
chunk_hit_count,full_chunk_hit,partial_chunk_hit)
|
|
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""",
|
|
(
|
|
_id(), task_id, r.qid, r.question, r.answer, r.type, r.top_k,
|
|
json.dumps([{
|
|
"section_path": h.section_path,
|
|
"file_id": h.file_id,
|
|
"file_name": h.file_name,
|
|
"hit": h.hit,
|
|
"hit_at_hop": h.hit_at_hop,
|
|
"contribution": h.contribution,
|
|
"expected_chunk_id": h.expected_chunk_id,
|
|
"chunk_hit": h.chunk_hit,
|
|
"chunk_hit_at_hop": h.chunk_hit_at_hop,
|
|
} for h in r.hop_results], ensure_ascii=False),
|
|
json.dumps([{
|
|
"hop_index": ah.hop_index,
|
|
"query": ah.query,
|
|
"retrieved": ah.retrieved,
|
|
} for ah in r.actual_hops], ensure_ascii=False),
|
|
json.dumps(r.retrieved, ensure_ascii=False),
|
|
r.agent_answer or "",
|
|
r.latency_ms, r.error, r.best_cosine_sim,
|
|
int(r.full_hit), int(r.partial_hit),
|
|
r.hop_count, r.hop_hit_count,
|
|
r.chunk_hit_count,
|
|
int(r.full_chunk_hit), int(r.partial_chunk_hit),
|
|
),
|
|
)
|
|
await db2.execute(
|
|
"UPDATE multi_hop_task SET progress=? WHERE id=?", (progress, task_id)
|
|
)
|
|
await db2.commit()
|
|
|
|
async def on_result(r, done, _total):
|
|
write_buf.append(r)
|
|
if len(write_buf) >= FLUSH_SIZE or done == _total:
|
|
buf = write_buf[:]
|
|
write_buf.clear()
|
|
await flush_buf(buf, done)
|
|
|
|
await tester.run(
|
|
qa_pairs, file_map,
|
|
top_k=top_k, concurrency=concurrency,
|
|
result_cb=on_result,
|
|
)
|
|
|
|
if write_buf:
|
|
await flush_buf(write_buf, total)
|
|
|
|
async with get_db() as db:
|
|
await db.execute(
|
|
"UPDATE multi_hop_task SET status='done', finished_at=? WHERE id=?",
|
|
(_now(), task_id),
|
|
)
|
|
await db.commit()
|
|
|
|
except Exception as exc:
|
|
async with get_db() as db:
|
|
await db.execute(
|
|
"UPDATE multi_hop_task SET status='failed', error_message=? WHERE id=?",
|
|
(str(exc), task_id),
|
|
)
|
|
await db.commit()
|