230 lines
7.1 KiB
Python
230 lines
7.1 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
测试本地模型和FAISS向量数据库的多模态检索系统
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import logging
|
|
from pathlib import Path
|
|
import time
|
|
from PIL import Image
|
|
import numpy as np
|
|
from multimodal_retrieval_local import MultimodalRetrievalLocal
|
|
|
|
# 设置日志
|
|
logging.basicConfig(level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 设置离线模式
|
|
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
|
|
|
def test_text_retrieval():
|
|
"""测试文本检索功能"""
|
|
print("\n=== 测试文本检索 ===")
|
|
|
|
# 初始化检索系统
|
|
print("初始化检索系统...")
|
|
retrieval = MultimodalRetrievalLocal(
|
|
model_path="/root/models/Ops-MM-embedding-v1-7B",
|
|
use_all_gpus=True,
|
|
index_path="local_faiss_text_test"
|
|
)
|
|
|
|
# 测试文本
|
|
texts = [
|
|
"一只可爱的橘色猫咪在沙发上睡觉",
|
|
"城市夜景中的高楼大厦和车流",
|
|
"阳光明媚的海滩上,人们在冲浪和晒太阳",
|
|
"美味的意大利面配红酒和沙拉",
|
|
"雪山上滑雪的运动员"
|
|
]
|
|
|
|
# 添加文本
|
|
print("\n添加文本到检索系统...")
|
|
text_ids = retrieval.add_texts(texts)
|
|
print(f"添加了{len(text_ids)}条文本")
|
|
|
|
# 获取统计信息
|
|
stats = retrieval.get_stats()
|
|
print(f"检索系统统计信息: {stats}")
|
|
|
|
# 测试文本搜索
|
|
print("\n测试文本搜索...")
|
|
queries = ["一只猫在睡觉", "都市风光", "海边的景色"]
|
|
|
|
for query in queries:
|
|
print(f"\n查询: {query}")
|
|
results = retrieval.search_by_text(query, k=2)
|
|
for i, result in enumerate(results):
|
|
print(f" 结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})")
|
|
|
|
# 保存索引
|
|
print("\n保存索引...")
|
|
retrieval.save_index()
|
|
|
|
print("\n文本检索测试完成!")
|
|
return retrieval
|
|
|
|
def test_image_retrieval():
|
|
"""测试图像检索功能"""
|
|
print("\n=== 测试图像检索 ===")
|
|
|
|
# 初始化检索系统
|
|
print("初始化检索系统...")
|
|
retrieval = MultimodalRetrievalLocal(
|
|
model_path="/root/models/Ops-MM-embedding-v1-7B",
|
|
use_all_gpus=True,
|
|
index_path="local_faiss_image_test"
|
|
)
|
|
|
|
# 创建测试图像
|
|
print("\n创建测试图像...")
|
|
images = []
|
|
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
|
|
image_paths = []
|
|
|
|
for i, color in enumerate(colors):
|
|
img = Image.new('RGB', (224, 224), color=color)
|
|
images.append(img)
|
|
|
|
# 保存图像
|
|
img_path = f"/tmp/test_image_{i}.png"
|
|
img.save(img_path)
|
|
image_paths.append(img_path)
|
|
print(f"创建图像: {img_path}")
|
|
|
|
# 添加图像
|
|
print("\n添加图像到检索系统...")
|
|
metadatas = [{"description": f"测试图像 {i+1}"} for i in range(len(images))]
|
|
image_ids = retrieval.add_images(images, metadatas, image_paths)
|
|
print(f"添加了{len(image_ids)}张图像")
|
|
|
|
# 获取统计信息
|
|
stats = retrieval.get_stats()
|
|
print(f"检索系统统计信息: {stats}")
|
|
|
|
# 测试图像搜索
|
|
print("\n测试图像搜索...")
|
|
query_image = Image.new('RGB', (224, 224), color=(255, 0, 0)) # 红色图像
|
|
|
|
print("\n使用图像查询图像:")
|
|
results = retrieval.search_by_image(query_image, k=2, filter_type="image")
|
|
for i, result in enumerate(results):
|
|
print(f" 结果 {i+1}: {result.get('description', 'N/A')} (分数: {result.get('score', 0):.4f})")
|
|
|
|
# 保存索引
|
|
print("\n保存索引...")
|
|
retrieval.save_index()
|
|
|
|
print("\n图像检索测试完成!")
|
|
return retrieval
|
|
|
|
def test_cross_modal_retrieval():
|
|
"""测试跨模态检索功能"""
|
|
print("\n=== 测试跨模态检索 ===")
|
|
|
|
# 初始化检索系统
|
|
print("初始化检索系统...")
|
|
retrieval = MultimodalRetrievalLocal(
|
|
model_path="/root/models/Ops-MM-embedding-v1-7B",
|
|
use_all_gpus=True,
|
|
index_path="local_faiss_cross_modal_test"
|
|
)
|
|
|
|
# 添加文本
|
|
texts = [
|
|
"一只红色的苹果",
|
|
"绿色的草地",
|
|
"蓝色的大海",
|
|
"黄色的向日葵",
|
|
"青色的天空"
|
|
]
|
|
print("\n添加文本到检索系统...")
|
|
text_ids = retrieval.add_texts(texts)
|
|
print(f"添加了{len(text_ids)}条文本")
|
|
|
|
# 添加图像
|
|
print("\n添加图像到检索系统...")
|
|
images = []
|
|
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
|
|
descriptions = ["红色图像", "绿色图像", "蓝色图像", "黄色图像", "青色图像"]
|
|
|
|
for i, color in enumerate(colors):
|
|
img = Image.new('RGB', (224, 224), color=color)
|
|
images.append(img)
|
|
|
|
metadatas = [{"description": desc} for desc in descriptions]
|
|
image_ids = retrieval.add_images(images, metadatas)
|
|
print(f"添加了{len(image_ids)}张图像")
|
|
|
|
# 获取统计信息
|
|
stats = retrieval.get_stats()
|
|
print(f"检索系统统计信息: {stats}")
|
|
|
|
# 测试文搜图
|
|
print("\n测试文搜图...")
|
|
query_text = "红色"
|
|
print(f"查询文本: {query_text}")
|
|
results = retrieval.search_by_text(query_text, k=2, filter_type="image")
|
|
for i, result in enumerate(results):
|
|
print(f" 结果 {i+1}: {result.get('description', 'N/A')} (分数: {result.get('score', 0):.4f})")
|
|
|
|
# 测试图搜文
|
|
print("\n测试图搜文...")
|
|
query_image = Image.new('RGB', (224, 224), color=(0, 0, 255)) # 蓝色图像
|
|
print("查询图像: 蓝色图像")
|
|
results = retrieval.search_by_image(query_image, k=2, filter_type="text")
|
|
for i, result in enumerate(results):
|
|
print(f" 结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})")
|
|
|
|
# 保存索引
|
|
print("\n保存索引...")
|
|
retrieval.save_index()
|
|
|
|
print("\n跨模态检索测试完成!")
|
|
return retrieval
|
|
|
|
def main():
|
|
"""主函数"""
|
|
print("=== 本地多模态检索系统测试 ===")
|
|
|
|
# 检查模型路径
|
|
model_path = "/root/models/Ops-MM-embedding-v1-7B"
|
|
if not os.path.exists(model_path):
|
|
print(f"错误: 模型路径不存在: {model_path}")
|
|
print("请先下载模型到指定路径")
|
|
return
|
|
|
|
# 检查模型文件
|
|
config_file = os.path.join(model_path, "config.json")
|
|
if not os.path.exists(config_file):
|
|
print(f"错误: 模型配置文件不存在: {config_file}")
|
|
print("请确保模型文件已正确下载")
|
|
return
|
|
|
|
print(f"模型路径验证成功: {model_path}")
|
|
|
|
# 运行测试
|
|
try:
|
|
# 测试文本检索
|
|
test_text_retrieval()
|
|
|
|
# 测试图像检索
|
|
test_image_retrieval()
|
|
|
|
# 测试跨模态检索
|
|
test_cross_modal_retrieval()
|
|
|
|
print("\n所有测试完成!")
|
|
|
|
except Exception as e:
|
|
print(f"测试过程中发生错误: {str(e)}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|