- gpu_specs: 新增 H20/H20-3e (中国合规版 H200) 规格定义,并修复 GPU 名称匹配顺序,避免 "H200" 被 "H20" 子串误匹配 - benchmark(compute): 引入 L2 cache 规避的 matrix pool 轮换 + 可选 torch.compile(max-autotune),FP8 增加 _scaled_mm 探测, 显著提升 FP16/BF16/FP8 实测吞吐准确性 - benchmark(memory): nvbandwidth 增加 --disableAffinity 规避 fabricmanager NVML 不兼容;全 0 结果时自动回退到 PyTorch; D2D 平均值排除对角线零值 - nccl: 各通信操作 (AllReduce/AllToAll/Broadcast 等) 使用独立 带宽阈值比例,避免 AllToAll 误报 WARN - rdma: 仅按 link_layer=InfiniBand 过滤端口,无 IB 硬件或全 DOWN 时直接 SKIP 而非报错 - stress: 计算矩阵尺寸封顶 4096,并改为先并发派发再统一同步, 修复 8 卡串行执行导致 duration 严重超时的问题 - report: 兼容 RDMA SKIP 状态与 PyTorch 回退场景的 Memory 判定, 避免回退结果被误判为 FAIL - config: 新增 benchmark.compute.use_compile 开关 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
299 lines
12 KiB
Python
299 lines
12 KiB
Python
"""RDMA / InfiniBand bandwidth and latency test module."""
|
|
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
from datetime import datetime
|
|
from typing import Optional, List
|
|
|
|
from rich.console import Console
|
|
from rich.table import Table
|
|
|
|
|
|
class RDMATest:
|
|
|
|
def __init__(self, config: dict):
|
|
self.config = config
|
|
self.console = Console()
|
|
self.rdma_cfg = config.get("rdma", {})
|
|
|
|
def _find_tool(self, name: str) -> Optional[str]:
|
|
p = shutil.which(name)
|
|
if p:
|
|
return p
|
|
return None
|
|
|
|
def _get_ib_devices(self) -> List[str]:
|
|
devices = []
|
|
ib_path = "/sys/class/infiniband"
|
|
if os.path.isdir(ib_path):
|
|
devices = sorted(os.listdir(ib_path))
|
|
return devices
|
|
|
|
def _get_ib_ports(self, device: str) -> List[str]:
|
|
ports = []
|
|
ports_dir = f"/sys/class/infiniband/{device}/ports"
|
|
if os.path.isdir(ports_dir):
|
|
ports = sorted(os.listdir(ports_dir))
|
|
return ports
|
|
|
|
@staticmethod
|
|
def _read_sys(path: str) -> str:
|
|
try:
|
|
with open(path) as f:
|
|
return f.read().strip()
|
|
except (FileNotFoundError, PermissionError, OSError):
|
|
return ""
|
|
|
|
def run(self) -> dict:
|
|
devices = self._get_ib_devices()
|
|
if not devices:
|
|
self.console.print(
|
|
"[yellow]No InfiniBand devices found — skipping RDMA test[/yellow]"
|
|
)
|
|
return {
|
|
"status": "SKIP", "skipped": True,
|
|
"reason": "no IB hardware detected",
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
|
|
# Only consider ports whose link_layer is InfiniBand — Ethernet
|
|
# bond/management interfaces (e.g. mlx5_bond_0) can show ACTIVE state
|
|
# without actually providing IB fabric connectivity.
|
|
ib_devices = []
|
|
active_ib_port = False
|
|
for dev in devices:
|
|
for port in self._get_ib_ports(dev):
|
|
link_layer = self._read_sys(
|
|
f"/sys/class/infiniband/{dev}/ports/{port}/link_layer")
|
|
if link_layer != "InfiniBand":
|
|
continue
|
|
ib_devices.append((dev, port))
|
|
state = self._read_sys(
|
|
f"/sys/class/infiniband/{dev}/ports/{port}/state")
|
|
if "ACTIVE" in state.upper():
|
|
active_ib_port = True
|
|
|
|
device_info = self._collect_device_info(devices)
|
|
if not ib_devices:
|
|
self.console.print(
|
|
"[yellow]No InfiniBand-link_layer ports present — "
|
|
"skipping RDMA benchmarks[/yellow]"
|
|
)
|
|
return {
|
|
"status": "SKIP", "skipped": True,
|
|
"reason": "no InfiniBand link_layer ports (only Ethernet/RoCE)",
|
|
"devices": device_info,
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
if not active_ib_port:
|
|
self.console.print(
|
|
f"[yellow]{len(ib_devices)} IB port(s) detected but all DOWN — "
|
|
f"fabric not wired, skipping RDMA benchmarks[/yellow]"
|
|
)
|
|
return {
|
|
"status": "SKIP", "skipped": True,
|
|
"reason": f"{len(ib_devices)} IB port(s) found but all DOWN (fabric not wired)",
|
|
"devices": device_info,
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
|
|
self.console.print(f"[cyan]RDMA Test - Devices: {', '.join(devices)}[/cyan]")
|
|
|
|
bw_results = self._run_bandwidth_tests(devices)
|
|
latency_results = self._run_latency_tests(devices)
|
|
|
|
all_passed = all(
|
|
r.get("status") == "PASS"
|
|
for r in bw_results + latency_results
|
|
if isinstance(r, dict)
|
|
)
|
|
|
|
return {
|
|
"passed": all_passed,
|
|
"devices": device_info,
|
|
"bandwidth_tests": bw_results,
|
|
"latency_tests": latency_results,
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
|
|
def _collect_device_info(self, devices: List[str]) -> List[dict]:
|
|
info = []
|
|
for dev in devices:
|
|
dev_info = {"name": dev, "ports": []}
|
|
ports = self._get_ib_ports(dev)
|
|
for port in ports:
|
|
port_info = {"port": port}
|
|
rate_path = f"/sys/class/infiniband/{dev}/ports/{port}/rate"
|
|
state_path = f"/sys/class/infiniband/{dev}/ports/{port}/state"
|
|
phys_state_path = f"/sys/class/infiniband/{dev}/ports/{port}/phys_state"
|
|
gid_path = f"/sys/class/infiniband/{dev}/ports/{port}/gids/0"
|
|
|
|
for label, path in [("rate", rate_path), ("state", state_path),
|
|
("phys_state", phys_state_path), ("gid", gid_path)]:
|
|
try:
|
|
with open(path) as f:
|
|
port_info[label] = f.read().strip()
|
|
except (FileNotFoundError, PermissionError):
|
|
port_info[label] = "N/A"
|
|
|
|
dev_info["ports"].append(port_info)
|
|
info.append(dev_info)
|
|
return info
|
|
|
|
def _run_ib_command(self, cmd: List[str], timeout: int = 60) -> dict:
|
|
try:
|
|
r = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout)
|
|
if r.returncode == 0:
|
|
return {"status": "PASS", "output": r.stdout.strip()}
|
|
return {"status": "FAIL", "error": r.stderr.strip()[:200]}
|
|
except subprocess.TimeoutExpired:
|
|
return {"status": "FAIL", "error": "timeout"}
|
|
except FileNotFoundError:
|
|
return {"status": "SKIP", "error": "tool not found"}
|
|
except Exception as e:
|
|
return {"status": "FAIL", "error": str(e)}
|
|
|
|
def _run_bandwidth_tests(self, devices: List[str]) -> List[dict]:
|
|
results = []
|
|
ib_write_bw = self._find_tool("ib_write_bw")
|
|
ib_read_bw = self._find_tool("ib_read_bw")
|
|
min_bw = self.rdma_cfg.get("min_bandwidth_gbps", 50)
|
|
msg_size = self.rdma_cfg.get("msg_size", 65536)
|
|
iters = self.rdma_cfg.get("ib_iterations", 1000)
|
|
dx = self.rdma_cfg.get("ib_device", None)
|
|
port = self.rdma_cfg.get("ib_port", 1)
|
|
|
|
for tool, label in [(ib_write_bw, "ib_write_bw"), (ib_read_bw, "ib_read_bw")]:
|
|
if not tool:
|
|
results.append({"test": label, "status": "SKIP", "error": "not installed"})
|
|
continue
|
|
|
|
server_cmd = [tool, "-d", dx or devices[0], "-i", str(port), "-s", str(msg_size)]
|
|
client_cmd = server_cmd + ["localhost"]
|
|
|
|
server = subprocess.Popen(server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
import time
|
|
time.sleep(1)
|
|
|
|
try:
|
|
client = subprocess.run(client_cmd, capture_output=True, text=True, timeout=60)
|
|
server.wait(timeout=10)
|
|
|
|
output = client.stdout + server.stdout.read() if server.stdout else ""
|
|
bw_mbps = 0
|
|
for line in output.split("\n"):
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
parts = line.split()
|
|
try:
|
|
bw_mbps = max(bw_mbps, float(parts[-1]))
|
|
except (ValueError, IndexError):
|
|
continue
|
|
|
|
bw_gbps = bw_mbps / 1000 if bw_mbps else 0
|
|
status = "PASS" if bw_gbps >= min_bw else "WARN"
|
|
results.append({
|
|
"test": label,
|
|
"status": status,
|
|
"bandwidth_gbps": round(bw_gbps, 2),
|
|
"min_required_gbps": min_bw,
|
|
})
|
|
except Exception as e:
|
|
server.kill()
|
|
results.append({"test": label, "status": "FAIL", "error": str(e)})
|
|
|
|
return results
|
|
|
|
def _run_latency_tests(self, devices: List[str]) -> List[dict]:
|
|
results = []
|
|
ib_write_lat = self._find_tool("ib_write_lat")
|
|
ib_read_lat = self._find_tool("ib_read_lat")
|
|
max_lat_us = self.rdma_cfg.get("max_latency_us", 10)
|
|
dx = self.rdma_cfg.get("ib_device", None)
|
|
port = self.rdma_cfg.get("ib_port", 1)
|
|
|
|
for tool, label in [(ib_write_lat, "ib_write_lat"), (ib_read_lat, "ib_read_lat")]:
|
|
if not tool:
|
|
results.append({"test": label, "status": "SKIP", "error": "not installed"})
|
|
continue
|
|
|
|
server_cmd = [tool, "-d", dx or devices[0], "-i", str(port)]
|
|
client_cmd = server_cmd + ["localhost"]
|
|
|
|
server = subprocess.Popen(server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
import time
|
|
time.sleep(1)
|
|
|
|
try:
|
|
client = subprocess.run(client_cmd, capture_output=True, text=True, timeout=60)
|
|
server.wait(timeout=10)
|
|
|
|
output = client.stdout + server.stdout.read() if server.stdout else ""
|
|
lat_us = 0
|
|
for line in output.split("\n"):
|
|
parts = line.strip().split()
|
|
try:
|
|
lat_us = max(lat_us, float(parts[-1]))
|
|
except (ValueError, IndexError):
|
|
continue
|
|
|
|
status = "PASS" if 0 < lat_us <= max_lat_us else ("WARN" if lat_us > 0 else "FAIL")
|
|
results.append({
|
|
"test": label,
|
|
"status": status,
|
|
"latency_us": round(lat_us, 2),
|
|
"max_allowed_us": max_lat_us,
|
|
})
|
|
except Exception as e:
|
|
server.kill()
|
|
results.append({"test": label, "status": "FAIL", "error": str(e)})
|
|
|
|
return results
|
|
|
|
@staticmethod
|
|
def print_results(results: dict, console: Console = None):
|
|
c = console or Console()
|
|
if results.get("skipped") or results.get("status") == "SKIP":
|
|
c.print(f"\n[bold yellow]RDMA/InfiniBand: SKIPPED[/bold yellow] "
|
|
f"[dim]({results.get('reason', 'no IB hardware')})[/dim]")
|
|
return
|
|
if "error" in results:
|
|
c.print(f"[bold red]Error: {results['error']}[/bold red]")
|
|
return
|
|
|
|
devices = results.get("devices", [])
|
|
c.print(f"\n[bold cyan]RDMA/InfiniBand Test Results[/bold cyan]")
|
|
c.print(f" Devices found: {len(devices)}")
|
|
|
|
for dev in devices:
|
|
c.print(f"\n [bold]{dev['name']}[/bold]")
|
|
for p in dev.get("ports", []):
|
|
state = p.get("state", "N/A")
|
|
color = "green" if "Active" in state else "red"
|
|
c.print(f" Port {p['port']}: [{color}]{state}[/{color}] | "
|
|
f"Rate: {p.get('rate', 'N/A')} | GID: {p.get('gid', 'N/A')[:20]}")
|
|
|
|
bw_tests = results.get("bandwidth_tests", [])
|
|
if bw_tests:
|
|
c.print("\n [bold]Bandwidth Tests[/bold]")
|
|
for t in bw_tests:
|
|
status = t.get("status", "SKIP")
|
|
sc = "green" if status == "PASS" else ("yellow" if status == "WARN" else "red")
|
|
bw = t.get("bandwidth_gbps", 0)
|
|
c.print(f" {t['test']}: [{sc}]{status}[/{sc}] "
|
|
f"({bw:.2f} GB/s, min: {t.get('min_required_gbps', 'N/A')} GB/s)" if status != "SKIP"
|
|
else f" {t['test']}: [dim]SKIPPED[/dim]")
|
|
|
|
lat_tests = results.get("latency_tests", [])
|
|
if lat_tests:
|
|
c.print("\n [bold]Latency Tests[/bold]")
|
|
for t in lat_tests:
|
|
status = t.get("status", "SKIP")
|
|
sc = "green" if status == "PASS" else ("yellow" if status == "WARN" else "red")
|
|
lat = t.get("latency_us", 0)
|
|
c.print(f" {t['test']}: [{sc}]{status}[/{sc}] "
|
|
f"({lat:.2f} us, max: {t.get('max_allowed_us', 'N/A')} us)" if status != "SKIP"
|
|
else f" {t['test']}: [dim]SKIPPED[/dim]")
|