zulifeng dd77a882f1 feat: 跨机 RDMA 并入 rdma_test.py + H800 算力门槛对齐 H100
- modules/rdma_test.py: 新增 SSH 编排的跨机 RDMA(run_cross_node /
  _cross_node_perftest / 解析器),从 client 端逐设备拉起对端 perftest
  server 跑本地 client,替代已删除的 scripts/rdma_cross_node.sh;两机
  4×NDR400 实测全 PASS(~387-392 Gb/s,~2 µs)。
- configs/default.yaml: 新增 rdma.cross_node 配置块(默认 enabled:false)。
- modules/gpu_specs.py: H800 PASS 门槛对齐 H100 实测地板
  (tf32 400->385, bf16 720->730, fp8 1400->1200);H800=H100 硅片,
  PyTorch tensorwise fp8 天花板 ~1310,原 1400 不可达。

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-25 19:38:43 +08:00

524 lines
22 KiB
Python

"""RDMA / InfiniBand bandwidth and latency test module."""
import os
import shutil
import subprocess
import time
from datetime import datetime
from typing import Optional, List
from rich.console import Console
from rich.table import Table
class RDMATest:
def __init__(self, config: dict):
self.config = config
self.console = Console()
self.rdma_cfg = config.get("rdma", {})
def _find_tool(self, name: str) -> Optional[str]:
p = shutil.which(name)
if p:
return p
return None
def _get_ib_devices(self) -> List[str]:
devices = []
ib_path = "/sys/class/infiniband"
if os.path.isdir(ib_path):
devices = sorted(os.listdir(ib_path))
return devices
def _get_ib_ports(self, device: str) -> List[str]:
ports = []
ports_dir = f"/sys/class/infiniband/{device}/ports"
if os.path.isdir(ports_dir):
ports = sorted(os.listdir(ports_dir))
return ports
@staticmethod
def _read_sys(path: str) -> str:
try:
with open(path) as f:
return f.read().strip()
except (FileNotFoundError, PermissionError, OSError):
return ""
def run(self) -> dict:
devices = self._get_ib_devices()
if not devices:
self.console.print(
"[yellow]No InfiniBand devices found — skipping RDMA test[/yellow]"
)
return {
"status": "SKIP", "skipped": True,
"reason": "no IB hardware detected",
"timestamp": datetime.now().isoformat(),
}
# Only consider ports whose link_layer is InfiniBand — Ethernet
# bond/management interfaces (e.g. mlx5_bond_0) can show ACTIVE state
# without actually providing IB fabric connectivity.
ib_devices = []
active_ib_port = False
for dev in devices:
for port in self._get_ib_ports(dev):
link_layer = self._read_sys(
f"/sys/class/infiniband/{dev}/ports/{port}/link_layer")
if link_layer != "InfiniBand":
continue
ib_devices.append((dev, port))
state = self._read_sys(
f"/sys/class/infiniband/{dev}/ports/{port}/state")
if "ACTIVE" in state.upper():
active_ib_port = True
device_info = self._collect_device_info(devices)
if not ib_devices:
self.console.print(
"[yellow]No InfiniBand-link_layer ports present — "
"skipping RDMA benchmarks[/yellow]"
)
return {
"status": "SKIP", "skipped": True,
"reason": "no InfiniBand link_layer ports (only Ethernet/RoCE)",
"devices": device_info,
"timestamp": datetime.now().isoformat(),
}
if not active_ib_port:
self.console.print(
f"[yellow]{len(ib_devices)} IB port(s) detected but all DOWN — "
f"fabric not wired, skipping RDMA benchmarks[/yellow]"
)
return {
"status": "SKIP", "skipped": True,
"reason": f"{len(ib_devices)} IB port(s) found but all DOWN (fabric not wired)",
"devices": device_info,
"timestamp": datetime.now().isoformat(),
}
self.console.print(f"[cyan]RDMA Test - Devices: {', '.join(devices)}[/cyan]")
bw_results = self._run_bandwidth_tests(devices)
latency_results = self._run_latency_tests(devices)
all_passed = all(
r.get("status") == "PASS"
for r in bw_results + latency_results
if isinstance(r, dict)
)
result = {
"passed": all_passed,
"devices": device_info,
"bandwidth_tests": bw_results,
"latency_tests": latency_results,
"timestamp": datetime.now().isoformat(),
}
# Cross-node (two-host) RDMA, run only when a peer is configured.
if (self.rdma_cfg.get("cross_node", {}) or {}).get("enabled"):
result["cross_node"] = self.run_cross_node()
return result
def _collect_device_info(self, devices: List[str]) -> List[dict]:
info = []
for dev in devices:
dev_info = {"name": dev, "ports": []}
ports = self._get_ib_ports(dev)
for port in ports:
port_info = {"port": port}
rate_path = f"/sys/class/infiniband/{dev}/ports/{port}/rate"
state_path = f"/sys/class/infiniband/{dev}/ports/{port}/state"
phys_state_path = f"/sys/class/infiniband/{dev}/ports/{port}/phys_state"
gid_path = f"/sys/class/infiniband/{dev}/ports/{port}/gids/0"
for label, path in [("rate", rate_path), ("state", state_path),
("phys_state", phys_state_path), ("gid", gid_path)]:
try:
with open(path) as f:
port_info[label] = f.read().strip()
except (FileNotFoundError, PermissionError):
port_info[label] = "N/A"
dev_info["ports"].append(port_info)
info.append(dev_info)
return info
def _run_ib_command(self, cmd: List[str], timeout: int = 60) -> dict:
try:
r = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout)
if r.returncode == 0:
return {"status": "PASS", "output": r.stdout.strip()}
return {"status": "FAIL", "error": r.stderr.strip()[:200]}
except subprocess.TimeoutExpired:
return {"status": "FAIL", "error": "timeout"}
except FileNotFoundError:
return {"status": "SKIP", "error": "tool not found"}
except Exception as e:
return {"status": "FAIL", "error": str(e)}
def _run_bandwidth_tests(self, devices: List[str]) -> List[dict]:
results = []
ib_write_bw = self._find_tool("ib_write_bw")
ib_read_bw = self._find_tool("ib_read_bw")
min_bw = self.rdma_cfg.get("min_bandwidth_gbps", 50)
msg_size = self.rdma_cfg.get("msg_size", 65536)
iters = self.rdma_cfg.get("ib_iterations", 1000)
dx = self.rdma_cfg.get("ib_device", None)
port = self.rdma_cfg.get("ib_port", 1)
for tool, label in [(ib_write_bw, "ib_write_bw"), (ib_read_bw, "ib_read_bw")]:
if not tool:
results.append({"test": label, "status": "SKIP", "error": "not installed"})
continue
server_cmd = [tool, "-d", dx or devices[0], "-i", str(port), "-s", str(msg_size)]
client_cmd = server_cmd + ["localhost"]
server = subprocess.Popen(server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
import time
time.sleep(1)
try:
client = subprocess.run(client_cmd, capture_output=True, text=True, timeout=60)
server.wait(timeout=10)
output = client.stdout + server.stdout.read() if server.stdout else ""
bw_mbps = 0
for line in output.split("\n"):
line = line.strip()
if not line:
continue
parts = line.split()
try:
bw_mbps = max(bw_mbps, float(parts[-1]))
except (ValueError, IndexError):
continue
bw_gbps = bw_mbps / 1000 if bw_mbps else 0
status = "PASS" if bw_gbps >= min_bw else "WARN"
results.append({
"test": label,
"status": status,
"bandwidth_gbps": round(bw_gbps, 2),
"min_required_gbps": min_bw,
})
except Exception as e:
server.kill()
results.append({"test": label, "status": "FAIL", "error": str(e)})
return results
def _run_latency_tests(self, devices: List[str]) -> List[dict]:
results = []
ib_write_lat = self._find_tool("ib_write_lat")
ib_read_lat = self._find_tool("ib_read_lat")
max_lat_us = self.rdma_cfg.get("max_latency_us", 10)
dx = self.rdma_cfg.get("ib_device", None)
port = self.rdma_cfg.get("ib_port", 1)
for tool, label in [(ib_write_lat, "ib_write_lat"), (ib_read_lat, "ib_read_lat")]:
if not tool:
results.append({"test": label, "status": "SKIP", "error": "not installed"})
continue
server_cmd = [tool, "-d", dx or devices[0], "-i", str(port)]
client_cmd = server_cmd + ["localhost"]
server = subprocess.Popen(server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
import time
time.sleep(1)
try:
client = subprocess.run(client_cmd, capture_output=True, text=True, timeout=60)
server.wait(timeout=10)
output = client.stdout + server.stdout.read() if server.stdout else ""
lat_us = 0
for line in output.split("\n"):
parts = line.strip().split()
try:
lat_us = max(lat_us, float(parts[-1]))
except (ValueError, IndexError):
continue
status = "PASS" if 0 < lat_us <= max_lat_us else ("WARN" if lat_us > 0 else "FAIL")
results.append({
"test": label,
"status": status,
"latency_us": round(lat_us, 2),
"max_allowed_us": max_lat_us,
})
except Exception as e:
server.kill()
results.append({"test": label, "status": "FAIL", "error": str(e)})
return results
# ------------------------------------------------------------------
# Cross-node (two-host) RDMA over perftest, orchestrated via SSH.
# Runs FROM the client host: for each IB device it launches the matching
# perftest server on the peer over SSH (held open in a live ssh channel),
# then runs the local client against the peer's OOB address and parses the
# result. Replaces the old standalone scripts/rdma_cross_node.sh.
# ------------------------------------------------------------------
def _active_ib_devices(self) -> List[str]:
"""IB devices whose port 1 is InfiniBand link_layer and ACTIVE."""
out = []
for dev in self._get_ib_devices():
for port in self._get_ib_ports(dev):
ll = self._read_sys(f"/sys/class/infiniband/{dev}/ports/{port}/link_layer")
st = self._read_sys(f"/sys/class/infiniband/{dev}/ports/{port}/state")
if ll == "InfiniBand" and "ACTIVE" in st.upper():
out.append(dev)
break
return out
def run_cross_node(self) -> dict:
cn = self.rdma_cfg.get("cross_node", {}) or {}
if not cn.get("enabled"):
return {"status": "SKIP", "skipped": True,
"reason": "rdma.cross_node.enabled is false"}
server = cn.get("server")
if not server:
return {"status": "SKIP", "skipped": True,
"reason": "rdma.cross_node.server (peer ssh address) not set"}
ssh_user = cn.get("ssh_user", "root")
server_target = server if "@" in server else f"{ssh_user}@{server}"
# OOB address the client's perftest connects to (defaults to the ssh host).
server_addr = cn.get("server_addr") or server.split("@")[-1]
ib_port = cn.get("ib_port", 1)
gid_index = cn.get("gid_index")
msg_size = cn.get("msg_size", 1048576)
iters = cn.get("iters", 5000)
base_port = cn.get("base_oob_port", 18515)
warmup = cn.get("server_warmup_sec", 2.0)
min_bw = cn.get("min_bandwidth_gbps", 350)
max_lat = cn.get("max_latency_us", 5)
devices = cn.get("devices") or self._active_ib_devices()
if not devices:
return {"status": "SKIP", "skipped": True,
"reason": "no active InfiniBand devices to test"}
has_bw = self._find_tool("ib_write_bw") is not None
has_lat = self._find_tool("ib_write_lat") is not None
if not has_bw and not has_lat:
return {"status": "SKIP", "skipped": True,
"reason": "perftest (ib_write_bw / ib_write_lat) not installed"}
self.console.print(
f"[cyan]Cross-node RDMA — client → {server_addr}, "
f"devices: {', '.join(devices)}[/cyan]")
per_device = []
for idx, dev in enumerate(devices):
oob = base_port + idx
entry = {"device": dev}
if has_bw:
bw = self._cross_node_perftest(
"ib_write_bw", dev, server_target, server_addr, ib_port,
oob, gid_index, warmup,
extra=["--report_gbits", "-s", str(msg_size), "-n", str(iters)],
parse="bw")
entry["bandwidth_gbps"] = bw
if isinstance(bw, (int, float)):
entry["bw_status"] = "PASS" if bw >= min_bw else "WARN"
else:
entry["bw_status"] = "FAIL"
if has_lat:
lat = self._cross_node_perftest(
"ib_write_lat", dev, server_target, server_addr, ib_port,
oob, gid_index, warmup, extra=[], parse="lat")
if isinstance(lat, dict):
entry["latency_us"] = lat.get("typical")
entry["latency_p99_us"] = lat.get("p99")
t = lat.get("typical")
entry["lat_status"] = ("PASS" if isinstance(t, (int, float)) and 0 < t <= max_lat
else ("WARN" if isinstance(t, (int, float)) else "FAIL"))
else:
entry["latency_us"] = lat
entry["lat_status"] = "FAIL"
per_device.append(entry)
statuses = [e.get(k) for e in per_device for k in ("bw_status", "lat_status") if e.get(k)]
verdict = "PASS"
for s in statuses:
if s == "FAIL":
verdict = "FAIL"
break
if s == "WARN" and verdict == "PASS":
verdict = "WARN"
return {
"status": verdict,
"server": server_addr,
"min_bandwidth_gbps": min_bw,
"max_latency_us": max_lat,
"per_device": per_device,
"timestamp": datetime.now().isoformat(),
}
def _cross_node_perftest(self, tool: str, dev: str, server_target: str,
server_addr: str, ib_port: int, oob_port: int,
gid_index, warmup: float, extra: List[str], parse: str):
"""Start `tool` server on the peer via SSH, run the local client, parse output.
Returns a float (bw, Gb/s), a dict {typical, p99} (lat, µs), or an error string.
"""
tool_path = self._find_tool(tool)
if not tool_path:
return f"{tool} not installed"
flags = ["-d", dev, "-i", str(ib_port), "-p", str(oob_port), "-F"]
if gid_index is not None:
flags += ["-x", str(gid_index)]
flags += extra
server_cmd = " ".join([tool] + flags) # server: no host argument
server_proc = None
try:
server_proc = subprocess.Popen(
["ssh", "-o", "BatchMode=yes", "-o", "StrictHostKeyChecking=no",
server_target, server_cmd],
stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
time.sleep(warmup) # let the remote server bind before the client connects
client = subprocess.run([tool_path] + flags + [server_addr],
capture_output=True, text=True, timeout=120)
out = client.stdout + "\n" + (client.stderr or "")
return self._parse_perftest_lat(out) if parse == "lat" else self._parse_perftest_bw(out)
except subprocess.TimeoutExpired:
return "timeout"
except Exception as e: # noqa: BLE001
return f"error: {e}"
finally:
if server_proc and server_proc.poll() is None:
server_proc.terminate()
try:
server_proc.wait(timeout=5)
except Exception:
server_proc.kill()
# ib_write_* server normally exits after one run; pkill cleans up a
# leftover one if the client failed mid-handshake. -x matches the exact
# process name so it never kills this ssh command itself.
try:
subprocess.run(
["ssh", "-o", "BatchMode=yes", server_target, f"pkill -x {tool}"],
capture_output=True, timeout=10)
except Exception:
pass
@staticmethod
def _parse_perftest_bw(output: str) -> float:
"""Parse ib_write_bw rows (#bytes #iter BW_peak BW_avg ...); return max BW avg."""
best = 0.0
for line in output.splitlines():
parts = line.split()
if len(parts) >= 4:
try:
int(parts[0]) # #bytes column
best = max(best, float(parts[3])) # BW average[Gb/sec]
except ValueError:
continue
return round(best, 2) if best else 0.0
@staticmethod
def _parse_perftest_lat(output: str) -> dict:
"""Parse ib_write_lat row (#bytes #iter t_min t_max t_typical t_avg ... 99%)."""
for line in output.splitlines():
parts = line.split()
if len(parts) >= 6:
try:
int(parts[0]); int(parts[1])
typical = float(parts[4]) # t_typical[usec]
except ValueError:
continue
p99 = None
if len(parts) >= 8:
try:
p99 = float(parts[7]) # 99% percentile[usec]
except ValueError:
p99 = None
return {"typical": round(typical, 2), "p99": round(p99, 2) if p99 else None}
return {"typical": None, "p99": None}
@staticmethod
def print_results(results: dict, console: Console = None):
c = console or Console()
if results.get("skipped") or results.get("status") == "SKIP":
c.print(f"\n[bold yellow]RDMA/InfiniBand: SKIPPED[/bold yellow] "
f"[dim]({results.get('reason', 'no IB hardware')})[/dim]")
return
if "error" in results:
c.print(f"[bold red]Error: {results['error']}[/bold red]")
return
devices = results.get("devices", [])
c.print(f"\n[bold cyan]RDMA/InfiniBand Test Results[/bold cyan]")
c.print(f" Devices found: {len(devices)}")
for dev in devices:
c.print(f"\n [bold]{dev['name']}[/bold]")
for p in dev.get("ports", []):
state = p.get("state", "N/A")
color = "green" if "Active" in state else "red"
c.print(f" Port {p['port']}: [{color}]{state}[/{color}] | "
f"Rate: {p.get('rate', 'N/A')} | GID: {p.get('gid', 'N/A')[:20]}")
bw_tests = results.get("bandwidth_tests", [])
if bw_tests:
c.print("\n [bold]Bandwidth Tests[/bold]")
for t in bw_tests:
status = t.get("status", "SKIP")
sc = "green" if status == "PASS" else ("yellow" if status == "WARN" else "red")
bw = t.get("bandwidth_gbps", 0)
c.print(f" {t['test']}: [{sc}]{status}[/{sc}] "
f"({bw:.2f} GB/s, min: {t.get('min_required_gbps', 'N/A')} GB/s)" if status != "SKIP"
else f" {t['test']}: [dim]SKIPPED[/dim]")
lat_tests = results.get("latency_tests", [])
if lat_tests:
c.print("\n [bold]Latency Tests[/bold]")
for t in lat_tests:
status = t.get("status", "SKIP")
sc = "green" if status == "PASS" else ("yellow" if status == "WARN" else "red")
lat = t.get("latency_us", 0)
c.print(f" {t['test']}: [{sc}]{status}[/{sc}] "
f"({lat:.2f} us, max: {t.get('max_allowed_us', 'N/A')} us)" if status != "SKIP"
else f" {t['test']}: [dim]SKIPPED[/dim]")
cn = results.get("cross_node")
if cn:
if cn.get("skipped"):
c.print(f"\n [bold]Cross-node RDMA[/bold]: [dim]SKIPPED "
f"({cn.get('reason', '')})[/dim]")
else:
v = cn.get("status", "?")
vc = "green" if v == "PASS" else ("yellow" if v == "WARN" else "red")
c.print(f"\n [bold]Cross-node RDMA[/bold] (server {cn.get('server')}) "
f"[{vc}]{v}[/{vc}] "
f"[dim]min {cn.get('min_bandwidth_gbps')} Gb/s, "
f"max {cn.get('max_latency_us')} µs[/dim]")
for e in cn.get("per_device", []):
bw = e.get("bandwidth_gbps")
lat = e.get("latency_us")
bws = e.get("bw_status", "")
lts = e.get("lat_status", "")
bc = "green" if bws == "PASS" else ("yellow" if bws == "WARN" else "red")
lc = "green" if lts == "PASS" else ("yellow" if lts == "WARN" else "red")
bw_s = f"{bw:.1f} Gb/s" if isinstance(bw, (int, float)) else str(bw)
lat_s = f"{lat:.2f} µs" if isinstance(lat, (int, float)) else str(lat)
p99 = e.get("latency_p99_us")
p99_s = f", p99 {p99:.2f}" if isinstance(p99, (int, float)) else ""
c.print(f" {e['device']}: BW [{bc}]{bw_s}[/{bc}] | "
f"lat [{lc}]{lat_s}[/{lc}]{p99_s}")