dagent_eval/sdk/rag_eval/judge/openai_compatible.py

289 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
from openai import AsyncOpenAI
from .base import LLMJudge
# ── Prompts ───────────────────────────────────────────────────────────────────
_DECOMPOSE_PROMPT = """请将以下回答分解为独立的原子声明列表,每条声明是一个不可再分的事实陈述。
回答:{answer}
只输出 JSON 数组,格式:["声明1", "声明2", ...]"""
_VERIFY_CLAIM_PROMPT = """参考资料:
{context}
声明:{claim}
上述声明是否可以从参考资料中推导出来?只回答 yes 或 no。"""
_RELEVANCE_GEN_PROMPT = """基于以下回答,生成 3 个该回答可能在回答的问题。
回答:{answer}
只输出 JSON 数组,格式:["问题1", "问题2", "问题3"]"""
_CORRECTNESS_PROMPT = """请评估以下回答与参考答案的事实一致程度。
参考答案:{reference}
待评估回答:{answer}
请从以下维度评估:
1. 事实一致性:回答中的事实与参考答案是否一致
2. 信息完整性:回答是否覆盖了参考答案的关键信息
3. 有无错误信息:回答是否包含参考答案中没有的错误内容
输出 JSON
{{"score": 0到1之间的小数, "reason": "简短理由", "factual_tp": 正确事实数, "factual_fp": 错误事实数, "factual_fn": 遗漏事实数}}"""
_GROUNDEDNESS_PROMPT = """以下是检索到的切片列表(带编号):
{numbered_chunks}
AI 回答:{answer}
请将回答分解为原子声明,并为每条声明标注支撑它的切片编号(无支撑则填 null
输出 JSON{{"claims": [{{"text": "声明内容", "source_chunk_index": 1}}, {{"text": "声明内容", "source_chunk_index": null}}]}}"""
_CONTEXT_PRECISION_PROMPT = """问题:{question}
参考答案:{ground_truth}
以下是检索系统返回的文档片段列表:
{chunks_text}
请判断每个片段对于回答该问题是否有用。
输出 JSON{{"results": [{{"index": 1, "useful": true, "reason": "简短理由"}}]}}"""
_CONTEXT_RECALL_PROMPT = """参考答案:{ground_truth}
检索到的文档内容(合并):
{retrieved_context}
请将参考答案拆分为若干独立陈述,判断每个陈述是否能在检索文档中找到支撑。
输出 JSON{{"statements": [{{"text": "陈述内容", "supported": true}}]}}"""
class OpenAICompatibleJudge(LLMJudge):
"""
兼容所有 OpenAI 协议的模型DeepSeek / Qwen / OpenAI / Azure OpenAI
评判逻辑使用中文 prompt适合中文 RAG 场景
"""
def __init__(
self,
base_url: str,
api_key: str,
model: str,
embed_base_url: str = "",
embed_api_key: str = "",
embed_model: str = "text-embedding-3-small",
):
self.client = AsyncOpenAI(
base_url=base_url or None,
api_key=api_key,
)
self.model = model
# 独立的 embedding client可与 LLM 使用不同的 endpoint
self.embed_client = AsyncOpenAI(
base_url=embed_base_url or base_url or None,
api_key=embed_api_key or api_key,
)
self.embed_model = embed_model
async def _call(self, prompt: str) -> str:
resp = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0,
)
return (resp.choices[0].message.content or "").strip()
async def _call_json(self, prompt: str) -> dict | list:
resp = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0,
)
raw = (resp.choices[0].message.content or "").strip()
# 去掉 markdown 代码块包装(```json ... ``` 或 ``` ... ```
if raw.startswith("```"):
lines = raw.splitlines()
# 去掉首行(```json 或 ```)和末行(```
inner = lines[1:] if lines[0].startswith("```") else lines
if inner and inner[-1].strip() == "```":
inner = inner[:-1]
raw = "\n".join(inner).strip()
try:
return json.loads(raw)
except json.JSONDecodeError:
# 尝试提取第一个 JSON 对象或数组
import re
m = re.search(r'(\{[\s\S]*\}|\[[\s\S]*\])', raw)
if m:
try:
return json.loads(m.group(1))
except json.JSONDecodeError:
pass
return {}
# ── Faithfulness两步法────────────────────────────────────────────────
async def score_faithfulness(self, answer: str, context: list[str]) -> tuple[float, dict]:
if not answer or not context:
return 0.0, {}
# Step 1: 分解为原子声明
raw_claims = await self._call_json(
_DECOMPOSE_PROMPT.format(answer=answer)
)
if isinstance(raw_claims, list):
claims = raw_claims
else:
claims = raw_claims.get("items", []) or raw_claims.get("claims", [])
if not claims:
return 0.0, {"claims": []}
context_text = "\n\n".join(c[:800] for c in context)
# Step 2: 逐条验证(并发)
async def _verify(claim: str) -> bool:
result = await self._call(
_VERIFY_CLAIM_PROMPT.format(context=context_text, claim=claim)
)
return "yes" in result.lower()
results = await asyncio.gather(*[_verify(c) for c in claims])
supported = sum(results)
score = round(supported / len(claims), 4)
detail = {
"claims": [
{"text": c, "supported": bool(r)}
for c, r in zip(claims, results)
]
}
return score, detail
# ── Answer Relevance反向生成 + 语义相似)───────────────────────────────
async def score_relevance(self, question: str, answer: str) -> tuple[float, dict]:
if not answer:
return 0.0, {}
raw = await self._call_json(
_RELEVANCE_GEN_PROMPT.format(answer=answer)
)
if isinstance(raw, list):
gen_questions = raw
else:
gen_questions = raw.get("items", []) or raw.get("questions", [])
if not gen_questions:
return 0.0, {}
# 用 embedding cosine 相似度计算
scores = await asyncio.gather(*[
self._embedding_similarity(question, q) for q in gen_questions
])
avg = round(sum(scores) / len(scores), 4)
return avg, {"generated_questions": gen_questions, "similarities": list(scores)}
async def _embedding_similarity(self, text_a: str, text_b: str) -> float:
import numpy as np
resp = await self.embed_client.embeddings.create(
model=self.embed_model,
input=[text_a, text_b],
)
a = np.array(resp.data[0].embedding)
b = np.array(resp.data[1].embedding)
cos = float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-9))
return round(max(0.0, cos), 4)
# ── Answer Correctness ───────────────────────────────────────────────────
async def score_correctness(self, answer: str, reference: str) -> tuple[float, dict]:
if not answer or not reference:
return 0.0, {}
raw = await self._call_json(
_CORRECTNESS_PROMPT.format(reference=reference, answer=answer)
)
try:
score = float(raw.get("score", 0.0))
except (TypeError, ValueError):
score = 0.0
tp = raw.get("factual_tp", 0) or 0
fp = raw.get("factual_fp", 0) or 0
fn = raw.get("factual_fn", 0) or 0
f1 = (2 * tp / (2 * tp + fp + fn)) if (2 * tp + fp + fn) > 0 else 0.0
final = round(0.75 * f1 + 0.25 * score, 4)
return final, raw
# ── Groundedness可溯源性──────────────────────────────────────────────
async def score_groundedness(self, answer: str, chunks: list[dict]) -> tuple[float, dict]:
if not answer or not chunks:
return 0.0, {}
numbered = "\n".join(
f"[{i+1}] {c.get('content', '')[:500]}" for i, c in enumerate(chunks)
)
raw = await self._call_json(
_GROUNDEDNESS_PROMPT.format(numbered_chunks=numbered, answer=answer)
)
claims = raw.get("claims", [])
if not claims:
return 0.0, raw
grounded = sum(1 for c in claims if c.get("source_chunk_index") is not None)
score = round(grounded / len(claims), 4)
return score, raw
# ── Context Precision ────────────────────────────────────────────────────
async def score_context_precision(
self, question: str, ground_truth: str, retrieved_chunks: list[str]
) -> tuple[float, dict]:
if not retrieved_chunks or not ground_truth:
return 0.0, {}
chunks_text = "\n".join(f"[{i+1}] {c[:500]}" for i, c in enumerate(retrieved_chunks))
raw = await self._call_json(
_CONTEXT_PRECISION_PROMPT.format(
question=question, ground_truth=ground_truth, chunks_text=chunks_text
)
)
results = raw.get("results", [])
if not results:
return 0.0, raw
useful_flags = [
r.get("useful", False)
for r in sorted(results, key=lambda x: x.get("index", 0))
]
# Weighted precision@k
score = sum(
(sum(useful_flags[:k+1]) / (k+1)) * useful_flags[k]
for k in range(len(useful_flags))
) / max(sum(useful_flags), 1)
return round(min(score, 1.0), 4), raw
# ── Context Recall ───────────────────────────────────────────────────────
async def score_context_recall(
self, ground_truth: str, retrieved_chunks: list[str]
) -> tuple[float, dict]:
if not retrieved_chunks or not ground_truth:
return 0.0, {}
retrieved_context = "\n\n".join(c[:800] for c in retrieved_chunks)
raw = await self._call_json(
_CONTEXT_RECALL_PROMPT.format(
ground_truth=ground_truth, retrieved_context=retrieved_context
)
)
statements = raw.get("statements", [])
if not statements:
return 0.0, raw
supported = sum(1 for s in statements if s.get("supported"))
return round(supported / len(statements), 4), raw