mmeb/multimodal_retrieval_local.py
2025-09-22 19:47:14 +08:00

606 lines
23 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
使用本地模型的多模态检索系统
支持文搜文、文搜图、图搜文、图搜图四种检索模式
"""
import torch
import numpy as np
from PIL import Image
from ops_mm_embedding_v1 import OpsMMEmbeddingV1
from concurrent.futures import ThreadPoolExecutor
from typing import List, Union, Tuple, Dict, Any, Optional
import os
import json
from pathlib import Path
import logging
import gc
import faiss
import time
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 设置离线模式
os.environ['TRANSFORMERS_OFFLINE'] = '1'
class MultimodalRetrievalLocal:
"""使用本地模型的多模态检索系统"""
def __init__(self,
model_path: str = "/root/models/Ops-MM-embedding-v1-7B",
use_all_gpus: bool = True,
gpu_ids: List[int] = None,
min_memory_gb: int = 12,
index_path: str = "local_faiss_index",
shard_model_across_gpus: bool = False,
load_in_4bit: bool = False,
load_in_8bit: bool = False,
torch_dtype: Optional[torch.dtype] = torch.bfloat16):
"""
初始化多模态检索系统
Args:
model_path: 本地模型路径
use_all_gpus: 是否使用所有可用GPU
gpu_ids: 指定使用的GPU ID列表
min_memory_gb: 最小可用内存(GB)
index_path: FAISS索引文件路径
"""
self.model_path = model_path
self.index_path = index_path
self.shard_model_across_gpus = shard_model_across_gpus
self.load_in_4bit = load_in_4bit
self.load_in_8bit = load_in_8bit
self.torch_dtype = torch_dtype
# 检查模型路径
if not os.path.exists(model_path):
logger.error(f"模型路径不存在: {model_path}")
logger.info("请先下载模型到指定路径")
raise FileNotFoundError(f"模型路径不存在: {model_path}")
# 设置GPU设备
self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb)
# 清理GPU内存
self._clear_all_gpu_memory()
# 加载嵌入模型
self._load_embedding_model()
# 初始化FAISS索引
self._init_index()
logger.info(f"多模态检索系统初始化完成,使用本地模型: {model_path}")
logger.info(f"向量存储路径: {index_path}")
def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int], min_memory_gb: int):
"""设置GPU设备"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.use_gpu = self.device.type == "cuda"
if self.use_gpu:
self.available_gpus = self._get_available_gpus(min_memory_gb)
if not self.available_gpus:
logger.warning(f"没有可用的GPU或GPU内存不足{min_memory_gb}GB将使用CPU")
self.device = torch.device("cpu")
self.use_gpu = False
else:
if gpu_ids:
self.gpu_ids = [gid for gid in gpu_ids if gid in self.available_gpus]
if not self.gpu_ids:
logger.warning(f"指定的GPU {gpu_ids}不可用或内存不足将使用可用的GPU: {self.available_gpus}")
self.gpu_ids = self.available_gpus
elif use_all_gpus:
self.gpu_ids = self.available_gpus
else:
self.gpu_ids = [self.available_gpus[0]]
logger.info(f"使用GPU: {self.gpu_ids}")
self.device = torch.device(f"cuda:{self.gpu_ids[0]}")
else:
logger.warning("没有可用的GPU将使用CPU")
self.gpu_ids = []
def _get_available_gpus(self, min_memory_gb: int) -> List[int]:
"""获取可用的GPU列表"""
available_gpus = []
for i in range(torch.cuda.device_count()):
total_mem = torch.cuda.get_device_properties(i).total_memory / (1024 ** 3) # GB
if total_mem >= min_memory_gb:
available_gpus.append(i)
return available_gpus
def _clear_all_gpu_memory(self):
"""清理GPU内存"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def _load_embedding_model(self):
"""加载多模态嵌入模型 OpsMMEmbeddingV1"""
logger.info(f"加载本地多模态嵌入模型: {self.model_path}")
try:
self.models: List[OpsMMEmbeddingV1] = []
if self.use_gpu and len(self.gpu_ids) > 1 and self.shard_model_across_gpus:
# Tensor-parallel sharding using device_map='auto' across visible GPUs
logger.info(f"启用模型跨GPU切片(shard),使用设备: {self.gpu_ids}")
# Rely on CUDA_VISIBLE_DEVICES to constrain which GPUs are visible, or use accelerate's auto mapping
ref_model = OpsMMEmbeddingV1(
self.model_path,
device="cuda",
attn_implementation=None,
device_map="auto",
load_in_4bit=self.load_in_4bit,
load_in_8bit=self.load_in_8bit,
torch_dtype=self.torch_dtype,
)
# For sharded model, we keep a single logical model reference
self.models.append(ref_model)
else:
if self.use_gpu and len(self.gpu_ids) > 1:
logger.info(f"检测到多GPU可用: {self.gpu_ids},为每张卡加载一个模型副本(数据并行)")
for gid in self.gpu_ids:
device_str = f"cuda:{gid}"
self.models.append(
OpsMMEmbeddingV1(
self.model_path,
device=device_str,
attn_implementation=None,
load_in_4bit=self.load_in_4bit,
load_in_8bit=self.load_in_8bit,
torch_dtype=self.torch_dtype,
)
)
ref_model = self.models[0]
else:
device_str = "cuda" if self.use_gpu else "cpu"
logger.info(f"使用单设备: {device_str}")
ref_model = OpsMMEmbeddingV1(
self.model_path,
device=device_str,
attn_implementation=None,
load_in_4bit=self.load_in_4bit,
load_in_8bit=self.load_in_8bit,
torch_dtype=self.torch_dtype,
)
self.models.append(ref_model)
# 获取向量维度
self.vector_dim = int(getattr(ref_model.base_model.config, "hidden_size"))
logger.info(f"向量维度: {self.vector_dim}")
logger.info("嵌入模型加载成功")
except Exception as e:
logger.error(f"嵌入模型加载失败: {str(e)}")
raise RuntimeError(f"嵌入模型加载失败: {str(e)}")
def _init_index(self):
"""初始化FAISS索引"""
index_file = f"{self.index_path}.index"
if os.path.exists(index_file):
logger.info(f"加载现有索引: {index_file}")
try:
self.index = faiss.read_index(index_file)
logger.info(f"索引加载成功,包含{self.index.ntotal}个向量")
except Exception as e:
logger.error(f"索引加载失败: {str(e)}")
logger.info("创建新索引...")
self.index = faiss.IndexFlatL2(self.vector_dim)
else:
logger.info(f"创建新索引,维度: {self.vector_dim}")
self.index = faiss.IndexFlatL2(self.vector_dim)
# 加载元数据
self.metadata = {}
metadata_file = f"{self.index_path}_metadata.json"
if os.path.exists(metadata_file):
try:
with open(metadata_file, 'r', encoding='utf-8') as f:
self.metadata = json.load(f)
logger.info(f"元数据加载成功,包含{len(self.metadata)}条记录")
except Exception as e:
logger.error(f"元数据加载失败: {str(e)}")
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
"""编码文本为向量(使用 OpsMMEmbeddingV1多GPU并行"""
if isinstance(text, str):
text = [text]
if len(text) == 0:
return np.zeros((0, self.vector_dim))
# 单条或单模型直接处理
if len(self.models) == 1 or len(text) == 1:
with torch.inference_mode():
emb = self.models[0].get_text_embeddings(texts=text)
arr = emb.detach().float().cpu().numpy()
return arr[0] if len(text) == 1 else arr
# 多GPU并行按设备平均切分
shard_count = len(self.models)
shards: List[List[str]] = []
for i in range(shard_count):
shards.append([])
for idx, t in enumerate(text):
shards[idx % shard_count].append(t)
results: List[np.ndarray] = [None] * shard_count # type: ignore
def run_shard(model: OpsMMEmbeddingV1, shard_texts: List[str], shard_idx: int):
if len(shard_texts) == 0:
results[shard_idx] = np.zeros((0, self.vector_dim))
return
with torch.inference_mode():
emb = model.get_text_embeddings(texts=shard_texts)
results[shard_idx] = emb.detach().float().cpu().numpy()
with ThreadPoolExecutor(max_workers=shard_count) as ex:
for i, (model, shard_texts) in enumerate(zip(self.models, shards)):
ex.submit(run_shard, model, shard_texts, i)
# 还原原始顺序(按 round-robin 拼接)
merged: List[np.ndarray] = []
max_len = max(len(s) for s in shards)
for k in range(max_len):
for i in range(shard_count):
if k < len(shards[i]) and results[i].shape[0] > 0:
merged.append(results[i][k:k+1])
if len(merged) == 0:
return np.zeros((0, self.vector_dim))
return np.vstack(merged)
def encode_image(self, image: Union[Image.Image, List[Image.Image]]) -> np.ndarray:
"""编码图像为向量(使用 OpsMMEmbeddingV1多GPU并行"""
try:
# 规范为列表
if isinstance(image, Image.Image):
images: List[Image.Image] = [image]
else:
images = image
if not images:
logger.error("encode_image: 图像列表为空")
return np.zeros((0, self.vector_dim))
# 强制为 RGB
rgb_images = [img.convert('RGB') if img.mode != 'RGB' else img for img in images]
# 单张或单模型直接处理
if len(self.models) == 1 or len(rgb_images) == 1:
with torch.inference_mode():
emb = self.models[0].get_image_embeddings(images=rgb_images)
return emb.detach().float().cpu().numpy()
# 多GPU并行按设备平均切分
shard_count = len(self.models)
shards: List[List[Image.Image]] = [[] for _ in range(shard_count)]
for idx, img in enumerate(rgb_images):
shards[idx % shard_count].append(img)
results: List[np.ndarray] = [None] * shard_count # type: ignore
def run_shard(model: OpsMMEmbeddingV1, shard_imgs: List[Image.Image], shard_idx: int):
if len(shard_imgs) == 0:
results[shard_idx] = np.zeros((0, self.vector_dim))
return
with torch.inference_mode():
emb = model.get_image_embeddings(images=shard_imgs)
results[shard_idx] = emb.detach().float().cpu().numpy()
with ThreadPoolExecutor(max_workers=shard_count) as ex:
for i, (model, shard_imgs) in enumerate(zip(self.models, shards)):
ex.submit(run_shard, model, shard_imgs, i)
# 还原原始顺序(按 round-robin 拼接)
merged: List[np.ndarray] = []
max_len = max(len(s) for s in shards)
for k in range(max_len):
for i in range(shard_count):
if k < len(shards[i]) and results[i].shape[0] > 0:
merged.append(results[i][k:k+1])
if len(merged) == 0:
return np.zeros((0, self.vector_dim))
return np.vstack(merged)
except Exception as e:
logger.error(f"encode_image: 异常: {str(e)}")
return np.zeros((0, self.vector_dim))
def add_texts(
self,
texts: List[str],
metadatas: Optional[List[Dict[str, Any]]] = None
) -> List[str]:
"""
添加文本到检索系统
Args:
texts: 文本列表
metadatas: 元数据列表,每个元素是一个字典
Returns:
添加的文本ID列表
"""
if not texts:
return []
if metadatas is None:
metadatas = [{} for _ in range(len(texts))]
if len(texts) != len(metadatas):
raise ValueError("texts和metadatas长度必须相同")
# 编码文本
text_embeddings = self.encode_text(texts)
# 准备元数据
start_id = self.index.ntotal
ids = list(range(start_id, start_id + len(texts)))
# 添加到索引
self.index.add(np.array(text_embeddings).astype('float32'))
# 保存元数据
for i, id in enumerate(ids):
self.metadata[str(id)] = {
"text": texts[i],
"type": "text",
**metadatas[i]
}
logger.info(f"成功添加{len(ids)}条文本到检索系统")
return [str(id) for id in ids]
def add_images(
self,
images: List[Image.Image],
metadatas: Optional[List[Dict[str, Any]]] = None,
image_paths: Optional[List[str]] = None
) -> List[str]:
"""
添加图像到检索系统
Args:
images: PIL图像列表
metadatas: 元数据列表,每个元素是一个字典
image_paths: 图像路径列表,用于保存到元数据
Returns:
添加的图像ID列表
"""
try:
logger.info(f"add_images: 开始添加图像,数量: {len(images) if images else 0}")
# 检查图像列表
if not images or len(images) == 0:
logger.warning("add_images: 图像列表为空")
return []
# 准备元数据
if metadatas is None:
logger.info("add_images: 创建默认元数据")
metadatas = [{} for _ in range(len(images))]
# 检查长度一致性
if len(images) != len(metadatas):
logger.error(f"add_images: 长度不一致 - images: {len(images)}, metadatas: {len(metadatas)}")
raise ValueError("images和metadatas长度必须相同")
# 编码图像
logger.info("add_images: 编码图像")
image_embeddings = self.encode_image(images)
# 检查编码结果
if image_embeddings.shape[0] == 0:
logger.error("add_images: 图像编码失败,返回空数组")
return []
# 准备元数据
start_id = self.index.ntotal
ids = list(range(start_id, start_id + len(images)))
logger.info(f"add_images: 生成索引ID: {start_id} - {start_id + len(images) - 1}")
# 添加到索引
logger.info(f"add_images: 添加向量到FAISS索引形状: {image_embeddings.shape}")
self.index.add(np.array(image_embeddings).astype('float32'))
# 保存元数据
for i, id in enumerate(ids):
try:
metadata = {
"type": "image",
"width": images[i].width,
"height": images[i].height,
**metadatas[i]
}
if image_paths and i < len(image_paths):
metadata["path"] = image_paths[i]
self.metadata[str(id)] = metadata
logger.debug(f"add_images: 保存元数据成功 - ID: {id}")
except Exception as e:
logger.error(f"add_images: 保存元数据失败 - ID: {id}, 错误: {str(e)}")
logger.info(f"add_images: 成功添加{len(ids)}张图像到检索系统")
return [str(id) for id in ids]
except Exception as e:
logger.error(f"add_images: 添加图像异常: {str(e)}")
return []
def search_by_text(
self,
query: str,
k: int = 5,
filter_type: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
文本搜索
Args:
query: 查询文本
k: 返回结果数量
filter_type: 过滤类型,可选值: "text", "image", None(不过滤)
Returns:
搜索结果列表,每个元素包含相似项和分数
"""
# 编码查询文本
query_embedding = self.encode_text(query)
# 执行搜索
return self._search(query_embedding, k, filter_type)
def search_by_image(
self,
image: Image.Image,
k: int = 5,
filter_type: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
图像搜索
Args:
image: 查询图像
k: 返回结果数量
filter_type: 过滤类型,可选值: "text", "image", None(不过滤)
Returns:
搜索结果列表,每个元素包含相似项和分数
"""
# 编码查询图像
query_embedding = self.encode_image(image)
# 执行搜索
return self._search(query_embedding, k, filter_type)
def _search(
self,
query_embedding: np.ndarray,
k: int = 5,
filter_type: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
执行搜索
Args:
query_embedding: 查询向量
k: 返回结果数量
filter_type: 过滤类型,可选值: "text", "image", None(不过滤)
Returns:
搜索结果列表
"""
if self.index.ntotal == 0:
return []
# 确保查询向量是2D数组
if len(query_embedding.shape) == 1:
query_embedding = query_embedding.reshape(1, -1)
# 执行搜索,获取更多结果以便过滤
actual_k = k * 3 if filter_type else k
actual_k = min(actual_k, self.index.ntotal)
distances, indices = self.index.search(query_embedding.astype('float32'), actual_k)
# 处理结果
results = []
for i in range(len(indices[0])):
idx = indices[0][i]
if idx < 0: # FAISS可能返回-1表示无效索引
continue
vector_id = str(idx)
if vector_id in self.metadata:
item = self.metadata[vector_id]
# 如果指定了过滤类型,则只返回该类型的结果
if filter_type and item.get("type") != filter_type:
continue
# 添加距离和分数
result = item.copy()
result["distance"] = float(distances[0][i])
result["score"] = float(1.0 / (1.0 + distances[0][i]))
results.append(result)
# 如果已经收集了足够的结果,则停止
if len(results) >= k:
break
return results
def save_index(self):
"""保存索引和元数据"""
# 保存索引
index_file = f"{self.index_path}.index"
try:
faiss.write_index(self.index, index_file)
logger.info(f"索引保存成功: {index_file}")
except Exception as e:
logger.error(f"索引保存失败: {str(e)}")
# 保存元数据
metadata_file = f"{self.index_path}_metadata.json"
try:
with open(metadata_file, 'w', encoding='utf-8') as f:
json.dump(self.metadata, f, ensure_ascii=False, indent=2)
logger.info(f"元数据保存成功: {metadata_file}")
except Exception as e:
logger.error(f"元数据保存失败: {str(e)}")
def get_stats(self) -> Dict[str, Any]:
"""获取检索系统统计信息"""
text_count = sum(1 for v in self.metadata.values() if v.get("type") == "text")
image_count = sum(1 for v in self.metadata.values() if v.get("type") == "image")
return {
"total_vectors": self.index.ntotal,
"text_count": text_count,
"image_count": image_count,
"vector_dimension": self.vector_dim,
"index_path": self.index_path,
"model_path": self.model_path
}
def clear_index(self):
"""清空索引"""
logger.info(f"清空索引: {self.index_path}")
# 重新创建索引
self.index = faiss.IndexFlatL2(self.vector_dim)
# 清空元数据
self.metadata = {}
# 保存空索引
self.save_index()
logger.info(f"索引已清空: {self.index_path}")
return True
def list_items(self) -> List[Dict[str, Any]]:
"""列出所有索引项"""
items = []
for item_id, metadata in self.metadata.items():
item = metadata.copy()
item['id'] = item_id
items.append(item)
return items
def __del__(self):
"""析构函数,确保资源被正确释放并自动保存索引"""
try:
if hasattr(self, 'model'):
del self.model
self._clear_all_gpu_memory()
if hasattr(self, 'index') and self.index is not None:
logger.info("系统关闭前自动保存索引")
self.save_index()
except Exception as e:
logger.error(f"析构时保存索引失败: {str(e)}")