329 lines
11 KiB
Python
329 lines
11 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
最小化百度VDB测试 - 解决Invalid Index Schema错误
|
||
使用最简单的表结构,不创建任何索引
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import numpy as np
|
||
import json
|
||
import hashlib
|
||
import time
|
||
import logging
|
||
from typing import List, Tuple, Dict, Any, Optional
|
||
|
||
import pymochow
|
||
from pymochow.configuration import Configuration
|
||
from pymochow.auth.bce_credentials import BceCredentials
|
||
from pymochow.model.schema import Schema, Field
|
||
from pymochow.model.enum import FieldType
|
||
from pymochow.model.table import Row, Partition
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class BaiduVDBMinimal:
|
||
"""最小化百度VDB后端类 - 无索引版本"""
|
||
|
||
def __init__(self,
|
||
account: str = "root",
|
||
api_key: str = "vdb$yjr9ln3n0td",
|
||
endpoint: str = "http://180.76.96.191:5287",
|
||
database_name: str = "minimal_test",
|
||
vector_dimension: int = 128): # 使用较小的向量维度
|
||
"""
|
||
初始化VDB连接
|
||
"""
|
||
self.account = account
|
||
self.api_key = api_key
|
||
self.endpoint = endpoint
|
||
self.database_name = database_name
|
||
self.vector_dimension = vector_dimension
|
||
|
||
# 表名
|
||
self.test_table_name = "simple_vectors"
|
||
|
||
# 初始化连接
|
||
self.client = None
|
||
self.db = None
|
||
self.test_table = None
|
||
|
||
self._init_connection()
|
||
|
||
def _init_connection(self):
|
||
"""初始化数据库连接"""
|
||
try:
|
||
logger.info("🔗 初始化最小化VDB连接...")
|
||
|
||
# 创建配置
|
||
config = Configuration(
|
||
credentials=BceCredentials(self.account, self.api_key),
|
||
endpoint=self.endpoint
|
||
)
|
||
|
||
# 创建客户端
|
||
self.client = pymochow.MochowClient(config)
|
||
logger.info("✅ VDB客户端创建成功")
|
||
|
||
# 确保数据库存在
|
||
self._ensure_database()
|
||
|
||
# 确保表存在
|
||
self._ensure_table()
|
||
|
||
logger.info("✅ 最小化VDB后端初始化完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ VDB连接初始化失败: {e}")
|
||
raise
|
||
|
||
def _ensure_database(self):
|
||
"""确保数据库存在"""
|
||
try:
|
||
# 检查数据库是否存在
|
||
db_list = self.client.list_databases()
|
||
db_names = [db.database_name for db in db_list]
|
||
|
||
if self.database_name not in db_names:
|
||
logger.info(f"创建数据库: {self.database_name}")
|
||
self.db = self.client.create_database(self.database_name)
|
||
else:
|
||
logger.info(f"使用现有数据库: {self.database_name}")
|
||
self.db = self.client.database(self.database_name)
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 数据库操作失败: {e}")
|
||
raise
|
||
|
||
def _ensure_table(self):
|
||
"""确保表存在"""
|
||
try:
|
||
# 获取现有表列表
|
||
table_list = self.db.list_table()
|
||
table_names = [table.table_name for table in table_list]
|
||
|
||
# 创建测试表
|
||
if self.test_table_name not in table_names:
|
||
self._create_simple_table()
|
||
else:
|
||
self.test_table = self.db.table(self.test_table_name)
|
||
logger.info(f"✅ 使用现有表: {self.test_table_name}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 表操作失败: {e}")
|
||
raise
|
||
|
||
def _create_simple_table(self):
|
||
"""创建最简单的表 - 无索引"""
|
||
try:
|
||
logger.info(f"创建最简单的表: {self.test_table_name}")
|
||
|
||
# 定义字段 - 最简单的配置
|
||
fields = [
|
||
# 主键和分区键 - 必须是STRING类型
|
||
Field("id", FieldType.STRING, primary_key=True, partition_key=True, not_null=True),
|
||
# 内容字段
|
||
Field("content", FieldType.STRING, not_null=True),
|
||
# 向量字段 - 使用较小维度
|
||
Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=self.vector_dimension)
|
||
]
|
||
|
||
# 不创建任何索引 - 空索引列表
|
||
indexes = []
|
||
|
||
# 创建Schema
|
||
schema = Schema(fields=fields, indexes=indexes)
|
||
|
||
# 创建表 - 使用最小配置
|
||
self.test_table = self.db.create_table(
|
||
table_name=self.test_table_name,
|
||
replication=2, # 最小副本数
|
||
partition=Partition(partition_num=1), # 单分区
|
||
schema=schema,
|
||
description="最简单的测试表"
|
||
)
|
||
|
||
logger.info(f"✅ 简单表创建成功: {self.test_table_name}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 创建简单表失败: {e}")
|
||
raise
|
||
|
||
def _generate_id(self, content: str) -> str:
|
||
"""生成唯一ID"""
|
||
return hashlib.md5(content.encode('utf-8')).hexdigest()[:16] # 使用较短的ID
|
||
|
||
def store_vectors(self, contents: List[str], vectors: np.ndarray) -> List[str]:
|
||
"""存储向量"""
|
||
try:
|
||
if len(contents) != len(vectors):
|
||
raise ValueError("内容数量与向量数量不匹配")
|
||
|
||
logger.info(f"存储 {len(contents)} 条向量...")
|
||
|
||
rows = []
|
||
ids = []
|
||
|
||
for i, (content, vector) in enumerate(zip(contents, vectors)):
|
||
doc_id = self._generate_id(f"{content}_{i}")
|
||
ids.append(doc_id)
|
||
|
||
row = Row(
|
||
id=doc_id,
|
||
content=content,
|
||
vector=vector.tolist()
|
||
)
|
||
rows.append(row)
|
||
|
||
# 批量插入
|
||
self.test_table.upsert(rows)
|
||
logger.info(f"✅ 成功存储 {len(contents)} 条向量")
|
||
|
||
return ids
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 存储向量失败: {e}")
|
||
return []
|
||
|
||
def get_all_data(self) -> List[Dict]:
|
||
"""获取所有数据(用于验证)"""
|
||
try:
|
||
logger.info("获取所有数据...")
|
||
|
||
# 使用简单查询获取数据
|
||
# 注意:这里不使用向量搜索,而是直接查询
|
||
results = []
|
||
|
||
# 尝试通过表统计获取信息
|
||
try:
|
||
stats = self.test_table.stats()
|
||
logger.info(f"表统计信息: {stats}")
|
||
except Exception as e:
|
||
logger.warning(f"无法获取表统计: {e}")
|
||
|
||
return results
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 获取数据失败: {e}")
|
||
return []
|
||
|
||
def get_statistics(self) -> Dict[str, Any]:
|
||
"""获取统计信息"""
|
||
try:
|
||
stats = {
|
||
"database_name": self.database_name,
|
||
"table_name": self.test_table_name,
|
||
"vector_dimension": self.vector_dimension,
|
||
"status": "connected",
|
||
"has_indexes": False
|
||
}
|
||
|
||
# 尝试获取表统计信息
|
||
try:
|
||
table_stats = self.test_table.stats()
|
||
stats["table_stats"] = table_stats
|
||
except Exception as e:
|
||
stats["table_stats_error"] = str(e)
|
||
|
||
return stats
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 获取统计信息失败: {e}")
|
||
return {"status": "error", "error": str(e)}
|
||
|
||
def clear_all_data(self):
|
||
"""清空所有数据"""
|
||
try:
|
||
logger.info("清空所有数据...")
|
||
|
||
# 删除表(如果存在)
|
||
try:
|
||
self.db.drop_table(self.test_table_name)
|
||
logger.info(f"✅ 删除表: {self.test_table_name}")
|
||
except Exception as e:
|
||
logger.warning(f"删除表失败: {e}")
|
||
|
||
# 重新创建表
|
||
self._ensure_table()
|
||
logger.info("✅ 数据清空完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 清空数据失败: {e}")
|
||
|
||
def close(self):
|
||
"""关闭连接"""
|
||
try:
|
||
if self.client:
|
||
self.client.close()
|
||
logger.info("✅ VDB连接已关闭")
|
||
except Exception as e:
|
||
logger.error(f"❌ 关闭连接失败: {e}")
|
||
|
||
def test_minimal_vdb():
|
||
"""测试最小化VDB"""
|
||
print("=" * 60)
|
||
print("测试最小化百度VDB后端(无索引版本)")
|
||
print("=" * 60)
|
||
|
||
vdb = None
|
||
|
||
try:
|
||
# 1. 初始化VDB
|
||
print("1. 初始化最小化VDB连接...")
|
||
vdb = BaiduVDBMinimal()
|
||
print("✅ 最小化VDB初始化成功")
|
||
|
||
# 2. 测试向量存储
|
||
print("\n2. 测试向量存储...")
|
||
test_contents = [
|
||
"测试文本1",
|
||
"测试文本2",
|
||
"测试文本3"
|
||
]
|
||
|
||
# 生成随机向量用于测试(使用较小维度)
|
||
test_vectors = np.random.rand(len(test_contents), 128).astype(np.float32)
|
||
|
||
ids = vdb.store_vectors(test_contents, test_vectors)
|
||
print(f"✅ 存储了 {len(ids)} 条向量")
|
||
print(f"生成的ID: {ids}")
|
||
|
||
# 3. 获取统计信息
|
||
print("\n3. 获取统计信息...")
|
||
stats = vdb.get_statistics()
|
||
print(f"✅ 统计信息:")
|
||
for key, value in stats.items():
|
||
print(f" {key}: {value}")
|
||
|
||
# 4. 验证数据存储
|
||
print("\n4. 验证数据存储...")
|
||
data = vdb.get_all_data()
|
||
print(f"✅ 数据验证完成")
|
||
|
||
print(f"\n🎉 最小化VDB测试完成!")
|
||
print("✅ 表创建成功(无索引)")
|
||
print("✅ 向量存储成功")
|
||
print("✅ 基本操作正常")
|
||
print("\n📋 下一步:")
|
||
print("1. 表创建成功,说明基本结构没问题")
|
||
print("2. 可以尝试添加向量索引")
|
||
print("3. 测试向量搜索功能")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"❌ 测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
finally:
|
||
if vdb:
|
||
vdb.close()
|
||
|
||
if __name__ == "__main__":
|
||
test_minimal_vdb()
|