608 lines
24 KiB
Python
608 lines
24 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
使用本地模型的多模态检索系统
|
||
支持文搜文、文搜图、图搜文、图搜图四种检索模式
|
||
"""
|
||
|
||
import torch
|
||
import numpy as np
|
||
from PIL import Image
|
||
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
||
from typing import List, Union, Tuple, Dict, Any, Optional
|
||
import os
|
||
import json
|
||
from pathlib import Path
|
||
import logging
|
||
import gc
|
||
import faiss
|
||
import time
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 设置离线模式
|
||
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
||
|
||
class MultimodalRetrievalLocal:
|
||
"""使用本地模型的多模态检索系统"""
|
||
|
||
def __init__(self,
|
||
model_path: str = "/root/models/Ops-MM-embedding-v1-7B",
|
||
use_all_gpus: bool = True,
|
||
gpu_ids: List[int] = None,
|
||
min_memory_gb: int = 12,
|
||
index_path: str = "local_faiss_index"):
|
||
"""
|
||
初始化多模态检索系统
|
||
|
||
Args:
|
||
model_path: 本地模型路径
|
||
use_all_gpus: 是否使用所有可用GPU
|
||
gpu_ids: 指定使用的GPU ID列表
|
||
min_memory_gb: 最小可用内存(GB)
|
||
index_path: FAISS索引文件路径
|
||
"""
|
||
self.model_path = model_path
|
||
self.index_path = index_path
|
||
|
||
# 检查模型路径
|
||
if not os.path.exists(model_path):
|
||
logger.error(f"模型路径不存在: {model_path}")
|
||
logger.info("请先下载模型到指定路径")
|
||
raise FileNotFoundError(f"模型路径不存在: {model_path}")
|
||
|
||
# 设置GPU设备
|
||
self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb)
|
||
|
||
# 清理GPU内存
|
||
self._clear_all_gpu_memory()
|
||
|
||
# 加载模型和处理器
|
||
self._load_model_and_processor()
|
||
|
||
# 初始化FAISS索引
|
||
self._init_index()
|
||
|
||
logger.info(f"多模态检索系统初始化完成,使用本地模型: {model_path}")
|
||
logger.info(f"向量存储路径: {index_path}")
|
||
|
||
def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int], min_memory_gb: int):
|
||
"""设置GPU设备"""
|
||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
self.use_gpu = self.device.type == "cuda"
|
||
|
||
if self.use_gpu:
|
||
self.available_gpus = self._get_available_gpus(min_memory_gb)
|
||
|
||
if not self.available_gpus:
|
||
logger.warning(f"没有可用的GPU或GPU内存不足{min_memory_gb}GB,将使用CPU")
|
||
self.device = torch.device("cpu")
|
||
self.use_gpu = False
|
||
else:
|
||
if gpu_ids:
|
||
self.gpu_ids = [gid for gid in gpu_ids if gid in self.available_gpus]
|
||
if not self.gpu_ids:
|
||
logger.warning(f"指定的GPU {gpu_ids}不可用或内存不足,将使用可用的GPU: {self.available_gpus}")
|
||
self.gpu_ids = self.available_gpus
|
||
elif use_all_gpus:
|
||
self.gpu_ids = self.available_gpus
|
||
else:
|
||
self.gpu_ids = [self.available_gpus[0]]
|
||
|
||
logger.info(f"使用GPU: {self.gpu_ids}")
|
||
self.device = torch.device(f"cuda:{self.gpu_ids[0]}")
|
||
else:
|
||
logger.warning("没有可用的GPU,将使用CPU")
|
||
self.gpu_ids = []
|
||
|
||
def _get_available_gpus(self, min_memory_gb: int) -> List[int]:
|
||
"""获取可用的GPU列表"""
|
||
available_gpus = []
|
||
for i in range(torch.cuda.device_count()):
|
||
total_mem = torch.cuda.get_device_properties(i).total_memory / (1024 ** 3) # GB
|
||
if total_mem >= min_memory_gb:
|
||
available_gpus.append(i)
|
||
return available_gpus
|
||
|
||
def _clear_all_gpu_memory(self):
|
||
"""清理GPU内存"""
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
gc.collect()
|
||
|
||
def _load_model_and_processor(self):
|
||
"""加载模型和处理器"""
|
||
logger.info(f"加载本地模型和处理器: {self.model_path}")
|
||
|
||
try:
|
||
# 加载模型和处理器
|
||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
|
||
self.processor = AutoProcessor.from_pretrained(self.model_path)
|
||
|
||
# 输出处理器信息
|
||
logger.info(f"Processor类型: {type(self.processor)}")
|
||
logger.info(f"Processor方法: {dir(self.processor)}")
|
||
|
||
# 检查是否有图像处理器
|
||
if hasattr(self.processor, 'image_processor'):
|
||
logger.info(f"Image processor类型: {type(self.processor.image_processor)}")
|
||
logger.info(f"Image processor方法: {dir(self.processor.image_processor)}")
|
||
|
||
# 加载模型
|
||
self.model = AutoModel.from_pretrained(
|
||
self.model_path,
|
||
torch_dtype=torch.float16 if self.use_gpu else torch.float32,
|
||
device_map="auto" if len(self.gpu_ids) > 1 else None
|
||
)
|
||
|
||
if len(self.gpu_ids) == 1:
|
||
self.model.to(self.device)
|
||
|
||
self.model.eval()
|
||
|
||
# 获取向量维度
|
||
self.vector_dim = self.model.config.hidden_size
|
||
logger.info(f"向量维度: {self.vector_dim}")
|
||
|
||
logger.info("模型和处理器加载成功")
|
||
|
||
except Exception as e:
|
||
logger.error(f"模型加载失败: {str(e)}")
|
||
raise RuntimeError(f"模型加载失败: {str(e)}")
|
||
|
||
def _init_index(self):
|
||
"""初始化FAISS索引"""
|
||
index_file = f"{self.index_path}.index"
|
||
if os.path.exists(index_file):
|
||
logger.info(f"加载现有索引: {index_file}")
|
||
try:
|
||
self.index = faiss.read_index(index_file)
|
||
logger.info(f"索引加载成功,包含{self.index.ntotal}个向量")
|
||
except Exception as e:
|
||
logger.error(f"索引加载失败: {str(e)}")
|
||
logger.info("创建新索引...")
|
||
self.index = faiss.IndexFlatL2(self.vector_dim)
|
||
else:
|
||
logger.info(f"创建新索引,维度: {self.vector_dim}")
|
||
self.index = faiss.IndexFlatL2(self.vector_dim)
|
||
|
||
# 加载元数据
|
||
self.metadata = {}
|
||
metadata_file = f"{self.index_path}_metadata.json"
|
||
if os.path.exists(metadata_file):
|
||
try:
|
||
with open(metadata_file, 'r', encoding='utf-8') as f:
|
||
self.metadata = json.load(f)
|
||
logger.info(f"元数据加载成功,包含{len(self.metadata)}条记录")
|
||
except Exception as e:
|
||
logger.error(f"元数据加载失败: {str(e)}")
|
||
|
||
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
|
||
"""编码文本为向量"""
|
||
if isinstance(text, str):
|
||
text = [text]
|
||
|
||
inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
|
||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||
|
||
with torch.no_grad():
|
||
outputs = self.model(**inputs)
|
||
# 获取[CLS]标记的隐藏状态作为句子表示
|
||
text_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
||
|
||
# 归一化向量
|
||
text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True)
|
||
return text_embeddings[0] if len(text) == 1 else text_embeddings
|
||
|
||
def encode_image(self, image: Union[Image.Image, List[Image.Image]]) -> np.ndarray:
|
||
"""编码图像为向量"""
|
||
try:
|
||
logger.info(f"encode_image: 开始编码图像,类型: {type(image)}")
|
||
|
||
if isinstance(image, Image.Image):
|
||
logger.info(f"encode_image: 单个图像,大小: {image.size}")
|
||
image = [image]
|
||
else:
|
||
logger.info(f"encode_image: 图像列表,长度: {len(image)}")
|
||
|
||
# 检查图像是否为空
|
||
if not image or len(image) == 0:
|
||
logger.error("encode_image: 图像列表为空")
|
||
# 返回一个空的二维数组
|
||
return np.zeros((0, self.vector_dim))
|
||
|
||
# 检查图像是否有效
|
||
for i, img in enumerate(image):
|
||
if not isinstance(img, Image.Image):
|
||
logger.error(f"encode_image: 第{i}个元素不是有效的PIL图像,类型: {type(img)}")
|
||
# 返回一个空的二维数组
|
||
return np.zeros((0, self.vector_dim))
|
||
|
||
logger.info("encode_image: 处理图像输入")
|
||
|
||
# 检查图像格式
|
||
for i, img in enumerate(image):
|
||
logger.info(f"encode_image: 图像 {i} 格式: {img.format}, 模式: {img.mode}, 大小: {img.size}")
|
||
# 转换为RGB模式,如果不是
|
||
if img.mode != 'RGB':
|
||
logger.info(f"encode_image: 将图像 {i} 从 {img.mode} 转换为 RGB")
|
||
image[i] = img.convert('RGB')
|
||
|
||
try:
|
||
# 直接使用image_processor处理图像
|
||
if hasattr(self.processor, 'image_processor'):
|
||
logger.info("encode_image: 使用image_processor处理图像")
|
||
pixel_values = self.processor.image_processor(images=image, return_tensors="pt").pixel_values
|
||
inputs = {"pixel_values": pixel_values}
|
||
else:
|
||
logger.info("encode_image: 使用processor处理图像")
|
||
inputs = self.processor(images=image, return_tensors="pt")
|
||
|
||
if not inputs or len(inputs) == 0:
|
||
logger.error("encode_image: processor返回了空的输入")
|
||
return np.zeros((0, self.vector_dim))
|
||
|
||
logger.info(f"encode_image: 处理后的输入键: {list(inputs.keys())}")
|
||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||
|
||
logger.info("encode_image: 运行模型推理")
|
||
logger.info(f"Model类型: {type(self.model)}")
|
||
logger.info(f"Model属性: {dir(self.model)}")
|
||
|
||
# 检查模型结构
|
||
try:
|
||
logger.info(f"Model配置: {self.model.config}")
|
||
logger.info(f"Model配置属性: {dir(self.model.config)}")
|
||
else:
|
||
visual_outputs = self.model.visual(**inputs)
|
||
|
||
if hasattr(visual_outputs, 'pooler_output'):
|
||
image_embeddings = visual_outputs.pooler_output.cpu().numpy()
|
||
elif hasattr(visual_outputs, 'last_hidden_state'):
|
||
image_embeddings = visual_outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
||
else:
|
||
logger.error("encode_image: 无法从视觉模型输出中获取图像向量")
|
||
raise ValueError("无法从视觉模型输出中获取图像向量")
|
||
else:
|
||
# 尝试直接使用模型进行推理
|
||
logger.info("encode_image: 尝试直接使用模型进行推理")
|
||
with torch.no_grad():
|
||
# 使用空文本输入,只提供图像
|
||
if 'pixel_values' in inputs:
|
||
outputs = self.model(pixel_values=inputs['pixel_values'], input_ids=None)
|
||
else:
|
||
outputs = self.model(**inputs, input_ids=None)
|
||
|
||
# 尝试从输出中获取图像向量
|
||
if hasattr(outputs, 'image_embeds'):
|
||
image_embeddings = outputs.image_embeds.cpu().numpy()
|
||
elif hasattr(outputs, 'vision_model_output') and hasattr(outputs.vision_model_output, 'pooler_output'):
|
||
image_embeddings = outputs.vision_model_output.pooler_output.cpu().numpy()
|
||
elif hasattr(outputs, 'pooler_output'):
|
||
image_embeddings = outputs.pooler_output.cpu().numpy()
|
||
elif hasattr(outputs, 'last_hidden_state'):
|
||
image_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
||
else:
|
||
logger.error("encode_image: 无法从模型输出中获取图像向量")
|
||
raise ValueError("无法从模型输出中获取图像向量")
|
||
except Exception as e:
|
||
logger.error(f"encode_image: 处理图像时出错: {str(e)}")
|
||
raise e
|
||
return np.zeros((0, self.vector_dim))
|
||
|
||
# 归一化向量
|
||
image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
|
||
|
||
# 始终返回二维数组,即使只有一个图像
|
||
if len(image) == 1:
|
||
result = np.array([image_embeddings[0]])
|
||
logger.info(f"encode_image: 返回单个图像向量,形状: {result.shape}")
|
||
return result
|
||
else:
|
||
logger.info(f"encode_image: 返回多个图像向量,形状: {image_embeddings.shape}")
|
||
return image_embeddings
|
||
|
||
except Exception as e:
|
||
logger.error(f"encode_image: 异常: {str(e)}")
|
||
# 返回一个空的二维数组
|
||
return np.zeros((0, self.vector_dim))
|
||
|
||
def add_texts(
|
||
self,
|
||
texts: List[str],
|
||
metadatas: Optional[List[Dict[str, Any]]] = None
|
||
) -> List[str]:
|
||
"""
|
||
添加文本到检索系统
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
metadatas: 元数据列表,每个元素是一个字典
|
||
|
||
Returns:
|
||
添加的文本ID列表
|
||
"""
|
||
if not texts:
|
||
return []
|
||
|
||
if metadatas is None:
|
||
metadatas = [{} for _ in range(len(texts))]
|
||
|
||
if len(texts) != len(metadatas):
|
||
raise ValueError("texts和metadatas长度必须相同")
|
||
|
||
# 编码文本
|
||
text_embeddings = self.encode_text(texts)
|
||
|
||
# 准备元数据
|
||
start_id = self.index.ntotal
|
||
ids = list(range(start_id, start_id + len(texts)))
|
||
|
||
# 添加到索引
|
||
self.index.add(np.array(text_embeddings).astype('float32'))
|
||
|
||
# 保存元数据
|
||
for i, id in enumerate(ids):
|
||
self.metadata[str(id)] = {
|
||
"text": texts[i],
|
||
"type": "text",
|
||
**metadatas[i]
|
||
}
|
||
|
||
logger.info(f"成功添加{len(ids)}条文本到检索系统")
|
||
return [str(id) for id in ids]
|
||
|
||
def add_images(
|
||
self,
|
||
images: List[Image.Image],
|
||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||
image_paths: Optional[List[str]] = None
|
||
) -> List[str]:
|
||
"""
|
||
添加图像到检索系统
|
||
|
||
Args:
|
||
images: PIL图像列表
|
||
metadatas: 元数据列表,每个元素是一个字典
|
||
image_paths: 图像路径列表,用于保存到元数据
|
||
|
||
Returns:
|
||
添加的图像ID列表
|
||
"""
|
||
try:
|
||
logger.info(f"add_images: 开始添加图像,数量: {len(images) if images else 0}")
|
||
|
||
# 检查图像列表
|
||
if not images or len(images) == 0:
|
||
logger.warning("add_images: 图像列表为空")
|
||
return []
|
||
|
||
# 准备元数据
|
||
if metadatas is None:
|
||
logger.info("add_images: 创建默认元数据")
|
||
metadatas = [{} for _ in range(len(images))]
|
||
|
||
# 检查长度一致性
|
||
if len(images) != len(metadatas):
|
||
logger.error(f"add_images: 长度不一致 - images: {len(images)}, metadatas: {len(metadatas)}")
|
||
raise ValueError("images和metadatas长度必须相同")
|
||
|
||
# 编码图像
|
||
logger.info("add_images: 编码图像")
|
||
image_embeddings = self.encode_image(images)
|
||
|
||
# 检查编码结果
|
||
if image_embeddings.shape[0] == 0:
|
||
logger.error("add_images: 图像编码失败,返回空数组")
|
||
return []
|
||
|
||
# 准备元数据
|
||
start_id = self.index.ntotal
|
||
ids = list(range(start_id, start_id + len(images)))
|
||
logger.info(f"add_images: 生成索引ID: {start_id} - {start_id + len(images) - 1}")
|
||
|
||
# 添加到索引
|
||
logger.info(f"add_images: 添加向量到FAISS索引,形状: {image_embeddings.shape}")
|
||
self.index.add(np.array(image_embeddings).astype('float32'))
|
||
|
||
# 保存元数据
|
||
for i, id in enumerate(ids):
|
||
try:
|
||
metadata = {
|
||
"type": "image",
|
||
"width": images[i].width,
|
||
"height": images[i].height,
|
||
**metadatas[i]
|
||
}
|
||
|
||
if image_paths and i < len(image_paths):
|
||
metadata["path"] = image_paths[i]
|
||
|
||
self.metadata[str(id)] = metadata
|
||
logger.debug(f"add_images: 保存元数据成功 - ID: {id}")
|
||
except Exception as e:
|
||
logger.error(f"add_images: 保存元数据失败 - ID: {id}, 错误: {str(e)}")
|
||
|
||
logger.info(f"add_images: 成功添加{len(ids)}张图像到检索系统")
|
||
return [str(id) for id in ids]
|
||
|
||
except Exception as e:
|
||
logger.error(f"add_images: 添加图像异常: {str(e)}")
|
||
return []
|
||
|
||
def search_by_text(
|
||
self,
|
||
query: str,
|
||
k: int = 5,
|
||
filter_type: Optional[str] = None
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
文本搜索
|
||
|
||
Args:
|
||
query: 查询文本
|
||
k: 返回结果数量
|
||
filter_type: 过滤类型,可选值: "text", "image", None(不过滤)
|
||
|
||
Returns:
|
||
搜索结果列表,每个元素包含相似项和分数
|
||
"""
|
||
# 编码查询文本
|
||
query_embedding = self.encode_text(query)
|
||
|
||
# 执行搜索
|
||
return self._search(query_embedding, k, filter_type)
|
||
|
||
def search_by_image(
|
||
self,
|
||
image: Image.Image,
|
||
k: int = 5,
|
||
filter_type: Optional[str] = None
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
图像搜索
|
||
|
||
Args:
|
||
image: 查询图像
|
||
k: 返回结果数量
|
||
filter_type: 过滤类型,可选值: "text", "image", None(不过滤)
|
||
|
||
Returns:
|
||
搜索结果列表,每个元素包含相似项和分数
|
||
"""
|
||
# 编码查询图像
|
||
query_embedding = self.encode_image(image)
|
||
|
||
# 执行搜索
|
||
return self._search(query_embedding, k, filter_type)
|
||
|
||
def _search(
|
||
self,
|
||
query_embedding: np.ndarray,
|
||
k: int = 5,
|
||
filter_type: Optional[str] = None
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
执行搜索
|
||
|
||
Args:
|
||
query_embedding: 查询向量
|
||
k: 返回结果数量
|
||
filter_type: 过滤类型,可选值: "text", "image", None(不过滤)
|
||
|
||
Returns:
|
||
搜索结果列表
|
||
"""
|
||
if self.index.ntotal == 0:
|
||
return []
|
||
|
||
# 确保查询向量是2D数组
|
||
if len(query_embedding.shape) == 1:
|
||
query_embedding = query_embedding.reshape(1, -1)
|
||
|
||
# 执行搜索,获取更多结果以便过滤
|
||
actual_k = k * 3 if filter_type else k
|
||
actual_k = min(actual_k, self.index.ntotal)
|
||
distances, indices = self.index.search(query_embedding.astype('float32'), actual_k)
|
||
|
||
# 处理结果
|
||
results = []
|
||
for i in range(len(indices[0])):
|
||
idx = indices[0][i]
|
||
if idx < 0: # FAISS可能返回-1表示无效索引
|
||
continue
|
||
|
||
vector_id = str(idx)
|
||
if vector_id in self.metadata:
|
||
item = self.metadata[vector_id]
|
||
|
||
# 如果指定了过滤类型,则只返回该类型的结果
|
||
if filter_type and item.get("type") != filter_type:
|
||
continue
|
||
|
||
# 添加距离和分数
|
||
result = item.copy()
|
||
result["distance"] = float(distances[0][i])
|
||
result["score"] = float(1.0 / (1.0 + distances[0][i]))
|
||
results.append(result)
|
||
|
||
# 如果已经收集了足够的结果,则停止
|
||
if len(results) >= k:
|
||
break
|
||
|
||
return results
|
||
|
||
def save_index(self):
|
||
"""保存索引和元数据"""
|
||
# 保存索引
|
||
index_file = f"{self.index_path}.index"
|
||
try:
|
||
faiss.write_index(self.index, index_file)
|
||
logger.info(f"索引保存成功: {index_file}")
|
||
except Exception as e:
|
||
logger.error(f"索引保存失败: {str(e)}")
|
||
|
||
# 保存元数据
|
||
metadata_file = f"{self.index_path}_metadata.json"
|
||
try:
|
||
with open(metadata_file, 'w', encoding='utf-8') as f:
|
||
json.dump(self.metadata, f, ensure_ascii=False, indent=2)
|
||
logger.info(f"元数据保存成功: {metadata_file}")
|
||
except Exception as e:
|
||
logger.error(f"元数据保存失败: {str(e)}")
|
||
|
||
def get_stats(self) -> Dict[str, Any]:
|
||
"""获取检索系统统计信息"""
|
||
text_count = sum(1 for v in self.metadata.values() if v.get("type") == "text")
|
||
image_count = sum(1 for v in self.metadata.values() if v.get("type") == "image")
|
||
|
||
return {
|
||
"total_vectors": self.index.ntotal,
|
||
"text_count": text_count,
|
||
"image_count": image_count,
|
||
"vector_dimension": self.vector_dim,
|
||
"index_path": self.index_path,
|
||
"model_path": self.model_path
|
||
}
|
||
|
||
def clear_index(self):
|
||
"""清空索引"""
|
||
logger.info(f"清空索引: {self.index_path}")
|
||
|
||
# 重新创建索引
|
||
self.index = faiss.IndexFlatL2(self.vector_dim)
|
||
|
||
# 清空元数据
|
||
self.metadata = {}
|
||
|
||
# 保存空索引
|
||
self.save_index()
|
||
|
||
logger.info(f"索引已清空: {self.index_path}")
|
||
return True
|
||
|
||
def list_items(self) -> List[Dict[str, Any]]:
|
||
"""列出所有索引项"""
|
||
items = []
|
||
|
||
for item_id, metadata in self.metadata.items():
|
||
item = metadata.copy()
|
||
item['id'] = item_id
|
||
items.append(item)
|
||
|
||
return items
|
||
|
||
def __del__(self):
|
||
"""析构函数,确保资源被正确释放并自动保存索引"""
|
||
try:
|
||
if hasattr(self, 'model'):
|
||
del self.model
|
||
self._clear_all_gpu_memory()
|
||
if hasattr(self, 'index') and self.index is not None:
|
||
logger.info("系统关闭前自动保存索引")
|
||
self.save_index()
|
||
except Exception as e:
|
||
logger.error(f"析构时保存索引失败: {str(e)}")
|