mmeb/multimodal_retrieval_vdb.py
2025-09-01 11:24:01 +00:00

497 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
集成百度VDB的多模态检索系统
支持文搜文、文搜图、图搜文、图搜图四种检索模式
"""
import torch
import numpy as np
from PIL import Image
from transformers import AutoModel, AutoProcessor, AutoTokenizer
from typing import List, Union, Tuple, Dict, Any
import os
import json
import logging
import gc
from baidu_vdb_backend import BaiduVDBBackend
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class MultimodalRetrievalVDB:
"""集成百度VDB的多模态检索系统"""
def __init__(self, model_name: str = "OpenSearch-AI/Ops-MM-embedding-v1-7B",
use_all_gpus: bool = True, gpu_ids: List[int] = None,
vdb_config: Dict[str, str] = None):
"""
初始化多模态检索系统
Args:
model_name: 模型名称
use_all_gpus: 是否使用所有可用GPU
gpu_ids: 指定使用的GPU ID列表
vdb_config: VDB配置字典
"""
self.model_name = model_name
# 设置GPU设备
self._setup_devices(use_all_gpus, gpu_ids)
# 清理GPU内存
self._clear_gpu_memory()
logger.info(f"正在加载模型到GPU: {self.device_ids}")
# 加载模型和处理器
self.model = None
self.tokenizer = None
self.processor = None
self._load_model()
# 初始化百度VDB后端
if vdb_config is None:
vdb_config = {
"account": "root",
"api_key": "vdb$yjr9ln3n0td",
"endpoint": "http://180.76.96.191:5287",
"database_name": "multimodal_retrieval"
}
self.vdb = BaiduVDBBackend(**vdb_config)
logger.info("多模态检索系统初始化完成")
def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int]):
"""设置GPU设备"""
if not torch.cuda.is_available():
raise RuntimeError("CUDA不可用无法使用GPU")
total_gpus = torch.cuda.device_count()
logger.info(f"检测到 {total_gpus} 个GPU")
if use_all_gpus:
self.device_ids = list(range(total_gpus))
elif gpu_ids:
self.device_ids = gpu_ids
else:
self.device_ids = [0]
self.num_gpus = len(self.device_ids)
self.primary_device = f"cuda:{self.device_ids[0]}"
logger.info(f"使用GPU: {self.device_ids}, 主设备: {self.primary_device}")
def _clear_gpu_memory(self):
"""清理GPU内存"""
for gpu_id in self.device_ids:
torch.cuda.set_device(gpu_id)
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
logger.info("GPU内存已清理")
def _load_model(self):
"""加载模型"""
try:
# 设置环境变量优化内存使用
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# 清理GPU内存
self._clear_gpu_memory()
# 加载模型
if self.num_gpus > 1:
# 多GPU加载
max_memory = {i: "18GiB" for i in self.device_ids}
self.model = AutoModel.from_pretrained(
self.model_name,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto",
max_memory=max_memory,
low_cpu_mem_usage=True
)
else:
# 单GPU加载
self.model = AutoModel.from_pretrained(
self.model_name,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map=self.primary_device
)
# 加载分词器和处理器
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True
)
try:
self.processor = AutoProcessor.from_pretrained(
self.model_name,
trust_remote_code=True
)
except Exception as e:
logger.warning(f"Processor加载失败使用tokenizer: {e}")
self.processor = self.tokenizer
logger.info("模型加载完成")
return True
except Exception as e:
logger.error(f"模型加载失败: {str(e)}")
return False
def encode_text_batch(self, texts: List[str]) -> np.ndarray:
"""
批量编码文本为向量
Args:
texts: 文本列表
Returns:
文本向量数组
"""
if not texts:
return np.array([])
with torch.no_grad():
# 预处理输入
inputs = self.tokenizer(
text=texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
# 将输入移动到主设备
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
# 前向传播
outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
# 清理GPU内存
del inputs, outputs
torch.cuda.empty_cache()
return embeddings.cpu().numpy().astype(np.float32)
def encode_image_batch(self, images: List[Union[str, Image.Image]]) -> np.ndarray:
"""
批量编码图像为向量
Args:
images: 图像路径或PIL图像列表
Returns:
图像向量数组
"""
if not images:
return np.array([])
# 预处理图像
processed_images = []
for img in images:
if isinstance(img, str):
img = Image.open(img).convert('RGB')
elif isinstance(img, Image.Image):
img = img.convert('RGB')
processed_images.append(img)
try:
logger.info(f"处理 {len(processed_images)} 张图像")
# 使用多模态模型生成图像embedding
conversations = []
for i in range(len(processed_images)):
conversation = [
{
"role": "user",
"content": [
{"type": "image", "image": processed_images[i]},
{"type": "text", "text": "What is in this image?"}
]
}
]
conversations.append(conversation)
# 使用processor处理
try:
texts = []
for conv in conversations:
text = self.processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
texts.append(text)
# 处理文本和图像
inputs = self.processor(
text=texts,
images=processed_images,
return_tensors="pt",
padding=True
)
# 移动到GPU
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
# 获取模型输出
with torch.no_grad():
outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
# 转换为numpy数组
embeddings = embeddings.cpu().numpy().astype(np.float32)
except Exception as inner_e:
logger.warning(f"多模态模型图像编码失败: {inner_e}")
# 使用零向量作为fallback
embedding_dim = 3584
embeddings = np.zeros((len(processed_images), embedding_dim), dtype=np.float32)
logger.info(f"生成图像embeddings: {embeddings.shape}")
return embeddings
except Exception as e:
logger.error(f"图像编码失败: {e}")
# 返回零向量作为fallback
embedding_dim = 3584
embeddings = np.zeros((len(processed_images), embedding_dim), dtype=np.float32)
return embeddings
def store_texts(self, texts: List[str], metadata: List[Dict] = None) -> List[str]:
"""
存储文本数据
Args:
texts: 文本列表
metadata: 元数据列表
Returns:
存储的ID列表
"""
logger.info(f"正在存储 {len(texts)} 条文本数据")
# 分批处理
batch_size = 16
all_ids = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
batch_metadata = metadata[i:i+batch_size] if metadata else None
try:
# 编码文本
vectors = self.encode_text_batch(batch_texts)
# 存储到VDB
ids = self.vdb.store_text_vectors(batch_texts, vectors, batch_metadata)
all_ids.extend(ids)
logger.info(f"已处理 {i + len(batch_texts)}/{len(texts)} 条文本")
except Exception as e:
logger.error(f"处理文本批次时出错: {e}")
continue
logger.info(f"✅ 文本存储完成,共 {len(all_ids)}")
return all_ids
def store_images(self, image_paths: List[str], metadata: List[Dict] = None) -> List[str]:
"""
存储图像数据
Args:
image_paths: 图像路径列表
metadata: 元数据列表
Returns:
存储的ID列表
"""
logger.info(f"正在存储 {len(image_paths)} 张图像数据")
# 图像处理使用更小的批次
batch_size = 8
all_ids = []
for i in range(0, len(image_paths), batch_size):
batch_images = image_paths[i:i+batch_size]
batch_metadata = metadata[i:i+batch_size] if metadata else None
try:
# 编码图像
vectors = self.encode_image_batch(batch_images)
# 存储到VDB
ids = self.vdb.store_image_vectors(batch_images, vectors, batch_metadata)
all_ids.extend(ids)
logger.info(f"已处理 {i + len(batch_images)}/{len(image_paths)} 张图像")
except Exception as e:
logger.error(f"处理图像批次时出错: {e}")
continue
logger.info(f"✅ 图像存储完成,共 {len(all_ids)}")
return all_ids
def search_text_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
"""文搜文:使用文本查询搜索相似文本"""
logger.info(f"执行文搜文查询: {query}")
# 编码查询文本
query_vector = self.encode_text_batch([query])[0]
# 在VDB中搜索
results = self.vdb.search_text_vectors(query_vector, top_k)
# 格式化结果
formatted_results = []
for doc_id, text_content, score, metadata in results:
formatted_results.append((text_content, score))
return formatted_results
def search_images_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
"""文搜图:使用文本查询搜索相似图像"""
logger.info(f"执行文搜图查询: {query}")
# 编码查询文本
query_vector = self.encode_text_batch([query])[0]
# 在VDB中搜索图像
results = self.vdb.search_image_vectors(query_vector, top_k)
# 格式化结果
formatted_results = []
for doc_id, image_path, image_name, score, metadata in results:
formatted_results.append((image_path, score))
return formatted_results
def search_text_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
"""图搜文:使用图像查询搜索相似文本"""
logger.info(f"执行图搜文查询")
# 编码查询图像
query_vector = self.encode_image_batch([query_image])[0]
# 在VDB中搜索文本
results = self.vdb.search_text_vectors(query_vector, top_k)
# 格式化结果
formatted_results = []
for doc_id, text_content, score, metadata in results:
formatted_results.append((text_content, score))
return formatted_results
def search_images_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
"""图搜图:使用图像查询搜索相似图像"""
logger.info(f"执行图搜图查询")
# 编码查询图像
query_vector = self.encode_image_batch([query_image])[0]
# 在VDB中搜索图像
results = self.vdb.search_image_vectors(query_vector, top_k)
# 格式化结果
formatted_results = []
for doc_id, image_path, image_name, score, metadata in results:
formatted_results.append((image_path, score))
return formatted_results
# Web应用兼容的方法名称
def search_text_to_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
"""文搜文Web应用兼容方法"""
return self.search_text_by_text(query, top_k)
def search_text_to_image(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
"""文搜图Web应用兼容方法"""
return self.search_images_by_text(query, top_k)
def search_image_to_text(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
"""图搜文Web应用兼容方法"""
return self.search_text_by_image(query_image, top_k)
def search_image_to_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
"""图搜图Web应用兼容方法"""
return self.search_images_by_image(query_image, top_k)
def get_statistics(self) -> Dict[str, Any]:
"""获取系统统计信息"""
return self.vdb.get_statistics()
def clear_all_data(self):
"""清空所有数据"""
self.vdb.clear_all_data()
def close(self):
"""关闭系统"""
if self.vdb:
self.vdb.close()
self._clear_gpu_memory()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def check_system_info():
"""检查系统信息"""
print("=== 多模态检索系统信息 ===")
if not torch.cuda.is_available():
print("❌ CUDA不可用")
return
gpu_count = torch.cuda.device_count()
print(f"✅ 检测到 {gpu_count} 个GPU")
print(f"CUDA版本: {torch.version.cuda}")
print(f"PyTorch版本: {torch.__version__}")
for i in range(gpu_count):
gpu_name = torch.cuda.get_device_name(i)
gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
print(f"GPU {i}: {gpu_name} ({gpu_memory:.1f}GB)")
print("========================")
if __name__ == "__main__":
# 检查系统环境
check_system_info()
# 示例使用
print("\n正在初始化多模态检索系统...")
try:
retrieval_system = MultimodalRetrievalVDB()
print("✅ 系统初始化成功!")
# 显示统计信息
stats = retrieval_system.get_statistics()
print(f"\n📊 数据库统计信息: {stats}")
print("\n🚀 多模态检索系统就绪!")
print("支持的检索模式:")
print("1. 文搜文: search_text_by_text()")
print("2. 文搜图: search_images_by_text()")
print("3. 图搜文: search_text_by_image()")
print("4. 图搜图: search_images_by_image()")
print("5. 存储文本: store_texts()")
print("6. 存储图像: store_images()")
except Exception as e:
print(f"❌ 系统初始化失败: {e}")
import traceback
traceback.print_exc()