test_gpu_scripts/modules/rdma_test.py

606 lines
26 KiB
Python

"""RDMA / InfiniBand bandwidth and latency test module."""
import glob
import os
import shutil
import subprocess
from datetime import datetime
from typing import Optional, List
from rich.console import Console
from rich.table import Table
from modules.gpu_specs import resolve_tools_dir
class RDMATest:
def __init__(self, config: dict):
self.config = config
self.console = Console()
self.rdma_cfg = config.get("rdma", {})
self.tools_dir = resolve_tools_dir(config)
def _find_tool(self, name: str) -> Optional[str]:
p = shutil.which(name)
if p:
return p
candidates = [
os.path.join(self.tools_dir, "perftest", name),
os.path.join(self.tools_dir, "perftest", "bin", name),
os.path.join(self.tools_dir, "rdma", name),
os.path.join(self.tools_dir, name),
]
for path in candidates:
if os.path.isfile(path) and os.access(path, os.X_OK):
return path
for path in glob.glob(os.path.join(self.tools_dir, "**", name), recursive=True):
if os.path.isfile(path) and os.access(path, os.X_OK):
return path
return None
def _get_ib_devices(self) -> List[str]:
devices = []
ib_path = "/sys/class/infiniband"
if os.path.isdir(ib_path):
devices = sorted(os.listdir(ib_path))
return devices
def _get_ib_ports(self, device: str) -> List[str]:
ports = []
ports_dir = f"/sys/class/infiniband/{device}/ports"
if os.path.isdir(ports_dir):
ports = sorted(os.listdir(ports_dir))
return ports
@staticmethod
def _read_sys(path: str) -> str:
try:
with open(path) as f:
return f.read().strip()
except (FileNotFoundError, PermissionError, OSError):
return ""
def run(self) -> dict:
devices = self._get_ib_devices()
if not devices:
self.console.print(
"[yellow]No InfiniBand devices found — skipping RDMA test[/yellow]"
)
return {
"status": "SKIP", "skipped": True,
"reason": "no IB hardware detected",
"timestamp": datetime.now().isoformat(),
}
# Only consider ports whose link_layer is InfiniBand — Ethernet
# bond/management interfaces (e.g. mlx5_bond_0) can show ACTIVE state
# without actually providing IB fabric connectivity.
ib_devices = []
active_ib_port = False
for dev in devices:
for port in self._get_ib_ports(dev):
link_layer = self._read_sys(
f"/sys/class/infiniband/{dev}/ports/{port}/link_layer")
if link_layer != "InfiniBand":
continue
ib_devices.append((dev, port))
state = self._read_sys(
f"/sys/class/infiniband/{dev}/ports/{port}/state")
if "ACTIVE" in state.upper():
active_ib_port = True
device_info = self._collect_device_info(devices)
if not ib_devices:
self.console.print(
"[yellow]No InfiniBand-link_layer ports present — "
"skipping RDMA benchmarks[/yellow]"
)
return {
"status": "SKIP", "skipped": True,
"reason": "no InfiniBand link_layer ports (only Ethernet/RoCE)",
"devices": device_info,
"timestamp": datetime.now().isoformat(),
}
if not active_ib_port:
self.console.print(
f"[yellow]{len(ib_devices)} IB port(s) detected but all DOWN — "
f"fabric not wired, skipping RDMA benchmarks[/yellow]"
)
return {
"status": "SKIP", "skipped": True,
"reason": f"{len(ib_devices)} IB port(s) found but all DOWN (fabric not wired)",
"devices": device_info,
"timestamp": datetime.now().isoformat(),
}
self.console.print(f"[cyan]RDMA Test - Devices: {', '.join(devices)}[/cyan]")
active_pairs = [
(dev, port) for dev, port in ib_devices
if "ACTIVE" in self._read_sys(f"/sys/class/infiniband/{dev}/ports/{port}/state").upper()
]
port_checks = self._evaluate_port_checks(device_info)
test_devices = [dev for dev, _ in active_pairs]
bw_results = self._run_bandwidth_tests(test_devices)
latency_results = self._run_latency_tests(test_devices)
ibping_results = self._run_ibping_tests(active_pairs)
fabric_counters = self._collect_pfc_ecn_counters() if self.rdma_cfg.get("pfc_ecn_counters", True) else {}
failures = self._failure_reasons(port_checks, bw_results, latency_results, ibping_results, fabric_counters)
fabric_counters_missing = (
self.rdma_cfg.get("pfc_ecn_counters", True)
and fabric_counters
and not fabric_counters.get("counters")
)
all_passed = all(
r.get("status") == "PASS"
for r in bw_results + latency_results + ibping_results
if isinstance(r, dict)
) and all(p.get("status") == "PASS" for p in port_checks) and not fabric_counters.get("failed", False) and not fabric_counters_missing
return {
"passed": all_passed,
"devices": device_info,
"port_checks": port_checks,
"bandwidth_tests": bw_results,
"latency_tests": latency_results,
"ibping_tests": ibping_results,
"fabric_counters": fabric_counters,
"failures": failures,
"timestamp": datetime.now().isoformat(),
}
def _collect_device_info(self, devices: List[str]) -> List[dict]:
info = []
for dev in devices:
dev_info = {"name": dev, "ports": []}
ports = self._get_ib_ports(dev)
for port in ports:
port_info = {"port": port}
rate_path = f"/sys/class/infiniband/{dev}/ports/{port}/rate"
state_path = f"/sys/class/infiniband/{dev}/ports/{port}/state"
phys_state_path = f"/sys/class/infiniband/{dev}/ports/{port}/phys_state"
gid_path = f"/sys/class/infiniband/{dev}/ports/{port}/gids/0"
for label, path in [("rate", rate_path), ("state", state_path),
("phys_state", phys_state_path), ("gid", gid_path)]:
try:
with open(path) as f:
port_info[label] = f.read().strip()
except (FileNotFoundError, PermissionError):
port_info[label] = "N/A"
port_info["link_layer"] = self._read_sys(
f"/sys/class/infiniband/{dev}/ports/{port}/link_layer"
) or "N/A"
dev_info["ports"].append(port_info)
info.append(dev_info)
return info
def _evaluate_port_checks(self, device_info: List[dict]) -> List[dict]:
checks = []
min_rate = float(self.rdma_cfg.get("min_port_rate_gbps", 400))
for dev in device_info:
for port in dev.get("ports", []):
if port.get("link_layer") != "InfiniBand":
continue
state = port.get("state", "")
rate = port.get("rate", "")
rate_gbps = self._parse_rate_gbps(rate)
status = "PASS" if "ACTIVE" in state.upper() and rate_gbps >= min_rate else "FAIL"
checks.append({
"device": dev.get("name"),
"port": port.get("port"),
"state": state,
"rate": rate,
"rate_gbps": rate_gbps,
"min_rate_gbps": min_rate,
"status": status,
})
return checks
@staticmethod
def _parse_rate_gbps(rate: str) -> float:
# Example: "400 Gb/sec (4X NDR)"
try:
return float(str(rate).split()[0])
except (ValueError, IndexError, AttributeError):
return 0.0
@staticmethod
def _failure_reasons(port_checks: List[dict], bw_results: List[dict],
latency_results: List[dict], ibping_results: List[dict],
fabric_counters: dict) -> List[str]:
failures = []
for p in port_checks:
if p.get("status") != "PASS":
failures.append(
f"{p.get('device')} port {p.get('port')} state/rate failed "
f"({p.get('state')}, {p.get('rate')}; required >= {p.get('min_rate_gbps')}Gbps ACTIVE)"
)
for r in bw_results:
if r.get("status") != "PASS":
if r.get("error"):
failures.append(f"{r.get('test')} failed: {r.get('error')}")
else:
failures.append(
f"{r.get('test')} bandwidth {r.get('bandwidth_gbps', 0)}GB/s "
f"< {r.get('min_required_gbps', 'N/A')}GB/s"
)
for r in latency_results:
if r.get("status") != "PASS":
if r.get("error"):
failures.append(f"{r.get('test')} failed: {r.get('error')}")
else:
failures.append(
f"{r.get('test')} latency {r.get('latency_us', 0)}us "
f"> {r.get('max_allowed_us', 'N/A')}us"
)
for r in ibping_results:
if r.get("status") != "PASS":
failures.append(f"{r.get('test')} failed: {r.get('error') or r.get('output_tail', '')[:120]}")
if fabric_counters.get("failed"):
nonzero = [f"{k}={v}" for k, v in fabric_counters.get("counters", {}).items() if v]
failures.append("non-zero PFC/ECN/CNP/congestion counters: " + ", ".join(nonzero[:10]))
elif fabric_counters and not fabric_counters.get("counters"):
failures.append("PFC/ECN/CNP/congestion counters not found; fabric counter evidence missing")
return failures
def _run_ib_command(self, cmd: List[str], timeout: int = 60) -> dict:
try:
r = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout)
if r.returncode == 0:
return {"status": "PASS", "output": r.stdout.strip()}
return {"status": "FAIL", "error": r.stderr.strip()[:200]}
except subprocess.TimeoutExpired:
return {"status": "FAIL", "error": "timeout"}
except FileNotFoundError:
return {"status": "SKIP", "error": "tool not found"}
except Exception as e:
return {"status": "FAIL", "error": str(e)}
def _run_bandwidth_tests(self, devices: List[str]) -> List[dict]:
results = []
ib_write_bw = self._find_tool("ib_write_bw")
ib_read_bw = self._find_tool("ib_read_bw")
min_bw = self.rdma_cfg.get("min_bandwidth_gbps", 50)
msg_size = self.rdma_cfg.get("msg_size", 65536)
iters = self.rdma_cfg.get("ib_iterations", 1000)
dx = self.rdma_cfg.get("ib_device", None)
port = self.rdma_cfg.get("ib_port", 1)
server_addr = self.rdma_cfg.get("server_addr") or os.environ.get("RDMA_SERVER_ADDR")
role = self.rdma_cfg.get("role", "auto")
for tool, label in [(ib_write_bw, "ib_write_bw"), (ib_read_bw, "ib_read_bw")]:
if not tool:
results.append({"test": label, "status": "FAIL", "error": "not installed"})
continue
if role == "client" and not server_addr:
results.append({
"test": label,
"status": "FAIL",
"error": "rdma.role=client requires rdma.server_addr or RDMA_SERVER_ADDR",
"role": "client",
})
continue
server_cmd = [tool, "-d", dx or devices[0], "-i", str(port), "-s", str(msg_size), "-n", str(iters)]
client_cmd = server_cmd + [server_addr or "localhost"]
if role == "server":
results.append(self._run_server_mode(label, server_cmd))
continue
server = None
if not server_addr and role != "client":
server = subprocess.Popen(server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
import time
time.sleep(1)
try:
client = subprocess.run(client_cmd, capture_output=True, text=True, timeout=60)
if server:
server.wait(timeout=10)
output = client.stdout
if server and server.stdout:
output += server.stdout.read()
bw_mibps = 0
for line in output.split("\n"):
line = line.strip()
if not line:
continue
parts = line.split()
try:
if len(parts) >= 5 and int(parts[0]) == int(msg_size):
# perftest bandwidth rows:
# #bytes #iterations BW peak[MiB/sec] BW average[MiB/sec] MsgRate[Mpps]
bw_mibps = max(bw_mibps, float(parts[3]))
except (ValueError, IndexError):
continue
bw_gbps = bw_mibps * 1024 * 1024 / 1e9 if bw_mibps else 0
status = "PASS" if bw_gbps >= min_bw else "FAIL"
results.append({
"test": label,
"status": status,
"bandwidth_gbps": round(bw_gbps, 2),
"min_required_gbps": min_bw,
"msg_size": msg_size,
"role": "client" if server_addr else "local_loopback",
})
except Exception as e:
if server:
server.kill()
results.append({"test": label, "status": "FAIL", "error": str(e)})
return results
def _run_latency_tests(self, devices: List[str]) -> List[dict]:
results = []
ib_write_lat = self._find_tool("ib_write_lat")
ib_read_lat = self._find_tool("ib_read_lat")
max_lat_us = self.rdma_cfg.get("max_latency_us", 10)
max_by_test = {
"ib_write_lat": self.rdma_cfg.get("max_write_latency_us", max_lat_us),
"ib_read_lat": self.rdma_cfg.get("max_read_latency_us", max_lat_us),
}
dx = self.rdma_cfg.get("ib_device", None)
port = self.rdma_cfg.get("ib_port", 1)
msg_size = self.rdma_cfg.get("latency_msg_size", 8)
iters = self.rdma_cfg.get("ib_iterations", 1000)
server_addr = self.rdma_cfg.get("server_addr") or os.environ.get("RDMA_SERVER_ADDR")
role = self.rdma_cfg.get("role", "auto")
for tool, label in [(ib_write_lat, "ib_write_lat"), (ib_read_lat, "ib_read_lat")]:
if not tool:
results.append({"test": label, "status": "FAIL", "error": "not installed"})
continue
if role == "client" and not server_addr:
results.append({
"test": label,
"status": "FAIL",
"error": "rdma.role=client requires rdma.server_addr or RDMA_SERVER_ADDR",
"role": "client",
})
continue
server_cmd = [tool, "-d", dx or devices[0], "-i", str(port), "-s", str(msg_size), "-n", str(iters)]
client_cmd = server_cmd + [server_addr or "localhost"]
if role == "server":
results.append(self._run_server_mode(label, server_cmd))
continue
server = None
if not server_addr and role != "client":
server = subprocess.Popen(server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
import time
time.sleep(1)
try:
client = subprocess.run(client_cmd, capture_output=True, text=True, timeout=60)
if server:
server.wait(timeout=10)
output = client.stdout
if server and server.stdout:
output += server.stdout.read()
lat_us = 0
for line in output.split("\n"):
parts = line.strip().split()
try:
if len(parts) >= 6:
int(parts[0])
int(parts[1])
# perftest latency rows:
# #bytes #iterations t_min t_max t_typical t_avg t_stdev p99 p99.9
lat_us = max(lat_us, float(parts[5]))
except (ValueError, IndexError):
continue
max_allowed = max_by_test[label]
status = "PASS" if 0 < lat_us <= max_allowed else "FAIL"
results.append({
"test": label,
"status": status,
"latency_us": round(lat_us, 2),
"max_allowed_us": max_allowed,
"msg_size": msg_size,
"role": "client" if server_addr else "local_loopback",
})
except Exception as e:
if server:
server.kill()
results.append({"test": label, "status": "FAIL", "error": str(e)})
return results
def _run_server_mode(self, label: str, server_cmd: List[str]) -> dict:
timeout = int(self.rdma_cfg.get("server_timeout_sec", 120))
try:
r = subprocess.run(server_cmd, capture_output=True, text=True, timeout=timeout)
return {
"test": label,
"status": "PASS" if r.returncode == 0 else "FAIL",
"role": "server",
"server_timeout_sec": timeout,
"output_tail": (r.stdout + r.stderr)[-500:],
}
except subprocess.TimeoutExpired:
return {
"test": label,
"status": "PASS",
"role": "server",
"server_timeout_sec": timeout,
"note": "server ran until timeout waiting for client",
}
except Exception as e:
return {"test": label, "status": "FAIL", "role": "server", "error": str(e)}
def _run_ibping_tests(self, active_pairs: List[tuple[str, str]]) -> List[dict]:
tool = self._find_tool("ibping")
if not tool:
return [{"test": "ibping", "status": "FAIL", "error": "not installed"}]
if not active_pairs:
return [{"test": "ibping", "status": "FAIL", "error": "no active IB ports"}]
dev, port = active_pairs[0]
target = (
self.rdma_cfg.get("ibping_target")
or os.environ.get("IBPING_TARGET")
)
count = int(self.rdma_cfg.get("ibping_count", 5))
role = self.rdma_cfg.get("role", "auto")
server_addr = self.rdma_cfg.get("server_addr") or os.environ.get("RDMA_SERVER_ADDR")
base = [tool, "-C", dev, "-P", str(port)]
if role == "server":
return [self._run_server_mode("ibping", [*base, "-S"])]
server = None
if not target and role != "client":
target = self._read_sys(f"/sys/class/infiniband/{dev}/ports/{port}/lid")
server = subprocess.Popen([*base, "-S"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
import time
time.sleep(1)
if not target:
reason = "no ibping target/lid"
if role == "client" or server_addr:
reason = (
"cross-node ibping requires rdma.ibping_target or IBPING_TARGET "
"(peer LID/GID; rdma.server_addr is only for perftest TCP bootstrap)"
)
return [{"test": "ibping", "status": "FAIL", "error": reason}]
try:
r = subprocess.run([*base, "-c", str(count), str(target)], capture_output=True, text=True, timeout=30)
if server:
server.terminate()
try:
server.wait(timeout=5)
except subprocess.TimeoutExpired:
server.kill()
output = r.stdout + r.stderr
failed = r.returncode != 0 or "failed" in output.lower()
return [{
"test": "ibping",
"status": "FAIL" if failed else "PASS",
"role": "client" if server_addr or role == "client" else "local_loopback",
"direction": "outbound_to_peer" if server_addr or role == "client" else "local_loopback",
"target": str(target),
"count": count,
"output_tail": output[-500:],
}]
except Exception as e:
if server:
server.kill()
return [{"test": "ibping", "status": "FAIL", "error": str(e)}]
def _collect_pfc_ecn_counters(self) -> dict:
counters = {}
failed = False
keywords = ("pfc", "ecn", "cnp", "congestion")
for root, _, files in os.walk("/sys/class/infiniband"):
for name in files:
lower = name.lower()
if not any(k in lower for k in keywords):
continue
path = os.path.join(root, name)
val = self._read_sys(path)
try:
num = int(val)
except ValueError:
continue
rel = path.replace("/sys/class/infiniband/", "")
counters[rel] = num
if num != 0:
failed = True
ethtool = shutil.which("ethtool")
net_dir = "/sys/class/net"
if ethtool and os.path.isdir(net_dir):
for iface in sorted(os.listdir(net_dir)):
try:
r = subprocess.run(
[ethtool, "-S", iface],
capture_output=True,
text=True,
timeout=10,
)
except Exception:
continue
if r.returncode != 0:
continue
for line in r.stdout.splitlines():
if ":" not in line:
continue
key, value = line.split(":", 1)
key = key.strip()
lower = key.lower()
if not any(k in lower for k in keywords):
continue
try:
num = int(value.strip().split()[0])
except (ValueError, IndexError):
continue
counters[f"net/{iface}/{key}"] = num
if num != 0:
failed = True
return {"failed": failed, "counters": counters}
@staticmethod
def print_results(results: dict, console: Console = None):
c = console or Console()
if results.get("skipped") or results.get("status") == "SKIP":
c.print(f"\n[bold yellow]RDMA/InfiniBand: SKIPPED[/bold yellow] "
f"[dim]({results.get('reason', 'no IB hardware')})[/dim]")
return
if "error" in results:
c.print(f"[bold red]Error: {results['error']}[/bold red]")
return
devices = results.get("devices", [])
c.print(f"\n[bold cyan]RDMA/InfiniBand Test Results[/bold cyan]")
c.print(f" Devices found: {len(devices)}")
for dev in devices:
c.print(f"\n [bold]{dev['name']}[/bold]")
for p in dev.get("ports", []):
state = p.get("state", "N/A")
color = "green" if "Active" in state else "red"
c.print(f" Port {p['port']}: [{color}]{state}[/{color}] | "
f"Rate: {p.get('rate', 'N/A')} | GID: {p.get('gid', 'N/A')[:20]}")
bw_tests = results.get("bandwidth_tests", [])
if bw_tests:
c.print("\n [bold]Bandwidth Tests[/bold]")
for t in bw_tests:
status = t.get("status", "SKIP")
sc = "green" if status == "PASS" else ("yellow" if status == "WARN" else "red")
bw = t.get("bandwidth_gbps", 0)
c.print(f" {t['test']}: [{sc}]{status}[/{sc}] "
f"({bw:.2f} GB/s, min: {t.get('min_required_gbps', 'N/A')} GB/s)" if status != "SKIP"
else f" {t['test']}: [dim]SKIPPED[/dim]")
lat_tests = results.get("latency_tests", [])
if lat_tests:
c.print("\n [bold]Latency Tests[/bold]")
for t in lat_tests:
status = t.get("status", "SKIP")
sc = "green" if status == "PASS" else ("yellow" if status == "WARN" else "red")
lat = t.get("latency_us", 0)
c.print(f" {t['test']}: [{sc}]{status}[/{sc}] "
f"({lat:.2f} us, max: {t.get('max_allowed_us', 'N/A')} us)" if status != "SKIP"
else f" {t['test']}: [dim]SKIPPED[/dim]")
ibping_tests = results.get("ibping_tests", [])
if ibping_tests:
c.print("\n [bold]IB Ping Tests[/bold]")
for t in ibping_tests:
status = t.get("status", "FAIL")
sc = "green" if status == "PASS" else "red"
c.print(f" {t['test']}: [{sc}]{status}[/{sc}] target={t.get('target', 'N/A')}")