497 lines
17 KiB
Python
497 lines
17 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
集成百度VDB的多模态检索系统
|
||
支持文搜文、文搜图、图搜文、图搜图四种检索模式
|
||
"""
|
||
|
||
import torch
|
||
import numpy as np
|
||
from PIL import Image
|
||
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
||
from typing import List, Union, Tuple, Dict, Any
|
||
import os
|
||
import json
|
||
import logging
|
||
import gc
|
||
from baidu_vdb_backend import BaiduVDBBackend
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class MultimodalRetrievalVDB:
|
||
"""集成百度VDB的多模态检索系统"""
|
||
|
||
def __init__(self, model_name: str = "OpenSearch-AI/Ops-MM-embedding-v1-7B",
|
||
use_all_gpus: bool = True, gpu_ids: List[int] = None,
|
||
vdb_config: Dict[str, str] = None):
|
||
"""
|
||
初始化多模态检索系统
|
||
|
||
Args:
|
||
model_name: 模型名称
|
||
use_all_gpus: 是否使用所有可用GPU
|
||
gpu_ids: 指定使用的GPU ID列表
|
||
vdb_config: VDB配置字典
|
||
"""
|
||
self.model_name = model_name
|
||
|
||
# 设置GPU设备
|
||
self._setup_devices(use_all_gpus, gpu_ids)
|
||
|
||
# 清理GPU内存
|
||
self._clear_gpu_memory()
|
||
|
||
logger.info(f"正在加载模型到GPU: {self.device_ids}")
|
||
|
||
# 加载模型和处理器
|
||
self.model = None
|
||
self.tokenizer = None
|
||
self.processor = None
|
||
self._load_model()
|
||
|
||
# 初始化百度VDB后端
|
||
if vdb_config is None:
|
||
vdb_config = {
|
||
"account": "root",
|
||
"api_key": "vdb$yjr9ln3n0td",
|
||
"endpoint": "http://180.76.96.191:5287",
|
||
"database_name": "multimodal_retrieval"
|
||
}
|
||
|
||
self.vdb = BaiduVDBBackend(**vdb_config)
|
||
|
||
logger.info("多模态检索系统初始化完成")
|
||
|
||
def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int]):
|
||
"""设置GPU设备"""
|
||
if not torch.cuda.is_available():
|
||
raise RuntimeError("CUDA不可用,无法使用GPU")
|
||
|
||
total_gpus = torch.cuda.device_count()
|
||
logger.info(f"检测到 {total_gpus} 个GPU")
|
||
|
||
if use_all_gpus:
|
||
self.device_ids = list(range(total_gpus))
|
||
elif gpu_ids:
|
||
self.device_ids = gpu_ids
|
||
else:
|
||
self.device_ids = [0]
|
||
|
||
self.num_gpus = len(self.device_ids)
|
||
self.primary_device = f"cuda:{self.device_ids[0]}"
|
||
|
||
logger.info(f"使用GPU: {self.device_ids}, 主设备: {self.primary_device}")
|
||
|
||
def _clear_gpu_memory(self):
|
||
"""清理GPU内存"""
|
||
for gpu_id in self.device_ids:
|
||
torch.cuda.set_device(gpu_id)
|
||
torch.cuda.empty_cache()
|
||
torch.cuda.synchronize()
|
||
gc.collect()
|
||
logger.info("GPU内存已清理")
|
||
|
||
def _load_model(self):
|
||
"""加载模型"""
|
||
try:
|
||
# 设置环境变量优化内存使用
|
||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
||
|
||
# 清理GPU内存
|
||
self._clear_gpu_memory()
|
||
|
||
# 加载模型
|
||
if self.num_gpus > 1:
|
||
# 多GPU加载
|
||
max_memory = {i: "18GiB" for i in self.device_ids}
|
||
|
||
self.model = AutoModel.from_pretrained(
|
||
self.model_name,
|
||
trust_remote_code=True,
|
||
torch_dtype=torch.float16,
|
||
device_map="auto",
|
||
max_memory=max_memory,
|
||
low_cpu_mem_usage=True
|
||
)
|
||
else:
|
||
# 单GPU加载
|
||
self.model = AutoModel.from_pretrained(
|
||
self.model_name,
|
||
trust_remote_code=True,
|
||
torch_dtype=torch.float16,
|
||
device_map=self.primary_device
|
||
)
|
||
|
||
# 加载分词器和处理器
|
||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||
self.model_name,
|
||
trust_remote_code=True
|
||
)
|
||
|
||
try:
|
||
self.processor = AutoProcessor.from_pretrained(
|
||
self.model_name,
|
||
trust_remote_code=True
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Processor加载失败,使用tokenizer: {e}")
|
||
self.processor = self.tokenizer
|
||
|
||
logger.info("模型加载完成")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"模型加载失败: {str(e)}")
|
||
return False
|
||
|
||
def encode_text_batch(self, texts: List[str]) -> np.ndarray:
|
||
"""
|
||
批量编码文本为向量
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
|
||
Returns:
|
||
文本向量数组
|
||
"""
|
||
if not texts:
|
||
return np.array([])
|
||
|
||
with torch.no_grad():
|
||
# 预处理输入
|
||
inputs = self.tokenizer(
|
||
text=texts,
|
||
return_tensors="pt",
|
||
padding=True,
|
||
truncation=True,
|
||
max_length=512
|
||
)
|
||
|
||
# 将输入移动到主设备
|
||
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
|
||
|
||
# 前向传播
|
||
outputs = self.model(**inputs)
|
||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||
|
||
# 清理GPU内存
|
||
del inputs, outputs
|
||
torch.cuda.empty_cache()
|
||
|
||
return embeddings.cpu().numpy().astype(np.float32)
|
||
|
||
def encode_image_batch(self, images: List[Union[str, Image.Image]]) -> np.ndarray:
|
||
"""
|
||
批量编码图像为向量
|
||
|
||
Args:
|
||
images: 图像路径或PIL图像列表
|
||
|
||
Returns:
|
||
图像向量数组
|
||
"""
|
||
if not images:
|
||
return np.array([])
|
||
|
||
# 预处理图像
|
||
processed_images = []
|
||
for img in images:
|
||
if isinstance(img, str):
|
||
img = Image.open(img).convert('RGB')
|
||
elif isinstance(img, Image.Image):
|
||
img = img.convert('RGB')
|
||
processed_images.append(img)
|
||
|
||
try:
|
||
logger.info(f"处理 {len(processed_images)} 张图像")
|
||
|
||
# 使用多模态模型生成图像embedding
|
||
conversations = []
|
||
for i in range(len(processed_images)):
|
||
conversation = [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "image", "image": processed_images[i]},
|
||
{"type": "text", "text": "What is in this image?"}
|
||
]
|
||
}
|
||
]
|
||
conversations.append(conversation)
|
||
|
||
# 使用processor处理
|
||
try:
|
||
texts = []
|
||
for conv in conversations:
|
||
text = self.processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
|
||
texts.append(text)
|
||
|
||
# 处理文本和图像
|
||
inputs = self.processor(
|
||
text=texts,
|
||
images=processed_images,
|
||
return_tensors="pt",
|
||
padding=True
|
||
)
|
||
|
||
# 移动到GPU
|
||
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
|
||
|
||
# 获取模型输出
|
||
with torch.no_grad():
|
||
outputs = self.model(**inputs)
|
||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||
|
||
# 转换为numpy数组
|
||
embeddings = embeddings.cpu().numpy().astype(np.float32)
|
||
|
||
except Exception as inner_e:
|
||
logger.warning(f"多模态模型图像编码失败: {inner_e}")
|
||
# 使用零向量作为fallback
|
||
embedding_dim = 3584
|
||
embeddings = np.zeros((len(processed_images), embedding_dim), dtype=np.float32)
|
||
|
||
logger.info(f"生成图像embeddings: {embeddings.shape}")
|
||
return embeddings
|
||
|
||
except Exception as e:
|
||
logger.error(f"图像编码失败: {e}")
|
||
# 返回零向量作为fallback
|
||
embedding_dim = 3584
|
||
embeddings = np.zeros((len(processed_images), embedding_dim), dtype=np.float32)
|
||
return embeddings
|
||
|
||
def store_texts(self, texts: List[str], metadata: List[Dict] = None) -> List[str]:
|
||
"""
|
||
存储文本数据
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
metadata: 元数据列表
|
||
|
||
Returns:
|
||
存储的ID列表
|
||
"""
|
||
logger.info(f"正在存储 {len(texts)} 条文本数据")
|
||
|
||
# 分批处理
|
||
batch_size = 16
|
||
all_ids = []
|
||
|
||
for i in range(0, len(texts), batch_size):
|
||
batch_texts = texts[i:i+batch_size]
|
||
batch_metadata = metadata[i:i+batch_size] if metadata else None
|
||
|
||
try:
|
||
# 编码文本
|
||
vectors = self.encode_text_batch(batch_texts)
|
||
|
||
# 存储到VDB
|
||
ids = self.vdb.store_text_vectors(batch_texts, vectors, batch_metadata)
|
||
all_ids.extend(ids)
|
||
|
||
logger.info(f"已处理 {i + len(batch_texts)}/{len(texts)} 条文本")
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理文本批次时出错: {e}")
|
||
continue
|
||
|
||
logger.info(f"✅ 文本存储完成,共 {len(all_ids)} 条")
|
||
return all_ids
|
||
|
||
def store_images(self, image_paths: List[str], metadata: List[Dict] = None) -> List[str]:
|
||
"""
|
||
存储图像数据
|
||
|
||
Args:
|
||
image_paths: 图像路径列表
|
||
metadata: 元数据列表
|
||
|
||
Returns:
|
||
存储的ID列表
|
||
"""
|
||
logger.info(f"正在存储 {len(image_paths)} 张图像数据")
|
||
|
||
# 图像处理使用更小的批次
|
||
batch_size = 8
|
||
all_ids = []
|
||
|
||
for i in range(0, len(image_paths), batch_size):
|
||
batch_images = image_paths[i:i+batch_size]
|
||
batch_metadata = metadata[i:i+batch_size] if metadata else None
|
||
|
||
try:
|
||
# 编码图像
|
||
vectors = self.encode_image_batch(batch_images)
|
||
|
||
# 存储到VDB
|
||
ids = self.vdb.store_image_vectors(batch_images, vectors, batch_metadata)
|
||
all_ids.extend(ids)
|
||
|
||
logger.info(f"已处理 {i + len(batch_images)}/{len(image_paths)} 张图像")
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理图像批次时出错: {e}")
|
||
continue
|
||
|
||
logger.info(f"✅ 图像存储完成,共 {len(all_ids)} 条")
|
||
return all_ids
|
||
|
||
def search_text_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||
"""文搜文:使用文本查询搜索相似文本"""
|
||
logger.info(f"执行文搜文查询: {query}")
|
||
|
||
# 编码查询文本
|
||
query_vector = self.encode_text_batch([query])[0]
|
||
|
||
# 在VDB中搜索
|
||
results = self.vdb.search_text_vectors(query_vector, top_k)
|
||
|
||
# 格式化结果
|
||
formatted_results = []
|
||
for doc_id, text_content, score, metadata in results:
|
||
formatted_results.append((text_content, score))
|
||
|
||
return formatted_results
|
||
|
||
def search_images_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||
"""文搜图:使用文本查询搜索相似图像"""
|
||
logger.info(f"执行文搜图查询: {query}")
|
||
|
||
# 编码查询文本
|
||
query_vector = self.encode_text_batch([query])[0]
|
||
|
||
# 在VDB中搜索图像
|
||
results = self.vdb.search_image_vectors(query_vector, top_k)
|
||
|
||
# 格式化结果
|
||
formatted_results = []
|
||
for doc_id, image_path, image_name, score, metadata in results:
|
||
formatted_results.append((image_path, score))
|
||
|
||
return formatted_results
|
||
|
||
def search_text_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||
"""图搜文:使用图像查询搜索相似文本"""
|
||
logger.info(f"执行图搜文查询")
|
||
|
||
# 编码查询图像
|
||
query_vector = self.encode_image_batch([query_image])[0]
|
||
|
||
# 在VDB中搜索文本
|
||
results = self.vdb.search_text_vectors(query_vector, top_k)
|
||
|
||
# 格式化结果
|
||
formatted_results = []
|
||
for doc_id, text_content, score, metadata in results:
|
||
formatted_results.append((text_content, score))
|
||
|
||
return formatted_results
|
||
|
||
def search_images_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||
"""图搜图:使用图像查询搜索相似图像"""
|
||
logger.info(f"执行图搜图查询")
|
||
|
||
# 编码查询图像
|
||
query_vector = self.encode_image_batch([query_image])[0]
|
||
|
||
# 在VDB中搜索图像
|
||
results = self.vdb.search_image_vectors(query_vector, top_k)
|
||
|
||
# 格式化结果
|
||
formatted_results = []
|
||
for doc_id, image_path, image_name, score, metadata in results:
|
||
formatted_results.append((image_path, score))
|
||
|
||
return formatted_results
|
||
|
||
# Web应用兼容的方法名称
|
||
def search_text_to_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||
"""文搜文:Web应用兼容方法"""
|
||
return self.search_text_by_text(query, top_k)
|
||
|
||
def search_text_to_image(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||
"""文搜图:Web应用兼容方法"""
|
||
return self.search_images_by_text(query, top_k)
|
||
|
||
def search_image_to_text(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||
"""图搜文:Web应用兼容方法"""
|
||
return self.search_text_by_image(query_image, top_k)
|
||
|
||
def search_image_to_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||
"""图搜图:Web应用兼容方法"""
|
||
return self.search_images_by_image(query_image, top_k)
|
||
|
||
def get_statistics(self) -> Dict[str, Any]:
|
||
"""获取系统统计信息"""
|
||
return self.vdb.get_statistics()
|
||
|
||
def clear_all_data(self):
|
||
"""清空所有数据"""
|
||
self.vdb.clear_all_data()
|
||
|
||
def close(self):
|
||
"""关闭系统"""
|
||
if self.vdb:
|
||
self.vdb.close()
|
||
self._clear_gpu_memory()
|
||
|
||
def __enter__(self):
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
self.close()
|
||
|
||
|
||
def check_system_info():
|
||
"""检查系统信息"""
|
||
print("=== 多模态检索系统信息 ===")
|
||
|
||
if not torch.cuda.is_available():
|
||
print("❌ CUDA不可用")
|
||
return
|
||
|
||
gpu_count = torch.cuda.device_count()
|
||
print(f"✅ 检测到 {gpu_count} 个GPU")
|
||
print(f"CUDA版本: {torch.version.cuda}")
|
||
print(f"PyTorch版本: {torch.__version__}")
|
||
|
||
for i in range(gpu_count):
|
||
gpu_name = torch.cuda.get_device_name(i)
|
||
gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
|
||
print(f"GPU {i}: {gpu_name} ({gpu_memory:.1f}GB)")
|
||
|
||
print("========================")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 检查系统环境
|
||
check_system_info()
|
||
|
||
# 示例使用
|
||
print("\n正在初始化多模态检索系统...")
|
||
|
||
try:
|
||
retrieval_system = MultimodalRetrievalVDB()
|
||
print("✅ 系统初始化成功!")
|
||
|
||
# 显示统计信息
|
||
stats = retrieval_system.get_statistics()
|
||
print(f"\n📊 数据库统计信息: {stats}")
|
||
|
||
print("\n🚀 多模态检索系统就绪!")
|
||
print("支持的检索模式:")
|
||
print("1. 文搜文: search_text_by_text()")
|
||
print("2. 文搜图: search_images_by_text()")
|
||
print("3. 图搜文: search_text_by_image()")
|
||
print("4. 图搜图: search_images_by_image()")
|
||
print("5. 存储文本: store_texts()")
|
||
print("6. 存储图像: store_images()")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 系统初始化失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|