diff --git a/configs/default.yaml b/configs/default.yaml index 7172b6e..a432c11 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -62,6 +62,24 @@ rdma: msg_size: 65536 ib_device: null ib_port: 1 + # Cross-node (two-host) RDMA via perftest, orchestrated over SSH from the CLIENT + # node. Replaces the old scripts/rdma_cross_node.sh. Run on the client; it starts + # ib_write_bw/ib_write_lat servers on `server` over SSH (passwordless required), + # then drives the local client per device. + cross_node: + enabled: false # set true on the client node to run cross-node RDMA + server: null # peer ssh address, e.g. 172.72.8.12 (server node) + server_addr: null # OOB addr client connects to (default: = server) + ssh_user: root + devices: [] # e.g. [mlx5_0, mlx5_1, mlx5_6, mlx5_7]; [] = auto-detect active IB + ib_port: 1 + gid_index: null # -x for RoCE; null for pure InfiniBand + msg_size: 1048576 # 1 MiB — large enough to reach NDR400 peak + iters: 5000 + base_oob_port: 18515 # per-device OOB port = base + device index + server_warmup_sec: 2.0 + min_bandwidth_gbps: 350 # per-port PASS floor (NDR400 ≈ 0.9 × 400) + max_latency_us: 5 training: model: gpt2 diff --git a/modules/gpu_specs.py b/modules/gpu_specs.py index a08e0d1..2ce9348 100644 --- a/modules/gpu_specs.py +++ b/modules/gpu_specs.py @@ -99,11 +99,14 @@ GPU_SPECS = { "fp16_tflops": 990, # dense (same as H100) "bf16_tflops": 990, # dense (same as H100) "fp8_tflops": 1979, # dense (same as H100) - # Tensor Core peaks identical to H100, so PASS thresholds match v2 calibration. - # FP64 deliberately NOT listed — H800 is restricted to ~1 TFLOPS FP64 and - # is not a valid HPC target dtype. + # Tensor Core peaks identical to H100, so PASS thresholds reuse the H100 + # eager-cuBLAS calibration (2026-05-25). Measured on 8×H800: fp32 ~52 / + # tf32 ~420 / fp16 ~741 / bf16 ~745 / fp8 ~1249 — all clear these. fp8 was + # 1400 (an H200/rowwise-scaling figure) which PyTorch tensorwise _scaled_mm + # can't reach on H100-class silicon (~1310 ceiling); lowered to 1200 to match + # h100. FP64 deliberately NOT listed — H800 is restricted to ~1 TFLOPS FP64. "compute_pass_thresholds_tflops": { - "fp32": 50, "tf32": 400, "fp16": 720, "bf16": 720, "fp8": 1400, + "fp32": 50, "tf32": 385, "fp16": 720, "bf16": 730, "fp8": 1200, }, "tdp_watts": 700, "nvlink_gen": 4, diff --git a/modules/rdma_test.py b/modules/rdma_test.py index 497e9d5..08ea610 100644 --- a/modules/rdma_test.py +++ b/modules/rdma_test.py @@ -3,6 +3,7 @@ import os import shutil import subprocess +import time from datetime import datetime from typing import Optional, List @@ -109,13 +110,17 @@ class RDMATest: if isinstance(r, dict) ) - return { + result = { "passed": all_passed, "devices": device_info, "bandwidth_tests": bw_results, "latency_tests": latency_results, "timestamp": datetime.now().isoformat(), } + # Cross-node (two-host) RDMA, run only when a peer is configured. + if (self.rdma_cfg.get("cross_node", {}) or {}).get("enabled"): + result["cross_node"] = self.run_cross_node() + return result def _collect_device_info(self, devices: List[str]) -> List[dict]: info = [] @@ -252,6 +257,200 @@ class RDMATest: return results + # ------------------------------------------------------------------ + # Cross-node (two-host) RDMA over perftest, orchestrated via SSH. + # Runs FROM the client host: for each IB device it launches the matching + # perftest server on the peer over SSH (held open in a live ssh channel), + # then runs the local client against the peer's OOB address and parses the + # result. Replaces the old standalone scripts/rdma_cross_node.sh. + # ------------------------------------------------------------------ + + def _active_ib_devices(self) -> List[str]: + """IB devices whose port 1 is InfiniBand link_layer and ACTIVE.""" + out = [] + for dev in self._get_ib_devices(): + for port in self._get_ib_ports(dev): + ll = self._read_sys(f"/sys/class/infiniband/{dev}/ports/{port}/link_layer") + st = self._read_sys(f"/sys/class/infiniband/{dev}/ports/{port}/state") + if ll == "InfiniBand" and "ACTIVE" in st.upper(): + out.append(dev) + break + return out + + def run_cross_node(self) -> dict: + cn = self.rdma_cfg.get("cross_node", {}) or {} + if not cn.get("enabled"): + return {"status": "SKIP", "skipped": True, + "reason": "rdma.cross_node.enabled is false"} + + server = cn.get("server") + if not server: + return {"status": "SKIP", "skipped": True, + "reason": "rdma.cross_node.server (peer ssh address) not set"} + + ssh_user = cn.get("ssh_user", "root") + server_target = server if "@" in server else f"{ssh_user}@{server}" + # OOB address the client's perftest connects to (defaults to the ssh host). + server_addr = cn.get("server_addr") or server.split("@")[-1] + ib_port = cn.get("ib_port", 1) + gid_index = cn.get("gid_index") + msg_size = cn.get("msg_size", 1048576) + iters = cn.get("iters", 5000) + base_port = cn.get("base_oob_port", 18515) + warmup = cn.get("server_warmup_sec", 2.0) + min_bw = cn.get("min_bandwidth_gbps", 350) + max_lat = cn.get("max_latency_us", 5) + + devices = cn.get("devices") or self._active_ib_devices() + if not devices: + return {"status": "SKIP", "skipped": True, + "reason": "no active InfiniBand devices to test"} + + has_bw = self._find_tool("ib_write_bw") is not None + has_lat = self._find_tool("ib_write_lat") is not None + if not has_bw and not has_lat: + return {"status": "SKIP", "skipped": True, + "reason": "perftest (ib_write_bw / ib_write_lat) not installed"} + + self.console.print( + f"[cyan]Cross-node RDMA — client → {server_addr}, " + f"devices: {', '.join(devices)}[/cyan]") + + per_device = [] + for idx, dev in enumerate(devices): + oob = base_port + idx + entry = {"device": dev} + + if has_bw: + bw = self._cross_node_perftest( + "ib_write_bw", dev, server_target, server_addr, ib_port, + oob, gid_index, warmup, + extra=["--report_gbits", "-s", str(msg_size), "-n", str(iters)], + parse="bw") + entry["bandwidth_gbps"] = bw + if isinstance(bw, (int, float)): + entry["bw_status"] = "PASS" if bw >= min_bw else "WARN" + else: + entry["bw_status"] = "FAIL" + + if has_lat: + lat = self._cross_node_perftest( + "ib_write_lat", dev, server_target, server_addr, ib_port, + oob, gid_index, warmup, extra=[], parse="lat") + if isinstance(lat, dict): + entry["latency_us"] = lat.get("typical") + entry["latency_p99_us"] = lat.get("p99") + t = lat.get("typical") + entry["lat_status"] = ("PASS" if isinstance(t, (int, float)) and 0 < t <= max_lat + else ("WARN" if isinstance(t, (int, float)) else "FAIL")) + else: + entry["latency_us"] = lat + entry["lat_status"] = "FAIL" + + per_device.append(entry) + + statuses = [e.get(k) for e in per_device for k in ("bw_status", "lat_status") if e.get(k)] + verdict = "PASS" + for s in statuses: + if s == "FAIL": + verdict = "FAIL" + break + if s == "WARN" and verdict == "PASS": + verdict = "WARN" + + return { + "status": verdict, + "server": server_addr, + "min_bandwidth_gbps": min_bw, + "max_latency_us": max_lat, + "per_device": per_device, + "timestamp": datetime.now().isoformat(), + } + + def _cross_node_perftest(self, tool: str, dev: str, server_target: str, + server_addr: str, ib_port: int, oob_port: int, + gid_index, warmup: float, extra: List[str], parse: str): + """Start `tool` server on the peer via SSH, run the local client, parse output. + + Returns a float (bw, Gb/s), a dict {typical, p99} (lat, µs), or an error string. + """ + tool_path = self._find_tool(tool) + if not tool_path: + return f"{tool} not installed" + + flags = ["-d", dev, "-i", str(ib_port), "-p", str(oob_port), "-F"] + if gid_index is not None: + flags += ["-x", str(gid_index)] + flags += extra + + server_cmd = " ".join([tool] + flags) # server: no host argument + server_proc = None + try: + server_proc = subprocess.Popen( + ["ssh", "-o", "BatchMode=yes", "-o", "StrictHostKeyChecking=no", + server_target, server_cmd], + stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) + time.sleep(warmup) # let the remote server bind before the client connects + + client = subprocess.run([tool_path] + flags + [server_addr], + capture_output=True, text=True, timeout=120) + out = client.stdout + "\n" + (client.stderr or "") + return self._parse_perftest_lat(out) if parse == "lat" else self._parse_perftest_bw(out) + except subprocess.TimeoutExpired: + return "timeout" + except Exception as e: # noqa: BLE001 + return f"error: {e}" + finally: + if server_proc and server_proc.poll() is None: + server_proc.terminate() + try: + server_proc.wait(timeout=5) + except Exception: + server_proc.kill() + # ib_write_* server normally exits after one run; pkill cleans up a + # leftover one if the client failed mid-handshake. -x matches the exact + # process name so it never kills this ssh command itself. + try: + subprocess.run( + ["ssh", "-o", "BatchMode=yes", server_target, f"pkill -x {tool}"], + capture_output=True, timeout=10) + except Exception: + pass + + @staticmethod + def _parse_perftest_bw(output: str) -> float: + """Parse ib_write_bw rows (#bytes #iter BW_peak BW_avg ...); return max BW avg.""" + best = 0.0 + for line in output.splitlines(): + parts = line.split() + if len(parts) >= 4: + try: + int(parts[0]) # #bytes column + best = max(best, float(parts[3])) # BW average[Gb/sec] + except ValueError: + continue + return round(best, 2) if best else 0.0 + + @staticmethod + def _parse_perftest_lat(output: str) -> dict: + """Parse ib_write_lat row (#bytes #iter t_min t_max t_typical t_avg ... 99%).""" + for line in output.splitlines(): + parts = line.split() + if len(parts) >= 6: + try: + int(parts[0]); int(parts[1]) + typical = float(parts[4]) # t_typical[usec] + except ValueError: + continue + p99 = None + if len(parts) >= 8: + try: + p99 = float(parts[7]) # 99% percentile[usec] + except ValueError: + p99 = None + return {"typical": round(typical, 2), "p99": round(p99, 2) if p99 else None} + return {"typical": None, "p99": None} + @staticmethod def print_results(results: dict, console: Console = None): c = console or Console() @@ -296,3 +495,29 @@ class RDMATest: c.print(f" {t['test']}: [{sc}]{status}[/{sc}] " f"({lat:.2f} us, max: {t.get('max_allowed_us', 'N/A')} us)" if status != "SKIP" else f" {t['test']}: [dim]SKIPPED[/dim]") + + cn = results.get("cross_node") + if cn: + if cn.get("skipped"): + c.print(f"\n [bold]Cross-node RDMA[/bold]: [dim]SKIPPED " + f"({cn.get('reason', '')})[/dim]") + else: + v = cn.get("status", "?") + vc = "green" if v == "PASS" else ("yellow" if v == "WARN" else "red") + c.print(f"\n [bold]Cross-node RDMA[/bold] (server {cn.get('server')}) " + f"[{vc}]{v}[/{vc}] " + f"[dim]min {cn.get('min_bandwidth_gbps')} Gb/s, " + f"max {cn.get('max_latency_us')} µs[/dim]") + for e in cn.get("per_device", []): + bw = e.get("bandwidth_gbps") + lat = e.get("latency_us") + bws = e.get("bw_status", "") + lts = e.get("lat_status", "") + bc = "green" if bws == "PASS" else ("yellow" if bws == "WARN" else "red") + lc = "green" if lts == "PASS" else ("yellow" if lts == "WARN" else "red") + bw_s = f"{bw:.1f} Gb/s" if isinstance(bw, (int, float)) else str(bw) + lat_s = f"{lat:.2f} µs" if isinstance(lat, (int, float)) else str(lat) + p99 = e.get("latency_p99_us") + p99_s = f", p99 {p99:.2f}" if isinstance(p99, (int, float)) else "" + c.print(f" {e['device']}: BW [{bc}]{bw_s}[/{bc}] | " + f"lat [{lc}]{lat_s}[/{lc}]{p99_s}")