228 lines
6.1 KiB
Python

import aiosqlite
import json
import uuid
from datetime import datetime
from pathlib import Path
from typing import Iterable
DB_PATH = Path(__file__).parent.parent / "data" / "rag_eval.db"
SCHEMA_PATH = Path(__file__).parent / "schema.sql"
from contextlib import asynccontextmanager
@asynccontextmanager
async def get_db():
"""Async context manager that yields a configured aiosqlite connection."""
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
async with aiosqlite.connect(DB_PATH, timeout=30.0) as db:
db.row_factory = aiosqlite.Row
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout=30000")
await db.execute("PRAGMA synchronous=NORMAL")
yield db
async def init_db():
async with get_db() as db:
sql = SCHEMA_PATH.read_text(encoding="utf-8")
await db.executescript(sql)
await _run_migrations(db)
await db.commit()
async def _run_migrations(db: aiosqlite.Connection):
"""Apply forward-only lightweight migrations for existing local DBs."""
await _ensure_columns(
db,
"single_jump_result",
(
("file_name", "TEXT"),
("match_type", "TEXT"),
("is_file_hit", "INTEGER DEFAULT 0"),
("expected_chunk_id", "TEXT"),
("is_chunk_hit", "INTEGER DEFAULT 0"),
("chunk_hit_rank", "INTEGER"),
("retrieved_chunk_ids", "TEXT"),
),
)
await _ensure_columns(
db,
"single_jump_task",
(
("progress", "INTEGER DEFAULT 0"),
("total", "INTEGER DEFAULT 0"),
("error_message", "TEXT"),
("finished_at", "TEXT"),
("md_content", "TEXT"),
),
)
# qa_gen tables migration
await _ensure_columns(
db,
"qa_gen_question",
(
("source_chunk", "TEXT"),
("quality_score", "REAL"),
("quality_detail", "TEXT"),
("dup_of", "TEXT"),
("dup_similarity", "REAL"),
("embedding", "TEXT"),
("updated_at", "TEXT"),
("file_id", "TEXT"),
("file_name", "TEXT"),
("chunk_id", "TEXT"),
("chunk_headers", "TEXT"),
("chunk_content_preview", "TEXT"),
),
)
await _ensure_columns(
db,
"qa_gen_task",
(
("approved", "INTEGER DEFAULT 0"),
),
)
await _ensure_columns(
db,
"loop_round",
(
("dedup_progress", "TEXT"),
),
)
# multi_hop_gen_task: add new columns for dagent source
await _ensure_columns(
db,
"multi_hop_gen_task",
(
("source", "TEXT NOT NULL DEFAULT 'file'"),
("org_id", "TEXT"),
("file_ids", "TEXT DEFAULT ''"),
),
)
# multi_hop_task: add llm_type column
await _ensure_columns(
db,
"multi_hop_task",
(
("judge_config_id", "TEXT DEFAULT ''"),
("llm_type", "TEXT DEFAULT 'deepseek_v3'"),
),
)
# multi_hop_task: add agent_id
await _ensure_columns(
db,
"multi_hop_task",
(
("agent_id", "TEXT DEFAULT ''"),
),
)
# multi_hop_result: add actual_hops and agent_answer
await _ensure_columns(
db,
"multi_hop_result",
(
("actual_hops", "TEXT DEFAULT '[]'"),
("agent_answer", "TEXT DEFAULT ''"),
("chunk_hit_count", "INTEGER DEFAULT 0"),
("full_chunk_hit", "INTEGER DEFAULT 0"),
("partial_chunk_hit", "INTEGER DEFAULT 0"),
),
)
# multi_hop_gen_task: add prompt_template_id
await _ensure_columns(
db,
"multi_hop_gen_task",
(
("prompt_template_id", "TEXT"),
),
)
# loop_task: add global_dedup flag
await _ensure_columns(
db,
"loop_task",
(
("global_dedup", "INTEGER DEFAULT 0"),
),
)
# loop_round: add chunk_hit for chunk-level hit tracking
await _ensure_columns(
db,
"loop_round",
(
("chunk_hit", "INTEGER DEFAULT 0"),
),
)
# loop_task: add total_chunk_hit for chunk-level aggregation
await _ensure_columns(
db,
"loop_task",
(
("total_chunk_hit", "INTEGER DEFAULT 0"),
),
)
# single_jump_task: add recall_top_k for unlimited recall results
await _ensure_columns(
db,
"single_jump_task",
(
("recall_top_k", "INTEGER DEFAULT 64"),
("hit_top_k", "INTEGER DEFAULT 64"),
),
)
# single_jump_result: add hit_top_k for chunk hit calculation
await _ensure_columns(
db,
"single_jump_result",
(
("hit_top_k", "INTEGER DEFAULT 64"),
),
)
# single_jump_result: add raw_chunk_headers for original section title
await _ensure_columns(
db,
"single_jump_result",
(
("raw_chunk_headers", "TEXT"),
),
)
# loop_task: add recall_top_k for unlimited recall results
await _ensure_columns(
db,
"loop_task",
(
("recall_top_k", "INTEGER DEFAULT 64"),
),
)
# loop_task: 批次规划中的切片总数,用于校验拉取是否完整(与 chunk_batches_plan.chunk_count 一致)
await _ensure_columns(
db,
"loop_task",
(
("expected_chunk_count", "INTEGER"),
),
)
async def _ensure_columns(
db: aiosqlite.Connection,
table_name: str,
columns: Iterable[tuple[str, str]],
):
"""Ensure table has required columns; add missing ones via ALTER TABLE."""
rows = await db.execute_fetchall(f"PRAGMA table_info({table_name})")
existing = {row["name"] for row in rows}
for column_name, column_def in columns:
if column_name in existing:
continue
await db.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_def}")
def _now() -> str:
return datetime.utcnow().isoformat()
def _id() -> str:
return uuid.uuid4().hex