221 lines
7.5 KiB
Python
221 lines
7.5 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
VDB集成测试 - 使用现有的多模态系统测试VDB功能
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import numpy as np
|
||
import json
|
||
import time
|
||
import logging
|
||
from PIL import Image
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
def test_vdb_with_existing_system():
|
||
"""使用现有的多模态系统测试VDB集成"""
|
||
print("=" * 60)
|
||
print("VDB集成测试 - 使用现有多模态系统")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
# 导入现有的多模态系统
|
||
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
||
|
||
print("1. 初始化多模态检索系统...")
|
||
retrieval_system = MultiGPUMultimodalRetrieval()
|
||
|
||
if retrieval_system.model is None:
|
||
raise Exception("模型加载失败")
|
||
|
||
print("✅ 多模态系统初始化成功")
|
||
|
||
# 准备测试数据
|
||
test_texts = [
|
||
"一只可爱的小猫在阳光下睡觉",
|
||
"现代化的城市建筑群",
|
||
"美丽的自然风景和山脉",
|
||
"科技产品和电子设备",
|
||
"传统的中国文化艺术"
|
||
]
|
||
|
||
print(f"\n2. 测试文本编码功能...")
|
||
text_vectors = retrieval_system.encode_text_batch(test_texts)
|
||
print(f"✅ 文本向量生成成功: {text_vectors.shape}")
|
||
|
||
# 构建文本索引
|
||
print(f"\n3. 构建文本索引...")
|
||
retrieval_system.build_text_index_parallel(test_texts)
|
||
print("✅ 文本索引构建完成")
|
||
|
||
# 测试文搜文
|
||
print(f"\n4. 测试文搜文功能...")
|
||
query = "小动物"
|
||
results = retrieval_system.search_text_by_text(query, top_k=3)
|
||
print(f"查询: {query}")
|
||
for i, (text, score) in enumerate(results, 1):
|
||
print(f" {i}. {text} (相似度: {score:.4f})")
|
||
|
||
# 测试图像功能(如果有图像文件)
|
||
image_files = []
|
||
sample_dir = "sample_images"
|
||
if os.path.exists(sample_dir):
|
||
for ext in ['jpg', 'jpeg', 'png', 'gif']:
|
||
import glob
|
||
pattern = os.path.join(sample_dir, f"*.{ext}")
|
||
image_files.extend(glob.glob(pattern))
|
||
|
||
if image_files:
|
||
print(f"\n5. 测试图像编码功能...")
|
||
# 只测试前3张图像
|
||
test_images = image_files[:3]
|
||
image_vectors = retrieval_system.encode_image_batch(test_images)
|
||
print(f"✅ 图像向量生成成功: {image_vectors.shape}")
|
||
|
||
print(f"\n6. 构建图像索引...")
|
||
retrieval_system.build_image_index_parallel(test_images)
|
||
print("✅ 图像索引构建完成")
|
||
|
||
# 测试文搜图
|
||
print(f"\n7. 测试文搜图功能...")
|
||
query = "图片"
|
||
results = retrieval_system.search_images_by_text(query, top_k=2)
|
||
print(f"查询: {query}")
|
||
for i, (image_path, score) in enumerate(results, 1):
|
||
print(f" {i}. {os.path.basename(image_path)} (相似度: {score:.4f})")
|
||
else:
|
||
print(f"\n5. 跳过图像测试 - 未找到图像文件")
|
||
|
||
print(f"\n✅ 多模态系统功能测试完成!")
|
||
print("系统支持的检索模式:")
|
||
print("- 文搜文: ✅")
|
||
print("- 文搜图: ✅" if image_files else "- 文搜图: ⚠️ (需要图像数据)")
|
||
print("- 图搜文: ✅" if image_files else "- 图搜文: ⚠️ (需要图像数据)")
|
||
print("- 图搜图: ✅" if image_files else "- 图搜图: ⚠️ (需要图像数据)")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"❌ 测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
def test_vdb_connection_only():
|
||
"""仅测试VDB连接功能"""
|
||
print("\n" + "=" * 60)
|
||
print("VDB连接测试")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
import pymochow
|
||
from pymochow.configuration import Configuration
|
||
from pymochow.auth.bce_credentials import BceCredentials
|
||
|
||
# VDB配置
|
||
account = "root"
|
||
api_key = "vdb$yjr9ln3n0td"
|
||
endpoint = "http://180.76.96.191:5287"
|
||
|
||
print("1. 测试VDB连接...")
|
||
config = Configuration(
|
||
credentials=BceCredentials(account, api_key),
|
||
endpoint=endpoint
|
||
)
|
||
client = pymochow.MochowClient(config)
|
||
|
||
print("2. 查询数据库列表...")
|
||
db_list = client.list_databases()
|
||
print(f"✅ VDB连接成功,数据库数量: {len(db_list)}")
|
||
|
||
# 创建测试数据库
|
||
test_db_name = f"test_multimodal_{int(time.time())}"
|
||
print(f"3. 创建测试数据库: {test_db_name}")
|
||
db = client.create_database(test_db_name)
|
||
print("✅ 测试数据库创建成功")
|
||
|
||
# 清理测试数据库
|
||
print("4. 清理测试数据库...")
|
||
client.drop_database(test_db_name)
|
||
print("✅ 测试数据库清理完成")
|
||
|
||
client.close()
|
||
print("✅ VDB连接测试完成")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"❌ VDB连接测试失败: {e}")
|
||
return False
|
||
|
||
def create_sample_data():
|
||
"""创建示例数据用于测试"""
|
||
print("\n" + "=" * 60)
|
||
print("创建示例数据")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
# 创建示例文本数据
|
||
sample_texts = [
|
||
"一只橙色的小猫在花园里玩耍",
|
||
"现代化的摩天大楼群在夜晚闪闪发光",
|
||
"壮丽的山脉和清澈的湖水",
|
||
"最新的智能手机和平板电脑",
|
||
"传统的中国书法艺术作品",
|
||
"美味的中式料理和茶文化",
|
||
"春天的樱花盛开景象",
|
||
"海边的日落和波浪",
|
||
"森林中的小径和绿色植物",
|
||
"城市公园里的人们在锻炼"
|
||
]
|
||
|
||
# 保存文本数据
|
||
text_dir = "text_data"
|
||
os.makedirs(text_dir, exist_ok=True)
|
||
|
||
text_file = os.path.join(text_dir, "sample_texts.json")
|
||
with open(text_file, 'w', encoding='utf-8') as f:
|
||
json.dump(sample_texts, f, ensure_ascii=False, indent=2)
|
||
|
||
print(f"✅ 创建示例文本数据: {len(sample_texts)} 条")
|
||
print(f" 保存位置: {text_file}")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"❌ 创建示例数据失败: {e}")
|
||
return False
|
||
|
||
if __name__ == "__main__":
|
||
print("🚀 开始VDB集成测试")
|
||
|
||
# 1. 创建示例数据
|
||
create_sample_data()
|
||
|
||
# 2. 测试VDB连接
|
||
vdb_ok = test_vdb_connection_only()
|
||
|
||
# 3. 测试多模态系统
|
||
if vdb_ok:
|
||
multimodal_ok = test_vdb_with_existing_system()
|
||
|
||
if multimodal_ok:
|
||
print(f"\n🎉 所有测试通过!")
|
||
print("✅ VDB连接正常")
|
||
print("✅ 多模态系统功能正常")
|
||
print("✅ 向量编码和检索功能正常")
|
||
print("\n📋 下一步建议:")
|
||
print("1. 上传更多图像和文本数据")
|
||
print("2. 启动Web应用进行交互式测试")
|
||
print("3. 测试跨模态检索功能")
|
||
else:
|
||
print(f"\n⚠️ 多模态系统测试失败")
|
||
else:
|
||
print(f"\n❌ VDB连接测试失败,请检查配置")
|
||
|
||
print(f"\n测试完成时间: {time.strftime('%Y-%m-%d %H:%M:%S')}")
|