mmeb/test_image_encoding.py
2025-08-20 10:01:03 +00:00

50 lines
1.9 KiB
Python

#!/usr/bin/env python3
"""
测试图像编码功能
"""
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
import numpy as np
from PIL import Image
def test_image_encoding():
print('正在初始化多GPU多模态检索系统...')
retrieval = MultiGPUMultimodalRetrieval()
# 测试文本编码
print('测试文本编码...')
text_embeddings = retrieval.encode_text_batch(['这是一个测试文本'])
print(f'文本embedding形状: {text_embeddings.shape}')
print(f'文本embedding数据类型: {text_embeddings.dtype}')
# 测试图像编码
print('测试图像编码...')
test_images = ['sample_images/1755677101_1__.jpg']
image_embeddings = retrieval.encode_image_batch(test_images)
print(f'图像embedding形状: {image_embeddings.shape}')
print(f'图像embedding数据类型: {image_embeddings.dtype}')
# 测试两次相同图像的embedding是否一致
print('测试embedding一致性...')
image_embeddings2 = retrieval.encode_image_batch(test_images)
consistency = np.allclose(image_embeddings, image_embeddings2, rtol=1e-5)
print(f'相同图像embedding一致性: {consistency}')
# 测试不同图像的embedding差异
print('测试不同图像embedding差异...')
test_images2 = ['sample_images/1755677101_2__.jpg']
image_embeddings3 = retrieval.encode_image_batch(test_images2)
similarity = np.dot(image_embeddings[0], image_embeddings3[0]) / (np.linalg.norm(image_embeddings[0]) * np.linalg.norm(image_embeddings3[0]))
print(f'不同图像间相似度: {similarity:.4f}')
# 验证维度一致性
if text_embeddings.shape[1] == image_embeddings.shape[1]:
print('✅ 文本和图像embedding维度一致')
else:
print('❌ 文本和图像embedding维度不一致')
print('测试完成!')
if __name__ == "__main__":
test_image_encoding()