236 lines
7.0 KiB
Python
236 lines
7.0 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
快速测试脚本 - 验证多模态检索系统功能
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import logging
|
||
import traceback
|
||
from pathlib import Path
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
def test_imports():
|
||
"""测试关键模块导入"""
|
||
logger.info("🔍 测试模块导入...")
|
||
|
||
try:
|
||
import torch
|
||
logger.info(f"✅ PyTorch {torch.__version__}")
|
||
|
||
import transformers
|
||
logger.info(f"✅ Transformers {transformers.__version__}")
|
||
|
||
import numpy as np
|
||
logger.info(f"✅ NumPy {np.__version__}")
|
||
|
||
from PIL import Image
|
||
logger.info("✅ Pillow")
|
||
|
||
import flask
|
||
logger.info(f"✅ Flask {flask.__version__}")
|
||
|
||
try:
|
||
import pymochow
|
||
logger.info("✅ PyMochow (百度VDB SDK)")
|
||
except ImportError:
|
||
logger.warning("⚠️ PyMochow 未安装,需要运行: pip install pymochow")
|
||
|
||
try:
|
||
import pymongo
|
||
logger.info("✅ PyMongo")
|
||
except ImportError:
|
||
logger.warning("⚠️ PyMongo 未安装,需要运行: pip install pymongo")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 模块导入失败: {str(e)}")
|
||
return False
|
||
|
||
def test_gpu_availability():
|
||
"""测试GPU可用性"""
|
||
logger.info("🖥️ 检查GPU环境...")
|
||
|
||
try:
|
||
import torch
|
||
|
||
if torch.cuda.is_available():
|
||
gpu_count = torch.cuda.device_count()
|
||
logger.info(f"✅ 检测到 {gpu_count} 个GPU")
|
||
|
||
for i in range(gpu_count):
|
||
gpu_name = torch.cuda.get_device_name(i)
|
||
gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
|
||
logger.info(f" GPU {i}: {gpu_name} ({gpu_memory:.1f}GB)")
|
||
|
||
return True
|
||
else:
|
||
logger.info("ℹ️ 未检测到GPU,将使用CPU")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ GPU检查失败: {str(e)}")
|
||
return False
|
||
|
||
def test_baidu_vdb_connection():
|
||
"""测试百度VDB连接"""
|
||
logger.info("🔗 测试百度VDB连接...")
|
||
|
||
try:
|
||
import pymochow
|
||
from pymochow.configuration import Configuration
|
||
from pymochow.auth.bce_credentials import BceCredentials
|
||
|
||
# 连接配置
|
||
account = "root"
|
||
api_key = "vdb$yjr9ln3n0td"
|
||
endpoint = "http://180.76.96.191:5287"
|
||
|
||
config = Configuration(
|
||
credentials=BceCredentials(account, api_key),
|
||
endpoint=endpoint
|
||
)
|
||
|
||
client = pymochow.MochowClient(config)
|
||
|
||
# 测试连接 - 列出数据库
|
||
databases = client.list_databases()
|
||
logger.info(f"✅ VDB连接成功,发现 {len(databases)} 个数据库")
|
||
|
||
client.close()
|
||
return True
|
||
|
||
except ImportError:
|
||
logger.error("❌ PyMochow 未安装,无法测试VDB连接")
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"❌ VDB连接失败: {str(e)}")
|
||
return False
|
||
|
||
def test_model_loading():
|
||
"""测试模型加载"""
|
||
logger.info("🤖 测试模型加载...")
|
||
|
||
try:
|
||
from ops_mm_embedding_v1 import OpsMMEmbeddingV1
|
||
|
||
logger.info("正在初始化模型...")
|
||
model = OpsMMEmbeddingV1()
|
||
|
||
# 测试文本编码
|
||
test_texts = ["测试文本"]
|
||
embeddings = model.embed(texts=test_texts)
|
||
|
||
logger.info(f"✅ 模型加载成功,向量维度: {embeddings.shape}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 模型加载失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return False
|
||
|
||
def test_web_app_import():
|
||
"""测试Web应用导入"""
|
||
logger.info("🌐 测试Web应用模块...")
|
||
|
||
try:
|
||
# 测试导入主要模块
|
||
from multimodal_retrieval_vdb_only import MultimodalRetrievalVDBOnly
|
||
logger.info("✅ 多模态检索系统模块")
|
||
|
||
from baidu_vdb_production import BaiduVDBProduction
|
||
logger.info("✅ 百度VDB后端模块")
|
||
|
||
# 测试Web应用文件存在
|
||
web_app_file = Path("web_app_vdb_production.py")
|
||
if web_app_file.exists():
|
||
logger.info("✅ Web应用文件存在")
|
||
else:
|
||
logger.error("❌ Web应用文件不存在")
|
||
return False
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ Web应用模块测试失败: {str(e)}")
|
||
return False
|
||
|
||
def create_test_directories():
|
||
"""创建必要的测试目录"""
|
||
logger.info("📁 创建测试目录...")
|
||
|
||
directories = ["uploads", "sample_images", "text_data"]
|
||
|
||
for dir_name in directories:
|
||
dir_path = Path(dir_name)
|
||
dir_path.mkdir(exist_ok=True)
|
||
logger.info(f"✅ 目录已创建: {dir_name}")
|
||
|
||
def main():
|
||
"""主测试函数"""
|
||
logger.info("🚀 开始快速测试...")
|
||
logger.info("=" * 50)
|
||
|
||
test_results = {}
|
||
|
||
# 1. 测试模块导入
|
||
test_results["imports"] = test_imports()
|
||
|
||
# 2. 测试GPU环境
|
||
test_results["gpu"] = test_gpu_availability()
|
||
|
||
# 3. 测试VDB连接
|
||
test_results["vdb"] = test_baidu_vdb_connection()
|
||
|
||
# 4. 测试Web应用模块
|
||
test_results["web_modules"] = test_web_app_import()
|
||
|
||
# 5. 创建测试目录
|
||
create_test_directories()
|
||
|
||
# 6. 尝试测试模型加载(可选)
|
||
if test_results["imports"]:
|
||
logger.info("\n⚠️ 模型加载测试需要较长时间,是否跳过?")
|
||
logger.info("如需测试模型,请单独运行模型测试")
|
||
# test_results["model"] = test_model_loading()
|
||
|
||
# 输出测试结果
|
||
logger.info("\n" + "=" * 50)
|
||
logger.info("📊 测试结果汇总:")
|
||
logger.info("=" * 50)
|
||
|
||
for test_name, result in test_results.items():
|
||
status = "✅ 通过" if result else "❌ 失败"
|
||
test_display = {
|
||
"imports": "模块导入",
|
||
"gpu": "GPU环境",
|
||
"vdb": "VDB连接",
|
||
"web_modules": "Web模块",
|
||
"model": "模型加载"
|
||
}.get(test_name, test_name)
|
||
|
||
logger.info(f"{test_display}: {status}")
|
||
|
||
# 计算成功率
|
||
success_count = sum(test_results.values())
|
||
total_count = len(test_results)
|
||
success_rate = (success_count / total_count) * 100
|
||
|
||
logger.info(f"\n总体成功率: {success_count}/{total_count} ({success_rate:.1f}%)")
|
||
|
||
if success_rate >= 75:
|
||
logger.info("🎉 系统基本就绪!可以启动Web应用进行完整测试")
|
||
logger.info("运行命令: python web_app_vdb_production.py")
|
||
else:
|
||
logger.warning("⚠️ 系统存在问题,请检查失败的测试项")
|
||
|
||
return test_results
|
||
|
||
if __name__ == "__main__":
|
||
main()
|