105 lines
3.7 KiB
Python
105 lines
3.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
测试所有四种多模态检索模式
|
|
"""
|
|
|
|
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
|
import numpy as np
|
|
from PIL import Image
|
|
import os
|
|
|
|
def test_all_retrieval_modes():
|
|
print('正在初始化多GPU多模态检索系统...')
|
|
retrieval = MultiGPUMultimodalRetrieval()
|
|
|
|
# 准备测试数据
|
|
test_texts = [
|
|
"一只可爱的小猫",
|
|
"美丽的风景照片",
|
|
"现代建筑设计",
|
|
"colorful flowers in garden"
|
|
]
|
|
|
|
test_images = [
|
|
'sample_images/1755677101_1__.jpg',
|
|
'sample_images/1755677101_2__.jpg',
|
|
'sample_images/1755677101_3__.jpg',
|
|
'sample_images/1755677101_4__.jpg'
|
|
]
|
|
|
|
# 验证测试图像存在
|
|
existing_images = [img for img in test_images if os.path.exists(img)]
|
|
if not existing_images:
|
|
print("❌ 没有找到测试图像文件")
|
|
return
|
|
|
|
print(f"找到 {len(existing_images)} 张测试图像")
|
|
|
|
try:
|
|
# 1. 构建文本索引
|
|
print('\n=== 构建文本索引 ===')
|
|
retrieval.build_text_index_parallel(test_texts)
|
|
print('✅ 文本索引构建完成')
|
|
|
|
# 2. 构建图像索引
|
|
print('\n=== 构建图像索引 ===')
|
|
retrieval.build_image_index_parallel(existing_images)
|
|
print('✅ 图像索引构建完成')
|
|
|
|
# 3. 测试文本到文本检索
|
|
print('\n=== 测试文本到文本检索 ===')
|
|
query = "小动物"
|
|
results = retrieval.search_text_by_text(query, top_k=3)
|
|
print(f'查询: "{query}"')
|
|
for i, (text, score) in enumerate(results):
|
|
print(f' {i+1}. {text} (相似度: {score:.4f})')
|
|
|
|
# 4. 测试文本到图像检索
|
|
print('\n=== 测试文本到图像检索 ===')
|
|
query = "beautiful image"
|
|
results = retrieval.search_images_by_text(query, top_k=3)
|
|
print(f'查询: "{query}"')
|
|
for i, (image_path, score) in enumerate(results):
|
|
print(f' {i+1}. {image_path} (相似度: {score:.4f})')
|
|
|
|
# 5. 测试图像到文本检索
|
|
print('\n=== 测试图像到文本检索 ===')
|
|
query_image = existing_images[0]
|
|
results = retrieval.search_text_by_image(query_image, top_k=3)
|
|
print(f'查询图像: {query_image}')
|
|
for i, (text, score) in enumerate(results):
|
|
print(f' {i+1}. {text} (相似度: {score:.4f})')
|
|
|
|
# 6. 测试图像到图像检索
|
|
print('\n=== 测试图像到图像检索 ===')
|
|
query_image = existing_images[0]
|
|
results = retrieval.search_images_by_image(query_image, top_k=3)
|
|
print(f'查询图像: {query_image}')
|
|
for i, (image_path, score) in enumerate(results):
|
|
print(f' {i+1}. {image_path} (相似度: {score:.4f})')
|
|
|
|
print('\n✅ 所有四种检索模式测试完成!')
|
|
|
|
# 7. 测试Web应用兼容的方法名
|
|
print('\n=== 测试Web应用兼容方法 ===')
|
|
try:
|
|
results = retrieval.search_text_to_image("test query", top_k=2)
|
|
print('✅ search_text_to_image 方法正常')
|
|
|
|
results = retrieval.search_image_to_text(existing_images[0], top_k=2)
|
|
print('✅ search_image_to_text 方法正常')
|
|
|
|
results = retrieval.search_image_to_image(existing_images[0], top_k=2)
|
|
print('✅ search_image_to_image 方法正常')
|
|
|
|
except Exception as e:
|
|
print(f'❌ Web应用兼容方法测试失败: {e}')
|
|
|
|
except Exception as e:
|
|
print(f'❌ 测试过程中出现错误: {e}')
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
if __name__ == "__main__":
|
|
test_all_retrieval_modes()
|