335 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
多跳召回测试执行器 v4
策略:调用 dagent 的 /agent/chat SSE 接口,让 Agent 自主决定搜几次、用什么 query。
解析 SSE 流中的 TOOL_END 事件,收集每一跳的召回文档,和期望 hop 做对比。
"""
import asyncio
import json
import time
from dataclasses import dataclass, field
import aiohttp
from .parser import MultiHopQAPair, Hop
@dataclass
class HopResult:
section_path: str
file_id: str | None
file_name: str | None
contribution: str
expected_chunk_id: str = "" # 期望命中的切片ID
hit: bool = False # 文件级命中
hit_at_hop: int | None = None
chunk_hit: bool = False # 切片级命中
chunk_hit_at_hop: int | None = None
@dataclass
class ActualHop:
"""Agent 实际执行的一跳"""
hop_index: int
query: str
retrieved: list[dict]
@dataclass
class MultiHopResult:
qid: str
question: str
answer: str
type: str
top_k: int
hop_results: list[HopResult]
actual_hops: list[ActualHop] = field(default_factory=list)
agent_answer: str = ""
latency_ms: int = 0
error: str | None = None
@property
def hop_count(self) -> int:
return len(self.hop_results)
@property
def actual_hop_count(self) -> int:
return len(self.actual_hops)
@property
def hop_hit_count(self) -> int:
return sum(1 for h in self.hop_results if h.hit)
@property
def chunk_hit_count(self) -> int:
return sum(1 for h in self.hop_results if h.chunk_hit)
@property
def full_hit(self) -> bool:
mappable = [h for h in self.hop_results if h.file_id]
return len(mappable) > 0 and all(h.hit for h in mappable)
@property
def full_chunk_hit(self) -> bool:
mappable = [h for h in self.hop_results if h.expected_chunk_id]
return len(mappable) > 0 and all(h.chunk_hit for h in mappable)
@property
def partial_hit(self) -> bool:
return any(h.hit for h in self.hop_results)
@property
def partial_chunk_hit(self) -> bool:
return any(h.chunk_hit for h in self.hop_results)
@property
def retrieved(self) -> list[dict]:
"""所有跳的召回结果合并去重"""
seen: set[str] = set()
merged = []
for ah in self.actual_hops:
for doc in ah.retrieved:
key = doc.get("file_id", "") + doc.get("headers", "")
if key not in seen:
seen.add(key)
merged.append(doc)
return merged
@property
def retrieved_file_ids(self) -> set[str]:
return {r.get("file_id", "") for r in self.retrieved if r.get("file_id")}
@property
def best_cosine_sim(self) -> float | None:
sims = [1.0 - r.get("cosine_distance_1", 1.0)
for r in self.retrieved if r.get("cosine_distance_1") is not None]
return round(max(sims), 4) if sims else None
async def _parse_agent_chat_sse(
session: aiohttp.ClientSession,
url: str,
payload: dict,
timeout_s: int = 300,
) -> tuple[list[ActualHop], str]:
"""
调用 /agent/chat SSE 接口,解析流中的事件。
返回:(actual_hops, agent_answer)
SSE 格式:每行一条 `data: {...}` 消息,行间以单个 \n 分隔(不是 \n\n)。
"""
import re as _re
actual_hops: list[ActualHop] = []
answer_chunks: list[str] = []
tool_query = ""
hop_index = 0
async with session.post(
url, json=payload,
timeout=aiohttp.ClientTimeout(total=timeout_s),
) as resp:
resp.raise_for_status()
# 逐行读取:服务端每行一条 data: 消息
line_buf = ""
async for raw in resp.content:
line_buf += raw.decode("utf-8", errors="replace")
# 按换行切割,保留末尾不完整行
while "\n" in line_buf:
line, line_buf = line_buf.split("\n", 1)
line = line.rstrip("\r")
if not line.startswith("data:"):
continue
data_str = line[5:].strip()
if not data_str or data_str == "[DONE]":
continue
try:
parsed = json.loads(data_str)
except json.JSONDecodeError:
continue
mt = parsed.get("message_type", "")
is_chunk = parsed.get("is_chunk_data", False)
data = parsed.get("data", "")
# 收集 Agent 最终回答
if is_chunk and mt not in ("THINKING_CHUNK", "EVENT"):
if isinstance(data, str):
answer_chunks.append(data)
# 收集 TOOL_CHUNK 中的 query 参数
if mt == "TOOL_CHUNK" and is_chunk and isinstance(data, str):
tool_query += data
# 解析 EVENT
if mt == "EVENT" and not is_chunk:
try:
ed = json.loads(data) if isinstance(data, str) else data
except (json.JSONDecodeError, TypeError):
continue
if not isinstance(ed, dict):
continue
ename = ed.get("event_name", "")
if ename == "TOOL_START":
tool_query = ""
elif ename == "TOOL_END":
edata = ed.get("event_data")
docs = []
if isinstance(edata, dict) and "items" in edata:
for item in edata["items"]:
file_id = str(item.get("file_id") or "")
chunk_id = str(item.get("paragraph_chunk_id") or "")
# 跳过外链类工具(无 file_id/chunk_id
if not file_id and not chunk_id:
continue
docs.append({
"file_id": file_id,
"headers": item.get("headers", ""),
"paragraph_md5": item.get("paragraph_md5", ""),
"paragraph_chunk_id": chunk_id,
})
# 只记录真正召回了知识切片的 hop
if docs:
hop_index += 1
query_match = _re.search(
r"<query>(.*?)</query>", tool_query, _re.DOTALL
)
query_text = (
query_match.group(1).strip()
if query_match
else tool_query.strip()
)
actual_hops.append(ActualHop(
hop_index=hop_index,
query=query_text,
retrieved=docs,
))
tool_query = ""
agent_answer = "".join(answer_chunks).strip()
return actual_hops, agent_answer
class MultiHopTester:
def __init__(self, env_url: str, org_id: str, d_user_id: str = "test",
agent_id: str = "", llm_type: str = "deepseek_v3"):
self.env_url = env_url.rstrip("/")
self.org_id = org_id
self.agent_id = agent_id
self.llm_type = llm_type
self.headers = {
"Content-Type": "application/json",
"d-user-id": d_user_id,
"org-id": org_id,
}
async def run(
self,
qa_pairs: list[MultiHopQAPair],
file_map: dict[str, dict | None],
top_k: int = 10,
concurrency: int = 5,
result_cb=None,
) -> list[MultiHopResult]:
results: list[MultiHopResult] = []
sem = asyncio.Semaphore(concurrency)
total = len(qa_pairs)
done = 0
connector = aiohttp.TCPConnector(ssl=False)
async with aiohttp.ClientSession(
headers=self.headers, connector=connector
) as session:
async def _test_one(qa: MultiHopQAPair) -> MultiHopResult:
nonlocal done
hop_results = []
for hop in qa.hops:
mapping = file_map.get(hop.section_path)
hop_results.append(HopResult(
section_path=hop.section_path,
file_id=mapping["file_id"] if mapping else None,
file_name=mapping["file_name"] if mapping else None,
contribution=hop.contribution,
expected_chunk_id=hop.chunk_id or "",
))
result = MultiHopResult(
qid=qa.qid,
question=qa.question,
answer=qa.answer,
type=qa.type,
top_k=top_k,
hop_results=hop_results,
)
async with sem:
start = time.monotonic()
try:
import uuid
# 构建 chat URL如果 env_url 以 /dagent 结尾,则拼接 /agent/chat否则拼接 /dagent/agent/chat
base = self.env_url.rstrip("/")
if base.endswith("/dagent"):
chat_url = f"{base}/agent/chat"
else:
chat_url = f"{base}/dagent/agent/chat"
payload = {
"task": qa.question,
"agent_id": self.agent_id,
"chat_id": uuid.uuid4().hex,
"llm_type": self.llm_type,
}
actual_hops, agent_answer = await _parse_agent_chat_sse(
session, chat_url, payload, timeout_s=300,
)
result.actual_hops = actual_hops
result.agent_answer = agent_answer
result.latency_ms = int(
(time.monotonic() - start) * 1000
)
# 文件级命中:期望文件是否出现在任意一跳召回中
for hr in result.hop_results:
if hr.file_id:
for ah in actual_hops:
if any(
d.get("file_id") == hr.file_id
for d in ah.retrieved
):
hr.hit = True
hr.hit_at_hop = ah.hop_index
break
# 切片级命中:期望 chunk_id 是否出现在任意一跳召回中
if hr.expected_chunk_id:
for ah in actual_hops:
if any(
d.get("paragraph_chunk_id") == hr.expected_chunk_id
for d in ah.retrieved
):
hr.chunk_hit = True
hr.chunk_hit_at_hop = ah.hop_index
break
except Exception as e:
result.error = str(e)
result.latency_ms = int(
(time.monotonic() - start) * 1000
)
done += 1
if result_cb:
await result_cb(result, done, total)
return result
tasks = [_test_one(qa) for qa in qa_pairs]
for coro in asyncio.as_completed(tasks):
results.append(await coro)
return results