241 lines
9.9 KiB
Python
241 lines
9.9 KiB
Python
import os
|
||
import subprocess
|
||
import sys
|
||
import time
|
||
import platform
|
||
import re
|
||
import logging
|
||
|
||
# 获取当前脚本所在目录的父目录(项目根目录)
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
project_root = os.path.dirname(script_dir)
|
||
|
||
# 配置日志 - 使用相对路径
|
||
log_file = os.path.join(project_root, 'environment_check.log')
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.StreamHandler(),
|
||
logging.FileHandler(log_file)
|
||
]
|
||
)
|
||
|
||
def run_command(command, description=None, show_output=False):
|
||
"""运行命令并打印状态"""
|
||
if description:
|
||
logging.info(f"=== {description} ===")
|
||
logging.info(f"执行: {command}")
|
||
|
||
try:
|
||
# 明确指定使用bash执行命令
|
||
result = subprocess.run(['/bin/bash', '-c', command],
|
||
check=True,
|
||
stdout=subprocess.PIPE,
|
||
stderr=subprocess.PIPE)
|
||
output = result.stdout.decode('utf-8')
|
||
logging.info(f"命令执行成功: {command}")
|
||
if show_output:
|
||
logging.info(f"输出:\n{output}")
|
||
return True, output
|
||
except subprocess.CalledProcessError as e:
|
||
error = e.stderr.decode('utf-8')
|
||
logging.error(f"命令执行失败: {command}")
|
||
logging.error(f"错误信息: {error}")
|
||
return False, error
|
||
|
||
def check_dependencies():
|
||
"""检查系统是否满足RoboTwin的依赖要求"""
|
||
results = []
|
||
all_passed = True
|
||
|
||
# 打印依赖要求
|
||
logging.info("\n=== RoboTwin Dependencies ===")
|
||
logging.info("""
|
||
Python versions:
|
||
* Python 3.8, 3.10
|
||
|
||
Operating systems:
|
||
* Linux: Ubuntu 18.04+, Centos 7+
|
||
|
||
Hardware:
|
||
* Rendering: NVIDIA or AMD GPU
|
||
* Recommended CUDA Version: 12.1
|
||
* Ray tracing: NVIDIA RTX GPU or AMD equivalent
|
||
* Ray-tracing Denoising: NVIDIA GPU
|
||
* GPU Simulation: NVIDIA GPU
|
||
|
||
Software:
|
||
* Ray tracing: NVIDIA Driver >= 470
|
||
* Denoising (OIDN): NVIDIA Driver >= 520
|
||
* CUDA Version: 12.1
|
||
* Conda: Required
|
||
""")
|
||
|
||
# 检查操作系统
|
||
os_check = {"name": "操作系统检查", "passed": False, "details": ""}
|
||
os_name = platform.system()
|
||
if os_name != "Linux":
|
||
os_check["details"] = f"不支持的操作系统: {os_name},需要 Linux"
|
||
else:
|
||
# 检查Linux发行版
|
||
success, distro_info = run_command("cat /etc/os-release", show_output=False)
|
||
if success:
|
||
# 检查是否是Ubuntu或CentOS
|
||
if "Ubuntu" in distro_info:
|
||
# 提取Ubuntu版本
|
||
match = re.search(r'VERSION_ID="(\d+\.\d+)"', distro_info)
|
||
if match and float(match.group(1)) >= 18.04:
|
||
os_check["passed"] = True
|
||
os_check["details"] = f"Ubuntu {match.group(1)},满足要求 (Ubuntu 18.04+)"
|
||
else:
|
||
os_check["details"] = "Ubuntu版本过低,需要 Ubuntu 18.04+"
|
||
elif "CentOS" in distro_info:
|
||
# 提取CentOS版本
|
||
match = re.search(r'VERSION_ID="(\d+)"', distro_info)
|
||
if match and int(match.group(1)) >= 7:
|
||
os_check["passed"] = True
|
||
os_check["details"] = f"CentOS {match.group(1)},满足要求 (CentOS 7+)"
|
||
else:
|
||
os_check["details"] = "CentOS版本过低,需要 CentOS 7+"
|
||
else:
|
||
os_check["details"] = "不支持的Linux发行版,需要 Ubuntu 18.04+ 或 CentOS 7+"
|
||
else:
|
||
os_check["details"] = "无法确定Linux发行版"
|
||
results.append(os_check)
|
||
all_passed = all_passed and os_check["passed"]
|
||
|
||
# 检查Conda是否已安装
|
||
conda_check = {"name": "Conda检查", "passed": False, "details": ""}
|
||
success, conda_info = run_command("conda --version", "检查Conda是否已安装", show_output=False)
|
||
if success:
|
||
# 提取Conda版本
|
||
match = re.search(r"conda (\d+\.\d+\.\d+)", conda_info)
|
||
if match:
|
||
conda_version = match.group(1)
|
||
conda_check["passed"] = True
|
||
conda_check["details"] = f"Conda已安装 (版本 {conda_version})"
|
||
|
||
# 检查conda初始化状态
|
||
success, conda_init = run_command("conda info | grep -i 'initialized'", "检查Conda初始化状态", show_output=False)
|
||
if success and "yes" in conda_init.lower():
|
||
conda_check["details"] += ", 已初始化"
|
||
else:
|
||
conda_check["details"] += ", 但可能未正确初始化,请运行 'conda init bash'"
|
||
else:
|
||
conda_check["details"] = "无法确定Conda版本,但似乎已安装"
|
||
else:
|
||
conda_check["details"] = "未安装Conda,请先安装Conda"
|
||
results.append(conda_check)
|
||
all_passed = all_passed and conda_check["passed"]
|
||
|
||
# 检查NVIDIA GPU
|
||
gpu_check = {"name": "GPU检查", "passed": False, "details": ""}
|
||
success, gpu_info = run_command("nvidia-smi", "检查NVIDIA GPU", show_output=False)
|
||
if success:
|
||
# 检查是否检测到NVIDIA GPU
|
||
if "NVIDIA" in gpu_info:
|
||
gpu_check["passed"] = True
|
||
gpu_check["details"] = "检测到NVIDIA GPU"
|
||
else:
|
||
gpu_check["details"] = "未检测到NVIDIA GPU"
|
||
else:
|
||
# 尝试检查AMD GPU
|
||
success, amd_gpu_info = run_command("lspci | grep -i amd", "检查AMD GPU", show_output=False)
|
||
if success and "AMD" in amd_gpu_info and ("Graphics" in amd_gpu_info or "VGA" in amd_gpu_info):
|
||
gpu_check["passed"] = True
|
||
gpu_check["details"] = "检测到AMD GPU"
|
||
else:
|
||
gpu_check["details"] = "未检测到NVIDIA或AMD GPU,需要NVIDIA或AMD GPU"
|
||
results.append(gpu_check)
|
||
all_passed = all_passed and gpu_check["passed"]
|
||
|
||
# 检查NVIDIA驱动版本
|
||
driver_check = {"name": "NVIDIA驱动版本检查", "passed": False, "details": ""}
|
||
if success and "NVIDIA" in gpu_info:
|
||
match = re.search(r"Driver Version: (\d+\.\d+)", gpu_info)
|
||
if match:
|
||
driver_version = float(match.group(1))
|
||
if driver_version >= 470:
|
||
driver_check["passed"] = True
|
||
driver_check["details"] = f"NVIDIA驱动版本 {driver_version},满足要求 (>=470)"
|
||
|
||
# 检查是否满足光线追踪去噪要求
|
||
if driver_version >= 520:
|
||
driver_check["details"] += ",同时满足去噪要求 (>=520)"
|
||
else:
|
||
driver_check["details"] += ",但不满足去噪要求 (>=520)"
|
||
else:
|
||
driver_check["details"] = f"NVIDIA驱动版本 {driver_version} 过低,需要 >=470"
|
||
else:
|
||
driver_check["details"] = "无法确定NVIDIA驱动版本"
|
||
elif gpu_check["passed"]:
|
||
driver_check["passed"] = True
|
||
driver_check["details"] = "使用AMD GPU,跳过NVIDIA驱动版本检查"
|
||
else:
|
||
driver_check["details"] = "未检测到NVIDIA GPU,无法检查驱动版本"
|
||
results.append(driver_check)
|
||
all_passed = all_passed and driver_check["passed"]
|
||
|
||
# 检查CUDA版本
|
||
cuda_check = {"name": "CUDA版本检查", "passed": False, "details": ""}
|
||
if success and "NVIDIA" in gpu_info:
|
||
# 尝试获取CUDA版本
|
||
success, cuda_info = run_command("nvcc --version", "检查CUDA版本", show_output=False)
|
||
if success:
|
||
match = re.search(r"release (\d+\.\d+)", cuda_info)
|
||
if match:
|
||
cuda_version = float(match.group(1))
|
||
if abs(cuda_version - 12.1) < 0.1: # 允许小的版本差异,比如12.1和12.0
|
||
cuda_check["passed"] = True
|
||
cuda_check["details"] = f"CUDA版本 {cuda_version},满足建议版本 (12.1)"
|
||
else:
|
||
cuda_check["passed"] = True # 仍然通过,但发出警告
|
||
cuda_check["details"] = f"CUDA版本 {cuda_version},不是推荐的 12.1,可能会遇到兼容性问题"
|
||
else:
|
||
cuda_check["details"] = "无法确定CUDA版本"
|
||
else:
|
||
# 尝试使用nvidia-smi确定CUDA版本
|
||
match = re.search(r"CUDA Version: (\d+\.\d+)", gpu_info)
|
||
if match:
|
||
cuda_version = float(match.group(1))
|
||
if abs(cuda_version - 12.1) < 0.1:
|
||
cuda_check["passed"] = True
|
||
cuda_check["details"] = f"CUDA版本 {cuda_version},满足建议版本 (12.1)"
|
||
else:
|
||
cuda_check["passed"] = True # 仍然通过,但发出警告
|
||
cuda_check["details"] = f"CUDA版本 {cuda_version},不是推荐的 12.1,可能会遇到兼容性问题"
|
||
else:
|
||
cuda_check["details"] = "未安装CUDA或无法确定版本"
|
||
elif gpu_check["passed"]:
|
||
cuda_check["passed"] = True
|
||
cuda_check["details"] = "使用AMD GPU,跳过CUDA版本检查"
|
||
else:
|
||
cuda_check["details"] = "未检测到NVIDIA GPU,无法检查CUDA版本"
|
||
results.append(cuda_check)
|
||
all_passed = all_passed and cuda_check["passed"]
|
||
|
||
# 打印检查结果
|
||
logging.info("\n=== 环境依赖检查结果 ===")
|
||
for result in results:
|
||
status = "✓" if result["passed"] else "✗"
|
||
logging.info(f"{status} {result['name']}: {result['details']}")
|
||
|
||
if not all_passed:
|
||
logging.warning("\n警告:某些依赖项不满足要求,环境配置可能无法正常完成或运行。")
|
||
|
||
return all_passed, results
|
||
|
||
if __name__ == "__main__":
|
||
all_passed, results = check_dependencies()
|
||
# 将结果保存到项目根目录
|
||
import json
|
||
results_file = os.path.join(project_root, 'environment_check_results.json')
|
||
with open(results_file, 'w') as f:
|
||
json.dump({"dependencies_met": all_passed, "check_results": results}, f)
|
||
|
||
# 退出状态码
|
||
if all_passed:
|
||
sys.exit(0)
|
||
else:
|
||
sys.exit(1) |