mmeb/test_all_retrieval_modes.py
2025-08-20 10:01:03 +00:00

105 lines
3.7 KiB
Python

#!/usr/bin/env python3
"""
测试所有四种多模态检索模式
"""
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
import numpy as np
from PIL import Image
import os
def test_all_retrieval_modes():
print('正在初始化多GPU多模态检索系统...')
retrieval = MultiGPUMultimodalRetrieval()
# 准备测试数据
test_texts = [
"一只可爱的小猫",
"美丽的风景照片",
"现代建筑设计",
"colorful flowers in garden"
]
test_images = [
'sample_images/1755677101_1__.jpg',
'sample_images/1755677101_2__.jpg',
'sample_images/1755677101_3__.jpg',
'sample_images/1755677101_4__.jpg'
]
# 验证测试图像存在
existing_images = [img for img in test_images if os.path.exists(img)]
if not existing_images:
print("❌ 没有找到测试图像文件")
return
print(f"找到 {len(existing_images)} 张测试图像")
try:
# 1. 构建文本索引
print('\n=== 构建文本索引 ===')
retrieval.build_text_index_parallel(test_texts)
print('✅ 文本索引构建完成')
# 2. 构建图像索引
print('\n=== 构建图像索引 ===')
retrieval.build_image_index_parallel(existing_images)
print('✅ 图像索引构建完成')
# 3. 测试文本到文本检索
print('\n=== 测试文本到文本检索 ===')
query = "小动物"
results = retrieval.search_text_by_text(query, top_k=3)
print(f'查询: "{query}"')
for i, (text, score) in enumerate(results):
print(f' {i+1}. {text} (相似度: {score:.4f})')
# 4. 测试文本到图像检索
print('\n=== 测试文本到图像检索 ===')
query = "beautiful image"
results = retrieval.search_images_by_text(query, top_k=3)
print(f'查询: "{query}"')
for i, (image_path, score) in enumerate(results):
print(f' {i+1}. {image_path} (相似度: {score:.4f})')
# 5. 测试图像到文本检索
print('\n=== 测试图像到文本检索 ===')
query_image = existing_images[0]
results = retrieval.search_text_by_image(query_image, top_k=3)
print(f'查询图像: {query_image}')
for i, (text, score) in enumerate(results):
print(f' {i+1}. {text} (相似度: {score:.4f})')
# 6. 测试图像到图像检索
print('\n=== 测试图像到图像检索 ===')
query_image = existing_images[0]
results = retrieval.search_images_by_image(query_image, top_k=3)
print(f'查询图像: {query_image}')
for i, (image_path, score) in enumerate(results):
print(f' {i+1}. {image_path} (相似度: {score:.4f})')
print('\n✅ 所有四种检索模式测试完成!')
# 7. 测试Web应用兼容的方法名
print('\n=== 测试Web应用兼容方法 ===')
try:
results = retrieval.search_text_to_image("test query", top_k=2)
print('✅ search_text_to_image 方法正常')
results = retrieval.search_image_to_text(existing_images[0], top_k=2)
print('✅ search_image_to_text 方法正常')
results = retrieval.search_image_to_image(existing_images[0], top_k=2)
print('✅ search_image_to_image 方法正常')
except Exception as e:
print(f'❌ Web应用兼容方法测试失败: {e}')
except Exception as e:
print(f'❌ 测试过程中出现错误: {e}')
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_all_retrieval_modes()