test_gpu_scripts/modules/nccl_test.py
zulifeng 375d439abb feat: 新增 H20 支持、优化算力测试精度并修复多项稳定性问题
- gpu_specs: 新增 H20/H20-3e (中国合规版 H200) 规格定义,并修复
  GPU 名称匹配顺序,避免 "H200" 被 "H20" 子串误匹配
- benchmark(compute): 引入 L2 cache 规避的 matrix pool 轮换 +
  可选 torch.compile(max-autotune),FP8 增加 _scaled_mm 探测,
  显著提升 FP16/BF16/FP8 实测吞吐准确性
- benchmark(memory): nvbandwidth 增加 --disableAffinity 规避
  fabricmanager NVML 不兼容;全 0 结果时自动回退到 PyTorch;
  D2D 平均值排除对角线零值
- nccl: 各通信操作 (AllReduce/AllToAll/Broadcast 等) 使用独立
  带宽阈值比例,避免 AllToAll 误报 WARN
- rdma: 仅按 link_layer=InfiniBand 过滤端口,无 IB 硬件或全 DOWN
  时直接 SKIP 而非报错
- stress: 计算矩阵尺寸封顶 4096,并改为先并发派发再统一同步,
  修复 8 卡串行执行导致 duration 严重超时的问题
- report: 兼容 RDMA SKIP 状态与 PyTorch 回退场景的 Memory 判定,
  避免回退结果被误判为 FAIL
- config: 新增 benchmark.compute.use_compile 开关

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-12 21:41:46 +08:00

462 lines
18 KiB
Python

