dagent_eval/server/api/multi_hop.py

283 lines
11 KiB
Python

"""
多跳召回测试 API
"""
import asyncio
import json
import sys
from pathlib import Path
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "sdk"))
sys.path.insert(0, str(Path(__file__).parent.parent))
from models.db import get_db, _now, _id
router = APIRouter(prefix="/api/multi-hop", tags=["多跳召回测试"])
@router.get("/dagent/agents")
async def list_dagent_agents(env_url: str, org_id: str, d_user_id: str = "test"):
"""从 dagent 平台拉取可用的 Agent 列表"""
import aiohttp
url = f"{env_url.rstrip('/')}/dagent/agent/page"
headers = {
"Content-Type": "application/json",
"d-user-id": d_user_id,
"org-id": org_id,
}
payload = {"current": 1, "page_size": 100, "org_id": org_id}
try:
async with aiohttp.ClientSession(headers=headers) as session:
async with session.post(url, json=payload, timeout=aiohttp.ClientTimeout(total=10)) as resp:
resp.raise_for_status()
data = await resp.json()
agents = data.get("data", {}).get("list", [])
return {"status": 0, "data": [
{"id": a.get("id"), "name": a.get("agent_name"), "type": a.get("agent_type"), "description": a.get("agent_description")}
for a in agents
]}
except Exception as e:
raise HTTPException(status_code=502, detail=f"无法连接 dagent: {e}")
@router.post("/task")
async def create_task(
file: UploadFile = File(...),
name: str = Form(""),
env_url: str = Form(...),
org_id: str = Form(...),
d_user_id: str = Form("test"),
agent_id: str = Form(...),
llm_type: str = Form("deepseek_v3"),
top_k: int = Form(10),
concurrency: int = Form(5),
):
content = await file.read()
qa_text = content.decode("utf-8")
task_id = _id()
async with get_db() as db:
await db.execute(
"""INSERT INTO multi_hop_task
(id,name,env_url,org_id,d_user_id,agent_id,llm_type,top_k,concurrency,status,created_at)
VALUES (?,?,?,?,?,?,?,?,?,?,?)""",
(task_id, name or file.filename, env_url, org_id,
d_user_id, agent_id, llm_type, top_k, concurrency, "pending", _now()),
)
await db.commit()
asyncio.create_task(_run_task(
task_id, qa_text, env_url, org_id, d_user_id, agent_id, llm_type, top_k, concurrency
))
return {"status": 0, "data": {"id": task_id}}
@router.get("/task/list")
async def list_tasks():
async with get_db() as db:
rows = await db.execute_fetchall(
"SELECT * FROM multi_hop_task ORDER BY created_at DESC"
)
return {"status": 0, "data": [dict(r) for r in rows]}
@router.get("/task/{task_id}")
async def get_task(task_id: str):
async with get_db() as db:
rows = await db.execute_fetchall(
"SELECT * FROM multi_hop_task WHERE id=?", (task_id,)
)
if not rows:
raise HTTPException(status_code=404, detail="Task not found")
return {"status": 0, "data": dict(rows[0])}
@router.delete("/task/{task_id}")
async def delete_task(task_id: str):
async with get_db() as db:
await db.execute("DELETE FROM multi_hop_result WHERE task_id=?", (task_id,))
await db.execute("DELETE FROM multi_hop_task WHERE id=?", (task_id,))
await db.commit()
return {"status": 0}
@router.get("/task/{task_id}/results")
async def get_results(task_id: str):
async with get_db() as db:
rows = await db.execute_fetchall(
"SELECT * FROM multi_hop_result WHERE task_id=? ORDER BY qid",
(task_id,),
)
results = []
for r in rows:
d = dict(r)
d["hops"] = json.loads(d.get("hops") or "[]")
d["actual_hops"] = json.loads(d.get("actual_hops") or "[]")
d["retrieved"] = json.loads(d.get("retrieved") or "[]")
results.append(d)
return {"status": 0, "data": results}
@router.get("/task/{task_id}/summary")
async def get_summary(task_id: str):
async with get_db() as db:
task_rows = await db.execute_fetchall(
"SELECT * FROM multi_hop_task WHERE id=?", (task_id,)
)
if not task_rows:
raise HTTPException(status_code=404)
task = dict(task_rows[0])
rows = await db.execute_fetchall(
"""SELECT
COUNT(*) as total,
SUM(CASE WHEN error IS NOT NULL THEN 1 ELSE 0 END) as errors,
SUM(full_hit) as full_hit_count,
SUM(partial_hit) as partial_hit_count,
SUM(full_chunk_hit) as full_chunk_hit_count,
SUM(partial_chunk_hit) as partial_chunk_hit_count,
AVG(CASE WHEN hop_count > 0 THEN CAST(hop_hit_count AS REAL) / hop_count ELSE 0 END) as avg_hop_hit_rate,
AVG(CASE WHEN hop_count > 0 THEN CAST(chunk_hit_count AS REAL) / hop_count ELSE 0 END) as avg_chunk_hit_rate,
AVG(latency_ms) as avg_latency_ms
FROM multi_hop_result WHERE task_id=?""",
(task_id,),
)
stats = dict(rows[0]) if rows else {}
total = stats.get("total") or 0
full_hit = stats.get("full_hit_count") or 0
partial_hit = stats.get("partial_hit_count") or 0
full_chunk_hit = stats.get("full_chunk_hit_count") or 0
partial_chunk_hit = stats.get("partial_chunk_hit_count") or 0
return {
"status": 0,
"data": {
"task": task,
"total": total,
"full_hit_count": full_hit,
"full_hit_rate": round(full_hit / total, 4) if total else 0.0,
"partial_hit_count": partial_hit,
"partial_hit_rate": round(partial_hit / total, 4) if total else 0.0,
"full_chunk_hit_count": full_chunk_hit,
"full_chunk_hit_rate": round(full_chunk_hit / total, 4) if total else 0.0,
"partial_chunk_hit_count": partial_chunk_hit,
"partial_chunk_hit_rate": round(partial_chunk_hit / total, 4) if total else 0.0,
"error_count": stats.get("errors") or 0,
"avg_hop_hit_rate": round(stats.get("avg_hop_hit_rate") or 0.0, 4),
"avg_chunk_hit_rate": round(stats.get("avg_chunk_hit_rate") or 0.0, 4),
"avg_latency_ms": round(stats.get("avg_latency_ms") or 0.0, 1),
}
}
# ── 后台执行 ───────────────────────────────────────────────────────────────────
async def _run_task(task_id: str, qa_text: str, env_url: str, org_id: str,
d_user_id: str, agent_id: str, llm_type: str,
top_k: int, concurrency: int):
try:
from rag_eval.multi_hop.parser import parse_multi_hop_text
from rag_eval.multi_hop.tester import MultiHopTester
from rag_eval.single_jump.mapper import FileMapper
case = parse_multi_hop_text(qa_text)
qa_pairs = case.qa_pairs
if not qa_pairs:
raise ValueError("未解析到任何多跳问答对")
total = len(qa_pairs)
async with get_db() as db:
await db.execute(
"UPDATE multi_hop_task SET status='running', total=? WHERE id=?",
(total, task_id),
)
await db.commit()
mapper = FileMapper(env_url, org_id, d_user_id)
await mapper.load_files()
all_paths = {hop.section_path for qa in qa_pairs for hop in qa.hops}
file_map = {path: mapper.map_section_to_file(path) for path in all_paths}
tester = MultiHopTester(
env_url, org_id, d_user_id,
agent_id=agent_id, llm_type=llm_type or "deepseek_v3",
)
write_buf = []
FLUSH_SIZE = 20
async def flush_buf(buf: list, progress: int):
async with get_db() as db2:
for r in buf:
await db2.execute(
"""INSERT INTO multi_hop_result
(id,task_id,qid,question,answer,type,top_k,
hops,actual_hops,retrieved,agent_answer,
latency_ms,error,best_cosine_sim,
full_hit,partial_hit,hop_count,hop_hit_count,
chunk_hit_count,full_chunk_hit,partial_chunk_hit)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""",
(
_id(), task_id, r.qid, r.question, r.answer, r.type, r.top_k,
json.dumps([{
"section_path": h.section_path,
"file_id": h.file_id,
"file_name": h.file_name,
"hit": h.hit,
"hit_at_hop": h.hit_at_hop,
"contribution": h.contribution,
"expected_chunk_id": h.expected_chunk_id,
"chunk_hit": h.chunk_hit,
"chunk_hit_at_hop": h.chunk_hit_at_hop,
} for h in r.hop_results], ensure_ascii=False),
json.dumps([{
"hop_index": ah.hop_index,
"query": ah.query,
"retrieved": ah.retrieved,
} for ah in r.actual_hops], ensure_ascii=False),
json.dumps(r.retrieved, ensure_ascii=False),
r.agent_answer or "",
r.latency_ms, r.error, r.best_cosine_sim,
int(r.full_hit), int(r.partial_hit),
r.hop_count, r.hop_hit_count,
r.chunk_hit_count,
int(r.full_chunk_hit), int(r.partial_chunk_hit),
),
)
await db2.execute(
"UPDATE multi_hop_task SET progress=? WHERE id=?", (progress, task_id)
)
await db2.commit()
async def on_result(r, done, _total):
write_buf.append(r)
if len(write_buf) >= FLUSH_SIZE or done == _total:
buf = write_buf[:]
write_buf.clear()
await flush_buf(buf, done)
await tester.run(
qa_pairs, file_map,
top_k=top_k, concurrency=concurrency,
result_cb=on_result,
)
if write_buf:
await flush_buf(write_buf, total)
async with get_db() as db:
await db.execute(
"UPDATE multi_hop_task SET status='done', finished_at=? WHERE id=?",
(_now(), task_id),
)
await db.commit()
except Exception as exc:
async with get_db() as db:
await db.execute(
"UPDATE multi_hop_task SET status='failed', error_message=? WHERE id=?",
(str(exc), task_id),
)
await db.commit()