feat: 跨机 RDMA 并入 rdma_test.py + H800 算力门槛对齐 H100
- modules/rdma_test.py: 新增 SSH 编排的跨机 RDMA(run_cross_node / _cross_node_perftest / 解析器),从 client 端逐设备拉起对端 perftest server 跑本地 client,替代已删除的 scripts/rdma_cross_node.sh;两机 4×NDR400 实测全 PASS(~387-392 Gb/s,~2 µs)。 - configs/default.yaml: 新增 rdma.cross_node 配置块(默认 enabled:false)。 - modules/gpu_specs.py: H800 PASS 门槛对齐 H100 实测地板 (tf32 400->385, bf16 720->730, fp8 1400->1200);H800=H100 硅片, PyTorch tensorwise fp8 天花板 ~1310,原 1400 不可达。 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
e49ea32094
commit
dd77a882f1
@ -62,6 +62,24 @@ rdma:
|
|||||||
msg_size: 65536
|
msg_size: 65536
|
||||||
ib_device: null
|
ib_device: null
|
||||||
ib_port: 1
|
ib_port: 1
|
||||||
|
# Cross-node (two-host) RDMA via perftest, orchestrated over SSH from the CLIENT
|
||||||
|
# node. Replaces the old scripts/rdma_cross_node.sh. Run on the client; it starts
|
||||||
|
# ib_write_bw/ib_write_lat servers on `server` over SSH (passwordless required),
|
||||||
|
# then drives the local client per device.
|
||||||
|
cross_node:
|
||||||
|
enabled: false # set true on the client node to run cross-node RDMA
|
||||||
|
server: null # peer ssh address, e.g. 172.72.8.12 (server node)
|
||||||
|
server_addr: null # OOB addr client connects to (default: = server)
|
||||||
|
ssh_user: root
|
||||||
|
devices: [] # e.g. [mlx5_0, mlx5_1, mlx5_6, mlx5_7]; [] = auto-detect active IB
|
||||||
|
ib_port: 1
|
||||||
|
gid_index: null # -x <n> for RoCE; null for pure InfiniBand
|
||||||
|
msg_size: 1048576 # 1 MiB — large enough to reach NDR400 peak
|
||||||
|
iters: 5000
|
||||||
|
base_oob_port: 18515 # per-device OOB port = base + device index
|
||||||
|
server_warmup_sec: 2.0
|
||||||
|
min_bandwidth_gbps: 350 # per-port PASS floor (NDR400 ≈ 0.9 × 400)
|
||||||
|
max_latency_us: 5
|
||||||
|
|
||||||
training:
|
training:
|
||||||
model: gpt2
|
model: gpt2
|
||||||
|
|||||||
@ -99,11 +99,14 @@ GPU_SPECS = {
|
|||||||
"fp16_tflops": 990, # dense (same as H100)
|
"fp16_tflops": 990, # dense (same as H100)
|
||||||
"bf16_tflops": 990, # dense (same as H100)
|
"bf16_tflops": 990, # dense (same as H100)
|
||||||
"fp8_tflops": 1979, # dense (same as H100)
|
"fp8_tflops": 1979, # dense (same as H100)
|
||||||
# Tensor Core peaks identical to H100, so PASS thresholds match v2 calibration.
|
# Tensor Core peaks identical to H100, so PASS thresholds reuse the H100
|
||||||
# FP64 deliberately NOT listed — H800 is restricted to ~1 TFLOPS FP64 and
|
# eager-cuBLAS calibration (2026-05-25). Measured on 8×H800: fp32 ~52 /
|
||||||
# is not a valid HPC target dtype.
|
# tf32 ~420 / fp16 ~741 / bf16 ~745 / fp8 ~1249 — all clear these. fp8 was
|
||||||
|
# 1400 (an H200/rowwise-scaling figure) which PyTorch tensorwise _scaled_mm
|
||||||
|
# can't reach on H100-class silicon (~1310 ceiling); lowered to 1200 to match
|
||||||
|
# h100. FP64 deliberately NOT listed — H800 is restricted to ~1 TFLOPS FP64.
|
||||||
"compute_pass_thresholds_tflops": {
|
"compute_pass_thresholds_tflops": {
|
||||||
"fp32": 50, "tf32": 400, "fp16": 720, "bf16": 720, "fp8": 1400,
|
"fp32": 50, "tf32": 385, "fp16": 720, "bf16": 730, "fp8": 1200,
|
||||||
},
|
},
|
||||||
"tdp_watts": 700,
|
"tdp_watts": 700,
|
||||||
"nvlink_gen": 4,
|
"nvlink_gen": 4,
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
@ -109,13 +110,17 @@ class RDMATest:
|
|||||||
if isinstance(r, dict)
|
if isinstance(r, dict)
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
result = {
|
||||||
"passed": all_passed,
|
"passed": all_passed,
|
||||||
"devices": device_info,
|
"devices": device_info,
|
||||||
"bandwidth_tests": bw_results,
|
"bandwidth_tests": bw_results,
|
||||||
"latency_tests": latency_results,
|
"latency_tests": latency_results,
|
||||||
"timestamp": datetime.now().isoformat(),
|
"timestamp": datetime.now().isoformat(),
|
||||||
}
|
}
|
||||||
|
# Cross-node (two-host) RDMA, run only when a peer is configured.
|
||||||
|
if (self.rdma_cfg.get("cross_node", {}) or {}).get("enabled"):
|
||||||
|
result["cross_node"] = self.run_cross_node()
|
||||||
|
return result
|
||||||
|
|
||||||
def _collect_device_info(self, devices: List[str]) -> List[dict]:
|
def _collect_device_info(self, devices: List[str]) -> List[dict]:
|
||||||
info = []
|
info = []
|
||||||
@ -252,6 +257,200 @@ class RDMATest:
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Cross-node (two-host) RDMA over perftest, orchestrated via SSH.
|
||||||
|
# Runs FROM the client host: for each IB device it launches the matching
|
||||||
|
# perftest server on the peer over SSH (held open in a live ssh channel),
|
||||||
|
# then runs the local client against the peer's OOB address and parses the
|
||||||
|
# result. Replaces the old standalone scripts/rdma_cross_node.sh.
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _active_ib_devices(self) -> List[str]:
|
||||||
|
"""IB devices whose port 1 is InfiniBand link_layer and ACTIVE."""
|
||||||
|
out = []
|
||||||
|
for dev in self._get_ib_devices():
|
||||||
|
for port in self._get_ib_ports(dev):
|
||||||
|
ll = self._read_sys(f"/sys/class/infiniband/{dev}/ports/{port}/link_layer")
|
||||||
|
st = self._read_sys(f"/sys/class/infiniband/{dev}/ports/{port}/state")
|
||||||
|
if ll == "InfiniBand" and "ACTIVE" in st.upper():
|
||||||
|
out.append(dev)
|
||||||
|
break
|
||||||
|
return out
|
||||||
|
|
||||||
|
def run_cross_node(self) -> dict:
|
||||||
|
cn = self.rdma_cfg.get("cross_node", {}) or {}
|
||||||
|
if not cn.get("enabled"):
|
||||||
|
return {"status": "SKIP", "skipped": True,
|
||||||
|
"reason": "rdma.cross_node.enabled is false"}
|
||||||
|
|
||||||
|
server = cn.get("server")
|
||||||
|
if not server:
|
||||||
|
return {"status": "SKIP", "skipped": True,
|
||||||
|
"reason": "rdma.cross_node.server (peer ssh address) not set"}
|
||||||
|
|
||||||
|
ssh_user = cn.get("ssh_user", "root")
|
||||||
|
server_target = server if "@" in server else f"{ssh_user}@{server}"
|
||||||
|
# OOB address the client's perftest connects to (defaults to the ssh host).
|
||||||
|
server_addr = cn.get("server_addr") or server.split("@")[-1]
|
||||||
|
ib_port = cn.get("ib_port", 1)
|
||||||
|
gid_index = cn.get("gid_index")
|
||||||
|
msg_size = cn.get("msg_size", 1048576)
|
||||||
|
iters = cn.get("iters", 5000)
|
||||||
|
base_port = cn.get("base_oob_port", 18515)
|
||||||
|
warmup = cn.get("server_warmup_sec", 2.0)
|
||||||
|
min_bw = cn.get("min_bandwidth_gbps", 350)
|
||||||
|
max_lat = cn.get("max_latency_us", 5)
|
||||||
|
|
||||||
|
devices = cn.get("devices") or self._active_ib_devices()
|
||||||
|
if not devices:
|
||||||
|
return {"status": "SKIP", "skipped": True,
|
||||||
|
"reason": "no active InfiniBand devices to test"}
|
||||||
|
|
||||||
|
has_bw = self._find_tool("ib_write_bw") is not None
|
||||||
|
has_lat = self._find_tool("ib_write_lat") is not None
|
||||||
|
if not has_bw and not has_lat:
|
||||||
|
return {"status": "SKIP", "skipped": True,
|
||||||
|
"reason": "perftest (ib_write_bw / ib_write_lat) not installed"}
|
||||||
|
|
||||||
|
self.console.print(
|
||||||
|
f"[cyan]Cross-node RDMA — client → {server_addr}, "
|
||||||
|
f"devices: {', '.join(devices)}[/cyan]")
|
||||||
|
|
||||||
|
per_device = []
|
||||||
|
for idx, dev in enumerate(devices):
|
||||||
|
oob = base_port + idx
|
||||||
|
entry = {"device": dev}
|
||||||
|
|
||||||
|
if has_bw:
|
||||||
|
bw = self._cross_node_perftest(
|
||||||
|
"ib_write_bw", dev, server_target, server_addr, ib_port,
|
||||||
|
oob, gid_index, warmup,
|
||||||
|
extra=["--report_gbits", "-s", str(msg_size), "-n", str(iters)],
|
||||||
|
parse="bw")
|
||||||
|
entry["bandwidth_gbps"] = bw
|
||||||
|
if isinstance(bw, (int, float)):
|
||||||
|
entry["bw_status"] = "PASS" if bw >= min_bw else "WARN"
|
||||||
|
else:
|
||||||
|
entry["bw_status"] = "FAIL"
|
||||||
|
|
||||||
|
if has_lat:
|
||||||
|
lat = self._cross_node_perftest(
|
||||||
|
"ib_write_lat", dev, server_target, server_addr, ib_port,
|
||||||
|
oob, gid_index, warmup, extra=[], parse="lat")
|
||||||
|
if isinstance(lat, dict):
|
||||||
|
entry["latency_us"] = lat.get("typical")
|
||||||
|
entry["latency_p99_us"] = lat.get("p99")
|
||||||
|
t = lat.get("typical")
|
||||||
|
entry["lat_status"] = ("PASS" if isinstance(t, (int, float)) and 0 < t <= max_lat
|
||||||
|
else ("WARN" if isinstance(t, (int, float)) else "FAIL"))
|
||||||
|
else:
|
||||||
|
entry["latency_us"] = lat
|
||||||
|
entry["lat_status"] = "FAIL"
|
||||||
|
|
||||||
|
per_device.append(entry)
|
||||||
|
|
||||||
|
statuses = [e.get(k) for e in per_device for k in ("bw_status", "lat_status") if e.get(k)]
|
||||||
|
verdict = "PASS"
|
||||||
|
for s in statuses:
|
||||||
|
if s == "FAIL":
|
||||||
|
verdict = "FAIL"
|
||||||
|
break
|
||||||
|
if s == "WARN" and verdict == "PASS":
|
||||||
|
verdict = "WARN"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": verdict,
|
||||||
|
"server": server_addr,
|
||||||
|
"min_bandwidth_gbps": min_bw,
|
||||||
|
"max_latency_us": max_lat,
|
||||||
|
"per_device": per_device,
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _cross_node_perftest(self, tool: str, dev: str, server_target: str,
|
||||||
|
server_addr: str, ib_port: int, oob_port: int,
|
||||||
|
gid_index, warmup: float, extra: List[str], parse: str):
|
||||||
|
"""Start `tool` server on the peer via SSH, run the local client, parse output.
|
||||||
|
|
||||||
|
Returns a float (bw, Gb/s), a dict {typical, p99} (lat, µs), or an error string.
|
||||||
|
"""
|
||||||
|
tool_path = self._find_tool(tool)
|
||||||
|
if not tool_path:
|
||||||
|
return f"{tool} not installed"
|
||||||
|
|
||||||
|
flags = ["-d", dev, "-i", str(ib_port), "-p", str(oob_port), "-F"]
|
||||||
|
if gid_index is not None:
|
||||||
|
flags += ["-x", str(gid_index)]
|
||||||
|
flags += extra
|
||||||
|
|
||||||
|
server_cmd = " ".join([tool] + flags) # server: no host argument
|
||||||
|
server_proc = None
|
||||||
|
try:
|
||||||
|
server_proc = subprocess.Popen(
|
||||||
|
["ssh", "-o", "BatchMode=yes", "-o", "StrictHostKeyChecking=no",
|
||||||
|
server_target, server_cmd],
|
||||||
|
stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
|
||||||
|
time.sleep(warmup) # let the remote server bind before the client connects
|
||||||
|
|
||||||
|
client = subprocess.run([tool_path] + flags + [server_addr],
|
||||||
|
capture_output=True, text=True, timeout=120)
|
||||||
|
out = client.stdout + "\n" + (client.stderr or "")
|
||||||
|
return self._parse_perftest_lat(out) if parse == "lat" else self._parse_perftest_bw(out)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
return "timeout"
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
return f"error: {e}"
|
||||||
|
finally:
|
||||||
|
if server_proc and server_proc.poll() is None:
|
||||||
|
server_proc.terminate()
|
||||||
|
try:
|
||||||
|
server_proc.wait(timeout=5)
|
||||||
|
except Exception:
|
||||||
|
server_proc.kill()
|
||||||
|
# ib_write_* server normally exits after one run; pkill cleans up a
|
||||||
|
# leftover one if the client failed mid-handshake. -x matches the exact
|
||||||
|
# process name so it never kills this ssh command itself.
|
||||||
|
try:
|
||||||
|
subprocess.run(
|
||||||
|
["ssh", "-o", "BatchMode=yes", server_target, f"pkill -x {tool}"],
|
||||||
|
capture_output=True, timeout=10)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_perftest_bw(output: str) -> float:
|
||||||
|
"""Parse ib_write_bw rows (#bytes #iter BW_peak BW_avg ...); return max BW avg."""
|
||||||
|
best = 0.0
|
||||||
|
for line in output.splitlines():
|
||||||
|
parts = line.split()
|
||||||
|
if len(parts) >= 4:
|
||||||
|
try:
|
||||||
|
int(parts[0]) # #bytes column
|
||||||
|
best = max(best, float(parts[3])) # BW average[Gb/sec]
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
return round(best, 2) if best else 0.0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_perftest_lat(output: str) -> dict:
|
||||||
|
"""Parse ib_write_lat row (#bytes #iter t_min t_max t_typical t_avg ... 99%)."""
|
||||||
|
for line in output.splitlines():
|
||||||
|
parts = line.split()
|
||||||
|
if len(parts) >= 6:
|
||||||
|
try:
|
||||||
|
int(parts[0]); int(parts[1])
|
||||||
|
typical = float(parts[4]) # t_typical[usec]
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
p99 = None
|
||||||
|
if len(parts) >= 8:
|
||||||
|
try:
|
||||||
|
p99 = float(parts[7]) # 99% percentile[usec]
|
||||||
|
except ValueError:
|
||||||
|
p99 = None
|
||||||
|
return {"typical": round(typical, 2), "p99": round(p99, 2) if p99 else None}
|
||||||
|
return {"typical": None, "p99": None}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def print_results(results: dict, console: Console = None):
|
def print_results(results: dict, console: Console = None):
|
||||||
c = console or Console()
|
c = console or Console()
|
||||||
@ -296,3 +495,29 @@ class RDMATest:
|
|||||||
c.print(f" {t['test']}: [{sc}]{status}[/{sc}] "
|
c.print(f" {t['test']}: [{sc}]{status}[/{sc}] "
|
||||||
f"({lat:.2f} us, max: {t.get('max_allowed_us', 'N/A')} us)" if status != "SKIP"
|
f"({lat:.2f} us, max: {t.get('max_allowed_us', 'N/A')} us)" if status != "SKIP"
|
||||||
else f" {t['test']}: [dim]SKIPPED[/dim]")
|
else f" {t['test']}: [dim]SKIPPED[/dim]")
|
||||||
|
|
||||||
|
cn = results.get("cross_node")
|
||||||
|
if cn:
|
||||||
|
if cn.get("skipped"):
|
||||||
|
c.print(f"\n [bold]Cross-node RDMA[/bold]: [dim]SKIPPED "
|
||||||
|
f"({cn.get('reason', '')})[/dim]")
|
||||||
|
else:
|
||||||
|
v = cn.get("status", "?")
|
||||||
|
vc = "green" if v == "PASS" else ("yellow" if v == "WARN" else "red")
|
||||||
|
c.print(f"\n [bold]Cross-node RDMA[/bold] (server {cn.get('server')}) "
|
||||||
|
f"[{vc}]{v}[/{vc}] "
|
||||||
|
f"[dim]min {cn.get('min_bandwidth_gbps')} Gb/s, "
|
||||||
|
f"max {cn.get('max_latency_us')} µs[/dim]")
|
||||||
|
for e in cn.get("per_device", []):
|
||||||
|
bw = e.get("bandwidth_gbps")
|
||||||
|
lat = e.get("latency_us")
|
||||||
|
bws = e.get("bw_status", "")
|
||||||
|
lts = e.get("lat_status", "")
|
||||||
|
bc = "green" if bws == "PASS" else ("yellow" if bws == "WARN" else "red")
|
||||||
|
lc = "green" if lts == "PASS" else ("yellow" if lts == "WARN" else "red")
|
||||||
|
bw_s = f"{bw:.1f} Gb/s" if isinstance(bw, (int, float)) else str(bw)
|
||||||
|
lat_s = f"{lat:.2f} µs" if isinstance(lat, (int, float)) else str(lat)
|
||||||
|
p99 = e.get("latency_p99_us")
|
||||||
|
p99_s = f", p99 {p99:.2f}" if isinstance(p99, (int, float)) else ""
|
||||||
|
c.print(f" {e['device']}: BW [{bc}]{bw_s}[/{bc}] | "
|
||||||
|
f"lat [{lc}]{lat_s}[/{lc}]{p99_s}")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user