483 lines
17 KiB
Python
483 lines
17 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
修复版百度VDB后端 - 解决Invalid Index Schema错误
|
||
基于官方文档规范重新设计表结构
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import numpy as np
|
||
import json
|
||
import hashlib
|
||
import time
|
||
import logging
|
||
from typing import List, Tuple, Dict, Any, Optional
|
||
|
||
import pymochow
|
||
from pymochow.configuration import Configuration
|
||
from pymochow.auth.bce_credentials import BceCredentials
|
||
from pymochow.model.schema import Schema, Field, VectorIndex, HNSWParams
|
||
from pymochow.model.enum import FieldType, IndexType, MetricType
|
||
from pymochow.model.table import Row, Partition
|
||
from pymochow.model.table import VectorTopkSearchRequest, VectorSearchConfig, FloatVector
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class BaiduVDBFixed:
|
||
"""修复版百度VDB后端类"""
|
||
|
||
def __init__(self,
|
||
account: str = "root",
|
||
api_key: str = "vdb$yjr9ln3n0td",
|
||
endpoint: str = "http://180.76.96.191:5287",
|
||
database_name: str = "multimodal_fixed",
|
||
vector_dimension: int = 3584):
|
||
"""
|
||
初始化VDB连接
|
||
|
||
Args:
|
||
account: 账户名
|
||
api_key: API密钥
|
||
endpoint: 服务端点
|
||
database_name: 数据库名称
|
||
vector_dimension: 向量维度
|
||
"""
|
||
self.account = account
|
||
self.api_key = api_key
|
||
self.endpoint = endpoint
|
||
self.database_name = database_name
|
||
self.vector_dimension = vector_dimension
|
||
|
||
# 表名
|
||
self.text_table_name = "text_vectors_v2"
|
||
self.image_table_name = "image_vectors_v2"
|
||
|
||
# 初始化连接
|
||
self.client = None
|
||
self.db = None
|
||
self.text_table = None
|
||
self.image_table = None
|
||
|
||
self._init_connection()
|
||
|
||
def _init_connection(self):
|
||
"""初始化数据库连接"""
|
||
try:
|
||
logger.info("🔗 初始化百度VDB连接...")
|
||
|
||
# 创建配置
|
||
config = Configuration(
|
||
credentials=BceCredentials(self.account, self.api_key),
|
||
endpoint=self.endpoint
|
||
)
|
||
|
||
# 创建客户端
|
||
self.client = pymochow.MochowClient(config)
|
||
logger.info("✅ VDB客户端创建成功")
|
||
|
||
# 确保数据库存在
|
||
self._ensure_database()
|
||
|
||
# 确保表存在
|
||
self._ensure_tables()
|
||
|
||
logger.info("✅ VDB后端初始化完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ VDB连接初始化失败: {e}")
|
||
raise
|
||
|
||
def _ensure_database(self):
|
||
"""确保数据库存在"""
|
||
try:
|
||
# 检查数据库是否存在
|
||
db_list = self.client.list_databases()
|
||
db_names = [db.database_name for db in db_list]
|
||
|
||
if self.database_name not in db_names:
|
||
logger.info(f"创建数据库: {self.database_name}")
|
||
self.db = self.client.create_database(self.database_name)
|
||
else:
|
||
logger.info(f"使用现有数据库: {self.database_name}")
|
||
self.db = self.client.database(self.database_name)
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 数据库操作失败: {e}")
|
||
raise
|
||
|
||
def _ensure_tables(self):
|
||
"""确保表存在"""
|
||
try:
|
||
# 获取现有表列表
|
||
table_list = self.db.list_table()
|
||
table_names = [table.table_name for table in table_list]
|
||
|
||
# 创建文本表
|
||
if self.text_table_name not in table_names:
|
||
self._create_text_table_fixed()
|
||
else:
|
||
self.text_table = self.db.table(self.text_table_name)
|
||
logger.info(f"✅ 使用现有文本表: {self.text_table_name}")
|
||
|
||
# 创建图像表
|
||
if self.image_table_name not in table_names:
|
||
self._create_image_table_fixed()
|
||
else:
|
||
self.image_table = self.db.table(self.image_table_name)
|
||
logger.info(f"✅ 使用现有图像表: {self.image_table_name}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 表操作失败: {e}")
|
||
raise
|
||
|
||
def _create_text_table_fixed(self):
|
||
"""创建修复版文本向量表"""
|
||
try:
|
||
logger.info(f"创建修复版文本向量表: {self.text_table_name}")
|
||
|
||
# 定义字段 - 严格按照官方文档规范
|
||
fields = [
|
||
# 主键和分区键 - 必须是STRING类型
|
||
Field("id", FieldType.STRING, primary_key=True, partition_key=True, not_null=True),
|
||
# 文本内容 - 使用STRING而不是TEXT
|
||
Field("content", FieldType.STRING, not_null=True),
|
||
# 向量字段 - 必须指定维度
|
||
Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=self.vector_dimension)
|
||
]
|
||
|
||
# 定义索引 - 只创建向量索引,避免复杂的二级索引
|
||
indexes = [
|
||
VectorIndex(
|
||
index_name="text_vector_index",
|
||
index_type=IndexType.HNSW,
|
||
field="vector",
|
||
metric_type=MetricType.COSINE,
|
||
params=HNSWParams(m=16, efconstruction=200), # 使用较小的参数
|
||
auto_build=True
|
||
)
|
||
]
|
||
|
||
# 创建Schema
|
||
schema = Schema(fields=fields, indexes=indexes)
|
||
|
||
# 创建表 - 使用较小的副本数和分区数
|
||
self.text_table = self.db.create_table(
|
||
table_name=self.text_table_name,
|
||
replication=2, # 最小副本数
|
||
partition=Partition(partition_num=1), # 单分区
|
||
schema=schema,
|
||
description="修复版文本向量表"
|
||
)
|
||
|
||
logger.info(f"✅ 文本表创建成功: {self.text_table_name}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 创建文本表失败: {e}")
|
||
raise
|
||
|
||
def _create_image_table_fixed(self):
|
||
"""创建修复版图像向量表"""
|
||
try:
|
||
logger.info(f"创建修复版图像向量表: {self.image_table_name}")
|
||
|
||
# 定义字段 - 严格按照官方文档规范
|
||
fields = [
|
||
# 主键和分区键
|
||
Field("id", FieldType.STRING, primary_key=True, partition_key=True, not_null=True),
|
||
# 图像路径
|
||
Field("image_path", FieldType.STRING, not_null=True),
|
||
# 向量字段
|
||
Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=self.vector_dimension)
|
||
]
|
||
|
||
# 定义索引 - 只创建向量索引
|
||
indexes = [
|
||
VectorIndex(
|
||
index_name="image_vector_index",
|
||
index_type=IndexType.HNSW,
|
||
field="vector",
|
||
metric_type=MetricType.COSINE,
|
||
params=HNSWParams(m=16, efconstruction=200),
|
||
auto_build=True
|
||
)
|
||
]
|
||
|
||
# 创建Schema
|
||
schema = Schema(fields=fields, indexes=indexes)
|
||
|
||
# 创建表
|
||
self.image_table = self.db.create_table(
|
||
table_name=self.image_table_name,
|
||
replication=2,
|
||
partition=Partition(partition_num=1),
|
||
schema=schema,
|
||
description="修复版图像向量表"
|
||
)
|
||
|
||
logger.info(f"✅ 图像表创建成功: {self.image_table_name}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 创建图像表失败: {e}")
|
||
raise
|
||
|
||
def _generate_id(self, content: str) -> str:
|
||
"""生成唯一ID"""
|
||
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
||
|
||
def store_text_vectors(self, texts: List[str], vectors: np.ndarray) -> List[str]:
|
||
"""存储文本向量"""
|
||
try:
|
||
if len(texts) != len(vectors):
|
||
raise ValueError("文本数量与向量数量不匹配")
|
||
|
||
logger.info(f"存储 {len(texts)} 条文本向量...")
|
||
|
||
rows = []
|
||
ids = []
|
||
|
||
for i, (text, vector) in enumerate(zip(texts, vectors)):
|
||
doc_id = self._generate_id(text)
|
||
ids.append(doc_id)
|
||
|
||
row = Row(
|
||
id=doc_id,
|
||
content=text,
|
||
vector=vector.tolist()
|
||
)
|
||
rows.append(row)
|
||
|
||
# 批量插入
|
||
self.text_table.upsert(rows)
|
||
logger.info(f"✅ 成功存储 {len(texts)} 条文本向量")
|
||
|
||
return ids
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 存储文本向量失败: {e}")
|
||
return []
|
||
|
||
def store_image_vectors(self, image_paths: List[str], vectors: np.ndarray) -> List[str]:
|
||
"""存储图像向量"""
|
||
try:
|
||
if len(image_paths) != len(vectors):
|
||
raise ValueError("图像数量与向量数量不匹配")
|
||
|
||
logger.info(f"存储 {len(image_paths)} 条图像向量...")
|
||
|
||
rows = []
|
||
ids = []
|
||
|
||
for i, (image_path, vector) in enumerate(zip(image_paths, vectors)):
|
||
doc_id = self._generate_id(image_path)
|
||
ids.append(doc_id)
|
||
|
||
row = Row(
|
||
id=doc_id,
|
||
image_path=image_path,
|
||
vector=vector.tolist()
|
||
)
|
||
rows.append(row)
|
||
|
||
# 批量插入
|
||
self.image_table.upsert(rows)
|
||
logger.info(f"✅ 成功存储 {len(image_paths)} 条图像向量")
|
||
|
||
return ids
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 存储图像向量失败: {e}")
|
||
return []
|
||
|
||
def search_text_vectors(self, query_vector: np.ndarray, top_k: int = 5) -> List[Tuple[str, str, float]]:
|
||
"""搜索文本向量"""
|
||
try:
|
||
logger.info(f"搜索文本向量,top_k={top_k}")
|
||
|
||
# 创建搜索请求
|
||
request = VectorTopkSearchRequest(
|
||
vector_field="vector",
|
||
vector=FloatVector(query_vector.tolist()),
|
||
limit=top_k,
|
||
config=VectorSearchConfig(ef=200)
|
||
)
|
||
|
||
# 执行搜索
|
||
results = self.text_table.vector_search(request=request)
|
||
|
||
# 解析结果
|
||
search_results = []
|
||
for result in results:
|
||
doc_id = result.get('id', '')
|
||
content = result.get('content', '')
|
||
score = result.get('_score', 0.0)
|
||
|
||
search_results.append((doc_id, content, float(score)))
|
||
|
||
logger.info(f"✅ 文本向量搜索完成,返回 {len(search_results)} 条结果")
|
||
return search_results
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 文本向量搜索失败: {e}")
|
||
return []
|
||
|
||
def search_image_vectors(self, query_vector: np.ndarray, top_k: int = 5) -> List[Tuple[str, str, float]]:
|
||
"""搜索图像向量"""
|
||
try:
|
||
logger.info(f"搜索图像向量,top_k={top_k}")
|
||
|
||
# 创建搜索请求
|
||
request = VectorTopkSearchRequest(
|
||
vector_field="vector",
|
||
vector=FloatVector(query_vector.tolist()),
|
||
limit=top_k,
|
||
config=VectorSearchConfig(ef=200)
|
||
)
|
||
|
||
# 执行搜索
|
||
results = self.image_table.vector_search(request=request)
|
||
|
||
# 解析结果
|
||
search_results = []
|
||
for result in results:
|
||
doc_id = result.get('id', '')
|
||
image_path = result.get('image_path', '')
|
||
score = result.get('_score', 0.0)
|
||
|
||
search_results.append((doc_id, image_path, float(score)))
|
||
|
||
logger.info(f"✅ 图像向量搜索完成,返回 {len(search_results)} 条结果")
|
||
return search_results
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 图像向量搜索失败: {e}")
|
||
return []
|
||
|
||
def get_statistics(self) -> Dict[str, Any]:
|
||
"""获取统计信息"""
|
||
try:
|
||
stats = {
|
||
"database_name": self.database_name,
|
||
"text_table": self.text_table_name,
|
||
"image_table": self.image_table_name,
|
||
"vector_dimension": self.vector_dimension,
|
||
"status": "connected"
|
||
}
|
||
|
||
# 尝试获取表统计信息
|
||
try:
|
||
text_stats = self.text_table.stats()
|
||
stats["text_count"] = text_stats.get("row_count", 0)
|
||
except:
|
||
stats["text_count"] = "unknown"
|
||
|
||
try:
|
||
image_stats = self.image_table.stats()
|
||
stats["image_count"] = image_stats.get("row_count", 0)
|
||
except:
|
||
stats["image_count"] = "unknown"
|
||
|
||
return stats
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 获取统计信息失败: {e}")
|
||
return {"status": "error", "error": str(e)}
|
||
|
||
def clear_all_data(self):
|
||
"""清空所有数据"""
|
||
try:
|
||
logger.info("清空所有数据...")
|
||
|
||
# 删除表(如果存在)
|
||
try:
|
||
self.db.drop_table(self.text_table_name)
|
||
logger.info(f"✅ 删除文本表: {self.text_table_name}")
|
||
except:
|
||
pass
|
||
|
||
try:
|
||
self.db.drop_table(self.image_table_name)
|
||
logger.info(f"✅ 删除图像表: {self.image_table_name}")
|
||
except:
|
||
pass
|
||
|
||
# 重新创建表
|
||
self._ensure_tables()
|
||
logger.info("✅ 数据清空完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 清空数据失败: {e}")
|
||
|
||
def close(self):
|
||
"""关闭连接"""
|
||
try:
|
||
if self.client:
|
||
self.client.close()
|
||
logger.info("✅ VDB连接已关闭")
|
||
except Exception as e:
|
||
logger.error(f"❌ 关闭连接失败: {e}")
|
||
|
||
def test_fixed_vdb():
|
||
"""测试修复版VDB"""
|
||
print("=" * 60)
|
||
print("测试修复版百度VDB后端")
|
||
print("=" * 60)
|
||
|
||
vdb = None
|
||
|
||
try:
|
||
# 1. 初始化VDB
|
||
print("1. 初始化VDB连接...")
|
||
vdb = BaiduVDBFixed()
|
||
print("✅ VDB初始化成功")
|
||
|
||
# 2. 测试文本向量存储
|
||
print("\n2. 测试文本向量存储...")
|
||
test_texts = [
|
||
"这是一个测试文本",
|
||
"另一个测试文本",
|
||
"第三个测试文本"
|
||
]
|
||
|
||
# 生成随机向量用于测试
|
||
test_vectors = np.random.rand(len(test_texts), 3584).astype(np.float32)
|
||
|
||
text_ids = vdb.store_text_vectors(test_texts, test_vectors)
|
||
print(f"✅ 存储了 {len(text_ids)} 条文本向量")
|
||
|
||
# 3. 测试文本向量搜索
|
||
print("\n3. 测试文本向量搜索...")
|
||
query_vector = np.random.rand(3584).astype(np.float32)
|
||
search_results = vdb.search_text_vectors(query_vector, top_k=2)
|
||
|
||
print(f"搜索结果 ({len(search_results)} 条):")
|
||
for i, (doc_id, content, score) in enumerate(search_results, 1):
|
||
print(f" {i}. {content[:30]}... (相似度: {score:.4f})")
|
||
|
||
# 4. 获取统计信息
|
||
print("\n4. 获取统计信息...")
|
||
stats = vdb.get_statistics()
|
||
print(f"✅ 统计信息: {stats}")
|
||
|
||
print(f"\n🎉 修复版VDB测试完成!")
|
||
print("✅ 表创建成功")
|
||
print("✅ 向量存储成功")
|
||
print("✅ 向量搜索成功")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"❌ 测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
finally:
|
||
if vdb:
|
||
vdb.close()
|
||
|
||
if __name__ == "__main__":
|
||
test_fixed_vdb()
|