350 lines
13 KiB
Python
350 lines
13 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
百度VDB连接可用性测试
|
||
测试连接、数据库操作、表操作和向量检索功能
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import time
|
||
import logging
|
||
import traceback
|
||
from typing import List, Dict, Any
|
||
|
||
import pymochow
|
||
from pymochow.configuration import Configuration
|
||
from pymochow.auth.bce_credentials import BceCredentials
|
||
from pymochow.model.schema import Schema, Field, VectorIndex, HNSWParams
|
||
from pymochow.model.enum import FieldType, IndexType, MetricType, TableState
|
||
from pymochow.model.table import Row, Partition
|
||
from pymochow.model.table import VectorTopkSearchRequest, VectorSearchConfig, FloatVector
|
||
import numpy as np
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class BaiduVDBConnectionTest:
|
||
"""百度VDB连接测试类"""
|
||
|
||
def __init__(self):
|
||
"""初始化测试连接"""
|
||
# 您提供的连接信息
|
||
self.account = "root"
|
||
self.api_key = "vdb$yjr9ln3n0td"
|
||
self.endpoint = "http://180.76.96.191:5287"
|
||
|
||
self.client = None
|
||
self.test_db_name = "test_connection_db"
|
||
self.test_table_name = "test_vectors"
|
||
|
||
def connect(self) -> bool:
|
||
"""测试连接"""
|
||
try:
|
||
logger.info("正在测试百度VDB连接...")
|
||
logger.info(f"端点: {self.endpoint}")
|
||
logger.info(f"账户: {self.account}")
|
||
|
||
# 创建配置
|
||
config = Configuration(
|
||
credentials=BceCredentials(self.account, self.api_key),
|
||
endpoint=self.endpoint
|
||
)
|
||
|
||
# 创建客户端
|
||
self.client = pymochow.MochowClient(config)
|
||
logger.info("✅ VDB客户端创建成功")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ VDB连接失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return False
|
||
|
||
def test_database_operations(self) -> bool:
|
||
"""测试数据库操作"""
|
||
try:
|
||
logger.info("\n=== 测试数据库操作 ===")
|
||
|
||
# 1. 列出现有数据库
|
||
logger.info("1. 查询数据库列表...")
|
||
databases = self.client.list_databases()
|
||
logger.info(f"现有数据库数量: {len(databases)}")
|
||
for db in databases:
|
||
logger.info(f" - {db.database_name}")
|
||
|
||
# 2. 创建测试数据库
|
||
logger.info(f"2. 创建测试数据库: {self.test_db_name}")
|
||
try:
|
||
# 先尝试删除可能存在的测试数据库
|
||
try:
|
||
self.client.drop_database(self.test_db_name)
|
||
logger.info("删除了已存在的测试数据库")
|
||
except:
|
||
pass
|
||
|
||
# 创建新数据库
|
||
db = self.client.create_database(self.test_db_name)
|
||
logger.info(f"✅ 数据库创建成功: {db.database_name}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 数据库创建失败: {str(e)}")
|
||
return False
|
||
|
||
# 3. 验证数据库创建
|
||
logger.info("3. 验证数据库创建...")
|
||
databases = self.client.list_databases()
|
||
db_names = [db.database_name for db in databases]
|
||
if self.test_db_name in db_names:
|
||
logger.info("✅ 数据库验证成功")
|
||
return True
|
||
else:
|
||
logger.error("❌ 数据库验证失败")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 数据库操作测试失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return False
|
||
|
||
def test_table_operations(self) -> bool:
|
||
"""测试表操作"""
|
||
try:
|
||
logger.info("\n=== 测试表操作 ===")
|
||
|
||
# 获取数据库对象
|
||
db = self.client.database(self.test_db_name)
|
||
|
||
# 1. 定义表结构
|
||
logger.info("1. 定义表结构...")
|
||
fields = []
|
||
fields.append(Field("id", FieldType.STRING, primary_key=True,
|
||
partition_key=True, auto_increment=False, not_null=True))
|
||
fields.append(Field("content", FieldType.STRING, not_null=True))
|
||
fields.append(Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=128))
|
||
|
||
# 定义向量索引
|
||
indexes = []
|
||
indexes.append(
|
||
VectorIndex(
|
||
index_name="vector_idx",
|
||
index_type=IndexType.HNSW,
|
||
field="vector",
|
||
metric_type=MetricType.L2,
|
||
params=HNSWParams(m=16, efconstruction=100),
|
||
auto_build=True
|
||
)
|
||
)
|
||
|
||
# 2. 创建表
|
||
logger.info(f"2. 创建表: {self.test_table_name}")
|
||
table = db.create_table(
|
||
table_name=self.test_table_name,
|
||
replication=1, # 单副本用于测试
|
||
partition=Partition(partition_num=1), # 单分区用于测试
|
||
schema=Schema(fields=fields, indexes=indexes)
|
||
)
|
||
logger.info(f"✅ 表创建成功: {table.table_name}")
|
||
|
||
# 3. 等待表状态正常
|
||
logger.info("3. 等待表状态正常...")
|
||
max_wait = 30 # 最多等待30秒
|
||
wait_time = 0
|
||
while wait_time < max_wait:
|
||
table_info = db.describe_table(self.test_table_name)
|
||
logger.info(f"表状态: {table_info.state}")
|
||
if table_info.state == TableState.NORMAL:
|
||
logger.info("✅ 表状态正常")
|
||
break
|
||
time.sleep(2)
|
||
wait_time += 2
|
||
else:
|
||
logger.warning("⚠️ 表状态等待超时,继续测试...")
|
||
|
||
# 4. 查询表列表
|
||
logger.info("4. 查询表列表...")
|
||
tables = db.list_table()
|
||
table_names = [t.table_name for t in tables]
|
||
logger.info(f"表数量: {len(tables)}")
|
||
for table_name in table_names:
|
||
logger.info(f" - {table_name}")
|
||
|
||
if self.test_table_name in table_names:
|
||
logger.info("✅ 表操作测试成功")
|
||
return True
|
||
else:
|
||
logger.error("❌ 表验证失败")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 表操作测试失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return False
|
||
|
||
def test_vector_operations(self) -> bool:
|
||
"""测试向量操作"""
|
||
try:
|
||
logger.info("\n=== 测试向量操作 ===")
|
||
|
||
# 获取表对象
|
||
db = self.client.database(self.test_db_name)
|
||
table = db.table(self.test_table_name)
|
||
|
||
# 1. 插入测试向量
|
||
logger.info("1. 插入测试向量...")
|
||
test_vectors = []
|
||
for i in range(5):
|
||
vector = np.random.rand(128).astype(np.float32).tolist()
|
||
row = Row(
|
||
id=f"test_{i:03d}",
|
||
content=f"测试内容_{i}",
|
||
vector=vector
|
||
)
|
||
test_vectors.append(row)
|
||
|
||
table.upsert(test_vectors)
|
||
logger.info(f"✅ 插入了 {len(test_vectors)} 个向量")
|
||
|
||
# 2. 查询向量
|
||
logger.info("2. 测试向量查询...")
|
||
primary_key = {'id': 'test_001'}
|
||
result = table.query(primary_key=primary_key, retrieve_vector=True)
|
||
if result:
|
||
logger.info("✅ 向量查询成功")
|
||
logger.info(f"查询结果: ID={result.id}, Content={result.content}")
|
||
else:
|
||
logger.warning("⚠️ 向量查询无结果")
|
||
|
||
# 3. 向量检索
|
||
logger.info("3. 测试向量检索...")
|
||
query_vector = np.random.rand(128).astype(np.float32).tolist()
|
||
search_request = VectorTopkSearchRequest(
|
||
vector_field="vector",
|
||
vector=FloatVector(query_vector),
|
||
limit=3,
|
||
config=VectorSearchConfig(ef=100)
|
||
)
|
||
|
||
search_results = table.vector_search(request=search_request)
|
||
logger.info(f"✅ 向量检索成功,返回 {len(search_results)} 个结果")
|
||
|
||
for i, result in enumerate(search_results):
|
||
logger.info(f" 结果 {i+1}: ID={result.id}, 相似度={result.distance:.4f}")
|
||
|
||
# 4. 查询表统计信息
|
||
logger.info("4. 查询表统计信息...")
|
||
stats = table.stats()
|
||
logger.info(f"记录数: {stats.rowCount}")
|
||
logger.info(f"内存大小: {stats.memorySizeInByte} bytes")
|
||
logger.info(f"磁盘大小: {stats.diskSizeInByte} bytes")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 向量操作测试失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return False
|
||
|
||
def cleanup(self) -> bool:
|
||
"""清理测试数据"""
|
||
try:
|
||
logger.info("\n=== 清理测试数据 ===")
|
||
|
||
# 删除测试数据库
|
||
logger.info(f"删除测试数据库: {self.test_db_name}")
|
||
self.client.drop_database(self.test_db_name)
|
||
logger.info("✅ 测试数据清理完成")
|
||
|
||
# 关闭连接
|
||
if self.client:
|
||
self.client.close()
|
||
logger.info("✅ VDB连接已关闭")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 清理失败: {str(e)}")
|
||
return False
|
||
|
||
def run_full_test(self) -> Dict[str, bool]:
|
||
"""运行完整测试"""
|
||
results = {
|
||
"connection": False,
|
||
"database_ops": False,
|
||
"table_ops": False,
|
||
"vector_ops": False,
|
||
"cleanup": False
|
||
}
|
||
|
||
try:
|
||
logger.info("🚀 开始百度VDB连接可用性测试")
|
||
logger.info("=" * 50)
|
||
|
||
# 1. 测试连接
|
||
if self.connect():
|
||
results["connection"] = True
|
||
|
||
# 2. 测试数据库操作
|
||
if self.test_database_operations():
|
||
results["database_ops"] = True
|
||
|
||
# 3. 测试表操作
|
||
if self.test_table_operations():
|
||
results["table_ops"] = True
|
||
|
||
# 4. 测试向量操作
|
||
if self.test_vector_operations():
|
||
results["vector_ops"] = True
|
||
|
||
# 5. 清理
|
||
if self.cleanup():
|
||
results["cleanup"] = True
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 测试过程中发生错误: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
|
||
finally:
|
||
# 输出测试结果
|
||
logger.info("\n" + "=" * 50)
|
||
logger.info("📊 测试结果汇总:")
|
||
logger.info("=" * 50)
|
||
|
||
test_items = [
|
||
("VDB连接", results["connection"]),
|
||
("数据库操作", results["database_ops"]),
|
||
("表操作", results["table_ops"]),
|
||
("向量操作", results["vector_ops"]),
|
||
("数据清理", results["cleanup"])
|
||
]
|
||
|
||
for item, success in test_items:
|
||
status = "✅ 成功" if success else "❌ 失败"
|
||
logger.info(f"{item}: {status}")
|
||
|
||
# 计算总体成功率
|
||
success_count = sum(results.values())
|
||
total_count = len(results)
|
||
success_rate = (success_count / total_count) * 100
|
||
|
||
logger.info(f"\n总体成功率: {success_count}/{total_count} ({success_rate:.1f}%)")
|
||
|
||
if success_rate >= 80:
|
||
logger.info("🎉 百度VDB连接测试基本通过!")
|
||
else:
|
||
logger.warning("⚠️ 百度VDB连接存在问题,需要进一步检查")
|
||
|
||
return results
|
||
|
||
def main():
|
||
"""主函数"""
|
||
tester = BaiduVDBConnectionTest()
|
||
results = tester.run_full_test()
|
||
|
||
# 返回测试是否成功
|
||
return results["connection"] and results["database_ops"]
|
||
|
||
if __name__ == "__main__":
|
||
main()
|