"""RDMA / InfiniBand bandwidth and latency test module.""" import glob import os import shutil import subprocess from datetime import datetime from typing import Optional, List from rich.console import Console from rich.table import Table from modules.gpu_specs import resolve_tools_dir class RDMATest: def __init__(self, config: dict): self.config = config self.console = Console() self.rdma_cfg = config.get("rdma", {}) self.tools_dir = resolve_tools_dir(config) def _find_tool(self, name: str) -> Optional[str]: p = shutil.which(name) if p: return p candidates = [ os.path.join(self.tools_dir, "perftest", name), os.path.join(self.tools_dir, "perftest", "bin", name), os.path.join(self.tools_dir, "rdma", name), os.path.join(self.tools_dir, name), ] for path in candidates: if os.path.isfile(path) and os.access(path, os.X_OK): return path for path in glob.glob(os.path.join(self.tools_dir, "**", name), recursive=True): if os.path.isfile(path) and os.access(path, os.X_OK): return path return None def _get_ib_devices(self) -> List[str]: devices = [] ib_path = "/sys/class/infiniband" if os.path.isdir(ib_path): devices = sorted(os.listdir(ib_path)) return devices def _get_ib_ports(self, device: str) -> List[str]: ports = [] ports_dir = f"/sys/class/infiniband/{device}/ports" if os.path.isdir(ports_dir): ports = sorted(os.listdir(ports_dir)) return ports @staticmethod def _read_sys(path: str) -> str: try: with open(path) as f: return f.read().strip() except (FileNotFoundError, PermissionError, OSError): return "" def run(self) -> dict: devices = self._get_ib_devices() if not devices: self.console.print( "[yellow]No InfiniBand devices found — skipping RDMA test[/yellow]" ) return { "status": "SKIP", "skipped": True, "reason": "no IB hardware detected", "timestamp": datetime.now().isoformat(), } # Only consider ports whose link_layer is InfiniBand — Ethernet # bond/management interfaces (e.g. mlx5_bond_0) can show ACTIVE state # without actually providing IB fabric connectivity. ib_devices = [] active_ib_port = False for dev in devices: for port in self._get_ib_ports(dev): link_layer = self._read_sys( f"/sys/class/infiniband/{dev}/ports/{port}/link_layer") if link_layer != "InfiniBand": continue ib_devices.append((dev, port)) state = self._read_sys( f"/sys/class/infiniband/{dev}/ports/{port}/state") if "ACTIVE" in state.upper(): active_ib_port = True device_info = self._collect_device_info(devices) if not ib_devices: self.console.print( "[yellow]No InfiniBand-link_layer ports present — " "skipping RDMA benchmarks[/yellow]" ) return { "status": "SKIP", "skipped": True, "reason": "no InfiniBand link_layer ports (only Ethernet/RoCE)", "devices": device_info, "timestamp": datetime.now().isoformat(), } if not active_ib_port: self.console.print( f"[yellow]{len(ib_devices)} IB port(s) detected but all DOWN — " f"fabric not wired, skipping RDMA benchmarks[/yellow]" ) return { "status": "SKIP", "skipped": True, "reason": f"{len(ib_devices)} IB port(s) found but all DOWN (fabric not wired)", "devices": device_info, "timestamp": datetime.now().isoformat(), } self.console.print(f"[cyan]RDMA Test - Devices: {', '.join(devices)}[/cyan]") active_pairs = [ (dev, port) for dev, port in ib_devices if "ACTIVE" in self._read_sys(f"/sys/class/infiniband/{dev}/ports/{port}/state").upper() ] port_checks = self._evaluate_port_checks(device_info) test_devices = [dev for dev, _ in active_pairs] bw_results = self._run_bandwidth_tests(test_devices) latency_results = self._run_latency_tests(test_devices) ibping_results = self._run_ibping_tests(active_pairs) fabric_counters = self._collect_pfc_ecn_counters() if self.rdma_cfg.get("pfc_ecn_counters", True) else {} failures = self._failure_reasons(port_checks, bw_results, latency_results, ibping_results, fabric_counters) fabric_counters_missing = ( self.rdma_cfg.get("pfc_ecn_counters", True) and fabric_counters and not fabric_counters.get("counters") ) all_passed = all( r.get("status") == "PASS" for r in bw_results + latency_results + ibping_results if isinstance(r, dict) ) and all(p.get("status") == "PASS" for p in port_checks) and not fabric_counters.get("failed", False) and not fabric_counters_missing return { "passed": all_passed, "devices": device_info, "port_checks": port_checks, "bandwidth_tests": bw_results, "latency_tests": latency_results, "ibping_tests": ibping_results, "fabric_counters": fabric_counters, "failures": failures, "timestamp": datetime.now().isoformat(), } def _collect_device_info(self, devices: List[str]) -> List[dict]: info = [] for dev in devices: dev_info = {"name": dev, "ports": []} ports = self._get_ib_ports(dev) for port in ports: port_info = {"port": port} rate_path = f"/sys/class/infiniband/{dev}/ports/{port}/rate" state_path = f"/sys/class/infiniband/{dev}/ports/{port}/state" phys_state_path = f"/sys/class/infiniband/{dev}/ports/{port}/phys_state" gid_path = f"/sys/class/infiniband/{dev}/ports/{port}/gids/0" for label, path in [("rate", rate_path), ("state", state_path), ("phys_state", phys_state_path), ("gid", gid_path)]: try: with open(path) as f: port_info[label] = f.read().strip() except (FileNotFoundError, PermissionError): port_info[label] = "N/A" port_info["link_layer"] = self._read_sys( f"/sys/class/infiniband/{dev}/ports/{port}/link_layer" ) or "N/A" dev_info["ports"].append(port_info) info.append(dev_info) return info def _evaluate_port_checks(self, device_info: List[dict]) -> List[dict]: checks = [] min_rate = float(self.rdma_cfg.get("min_port_rate_gbps", 400)) for dev in device_info: for port in dev.get("ports", []): if port.get("link_layer") != "InfiniBand": continue state = port.get("state", "") rate = port.get("rate", "") rate_gbps = self._parse_rate_gbps(rate) status = "PASS" if "ACTIVE" in state.upper() and rate_gbps >= min_rate else "FAIL" checks.append({ "device": dev.get("name"), "port": port.get("port"), "state": state, "rate": rate, "rate_gbps": rate_gbps, "min_rate_gbps": min_rate, "status": status, }) return checks @staticmethod def _parse_rate_gbps(rate: str) -> float: # Example: "400 Gb/sec (4X NDR)" try: return float(str(rate).split()[0]) except (ValueError, IndexError, AttributeError): return 0.0 @staticmethod def _failure_reasons(port_checks: List[dict], bw_results: List[dict], latency_results: List[dict], ibping_results: List[dict], fabric_counters: dict) -> List[str]: failures = [] for p in port_checks: if p.get("status") != "PASS": failures.append( f"{p.get('device')} port {p.get('port')} state/rate failed " f"({p.get('state')}, {p.get('rate')}; required >= {p.get('min_rate_gbps')}Gbps ACTIVE)" ) for r in bw_results: if r.get("status") != "PASS": if r.get("error"): failures.append(f"{r.get('test')} failed: {r.get('error')}") else: failures.append( f"{r.get('test')} bandwidth {r.get('bandwidth_gbps', 0)}GB/s " f"< {r.get('min_required_gbps', 'N/A')}GB/s" ) for r in latency_results: if r.get("status") != "PASS": if r.get("error"): failures.append(f"{r.get('test')} failed: {r.get('error')}") else: failures.append( f"{r.get('test')} latency {r.get('latency_us', 0)}us " f"> {r.get('max_allowed_us', 'N/A')}us" ) for r in ibping_results: if r.get("status") != "PASS": failures.append(f"{r.get('test')} failed: {r.get('error') or r.get('output_tail', '')[:120]}") if fabric_counters.get("failed"): nonzero = [f"{k}={v}" for k, v in fabric_counters.get("counters", {}).items() if v] failures.append("non-zero PFC/ECN/CNP/congestion counters: " + ", ".join(nonzero[:10])) elif fabric_counters and not fabric_counters.get("counters"): failures.append("PFC/ECN/CNP/congestion counters not found; fabric counter evidence missing") return failures def _run_ib_command(self, cmd: List[str], timeout: int = 60) -> dict: try: r = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout) if r.returncode == 0: return {"status": "PASS", "output": r.stdout.strip()} return {"status": "FAIL", "error": r.stderr.strip()[:200]} except subprocess.TimeoutExpired: return {"status": "FAIL", "error": "timeout"} except FileNotFoundError: return {"status": "SKIP", "error": "tool not found"} except Exception as e: return {"status": "FAIL", "error": str(e)} def _run_bandwidth_tests(self, devices: List[str]) -> List[dict]: results = [] ib_write_bw = self._find_tool("ib_write_bw") ib_read_bw = self._find_tool("ib_read_bw") min_bw = self.rdma_cfg.get("min_bandwidth_gbps", 50) msg_size = self.rdma_cfg.get("msg_size", 65536) iters = self.rdma_cfg.get("ib_iterations", 1000) dx = self.rdma_cfg.get("ib_device", None) port = self.rdma_cfg.get("ib_port", 1) server_addr = self.rdma_cfg.get("server_addr") or os.environ.get("RDMA_SERVER_ADDR") role = self.rdma_cfg.get("role", "auto") for tool, label in [(ib_write_bw, "ib_write_bw"), (ib_read_bw, "ib_read_bw")]: if not tool: results.append({"test": label, "status": "FAIL", "error": "not installed"}) continue if role == "client" and not server_addr: results.append({ "test": label, "status": "FAIL", "error": "rdma.role=client requires rdma.server_addr or RDMA_SERVER_ADDR", "role": "client", }) continue server_cmd = [tool, "-d", dx or devices[0], "-i", str(port), "-s", str(msg_size), "-n", str(iters)] client_cmd = server_cmd + [server_addr or "localhost"] if role == "server": results.append(self._run_server_mode(label, server_cmd)) continue server = None if not server_addr and role != "client": server = subprocess.Popen(server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) import time time.sleep(1) try: client = subprocess.run(client_cmd, capture_output=True, text=True, timeout=60) if server: server.wait(timeout=10) output = client.stdout if server and server.stdout: output += server.stdout.read() bw_mibps = 0 for line in output.split("\n"): line = line.strip() if not line: continue parts = line.split() try: if len(parts) >= 5 and int(parts[0]) == int(msg_size): # perftest bandwidth rows: # #bytes #iterations BW peak[MiB/sec] BW average[MiB/sec] MsgRate[Mpps] bw_mibps = max(bw_mibps, float(parts[3])) except (ValueError, IndexError): continue bw_gbps = bw_mibps * 1024 * 1024 / 1e9 if bw_mibps else 0 status = "PASS" if bw_gbps >= min_bw else "FAIL" results.append({ "test": label, "status": status, "bandwidth_gbps": round(bw_gbps, 2), "min_required_gbps": min_bw, "msg_size": msg_size, "role": "client" if server_addr else "local_loopback", }) except Exception as e: if server: server.kill() results.append({"test": label, "status": "FAIL", "error": str(e)}) return results def _run_latency_tests(self, devices: List[str]) -> List[dict]: results = [] ib_write_lat = self._find_tool("ib_write_lat") ib_read_lat = self._find_tool("ib_read_lat") max_lat_us = self.rdma_cfg.get("max_latency_us", 10) max_by_test = { "ib_write_lat": self.rdma_cfg.get("max_write_latency_us", max_lat_us), "ib_read_lat": self.rdma_cfg.get("max_read_latency_us", max_lat_us), } dx = self.rdma_cfg.get("ib_device", None) port = self.rdma_cfg.get("ib_port", 1) msg_size = self.rdma_cfg.get("latency_msg_size", 8) iters = self.rdma_cfg.get("ib_iterations", 1000) server_addr = self.rdma_cfg.get("server_addr") or os.environ.get("RDMA_SERVER_ADDR") role = self.rdma_cfg.get("role", "auto") for tool, label in [(ib_write_lat, "ib_write_lat"), (ib_read_lat, "ib_read_lat")]: if not tool: results.append({"test": label, "status": "FAIL", "error": "not installed"}) continue if role == "client" and not server_addr: results.append({ "test": label, "status": "FAIL", "error": "rdma.role=client requires rdma.server_addr or RDMA_SERVER_ADDR", "role": "client", }) continue server_cmd = [tool, "-d", dx or devices[0], "-i", str(port), "-s", str(msg_size), "-n", str(iters)] client_cmd = server_cmd + [server_addr or "localhost"] if role == "server": results.append(self._run_server_mode(label, server_cmd)) continue server = None if not server_addr and role != "client": server = subprocess.Popen(server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) import time time.sleep(1) try: client = subprocess.run(client_cmd, capture_output=True, text=True, timeout=60) if server: server.wait(timeout=10) output = client.stdout if server and server.stdout: output += server.stdout.read() lat_us = 0 for line in output.split("\n"): parts = line.strip().split() try: if len(parts) >= 6: int(parts[0]) int(parts[1]) # perftest latency rows: # #bytes #iterations t_min t_max t_typical t_avg t_stdev p99 p99.9 lat_us = max(lat_us, float(parts[5])) except (ValueError, IndexError): continue max_allowed = max_by_test[label] status = "PASS" if 0 < lat_us <= max_allowed else "FAIL" results.append({ "test": label, "status": status, "latency_us": round(lat_us, 2), "max_allowed_us": max_allowed, "msg_size": msg_size, "role": "client" if server_addr else "local_loopback", }) except Exception as e: if server: server.kill() results.append({"test": label, "status": "FAIL", "error": str(e)}) return results def _run_server_mode(self, label: str, server_cmd: List[str]) -> dict: timeout = int(self.rdma_cfg.get("server_timeout_sec", 120)) try: r = subprocess.run(server_cmd, capture_output=True, text=True, timeout=timeout) return { "test": label, "status": "PASS" if r.returncode == 0 else "FAIL", "role": "server", "server_timeout_sec": timeout, "output_tail": (r.stdout + r.stderr)[-500:], } except subprocess.TimeoutExpired: return { "test": label, "status": "PASS", "role": "server", "server_timeout_sec": timeout, "note": "server ran until timeout waiting for client", } except Exception as e: return {"test": label, "status": "FAIL", "role": "server", "error": str(e)} def _run_ibping_tests(self, active_pairs: List[tuple[str, str]]) -> List[dict]: tool = self._find_tool("ibping") if not tool: return [{"test": "ibping", "status": "FAIL", "error": "not installed"}] if not active_pairs: return [{"test": "ibping", "status": "FAIL", "error": "no active IB ports"}] dev, port = active_pairs[0] target = ( self.rdma_cfg.get("ibping_target") or os.environ.get("IBPING_TARGET") ) count = int(self.rdma_cfg.get("ibping_count", 5)) role = self.rdma_cfg.get("role", "auto") server_addr = self.rdma_cfg.get("server_addr") or os.environ.get("RDMA_SERVER_ADDR") base = [tool, "-C", dev, "-P", str(port)] if role == "server": return [self._run_server_mode("ibping", [*base, "-S"])] server = None if not target and role != "client": target = self._read_sys(f"/sys/class/infiniband/{dev}/ports/{port}/lid") server = subprocess.Popen([*base, "-S"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) import time time.sleep(1) if not target: reason = "no ibping target/lid" if role == "client" or server_addr: reason = ( "cross-node ibping requires rdma.ibping_target or IBPING_TARGET " "(peer LID/GID; rdma.server_addr is only for perftest TCP bootstrap)" ) return [{"test": "ibping", "status": "FAIL", "error": reason}] try: r = subprocess.run([*base, "-c", str(count), str(target)], capture_output=True, text=True, timeout=30) if server: server.terminate() try: server.wait(timeout=5) except subprocess.TimeoutExpired: server.kill() output = r.stdout + r.stderr failed = r.returncode != 0 or "failed" in output.lower() return [{ "test": "ibping", "status": "FAIL" if failed else "PASS", "role": "client" if server_addr or role == "client" else "local_loopback", "direction": "outbound_to_peer" if server_addr or role == "client" else "local_loopback", "target": str(target), "count": count, "output_tail": output[-500:], }] except Exception as e: if server: server.kill() return [{"test": "ibping", "status": "FAIL", "error": str(e)}] def _collect_pfc_ecn_counters(self) -> dict: counters = {} failed = False keywords = ("pfc", "ecn", "cnp", "congestion") for root, _, files in os.walk("/sys/class/infiniband"): for name in files: lower = name.lower() if not any(k in lower for k in keywords): continue path = os.path.join(root, name) val = self._read_sys(path) try: num = int(val) except ValueError: continue rel = path.replace("/sys/class/infiniband/", "") counters[rel] = num if num != 0: failed = True ethtool = shutil.which("ethtool") net_dir = "/sys/class/net" if ethtool and os.path.isdir(net_dir): for iface in sorted(os.listdir(net_dir)): try: r = subprocess.run( [ethtool, "-S", iface], capture_output=True, text=True, timeout=10, ) except Exception: continue if r.returncode != 0: continue for line in r.stdout.splitlines(): if ":" not in line: continue key, value = line.split(":", 1) key = key.strip() lower = key.lower() if not any(k in lower for k in keywords): continue try: num = int(value.strip().split()[0]) except (ValueError, IndexError): continue counters[f"net/{iface}/{key}"] = num if num != 0: failed = True return {"failed": failed, "counters": counters} @staticmethod def print_results(results: dict, console: Console = None): c = console or Console() if results.get("skipped") or results.get("status") == "SKIP": c.print(f"\n[bold yellow]RDMA/InfiniBand: SKIPPED[/bold yellow] " f"[dim]({results.get('reason', 'no IB hardware')})[/dim]") return if "error" in results: c.print(f"[bold red]Error: {results['error']}[/bold red]") return devices = results.get("devices", []) c.print(f"\n[bold cyan]RDMA/InfiniBand Test Results[/bold cyan]") c.print(f" Devices found: {len(devices)}") for dev in devices: c.print(f"\n [bold]{dev['name']}[/bold]") for p in dev.get("ports", []): state = p.get("state", "N/A") color = "green" if "Active" in state else "red" c.print(f" Port {p['port']}: [{color}]{state}[/{color}] | " f"Rate: {p.get('rate', 'N/A')} | GID: {p.get('gid', 'N/A')[:20]}") bw_tests = results.get("bandwidth_tests", []) if bw_tests: c.print("\n [bold]Bandwidth Tests[/bold]") for t in bw_tests: status = t.get("status", "SKIP") sc = "green" if status == "PASS" else ("yellow" if status == "WARN" else "red") bw = t.get("bandwidth_gbps", 0) c.print(f" {t['test']}: [{sc}]{status}[/{sc}] " f"({bw:.2f} GB/s, min: {t.get('min_required_gbps', 'N/A')} GB/s)" if status != "SKIP" else f" {t['test']}: [dim]SKIPPED[/dim]") lat_tests = results.get("latency_tests", []) if lat_tests: c.print("\n [bold]Latency Tests[/bold]") for t in lat_tests: status = t.get("status", "SKIP") sc = "green" if status == "PASS" else ("yellow" if status == "WARN" else "red") lat = t.get("latency_us", 0) c.print(f" {t['test']}: [{sc}]{status}[/{sc}] " f"({lat:.2f} us, max: {t.get('max_allowed_us', 'N/A')} us)" if status != "SKIP" else f" {t['test']}: [dim]SKIPPED[/dim]") ibping_tests = results.get("ibping_tests", []) if ibping_tests: c.print("\n [bold]IB Ping Tests[/bold]") for t in ibping_tests: status = t.get("status", "FAIL") sc = "green" if status == "PASS" else "red" c.print(f" {t['test']}: [{sc}]{status}[/{sc}] target={t.get('target', 'N/A')}")