359 lines
15 KiB
Python
359 lines
15 KiB
Python
"""
|
||
召回测试执行器:对每条问答对调用 dagent 语义召回接口,记录结果。
|
||
"""
|
||
import asyncio
|
||
import json
|
||
import sys
|
||
import time
|
||
from dataclasses import dataclass, field
|
||
import aiohttp
|
||
|
||
# Fix Windows GBK encoding issue
|
||
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
||
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
||
|
||
from .parser import Section, QAPair
|
||
|
||
|
||
@dataclass
|
||
class RecallResult:
|
||
section_path: str
|
||
doc_name: str
|
||
file_id: str | None
|
||
match_type: str | None # exact / contains / fuzzy / unmatched
|
||
qid: str
|
||
question: str
|
||
reference_answer: str
|
||
top_k: int # 用于判断命中的top_k值
|
||
hit_top_k: int # 用于判断切片是否命中的top_k阈值(可能不同于召回时的top_k)
|
||
retrieved: list[dict] = field(default_factory=list) # 召回的切片列表(全部,不截断)
|
||
latency_ms: int = 0
|
||
error: str | None = None
|
||
expected_chunk_id: str | None = None # 期望命中的切片ID
|
||
raw_chunk_headers: str | None = None # 原始切片标题(从元数据解析)
|
||
|
||
# 计算属性
|
||
@property
|
||
def best_cosine_sim(self) -> float | None:
|
||
sims = [1.0 - r.get("cosine_distance_1", 1.0) for r in self.retrieved if r.get("cosine_distance_1") is not None]
|
||
return round(max(sims), 4) if sims else None
|
||
|
||
@property
|
||
def avg_cosine_sim(self) -> float | None:
|
||
sims = [1.0 - r.get("cosine_distance_1", 1.0) for r in self.retrieved if r.get("cosine_distance_1") is not None]
|
||
return round(sum(sims) / len(sims), 4) if sims else None
|
||
|
||
@property
|
||
def is_empty(self) -> bool:
|
||
return len(self.retrieved) == 0
|
||
|
||
@property
|
||
def retrieved_file_ids(self) -> list[str]:
|
||
return list({r.get("file_id", "") for r in self.retrieved if r.get("file_id")})
|
||
|
||
@property
|
||
def retrieved_chunk_ids(self) -> list[str]:
|
||
"""获取召回的所有切片ID"""
|
||
chunk_ids = []
|
||
for r in self.retrieved:
|
||
chunk_id = r.get("knowledge_md_header_split_id") or r.get("id") or r.get("chunk_id")
|
||
if chunk_id:
|
||
chunk_ids.append(chunk_id)
|
||
return chunk_ids
|
||
|
||
@property
|
||
def is_chunk_hit(self) -> bool:
|
||
"""检查期望切片是否在召回结果的前hit_top_k个结果中"""
|
||
if not self.expected_chunk_id:
|
||
return False
|
||
return self.expected_chunk_id in self.retrieved_chunk_ids[:self.hit_top_k]
|
||
|
||
@property
|
||
def chunk_hit_rank(self) -> int | None:
|
||
"""返回期望切片在召回结果中的排名(1-based),未命中返回None
|
||
|
||
只在hit_top_k范围内查找,超出范围视为未命中
|
||
"""
|
||
if not self.expected_chunk_id:
|
||
return None
|
||
try:
|
||
idx = self.retrieved_chunk_ids[:self.hit_top_k].index(self.expected_chunk_id)
|
||
return idx + 1
|
||
except ValueError:
|
||
return None
|
||
|
||
@property
|
||
def is_file_hit(self) -> bool:
|
||
"""检查期望文件是否在召回结果的前hit_top_k个结果中"""
|
||
if not self.file_id:
|
||
return False
|
||
# 获取前hit_top_k个结果的file_ids
|
||
top_file_ids = []
|
||
for r in self.retrieved[:self.hit_top_k]:
|
||
fid = r.get("file_id")
|
||
if fid:
|
||
top_file_ids.append(fid)
|
||
return self.file_id in top_file_ids
|
||
|
||
|
||
class RecallTester:
|
||
def __init__(self, env_url: str, org_id: str, d_user_id: str = "test"):
|
||
self.env_url = env_url.rstrip("/")
|
||
self.org_id = org_id
|
||
self.headers = {
|
||
"Content-Type": "application/json",
|
||
"d-user-id": d_user_id,
|
||
"org-id": org_id,
|
||
}
|
||
|
||
async def _recall_one(
|
||
self,
|
||
session: aiohttp.ClientSession,
|
||
question: str,
|
||
file_id_list: list[str] | None,
|
||
recall_top_k: int, # 用于API调用时的top_k,可以设置较大值获取所有结果
|
||
agent_id: str = "", # 用于召回测试的 agent ID
|
||
) -> tuple[list[dict], int]:
|
||
# 如果提供了 agent_id,使用 agent chat API 进行召回
|
||
if agent_id:
|
||
return await self._recall_via_agent(session, question, agent_id, recall_top_k)
|
||
|
||
# 否则直接使用知识库搜索 API
|
||
url = f"{self.env_url}/dagent/knowledge/hub/semantic_search_knowledge/detail"
|
||
payload: dict = {
|
||
"query": question,
|
||
"org_id": self.org_id,
|
||
"top_k": recall_top_k,
|
||
}
|
||
if file_id_list:
|
||
payload["file_id_list"] = file_id_list
|
||
|
||
start = time.monotonic()
|
||
# 增加超时时间到60秒,并添加重试逻辑
|
||
max_retries = 3
|
||
last_error = None
|
||
|
||
for attempt in range(max_retries):
|
||
try:
|
||
async with session.post(url, json=payload, timeout=aiohttp.ClientTimeout(total=60)) as resp:
|
||
resp.raise_for_status()
|
||
data = await resp.json()
|
||
break # 成功则跳出重试循环
|
||
except asyncio.TimeoutError as e:
|
||
last_error = e
|
||
print(f"[DEBUG] Recall timeout (attempt {attempt+1}/{max_retries}) for: {question[:50]}...")
|
||
if attempt < max_retries - 1:
|
||
await asyncio.sleep(2 ** attempt) # 指数退避: 1s, 2s, 4s
|
||
else:
|
||
raise # 最后一次重试失败,抛出异常
|
||
except Exception as e:
|
||
raise # 其他异常直接抛出
|
||
|
||
latency_ms = int((time.monotonic() - start) * 1000)
|
||
|
||
# 检查 API 返回的业务错误码
|
||
code = data.get("code")
|
||
if code is not None and code != 0:
|
||
msg = data.get("msg", "Unknown error")
|
||
raise Exception(f"API error: code={code}, msg={msg}")
|
||
|
||
result_data = data.get("data", {}) or {}
|
||
|
||
# 调试:如果结果为空,打印调试信息
|
||
if not result_data or (not result_data.get("standard_answer_results") and not result_data.get("related_knowledge_rerank_results_top")):
|
||
print(f"[DEBUG] Empty/No results for question: {question[:50]}...")
|
||
print(f"[DEBUG] Response code: {data.get('code')}, msg: {data.get('msg')}")
|
||
print(f"[DEBUG] org_id used: {self.org_id}")
|
||
print(f"[DEBUG] Request payload: {payload}")
|
||
print(f"[DEBUG] Response data keys: {list(data.keys())}")
|
||
if result_data:
|
||
print(f"[DEBUG] result_data keys: {list(result_data.keys())}")
|
||
|
||
standard = result_data.get("standard_answer_results") or []
|
||
rerank_top = result_data.get("related_knowledge_rerank_results_top") or []
|
||
all_items = standard + rerank_top
|
||
|
||
# 调试:记录召回结果数量
|
||
if len(all_items) == 0:
|
||
print(f"[DEBUG] No recall results for: {question[:50]}... (standard={len(standard)}, rerank={len(rerank_top)})")
|
||
|
||
return all_items, latency_ms
|
||
|
||
async def _recall_via_agent(
|
||
self,
|
||
session: aiohttp.ClientSession,
|
||
question: str,
|
||
agent_id: str,
|
||
recall_top_k: int,
|
||
) -> tuple[list[dict], int]:
|
||
"""通过 Agent chat SSE 接口获取召回结果。
|
||
|
||
解析策略:
|
||
- 逐行读取 SSE(服务端单 `\n` 分隔,不是双换行)
|
||
- 每个 EVENT.event_name == "TOOL_END" 的 event_data.items 里有一批 chunk
|
||
- Agent 可能多轮工具调用,每次 TOOL_END 都累加;按 (file_id, paragraph_chunk_id) 去重
|
||
- 顺序保留首次出现位置(作为伪 rank),用于命中排名统计
|
||
"""
|
||
import uuid
|
||
payload = {
|
||
"chat_id": uuid.uuid4().hex,
|
||
"task": question,
|
||
"agent_id": agent_id,
|
||
"llm_type": "deepseek_v3",
|
||
}
|
||
|
||
start = time.monotonic()
|
||
items: list[dict] = []
|
||
seen: set[tuple[str, str]] = set()
|
||
|
||
try:
|
||
async with session.post(
|
||
f"{self.env_url}/dagent/agent/chat",
|
||
json=payload,
|
||
headers={"Accept": "text/event-stream"},
|
||
timeout=aiohttp.ClientTimeout(total=300),
|
||
) as resp:
|
||
resp.raise_for_status()
|
||
line_buf = ""
|
||
async for raw in resp.content:
|
||
line_buf += raw.decode("utf-8", errors="replace")
|
||
while "\n" in line_buf:
|
||
line, line_buf = line_buf.split("\n", 1)
|
||
line = line.rstrip("\r")
|
||
if not line.startswith("data:"):
|
||
continue
|
||
data_str = line[5:].strip()
|
||
if not data_str or data_str == "[DONE]":
|
||
continue
|
||
try:
|
||
chunk = json.loads(data_str)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
if chunk.get("message_type") != "EVENT" or chunk.get("is_chunk_data"):
|
||
continue
|
||
event_data_raw = chunk.get("data")
|
||
if isinstance(event_data_raw, str):
|
||
try:
|
||
event_data = json.loads(event_data_raw)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
else:
|
||
event_data = event_data_raw
|
||
if not isinstance(event_data, dict):
|
||
continue
|
||
if event_data.get("event_name") != "TOOL_END":
|
||
continue
|
||
tool_event_data = event_data.get("event_data")
|
||
if not isinstance(tool_event_data, dict):
|
||
continue
|
||
reference_items = tool_event_data.get("items") or []
|
||
if not isinstance(reference_items, list):
|
||
continue
|
||
for item in reference_items:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
file_id = str(item.get("file_id") or "")
|
||
chunk_id = str(
|
||
item.get("paragraph_chunk_id")
|
||
or item.get("knowledge_md_header_split_id")
|
||
or ""
|
||
)
|
||
# 跳过不带 file_id/chunk_id 的外链类条目(只有 file_name+url)
|
||
if not file_id and not chunk_id:
|
||
continue
|
||
key = (file_id, chunk_id)
|
||
if key in seen:
|
||
continue
|
||
seen.add(key)
|
||
items.append({
|
||
"file_id": file_id,
|
||
"file_name": "",
|
||
"headers": str(item.get("headers") or ""),
|
||
"content": item.get("active_paragraph_context")
|
||
or item.get("active_context") or "",
|
||
"knowledge_md_header_split_id": chunk_id,
|
||
"id": chunk_id,
|
||
"paragraph_md5": str(item.get("paragraph_md5") or ""),
|
||
"cosine_distance_1": None,
|
||
})
|
||
except Exception as e:
|
||
print(f"[DEBUG] Agent recall error: {e}")
|
||
|
||
latency_ms = int((time.monotonic() - start) * 1000)
|
||
return items[:recall_top_k], latency_ms
|
||
|
||
async def run(
|
||
self,
|
||
sections: list[Section],
|
||
file_map: dict[str, dict | None],
|
||
top_k: int = 5, # 用于判断命中的top_k阈值
|
||
recall_top_k: int = 100, # 用于API调用时的top_k,默认100获取更多结果
|
||
concurrency: int = 20, # 增加默认并发数到20
|
||
cross_chunk: bool = False, # 保留参数兼容旧调用,但不再控制搜索范围
|
||
result_cb=None,
|
||
progress_cb=None, # 保留兼容旧调用
|
||
chunk_map: dict[str, str] | None = None, # question -> expected_chunk_id
|
||
agent_id: str = "", # 用于召回测试的 agent ID
|
||
) -> list[RecallResult]:
|
||
results: list[RecallResult] = []
|
||
sem = asyncio.Semaphore(concurrency)
|
||
total = sum(len(s.qa_pairs) for s in sections)
|
||
done = 0
|
||
|
||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||
async def _test_one(section: Section, qa: QAPair) -> RecallResult:
|
||
nonlocal done
|
||
mapping = file_map.get(section.section_path)
|
||
file_id = mapping["file_id"] if mapping else None
|
||
match_type = mapping["match_type"] if mapping else "unmatched"
|
||
|
||
# 优先使用 QAPair 上已注入的 chunk_id,其次从 chunk_map 查找
|
||
expected_chunk_id = qa.expected_chunk_id or (
|
||
chunk_map.get(qa.question) if chunk_map else None
|
||
)
|
||
|
||
result = RecallResult(
|
||
section_path=section.section_path,
|
||
doc_name=section.doc_name,
|
||
file_id=file_id,
|
||
match_type=match_type,
|
||
qid=qa.qid,
|
||
question=qa.question,
|
||
reference_answer=qa.answer,
|
||
top_k=top_k,
|
||
hit_top_k=top_k, # 用于判断命中的阈值
|
||
expected_chunk_id=expected_chunk_id,
|
||
raw_chunk_headers=section.raw_chunk_headers,
|
||
)
|
||
|
||
# 始终全库搜索(不传 file_id_list),以切片命中为主要指标
|
||
# 使用较大的 recall_top_k 获取所有召回结果
|
||
async with sem:
|
||
try:
|
||
chunks, latency = await self._recall_one(session, qa.question, None, recall_top_k, agent_id)
|
||
result.retrieved = chunks
|
||
result.latency_ms = latency
|
||
# 调试:记录召回结果数量
|
||
if len(chunks) == 0:
|
||
print(f"[DEBUG] Empty recall for question: {qa.question[:60]}... (section: {section.section_path[:40]}...)")
|
||
except Exception as e:
|
||
result.error = str(e)
|
||
print(f"[DEBUG] Recall error for question: {qa.question[:60]}... Error: {e}")
|
||
|
||
done += 1
|
||
if result_cb:
|
||
await result_cb(result, done, total)
|
||
elif progress_cb and (done % 10 == 0 or done == total):
|
||
await progress_cb(done, total)
|
||
return result
|
||
|
||
tasks = [
|
||
_test_one(section, qa)
|
||
for section in sections
|
||
for qa in section.qa_pairs
|
||
]
|
||
for coro in asyncio.as_completed(tasks):
|
||
results.append(await coro)
|
||
|
||
return results
|