199 lines
6.6 KiB
Python
199 lines
6.6 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
带索引的百度VDB测试 - 在表创建完成后添加索引
|
||
"""
|
||
|
||
import time
|
||
import logging
|
||
from baidu_vdb_minimal import BaiduVDBMinimal
|
||
import numpy as np
|
||
|
||
import pymochow
|
||
from pymochow.configuration import Configuration
|
||
from pymochow.auth.bce_credentials import BceCredentials
|
||
from pymochow.model.schema import VectorIndex, HNSWParams
|
||
from pymochow.model.enum import IndexType, MetricType
|
||
from pymochow.model.table import VectorTopkSearchRequest, VectorSearchConfig, FloatVector
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class BaiduVDBWithIndex(BaiduVDBMinimal):
|
||
"""带索引的VDB类"""
|
||
|
||
def wait_table_ready(self, max_wait_seconds=60):
|
||
"""等待表创建完成"""
|
||
logger.info("等待表创建完成...")
|
||
|
||
for i in range(max_wait_seconds):
|
||
try:
|
||
stats = self.test_table.stats()
|
||
logger.info(f"表状态检查 {i+1}/{max_wait_seconds}: {stats.get('msg', 'Unknown')}")
|
||
|
||
# 尝试存储一条测试数据
|
||
test_vector = np.random.rand(self.vector_dimension).astype(np.float32)
|
||
test_ids = self.store_vectors(["test"], test_vector.reshape(1, -1))
|
||
|
||
if test_ids:
|
||
logger.info("✅ 表已就绪,可以存储数据")
|
||
return True
|
||
|
||
except Exception as e:
|
||
if "Table Not Ready" in str(e):
|
||
logger.info(f"表仍在创建中... ({i+1}/{max_wait_seconds})")
|
||
time.sleep(1)
|
||
continue
|
||
else:
|
||
logger.error(f"其他错误: {e}")
|
||
break
|
||
|
||
logger.warning("⚠️ 表可能仍未就绪")
|
||
return False
|
||
|
||
def add_vector_index(self):
|
||
"""为现有表添加向量索引"""
|
||
try:
|
||
logger.info("为表添加向量索引...")
|
||
|
||
# 创建向量索引
|
||
vector_index = VectorIndex(
|
||
index_name="vector_hnsw_idx",
|
||
index_type=IndexType.HNSW,
|
||
field="vector",
|
||
metric_type=MetricType.COSINE,
|
||
params=HNSWParams(m=16, efconstruction=200),
|
||
auto_build=True
|
||
)
|
||
|
||
# 添加索引到表
|
||
self.test_table.add_index(vector_index)
|
||
logger.info("✅ 向量索引添加成功")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 添加向量索引失败: {e}")
|
||
return False
|
||
|
||
def search_vectors(self, query_vector: np.ndarray, top_k: int = 3) -> list:
|
||
"""搜索向量"""
|
||
try:
|
||
logger.info(f"搜索向量,top_k={top_k}")
|
||
|
||
# 创建搜索请求
|
||
request = VectorTopkSearchRequest(
|
||
vector_field="vector",
|
||
vector=FloatVector(query_vector.tolist()),
|
||
limit=top_k,
|
||
config=VectorSearchConfig(ef=200)
|
||
)
|
||
|
||
# 执行搜索
|
||
results = self.test_table.vector_search(request=request)
|
||
|
||
# 解析结果
|
||
search_results = []
|
||
for result in results:
|
||
doc_id = result.get('id', '')
|
||
content = result.get('content', '')
|
||
score = result.get('_score', 0.0)
|
||
|
||
search_results.append((doc_id, content, float(score)))
|
||
|
||
logger.info(f"✅ 向量搜索完成,返回 {len(search_results)} 条结果")
|
||
return search_results
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 向量搜索失败: {e}")
|
||
return []
|
||
|
||
def test_vdb_with_index():
|
||
"""测试带索引的VDB"""
|
||
print("=" * 60)
|
||
print("测试带索引的百度VDB")
|
||
print("=" * 60)
|
||
|
||
vdb = None
|
||
|
||
try:
|
||
# 1. 初始化VDB(复用无索引版本)
|
||
print("1. 初始化VDB连接...")
|
||
vdb = BaiduVDBWithIndex()
|
||
print("✅ VDB初始化成功")
|
||
|
||
# 2. 等待表就绪
|
||
print("\n2. 等待表创建完成...")
|
||
if vdb.wait_table_ready(30):
|
||
print("✅ 表已就绪")
|
||
else:
|
||
print("⚠️ 表可能仍在创建中,继续测试...")
|
||
|
||
# 3. 存储测试数据
|
||
print("\n3. 存储测试向量...")
|
||
test_contents = [
|
||
"这是第一个测试文档",
|
||
"这是第二个测试文档",
|
||
"这是第三个测试文档",
|
||
"这是第四个测试文档",
|
||
"这是第五个测试文档"
|
||
]
|
||
|
||
test_vectors = np.random.rand(len(test_contents), 128).astype(np.float32)
|
||
|
||
ids = vdb.store_vectors(test_contents, test_vectors)
|
||
print(f"✅ 存储了 {len(ids)} 条向量")
|
||
|
||
if not ids:
|
||
print("⚠️ 数据存储失败,跳过后续测试")
|
||
return False
|
||
|
||
# 4. 添加向量索引
|
||
print("\n4. 添加向量索引...")
|
||
if vdb.add_vector_index():
|
||
print("✅ 向量索引添加成功")
|
||
|
||
# 等待索引构建
|
||
print("等待索引构建...")
|
||
time.sleep(10)
|
||
|
||
# 5. 测试向量搜索
|
||
print("\n5. 测试向量搜索...")
|
||
query_vector = test_vectors[0] # 使用第一个向量作为查询
|
||
|
||
results = vdb.search_vectors(query_vector, top_k=3)
|
||
|
||
if results:
|
||
print(f"搜索结果 ({len(results)} 条):")
|
||
for i, (doc_id, content, score) in enumerate(results, 1):
|
||
print(f" {i}. {content} (相似度: {score:.4f})")
|
||
print("✅ 向量搜索成功")
|
||
else:
|
||
print("⚠️ 向量搜索失败或无结果")
|
||
else:
|
||
print("❌ 向量索引添加失败")
|
||
|
||
# 6. 获取最终统计
|
||
print("\n6. 获取统计信息...")
|
||
stats = vdb.get_statistics()
|
||
print("最终统计:")
|
||
for key, value in stats.items():
|
||
print(f" {key}: {value}")
|
||
|
||
print(f"\n🎉 带索引VDB测试完成!")
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"❌ 测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
finally:
|
||
if vdb:
|
||
vdb.close()
|
||
|
||
if __name__ == "__main__":
|
||
test_vdb_with_index()
|