mmeb/multimodal_retrieval_vdb_only.py
2025-09-01 11:24:01 +00:00

444 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
纯百度VDB多模态检索系统 - 完全替代FAISS
支持文搜文、文搜图、图搜文、图搜图四种检索模式
"""
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
import numpy as np
from PIL import Image
from transformers import AutoModel, AutoProcessor, AutoTokenizer
from typing import List, Union, Tuple, Dict, Any
import os
import json
from pathlib import Path
import logging
import gc
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
from baidu_vdb_production import BaiduVDBProduction
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class MultimodalRetrievalVDBOnly:
"""纯百度VDB多模态检索系统完全替代FAISS"""
def __init__(self, model_name: str = "OpenSearch-AI/Ops-MM-embedding-v1-7B",
use_all_gpus: bool = True, gpu_ids: List[int] = None, min_memory_gb=12):
"""
初始化纯VDB多模态检索系统
Args:
model_name: 模型名称
use_all_gpus: 是否使用所有可用GPU
gpu_ids: 指定使用的GPU ID列表
min_memory_gb: 最小可用内存GB
"""
self.model_name = model_name
# 设置GPU设备
self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb)
# 清理GPU内存
self._clear_all_gpu_memory()
logger.info(f"正在加载模型到多GPU: {self.device_ids}")
# 加载模型和处理器
self.model = None
self.tokenizer = None
self.processor = None
self._load_model_multigpu()
# 初始化百度VDB后端替代FAISS索引
logger.info("初始化百度VDB后端...")
self.vdb = BaiduVDBProduction()
logger.info("✅ 百度VDB后端初始化完成")
# 线程锁
self.model_lock = threading.Lock()
logger.info("✅ 纯VDB多模态检索系统初始化完成")
def _setup_devices(self, use_all_gpus, gpu_ids, min_memory_gb):
"""设置GPU设备"""
if not torch.cuda.is_available():
raise RuntimeError("CUDA不可用需要GPU支持")
total_gpus = torch.cuda.device_count()
logger.info(f"检测到 {total_gpus} 个GPU")
# 获取可用GPU
available_gpus = []
for i in range(total_gpus):
memory_gb = torch.cuda.get_device_properties(i).total_memory / (1024**3)
free_memory = torch.cuda.memory_reserved(i) / (1024**3)
available_memory = memory_gb - free_memory
logger.info(f"GPU {i}: {torch.cuda.get_device_properties(i).name} ({memory_gb:.1f}GB)")
if available_memory >= min_memory_gb:
available_gpus.append(i)
logger.info(f"GPU {i}: {available_memory:.0f}MB 可用 (合适)")
else:
logger.info(f"GPU {i}: {available_memory:.0f}MB 可用 (不足)")
if not available_gpus:
raise RuntimeError(f"没有找到满足 {min_memory_gb}GB 内存要求的GPU")
# 选择使用的GPU
if gpu_ids:
self.device_ids = [gpu_id for gpu_id in gpu_ids if gpu_id in available_gpus]
elif use_all_gpus:
self.device_ids = available_gpus
else:
self.device_ids = [available_gpus[0]]
if not self.device_ids:
raise RuntimeError("没有可用的GPU设备")
# 设置主设备
self.primary_device = f"cuda:{self.device_ids[0]}"
torch.cuda.set_device(self.device_ids[0])
logger.info(f"使用GPU: {self.device_ids}, 主设备: {self.primary_device}")
def _clear_all_gpu_memory(self):
"""清理所有GPU内存"""
for device_id in self.device_ids:
with torch.cuda.device(device_id):
torch.cuda.empty_cache()
gc.collect()
logger.info("所有GPU内存已清理")
def _load_model_multigpu(self):
"""加载模型到多GPU"""
try:
# 清理GPU内存
self._clear_all_gpu_memory()
logger.info(f"正在加载模型到多GPU: {self.device_ids}")
# 加载模型
self.model = AutoModel.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto"
)
# 加载tokenizer和processor
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
logger.info("Tokenizer加载成功")
self.processor = AutoProcessor.from_pretrained(self.model_name, trust_remote_code=True)
logger.info("Processor加载成功")
# 显示设备映射
if hasattr(self.model, 'hf_device_map'):
logger.info(f"模型已成功加载到设备: {dict(list(self.model.hf_device_map.items())[:10])}")
self.model.eval()
logger.info("多GPU模型加载完成")
except Exception as e:
logger.error(f"模型加载失败: {e}")
raise
def encode_text_batch(self, texts: List[str], batch_size: int = 8) -> np.ndarray:
"""批量编码文本"""
try:
with self.model_lock:
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
# 使用processor处理文本
inputs = self.processor(
text=batch_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
# 将输入移动到主设备
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
embeddings = embeddings.cpu().numpy()
all_embeddings.append(embeddings)
return np.vstack(all_embeddings)
except Exception as e:
logger.error(f"文本编码失败: {e}")
return np.zeros((len(texts), 3584), dtype=np.float32)
def encode_image_batch(self, images: List[Union[str, Image.Image]], batch_size: int = 4) -> np.ndarray:
"""批量编码图像"""
try:
with self.model_lock:
processed_images = []
# 处理图像输入
for img in images:
if isinstance(img, str):
if os.path.exists(img):
processed_images.append(Image.open(img).convert('RGB'))
else:
logger.warning(f"图像文件不存在: {img}")
processed_images.append(Image.new('RGB', (224, 224), color='white'))
elif isinstance(img, Image.Image):
processed_images.append(img.convert('RGB'))
else:
logger.warning(f"不支持的图像类型: {type(img)}")
processed_images.append(Image.new('RGB', (224, 224), color='white'))
all_embeddings = []
for i in range(0, len(processed_images), batch_size):
batch_images = processed_images[i:i + batch_size]
# 使用processor处理图像
inputs = self.processor(
images=batch_images,
return_tensors="pt",
padding=True
)
# 将输入移动到主设备
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
embeddings = embeddings.cpu().numpy()
all_embeddings.append(embeddings)
return np.vstack(all_embeddings)
except Exception as e:
logger.error(f"图像编码失败: {e}")
embedding_dim = 3584
embeddings = np.zeros((len(images), embedding_dim), dtype=np.float32)
return embeddings
def build_text_index_parallel(self, texts: List[str], save_path: str = None):
"""
构建文本索引使用VDB替代FAISS
"""
try:
logger.info(f"正在构建文本索引,共 {len(texts)} 条文本")
# 编码文本
embeddings = self.encode_text_batch(texts)
# 使用VDB存储
self.vdb.build_text_index(texts, embeddings)
logger.info("文本索引构建完成")
except Exception as e:
logger.error(f"构建文本索引失败: {e}")
raise
def build_image_index_parallel(self, image_paths: List[str], save_path: str = None):
"""
构建图像索引使用VDB替代FAISS
"""
try:
logger.info(f"正在构建图像索引,共 {len(image_paths)} 张图像")
# 编码图像
embeddings = self.encode_image_batch(image_paths)
# 使用VDB存储
self.vdb.build_image_index(image_paths, embeddings)
logger.info("图像索引构建完成")
except Exception as e:
logger.error(f"构建图像索引失败: {e}")
raise
def search_text_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
"""文搜文:使用文本查询搜索相似文本"""
try:
query_embedding = self.encode_text_batch([query])
return self.vdb.search_text_by_text(query_embedding[0], top_k)
except Exception as e:
logger.error(f"文搜文失败: {e}")
return []
def search_images_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
"""文搜图:使用文本查询搜索相似图像"""
try:
query_embedding = self.encode_text_batch([query])
return self.vdb.search_images_by_text(query_embedding[0], top_k)
except Exception as e:
logger.error(f"文搜图失败: {e}")
return []
def search_images_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
"""图搜图:使用图像查询搜索相似图像"""
try:
query_embedding = self.encode_image_batch([query_image])
return self.vdb.search_images_by_image(query_embedding[0], top_k)
except Exception as e:
logger.error(f"图搜图失败: {e}")
return []
def search_text_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
"""图搜文:使用图像查询搜索相似文本"""
try:
query_embedding = self.encode_image_batch([query_image])
return self.vdb.search_text_by_image(query_embedding[0], top_k)
except Exception as e:
logger.error(f"图搜文失败: {e}")
return []
# Web应用兼容方法
def search_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
"""文搜文Web应用兼容方法"""
return self.search_text_by_text(query, top_k)
def search_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
"""图搜图Web应用兼容方法"""
return self.search_images_by_image(query_image, top_k)
def search_images_by_text_query(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
"""文搜图Web应用兼容方法"""
return self.search_images_by_text(query, top_k)
def search_texts_by_image_query(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
"""图搜文Web应用兼容方法"""
return self.search_text_by_image(query_image, top_k)
def get_statistics(self) -> Dict[str, Any]:
"""获取系统统计信息"""
try:
vdb_stats = self.vdb.get_statistics()
stats = {
"model_name": self.model_name,
"device_ids": self.device_ids,
"primary_device": self.primary_device,
"backend": "Baidu VDB (No FAISS)",
**vdb_stats
}
return stats
except Exception as e:
logger.error(f"获取统计信息失败: {e}")
return {"status": "error", "error": str(e)}
def clear_all_data(self):
"""清空所有数据"""
try:
self.vdb.clear_all_data()
logger.info("✅ 所有数据已清空")
except Exception as e:
logger.error(f"❌ 清空数据失败: {e}")
def get_gpu_memory_info(self):
"""获取所有GPU内存使用信息"""
memory_info = {}
for device_id in self.device_ids:
with torch.cuda.device(device_id):
allocated = torch.cuda.memory_allocated() / (1024**3)
reserved = torch.cuda.memory_reserved() / (1024**3)
total = torch.cuda.get_device_properties(device_id).total_memory / (1024**3)
memory_info[f"GPU_{device_id}"] = {
"allocated_GB": round(allocated, 2),
"reserved_GB": round(reserved, 2),
"total_GB": round(total, 2),
"free_GB": round(total - reserved, 2)
}
return memory_info
def cleanup(self):
"""清理资源"""
try:
if self.vdb:
self.vdb.close()
self._clear_all_gpu_memory()
logger.info("✅ 资源清理完成")
except Exception as e:
logger.error(f"❌ 资源清理失败: {e}")
def test_vdb_only_system():
"""测试纯VDB多模态检索系统"""
print("=" * 60)
print("测试纯百度VDB多模态检索系统")
print("=" * 60)
system = None
try:
# 1. 初始化系统
print("1. 初始化纯VDB多模态检索系统...")
system = MultimodalRetrievalVDBOnly()
print("✅ 系统初始化成功")
# 2. 构建文本索引
print("\n2. 构建文本索引...")
test_texts = [
"人工智能技术的发展趋势",
"机器学习在医疗领域的应用",
"深度学习算法优化方法",
"计算机视觉技术创新",
"自然语言处理最新进展"
]
system.build_text_index_parallel(test_texts)
print("✅ 文本索引构建完成")
# 3. 测试文搜文
print("\n3. 测试文搜文...")
query = "AI技术"
results = system.search_text_by_text(query, top_k=3)
print(f"查询: {query}")
for i, (text, score) in enumerate(results, 1):
print(f" {i}. {text} (相似度: {score:.3f})")
# 4. 获取统计信息
print("\n4. 获取统计信息...")
stats = system.get_statistics()
print("系统统计:")
for key, value in stats.items():
print(f" {key}: {value}")
print(f"\n🎉 纯VDB系统测试完成")
print("✅ 完全移除FAISS依赖")
print("✅ 使用百度VDB作为向量数据库")
print("✅ 支持多模态检索功能")
return True
except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
return False
finally:
if system:
system.cleanup()
if __name__ == "__main__":
test_vdb_only_system()