feat: 跨机 RDMA 并入 rdma_test.py + H800 算力门槛对齐 H100

- modules/rdma_test.py: 新增 SSH 编排的跨机 RDMA(run_cross_node /
  _cross_node_perftest / 解析器),从 client 端逐设备拉起对端 perftest
  server 跑本地 client,替代已删除的 scripts/rdma_cross_node.sh;两机
  4×NDR400 实测全 PASS(~387-392 Gb/s,~2 µs)。
- configs/default.yaml: 新增 rdma.cross_node 配置块(默认 enabled:false)。
- modules/gpu_specs.py: H800 PASS 门槛对齐 H100 实测地板
  (tf32 400->385, bf16 720->730, fp8 1400->1200);H800=H100 硅片,
  PyTorch tensorwise fp8 天花板 ~1310,原 1400 不可达。

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
zulifeng 2026-05-25 19:38:43 +08:00
parent e49ea32094
commit dd77a882f1
3 changed files with 251 additions and 5 deletions

View File

@ -62,6 +62,24 @@ rdma:
msg_size: 65536 msg_size: 65536
ib_device: null ib_device: null
ib_port: 1 ib_port: 1
# Cross-node (two-host) RDMA via perftest, orchestrated over SSH from the CLIENT
# node. Replaces the old scripts/rdma_cross_node.sh. Run on the client; it starts
# ib_write_bw/ib_write_lat servers on `server` over SSH (passwordless required),
# then drives the local client per device.
cross_node:
enabled: false # set true on the client node to run cross-node RDMA
server: null # peer ssh address, e.g. 172.72.8.12 (server node)
server_addr: null # OOB addr client connects to (default: = server)
ssh_user: root
devices: [] # e.g. [mlx5_0, mlx5_1, mlx5_6, mlx5_7]; [] = auto-detect active IB
ib_port: 1
gid_index: null # -x <n> for RoCE; null for pure InfiniBand
msg_size: 1048576 # 1 MiB — large enough to reach NDR400 peak
iters: 5000
base_oob_port: 18515 # per-device OOB port = base + device index
server_warmup_sec: 2.0
min_bandwidth_gbps: 350 # per-port PASS floor (NDR400 ≈ 0.9 × 400)
max_latency_us: 5
training: training:
model: gpt2 model: gpt2

View File

