RoboTwin_image/script/env_prerequisite_checker.py
2025-07-02 03:13:07 +00:00

241 lines
9.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import subprocess
import sys
import time
import platform
import re
import logging
# 获取当前脚本所在目录的父目录(项目根目录)
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(script_dir)
# 配置日志 - 使用相对路径
log_file = os.path.join(project_root, 'environment_check.log')
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler(log_file)
]
)
def run_command(command, description=None, show_output=False):
"""运行命令并打印状态"""
if description:
logging.info(f"=== {description} ===")
logging.info(f"执行: {command}")
try:
# 明确指定使用bash执行命令
result = subprocess.run(['/bin/bash', '-c', command],
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
output = result.stdout.decode('utf-8')
logging.info(f"命令执行成功: {command}")
if show_output:
logging.info(f"输出:\n{output}")
return True, output
except subprocess.CalledProcessError as e:
error = e.stderr.decode('utf-8')
logging.error(f"命令执行失败: {command}")
logging.error(f"错误信息: {error}")
return False, error
def check_dependencies():
"""检查系统是否满足RoboTwin的依赖要求"""
results = []
all_passed = True
# 打印依赖要求
logging.info("\n=== RoboTwin Dependencies ===")
logging.info("""
Python versions:
* Python 3.8, 3.10
Operating systems:
* Linux: Ubuntu 18.04+, Centos 7+
Hardware:
* Rendering: NVIDIA or AMD GPU
* Recommended CUDA Version: 12.1
* Ray tracing: NVIDIA RTX GPU or AMD equivalent
* Ray-tracing Denoising: NVIDIA GPU
* GPU Simulation: NVIDIA GPU
Software:
* Ray tracing: NVIDIA Driver >= 470
* Denoising (OIDN): NVIDIA Driver >= 520
* CUDA Version: 12.1
* Conda: Required
""")
# 检查操作系统
os_check = {"name": "操作系统检查", "passed": False, "details": ""}
os_name = platform.system()
if os_name != "Linux":
os_check["details"] = f"不支持的操作系统: {os_name},需要 Linux"
else:
# 检查Linux发行版
success, distro_info = run_command("cat /etc/os-release", show_output=False)
if success:
# 检查是否是Ubuntu或CentOS
if "Ubuntu" in distro_info:
# 提取Ubuntu版本
match = re.search(r'VERSION_ID="(\d+\.\d+)"', distro_info)
if match and float(match.group(1)) >= 18.04:
os_check["passed"] = True
os_check["details"] = f"Ubuntu {match.group(1)},满足要求 (Ubuntu 18.04+)"
else:
os_check["details"] = "Ubuntu版本过低需要 Ubuntu 18.04+"
elif "CentOS" in distro_info:
# 提取CentOS版本
match = re.search(r'VERSION_ID="(\d+)"', distro_info)
if match and int(match.group(1)) >= 7:
os_check["passed"] = True
os_check["details"] = f"CentOS {match.group(1)},满足要求 (CentOS 7+)"
else:
os_check["details"] = "CentOS版本过低需要 CentOS 7+"
else:
os_check["details"] = "不支持的Linux发行版需要 Ubuntu 18.04+ 或 CentOS 7+"
else:
os_check["details"] = "无法确定Linux发行版"
results.append(os_check)
all_passed = all_passed and os_check["passed"]
# 检查Conda是否已安装
conda_check = {"name": "Conda检查", "passed": False, "details": ""}
success, conda_info = run_command("conda --version", "检查Conda是否已安装", show_output=False)
if success:
# 提取Conda版本
match = re.search(r"conda (\d+\.\d+\.\d+)", conda_info)
if match:
conda_version = match.group(1)
conda_check["passed"] = True
conda_check["details"] = f"Conda已安装 (版本 {conda_version})"
# 检查conda初始化状态
success, conda_init = run_command("conda info | grep -i 'initialized'", "检查Conda初始化状态", show_output=False)
if success and "yes" in conda_init.lower():
conda_check["details"] += ", 已初始化"
else:
conda_check["details"] += ", 但可能未正确初始化,请运行 'conda init bash'"
else:
conda_check["details"] = "无法确定Conda版本但似乎已安装"
else:
conda_check["details"] = "未安装Conda请先安装Conda"
results.append(conda_check)
all_passed = all_passed and conda_check["passed"]
# 检查NVIDIA GPU
gpu_check = {"name": "GPU检查", "passed": False, "details": ""}
success, gpu_info = run_command("nvidia-smi", "检查NVIDIA GPU", show_output=False)
if success:
# 检查是否检测到NVIDIA GPU
if "NVIDIA" in gpu_info:
gpu_check["passed"] = True
gpu_check["details"] = "检测到NVIDIA GPU"
else:
gpu_check["details"] = "未检测到NVIDIA GPU"
else:
# 尝试检查AMD GPU
success, amd_gpu_info = run_command("lspci | grep -i amd", "检查AMD GPU", show_output=False)
if success and "AMD" in amd_gpu_info and ("Graphics" in amd_gpu_info or "VGA" in amd_gpu_info):
gpu_check["passed"] = True
gpu_check["details"] = "检测到AMD GPU"
else:
gpu_check["details"] = "未检测到NVIDIA或AMD GPU需要NVIDIA或AMD GPU"
results.append(gpu_check)
all_passed = all_passed and gpu_check["passed"]
# 检查NVIDIA驱动版本
driver_check = {"name": "NVIDIA驱动版本检查", "passed": False, "details": ""}
if success and "NVIDIA" in gpu_info:
match = re.search(r"Driver Version: (\d+\.\d+)", gpu_info)
if match:
driver_version = float(match.group(1))
if driver_version >= 470:
driver_check["passed"] = True
driver_check["details"] = f"NVIDIA驱动版本 {driver_version},满足要求 (>=470)"
# 检查是否满足光线追踪去噪要求
if driver_version >= 520:
driver_check["details"] += ",同时满足去噪要求 (>=520)"
else:
driver_check["details"] += ",但不满足去噪要求 (>=520)"
else:
driver_check["details"] = f"NVIDIA驱动版本 {driver_version} 过低,需要 >=470"
else:
driver_check["details"] = "无法确定NVIDIA驱动版本"
elif gpu_check["passed"]:
driver_check["passed"] = True
driver_check["details"] = "使用AMD GPU跳过NVIDIA驱动版本检查"
else:
driver_check["details"] = "未检测到NVIDIA GPU无法检查驱动版本"
results.append(driver_check)
all_passed = all_passed and driver_check["passed"]
# 检查CUDA版本
cuda_check = {"name": "CUDA版本检查", "passed": False, "details": ""}
if success and "NVIDIA" in gpu_info:
# 尝试获取CUDA版本
success, cuda_info = run_command("nvcc --version", "检查CUDA版本", show_output=False)
if success:
match = re.search(r"release (\d+\.\d+)", cuda_info)
if match:
cuda_version = float(match.group(1))
if abs(cuda_version - 12.1) < 0.1: # 允许小的版本差异比如12.1和12.0
cuda_check["passed"] = True
cuda_check["details"] = f"CUDA版本 {cuda_version},满足建议版本 (12.1)"
else:
cuda_check["passed"] = True # 仍然通过,但发出警告
cuda_check["details"] = f"CUDA版本 {cuda_version},不是推荐的 12.1,可能会遇到兼容性问题"
else:
cuda_check["details"] = "无法确定CUDA版本"
else:
# 尝试使用nvidia-smi确定CUDA版本
match = re.search(r"CUDA Version: (\d+\.\d+)", gpu_info)
if match:
cuda_version = float(match.group(1))
if abs(cuda_version - 12.1) < 0.1:
cuda_check["passed"] = True
cuda_check["details"] = f"CUDA版本 {cuda_version},满足建议版本 (12.1)"
else:
cuda_check["passed"] = True # 仍然通过,但发出警告
cuda_check["details"] = f"CUDA版本 {cuda_version},不是推荐的 12.1,可能会遇到兼容性问题"
else:
cuda_check["details"] = "未安装CUDA或无法确定版本"
elif gpu_check["passed"]:
cuda_check["passed"] = True
cuda_check["details"] = "使用AMD GPU跳过CUDA版本检查"
else:
cuda_check["details"] = "未检测到NVIDIA GPU无法检查CUDA版本"
results.append(cuda_check)
all_passed = all_passed and cuda_check["passed"]
# 打印检查结果
logging.info("\n=== 环境依赖检查结果 ===")
for result in results:
status = "" if result["passed"] else ""
logging.info(f"{status} {result['name']}: {result['details']}")
if not all_passed:
logging.warning("\n警告:某些依赖项不满足要求,环境配置可能无法正常完成或运行。")
return all_passed, results
if __name__ == "__main__":
all_passed, results = check_dependencies()
# 将结果保存到项目根目录
import json
results_file = os.path.join(project_root, 'environment_check_results.json')
with open(results_file, 'w') as f:
json.dump({"dependencies_met": all_passed, "check_results": results}, f)
# 退出状态码
if all_passed:
sys.exit(0)
else:
sys.exit(1)