mmeb/multimodal_retrieval_local.py
2025-09-22 10:13:11 +00:00

608 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
使用本地模型的多模态检索系统
支持文搜文、文搜图、图搜文、图搜图四种检索模式
"""
import torch
import numpy as np
from PIL import Image
from transformers import AutoModel, AutoProcessor, AutoTokenizer
from typing import List, Union, Tuple, Dict, Any, Optional
import os
import json
from pathlib import Path
import logging
import gc
import faiss
import time
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 设置离线模式
os.environ['TRANSFORMERS_OFFLINE'] = '1'
class MultimodalRetrievalLocal:
"""使用本地模型的多模态检索系统"""
def __init__(self,
model_path: str = "/root/models/Ops-MM-embedding-v1-7B",
use_all_gpus: bool = True,
gpu_ids: List[int] = None,
min_memory_gb: int = 12,
index_path: str = "local_faiss_index"):
"""
初始化多模态检索系统
Args:
model_path: 本地模型路径
use_all_gpus: 是否使用所有可用GPU
gpu_ids: 指定使用的GPU ID列表
min_memory_gb: 最小可用内存(GB)
index_path: FAISS索引文件路径
"""
self.model_path = model_path
self.index_path = index_path
# 检查模型路径
if not os.path.exists(model_path):
logger.error(f"模型路径不存在: {model_path}")
logger.info("请先下载模型到指定路径")
raise FileNotFoundError(f"模型路径不存在: {model_path}")
# 设置GPU设备
self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb)
# 清理GPU内存
self._clear_all_gpu_memory()
# 加载模型和处理器
self._load_model_and_processor()
# 初始化FAISS索引
self._init_index()
logger.info(f"多模态检索系统初始化完成,使用本地模型: {model_path}")
logger.info(f"向量存储路径: {index_path}")
def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int], min_memory_gb: int):
"""设置GPU设备"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.use_gpu = self.device.type == "cuda"
if self.use_gpu:
self.available_gpus = self._get_available_gpus(min_memory_gb)
if not self.available_gpus:
logger.warning(f"没有可用的GPU或GPU内存不足{min_memory_gb}GB将使用CPU")
self.device = torch.device("cpu")
self.use_gpu = False
else:
if gpu_ids:
self.gpu_ids = [gid for gid in gpu_ids if gid in self.available_gpus]
if not self.gpu_ids:
logger.warning(f"指定的GPU {gpu_ids}不可用或内存不足将使用可用的GPU: {self.available_gpus}")
self.gpu_ids = self.available_gpus
elif use_all_gpus:
self.gpu_ids = self.available_gpus
else:
self.gpu_ids = [self.available_gpus[0]]
logger.info(f"使用GPU: {self.gpu_ids}")
self.device = torch.device(f"cuda:{self.gpu_ids[0]}")
else:
logger.warning("没有可用的GPU将使用CPU")
self.gpu_ids = []
def _get_available_gpus(self, min_memory_gb: int) -> List[int]:
"""获取可用的GPU列表"""
available_gpus = []
for i in range(torch.cuda.device_count()):
total_mem = torch.cuda.get_device_properties(i).total_memory / (1024 ** 3) # GB
if total_mem >= min_memory_gb:
available_gpus.append(i)
return available_gpus
def _clear_all_gpu_memory(self):
"""清理GPU内存"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def _load_model_and_processor(self):
"""加载模型和处理器"""
logger.info(f"加载本地模型和处理器: {self.model_path}")
try:
# 加载模型和处理器
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
self.processor = AutoProcessor.from_pretrained(self.model_path)
# 输出处理器信息
logger.info(f"Processor类型: {type(self.processor)}")
logger.info(f"Processor方法: {dir(self.processor)}")
# 检查是否有图像处理器
if hasattr(self.processor, 'image_processor'):
logger.info(f"Image processor类型: {type(self.processor.image_processor)}")
logger.info(f"Image processor方法: {dir(self.processor.image_processor)}")
# 加载模型
self.model = AutoModel.from_pretrained(
self.model_path,
torch_dtype=torch.float16 if self.use_gpu else torch.float32,
device_map="auto" if len(self.gpu_ids) > 1 else None
)
if len(self.gpu_ids) == 1:
self.model.to(self.device)
self.model.eval()
# 获取向量维度
self.vector_dim = self.model.config.hidden_size
logger.info(f"向量维度: {self.vector_dim}")
logger.info("模型和处理器加载成功")
except Exception as e:
logger.error(f"模型加载失败: {str(e)}")
raise RuntimeError(f"模型加载失败: {str(e)}")
def _init_index(self):
"""初始化FAISS索引"""
index_file = f"{self.index_path}.index"
if os.path.exists(index_file):
logger.info(f"加载现有索引: {index_file}")
try:
self.index = faiss.read_index(index_file)
logger.info(f"索引加载成功,包含{self.index.ntotal}个向量")
except Exception as e:
logger.error(f"索引加载失败: {str(e)}")
logger.info("创建新索引...")
self.index = faiss.IndexFlatL2(self.vector_dim)
else:
logger.info(f"创建新索引,维度: {self.vector_dim}")
self.index = faiss.IndexFlatL2(self.vector_dim)
# 加载元数据
self.metadata = {}
metadata_file = f"{self.index_path}_metadata.json"
if os.path.exists(metadata_file):
try:
with open(metadata_file, 'r', encoding='utf-8') as f:
self.metadata = json.load(f)
logger.info(f"元数据加载成功,包含{len(self.metadata)}条记录")
except Exception as e:
logger.error(f"元数据加载失败: {str(e)}")
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
"""编码文本为向量"""
if isinstance(text, str):
text = [text]
inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
# 获取[CLS]标记的隐藏状态作为句子表示
text_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
# 归一化向量
text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True)
return text_embeddings[0] if len(text) == 1 else text_embeddings
def encode_image(self, image: Union[Image.Image, List[Image.Image]]) -> np.ndarray:
"""编码图像为向量"""
try:
logger.info(f"encode_image: 开始编码图像,类型: {type(image)}")
if isinstance(image, Image.Image):
logger.info(f"encode_image: 单个图像,大小: {image.size}")
image = [image]
else:
logger.info(f"encode_image: 图像列表,长度: {len(image)}")
# 检查图像是否为空
if not image or len(image) == 0:
logger.error("encode_image: 图像列表为空")
# 返回一个空的二维数组
return np.zeros((0, self.vector_dim))
# 检查图像是否有效
for i, img in enumerate(image):
if not isinstance(img, Image.Image):
logger.error(f"encode_image: 第{i}个元素不是有效的PIL图像类型: {type(img)}")
# 返回一个空的二维数组
return np.zeros((0, self.vector_dim))
logger.info("encode_image: 处理图像输入")
# 检查图像格式
for i, img in enumerate(image):
logger.info(f"encode_image: 图像 {i} 格式: {img.format}, 模式: {img.mode}, 大小: {img.size}")
# 转换为RGB模式如果不是
if img.mode != 'RGB':
logger.info(f"encode_image: 将图像 {i}{img.mode} 转换为 RGB")
image[i] = img.convert('RGB')
try:
# 直接使用image_processor处理图像
if hasattr(self.processor, 'image_processor'):
logger.info("encode_image: 使用image_processor处理图像")
pixel_values = self.processor.image_processor(images=image, return_tensors="pt").pixel_values
inputs = {"pixel_values": pixel_values}
else:
logger.info("encode_image: 使用processor处理图像")
inputs = self.processor(images=image, return_tensors="pt")
if not inputs or len(inputs) == 0:
logger.error("encode_image: processor返回了空的输入")
return np.zeros((0, self.vector_dim))
logger.info(f"encode_image: 处理后的输入键: {list(inputs.keys())}")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
logger.info("encode_image: 运行模型推理")
logger.info(f"Model类型: {type(self.model)}")
logger.info(f"Model属性: {dir(self.model)}")
# 检查模型结构
try:
logger.info(f"Model配置: {self.model.config}")
logger.info(f"Model配置属性: {dir(self.model.config)}")
else:
visual_outputs = self.model.visual(**inputs)
if hasattr(visual_outputs, 'pooler_output'):
image_embeddings = visual_outputs.pooler_output.cpu().numpy()
elif hasattr(visual_outputs, 'last_hidden_state'):
image_embeddings = visual_outputs.last_hidden_state[:, 0, :].cpu().numpy()
else:
logger.error("encode_image: 无法从视觉模型输出中获取图像向量")
raise ValueError("无法从视觉模型输出中获取图像向量")
else:
# 尝试直接使用模型进行推理
logger.info("encode_image: 尝试直接使用模型进行推理")
with torch.no_grad():
# 使用空文本输入,只提供图像
if 'pixel_values' in inputs:
outputs = self.model(pixel_values=inputs['pixel_values'], input_ids=None)
else:
outputs = self.model(**inputs, input_ids=None)
# 尝试从输出中获取图像向量
if hasattr(outputs, 'image_embeds'):
image_embeddings = outputs.image_embeds.cpu().numpy()
elif hasattr(outputs, 'vision_model_output') and hasattr(outputs.vision_model_output, 'pooler_output'):
image_embeddings = outputs.vision_model_output.pooler_output.cpu().numpy()
elif hasattr(outputs, 'pooler_output'):
image_embeddings = outputs.pooler_output.cpu().numpy()
elif hasattr(outputs, 'last_hidden_state'):
image_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
else:
logger.error("encode_image: 无法从模型输出中获取图像向量")
raise ValueError("无法从模型输出中获取图像向量")
except Exception as e:
logger.error(f"encode_image: 处理图像时出错: {str(e)}")
raise e
return np.zeros((0, self.vector_dim))
# 归一化向量
image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
# 始终返回二维数组,即使只有一个图像
if len(image) == 1:
result = np.array([image_embeddings[0]])
logger.info(f"encode_image: 返回单个图像向量,形状: {result.shape}")
return result
else:
logger.info(f"encode_image: 返回多个图像向量,形状: {image_embeddings.shape}")
return image_embeddings
except Exception as e:
logger.error(f"encode_image: 异常: {str(e)}")
# 返回一个空的二维数组
return np.zeros((0, self.vector_dim))
def add_texts(
self,
texts: List[str],
metadatas: Optional[List[Dict[str, Any]]] = None
) -> List[str]:
"""
添加文本到检索系统
Args:
texts: 文本列表
metadatas: 元数据列表,每个元素是一个字典
Returns:
添加的文本ID列表
"""
if not texts:
return []
if metadatas is None:
metadatas = [{} for _ in range(len(texts))]
if len(texts) != len(metadatas):
raise ValueError("texts和metadatas长度必须相同")
# 编码文本
text_embeddings = self.encode_text(texts)
# 准备元数据
start_id = self.index.ntotal
ids = list(range(start_id, start_id + len(texts)))
# 添加到索引
self.index.add(np.array(text_embeddings).astype('float32'))
# 保存元数据
for i, id in enumerate(ids):
self.metadata[str(id)] = {
"text": texts[i],
"type": "text",
**metadatas[i]
}
logger.info(f"成功添加{len(ids)}条文本到检索系统")
return [str(id) for id in ids]
def add_images(
self,
images: List[Image.Image],
metadatas: Optional[List[Dict[str, Any]]] = None,
image_paths: Optional[List[str]] = None
) -> List[str]:
"""
添加图像到检索系统
Args:
images: PIL图像列表
metadatas: 元数据列表,每个元素是一个字典
image_paths: 图像路径列表,用于保存到元数据
Returns:
添加的图像ID列表
"""
try:
logger.info(f"add_images: 开始添加图像,数量: {len(images) if images else 0}")
# 检查图像列表
if not images or len(images) == 0:
logger.warning("add_images: 图像列表为空")
return []
# 准备元数据
if metadatas is None:
logger.info("add_images: 创建默认元数据")
metadatas = [{} for _ in range(len(images))]
# 检查长度一致性
if len(images) != len(metadatas):
logger.error(f"add_images: 长度不一致 - images: {len(images)}, metadatas: {len(metadatas)}")
raise ValueError("images和metadatas长度必须相同")
# 编码图像
logger.info("add_images: 编码图像")
image_embeddings = self.encode_image(images)
# 检查编码结果
if image_embeddings.shape[0] == 0:
logger.error("add_images: 图像编码失败,返回空数组")
return []
# 准备元数据
start_id = self.index.ntotal
ids = list(range(start_id, start_id + len(images)))
logger.info(f"add_images: 生成索引ID: {start_id} - {start_id + len(images) - 1}")
# 添加到索引
logger.info(f"add_images: 添加向量到FAISS索引形状: {image_embeddings.shape}")
self.index.add(np.array(image_embeddings).astype('float32'))
# 保存元数据
for i, id in enumerate(ids):
try:
metadata = {
"type": "image",
"width": images[i].width,
"height": images[i].height,
**metadatas[i]
}
if image_paths and i < len(image_paths):
metadata["path"] = image_paths[i]
self.metadata[str(id)] = metadata
logger.debug(f"add_images: 保存元数据成功 - ID: {id}")
except Exception as e:
logger.error(f"add_images: 保存元数据失败 - ID: {id}, 错误: {str(e)}")
logger.info(f"add_images: 成功添加{len(ids)}张图像到检索系统")
return [str(id) for id in ids]
except Exception as e:
logger.error(f"add_images: 添加图像异常: {str(e)}")
return []
def search_by_text(
self,
query: str,
k: int = 5,
filter_type: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
文本搜索
Args:
query: 查询文本
k: 返回结果数量
filter_type: 过滤类型,可选值: "text", "image", None(不过滤)
Returns:
搜索结果列表,每个元素包含相似项和分数
"""
# 编码查询文本
query_embedding = self.encode_text(query)
# 执行搜索
return self._search(query_embedding, k, filter_type)
def search_by_image(
self,
image: Image.Image,
k: int = 5,
filter_type: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
图像搜索
Args:
image: 查询图像
k: 返回结果数量
filter_type: 过滤类型,可选值: "text", "image", None(不过滤)
Returns:
搜索结果列表,每个元素包含相似项和分数
"""
# 编码查询图像
query_embedding = self.encode_image(image)
# 执行搜索
return self._search(query_embedding, k, filter_type)
def _search(
self,
query_embedding: np.ndarray,
k: int = 5,
filter_type: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
执行搜索
Args:
query_embedding: 查询向量
k: 返回结果数量
filter_type: 过滤类型,可选值: "text", "image", None(不过滤)
Returns:
搜索结果列表
"""
if self.index.ntotal == 0:
return []
# 确保查询向量是2D数组
if len(query_embedding.shape) == 1:
query_embedding = query_embedding.reshape(1, -1)
# 执行搜索,获取更多结果以便过滤
actual_k = k * 3 if filter_type else k
actual_k = min(actual_k, self.index.ntotal)
distances, indices = self.index.search(query_embedding.astype('float32'), actual_k)
# 处理结果
results = []
for i in range(len(indices[0])):
idx = indices[0][i]
if idx < 0: # FAISS可能返回-1表示无效索引
continue
vector_id = str(idx)
if vector_id in self.metadata:
item = self.metadata[vector_id]
# 如果指定了过滤类型,则只返回该类型的结果
if filter_type and item.get("type") != filter_type:
continue
# 添加距离和分数
result = item.copy()
result["distance"] = float(distances[0][i])
result["score"] = float(1.0 / (1.0 + distances[0][i]))
results.append(result)
# 如果已经收集了足够的结果,则停止
if len(results) >= k:
break
return results
def save_index(self):
"""保存索引和元数据"""
# 保存索引
index_file = f"{self.index_path}.index"
try:
faiss.write_index(self.index, index_file)
logger.info(f"索引保存成功: {index_file}")
except Exception as e:
logger.error(f"索引保存失败: {str(e)}")
# 保存元数据
metadata_file = f"{self.index_path}_metadata.json"
try:
with open(metadata_file, 'w', encoding='utf-8') as f:
json.dump(self.metadata, f, ensure_ascii=False, indent=2)
logger.info(f"元数据保存成功: {metadata_file}")
except Exception as e:
logger.error(f"元数据保存失败: {str(e)}")
def get_stats(self) -> Dict[str, Any]:
"""获取检索系统统计信息"""
text_count = sum(1 for v in self.metadata.values() if v.get("type") == "text")
image_count = sum(1 for v in self.metadata.values() if v.get("type") == "image")
return {
"total_vectors": self.index.ntotal,
"text_count": text_count,
"image_count": image_count,
"vector_dimension": self.vector_dim,
"index_path": self.index_path,
"model_path": self.model_path
}
def clear_index(self):
"""清空索引"""
logger.info(f"清空索引: {self.index_path}")
# 重新创建索引
self.index = faiss.IndexFlatL2(self.vector_dim)
# 清空元数据
self.metadata = {}
# 保存空索引
self.save_index()
logger.info(f"索引已清空: {self.index_path}")
return True
def list_items(self) -> List[Dict[str, Any]]:
"""列出所有索引项"""
items = []
for item_id, metadata in self.metadata.items():
item = metadata.copy()
item['id'] = item_id
items.append(item)
return items
def __del__(self):
"""析构函数,确保资源被正确释放并自动保存索引"""
try:
if hasattr(self, 'model'):
del self.model
self._clear_all_gpu_memory()
if hasattr(self, 'index') and self.index is not None:
logger.info("系统关闭前自动保存索引")
self.save_index()
except Exception as e:
logger.error(f"析构时保存索引失败: {str(e)}")