@ -99,11 +99,14 @@ GPU_SPECS = {
"fp16_tflops": 990, # dense (same as H100) "fp16_tflops": 990, # dense (same as H100)
"bf16_tflops": 990, # dense (same as H100) "bf16_tflops": 990, # dense (same as H100)
"fp8_tflops": 1979, # dense (same as H100) "fp8_tflops": 1979, # dense (same as H100)
# Tensor Core peaks identical to H100, so PASS thresholds match v2 calibration. # Tensor Core peaks identical to H100, so PASS thresholds reuse the H100
# FP64 deliberately NOT listed — H800 is restricted to ~1 TFLOPS FP64 and # eager-cuBLAS calibration (2026-05-25). Measured on 8×H800: fp32 ~52 /
# is not a valid HPC target dtype. # tf32 ~420 / fp16 ~741 / bf16 ~745 / fp8 ~1249 — all clear these. fp8 was
# 1400 (an H200/rowwise-scaling figure) which PyTorch tensorwise _scaled_mm
# can't reach on H100-class silicon (~1310 ceiling); lowered to 1200 to match
# h100. FP64 deliberately NOT listed — H800 is restricted to ~1 TFLOPS FP64.
"compute_pass_thresholds_tflops": { "compute_pass_thresholds_tflops": {
"fp32": 50, "tf32": 400, "fp16": 720, "bf16": 720, "fp8": 1400, "fp32": 50, "tf32": 385, "fp16": 720, "bf16": 730, "fp8": 1200,
}, },
"tdp_watts": 700, "tdp_watts": 700,
"nvlink_gen": 4, "nvlink_gen": 4,

View File

@ -3,6 +3,7 @@
import os import os
import shutil import shutil
import subprocess import subprocess
import time
from datetime import datetime from datetime import datetime
from typing import Optional, List from typing import Optional, List
@ -109,13 +110,17 @@ class RDMATest:
if isinstance(r, dict) if isinstance(r, dict)
) )
return { result = {
"passed": all_passed, "passed": all_passed,
"devices": device_info, "devices": device_info,
"bandwidth_tests": bw_results, "bandwidth_tests": bw_results,
"latency_tests": latency_results, "latency_tests": latency_results,
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
} }
# Cross-node (two-host) RDMA, run only when a peer is configured.
if (self.rdma_cfg.get("cross_node", {}) or {}).get("enabled"):
result["cross_node"] = self.run_cross_node()
return result
def _collect_device_info(self, devices: List[str]) -> List[dict]: def _collect_device_info(self, devices: List[str]) -> List[dict]:
info = [] info = []
@ -252,6 +257,200 @@ class RDMATest:
return results return results
# ------------------------------------------------------------------
# Cross-node (two-host) RDMA over perftest, orchestrated via SSH.
# Runs FROM the client host: for each IB device it launches the matching
# perftest server on the peer over SSH (held open in a live ssh channel),
# then runs the local client against the peer's OOB address and parses the
# result. Replaces the old standalone scripts/rdma_cross_node.sh.
# ------------------------------------------------------------------
def _active_ib_devices(self) -> List[str]:
"""IB devices whose port 1 is InfiniBand link_layer and ACTIVE."""
out = []
for dev in self._get_ib_devices():
for port in self._get_ib_ports(dev):
ll = self._read_sys(f"/sys/class/infiniband/{dev}/ports/{port}/link_layer")
st = self._read_sys(f"/sys/class/infiniband/{dev}/ports/{port}/state")
if ll == "InfiniBand" and "ACTIVE" in st.upper():
out.append(dev)
break
return out
def run_cross_node(self) -> dict:
cn = self.rdma_cfg.get("cross_node", {}) or {}
if not cn.get("enabled"):
return {"status": "SKIP", "skipped": True,
"reason": "rdma.cross_node.enabled is false"}
server = cn.get("server")
if not server:
return {"status": "SKIP", "skipped": True,
"reason": "rdma.cross_node.server (peer ssh address) not set"}
ssh_user = cn.get("ssh_user", "root")
server_target = server if "@" in server else f"{ssh_user}@{server}"
# OOB address the client's perftest connects to (defaults to the ssh host).
server_addr = cn.get("server_addr") or server.split("@")[-1]
ib_port = cn.get("ib_port", 1)
gid_index = cn.get("gid_index")
msg_size = cn.get("msg_size", 1048576)
iters = cn.get("iters", 5000)
base_port = cn.get("base_oob_port", 18515)
warmup = cn.get("server_warmup_sec", 2.0)
min_bw = cn.get("min_bandwidth_gbps", 350)
max_lat = cn.get("max_latency_us", 5)
devices = cn.get("devices") or self._active_ib_devices()
if not devices:
return {"status": "SKIP", "skipped": True,
"reason": "no active InfiniBand devices to test"}
has_bw = self._find_tool("ib_write_bw") is not None
has_lat = self._find_tool("ib_write_lat") is not None
if not has_bw and not has_lat:
return {"status": "SKIP", "skipped": True,
"reason": "perftest (ib_write_bw / ib_write_lat) not installed"}
self.console.print(
f"[cyan]Cross-node RDMA — client → {server_addr}, "
f"devices: {', '.join(devices)}[/cyan]")
per_device = []
for idx, dev in enumerate(devices):
oob = base_port + idx
entry = {"device": dev}
if has_bw:
bw = self._cross_node_perftest(
"ib_write_bw", dev, server_target, server_addr, ib_port,
oob, gid_index, warmup,
extra=["--report_gbits", "-s", str(msg_size), "-n", str(iters)],
parse="bw")
entry["bandwidth_gbps"] = bw
if isinstance(bw, (int, float)):
entry["bw_status"] = "PASS" if bw >= min_bw else "WARN"
else:
entry["bw_status"] = "FAIL"
if has_lat:
lat = self._cross_node_perftest(
"ib_write_lat", dev, server_target, server_addr, ib_port,
oob, gid_index, warmup, extra=[], parse="lat")
if isinstance(lat, dict):
entry["latency_us"] = lat.get("typical")
entry["latency_p99_us"] = lat.get("p99")
t = lat.get("typical")
entry["lat_status"] = ("PASS" if isinstance(t, (int, float)) and 0 < t <= max_lat
else ("WARN" if isinstance(t, (int, float)) else "FAIL"))
else:
entry["latency_us"] = lat
entry["lat_status"] = "FAIL"
per_device.append(entry)
statuses = [e.get(k) for e in per_device for k in ("bw_status", "lat_status") if e.get(k)]
verdict = "PASS"
for s in statuses:
if s == "FAIL":
verdict = "FAIL"
break
if s == "WARN" and verdict == "PASS":
verdict = "WARN"
return {
"status": verdict,
"server": server_addr,
"min_bandwidth_gbps": min_bw,
"max_latency_us": max_lat,
"per_device": per_device,
"timestamp": datetime.now().isoformat(),
}
def _cross_node_perftest(self, tool: str, dev: str, server_target: str,
server_addr: str, ib_port: int, oob_port: int,
gid_index, warmup: float, extra: List[str], parse: str):
"""Start `tool` server on the peer via SSH, run the local client, parse output.
Returns a float (bw, Gb/s), a dict {typical, p99} (lat, µs), or an error string.
"""
tool_path = self._find_tool(tool)
if not tool_path:
return f"{tool} not installed"
flags = ["-d", dev, "-i", str(ib_port), "-p", str(oob_port), "-F"]
if gid_index is not None:
flags += ["-x", str(gid_index)]
flags += extra
server_cmd = " ".join([tool] + flags) # server: no host argument
server_proc = None
try:
server_proc = subprocess.Popen(
["ssh", "-o", "BatchMode=yes", "-o", "StrictHostKeyChecking=no",
server_target, server_cmd],
stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
time.sleep(warmup) # let the remote server bind before the client connects
client = subprocess.run([tool_path] + flags + [server_addr],
capture_output=True, text=True, timeout=120)
out = client.stdout + "\n" + (client.stderr or "")
return self._parse_perftest_lat(out) if parse == "lat" else self._parse_perftest_bw(out)
except subprocess.TimeoutExpired:
return "timeout"
except Exception as e: # noqa: BLE001
return f"error: {e}"
finally:
if server_proc and server_proc.poll() is None:
server_proc.terminate()
try:
server_proc.wait(timeout=5)
except Exception:
server_proc.kill()
# ib_write_* server normally exits after one run; pkill cleans up a
# leftover one if the client failed mid-handshake. -x matches the exact
# process name so it never kills this ssh command itself.
try:
subprocess.run(
["ssh", "-o", "BatchMode=yes", server_target, f"pkill -x {tool}"],
capture_output=True, timeout=10)
except Exception:
pass
@staticmethod
def _parse_perftest_bw(output: str) -> float:
"""Parse ib_write_bw rows (#bytes #iter BW_peak BW_avg ...); return max BW avg."""
best = 0.0
for line in output.splitlines():
parts = line.split()
if len(parts) >= 4:
try:
int(parts[0]) # #bytes column
best = max(best, float(parts[3])) # BW average[Gb/sec]
except ValueError:
continue
return round(best, 2) if best else 0.0
@staticmethod
def _parse_perftest_lat(output: str) -> dict:
"""Parse ib_write_lat row (#bytes #iter t_min t_max t_typical t_avg ... 99%)."""
for line in output.splitlines():
parts = line.split()
if len(parts) >= 6:
try:
int(parts[0]); int(parts[1])
typical = float(parts[4]) # t_typical[usec]
except ValueError:
continue
p99 = None
if len(parts) >= 8:
try:
p99 = float(parts[7]) # 99% percentile[usec]
except ValueError:
p99 = None
return {"typical": round(typical, 2), "p99": round(p99, 2) if p99 else None}
return {"typical": None, "p99": None}
@staticmethod @staticmethod
def print_results(results: dict, console: Console = None): def print_results(results: dict, console: Console = None):
c = console or Console() c = console or Console()
@ -296,3 +495,29 @@ class RDMATest:
c.print(f" {t['test']}: [{sc}]{status}[/{sc}] " c.print(f" {t['test']}: [{sc}]{status}[/{sc}] "
f"({lat:.2f} us, max: {t.get('max_allowed_us', 'N/A')} us)" if status != "SKIP" f"({lat:.2f} us, max: {t.get('max_allowed_us', 'N/A')} us)" if status != "SKIP"
else f" {t['test']}: [dim]SKIPPED[/dim]") else f" {t['test']}: [dim]SKIPPED[/dim]")
cn = results.get("cross_node")
if cn:
if cn.get("skipped"):
c.print(f"\n [bold]Cross-node RDMA[/bold]: [dim]SKIPPED "
f"({cn.get('reason', '')})[/dim]")
else:
v = cn.get("status", "?")
vc = "green" if v == "PASS" else ("yellow" if v == "WARN" else "red")
c.print(f"\n [bold]Cross-node RDMA[/bold] (server {cn.get('server')}) "
f"[{vc}]{v}[/{vc}] "
f"[dim]min {cn.get('min_bandwidth_gbps')} Gb/s, "
f"max {cn.get('max_latency_us')} µs[/dim]")
for e in cn.get("per_device", []):
bw = e.get("bandwidth_gbps")
lat = e.get("latency_us")
bws = e.get("bw_status", "")
lts = e.get("lat_status", "")
bc = "green" if bws == "PASS" else ("yellow" if bws == "WARN" else "red")
lc = "green" if lts == "PASS" else ("yellow" if lts == "WARN" else "red")
bw_s = f"{bw:.1f} Gb/s" if isinstance(bw, (int, float)) else str(bw)
lat_s = f"{lat:.2f} µs" if isinstance(lat, (int, float)) else str(lat)
p99 = e.get("latency_p99_us")
p99_s = f", p99 {p99:.2f}" if isinstance(p99, (int, float)) else ""
c.print(f" {e['device']}: BW [{bc}]{bw_s}[/{bc}] | "
f"lat [{lc}]{lat_s}[/{lc}]{p99_s}")