"""NCCL multi-GPU communication test — wraps official nccl-tests."""
import glob
import os
import re
import shutil
import subprocess
from datetime import datetime
from typing import Optional
from rich.console import Console
from rich.table import Table
from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
from modules.gpu_specs import detect_gpu_type, get_gpu_specs, resolve_tools_dir
TORCH_AVAILABLE = False
try:
import torch
if torch.cuda.is_available():
TORCH_AVAILABLE = True
except ImportError:
pass
# Per-operation bandwidth thresholds, as a fraction of NVLink bidirectional BW.
# AllReduce uses ring algorithm and saturates ring BW; AllToAll requires full-mesh
# transfers and on 8-GPU NVSwitch typically runs 10-20% lower than AllReduce.
# Public H100/H200 8-GPU benchmarks show AllToAll bus BW in the 300-380 GB/s range
# vs AllReduce in 400-500 GB/s. Using a single 40% threshold for both produced
# false positives for AllToAll.
_OP_BW_FRACTIONS = {
"allreduce": 0.40,
"alltoall": 0.30,
"broadcast": 0.35,
"reducescatter": 0.38,
"allgather": 0.38,
"sendrecv": 0.35,
}
class NCCLTest:
def __init__(self, config: dict):
self.config = config
self.console = Console()
self.nccl_cfg = config.get("nccl", {})
self.tools_dir = resolve_tools_dir(config)
self.gpu_type = detect_gpu_type()
self.specs = get_gpu_specs(self.gpu_type)
def _find_nccl_test(self, name: str) -> Optional[str]:
p = shutil.which(name)
if p:
return p
build_dir = os.path.join(self.tools_dir, "nccl-tests", "build")
local = os.path.join(build_dir, name)
if os.path.isfile(local) and os.access(local, shutil.os.X_OK):
return local
matches = glob.glob(os.path.join(self.tools_dir, "nccl-tests", "**", name), recursive=True)
for m in matches:
if os.access(m, shutil.os.X_OK):
return m
return None
def _find_mpirun(self) -> Optional[str]:
for cmd in ["mpirun", "mpiexec", os.path.join(self.tools_dir, "mpi", "bin", "mpirun")]:
p = shutil.which(cmd)
if p:
return p
return None
def run(self) -> dict:
gpu_count = 0
if TORCH_AVAILABLE:
gpu_count = torch.cuda.device_count()
if gpu_count < 2:
self.console.print(f"[yellow]NCCL test requires at least 2 GPUs (found {gpu_count})[/yellow]")
return {"error": "need_at_least_2_gpus", "gpu_count": gpu_count}
tests = []
if self.nccl_cfg.get("test_allreduce", True):
tests.append(("all_reduce_perf", "AllReduce"))
if self.nccl_cfg.get("test_alltoall", True):
tests.append(("alltoall_perf", "AllToAll"))
if self.nccl_cfg.get("test_broadcast", True):
tests.append(("broadcast_perf", "Broadcast"))
if self.nccl_cfg.get("test_reduce_scatter", False):
tests.append(("reduce_scatter_perf", "ReduceScatter"))
if self.nccl_cfg.get("test_allgather", False):
tests.append(("allgather_perf", "AllGather"))
if self.nccl_cfg.get("test_sendrecv", False):
tests.append(("sendrecv_perf", "SendRecv"))
nvlink_bw = self.specs.get("nvlink_bandwidth_gbps", 0)
# User-provided override applies uniformly across all ops; otherwise
# each op gets its own threshold from _OP_BW_FRACTIONS.
user_override = self.nccl_cfg.get("min_bandwidth_gbps")
def threshold_for(label: str) -> float:
if user_override:
return float(user_override)
if nvlink_bw <= 0:
return 10.0 # conservative floor
frac = _OP_BW_FRACTIONS.get(label.lower(), 0.40)
return round(nvlink_bw * frac)
if self.gpu_type == "unknown":
self.console.print("[yellow]Unknown GPU — using conservative bandwidth thresholds[/yellow]")
# Strategy: try nccl-tests binary directly (single-node, -g N),
# then mpirun, then torchrun fallback
results = {}
any_binary_worked = False
with Progress(
SpinnerColumn(), TextColumn("[progress.description]{task.description}"),
TimeElapsedColumn(), console=self.console,
) as progress:
task = progress.add_task("NCCL tests...", total=len(tests))
for binary, label in tests:
progress.update(task, description=f"NCCL {label}...")
op_min_bw = threshold_for(label)
result = self._run_one_nccl_test_direct(
binary, label, gpu_count, op_min_bw
)
if result.get("status") not in ("SKIP", None) and "error" not in result:
any_binary_worked = True
results[label.lower()] = result
else:
# Try mpirun fallback
mpirun = self._find_mpirun()
if mpirun:
result = self._run_one_nccl_test_mpirun(
binary, label, gpu_count, mpirun, op_min_bw
)
if result.get("status") not in ("SKIP", None) and "error" not in result:
any_binary_worked = True
results[label.lower()] = result
progress.advance(task)
if not any_binary_worked:
self.console.print("[yellow]nccl-tests binaries failed, falling back to torchrun[/yellow]")
return self._run_torchrun_fallback(gpu_count)
all_passed = all(
r.get("status") == "PASS"
for r in results.values()
if isinstance(r, dict) and "status" in r
)
return {
"passed": all_passed,
"source": "nccl-tests",
"min_bandwidth_gbps": {
lbl.lower(): threshold_for(lbl) for _, lbl in tests
},
"tests": results,
"gpu_count": gpu_count,
"timestamp": datetime.now().isoformat(),
"detected_gpu_type": self.gpu_type,
}
def _run_one_nccl_test_direct(self, binary_name: str, label: str,
gpu_count: int, min_bw: float) -> dict:
"""Run nccl-tests binary directly with -g N (no mpirun needed for single-node)."""
binary = self._find_nccl_test(binary_name)
if not binary:
return {"status": "SKIP", "error": f"{binary_name} not found"}
cmd = [
binary,
"-b", "8M",
"-e", "8G",
"-f", "2",
"-g", str(gpu_count),
"-w", "5",
"-n", "20",
]
try:
env = os.environ.copy()
env["NCCL_DEBUG"] = "WARN"
r = subprocess.run(cmd, capture_output=True, text=True, timeout=180, env=env)
combined = r.stdout + r.stderr
# Check for NCCL/CUDA compatibility errors
if "CUDA driver version is insufficient" in combined or \
"Test NCCL failure" in combined:
error_msg = "NCCL/CUDA driver version mismatch" \
if "CUDA driver version" in combined \
else "NCCL test failure (library incompatibility)"
return {"status": "FAIL", "error": error_msg}
if r.returncode != 0:
return {"status": "FAIL", "error": r.stderr[:300]}
return self._parse_nccl_output(r.stdout, min_bw)
except subprocess.TimeoutExpired:
return {"status": "FAIL", "error": "timeout"}
except Exception as e:
return {"status": "FAIL", "error": str(e)}
def _run_one_nccl_test_mpirun(self, binary_name: str, label: str,
gpu_count: int, mpirun: str, min_bw: float) -> dict:
"""Run nccl-tests via mpirun (multi-node or per-GPU-process mode)."""
binary = self._find_nccl_test(binary_name)
if not binary:
return {"status": "SKIP", "error": f"{binary_name} not found"}
cmd = [
mpirun,
"-np", str(gpu_count),
"--allow-run-as-root",
"-x", "NCCL_DEBUG=WARN",
"-x", "CUDA_VISIBLE_DEVICES=" + ",".join(str(i) for i in range(gpu_count)),
binary,
"-b", "8",
"-e", "256M",
"-f", "2",
"-g", "1",
"-w", "5",
"-n", "20",
]
try:
env = os.environ.copy()
env["NCCL_DEBUG"] = "WARN"
r = subprocess.run(cmd, capture_output=True, text=True, timeout=180, env=env)
combined = r.stdout + r.stderr
if "CUDA driver version is insufficient" in combined or \
"Test NCCL failure" in combined:
error_msg = "NCCL/CUDA driver version mismatch" \
if "CUDA driver version" in combined \
else "NCCL test failure (library incompatibility)"
return {"status": "FAIL", "error": error_msg}
if r.returncode != 0:
return {"status": "FAIL", "error": r.stderr[:300]}
return self._parse_nccl_output(r.stdout, min_bw)
except subprocess.TimeoutExpired:
return {"status": "FAIL", "error": "timeout"}
except Exception as e:
return {"status": "FAIL", "error": str(e)}
@staticmethod
def _parse_nccl_output(stdout: str, min_bw: float) -> dict:
"""Parse nccl-tests tabular output and extract bandwidth results."""
best_algbw = 0.0
best_busbw = 0.0
size_results = []
for line in stdout.split("\n"):
line = line.strip()
if not line or line.startswith("#"):
continue
parts = line.split()
# nccl-tests data lines: size count type redop root time algbw busbw #wrong [time algbw busbw #wrong]
if len(parts) >= 9:
try:
size = int(parts[0])
# parts[2] is dtype string ('float'/'int32'/etc.), not a number
# out-of-place columns: time=parts[5], algbw=parts[6], busbw=parts[7]
time_us = float(parts[5])
algbw = float(parts[6])
busbw = float(parts[7])
size_results.append({
"size": size,
"time_us": time_us,
"algbw_gbps": algbw,
"busbw_gbps": busbw,
})
if busbw > best_busbw:
best_busbw = busbw
if algbw > best_algbw:
best_algbw = algbw
except (ValueError, IndexError):
continue
status = "PASS" if best_busbw >= min_bw else "WARN"
return {
"status": status,
"best_algbw_gbps": round(best_algbw, 1),
"best_busbw_gbps": round(best_busbw, 1),
"min_required_gbps": min_bw,
"by_size": size_results[-5:] if size_results else [],
}
def _run_torchrun_fallback(self, gpu_count: int) -> dict:
"""Basic NCCL connectivity test via torchrun — verifies NCCL works but does not benchmark performance."""
self.console.print("[yellow]nccl-tests not available, running basic NCCL connectivity check[/yellow]")
code = f"""
import torch, torch.distributed as dist, os
os.environ.setdefault("MASTER_ADDR","127.0.0.1")
os.environ.setdefault("MASTER_PORT","29500")
rank=int(os.environ.get("LOCAL_RANK",0))
ws={gpu_count}
dist.init_process_group("nccl",rank=rank,world_size=ws)
torch.cuda.set_device(rank)
x=torch.randn(1024*1024,device=f"cuda:{{rank}}",dtype=torch.float32)
# Test AllReduce
try:
dist.all_reduce(x.clone())
if rank==0: print("allreduce:ok")
except Exception as e:
if rank==0: print(f"allreduce:fail:{{e}}")
# Test Broadcast
try:
dist.broadcast(x.clone(),src=0)
if rank==0: print("broadcast:ok")
except Exception as e:
if rank==0: print(f"broadcast:fail:{{e}}")
# Test AllGather
try:
tensor_list=[torch.empty_like(x) for _ in range(ws)]
dist.all_gather(tensor_list,x.clone())
if rank==0: print("allgather:ok")
except Exception as e:
if rank==0: print(f"allgather:fail:{{e}}")
# Test ReduceScatter
try:
chunks=list(x.chunk(ws))
output=torch.empty_like(chunks[0])
dist.reduce_scatter(output,chunks)
if rank==0: print("reducescatter:ok")
except Exception as e:
if rank==0: print(f"reducescatter:fail:{{e}}")
# Test AllToAll
try:
chunks=list(x.chunk(ws))
output_list=[torch.empty_like(c) for c in chunks]
dist.all_to_all(output_list,chunks)
if rank==0: print("alltoall:ok")
except Exception as e:
if rank==0: print(f"alltoall:fail:{{e}}")
dist.destroy_process_group()
"""
import tempfile
tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False)
tmp.write(code)
tmp.close()
try:
# Prefer torchrun from the same venv as the running Python
import sys
venv_torchrun = os.path.join(os.path.dirname(sys.executable), "torchrun")
torchrun_cmd = venv_torchrun if os.path.isfile(venv_torchrun) else "torchrun"
r = subprocess.run(
[torchrun_cmd, f"--nproc_per_node={gpu_count}", tmp.name],
capture_output=True, text=True, timeout=120,
env={**os.environ, "NCCL_DEBUG": "WARN"},
)
os.unlink(tmp.name)
# Parse connectivity results — format: op_name:ok or op_name:fail:error
tests = {}
all_passed = True
for line in r.stdout.split("\n"):
line = line.strip()
if not line:
continue
parts = line.split(":")
op_name = parts[0]
result = parts[1] if len(parts) > 1 else "unknown"
if result == "ok":
status = "PASS"
else:
status = "FAIL"
all_passed = False
tests[op_name] = {
"status": status,
"error": ":".join(parts[2:]) if len(parts) > 2 and result == "fail" else None,
}
return {
"passed": all_passed,
"source": "torchrun_fallback",
"tests": tests,
"gpu_count": gpu_count,
}
except Exception as e:
return {"passed": False, "source": "torchrun_fallback", "error": str(e)}
@staticmethod
def print_results(results: dict, console: Console = None):
c = console or Console()
if "error" in results:
c.print(f"[bold red]Error: {results['error']}[/bold red]")
return
passed = results.get("passed", False)
source = results.get("source", "unknown")
if source == "torchrun_fallback":
# Connectivity check mode
verdict = "[bold green]✓ NCCL Connectivity OK[/bold green]" if passed else "[bold red]✗ NCCL Connectivity FAILED[/bold red]"
c.print(f"{verdict} [dim](basic check via torchrun)[/dim]")
tests = results.get("tests", {})
if tests:
c.print("\n[dim]Operations tested:[/dim]")
for op_name, result in tests.items():
if not isinstance(result, dict):
continue
status = result.get("status", "FAIL")
s_color = "green" if status == "PASS" else "red"
error = result.get("error")
if error:
c.print(f" [{s_color}]{op_name}[/{s_color}] — {error}")
else:
c.print(f" [{s_color}]{op_name}[/{s_color}]")
c.print("\n[yellow]Note: functional connectivity test only (no performance data)[/yellow]")
else:
# nccl-tests mode
verdict = "[bold green]✓ NCCL tests PASSED[/bold green]" if passed else "[bold yellow]⚠ NCCL tests WARNING[/bold yellow]"
c.print(f"{verdict} [dim](via {source})[/dim]")
tests = results.get("tests", {})
for op_name, result in tests.items():
if not isinstance(result, dict):
continue
c.print(f"\n[bold cyan]{op_name.upper()}[/bold cyan]")
status = result.get("status", "FAIL")
s_color = "green" if status == "PASS" else ("yellow" if status == "WARN" else "red")
c.print(f" Status: [{s_color}]{status}[/{s_color}] "
f"Best bus BW: {result.get('best_busbw_gbps', 'N/A')} GB/s "
f"(min: {result.get('min_required_gbps', 'N/A')} GB/s)")
by_size = result.get("by_size", [])
if by_size:
t = Table(box=None, padding=(0, 1))
t.add_column("Size", style="bold", justify="right")
t.add_column("Time (us)", justify="right")
t.add_column("Alg BW (GB/s)", justify="right")
t.add_column("Bus BW (GB/s)", justify="right")
for r in by_size:
sz = r.get("size", 0)
sz_str = f"{sz/1024:.0f}K" if sz < 1048576 else f"{sz/1048576:.0f}M"
t.add_row(sz_str, f"{r.get('time_us',0):.1f}",
f"{r.get('algbw_gbps',0):.1f}", f"{r.get('busbw_gbps',0):.1f}")
c.print(t)