test_gpu_scripts/modules/multinode_nccl_test.py

488 lines
19 KiB
Python

"""Multi-node NCCL benchmark wrapper for nccl-tests via mpirun."""
import os
import re
import shutil
import subprocess
from datetime import datetime
from typing import Optional
from rich.console import Console
from rich.table import Table
from modules.gpu_specs import resolve_tools_dir
_TEST_ALIASES = {
"allreduce": "all_reduce_perf",
"all_reduce": "all_reduce_perf",
"all_reduce_perf": "all_reduce_perf",
"alltoall": "alltoall_perf",
"all_to_all": "alltoall_perf",
"alltoall_perf": "alltoall_perf",
}
_OP_LABELS = {
"all_reduce_perf": "allreduce",
"alltoall_perf": "alltoall",
}
class MultiNodeNCCLTest:
"""Run cross-node NCCL tests with a PDF-style message-size sweep."""
def __init__(self, config: dict):
self.config = config
self.cfg = config.get("multinode_nccl", {}) or {}
self.tools_dir = resolve_tools_dir(config)
self.console = Console()
def _find_mpirun(self) -> Optional[str]:
configured = self.cfg.get("mpirun_path")
if configured and os.path.isfile(str(configured)) and os.access(str(configured), os.X_OK):
return str(configured)
for cmd in ["mpirun", "mpiexec", os.path.join(self.tools_dir, "mpi", "bin", "mpirun")]:
found = shutil.which(cmd)
if found:
return found
return None
def _find_nccl_test(self, binary_name: str) -> Optional[str]:
configured = self.cfg.get("nccl_tests_dir")
candidates = []
if configured:
candidates.append(os.path.join(configured, binary_name))
candidates.append(os.path.join(self.tools_dir, "nccl-tests", "build", binary_name))
found = shutil.which(binary_name)
if found:
candidates.insert(0, found)
for path in candidates:
if path and os.path.isfile(path) and os.access(path, os.X_OK):
return path
return None
def _tests(self) -> list[str]:
configured = self.cfg.get("tests") or ["all_reduce_perf", "alltoall_perf"]
tests = []
for name in configured:
binary = _TEST_ALIASES.get(str(name).lower())
if binary and binary not in tests:
tests.append(binary)
return tests
def _hosts(self) -> list[dict]:
hosts = self.cfg.get("hosts") or []
normalized = []
for host in hosts:
if isinstance(host, str):
normalized.append({"addr": host, "slots": 8})
elif isinstance(host, dict):
normalized.append({
"name": host.get("name") or host.get("addr"),
"addr": host.get("addr") or host.get("host") or host.get("ip"),
"slots": int(host.get("slots", 8)),
})
return [h for h in normalized if h.get("addr")]
def _topologies(self) -> list[dict]:
topologies = self.cfg.get("topologies") or [{"nodes": 2, "gpus_per_node": 8}]
normalized = []
for topo in topologies:
nodes = int(topo.get("nodes", 2))
gpus_per_node = int(topo.get("gpus_per_node", topo.get("gpn", 8)))
normalized.append({
"nodes": nodes,
"gpus_per_node": gpus_per_node,
"label": topo.get("label") or f"{nodes} nodes x {gpus_per_node} GPUs",
"cuda_visible_devices": topo.get("cuda_visible_devices"),
"env": topo.get("env") or {},
"op_env": topo.get("op_env") or topo.get("test_env") or {},
"min_peak_busbw_gbps": topo.get("min_peak_busbw_gbps"),
})
return normalized
def _env_exports(self, topo: dict = None, label: str = None, binary: str = None) -> list[tuple[str, str]]:
env_cfg = {
"NCCL_DEBUG": self.cfg.get("debug", "WARN"),
"NCCL_SOCKET_IFNAME": self.cfg.get("socket_ifname"),
"NCCL_IB_GID_INDEX": self.cfg.get("ib_gid_index"),
"NCCL_IB_SL": self.cfg.get("ib_sl"),
"NCCL_IB_TC": self.cfg.get("ib_tc"),
"NCCL_IB_HCA": self.cfg.get("ib_hca"),
"NCCL_IB_TIMEOUT": self.cfg.get("ib_timeout"),
"NCCL_IB_QPS_PER_CONNECTION": self.cfg.get("qps_per_connection"),
"NCCL_MIN_NCHANNELS": self.cfg.get("min_nchannels"),
"NCCL_NET_PLUGIN": self.cfg.get("net_plugin"),
"NCCL_NVLS_ENABLE": self.cfg.get("nvls_enable"),
"NCCL_IB_SPLIT_DATA_ON_QPS": self.cfg.get("split_data_on_qps"),
}
mpi_ld_preload = self._mpi_ld_preload()
if mpi_ld_preload:
env_cfg["LD_PRELOAD"] = mpi_ld_preload
extra_ld_library_path = self._extra_ld_library_path()
if extra_ld_library_path:
existing = os.environ.get("LD_LIBRARY_PATH", "")
env_cfg["LD_LIBRARY_PATH"] = ":".join(
[extra_ld_library_path] + ([existing] if existing else [])
)
extra_env = self.cfg.get("extra_env") or {}
if isinstance(extra_env, dict):
self._merge_env(env_cfg, extra_env)
if topo:
if topo.get("cuda_visible_devices"):
env_cfg["CUDA_VISIBLE_DEVICES"] = str(topo["cuda_visible_devices"])
if isinstance(topo.get("env"), dict):
self._merge_env(env_cfg, topo["env"])
op_env = topo.get("op_env")
if isinstance(op_env, dict):
for key in (label, binary):
overrides = op_env.get(key)
if isinstance(overrides, dict):
self._merge_env(env_cfg, overrides)
return [(k, str(v)) for k, v in env_cfg.items() if v is not None]
@staticmethod
def _merge_env(env_cfg: dict, overrides: dict):
for key, value in overrides.items():
key = str(key)
if value is None:
env_cfg.pop(key, None)
else:
env_cfg[key] = str(value)
def _mpi_ld_preload(self) -> str:
preload = self.cfg.get("mpi_ld_preload")
if isinstance(preload, list):
return " ".join(str(p) for p in preload if p)
return str(preload) if preload else ""
def _runtime_env(self) -> dict:
env = os.environ.copy()
mpi_ld_preload = self._mpi_ld_preload()
if mpi_ld_preload:
env["LD_PRELOAD"] = mpi_ld_preload
extra_ld_library_path = self._extra_ld_library_path()
if extra_ld_library_path:
existing = env.get("LD_LIBRARY_PATH", "")
env["LD_LIBRARY_PATH"] = ":".join(
[extra_ld_library_path] + ([existing] if existing else [])
)
return env
def _extra_ld_library_path(self) -> str:
paths = self.cfg.get("extra_ld_library_path")
if isinstance(paths, list):
return ":".join(str(p) for p in paths if p)
return str(paths) if paths else ""
def _preflight(self, mpirun: Optional[str], tests: list[str], hosts: list[dict]) -> dict:
checks = []
checks.append({"name": "mpirun", "status": "PASS" if mpirun else "FAIL", "detail": mpirun or "not found"})
checks.append({"name": "hosts", "status": "PASS" if len(hosts) >= 2 else "FAIL", "detail": f"{len(hosts)} configured"})
for binary in tests:
path = self._find_nccl_test(binary)
checks.append({"name": binary, "status": "PASS" if path else "FAIL", "detail": path or "not found"})
if self.cfg.get("ssh_preflight", True):
user = self.cfg.get("ssh_user", "root")
for host in hosts:
target = f"{user}@{host['addr']}"
cmd = [
"ssh",
"-o", "BatchMode=yes",
"-o", "ConnectTimeout=5",
"-o", "StrictHostKeyChecking=accept-new",
target,
"hostname",
]
try:
r = subprocess.run(cmd, capture_output=True, text=True, timeout=8, env=self._runtime_env())
detail = r.stdout.strip() or r.stderr.strip()[:120]
checks.append({
"name": f"ssh {host['addr']}",
"status": "PASS" if r.returncode == 0 else "WARN",
"detail": detail,
})
except Exception as e:
checks.append({"name": f"ssh {host['addr']}", "status": "WARN", "detail": str(e)})
return {
"checks": checks,
"passed": all(c["status"] == "PASS" for c in checks if not c["name"].startswith("ssh ")),
}
def run(self) -> dict:
mpirun = self._find_mpirun()
tests = self._tests()
hosts = self._hosts()
topologies = self._topologies()
preflight = self._preflight(mpirun, tests, hosts)
if not preflight["passed"]:
return {
"passed": False,
"source": "nccl-tests-mpirun",
"mode": self.cfg.get("mode", "sweep"),
"hosts": hosts,
"preflight": preflight,
"tests": {},
"error": "multinode NCCL preflight failed",
"timestamp": datetime.now().isoformat(),
}
results = {}
for binary in tests:
label = _OP_LABELS[binary]
binary_path = self._find_nccl_test(binary)
op_results = []
for topo in topologies:
op_results.append(self._run_topology(mpirun, binary_path, label, hosts, topo))
results[label] = {"binary": binary_path, "topologies": op_results}
passed = all(
topo.get("status") == "PASS"
for op in results.values()
for topo in op.get("topologies", [])
)
return {
"passed": passed,
"source": "nccl-tests-mpirun",
"mode": self.cfg.get("mode", "sweep"),
"hosts": hosts,
"preflight": preflight,
"tests": results,
"timestamp": datetime.now().isoformat(),
}
def _run_topology(self, mpirun: str, binary: str, label: str, hosts: list[dict], topo: dict) -> dict:
nodes = topo["nodes"]
gpus_per_node = topo["gpus_per_node"]
selected_hosts = hosts[:nodes]
host_arg = ",".join(f"{h['addr']}:{gpus_per_node}" for h in selected_hosts)
ranks = nodes * gpus_per_node
cmd = [
mpirun,
"--allow-run-as-root",
"--mca", "btl_openib_warn_no_device_params_found", "0",
"--mca", "btl_tcp_if_include", str(self.cfg.get("socket_ifname", "bond0")),
"--mca", "oob_tcp_if_include", str(self.cfg.get("oob_tcp_ifname", self.cfg.get("socket_ifname", "bond0"))),
"-H", host_arg,
"--map-by", f"ppr:{gpus_per_node}:node",
"-np", str(ranks),
]
plm_rsh_args = self.cfg.get("plm_rsh_args")
if plm_rsh_args:
cmd.extend(["--mca", "plm_rsh_args", str(plm_rsh_args)])
for key, value in self._env_exports(topo=topo, label=label, binary=os.path.basename(binary)):
cmd.extend(["-x", f"{key}={value}"])
cmd.extend([
binary,
"-b", str(self.cfg.get("begin_size", "1k")),
"-e", str(self.cfg.get("end_size", "16g")),
"-g", str(self.cfg.get("gpus_per_rank", 1)),
"-f", str(self.cfg.get("step_factor", 2)),
"-w", str(self.cfg.get("warmup_iters", 10)),
])
if self.cfg.get("iters") is not None:
cmd.extend(["-n", str(self.cfg["iters"])])
timeout = int(self.cfg.get("timeout_sec", 1800))
started = datetime.now().isoformat()
try:
r = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout, env=self._runtime_env())
except subprocess.TimeoutExpired:
return {
"label": topo["label"],
"nodes": nodes,
"gpus_per_node": gpus_per_node,
"ranks": ranks,
"hosts": selected_hosts,
"command": " ".join(cmd),
"status": "FAIL",
"error": f"timeout after {timeout}s",
"started_at": started,
}
parsed = self._parse_nccl_output(r.stdout)
net_diag = self._parse_network_diagnostics(r.stdout + "\n" + r.stderr)
threshold = self._threshold_for(label, topo)
wrong = sum(row.get("wrong", 0) for row in parsed["by_size"])
has_bw = parsed["peak_busbw_gbps"] > 0
status = "PASS" if r.returncode == 0 and has_bw and wrong == 0 and parsed["peak_busbw_gbps"] >= threshold else "FAIL"
return {
"label": topo["label"],
"nodes": nodes,
"gpus_per_node": gpus_per_node,
"ranks": ranks,
"hosts": selected_hosts,
"cuda_visible_devices": topo.get("cuda_visible_devices"),
"command": " ".join(cmd),
"returncode": r.returncode,
"status": status,
"peak_busbw_gbps": parsed["peak_busbw_gbps"],
"peak_algbw_gbps": parsed["peak_algbw_gbps"],
"peak_size": parsed["peak_size"],
"avg_busbw_gbps": parsed["avg_busbw_gbps"],
"min_required_gbps": threshold,
"wrong_count": wrong,
"network": net_diag,
"by_size": parsed["by_size"],
"stderr_tail": r.stderr[-1200:],
"stdout_tail": r.stdout[-1200:],
"started_at": started,
"finished_at": datetime.now().isoformat(),
}
def _threshold_for(self, label: str, topo: dict = None) -> float:
if topo and topo.get("min_peak_busbw_gbps") is not None:
topo_thresholds = topo.get("min_peak_busbw_gbps")
if isinstance(topo_thresholds, dict):
return float(topo_thresholds.get(label, 0) or 0)
return float(topo_thresholds or 0)
thresholds = self.cfg.get("min_peak_busbw_gbps") or {}
if isinstance(thresholds, dict):
op_threshold = thresholds.get(label, 0)
if isinstance(op_threshold, dict):
keys = []
if topo:
keys.extend([
topo.get("label"),
f"{topo.get('nodes')}x{topo.get('gpus_per_node')}",
f"{topo.get('nodes')} nodes x {topo.get('gpus_per_node')} GPUs",
str(topo.get("gpus_per_node")),
])
keys.append("default")
for key in keys:
if key in op_threshold:
return float(op_threshold.get(key) or 0)
return 0.0
return float(op_threshold or 0)
return float(thresholds or 0)
@staticmethod
def _parse_nccl_output(stdout: str) -> dict:
rows = []
avg_bus = 0.0
for line in stdout.splitlines():
stripped = line.strip()
if not stripped:
continue
avg_match = re.search(r"Avg bus bandwidth\s*:\s*([0-9.]+)", stripped)
if avg_match:
avg_bus = float(avg_match.group(1))
continue
if stripped.startswith("#"):
continue
parts = stripped.split()
if len(parts) < 9:
continue
try:
size_bytes = int(parts[0])
time_us = float(parts[5])
algbw = float(parts[6])
busbw = float(parts[7])
wrong = int(parts[8])
except (ValueError, IndexError):
continue
rows.append({
"size_bytes": size_bytes,
"size": _format_size(size_bytes),
"time_us": time_us,
"algbw_gbps": algbw,
"busbw_gbps": busbw,
"wrong": wrong,
})
peak_row = max(rows, key=lambda r: r["busbw_gbps"], default={})
return {
"peak_busbw_gbps": round(float(peak_row.get("busbw_gbps", 0)), 2),
"peak_algbw_gbps": round(float(peak_row.get("algbw_gbps", 0)), 2),
"peak_size": peak_row.get("size", ""),
"avg_busbw_gbps": round(avg_bus, 2),
"by_size": rows,
}
@staticmethod
def _parse_network_diagnostics(output: str) -> dict:
networks = sorted(set(re.findall(r"Using network (\S+)", output)))
gdr_enabled = sorted(set(re.findall(r"GPU Direct RDMA Enabled for HCA \d+ '([^']+)'", output)))
gdr_disabled = sorted(set(re.findall(r"GPU Direct RDMA Disabled for HCA \d+ '([^']+)'", output)))
ib_using = []
for line in output.splitlines():
if "NET/IB : Using" in line:
text = line.split("NET/IB : ", 1)[-1].strip()
if text not in ib_using:
ib_using.append(text)
if gdr_disabled:
gdr_state = "DISABLED"
elif gdr_enabled or "/GDRDMA" in output:
gdr_state = "ENABLED"
elif networks:
gdr_state = "NOT_DISABLED_IN_LOG"
else:
gdr_state = "UNKNOWN"
return {
"networks": networks,
"ib_using": ib_using[:8],
"gdr_enabled_hcas": gdr_enabled,
"gdr_disabled_hcas": gdr_disabled,
"gpu_direct_rdma": gdr_state,
}
@staticmethod
def print_results(results: dict, console: Console = None):
c = console or Console()
if results.get("error"):
c.print(f"[bold red]Multi-node NCCL failed: {results['error']}[/bold red]")
else:
c.print("[bold green]Multi-node NCCL complete[/bold green]" if results.get("passed") else "[bold red]Multi-node NCCL failed[/bold red]")
preflight = results.get("preflight", {})
if preflight.get("checks"):
table = Table(title="Preflight")
table.add_column("Check")
table.add_column("Status")
table.add_column("Detail")
for check in preflight["checks"]:
table.add_row(check["name"], check["status"], str(check.get("detail", "")))
c.print(table)
for op, data in (results.get("tests") or {}).items():
table = Table(title=f"Multi-node NCCL {op}")
table.add_column("Topology")
table.add_column("Peak Bus BW")
table.add_column("Peak Size")
table.add_column("Threshold")
table.add_column("Status")
for topo in data.get("topologies", []):
table.add_row(
topo.get("label", ""),
f"{topo.get('peak_busbw_gbps', 0):.2f} GB/s",
str(topo.get("peak_size", "")),
f">= {_format_gbps(topo.get('min_required_gbps', 0))} GB/s" if topo.get("min_required_gbps") else "-",
topo.get("status", "?"),
)
c.print(table)
def _format_size(size_bytes: int) -> str:
units = [("G", 1024 ** 3), ("M", 1024 ** 2), ("K", 1024)]
for suffix, factor in units:
if size_bytes >= factor and size_bytes % factor == 0:
return f"{size_bytes // factor}{suffix}"
return str(size_bytes)
def _format_gbps(value) -> str:
try:
numeric = float(value)
except (TypeError, ValueError):
return str(value)
if numeric.is_integer():
return f"{numeric:.0f}"
return f"{numeric:.2f}"