#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 使用本地模型的多模态检索系统 支持文搜文、文搜图、图搜文、图搜图四种检索模式 """ import torch import numpy as np from PIL import Image from ops_mm_embedding_v1 import OpsMMEmbeddingV1 from typing import List, Union, Tuple, Dict, Any, Optional import os import json from pathlib import Path import logging import gc import faiss import time # 设置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 设置离线模式 os.environ['TRANSFORMERS_OFFLINE'] = '1' class MultimodalRetrievalLocal: """使用本地模型的多模态检索系统""" def __init__(self, model_path: str = "/root/models/Ops-MM-embedding-v1-7B", use_all_gpus: bool = True, gpu_ids: List[int] = None, min_memory_gb: int = 12, index_path: str = "local_faiss_index"): """ 初始化多模态检索系统 Args: model_path: 本地模型路径 use_all_gpus: 是否使用所有可用GPU gpu_ids: 指定使用的GPU ID列表 min_memory_gb: 最小可用内存(GB) index_path: FAISS索引文件路径 """ self.model_path = model_path self.index_path = index_path # 检查模型路径 if not os.path.exists(model_path): logger.error(f"模型路径不存在: {model_path}") logger.info("请先下载模型到指定路径") raise FileNotFoundError(f"模型路径不存在: {model_path}") # 设置GPU设备 self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb) # 清理GPU内存 self._clear_all_gpu_memory() # 加载嵌入模型 self._load_embedding_model() # 初始化FAISS索引 self._init_index() logger.info(f"多模态检索系统初始化完成,使用本地模型: {model_path}") logger.info(f"向量存储路径: {index_path}") def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int], min_memory_gb: int): """设置GPU设备""" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.use_gpu = self.device.type == "cuda" if self.use_gpu: self.available_gpus = self._get_available_gpus(min_memory_gb) if not self.available_gpus: logger.warning(f"没有可用的GPU或GPU内存不足{min_memory_gb}GB,将使用CPU") self.device = torch.device("cpu") self.use_gpu = False else: if gpu_ids: self.gpu_ids = [gid for gid in gpu_ids if gid in self.available_gpus] if not self.gpu_ids: logger.warning(f"指定的GPU {gpu_ids}不可用或内存不足,将使用可用的GPU: {self.available_gpus}") self.gpu_ids = self.available_gpus elif use_all_gpus: self.gpu_ids = self.available_gpus else: self.gpu_ids = [self.available_gpus[0]] logger.info(f"使用GPU: {self.gpu_ids}") self.device = torch.device(f"cuda:{self.gpu_ids[0]}") else: logger.warning("没有可用的GPU,将使用CPU") self.gpu_ids = [] def _get_available_gpus(self, min_memory_gb: int) -> List[int]: """获取可用的GPU列表""" available_gpus = [] for i in range(torch.cuda.device_count()): total_mem = torch.cuda.get_device_properties(i).total_memory / (1024 ** 3) # GB if total_mem >= min_memory_gb: available_gpus.append(i) return available_gpus def _clear_all_gpu_memory(self): """清理GPU内存""" if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() def _load_embedding_model(self): """加载多模态嵌入模型 OpsMMEmbeddingV1""" logger.info(f"加载本地多模态嵌入模型: {self.model_path}") try: device_str = "cuda" if self.use_gpu else "cpu" self.model = OpsMMEmbeddingV1( self.model_path, device=device_str, attn_implementation=None, ) # 获取向量维度 self.vector_dim = int(getattr(self.model.base_model.config, "hidden_size")) logger.info(f"向量维度: {self.vector_dim}") logger.info("嵌入模型加载成功") except Exception as e: logger.error(f"嵌入模型加载失败: {str(e)}") raise RuntimeError(f"嵌入模型加载失败: {str(e)}") def _init_index(self): """初始化FAISS索引""" index_file = f"{self.index_path}.index" if os.path.exists(index_file): logger.info(f"加载现有索引: {index_file}") try: self.index = faiss.read_index(index_file) logger.info(f"索引加载成功,包含{self.index.ntotal}个向量") except Exception as e: logger.error(f"索引加载失败: {str(e)}") logger.info("创建新索引...") self.index = faiss.IndexFlatL2(self.vector_dim) else: logger.info(f"创建新索引,维度: {self.vector_dim}") self.index = faiss.IndexFlatL2(self.vector_dim) # 加载元数据 self.metadata = {} metadata_file = f"{self.index_path}_metadata.json" if os.path.exists(metadata_file): try: with open(metadata_file, 'r', encoding='utf-8') as f: self.metadata = json.load(f) logger.info(f"元数据加载成功,包含{len(self.metadata)}条记录") except Exception as e: logger.error(f"元数据加载失败: {str(e)}") def encode_text(self, text: Union[str, List[str]]) -> np.ndarray: """编码文本为向量(使用 OpsMMEmbeddingV1)""" if isinstance(text, str): text = [text] with torch.inference_mode(): emb = self.model.get_text_embeddings(texts=text) text_embeddings = emb.detach().float().cpu().numpy() # emb 已经做过 L2 归一化,这里保持一致 return text_embeddings[0] if len(text) == 1 else text_embeddings def encode_image(self, image: Union[Image.Image, List[Image.Image]]) -> np.ndarray: """编码图像为向量(使用 OpsMMEmbeddingV1)""" try: # 规范为列表 images: List[Image.Image] if isinstance(image, Image.Image): images = [image] else: images = image if not images: logger.error("encode_image: 图像列表为空") return np.zeros((0, self.vector_dim)) # 强制为 RGB rgb_images = [img.convert('RGB') if img.mode != 'RGB' else img for img in images] with torch.inference_mode(): emb = self.model.get_image_embeddings(images=rgb_images) image_embeddings = emb.detach().float().cpu().numpy() return image_embeddings except Exception as e: logger.error(f"encode_image: 异常: {str(e)}") return np.zeros((0, self.vector_dim)) def add_texts( self, texts: List[str], metadatas: Optional[List[Dict[str, Any]]] = None ) -> List[str]: """ 添加文本到检索系统 Args: texts: 文本列表 metadatas: 元数据列表,每个元素是一个字典 Returns: 添加的文本ID列表 """ if not texts: return [] if metadatas is None: metadatas = [{} for _ in range(len(texts))] if len(texts) != len(metadatas): raise ValueError("texts和metadatas长度必须相同") # 编码文本 text_embeddings = self.encode_text(texts) # 准备元数据 start_id = self.index.ntotal ids = list(range(start_id, start_id + len(texts))) # 添加到索引 self.index.add(np.array(text_embeddings).astype('float32')) # 保存元数据 for i, id in enumerate(ids): self.metadata[str(id)] = { "text": texts[i], "type": "text", **metadatas[i] } logger.info(f"成功添加{len(ids)}条文本到检索系统") return [str(id) for id in ids] def add_images( self, images: List[Image.Image], metadatas: Optional[List[Dict[str, Any]]] = None, image_paths: Optional[List[str]] = None ) -> List[str]: """ 添加图像到检索系统 Args: images: PIL图像列表 metadatas: 元数据列表,每个元素是一个字典 image_paths: 图像路径列表,用于保存到元数据 Returns: 添加的图像ID列表 """ try: logger.info(f"add_images: 开始添加图像,数量: {len(images) if images else 0}") # 检查图像列表 if not images or len(images) == 0: logger.warning("add_images: 图像列表为空") return [] # 准备元数据 if metadatas is None: logger.info("add_images: 创建默认元数据") metadatas = [{} for _ in range(len(images))] # 检查长度一致性 if len(images) != len(metadatas): logger.error(f"add_images: 长度不一致 - images: {len(images)}, metadatas: {len(metadatas)}") raise ValueError("images和metadatas长度必须相同") # 编码图像 logger.info("add_images: 编码图像") image_embeddings = self.encode_image(images) # 检查编码结果 if image_embeddings.shape[0] == 0: logger.error("add_images: 图像编码失败,返回空数组") return [] # 准备元数据 start_id = self.index.ntotal ids = list(range(start_id, start_id + len(images))) logger.info(f"add_images: 生成索引ID: {start_id} - {start_id + len(images) - 1}") # 添加到索引 logger.info(f"add_images: 添加向量到FAISS索引,形状: {image_embeddings.shape}") self.index.add(np.array(image_embeddings).astype('float32')) # 保存元数据 for i, id in enumerate(ids): try: metadata = { "type": "image", "width": images[i].width, "height": images[i].height, **metadatas[i] } if image_paths and i < len(image_paths): metadata["path"] = image_paths[i] self.metadata[str(id)] = metadata logger.debug(f"add_images: 保存元数据成功 - ID: {id}") except Exception as e: logger.error(f"add_images: 保存元数据失败 - ID: {id}, 错误: {str(e)}") logger.info(f"add_images: 成功添加{len(ids)}张图像到检索系统") return [str(id) for id in ids] except Exception as e: logger.error(f"add_images: 添加图像异常: {str(e)}") return [] def search_by_text( self, query: str, k: int = 5, filter_type: Optional[str] = None ) -> List[Dict[str, Any]]: """ 文本搜索 Args: query: 查询文本 k: 返回结果数量 filter_type: 过滤类型,可选值: "text", "image", None(不过滤) Returns: 搜索结果列表,每个元素包含相似项和分数 """ # 编码查询文本 query_embedding = self.encode_text(query) # 执行搜索 return self._search(query_embedding, k, filter_type) def search_by_image( self, image: Image.Image, k: int = 5, filter_type: Optional[str] = None ) -> List[Dict[str, Any]]: """ 图像搜索 Args: image: 查询图像 k: 返回结果数量 filter_type: 过滤类型,可选值: "text", "image", None(不过滤) Returns: 搜索结果列表,每个元素包含相似项和分数 """ # 编码查询图像 query_embedding = self.encode_image(image) # 执行搜索 return self._search(query_embedding, k, filter_type) def _search( self, query_embedding: np.ndarray, k: int = 5, filter_type: Optional[str] = None ) -> List[Dict[str, Any]]: """ 执行搜索 Args: query_embedding: 查询向量 k: 返回结果数量 filter_type: 过滤类型,可选值: "text", "image", None(不过滤) Returns: 搜索结果列表 """ if self.index.ntotal == 0: return [] # 确保查询向量是2D数组 if len(query_embedding.shape) == 1: query_embedding = query_embedding.reshape(1, -1) # 执行搜索,获取更多结果以便过滤 actual_k = k * 3 if filter_type else k actual_k = min(actual_k, self.index.ntotal) distances, indices = self.index.search(query_embedding.astype('float32'), actual_k) # 处理结果 results = [] for i in range(len(indices[0])): idx = indices[0][i] if idx < 0: # FAISS可能返回-1表示无效索引 continue vector_id = str(idx) if vector_id in self.metadata: item = self.metadata[vector_id] # 如果指定了过滤类型,则只返回该类型的结果 if filter_type and item.get("type") != filter_type: continue # 添加距离和分数 result = item.copy() result["distance"] = float(distances[0][i]) result["score"] = float(1.0 / (1.0 + distances[0][i])) results.append(result) # 如果已经收集了足够的结果,则停止 if len(results) >= k: break return results def save_index(self): """保存索引和元数据""" # 保存索引 index_file = f"{self.index_path}.index" try: faiss.write_index(self.index, index_file) logger.info(f"索引保存成功: {index_file}") except Exception as e: logger.error(f"索引保存失败: {str(e)}") # 保存元数据 metadata_file = f"{self.index_path}_metadata.json" try: with open(metadata_file, 'w', encoding='utf-8') as f: json.dump(self.metadata, f, ensure_ascii=False, indent=2) logger.info(f"元数据保存成功: {metadata_file}") except Exception as e: logger.error(f"元数据保存失败: {str(e)}") def get_stats(self) -> Dict[str, Any]: """获取检索系统统计信息""" text_count = sum(1 for v in self.metadata.values() if v.get("type") == "text") image_count = sum(1 for v in self.metadata.values() if v.get("type") == "image") return { "total_vectors": self.index.ntotal, "text_count": text_count, "image_count": image_count, "vector_dimension": self.vector_dim, "index_path": self.index_path, "model_path": self.model_path } def clear_index(self): """清空索引""" logger.info(f"清空索引: {self.index_path}") # 重新创建索引 self.index = faiss.IndexFlatL2(self.vector_dim) # 清空元数据 self.metadata = {} # 保存空索引 self.save_index() logger.info(f"索引已清空: {self.index_path}") return True def list_items(self) -> List[Dict[str, Any]]: """列出所有索引项""" items = [] for item_id, metadata in self.metadata.items(): item = metadata.copy() item['id'] = item_id items.append(item) return items def __del__(self): """析构函数,确保资源被正确释放并自动保存索引""" try: if hasattr(self, 'model'): del self.model self._clear_all_gpu_memory() if hasattr(self, 'index') and self.index is not None: logger.info("系统关闭前自动保存索引") self.save_index() except Exception as e: logger.error(f"析构时保存索引失败: {str(e)}")