feat: 新增多机 nccl test 测试脚本

This commit is contained in:
zulifeng 2026-05-25 14:19:02 +08:00
parent fc97a768cf
commit e49ea32094
5 changed files with 791 additions and 104 deletions

View File

@ -1,4 +1,4 @@
# GPU type: auto-detect or override to a100/a800/h100/h200/b200/b300 # GPU type: auto-detect or override to a100/a800/h100/h800/h200/h20/b200/b300
gpu_type: auto gpu_type: auto
benchmark: benchmark:
@ -14,10 +14,25 @@ benchmark:
- fp16 - fp16
- bf16 - bf16
- fp8 - fp8
matrix_size: 8192 # MAMF-style shape sweep: measure each dtype at every shape below and keep the max
warmup: 50 # TFLOPS (the realistic achievable peak). A single fixed shape under-reports by
iterations: 500 # ~7-12% and can't meet the MAMF-calibrated thresholds in gpu_specs.py.
use_compile: true # Each entry is either N (square N×N×N) or [M, N, K]. K-heavy non-square shapes
# (e.g. 2048×2048×13312) hit the true Hopper MAMF — bf16 ~790 vs ~755 square.
# Empty list => single matrix_size shape (legacy behaviour).
sweep_sizes:
- 3584
- 4608
- 5376
- 8192
- 11520
- [2048, 2048, 13312]
- [2048, 2048, 16384]
matrix_size: 8192 # fallback shape when sweep_sizes is empty
warmup: 20
iterations: 80
# NOTE: torch.compile was dropped — on H100 eager cuBLAS beats Triton for plain
# GEMM, and compiling would re-autotune per shape and make the sweep very slow.
health: health:
temp_warning: 75 temp_warning: 75
@ -34,7 +49,7 @@ nccl:
test_sendrecv: false test_sendrecv: false
stress: stress:
duration_sec: 60 duration_sec: 600 # 10 min — reaches thermal steady state, validates throttle/jitter beyond warmup
use_doubles: false use_doubles: false
use_tensor_cores: true use_tensor_cores: true
memory_pct: 90 memory_pct: 90

View File

