50 lines
1.9 KiB
Python
50 lines
1.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
测试图像编码功能
|
|
"""
|
|
|
|
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
def test_image_encoding():
|
|
print('正在初始化多GPU多模态检索系统...')
|
|
retrieval = MultiGPUMultimodalRetrieval()
|
|
|
|
# 测试文本编码
|
|
print('测试文本编码...')
|
|
text_embeddings = retrieval.encode_text_batch(['这是一个测试文本'])
|
|
print(f'文本embedding形状: {text_embeddings.shape}')
|
|
print(f'文本embedding数据类型: {text_embeddings.dtype}')
|
|
|
|
# 测试图像编码
|
|
print('测试图像编码...')
|
|
test_images = ['sample_images/1755677101_1__.jpg']
|
|
image_embeddings = retrieval.encode_image_batch(test_images)
|
|
print(f'图像embedding形状: {image_embeddings.shape}')
|
|
print(f'图像embedding数据类型: {image_embeddings.dtype}')
|
|
|
|
# 测试两次相同图像的embedding是否一致
|
|
print('测试embedding一致性...')
|
|
image_embeddings2 = retrieval.encode_image_batch(test_images)
|
|
consistency = np.allclose(image_embeddings, image_embeddings2, rtol=1e-5)
|
|
print(f'相同图像embedding一致性: {consistency}')
|
|
|
|
# 测试不同图像的embedding差异
|
|
print('测试不同图像embedding差异...')
|
|
test_images2 = ['sample_images/1755677101_2__.jpg']
|
|
image_embeddings3 = retrieval.encode_image_batch(test_images2)
|
|
similarity = np.dot(image_embeddings[0], image_embeddings3[0]) / (np.linalg.norm(image_embeddings[0]) * np.linalg.norm(image_embeddings3[0]))
|
|
print(f'不同图像间相似度: {similarity:.4f}')
|
|
|
|
# 验证维度一致性
|
|
if text_embeddings.shape[1] == image_embeddings.shape[1]:
|
|
print('✅ 文本和图像embedding维度一致')
|
|
else:
|
|
print('❌ 文本和图像embedding维度不一致')
|
|
|
|
print('测试完成!')
|
|
|
|
if __name__ == "__main__":
|
|
test_image_encoding()
|