Initial release: RAG Eval platform
This commit is contained in:
commit
22ef0c8bb1
56
.gitignore
vendored
Normal file
56
.gitignore
vendored
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
.Python
|
||||||
|
*.egg-info/
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
|
||||||
|
# Secrets & local dagent / judge config (use sdk/config.example.yaml)
|
||||||
|
sdk/config.yaml
|
||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
*.pem
|
||||||
|
|
||||||
|
# Database & local runtime data
|
||||||
|
server/data/
|
||||||
|
*.db
|
||||||
|
*.db-journal
|
||||||
|
*.db-wal
|
||||||
|
*.db-shm
|
||||||
|
|
||||||
|
# Node
|
||||||
|
frontend/node_modules/
|
||||||
|
frontend/dist/
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
/tmp/
|
||||||
|
|
||||||
|
# Knowledge base exports & batch run artifacts (do not publish)
|
||||||
|
docs/exports/
|
||||||
|
docs/task_groups_plan.json
|
||||||
|
docs/循环测试_14组分批规则.md
|
||||||
|
all_chunks.json
|
||||||
|
chunk_batches_*.json
|
||||||
|
page*.json
|
||||||
|
file_ids.txt
|
||||||
|
file_list.txt
|
||||||
|
task_groups.db
|
||||||
|
|
||||||
|
# Ops scripts with embedded org/env (local batch runs only)
|
||||||
|
server/scripts/batch_create_by_files.py
|
||||||
|
server/scripts/batch_create_tasks.py
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
14
Dockerfile.frontend
Normal file
14
Dockerfile.frontend
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
FROM node:20-alpine AS builder
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
COPY frontend/package.json ./
|
||||||
|
RUN npm install
|
||||||
|
|
||||||
|
COPY frontend/ ./
|
||||||
|
RUN npm run build
|
||||||
|
|
||||||
|
FROM nginx:alpine
|
||||||
|
COPY --from=builder /app/dist /usr/share/nginx/html
|
||||||
|
COPY nginx.conf /etc/nginx/conf.d/default.conf
|
||||||
|
|
||||||
|
EXPOSE 80
|
||||||
19
Dockerfile.server
Normal file
19
Dockerfile.server
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
FROM python:3.10-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install server dependencies
|
||||||
|
COPY server/requirements.txt .
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# Copy SDK (server imports it at runtime)
|
||||||
|
COPY sdk/ ./sdk/
|
||||||
|
|
||||||
|
# Copy server code
|
||||||
|
COPY server/ ./server/
|
||||||
|
|
||||||
|
WORKDIR /app/server
|
||||||
|
|
||||||
|
EXPOSE 8003
|
||||||
|
|
||||||
|
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8003"]
|
||||||
151
README.md
Normal file
151
README.md
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
# RAG Eval Framework
|
||||||
|
|
||||||
|
平台无关的 **RAG 评测平台**,面向 dagent 及任意兼容 HTTP 接口的 RAG 系统,提供检索层 + 生成层全指标评测、LLM 自动出题、单跳/多跳召回测试与循环压测能力。
|
||||||
|
|
||||||
|
| 使用方式 | 说明 |
|
||||||
|
|----------|------|
|
||||||
|
| **Web UI** | React + Ant Design,配置 / 测试集 / 任务 / 报告一站式操作 |
|
||||||
|
| **REST API** | FastAPI,11 组路由,OpenAPI 文档 `/docs` |
|
||||||
|
| **Python SDK** | `EvalRunner` + CLI,可嵌入 CI/CD |
|
||||||
|
|
||||||
|
📖 **详细技术文档(万字级,含架构图与时序图)**:[docs/RAG-Eval平台技术规格说明书.md](./docs/RAG-Eval平台技术规格说明书.md)
|
||||||
|
📁 **方案、分批规则与数据导出**:[docs/](./docs/)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 功能一览
|
||||||
|
|
||||||
|
| 模块 | 能力 |
|
||||||
|
|------|------|
|
||||||
|
| **综合评测** | Hit@K、MRR、NDCG、Context Precision/Recall、Faithfulness、Answer Relevance/Correctness、Groundedness、RAG Score |
|
||||||
|
| **测试集** | 手动录入、JSON 导入、LLM 按知识库文件自动生成 |
|
||||||
|
| **单跳召回** | 上传 MD 问答集,映射 file_id,批量语义召回与命中率统计 |
|
||||||
|
| **多跳召回** | 多跳问题解析、分跳召回与全链路命中判定 |
|
||||||
|
| **问题生成** | 按切片 LLM 出题、质量打分、向量去重、人工审核 |
|
||||||
|
| **循环测试** | 多轮「出题 → 去重 → 单跳验证」闭环,支持暂停/恢复/导出 |
|
||||||
|
| **提示词模板** | 可配置出题 / 评判 Prompt |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 架构概览
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────┐ ┌──────────────────┐ ┌─────────────┐
|
||||||
|
│ React UI │────▶│ FastAPI :8021 │────▶│ SQLite DB │
|
||||||
|
│ (Vite) │ │ + 11 API 路由 │ │ (WAL) │
|
||||||
|
└─────────────┘ └────────┬─────────┘ └─────────────┘
|
||||||
|
│ sys.path → sdk/
|
||||||
|
▼
|
||||||
|
┌──────────────────┐
|
||||||
|
│ rag_eval SDK │
|
||||||
|
│ Adapter/Judge/ │
|
||||||
|
│ Runner/Parser │
|
||||||
|
└────────┬─────────┘
|
||||||
|
│ HTTP
|
||||||
|
▼
|
||||||
|
┌──────────────────┐
|
||||||
|
│ dagent / 其他 │
|
||||||
|
│ RAG 平台 │
|
||||||
|
└──────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 项目结构
|
||||||
|
|
||||||
|
```
|
||||||
|
rag-eval/
|
||||||
|
├── docs/ # 技术文档、分批规则、数据导出
|
||||||
|
├── sdk/rag_eval/ # 核心评测逻辑(Adapter、Judge、Runner…)
|
||||||
|
├── server/ # FastAPI 后端
|
||||||
|
│ ├── api/ # REST 路由(config/dataset/task/report/…)
|
||||||
|
│ ├── service/ # 任务编排(task_service、loop_engine)
|
||||||
|
│ └── models/ # SQLite schema + 迁移
|
||||||
|
├── frontend/ # React Web UI
|
||||||
|
├── docker-compose.yml
|
||||||
|
└── README.md
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
### Docker Compose
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd rag-eval
|
||||||
|
docker-compose up -d
|
||||||
|
# Web UI: http://localhost:3000 | API: http://localhost:8003/docs
|
||||||
|
```
|
||||||
|
|
||||||
|
### 本地开发
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 后端(默认 8021,可改端口)
|
||||||
|
cd server && pip install -r requirements.txt
|
||||||
|
uvicorn main:app --host 0.0.0.0 --port 8021 --reload
|
||||||
|
|
||||||
|
# 前端(开发代理到 8021)
|
||||||
|
cd frontend && npm install && npm run dev
|
||||||
|
|
||||||
|
# SDK
|
||||||
|
cd sdk && pip install -e .
|
||||||
|
rag-eval run --config config.yaml --dataset dataset.json --output report.json
|
||||||
|
```
|
||||||
|
|
||||||
|
生产环境可将 `frontend/dist` 构建产物由 FastAPI `StaticFiles` 挂载,单端口对外。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 典型工作流
|
||||||
|
|
||||||
|
1. **配置管理** — 添加 dagent `base_url` / `org_id` 与 Judge(OpenAI 兼容)模型
|
||||||
|
2. **测试集** — 导入 JSON、手动添加或 LLM 自动生成
|
||||||
|
3. **评测任务** — 选择指标子集,后台异步跑批,查看雷达图与 AI 解读
|
||||||
|
4. **单跳/多跳/循环** — 见 [技术规格说明书 · 业务流程](./docs/RAG-Eval平台技术规格说明书.md#6-业务流程与时序图)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 评测指标速查
|
||||||
|
|
||||||
|
| 层级 | 指标 | 类型 |
|
||||||
|
|------|------|------|
|
||||||
|
| 检索 | Hit Rate@K、MRR@K、NDCG@K | 规则(需 `relevant_chunk_ids`) |
|
||||||
|
| 检索 | Context Precision / Recall | LLM-as-Judge |
|
||||||
|
| 生成 | Faithfulness、Groundedness | LLM-as-Judge |
|
||||||
|
| 生成 | Answer Relevance | LLM + Embedding |
|
||||||
|
| 生成 | Answer Correctness | LLM-as-Judge(需参考答案) |
|
||||||
|
| 综合 | RAG Score、Hallucination Rate | 派生 |
|
||||||
|
|
||||||
|
阈值与解读见技术文档 [第 7 章](./docs/RAG-Eval平台技术规格说明书.md#7-评测指标体系)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 扩展其他 RAG 平台
|
||||||
|
|
||||||
|
实现 `RAGAdapter` 的 `retrieve` 与 `chat` 即可接入:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from rag_eval.adapters.base import RAGAdapter, RetrievedChunk, AgentResponse
|
||||||
|
|
||||||
|
class MyAdapter(RAGAdapter):
|
||||||
|
async def retrieve(self, query, knowledge_hub_id, top_k=10, **kwargs) -> list[RetrievedChunk]: ...
|
||||||
|
async def chat(self, query, agent_id, **kwargs) -> AgentResponse: ...
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 相关文档
|
||||||
|
|
||||||
|
| 文档 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| [RAG-Eval平台技术规格说明书.md](./docs/RAG-Eval平台技术规格说明书.md) | 架构、时序图、数据模型、API |
|
||||||
|
| [循环测试_14组分批规则.md](./docs/循环测试_14组分批规则.md) | 远程 dagent 42 批次规划 |
|
||||||
|
| [TUTORIAL.md](./docs/TUTORIAL.md) | 操作教程 |
|
||||||
|
| [rag-eval-framework-design.md](./docs/rag-eval-framework-design.md) | 早期框架设计稿 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
内部项目,使用前请遵循组织代码与数据安全规范。
|
||||||
22
docker-compose.yml
Normal file
22
docker-compose.yml
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
services:
|
||||||
|
server:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile.server
|
||||||
|
ports:
|
||||||
|
- "8003:8003"
|
||||||
|
volumes:
|
||||||
|
- ./data:/app/server/data # SQLite DB persistence
|
||||||
|
environment:
|
||||||
|
- PYTHONUNBUFFERED=1
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
frontend:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile.frontend
|
||||||
|
ports:
|
||||||
|
- "3000:80"
|
||||||
|
depends_on:
|
||||||
|
- server
|
||||||
|
restart: unless-stopped
|
||||||
272
docs/Dagent文件选择器方案.md
Normal file
272
docs/Dagent文件选择器方案.md
Normal file
@ -0,0 +1,272 @@
|
|||||||
|
# Dagent 文件可视化选择器方案
|
||||||
|
|
||||||
|
## 一、需求背景
|
||||||
|
|
||||||
|
当前从 Dagent 导入时,用户需要手动输入逗号分隔的文件 ID,无法直观看到文件内容和进行选择。需要增加可视化文件选择器功能,让用户可以:
|
||||||
|
1. 查看文件列表(文件名、类型、大小、状态)
|
||||||
|
2. 直观选择需要的文件
|
||||||
|
3. 支持全选、搜索、分页
|
||||||
|
|
||||||
|
## 二、当前状态分析
|
||||||
|
|
||||||
|
### 现有 API
|
||||||
|
- `GET /api/qa-gen/dagent/files?org_id=xxx` - 返回 207 个文件的列表
|
||||||
|
- 字段:`id, file_name, file_type, file_clean_status, file_bytes, create_time`
|
||||||
|
|
||||||
|
### 现有前端 UI
|
||||||
|
- 简单的 `Input.TextArea` 用于输入文件 ID
|
||||||
|
- 没有可视化选择界面
|
||||||
|
|
||||||
|
## 三、技术方案
|
||||||
|
|
||||||
|
### 1. 后端 API(无变化)
|
||||||
|
现有 API 已足够,无需新增接口。文件列表数据包含:
|
||||||
|
- `id`:文件唯一标识(用于选择)
|
||||||
|
- `file_name`:文件名(用于展示)
|
||||||
|
- `file_type`:文件类型(HTML/PDF/DOCX)
|
||||||
|
- `file_clean_status`:处理状态(用于状态提示)
|
||||||
|
- `file_bytes`:文件大小(格式化展示)
|
||||||
|
- `create_time`:创建时间
|
||||||
|
|
||||||
|
### 2. 前端组件设计
|
||||||
|
|
||||||
|
#### 2.1 文件选择器组件
|
||||||
|
创建一个独立的文件选择器组件,支持以下功能:
|
||||||
|
|
||||||
|
**UI 元素:**
|
||||||
|
- 文件列表表格(支持多选)
|
||||||
|
- 搜索框(按文件名搜索)
|
||||||
|
- 状态筛选器(按 file_clean_status 筛选)
|
||||||
|
- 全选/反选按钮
|
||||||
|
- 分页组件(每页显示 20 个文件)
|
||||||
|
- 已选择文件计数
|
||||||
|
|
||||||
|
**表格列:**
|
||||||
|
1. 选择列(复选框)
|
||||||
|
2. 文件名(可点击查看详情)
|
||||||
|
3. 文件类型
|
||||||
|
4. 文件大小(格式化为 KB/MB)
|
||||||
|
5. 处理状态(标签显示)
|
||||||
|
6. 创建时间
|
||||||
|
|
||||||
|
#### 2.2 文件详情弹窗
|
||||||
|
点击文件名时显示:
|
||||||
|
- 文件基本信息
|
||||||
|
- 段落统计(如果后端支持)
|
||||||
|
- 预览按钮(如果需要)
|
||||||
|
|
||||||
|
#### 2.3 与现有表单的集成
|
||||||
|
- 使用 `Form.Item` 包裹选择器组件
|
||||||
|
- 选中的文件 ID 存储在隐藏的 `file_ids` 字段中
|
||||||
|
- 保持向后兼容(支持手动输入)
|
||||||
|
|
||||||
|
### 3. 实现步骤
|
||||||
|
|
||||||
|
#### 步骤 1:创建文件选择器组件
|
||||||
|
```typescript
|
||||||
|
// src/components/DagentFileSelector/index.tsx
|
||||||
|
import { useState, useEffect } from 'react'
|
||||||
|
import { Table, Input, Button, Tag, Space, Modal, message, Pagination } from 'antd'
|
||||||
|
import { qaGenApi } from '../../services/api'
|
||||||
|
|
||||||
|
interface FileItem {
|
||||||
|
id: string
|
||||||
|
file_name: string
|
||||||
|
file_type: string
|
||||||
|
file_clean_status: string
|
||||||
|
file_bytes: number
|
||||||
|
create_time: string
|
||||||
|
}
|
||||||
|
|
||||||
|
interface DagentFileSelectorProps {
|
||||||
|
orgId: string
|
||||||
|
value?: string[] // 选中的文件ID数组
|
||||||
|
onChange?: (fileIds: string[]) => void
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 步骤 2:更新 QaGen 页面
|
||||||
|
- 将现有的 `Input.TextArea` 替换为 `DagentFileSelector`
|
||||||
|
- 保留原有的 `file_ids` 字段作为隐藏字段
|
||||||
|
- 添加文件选择器触发按钮
|
||||||
|
|
||||||
|
#### 步骤 3:添加交互逻辑
|
||||||
|
- 点击"选择文件"按钮打开选择器弹窗
|
||||||
|
- 选择完成后关闭弹窗,更新隐藏字段
|
||||||
|
- 显示已选择的文件数量和文件名摘要
|
||||||
|
|
||||||
|
### 4. 状态设计
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const [files, setFiles] = useState<FileItem[]>([])
|
||||||
|
const [loading, setLoading] = useState(false)
|
||||||
|
const [searchText, setSearchText] = useState('')
|
||||||
|
const [selectedRowKeys, setSelectedRowKeys] = useState<string[]>([])
|
||||||
|
const [pagination, setPagination] = useState({ current: 1, pageSize: 20, total: 0 })
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. 文件格式化和状态显示
|
||||||
|
|
||||||
|
**文件大小格式化:**
|
||||||
|
```typescript
|
||||||
|
const formatFileSize = (bytes: number) => {
|
||||||
|
if (bytes < 1024) return `${bytes} B`
|
||||||
|
if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(1)} KB`
|
||||||
|
return `${(bytes / (1024 * 1024)).toFixed(1)} MB`
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**状态标签:**
|
||||||
|
```typescript
|
||||||
|
const statusTag = (status: string) => {
|
||||||
|
const map = {
|
||||||
|
'CLEAN_FINISH': { color: 'success', label: '已处理' },
|
||||||
|
'CLEAN_PROCESSING': { color: 'processing', label: '处理中' },
|
||||||
|
'CLEAN_FAILED': { color: 'error', label: '处理失败' },
|
||||||
|
'UPLOAD_FAILED': { color: 'warning', label: '上传失败' }
|
||||||
|
}
|
||||||
|
const cfg = map[status] || { color: 'default', label: status }
|
||||||
|
return <Tag color={cfg.color}>{cfg.label}</Tag>
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. 性能优化
|
||||||
|
|
||||||
|
1. **分页加载**:每次只加载当前页的文件
|
||||||
|
2. **虚拟滚动**:如果文件数量很多(>1000),考虑虚拟滚动
|
||||||
|
3. **数据缓存**:文件列表数据缓存 5 分钟
|
||||||
|
4. **防抖搜索**:搜索输入使用防抖,避免频繁请求
|
||||||
|
|
||||||
|
### 7. 用户体验设计
|
||||||
|
|
||||||
|
#### 7.1 选择流程
|
||||||
|
1. 用户输入 org_id 并查询统计信息
|
||||||
|
2. 显示"选择文件"按钮(仅在获取到统计信息后启用)
|
||||||
|
3. 点击按钮打开文件选择器
|
||||||
|
4. 选择文件并确认
|
||||||
|
5. 返回表单,显示已选择文件摘要
|
||||||
|
|
||||||
|
#### 7.2 确认对话框
|
||||||
|
用户确认选择时显示:
|
||||||
|
- 已选择文件数量
|
||||||
|
- 预计生成的问题数(文件数 × 段落平均数 × 每段落问题数)
|
||||||
|
- 确认按钮
|
||||||
|
|
||||||
|
### 8. 扩展功能考虑
|
||||||
|
|
||||||
|
#### 8.1 段落预览
|
||||||
|
如果后端支持,可以添加:
|
||||||
|
- `GET /api/qa-gen/dagent/file/{file_id}/paragraphs` - 获取文件段落列表
|
||||||
|
- 点击文件时显示段落预览
|
||||||
|
|
||||||
|
#### 8.2 智能筛选
|
||||||
|
- 按文件类型筛选(HTML/PDF/DOCX)
|
||||||
|
- 按处理状态筛选
|
||||||
|
- 按文件大小筛选
|
||||||
|
|
||||||
|
#### 8.3 批量操作
|
||||||
|
- 按文件夹/目录批量选择
|
||||||
|
- 按文件名模式匹配选择
|
||||||
|
|
||||||
|
## 四、实施计划
|
||||||
|
|
||||||
|
### 第一阶段:基础文件选择器(1-2天)
|
||||||
|
1. 创建 `DagentFileSelector` 组件
|
||||||
|
2. 集成到 QaGen 页面
|
||||||
|
3. 实现基本的多选功能
|
||||||
|
|
||||||
|
### 第二阶段:增强功能(1-2天)
|
||||||
|
1. 添加搜索和筛选功能
|
||||||
|
2. 添加分页支持
|
||||||
|
3. 优化性能和用户体验
|
||||||
|
|
||||||
|
### 第三阶段:高级功能(可选)
|
||||||
|
1. 文件详情预览
|
||||||
|
2. 段落统计显示
|
||||||
|
3. 批量选择模式
|
||||||
|
|
||||||
|
## 五、API 接口说明
|
||||||
|
|
||||||
|
### 现有接口
|
||||||
|
```http
|
||||||
|
GET /api/qa-gen/dagent/files?org_id=xxx
|
||||||
|
```
|
||||||
|
|
||||||
|
### 响应格式
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": 0,
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": "file_123",
|
||||||
|
"file_name": "linux_development.md",
|
||||||
|
"file_type": "html",
|
||||||
|
"file_clean_status": "CLEAN_FINISH",
|
||||||
|
"file_bytes": 20480,
|
||||||
|
"create_time": "2024-01-01 10:00:00"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 六、前端组件结构
|
||||||
|
|
||||||
|
```
|
||||||
|
QaGen/index.tsx
|
||||||
|
├── Form.Item name="file_ids"
|
||||||
|
│ └── <DagentFileSelector>
|
||||||
|
│ ├── <Button>选择文件</Button>
|
||||||
|
│ ├── <Modal>文件选择器
|
||||||
|
│ │ ├── <Input.Search>搜索框
|
||||||
|
│ │ ├── <Table>文件列表
|
||||||
|
│ │ │ ├── 选择列
|
||||||
|
│ │ │ ├── 文件名
|
||||||
|
│ │ │ ├── 文件类型
|
||||||
|
│ │ │ ├── 文件大小
|
||||||
|
│ │ │ ├── 处理状态
|
||||||
|
│ │ │ └── 创建时间
|
||||||
|
│ │ ├── <Pagination>分页
|
||||||
|
│ │ └── <Space>操作按钮
|
||||||
|
│ └── 已选择文件摘要
|
||||||
|
```
|
||||||
|
|
||||||
|
## 七、注意事项
|
||||||
|
|
||||||
|
1. **向后兼容**:保持支持手动输入文件 ID
|
||||||
|
2. **错误处理**:网络错误、空状态处理
|
||||||
|
3. **移动端适配**:表格在小屏幕下的显示优化
|
||||||
|
4. **无障碍访问**:支持键盘导航和屏幕阅读器
|
||||||
|
5. **国际化**:标签和提示语的国际化支持
|
||||||
|
|
||||||
|
## 八、测试计划
|
||||||
|
|
||||||
|
1. **功能测试**:
|
||||||
|
- 文件列表加载
|
||||||
|
- 多选功能
|
||||||
|
- 搜索筛选
|
||||||
|
- 分页切换
|
||||||
|
- 表单数据同步
|
||||||
|
|
||||||
|
2. **性能测试**:
|
||||||
|
- 207 个文件的加载时间
|
||||||
|
- 搜索响应时间
|
||||||
|
- 内存占用
|
||||||
|
|
||||||
|
3. **兼容性测试**:
|
||||||
|
- 不同浏览器
|
||||||
|
- 不同屏幕尺寸
|
||||||
|
- 键盘操作
|
||||||
|
|
||||||
|
## 九、风险评估
|
||||||
|
|
||||||
|
1. **API 性能**:207 个文件一次性加载可能较慢 → 实施分页
|
||||||
|
2. **内存占用**:大量 DOM 元素可能影响性能 → 虚拟滚动
|
||||||
|
3. **用户体验**:选择过程复杂 → 简化操作流程
|
||||||
|
4. **向后兼容**:确保现有手动输入功能正常工作
|
||||||
|
|
||||||
|
## 十、成功指标
|
||||||
|
|
||||||
|
1. **功能完整性**:100% 覆盖需求功能
|
||||||
|
2. **性能指标**:文件列表加载时间 < 2 秒
|
||||||
|
3. **用户体验**:选择流程步骤 ≤ 3 步
|
||||||
|
4. **代码质量**:无 TypeScript 错误,测试覆盖率 > 80%
|
||||||
122
docs/EVB知识库单跳召回测试报告.md
Normal file
122
docs/EVB知识库单跳召回测试报告.md
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
# EVB 知识库单跳召回测试报告
|
||||||
|
|
||||||
|
**测试时间:** 2026-04-21
|
||||||
|
**测试范围:** EVB 知识库全量 7 个模块
|
||||||
|
**测试方法:** 单跳语义召回(cross_chunk 模式,top_k=3)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 一、总体概览
|
||||||
|
|
||||||
|
| 指标 | 数值 |
|
||||||
|
|------|------|
|
||||||
|
| 总问题数 | **12,591** |
|
||||||
|
| 召回成功率 | **100%**(12,591 / 12,591) |
|
||||||
|
| 文件命中率 | **63.1%**(7,849 / 12,591) |
|
||||||
|
| 文件命中失败 | **4,742 条** |
|
||||||
|
| 平均最佳余弦相似度 | **0.868** |
|
||||||
|
| 平均召回延迟 | **432 ms** |
|
||||||
|
| 覆盖章节数 | **171 个** |
|
||||||
|
|
||||||
|
> **召回成功率 100%** 说明知识库语义索引完整,所有问题均能检索到相关内容。
|
||||||
|
> **文件命中率 63.1%** 是核心问题:召回的 top-k 结果中,有 36.9% 的问题未能命中预期文件,说明跨文件语义干扰较严重。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 二、分模块统计
|
||||||
|
|
||||||
|
| 模块 | 问题数 | 召回率 | 文件命中率 | 命中失败 | 平均相似度 | 平均延迟 | 章节数 |
|
||||||
|
|------|--------|--------|-----------|---------|-----------|---------|--------|
|
||||||
|
| linux_development | 7,455 | 100% | **63.7%** | 2,703 | 0.864 | 433ms | 107 |
|
||||||
|
| multimedia_development | 2,307 | 100% | **68.8%** | 720 | 0.880 | 425ms | 25 |
|
||||||
|
| samples | 1,374 | 100% | **53.6%** | 637 | 0.872 | 434ms | 19 |
|
||||||
|
| toolchain_development | 832 | 100% | **57.3%** | 355 | 0.859 | 423ms | 13 |
|
||||||
|
| quick_start | 483 | 100% | **47.6%** | 253 | 0.869 | 441ms | 5 |
|
||||||
|
| preface | 86 | 100% | **30.2%** | 60 | 0.867 | 469ms | 1 |
|
||||||
|
| common_questions | 54 | 100% | **74.1%** | 14 | 0.887 | 476ms | 1 |
|
||||||
|
|
||||||
|
**最佳模块:** common_questions(74.1%)、multimedia_development(68.8%)
|
||||||
|
**最差模块:** preface(30.2%)、quick_start(47.6%)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 三、文件命中率最差章节 TOP 20
|
||||||
|
|
||||||
|
| 模块 | 章节路径(截断) | 命中率 | 问题数 | 平均相似度 |
|
||||||
|
|------|----------------|--------|--------|-----------|
|
||||||
|
| multimedia_development | multimedia_development / 8-GDC_index_zh_ | **0%** | 68 | 0.868 |
|
||||||
|
| toolchain_development | toolchain_development / expert / environ | **2%** | 35 | 0.836 |
|
||||||
|
| samples | samples / sunrise_camera_develop_guide | **12%** | 141 | 0.838 |
|
||||||
|
| linux_development | linux_development / system_debug / ddr (1) | **13%** | 30 | 0.853 |
|
||||||
|
| samples | samples / overview | **13%** | 65 | 0.874 |
|
||||||
|
| linux_development | linux_command_manual (1) | **15%** | 71 | 0.860 |
|
||||||
|
| linux_development | linux_development / system_debug / ddr (2) | **18%** | 75 | 0.865 |
|
||||||
|
| quick_start | quick_start / x5_evb_1_b_user_guide | **21%** | 131 | 0.863 |
|
||||||
|
| toolchain_development | toolchain_development / expert / quick_s | **22%** | 40 | 0.840 |
|
||||||
|
| linux_development | linux_development / system_debug / ddr (3) | **23%** | 51 | 0.837 |
|
||||||
|
| samples | samples / sample_osd | **26%** | 50 | 0.867 |
|
||||||
|
| samples | samples / sample_hbmem | **26%** | 68 | 0.857 |
|
||||||
|
| linux_development | linux_development / driver_develop_guide (1) | **28%** | 50 | 0.851 |
|
||||||
|
| preface | preface / overview | **30%** | 86 | 0.867 |
|
||||||
|
| linux_development | system_component_dev | **31%** | 51 | 0.849 |
|
||||||
|
| samples | samples / sample_imu | **31%** | 72 | 0.868 |
|
||||||
|
| linux_development | linux_command_manual (2) | **32%** | 134 | 0.849 |
|
||||||
|
| quick_start | quick_start / x5_evb_v2p0_user_guide | **33%** | 107 | 0.878 |
|
||||||
|
| linux_development | linux_development / driver_develop_guide (2) | **33%** | 59 | 0.857 |
|
||||||
|
| samples | samples / sample_trustzone | **34%** | 50 | 0.881 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 四、问题诊断
|
||||||
|
|
||||||
|
### 4.1 文件命中率低的根本原因
|
||||||
|
|
||||||
|
**相似度高但命中率低**(如 multimedia_development/GDC 章节:sim=0.868 但命中率 0%)说明问题不是语义索引质量差,而是:
|
||||||
|
|
||||||
|
1. **知识库文件粒度过粗**:多个文档内容高度相似(如不同版本的 EVB 用户手册、多个 DDR 调试文档),导致召回时命中了语义相近但文件不同的内容
|
||||||
|
2. **章节路径与文件名映射偏差**:部分章节(如 GDC、preface/overview)在知识库中对应的文件名与 MD 路径差异较大,文件映射失败
|
||||||
|
3. **跨文件语义干扰**:samples 模块各 sample 文档结构相似(都有 overview、API 说明),问题语义相近导致召回串文件
|
||||||
|
|
||||||
|
### 4.2 各模块特征分析
|
||||||
|
|
||||||
|
- **linux_development**(最大模块,107 章节):整体命中率 63.7%,DDR 调试相关章节命中率极低(13-23%),推测是多个 DDR 相关文档内容重叠
|
||||||
|
- **multimedia_development**:GDC 章节 0% 命中,需检查该章节的文件映射是否正确
|
||||||
|
- **samples**:命中率最低(53.6%),各 sample 文档结构高度相似是主因
|
||||||
|
- **preface**:仅 30.2%,overview 章节内容通用性强,容易被其他文档的 overview 内容干扰
|
||||||
|
- **common_questions**:命中率最高(74.1%),FAQ 类问题语义独特性强
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 五、优化建议
|
||||||
|
|
||||||
|
### 短期(知识库配置层面)
|
||||||
|
|
||||||
|
| 优先级 | 建议 | 预期收益 |
|
||||||
|
|--------|------|---------|
|
||||||
|
| 🔴 高 | 检查并修复 multimedia_development/GDC 章节的文件映射 | 68 条 0% 命中问题 |
|
||||||
|
| 🔴 高 | 对 DDR 调试相关文档(3 个重叠章节)合并或增加文件标识元数据 | ~156 条低命中问题 |
|
||||||
|
| 🟡 中 | samples 模块各文档增加文件级别的唯一标识前缀(如文件名注入到 chunk) | 637 条命中失败 |
|
||||||
|
| 🟡 中 | quick_start 两个版本手册(1_b 和 v2p0)内容重叠,考虑合并或版本标注 | 384 条命中失败 |
|
||||||
|
| 🟢 低 | preface/overview 内容过于通用,考虑增加文档标题作为 chunk 前缀 | 60 条命中失败 |
|
||||||
|
|
||||||
|
### 中期(召回策略层面)
|
||||||
|
|
||||||
|
1. **降低 top_k**:当前 top_k=3,对于高相似度干扰场景可尝试 top_k=1 测试精确命中率
|
||||||
|
2. **文件级过滤**:对已知文件映射的章节,在召回时传入 `file_id_list` 限定范围(关闭 cross_chunk)
|
||||||
|
3. **Rerank 优化**:在 rerank 阶段引入文件来源权重,同文件内的 chunk 给予加分
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 六、测试执行情况
|
||||||
|
|
||||||
|
| 模块 | 开始时间 | 结束时间 | 耗时 |
|
||||||
|
|------|---------|---------|------|
|
||||||
|
| linux_development | 03:01:38 | 03:12:27 | **10m 49s** |
|
||||||
|
| multimedia_development | 03:12:08 | 03:15:26 | **3m 18s** |
|
||||||
|
| quick_start | 03:21:33 | 03:21:46 | **13s** |
|
||||||
|
| samples | 03:22:56 | 03:23:26 | **30s** |
|
||||||
|
| toolchain_development | 03:23:52 | 03:24:12 | **20s** |
|
||||||
|
| preface | 03:24:25 | 03:24:28 | **3s** |
|
||||||
|
| common_questions | 03:24:59 | 03:25:01 | **2s** |
|
||||||
|
|
||||||
|
总测试耗时约 **24 分钟**,12,591 条问题全部完成,无错误。
|
||||||
514
docs/LLM自动生成问题方案.md
Normal file
514
docs/LLM自动生成问题方案.md
Normal file
@ -0,0 +1,514 @@
|
|||||||
|
# LLM 自动生成问题 + 测试 + 审核方案
|
||||||
|
|
||||||
|
**版本:** v1.0
|
||||||
|
**日期:** 2026-04-21
|
||||||
|
**目标:** 基于知识库 MD 文件,自动生成测试问题,经过查重和质量审核后,直接送入单跳召回测试
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 一、整体流程
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────┐
|
||||||
|
│ 上传 MD 文件 │
|
||||||
|
└──────┬───────┘
|
||||||
|
↓
|
||||||
|
┌──────────────────────────┐
|
||||||
|
│ LLM 按章节生成 Q&A │
|
||||||
|
│ - 每个 section 生成 N 个 │
|
||||||
|
│ - 同时生成参考答案 │
|
||||||
|
│ - 记录答案来源原文片段 │
|
||||||
|
└──────┬───────────────────┘
|
||||||
|
↓
|
||||||
|
┌─────────────────────────────────┐
|
||||||
|
│ 审核流程 │
|
||||||
|
│ ┌─────────────────────────┐ │
|
||||||
|
│ │ 1. 批次内查重 │ │
|
||||||
|
│ │ - 精确查重(hash) │ │
|
||||||
|
│ │ - 语义查重(embedding)│ │
|
||||||
|
│ └─────────────────────────┘ │
|
||||||
|
│ ┌─────────────────────────┐ │
|
||||||
|
│ │ 2. 跨历史问题库查重 │ │
|
||||||
|
│ │ - 与已审核问题对比 │ │
|
||||||
|
│ └─────────────────────────┘ │
|
||||||
|
│ ┌─────────────────────────┐ │
|
||||||
|
│ │ 3. 问题质量自动评分 │ │
|
||||||
|
│ │ - 可回答性 │ │
|
||||||
|
│ │ - 问题清晰度 │ │
|
||||||
|
│ │ - 答案准确性 │ │
|
||||||
|
│ │ - 独特性 │ │
|
||||||
|
│ └─────────────────────────┘ │
|
||||||
|
│ ┌─────────────────────────┐ │
|
||||||
|
│ │ 4. 人工确认/编辑/删除 │ │
|
||||||
|
│ │ - 自动通过高质量问题 │ │
|
||||||
|
│ │ - 标记低质量/重复问题 │ │
|
||||||
|
│ └─────────────────────────┘ │
|
||||||
|
└─────────┬───────────────────────┘
|
||||||
|
↓
|
||||||
|
┌──────────────────────────┐
|
||||||
|
│ 导出为标准 MD 格式 │
|
||||||
|
│ (与现有单跳测试格式一致) │
|
||||||
|
└──────┬───────────────────┘
|
||||||
|
↓
|
||||||
|
┌──────────────────────────┐
|
||||||
|
│ 直接送入单跳召回测试 │
|
||||||
|
└──────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 二、模块设计
|
||||||
|
|
||||||
|
### 2.1 生成模块(`/api/qa-gen`)
|
||||||
|
|
||||||
|
#### API 设计
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /api/qa-gen/task # 创建生成任务
|
||||||
|
GET /api/qa-gen/task/list # 任务列表
|
||||||
|
GET /api/qa-gen/task/{id} # 任务详情(含进度)
|
||||||
|
DELETE /api/qa-gen/task/{id} # 删除任务
|
||||||
|
GET /api/qa-gen/task/{id}/questions # 获取生成的问题列表
|
||||||
|
POST /api/qa-gen/question/{id}/approve # 通过问题
|
||||||
|
POST /api/qa-gen/question/{id}/reject # 拒绝问题
|
||||||
|
PUT /api/qa-gen/question/{id} # 编辑问题
|
||||||
|
POST /api/qa-gen/task/{id}/export-md # 导出已通过问题为 MD
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 生成策略
|
||||||
|
|
||||||
|
**输入:**
|
||||||
|
- MD 文件(与单跳测试相同格式)
|
||||||
|
- 配置参数:
|
||||||
|
- `model`: LLM 模型(默认 gpt-4o-mini)
|
||||||
|
- `questions_per_section`: 每章节生成问题数(默认 5)
|
||||||
|
- `quality_threshold`: 质量阈值(默认 0.6)
|
||||||
|
- `judge_config_id`: 评分模型配置
|
||||||
|
|
||||||
|
**处理流程:**
|
||||||
|
1. 按 `## section` 切分文档
|
||||||
|
2. 对每个 section:
|
||||||
|
- 提取章节标题和内容
|
||||||
|
- 调用 LLM 生成 N 个问题
|
||||||
|
- 每个问题包含:
|
||||||
|
- 问题文本
|
||||||
|
- 参考答案
|
||||||
|
- 答案来源原文片段(用于质量审核)
|
||||||
|
3. 后台异步执行,支持进度回调
|
||||||
|
|
||||||
|
**Prompt 模板:**
|
||||||
|
|
||||||
|
```
|
||||||
|
你是一个专业的技术文档测试问题生成专家。
|
||||||
|
|
||||||
|
任务:根据以下技术文档章节内容,生成 {N} 个测试问题。
|
||||||
|
|
||||||
|
章节标题:{section_path}
|
||||||
|
章节内容:
|
||||||
|
{content}
|
||||||
|
|
||||||
|
要求:
|
||||||
|
1. 问题必须能从该章节内容直接回答(不要生成需要跨文档才能回答的问题)
|
||||||
|
2. 问题应覆盖章节的关键知识点
|
||||||
|
3. 问题表述清晰,无歧义
|
||||||
|
4. 答案准确,与原文一致
|
||||||
|
5. 标注答案来源的原文片段(用于后续审核)
|
||||||
|
|
||||||
|
输出格式(JSON):
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"question": "问题文本",
|
||||||
|
"answer": "参考答案",
|
||||||
|
"source_chunk": "答案来源的原文片段(50-200字)"
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2.2 查重模块
|
||||||
|
|
||||||
|
#### 两层查重机制
|
||||||
|
|
||||||
|
| 层级 | 方法 | 阈值 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| **精确查重** | 问题文本 hash | 完全相同 | 快速过滤完全重复 |
|
||||||
|
| **语义查重** | embedding 余弦相似度 | > 0.92 | 识别语义相似问题 |
|
||||||
|
|
||||||
|
#### 查重范围
|
||||||
|
|
||||||
|
1. **批次内查重**:当前生成任务内的问题互相查重
|
||||||
|
2. **跨历史查重**:与 `qa_approved_question` 表中已审核通过的问题查重
|
||||||
|
|
||||||
|
#### 实现细节
|
||||||
|
|
||||||
|
**Embedding 计算:**
|
||||||
|
- 使用 `text-embedding-3-small` 或配置的 embedding 模型
|
||||||
|
- 问题生成后立即计算 embedding 并存储
|
||||||
|
- embedding 存储为 JSON 字符串(1536 维向量)
|
||||||
|
|
||||||
|
**查重流程:**
|
||||||
|
```python
|
||||||
|
# 1. 精确查重
|
||||||
|
question_hash = hashlib.md5(question.strip().lower().encode()).hexdigest()
|
||||||
|
if question_hash in existing_hashes:
|
||||||
|
mark_as_duplicate()
|
||||||
|
|
||||||
|
# 2. 语义查重
|
||||||
|
question_embedding = get_embedding(question)
|
||||||
|
similarities = cosine_similarity(question_embedding, all_embeddings)
|
||||||
|
if max(similarities) > 0.92:
|
||||||
|
mark_as_similar(most_similar_question_id)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2.3 质量审核模块
|
||||||
|
|
||||||
|
#### 自动质量评分
|
||||||
|
|
||||||
|
每条生成的问题自动打分(0-1),综合以下维度:
|
||||||
|
|
||||||
|
| 维度 | 权重 | 评分方法 |
|
||||||
|
|------|------|---------|
|
||||||
|
| **可回答性** | 30% | LLM 判断:答案是否能从 source_chunk 推导出 |
|
||||||
|
| **问题清晰度** | 25% | LLM 判断:问题是否有歧义、表述是否清晰 |
|
||||||
|
| **答案准确性** | 30% | LLM 判断:参考答案是否与 source_chunk 一致 |
|
||||||
|
| **独特性** | 15% | 计算:与最相似问题的语义距离(1 - max_similarity) |
|
||||||
|
|
||||||
|
**质量评分 Prompt:**
|
||||||
|
|
||||||
|
```
|
||||||
|
评估以下测试问题的质量,从 0-1 打分。
|
||||||
|
|
||||||
|
问题:{question}
|
||||||
|
参考答案:{answer}
|
||||||
|
答案来源原文:{source_chunk}
|
||||||
|
|
||||||
|
评估维度:
|
||||||
|
1. 可回答性(0-1):答案是否能从原文推导出?
|
||||||
|
2. 问题清晰度(0-1):问题是否清晰无歧义?
|
||||||
|
3. 答案准确性(0-1):参考答案是否与原文一致?
|
||||||
|
|
||||||
|
输出格式(JSON):
|
||||||
|
{
|
||||||
|
"answerable": 0.9,
|
||||||
|
"clarity": 0.85,
|
||||||
|
"accuracy": 0.95,
|
||||||
|
"reasoning": "简短说明"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 审核状态流转
|
||||||
|
|
||||||
|
```
|
||||||
|
pending(待审核)
|
||||||
|
↓
|
||||||
|
├─→ approved(通过)→ 进入 qa_approved_question 表
|
||||||
|
├─→ rejected(拒绝)→ 不进入测试
|
||||||
|
└─→ edited(编辑后)→ 重新计算 embedding 和质量分
|
||||||
|
```
|
||||||
|
|
||||||
|
**自动通过规则:**
|
||||||
|
- `quality_score >= threshold`(默认 0.6)
|
||||||
|
- 且 `dup_of IS NULL`(非重复)
|
||||||
|
- 自动标记为 `approved`
|
||||||
|
|
||||||
|
**需人工审核:**
|
||||||
|
- `quality_score < threshold`
|
||||||
|
- 或 `dup_of IS NOT NULL`(疑似重复)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2.4 数据库设计
|
||||||
|
|
||||||
|
#### 新增表
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- 生成任务表
|
||||||
|
CREATE TABLE qa_gen_task (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT,
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending', -- pending/running/done/failed
|
||||||
|
model TEXT NOT NULL, -- 使用的 LLM 模型
|
||||||
|
judge_config_id TEXT, -- 评分模型配置
|
||||||
|
questions_per_section INTEGER DEFAULT 5,
|
||||||
|
quality_threshold REAL DEFAULT 0.6,
|
||||||
|
progress INTEGER DEFAULT 0,
|
||||||
|
total INTEGER DEFAULT 0,
|
||||||
|
error_message TEXT,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
finished_at TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
-- 生成的问题表(待审核池)
|
||||||
|
CREATE TABLE qa_gen_question (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
task_id TEXT NOT NULL,
|
||||||
|
section_path TEXT NOT NULL,
|
||||||
|
question TEXT NOT NULL,
|
||||||
|
reference_answer TEXT NOT NULL,
|
||||||
|
source_chunk TEXT, -- 答案来源原文片段
|
||||||
|
quality_score REAL, -- 自动质量评分(0-1)
|
||||||
|
quality_detail TEXT, -- JSON: {answerable, clarity, accuracy, reasoning}
|
||||||
|
dup_of TEXT, -- 重复问题的 id(如果是重复的)
|
||||||
|
dup_similarity REAL, -- 与重复问题的相似度
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending', -- pending/approved/rejected/edited
|
||||||
|
embedding TEXT, -- JSON 向量,用于查重
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
updated_at TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
-- 已审核通过的问题库(用于查重基准 + 导出测试)
|
||||||
|
CREATE TABLE qa_approved_question (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
gen_question_id TEXT NOT NULL, -- 关联 qa_gen_question.id
|
||||||
|
section_path TEXT NOT NULL,
|
||||||
|
question TEXT NOT NULL,
|
||||||
|
reference_answer TEXT NOT NULL,
|
||||||
|
embedding TEXT NOT NULL, -- 用于后续查重
|
||||||
|
source_task_id TEXT NOT NULL,
|
||||||
|
quality_score REAL,
|
||||||
|
approved_at TEXT NOT NULL,
|
||||||
|
approved_by TEXT DEFAULT 'auto' -- auto/manual
|
||||||
|
);
|
||||||
|
|
||||||
|
-- 索引
|
||||||
|
CREATE INDEX idx_qa_gen_question_task_id ON qa_gen_question(task_id);
|
||||||
|
CREATE INDEX idx_qa_gen_question_status ON qa_gen_question(status);
|
||||||
|
CREATE INDEX idx_qa_approved_question_section ON qa_approved_question(section_path);
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 三、前端设计
|
||||||
|
|
||||||
|
### 3.1 页面结构
|
||||||
|
|
||||||
|
新增"问题生成"一级菜单,包含两个子页面:
|
||||||
|
|
||||||
|
```
|
||||||
|
问题生成
|
||||||
|
├─ 生成任务
|
||||||
|
└─ 问题审核
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3.2 生成任务页
|
||||||
|
|
||||||
|
**布局:** 类似单跳测试的任务列表页
|
||||||
|
|
||||||
|
**功能:**
|
||||||
|
- 上传 MD 文件
|
||||||
|
- 配置生成参数:
|
||||||
|
- 模型选择(下拉)
|
||||||
|
- 每章节问题数(数字输入,默认 5)
|
||||||
|
- 质量阈值(滑块,0-1,默认 0.6)
|
||||||
|
- 评分模型配置(下拉,复用 judge_config)
|
||||||
|
- 任务列表:
|
||||||
|
- 任务名称、状态、进度、创建时间
|
||||||
|
- 操作:查看问题、删除任务
|
||||||
|
|
||||||
|
**任务状态展示:**
|
||||||
|
```
|
||||||
|
┌────────────────────────────────────────────────────┐
|
||||||
|
│ 任务名称:evb_linux_development │
|
||||||
|
│ 状态:运行中 进度:45/107 章节 │
|
||||||
|
│ 已生成:225 个问题 自动通过:180 待审核:45 │
|
||||||
|
│ [查看问题] [停止任务] │
|
||||||
|
└────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3.3 问题审核页(核心交互)
|
||||||
|
|
||||||
|
**布局:** 左右分栏
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ 筛选:[全部] [待审核] [重复] [低质量] [已通过] [已拒绝] │
|
||||||
|
│ 任务:[下拉选择任务] │
|
||||||
|
├──────────┬──────────────────────────────────────────────────┤
|
||||||
|
│ │ 批量操作:[全部通过] [通过高质量(>0.6)] [导出MD] │
|
||||||
|
│ ├──────────────────────────────────────────────────┤
|
||||||
|
│ 章节列表 │ 问题列表 │
|
||||||
|
│ │ ┌────────────────────────────────────────────┐ │
|
||||||
|
│ □ 全选 │ │ ✅ Q1: 如何配置 DDR 参数? 质量分: 0.85 │ │
|
||||||
|
│ □ ch1 │ │ A: 通过修改 xxx 配置文件... │ │
|
||||||
|
│ (12/15) │ │ 来源: linux_development/ddr/config │ │
|
||||||
|
│ │ │ [通过] [拒绝] [编辑] │ │
|
||||||
|
│ □ ch2 │ └────────────────────────────────────────────┘ │
|
||||||
|
│ (8/10) │ ┌────────────────────────────────────────────┐ │
|
||||||
|
│ │ │ ⚠️ Q2: DDR 配置文件在哪? 质量分: 0.45 │ │
|
||||||
|
│ □ ch3 │ │ A: 在 /etc/ddr.conf │ │
|
||||||
|
│ (5/8) │ │ ⚠️ 与"Q1"相似度 0.94(疑似重复) │ │
|
||||||
|
│ │ │ [通过] [拒绝] [编辑] [查看原问题] │ │
|
||||||
|
│ │ └────────────────────────────────────────────┘ │
|
||||||
|
│ │ ┌────────────────────────────────────────────┐ │
|
||||||
|
│ │ │ ❌ Q3: xxx? 质量分: 0.32 │ │
|
||||||
|
│ │ │ A: xxx │ │
|
||||||
|
│ │ │ ⚠️ 低质量:问题不清晰 │ │
|
||||||
|
│ │ │ [通过] [拒绝] [编辑] │ │
|
||||||
|
│ │ └────────────────────────────────────────────┘ │
|
||||||
|
└──────────┴──────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
**交互细节:**
|
||||||
|
|
||||||
|
1. **问题卡片状态标识:**
|
||||||
|
- ✅ 绿色:已通过(quality_score >= threshold 且非重复)
|
||||||
|
- ⚠️ 黄色:待审核(低质量或疑似重复)
|
||||||
|
- ❌ 红色:已拒绝
|
||||||
|
|
||||||
|
2. **批量操作:**
|
||||||
|
- "全部通过":将当前筛选结果中所有 pending 问题标记为 approved
|
||||||
|
- "通过高质量":仅通过 quality_score >= threshold 且非重复的问题
|
||||||
|
- "导出 MD":导出已通过问题为标准 MD 格式
|
||||||
|
|
||||||
|
3. **编辑问题:**
|
||||||
|
- 弹出对话框,可修改问题、答案
|
||||||
|
- 保存后重新计算 embedding 和质量分
|
||||||
|
- 状态变为 `edited`
|
||||||
|
|
||||||
|
4. **查看原问题:**
|
||||||
|
- 点击"查看原问题"跳转到重复问题的卡片
|
||||||
|
- 高亮显示相似部分
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3.4 导出 MD 格式
|
||||||
|
|
||||||
|
导出的 MD 文件格式与单跳测试输入格式完全一致:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## section_path / doc_name
|
||||||
|
|
||||||
|
## Q1: 问题文本
|
||||||
|
**A1:** 参考答案
|
||||||
|
|
||||||
|
## Q2: 问题文本
|
||||||
|
**A2:** 参考答案
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## section_path2 / doc_name2
|
||||||
|
|
||||||
|
## Q1: 问题文本
|
||||||
|
**A1:** 参考答案
|
||||||
|
```
|
||||||
|
|
||||||
|
导出后可直接上传到"单跳召回测试"模块进行测试。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 四、实现优先级
|
||||||
|
|
||||||
|
### P0(核心功能,1-2 周)
|
||||||
|
|
||||||
|
| 模块 | 功能 | 工作量 |
|
||||||
|
|------|------|--------|
|
||||||
|
| 后端 | 生成任务 API(上传 MD → LLM 生成 Q&A → 存库) | 1-2 天 |
|
||||||
|
| 后端 | 问题列表 API + 通过/拒绝/编辑 API | 0.5 天 |
|
||||||
|
| 后端 | 导出 MD API | 0.5 天 |
|
||||||
|
| 前端 | 生成任务页(上传 + 配置 + 任务列表) | 1 天 |
|
||||||
|
| 前端 | 问题审核页(列表 + 基础交互) | 1-2 天 |
|
||||||
|
| 数据库 | 新增 3 张表 + schema 迁移 | 0.5 天 |
|
||||||
|
|
||||||
|
### P1(查重 + 质量评分,1 周)
|
||||||
|
|
||||||
|
| 模块 | 功能 | 工作量 |
|
||||||
|
|------|------|--------|
|
||||||
|
| 后端 | 批次内查重(hash + embedding) | 1 天 |
|
||||||
|
| 后端 | 质量自动评分(LLM 评分) | 1 天 |
|
||||||
|
| 前端 | 问题卡片状态标识(质量分、重复标记) | 0.5 天 |
|
||||||
|
| 前端 | 批量操作(全部通过、通过高质量) | 0.5 天 |
|
||||||
|
|
||||||
|
### P2(跨历史查重 + 优化,3-5 天)
|
||||||
|
|
||||||
|
| 模块 | 功能 | 工作量 |
|
||||||
|
|------|------|--------|
|
||||||
|
| 后端 | 跨历史问题库查重 | 0.5 天 |
|
||||||
|
| 前端 | 查看原问题跳转 | 0.5 天 |
|
||||||
|
| 前端 | 编辑问题对话框 | 0.5 天 |
|
||||||
|
| 优化 | embedding 批量计算优化 | 0.5 天 |
|
||||||
|
| 优化 | 生成任务并发控制 | 0.5 天 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 五、技术选型
|
||||||
|
|
||||||
|
### LLM 模型
|
||||||
|
|
||||||
|
| 用途 | 推荐模型 | 备选 |
|
||||||
|
|------|---------|------|
|
||||||
|
| 问题生成 | gpt-4o-mini | gpt-4o, claude-3.5-sonnet |
|
||||||
|
| 质量评分 | gpt-4o-mini | gpt-4o |
|
||||||
|
| Embedding | text-embedding-3-small | text-embedding-3-large |
|
||||||
|
|
||||||
|
### 依赖库
|
||||||
|
|
||||||
|
- **后端:** 复用现有 `judge_config` 表的 OpenAI 配置
|
||||||
|
- **Embedding:** 使用 OpenAI SDK 或 `sentence-transformers`(如果需要本地部署)
|
||||||
|
- **相似度计算:** `numpy.dot` + `numpy.linalg.norm`(余弦相似度)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 六、风险与注意事项
|
||||||
|
|
||||||
|
### 6.1 成本控制
|
||||||
|
|
||||||
|
- **问题生成:** 每个 section 约 500-2000 tokens 输入,生成 5 个问题约 500 tokens 输出
|
||||||
|
- 估算:12,000 条问题(2,400 sections × 5)≈ 3M tokens input + 1M tokens output
|
||||||
|
- 成本(gpt-4o-mini):约 $0.6
|
||||||
|
- **质量评分:** 每个问题约 300 tokens 输入 + 100 tokens 输出
|
||||||
|
- 估算:12,000 条问题 ≈ 3.6M tokens input + 1.2M tokens output
|
||||||
|
- 成本(gpt-4o-mini):约 $0.7
|
||||||
|
- **Embedding:** 每个问题约 20 tokens
|
||||||
|
- 估算:12,000 条问题 ≈ 240K tokens
|
||||||
|
- 成本(text-embedding-3-small):约 $0.005
|
||||||
|
|
||||||
|
**总成本:** 约 $1.3 / 12,000 条问题
|
||||||
|
|
||||||
|
### 6.2 性能优化
|
||||||
|
|
||||||
|
- **并发控制:** 生成任务使用 `asyncio.Semaphore` 限制并发数(默认 5)
|
||||||
|
- **批量 embedding:** 每次最多 100 个问题批量计算 embedding
|
||||||
|
- **查重优化:** 使用 numpy 向量化计算,避免循环
|
||||||
|
|
||||||
|
### 6.3 数据一致性
|
||||||
|
|
||||||
|
- **事务保护:** 问题通过/拒绝操作使用数据库事务
|
||||||
|
- **幂等性:** 重复提交生成任务时检查是否已存在相同任务
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 七、后续扩展
|
||||||
|
|
||||||
|
### 7.1 高级功能
|
||||||
|
|
||||||
|
- **问题难度分级:** 自动标注问题难度(简单/中等/困难)
|
||||||
|
- **知识点标签:** 自动提取问题涉及的知识点标签
|
||||||
|
- **多轮对话问题:** 生成需要多轮交互的复杂问题
|
||||||
|
- **负样本生成:** 生成故意错误的答案,用于测试模型鲁棒性
|
||||||
|
|
||||||
|
### 7.2 集成优化
|
||||||
|
|
||||||
|
- **与单跳测试联动:** 审核通过后自动创建单跳测试任务
|
||||||
|
- **测试结果反馈:** 单跳测试失败的问题自动标记为"需优化"
|
||||||
|
- **持续迭代:** 根据测试结果自动调整生成策略
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 八、总结
|
||||||
|
|
||||||
|
本方案提供了一个完整的"生成 → 查重 → 审核 → 测试"闭环,核心优势:
|
||||||
|
|
||||||
|
1. **自动化程度高:** 90% 的高质量问题可自动通过,人工仅需审核 10%
|
||||||
|
2. **质量可控:** 多维度质量评分 + 查重机制保证问题质量
|
||||||
|
3. **无缝集成:** 导出格式与现有单跳测试完全兼容
|
||||||
|
4. **可扩展性强:** 模块化设计,易于后续扩展
|
||||||
|
|
||||||
|
**预期效果:** 将问题生成效率提升 10 倍,从人工编写 1 小时 10 条问题,提升到 LLM 生成 1 小时 1000+ 条问题(含审核)。
|
||||||
1684
docs/RAG-Eval平台技术规格说明书.md
Normal file
1684
docs/RAG-Eval平台技术规格说明书.md
Normal file
File diff suppressed because it is too large
Load Diff
38
docs/README.md
Normal file
38
docs/README.md
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
# RAG Eval 文档目录
|
||||||
|
|
||||||
|
本目录集中存放**技术性说明**与**数据型资产**(方案、规则、配置、导出结果)。项目入口说明仍见仓库根目录 [`README.md`](../README.md)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 分批与数据
|
||||||
|
|
||||||
|
| 文档 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| [循环测试_14组分批规则.md](./循环测试_14组分批规则.md) | 14 组 × 42 批次规则说明(人类可读) |
|
||||||
|
| task_groups_plan.json / exports/ | 本地数据资产(**不入 Git**,见 `.gitignore`) |
|
||||||
|
| [循环测试_14组分批规则.md](./循环测试_14组分批规则.md) | 分批说明(本地保留,含环境信息时不提交仓库) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 架构与设计
|
||||||
|
|
||||||
|
| 文档 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| [**RAG-Eval平台技术规格说明书.md**](./RAG-Eval平台技术规格说明书.md) | **万字级技术文档**(架构图、时序图、数据模型、API、指标) |
|
||||||
|
| [rag-eval-framework-design.md](./rag-eval-framework-design.md) | 评测框架总体设计(早期稿) |
|
||||||
|
| [TUTORIAL.md](./TUTORIAL.md) | 使用教程 |
|
||||||
|
| [config.example.yaml](./config.example.yaml) | SDK 配置示例(副本,运行仍以 `sdk/config.example.yaml` 为准) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 方案与报告
|
||||||
|
|
||||||
|
| 文档 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| [LLM自动生成问题方案.md](./LLM自动生成问题方案.md) | LLM 自动出题流程 |
|
||||||
|
| [多模态问答集生成方案.md](./多模态问答集生成方案.md) | 多模态问答集生成 |
|
||||||
|
| [基于Dagent平台的多模态问答集生成方案.md](./基于Dagent平台的多模态问答集生成方案.md) | Dagent 平台多模态方案 |
|
||||||
|
| [Dagent文件选择器方案.md](./Dagent文件选择器方案.md) | 文件选择器 |
|
||||||
|
| [EVB知识库单跳召回测试报告.md](./EVB知识库单跳召回测试报告.md) | EVB 单跳召回测试 |
|
||||||
|
| [验证报告.md](./验证报告.md) | 验证报告 |
|
||||||
|
| [multi-hop-example.md](./multi-hop-example.md) | 多跳测试 MD 样例格式 |
|
||||||
25
docs/config.example.yaml
Normal file
25
docs/config.example.yaml
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# 平台连接配置
|
||||||
|
platform:
|
||||||
|
base_url: "http://localhost:8000"
|
||||||
|
org_id: "your_org_id"
|
||||||
|
token: "" # 如有鉴权 token 填写
|
||||||
|
|
||||||
|
# Judge LLM 配置(OpenAI 兼容接口)
|
||||||
|
judge:
|
||||||
|
base_url: "https://api.openai.com/v1"
|
||||||
|
api_key: "sk-your-key"
|
||||||
|
model: "gpt-4o"
|
||||||
|
|
||||||
|
# 评测参数
|
||||||
|
eval:
|
||||||
|
agent_id: "your_agent_id"
|
||||||
|
knowledge_hub_id: "your_hub_id"
|
||||||
|
top_k: 10
|
||||||
|
eval_retrieval: true
|
||||||
|
eval_generation: true
|
||||||
|
file_id_list:
|
||||||
|
- "file_id_1"
|
||||||
|
- "file_id_2"
|
||||||
|
concurrency: 3
|
||||||
|
questions_per_chunk: 2
|
||||||
|
max_chunks: 50
|
||||||
23
docs/multi-hop-example.md
Normal file
23
docs/multi-hop-example.md
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
## MH1
|
||||||
|
**类型:** comparison
|
||||||
|
**问题:** RDK X3 和 RDK X5 的 CPU 核心数和主频分别是多少,有何差异?
|
||||||
|
**答案:** RDK X3 搭载 4 核 ARM Cortex-A53,主频 1.2GHz;RDK X5 搭载 8 核 ARM Cortex-A55,主频 1.5GHz,X5 核心数翻倍且主频更高。
|
||||||
|
**Hop1:** hardware / rdk_x3_spec | 提供 RDK X3 的 CPU 规格参数
|
||||||
|
**Hop2:** hardware / rdk_x5_spec | 提供 RDK X5 的 CPU 规格参数
|
||||||
|
---
|
||||||
|
|
||||||
|
## MH2
|
||||||
|
**类型:** reasoning
|
||||||
|
**问题:** 使用 RDK 开发板进行 BPU 推理时,需要先完成哪些环境准备步骤?
|
||||||
|
**答案:** 需要先完成系统烧录、驱动安装,再配置 Python 环境,最后安装 horizon_bpu 推理库。
|
||||||
|
**Hop1:** quick_start / system_install | 提供系统烧录和驱动安装步骤
|
||||||
|
**Hop2:** linux_development / bpu_develop | 提供 BPU 推理环境配置和库安装步骤
|
||||||
|
---
|
||||||
|
|
||||||
|
## MH3
|
||||||
|
**类型:** aggregation
|
||||||
|
**问题:** RDK 平台支持哪些多媒体编解码格式,对应的硬件加速模块是什么?
|
||||||
|
**答案:** 支持 H.264/H.265 编解码,由 VPU 硬件模块加速;支持 JPEG 编解码,由 JPU 模块加速。
|
||||||
|
**Hop1:** multimedia_development / codec_overview | 提供支持的编解码格式列表
|
||||||
|
**Hop2:** hardware / hardware_modules | 提供 VPU/JPU 硬件模块说明
|
||||||
|
---
|
||||||
646
docs/rag-eval-framework-design.md
Normal file
646
docs/rag-eval-framework-design.md
Normal file
@ -0,0 +1,646 @@
|
|||||||
|
# RAG 评测框架设计文档
|
||||||
|
|
||||||
|
> 版本:v1.0
|
||||||
|
> 日期:2026-04-13
|
||||||
|
> 背景:为 dagent agent 平台设计的独立 RAG 评测框架
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 一、背景与目标
|
||||||
|
|
||||||
|
### 为什么做成独立框架
|
||||||
|
|
||||||
|
dagent 平台已具备完整的 RAG 能力(知识库切片、向量检索、ReAct Agent),但缺乏系统性的评测手段。将评测能力做成**独立框架**而非嵌入现有 backend,原因如下:
|
||||||
|
|
||||||
|
- **平台无关**:通过标准化 Adapter 接口,可评测任何 RAG 系统,不只是 dagent
|
||||||
|
- **独立部署**:不影响生产服务,可单独扩缩容,评测任务不占用业务资源
|
||||||
|
- **技术栈自由**:可选最适合评测场景的工具和模型
|
||||||
|
- **可复用**:其他项目也能接入使用
|
||||||
|
|
||||||
|
### 目标
|
||||||
|
|
||||||
|
1. 提供**检索层**和**生成层**的完整评测指标体系
|
||||||
|
2. 支持通过 **Python SDK** 集成到 CI/CD 流程
|
||||||
|
3. 提供 **Web UI** 供非技术人员操作和查看报告
|
||||||
|
4. 对接 dagent 平台,同时保持对其他平台的扩展能力
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 二、评测指标体系
|
||||||
|
|
||||||
|
### 2.1 检索层评测(Retrieval Evaluation)
|
||||||
|
|
||||||
|
评测知识库切片的召回质量,**不依赖 LLM**,纯计算指标。
|
||||||
|
|
||||||
|
| 指标 | 全称 | 说明 | 计算方式 |
|
||||||
|
|------|------|------|----------|
|
||||||
|
| **Hit Rate@K** | 命中率 | Top-K 结果中是否包含至少一个相关切片 | 二值判断,对所有样本取均值 |
|
||||||
|
| **MRR@K** | Mean Reciprocal Rank | 第一个相关切片排名的倒数均值 | `MRR = mean(1 / rank_i)`,rank_i 为第一个相关切片的位置 |
|
||||||
|
| **NDCG@K** | Normalized Discounted Cumulative Gain | 考虑排名权重的相关性得分,最全面的检索指标 | `NDCG = DCG / IDCG`,DCG 对高排名相关结果给予更高权重 |
|
||||||
|
| **Context Precision** | 上下文精确率 | 召回的切片中有多少是真正相关的(信噪比) | LLM-as-judge 判断每个召回切片是否相关 |
|
||||||
|
| **Context Recall** | 上下文召回率 | 回答所需信息有多少被召回覆盖 | LLM 将参考答案分解为原子声明,检查每条声明是否被召回内容覆盖 |
|
||||||
|
|
||||||
|
**指标公式**
|
||||||
|
|
||||||
|
```
|
||||||
|
Hit Rate@K = (1/|Q|) * Σ 1[∃ relevant chunk in top-K results]
|
||||||
|
|
||||||
|
MRR@K = (1/|Q|) * Σ (1 / rank_i)
|
||||||
|
rank_i = position of first relevant chunk for query i
|
||||||
|
|
||||||
|
DCG@K = Σ_{i=1}^{K} rel_i / log2(i+1)
|
||||||
|
NDCG@K = DCG@K / IDCG@K
|
||||||
|
IDCG = DCG of ideal (perfect) ranking
|
||||||
|
|
||||||
|
Context Precision = |relevant ∩ retrieved| / |retrieved|
|
||||||
|
Context Recall = |ground truth claims covered by context| / |total ground truth claims|
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.2 生成层评测(Generation Evaluation)
|
||||||
|
|
||||||
|
评测 Agent 基于召回内容的回复质量,**依赖 LLM Judge**。
|
||||||
|
|
||||||
|
| 指标 | 说明 | 计算方式 | 是否需要参考答案 |
|
||||||
|
|------|------|----------|-----------------|
|
||||||
|
| **Faithfulness(忠实度)** | 回答中每个声明是否都有召回内容支撑,无幻觉 | LLM 分解答案为原子声明 → 逐条判断是否可从 context 推导 → 支持数/总数 | 否 |
|
||||||
|
| **Answer Relevance(答案相关性)** | 回答是否切题,有没有答非所问 | LLM 从答案反向生成问题 → 与原问题做 Embedding 相似度 | 否 |
|
||||||
|
| **Answer Correctness(答案正确性)** | 回答与标准答案的事实一致程度 | LLM judge 评分 + Embedding 相似度加权 | 是 |
|
||||||
|
| **Groundedness(可溯源性)** | 回答中每个声明是否可追溯到具体切片 | LLM-as-judge,带 chain-of-thought | 否 |
|
||||||
|
|
||||||
|
**Faithfulness 计算原理(最重要的指标)**
|
||||||
|
|
||||||
|
```
|
||||||
|
1. LLM 将 answer 分解为原子声明列表
|
||||||
|
例:"答案:北京是中国首都,人口约2200万"
|
||||||
|
→ ["北京是中国首都", "北京人口约2200万"]
|
||||||
|
|
||||||
|
2. 对每条声明,LLM 判断:能否从 retrieved context 中推导出来?
|
||||||
|
→ [True, False] (第二条无法从 context 推导 = 幻觉)
|
||||||
|
|
||||||
|
3. Faithfulness = 支持的声明数 / 总声明数 = 1/2 = 0.5
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.3 端到端综合指标
|
||||||
|
|
||||||
|
| 指标 | 计算方式 | 说明 |
|
||||||
|
|------|----------|------|
|
||||||
|
| **RAG Score** | 调和均值(Faithfulness, Answer Relevance, Context Precision, Context Recall) | 综合评分,任一短板都会拉低总分 |
|
||||||
|
| **Hallucination Rate** | 含幻觉样本数 / 总样本数(Faithfulness < 阈值) | 幻觉发生率 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 三、系统架构
|
||||||
|
|
||||||
|
### 3.1 整体架构
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────────┐
|
||||||
|
│ RAG Eval Framework │
|
||||||
|
│ │
|
||||||
|
│ ┌──────────────┐ ┌──────────────────┐ ┌───────────────┐ │
|
||||||
|
│ │ Python SDK │ │ FastAPI Server │ │ React Web UI │ │
|
||||||
|
│ │ (核心逻辑) │ ←→ │ (REST API) │ ←→ │ (可视化报告) │ │
|
||||||
|
│ │ CLI 支持 │ │ 任务队列 │ │ 测试集管理 │ │
|
||||||
|
│ └──────────────┘ └──────────────────┘ └───────────────┘ │
|
||||||
|
│ ↓ ↓ │
|
||||||
|
│ ┌──────────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ 核心模块 │ │
|
||||||
|
│ │ Adapters │ Evaluators │ LLM Judge │ Dataset Gen │ │
|
||||||
|
│ └──────────────────────────────────────────────────────────┘ │
|
||||||
|
└─────────────────────────────────────────────────────────────────┘
|
||||||
|
↕ HTTP API(标准化 Adapter 接口)
|
||||||
|
┌─────────────────────────────────────────────────────────────────┐
|
||||||
|
│ dagent platform / 任何其他 RAG 系统 │
|
||||||
|
└─────────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.2 数据流
|
||||||
|
|
||||||
|
```
|
||||||
|
测试集 (question + relevant_chunk_ids + reference_answer)
|
||||||
|
↓
|
||||||
|
EvalRunner.run(dataset, agent_id, knowledge_hub_id)
|
||||||
|
↓
|
||||||
|
┌────────────────────────────────────────────────────┐
|
||||||
|
│ for each sample: │
|
||||||
|
│ │
|
||||||
|
│ Step 1: adapter.retrieve(question) │
|
||||||
|
│ → 获取 Top-K 召回切片 │
|
||||||
|
│ → 计算 Hit Rate / MRR / NDCG(与标注对比) │
|
||||||
|
│ │
|
||||||
|
│ Step 2: adapter.chat(question) │
|
||||||
|
│ → 获取 Agent 回复 + 引用切片 │
|
||||||
|
│ → judge.score_faithfulness(answer, context) │
|
||||||
|
│ → judge.score_relevance(question, answer) │
|
||||||
|
│ → judge.score_correctness(answer, reference) │
|
||||||
|
└────────────────────────────────────────────────────┘
|
||||||
|
↓
|
||||||
|
EvalReport(每条样本详情 + 汇总统计 + 趋势对比)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 四、项目结构
|
||||||
|
|
||||||
|
```
|
||||||
|
rag-eval/
|
||||||
|
├── sdk/ # Python SDK(核心)
|
||||||
|
│ ├── rag_eval/
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── runner.py # 评测任务执行器(入口)
|
||||||
|
│ │ ├── adapters/ # 平台适配器
|
||||||
|
│ │ │ ├── base.py # 抽象接口定义
|
||||||
|
│ │ │ └── dagent.py # dagent 适配器实现
|
||||||
|
│ │ ├── evaluators/ # 评测器
|
||||||
|
│ │ │ ├── retrieval.py # 检索层:Hit Rate / MRR / NDCG
|
||||||
|
│ │ │ └── generation.py # 生成层:Faithfulness / Relevance / Correctness
|
||||||
|
│ │ ├── judge/ # LLM Judge
|
||||||
|
│ │ │ ├── base.py # 抽象接口
|
||||||
|
│ │ │ └── openai_compatible.py # 兼容 DeepSeek / Qwen / OpenAI
|
||||||
|
│ │ ├── dataset/ # 测试集管理
|
||||||
|
│ │ │ ├── schema.py # 数据结构定义(Pydantic)
|
||||||
|
│ │ │ └── generator.py # LLM 自动生成测试集
|
||||||
|
│ │ └── report.py # 报告生成与格式化
|
||||||
|
│ ├── pyproject.toml
|
||||||
|
│ └── README.md
|
||||||
|
│
|
||||||
|
├── server/ # FastAPI 后端
|
||||||
|
│ ├── main.py
|
||||||
|
│ ├── api/
|
||||||
|
│ │ ├── dataset.py # 测试集 CRUD
|
||||||
|
│ │ ├── task.py # 评测任务管理
|
||||||
|
│ │ ├── report.py # 报告查询
|
||||||
|
│ │ └── config.py # 平台连接 & Judge 配置
|
||||||
|
│ ├── service/
|
||||||
|
│ │ ├── task_service.py
|
||||||
|
│ │ └── report_service.py
|
||||||
|
│ ├── models/ # 数据库模型(SQLite / PostgreSQL)
|
||||||
|
│ │ └── schema.sql
|
||||||
|
│ └── requirements.txt
|
||||||
|
│
|
||||||
|
├── frontend/ # React 前端
|
||||||
|
│ ├── src/
|
||||||
|
│ │ ├── pages/
|
||||||
|
│ │ │ ├── Dataset/ # 测试集管理(上传/生成/标注)
|
||||||
|
│ │ │ ├── Task/ # 评测任务(配置/提交/进度)
|
||||||
|
│ │ │ └── Report/ # 报告 & 可视化(雷达图/趋势图)
|
||||||
|
│ │ └── components/
|
||||||
|
│ └── package.json
|
||||||
|
│
|
||||||
|
└── docker-compose.yml # 一键部署
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 五、核心接口设计
|
||||||
|
|
||||||
|
### 5.1 Adapter 抽象接口
|
||||||
|
|
||||||
|
```python
|
||||||
|
# sdk/rag_eval/adapters/base.py
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetrievedChunk:
|
||||||
|
chunk_id: str
|
||||||
|
content: str
|
||||||
|
score: float # 相似度分数
|
||||||
|
headers: str # 所属章节标题
|
||||||
|
file_id: str
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentResponse:
|
||||||
|
answer: str
|
||||||
|
retrieved_chunks: list[RetrievedChunk] # Agent 实际使用的切片
|
||||||
|
latency_ms: int
|
||||||
|
|
||||||
|
class RAGAdapter(ABC):
|
||||||
|
"""
|
||||||
|
任何 RAG 平台都需要实现这两个方法。
|
||||||
|
框架通过此接口与平台交互,不依赖平台内部实现。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def retrieve(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
knowledge_hub_id: str,
|
||||||
|
top_k: int = 10,
|
||||||
|
**kwargs
|
||||||
|
) -> list[RetrievedChunk]:
|
||||||
|
"""调用平台检索接口,返回召回的切片列表"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
agent_id: str,
|
||||||
|
**kwargs
|
||||||
|
) -> AgentResponse:
|
||||||
|
"""调用平台 Agent 对话接口,返回回复和引用的切片"""
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.2 dagent 适配器
|
||||||
|
|
||||||
|
```python
|
||||||
|
# sdk/rag_eval/adapters/dagent.py
|
||||||
|
|
||||||
|
class DagentAdapter(RAGAdapter):
|
||||||
|
"""
|
||||||
|
对接 dagent 平台的适配器。
|
||||||
|
通过 HTTP API 调用,不依赖 dagent 内部代码。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str, org_id: str, token: str):
|
||||||
|
self.base_url = base_url
|
||||||
|
self.org_id = org_id
|
||||||
|
self.headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
async def retrieve(self, query, knowledge_hub_id, top_k=10, **kwargs):
|
||||||
|
# 调用 dagent 知识库检索接口
|
||||||
|
# POST /dagent/knowledge/retrieve
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
resp = await session.post(
|
||||||
|
f"{self.base_url}/dagent/knowledge/retrieve",
|
||||||
|
json={"query": query, "knowledge_hub_id": knowledge_hub_id,
|
||||||
|
"top_k": top_k, "org_id": self.org_id},
|
||||||
|
headers=self.headers
|
||||||
|
)
|
||||||
|
data = await resp.json()
|
||||||
|
return [RetrievedChunk(**chunk) for chunk in data["chunks"]]
|
||||||
|
|
||||||
|
async def chat(self, query, agent_id, **kwargs):
|
||||||
|
# 调用 dagent Agent 对话接口(SSE 流式,解析完整回复)
|
||||||
|
# POST /dagent/agent/chat
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.3 LLM Judge
|
||||||
|
|
||||||
|
```python
|
||||||
|
# sdk/rag_eval/judge/openai_compatible.py
|
||||||
|
|
||||||
|
class OpenAICompatibleJudge(LLMJudge):
|
||||||
|
"""
|
||||||
|
兼容所有 OpenAI 协议的模型:DeepSeek / Qwen / OpenAI / Azure OpenAI
|
||||||
|
评判逻辑使用中文 prompt,适合中文 RAG 场景
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str, api_key: str, model: str):
|
||||||
|
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key)
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
async def score_faithfulness(self, answer: str, context: list[str]) -> float:
|
||||||
|
"""
|
||||||
|
原理:
|
||||||
|
1. 让 LLM 把 answer 分解为原子声明列表
|
||||||
|
2. 对每条声明,判断是否可从 context 推导
|
||||||
|
3. 返回 支持声明数 / 总声明数
|
||||||
|
"""
|
||||||
|
context_text = "\n\n".join(context)
|
||||||
|
|
||||||
|
# Step 1: 分解为原子声明
|
||||||
|
decompose_prompt = f"""
|
||||||
|
请将以下回答分解为独立的原子声明列表,每条声明是一个不可再分的事实陈述。
|
||||||
|
回答:{answer}
|
||||||
|
输出格式:JSON 数组,如 ["声明1", "声明2", ...]
|
||||||
|
"""
|
||||||
|
claims = await self._call_json(decompose_prompt)
|
||||||
|
|
||||||
|
# Step 2: 逐条判断是否有 context 支撑
|
||||||
|
supported = 0
|
||||||
|
for claim in claims:
|
||||||
|
verify_prompt = f"""
|
||||||
|
参考资料:
|
||||||
|
{context_text}
|
||||||
|
|
||||||
|
声明:{claim}
|
||||||
|
|
||||||
|
问题:上述声明是否可以从参考资料中推导出来?
|
||||||
|
只回答 yes 或 no。
|
||||||
|
"""
|
||||||
|
result = await self._call(verify_prompt)
|
||||||
|
if "yes" in result.lower():
|
||||||
|
supported += 1
|
||||||
|
|
||||||
|
return supported / len(claims) if claims else 0.0
|
||||||
|
|
||||||
|
async def score_relevance(self, question: str, answer: str) -> float:
|
||||||
|
"""
|
||||||
|
原理:
|
||||||
|
1. 让 LLM 从 answer 反向生成 N 个问题
|
||||||
|
2. 计算这些问题与原 question 的 Embedding 相似度
|
||||||
|
3. 返回均值
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def score_correctness(self, answer: str, reference: str) -> float:
|
||||||
|
"""
|
||||||
|
原理:LLM 对比 answer 和 reference,给出 0-1 分数
|
||||||
|
"""
|
||||||
|
prompt = f"""
|
||||||
|
请评估以下回答与参考答案的事实一致程度,给出 0 到 1 之间的分数。
|
||||||
|
1.0 = 完全一致,0.0 = 完全错误或无关。
|
||||||
|
|
||||||
|
参考答案:{reference}
|
||||||
|
待评估回答:{answer}
|
||||||
|
|
||||||
|
只输出一个 0 到 1 之间的小数。
|
||||||
|
"""
|
||||||
|
result = await self._call(prompt)
|
||||||
|
return float(result.strip())
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.4 测试集数据结构
|
||||||
|
|
||||||
|
```python
|
||||||
|
# sdk/rag_eval/dataset/schema.py
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvalSample:
|
||||||
|
id: str
|
||||||
|
question: str # 测试问题
|
||||||
|
reference_answer: str # 标准参考答案
|
||||||
|
relevant_chunk_ids: list[str] # 标注的相关切片 ID(用于检索层评测)
|
||||||
|
knowledge_hub_id: str # 所属知识库
|
||||||
|
source_file_id: str | None = None # 来源文件(可选)
|
||||||
|
metadata: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvalDataset:
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
samples: list[EvalSample]
|
||||||
|
created_at: datetime
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.5 SDK 使用示例
|
||||||
|
|
||||||
|
```python
|
||||||
|
from rag_eval import EvalRunner
|
||||||
|
from rag_eval.adapters import DagentAdapter
|
||||||
|
from rag_eval.judge import OpenAICompatibleJudge
|
||||||
|
|
||||||
|
# 配置适配器(对接 dagent)
|
||||||
|
adapter = DagentAdapter(
|
||||||
|
base_url="http://dagent-backend:8000",
|
||||||
|
org_id="org_xxx",
|
||||||
|
token="your-token"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 配置 LLM Judge(独立于 dagent,使用 DeepSeek)
|
||||||
|
judge = OpenAICompatibleJudge(
|
||||||
|
base_url="https://api.deepseek.com/v1",
|
||||||
|
api_key="sk-xxx",
|
||||||
|
model="deepseek-chat"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 运行评测
|
||||||
|
runner = EvalRunner(adapter=adapter, judge=judge)
|
||||||
|
report = await runner.run(
|
||||||
|
dataset="./my_dataset.json",
|
||||||
|
agent_id="agent_xxx",
|
||||||
|
knowledge_hub_id="hub_xxx",
|
||||||
|
top_k=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 查看结果
|
||||||
|
print(report.summary())
|
||||||
|
# ┌─────────────────────────────────────────┐
|
||||||
|
# │ 评测报告摘要 │
|
||||||
|
# ├──────────────────────┬──────────────────┤
|
||||||
|
# │ 样本数 │ 200 │
|
||||||
|
# │ Hit Rate@10 │ 0.87 │
|
||||||
|
# │ MRR@10 │ 0.72 │
|
||||||
|
# │ NDCG@10 │ 0.81 │
|
||||||
|
# │ Context Precision │ 0.76 │
|
||||||
|
# │ Context Recall │ 0.83 │
|
||||||
|
# │ Faithfulness │ 0.91 │
|
||||||
|
# │ Answer Relevance │ 0.88 │
|
||||||
|
# │ Answer Correctness │ 0.79 │
|
||||||
|
# │ RAG Score │ 0.84 │
|
||||||
|
# │ Hallucination Rate │ 4.5% │
|
||||||
|
# └──────────────────────┴──────────────────┘
|
||||||
|
|
||||||
|
report.save("./eval_report_20260413.json")
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 六、测试集构建方案
|
||||||
|
|
||||||
|
### 6.1 数据结构
|
||||||
|
|
||||||
|
每条测试样本:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "sample_001",
|
||||||
|
"question": "什么是向量数据库?",
|
||||||
|
"reference_answer": "向量数据库是专门存储和检索高维向量的数据库系统...",
|
||||||
|
"relevant_chunk_ids": ["chunk_abc123", "chunk_def456"],
|
||||||
|
"knowledge_hub_id": "hub_xxx",
|
||||||
|
"source_file_id": "file_yyy"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.2 构建方式
|
||||||
|
|
||||||
|
**方式 A:LLM 自动生成(推荐先用)**
|
||||||
|
|
||||||
|
```python
|
||||||
|
from rag_eval.dataset import DatasetGenerator
|
||||||
|
|
||||||
|
generator = DatasetGenerator(judge=judge, adapter=adapter)
|
||||||
|
dataset = await generator.generate(
|
||||||
|
knowledge_hub_id="hub_xxx",
|
||||||
|
questions_per_chunk=2,
|
||||||
|
question_types=["factual", "reasoning", "comparison", "unanswerable"]
|
||||||
|
)
|
||||||
|
# 自动生成问题 + 参考答案 + 标注 relevant_chunk_ids
|
||||||
|
```
|
||||||
|
|
||||||
|
原理:
|
||||||
|
1. 遍历知识库中所有切片
|
||||||
|
2. 对每个切片,用 LLM 生成 2-3 个不同类型的问题
|
||||||
|
3. 用 LLM 基于切片内容生成参考答案
|
||||||
|
4. 自动标注 `relevant_chunk_ids`(生成来源切片)
|
||||||
|
5. 建议人工抽检 10-20% 过滤低质量样本
|
||||||
|
|
||||||
|
**方式 B:人工标注(质量最高)**
|
||||||
|
|
||||||
|
通过 Web UI 提供标注界面:
|
||||||
|
- 输入问题
|
||||||
|
- 搜索并标注相关切片
|
||||||
|
- 填写参考答案
|
||||||
|
|
||||||
|
**问题类型覆盖建议**
|
||||||
|
|
||||||
|
| 类型 | 示例 | 占比建议 |
|
||||||
|
|------|------|----------|
|
||||||
|
| 事实查询 | "X 是什么?" | 40% |
|
||||||
|
| 多跳推理 | "X 和 Y 的关系是?" | 20% |
|
||||||
|
| 比较 | "X 和 Y 有什么区别?" | 20% |
|
||||||
|
| 不可回答 | 文档中不存在的信息 | 10% |
|
||||||
|
| 摘要 | "总结 X 的主要内容" | 10% |
|
||||||
|
|
||||||
|
推荐测试集规模:**200-500 条**,低于 100 条统计意义不足。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 七、Web 端功能规划
|
||||||
|
|
||||||
|
| 页面 | 核心功能 |
|
||||||
|
|------|----------|
|
||||||
|
| **测试集管理** | 上传 JSON 测试集、LLM 自动生成、人工标注界面、样本预览 |
|
||||||
|
| **评测任务** | 配置 Adapter(平台连接)、配置 Judge 模型、提交任务、实时进度 |
|
||||||
|
| **评测报告** | 各指标得分雷达图、样本级别明细表、多次评测趋势对比、问题样本下钻 |
|
||||||
|
| **配置管理** | 平台连接配置(URL/Token)、Judge 模型配置(API Key/Model)|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 八、数据库设计(Server 端)
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- 平台连接配置
|
||||||
|
CREATE TABLE platform_config (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
type TEXT NOT NULL, -- 'dagent' | 'custom'
|
||||||
|
base_url TEXT NOT NULL,
|
||||||
|
org_id TEXT,
|
||||||
|
token TEXT,
|
||||||
|
created_at DATETIME
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Judge 模型配置
|
||||||
|
CREATE TABLE judge_config (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
base_url TEXT NOT NULL,
|
||||||
|
api_key TEXT NOT NULL,
|
||||||
|
model TEXT NOT NULL,
|
||||||
|
created_at DATETIME
|
||||||
|
);
|
||||||
|
|
||||||
|
-- 测试集
|
||||||
|
CREATE TABLE eval_dataset (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
sample_count INTEGER,
|
||||||
|
created_at DATETIME
|
||||||
|
);
|
||||||
|
|
||||||
|
-- 测试样本
|
||||||
|
CREATE TABLE eval_sample (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
dataset_id TEXT NOT NULL,
|
||||||
|
question TEXT NOT NULL,
|
||||||
|
reference_answer TEXT NOT NULL,
|
||||||
|
relevant_chunk_ids TEXT NOT NULL, -- JSON array
|
||||||
|
knowledge_hub_id TEXT NOT NULL,
|
||||||
|
source_file_id TEXT,
|
||||||
|
metadata TEXT -- JSON
|
||||||
|
);
|
||||||
|
|
||||||
|
-- 评测任务
|
||||||
|
CREATE TABLE eval_task (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT,
|
||||||
|
dataset_id TEXT NOT NULL,
|
||||||
|
platform_config_id TEXT NOT NULL,
|
||||||
|
judge_config_id TEXT NOT NULL,
|
||||||
|
agent_id TEXT NOT NULL,
|
||||||
|
knowledge_hub_id TEXT NOT NULL,
|
||||||
|
top_k INTEGER DEFAULT 10,
|
||||||
|
status TEXT NOT NULL, -- pending | running | done | failed
|
||||||
|
progress INTEGER DEFAULT 0,
|
||||||
|
created_at DATETIME,
|
||||||
|
finished_at DATETIME
|
||||||
|
);
|
||||||
|
|
||||||
|
-- 样本级评测结果
|
||||||
|
CREATE TABLE eval_result (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
task_id TEXT NOT NULL,
|
||||||
|
sample_id TEXT NOT NULL,
|
||||||
|
retrieved_chunks TEXT, -- JSON
|
||||||
|
agent_answer TEXT,
|
||||||
|
hit_rate REAL,
|
||||||
|
mrr REAL,
|
||||||
|
ndcg REAL,
|
||||||
|
context_precision REAL,
|
||||||
|
context_recall REAL,
|
||||||
|
faithfulness REAL,
|
||||||
|
answer_relevance REAL,
|
||||||
|
answer_correctness REAL,
|
||||||
|
judge_detail TEXT -- JSON,LLM judge 的推理过程
|
||||||
|
);
|
||||||
|
|
||||||
|
-- 评测汇总报告
|
||||||
|
CREATE TABLE eval_report (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
task_id TEXT NOT NULL UNIQUE,
|
||||||
|
sample_count INTEGER,
|
||||||
|
avg_hit_rate REAL,
|
||||||
|
avg_mrr REAL,
|
||||||
|
avg_ndcg REAL,
|
||||||
|
avg_context_precision REAL,
|
||||||
|
avg_context_recall REAL,
|
||||||
|
avg_faithfulness REAL,
|
||||||
|
avg_answer_relevance REAL,
|
||||||
|
avg_answer_correctness REAL,
|
||||||
|
rag_score REAL,
|
||||||
|
hallucination_rate REAL,
|
||||||
|
created_at DATETIME
|
||||||
|
);
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 九、开发优先级
|
||||||
|
|
||||||
|
| 阶段 | 内容 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| **Phase 1** | SDK 核心:Adapter 接口 + 检索评测器 | 无 LLM 依赖,最快验证,Hit Rate/MRR/NDCG |
|
||||||
|
| **Phase 2** | dagent Adapter 实现 | 对接现有平台 HTTP API |
|
||||||
|
| **Phase 3** | LLM Judge 模块 | Faithfulness / Relevance / Correctness |
|
||||||
|
| **Phase 4** | 测试集自动生成器 | 降低标注成本 |
|
||||||
|
| **Phase 5** | FastAPI Server | 把 SDK 包成 Web 服务,支持异步任务 |
|
||||||
|
| **Phase 6** | React 前端 | 报告可视化、测试集管理 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 十、技术选型
|
||||||
|
|
||||||
|
| 模块 | 技术 | 理由 |
|
||||||
|
|------|------|------|
|
||||||
|
| SDK | Python 3.10+, asyncio, Pydantic | 与 dagent 保持一致,异步支持并发评测 |
|
||||||
|
| Server | FastAPI + SQLite(开发)/ PostgreSQL(生产) | 轻量,易部署 |
|
||||||
|
| 任务队列 | asyncio.Queue(轻量)/ Celery(生产) | 评测任务耗时长,需异步执行 |
|
||||||
|
| Frontend | React + TypeScript + Ant Design | 与 dagent 前端技术栈一致 |
|
||||||
|
| LLM Judge | OpenAI SDK(兼容 DeepSeek/Qwen) | 统一接口,灵活切换模型 |
|
||||||
|
| 部署 | Docker Compose | 一键启动 server + frontend |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 十一、与 dagent 平台的集成方式
|
||||||
|
|
||||||
|
框架通过 **HTTP API** 调用 dagent,不依赖 dagent 内部代码。
|
||||||
|
|
||||||
|
dagent 需要提供(或框架调用现有接口):
|
||||||
|
|
||||||
|
1. **检索接口**:`POST /dagent/knowledge/retrieve`
|
||||||
|
- 输入:query, knowledge_hub_id, top_k, org_id
|
||||||
|
- 输出:切片列表(chunk_id, content, score, headers, file_id)
|
||||||
|
|
||||||
|
2. **对话接口**:`POST /dagent/agent/chat`(现有 SSE 接口)
|
||||||
|
- 输入:question, agent_id, org_id
|
||||||
|
- 输出:回复文本 + 引用切片信息
|
||||||
|
|
||||||
|
如果 dagent 现有接口不完全满足,可在 dagent 侧新增一个**评测专用接口**,返回更详细的检索过程信息(如每个切片的 cosine distance、rerank score 等)。
|
||||||
566
docs/基于Dagent平台的多模态问答集生成方案.md
Normal file
566
docs/基于Dagent平台的多模态问答集生成方案.md
Normal file
@ -0,0 +1,566 @@
|
|||||||
|
# 基于 Dagent 平台的多模态问答集生成方案
|
||||||
|
|
||||||
|
**目标:** 利用 dagent 后端已有的知识库处理能力,生成包含图像信息的高质量问答集
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 一、Dagent 平台现有能力分析
|
||||||
|
|
||||||
|
### 1.1 核心能力
|
||||||
|
|
||||||
|
| 能力 | 实现位置 | 说明 |
|
||||||
|
|------|---------|------|
|
||||||
|
| **HTML → Markdown 转换** | `pdf_service.py` 调用 marker 服务 | 支持 PDF/DOCX/RST → MD |
|
||||||
|
| **图片 OCR + 语义描述** | `pic_to_text.py` | 使用 GPT-4V 将图片转文本,存入数据库 |
|
||||||
|
| **Markdown 段落分割** | `split_markdown_filter.py` | 按标题层级分割段落 |
|
||||||
|
| **图片路径处理** | `md_service.py` | 相对路径 → BOS 绝对路径 |
|
||||||
|
| **向量索引存储** | `store_*_semantic_index.py` | 段落/问题/表格向量化 |
|
||||||
|
| **知识库检索** | `knowledge_md_retrieve_service.py` | 语义搜索 |
|
||||||
|
|
||||||
|
### 1.2 数据库结构(OceanBase,兼容 MySQL)
|
||||||
|
|
||||||
|
**连接信息:**
|
||||||
|
```
|
||||||
|
Host: 120.48.66.228
|
||||||
|
Port: 23306
|
||||||
|
User: dagent
|
||||||
|
Password: Fd1.Ej3.fdIie48
|
||||||
|
Database: dagent_platform
|
||||||
|
```
|
||||||
|
|
||||||
|
**核心表:**
|
||||||
|
|
||||||
|
**knowledge_file** — 原始文件元数据
|
||||||
|
```
|
||||||
|
id, org_id, file_md5, file_name, file_type, file_bytes, file_url, file_clean_status
|
||||||
|
```
|
||||||
|
|
||||||
|
**knowledge_md_header_split** — 段落分割结果(最重要)
|
||||||
|
```
|
||||||
|
id, org_id, file_id, file_name, headers
|
||||||
|
paragraph_context -- 段落文本内容
|
||||||
|
paragraph_img_num -- 段落内图片数量
|
||||||
|
paragraph_pic_semantics_context -- 图片 OCR + 语义描述(GPT-4V 已处理)
|
||||||
|
paragraph_question -- Dagent 已生成的段落问题
|
||||||
|
paragraph_summary -- 段落摘要
|
||||||
|
paragraph_keywords -- 关键词
|
||||||
|
```
|
||||||
|
|
||||||
|
**knowledge_md_paragraph_active_context** — 段落活跃上下文(含向量)
|
||||||
|
```
|
||||||
|
id, file_id, headers, active_context, active_context_vector
|
||||||
|
```
|
||||||
|
|
||||||
|
### 1.3 关键发现
|
||||||
|
|
||||||
|
**Dagent 已经做了:**
|
||||||
|
- 209 个 HTML 文件 → 已转换为 Markdown
|
||||||
|
- 1142 张图片 → 已上传 BOS,已用 GPT-4V 生成语义描述
|
||||||
|
- 段落按标题层级分割完毕
|
||||||
|
- 每个段落已有 `paragraph_question`、`paragraph_summary`、`paragraph_keywords`
|
||||||
|
|
||||||
|
**结论:不需要重新处理 HTML,直接读数据库即可。**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 二、方案设计
|
||||||
|
|
||||||
|
### 2.1 整体流程
|
||||||
|
|
||||||
|
```
|
||||||
|
Dagent 数据库 (knowledge_md_header_split)
|
||||||
|
↓
|
||||||
|
提取段落数据
|
||||||
|
- paragraph_context(文本)
|
||||||
|
- paragraph_pic_semantics_context(图片语义,已有)
|
||||||
|
- paragraph_question(种子问题,已有)
|
||||||
|
↓
|
||||||
|
┌─────────────────────────────────────┐
|
||||||
|
│ 问答生成(三类) │
|
||||||
|
│ 1. 纯文本问题(基于 paragraph_context)│
|
||||||
|
│ 2. 图文结合问题(文本 + 图片语义) │
|
||||||
|
│ 3. 扩展种子问题(基于已有问题扩展) │
|
||||||
|
└─────────────────────────────────────┘
|
||||||
|
↓
|
||||||
|
存入 RAG Eval 数据库 (qa_gen_question)
|
||||||
|
↓
|
||||||
|
审核 → 导出 MD → 单跳召回测试
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.2 对比:从零处理 vs 利用 Dagent
|
||||||
|
|
||||||
|
| 维度 | 从零处理 HTML | 利用 Dagent 数据库 |
|
||||||
|
|------|--------------|-------------------|
|
||||||
|
| 开发工作量 | 2-3 周 | 3-5 天 |
|
||||||
|
| 图片 OCR 成本 | 1142 张 × $0.008 = $9 | $0(已完成) |
|
||||||
|
| 问答生成成本 | $4 | $4 |
|
||||||
|
| 数据可靠性 | 需验证 | 生产环境已验证 |
|
||||||
|
| **总成本** | **$13 + 2-3 周** | **$4 + 3-5 天** |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 三、实现方案
|
||||||
|
|
||||||
|
### 3.1 后端:新增 Dagent 数据源支持
|
||||||
|
|
||||||
|
**新增文件:** `server/api/qa_gen_dagent.py`
|
||||||
|
|
||||||
|
```python
|
||||||
|
"""
|
||||||
|
从 Dagent 数据库导入知识库数据,生成多模态问答集
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import aiomysql
|
||||||
|
from fastapi import APIRouter, Form
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..models.db import get_db, _now, _id
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/qa-gen", tags=["问题生成-Dagent"])
|
||||||
|
|
||||||
|
DAGENT_DB = {
|
||||||
|
"host": "120.48.66.228",
|
||||||
|
"port": 23306,
|
||||||
|
"user": "dagent",
|
||||||
|
"password": "Fd1.Ej3.fdIie48",
|
||||||
|
"db": "dagent_platform",
|
||||||
|
"charset": "utf8mb4",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def get_dagent_conn():
|
||||||
|
return await aiomysql.connect(**DAGENT_DB)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/task/from-dagent")
|
||||||
|
async def create_task_from_dagent(
|
||||||
|
org_id: str = Form(...),
|
||||||
|
name: str = Form(""),
|
||||||
|
judge_config_id: str = Form(...),
|
||||||
|
file_ids: str = Form(""), # 逗号分隔的 file_id,为空则全量
|
||||||
|
questions_per_section: int = Form(5),
|
||||||
|
quality_threshold: float = Form(0.6),
|
||||||
|
include_multimodal: bool = Form(True),
|
||||||
|
):
|
||||||
|
"""从 Dagent 数据库创建问答生成任务"""
|
||||||
|
task_id = _id()
|
||||||
|
file_id_list = [f.strip() for f in file_ids.split(",") if f.strip()]
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"""INSERT INTO qa_gen_task
|
||||||
|
(id,name,judge_config_id,questions_per_section,quality_threshold,status,created_at)
|
||||||
|
VALUES (?,?,?,?,?,?,?)""",
|
||||||
|
(task_id, name or f"Dagent导入({org_id[:8]}...)",
|
||||||
|
judge_config_id, questions_per_section, quality_threshold, "pending", _now()),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
asyncio.create_task(_run_dagent_task(
|
||||||
|
task_id, org_id, file_id_list, judge_config_id,
|
||||||
|
questions_per_section, quality_threshold, include_multimodal,
|
||||||
|
))
|
||||||
|
return {"status": 0, "data": {"id": task_id}}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/dagent/files")
|
||||||
|
async def list_dagent_files(org_id: str):
|
||||||
|
"""列出 Dagent 中某组织下已处理完成的文件"""
|
||||||
|
conn = await get_dagent_conn()
|
||||||
|
cursor = await conn.cursor(aiomysql.DictCursor)
|
||||||
|
await cursor.execute(
|
||||||
|
"""SELECT id, file_name, file_type, file_clean_status,
|
||||||
|
file_bytes, create_time
|
||||||
|
FROM knowledge_file
|
||||||
|
WHERE org_id = %s AND delete_time IS NULL
|
||||||
|
ORDER BY create_time DESC""",
|
||||||
|
(org_id,),
|
||||||
|
)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
await cursor.close()
|
||||||
|
conn.close()
|
||||||
|
return {"status": 0, "data": [dict(r) for r in rows]}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/dagent/stats")
|
||||||
|
async def get_dagent_stats(org_id: str):
|
||||||
|
"""获取 Dagent 知识库统计信息"""
|
||||||
|
conn = await get_dagent_conn()
|
||||||
|
cursor = await conn.cursor(aiomysql.DictCursor)
|
||||||
|
await cursor.execute(
|
||||||
|
"""SELECT
|
||||||
|
COUNT(DISTINCT f.id) as file_count,
|
||||||
|
COUNT(h.id) as paragraph_count,
|
||||||
|
SUM(h.paragraph_img_num) as total_images,
|
||||||
|
SUM(CASE WHEN h.paragraph_pic_semantics_context IS NOT NULL
|
||||||
|
AND h.paragraph_img_num > 0 THEN 1 ELSE 0 END) as paragraphs_with_pic_text,
|
||||||
|
SUM(CASE WHEN h.paragraph_question IS NOT NULL THEN 1 ELSE 0 END) as paragraphs_with_question
|
||||||
|
FROM knowledge_file f
|
||||||
|
LEFT JOIN knowledge_md_header_split h
|
||||||
|
ON f.id = h.file_id AND h.delete_time IS NULL
|
||||||
|
WHERE f.org_id = %s AND f.delete_time IS NULL
|
||||||
|
AND f.file_clean_status = 'CLEAN_FINISH'""",
|
||||||
|
(org_id,),
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
await cursor.close()
|
||||||
|
conn.close()
|
||||||
|
return {"status": 0, "data": dict(row) if row else {}}
|
||||||
|
|
||||||
|
|
||||||
|
# ── 内部:后台任务 ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _fetch_paragraphs(org_id: str, file_id_list: list[str]) -> list[dict]:
|
||||||
|
"""从 Dagent 数据库提取段落数据"""
|
||||||
|
conn = await get_dagent_conn()
|
||||||
|
cursor = await conn.cursor(aiomysql.DictCursor)
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
SELECT h.id, h.file_id, h.file_name, h.headers,
|
||||||
|
h.paragraph_context, h.paragraph_img_num,
|
||||||
|
h.paragraph_pic_semantics_context,
|
||||||
|
h.paragraph_question, h.paragraph_summary, h.paragraph_keywords
|
||||||
|
FROM knowledge_md_header_split h
|
||||||
|
JOIN knowledge_file f ON f.id = h.file_id
|
||||||
|
WHERE h.org_id = %s
|
||||||
|
AND h.delete_time IS NULL
|
||||||
|
AND f.delete_time IS NULL
|
||||||
|
AND f.file_clean_status = 'CLEAN_FINISH'
|
||||||
|
"""
|
||||||
|
params = [org_id]
|
||||||
|
|
||||||
|
if file_id_list:
|
||||||
|
placeholders = ",".join(["%s"] * len(file_id_list))
|
||||||
|
sql += f" AND h.file_id IN ({placeholders})"
|
||||||
|
params.extend(file_id_list)
|
||||||
|
|
||||||
|
sql += " ORDER BY h.file_name, h.headers"
|
||||||
|
|
||||||
|
await cursor.execute(sql, params)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
await cursor.close()
|
||||||
|
conn.close()
|
||||||
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
async def _generate_questions_for_paragraph(
|
||||||
|
para: dict, cfg: dict, n: int, include_multimodal: bool
|
||||||
|
) -> list[dict]:
|
||||||
|
"""为单个段落生成问答"""
|
||||||
|
import aiohttp, re
|
||||||
|
|
||||||
|
base_url = cfg.get("base_url", "").rstrip("/")
|
||||||
|
api_key = cfg.get("api_key", "")
|
||||||
|
model = cfg.get("model", "gpt-4o-mini")
|
||||||
|
|
||||||
|
text = (para.get("paragraph_context") or "").strip()
|
||||||
|
pic_semantics = (para.get("paragraph_pic_semantics_context") or "").strip()
|
||||||
|
seed_question = (para.get("paragraph_question") or "").strip()
|
||||||
|
headers = (para.get("headers") or "").strip()
|
||||||
|
has_image = bool(pic_semantics and para.get("paragraph_img_num", 0) > 0)
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 构建 prompt
|
||||||
|
pic_section = ""
|
||||||
|
if has_image and include_multimodal:
|
||||||
|
pic_section = f"""
|
||||||
|
**图片语义描述(图片已由 AI 识别):**
|
||||||
|
{pic_semantics[:800]}
|
||||||
|
"""
|
||||||
|
|
||||||
|
seed_section = ""
|
||||||
|
if seed_question:
|
||||||
|
seed_section = f"\n**已有种子问题(请避免重复,可从不同角度扩展):** {seed_question}"
|
||||||
|
|
||||||
|
prompt = f"""你是一个技术文档问答生成专家。基于以下内容生成 {n} 个测试问题。
|
||||||
|
|
||||||
|
**章节路径:** {headers}
|
||||||
|
|
||||||
|
**文本内容:**
|
||||||
|
{text[:2500]}
|
||||||
|
{pic_section}{seed_section}
|
||||||
|
|
||||||
|
**要求:**
|
||||||
|
1. 问题必须能从该章节内容直接回答
|
||||||
|
2. 覆盖关键知识点,避免过于简单的是非题
|
||||||
|
3. 如果有图片语义描述,至少生成 1 个图文结合的问题(问题中提及"如图所示"、"图中"等)
|
||||||
|
4. 答案准确,长度适中(1-3 句话)
|
||||||
|
5. source_chunk 为答案来源的原文片段(50-150 字)
|
||||||
|
6. has_image 标记该问题是否依赖图像信息
|
||||||
|
7. quality_score 为质量评估(0-1)
|
||||||
|
|
||||||
|
只输出 JSON 数组:
|
||||||
|
[
|
||||||
|
{{
|
||||||
|
"question": "问题文本",
|
||||||
|
"answer": "参考答案",
|
||||||
|
"source_chunk": "答案来源原文片段",
|
||||||
|
"has_image": false,
|
||||||
|
"quality_score": 0.9
|
||||||
|
}}
|
||||||
|
]"""
|
||||||
|
|
||||||
|
headers_http = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"temperature": 0.3,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(headers=headers_http) as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{base_url}/chat/completions",
|
||||||
|
json=payload,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=60),
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = await resp.json()
|
||||||
|
|
||||||
|
text_resp = data["choices"][0]["message"]["content"].strip()
|
||||||
|
m = re.search(r"\[.*\]", text_resp, re.DOTALL)
|
||||||
|
if not m:
|
||||||
|
return []
|
||||||
|
questions = json.loads(m.group())
|
||||||
|
result = []
|
||||||
|
for q in questions:
|
||||||
|
if isinstance(q, dict) and q.get("question") and q.get("answer"):
|
||||||
|
result.append({
|
||||||
|
"question": str(q["question"]).strip(),
|
||||||
|
"answer": str(q["answer"]).strip(),
|
||||||
|
"source_chunk": str(q.get("source_chunk", "")).strip(),
|
||||||
|
"has_image": bool(q.get("has_image", False)),
|
||||||
|
"quality_score": float(q.get("quality_score", 0.8)),
|
||||||
|
"source_image_desc": pic_semantics[:300] if q.get("has_image") else "",
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_dagent_task(
|
||||||
|
task_id: str,
|
||||||
|
org_id: str,
|
||||||
|
file_id_list: list[str],
|
||||||
|
judge_config_id: str,
|
||||||
|
questions_per_section: int,
|
||||||
|
quality_threshold: float,
|
||||||
|
include_multimodal: bool,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
# 1. 提取段落
|
||||||
|
paragraphs = await _fetch_paragraphs(org_id, file_id_list)
|
||||||
|
total = len(paragraphs)
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE qa_gen_task SET status='running', total=? WHERE id=?",
|
||||||
|
(total, task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# 2. 获取 LLM 配置
|
||||||
|
async with get_db() as db:
|
||||||
|
cfg_rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM judge_config WHERE id=?", (judge_config_id,)
|
||||||
|
)
|
||||||
|
if not cfg_rows:
|
||||||
|
raise ValueError("judge_config not found")
|
||||||
|
cfg = dict(cfg_rows[0])
|
||||||
|
|
||||||
|
# 3. 并发生成(每次最多 5 个段落并发)
|
||||||
|
sem = asyncio.Semaphore(5)
|
||||||
|
done = 0
|
||||||
|
FLUSH_SIZE = 10
|
||||||
|
write_buf = []
|
||||||
|
|
||||||
|
async def process_one(para: dict):
|
||||||
|
nonlocal done
|
||||||
|
async with sem:
|
||||||
|
questions = await _generate_questions_for_paragraph(
|
||||||
|
para, cfg, questions_per_section, include_multimodal
|
||||||
|
)
|
||||||
|
done += 1
|
||||||
|
write_buf.extend([(para, q) for q in questions])
|
||||||
|
|
||||||
|
if len(write_buf) >= FLUSH_SIZE or done == total:
|
||||||
|
batch = write_buf.copy()
|
||||||
|
write_buf.clear()
|
||||||
|
async with get_db() as db2:
|
||||||
|
for p, q in batch:
|
||||||
|
qid = _id()
|
||||||
|
status = "approved" if q["quality_score"] >= quality_threshold else "pending"
|
||||||
|
await db2.execute(
|
||||||
|
"""INSERT INTO qa_gen_question
|
||||||
|
(id,task_id,section_path,question,reference_answer,source_chunk,
|
||||||
|
quality_score,status,created_at)
|
||||||
|
VALUES (?,?,?,?,?,?,?,?,?)""",
|
||||||
|
(qid, task_id, p["headers"],
|
||||||
|
q["question"], q["answer"], q["source_chunk"],
|
||||||
|
q["quality_score"], status, _now()),
|
||||||
|
)
|
||||||
|
# 同步 approved 计数
|
||||||
|
count_rows = await db2.execute_fetchall(
|
||||||
|
"SELECT COUNT(*) as cnt FROM qa_gen_question WHERE task_id=? AND status='approved'",
|
||||||
|
(task_id,),
|
||||||
|
)
|
||||||
|
approved = dict(count_rows[0])["cnt"] if count_rows else 0
|
||||||
|
await db2.execute(
|
||||||
|
"UPDATE qa_gen_task SET progress=?, approved=? WHERE id=?",
|
||||||
|
(done, approved, task_id),
|
||||||
|
)
|
||||||
|
await db2.commit()
|
||||||
|
|
||||||
|
await asyncio.gather(*[process_one(p) for p in paragraphs])
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE qa_gen_task SET status='done', finished_at=? WHERE id=?",
|
||||||
|
(_now(), task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE qa_gen_task SET status='failed', error_message=? WHERE id=?",
|
||||||
|
(str(exc), task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.2 注册路由
|
||||||
|
|
||||||
|
在 `server/main.py` 中添加:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from .api import config, dataset, task, report, single_jump, qa_gen, qa_gen_dagent
|
||||||
|
|
||||||
|
app.include_router(qa_gen_dagent.router)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.3 前端:新增"从 Dagent 导入"入口
|
||||||
|
|
||||||
|
在 `QaGen/index.tsx` 的新建任务弹窗中增加数据源切换:
|
||||||
|
|
||||||
|
```tsx
|
||||||
|
// 数据源选择
|
||||||
|
<Form.Item label="数据来源">
|
||||||
|
<Radio.Group value={dataSource} onChange={e => setDataSource(e.target.value)}>
|
||||||
|
<Radio value="file">上传 MD 文件</Radio>
|
||||||
|
<Radio value="dagent">从 Dagent 知识库导入</Radio>
|
||||||
|
</Radio.Group>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
{dataSource === 'dagent' ? (
|
||||||
|
<>
|
||||||
|
<Form.Item name="org_id" label="Dagent 组织 ID" rules={[{ required: true }]}>
|
||||||
|
<Input placeholder="cd6e121594984516..." />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="file_ids" label="指定文件 ID(可选)"
|
||||||
|
tooltip="留空则导入该组织下所有已处理完成的文件">
|
||||||
|
<Input.TextArea rows={2} placeholder="多个 ID 用逗号分隔,留空则全量导入" />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="include_multimodal" label="生成图文结合问题" valuePropName="checked"
|
||||||
|
tooltip="利用 Dagent 已生成的图片语义描述,生成图文结合的问题">
|
||||||
|
<Switch defaultChecked />
|
||||||
|
</Form.Item>
|
||||||
|
{/* 统计信息展示 */}
|
||||||
|
{dagentStats && (
|
||||||
|
<div style={{ background: '#f6ffed', border: '1px solid #b7eb8f', borderRadius: 6, padding: '8px 12px', marginBottom: 16 }}>
|
||||||
|
<Space split={<Divider type="vertical" />}>
|
||||||
|
<span>文件数: <b>{dagentStats.file_count}</b></span>
|
||||||
|
<span>段落数: <b>{dagentStats.paragraph_count}</b></span>
|
||||||
|
<span>含图段落: <b>{dagentStats.paragraphs_with_pic_text}</b></span>
|
||||||
|
<span>总图片: <b>{dagentStats.total_images}</b></span>
|
||||||
|
</Space>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
// 原有的文件上传 UI
|
||||||
|
<Form.Item label="知识库 MD 文件" required>
|
||||||
|
<Upload ... />
|
||||||
|
</Form.Item>
|
||||||
|
)}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 四、验证步骤
|
||||||
|
|
||||||
|
### Step 1:先查询数据库确认数据完整性
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- 查看 EVB 知识库的文件列表
|
||||||
|
SELECT id, file_name, file_type, file_clean_status
|
||||||
|
FROM knowledge_file
|
||||||
|
WHERE org_id = 'cd6e121594984516bde17ae9aeb0eb45a01e6d28143034608c4985aea369deec'
|
||||||
|
AND delete_time IS NULL
|
||||||
|
ORDER BY file_name;
|
||||||
|
|
||||||
|
-- 查看段落统计(含图片处理情况)
|
||||||
|
SELECT
|
||||||
|
f.file_name,
|
||||||
|
COUNT(h.id) as paragraphs,
|
||||||
|
SUM(h.paragraph_img_num) as images,
|
||||||
|
SUM(CASE WHEN h.paragraph_pic_semantics_context IS NOT NULL THEN 1 ELSE 0 END) as pic_text_done,
|
||||||
|
SUM(CASE WHEN h.paragraph_question IS NOT NULL THEN 1 ELSE 0 END) as has_question
|
||||||
|
FROM knowledge_file f
|
||||||
|
JOIN knowledge_md_header_split h ON f.id = h.file_id AND h.delete_time IS NULL
|
||||||
|
WHERE f.org_id = 'cd6e121594984516bde17ae9aeb0eb45a01e6d28143034608c4985aea369deec'
|
||||||
|
AND f.delete_time IS NULL
|
||||||
|
GROUP BY f.file_name
|
||||||
|
ORDER BY f.file_name;
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2:抽样检查图片语义质量
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- 随机抽取 10 个有图片的段落,检查图片语义描述质量
|
||||||
|
SELECT headers, LEFT(paragraph_context, 200) as text_preview,
|
||||||
|
LEFT(paragraph_pic_semantics_context, 300) as pic_text_preview
|
||||||
|
FROM knowledge_md_header_split
|
||||||
|
WHERE org_id = 'cd6e121594984516bde17ae9aeb0eb45a01e6d28143034608c4985aea369deec'
|
||||||
|
AND paragraph_img_num > 0
|
||||||
|
AND paragraph_pic_semantics_context IS NOT NULL
|
||||||
|
AND delete_time IS NULL
|
||||||
|
ORDER BY RAND()
|
||||||
|
LIMIT 10;
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3:小批量 Pilot 测试
|
||||||
|
|
||||||
|
先选 1 个文件(如 `common_questions`)做 Pilot,生成 ~50 条问答,人工审核质量后再全量。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 五、预期产出
|
||||||
|
|
||||||
|
| 模块 | 段落数 | 含图段落 | 预期问答数 |
|
||||||
|
|------|--------|---------|-----------|
|
||||||
|
| linux_development | ~500 | ~200 | ~2500 条 |
|
||||||
|
| multimedia_development | ~150 | ~80 | ~750 条 |
|
||||||
|
| samples | ~100 | ~50 | ~500 条 |
|
||||||
|
| toolchain_development | ~80 | ~30 | ~400 条 |
|
||||||
|
| quick_start | ~30 | ~15 | ~150 条 |
|
||||||
|
| preface + common_questions | ~20 | ~5 | ~100 条 |
|
||||||
|
| **合计** | **~880** | **~380** | **~4400 条** |
|
||||||
|
|
||||||
|
其中多模态问题(图文结合)预计占 **20-30%**(约 880-1320 条)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 六、依赖安装
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install aiomysql
|
||||||
|
```
|
||||||
|
|
||||||
|
(其他依赖 aiohttp、fastapi 等已有)
|
||||||
621
docs/多模态问答集生成方案.md
Normal file
621
docs/多模态问答集生成方案.md
Normal file
@ -0,0 +1,621 @@
|
|||||||
|
# 基于 HTML 知识库的多模态问答集生成方案
|
||||||
|
|
||||||
|
**知识库特征:**
|
||||||
|
- 格式:Sphinx 生成的 HTML 文档(209 个 HTML 文件)
|
||||||
|
- 图片:1142 张图片(PNG/JPG),存放在 `_images/` 目录
|
||||||
|
- 结构:层级目录组织,包含大量配置界面截图、架构图、流程图等
|
||||||
|
|
||||||
|
**目标:** 生成高质量的多模态问答集,充分利用文本和图像信息
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 一、核心挑战
|
||||||
|
|
||||||
|
### 1.1 当前问题
|
||||||
|
|
||||||
|
| 问题 | 影响 |
|
||||||
|
|------|------|
|
||||||
|
| **图像信息丢失** | 现有 MD 问答集只有文本,配置界面截图、架构图等关键信息缺失 |
|
||||||
|
| **图文关联弱** | 图片与文本分离,无法生成"如图所示"类问题 |
|
||||||
|
| **问题质量受限** | 纯文本问题无法覆盖"界面在哪里点击"、"架构图中的模块关系"等场景 |
|
||||||
|
|
||||||
|
### 1.2 图像类型分析
|
||||||
|
|
||||||
|
根据采样分析,知识库中的图像主要分为以下类型:
|
||||||
|
|
||||||
|
| 图像类型 | 占比估算 | 示例 | 问答价值 |
|
||||||
|
|---------|---------|------|---------|
|
||||||
|
| **配置界面截图** | ~40% | menuconfig 界面、参数配置页面 | ⭐⭐⭐⭐⭐ 高价值,可生成操作类问题 |
|
||||||
|
| **架构图/流程图** | ~30% | 系统架构、数据流向、模块关系 | ⭐⭐⭐⭐⭐ 高价值,可生成理解类问题 |
|
||||||
|
| **代码截图** | ~15% | 代码片段、配置文件示例 | ⭐⭐⭐ 中等价值,文本已包含 |
|
||||||
|
| **硬件接口图** | ~10% | 引脚定义、电路连接 | ⭐⭐⭐⭐ 高价值,纯文本难以描述 |
|
||||||
|
| **其他** | ~5% | Logo、装饰性图片 | ⭐ 低价值 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 二、方案设计
|
||||||
|
|
||||||
|
### 2.1 整体流程
|
||||||
|
|
||||||
|
```
|
||||||
|
HTML 知识库
|
||||||
|
↓
|
||||||
|
┌─────────────────────────────────────┐
|
||||||
|
│ 阶段 1:HTML → 结构化 Markdown │
|
||||||
|
│ - 提取文本内容 │
|
||||||
|
│ - 保留图片占位符 [IMAGE: xxx.png] │
|
||||||
|
│ - 保留章节层级结构 │
|
||||||
|
└─────────────────────────────────────┘
|
||||||
|
↓
|
||||||
|
┌─────────────────────────────────────┐
|
||||||
|
│ 阶段 2:图像分类与描述生成 │
|
||||||
|
│ - 多模态 LLM 识别图像类型 │
|
||||||
|
│ - 生成图像描述(caption) │
|
||||||
|
│ - 提取图像中的关键信息 │
|
||||||
|
└─────────────────────────────────────┘
|
||||||
|
↓
|
||||||
|
┌─────────────────────────────────────┐
|
||||||
|
│ 阶段 3:多模态问答生成 │
|
||||||
|
│ - 纯文本问题(基于文本内容) │
|
||||||
|
│ - 图文结合问题(基于图像+上下文) │
|
||||||
|
│ - 图像理解问题(基于图像描述) │
|
||||||
|
└─────────────────────────────────────┘
|
||||||
|
↓
|
||||||
|
┌─────────────────────────────────────┐
|
||||||
|
│ 阶段 4:问答集审核与优化 │
|
||||||
|
│ - 查重(文本 + 图像相似度) │
|
||||||
|
│ - 质量评分 │
|
||||||
|
│ - 人工审核 │
|
||||||
|
└─────────────────────────────────────┘
|
||||||
|
↓
|
||||||
|
多模态问答集(MD + 图像引用)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 三、技术实现
|
||||||
|
|
||||||
|
### 3.1 阶段 1:HTML → 结构化 Markdown
|
||||||
|
|
||||||
|
**目标:** 将 HTML 转换为保留图像占位符的 Markdown
|
||||||
|
|
||||||
|
**实现方案:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def html_to_markdown_with_images(html_path: Path, base_path: Path) -> str:
|
||||||
|
"""
|
||||||
|
将 HTML 转换为 Markdown,保留图像占位符
|
||||||
|
|
||||||
|
返回格式:
|
||||||
|
## 章节标题
|
||||||
|
|
||||||
|
文本内容...
|
||||||
|
|
||||||
|

|
||||||
|
*图:Uboot menuconfig 配置界面*
|
||||||
|
|
||||||
|
继续文本内容...
|
||||||
|
"""
|
||||||
|
html = html_path.read_text(encoding='utf-8')
|
||||||
|
soup = BeautifulSoup(html, 'html.parser')
|
||||||
|
|
||||||
|
# 提取主内容区域
|
||||||
|
main = soup.find('div', role='main') or soup.find('section')
|
||||||
|
|
||||||
|
md_lines = []
|
||||||
|
current_section = []
|
||||||
|
|
||||||
|
for elem in main.descendants:
|
||||||
|
if elem.name in ('h1', 'h2', 'h3', 'h4'):
|
||||||
|
# 章节标题
|
||||||
|
level = int(elem.name[1])
|
||||||
|
title = elem.get_text(strip=True)
|
||||||
|
md_lines.append(f"{'#' * level} {title}\n")
|
||||||
|
|
||||||
|
elif elem.name == 'p':
|
||||||
|
# 段落(可能包含图片)
|
||||||
|
if elem.find('img'):
|
||||||
|
# 处理图片
|
||||||
|
for img in elem.find_all('img'):
|
||||||
|
src = img.get('src', '')
|
||||||
|
alt = img.get('alt', '')
|
||||||
|
# 转换相对路径为绝对路径
|
||||||
|
img_path = (html_path.parent / src).resolve()
|
||||||
|
rel_path = img_path.relative_to(base_path)
|
||||||
|
md_lines.append(f"")
|
||||||
|
# 添加图片说明(从上下文推断)
|
||||||
|
caption = infer_image_caption(elem, img)
|
||||||
|
if caption:
|
||||||
|
md_lines.append(f"*图:{caption}*\n")
|
||||||
|
else:
|
||||||
|
# 纯文本段落
|
||||||
|
text = elem.get_text(strip=True)
|
||||||
|
if text:
|
||||||
|
md_lines.append(f"{text}\n")
|
||||||
|
|
||||||
|
elif elem.name == 'pre':
|
||||||
|
# 代码块
|
||||||
|
code = elem.get_text(strip=True)
|
||||||
|
md_lines.append(f"```\n{code}\n```\n")
|
||||||
|
|
||||||
|
elif elem.name == 'li':
|
||||||
|
# 列表项
|
||||||
|
text = elem.get_text(strip=True)
|
||||||
|
if text:
|
||||||
|
md_lines.append(f"- {text}")
|
||||||
|
|
||||||
|
return '\n'.join(md_lines)
|
||||||
|
|
||||||
|
def infer_image_caption(parent_elem, img_elem) -> str:
|
||||||
|
"""从图片周围的上下文推断图片说明"""
|
||||||
|
# 策略 1:查找前一个段落的最后一句
|
||||||
|
prev = parent_elem.find_previous_sibling('p')
|
||||||
|
if prev:
|
||||||
|
text = prev.get_text(strip=True)
|
||||||
|
if '如图' in text or '如下图' in text or '界面' in text:
|
||||||
|
return text[-50:] # 取最后 50 字符
|
||||||
|
|
||||||
|
# 策略 2:使用 alt 属性
|
||||||
|
alt = img_elem.get('alt', '')
|
||||||
|
if alt and not alt.startswith('image-'):
|
||||||
|
return alt
|
||||||
|
|
||||||
|
# 策略 3:查找后续段落的第一句
|
||||||
|
next_elem = parent_elem.find_next_sibling('p')
|
||||||
|
if next_elem:
|
||||||
|
text = next_elem.get_text(strip=True)
|
||||||
|
if text:
|
||||||
|
return text[:50]
|
||||||
|
|
||||||
|
return ""
|
||||||
|
```
|
||||||
|
|
||||||
|
**输出示例:**
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## 4.3.2. 配置 Uboot 和 Kernel 选项参数
|
||||||
|
|
||||||
|
在嵌入式系统开发中,Uboot 和 Kernel 的功能选项配置...
|
||||||
|
|
||||||
|
### 使用 xbuild 命令配置
|
||||||
|
|
||||||
|
命令执行成功后,系统会启动一个图形化的配置界面。
|
||||||
|
|
||||||
|

|
||||||
|
*图:Uboot menuconfig 配置界面,可以选择启用或禁用功能*
|
||||||
|
|
||||||
|
完成配置后,选择 Exit 退出...
|
||||||
|
|
||||||
|

|
||||||
|
*图:保存配置提示界面*
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3.2 阶段 2:图像分类与描述生成
|
||||||
|
|
||||||
|
**目标:** 使用多模态 LLM 为每张图片生成结构化描述
|
||||||
|
|
||||||
|
**实现方案:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
import base64
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
async def analyze_image_with_llm(
|
||||||
|
image_path: Path,
|
||||||
|
context_before: str,
|
||||||
|
context_after: str,
|
||||||
|
llm_config: dict,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
使用多模态 LLM 分析图像
|
||||||
|
|
||||||
|
返回:
|
||||||
|
{
|
||||||
|
"type": "config_ui", # 图像类型
|
||||||
|
"description": "Uboot menuconfig 配置界面截图,显示了...",
|
||||||
|
"key_elements": ["File System support", "Network support", ...],
|
||||||
|
"qa_value": 5, # 问答价值评分 1-5
|
||||||
|
"suggested_questions": [
|
||||||
|
"如何进入 Uboot 配置界面?",
|
||||||
|
"配置界面中如何保存修改?"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# 读取图像并编码为 base64
|
||||||
|
img_data = image_path.read_bytes()
|
||||||
|
img_b64 = base64.b64encode(img_data).decode()
|
||||||
|
|
||||||
|
prompt = f"""你是一个技术文档图像分析专家。请分析以下图像并提供结构化信息。
|
||||||
|
|
||||||
|
**图像上下文(前):**
|
||||||
|
{context_before[-500:]}
|
||||||
|
|
||||||
|
**图像上下文(后):**
|
||||||
|
{context_after[:500]}
|
||||||
|
|
||||||
|
请分析图像并返回 JSON 格式:
|
||||||
|
{{
|
||||||
|
"type": "图像类型(config_ui/architecture/flowchart/code/hardware/other)",
|
||||||
|
"description": "详细描述图像内容(100-200字)",
|
||||||
|
"key_elements": ["图像中的关键元素列表"],
|
||||||
|
"qa_value": "问答价值评分 1-5(5=高价值)",
|
||||||
|
"suggested_questions": ["基于此图像可以生成的问题示例"]
|
||||||
|
}}
|
||||||
|
|
||||||
|
**图像类型定义:**
|
||||||
|
- config_ui: 配置界面截图(menuconfig、参数设置页面等)
|
||||||
|
- architecture: 架构图、系统框图
|
||||||
|
- flowchart: 流程图、时序图
|
||||||
|
- code: 代码截图
|
||||||
|
- hardware: 硬件接口图、引脚定义
|
||||||
|
- other: 其他类型
|
||||||
|
|
||||||
|
**评分标准:**
|
||||||
|
- 5分:包含关键操作步骤或架构信息,必须通过图像才能理解
|
||||||
|
- 4分:补充说明性图像,有助于理解但非必需
|
||||||
|
- 3分:代码或配置示例,文本已包含但图像更直观
|
||||||
|
- 2分:装饰性图像,价值较低
|
||||||
|
- 1分:无实质内容
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 调用多模态 LLM(GPT-4V / Claude 3.5 Sonnet)
|
||||||
|
response = await call_multimodal_llm(
|
||||||
|
prompt=prompt,
|
||||||
|
image_base64=img_b64,
|
||||||
|
config=llm_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return json.loads(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
**批量处理策略:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def batch_analyze_images(
|
||||||
|
md_content: str,
|
||||||
|
image_paths: list[Path],
|
||||||
|
llm_config: dict,
|
||||||
|
concurrency: int = 5,
|
||||||
|
) -> dict[str, dict]:
|
||||||
|
"""
|
||||||
|
批量分析图像,返回 {image_path: analysis_result}
|
||||||
|
|
||||||
|
优化策略:
|
||||||
|
1. 并发调用(concurrency=5)
|
||||||
|
2. 缓存已分析的图像(基于文件 hash)
|
||||||
|
3. 低价值图像跳过详细分析(仅记录类型)
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
sem = asyncio.Semaphore(concurrency)
|
||||||
|
|
||||||
|
async def analyze_one(img_path: Path):
|
||||||
|
# 检查缓存
|
||||||
|
cache_key = hashlib.md5(img_path.read_bytes()).hexdigest()
|
||||||
|
if cache_key in image_cache:
|
||||||
|
return image_cache[cache_key]
|
||||||
|
|
||||||
|
# 提取上下文
|
||||||
|
context_before, context_after = extract_image_context(md_content, img_path)
|
||||||
|
|
||||||
|
async with sem:
|
||||||
|
result = await analyze_image_with_llm(
|
||||||
|
img_path, context_before, context_after, llm_config
|
||||||
|
)
|
||||||
|
image_cache[cache_key] = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
tasks = [analyze_one(p) for p in image_paths]
|
||||||
|
analyses = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
return dict(zip(image_paths, analyses))
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3.3 阶段 3:多模态问答生成
|
||||||
|
|
||||||
|
**目标:** 生成三类问题:纯文本、图文结合、图像理解
|
||||||
|
|
||||||
|
**3.3.1 纯文本问题生成**
|
||||||
|
|
||||||
|
复用现有的问题生成逻辑(已实现),基于文本内容生成。
|
||||||
|
|
||||||
|
**3.3.2 图文结合问题生成**
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def generate_image_text_questions(
|
||||||
|
section_path: str,
|
||||||
|
text_content: str,
|
||||||
|
images: list[dict], # [{path, analysis, context}]
|
||||||
|
llm_config: dict,
|
||||||
|
n: int = 3,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
生成图文结合的问题
|
||||||
|
|
||||||
|
示例问题类型:
|
||||||
|
- "如图所示的配置界面中,如何启用 XXX 功能?"
|
||||||
|
- "根据架构图,XXX 模块与 YYY 模块的关系是什么?"
|
||||||
|
- "图中显示的错误信息是什么原因导致的?"
|
||||||
|
"""
|
||||||
|
# 筛选高价值图像(qa_value >= 4)
|
||||||
|
high_value_images = [img for img in images if img['analysis']['qa_value'] >= 4]
|
||||||
|
|
||||||
|
if not high_value_images:
|
||||||
|
return []
|
||||||
|
|
||||||
|
prompt = f"""你是一个技术文档问答生成专家。基于以下文本和图像信息,生成 {n} 个图文结合的问题。
|
||||||
|
|
||||||
|
**章节路径:** {section_path}
|
||||||
|
|
||||||
|
**文本内容:**
|
||||||
|
{text_content[:2000]}
|
||||||
|
|
||||||
|
**图像信息:**
|
||||||
|
"""
|
||||||
|
|
||||||
|
for i, img in enumerate(high_value_images[:3]): # 最多 3 张图
|
||||||
|
prompt += f"""
|
||||||
|
图 {i+1}:{img['path'].name}
|
||||||
|
- 类型:{img['analysis']['type']}
|
||||||
|
- 描述:{img['analysis']['description']}
|
||||||
|
- 关键元素:{', '.join(img['analysis']['key_elements'][:5])}
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt += """
|
||||||
|
**要求:**
|
||||||
|
1. 问题必须同时依赖文本和图像才能回答(不能只看文本或只看图)
|
||||||
|
2. 问题中明确提及"如图所示"、"图中"、"根据架构图"等
|
||||||
|
3. 答案需要结合图像中的具体元素(按钮位置、模块名称、流程步骤等)
|
||||||
|
4. 每个问题附带:
|
||||||
|
- 参考答案
|
||||||
|
- 关联的图像文件名
|
||||||
|
- 答案来源(文本片段 + 图像描述)
|
||||||
|
|
||||||
|
输出 JSON 数组:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"question": "如图所示的配置界面中,如何...",
|
||||||
|
"answer": "在图中可以看到...",
|
||||||
|
"image_ref": "image-20220518111319607.png",
|
||||||
|
"source_text": "文本来源片段",
|
||||||
|
"source_image_desc": "图像描述片段",
|
||||||
|
"quality_score": 0.9
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
|
response = await call_llm(prompt, llm_config)
|
||||||
|
return json.loads(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
**3.3.3 图像理解问题生成**
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def generate_image_understanding_questions(
|
||||||
|
image_path: Path,
|
||||||
|
image_analysis: dict,
|
||||||
|
context: str,
|
||||||
|
llm_config: dict,
|
||||||
|
n: int = 2,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
生成纯图像理解问题(主要针对架构图、流程图)
|
||||||
|
|
||||||
|
示例问题类型:
|
||||||
|
- "架构图中有哪些主要模块?"
|
||||||
|
- "数据流向是怎样的?"
|
||||||
|
- "XXX 模块的输入输出是什么?"
|
||||||
|
"""
|
||||||
|
if image_analysis['type'] not in ('architecture', 'flowchart', 'hardware'):
|
||||||
|
return [] # 只对特定类型图像生成
|
||||||
|
|
||||||
|
# 读取图像
|
||||||
|
img_b64 = base64.b64encode(image_path.read_bytes()).decode()
|
||||||
|
|
||||||
|
prompt = f"""基于以下架构图/流程图,生成 {n} 个理解性问题。
|
||||||
|
|
||||||
|
**图像类型:** {image_analysis['type']}
|
||||||
|
**图像描述:** {image_analysis['description']}
|
||||||
|
**上下文:** {context[:500]}
|
||||||
|
|
||||||
|
**要求:**
|
||||||
|
1. 问题聚焦于图像中的结构、关系、流程
|
||||||
|
2. 答案必须通过仔细观察图像才能得出
|
||||||
|
3. 避免过于简单的"有哪些模块"类问题,要问模块间的关系、数据流向等
|
||||||
|
|
||||||
|
输出 JSON 数组:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"question": "架构图中 XXX 模块与 YYY 模块通过什么方式通信?",
|
||||||
|
"answer": "通过 ZZZ 接口进行通信,数据流向为...",
|
||||||
|
"image_ref": "{image_path.name}",
|
||||||
|
"quality_score": 0.85
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
|
response = await call_multimodal_llm(prompt, img_b64, llm_config)
|
||||||
|
return json.loads(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3.4 阶段 4:多模态问答集格式
|
||||||
|
|
||||||
|
**输出格式:** 扩展现有 MD 格式,支持图像引用
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## linux_development/driver_develop_guide/uboot_config
|
||||||
|
|
||||||
|
## Q1: 如何进入 Uboot 的图形化配置界面?
|
||||||
|
**A1:** 在 Uboot 目录下执行 `make ARCH=arm menuconfig` 命令,会启动图形化配置界面。
|
||||||
|
|
||||||
|
## Q2: 如图所示的配置界面中,如何保存修改后的配置?
|
||||||
|
**IMAGE:** linux_development/driver_develop_guide/_images/image-20220518111319607.png
|
||||||
|
**A2:** 在配置界面中选择 Exit 退出,系统会提示是否保存修改,选择 Yes 即可保存配置到 .config 文件中。如图中红框所示,选择 "< Save >" 按钮。
|
||||||
|
|
||||||
|
## Q3: 根据架构图,BPU 模块与 DDR 之间的数据通路是什么?
|
||||||
|
**IMAGE:** linux_development/system_architecture/_images/bpu_architecture.png
|
||||||
|
**A3:** BPU 模块通过 AXI 总线与 DDR 进行数据交互,支持 DMA 方式进行高速数据传输。
|
||||||
|
|
||||||
|
---
|
||||||
|
```
|
||||||
|
|
||||||
|
**数据库扩展:**
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- 扩展 qa_gen_question 表
|
||||||
|
ALTER TABLE qa_gen_question ADD COLUMN image_ref TEXT; -- 关联的图像路径
|
||||||
|
ALTER TABLE qa_gen_question ADD COLUMN question_type TEXT; -- text/image_text/image_only
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 四、实施计划
|
||||||
|
|
||||||
|
### 4.1 Phase 1:基础设施(1-2 天)
|
||||||
|
|
||||||
|
| 任务 | 输出 |
|
||||||
|
|------|------|
|
||||||
|
| HTML → Markdown 转换器 | Python 脚本,支持图像占位符 |
|
||||||
|
| 图像分析 API 封装 | 调用 GPT-4V/Claude 3.5 Sonnet |
|
||||||
|
| 数据库扩展 | 新增 image_ref 字段 |
|
||||||
|
|
||||||
|
### 4.2 Phase 2:图像分析(2-3 天)
|
||||||
|
|
||||||
|
| 任务 | 输出 |
|
||||||
|
|------|------|
|
||||||
|
| 批量图像分类 | 1142 张图片的类型标注 |
|
||||||
|
| 高价值图像筛选 | 筛选出 qa_value >= 4 的图像(约 400-500 张) |
|
||||||
|
| 图像描述生成 | 为高价值图像生成详细描述 |
|
||||||
|
|
||||||
|
**成本估算:**
|
||||||
|
- GPT-4V:$0.01/image × 1142 = **$11.42**
|
||||||
|
- Claude 3.5 Sonnet:$0.008/image × 1142 = **$9.14**
|
||||||
|
|
||||||
|
### 4.3 Phase 3:多模态问答生成(3-5 天)
|
||||||
|
|
||||||
|
| 任务 | 输出 |
|
||||||
|
|------|------|
|
||||||
|
| 纯文本问题生成 | 复用现有逻辑 |
|
||||||
|
| 图文结合问题生成 | 针对 400-500 张高价值图像 |
|
||||||
|
| 图像理解问题生成 | 针对架构图/流程图(约 100-150 张) |
|
||||||
|
| 问答集审核 | 查重、质量评分、人工审核 |
|
||||||
|
|
||||||
|
**预期产出:**
|
||||||
|
- 纯文本问题:~1000 条(与现有方案一致)
|
||||||
|
- 图文结合问题:~800 条(每张高价值图像 2 条)
|
||||||
|
- 图像理解问题:~200 条(每张架构图 2 条)
|
||||||
|
- **总计:~2000 条多模态问答**
|
||||||
|
|
||||||
|
### 4.4 Phase 4:集成与测试(2-3 天)
|
||||||
|
|
||||||
|
| 任务 | 输出 |
|
||||||
|
|------|------|
|
||||||
|
| 前端支持图像预览 | 审核页显示关联图像 |
|
||||||
|
| 导出格式扩展 | 支持导出带图像引用的 MD |
|
||||||
|
| 单跳测试适配 | 支持多模态召回测试 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 五、技术选型
|
||||||
|
|
||||||
|
### 5.1 多模态 LLM 选择
|
||||||
|
|
||||||
|
| 模型 | 优势 | 劣势 | 推荐场景 |
|
||||||
|
|------|------|------|---------|
|
||||||
|
| **GPT-4V** | 图像理解能力强,API 稳定 | 成本较高($0.01/image) | 图像分类、架构图理解 |
|
||||||
|
| **Claude 3.5 Sonnet** | 成本较低($0.008/image),中文支持好 | 图像细节识别略弱 | 配置界面截图、流程图 |
|
||||||
|
| **Qwen-VL** | 开源免费,可本地部署 | 需要 GPU,推理速度慢 | 成本敏感场景 |
|
||||||
|
|
||||||
|
**推荐组合:**
|
||||||
|
- 图像分类:Claude 3.5 Sonnet(成本低,速度快)
|
||||||
|
- 架构图理解:GPT-4V(精度高)
|
||||||
|
- 问答生成:Claude 3.5 Sonnet(中文生成质量好)
|
||||||
|
|
||||||
|
### 5.2 HTML 解析库
|
||||||
|
|
||||||
|
- **BeautifulSoup4**:简单易用,适合结构化 HTML
|
||||||
|
- **html2text**:快速转换,但图像处理能力弱
|
||||||
|
- **推荐:BeautifulSoup4 + 自定义逻辑**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 六、预期效果
|
||||||
|
|
||||||
|
### 6.1 问答集质量提升
|
||||||
|
|
||||||
|
| 维度 | 现有方案(纯文本) | 多模态方案 | 提升 |
|
||||||
|
|------|------------------|-----------|------|
|
||||||
|
| 问题数量 | ~1000 条 | ~2000 条 | **+100%** |
|
||||||
|
| 覆盖场景 | 概念、配置、命令 | + 界面操作、架构理解、硬件接口 | **+3 类** |
|
||||||
|
| 召回准确率 | 63% | 预期 75%+ | **+12%** |
|
||||||
|
| 用户体验 | 纯文本问答 | 图文并茂,更直观 | **显著提升** |
|
||||||
|
|
||||||
|
### 6.2 Benchmark 价值
|
||||||
|
|
||||||
|
- **更全面**:覆盖文本 + 图像两个模态
|
||||||
|
- **更真实**:贴近用户实际使用场景(看文档 + 看图)
|
||||||
|
- **更有挑战**:测试 RAG 系统的多模态召回能力
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 七、风险与应对
|
||||||
|
|
||||||
|
| 风险 | 影响 | 应对措施 |
|
||||||
|
|------|------|---------|
|
||||||
|
| 图像分析成本高 | 预算超支 | 1. 先筛选高价值图像<br>2. 使用 Claude 3.5 Sonnet 降低成本<br>3. 缓存分析结果 |
|
||||||
|
| 图像描述不准确 | 问答质量下降 | 1. 人工抽查 10% 样本<br>2. 低置信度图像跳过<br>3. 提供图像让审核人员验证 |
|
||||||
|
| 多模态召回测试复杂 | 实施困难 | 1. Phase 1 先生成问答集<br>2. Phase 2 再适配召回测试<br>3. 可先用纯文本测试验证 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 八、快速启动方案(MVP)
|
||||||
|
|
||||||
|
如果时间紧张,可以先实现 **最小可行方案**:
|
||||||
|
|
||||||
|
### MVP 范围
|
||||||
|
|
||||||
|
1. **只处理配置界面截图**(~400 张,占比 40%)
|
||||||
|
2. **只生成图文结合问题**(不做纯图像理解)
|
||||||
|
3. **手动筛选 50 张高价值图像** 作为 Pilot
|
||||||
|
|
||||||
|
### MVP 实施(3-5 天)
|
||||||
|
|
||||||
|
| Day | 任务 |
|
||||||
|
|-----|------|
|
||||||
|
| Day 1 | HTML → MD 转换 + 手动筛选 50 张图 |
|
||||||
|
| Day 2 | 图像分析(50 张) + 描述生成 |
|
||||||
|
| Day 3 | 图文结合问答生成(~100 条) |
|
||||||
|
| Day 4 | 审核 + 导出 |
|
||||||
|
| Day 5 | 单跳测试验证效果 |
|
||||||
|
|
||||||
|
### MVP 成本
|
||||||
|
|
||||||
|
- 图像分析:50 × $0.008 = **$0.4**
|
||||||
|
- 问答生成:100 条 × $0.002 = **$0.2**
|
||||||
|
- **总计:~$0.6**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 九、下一步行动
|
||||||
|
|
||||||
|
1. **确认方案**:是否采用多模态方案?全量还是 MVP?
|
||||||
|
2. **准备环境**:
|
||||||
|
- 申请 GPT-4V 或 Claude 3.5 Sonnet API key
|
||||||
|
- 准备图像存储(本地 or OSS)
|
||||||
|
3. **开发排期**:
|
||||||
|
- 谁负责 HTML 解析?
|
||||||
|
- 谁负责图像分析?
|
||||||
|
- 谁负责问答生成?
|
||||||
|
4. **预算审批**:全量方案约 $20,MVP 约 $1
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**总结:** 多模态方案能显著提升问答集质量和覆盖度,建议先用 MVP 验证效果,再决定是否全量实施。
|
||||||
172
docs/验证报告.md
Normal file
172
docs/验证报告.md
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
# "从 Dagent 知识库导入"功能验证报告
|
||||||
|
|
||||||
|
**验证时间:** 2026年4月21日
|
||||||
|
**验证人员:** Claude Code
|
||||||
|
**项目路径:** d:/project/dagent/rag-eval
|
||||||
|
|
||||||
|
## 一、验证概述
|
||||||
|
|
||||||
|
本次验证针对"从 Dagent 知识库导入"功能的实现,包括后端 API、前端 UI 和数据源切换功能。验证内容包括代码语法、数据库连接、API 路由和前端 UI 显示。
|
||||||
|
|
||||||
|
## 二、验证结果
|
||||||
|
|
||||||
|
### 1. 后端 API 实现 ✅
|
||||||
|
|
||||||
|
**文件:** `server/api/qa_gen_dagent.py`
|
||||||
|
|
||||||
|
**验证项目:**
|
||||||
|
- ✅ 语法检查通过,无语法错误
|
||||||
|
- ✅ 正确导入所有依赖包 (aiomysql, fastapi, asyncio 等)
|
||||||
|
- ✅ 数据库连接配置正确
|
||||||
|
- ✅ 实现三个核心 API 端点:
|
||||||
|
- `GET /api/qa-gen/dagent/stats` - 查询 Dagent 知识库统计
|
||||||
|
- `GET /api/qa-gen/dagent/files` - 列出已处理完成的文件
|
||||||
|
- `POST /api/qa-gen/task/from-dagent` - 创建导入任务
|
||||||
|
- ✅ 后台任务逻辑完整,包括:
|
||||||
|
- 连接 Dagent 数据库
|
||||||
|
- 提取段落数据
|
||||||
|
- 调用 LLM 生成问答
|
||||||
|
- 存入 qa_gen_question 表
|
||||||
|
- ✅ 复用现有 `_sync_approved_count` 函数
|
||||||
|
|
||||||
|
### 2. 路由注册 ✅
|
||||||
|
|
||||||
|
**文件:** `server/main.py`
|
||||||
|
|
||||||
|
**验证项目:**
|
||||||
|
- ✅ 正确导入 `qa_gen_dagent` 模块
|
||||||
|
- ✅ 正确注册路由到 FastAPI 应用
|
||||||
|
- ✅ 路由前缀和标签设置正确
|
||||||
|
|
||||||
|
### 3. 前端 API 服务更新 ✅
|
||||||
|
|
||||||
|
**文件:** `frontend/src/services/api.ts`
|
||||||
|
|
||||||
|
**验证项目:**
|
||||||
|
- ✅ 新增三个 Dagent 相关 API 函数:
|
||||||
|
- `createTaskFromDagent()`
|
||||||
|
- `getDagentStats()`
|
||||||
|
- `listDagentFiles()`
|
||||||
|
- ✅ API 路径与后端一致
|
||||||
|
- ✅ 参数传递正确
|
||||||
|
|
||||||
|
### 4. 前端 UI 修改 ✅
|
||||||
|
|
||||||
|
**文件:** `frontend/src/pages/QaGen/index.tsx`
|
||||||
|
|
||||||
|
**验证项目:**
|
||||||
|
- ✅ 新增数据源切换组件 (Radio.Group)
|
||||||
|
- ✅ 新增 Dagent 模式下的 UI 元素:
|
||||||
|
- org_id 输入框(带查询按钮)
|
||||||
|
- 文件 ID 多行输入框
|
||||||
|
- 生成图文结合问题开关
|
||||||
|
- 统计信息展示区域
|
||||||
|
- ✅ 条件渲染逻辑正确
|
||||||
|
- ✅ 表单验证逻辑完整
|
||||||
|
|
||||||
|
### 5. 数据库连接测试 ✅
|
||||||
|
|
||||||
|
**测试项目:**
|
||||||
|
- ✅ Dagent 数据库连接成功
|
||||||
|
- ✅ 查询 EVB 知识库统计信息成功
|
||||||
|
- ✅ 返回数据与方案文档一致:
|
||||||
|
- 文件数:207 ✓
|
||||||
|
- 段落数:4883 ✓
|
||||||
|
- 总图片数:1226 ✓
|
||||||
|
- 含图段落数:762 ✓
|
||||||
|
- 有种子问题段落数:4883 ✓
|
||||||
|
|
||||||
|
### 6. API 路由测试 ✅
|
||||||
|
|
||||||
|
**测试项目:**
|
||||||
|
- ✅ `/api/health` 健康检查通过
|
||||||
|
- ✅ `/api/qa-gen/dagent/stats` 返回正确统计信息
|
||||||
|
- ✅ `/api/qa-gen/dagent/files` 返回文件列表(207个文件)
|
||||||
|
- ✅ 参数验证正常工作(无 org_id 时返回 422 错误)
|
||||||
|
|
||||||
|
## 三、发现的问题
|
||||||
|
|
||||||
|
### 1. 潜在循环导入风险 ⚠️
|
||||||
|
|
||||||
|
**问题描述:**
|
||||||
|
在 `qa_gen_dagent.py` 第311行,从 `.qa_gen` 导入 `_sync_approved_count` 函数。虽然当前没有循环导入问题,但未来如果 `qa_gen.py` 导入 `qa_gen_dagent` 模块,可能导致循环导入。
|
||||||
|
|
||||||
|
**建议解决方案:**
|
||||||
|
将 `_sync_approved_count` 函数移动到公共工具模块,或使用绝对导入。
|
||||||
|
|
||||||
|
### 2. Windows 控制台编码问题 ⚠️
|
||||||
|
|
||||||
|
**问题描述:**
|
||||||
|
Windows 控制台默认使用 GBK 编码,导致 Unicode 字符(如 ✅ ❌)显示为乱码。
|
||||||
|
|
||||||
|
**影响:**
|
||||||
|
仅影响控制台输出显示,不影响功能。
|
||||||
|
|
||||||
|
## 四、功能完整性检查
|
||||||
|
|
||||||
|
| 功能模块 | 状态 | 备注 |
|
||||||
|
|---------|------|------|
|
||||||
|
| Dagent 数据库连接 | ✅ | 已测试通过 |
|
||||||
|
| 统计信息查询 | ✅ | 返回正确数据 |
|
||||||
|
| 文件列表查询 | ✅ | 返回207个文件 |
|
||||||
|
| 任务创建 | ✅ | 逻辑完整 |
|
||||||
|
| 后台生成任务 | ✅ | 包含并发控制 |
|
||||||
|
| 数据存储 | ✅ | 复用现有表结构 |
|
||||||
|
| 前端数据源切换 | ✅ | UI 完整 |
|
||||||
|
| 表单验证 | ✅ | 前后端一致 |
|
||||||
|
| 错误处理 | ✅ | 包含异常处理 |
|
||||||
|
|
||||||
|
## 五、部署前检查清单
|
||||||
|
|
||||||
|
### 后端检查项:
|
||||||
|
- [x] `aiomysql` 已安装 (`pip install aiomysql`)
|
||||||
|
- [x] `fastapi` 和 `uvicorn` 已安装
|
||||||
|
- [x] 数据库连接配置正确
|
||||||
|
- [x] 路由注册正确
|
||||||
|
- [x] 与现有 qa_gen 模块兼容
|
||||||
|
|
||||||
|
### 前端检查项:
|
||||||
|
- [x] TypeScript 编译无错误
|
||||||
|
- [x] API 服务更新正确
|
||||||
|
- [x] UI 组件引入完整
|
||||||
|
- [x] 条件渲染逻辑正确
|
||||||
|
|
||||||
|
### 数据库检查项:
|
||||||
|
- [x] `qa_gen_task` 表存在且结构正确
|
||||||
|
- [x] `qa_gen_question` 表存在且结构正确
|
||||||
|
- [x] `judge_config` 表有可用配置
|
||||||
|
|
||||||
|
## 六、后续建议
|
||||||
|
|
||||||
|
### 1. 立即实施:
|
||||||
|
- 启动后端服务测试完整流程
|
||||||
|
- 创建一个小型测试任务(选择1-2个文件)
|
||||||
|
- 验证问答生成质量
|
||||||
|
|
||||||
|
### 2. 短期优化:
|
||||||
|
- 添加任务进度实时更新
|
||||||
|
- 优化 LLM 调用超时处理
|
||||||
|
- 添加更多错误日志
|
||||||
|
|
||||||
|
### 3. 长期规划:
|
||||||
|
- 支持增量导入(只导入新增段落)
|
||||||
|
- 添加问答质量自动评估
|
||||||
|
- 支持多知识库并行导入
|
||||||
|
|
||||||
|
## 七、结论
|
||||||
|
|
||||||
|
✅ **验证通过**
|
||||||
|
|
||||||
|
所有核心功能已正确实现:
|
||||||
|
1. 后端 API 完整且功能正常
|
||||||
|
2. 数据库连接测试通过
|
||||||
|
3. 前端 UI 修改正确
|
||||||
|
4. 数据源切换逻辑完整
|
||||||
|
5. 统计信息查询准确
|
||||||
|
|
||||||
|
系统已具备从 Dagent 知识库导入数据并生成多模态问答集的能力。建议进行小规模试点测试后即可投入生产使用。
|
||||||
|
|
||||||
|
---
|
||||||
|
**验证人:** Claude Code
|
||||||
|
**日期:** 2026年4月21日
|
||||||
|
**版本:** v1.0
|
||||||
12
frontend/index.html
Normal file
12
frontend/index.html
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="zh-CN">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
|
<title>RAG Eval Framework</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="root"></div>
|
||||||
|
<script type="module" src="/src/main.tsx"></script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
4278
frontend/package-lock.json
generated
Normal file
4278
frontend/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
27
frontend/package.json
Normal file
27
frontend/package.json
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
{
|
||||||
|
"name": "rag-eval-frontend",
|
||||||
|
"version": "0.1.0",
|
||||||
|
"private": true,
|
||||||
|
"scripts": {
|
||||||
|
"dev": "vite",
|
||||||
|
"build": "tsc && vite build",
|
||||||
|
"preview": "vite preview"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"react": "^18.2.0",
|
||||||
|
"react-dom": "^18.2.0",
|
||||||
|
"react-router-dom": "^6.22.0",
|
||||||
|
"antd": "^5.14.0",
|
||||||
|
"@ant-design/icons": "^5.3.0",
|
||||||
|
"@ant-design/charts": "^2.1.0",
|
||||||
|
"axios": "^1.6.0",
|
||||||
|
"dayjs": "^1.11.10"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@types/react": "^18.2.0",
|
||||||
|
"@types/react-dom": "^18.2.0",
|
||||||
|
"@vitejs/plugin-react": "^4.2.0",
|
||||||
|
"typescript": "^5.3.0",
|
||||||
|
"vite": "^5.1.0"
|
||||||
|
}
|
||||||
|
}
|
||||||
81
frontend/src/App.tsx
Normal file
81
frontend/src/App.tsx
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
import React from 'react'
|
||||||
|
import { BrowserRouter, Routes, Route, NavLink, Navigate, useLocation } from 'react-router-dom'
|
||||||
|
import { Layout, Menu } from 'antd'
|
||||||
|
import {
|
||||||
|
DatabaseOutlined,
|
||||||
|
PlayCircleOutlined,
|
||||||
|
BarChartOutlined,
|
||||||
|
SettingOutlined,
|
||||||
|
AimOutlined,
|
||||||
|
BulbOutlined,
|
||||||
|
ForkOutlined,
|
||||||
|
} from '@ant-design/icons'
|
||||||
|
import Dataset from './pages/Dataset'
|
||||||
|
import DatasetDetail from './pages/Dataset/detail'
|
||||||
|
import Task from './pages/Task'
|
||||||
|
import Report from './pages/Report'
|
||||||
|
import Config from './pages/Config'
|
||||||
|
import SingleJump from './pages/SingleJump'
|
||||||
|
import QaGen from './pages/QaGen'
|
||||||
|
import MultiHop from './pages/MultiHop'
|
||||||
|
|
||||||
|
const { Sider, Content } = Layout
|
||||||
|
|
||||||
|
const NAV = [
|
||||||
|
{ key: '/dataset', icon: <DatabaseOutlined />, label: '测试集' },
|
||||||
|
{ key: '/task', icon: <PlayCircleOutlined />, label: '评测任务' },
|
||||||
|
{ key: '/single-jump', icon: <AimOutlined />, label: '单跳召回测试' },
|
||||||
|
{ key: '/multi-hop', icon: <ForkOutlined />, label: '多跳召回测试' },
|
||||||
|
{ key: '/qa-gen', icon: <BulbOutlined />, label: '问题生成' },
|
||||||
|
{ key: '/config', icon: <SettingOutlined />, label: '配置管理' },
|
||||||
|
]
|
||||||
|
|
||||||
|
function AppLayout() {
|
||||||
|
const location = useLocation()
|
||||||
|
const currentPath = location.pathname.split('/').slice(0, 2).join('/')
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Layout style={{ minHeight: '100vh' }}>
|
||||||
|
<Sider theme="dark" width={200}>
|
||||||
|
<div style={{ color: '#fff', fontWeight: 700, fontSize: 16, padding: '20px 24px 12px' }}>
|
||||||
|
RAG Eval
|
||||||
|
</div>
|
||||||
|
<Menu
|
||||||
|
theme="dark"
|
||||||
|
mode="inline"
|
||||||
|
selectedKeys={[currentPath]}
|
||||||
|
items={NAV.map(n => ({
|
||||||
|
key: n.key,
|
||||||
|
icon: n.icon,
|
||||||
|
label: <NavLink to={n.key}>{n.label}</NavLink>,
|
||||||
|
}))}
|
||||||
|
/>
|
||||||
|
</Sider>
|
||||||
|
<Layout>
|
||||||
|
<Content style={{ padding: 24, background: '#f5f5f5', minHeight: '100vh' }}>
|
||||||
|
<div style={{ background: '#fff', padding: 24, borderRadius: 8, minHeight: '100%' }}>
|
||||||
|
<Routes>
|
||||||
|
<Route path="/" element={<Navigate to="/dataset" replace />} />
|
||||||
|
<Route path="/dataset" element={<Dataset />} />
|
||||||
|
<Route path="/dataset/:id" element={<DatasetDetail />} />
|
||||||
|
<Route path="/task" element={<Task />} />
|
||||||
|
<Route path="/report/:taskId" element={<Report />} />
|
||||||
|
<Route path="/single-jump" element={<SingleJump />} />
|
||||||
|
<Route path="/multi-hop" element={<MultiHop />} />
|
||||||
|
<Route path="/qa-gen" element={<QaGen />} />
|
||||||
|
<Route path="/config" element={<Config />} />
|
||||||
|
</Routes>
|
||||||
|
</div>
|
||||||
|
</Content>
|
||||||
|
</Layout>
|
||||||
|
</Layout>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function App() {
|
||||||
|
return (
|
||||||
|
<BrowserRouter>
|
||||||
|
<AppLayout />
|
||||||
|
</BrowserRouter>
|
||||||
|
)
|
||||||
|
}
|
||||||
311
frontend/src/components/DagentFileSelector/index.tsx
Normal file
311
frontend/src/components/DagentFileSelector/index.tsx
Normal file
@ -0,0 +1,311 @@
|
|||||||
|
import React, { useState, useEffect } from 'react'
|
||||||
|
import { Table, Input, Button, Tag, Space, message, Pagination, Typography } from 'antd'
|
||||||
|
import { SearchOutlined, ReloadOutlined, CheckCircleOutlined } from '@ant-design/icons'
|
||||||
|
import { qaGenApi } from '../../services/api'
|
||||||
|
|
||||||
|
const { Text } = Typography
|
||||||
|
|
||||||
|
interface FileItem {
|
||||||
|
id: string
|
||||||
|
file_name: string
|
||||||
|
file_type: string
|
||||||
|
file_clean_status: string
|
||||||
|
file_bytes: number
|
||||||
|
create_time: string
|
||||||
|
}
|
||||||
|
|
||||||
|
interface DagentFileSelectorProps {
|
||||||
|
orgId: string
|
||||||
|
envUrl?: string // Dagent 环境 URL
|
||||||
|
value?: string | string[] // 选中的文件ID(逗号分隔字符串或数组)
|
||||||
|
onChange?: (fileIds: string | string[]) => void
|
||||||
|
disabled?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
const DagentFileSelector: React.FC<DagentFileSelectorProps> = ({
|
||||||
|
orgId,
|
||||||
|
envUrl = '',
|
||||||
|
value = [],
|
||||||
|
onChange,
|
||||||
|
disabled = false,
|
||||||
|
}) => {
|
||||||
|
const [files, setFiles] = useState<FileItem[]>([])
|
||||||
|
const [loading, setLoading] = useState(false)
|
||||||
|
const [searchText, setSearchText] = useState('')
|
||||||
|
// 转换value为数组格式
|
||||||
|
const valueToArray = (val: string | string[] | undefined): string[] => {
|
||||||
|
if (!val) return []
|
||||||
|
if (Array.isArray(val)) return val
|
||||||
|
return val.split(',').map(id => id.trim()).filter(id => id.length > 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
const [selectedRowKeys, setSelectedRowKeys] = useState<string[]>(valueToArray(value))
|
||||||
|
const [pagination, setPagination] = useState({
|
||||||
|
current: 1,
|
||||||
|
pageSize: 20,
|
||||||
|
total: 0,
|
||||||
|
})
|
||||||
|
|
||||||
|
// 加载文件列表
|
||||||
|
const loadFiles = async (page = 1, pageSize = 20) => {
|
||||||
|
if (!orgId || orgId.length < 8) return
|
||||||
|
|
||||||
|
setLoading(true)
|
||||||
|
try {
|
||||||
|
const res = await qaGenApi.listDagentFiles(orgId, envUrl) as any
|
||||||
|
const fileList = res.data || []
|
||||||
|
setFiles(fileList)
|
||||||
|
setPagination(prev => ({
|
||||||
|
...prev,
|
||||||
|
total: fileList.length,
|
||||||
|
current: page,
|
||||||
|
pageSize,
|
||||||
|
}))
|
||||||
|
} catch (e: any) {
|
||||||
|
console.error('加载文件列表失败:', e)
|
||||||
|
message.error(`加载文件列表失败: ${e.message || '未知错误'}`)
|
||||||
|
setFiles([])
|
||||||
|
} finally {
|
||||||
|
setLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化加载
|
||||||
|
useEffect(() => {
|
||||||
|
if (orgId && orgId.length >= 8) {
|
||||||
|
loadFiles()
|
||||||
|
} else {
|
||||||
|
setFiles([])
|
||||||
|
setSelectedRowKeys([])
|
||||||
|
}
|
||||||
|
}, [orgId, envUrl])
|
||||||
|
|
||||||
|
// 同步选中状态到外部
|
||||||
|
useEffect(() => {
|
||||||
|
setSelectedRowKeys(valueToArray(value))
|
||||||
|
}, [value])
|
||||||
|
|
||||||
|
// 处理选择变化
|
||||||
|
const handleSelectChange = (selectedKeys: string[]) => {
|
||||||
|
setSelectedRowKeys(selectedKeys)
|
||||||
|
if (onChange) {
|
||||||
|
// 为了向后兼容,返回逗号分隔的字符串
|
||||||
|
onChange(selectedKeys.join(','))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 全选/取消全选
|
||||||
|
const handleSelectAll = () => {
|
||||||
|
if (selectedRowKeys.length === filteredFiles.length) {
|
||||||
|
// 取消全选
|
||||||
|
handleSelectChange([])
|
||||||
|
} else {
|
||||||
|
// 全选
|
||||||
|
const allIds = filteredFiles.map(file => file.id)
|
||||||
|
handleSelectChange(allIds)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 格式化文件大小
|
||||||
|
const formatFileSize = (bytes: number) => {
|
||||||
|
if (bytes < 1024) return `${bytes} B`
|
||||||
|
if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(1)} KB`
|
||||||
|
return `${(bytes / (1024 * 1024)).toFixed(1)} MB`
|
||||||
|
}
|
||||||
|
|
||||||
|
// 状态标签
|
||||||
|
const statusTag = (status: string) => {
|
||||||
|
const map: Record<string, { color: string, label: string }> = {
|
||||||
|
'CLEAN_FINISH': { color: 'success', label: '已处理' },
|
||||||
|
'CLEAN_PROCESSING': { color: 'processing', label: '处理中' },
|
||||||
|
'CLEAN_FAILED': { color: 'error', label: '处理失败' },
|
||||||
|
'UPLOAD_FAILED': { color: 'warning', label: '上传失败' },
|
||||||
|
'UPLOAD_SUCCESS': { color: 'default', label: '已上传' },
|
||||||
|
}
|
||||||
|
const cfg = map[status] || { color: 'default', label: status }
|
||||||
|
return <Tag color={cfg.color}>{cfg.label}</Tag>
|
||||||
|
}
|
||||||
|
|
||||||
|
// 文件类型标签
|
||||||
|
const fileTypeTag = (fileType: string) => {
|
||||||
|
const map: Record<string, { color: string, label: string }> = {
|
||||||
|
'html': { color: 'blue', label: 'HTML' },
|
||||||
|
'pdf': { color: 'red', label: 'PDF' },
|
||||||
|
'docx': { color: 'green', label: 'DOCX' },
|
||||||
|
'md': { color: 'purple', label: 'Markdown' },
|
||||||
|
}
|
||||||
|
const cfg = map[fileType.toLowerCase()] || { color: 'default', label: fileType }
|
||||||
|
return <Tag color={cfg.color}>{cfg.label}</Tag>
|
||||||
|
}
|
||||||
|
|
||||||
|
// 搜索过滤
|
||||||
|
const filteredFiles = files.filter(file =>
|
||||||
|
file.file_name.toLowerCase().includes(searchText.toLowerCase()) ||
|
||||||
|
file.id.toLowerCase().includes(searchText.toLowerCase())
|
||||||
|
)
|
||||||
|
|
||||||
|
// 分页数据
|
||||||
|
const startIndex = (pagination.current - 1) * pagination.pageSize
|
||||||
|
const endIndex = startIndex + pagination.pageSize
|
||||||
|
const pageData = filteredFiles.slice(startIndex, endIndex)
|
||||||
|
|
||||||
|
const columns = [
|
||||||
|
{
|
||||||
|
title: (
|
||||||
|
<div style={{ display: 'flex', alignItems: 'center', gap: 8 }}>
|
||||||
|
<span>选择</span>
|
||||||
|
{filteredFiles.length > 0 && (
|
||||||
|
<Button
|
||||||
|
size="small"
|
||||||
|
type="link"
|
||||||
|
onClick={handleSelectAll}
|
||||||
|
style={{ padding: 0, height: 'auto' }}
|
||||||
|
>
|
||||||
|
{selectedRowKeys.length === filteredFiles.length ? '取消全选' : '全选'}
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
),
|
||||||
|
key: 'selection',
|
||||||
|
width: 80,
|
||||||
|
render: (_: any, record: FileItem) => (
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
checked={selectedRowKeys.includes(record.id)}
|
||||||
|
onChange={(e) => {
|
||||||
|
const newSelectedKeys = e.target.checked
|
||||||
|
? [...selectedRowKeys, record.id]
|
||||||
|
: selectedRowKeys.filter(key => key !== record.id)
|
||||||
|
handleSelectChange(newSelectedKeys)
|
||||||
|
}}
|
||||||
|
disabled={disabled}
|
||||||
|
/>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '文件名',
|
||||||
|
dataIndex: 'file_name',
|
||||||
|
key: 'file_name',
|
||||||
|
ellipsis: true,
|
||||||
|
width: 200,
|
||||||
|
render: (text: string) => (
|
||||||
|
<Text strong style={{ fontSize: 13 }}>{text}</Text>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '类型',
|
||||||
|
dataIndex: 'file_type',
|
||||||
|
key: 'file_type',
|
||||||
|
width: 80,
|
||||||
|
render: (type: string) => fileTypeTag(type),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '大小',
|
||||||
|
dataIndex: 'file_bytes',
|
||||||
|
key: 'file_bytes',
|
||||||
|
width: 90,
|
||||||
|
render: (bytes: number) => (
|
||||||
|
<Text type="secondary" style={{ fontSize: 12 }}>
|
||||||
|
{formatFileSize(bytes)}
|
||||||
|
</Text>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '状态',
|
||||||
|
dataIndex: 'file_clean_status',
|
||||||
|
key: 'file_clean_status',
|
||||||
|
width: 90,
|
||||||
|
render: (status: string) => statusTag(status),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '创建时间',
|
||||||
|
dataIndex: 'create_time',
|
||||||
|
key: 'create_time',
|
||||||
|
width: 120,
|
||||||
|
render: (time: string) => (
|
||||||
|
<Text type="secondary" style={{ fontSize: 12 }}>
|
||||||
|
{time ? time.slice(0, 10) : '-'}
|
||||||
|
</Text>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div style={{ border: '1px solid #d9d9d9', borderRadius: 6, padding: 16 }}>
|
||||||
|
{/* 工具栏 */}
|
||||||
|
<div style={{ marginBottom: 16, display: 'flex', justifyContent: 'space-between', alignItems: 'center' }}>
|
||||||
|
<Space>
|
||||||
|
<Input
|
||||||
|
placeholder="搜索文件名或ID"
|
||||||
|
prefix={<SearchOutlined />}
|
||||||
|
value={searchText}
|
||||||
|
onChange={(e) => setSearchText(e.target.value)}
|
||||||
|
style={{ width: 200 }}
|
||||||
|
disabled={disabled || !orgId}
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
icon={<ReloadOutlined />}
|
||||||
|
onClick={() => loadFiles()}
|
||||||
|
loading={loading}
|
||||||
|
disabled={disabled || !orgId}
|
||||||
|
>
|
||||||
|
刷新
|
||||||
|
</Button>
|
||||||
|
</Space>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<Text type="secondary" style={{ fontSize: 12, marginRight: 8 }}>
|
||||||
|
共 {filteredFiles.length} 个文件
|
||||||
|
</Text>
|
||||||
|
<Tag color="blue">
|
||||||
|
<CheckCircleOutlined /> 已选择 {selectedRowKeys.length} 个
|
||||||
|
</Tag>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* 文件列表表格 */}
|
||||||
|
<Table
|
||||||
|
size="small"
|
||||||
|
rowKey="id"
|
||||||
|
columns={columns}
|
||||||
|
dataSource={pageData}
|
||||||
|
loading={loading}
|
||||||
|
pagination={false}
|
||||||
|
scroll={{ y: 300 }}
|
||||||
|
rowSelection={{
|
||||||
|
selectedRowKeys,
|
||||||
|
onChange: (selectedKeys) => handleSelectChange(selectedKeys as string[]),
|
||||||
|
getCheckboxProps: () => ({ disabled }),
|
||||||
|
}}
|
||||||
|
rowClassName={(record) => selectedRowKeys.includes(record.id) ? 'ant-table-row-selected' : ''}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{/* 分页 */}
|
||||||
|
{filteredFiles.length > pagination.pageSize && (
|
||||||
|
<div style={{ marginTop: 16, display: 'flex', justifyContent: 'flex-end' }}>
|
||||||
|
<Pagination
|
||||||
|
size="small"
|
||||||
|
current={pagination.current}
|
||||||
|
pageSize={pagination.pageSize}
|
||||||
|
total={filteredFiles.length}
|
||||||
|
onChange={(page, pageSize) => {
|
||||||
|
setPagination({ ...pagination, current: page, pageSize })
|
||||||
|
}}
|
||||||
|
showSizeChanger
|
||||||
|
pageSizeOptions={['10', '20', '50', '100']}
|
||||||
|
showTotal={(total) => `共 ${total} 个文件`}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* 空状态 */}
|
||||||
|
{!loading && filteredFiles.length === 0 && (
|
||||||
|
<div style={{ textAlign: 'center', padding: '40px 0', color: '#999' }}>
|
||||||
|
{orgId && orgId.length >= 8 ? '暂无文件数据' : '请输入组织ID查询文件'}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default DagentFileSelector
|
||||||
289
frontend/src/components/DagentTreeSelector/index.tsx
Normal file
289
frontend/src/components/DagentTreeSelector/index.tsx
Normal file
@ -0,0 +1,289 @@
|
|||||||
|
import React, { useState, useEffect } from 'react'
|
||||||
|
import { Tree, Card, Tag, Space, Typography, Button, Input, message } from 'antd'
|
||||||
|
import { ReloadOutlined, FolderOutlined, FileOutlined, FileTextOutlined, ClusterOutlined } from '@ant-design/icons'
|
||||||
|
import { qaGenApi } from '../../services/api'
|
||||||
|
|
||||||
|
const { Text } = Typography
|
||||||
|
const { Search } = Input
|
||||||
|
|
||||||
|
interface TreeNode {
|
||||||
|
key: string
|
||||||
|
title: string
|
||||||
|
type: 'major_chapter' | 'minor_chapter' | 'file' | 'chunk'
|
||||||
|
file_id?: string
|
||||||
|
chunk_id?: string
|
||||||
|
chunk_count?: number
|
||||||
|
file_type?: string
|
||||||
|
status?: string
|
||||||
|
preview?: string
|
||||||
|
has_image?: boolean
|
||||||
|
children?: TreeNode[]
|
||||||
|
}
|
||||||
|
|
||||||
|
interface DagentTreeSelectorProps {
|
||||||
|
orgId: string
|
||||||
|
envUrl?: string
|
||||||
|
value?: string[] // 选中的文件ID列表
|
||||||
|
onChange?: (fileIds: string[]) => void
|
||||||
|
disabled?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
const DagentTreeSelector: React.FC<DagentTreeSelectorProps> = ({
|
||||||
|
orgId,
|
||||||
|
envUrl = '',
|
||||||
|
value = [],
|
||||||
|
onChange,
|
||||||
|
disabled = false,
|
||||||
|
}) => {
|
||||||
|
const [treeData, setTreeData] = useState<TreeNode[]>([])
|
||||||
|
const [loading, setLoading] = useState(false)
|
||||||
|
const [expandedKeys, setExpandedKeys] = useState<string[]>([])
|
||||||
|
const [checkedKeys, setCheckedKeys] = useState<string[]>([])
|
||||||
|
const [searchText, setSearchText] = useState('')
|
||||||
|
|
||||||
|
// 加载树形数据
|
||||||
|
const loadTreeData = async () => {
|
||||||
|
if (!orgId || orgId.length < 8) return
|
||||||
|
|
||||||
|
setLoading(true)
|
||||||
|
try {
|
||||||
|
const res: any = await qaGenApi.getDagentTree(orgId, envUrl)
|
||||||
|
if (res.status === 0) {
|
||||||
|
setTreeData(res.data || [])
|
||||||
|
// 默认展开第一级
|
||||||
|
if (res.data && res.data.length > 0) {
|
||||||
|
setExpandedKeys(res.data.map((n: TreeNode) => n.key))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
message.error(res.message || '加载树形数据失败')
|
||||||
|
}
|
||||||
|
} catch (e: any) {
|
||||||
|
console.error('加载树形数据失败:', e)
|
||||||
|
message.error(`加载失败: ${e.message || '未知错误'}`)
|
||||||
|
} finally {
|
||||||
|
setLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化加载
|
||||||
|
useEffect(() => {
|
||||||
|
if (orgId && orgId.length >= 8) {
|
||||||
|
loadTreeData()
|
||||||
|
} else {
|
||||||
|
setTreeData([])
|
||||||
|
}
|
||||||
|
}, [orgId, envUrl])
|
||||||
|
|
||||||
|
// 同步选中状态到外部
|
||||||
|
useEffect(() => {
|
||||||
|
if (value) {
|
||||||
|
// 将 file:id 格式转换为 key 格式
|
||||||
|
const keys = value.map(id => `file:${id}`)
|
||||||
|
setCheckedKeys(keys)
|
||||||
|
}
|
||||||
|
}, [value])
|
||||||
|
|
||||||
|
// 获取所有子文件key
|
||||||
|
const getAllFileKeys = (node: TreeNode): string[] => {
|
||||||
|
const keys: string[] = []
|
||||||
|
if (node.type === 'file' && node.file_id) {
|
||||||
|
keys.push(node.key)
|
||||||
|
}
|
||||||
|
if (node.children) {
|
||||||
|
node.children.forEach(child => {
|
||||||
|
keys.push(...getAllFileKeys(child))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理选择变化
|
||||||
|
const handleCheck = (checked: any, info: any) => {
|
||||||
|
const keys = checked as string[]
|
||||||
|
setCheckedKeys(keys)
|
||||||
|
|
||||||
|
// 提取文件ID
|
||||||
|
const fileIds: string[] = []
|
||||||
|
keys.forEach((key: string) => {
|
||||||
|
if (key.startsWith('file:')) {
|
||||||
|
fileIds.push(key.replace('file:', ''))
|
||||||
|
} else if (key.startsWith('major:') || key.startsWith('minor:')) {
|
||||||
|
// 如果是章节被选中,获取其下所有文件
|
||||||
|
const findNode = (nodes: TreeNode[], targetKey: string): TreeNode | null => {
|
||||||
|
for (const node of nodes) {
|
||||||
|
if (node.key === targetKey) return node
|
||||||
|
if (node.children) {
|
||||||
|
const found = findNode(node.children, targetKey)
|
||||||
|
if (found) return found
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
const node = findNode(treeData, key)
|
||||||
|
if (node) {
|
||||||
|
const fileKeys = getAllFileKeys(node)
|
||||||
|
fileKeys.forEach(k => fileIds.push(k.replace('file:', '')))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// 去重
|
||||||
|
const uniqueFileIds = [...new Set(fileIds)]
|
||||||
|
onChange?.(uniqueFileIds)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 搜索过滤树
|
||||||
|
const filterTreeData = (data: TreeNode[], search: string): TreeNode[] => {
|
||||||
|
if (!search) return data
|
||||||
|
|
||||||
|
return data.map(node => {
|
||||||
|
const filteredChildren = node.children ? filterTreeData(node.children, search) : []
|
||||||
|
const matchTitle = node.title.toLowerCase().includes(search.toLowerCase())
|
||||||
|
|
||||||
|
if (matchTitle || filteredChildren.length > 0) {
|
||||||
|
return {
|
||||||
|
...node,
|
||||||
|
children: filteredChildren
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}).filter(Boolean) as TreeNode[]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 自定义标题渲染
|
||||||
|
const titleRender = (nodeData: TreeNode) => {
|
||||||
|
const { type, title, chunk_count, file_type, status, preview, has_image } = nodeData
|
||||||
|
|
||||||
|
const getIcon = () => {
|
||||||
|
switch (type) {
|
||||||
|
case 'major_chapter':
|
||||||
|
return <ClusterOutlined style={{ color: '#1890ff', marginRight: 4 }} />
|
||||||
|
case 'minor_chapter':
|
||||||
|
return <FolderOutlined style={{ color: '#faad14', marginRight: 4 }} />
|
||||||
|
case 'file':
|
||||||
|
return <FileTextOutlined style={{ color: '#52c41a', marginRight: 4 }} />
|
||||||
|
case 'chunk':
|
||||||
|
return <FileOutlined style={{ color: '#722ed1', marginRight: 4, fontSize: 12 }} />
|
||||||
|
default:
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const getTag = () => {
|
||||||
|
if (type === 'file') {
|
||||||
|
const color = status === 'clean_finish' ? 'success' : status === 'clean_processing' ? 'processing' : 'default'
|
||||||
|
return (
|
||||||
|
<Space size={4}>
|
||||||
|
<Tag color="blue">{file_type?.toUpperCase() || 'FILE'}</Tag>
|
||||||
|
{chunk_count !== undefined && <Tag color="cyan">{chunk_count} 切片</Tag>}
|
||||||
|
</Space>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if (type === 'chunk' && has_image) {
|
||||||
|
return <Tag color="orange">含图片</Tag>
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<span style={{ display: 'inline-flex', alignItems: 'center', gap: 4 }}>
|
||||||
|
{getIcon()}
|
||||||
|
<Text
|
||||||
|
style={{
|
||||||
|
fontSize: type === 'chunk' ? 12 : 13,
|
||||||
|
fontWeight: type === 'major_chapter' ? 600 : type === 'minor_chapter' ? 500 : 'normal'
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{title}
|
||||||
|
</Text>
|
||||||
|
{getTag()}
|
||||||
|
{type === 'chunk' && preview && (
|
||||||
|
<Text type="secondary" style={{ fontSize: 11, marginLeft: 8 }}>
|
||||||
|
{preview}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
</span>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 统计信息
|
||||||
|
const getStats = () => {
|
||||||
|
const stats = { files: 0, chunks: 0, selectedFiles: 0 }
|
||||||
|
|
||||||
|
const traverse = (nodes: TreeNode[]) => {
|
||||||
|
nodes.forEach(node => {
|
||||||
|
if (node.type === 'file') {
|
||||||
|
stats.files++
|
||||||
|
stats.chunks += node.chunk_count || 0
|
||||||
|
if (checkedKeys.includes(node.key)) {
|
||||||
|
stats.selectedFiles++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (node.children) traverse(node.children)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
traverse(treeData)
|
||||||
|
return stats
|
||||||
|
}
|
||||||
|
|
||||||
|
const stats = getStats()
|
||||||
|
const filteredTreeData = searchText ? filterTreeData(treeData, searchText) : treeData
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Card
|
||||||
|
size="small"
|
||||||
|
loading={loading}
|
||||||
|
title={
|
||||||
|
<Space>
|
||||||
|
<span>知识库文件树</span>
|
||||||
|
<Tag color="blue">{stats.files} 文件</Tag>
|
||||||
|
<Tag color="cyan">{stats.chunks} 切片</Tag>
|
||||||
|
<Tag color="green">已选 {stats.selectedFiles} 文件</Tag>
|
||||||
|
</Space>
|
||||||
|
}
|
||||||
|
extra={
|
||||||
|
<Space>
|
||||||
|
<Search
|
||||||
|
placeholder="搜索文件或章节"
|
||||||
|
allowClear
|
||||||
|
size="small"
|
||||||
|
style={{ width: 180 }}
|
||||||
|
value={searchText}
|
||||||
|
onChange={(e) => setSearchText(e.target.value)}
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
icon={<ReloadOutlined />}
|
||||||
|
size="small"
|
||||||
|
onClick={loadTreeData}
|
||||||
|
loading={loading}
|
||||||
|
disabled={disabled || !orgId}
|
||||||
|
>
|
||||||
|
刷新
|
||||||
|
</Button>
|
||||||
|
</Space>
|
||||||
|
}
|
||||||
|
>
|
||||||
|
{treeData.length > 0 ? (
|
||||||
|
<Tree
|
||||||
|
checkable
|
||||||
|
checkStrictly={false}
|
||||||
|
checkedKeys={checkedKeys}
|
||||||
|
expandedKeys={expandedKeys}
|
||||||
|
onExpand={(keys) => setExpandedKeys(keys as string[])}
|
||||||
|
onCheck={handleCheck}
|
||||||
|
treeData={filteredTreeData}
|
||||||
|
titleRender={titleRender}
|
||||||
|
style={{ maxHeight: 400, overflow: 'auto' }}
|
||||||
|
disabled={disabled}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<div style={{ textAlign: 'center', padding: 40, color: '#999' }}>
|
||||||
|
{orgId && orgId.length >= 8 ? '暂无数据,请刷新重试' : '请输入组织ID'}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</Card>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default DagentTreeSelector
|
||||||
34
frontend/src/constants/metrics.ts
Normal file
34
frontend/src/constants/metrics.ts
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
export interface MetricMeta {
|
||||||
|
key: string
|
||||||
|
en: string
|
||||||
|
cn: string
|
||||||
|
group: 'retrieval' | 'generation'
|
||||||
|
desc: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export const METRICS: Record<string, MetricMeta> = {
|
||||||
|
hit_rate: { key: 'hit_rate', en: 'Hit Rate@K', cn: '命中率', group: 'retrieval', desc: '检索结果中包含相关文档的比例' },
|
||||||
|
mrr: { key: 'mrr', en: 'MRR@K', cn: '平均倒数排名', group: 'retrieval', desc: '第一个相关文档排名位置的倒数均值' },
|
||||||
|
ndcg: { key: 'ndcg', en: 'NDCG@K', cn: '归一化折损累积增益', group: 'retrieval', desc: '考虑排名位置的检索质量综合评分' },
|
||||||
|
context_precision: { key: 'context_precision', en: 'Context Precision', cn: '上下文精确度', group: 'retrieval', desc: '检索到的文档中与问题相关的比例' },
|
||||||
|
context_recall: { key: 'context_recall', en: 'Context Recall', cn: '上下文召回率', group: 'retrieval', desc: '参考答案中的信息被检索文档覆盖的比例' },
|
||||||
|
faithfulness: { key: 'faithfulness', en: 'Faithfulness', cn: '忠实度', group: 'generation', desc: '回答内容是否忠实于检索到的上下文' },
|
||||||
|
answer_relevance: { key: 'answer_relevance', en: 'Answer Relevance', cn: '回答相关性', group: 'generation', desc: '回答与原始问题的相关程度' },
|
||||||
|
answer_correctness: { key: 'answer_correctness', en: 'Answer Correctness',cn: '回答正确性', group: 'generation', desc: '回答与参考答案的事实一致程度' },
|
||||||
|
groundedness: { key: 'groundedness', en: 'Groundedness', cn: '可溯源性', group: 'generation', desc: '回答中的声明能否追溯到检索文档' },
|
||||||
|
}
|
||||||
|
|
||||||
|
export const RETRIEVAL_METRICS = Object.values(METRICS).filter(m => m.group === 'retrieval')
|
||||||
|
export const GENERATION_METRICS = Object.values(METRICS).filter(m => m.group === 'generation')
|
||||||
|
export const ALL_METRIC_KEYS = Object.keys(METRICS)
|
||||||
|
|
||||||
|
/** 根据 key 获取中文显示名 */
|
||||||
|
export function metricLabel(key: string): string {
|
||||||
|
const m = METRICS[key]
|
||||||
|
return m ? `${m.cn} (${m.en})` : key
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 根据 key 获取短中文名 */
|
||||||
|
export function metricCn(key: string): string {
|
||||||
|
return METRICS[key]?.cn ?? key
|
||||||
|
}
|
||||||
10
frontend/src/main.tsx
Normal file
10
frontend/src/main.tsx
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
import React from 'react'
|
||||||
|
import ReactDOM from 'react-dom/client'
|
||||||
|
import App from './App'
|
||||||
|
import 'antd/dist/reset.css'
|
||||||
|
|
||||||
|
ReactDOM.createRoot(document.getElementById('root')!).render(
|
||||||
|
<React.StrictMode>
|
||||||
|
<App />
|
||||||
|
</React.StrictMode>
|
||||||
|
)
|
||||||
203
frontend/src/pages/Config/index.tsx
Normal file
203
frontend/src/pages/Config/index.tsx
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
import React, { useEffect, useState } from 'react'
|
||||||
|
import { Table, Button, Modal, Form, Input, Select, Popconfirm, message, Tag, Space } from 'antd'
|
||||||
|
import { PlusOutlined, DeleteOutlined } from '@ant-design/icons'
|
||||||
|
import { configApi } from '../../services/api'
|
||||||
|
|
||||||
|
const { Option } = Select
|
||||||
|
|
||||||
|
export default function Config() {
|
||||||
|
const [platforms, setPlatforms] = useState<any[]>([])
|
||||||
|
const [judges, setJudges] = useState<any[]>([])
|
||||||
|
const [platformModal, setPlatformModal] = useState(false)
|
||||||
|
const [judgeModal, setJudgeModal] = useState(false)
|
||||||
|
const [form] = Form.useForm()
|
||||||
|
const [judgeForm] = Form.useForm()
|
||||||
|
const [selectedPlatformKeys, setSelectedPlatformKeys] = useState<React.Key[]>([])
|
||||||
|
const [selectedJudgeKeys, setSelectedJudgeKeys] = useState<React.Key[]>([])
|
||||||
|
|
||||||
|
const load = async () => {
|
||||||
|
const [p, j] = await Promise.all([configApi.listPlatforms(), configApi.listJudges()])
|
||||||
|
setPlatforms((p as any).data || [])
|
||||||
|
setJudges((j as any).data || [])
|
||||||
|
}
|
||||||
|
|
||||||
|
useEffect(() => { load() }, [])
|
||||||
|
|
||||||
|
const savePlatform = async () => {
|
||||||
|
const vals = await form.validateFields()
|
||||||
|
await configApi.createPlatform(vals)
|
||||||
|
message.success('平台配置已保存')
|
||||||
|
setPlatformModal(false)
|
||||||
|
form.resetFields()
|
||||||
|
load()
|
||||||
|
}
|
||||||
|
|
||||||
|
const saveJudge = async () => {
|
||||||
|
const vals = await judgeForm.validateFields()
|
||||||
|
await configApi.createJudge(vals)
|
||||||
|
message.success('Judge 配置已保存')
|
||||||
|
setJudgeModal(false)
|
||||||
|
judgeForm.resetFields()
|
||||||
|
load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 批量删除平台配置 ───────────────────────────────────────────────────────────
|
||||||
|
const handleBatchDeletePlatform = async () => {
|
||||||
|
if (selectedPlatformKeys.length === 0) {
|
||||||
|
message.warning('请先选择要删除的平台配置')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
Modal.confirm({
|
||||||
|
title: `确认删除选中的 ${selectedPlatformKeys.length} 个平台配置?`,
|
||||||
|
content: '删除后将无法恢复。',
|
||||||
|
okText: '确认删除',
|
||||||
|
okType: 'danger',
|
||||||
|
cancelText: '取消',
|
||||||
|
async onOk() {
|
||||||
|
try {
|
||||||
|
await Promise.all(selectedPlatformKeys.map(id => configApi.deletePlatform(id as string)))
|
||||||
|
message.success(`成功删除 ${selectedPlatformKeys.length} 个平台配置`)
|
||||||
|
setSelectedPlatformKeys([])
|
||||||
|
load()
|
||||||
|
} catch (e: any) {
|
||||||
|
message.error(e?.message || '批量删除失败')
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 批量删除 Judge 配置 ───────────────────────────────────────────────────────
|
||||||
|
const handleBatchDeleteJudge = async () => {
|
||||||
|
if (selectedJudgeKeys.length === 0) {
|
||||||
|
message.warning('请先选择要删除的 Judge 配置')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
Modal.confirm({
|
||||||
|
title: `确认删除选中的 ${selectedJudgeKeys.length} 个 Judge 配置?`,
|
||||||
|
content: '删除后将无法恢复。',
|
||||||
|
okText: '确认删除',
|
||||||
|
okType: 'danger',
|
||||||
|
cancelText: '取消',
|
||||||
|
async onOk() {
|
||||||
|
try {
|
||||||
|
await Promise.all(selectedJudgeKeys.map(id => configApi.deleteJudge(id as string)))
|
||||||
|
message.success(`成功删除 ${selectedJudgeKeys.length} 个 Judge 配置`)
|
||||||
|
setSelectedJudgeKeys([])
|
||||||
|
load()
|
||||||
|
} catch (e: any) {
|
||||||
|
message.error(e?.message || '批量删除失败')
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const platformCols = [
|
||||||
|
{ title: '名称', dataIndex: 'name' },
|
||||||
|
{ title: '类型', dataIndex: 'type', render: (v: string) => <Tag color="blue">{v}</Tag> },
|
||||||
|
{ title: 'Base URL', dataIndex: 'base_url' },
|
||||||
|
{ title: 'Org ID', dataIndex: 'org_id' },
|
||||||
|
{ title: '创建时间', dataIndex: 'created_at', render: (v: string) => v?.slice(0, 19) },
|
||||||
|
{
|
||||||
|
title: '操作', render: (_: any, r: any) => (
|
||||||
|
<Popconfirm title="确认删除?" onConfirm={() => configApi.deletePlatform(r.id).then(load)}>
|
||||||
|
<Button danger size="small" icon={<DeleteOutlined />} />
|
||||||
|
</Popconfirm>
|
||||||
|
)
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
const judgeCols = [
|
||||||
|
{ title: '名称', dataIndex: 'name' },
|
||||||
|
{ title: '模型', dataIndex: 'model', render: (v: string) => <Tag color="purple">{v}</Tag> },
|
||||||
|
{ title: 'Base URL', dataIndex: 'base_url' },
|
||||||
|
{ title: 'Embed 模型', dataIndex: 'embed_model', render: (v: string) => v ? <Tag color="cyan">{v}</Tag> : '-' },
|
||||||
|
{ title: '创建时间', dataIndex: 'created_at', render: (v: string) => v?.slice(0, 19) },
|
||||||
|
{
|
||||||
|
title: '操作', render: (_: any, r: any) => (
|
||||||
|
<Popconfirm title="确认删除?" onConfirm={() => configApi.deleteJudge(r.id).then(load)}>
|
||||||
|
<Button danger size="small" icon={<DeleteOutlined />} />
|
||||||
|
</Popconfirm>
|
||||||
|
)
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<h2>配置管理</h2>
|
||||||
|
|
||||||
|
<div style={{ marginBottom: 32 }}>
|
||||||
|
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 12 }}>
|
||||||
|
<h3 style={{ margin: 0 }}>平台连接配置</h3>
|
||||||
|
<Space>
|
||||||
|
{selectedPlatformKeys.length > 0 && (
|
||||||
|
<Button danger icon={<DeleteOutlined />} onClick={handleBatchDeletePlatform}>
|
||||||
|
批量删除 ({selectedPlatformKeys.length})
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
<Button type="primary" icon={<PlusOutlined />} onClick={() => setPlatformModal(true)}>新增平台</Button>
|
||||||
|
</Space>
|
||||||
|
</div>
|
||||||
|
<Table
|
||||||
|
rowKey="id"
|
||||||
|
dataSource={platforms}
|
||||||
|
columns={platformCols}
|
||||||
|
pagination={false}
|
||||||
|
size="small"
|
||||||
|
rowSelection={{
|
||||||
|
selectedRowKeys: selectedPlatformKeys,
|
||||||
|
onChange: setSelectedPlatformKeys,
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 12 }}>
|
||||||
|
<h3 style={{ margin: 0 }}>Judge 模型配置</h3>
|
||||||
|
<Space>
|
||||||
|
{selectedJudgeKeys.length > 0 && (
|
||||||
|
<Button danger icon={<DeleteOutlined />} onClick={handleBatchDeleteJudge}>
|
||||||
|
批量删除 ({selectedJudgeKeys.length})
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
<Button type="primary" icon={<PlusOutlined />} onClick={() => setJudgeModal(true)}>新增 Judge</Button>
|
||||||
|
</Space>
|
||||||
|
</div>
|
||||||
|
<Table
|
||||||
|
rowKey="id"
|
||||||
|
dataSource={judges}
|
||||||
|
columns={judgeCols}
|
||||||
|
pagination={false}
|
||||||
|
size="small"
|
||||||
|
rowSelection={{
|
||||||
|
selectedRowKeys: selectedJudgeKeys,
|
||||||
|
onChange: setSelectedJudgeKeys,
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Modal title="新增平台配置" open={platformModal} onOk={savePlatform} onCancel={() => setPlatformModal(false)}>
|
||||||
|
<Form form={form} layout="vertical">
|
||||||
|
<Form.Item name="name" label="名称" rules={[{ required: true }]}><Input /></Form.Item>
|
||||||
|
<Form.Item name="type" label="类型" initialValue="dagent">
|
||||||
|
<Select><Option value="dagent">dagent</Option><Option value="custom">custom</Option></Select>
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="base_url" label="Base URL" rules={[{ required: true }]}><Input placeholder="http://localhost:8000" /></Form.Item>
|
||||||
|
<Form.Item name="org_id" label="Org ID" rules={[{ required: true }]}><Input placeholder="a4d49699ba313815..." /></Form.Item>
|
||||||
|
<Form.Item name="token" label="Token"><Input.Password /></Form.Item>
|
||||||
|
</Form>
|
||||||
|
</Modal>
|
||||||
|
|
||||||
|
<Modal title="新增 Judge 配置" open={judgeModal} onOk={saveJudge} onCancel={() => setJudgeModal(false)} width={560}>
|
||||||
|
<Form form={judgeForm} layout="vertical">
|
||||||
|
<Form.Item name="name" label="名称" rules={[{ required: true }]}><Input /></Form.Item>
|
||||||
|
<Form.Item name="base_url" label="Base URL (OpenAI 兼容)" rules={[{ required: true }]}><Input placeholder="https://api.openai.com/v1" /></Form.Item>
|
||||||
|
<Form.Item name="api_key" label="API Key" rules={[{ required: true }]}><Input.Password /></Form.Item>
|
||||||
|
<Form.Item name="model" label="模型" rules={[{ required: true }]}><Input placeholder="gpt-4o" /></Form.Item>
|
||||||
|
<Form.Item name="embed_base_url" label="Embedding Base URL(可选,不填则复用上方 Base URL)"><Input placeholder="https://dashscope.aliyuncs.com/compatible-mode/v1" /></Form.Item>
|
||||||
|
<Form.Item name="embed_api_key" label="Embedding API Key(可选,不填则复用上方 Key)"><Input.Password /></Form.Item>
|
||||||
|
<Form.Item name="embed_model" label="Embedding 模型" initialValue="text-embedding-3-small"><Input placeholder="text-embedding-v2" /></Form.Item>
|
||||||
|
</Form>
|
||||||
|
</Modal>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
325
frontend/src/pages/Dataset/detail.tsx
Normal file
325
frontend/src/pages/Dataset/detail.tsx
Normal file
@ -0,0 +1,325 @@
|
|||||||
|
import React, { useEffect, useState, useRef, useCallback } from 'react'
|
||||||
|
import { Table, Button, Modal, Form, Input, Select, Tag, Space, message, Descriptions, Progress, Checkbox, Tooltip, Alert } from 'antd'
|
||||||
|
import { PlusOutlined, ThunderboltOutlined, SearchOutlined, ReloadOutlined } from '@ant-design/icons'
|
||||||
|
import { useParams, useNavigate } from 'react-router-dom'
|
||||||
|
import { datasetApi, configApi } from '../../services/api'
|
||||||
|
|
||||||
|
const { Option } = Select
|
||||||
|
|
||||||
|
export default function DatasetDetail() {
|
||||||
|
const { id } = useParams<{ id: string }>()
|
||||||
|
const navigate = useNavigate()
|
||||||
|
const [dataset, setDataset] = useState<any>(null)
|
||||||
|
const [samples, setSamples] = useState<any[]>([])
|
||||||
|
const [addModal, setAddModal] = useState(false)
|
||||||
|
const [genModal, setGenModal] = useState(false)
|
||||||
|
const [platforms, setPlatforms] = useState<any[]>([])
|
||||||
|
const [judges, setJudges] = useState<any[]>([])
|
||||||
|
const [form] = Form.useForm()
|
||||||
|
const [genForm] = Form.useForm()
|
||||||
|
|
||||||
|
// Chunk preview state
|
||||||
|
const [chunks, setChunks] = useState<any[]>([])
|
||||||
|
const [chunksLoading, setChunksLoading] = useState(false)
|
||||||
|
const [selectedChunkIds, setSelectedChunkIds] = useState<string[]>([])
|
||||||
|
|
||||||
|
// Generate progress state
|
||||||
|
const [genTaskId, setGenTaskId] = useState<string | null>(null)
|
||||||
|
const [genProgress, setGenProgress] = useState<{ progress: number; total: number; status: string } | null>(null)
|
||||||
|
const pollRef = useRef<ReturnType<typeof setInterval> | null>(null)
|
||||||
|
|
||||||
|
const load = async () => {
|
||||||
|
const res = await datasetApi.get(id!) as any
|
||||||
|
setDataset(res.data)
|
||||||
|
setSamples(res.data?.samples || [])
|
||||||
|
}
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
load()
|
||||||
|
configApi.listPlatforms().then((r: any) => setPlatforms(r.data || []))
|
||||||
|
configApi.listJudges().then((r: any) => setJudges(r.data || []))
|
||||||
|
return () => { if (pollRef.current) clearInterval(pollRef.current) }
|
||||||
|
}, [id])
|
||||||
|
|
||||||
|
// Fetch chunks when platform + hub_id are both set
|
||||||
|
const fetchChunks = useCallback(async () => {
|
||||||
|
const platformId = genForm.getFieldValue('platform_config_id')
|
||||||
|
const hubId = genForm.getFieldValue('knowledge_hub_id')
|
||||||
|
if (!platformId || !hubId) {
|
||||||
|
message.warning('请先选择平台配置并填写知识库 ID')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
setChunksLoading(true)
|
||||||
|
setChunks([])
|
||||||
|
setSelectedChunkIds([])
|
||||||
|
try {
|
||||||
|
const res = await datasetApi.chunksPreview(platformId, hubId) as any
|
||||||
|
const data = res.data || []
|
||||||
|
setChunks(data)
|
||||||
|
setSelectedChunkIds(data.map((c: any) => c.id))
|
||||||
|
if (data.length === 0) {
|
||||||
|
message.info('未找到切片,请检查知识库 ID 是否正确')
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
message.error('获取切片失败')
|
||||||
|
} finally {
|
||||||
|
setChunksLoading(false)
|
||||||
|
}
|
||||||
|
}, [genForm])
|
||||||
|
|
||||||
|
// Poll generate progress
|
||||||
|
const startPolling = useCallback((taskId: string) => {
|
||||||
|
if (pollRef.current) clearInterval(pollRef.current)
|
||||||
|
pollRef.current = setInterval(async () => {
|
||||||
|
try {
|
||||||
|
const res = await datasetApi.getGenerateProgress(taskId) as any
|
||||||
|
const data = res.data
|
||||||
|
setGenProgress({ progress: data.progress || 0, total: data.total || 0, status: data.status })
|
||||||
|
if (data.status === 'done' || data.status === 'failed') {
|
||||||
|
if (pollRef.current) clearInterval(pollRef.current)
|
||||||
|
pollRef.current = null
|
||||||
|
if (data.status === 'done') {
|
||||||
|
message.success('样本生成完成')
|
||||||
|
load()
|
||||||
|
} else {
|
||||||
|
message.error(`生成失败: ${data.error_message || '未知错误'}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// ignore poll errors
|
||||||
|
}
|
||||||
|
}, 2000)
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const addSample = async () => {
|
||||||
|
const vals = await form.validateFields()
|
||||||
|
await datasetApi.addSample({
|
||||||
|
...vals,
|
||||||
|
dataset_id: id,
|
||||||
|
relevant_chunk_ids: vals.relevant_chunk_ids
|
||||||
|
? vals.relevant_chunk_ids.split('\n').map((s: string) => s.trim()).filter(Boolean)
|
||||||
|
: [],
|
||||||
|
})
|
||||||
|
message.success('样本已添加')
|
||||||
|
setAddModal(false)
|
||||||
|
form.resetFields()
|
||||||
|
load()
|
||||||
|
}
|
||||||
|
|
||||||
|
const startGenerate = async () => {
|
||||||
|
const vals = await genForm.validateFields()
|
||||||
|
if (selectedChunkIds.length === 0 && chunks.length > 0) {
|
||||||
|
message.warning('请至少选择一个切片')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Derive file_id_list from selected chunks
|
||||||
|
const fileIds = [...new Set(
|
||||||
|
chunks.filter(c => selectedChunkIds.includes(c.id)).map((c: any) => c.file_id)
|
||||||
|
)]
|
||||||
|
const res = await datasetApi.generate({
|
||||||
|
...vals,
|
||||||
|
dataset_id: id,
|
||||||
|
file_id_list: fileIds.length > 0 ? fileIds : [vals.knowledge_hub_id],
|
||||||
|
chunk_ids: selectedChunkIds,
|
||||||
|
}) as any
|
||||||
|
const taskId = res.data?.gen_task_id
|
||||||
|
if (taskId) {
|
||||||
|
setGenTaskId(taskId)
|
||||||
|
setGenProgress({ progress: 0, total: selectedChunkIds.length || 0, status: 'pending' })
|
||||||
|
startPolling(taskId)
|
||||||
|
message.success('生成任务已启动')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const closeGenModal = () => {
|
||||||
|
setGenModal(false)
|
||||||
|
setChunks([])
|
||||||
|
setSelectedChunkIds([])
|
||||||
|
setGenTaskId(null)
|
||||||
|
setGenProgress(null)
|
||||||
|
genForm.resetFields()
|
||||||
|
if (pollRef.current) { clearInterval(pollRef.current); pollRef.current = null }
|
||||||
|
}
|
||||||
|
|
||||||
|
const columns = [
|
||||||
|
{ title: '问题', dataIndex: 'question', ellipsis: true, width: '30%' },
|
||||||
|
{ title: '参考答案', dataIndex: 'reference_answer', ellipsis: true, width: '30%' },
|
||||||
|
{ title: '知识库 ID', dataIndex: 'knowledge_hub_id', ellipsis: true },
|
||||||
|
{
|
||||||
|
title: '类型', dataIndex: 'metadata',
|
||||||
|
render: (m: any) => m?.type ? <Tag>{m.type}</Tag> : '-'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '难度', dataIndex: 'metadata',
|
||||||
|
key: 'difficulty',
|
||||||
|
render: (m: any) => {
|
||||||
|
const color: any = { easy: 'green', medium: 'orange', hard: 'red' }
|
||||||
|
return m?.difficulty ? <Tag color={color[m.difficulty]}>{m.difficulty}</Tag> : '-'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
const chunkColumns = [
|
||||||
|
{
|
||||||
|
title: () => (
|
||||||
|
<Checkbox
|
||||||
|
checked={selectedChunkIds.length === chunks.length && chunks.length > 0}
|
||||||
|
indeterminate={selectedChunkIds.length > 0 && selectedChunkIds.length < chunks.length}
|
||||||
|
onChange={e => setSelectedChunkIds(e.target.checked ? chunks.map(c => c.id) : [])}
|
||||||
|
/>
|
||||||
|
),
|
||||||
|
dataIndex: 'id',
|
||||||
|
width: 40,
|
||||||
|
render: (cid: string) => (
|
||||||
|
<Checkbox
|
||||||
|
checked={selectedChunkIds.includes(cid)}
|
||||||
|
onChange={e => {
|
||||||
|
setSelectedChunkIds(prev =>
|
||||||
|
e.target.checked ? [...prev, cid] : prev.filter(x => x !== cid)
|
||||||
|
)
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '切片内容',
|
||||||
|
dataIndex: 'content',
|
||||||
|
ellipsis: true,
|
||||||
|
render: (text: string) => (
|
||||||
|
<Tooltip title={text} placement="topLeft" overlayStyle={{ maxWidth: 500 }}>
|
||||||
|
<span>{text?.slice(0, 120)}{text?.length > 120 ? '...' : ''}</span>
|
||||||
|
</Tooltip>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{ title: '文件 ID', dataIndex: 'file_id', ellipsis: true, width: 120 },
|
||||||
|
]
|
||||||
|
|
||||||
|
const isGenerating = genProgress && (genProgress.status === 'pending' || genProgress.status === 'running')
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<Button type="link" onClick={() => navigate('/dataset')} style={{ paddingLeft: 0 }}>← 返回列表</Button>
|
||||||
|
{dataset && (
|
||||||
|
<Descriptions title={dataset.name} bordered size="small" style={{ marginBottom: 16 }}>
|
||||||
|
<Descriptions.Item label="描述">{dataset.description || '-'}</Descriptions.Item>
|
||||||
|
<Descriptions.Item label="样本数">{dataset.sample_count}</Descriptions.Item>
|
||||||
|
<Descriptions.Item label="创建时间">{dataset.created_at?.slice(0, 19)}</Descriptions.Item>
|
||||||
|
</Descriptions>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div style={{ display: 'flex', justifyContent: 'flex-end', gap: 8, marginBottom: 12 }}>
|
||||||
|
<Button icon={<ThunderboltOutlined />} onClick={() => setGenModal(true)}>LLM 自动生成</Button>
|
||||||
|
<Button type="primary" icon={<PlusOutlined />} onClick={() => setAddModal(true)}>手动添加样本</Button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Table rowKey="id" dataSource={samples} columns={columns} size="small" />
|
||||||
|
|
||||||
|
{/* Add sample modal */}
|
||||||
|
<Modal title="添加样本" open={addModal} onOk={addSample} onCancel={() => setAddModal(false)} width={600}>
|
||||||
|
<Form form={form} layout="vertical">
|
||||||
|
<Form.Item name="question" label="问题" rules={[{ required: true }]}>
|
||||||
|
<Input.TextArea rows={2} />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="reference_answer" label="参考答案" rules={[{ required: true }]}>
|
||||||
|
<Input.TextArea rows={3} />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="knowledge_hub_id" label="知识库 ID" rules={[{ required: true }]}>
|
||||||
|
<Input />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="relevant_chunk_ids" label="相关 Chunk IDs(每行一个)">
|
||||||
|
<Input.TextArea rows={3} placeholder="chunk_id_1 chunk_id_2" />
|
||||||
|
</Form.Item>
|
||||||
|
</Form>
|
||||||
|
</Modal>
|
||||||
|
|
||||||
|
{/* Generate modal */}
|
||||||
|
<Modal
|
||||||
|
title="LLM 自动生成样本"
|
||||||
|
open={genModal}
|
||||||
|
onCancel={closeGenModal}
|
||||||
|
width={800}
|
||||||
|
footer={
|
||||||
|
isGenerating ? null : [
|
||||||
|
<Button key="cancel" onClick={closeGenModal}>取消</Button>,
|
||||||
|
<Button key="ok" type="primary" onClick={startGenerate}
|
||||||
|
disabled={chunks.length > 0 && selectedChunkIds.length === 0}>
|
||||||
|
开始生成
|
||||||
|
</Button>,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
>
|
||||||
|
{/* Progress bar */}
|
||||||
|
{genProgress && (
|
||||||
|
<div style={{ marginBottom: 16 }}>
|
||||||
|
<Alert
|
||||||
|
type={genProgress.status === 'done' ? 'success' : genProgress.status === 'failed' ? 'error' : 'info'}
|
||||||
|
message={
|
||||||
|
genProgress.status === 'done' ? '生成完成' :
|
||||||
|
genProgress.status === 'failed' ? '生成失败' :
|
||||||
|
`正在生成中... (${genProgress.progress}/${genProgress.total})`
|
||||||
|
}
|
||||||
|
showIcon
|
||||||
|
/>
|
||||||
|
<Progress
|
||||||
|
percent={genProgress.total > 0 ? Math.round(genProgress.progress / genProgress.total * 100) : 0}
|
||||||
|
status={genProgress.status === 'failed' ? 'exception' : genProgress.status === 'done' ? 'success' : 'active'}
|
||||||
|
style={{ marginTop: 8 }}
|
||||||
|
/>
|
||||||
|
{genProgress.status === 'done' && (
|
||||||
|
<Button type="link" onClick={() => { closeGenModal(); load() }}>关闭并刷新</Button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{!isGenerating && (
|
||||||
|
<Form form={genForm} layout="vertical">
|
||||||
|
<Form.Item name="platform_config_id" label="平台配置" rules={[{ required: true }]}>
|
||||||
|
<Select placeholder="选择平台">
|
||||||
|
{platforms.map((p: any) => <Option key={p.id} value={p.id}>{p.name}</Option>)}
|
||||||
|
</Select>
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="judge_config_id" label="Judge 模型" rules={[{ required: true }]}>
|
||||||
|
<Select placeholder="选择 Judge">
|
||||||
|
{judges.map((j: any) => <Option key={j.id} value={j.id}>{j.name} ({j.model})</Option>)}
|
||||||
|
</Select>
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="knowledge_hub_id" label="知识库 ID" rules={[{ required: true }]}>
|
||||||
|
<Space.Compact style={{ width: '100%' }}>
|
||||||
|
<Form.Item name="knowledge_hub_id" noStyle rules={[{ required: true }]}>
|
||||||
|
<Input style={{ flex: 1 }} placeholder="输入知识库 ID" />
|
||||||
|
</Form.Item>
|
||||||
|
<Button icon={<SearchOutlined />} loading={chunksLoading} onClick={fetchChunks}>
|
||||||
|
加载切片
|
||||||
|
</Button>
|
||||||
|
</Space.Compact>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
{/* Chunk preview table */}
|
||||||
|
{chunks.length > 0 && (
|
||||||
|
<div style={{ marginBottom: 16 }}>
|
||||||
|
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 8 }}>
|
||||||
|
<span>共 {chunks.length} 个切片,已选 {selectedChunkIds.length} 个</span>
|
||||||
|
<Button size="small" icon={<ReloadOutlined />} onClick={fetchChunks}>刷新</Button>
|
||||||
|
</div>
|
||||||
|
<Table
|
||||||
|
rowKey="id"
|
||||||
|
dataSource={chunks}
|
||||||
|
columns={chunkColumns}
|
||||||
|
size="small"
|
||||||
|
pagination={{ pageSize: 5, size: 'small' }}
|
||||||
|
scroll={{ y: 240 }}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<Form.Item name="questions_per_chunk" label="每个切片生成问题数" initialValue={2}>
|
||||||
|
<Select>
|
||||||
|
{[1, 2, 3, 4].map(n => <Option key={n} value={n}>{n}</Option>)}
|
||||||
|
</Select>
|
||||||
|
</Form.Item>
|
||||||
|
</Form>
|
||||||
|
)}
|
||||||
|
</Modal>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
123
frontend/src/pages/Dataset/index.tsx
Normal file
123
frontend/src/pages/Dataset/index.tsx
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
import React, { useEffect, useState } from 'react'
|
||||||
|
import { Table, Button, Modal, Form, Input, Upload, Popconfirm, message, Tag, Space, Tooltip } from 'antd'
|
||||||
|
import { PlusOutlined, UploadOutlined, DeleteOutlined, EyeOutlined } from '@ant-design/icons'
|
||||||
|
import { useNavigate } from 'react-router-dom'
|
||||||
|
import { datasetApi } from '../../services/api'
|
||||||
|
|
||||||
|
export default function Dataset() {
|
||||||
|
const [datasets, setDatasets] = useState<any[]>([])
|
||||||
|
const [createModal, setCreateModal] = useState(false)
|
||||||
|
const [form] = Form.useForm()
|
||||||
|
const navigate = useNavigate()
|
||||||
|
const [selectedRowKeys, setSelectedRowKeys] = useState<React.Key[]>([])
|
||||||
|
|
||||||
|
const load = async () => {
|
||||||
|
const res = await datasetApi.list() as any
|
||||||
|
setDatasets(res.data || [])
|
||||||
|
}
|
||||||
|
|
||||||
|
useEffect(() => { load() }, [])
|
||||||
|
|
||||||
|
const create = async () => {
|
||||||
|
const vals = await form.validateFields()
|
||||||
|
await datasetApi.create(vals)
|
||||||
|
message.success('数据集已创建')
|
||||||
|
setCreateModal(false)
|
||||||
|
form.resetFields()
|
||||||
|
load()
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleImport = async (file: File) => {
|
||||||
|
try {
|
||||||
|
await datasetApi.import(file)
|
||||||
|
message.success('导入成功')
|
||||||
|
load()
|
||||||
|
} catch {
|
||||||
|
message.error('导入失败')
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 批量删除 ────────────────────────────────────────────────────────────────
|
||||||
|
const handleBatchDelete = async () => {
|
||||||
|
if (selectedRowKeys.length === 0) {
|
||||||
|
message.warning('请先选择要删除的数据集')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
Modal.confirm({
|
||||||
|
title: `确认删除选中的 ${selectedRowKeys.length} 个数据集?`,
|
||||||
|
content: '删除后将无法恢复,相关样本也会被删除。',
|
||||||
|
okText: '确认删除',
|
||||||
|
okType: 'danger',
|
||||||
|
cancelText: '取消',
|
||||||
|
async onOk() {
|
||||||
|
try {
|
||||||
|
await Promise.all(selectedRowKeys.map(id => datasetApi.delete(id as string)))
|
||||||
|
message.success(`成功删除 ${selectedRowKeys.length} 个数据集`)
|
||||||
|
setSelectedRowKeys([])
|
||||||
|
load()
|
||||||
|
} catch (e: any) {
|
||||||
|
message.error(e?.message || '批量删除失败')
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const columns = [
|
||||||
|
{ title: '名称', dataIndex: 'name', render: (v: string, r: any) => (
|
||||||
|
<a onClick={() => navigate(`/dataset/${r.id}`)}>{v}</a>
|
||||||
|
)},
|
||||||
|
{ title: '描述', dataIndex: 'description', ellipsis: true },
|
||||||
|
{ title: '样本数', dataIndex: 'sample_count', render: (v: number) => <Tag color="blue">{v}</Tag> },
|
||||||
|
{ title: '创建时间', dataIndex: 'created_at', render: (v: string) => v?.slice(0, 19) },
|
||||||
|
{
|
||||||
|
title: '操作',
|
||||||
|
render: (_: any, r: any) => (
|
||||||
|
<Space>
|
||||||
|
<Tooltip title="查看样本">
|
||||||
|
<Button size="small" icon={<EyeOutlined />} onClick={() => navigate(`/dataset/${r.id}`)} />
|
||||||
|
</Tooltip>
|
||||||
|
<Popconfirm title="确认删除该数据集及所有样本?" onConfirm={() => datasetApi.delete(r.id).then(load)}>
|
||||||
|
<Button danger size="small" icon={<DeleteOutlined />} />
|
||||||
|
</Popconfirm>
|
||||||
|
</Space>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 16 }}>
|
||||||
|
<h2 style={{ margin: 0 }}>测试集管理</h2>
|
||||||
|
<Space>
|
||||||
|
{selectedRowKeys.length > 0 && (
|
||||||
|
<Button danger icon={<DeleteOutlined />} onClick={handleBatchDelete}>
|
||||||
|
批量删除 ({selectedRowKeys.length})
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
<Upload beforeUpload={handleImport} showUploadList={false} accept=".json">
|
||||||
|
<Button icon={<UploadOutlined />}>导入 JSON</Button>
|
||||||
|
</Upload>
|
||||||
|
<Button type="primary" icon={<PlusOutlined />} onClick={() => setCreateModal(true)}>新建数据集</Button>
|
||||||
|
</Space>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Table
|
||||||
|
rowKey="id"
|
||||||
|
dataSource={datasets}
|
||||||
|
columns={columns}
|
||||||
|
rowSelection={{
|
||||||
|
selectedRowKeys,
|
||||||
|
onChange: setSelectedRowKeys,
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<Modal title="新建数据集" open={createModal} onOk={create} onCancel={() => setCreateModal(false)}>
|
||||||
|
<Form form={form} layout="vertical">
|
||||||
|
<Form.Item name="name" label="名称" rules={[{ required: true }]}><Input /></Form.Item>
|
||||||
|
<Form.Item name="description" label="描述"><Input.TextArea rows={3} /></Form.Item>
|
||||||
|
</Form>
|
||||||
|
</Modal>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
899
frontend/src/pages/MultiHop/GenTab.tsx
Normal file
899
frontend/src/pages/MultiHop/GenTab.tsx
Normal file
@ -0,0 +1,899 @@
|
|||||||
|
import React, { useEffect, useRef, useState } from 'react'
|
||||||
|
import {
|
||||||
|
Table, Button, Modal, Form, Input, InputNumber, Select, Upload,
|
||||||
|
Tag, Progress, Drawer, Space, Tooltip, Typography, message, Popconfirm,
|
||||||
|
Segmented, Empty, Pagination, Spin, Card, Row, Col, Statistic, Radio, Switch,
|
||||||
|
} from 'antd'
|
||||||
|
import {
|
||||||
|
PlusOutlined, DeleteOutlined, ReloadOutlined, UploadOutlined,
|
||||||
|
SyncOutlined, CheckCircleOutlined, CloseCircleOutlined,
|
||||||
|
WarningOutlined, DownloadOutlined, CheckOutlined, CloseOutlined,
|
||||||
|
EditOutlined, ThunderboltOutlined, DatabaseOutlined, AimOutlined, SearchOutlined,
|
||||||
|
} from '@ant-design/icons'
|
||||||
|
import { multiHopGenApi, multiHopApi, configApi, promptTemplateApi } from '../../services/api'
|
||||||
|
import DagentFileSelector from '../../components/DagentFileSelector'
|
||||||
|
|
||||||
|
const { Text } = Typography
|
||||||
|
|
||||||
|
function StatusTag({ status }: { status: string }) {
|
||||||
|
const map: Record<string, { color: string; icon?: React.ReactNode; label: string }> = {
|
||||||
|
pending: { color: 'default', label: '等待中' },
|
||||||
|
running: { color: 'processing', icon: <SyncOutlined spin />, label: '生成中' },
|
||||||
|
done: { color: 'success', label: '完成' },
|
||||||
|
failed: { color: 'error', label: '失败' },
|
||||||
|
}
|
||||||
|
const cfg = map[status] || { color: 'default', label: status }
|
||||||
|
return <Tag color={cfg.color} icon={cfg.icon}>{cfg.label}</Tag>
|
||||||
|
}
|
||||||
|
|
||||||
|
function QStatusTag({ status }: { status: string }) {
|
||||||
|
const map: Record<string, { color: string; icon: React.ReactNode; label: string }> = {
|
||||||
|
pending: { color: 'default', icon: <WarningOutlined />, label: '待审核' },
|
||||||
|
approved: { color: 'success', icon: <CheckCircleOutlined />, label: '已通过' },
|
||||||
|
rejected: { color: 'error', icon: <CloseCircleOutlined />, label: '已拒绝' },
|
||||||
|
}
|
||||||
|
const cfg = map[status] || { color: 'default', icon: null, label: status }
|
||||||
|
return <Tag color={cfg.color} icon={cfg.icon}>{cfg.label}</Tag>
|
||||||
|
}
|
||||||
|
|
||||||
|
function TypeTag({ type }: { type: string }) {
|
||||||
|
const map: Record<string, string> = {
|
||||||
|
comparison: '比较型',
|
||||||
|
reasoning: '推理型',
|
||||||
|
aggregation: '聚合型',
|
||||||
|
}
|
||||||
|
return <Tag>{map[type] || type}</Tag>
|
||||||
|
}
|
||||||
|
|
||||||
|
function EditModal({ question, onOk, onCancel }: { question: any; onOk: (v: any) => void; onCancel: () => void }) {
|
||||||
|
const [form] = Form.useForm()
|
||||||
|
useEffect(() => {
|
||||||
|
form.setFieldsValue({ question: question.question, answer: question.answer, type: question.type })
|
||||||
|
}, [question])
|
||||||
|
return (
|
||||||
|
<Modal title="编辑多跳问题" open onOk={() => form.validateFields().then(onOk)} onCancel={onCancel} width={640}>
|
||||||
|
<Form form={form} layout="vertical">
|
||||||
|
<Form.Item name="type" label="类型">
|
||||||
|
<Select options={[
|
||||||
|
{ label: '推理型', value: 'reasoning' },
|
||||||
|
{ label: '比较型', value: 'comparison' },
|
||||||
|
{ label: '聚合型', value: 'aggregation' },
|
||||||
|
]} />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="question" label="问题" rules={[{ required: true }]}>
|
||||||
|
<Input.TextArea rows={3} />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="answer" label="参考答案" rules={[{ required: true }]}>
|
||||||
|
<Input.TextArea rows={4} />
|
||||||
|
</Form.Item>
|
||||||
|
</Form>
|
||||||
|
</Modal>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function GenTab() {
|
||||||
|
const [tasks, setTasks] = useState<any[]>([])
|
||||||
|
const [loading, setLoading] = useState(false)
|
||||||
|
const [createModal, setCreateModal] = useState(false)
|
||||||
|
const [form] = Form.useForm()
|
||||||
|
const [fileList, setFileList] = useState<any[]>([])
|
||||||
|
const [submitting, setSubmitting] = useState(false)
|
||||||
|
const [judgeOptions, setJudgeOptions] = useState<{ label: string; value: string }[]>([])
|
||||||
|
const pollingRef = useRef<ReturnType<typeof setInterval> | null>(null)
|
||||||
|
|
||||||
|
// Dagent 数据源
|
||||||
|
const [dataSource, setDataSource] = useState<'file' | 'dagent'>('file')
|
||||||
|
const [dagentStats, setDagentStats] = useState<any>(null)
|
||||||
|
const [loadingStats, setLoadingStats] = useState(false)
|
||||||
|
|
||||||
|
const [reviewDrawer, setReviewDrawer] = useState<string | null>(null)
|
||||||
|
const [reviewTask, setReviewTask] = useState<any>(null)
|
||||||
|
const [questions, setQuestions] = useState<any[]>([])
|
||||||
|
const [questionTotal, setQuestionTotal] = useState(0)
|
||||||
|
const [questionPage, setQuestionPage] = useState(1)
|
||||||
|
const [questionLoading, setQuestionLoading] = useState(false)
|
||||||
|
const [statusFilter, setStatusFilter] = useState('all')
|
||||||
|
const [editingQ, setEditingQ] = useState<any>(null)
|
||||||
|
const [detailQ, setDetailQ] = useState<any>(null)
|
||||||
|
const PAGE_SIZE = 30
|
||||||
|
|
||||||
|
// 提示词模板
|
||||||
|
const [templates, setTemplates] = useState<any[]>([])
|
||||||
|
const [templateDrawer, setTemplateDrawer] = useState(false)
|
||||||
|
const [editingTemplate, setEditingTemplate] = useState<any>(null) // null=新建, obj=编辑
|
||||||
|
const [templateForm] = Form.useForm()
|
||||||
|
const [templateSubmitting, setTemplateSubmitting] = useState(false)
|
||||||
|
const [selectedTemplateContent, setSelectedTemplateContent] = useState<string | null>(null)
|
||||||
|
|
||||||
|
const loadTasks = async () => {
|
||||||
|
setLoading(true)
|
||||||
|
try {
|
||||||
|
const res = await multiHopGenApi.listTasks() as any
|
||||||
|
setTasks(res.data || [])
|
||||||
|
} finally {
|
||||||
|
setLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const loadTemplates = async () => {
|
||||||
|
try {
|
||||||
|
const res = await promptTemplateApi.list() as any
|
||||||
|
setTemplates(res.data || [])
|
||||||
|
} catch { /* ignore */ }
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleTemplateSave = async () => {
|
||||||
|
const vals = await templateForm.validateFields()
|
||||||
|
setTemplateSubmitting(true)
|
||||||
|
try {
|
||||||
|
if (editingTemplate?.id) {
|
||||||
|
await promptTemplateApi.update(editingTemplate.id, vals)
|
||||||
|
message.success('已更新')
|
||||||
|
} else {
|
||||||
|
await promptTemplateApi.create(vals)
|
||||||
|
message.success('已创建')
|
||||||
|
}
|
||||||
|
templateForm.resetFields()
|
||||||
|
setEditingTemplate(null)
|
||||||
|
loadTemplates()
|
||||||
|
} catch (e: any) {
|
||||||
|
message.error(e?.response?.data?.detail || e?.message || '保存失败')
|
||||||
|
} finally {
|
||||||
|
setTemplateSubmitting(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleTemplateDelete = async (id: string) => {
|
||||||
|
await promptTemplateApi.delete(id)
|
||||||
|
message.success('已删除')
|
||||||
|
loadTemplates()
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleImportDefault = async () => {
|
||||||
|
try {
|
||||||
|
const res = await promptTemplateApi.getDefault() as any
|
||||||
|
templateForm.setFieldValue('content', res.data?.content || '')
|
||||||
|
} catch { /* ignore */ }
|
||||||
|
}
|
||||||
|
|
||||||
|
const loadJudgeOptions = async () => {
|
||||||
|
try {
|
||||||
|
const res = await configApi.listJudges() as any
|
||||||
|
setJudgeOptions((res.data || []).map((j: any) => ({
|
||||||
|
label: `${j.name} (${j.model})`,
|
||||||
|
value: j.id,
|
||||||
|
})))
|
||||||
|
} catch { /* ignore */ }
|
||||||
|
}
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
loadTasks()
|
||||||
|
loadJudgeOptions()
|
||||||
|
loadTemplates()
|
||||||
|
pollingRef.current = setInterval(() => {
|
||||||
|
setTasks(prev => {
|
||||||
|
const hasRunning = prev.some(t => t.status === 'running' || t.status === 'pending')
|
||||||
|
if (hasRunning) loadTasks()
|
||||||
|
return prev
|
||||||
|
})
|
||||||
|
}, 3000)
|
||||||
|
return () => { if (pollingRef.current) clearInterval(pollingRef.current) }
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const loadDagentStats = async (orgId: string) => {
|
||||||
|
if (!orgId || orgId.length < 8) return
|
||||||
|
const envUrl = form.getFieldValue('env_url') || ''
|
||||||
|
setLoadingStats(true)
|
||||||
|
try {
|
||||||
|
const res = await multiHopGenApi.getDagentStats(orgId, envUrl) as any
|
||||||
|
setDagentStats(res.data || null)
|
||||||
|
} catch (e: any) {
|
||||||
|
message.error(`加载统计信息失败: ${e.message || '未知错误'}`)
|
||||||
|
setDagentStats(null)
|
||||||
|
} finally {
|
||||||
|
setLoadingStats(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleCreate = async () => {
|
||||||
|
const vals = await form.validateFields()
|
||||||
|
setSubmitting(true)
|
||||||
|
try {
|
||||||
|
const fd = new FormData()
|
||||||
|
fd.append('prompt_template_id', vals.prompt_template_id || '')
|
||||||
|
if (dataSource === 'file') {
|
||||||
|
if (!fileList.length) { message.error('请上传知识库 MD 文件'); return }
|
||||||
|
fd.append('file', fileList[0].originFileObj)
|
||||||
|
fd.append('name', vals.name || fileList[0].name)
|
||||||
|
fd.append('judge_config_id', vals.judge_config_id)
|
||||||
|
fd.append('hops_per_question', String(vals.hops_per_question ?? 2))
|
||||||
|
fd.append('questions_per_group', String(vals.questions_per_group ?? 3))
|
||||||
|
fd.append('quality_threshold', String(vals.quality_threshold ?? 0.6))
|
||||||
|
await multiHopGenApi.createTask(fd)
|
||||||
|
} else {
|
||||||
|
fd.append('org_id', vals.org_id)
|
||||||
|
fd.append('env_url', vals.env_url || '')
|
||||||
|
fd.append('name', vals.name || `Dagent多跳(${vals.org_id.slice(0, 8)}...)`)
|
||||||
|
fd.append('judge_config_id', vals.judge_config_id)
|
||||||
|
fd.append('file_ids', vals.file_ids || '')
|
||||||
|
fd.append('hops_per_question', String(vals.hops_per_question ?? 2))
|
||||||
|
fd.append('questions_per_group', String(vals.questions_per_group ?? 3))
|
||||||
|
fd.append('quality_threshold', String(vals.quality_threshold ?? 0.6))
|
||||||
|
await multiHopGenApi.createTaskFromDagent(fd)
|
||||||
|
}
|
||||||
|
message.success('生成任务已创建')
|
||||||
|
setCreateModal(false)
|
||||||
|
form.resetFields()
|
||||||
|
setSelectedTemplateContent(null)
|
||||||
|
setFileList([])
|
||||||
|
setDagentStats(null)
|
||||||
|
setDataSource('file')
|
||||||
|
loadTasks()
|
||||||
|
} catch (e: any) {
|
||||||
|
message.error(e?.response?.data?.detail || e?.message || '创建失败')
|
||||||
|
} finally {
|
||||||
|
setSubmitting(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleDelete = async (id: string) => {
|
||||||
|
try {
|
||||||
|
await multiHopGenApi.deleteTask(id)
|
||||||
|
message.success('已删除')
|
||||||
|
loadTasks()
|
||||||
|
if (reviewDrawer === id) setReviewDrawer(null)
|
||||||
|
} catch (e: any) {
|
||||||
|
message.error(e?.message || '删除失败')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const openReview = async (taskId: string) => {
|
||||||
|
setReviewDrawer(taskId)
|
||||||
|
setStatusFilter('all')
|
||||||
|
setQuestions([])
|
||||||
|
setQuestionPage(1)
|
||||||
|
try {
|
||||||
|
const res = await multiHopGenApi.getTask(taskId) as any
|
||||||
|
setReviewTask(res.data)
|
||||||
|
} catch {
|
||||||
|
message.error('加载失败')
|
||||||
|
setReviewDrawer(null)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const loadQuestions = async (taskId: string, page = 1) => {
|
||||||
|
setQuestionLoading(true)
|
||||||
|
try {
|
||||||
|
const res = await multiHopGenApi.listQuestions(taskId, {
|
||||||
|
status: statusFilter === 'all' ? undefined : statusFilter,
|
||||||
|
page,
|
||||||
|
page_size: PAGE_SIZE,
|
||||||
|
}) as any
|
||||||
|
setQuestions(res.data?.items || [])
|
||||||
|
setQuestionTotal(res.data?.total || 0)
|
||||||
|
setQuestionPage(page)
|
||||||
|
} finally {
|
||||||
|
setQuestionLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (reviewDrawer) loadQuestions(reviewDrawer, 1)
|
||||||
|
}, [reviewDrawer, statusFilter])
|
||||||
|
|
||||||
|
const refreshReview = async () => {
|
||||||
|
if (!reviewDrawer) return
|
||||||
|
const res = await multiHopGenApi.getTask(reviewDrawer) as any
|
||||||
|
setReviewTask(res.data)
|
||||||
|
loadQuestions(reviewDrawer, questionPage)
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleApprove = async (id: string) => {
|
||||||
|
await multiHopGenApi.approveQuestion(id)
|
||||||
|
refreshReview()
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleReject = async (id: string) => {
|
||||||
|
await multiHopGenApi.rejectQuestion(id)
|
||||||
|
refreshReview()
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleEdit = async (vals: any) => {
|
||||||
|
if (!editingQ) return
|
||||||
|
await multiHopGenApi.editQuestion(editingQ.id, vals)
|
||||||
|
setEditingQ(null)
|
||||||
|
message.success('已保存并通过')
|
||||||
|
refreshReview()
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleBatchApprove = async (minQuality: number) => {
|
||||||
|
if (!reviewDrawer) return
|
||||||
|
await multiHopGenApi.batchApprove(reviewDrawer, minQuality)
|
||||||
|
message.success('批量通过完成')
|
||||||
|
refreshReview()
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleExport = () => {
|
||||||
|
if (!reviewDrawer) return
|
||||||
|
const url = multiHopGenApi.exportMd(reviewDrawer)
|
||||||
|
const link = document.createElement('a')
|
||||||
|
link.href = url
|
||||||
|
link.download = `multi_hop_${reviewDrawer.slice(0, 8)}.md`
|
||||||
|
document.body.appendChild(link)
|
||||||
|
link.click()
|
||||||
|
document.body.removeChild(link)
|
||||||
|
}
|
||||||
|
|
||||||
|
const [testModal, setTestModal] = useState(false)
|
||||||
|
const [testForm] = Form.useForm()
|
||||||
|
const [testSubmitting, setTestSubmitting] = useState(false)
|
||||||
|
const [testAgentOptions, setTestAgentOptions] = useState<{ label: string; value: string }[]>([])
|
||||||
|
const [testAgentLoading, setTestAgentLoading] = useState(false)
|
||||||
|
|
||||||
|
const loadTestAgents = async () => {
|
||||||
|
const envUrl = testForm.getFieldValue('env_url')
|
||||||
|
const orgId = testForm.getFieldValue('org_id')
|
||||||
|
const dUserId = testForm.getFieldValue('d_user_id') || 'test'
|
||||||
|
if (!envUrl || !orgId) { message.warning('请先填写环境地址和 Org ID'); return }
|
||||||
|
setTestAgentLoading(true)
|
||||||
|
try {
|
||||||
|
const res = await multiHopApi.listDagentAgents(envUrl, orgId, dUserId) as any
|
||||||
|
const agents = res.data || []
|
||||||
|
if (!agents.length) { message.warning('未找到可用的 Agent'); return }
|
||||||
|
setTestAgentOptions(agents.map((a: any) => ({ label: `${a.name || a.id} (${String(a.id).slice(0, 8)}...)`, value: a.id })))
|
||||||
|
message.success(`找到 ${agents.length} 个 Agent`)
|
||||||
|
} catch (e: any) {
|
||||||
|
message.error(e?.response?.data?.detail || e?.message || '拉取 Agent 列表失败')
|
||||||
|
} finally {
|
||||||
|
setTestAgentLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleCreateTest = async () => {
|
||||||
|
if (!reviewDrawer) return
|
||||||
|
const vals = await testForm.validateFields()
|
||||||
|
setTestSubmitting(true)
|
||||||
|
try {
|
||||||
|
const res = await multiHopGenApi.createTest(reviewDrawer, {
|
||||||
|
env_url: vals.env_url,
|
||||||
|
org_id: vals.org_id,
|
||||||
|
agent_id: vals.agent_id,
|
||||||
|
llm_type: vals.llm_type || 'deepseek_v3',
|
||||||
|
d_user_id: vals.d_user_id || 'test',
|
||||||
|
top_k: vals.top_k ?? 10,
|
||||||
|
concurrency: vals.concurrency ?? 5,
|
||||||
|
name: vals.name || '',
|
||||||
|
}) as any
|
||||||
|
message.success(`召回测试已创建,共 ${res.data.question_count} 题,请切换到「召回测试」Tab 查看进度`)
|
||||||
|
setTestModal(false)
|
||||||
|
testForm.resetFields()
|
||||||
|
} catch (e: any) {
|
||||||
|
message.error(e?.response?.data?.detail || e?.message || '创建失败')
|
||||||
|
} finally {
|
||||||
|
setTestSubmitting(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const columns = [
|
||||||
|
{ title: '任务名称', dataIndex: 'name', ellipsis: true },
|
||||||
|
{ title: '状态', dataIndex: 'status', width: 100, render: (v: string) => <StatusTag status={v} /> },
|
||||||
|
{
|
||||||
|
title: '进度', width: 160,
|
||||||
|
render: (_: any, r: any) => r.status === 'running'
|
||||||
|
? <Progress percent={r.total ? Math.round(r.progress / r.total * 100) : 0} size="small" />
|
||||||
|
: r.status === 'done'
|
||||||
|
? <Text type="success">{r.total} 组完成</Text>
|
||||||
|
: r.status === 'failed'
|
||||||
|
? <Tooltip title={r.error_message}><Text type="danger">失败</Text></Tooltip>
|
||||||
|
: '-',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '已通过', width: 90,
|
||||||
|
render: (_: any, r: any) => r.status === 'done'
|
||||||
|
? <Text type="success">{r.approved ?? 0} 题</Text>
|
||||||
|
: '-',
|
||||||
|
},
|
||||||
|
{ title: '创建时间', dataIndex: 'created_at', width: 160, render: (v: string) => v?.slice(0, 19) },
|
||||||
|
{
|
||||||
|
title: '操作', width: 160,
|
||||||
|
render: (_: any, r: any) => (
|
||||||
|
<Space>
|
||||||
|
<Button size="small" disabled={r.status !== 'done'} onClick={() => openReview(r.id)}>审核</Button>
|
||||||
|
<Button size="small" danger onClick={() => handleDelete(r.id)}>删除</Button>
|
||||||
|
</Space>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
const statusOptions = [
|
||||||
|
{ label: '全部', value: 'all' },
|
||||||
|
{ label: '待审核', value: 'pending' },
|
||||||
|
{ label: '已通过', value: 'approved' },
|
||||||
|
{ label: '已拒绝', value: 'rejected' },
|
||||||
|
]
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<div style={{ display: 'flex', justifyContent: 'flex-end', marginBottom: 16 }}>
|
||||||
|
<Space>
|
||||||
|
<Button icon={<ReloadOutlined />} onClick={loadTasks}>刷新</Button>
|
||||||
|
<Button icon={<EditOutlined />} onClick={() => { setTemplateDrawer(true); loadTemplates() }}>提示词模板</Button>
|
||||||
|
<Button type="primary" icon={<PlusOutlined />} onClick={() => setCreateModal(true)}>新建生成任务</Button>
|
||||||
|
</Space>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Table
|
||||||
|
rowKey="id"
|
||||||
|
dataSource={tasks}
|
||||||
|
columns={columns}
|
||||||
|
loading={loading}
|
||||||
|
size="small"
|
||||||
|
pagination={{ pageSize: 20 }}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{/* 新建弹窗 */}
|
||||||
|
<Modal
|
||||||
|
title="新建多跳 Case 生成任务"
|
||||||
|
open={createModal}
|
||||||
|
onOk={handleCreate}
|
||||||
|
onCancel={() => {
|
||||||
|
setCreateModal(false); form.resetFields(); setFileList([])
|
||||||
|
setDagentStats(null); setDataSource('file'); setSelectedTemplateContent(null)
|
||||||
|
}}
|
||||||
|
confirmLoading={submitting}
|
||||||
|
width={560}
|
||||||
|
>
|
||||||
|
<Form form={form} layout="vertical"
|
||||||
|
initialValues={{ hops_per_question: 2, questions_per_group: 3, quality_threshold: 0.6 }}>
|
||||||
|
|
||||||
|
{/* 数据来源切换 */}
|
||||||
|
<Form.Item label="数据来源">
|
||||||
|
<Radio.Group value={dataSource} onChange={e => { setDataSource(e.target.value); setDagentStats(null) }}>
|
||||||
|
<Radio value="file"><UploadOutlined /> 上传 MD 文件</Radio>
|
||||||
|
<Radio value="dagent"><DatabaseOutlined /> 从 Dagent 知识库导入</Radio>
|
||||||
|
</Radio.Group>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item name="name" label="任务名称">
|
||||||
|
<Input placeholder={dataSource === 'file' ? '可选,默认使用文件名' : '可选,默认使用组织 ID'} />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="judge_config_id" label="LLM 配置" rules={[{ required: true, message: '请选择 LLM 配置' }]}>
|
||||||
|
<Select options={judgeOptions} placeholder="请选择用于生成的 LLM" />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item
|
||||||
|
name="prompt_template_id"
|
||||||
|
label={
|
||||||
|
<Space size={4}>
|
||||||
|
<span>提示词模板</span>
|
||||||
|
<Button
|
||||||
|
type="link"
|
||||||
|
size="small"
|
||||||
|
style={{ padding: 0, height: 'auto', fontSize: 12 }}
|
||||||
|
onClick={() => { setTemplateDrawer(true); loadTemplates() }}
|
||||||
|
>
|
||||||
|
管理模板
|
||||||
|
</Button>
|
||||||
|
</Space>
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<Select
|
||||||
|
placeholder="使用默认(不选则使用内置提示词)"
|
||||||
|
allowClear
|
||||||
|
options={[
|
||||||
|
...templates.map(t => ({ label: t.name, value: t.id })),
|
||||||
|
]}
|
||||||
|
onChange={(val) => {
|
||||||
|
const tpl = templates.find(t => t.id === val)
|
||||||
|
setSelectedTemplateContent(tpl?.content || null)
|
||||||
|
}}
|
||||||
|
onClear={() => setSelectedTemplateContent(null)}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
{selectedTemplateContent && (
|
||||||
|
<div style={{
|
||||||
|
marginTop: -12, marginBottom: 12, padding: '8px 10px',
|
||||||
|
background: '#f6f8fa', border: '1px solid #e8e8e8', borderRadius: 4,
|
||||||
|
fontSize: 12, color: '#555', whiteSpace: 'pre-wrap', maxHeight: 120, overflowY: 'auto',
|
||||||
|
}}>
|
||||||
|
{selectedTemplateContent}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
<Row gutter={12}>
|
||||||
|
<Col span={8}>
|
||||||
|
<Form.Item name="hops_per_question" label="每题 Hop 数" tooltip="每个问题需要跨越的章节/文件数,建议 2-3">
|
||||||
|
<InputNumber min={2} max={5} style={{ width: '100%' }} />
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
<Col span={8}>
|
||||||
|
<Form.Item name="questions_per_group" label="每组问题数">
|
||||||
|
<InputNumber min={1} max={10} style={{ width: '100%' }} />
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
<Col span={8}>
|
||||||
|
<Form.Item name="quality_threshold" label="自动通过阈值">
|
||||||
|
<InputNumber min={0} max={1} step={0.1} style={{ width: '100%' }} />
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
|
|
||||||
|
{dataSource === 'file' ? (
|
||||||
|
<Form.Item label="知识库 MD 文件" required>
|
||||||
|
<Upload
|
||||||
|
accept=".md"
|
||||||
|
maxCount={1}
|
||||||
|
fileList={fileList}
|
||||||
|
beforeUpload={() => false}
|
||||||
|
onChange={({ fileList: fl }) => setFileList(fl)}
|
||||||
|
>
|
||||||
|
<Button icon={<UploadOutlined />}>选择文件</Button>
|
||||||
|
</Upload>
|
||||||
|
<div style={{ marginTop: 4, color: '#888', fontSize: 12 }}>
|
||||||
|
按 ## 标题切分章节,LLM 跨章节生成多跳问题
|
||||||
|
</div>
|
||||||
|
</Form.Item>
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
|
<Form.Item name="env_url" label="Dagent 环境地址" rules={[{ required: true, message: '请输入环境地址' }]}>
|
||||||
|
<Input placeholder="https://dagent.d-robotics.cc" />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item name="org_id" label="Dagent 组织 ID" rules={[{ required: true, message: '请输入组织 ID' }]}>
|
||||||
|
<Input.Search
|
||||||
|
placeholder="输入 org_id 后点击查询统计"
|
||||||
|
enterButton="查询"
|
||||||
|
loading={loadingStats}
|
||||||
|
onSearch={loadDagentStats}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
{dagentStats && (
|
||||||
|
<div style={{ background: '#f6ffed', border: '1px solid #b7eb8f', borderRadius: 6, padding: '10px 14px', marginBottom: 16 }}>
|
||||||
|
<Row gutter={16}>
|
||||||
|
<Col span={8}><Statistic title="文件数" value={dagentStats.file_count ?? 0} valueStyle={{ fontSize: 18 }} /></Col>
|
||||||
|
<Col span={8}><Statistic title="段落数" value={dagentStats.paragraph_count ?? 0} valueStyle={{ fontSize: 18 }} /></Col>
|
||||||
|
<Col span={8}><Statistic title="含图段落" value={dagentStats.paragraphs_with_pic_text ?? 0} valueStyle={{ fontSize: 18 }} /></Col>
|
||||||
|
</Row>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
<Form.Item name="file_ids" label="选择文件" tooltip="留空则使用全部已处理文件,每个文件取最具代表性的段落参与多跳组合">
|
||||||
|
<DagentFileSelector
|
||||||
|
orgId={form.getFieldValue('org_id') || ''}
|
||||||
|
envUrl={form.getFieldValue('env_url') || ''}
|
||||||
|
disabled={!form.getFieldValue('org_id')}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</Form>
|
||||||
|
</Modal>
|
||||||
|
|
||||||
|
{/* 审核 Drawer */}
|
||||||
|
<Drawer
|
||||||
|
title={`多跳 Case 审核 — ${reviewTask?.name || ''}`}
|
||||||
|
open={!!reviewDrawer}
|
||||||
|
onClose={() => { setReviewDrawer(null); setReviewTask(null) }}
|
||||||
|
width="80%"
|
||||||
|
extra={
|
||||||
|
<Space>
|
||||||
|
<Button icon={<DownloadOutlined />} onClick={handleExport} disabled={!reviewTask?.approved}>
|
||||||
|
导出 {reviewTask?.approved ? `(${reviewTask.approved} 题)` : ''}
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
type="primary"
|
||||||
|
icon={<AimOutlined />}
|
||||||
|
disabled={!reviewTask?.approved}
|
||||||
|
onClick={() => {
|
||||||
|
testForm.setFieldsValue({ name: `${reviewTask?.name || ''}-召回测试` })
|
||||||
|
setTestModal(true)
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
创建召回测试
|
||||||
|
</Button>
|
||||||
|
</Space>
|
||||||
|
}
|
||||||
|
>
|
||||||
|
{!reviewTask ? <Spin /> : (
|
||||||
|
<div style={{ display: 'flex', flexDirection: 'column', height: '100%' }}>
|
||||||
|
<Card size="small" style={{ marginBottom: 12 }}>
|
||||||
|
<Row gutter={16}>
|
||||||
|
<Col span={6}><Statistic title="总组数" value={reviewTask.total || 0} /></Col>
|
||||||
|
<Col span={6}><Statistic title="已通过" value={reviewTask.approved || 0} valueStyle={{ color: '#52c41a' }} /></Col>
|
||||||
|
<Col span={6}><Statistic title="进度" value={reviewTask.total ? Math.round(reviewTask.progress / reviewTask.total * 100) : 0} suffix="%" /></Col>
|
||||||
|
<Col span={6}><StatusTag status={reviewTask.status} /></Col>
|
||||||
|
</Row>
|
||||||
|
</Card>
|
||||||
|
|
||||||
|
<div style={{ marginBottom: 12, display: 'flex', justifyContent: 'space-between', alignItems: 'center' }}>
|
||||||
|
<Segmented
|
||||||
|
options={statusOptions}
|
||||||
|
value={statusFilter}
|
||||||
|
onChange={v => { setStatusFilter(v as string); setQuestionPage(1) }}
|
||||||
|
/>
|
||||||
|
<Space>
|
||||||
|
<Button size="small" icon={<ThunderboltOutlined />} onClick={() => handleBatchApprove(0.6)}>
|
||||||
|
通过高质量(≥0.6)
|
||||||
|
</Button>
|
||||||
|
<Button size="small" onClick={() => handleBatchApprove(0)}>全部通过</Button>
|
||||||
|
</Space>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div style={{ flex: 1, overflow: 'auto' }}>
|
||||||
|
<Spin spinning={questionLoading}>
|
||||||
|
{questions.length === 0 && !questionLoading
|
||||||
|
? <Empty description="暂无问题" />
|
||||||
|
: questions.map(q => (
|
||||||
|
<Card
|
||||||
|
key={q.id}
|
||||||
|
size="small"
|
||||||
|
style={{
|
||||||
|
marginBottom: 8,
|
||||||
|
borderColor: q.status === 'approved' ? '#b7eb8f'
|
||||||
|
: q.status === 'rejected' ? '#ffa39e' : undefined,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<div style={{ display: 'flex', justifyContent: 'space-between', gap: 8 }}>
|
||||||
|
<div style={{ flex: 1, minWidth: 0 }}>
|
||||||
|
<div style={{ marginBottom: 4 }}>
|
||||||
|
<Space size={4}>
|
||||||
|
<TypeTag type={q.type} />
|
||||||
|
<Text strong>{q.qid}</Text>
|
||||||
|
{q.quality_score != null && (
|
||||||
|
<Text type="secondary" style={{ fontSize: 12 }}>
|
||||||
|
质量 {q.quality_score.toFixed(2)}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
</Space>
|
||||||
|
</div>
|
||||||
|
<div style={{ fontWeight: 500, marginBottom: 4 }}>Q: {q.question}</div>
|
||||||
|
<div style={{ color: '#555', fontSize: 13, marginBottom: 6 }}>
|
||||||
|
<Text type="secondary">A: </Text>{q.answer}
|
||||||
|
</div>
|
||||||
|
<div style={{ display: 'flex', gap: 6, flexWrap: 'wrap', marginBottom: 4 }}>
|
||||||
|
{(q.hops || []).map((h: any, i: number) => (
|
||||||
|
<Tag key={i} style={{ fontSize: 11 }}>
|
||||||
|
Hop{i + 1}: {h.section_path?.split('/').pop() || h.section_path}
|
||||||
|
</Tag>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
<QStatusTag status={q.status} />
|
||||||
|
</div>
|
||||||
|
<Space direction="vertical" size={4} style={{ flexShrink: 0 }}>
|
||||||
|
{q.status !== 'approved' && (
|
||||||
|
<Button size="small" type="primary" icon={<CheckOutlined />}
|
||||||
|
onClick={() => handleApprove(q.id)}>通过</Button>
|
||||||
|
)}
|
||||||
|
{q.status !== 'rejected' && (
|
||||||
|
<Button size="small" danger icon={<CloseOutlined />}
|
||||||
|
onClick={() => handleReject(q.id)}>拒绝</Button>
|
||||||
|
)}
|
||||||
|
<Button size="small" icon={<EditOutlined />} onClick={() => setEditingQ(q)}>编辑</Button>
|
||||||
|
<Button size="small" onClick={() => setDetailQ(q)}>详情</Button>
|
||||||
|
</Space>
|
||||||
|
</div>
|
||||||
|
</Card>
|
||||||
|
))
|
||||||
|
}
|
||||||
|
</Spin>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{questionTotal > PAGE_SIZE && (
|
||||||
|
<div style={{ marginTop: 12, textAlign: 'right' }}>
|
||||||
|
<Pagination
|
||||||
|
size="small"
|
||||||
|
current={questionPage}
|
||||||
|
pageSize={PAGE_SIZE}
|
||||||
|
total={questionTotal}
|
||||||
|
onChange={p => { setQuestionPage(p); loadQuestions(reviewDrawer!, p) }}
|
||||||
|
showTotal={t => `共 ${t} 条`}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</Drawer>
|
||||||
|
|
||||||
|
{/* 问题详情 Drawer */}
|
||||||
|
<Drawer
|
||||||
|
title={`${detailQ?.qid} 详情`}
|
||||||
|
width={520}
|
||||||
|
open={!!detailQ}
|
||||||
|
onClose={() => setDetailQ(null)}
|
||||||
|
>
|
||||||
|
{detailQ && (
|
||||||
|
<div>
|
||||||
|
<Card size="small" title="问题" style={{ marginBottom: 12 }}>
|
||||||
|
<Space style={{ marginBottom: 8 }}><TypeTag type={detailQ.type} /></Space>
|
||||||
|
<div style={{ fontWeight: 500, marginBottom: 8 }}>{detailQ.question}</div>
|
||||||
|
<Text type="secondary">参考答案:{detailQ.answer}</Text>
|
||||||
|
</Card>
|
||||||
|
<Card size="small" title="Hop 来源章节">
|
||||||
|
{(detailQ.hops || []).map((h: any, i: number) => (
|
||||||
|
<div key={i} style={{ marginBottom: 8, padding: '6px 8px', background: '#fafafa', borderRadius: 4, border: '1px solid #f0f0f0' }}>
|
||||||
|
<Text strong>Hop{i + 1}:</Text>
|
||||||
|
<Text code style={{ fontSize: 12 }}>{h.section_path}</Text>
|
||||||
|
{h.contribution && (
|
||||||
|
<div style={{ fontSize: 12, color: '#888', marginTop: 2 }}>{h.contribution}</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</Card>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</Drawer>
|
||||||
|
|
||||||
|
{editingQ && (
|
||||||
|
<EditModal question={editingQ} onOk={handleEdit} onCancel={() => setEditingQ(null)} />
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* 创建召回测试弹窗 */}
|
||||||
|
<Modal
|
||||||
|
title="创建召回测试"
|
||||||
|
open={testModal}
|
||||||
|
onOk={handleCreateTest}
|
||||||
|
onCancel={() => { setTestModal(false); testForm.resetFields(); setTestAgentOptions([]) }}
|
||||||
|
confirmLoading={testSubmitting}
|
||||||
|
width={480}
|
||||||
|
>
|
||||||
|
<div style={{ marginBottom: 12, color: '#666', fontSize: 13 }}>
|
||||||
|
将已通过的 <strong>{reviewTask?.approved ?? 0}</strong> 个多跳问题直接创建为召回测试任务
|
||||||
|
</div>
|
||||||
|
<Form form={testForm} layout="vertical" initialValues={{ d_user_id: 'test', top_k: 10, concurrency: 5, llm_type: 'deepseek_v3' }}>
|
||||||
|
<Form.Item name="name" label="测试任务名称">
|
||||||
|
<Input placeholder="可选" />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="env_url" label="Agent 环境地址" rules={[{ required: true }]}>
|
||||||
|
<Input placeholder="https://dagent.d-robotics.cc/dagent" />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="org_id" label="Org ID" rules={[{ required: true }]}>
|
||||||
|
<Input placeholder="a4d49699ba313815..." />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="agent_id" label="Agent" rules={[{ required: true, message: '请选择 Agent' }]}>
|
||||||
|
<Select
|
||||||
|
placeholder="填写环境地址和 Org ID 后点右侧按钮查询"
|
||||||
|
options={testAgentOptions}
|
||||||
|
notFoundContent={testAgentOptions.length === 0 ? '点击下方「查询 Agent」' : '无匹配'}
|
||||||
|
dropdownRender={menu => (
|
||||||
|
<div>
|
||||||
|
{menu}
|
||||||
|
<div style={{ padding: '6px 8px', borderTop: '1px solid #f0f0f0' }}>
|
||||||
|
<Button
|
||||||
|
size="small"
|
||||||
|
icon={<SearchOutlined />}
|
||||||
|
loading={testAgentLoading}
|
||||||
|
onClick={loadTestAgents}
|
||||||
|
block
|
||||||
|
>
|
||||||
|
查询 Agent 列表
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="llm_type" label="LLM 类型" tooltip="Agent 使用的 LLM,不同模型可用性取决于远程环境">
|
||||||
|
<Select options={[
|
||||||
|
{ label: 'DeepSeek V3', value: 'deepseek_v3' },
|
||||||
|
{ label: 'DeepSeek R1', value: 'deepseek-r1' },
|
||||||
|
{ label: 'Volc DeepSeek V3', value: 'volc_deepseek_v3_250324' },
|
||||||
|
{ label: 'Azure GPT-4o', value: 'azure_openai_4o' },
|
||||||
|
{ label: 'Azure GPT-4.1', value: 'azure/gpt-4.1' },
|
||||||
|
{ label: 'Claude 3.5 Sonnet', value: 'aws/claude-3-5-sonnet' },
|
||||||
|
]} />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="d_user_id" label="User ID">
|
||||||
|
<Input />
|
||||||
|
</Form.Item>
|
||||||
|
<Row gutter={12}>
|
||||||
|
<Col span={12}>
|
||||||
|
<Form.Item name="concurrency" label="并发数">
|
||||||
|
<InputNumber min={1} max={10} style={{ width: '100%' }} />
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
|
</Form>
|
||||||
|
</Modal>
|
||||||
|
|
||||||
|
{/* 提示词模板管理 Drawer */}
|
||||||
|
<Drawer
|
||||||
|
title="提示词模板管理"
|
||||||
|
width={720}
|
||||||
|
open={templateDrawer}
|
||||||
|
onClose={() => { setTemplateDrawer(false); setEditingTemplate(null); templateForm.resetFields() }}
|
||||||
|
extra={
|
||||||
|
<Button type="primary" icon={<PlusOutlined />} onClick={() => { setEditingTemplate({}); templateForm.resetFields() }}>
|
||||||
|
新建模板
|
||||||
|
</Button>
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<Row gutter={16} style={{ height: '100%' }}>
|
||||||
|
{/* 左侧:模板列表 */}
|
||||||
|
<Col span={editingTemplate !== null ? 10 : 24}>
|
||||||
|
{templates.length === 0 ? (
|
||||||
|
<Empty description="暂无模板,点击右上角新建" />
|
||||||
|
) : (
|
||||||
|
templates.map(t => (
|
||||||
|
<Card
|
||||||
|
key={t.id}
|
||||||
|
size="small"
|
||||||
|
style={{ marginBottom: 8, cursor: 'pointer', border: editingTemplate?.id === t.id ? '1px solid #1677ff' : undefined }}
|
||||||
|
onClick={() => { setEditingTemplate(t); templateForm.setFieldsValue({ name: t.name, description: t.description, content: t.content }) }}
|
||||||
|
actions={[
|
||||||
|
<Button
|
||||||
|
key="edit"
|
||||||
|
type="link"
|
||||||
|
size="small"
|
||||||
|
icon={<EditOutlined />}
|
||||||
|
onClick={e => { e.stopPropagation(); setEditingTemplate(t); templateForm.setFieldsValue({ name: t.name, description: t.description, content: t.content }) }}
|
||||||
|
>
|
||||||
|
编辑
|
||||||
|
</Button>,
|
||||||
|
<Popconfirm
|
||||||
|
key="del"
|
||||||
|
title="确认删除此模板?"
|
||||||
|
onConfirm={e => { e?.stopPropagation(); handleTemplateDelete(t.id) }}
|
||||||
|
>
|
||||||
|
<Button type="link" size="small" danger icon={<DeleteOutlined />} onClick={e => e.stopPropagation()}>删除</Button>
|
||||||
|
</Popconfirm>,
|
||||||
|
]}
|
||||||
|
>
|
||||||
|
<div style={{ fontWeight: 500 }}>{t.name}</div>
|
||||||
|
{t.description && <Text type="secondary" style={{ fontSize: 12 }}>{t.description}</Text>}
|
||||||
|
<div style={{ marginTop: 6, fontSize: 12, color: '#888', whiteSpace: 'pre-wrap', maxHeight: 60, overflow: 'hidden' }}>
|
||||||
|
{t.content}
|
||||||
|
</div>
|
||||||
|
</Card>
|
||||||
|
))
|
||||||
|
)}
|
||||||
|
</Col>
|
||||||
|
|
||||||
|
{/* 右侧:编辑区 */}
|
||||||
|
{editingTemplate !== null && (
|
||||||
|
<Col span={14}>
|
||||||
|
<Card
|
||||||
|
size="small"
|
||||||
|
title={editingTemplate.id ? '编辑模板' : '新建模板'}
|
||||||
|
extra={
|
||||||
|
<Button size="small" onClick={() => { setEditingTemplate(null); templateForm.resetFields() }}>取消</Button>
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<Form form={templateForm} layout="vertical">
|
||||||
|
<Form.Item name="name" label="模板名称" rules={[{ required: true, message: '请输入名称' }]}>
|
||||||
|
<Input placeholder="例如:偏操作步骤型" />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="description" label="描述(可选)">
|
||||||
|
<Input placeholder="简要说明此模板的用途" />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item
|
||||||
|
name="content"
|
||||||
|
label={
|
||||||
|
<Space size={4}>
|
||||||
|
<span>生成要求</span>
|
||||||
|
<Button
|
||||||
|
type="link"
|
||||||
|
size="small"
|
||||||
|
style={{ padding: 0, height: 'auto', fontSize: 12 }}
|
||||||
|
onClick={handleImportDefault}
|
||||||
|
>
|
||||||
|
导入默认模板
|
||||||
|
</Button>
|
||||||
|
</Space>
|
||||||
|
}
|
||||||
|
rules={[{ required: true, message: '请输入生成要求' }]}
|
||||||
|
tooltip="只需填写生成要求,系统会自动拼接角色定义、章节内容和输出格式"
|
||||||
|
>
|
||||||
|
<Input.TextArea
|
||||||
|
rows={10}
|
||||||
|
placeholder={'1. 每个问题必须真正跨越多个章节...\n2. 问题类型可以是...\n3. ...'}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
<Button type="primary" loading={templateSubmitting} onClick={handleTemplateSave} block>
|
||||||
|
保存
|
||||||
|
</Button>
|
||||||
|
</Form>
|
||||||
|
</Card>
|
||||||
|
</Col>
|
||||||
|
)}
|
||||||
|
</Row>
|
||||||
|
</Drawer>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
642
frontend/src/pages/MultiHop/index.tsx
Normal file
642
frontend/src/pages/MultiHop/index.tsx
Normal file
@ -0,0 +1,642 @@
|
|||||||
|
import React, { useEffect, useRef, useState } from 'react'
|
||||||
|
import {
|
||||||
|
Table, Button, Modal, Form, Input, InputNumber, Upload, Tag, Progress,
|
||||||
|
Drawer, Card, Row, Col, Statistic, Space, Tooltip, Typography, message,
|
||||||
|
Collapse, Badge, Segmented, Select,
|
||||||
|
} from 'antd'
|
||||||
|
import {
|
||||||
|
PlusOutlined, DeleteOutlined, ReloadOutlined, UploadOutlined,
|
||||||
|
SyncOutlined, CheckCircleOutlined, CloseCircleOutlined, MinusCircleOutlined,
|
||||||
|
AimOutlined, BulbOutlined, SearchOutlined,
|
||||||
|
} from '@ant-design/icons'
|
||||||
|
import { multiHopApi } from '../../services/api'
|
||||||
|
import GenTab from './GenTab'
|
||||||
|
|
||||||
|
const { Text } = Typography
|
||||||
|
|
||||||
|
function StatusTag({ status }: { status: string }) {
|
||||||
|
const map: Record<string, { color: string; icon?: React.ReactNode; label: string }> = {
|
||||||
|
pending: { color: 'default', label: '等待中' },
|
||||||
|
running: { color: 'processing', icon: <SyncOutlined spin />, label: '运行中' },
|
||||||
|
done: { color: 'success', label: '完成' },
|
||||||
|
failed: { color: 'error', label: '失败' },
|
||||||
|
}
|
||||||
|
const cfg = map[status] || { color: 'default', label: status }
|
||||||
|
return <Tag color={cfg.color} icon={cfg.icon}>{cfg.label}</Tag>
|
||||||
|
}
|
||||||
|
|
||||||
|
function HitTag({ full, partial }: { full: boolean; partial: boolean }) {
|
||||||
|
if (full) return <Tag color="success" icon={<CheckCircleOutlined />}>全命中</Tag>
|
||||||
|
if (partial) return <Tag color="warning" icon={<MinusCircleOutlined />}>部分命中</Tag>
|
||||||
|
return <Tag color="error" icon={<CloseCircleOutlined />}>未命中</Tag>
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function MultiHop() {
|
||||||
|
const [tasks, setTasks] = useState<any[]>([])
|
||||||
|
const [loading, setLoading] = useState(false)
|
||||||
|
const [createModal, setCreateModal] = useState(false)
|
||||||
|
const [form] = Form.useForm()
|
||||||
|
const [submitting, setSubmitting] = useState(false)
|
||||||
|
const [fileList, setFileList] = useState<any[]>([])
|
||||||
|
const pollingRef = useRef<ReturnType<typeof setInterval> | null>(null)
|
||||||
|
|
||||||
|
// Agent 列表
|
||||||
|
const [agentOptions, setAgentOptions] = useState<{ label: string; value: string; desc?: string }[]>([])
|
||||||
|
const [loadingAgents, setLoadingAgents] = useState(false)
|
||||||
|
|
||||||
|
// 报告 Drawer
|
||||||
|
const [drawerTaskId, setDrawerTaskId] = useState<string | null>(null)
|
||||||
|
const [summary, setSummary] = useState<any>(null)
|
||||||
|
const [results, setResults] = useState<any[]>([])
|
||||||
|
const [drawerLoading, setDrawerLoading] = useState(false)
|
||||||
|
|
||||||
|
// 详情 Drawer
|
||||||
|
const [detailResult, setDetailResult] = useState<any>(null)
|
||||||
|
|
||||||
|
// 批量删除
|
||||||
|
const [selectedKeys, setSelectedKeys] = useState<React.Key[]>([])
|
||||||
|
|
||||||
|
const loadTasks = async () => {
|
||||||
|
setLoading(true)
|
||||||
|
try {
|
||||||
|
const res = await multiHopApi.listTasks() as any
|
||||||
|
setTasks(res.data || [])
|
||||||
|
} finally {
|
||||||
|
setLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const loadAgents = async () => {
|
||||||
|
const envUrl = form.getFieldValue('env_url')
|
||||||
|
const orgId = form.getFieldValue('org_id')
|
||||||
|
const dUserId = form.getFieldValue('d_user_id') || 'test'
|
||||||
|
if (!envUrl || !orgId) { message.warning('请先填写环境地址和 Org ID'); return }
|
||||||
|
setLoadingAgents(true)
|
||||||
|
try {
|
||||||
|
const res = await multiHopApi.listDagentAgents(envUrl, orgId, dUserId) as any
|
||||||
|
const agents = res.data || []
|
||||||
|
if (!agents.length) { message.warning('未找到可用的 Agent'); return }
|
||||||
|
setAgentOptions(agents.map((a: any) => ({
|
||||||
|
label: a.name || a.id,
|
||||||
|
value: a.id,
|
||||||
|
desc: a.description || a.type || '',
|
||||||
|
})))
|
||||||
|
message.success(`找到 ${agents.length} 个 Agent`)
|
||||||
|
} catch (e: any) {
|
||||||
|
message.error(e?.response?.data?.detail || e?.message || '拉取 Agent 列表失败')
|
||||||
|
} finally {
|
||||||
|
setLoadingAgents(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const loadReport = async (taskId: string) => {
|
||||||
|
setDrawerLoading(true)
|
||||||
|
try {
|
||||||
|
const [sumRes, resRes] = await Promise.all([
|
||||||
|
multiHopApi.getSummary(taskId) as any,
|
||||||
|
multiHopApi.getResults(taskId) as any,
|
||||||
|
])
|
||||||
|
setSummary(sumRes.data)
|
||||||
|
setResults(resRes.data || [])
|
||||||
|
} finally {
|
||||||
|
setDrawerLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
loadTasks()
|
||||||
|
pollingRef.current = setInterval(() => {
|
||||||
|
setTasks(prev => {
|
||||||
|
const hasRunning = prev.some(t => t.status === 'running' || t.status === 'pending')
|
||||||
|
if (hasRunning) loadTasks()
|
||||||
|
return prev
|
||||||
|
})
|
||||||
|
}, 3000)
|
||||||
|
return () => { if (pollingRef.current) clearInterval(pollingRef.current) }
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const handleCreate = async () => {
|
||||||
|
const vals = await form.validateFields()
|
||||||
|
if (!fileList.length) { message.error('请上传多跳问答 MD 文件'); return }
|
||||||
|
setSubmitting(true)
|
||||||
|
try {
|
||||||
|
const fd = new FormData()
|
||||||
|
fd.append('file', fileList[0].originFileObj)
|
||||||
|
fd.append('name', vals.name || fileList[0].name)
|
||||||
|
fd.append('env_url', vals.env_url)
|
||||||
|
fd.append('org_id', vals.org_id)
|
||||||
|
fd.append('agent_id', vals.agent_id)
|
||||||
|
fd.append('llm_type', vals.llm_type || 'deepseek_v3')
|
||||||
|
fd.append('d_user_id', vals.d_user_id || 'test')
|
||||||
|
fd.append('top_k', String(vals.top_k ?? 10))
|
||||||
|
fd.append('concurrency', String(vals.concurrency ?? 5))
|
||||||
|
await multiHopApi.createTask(fd)
|
||||||
|
message.success('任务已创建')
|
||||||
|
setCreateModal(false)
|
||||||
|
form.resetFields()
|
||||||
|
setFileList([])
|
||||||
|
loadTasks()
|
||||||
|
} catch (e: any) {
|
||||||
|
message.error(e?.response?.data?.detail || e?.message || '创建失败')
|
||||||
|
} finally {
|
||||||
|
setSubmitting(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleDelete = async (id: string) => {
|
||||||
|
try {
|
||||||
|
await multiHopApi.deleteTask(id)
|
||||||
|
message.success('已删除')
|
||||||
|
loadTasks()
|
||||||
|
if (drawerTaskId === id) setDrawerTaskId(null)
|
||||||
|
} catch (e: any) {
|
||||||
|
message.error(e?.message || '删除失败')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleBatchDelete = () => {
|
||||||
|
if (!selectedKeys.length) { message.warning('请先选择任务'); return }
|
||||||
|
Modal.confirm({
|
||||||
|
title: `确认删除选中的 ${selectedKeys.length} 个任务?`,
|
||||||
|
okType: 'danger',
|
||||||
|
okText: '确认删除',
|
||||||
|
cancelText: '取消',
|
||||||
|
async onOk() {
|
||||||
|
await Promise.all(selectedKeys.map(id => multiHopApi.deleteTask(id as string)))
|
||||||
|
message.success('批量删除成功')
|
||||||
|
setSelectedKeys([])
|
||||||
|
loadTasks()
|
||||||
|
if (drawerTaskId && selectedKeys.includes(drawerTaskId)) setDrawerTaskId(null)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const openReport = (taskId: string) => {
|
||||||
|
setDrawerTaskId(taskId)
|
||||||
|
loadReport(taskId)
|
||||||
|
}
|
||||||
|
|
||||||
|
const columns = [
|
||||||
|
{ title: '任务名称', dataIndex: 'name', ellipsis: true },
|
||||||
|
{ title: '环境地址', dataIndex: 'env_url', ellipsis: true, width: 200 },
|
||||||
|
{ title: 'Org ID', dataIndex: 'org_id', ellipsis: true, width: 160 },
|
||||||
|
{
|
||||||
|
title: '状态', dataIndex: 'status', width: 100,
|
||||||
|
render: (v: string) => <StatusTag status={v} />,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '进度', width: 160,
|
||||||
|
render: (_: any, r: any) => r.status === 'running'
|
||||||
|
? <Progress percent={r.total ? Math.round(r.progress / r.total * 100) : 0} size="small" />
|
||||||
|
: r.status === 'done'
|
||||||
|
? <Text type="success">{r.total} 题完成</Text>
|
||||||
|
: r.status === 'failed'
|
||||||
|
? <Text type="danger">失败</Text>
|
||||||
|
: '-',
|
||||||
|
},
|
||||||
|
{ title: '创建时间', dataIndex: 'created_at', width: 160, render: (v: string) => v?.slice(0, 19) },
|
||||||
|
{
|
||||||
|
title: '操作', width: 140,
|
||||||
|
render: (_: any, r: any) => (
|
||||||
|
<Space>
|
||||||
|
<Button size="small" disabled={r.status !== 'done'} onClick={() => openReport(r.id)}>报告</Button>
|
||||||
|
<Button size="small" danger onClick={() => handleDelete(r.id)}>删除</Button>
|
||||||
|
</Space>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
const drawerTask = tasks.find(t => t.id === drawerTaskId)
|
||||||
|
const [activeTab, setActiveTab] = useState<'test' | 'gen'>('test')
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{/* 标题栏 */}
|
||||||
|
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 16 }}>
|
||||||
|
<Segmented
|
||||||
|
value={activeTab}
|
||||||
|
onChange={v => setActiveTab(v as 'test' | 'gen')}
|
||||||
|
options={[
|
||||||
|
{ label: '召回测试', value: 'test', icon: <AimOutlined /> },
|
||||||
|
{ label: '生成 Case', value: 'gen', icon: <BulbOutlined /> },
|
||||||
|
]}
|
||||||
|
/>
|
||||||
|
{activeTab === 'test' && <Space>
|
||||||
|
{selectedKeys.length > 0 && (
|
||||||
|
<Button danger icon={<DeleteOutlined />} onClick={handleBatchDelete}>
|
||||||
|
删除选中 ({selectedKeys.length})
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
<Button icon={<ReloadOutlined />} onClick={loadTasks}>刷新</Button>
|
||||||
|
<Button type="primary" icon={<PlusOutlined />} onClick={() => setCreateModal(true)}>新建测试</Button>
|
||||||
|
</Space>}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{activeTab === 'gen' ? <GenTab /> : (
|
||||||
|
<>
|
||||||
|
{/* 任务列表 */}
|
||||||
|
<Table
|
||||||
|
rowKey="id"
|
||||||
|
dataSource={tasks}
|
||||||
|
columns={columns}
|
||||||
|
loading={loading}
|
||||||
|
size="small"
|
||||||
|
rowSelection={{ selectedRowKeys: selectedKeys, onChange: setSelectedKeys }}
|
||||||
|
pagination={{ pageSize: 20 }}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{/* 新建弹窗 */}
|
||||||
|
<Modal
|
||||||
|
title="新建多跳召回测试"
|
||||||
|
open={createModal}
|
||||||
|
onOk={handleCreate}
|
||||||
|
onCancel={() => { setCreateModal(false); form.resetFields(); setFileList([]); setAgentOptions([]) }}
|
||||||
|
confirmLoading={submitting}
|
||||||
|
width={520}
|
||||||
|
>
|
||||||
|
<Form form={form} layout="vertical"
|
||||||
|
initialValues={{ d_user_id: 'test', concurrency: 5 }}>
|
||||||
|
<Form.Item name="name" label="任务名称">
|
||||||
|
<Input placeholder="可选,默认使用文件名" />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="env_url" label="Agent 环境地址" rules={[{ required: true }]}>
|
||||||
|
<Input placeholder="https://your-dagent-env.com" />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="org_id" label="Org ID" rules={[{ required: true }]}>
|
||||||
|
<Input placeholder="cd6e121594984516..." />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="d_user_id" label="User ID">
|
||||||
|
<Input />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item
|
||||||
|
name="agent_id"
|
||||||
|
label="Agent"
|
||||||
|
rules={[{ required: true, message: '请选择 Agent' }]}
|
||||||
|
tooltip="测试会调用 /agent/chat,由 Agent 自主决定搜几次知识库"
|
||||||
|
>
|
||||||
|
<Select
|
||||||
|
placeholder="填写环境地址和 Org ID 后点右侧按钮查询"
|
||||||
|
options={agentOptions.map(a => ({
|
||||||
|
label: (
|
||||||
|
<Space size={4}>
|
||||||
|
<span>{a.label}</span>
|
||||||
|
{a.desc && <Text type="secondary" style={{ fontSize: 11 }}>{a.desc}</Text>}
|
||||||
|
</Space>
|
||||||
|
),
|
||||||
|
value: a.value,
|
||||||
|
}))}
|
||||||
|
notFoundContent={agentOptions.length === 0 ? '点击下方「查询 Agent」' : '无匹配'}
|
||||||
|
dropdownRender={menu => (
|
||||||
|
<div>
|
||||||
|
{menu}
|
||||||
|
<div style={{ padding: '6px 8px', borderTop: '1px solid #f0f0f0' }}>
|
||||||
|
<Button
|
||||||
|
size="small"
|
||||||
|
icon={<SearchOutlined />}
|
||||||
|
loading={loadingAgents}
|
||||||
|
onClick={loadAgents}
|
||||||
|
block
|
||||||
|
>
|
||||||
|
查询 Agent 列表
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item
|
||||||
|
name="llm_type"
|
||||||
|
label="LLM 类型"
|
||||||
|
tooltip="Agent 使用的 LLM,不同模型可用性取决于远程环境配置"
|
||||||
|
>
|
||||||
|
<Select
|
||||||
|
options={[
|
||||||
|
{ label: 'DeepSeek V3', value: 'deepseek_v3' },
|
||||||
|
{ label: 'DeepSeek R1', value: 'deepseek-r1' },
|
||||||
|
{ label: 'Volc DeepSeek V3', value: 'volc_deepseek_v3_250324' },
|
||||||
|
{ label: 'Azure GPT-4o', value: 'azure_openai_4o' },
|
||||||
|
{ label: 'Azure GPT-4.1', value: 'azure/gpt-4.1' },
|
||||||
|
{ label: 'Claude 3.5 Sonnet', value: 'aws/claude-3-5-sonnet' },
|
||||||
|
]}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
<Row gutter={12}>
|
||||||
|
<Col span={12}>
|
||||||
|
<Form.Item name="concurrency" label="并发数">
|
||||||
|
<InputNumber min={1} max={10} style={{ width: '100%' }} />
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
|
<Form.Item label="多跳问答 MD 文件" required>
|
||||||
|
<Upload
|
||||||
|
accept=".md"
|
||||||
|
maxCount={1}
|
||||||
|
fileList={fileList}
|
||||||
|
beforeUpload={() => false}
|
||||||
|
onChange={({ fileList: fl }) => setFileList(fl)}
|
||||||
|
>
|
||||||
|
<Button icon={<UploadOutlined />}>选择文件</Button>
|
||||||
|
</Upload>
|
||||||
|
<div style={{ marginTop: 4, color: '#888', fontSize: 12 }}>
|
||||||
|
格式参考:## MH1 / **问题:** / **答案:** / **Hop1:** section_path | 说明
|
||||||
|
</div>
|
||||||
|
</Form.Item>
|
||||||
|
</Form>
|
||||||
|
</Modal>
|
||||||
|
|
||||||
|
{/* 报告 Drawer */}
|
||||||
|
<Drawer
|
||||||
|
title={drawerTask?.name}
|
||||||
|
width="80%"
|
||||||
|
open={!!drawerTaskId}
|
||||||
|
onClose={() => setDrawerTaskId(null)}
|
||||||
|
>
|
||||||
|
{summary && (
|
||||||
|
<>
|
||||||
|
<Card size="small" style={{ marginBottom: 16 }}>
|
||||||
|
<Row gutter={16}>
|
||||||
|
<Col span={4}><Statistic title="总问题数" value={summary.total} /></Col>
|
||||||
|
<Col span={4}>
|
||||||
|
<Statistic title="全命中率" value={(summary.full_hit_rate * 100).toFixed(1)} suffix="%" valueStyle={{ color: '#52c41a' }} />
|
||||||
|
</Col>
|
||||||
|
<Col span={4}>
|
||||||
|
<Statistic title="部分命中率" value={(summary.partial_hit_rate * 100).toFixed(1)} suffix="%" valueStyle={{ color: '#faad14' }} />
|
||||||
|
</Col>
|
||||||
|
<Col span={4}>
|
||||||
|
<Statistic title="平均Hop命中" value={(summary.avg_hop_hit_rate * 100).toFixed(1)} suffix="%" />
|
||||||
|
</Col>
|
||||||
|
<Col span={4}>
|
||||||
|
<Statistic title="平均相似度" value={summary.avg_cosine_sim ?? '-'} precision={4} />
|
||||||
|
</Col>
|
||||||
|
<Col span={4}>
|
||||||
|
<Statistic title="平均延迟" value={summary.avg_latency_ms?.toFixed(0) ?? '-'} suffix="ms" />
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
|
{(summary.empty_count > 0 || summary.error_count > 0) && (
|
||||||
|
<div style={{ marginTop: 8 }}>
|
||||||
|
{summary.empty_count > 0 && <Tag color="warning">空召回 {summary.empty_count} 题</Tag>}
|
||||||
|
{summary.error_count > 0 && <Tag color="error">错误 {summary.error_count} 题</Tag>}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</Card>
|
||||||
|
<Table
|
||||||
|
rowKey="id"
|
||||||
|
dataSource={results}
|
||||||
|
loading={drawerLoading}
|
||||||
|
size="small"
|
||||||
|
pagination={{ pageSize: 20 }}
|
||||||
|
columns={[
|
||||||
|
{ title: 'ID', dataIndex: 'qid', width: 70 },
|
||||||
|
{ title: '类型', dataIndex: 'type', width: 90,
|
||||||
|
render: (v: string) => {
|
||||||
|
const map: Record<string, string> = { comparison: '比较型', reasoning: '推理型', aggregation: '聚合型' }
|
||||||
|
return <Tag>{map[v] || v}</Tag>
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{ title: '问题', dataIndex: 'question', ellipsis: true },
|
||||||
|
{ title: '命中', width: 110, render: (_: any, r: any) => <HitTag full={r.full_hit === 1} partial={r.partial_hit === 1} /> },
|
||||||
|
{ title: 'Hop命中', width: 80, render: (_: any, r: any) => `${r.hop_hit_count}/${r.hop_count}` },
|
||||||
|
{ title: '最佳相似度', dataIndex: 'best_cosine_sim', width: 100, render: (v: number) => v != null ? v.toFixed(4) : '-' },
|
||||||
|
{ title: '延迟', dataIndex: 'latency_ms', width: 80, render: (v: number) => `${v}ms` },
|
||||||
|
{ title: '操作', width: 70, render: (_: any, r: any) => <Button size="small" onClick={() => setDetailResult(r)}>详情</Button> },
|
||||||
|
]}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</Drawer>
|
||||||
|
|
||||||
|
{/* 问题详情 Drawer */}
|
||||||
|
<Drawer
|
||||||
|
title={
|
||||||
|
<Space>
|
||||||
|
<span>{detailResult?.qid}</span>
|
||||||
|
{detailResult && (() => {
|
||||||
|
const typeMap: Record<string, { label: string; color: string; desc: string }> = {
|
||||||
|
comparison: { label: '比较型', color: 'blue', desc: '需对比多个文档中的同类信息' },
|
||||||
|
reasoning: { label: '推理型', color: 'purple', desc: '需从多个文档逐步推导出结论' },
|
||||||
|
aggregation: { label: '聚合型', color: 'cyan', desc: '需从多个文档收集同类信息汇总' },
|
||||||
|
}
|
||||||
|
const t = typeMap[detailResult.type] || { label: detailResult.type, color: 'default', desc: '' }
|
||||||
|
return (
|
||||||
|
<Tooltip title={t.desc}>
|
||||||
|
<Tag color={t.color} style={{ cursor: 'help' }}>{t.label}</Tag>
|
||||||
|
</Tooltip>
|
||||||
|
)
|
||||||
|
})()}
|
||||||
|
</Space>
|
||||||
|
}
|
||||||
|
width={620}
|
||||||
|
open={!!detailResult}
|
||||||
|
onClose={() => setDetailResult(null)}
|
||||||
|
>
|
||||||
|
{detailResult && (() => {
|
||||||
|
const hops: any[] = detailResult.hops || []
|
||||||
|
const actualHops: any[] = detailResult.actual_hops || []
|
||||||
|
const retrieved: any[] = detailResult.retrieved || []
|
||||||
|
|
||||||
|
// 期望 hop → 在合并召回列表中的排名
|
||||||
|
const hopRankMap: Record<number, number> = {}
|
||||||
|
hops.forEach((h, hi) => {
|
||||||
|
if (!h.file_id) return
|
||||||
|
const rank = retrieved.findIndex((r: any) => r.file_id === h.file_id)
|
||||||
|
hopRankMap[hi] = rank >= 0 ? rank + 1 : 0
|
||||||
|
})
|
||||||
|
|
||||||
|
// 合并召回列表中每条属于哪个期望 hop
|
||||||
|
const chunkHopMap: Record<number, number[]> = {}
|
||||||
|
retrieved.forEach((chunk: any, ci) => {
|
||||||
|
hops.forEach((h, hi) => {
|
||||||
|
if (h.file_id && h.file_id === chunk.file_id) {
|
||||||
|
if (!chunkHopMap[ci]) chunkHopMap[ci] = []
|
||||||
|
chunkHopMap[ci].push(hi + 1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// 诊断:Agent 无召回的原因
|
||||||
|
const noActualHopReason: string | null = (() => {
|
||||||
|
if (actualHops.length > 0) return null
|
||||||
|
if (detailResult.error) return null // 有 error 单独展示
|
||||||
|
return 'Agent 未返回任何召回结果,可能原因:Agent ID 配置错误、网络超时,或该问题触发了 Agent 的拒答逻辑。请检查任务配置后重新运行。'
|
||||||
|
})()
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{/* 问题 & 答案 */}
|
||||||
|
<Card size="small" style={{ marginBottom: 12 }}>
|
||||||
|
<div style={{ fontWeight: 500, marginBottom: 6 }}>{detailResult.question}</div>
|
||||||
|
<Text type="secondary" style={{ fontSize: 12 }}>参考答案:{detailResult.answer}</Text>
|
||||||
|
{detailResult.agent_answer && (
|
||||||
|
<div style={{ marginTop: 8, padding: '6px 10px', background: '#f0f5ff', borderRadius: 4, fontSize: 12 }}>
|
||||||
|
<Text type="secondary">Agent 回答:</Text>{detailResult.agent_answer}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</Card>
|
||||||
|
|
||||||
|
{/* 期望跳链 */}
|
||||||
|
<Card
|
||||||
|
size="small"
|
||||||
|
style={{ marginBottom: 12 }}
|
||||||
|
title={
|
||||||
|
<Space>
|
||||||
|
<span>期望跳链({hops.length} 跳)</span>
|
||||||
|
<Text type="secondary" style={{ fontSize: 12, fontWeight: 400 }}>
|
||||||
|
— 回答此问题需要覆盖的文档
|
||||||
|
</Text>
|
||||||
|
</Space>
|
||||||
|
}
|
||||||
|
>
|
||||||
|
{hops.map((h: any, i: number) => {
|
||||||
|
const hit = h.hit
|
||||||
|
const hitAtHop = h.hit_at_hop
|
||||||
|
// 细化未命中原因
|
||||||
|
const missReason: string = (() => {
|
||||||
|
if (hit) return `第 ${hitAtHop} 跳命中`
|
||||||
|
if (!h.file_id) return '文件映射失败'
|
||||||
|
if (actualHops.length === 0) return 'Agent 无召回'
|
||||||
|
return '未召回'
|
||||||
|
})()
|
||||||
|
// 文件映射失败用橙色,其他未命中用红色
|
||||||
|
const missColor = !h.file_id ? '#fa8c16' : '#ff4d4f'
|
||||||
|
const rankColor = hit ? '#52c41a' : missColor
|
||||||
|
// 文件映射失败时背景用橙色系
|
||||||
|
const bgColor = hit ? '#f6ffed' : (!h.file_id ? '#fff7e6' : '#fff2f0')
|
||||||
|
const borderColor = hit ? '#b7eb8f' : (!h.file_id ? '#ffd591' : '#ffccc7')
|
||||||
|
return (
|
||||||
|
<div key={i} style={{ display: 'flex', gap: 12, alignItems: 'flex-start' }}>
|
||||||
|
<div style={{ display: 'flex', flexDirection: 'column', alignItems: 'center', flexShrink: 0 }}>
|
||||||
|
<div style={{
|
||||||
|
width: 26, height: 26, borderRadius: '50%', display: 'flex',
|
||||||
|
alignItems: 'center', justifyContent: 'center', fontSize: 12, fontWeight: 600,
|
||||||
|
background: bgColor,
|
||||||
|
border: `2px solid ${rankColor}`,
|
||||||
|
color: rankColor,
|
||||||
|
}}>{i + 1}</div>
|
||||||
|
{i < hops.length - 1 && (
|
||||||
|
<div style={{ width: 2, flex: 1, minHeight: 16, background: '#f0f0f0', margin: '3px 0' }} />
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<div style={{
|
||||||
|
flex: 1, padding: '5px 10px', borderRadius: 6, marginBottom: 8,
|
||||||
|
background: bgColor,
|
||||||
|
border: `1px solid ${borderColor}`,
|
||||||
|
}}>
|
||||||
|
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 3 }}>
|
||||||
|
<Space size={4}>
|
||||||
|
{hit ? <CheckCircleOutlined style={{ color: '#52c41a' }} /> : <CloseCircleOutlined style={{ color: missColor }} />}
|
||||||
|
<Text strong style={{ fontSize: 13 }}>Hop {i + 1}</Text>
|
||||||
|
<Text style={{ fontSize: 12 }}>{h.file_name || h.section_path}</Text>
|
||||||
|
</Space>
|
||||||
|
<span style={{ fontSize: 12, color: rankColor, fontWeight: 500 }}>{missReason}</span>
|
||||||
|
</div>
|
||||||
|
{!h.file_id && (
|
||||||
|
<div style={{ fontSize: 11, color: '#ad6800', marginBottom: 3 }}>
|
||||||
|
⚠️ section_path「{h.section_path}」未能匹配到知识库中的任何文件,命中判断已跳过此跳
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{h.contribution && (
|
||||||
|
<div style={{ fontSize: 12, color: '#666' }}>📌 {h.contribution}</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
})}
|
||||||
|
</Card>
|
||||||
|
|
||||||
|
{/* Agent 无召回时的诊断提示 */}
|
||||||
|
{noActualHopReason && (
|
||||||
|
<Card
|
||||||
|
size="small"
|
||||||
|
style={{ marginBottom: 12, borderColor: '#faad14', background: '#fffbe6' }}
|
||||||
|
title={<Text style={{ color: '#ad6800' }}>⚠️ Agent 召回诊断</Text>}
|
||||||
|
>
|
||||||
|
<Text style={{ fontSize: 12, color: '#ad6800' }}>{noActualHopReason}</Text>
|
||||||
|
</Card>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* 每跳召回详情 */}
|
||||||
|
{actualHops.length > 0 && actualHops.map((ah: any, hopIdx: number) => {
|
||||||
|
const docs: any[] = ah.retrieved || []
|
||||||
|
const hitHopNums = hops
|
||||||
|
.map((h: any, hi: number) => h.file_id && docs.some((d: any) => d.file_id === h.file_id) ? hi + 1 : null)
|
||||||
|
.filter(Boolean)
|
||||||
|
const hopColors = ['#1890ff', '#722ed1', '#13c2c2', '#fa8c16', '#eb2f96']
|
||||||
|
const color = hopColors[hopIdx % hopColors.length]
|
||||||
|
return (
|
||||||
|
<Card
|
||||||
|
key={hopIdx}
|
||||||
|
size="small"
|
||||||
|
style={{ marginBottom: 12, borderColor: color }}
|
||||||
|
title={
|
||||||
|
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center' }}>
|
||||||
|
<Space size={6}>
|
||||||
|
<div style={{
|
||||||
|
width: 24, height: 24, borderRadius: '50%', display: 'inline-flex',
|
||||||
|
alignItems: 'center', justifyContent: 'center', fontSize: 12, fontWeight: 700,
|
||||||
|
background: color, color: '#fff',
|
||||||
|
}}>{ah.hop_index}</div>
|
||||||
|
<span style={{ fontWeight: 600 }}>第 {ah.hop_index} 跳</span>
|
||||||
|
{hitHopNums.map((n: any) => (
|
||||||
|
<Tag key={n} color="success" style={{ fontSize: 11, margin: 0 }}>命中期望 Hop{n}</Tag>
|
||||||
|
))}
|
||||||
|
</Space>
|
||||||
|
<Text type="secondary" style={{ fontSize: 12 }}>{docs.length} 条召回</Text>
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
>
|
||||||
|
{ah.query && (
|
||||||
|
<div style={{ fontSize: 12, color: '#555', marginBottom: 8, padding: '4px 8px', background: '#fafafa', borderRadius: 4 }}>
|
||||||
|
🔍 Query:{ah.query.length > 120 ? ah.query.slice(0, 120) + '...' : ah.query}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{docs.length === 0
|
||||||
|
? <Text type="secondary" style={{ fontSize: 12 }}>无召回结果</Text>
|
||||||
|
: docs.map((d: any, di: number) => {
|
||||||
|
const sim = d.cosine_distance_1 != null ? (1 - d.cosine_distance_1).toFixed(4) : null
|
||||||
|
const isExpected = hops.some((h: any) => h.file_id && h.file_id === d.file_id)
|
||||||
|
const matchedHops = hops
|
||||||
|
.map((h: any, hi: number) => h.file_id && h.file_id === d.file_id ? hi + 1 : null)
|
||||||
|
.filter(Boolean)
|
||||||
|
return (
|
||||||
|
<div key={di} style={{
|
||||||
|
marginBottom: 6, padding: '5px 8px', borderRadius: 4,
|
||||||
|
background: isExpected ? '#e6f7ff' : '#fafafa',
|
||||||
|
border: `1px solid ${isExpected ? '#91d5ff' : '#f0f0f0'}`,
|
||||||
|
}}>
|
||||||
|
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center' }}>
|
||||||
|
<Space size={4}>
|
||||||
|
<Text style={{ fontSize: 12, color: '#999' }}>#{di + 1}</Text>
|
||||||
|
{matchedHops.map((n: any) => (
|
||||||
|
<Tag key={n} color="blue" style={{ fontSize: 10, margin: 0, lineHeight: '16px' }}>Hop{n}</Tag>
|
||||||
|
))}
|
||||||
|
<Text style={{ fontSize: 12 }}>{d.file_name || d.headers || d.file_id || '未知文件'}</Text>
|
||||||
|
</Space>
|
||||||
|
{sim && <Text type="secondary" style={{ fontSize: 11 }}>相似度 {sim}</Text>}
|
||||||
|
</div>
|
||||||
|
{d.paragraph_content && (
|
||||||
|
<div style={{ fontSize: 11, color: '#666', marginTop: 3, maxHeight: 48, overflow: 'hidden', lineHeight: 1.4 }}>
|
||||||
|
{d.paragraph_content.slice(0, 150)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
</Card>
|
||||||
|
)
|
||||||
|
})}
|
||||||
|
|
||||||
|
{detailResult.error && (
|
||||||
|
<Card size="small" title="错误" style={{ marginBottom: 12 }}>
|
||||||
|
<Text type="danger">{detailResult.error}</Text>
|
||||||
|
</Card>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
})()}
|
||||||
|
</Drawer>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
1728
frontend/src/pages/QaGen/index.tsx
Normal file
1728
frontend/src/pages/QaGen/index.tsx
Normal file
File diff suppressed because it is too large
Load Diff
283
frontend/src/pages/Report/index.tsx
Normal file
283
frontend/src/pages/Report/index.tsx
Normal file
@ -0,0 +1,283 @@
|
|||||||
|
import React, { useEffect, useState } from 'react'
|
||||||
|
import { Table, Button, Tabs, Tag, Statistic, Row, Col, Card, Drawer, Typography, Spin, Empty, Alert, Tooltip } from 'antd'
|
||||||
|
import { ArrowLeftOutlined, QuestionCircleOutlined } from '@ant-design/icons'
|
||||||
|
import { useParams, useNavigate } from 'react-router-dom'
|
||||||
|
import { Radar } from '@ant-design/charts'
|
||||||
|
import { reportApi, taskApi } from '../../services/api'
|
||||||
|
import { metricLabel, metricCn, METRICS } from '../../constants/metrics'
|
||||||
|
|
||||||
|
const { Text, Paragraph } = Typography
|
||||||
|
|
||||||
|
function MetricCard({ metricKey, value, color }: { metricKey: string; value: number | null; color: string }) {
|
||||||
|
const metric = METRICS[metricKey]
|
||||||
|
return (
|
||||||
|
<Card size="small" style={{ textAlign: 'center' }}>
|
||||||
|
<Statistic
|
||||||
|
title={
|
||||||
|
<div>
|
||||||
|
<div style={{ fontWeight: 500 }}>{metricLabel(metricKey)}</div>
|
||||||
|
{metric && (
|
||||||
|
<div style={{ fontSize: 11, color: '#888', fontWeight: 400, marginTop: 2, lineHeight: 1.4 }}>
|
||||||
|
{metric.desc}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
value={value != null ? (value * 100).toFixed(1) : 'N/A'}
|
||||||
|
suffix={value != null ? '%' : ''}
|
||||||
|
valueStyle={{ color, fontSize: 22 }}
|
||||||
|
/>
|
||||||
|
</Card>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function Report() {
|
||||||
|
const { taskId } = useParams<{ taskId: string }>()
|
||||||
|
const navigate = useNavigate()
|
||||||
|
const [report, setReport] = useState<any>(null)
|
||||||
|
const [items, setItems] = useState<any[]>([])
|
||||||
|
const [task, setTask] = useState<any>(null)
|
||||||
|
const [loading, setLoading] = useState(true)
|
||||||
|
const [drawer, setDrawer] = useState<any>(null)
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
Promise.all([
|
||||||
|
reportApi.get(taskId!),
|
||||||
|
reportApi.items(taskId!),
|
||||||
|
taskApi.get(taskId!),
|
||||||
|
]).then(([r, i, t]: any[]) => {
|
||||||
|
setReport(r.data)
|
||||||
|
setItems(i.data?.records || [])
|
||||||
|
setTask(t.data)
|
||||||
|
}).finally(() => setLoading(false))
|
||||||
|
}, [taskId])
|
||||||
|
|
||||||
|
if (loading) return <Spin style={{ display: 'block', marginTop: 80 }} />
|
||||||
|
if (!report) return <Empty description="报告不存在或任务尚未完成" style={{ marginTop: 80 }} />
|
||||||
|
|
||||||
|
const selectedMetrics = task?.selected_metrics || []
|
||||||
|
const shouldShow = (key: string) => selectedMetrics.length === 0 || selectedMetrics.includes(key)
|
||||||
|
|
||||||
|
// Radar chart data - only show selected metrics
|
||||||
|
const radarData = [
|
||||||
|
{ metric: metricCn('hit_rate'), value: report.avg_hit_rate ?? 0, key: 'hit_rate' },
|
||||||
|
{ metric: metricCn('mrr'), value: report.avg_mrr ?? 0, key: 'mrr' },
|
||||||
|
{ metric: metricCn('ndcg'), value: report.avg_ndcg ?? 0, key: 'ndcg' },
|
||||||
|
{ metric: metricCn('context_precision'), value: report.avg_context_precision ?? 0, key: 'context_precision' },
|
||||||
|
{ metric: metricCn('context_recall'), value: report.avg_context_recall ?? 0, key: 'context_recall' },
|
||||||
|
{ metric: metricCn('faithfulness'), value: report.avg_faithfulness ?? 0, key: 'faithfulness' },
|
||||||
|
{ metric: metricCn('answer_relevance'), value: report.avg_answer_relevance ?? 0, key: 'answer_relevance' },
|
||||||
|
{ metric: metricCn('answer_correctness'), value: report.avg_answer_correctness ?? 0, key: 'answer_correctness' },
|
||||||
|
{ metric: metricCn('groundedness'), value: report.avg_groundedness ?? 0, key: 'groundedness' },
|
||||||
|
].filter(d => d.value > 0 && shouldShow(d.key))
|
||||||
|
|
||||||
|
const radarConfig = {
|
||||||
|
data: radarData,
|
||||||
|
xField: 'metric',
|
||||||
|
yField: 'value',
|
||||||
|
area: { style: { fillOpacity: 0.3 } },
|
||||||
|
scale: { y: { domain: [0, 1] } },
|
||||||
|
axis: { y: { tickCount: 5 } },
|
||||||
|
height: 320,
|
||||||
|
}
|
||||||
|
|
||||||
|
const itemColumns = [
|
||||||
|
{ title: '问题', dataIndex: 'question', ellipsis: true, width: '25%' },
|
||||||
|
shouldShow('hit_rate') && {
|
||||||
|
title: metricCn('hit_rate'), dataIndex: 'hit_rate',
|
||||||
|
render: (v: number | null) => v != null ? <Tag color={v >= 0.8 ? 'green' : v >= 0.5 ? 'orange' : 'red'}>{(v * 100).toFixed(0)}%</Tag> : '-',
|
||||||
|
},
|
||||||
|
shouldShow('mrr') && {
|
||||||
|
title: metricCn('mrr'), dataIndex: 'mrr',
|
||||||
|
render: (v: number | null) => v != null ? (v).toFixed(3) : '-',
|
||||||
|
},
|
||||||
|
shouldShow('ndcg') && {
|
||||||
|
title: metricCn('ndcg'), dataIndex: 'ndcg',
|
||||||
|
render: (v: number | null) => v != null ? (v).toFixed(3) : '-',
|
||||||
|
},
|
||||||
|
shouldShow('faithfulness') && {
|
||||||
|
title: metricCn('faithfulness'), dataIndex: 'faithfulness',
|
||||||
|
render: (v: number | null) => v != null
|
||||||
|
? <Tag color={v >= 0.8 ? 'green' : v >= 0.6 ? 'orange' : 'red'}>{(v * 100).toFixed(0)}%</Tag>
|
||||||
|
: '-',
|
||||||
|
},
|
||||||
|
shouldShow('answer_relevance') && {
|
||||||
|
title: metricCn('answer_relevance'), dataIndex: 'answer_relevance',
|
||||||
|
render: (v: number | null) => v != null ? (v).toFixed(3) : '-',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '状态', dataIndex: 'error',
|
||||||
|
render: (v: string | null) => v ? <Tag color="red">失败</Tag> : <Tag color="green">正常</Tag>,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '详情',
|
||||||
|
render: (_: any, r: any) => (
|
||||||
|
<Button size="small" onClick={() => setDrawer(r)}>查看</Button>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
].filter(Boolean)
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<Button type="link" icon={<ArrowLeftOutlined />} onClick={() => navigate('/task')} style={{ paddingLeft: 0 }}>
|
||||||
|
返回任务列表
|
||||||
|
</Button>
|
||||||
|
|
||||||
|
<h2 style={{ marginTop: 8 }}>
|
||||||
|
评测报告 — {task?.name || taskId?.slice(0, 12)}
|
||||||
|
<Tag color="success" style={{ marginLeft: 12, fontSize: 14 }}>
|
||||||
|
{report.sample_count} 条样本
|
||||||
|
</Tag>
|
||||||
|
</h2>
|
||||||
|
|
||||||
|
{/* Composite scores */}
|
||||||
|
<Row gutter={16} style={{ marginBottom: 24 }}>
|
||||||
|
<Col span={6}>
|
||||||
|
<Card style={{ background: '#f0f5ff', border: '1px solid #adc6ff' }}>
|
||||||
|
<Statistic
|
||||||
|
title="RAG Score(综合评分)"
|
||||||
|
value={report.rag_score != null ? (report.rag_score * 100).toFixed(1) : 'N/A'}
|
||||||
|
suffix={report.rag_score != null ? '%' : ''}
|
||||||
|
valueStyle={{ color: '#1677ff', fontSize: 28 }}
|
||||||
|
/>
|
||||||
|
</Card>
|
||||||
|
</Col>
|
||||||
|
<Col span={6}>
|
||||||
|
<Card style={{ background: '#fff7e6', border: '1px solid #ffd591' }}>
|
||||||
|
<Statistic
|
||||||
|
title="幻觉发生率"
|
||||||
|
value={report.hallucination_rate != null ? (report.hallucination_rate * 100).toFixed(1) : 'N/A'}
|
||||||
|
suffix={report.hallucination_rate != null ? '%' : ''}
|
||||||
|
valueStyle={{ color: report.hallucination_rate > 0.2 ? '#cf1322' : '#389e0d', fontSize: 28 }}
|
||||||
|
/>
|
||||||
|
</Card>
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
|
|
||||||
|
{/* Interpretation */}
|
||||||
|
{report.interpretation && (
|
||||||
|
<Alert
|
||||||
|
message="评测结果解读"
|
||||||
|
description={
|
||||||
|
<div style={{ whiteSpace: 'pre-wrap', lineHeight: 1.8 }}>
|
||||||
|
{report.interpretation}
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
type="info"
|
||||||
|
showIcon
|
||||||
|
style={{ marginBottom: 24 }}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<Tabs
|
||||||
|
items={[
|
||||||
|
{
|
||||||
|
key: 'overview',
|
||||||
|
label: '指标总览',
|
||||||
|
children: (
|
||||||
|
<Row gutter={[16, 16]}>
|
||||||
|
<Col span={12}>
|
||||||
|
<Card title="雷达图" size="small">
|
||||||
|
{radarData.length > 0 ? <Radar {...radarConfig} /> : <Empty description="暂无数据" />}
|
||||||
|
</Card>
|
||||||
|
</Col>
|
||||||
|
<Col span={12}>
|
||||||
|
<Row gutter={[12, 12]}>
|
||||||
|
{shouldShow('hit_rate') && <Col span={12}><MetricCard metricKey="hit_rate" value={report.avg_hit_rate} color="#1677ff" /></Col>}
|
||||||
|
{shouldShow('mrr') && <Col span={12}><MetricCard metricKey="mrr" value={report.avg_mrr} color="#1677ff" /></Col>}
|
||||||
|
{shouldShow('ndcg') && <Col span={12}><MetricCard metricKey="ndcg" value={report.avg_ndcg} color="#1677ff" /></Col>}
|
||||||
|
{shouldShow('context_precision') && <Col span={12}><MetricCard metricKey="context_precision" value={report.avg_context_precision} color="#722ed1" /></Col>}
|
||||||
|
{shouldShow('context_recall') && <Col span={12}><MetricCard metricKey="context_recall" value={report.avg_context_recall} color="#722ed1" /></Col>}
|
||||||
|
{shouldShow('faithfulness') && <Col span={12}><MetricCard metricKey="faithfulness" value={report.avg_faithfulness} color="#52c41a" /></Col>}
|
||||||
|
{shouldShow('answer_relevance') && <Col span={12}><MetricCard metricKey="answer_relevance" value={report.avg_answer_relevance} color="#52c41a" /></Col>}
|
||||||
|
{shouldShow('answer_correctness') && <Col span={12}><MetricCard metricKey="answer_correctness" value={report.avg_answer_correctness} color="#52c41a" /></Col>}
|
||||||
|
{shouldShow('groundedness') && <Col span={12}><MetricCard metricKey="groundedness" value={report.avg_groundedness} color="#fa8c16" /></Col>}
|
||||||
|
</Row>
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: 'items',
|
||||||
|
label: `样本明细 (${items.length})`,
|
||||||
|
children: (
|
||||||
|
<Table
|
||||||
|
rowKey="id"
|
||||||
|
dataSource={items}
|
||||||
|
columns={itemColumns}
|
||||||
|
size="small"
|
||||||
|
scroll={{ x: 900 }}
|
||||||
|
rowClassName={(r) => r.error ? 'ant-table-row-error' : ''}
|
||||||
|
/>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{/* Sample detail drawer */}
|
||||||
|
<Drawer
|
||||||
|
title="样本详情"
|
||||||
|
open={!!drawer}
|
||||||
|
onClose={() => setDrawer(null)}
|
||||||
|
width={640}
|
||||||
|
>
|
||||||
|
{drawer && (
|
||||||
|
<div>
|
||||||
|
<Paragraph><Text strong>问题:</Text>{drawer.question}</Paragraph>
|
||||||
|
<Paragraph><Text strong>参考答案:</Text>{drawer.reference_answer}</Paragraph>
|
||||||
|
<Paragraph><Text strong>Agent 回答:</Text>{drawer.agent_answer || '-'}</Paragraph>
|
||||||
|
|
||||||
|
{(shouldShow('hit_rate') || shouldShow('mrr') || shouldShow('ndcg') || shouldShow('context_precision') || shouldShow('context_recall')) && (
|
||||||
|
<Card title="检索指标" size="small" style={{ marginBottom: 12 }}>
|
||||||
|
<Row gutter={16}>
|
||||||
|
{[
|
||||||
|
shouldShow('hit_rate') && [metricLabel('hit_rate'), drawer.hit_rate],
|
||||||
|
shouldShow('mrr') && [metricLabel('mrr'), drawer.mrr],
|
||||||
|
shouldShow('ndcg') && [metricLabel('ndcg'), drawer.ndcg],
|
||||||
|
shouldShow('context_precision') && [metricLabel('context_precision'), drawer.context_precision],
|
||||||
|
shouldShow('context_recall') && [metricLabel('context_recall'), drawer.context_recall],
|
||||||
|
].filter(Boolean).map(([k, v]) => (
|
||||||
|
<Col span={8} key={k as string}>
|
||||||
|
<Statistic title={k as string} value={v != null ? (v as number).toFixed(3) : 'N/A'} />
|
||||||
|
</Col>
|
||||||
|
))}
|
||||||
|
</Row>
|
||||||
|
</Card>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{(shouldShow('faithfulness') || shouldShow('answer_relevance') || shouldShow('answer_correctness') || shouldShow('groundedness')) && (
|
||||||
|
<Card title="生成指标" size="small" style={{ marginBottom: 12 }}>
|
||||||
|
<Row gutter={16}>
|
||||||
|
{[
|
||||||
|
shouldShow('faithfulness') && [metricLabel('faithfulness'), drawer.faithfulness],
|
||||||
|
shouldShow('answer_relevance') && [metricLabel('answer_relevance'), drawer.answer_relevance],
|
||||||
|
shouldShow('answer_correctness') && [metricLabel('answer_correctness'), drawer.answer_correctness],
|
||||||
|
shouldShow('groundedness') && [metricLabel('groundedness'), drawer.groundedness],
|
||||||
|
].filter(Boolean).map(([k, v]) => (
|
||||||
|
<Col span={6} key={k as string}>
|
||||||
|
<Statistic title={k as string} value={v != null ? (v as number).toFixed(3) : 'N/A'} />
|
||||||
|
</Col>
|
||||||
|
))}
|
||||||
|
</Row>
|
||||||
|
</Card>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{drawer.judge_detail && Object.keys(drawer.judge_detail).length > 0 && (
|
||||||
|
<Card title="Judge 推理过程" size="small">
|
||||||
|
<pre style={{ fontSize: 12, maxHeight: 300, overflow: 'auto', background: '#f5f5f5', padding: 8 }}>
|
||||||
|
{JSON.stringify(drawer.judge_detail, null, 2)}
|
||||||
|
</pre>
|
||||||
|
</Card>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{drawer.error && (
|
||||||
|
<Card title="错误信息" size="small" style={{ borderColor: '#ff4d4f' }}>
|
||||||
|
<Text type="danger">{drawer.error}</Text>
|
||||||
|
</Card>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</Drawer>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
762
frontend/src/pages/SingleJump/index.tsx
Normal file
762
frontend/src/pages/SingleJump/index.tsx
Normal file
@ -0,0 +1,762 @@
|
|||||||
|
import React, { useEffect, useState, useRef } from 'react'
|
||||||
|
import {
|
||||||
|
Table, Button, Modal, Form, Input, InputNumber, Switch, Upload,
|
||||||
|
message, Popconfirm, Tag, Space, Progress, Tooltip, Drawer,
|
||||||
|
Row, Col, Card, Statistic, Divider, Typography, Empty, Spin, Select
|
||||||
|
} from 'antd'
|
||||||
|
import {
|
||||||
|
PlusOutlined, DeleteOutlined, EyeOutlined, ReloadOutlined,
|
||||||
|
UploadOutlined, QuestionCircleOutlined, CheckCircleOutlined,
|
||||||
|
CloseCircleOutlined, WarningOutlined, DownloadOutlined
|
||||||
|
} from '@ant-design/icons'
|
||||||
|
import { singleJumpApi, multiHopApi } from '../../services/api'
|
||||||
|
|
||||||
|
const { Text, Paragraph } = Typography
|
||||||
|
|
||||||
|
// ── 指标说明 ──────────────────────────────────────────────────────────────────
|
||||||
|
const METRIC_TIPS: Record<string, string> = {
|
||||||
|
recall_rate: '有召回结果的问题数 / 总问题数。越高说明知识库覆盖越全面。',
|
||||||
|
file_hit_rate: '召回结果中包含预期文件的问题数 / 有召回结果的问题数。越高说明单跳定位越准确。',
|
||||||
|
avg_cosine_sim: '召回结果与问题的平均余弦相似度(0~1)。越高说明语义匹配越好。',
|
||||||
|
avg_latency_ms: '每次召回的平均耗时(毫秒)。',
|
||||||
|
section_match_rate: '成功映射到知识库文件的章节数 / 总章节数。',
|
||||||
|
}
|
||||||
|
|
||||||
|
function MetricTip({ metricKey }: { metricKey: string }) {
|
||||||
|
return METRIC_TIPS[metricKey] ? (
|
||||||
|
<Tooltip title={METRIC_TIPS[metricKey]}>
|
||||||
|
<QuestionCircleOutlined style={{ marginLeft: 4, color: '#999', fontSize: 12 }} />
|
||||||
|
</Tooltip>
|
||||||
|
) : null
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 状态标签 ──────────────────────────────────────────────────────────────────
|
||||||
|
function StatusTag({ status }: { status: string }) {
|
||||||
|
const map: Record<string, { color: string; label: string }> = {
|
||||||
|
pending: { color: 'default', label: '等待中' },
|
||||||
|
running: { color: 'processing', label: '运行中' },
|
||||||
|
done: { color: 'success', label: '完成' },
|
||||||
|
failed: { color: 'error', label: '失败' },
|
||||||
|
}
|
||||||
|
const cfg = map[status] || { color: 'default', label: status }
|
||||||
|
return <Tag color={cfg.color}>{cfg.label}</Tag>
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 汇总卡片 ──────────────────────────────────────────────────────────────────
|
||||||
|
function SummaryCards({ summary }: { summary: any }) {
|
||||||
|
const pct = (v: number | null) => v != null ? `${(v * 100).toFixed(1)}%` : 'N/A'
|
||||||
|
const cards = [
|
||||||
|
{ key: 'recall_rate', label: '召回率', value: pct(summary.recall_rate), color: summary.recall_rate >= 0.8 ? '#52c41a' : '#faad14' },
|
||||||
|
{ key: 'file_hit_rate', label: '文件命中率', value: pct(summary.file_hit_rate), color: summary.file_hit_rate >= 0.7 ? '#52c41a' : '#faad14' },
|
||||||
|
{ key: 'avg_cosine_sim', label: '平均余弦相似度', value: summary.avg_cosine_sim != null ? summary.avg_cosine_sim.toFixed(4) : 'N/A', color: '#1677ff' },
|
||||||
|
{ key: 'avg_latency_ms', label: '平均延迟', value: summary.avg_latency_ms != null ? `${summary.avg_latency_ms.toFixed(0)}ms` : 'N/A', color: '#722ed1' },
|
||||||
|
{ key: 'section_match_rate', label: '章节匹配率', value: summary.total_sections ? `${summary.matched_sections}/${summary.total_sections}` : 'N/A', color: '#13c2c2' },
|
||||||
|
]
|
||||||
|
return (
|
||||||
|
<Row gutter={[12, 12]} style={{ marginBottom: 16 }}>
|
||||||
|
{cards.map(c => (
|
||||||
|
<Col key={c.key} xs={12} sm={8} md={6} lg={4}>
|
||||||
|
<Card size="small" style={{ textAlign: 'center' }}>
|
||||||
|
<div style={{ fontSize: 22, fontWeight: 600, color: c.color }}>{c.value}</div>
|
||||||
|
<div style={{ fontSize: 12, color: '#666', marginTop: 4 }}>
|
||||||
|
{c.label}<MetricTip metricKey={c.key} />
|
||||||
|
</div>
|
||||||
|
</Card>
|
||||||
|
</Col>
|
||||||
|
))}
|
||||||
|
<Col xs={12} sm={8} md={6} lg={4}>
|
||||||
|
<Card size="small" style={{ textAlign: 'center' }}>
|
||||||
|
<div style={{ fontSize: 22, fontWeight: 600 }}>{summary.total_questions ?? '-'}</div>
|
||||||
|
<div style={{ fontSize: 12, color: '#666', marginTop: 4 }}>总问题数</div>
|
||||||
|
</Card>
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 主页面 ────────────────────────────────────────────────────────────────────
|
||||||
|
export default function SingleJump() {
|
||||||
|
const [tasks, setTasks] = useState<any[]>([])
|
||||||
|
const [loading, setLoading] = useState(false)
|
||||||
|
const [createModal, setCreateModal] = useState(false)
|
||||||
|
const [form] = Form.useForm()
|
||||||
|
const [fileList, setFileList] = useState<any[]>([])
|
||||||
|
const [folderFiles, setFolderFiles] = useState<any[]>([])
|
||||||
|
const [submitting, setSubmitting] = useState(false)
|
||||||
|
const mergedFileRef = useRef<File | null>(null)
|
||||||
|
const [selectedRowKeys, setSelectedRowKeys] = useState<React.Key[]>([])
|
||||||
|
|
||||||
|
// 报告抽屉
|
||||||
|
const [reportDrawer, setReportDrawer] = useState<string | null>(null)
|
||||||
|
const [summary, setSummary] = useState<any>(null)
|
||||||
|
const [sections, setSections] = useState<any[]>([])
|
||||||
|
const [selectedSection, setSelectedSection] = useState<string | null>(null)
|
||||||
|
const [results, setResults] = useState<any[]>([])
|
||||||
|
const [resultLoading, setResultLoading] = useState(false)
|
||||||
|
const [detailDrawer, setDetailDrawer] = useState<any>(null)
|
||||||
|
const [agentIdForRecall, setAgentIdForRecall] = useState('')
|
||||||
|
const [agentRecallLoading, setAgentRecallLoading] = useState(false)
|
||||||
|
const [agentRecallItems, setAgentRecallItems] = useState<any[]>([])
|
||||||
|
const [agentOptions, setAgentOptions] = useState<{ label: string; value: string }[]>([])
|
||||||
|
const [agentOptionsLoading, setAgentOptionsLoading] = useState(false)
|
||||||
|
// 创建任务时的 agent 选项
|
||||||
|
const [createAgentOptions, setCreateAgentOptions] = useState<{ label: string; value: string }[]>([])
|
||||||
|
const [createAgentOptionsLoading, setCreateAgentOptionsLoading] = useState(false)
|
||||||
|
const orgIdValue = Form.useWatch('org_id', form)
|
||||||
|
const envUrlValue = Form.useWatch('env_url', form)
|
||||||
|
|
||||||
|
const pollingRef = useRef<ReturnType<typeof setInterval> | null>(null)
|
||||||
|
|
||||||
|
const loadTasks = async () => {
|
||||||
|
setLoading(true)
|
||||||
|
try {
|
||||||
|
const res = await singleJumpApi.listTasks() as any
|
||||||
|
setTasks(res.data || [])
|
||||||
|
} finally {
|
||||||
|
setLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
loadTasks()
|
||||||
|
pollingRef.current = setInterval(() => {
|
||||||
|
setTasks(prev => {
|
||||||
|
const hasRunning = prev.some(t => t.status === 'running' || t.status === 'pending')
|
||||||
|
if (hasRunning) loadTasks()
|
||||||
|
return prev
|
||||||
|
})
|
||||||
|
}, 3000)
|
||||||
|
return () => { if (pollingRef.current) clearInterval(pollingRef.current) }
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const handleCreate = async () => {
|
||||||
|
const vals = await form.validateFields()
|
||||||
|
if (!fileList.length && !mergedFileRef.current) {
|
||||||
|
message.error('请上传问答集文件或选择文件夹');
|
||||||
|
return
|
||||||
|
}
|
||||||
|
setSubmitting(true)
|
||||||
|
try {
|
||||||
|
const fd = new FormData()
|
||||||
|
|
||||||
|
// 文件夹场景用合并后的文件,单文件场景用原始文件
|
||||||
|
const uploadFile = mergedFileRef.current || fileList[0].originFileObj
|
||||||
|
fd.append('file', uploadFile)
|
||||||
|
fd.append('name', vals.name || (folderFiles.length > 0 ? `批量任务(${folderFiles.length}个文件)` : ''))
|
||||||
|
fd.append('env_url', vals.env_url)
|
||||||
|
fd.append('org_id', vals.org_id)
|
||||||
|
fd.append('d_user_id', vals.d_user_id || 'test')
|
||||||
|
fd.append('agent_id', vals.agent_id || '')
|
||||||
|
fd.append('top_k', String(vals.top_k ?? 64))
|
||||||
|
fd.append('recall_top_k', String(vals.recall_top_k ?? 64))
|
||||||
|
fd.append('concurrency', String(vals.concurrency ?? 5))
|
||||||
|
fd.append('cross_chunk', String(vals.cross_chunk ?? true))
|
||||||
|
|
||||||
|
await singleJumpApi.createTask(fd)
|
||||||
|
|
||||||
|
message.success('任务已创建,正在后台运行')
|
||||||
|
setCreateModal(false)
|
||||||
|
form.resetFields()
|
||||||
|
setFileList([])
|
||||||
|
setFolderFiles([])
|
||||||
|
mergedFileRef.current = null
|
||||||
|
loadTasks()
|
||||||
|
} catch (e: any) {
|
||||||
|
message.error(e?.message || '创建失败')
|
||||||
|
} finally {
|
||||||
|
setSubmitting(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleFolderSelect = async (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||||
|
const files = Array.from(e.target.files || [])
|
||||||
|
const mdFiles = files.filter(f => f.name.endsWith('.md'))
|
||||||
|
if (mdFiles.length === 0) {
|
||||||
|
message.warning('文件夹中没有 MD 文件')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 前端并行读取所有文件内容,合并为单个 File,避免多 part 上传慢
|
||||||
|
const texts = await Promise.all(mdFiles.map(f => f.text()))
|
||||||
|
const merged = new File([texts.join('\n')], `batch_${mdFiles.length}files.md`, { type: 'text/markdown' })
|
||||||
|
mergedFileRef.current = merged
|
||||||
|
setFolderFiles(mdFiles)
|
||||||
|
setFileList([])
|
||||||
|
message.success(`已选择 ${mdFiles.length} 个 MD 文件,将合并为单个文件上传`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 批量删除 ────────────────────────────────────────────────────────────────
|
||||||
|
const handleBatchDelete = async () => {
|
||||||
|
if (selectedRowKeys.length === 0) {
|
||||||
|
message.warning('请先选择要删除的任务')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
Modal.confirm({
|
||||||
|
title: `确认删除选中的 ${selectedRowKeys.length} 个任务?`,
|
||||||
|
content: '删除后将无法恢复,相关测试结果也会被删除。',
|
||||||
|
okText: '确认删除',
|
||||||
|
okType: 'danger',
|
||||||
|
cancelText: '取消',
|
||||||
|
async onOk() {
|
||||||
|
try {
|
||||||
|
await Promise.all(selectedRowKeys.map(id => singleJumpApi.deleteTask(id as string)))
|
||||||
|
message.success(`成功删除 ${selectedRowKeys.length} 个任务`)
|
||||||
|
setSelectedRowKeys([])
|
||||||
|
loadTasks()
|
||||||
|
} catch (e: any) {
|
||||||
|
message.error(e?.message || '批量删除失败')
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleExportFailed = () => {
|
||||||
|
if (!reportDrawer) return
|
||||||
|
const url = singleJumpApi.exportFailedMd(reportDrawer)
|
||||||
|
window.open(url, '_blank')
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleExportFileMiss = () => {
|
||||||
|
if (!reportDrawer) return
|
||||||
|
const url = singleJumpApi.exportFileMissMd(reportDrawer)
|
||||||
|
window.open(url, '_blank')
|
||||||
|
}
|
||||||
|
|
||||||
|
const openReport = async (taskId: string) => {
|
||||||
|
setReportDrawer(taskId)
|
||||||
|
setSummary(null)
|
||||||
|
setSections([])
|
||||||
|
setSelectedSection(null)
|
||||||
|
setResults([])
|
||||||
|
try {
|
||||||
|
const [sumRes, secRes] = await Promise.all([
|
||||||
|
singleJumpApi.getSummary(taskId) as any,
|
||||||
|
singleJumpApi.getSections(taskId) as any,
|
||||||
|
])
|
||||||
|
setSummary(sumRes.data)
|
||||||
|
setSections(secRes.data || [])
|
||||||
|
} catch (e: any) {
|
||||||
|
setReportDrawer(null)
|
||||||
|
message.error(e?.response?.data?.detail || e?.message || '加载测试报告失败')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const loadResults = async (taskId: string, section: string | null) => {
|
||||||
|
setResultLoading(true)
|
||||||
|
try {
|
||||||
|
const res = await singleJumpApi.getResults(taskId, section || undefined) as any
|
||||||
|
setResults(res.data || [])
|
||||||
|
} finally {
|
||||||
|
setResultLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleSectionChange = (val: string | null) => {
|
||||||
|
setSelectedSection(val)
|
||||||
|
if (reportDrawer) loadResults(reportDrawer, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
const openDetail = (row: any) => {
|
||||||
|
setDetailDrawer(row)
|
||||||
|
setAgentRecallItems([])
|
||||||
|
if (reportDrawer) loadAgentOptions(reportDrawer)
|
||||||
|
}
|
||||||
|
|
||||||
|
const loadAgentOptions = async (taskId: string) => {
|
||||||
|
setAgentOptionsLoading(true)
|
||||||
|
try {
|
||||||
|
const res = await singleJumpApi.listAgents(taskId) as any
|
||||||
|
const opts = (res?.data || []).map((a: any) => ({ label: `${a.name} (${a.id.slice(0, 8)}...)`, value: a.id }))
|
||||||
|
setAgentOptions(opts)
|
||||||
|
} catch {
|
||||||
|
setAgentOptions([])
|
||||||
|
} finally {
|
||||||
|
setAgentOptionsLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const loadAgentRecall = async () => {
|
||||||
|
if (!reportDrawer || !detailDrawer?.id) return
|
||||||
|
if (!agentIdForRecall.trim()) {
|
||||||
|
message.warning('请先填写 Agent ID')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
setAgentRecallLoading(true)
|
||||||
|
try {
|
||||||
|
const res = await singleJumpApi.getAgentRecall(reportDrawer, detailDrawer.id, agentIdForRecall.trim()) as any
|
||||||
|
setAgentRecallItems(res?.data?.items || [])
|
||||||
|
message.success(`已拉取 ${res?.data?.items?.length || 0} 条在线 Agent 召回结果`)
|
||||||
|
} catch (e: any) {
|
||||||
|
message.error(e?.response?.data?.detail || e?.message || '拉取在线 Agent 召回失败')
|
||||||
|
} finally {
|
||||||
|
setAgentRecallLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加载创建任务时的 agent 列表
|
||||||
|
const loadCreateAgentOptions = async () => {
|
||||||
|
if (!orgIdValue || !envUrlValue) return
|
||||||
|
setCreateAgentOptionsLoading(true)
|
||||||
|
try {
|
||||||
|
const res = await multiHopApi.listDagentAgents(envUrlValue, orgIdValue) as any
|
||||||
|
const opts = (res?.data || []).map((a: any) => ({ label: `${a.name} (${a.id.slice(0, 8)}...)`, value: a.id }))
|
||||||
|
setCreateAgentOptions(opts)
|
||||||
|
} catch {
|
||||||
|
setCreateAgentOptions([])
|
||||||
|
} finally {
|
||||||
|
setCreateAgentOptionsLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 当 org_id 或 env_url 变化时,加载 agent 列表
|
||||||
|
useEffect(() => {
|
||||||
|
if (orgIdValue && envUrlValue && createModal) {
|
||||||
|
loadCreateAgentOptions()
|
||||||
|
}
|
||||||
|
}, [orgIdValue, envUrlValue, createModal])
|
||||||
|
|
||||||
|
// ── 任务列表列 ──────────────────────────────────────────────────────────────
|
||||||
|
const taskColumns = [
|
||||||
|
{ title: '任务名称', dataIndex: 'name', ellipsis: true, width: 180 },
|
||||||
|
{ title: '环境地址', dataIndex: 'env_url', ellipsis: true },
|
||||||
|
{ title: 'Org ID', dataIndex: 'org_id', ellipsis: true, width: 160,
|
||||||
|
render: (v: string) => <Text code style={{ fontSize: 11 }}>{v?.slice(0, 16)}…</Text> },
|
||||||
|
{ title: '状态', dataIndex: 'status', width: 90, render: (v: string) => <StatusTag status={v} /> },
|
||||||
|
{
|
||||||
|
title: '进度', width: 140,
|
||||||
|
render: (_: any, r: any) => r.status === 'running'
|
||||||
|
? <Progress percent={r.total ? Math.round(r.progress / r.total * 100) : 0} size="small" />
|
||||||
|
: r.status === 'done' ? <Text type="success">{r.total} 条完成</Text>
|
||||||
|
: r.status === 'failed' ? <Tooltip title={r.error_message}><Text type="danger">失败</Text></Tooltip>
|
||||||
|
: <Text type="secondary">-</Text>
|
||||||
|
},
|
||||||
|
{ title: '创建时间', dataIndex: 'created_at', width: 160,
|
||||||
|
render: (v: string) => v?.slice(0, 19) },
|
||||||
|
{
|
||||||
|
title: '操作', width: 120,
|
||||||
|
render: (_: any, r: any) => (
|
||||||
|
<Space>
|
||||||
|
<Button size="small" icon={<EyeOutlined />} disabled={r.status !== 'done'}
|
||||||
|
onClick={() => openReport(r.id)}>报告</Button>
|
||||||
|
<Popconfirm title="确认删除?" onConfirm={async () => {
|
||||||
|
await singleJumpApi.deleteTask(r.id)
|
||||||
|
loadTasks()
|
||||||
|
}}>
|
||||||
|
<Button size="small" danger icon={<DeleteOutlined />} />
|
||||||
|
</Popconfirm>
|
||||||
|
</Space>
|
||||||
|
)
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
// ── 章节列表列 ──────────────────────────────────────────────────────────────
|
||||||
|
const sectionColumns = [
|
||||||
|
{ title: '章节路径', dataIndex: 'section_path', ellipsis: true },
|
||||||
|
{ title: '对应文件', dataIndex: 'file_name', ellipsis: true,
|
||||||
|
render: (v: string) => v
|
||||||
|
? <Tooltip title={v}><Text code style={{ fontSize: 11 }}>{v}</Text></Tooltip>
|
||||||
|
: <Text type="secondary">未匹配</Text>
|
||||||
|
},
|
||||||
|
{ title: '匹配方式', dataIndex: 'match_type', width: 90,
|
||||||
|
render: (v: string) => v
|
||||||
|
? <Tag color={v === 'exact' ? 'green' : v === 'path_contains' ? 'blue' : v === 'basename' ? 'cyan' : 'orange'}>{v}</Tag>
|
||||||
|
: <Tag color="red">未匹配</Tag>
|
||||||
|
},
|
||||||
|
{ title: '问题数', dataIndex: 'total', width: 70 },
|
||||||
|
{ title: '召回数', dataIndex: 'recalled', width: 70,
|
||||||
|
render: (v: number, r: any) => <Text type={v === r.total ? 'success' : 'warning'}>{v}</Text> },
|
||||||
|
{ title: '文件命中', dataIndex: 'file_hits', width: 80,
|
||||||
|
render: (v: number, r: any) => r.recalled
|
||||||
|
? <Text type={v / r.recalled >= 0.7 ? 'success' : 'warning'}>{v}/{r.recalled}</Text>
|
||||||
|
: '-'
|
||||||
|
},
|
||||||
|
{ title: '平均相似度', dataIndex: 'avg_sim', width: 100,
|
||||||
|
render: (v: number) => v != null ? v.toFixed(4) : '-' },
|
||||||
|
{
|
||||||
|
title: '操作', width: 80,
|
||||||
|
render: (_: any, r: any) => (
|
||||||
|
<Button size="small" type="link" onClick={() => handleSectionChange(r.section_path)}>
|
||||||
|
查看问题
|
||||||
|
</Button>
|
||||||
|
)
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
// ── 问题结果列 ──────────────────────────────────────────────────────────────
|
||||||
|
const resultColumns = [
|
||||||
|
{ title: 'ID', dataIndex: 'qid', width: 60 },
|
||||||
|
{ title: '问题', dataIndex: 'question', ellipsis: true },
|
||||||
|
{
|
||||||
|
title: '召回状态', width: 90,
|
||||||
|
render: (_: any, r: any) => r.error
|
||||||
|
? <Tooltip title={r.error}><Tag color="red" icon={<CloseCircleOutlined />}>错误</Tag></Tooltip>
|
||||||
|
: r.retrieved?.length
|
||||||
|
? <Tag color="green" icon={<CheckCircleOutlined />}>{r.retrieved.length} 条</Tag>
|
||||||
|
: <Tag color="orange" icon={<WarningOutlined />}>空</Tag>
|
||||||
|
},
|
||||||
|
{ title: '文件命中', dataIndex: 'is_file_hit', width: 80,
|
||||||
|
render: (v: number, r: any) => !r.file_id ? <Text type="secondary">-</Text>
|
||||||
|
: v ? <Tag color="green">命中</Tag> : <Tag color="orange">未命中</Tag>
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '切片命中', width: 200,
|
||||||
|
render: (_: any, r: any) => {
|
||||||
|
if (!r.expected_chunk_id) return <Text type="secondary">-</Text>
|
||||||
|
const chunkName = r.expected_chunk_name || r.expected_chunk_id?.slice(0, 16) + '...'
|
||||||
|
if (r.is_chunk_hit) {
|
||||||
|
return <Tooltip title={chunkName}><Tag color="green">{chunkName.slice(0, 20)} 命中(Top{r.chunk_hit_rank})</Tag></Tooltip>
|
||||||
|
}
|
||||||
|
return <Tooltip title={chunkName}><Tag color="orange">{chunkName.slice(0, 20)} 未命中</Tag></Tooltip>
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{ title: 'Top1召回文件', width: 180,
|
||||||
|
render: (_: any, r: any) => {
|
||||||
|
const top1 = r.retrieved?.[0]
|
||||||
|
const fileName = top1?.display_file_name || top1?.file_name
|
||||||
|
return fileName ? <Tooltip title={fileName}><Text code style={{ fontSize: 11 }}>{fileName}</Text></Tooltip> : <Text type="secondary">-</Text>
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{ title: '最佳相似度', dataIndex: 'best_cosine_sim', width: 100,
|
||||||
|
render: (v: number) => v != null
|
||||||
|
? <Text type={v >= 0.8 ? 'success' : v >= 0.6 ? 'warning' : 'danger'}>{v.toFixed(4)}</Text>
|
||||||
|
: '-'
|
||||||
|
},
|
||||||
|
{ title: '延迟', dataIndex: 'latency_ms', width: 70,
|
||||||
|
render: (v: number) => v ? `${v}ms` : '-' },
|
||||||
|
{
|
||||||
|
title: '详情', width: 60,
|
||||||
|
render: (_: any, r: any) => (
|
||||||
|
<Button size="small" type="link" onClick={() => openDetail(r)}>查看</Button>
|
||||||
|
)
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{/* 标题栏 */}
|
||||||
|
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 16 }}>
|
||||||
|
<Typography.Title level={4} style={{ margin: 0 }}>单跳召回测试</Typography.Title>
|
||||||
|
<Space>
|
||||||
|
{selectedRowKeys.length > 0 && (
|
||||||
|
<Button danger icon={<DeleteOutlined />} onClick={handleBatchDelete}>
|
||||||
|
批量删除 ({selectedRowKeys.length})
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
<Button icon={<ReloadOutlined />} onClick={loadTasks}>刷新</Button>
|
||||||
|
<Button type="primary" icon={<PlusOutlined />} onClick={() => setCreateModal(true)}>新建测试</Button>
|
||||||
|
</Space>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Table
|
||||||
|
rowKey="id"
|
||||||
|
dataSource={tasks}
|
||||||
|
columns={taskColumns}
|
||||||
|
loading={loading}
|
||||||
|
size="small"
|
||||||
|
pagination={{ pageSize: 10 }}
|
||||||
|
rowSelection={{
|
||||||
|
selectedRowKeys,
|
||||||
|
onChange: setSelectedRowKeys,
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{/* 新建任务弹窗 */}
|
||||||
|
<Modal
|
||||||
|
title="新建单跳召回测试"
|
||||||
|
open={createModal}
|
||||||
|
onOk={handleCreate}
|
||||||
|
onCancel={() => { setCreateModal(false); form.resetFields(); setFileList([]) }}
|
||||||
|
confirmLoading={submitting}
|
||||||
|
width={560}
|
||||||
|
>
|
||||||
|
<Form form={form} layout="vertical" initialValues={{ top_k: 64, recall_top_k: 64, concurrency: 20, cross_chunk: true, d_user_id: 'test' }}>
|
||||||
|
<Form.Item name="name" label="任务名称">
|
||||||
|
<Input placeholder="可选,默认使用文件名" />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="env_url" label="Agent 环境地址" rules={[{ required: true }]}
|
||||||
|
tooltip="dagent 服务地址,如 https://dagent.d-robotics.cc">
|
||||||
|
<Input placeholder="https://dagent.d-robotics.cc" />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="org_id" label="Org ID" rules={[{ required: true }]}
|
||||||
|
tooltip="知识库所属的组织 ID">
|
||||||
|
<Input placeholder="a4d49699ba313815..." />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="agent_id" label="Agent(可选)"
|
||||||
|
tooltip="选择要使用的 Agent 版本进行召回测试,为空时直接调用知识库搜索 API">
|
||||||
|
<Select
|
||||||
|
placeholder="请选择 Agent(可选)"
|
||||||
|
allowClear
|
||||||
|
showSearch
|
||||||
|
options={createAgentOptions}
|
||||||
|
loading={createAgentOptionsLoading}
|
||||||
|
disabled={!orgIdValue || !envUrlValue}
|
||||||
|
notFoundContent={!orgIdValue || !envUrlValue ? '请先填写 Org ID 和环境地址' : '未找到 Agent'}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="d_user_id" label="User ID"
|
||||||
|
tooltip="请求头 d-user-id,默认 test">
|
||||||
|
<Input />
|
||||||
|
</Form.Item>
|
||||||
|
<Row gutter={12}>
|
||||||
|
<Col span={6}>
|
||||||
|
<Form.Item name="top_k" label={<span>命中判断 Top K <MetricTip metricKey="recall_rate" /></span>}>
|
||||||
|
<InputNumber min={1} max={200} style={{ width: '100%' }} />
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
<Col span={6}>
|
||||||
|
<Form.Item name="recall_top_k" label={<span>召回数量 Top K <Tooltip title="调用召回API时请求的结果数量,建议设置较大值以获取更多召回切片用于分析"><QuestionCircleOutlined style={{ color: '#999' }} /></Tooltip></span>}>
|
||||||
|
<InputNumber min={1} max={500} style={{ width: '100%' }} />
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
<Col span={6}>
|
||||||
|
<Form.Item name="concurrency" label="并发数">
|
||||||
|
<InputNumber min={1} max={50} style={{ width: '100%' }} />
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
<Col span={6}>
|
||||||
|
<Form.Item name="cross_chunk" label={<span>跨切片模式 <Tooltip title="关闭后限定在对应文件内召回(当前 dagent 版本建议开启)"><QuestionCircleOutlined style={{ color: '#999' }} /></Tooltip></span>} valuePropName="checked">
|
||||||
|
<Switch />
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
|
<Form.Item label="问答集文件(MD 格式)" required>
|
||||||
|
<Space direction="vertical" style={{ width: '100%' }}>
|
||||||
|
<Space wrap>
|
||||||
|
<Upload
|
||||||
|
accept=".md"
|
||||||
|
maxCount={1}
|
||||||
|
fileList={fileList}
|
||||||
|
beforeUpload={() => false}
|
||||||
|
onChange={({ fileList: fl }) => { setFileList(fl); setFolderFiles([]) }}
|
||||||
|
>
|
||||||
|
<Button icon={<UploadOutlined />}>选择单个文件</Button>
|
||||||
|
</Upload>
|
||||||
|
<label>
|
||||||
|
<Button
|
||||||
|
icon={<UploadOutlined />}
|
||||||
|
onClick={() => document.getElementById('folder-input')?.click()}
|
||||||
|
>
|
||||||
|
选择文件夹
|
||||||
|
</Button>
|
||||||
|
<input
|
||||||
|
id="folder-input"
|
||||||
|
type="file"
|
||||||
|
style={{ display: 'none' }}
|
||||||
|
// @ts-ignore
|
||||||
|
webkitdirectory=""
|
||||||
|
multiple
|
||||||
|
onChange={handleFolderSelect}
|
||||||
|
/>
|
||||||
|
</label>
|
||||||
|
</Space>
|
||||||
|
{folderFiles.length > 0 && (
|
||||||
|
<div style={{ fontSize: 12, color: '#1677ff' }}>
|
||||||
|
已选择文件夹,共 {folderFiles.length} 个 MD 文件:
|
||||||
|
{folderFiles.slice(0, 5).map(f => (
|
||||||
|
<div key={f.name} style={{ color: '#666', paddingLeft: 8 }}>· {f.webkitRelativePath || f.name}</div>
|
||||||
|
))}
|
||||||
|
{folderFiles.length > 5 && <div style={{ color: '#999', paddingLeft: 8 }}>...还有 {folderFiles.length - 5} 个文件</div>}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
<div style={{ fontSize: 12, color: '#999' }}>
|
||||||
|
支持 EVB 知识库问答集格式(## chapter / doc_name + Q/A 结构)
|
||||||
|
</div>
|
||||||
|
</Space>
|
||||||
|
</Form.Item>
|
||||||
|
</Form>
|
||||||
|
</Modal>
|
||||||
|
|
||||||
|
{/* 报告抽屉 */}
|
||||||
|
<Drawer
|
||||||
|
title={`测试报告 — ${tasks.find(t => t.id === reportDrawer)?.name || ''}`}
|
||||||
|
open={!!reportDrawer}
|
||||||
|
onClose={() => { setReportDrawer(null); setSelectedSection(null); setResults([]) }}
|
||||||
|
width="85%"
|
||||||
|
styles={{ body: { padding: '16px 24px' } }}
|
||||||
|
>
|
||||||
|
{!summary ? <Spin /> : (
|
||||||
|
<>
|
||||||
|
<SummaryCards summary={summary} />
|
||||||
|
<div style={{ marginBottom: 12 }}>
|
||||||
|
<Space>
|
||||||
|
<Button
|
||||||
|
icon={<DownloadOutlined />}
|
||||||
|
onClick={handleExportFailed}
|
||||||
|
disabled={!summary?.empty_questions}
|
||||||
|
>
|
||||||
|
导出召回失败问题 {summary?.empty_questions ? `(${summary.empty_questions} 条)` : ''}
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
icon={<DownloadOutlined />}
|
||||||
|
onClick={handleExportFileMiss}
|
||||||
|
disabled={!summary?.file_miss_questions}
|
||||||
|
>
|
||||||
|
导出文件命中失败问题 {summary?.file_miss_questions ? `(${summary.file_miss_questions} 条)` : ''}
|
||||||
|
</Button>
|
||||||
|
</Space>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Divider orientation="left">章节统计</Divider>
|
||||||
|
<div style={{ marginBottom: 8, display: 'flex', gap: 8, alignItems: 'center' }}>
|
||||||
|
<Text type="secondary">共 {sections.length} 个章节,点击「查看问题」可按章节筛选</Text>
|
||||||
|
{selectedSection && (
|
||||||
|
<Button size="small" onClick={() => handleSectionChange(null)}>清除筛选</Button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<Table
|
||||||
|
rowKey="section_path"
|
||||||
|
dataSource={sections}
|
||||||
|
columns={sectionColumns}
|
||||||
|
size="small"
|
||||||
|
pagination={{ pageSize: 10 }}
|
||||||
|
rowClassName={(r) => r.section_path === selectedSection ? 'ant-table-row-selected' : ''}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<Divider orientation="left">
|
||||||
|
{selectedSection ? `问题详情 — ${selectedSection}` : '问题详情(点击章节行查看)'}
|
||||||
|
</Divider>
|
||||||
|
{selectedSection ? (
|
||||||
|
<Spin spinning={resultLoading}>
|
||||||
|
<Table
|
||||||
|
rowKey="id"
|
||||||
|
dataSource={results}
|
||||||
|
columns={resultColumns}
|
||||||
|
size="small"
|
||||||
|
pagination={{ pageSize: 20 }}
|
||||||
|
/>
|
||||||
|
</Spin>
|
||||||
|
) : (
|
||||||
|
<Empty description="请在章节表格中点击「查看问题」" />
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</Drawer>
|
||||||
|
|
||||||
|
{/* 问题详情抽屉 */}
|
||||||
|
<Drawer
|
||||||
|
title={`问题详情 — ${detailDrawer?.qid}`}
|
||||||
|
open={!!detailDrawer}
|
||||||
|
onClose={() => { setDetailDrawer(null); setAgentRecallItems([]) }}
|
||||||
|
width={560}
|
||||||
|
>
|
||||||
|
{detailDrawer && (
|
||||||
|
<div>
|
||||||
|
<Paragraph><Text strong>问题:</Text>{detailDrawer.question}</Paragraph>
|
||||||
|
<Paragraph><Text strong>参考答案:</Text>{detailDrawer.reference_answer}</Paragraph>
|
||||||
|
<Paragraph>
|
||||||
|
<Text strong>预期文件:</Text>
|
||||||
|
{(detailDrawer.expected_file_name || detailDrawer.file_name)
|
||||||
|
? <Text code style={{ fontSize: 11 }}>{detailDrawer.expected_file_name || detailDrawer.file_name}</Text>
|
||||||
|
: <Text type="secondary">未匹配</Text>
|
||||||
|
}
|
||||||
|
{detailDrawer.match_type && <Tag style={{ marginLeft: 8 }}>{detailDrawer.match_type}</Tag>}
|
||||||
|
</Paragraph>
|
||||||
|
{detailDrawer.file_id && (
|
||||||
|
<Paragraph>
|
||||||
|
<Text strong>预期文件ID:</Text>
|
||||||
|
<Text code style={{ fontSize: 11 }}>{detailDrawer.file_id}</Text>
|
||||||
|
</Paragraph>
|
||||||
|
)}
|
||||||
|
{/* 预期切片信息 */}
|
||||||
|
{detailDrawer.expected_chunk_id && (
|
||||||
|
<Paragraph>
|
||||||
|
<Text strong>预期切片:</Text>
|
||||||
|
<Text code style={{ fontSize: 11 }}>{detailDrawer.expected_chunk_name || detailDrawer.section_path || '未知'}</Text>
|
||||||
|
<Tag color={detailDrawer.is_chunk_hit ? 'green' : 'orange'} style={{ marginLeft: 8 }}>
|
||||||
|
{detailDrawer.is_chunk_hit ? `命中 (Top${detailDrawer.chunk_hit_rank})` : '未命中'}
|
||||||
|
</Tag>
|
||||||
|
</Paragraph>
|
||||||
|
)}
|
||||||
|
{detailDrawer.error && (
|
||||||
|
<Paragraph><Text type="danger">错误:{detailDrawer.error}</Text></Paragraph>
|
||||||
|
)}
|
||||||
|
<Divider>召回结果({detailDrawer.retrieved?.length || 0} 条)</Divider>
|
||||||
|
{(detailDrawer.retrieved || []).map((chunk: any, i: number) => (
|
||||||
|
<Card key={i} size="small" style={{ marginBottom: 8 }}
|
||||||
|
title={
|
||||||
|
<Space wrap>
|
||||||
|
<Text>#{i + 1}</Text>
|
||||||
|
<Text type="secondary" style={{ fontSize: 11 }}>
|
||||||
|
相似度: {chunk.cosine_distance_1 != null ? (1 - chunk.cosine_distance_1).toFixed(4) : '-'}
|
||||||
|
</Text>
|
||||||
|
{/* 切片命中标识 */}
|
||||||
|
{detailDrawer.expected_chunk_id && chunk.id === detailDrawer.expected_chunk_id ? (
|
||||||
|
<Tag color="green" icon={<CheckCircleOutlined />}>命中预期切片</Tag>
|
||||||
|
) : detailDrawer.file_id && chunk.file_id === detailDrawer.file_id ? (
|
||||||
|
<Tag color="blue" icon={<CheckCircleOutlined />}>命中预期文件</Tag>
|
||||||
|
) : (
|
||||||
|
<Tag color="orange">其他</Tag>
|
||||||
|
)}
|
||||||
|
{chunk.file_id && (
|
||||||
|
<Text type="secondary" style={{ fontSize: 10 }}>
|
||||||
|
文件: {chunk.display_file_name || chunk.file_name || '未知文件'}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
{chunk.file_id && (
|
||||||
|
<Tooltip title={chunk.file_id}>
|
||||||
|
<Text type="secondary" style={{ fontSize: 10 }}>
|
||||||
|
ID: {chunk.file_id}
|
||||||
|
</Text>
|
||||||
|
</Tooltip>
|
||||||
|
)}
|
||||||
|
</Space>
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<Text style={{ fontSize: 12 }}>{chunk.active_paragraph_context?.slice(0, 300)}</Text>
|
||||||
|
{chunk.headers && <div style={{ marginTop: 4, fontSize: 11, color: '#999' }}>标题: {chunk.headers}</div>}
|
||||||
|
</Card>
|
||||||
|
))}
|
||||||
|
|
||||||
|
<Divider>在线 Agent 召回结果对照</Divider>
|
||||||
|
<Space direction="vertical" style={{ width: '100%' }} size={8}>
|
||||||
|
<Text type="secondary">优先从下拉选择 Agent(也支持直接手填),拉取该问题在 Agent 链路中的真实召回结果。</Text>
|
||||||
|
<Space.Compact style={{ width: '100%' }}>
|
||||||
|
<Select
|
||||||
|
showSearch
|
||||||
|
allowClear
|
||||||
|
style={{ width: '100%' }}
|
||||||
|
placeholder="请选择 agent"
|
||||||
|
value={agentIdForRecall || undefined}
|
||||||
|
options={agentOptions}
|
||||||
|
loading={agentOptionsLoading}
|
||||||
|
onChange={(v) => setAgentIdForRecall(v || '')}
|
||||||
|
filterOption={(input, option) => ((option?.label as string) || '').toLowerCase().includes(input.toLowerCase())}
|
||||||
|
/>
|
||||||
|
<Button loading={agentRecallLoading} onClick={loadAgentRecall}>
|
||||||
|
拉取在线召回
|
||||||
|
</Button>
|
||||||
|
</Space.Compact>
|
||||||
|
<Input
|
||||||
|
placeholder="如果下拉没有,手动输入 agent_id"
|
||||||
|
value={agentIdForRecall}
|
||||||
|
onChange={(e) => setAgentIdForRecall(e.target.value)}
|
||||||
|
/>
|
||||||
|
</Space>
|
||||||
|
<div style={{ marginTop: 12 }}>
|
||||||
|
{agentRecallLoading ? (
|
||||||
|
<Spin />
|
||||||
|
) : agentRecallItems.length ? (
|
||||||
|
agentRecallItems.map((item: any, i: number) => (
|
||||||
|
<Card key={`${item.file_id || 'f'}-${i}`} size="small" style={{ marginBottom: 8 }}
|
||||||
|
title={
|
||||||
|
<Space wrap>
|
||||||
|
<Text>#{i + 1}</Text>
|
||||||
|
<Text code style={{ fontSize: 11 }}>{item.file_name || '未知文件名'}</Text>
|
||||||
|
{item.file_id && <Text type="secondary" style={{ fontSize: 10 }}>ID: {item.file_id}</Text>}
|
||||||
|
{item.similarity != null && <Tag color="blue">相似度 {item.similarity}</Tag>}
|
||||||
|
</Space>
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<Text style={{ fontSize: 12 }}>{item.content?.slice(0, 300) || '-'}</Text>
|
||||||
|
{item.headers && <div style={{ marginTop: 4, fontSize: 11, color: '#999' }}>标题: {item.headers}</div>}
|
||||||
|
</Card>
|
||||||
|
))
|
||||||
|
) : (
|
||||||
|
<Empty description="暂未拉取在线 Agent 召回结果" />
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</Drawer>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
219
frontend/src/pages/Task/index.tsx
Normal file
219
frontend/src/pages/Task/index.tsx
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
import React, { useEffect, useState } from 'react'
|
||||||
|
import { Table, Button, Modal, Form, Input, Select, InputNumber, Tag, Space, Popconfirm, message, Checkbox, Tooltip, Divider } from 'antd'
|
||||||
|
import { PlusOutlined, DeleteOutlined, EyeOutlined, ReloadOutlined, QuestionCircleOutlined } from '@ant-design/icons'
|
||||||
|
import { useNavigate, useSearchParams } from 'react-router-dom'
|
||||||
|
import { taskApi, datasetApi, configApi } from '../../services/api'
|
||||||
|
import { METRICS, RETRIEVAL_METRICS, GENERATION_METRICS, ALL_METRIC_KEYS } from '../../constants/metrics'
|
||||||
|
|
||||||
|
const { Option } = Select
|
||||||
|
|
||||||
|
const STATUS_COLOR: Record<string, string> = {
|
||||||
|
pending: 'default',
|
||||||
|
running: 'processing',
|
||||||
|
done: 'success',
|
||||||
|
failed: 'error',
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function Task() {
|
||||||
|
const [tasks, setTasks] = useState<any[]>([])
|
||||||
|
const [datasets, setDatasets] = useState<any[]>([])
|
||||||
|
const [platforms, setPlatforms] = useState<any[]>([])
|
||||||
|
const [judges, setJudges] = useState<any[]>([])
|
||||||
|
const [modal, setModal] = useState(false)
|
||||||
|
const [form] = Form.useForm()
|
||||||
|
const navigate = useNavigate()
|
||||||
|
const [searchParams] = useSearchParams()
|
||||||
|
const [selectedRowKeys, setSelectedRowKeys] = useState<React.Key[]>([])
|
||||||
|
|
||||||
|
const load = async () => {
|
||||||
|
const res = await taskApi.list() as any
|
||||||
|
setTasks(res.data || [])
|
||||||
|
}
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
load()
|
||||||
|
datasetApi.list().then((r: any) => {
|
||||||
|
const ds = r.data || []
|
||||||
|
setDatasets(ds)
|
||||||
|
const datasetId = searchParams.get('dataset_id')
|
||||||
|
if (datasetId) {
|
||||||
|
// 检查数据集是否存在
|
||||||
|
const found = ds.find((d: any) => d.id === datasetId)
|
||||||
|
if (found) {
|
||||||
|
form.setFieldsValue({ dataset_id: datasetId })
|
||||||
|
// 自动打开新建任务模态框
|
||||||
|
setModal(true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
configApi.listPlatforms().then((r: any) => setPlatforms(r.data || []))
|
||||||
|
configApi.listJudges().then((r: any) => setJudges(r.data || []))
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const runTask = async () => {
|
||||||
|
const vals = await form.validateFields()
|
||||||
|
await taskApi.run(vals)
|
||||||
|
message.success('评测任务已启动')
|
||||||
|
setModal(false)
|
||||||
|
form.resetFields()
|
||||||
|
load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 批量删除 ────────────────────────────────────────────────────────────────
|
||||||
|
const handleBatchDelete = async () => {
|
||||||
|
if (selectedRowKeys.length === 0) {
|
||||||
|
message.warning('请先选择要删除的任务')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
Modal.confirm({
|
||||||
|
title: `确认删除选中的 ${selectedRowKeys.length} 个评测任务?`,
|
||||||
|
content: '删除后将无法恢复,相关评测结果也会被删除。',
|
||||||
|
okText: '确认删除',
|
||||||
|
okType: 'danger',
|
||||||
|
cancelText: '取消',
|
||||||
|
async onOk() {
|
||||||
|
try {
|
||||||
|
await Promise.all(selectedRowKeys.map(id => taskApi.delete(id as string)))
|
||||||
|
message.success(`成功删除 ${selectedRowKeys.length} 个任务`)
|
||||||
|
setSelectedRowKeys([])
|
||||||
|
load()
|
||||||
|
} catch (e: any) {
|
||||||
|
message.error(e?.message || '批量删除失败')
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const columns = [
|
||||||
|
{
|
||||||
|
title: '任务名称', dataIndex: 'name',
|
||||||
|
render: (v: string, r: any) => v || r.id.slice(0, 8) + '...',
|
||||||
|
},
|
||||||
|
{ title: '数据集', dataIndex: 'dataset_id', ellipsis: true },
|
||||||
|
{
|
||||||
|
title: '评测指标',
|
||||||
|
render: (_: any, r: any) => {
|
||||||
|
const metrics = r.selected_metrics || []
|
||||||
|
if (metrics.length === 0) {
|
||||||
|
// 向后兼容:显示检索/生成标签
|
||||||
|
const tags = []
|
||||||
|
if (r.eval_retrieval) tags.push(<Tag key="r" color="blue">检索</Tag>)
|
||||||
|
if (r.eval_generation) tags.push(<Tag key="g" color="purple">生成</Tag>)
|
||||||
|
return <>{tags}</>
|
||||||
|
}
|
||||||
|
return <Tag>{metrics.length} 项指标</Tag>
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '状态', dataIndex: 'status',
|
||||||
|
render: (v: string) => <Tag color={STATUS_COLOR[v] || 'default'}>{v}</Tag>,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: '进度', render: (_: any, r: any) =>
|
||||||
|
r.total > 0 ? `${r.progress} / ${r.total}` : '-',
|
||||||
|
},
|
||||||
|
{ title: '创建时间', dataIndex: 'created_at', render: (v: string) => v?.slice(0, 19) },
|
||||||
|
{
|
||||||
|
title: '操作',
|
||||||
|
render: (_: any, r: any) => (
|
||||||
|
<Space>
|
||||||
|
{r.status === 'done' && (
|
||||||
|
<Button size="small" icon={<EyeOutlined />} onClick={() => navigate(`/report/${r.id}`)}>
|
||||||
|
报告
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
<Popconfirm title="确认删除该任务及结果?" onConfirm={() => taskApi.delete(r.id).then(load)}>
|
||||||
|
<Button danger size="small" icon={<DeleteOutlined />} />
|
||||||
|
</Popconfirm>
|
||||||
|
</Space>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 16 }}>
|
||||||
|
<h2 style={{ margin: 0 }}>评测任务</h2>
|
||||||
|
<Space>
|
||||||
|
{selectedRowKeys.length > 0 && (
|
||||||
|
<Button danger icon={<DeleteOutlined />} onClick={handleBatchDelete}>
|
||||||
|
批量删除 ({selectedRowKeys.length})
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
<Button icon={<ReloadOutlined />} onClick={load}>刷新</Button>
|
||||||
|
<Button type="primary" icon={<PlusOutlined />} onClick={() => setModal(true)}>新建任务</Button>
|
||||||
|
</Space>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Table
|
||||||
|
rowKey="id"
|
||||||
|
dataSource={tasks}
|
||||||
|
columns={columns}
|
||||||
|
rowSelection={{
|
||||||
|
selectedRowKeys,
|
||||||
|
onChange: setSelectedRowKeys,
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<Modal title="新建评测任务" open={modal} onOk={runTask} onCancel={() => setModal(false)} width={700}>
|
||||||
|
<Form form={form} layout="vertical">
|
||||||
|
<Form.Item name="name" label="任务名称(可选)"><Input /></Form.Item>
|
||||||
|
<Form.Item name="dataset_id" label="测试集" rules={[{ required: true }]}>
|
||||||
|
<Select placeholder="选择测试集">
|
||||||
|
{datasets.map((d: any) => (
|
||||||
|
<Option key={d.id} value={d.id}>{d.name} ({d.sample_count} 条)</Option>
|
||||||
|
))}
|
||||||
|
</Select>
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="platform_config_id" label="平台配置" rules={[{ required: true }]}>
|
||||||
|
<Select placeholder="选择平台">
|
||||||
|
{platforms.map((p: any) => <Option key={p.id} value={p.id}>{p.name}</Option>)}
|
||||||
|
</Select>
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="judge_config_id" label="Judge 模型" rules={[{ required: true }]}>
|
||||||
|
<Select placeholder="选择 Judge">
|
||||||
|
{judges.map((j: any) => <Option key={j.id} value={j.id}>{j.name} ({j.model})</Option>)}
|
||||||
|
</Select>
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item name="agent_id" label="Agent ID" rules={[{ required: true }]}><Input /></Form.Item>
|
||||||
|
<Form.Item name="knowledge_hub_id" label="知识库 ID" rules={[{ required: true }]}><Input /></Form.Item>
|
||||||
|
<Form.Item name="top_k" label="Top K" initialValue={10}>
|
||||||
|
<InputNumber min={1} max={50} />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Divider orientation="left">评测指标选择</Divider>
|
||||||
|
<Form.Item name="selected_metrics" label="选择要评测的指标" initialValue={ALL_METRIC_KEYS}>
|
||||||
|
<Checkbox.Group style={{ width: '100%' }}>
|
||||||
|
<div style={{ marginBottom: 12 }}>
|
||||||
|
<div style={{ fontWeight: 'bold', marginBottom: 8, color: '#1677ff' }}>检索层指标</div>
|
||||||
|
<Space direction="vertical" size={4}>
|
||||||
|
{RETRIEVAL_METRICS.map(m => (
|
||||||
|
<Checkbox key={m.key} value={m.key}>
|
||||||
|
<span style={{ fontWeight: 500 }}>{m.cn} ({m.en})</span>
|
||||||
|
<span style={{ marginLeft: 8, fontSize: 12, color: '#888' }}>{m.desc}</span>
|
||||||
|
</Checkbox>
|
||||||
|
))}
|
||||||
|
</Space>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<div style={{ fontWeight: 'bold', marginBottom: 8, color: '#722ed1' }}>生成层指标</div>
|
||||||
|
<Space direction="vertical" size={4}>
|
||||||
|
{GENERATION_METRICS.map(m => (
|
||||||
|
<Checkbox key={m.key} value={m.key}>
|
||||||
|
<span style={{ fontWeight: 500 }}>{m.cn} ({m.en})</span>
|
||||||
|
<span style={{ marginLeft: 8, fontSize: 12, color: '#888' }}>{m.desc}</span>
|
||||||
|
</Checkbox>
|
||||||
|
))}
|
||||||
|
</Space>
|
||||||
|
</div>
|
||||||
|
</Checkbox.Group>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item name="concurrency" label="并发数" initialValue={3}>
|
||||||
|
<InputNumber min={1} max={10} />
|
||||||
|
</Form.Item>
|
||||||
|
</Form>
|
||||||
|
</Modal>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
157
frontend/src/services/api.ts
Normal file
157
frontend/src/services/api.ts
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
import http from './http'
|
||||||
|
|
||||||
|
export const configApi = {
|
||||||
|
listPlatforms: () => http.get('/config/platform'),
|
||||||
|
createPlatform: (data: any) => http.post('/config/platform', data),
|
||||||
|
deletePlatform: (id: string) => http.delete(`/config/platform/${id}`),
|
||||||
|
listJudges: () => http.get('/config/judge'),
|
||||||
|
createJudge: (data: any) => http.post('/config/judge', data),
|
||||||
|
deleteJudge: (id: string) => http.delete(`/config/judge/${id}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
export const datasetApi = {
|
||||||
|
list: () => http.get('/dataset/list'),
|
||||||
|
get: (id: string) => http.get(`/dataset/${id}`),
|
||||||
|
create: (data: any) => http.post('/dataset/create', data),
|
||||||
|
delete: (id: string) => http.delete(`/dataset/${id}`),
|
||||||
|
addSample: (data: any) => http.post('/dataset/sample/add', data),
|
||||||
|
generate: (data: any) => http.post('/dataset/generate', data),
|
||||||
|
getGenerateProgress: (genTaskId: string) => http.get(`/dataset/generate/${genTaskId}`),
|
||||||
|
chunksPreview: (platformConfigId: string, knowledgeHubId: string) =>
|
||||||
|
http.get(`/dataset/chunks-preview?platform_config_id=${platformConfigId}&knowledge_hub_id=${knowledgeHubId}`),
|
||||||
|
import: (file: File) => {
|
||||||
|
const form = new FormData()
|
||||||
|
form.append('file', file)
|
||||||
|
return http.post('/dataset/import', form)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
export const taskApi = {
|
||||||
|
list: () => http.get('/task/list'),
|
||||||
|
get: (id: string) => http.get(`/task/${id}`),
|
||||||
|
run: (data: any) => http.post('/task/run', data),
|
||||||
|
delete: (id: string) => http.delete(`/task/${id}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
export const reportApi = {
|
||||||
|
get: (taskId: string) => http.get(`/report/${taskId}`),
|
||||||
|
items: (taskId: string) => http.get(`/report/${taskId}/items`),
|
||||||
|
}
|
||||||
|
|
||||||
|
export const singleJumpApi = {
|
||||||
|
createTask: (formData: FormData) => http.post('/single-jump/task', formData),
|
||||||
|
createTaskBatch: (formData: FormData) => http.post('/single-jump/task/batch', formData),
|
||||||
|
listTasks: () => http.get('/single-jump/task/list'),
|
||||||
|
getTask: (id: string) => http.get(`/single-jump/task/${id}`),
|
||||||
|
deleteTask: (id: string) => http.delete(`/single-jump/task/${id}`),
|
||||||
|
getSummary: (id: string) => http.get(`/single-jump/task/${id}/summary`),
|
||||||
|
getSections: (id: string) => http.get(`/single-jump/task/${id}/sections`),
|
||||||
|
getResults: (id: string, section?: string) =>
|
||||||
|
http.get(`/single-jump/task/${id}/results${section ? `?section=${encodeURIComponent(section)}` : ''}`),
|
||||||
|
getAgentRecall: (taskId: string, resultId: string, agentId: string) =>
|
||||||
|
http.get(`/single-jump/task/${taskId}/agent-recall?result_id=${encodeURIComponent(resultId)}&agent_id=${encodeURIComponent(agentId)}`),
|
||||||
|
listAgents: (taskId: string) => http.get(`/single-jump/task/${taskId}/agents`),
|
||||||
|
exportFailedMd: (taskId: string) => `/api/single-jump/task/${taskId}/export-failed-md`,
|
||||||
|
exportFileMissMd: (taskId: string) => `/api/single-jump/task/${taskId}/export-file-miss-md`,
|
||||||
|
}
|
||||||
|
|
||||||
|
export const qaGenApi = {
|
||||||
|
createTask: (formData: FormData) => http.post('/qa-gen/task', formData),
|
||||||
|
createTaskFromDagent: (formData: FormData) => http.post('/qa-gen/task/from-dagent', formData),
|
||||||
|
getDagentStats: (orgId: string, envUrl?: string) => http.get(`/qa-gen/dagent/stats?org_id=${encodeURIComponent(orgId)}${envUrl ? `&env_url=${encodeURIComponent(envUrl)}` : ''}`),
|
||||||
|
listDagentFiles: (orgId: string, envUrl?: string) => http.get(`/qa-gen/dagent/files?org_id=${encodeURIComponent(orgId)}${envUrl ? `&env_url=${encodeURIComponent(envUrl)}` : ''}`),
|
||||||
|
getDagentTree: (orgId: string, envUrl?: string) => http.get(`/qa-gen/dagent/tree?org_id=${encodeURIComponent(orgId)}${envUrl ? `&env_url=${encodeURIComponent(envUrl)}` : ''}`),
|
||||||
|
listTasks: () => http.get('/qa-gen/task/list'),
|
||||||
|
getTask: (id: string) => http.get(`/qa-gen/task/${id}`),
|
||||||
|
deleteTask: (id: string) => http.delete(`/qa-gen/task/${id}`),
|
||||||
|
listQuestions: (taskId: string, params?: { status?: string; section?: string; page?: number; page_size?: number }) => {
|
||||||
|
const q = new URLSearchParams()
|
||||||
|
if (params?.status) q.set('status', params.status)
|
||||||
|
if (params?.section) q.set('section', params.section)
|
||||||
|
if (params?.page) q.set('page', String(params.page))
|
||||||
|
if (params?.page_size) q.set('page_size', String(params.page_size))
|
||||||
|
const qs = q.toString()
|
||||||
|
return http.get(`/qa-gen/task/${taskId}/questions${qs ? `?${qs}` : ''}`)
|
||||||
|
},
|
||||||
|
listSections: (taskId: string) => http.get(`/qa-gen/task/${taskId}/sections`),
|
||||||
|
approveQuestion: (id: string) => http.post(`/qa-gen/question/${id}/approve`),
|
||||||
|
rejectQuestion: (id: string) => http.post(`/qa-gen/question/${id}/reject`),
|
||||||
|
editQuestion: (id: string, data: { question?: string; reference_answer?: string }) =>
|
||||||
|
http.put(`/qa-gen/question/${id}`, data),
|
||||||
|
batchApprove: (taskId: string, minQuality = 0) =>
|
||||||
|
http.post(`/qa-gen/task/${taskId}/batch-approve?min_quality=${minQuality}`),
|
||||||
|
exportMd: (taskId: string) => `/api/qa-gen/task/${taskId}/export-md`,
|
||||||
|
createDataset: (taskId: string, data: { name: string; knowledge_hub_id?: string; description?: string }) =>
|
||||||
|
http.post(`/qa-gen/task/${taskId}/create-dataset`, data),
|
||||||
|
}
|
||||||
|
|
||||||
|
export const loopApi = {
|
||||||
|
createTask: (formData: FormData) => http.post('/loop/task', formData),
|
||||||
|
listTasks: () => http.get('/loop/task/list'),
|
||||||
|
getTask: (id: string) => http.get(`/loop/task/${id}`),
|
||||||
|
pauseTask: (id: string) => http.post(`/loop/task/${id}/pause`),
|
||||||
|
resumeTask: (id: string) => http.post(`/loop/task/${id}/resume`),
|
||||||
|
stopTask: (id: string) => http.post(`/loop/task/${id}/stop`),
|
||||||
|
deleteTask: (id: string) => http.delete(`/loop/task/${id}`),
|
||||||
|
getRounds: (id: string) => http.get(`/loop/task/${id}/rounds`),
|
||||||
|
getQuestions: (id: string, params?: { status?: string; category?: string; page?: number; page_size?: number }) => {
|
||||||
|
const q = new URLSearchParams()
|
||||||
|
if (params?.status) q.set('status', params.status)
|
||||||
|
if (params?.category) q.set('category', params.category)
|
||||||
|
if (params?.page) q.set('page', String(params.page))
|
||||||
|
if (params?.page_size) q.set('page_size', String(params.page_size))
|
||||||
|
const qs = q.toString()
|
||||||
|
return http.get(`/loop/task/${id}/questions${qs ? `?${qs}` : ''}`)
|
||||||
|
},
|
||||||
|
export: (id: string, category: string, format: 'md' | 'json' = 'md') =>
|
||||||
|
`/api/loop/task/${id}/export?category=${category}&format=${format}`,
|
||||||
|
}
|
||||||
|
|
||||||
|
export const multiHopApi = {
|
||||||
|
createTask: (formData: FormData) => http.post('/multi-hop/task', formData),
|
||||||
|
listTasks: () => http.get('/multi-hop/task/list'),
|
||||||
|
getTask: (id: string) => http.get(`/multi-hop/task/${id}`),
|
||||||
|
deleteTask: (id: string) => http.delete(`/multi-hop/task/${id}`),
|
||||||
|
getResults: (id: string) => http.get(`/multi-hop/task/${id}/results`),
|
||||||
|
getSummary: (id: string) => http.get(`/multi-hop/task/${id}/summary`),
|
||||||
|
listDagentAgents: (envUrl: string, orgId: string, dUserId = 'test') =>
|
||||||
|
http.get(`/multi-hop/dagent/agents?env_url=${encodeURIComponent(envUrl)}&org_id=${encodeURIComponent(orgId)}&d_user_id=${dUserId}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
export const promptTemplateApi = {
|
||||||
|
list: () => http.get('/prompt-template/list'),
|
||||||
|
getDefault: () => http.get('/prompt-template/default'),
|
||||||
|
create: (data: { name: string; description?: string; content: string }) =>
|
||||||
|
http.post('/prompt-template', data),
|
||||||
|
update: (id: string, data: { name: string; description?: string; content: string }) =>
|
||||||
|
http.put(`/prompt-template/${id}`, data),
|
||||||
|
delete: (id: string) => http.delete(`/prompt-template/${id}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
export const multiHopGenApi = {
|
||||||
|
createTask: (formData: FormData) => http.post('/multi-hop-gen/task', formData),
|
||||||
|
createTaskFromDagent: (formData: FormData) => http.post('/multi-hop-gen/task/from-dagent', formData),
|
||||||
|
getDagentStats: (orgId: string, envUrl?: string) => http.get(`/multi-hop-gen/dagent/stats?org_id=${encodeURIComponent(orgId)}${envUrl ? `&env_url=${encodeURIComponent(envUrl)}` : ''}`),
|
||||||
|
listDagentFiles: (orgId: string, envUrl?: string) => http.get(`/multi-hop-gen/dagent/files?org_id=${encodeURIComponent(orgId)}${envUrl ? `&env_url=${encodeURIComponent(envUrl)}` : ''}`),
|
||||||
|
listTasks: () => http.get('/multi-hop-gen/task/list'),
|
||||||
|
getTask: (id: string) => http.get(`/multi-hop-gen/task/${id}`),
|
||||||
|
deleteTask: (id: string) => http.delete(`/multi-hop-gen/task/${id}`),
|
||||||
|
listQuestions: (taskId: string, params?: { status?: string; page?: number; page_size?: number }) => {
|
||||||
|
const q = new URLSearchParams()
|
||||||
|
if (params?.status) q.set('status', params.status)
|
||||||
|
if (params?.page) q.set('page', String(params.page))
|
||||||
|
if (params?.page_size) q.set('page_size', String(params.page_size))
|
||||||
|
const qs = q.toString()
|
||||||
|
return http.get(`/multi-hop-gen/task/${taskId}/questions${qs ? `?${qs}` : ''}`)
|
||||||
|
},
|
||||||
|
approveQuestion: (id: string) => http.post(`/multi-hop-gen/question/${id}/approve`),
|
||||||
|
rejectQuestion: (id: string) => http.post(`/multi-hop-gen/question/${id}/reject`),
|
||||||
|
editQuestion: (id: string, data: { question?: string; answer?: string; type?: string }) =>
|
||||||
|
http.put(`/multi-hop-gen/question/${id}`, data),
|
||||||
|
batchApprove: (taskId: string, minQuality = 0) =>
|
||||||
|
http.post(`/multi-hop-gen/task/${taskId}/batch-approve?min_quality=${minQuality}`),
|
||||||
|
exportMd: (taskId: string) => `/api/multi-hop-gen/task/${taskId}/export-md`,
|
||||||
|
createTest: (taskId: string, data: { env_url: string; org_id: string; agent_id: string; llm_type?: string; d_user_id?: string; top_k?: number; concurrency?: number; name?: string }) =>
|
||||||
|
http.post(`/multi-hop-gen/task/${taskId}/create-test`, data),
|
||||||
|
}
|
||||||
|
|
||||||
10
frontend/src/services/http.ts
Normal file
10
frontend/src/services/http.ts
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
import axios from 'axios'
|
||||||
|
|
||||||
|
const http = axios.create({ baseURL: '/api' })
|
||||||
|
|
||||||
|
http.interceptors.response.use(
|
||||||
|
(res) => res.data,
|
||||||
|
(err) => Promise.reject(err)
|
||||||
|
)
|
||||||
|
|
||||||
|
export default http
|
||||||
17
frontend/tsconfig.json
Normal file
17
frontend/tsconfig.json
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"target": "ES2020",
|
||||||
|
"useDefineForClassFields": true,
|
||||||
|
"lib": ["ES2020", "DOM", "DOM.Iterable"],
|
||||||
|
"module": "ESNext",
|
||||||
|
"skipLibCheck": true,
|
||||||
|
"moduleResolution": "bundler",
|
||||||
|
"allowImportingTsExtensions": true,
|
||||||
|
"resolveJsonModule": true,
|
||||||
|
"isolatedModules": true,
|
||||||
|
"noEmit": true,
|
||||||
|
"jsx": "react-jsx",
|
||||||
|
"strict": true
|
||||||
|
},
|
||||||
|
"include": ["src"]
|
||||||
|
}
|
||||||
15
frontend/vite.config.ts
Normal file
15
frontend/vite.config.ts
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import { defineConfig } from 'vite'
|
||||||
|
import react from '@vitejs/plugin-react'
|
||||||
|
|
||||||
|
export default defineConfig({
|
||||||
|
plugins: [react()],
|
||||||
|
server: {
|
||||||
|
port: 5173,
|
||||||
|
proxy: {
|
||||||
|
'/api': 'http://localhost:8021',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
build: {
|
||||||
|
outDir: 'dist',
|
||||||
|
},
|
||||||
|
})
|
||||||
19
nginx.conf
Normal file
19
nginx.conf
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
server {
|
||||||
|
listen 80;
|
||||||
|
server_name _;
|
||||||
|
|
||||||
|
root /usr/share/nginx/html;
|
||||||
|
index index.html;
|
||||||
|
|
||||||
|
# SPA fallback
|
||||||
|
location / {
|
||||||
|
try_files $uri $uri/ /index.html;
|
||||||
|
}
|
||||||
|
|
||||||
|
# Proxy API to backend
|
||||||
|
location /api/ {
|
||||||
|
proxy_pass http://server:8003;
|
||||||
|
proxy_set_header Host $host;
|
||||||
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
|
}
|
||||||
|
}
|
||||||
25
sdk/config.example.yaml
Normal file
25
sdk/config.example.yaml
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# 平台连接配置
|
||||||
|
platform:
|
||||||
|
base_url: "http://localhost:8000"
|
||||||
|
org_id: "your_org_id"
|
||||||
|
token: "" # 如有鉴权 token 填写
|
||||||
|
|
||||||
|
# Judge LLM 配置(OpenAI 兼容接口)
|
||||||
|
judge:
|
||||||
|
base_url: "https://api.openai.com/v1"
|
||||||
|
api_key: "sk-your-key"
|
||||||
|
model: "gpt-4o"
|
||||||
|
|
||||||
|
# 评测参数
|
||||||
|
eval:
|
||||||
|
agent_id: "your_agent_id"
|
||||||
|
knowledge_hub_id: "your_hub_id"
|
||||||
|
top_k: 10
|
||||||
|
eval_retrieval: true
|
||||||
|
eval_generation: true
|
||||||
|
file_id_list:
|
||||||
|
- "file_id_1"
|
||||||
|
- "file_id_2"
|
||||||
|
concurrency: 3
|
||||||
|
questions_per_chunk: 2
|
||||||
|
max_chunks: 50
|
||||||
1449
sdk/poetry.lock
generated
Normal file
1449
sdk/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
21
sdk/pyproject.toml
Normal file
21
sdk/pyproject.toml
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
[tool.poetry]
|
||||||
|
name = "rag-eval"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Platform-agnostic RAG evaluation framework"
|
||||||
|
authors = []
|
||||||
|
packages = [{ include = "rag_eval" }]
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = "^3.10"
|
||||||
|
openai = "^1.67.0"
|
||||||
|
aiohttp = "^3.9.0"
|
||||||
|
numpy = ">=2.0"
|
||||||
|
pydantic = "^2.0"
|
||||||
|
pyyaml = "^6.0.3"
|
||||||
|
|
||||||
|
[tool.poetry.scripts]
|
||||||
|
rag-eval = "rag_eval.cli:main"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
||||||
4
sdk/rag_eval/__init__.py
Normal file
4
sdk/rag_eval/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .runner import EvalRunner
|
||||||
|
from .dataset.schema import EvalSample, EvalDataset
|
||||||
|
|
||||||
|
__all__ = ["EvalRunner", "EvalSample", "EvalDataset"]
|
||||||
4
sdk/rag_eval/adapters/__init__.py
Normal file
4
sdk/rag_eval/adapters/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .base import RAGAdapter, RetrievedChunk, AgentResponse
|
||||||
|
from .dagent import DagentAdapter
|
||||||
|
|
||||||
|
__all__ = ["RAGAdapter", "RetrievedChunk", "AgentResponse", "DagentAdapter"]
|
||||||
46
sdk/rag_eval/adapters/base.py
Normal file
46
sdk/rag_eval/adapters/base.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetrievedChunk:
|
||||||
|
chunk_id: str
|
||||||
|
content: str
|
||||||
|
score: float
|
||||||
|
headers: str = ""
|
||||||
|
file_id: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentResponse:
|
||||||
|
answer: str
|
||||||
|
retrieved_chunks: list[RetrievedChunk] = field(default_factory=list)
|
||||||
|
latency_ms: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class RAGAdapter(ABC):
|
||||||
|
"""
|
||||||
|
任何 RAG 平台都需要实现这两个方法。
|
||||||
|
框架通过此接口与平台交互,不依赖平台内部实现。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def retrieve(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
knowledge_hub_id: str,
|
||||||
|
top_k: int = 10,
|
||||||
|
**kwargs,
|
||||||
|
) -> list[RetrievedChunk]:
|
||||||
|
"""调用平台检索接口,返回召回的切片列表"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
agent_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> AgentResponse:
|
||||||
|
"""调用平台 Agent 对话接口,返回回复和引用的切片"""
|
||||||
|
...
|
||||||
138
sdk/rag_eval/adapters/dagent.py
Normal file
138
sdk/rag_eval/adapters/dagent.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
import aiohttp
|
||||||
|
from .base import RAGAdapter, RetrievedChunk, AgentResponse
|
||||||
|
|
||||||
|
|
||||||
|
class DagentAdapter(RAGAdapter):
|
||||||
|
"""
|
||||||
|
对接 dagent 平台的适配器。
|
||||||
|
通过 HTTP API 调用,不依赖 dagent 内部代码。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str, org_id: str, token: str = ""):
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.org_id = org_id
|
||||||
|
self.headers = {"Authorization": f"Bearer {token}"} if token else {}
|
||||||
|
|
||||||
|
async def retrieve(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
knowledge_hub_id: str,
|
||||||
|
top_k: int = 10,
|
||||||
|
file_id_list: list[str] | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> list[RetrievedChunk]:
|
||||||
|
payload = {
|
||||||
|
"query": query,
|
||||||
|
"org_id": self.org_id,
|
||||||
|
"top_k": top_k,
|
||||||
|
}
|
||||||
|
if knowledge_hub_id:
|
||||||
|
payload["knowledge_hub_id"] = knowledge_hub_id
|
||||||
|
if file_id_list:
|
||||||
|
payload["file_id_list"] = file_id_list
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/dagent/knowledge/hub/semantic_search_knowledge/detail",
|
||||||
|
json=payload,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=30),
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = await resp.json()
|
||||||
|
|
||||||
|
result_data = data.get("data", {})
|
||||||
|
standard = result_data.get("standard_answer_results") or []
|
||||||
|
related = result_data.get("related_knowledge_rerank_results_top") or []
|
||||||
|
all_items = standard + related
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
for item in all_items:
|
||||||
|
chunks.append(RetrievedChunk(
|
||||||
|
chunk_id=item.get("knowledge_md_header_split_id") or item.get("id", ""),
|
||||||
|
content=item.get("active_paragraph_context") or item.get("active_context") or "",
|
||||||
|
score=1.0 - (item.get("cosine_distance_1") or 0.0),
|
||||||
|
headers=item.get("headers") or "",
|
||||||
|
file_id=item.get("file_id") or "",
|
||||||
|
))
|
||||||
|
return chunks[:top_k]
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
agent_id: str,
|
||||||
|
llm_type: str = "azure_openai_4o",
|
||||||
|
**kwargs,
|
||||||
|
) -> AgentResponse:
|
||||||
|
import uuid
|
||||||
|
payload = {
|
||||||
|
"chat_id": str(uuid.uuid4()),
|
||||||
|
"task": query,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"org_id": self.org_id,
|
||||||
|
"llm_type": llm_type,
|
||||||
|
"chat_messages": [{"role": "user", "content": query}],
|
||||||
|
}
|
||||||
|
|
||||||
|
answer_parts: list[str] = []
|
||||||
|
start = time.monotonic()
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/dagent/agent/chat",
|
||||||
|
json=payload,
|
||||||
|
headers={**self.headers, "Accept": "text/event-stream"},
|
||||||
|
timeout=aiohttp.ClientTimeout(total=120),
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
async for raw_line in resp.content:
|
||||||
|
line = raw_line.decode("utf-8").strip()
|
||||||
|
if not line.startswith("data:"):
|
||||||
|
continue
|
||||||
|
data_str = line[5:].strip()
|
||||||
|
if not data_str or data_str == "[DONE]":
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
chunk = json.loads(data_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
msg_type = chunk.get("message_type", "")
|
||||||
|
if chunk.get("is_chunk_data") or msg_type in ("", "CHUNK"):
|
||||||
|
content = chunk.get("data", "")
|
||||||
|
if isinstance(content, str):
|
||||||
|
answer_parts.append(content)
|
||||||
|
elif msg_type == "EVENT":
|
||||||
|
event = chunk.get("data", {})
|
||||||
|
if isinstance(event, dict) and event.get("event_finish"):
|
||||||
|
break
|
||||||
|
|
||||||
|
latency_ms = int((time.monotonic() - start) * 1000)
|
||||||
|
return AgentResponse(
|
||||||
|
answer="".join(answer_parts).strip(),
|
||||||
|
retrieved_chunks=[],
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_chunks_for_file(
|
||||||
|
self,
|
||||||
|
file_id: str,
|
||||||
|
page_size: int = 100,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""拉取文件的所有 chunk,用于测试集生成"""
|
||||||
|
payload = {
|
||||||
|
"file_id": file_id,
|
||||||
|
"org_id": self.org_id,
|
||||||
|
"page": 1,
|
||||||
|
"page_size": page_size,
|
||||||
|
}
|
||||||
|
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/dagent/knowledge/chunk/page",
|
||||||
|
json=payload,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=30),
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = await resp.json()
|
||||||
|
# API returns data.data.list, not data.data.records
|
||||||
|
return data.get("data", {}).get("list", [])
|
||||||
97
sdk/rag_eval/cli.py
Normal file
97
sdk/rag_eval/cli.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
"""
|
||||||
|
CLI entry point: rag-eval run --config config.yaml
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import yaml
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(prog="rag-eval", description="RAG Evaluation Framework")
|
||||||
|
sub = parser.add_subparsers(dest="command")
|
||||||
|
|
||||||
|
# rag-eval run
|
||||||
|
run_p = sub.add_parser("run", help="Run evaluation")
|
||||||
|
run_p.add_argument("--config", required=True, help="Path to YAML config file")
|
||||||
|
run_p.add_argument("--dataset", required=True, help="Path to dataset JSON file")
|
||||||
|
run_p.add_argument("--output", default="eval_report.json", help="Output report path")
|
||||||
|
|
||||||
|
# rag-eval generate
|
||||||
|
gen_p = sub.add_parser("generate", help="Generate dataset from knowledge base")
|
||||||
|
gen_p.add_argument("--config", required=True, help="Path to YAML config file")
|
||||||
|
gen_p.add_argument("--output", default="dataset.json", help="Output dataset path")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
if not args.command:
|
||||||
|
parser.print_help()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
asyncio.run(_dispatch(args))
|
||||||
|
|
||||||
|
|
||||||
|
async def _dispatch(args):
|
||||||
|
config = _load_config(args.config)
|
||||||
|
|
||||||
|
from rag_eval.adapters.dagent import DagentAdapter
|
||||||
|
from rag_eval.judge.openai_compatible import OpenAICompatibleJudge
|
||||||
|
|
||||||
|
adapter = DagentAdapter(
|
||||||
|
base_url=config["platform"]["base_url"],
|
||||||
|
org_id=config["platform"]["org_id"],
|
||||||
|
token=config["platform"].get("token", ""),
|
||||||
|
)
|
||||||
|
judge = OpenAICompatibleJudge(
|
||||||
|
base_url=config["judge"]["base_url"],
|
||||||
|
api_key=config["judge"]["api_key"],
|
||||||
|
model=config["judge"]["model"],
|
||||||
|
embed_base_url=config["judge"].get("embed_base_url", ""),
|
||||||
|
embed_api_key=config["judge"].get("embed_api_key", ""),
|
||||||
|
embed_model=config["judge"].get("embed_model", "text-embedding-3-small"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.command == "run":
|
||||||
|
from rag_eval.runner import EvalRunner, RunConfig
|
||||||
|
from rag_eval.dataset.schema import EvalDataset
|
||||||
|
|
||||||
|
run_cfg = RunConfig(
|
||||||
|
agent_id=config["eval"]["agent_id"],
|
||||||
|
knowledge_hub_id=config["eval"]["knowledge_hub_id"],
|
||||||
|
top_k=config["eval"].get("top_k", 10),
|
||||||
|
eval_retrieval=config["eval"].get("eval_retrieval", True),
|
||||||
|
eval_generation=config["eval"].get("eval_generation", True),
|
||||||
|
file_id_list=config["eval"].get("file_id_list"),
|
||||||
|
concurrency=config["eval"].get("concurrency", 3),
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = EvalRunner(adapter=adapter, judge=judge)
|
||||||
|
|
||||||
|
def _progress(done, total):
|
||||||
|
print(f"\r Progress: {done}/{total}", end="", flush=True)
|
||||||
|
|
||||||
|
print(f"Running evaluation on {args.dataset} ...")
|
||||||
|
report = await runner.run(args.dataset, run_cfg, progress_cb=_progress)
|
||||||
|
print()
|
||||||
|
print(report.summary())
|
||||||
|
report.save(args.output)
|
||||||
|
|
||||||
|
elif args.command == "generate":
|
||||||
|
from rag_eval.dataset.generator import DatasetGenerator
|
||||||
|
|
||||||
|
gen = DatasetGenerator(judge=judge, adapter=adapter)
|
||||||
|
dataset = await gen.generate(
|
||||||
|
knowledge_hub_id=config["eval"]["knowledge_hub_id"],
|
||||||
|
file_id_list=config["eval"]["file_id_list"],
|
||||||
|
questions_per_chunk=config["eval"].get("questions_per_chunk", 2),
|
||||||
|
max_chunks=config["eval"].get("max_chunks", 50),
|
||||||
|
)
|
||||||
|
with open(args.output, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(dataset.to_dict(), f, ensure_ascii=False, indent=2)
|
||||||
|
print(f"Generated {len(dataset.samples)} samples → {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_config(path: str) -> dict:
|
||||||
|
with open(path, encoding="utf-8") as f:
|
||||||
|
return yaml.safe_load(f)
|
||||||
4
sdk/rag_eval/dataset/__init__.py
Normal file
4
sdk/rag_eval/dataset/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .schema import EvalSample, EvalDataset
|
||||||
|
from .generator import DatasetGenerator
|
||||||
|
|
||||||
|
__all__ = ["EvalSample", "EvalDataset", "DatasetGenerator"]
|
||||||
123
sdk/rag_eval/dataset/generator.py
Normal file
123
sdk/rag_eval/dataset/generator.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from .schema import EvalSample, EvalDataset
|
||||||
|
from ..judge.base import LLMJudge
|
||||||
|
|
||||||
|
_GEN_PROMPT = """你是一个专业的问答数据集构建专家。
|
||||||
|
基于以下文档片段,生成 {n} 个高质量的问题和对应的参考答案,用于评测知识库检索系统。
|
||||||
|
|
||||||
|
要求:
|
||||||
|
1. 问题必须能从文档中找到明确答案
|
||||||
|
2. 包含不同类型:事实性(factual)、推理性(reasoning)、比较性(comparison)
|
||||||
|
3. 同时生成一个该文档无法回答的问题(unanswerable),answer 填 "该文档中未提及此信息"
|
||||||
|
4. 参考答案要简洁准确
|
||||||
|
|
||||||
|
文档标题:{headers}
|
||||||
|
文档内容:
|
||||||
|
{content}
|
||||||
|
|
||||||
|
严格按以下 JSON 格式输出:
|
||||||
|
{{
|
||||||
|
"items": [
|
||||||
|
{{
|
||||||
|
"question": "问题文本",
|
||||||
|
"answer": "参考答案",
|
||||||
|
"type": "factual | reasoning | comparison | unanswerable",
|
||||||
|
"difficulty": "easy | medium | hard"
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}"""
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetGenerator:
|
||||||
|
def __init__(self, judge: LLMJudge, adapter=None):
|
||||||
|
self.judge = judge
|
||||||
|
self.adapter = adapter
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
knowledge_hub_id: str,
|
||||||
|
file_id_list: list[str],
|
||||||
|
questions_per_chunk: int = 2,
|
||||||
|
max_chunks: int = 50,
|
||||||
|
dataset_name: str = "Auto Generated Dataset",
|
||||||
|
chunk_ids: list[str] | None = None,
|
||||||
|
progress_cb=None,
|
||||||
|
) -> EvalDataset:
|
||||||
|
"""
|
||||||
|
遍历知识库切片,用 LLM 自动生成问答对,返回 EvalDataset。
|
||||||
|
progress_cb(done, total): 可选进度回调
|
||||||
|
chunk_ids: 若指定,只处理这些 chunk(忽略 file_id_list)
|
||||||
|
"""
|
||||||
|
samples: list[EvalSample] = []
|
||||||
|
|
||||||
|
# 收集所有待处理 chunks
|
||||||
|
all_chunks: list[dict] = []
|
||||||
|
if chunk_ids:
|
||||||
|
# 直接用指定的 chunk_ids,从 file_id_list 的第一个 file 拉取后过滤
|
||||||
|
for file_id in file_id_list:
|
||||||
|
raw = await self.adapter.get_chunks_for_file(file_id, page_size=max_chunks)
|
||||||
|
all_chunks.extend(raw)
|
||||||
|
all_chunks = [c for c in all_chunks if c.get("id") in chunk_ids]
|
||||||
|
else:
|
||||||
|
for file_id in file_id_list:
|
||||||
|
raw = await self.adapter.get_chunks_for_file(file_id, page_size=max_chunks)
|
||||||
|
all_chunks.extend(raw)
|
||||||
|
|
||||||
|
total = len(all_chunks)
|
||||||
|
done = 0
|
||||||
|
|
||||||
|
for chunk in all_chunks:
|
||||||
|
content = (
|
||||||
|
chunk.get("content")
|
||||||
|
or chunk.get("paragraph_context")
|
||||||
|
or chunk.get("large_paragraph_llm_summary")
|
||||||
|
or ""
|
||||||
|
)
|
||||||
|
headers = chunk.get("headers") or ""
|
||||||
|
if not content.strip():
|
||||||
|
done += 1
|
||||||
|
if progress_cb:
|
||||||
|
await progress_cb(done, total)
|
||||||
|
continue
|
||||||
|
|
||||||
|
prompt = _GEN_PROMPT.format(
|
||||||
|
n=questions_per_chunk,
|
||||||
|
headers=headers,
|
||||||
|
content=content[:2000],
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
raw = await self.judge._call_json(prompt)
|
||||||
|
for item in raw.get("items", []):
|
||||||
|
if not item.get("question") or not item.get("answer"):
|
||||||
|
continue
|
||||||
|
samples.append(EvalSample(
|
||||||
|
id=uuid.uuid4().hex,
|
||||||
|
question=item["question"],
|
||||||
|
reference_answer=item["answer"],
|
||||||
|
relevant_chunk_ids=[chunk["id"]] if chunk.get("id") else [],
|
||||||
|
knowledge_hub_id=knowledge_hub_id,
|
||||||
|
source_file_id=chunk.get("file_id", ""),
|
||||||
|
metadata={
|
||||||
|
"type": item.get("type", "factual"),
|
||||||
|
"difficulty": item.get("difficulty", "medium"),
|
||||||
|
"chunk_id": chunk.get("id", ""),
|
||||||
|
"chunk_headers": chunk.get("headers", ""),
|
||||||
|
"chunk_content_preview": content[:500] if content else "",
|
||||||
|
"file_name": chunk.get("file_name", ""),
|
||||||
|
},
|
||||||
|
))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
done += 1
|
||||||
|
if progress_cb:
|
||||||
|
await progress_cb(done, total)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
return EvalDataset(
|
||||||
|
id=uuid.uuid4().hex,
|
||||||
|
name=dataset_name,
|
||||||
|
description=f"Auto generated from {total} chunk(s), {len(samples)} samples",
|
||||||
|
samples=samples,
|
||||||
|
)
|
||||||
63
sdk/rag_eval/dataset/schema.py
Normal file
63
sdk/rag_eval/dataset/schema.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvalSample:
|
||||||
|
id: str
|
||||||
|
question: str
|
||||||
|
reference_answer: str
|
||||||
|
relevant_chunk_ids: list[str]
|
||||||
|
knowledge_hub_id: str
|
||||||
|
source_file_id: str | None = None
|
||||||
|
metadata: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvalDataset:
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
samples: list[EvalSample]
|
||||||
|
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"created_at": self.created_at.isoformat(),
|
||||||
|
"samples": [
|
||||||
|
{
|
||||||
|
"id": s.id,
|
||||||
|
"question": s.question,
|
||||||
|
"reference_answer": s.reference_answer,
|
||||||
|
"relevant_chunk_ids": s.relevant_chunk_ids,
|
||||||
|
"knowledge_hub_id": s.knowledge_hub_id,
|
||||||
|
"source_file_id": s.source_file_id,
|
||||||
|
"metadata": s.metadata,
|
||||||
|
}
|
||||||
|
for s in self.samples
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict) -> "EvalDataset":
|
||||||
|
samples = [
|
||||||
|
EvalSample(
|
||||||
|
id=s["id"],
|
||||||
|
question=s["question"],
|
||||||
|
reference_answer=s.get("reference_answer", ""),
|
||||||
|
relevant_chunk_ids=s.get("relevant_chunk_ids", []),
|
||||||
|
knowledge_hub_id=s.get("knowledge_hub_id", ""),
|
||||||
|
source_file_id=s.get("source_file_id"),
|
||||||
|
metadata=s.get("metadata", {}),
|
||||||
|
)
|
||||||
|
for s in data.get("samples", [])
|
||||||
|
]
|
||||||
|
return cls(
|
||||||
|
id=data["id"],
|
||||||
|
name=data["name"],
|
||||||
|
description=data.get("description", ""),
|
||||||
|
samples=samples,
|
||||||
|
)
|
||||||
3
sdk/rag_eval/evaluators/__init__.py
Normal file
3
sdk/rag_eval/evaluators/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .retrieval import hit_rate, mrr, ndcg
|
||||||
|
|
||||||
|
__all__ = ["hit_rate", "mrr", "ndcg"]
|
||||||
32
sdk/rag_eval/evaluators/retrieval.py
Normal file
32
sdk/rag_eval/evaluators/retrieval.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def hit_rate(retrieved_ids: list[str], relevant_ids: list[str]) -> float:
|
||||||
|
if not relevant_ids:
|
||||||
|
return 0.0
|
||||||
|
return 1.0 if any(r in set(relevant_ids) for r in retrieved_ids) else 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def mrr(retrieved_ids: list[str], relevant_ids: list[str]) -> float:
|
||||||
|
if not relevant_ids:
|
||||||
|
return 0.0
|
||||||
|
relevant_set = set(relevant_ids)
|
||||||
|
for rank, rid in enumerate(retrieved_ids, start=1):
|
||||||
|
if rid in relevant_set:
|
||||||
|
return 1.0 / rank
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def ndcg(retrieved_ids: list[str], relevant_ids: list[str], k: int = 10) -> float:
|
||||||
|
if not relevant_ids:
|
||||||
|
return 0.0
|
||||||
|
relevant_set = set(relevant_ids)
|
||||||
|
top_k = retrieved_ids[:k]
|
||||||
|
dcg = sum(
|
||||||
|
1.0 / math.log2(i + 2)
|
||||||
|
for i, rid in enumerate(top_k)
|
||||||
|
if rid in relevant_set
|
||||||
|
)
|
||||||
|
ideal_hits = min(len(relevant_set), k)
|
||||||
|
idcg = sum(1.0 / math.log2(i + 2) for i in range(ideal_hits))
|
||||||
|
return round(dcg / idcg, 4) if idcg > 0 else 0.0
|
||||||
4
sdk/rag_eval/judge/__init__.py
Normal file
4
sdk/rag_eval/judge/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .base import LLMJudge
|
||||||
|
from .openai_compatible import OpenAICompatibleJudge
|
||||||
|
|
||||||
|
__all__ = ["LLMJudge", "OpenAICompatibleJudge"]
|
||||||
22
sdk/rag_eval/judge/base.py
Normal file
22
sdk/rag_eval/judge/base.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
|
||||||
|
class LLMJudge(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def score_faithfulness(self, answer: str, context: list[str]) -> tuple[float, dict]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def score_relevance(self, question: str, answer: str) -> tuple[float, dict]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def score_correctness(self, answer: str, reference: str) -> tuple[float, dict]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def score_groundedness(self, answer: str, chunks: list[dict]) -> tuple[float, dict]:
|
||||||
|
...
|
||||||
288
sdk/rag_eval/judge/openai_compatible.py
Normal file
288
sdk/rag_eval/judge/openai_compatible.py
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
from .base import LLMJudge
|
||||||
|
|
||||||
|
# ── Prompts ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_DECOMPOSE_PROMPT = """请将以下回答分解为独立的原子声明列表,每条声明是一个不可再分的事实陈述。
|
||||||
|
回答:{answer}
|
||||||
|
只输出 JSON 数组,格式:["声明1", "声明2", ...]"""
|
||||||
|
|
||||||
|
_VERIFY_CLAIM_PROMPT = """参考资料:
|
||||||
|
{context}
|
||||||
|
|
||||||
|
声明:{claim}
|
||||||
|
|
||||||
|
上述声明是否可以从参考资料中推导出来?只回答 yes 或 no。"""
|
||||||
|
|
||||||
|
_RELEVANCE_GEN_PROMPT = """基于以下回答,生成 3 个该回答可能在回答的问题。
|
||||||
|
回答:{answer}
|
||||||
|
只输出 JSON 数组,格式:["问题1", "问题2", "问题3"]"""
|
||||||
|
|
||||||
|
_CORRECTNESS_PROMPT = """请评估以下回答与参考答案的事实一致程度。
|
||||||
|
|
||||||
|
参考答案:{reference}
|
||||||
|
待评估回答:{answer}
|
||||||
|
|
||||||
|
请从以下维度评估:
|
||||||
|
1. 事实一致性:回答中的事实与参考答案是否一致
|
||||||
|
2. 信息完整性:回答是否覆盖了参考答案的关键信息
|
||||||
|
3. 有无错误信息:回答是否包含参考答案中没有的错误内容
|
||||||
|
|
||||||
|
输出 JSON:
|
||||||
|
{{"score": 0到1之间的小数, "reason": "简短理由", "factual_tp": 正确事实数, "factual_fp": 错误事实数, "factual_fn": 遗漏事实数}}"""
|
||||||
|
|
||||||
|
_GROUNDEDNESS_PROMPT = """以下是检索到的切片列表(带编号):
|
||||||
|
{numbered_chunks}
|
||||||
|
|
||||||
|
AI 回答:{answer}
|
||||||
|
|
||||||
|
请将回答分解为原子声明,并为每条声明标注支撑它的切片编号(无支撑则填 null)。
|
||||||
|
输出 JSON:{{"claims": [{{"text": "声明内容", "source_chunk_index": 1}}, {{"text": "声明内容", "source_chunk_index": null}}]}}"""
|
||||||
|
|
||||||
|
_CONTEXT_PRECISION_PROMPT = """问题:{question}
|
||||||
|
参考答案:{ground_truth}
|
||||||
|
|
||||||
|
以下是检索系统返回的文档片段列表:
|
||||||
|
{chunks_text}
|
||||||
|
|
||||||
|
请判断每个片段对于回答该问题是否有用。
|
||||||
|
输出 JSON:{{"results": [{{"index": 1, "useful": true, "reason": "简短理由"}}]}}"""
|
||||||
|
|
||||||
|
_CONTEXT_RECALL_PROMPT = """参考答案:{ground_truth}
|
||||||
|
|
||||||
|
检索到的文档内容(合并):
|
||||||
|
{retrieved_context}
|
||||||
|
|
||||||
|
请将参考答案拆分为若干独立陈述,判断每个陈述是否能在检索文档中找到支撑。
|
||||||
|
输出 JSON:{{"statements": [{{"text": "陈述内容", "supported": true}}]}}"""
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompatibleJudge(LLMJudge):
|
||||||
|
"""
|
||||||
|
兼容所有 OpenAI 协议的模型:DeepSeek / Qwen / OpenAI / Azure OpenAI
|
||||||
|
评判逻辑使用中文 prompt,适合中文 RAG 场景
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
model: str,
|
||||||
|
embed_base_url: str = "",
|
||||||
|
embed_api_key: str = "",
|
||||||
|
embed_model: str = "text-embedding-3-small",
|
||||||
|
):
|
||||||
|
self.client = AsyncOpenAI(
|
||||||
|
base_url=base_url or None,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
self.model = model
|
||||||
|
# 独立的 embedding client(可与 LLM 使用不同的 endpoint)
|
||||||
|
self.embed_client = AsyncOpenAI(
|
||||||
|
base_url=embed_base_url or base_url or None,
|
||||||
|
api_key=embed_api_key or api_key,
|
||||||
|
)
|
||||||
|
self.embed_model = embed_model
|
||||||
|
|
||||||
|
async def _call(self, prompt: str) -> str:
|
||||||
|
resp = await self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
return (resp.choices[0].message.content or "").strip()
|
||||||
|
|
||||||
|
async def _call_json(self, prompt: str) -> dict | list:
|
||||||
|
resp = await self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
raw = (resp.choices[0].message.content or "").strip()
|
||||||
|
# 去掉 markdown 代码块包装(```json ... ``` 或 ``` ... ```)
|
||||||
|
if raw.startswith("```"):
|
||||||
|
lines = raw.splitlines()
|
||||||
|
# 去掉首行(```json 或 ```)和末行(```)
|
||||||
|
inner = lines[1:] if lines[0].startswith("```") else lines
|
||||||
|
if inner and inner[-1].strip() == "```":
|
||||||
|
inner = inner[:-1]
|
||||||
|
raw = "\n".join(inner).strip()
|
||||||
|
try:
|
||||||
|
return json.loads(raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# 尝试提取第一个 JSON 对象或数组
|
||||||
|
import re
|
||||||
|
m = re.search(r'(\{[\s\S]*\}|\[[\s\S]*\])', raw)
|
||||||
|
if m:
|
||||||
|
try:
|
||||||
|
return json.loads(m.group(1))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# ── Faithfulness(两步法)────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def score_faithfulness(self, answer: str, context: list[str]) -> tuple[float, dict]:
|
||||||
|
if not answer or not context:
|
||||||
|
return 0.0, {}
|
||||||
|
|
||||||
|
# Step 1: 分解为原子声明
|
||||||
|
raw_claims = await self._call_json(
|
||||||
|
_DECOMPOSE_PROMPT.format(answer=answer)
|
||||||
|
)
|
||||||
|
if isinstance(raw_claims, list):
|
||||||
|
claims = raw_claims
|
||||||
|
else:
|
||||||
|
claims = raw_claims.get("items", []) or raw_claims.get("claims", [])
|
||||||
|
|
||||||
|
if not claims:
|
||||||
|
return 0.0, {"claims": []}
|
||||||
|
|
||||||
|
context_text = "\n\n".join(c[:800] for c in context)
|
||||||
|
|
||||||
|
# Step 2: 逐条验证(并发)
|
||||||
|
async def _verify(claim: str) -> bool:
|
||||||
|
result = await self._call(
|
||||||
|
_VERIFY_CLAIM_PROMPT.format(context=context_text, claim=claim)
|
||||||
|
)
|
||||||
|
return "yes" in result.lower()
|
||||||
|
|
||||||
|
results = await asyncio.gather(*[_verify(c) for c in claims])
|
||||||
|
supported = sum(results)
|
||||||
|
score = round(supported / len(claims), 4)
|
||||||
|
|
||||||
|
detail = {
|
||||||
|
"claims": [
|
||||||
|
{"text": c, "supported": bool(r)}
|
||||||
|
for c, r in zip(claims, results)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
return score, detail
|
||||||
|
|
||||||
|
# ── Answer Relevance(反向生成 + 语义相似)───────────────────────────────
|
||||||
|
|
||||||
|
async def score_relevance(self, question: str, answer: str) -> tuple[float, dict]:
|
||||||
|
if not answer:
|
||||||
|
return 0.0, {}
|
||||||
|
|
||||||
|
raw = await self._call_json(
|
||||||
|
_RELEVANCE_GEN_PROMPT.format(answer=answer)
|
||||||
|
)
|
||||||
|
if isinstance(raw, list):
|
||||||
|
gen_questions = raw
|
||||||
|
else:
|
||||||
|
gen_questions = raw.get("items", []) or raw.get("questions", [])
|
||||||
|
|
||||||
|
if not gen_questions:
|
||||||
|
return 0.0, {}
|
||||||
|
|
||||||
|
# 用 embedding cosine 相似度计算
|
||||||
|
scores = await asyncio.gather(*[
|
||||||
|
self._embedding_similarity(question, q) for q in gen_questions
|
||||||
|
])
|
||||||
|
avg = round(sum(scores) / len(scores), 4)
|
||||||
|
return avg, {"generated_questions": gen_questions, "similarities": list(scores)}
|
||||||
|
|
||||||
|
async def _embedding_similarity(self, text_a: str, text_b: str) -> float:
|
||||||
|
import numpy as np
|
||||||
|
resp = await self.embed_client.embeddings.create(
|
||||||
|
model=self.embed_model,
|
||||||
|
input=[text_a, text_b],
|
||||||
|
)
|
||||||
|
a = np.array(resp.data[0].embedding)
|
||||||
|
b = np.array(resp.data[1].embedding)
|
||||||
|
cos = float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-9))
|
||||||
|
return round(max(0.0, cos), 4)
|
||||||
|
|
||||||
|
# ── Answer Correctness ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def score_correctness(self, answer: str, reference: str) -> tuple[float, dict]:
|
||||||
|
if not answer or not reference:
|
||||||
|
return 0.0, {}
|
||||||
|
|
||||||
|
raw = await self._call_json(
|
||||||
|
_CORRECTNESS_PROMPT.format(reference=reference, answer=answer)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
score = float(raw.get("score", 0.0))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
score = 0.0
|
||||||
|
|
||||||
|
tp = raw.get("factual_tp", 0) or 0
|
||||||
|
fp = raw.get("factual_fp", 0) or 0
|
||||||
|
fn = raw.get("factual_fn", 0) or 0
|
||||||
|
f1 = (2 * tp / (2 * tp + fp + fn)) if (2 * tp + fp + fn) > 0 else 0.0
|
||||||
|
final = round(0.75 * f1 + 0.25 * score, 4)
|
||||||
|
return final, raw
|
||||||
|
|
||||||
|
# ── Groundedness(可溯源性)──────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def score_groundedness(self, answer: str, chunks: list[dict]) -> tuple[float, dict]:
|
||||||
|
if not answer or not chunks:
|
||||||
|
return 0.0, {}
|
||||||
|
|
||||||
|
numbered = "\n".join(
|
||||||
|
f"[{i+1}] {c.get('content', '')[:500]}" for i, c in enumerate(chunks)
|
||||||
|
)
|
||||||
|
raw = await self._call_json(
|
||||||
|
_GROUNDEDNESS_PROMPT.format(numbered_chunks=numbered, answer=answer)
|
||||||
|
)
|
||||||
|
claims = raw.get("claims", [])
|
||||||
|
if not claims:
|
||||||
|
return 0.0, raw
|
||||||
|
|
||||||
|
grounded = sum(1 for c in claims if c.get("source_chunk_index") is not None)
|
||||||
|
score = round(grounded / len(claims), 4)
|
||||||
|
return score, raw
|
||||||
|
|
||||||
|
# ── Context Precision ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def score_context_precision(
|
||||||
|
self, question: str, ground_truth: str, retrieved_chunks: list[str]
|
||||||
|
) -> tuple[float, dict]:
|
||||||
|
if not retrieved_chunks or not ground_truth:
|
||||||
|
return 0.0, {}
|
||||||
|
|
||||||
|
chunks_text = "\n".join(f"[{i+1}] {c[:500]}" for i, c in enumerate(retrieved_chunks))
|
||||||
|
raw = await self._call_json(
|
||||||
|
_CONTEXT_PRECISION_PROMPT.format(
|
||||||
|
question=question, ground_truth=ground_truth, chunks_text=chunks_text
|
||||||
|
)
|
||||||
|
)
|
||||||
|
results = raw.get("results", [])
|
||||||
|
if not results:
|
||||||
|
return 0.0, raw
|
||||||
|
|
||||||
|
useful_flags = [
|
||||||
|
r.get("useful", False)
|
||||||
|
for r in sorted(results, key=lambda x: x.get("index", 0))
|
||||||
|
]
|
||||||
|
# Weighted precision@k
|
||||||
|
score = sum(
|
||||||
|
(sum(useful_flags[:k+1]) / (k+1)) * useful_flags[k]
|
||||||
|
for k in range(len(useful_flags))
|
||||||
|
) / max(sum(useful_flags), 1)
|
||||||
|
return round(min(score, 1.0), 4), raw
|
||||||
|
|
||||||
|
# ── Context Recall ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def score_context_recall(
|
||||||
|
self, ground_truth: str, retrieved_chunks: list[str]
|
||||||
|
) -> tuple[float, dict]:
|
||||||
|
if not retrieved_chunks or not ground_truth:
|
||||||
|
return 0.0, {}
|
||||||
|
|
||||||
|
retrieved_context = "\n\n".join(c[:800] for c in retrieved_chunks)
|
||||||
|
raw = await self._call_json(
|
||||||
|
_CONTEXT_RECALL_PROMPT.format(
|
||||||
|
ground_truth=ground_truth, retrieved_context=retrieved_context
|
||||||
|
)
|
||||||
|
)
|
||||||
|
statements = raw.get("statements", [])
|
||||||
|
if not statements:
|
||||||
|
return 0.0, raw
|
||||||
|
|
||||||
|
supported = sum(1 for s in statements if s.get("supported"))
|
||||||
|
return round(supported / len(statements), 4), raw
|
||||||
0
sdk/rag_eval/multi_hop/__init__.py
Normal file
0
sdk/rag_eval/multi_hop/__init__.py
Normal file
115
sdk/rag_eval/multi_hop/cli.py
Normal file
115
sdk/rag_eval/multi_hop/cli.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
"""
|
||||||
|
多跳召回测试 CLI。
|
||||||
|
|
||||||
|
用法:
|
||||||
|
python -m rag_eval.multi_hop.cli \\
|
||||||
|
--env-url https://your-dagent-env.com \\
|
||||||
|
--org-id cd6e121594984516... \\
|
||||||
|
--qa-file path/to/multi_hop.md \\
|
||||||
|
--top-k 10 \\
|
||||||
|
--concurrency 5 \\
|
||||||
|
--output report.json
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
|
from rag_eval.multi_hop.parser import parse_multi_hop_file
|
||||||
|
from rag_eval.multi_hop.tester import MultiHopTester
|
||||||
|
from rag_eval.multi_hop.report import build_report
|
||||||
|
from rag_eval.single_jump.mapper import FileMapper
|
||||||
|
|
||||||
|
|
||||||
|
async def run(args):
|
||||||
|
# 1. 解析 MD 文件
|
||||||
|
print(f"[1/4] 解析多跳问答文件: {args.qa_file}")
|
||||||
|
case = parse_multi_hop_file(args.qa_file)
|
||||||
|
qa_pairs = case.qa_pairs
|
||||||
|
if not qa_pairs:
|
||||||
|
print("ERROR: 未解析到任何多跳问答对,请检查文件格式")
|
||||||
|
sys.exit(1)
|
||||||
|
print(f" 共 {len(qa_pairs)} 个问题,"
|
||||||
|
f"hop 数分布: {_hop_dist(qa_pairs)}")
|
||||||
|
|
||||||
|
# 2. 拉取知识库文件列表,构建 section_path -> file_id 映射
|
||||||
|
print(f"[2/4] 拉取知识库文件列表...")
|
||||||
|
mapper = FileMapper(args.env_url, args.org_id, args.d_user_id)
|
||||||
|
file_count = await mapper.load_files()
|
||||||
|
print(f" 共 {file_count} 个文件")
|
||||||
|
|
||||||
|
# 收集所有 hop 的 section_path,批量映射
|
||||||
|
all_paths = {hop.section_path for qa in qa_pairs for hop in qa.hops}
|
||||||
|
file_map = {path: mapper.map_section_to_file(path) for path in all_paths}
|
||||||
|
|
||||||
|
mapped = sum(1 for v in file_map.values() if v)
|
||||||
|
unmapped = sum(1 for v in file_map.values() if not v)
|
||||||
|
print(f" 映射成功: {mapped} 未映射: {unmapped}")
|
||||||
|
if unmapped:
|
||||||
|
for path, v in file_map.items():
|
||||||
|
if not v:
|
||||||
|
print(f" [未映射] {path}")
|
||||||
|
|
||||||
|
# 3. 执行多跳召回测试
|
||||||
|
print(f"[3/4] 执行召回测试 (top_k={args.top_k}, concurrency={args.concurrency})...")
|
||||||
|
tester = MultiHopTester(args.env_url, args.org_id, args.d_user_id)
|
||||||
|
|
||||||
|
done_count = 0
|
||||||
|
|
||||||
|
async def progress_cb(result, done, total):
|
||||||
|
nonlocal done_count
|
||||||
|
done_count = done
|
||||||
|
status = "全命中" if result.full_hit else (
|
||||||
|
f"部分命中({result.hop_hit_count}/{result.hop_count})" if result.partial_hit else "未命中"
|
||||||
|
)
|
||||||
|
if result.error:
|
||||||
|
status = f"ERROR: {result.error[:40]}"
|
||||||
|
print(f" [{done:>4}/{total}] {result.qid} {status}")
|
||||||
|
|
||||||
|
results = await tester.run(
|
||||||
|
qa_pairs,
|
||||||
|
file_map,
|
||||||
|
top_k=args.top_k,
|
||||||
|
concurrency=args.concurrency,
|
||||||
|
result_cb=progress_cb,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. 生成报告
|
||||||
|
print(f"[4/4] 生成报告...")
|
||||||
|
report = build_report(results, args.env_url, args.org_id, args.top_k)
|
||||||
|
print()
|
||||||
|
print(report.summary())
|
||||||
|
|
||||||
|
if args.output:
|
||||||
|
out_path = Path(args.output)
|
||||||
|
out_path.write_text(
|
||||||
|
json.dumps(report.to_dict(), ensure_ascii=False, indent=2),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
print(f"\n报告已保存: {out_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def _hop_dist(qa_pairs) -> str:
|
||||||
|
from collections import Counter
|
||||||
|
c = Counter(len(qa.hops) for qa in qa_pairs)
|
||||||
|
return " ".join(f"{k}跳×{v}" for k, v in sorted(c.items()))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="多跳召回测试")
|
||||||
|
parser.add_argument("--env-url", required=True, help="Dagent 环境地址")
|
||||||
|
parser.add_argument("--org-id", required=True, help="组织 ID")
|
||||||
|
parser.add_argument("--d-user-id", default="test", help="d-user-id 请求头")
|
||||||
|
parser.add_argument("--qa-file", required=True, help="多跳问答 MD 文件路径")
|
||||||
|
parser.add_argument("--top-k", type=int, default=10, help="召回数量(建议 ≥10)")
|
||||||
|
parser.add_argument("--concurrency", type=int, default=5, help="并发数")
|
||||||
|
parser.add_argument("--output", default=None, help="报告输出路径(JSON)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
asyncio.run(run(args))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
23
sdk/rag_eval/multi_hop/example.md
Normal file
23
sdk/rag_eval/multi_hop/example.md
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
## MH1
|
||||||
|
**类型:** comparison
|
||||||
|
**问题:** RDK X3 和 RDK X5 的 CPU 核心数和主频分别是多少,有何差异?
|
||||||
|
**答案:** RDK X3 搭载 4 核 ARM Cortex-A53,主频 1.2GHz;RDK X5 搭载 8 核 ARM Cortex-A55,主频 1.5GHz,X5 核心数翻倍且主频更高。
|
||||||
|
**Hop1:** hardware / rdk_x3_spec | 提供 RDK X3 的 CPU 规格参数
|
||||||
|
**Hop2:** hardware / rdk_x5_spec | 提供 RDK X5 的 CPU 规格参数
|
||||||
|
---
|
||||||
|
|
||||||
|
## MH2
|
||||||
|
**类型:** reasoning
|
||||||
|
**问题:** 使用 RDK 开发板进行 BPU 推理时,需要先完成哪些环境准备步骤?
|
||||||
|
**答案:** 需要先完成系统烧录、驱动安装,再配置 Python 环境,最后安装 horizon_bpu 推理库。
|
||||||
|
**Hop1:** quick_start / system_install | 提供系统烧录和驱动安装步骤
|
||||||
|
**Hop2:** linux_development / bpu_develop | 提供 BPU 推理环境配置和库安装步骤
|
||||||
|
---
|
||||||
|
|
||||||
|
## MH3
|
||||||
|
**类型:** aggregation
|
||||||
|
**问题:** RDK 平台支持哪些多媒体编解码格式,对应的硬件加速模块是什么?
|
||||||
|
**答案:** 支持 H.264/H.265 编解码,由 VPU 硬件模块加速;支持 JPEG 编解码,由 JPU 模块加速。
|
||||||
|
**Hop1:** multimedia_development / codec_overview | 提供支持的编解码格式列表
|
||||||
|
**Hop2:** hardware / hardware_modules | 提供 VPU/JPU 硬件模块说明
|
||||||
|
---
|
||||||
130
sdk/rag_eval/multi_hop/parser.py
Normal file
130
sdk/rag_eval/multi_hop/parser.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
"""
|
||||||
|
多跳问答 MD 文件解析器。
|
||||||
|
|
||||||
|
文件格式:
|
||||||
|
## MH1
|
||||||
|
**类型:** comparison
|
||||||
|
**问题:** A 产品和 B 产品的接口规格有何差异?
|
||||||
|
**答案:** A 产品...,B 产品...
|
||||||
|
**Hop1:** linux_development / bsp_develop | 该片段提供了 A 产品的接口规格
|
||||||
|
**Hop2:** hardware / interface_spec | 该片段提供了 B 产品的接口规格
|
||||||
|
---
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Hop:
|
||||||
|
section_path: str # 对应知识库文件的路径标识,与单跳 section_path 格式一致
|
||||||
|
contribution: str # 该 hop 提供了什么信息
|
||||||
|
chunk_id: str = "" # 期望命中的切片 ID(paragraph_chunk_id);为空则退化为仅文件级命中
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultiHopQAPair:
|
||||||
|
qid: str # MH1, MH2, ...
|
||||||
|
question: str
|
||||||
|
answer: str
|
||||||
|
hops: list[Hop] # 至少 2 个
|
||||||
|
type: str = "reasoning" # comparison / reasoning / aggregation
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultiHopCase:
|
||||||
|
"""一组多跳问答对,对应一个 MD 文件"""
|
||||||
|
qa_pairs: list[MultiHopQAPair] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_multi_hop_file(filepath: str) -> MultiHopCase:
|
||||||
|
with open(filepath, encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
return parse_multi_hop_text(content)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_multi_hop_text(content: str) -> MultiHopCase:
|
||||||
|
"""从文本内容解析多跳问答对"""
|
||||||
|
case = MultiHopCase()
|
||||||
|
current: dict | None = None
|
||||||
|
|
||||||
|
def _flush():
|
||||||
|
if not current:
|
||||||
|
return
|
||||||
|
qid = current.get("qid", "")
|
||||||
|
question = current.get("question", "").strip()
|
||||||
|
answer = current.get("answer", "").strip()
|
||||||
|
hops = current.get("hops", [])
|
||||||
|
qtype = current.get("type", "reasoning")
|
||||||
|
if qid and question and answer and len(hops) >= 2:
|
||||||
|
case.qa_pairs.append(MultiHopQAPair(
|
||||||
|
qid=qid,
|
||||||
|
question=question,
|
||||||
|
answer=answer,
|
||||||
|
hops=hops,
|
||||||
|
type=qtype,
|
||||||
|
))
|
||||||
|
|
||||||
|
for line in content.splitlines():
|
||||||
|
# 新问题块:## MH1
|
||||||
|
m = re.match(r"^## (MH\d+)\s*$", line)
|
||||||
|
if m:
|
||||||
|
_flush()
|
||||||
|
current = {"qid": m.group(1), "hops": []}
|
||||||
|
continue
|
||||||
|
|
||||||
|
if current is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 类型
|
||||||
|
m = re.match(r"^\*\*类型[::]\*\*\s*(.+)$", line)
|
||||||
|
if m:
|
||||||
|
current["type"] = m.group(1).strip()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 问题
|
||||||
|
m = re.match(r"^\*\*问题[::]\*\*\s*(.+)$", line)
|
||||||
|
if m:
|
||||||
|
current["question"] = m.group(1).strip()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 答案
|
||||||
|
m = re.match(r"^\*\*答案[::]\*\*\s*(.+)$", line)
|
||||||
|
if m:
|
||||||
|
current["answer"] = m.group(1).strip()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Hop:**Hop1:** section_path | contribution [| chunk_id]
|
||||||
|
m = re.match(r"^\*\*Hop\d+[::]\*\*\s*(.+)$", line)
|
||||||
|
if m:
|
||||||
|
raw = m.group(1).strip()
|
||||||
|
parts = [p.strip() for p in raw.split("|")]
|
||||||
|
path = parts[0] if parts else ""
|
||||||
|
contrib = parts[1] if len(parts) > 1 else ""
|
||||||
|
chunk_id = parts[2] if len(parts) > 2 else ""
|
||||||
|
current["hops"].append(Hop(
|
||||||
|
section_path=path,
|
||||||
|
contribution=contrib,
|
||||||
|
chunk_id=chunk_id,
|
||||||
|
))
|
||||||
|
continue
|
||||||
|
|
||||||
|
_flush()
|
||||||
|
return case
|
||||||
|
|
||||||
|
|
||||||
|
def dump_multi_hop_md(qa_pairs: list[MultiHopQAPair]) -> str:
|
||||||
|
"""将多跳问答对序列化为 MD 格式(用于生成/导出)"""
|
||||||
|
lines = []
|
||||||
|
for qa in qa_pairs:
|
||||||
|
lines.append(f"## {qa.qid}")
|
||||||
|
lines.append(f"**类型:** {qa.type}")
|
||||||
|
lines.append(f"**问题:** {qa.question}")
|
||||||
|
lines.append(f"**答案:** {qa.answer}")
|
||||||
|
for i, hop in enumerate(qa.hops, 1):
|
||||||
|
if hop.chunk_id:
|
||||||
|
lines.append(f"**Hop{i}:** {hop.section_path} | {hop.contribution} | {hop.chunk_id}")
|
||||||
|
else:
|
||||||
|
lines.append(f"**Hop{i}:** {hop.section_path} | {hop.contribution}")
|
||||||
|
lines.append("---")
|
||||||
|
lines.append("")
|
||||||
|
return "\n".join(lines)
|
||||||
178
sdk/rag_eval/multi_hop/report.py
Normal file
178
sdk/rag_eval/multi_hop/report.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
"""
|
||||||
|
多跳召回测试报告生成。
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from .tester import MultiHopResult
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultiHopReport:
|
||||||
|
env_url: str
|
||||||
|
org_id: str
|
||||||
|
top_k: int
|
||||||
|
total: int
|
||||||
|
error_count: int
|
||||||
|
empty_count: int # retrieved 为空
|
||||||
|
full_hit_count: int # 所有 hop 全部命中
|
||||||
|
partial_hit_count: int # 至少命中 1 个 hop(含全命中)
|
||||||
|
avg_hop_hit_rate: float # 平均每题命中 hop 比例
|
||||||
|
avg_latency_ms: float
|
||||||
|
avg_best_sim: float | None
|
||||||
|
by_type: dict # {type: {total, full_hit, partial_hit}}
|
||||||
|
results: list[MultiHopResult] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def full_hit_rate(self) -> float:
|
||||||
|
return round(self.full_hit_count / self.total, 4) if self.total else 0.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def partial_hit_rate(self) -> float:
|
||||||
|
return round(self.partial_hit_count / self.total, 4) if self.total else 0.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def empty_rate(self) -> float:
|
||||||
|
return round(self.empty_count / self.total, 4) if self.total else 0.0
|
||||||
|
|
||||||
|
def summary(self) -> str:
|
||||||
|
lines = [
|
||||||
|
"=" * 60,
|
||||||
|
"多跳召回测试报告",
|
||||||
|
"=" * 60,
|
||||||
|
f"环境: {self.env_url}",
|
||||||
|
f"组织: {self.org_id}",
|
||||||
|
f"top_k: {self.top_k}",
|
||||||
|
f"总问题数: {self.total}",
|
||||||
|
f"全命中率: {self.full_hit_rate:.1%} ({self.full_hit_count}/{self.total})",
|
||||||
|
f"部分命中率: {self.partial_hit_rate:.1%} ({self.partial_hit_count}/{self.total})",
|
||||||
|
f"空召回率: {self.empty_rate:.1%} ({self.empty_count}/{self.total})",
|
||||||
|
f"平均hop命中: {self.avg_hop_hit_rate:.1%}",
|
||||||
|
f"平均延迟: {self.avg_latency_ms:.0f} ms",
|
||||||
|
]
|
||||||
|
if self.avg_best_sim is not None:
|
||||||
|
lines.append(f"平均最佳相似度: {self.avg_best_sim:.4f}")
|
||||||
|
if self.error_count:
|
||||||
|
lines.append(f"错误数: {self.error_count}")
|
||||||
|
|
||||||
|
if self.by_type:
|
||||||
|
lines.append("")
|
||||||
|
lines.append("按类型统计:")
|
||||||
|
for qtype, stat in self.by_type.items():
|
||||||
|
t = stat["total"]
|
||||||
|
fh = stat["full_hit"]
|
||||||
|
ph = stat["partial_hit"]
|
||||||
|
lines.append(
|
||||||
|
f" {qtype:<15} 共{t:>4}题 全命中{fh/t:.1%} 部分命中{ph/t:.1%}"
|
||||||
|
)
|
||||||
|
|
||||||
|
lines.append("=" * 60)
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"env_url": self.env_url,
|
||||||
|
"org_id": self.org_id,
|
||||||
|
"top_k": self.top_k,
|
||||||
|
"total": self.total,
|
||||||
|
"full_hit_count": self.full_hit_count,
|
||||||
|
"full_hit_rate": self.full_hit_rate,
|
||||||
|
"partial_hit_count": self.partial_hit_count,
|
||||||
|
"partial_hit_rate": self.partial_hit_rate,
|
||||||
|
"empty_count": self.empty_count,
|
||||||
|
"empty_rate": self.empty_rate,
|
||||||
|
"error_count": self.error_count,
|
||||||
|
"avg_hop_hit_rate": self.avg_hop_hit_rate,
|
||||||
|
"avg_latency_ms": self.avg_latency_ms,
|
||||||
|
"avg_best_sim": self.avg_best_sim,
|
||||||
|
"by_type": self.by_type,
|
||||||
|
"results": [_result_to_dict(r) for r in self.results],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _result_to_dict(r: MultiHopResult) -> dict:
|
||||||
|
return {
|
||||||
|
"qid": r.qid,
|
||||||
|
"question": r.question,
|
||||||
|
"type": r.type,
|
||||||
|
"full_hit": r.full_hit,
|
||||||
|
"partial_hit": r.partial_hit,
|
||||||
|
"hop_count": r.hop_count,
|
||||||
|
"hop_hit_count": r.hop_hit_count,
|
||||||
|
"latency_ms": r.latency_ms,
|
||||||
|
"best_cosine_sim": r.best_cosine_sim,
|
||||||
|
"error": r.error,
|
||||||
|
"hops": [
|
||||||
|
{
|
||||||
|
"section_path": h.section_path,
|
||||||
|
"file_id": h.file_id,
|
||||||
|
"file_name": h.file_name,
|
||||||
|
"hit": h.hit,
|
||||||
|
"contribution": h.contribution,
|
||||||
|
}
|
||||||
|
for h in r.hop_results
|
||||||
|
],
|
||||||
|
"retrieved_file_ids": list(r.retrieved_file_ids),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_report(
|
||||||
|
results: list[MultiHopResult],
|
||||||
|
env_url: str,
|
||||||
|
org_id: str,
|
||||||
|
top_k: int,
|
||||||
|
) -> MultiHopReport:
|
||||||
|
total = len(results)
|
||||||
|
if total == 0:
|
||||||
|
return MultiHopReport(
|
||||||
|
env_url=env_url, org_id=org_id, top_k=top_k,
|
||||||
|
total=0, error_count=0, empty_count=0,
|
||||||
|
full_hit_count=0, partial_hit_count=0,
|
||||||
|
avg_hop_hit_rate=0.0, avg_latency_ms=0.0,
|
||||||
|
avg_best_sim=None, by_type={}, results=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
error_count = sum(1 for r in results if r.error)
|
||||||
|
empty_count = sum(1 for r in results if r.is_empty and not r.error)
|
||||||
|
full_hit_count = sum(1 for r in results if r.full_hit)
|
||||||
|
partial_hit_count = sum(1 for r in results if r.partial_hit)
|
||||||
|
|
||||||
|
# 平均 hop 命中率(只统计有 file_id 映射的 hop)
|
||||||
|
hop_hit_rates = []
|
||||||
|
for r in results:
|
||||||
|
mappable = [h for h in r.hop_results if h.file_id]
|
||||||
|
if mappable:
|
||||||
|
hop_hit_rates.append(sum(1 for h in mappable if h.hit) / len(mappable))
|
||||||
|
avg_hop_hit_rate = sum(hop_hit_rates) / len(hop_hit_rates) if hop_hit_rates else 0.0
|
||||||
|
|
||||||
|
valid = [r for r in results if not r.error]
|
||||||
|
avg_latency_ms = sum(r.latency_ms for r in valid) / len(valid) if valid else 0.0
|
||||||
|
|
||||||
|
sims = [r.best_cosine_sim for r in valid if r.best_cosine_sim is not None]
|
||||||
|
avg_best_sim = round(sum(sims) / len(sims), 4) if sims else None
|
||||||
|
|
||||||
|
# 按类型统计
|
||||||
|
by_type: dict = {}
|
||||||
|
for r in results:
|
||||||
|
t = r.type
|
||||||
|
if t not in by_type:
|
||||||
|
by_type[t] = {"total": 0, "full_hit": 0, "partial_hit": 0}
|
||||||
|
by_type[t]["total"] += 1
|
||||||
|
if r.full_hit:
|
||||||
|
by_type[t]["full_hit"] += 1
|
||||||
|
if r.partial_hit:
|
||||||
|
by_type[t]["partial_hit"] += 1
|
||||||
|
|
||||||
|
return MultiHopReport(
|
||||||
|
env_url=env_url,
|
||||||
|
org_id=org_id,
|
||||||
|
top_k=top_k,
|
||||||
|
total=total,
|
||||||
|
error_count=error_count,
|
||||||
|
empty_count=empty_count,
|
||||||
|
full_hit_count=full_hit_count,
|
||||||
|
partial_hit_count=partial_hit_count,
|
||||||
|
avg_hop_hit_rate=round(avg_hop_hit_rate, 4),
|
||||||
|
avg_latency_ms=round(avg_latency_ms, 1),
|
||||||
|
avg_best_sim=avg_best_sim,
|
||||||
|
by_type=by_type,
|
||||||
|
results=results,
|
||||||
|
)
|
||||||
46
sdk/rag_eval/multi_hop/test_parser.py
Normal file
46
sdk/rag_eval/multi_hop/test_parser.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
"""
|
||||||
|
快速测试多跳模块的解析和数据结构。
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
|
from rag_eval.multi_hop.parser import parse_multi_hop_file, dump_multi_hop_md
|
||||||
|
|
||||||
|
|
||||||
|
def test_parser():
|
||||||
|
print("=" * 60)
|
||||||
|
print("测试多跳 MD 文件解析")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
example_file = Path(__file__).parent / "example.md"
|
||||||
|
if not example_file.exists():
|
||||||
|
print(f"ERROR: 示例文件不存在: {example_file}")
|
||||||
|
return
|
||||||
|
|
||||||
|
case = parse_multi_hop_file(str(example_file))
|
||||||
|
print(f"\n解析结果: 共 {len(case.qa_pairs)} 个问题\n")
|
||||||
|
|
||||||
|
for qa in case.qa_pairs:
|
||||||
|
print(f"问题 {qa.qid} ({qa.type}):")
|
||||||
|
print(f" Q: {qa.question}")
|
||||||
|
print(f" A: {qa.answer[:80]}...")
|
||||||
|
print(f" Hops ({len(qa.hops)}):")
|
||||||
|
for i, hop in enumerate(qa.hops, 1):
|
||||||
|
print(f" {i}. {hop.section_path}")
|
||||||
|
print(f" → {hop.contribution}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 测试序列化
|
||||||
|
print("=" * 60)
|
||||||
|
print("测试序列化")
|
||||||
|
print("=" * 60)
|
||||||
|
md_text = dump_multi_hop_md(case.qa_pairs)
|
||||||
|
print(md_text[:500])
|
||||||
|
print("...")
|
||||||
|
print("\nOK: 解析和序列化测试通过")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_parser()
|
||||||
334
sdk/rag_eval/multi_hop/tester.py
Normal file
334
sdk/rag_eval/multi_hop/tester.py
Normal file
@ -0,0 +1,334 @@
|
|||||||
|
"""
|
||||||
|
多跳召回测试执行器 v4
|
||||||
|
|
||||||
|
策略:调用 dagent 的 /agent/chat SSE 接口,让 Agent 自主决定搜几次、用什么 query。
|
||||||
|
解析 SSE 流中的 TOOL_END 事件,收集每一跳的召回文档,和期望 hop 做对比。
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from .parser import MultiHopQAPair, Hop
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HopResult:
|
||||||
|
section_path: str
|
||||||
|
file_id: str | None
|
||||||
|
file_name: str | None
|
||||||
|
contribution: str
|
||||||
|
expected_chunk_id: str = "" # 期望命中的切片ID
|
||||||
|
hit: bool = False # 文件级命中
|
||||||
|
hit_at_hop: int | None = None
|
||||||
|
chunk_hit: bool = False # 切片级命中
|
||||||
|
chunk_hit_at_hop: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ActualHop:
|
||||||
|
"""Agent 实际执行的一跳"""
|
||||||
|
hop_index: int
|
||||||
|
query: str
|
||||||
|
retrieved: list[dict]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultiHopResult:
|
||||||
|
qid: str
|
||||||
|
question: str
|
||||||
|
answer: str
|
||||||
|
type: str
|
||||||
|
top_k: int
|
||||||
|
hop_results: list[HopResult]
|
||||||
|
actual_hops: list[ActualHop] = field(default_factory=list)
|
||||||
|
agent_answer: str = ""
|
||||||
|
latency_ms: int = 0
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hop_count(self) -> int:
|
||||||
|
return len(self.hop_results)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def actual_hop_count(self) -> int:
|
||||||
|
return len(self.actual_hops)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hop_hit_count(self) -> int:
|
||||||
|
return sum(1 for h in self.hop_results if h.hit)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chunk_hit_count(self) -> int:
|
||||||
|
return sum(1 for h in self.hop_results if h.chunk_hit)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def full_hit(self) -> bool:
|
||||||
|
mappable = [h for h in self.hop_results if h.file_id]
|
||||||
|
return len(mappable) > 0 and all(h.hit for h in mappable)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def full_chunk_hit(self) -> bool:
|
||||||
|
mappable = [h for h in self.hop_results if h.expected_chunk_id]
|
||||||
|
return len(mappable) > 0 and all(h.chunk_hit for h in mappable)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def partial_hit(self) -> bool:
|
||||||
|
return any(h.hit for h in self.hop_results)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def partial_chunk_hit(self) -> bool:
|
||||||
|
return any(h.chunk_hit for h in self.hop_results)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retrieved(self) -> list[dict]:
|
||||||
|
"""所有跳的召回结果合并去重"""
|
||||||
|
seen: set[str] = set()
|
||||||
|
merged = []
|
||||||
|
for ah in self.actual_hops:
|
||||||
|
for doc in ah.retrieved:
|
||||||
|
key = doc.get("file_id", "") + doc.get("headers", "")
|
||||||
|
if key not in seen:
|
||||||
|
seen.add(key)
|
||||||
|
merged.append(doc)
|
||||||
|
return merged
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retrieved_file_ids(self) -> set[str]:
|
||||||
|
return {r.get("file_id", "") for r in self.retrieved if r.get("file_id")}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def best_cosine_sim(self) -> float | None:
|
||||||
|
sims = [1.0 - r.get("cosine_distance_1", 1.0)
|
||||||
|
for r in self.retrieved if r.get("cosine_distance_1") is not None]
|
||||||
|
return round(max(sims), 4) if sims else None
|
||||||
|
|
||||||
|
|
||||||
|
async def _parse_agent_chat_sse(
|
||||||
|
session: aiohttp.ClientSession,
|
||||||
|
url: str,
|
||||||
|
payload: dict,
|
||||||
|
timeout_s: int = 300,
|
||||||
|
) -> tuple[list[ActualHop], str]:
|
||||||
|
"""
|
||||||
|
调用 /agent/chat SSE 接口,解析流中的事件。
|
||||||
|
|
||||||
|
返回:(actual_hops, agent_answer)
|
||||||
|
|
||||||
|
SSE 格式:每行一条 `data: {...}` 消息,行间以单个 \n 分隔(不是 \n\n)。
|
||||||
|
"""
|
||||||
|
import re as _re
|
||||||
|
|
||||||
|
actual_hops: list[ActualHop] = []
|
||||||
|
answer_chunks: list[str] = []
|
||||||
|
tool_query = ""
|
||||||
|
hop_index = 0
|
||||||
|
|
||||||
|
async with session.post(
|
||||||
|
url, json=payload,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=timeout_s),
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
# 逐行读取:服务端每行一条 data: 消息
|
||||||
|
line_buf = ""
|
||||||
|
async for raw in resp.content:
|
||||||
|
line_buf += raw.decode("utf-8", errors="replace")
|
||||||
|
# 按换行切割,保留末尾不完整行
|
||||||
|
while "\n" in line_buf:
|
||||||
|
line, line_buf = line_buf.split("\n", 1)
|
||||||
|
line = line.rstrip("\r")
|
||||||
|
if not line.startswith("data:"):
|
||||||
|
continue
|
||||||
|
data_str = line[5:].strip()
|
||||||
|
if not data_str or data_str == "[DONE]":
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
parsed = json.loads(data_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
mt = parsed.get("message_type", "")
|
||||||
|
is_chunk = parsed.get("is_chunk_data", False)
|
||||||
|
data = parsed.get("data", "")
|
||||||
|
|
||||||
|
# 收集 Agent 最终回答
|
||||||
|
if is_chunk and mt not in ("THINKING_CHUNK", "EVENT"):
|
||||||
|
if isinstance(data, str):
|
||||||
|
answer_chunks.append(data)
|
||||||
|
|
||||||
|
# 收集 TOOL_CHUNK 中的 query 参数
|
||||||
|
if mt == "TOOL_CHUNK" and is_chunk and isinstance(data, str):
|
||||||
|
tool_query += data
|
||||||
|
|
||||||
|
# 解析 EVENT
|
||||||
|
if mt == "EVENT" and not is_chunk:
|
||||||
|
try:
|
||||||
|
ed = json.loads(data) if isinstance(data, str) else data
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
continue
|
||||||
|
if not isinstance(ed, dict):
|
||||||
|
continue
|
||||||
|
ename = ed.get("event_name", "")
|
||||||
|
|
||||||
|
if ename == "TOOL_START":
|
||||||
|
tool_query = ""
|
||||||
|
|
||||||
|
elif ename == "TOOL_END":
|
||||||
|
edata = ed.get("event_data")
|
||||||
|
docs = []
|
||||||
|
if isinstance(edata, dict) and "items" in edata:
|
||||||
|
for item in edata["items"]:
|
||||||
|
file_id = str(item.get("file_id") or "")
|
||||||
|
chunk_id = str(item.get("paragraph_chunk_id") or "")
|
||||||
|
# 跳过外链类工具(无 file_id/chunk_id)
|
||||||
|
if not file_id and not chunk_id:
|
||||||
|
continue
|
||||||
|
docs.append({
|
||||||
|
"file_id": file_id,
|
||||||
|
"headers": item.get("headers", ""),
|
||||||
|
"paragraph_md5": item.get("paragraph_md5", ""),
|
||||||
|
"paragraph_chunk_id": chunk_id,
|
||||||
|
})
|
||||||
|
|
||||||
|
# 只记录真正召回了知识切片的 hop
|
||||||
|
if docs:
|
||||||
|
hop_index += 1
|
||||||
|
query_match = _re.search(
|
||||||
|
r"<query>(.*?)</query>", tool_query, _re.DOTALL
|
||||||
|
)
|
||||||
|
query_text = (
|
||||||
|
query_match.group(1).strip()
|
||||||
|
if query_match
|
||||||
|
else tool_query.strip()
|
||||||
|
)
|
||||||
|
actual_hops.append(ActualHop(
|
||||||
|
hop_index=hop_index,
|
||||||
|
query=query_text,
|
||||||
|
retrieved=docs,
|
||||||
|
))
|
||||||
|
tool_query = ""
|
||||||
|
|
||||||
|
agent_answer = "".join(answer_chunks).strip()
|
||||||
|
return actual_hops, agent_answer
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHopTester:
|
||||||
|
def __init__(self, env_url: str, org_id: str, d_user_id: str = "test",
|
||||||
|
agent_id: str = "", llm_type: str = "deepseek_v3"):
|
||||||
|
self.env_url = env_url.rstrip("/")
|
||||||
|
self.org_id = org_id
|
||||||
|
self.agent_id = agent_id
|
||||||
|
self.llm_type = llm_type
|
||||||
|
self.headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"d-user-id": d_user_id,
|
||||||
|
"org-id": org_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
qa_pairs: list[MultiHopQAPair],
|
||||||
|
file_map: dict[str, dict | None],
|
||||||
|
top_k: int = 10,
|
||||||
|
concurrency: int = 5,
|
||||||
|
result_cb=None,
|
||||||
|
) -> list[MultiHopResult]:
|
||||||
|
results: list[MultiHopResult] = []
|
||||||
|
sem = asyncio.Semaphore(concurrency)
|
||||||
|
total = len(qa_pairs)
|
||||||
|
done = 0
|
||||||
|
|
||||||
|
connector = aiohttp.TCPConnector(ssl=False)
|
||||||
|
async with aiohttp.ClientSession(
|
||||||
|
headers=self.headers, connector=connector
|
||||||
|
) as session:
|
||||||
|
|
||||||
|
async def _test_one(qa: MultiHopQAPair) -> MultiHopResult:
|
||||||
|
nonlocal done
|
||||||
|
|
||||||
|
hop_results = []
|
||||||
|
for hop in qa.hops:
|
||||||
|
mapping = file_map.get(hop.section_path)
|
||||||
|
hop_results.append(HopResult(
|
||||||
|
section_path=hop.section_path,
|
||||||
|
file_id=mapping["file_id"] if mapping else None,
|
||||||
|
file_name=mapping["file_name"] if mapping else None,
|
||||||
|
contribution=hop.contribution,
|
||||||
|
expected_chunk_id=hop.chunk_id or "",
|
||||||
|
))
|
||||||
|
|
||||||
|
result = MultiHopResult(
|
||||||
|
qid=qa.qid,
|
||||||
|
question=qa.question,
|
||||||
|
answer=qa.answer,
|
||||||
|
type=qa.type,
|
||||||
|
top_k=top_k,
|
||||||
|
hop_results=hop_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with sem:
|
||||||
|
start = time.monotonic()
|
||||||
|
try:
|
||||||
|
import uuid
|
||||||
|
# 构建 chat URL:如果 env_url 以 /dagent 结尾,则拼接 /agent/chat,否则拼接 /dagent/agent/chat
|
||||||
|
base = self.env_url.rstrip("/")
|
||||||
|
if base.endswith("/dagent"):
|
||||||
|
chat_url = f"{base}/agent/chat"
|
||||||
|
else:
|
||||||
|
chat_url = f"{base}/dagent/agent/chat"
|
||||||
|
payload = {
|
||||||
|
"task": qa.question,
|
||||||
|
"agent_id": self.agent_id,
|
||||||
|
"chat_id": uuid.uuid4().hex,
|
||||||
|
"llm_type": self.llm_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
actual_hops, agent_answer = await _parse_agent_chat_sse(
|
||||||
|
session, chat_url, payload, timeout_s=300,
|
||||||
|
)
|
||||||
|
result.actual_hops = actual_hops
|
||||||
|
result.agent_answer = agent_answer
|
||||||
|
result.latency_ms = int(
|
||||||
|
(time.monotonic() - start) * 1000
|
||||||
|
)
|
||||||
|
|
||||||
|
# 文件级命中:期望文件是否出现在任意一跳召回中
|
||||||
|
for hr in result.hop_results:
|
||||||
|
if hr.file_id:
|
||||||
|
for ah in actual_hops:
|
||||||
|
if any(
|
||||||
|
d.get("file_id") == hr.file_id
|
||||||
|
for d in ah.retrieved
|
||||||
|
):
|
||||||
|
hr.hit = True
|
||||||
|
hr.hit_at_hop = ah.hop_index
|
||||||
|
break
|
||||||
|
# 切片级命中:期望 chunk_id 是否出现在任意一跳召回中
|
||||||
|
if hr.expected_chunk_id:
|
||||||
|
for ah in actual_hops:
|
||||||
|
if any(
|
||||||
|
d.get("paragraph_chunk_id") == hr.expected_chunk_id
|
||||||
|
for d in ah.retrieved
|
||||||
|
):
|
||||||
|
hr.chunk_hit = True
|
||||||
|
hr.chunk_hit_at_hop = ah.hop_index
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result.error = str(e)
|
||||||
|
result.latency_ms = int(
|
||||||
|
(time.monotonic() - start) * 1000
|
||||||
|
)
|
||||||
|
|
||||||
|
done += 1
|
||||||
|
if result_cb:
|
||||||
|
await result_cb(result, done, total)
|
||||||
|
return result
|
||||||
|
|
||||||
|
tasks = [_test_one(qa) for qa in qa_pairs]
|
||||||
|
for coro in asyncio.as_completed(tasks):
|
||||||
|
results.append(await coro)
|
||||||
|
|
||||||
|
return results
|
||||||
128
sdk/rag_eval/report.py
Normal file
128
sdk/rag_eval/report.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SampleResult:
|
||||||
|
sample_id: str
|
||||||
|
question: str
|
||||||
|
reference_answer: str
|
||||||
|
# Retrieval
|
||||||
|
retrieved_chunk_ids: list[str] = field(default_factory=list)
|
||||||
|
retrieved_chunks: list[str] = field(default_factory=list)
|
||||||
|
hit_rate: float | None = None
|
||||||
|
mrr: float | None = None
|
||||||
|
ndcg: float | None = None
|
||||||
|
context_precision: float | None = None
|
||||||
|
context_recall: float | None = None
|
||||||
|
# Generation
|
||||||
|
agent_answer: str = ""
|
||||||
|
faithfulness: float | None = None
|
||||||
|
answer_relevance: float | None = None
|
||||||
|
answer_correctness: float | None = None
|
||||||
|
groundedness: float | None = None
|
||||||
|
latency_ms: int = 0
|
||||||
|
# Raw judge output
|
||||||
|
judge_detail: dict = field(default_factory=dict)
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvalReport:
|
||||||
|
task_id: str
|
||||||
|
dataset_name: str
|
||||||
|
sample_count: int
|
||||||
|
results: list[SampleResult]
|
||||||
|
# Retrieval averages
|
||||||
|
avg_hit_rate: float | None = None
|
||||||
|
avg_mrr: float | None = None
|
||||||
|
avg_ndcg: float | None = None
|
||||||
|
avg_context_precision: float | None = None
|
||||||
|
avg_context_recall: float | None = None
|
||||||
|
# Generation averages
|
||||||
|
avg_faithfulness: float | None = None
|
||||||
|
avg_answer_relevance: float | None = None
|
||||||
|
avg_answer_correctness: float | None = None
|
||||||
|
avg_groundedness: float | None = None
|
||||||
|
# Composite
|
||||||
|
rag_score: float | None = None
|
||||||
|
hallucination_rate: float | None = None
|
||||||
|
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
def summary(self) -> str:
|
||||||
|
lines = [
|
||||||
|
"┌─────────────────────────────────────────┐",
|
||||||
|
"│ 评测报告摘要 │",
|
||||||
|
"├──────────────────────┬──────────────────┤",
|
||||||
|
f"│ 样本数 │ {self.sample_count:<16} │",
|
||||||
|
]
|
||||||
|
def _row(label, val):
|
||||||
|
v = f"{val:.4f}" if val is not None else "N/A"
|
||||||
|
return f"│ {label:<20} │ {v:<16} │"
|
||||||
|
|
||||||
|
lines += [
|
||||||
|
_row("Hit Rate@K", self.avg_hit_rate),
|
||||||
|
_row("MRR@K", self.avg_mrr),
|
||||||
|
_row("NDCG@K", self.avg_ndcg),
|
||||||
|
_row("Context Precision", self.avg_context_precision),
|
||||||
|
_row("Context Recall", self.avg_context_recall),
|
||||||
|
_row("Faithfulness", self.avg_faithfulness),
|
||||||
|
_row("Answer Relevance", self.avg_answer_relevance),
|
||||||
|
_row("Answer Correctness", self.avg_answer_correctness),
|
||||||
|
_row("Groundedness", self.avg_groundedness),
|
||||||
|
_row("RAG Score", self.rag_score),
|
||||||
|
_row("Hallucination Rate", self.hallucination_rate),
|
||||||
|
"└──────────────────────┴──────────────────┘",
|
||||||
|
]
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"task_id": self.task_id,
|
||||||
|
"dataset_name": self.dataset_name,
|
||||||
|
"sample_count": self.sample_count,
|
||||||
|
"created_at": self.created_at.isoformat(),
|
||||||
|
"retrieval": {
|
||||||
|
"avg_hit_rate": self.avg_hit_rate,
|
||||||
|
"avg_mrr": self.avg_mrr,
|
||||||
|
"avg_ndcg": self.avg_ndcg,
|
||||||
|
"avg_context_precision": self.avg_context_precision,
|
||||||
|
"avg_context_recall": self.avg_context_recall,
|
||||||
|
},
|
||||||
|
"generation": {
|
||||||
|
"avg_faithfulness": self.avg_faithfulness,
|
||||||
|
"avg_answer_relevance": self.avg_answer_relevance,
|
||||||
|
"avg_answer_correctness": self.avg_answer_correctness,
|
||||||
|
"avg_groundedness": self.avg_groundedness,
|
||||||
|
},
|
||||||
|
"composite": {
|
||||||
|
"rag_score": self.rag_score,
|
||||||
|
"hallucination_rate": self.hallucination_rate,
|
||||||
|
},
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"sample_id": r.sample_id,
|
||||||
|
"question": r.question,
|
||||||
|
"agent_answer": r.agent_answer,
|
||||||
|
"retrieved_chunk_ids": r.retrieved_chunk_ids,
|
||||||
|
"hit_rate": r.hit_rate,
|
||||||
|
"mrr": r.mrr,
|
||||||
|
"ndcg": r.ndcg,
|
||||||
|
"context_precision": r.context_precision,
|
||||||
|
"context_recall": r.context_recall,
|
||||||
|
"faithfulness": r.faithfulness,
|
||||||
|
"answer_relevance": r.answer_relevance,
|
||||||
|
"answer_correctness": r.answer_correctness,
|
||||||
|
"groundedness": r.groundedness,
|
||||||
|
"latency_ms": r.latency_ms,
|
||||||
|
"error": r.error,
|
||||||
|
}
|
||||||
|
for r in self.results
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def save(self, path: str):
|
||||||
|
import json
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(self.to_dict(), f, ensure_ascii=False, indent=2)
|
||||||
|
print(f"Report saved to {path}")
|
||||||
257
sdk/rag_eval/runner.py
Normal file
257
sdk/rag_eval/runner.py
Normal file
@ -0,0 +1,257 @@
|
|||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from .adapters.base import RAGAdapter
|
||||||
|
from .judge.base import LLMJudge
|
||||||
|
from .evaluators.retrieval import hit_rate, mrr, ndcg
|
||||||
|
from .dataset.schema import EvalDataset, EvalSample
|
||||||
|
from .report import EvalReport, SampleResult
|
||||||
|
|
||||||
|
|
||||||
|
RETRIEVAL_METRIC_KEYS = {"hit_rate", "mrr", "ndcg", "context_precision", "context_recall"}
|
||||||
|
GENERATION_METRIC_KEYS = {"faithfulness", "answer_relevance", "answer_correctness", "groundedness"}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RunConfig:
|
||||||
|
agent_id: str
|
||||||
|
knowledge_hub_id: str
|
||||||
|
top_k: int = 10
|
||||||
|
eval_retrieval: bool = True
|
||||||
|
eval_generation: bool = True
|
||||||
|
selected_metrics: list[str] | None = None
|
||||||
|
file_id_list: list[str] | None = None
|
||||||
|
concurrency: int = 3 # 并发评测样本数
|
||||||
|
faithfulness_threshold: float = 0.7 # 低于此值视为幻觉
|
||||||
|
|
||||||
|
def should_eval(self, metric_key: str) -> bool:
|
||||||
|
"""判断是否需要计算某个指标"""
|
||||||
|
if self.selected_metrics:
|
||||||
|
return metric_key in self.selected_metrics
|
||||||
|
# 向后兼容:未指定 selected_metrics 时按 eval_retrieval/eval_generation 开关
|
||||||
|
if metric_key in RETRIEVAL_METRIC_KEYS:
|
||||||
|
return self.eval_retrieval
|
||||||
|
if metric_key in GENERATION_METRIC_KEYS:
|
||||||
|
return self.eval_generation
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def need_retrieval(self) -> bool:
|
||||||
|
if self.selected_metrics:
|
||||||
|
return bool(set(self.selected_metrics) & RETRIEVAL_METRIC_KEYS)
|
||||||
|
return self.eval_retrieval
|
||||||
|
|
||||||
|
@property
|
||||||
|
def need_generation(self) -> bool:
|
||||||
|
if self.selected_metrics:
|
||||||
|
return bool(set(self.selected_metrics) & GENERATION_METRIC_KEYS)
|
||||||
|
return self.eval_generation
|
||||||
|
|
||||||
|
|
||||||
|
class EvalRunner:
|
||||||
|
def __init__(self, adapter: RAGAdapter, judge: LLMJudge):
|
||||||
|
self.adapter = adapter
|
||||||
|
self.judge = judge
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
dataset: EvalDataset | str,
|
||||||
|
config: RunConfig,
|
||||||
|
progress_cb: Callable[[int, int], None] | None = None,
|
||||||
|
) -> EvalReport:
|
||||||
|
"""
|
||||||
|
运行完整评测流程。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: EvalDataset 对象或 JSON 文件路径
|
||||||
|
config: 评测配置
|
||||||
|
progress_cb: 进度回调 (finished, total)
|
||||||
|
"""
|
||||||
|
if isinstance(dataset, str):
|
||||||
|
import json
|
||||||
|
with open(dataset, encoding="utf-8") as f:
|
||||||
|
dataset = EvalDataset.from_dict(json.load(f))
|
||||||
|
|
||||||
|
samples = dataset.samples
|
||||||
|
total = len(samples)
|
||||||
|
results: list[SampleResult] = []
|
||||||
|
finished = 0
|
||||||
|
|
||||||
|
sem = asyncio.Semaphore(config.concurrency)
|
||||||
|
|
||||||
|
async def _eval_one(sample: EvalSample) -> SampleResult:
|
||||||
|
async with sem:
|
||||||
|
return await self._eval_sample(sample, config)
|
||||||
|
|
||||||
|
tasks = [_eval_one(s) for s in samples]
|
||||||
|
|
||||||
|
for coro in asyncio.as_completed(tasks):
|
||||||
|
result = await coro
|
||||||
|
results.append(result)
|
||||||
|
finished += 1
|
||||||
|
if progress_cb:
|
||||||
|
progress_cb(finished, total)
|
||||||
|
|
||||||
|
return self._build_report(
|
||||||
|
task_id=uuid.uuid4().hex,
|
||||||
|
dataset=dataset,
|
||||||
|
results=results,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _eval_sample(self, sample: EvalSample, config: RunConfig) -> SampleResult:
|
||||||
|
result = SampleResult(
|
||||||
|
sample_id=sample.id,
|
||||||
|
question=sample.question,
|
||||||
|
reference_answer=sample.reference_answer,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# ── Step 1: Retrieval ─────────────────────────────────────────
|
||||||
|
if config.need_retrieval:
|
||||||
|
chunks = await self.adapter.retrieve(
|
||||||
|
query=sample.question,
|
||||||
|
knowledge_hub_id=config.knowledge_hub_id,
|
||||||
|
top_k=config.top_k,
|
||||||
|
file_id_list=config.file_id_list,
|
||||||
|
)
|
||||||
|
result.retrieved_chunk_ids = [c.chunk_id for c in chunks]
|
||||||
|
result.retrieved_chunks = [c.content for c in chunks]
|
||||||
|
|
||||||
|
# Rule-based metrics
|
||||||
|
if sample.relevant_chunk_ids:
|
||||||
|
if config.should_eval("hit_rate"):
|
||||||
|
result.hit_rate = hit_rate(result.retrieved_chunk_ids, sample.relevant_chunk_ids)
|
||||||
|
if config.should_eval("mrr"):
|
||||||
|
result.mrr = mrr(result.retrieved_chunk_ids, sample.relevant_chunk_ids)
|
||||||
|
if config.should_eval("ndcg"):
|
||||||
|
result.ndcg = ndcg(result.retrieved_chunk_ids, sample.relevant_chunk_ids, k=config.top_k)
|
||||||
|
|
||||||
|
# LLM-as-Judge retrieval metrics
|
||||||
|
if sample.reference_answer and result.retrieved_chunks:
|
||||||
|
if config.should_eval("context_precision"):
|
||||||
|
cp, raw_cp = await self.judge.score_context_precision(
|
||||||
|
sample.question, sample.reference_answer, result.retrieved_chunks
|
||||||
|
)
|
||||||
|
result.context_precision = cp
|
||||||
|
result.judge_detail["context_precision"] = raw_cp
|
||||||
|
|
||||||
|
if config.should_eval("context_recall"):
|
||||||
|
cr, raw_cr = await self.judge.score_context_recall(
|
||||||
|
sample.reference_answer, result.retrieved_chunks
|
||||||
|
)
|
||||||
|
result.context_recall = cr
|
||||||
|
result.judge_detail["context_recall"] = raw_cr
|
||||||
|
|
||||||
|
# ── Step 2: Generation ────────────────────────────────────────
|
||||||
|
if config.need_generation:
|
||||||
|
agent_resp = await self.adapter.chat(
|
||||||
|
query=sample.question,
|
||||||
|
agent_id=config.agent_id,
|
||||||
|
)
|
||||||
|
result.agent_answer = agent_resp.answer
|
||||||
|
result.latency_ms = agent_resp.latency_ms
|
||||||
|
|
||||||
|
# 若检索阶段被跳过,单独 retrieve 一次以支撑生成指标评判
|
||||||
|
if not result.retrieved_chunks:
|
||||||
|
try:
|
||||||
|
chunks = await self.adapter.retrieve(
|
||||||
|
query=sample.question,
|
||||||
|
knowledge_hub_id=config.knowledge_hub_id,
|
||||||
|
top_k=config.top_k,
|
||||||
|
file_id_list=config.file_id_list,
|
||||||
|
)
|
||||||
|
result.retrieved_chunk_ids = [c.chunk_id for c in chunks]
|
||||||
|
result.retrieved_chunks = [c.content for c in chunks]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if result.agent_answer and result.retrieved_chunks:
|
||||||
|
if config.should_eval("faithfulness"):
|
||||||
|
faith, raw_faith = await self.judge.score_faithfulness(
|
||||||
|
result.agent_answer, result.retrieved_chunks
|
||||||
|
)
|
||||||
|
result.faithfulness = faith
|
||||||
|
result.judge_detail["faithfulness"] = raw_faith
|
||||||
|
|
||||||
|
if config.should_eval("answer_relevance"):
|
||||||
|
rel, raw_rel = await self.judge.score_relevance(
|
||||||
|
sample.question, result.agent_answer
|
||||||
|
)
|
||||||
|
result.answer_relevance = rel
|
||||||
|
result.judge_detail["answer_relevance"] = raw_rel
|
||||||
|
|
||||||
|
if config.should_eval("groundedness"):
|
||||||
|
ground, raw_ground = await self.judge.score_groundedness(
|
||||||
|
result.agent_answer,
|
||||||
|
[{"content": c} for c in result.retrieved_chunks],
|
||||||
|
)
|
||||||
|
result.groundedness = ground
|
||||||
|
result.judge_detail["groundedness"] = raw_ground
|
||||||
|
|
||||||
|
if config.should_eval("answer_correctness") and sample.reference_answer:
|
||||||
|
corr, raw_corr = await self.judge.score_correctness(
|
||||||
|
result.agent_answer, sample.reference_answer
|
||||||
|
)
|
||||||
|
result.answer_correctness = corr
|
||||||
|
result.judge_detail["answer_correctness"] = raw_corr
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
result.error = str(exc)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _build_report(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
dataset: EvalDataset,
|
||||||
|
results: list[SampleResult],
|
||||||
|
config: RunConfig,
|
||||||
|
) -> EvalReport:
|
||||||
|
def _avg(vals: list[float]) -> float | None:
|
||||||
|
v = [x for x in vals if x is not None]
|
||||||
|
return round(sum(v) / len(v), 4) if v else None
|
||||||
|
|
||||||
|
def _collect(attr: str) -> list[float]:
|
||||||
|
return [getattr(r, attr) for r in results if getattr(r, attr) is not None]
|
||||||
|
|
||||||
|
avg_hit_rate = _avg(_collect("hit_rate"))
|
||||||
|
avg_mrr = _avg(_collect("mrr"))
|
||||||
|
avg_ndcg = _avg(_collect("ndcg"))
|
||||||
|
avg_ctx_prec = _avg(_collect("context_precision"))
|
||||||
|
avg_ctx_rec = _avg(_collect("context_recall"))
|
||||||
|
avg_faithfulness = _avg(_collect("faithfulness"))
|
||||||
|
avg_answer_relevance = _avg(_collect("answer_relevance"))
|
||||||
|
avg_answer_correctness= _avg(_collect("answer_correctness"))
|
||||||
|
avg_groundedness = _avg(_collect("groundedness"))
|
||||||
|
|
||||||
|
# RAG Score: harmonic mean of four core metrics
|
||||||
|
core = [s for s in [avg_faithfulness, avg_answer_relevance, avg_ctx_prec, avg_ctx_rec]
|
||||||
|
if s is not None and s > 0]
|
||||||
|
rag_score = round(len(core) / sum(1 / s for s in core), 4) if core else None
|
||||||
|
|
||||||
|
# Hallucination Rate
|
||||||
|
faith_vals = _collect("faithfulness")
|
||||||
|
hallucination_rate = (
|
||||||
|
round(sum(1 for f in faith_vals if f < config.faithfulness_threshold) / len(faith_vals), 4)
|
||||||
|
if faith_vals else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return EvalReport(
|
||||||
|
task_id=task_id,
|
||||||
|
dataset_name=dataset.name,
|
||||||
|
sample_count=len(results),
|
||||||
|
results=results,
|
||||||
|
avg_hit_rate=avg_hit_rate,
|
||||||
|
avg_mrr=avg_mrr,
|
||||||
|
avg_ndcg=avg_ndcg,
|
||||||
|
avg_context_precision=avg_ctx_prec,
|
||||||
|
avg_context_recall=avg_ctx_rec,
|
||||||
|
avg_faithfulness=avg_faithfulness,
|
||||||
|
avg_answer_relevance=avg_answer_relevance,
|
||||||
|
avg_answer_correctness=avg_answer_correctness,
|
||||||
|
avg_groundedness=avg_groundedness,
|
||||||
|
rag_score=rag_score,
|
||||||
|
hallucination_rate=hallucination_rate,
|
||||||
|
)
|
||||||
354
sdk/rag_eval/semantic_coverage.py
Normal file
354
sdk/rag_eval/semantic_coverage.py
Normal file
@ -0,0 +1,354 @@
|
|||||||
|
"""
|
||||||
|
语义覆盖率监控模块
|
||||||
|
|
||||||
|
基于最近邻距离的语义覆盖率方案:
|
||||||
|
- 计算新问题与已有问题集的语义距离
|
||||||
|
- 当平均距离低于阈值时,认为该切片的问题空间已被充分探索
|
||||||
|
- 用于判断循环测试何时应该停止
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Dict, Tuple, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SemanticCoverageResult:
|
||||||
|
"""语义覆盖率结果"""
|
||||||
|
chunk_id: str
|
||||||
|
total_questions: int
|
||||||
|
avg_neighbor_distance: float
|
||||||
|
min_neighbor_distance: float
|
||||||
|
coverage_score: float # 0-1,越高表示覆盖越充分
|
||||||
|
is_converged: bool
|
||||||
|
recommended_action: str # 'continue', 'stop', 'reduce'
|
||||||
|
|
||||||
|
|
||||||
|
class SemanticCoverageMonitor:
|
||||||
|
"""
|
||||||
|
语义覆盖率监控器
|
||||||
|
|
||||||
|
算法:
|
||||||
|
1. 使用embedding表示每个问题的语义
|
||||||
|
2. 对每个新问题,计算其与已有问题的最小距离
|
||||||
|
3. 当平均最小距离 < threshold时,认为收敛
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
threshold: float = 0.15,
|
||||||
|
min_questions: int = 3,
|
||||||
|
max_questions: int = 20,
|
||||||
|
embedding_model: str = "text-embedding-3-small",
|
||||||
|
):
|
||||||
|
self.threshold = threshold
|
||||||
|
self.min_questions = min_questions
|
||||||
|
self.max_questions = max_questions
|
||||||
|
self.embedding_model = embedding_model
|
||||||
|
self._embeddings_cache: Dict[str, List[float]] = {}
|
||||||
|
|
||||||
|
async def _get_embedding(self, text: str, client) -> List[float]:
|
||||||
|
"""获取文本的embedding"""
|
||||||
|
cache_key = hash(text)
|
||||||
|
if cache_key in self._embeddings_cache:
|
||||||
|
return self._embeddings_cache[cache_key]
|
||||||
|
|
||||||
|
resp = await client.embeddings.create(
|
||||||
|
model=self.embedding_model,
|
||||||
|
input=text,
|
||||||
|
)
|
||||||
|
embedding = resp.data[0].embedding
|
||||||
|
self._embeddings_cache[cache_key] = embedding
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def _cosine_distance(self, a: List[float], b: List[float]) -> float:
|
||||||
|
"""计算余弦距离"""
|
||||||
|
a_np = np.array(a)
|
||||||
|
b_np = np.array(b)
|
||||||
|
cos_sim = np.dot(a_np, b_np) / (np.linalg.norm(a_np) * np.linalg.norm(b_np) + 1e-9)
|
||||||
|
return 1.0 - cos_sim
|
||||||
|
|
||||||
|
def _calculate_coverage_metrics(
|
||||||
|
self,
|
||||||
|
embeddings: List[List[float]]
|
||||||
|
) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
计算覆盖率指标
|
||||||
|
|
||||||
|
返回: (平均最近邻距离, 最小最近邻距离)
|
||||||
|
"""
|
||||||
|
if len(embeddings) < 2:
|
||||||
|
return 1.0, 1.0 # 问题太少,返回最大距离
|
||||||
|
|
||||||
|
distances = []
|
||||||
|
min_distances = []
|
||||||
|
|
||||||
|
for i, emb_i in enumerate(embeddings):
|
||||||
|
# 计算与其他所有问题的距离
|
||||||
|
other_distances = []
|
||||||
|
for j, emb_j in enumerate(embeddings):
|
||||||
|
if i != j:
|
||||||
|
dist = self._cosine_distance(emb_i, emb_j)
|
||||||
|
other_distances.append(dist)
|
||||||
|
|
||||||
|
if other_distances:
|
||||||
|
min_dist = min(other_distances)
|
||||||
|
min_distances.append(min_dist)
|
||||||
|
distances.extend(other_distances)
|
||||||
|
|
||||||
|
avg_min_distance = np.mean(min_distances) if min_distances else 1.0
|
||||||
|
min_distance = min(min_distances) if min_distances else 1.0
|
||||||
|
|
||||||
|
return avg_min_distance, min_distance
|
||||||
|
|
||||||
|
async def evaluate_chunk_coverage(
|
||||||
|
self,
|
||||||
|
chunk_id: str,
|
||||||
|
questions: List[str],
|
||||||
|
client,
|
||||||
|
) -> SemanticCoverageResult:
|
||||||
|
"""
|
||||||
|
评估单个切片的语义覆盖率
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_id: 切片ID
|
||||||
|
questions: 该切片已有的问题列表
|
||||||
|
client: OpenAI客户端用于获取embedding
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SemanticCoverageResult: 覆盖率评估结果
|
||||||
|
"""
|
||||||
|
total = len(questions)
|
||||||
|
|
||||||
|
# 问题数不足
|
||||||
|
if total < self.min_questions:
|
||||||
|
return SemanticCoverageResult(
|
||||||
|
chunk_id=chunk_id,
|
||||||
|
total_questions=total,
|
||||||
|
avg_neighbor_distance=1.0,
|
||||||
|
min_neighbor_distance=1.0,
|
||||||
|
coverage_score=0.0,
|
||||||
|
is_converged=False,
|
||||||
|
recommended_action='continue',
|
||||||
|
)
|
||||||
|
|
||||||
|
# 问题数已达上限
|
||||||
|
if total >= self.max_questions:
|
||||||
|
return SemanticCoverageResult(
|
||||||
|
chunk_id=chunk_id,
|
||||||
|
total_questions=total,
|
||||||
|
avg_neighbor_distance=0.0,
|
||||||
|
min_neighbor_distance=0.0,
|
||||||
|
coverage_score=1.0,
|
||||||
|
is_converged=True,
|
||||||
|
recommended_action='stop',
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算embedding
|
||||||
|
embeddings = []
|
||||||
|
for q in questions:
|
||||||
|
emb = await self._get_embedding(q, client)
|
||||||
|
embeddings.append(emb)
|
||||||
|
|
||||||
|
# 计算覆盖率指标
|
||||||
|
avg_dist, min_dist = self._calculate_coverage_metrics(embeddings)
|
||||||
|
|
||||||
|
# 计算覆盖率分数 (0-1)
|
||||||
|
# 距离越小,覆盖率越高
|
||||||
|
coverage_score = max(0.0, 1.0 - (avg_dist / self.threshold))
|
||||||
|
|
||||||
|
# 判断是否收敛
|
||||||
|
is_converged = avg_dist < self.threshold
|
||||||
|
|
||||||
|
# 推荐动作
|
||||||
|
if is_converged:
|
||||||
|
recommended_action = 'stop'
|
||||||
|
elif total > self.max_questions * 0.8:
|
||||||
|
recommended_action = 'reduce' # 减少生成数量
|
||||||
|
else:
|
||||||
|
recommended_action = 'continue'
|
||||||
|
|
||||||
|
return SemanticCoverageResult(
|
||||||
|
chunk_id=chunk_id,
|
||||||
|
total_questions=total,
|
||||||
|
avg_neighbor_distance=avg_dist,
|
||||||
|
min_neighbor_distance=min_dist,
|
||||||
|
coverage_score=coverage_score,
|
||||||
|
is_converged=is_converged,
|
||||||
|
recommended_action=recommended_action,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def evaluate_batch_coverage(
|
||||||
|
self,
|
||||||
|
chunk_questions: Dict[str, List[str]],
|
||||||
|
client,
|
||||||
|
) -> Dict[str, SemanticCoverageResult]:
|
||||||
|
"""
|
||||||
|
评估一批切片的覆盖率
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_questions: {chunk_id: [question1, question2, ...]}
|
||||||
|
client: OpenAI客户端
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{chunk_id: SemanticCoverageResult}
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
for chunk_id, questions in chunk_questions.items():
|
||||||
|
result = await self.evaluate_chunk_coverage(chunk_id, questions, client)
|
||||||
|
results[chunk_id] = result
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_batch_summary(
|
||||||
|
self,
|
||||||
|
results: Dict[str, SemanticCoverageResult]
|
||||||
|
) -> Dict:
|
||||||
|
"""获取批次覆盖率汇总"""
|
||||||
|
total_chunks = len(results)
|
||||||
|
converged_chunks = sum(1 for r in results.values() if r.is_converged)
|
||||||
|
total_questions = sum(r.total_questions for r in results.values())
|
||||||
|
avg_coverage = np.mean([r.coverage_score for r in results.values()]) if results else 0.0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_chunks": total_chunks,
|
||||||
|
"converged_chunks": converged_chunks,
|
||||||
|
"convergence_rate": converged_chunks / total_chunks if total_chunks > 0 else 0.0,
|
||||||
|
"total_questions": total_questions,
|
||||||
|
"avg_coverage_score": avg_coverage,
|
||||||
|
"should_stop": converged_chunks / total_chunks > 0.9 if total_chunks > 0 else False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class LoopConvergenceChecker:
|
||||||
|
"""
|
||||||
|
循环任务收敛检查器
|
||||||
|
|
||||||
|
集成到loop_engine中,用于判断是否应该停止循环
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, monitor: SemanticCoverageMonitor):
|
||||||
|
self.monitor = monitor
|
||||||
|
|
||||||
|
async def check_convergence(
|
||||||
|
self,
|
||||||
|
qa_task_id: str,
|
||||||
|
client,
|
||||||
|
) -> Tuple[bool, Dict]:
|
||||||
|
"""
|
||||||
|
检查loop任务是否收敛
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(是否收敛, 详细信息)
|
||||||
|
"""
|
||||||
|
# 从数据库获取该任务的所有问题
|
||||||
|
from server.models.db import get_db
|
||||||
|
|
||||||
|
chunk_questions = {}
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"""SELECT chunk_id, question
|
||||||
|
FROM qa_gen_question
|
||||||
|
WHERE task_id=? AND status='approved' AND chunk_id IS NOT NULL""",
|
||||||
|
(qa_task_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
chunk_id = row["chunk_id"]
|
||||||
|
question = row["question"]
|
||||||
|
if chunk_id not in chunk_questions:
|
||||||
|
chunk_questions[chunk_id] = []
|
||||||
|
chunk_questions[chunk_id].append(question)
|
||||||
|
|
||||||
|
if not chunk_questions:
|
||||||
|
return False, {"reason": "no_questions_yet"}
|
||||||
|
|
||||||
|
# 评估覆盖率
|
||||||
|
results = await self.monitor.evaluate_batch_coverage(chunk_questions, client)
|
||||||
|
summary = self.monitor.get_batch_summary(results)
|
||||||
|
|
||||||
|
# 判断收敛条件
|
||||||
|
should_stop = summary["should_stop"]
|
||||||
|
|
||||||
|
details = {
|
||||||
|
"summary": summary,
|
||||||
|
"chunk_details": {
|
||||||
|
chunk_id: {
|
||||||
|
"questions": r.total_questions,
|
||||||
|
"coverage_score": r.coverage_score,
|
||||||
|
"is_converged": r.is_converged,
|
||||||
|
"action": r.recommended_action,
|
||||||
|
}
|
||||||
|
for chunk_id, r in results.items()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return should_stop, details
|
||||||
|
|
||||||
|
|
||||||
|
# 集成到loop_engine的示例代码(供参考)
|
||||||
|
LOOP_ENGINE_INTEGRATION = '''
|
||||||
|
# 在 loop_engine.py 中的 _do_run_loop 函数中添加
|
||||||
|
|
||||||
|
async def _check_semantic_convergence(
|
||||||
|
self,
|
||||||
|
qa_task_id: str,
|
||||||
|
llm_client,
|
||||||
|
) -> Tuple[bool, Dict]:
|
||||||
|
"""检查语义覆盖率是否收敛"""
|
||||||
|
from .semantic_coverage import SemanticCoverageMonitor, LoopConvergenceChecker
|
||||||
|
|
||||||
|
monitor = SemanticCoverageMonitor(
|
||||||
|
threshold=0.15,
|
||||||
|
min_questions=3,
|
||||||
|
max_questions=20,
|
||||||
|
)
|
||||||
|
checker = LoopConvergenceChecker(monitor)
|
||||||
|
|
||||||
|
should_stop, details = await checker.check_convergence(qa_task_id, llm_client)
|
||||||
|
return should_stop, details
|
||||||
|
|
||||||
|
# 在每轮结束时调用
|
||||||
|
should_stop, convergence_details = await self._check_semantic_convergence(
|
||||||
|
qa_task_id, llm_client
|
||||||
|
)
|
||||||
|
if should_stop:
|
||||||
|
print(f"[Loop] Semantic convergence reached, stopping...")
|
||||||
|
break
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
# 命令行工具
|
||||||
|
async def main():
|
||||||
|
"""分析当前任务的语义覆盖率"""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="语义覆盖率分析工具")
|
||||||
|
parser.add_argument("--task-id", help="QA生成任务ID")
|
||||||
|
parser.add_argument("--threshold", type=float, default=0.15, help="收敛阈值")
|
||||||
|
parser.add_argument("--base-url", default="https://api.openai.com/v1", help="API base URL")
|
||||||
|
parser.add_argument("--api-key", required=True, help="API key")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
client = AsyncOpenAI(
|
||||||
|
base_url=args.base_url,
|
||||||
|
api_key=args.api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
monitor = SemanticCoverageMonitor(threshold=args.threshold)
|
||||||
|
checker = LoopConvergenceChecker(monitor)
|
||||||
|
|
||||||
|
if args.task_id:
|
||||||
|
should_stop, details = await checker.check_convergence(args.task_id, client)
|
||||||
|
print(json.dumps(details, indent=2, ensure_ascii=False))
|
||||||
|
print(f"\n建议: {'停止' if should_stop else '继续'}生成问题")
|
||||||
|
else:
|
||||||
|
print("请提供 --task-id 参数")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
0
sdk/rag_eval/single_jump/__init__.py
Normal file
0
sdk/rag_eval/single_jump/__init__.py
Normal file
137
sdk/rag_eval/single_jump/cli.py
Normal file
137
sdk/rag_eval/single_jump/cli.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
"""
|
||||||
|
单跳召回测试 CLI 入口。
|
||||||
|
|
||||||
|
用法:
|
||||||
|
python -m rag_eval.single_jump.cli \
|
||||||
|
--env-url https://cloud-dev.d-robotics.cc \
|
||||||
|
--org-id dc778d0ae0aade4c33e19342ddd4fe72e68021623de5ff0e7c6b63dc04c7a1a7 \
|
||||||
|
--qa-file "D:/evb知识库/EVB知识库完整问答集.md" \
|
||||||
|
--top-k 5 \
|
||||||
|
--output report.json
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
async def run(args):
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
|
from rag_eval.single_jump.parser import parse_qa_file
|
||||||
|
from rag_eval.single_jump.mapper import FileMapper
|
||||||
|
from rag_eval.single_jump.tester import RecallTester
|
||||||
|
from rag_eval.single_jump.quality import check_recall_quality
|
||||||
|
from rag_eval.single_jump.report import build_report
|
||||||
|
|
||||||
|
# ── Step 1: 解析 MD 文件 ──────────────────────────────────────
|
||||||
|
print(f"解析问答集文件: {args.qa_file}")
|
||||||
|
sections = parse_qa_file(args.qa_file)
|
||||||
|
total_qa = sum(len(s.qa_pairs) for s in sections)
|
||||||
|
print(f" 共 {len(sections)} 个章节,{total_qa} 条问答对")
|
||||||
|
|
||||||
|
# 限制测试数量(调试用)
|
||||||
|
if args.max_questions and args.max_questions > 0:
|
||||||
|
count = 0
|
||||||
|
trimmed = []
|
||||||
|
for s in sections:
|
||||||
|
if count >= args.max_questions:
|
||||||
|
break
|
||||||
|
keep = s.qa_pairs[:max(0, args.max_questions - count)]
|
||||||
|
if keep:
|
||||||
|
s.qa_pairs = keep
|
||||||
|
trimmed.append(s)
|
||||||
|
count += len(keep)
|
||||||
|
sections = trimmed
|
||||||
|
total_qa = sum(len(s.qa_pairs) for s in sections)
|
||||||
|
print(f" 限制为 {total_qa} 条(--max-questions {args.max_questions})")
|
||||||
|
|
||||||
|
# ── Step 2: 文件名映射 ────────────────────────────────────────
|
||||||
|
print(f"\n拉取知识库文件列表...")
|
||||||
|
mapper = FileMapper(
|
||||||
|
env_url=args.env_url,
|
||||||
|
org_id=args.org_id,
|
||||||
|
d_user_id=args.user_id,
|
||||||
|
)
|
||||||
|
file_count = await mapper.load_files()
|
||||||
|
print(f" 共 {file_count} 个文件")
|
||||||
|
|
||||||
|
file_map: dict[str, dict | None] = {}
|
||||||
|
unmatched = []
|
||||||
|
for s in sections:
|
||||||
|
if s.section_path not in file_map:
|
||||||
|
result = mapper.map_section_to_file(s.section_path)
|
||||||
|
file_map[s.section_path] = result
|
||||||
|
if not result:
|
||||||
|
unmatched.append(s.section_path)
|
||||||
|
|
||||||
|
matched = len(file_map) - len(unmatched)
|
||||||
|
print(f" 映射成功: {matched}/{len(file_map)} 个章节")
|
||||||
|
if unmatched:
|
||||||
|
print(f" 未匹配章节 ({len(unmatched)}): {unmatched[:5]}{'...' if len(unmatched) > 5 else ''}")
|
||||||
|
|
||||||
|
# ── Step 3: 执行召回测试 ──────────────────────────────────────
|
||||||
|
print(f"\n开始召回测试 (top_k={args.top_k}, concurrency={args.concurrency}, cross_chunk={args.cross_chunk})...")
|
||||||
|
tester = RecallTester(
|
||||||
|
env_url=args.env_url,
|
||||||
|
org_id=args.org_id,
|
||||||
|
d_user_id=args.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
finished = 0
|
||||||
|
def progress(done, total):
|
||||||
|
nonlocal finished
|
||||||
|
finished = done
|
||||||
|
print(f"\r 进度: {done}/{total}", end="", flush=True)
|
||||||
|
|
||||||
|
results = await tester.run(
|
||||||
|
sections=sections,
|
||||||
|
file_map=file_map,
|
||||||
|
top_k=args.top_k,
|
||||||
|
concurrency=args.concurrency,
|
||||||
|
cross_chunk=args.cross_chunk,
|
||||||
|
progress_cb=progress,
|
||||||
|
)
|
||||||
|
print(f"\r 完成: {len(results)} 条")
|
||||||
|
|
||||||
|
# ── Step 4: 质量检测 ──────────────────────────────────────────
|
||||||
|
quality_info = check_recall_quality(results)
|
||||||
|
|
||||||
|
# ── Step 5: 生成报告 ──────────────────────────────────────────
|
||||||
|
report = build_report(
|
||||||
|
results=results,
|
||||||
|
env_url=args.env_url,
|
||||||
|
org_id=args.org_id,
|
||||||
|
qa_file=args.qa_file,
|
||||||
|
top_k=args.top_k,
|
||||||
|
cross_chunk=args.cross_chunk,
|
||||||
|
quality_info=quality_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n" + report.summary_text())
|
||||||
|
|
||||||
|
report.save(args.output)
|
||||||
|
print(f"\n报告已保存: {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="single-jump-eval",
|
||||||
|
description="单跳知识库召回自动化测试",
|
||||||
|
)
|
||||||
|
parser.add_argument("--env-url", required=True, help="dagent 环境地址,如 https://cloud-dev.d-robotics.cc")
|
||||||
|
parser.add_argument("--org-id", required=True, help="组织 ID")
|
||||||
|
parser.add_argument("--user-id", default="test", help="d-user-id 请求头(默认 test)")
|
||||||
|
parser.add_argument("--qa-file", required=True, help="问答集 MD 文件路径")
|
||||||
|
parser.add_argument("--top-k", type=int, default=5, help="召回数量(默认 5)")
|
||||||
|
parser.add_argument("--concurrency", type=int, default=5, help="并发数(默认 5)")
|
||||||
|
parser.add_argument("--cross-chunk", action="store_true", help="跨切片模式(不限定 file_id)")
|
||||||
|
parser.add_argument("--max-questions", type=int, default=0, help="限制测试问题数(0=不限制,调试用)")
|
||||||
|
parser.add_argument("--output", default="single_jump_report.json", help="输出报告路径")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
asyncio.run(run(args))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
111
sdk/rag_eval/single_jump/mapper.py
Normal file
111
sdk/rag_eval/single_jump/mapper.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
"""
|
||||||
|
将 MD 文件中的 doc_name 映射到 dagent 知识库的 file_id。
|
||||||
|
|
||||||
|
映射规则(优先级从高到低):
|
||||||
|
1. 精确匹配:file_name == doc_name
|
||||||
|
2. 包含匹配:file_name 包含 doc_name
|
||||||
|
3. 模糊匹配:doc_name 的关键词在 file_name 中
|
||||||
|
"""
|
||||||
|
import aiohttp
|
||||||
|
from difflib import SequenceMatcher
|
||||||
|
|
||||||
|
|
||||||
|
class FileMapper:
|
||||||
|
def __init__(self, env_url: str, org_id: str, d_user_id: str = "test"):
|
||||||
|
self.env_url = env_url.rstrip("/")
|
||||||
|
self.org_id = org_id
|
||||||
|
self.headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"d-user-id": d_user_id,
|
||||||
|
"org-id": org_id,
|
||||||
|
}
|
||||||
|
self.files: list[dict] = []
|
||||||
|
|
||||||
|
async def load_files(self):
|
||||||
|
"""拉取知识库所有文件列表"""
|
||||||
|
url = f"{self.env_url}/dagent/knowledge/file/page"
|
||||||
|
all_files = []
|
||||||
|
page = 1
|
||||||
|
page_size = 100
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||||
|
while True:
|
||||||
|
payload = {
|
||||||
|
"current": page,
|
||||||
|
"page_size": page_size,
|
||||||
|
"org_id": self.org_id,
|
||||||
|
}
|
||||||
|
async with session.post(url, json=payload, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = await resp.json()
|
||||||
|
file_list = data.get("data", {}).get("list", [])
|
||||||
|
if not file_list:
|
||||||
|
break
|
||||||
|
all_files.extend(file_list)
|
||||||
|
if len(file_list) < page_size:
|
||||||
|
break
|
||||||
|
page += 1
|
||||||
|
|
||||||
|
self.files = all_files
|
||||||
|
return len(all_files)
|
||||||
|
|
||||||
|
def map_section_to_file(self, section_path: str) -> dict | None:
|
||||||
|
"""
|
||||||
|
将 section_path(如 "linux_development / bsp_develop")映射到 file_id。
|
||||||
|
|
||||||
|
文件名格式:linux_development/bsp_develop.md
|
||||||
|
section_path 格式:linux_development / bsp_develop
|
||||||
|
|
||||||
|
匹配规则(优先级从高到低):
|
||||||
|
1. 路径精确匹配:把 section_path 的空格去掉后与文件名(去扩展名)完全一致
|
||||||
|
2. 路径包含匹配:文件名(去扩展名)包含 section_path 的规范化形式
|
||||||
|
3. 末段精确匹配:文件名末段(去扩展名)== section_path 最后一段
|
||||||
|
4. 模糊匹配
|
||||||
|
"""
|
||||||
|
if not self.files:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 规范化 section_path:去空格,转小写,斜杠统一
|
||||||
|
# "linux_development / bsp_develop" -> "linux_development/bsp_develop"
|
||||||
|
normalized = "/".join(p.strip() for p in section_path.split("/")).lower()
|
||||||
|
doc_name = section_path.split("/")[-1].strip().lower()
|
||||||
|
|
||||||
|
# 1. 路径精确匹配(去扩展名)
|
||||||
|
for f in self.files:
|
||||||
|
fname_base = f["file_name"].rsplit(".", 1)[0].lower()
|
||||||
|
if fname_base == normalized:
|
||||||
|
return {"file_id": f["id"], "file_name": f["file_name"], "match_type": "exact"}
|
||||||
|
|
||||||
|
# 2. 路径包含匹配
|
||||||
|
for f in self.files:
|
||||||
|
fname_base = f["file_name"].rsplit(".", 1)[0].lower()
|
||||||
|
if normalized in fname_base or fname_base in normalized:
|
||||||
|
return {"file_id": f["id"], "file_name": f["file_name"], "match_type": "path_contains"}
|
||||||
|
|
||||||
|
# 3. 末段精确匹配
|
||||||
|
for f in self.files:
|
||||||
|
fname_base = f["file_name"].rsplit(".", 1)[0].lower()
|
||||||
|
fname_last = fname_base.split("/")[-1]
|
||||||
|
if fname_last == doc_name:
|
||||||
|
return {"file_id": f["id"], "file_name": f["file_name"], "match_type": "basename"}
|
||||||
|
|
||||||
|
# 4. 模糊匹配(相似度 > 0.6)
|
||||||
|
best_match = None
|
||||||
|
best_score = 0.6
|
||||||
|
for f in self.files:
|
||||||
|
fname_base = f["file_name"].rsplit(".", 1)[0].lower()
|
||||||
|
score = SequenceMatcher(None, normalized, fname_base).ratio()
|
||||||
|
if score > best_score:
|
||||||
|
best_score = score
|
||||||
|
best_match = {
|
||||||
|
"file_id": f["id"],
|
||||||
|
"file_name": f["file_name"],
|
||||||
|
"match_type": "fuzzy",
|
||||||
|
"score": round(score, 3),
|
||||||
|
}
|
||||||
|
|
||||||
|
return best_match
|
||||||
|
|
||||||
|
def map_doc_to_file(self, doc_name: str) -> dict | None:
|
||||||
|
"""向后兼容,内部调用 map_section_to_file"""
|
||||||
|
return self.map_section_to_file(doc_name)
|
||||||
133
sdk/rag_eval/single_jump/parser.py
Normal file
133
sdk/rag_eval/single_jump/parser.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
"""
|
||||||
|
解析 EVB 知识库问答集 MD 文件,提取结构化问答对。
|
||||||
|
|
||||||
|
文件格式:
|
||||||
|
# 第N章 章节名
|
||||||
|
## chapter_path / doc_name ← 知识库文件标识
|
||||||
|
# 文档标题
|
||||||
|
> 由 LLM 自动生成的问答对
|
||||||
|
---
|
||||||
|
## Q1: 问题
|
||||||
|
**A1:** 答案
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QAPair:
|
||||||
|
qid: str # Q1, Q2 ...
|
||||||
|
question: str
|
||||||
|
answer: str
|
||||||
|
expected_chunk_id: str | None = None # 期望命中的切片ID,从MD元数据解析
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Section:
|
||||||
|
chapter: str # 第一章 前言
|
||||||
|
section_path: str # preface / overview
|
||||||
|
doc_name: str # overview(最后一段,用于匹配文件名)
|
||||||
|
doc_title: str # 1. 前言
|
||||||
|
qa_pairs: list[QAPair] = field(default_factory=list)
|
||||||
|
raw_chunk_headers: str | None = None # 原始切片标题(从元数据解析)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_qa_file(filepath: str) -> list[Section]:
|
||||||
|
with open(filepath, encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
return parse_qa_file_text(content)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_qa_file_text(content: str) -> list[Section]:
|
||||||
|
"""从文本内容解析问答对(用于 API 上传)"""
|
||||||
|
sections: list[Section] = []
|
||||||
|
current_chapter = ""
|
||||||
|
current_section: Section | None = None
|
||||||
|
current_q: str | None = None
|
||||||
|
current_q_text: str | None = None
|
||||||
|
current_q_chunk_id: str | None = None # 当前问答对期望的 chunk_id
|
||||||
|
answer_lines: list[str] = []
|
||||||
|
|
||||||
|
def _flush_qa():
|
||||||
|
nonlocal current_q, current_q_text, answer_lines, current_q_chunk_id
|
||||||
|
if current_section and current_q and current_q_text:
|
||||||
|
ans = " ".join(answer_lines).strip()
|
||||||
|
# 去掉 **A1:** 前缀
|
||||||
|
ans = re.sub(r"^\*\*A\d+:\*\*\s*", "", ans)
|
||||||
|
current_section.qa_pairs.append(QAPair(
|
||||||
|
qid=current_q,
|
||||||
|
question=current_q_text,
|
||||||
|
answer=ans,
|
||||||
|
expected_chunk_id=current_q_chunk_id,
|
||||||
|
))
|
||||||
|
current_q = None
|
||||||
|
current_q_text = None
|
||||||
|
answer_lines = []
|
||||||
|
current_q_chunk_id = None
|
||||||
|
|
||||||
|
for line in content.splitlines():
|
||||||
|
# 章节标题:# 第N章 ...
|
||||||
|
m = re.match(r"^# (第.+章.+)$", line)
|
||||||
|
if m:
|
||||||
|
current_chapter = m.group(1).strip()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 知识库标识:## chapter / doc_name(排除 ## Q1: 问题 这种问答行)
|
||||||
|
# 允许逗号、反引号、括号、问号等切片标题常见符号,避免把中文路径清洗成下划线后才能解析
|
||||||
|
m = re.match(r"^## (?!Q\d+:)(.+)$", line)
|
||||||
|
if m:
|
||||||
|
_flush_qa()
|
||||||
|
if current_section:
|
||||||
|
sections.append(current_section)
|
||||||
|
path = m.group(1).strip()
|
||||||
|
parts = [p.strip() for p in path.split("/")]
|
||||||
|
doc_name = parts[-1] if parts else path
|
||||||
|
current_section = Section(
|
||||||
|
chapter=current_chapter,
|
||||||
|
section_path=path,
|
||||||
|
doc_name=doc_name,
|
||||||
|
doc_title="",
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 元数据行:> 原始切片标题: xxx
|
||||||
|
m = re.match(r"^> 原始切片标题: (.+)$", line)
|
||||||
|
if m and current_section:
|
||||||
|
current_section.raw_chunk_headers = m.group(1).strip()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 文档标题:# N. 标题
|
||||||
|
m = re.match(r"^# (\d[\d\.]*\s+.+)$", line)
|
||||||
|
if m and current_section and not current_section.doc_title:
|
||||||
|
current_section.doc_title = m.group(1).strip()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 问题行:## Q1: 问题内容
|
||||||
|
m = re.match(r"^## (Q\d+):\s*(.+)$", line)
|
||||||
|
if m:
|
||||||
|
_flush_qa()
|
||||||
|
current_q = m.group(1)
|
||||||
|
current_q_text = m.group(2).strip()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# chunk_id 元数据行:> chunk_id: xxx
|
||||||
|
m = re.match(r"^> chunk_id:\s*(\S+)$", line)
|
||||||
|
if m and current_q:
|
||||||
|
current_q_chunk_id = m.group(1).strip()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 答案行:**A1:** 答案内容
|
||||||
|
if current_q and re.match(r"^\*\*A\d+:\*\*", line):
|
||||||
|
ans = re.sub(r"^\*\*A\d+:\*\*\s*", "", line).strip()
|
||||||
|
answer_lines = [ans]
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 答案续行(非空、非分隔符、非新问题)
|
||||||
|
if current_q and answer_lines is not None and line.strip() and not line.startswith("#") and line != "---":
|
||||||
|
answer_lines.append(line.strip())
|
||||||
|
|
||||||
|
_flush_qa()
|
||||||
|
if current_section:
|
||||||
|
sections.append(current_section)
|
||||||
|
|
||||||
|
return sections
|
||||||
79
sdk/rag_eval/single_jump/quality.py
Normal file
79
sdk/rag_eval/single_jump/quality.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
"""
|
||||||
|
测试样例质量检测器。
|
||||||
|
"""
|
||||||
|
from .parser import QAPair, Section
|
||||||
|
from .tester import RecallResult
|
||||||
|
|
||||||
|
|
||||||
|
def check_qa_quality(qa: QAPair) -> dict:
|
||||||
|
"""
|
||||||
|
检查单条问答对的质量。
|
||||||
|
返回:{"is_valid": bool, "issues": [str]}
|
||||||
|
"""
|
||||||
|
issues = []
|
||||||
|
|
||||||
|
# 问题完整性
|
||||||
|
if len(qa.question) < 5:
|
||||||
|
issues.append("问题过短")
|
||||||
|
if not qa.question.endswith("?") and not qa.question.endswith("?"):
|
||||||
|
issues.append("问题未以问号结尾")
|
||||||
|
|
||||||
|
# 答案完整性
|
||||||
|
if len(qa.answer) < 10:
|
||||||
|
issues.append("答案过短")
|
||||||
|
|
||||||
|
# 问答一致性(答案中应包含问题的关键词)
|
||||||
|
q_words = set(qa.question.replace("?", "").replace("?", "").split())
|
||||||
|
a_words = set(qa.answer.split())
|
||||||
|
if len(q_words & a_words) == 0:
|
||||||
|
issues.append("答案与问题无关键词重叠")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"is_valid": len(issues) == 0,
|
||||||
|
"issues": issues,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def check_recall_quality(results: list[RecallResult]) -> dict:
|
||||||
|
"""
|
||||||
|
通过召回结果反向验证样例质量。
|
||||||
|
返回:{"low_quality": [RecallResult], "suspicious": [RecallResult]}
|
||||||
|
"""
|
||||||
|
low_quality = []
|
||||||
|
suspicious = []
|
||||||
|
|
||||||
|
for r in results:
|
||||||
|
if r.error or r.is_empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 召回相似度极低(< 0.5)
|
||||||
|
if r.best_cosine_sim and r.best_cosine_sim < 0.5:
|
||||||
|
low_quality.append(r)
|
||||||
|
|
||||||
|
# 召回的文件与预期不符(跨文件召回)
|
||||||
|
if r.file_id and r.file_id not in r.retrieved_file_ids:
|
||||||
|
suspicious.append(r)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"low_quality": low_quality,
|
||||||
|
"suspicious": suspicious,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def detect_duplicates(sections: list[Section], threshold: float = 0.9) -> list[tuple[str, str]]:
|
||||||
|
"""
|
||||||
|
检测重复问题(简单基于字符串相似度)。
|
||||||
|
返回:[(qid1, qid2), ...]
|
||||||
|
"""
|
||||||
|
from difflib import SequenceMatcher
|
||||||
|
|
||||||
|
all_qa = [(s.section_path, qa) for s in sections for qa in s.qa_pairs]
|
||||||
|
duplicates = []
|
||||||
|
|
||||||
|
for i, (path1, qa1) in enumerate(all_qa):
|
||||||
|
for path2, qa2 in all_qa[i + 1:]:
|
||||||
|
sim = SequenceMatcher(None, qa1.question, qa2.question).ratio()
|
||||||
|
if sim > threshold:
|
||||||
|
duplicates.append((f"{path1}/{qa1.qid}", f"{path2}/{qa2.qid}"))
|
||||||
|
|
||||||
|
return duplicates
|
||||||
206
sdk/rag_eval/single_jump/report.py
Normal file
206
sdk/rag_eval/single_jump/report.py
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
"""
|
||||||
|
报告生成器:汇总召回测试结果,输出结构化报告。
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from .tester import RecallResult
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SectionStats:
|
||||||
|
section_path: str
|
||||||
|
doc_name: str
|
||||||
|
file_id: str | None
|
||||||
|
match_type: str | None
|
||||||
|
total: int = 0
|
||||||
|
recalled: int = 0 # 有召回结果的问题数
|
||||||
|
empty: int = 0 # 空召回数
|
||||||
|
errors: int = 0
|
||||||
|
avg_cosine_sim: float | None = None
|
||||||
|
avg_latency_ms: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SingleJumpReport:
|
||||||
|
env_url: str
|
||||||
|
org_id: str
|
||||||
|
qa_file: str
|
||||||
|
top_k: int
|
||||||
|
cross_chunk: bool
|
||||||
|
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
|
|
||||||
|
total_questions: int = 0
|
||||||
|
total_sections: int = 0
|
||||||
|
matched_sections: int = 0 # 成功映射到 file_id 的章节数
|
||||||
|
unmatched_sections: int = 0
|
||||||
|
recalled_questions: int = 0 # 有召回结果的问题数
|
||||||
|
empty_questions: int = 0
|
||||||
|
error_questions: int = 0
|
||||||
|
|
||||||
|
recall_rate: float | None = None # recalled / total
|
||||||
|
empty_rate: float | None = None
|
||||||
|
section_match_rate: float | None = None
|
||||||
|
avg_cosine_sim: float | None = None
|
||||||
|
avg_latency_ms: float | None = None
|
||||||
|
|
||||||
|
section_stats: list[SectionStats] = field(default_factory=list)
|
||||||
|
low_quality_results: list[dict] = field(default_factory=list)
|
||||||
|
suspicious_results: list[dict] = field(default_factory=list)
|
||||||
|
unmatched_section_list: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
d = {
|
||||||
|
"env_url": self.env_url,
|
||||||
|
"org_id": self.org_id,
|
||||||
|
"qa_file": self.qa_file,
|
||||||
|
"top_k": self.top_k,
|
||||||
|
"cross_chunk": self.cross_chunk,
|
||||||
|
"created_at": self.created_at,
|
||||||
|
"summary": {
|
||||||
|
"total_questions": self.total_questions,
|
||||||
|
"total_sections": self.total_sections,
|
||||||
|
"matched_sections": self.matched_sections,
|
||||||
|
"unmatched_sections": self.unmatched_sections,
|
||||||
|
"recalled_questions": self.recalled_questions,
|
||||||
|
"empty_questions": self.empty_questions,
|
||||||
|
"error_questions": self.error_questions,
|
||||||
|
"recall_rate": self.recall_rate,
|
||||||
|
"empty_rate": self.empty_rate,
|
||||||
|
"section_match_rate": self.section_match_rate,
|
||||||
|
"avg_cosine_sim": self.avg_cosine_sim,
|
||||||
|
"avg_latency_ms": self.avg_latency_ms,
|
||||||
|
},
|
||||||
|
"section_stats": [
|
||||||
|
{
|
||||||
|
"section_path": s.section_path,
|
||||||
|
"doc_name": s.doc_name,
|
||||||
|
"file_id": s.file_id,
|
||||||
|
"match_type": s.match_type,
|
||||||
|
"total": s.total,
|
||||||
|
"recalled": s.recalled,
|
||||||
|
"empty": s.empty,
|
||||||
|
"errors": s.errors,
|
||||||
|
"avg_cosine_sim": s.avg_cosine_sim,
|
||||||
|
"avg_latency_ms": s.avg_latency_ms,
|
||||||
|
}
|
||||||
|
for s in self.section_stats
|
||||||
|
],
|
||||||
|
"unmatched_sections": self.unmatched_section_list,
|
||||||
|
"low_quality_count": len(self.low_quality_results),
|
||||||
|
"suspicious_count": len(self.suspicious_results),
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
|
||||||
|
def save(self, path: str):
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(self.to_dict(), f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
def summary_text(self) -> str:
|
||||||
|
lines = [
|
||||||
|
"=" * 60,
|
||||||
|
" 单跳召回测试报告",
|
||||||
|
"=" * 60,
|
||||||
|
f" 环境地址 : {self.env_url}",
|
||||||
|
f" 总问题数 : {self.total_questions}",
|
||||||
|
f" 总章节数 : {self.total_sections}",
|
||||||
|
f" 章节匹配率 : {self.section_match_rate:.1%}" if self.section_match_rate is not None else " 章节匹配率 : N/A",
|
||||||
|
f" 召回率 : {self.recall_rate:.1%}" if self.recall_rate is not None else " 召回率 : N/A",
|
||||||
|
f" 空召回率 : {self.empty_rate:.1%}" if self.empty_rate is not None else " 空召回率 : N/A",
|
||||||
|
f" 平均余弦相似度 : {self.avg_cosine_sim:.4f}" if self.avg_cosine_sim is not None else " 平均余弦相似度 : N/A",
|
||||||
|
f" 平均延迟 : {self.avg_latency_ms:.0f}ms" if self.avg_latency_ms is not None else " 平均延迟 : N/A",
|
||||||
|
f" 低质量样例 : {len(self.low_quality_results)}",
|
||||||
|
f" 可疑样例 : {len(self.suspicious_results)}",
|
||||||
|
"=" * 60,
|
||||||
|
]
|
||||||
|
if self.unmatched_section_list:
|
||||||
|
lines.append(f" 未匹配章节 ({len(self.unmatched_section_list)}):")
|
||||||
|
for s in self.unmatched_section_list[:10]:
|
||||||
|
lines.append(f" - {s}")
|
||||||
|
if len(self.unmatched_section_list) > 10:
|
||||||
|
lines.append(f" ... 共 {len(self.unmatched_section_list)} 个")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def build_report(
|
||||||
|
results: list[RecallResult],
|
||||||
|
env_url: str,
|
||||||
|
org_id: str,
|
||||||
|
qa_file: str,
|
||||||
|
top_k: int,
|
||||||
|
cross_chunk: bool,
|
||||||
|
quality_info: dict | None = None,
|
||||||
|
) -> SingleJumpReport:
|
||||||
|
report = SingleJumpReport(
|
||||||
|
env_url=env_url,
|
||||||
|
org_id=org_id,
|
||||||
|
qa_file=qa_file,
|
||||||
|
top_k=top_k,
|
||||||
|
cross_chunk=cross_chunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 按章节分组
|
||||||
|
section_map: dict[str, SectionStats] = {}
|
||||||
|
for r in results:
|
||||||
|
key = r.section_path
|
||||||
|
if key not in section_map:
|
||||||
|
section_map[key] = SectionStats(
|
||||||
|
section_path=r.section_path,
|
||||||
|
doc_name=r.doc_name,
|
||||||
|
file_id=r.file_id,
|
||||||
|
match_type=r.match_type,
|
||||||
|
)
|
||||||
|
s = section_map[key]
|
||||||
|
s.total += 1
|
||||||
|
if r.error:
|
||||||
|
s.errors += 1
|
||||||
|
elif r.is_empty:
|
||||||
|
s.empty += 1
|
||||||
|
else:
|
||||||
|
s.recalled += 1
|
||||||
|
|
||||||
|
# 计算章节平均指标
|
||||||
|
for key, s in section_map.items():
|
||||||
|
sec_results = [r for r in results if r.section_path == key and not r.error and not r.is_empty]
|
||||||
|
sims = [r.best_cosine_sim for r in sec_results if r.best_cosine_sim is not None]
|
||||||
|
lats = [r.latency_ms for r in sec_results if r.latency_ms]
|
||||||
|
s.avg_cosine_sim = round(sum(sims) / len(sims), 4) if sims else None
|
||||||
|
s.avg_latency_ms = round(sum(lats) / len(lats), 1) if lats else None
|
||||||
|
|
||||||
|
report.section_stats = list(section_map.values())
|
||||||
|
report.total_sections = len(section_map)
|
||||||
|
report.matched_sections = sum(1 for s in report.section_stats if s.file_id)
|
||||||
|
report.unmatched_sections = report.total_sections - report.matched_sections
|
||||||
|
report.unmatched_section_list = [
|
||||||
|
s.section_path for s in report.section_stats if not s.file_id
|
||||||
|
]
|
||||||
|
|
||||||
|
# 全局统计
|
||||||
|
report.total_questions = len(results)
|
||||||
|
report.recalled_questions = sum(1 for r in results if not r.error and not r.is_empty)
|
||||||
|
report.empty_questions = sum(1 for r in results if not r.error and r.is_empty)
|
||||||
|
report.error_questions = sum(1 for r in results if r.error)
|
||||||
|
|
||||||
|
if report.total_questions > 0:
|
||||||
|
report.recall_rate = round(report.recalled_questions / report.total_questions, 4)
|
||||||
|
report.empty_rate = round(report.empty_questions / report.total_questions, 4)
|
||||||
|
if report.total_sections > 0:
|
||||||
|
report.section_match_rate = round(report.matched_sections / report.total_sections, 4)
|
||||||
|
|
||||||
|
all_sims = [r.best_cosine_sim for r in results if r.best_cosine_sim is not None]
|
||||||
|
all_lats = [r.latency_ms for r in results if r.latency_ms]
|
||||||
|
report.avg_cosine_sim = round(sum(all_sims) / len(all_sims), 4) if all_sims else None
|
||||||
|
report.avg_latency_ms = round(sum(all_lats) / len(all_lats), 1) if all_lats else None
|
||||||
|
|
||||||
|
if quality_info:
|
||||||
|
report.low_quality_results = [
|
||||||
|
{"section": r.section_path, "qid": r.qid, "question": r.question, "sim": r.best_cosine_sim}
|
||||||
|
for r in quality_info.get("low_quality", [])
|
||||||
|
]
|
||||||
|
report.suspicious_results = [
|
||||||
|
{"section": r.section_path, "qid": r.qid, "question": r.question,
|
||||||
|
"expected_file": r.file_id, "retrieved_files": r.retrieved_file_ids}
|
||||||
|
for r in quality_info.get("suspicious", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
return report
|
||||||
358
sdk/rag_eval/single_jump/tester.py
Normal file
358
sdk/rag_eval/single_jump/tester.py
Normal file
@ -0,0 +1,358 @@
|
|||||||
|
"""
|
||||||
|
召回测试执行器:对每条问答对调用 dagent 语义召回接口,记录结果。
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
# Fix Windows GBK encoding issue
|
||||||
|
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
||||||
|
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
||||||
|
|
||||||
|
from .parser import Section, QAPair
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RecallResult:
|
||||||
|
section_path: str
|
||||||
|
doc_name: str
|
||||||
|
file_id: str | None
|
||||||
|
match_type: str | None # exact / contains / fuzzy / unmatched
|
||||||
|
qid: str
|
||||||
|
question: str
|
||||||
|
reference_answer: str
|
||||||
|
top_k: int # 用于判断命中的top_k值
|
||||||
|
hit_top_k: int # 用于判断切片是否命中的top_k阈值(可能不同于召回时的top_k)
|
||||||
|
retrieved: list[dict] = field(default_factory=list) # 召回的切片列表(全部,不截断)
|
||||||
|
latency_ms: int = 0
|
||||||
|
error: str | None = None
|
||||||
|
expected_chunk_id: str | None = None # 期望命中的切片ID
|
||||||
|
raw_chunk_headers: str | None = None # 原始切片标题(从元数据解析)
|
||||||
|
|
||||||
|
# 计算属性
|
||||||
|
@property
|
||||||
|
def best_cosine_sim(self) -> float | None:
|
||||||
|
sims = [1.0 - r.get("cosine_distance_1", 1.0) for r in self.retrieved if r.get("cosine_distance_1") is not None]
|
||||||
|
return round(max(sims), 4) if sims else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avg_cosine_sim(self) -> float | None:
|
||||||
|
sims = [1.0 - r.get("cosine_distance_1", 1.0) for r in self.retrieved if r.get("cosine_distance_1") is not None]
|
||||||
|
return round(sum(sims) / len(sims), 4) if sims else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
return len(self.retrieved) == 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retrieved_file_ids(self) -> list[str]:
|
||||||
|
return list({r.get("file_id", "") for r in self.retrieved if r.get("file_id")})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retrieved_chunk_ids(self) -> list[str]:
|
||||||
|
"""获取召回的所有切片ID"""
|
||||||
|
chunk_ids = []
|
||||||
|
for r in self.retrieved:
|
||||||
|
chunk_id = r.get("knowledge_md_header_split_id") or r.get("id") or r.get("chunk_id")
|
||||||
|
if chunk_id:
|
||||||
|
chunk_ids.append(chunk_id)
|
||||||
|
return chunk_ids
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_chunk_hit(self) -> bool:
|
||||||
|
"""检查期望切片是否在召回结果的前hit_top_k个结果中"""
|
||||||
|
if not self.expected_chunk_id:
|
||||||
|
return False
|
||||||
|
return self.expected_chunk_id in self.retrieved_chunk_ids[:self.hit_top_k]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chunk_hit_rank(self) -> int | None:
|
||||||
|
"""返回期望切片在召回结果中的排名(1-based),未命中返回None
|
||||||
|
|
||||||
|
只在hit_top_k范围内查找,超出范围视为未命中
|
||||||
|
"""
|
||||||
|
if not self.expected_chunk_id:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
idx = self.retrieved_chunk_ids[:self.hit_top_k].index(self.expected_chunk_id)
|
||||||
|
return idx + 1
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_file_hit(self) -> bool:
|
||||||
|
"""检查期望文件是否在召回结果的前hit_top_k个结果中"""
|
||||||
|
if not self.file_id:
|
||||||
|
return False
|
||||||
|
# 获取前hit_top_k个结果的file_ids
|
||||||
|
top_file_ids = []
|
||||||
|
for r in self.retrieved[:self.hit_top_k]:
|
||||||
|
fid = r.get("file_id")
|
||||||
|
if fid:
|
||||||
|
top_file_ids.append(fid)
|
||||||
|
return self.file_id in top_file_ids
|
||||||
|
|
||||||
|
|
||||||
|
class RecallTester:
|
||||||
|
def __init__(self, env_url: str, org_id: str, d_user_id: str = "test"):
|
||||||
|
self.env_url = env_url.rstrip("/")
|
||||||
|
self.org_id = org_id
|
||||||
|
self.headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"d-user-id": d_user_id,
|
||||||
|
"org-id": org_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _recall_one(
|
||||||
|
self,
|
||||||
|
session: aiohttp.ClientSession,
|
||||||
|
question: str,
|
||||||
|
file_id_list: list[str] | None,
|
||||||
|
recall_top_k: int, # 用于API调用时的top_k,可以设置较大值获取所有结果
|
||||||
|
agent_id: str = "", # 用于召回测试的 agent ID
|
||||||
|
) -> tuple[list[dict], int]:
|
||||||
|
# 如果提供了 agent_id,使用 agent chat API 进行召回
|
||||||
|
if agent_id:
|
||||||
|
return await self._recall_via_agent(session, question, agent_id, recall_top_k)
|
||||||
|
|
||||||
|
# 否则直接使用知识库搜索 API
|
||||||
|
url = f"{self.env_url}/dagent/knowledge/hub/semantic_search_knowledge/detail"
|
||||||
|
payload: dict = {
|
||||||
|
"query": question,
|
||||||
|
"org_id": self.org_id,
|
||||||
|
"top_k": recall_top_k,
|
||||||
|
}
|
||||||
|
if file_id_list:
|
||||||
|
payload["file_id_list"] = file_id_list
|
||||||
|
|
||||||
|
start = time.monotonic()
|
||||||
|
# 增加超时时间到60秒,并添加重试逻辑
|
||||||
|
max_retries = 3
|
||||||
|
last_error = None
|
||||||
|
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
async with session.post(url, json=payload, timeout=aiohttp.ClientTimeout(total=60)) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = await resp.json()
|
||||||
|
break # 成功则跳出重试循环
|
||||||
|
except asyncio.TimeoutError as e:
|
||||||
|
last_error = e
|
||||||
|
print(f"[DEBUG] Recall timeout (attempt {attempt+1}/{max_retries}) for: {question[:50]}...")
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
await asyncio.sleep(2 ** attempt) # 指数退避: 1s, 2s, 4s
|
||||||
|
else:
|
||||||
|
raise # 最后一次重试失败,抛出异常
|
||||||
|
except Exception as e:
|
||||||
|
raise # 其他异常直接抛出
|
||||||
|
|
||||||
|
latency_ms = int((time.monotonic() - start) * 1000)
|
||||||
|
|
||||||
|
# 检查 API 返回的业务错误码
|
||||||
|
code = data.get("code")
|
||||||
|
if code is not None and code != 0:
|
||||||
|
msg = data.get("msg", "Unknown error")
|
||||||
|
raise Exception(f"API error: code={code}, msg={msg}")
|
||||||
|
|
||||||
|
result_data = data.get("data", {}) or {}
|
||||||
|
|
||||||
|
# 调试:如果结果为空,打印调试信息
|
||||||
|
if not result_data or (not result_data.get("standard_answer_results") and not result_data.get("related_knowledge_rerank_results_top")):
|
||||||
|
print(f"[DEBUG] Empty/No results for question: {question[:50]}...")
|
||||||
|
print(f"[DEBUG] Response code: {data.get('code')}, msg: {data.get('msg')}")
|
||||||
|
print(f"[DEBUG] org_id used: {self.org_id}")
|
||||||
|
print(f"[DEBUG] Request payload: {payload}")
|
||||||
|
print(f"[DEBUG] Response data keys: {list(data.keys())}")
|
||||||
|
if result_data:
|
||||||
|
print(f"[DEBUG] result_data keys: {list(result_data.keys())}")
|
||||||
|
|
||||||
|
standard = result_data.get("standard_answer_results") or []
|
||||||
|
rerank_top = result_data.get("related_knowledge_rerank_results_top") or []
|
||||||
|
all_items = standard + rerank_top
|
||||||
|
|
||||||
|
# 调试:记录召回结果数量
|
||||||
|
if len(all_items) == 0:
|
||||||
|
print(f"[DEBUG] No recall results for: {question[:50]}... (standard={len(standard)}, rerank={len(rerank_top)})")
|
||||||
|
|
||||||
|
return all_items, latency_ms
|
||||||
|
|
||||||
|
async def _recall_via_agent(
|
||||||
|
self,
|
||||||
|
session: aiohttp.ClientSession,
|
||||||
|
question: str,
|
||||||
|
agent_id: str,
|
||||||
|
recall_top_k: int,
|
||||||
|
) -> tuple[list[dict], int]:
|
||||||
|
"""通过 Agent chat SSE 接口获取召回结果。
|
||||||
|
|
||||||
|
解析策略:
|
||||||
|
- 逐行读取 SSE(服务端单 `\n` 分隔,不是双换行)
|
||||||
|
- 每个 EVENT.event_name == "TOOL_END" 的 event_data.items 里有一批 chunk
|
||||||
|
- Agent 可能多轮工具调用,每次 TOOL_END 都累加;按 (file_id, paragraph_chunk_id) 去重
|
||||||
|
- 顺序保留首次出现位置(作为伪 rank),用于命中排名统计
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
payload = {
|
||||||
|
"chat_id": uuid.uuid4().hex,
|
||||||
|
"task": question,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"llm_type": "deepseek_v3",
|
||||||
|
}
|
||||||
|
|
||||||
|
start = time.monotonic()
|
||||||
|
items: list[dict] = []
|
||||||
|
seen: set[tuple[str, str]] = set()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.env_url}/dagent/agent/chat",
|
||||||
|
json=payload,
|
||||||
|
headers={"Accept": "text/event-stream"},
|
||||||
|
timeout=aiohttp.ClientTimeout(total=300),
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
line_buf = ""
|
||||||
|
async for raw in resp.content:
|
||||||
|
line_buf += raw.decode("utf-8", errors="replace")
|
||||||
|
while "\n" in line_buf:
|
||||||
|
line, line_buf = line_buf.split("\n", 1)
|
||||||
|
line = line.rstrip("\r")
|
||||||
|
if not line.startswith("data:"):
|
||||||
|
continue
|
||||||
|
data_str = line[5:].strip()
|
||||||
|
if not data_str or data_str == "[DONE]":
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
chunk = json.loads(data_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
if chunk.get("message_type") != "EVENT" or chunk.get("is_chunk_data"):
|
||||||
|
continue
|
||||||
|
event_data_raw = chunk.get("data")
|
||||||
|
if isinstance(event_data_raw, str):
|
||||||
|
try:
|
||||||
|
event_data = json.loads(event_data_raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
event_data = event_data_raw
|
||||||
|
if not isinstance(event_data, dict):
|
||||||
|
continue
|
||||||
|
if event_data.get("event_name") != "TOOL_END":
|
||||||
|
continue
|
||||||
|
tool_event_data = event_data.get("event_data")
|
||||||
|
if not isinstance(tool_event_data, dict):
|
||||||
|
continue
|
||||||
|
reference_items = tool_event_data.get("items") or []
|
||||||
|
if not isinstance(reference_items, list):
|
||||||
|
continue
|
||||||
|
for item in reference_items:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
file_id = str(item.get("file_id") or "")
|
||||||
|
chunk_id = str(
|
||||||
|
item.get("paragraph_chunk_id")
|
||||||
|
or item.get("knowledge_md_header_split_id")
|
||||||
|
or ""
|
||||||
|
)
|
||||||
|
# 跳过不带 file_id/chunk_id 的外链类条目(只有 file_name+url)
|
||||||
|
if not file_id and not chunk_id:
|
||||||
|
continue
|
||||||
|
key = (file_id, chunk_id)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
items.append({
|
||||||
|
"file_id": file_id,
|
||||||
|
"file_name": "",
|
||||||
|
"headers": str(item.get("headers") or ""),
|
||||||
|
"content": item.get("active_paragraph_context")
|
||||||
|
or item.get("active_context") or "",
|
||||||
|
"knowledge_md_header_split_id": chunk_id,
|
||||||
|
"id": chunk_id,
|
||||||
|
"paragraph_md5": str(item.get("paragraph_md5") or ""),
|
||||||
|
"cosine_distance_1": None,
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[DEBUG] Agent recall error: {e}")
|
||||||
|
|
||||||
|
latency_ms = int((time.monotonic() - start) * 1000)
|
||||||
|
return items[:recall_top_k], latency_ms
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
sections: list[Section],
|
||||||
|
file_map: dict[str, dict | None],
|
||||||
|
top_k: int = 5, # 用于判断命中的top_k阈值
|
||||||
|
recall_top_k: int = 100, # 用于API调用时的top_k,默认100获取更多结果
|
||||||
|
concurrency: int = 20, # 增加默认并发数到20
|
||||||
|
cross_chunk: bool = False, # 保留参数兼容旧调用,但不再控制搜索范围
|
||||||
|
result_cb=None,
|
||||||
|
progress_cb=None, # 保留兼容旧调用
|
||||||
|
chunk_map: dict[str, str] | None = None, # question -> expected_chunk_id
|
||||||
|
agent_id: str = "", # 用于召回测试的 agent ID
|
||||||
|
) -> list[RecallResult]:
|
||||||
|
results: list[RecallResult] = []
|
||||||
|
sem = asyncio.Semaphore(concurrency)
|
||||||
|
total = sum(len(s.qa_pairs) for s in sections)
|
||||||
|
done = 0
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||||
|
async def _test_one(section: Section, qa: QAPair) -> RecallResult:
|
||||||
|
nonlocal done
|
||||||
|
mapping = file_map.get(section.section_path)
|
||||||
|
file_id = mapping["file_id"] if mapping else None
|
||||||
|
match_type = mapping["match_type"] if mapping else "unmatched"
|
||||||
|
|
||||||
|
# 优先使用 QAPair 上已注入的 chunk_id,其次从 chunk_map 查找
|
||||||
|
expected_chunk_id = qa.expected_chunk_id or (
|
||||||
|
chunk_map.get(qa.question) if chunk_map else None
|
||||||
|
)
|
||||||
|
|
||||||
|
result = RecallResult(
|
||||||
|
section_path=section.section_path,
|
||||||
|
doc_name=section.doc_name,
|
||||||
|
file_id=file_id,
|
||||||
|
match_type=match_type,
|
||||||
|
qid=qa.qid,
|
||||||
|
question=qa.question,
|
||||||
|
reference_answer=qa.answer,
|
||||||
|
top_k=top_k,
|
||||||
|
hit_top_k=top_k, # 用于判断命中的阈值
|
||||||
|
expected_chunk_id=expected_chunk_id,
|
||||||
|
raw_chunk_headers=section.raw_chunk_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 始终全库搜索(不传 file_id_list),以切片命中为主要指标
|
||||||
|
# 使用较大的 recall_top_k 获取所有召回结果
|
||||||
|
async with sem:
|
||||||
|
try:
|
||||||
|
chunks, latency = await self._recall_one(session, qa.question, None, recall_top_k, agent_id)
|
||||||
|
result.retrieved = chunks
|
||||||
|
result.latency_ms = latency
|
||||||
|
# 调试:记录召回结果数量
|
||||||
|
if len(chunks) == 0:
|
||||||
|
print(f"[DEBUG] Empty recall for question: {qa.question[:60]}... (section: {section.section_path[:40]}...)")
|
||||||
|
except Exception as e:
|
||||||
|
result.error = str(e)
|
||||||
|
print(f"[DEBUG] Recall error for question: {qa.question[:60]}... Error: {e}")
|
||||||
|
|
||||||
|
done += 1
|
||||||
|
if result_cb:
|
||||||
|
await result_cb(result, done, total)
|
||||||
|
elif progress_cb and (done % 10 == 0 or done == total):
|
||||||
|
await progress_cb(done, total)
|
||||||
|
return result
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
_test_one(section, qa)
|
||||||
|
for section in sections
|
||||||
|
for qa in section.qa_pairs
|
||||||
|
]
|
||||||
|
for coro in asyncio.as_completed(tasks):
|
||||||
|
results.append(await coro)
|
||||||
|
|
||||||
|
return results
|
||||||
0
server/__init__.py
Normal file
0
server/__init__.py
Normal file
0
server/api/__init__.py
Normal file
0
server/api/__init__.py
Normal file
88
server/api/config.py
Normal file
88
server/api/config.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
# Add parent directory to sys.path for relative imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
from models.db import get_db, _now, _id
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/config", tags=["配置管理"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Platform Config ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class PlatformConfigReq(BaseModel):
|
||||||
|
name: str
|
||||||
|
type: str = "dagent"
|
||||||
|
base_url: str
|
||||||
|
org_id: Optional[str] = None
|
||||||
|
token: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/platform")
|
||||||
|
async def create_platform_config(req: PlatformConfigReq):
|
||||||
|
async with get_db() as db:
|
||||||
|
row_id = _id()
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO platform_config (id,name,type,base_url,org_id,token,created_at) VALUES (?,?,?,?,?,?,?)",
|
||||||
|
(row_id, req.name, req.type, req.base_url, req.org_id, req.token, _now()),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0, "data": {"id": row_id}}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/platform")
|
||||||
|
async def list_platform_configs():
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall("SELECT * FROM platform_config ORDER BY created_at DESC")
|
||||||
|
return {"status": 0, "data": [dict(r) for r in rows]}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/platform/{config_id}")
|
||||||
|
async def delete_platform_config(config_id: str):
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute("DELETE FROM platform_config WHERE id=?", (config_id,))
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0, "data": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Judge Config ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class JudgeConfigReq(BaseModel):
|
||||||
|
name: str
|
||||||
|
base_url: str
|
||||||
|
api_key: str
|
||||||
|
model: str
|
||||||
|
embed_base_url: Optional[str] = ""
|
||||||
|
embed_api_key: Optional[str] = ""
|
||||||
|
embed_model: Optional[str] = "text-embedding-3-small"
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/judge")
|
||||||
|
async def create_judge_config(req: JudgeConfigReq):
|
||||||
|
async with get_db() as db:
|
||||||
|
row_id = _id()
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO judge_config (id,name,base_url,api_key,model,embed_base_url,embed_api_key,embed_model,created_at) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||||
|
(row_id, req.name, req.base_url, req.api_key, req.model, req.embed_base_url, req.embed_api_key, req.embed_model, _now()),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0, "data": {"id": row_id}}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/judge")
|
||||||
|
async def list_judge_configs():
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall("SELECT id,name,base_url,model,embed_base_url,embed_model,created_at FROM judge_config ORDER BY created_at DESC")
|
||||||
|
return {"status": 0, "data": [dict(r) for r in rows]}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/judge/{config_id}")
|
||||||
|
async def delete_judge_config(config_id: str):
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute("DELETE FROM judge_config WHERE id=?", (config_id,))
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0, "data": True}
|
||||||
250
server/api/dataset.py
Normal file
250
server/api/dataset.py
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from fastapi import APIRouter, HTTPException, UploadFile, File
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
# Add parent directory to sys.path for relative imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
from models.db import get_db, _now, _id
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/dataset", tags=["测试集管理"])
|
||||||
|
|
||||||
|
|
||||||
|
class CreateDatasetReq(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = ""
|
||||||
|
|
||||||
|
|
||||||
|
class AddSampleReq(BaseModel):
|
||||||
|
dataset_id: str
|
||||||
|
question: str
|
||||||
|
reference_answer: str
|
||||||
|
relevant_chunk_ids: list[str] = []
|
||||||
|
knowledge_hub_id: str
|
||||||
|
source_file_id: Optional[str] = None
|
||||||
|
metadata: dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateReq(BaseModel):
|
||||||
|
dataset_id: str
|
||||||
|
platform_config_id: str
|
||||||
|
judge_config_id: str
|
||||||
|
knowledge_hub_id: str
|
||||||
|
file_id_list: list[str]
|
||||||
|
chunk_ids: list[str] = []
|
||||||
|
questions_per_chunk: int = 2
|
||||||
|
max_chunks: int = 50
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/create")
|
||||||
|
async def create_dataset(req: CreateDatasetReq):
|
||||||
|
async with get_db() as db:
|
||||||
|
row_id = _id()
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO eval_dataset (id,name,description,sample_count,created_at) VALUES (?,?,?,0,?)",
|
||||||
|
(row_id, req.name, req.description, _now()),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0, "data": {"id": row_id}}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/list")
|
||||||
|
async def list_datasets():
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM eval_dataset ORDER BY created_at DESC"
|
||||||
|
)
|
||||||
|
return {"status": 0, "data": [dict(r) for r in rows]}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixed-path routes MUST come before /{dataset_id} ────────────────────────
|
||||||
|
|
||||||
|
@router.get("/chunks-preview")
|
||||||
|
async def chunks_preview(platform_config_id: str, knowledge_hub_id: str):
|
||||||
|
"""Proxy: fetch chunks from the RAG platform for preview/selection"""
|
||||||
|
import aiohttp
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM platform_config WHERE id=?", (platform_config_id,)
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
raise HTTPException(status_code=404, detail="Platform config not found")
|
||||||
|
cfg = dict(rows[0])
|
||||||
|
base_url = cfg["base_url"].rstrip("/")
|
||||||
|
org_id = cfg.get("org_id", "")
|
||||||
|
|
||||||
|
# Build headers with org-id for dagent API
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"org-id": org_id,
|
||||||
|
"d-user-id": "test",
|
||||||
|
}
|
||||||
|
|
||||||
|
all_chunks = []
|
||||||
|
|
||||||
|
# Use dagent file/page endpoint to get all files, then fetch chunks for each
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(headers=headers) as session:
|
||||||
|
page = 1
|
||||||
|
page_size = 100
|
||||||
|
while True:
|
||||||
|
async with session.post(
|
||||||
|
f"{base_url}/dagent/knowledge/file/page",
|
||||||
|
json={"current": page, "page_size": page_size, "org_id": org_id},
|
||||||
|
timeout=aiohttp.ClientTimeout(total=15),
|
||||||
|
) as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
data = await resp.json()
|
||||||
|
files = data.get("data", {}).get("list", [])
|
||||||
|
if not files:
|
||||||
|
break
|
||||||
|
# Fetch chunks for each file
|
||||||
|
for f in files:
|
||||||
|
try:
|
||||||
|
async with session.post(
|
||||||
|
f"{base_url}/dagent/knowledge/chunk/page",
|
||||||
|
json={"file_id": f["id"], "org_id": org_id, "page_size": 200},
|
||||||
|
timeout=aiohttp.ClientTimeout(total=15),
|
||||||
|
) as cr:
|
||||||
|
if cr.status == 200:
|
||||||
|
cd = await cr.json()
|
||||||
|
for c in cd.get("data", {}).get("list", []):
|
||||||
|
c["file_id"] = c.get("file_id", f["id"])
|
||||||
|
c["file_name"] = f.get("file_name", "")
|
||||||
|
c["content"] = c.get("paragraph_context") or c.get("content", "")
|
||||||
|
all_chunks.append(c)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if len(files) < page_size:
|
||||||
|
break
|
||||||
|
page += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
if all_chunks:
|
||||||
|
return {"status": 0, "data": all_chunks}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[chunks_preview] Error: {e}")
|
||||||
|
|
||||||
|
return {"status": 0, "data": []}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/generate")
|
||||||
|
async def generate_dataset(req: GenerateReq):
|
||||||
|
"""Trigger async LLM generation — returns gen_task_id for progress tracking"""
|
||||||
|
import asyncio
|
||||||
|
from ..service.task_service import run_generate_task
|
||||||
|
|
||||||
|
gen_task_id = _id()
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO generate_task (id,dataset_id,status,created_at) VALUES (?,?,'pending',?)",
|
||||||
|
(gen_task_id, req.dataset_id, _now()),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
params = req.dict()
|
||||||
|
params["gen_task_id"] = gen_task_id
|
||||||
|
asyncio.create_task(run_generate_task(params))
|
||||||
|
return {"status": 0, "data": {"gen_task_id": gen_task_id}}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/generate/{gen_task_id}")
|
||||||
|
async def get_generate_progress(gen_task_id: str):
|
||||||
|
"""Poll generation task progress"""
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM generate_task WHERE id=?", (gen_task_id,)
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
raise HTTPException(status_code=404, detail="Generate task not found")
|
||||||
|
return {"status": 0, "data": dict(rows[0])}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/generate-tasks/{dataset_id}")
|
||||||
|
async def list_generate_tasks(dataset_id: str):
|
||||||
|
"""List all generate tasks for a dataset"""
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM generate_task WHERE dataset_id=? ORDER BY created_at DESC",
|
||||||
|
(dataset_id,),
|
||||||
|
)
|
||||||
|
return {"status": 0, "data": [dict(r) for r in rows]}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sample/add")
|
||||||
|
async def add_sample(req: AddSampleReq):
|
||||||
|
async with get_db() as db:
|
||||||
|
row_id = _id()
|
||||||
|
await db.execute(
|
||||||
|
"""INSERT INTO eval_sample
|
||||||
|
(id,dataset_id,question,reference_answer,relevant_chunk_ids,knowledge_hub_id,source_file_id,metadata)
|
||||||
|
VALUES (?,?,?,?,?,?,?,?)""",
|
||||||
|
(row_id, req.dataset_id, req.question, req.reference_answer,
|
||||||
|
json.dumps(req.relevant_chunk_ids, ensure_ascii=False),
|
||||||
|
req.knowledge_hub_id, req.source_file_id,
|
||||||
|
json.dumps(req.metadata, ensure_ascii=False)),
|
||||||
|
)
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE eval_dataset SET sample_count=sample_count+1 WHERE id=?",
|
||||||
|
(req.dataset_id,),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0, "data": {"id": row_id}}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/import")
|
||||||
|
async def import_dataset(file: UploadFile = File(...)):
|
||||||
|
"""Upload a JSON file exported by the SDK (EvalDataset.to_dict())"""
|
||||||
|
content = await file.read()
|
||||||
|
data = json.loads(content)
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
ds_id = data.get("id") or _id()
|
||||||
|
await db.execute(
|
||||||
|
"INSERT OR REPLACE INTO eval_dataset (id,name,description,sample_count,created_at) VALUES (?,?,?,?,?)",
|
||||||
|
(ds_id, data["name"], data.get("description", ""), len(data.get("samples", [])), _now()),
|
||||||
|
)
|
||||||
|
for s in data.get("samples", []):
|
||||||
|
await db.execute(
|
||||||
|
"""INSERT OR REPLACE INTO eval_sample
|
||||||
|
(id,dataset_id,question,reference_answer,relevant_chunk_ids,knowledge_hub_id,source_file_id,metadata)
|
||||||
|
VALUES (?,?,?,?,?,?,?,?)""",
|
||||||
|
(s.get("id") or _id(), ds_id, s["question"], s.get("reference_answer", ""),
|
||||||
|
json.dumps(s.get("relevant_chunk_ids", []), ensure_ascii=False),
|
||||||
|
s.get("knowledge_hub_id", ""), s.get("source_file_id"),
|
||||||
|
json.dumps(s.get("metadata", {}), ensure_ascii=False)),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0, "data": {"id": ds_id, "imported": len(data.get("samples", []))}}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Dynamic path routes MUST come last ──────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/{dataset_id}")
|
||||||
|
async def get_dataset(dataset_id: str):
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM eval_sample WHERE dataset_id=?", (dataset_id,)
|
||||||
|
)
|
||||||
|
ds = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM eval_dataset WHERE id=?", (dataset_id,)
|
||||||
|
)
|
||||||
|
if not ds:
|
||||||
|
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||||
|
samples = [
|
||||||
|
{**dict(r), "relevant_chunk_ids": json.loads(r["relevant_chunk_ids"]),
|
||||||
|
"metadata": json.loads(r["metadata"])}
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
return {"status": 0, "data": {**dict(ds[0]), "samples": samples}}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{dataset_id}")
|
||||||
|
async def delete_dataset(dataset_id: str):
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute("DELETE FROM eval_sample WHERE dataset_id=?", (dataset_id,))
|
||||||
|
await db.execute("DELETE FROM eval_dataset WHERE id=?", (dataset_id,))
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0, "data": True}
|
||||||
629
server/api/loop.py
Normal file
629
server/api/loop.py
Normal file
@ -0,0 +1,629 @@
|
|||||||
|
"""
|
||||||
|
Loop task API - Automated QA generation and testing with pause/resume.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Form, HTTPException, Query
|
||||||
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent / "sdk"))
|
||||||
|
|
||||||
|
from models.db import get_db, _id, _now
|
||||||
|
from service.loop_recall_md import DEFAULT_LLM_NOTE, append_recall_md_section
|
||||||
|
from service.loop_engine import (
|
||||||
|
run_loop_task, pause_loop, resume_loop, stop_loop,
|
||||||
|
_loop_controls, _update_loop_stats
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/loop", tags=["Loop Task"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/task")
|
||||||
|
async def create_loop_task(
|
||||||
|
name: str = Form(...),
|
||||||
|
org_id: str = Form(...),
|
||||||
|
judge_config_id: str = Form(...),
|
||||||
|
file_ids: str = Form(""), # comma-separated
|
||||||
|
questions_per_section: int = Form(5),
|
||||||
|
quality_threshold: float = Form(0.6),
|
||||||
|
include_multimodal: bool = Form(True),
|
||||||
|
env_url: str = Form(...),
|
||||||
|
d_user_id: str = Form("test"),
|
||||||
|
agent_id: str = Form(""), # 用于召回测试的 agent ID
|
||||||
|
top_k: int = Form(64),
|
||||||
|
recall_top_k: int = Form(64),
|
||||||
|
concurrency: int = Form(20),
|
||||||
|
cross_chunk: bool = Form(True),
|
||||||
|
max_rounds: int = Form(0),
|
||||||
|
max_questions: int = Form(0),
|
||||||
|
global_dedup: bool = Form(False), # 是否全局去重(跨任务)
|
||||||
|
expected_chunk_count: int = Form(0), # 本批次切片总数,与 chunk_batches_plan.chunk_count 一致;>0 时校验拉取完整性
|
||||||
|
):
|
||||||
|
"""Create and start a loop task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
top_k: 用于判断切片/文件是否命中的阈值(默认64)
|
||||||
|
recall_top_k: 调用召回API时请求的top_k数量(默认64)
|
||||||
|
agent_id: 用于召回测试的 agent ID(可选,为空时直接调用知识库搜索)
|
||||||
|
expected_chunk_count: 可选;与批次 chunk_count 一致时,拉取不足会重试并最终失败,避免静默缺切片
|
||||||
|
"""
|
||||||
|
|
||||||
|
task_id = _id()
|
||||||
|
file_id_list = [f.strip() for f in file_ids.split(",") if f.strip()]
|
||||||
|
ecc = int(expected_chunk_count) if expected_chunk_count and int(expected_chunk_count) > 0 else None
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"""INSERT INTO loop_task
|
||||||
|
(id,name,org_id,judge_config_id,file_ids,questions_per_section,quality_threshold,
|
||||||
|
include_multimodal,env_url,d_user_id,agent_id,top_k,recall_top_k,concurrency,cross_chunk,
|
||||||
|
status,current_round,max_rounds,max_questions,total_generated,total_approved,
|
||||||
|
total_duplicates,total_tested,total_recalled,total_file_hit,total_file_miss,
|
||||||
|
total_recall_failed,global_dedup,expected_chunk_count,created_at)
|
||||||
|
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""",
|
||||||
|
(task_id, name, org_id, judge_config_id, ",".join(file_id_list),
|
||||||
|
questions_per_section, quality_threshold, int(include_multimodal),
|
||||||
|
env_url, d_user_id, agent_id, top_k, recall_top_k, concurrency, int(cross_chunk),
|
||||||
|
"pending", 0, max_rounds, max_questions,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, int(global_dedup), ecc, _now()),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# Start the loop in background
|
||||||
|
asyncio.create_task(run_loop_task(
|
||||||
|
loop_task_id=task_id,
|
||||||
|
org_id=org_id,
|
||||||
|
file_ids=file_id_list,
|
||||||
|
judge_config_id=judge_config_id,
|
||||||
|
questions_per_section=questions_per_section,
|
||||||
|
quality_threshold=quality_threshold,
|
||||||
|
include_multimodal=include_multimodal,
|
||||||
|
env_url=env_url,
|
||||||
|
d_user_id=d_user_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
top_k=top_k,
|
||||||
|
recall_top_k=recall_top_k,
|
||||||
|
concurrency=concurrency,
|
||||||
|
cross_chunk=cross_chunk,
|
||||||
|
max_rounds=max_rounds,
|
||||||
|
max_questions=max_questions,
|
||||||
|
global_dedup=global_dedup,
|
||||||
|
))
|
||||||
|
|
||||||
|
return {"status": 0, "data": {"id": task_id}}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/task/list")
|
||||||
|
async def list_loop_tasks(
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
page_size: int = Query(20, ge=1, le=100),
|
||||||
|
):
|
||||||
|
"""List all loop tasks with pagination."""
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"""SELECT * FROM loop_task
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT ? OFFSET ?""",
|
||||||
|
(page_size, offset),
|
||||||
|
)
|
||||||
|
total = await db.execute_fetchall(
|
||||||
|
"SELECT COUNT(*) as cnt FROM loop_task"
|
||||||
|
)
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
for row in rows:
|
||||||
|
task = dict(row)
|
||||||
|
# Calculate derived metrics
|
||||||
|
total_tested = task.get("total_tested") or 0
|
||||||
|
total_recalled = task.get("total_recalled") or 0
|
||||||
|
total_file_hit = task.get("total_file_hit") or 0
|
||||||
|
total_file_miss = task.get("total_file_miss") or 0
|
||||||
|
|
||||||
|
task["recall_rate"] = round(total_recalled / total_tested, 4) if total_tested > 0 else 0
|
||||||
|
task["file_hit_rate"] = round(total_file_hit / total_recalled, 4) if total_recalled > 0 else 0
|
||||||
|
task["file_miss_rate"] = round(total_file_miss / total_recalled, 4) if total_recalled > 0 else 0
|
||||||
|
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": 0,
|
||||||
|
"data": {
|
||||||
|
"total": total[0]["cnt"] if total else 0,
|
||||||
|
"items": tasks,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/task/{task_id}")
|
||||||
|
async def get_loop_task(task_id: str):
|
||||||
|
"""Get loop task details with cumulative stats."""
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM loop_task WHERE id=?", (task_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
|
task = dict(rows[0])
|
||||||
|
|
||||||
|
# Calculate rates
|
||||||
|
total_tested = task.get("total_tested") or 0
|
||||||
|
total_recalled = task.get("total_recalled") or 0
|
||||||
|
total_file_hit = task.get("total_file_hit") or 0
|
||||||
|
total_file_miss = task.get("total_file_miss") or 0
|
||||||
|
|
||||||
|
task["recall_rate"] = round(total_recalled / total_tested, 4) if total_tested > 0 else 0
|
||||||
|
task["file_hit_rate"] = round(total_file_hit / total_recalled, 4) if total_recalled > 0 else 0
|
||||||
|
task["file_miss_rate"] = round(total_file_miss / total_recalled, 4) if total_recalled > 0 else 0
|
||||||
|
|
||||||
|
return {"status": 0, "data": task}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/task/{task_id}/pause")
|
||||||
|
async def pause_task(task_id: str):
|
||||||
|
"""Pause a running loop task."""
|
||||||
|
result = await pause_loop(task_id)
|
||||||
|
if not result:
|
||||||
|
raise HTTPException(status_code=400, detail="Task not running")
|
||||||
|
|
||||||
|
# 返回更新后的任务状态
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM loop_task WHERE id=?", (task_id,)
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
|
task = dict(rows[0])
|
||||||
|
return {"status": 0, "data": task}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/task/{task_id}/resume")
|
||||||
|
async def resume_task(task_id: str):
|
||||||
|
"""Resume a paused loop task."""
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT status FROM loop_task WHERE id=?", (task_id,)
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
|
if dict(rows[0])["status"] != "paused":
|
||||||
|
raise HTTPException(status_code=400, detail="Task not paused")
|
||||||
|
|
||||||
|
# 立即把状态改成 running,让前端马上看到反馈
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE loop_task SET status='running', paused_at=NULL WHERE id=?",
|
||||||
|
(task_id,),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# 尝试唤醒内存中的任务
|
||||||
|
result = await resume_loop(task_id)
|
||||||
|
if not result:
|
||||||
|
# 内存中没有(服务重启过),重新启动任务
|
||||||
|
async with get_db() as db:
|
||||||
|
task_rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM loop_task WHERE id=?", (task_id,)
|
||||||
|
)
|
||||||
|
task = dict(task_rows[0])
|
||||||
|
file_ids = [f.strip() for f in (task.get("file_ids") or "").split(",") if f.strip()]
|
||||||
|
|
||||||
|
asyncio.create_task(run_loop_task(
|
||||||
|
loop_task_id=task_id,
|
||||||
|
org_id=task["org_id"],
|
||||||
|
file_ids=file_ids,
|
||||||
|
judge_config_id=task["judge_config_id"],
|
||||||
|
questions_per_section=task["questions_per_section"],
|
||||||
|
quality_threshold=task["quality_threshold"],
|
||||||
|
include_multimodal=bool(task["include_multimodal"]),
|
||||||
|
env_url=task["env_url"],
|
||||||
|
d_user_id=task["d_user_id"],
|
||||||
|
agent_id=task.get("agent_id", ""),
|
||||||
|
top_k=task["top_k"],
|
||||||
|
recall_top_k=task.get("recall_top_k", 64),
|
||||||
|
concurrency=task["concurrency"],
|
||||||
|
cross_chunk=bool(task["cross_chunk"]),
|
||||||
|
max_rounds=task["max_rounds"],
|
||||||
|
max_questions=task["max_questions"],
|
||||||
|
global_dedup=bool(task.get("global_dedup", 0)),
|
||||||
|
))
|
||||||
|
|
||||||
|
# 返回更新后的任务状态
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM loop_task WHERE id=?", (task_id,)
|
||||||
|
)
|
||||||
|
task = dict(rows[0])
|
||||||
|
return {"status": 0, "data": task}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/task/{task_id}/stop")
|
||||||
|
async def stop_task(task_id: str):
|
||||||
|
"""Stop a loop task permanently."""
|
||||||
|
# Check task exists and is running or paused
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT status FROM loop_task WHERE id=?", (task_id,)
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
|
status = rows[0]["status"]
|
||||||
|
if status not in ("running", "paused"):
|
||||||
|
raise HTTPException(status_code=400, detail="Task not running or paused")
|
||||||
|
|
||||||
|
# Try to stop via control structure (if running)
|
||||||
|
from service.loop_engine import _loop_controls
|
||||||
|
ctrl = _loop_controls.get(task_id)
|
||||||
|
if ctrl:
|
||||||
|
ctrl["stop"] = True
|
||||||
|
ctrl["pause_event"].set()
|
||||||
|
|
||||||
|
# Update database status regardless
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE loop_task SET status='stopped', finished_at=? WHERE id=?",
|
||||||
|
(_now(), task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"status": 0, "data": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/task/{task_id}")
|
||||||
|
async def delete_task(task_id: str):
|
||||||
|
"""Delete loop task and all related data."""
|
||||||
|
|
||||||
|
# First stop any running background task
|
||||||
|
from service.loop_engine import _loop_controls
|
||||||
|
ctrl = _loop_controls.get(task_id)
|
||||||
|
if ctrl:
|
||||||
|
ctrl["stop"] = True
|
||||||
|
ctrl["pause_event"].set()
|
||||||
|
_loop_controls.pop(task_id, None)
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
# Get all rounds to delete related tasks
|
||||||
|
rounds = await db.execute_fetchall(
|
||||||
|
"SELECT qa_gen_task_id, single_jump_task_id FROM loop_round WHERE loop_task_id=?",
|
||||||
|
(task_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
for r in rounds:
|
||||||
|
qa_id = r["qa_gen_task_id"]
|
||||||
|
sj_id = r["single_jump_task_id"]
|
||||||
|
|
||||||
|
# Delete QA questions
|
||||||
|
if qa_id:
|
||||||
|
await db.execute(
|
||||||
|
"DELETE FROM qa_gen_question WHERE task_id=?", (qa_id,)
|
||||||
|
)
|
||||||
|
await db.execute(
|
||||||
|
"DELETE FROM qa_gen_task WHERE id=?", (qa_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete single-jump results
|
||||||
|
if sj_id:
|
||||||
|
await db.execute(
|
||||||
|
"DELETE FROM single_jump_result WHERE task_id=?", (sj_id,)
|
||||||
|
)
|
||||||
|
await db.execute(
|
||||||
|
"DELETE FROM single_jump_task WHERE id=?", (sj_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete rounds
|
||||||
|
await db.execute(
|
||||||
|
"DELETE FROM loop_round WHERE loop_task_id=?", (task_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete task
|
||||||
|
await db.execute(
|
||||||
|
"DELETE FROM loop_task WHERE id=?", (task_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"status": 0, "data": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/task/{task_id}/rounds")
|
||||||
|
async def get_rounds(task_id: str):
|
||||||
|
"""Get all rounds for a loop task."""
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"""SELECT * FROM loop_round
|
||||||
|
WHERE loop_task_id=?
|
||||||
|
ORDER BY round_number""",
|
||||||
|
(task_id,),
|
||||||
|
)
|
||||||
|
# Convert rows to dicts while connection is still open
|
||||||
|
rounds = [dict(r) for r in rows]
|
||||||
|
|
||||||
|
return {"status": 0, "data": rounds}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/task/{task_id}/questions")
|
||||||
|
async def get_questions(
|
||||||
|
task_id: str,
|
||||||
|
status: Optional[str] = Query(None), # approved, rejected, duplicate
|
||||||
|
category: Optional[str] = Query(None), # hit, file_miss, recall_failed
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
page_size: int = Query(20, ge=1, le=100),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get questions across all rounds.
|
||||||
|
|
||||||
|
- status: filter by qa_gen_question status
|
||||||
|
- category: filter by test result category
|
||||||
|
"""
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
|
# Build query
|
||||||
|
where_clauses = ["lr.loop_task_id = ?"]
|
||||||
|
params = [task_id]
|
||||||
|
|
||||||
|
if status:
|
||||||
|
if status == "duplicate":
|
||||||
|
where_clauses.append("q.dup_of IS NOT NULL")
|
||||||
|
else:
|
||||||
|
where_clauses.append("q.status = ?")
|
||||||
|
params.append(status)
|
||||||
|
|
||||||
|
if category:
|
||||||
|
if category == "hit":
|
||||||
|
where_clauses.append("r.is_file_hit = 1")
|
||||||
|
elif category == "file_miss":
|
||||||
|
where_clauses.append("r.is_file_hit = 0 AND COALESCE(json_array_length(r.retrieved), 0) > 0")
|
||||||
|
elif category == "recall_failed":
|
||||||
|
where_clauses.append("COALESCE(json_array_length(r.retrieved), 0) = 0 AND r.error IS NULL")
|
||||||
|
|
||||||
|
where_sql = " AND ".join(where_clauses)
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
f"""SELECT
|
||||||
|
q.id, q.section_path, q.question, q.reference_answer,
|
||||||
|
q.source_chunk, q.quality_score, q.status,
|
||||||
|
q.dup_of, q.dup_similarity,
|
||||||
|
q.chunk_headers, q.chunk_id, q.file_name,
|
||||||
|
lr.round_number,
|
||||||
|
r.is_file_hit, r.retrieved, r.best_cosine_sim, r.latency_ms, r.error,
|
||||||
|
r.expected_chunk_id, r.is_chunk_hit, r.chunk_hit_rank
|
||||||
|
FROM qa_gen_question q
|
||||||
|
JOIN loop_round lr ON q.task_id = lr.qa_gen_task_id
|
||||||
|
LEFT JOIN single_jump_result r ON r.rowid = (
|
||||||
|
SELECT r2.rowid FROM single_jump_result r2
|
||||||
|
WHERE r2.task_id = lr.single_jump_task_id AND r2.question = q.question
|
||||||
|
ORDER BY r2.rowid DESC LIMIT 1
|
||||||
|
)
|
||||||
|
WHERE {where_sql}
|
||||||
|
ORDER BY lr.round_number DESC, q.created_at DESC
|
||||||
|
LIMIT ? OFFSET ?""",
|
||||||
|
(*params, page_size, offset),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert rows to dicts while connection is still open
|
||||||
|
questions = [dict(r) for r in rows]
|
||||||
|
|
||||||
|
# Get total count
|
||||||
|
total_rows = await db.execute_fetchall(
|
||||||
|
f"""SELECT COUNT(DISTINCT q.id) as cnt
|
||||||
|
FROM qa_gen_question q
|
||||||
|
JOIN loop_round lr ON q.task_id = lr.qa_gen_task_id
|
||||||
|
LEFT JOIN single_jump_result r ON r.rowid = (
|
||||||
|
SELECT r2.rowid FROM single_jump_result r2
|
||||||
|
WHERE r2.task_id = lr.single_jump_task_id AND r2.question = q.question
|
||||||
|
ORDER BY r2.rowid DESC LIMIT 1
|
||||||
|
)
|
||||||
|
WHERE {where_sql}""",
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": 0,
|
||||||
|
"data": {
|
||||||
|
"total": total_rows[0]["cnt"] if total_rows else 0,
|
||||||
|
"items": questions,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/task/{task_id}/export")
|
||||||
|
async def export_questions(
|
||||||
|
task_id: str,
|
||||||
|
category: str = Query("all"), # all, hit, file_miss, recall_failed
|
||||||
|
format: str = Query("md"), # md, json
|
||||||
|
):
|
||||||
|
"""Export questions to MD or JSON format."""
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
# Check if we have qa_gen_task_id in loop_round
|
||||||
|
has_qa_task = await db.execute_fetchall(
|
||||||
|
"""SELECT COUNT(*) as cnt FROM loop_round
|
||||||
|
WHERE loop_task_id=? AND qa_gen_task_id IS NOT NULL""",
|
||||||
|
(task_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
use_qa_task = has_qa_task[0]["cnt"] > 0 if has_qa_task else False
|
||||||
|
|
||||||
|
# Build where clause based on category
|
||||||
|
if use_qa_task:
|
||||||
|
# New tasks: query from qa_gen_question and join single_jump_result for expected_chunk_id
|
||||||
|
if category == "hit":
|
||||||
|
where_clause = "r.is_file_hit = 1"
|
||||||
|
elif category == "file_miss":
|
||||||
|
where_clause = "r.is_file_hit = 0 AND COALESCE(json_array_length(r.retrieved), 0) > 0"
|
||||||
|
elif category == "recall_failed":
|
||||||
|
where_clause = "COALESCE(json_array_length(r.retrieved), 0) = 0 AND r.error IS NULL"
|
||||||
|
else: # all
|
||||||
|
where_clause = "1=1"
|
||||||
|
|
||||||
|
# 注意:不要用 JOIN qa_gen_question ON chunk_id,同一 chunk 下多题会行膨胀导致导出重复。
|
||||||
|
# single_jump_result 若同一 task 下同题干有多行,只取最新一条(rowid 最大)。
|
||||||
|
db_rows = await db.execute_fetchall(
|
||||||
|
f"""SELECT
|
||||||
|
q.id as qa_question_id,
|
||||||
|
q.section_path, q.file_name, q.question, q.reference_answer,
|
||||||
|
q.source_chunk, q.quality_score, q.status,
|
||||||
|
q.dup_of, q.dup_similarity,
|
||||||
|
q.chunk_headers, q.chunk_id,
|
||||||
|
lr.round_number,
|
||||||
|
r.is_file_hit, r.retrieved, r.best_cosine_sim,
|
||||||
|
r.expected_chunk_id,
|
||||||
|
(SELECT q2b.chunk_headers FROM qa_gen_question q2b
|
||||||
|
WHERE q2b.chunk_id = r.expected_chunk_id
|
||||||
|
AND q2b.chunk_id IS NOT NULL AND trim(COALESCE(q2b.chunk_headers, '')) != ''
|
||||||
|
LIMIT 1) AS expected_chunk_name
|
||||||
|
FROM qa_gen_question q
|
||||||
|
JOIN loop_round lr ON q.task_id = lr.qa_gen_task_id
|
||||||
|
LEFT JOIN single_jump_result r ON r.rowid = (
|
||||||
|
SELECT r2.rowid FROM single_jump_result r2
|
||||||
|
WHERE r2.task_id = lr.single_jump_task_id AND r2.question = q.question
|
||||||
|
ORDER BY r2.rowid DESC LIMIT 1
|
||||||
|
)
|
||||||
|
WHERE lr.loop_task_id = ? AND q.status = 'approved' AND {where_clause}
|
||||||
|
ORDER BY lr.round_number, q.chunk_headers, q.created_at""",
|
||||||
|
(task_id,),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Old tasks: query from single_jump_result directly
|
||||||
|
if category == "hit":
|
||||||
|
where_clause = "r.is_file_hit = 1"
|
||||||
|
elif category == "file_miss":
|
||||||
|
where_clause = "r.is_file_hit = 0 AND COALESCE(json_array_length(r.retrieved), 0) > 0"
|
||||||
|
elif category == "recall_failed":
|
||||||
|
where_clause = "COALESCE(json_array_length(r.retrieved), 0) = 0 AND r.error IS NULL"
|
||||||
|
else: # all
|
||||||
|
where_clause = "1=1"
|
||||||
|
|
||||||
|
db_rows = await db.execute_fetchall(
|
||||||
|
f"""SELECT
|
||||||
|
r.rowid as result_rowid,
|
||||||
|
r.section_path, r.file_name, r.question, r.reference_answer,
|
||||||
|
'' as source_chunk, 1.0 as quality_score, 'approved' as status,
|
||||||
|
NULL as dup_of, NULL as dup_similarity,
|
||||||
|
COALESCE(r.raw_chunk_headers, r.section_path) as chunk_headers,
|
||||||
|
r.expected_chunk_id as chunk_id,
|
||||||
|
lr.round_number,
|
||||||
|
r.is_file_hit, r.retrieved, r.best_cosine_sim,
|
||||||
|
r.expected_chunk_id,
|
||||||
|
(SELECT qb.chunk_headers FROM qa_gen_question qb
|
||||||
|
WHERE qb.chunk_id = r.expected_chunk_id LIMIT 1) AS expected_chunk_name
|
||||||
|
FROM single_jump_result r
|
||||||
|
JOIN loop_round lr ON r.task_id = lr.single_jump_task_id
|
||||||
|
WHERE lr.loop_task_id = ? AND {where_clause}
|
||||||
|
ORDER BY lr.round_number, r.section_path""",
|
||||||
|
(task_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert rows to dicts while connection is still open
|
||||||
|
rows = [dict(row) for row in db_rows]
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
# Return empty response if no data
|
||||||
|
from fastapi.responses import PlainTextResponse
|
||||||
|
return PlainTextResponse(
|
||||||
|
"没有符合条件的数据",
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
# Group by section
|
||||||
|
from collections import defaultdict
|
||||||
|
sections: dict[str, list] = defaultdict(list)
|
||||||
|
for row in rows:
|
||||||
|
# Use chunk_headers as the grouping key if available, otherwise use section_path
|
||||||
|
section_key = row.get("chunk_headers") or row.get("section_path") or row.get("file_name") or "default"
|
||||||
|
sections[section_key].append(row)
|
||||||
|
|
||||||
|
if format == "json":
|
||||||
|
# JSON export
|
||||||
|
data = {
|
||||||
|
"task_id": task_id,
|
||||||
|
"category": category,
|
||||||
|
"exported_at": _now(),
|
||||||
|
"questions": [],
|
||||||
|
}
|
||||||
|
for section_path, items in sections.items():
|
||||||
|
for item in items:
|
||||||
|
data["questions"].append({
|
||||||
|
"section_path": section_path,
|
||||||
|
"file_name": item.get("file_name"),
|
||||||
|
"round": item["round_number"],
|
||||||
|
"question": item["question"],
|
||||||
|
"reference_answer": item["reference_answer"],
|
||||||
|
"source_chunk": item["source_chunk"],
|
||||||
|
"quality_score": item["quality_score"],
|
||||||
|
"status": item["status"],
|
||||||
|
"is_duplicate": bool(item.get("dup_of")),
|
||||||
|
"dup_similarity": item.get("dup_similarity"),
|
||||||
|
"is_file_hit": bool(item.get("is_file_hit")),
|
||||||
|
"recall_results": json.loads(item["retrieved"]) if item.get("retrieved") else [],
|
||||||
|
"best_cosine_sim": item["best_cosine_sim"],
|
||||||
|
"expected_chunk_id": item.get("expected_chunk_id"),
|
||||||
|
"expected_chunk_name": item.get("expected_chunk_name"),
|
||||||
|
"chunk_id": item.get("chunk_id") or item.get("expected_chunk_id"),
|
||||||
|
})
|
||||||
|
|
||||||
|
content = json.dumps(data, ensure_ascii=False, indent=2)
|
||||||
|
filename = f"loop_{task_id}_{category}.json"
|
||||||
|
media_type = "application/json"
|
||||||
|
|
||||||
|
else:
|
||||||
|
# MD export:与单跳解析器、循环内单跳 MD、离线脚本同一套 loop_recall_md
|
||||||
|
lines: list[str] = []
|
||||||
|
|
||||||
|
def _after_answer(_i: int, item: dict):
|
||||||
|
if item.get("expected_chunk_name"):
|
||||||
|
yield f"> 预期切片: {item['expected_chunk_name']}"
|
||||||
|
sc = item.get("source_chunk")
|
||||||
|
if sc:
|
||||||
|
yield f"> Source: {str(sc)[:200]}..."
|
||||||
|
|
||||||
|
section_index = 0
|
||||||
|
for section_key, items in sections.items():
|
||||||
|
section_index += 1
|
||||||
|
file_name = (items[0].get("file_name") or "").strip()
|
||||||
|
slice_title = (items[0].get("chunk_headers") or "").strip() or section_key
|
||||||
|
meta = [f"> 代表轮次: {items[0]['round_number']}", DEFAULT_LLM_NOTE]
|
||||||
|
if category != "all":
|
||||||
|
meta.insert(0, f"> 导出分类: {category}")
|
||||||
|
qa_items = [
|
||||||
|
{
|
||||||
|
"question": it["question"],
|
||||||
|
"reference_answer": it["reference_answer"],
|
||||||
|
"chunk_id": (it.get("chunk_id") or it.get("expected_chunk_id") or ""),
|
||||||
|
}
|
||||||
|
for it in items
|
||||||
|
]
|
||||||
|
append_recall_md_section(
|
||||||
|
lines,
|
||||||
|
section_index,
|
||||||
|
file_name=file_name,
|
||||||
|
slice_title=slice_title,
|
||||||
|
qa_items=qa_items,
|
||||||
|
meta_lines=meta,
|
||||||
|
after_answer_lines=_after_answer,
|
||||||
|
)
|
||||||
|
|
||||||
|
content = "\n".join(lines)
|
||||||
|
filename = f"loop_{task_id}_{category}.md"
|
||||||
|
media_type = "text/markdown"
|
||||||
|
|
||||||
|
from urllib.parse import quote
|
||||||
|
filename_encoded = quote(filename)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
BytesIO(content.encode("utf-8")),
|
||||||
|
media_type=media_type,
|
||||||
|
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{filename_encoded}"},
|
||||||
|
)
|
||||||
282
server/api/multi_hop.py
Normal file
282
server/api/multi_hop.py
Normal file
@ -0,0 +1,282 @@
|
|||||||
|
"""
|
||||||
|
多跳召回测试 API
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "sdk"))
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
from models.db import get_db, _now, _id
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/multi-hop", tags=["多跳召回测试"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/dagent/agents")
|
||||||
|
async def list_dagent_agents(env_url: str, org_id: str, d_user_id: str = "test"):
|
||||||
|
"""从 dagent 平台拉取可用的 Agent 列表"""
|
||||||
|
import aiohttp
|
||||||
|
url = f"{env_url.rstrip('/')}/dagent/agent/page"
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"d-user-id": d_user_id,
|
||||||
|
"org-id": org_id,
|
||||||
|
}
|
||||||
|
payload = {"current": 1, "page_size": 100, "org_id": org_id}
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(headers=headers) as session:
|
||||||
|
async with session.post(url, json=payload, timeout=aiohttp.ClientTimeout(total=10)) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = await resp.json()
|
||||||
|
agents = data.get("data", {}).get("list", [])
|
||||||
|
return {"status": 0, "data": [
|
||||||
|
{"id": a.get("id"), "name": a.get("agent_name"), "type": a.get("agent_type"), "description": a.get("agent_description")}
|
||||||
|
for a in agents
|
||||||
|
]}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=502, detail=f"无法连接 dagent: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/task")
|
||||||
|
async def create_task(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
name: str = Form(""),
|
||||||
|
env_url: str = Form(...),
|
||||||
|
org_id: str = Form(...),
|
||||||
|
d_user_id: str = Form("test"),
|
||||||
|
agent_id: str = Form(...),
|
||||||
|
llm_type: str = Form("deepseek_v3"),
|
||||||
|
top_k: int = Form(10),
|
||||||
|
concurrency: int = Form(5),
|
||||||
|
):
|
||||||
|
content = await file.read()
|
||||||
|
qa_text = content.decode("utf-8")
|
||||||
|
|
||||||
|
task_id = _id()
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"""INSERT INTO multi_hop_task
|
||||||
|
(id,name,env_url,org_id,d_user_id,agent_id,llm_type,top_k,concurrency,status,created_at)
|
||||||
|
VALUES (?,?,?,?,?,?,?,?,?,?,?)""",
|
||||||
|
(task_id, name or file.filename, env_url, org_id,
|
||||||
|
d_user_id, agent_id, llm_type, top_k, concurrency, "pending", _now()),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
asyncio.create_task(_run_task(
|
||||||
|
task_id, qa_text, env_url, org_id, d_user_id, agent_id, llm_type, top_k, concurrency
|
||||||
|
))
|
||||||
|
return {"status": 0, "data": {"id": task_id}}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/task/list")
|
||||||
|
async def list_tasks():
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM multi_hop_task ORDER BY created_at DESC"
|
||||||
|
)
|
||||||
|
return {"status": 0, "data": [dict(r) for r in rows]}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/task/{task_id}")
|
||||||
|
async def get_task(task_id: str):
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM multi_hop_task WHERE id=?", (task_id,)
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
return {"status": 0, "data": dict(rows[0])}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/task/{task_id}")
|
||||||
|
async def delete_task(task_id: str):
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute("DELETE FROM multi_hop_result WHERE task_id=?", (task_id,))
|
||||||
|
await db.execute("DELETE FROM multi_hop_task WHERE id=?", (task_id,))
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/task/{task_id}/results")
|
||||||
|
async def get_results(task_id: str):
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM multi_hop_result WHERE task_id=? ORDER BY qid",
|
||||||
|
(task_id,),
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
for r in rows:
|
||||||
|
d = dict(r)
|
||||||
|
d["hops"] = json.loads(d.get("hops") or "[]")
|
||||||
|
d["actual_hops"] = json.loads(d.get("actual_hops") or "[]")
|
||||||
|
d["retrieved"] = json.loads(d.get("retrieved") or "[]")
|
||||||
|
results.append(d)
|
||||||
|
return {"status": 0, "data": results}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/task/{task_id}/summary")
|
||||||
|
async def get_summary(task_id: str):
|
||||||
|
async with get_db() as db:
|
||||||
|
task_rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM multi_hop_task WHERE id=?", (task_id,)
|
||||||
|
)
|
||||||
|
if not task_rows:
|
||||||
|
raise HTTPException(status_code=404)
|
||||||
|
task = dict(task_rows[0])
|
||||||
|
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"""SELECT
|
||||||
|
COUNT(*) as total,
|
||||||
|
SUM(CASE WHEN error IS NOT NULL THEN 1 ELSE 0 END) as errors,
|
||||||
|
SUM(full_hit) as full_hit_count,
|
||||||
|
SUM(partial_hit) as partial_hit_count,
|
||||||
|
SUM(full_chunk_hit) as full_chunk_hit_count,
|
||||||
|
SUM(partial_chunk_hit) as partial_chunk_hit_count,
|
||||||
|
AVG(CASE WHEN hop_count > 0 THEN CAST(hop_hit_count AS REAL) / hop_count ELSE 0 END) as avg_hop_hit_rate,
|
||||||
|
AVG(CASE WHEN hop_count > 0 THEN CAST(chunk_hit_count AS REAL) / hop_count ELSE 0 END) as avg_chunk_hit_rate,
|
||||||
|
AVG(latency_ms) as avg_latency_ms
|
||||||
|
FROM multi_hop_result WHERE task_id=?""",
|
||||||
|
(task_id,),
|
||||||
|
)
|
||||||
|
stats = dict(rows[0]) if rows else {}
|
||||||
|
|
||||||
|
total = stats.get("total") or 0
|
||||||
|
full_hit = stats.get("full_hit_count") or 0
|
||||||
|
partial_hit = stats.get("partial_hit_count") or 0
|
||||||
|
full_chunk_hit = stats.get("full_chunk_hit_count") or 0
|
||||||
|
partial_chunk_hit = stats.get("partial_chunk_hit_count") or 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": 0,
|
||||||
|
"data": {
|
||||||
|
"task": task,
|
||||||
|
"total": total,
|
||||||
|
"full_hit_count": full_hit,
|
||||||
|
"full_hit_rate": round(full_hit / total, 4) if total else 0.0,
|
||||||
|
"partial_hit_count": partial_hit,
|
||||||
|
"partial_hit_rate": round(partial_hit / total, 4) if total else 0.0,
|
||||||
|
"full_chunk_hit_count": full_chunk_hit,
|
||||||
|
"full_chunk_hit_rate": round(full_chunk_hit / total, 4) if total else 0.0,
|
||||||
|
"partial_chunk_hit_count": partial_chunk_hit,
|
||||||
|
"partial_chunk_hit_rate": round(partial_chunk_hit / total, 4) if total else 0.0,
|
||||||
|
"error_count": stats.get("errors") or 0,
|
||||||
|
"avg_hop_hit_rate": round(stats.get("avg_hop_hit_rate") or 0.0, 4),
|
||||||
|
"avg_chunk_hit_rate": round(stats.get("avg_chunk_hit_rate") or 0.0, 4),
|
||||||
|
"avg_latency_ms": round(stats.get("avg_latency_ms") or 0.0, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── 后台执行 ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _run_task(task_id: str, qa_text: str, env_url: str, org_id: str,
|
||||||
|
d_user_id: str, agent_id: str, llm_type: str,
|
||||||
|
top_k: int, concurrency: int):
|
||||||
|
try:
|
||||||
|
from rag_eval.multi_hop.parser import parse_multi_hop_text
|
||||||
|
from rag_eval.multi_hop.tester import MultiHopTester
|
||||||
|
from rag_eval.single_jump.mapper import FileMapper
|
||||||
|
|
||||||
|
case = parse_multi_hop_text(qa_text)
|
||||||
|
qa_pairs = case.qa_pairs
|
||||||
|
if not qa_pairs:
|
||||||
|
raise ValueError("未解析到任何多跳问答对")
|
||||||
|
|
||||||
|
total = len(qa_pairs)
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE multi_hop_task SET status='running', total=? WHERE id=?",
|
||||||
|
(total, task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
mapper = FileMapper(env_url, org_id, d_user_id)
|
||||||
|
await mapper.load_files()
|
||||||
|
all_paths = {hop.section_path for qa in qa_pairs for hop in qa.hops}
|
||||||
|
file_map = {path: mapper.map_section_to_file(path) for path in all_paths}
|
||||||
|
|
||||||
|
tester = MultiHopTester(
|
||||||
|
env_url, org_id, d_user_id,
|
||||||
|
agent_id=agent_id, llm_type=llm_type or "deepseek_v3",
|
||||||
|
)
|
||||||
|
|
||||||
|
write_buf = []
|
||||||
|
FLUSH_SIZE = 20
|
||||||
|
|
||||||
|
async def flush_buf(buf: list, progress: int):
|
||||||
|
async with get_db() as db2:
|
||||||
|
for r in buf:
|
||||||
|
await db2.execute(
|
||||||
|
"""INSERT INTO multi_hop_result
|
||||||
|
(id,task_id,qid,question,answer,type,top_k,
|
||||||
|
hops,actual_hops,retrieved,agent_answer,
|
||||||
|
latency_ms,error,best_cosine_sim,
|
||||||
|
full_hit,partial_hit,hop_count,hop_hit_count,
|
||||||
|
chunk_hit_count,full_chunk_hit,partial_chunk_hit)
|
||||||
|
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""",
|
||||||
|
(
|
||||||
|
_id(), task_id, r.qid, r.question, r.answer, r.type, r.top_k,
|
||||||
|
json.dumps([{
|
||||||
|
"section_path": h.section_path,
|
||||||
|
"file_id": h.file_id,
|
||||||
|
"file_name": h.file_name,
|
||||||
|
"hit": h.hit,
|
||||||
|
"hit_at_hop": h.hit_at_hop,
|
||||||
|
"contribution": h.contribution,
|
||||||
|
"expected_chunk_id": h.expected_chunk_id,
|
||||||
|
"chunk_hit": h.chunk_hit,
|
||||||
|
"chunk_hit_at_hop": h.chunk_hit_at_hop,
|
||||||
|
} for h in r.hop_results], ensure_ascii=False),
|
||||||
|
json.dumps([{
|
||||||
|
"hop_index": ah.hop_index,
|
||||||
|
"query": ah.query,
|
||||||
|
"retrieved": ah.retrieved,
|
||||||
|
} for ah in r.actual_hops], ensure_ascii=False),
|
||||||
|
json.dumps(r.retrieved, ensure_ascii=False),
|
||||||
|
r.agent_answer or "",
|
||||||
|
r.latency_ms, r.error, r.best_cosine_sim,
|
||||||
|
int(r.full_hit), int(r.partial_hit),
|
||||||
|
r.hop_count, r.hop_hit_count,
|
||||||
|
r.chunk_hit_count,
|
||||||
|
int(r.full_chunk_hit), int(r.partial_chunk_hit),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await db2.execute(
|
||||||
|
"UPDATE multi_hop_task SET progress=? WHERE id=?", (progress, task_id)
|
||||||
|
)
|
||||||
|
await db2.commit()
|
||||||
|
|
||||||
|
async def on_result(r, done, _total):
|
||||||
|
write_buf.append(r)
|
||||||
|
if len(write_buf) >= FLUSH_SIZE or done == _total:
|
||||||
|
buf = write_buf[:]
|
||||||
|
write_buf.clear()
|
||||||
|
await flush_buf(buf, done)
|
||||||
|
|
||||||
|
await tester.run(
|
||||||
|
qa_pairs, file_map,
|
||||||
|
top_k=top_k, concurrency=concurrency,
|
||||||
|
result_cb=on_result,
|
||||||
|
)
|
||||||
|
|
||||||
|
if write_buf:
|
||||||
|
await flush_buf(write_buf, total)
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE multi_hop_task SET status='done', finished_at=? WHERE id=?",
|
||||||
|
(_now(), task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE multi_hop_task SET status='failed', error_message=? WHERE id=?",
|
||||||
|
(str(exc), task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
919
server/api/multi_hop_gen.py
Normal file
919
server/api/multi_hop_gen.py
Normal file
@ -0,0 +1,919 @@
|
|||||||
|
"""
|
||||||
|
多跳问答生成 API
|
||||||
|
|
||||||
|
支持两种数据源:
|
||||||
|
1. 上传知识库 MD 文件(与 qa_gen 相同格式)
|
||||||
|
2. 从 Dagent 远程数据库拉取段落,按文件分组生成跨文件多跳问答对
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
from models.db import get_db, _now, _id
|
||||||
|
from api.qa_gen_dagent import get_dagent_conn, _fetch_paragraphs
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/multi-hop-gen", tags=["多跳问答生成"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── 任务 CRUD ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/task")
|
||||||
|
async def create_task(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
name: str = Form(""),
|
||||||
|
judge_config_id: str = Form(...),
|
||||||
|
hops_per_question: int = Form(2),
|
||||||
|
questions_per_group: int = Form(3),
|
||||||
|
quality_threshold: float = Form(0.6),
|
||||||
|
prompt_template_id: str = Form(""),
|
||||||
|
):
|
||||||
|
content = await file.read()
|
||||||
|
md_text = content.decode("utf-8")
|
||||||
|
|
||||||
|
task_id = _id()
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"""INSERT INTO multi_hop_gen_task
|
||||||
|
(id,name,source,judge_config_id,hops_per_question,questions_per_group,
|
||||||
|
quality_threshold,prompt_template_id,status,created_at)
|
||||||
|
VALUES (?,?,?,?,?,?,?,?,?,?)""",
|
||||||
|
(task_id, name or file.filename, "file", judge_config_id,
|
||||||
|
hops_per_question, questions_per_group, quality_threshold,
|
||||||
|
prompt_template_id or None, "pending", _now()),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
asyncio.create_task(_run_task(task_id, md_text))
|
||||||
|
return {"status": 0, "data": {"id": task_id}}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Dagent 数据源接口 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/dagent/stats")
|
||||||
|
async def get_dagent_stats(org_id: str, env_url: str = ""):
|
||||||
|
"""获取 Dagent 知识库统计信息(通过 HTTP API)"""
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
base_url = (env_url or "https://dagent.d-robotics.cc").rstrip("/")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"org-id": org_id,
|
||||||
|
"d-user-id": "test",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(headers=headers) as session:
|
||||||
|
page = 1
|
||||||
|
page_size = 100
|
||||||
|
total_files = 0
|
||||||
|
total_paragraphs = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
async with session.post(
|
||||||
|
f"{base_url}/dagent/knowledge/file/page",
|
||||||
|
json={"current": page, "page_size": page_size, "org_id": org_id},
|
||||||
|
timeout=aiohttp.ClientTimeout(total=15),
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
break
|
||||||
|
data = await resp.json()
|
||||||
|
files = data.get("data", {}).get("list", [])
|
||||||
|
if not files:
|
||||||
|
break
|
||||||
|
|
||||||
|
total_files += len(files)
|
||||||
|
|
||||||
|
for f in files:
|
||||||
|
try:
|
||||||
|
async with session.post(
|
||||||
|
f"{base_url}/dagent/knowledge/chunk/page",
|
||||||
|
json={"file_id": f["id"], "org_id": org_id, "page": 1, "page_size": 1},
|
||||||
|
timeout=aiohttp.ClientTimeout(total=10),
|
||||||
|
) as cr:
|
||||||
|
if cr.status == 200:
|
||||||
|
cd = await cr.json()
|
||||||
|
total_paragraphs += cd.get("data", {}).get("total", 0)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if len(files) < page_size:
|
||||||
|
break
|
||||||
|
page += 1
|
||||||
|
|
||||||
|
return {"status": 0, "data": {
|
||||||
|
"file_count": total_files,
|
||||||
|
"paragraph_count": total_paragraphs,
|
||||||
|
"total_images": 0,
|
||||||
|
"paragraphs_with_pic_text": 0,
|
||||||
|
}}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[get_dagent_stats] Error: {e}")
|
||||||
|
return {"status": 0, "data": {}}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/dagent/files")
|
||||||
|
async def list_dagent_files(org_id: str, env_url: str = ""):
|
||||||
|
"""列出 Dagent 中某组织下已处理完成的文件(通过 HTTP API)"""
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
base_url = (env_url or "https://dagent.d-robotics.cc").rstrip("/")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"org-id": org_id,
|
||||||
|
"d-user-id": "test",
|
||||||
|
}
|
||||||
|
|
||||||
|
all_files = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(headers=headers) as session:
|
||||||
|
page = 1
|
||||||
|
page_size = 100
|
||||||
|
|
||||||
|
while True:
|
||||||
|
async with session.post(
|
||||||
|
f"{base_url}/dagent/knowledge/file/page",
|
||||||
|
json={"current": page, "page_size": page_size, "org_id": org_id},
|
||||||
|
timeout=aiohttp.ClientTimeout(total=15),
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
break
|
||||||
|
data = await resp.json()
|
||||||
|
files = data.get("data", {}).get("list", [])
|
||||||
|
if not files:
|
||||||
|
break
|
||||||
|
|
||||||
|
for f in files:
|
||||||
|
all_files.append({
|
||||||
|
"id": f.get("id"),
|
||||||
|
"file_name": f.get("file_name"),
|
||||||
|
"file_type": f.get("file_type"),
|
||||||
|
"file_clean_status": f.get("file_clean_status", "").lower(),
|
||||||
|
"file_bytes": f.get("file_bytes", 0),
|
||||||
|
"create_time": f.get("create_time"),
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(files) < page_size:
|
||||||
|
break
|
||||||
|
page += 1
|
||||||
|
|
||||||
|
return {"status": 0, "data": all_files}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[list_dagent_files] Error: {e}")
|
||||||
|
return {"status": 0, "data": []}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/task/from-dagent")
|
||||||
|
async def create_task_from_dagent(
|
||||||
|
org_id: str = Form(...),
|
||||||
|
env_url: str = Form(""),
|
||||||
|
name: str = Form(""),
|
||||||
|
judge_config_id: str = Form(...),
|
||||||
|
file_ids: str = Form(""),
|
||||||
|
hops_per_question: int = Form(2),
|
||||||
|
questions_per_group: int = Form(3),
|
||||||
|
quality_threshold: float = Form(0.6),
|
||||||
|
prompt_template_id: str = Form(""),
|
||||||
|
):
|
||||||
|
"""从 Dagent 知识库创建多跳问答生成任务"""
|
||||||
|
task_id = _id()
|
||||||
|
file_id_list = [f.strip() for f in file_ids.split(",") if f.strip()]
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"""INSERT INTO multi_hop_gen_task
|
||||||
|
(id,name,source,judge_config_id,org_id,file_ids,
|
||||||
|
hops_per_question,questions_per_group,quality_threshold,
|
||||||
|
prompt_template_id,status,created_at)
|
||||||
|
VALUES (?,?,?,?,?,?,?,?,?,?,?,?)""",
|
||||||
|
(task_id, name or f"Dagent多跳({org_id[:8]}...)", "dagent",
|
||||||
|
judge_config_id, org_id, file_ids,
|
||||||
|
hops_per_question, questions_per_group, quality_threshold,
|
||||||
|
prompt_template_id or None, "pending", _now()),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
asyncio.create_task(_run_dagent_task(task_id, org_id, file_id_list, env_url))
|
||||||
|
return {"status": 0, "data": {"id": task_id}}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/task/list")
|
||||||
|
async def list_tasks():
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM multi_hop_gen_task ORDER BY created_at DESC"
|
||||||
|
)
|
||||||
|
return {"status": 0, "data": [dict(r) for r in rows]}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/task/{task_id}")
|
||||||
|
async def get_task(task_id: str):
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM multi_hop_gen_task WHERE id=?", (task_id,)
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
return {"status": 0, "data": dict(rows[0])}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/task/{task_id}")
|
||||||
|
async def delete_task(task_id: str):
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute("DELETE FROM multi_hop_gen_question WHERE task_id=?", (task_id,))
|
||||||
|
await db.execute("DELETE FROM multi_hop_gen_task WHERE id=?", (task_id,))
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0}
|
||||||
|
|
||||||
|
|
||||||
|
# ── 问题列表 ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/task/{task_id}/questions")
|
||||||
|
async def list_questions(
|
||||||
|
task_id: str,
|
||||||
|
status: Optional[str] = None,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 50,
|
||||||
|
):
|
||||||
|
conditions = ["task_id=?"]
|
||||||
|
params: list = [task_id]
|
||||||
|
if status:
|
||||||
|
conditions.append("status=?")
|
||||||
|
params.append(status)
|
||||||
|
where = " AND ".join(conditions)
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
count_rows = await db.execute_fetchall(
|
||||||
|
f"SELECT COUNT(*) as cnt FROM multi_hop_gen_question WHERE {where}", params
|
||||||
|
)
|
||||||
|
total = dict(count_rows[0])["cnt"]
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
f"""SELECT * FROM multi_hop_gen_question WHERE {where}
|
||||||
|
ORDER BY created_at LIMIT ? OFFSET ?""",
|
||||||
|
params + [page_size, offset],
|
||||||
|
)
|
||||||
|
|
||||||
|
items = []
|
||||||
|
for r in rows:
|
||||||
|
d = dict(r)
|
||||||
|
d["hops"] = json.loads(d.get("hops") or "[]")
|
||||||
|
d["source_sections"] = json.loads(d.get("source_sections") or "[]")
|
||||||
|
items.append(d)
|
||||||
|
|
||||||
|
return {"status": 0, "data": {"total": total, "items": items}}
|
||||||
|
|
||||||
|
|
||||||
|
# ── 审核操作 ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/question/{question_id}/approve")
|
||||||
|
async def approve_question(question_id: str):
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT task_id FROM multi_hop_gen_question WHERE id=?", (question_id,)
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
raise HTTPException(status_code=404)
|
||||||
|
task_id = dict(rows[0])["task_id"]
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE multi_hop_gen_question SET status='approved', updated_at=? WHERE id=?",
|
||||||
|
(_now(), question_id),
|
||||||
|
)
|
||||||
|
await _sync_approved(db, task_id)
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/question/{question_id}/reject")
|
||||||
|
async def reject_question(question_id: str):
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT task_id FROM multi_hop_gen_question WHERE id=?", (question_id,)
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
raise HTTPException(status_code=404)
|
||||||
|
task_id = dict(rows[0])["task_id"]
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE multi_hop_gen_question SET status='rejected', updated_at=? WHERE id=?",
|
||||||
|
(_now(), question_id),
|
||||||
|
)
|
||||||
|
await _sync_approved(db, task_id)
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0}
|
||||||
|
|
||||||
|
|
||||||
|
class QuestionEditReq(BaseModel):
|
||||||
|
question: Optional[str] = None
|
||||||
|
answer: Optional[str] = None
|
||||||
|
type: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/question/{question_id}")
|
||||||
|
async def edit_question(question_id: str, req: QuestionEditReq):
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT task_id FROM multi_hop_gen_question WHERE id=?", (question_id,)
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
raise HTTPException(status_code=404)
|
||||||
|
task_id = dict(rows[0])["task_id"]
|
||||||
|
updates, params = [], []
|
||||||
|
if req.question is not None:
|
||||||
|
updates.append("question=?"); params.append(req.question)
|
||||||
|
if req.answer is not None:
|
||||||
|
updates.append("answer=?"); params.append(req.answer)
|
||||||
|
if req.type is not None:
|
||||||
|
updates.append("type=?"); params.append(req.type)
|
||||||
|
if not updates:
|
||||||
|
raise HTTPException(status_code=400, detail="No fields to update")
|
||||||
|
updates += ["status='approved'", "updated_at=?"]
|
||||||
|
params += [_now(), question_id]
|
||||||
|
await db.execute(
|
||||||
|
f"UPDATE multi_hop_gen_question SET {', '.join(updates)} WHERE id=?", params
|
||||||
|
)
|
||||||
|
await _sync_approved(db, task_id)
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/task/{task_id}/batch-approve")
|
||||||
|
async def batch_approve(task_id: str, min_quality: float = 0.0):
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"""UPDATE multi_hop_gen_question SET status='approved', updated_at=?
|
||||||
|
WHERE task_id=? AND status='pending'
|
||||||
|
AND (quality_score IS NULL OR quality_score >= ?)""",
|
||||||
|
(_now(), task_id, min_quality),
|
||||||
|
)
|
||||||
|
await _sync_approved(db, task_id)
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0}
|
||||||
|
|
||||||
|
|
||||||
|
# ── 导出 MD ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/task/{task_id}/export-md")
|
||||||
|
async def export_md(task_id: str):
|
||||||
|
"""导出已通过的多跳问答对为标准 MD 格式(可直接用于多跳召回测试)"""
|
||||||
|
async with get_db() as db:
|
||||||
|
task_rows = await db.execute_fetchall(
|
||||||
|
"SELECT name FROM multi_hop_gen_task WHERE id=?", (task_id,)
|
||||||
|
)
|
||||||
|
if not task_rows:
|
||||||
|
raise HTTPException(status_code=404)
|
||||||
|
task_name = dict(task_rows[0]).get("name", task_id)
|
||||||
|
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"""SELECT * FROM multi_hop_gen_question
|
||||||
|
WHERE task_id=? AND status='approved'
|
||||||
|
ORDER BY created_at""",
|
||||||
|
(task_id,),
|
||||||
|
)
|
||||||
|
# Convert rows to dicts while connection is still open
|
||||||
|
row_dicts = [dict(r) for r in rows]
|
||||||
|
|
||||||
|
if not row_dicts:
|
||||||
|
raise HTTPException(status_code=404, detail="没有已通过的问题")
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
for i, d in enumerate(row_dicts, 1):
|
||||||
|
hops = json.loads(d.get("hops") or "[]")
|
||||||
|
qid = d.get("qid") or f"MH{i}"
|
||||||
|
lines.append(f"## {qid}")
|
||||||
|
lines.append(f"**类型:** {d.get('type', 'reasoning')}")
|
||||||
|
lines.append(f"**问题:** {d['question']}")
|
||||||
|
lines.append(f"**答案:** {d['answer']}")
|
||||||
|
for j, hop in enumerate(hops, 1):
|
||||||
|
section = hop.get("section_path", "")
|
||||||
|
contrib = hop.get("contribution", "")
|
||||||
|
chunk_id = hop.get("chunk_id") or hop.get("paragraph_chunk_id") or ""
|
||||||
|
if chunk_id:
|
||||||
|
lines.append(f"**Hop{j}:** {section} | {contrib} | {chunk_id}")
|
||||||
|
else:
|
||||||
|
lines.append(f"**Hop{j}:** {section} | {contrib}")
|
||||||
|
lines.append("---")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
md_content = "\n".join(lines)
|
||||||
|
filename_encoded = quote(f"multi_hop_{task_name}.md".replace(" ", "_"))
|
||||||
|
return StreamingResponse(
|
||||||
|
iter([md_content.encode("utf-8")]),
|
||||||
|
media_type="text/markdown",
|
||||||
|
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{filename_encoded}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CreateTestReq(BaseModel):
|
||||||
|
env_url: str
|
||||||
|
org_id: str
|
||||||
|
agent_id: str
|
||||||
|
llm_type: str = "deepseek_v3"
|
||||||
|
d_user_id: str = "test"
|
||||||
|
top_k: int = 10
|
||||||
|
concurrency: int = 5
|
||||||
|
name: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/task/{task_id}/create-test")
|
||||||
|
async def create_test_from_gen(task_id: str, req: CreateTestReq):
|
||||||
|
"""将已通过的多跳问答对直接创建为召回测试任务"""
|
||||||
|
async with get_db() as db:
|
||||||
|
task_rows = await db.execute_fetchall(
|
||||||
|
"SELECT name FROM multi_hop_gen_task WHERE id=?", (task_id,)
|
||||||
|
)
|
||||||
|
if not task_rows:
|
||||||
|
raise HTTPException(status_code=404, detail="生成任务不存在")
|
||||||
|
task_name = dict(task_rows[0]).get("name", task_id)
|
||||||
|
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"""SELECT * FROM multi_hop_gen_question
|
||||||
|
WHERE task_id=? AND status='approved'
|
||||||
|
ORDER BY created_at""",
|
||||||
|
(task_id,),
|
||||||
|
)
|
||||||
|
# Convert rows to dicts while connection is still open
|
||||||
|
row_dicts = [dict(r) for r in rows]
|
||||||
|
|
||||||
|
if not row_dicts:
|
||||||
|
raise HTTPException(status_code=400, detail="没有已通过的问题,请先审核通过至少一个问题")
|
||||||
|
|
||||||
|
# 构建 MD 内容
|
||||||
|
lines = []
|
||||||
|
for i, d in enumerate(row_dicts, 1):
|
||||||
|
hops = json.loads(d.get("hops") or "[]")
|
||||||
|
qid = d.get("qid") or f"MH{i}"
|
||||||
|
lines.append(f"## {qid}")
|
||||||
|
lines.append(f"**类型:** {d.get('type', 'reasoning')}")
|
||||||
|
lines.append(f"**问题:** {d['question']}")
|
||||||
|
lines.append(f"**答案:** {d['answer']}")
|
||||||
|
for j, hop in enumerate(hops, 1):
|
||||||
|
section = hop.get("section_path", "")
|
||||||
|
contrib = hop.get("contribution", "")
|
||||||
|
chunk_id = hop.get("chunk_id") or hop.get("paragraph_chunk_id") or ""
|
||||||
|
if chunk_id:
|
||||||
|
lines.append(f"**Hop{j}:** {section} | {contrib} | {chunk_id}")
|
||||||
|
else:
|
||||||
|
lines.append(f"**Hop{j}:** {section} | {contrib}")
|
||||||
|
lines.append("---")
|
||||||
|
lines.append("")
|
||||||
|
md_content = "\n".join(lines)
|
||||||
|
|
||||||
|
# 直接写入 multi_hop_task 并触发后台任务
|
||||||
|
from api.multi_hop import _run_task as _run_test_task
|
||||||
|
test_name = req.name or f"{task_name}-召回测试"
|
||||||
|
test_task_id = _id()
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"""INSERT INTO multi_hop_task
|
||||||
|
(id,name,env_url,org_id,d_user_id,agent_id,llm_type,top_k,concurrency,status,created_at)
|
||||||
|
VALUES (?,?,?,?,?,?,?,?,?,?,?)""",
|
||||||
|
(test_task_id, test_name, req.env_url, req.org_id,
|
||||||
|
req.d_user_id, req.agent_id, req.llm_type,
|
||||||
|
req.top_k, req.concurrency, "pending", _now()),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
asyncio.create_task(_run_test_task(
|
||||||
|
test_task_id, md_content, req.env_url, req.org_id,
|
||||||
|
req.d_user_id, req.agent_id, req.llm_type,
|
||||||
|
req.top_k, req.concurrency,
|
||||||
|
))
|
||||||
|
|
||||||
|
return {"status": 0, "data": {"test_task_id": test_task_id, "question_count": len(row_dicts)}}
|
||||||
|
|
||||||
|
|
||||||
|
# ── 内部:运行生成任务 ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _run_task(task_id: str, md_text: str):
|
||||||
|
try:
|
||||||
|
# 获取任务配置
|
||||||
|
async with get_db() as db:
|
||||||
|
cfg_rows = await db.execute_fetchall(
|
||||||
|
"SELECT t.*, j.base_url, j.api_key, j.model "
|
||||||
|
"FROM multi_hop_gen_task t JOIN judge_config j ON t.judge_config_id=j.id "
|
||||||
|
"WHERE t.id=?",
|
||||||
|
(task_id,),
|
||||||
|
)
|
||||||
|
if not cfg_rows:
|
||||||
|
raise ValueError("judge_config not found")
|
||||||
|
cfg = dict(cfg_rows[0])
|
||||||
|
hops_per_question = cfg["hops_per_question"]
|
||||||
|
questions_per_group = cfg["questions_per_group"]
|
||||||
|
quality_threshold = cfg["quality_threshold"]
|
||||||
|
requirements = await _load_requirements(cfg.get("prompt_template_id"))
|
||||||
|
|
||||||
|
# 切分章节
|
||||||
|
sections = _parse_knowledge_md(md_text)
|
||||||
|
if len(sections) < hops_per_question:
|
||||||
|
raise ValueError(f"文档章节数({len(sections)})少于 hops_per_question({hops_per_question}),无法生成多跳问题")
|
||||||
|
|
||||||
|
# 将章节分组:每组 hops_per_question 个,滑动窗口
|
||||||
|
import random
|
||||||
|
groups = _make_groups(sections, hops_per_question)
|
||||||
|
total = len(groups)
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE multi_hop_gen_task SET status='running', total=? WHERE id=?",
|
||||||
|
(total, task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
sem = asyncio.Semaphore(3)
|
||||||
|
done = 0
|
||||||
|
question_counter = [0]
|
||||||
|
|
||||||
|
async def gen_group(group: list[tuple[str, str]]):
|
||||||
|
nonlocal done
|
||||||
|
async with sem:
|
||||||
|
questions = await _generate_multi_hop_questions(
|
||||||
|
cfg=cfg,
|
||||||
|
sections=group,
|
||||||
|
n=questions_per_group,
|
||||||
|
hops=hops_per_question,
|
||||||
|
requirements=requirements,
|
||||||
|
)
|
||||||
|
async with get_db() as db2:
|
||||||
|
for q in questions:
|
||||||
|
question_counter[0] += 1
|
||||||
|
qid = f"MH{question_counter[0]}"
|
||||||
|
quality_score = q.get("quality_score", 0.8)
|
||||||
|
status = "approved" if quality_score >= quality_threshold else "pending"
|
||||||
|
source_sections = [s for s, _ in group]
|
||||||
|
await db2.execute(
|
||||||
|
"""INSERT INTO multi_hop_gen_question
|
||||||
|
(id,task_id,qid,question,answer,type,hops,source_sections,
|
||||||
|
quality_score,status,created_at)
|
||||||
|
VALUES (?,?,?,?,?,?,?,?,?,?,?)""",
|
||||||
|
(
|
||||||
|
_id(), task_id, qid,
|
||||||
|
q["question"], q["answer"], q.get("type", "reasoning"),
|
||||||
|
json.dumps(q.get("hops", []), ensure_ascii=False),
|
||||||
|
json.dumps(source_sections, ensure_ascii=False),
|
||||||
|
quality_score, status, _now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
done += 1
|
||||||
|
await db2.execute(
|
||||||
|
"UPDATE multi_hop_gen_task SET progress=? WHERE id=?", (done, task_id)
|
||||||
|
)
|
||||||
|
await _sync_approved(db2, task_id)
|
||||||
|
await db2.commit()
|
||||||
|
|
||||||
|
await asyncio.gather(*[gen_group(g) for g in groups])
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE multi_hop_gen_task SET status='done', finished_at=? WHERE id=?",
|
||||||
|
(_now(), task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE multi_hop_gen_task SET status='failed', error_message=? WHERE id=?",
|
||||||
|
(str(exc), task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_knowledge_md(md_text: str) -> list[tuple[str, str]]:
|
||||||
|
"""按 ## 标题切分章节,返回 (section_path, content) 列表"""
|
||||||
|
lines = md_text.splitlines()
|
||||||
|
sections: list[tuple[str, str]] = []
|
||||||
|
current_path: list[str] = []
|
||||||
|
current_lines: list[str] = []
|
||||||
|
current_level = 0
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
m = re.match(r'^(#{1,4})\s+(.+)', line)
|
||||||
|
if m:
|
||||||
|
if current_path and current_lines:
|
||||||
|
content = "\n".join(current_lines).strip()
|
||||||
|
if content:
|
||||||
|
sections.append(("/".join(current_path), content))
|
||||||
|
level = len(m.group(1))
|
||||||
|
title = m.group(2).strip()
|
||||||
|
if level > current_level:
|
||||||
|
current_path.append(title)
|
||||||
|
elif level == current_level:
|
||||||
|
current_path = current_path[:level - 1] + [title]
|
||||||
|
else:
|
||||||
|
current_path = current_path[:level - 1] + [title]
|
||||||
|
current_level = level
|
||||||
|
current_lines = []
|
||||||
|
else:
|
||||||
|
current_lines.append(line)
|
||||||
|
|
||||||
|
if current_path and current_lines:
|
||||||
|
content = "\n".join(current_lines).strip()
|
||||||
|
if content:
|
||||||
|
sections.append(("/".join(current_path), content))
|
||||||
|
|
||||||
|
return sections
|
||||||
|
|
||||||
|
|
||||||
|
def _make_groups(
|
||||||
|
sections: list[tuple[str, str]],
|
||||||
|
hops: int,
|
||||||
|
) -> list[list[tuple[str, str]]]:
|
||||||
|
"""
|
||||||
|
将章节列表组合成多跳分组。
|
||||||
|
策略:随机采样,每组 hops 个不同章节,最多生成 min(len*2, 50) 组避免过多。
|
||||||
|
"""
|
||||||
|
import random
|
||||||
|
n = len(sections)
|
||||||
|
max_groups = min(n * 2, 60)
|
||||||
|
groups = []
|
||||||
|
seen: set[frozenset] = set()
|
||||||
|
|
||||||
|
# 先做滑动窗口(相邻章节更可能有关联)
|
||||||
|
for i in range(n - hops + 1):
|
||||||
|
group = sections[i:i + hops]
|
||||||
|
key = frozenset(s for s, _ in group)
|
||||||
|
if key not in seen:
|
||||||
|
seen.add(key)
|
||||||
|
groups.append(group)
|
||||||
|
|
||||||
|
# 再随机补充
|
||||||
|
attempts = 0
|
||||||
|
while len(groups) < max_groups and attempts < max_groups * 3:
|
||||||
|
attempts += 1
|
||||||
|
idxs = random.sample(range(n), min(hops, n))
|
||||||
|
group = [sections[i] for i in sorted(idxs)]
|
||||||
|
key = frozenset(s for s, _ in group)
|
||||||
|
if key not in seen:
|
||||||
|
seen.add(key)
|
||||||
|
groups.append(group)
|
||||||
|
|
||||||
|
return groups
|
||||||
|
|
||||||
|
|
||||||
|
async def _load_requirements(prompt_template_id: str | None) -> str:
|
||||||
|
"""从数据库加载提示词模板内容,无模板则返回内置默认"""
|
||||||
|
from api.prompt_template import DEFAULT_CONTENT
|
||||||
|
if not prompt_template_id:
|
||||||
|
return DEFAULT_CONTENT
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT content FROM prompt_template WHERE id=?", (prompt_template_id,)
|
||||||
|
)
|
||||||
|
if rows:
|
||||||
|
return dict(rows[0])["content"]
|
||||||
|
return DEFAULT_CONTENT
|
||||||
|
|
||||||
|
|
||||||
|
async def _generate_multi_hop_questions(
|
||||||
|
cfg: dict,
|
||||||
|
sections: list[tuple[str, str]],
|
||||||
|
n: int,
|
||||||
|
hops: int,
|
||||||
|
requirements: str = "",
|
||||||
|
) -> list[dict]:
|
||||||
|
"""调用 LLM 生成多跳问答对"""
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
base_url = cfg.get("base_url", "").rstrip("/")
|
||||||
|
api_key = cfg.get("api_key", "")
|
||||||
|
model = cfg.get("model", "gpt-4o-mini")
|
||||||
|
|
||||||
|
# 构建章节描述
|
||||||
|
section_blocks = []
|
||||||
|
for i, (path, content) in enumerate(sections, 1):
|
||||||
|
truncated = content[:1500] if len(content) > 1500 else content
|
||||||
|
section_blocks.append(f"【章节{i}】路径:{path}\n{truncated}")
|
||||||
|
sections_text = "\n\n".join(section_blocks)
|
||||||
|
|
||||||
|
hop_labels = "、".join([f"章节{i+1}" for i in range(hops)])
|
||||||
|
type_examples = "comparison(比较型)、reasoning(推理型)、aggregation(聚合型)"
|
||||||
|
|
||||||
|
prompt = f"""你是一个专业的技术文档多跳问答生成专家。
|
||||||
|
|
||||||
|
以下是来自同一知识库的 {hops} 个不同章节,请生成 {n} 个需要同时参考这 {hops} 个章节才能完整回答的多跳问题。
|
||||||
|
|
||||||
|
{sections_text}
|
||||||
|
|
||||||
|
要求:
|
||||||
|
{requirements}
|
||||||
|
|
||||||
|
只输出 JSON 数组,不要有其他内容:
|
||||||
|
[
|
||||||
|
{{
|
||||||
|
"question": "问题文本",
|
||||||
|
"answer": "综合多个章节的参考答案",
|
||||||
|
"type": "comparison",
|
||||||
|
"quality_score": 0.85,
|
||||||
|
"hops": [
|
||||||
|
{{"section_path": "{sections[0][0] if sections else ''}", "contribution": "该章节提供了..."}},
|
||||||
|
{{"section_path": "{sections[1][0] if len(sections) > 1 else ''}", "contribution": "该章节提供了..."}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
]"""
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"temperature": 0.4,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(headers=headers) as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{base_url}/chat/completions",
|
||||||
|
json=payload,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=90),
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = await resp.json()
|
||||||
|
|
||||||
|
text = data["choices"][0]["message"]["content"].strip()
|
||||||
|
m = re.search(r'\[.*\]', text, re.DOTALL)
|
||||||
|
if not m:
|
||||||
|
return []
|
||||||
|
questions = json.loads(m.group())
|
||||||
|
result = []
|
||||||
|
for q in questions:
|
||||||
|
if not isinstance(q, dict):
|
||||||
|
continue
|
||||||
|
if not q.get("question") or not q.get("answer"):
|
||||||
|
continue
|
||||||
|
hops_data = q.get("hops", [])
|
||||||
|
# 校验 hops 数量
|
||||||
|
if len(hops_data) < 2:
|
||||||
|
continue
|
||||||
|
result.append({
|
||||||
|
"question": str(q["question"]).strip(),
|
||||||
|
"answer": str(q["answer"]).strip(),
|
||||||
|
"type": str(q.get("type", "reasoning")).strip(),
|
||||||
|
"quality_score": float(q.get("quality_score", 0.8)),
|
||||||
|
"hops": [
|
||||||
|
{
|
||||||
|
"section_path": str(h.get("section_path", "")).strip(),
|
||||||
|
"contribution": str(h.get("contribution", "")).strip(),
|
||||||
|
}
|
||||||
|
for h in hops_data if isinstance(h, dict)
|
||||||
|
],
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def _sync_approved(db, task_id: str):
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT COUNT(*) as cnt FROM multi_hop_gen_question WHERE task_id=? AND status='approved'",
|
||||||
|
(task_id,),
|
||||||
|
)
|
||||||
|
approved = dict(rows[0])["cnt"] if rows else 0
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE multi_hop_gen_task SET approved=? WHERE id=?", (approved, task_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_dagent_task(task_id: str, org_id: str, file_id_list: list[str], env_url: str = ""):
|
||||||
|
"""
|
||||||
|
从 Dagent 拉取段落,按文件分组后跨文件生成多跳问答对。
|
||||||
|
|
||||||
|
分组策略:
|
||||||
|
- 将段落按 file_name 聚合成文件级 section
|
||||||
|
- 每组随机选 hops_per_question 个不同文件的 section 组合
|
||||||
|
- 调用 LLM 生成跨文件多跳问题
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 获取任务配置
|
||||||
|
async with get_db() as db:
|
||||||
|
cfg_rows = await db.execute_fetchall(
|
||||||
|
"SELECT t.*, j.base_url, j.api_key, j.model "
|
||||||
|
"FROM multi_hop_gen_task t JOIN judge_config j ON t.judge_config_id=j.id "
|
||||||
|
"WHERE t.id=?",
|
||||||
|
(task_id,),
|
||||||
|
)
|
||||||
|
if not cfg_rows:
|
||||||
|
raise ValueError("judge_config not found")
|
||||||
|
cfg = dict(cfg_rows[0])
|
||||||
|
hops_per_question = cfg["hops_per_question"]
|
||||||
|
questions_per_group = cfg["questions_per_group"]
|
||||||
|
quality_threshold = cfg["quality_threshold"]
|
||||||
|
requirements = await _load_requirements(cfg.get("prompt_template_id"))
|
||||||
|
|
||||||
|
# 1. 从 Dagent 拉取段落
|
||||||
|
paragraphs = await _fetch_paragraphs(org_id, file_id_list, env_url)
|
||||||
|
if not paragraphs:
|
||||||
|
raise ValueError("未获取到任何段落,请检查 org_id 和文件选择")
|
||||||
|
|
||||||
|
# 2. 按文件聚合段落 -> file_sections: {file_name: [(section_path, content), ...]}
|
||||||
|
from collections import defaultdict
|
||||||
|
file_sections: dict[str, list[tuple[str, str]]] = defaultdict(list)
|
||||||
|
for para in paragraphs:
|
||||||
|
file_name = para.get("file_name") or para.get("file_id", "unknown")
|
||||||
|
headers = (para.get("headers") or "").strip()
|
||||||
|
text = (para.get("paragraph_context") or "").strip()
|
||||||
|
pic = (para.get("paragraph_pic_semantics_context") or "").strip()
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
content = text
|
||||||
|
if pic:
|
||||||
|
content += f"\n\n[图片描述] {pic[:500]}"
|
||||||
|
section_path = f"{file_name}/{headers}" if headers else file_name
|
||||||
|
file_sections[file_name].append((section_path, content[:2000]))
|
||||||
|
|
||||||
|
# 每个文件取最具代表性的段落(最长的前 N 个)
|
||||||
|
file_repr: dict[str, tuple[str, str]] = {}
|
||||||
|
for fname, secs in file_sections.items():
|
||||||
|
# 取内容最长的段落作为该文件的代表
|
||||||
|
best = max(secs, key=lambda x: len(x[1]))
|
||||||
|
file_repr[fname] = best
|
||||||
|
|
||||||
|
file_names = list(file_repr.keys())
|
||||||
|
if len(file_names) < hops_per_question:
|
||||||
|
raise ValueError(
|
||||||
|
f"文件数({len(file_names)})少于 hops_per_question({hops_per_question}),"
|
||||||
|
"请减少 Hop 数或选择更多文件"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 生成跨文件分组
|
||||||
|
sections_flat = list(file_repr.values()) # [(section_path, content), ...]
|
||||||
|
groups = _make_groups(sections_flat, hops_per_question)
|
||||||
|
total = len(groups)
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE multi_hop_gen_task SET status='running', total=? WHERE id=?",
|
||||||
|
(total, task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# 4. 并发生成
|
||||||
|
sem = asyncio.Semaphore(3)
|
||||||
|
done = 0
|
||||||
|
question_counter = [0]
|
||||||
|
|
||||||
|
async def gen_group(group: list[tuple[str, str]]):
|
||||||
|
nonlocal done
|
||||||
|
async with sem:
|
||||||
|
questions = await _generate_multi_hop_questions(
|
||||||
|
cfg=cfg,
|
||||||
|
sections=group,
|
||||||
|
n=questions_per_group,
|
||||||
|
hops=hops_per_question,
|
||||||
|
requirements=requirements,
|
||||||
|
)
|
||||||
|
async with get_db() as db2:
|
||||||
|
for q in questions:
|
||||||
|
question_counter[0] += 1
|
||||||
|
qid = f"MH{question_counter[0]}"
|
||||||
|
quality_score = q.get("quality_score", 0.8)
|
||||||
|
status = "approved" if quality_score >= quality_threshold else "pending"
|
||||||
|
source_sections = [s for s, _ in group]
|
||||||
|
await db2.execute(
|
||||||
|
"""INSERT INTO multi_hop_gen_question
|
||||||
|
(id,task_id,qid,question,answer,type,hops,source_sections,
|
||||||
|
quality_score,status,created_at)
|
||||||
|
VALUES (?,?,?,?,?,?,?,?,?,?,?)""",
|
||||||
|
(
|
||||||
|
_id(), task_id, qid,
|
||||||
|
q["question"], q["answer"], q.get("type", "reasoning"),
|
||||||
|
json.dumps(q.get("hops", []), ensure_ascii=False),
|
||||||
|
json.dumps(source_sections, ensure_ascii=False),
|
||||||
|
quality_score, status, _now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
done += 1
|
||||||
|
await db2.execute(
|
||||||
|
"UPDATE multi_hop_gen_task SET progress=? WHERE id=?", (done, task_id)
|
||||||
|
)
|
||||||
|
await _sync_approved(db2, task_id)
|
||||||
|
await db2.commit()
|
||||||
|
|
||||||
|
await asyncio.gather(*[gen_group(g) for g in groups])
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE multi_hop_gen_task SET status='done', finished_at=? WHERE id=?",
|
||||||
|
(_now(), task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE multi_hop_gen_task SET status='failed', error_message=? WHERE id=?",
|
||||||
|
(str(exc), task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
78
server/api/prompt_template.py
Normal file
78
server/api/prompt_template.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
"""
|
||||||
|
提示词模板管理 API
|
||||||
|
"""
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from models.db import get_db, _now, _id
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/prompt-template", tags=["提示词模板"])
|
||||||
|
|
||||||
|
DEFAULT_CONTENT = """1. 每个问题必须真正跨越多个章节,单独看任何一个章节都无法完整回答
|
||||||
|
2. 问题类型可以是:comparison(比较型)、reasoning(推理型)、aggregation(聚合型)
|
||||||
|
3. 答案要综合所有章节的信息,准确完整
|
||||||
|
4. 每个 hop 说明该章节对回答问题的具体贡献
|
||||||
|
5. quality_score 为你对该问题质量的评估(0-1)"""
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/default")
|
||||||
|
async def get_default():
|
||||||
|
"""返回内置默认提示词内容"""
|
||||||
|
return {"status": 0, "data": {"content": DEFAULT_CONTENT}}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/list")
|
||||||
|
async def list_templates():
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM prompt_template ORDER BY created_at DESC"
|
||||||
|
)
|
||||||
|
return {"status": 0, "data": [dict(r) for r in rows]}
|
||||||
|
|
||||||
|
|
||||||
|
class TemplateReq(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("")
|
||||||
|
async def create_template(req: TemplateReq):
|
||||||
|
if not req.content.strip():
|
||||||
|
raise HTTPException(status_code=400, detail="content 不能为空")
|
||||||
|
row_id = _id()
|
||||||
|
now = _now()
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO prompt_template (id,name,description,content,created_at,updated_at) VALUES (?,?,?,?,?,?)",
|
||||||
|
(row_id, req.name, req.description, req.content, now, now),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0, "data": {"id": row_id}}
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{template_id}")
|
||||||
|
async def update_template(template_id: str, req: TemplateReq):
|
||||||
|
if not req.content.strip():
|
||||||
|
raise HTTPException(status_code=400, detail="content 不能为空")
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT id FROM prompt_template WHERE id=?", (template_id,)
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
raise HTTPException(status_code=404, detail="模板不存在")
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE prompt_template SET name=?,description=?,content=?,updated_at=? WHERE id=?",
|
||||||
|
(req.name, req.description, req.content, _now(), template_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0, "data": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{template_id}")
|
||||||
|
async def delete_template(template_id: str):
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute("DELETE FROM prompt_template WHERE id=?", (template_id,))
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0, "data": True}
|
||||||
606
server/api/qa_gen.py
Normal file
606
server/api/qa_gen.py
Normal file
@ -0,0 +1,606 @@
|
|||||||
|
"""
|
||||||
|
问题生成 API
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
# Add parent directory to sys.path for relative imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
from models.db import get_db, _now, _id
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/qa-gen", tags=["问题生成"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── 任务 CRUD ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/task")
|
||||||
|
async def create_task(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
name: str = Form(""),
|
||||||
|
judge_config_id: str = Form(...),
|
||||||
|
questions_per_section: int = Form(5),
|
||||||
|
quality_threshold: float = Form(0.6),
|
||||||
|
):
|
||||||
|
content = await file.read()
|
||||||
|
md_text = content.decode("utf-8")
|
||||||
|
|
||||||
|
task_id = _id()
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"""INSERT INTO qa_gen_task
|
||||||
|
(id,name,judge_config_id,questions_per_section,quality_threshold,status,created_at)
|
||||||
|
VALUES (?,?,?,?,?,?,?)""",
|
||||||
|
(task_id, name or file.filename, judge_config_id,
|
||||||
|
questions_per_section, quality_threshold, "pending", _now()),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
asyncio.create_task(_run_task(task_id, md_text))
|
||||||
|
return {"status": 0, "data": {"id": task_id}}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/task/list")
|
||||||
|
async def list_tasks():
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM qa_gen_task ORDER BY created_at DESC"
|
||||||
|
)
|
||||||
|
return {"status": 0, "data": [dict(r) for r in rows]}
|
||||||
|
|
||||||
|
|
||||||
|
class CreateDatasetReq(BaseModel):
|
||||||
|
name: str
|
||||||
|
knowledge_hub_id: str = ""
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/task/{task_id}/create-dataset")
|
||||||
|
async def create_dataset_from_qa_gen(task_id: str, req: CreateDatasetReq):
|
||||||
|
"""根据 QA 生成任务创建评测数据集"""
|
||||||
|
async with get_db() as db:
|
||||||
|
# 检查任务是否存在
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM qa_gen_task WHERE id=?", (task_id,)
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
raise HTTPException(status_code=404, detail="QA 生成任务不存在")
|
||||||
|
|
||||||
|
# 获取已通过的问题
|
||||||
|
question_rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM qa_gen_question WHERE task_id=? AND status='approved'",
|
||||||
|
(task_id,)
|
||||||
|
)
|
||||||
|
if not question_rows:
|
||||||
|
raise HTTPException(status_code=400, detail="没有已通过的问题")
|
||||||
|
|
||||||
|
# 创建数据集
|
||||||
|
dataset_id = _id()
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO eval_dataset (id,name,description,sample_count,created_at) VALUES (?,?,?,?,?)",
|
||||||
|
(dataset_id, req.name, req.description, len(question_rows), _now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加样本
|
||||||
|
for q in question_rows:
|
||||||
|
q_dict = dict(q)
|
||||||
|
sample_id = _id()
|
||||||
|
await db.execute(
|
||||||
|
"""INSERT INTO eval_sample
|
||||||
|
(id,dataset_id,question,reference_answer,relevant_chunk_ids,knowledge_hub_id,source_file_id,metadata)
|
||||||
|
VALUES (?,?,?,?,?,?,?,?)""",
|
||||||
|
(sample_id, dataset_id, q_dict["question"], q_dict["reference_answer"],
|
||||||
|
json.dumps([], ensure_ascii=False), req.knowledge_hub_id,
|
||||||
|
None, json.dumps({"source": "qa_gen", "qa_gen_task_id": task_id}, ensure_ascii=False)),
|
||||||
|
)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"status": 0, "data": {"dataset_id": dataset_id, "sample_count": len(question_rows)}}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/task/{task_id}")
|
||||||
|
async def get_task(task_id: str):
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM qa_gen_task WHERE id=?", (task_id,)
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
return {"status": 0, "data": dict(rows[0])}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/task/{task_id}")
|
||||||
|
async def delete_task(task_id: str):
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute("DELETE FROM qa_gen_question WHERE task_id=?", (task_id,))
|
||||||
|
await db.execute("DELETE FROM qa_gen_task WHERE id=?", (task_id,))
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0, "data": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── 问题列表 ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/task/{task_id}/questions")
|
||||||
|
async def list_questions(
|
||||||
|
task_id: str,
|
||||||
|
status: Optional[str] = None,
|
||||||
|
section: Optional[str] = None,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 50,
|
||||||
|
):
|
||||||
|
conditions = ["task_id=?"]
|
||||||
|
params: list = [task_id]
|
||||||
|
if status:
|
||||||
|
conditions.append("status=?")
|
||||||
|
params.append(status)
|
||||||
|
if section:
|
||||||
|
conditions.append("section_path=?")
|
||||||
|
params.append(section)
|
||||||
|
where = " AND ".join(conditions)
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
count_rows = await db.execute_fetchall(
|
||||||
|
f"SELECT COUNT(*) as cnt FROM qa_gen_question WHERE {where}", params
|
||||||
|
)
|
||||||
|
total = dict(count_rows[0])["cnt"]
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
f"""SELECT id,task_id,section_path,question,reference_answer,source_chunk,
|
||||||
|
quality_score,quality_detail,dup_of,dup_similarity,status,created_at,updated_at,
|
||||||
|
chunk_headers,chunk_id,file_id,file_name
|
||||||
|
FROM qa_gen_question WHERE {where}
|
||||||
|
ORDER BY section_path, created_at
|
||||||
|
LIMIT ? OFFSET ?""",
|
||||||
|
params + [page_size, offset],
|
||||||
|
)
|
||||||
|
|
||||||
|
items = []
|
||||||
|
for r in rows:
|
||||||
|
d = dict(r)
|
||||||
|
if d.get("quality_detail"):
|
||||||
|
try:
|
||||||
|
d["quality_detail"] = json.loads(d["quality_detail"])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
items.append(d)
|
||||||
|
|
||||||
|
return {"status": 0, "data": {"total": total, "items": items}}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/task/{task_id}/sections")
|
||||||
|
async def list_sections(task_id: str):
|
||||||
|
"""返回任务下各章节的问题统计"""
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"""SELECT section_path,
|
||||||
|
COUNT(*) as total,
|
||||||
|
SUM(CASE WHEN status='approved' THEN 1 ELSE 0 END) as approved,
|
||||||
|
SUM(CASE WHEN status='rejected' THEN 1 ELSE 0 END) as rejected,
|
||||||
|
SUM(CASE WHEN status='pending' THEN 1 ELSE 0 END) as pending,
|
||||||
|
SUM(CASE WHEN dup_of IS NOT NULL THEN 1 ELSE 0 END) as duplicates,
|
||||||
|
AVG(quality_score) as avg_quality
|
||||||
|
FROM qa_gen_question WHERE task_id=?
|
||||||
|
GROUP BY section_path ORDER BY section_path""",
|
||||||
|
(task_id,),
|
||||||
|
)
|
||||||
|
return {"status": 0, "data": [dict(r) for r in rows]}
|
||||||
|
|
||||||
|
|
||||||
|
# ── 审核操作 ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/question/{question_id}/approve")
|
||||||
|
async def approve_question(question_id: str):
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT task_id FROM qa_gen_question WHERE id=?", (question_id,)
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
raise HTTPException(status_code=404, detail="Question not found")
|
||||||
|
task_id = dict(rows[0])["task_id"]
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE qa_gen_question SET status='approved', updated_at=? WHERE id=?",
|
||||||
|
(_now(), question_id),
|
||||||
|
)
|
||||||
|
await _sync_approved_count(db, task_id)
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0, "data": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/question/{question_id}/reject")
|
||||||
|
async def reject_question(question_id: str):
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT task_id FROM qa_gen_question WHERE id=?", (question_id,)
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
raise HTTPException(status_code=404, detail="Question not found")
|
||||||
|
task_id = dict(rows[0])["task_id"]
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE qa_gen_question SET status='rejected', updated_at=? WHERE id=?",
|
||||||
|
(_now(), question_id),
|
||||||
|
)
|
||||||
|
await _sync_approved_count(db, task_id)
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0, "data": True}
|
||||||
|
|
||||||
|
|
||||||
|
class QuestionEditReq(BaseModel):
|
||||||
|
question: Optional[str] = None
|
||||||
|
reference_answer: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/question/{question_id}")
|
||||||
|
async def edit_question(question_id: str, req: QuestionEditReq):
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT task_id FROM qa_gen_question WHERE id=?", (question_id,)
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
raise HTTPException(status_code=404, detail="Question not found")
|
||||||
|
task_id = dict(rows[0])["task_id"]
|
||||||
|
updates = []
|
||||||
|
params = []
|
||||||
|
if req.question is not None:
|
||||||
|
updates.append("question=?")
|
||||||
|
params.append(req.question)
|
||||||
|
if req.reference_answer is not None:
|
||||||
|
updates.append("reference_answer=?")
|
||||||
|
params.append(req.reference_answer)
|
||||||
|
if not updates:
|
||||||
|
raise HTTPException(status_code=400, detail="No fields to update")
|
||||||
|
updates.append("status='approved'")
|
||||||
|
updates.append("updated_at=?")
|
||||||
|
params.append(_now())
|
||||||
|
params.append(question_id)
|
||||||
|
await db.execute(
|
||||||
|
f"UPDATE qa_gen_question SET {', '.join(updates)} WHERE id=?", params
|
||||||
|
)
|
||||||
|
await _sync_approved_count(db, task_id)
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0, "data": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/task/{task_id}/batch-approve")
|
||||||
|
async def batch_approve(task_id: str, min_quality: float = 0.0):
|
||||||
|
"""批量通过:通过 quality_score >= min_quality 且非重复的 pending 问题"""
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"""UPDATE qa_gen_question SET status='approved', updated_at=?
|
||||||
|
WHERE task_id=? AND status='pending' AND dup_of IS NULL
|
||||||
|
AND (quality_score IS NULL OR quality_score >= ?)""",
|
||||||
|
(_now(), task_id, min_quality),
|
||||||
|
)
|
||||||
|
await _sync_approved_count(db, task_id)
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0, "data": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── 导出 MD ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/task/{task_id}/export-md")
|
||||||
|
async def export_md(task_id: str):
|
||||||
|
"""导出已通过的问题为标准 MD 格式(与单跳测试输入格式一致)"""
|
||||||
|
async with get_db() as db:
|
||||||
|
task_rows = await db.execute_fetchall(
|
||||||
|
"SELECT name FROM qa_gen_task WHERE id=?", (task_id,)
|
||||||
|
)
|
||||||
|
if not task_rows:
|
||||||
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
task_name = dict(task_rows[0]).get("name", task_id)
|
||||||
|
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"""SELECT section_path, qid, question, reference_answer, file_name
|
||||||
|
FROM (
|
||||||
|
SELECT section_path, question, reference_answer, file_name,
|
||||||
|
ROW_NUMBER() OVER (PARTITION BY section_path ORDER BY created_at) as rn,
|
||||||
|
'Q' || ROW_NUMBER() OVER (PARTITION BY section_path ORDER BY created_at) as qid
|
||||||
|
FROM qa_gen_question
|
||||||
|
WHERE task_id=? AND status='approved'
|
||||||
|
)
|
||||||
|
ORDER BY section_path, rn""",
|
||||||
|
(task_id,),
|
||||||
|
)
|
||||||
|
# Convert rows to dicts while connection is still open
|
||||||
|
row_dicts = [dict(r) for r in rows]
|
||||||
|
|
||||||
|
if not row_dicts:
|
||||||
|
raise HTTPException(status_code=404, detail="没有已通过的问题")
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
sections: dict[str, list] = defaultdict(list)
|
||||||
|
section_file_names: dict[str, str] = {}
|
||||||
|
for d in row_dicts:
|
||||||
|
sections[d["section_path"]].append(d)
|
||||||
|
if d.get("file_name") and d["section_path"] not in section_file_names:
|
||||||
|
section_file_names[d["section_path"]] = d["file_name"]
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
import re
|
||||||
|
|
||||||
|
def clean_for_parser(text: str) -> str:
|
||||||
|
"""清理文本以匹配解析器正则表达式,保留中文字符"""
|
||||||
|
if not text:
|
||||||
|
return "default"
|
||||||
|
# 保留中文字符、数字、字母、下划线、斜杠、空格、点、连字符
|
||||||
|
cleaned = re.sub(r'[^一-龥a-zA-Z0-9_/ .\-]', '_', text)
|
||||||
|
cleaned = cleaned.strip()
|
||||||
|
if cleaned.startswith('.'):
|
||||||
|
cleaned = '_' + cleaned[1:]
|
||||||
|
return cleaned if cleaned else "default_section"
|
||||||
|
|
||||||
|
section_index = 0
|
||||||
|
for section_path, items in sections.items():
|
||||||
|
section_index += 1
|
||||||
|
file_name = section_file_names.get(section_path)
|
||||||
|
|
||||||
|
if file_name:
|
||||||
|
# 使用 Dagent 的 file_name 作为 section 标识
|
||||||
|
doc_name = file_name.rsplit(".", 1)[0] if "." in file_name else file_name
|
||||||
|
chapter_title = f"第{section_index}章 {doc_name.split('/')[-1]}"
|
||||||
|
lines.append(f"# {chapter_title}")
|
||||||
|
lines.append(f"## {file_name} / {doc_name}")
|
||||||
|
lines.append(f"# {section_index}. {doc_name.split('/')[-1]}_Document")
|
||||||
|
else:
|
||||||
|
# 回退:没有 file_name 时用清理后的 section_path
|
||||||
|
clean_section_path = clean_for_parser(section_path)
|
||||||
|
raw_doc_name = section_path.split("/")[-1] if "/" in section_path else section_path
|
||||||
|
clean_doc_name = clean_for_parser(raw_doc_name)
|
||||||
|
chapter_title = f"第{section_index}章 {clean_doc_name}"
|
||||||
|
lines.append(f"# {chapter_title}")
|
||||||
|
lines.append(f"## {clean_section_path} / {clean_doc_name}")
|
||||||
|
lines.append(f"# {section_index}. {clean_doc_name}_Document")
|
||||||
|
|
||||||
|
lines.append("> Generated from QA generation task")
|
||||||
|
lines.append("---")
|
||||||
|
lines.append("")
|
||||||
|
for item in items:
|
||||||
|
qid = item["qid"]
|
||||||
|
aid = qid.replace("Q", "A")
|
||||||
|
lines.append(f"## {qid}: {item['question']}")
|
||||||
|
lines.append(f"**{aid}:** {item['reference_answer']}")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("---")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
md_content = "\n".join(lines)
|
||||||
|
filename_encoded = quote(f"qa_{task_name}.md".replace(" ", "_"))
|
||||||
|
return StreamingResponse(
|
||||||
|
iter([md_content.encode("utf-8")]),
|
||||||
|
media_type="text/markdown",
|
||||||
|
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{filename_encoded}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── 内部:运行生成任务 ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _run_task(task_id: str, md_text: str):
|
||||||
|
try:
|
||||||
|
from rag_eval.single_jump.parser import parse_qa_file_text as _parse
|
||||||
|
|
||||||
|
# 复用 single_jump parser 解析章节结构,但这里 md_text 是知识库原文
|
||||||
|
# 需要用自定义解析器按 ## 切分章节
|
||||||
|
sections = _parse_knowledge_md(md_text)
|
||||||
|
total = len(sections)
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE qa_gen_task SET status='running', total=? WHERE id=?",
|
||||||
|
(total, task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# 获取 judge_config
|
||||||
|
async with get_db() as db:
|
||||||
|
cfg_rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM qa_gen_task t JOIN judge_config j ON t.judge_config_id=j.id WHERE t.id=?",
|
||||||
|
(task_id,),
|
||||||
|
)
|
||||||
|
if not cfg_rows:
|
||||||
|
raise ValueError("judge_config not found")
|
||||||
|
cfg = dict(cfg_rows[0])
|
||||||
|
questions_per_section = cfg["questions_per_section"]
|
||||||
|
quality_threshold = cfg["quality_threshold"]
|
||||||
|
|
||||||
|
# 逐章节生成
|
||||||
|
sem = asyncio.Semaphore(3)
|
||||||
|
done = 0
|
||||||
|
|
||||||
|
async def gen_section(section_path: str, content: str):
|
||||||
|
nonlocal done
|
||||||
|
async with sem:
|
||||||
|
questions = await _generate_questions(
|
||||||
|
cfg=cfg,
|
||||||
|
section_path=section_path,
|
||||||
|
content=content,
|
||||||
|
n=questions_per_section,
|
||||||
|
)
|
||||||
|
async with get_db() as db2:
|
||||||
|
for q in questions:
|
||||||
|
qid = _id()
|
||||||
|
# 简单质量评分:暂时用 LLM 返回的置信度,后续可扩展
|
||||||
|
quality_score = q.get("quality_score", 0.8)
|
||||||
|
status = "approved" if quality_score >= quality_threshold else "pending"
|
||||||
|
await db2.execute(
|
||||||
|
"""INSERT INTO qa_gen_question
|
||||||
|
(id,task_id,section_path,question,reference_answer,source_chunk,
|
||||||
|
quality_score,status,created_at)
|
||||||
|
VALUES (?,?,?,?,?,?,?,?,?)""",
|
||||||
|
(qid, task_id, section_path,
|
||||||
|
q["question"], q["answer"], q.get("source_chunk", ""),
|
||||||
|
quality_score, status, _now()),
|
||||||
|
)
|
||||||
|
done += 1
|
||||||
|
await db2.execute(
|
||||||
|
"UPDATE qa_gen_task SET progress=? WHERE id=?", (done, task_id)
|
||||||
|
)
|
||||||
|
await _sync_approved_count(db2, task_id)
|
||||||
|
await db2.commit()
|
||||||
|
|
||||||
|
await asyncio.gather(*[gen_section(sp, ct) for sp, ct in sections])
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE qa_gen_task SET status='done', finished_at=? WHERE id=?",
|
||||||
|
(_now(), task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE qa_gen_task SET status='failed', error_message=? WHERE id=?",
|
||||||
|
(str(exc), task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_knowledge_md(md_text: str) -> list[tuple[str, str]]:
|
||||||
|
"""
|
||||||
|
将知识库 MD 文件按 ## 标题切分为 (section_path, content) 列表。
|
||||||
|
支持多级标题,用 / 拼接路径。
|
||||||
|
"""
|
||||||
|
lines = md_text.splitlines()
|
||||||
|
sections: list[tuple[str, str]] = []
|
||||||
|
current_path: list[str] = []
|
||||||
|
current_lines: list[str] = []
|
||||||
|
current_level = 0
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
m = re.match(r'^(#{1,4})\s+(.+)', line)
|
||||||
|
if m:
|
||||||
|
# 保存上一个 section
|
||||||
|
if current_path and current_lines:
|
||||||
|
content = "\n".join(current_lines).strip()
|
||||||
|
if content:
|
||||||
|
sections.append(("/".join(current_path), content))
|
||||||
|
level = len(m.group(1))
|
||||||
|
title = m.group(2).strip()
|
||||||
|
# 调整路径深度
|
||||||
|
if level > current_level:
|
||||||
|
current_path.append(title)
|
||||||
|
elif level == current_level:
|
||||||
|
if current_path:
|
||||||
|
current_path[-1] = title
|
||||||
|
else:
|
||||||
|
current_path = [title]
|
||||||
|
else:
|
||||||
|
# 回退到对应层级
|
||||||
|
current_path = current_path[:level - 1] + [title]
|
||||||
|
current_level = level
|
||||||
|
current_lines = []
|
||||||
|
else:
|
||||||
|
current_lines.append(line)
|
||||||
|
|
||||||
|
# 最后一个 section
|
||||||
|
if current_path and current_lines:
|
||||||
|
content = "\n".join(current_lines).strip()
|
||||||
|
if content:
|
||||||
|
sections.append(("/".join(current_path), content))
|
||||||
|
|
||||||
|
return sections
|
||||||
|
|
||||||
|
|
||||||
|
async def _generate_questions(
|
||||||
|
cfg: dict,
|
||||||
|
section_path: str,
|
||||||
|
content: str,
|
||||||
|
n: int,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""调用 LLM 生成问题,返回 [{question, answer, source_chunk, quality_score}]"""
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
base_url = cfg.get("base_url", "").rstrip("/")
|
||||||
|
api_key = cfg.get("api_key", "")
|
||||||
|
model = cfg.get("model", "gpt-4o-mini")
|
||||||
|
|
||||||
|
# 截断过长内容
|
||||||
|
content_truncated = content[:3000] if len(content) > 3000 else content
|
||||||
|
|
||||||
|
prompt = f"""你是一个专业的技术文档测试问题生成专家。
|
||||||
|
|
||||||
|
根据以下技术文档章节内容,生成 {n} 个测试问题。
|
||||||
|
|
||||||
|
章节路径:{section_path}
|
||||||
|
章节内容:
|
||||||
|
{content_truncated}
|
||||||
|
|
||||||
|
要求:
|
||||||
|
1. 问题必须能从该章节内容直接回答,不要生成需要跨文档才能回答的问题
|
||||||
|
2. 问题应覆盖章节的关键知识点,避免过于简单的是非题
|
||||||
|
3. 问题表述清晰,无歧义
|
||||||
|
4. 答案准确,与原文一致,长度适中(1-3句话)
|
||||||
|
5. source_chunk 为答案来源的原文片段(50-150字)
|
||||||
|
6. quality_score 为你对该问题质量的评估(0-1,1为最高质量)
|
||||||
|
|
||||||
|
只输出 JSON 数组,不要有其他内容:
|
||||||
|
[
|
||||||
|
{{
|
||||||
|
"question": "问题文本",
|
||||||
|
"answer": "参考答案",
|
||||||
|
"source_chunk": "答案来源原文片段",
|
||||||
|
"quality_score": 0.9
|
||||||
|
}}
|
||||||
|
]"""
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"temperature": 0.3,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(headers=headers) as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{base_url}/chat/completions",
|
||||||
|
json=payload,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=60),
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = await resp.json()
|
||||||
|
|
||||||
|
text = data["choices"][0]["message"]["content"].strip()
|
||||||
|
# 提取 JSON 数组
|
||||||
|
m = re.search(r'\[.*\]', text, re.DOTALL)
|
||||||
|
if not m:
|
||||||
|
return []
|
||||||
|
questions = json.loads(m.group())
|
||||||
|
# 校验字段
|
||||||
|
result = []
|
||||||
|
for q in questions:
|
||||||
|
if isinstance(q, dict) and q.get("question") and q.get("answer"):
|
||||||
|
result.append({
|
||||||
|
"question": str(q["question"]).strip(),
|
||||||
|
"answer": str(q["answer"]).strip(),
|
||||||
|
"source_chunk": str(q.get("source_chunk", "")).strip(),
|
||||||
|
"quality_score": float(q.get("quality_score", 0.8)),
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
# 生成失败不中断整个任务,返回空列表
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def _sync_approved_count(db, task_id: str):
|
||||||
|
"""同步更新 qa_gen_task.approved 计数"""
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT COUNT(*) as cnt FROM qa_gen_question WHERE task_id=? AND status='approved'",
|
||||||
|
(task_id,),
|
||||||
|
)
|
||||||
|
approved = dict(rows[0])["cnt"] if rows else 0
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE qa_gen_task SET approved=? WHERE id=?", (approved, task_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
1174
server/api/qa_gen_dagent.py
Normal file
1174
server/api/qa_gen_dagent.py
Normal file
File diff suppressed because it is too large
Load Diff
36
server/api/report.py
Normal file
36
server/api/report.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
|
# Add parent directory to sys.path for relative imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
from models.db import get_db
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/report", tags=["评测报告"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{task_id}")
|
||||||
|
async def get_report(task_id: str):
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM eval_report WHERE task_id=?", (task_id,)
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
raise HTTPException(status_code=404, detail="Report not found. Task may still be running.")
|
||||||
|
return {"status": 0, "data": dict(rows[0])}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{task_id}/items")
|
||||||
|
async def get_report_items(task_id: str):
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM eval_result WHERE task_id=? ORDER BY rowid ASC", (task_id,)
|
||||||
|
)
|
||||||
|
items = []
|
||||||
|
for r in rows:
|
||||||
|
d = dict(r)
|
||||||
|
d["retrieved_chunks"] = json.loads(d["retrieved_chunks"] or "[]")
|
||||||
|
d["judge_detail"] = json.loads(d["judge_detail"] or "{}")
|
||||||
|
items.append(d)
|
||||||
|
return {"status": 0, "data": {"total": len(items), "records": items}}
|
||||||
1165
server/api/single_jump.py
Normal file
1165
server/api/single_jump.py
Normal file
File diff suppressed because it is too large
Load Diff
90
server/api/task.py
Normal file
90
server/api/task.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
# Add parent directory to sys.path for relative imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
from models.db import get_db, _now, _id
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/task", tags=["评测任务"])
|
||||||
|
|
||||||
|
|
||||||
|
class RunTaskReq(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
dataset_id: str
|
||||||
|
platform_config_id: str
|
||||||
|
judge_config_id: str
|
||||||
|
agent_id: str
|
||||||
|
knowledge_hub_id: str
|
||||||
|
file_id_list: list[str] = []
|
||||||
|
top_k: int = 10
|
||||||
|
eval_retrieval: bool = True
|
||||||
|
eval_generation: bool = True
|
||||||
|
selected_metrics: list[str] = []
|
||||||
|
concurrency: int = 3
|
||||||
|
|
||||||
|
|
||||||
|
def _task_dict(r) -> dict:
|
||||||
|
d = dict(r)
|
||||||
|
d["file_id_list"] = json.loads(r["file_id_list"] or "[]")
|
||||||
|
d["selected_metrics"] = json.loads(r["selected_metrics"] or "[]")
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/run")
|
||||||
|
async def run_task(req: RunTaskReq):
|
||||||
|
async with get_db() as db:
|
||||||
|
task_id = _id()
|
||||||
|
await db.execute(
|
||||||
|
"""INSERT INTO eval_task
|
||||||
|
(id,name,dataset_id,platform_config_id,judge_config_id,agent_id,
|
||||||
|
knowledge_hub_id,file_id_list,top_k,eval_retrieval,eval_generation,
|
||||||
|
selected_metrics,concurrency,status,progress,total,created_at)
|
||||||
|
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,'pending',0,0,?)""",
|
||||||
|
(task_id, req.name, req.dataset_id, req.platform_config_id,
|
||||||
|
req.judge_config_id, req.agent_id, req.knowledge_hub_id,
|
||||||
|
json.dumps(req.file_id_list), req.top_k,
|
||||||
|
int(req.eval_retrieval), int(req.eval_generation),
|
||||||
|
json.dumps(req.selected_metrics),
|
||||||
|
req.concurrency, _now()),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
task_svc = importlib.import_module("service.task_service")
|
||||||
|
asyncio.create_task(task_svc.run_eval_task(task_id))
|
||||||
|
return {"status": 0, "data": {"id": task_id}}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/list")
|
||||||
|
async def list_tasks():
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM eval_task ORDER BY created_at DESC"
|
||||||
|
)
|
||||||
|
return {"status": 0, "data": [_task_dict(r) for r in rows]}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{task_id}")
|
||||||
|
async def get_task(task_id: str):
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM eval_task WHERE id=?", (task_id,)
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
return {"status": 0, "data": _task_dict(rows[0])}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{task_id}")
|
||||||
|
async def delete_task(task_id: str):
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute("DELETE FROM eval_result WHERE task_id=?", (task_id,))
|
||||||
|
await db.execute("DELETE FROM eval_report WHERE task_id=?", (task_id,))
|
||||||
|
await db.execute("DELETE FROM eval_task WHERE id=?", (task_id,))
|
||||||
|
await db.commit()
|
||||||
|
return {"status": 0, "data": True}
|
||||||
68
server/main.py
Normal file
68
server/main.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import sys
|
||||||
|
import io
|
||||||
|
from pathlib import Path
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
|
# Fix Windows console GBK encoding: force stdout/stderr to UTF-8 with replace
|
||||||
|
# so print() of non-GBK chars (e.g. ‑, ᵀ) never raises.
|
||||||
|
if sys.platform == "win32":
|
||||||
|
try:
|
||||||
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace", line_buffering=True)
|
||||||
|
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace", line_buffering=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent / "sdk"))
|
||||||
|
|
||||||
|
from models.db import init_db
|
||||||
|
from api import config, dataset, task, report, single_jump, qa_gen, qa_gen_dagent, loop, multi_hop, multi_hop_gen, prompt_template
|
||||||
|
from service.loop_engine import recover_orphaned_loops
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
await init_db()
|
||||||
|
# Recover orphaned loop tasks (set 'running' to 'paused' on startup)
|
||||||
|
await recover_orphaned_loops()
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="RAG Eval Framework", version="0.1.0", lifespan=lifespan)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.include_router(config.router)
|
||||||
|
app.include_router(dataset.router)
|
||||||
|
app.include_router(task.router)
|
||||||
|
app.include_router(report.router)
|
||||||
|
app.include_router(single_jump.router)
|
||||||
|
app.include_router(qa_gen.router)
|
||||||
|
app.include_router(qa_gen_dagent.router)
|
||||||
|
app.include_router(loop.router)
|
||||||
|
app.include_router(multi_hop.router)
|
||||||
|
app.include_router(multi_hop_gen.router)
|
||||||
|
app.include_router(prompt_template.router)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/health")
|
||||||
|
async def health():
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
# Serve frontend static files (built React app)
|
||||||
|
frontend_dist = Path(__file__).parent.parent / "frontend" / "dist"
|
||||||
|
if frontend_dist.exists():
|
||||||
|
app.mount("/", StaticFiles(directory=str(frontend_dist), html=True), name="frontend")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run("main:app", host="0.0.0.0", port=8021, reload=True)
|
||||||
0
server/models/__init__.py
Normal file
0
server/models/__init__.py
Normal file
227
server/models/db.py
Normal file
227
server/models/db.py
Normal file
@ -0,0 +1,227 @@
|
|||||||
|
import aiosqlite
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
DB_PATH = Path(__file__).parent.parent / "data" / "rag_eval.db"
|
||||||
|
SCHEMA_PATH = Path(__file__).parent / "schema.sql"
|
||||||
|
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_db():
|
||||||
|
"""Async context manager that yields a configured aiosqlite connection."""
|
||||||
|
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
async with aiosqlite.connect(DB_PATH, timeout=30.0) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
await db.execute("PRAGMA journal_mode=WAL")
|
||||||
|
await db.execute("PRAGMA busy_timeout=30000")
|
||||||
|
await db.execute("PRAGMA synchronous=NORMAL")
|
||||||
|
yield db
|
||||||
|
|
||||||
|
|
||||||
|
async def init_db():
|
||||||
|
async with get_db() as db:
|
||||||
|
sql = SCHEMA_PATH.read_text(encoding="utf-8")
|
||||||
|
await db.executescript(sql)
|
||||||
|
await _run_migrations(db)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_migrations(db: aiosqlite.Connection):
|
||||||
|
"""Apply forward-only lightweight migrations for existing local DBs."""
|
||||||
|
await _ensure_columns(
|
||||||
|
db,
|
||||||
|
"single_jump_result",
|
||||||
|
(
|
||||||
|
("file_name", "TEXT"),
|
||||||
|
("match_type", "TEXT"),
|
||||||
|
("is_file_hit", "INTEGER DEFAULT 0"),
|
||||||
|
("expected_chunk_id", "TEXT"),
|
||||||
|
("is_chunk_hit", "INTEGER DEFAULT 0"),
|
||||||
|
("chunk_hit_rank", "INTEGER"),
|
||||||
|
("retrieved_chunk_ids", "TEXT"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await _ensure_columns(
|
||||||
|
db,
|
||||||
|
"single_jump_task",
|
||||||
|
(
|
||||||
|
("progress", "INTEGER DEFAULT 0"),
|
||||||
|
("total", "INTEGER DEFAULT 0"),
|
||||||
|
("error_message", "TEXT"),
|
||||||
|
("finished_at", "TEXT"),
|
||||||
|
("md_content", "TEXT"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# qa_gen tables migration
|
||||||
|
await _ensure_columns(
|
||||||
|
db,
|
||||||
|
"qa_gen_question",
|
||||||
|
(
|
||||||
|
("source_chunk", "TEXT"),
|
||||||
|
("quality_score", "REAL"),
|
||||||
|
("quality_detail", "TEXT"),
|
||||||
|
("dup_of", "TEXT"),
|
||||||
|
("dup_similarity", "REAL"),
|
||||||
|
("embedding", "TEXT"),
|
||||||
|
("updated_at", "TEXT"),
|
||||||
|
("file_id", "TEXT"),
|
||||||
|
("file_name", "TEXT"),
|
||||||
|
("chunk_id", "TEXT"),
|
||||||
|
("chunk_headers", "TEXT"),
|
||||||
|
("chunk_content_preview", "TEXT"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await _ensure_columns(
|
||||||
|
db,
|
||||||
|
"qa_gen_task",
|
||||||
|
(
|
||||||
|
("approved", "INTEGER DEFAULT 0"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await _ensure_columns(
|
||||||
|
db,
|
||||||
|
"loop_round",
|
||||||
|
(
|
||||||
|
("dedup_progress", "TEXT"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# multi_hop_gen_task: add new columns for dagent source
|
||||||
|
await _ensure_columns(
|
||||||
|
db,
|
||||||
|
"multi_hop_gen_task",
|
||||||
|
(
|
||||||
|
("source", "TEXT NOT NULL DEFAULT 'file'"),
|
||||||
|
("org_id", "TEXT"),
|
||||||
|
("file_ids", "TEXT DEFAULT ''"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# multi_hop_task: add llm_type column
|
||||||
|
await _ensure_columns(
|
||||||
|
db,
|
||||||
|
"multi_hop_task",
|
||||||
|
(
|
||||||
|
("judge_config_id", "TEXT DEFAULT ''"),
|
||||||
|
("llm_type", "TEXT DEFAULT 'deepseek_v3'"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# multi_hop_task: add agent_id
|
||||||
|
await _ensure_columns(
|
||||||
|
db,
|
||||||
|
"multi_hop_task",
|
||||||
|
(
|
||||||
|
("agent_id", "TEXT DEFAULT ''"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# multi_hop_result: add actual_hops and agent_answer
|
||||||
|
await _ensure_columns(
|
||||||
|
db,
|
||||||
|
"multi_hop_result",
|
||||||
|
(
|
||||||
|
("actual_hops", "TEXT DEFAULT '[]'"),
|
||||||
|
("agent_answer", "TEXT DEFAULT ''"),
|
||||||
|
("chunk_hit_count", "INTEGER DEFAULT 0"),
|
||||||
|
("full_chunk_hit", "INTEGER DEFAULT 0"),
|
||||||
|
("partial_chunk_hit", "INTEGER DEFAULT 0"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# multi_hop_gen_task: add prompt_template_id
|
||||||
|
await _ensure_columns(
|
||||||
|
db,
|
||||||
|
"multi_hop_gen_task",
|
||||||
|
(
|
||||||
|
("prompt_template_id", "TEXT"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# loop_task: add global_dedup flag
|
||||||
|
await _ensure_columns(
|
||||||
|
db,
|
||||||
|
"loop_task",
|
||||||
|
(
|
||||||
|
("global_dedup", "INTEGER DEFAULT 0"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# loop_round: add chunk_hit for chunk-level hit tracking
|
||||||
|
await _ensure_columns(
|
||||||
|
db,
|
||||||
|
"loop_round",
|
||||||
|
(
|
||||||
|
("chunk_hit", "INTEGER DEFAULT 0"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# loop_task: add total_chunk_hit for chunk-level aggregation
|
||||||
|
await _ensure_columns(
|
||||||
|
db,
|
||||||
|
"loop_task",
|
||||||
|
(
|
||||||
|
("total_chunk_hit", "INTEGER DEFAULT 0"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# single_jump_task: add recall_top_k for unlimited recall results
|
||||||
|
await _ensure_columns(
|
||||||
|
db,
|
||||||
|
"single_jump_task",
|
||||||
|
(
|
||||||
|
("recall_top_k", "INTEGER DEFAULT 64"),
|
||||||
|
("hit_top_k", "INTEGER DEFAULT 64"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# single_jump_result: add hit_top_k for chunk hit calculation
|
||||||
|
await _ensure_columns(
|
||||||
|
db,
|
||||||
|
"single_jump_result",
|
||||||
|
(
|
||||||
|
("hit_top_k", "INTEGER DEFAULT 64"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# single_jump_result: add raw_chunk_headers for original section title
|
||||||
|
await _ensure_columns(
|
||||||
|
db,
|
||||||
|
"single_jump_result",
|
||||||
|
(
|
||||||
|
("raw_chunk_headers", "TEXT"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# loop_task: add recall_top_k for unlimited recall results
|
||||||
|
await _ensure_columns(
|
||||||
|
db,
|
||||||
|
"loop_task",
|
||||||
|
(
|
||||||
|
("recall_top_k", "INTEGER DEFAULT 64"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# loop_task: 批次规划中的切片总数,用于校验拉取是否完整(与 chunk_batches_plan.chunk_count 一致)
|
||||||
|
await _ensure_columns(
|
||||||
|
db,
|
||||||
|
"loop_task",
|
||||||
|
(
|
||||||
|
("expected_chunk_count", "INTEGER"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _ensure_columns(
|
||||||
|
db: aiosqlite.Connection,
|
||||||
|
table_name: str,
|
||||||
|
columns: Iterable[tuple[str, str]],
|
||||||
|
):
|
||||||
|
"""Ensure table has required columns; add missing ones via ALTER TABLE."""
|
||||||
|
rows = await db.execute_fetchall(f"PRAGMA table_info({table_name})")
|
||||||
|
existing = {row["name"] for row in rows}
|
||||||
|
for column_name, column_def in columns:
|
||||||
|
if column_name in existing:
|
||||||
|
continue
|
||||||
|
await db.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_def}")
|
||||||
|
|
||||||
|
|
||||||
|
def _now() -> str:
|
||||||
|
return datetime.utcnow().isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
def _id() -> str:
|
||||||
|
return uuid.uuid4().hex
|
||||||
340
server/models/schema.sql
Normal file
340
server/models/schema.sql
Normal file
@ -0,0 +1,340 @@
|
|||||||
|
-- RAG Eval Framework — SQLite schema
|
||||||
|
-- server/models/schema.sql
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS platform_config (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
type TEXT NOT NULL DEFAULT 'dagent',
|
||||||
|
base_url TEXT NOT NULL,
|
||||||
|
org_id TEXT,
|
||||||
|
token TEXT,
|
||||||
|
created_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS judge_config (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
base_url TEXT NOT NULL,
|
||||||
|
api_key TEXT NOT NULL,
|
||||||
|
model TEXT NOT NULL,
|
||||||
|
embed_base_url TEXT DEFAULT '',
|
||||||
|
embed_api_key TEXT DEFAULT '',
|
||||||
|
embed_model TEXT DEFAULT 'text-embedding-3-small',
|
||||||
|
created_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS eval_dataset (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
sample_count INTEGER DEFAULT 0,
|
||||||
|
created_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS eval_sample (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
dataset_id TEXT NOT NULL,
|
||||||
|
question TEXT NOT NULL,
|
||||||
|
reference_answer TEXT NOT NULL,
|
||||||
|
relevant_chunk_ids TEXT NOT NULL DEFAULT '[]',
|
||||||
|
knowledge_hub_id TEXT NOT NULL,
|
||||||
|
source_file_id TEXT,
|
||||||
|
metadata TEXT DEFAULT '{}'
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS eval_task (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT,
|
||||||
|
dataset_id TEXT NOT NULL,
|
||||||
|
platform_config_id TEXT NOT NULL,
|
||||||
|
judge_config_id TEXT NOT NULL,
|
||||||
|
agent_id TEXT NOT NULL,
|
||||||
|
knowledge_hub_id TEXT NOT NULL,
|
||||||
|
file_id_list TEXT DEFAULT '[]',
|
||||||
|
top_k INTEGER DEFAULT 10,
|
||||||
|
eval_retrieval INTEGER DEFAULT 1,
|
||||||
|
eval_generation INTEGER DEFAULT 1,
|
||||||
|
concurrency INTEGER DEFAULT 3,
|
||||||
|
selected_metrics TEXT DEFAULT '[]',
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending',
|
||||||
|
progress INTEGER DEFAULT 0,
|
||||||
|
total INTEGER DEFAULT 0,
|
||||||
|
error_message TEXT,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
finished_at TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS eval_result (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
task_id TEXT NOT NULL,
|
||||||
|
sample_id TEXT NOT NULL,
|
||||||
|
question TEXT,
|
||||||
|
reference_answer TEXT,
|
||||||
|
retrieved_chunks TEXT,
|
||||||
|
agent_answer TEXT,
|
||||||
|
hit_rate REAL,
|
||||||
|
mrr REAL,
|
||||||
|
ndcg REAL,
|
||||||
|
context_precision REAL,
|
||||||
|
context_recall REAL,
|
||||||
|
faithfulness REAL,
|
||||||
|
answer_relevance REAL,
|
||||||
|
answer_correctness REAL,
|
||||||
|
groundedness REAL,
|
||||||
|
latency_ms INTEGER,
|
||||||
|
judge_detail TEXT,
|
||||||
|
error TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS generate_task (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
dataset_id TEXT NOT NULL,
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending',
|
||||||
|
progress INTEGER DEFAULT 0,
|
||||||
|
total INTEGER DEFAULT 0,
|
||||||
|
error_message TEXT,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
finished_at TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS single_jump_task (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT,
|
||||||
|
env_url TEXT NOT NULL,
|
||||||
|
org_id TEXT NOT NULL,
|
||||||
|
d_user_id TEXT DEFAULT 'test',
|
||||||
|
agent_id TEXT DEFAULT '', -- 用于召回测试的 agent ID
|
||||||
|
top_k INTEGER DEFAULT 64,
|
||||||
|
concurrency INTEGER DEFAULT 5,
|
||||||
|
cross_chunk INTEGER DEFAULT 1,
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending',
|
||||||
|
progress INTEGER DEFAULT 0,
|
||||||
|
total INTEGER DEFAULT 0,
|
||||||
|
error_message TEXT,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
finished_at TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS single_jump_result (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
task_id TEXT NOT NULL,
|
||||||
|
section_path TEXT,
|
||||||
|
doc_name TEXT,
|
||||||
|
file_id TEXT,
|
||||||
|
file_name TEXT,
|
||||||
|
match_type TEXT,
|
||||||
|
qid TEXT,
|
||||||
|
question TEXT,
|
||||||
|
reference_answer TEXT,
|
||||||
|
top_k INTEGER,
|
||||||
|
retrieved TEXT DEFAULT '[]',
|
||||||
|
latency_ms INTEGER DEFAULT 0,
|
||||||
|
error TEXT,
|
||||||
|
best_cosine_sim REAL,
|
||||||
|
avg_cosine_sim REAL,
|
||||||
|
is_file_hit INTEGER DEFAULT 0,
|
||||||
|
expected_chunk_id TEXT, -- 期望命中的切片ID
|
||||||
|
is_chunk_hit INTEGER DEFAULT 0, -- 是否命中切片
|
||||||
|
chunk_hit_rank INTEGER, -- 切片命中排名
|
||||||
|
retrieved_chunk_ids TEXT, -- JSON数组:召回的所有切片ID
|
||||||
|
raw_chunk_headers TEXT -- 原始切片标题(从元数据解析)
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Indexes for single_jump_result
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_single_jump_result_task_id ON single_jump_result(task_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_single_jump_result_section_path ON single_jump_result(section_path);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_single_jump_result_is_file_hit ON single_jump_result(is_file_hit);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_single_jump_result_error ON single_jump_result(error);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_single_jump_result_task_section ON single_jump_result(task_id, section_path);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS multi_hop_task (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT,
|
||||||
|
env_url TEXT NOT NULL,
|
||||||
|
org_id TEXT NOT NULL,
|
||||||
|
d_user_id TEXT DEFAULT 'test',
|
||||||
|
agent_id TEXT DEFAULT '',
|
||||||
|
judge_config_id TEXT DEFAULT '',
|
||||||
|
top_k INTEGER DEFAULT 10,
|
||||||
|
concurrency INTEGER DEFAULT 5,
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending',
|
||||||
|
progress INTEGER DEFAULT 0,
|
||||||
|
total INTEGER DEFAULT 0,
|
||||||
|
error_message TEXT,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
finished_at TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS multi_hop_result (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
task_id TEXT NOT NULL,
|
||||||
|
qid TEXT,
|
||||||
|
question TEXT,
|
||||||
|
answer TEXT,
|
||||||
|
type TEXT,
|
||||||
|
top_k INTEGER,
|
||||||
|
hops TEXT DEFAULT '[]', -- JSON: [{section_path, file_id, file_name, hit, contribution}]
|
||||||
|
actual_hops TEXT DEFAULT '[]', -- JSON: [{hop_index, query, retrieved:[{file_id,headers,file_name}]}]
|
||||||
|
retrieved TEXT DEFAULT '[]', -- JSON: 所有跳合并去重的召回结果(兼容旧逻辑)
|
||||||
|
agent_answer TEXT DEFAULT '', -- Agent 最终回答
|
||||||
|
latency_ms INTEGER DEFAULT 0,
|
||||||
|
error TEXT,
|
||||||
|
best_cosine_sim REAL,
|
||||||
|
full_hit INTEGER DEFAULT 0,
|
||||||
|
partial_hit INTEGER DEFAULT 0,
|
||||||
|
hop_count INTEGER DEFAULT 0,
|
||||||
|
hop_hit_count INTEGER DEFAULT 0
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS qa_gen_task (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT,
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending',
|
||||||
|
judge_config_id TEXT NOT NULL,
|
||||||
|
questions_per_section INTEGER DEFAULT 5,
|
||||||
|
quality_threshold REAL DEFAULT 0.6,
|
||||||
|
progress INTEGER DEFAULT 0,
|
||||||
|
total INTEGER DEFAULT 0,
|
||||||
|
approved INTEGER DEFAULT 0,
|
||||||
|
error_message TEXT,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
finished_at TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS qa_gen_question (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
task_id TEXT NOT NULL,
|
||||||
|
section_path TEXT NOT NULL,
|
||||||
|
question TEXT NOT NULL,
|
||||||
|
reference_answer TEXT NOT NULL,
|
||||||
|
source_chunk TEXT,
|
||||||
|
quality_score REAL,
|
||||||
|
quality_detail TEXT,
|
||||||
|
dup_of TEXT,
|
||||||
|
dup_similarity REAL,
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending',
|
||||||
|
embedding TEXT,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
updated_at TEXT,
|
||||||
|
file_id TEXT,
|
||||||
|
file_name TEXT,
|
||||||
|
chunk_id TEXT, -- 切片ID,用于追踪问题来源的切片
|
||||||
|
chunk_headers TEXT, -- 切片标题路径
|
||||||
|
chunk_content_preview TEXT -- 切片内容预览(前500字)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS eval_report (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
task_id TEXT NOT NULL UNIQUE,
|
||||||
|
sample_count INTEGER,
|
||||||
|
avg_hit_rate REAL,
|
||||||
|
avg_mrr REAL,
|
||||||
|
avg_ndcg REAL,
|
||||||
|
avg_context_precision REAL,
|
||||||
|
avg_context_recall REAL,
|
||||||
|
avg_faithfulness REAL,
|
||||||
|
avg_answer_relevance REAL,
|
||||||
|
avg_answer_correctness REAL,
|
||||||
|
avg_groundedness REAL,
|
||||||
|
rag_score REAL,
|
||||||
|
hallucination_rate REAL,
|
||||||
|
interpretation TEXT,
|
||||||
|
created_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS loop_task (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT,
|
||||||
|
org_id TEXT NOT NULL,
|
||||||
|
judge_config_id TEXT NOT NULL,
|
||||||
|
file_ids TEXT DEFAULT '',
|
||||||
|
questions_per_section INTEGER DEFAULT 5,
|
||||||
|
quality_threshold REAL DEFAULT 0.6,
|
||||||
|
include_multimodal INTEGER DEFAULT 1,
|
||||||
|
env_url TEXT NOT NULL,
|
||||||
|
d_user_id TEXT DEFAULT 'test',
|
||||||
|
agent_id TEXT DEFAULT '', -- 用于召回测试的 agent ID
|
||||||
|
top_k INTEGER DEFAULT 64,
|
||||||
|
concurrency INTEGER DEFAULT 20,
|
||||||
|
cross_chunk INTEGER DEFAULT 1,
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending',
|
||||||
|
current_round INTEGER DEFAULT 0,
|
||||||
|
max_rounds INTEGER DEFAULT 0,
|
||||||
|
max_questions INTEGER DEFAULT 0,
|
||||||
|
total_generated INTEGER DEFAULT 0,
|
||||||
|
total_approved INTEGER DEFAULT 0,
|
||||||
|
total_duplicates INTEGER DEFAULT 0,
|
||||||
|
total_tested INTEGER DEFAULT 0,
|
||||||
|
total_recalled INTEGER DEFAULT 0,
|
||||||
|
total_file_hit INTEGER DEFAULT 0,
|
||||||
|
total_file_miss INTEGER DEFAULT 0,
|
||||||
|
total_recall_failed INTEGER DEFAULT 0,
|
||||||
|
error_message TEXT,
|
||||||
|
global_dedup INTEGER DEFAULT 0, -- 是否全局去重(跨任务)
|
||||||
|
expected_chunk_count INTEGER, -- 批次规划切片总数,与 chunk_batches_plan.chunk_count 对齐
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
paused_at TEXT,
|
||||||
|
finished_at TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS loop_round (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
loop_task_id TEXT NOT NULL,
|
||||||
|
round_number INTEGER NOT NULL,
|
||||||
|
qa_gen_task_id TEXT,
|
||||||
|
single_jump_task_id TEXT,
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending',
|
||||||
|
generated INTEGER DEFAULT 0,
|
||||||
|
approved INTEGER DEFAULT 0,
|
||||||
|
duplicates INTEGER DEFAULT 0,
|
||||||
|
tested INTEGER DEFAULT 0,
|
||||||
|
recalled INTEGER DEFAULT 0,
|
||||||
|
file_hit INTEGER DEFAULT 0,
|
||||||
|
dedup_progress TEXT,
|
||||||
|
started_at TEXT,
|
||||||
|
finished_at TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS multi_hop_gen_task (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT,
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending',
|
||||||
|
source TEXT NOT NULL DEFAULT 'file', -- 'file' | 'dagent'
|
||||||
|
judge_config_id TEXT NOT NULL,
|
||||||
|
org_id TEXT,
|
||||||
|
file_ids TEXT DEFAULT '',
|
||||||
|
hops_per_question INTEGER DEFAULT 2,
|
||||||
|
questions_per_group INTEGER DEFAULT 3,
|
||||||
|
quality_threshold REAL DEFAULT 0.6,
|
||||||
|
progress INTEGER DEFAULT 0,
|
||||||
|
total INTEGER DEFAULT 0,
|
||||||
|
approved INTEGER DEFAULT 0,
|
||||||
|
error_message TEXT,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
finished_at TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS prompt_template (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
updated_at TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS multi_hop_gen_question (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
task_id TEXT NOT NULL,
|
||||||
|
qid TEXT,
|
||||||
|
question TEXT NOT NULL,
|
||||||
|
answer TEXT NOT NULL,
|
||||||
|
type TEXT DEFAULT 'reasoning',
|
||||||
|
hops TEXT DEFAULT '[]',
|
||||||
|
source_sections TEXT DEFAULT '[]',
|
||||||
|
quality_score REAL,
|
||||||
|
quality_detail TEXT,
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending',
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
updated_at TEXT
|
||||||
|
);
|
||||||
8
server/requirements.txt
Normal file
8
server/requirements.txt
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
fastapi==0.115.0
|
||||||
|
uvicorn[standard]==0.34.0
|
||||||
|
aiosqlite==0.20.0
|
||||||
|
python-multipart==0.0.20
|
||||||
|
aiohttp>=3.9.0
|
||||||
|
openai==1.67.0
|
||||||
|
numpy>=2.0
|
||||||
|
pydantic>=2.0.0
|
||||||
265
server/scripts/export_loop_all_groups.py
Normal file
265
server/scripts/export_loop_all_groups.py
Normal file
@ -0,0 +1,265 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Export all loop-test Q&A batches for remote dagent from SQLite (fast path)."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sqlite3
|
||||||
|
import sys
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parent.parent.parent
|
||||||
|
DB_PATH = ROOT / "server" / "data" / "rag_eval.db"
|
||||||
|
PLAN_PATH = ROOT / "docs" / "task_groups_plan.json"
|
||||||
|
EXPORT_DIR = ROOT / "docs" / "exports"
|
||||||
|
sys.path.insert(0, str(ROOT / "server"))
|
||||||
|
sys.path.insert(0, str(ROOT / "server" / "service"))
|
||||||
|
from loop_recall_md import DEFAULT_LLM_NOTE, append_recall_md_section # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def get_task_questions_fast(conn: sqlite3.Connection, task_id: str) -> list[dict]:
|
||||||
|
"""Approved Q&A from qa_gen_question; fallback to single_jump_result for legacy tasks."""
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""SELECT COUNT(*) as cnt FROM loop_round
|
||||||
|
WHERE loop_task_id=? AND qa_gen_task_id IS NOT NULL""",
|
||||||
|
(task_id,),
|
||||||
|
)
|
||||||
|
if cursor.fetchone()["cnt"] > 0:
|
||||||
|
cursor.execute(
|
||||||
|
"""SELECT
|
||||||
|
q.id as qa_question_id,
|
||||||
|
q.section_path, q.file_name, q.question, q.reference_answer,
|
||||||
|
q.source_chunk, q.quality_score, q.status,
|
||||||
|
q.dup_of, q.dup_similarity,
|
||||||
|
q.chunk_headers, q.chunk_id, q.file_id,
|
||||||
|
lr.round_number
|
||||||
|
FROM qa_gen_question q
|
||||||
|
JOIN loop_round lr ON q.task_id = lr.qa_gen_task_id
|
||||||
|
WHERE lr.loop_task_id = ? AND q.status = 'approved'
|
||||||
|
ORDER BY lr.round_number, q.chunk_headers, q.created_at""",
|
||||||
|
(task_id,),
|
||||||
|
)
|
||||||
|
return [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"""SELECT
|
||||||
|
r.section_path, r.file_name, r.question, r.reference_answer,
|
||||||
|
COALESCE(r.raw_chunk_headers, r.section_path) as chunk_headers,
|
||||||
|
r.expected_chunk_id as chunk_id,
|
||||||
|
lr.round_number
|
||||||
|
FROM single_jump_result r
|
||||||
|
JOIN loop_round lr ON r.task_id = lr.single_jump_task_id
|
||||||
|
WHERE lr.loop_task_id = ?
|
||||||
|
ORDER BY lr.round_number, r.section_path""",
|
||||||
|
(task_id,),
|
||||||
|
)
|
||||||
|
rows = []
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
d = dict(row)
|
||||||
|
d.setdefault("quality_score", 1.0)
|
||||||
|
d.setdefault("status", "approved")
|
||||||
|
rows.append(d)
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
def rows_to_md(rows: list[dict]) -> str:
|
||||||
|
if not rows:
|
||||||
|
return ""
|
||||||
|
sections: dict[str, list] = defaultdict(list)
|
||||||
|
for row in rows:
|
||||||
|
key = row.get("chunk_headers") or row.get("section_path") or row.get("file_name") or "default"
|
||||||
|
sections[key].append(row)
|
||||||
|
|
||||||
|
lines: list[str] = []
|
||||||
|
for section_index, (section_key, items) in enumerate(sections.items(), 1):
|
||||||
|
file_name = (items[0].get("file_name") or "").strip()
|
||||||
|
slice_title = (items[0].get("chunk_headers") or "").strip() or section_key
|
||||||
|
meta = [f"> 代表轮次: {items[0]['round_number']}", DEFAULT_LLM_NOTE]
|
||||||
|
qa_items = [
|
||||||
|
{
|
||||||
|
"question": it["question"],
|
||||||
|
"reference_answer": it["reference_answer"],
|
||||||
|
"chunk_id": (it.get("chunk_id") or ""),
|
||||||
|
}
|
||||||
|
for it in items
|
||||||
|
]
|
||||||
|
append_recall_md_section(
|
||||||
|
lines, section_index,
|
||||||
|
file_name=file_name, slice_title=slice_title,
|
||||||
|
qa_items=qa_items, meta_lines=meta,
|
||||||
|
)
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def rows_to_json_questions(rows: list[dict]) -> list[dict]:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"section_path": r.get("section_path"),
|
||||||
|
"file_name": r.get("file_name"),
|
||||||
|
"file_id": r.get("file_id"),
|
||||||
|
"chunk_headers": r.get("chunk_headers"),
|
||||||
|
"chunk_id": r.get("chunk_id"),
|
||||||
|
"round": r.get("round_number"),
|
||||||
|
"question": r["question"],
|
||||||
|
"reference_answer": r["reference_answer"],
|
||||||
|
"source_chunk": r.get("source_chunk"),
|
||||||
|
"quality_score": r.get("quality_score"),
|
||||||
|
"status": r.get("status"),
|
||||||
|
"is_duplicate": bool(r.get("dup_of")),
|
||||||
|
"dup_similarity": r.get("dup_similarity"),
|
||||||
|
"qa_question_id": r.get("qa_question_id"),
|
||||||
|
}
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_task_id_from_db(conn: sqlite3.Connection, group_id: int, batch_id: int) -> dict | None:
|
||||||
|
"""Pick the loop_task with most approved questions when duplicates exist."""
|
||||||
|
name = f"循环测试_组{group_id}_批次{batch_id}"
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""SELECT id, name, status, total_approved, env_url, created_at
|
||||||
|
FROM loop_task
|
||||||
|
WHERE name=? AND env_url LIKE '%dagent%'
|
||||||
|
ORDER BY total_approved DESC, created_at DESC
|
||||||
|
LIMIT 1""",
|
||||||
|
(name,),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
return dict(row) if row else None
|
||||||
|
|
||||||
|
|
||||||
|
def build_export_plan(conn: sqlite3.Connection, plan: dict) -> list[dict]:
|
||||||
|
"""Merge task_groups_plan with DB tasks for pending groups missing task_ids."""
|
||||||
|
groups_by_id = {g["task_group_id"]: dict(g) for g in plan.get("task_groups") or []}
|
||||||
|
for gid in range(1, 15):
|
||||||
|
if gid not in groups_by_id:
|
||||||
|
groups_by_id[gid] = {"task_group_id": gid, "batch_ids": [], "status": "unknown", "task_ids": []}
|
||||||
|
|
||||||
|
for gid, group in groups_by_id.items():
|
||||||
|
batch_ids = list(group.get("batch_ids") or [])
|
||||||
|
plan_tasks = {t["batch_id"]: t for t in (group.get("task_ids") or [])}
|
||||||
|
|
||||||
|
# Infer batch ids from DB when plan only has pending stub
|
||||||
|
if not batch_ids:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""SELECT DISTINCT CAST(substr(name, instr(name, '批次') + 2) AS INTEGER) AS bid
|
||||||
|
FROM loop_task
|
||||||
|
WHERE name LIKE ? AND env_url LIKE '%dagent%'
|
||||||
|
ORDER BY bid""",
|
||||||
|
(f"循环测试_组{gid}_批次%",),
|
||||||
|
)
|
||||||
|
batch_ids = [r["bid"] for r in cur.fetchall() if r["bid"]]
|
||||||
|
|
||||||
|
merged_tasks = []
|
||||||
|
for bid in sorted(batch_ids):
|
||||||
|
if bid in plan_tasks and plan_tasks[bid].get("task_id"):
|
||||||
|
merged_tasks.append(plan_tasks[bid])
|
||||||
|
continue
|
||||||
|
db_task = resolve_task_id_from_db(conn, gid, bid)
|
||||||
|
if db_task:
|
||||||
|
merged_tasks.append({
|
||||||
|
"batch_id": bid,
|
||||||
|
"task_id": db_task["id"],
|
||||||
|
"task_name": db_task["name"],
|
||||||
|
"db_status": db_task["status"],
|
||||||
|
"total_approved": db_task["total_approved"],
|
||||||
|
})
|
||||||
|
group["task_ids"] = merged_tasks
|
||||||
|
group["batch_ids"] = batch_ids
|
||||||
|
|
||||||
|
return [groups_by_id[i] for i in sorted(groups_by_id)]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if not DB_PATH.exists():
|
||||||
|
print(f"Database not found: {DB_PATH}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
plan = json.loads(PLAN_PATH.read_text(encoding="utf-8")) if PLAN_PATH.exists() else {}
|
||||||
|
exported_at = datetime.now().isoformat()
|
||||||
|
env = plan.get("environment", "")
|
||||||
|
org_id = plan.get("org_id", "")
|
||||||
|
|
||||||
|
EXPORT_DIR.mkdir(exist_ok=True)
|
||||||
|
conn = sqlite3.connect(DB_PATH)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
conn.execute("PRAGMA query_only=ON")
|
||||||
|
export_groups = build_export_plan(conn, plan)
|
||||||
|
|
||||||
|
md_parts = [
|
||||||
|
"# 远程 dagent 循环测试 — 全部组别批次问答集汇总\n",
|
||||||
|
f"\n> 导出时间: {exported_at}\n> 环境: {env}\n> 组织ID: {org_id}\n> 说明: 已批准问答(qa_gen_question.status=approved)\n\n---\n",
|
||||||
|
]
|
||||||
|
json_export = {
|
||||||
|
"exported_at": exported_at,
|
||||||
|
"environment": env,
|
||||||
|
"org_id": org_id,
|
||||||
|
"source_db": str(DB_PATH),
|
||||||
|
"task_groups": [],
|
||||||
|
"summary": {"groups": 0, "batches": 0, "batches_with_data": 0, "total_questions": 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
total_batches = batches_with_data = total_questions = 0
|
||||||
|
|
||||||
|
for group in export_groups:
|
||||||
|
gid = group.get("task_group_id")
|
||||||
|
gstatus = group.get("status", "unknown")
|
||||||
|
group_entry = {
|
||||||
|
"task_group_id": gid,
|
||||||
|
"status": gstatus,
|
||||||
|
"batch_ids": group.get("batch_ids", []),
|
||||||
|
"total_chunks": group.get("total_chunks"),
|
||||||
|
"total_files": group.get("total_files"),
|
||||||
|
"completed_at": group.get("completed_at"),
|
||||||
|
"batches": [],
|
||||||
|
}
|
||||||
|
json_export["task_groups"].append(group_entry)
|
||||||
|
json_export["summary"]["groups"] += 1
|
||||||
|
|
||||||
|
md_parts.append(f"\n# 任务组 {gid}({gstatus})\n批次: {group.get('batch_ids', [])}\n")
|
||||||
|
|
||||||
|
for ti in group.get("task_ids") or []:
|
||||||
|
task_id, task_name, batch_id = ti["task_id"], ti.get("task_name"), ti.get("batch_id")
|
||||||
|
total_batches += 1
|
||||||
|
print(f"组{gid} 批次{batch_id} {task_name}", flush=True)
|
||||||
|
rows = get_task_questions_fast(conn, task_id)
|
||||||
|
n = len(rows)
|
||||||
|
total_questions += n
|
||||||
|
group_entry["batches"].append({
|
||||||
|
"batch_id": batch_id,
|
||||||
|
"task_id": task_id,
|
||||||
|
"task_name": task_name,
|
||||||
|
"chunk_count": ti.get("chunk_count"),
|
||||||
|
"question_count": n,
|
||||||
|
"questions": rows_to_json_questions(rows),
|
||||||
|
})
|
||||||
|
if n:
|
||||||
|
batches_with_data += 1
|
||||||
|
md_parts.append(f"\n## 批次 {batch_id}: {task_name}({n} 题)\n\n{rows_to_md(rows)}\n\n---\n")
|
||||||
|
else:
|
||||||
|
md_parts.append(f"\n## 批次 {batch_id}: {task_name}(无数据)\n\n---\n")
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
json_export["summary"].update(
|
||||||
|
batches=total_batches,
|
||||||
|
batches_with_data=batches_with_data,
|
||||||
|
total_questions=total_questions,
|
||||||
|
)
|
||||||
|
|
||||||
|
md_path = EXPORT_DIR / "loop_dagent_全部组别批次_问答集汇总.md"
|
||||||
|
json_path = EXPORT_DIR / "loop_dagent_全部组别批次_问答集汇总.json"
|
||||||
|
md_path.write_text("".join(md_parts), encoding="utf-8")
|
||||||
|
json_path.write_text(json.dumps(json_export, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"完成: {json_export['summary']['groups']} 组, {batches_with_data}/{total_batches} 批有数据, {total_questions} 题")
|
||||||
|
print(md_path)
|
||||||
|
print(json_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
96
server/scripts/export_loop_batches_recall_md.py
Normal file
96
server/scripts/export_loop_batches_recall_md.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
从 rag_eval.db 导出指定循环任务批次的问题为单跳召回测试用 Markdown。
|
||||||
|
|
||||||
|
默认导出:循环测试_组1_批次1–4 + 组2_批次5–8;版式与 `service.loop_recall_md`、HTTP `/api/loop/.../export` 一致。
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import sys
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
SERVER_ROOT = Path(__file__).resolve().parent.parent
|
||||||
|
sys.path.insert(0, str(SERVER_ROOT))
|
||||||
|
|
||||||
|
from service.loop_recall_md import DEFAULT_LLM_NOTE, append_recall_md_section # noqa: E402
|
||||||
|
|
||||||
|
DB_PATH = SERVER_ROOT / "data" / "rag_eval.db"
|
||||||
|
OUT_PATH = Path(__file__).resolve().parent.parent.parent / "exports" / "loop_组1组2_共8批次_召回测试问答集.md"
|
||||||
|
|
||||||
|
# 循环测试_组1_批次1–4 + 组2_批次5–8(与库中 name 一致)
|
||||||
|
LOOP_TASK_IDS = (
|
||||||
|
"ed60fd467c364945b259ad8835458aa1", # 组1_批次1
|
||||||
|
"e40ddda0d73b4ba690399ebc00c2308f", # 组1_批次2
|
||||||
|
"1dbd2454ac024775a7c00dc376be308d", # 组1_批次3
|
||||||
|
"6f51d327d1aa451883e75ec6067e79d9", # 组1_批次4
|
||||||
|
"7e0a679c851547f68c63e073bd2c8716", # 组2_批次5
|
||||||
|
"9f52a2a526be477c8dfdae27ec978eda", # 组2_批次6
|
||||||
|
"8105a23ee907456ba45ebcd8f3b4ed1b", # 组2_批次7
|
||||||
|
"9d4fcbc5731347a3b5133b72488af6cc", # 组2_批次8
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
OUT_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
placeholders = ",".join("?" * len(LOOP_TASK_IDS))
|
||||||
|
sql = f"""
|
||||||
|
SELECT q.section_path, q.chunk_headers, q.question, q.reference_answer, q.file_name, q.chunk_id,
|
||||||
|
q.created_at
|
||||||
|
FROM qa_gen_question q
|
||||||
|
JOIN loop_round lr ON q.task_id = lr.qa_gen_task_id
|
||||||
|
JOIN loop_task lt ON lr.loop_task_id = lt.id
|
||||||
|
WHERE lr.loop_task_id IN ({placeholders})
|
||||||
|
AND q.status = 'approved'
|
||||||
|
AND (q.dup_of IS NULL OR q.dup_of = '')
|
||||||
|
ORDER BY q.chunk_headers, q.section_path, q.created_at
|
||||||
|
"""
|
||||||
|
conn = sqlite3.connect(str(DB_PATH))
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
cur = conn.execute(sql, LOOP_TASK_IDS)
|
||||||
|
|
||||||
|
by_group: dict[str, list[dict]] = defaultdict(list)
|
||||||
|
seen_q: set[tuple[str, str]] = set()
|
||||||
|
for row in cur:
|
||||||
|
d = dict(row)
|
||||||
|
gk = (d.get("chunk_headers") or "").strip() or (d.get("section_path") or "default")
|
||||||
|
key = (gk, d["question"] or "")
|
||||||
|
if key in seen_q:
|
||||||
|
continue
|
||||||
|
seen_q.add(key)
|
||||||
|
by_group[gk].append(d)
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
lines: list[str] = []
|
||||||
|
lines.append("# 循环测试组1+组2 共8批次 召回测试问答集")
|
||||||
|
lines.append("")
|
||||||
|
lines.append(
|
||||||
|
"> 由 `export_loop_batches_recall_md.py` 汇总;分组键与循环导出一致(chunk_headers 优先);"
|
||||||
|
"`##` 行在有 file_name 时为 `file_name / doc_name`。"
|
||||||
|
)
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
section_idx = 0
|
||||||
|
for gk in sorted(by_group.keys(), key=lambda x: (x or "").lower()):
|
||||||
|
rows = by_group[gk]
|
||||||
|
if not rows:
|
||||||
|
continue
|
||||||
|
section_idx += 1
|
||||||
|
file_name = (rows[0].get("file_name") or "").strip()
|
||||||
|
slice_title = (rows[0].get("chunk_headers") or "").strip() or (rows[0].get("section_path") or gk)
|
||||||
|
append_recall_md_section(
|
||||||
|
lines,
|
||||||
|
section_idx,
|
||||||
|
file_name=file_name,
|
||||||
|
slice_title=slice_title,
|
||||||
|
qa_items=rows,
|
||||||
|
meta_lines=[DEFAULT_LLM_NOTE],
|
||||||
|
)
|
||||||
|
|
||||||
|
OUT_PATH.write_text("\n".join(lines), encoding="utf-8")
|
||||||
|
print(f"Wrote {OUT_PATH} ({section_idx} sections, {len(seen_q)} unique Q&A)")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
0
server/service/__init__.py
Normal file
0
server/service/__init__.py
Normal file
201
server/service/dedup.py
Normal file
201
server/service/dedup.py
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Question deduplication service.
|
||||||
|
|
||||||
|
使用 正则归一化 + 向量余弦相似度 两阶段查重:
|
||||||
|
1) 正则归一化:去除标点/空白/常见中文疑问助词后字符串完全相等,判为重复(sim=1.0)。
|
||||||
|
2) 向量相似度:对归一化后仍不同的问题,批量 embedding + 计算 cosine;
|
||||||
|
>= similarity_threshold 判为重复。
|
||||||
|
|
||||||
|
相比 LLM 查重:更快、更便宜、结果确定,且可批量。
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import re
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# 空白 + ASCII/中英文全角标点
|
||||||
|
_PUNCT_RE = re.compile(
|
||||||
|
r'[\s - -〿-'
|
||||||
|
r'\-_=+*&^%$#@!\\/?.,;:\'"`~<>()\[\]{}]+'
|
||||||
|
)
|
||||||
|
# 结尾的疑问助词和语气词
|
||||||
|
_TAIL_PARTICLE_RE = re.compile(r'(?:吗|呢|啊|呀|哪|么|嘛|吧)+[??。!!]*$')
|
||||||
|
# 开头的礼貌/引导词
|
||||||
|
_LEADING_ASK_RE = re.compile(r'^(?:请问一下|请问|问一下|那么|然后|所以)')
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize(text: str) -> str:
|
||||||
|
"""问题文本的规范形式(用于正则查重)。"""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
s = text.strip().lower()
|
||||||
|
s = _LEADING_ASK_RE.sub("", s)
|
||||||
|
s = _TAIL_PARTICLE_RE.sub("", s)
|
||||||
|
s = _PUNCT_RE.sub("", s)
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def _regex_duplicate_id(
|
||||||
|
new_question: str,
|
||||||
|
existing_questions: list[tuple[str, str]],
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""规范化后的字符串与已有问题完全相等则判重,返回该已有问题 id。"""
|
||||||
|
norm_new = _normalize(new_question)
|
||||||
|
if not norm_new:
|
||||||
|
return None
|
||||||
|
for qid, existing_q in existing_questions:
|
||||||
|
if _normalize(existing_q) == norm_new:
|
||||||
|
return qid
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _embed_texts(
|
||||||
|
embed_client,
|
||||||
|
model: str,
|
||||||
|
texts: list[str],
|
||||||
|
batch_size: int = 64,
|
||||||
|
) -> list[np.ndarray]:
|
||||||
|
"""批量 embedding,返回 L2 归一化后的向量列表(顺序与输入一致)。"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
out: list[np.ndarray] = []
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch = texts[i:i + batch_size]
|
||||||
|
resp = await embed_client.embeddings.create(model=model, input=batch)
|
||||||
|
for item in resp.data:
|
||||||
|
v = np.asarray(item.embedding, dtype=np.float32)
|
||||||
|
n = np.linalg.norm(v)
|
||||||
|
if n > 0:
|
||||||
|
v = v / n
|
||||||
|
out.append(v)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
async def deduplicate_questions_by_chunk(
|
||||||
|
new_questions_by_chunk: dict[str, list[dict]], # {chunk_id: [{id, question, ...}]}
|
||||||
|
existing_questions_by_chunk: dict[str, list[tuple[str, str]]], # {chunk_id: [(id, question)]}
|
||||||
|
embed_client,
|
||||||
|
embed_model: str,
|
||||||
|
similarity_threshold: float = 0.85,
|
||||||
|
max_parallel_chunks: int = 5,
|
||||||
|
stop_check: Optional[Callable[[], bool]] = None,
|
||||||
|
pause_check: Optional[Callable[[], bool]] = None, # New: check if paused
|
||||||
|
on_progress: Optional[Callable] = None, # async callback(done, total)
|
||||||
|
) -> dict[str, tuple[Optional[str], float]]:
|
||||||
|
"""
|
||||||
|
按切片并行查重。
|
||||||
|
|
||||||
|
对每个切片:
|
||||||
|
- 先用正则归一化做精确查重(新 vs 已有,新 vs 新同批)。
|
||||||
|
- 剩余的问题批量 embedding,逐一与已有问题、该批内更早的问题计算 cosine,
|
||||||
|
取最大值;>= threshold 判重。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{new_question_id: (dup_of_id_or_None, similarity)}
|
||||||
|
"""
|
||||||
|
chunk_sem = asyncio.Semaphore(max_parallel_chunks)
|
||||||
|
results: dict[str, tuple[Optional[str], float]] = {}
|
||||||
|
stopped = False
|
||||||
|
done_count = 0
|
||||||
|
total = sum(len(qs) for qs in new_questions_by_chunk.values())
|
||||||
|
progress_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def bump_progress(n: int):
|
||||||
|
nonlocal done_count
|
||||||
|
async with progress_lock:
|
||||||
|
done_count += n
|
||||||
|
if on_progress:
|
||||||
|
await on_progress(done_count, total)
|
||||||
|
|
||||||
|
async def dedup_one_chunk(chunk_id: str, new_questions: list[dict]):
|
||||||
|
nonlocal stopped
|
||||||
|
if stopped or (stop_check and stop_check()):
|
||||||
|
stopped = True
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check pause before starting chunk
|
||||||
|
if pause_check and await pause_check():
|
||||||
|
stopped = True
|
||||||
|
return
|
||||||
|
|
||||||
|
existing = existing_questions_by_chunk.get(chunk_id, [])
|
||||||
|
|
||||||
|
async with chunk_sem:
|
||||||
|
if stopped or (stop_check and stop_check()):
|
||||||
|
stopped = True
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check pause again after acquiring semaphore
|
||||||
|
if pause_check and await pause_check():
|
||||||
|
stopped = True
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Step 1: 正则归一化查重 ─────────────────────────────────
|
||||||
|
seen_norm: dict[str, str] = {} # 归一化后的字符串 -> 首次出现该形式的新问题 id
|
||||||
|
remaining: list[dict] = []
|
||||||
|
|
||||||
|
for q in new_questions:
|
||||||
|
# 与已有问题做规范化比对
|
||||||
|
ex_id = _regex_duplicate_id(q["question"], existing)
|
||||||
|
if ex_id:
|
||||||
|
results[q["id"]] = (ex_id, 1.0)
|
||||||
|
continue
|
||||||
|
# 与同批次更早的新问题比对
|
||||||
|
norm = _normalize(q["question"])
|
||||||
|
if norm and norm in seen_norm:
|
||||||
|
results[q["id"]] = (seen_norm[norm], 1.0)
|
||||||
|
continue
|
||||||
|
if norm:
|
||||||
|
seen_norm[norm] = q["id"]
|
||||||
|
remaining.append(q)
|
||||||
|
|
||||||
|
# ── Step 2: 向量相似度查重 ─────────────────────────────────
|
||||||
|
if remaining:
|
||||||
|
try:
|
||||||
|
new_texts = [q["question"] for q in remaining]
|
||||||
|
new_ids = [q["id"] for q in remaining]
|
||||||
|
existing_texts = [q for _, q in existing]
|
||||||
|
existing_ids = [qid for qid, _ in existing]
|
||||||
|
|
||||||
|
all_vecs = await _embed_texts(
|
||||||
|
embed_client, embed_model, new_texts + existing_texts
|
||||||
|
)
|
||||||
|
new_vecs = all_vecs[:len(new_texts)]
|
||||||
|
ex_vecs = all_vecs[len(new_texts):]
|
||||||
|
|
||||||
|
for i, nv in enumerate(new_vecs):
|
||||||
|
best_id: Optional[str] = None
|
||||||
|
best_sim = 0.0
|
||||||
|
# vs 已有问题
|
||||||
|
for ex_id, ev in zip(existing_ids, ex_vecs):
|
||||||
|
sim = float(np.dot(nv, ev))
|
||||||
|
if sim > best_sim:
|
||||||
|
best_sim = sim
|
||||||
|
best_id = ex_id
|
||||||
|
# vs 同批次更早的新问题(捕获批内近似重复)
|
||||||
|
for j in range(i):
|
||||||
|
sim = float(np.dot(nv, new_vecs[j]))
|
||||||
|
if sim > best_sim:
|
||||||
|
best_sim = sim
|
||||||
|
best_id = new_ids[j]
|
||||||
|
|
||||||
|
if best_id is not None and best_sim >= similarity_threshold:
|
||||||
|
results[new_ids[i]] = (best_id, round(best_sim, 4))
|
||||||
|
else:
|
||||||
|
results[new_ids[i]] = (None, 0.0)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARN] Vector dedup failed for chunk {chunk_id}: {e}")
|
||||||
|
for q in remaining:
|
||||||
|
results.setdefault(q["id"], (None, 0.0))
|
||||||
|
|
||||||
|
await bump_progress(len(new_questions))
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
dedup_one_chunk(chunk_id, questions)
|
||||||
|
for chunk_id, questions in new_questions_by_chunk.items()
|
||||||
|
]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
return results
|
||||||
928
server/service/loop_engine.py
Normal file
928
server/service/loop_engine.py
Normal file
@ -0,0 +1,928 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Loop task execution engine with pause/resume support.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
# Fix Windows GBK encoding issue
|
||||||
|
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
||||||
|
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
||||||
|
|
||||||
|
from models.db import get_db, _id, _now
|
||||||
|
from service.loop_recall_md import DEFAULT_LLM_NOTE, append_recall_md_section
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level control dictionary for pause/resume/stop
|
||||||
|
# key=loop_task_id, value={"pause_event": asyncio.Event, "stop": bool}
|
||||||
|
_loop_controls: dict[str, dict] = {}
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_pause(loop_task_id: str) -> bool:
|
||||||
|
"""Check if task should pause. Returns True if stopped."""
|
||||||
|
ctrl = _loop_controls.get(loop_task_id)
|
||||||
|
if not ctrl:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if ctrl["stop"]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Wait for pause_event (will block if event is cleared)
|
||||||
|
await ctrl["pause_event"].wait()
|
||||||
|
return ctrl["stop"]
|
||||||
|
|
||||||
|
|
||||||
|
def _init_control(loop_task_id: str) -> None:
|
||||||
|
"""Initialize control structure for a loop task."""
|
||||||
|
event = asyncio.Event()
|
||||||
|
event.set() # Initially not paused
|
||||||
|
_loop_controls[loop_task_id] = {
|
||||||
|
"pause_event": event,
|
||||||
|
"stop": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _clear_control(loop_task_id: str) -> None:
|
||||||
|
"""Clean up control structure."""
|
||||||
|
_loop_controls.pop(loop_task_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
async def pause_loop(loop_task_id: str) -> bool:
|
||||||
|
"""Pause a running loop task."""
|
||||||
|
ctrl = _loop_controls.get(loop_task_id)
|
||||||
|
if not ctrl:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 立即写数据库,让前端看到"已暂停"状态
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE loop_task SET status='paused', paused_at=? WHERE id=?",
|
||||||
|
(_now(), loop_task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# Clear event,后台会在阶段边界停下来
|
||||||
|
ctrl["pause_event"].clear()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def resume_loop(loop_task_id: str) -> bool:
|
||||||
|
"""Resume a paused loop task."""
|
||||||
|
ctrl = _loop_controls.get(loop_task_id)
|
||||||
|
if not ctrl:
|
||||||
|
return False
|
||||||
|
|
||||||
|
ctrl["pause_event"].set()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def stop_loop(loop_task_id: str) -> bool:
|
||||||
|
"""Stop a loop task permanently."""
|
||||||
|
ctrl = _loop_controls.get(loop_task_id)
|
||||||
|
if not ctrl:
|
||||||
|
return False
|
||||||
|
|
||||||
|
ctrl["stop"] = True
|
||||||
|
ctrl["pause_event"].set() # Unblock if paused
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE loop_task SET status='stopped', finished_at=? WHERE id=?",
|
||||||
|
(_now(), loop_task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def run_loop_task(
|
||||||
|
loop_task_id: str,
|
||||||
|
org_id: str,
|
||||||
|
file_ids: list[str],
|
||||||
|
judge_config_id: str,
|
||||||
|
questions_per_section: int,
|
||||||
|
quality_threshold: float,
|
||||||
|
include_multimodal: bool,
|
||||||
|
env_url: str,
|
||||||
|
d_user_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
top_k: int,
|
||||||
|
recall_top_k: int,
|
||||||
|
concurrency: int,
|
||||||
|
cross_chunk: bool,
|
||||||
|
max_rounds: int,
|
||||||
|
max_questions: int,
|
||||||
|
global_dedup: bool = False, # 是否使用全局去重(跨任务)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Main loop execution engine.
|
||||||
|
|
||||||
|
Each round:
|
||||||
|
1. Fetch existing questions from all previous rounds
|
||||||
|
2. Generate new questions (avoiding existing angles)
|
||||||
|
3. Deduplicate with LLM
|
||||||
|
4. Create single-jump test
|
||||||
|
5. Wait for test completion
|
||||||
|
6. Update stats and check termination conditions
|
||||||
|
"""
|
||||||
|
_init_control(loop_task_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await _do_run_loop(
|
||||||
|
loop_task_id, org_id, file_ids, judge_config_id,
|
||||||
|
questions_per_section, quality_threshold, include_multimodal,
|
||||||
|
env_url, d_user_id, agent_id, top_k, recall_top_k, concurrency, cross_chunk,
|
||||||
|
max_rounds, max_questions, global_dedup
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Mark as failed
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE loop_task SET status='failed', error_message=? WHERE id=?",
|
||||||
|
(str(e), loop_task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
finally:
|
||||||
|
_clear_control(loop_task_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _do_run_loop(
|
||||||
|
loop_task_id: str,
|
||||||
|
org_id: str,
|
||||||
|
file_ids: list[str],
|
||||||
|
judge_config_id: str,
|
||||||
|
questions_per_section: int,
|
||||||
|
quality_threshold: float,
|
||||||
|
include_multimodal: bool,
|
||||||
|
env_url: str,
|
||||||
|
d_user_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
top_k: int,
|
||||||
|
recall_top_k: int,
|
||||||
|
concurrency: int,
|
||||||
|
cross_chunk: bool,
|
||||||
|
max_rounds: int,
|
||||||
|
max_questions: int,
|
||||||
|
global_dedup: bool = False,
|
||||||
|
):
|
||||||
|
"""Internal loop implementation."""
|
||||||
|
|
||||||
|
# Get loop task name与批次期望切片数(与 chunk_batches_plan.chunk_count 对齐,用于拉取完整性校验)
|
||||||
|
async with get_db() as db:
|
||||||
|
task_rows = await db.execute_fetchall(
|
||||||
|
"SELECT name, expected_chunk_count FROM loop_task WHERE id=?", (loop_task_id,)
|
||||||
|
)
|
||||||
|
_tr = dict(task_rows[0]) if task_rows else {}
|
||||||
|
loop_task_name = _tr.get("name") or loop_task_id[:8]
|
||||||
|
_ecc = _tr.get("expected_chunk_count")
|
||||||
|
try:
|
||||||
|
expected_chunk_count = int(_ecc) if _ecc is not None and int(_ecc) > 0 else None
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
expected_chunk_count = None
|
||||||
|
|
||||||
|
# Get judge config for LLM client
|
||||||
|
async with get_db() as db:
|
||||||
|
cfg_rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM judge_config WHERE id=?", (judge_config_id,)
|
||||||
|
)
|
||||||
|
if not cfg_rows:
|
||||||
|
raise ValueError("judge_config not found")
|
||||||
|
judge_cfg = dict(cfg_rows[0])
|
||||||
|
|
||||||
|
# Initialize Embedding client for dedup (向量相似度查重,不再使用 LLM)
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
embed_base = (judge_cfg.get("embed_base_url") or judge_cfg["base_url"]).rstrip("/")
|
||||||
|
embed_key = judge_cfg.get("embed_api_key") or judge_cfg["api_key"]
|
||||||
|
embed_client = AsyncOpenAI(
|
||||||
|
base_url=embed_base,
|
||||||
|
api_key=embed_key,
|
||||||
|
)
|
||||||
|
embed_model = judge_cfg.get("embed_model") or "text-embedding-3-small"
|
||||||
|
|
||||||
|
# Update status to running
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE loop_task SET status='running' WHERE id=?",
|
||||||
|
(loop_task_id,),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
consecutive_empty_rounds = 0
|
||||||
|
|
||||||
|
def stop_check():
|
||||||
|
ctrl = _loop_controls.get(loop_task_id)
|
||||||
|
if ctrl is None or ctrl.get("stop", False):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def async_pause_check():
|
||||||
|
"""Check if paused and wait for resume. Returns True if should stop."""
|
||||||
|
ctrl = _loop_controls.get(loop_task_id)
|
||||||
|
if not ctrl:
|
||||||
|
return False
|
||||||
|
if ctrl.get("stop", False):
|
||||||
|
return True
|
||||||
|
# Check pause and wait if needed
|
||||||
|
if not ctrl["pause_event"].is_set():
|
||||||
|
await ctrl["pause_event"].wait()
|
||||||
|
if ctrl.get("stop", False):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def check_pause_between_stages() -> bool:
|
||||||
|
"""在阶段边界等待暂停信号,返回 True 表示应该停止。"""
|
||||||
|
ctrl = _loop_controls.get(loop_task_id)
|
||||||
|
if not ctrl:
|
||||||
|
return False
|
||||||
|
if ctrl["stop"]:
|
||||||
|
return True
|
||||||
|
# 如果 pause_event 已被 clear,说明用户点了暂停
|
||||||
|
# pause_loop 已经写了数据库,这里只需要等待 resume
|
||||||
|
if not ctrl["pause_event"].is_set():
|
||||||
|
await ctrl["pause_event"].wait() # 阻塞直到 resume
|
||||||
|
if ctrl["stop"]:
|
||||||
|
return True
|
||||||
|
# resume 后把状态改回 running
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE loop_task SET status='running', paused_at=NULL WHERE id=?",
|
||||||
|
(loop_task_id,),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 确定从哪一轮、哪个阶段开始
|
||||||
|
# 查最后一轮的状态,决定是继续该轮还是开新轮
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"""SELECT id, round_number, status, qa_gen_task_id, single_jump_task_id
|
||||||
|
FROM loop_round
|
||||||
|
WHERE loop_task_id=?
|
||||||
|
ORDER BY round_number DESC LIMIT 1""",
|
||||||
|
(loop_task_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
# resume_round: 需要继续执行的轮次信息,None 表示从新轮开始
|
||||||
|
resume_round = None
|
||||||
|
if rows:
|
||||||
|
last = dict(rows[0])
|
||||||
|
if last["status"] != "done":
|
||||||
|
resume_round = last # 需要从这一轮的某个阶段继续
|
||||||
|
round_number = last["round_number"] - 1 # 循环会 +1 回到这一轮
|
||||||
|
else:
|
||||||
|
round_number = last["round_number"] # 从下一轮开始
|
||||||
|
else:
|
||||||
|
round_number = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# 阶段边界:检查暂停/停止
|
||||||
|
if await check_pause_between_stages():
|
||||||
|
return
|
||||||
|
|
||||||
|
round_number += 1
|
||||||
|
|
||||||
|
# Check max_rounds
|
||||||
|
if max_rounds > 0 and round_number > max_rounds:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check max_questions
|
||||||
|
if max_questions > 0:
|
||||||
|
async with get_db() as db:
|
||||||
|
row = await db.execute_fetchall(
|
||||||
|
"SELECT total_approved FROM loop_task WHERE id=?", (loop_task_id,)
|
||||||
|
)
|
||||||
|
current_total = row[0]["total_approved"] if row else 0
|
||||||
|
if current_total >= max_questions:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 判断是继续上次中断的轮次,还是创建新轮次
|
||||||
|
if resume_round and resume_round["round_number"] == round_number:
|
||||||
|
# 继续上次中断的轮次,复用已有的 round_id 和 qa_gen_task_id
|
||||||
|
round_id = resume_round["id"]
|
||||||
|
resume_stage = resume_round["status"] # qa_generating / deduplicating / testing
|
||||||
|
qa_task_id = resume_round["qa_gen_task_id"]
|
||||||
|
resume_round = None # 只用一次
|
||||||
|
else:
|
||||||
|
# 创建新轮次
|
||||||
|
resume_stage = None
|
||||||
|
round_id = _id()
|
||||||
|
qa_task_id = None
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"""INSERT INTO loop_round
|
||||||
|
(id, loop_task_id, round_number, status, started_at)
|
||||||
|
VALUES (?,?,?,?,?)""",
|
||||||
|
(round_id, loop_task_id, round_number, "qa_generating", _now()),
|
||||||
|
)
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE loop_task SET current_round=? WHERE id=?",
|
||||||
|
(round_number, loop_task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# 1. Get existing questions from all previous rounds
|
||||||
|
section_existing_questions = await _get_existing_questions(loop_task_id, global_dedup=global_dedup)
|
||||||
|
all_existing_questions = []
|
||||||
|
for questions in section_existing_questions.values():
|
||||||
|
all_existing_questions.extend(questions)
|
||||||
|
|
||||||
|
# For QA generation, only pass question text (not ids)
|
||||||
|
section_existing_text = {
|
||||||
|
sp: [q["question"] for q in qs]
|
||||||
|
for sp, qs in section_existing_questions.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# 2. QA 生成阶段
|
||||||
|
# 如果是从 deduplicating 或 testing 阶段 resume,跳过 QA 生成
|
||||||
|
if resume_stage in ("deduplicating", "testing"):
|
||||||
|
# qa_task_id 已经有了,直接跳过生成
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# 需要运行 QA 生成(新轮次,或从 qa_generating 阶段 resume)
|
||||||
|
if qa_task_id is None:
|
||||||
|
qa_task_id = _id()
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"""INSERT INTO qa_gen_task
|
||||||
|
(id,name,status,judge_config_id,questions_per_section,quality_threshold,
|
||||||
|
progress,total,created_at)
|
||||||
|
VALUES (?,?,?,?,?,?,?,?,?)""",
|
||||||
|
(qa_task_id, f"{loop_task_name}-问题生成-第{round_number}轮", "pending",
|
||||||
|
judge_config_id, questions_per_section, quality_threshold,
|
||||||
|
0, 0, _now()),
|
||||||
|
)
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE loop_round SET qa_gen_task_id=?, status='qa_generating' WHERE id=?",
|
||||||
|
(qa_task_id, round_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
else:
|
||||||
|
# resume_stage == 'qa_generating':qa_task 已存在但未完成,重新跑
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE loop_round SET status='qa_generating' WHERE id=?",
|
||||||
|
(round_id,),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
from api.qa_gen_dagent import _run_dagent_task
|
||||||
|
try:
|
||||||
|
await _run_dagent_task(
|
||||||
|
task_id=qa_task_id,
|
||||||
|
org_id=org_id,
|
||||||
|
file_id_list=file_ids,
|
||||||
|
judge_config_id=judge_config_id,
|
||||||
|
questions_per_section=questions_per_section,
|
||||||
|
quality_threshold=quality_threshold,
|
||||||
|
include_multimodal=include_multimodal,
|
||||||
|
section_existing_questions=section_existing_text,
|
||||||
|
stop_check=stop_check,
|
||||||
|
pause_check=async_pause_check,
|
||||||
|
env_url=env_url,
|
||||||
|
expected_chunk_count=expected_chunk_count,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE loop_round SET status='failed', finished_at=? WHERE id=?",
|
||||||
|
(_now(), round_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
raise
|
||||||
|
|
||||||
|
# 阶段边界:QA 生成完成后检查暂停
|
||||||
|
if await check_pause_between_stages():
|
||||||
|
return
|
||||||
|
|
||||||
|
# 3. 去重阶段
|
||||||
|
if resume_stage != "testing":
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE loop_round SET status='deduplicating' WHERE id=?",
|
||||||
|
(round_id,),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# 按切片分组获取新问题
|
||||||
|
new_questions_by_chunk = await _get_new_questions_by_chunk(qa_task_id)
|
||||||
|
|
||||||
|
# 按切片分组获取已有问题(用于查重),排除本轮 qa_task_id 避免自查自
|
||||||
|
existing_by_chunk = await _get_existing_questions_by_chunk(
|
||||||
|
loop_task_id,
|
||||||
|
exclude_qa_task_id=qa_task_id,
|
||||||
|
global_dedup=global_dedup,
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_questions_by_chunk:
|
||||||
|
from service.dedup import deduplicate_questions_by_chunk
|
||||||
|
|
||||||
|
async def on_dedup_progress(done: int, total: int):
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE loop_round SET dedup_progress=? WHERE id=?",
|
||||||
|
(f"{done}/{total}", round_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# 按切片并行查重(正则归一化 + 向量余弦相似度)
|
||||||
|
dup_results = await deduplicate_questions_by_chunk(
|
||||||
|
new_questions_by_chunk,
|
||||||
|
existing_by_chunk,
|
||||||
|
embed_client,
|
||||||
|
embed_model,
|
||||||
|
similarity_threshold=0.85,
|
||||||
|
max_parallel_chunks=5,
|
||||||
|
stop_check=stop_check,
|
||||||
|
pause_check=async_pause_check,
|
||||||
|
on_progress=on_dedup_progress,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stop_check():
|
||||||
|
return
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
for qid, (dup_of, sim) in dup_results.items():
|
||||||
|
if dup_of:
|
||||||
|
await db.execute(
|
||||||
|
"""UPDATE qa_gen_question
|
||||||
|
SET dup_of=?, dup_similarity=?, status='rejected'
|
||||||
|
WHERE id=?""",
|
||||||
|
(dup_of, sim, qid),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# 阶段边界:去重完成后检查暂停
|
||||||
|
if await check_pause_between_stages():
|
||||||
|
return
|
||||||
|
|
||||||
|
# 统计本轮数据
|
||||||
|
async with get_db() as db:
|
||||||
|
counts = await db.execute_fetchall(
|
||||||
|
"""SELECT
|
||||||
|
COUNT(*) as generated,
|
||||||
|
SUM(CASE WHEN status='approved' THEN 1 ELSE 0 END) as approved,
|
||||||
|
SUM(CASE WHEN dup_of IS NOT NULL THEN 1 ELSE 0 END) as duplicates
|
||||||
|
FROM qa_gen_question WHERE task_id=?""",
|
||||||
|
(qa_task_id,),
|
||||||
|
)
|
||||||
|
gen_count = counts[0]["generated"] if counts else 0
|
||||||
|
app_count = counts[0]["approved"] if counts else 0
|
||||||
|
dup_count = counts[0]["duplicates"] if counts else 0
|
||||||
|
# SUM 在没有匹配行时返回 NULL,统一成 0 避免后续 None 比较
|
||||||
|
gen_count = gen_count or 0
|
||||||
|
app_count = app_count or 0
|
||||||
|
dup_count = dup_count or 0
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"""UPDATE loop_round
|
||||||
|
SET generated=?, approved=?, duplicates=?, status='testing'
|
||||||
|
WHERE id=?""",
|
||||||
|
(gen_count, app_count, dup_count, round_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# 收敛检测
|
||||||
|
if app_count == 0:
|
||||||
|
consecutive_empty_rounds += 1
|
||||||
|
if consecutive_empty_rounds >= 2:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
consecutive_empty_rounds = 0
|
||||||
|
|
||||||
|
# 4. 召回测试阶段
|
||||||
|
if app_count > 0:
|
||||||
|
await _run_single_jump_for_round(
|
||||||
|
loop_task_id, loop_task_name, round_number, round_id, qa_task_id,
|
||||||
|
env_url, org_id, d_user_id, agent_id, top_k, recall_top_k, concurrency, cross_chunk
|
||||||
|
)
|
||||||
|
|
||||||
|
# 阶段边界:召回测试完成后检查暂停
|
||||||
|
if await check_pause_between_stages():
|
||||||
|
return
|
||||||
|
|
||||||
|
# 5. 更新累计统计
|
||||||
|
await _update_loop_stats(loop_task_id)
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE loop_round SET status='done', finished_at=? WHERE id=?",
|
||||||
|
(_now(), round_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# Loop finished normally
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE loop_task SET status='done', finished_at=? WHERE id=?",
|
||||||
|
(_now(), loop_task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_existing_questions(loop_task_id: str, global_dedup: bool = False) -> dict[str, list[str]]:
|
||||||
|
"""Get all approved questions, grouped by section_path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loop_task_id: Current loop task ID
|
||||||
|
global_dedup: If True, get all approved questions from database (cross-task dedup)
|
||||||
|
If False, only get questions from this loop task (default)
|
||||||
|
"""
|
||||||
|
async with get_db() as db:
|
||||||
|
if global_dedup:
|
||||||
|
# 全局去重:获取所有已批准的问题(跨任务)
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"""SELECT q.id, q.section_path, q.question
|
||||||
|
FROM qa_gen_question q
|
||||||
|
WHERE q.status = 'approved'
|
||||||
|
ORDER BY q.created_at""",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 任务内去重:只获取当前循环任务的问题
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"""SELECT q.id, q.section_path, q.question
|
||||||
|
FROM qa_gen_question q
|
||||||
|
JOIN loop_round lr ON q.task_id = lr.qa_gen_task_id
|
||||||
|
WHERE lr.loop_task_id = ? AND q.status = 'approved'
|
||||||
|
ORDER BY q.created_at""",
|
||||||
|
(loop_task_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
result: dict[str, list] = {}
|
||||||
|
for row in rows:
|
||||||
|
sp = row["section_path"]
|
||||||
|
if sp not in result:
|
||||||
|
result[sp] = []
|
||||||
|
result[sp].append({"id": row["id"], "question": row["question"]})
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_new_questions(qa_task_id: str) -> list[dict]:
|
||||||
|
"""Get all questions from a QA task."""
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT id, question FROM qa_gen_question WHERE task_id=?",
|
||||||
|
(qa_task_id,),
|
||||||
|
)
|
||||||
|
return [{"id": r["id"], "question": r["question"]} for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_new_questions_by_chunk(qa_task_id: str) -> dict[str, list[dict]]:
|
||||||
|
"""按切片分组获取新问题。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{chunk_id: [{id, question, ...}]}
|
||||||
|
"""
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"""SELECT id, question, chunk_id, section_path
|
||||||
|
FROM qa_gen_question
|
||||||
|
WHERE task_id=?""",
|
||||||
|
(qa_task_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
result: dict[str, list] = {}
|
||||||
|
for row in rows:
|
||||||
|
chunk_id = row["chunk_id"] or row["section_path"] or "default"
|
||||||
|
if chunk_id not in result:
|
||||||
|
result[chunk_id] = []
|
||||||
|
result[chunk_id].append({
|
||||||
|
"id": row["id"],
|
||||||
|
"question": row["question"],
|
||||||
|
"chunk_id": row["chunk_id"],
|
||||||
|
"section_path": row["section_path"],
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_existing_questions_by_chunk(
|
||||||
|
loop_task_id: str,
|
||||||
|
exclude_qa_task_id: str | None = None,
|
||||||
|
global_dedup: bool = False,
|
||||||
|
) -> dict[str, list[tuple[str, str]]]:
|
||||||
|
"""按切片分组获取已有问题(用于查重)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loop_task_id: 当前循环任务ID
|
||||||
|
exclude_qa_task_id: 排除的 qa_gen_task_id(即本轮刚生成的一批,避免自己查自己)
|
||||||
|
global_dedup: 是否全局去重(跨任务)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{chunk_id: [(id, question)]}
|
||||||
|
"""
|
||||||
|
async with get_db() as db:
|
||||||
|
if global_dedup:
|
||||||
|
# 全局去重:获取所有已批准的问题,但排除本轮 qa_task
|
||||||
|
if exclude_qa_task_id:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"""SELECT id, chunk_id, section_path, question
|
||||||
|
FROM qa_gen_question
|
||||||
|
WHERE status = 'approved' AND task_id != ?
|
||||||
|
ORDER BY created_at""",
|
||||||
|
(exclude_qa_task_id,),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"""SELECT id, chunk_id, section_path, question
|
||||||
|
FROM qa_gen_question
|
||||||
|
WHERE status = 'approved'
|
||||||
|
ORDER BY created_at""",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 任务内去重:只获取当前循环任务的问题,但排除本轮 qa_task
|
||||||
|
if exclude_qa_task_id:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"""SELECT q.id, q.chunk_id, q.section_path, q.question
|
||||||
|
FROM qa_gen_question q
|
||||||
|
JOIN loop_round lr ON q.task_id = lr.qa_gen_task_id
|
||||||
|
WHERE lr.loop_task_id = ?
|
||||||
|
AND q.status = 'approved'
|
||||||
|
AND q.task_id != ?
|
||||||
|
ORDER BY q.created_at""",
|
||||||
|
(loop_task_id, exclude_qa_task_id),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"""SELECT q.id, q.chunk_id, q.section_path, q.question
|
||||||
|
FROM qa_gen_question q
|
||||||
|
JOIN loop_round lr ON q.task_id = lr.qa_gen_task_id
|
||||||
|
WHERE lr.loop_task_id = ? AND q.status = 'approved'
|
||||||
|
ORDER BY q.created_at""",
|
||||||
|
(loop_task_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
result: dict[str, list] = {}
|
||||||
|
for row in rows:
|
||||||
|
chunk_id = row["chunk_id"] or row["section_path"] or "default"
|
||||||
|
if chunk_id not in result:
|
||||||
|
result[chunk_id] = []
|
||||||
|
result[chunk_id].append((row["id"], row["question"]))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_single_jump_for_round(
|
||||||
|
loop_task_id: str,
|
||||||
|
loop_task_name: str,
|
||||||
|
round_number: int,
|
||||||
|
round_id: str,
|
||||||
|
qa_task_id: str,
|
||||||
|
env_url: str,
|
||||||
|
org_id: str,
|
||||||
|
d_user_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
top_k: int,
|
||||||
|
recall_top_k: int,
|
||||||
|
concurrency: int,
|
||||||
|
cross_chunk: bool,
|
||||||
|
):
|
||||||
|
"""Run single-jump test for a round's approved questions."""
|
||||||
|
|
||||||
|
def stop_check():
|
||||||
|
ctrl = _loop_controls.get(loop_task_id)
|
||||||
|
return ctrl is None or ctrl.get("stop", False)
|
||||||
|
|
||||||
|
# Check stop before starting
|
||||||
|
if stop_check():
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create single-jump task
|
||||||
|
sj_task_id = _id()
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"""INSERT INTO single_jump_task
|
||||||
|
(id,name,env_url,org_id,d_user_id,agent_id,top_k,recall_top_k,concurrency,cross_chunk,
|
||||||
|
status,progress,total,created_at,hit_top_k)
|
||||||
|
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""",
|
||||||
|
(sj_task_id, f"{loop_task_name}-单跳测试-第{round_number}轮", env_url, org_id, d_user_id,
|
||||||
|
agent_id, top_k, recall_top_k, concurrency, int(cross_chunk), "pending", 0, 0, _now(), top_k),
|
||||||
|
)
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE loop_round SET single_jump_task_id=? WHERE id=?",
|
||||||
|
(sj_task_id, round_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# Build MD content from approved questions
|
||||||
|
# Query approved questions from this QA task
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"""SELECT section_path, file_name, file_id, question, reference_answer, chunk_id, chunk_headers
|
||||||
|
FROM qa_gen_question
|
||||||
|
WHERE task_id=? AND status='approved'
|
||||||
|
ORDER BY chunk_headers, created_at""",
|
||||||
|
(qa_task_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
# No approved questions, skip test
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check stop before running test
|
||||||
|
if stop_check():
|
||||||
|
return
|
||||||
|
|
||||||
|
# Group by chunk_headers (use section_path as fallback)
|
||||||
|
from collections import defaultdict
|
||||||
|
sections_dict: dict[str, list] = defaultdict(list)
|
||||||
|
question_chunk_map: dict[str, str] = {} # question -> chunk_id
|
||||||
|
# section_key -> {file_id, file_name} from qa_gen_question
|
||||||
|
section_file_info: dict[str, dict] = {}
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
# Use chunk_headers as the grouping key if available, otherwise use section_path
|
||||||
|
section_key = row["chunk_headers"] if row["chunk_headers"] else row["section_path"]
|
||||||
|
if not section_key:
|
||||||
|
section_key = row["file_name"] or "default"
|
||||||
|
sections_dict[section_key].append({
|
||||||
|
"question": row["question"],
|
||||||
|
"reference_answer": row["reference_answer"],
|
||||||
|
"file_name": row["file_name"],
|
||||||
|
"chunk_headers": row["chunk_headers"],
|
||||||
|
"chunk_id": row["chunk_id"],
|
||||||
|
})
|
||||||
|
# Build question to chunk_id mapping
|
||||||
|
if row["chunk_id"] and row["question"]:
|
||||||
|
question_chunk_map[row["question"]] = row["chunk_id"]
|
||||||
|
# Remember file info for this section_key (first non-empty file_id wins)
|
||||||
|
if row["file_id"] and section_key not in section_file_info:
|
||||||
|
section_file_info[section_key] = {
|
||||||
|
"file_id": row["file_id"],
|
||||||
|
"file_name": row["file_name"] or "",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Generate MD(与 HTTP 导出、离线脚本共用 loop_recall_md)
|
||||||
|
prebuilt_file_map: dict[str, dict] = {}
|
||||||
|
md_lines: list[str] = []
|
||||||
|
|
||||||
|
section_index = 0
|
||||||
|
for section_key, items in sections_dict.items():
|
||||||
|
section_index += 1
|
||||||
|
file_name = (items[0].get("file_name") or "").strip()
|
||||||
|
slice_title = (items[0].get("chunk_headers") or "").strip() or section_key
|
||||||
|
|
||||||
|
parsed_section_path = append_recall_md_section(
|
||||||
|
md_lines,
|
||||||
|
section_index,
|
||||||
|
file_name=file_name,
|
||||||
|
slice_title=slice_title,
|
||||||
|
qa_items=items,
|
||||||
|
meta_lines=[DEFAULT_LLM_NOTE],
|
||||||
|
)
|
||||||
|
finfo = section_file_info.get(section_key)
|
||||||
|
if finfo:
|
||||||
|
prebuilt_file_map[parsed_section_path] = {
|
||||||
|
"file_id": finfo["file_id"],
|
||||||
|
"file_name": finfo["file_name"],
|
||||||
|
"match_type": "exact",
|
||||||
|
}
|
||||||
|
|
||||||
|
md_content = "\n".join(md_lines)
|
||||||
|
|
||||||
|
# Check stop before running test
|
||||||
|
if stop_check():
|
||||||
|
return
|
||||||
|
|
||||||
|
# Run single-jump test
|
||||||
|
from api.single_jump import _run_task
|
||||||
|
|
||||||
|
# Import necessary modules
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent / "sdk"))
|
||||||
|
|
||||||
|
await _run_task(
|
||||||
|
task_id=sj_task_id,
|
||||||
|
qa_text=md_content,
|
||||||
|
env_url=env_url,
|
||||||
|
org_id=org_id,
|
||||||
|
d_user_id=d_user_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
hit_top_k=top_k,
|
||||||
|
recall_top_k=recall_top_k,
|
||||||
|
concurrency=concurrency,
|
||||||
|
cross_chunk=cross_chunk,
|
||||||
|
prebuilt_file_map=prebuilt_file_map if prebuilt_file_map else None,
|
||||||
|
prebuilt_chunk_map=question_chunk_map if question_chunk_map else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# After test completes, aggregate stats from single_jump_result
|
||||||
|
async with get_db() as db:
|
||||||
|
# Wait a bit for the test to complete (polling)
|
||||||
|
max_wait = 1800 # Max 30 minutes wait for large tasks
|
||||||
|
waited = 0
|
||||||
|
while waited < max_wait:
|
||||||
|
# Check stop during polling
|
||||||
|
if stop_check():
|
||||||
|
return
|
||||||
|
|
||||||
|
row = await db.execute_fetchall(
|
||||||
|
"SELECT status FROM single_jump_task WHERE id=?",
|
||||||
|
(sj_task_id,)
|
||||||
|
)
|
||||||
|
if row and row[0]["status"] in ("done", "failed"):
|
||||||
|
break
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
waited += 2
|
||||||
|
|
||||||
|
# Aggregate stats
|
||||||
|
stats_rows = await db.execute_fetchall(
|
||||||
|
"""SELECT
|
||||||
|
COUNT(*) as tested,
|
||||||
|
SUM(CASE WHEN error IS NULL AND COALESCE(json_array_length(retrieved), 0) > 0 THEN 1 ELSE 0 END) as recalled,
|
||||||
|
SUM(CASE WHEN is_file_hit = 1 THEN 1 ELSE 0 END) as file_hit,
|
||||||
|
SUM(CASE WHEN is_chunk_hit = 1 THEN 1 ELSE 0 END) as chunk_hit
|
||||||
|
FROM single_jump_result
|
||||||
|
WHERE task_id=?""",
|
||||||
|
(sj_task_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
if stats_rows:
|
||||||
|
stats = dict(stats_rows[0])
|
||||||
|
await db.execute(
|
||||||
|
"""UPDATE loop_round
|
||||||
|
SET tested=?, recalled=?, file_hit=?, chunk_hit=?
|
||||||
|
WHERE id=?""",
|
||||||
|
(stats.get("tested") or 0, stats.get("recalled") or 0,
|
||||||
|
stats.get("file_hit") or 0, stats.get("chunk_hit") or 0,
|
||||||
|
round_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_loop_stats(loop_task_id: str):
|
||||||
|
"""Update cumulative stats from all rounds."""
|
||||||
|
async with get_db() as db:
|
||||||
|
# Aggregate from loop_round
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"""SELECT
|
||||||
|
SUM(generated) as total_generated,
|
||||||
|
SUM(approved) as total_approved,
|
||||||
|
SUM(duplicates) as total_duplicates,
|
||||||
|
SUM(tested) as total_tested,
|
||||||
|
SUM(recalled) as total_recalled,
|
||||||
|
SUM(file_hit) as total_file_hit,
|
||||||
|
SUM(chunk_hit) as total_chunk_hit
|
||||||
|
FROM loop_round WHERE loop_task_id=?""",
|
||||||
|
(loop_task_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
stats = dict(rows[0]) if rows else {}
|
||||||
|
|
||||||
|
# Count file_miss and recall_failed from single_jump_result
|
||||||
|
miss_rows = await db.execute_fetchall(
|
||||||
|
"""SELECT
|
||||||
|
SUM(CASE WHEN r.is_file_hit=0 AND COALESCE(json_array_length(r.retrieved), 0)>0 THEN 1 ELSE 0 END) as file_miss,
|
||||||
|
SUM(CASE WHEN COALESCE(json_array_length(r.retrieved), 0)=0 AND r.error IS NULL THEN 1 ELSE 0 END) as recall_failed
|
||||||
|
FROM single_jump_result r
|
||||||
|
JOIN loop_round lr ON r.task_id = lr.single_jump_task_id
|
||||||
|
WHERE lr.loop_task_id=?""",
|
||||||
|
(loop_task_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
miss_stats = dict(miss_rows[0]) if miss_rows else {}
|
||||||
|
|
||||||
|
await db.execute(
|
||||||
|
"""UPDATE loop_task SET
|
||||||
|
total_generated=?,
|
||||||
|
total_approved=?,
|
||||||
|
total_duplicates=?,
|
||||||
|
total_tested=?,
|
||||||
|
total_recalled=?,
|
||||||
|
total_file_hit=?,
|
||||||
|
total_file_miss=?,
|
||||||
|
total_recall_failed=?,
|
||||||
|
total_chunk_hit=?
|
||||||
|
WHERE id=?""",
|
||||||
|
(
|
||||||
|
stats.get("total_generated") or 0,
|
||||||
|
stats.get("total_approved") or 0,
|
||||||
|
stats.get("total_duplicates") or 0,
|
||||||
|
stats.get("total_tested") or 0,
|
||||||
|
stats.get("total_recalled") or 0,
|
||||||
|
stats.get("total_file_hit") or 0,
|
||||||
|
miss_stats.get("file_miss") or 0,
|
||||||
|
miss_stats.get("recall_failed") or 0,
|
||||||
|
stats.get("total_chunk_hit") or 0,
|
||||||
|
loop_task_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def recover_orphaned_loops():
|
||||||
|
"""On startup, set any 'running' loop tasks to 'paused'."""
|
||||||
|
async with get_db() as db:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT id FROM loop_task WHERE status='running'"
|
||||||
|
)
|
||||||
|
for row in rows:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE loop_task SET status='paused', paused_at=? WHERE id=?",
|
||||||
|
(_now(), row["id"]),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
99
server/service/loop_recall_md.py
Normal file
99
server/service/loop_recall_md.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
循环测试相关:生成与单跳召回解析器一致的 Markdown 片段。
|
||||||
|
|
||||||
|
约定(与 rag_eval.single_jump.parser 对齐):
|
||||||
|
- `##` 行在有 `file_name` 时为 `{file_name} / {doc_name}`,便于 FileMapper;
|
||||||
|
- 完整中文切片名写在 `# 第N章` 与 `> 原始切片标题`;
|
||||||
|
- 每条问答带可选的 `> chunk_id:`,便于切片级命中校验。
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from collections.abc import Callable, Iterable
|
||||||
|
|
||||||
|
DEFAULT_LLM_NOTE = "> 由 LLM 自动生成的问答对"
|
||||||
|
|
||||||
|
|
||||||
|
def doc_name_from_file_name(file_name: str) -> str:
|
||||||
|
"""知识库路径去扩展名,用于 `## xxx.md / xxx` 的右侧。"""
|
||||||
|
fn = (file_name or "").strip()
|
||||||
|
if not fn:
|
||||||
|
return "document"
|
||||||
|
base = fn.rsplit("/", 1)[-1]
|
||||||
|
return base.rsplit(".", 1)[0] if "." in base else base
|
||||||
|
|
||||||
|
|
||||||
|
def chapter_title_suffix(slice_title: str, max_len: int = 80) -> str:
|
||||||
|
"""章节行 `# 第N章 …` 的展示用短标题。"""
|
||||||
|
s = (slice_title or "").strip() or "未命名切片"
|
||||||
|
s = re.sub(r"\s+", " ", s)
|
||||||
|
return s if len(s) <= max_len else s[: max_len - 1] + "…"
|
||||||
|
|
||||||
|
|
||||||
|
def recall_parsed_section_path(file_name: str, slice_title: str) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
parsed: `##` 行正文(即解析后的 section_path,与 prebuilt_file_map 键一致)
|
||||||
|
doc_suffix: 用于 `# N. {doc_suffix}_Document` 的末段名
|
||||||
|
"""
|
||||||
|
fn = (file_name or "").strip()
|
||||||
|
st = (slice_title or "").strip() or "default"
|
||||||
|
if fn:
|
||||||
|
doc_name = doc_name_from_file_name(fn)
|
||||||
|
parsed = f"{fn} / {doc_name}"
|
||||||
|
doc_suffix = doc_name.split("/")[-1]
|
||||||
|
return parsed, doc_suffix
|
||||||
|
raw_doc = st.split("/")[-1].strip() if "/" in st else st
|
||||||
|
parsed = f"{st} / {raw_doc}"
|
||||||
|
doc_suffix = raw_doc
|
||||||
|
return parsed, doc_suffix
|
||||||
|
|
||||||
|
|
||||||
|
def append_recall_md_section(
|
||||||
|
lines: list[str],
|
||||||
|
section_index: int,
|
||||||
|
*,
|
||||||
|
file_name: str,
|
||||||
|
slice_title: str,
|
||||||
|
qa_items: list[dict],
|
||||||
|
meta_lines: list[str] | None = None,
|
||||||
|
after_answer_lines: Callable[[int, dict], Iterable[str]] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
向 lines 追加一个完整 section,返回解析用 section_path(与 `##` 行一致)。
|
||||||
|
|
||||||
|
qa_items: 每项含 question、reference_answer;可选 chunk_id。
|
||||||
|
|
||||||
|
meta_lines: 写在 `# N. xxx_Document` 之后、`---` 之前;None 时仅写入 DEFAULT_LLM_NOTE。
|
||||||
|
|
||||||
|
after_answer_lines: 在每条 `**An:**` 之后、该问答块空行之前插入的额外行。
|
||||||
|
"""
|
||||||
|
parsed, doc_suffix = recall_parsed_section_path(file_name, slice_title)
|
||||||
|
ch = chapter_title_suffix(slice_title)
|
||||||
|
lines.append(f"# 第{section_index}章 {ch}")
|
||||||
|
lines.append(f"## {parsed}")
|
||||||
|
st = (slice_title or "").strip()
|
||||||
|
if st:
|
||||||
|
lines.append(f"> 原始切片标题: {st}")
|
||||||
|
lines.append(f"# {section_index}. {doc_suffix}_Document")
|
||||||
|
for meta in meta_lines if meta_lines is not None else [DEFAULT_LLM_NOTE]:
|
||||||
|
lines.append(meta)
|
||||||
|
lines.append("---")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
for i, item in enumerate(qa_items, 1):
|
||||||
|
lines.append(f"## Q{i}: {item['question']}")
|
||||||
|
cid = (item.get("chunk_id") or "").strip()
|
||||||
|
if cid:
|
||||||
|
lines.append(f"> chunk_id: {cid}")
|
||||||
|
lines.append(f"**A{i}:** {item['reference_answer']}")
|
||||||
|
if after_answer_lines:
|
||||||
|
for L in after_answer_lines(i, item):
|
||||||
|
if L:
|
||||||
|
lines.append(L)
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
lines.append("---")
|
||||||
|
lines.append("")
|
||||||
|
return parsed
|
||||||
305
server/service/task_service.py
Normal file
305
server/service/task_service.py
Normal file
@ -0,0 +1,305 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Make sdk and server root importable
|
||||||
|
_server_root = Path(__file__).parent.parent
|
||||||
|
sys.path.insert(0, str(_server_root))
|
||||||
|
sys.path.insert(0, str(_server_root.parent / "sdk"))
|
||||||
|
|
||||||
|
from rag_eval.adapters.dagent import DagentAdapter
|
||||||
|
from rag_eval.judge.openai_compatible import OpenAICompatibleJudge
|
||||||
|
from rag_eval.runner import EvalRunner, RunConfig
|
||||||
|
from rag_eval.dataset.schema import EvalDataset, EvalSample
|
||||||
|
from rag_eval.dataset.generator import DatasetGenerator
|
||||||
|
from models.db import get_db, _now, _id
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_platform_config(db, config_id: str) -> dict:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM platform_config WHERE id=?", (config_id,)
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
raise ValueError(f"Platform config {config_id} not found")
|
||||||
|
return dict(rows[0])
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_judge_config(db, config_id: str) -> dict:
|
||||||
|
rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM judge_config WHERE id=?", (config_id,)
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
raise ValueError(f"Judge config {config_id} not found")
|
||||||
|
return dict(rows[0])
|
||||||
|
|
||||||
|
|
||||||
|
async def _load_dataset(db, dataset_id: str) -> EvalDataset:
|
||||||
|
ds_rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM eval_dataset WHERE id=?", (dataset_id,)
|
||||||
|
)
|
||||||
|
if not ds_rows:
|
||||||
|
raise ValueError(f"Dataset {dataset_id} not found")
|
||||||
|
ds = dict(ds_rows[0])
|
||||||
|
|
||||||
|
sample_rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM eval_sample WHERE dataset_id=?", (dataset_id,)
|
||||||
|
)
|
||||||
|
samples = [
|
||||||
|
EvalSample(
|
||||||
|
id=r["id"],
|
||||||
|
question=r["question"],
|
||||||
|
reference_answer=r["reference_answer"],
|
||||||
|
relevant_chunk_ids=json.loads(r["relevant_chunk_ids"] or "[]"),
|
||||||
|
knowledge_hub_id=r["knowledge_hub_id"],
|
||||||
|
source_file_id=r["source_file_id"],
|
||||||
|
metadata=json.loads(r["metadata"] or "{}"),
|
||||||
|
)
|
||||||
|
for r in sample_rows
|
||||||
|
]
|
||||||
|
return EvalDataset(id=ds["id"], name=ds["name"], description=ds.get("description", ""), samples=samples)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_eval_task(task_id: str):
|
||||||
|
"""Background coroutine: runs the full eval loop for a task."""
|
||||||
|
async with get_db() as db:
|
||||||
|
task_rows = await db.execute_fetchall(
|
||||||
|
"SELECT * FROM eval_task WHERE id=?", (task_id,)
|
||||||
|
)
|
||||||
|
if not task_rows:
|
||||||
|
return
|
||||||
|
task = dict(task_rows[0])
|
||||||
|
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE eval_task SET status='running' WHERE id=?", (task_id,)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
try:
|
||||||
|
platform_cfg = await _get_platform_config(db, task["platform_config_id"])
|
||||||
|
judge_cfg = await _get_judge_config(db, task["judge_config_id"])
|
||||||
|
dataset = await _load_dataset(db, task["dataset_id"])
|
||||||
|
except Exception as exc:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE eval_task SET status='failed', error_message=? WHERE id=?",
|
||||||
|
(str(exc), task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return
|
||||||
|
|
||||||
|
adapter = DagentAdapter(
|
||||||
|
base_url=platform_cfg["base_url"],
|
||||||
|
org_id=platform_cfg.get("org_id", ""),
|
||||||
|
token=platform_cfg.get("token", ""),
|
||||||
|
)
|
||||||
|
judge = OpenAICompatibleJudge(
|
||||||
|
base_url=judge_cfg["base_url"],
|
||||||
|
api_key=judge_cfg["api_key"],
|
||||||
|
model=judge_cfg["model"],
|
||||||
|
embed_base_url=judge_cfg.get("embed_base_url", ""),
|
||||||
|
embed_api_key=judge_cfg.get("embed_api_key", ""),
|
||||||
|
embed_model=judge_cfg.get("embed_model", "text-embedding-3-small"),
|
||||||
|
)
|
||||||
|
run_cfg = RunConfig(
|
||||||
|
agent_id=task["agent_id"],
|
||||||
|
knowledge_hub_id=task["knowledge_hub_id"],
|
||||||
|
top_k=task["top_k"],
|
||||||
|
eval_retrieval=bool(task["eval_retrieval"]),
|
||||||
|
eval_generation=bool(task["eval_generation"]),
|
||||||
|
selected_metrics=json.loads(task.get("selected_metrics") or "[]") or None,
|
||||||
|
file_id_list=json.loads(task["file_id_list"] or "[]") or None,
|
||||||
|
concurrency=task["concurrency"],
|
||||||
|
)
|
||||||
|
|
||||||
|
finished = 0
|
||||||
|
total = len(dataset.samples)
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE eval_task SET total=? WHERE id=?", (total, task_id)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def _progress(done, _total):
|
||||||
|
nonlocal finished
|
||||||
|
finished = done
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE eval_task SET progress=? WHERE id=?", (done, task_id)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
runner = EvalRunner(adapter=adapter, judge=judge)
|
||||||
|
|
||||||
|
try:
|
||||||
|
report = await runner.run(dataset, run_cfg, progress_cb=lambda d, t: asyncio.create_task(_progress(d, t)))
|
||||||
|
except Exception as exc:
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE eval_task SET status='failed', error_message=? WHERE id=?",
|
||||||
|
(str(exc), task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Generate interpretation using judge LLM
|
||||||
|
interpretation = ""
|
||||||
|
try:
|
||||||
|
# Format metrics for prompt
|
||||||
|
def fmt(val, fmt_str='.2%'):
|
||||||
|
return f"{val:{fmt_str}}" if val is not None else 'N/A'
|
||||||
|
|
||||||
|
interp_prompt = f"""请对以下 RAG 系统评测结果进行解读分析,用 2-3 段中文总结:
|
||||||
|
|
||||||
|
评测样本数:{report.sample_count}
|
||||||
|
|
||||||
|
检索层指标:
|
||||||
|
- 命中率 (Hit Rate): {fmt(report.avg_hit_rate)}
|
||||||
|
- 平均倒数排名 (MRR): {fmt(report.avg_mrr, '.4f')}
|
||||||
|
- 归一化折损累积增益 (NDCG): {fmt(report.avg_ndcg, '.4f')}
|
||||||
|
- 上下文精确度 (Context Precision): {fmt(report.avg_context_precision)}
|
||||||
|
- 上下文召回率 (Context Recall): {fmt(report.avg_context_recall)}
|
||||||
|
|
||||||
|
生成层指标:
|
||||||
|
- 忠实度 (Faithfulness): {fmt(report.avg_faithfulness)}
|
||||||
|
- 回答相关性 (Answer Relevance): {fmt(report.avg_answer_relevance, '.4f')}
|
||||||
|
- 回答正确性 (Answer Correctness): {fmt(report.avg_answer_correctness, '.4f')}
|
||||||
|
- 可溯源性 (Groundedness): {fmt(report.avg_groundedness)}
|
||||||
|
|
||||||
|
综合指标:
|
||||||
|
- RAG Score: {fmt(report.rag_score)}
|
||||||
|
- 幻觉发生率: {fmt(report.hallucination_rate)}
|
||||||
|
|
||||||
|
请从以下角度分析:
|
||||||
|
1. 整体表现评价(优势和亮点)
|
||||||
|
2. 存在的主要问题和不足
|
||||||
|
3. 具体改进建议
|
||||||
|
|
||||||
|
要求:语言简洁专业,每段 2-3 句话,总字数 200-300 字。"""
|
||||||
|
|
||||||
|
interpretation = await judge._call(interp_prompt)
|
||||||
|
except Exception:
|
||||||
|
interpretation = "评测结果解释生成失败"
|
||||||
|
|
||||||
|
# Persist results and report
|
||||||
|
async with get_db() as db:
|
||||||
|
for r in report.results:
|
||||||
|
await db.execute(
|
||||||
|
"""INSERT INTO eval_result
|
||||||
|
(id,task_id,sample_id,question,reference_answer,retrieved_chunks,
|
||||||
|
agent_answer,hit_rate,mrr,ndcg,context_precision,context_recall,
|
||||||
|
faithfulness,answer_relevance,answer_correctness,groundedness,
|
||||||
|
latency_ms,judge_detail,error)
|
||||||
|
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""",
|
||||||
|
(
|
||||||
|
_id(), task_id, r.sample_id, r.question, r.reference_answer,
|
||||||
|
json.dumps(r.retrieved_chunks, ensure_ascii=False),
|
||||||
|
r.agent_answer, r.hit_rate, r.mrr, r.ndcg,
|
||||||
|
r.context_precision, r.context_recall,
|
||||||
|
r.faithfulness, r.answer_relevance, r.answer_correctness,
|
||||||
|
r.groundedness, r.latency_ms,
|
||||||
|
json.dumps(r.judge_detail, ensure_ascii=False),
|
||||||
|
r.error,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await db.execute(
|
||||||
|
"""INSERT OR REPLACE INTO eval_report
|
||||||
|
(id,task_id,sample_count,avg_hit_rate,avg_mrr,avg_ndcg,
|
||||||
|
avg_context_precision,avg_context_recall,avg_faithfulness,
|
||||||
|
avg_answer_relevance,avg_answer_correctness,avg_groundedness,
|
||||||
|
rag_score,hallucination_rate,interpretation,created_at)
|
||||||
|
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""",
|
||||||
|
(
|
||||||
|
_id(), task_id, report.sample_count,
|
||||||
|
report.avg_hit_rate, report.avg_mrr, report.avg_ndcg,
|
||||||
|
report.avg_context_precision, report.avg_context_recall,
|
||||||
|
report.avg_faithfulness, report.avg_answer_relevance,
|
||||||
|
report.avg_answer_correctness, report.avg_groundedness,
|
||||||
|
report.rag_score, report.hallucination_rate, interpretation, _now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE eval_task SET status='done', finished_at=?, progress=total WHERE id=?",
|
||||||
|
(_now(), task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_generate_task(params: dict):
|
||||||
|
"""Background coroutine: generates dataset samples via LLM."""
|
||||||
|
gen_task_id = params.get("gen_task_id")
|
||||||
|
|
||||||
|
async def _update_gen_progress(done: int, total: int):
|
||||||
|
if not gen_task_id:
|
||||||
|
return
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE generate_task SET progress=?, total=?, status='running' WHERE id=?",
|
||||||
|
(done, total, gen_task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
platform_cfg = await _get_platform_config(db, params["platform_config_id"])
|
||||||
|
judge_cfg = await _get_judge_config(db, params["judge_config_id"])
|
||||||
|
|
||||||
|
adapter = DagentAdapter(
|
||||||
|
base_url=platform_cfg["base_url"],
|
||||||
|
org_id=platform_cfg.get("org_id", ""),
|
||||||
|
token=platform_cfg.get("token", ""),
|
||||||
|
)
|
||||||
|
judge = OpenAICompatibleJudge(
|
||||||
|
base_url=judge_cfg["base_url"],
|
||||||
|
api_key=judge_cfg["api_key"],
|
||||||
|
model=judge_cfg["model"],
|
||||||
|
embed_base_url=judge_cfg.get("embed_base_url", ""),
|
||||||
|
embed_api_key=judge_cfg.get("embed_api_key", ""),
|
||||||
|
embed_model=judge_cfg.get("embed_model", "text-embedding-3-small"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
gen = DatasetGenerator(judge=judge, adapter=adapter)
|
||||||
|
dataset = await gen.generate(
|
||||||
|
knowledge_hub_id=params["knowledge_hub_id"],
|
||||||
|
file_id_list=params["file_id_list"],
|
||||||
|
questions_per_chunk=params.get("questions_per_chunk", 2),
|
||||||
|
max_chunks=params.get("max_chunks", 50),
|
||||||
|
chunk_ids=params.get("chunk_ids") or None,
|
||||||
|
progress_cb=_update_gen_progress,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
if gen_task_id:
|
||||||
|
async with get_db() as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE generate_task SET status='failed', error_message=?, finished_at=? WHERE id=?",
|
||||||
|
(str(exc), _now(), gen_task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return
|
||||||
|
|
||||||
|
async with get_db() as db:
|
||||||
|
for s in dataset.samples:
|
||||||
|
await db.execute(
|
||||||
|
"""INSERT INTO eval_sample
|
||||||
|
(id,dataset_id,question,reference_answer,relevant_chunk_ids,
|
||||||
|
knowledge_hub_id,source_file_id,metadata)
|
||||||
|
VALUES (?,?,?,?,?,?,?,?)""",
|
||||||
|
(
|
||||||
|
s.id, params["dataset_id"], s.question, s.reference_answer,
|
||||||
|
json.dumps(s.relevant_chunk_ids, ensure_ascii=False),
|
||||||
|
s.knowledge_hub_id, s.source_file_id,
|
||||||
|
json.dumps(s.metadata, ensure_ascii=False),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE eval_dataset SET sample_count=sample_count+? WHERE id=?",
|
||||||
|
(len(dataset.samples), params["dataset_id"]),
|
||||||
|
)
|
||||||
|
if gen_task_id:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE generate_task SET status='done', progress=total, finished_at=? WHERE id=?",
|
||||||
|
(_now(), gen_task_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
4
server/start_server.bat
Normal file
4
server/start_server.bat
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
@echo off
|
||||||
|
chcp 65001 >/dev/null
|
||||||
|
set PYTHONIOENCODING=utf-8
|
||||||
|
python -m uvicorn main:app --host 0.0.0.0 --port 8021
|
||||||
Loading…
x
Reference in New Issue
Block a user