499 lines
20 KiB
Python
499 lines
20 KiB
Python
"""NCCL multi-GPU communication test — wraps official nccl-tests."""
|
|
|
|
import glob
|
|
import os
|
|
import re
|
|
import shutil
|
|
import subprocess
|
|
import statistics
|
|
import sys
|
|
from datetime import datetime
|
|
from typing import Optional
|
|
|
|
from rich.console import Console
|
|
from rich.table import Table
|
|
from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
|
|
|
|
from modules.gpu_specs import detect_gpu_type, get_gpu_specs, resolve_tools_dir
|
|
|
|
TORCH_AVAILABLE = False
|
|
try:
|
|
import torch
|
|
if torch.cuda.is_available():
|
|
TORCH_AVAILABLE = True
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
# Per-operation bandwidth thresholds, as a fraction of NVLink bidirectional BW.
|
|
# Values aligned with the H100 production acceptance criteria (acceptance doc §5).
|
|
# AllToAll runs ~10-20% lower than AllReduce on 8-GPU NVSwitch, so its fraction is
|
|
# set lower; broadcast/sendrecv sit between.
|
|
_OP_BW_FRACTIONS = {
|
|
"allreduce": 0.45,
|
|
"allgather": 0.45,
|
|
"reducescatter": 0.45,
|
|
"broadcast": 0.40,
|
|
"sendrecv": 0.40,
|
|
"alltoall": 0.35,
|
|
}
|
|
|
|
|
|
class NCCLTest:
|
|
|
|
def __init__(self, config: dict):
|
|
self.config = config
|
|
self.console = Console()
|
|
self.nccl_cfg = config.get("nccl", {})
|
|
self.tools_dir = resolve_tools_dir(config)
|
|
self.gpu_type = detect_gpu_type()
|
|
self.specs = get_gpu_specs(self.gpu_type)
|
|
|
|
def _find_nccl_test(self, name: str) -> Optional[str]:
|
|
p = shutil.which(name)
|
|
if p:
|
|
return p
|
|
|
|
build_dir = os.path.join(self.tools_dir, "nccl-tests", "build")
|
|
local = os.path.join(build_dir, name)
|
|
if os.path.isfile(local) and os.access(local, shutil.os.X_OK):
|
|
return local
|
|
|
|
matches = glob.glob(os.path.join(self.tools_dir, "nccl-tests", "**", name), recursive=True)
|
|
for m in matches:
|
|
if os.access(m, shutil.os.X_OK):
|
|
return m
|
|
return None
|
|
|
|
def _find_mpirun(self) -> Optional[str]:
|
|
for cmd in ["mpirun", "mpiexec", os.path.join(self.tools_dir, "mpi", "bin", "mpirun")]:
|
|
p = shutil.which(cmd)
|
|
if p:
|
|
return p
|
|
return None
|
|
|
|
def _message_sizes(self) -> list[str]:
|
|
return list(self.nccl_cfg.get("message_sizes") or ["1M", "256M", "2G"])
|
|
|
|
def _repeats(self) -> int:
|
|
return int(self.nccl_cfg.get("repeats", 3))
|
|
|
|
def _max_stddev_pct(self) -> float:
|
|
return float(self.nccl_cfg.get("max_stddev_pct", 3))
|
|
|
|
def _runtime_env(self) -> dict:
|
|
env = {**os.environ, "NCCL_DEBUG": "WARN"}
|
|
lib_dirs = []
|
|
|
|
nccl_home = env.get("NCCL_HOME") or self.nccl_cfg.get("nccl_home")
|
|
if nccl_home:
|
|
lib_dirs.append(os.path.join(str(nccl_home), "lib"))
|
|
|
|
for path in sys.path:
|
|
lib_dirs.append(os.path.join(path, "nvidia", "nccl", "lib"))
|
|
|
|
venv_root = os.path.dirname(os.path.dirname(sys.executable))
|
|
lib_dirs.extend(glob.glob(os.path.join(venv_root, "lib", "python*", "site-packages", "nvidia", "nccl", "lib")))
|
|
|
|
existing = env.get("LD_LIBRARY_PATH", "")
|
|
valid_dirs = []
|
|
for d in lib_dirs:
|
|
if d and os.path.isdir(d) and d not in valid_dirs:
|
|
valid_dirs.append(d)
|
|
if valid_dirs:
|
|
env["LD_LIBRARY_PATH"] = ":".join(valid_dirs + ([existing] if existing else []))
|
|
return env
|
|
|
|
def run(self) -> dict:
|
|
gpu_count = 0
|
|
if TORCH_AVAILABLE:
|
|
gpu_count = torch.cuda.device_count()
|
|
|
|
if gpu_count < 2:
|
|
self.console.print(f"[yellow]NCCL test requires at least 2 GPUs (found {gpu_count})[/yellow]")
|
|
return {"error": "need_at_least_2_gpus", "gpu_count": gpu_count}
|
|
|
|
tests = []
|
|
if self.nccl_cfg.get("test_allreduce", True):
|
|
tests.append(("all_reduce_perf", "AllReduce"))
|
|
if self.nccl_cfg.get("test_alltoall", True):
|
|
tests.append(("alltoall_perf", "AllToAll"))
|
|
if self.nccl_cfg.get("test_broadcast", True):
|
|
tests.append(("broadcast_perf", "Broadcast"))
|
|
if self.nccl_cfg.get("test_reduce_scatter", False):
|
|
tests.append(("reduce_scatter_perf", "ReduceScatter"))
|
|
if self.nccl_cfg.get("test_allgather", False):
|
|
tests.append(("all_gather_perf", "AllGather"))
|
|
if self.nccl_cfg.get("test_sendrecv", False):
|
|
tests.append(("sendrecv_perf", "SendRecv"))
|
|
|
|
nvlink_bw = self.specs.get("nvlink_bandwidth_gbps", 0)
|
|
# User-provided override applies uniformly across all ops; otherwise
|
|
# each op gets its own threshold from _OP_BW_FRACTIONS.
|
|
user_override = self.nccl_cfg.get("min_bandwidth_gbps")
|
|
|
|
def threshold_for(label: str) -> float:
|
|
if user_override:
|
|
return float(user_override)
|
|
if nvlink_bw <= 0:
|
|
return 10.0 # conservative floor
|
|
frac = _OP_BW_FRACTIONS.get(label.lower(), 0.45)
|
|
return round(nvlink_bw * frac)
|
|
|
|
if self.gpu_type == "unknown":
|
|
self.console.print("[yellow]Unknown GPU — using conservative bandwidth thresholds[/yellow]")
|
|
|
|
# Strategy: try nccl-tests binary directly (single-node, -g N),
|
|
# then mpirun, then torchrun fallback
|
|
results = {}
|
|
any_binary_worked = False
|
|
|
|
with Progress(
|
|
SpinnerColumn(), TextColumn("[progress.description]{task.description}"),
|
|
TimeElapsedColumn(), console=self.console,
|
|
) as progress:
|
|
task = progress.add_task("NCCL tests...", total=len(tests))
|
|
|
|
for binary, label in tests:
|
|
progress.update(task, description=f"NCCL {label}...")
|
|
op_min_bw = threshold_for(label)
|
|
result = self._run_one_nccl_test_direct(
|
|
binary, label, gpu_count, op_min_bw
|
|
)
|
|
if result.get("status") not in ("SKIP", None) and "error" not in result:
|
|
any_binary_worked = True
|
|
results[label.lower()] = result
|
|
else:
|
|
# Try mpirun fallback
|
|
mpirun = self._find_mpirun()
|
|
if mpirun:
|
|
result = self._run_one_nccl_test_mpirun(
|
|
binary, label, gpu_count, mpirun, op_min_bw
|
|
)
|
|
if result.get("status") not in ("SKIP", None) and "error" not in result:
|
|
any_binary_worked = True
|
|
results[label.lower()] = result
|
|
progress.advance(task)
|
|
|
|
if not any_binary_worked:
|
|
self.console.print("[yellow]nccl-tests binaries failed, falling back to torchrun[/yellow]")
|
|
return self._run_torchrun_fallback(gpu_count)
|
|
|
|
all_passed = all(
|
|
r.get("status") == "PASS"
|
|
for r in results.values()
|
|
if isinstance(r, dict) and "status" in r
|
|
)
|
|
|
|
return {
|
|
"passed": all_passed,
|
|
"source": "nccl-tests",
|
|
"min_bandwidth_gbps": {
|
|
lbl.lower(): threshold_for(lbl) for _, lbl in tests
|
|
},
|
|
"tests": results,
|
|
"gpu_count": gpu_count,
|
|
"timestamp": datetime.now().isoformat(),
|
|
"detected_gpu_type": self.gpu_type,
|
|
}
|
|
|
|
def _run_one_nccl_test_direct(self, binary_name: str, label: str,
|
|
gpu_count: int, min_bw: float) -> dict:
|
|
"""Run nccl-tests binary directly with -g N (no mpirun needed for single-node)."""
|
|
binary = self._find_nccl_test(binary_name)
|
|
if not binary:
|
|
return {"status": "SKIP", "error": f"{binary_name} not found"}
|
|
|
|
return self._run_nccl_matrix([binary, "-g", str(gpu_count)], min_bw)
|
|
|
|
def _run_one_nccl_test_mpirun(self, binary_name: str, label: str,
|
|
gpu_count: int, mpirun: str, min_bw: float) -> dict:
|
|
"""Run nccl-tests via mpirun (multi-node or per-GPU-process mode)."""
|
|
binary = self._find_nccl_test(binary_name)
|
|
if not binary:
|
|
return {"status": "SKIP", "error": f"{binary_name} not found"}
|
|
|
|
cmd = [
|
|
mpirun,
|
|
"-np", str(gpu_count),
|
|
"--allow-run-as-root",
|
|
"-x", "NCCL_DEBUG=WARN",
|
|
"-x", "CUDA_VISIBLE_DEVICES=" + ",".join(str(i) for i in range(gpu_count)),
|
|
binary,
|
|
"-g", "1",
|
|
]
|
|
|
|
return self._run_nccl_matrix(cmd, min_bw)
|
|
|
|
def _run_nccl_matrix(self, base_cmd: list[str], min_bw: float) -> dict:
|
|
size_results = []
|
|
failures = []
|
|
env = self._runtime_env()
|
|
|
|
try:
|
|
for size in self._message_sizes():
|
|
runs = []
|
|
for _ in range(self._repeats()):
|
|
cmd = [*base_cmd, "-b", size, "-e", size, "-f", "2", "-w", "5", "-n", "20"]
|
|
r = subprocess.run(cmd, capture_output=True, text=True, timeout=300, env=env)
|
|
combined = r.stdout + r.stderr
|
|
if "CUDA driver version is insufficient" in combined or "Test NCCL failure" in combined:
|
|
failures.append({"size": size, "error": "NCCL/CUDA/library failure"})
|
|
continue
|
|
if r.returncode != 0:
|
|
failures.append({"size": size, "error": r.stderr[:300]})
|
|
continue
|
|
parsed = self._parse_nccl_output(r.stdout, min_bw)
|
|
runs.append(parsed.get("best_busbw_gbps", 0))
|
|
if runs:
|
|
worst = min(runs)
|
|
mean = sum(runs) / len(runs)
|
|
std_pct = (statistics.pstdev(runs) / mean * 100) if len(runs) > 1 and mean else 0
|
|
size_results.append({
|
|
"size": size,
|
|
"runs_busbw_gbps": [round(v, 1) for v in runs],
|
|
"worst_busbw_gbps": round(worst, 1),
|
|
"mean_busbw_gbps": round(mean, 1),
|
|
"stddev_pct": round(std_pct, 2),
|
|
"status": "PASS" if worst >= min_bw and std_pct <= self._max_stddev_pct() else "FAIL",
|
|
})
|
|
else:
|
|
size_results.append({"size": size, "status": "FAIL", "runs_busbw_gbps": []})
|
|
|
|
except subprocess.TimeoutExpired:
|
|
return {"status": "FAIL", "error": "timeout"}
|
|
except Exception as e:
|
|
return {"status": "FAIL", "error": str(e)}
|
|
|
|
best_bus = max((r.get("mean_busbw_gbps", 0) for r in size_results), default=0)
|
|
worst_bus = min((r.get("worst_busbw_gbps", 0) for r in size_results if r.get("runs_busbw_gbps")), default=0)
|
|
passed = bool(size_results) and all(r.get("status") == "PASS" for r in size_results) and not failures
|
|
return {
|
|
"status": "PASS" if passed else "FAIL",
|
|
"best_busbw_gbps": round(best_bus, 1),
|
|
"worst_busbw_gbps": round(worst_bus, 1),
|
|
"min_required_gbps": min_bw,
|
|
"max_stddev_pct": self._max_stddev_pct(),
|
|
"by_size": size_results,
|
|
"failures": failures,
|
|
}
|
|
|
|
@staticmethod
|
|
def _parse_nccl_output(stdout: str, min_bw: float) -> dict:
|
|
"""Parse nccl-tests tabular output and extract bandwidth results."""
|
|
best_algbw = 0.0
|
|
best_busbw = 0.0
|
|
size_results = []
|
|
|
|
for line in stdout.split("\n"):
|
|
line = line.strip()
|
|
if not line or line.startswith("#"):
|
|
continue
|
|
parts = line.split()
|
|
# nccl-tests data lines: size count type redop root time algbw busbw #wrong [time algbw busbw #wrong]
|
|
if len(parts) >= 9:
|
|
try:
|
|
size = int(parts[0])
|
|
# parts[2] is dtype string ('float'/'int32'/etc.), not a number
|
|
# out-of-place columns: time=parts[5], algbw=parts[6], busbw=parts[7]
|
|
time_us = float(parts[5])
|
|
algbw = float(parts[6])
|
|
busbw = float(parts[7])
|
|
size_results.append({
|
|
"size": size,
|
|
"time_us": time_us,
|
|
"algbw_gbps": algbw,
|
|
"busbw_gbps": busbw,
|
|
})
|
|
if busbw > best_busbw:
|
|
best_busbw = busbw
|
|
if algbw > best_algbw:
|
|
best_algbw = algbw
|
|
except (ValueError, IndexError):
|
|
continue
|
|
|
|
status = "PASS" if best_busbw >= min_bw else "WARN"
|
|
return {
|
|
"status": status,
|
|
"best_algbw_gbps": round(best_algbw, 1),
|
|
"best_busbw_gbps": round(best_busbw, 1),
|
|
"min_required_gbps": min_bw,
|
|
"by_size": size_results[-5:] if size_results else [],
|
|
}
|
|
|
|
def _run_torchrun_fallback(self, gpu_count: int) -> dict:
|
|
"""Basic NCCL connectivity test via torchrun — verifies NCCL works but does not benchmark performance."""
|
|
self.console.print("[yellow]nccl-tests not available, running basic NCCL connectivity check[/yellow]")
|
|
|
|
code = f"""
|
|
import torch, torch.distributed as dist, os
|
|
os.environ.setdefault("MASTER_ADDR","127.0.0.1")
|
|
os.environ.setdefault("MASTER_PORT","29500")
|
|
rank=int(os.environ.get("LOCAL_RANK",0))
|
|
ws={gpu_count}
|
|
dist.init_process_group("nccl",rank=rank,world_size=ws)
|
|
torch.cuda.set_device(rank)
|
|
|
|
x=torch.randn(1024*1024,device=f"cuda:{{rank}}",dtype=torch.float32)
|
|
|
|
# Test AllReduce
|
|
try:
|
|
dist.all_reduce(x.clone())
|
|
if rank==0: print("allreduce:ok")
|
|
except Exception as e:
|
|
if rank==0: print(f"allreduce:fail:{{e}}")
|
|
|
|
# Test Broadcast
|
|
try:
|
|
dist.broadcast(x.clone(),src=0)
|
|
if rank==0: print("broadcast:ok")
|
|
except Exception as e:
|
|
if rank==0: print(f"broadcast:fail:{{e}}")
|
|
|
|
# Test AllGather
|
|
try:
|
|
tensor_list=[torch.empty_like(x) for _ in range(ws)]
|
|
dist.all_gather(tensor_list,x.clone())
|
|
if rank==0: print("allgather:ok")
|
|
except Exception as e:
|
|
if rank==0: print(f"allgather:fail:{{e}}")
|
|
|
|
# Test ReduceScatter
|
|
try:
|
|
chunks=list(x.chunk(ws))
|
|
output=torch.empty_like(chunks[0])
|
|
dist.reduce_scatter(output,chunks)
|
|
if rank==0: print("reducescatter:ok")
|
|
except Exception as e:
|
|
if rank==0: print(f"reducescatter:fail:{{e}}")
|
|
|
|
# Test AllToAll
|
|
try:
|
|
chunks=list(x.chunk(ws))
|
|
output_list=[torch.empty_like(c) for c in chunks]
|
|
dist.all_to_all(output_list,chunks)
|
|
if rank==0: print("alltoall:ok")
|
|
except Exception as e:
|
|
if rank==0: print(f"alltoall:fail:{{e}}")
|
|
|
|
dist.destroy_process_group()
|
|
"""
|
|
import tempfile
|
|
tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False)
|
|
tmp.write(code)
|
|
tmp.close()
|
|
|
|
try:
|
|
# Prefer torchrun from the same venv as the running Python
|
|
import sys
|
|
venv_torchrun = os.path.join(os.path.dirname(sys.executable), "torchrun")
|
|
torchrun_cmd = venv_torchrun if os.path.isfile(venv_torchrun) else "torchrun"
|
|
|
|
r = subprocess.run(
|
|
[torchrun_cmd, f"--nproc_per_node={gpu_count}", tmp.name],
|
|
capture_output=True, text=True, timeout=120,
|
|
env=self._runtime_env(),
|
|
)
|
|
os.unlink(tmp.name)
|
|
|
|
# Parse connectivity results — format: op_name:ok or op_name:fail:error
|
|
tests = {}
|
|
all_passed = True
|
|
for line in r.stdout.split("\n"):
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
parts = line.split(":")
|
|
op_name = parts[0]
|
|
result = parts[1] if len(parts) > 1 else "unknown"
|
|
|
|
if result == "ok":
|
|
status = "PASS"
|
|
else:
|
|
status = "FAIL"
|
|
all_passed = False
|
|
|
|
tests[op_name] = {
|
|
"status": status,
|
|
"error": ":".join(parts[2:]) if len(parts) > 2 and result == "fail" else None,
|
|
}
|
|
|
|
return {
|
|
# torchrun fallback is a functional smoke only. It never proves
|
|
# production bus bandwidth, so it must not satisfy acceptance.
|
|
"passed": False,
|
|
"functional_passed": all_passed,
|
|
"source": "torchrun_fallback",
|
|
"tests": tests,
|
|
"gpu_count": gpu_count,
|
|
"error": None if all_passed else "torchrun functional NCCL smoke failed",
|
|
"acceptance_gap": "nccl-tests bus bandwidth was not measured",
|
|
}
|
|
except Exception as e:
|
|
return {"passed": False, "source": "torchrun_fallback", "error": str(e)}
|
|
|
|
@staticmethod
|
|
def print_results(results: dict, console: Console = None):
|
|
c = console or Console()
|
|
if "error" in results:
|
|
c.print(f"[bold red]Error: {results['error']}[/bold red]")
|
|
return
|
|
|
|
passed = results.get("passed", False)
|
|
source = results.get("source", "unknown")
|
|
|
|
if source == "torchrun_fallback":
|
|
# Connectivity check mode
|
|
functional = results.get("functional_passed", passed)
|
|
verdict = "[bold yellow]⚠ NCCL bus BW NOT VERIFIED[/bold yellow]" if functional else "[bold red]✗ NCCL Connectivity FAILED[/bold red]"
|
|
c.print(f"{verdict} [dim](basic check via torchrun)[/dim]")
|
|
|
|
tests = results.get("tests", {})
|
|
if tests:
|
|
c.print("\n[dim]Operations tested:[/dim]")
|
|
for op_name, result in tests.items():
|
|
if not isinstance(result, dict):
|
|
continue
|
|
status = result.get("status", "FAIL")
|
|
s_color = "green" if status == "PASS" else "red"
|
|
error = result.get("error")
|
|
if error:
|
|
c.print(f" [{s_color}]{op_name}[/{s_color}] — {error}")
|
|
else:
|
|
c.print(f" [{s_color}]{op_name}[/{s_color}]")
|
|
|
|
c.print("\n[yellow]Note: functional connectivity test only (no bus bandwidth data; acceptance FAIL)[/yellow]")
|
|
else:
|
|
# nccl-tests mode
|
|
verdict = "[bold green]✓ NCCL tests PASSED[/bold green]" if passed else "[bold yellow]⚠ NCCL tests WARNING[/bold yellow]"
|
|
c.print(f"{verdict} [dim](via {source})[/dim]")
|
|
|
|
tests = results.get("tests", {})
|
|
for op_name, result in tests.items():
|
|
if not isinstance(result, dict):
|
|
continue
|
|
c.print(f"\n[bold cyan]{op_name.upper()}[/bold cyan]")
|
|
status = result.get("status", "FAIL")
|
|
s_color = "green" if status == "PASS" else ("yellow" if status == "WARN" else "red")
|
|
c.print(f" Status: [{s_color}]{status}[/{s_color}] "
|
|
f"Best bus BW: {result.get('best_busbw_gbps', 'N/A')} GB/s "
|
|
f"(min: {result.get('min_required_gbps', 'N/A')} GB/s)")
|
|
|
|
by_size = result.get("by_size", [])
|
|
if by_size:
|
|
t = Table(box=None, padding=(0, 1))
|
|
t.add_column("Size", style="bold", justify="right")
|
|
t.add_column("Worst Bus BW", justify="right")
|
|
t.add_column("Mean Bus BW", justify="right")
|
|
t.add_column("StdDev", justify="right")
|
|
t.add_column("Status", justify="right")
|
|
for r in by_size:
|
|
t.add_row(
|
|
str(r.get("size", "")),
|
|
f"{r.get('worst_busbw_gbps', 0):.1f}",
|
|
f"{r.get('mean_busbw_gbps', 0):.1f}",
|
|
f"{r.get('stddev_pct', 0):.2f}%",
|
|
r.get("status", "?"),
|
|
)
|
|
c.print(t)
|