mmeb/test_local_retrieval.py
2025-09-22 10:13:11 +00:00

230 lines
7.1 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试本地模型和FAISS向量数据库的多模态检索系统
"""
import os
import sys
import logging
from pathlib import Path
import time
from PIL import Image
import numpy as np
from multimodal_retrieval_local import MultimodalRetrievalLocal
# 设置日志
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 设置离线模式
os.environ['TRANSFORMERS_OFFLINE'] = '1'
def test_text_retrieval():
"""测试文本检索功能"""
print("\n=== 测试文本检索 ===")
# 初始化检索系统
print("初始化检索系统...")
retrieval = MultimodalRetrievalLocal(
model_path="/root/models/Ops-MM-embedding-v1-7B",
use_all_gpus=True,
index_path="local_faiss_text_test"
)
# 测试文本
texts = [
"一只可爱的橘色猫咪在沙发上睡觉",
"城市夜景中的高楼大厦和车流",
"阳光明媚的海滩上,人们在冲浪和晒太阳",
"美味的意大利面配红酒和沙拉",
"雪山上滑雪的运动员"
]
# 添加文本
print("\n添加文本到检索系统...")
text_ids = retrieval.add_texts(texts)
print(f"添加了{len(text_ids)}条文本")
# 获取统计信息
stats = retrieval.get_stats()
print(f"检索系统统计信息: {stats}")
# 测试文本搜索
print("\n测试文本搜索...")
queries = ["一只猫在睡觉", "都市风光", "海边的景色"]
for query in queries:
print(f"\n查询: {query}")
results = retrieval.search_by_text(query, k=2)
for i, result in enumerate(results):
print(f" 结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})")
# 保存索引
print("\n保存索引...")
retrieval.save_index()
print("\n文本检索测试完成!")
return retrieval
def test_image_retrieval():
"""测试图像检索功能"""
print("\n=== 测试图像检索 ===")
# 初始化检索系统
print("初始化检索系统...")
retrieval = MultimodalRetrievalLocal(
model_path="/root/models/Ops-MM-embedding-v1-7B",
use_all_gpus=True,
index_path="local_faiss_image_test"
)
# 创建测试图像
print("\n创建测试图像...")
images = []
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
image_paths = []
for i, color in enumerate(colors):
img = Image.new('RGB', (224, 224), color=color)
images.append(img)
# 保存图像
img_path = f"/tmp/test_image_{i}.png"
img.save(img_path)
image_paths.append(img_path)
print(f"创建图像: {img_path}")
# 添加图像
print("\n添加图像到检索系统...")
metadatas = [{"description": f"测试图像 {i+1}"} for i in range(len(images))]
image_ids = retrieval.add_images(images, metadatas, image_paths)
print(f"添加了{len(image_ids)}张图像")
# 获取统计信息
stats = retrieval.get_stats()
print(f"检索系统统计信息: {stats}")
# 测试图像搜索
print("\n测试图像搜索...")
query_image = Image.new('RGB', (224, 224), color=(255, 0, 0)) # 红色图像
print("\n使用图像查询图像:")
results = retrieval.search_by_image(query_image, k=2, filter_type="image")
for i, result in enumerate(results):
print(f" 结果 {i+1}: {result.get('description', 'N/A')} (分数: {result.get('score', 0):.4f})")
# 保存索引
print("\n保存索引...")
retrieval.save_index()
print("\n图像检索测试完成!")
return retrieval
def test_cross_modal_retrieval():
"""测试跨模态检索功能"""
print("\n=== 测试跨模态检索 ===")
# 初始化检索系统
print("初始化检索系统...")
retrieval = MultimodalRetrievalLocal(
model_path="/root/models/Ops-MM-embedding-v1-7B",
use_all_gpus=True,
index_path="local_faiss_cross_modal_test"
)
# 添加文本
texts = [
"一只红色的苹果",
"绿色的草地",
"蓝色的大海",
"黄色的向日葵",
"青色的天空"
]
print("\n添加文本到检索系统...")
text_ids = retrieval.add_texts(texts)
print(f"添加了{len(text_ids)}条文本")
# 添加图像
print("\n添加图像到检索系统...")
images = []
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
descriptions = ["红色图像", "绿色图像", "蓝色图像", "黄色图像", "青色图像"]
for i, color in enumerate(colors):
img = Image.new('RGB', (224, 224), color=color)
images.append(img)
metadatas = [{"description": desc} for desc in descriptions]
image_ids = retrieval.add_images(images, metadatas)
print(f"添加了{len(image_ids)}张图像")
# 获取统计信息
stats = retrieval.get_stats()
print(f"检索系统统计信息: {stats}")
# 测试文搜图
print("\n测试文搜图...")
query_text = "红色"
print(f"查询文本: {query_text}")
results = retrieval.search_by_text(query_text, k=2, filter_type="image")
for i, result in enumerate(results):
print(f" 结果 {i+1}: {result.get('description', 'N/A')} (分数: {result.get('score', 0):.4f})")
# 测试图搜文
print("\n测试图搜文...")
query_image = Image.new('RGB', (224, 224), color=(0, 0, 255)) # 蓝色图像
print("查询图像: 蓝色图像")
results = retrieval.search_by_image(query_image, k=2, filter_type="text")
for i, result in enumerate(results):
print(f" 结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})")
# 保存索引
print("\n保存索引...")
retrieval.save_index()
print("\n跨模态检索测试完成!")
return retrieval
def main():
"""主函数"""
print("=== 本地多模态检索系统测试 ===")
# 检查模型路径
model_path = "/root/models/Ops-MM-embedding-v1-7B"
if not os.path.exists(model_path):
print(f"错误: 模型路径不存在: {model_path}")
print("请先下载模型到指定路径")
return
# 检查模型文件
config_file = os.path.join(model_path, "config.json")
if not os.path.exists(config_file):
print(f"错误: 模型配置文件不存在: {config_file}")
print("请确保模型文件已正确下载")
return
print(f"模型路径验证成功: {model_path}")
# 运行测试
try:
# 测试文本检索
test_text_retrieval()
# 测试图像检索
test_image_retrieval()
# 测试跨模态检索
test_cross_modal_retrieval()
print("\n所有测试完成!")
except Exception as e:
print(f"测试过程中发生错误: {str(e)}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()