258 lines
10 KiB
Python
258 lines
10 KiB
Python
import asyncio
|
|
import uuid
|
|
from dataclasses import dataclass
|
|
from typing import Callable
|
|
|
|
from .adapters.base import RAGAdapter
|
|
from .judge.base import LLMJudge
|
|
from .evaluators.retrieval import hit_rate, mrr, ndcg
|
|
from .dataset.schema import EvalDataset, EvalSample
|
|
from .report import EvalReport, SampleResult
|
|
|
|
|
|
RETRIEVAL_METRIC_KEYS = {"hit_rate", "mrr", "ndcg", "context_precision", "context_recall"}
|
|
GENERATION_METRIC_KEYS = {"faithfulness", "answer_relevance", "answer_correctness", "groundedness"}
|
|
|
|
|
|
@dataclass
|
|
class RunConfig:
|
|
agent_id: str
|
|
knowledge_hub_id: str
|
|
top_k: int = 10
|
|
eval_retrieval: bool = True
|
|
eval_generation: bool = True
|
|
selected_metrics: list[str] | None = None
|
|
file_id_list: list[str] | None = None
|
|
concurrency: int = 3 # 并发评测样本数
|
|
faithfulness_threshold: float = 0.7 # 低于此值视为幻觉
|
|
|
|
def should_eval(self, metric_key: str) -> bool:
|
|
"""判断是否需要计算某个指标"""
|
|
if self.selected_metrics:
|
|
return metric_key in self.selected_metrics
|
|
# 向后兼容:未指定 selected_metrics 时按 eval_retrieval/eval_generation 开关
|
|
if metric_key in RETRIEVAL_METRIC_KEYS:
|
|
return self.eval_retrieval
|
|
if metric_key in GENERATION_METRIC_KEYS:
|
|
return self.eval_generation
|
|
return True
|
|
|
|
@property
|
|
def need_retrieval(self) -> bool:
|
|
if self.selected_metrics:
|
|
return bool(set(self.selected_metrics) & RETRIEVAL_METRIC_KEYS)
|
|
return self.eval_retrieval
|
|
|
|
@property
|
|
def need_generation(self) -> bool:
|
|
if self.selected_metrics:
|
|
return bool(set(self.selected_metrics) & GENERATION_METRIC_KEYS)
|
|
return self.eval_generation
|
|
|
|
|
|
class EvalRunner:
|
|
def __init__(self, adapter: RAGAdapter, judge: LLMJudge):
|
|
self.adapter = adapter
|
|
self.judge = judge
|
|
|
|
async def run(
|
|
self,
|
|
dataset: EvalDataset | str,
|
|
config: RunConfig,
|
|
progress_cb: Callable[[int, int], None] | None = None,
|
|
) -> EvalReport:
|
|
"""
|
|
运行完整评测流程。
|
|
|
|
Args:
|
|
dataset: EvalDataset 对象或 JSON 文件路径
|
|
config: 评测配置
|
|
progress_cb: 进度回调 (finished, total)
|
|
"""
|
|
if isinstance(dataset, str):
|
|
import json
|
|
with open(dataset, encoding="utf-8") as f:
|
|
dataset = EvalDataset.from_dict(json.load(f))
|
|
|
|
samples = dataset.samples
|
|
total = len(samples)
|
|
results: list[SampleResult] = []
|
|
finished = 0
|
|
|
|
sem = asyncio.Semaphore(config.concurrency)
|
|
|
|
async def _eval_one(sample: EvalSample) -> SampleResult:
|
|
async with sem:
|
|
return await self._eval_sample(sample, config)
|
|
|
|
tasks = [_eval_one(s) for s in samples]
|
|
|
|
for coro in asyncio.as_completed(tasks):
|
|
result = await coro
|
|
results.append(result)
|
|
finished += 1
|
|
if progress_cb:
|
|
progress_cb(finished, total)
|
|
|
|
return self._build_report(
|
|
task_id=uuid.uuid4().hex,
|
|
dataset=dataset,
|
|
results=results,
|
|
config=config,
|
|
)
|
|
|
|
async def _eval_sample(self, sample: EvalSample, config: RunConfig) -> SampleResult:
|
|
result = SampleResult(
|
|
sample_id=sample.id,
|
|
question=sample.question,
|
|
reference_answer=sample.reference_answer,
|
|
)
|
|
try:
|
|
# ── Step 1: Retrieval ─────────────────────────────────────────
|
|
if config.need_retrieval:
|
|
chunks = await self.adapter.retrieve(
|
|
query=sample.question,
|
|
knowledge_hub_id=config.knowledge_hub_id,
|
|
top_k=config.top_k,
|
|
file_id_list=config.file_id_list,
|
|
)
|
|
result.retrieved_chunk_ids = [c.chunk_id for c in chunks]
|
|
result.retrieved_chunks = [c.content for c in chunks]
|
|
|
|
# Rule-based metrics
|
|
if sample.relevant_chunk_ids:
|
|
if config.should_eval("hit_rate"):
|
|
result.hit_rate = hit_rate(result.retrieved_chunk_ids, sample.relevant_chunk_ids)
|
|
if config.should_eval("mrr"):
|
|
result.mrr = mrr(result.retrieved_chunk_ids, sample.relevant_chunk_ids)
|
|
if config.should_eval("ndcg"):
|
|
result.ndcg = ndcg(result.retrieved_chunk_ids, sample.relevant_chunk_ids, k=config.top_k)
|
|
|
|
# LLM-as-Judge retrieval metrics
|
|
if sample.reference_answer and result.retrieved_chunks:
|
|
if config.should_eval("context_precision"):
|
|
cp, raw_cp = await self.judge.score_context_precision(
|
|
sample.question, sample.reference_answer, result.retrieved_chunks
|
|
)
|
|
result.context_precision = cp
|
|
result.judge_detail["context_precision"] = raw_cp
|
|
|
|
if config.should_eval("context_recall"):
|
|
cr, raw_cr = await self.judge.score_context_recall(
|
|
sample.reference_answer, result.retrieved_chunks
|
|
)
|
|
result.context_recall = cr
|
|
result.judge_detail["context_recall"] = raw_cr
|
|
|
|
# ── Step 2: Generation ────────────────────────────────────────
|
|
if config.need_generation:
|
|
agent_resp = await self.adapter.chat(
|
|
query=sample.question,
|
|
agent_id=config.agent_id,
|
|
)
|
|
result.agent_answer = agent_resp.answer
|
|
result.latency_ms = agent_resp.latency_ms
|
|
|
|
# 若检索阶段被跳过,单独 retrieve 一次以支撑生成指标评判
|
|
if not result.retrieved_chunks:
|
|
try:
|
|
chunks = await self.adapter.retrieve(
|
|
query=sample.question,
|
|
knowledge_hub_id=config.knowledge_hub_id,
|
|
top_k=config.top_k,
|
|
file_id_list=config.file_id_list,
|
|
)
|
|
result.retrieved_chunk_ids = [c.chunk_id for c in chunks]
|
|
result.retrieved_chunks = [c.content for c in chunks]
|
|
except Exception:
|
|
pass
|
|
|
|
if result.agent_answer and result.retrieved_chunks:
|
|
if config.should_eval("faithfulness"):
|
|
faith, raw_faith = await self.judge.score_faithfulness(
|
|
result.agent_answer, result.retrieved_chunks
|
|
)
|
|
result.faithfulness = faith
|
|
result.judge_detail["faithfulness"] = raw_faith
|
|
|
|
if config.should_eval("answer_relevance"):
|
|
rel, raw_rel = await self.judge.score_relevance(
|
|
sample.question, result.agent_answer
|
|
)
|
|
result.answer_relevance = rel
|
|
result.judge_detail["answer_relevance"] = raw_rel
|
|
|
|
if config.should_eval("groundedness"):
|
|
ground, raw_ground = await self.judge.score_groundedness(
|
|
result.agent_answer,
|
|
[{"content": c} for c in result.retrieved_chunks],
|
|
)
|
|
result.groundedness = ground
|
|
result.judge_detail["groundedness"] = raw_ground
|
|
|
|
if config.should_eval("answer_correctness") and sample.reference_answer:
|
|
corr, raw_corr = await self.judge.score_correctness(
|
|
result.agent_answer, sample.reference_answer
|
|
)
|
|
result.answer_correctness = corr
|
|
result.judge_detail["answer_correctness"] = raw_corr
|
|
|
|
except Exception as exc:
|
|
result.error = str(exc)
|
|
|
|
return result
|
|
|
|
def _build_report(
|
|
self,
|
|
task_id: str,
|
|
dataset: EvalDataset,
|
|
results: list[SampleResult],
|
|
config: RunConfig,
|
|
) -> EvalReport:
|
|
def _avg(vals: list[float]) -> float | None:
|
|
v = [x for x in vals if x is not None]
|
|
return round(sum(v) / len(v), 4) if v else None
|
|
|
|
def _collect(attr: str) -> list[float]:
|
|
return [getattr(r, attr) for r in results if getattr(r, attr) is not None]
|
|
|
|
avg_hit_rate = _avg(_collect("hit_rate"))
|
|
avg_mrr = _avg(_collect("mrr"))
|
|
avg_ndcg = _avg(_collect("ndcg"))
|
|
avg_ctx_prec = _avg(_collect("context_precision"))
|
|
avg_ctx_rec = _avg(_collect("context_recall"))
|
|
avg_faithfulness = _avg(_collect("faithfulness"))
|
|
avg_answer_relevance = _avg(_collect("answer_relevance"))
|
|
avg_answer_correctness= _avg(_collect("answer_correctness"))
|
|
avg_groundedness = _avg(_collect("groundedness"))
|
|
|
|
# RAG Score: harmonic mean of four core metrics
|
|
core = [s for s in [avg_faithfulness, avg_answer_relevance, avg_ctx_prec, avg_ctx_rec]
|
|
if s is not None and s > 0]
|
|
rag_score = round(len(core) / sum(1 / s for s in core), 4) if core else None
|
|
|
|
# Hallucination Rate
|
|
faith_vals = _collect("faithfulness")
|
|
hallucination_rate = (
|
|
round(sum(1 for f in faith_vals if f < config.faithfulness_threshold) / len(faith_vals), 4)
|
|
if faith_vals else None
|
|
)
|
|
|
|
return EvalReport(
|
|
task_id=task_id,
|
|
dataset_name=dataset.name,
|
|
sample_count=len(results),
|
|
results=results,
|
|
avg_hit_rate=avg_hit_rate,
|
|
avg_mrr=avg_mrr,
|
|
avg_ndcg=avg_ndcg,
|
|
avg_context_precision=avg_ctx_prec,
|
|
avg_context_recall=avg_ctx_rec,
|
|
avg_faithfulness=avg_faithfulness,
|
|
avg_answer_relevance=avg_answer_relevance,
|
|
avg_answer_correctness=avg_answer_correctness,
|
|
avg_groundedness=avg_groundedness,
|
|
rag_score=rag_score,
|
|
hallucination_rate=hallucination_rate,
|
|
)
|