#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 快速测试脚本 - 验证多模态检索系统功能 """ import os import sys import logging import traceback from pathlib import Path # 设置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def test_imports(): """测试关键模块导入""" logger.info("🔍 测试模块导入...") try: import torch logger.info(f"✅ PyTorch {torch.__version__}") import transformers logger.info(f"✅ Transformers {transformers.__version__}") import numpy as np logger.info(f"✅ NumPy {np.__version__}") from PIL import Image logger.info("✅ Pillow") import flask logger.info(f"✅ Flask {flask.__version__}") try: import pymochow logger.info("✅ PyMochow (百度VDB SDK)") except ImportError: logger.warning("⚠️ PyMochow 未安装,需要运行: pip install pymochow") try: import pymongo logger.info("✅ PyMongo") except ImportError: logger.warning("⚠️ PyMongo 未安装,需要运行: pip install pymongo") return True except Exception as e: logger.error(f"❌ 模块导入失败: {str(e)}") return False def test_gpu_availability(): """测试GPU可用性""" logger.info("🖥️ 检查GPU环境...") try: import torch if torch.cuda.is_available(): gpu_count = torch.cuda.device_count() logger.info(f"✅ 检测到 {gpu_count} 个GPU") for i in range(gpu_count): gpu_name = torch.cuda.get_device_name(i) gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3 logger.info(f" GPU {i}: {gpu_name} ({gpu_memory:.1f}GB)") return True else: logger.info("ℹ️ 未检测到GPU,将使用CPU") return False except Exception as e: logger.error(f"❌ GPU检查失败: {str(e)}") return False def test_baidu_vdb_connection(): """测试百度VDB连接""" logger.info("🔗 测试百度VDB连接...") try: import pymochow from pymochow.configuration import Configuration from pymochow.auth.bce_credentials import BceCredentials # 连接配置 account = "root" api_key = "vdb$yjr9ln3n0td" endpoint = "http://180.76.96.191:5287" config = Configuration( credentials=BceCredentials(account, api_key), endpoint=endpoint ) client = pymochow.MochowClient(config) # 测试连接 - 列出数据库 databases = client.list_databases() logger.info(f"✅ VDB连接成功,发现 {len(databases)} 个数据库") client.close() return True except ImportError: logger.error("❌ PyMochow 未安装,无法测试VDB连接") return False except Exception as e: logger.error(f"❌ VDB连接失败: {str(e)}") return False def test_model_loading(): """测试模型加载""" logger.info("🤖 测试模型加载...") try: from ops_mm_embedding_v1 import OpsMMEmbeddingV1 logger.info("正在初始化模型...") model = OpsMMEmbeddingV1() # 测试文本编码 test_texts = ["测试文本"] embeddings = model.embed(texts=test_texts) logger.info(f"✅ 模型加载成功,向量维度: {embeddings.shape}") return True except Exception as e: logger.error(f"❌ 模型加载失败: {str(e)}") logger.error(traceback.format_exc()) return False def test_web_app_import(): """测试Web应用导入""" logger.info("🌐 测试Web应用模块...") try: # 测试导入主要模块 from multimodal_retrieval_vdb_only import MultimodalRetrievalVDBOnly logger.info("✅ 多模态检索系统模块") from baidu_vdb_production import BaiduVDBProduction logger.info("✅ 百度VDB后端模块") # 测试Web应用文件存在 web_app_file = Path("web_app_vdb_production.py") if web_app_file.exists(): logger.info("✅ Web应用文件存在") else: logger.error("❌ Web应用文件不存在") return False return True except Exception as e: logger.error(f"❌ Web应用模块测试失败: {str(e)}") return False def create_test_directories(): """创建必要的测试目录""" logger.info("📁 创建测试目录...") directories = ["uploads", "sample_images", "text_data"] for dir_name in directories: dir_path = Path(dir_name) dir_path.mkdir(exist_ok=True) logger.info(f"✅ 目录已创建: {dir_name}") def main(): """主测试函数""" logger.info("🚀 开始快速测试...") logger.info("=" * 50) test_results = {} # 1. 测试模块导入 test_results["imports"] = test_imports() # 2. 测试GPU环境 test_results["gpu"] = test_gpu_availability() # 3. 测试VDB连接 test_results["vdb"] = test_baidu_vdb_connection() # 4. 测试Web应用模块 test_results["web_modules"] = test_web_app_import() # 5. 创建测试目录 create_test_directories() # 6. 尝试测试模型加载(可选) if test_results["imports"]: logger.info("\n⚠️ 模型加载测试需要较长时间,是否跳过?") logger.info("如需测试模型,请单独运行模型测试") # test_results["model"] = test_model_loading() # 输出测试结果 logger.info("\n" + "=" * 50) logger.info("📊 测试结果汇总:") logger.info("=" * 50) for test_name, result in test_results.items(): status = "✅ 通过" if result else "❌ 失败" test_display = { "imports": "模块导入", "gpu": "GPU环境", "vdb": "VDB连接", "web_modules": "Web模块", "model": "模型加载" }.get(test_name, test_name) logger.info(f"{test_display}: {status}") # 计算成功率 success_count = sum(test_results.values()) total_count = len(test_results) success_rate = (success_count / total_count) * 100 logger.info(f"\n总体成功率: {success_count}/{total_count} ({success_rate:.1f}%)") if success_rate >= 75: logger.info("🎉 系统基本就绪!可以启动Web应用进行完整测试") logger.info("运行命令: python web_app_vdb_production.py") else: logger.warning("⚠️ 系统存在问题,请检查失败的测试项") return test_results if __name__ == "__main__": main()