444 lines
17 KiB
Python
444 lines
17 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
纯百度VDB多模态检索系统 - 完全替代FAISS
|
||
支持文搜文、文搜图、图搜文、图搜图四种检索模式
|
||
"""
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||
import numpy as np
|
||
from PIL import Image
|
||
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
||
from typing import List, Union, Tuple, Dict, Any
|
||
import os
|
||
import json
|
||
from pathlib import Path
|
||
import logging
|
||
import gc
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
import threading
|
||
|
||
from baidu_vdb_production import BaiduVDBProduction
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class MultimodalRetrievalVDBOnly:
|
||
"""纯百度VDB多模态检索系统,完全替代FAISS"""
|
||
|
||
def __init__(self, model_name: str = "OpenSearch-AI/Ops-MM-embedding-v1-7B",
|
||
use_all_gpus: bool = True, gpu_ids: List[int] = None, min_memory_gb=12):
|
||
"""
|
||
初始化纯VDB多模态检索系统
|
||
|
||
Args:
|
||
model_name: 模型名称
|
||
use_all_gpus: 是否使用所有可用GPU
|
||
gpu_ids: 指定使用的GPU ID列表
|
||
min_memory_gb: 最小可用内存(GB)
|
||
"""
|
||
self.model_name = model_name
|
||
|
||
# 设置GPU设备
|
||
self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb)
|
||
|
||
# 清理GPU内存
|
||
self._clear_all_gpu_memory()
|
||
|
||
logger.info(f"正在加载模型到多GPU: {self.device_ids}")
|
||
|
||
# 加载模型和处理器
|
||
self.model = None
|
||
self.tokenizer = None
|
||
self.processor = None
|
||
self._load_model_multigpu()
|
||
|
||
# 初始化百度VDB后端(替代FAISS索引)
|
||
logger.info("初始化百度VDB后端...")
|
||
self.vdb = BaiduVDBProduction()
|
||
logger.info("✅ 百度VDB后端初始化完成")
|
||
|
||
# 线程锁
|
||
self.model_lock = threading.Lock()
|
||
|
||
logger.info("✅ 纯VDB多模态检索系统初始化完成")
|
||
|
||
def _setup_devices(self, use_all_gpus, gpu_ids, min_memory_gb):
|
||
"""设置GPU设备"""
|
||
if not torch.cuda.is_available():
|
||
raise RuntimeError("CUDA不可用,需要GPU支持")
|
||
|
||
total_gpus = torch.cuda.device_count()
|
||
logger.info(f"检测到 {total_gpus} 个GPU")
|
||
|
||
# 获取可用GPU
|
||
available_gpus = []
|
||
for i in range(total_gpus):
|
||
memory_gb = torch.cuda.get_device_properties(i).total_memory / (1024**3)
|
||
free_memory = torch.cuda.memory_reserved(i) / (1024**3)
|
||
available_memory = memory_gb - free_memory
|
||
|
||
logger.info(f"GPU {i}: {torch.cuda.get_device_properties(i).name} ({memory_gb:.1f}GB)")
|
||
|
||
if available_memory >= min_memory_gb:
|
||
available_gpus.append(i)
|
||
logger.info(f"GPU {i}: {available_memory:.0f}MB 可用 (合适)")
|
||
else:
|
||
logger.info(f"GPU {i}: {available_memory:.0f}MB 可用 (不足)")
|
||
|
||
if not available_gpus:
|
||
raise RuntimeError(f"没有找到满足 {min_memory_gb}GB 内存要求的GPU")
|
||
|
||
# 选择使用的GPU
|
||
if gpu_ids:
|
||
self.device_ids = [gpu_id for gpu_id in gpu_ids if gpu_id in available_gpus]
|
||
elif use_all_gpus:
|
||
self.device_ids = available_gpus
|
||
else:
|
||
self.device_ids = [available_gpus[0]]
|
||
|
||
if not self.device_ids:
|
||
raise RuntimeError("没有可用的GPU设备")
|
||
|
||
# 设置主设备
|
||
self.primary_device = f"cuda:{self.device_ids[0]}"
|
||
torch.cuda.set_device(self.device_ids[0])
|
||
|
||
logger.info(f"使用GPU: {self.device_ids}, 主设备: {self.primary_device}")
|
||
|
||
def _clear_all_gpu_memory(self):
|
||
"""清理所有GPU内存"""
|
||
for device_id in self.device_ids:
|
||
with torch.cuda.device(device_id):
|
||
torch.cuda.empty_cache()
|
||
gc.collect()
|
||
logger.info("所有GPU内存已清理")
|
||
|
||
def _load_model_multigpu(self):
|
||
"""加载模型到多GPU"""
|
||
try:
|
||
# 清理GPU内存
|
||
self._clear_all_gpu_memory()
|
||
|
||
logger.info(f"正在加载模型到多GPU: {self.device_ids}")
|
||
|
||
# 加载模型
|
||
self.model = AutoModel.from_pretrained(
|
||
self.model_name,
|
||
torch_dtype=torch.float16,
|
||
trust_remote_code=True,
|
||
device_map="auto"
|
||
)
|
||
|
||
# 加载tokenizer和processor
|
||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
|
||
logger.info("Tokenizer加载成功")
|
||
|
||
self.processor = AutoProcessor.from_pretrained(self.model_name, trust_remote_code=True)
|
||
logger.info("Processor加载成功")
|
||
|
||
# 显示设备映射
|
||
if hasattr(self.model, 'hf_device_map'):
|
||
logger.info(f"模型已成功加载到设备: {dict(list(self.model.hf_device_map.items())[:10])}")
|
||
|
||
self.model.eval()
|
||
logger.info("多GPU模型加载完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"模型加载失败: {e}")
|
||
raise
|
||
|
||
def encode_text_batch(self, texts: List[str], batch_size: int = 8) -> np.ndarray:
|
||
"""批量编码文本"""
|
||
try:
|
||
with self.model_lock:
|
||
all_embeddings = []
|
||
|
||
for i in range(0, len(texts), batch_size):
|
||
batch_texts = texts[i:i + batch_size]
|
||
|
||
# 使用processor处理文本
|
||
inputs = self.processor(
|
||
text=batch_texts,
|
||
return_tensors="pt",
|
||
padding=True,
|
||
truncation=True,
|
||
max_length=512
|
||
)
|
||
|
||
# 将输入移动到主设备
|
||
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
|
||
|
||
with torch.no_grad():
|
||
outputs = self.model(**inputs)
|
||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||
embeddings = embeddings.cpu().numpy()
|
||
all_embeddings.append(embeddings)
|
||
|
||
return np.vstack(all_embeddings)
|
||
|
||
except Exception as e:
|
||
logger.error(f"文本编码失败: {e}")
|
||
return np.zeros((len(texts), 3584), dtype=np.float32)
|
||
|
||
def encode_image_batch(self, images: List[Union[str, Image.Image]], batch_size: int = 4) -> np.ndarray:
|
||
"""批量编码图像"""
|
||
try:
|
||
with self.model_lock:
|
||
processed_images = []
|
||
|
||
# 处理图像输入
|
||
for img in images:
|
||
if isinstance(img, str):
|
||
if os.path.exists(img):
|
||
processed_images.append(Image.open(img).convert('RGB'))
|
||
else:
|
||
logger.warning(f"图像文件不存在: {img}")
|
||
processed_images.append(Image.new('RGB', (224, 224), color='white'))
|
||
elif isinstance(img, Image.Image):
|
||
processed_images.append(img.convert('RGB'))
|
||
else:
|
||
logger.warning(f"不支持的图像类型: {type(img)}")
|
||
processed_images.append(Image.new('RGB', (224, 224), color='white'))
|
||
|
||
all_embeddings = []
|
||
|
||
for i in range(0, len(processed_images), batch_size):
|
||
batch_images = processed_images[i:i + batch_size]
|
||
|
||
# 使用processor处理图像
|
||
inputs = self.processor(
|
||
images=batch_images,
|
||
return_tensors="pt",
|
||
padding=True
|
||
)
|
||
|
||
# 将输入移动到主设备
|
||
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
|
||
|
||
with torch.no_grad():
|
||
outputs = self.model(**inputs)
|
||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||
embeddings = embeddings.cpu().numpy()
|
||
all_embeddings.append(embeddings)
|
||
|
||
return np.vstack(all_embeddings)
|
||
|
||
except Exception as e:
|
||
logger.error(f"图像编码失败: {e}")
|
||
embedding_dim = 3584
|
||
embeddings = np.zeros((len(images), embedding_dim), dtype=np.float32)
|
||
return embeddings
|
||
|
||
def build_text_index_parallel(self, texts: List[str], save_path: str = None):
|
||
"""
|
||
构建文本索引(使用VDB替代FAISS)
|
||
"""
|
||
try:
|
||
logger.info(f"正在构建文本索引,共 {len(texts)} 条文本")
|
||
|
||
# 编码文本
|
||
embeddings = self.encode_text_batch(texts)
|
||
|
||
# 使用VDB存储
|
||
self.vdb.build_text_index(texts, embeddings)
|
||
|
||
logger.info("文本索引构建完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"构建文本索引失败: {e}")
|
||
raise
|
||
|
||
def build_image_index_parallel(self, image_paths: List[str], save_path: str = None):
|
||
"""
|
||
构建图像索引(使用VDB替代FAISS)
|
||
"""
|
||
try:
|
||
logger.info(f"正在构建图像索引,共 {len(image_paths)} 张图像")
|
||
|
||
# 编码图像
|
||
embeddings = self.encode_image_batch(image_paths)
|
||
|
||
# 使用VDB存储
|
||
self.vdb.build_image_index(image_paths, embeddings)
|
||
|
||
logger.info("图像索引构建完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"构建图像索引失败: {e}")
|
||
raise
|
||
|
||
def search_text_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||
"""文搜文:使用文本查询搜索相似文本"""
|
||
try:
|
||
query_embedding = self.encode_text_batch([query])
|
||
return self.vdb.search_text_by_text(query_embedding[0], top_k)
|
||
except Exception as e:
|
||
logger.error(f"文搜文失败: {e}")
|
||
return []
|
||
|
||
def search_images_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||
"""文搜图:使用文本查询搜索相似图像"""
|
||
try:
|
||
query_embedding = self.encode_text_batch([query])
|
||
return self.vdb.search_images_by_text(query_embedding[0], top_k)
|
||
except Exception as e:
|
||
logger.error(f"文搜图失败: {e}")
|
||
return []
|
||
|
||
def search_images_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||
"""图搜图:使用图像查询搜索相似图像"""
|
||
try:
|
||
query_embedding = self.encode_image_batch([query_image])
|
||
return self.vdb.search_images_by_image(query_embedding[0], top_k)
|
||
except Exception as e:
|
||
logger.error(f"图搜图失败: {e}")
|
||
return []
|
||
|
||
def search_text_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||
"""图搜文:使用图像查询搜索相似文本"""
|
||
try:
|
||
query_embedding = self.encode_image_batch([query_image])
|
||
return self.vdb.search_text_by_image(query_embedding[0], top_k)
|
||
except Exception as e:
|
||
logger.error(f"图搜文失败: {e}")
|
||
return []
|
||
|
||
# Web应用兼容方法
|
||
def search_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||
"""文搜文:Web应用兼容方法"""
|
||
return self.search_text_by_text(query, top_k)
|
||
|
||
def search_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||
"""图搜图:Web应用兼容方法"""
|
||
return self.search_images_by_image(query_image, top_k)
|
||
|
||
def search_images_by_text_query(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||
"""文搜图:Web应用兼容方法"""
|
||
return self.search_images_by_text(query, top_k)
|
||
|
||
def search_texts_by_image_query(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||
"""图搜文:Web应用兼容方法"""
|
||
return self.search_text_by_image(query_image, top_k)
|
||
|
||
def get_statistics(self) -> Dict[str, Any]:
|
||
"""获取系统统计信息"""
|
||
try:
|
||
vdb_stats = self.vdb.get_statistics()
|
||
|
||
stats = {
|
||
"model_name": self.model_name,
|
||
"device_ids": self.device_ids,
|
||
"primary_device": self.primary_device,
|
||
"backend": "Baidu VDB (No FAISS)",
|
||
**vdb_stats
|
||
}
|
||
|
||
return stats
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取统计信息失败: {e}")
|
||
return {"status": "error", "error": str(e)}
|
||
|
||
def clear_all_data(self):
|
||
"""清空所有数据"""
|
||
try:
|
||
self.vdb.clear_all_data()
|
||
logger.info("✅ 所有数据已清空")
|
||
except Exception as e:
|
||
logger.error(f"❌ 清空数据失败: {e}")
|
||
|
||
def get_gpu_memory_info(self):
|
||
"""获取所有GPU内存使用信息"""
|
||
memory_info = {}
|
||
for device_id in self.device_ids:
|
||
with torch.cuda.device(device_id):
|
||
allocated = torch.cuda.memory_allocated() / (1024**3)
|
||
reserved = torch.cuda.memory_reserved() / (1024**3)
|
||
total = torch.cuda.get_device_properties(device_id).total_memory / (1024**3)
|
||
|
||
memory_info[f"GPU_{device_id}"] = {
|
||
"allocated_GB": round(allocated, 2),
|
||
"reserved_GB": round(reserved, 2),
|
||
"total_GB": round(total, 2),
|
||
"free_GB": round(total - reserved, 2)
|
||
}
|
||
|
||
return memory_info
|
||
|
||
def cleanup(self):
|
||
"""清理资源"""
|
||
try:
|
||
if self.vdb:
|
||
self.vdb.close()
|
||
|
||
self._clear_all_gpu_memory()
|
||
logger.info("✅ 资源清理完成")
|
||
except Exception as e:
|
||
logger.error(f"❌ 资源清理失败: {e}")
|
||
|
||
def test_vdb_only_system():
|
||
"""测试纯VDB多模态检索系统"""
|
||
print("=" * 60)
|
||
print("测试纯百度VDB多模态检索系统")
|
||
print("=" * 60)
|
||
|
||
system = None
|
||
|
||
try:
|
||
# 1. 初始化系统
|
||
print("1. 初始化纯VDB多模态检索系统...")
|
||
system = MultimodalRetrievalVDBOnly()
|
||
print("✅ 系统初始化成功")
|
||
|
||
# 2. 构建文本索引
|
||
print("\n2. 构建文本索引...")
|
||
test_texts = [
|
||
"人工智能技术的发展趋势",
|
||
"机器学习在医疗领域的应用",
|
||
"深度学习算法优化方法",
|
||
"计算机视觉技术创新",
|
||
"自然语言处理最新进展"
|
||
]
|
||
|
||
system.build_text_index_parallel(test_texts)
|
||
print("✅ 文本索引构建完成")
|
||
|
||
# 3. 测试文搜文
|
||
print("\n3. 测试文搜文...")
|
||
query = "AI技术"
|
||
results = system.search_text_by_text(query, top_k=3)
|
||
print(f"查询: {query}")
|
||
for i, (text, score) in enumerate(results, 1):
|
||
print(f" {i}. {text} (相似度: {score:.3f})")
|
||
|
||
# 4. 获取统计信息
|
||
print("\n4. 获取统计信息...")
|
||
stats = system.get_statistics()
|
||
print("系统统计:")
|
||
for key, value in stats.items():
|
||
print(f" {key}: {value}")
|
||
|
||
print(f"\n🎉 纯VDB系统测试完成!")
|
||
print("✅ 完全移除FAISS依赖")
|
||
print("✅ 使用百度VDB作为向量数据库")
|
||
print("✅ 支持多模态检索功能")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"❌ 测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
finally:
|
||
if system:
|
||
system.cleanup()
|
||
|
||
if __name__ == "__main__":
|
||
test_vdb_only_system()
|