"""Multi-node NCCL benchmark wrapper for nccl-tests via mpirun.""" import os import re import shutil import subprocess from datetime import datetime from typing import Optional from rich.console import Console from rich.table import Table from modules.gpu_specs import resolve_tools_dir _TEST_ALIASES = { "allreduce": "all_reduce_perf", "all_reduce": "all_reduce_perf", "all_reduce_perf": "all_reduce_perf", "alltoall": "alltoall_perf", "all_to_all": "alltoall_perf", "alltoall_perf": "alltoall_perf", } _OP_LABELS = { "all_reduce_perf": "allreduce", "alltoall_perf": "alltoall", } class MultiNodeNCCLTest: """Run cross-node NCCL tests with a PDF-style message-size sweep.""" def __init__(self, config: dict): self.config = config self.cfg = config.get("multinode_nccl", {}) or {} self.tools_dir = resolve_tools_dir(config) self.console = Console() def _find_mpirun(self) -> Optional[str]: configured = self.cfg.get("mpirun_path") if configured and os.path.isfile(str(configured)) and os.access(str(configured), os.X_OK): return str(configured) for cmd in ["mpirun", "mpiexec", os.path.join(self.tools_dir, "mpi", "bin", "mpirun")]: found = shutil.which(cmd) if found: return found return None def _find_nccl_test(self, binary_name: str) -> Optional[str]: configured = self.cfg.get("nccl_tests_dir") candidates = [] if configured: candidates.append(os.path.join(configured, binary_name)) candidates.append(os.path.join(self.tools_dir, "nccl-tests", "build", binary_name)) found = shutil.which(binary_name) if found: candidates.insert(0, found) for path in candidates: if path and os.path.isfile(path) and os.access(path, os.X_OK): return path return None def _tests(self) -> list[str]: configured = self.cfg.get("tests") or ["all_reduce_perf", "alltoall_perf"] tests = [] for name in configured: binary = _TEST_ALIASES.get(str(name).lower()) if binary and binary not in tests: tests.append(binary) return tests def _hosts(self) -> list[dict]: hosts = self.cfg.get("hosts") or [] normalized = [] for host in hosts: if isinstance(host, str): normalized.append({"addr": host, "slots": 8}) elif isinstance(host, dict): normalized.append({ "name": host.get("name") or host.get("addr"), "addr": host.get("addr") or host.get("host") or host.get("ip"), "slots": int(host.get("slots", 8)), }) return [h for h in normalized if h.get("addr")] def _topologies(self) -> list[dict]: topologies = self.cfg.get("topologies") or [{"nodes": 2, "gpus_per_node": 8}] normalized = [] for topo in topologies: nodes = int(topo.get("nodes", 2)) gpus_per_node = int(topo.get("gpus_per_node", topo.get("gpn", 8))) normalized.append({ "nodes": nodes, "gpus_per_node": gpus_per_node, "label": topo.get("label") or f"{nodes} nodes x {gpus_per_node} GPUs", }) return normalized def _env_exports(self) -> list[tuple[str, str]]: env_cfg = { "NCCL_DEBUG": self.cfg.get("debug", "WARN"), "NCCL_SOCKET_IFNAME": self.cfg.get("socket_ifname"), "NCCL_IB_GID_INDEX": self.cfg.get("ib_gid_index"), "NCCL_IB_SL": self.cfg.get("ib_sl"), "NCCL_IB_TC": self.cfg.get("ib_tc"), "NCCL_IB_HCA": self.cfg.get("ib_hca"), "NCCL_IB_TIMEOUT": self.cfg.get("ib_timeout"), "NCCL_IB_QPS_PER_CONNECTION": self.cfg.get("qps_per_connection"), "NCCL_MIN_NCHANNELS": self.cfg.get("min_nchannels"), "NCCL_NET_PLUGIN": self.cfg.get("net_plugin"), "NCCL_NVLS_ENABLE": self.cfg.get("nvls_enable"), "NCCL_IB_SPLIT_DATA_ON_QPS": self.cfg.get("split_data_on_qps"), } mpi_ld_preload = self._mpi_ld_preload() if mpi_ld_preload: env_cfg["LD_PRELOAD"] = mpi_ld_preload extra_ld_library_path = self._extra_ld_library_path() if extra_ld_library_path: existing = os.environ.get("LD_LIBRARY_PATH", "") env_cfg["LD_LIBRARY_PATH"] = ":".join( [extra_ld_library_path] + ([existing] if existing else []) ) return [(k, str(v)) for k, v in env_cfg.items() if v is not None] def _mpi_ld_preload(self) -> str: preload = self.cfg.get("mpi_ld_preload") if isinstance(preload, list): return " ".join(str(p) for p in preload if p) return str(preload) if preload else "" def _runtime_env(self) -> dict: env = os.environ.copy() mpi_ld_preload = self._mpi_ld_preload() if mpi_ld_preload: env["LD_PRELOAD"] = mpi_ld_preload extra_ld_library_path = self._extra_ld_library_path() if extra_ld_library_path: existing = env.get("LD_LIBRARY_PATH", "") env["LD_LIBRARY_PATH"] = ":".join( [extra_ld_library_path] + ([existing] if existing else []) ) return env def _extra_ld_library_path(self) -> str: paths = self.cfg.get("extra_ld_library_path") if isinstance(paths, list): return ":".join(str(p) for p in paths if p) return str(paths) if paths else "" def _preflight(self, mpirun: Optional[str], tests: list[str], hosts: list[dict]) -> dict: checks = [] checks.append({"name": "mpirun", "status": "PASS" if mpirun else "FAIL", "detail": mpirun or "not found"}) checks.append({"name": "hosts", "status": "PASS" if len(hosts) >= 2 else "FAIL", "detail": f"{len(hosts)} configured"}) for binary in tests: path = self._find_nccl_test(binary) checks.append({"name": binary, "status": "PASS" if path else "FAIL", "detail": path or "not found"}) if self.cfg.get("ssh_preflight", True): user = self.cfg.get("ssh_user", "root") for host in hosts: target = f"{user}@{host['addr']}" cmd = ["ssh", "-o", "BatchMode=yes", "-o", "ConnectTimeout=5", target, "hostname"] try: r = subprocess.run(cmd, capture_output=True, text=True, timeout=8, env=self._runtime_env()) detail = r.stdout.strip() or r.stderr.strip()[:120] checks.append({ "name": f"ssh {host['addr']}", "status": "PASS" if r.returncode == 0 else "WARN", "detail": detail, }) except Exception as e: checks.append({"name": f"ssh {host['addr']}", "status": "WARN", "detail": str(e)}) return { "checks": checks, "passed": all(c["status"] == "PASS" for c in checks if not c["name"].startswith("ssh ")), } def run(self) -> dict: mpirun = self._find_mpirun() tests = self._tests() hosts = self._hosts() topologies = self._topologies() preflight = self._preflight(mpirun, tests, hosts) if not preflight["passed"]: return { "passed": False, "source": "nccl-tests-mpirun", "mode": self.cfg.get("mode", "sweep"), "hosts": hosts, "preflight": preflight, "tests": {}, "error": "multinode NCCL preflight failed", "timestamp": datetime.now().isoformat(), } results = {} for binary in tests: label = _OP_LABELS[binary] binary_path = self._find_nccl_test(binary) op_results = [] for topo in topologies: op_results.append(self._run_topology(mpirun, binary_path, label, hosts, topo)) results[label] = {"binary": binary_path, "topologies": op_results} passed = all( topo.get("status") == "PASS" for op in results.values() for topo in op.get("topologies", []) ) return { "passed": passed, "source": "nccl-tests-mpirun", "mode": self.cfg.get("mode", "sweep"), "hosts": hosts, "preflight": preflight, "tests": results, "timestamp": datetime.now().isoformat(), } def _run_topology(self, mpirun: str, binary: str, label: str, hosts: list[dict], topo: dict) -> dict: nodes = topo["nodes"] gpus_per_node = topo["gpus_per_node"] selected_hosts = hosts[:nodes] host_arg = ",".join(f"{h['addr']}:{gpus_per_node}" for h in selected_hosts) ranks = nodes * gpus_per_node cmd = [ mpirun, "--allow-run-as-root", "--mca", "btl_openib_warn_no_device_params_found", "0", "--mca", "btl_tcp_if_include", str(self.cfg.get("socket_ifname", "bond0")), "-H", host_arg, "--map-by", f"ppr:{gpus_per_node}:node", "-np", str(ranks), ] for key, value in self._env_exports(): cmd.extend(["-x", f"{key}={value}"]) cmd.extend([ binary, "-b", str(self.cfg.get("begin_size", "1k")), "-e", str(self.cfg.get("end_size", "16g")), "-g", str(self.cfg.get("gpus_per_rank", 1)), "-f", str(self.cfg.get("step_factor", 2)), "-w", str(self.cfg.get("warmup_iters", 10)), ]) if self.cfg.get("iters") is not None: cmd.extend(["-n", str(self.cfg["iters"])]) timeout = int(self.cfg.get("timeout_sec", 1800)) started = datetime.now().isoformat() try: r = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout, env=self._runtime_env()) except subprocess.TimeoutExpired: return { "label": topo["label"], "nodes": nodes, "gpus_per_node": gpus_per_node, "ranks": ranks, "hosts": selected_hosts, "command": " ".join(cmd), "status": "FAIL", "error": f"timeout after {timeout}s", "started_at": started, } parsed = self._parse_nccl_output(r.stdout) threshold = self._threshold_for(label) wrong = sum(row.get("wrong", 0) for row in parsed["by_size"]) has_bw = parsed["peak_busbw_gbps"] > 0 status = "PASS" if r.returncode == 0 and has_bw and wrong == 0 and parsed["peak_busbw_gbps"] >= threshold else "FAIL" return { "label": topo["label"], "nodes": nodes, "gpus_per_node": gpus_per_node, "ranks": ranks, "hosts": selected_hosts, "command": " ".join(cmd), "returncode": r.returncode, "status": status, "peak_busbw_gbps": parsed["peak_busbw_gbps"], "peak_algbw_gbps": parsed["peak_algbw_gbps"], "peak_size": parsed["peak_size"], "avg_busbw_gbps": parsed["avg_busbw_gbps"], "min_required_gbps": threshold, "wrong_count": wrong, "by_size": parsed["by_size"], "stderr_tail": r.stderr[-1200:], "stdout_tail": r.stdout[-1200:], "started_at": started, "finished_at": datetime.now().isoformat(), } def _threshold_for(self, label: str) -> float: thresholds = self.cfg.get("min_peak_busbw_gbps") or {} if isinstance(thresholds, dict): return float(thresholds.get(label, 0) or 0) return float(thresholds or 0) @staticmethod def _parse_nccl_output(stdout: str) -> dict: rows = [] avg_bus = 0.0 for line in stdout.splitlines(): stripped = line.strip() if not stripped: continue avg_match = re.search(r"Avg bus bandwidth\s*:\s*([0-9.]+)", stripped) if avg_match: avg_bus = float(avg_match.group(1)) continue if stripped.startswith("#"): continue parts = stripped.split() if len(parts) < 9: continue try: size_bytes = int(parts[0]) time_us = float(parts[5]) algbw = float(parts[6]) busbw = float(parts[7]) wrong = int(parts[8]) except (ValueError, IndexError): continue rows.append({ "size_bytes": size_bytes, "size": _format_size(size_bytes), "time_us": time_us, "algbw_gbps": algbw, "busbw_gbps": busbw, "wrong": wrong, }) peak_row = max(rows, key=lambda r: r["busbw_gbps"], default={}) return { "peak_busbw_gbps": round(float(peak_row.get("busbw_gbps", 0)), 2), "peak_algbw_gbps": round(float(peak_row.get("algbw_gbps", 0)), 2), "peak_size": peak_row.get("size", ""), "avg_busbw_gbps": round(avg_bus, 2), "by_size": rows, } @staticmethod def print_results(results: dict, console: Console = None): c = console or Console() if results.get("error"): c.print(f"[bold red]Multi-node NCCL failed: {results['error']}[/bold red]") else: c.print("[bold green]Multi-node NCCL complete[/bold green]" if results.get("passed") else "[bold red]Multi-node NCCL failed[/bold red]") preflight = results.get("preflight", {}) if preflight.get("checks"): table = Table(title="Preflight") table.add_column("Check") table.add_column("Status") table.add_column("Detail") for check in preflight["checks"]: table.add_row(check["name"], check["status"], str(check.get("detail", ""))) c.print(table) for op, data in (results.get("tests") or {}).items(): table = Table(title=f"Multi-node NCCL {op}") table.add_column("Topology") table.add_column("Peak Bus BW") table.add_column("Peak Size") table.add_column("Threshold") table.add_column("Status") for topo in data.get("topologies", []): table.add_row( topo.get("label", ""), f"{topo.get('peak_busbw_gbps', 0):.2f} GB/s", str(topo.get("peak_size", "")), f">= {topo.get('min_required_gbps', 0):.0f} GB/s" if topo.get("min_required_gbps") else "-", topo.get("status", "?"), ) c.print(table) def _format_size(size_bytes: int) -> str: units = [("G", 1024 ** 3), ("M", 1024 ** 2), ("K", 1024)] for suffix, factor in units: if size_bytes >= factor and size_bytes % factor == 0: return f"{size_bytes // factor}{suffix}" return str(size_bytes)