mmeb/baidu_vdb_with_index.py
2025-09-01 11:24:01 +00:00

199 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
带索引的百度VDB测试 - 在表创建完成后添加索引
"""
import time
import logging
from baidu_vdb_minimal import BaiduVDBMinimal
import numpy as np
import pymochow
from pymochow.configuration import Configuration
from pymochow.auth.bce_credentials import BceCredentials
from pymochow.model.schema import VectorIndex, HNSWParams
from pymochow.model.enum import IndexType, MetricType
from pymochow.model.table import VectorTopkSearchRequest, VectorSearchConfig, FloatVector
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class BaiduVDBWithIndex(BaiduVDBMinimal):
"""带索引的VDB类"""
def wait_table_ready(self, max_wait_seconds=60):
"""等待表创建完成"""
logger.info("等待表创建完成...")
for i in range(max_wait_seconds):
try:
stats = self.test_table.stats()
logger.info(f"表状态检查 {i+1}/{max_wait_seconds}: {stats.get('msg', 'Unknown')}")
# 尝试存储一条测试数据
test_vector = np.random.rand(self.vector_dimension).astype(np.float32)
test_ids = self.store_vectors(["test"], test_vector.reshape(1, -1))
if test_ids:
logger.info("✅ 表已就绪,可以存储数据")
return True
except Exception as e:
if "Table Not Ready" in str(e):
logger.info(f"表仍在创建中... ({i+1}/{max_wait_seconds})")
time.sleep(1)
continue
else:
logger.error(f"其他错误: {e}")
break
logger.warning("⚠️ 表可能仍未就绪")
return False
def add_vector_index(self):
"""为现有表添加向量索引"""
try:
logger.info("为表添加向量索引...")
# 创建向量索引
vector_index = VectorIndex(
index_name="vector_hnsw_idx",
index_type=IndexType.HNSW,
field="vector",
metric_type=MetricType.COSINE,
params=HNSWParams(m=16, efconstruction=200),
auto_build=True
)
# 添加索引到表
self.test_table.add_index(vector_index)
logger.info("✅ 向量索引添加成功")
return True
except Exception as e:
logger.error(f"❌ 添加向量索引失败: {e}")
return False
def search_vectors(self, query_vector: np.ndarray, top_k: int = 3) -> list:
"""搜索向量"""
try:
logger.info(f"搜索向量top_k={top_k}")
# 创建搜索请求
request = VectorTopkSearchRequest(
vector_field="vector",
vector=FloatVector(query_vector.tolist()),
limit=top_k,
config=VectorSearchConfig(ef=200)
)
# 执行搜索
results = self.test_table.vector_search(request=request)
# 解析结果
search_results = []
for result in results:
doc_id = result.get('id', '')
content = result.get('content', '')
score = result.get('_score', 0.0)
search_results.append((doc_id, content, float(score)))
logger.info(f"✅ 向量搜索完成,返回 {len(search_results)} 条结果")
return search_results
except Exception as e:
logger.error(f"❌ 向量搜索失败: {e}")
return []
def test_vdb_with_index():
"""测试带索引的VDB"""
print("=" * 60)
print("测试带索引的百度VDB")
print("=" * 60)
vdb = None
try:
# 1. 初始化VDB复用无索引版本
print("1. 初始化VDB连接...")
vdb = BaiduVDBWithIndex()
print("✅ VDB初始化成功")
# 2. 等待表就绪
print("\n2. 等待表创建完成...")
if vdb.wait_table_ready(30):
print("✅ 表已就绪")
else:
print("⚠️ 表可能仍在创建中,继续测试...")
# 3. 存储测试数据
print("\n3. 存储测试向量...")
test_contents = [
"这是第一个测试文档",
"这是第二个测试文档",
"这是第三个测试文档",
"这是第四个测试文档",
"这是第五个测试文档"
]
test_vectors = np.random.rand(len(test_contents), 128).astype(np.float32)
ids = vdb.store_vectors(test_contents, test_vectors)
print(f"✅ 存储了 {len(ids)} 条向量")
if not ids:
print("⚠️ 数据存储失败,跳过后续测试")
return False
# 4. 添加向量索引
print("\n4. 添加向量索引...")
if vdb.add_vector_index():
print("✅ 向量索引添加成功")
# 等待索引构建
print("等待索引构建...")
time.sleep(10)
# 5. 测试向量搜索
print("\n5. 测试向量搜索...")
query_vector = test_vectors[0] # 使用第一个向量作为查询
results = vdb.search_vectors(query_vector, top_k=3)
if results:
print(f"搜索结果 ({len(results)} 条):")
for i, (doc_id, content, score) in enumerate(results, 1):
print(f" {i}. {content} (相似度: {score:.4f})")
print("✅ 向量搜索成功")
else:
print("⚠️ 向量搜索失败或无结果")
else:
print("❌ 向量索引添加失败")
# 6. 获取最终统计
print("\n6. 获取统计信息...")
stats = vdb.get_statistics()
print("最终统计:")
for key, value in stats.items():
print(f" {key}: {value}")
print(f"\n🎉 带索引VDB测试完成")
return True
except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
return False
finally:
if vdb:
vdb.close()
if __name__ == "__main__":
test_vdb_with_index()