484 lines
17 KiB
Python
484 lines
17 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
百度VDB向量数据库后端
|
||
支持多模态向量存储和检索
|
||
"""
|
||
|
||
import pymochow
|
||
from pymochow.configuration import Configuration
|
||
from pymochow.auth.bce_credentials import BceCredentials
|
||
from pymochow.model.schema import Schema, Field, VectorIndex, HNSWParams, SecondaryIndex
|
||
from pymochow.model.enum import FieldType, IndexType, MetricType
|
||
from pymochow.model.table import Partition, Row, VectorTopkSearchRequest, FloatVector, VectorSearchConfig
|
||
import numpy as np
|
||
import json
|
||
import time
|
||
import logging
|
||
from typing import List, Tuple, Union, Dict, Any
|
||
import hashlib
|
||
import os
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class BaiduVDBBackend:
|
||
"""百度VDB向量数据库后端"""
|
||
|
||
def __init__(self, account: str = "root", api_key: str = "vdb$yjr9ln3n0td",
|
||
endpoint: str = "http://180.76.96.191:5287", database_name: str = "multimodal_retrieval"):
|
||
"""
|
||
初始化百度VDB后端
|
||
|
||
Args:
|
||
account: 用户名
|
||
api_key: API密钥
|
||
endpoint: 服务器端点
|
||
database_name: 数据库名称
|
||
"""
|
||
self.account = account
|
||
self.api_key = api_key
|
||
self.endpoint = endpoint
|
||
self.database_name = database_name
|
||
|
||
# 初始化客户端
|
||
self.client = None
|
||
self.db = None
|
||
self.text_table = None
|
||
self.image_table = None
|
||
|
||
# 表名配置
|
||
self.text_table_name = "text_vectors"
|
||
self.image_table_name = "image_vectors"
|
||
|
||
# 向量维度(根据Ops-MM-embedding-v1-7B模型)
|
||
self.vector_dimension = 3584
|
||
|
||
self._connect()
|
||
self._ensure_database()
|
||
self._ensure_tables()
|
||
|
||
def _connect(self):
|
||
"""连接到百度VDB"""
|
||
try:
|
||
config = Configuration(
|
||
credentials=BceCredentials(self.account, self.api_key),
|
||
endpoint=self.endpoint
|
||
)
|
||
self.client = pymochow.MochowClient(config)
|
||
logger.info(f"✅ 成功连接到百度VDB: {self.endpoint}")
|
||
except Exception as e:
|
||
logger.error(f"❌ 连接百度VDB失败: {e}")
|
||
raise
|
||
|
||
def _ensure_database(self):
|
||
"""确保数据库存在"""
|
||
try:
|
||
# 检查数据库是否存在
|
||
db_list = self.client.list_databases()
|
||
existing_dbs = [db.database_name for db in db_list]
|
||
|
||
if self.database_name not in existing_dbs:
|
||
logger.info(f"创建数据库: {self.database_name}")
|
||
self.db = self.client.create_database(self.database_name)
|
||
else:
|
||
logger.info(f"使用现有数据库: {self.database_name}")
|
||
self.db = self.client.database(self.database_name)
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 数据库操作失败: {e}")
|
||
raise
|
||
|
||
def _ensure_tables(self):
|
||
"""确保表存在"""
|
||
try:
|
||
# 获取现有表列表
|
||
existing_tables = self.db.list_table()
|
||
existing_table_names = [table.table_name for table in existing_tables]
|
||
|
||
# 创建文本向量表
|
||
if self.text_table_name not in existing_table_names:
|
||
self._create_text_table()
|
||
else:
|
||
self.text_table = self.db.table(self.text_table_name)
|
||
logger.info(f"使用现有文本表: {self.text_table_name}")
|
||
|
||
# 创建图像向量表
|
||
if self.image_table_name not in existing_table_names:
|
||
self._create_image_table()
|
||
else:
|
||
self.image_table = self.db.table(self.image_table_name)
|
||
logger.info(f"使用现有图像表: {self.image_table_name}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 表操作失败: {e}")
|
||
raise
|
||
|
||
def _create_text_table(self):
|
||
"""创建文本向量表"""
|
||
try:
|
||
logger.info(f"创建文本向量表: {self.text_table_name}")
|
||
|
||
# 定义字段 - 移除可能导致问题的复杂配置
|
||
fields = [
|
||
Field("id", FieldType.STRING, primary_key=True, not_null=True),
|
||
Field("text_content", FieldType.STRING, not_null=True),
|
||
Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=self.vector_dimension)
|
||
]
|
||
|
||
# 定义索引 - 简化配置
|
||
indexes = [
|
||
VectorIndex(
|
||
index_name="text_vector_idx",
|
||
index_type=IndexType.HNSW,
|
||
field="vector",
|
||
metric_type=MetricType.COSINE,
|
||
params=HNSWParams(m=16, efconstruction=100),
|
||
auto_build=True
|
||
)
|
||
]
|
||
|
||
# 创建表 - 简化配置
|
||
self.text_table = self.db.create_table(
|
||
table_name=self.text_table_name,
|
||
replication=1, # 单副本
|
||
schema=Schema(fields=fields, indexes=indexes)
|
||
)
|
||
|
||
logger.info(f"✅ 文本向量表创建成功")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 创建文本表失败: {e}")
|
||
raise
|
||
|
||
def _create_image_table(self):
|
||
"""创建图像向量表"""
|
||
try:
|
||
logger.info(f"创建图像向量表: {self.image_table_name}")
|
||
|
||
# 定义字段 - 移除可能导致问题的复杂配置
|
||
fields = [
|
||
Field("id", FieldType.STRING, primary_key=True, not_null=True),
|
||
Field("image_path", FieldType.STRING, not_null=True),
|
||
Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=self.vector_dimension)
|
||
]
|
||
|
||
# 定义索引 - 简化配置
|
||
indexes = [
|
||
VectorIndex(
|
||
index_name="image_vector_idx",
|
||
index_type=IndexType.HNSW,
|
||
field="vector",
|
||
metric_type=MetricType.COSINE,
|
||
params=HNSWParams(m=16, efconstruction=100),
|
||
auto_build=True
|
||
)
|
||
]
|
||
|
||
# 创建表 - 简化配置
|
||
self.image_table = self.db.create_table(
|
||
table_name=self.image_table_name,
|
||
replication=1, # 单副本
|
||
schema=Schema(fields=fields, indexes=indexes)
|
||
)
|
||
|
||
logger.info(f"✅ 图像向量表创建成功")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 创建图像表失败: {e}")
|
||
raise
|
||
|
||
def _generate_id(self, content: str) -> str:
|
||
"""生成唯一ID"""
|
||
return hashlib.md5(f"{content}_{time.time()}".encode()).hexdigest()
|
||
|
||
def store_text_vectors(self, texts: List[str], vectors: np.ndarray, metadata: List[Dict] = None) -> List[str]:
|
||
"""
|
||
存储文本向量
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
vectors: 向量数组
|
||
metadata: 元数据列表
|
||
|
||
Returns:
|
||
存储的ID列表
|
||
"""
|
||
if len(texts) != len(vectors):
|
||
raise ValueError("文本数量与向量数量不匹配")
|
||
|
||
try:
|
||
rows = []
|
||
ids = []
|
||
current_time = int(time.time() * 1000) # 毫秒时间戳
|
||
|
||
for i, (text, vector) in enumerate(zip(texts, vectors)):
|
||
doc_id = self._generate_id(text)
|
||
ids.append(doc_id)
|
||
|
||
# 准备元数据
|
||
meta = metadata[i] if metadata and i < len(metadata) else {}
|
||
meta_json = json.dumps(meta, ensure_ascii=False)
|
||
|
||
row = Row(
|
||
id=doc_id,
|
||
text_content=text,
|
||
vector=vector.tolist()
|
||
)
|
||
rows.append(row)
|
||
|
||
# 批量插入
|
||
self.text_table.upsert(rows)
|
||
logger.info(f"✅ 成功存储 {len(texts)} 条文本向量")
|
||
|
||
return ids
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 存储文本向量失败: {e}")
|
||
raise
|
||
|
||
def store_image_vectors(self, image_paths: List[str], vectors: np.ndarray, metadata: List[Dict] = None) -> List[str]:
|
||
"""
|
||
存储图像向量
|
||
|
||
Args:
|
||
image_paths: 图像路径列表
|
||
vectors: 向量数组
|
||
metadata: 元数据列表
|
||
|
||
Returns:
|
||
存储的ID列表
|
||
"""
|
||
if len(image_paths) != len(vectors):
|
||
raise ValueError("图像数量与向量数量不匹配")
|
||
|
||
try:
|
||
rows = []
|
||
ids = []
|
||
current_time = int(time.time() * 1000) # 毫秒时间戳
|
||
|
||
for i, (image_path, vector) in enumerate(zip(image_paths, vectors)):
|
||
doc_id = self._generate_id(image_path)
|
||
ids.append(doc_id)
|
||
|
||
# 准备元数据
|
||
meta = metadata[i] if metadata and i < len(metadata) else {}
|
||
meta_json = json.dumps(meta, ensure_ascii=False)
|
||
|
||
row = Row(
|
||
id=doc_id,
|
||
image_path=image_path,
|
||
vector=vector.tolist()
|
||
)
|
||
rows.append(row)
|
||
|
||
# 批量插入
|
||
self.image_table.upsert(rows)
|
||
logger.info(f"✅ 成功存储 {len(image_paths)} 条图像向量")
|
||
|
||
return ids
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 存储图像向量失败: {e}")
|
||
raise
|
||
|
||
def search_text_vectors(self, query_vector: np.ndarray, top_k: int = 5,
|
||
filter_condition: str = None) -> List[Tuple[str, str, float, Dict]]:
|
||
"""
|
||
搜索文本向量
|
||
|
||
Args:
|
||
query_vector: 查询向量
|
||
top_k: 返回结果数量
|
||
filter_condition: 过滤条件
|
||
|
||
Returns:
|
||
(id, text_content, score, metadata) 列表
|
||
"""
|
||
try:
|
||
# 构建搜索请求
|
||
request = VectorTopkSearchRequest(
|
||
vector_field="vector",
|
||
vector=FloatVector(query_vector.tolist()),
|
||
limit=top_k,
|
||
filter=filter_condition,
|
||
config=VectorSearchConfig(ef=200)
|
||
)
|
||
|
||
# 执行搜索
|
||
results = self.text_table.vector_search(request=request)
|
||
|
||
# 解析结果
|
||
search_results = []
|
||
for result in results:
|
||
doc_id = result.get('id', '')
|
||
text_content = result.get('text_content', '')
|
||
score = result.get('_score', 0.0)
|
||
|
||
search_results.append((doc_id, text_content, float(score), {}))
|
||
|
||
logger.info(f"✅ 文本向量搜索完成,返回 {len(search_results)} 条结果")
|
||
return search_results
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 文本向量搜索失败: {e}")
|
||
return []
|
||
|
||
def search_image_vectors(self, query_vector: np.ndarray, top_k: int = 5,
|
||
filter_condition: str = None) -> List[Tuple[str, str, str, float, Dict]]:
|
||
"""
|
||
搜索图像向量
|
||
|
||
Args:
|
||
query_vector: 查询向量
|
||
top_k: 返回结果数量
|
||
filter_condition: 过滤条件
|
||
|
||
Returns:
|
||
(id, image_path, image_name, score, metadata) 列表
|
||
"""
|
||
try:
|
||
# 构建搜索请求
|
||
request = VectorTopkSearchRequest(
|
||
vector_field="vector",
|
||
vector=FloatVector(query_vector.tolist()),
|
||
limit=top_k,
|
||
filter=filter_condition,
|
||
config=VectorSearchConfig(ef=200)
|
||
)
|
||
|
||
# 执行搜索
|
||
results = self.image_table.vector_search(request=request)
|
||
|
||
# 解析结果
|
||
search_results = []
|
||
for result in results:
|
||
doc_id = result.get('id', '')
|
||
image_path = result.get('image_path', '')
|
||
image_name = os.path.basename(image_path)
|
||
score = result.get('_score', 0.0)
|
||
|
||
search_results.append((doc_id, image_path, image_name, float(score), {}))
|
||
|
||
logger.info(f"✅ 图像向量搜索完成,返回 {len(search_results)} 条结果")
|
||
return search_results
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 图像向量搜索失败: {e}")
|
||
return []
|
||
|
||
def get_statistics(self) -> Dict[str, Any]:
|
||
"""获取数据库统计信息"""
|
||
try:
|
||
stats = {}
|
||
|
||
# 文本表统计
|
||
if self.text_table:
|
||
text_stats = self.text_table.stats()
|
||
stats['text_table'] = {
|
||
'row_count': text_stats.get('rowCount', 0),
|
||
'memory_size_mb': text_stats.get('memorySizeInByte', 0) / (1024 * 1024),
|
||
'disk_size_mb': text_stats.get('diskSizeInByte', 0) / (1024 * 1024)
|
||
}
|
||
|
||
# 图像表统计
|
||
if self.image_table:
|
||
image_stats = self.image_table.stats()
|
||
stats['image_table'] = {
|
||
'row_count': image_stats.get('rowCount', 0),
|
||
'memory_size_mb': image_stats.get('memorySizeInByte', 0) / (1024 * 1024),
|
||
'disk_size_mb': image_stats.get('diskSizeInByte', 0) / (1024 * 1024)
|
||
}
|
||
|
||
return stats
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 获取统计信息失败: {e}")
|
||
return {}
|
||
|
||
def clear_all_data(self):
|
||
"""清空所有数据"""
|
||
try:
|
||
# 清空文本表
|
||
if self.text_table:
|
||
self.text_table.delete(filter="*")
|
||
logger.info("✅ 文本表数据已清空")
|
||
|
||
# 清空图像表
|
||
if self.image_table:
|
||
self.image_table.delete(filter="*")
|
||
logger.info("✅ 图像表数据已清空")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 清空数据失败: {e}")
|
||
raise
|
||
|
||
def close(self):
|
||
"""关闭连接"""
|
||
try:
|
||
if self.client:
|
||
self.client.close()
|
||
logger.info("✅ 百度VDB连接已关闭")
|
||
except Exception as e:
|
||
logger.error(f"❌ 关闭连接失败: {e}")
|
||
|
||
def __enter__(self):
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
self.close()
|
||
|
||
|
||
def test_vdb_backend():
|
||
"""测试VDB后端功能"""
|
||
print("=" * 60)
|
||
print("测试百度VDB后端功能")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
# 初始化后端
|
||
vdb = BaiduVDBBackend()
|
||
|
||
# 测试数据
|
||
test_texts = [
|
||
"这是一个测试文本",
|
||
"多模态检索系统",
|
||
"向量数据库测试"
|
||
]
|
||
|
||
# 生成测试向量(随机向量)
|
||
test_vectors = np.random.rand(3, 3584).astype(np.float32)
|
||
|
||
# 存储文本向量
|
||
print("1. 存储文本向量...")
|
||
text_ids = vdb.store_text_vectors(test_texts, test_vectors)
|
||
print(f" 存储成功,ID: {text_ids}")
|
||
|
||
# 搜索文本向量
|
||
print("\n2. 搜索文本向量...")
|
||
query_vector = np.random.rand(3584).astype(np.float32)
|
||
results = vdb.search_text_vectors(query_vector, top_k=3)
|
||
print(f" 搜索结果: {len(results)} 条")
|
||
for i, (doc_id, text, score, meta) in enumerate(results, 1):
|
||
print(f" {i}. {text[:30]}... (相似度: {score:.4f})")
|
||
|
||
# 获取统计信息
|
||
print("\n3. 数据库统计信息...")
|
||
stats = vdb.get_statistics()
|
||
print(f" 统计信息: {stats}")
|
||
|
||
# 清理测试数据
|
||
print("\n4. 清理测试数据...")
|
||
vdb.clear_all_data()
|
||
print(" 清理完成")
|
||
|
||
print("\n✅ 百度VDB后端测试完成!")
|
||
|
||
except Exception as e:
|
||
print(f"\n❌ 测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
test_vdb_backend()
|