#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 测试本地模型和FAISS向量数据库的多模态检索系统 """ import os import sys import logging from pathlib import Path import time from PIL import Image import numpy as np from multimodal_retrieval_local import MultimodalRetrievalLocal # 设置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # 设置离线模式 os.environ['TRANSFORMERS_OFFLINE'] = '1' def test_text_retrieval(): """测试文本检索功能""" print("\n=== 测试文本检索 ===") # 初始化检索系统 print("初始化检索系统...") retrieval = MultimodalRetrievalLocal( model_path="/root/models/Ops-MM-embedding-v1-7B", use_all_gpus=True, index_path="local_faiss_text_test" ) # 测试文本 texts = [ "一只可爱的橘色猫咪在沙发上睡觉", "城市夜景中的高楼大厦和车流", "阳光明媚的海滩上,人们在冲浪和晒太阳", "美味的意大利面配红酒和沙拉", "雪山上滑雪的运动员" ] # 添加文本 print("\n添加文本到检索系统...") text_ids = retrieval.add_texts(texts) print(f"添加了{len(text_ids)}条文本") # 获取统计信息 stats = retrieval.get_stats() print(f"检索系统统计信息: {stats}") # 测试文本搜索 print("\n测试文本搜索...") queries = ["一只猫在睡觉", "都市风光", "海边的景色"] for query in queries: print(f"\n查询: {query}") results = retrieval.search_by_text(query, k=2) for i, result in enumerate(results): print(f" 结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})") # 保存索引 print("\n保存索引...") retrieval.save_index() print("\n文本检索测试完成!") return retrieval def test_image_retrieval(): """测试图像检索功能""" print("\n=== 测试图像检索 ===") # 初始化检索系统 print("初始化检索系统...") retrieval = MultimodalRetrievalLocal( model_path="/root/models/Ops-MM-embedding-v1-7B", use_all_gpus=True, index_path="local_faiss_image_test" ) # 创建测试图像 print("\n创建测试图像...") images = [] colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)] image_paths = [] for i, color in enumerate(colors): img = Image.new('RGB', (224, 224), color=color) images.append(img) # 保存图像 img_path = f"/tmp/test_image_{i}.png" img.save(img_path) image_paths.append(img_path) print(f"创建图像: {img_path}") # 添加图像 print("\n添加图像到检索系统...") metadatas = [{"description": f"测试图像 {i+1}"} for i in range(len(images))] image_ids = retrieval.add_images(images, metadatas, image_paths) print(f"添加了{len(image_ids)}张图像") # 获取统计信息 stats = retrieval.get_stats() print(f"检索系统统计信息: {stats}") # 测试图像搜索 print("\n测试图像搜索...") query_image = Image.new('RGB', (224, 224), color=(255, 0, 0)) # 红色图像 print("\n使用图像查询图像:") results = retrieval.search_by_image(query_image, k=2, filter_type="image") for i, result in enumerate(results): print(f" 结果 {i+1}: {result.get('description', 'N/A')} (分数: {result.get('score', 0):.4f})") # 保存索引 print("\n保存索引...") retrieval.save_index() print("\n图像检索测试完成!") return retrieval def test_cross_modal_retrieval(): """测试跨模态检索功能""" print("\n=== 测试跨模态检索 ===") # 初始化检索系统 print("初始化检索系统...") retrieval = MultimodalRetrievalLocal( model_path="/root/models/Ops-MM-embedding-v1-7B", use_all_gpus=True, index_path="local_faiss_cross_modal_test" ) # 添加文本 texts = [ "一只红色的苹果", "绿色的草地", "蓝色的大海", "黄色的向日葵", "青色的天空" ] print("\n添加文本到检索系统...") text_ids = retrieval.add_texts(texts) print(f"添加了{len(text_ids)}条文本") # 添加图像 print("\n添加图像到检索系统...") images = [] colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)] descriptions = ["红色图像", "绿色图像", "蓝色图像", "黄色图像", "青色图像"] for i, color in enumerate(colors): img = Image.new('RGB', (224, 224), color=color) images.append(img) metadatas = [{"description": desc} for desc in descriptions] image_ids = retrieval.add_images(images, metadatas) print(f"添加了{len(image_ids)}张图像") # 获取统计信息 stats = retrieval.get_stats() print(f"检索系统统计信息: {stats}") # 测试文搜图 print("\n测试文搜图...") query_text = "红色" print(f"查询文本: {query_text}") results = retrieval.search_by_text(query_text, k=2, filter_type="image") for i, result in enumerate(results): print(f" 结果 {i+1}: {result.get('description', 'N/A')} (分数: {result.get('score', 0):.4f})") # 测试图搜文 print("\n测试图搜文...") query_image = Image.new('RGB', (224, 224), color=(0, 0, 255)) # 蓝色图像 print("查询图像: 蓝色图像") results = retrieval.search_by_image(query_image, k=2, filter_type="text") for i, result in enumerate(results): print(f" 结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})") # 保存索引 print("\n保存索引...") retrieval.save_index() print("\n跨模态检索测试完成!") return retrieval def main(): """主函数""" print("=== 本地多模态检索系统测试 ===") # 检查模型路径 model_path = "/root/models/Ops-MM-embedding-v1-7B" if not os.path.exists(model_path): print(f"错误: 模型路径不存在: {model_path}") print("请先下载模型到指定路径") return # 检查模型文件 config_file = os.path.join(model_path, "config.json") if not os.path.exists(config_file): print(f"错误: 模型配置文件不存在: {config_file}") print("请确保模型文件已正确下载") return print(f"模型路径验证成功: {model_path}") # 运行测试 try: # 测试文本检索 test_text_retrieval() # 测试图像检索 test_image_retrieval() # 测试跨模态检索 test_cross_modal_retrieval() print("\n所有测试完成!") except Exception as e: print(f"测试过程中发生错误: {str(e)}") import traceback traceback.print_exc() if __name__ == "__main__": main()