mmeb/multimodal_retrieval_faiss.py
2025-09-22 10:13:11 +00:00

371 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
基于FAISS的多模态检索系统
支持文搜文、文搜图、图搜文、图搜图四种检索模式
"""
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel
import numpy as np
from PIL import Image
from transformers import AutoModel, AutoProcessor, AutoTokenizer
from typing import List, Union, Tuple, Dict, Any, Optional
import os
import json
from pathlib import Path
import logging
import gc
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
from faiss_vector_store import FaissVectorStore
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class MultimodalRetrievalFAISS:
"""基于FAISS的多模态检索系统"""
def __init__(self, model_name: str = "OpenSearch-AI/Ops-MM-embedding-v1-7B",
use_all_gpus: bool = True, gpu_ids: List[int] = None,
min_memory_gb: int = 12, index_path: str = "faiss_index"):
"""
初始化多模态检索系统
Args:
model_name: 模型名称
use_all_gpus: 是否使用所有可用GPU
gpu_ids: 指定使用的GPU ID列表
min_memory_gb: 最小可用内存(GB)
index_path: FAISS索引文件路径
"""
self.model_name = model_name
self.index_path = index_path
# 设置GPU设备
self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb)
# 清理GPU内存
self._clear_all_gpu_memory()
# 加载模型和处理器
self._load_model_and_processor()
# 初始化FAISS向量存储
self.vector_store = FaissVectorStore(
index_path=index_path,
dimension=3584 # OpenSearch-AI/Ops-MM-embedding-v1-7B的向量维度
)
logger.info(f"多模态检索系统初始化完成,使用模型: {model_name}")
logger.info(f"向量存储路径: {index_path}")
def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int], min_memory_gb: int):
"""设置GPU设备"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.use_gpu = self.device.type == "cuda"
if self.use_gpu:
self.available_gpus = self._get_available_gpus(min_memory_gb)
if not self.available_gpus:
logger.warning(f"没有可用的GPU或GPU内存不足{min_memory_gb}GB将使用CPU")
self.device = torch.device("cpu")
self.use_gpu = False
else:
if gpu_ids:
self.gpu_ids = [gid for gid in gpu_ids if gid in self.available_gpus]
if not self.gpu_ids:
logger.warning(f"指定的GPU {gpu_ids}不可用或内存不足将使用可用的GPU: {self.available_gpus}")
self.gpu_ids = self.available_gpus
elif use_all_gpus:
self.gpu_ids = self.available_gpus
else:
self.gpu_ids = [self.available_gpus[0]]
logger.info(f"使用GPU: {self.gpu_ids}")
self.device = torch.device(f"cuda:{self.gpu_ids[0]}")
def _get_available_gpus(self, min_memory_gb: int) -> List[int]:
"""获取可用的GPU列表"""
available_gpus = []
for i in range(torch.cuda.device_count()):
total_mem = torch.cuda.get_device_properties(i).total_memory / (1024 ** 3) # GB
if total_mem >= min_memory_gb:
available_gpus.append(i)
return available_gpus
def _clear_all_gpu_memory(self):
"""清理GPU内存"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def _load_model_and_processor(self):
"""加载模型和处理器"""
logger.info(f"加载模型和处理器: {self.model_name}")
# 加载tokenizer和processor
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.processor = AutoProcessor.from_pretrained(self.model_name)
# 加载模型
self.model = AutoModel.from_pretrained(
self.model_name,
torch_dtype=torch.float16 if self.use_gpu else torch.float32,
device_map="auto" if len(self.gpu_ids) > 1 else None
)
# 如果使用多GPU包装模型
if len(self.gpu_ids) > 1:
self.model = DataParallel(self.model, device_ids=self.gpu_ids)
self.model.eval()
self.model.to(self.device)
logger.info("模型和处理器加载完成")
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
"""编码文本为向量"""
if isinstance(text, str):
text = [text]
inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
# 获取[CLS]标记的隐藏状态作为句子表示
text_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
# 归一化向量
text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True)
return text_embeddings[0] if len(text) == 1 else text_embeddings
def encode_image(self, image: Union[Image.Image, List[Image.Image]]) -> np.ndarray:
"""编码图像为向量"""
if isinstance(image, Image.Image):
image = [image]
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.vision_model(**inputs)
# 获取[CLS]标记的隐藏状态作为图像表示
image_embeddings = outputs.pooler_output.cpu().numpy()
# 归一化向量
image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
return image_embeddings[0] if len(image) == 1 else image_embeddings
def add_texts(
self,
texts: List[str],
metadatas: Optional[List[Dict[str, Any]]] = None
) -> List[str]:
"""
添加文本到检索系统
Args:
texts: 文本列表
metadatas: 元数据列表,每个元素是一个字典
Returns:
添加的文本ID列表
"""
if not texts:
return []
if metadatas is None:
metadatas = [{} for _ in range(len(texts))]
if len(texts) != len(metadatas):
raise ValueError("texts和metadatas长度必须相同")
# 编码文本
text_embeddings = self.encode_text(texts)
# 准备元数据
for i, text in enumerate(texts):
metadatas[i].update({
"text": text,
"type": "text"
})
# 添加到向量存储
vector_ids = self.vector_store.add_vectors(text_embeddings, metadatas)
logger.info(f"成功添加{len(vector_ids)}条文本到检索系统")
return vector_ids
def add_images(
self,
images: List[Image.Image],
metadatas: Optional[List[Dict[str, Any]]] = None
) -> List[str]:
"""
添加图像到检索系统
Args:
images: PIL图像列表
metadatas: 元数据列表,每个元素是一个字典
Returns:
添加的图像ID列表
"""
if not images:
return []
if metadatas is None:
metadatas = [{} for _ in range(len(images))]
if len(images) != len(metadatas):
raise ValueError("images和metadatas长度必须相同")
# 编码图像
image_embeddings = self.encode_image(images)
# 准备元数据
for i, image in enumerate(images):
metadatas[i].update({
"type": "image",
"width": image.width,
"height": image.height
})
# 添加到向量存储
vector_ids = self.vector_store.add_vectors(image_embeddings, metadatas)
logger.info(f"成功添加{len(vector_ids)}张图像到检索系统")
return vector_ids
def search_by_text(
self,
query: str,
k: int = 5,
filter_condition: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
文本搜索
Args:
query: 查询文本
k: 返回结果数量
filter_condition: 过滤条件
Returns:
搜索结果列表,每个元素包含相似项和分数
"""
# 编码查询文本
query_embedding = self.encode_text(query)
# 执行搜索
results, distances = self.vector_store.search(query_embedding, k)
# 处理结果
search_results = []
for i, (result, distance) in enumerate(zip(results, distances)):
result["score"] = 1.0 / (1.0 + distance) # 将距离转换为相似度分数
search_results.append(result)
return search_results
def search_by_image(
self,
image: Image.Image,
k: int = 5,
filter_condition: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
图像搜索
Args:
image: 查询图像
k: 返回结果数量
filter_condition: 过滤条件
Returns:
搜索结果列表,每个元素包含相似项和分数
"""
# 编码查询图像
query_embedding = self.encode_image(image)
# 执行搜索
results, distances = self.vector_store.search(query_embedding, k)
# 处理结果
search_results = []
for i, (result, distance) in enumerate(zip(results, distances)):
result["score"] = 1.0 / (1.0 + distance) # 将距离转换为相似度分数
search_results.append(result)
return search_results
def get_vector_count(self) -> int:
"""获取向量数量"""
return self.vector_store.get_vector_count()
def save_index(self):
"""保存索引"""
self.vector_store.save_index()
logger.info("索引已保存")
def __del__(self):
"""析构函数,确保资源被正确释放"""
if hasattr(self, 'model'):
del self.model
self._clear_all_gpu_memory()
if hasattr(self, 'vector_store'):
self.save_index()
def test_faiss_system():
"""测试FAISS多模态检索系统"""
import time
from PIL import Image
import numpy as np
# 初始化检索系统
print("初始化多模态检索系统...")
retrieval = MultimodalRetrievalFAISS(
model_name="OpenSearch-AI/Ops-MM-embedding-v1-7B",
use_all_gpus=True,
index_path="faiss_index_test"
)
# 测试文本
texts = [
"一只可爱的橘色猫咪在沙发上睡觉",
"城市夜景中的高楼大厦和车流",
"阳光明媚的海滩上,人们在冲浪和晒太阳",
"美味的意大利面配红酒和沙拉",
"雪山上滑雪的运动员"
]
# 添加文本
print("\n添加文本到检索系统...")
text_ids = retrieval.add_texts(texts)
print(f"添加了{len(text_ids)}条文本")
# 测试文本搜索
print("\n测试文本搜索...")
query_text = "一只猫在睡觉"
print(f"查询: {query_text}")
results = retrieval.search_by_text(query_text, k=2)
for i, result in enumerate(results):
print(f"结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})")
# 测试保存和加载
print("\n保存索引...")
retrieval.save_index()
print("\n测试完成!")
if __name__ == "__main__":
test_faiss_system()