feat: 跨机 RDMA 并入 rdma_test.py + H800 算力门槛对齐 H100
- modules/rdma_test.py: 新增 SSH 编排的跨机 RDMA(run_cross_node / _cross_node_perftest / 解析器),从 client 端逐设备拉起对端 perftest server 跑本地 client,替代已删除的 scripts/rdma_cross_node.sh;两机 4×NDR400 实测全 PASS(~387-392 Gb/s,~2 µs)。 - configs/default.yaml: 新增 rdma.cross_node 配置块(默认 enabled:false)。 - modules/gpu_specs.py: H800 PASS 门槛对齐 H100 实测地板 (tf32 400->385, bf16 720->730, fp8 1400->1200);H800=H100 硅片, PyTorch tensorwise fp8 天花板 ~1310,原 1400 不可达。 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
e49ea32094
commit
dd77a882f1
@ -62,6 +62,24 @@ rdma:
|
||||
msg_size: 65536
|
||||
ib_device: null
|
||||
ib_port: 1
|
||||
# Cross-node (two-host) RDMA via perftest, orchestrated over SSH from the CLIENT
|
||||
# node. Replaces the old scripts/rdma_cross_node.sh. Run on the client; it starts
|
||||
# ib_write_bw/ib_write_lat servers on `server` over SSH (passwordless required),
|
||||
# then drives the local client per device.
|
||||
cross_node:
|
||||
enabled: false # set true on the client node to run cross-node RDMA
|
||||
server: null # peer ssh address, e.g. 172.72.8.12 (server node)
|
||||
server_addr: null # OOB addr client connects to (default: = server)
|
||||
ssh_user: root
|
||||
devices: [] # e.g. [mlx5_0, mlx5_1, mlx5_6, mlx5_7]; [] = auto-detect active IB
|
||||
ib_port: 1
|
||||
gid_index: null # -x <n> for RoCE; null for pure InfiniBand
|
||||
msg_size: 1048576 # 1 MiB — large enough to reach NDR400 peak
|
||||
iters: 5000
|
||||
base_oob_port: 18515 # per-device OOB port = base + device index
|
||||
server_warmup_sec: 2.0
|
||||
min_bandwidth_gbps: 350 # per-port PASS floor (NDR400 ≈ 0.9 × 400)
|
||||
max_latency_us: 5
|
||||
|
||||
training:
|
||||
model: gpt2
|
||||
|
||||
@ -99,11 +99,14 @@ GPU_SPECS = {
|
||||
"fp16_tflops": 990, # dense (same as H100)
|
||||
"bf16_tflops": 990, # dense (same as H100)
|
||||
"fp8_tflops": 1979, # dense (same as H100)
|
||||
# Tensor Core peaks identical to H100, so PASS thresholds match v2 calibration.
|
||||
# FP64 deliberately NOT listed — H800 is restricted to ~1 TFLOPS FP64 and
|
||||
# is not a valid HPC target dtype.
|
||||
# Tensor Core peaks identical to H100, so PASS thresholds reuse the H100
|
||||
# eager-cuBLAS calibration (2026-05-25). Measured on 8×H800: fp32 ~52 /
|
||||
# tf32 ~420 / fp16 ~741 / bf16 ~745 / fp8 ~1249 — all clear these. fp8 was
|
||||
# 1400 (an H200/rowwise-scaling figure) which PyTorch tensorwise _scaled_mm
|
||||
# can't reach on H100-class silicon (~1310 ceiling); lowered to 1200 to match
|
||||
# h100. FP64 deliberately NOT listed — H800 is restricted to ~1 TFLOPS FP64.
|
||||
"compute_pass_thresholds_tflops": {
|
||||
"fp32": 50, "tf32": 400, "fp16": 720, "bf16": 720, "fp8": 1400,
|
||||
"fp32": 50, "tf32": 385, "fp16": 720, "bf16": 730, "fp8": 1200,
|
||||
},
|
||||
"tdp_watts": 700,
|
||||
"nvlink_gen": 4,
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
|
||||
@ -109,13 +110,17 @@ class RDMATest:
|
||||
if isinstance(r, dict)
|
||||
)
|
||||
|
||||
return {
|
||||
result = {
|
||||
"passed": all_passed,
|
||||
"devices": device_info,
|
||||
"bandwidth_tests": bw_results,
|
||||
"latency_tests": latency_results,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
# Cross-node (two-host) RDMA, run only when a peer is configured.
|
||||
if (self.rdma_cfg.get("cross_node", {}) or {}).get("enabled"):
|
||||
result["cross_node"] = self.run_cross_node()
|
||||
return result
|
||||
|
||||
def _collect_device_info(self, devices: List[str]) -> List[dict]:
|
||||
info = []
|
||||
@ -252,6 +257,200 @@ class RDMATest:
|
||||
|
||||
return results
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cross-node (two-host) RDMA over perftest, orchestrated via SSH.
|
||||
# Runs FROM the client host: for each IB device it launches the matching
|
||||
# perftest server on the peer over SSH (held open in a live ssh channel),
|
||||
# then runs the local client against the peer's OOB address and parses the
|
||||
# result. Replaces the old standalone scripts/rdma_cross_node.sh.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _active_ib_devices(self) -> List[str]:
|
||||
"""IB devices whose port 1 is InfiniBand link_layer and ACTIVE."""
|
||||
out = []
|
||||
for dev in self._get_ib_devices():
|
||||
for port in self._get_ib_ports(dev):
|
||||
ll = self._read_sys(f"/sys/class/infiniband/{dev}/ports/{port}/link_layer")
|
||||
st = self._read_sys(f"/sys/class/infiniband/{dev}/ports/{port}/state")
|
||||
if ll == "InfiniBand" and "ACTIVE" in st.upper():
|
||||
out.append(dev)
|
||||
break
|
||||
return out
|
||||
|
||||
def run_cross_node(self) -> dict:
|
||||
cn = self.rdma_cfg.get("cross_node", {}) or {}
|
||||
if not cn.get("enabled"):
|
||||
return {"status": "SKIP", "skipped": True,
|
||||
"reason": "rdma.cross_node.enabled is false"}
|
||||
|
||||
server = cn.get("server")
|
||||
if not server:
|
||||
return {"status": "SKIP", "skipped": True,
|
||||
"reason": "rdma.cross_node.server (peer ssh address) not set"}
|
||||
|
||||
ssh_user = cn.get("ssh_user", "root")
|
||||
server_target = server if "@" in server else f"{ssh_user}@{server}"
|
||||
# OOB address the client's perftest connects to (defaults to the ssh host).
|
||||
server_addr = cn.get("server_addr") or server.split("@")[-1]
|
||||
ib_port = cn.get("ib_port", 1)
|
||||
gid_index = cn.get("gid_index")
|
||||
msg_size = cn.get("msg_size", 1048576)
|
||||
iters = cn.get("iters", 5000)
|
||||
base_port = cn.get("base_oob_port", 18515)
|
||||
warmup = cn.get("server_warmup_sec", 2.0)
|
||||
min_bw = cn.get("min_bandwidth_gbps", 350)
|
||||
max_lat = cn.get("max_latency_us", 5)
|
||||
|
||||
devices = cn.get("devices") or self._active_ib_devices()
|
||||
if not devices:
|
||||
return {"status": "SKIP", "skipped": True,
|
||||
"reason": "no active InfiniBand devices to test"}
|
||||
|
||||
has_bw = self._find_tool("ib_write_bw") is not None
|
||||
has_lat = self._find_tool("ib_write_lat") is not None
|
||||
if not has_bw and not has_lat:
|
||||
return {"status": "SKIP", "skipped": True,
|
||||
"reason": "perftest (ib_write_bw / ib_write_lat) not installed"}
|
||||
|
||||
self.console.print(
|
||||
f"[cyan]Cross-node RDMA — client → {server_addr}, "
|
||||
f"devices: {', '.join(devices)}[/cyan]")
|
||||
|
||||
per_device = []
|
||||
for idx, dev in enumerate(devices):
|
||||
oob = base_port + idx
|
||||
entry = {"device": dev}
|
||||
|
||||
if has_bw:
|
||||
bw = self._cross_node_perftest(
|
||||
"ib_write_bw", dev, server_target, server_addr, ib_port,
|
||||
oob, gid_index, warmup,
|
||||
extra=["--report_gbits", "-s", str(msg_size), "-n", str(iters)],
|
||||
parse="bw")
|
||||
entry["bandwidth_gbps"] = bw
|
||||
if isinstance(bw, (int, float)):
|
||||
entry["bw_status"] = "PASS" if bw >= min_bw else "WARN"
|
||||
else:
|
||||
entry["bw_status"] = "FAIL"
|
||||
|
||||
if has_lat:
|
||||
lat = self._cross_node_perftest(
|
||||
"ib_write_lat", dev, server_target, server_addr, ib_port,
|
||||
oob, gid_index, warmup, extra=[], parse="lat")
|
||||
if isinstance(lat, dict):
|
||||
entry["latency_us"] = lat.get("typical")
|
||||
entry["latency_p99_us"] = lat.get("p99")
|
||||
t = lat.get("typical")
|
||||
entry["lat_status"] = ("PASS" if isinstance(t, (int, float)) and 0 < t <= max_lat
|
||||
else ("WARN" if isinstance(t, (int, float)) else "FAIL"))
|
||||
else:
|
||||
entry["latency_us"] = lat
|
||||
entry["lat_status"] = "FAIL"
|
||||
|
||||
per_device.append(entry)
|
||||
|
||||
statuses = [e.get(k) for e in per_device for k in ("bw_status", "lat_status") if e.get(k)]
|
||||
verdict = "PASS"
|
||||
for s in statuses:
|
||||
if s == "FAIL":
|
||||
verdict = "FAIL"
|
||||
break
|
||||
if s == "WARN" and verdict == "PASS":
|
||||
verdict = "WARN"
|
||||
|
||||
return {
|
||||
"status": verdict,
|
||||
"server": server_addr,
|
||||
"min_bandwidth_gbps": min_bw,
|
||||
"max_latency_us": max_lat,
|
||||
"per_device": per_device,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
def _cross_node_perftest(self, tool: str, dev: str, server_target: str,
|
||||
server_addr: str, ib_port: int, oob_port: int,
|
||||
gid_index, warmup: float, extra: List[str], parse: str):
|
||||
"""Start `tool` server on the peer via SSH, run the local client, parse output.
|
||||
|
||||
Returns a float (bw, Gb/s), a dict {typical, p99} (lat, µs), or an error string.
|
||||
"""
|
||||
tool_path = self._find_tool(tool)
|
||||
if not tool_path:
|
||||
return f"{tool} not installed"
|
||||
|
||||
flags = ["-d", dev, "-i", str(ib_port), "-p", str(oob_port), "-F"]
|
||||
if gid_index is not None:
|
||||
flags += ["-x", str(gid_index)]
|
||||
flags += extra
|
||||
|
||||
server_cmd = " ".join([tool] + flags) # server: no host argument
|
||||
server_proc = None
|
||||
try:
|
||||
server_proc = subprocess.Popen(
|
||||
["ssh", "-o", "BatchMode=yes", "-o", "StrictHostKeyChecking=no",
|
||||
server_target, server_cmd],
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
|
||||
time.sleep(warmup) # let the remote server bind before the client connects
|
||||
|
||||
client = subprocess.run([tool_path] + flags + [server_addr],
|
||||
capture_output=True, text=True, timeout=120)
|
||||
out = client.stdout + "\n" + (client.stderr or "")
|
||||
return self._parse_perftest_lat(out) if parse == "lat" else self._parse_perftest_bw(out)
|
||||
except subprocess.TimeoutExpired:
|
||||
return "timeout"
|
||||
except Exception as e: # noqa: BLE001
|
||||
return f"error: {e}"
|
||||
finally:
|
||||
if server_proc and server_proc.poll() is None:
|
||||
server_proc.terminate()
|
||||
try:
|
||||
server_proc.wait(timeout=5)
|
||||
except Exception:
|
||||
server_proc.kill()
|
||||
# ib_write_* server normally exits after one run; pkill cleans up a
|
||||
# leftover one if the client failed mid-handshake. -x matches the exact
|
||||
# process name so it never kills this ssh command itself.
|
||||
try:
|
||||
subprocess.run(
|
||||
["ssh", "-o", "BatchMode=yes", server_target, f"pkill -x {tool}"],
|
||||
capture_output=True, timeout=10)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _parse_perftest_bw(output: str) -> float:
|
||||
"""Parse ib_write_bw rows (#bytes #iter BW_peak BW_avg ...); return max BW avg."""
|
||||
best = 0.0
|
||||
for line in output.splitlines():
|
||||
parts = line.split()
|
||||
if len(parts) >= 4:
|
||||
try:
|
||||
int(parts[0]) # #bytes column
|
||||
best = max(best, float(parts[3])) # BW average[Gb/sec]
|
||||
except ValueError:
|
||||
continue
|
||||
return round(best, 2) if best else 0.0
|
||||
|
||||
@staticmethod
|
||||
def _parse_perftest_lat(output: str) -> dict:
|
||||
"""Parse ib_write_lat row (#bytes #iter t_min t_max t_typical t_avg ... 99%)."""
|
||||
for line in output.splitlines():
|
||||
parts = line.split()
|
||||
if len(parts) >= 6:
|
||||
try:
|
||||
int(parts[0]); int(parts[1])
|
||||
typical = float(parts[4]) # t_typical[usec]
|
||||
except ValueError:
|
||||
continue
|
||||
p99 = None
|
||||
if len(parts) >= 8:
|
||||
try:
|
||||
p99 = float(parts[7]) # 99% percentile[usec]
|
||||
except ValueError:
|
||||
p99 = None
|
||||
return {"typical": round(typical, 2), "p99": round(p99, 2) if p99 else None}
|
||||
return {"typical": None, "p99": None}
|
||||
|
||||
@staticmethod
|
||||
def print_results(results: dict, console: Console = None):
|
||||
c = console or Console()
|
||||
@ -296,3 +495,29 @@ class RDMATest:
|
||||
c.print(f" {t['test']}: [{sc}]{status}[/{sc}] "
|
||||
f"({lat:.2f} us, max: {t.get('max_allowed_us', 'N/A')} us)" if status != "SKIP"
|
||||
else f" {t['test']}: [dim]SKIPPED[/dim]")
|
||||
|
||||
cn = results.get("cross_node")
|
||||
if cn:
|
||||
if cn.get("skipped"):
|
||||
c.print(f"\n [bold]Cross-node RDMA[/bold]: [dim]SKIPPED "
|
||||
f"({cn.get('reason', '')})[/dim]")
|
||||
else:
|
||||
v = cn.get("status", "?")
|
||||
vc = "green" if v == "PASS" else ("yellow" if v == "WARN" else "red")
|
||||
c.print(f"\n [bold]Cross-node RDMA[/bold] (server {cn.get('server')}) "
|
||||
f"[{vc}]{v}[/{vc}] "
|
||||
f"[dim]min {cn.get('min_bandwidth_gbps')} Gb/s, "
|
||||
f"max {cn.get('max_latency_us')} µs[/dim]")
|
||||
for e in cn.get("per_device", []):
|
||||
bw = e.get("bandwidth_gbps")
|
||||
lat = e.get("latency_us")
|
||||
bws = e.get("bw_status", "")
|
||||
lts = e.get("lat_status", "")
|
||||
bc = "green" if bws == "PASS" else ("yellow" if bws == "WARN" else "red")
|
||||
lc = "green" if lts == "PASS" else ("yellow" if lts == "WARN" else "red")
|
||||
bw_s = f"{bw:.1f} Gb/s" if isinstance(bw, (int, float)) else str(bw)
|
||||
lat_s = f"{lat:.2f} µs" if isinstance(lat, (int, float)) else str(lat)
|
||||
p99 = e.get("latency_p99_us")
|
||||
p99_s = f", p99 {p99:.2f}" if isinstance(p99, (int, float)) else ""
|
||||
c.print(f" {e['device']}: BW [{bc}]{bw_s}[/{bc}] | "
|
||||
f"lat [{lc}]{lat_s}[/{lc}]{p99_s}")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user