606 lines
23 KiB
Python
606 lines
23 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
使用本地模型的多模态检索系统
|
||
支持文搜文、文搜图、图搜文、图搜图四种检索模式
|
||
"""
|
||
|
||
import torch
|
||
import numpy as np
|
||
from PIL import Image
|
||
from ops_mm_embedding_v1 import OpsMMEmbeddingV1
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from typing import List, Union, Tuple, Dict, Any, Optional
|
||
import os
|
||
import json
|
||
from pathlib import Path
|
||
import logging
|
||
import gc
|
||
import faiss
|
||
import time
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 设置离线模式
|
||
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
||
|
||
class MultimodalRetrievalLocal:
|
||
"""使用本地模型的多模态检索系统"""
|
||
|
||
def __init__(self,
|
||
model_path: str = "/root/models/Ops-MM-embedding-v1-7B",
|
||
use_all_gpus: bool = True,
|
||
gpu_ids: List[int] = None,
|
||
min_memory_gb: int = 12,
|
||
index_path: str = "local_faiss_index",
|
||
shard_model_across_gpus: bool = False,
|
||
load_in_4bit: bool = False,
|
||
load_in_8bit: bool = False,
|
||
torch_dtype: Optional[torch.dtype] = torch.bfloat16):
|
||
"""
|
||
初始化多模态检索系统
|
||
|
||
Args:
|
||
model_path: 本地模型路径
|
||
use_all_gpus: 是否使用所有可用GPU
|
||
gpu_ids: 指定使用的GPU ID列表
|
||
min_memory_gb: 最小可用内存(GB)
|
||
index_path: FAISS索引文件路径
|
||
"""
|
||
self.model_path = model_path
|
||
self.index_path = index_path
|
||
self.shard_model_across_gpus = shard_model_across_gpus
|
||
self.load_in_4bit = load_in_4bit
|
||
self.load_in_8bit = load_in_8bit
|
||
self.torch_dtype = torch_dtype
|
||
|
||
# 检查模型路径
|
||
if not os.path.exists(model_path):
|
||
logger.error(f"模型路径不存在: {model_path}")
|
||
logger.info("请先下载模型到指定路径")
|
||
raise FileNotFoundError(f"模型路径不存在: {model_path}")
|
||
|
||
# 设置GPU设备
|
||
self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb)
|
||
|
||
# 清理GPU内存
|
||
self._clear_all_gpu_memory()
|
||
|
||
# 加载嵌入模型
|
||
self._load_embedding_model()
|
||
|
||
# 初始化FAISS索引
|
||
self._init_index()
|
||
|
||
logger.info(f"多模态检索系统初始化完成,使用本地模型: {model_path}")
|
||
logger.info(f"向量存储路径: {index_path}")
|
||
|
||
def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int], min_memory_gb: int):
|
||
"""设置GPU设备"""
|
||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
self.use_gpu = self.device.type == "cuda"
|
||
|
||
if self.use_gpu:
|
||
self.available_gpus = self._get_available_gpus(min_memory_gb)
|
||
|
||
if not self.available_gpus:
|
||
logger.warning(f"没有可用的GPU或GPU内存不足{min_memory_gb}GB,将使用CPU")
|
||
self.device = torch.device("cpu")
|
||
self.use_gpu = False
|
||
else:
|
||
if gpu_ids:
|
||
self.gpu_ids = [gid for gid in gpu_ids if gid in self.available_gpus]
|
||
if not self.gpu_ids:
|
||
logger.warning(f"指定的GPU {gpu_ids}不可用或内存不足,将使用可用的GPU: {self.available_gpus}")
|
||
self.gpu_ids = self.available_gpus
|
||
elif use_all_gpus:
|
||
self.gpu_ids = self.available_gpus
|
||
else:
|
||
self.gpu_ids = [self.available_gpus[0]]
|
||
|
||
logger.info(f"使用GPU: {self.gpu_ids}")
|
||
self.device = torch.device(f"cuda:{self.gpu_ids[0]}")
|
||
else:
|
||
logger.warning("没有可用的GPU,将使用CPU")
|
||
self.gpu_ids = []
|
||
|
||
def _get_available_gpus(self, min_memory_gb: int) -> List[int]:
|
||
"""获取可用的GPU列表"""
|
||
available_gpus = []
|
||
for i in range(torch.cuda.device_count()):
|
||
total_mem = torch.cuda.get_device_properties(i).total_memory / (1024 ** 3) # GB
|
||
if total_mem >= min_memory_gb:
|
||
available_gpus.append(i)
|
||
return available_gpus
|
||
|
||
def _clear_all_gpu_memory(self):
|
||
"""清理GPU内存"""
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
gc.collect()
|
||
|
||
def _load_embedding_model(self):
|
||
"""加载多模态嵌入模型 OpsMMEmbeddingV1"""
|
||
logger.info(f"加载本地多模态嵌入模型: {self.model_path}")
|
||
try:
|
||
self.models: List[OpsMMEmbeddingV1] = []
|
||
if self.use_gpu and len(self.gpu_ids) > 1 and self.shard_model_across_gpus:
|
||
# Tensor-parallel sharding using device_map='auto' across visible GPUs
|
||
logger.info(f"启用模型跨GPU切片(shard),使用设备: {self.gpu_ids}")
|
||
# Rely on CUDA_VISIBLE_DEVICES to constrain which GPUs are visible, or use accelerate's auto mapping
|
||
ref_model = OpsMMEmbeddingV1(
|
||
self.model_path,
|
||
device="cuda",
|
||
attn_implementation=None,
|
||
device_map="auto",
|
||
load_in_4bit=self.load_in_4bit,
|
||
load_in_8bit=self.load_in_8bit,
|
||
torch_dtype=self.torch_dtype,
|
||
)
|
||
# For sharded model, we keep a single logical model reference
|
||
self.models.append(ref_model)
|
||
else:
|
||
if self.use_gpu and len(self.gpu_ids) > 1:
|
||
logger.info(f"检测到多GPU,可用: {self.gpu_ids},为每张卡加载一个模型副本(数据并行)")
|
||
for gid in self.gpu_ids:
|
||
device_str = f"cuda:{gid}"
|
||
self.models.append(
|
||
OpsMMEmbeddingV1(
|
||
self.model_path,
|
||
device=device_str,
|
||
attn_implementation=None,
|
||
load_in_4bit=self.load_in_4bit,
|
||
load_in_8bit=self.load_in_8bit,
|
||
torch_dtype=self.torch_dtype,
|
||
)
|
||
)
|
||
ref_model = self.models[0]
|
||
else:
|
||
device_str = "cuda" if self.use_gpu else "cpu"
|
||
logger.info(f"使用单设备: {device_str}")
|
||
ref_model = OpsMMEmbeddingV1(
|
||
self.model_path,
|
||
device=device_str,
|
||
attn_implementation=None,
|
||
load_in_4bit=self.load_in_4bit,
|
||
load_in_8bit=self.load_in_8bit,
|
||
torch_dtype=self.torch_dtype,
|
||
)
|
||
self.models.append(ref_model)
|
||
|
||
# 获取向量维度
|
||
self.vector_dim = int(getattr(ref_model.base_model.config, "hidden_size"))
|
||
logger.info(f"向量维度: {self.vector_dim}")
|
||
logger.info("嵌入模型加载成功")
|
||
except Exception as e:
|
||
logger.error(f"嵌入模型加载失败: {str(e)}")
|
||
raise RuntimeError(f"嵌入模型加载失败: {str(e)}")
|
||
|
||
def _init_index(self):
|
||
"""初始化FAISS索引"""
|
||
index_file = f"{self.index_path}.index"
|
||
if os.path.exists(index_file):
|
||
logger.info(f"加载现有索引: {index_file}")
|
||
try:
|
||
self.index = faiss.read_index(index_file)
|
||
logger.info(f"索引加载成功,包含{self.index.ntotal}个向量")
|
||
except Exception as e:
|
||
logger.error(f"索引加载失败: {str(e)}")
|
||
logger.info("创建新索引...")
|
||
self.index = faiss.IndexFlatL2(self.vector_dim)
|
||
else:
|
||
logger.info(f"创建新索引,维度: {self.vector_dim}")
|
||
self.index = faiss.IndexFlatL2(self.vector_dim)
|
||
|
||
# 加载元数据
|
||
self.metadata = {}
|
||
metadata_file = f"{self.index_path}_metadata.json"
|
||
if os.path.exists(metadata_file):
|
||
try:
|
||
with open(metadata_file, 'r', encoding='utf-8') as f:
|
||
self.metadata = json.load(f)
|
||
logger.info(f"元数据加载成功,包含{len(self.metadata)}条记录")
|
||
except Exception as e:
|
||
logger.error(f"元数据加载失败: {str(e)}")
|
||
|
||
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
|
||
"""编码文本为向量(使用 OpsMMEmbeddingV1,多GPU并行)"""
|
||
if isinstance(text, str):
|
||
text = [text]
|
||
if len(text) == 0:
|
||
return np.zeros((0, self.vector_dim))
|
||
|
||
# 单条或单模型直接处理
|
||
if len(self.models) == 1 or len(text) == 1:
|
||
with torch.inference_mode():
|
||
emb = self.models[0].get_text_embeddings(texts=text)
|
||
arr = emb.detach().float().cpu().numpy()
|
||
return arr[0] if len(text) == 1 else arr
|
||
|
||
# 多GPU并行:按设备平均切分
|
||
shard_count = len(self.models)
|
||
shards: List[List[str]] = []
|
||
for i in range(shard_count):
|
||
shards.append([])
|
||
for idx, t in enumerate(text):
|
||
shards[idx % shard_count].append(t)
|
||
|
||
results: List[np.ndarray] = [None] * shard_count # type: ignore
|
||
|
||
def run_shard(model: OpsMMEmbeddingV1, shard_texts: List[str], shard_idx: int):
|
||
if len(shard_texts) == 0:
|
||
results[shard_idx] = np.zeros((0, self.vector_dim))
|
||
return
|
||
with torch.inference_mode():
|
||
emb = model.get_text_embeddings(texts=shard_texts)
|
||
results[shard_idx] = emb.detach().float().cpu().numpy()
|
||
|
||
with ThreadPoolExecutor(max_workers=shard_count) as ex:
|
||
for i, (model, shard_texts) in enumerate(zip(self.models, shards)):
|
||
ex.submit(run_shard, model, shard_texts, i)
|
||
|
||
# 还原原始顺序(按 round-robin 拼接)
|
||
merged: List[np.ndarray] = []
|
||
max_len = max(len(s) for s in shards)
|
||
for k in range(max_len):
|
||
for i in range(shard_count):
|
||
if k < len(shards[i]) and results[i].shape[0] > 0:
|
||
merged.append(results[i][k:k+1])
|
||
if len(merged) == 0:
|
||
return np.zeros((0, self.vector_dim))
|
||
return np.vstack(merged)
|
||
|
||
def encode_image(self, image: Union[Image.Image, List[Image.Image]]) -> np.ndarray:
|
||
"""编码图像为向量(使用 OpsMMEmbeddingV1,多GPU并行)"""
|
||
try:
|
||
# 规范为列表
|
||
if isinstance(image, Image.Image):
|
||
images: List[Image.Image] = [image]
|
||
else:
|
||
images = image
|
||
if not images:
|
||
logger.error("encode_image: 图像列表为空")
|
||
return np.zeros((0, self.vector_dim))
|
||
|
||
# 强制为 RGB
|
||
rgb_images = [img.convert('RGB') if img.mode != 'RGB' else img for img in images]
|
||
|
||
# 单张或单模型直接处理
|
||
if len(self.models) == 1 or len(rgb_images) == 1:
|
||
with torch.inference_mode():
|
||
emb = self.models[0].get_image_embeddings(images=rgb_images)
|
||
return emb.detach().float().cpu().numpy()
|
||
|
||
# 多GPU并行:按设备平均切分
|
||
shard_count = len(self.models)
|
||
shards: List[List[Image.Image]] = [[] for _ in range(shard_count)]
|
||
for idx, img in enumerate(rgb_images):
|
||
shards[idx % shard_count].append(img)
|
||
|
||
results: List[np.ndarray] = [None] * shard_count # type: ignore
|
||
|
||
def run_shard(model: OpsMMEmbeddingV1, shard_imgs: List[Image.Image], shard_idx: int):
|
||
if len(shard_imgs) == 0:
|
||
results[shard_idx] = np.zeros((0, self.vector_dim))
|
||
return
|
||
with torch.inference_mode():
|
||
emb = model.get_image_embeddings(images=shard_imgs)
|
||
results[shard_idx] = emb.detach().float().cpu().numpy()
|
||
|
||
with ThreadPoolExecutor(max_workers=shard_count) as ex:
|
||
for i, (model, shard_imgs) in enumerate(zip(self.models, shards)):
|
||
ex.submit(run_shard, model, shard_imgs, i)
|
||
|
||
# 还原原始顺序(按 round-robin 拼接)
|
||
merged: List[np.ndarray] = []
|
||
max_len = max(len(s) for s in shards)
|
||
for k in range(max_len):
|
||
for i in range(shard_count):
|
||
if k < len(shards[i]) and results[i].shape[0] > 0:
|
||
merged.append(results[i][k:k+1])
|
||
if len(merged) == 0:
|
||
return np.zeros((0, self.vector_dim))
|
||
return np.vstack(merged)
|
||
except Exception as e:
|
||
logger.error(f"encode_image: 异常: {str(e)}")
|
||
return np.zeros((0, self.vector_dim))
|
||
|
||
def add_texts(
|
||
self,
|
||
texts: List[str],
|
||
metadatas: Optional[List[Dict[str, Any]]] = None
|
||
) -> List[str]:
|
||
"""
|
||
添加文本到检索系统
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
metadatas: 元数据列表,每个元素是一个字典
|
||
|
||
Returns:
|
||
添加的文本ID列表
|
||
"""
|
||
if not texts:
|
||
return []
|
||
|
||
if metadatas is None:
|
||
metadatas = [{} for _ in range(len(texts))]
|
||
|
||
if len(texts) != len(metadatas):
|
||
raise ValueError("texts和metadatas长度必须相同")
|
||
|
||
# 编码文本
|
||
text_embeddings = self.encode_text(texts)
|
||
|
||
# 准备元数据
|
||
start_id = self.index.ntotal
|
||
ids = list(range(start_id, start_id + len(texts)))
|
||
|
||
# 添加到索引
|
||
self.index.add(np.array(text_embeddings).astype('float32'))
|
||
|
||
# 保存元数据
|
||
for i, id in enumerate(ids):
|
||
self.metadata[str(id)] = {
|
||
"text": texts[i],
|
||
"type": "text",
|
||
**metadatas[i]
|
||
}
|
||
|
||
logger.info(f"成功添加{len(ids)}条文本到检索系统")
|
||
return [str(id) for id in ids]
|
||
|
||
def add_images(
|
||
self,
|
||
images: List[Image.Image],
|
||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||
image_paths: Optional[List[str]] = None
|
||
) -> List[str]:
|
||
"""
|
||
添加图像到检索系统
|
||
|
||
Args:
|
||
images: PIL图像列表
|
||
metadatas: 元数据列表,每个元素是一个字典
|
||
image_paths: 图像路径列表,用于保存到元数据
|
||
|
||
Returns:
|
||
添加的图像ID列表
|
||
"""
|
||
try:
|
||
logger.info(f"add_images: 开始添加图像,数量: {len(images) if images else 0}")
|
||
|
||
# 检查图像列表
|
||
if not images or len(images) == 0:
|
||
logger.warning("add_images: 图像列表为空")
|
||
return []
|
||
|
||
# 准备元数据
|
||
if metadatas is None:
|
||
logger.info("add_images: 创建默认元数据")
|
||
metadatas = [{} for _ in range(len(images))]
|
||
|
||
# 检查长度一致性
|
||
if len(images) != len(metadatas):
|
||
logger.error(f"add_images: 长度不一致 - images: {len(images)}, metadatas: {len(metadatas)}")
|
||
raise ValueError("images和metadatas长度必须相同")
|
||
|
||
# 编码图像
|
||
logger.info("add_images: 编码图像")
|
||
image_embeddings = self.encode_image(images)
|
||
|
||
# 检查编码结果
|
||
if image_embeddings.shape[0] == 0:
|
||
logger.error("add_images: 图像编码失败,返回空数组")
|
||
return []
|
||
|
||
# 准备元数据
|
||
start_id = self.index.ntotal
|
||
ids = list(range(start_id, start_id + len(images)))
|
||
logger.info(f"add_images: 生成索引ID: {start_id} - {start_id + len(images) - 1}")
|
||
|
||
# 添加到索引
|
||
logger.info(f"add_images: 添加向量到FAISS索引,形状: {image_embeddings.shape}")
|
||
self.index.add(np.array(image_embeddings).astype('float32'))
|
||
|
||
# 保存元数据
|
||
for i, id in enumerate(ids):
|
||
try:
|
||
metadata = {
|
||
"type": "image",
|
||
"width": images[i].width,
|
||
"height": images[i].height,
|
||
**metadatas[i]
|
||
}
|
||
|
||
if image_paths and i < len(image_paths):
|
||
metadata["path"] = image_paths[i]
|
||
|
||
self.metadata[str(id)] = metadata
|
||
logger.debug(f"add_images: 保存元数据成功 - ID: {id}")
|
||
except Exception as e:
|
||
logger.error(f"add_images: 保存元数据失败 - ID: {id}, 错误: {str(e)}")
|
||
|
||
logger.info(f"add_images: 成功添加{len(ids)}张图像到检索系统")
|
||
return [str(id) for id in ids]
|
||
|
||
except Exception as e:
|
||
logger.error(f"add_images: 添加图像异常: {str(e)}")
|
||
return []
|
||
|
||
def search_by_text(
|
||
self,
|
||
query: str,
|
||
k: int = 5,
|
||
filter_type: Optional[str] = None
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
文本搜索
|
||
|
||
Args:
|
||
query: 查询文本
|
||
k: 返回结果数量
|
||
filter_type: 过滤类型,可选值: "text", "image", None(不过滤)
|
||
|
||
Returns:
|
||
搜索结果列表,每个元素包含相似项和分数
|
||
"""
|
||
# 编码查询文本
|
||
query_embedding = self.encode_text(query)
|
||
|
||
# 执行搜索
|
||
return self._search(query_embedding, k, filter_type)
|
||
|
||
def search_by_image(
|
||
self,
|
||
image: Image.Image,
|
||
k: int = 5,
|
||
filter_type: Optional[str] = None
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
图像搜索
|
||
|
||
Args:
|
||
image: 查询图像
|
||
k: 返回结果数量
|
||
filter_type: 过滤类型,可选值: "text", "image", None(不过滤)
|
||
|
||
Returns:
|
||
搜索结果列表,每个元素包含相似项和分数
|
||
"""
|
||
# 编码查询图像
|
||
query_embedding = self.encode_image(image)
|
||
|
||
# 执行搜索
|
||
return self._search(query_embedding, k, filter_type)
|
||
|
||
def _search(
|
||
self,
|
||
query_embedding: np.ndarray,
|
||
k: int = 5,
|
||
filter_type: Optional[str] = None
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
执行搜索
|
||
|
||
Args:
|
||
query_embedding: 查询向量
|
||
k: 返回结果数量
|
||
filter_type: 过滤类型,可选值: "text", "image", None(不过滤)
|
||
|
||
Returns:
|
||
搜索结果列表
|
||
"""
|
||
if self.index.ntotal == 0:
|
||
return []
|
||
|
||
# 确保查询向量是2D数组
|
||
if len(query_embedding.shape) == 1:
|
||
query_embedding = query_embedding.reshape(1, -1)
|
||
|
||
# 执行搜索,获取更多结果以便过滤
|
||
actual_k = k * 3 if filter_type else k
|
||
actual_k = min(actual_k, self.index.ntotal)
|
||
distances, indices = self.index.search(query_embedding.astype('float32'), actual_k)
|
||
|
||
# 处理结果
|
||
results = []
|
||
for i in range(len(indices[0])):
|
||
idx = indices[0][i]
|
||
if idx < 0: # FAISS可能返回-1表示无效索引
|
||
continue
|
||
|
||
vector_id = str(idx)
|
||
if vector_id in self.metadata:
|
||
item = self.metadata[vector_id]
|
||
|
||
# 如果指定了过滤类型,则只返回该类型的结果
|
||
if filter_type and item.get("type") != filter_type:
|
||
continue
|
||
|
||
# 添加距离和分数
|
||
result = item.copy()
|
||
result["distance"] = float(distances[0][i])
|
||
result["score"] = float(1.0 / (1.0 + distances[0][i]))
|
||
results.append(result)
|
||
|
||
# 如果已经收集了足够的结果,则停止
|
||
if len(results) >= k:
|
||
break
|
||
|
||
return results
|
||
|
||
def save_index(self):
|
||
"""保存索引和元数据"""
|
||
# 保存索引
|
||
index_file = f"{self.index_path}.index"
|
||
try:
|
||
faiss.write_index(self.index, index_file)
|
||
logger.info(f"索引保存成功: {index_file}")
|
||
except Exception as e:
|
||
logger.error(f"索引保存失败: {str(e)}")
|
||
|
||
# 保存元数据
|
||
metadata_file = f"{self.index_path}_metadata.json"
|
||
try:
|
||
with open(metadata_file, 'w', encoding='utf-8') as f:
|
||
json.dump(self.metadata, f, ensure_ascii=False, indent=2)
|
||
logger.info(f"元数据保存成功: {metadata_file}")
|
||
except Exception as e:
|
||
logger.error(f"元数据保存失败: {str(e)}")
|
||
|
||
def get_stats(self) -> Dict[str, Any]:
|
||
"""获取检索系统统计信息"""
|
||
text_count = sum(1 for v in self.metadata.values() if v.get("type") == "text")
|
||
image_count = sum(1 for v in self.metadata.values() if v.get("type") == "image")
|
||
|
||
return {
|
||
"total_vectors": self.index.ntotal,
|
||
"text_count": text_count,
|
||
"image_count": image_count,
|
||
"vector_dimension": self.vector_dim,
|
||
"index_path": self.index_path,
|
||
"model_path": self.model_path
|
||
}
|
||
|
||
def clear_index(self):
|
||
"""清空索引"""
|
||
logger.info(f"清空索引: {self.index_path}")
|
||
|
||
# 重新创建索引
|
||
self.index = faiss.IndexFlatL2(self.vector_dim)
|
||
|
||
# 清空元数据
|
||
self.metadata = {}
|
||
|
||
# 保存空索引
|
||
self.save_index()
|
||
|
||
logger.info(f"索引已清空: {self.index_path}")
|
||
return True
|
||
|
||
def list_items(self) -> List[Dict[str, Any]]:
|
||
"""列出所有索引项"""
|
||
items = []
|
||
|
||
for item_id, metadata in self.metadata.items():
|
||
item = metadata.copy()
|
||
item['id'] = item_id
|
||
items.append(item)
|
||
|
||
return items
|
||
|
||
def __del__(self):
|
||
"""析构函数,确保资源被正确释放并自动保存索引"""
|
||
try:
|
||
if hasattr(self, 'model'):
|
||
del self.model
|
||
self._clear_all_gpu_memory()
|
||
if hasattr(self, 'index') and self.index is not None:
|
||
logger.info("系统关闭前自动保存索引")
|
||
self.save_index()
|
||
except Exception as e:
|
||
logger.error(f"析构时保存索引失败: {str(e)}")
|