359 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
召回测试执行器:对每条问答对调用 dagent 语义召回接口,记录结果。
"""
import asyncio
import json
import sys
import time
from dataclasses import dataclass, field
import aiohttp
# Fix Windows GBK encoding issue
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
from .parser import Section, QAPair
@dataclass
class RecallResult:
section_path: str
doc_name: str
file_id: str | None
match_type: str | None # exact / contains / fuzzy / unmatched
qid: str
question: str
reference_answer: str
top_k: int # 用于判断命中的top_k值
hit_top_k: int # 用于判断切片是否命中的top_k阈值可能不同于召回时的top_k
retrieved: list[dict] = field(default_factory=list) # 召回的切片列表(全部,不截断)
latency_ms: int = 0
error: str | None = None
expected_chunk_id: str | None = None # 期望命中的切片ID
raw_chunk_headers: str | None = None # 原始切片标题(从元数据解析)
# 计算属性
@property
def best_cosine_sim(self) -> float | None:
sims = [1.0 - r.get("cosine_distance_1", 1.0) for r in self.retrieved if r.get("cosine_distance_1") is not None]
return round(max(sims), 4) if sims else None
@property
def avg_cosine_sim(self) -> float | None:
sims = [1.0 - r.get("cosine_distance_1", 1.0) for r in self.retrieved if r.get("cosine_distance_1") is not None]
return round(sum(sims) / len(sims), 4) if sims else None
@property
def is_empty(self) -> bool:
return len(self.retrieved) == 0
@property
def retrieved_file_ids(self) -> list[str]:
return list({r.get("file_id", "") for r in self.retrieved if r.get("file_id")})
@property
def retrieved_chunk_ids(self) -> list[str]:
"""获取召回的所有切片ID"""
chunk_ids = []
for r in self.retrieved:
chunk_id = r.get("knowledge_md_header_split_id") or r.get("id") or r.get("chunk_id")
if chunk_id:
chunk_ids.append(chunk_id)
return chunk_ids
@property
def is_chunk_hit(self) -> bool:
"""检查期望切片是否在召回结果的前hit_top_k个结果中"""
if not self.expected_chunk_id:
return False
return self.expected_chunk_id in self.retrieved_chunk_ids[:self.hit_top_k]
@property
def chunk_hit_rank(self) -> int | None:
"""返回期望切片在召回结果中的排名1-based未命中返回None
只在hit_top_k范围内查找超出范围视为未命中
"""
if not self.expected_chunk_id:
return None
try:
idx = self.retrieved_chunk_ids[:self.hit_top_k].index(self.expected_chunk_id)
return idx + 1
except ValueError:
return None
@property
def is_file_hit(self) -> bool:
"""检查期望文件是否在召回结果的前hit_top_k个结果中"""
if not self.file_id:
return False
# 获取前hit_top_k个结果的file_ids
top_file_ids = []
for r in self.retrieved[:self.hit_top_k]:
fid = r.get("file_id")
if fid:
top_file_ids.append(fid)
return self.file_id in top_file_ids
class RecallTester:
def __init__(self, env_url: str, org_id: str, d_user_id: str = "test"):
self.env_url = env_url.rstrip("/")
self.org_id = org_id
self.headers = {
"Content-Type": "application/json",
"d-user-id": d_user_id,
"org-id": org_id,
}
async def _recall_one(
self,
session: aiohttp.ClientSession,
question: str,
file_id_list: list[str] | None,
recall_top_k: int, # 用于API调用时的top_k可以设置较大值获取所有结果
agent_id: str = "", # 用于召回测试的 agent ID
) -> tuple[list[dict], int]:
# 如果提供了 agent_id使用 agent chat API 进行召回
if agent_id:
return await self._recall_via_agent(session, question, agent_id, recall_top_k)
# 否则直接使用知识库搜索 API
url = f"{self.env_url}/dagent/knowledge/hub/semantic_search_knowledge/detail"
payload: dict = {
"query": question,
"org_id": self.org_id,
"top_k": recall_top_k,
}
if file_id_list:
payload["file_id_list"] = file_id_list
start = time.monotonic()
# 增加超时时间到60秒并添加重试逻辑
max_retries = 3
last_error = None
for attempt in range(max_retries):
try:
async with session.post(url, json=payload, timeout=aiohttp.ClientTimeout(total=60)) as resp:
resp.raise_for_status()
data = await resp.json()
break # 成功则跳出重试循环
except asyncio.TimeoutError as e:
last_error = e
print(f"[DEBUG] Recall timeout (attempt {attempt+1}/{max_retries}) for: {question[:50]}...")
if attempt < max_retries - 1:
await asyncio.sleep(2 ** attempt) # 指数退避: 1s, 2s, 4s
else:
raise # 最后一次重试失败,抛出异常
except Exception as e:
raise # 其他异常直接抛出
latency_ms = int((time.monotonic() - start) * 1000)
# 检查 API 返回的业务错误码
code = data.get("code")
if code is not None and code != 0:
msg = data.get("msg", "Unknown error")
raise Exception(f"API error: code={code}, msg={msg}")
result_data = data.get("data", {}) or {}
# 调试:如果结果为空,打印调试信息
if not result_data or (not result_data.get("standard_answer_results") and not result_data.get("related_knowledge_rerank_results_top")):
print(f"[DEBUG] Empty/No results for question: {question[:50]}...")
print(f"[DEBUG] Response code: {data.get('code')}, msg: {data.get('msg')}")
print(f"[DEBUG] org_id used: {self.org_id}")
print(f"[DEBUG] Request payload: {payload}")
print(f"[DEBUG] Response data keys: {list(data.keys())}")
if result_data:
print(f"[DEBUG] result_data keys: {list(result_data.keys())}")
standard = result_data.get("standard_answer_results") or []
rerank_top = result_data.get("related_knowledge_rerank_results_top") or []
all_items = standard + rerank_top
# 调试:记录召回结果数量
if len(all_items) == 0:
print(f"[DEBUG] No recall results for: {question[:50]}... (standard={len(standard)}, rerank={len(rerank_top)})")
return all_items, latency_ms
async def _recall_via_agent(
self,
session: aiohttp.ClientSession,
question: str,
agent_id: str,
recall_top_k: int,
) -> tuple[list[dict], int]:
"""通过 Agent chat SSE 接口获取召回结果。
解析策略:
- 逐行读取 SSE服务端单 `\n` 分隔,不是双换行)
- 每个 EVENT.event_name == "TOOL_END" 的 event_data.items 里有一批 chunk
- Agent 可能多轮工具调用,每次 TOOL_END 都累加;按 (file_id, paragraph_chunk_id) 去重
- 顺序保留首次出现位置(作为伪 rank用于命中排名统计
"""
import uuid
payload = {
"chat_id": uuid.uuid4().hex,
"task": question,
"agent_id": agent_id,
"llm_type": "deepseek_v3",
}
start = time.monotonic()
items: list[dict] = []
seen: set[tuple[str, str]] = set()
try:
async with session.post(
f"{self.env_url}/dagent/agent/chat",
json=payload,
headers={"Accept": "text/event-stream"},
timeout=aiohttp.ClientTimeout(total=300),
) as resp:
resp.raise_for_status()
line_buf = ""
async for raw in resp.content:
line_buf += raw.decode("utf-8", errors="replace")
while "\n" in line_buf:
line, line_buf = line_buf.split("\n", 1)
line = line.rstrip("\r")
if not line.startswith("data:"):
continue
data_str = line[5:].strip()
if not data_str or data_str == "[DONE]":
continue
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
continue
if chunk.get("message_type") != "EVENT" or chunk.get("is_chunk_data"):
continue
event_data_raw = chunk.get("data")
if isinstance(event_data_raw, str):
try:
event_data = json.loads(event_data_raw)
except json.JSONDecodeError:
continue
else:
event_data = event_data_raw
if not isinstance(event_data, dict):
continue
if event_data.get("event_name") != "TOOL_END":
continue
tool_event_data = event_data.get("event_data")
if not isinstance(tool_event_data, dict):
continue
reference_items = tool_event_data.get("items") or []
if not isinstance(reference_items, list):
continue
for item in reference_items:
if not isinstance(item, dict):
continue
file_id = str(item.get("file_id") or "")
chunk_id = str(
item.get("paragraph_chunk_id")
or item.get("knowledge_md_header_split_id")
or ""
)
# 跳过不带 file_id/chunk_id 的外链类条目(只有 file_name+url
if not file_id and not chunk_id:
continue
key = (file_id, chunk_id)
if key in seen:
continue
seen.add(key)
items.append({
"file_id": file_id,
"file_name": "",
"headers": str(item.get("headers") or ""),
"content": item.get("active_paragraph_context")
or item.get("active_context") or "",
"knowledge_md_header_split_id": chunk_id,
"id": chunk_id,
"paragraph_md5": str(item.get("paragraph_md5") or ""),
"cosine_distance_1": None,
})
except Exception as e:
print(f"[DEBUG] Agent recall error: {e}")
latency_ms = int((time.monotonic() - start) * 1000)
return items[:recall_top_k], latency_ms
async def run(
self,
sections: list[Section],
file_map: dict[str, dict | None],
top_k: int = 5, # 用于判断命中的top_k阈值
recall_top_k: int = 100, # 用于API调用时的top_k默认100获取更多结果
concurrency: int = 20, # 增加默认并发数到20
cross_chunk: bool = False, # 保留参数兼容旧调用,但不再控制搜索范围
result_cb=None,
progress_cb=None, # 保留兼容旧调用
chunk_map: dict[str, str] | None = None, # question -> expected_chunk_id
agent_id: str = "", # 用于召回测试的 agent ID
) -> list[RecallResult]:
results: list[RecallResult] = []
sem = asyncio.Semaphore(concurrency)
total = sum(len(s.qa_pairs) for s in sections)
done = 0
async with aiohttp.ClientSession(headers=self.headers) as session:
async def _test_one(section: Section, qa: QAPair) -> RecallResult:
nonlocal done
mapping = file_map.get(section.section_path)
file_id = mapping["file_id"] if mapping else None
match_type = mapping["match_type"] if mapping else "unmatched"
# 优先使用 QAPair 上已注入的 chunk_id其次从 chunk_map 查找
expected_chunk_id = qa.expected_chunk_id or (
chunk_map.get(qa.question) if chunk_map else None
)
result = RecallResult(
section_path=section.section_path,
doc_name=section.doc_name,
file_id=file_id,
match_type=match_type,
qid=qa.qid,
question=qa.question,
reference_answer=qa.answer,
top_k=top_k,
hit_top_k=top_k, # 用于判断命中的阈值
expected_chunk_id=expected_chunk_id,
raw_chunk_headers=section.raw_chunk_headers,
)
# 始终全库搜索(不传 file_id_list以切片命中为主要指标
# 使用较大的 recall_top_k 获取所有召回结果
async with sem:
try:
chunks, latency = await self._recall_one(session, qa.question, None, recall_top_k, agent_id)
result.retrieved = chunks
result.latency_ms = latency
# 调试:记录召回结果数量
if len(chunks) == 0:
print(f"[DEBUG] Empty recall for question: {qa.question[:60]}... (section: {section.section_path[:40]}...)")
except Exception as e:
result.error = str(e)
print(f"[DEBUG] Recall error for question: {qa.question[:60]}... Error: {e}")
done += 1
if result_cb:
await result_cb(result, done, total)
elif progress_cb and (done % 10 == 0 or done == total):
await progress_cb(done, total)
return result
tasks = [
_test_one(section, qa)
for section in sections
for qa in section.qa_pairs
]
for coro in asyncio.as_completed(tasks):
results.append(await coro)
return results