@ -312,10 +312,31 @@ class Benchmark:
def run_compute_benchmark(self, dtypes: Optional[List[str]] = None) -> dict: def run_compute_benchmark(self, dtypes: Optional[List[str]] = None) -> dict:
comp_cfg = self.bench_cfg.get("compute", {}) comp_cfg = self.bench_cfg.get("compute", {})
configured_dtypes = dtypes or comp_cfg.get("dtypes", ["fp32", "tf32", "fp16", "bf16", "fp8"]) configured_dtypes = dtypes or comp_cfg.get("dtypes", ["fp32", "tf32", "fp16", "bf16", "fp8"])
matrix_size = comp_cfg.get("matrix_size", 4096)
warmup = comp_cfg.get("warmup", 10) # MAMF-style shape sweep (à la stas00's mamf-finder): a single fixed matmul
iterations = comp_cfg.get("iterations", 100) # shape under-reports the achievable peak by ~7-12% and therefore can't meet
use_compile = comp_cfg.get("use_compile", False) # the MAMF-calibrated PASS thresholds in gpu_specs.compute_pass_thresholds_tflops.
# So for each dtype we time several matmul shapes and keep the MAXIMUM TFLOPS
# (the realistic peak). matrix_size is the fallback when sweep_sizes is empty.
matrix_size = comp_cfg.get("matrix_size", 8192)
sweep_sizes = comp_cfg.get("sweep_sizes") or [matrix_size]
warmup = comp_cfg.get("warmup", 20)
iterations = comp_cfg.get("iterations", 80)
# Each sweep entry is either an int N (square N×N×N) or an [M, N, K] triple.
# Non-square / K-heavy shapes (e.g. 2048×2048×13312) reach the true MAMF peak
# on Hopper — square-only tops out ~5% lower — so the default set mixes both.
def _to_shape(entry):
if isinstance(entry, (list, tuple)):
if len(entry) == 3:
return tuple(int(x) for x in entry)
if len(entry) == 1:
n = int(entry[0])
return (n, n, n)
raise ValueError(f"sweep size {entry!r} must be an int or [M, N, K]")
n = int(entry)
return (n, n, n)
shapes = [_to_shape(e) for e in sweep_sizes]
if not TORCH_AVAILABLE: if not TORCH_AVAILABLE:
self.console.print("[yellow]PyTorch not available - skipping compute benchmark[/yellow]") self.console.print("[yellow]PyTorch not available - skipping compute benchmark[/yellow]")
@ -323,25 +344,11 @@ class Benchmark:
gpu_count = torch.cuda.device_count() gpu_count = torch.cuda.device_count()
self.console.print(f"[cyan]Compute Benchmark - {gpu_count} GPU(s)[/cyan]") self.console.print(f"[cyan]Compute Benchmark - {gpu_count} GPU(s)[/cyan]")
if len(sweep_sizes) > 1:
# torch.compile(max-autotune) benchmarks cuBLAS vs Triton kernels and picks self.console.print(
# the fastest for this GPU/shape, typically improving efficiency by 8-15%. f"[cyan] MAMF shape sweep over {len(sweep_sizes)} sizes: "
# compile_warmup must be larger than warmup to absorb JIT + autotuning time. f"{', '.join(str(s) for s in sweep_sizes)}[/cyan]"
mm_fn = torch.matmul )
compile_warmup = warmup
if use_compile:
try:
_compiled = torch.compile(torch.matmul, mode="max-autotune")
# Trial call to trigger JIT and verify compilation succeeds before the dtype loop.
_t = torch.randn(64, 64, device="cuda", dtype=torch.float32)
_compiled(_t, _t)
torch.cuda.synchronize()
del _t
mm_fn = _compiled
compile_warmup = max(warmup, 50)
self.console.print("[cyan] torch.compile(max-autotune) enabled[/cyan]")
except Exception as e:
self.console.print(f"[yellow] torch.compile unavailable ({type(e).__name__}), using eager[/yellow]")
dtype_map = { dtype_map = {
"fp32": (torch.float32, self.specs["fp32_tflops"]), "fp32": (torch.float32, self.specs["fp32_tflops"]),
@ -352,6 +359,7 @@ class Benchmark:
} }
results_by_dtype = {} results_by_dtype = {}
best_shapes = {}
per_gpu_results = [{"index": i} for i in range(gpu_count)] per_gpu_results = [{"index": i} for i in range(gpu_count)]
with Progress( with Progress(
@ -376,84 +384,39 @@ class Benchmark:
dtype_val, peak_tflops = dtype_map[dtype_name] dtype_val, peak_tflops = dtype_map[dtype_name]
try: # allow_tf32 only affects float32 matmuls: ON for the TF32 run, OFF for
if dtype_name == "tf32": # the true-FP32 run so the two stay distinct.
old_tf32 = torch.backends.cuda.matmul.allow_tf32 old_tf32 = torch.backends.cuda.matmul.allow_tf32
if dtype_name == "tf32":
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
dtype_val = torch.float32 dtype_val = torch.float32
elif dtype_name == "fp32":
torch.backends.cuda.matmul.allow_tf32 = False
M = N = K = matrix_size best_tflops, best_shape, last_err = 0.0, None, None
for (M, N, K) in shapes:
# Allocate enough matrix pairs so total memory exceeds GPU L2 cache
# (H100/H200 L2 = 50 MB), preventing cross-iteration cache reuse.
elem_bytes = 1 if dtype_name == "fp8" else torch.tensor([], dtype=dtype_val).element_size()
pair_bytes = 2 * M * K * elem_bytes
num_pools = max(4, -(-256 * 1024 * 1024 // pair_bytes)) # ceil(256MB / pair)
pools_a = pools_b = c = None
if dtype_name == "fp8":
pools_a = [torch.randn(M, K, device="cuda", dtype=torch.float32).to(torch.float8_e4m3fn) for _ in range(num_pools)]
pools_b = [torch.randn(N, K, device="cuda", dtype=torch.float32).to(torch.float8_e4m3fn) for _ in range(num_pools)]
scale_a = torch.tensor(1.0, device="cuda")
scale_b = torch.tensor(1.0, device="cuda")
def _fp8_mm(i):
return torch._scaled_mm(pools_a[i], pools_b[i].T, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16)
# Probe: verify _scaled_mm is functional before the timed loop.
# It requires PyTorch >= 2.1 + CUDA >= 12.0 + sm90 (Hopper).
if not hasattr(torch, "_scaled_mm"):
raise RuntimeError("torch._scaled_mm unavailable — upgrade to PyTorch >= 2.1")
try: try:
_probe = _fp8_mm(0) t = self._bench_matmul_once(dtype_name, dtype_val, M, N, K, warmup, iterations)
torch.cuda.synchronize() if t > best_tflops:
del _probe best_tflops, best_shape = t, (M, N, K)
except Exception as probe_err: except Exception as e: # noqa: BLE001 - record and try the next shape
raise RuntimeError(f"FP8 _scaled_mm probe failed: {probe_err}") from probe_err last_err = e
for i in range(warmup):
_fp8_mm(i % num_pools)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for i in range(iterations):
c = _fp8_mm(i % num_pools)
end_event.record()
torch.cuda.synchronize()
elapsed_ms = start_event.elapsed_time(end_event)
else:
pools_a = [torch.randn(M, K, device="cuda", dtype=dtype_val) for _ in range(num_pools)]
pools_b = [torch.randn(K, N, device="cuda", dtype=dtype_val) for _ in range(num_pools)]
indexed_a = [pools_a[i % num_pools] for i in range(compile_warmup + iterations)]
indexed_b = [pools_b[i % num_pools] for i in range(compile_warmup + iterations)]
for i in range(compile_warmup):
mm_fn(indexed_a[i], indexed_b[i])
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for i in range(compile_warmup, compile_warmup + iterations):
c = mm_fn(indexed_a[i], indexed_b[i])
end_event.record()
torch.cuda.synchronize()
elapsed_ms = start_event.elapsed_time(end_event)
flops = 2 * M * N * K * iterations
tflops = flops / (elapsed_ms / 1000) / 1e12
results_by_dtype[dtype_name] = round(tflops, 1)
for pg in per_gpu_results:
pg[dtype_name] = round(tflops, 1)
if dtype_name == "tf32":
torch.backends.cuda.matmul.allow_tf32 = old_tf32 torch.backends.cuda.matmul.allow_tf32 = old_tf32
del pools_a, pools_b, c if best_shape is None:
torch.cuda.empty_cache() results_by_dtype[dtype_name] = f"error: {last_err}"
self.console.print(f"[yellow] {dtype_name}: {last_err}[/yellow]")
except Exception as e: else:
results_by_dtype[dtype_name] = f"error: {e}" shape_str = "x".join(str(d) for d in best_shape)
self.console.print(f"[yellow] {dtype_name}: {e}[/yellow]") results_by_dtype[dtype_name] = round(best_tflops, 1)
best_shapes[dtype_name] = shape_str
for pg in per_gpu_results:
pg[dtype_name] = round(best_tflops, 1)
if len(shapes) > 1:
self.console.print(
f"[dim] {dtype_name}: {best_tflops:.1f} TFLOPS @ {shape_str}[/dim]"
)
progress.advance(task) progress.advance(task)
@ -476,12 +439,67 @@ class Benchmark:
self.specs.get("compute_pass_thresholds_tflops") or {} self.specs.get("compute_pass_thresholds_tflops") or {}
), ),
"per_gpu": per_gpu_results, "per_gpu": per_gpu_results,
"sweep_sizes": list(sweep_sizes),
"best_shapes": best_shapes,
"matrix_size": matrix_size, "matrix_size": matrix_size,
"warmup": warmup, "warmup": warmup,
"iterations": iterations, "iterations": iterations,
} }
} }
def _bench_matmul_once(self, dtype_name: str, dtype_val, M: int, N: int, K: int,
warmup: int, iterations: int) -> float:
"""Time one (M×K)·(K×N) matmul for a dtype and return achieved TFLOPS.
Uses an L2-cache-busting pool of matrix pairs (total > 256 MB) so operands
can't be served from L2 across iterations, and CUDA events for timing. FP8
goes through torch._scaled_mm (e4m3); all others through torch.matmul eager
cuBLAS, which on H100 beats torch.compile/Triton for plain GEMM and avoids the
per-shape recompile cost that would make a sweep pathologically slow.
"""
elem_bytes = 1 if dtype_name == "fp8" else torch.tensor([], dtype=dtype_val).element_size()
pair_bytes = (M * K + K * N) * elem_bytes
num_pools = max(4, -(-256 * 1024 * 1024 // pair_bytes)) # ceil(256MB / pair)
if dtype_name == "fp8":
if not hasattr(torch, "_scaled_mm"):
raise RuntimeError("torch._scaled_mm unavailable — upgrade to PyTorch >= 2.1")
pools_a = [torch.randn(M, K, device="cuda", dtype=torch.float32).to(torch.float8_e4m3fn) for _ in range(num_pools)]
pools_b = [torch.randn(N, K, device="cuda", dtype=torch.float32).to(torch.float8_e4m3fn) for _ in range(num_pools)]
scale_a = torch.tensor(1.0, device="cuda")
scale_b = torch.tensor(1.0, device="cuda")
def op(i):
return torch._scaled_mm(pools_a[i], pools_b[i].T, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16)
else:
pools_a = [torch.randn(M, K, device="cuda", dtype=dtype_val) for _ in range(num_pools)]
pools_b = [torch.randn(K, N, device="cuda", dtype=dtype_val) for _ in range(num_pools)]
def op(i):
return torch.matmul(pools_a[i], pools_b[i])
try:
# Probe once so a broken/unsupported kernel raises before the timed loop.
_probe = op(0)
torch.cuda.synchronize()
del _probe
for i in range(warmup):
op(i % num_pools)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for i in range(iterations):
op(i % num_pools)
end_event.record()
torch.cuda.synchronize()
elapsed_ms = start_event.elapsed_time(end_event)
finally:
del pools_a, pools_b
torch.cuda.empty_cache()
return (2 * M * N * K * iterations) / (elapsed_ms / 1000) / 1e12
@staticmethod @staticmethod
def print_results(results: dict, console: Console = None): def print_results(results: dict, console: Console = None):
c = console or Console() c = console or Console()
@ -564,3 +582,78 @@ class Benchmark:
table.add_row(dt.upper(), f"{achieved:.1f}", f"{pk:.0f}", table.add_row(dt.upper(), f"{achieved:.1f}", f"{pk:.0f}",
f"[{ec}]{ef:.1f}%[/{ec}]") f"[{ec}]{ef:.1f}%[/{ec}]")
c.print(table) c.print(table)
@staticmethod
def judge_compute(results: dict) -> dict:
"""Judge compute results against pass_thresholds_tflops.
Single source of truth for the PASS/WARN/FAIL rule (same one report.py uses):
achieved >= thr -> PASS; >= 0.9*thr -> WARN; else FAIL. A string achieved value
(skipped/error) -> SKIP. A dtype without a threshold falls back to efficiency
(>=80 PASS / >=50 WARN / else FAIL).
Returns {"rows": [(dtype, achieved, threshold, status), ...], "verdict": str}.
"""
comp = results.get("compute", results)
per_dtype = comp.get("per_dtype_tflops", {})
thresholds = comp.get("pass_thresholds_tflops", {}) or {}
eff = comp.get("efficiency_pct", {})
rank = {"PASS": 0, "WARN": 1, "FAIL": 2, "SKIP": 0}
rows, verdict = [], "PASS"
for dt, val in per_dtype.items():
thr = thresholds.get(dt)
if isinstance(val, str):
status = "SKIP"
elif thr:
status = "PASS" if val >= thr else ("WARN" if val >= thr * 0.9 else "FAIL")
else:
e = eff.get(dt, 0)
status = "PASS" if e >= 80 else ("WARN" if e >= 50 else "FAIL")
rows.append((dt, val, thr, status))
if rank[status] > rank[verdict]:
verdict = status
return {"rows": rows, "verdict": verdict}
@staticmethod
def print_compute_verdict(results: dict, console: Console = None) -> str:
"""Print the PASS/WARN/FAIL table for compute results; return the verdict."""
c = console or Console()
judged = Benchmark.judge_compute(results)
color = {"PASS": "green", "WARN": "yellow", "FAIL": "red", "SKIP": "dim"}
c.print("\n[bold cyan]Compute Verdict (vs thresholds)[/bold cyan]")
for dt, val, thr, status in judged["rows"]:
val_s = f"{val:.1f}" if isinstance(val, (int, float)) else str(val)
thr_s = f">= {thr}" if thr else "(efficiency)"
c.print(f" {dt.upper():>4}: {val_s:>8} {thr_s:<12} [{color[status]}]{status}[/{color[status]}]")
v = judged["verdict"]
c.print(f" [bold]VERDICT: [{color[v]}]{v}[/{color[v]}][/bold]")
return v
def _run_cli() -> None:
"""`python -m modules.benchmark` — run ONLY the compute-throughput benchmark."""
import argparse
from pathlib import Path
import yaml
repo_root = Path(__file__).resolve().parent.parent
parser = argparse.ArgumentParser(description="Run the compute-throughput benchmark only.")
parser.add_argument("--config", default=str(repo_root / "configs" / "default.yaml"),
help="path to config YAML (default: configs/default.yaml)")
parser.add_argument("--json", action="store_true", help="also print raw JSON of the compute results")
args = parser.parse_args()
with open(args.config) as f:
config = yaml.safe_load(f) or {}
results = Benchmark(config).run_compute_benchmark()
Benchmark.print_results(results)
Benchmark.print_compute_verdict(results)
if args.json:
print("JSON_RESULT:" + json.dumps(results["compute"]))
if __name__ == "__main__":
_run_cli()

View File

@ -11,6 +11,7 @@ GPU_NAME_PATTERNS = {
"A100": "a100", "A100": "a100",
"A800": "a800", "A800": "a800",
"H100": "h100", "H100": "h100",
"H800": "h800", # H800 = H100 SXM with NVLink halved (400 GB/s) and FP64 restricted
"H200": "h200", "H200": "h200",
"H20": "h20", # H20 / H20-3e is the China-compliance export variant, REDUCED peaks "H20": "h20", # H20 / H20-3e is the China-compliance export variant, REDUCED peaks
"B200": "b200", "B200": "b200",
@ -36,7 +37,14 @@ GPU_SPECS = {
"bf16_tflops": 990, # dense "bf16_tflops": 990, # dense
"fp8_tflops": 1979, # dense "fp8_tflops": 1979, # dense
"compute_pass_thresholds_tflops": { "compute_pass_thresholds_tflops": {
"fp32": 54, "tf32": 444, "fp16": 734, "bf16": 745, "fp8": 1400, # Recalibrated 2026-05-25 to the H100 eager-cuBLAS achievable floor (each
# threshold ~2-4% below the sustained value measured across 16 GPUs via the
# MAMF shape sweep: fp32 ~52 / tf32 ~405 / fp16 ~732-748 / bf16 ~747-758 /
# fp8 ~1248-1271). The old marketing/MAMF-derived values (fp32 54, tf32 444,
# fp16 734, bf16 745, fp8 1400) sat ON or ABOVE what PyTorch cuBLAS reaches
# on H100, so healthy cards flaked to WARN/FAIL. fp8 1400 in particular was
# an H200/rowwise-scaling figure; H100 tensorwise _scaled_mm tops out ~1310.
"fp32": 50, "tf32": 385, "fp16": 720, "bf16": 730, "fp8": 1200,
# FP64 63 / INT8 1536 — listed for documentation; benchmark module # FP64 63 / INT8 1536 — listed for documentation; benchmark module
# doesn't currently exercise these dtypes. # doesn't currently exercise these dtypes.
}, },
@ -59,10 +67,48 @@ GPU_SPECS = {
"fp16_tflops": 990, # dense "fp16_tflops": 990, # dense
"bf16_tflops": 990, # dense "bf16_tflops": 990, # dense
"fp8_tflops": 1979, # dense "fp8_tflops": 1979, # dense
# PASS thresholds aligned with H200_production_acceptance.md v2 (2026-05-21):
# calibrated against Semianalysis & stas00 MAMF — H200 shares H100 SMs so
# achievable TFLOPS in PyTorch is in the same band.
"compute_pass_thresholds_tflops": {
"fp32": 50, "tf32": 400, "fp16": 720, "bf16": 720, "fp8": 1400,
},
"tdp_watts": 700, "tdp_watts": 700,
"nvlink_gen": 4, "nvlink_gen": 4,
"nvlink_bandwidth_gbps": 900, "nvlink_bandwidth_gbps": 900,
"pcie_gen": 5, "pcie_gen": 5,
"min_driver_version": "545",
"min_cuda_version": "12.4",
},
"h800": {
# H800 = China-compliance export variant of H100 SXM5. SAME chip / SMs /
# clocks / HBM as H100 SXM5 — Tensor Core peaks (FP16 / BF16 / FP8 / TF32 /
# FP32) are identical to H100. Two restrictions vs H100:
# 1. NVLink bandwidth halved: 400 GB/s bidirectional (vs H100 900 GB/s)
# 2. FP64 throughput severely cut to ~1 TFLOPS (vs H100 34/67 TFLOPS)
# All other interfaces (PCIe Gen5, NVSwitch, HBM3 80GB @ 3.35 TB/s) match H100.
# NCCL multi-GPU thresholds MUST be downscaled because NVLink BW is halved.
"full_name": "NVIDIA H800 SXM5",
"architecture": "Hopper",
"compute_capability": 9.0,
"hbm_capacity_gb": 80,
"hbm_type": "HBM3",
"memory_bandwidth_gbps": 3350, # GB/s (3.35 TB/s) — same as H100 SXM
"fp32_tflops": 67,
"tf32_tflops": 495, # dense (same as H100)
"fp16_tflops": 990, # dense (same as H100)
"bf16_tflops": 990, # dense (same as H100)
"fp8_tflops": 1979, # dense (same as H100)
# Tensor Core peaks identical to H100, so PASS thresholds match v2 calibration.
# FP64 deliberately NOT listed — H800 is restricted to ~1 TFLOPS FP64 and
# is not a valid HPC target dtype.
"compute_pass_thresholds_tflops": {
"fp32": 50, "tf32": 400, "fp16": 720, "bf16": 720, "fp8": 1400,
},
"tdp_watts": 700,
"nvlink_gen": 4,
"nvlink_bandwidth_gbps": 400, # bidirectional — HALF of H100 (export restriction)
"pcie_gen": 5,
"min_driver_version": "535", "min_driver_version": "535",
"min_cuda_version": "12.1", "min_cuda_version": "12.1",
}, },

View File

@ -0,0 +1,533 @@
"""Multi-node NCCL benchmark wrapper for nccl-tests via mpirun."""
import json
import os
import re
import shutil
import subprocess
from datetime import datetime
from typing import Optional
from rich.console import Console
from rich.table import Table
from modules.gpu_specs import resolve_tools_dir
_TEST_ALIASES = {
"allreduce": "all_reduce_perf",
"all_reduce": "all_reduce_perf",
"all_reduce_perf": "all_reduce_perf",
"allgather": "all_gather_perf",
"all_gather": "all_gather_perf",
"all_gather_perf": "all_gather_perf",
"alltoall": "alltoall_perf",
"all_to_all": "alltoall_perf",
"alltoall_perf": "alltoall_perf",
"broadcast": "broadcast_perf",
"broadcast_perf": "broadcast_perf",
"reducescatter": "reduce_scatter_perf",
"reduce_scatter": "reduce_scatter_perf",
"reduce_scatter_perf": "reduce_scatter_perf",
"sendrecv": "sendrecv_perf",
"send_recv": "sendrecv_perf",
"sendrecv_perf": "sendrecv_perf",
}
_OP_LABELS = {
"all_reduce_perf": "allreduce",
"all_gather_perf": "allgather",
"alltoall_perf": "alltoall",
"broadcast_perf": "broadcast",
"reduce_scatter_perf": "reducescatter",
"sendrecv_perf": "sendrecv",
}
class MultiNodeNCCLTest:
"""Run cross-node NCCL tests with a PDF-style message-size sweep."""
def __init__(self, config: dict):
self.config = config
self.cfg = config.get("multinode_nccl", {}) or {}
self.tools_dir = resolve_tools_dir(config)
self.console = Console()
self.artifact_dir = os.environ.get("MULTINODE_NCCL_ARTIFACT_DIR") or self.cfg.get("artifact_dir")
def _find_mpirun(self) -> Optional[str]:
configured = self.cfg.get("mpirun_path")
if configured and os.path.isfile(str(configured)) and os.access(str(configured), os.X_OK):
return str(configured)
for cmd in ["mpirun", "mpiexec", os.path.join(self.tools_dir, "mpi", "bin", "mpirun")]:
found = shutil.which(cmd)
if found:
return found
return None
def _find_nccl_test(self, binary_name: str) -> Optional[str]:
configured = self.cfg.get("nccl_tests_dir")
candidates = []
if configured:
candidates.append(os.path.join(configured, binary_name))
candidates.append(os.path.join(self.tools_dir, "nccl-tests", "build", binary_name))
found = shutil.which(binary_name)
if found:
candidates.insert(0, found)
for path in candidates:
if path and os.path.isfile(path) and os.access(path, os.X_OK):
return path
return None
def _tests(self) -> list[str]:
configured = self.cfg.get("tests") or ["all_reduce_perf", "alltoall_perf"]
tests = []
for name in configured:
binary = _TEST_ALIASES.get(str(name).lower())
if binary and binary not in tests:
tests.append(binary)
return tests
def _hosts(self) -> list[dict]:
hosts = self.cfg.get("hosts") or []
normalized = []
for host in hosts:
if isinstance(host, str):
normalized.append({"addr": host, "slots": 8})
elif isinstance(host, dict):
normalized.append({
"name": host.get("name") or host.get("addr"),
"addr": host.get("addr") or host.get("host") or host.get("ip"),
"slots": int(host.get("slots", 8)),
})
return [h for h in normalized if h.get("addr")]
def _topologies(self) -> list[dict]:
topologies = self.cfg.get("topologies") or [{"nodes": 2, "gpus_per_node": 8}]
normalized = []
for topo in topologies:
nodes = int(topo.get("nodes", 2))
gpus_per_node = int(topo.get("gpus_per_node", topo.get("gpn", 8)))
normalized.append({
"nodes": nodes,
"gpus_per_node": gpus_per_node,
"label": topo.get("label") or f"{nodes} nodes x {gpus_per_node} GPUs",
"cuda_visible_devices": topo.get("cuda_visible_devices"),
"env": topo.get("env") or {},
"op_env": topo.get("op_env") or topo.get("test_env") or {},
"min_peak_busbw_gbps": topo.get("min_peak_busbw_gbps"),
})
return normalized
def _env_exports(self, topo: dict = None, label: str = None, binary: str = None) -> list[tuple[str, str]]:
env_cfg = {
"NCCL_DEBUG": self.cfg.get("debug", "WARN"),
"NCCL_SOCKET_IFNAME": self.cfg.get("socket_ifname"),
"NCCL_IB_GID_INDEX": self.cfg.get("ib_gid_index"),
"NCCL_IB_SL": self.cfg.get("ib_sl"),
"NCCL_IB_TC": self.cfg.get("ib_tc"),
"NCCL_IB_HCA": self.cfg.get("ib_hca"),
"NCCL_IB_TIMEOUT": self.cfg.get("ib_timeout"),
"NCCL_IB_QPS_PER_CONNECTION": self.cfg.get("qps_per_connection"),
"NCCL_MIN_NCHANNELS": self.cfg.get("min_nchannels"),
"NCCL_NET_PLUGIN": self.cfg.get("net_plugin"),
"NCCL_NVLS_ENABLE": self.cfg.get("nvls_enable"),
"NCCL_IB_SPLIT_DATA_ON_QPS": self.cfg.get("split_data_on_qps"),
}
mpi_ld_preload = self._mpi_ld_preload()
if mpi_ld_preload:
env_cfg["LD_PRELOAD"] = mpi_ld_preload
extra_ld_library_path = self._extra_ld_library_path()
if extra_ld_library_path:
existing = os.environ.get("LD_LIBRARY_PATH", "")
env_cfg["LD_LIBRARY_PATH"] = ":".join(
[extra_ld_library_path] + ([existing] if existing else [])
)
extra_env = self.cfg.get("extra_env") or {}
if isinstance(extra_env, dict):
self._merge_env(env_cfg, extra_env)
if topo:
if topo.get("cuda_visible_devices"):
env_cfg["CUDA_VISIBLE_DEVICES"] = str(topo["cuda_visible_devices"])
if isinstance(topo.get("env"), dict):
self._merge_env(env_cfg, topo["env"])
op_env = topo.get("op_env")
if isinstance(op_env, dict):
for key in (label, binary):
overrides = op_env.get(key)
if isinstance(overrides, dict):
self._merge_env(env_cfg, overrides)
return [(k, str(v)) for k, v in env_cfg.items() if v is not None]
@staticmethod
def _merge_env(env_cfg: dict, overrides: dict):
for key, value in overrides.items():
key = str(key)
if value is None:
env_cfg.pop(key, None)
else:
env_cfg[key] = str(value)
def _mpi_ld_preload(self) -> str:
preload = self.cfg.get("mpi_ld_preload")
if isinstance(preload, list):
return " ".join(str(p) for p in preload if p)
return str(preload) if preload else ""
def _runtime_env(self) -> dict:
env = os.environ.copy()
mpi_ld_preload = self._mpi_ld_preload()
if mpi_ld_preload:
env["LD_PRELOAD"] = mpi_ld_preload
extra_ld_library_path = self._extra_ld_library_path()
if extra_ld_library_path:
existing = env.get("LD_LIBRARY_PATH", "")
env["LD_LIBRARY_PATH"] = ":".join(
[extra_ld_library_path] + ([existing] if existing else [])
)
return env
def _extra_ld_library_path(self) -> str:
paths = self.cfg.get("extra_ld_library_path")
if isinstance(paths, list):
return ":".join(str(p) for p in paths if p)
return str(paths) if paths else ""
def _preflight(self, mpirun: Optional[str], tests: list[str], hosts: list[dict]) -> dict:
checks = []
checks.append({"name": "mpirun", "status": "PASS" if mpirun else "FAIL", "detail": mpirun or "not found"})
checks.append({"name": "hosts", "status": "PASS" if len(hosts) >= 2 else "FAIL", "detail": f"{len(hosts)} configured"})
for binary in tests:
path = self._find_nccl_test(binary)
checks.append({"name": binary, "status": "PASS" if path else "FAIL", "detail": path or "not found"})
if self.cfg.get("ssh_preflight", True):
user = self.cfg.get("ssh_user", "root")
for host in hosts:
target = f"{user}@{host['addr']}"
cmd = [
"ssh",
"-o", "BatchMode=yes",
"-o", "ConnectTimeout=5",
"-o", "StrictHostKeyChecking=accept-new",
target,
"hostname",
]
try:
r = subprocess.run(cmd, capture_output=True, text=True, timeout=8, env=self._runtime_env())
detail = r.stdout.strip() or r.stderr.strip()[:120]
checks.append({
"name": f"ssh {host['addr']}",
"status": "PASS" if r.returncode == 0 else "WARN",
"detail": detail,
})
except Exception as e:
checks.append({"name": f"ssh {host['addr']}", "status": "WARN", "detail": str(e)})
return {
"checks": checks,
"passed": all(c["status"] == "PASS" for c in checks if not c["name"].startswith("ssh ")),
}
def run(self) -> dict:
mpirun = self._find_mpirun()
tests = self._tests()
hosts = self._hosts()
topologies = self._topologies()
preflight = self._preflight(mpirun, tests, hosts)
if not preflight["passed"]:
return {
"passed": False,
"source": "nccl-tests-mpirun",
"mode": self.cfg.get("mode", "sweep"),
"hosts": hosts,
"preflight": preflight,
"tests": {},
"error": "multinode NCCL preflight failed",
"timestamp": datetime.now().isoformat(),
}
results = {}
for binary in tests:
label = _OP_LABELS[binary]
binary_path = self._find_nccl_test(binary)
op_results = []
for topo in topologies:
op_results.append(self._run_topology(mpirun, binary_path, label, hosts, topo))
results[label] = {"binary": binary_path, "topologies": op_results}
passed = all(
topo.get("status") == "PASS"
for op in results.values()
for topo in op.get("topologies", [])
)
return {
"passed": passed,
"source": "nccl-tests-mpirun",
"mode": self.cfg.get("mode", "sweep"),
"hosts": hosts,
"preflight": preflight,
"tests": results,
"artifact_dir": self.artifact_dir,
"timestamp": datetime.now().isoformat(),
}
def _run_topology(self, mpirun: str, binary: str, label: str, hosts: list[dict], topo: dict) -> dict:
nodes = topo["nodes"]
gpus_per_node = topo["gpus_per_node"]
selected_hosts = hosts[:nodes]
host_arg = ",".join(f"{h['addr']}:{gpus_per_node}" for h in selected_hosts)
ranks = nodes * gpus_per_node
cmd = [
mpirun,
"--allow-run-as-root",
"--mca", "btl_openib_warn_no_device_params_found", "0",
"--mca", "btl_tcp_if_include", str(self.cfg.get("socket_ifname", "bond0")),
"--mca", "oob_tcp_if_include", str(self.cfg.get("oob_tcp_ifname", self.cfg.get("socket_ifname", "bond0"))),
"-H", host_arg,
"--map-by", f"ppr:{gpus_per_node}:node",
"-np", str(ranks),
]
plm_rsh_args = self.cfg.get("plm_rsh_args")
if plm_rsh_args:
cmd.extend(["--mca", "plm_rsh_args", str(plm_rsh_args)])
for key, value in self._env_exports(topo=topo, label=label, binary=os.path.basename(binary)):
cmd.extend(["-x", f"{key}={value}"])
cmd.extend([
binary,
"-b", str(self.cfg.get("begin_size", "1k")),
"-e", str(self.cfg.get("end_size", "16g")),
"-g", str(self.cfg.get("gpus_per_rank", 1)),
"-f", str(self.cfg.get("step_factor", 2)),
"-w", str(self.cfg.get("warmup_iters", 10)),
])
if self.cfg.get("iters") is not None:
cmd.extend(["-n", str(self.cfg["iters"])])
timeout = int(self.cfg.get("timeout_sec", 1800))
started = datetime.now().isoformat()
try:
r = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout, env=self._runtime_env())
except subprocess.TimeoutExpired:
result = {
"label": topo["label"],
"nodes": nodes,
"gpus_per_node": gpus_per_node,
"ranks": ranks,
"hosts": selected_hosts,
"command": " ".join(cmd),
"status": "FAIL",
"error": f"timeout after {timeout}s",
"started_at": started,
}
self._write_artifacts(label, topo, result, "", "")
return result
parsed = self._parse_nccl_output(r.stdout)
net_diag = self._parse_network_diagnostics(r.stdout + "\n" + r.stderr)
threshold = self._threshold_for(label, topo)
wrong = sum(row.get("wrong", 0) for row in parsed["by_size"])
has_bw = parsed["peak_busbw_gbps"] > 0
status = "PASS" if r.returncode == 0 and has_bw and wrong == 0 and parsed["peak_busbw_gbps"] >= threshold else "FAIL"
result = {
"label": topo["label"],
"nodes": nodes,
"gpus_per_node": gpus_per_node,
"ranks": ranks,
"hosts": selected_hosts,
"cuda_visible_devices": topo.get("cuda_visible_devices"),
"command": " ".join(cmd),
"returncode": r.returncode,
"status": status,
"peak_busbw_gbps": parsed["peak_busbw_gbps"],
"peak_algbw_gbps": parsed["peak_algbw_gbps"],
"peak_size": parsed["peak_size"],
"avg_busbw_gbps": parsed["avg_busbw_gbps"],
"min_required_gbps": threshold,
"wrong_count": wrong,
"network": net_diag,
"by_size": parsed["by_size"],
"stderr_tail": r.stderr[-1200:],
"stdout_tail": r.stdout[-1200:],
"started_at": started,
"finished_at": datetime.now().isoformat(),
}
self._write_artifacts(label, topo, result, r.stdout, r.stderr)
return result
def _write_artifacts(self, label: str, topo: dict, result: dict, stdout: str, stderr: str):
if not self.artifact_dir:
return
os.makedirs(self.artifact_dir, exist_ok=True)
prefix = _safe_name(f"{label}_{topo.get('nodes')}x{topo.get('gpus_per_node')}_{topo.get('label')}")
base = os.path.join(self.artifact_dir, prefix)
with open(base + ".cmd.txt", "w") as f:
f.write(result.get("command", ""))
f.write("\n")
with open(base + ".stdout.txt", "w") as f:
f.write(stdout)
with open(base + ".stderr.txt", "w") as f:
f.write(stderr)
artifact_result = {k: v for k, v in result.items() if k not in ("stdout_tail", "stderr_tail")}
with open(base + ".json", "w") as f:
json.dump(artifact_result, f, indent=2, default=str)
result["artifact_prefix"] = base
def _threshold_for(self, label: str, topo: dict = None) -> float:
if topo and topo.get("min_peak_busbw_gbps") is not None:
topo_thresholds = topo.get("min_peak_busbw_gbps")
if isinstance(topo_thresholds, dict):
return float(topo_thresholds.get(label, 0) or 0)
return float(topo_thresholds or 0)
thresholds = self.cfg.get("min_peak_busbw_gbps") or {}
if isinstance(thresholds, dict):
op_threshold = thresholds.get(label, 0)
if isinstance(op_threshold, dict):
keys = []
if topo:
keys.extend([
topo.get("label"),
f"{topo.get('nodes')}x{topo.get('gpus_per_node')}",
f"{topo.get('nodes')} nodes x {topo.get('gpus_per_node')} GPUs",
str(topo.get("gpus_per_node")),
])
keys.append("default")
for key in keys:
if key in op_threshold:
return float(op_threshold.get(key) or 0)
return 0.0
return float(op_threshold or 0)
return float(thresholds or 0)
@staticmethod
def _parse_nccl_output(stdout: str) -> dict:
rows = []
avg_bus = 0.0
for line in stdout.splitlines():
stripped = line.strip()
if not stripped:
continue
avg_match = re.search(r"Avg bus bandwidth\s*:\s*([0-9.]+)", stripped)
if avg_match:
avg_bus = float(avg_match.group(1))
continue
if stripped.startswith("#"):
continue
parts = stripped.split()
if len(parts) < 9:
continue
try:
size_bytes = int(parts[0])
time_us = float(parts[5])
algbw = float(parts[6])
busbw = float(parts[7])
wrong = int(parts[8])
except (ValueError, IndexError):
continue
rows.append({
"size_bytes": size_bytes,
"size": _format_size(size_bytes),
"time_us": time_us,
"algbw_gbps": algbw,
"busbw_gbps": busbw,
"wrong": wrong,
})
peak_row = max(rows, key=lambda r: r["busbw_gbps"], default={})
return {
"peak_busbw_gbps": round(float(peak_row.get("busbw_gbps", 0)), 2),
"peak_algbw_gbps": round(float(peak_row.get("algbw_gbps", 0)), 2),
"peak_size": peak_row.get("size", ""),
"avg_busbw_gbps": round(avg_bus, 2),
"by_size": rows,
}
@staticmethod
def _parse_network_diagnostics(output: str) -> dict:
networks = sorted(set(re.findall(r"Using network (\S+)", output)))
gdr_enabled = sorted(set(re.findall(r"GPU Direct RDMA Enabled for HCA \d+ '([^']+)'", output)))
gdr_disabled = sorted(set(re.findall(r"GPU Direct RDMA Disabled for HCA \d+ '([^']+)'", output)))
ib_using = []
for line in output.splitlines():
if "NET/IB : Using" in line:
text = line.split("NET/IB : ", 1)[-1].strip()
if text not in ib_using:
ib_using.append(text)
if gdr_disabled:
gdr_state = "DISABLED"
elif gdr_enabled or "/GDRDMA" in output:
gdr_state = "ENABLED"
elif networks:
gdr_state = "NOT_DISABLED_IN_LOG"
else:
gdr_state = "UNKNOWN"
return {
"networks": networks,
"ib_using": ib_using[:8],
"gdr_enabled_hcas": gdr_enabled,
"gdr_disabled_hcas": gdr_disabled,
"gpu_direct_rdma": gdr_state,
}
@staticmethod
def print_results(results: dict, console: Console = None):
c = console or Console()
if results.get("error"):
c.print(f"[bold red]Multi-node NCCL failed: {results['error']}[/bold red]")
else:
c.print("[bold green]Multi-node NCCL complete[/bold green]" if results.get("passed") else "[bold red]Multi-node NCCL failed[/bold red]")
preflight = results.get("preflight", {})
if preflight.get("checks"):
table = Table(title="Preflight")
table.add_column("Check")
table.add_column("Status")
table.add_column("Detail")
for check in preflight["checks"]:
table.add_row(check["name"], check["status"], str(check.get("detail", "")))
c.print(table)
for op, data in (results.get("tests") or {}).items():
table = Table(title=f"Multi-node NCCL {op}")
table.add_column("Topology")
table.add_column("Peak Bus BW")
table.add_column("Peak Size")
table.add_column("Threshold")
table.add_column("Status")
for topo in data.get("topologies", []):
table.add_row(
topo.get("label", ""),
f"{topo.get('peak_busbw_gbps', 0):.2f} GB/s",
str(topo.get("peak_size", "")),
f">= {_format_gbps(topo.get('min_required_gbps', 0))} GB/s" if topo.get("min_required_gbps") else "-",
topo.get("status", "?"),
)
c.print(table)
def _format_size(size_bytes: int) -> str:
units = [("G", 1024 ** 3), ("M", 1024 ** 2), ("K", 1024)]
for suffix, factor in units:
if size_bytes >= factor and size_bytes % factor == 0:
return f"{size_bytes // factor}{suffix}"
return str(size_bytes)
def _format_gbps(value) -> str:
try:
numeric = float(value)
except (TypeError, ValueError):
return str(value)
if numeric.is_integer():
return f"{numeric:.0f}"
return f"{numeric:.2f}"
def _safe_name(value: str) -> str:
text = re.sub(r"[^A-Za-z0-9_.-]+", "_", value.strip())
text = re.sub(r"_+", "_", text).strip("_")
return text[:160] or "case"