"""NCCL multi-GPU communication test — wraps official nccl-tests.""" import glob import os import re import shutil import subprocess import statistics import sys from datetime import datetime from typing import Optional from rich.console import Console from rich.table import Table from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn from modules.gpu_specs import detect_gpu_type, get_gpu_specs, resolve_tools_dir TORCH_AVAILABLE = False try: import torch if torch.cuda.is_available(): TORCH_AVAILABLE = True except ImportError: pass # Per-operation bandwidth thresholds, as a fraction of NVLink bidirectional BW. # Values aligned with the H100 production acceptance criteria (acceptance doc §5). # AllToAll runs ~10-20% lower than AllReduce on 8-GPU NVSwitch, so its fraction is # set lower; broadcast/sendrecv sit between. _OP_BW_FRACTIONS = { "allreduce": 0.45, "allgather": 0.45, "reducescatter": 0.45, "broadcast": 0.40, "sendrecv": 0.40, "alltoall": 0.35, } class NCCLTest: def __init__(self, config: dict): self.config = config self.console = Console() self.nccl_cfg = config.get("nccl", {}) self.tools_dir = resolve_tools_dir(config) self.gpu_type = detect_gpu_type() self.specs = get_gpu_specs(self.gpu_type) def _find_nccl_test(self, name: str) -> Optional[str]: p = shutil.which(name) if p: return p build_dir = os.path.join(self.tools_dir, "nccl-tests", "build") local = os.path.join(build_dir, name) if os.path.isfile(local) and os.access(local, shutil.os.X_OK): return local matches = glob.glob(os.path.join(self.tools_dir, "nccl-tests", "**", name), recursive=True) for m in matches: if os.access(m, shutil.os.X_OK): return m return None def _find_mpirun(self) -> Optional[str]: for cmd in ["mpirun", "mpiexec", os.path.join(self.tools_dir, "mpi", "bin", "mpirun")]: p = shutil.which(cmd) if p: return p return None def _message_sizes(self) -> list[str]: return list(self.nccl_cfg.get("message_sizes") or ["1M", "256M", "2G"]) def _repeats(self) -> int: return int(self.nccl_cfg.get("repeats", 3)) def _max_stddev_pct(self) -> float: return float(self.nccl_cfg.get("max_stddev_pct", 3)) def _runtime_env(self) -> dict: env = {**os.environ, "NCCL_DEBUG": "WARN"} lib_dirs = [] nccl_home = env.get("NCCL_HOME") or self.nccl_cfg.get("nccl_home") if nccl_home: lib_dirs.append(os.path.join(str(nccl_home), "lib")) for path in sys.path: lib_dirs.append(os.path.join(path, "nvidia", "nccl", "lib")) venv_root = os.path.dirname(os.path.dirname(sys.executable)) lib_dirs.extend(glob.glob(os.path.join(venv_root, "lib", "python*", "site-packages", "nvidia", "nccl", "lib"))) existing = env.get("LD_LIBRARY_PATH", "") valid_dirs = [] for d in lib_dirs: if d and os.path.isdir(d) and d not in valid_dirs: valid_dirs.append(d) if valid_dirs: env["LD_LIBRARY_PATH"] = ":".join(valid_dirs + ([existing] if existing else [])) return env def run(self) -> dict: gpu_count = 0 if TORCH_AVAILABLE: gpu_count = torch.cuda.device_count() if gpu_count < 2: self.console.print(f"[yellow]NCCL test requires at least 2 GPUs (found {gpu_count})[/yellow]") return {"error": "need_at_least_2_gpus", "gpu_count": gpu_count} tests = [] if self.nccl_cfg.get("test_allreduce", True): tests.append(("all_reduce_perf", "AllReduce")) if self.nccl_cfg.get("test_alltoall", True): tests.append(("alltoall_perf", "AllToAll")) if self.nccl_cfg.get("test_broadcast", True): tests.append(("broadcast_perf", "Broadcast")) if self.nccl_cfg.get("test_reduce_scatter", False): tests.append(("reduce_scatter_perf", "ReduceScatter")) if self.nccl_cfg.get("test_allgather", False): tests.append(("all_gather_perf", "AllGather")) if self.nccl_cfg.get("test_sendrecv", False): tests.append(("sendrecv_perf", "SendRecv")) nvlink_bw = self.specs.get("nvlink_bandwidth_gbps", 0) # User-provided override applies uniformly across all ops; otherwise # each op gets its own threshold from _OP_BW_FRACTIONS. user_override = self.nccl_cfg.get("min_bandwidth_gbps") def threshold_for(label: str) -> float: if user_override: return float(user_override) if nvlink_bw <= 0: return 10.0 # conservative floor frac = _OP_BW_FRACTIONS.get(label.lower(), 0.45) return round(nvlink_bw * frac) if self.gpu_type == "unknown": self.console.print("[yellow]Unknown GPU — using conservative bandwidth thresholds[/yellow]") # Strategy: try nccl-tests binary directly (single-node, -g N), # then mpirun, then torchrun fallback results = {} any_binary_worked = False with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), TimeElapsedColumn(), console=self.console, ) as progress: task = progress.add_task("NCCL tests...", total=len(tests)) for binary, label in tests: progress.update(task, description=f"NCCL {label}...") op_min_bw = threshold_for(label) result = self._run_one_nccl_test_direct( binary, label, gpu_count, op_min_bw ) if result.get("status") not in ("SKIP", None) and "error" not in result: any_binary_worked = True results[label.lower()] = result else: # Try mpirun fallback mpirun = self._find_mpirun() if mpirun: result = self._run_one_nccl_test_mpirun( binary, label, gpu_count, mpirun, op_min_bw ) if result.get("status") not in ("SKIP", None) and "error" not in result: any_binary_worked = True results[label.lower()] = result progress.advance(task) if not any_binary_worked: self.console.print("[yellow]nccl-tests binaries failed, falling back to torchrun[/yellow]") return self._run_torchrun_fallback(gpu_count) all_passed = all( r.get("status") == "PASS" for r in results.values() if isinstance(r, dict) and "status" in r ) return { "passed": all_passed, "source": "nccl-tests", "min_bandwidth_gbps": { lbl.lower(): threshold_for(lbl) for _, lbl in tests }, "tests": results, "gpu_count": gpu_count, "timestamp": datetime.now().isoformat(), "detected_gpu_type": self.gpu_type, } def _run_one_nccl_test_direct(self, binary_name: str, label: str, gpu_count: int, min_bw: float) -> dict: """Run nccl-tests binary directly with -g N (no mpirun needed for single-node).""" binary = self._find_nccl_test(binary_name) if not binary: return {"status": "SKIP", "error": f"{binary_name} not found"} return self._run_nccl_matrix([binary, "-g", str(gpu_count)], min_bw) def _run_one_nccl_test_mpirun(self, binary_name: str, label: str, gpu_count: int, mpirun: str, min_bw: float) -> dict: """Run nccl-tests via mpirun (multi-node or per-GPU-process mode).""" binary = self._find_nccl_test(binary_name) if not binary: return {"status": "SKIP", "error": f"{binary_name} not found"} cmd = [ mpirun, "-np", str(gpu_count), "--allow-run-as-root", "-x", "NCCL_DEBUG=WARN", "-x", "CUDA_VISIBLE_DEVICES=" + ",".join(str(i) for i in range(gpu_count)), binary, "-g", "1", ] return self._run_nccl_matrix(cmd, min_bw) def _run_nccl_matrix(self, base_cmd: list[str], min_bw: float) -> dict: size_results = [] failures = [] env = self._runtime_env() try: for size in self._message_sizes(): runs = [] for _ in range(self._repeats()): cmd = [*base_cmd, "-b", size, "-e", size, "-f", "2", "-w", "5", "-n", "20"] r = subprocess.run(cmd, capture_output=True, text=True, timeout=300, env=env) combined = r.stdout + r.stderr if "CUDA driver version is insufficient" in combined or "Test NCCL failure" in combined: failures.append({"size": size, "error": "NCCL/CUDA/library failure"}) continue if r.returncode != 0: failures.append({"size": size, "error": r.stderr[:300]}) continue parsed = self._parse_nccl_output(r.stdout, min_bw) runs.append(parsed.get("best_busbw_gbps", 0)) if runs: worst = min(runs) mean = sum(runs) / len(runs) std_pct = (statistics.pstdev(runs) / mean * 100) if len(runs) > 1 and mean else 0 size_results.append({ "size": size, "runs_busbw_gbps": [round(v, 1) for v in runs], "worst_busbw_gbps": round(worst, 1), "mean_busbw_gbps": round(mean, 1), "stddev_pct": round(std_pct, 2), "status": "PASS" if worst >= min_bw and std_pct <= self._max_stddev_pct() else "FAIL", }) else: size_results.append({"size": size, "status": "FAIL", "runs_busbw_gbps": []}) except subprocess.TimeoutExpired: return {"status": "FAIL", "error": "timeout"} except Exception as e: return {"status": "FAIL", "error": str(e)} best_bus = max((r.get("mean_busbw_gbps", 0) for r in size_results), default=0) worst_bus = min((r.get("worst_busbw_gbps", 0) for r in size_results if r.get("runs_busbw_gbps")), default=0) passed = bool(size_results) and all(r.get("status") == "PASS" for r in size_results) and not failures return { "status": "PASS" if passed else "FAIL", "best_busbw_gbps": round(best_bus, 1), "worst_busbw_gbps": round(worst_bus, 1), "min_required_gbps": min_bw, "max_stddev_pct": self._max_stddev_pct(), "by_size": size_results, "failures": failures, } @staticmethod def _parse_nccl_output(stdout: str, min_bw: float) -> dict: """Parse nccl-tests tabular output and extract bandwidth results.""" best_algbw = 0.0 best_busbw = 0.0 size_results = [] for line in stdout.split("\n"): line = line.strip() if not line or line.startswith("#"): continue parts = line.split() # nccl-tests data lines: size count type redop root time algbw busbw #wrong [time algbw busbw #wrong] if len(parts) >= 9: try: size = int(parts[0]) # parts[2] is dtype string ('float'/'int32'/etc.), not a number # out-of-place columns: time=parts[5], algbw=parts[6], busbw=parts[7] time_us = float(parts[5]) algbw = float(parts[6]) busbw = float(parts[7]) size_results.append({ "size": size, "time_us": time_us, "algbw_gbps": algbw, "busbw_gbps": busbw, }) if busbw > best_busbw: best_busbw = busbw if algbw > best_algbw: best_algbw = algbw except (ValueError, IndexError): continue status = "PASS" if best_busbw >= min_bw else "WARN" return { "status": status, "best_algbw_gbps": round(best_algbw, 1), "best_busbw_gbps": round(best_busbw, 1), "min_required_gbps": min_bw, "by_size": size_results[-5:] if size_results else [], } def _run_torchrun_fallback(self, gpu_count: int) -> dict: """Basic NCCL connectivity test via torchrun — verifies NCCL works but does not benchmark performance.""" self.console.print("[yellow]nccl-tests not available, running basic NCCL connectivity check[/yellow]") code = f""" import torch, torch.distributed as dist, os os.environ.setdefault("MASTER_ADDR","127.0.0.1") os.environ.setdefault("MASTER_PORT","29500") rank=int(os.environ.get("LOCAL_RANK",0)) ws={gpu_count} dist.init_process_group("nccl",rank=rank,world_size=ws) torch.cuda.set_device(rank) x=torch.randn(1024*1024,device=f"cuda:{{rank}}",dtype=torch.float32) # Test AllReduce try: dist.all_reduce(x.clone()) if rank==0: print("allreduce:ok") except Exception as e: if rank==0: print(f"allreduce:fail:{{e}}") # Test Broadcast try: dist.broadcast(x.clone(),src=0) if rank==0: print("broadcast:ok") except Exception as e: if rank==0: print(f"broadcast:fail:{{e}}") # Test AllGather try: tensor_list=[torch.empty_like(x) for _ in range(ws)] dist.all_gather(tensor_list,x.clone()) if rank==0: print("allgather:ok") except Exception as e: if rank==0: print(f"allgather:fail:{{e}}") # Test ReduceScatter try: chunks=list(x.chunk(ws)) output=torch.empty_like(chunks[0]) dist.reduce_scatter(output,chunks) if rank==0: print("reducescatter:ok") except Exception as e: if rank==0: print(f"reducescatter:fail:{{e}}") # Test AllToAll try: chunks=list(x.chunk(ws)) output_list=[torch.empty_like(c) for c in chunks] dist.all_to_all(output_list,chunks) if rank==0: print("alltoall:ok") except Exception as e: if rank==0: print(f"alltoall:fail:{{e}}") dist.destroy_process_group() """ import tempfile tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) tmp.write(code) tmp.close() try: # Prefer torchrun from the same venv as the running Python import sys venv_torchrun = os.path.join(os.path.dirname(sys.executable), "torchrun") torchrun_cmd = venv_torchrun if os.path.isfile(venv_torchrun) else "torchrun" r = subprocess.run( [torchrun_cmd, f"--nproc_per_node={gpu_count}", tmp.name], capture_output=True, text=True, timeout=120, env=self._runtime_env(), ) os.unlink(tmp.name) # Parse connectivity results — format: op_name:ok or op_name:fail:error tests = {} all_passed = True for line in r.stdout.split("\n"): line = line.strip() if not line: continue parts = line.split(":") op_name = parts[0] result = parts[1] if len(parts) > 1 else "unknown" if result == "ok": status = "PASS" else: status = "FAIL" all_passed = False tests[op_name] = { "status": status, "error": ":".join(parts[2:]) if len(parts) > 2 and result == "fail" else None, } return { # torchrun fallback is a functional smoke only. It never proves # production bus bandwidth, so it must not satisfy acceptance. "passed": False, "functional_passed": all_passed, "source": "torchrun_fallback", "tests": tests, "gpu_count": gpu_count, "error": None if all_passed else "torchrun functional NCCL smoke failed", "acceptance_gap": "nccl-tests bus bandwidth was not measured", } except Exception as e: return {"passed": False, "source": "torchrun_fallback", "error": str(e)} @staticmethod def print_results(results: dict, console: Console = None): c = console or Console() if "error" in results: c.print(f"[bold red]Error: {results['error']}[/bold red]") return passed = results.get("passed", False) source = results.get("source", "unknown") if source == "torchrun_fallback": # Connectivity check mode functional = results.get("functional_passed", passed) verdict = "[bold yellow]⚠ NCCL bus BW NOT VERIFIED[/bold yellow]" if functional else "[bold red]✗ NCCL Connectivity FAILED[/bold red]" c.print(f"{verdict} [dim](basic check via torchrun)[/dim]") tests = results.get("tests", {}) if tests: c.print("\n[dim]Operations tested:[/dim]") for op_name, result in tests.items(): if not isinstance(result, dict): continue status = result.get("status", "FAIL") s_color = "green" if status == "PASS" else "red" error = result.get("error") if error: c.print(f" [{s_color}]{op_name}[/{s_color}] — {error}") else: c.print(f" [{s_color}]{op_name}[/{s_color}]") c.print("\n[yellow]Note: functional connectivity test only (no bus bandwidth data; acceptance FAIL)[/yellow]") else: # nccl-tests mode verdict = "[bold green]✓ NCCL tests PASSED[/bold green]" if passed else "[bold yellow]⚠ NCCL tests WARNING[/bold yellow]" c.print(f"{verdict} [dim](via {source})[/dim]") tests = results.get("tests", {}) for op_name, result in tests.items(): if not isinstance(result, dict): continue c.print(f"\n[bold cyan]{op_name.upper()}[/bold cyan]") status = result.get("status", "FAIL") s_color = "green" if status == "PASS" else ("yellow" if status == "WARN" else "red") c.print(f" Status: [{s_color}]{status}[/{s_color}] " f"Best bus BW: {result.get('best_busbw_gbps', 'N/A')} GB/s " f"(min: {result.get('min_required_gbps', 'N/A')} GB/s)") by_size = result.get("by_size", []) if by_size: t = Table(box=None, padding=(0, 1)) t.add_column("Size", style="bold", justify="right") t.add_column("Worst Bus BW", justify="right") t.add_column("Mean Bus BW", justify="right") t.add_column("StdDev", justify="right") t.add_column("Status", justify="right") for r in by_size: t.add_row( str(r.get("size", "")), f"{r.get('worst_busbw_gbps', 0):.1f}", f"{r.get('mean_busbw_gbps', 0):.1f}", f"{r.get('stddev_pct', 0):.2f}%", r.get("status", "?"), ) c.print(t)