371 lines
12 KiB
Python
371 lines
12 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
基于FAISS的多模态检索系统
|
||
支持文搜文、文搜图、图搜文、图搜图四种检索模式
|
||
"""
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
from torch.nn.parallel import DataParallel
|
||
import numpy as np
|
||
from PIL import Image
|
||
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
||
from typing import List, Union, Tuple, Dict, Any, Optional
|
||
import os
|
||
import json
|
||
from pathlib import Path
|
||
import logging
|
||
import gc
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
import threading
|
||
|
||
from faiss_vector_store import FaissVectorStore
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class MultimodalRetrievalFAISS:
|
||
"""基于FAISS的多模态检索系统"""
|
||
|
||
def __init__(self, model_name: str = "OpenSearch-AI/Ops-MM-embedding-v1-7B",
|
||
use_all_gpus: bool = True, gpu_ids: List[int] = None,
|
||
min_memory_gb: int = 12, index_path: str = "faiss_index"):
|
||
"""
|
||
初始化多模态检索系统
|
||
|
||
Args:
|
||
model_name: 模型名称
|
||
use_all_gpus: 是否使用所有可用GPU
|
||
gpu_ids: 指定使用的GPU ID列表
|
||
min_memory_gb: 最小可用内存(GB)
|
||
index_path: FAISS索引文件路径
|
||
"""
|
||
self.model_name = model_name
|
||
self.index_path = index_path
|
||
|
||
# 设置GPU设备
|
||
self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb)
|
||
|
||
# 清理GPU内存
|
||
self._clear_all_gpu_memory()
|
||
|
||
# 加载模型和处理器
|
||
self._load_model_and_processor()
|
||
|
||
# 初始化FAISS向量存储
|
||
self.vector_store = FaissVectorStore(
|
||
index_path=index_path,
|
||
dimension=3584 # OpenSearch-AI/Ops-MM-embedding-v1-7B的向量维度
|
||
)
|
||
|
||
logger.info(f"多模态检索系统初始化完成,使用模型: {model_name}")
|
||
logger.info(f"向量存储路径: {index_path}")
|
||
|
||
def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int], min_memory_gb: int):
|
||
"""设置GPU设备"""
|
||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
self.use_gpu = self.device.type == "cuda"
|
||
|
||
if self.use_gpu:
|
||
self.available_gpus = self._get_available_gpus(min_memory_gb)
|
||
|
||
if not self.available_gpus:
|
||
logger.warning(f"没有可用的GPU或GPU内存不足{min_memory_gb}GB,将使用CPU")
|
||
self.device = torch.device("cpu")
|
||
self.use_gpu = False
|
||
else:
|
||
if gpu_ids:
|
||
self.gpu_ids = [gid for gid in gpu_ids if gid in self.available_gpus]
|
||
if not self.gpu_ids:
|
||
logger.warning(f"指定的GPU {gpu_ids}不可用或内存不足,将使用可用的GPU: {self.available_gpus}")
|
||
self.gpu_ids = self.available_gpus
|
||
elif use_all_gpus:
|
||
self.gpu_ids = self.available_gpus
|
||
else:
|
||
self.gpu_ids = [self.available_gpus[0]]
|
||
|
||
logger.info(f"使用GPU: {self.gpu_ids}")
|
||
self.device = torch.device(f"cuda:{self.gpu_ids[0]}")
|
||
|
||
def _get_available_gpus(self, min_memory_gb: int) -> List[int]:
|
||
"""获取可用的GPU列表"""
|
||
available_gpus = []
|
||
for i in range(torch.cuda.device_count()):
|
||
total_mem = torch.cuda.get_device_properties(i).total_memory / (1024 ** 3) # GB
|
||
if total_mem >= min_memory_gb:
|
||
available_gpus.append(i)
|
||
return available_gpus
|
||
|
||
def _clear_all_gpu_memory(self):
|
||
"""清理GPU内存"""
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
gc.collect()
|
||
|
||
def _load_model_and_processor(self):
|
||
"""加载模型和处理器"""
|
||
logger.info(f"加载模型和处理器: {self.model_name}")
|
||
|
||
# 加载tokenizer和processor
|
||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||
self.processor = AutoProcessor.from_pretrained(self.model_name)
|
||
|
||
# 加载模型
|
||
self.model = AutoModel.from_pretrained(
|
||
self.model_name,
|
||
torch_dtype=torch.float16 if self.use_gpu else torch.float32,
|
||
device_map="auto" if len(self.gpu_ids) > 1 else None
|
||
)
|
||
|
||
# 如果使用多GPU,包装模型
|
||
if len(self.gpu_ids) > 1:
|
||
self.model = DataParallel(self.model, device_ids=self.gpu_ids)
|
||
|
||
self.model.eval()
|
||
self.model.to(self.device)
|
||
|
||
logger.info("模型和处理器加载完成")
|
||
|
||
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
|
||
"""编码文本为向量"""
|
||
if isinstance(text, str):
|
||
text = [text]
|
||
|
||
inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
|
||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||
|
||
with torch.no_grad():
|
||
outputs = self.model(**inputs)
|
||
# 获取[CLS]标记的隐藏状态作为句子表示
|
||
text_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
||
|
||
# 归一化向量
|
||
text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True)
|
||
return text_embeddings[0] if len(text) == 1 else text_embeddings
|
||
|
||
def encode_image(self, image: Union[Image.Image, List[Image.Image]]) -> np.ndarray:
|
||
"""编码图像为向量"""
|
||
if isinstance(image, Image.Image):
|
||
image = [image]
|
||
|
||
inputs = self.processor(images=image, return_tensors="pt")
|
||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||
|
||
with torch.no_grad():
|
||
outputs = self.model.vision_model(**inputs)
|
||
# 获取[CLS]标记的隐藏状态作为图像表示
|
||
image_embeddings = outputs.pooler_output.cpu().numpy()
|
||
|
||
# 归一化向量
|
||
image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
|
||
return image_embeddings[0] if len(image) == 1 else image_embeddings
|
||
|
||
def add_texts(
|
||
self,
|
||
texts: List[str],
|
||
metadatas: Optional[List[Dict[str, Any]]] = None
|
||
) -> List[str]:
|
||
"""
|
||
添加文本到检索系统
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
metadatas: 元数据列表,每个元素是一个字典
|
||
|
||
Returns:
|
||
添加的文本ID列表
|
||
"""
|
||
if not texts:
|
||
return []
|
||
|
||
if metadatas is None:
|
||
metadatas = [{} for _ in range(len(texts))]
|
||
|
||
if len(texts) != len(metadatas):
|
||
raise ValueError("texts和metadatas长度必须相同")
|
||
|
||
# 编码文本
|
||
text_embeddings = self.encode_text(texts)
|
||
|
||
# 准备元数据
|
||
for i, text in enumerate(texts):
|
||
metadatas[i].update({
|
||
"text": text,
|
||
"type": "text"
|
||
})
|
||
|
||
# 添加到向量存储
|
||
vector_ids = self.vector_store.add_vectors(text_embeddings, metadatas)
|
||
|
||
logger.info(f"成功添加{len(vector_ids)}条文本到检索系统")
|
||
return vector_ids
|
||
|
||
def add_images(
|
||
self,
|
||
images: List[Image.Image],
|
||
metadatas: Optional[List[Dict[str, Any]]] = None
|
||
) -> List[str]:
|
||
"""
|
||
添加图像到检索系统
|
||
|
||
Args:
|
||
images: PIL图像列表
|
||
metadatas: 元数据列表,每个元素是一个字典
|
||
|
||
Returns:
|
||
添加的图像ID列表
|
||
"""
|
||
if not images:
|
||
return []
|
||
|
||
if metadatas is None:
|
||
metadatas = [{} for _ in range(len(images))]
|
||
|
||
if len(images) != len(metadatas):
|
||
raise ValueError("images和metadatas长度必须相同")
|
||
|
||
# 编码图像
|
||
image_embeddings = self.encode_image(images)
|
||
|
||
# 准备元数据
|
||
for i, image in enumerate(images):
|
||
metadatas[i].update({
|
||
"type": "image",
|
||
"width": image.width,
|
||
"height": image.height
|
||
})
|
||
|
||
# 添加到向量存储
|
||
vector_ids = self.vector_store.add_vectors(image_embeddings, metadatas)
|
||
|
||
logger.info(f"成功添加{len(vector_ids)}张图像到检索系统")
|
||
return vector_ids
|
||
|
||
def search_by_text(
|
||
self,
|
||
query: str,
|
||
k: int = 5,
|
||
filter_condition: Optional[Dict[str, Any]] = None
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
文本搜索
|
||
|
||
Args:
|
||
query: 查询文本
|
||
k: 返回结果数量
|
||
filter_condition: 过滤条件
|
||
|
||
Returns:
|
||
搜索结果列表,每个元素包含相似项和分数
|
||
"""
|
||
# 编码查询文本
|
||
query_embedding = self.encode_text(query)
|
||
|
||
# 执行搜索
|
||
results, distances = self.vector_store.search(query_embedding, k)
|
||
|
||
# 处理结果
|
||
search_results = []
|
||
for i, (result, distance) in enumerate(zip(results, distances)):
|
||
result["score"] = 1.0 / (1.0 + distance) # 将距离转换为相似度分数
|
||
search_results.append(result)
|
||
|
||
return search_results
|
||
|
||
def search_by_image(
|
||
self,
|
||
image: Image.Image,
|
||
k: int = 5,
|
||
filter_condition: Optional[Dict[str, Any]] = None
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
图像搜索
|
||
|
||
Args:
|
||
image: 查询图像
|
||
k: 返回结果数量
|
||
filter_condition: 过滤条件
|
||
|
||
Returns:
|
||
搜索结果列表,每个元素包含相似项和分数
|
||
"""
|
||
# 编码查询图像
|
||
query_embedding = self.encode_image(image)
|
||
|
||
# 执行搜索
|
||
results, distances = self.vector_store.search(query_embedding, k)
|
||
|
||
# 处理结果
|
||
search_results = []
|
||
for i, (result, distance) in enumerate(zip(results, distances)):
|
||
result["score"] = 1.0 / (1.0 + distance) # 将距离转换为相似度分数
|
||
search_results.append(result)
|
||
|
||
return search_results
|
||
|
||
def get_vector_count(self) -> int:
|
||
"""获取向量数量"""
|
||
return self.vector_store.get_vector_count()
|
||
|
||
def save_index(self):
|
||
"""保存索引"""
|
||
self.vector_store.save_index()
|
||
logger.info("索引已保存")
|
||
|
||
def __del__(self):
|
||
"""析构函数,确保资源被正确释放"""
|
||
if hasattr(self, 'model'):
|
||
del self.model
|
||
self._clear_all_gpu_memory()
|
||
if hasattr(self, 'vector_store'):
|
||
self.save_index()
|
||
|
||
|
||
def test_faiss_system():
|
||
"""测试FAISS多模态检索系统"""
|
||
import time
|
||
from PIL import Image
|
||
import numpy as np
|
||
|
||
# 初始化检索系统
|
||
print("初始化多模态检索系统...")
|
||
retrieval = MultimodalRetrievalFAISS(
|
||
model_name="OpenSearch-AI/Ops-MM-embedding-v1-7B",
|
||
use_all_gpus=True,
|
||
index_path="faiss_index_test"
|
||
)
|
||
|
||
# 测试文本
|
||
texts = [
|
||
"一只可爱的橘色猫咪在沙发上睡觉",
|
||
"城市夜景中的高楼大厦和车流",
|
||
"阳光明媚的海滩上,人们在冲浪和晒太阳",
|
||
"美味的意大利面配红酒和沙拉",
|
||
"雪山上滑雪的运动员"
|
||
]
|
||
|
||
# 添加文本
|
||
print("\n添加文本到检索系统...")
|
||
text_ids = retrieval.add_texts(texts)
|
||
print(f"添加了{len(text_ids)}条文本")
|
||
|
||
# 测试文本搜索
|
||
print("\n测试文本搜索...")
|
||
query_text = "一只猫在睡觉"
|
||
print(f"查询: {query_text}")
|
||
results = retrieval.search_by_text(query_text, k=2)
|
||
for i, result in enumerate(results):
|
||
print(f"结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})")
|
||
|
||
# 测试保存和加载
|
||
print("\n保存索引...")
|
||
retrieval.save_index()
|
||
|
||
print("\n测试完成!")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
test_faiss_system()
|