#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 集成百度VDB的多模态检索系统 支持文搜文、文搜图、图搜文、图搜图四种检索模式 """ import torch import numpy as np from PIL import Image from transformers import AutoModel, AutoProcessor, AutoTokenizer from typing import List, Union, Tuple, Dict, Any import os import json import logging import gc from baidu_vdb_backend import BaiduVDBBackend # 设置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class MultimodalRetrievalVDB: """集成百度VDB的多模态检索系统""" def __init__(self, model_name: str = "OpenSearch-AI/Ops-MM-embedding-v1-7B", use_all_gpus: bool = True, gpu_ids: List[int] = None, vdb_config: Dict[str, str] = None): """ 初始化多模态检索系统 Args: model_name: 模型名称 use_all_gpus: 是否使用所有可用GPU gpu_ids: 指定使用的GPU ID列表 vdb_config: VDB配置字典 """ self.model_name = model_name # 设置GPU设备 self._setup_devices(use_all_gpus, gpu_ids) # 清理GPU内存 self._clear_gpu_memory() logger.info(f"正在加载模型到GPU: {self.device_ids}") # 加载模型和处理器 self.model = None self.tokenizer = None self.processor = None self._load_model() # 初始化百度VDB后端 if vdb_config is None: vdb_config = { "account": "root", "api_key": "vdb$yjr9ln3n0td", "endpoint": "http://180.76.96.191:5287", "database_name": "multimodal_retrieval" } try: self.vdb = BaiduVDBBackend(**vdb_config) logger.info("✅ VDB后端初始化成功") except Exception as e: logger.error(f"❌ VDB后端初始化失败: {e}") # 创建一个模拟的VDB后端,避免系统完全崩溃 self.vdb = None logger.warning("⚠️ 系统将在无VDB模式下运行,数据将不会持久化") logger.info("多模态检索系统初始化完成") def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int]): """设置GPU设备""" if not torch.cuda.is_available(): raise RuntimeError("CUDA不可用,无法使用GPU") total_gpus = torch.cuda.device_count() logger.info(f"检测到 {total_gpus} 个GPU") if use_all_gpus: self.device_ids = list(range(total_gpus)) elif gpu_ids: self.device_ids = gpu_ids else: self.device_ids = [0] self.num_gpus = len(self.device_ids) self.primary_device = f"cuda:{self.device_ids[0]}" logger.info(f"使用GPU: {self.device_ids}, 主设备: {self.primary_device}") def _clear_gpu_memory(self): """清理GPU内存""" for gpu_id in self.device_ids: torch.cuda.set_device(gpu_id) torch.cuda.empty_cache() torch.cuda.synchronize() gc.collect() logger.info("GPU内存已清理") def _load_model(self): """加载模型""" try: # 设置环境变量优化内存使用 os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' # 清理GPU内存 self._clear_gpu_memory() # 设置离线模式环境变量 os.environ['TRANSFORMERS_OFFLINE'] = '1' os.environ['HF_HUB_OFFLINE'] = '1' # 尝试加载模型,如果网络失败则使用本地缓存 try: # 加载模型 if self.num_gpus > 1: # 多GPU加载 max_memory = {i: "18GiB" for i in self.device_ids} self.model = AutoModel.from_pretrained( self.model_name, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto", max_memory=max_memory, low_cpu_mem_usage=True, local_files_only=False # 允许从网络下载 ) else: # 单GPU加载 self.model = AutoModel.from_pretrained( self.model_name, trust_remote_code=True, torch_dtype=torch.float16, device_map=self.primary_device, local_files_only=False # 允许从网络下载 ) logger.info("模型从网络加载成功") except Exception as network_error: logger.warning(f"网络加载失败,尝试本地缓存: {network_error}") # 尝试从本地缓存加载 try: if self.num_gpus > 1: max_memory = {i: "18GiB" for i in self.device_ids} self.model = AutoModel.from_pretrained( self.model_name, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto", max_memory=max_memory, low_cpu_mem_usage=True, local_files_only=True # 仅使用本地文件 ) else: self.model = AutoModel.from_pretrained( self.model_name, trust_remote_code=True, torch_dtype=torch.float16, device_map=self.primary_device, local_files_only=True # 仅使用本地文件 ) logger.info("模型从本地缓存加载成功") except Exception as local_error: logger.error(f"本地缓存加载也失败: {local_error}") raise local_error # 加载分词器和处理器 try: self.tokenizer = AutoTokenizer.from_pretrained( self.model_name, trust_remote_code=True, local_files_only=False ) except Exception as e: logger.warning(f"Tokenizer网络加载失败,尝试本地: {e}") self.tokenizer = AutoTokenizer.from_pretrained( self.model_name, trust_remote_code=True, local_files_only=True ) try: self.processor = AutoProcessor.from_pretrained( self.model_name, trust_remote_code=True, local_files_only=False ) except Exception as e: logger.warning(f"Processor加载失败,使用tokenizer: {e}") try: self.processor = AutoProcessor.from_pretrained( self.model_name, trust_remote_code=True, local_files_only=True ) except Exception as e2: logger.warning(f"Processor本地加载也失败,使用tokenizer: {e2}") self.processor = self.tokenizer logger.info("模型加载完成") return True except Exception as e: logger.error(f"模型加载失败: {str(e)}") return False def encode_text_batch(self, texts: List[str]) -> np.ndarray: """ 批量编码文本为向量 Args: texts: 文本列表 Returns: 文本向量数组 """ if not texts: return np.array([]) with torch.no_grad(): # 预处理输入 inputs = self.tokenizer( text=texts, return_tensors="pt", padding=True, truncation=True, max_length=512 ) # 将输入移动到主设备 inputs = {k: v.to(self.primary_device) for k, v in inputs.items()} # 前向传播 outputs = self.model(**inputs) embeddings = outputs.last_hidden_state.mean(dim=1) # 清理GPU内存 del inputs, outputs torch.cuda.empty_cache() return embeddings.cpu().numpy().astype(np.float32) def encode_image_batch(self, images: List[Union[str, Image.Image]]) -> np.ndarray: """ 批量编码图像为向量 Args: images: 图像路径或PIL图像列表 Returns: 图像向量数组 """ if not images: return np.array([]) # 预处理图像 processed_images = [] for img in images: if isinstance(img, str): img = Image.open(img).convert('RGB') elif isinstance(img, Image.Image): img = img.convert('RGB') processed_images.append(img) try: logger.info(f"处理 {len(processed_images)} 张图像") # 使用多模态模型生成图像embedding conversations = [] for i in range(len(processed_images)): conversation = [ { "role": "user", "content": [ {"type": "image", "image": processed_images[i]}, {"type": "text", "text": "What is in this image?"} ] } ] conversations.append(conversation) # 使用processor处理 try: texts = [] for conv in conversations: text = self.processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=False) texts.append(text) # 处理文本和图像 inputs = self.processor( text=texts, images=processed_images, return_tensors="pt", padding=True ) # 移动到GPU inputs = {k: v.to(self.primary_device) for k, v in inputs.items()} # 获取模型输出 with torch.no_grad(): outputs = self.model(**inputs) embeddings = outputs.last_hidden_state.mean(dim=1) # 转换为numpy数组 embeddings = embeddings.cpu().numpy().astype(np.float32) except Exception as inner_e: logger.warning(f"多模态模型图像编码失败: {inner_e}") # 使用零向量作为fallback embedding_dim = 3584 embeddings = np.zeros((len(processed_images), embedding_dim), dtype=np.float32) logger.info(f"生成图像embeddings: {embeddings.shape}") return embeddings except Exception as e: logger.error(f"图像编码失败: {e}") # 返回零向量作为fallback embedding_dim = 3584 embeddings = np.zeros((len(processed_images), embedding_dim), dtype=np.float32) return embeddings def store_texts(self, texts: List[str], metadata: List[Dict] = None) -> List[str]: """ 存储文本数据 Args: texts: 文本列表 metadata: 元数据列表 Returns: 存储的ID列表 """ if self.vdb is None: logger.warning("VDB不可用,文本数据将不会持久化存储") return [] logger.info(f"正在存储 {len(texts)} 条文本数据") # 分批处理 batch_size = 16 all_ids = [] for i in range(0, len(texts), batch_size): batch_texts = texts[i:i+batch_size] batch_metadata = metadata[i:i+batch_size] if metadata else None try: # 编码文本 vectors = self.encode_text_batch(batch_texts) # 存储到VDB ids = self.vdb.store_text_vectors(batch_texts, vectors, batch_metadata) all_ids.extend(ids) logger.info(f"已处理 {i + len(batch_texts)}/{len(texts)} 条文本") except Exception as e: logger.error(f"处理文本批次时出错: {e}") continue logger.info(f"✅ 文本存储完成,共 {len(all_ids)} 条") return all_ids def store_images(self, image_paths: List[str], metadata: List[Dict] = None) -> List[str]: """ 存储图像数据 Args: image_paths: 图像路径列表 metadata: 元数据列表 Returns: 存储的ID列表 """ if self.vdb is None: logger.warning("VDB不可用,图像数据将不会持久化存储") return [] logger.info(f"正在存储 {len(image_paths)} 张图像数据") # 图像处理使用更小的批次 batch_size = 8 all_ids = [] for i in range(0, len(image_paths), batch_size): batch_images = image_paths[i:i+batch_size] batch_metadata = metadata[i:i+batch_size] if metadata else None try: # 编码图像 vectors = self.encode_image_batch(batch_images) # 存储到VDB ids = self.vdb.store_image_vectors(batch_images, vectors, batch_metadata) all_ids.extend(ids) logger.info(f"已处理 {i + len(batch_images)}/{len(image_paths)} 张图像") except Exception as e: logger.error(f"处理图像批次时出错: {e}") continue logger.info(f"✅ 图像存储完成,共 {len(all_ids)} 条") return all_ids def search_text_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: """文搜文:使用文本查询搜索相似文本""" if self.vdb is None: logger.warning("VDB不可用,无法执行搜索") return [] logger.info(f"执行文搜文查询: {query}") # 编码查询文本 query_vector = self.encode_text_batch([query])[0] # 在VDB中搜索 results = self.vdb.search_text_vectors(query_vector, top_k) # 格式化结果 formatted_results = [] for doc_id, text_content, score, metadata in results: formatted_results.append((text_content, score)) return formatted_results def search_images_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: """文搜图:使用文本查询搜索相似图像""" if self.vdb is None: logger.warning("VDB不可用,无法执行搜索") return [] logger.info(f"执行文搜图查询: {query}") # 编码查询文本 query_vector = self.encode_text_batch([query])[0] # 在VDB中搜索图像 results = self.vdb.search_image_vectors(query_vector, top_k) # 格式化结果 formatted_results = [] for doc_id, image_path, image_name, score, metadata in results: formatted_results.append((image_path, score)) return formatted_results def search_text_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]: """图搜文:使用图像查询搜索相似文本""" if self.vdb is None: logger.warning("VDB不可用,无法执行搜索") return [] logger.info(f"执行图搜文查询") # 编码查询图像 query_vector = self.encode_image_batch([query_image])[0] # 在VDB中搜索文本 results = self.vdb.search_text_vectors(query_vector, top_k) # 格式化结果 formatted_results = [] for doc_id, text_content, score, metadata in results: formatted_results.append((text_content, score)) return formatted_results def search_images_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]: """图搜图:使用图像查询搜索相似图像""" if self.vdb is None: logger.warning("VDB不可用,无法执行搜索") return [] logger.info(f"执行图搜图查询") # 编码查询图像 query_vector = self.encode_image_batch([query_image])[0] # 在VDB中搜索图像 results = self.vdb.search_image_vectors(query_vector, top_k) # 格式化结果 formatted_results = [] for doc_id, image_path, image_name, score, metadata in results: formatted_results.append((image_path, score)) return formatted_results # Web应用兼容的方法名称 def search_text_to_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: """文搜文:Web应用兼容方法""" return self.search_text_by_text(query, top_k) def search_text_to_image(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: """文搜图:Web应用兼容方法""" return self.search_images_by_text(query, top_k) def search_image_to_text(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]: """图搜文:Web应用兼容方法""" return self.search_text_by_image(query_image, top_k) def search_image_to_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]: """图搜图:Web应用兼容方法""" return self.search_images_by_image(query_image, top_k) def get_statistics(self) -> Dict[str, Any]: """获取系统统计信息""" if self.vdb is None: return {"error": "VDB不可用"} return self.vdb.get_statistics() def clear_all_data(self): """清空所有数据""" if self.vdb is None: logger.warning("VDB不可用,无法清空数据") return self.vdb.clear_all_data() def close(self): """关闭系统""" if self.vdb: self.vdb.close() self._clear_gpu_memory() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() def check_system_info(): """检查系统信息""" print("=== 多模态检索系统信息 ===") if not torch.cuda.is_available(): print("❌ CUDA不可用") return gpu_count = torch.cuda.device_count() print(f"✅ 检测到 {gpu_count} 个GPU") print(f"CUDA版本: {torch.version.cuda}") print(f"PyTorch版本: {torch.__version__}") for i in range(gpu_count): gpu_name = torch.cuda.get_device_name(i) gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3 print(f"GPU {i}: {gpu_name} ({gpu_memory:.1f}GB)") print("========================") if __name__ == "__main__": # 检查系统环境 check_system_info() # 示例使用 print("\n正在初始化多模态检索系统...") try: retrieval_system = MultimodalRetrievalVDB() print("✅ 系统初始化成功!") # 显示统计信息 stats = retrieval_system.get_statistics() print(f"\n📊 数据库统计信息: {stats}") print("\n🚀 多模态检索系统就绪!") print("支持的检索模式:") print("1. 文搜文: search_text_by_text()") print("2. 文搜图: search_images_by_text()") print("3. 图搜文: search_text_by_image()") print("4. 图搜图: search_images_by_image()") print("5. 存储文本: store_texts()") print("6. 存储图像: store_images()") except Exception as e: print(f"❌ 系统初始化失败: {e}") import traceback traceback.print_exc()