mmeb/baidu_vdb_fixed.py
2025-09-01 11:24:01 +00:00

483 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
修复版百度VDB后端 - 解决Invalid Index Schema错误
基于官方文档规范重新设计表结构
"""
import os
import sys
import numpy as np
import json
import hashlib
import time
import logging
from typing import List, Tuple, Dict, Any, Optional
import pymochow
from pymochow.configuration import Configuration
from pymochow.auth.bce_credentials import BceCredentials
from pymochow.model.schema import Schema, Field, VectorIndex, HNSWParams
from pymochow.model.enum import FieldType, IndexType, MetricType
from pymochow.model.table import Row, Partition
from pymochow.model.table import VectorTopkSearchRequest, VectorSearchConfig, FloatVector
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class BaiduVDBFixed:
"""修复版百度VDB后端类"""
def __init__(self,
account: str = "root",
api_key: str = "vdb$yjr9ln3n0td",
endpoint: str = "http://180.76.96.191:5287",
database_name: str = "multimodal_fixed",
vector_dimension: int = 3584):
"""
初始化VDB连接
Args:
account: 账户名
api_key: API密钥
endpoint: 服务端点
database_name: 数据库名称
vector_dimension: 向量维度
"""
self.account = account
self.api_key = api_key
self.endpoint = endpoint
self.database_name = database_name
self.vector_dimension = vector_dimension
# 表名
self.text_table_name = "text_vectors_v2"
self.image_table_name = "image_vectors_v2"
# 初始化连接
self.client = None
self.db = None
self.text_table = None
self.image_table = None
self._init_connection()
def _init_connection(self):
"""初始化数据库连接"""
try:
logger.info("🔗 初始化百度VDB连接...")
# 创建配置
config = Configuration(
credentials=BceCredentials(self.account, self.api_key),
endpoint=self.endpoint
)
# 创建客户端
self.client = pymochow.MochowClient(config)
logger.info("✅ VDB客户端创建成功")
# 确保数据库存在
self._ensure_database()
# 确保表存在
self._ensure_tables()
logger.info("✅ VDB后端初始化完成")
except Exception as e:
logger.error(f"❌ VDB连接初始化失败: {e}")
raise
def _ensure_database(self):
"""确保数据库存在"""
try:
# 检查数据库是否存在
db_list = self.client.list_databases()
db_names = [db.database_name for db in db_list]
if self.database_name not in db_names:
logger.info(f"创建数据库: {self.database_name}")
self.db = self.client.create_database(self.database_name)
else:
logger.info(f"使用现有数据库: {self.database_name}")
self.db = self.client.database(self.database_name)
except Exception as e:
logger.error(f"❌ 数据库操作失败: {e}")
raise
def _ensure_tables(self):
"""确保表存在"""
try:
# 获取现有表列表
table_list = self.db.list_table()
table_names = [table.table_name for table in table_list]
# 创建文本表
if self.text_table_name not in table_names:
self._create_text_table_fixed()
else:
self.text_table = self.db.table(self.text_table_name)
logger.info(f"✅ 使用现有文本表: {self.text_table_name}")
# 创建图像表
if self.image_table_name not in table_names:
self._create_image_table_fixed()
else:
self.image_table = self.db.table(self.image_table_name)
logger.info(f"✅ 使用现有图像表: {self.image_table_name}")
except Exception as e:
logger.error(f"❌ 表操作失败: {e}")
raise
def _create_text_table_fixed(self):
"""创建修复版文本向量表"""
try:
logger.info(f"创建修复版文本向量表: {self.text_table_name}")
# 定义字段 - 严格按照官方文档规范
fields = [
# 主键和分区键 - 必须是STRING类型
Field("id", FieldType.STRING, primary_key=True, partition_key=True, not_null=True),
# 文本内容 - 使用STRING而不是TEXT
Field("content", FieldType.STRING, not_null=True),
# 向量字段 - 必须指定维度
Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=self.vector_dimension)
]
# 定义索引 - 只创建向量索引,避免复杂的二级索引
indexes = [
VectorIndex(
index_name="text_vector_index",
index_type=IndexType.HNSW,
field="vector",
metric_type=MetricType.COSINE,
params=HNSWParams(m=16, efconstruction=200), # 使用较小的参数
auto_build=True
)
]
# 创建Schema
schema = Schema(fields=fields, indexes=indexes)
# 创建表 - 使用较小的副本数和分区数
self.text_table = self.db.create_table(
table_name=self.text_table_name,
replication=2, # 最小副本数
partition=Partition(partition_num=1), # 单分区
schema=schema,
description="修复版文本向量表"
)
logger.info(f"✅ 文本表创建成功: {self.text_table_name}")
except Exception as e:
logger.error(f"❌ 创建文本表失败: {e}")
raise
def _create_image_table_fixed(self):
"""创建修复版图像向量表"""
try:
logger.info(f"创建修复版图像向量表: {self.image_table_name}")
# 定义字段 - 严格按照官方文档规范
fields = [
# 主键和分区键
Field("id", FieldType.STRING, primary_key=True, partition_key=True, not_null=True),
# 图像路径
Field("image_path", FieldType.STRING, not_null=True),
# 向量字段
Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=self.vector_dimension)
]
# 定义索引 - 只创建向量索引
indexes = [
VectorIndex(
index_name="image_vector_index",
index_type=IndexType.HNSW,
field="vector",
metric_type=MetricType.COSINE,
params=HNSWParams(m=16, efconstruction=200),
auto_build=True
)
]
# 创建Schema
schema = Schema(fields=fields, indexes=indexes)
# 创建表
self.image_table = self.db.create_table(
table_name=self.image_table_name,
replication=2,
partition=Partition(partition_num=1),
schema=schema,
description="修复版图像向量表"
)
logger.info(f"✅ 图像表创建成功: {self.image_table_name}")
except Exception as e:
logger.error(f"❌ 创建图像表失败: {e}")
raise
def _generate_id(self, content: str) -> str:
"""生成唯一ID"""
return hashlib.md5(content.encode('utf-8')).hexdigest()
def store_text_vectors(self, texts: List[str], vectors: np.ndarray) -> List[str]:
"""存储文本向量"""
try:
if len(texts) != len(vectors):
raise ValueError("文本数量与向量数量不匹配")
logger.info(f"存储 {len(texts)} 条文本向量...")
rows = []
ids = []
for i, (text, vector) in enumerate(zip(texts, vectors)):
doc_id = self._generate_id(text)
ids.append(doc_id)
row = Row(
id=doc_id,
content=text,
vector=vector.tolist()
)
rows.append(row)
# 批量插入
self.text_table.upsert(rows)
logger.info(f"✅ 成功存储 {len(texts)} 条文本向量")
return ids
except Exception as e:
logger.error(f"❌ 存储文本向量失败: {e}")
return []
def store_image_vectors(self, image_paths: List[str], vectors: np.ndarray) -> List[str]:
"""存储图像向量"""
try:
if len(image_paths) != len(vectors):
raise ValueError("图像数量与向量数量不匹配")
logger.info(f"存储 {len(image_paths)} 条图像向量...")
rows = []
ids = []
for i, (image_path, vector) in enumerate(zip(image_paths, vectors)):
doc_id = self._generate_id(image_path)
ids.append(doc_id)
row = Row(
id=doc_id,
image_path=image_path,
vector=vector.tolist()
)
rows.append(row)
# 批量插入
self.image_table.upsert(rows)
logger.info(f"✅ 成功存储 {len(image_paths)} 条图像向量")
return ids
except Exception as e:
logger.error(f"❌ 存储图像向量失败: {e}")
return []
def search_text_vectors(self, query_vector: np.ndarray, top_k: int = 5) -> List[Tuple[str, str, float]]:
"""搜索文本向量"""
try:
logger.info(f"搜索文本向量top_k={top_k}")
# 创建搜索请求
request = VectorTopkSearchRequest(
vector_field="vector",
vector=FloatVector(query_vector.tolist()),
limit=top_k,
config=VectorSearchConfig(ef=200)
)
# 执行搜索
results = self.text_table.vector_search(request=request)
# 解析结果
search_results = []
for result in results:
doc_id = result.get('id', '')
content = result.get('content', '')
score = result.get('_score', 0.0)
search_results.append((doc_id, content, float(score)))
logger.info(f"✅ 文本向量搜索完成,返回 {len(search_results)} 条结果")
return search_results
except Exception as e:
logger.error(f"❌ 文本向量搜索失败: {e}")
return []
def search_image_vectors(self, query_vector: np.ndarray, top_k: int = 5) -> List[Tuple[str, str, float]]:
"""搜索图像向量"""
try:
logger.info(f"搜索图像向量top_k={top_k}")
# 创建搜索请求
request = VectorTopkSearchRequest(
vector_field="vector",
vector=FloatVector(query_vector.tolist()),
limit=top_k,
config=VectorSearchConfig(ef=200)
)
# 执行搜索
results = self.image_table.vector_search(request=request)
# 解析结果
search_results = []
for result in results:
doc_id = result.get('id', '')
image_path = result.get('image_path', '')
score = result.get('_score', 0.0)
search_results.append((doc_id, image_path, float(score)))
logger.info(f"✅ 图像向量搜索完成,返回 {len(search_results)} 条结果")
return search_results
except Exception as e:
logger.error(f"❌ 图像向量搜索失败: {e}")
return []
def get_statistics(self) -> Dict[str, Any]:
"""获取统计信息"""
try:
stats = {
"database_name": self.database_name,
"text_table": self.text_table_name,
"image_table": self.image_table_name,
"vector_dimension": self.vector_dimension,
"status": "connected"
}
# 尝试获取表统计信息
try:
text_stats = self.text_table.stats()
stats["text_count"] = text_stats.get("row_count", 0)
except:
stats["text_count"] = "unknown"
try:
image_stats = self.image_table.stats()
stats["image_count"] = image_stats.get("row_count", 0)
except:
stats["image_count"] = "unknown"
return stats
except Exception as e:
logger.error(f"❌ 获取统计信息失败: {e}")
return {"status": "error", "error": str(e)}
def clear_all_data(self):
"""清空所有数据"""
try:
logger.info("清空所有数据...")
# 删除表(如果存在)
try:
self.db.drop_table(self.text_table_name)
logger.info(f"✅ 删除文本表: {self.text_table_name}")
except:
pass
try:
self.db.drop_table(self.image_table_name)
logger.info(f"✅ 删除图像表: {self.image_table_name}")
except:
pass
# 重新创建表
self._ensure_tables()
logger.info("✅ 数据清空完成")
except Exception as e:
logger.error(f"❌ 清空数据失败: {e}")
def close(self):
"""关闭连接"""
try:
if self.client:
self.client.close()
logger.info("✅ VDB连接已关闭")
except Exception as e:
logger.error(f"❌ 关闭连接失败: {e}")
def test_fixed_vdb():
"""测试修复版VDB"""
print("=" * 60)
print("测试修复版百度VDB后端")
print("=" * 60)
vdb = None
try:
# 1. 初始化VDB
print("1. 初始化VDB连接...")
vdb = BaiduVDBFixed()
print("✅ VDB初始化成功")
# 2. 测试文本向量存储
print("\n2. 测试文本向量存储...")
test_texts = [
"这是一个测试文本",
"另一个测试文本",
"第三个测试文本"
]
# 生成随机向量用于测试
test_vectors = np.random.rand(len(test_texts), 3584).astype(np.float32)
text_ids = vdb.store_text_vectors(test_texts, test_vectors)
print(f"✅ 存储了 {len(text_ids)} 条文本向量")
# 3. 测试文本向量搜索
print("\n3. 测试文本向量搜索...")
query_vector = np.random.rand(3584).astype(np.float32)
search_results = vdb.search_text_vectors(query_vector, top_k=2)
print(f"搜索结果 ({len(search_results)} 条):")
for i, (doc_id, content, score) in enumerate(search_results, 1):
print(f" {i}. {content[:30]}... (相似度: {score:.4f})")
# 4. 获取统计信息
print("\n4. 获取统计信息...")
stats = vdb.get_statistics()
print(f"✅ 统计信息: {stats}")
print(f"\n🎉 修复版VDB测试完成")
print("✅ 表创建成功")
print("✅ 向量存储成功")
print("✅ 向量搜索成功")
return True
except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
return False
finally:
if vdb:
vdb.close()
if __name__ == "__main__":
test_fixed_vdb()