test_gpu_scripts/modules/multinode_nccl_test.py

534 lines
21 KiB
Python

"""Multi-node NCCL benchmark wrapper for nccl-tests via mpirun."""
import json
import os
import re
import shutil
import subprocess
from datetime import datetime
from typing import Optional
from rich.console import Console
from rich.table import Table
from modules.gpu_specs import resolve_tools_dir
_TEST_ALIASES = {
"allreduce": "all_reduce_perf",
"all_reduce": "all_reduce_perf",
"all_reduce_perf": "all_reduce_perf",
"allgather": "all_gather_perf",
"all_gather": "all_gather_perf",
"all_gather_perf": "all_gather_perf",
"alltoall": "alltoall_perf",
"all_to_all": "alltoall_perf",
"alltoall_perf": "alltoall_perf",
"broadcast": "broadcast_perf",
"broadcast_perf": "broadcast_perf",
"reducescatter": "reduce_scatter_perf",
"reduce_scatter": "reduce_scatter_perf",
"reduce_scatter_perf": "reduce_scatter_perf",
"sendrecv": "sendrecv_perf",
"send_recv": "sendrecv_perf",
"sendrecv_perf": "sendrecv_perf",
}
_OP_LABELS = {
"all_reduce_perf": "allreduce",
"all_gather_perf": "allgather",
"alltoall_perf": "alltoall",
"broadcast_perf": "broadcast",
"reduce_scatter_perf": "reducescatter",
"sendrecv_perf": "sendrecv",
}
class MultiNodeNCCLTest:
"""Run cross-node NCCL tests with a PDF-style message-size sweep."""
def __init__(self, config: dict):
self.config = config
self.cfg = config.get("multinode_nccl", {}) or {}
self.tools_dir = resolve_tools_dir(config)
self.console = Console()
self.artifact_dir = os.environ.get("MULTINODE_NCCL_ARTIFACT_DIR") or self.cfg.get("artifact_dir")
def _find_mpirun(self) -> Optional[str]:
configured = self.cfg.get("mpirun_path")
if configured and os.path.isfile(str(configured)) and os.access(str(configured), os.X_OK):
return str(configured)
for cmd in ["mpirun", "mpiexec", os.path.join(self.tools_dir, "mpi", "bin", "mpirun")]:
found = shutil.which(cmd)
if found:
return found
return None
def _find_nccl_test(self, binary_name: str) -> Optional[str]:
configured = self.cfg.get("nccl_tests_dir")
candidates = []
if configured:
candidates.append(os.path.join(configured, binary_name))
candidates.append(os.path.join(self.tools_dir, "nccl-tests", "build", binary_name))
found = shutil.which(binary_name)
if found:
candidates.insert(0, found)
for path in candidates:
if path and os.path.isfile(path) and os.access(path, os.X_OK):
return path
return None
def _tests(self) -> list[str]:
configured = self.cfg.get("tests") or ["all_reduce_perf", "alltoall_perf"]
tests = []
for name in configured:
binary = _TEST_ALIASES.get(str(name).lower())
if binary and binary not in tests:
tests.append(binary)
return tests
def _hosts(self) -> list[dict]:
hosts = self.cfg.get("hosts") or []
normalized = []
for host in hosts:
if isinstance(host, str):
normalized.append({"addr": host, "slots": 8})
elif isinstance(host, dict):
normalized.append({
"name": host.get("name") or host.get("addr"),
"addr": host.get("addr") or host.get("host") or host.get("ip"),
"slots": int(host.get("slots", 8)),
})
return [h for h in normalized if h.get("addr")]
def _topologies(self) -> list[dict]:
topologies = self.cfg.get("topologies") or [{"nodes": 2, "gpus_per_node": 8}]
normalized = []
for topo in topologies:
nodes = int(topo.get("nodes", 2))
gpus_per_node = int(topo.get("gpus_per_node", topo.get("gpn", 8)))
normalized.append({
"nodes": nodes,
"gpus_per_node": gpus_per_node,
"label": topo.get("label") or f"{nodes} nodes x {gpus_per_node} GPUs",
"cuda_visible_devices": topo.get("cuda_visible_devices"),
"env": topo.get("env") or {},
"op_env": topo.get("op_env") or topo.get("test_env") or {},
"min_peak_busbw_gbps": topo.get("min_peak_busbw_gbps"),
})
return normalized
def _env_exports(self, topo: dict = None, label: str = None, binary: str = None) -> list[tuple[str, str]]:
env_cfg = {
"NCCL_DEBUG": self.cfg.get("debug", "WARN"),
"NCCL_SOCKET_IFNAME": self.cfg.get("socket_ifname"),
"NCCL_IB_GID_INDEX": self.cfg.get("ib_gid_index"),
"NCCL_IB_SL": self.cfg.get("ib_sl"),
"NCCL_IB_TC": self.cfg.get("ib_tc"),
"NCCL_IB_HCA": self.cfg.get("ib_hca"),
"NCCL_IB_TIMEOUT": self.cfg.get("ib_timeout"),
"NCCL_IB_QPS_PER_CONNECTION": self.cfg.get("qps_per_connection"),
"NCCL_MIN_NCHANNELS": self.cfg.get("min_nchannels"),
"NCCL_NET_PLUGIN": self.cfg.get("net_plugin"),
"NCCL_NVLS_ENABLE": self.cfg.get("nvls_enable"),
"NCCL_IB_SPLIT_DATA_ON_QPS": self.cfg.get("split_data_on_qps"),
}
mpi_ld_preload = self._mpi_ld_preload()
if mpi_ld_preload:
env_cfg["LD_PRELOAD"] = mpi_ld_preload
extra_ld_library_path = self._extra_ld_library_path()
if extra_ld_library_path:
existing = os.environ.get("LD_LIBRARY_PATH", "")
env_cfg["LD_LIBRARY_PATH"] = ":".join(
[extra_ld_library_path] + ([existing] if existing else [])
)
extra_env = self.cfg.get("extra_env") or {}
if isinstance(extra_env, dict):
self._merge_env(env_cfg, extra_env)
if topo:
if topo.get("cuda_visible_devices"):
env_cfg["CUDA_VISIBLE_DEVICES"] = str(topo["cuda_visible_devices"])
if isinstance(topo.get("env"), dict):
self._merge_env(env_cfg, topo["env"])
op_env = topo.get("op_env")
if isinstance(op_env, dict):
for key in (label, binary):
overrides = op_env.get(key)
if isinstance(overrides, dict):
self._merge_env(env_cfg, overrides)
return [(k, str(v)) for k, v in env_cfg.items() if v is not None]
@staticmethod
def _merge_env(env_cfg: dict, overrides: dict):
for key, value in overrides.items():
key = str(key)
if value is None:
env_cfg.pop(key, None)
else:
env_cfg[key] = str(value)
def _mpi_ld_preload(self) -> str:
preload = self.cfg.get("mpi_ld_preload")
if isinstance(preload, list):
return " ".join(str(p) for p in preload if p)
return str(preload) if preload else ""
def _runtime_env(self) -> dict:
env = os.environ.copy()
mpi_ld_preload = self._mpi_ld_preload()
if mpi_ld_preload:
env["LD_PRELOAD"] = mpi_ld_preload
extra_ld_library_path = self._extra_ld_library_path()
if extra_ld_library_path:
existing = env.get("LD_LIBRARY_PATH", "")
env["LD_LIBRARY_PATH"] = ":".join(
[extra_ld_library_path] + ([existing] if existing else [])
)
return env
def _extra_ld_library_path(self) -> str:
paths = self.cfg.get("extra_ld_library_path")
if isinstance(paths, list):
return ":".join(str(p) for p in paths if p)
return str(paths) if paths else ""
def _preflight(self, mpirun: Optional[str], tests: list[str], hosts: list[dict]) -> dict:
checks = []
checks.append({"name": "mpirun", "status": "PASS" if mpirun else "FAIL", "detail": mpirun or "not found"})
checks.append({"name": "hosts", "status": "PASS" if len(hosts) >= 2 else "FAIL", "detail": f"{len(hosts)} configured"})
for binary in tests:
path = self._find_nccl_test(binary)
checks.append({"name": binary, "status": "PASS" if path else "FAIL", "detail": path or "not found"})
if self.cfg.get("ssh_preflight", True):
user = self.cfg.get("ssh_user", "root")
for host in hosts:
target = f"{user}@{host['addr']}"
cmd = [
"ssh",
"-o", "BatchMode=yes",
"-o", "ConnectTimeout=5",
"-o", "StrictHostKeyChecking=accept-new",
target,
"hostname",
]
try:
r = subprocess.run(cmd, capture_output=True, text=True, timeout=8, env=self._runtime_env())
detail = r.stdout.strip() or r.stderr.strip()[:120]
checks.append({
"name": f"ssh {host['addr']}",
"status": "PASS" if r.returncode == 0 else "WARN",
"detail": detail,
})
except Exception as e:
checks.append({"name": f"ssh {host['addr']}", "status": "WARN", "detail": str(e)})
return {
"checks": checks,
"passed": all(c["status"] == "PASS" for c in checks if not c["name"].startswith("ssh ")),
}
def run(self) -> dict:
mpirun = self._find_mpirun()
tests = self._tests()
hosts = self._hosts()
topologies = self._topologies()
preflight = self._preflight(mpirun, tests, hosts)
if not preflight["passed"]:
return {
"passed": False,
"source": "nccl-tests-mpirun",
"mode": self.cfg.get("mode", "sweep"),
"hosts": hosts,
"preflight": preflight,
"tests": {},
"error": "multinode NCCL preflight failed",
"timestamp": datetime.now().isoformat(),
}
results = {}
for binary in tests:
label = _OP_LABELS[binary]
binary_path = self._find_nccl_test(binary)
op_results = []
for topo in topologies:
op_results.append(self._run_topology(mpirun, binary_path, label, hosts, topo))
results[label] = {"binary": binary_path, "topologies": op_results}
passed = all(
topo.get("status") == "PASS"
for op in results.values()
for topo in op.get("topologies", [])
)
return {
"passed": passed,
"source": "nccl-tests-mpirun",
"mode": self.cfg.get("mode", "sweep"),
"hosts": hosts,
"preflight": preflight,
"tests": results,
"artifact_dir": self.artifact_dir,
"timestamp": datetime.now().isoformat(),
}
def _run_topology(self, mpirun: str, binary: str, label: str, hosts: list[dict], topo: dict) -> dict:
nodes = topo["nodes"]
gpus_per_node = topo["gpus_per_node"]
selected_hosts = hosts[:nodes]
host_arg = ",".join(f"{h['addr']}:{gpus_per_node}" for h in selected_hosts)
ranks = nodes * gpus_per_node
cmd = [
mpirun,
"--allow-run-as-root",
"--mca", "btl_openib_warn_no_device_params_found", "0",
"--mca", "btl_tcp_if_include", str(self.cfg.get("socket_ifname", "bond0")),
"--mca", "oob_tcp_if_include", str(self.cfg.get("oob_tcp_ifname", self.cfg.get("socket_ifname", "bond0"))),
"-H", host_arg,
"--map-by", f"ppr:{gpus_per_node}:node",
"-np", str(ranks),
]
plm_rsh_args = self.cfg.get("plm_rsh_args")
if plm_rsh_args:
cmd.extend(["--mca", "plm_rsh_args", str(plm_rsh_args)])
for key, value in self._env_exports(topo=topo, label=label, binary=os.path.basename(binary)):
cmd.extend(["-x", f"{key}={value}"])
cmd.extend([
binary,
"-b", str(self.cfg.get("begin_size", "1k")),
"-e", str(self.cfg.get("end_size", "16g")),
"-g", str(self.cfg.get("gpus_per_rank", 1)),
"-f", str(self.cfg.get("step_factor", 2)),
"-w", str(self.cfg.get("warmup_iters", 10)),
])
if self.cfg.get("iters") is not None:
cmd.extend(["-n", str(self.cfg["iters"])])
timeout = int(self.cfg.get("timeout_sec", 1800))
started = datetime.now().isoformat()
try:
r = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout, env=self._runtime_env())
except subprocess.TimeoutExpired:
result = {
"label": topo["label"],
"nodes": nodes,
"gpus_per_node": gpus_per_node,
"ranks": ranks,
"hosts": selected_hosts,
"command": " ".join(cmd),
"status": "FAIL",
"error": f"timeout after {timeout}s",
"started_at": started,
}
self._write_artifacts(label, topo, result, "", "")
return result
parsed = self._parse_nccl_output(r.stdout)
net_diag = self._parse_network_diagnostics(r.stdout + "\n" + r.stderr)
threshold = self._threshold_for(label, topo)
wrong = sum(row.get("wrong", 0) for row in parsed["by_size"])
has_bw = parsed["peak_busbw_gbps"] > 0
status = "PASS" if r.returncode == 0 and has_bw and wrong == 0 and parsed["peak_busbw_gbps"] >= threshold else "FAIL"
result = {
"label": topo["label"],
"nodes": nodes,
"gpus_per_node": gpus_per_node,
"ranks": ranks,
"hosts": selected_hosts,
"cuda_visible_devices": topo.get("cuda_visible_devices"),
"command": " ".join(cmd),
"returncode": r.returncode,
"status": status,
"peak_busbw_gbps": parsed["peak_busbw_gbps"],
"peak_algbw_gbps": parsed["peak_algbw_gbps"],
"peak_size": parsed["peak_size"],
"avg_busbw_gbps": parsed["avg_busbw_gbps"],
"min_required_gbps": threshold,
"wrong_count": wrong,
"network": net_diag,
"by_size": parsed["by_size"],
"stderr_tail": r.stderr[-1200:],
"stdout_tail": r.stdout[-1200:],
"started_at": started,
"finished_at": datetime.now().isoformat(),
}
self._write_artifacts(label, topo, result, r.stdout, r.stderr)
return result
def _write_artifacts(self, label: str, topo: dict, result: dict, stdout: str, stderr: str):
if not self.artifact_dir:
return
os.makedirs(self.artifact_dir, exist_ok=True)
prefix = _safe_name(f"{label}_{topo.get('nodes')}x{topo.get('gpus_per_node')}_{topo.get('label')}")
base = os.path.join(self.artifact_dir, prefix)
with open(base + ".cmd.txt", "w") as f:
f.write(result.get("command", ""))
f.write("\n")
with open(base + ".stdout.txt", "w") as f:
f.write(stdout)
with open(base + ".stderr.txt", "w") as f:
f.write(stderr)
artifact_result = {k: v for k, v in result.items() if k not in ("stdout_tail", "stderr_tail")}
with open(base + ".json", "w") as f:
json.dump(artifact_result, f, indent=2, default=str)
result["artifact_prefix"] = base
def _threshold_for(self, label: str, topo: dict = None) -> float:
if topo and topo.get("min_peak_busbw_gbps") is not None:
topo_thresholds = topo.get("min_peak_busbw_gbps")
if isinstance(topo_thresholds, dict):
return float(topo_thresholds.get(label, 0) or 0)
return float(topo_thresholds or 0)
thresholds = self.cfg.get("min_peak_busbw_gbps") or {}
if isinstance(thresholds, dict):
op_threshold = thresholds.get(label, 0)
if isinstance(op_threshold, dict):
keys = []
if topo:
keys.extend([
topo.get("label"),
f"{topo.get('nodes')}x{topo.get('gpus_per_node')}",
f"{topo.get('nodes')} nodes x {topo.get('gpus_per_node')} GPUs",
str(topo.get("gpus_per_node")),
])
keys.append("default")
for key in keys:
if key in op_threshold:
return float(op_threshold.get(key) or 0)
return 0.0
return float(op_threshold or 0)
return float(thresholds or 0)
@staticmethod
def _parse_nccl_output(stdout: str) -> dict:
rows = []
avg_bus = 0.0
for line in stdout.splitlines():
stripped = line.strip()
if not stripped:
continue
avg_match = re.search(r"Avg bus bandwidth\s*:\s*([0-9.]+)", stripped)
if avg_match:
avg_bus = float(avg_match.group(1))
continue
if stripped.startswith("#"):
continue
parts = stripped.split()
if len(parts) < 9:
continue
try:
size_bytes = int(parts[0])
time_us = float(parts[5])
algbw = float(parts[6])
busbw = float(parts[7])
wrong = int(parts[8])
except (ValueError, IndexError):
continue
rows.append({
"size_bytes": size_bytes,
"size": _format_size(size_bytes),
"time_us": time_us,
"algbw_gbps": algbw,
"busbw_gbps": busbw,
"wrong": wrong,
})
peak_row = max(rows, key=lambda r: r["busbw_gbps"], default={})
return {
"peak_busbw_gbps": round(float(peak_row.get("busbw_gbps", 0)), 2),
"peak_algbw_gbps": round(float(peak_row.get("algbw_gbps", 0)), 2),
"peak_size": peak_row.get("size", ""),
"avg_busbw_gbps": round(avg_bus, 2),
"by_size": rows,
}
@staticmethod
def _parse_network_diagnostics(output: str) -> dict:
networks = sorted(set(re.findall(r"Using network (\S+)", output)))
gdr_enabled = sorted(set(re.findall(r"GPU Direct RDMA Enabled for HCA \d+ '([^']+)'", output)))
gdr_disabled = sorted(set(re.findall(r"GPU Direct RDMA Disabled for HCA \d+ '([^']+)'", output)))
ib_using = []
for line in output.splitlines():
if "NET/IB : Using" in line:
text = line.split("NET/IB : ", 1)[-1].strip()
if text not in ib_using:
ib_using.append(text)
if gdr_disabled:
gdr_state = "DISABLED"
elif gdr_enabled or "/GDRDMA" in output:
gdr_state = "ENABLED"
elif networks:
gdr_state = "NOT_DISABLED_IN_LOG"
else:
gdr_state = "UNKNOWN"
return {
"networks": networks,
"ib_using": ib_using[:8],
"gdr_enabled_hcas": gdr_enabled,
"gdr_disabled_hcas": gdr_disabled,
"gpu_direct_rdma": gdr_state,
}
@staticmethod
def print_results(results: dict, console: Console = None):
c = console or Console()
if results.get("error"):
c.print(f"[bold red]Multi-node NCCL failed: {results['error']}[/bold red]")
else:
c.print("[bold green]Multi-node NCCL complete[/bold green]" if results.get("passed") else "[bold red]Multi-node NCCL failed[/bold red]")
preflight = results.get("preflight", {})
if preflight.get("checks"):
table = Table(title="Preflight")
table.add_column("Check")
table.add_column("Status")
table.add_column("Detail")
for check in preflight["checks"]:
table.add_row(check["name"], check["status"], str(check.get("detail", "")))
c.print(table)
for op, data in (results.get("tests") or {}).items():
table = Table(title=f"Multi-node NCCL {op}")
table.add_column("Topology")
table.add_column("Peak Bus BW")
table.add_column("Peak Size")
table.add_column("Threshold")
table.add_column("Status")
for topo in data.get("topologies", []):
table.add_row(
topo.get("label", ""),
f"{topo.get('peak_busbw_gbps', 0):.2f} GB/s",
str(topo.get("peak_size", "")),
f">= {_format_gbps(topo.get('min_required_gbps', 0))} GB/s" if topo.get("min_required_gbps") else "-",
topo.get("status", "?"),
)
c.print(table)
def _format_size(size_bytes: int) -> str:
units = [("G", 1024 ** 3), ("M", 1024 ** 2), ("K", 1024)]
for suffix, factor in units:
if size_bytes >= factor and size_bytes % factor == 0:
return f"{size_bytes // factor}{suffix}"
return str(size_bytes)
def _format_gbps(value) -> str:
try:
numeric = float(value)
except (TypeError, ValueError):
return str(value)
if numeric.is_integer():
return f"{numeric:.0f}"
return f"{numeric:.2f}"
def _safe_name(value: str) -> str:
text = re.sub(r"[^A-Za-z0-9_.-]+", "_", value.strip())
text = re.sub(r"_+", "_", text).strip("_")
return text[:160] or "case"