#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 纯百度VDB多模态检索系统 - 完全替代FAISS 支持文搜文、文搜图、图搜文、图搜图四种检索模式 """ import torch import torch.nn as nn from torch.nn.parallel import DataParallel, DistributedDataParallel import numpy as np from PIL import Image from transformers import AutoModel, AutoProcessor, AutoTokenizer from typing import List, Union, Tuple, Dict, Any import os import json from pathlib import Path import logging import gc from concurrent.futures import ThreadPoolExecutor, as_completed import threading from baidu_vdb_production import BaiduVDBProduction # 设置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class MultimodalRetrievalVDBOnly: """纯百度VDB多模态检索系统,完全替代FAISS""" def __init__(self, model_name: str = "OpenSearch-AI/Ops-MM-embedding-v1-7B", use_all_gpus: bool = True, gpu_ids: List[int] = None, min_memory_gb=12): """ 初始化纯VDB多模态检索系统 Args: model_name: 模型名称 use_all_gpus: 是否使用所有可用GPU gpu_ids: 指定使用的GPU ID列表 min_memory_gb: 最小可用内存(GB) """ self.model_name = model_name # 设置GPU设备 self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb) # 清理GPU内存 self._clear_all_gpu_memory() logger.info(f"正在加载模型到多GPU: {self.device_ids}") # 加载模型和处理器 self.model = None self.tokenizer = None self.processor = None self._load_model_multigpu() # 初始化百度VDB后端(替代FAISS索引) logger.info("初始化百度VDB后端...") self.vdb = BaiduVDBProduction() logger.info("✅ 百度VDB后端初始化完成") # 线程锁 self.model_lock = threading.Lock() logger.info("✅ 纯VDB多模态检索系统初始化完成") def _setup_devices(self, use_all_gpus, gpu_ids, min_memory_gb): """设置GPU设备""" if not torch.cuda.is_available(): raise RuntimeError("CUDA不可用,需要GPU支持") total_gpus = torch.cuda.device_count() logger.info(f"检测到 {total_gpus} 个GPU") # 获取可用GPU available_gpus = [] for i in range(total_gpus): memory_gb = torch.cuda.get_device_properties(i).total_memory / (1024**3) free_memory = torch.cuda.memory_reserved(i) / (1024**3) available_memory = memory_gb - free_memory logger.info(f"GPU {i}: {torch.cuda.get_device_properties(i).name} ({memory_gb:.1f}GB)") if available_memory >= min_memory_gb: available_gpus.append(i) logger.info(f"GPU {i}: {available_memory:.0f}MB 可用 (合适)") else: logger.info(f"GPU {i}: {available_memory:.0f}MB 可用 (不足)") if not available_gpus: raise RuntimeError(f"没有找到满足 {min_memory_gb}GB 内存要求的GPU") # 选择使用的GPU if gpu_ids: self.device_ids = [gpu_id for gpu_id in gpu_ids if gpu_id in available_gpus] elif use_all_gpus: self.device_ids = available_gpus else: self.device_ids = [available_gpus[0]] if not self.device_ids: raise RuntimeError("没有可用的GPU设备") # 设置主设备 self.primary_device = f"cuda:{self.device_ids[0]}" torch.cuda.set_device(self.device_ids[0]) logger.info(f"使用GPU: {self.device_ids}, 主设备: {self.primary_device}") def _clear_all_gpu_memory(self): """清理所有GPU内存""" for device_id in self.device_ids: with torch.cuda.device(device_id): torch.cuda.empty_cache() gc.collect() logger.info("所有GPU内存已清理") def _load_model_multigpu(self): """加载模型到多GPU""" try: # 清理GPU内存 self._clear_all_gpu_memory() logger.info(f"正在加载模型到多GPU: {self.device_ids}") # 加载模型 self.model = AutoModel.from_pretrained( self.model_name, torch_dtype=torch.float16, trust_remote_code=True, device_map="auto" ) # 加载tokenizer和processor self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) logger.info("Tokenizer加载成功") self.processor = AutoProcessor.from_pretrained(self.model_name, trust_remote_code=True) logger.info("Processor加载成功") # 显示设备映射 if hasattr(self.model, 'hf_device_map'): logger.info(f"模型已成功加载到设备: {dict(list(self.model.hf_device_map.items())[:10])}") self.model.eval() logger.info("多GPU模型加载完成") except Exception as e: logger.error(f"模型加载失败: {e}") raise def encode_text_batch(self, texts: List[str], batch_size: int = 8) -> np.ndarray: """批量编码文本""" try: with self.model_lock: all_embeddings = [] for i in range(0, len(texts), batch_size): batch_texts = texts[i:i + batch_size] # 使用processor处理文本 inputs = self.processor( text=batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=512 ) # 将输入移动到主设备 inputs = {k: v.to(self.primary_device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs) embeddings = outputs.last_hidden_state.mean(dim=1) embeddings = embeddings.cpu().numpy() all_embeddings.append(embeddings) return np.vstack(all_embeddings) except Exception as e: logger.error(f"文本编码失败: {e}") return np.zeros((len(texts), 3584), dtype=np.float32) def encode_image_batch(self, images: List[Union[str, Image.Image]], batch_size: int = 4) -> np.ndarray: """批量编码图像""" try: with self.model_lock: processed_images = [] # 处理图像输入 for img in images: if isinstance(img, str): if os.path.exists(img): processed_images.append(Image.open(img).convert('RGB')) else: logger.warning(f"图像文件不存在: {img}") processed_images.append(Image.new('RGB', (224, 224), color='white')) elif isinstance(img, Image.Image): processed_images.append(img.convert('RGB')) else: logger.warning(f"不支持的图像类型: {type(img)}") processed_images.append(Image.new('RGB', (224, 224), color='white')) all_embeddings = [] for i in range(0, len(processed_images), batch_size): batch_images = processed_images[i:i + batch_size] # 使用processor处理图像 inputs = self.processor( images=batch_images, return_tensors="pt", padding=True ) # 将输入移动到主设备 inputs = {k: v.to(self.primary_device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs) embeddings = outputs.last_hidden_state.mean(dim=1) embeddings = embeddings.cpu().numpy() all_embeddings.append(embeddings) return np.vstack(all_embeddings) except Exception as e: logger.error(f"图像编码失败: {e}") embedding_dim = 3584 embeddings = np.zeros((len(images), embedding_dim), dtype=np.float32) return embeddings def build_text_index_parallel(self, texts: List[str], save_path: str = None): """ 构建文本索引(使用VDB替代FAISS) """ try: logger.info(f"正在构建文本索引,共 {len(texts)} 条文本") # 编码文本 embeddings = self.encode_text_batch(texts) # 使用VDB存储 self.vdb.build_text_index(texts, embeddings) logger.info("文本索引构建完成") except Exception as e: logger.error(f"构建文本索引失败: {e}") raise def build_image_index_parallel(self, image_paths: List[str], save_path: str = None): """ 构建图像索引(使用VDB替代FAISS) """ try: logger.info(f"正在构建图像索引,共 {len(image_paths)} 张图像") # 编码图像 embeddings = self.encode_image_batch(image_paths) # 使用VDB存储 self.vdb.build_image_index(image_paths, embeddings) logger.info("图像索引构建完成") except Exception as e: logger.error(f"构建图像索引失败: {e}") raise def search_text_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: """文搜文:使用文本查询搜索相似文本""" try: query_embedding = self.encode_text_batch([query]) return self.vdb.search_text_by_text(query_embedding[0], top_k) except Exception as e: logger.error(f"文搜文失败: {e}") return [] def search_images_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: """文搜图:使用文本查询搜索相似图像""" try: query_embedding = self.encode_text_batch([query]) return self.vdb.search_images_by_text(query_embedding[0], top_k) except Exception as e: logger.error(f"文搜图失败: {e}") return [] def search_images_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]: """图搜图:使用图像查询搜索相似图像""" try: query_embedding = self.encode_image_batch([query_image]) return self.vdb.search_images_by_image(query_embedding[0], top_k) except Exception as e: logger.error(f"图搜图失败: {e}") return [] def search_text_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]: """图搜文:使用图像查询搜索相似文本""" try: query_embedding = self.encode_image_batch([query_image]) return self.vdb.search_text_by_image(query_embedding[0], top_k) except Exception as e: logger.error(f"图搜文失败: {e}") return [] # Web应用兼容方法 def search_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: """文搜文:Web应用兼容方法""" return self.search_text_by_text(query, top_k) def search_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]: """图搜图:Web应用兼容方法""" return self.search_images_by_image(query_image, top_k) def search_images_by_text_query(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: """文搜图:Web应用兼容方法""" return self.search_images_by_text(query, top_k) def search_texts_by_image_query(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]: """图搜文:Web应用兼容方法""" return self.search_text_by_image(query_image, top_k) def get_statistics(self) -> Dict[str, Any]: """获取系统统计信息""" try: vdb_stats = self.vdb.get_statistics() stats = { "model_name": self.model_name, "device_ids": self.device_ids, "primary_device": self.primary_device, "backend": "Baidu VDB (No FAISS)", **vdb_stats } return stats except Exception as e: logger.error(f"获取统计信息失败: {e}") return {"status": "error", "error": str(e)} def clear_all_data(self): """清空所有数据""" try: self.vdb.clear_all_data() logger.info("✅ 所有数据已清空") except Exception as e: logger.error(f"❌ 清空数据失败: {e}") def get_gpu_memory_info(self): """获取所有GPU内存使用信息""" memory_info = {} for device_id in self.device_ids: with torch.cuda.device(device_id): allocated = torch.cuda.memory_allocated() / (1024**3) reserved = torch.cuda.memory_reserved() / (1024**3) total = torch.cuda.get_device_properties(device_id).total_memory / (1024**3) memory_info[f"GPU_{device_id}"] = { "allocated_GB": round(allocated, 2), "reserved_GB": round(reserved, 2), "total_GB": round(total, 2), "free_GB": round(total - reserved, 2) } return memory_info def cleanup(self): """清理资源""" try: if self.vdb: self.vdb.close() self._clear_all_gpu_memory() logger.info("✅ 资源清理完成") except Exception as e: logger.error(f"❌ 资源清理失败: {e}") def test_vdb_only_system(): """测试纯VDB多模态检索系统""" print("=" * 60) print("测试纯百度VDB多模态检索系统") print("=" * 60) system = None try: # 1. 初始化系统 print("1. 初始化纯VDB多模态检索系统...") system = MultimodalRetrievalVDBOnly() print("✅ 系统初始化成功") # 2. 构建文本索引 print("\n2. 构建文本索引...") test_texts = [ "人工智能技术的发展趋势", "机器学习在医疗领域的应用", "深度学习算法优化方法", "计算机视觉技术创新", "自然语言处理最新进展" ] system.build_text_index_parallel(test_texts) print("✅ 文本索引构建完成") # 3. 测试文搜文 print("\n3. 测试文搜文...") query = "AI技术" results = system.search_text_by_text(query, top_k=3) print(f"查询: {query}") for i, (text, score) in enumerate(results, 1): print(f" {i}. {text} (相似度: {score:.3f})") # 4. 获取统计信息 print("\n4. 获取统计信息...") stats = system.get_statistics() print("系统统计:") for key, value in stats.items(): print(f" {key}: {value}") print(f"\n🎉 纯VDB系统测试完成!") print("✅ 完全移除FAISS依赖") print("✅ 使用百度VDB作为向量数据库") print("✅ 支持多模态检索功能") return True except Exception as e: print(f"❌ 测试失败: {e}") import traceback traceback.print_exc() return False finally: if system: system.cleanup() if __name__ == "__main__": test_vdb_only_system()