"""Multi-node NCCL benchmark wrapper for nccl-tests via mpirun.""" import json import os import re import shutil import subprocess from datetime import datetime from typing import Optional from rich.console import Console from rich.table import Table from modules.gpu_specs import resolve_tools_dir _TEST_ALIASES = { "allreduce": "all_reduce_perf", "all_reduce": "all_reduce_perf", "all_reduce_perf": "all_reduce_perf", "allgather": "all_gather_perf", "all_gather": "all_gather_perf", "all_gather_perf": "all_gather_perf", "alltoall": "alltoall_perf", "all_to_all": "alltoall_perf", "alltoall_perf": "alltoall_perf", "broadcast": "broadcast_perf", "broadcast_perf": "broadcast_perf", "reducescatter": "reduce_scatter_perf", "reduce_scatter": "reduce_scatter_perf", "reduce_scatter_perf": "reduce_scatter_perf", "sendrecv": "sendrecv_perf", "send_recv": "sendrecv_perf", "sendrecv_perf": "sendrecv_perf", } _OP_LABELS = { "all_reduce_perf": "allreduce", "all_gather_perf": "allgather", "alltoall_perf": "alltoall", "broadcast_perf": "broadcast", "reduce_scatter_perf": "reducescatter", "sendrecv_perf": "sendrecv", } class MultiNodeNCCLTest: """Run cross-node NCCL tests with a PDF-style message-size sweep.""" def __init__(self, config: dict): self.config = config self.cfg = config.get("multinode_nccl", {}) or {} self.tools_dir = resolve_tools_dir(config) self.console = Console() self.artifact_dir = os.environ.get("MULTINODE_NCCL_ARTIFACT_DIR") or self.cfg.get("artifact_dir") def _find_mpirun(self) -> Optional[str]: configured = self.cfg.get("mpirun_path") if configured and os.path.isfile(str(configured)) and os.access(str(configured), os.X_OK): return str(configured) for cmd in ["mpirun", "mpiexec", os.path.join(self.tools_dir, "mpi", "bin", "mpirun")]: found = shutil.which(cmd) if found: return found return None def _find_nccl_test(self, binary_name: str) -> Optional[str]: configured = self.cfg.get("nccl_tests_dir") candidates = [] if configured: candidates.append(os.path.join(configured, binary_name)) candidates.append(os.path.join(self.tools_dir, "nccl-tests", "build", binary_name)) found = shutil.which(binary_name) if found: candidates.insert(0, found) for path in candidates: if path and os.path.isfile(path) and os.access(path, os.X_OK): return path return None def _tests(self) -> list[str]: configured = self.cfg.get("tests") or ["all_reduce_perf", "alltoall_perf"] tests = [] for name in configured: binary = _TEST_ALIASES.get(str(name).lower()) if binary and binary not in tests: tests.append(binary) return tests def _hosts(self) -> list[dict]: hosts = self.cfg.get("hosts") or [] normalized = [] for host in hosts: if isinstance(host, str): normalized.append({"addr": host, "slots": 8}) elif isinstance(host, dict): normalized.append({ "name": host.get("name") or host.get("addr"), "addr": host.get("addr") or host.get("host") or host.get("ip"), "slots": int(host.get("slots", 8)), }) return [h for h in normalized if h.get("addr")] def _topologies(self) -> list[dict]: topologies = self.cfg.get("topologies") or [{"nodes": 2, "gpus_per_node": 8}] normalized = [] for topo in topologies: nodes = int(topo.get("nodes", 2)) gpus_per_node = int(topo.get("gpus_per_node", topo.get("gpn", 8))) normalized.append({ "nodes": nodes, "gpus_per_node": gpus_per_node, "label": topo.get("label") or f"{nodes} nodes x {gpus_per_node} GPUs", "cuda_visible_devices": topo.get("cuda_visible_devices"), "env": topo.get("env") or {}, "op_env": topo.get("op_env") or topo.get("test_env") or {}, "min_peak_busbw_gbps": topo.get("min_peak_busbw_gbps"), }) return normalized def _env_exports(self, topo: dict = None, label: str = None, binary: str = None) -> list[tuple[str, str]]: env_cfg = { "NCCL_DEBUG": self.cfg.get("debug", "WARN"), "NCCL_SOCKET_IFNAME": self.cfg.get("socket_ifname"), "NCCL_IB_GID_INDEX": self.cfg.get("ib_gid_index"), "NCCL_IB_SL": self.cfg.get("ib_sl"), "NCCL_IB_TC": self.cfg.get("ib_tc"), "NCCL_IB_HCA": self.cfg.get("ib_hca"), "NCCL_IB_TIMEOUT": self.cfg.get("ib_timeout"), "NCCL_IB_QPS_PER_CONNECTION": self.cfg.get("qps_per_connection"), "NCCL_MIN_NCHANNELS": self.cfg.get("min_nchannels"), "NCCL_NET_PLUGIN": self.cfg.get("net_plugin"), "NCCL_NVLS_ENABLE": self.cfg.get("nvls_enable"), "NCCL_IB_SPLIT_DATA_ON_QPS": self.cfg.get("split_data_on_qps"), } mpi_ld_preload = self._mpi_ld_preload() if mpi_ld_preload: env_cfg["LD_PRELOAD"] = mpi_ld_preload extra_ld_library_path = self._extra_ld_library_path() if extra_ld_library_path: existing = os.environ.get("LD_LIBRARY_PATH", "") env_cfg["LD_LIBRARY_PATH"] = ":".join( [extra_ld_library_path] + ([existing] if existing else []) ) extra_env = self.cfg.get("extra_env") or {} if isinstance(extra_env, dict): self._merge_env(env_cfg, extra_env) if topo: if topo.get("cuda_visible_devices"): env_cfg["CUDA_VISIBLE_DEVICES"] = str(topo["cuda_visible_devices"]) if isinstance(topo.get("env"), dict): self._merge_env(env_cfg, topo["env"]) op_env = topo.get("op_env") if isinstance(op_env, dict): for key in (label, binary): overrides = op_env.get(key) if isinstance(overrides, dict): self._merge_env(env_cfg, overrides) return [(k, str(v)) for k, v in env_cfg.items() if v is not None] @staticmethod def _merge_env(env_cfg: dict, overrides: dict): for key, value in overrides.items(): key = str(key) if value is None: env_cfg.pop(key, None) else: env_cfg[key] = str(value) def _mpi_ld_preload(self) -> str: preload = self.cfg.get("mpi_ld_preload") if isinstance(preload, list): return " ".join(str(p) for p in preload if p) return str(preload) if preload else "" def _runtime_env(self) -> dict: env = os.environ.copy() mpi_ld_preload = self._mpi_ld_preload() if mpi_ld_preload: env["LD_PRELOAD"] = mpi_ld_preload extra_ld_library_path = self._extra_ld_library_path() if extra_ld_library_path: existing = env.get("LD_LIBRARY_PATH", "") env["LD_LIBRARY_PATH"] = ":".join( [extra_ld_library_path] + ([existing] if existing else []) ) return env def _extra_ld_library_path(self) -> str: paths = self.cfg.get("extra_ld_library_path") if isinstance(paths, list): return ":".join(str(p) for p in paths if p) return str(paths) if paths else "" def _preflight(self, mpirun: Optional[str], tests: list[str], hosts: list[dict]) -> dict: checks = [] checks.append({"name": "mpirun", "status": "PASS" if mpirun else "FAIL", "detail": mpirun or "not found"}) checks.append({"name": "hosts", "status": "PASS" if len(hosts) >= 2 else "FAIL", "detail": f"{len(hosts)} configured"}) for binary in tests: path = self._find_nccl_test(binary) checks.append({"name": binary, "status": "PASS" if path else "FAIL", "detail": path or "not found"}) if self.cfg.get("ssh_preflight", True): user = self.cfg.get("ssh_user", "root") for host in hosts: target = f"{user}@{host['addr']}" cmd = [ "ssh", "-o", "BatchMode=yes", "-o", "ConnectTimeout=5", "-o", "StrictHostKeyChecking=accept-new", target, "hostname", ] try: r = subprocess.run(cmd, capture_output=True, text=True, timeout=8, env=self._runtime_env()) detail = r.stdout.strip() or r.stderr.strip()[:120] checks.append({ "name": f"ssh {host['addr']}", "status": "PASS" if r.returncode == 0 else "WARN", "detail": detail, }) except Exception as e: checks.append({"name": f"ssh {host['addr']}", "status": "WARN", "detail": str(e)}) return { "checks": checks, "passed": all(c["status"] == "PASS" for c in checks if not c["name"].startswith("ssh ")), } def run(self) -> dict: mpirun = self._find_mpirun() tests = self._tests() hosts = self._hosts() topologies = self._topologies() preflight = self._preflight(mpirun, tests, hosts) if not preflight["passed"]: return { "passed": False, "source": "nccl-tests-mpirun", "mode": self.cfg.get("mode", "sweep"), "hosts": hosts, "preflight": preflight, "tests": {}, "error": "multinode NCCL preflight failed", "timestamp": datetime.now().isoformat(), } results = {} for binary in tests: label = _OP_LABELS[binary] binary_path = self._find_nccl_test(binary) op_results = [] for topo in topologies: op_results.append(self._run_topology(mpirun, binary_path, label, hosts, topo)) results[label] = {"binary": binary_path, "topologies": op_results} passed = all( topo.get("status") == "PASS" for op in results.values() for topo in op.get("topologies", []) ) return { "passed": passed, "source": "nccl-tests-mpirun", "mode": self.cfg.get("mode", "sweep"), "hosts": hosts, "preflight": preflight, "tests": results, "artifact_dir": self.artifact_dir, "timestamp": datetime.now().isoformat(), } def _run_topology(self, mpirun: str, binary: str, label: str, hosts: list[dict], topo: dict) -> dict: nodes = topo["nodes"] gpus_per_node = topo["gpus_per_node"] selected_hosts = hosts[:nodes] host_arg = ",".join(f"{h['addr']}:{gpus_per_node}" for h in selected_hosts) ranks = nodes * gpus_per_node cmd = [ mpirun, "--allow-run-as-root", "--mca", "btl_openib_warn_no_device_params_found", "0", "--mca", "btl_tcp_if_include", str(self.cfg.get("socket_ifname", "bond0")), "--mca", "oob_tcp_if_include", str(self.cfg.get("oob_tcp_ifname", self.cfg.get("socket_ifname", "bond0"))), "-H", host_arg, "--map-by", f"ppr:{gpus_per_node}:node", "-np", str(ranks), ] plm_rsh_args = self.cfg.get("plm_rsh_args") if plm_rsh_args: cmd.extend(["--mca", "plm_rsh_args", str(plm_rsh_args)]) for key, value in self._env_exports(topo=topo, label=label, binary=os.path.basename(binary)): cmd.extend(["-x", f"{key}={value}"]) cmd.extend([ binary, "-b", str(self.cfg.get("begin_size", "1k")), "-e", str(self.cfg.get("end_size", "16g")), "-g", str(self.cfg.get("gpus_per_rank", 1)), "-f", str(self.cfg.get("step_factor", 2)), "-w", str(self.cfg.get("warmup_iters", 10)), ]) if self.cfg.get("iters") is not None: cmd.extend(["-n", str(self.cfg["iters"])]) timeout = int(self.cfg.get("timeout_sec", 1800)) started = datetime.now().isoformat() try: r = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout, env=self._runtime_env()) except subprocess.TimeoutExpired: result = { "label": topo["label"], "nodes": nodes, "gpus_per_node": gpus_per_node, "ranks": ranks, "hosts": selected_hosts, "command": " ".join(cmd), "status": "FAIL", "error": f"timeout after {timeout}s", "started_at": started, } self._write_artifacts(label, topo, result, "", "") return result parsed = self._parse_nccl_output(r.stdout) net_diag = self._parse_network_diagnostics(r.stdout + "\n" + r.stderr) threshold = self._threshold_for(label, topo) wrong = sum(row.get("wrong", 0) for row in parsed["by_size"]) has_bw = parsed["peak_busbw_gbps"] > 0 status = "PASS" if r.returncode == 0 and has_bw and wrong == 0 and parsed["peak_busbw_gbps"] >= threshold else "FAIL" result = { "label": topo["label"], "nodes": nodes, "gpus_per_node": gpus_per_node, "ranks": ranks, "hosts": selected_hosts, "cuda_visible_devices": topo.get("cuda_visible_devices"), "command": " ".join(cmd), "returncode": r.returncode, "status": status, "peak_busbw_gbps": parsed["peak_busbw_gbps"], "peak_algbw_gbps": parsed["peak_algbw_gbps"], "peak_size": parsed["peak_size"], "avg_busbw_gbps": parsed["avg_busbw_gbps"], "min_required_gbps": threshold, "wrong_count": wrong, "network": net_diag, "by_size": parsed["by_size"], "stderr_tail": r.stderr[-1200:], "stdout_tail": r.stdout[-1200:], "started_at": started, "finished_at": datetime.now().isoformat(), } self._write_artifacts(label, topo, result, r.stdout, r.stderr) return result def _write_artifacts(self, label: str, topo: dict, result: dict, stdout: str, stderr: str): if not self.artifact_dir: return os.makedirs(self.artifact_dir, exist_ok=True) prefix = _safe_name(f"{label}_{topo.get('nodes')}x{topo.get('gpus_per_node')}_{topo.get('label')}") base = os.path.join(self.artifact_dir, prefix) with open(base + ".cmd.txt", "w") as f: f.write(result.get("command", "")) f.write("\n") with open(base + ".stdout.txt", "w") as f: f.write(stdout) with open(base + ".stderr.txt", "w") as f: f.write(stderr) artifact_result = {k: v for k, v in result.items() if k not in ("stdout_tail", "stderr_tail")} with open(base + ".json", "w") as f: json.dump(artifact_result, f, indent=2, default=str) result["artifact_prefix"] = base def _threshold_for(self, label: str, topo: dict = None) -> float: if topo and topo.get("min_peak_busbw_gbps") is not None: topo_thresholds = topo.get("min_peak_busbw_gbps") if isinstance(topo_thresholds, dict): return float(topo_thresholds.get(label, 0) or 0) return float(topo_thresholds or 0) thresholds = self.cfg.get("min_peak_busbw_gbps") or {} if isinstance(thresholds, dict): op_threshold = thresholds.get(label, 0) if isinstance(op_threshold, dict): keys = [] if topo: keys.extend([ topo.get("label"), f"{topo.get('nodes')}x{topo.get('gpus_per_node')}", f"{topo.get('nodes')} nodes x {topo.get('gpus_per_node')} GPUs", str(topo.get("gpus_per_node")), ]) keys.append("default") for key in keys: if key in op_threshold: return float(op_threshold.get(key) or 0) return 0.0 return float(op_threshold or 0) return float(thresholds or 0) @staticmethod def _parse_nccl_output(stdout: str) -> dict: rows = [] avg_bus = 0.0 for line in stdout.splitlines(): stripped = line.strip() if not stripped: continue avg_match = re.search(r"Avg bus bandwidth\s*:\s*([0-9.]+)", stripped) if avg_match: avg_bus = float(avg_match.group(1)) continue if stripped.startswith("#"): continue parts = stripped.split() if len(parts) < 9: continue try: size_bytes = int(parts[0]) time_us = float(parts[5]) algbw = float(parts[6]) busbw = float(parts[7]) wrong = int(parts[8]) except (ValueError, IndexError): continue rows.append({ "size_bytes": size_bytes, "size": _format_size(size_bytes), "time_us": time_us, "algbw_gbps": algbw, "busbw_gbps": busbw, "wrong": wrong, }) peak_row = max(rows, key=lambda r: r["busbw_gbps"], default={}) return { "peak_busbw_gbps": round(float(peak_row.get("busbw_gbps", 0)), 2), "peak_algbw_gbps": round(float(peak_row.get("algbw_gbps", 0)), 2), "peak_size": peak_row.get("size", ""), "avg_busbw_gbps": round(avg_bus, 2), "by_size": rows, } @staticmethod def _parse_network_diagnostics(output: str) -> dict: networks = sorted(set(re.findall(r"Using network (\S+)", output))) gdr_enabled = sorted(set(re.findall(r"GPU Direct RDMA Enabled for HCA \d+ '([^']+)'", output))) gdr_disabled = sorted(set(re.findall(r"GPU Direct RDMA Disabled for HCA \d+ '([^']+)'", output))) ib_using = [] for line in output.splitlines(): if "NET/IB : Using" in line: text = line.split("NET/IB : ", 1)[-1].strip() if text not in ib_using: ib_using.append(text) if gdr_disabled: gdr_state = "DISABLED" elif gdr_enabled or "/GDRDMA" in output: gdr_state = "ENABLED" elif networks: gdr_state = "NOT_DISABLED_IN_LOG" else: gdr_state = "UNKNOWN" return { "networks": networks, "ib_using": ib_using[:8], "gdr_enabled_hcas": gdr_enabled, "gdr_disabled_hcas": gdr_disabled, "gpu_direct_rdma": gdr_state, } @staticmethod def print_results(results: dict, console: Console = None): c = console or Console() if results.get("error"): c.print(f"[bold red]Multi-node NCCL failed: {results['error']}[/bold red]") else: c.print("[bold green]Multi-node NCCL complete[/bold green]" if results.get("passed") else "[bold red]Multi-node NCCL failed[/bold red]") preflight = results.get("preflight", {}) if preflight.get("checks"): table = Table(title="Preflight") table.add_column("Check") table.add_column("Status") table.add_column("Detail") for check in preflight["checks"]: table.add_row(check["name"], check["status"], str(check.get("detail", ""))) c.print(table) for op, data in (results.get("tests") or {}).items(): table = Table(title=f"Multi-node NCCL {op}") table.add_column("Topology") table.add_column("Peak Bus BW") table.add_column("Peak Size") table.add_column("Threshold") table.add_column("Status") for topo in data.get("topologies", []): table.add_row( topo.get("label", ""), f"{topo.get('peak_busbw_gbps', 0):.2f} GB/s", str(topo.get("peak_size", "")), f">= {_format_gbps(topo.get('min_required_gbps', 0))} GB/s" if topo.get("min_required_gbps") else "-", topo.get("status", "?"), ) c.print(table) def _format_size(size_bytes: int) -> str: units = [("G", 1024 ** 3), ("M", 1024 ** 2), ("K", 1024)] for suffix, factor in units: if size_bytes >= factor and size_bytes % factor == 0: return f"{size_bytes // factor}{suffix}" return str(size_bytes) def _format_gbps(value) -> str: try: numeric = float(value) except (TypeError, ValueError): return str(value) if numeric.is_integer(): return f"{numeric:.0f}" return f"{numeric:.2f}" def _safe_name(value: str) -> str: text = re.sub(r"[^A-Za-z0-9_.-]+", "_", value.strip()) text = re.sub(r"_+", "_", text).strip("_") return text[:160] or "case"