"""RDMA / InfiniBand bandwidth and latency test module.""" import os import shutil import subprocess import time from datetime import datetime from typing import Optional, List from rich.console import Console from rich.table import Table class RDMATest: def __init__(self, config: dict): self.config = config self.console = Console() self.rdma_cfg = config.get("rdma", {}) def _find_tool(self, name: str) -> Optional[str]: p = shutil.which(name) if p: return p return None def _get_ib_devices(self) -> List[str]: devices = [] ib_path = "/sys/class/infiniband" if os.path.isdir(ib_path): devices = sorted(os.listdir(ib_path)) return devices def _get_ib_ports(self, device: str) -> List[str]: ports = [] ports_dir = f"/sys/class/infiniband/{device}/ports" if os.path.isdir(ports_dir): ports = sorted(os.listdir(ports_dir)) return ports @staticmethod def _read_sys(path: str) -> str: try: with open(path) as f: return f.read().strip() except (FileNotFoundError, PermissionError, OSError): return "" def run(self) -> dict: devices = self._get_ib_devices() if not devices: self.console.print( "[yellow]No InfiniBand devices found — skipping RDMA test[/yellow]" ) return { "status": "SKIP", "skipped": True, "reason": "no IB hardware detected", "timestamp": datetime.now().isoformat(), } # Only consider ports whose link_layer is InfiniBand — Ethernet # bond/management interfaces (e.g. mlx5_bond_0) can show ACTIVE state # without actually providing IB fabric connectivity. ib_devices = [] active_ib_port = False for dev in devices: for port in self._get_ib_ports(dev): link_layer = self._read_sys( f"/sys/class/infiniband/{dev}/ports/{port}/link_layer") if link_layer != "InfiniBand": continue ib_devices.append((dev, port)) state = self._read_sys( f"/sys/class/infiniband/{dev}/ports/{port}/state") if "ACTIVE" in state.upper(): active_ib_port = True device_info = self._collect_device_info(devices) if not ib_devices: self.console.print( "[yellow]No InfiniBand-link_layer ports present — " "skipping RDMA benchmarks[/yellow]" ) return { "status": "SKIP", "skipped": True, "reason": "no InfiniBand link_layer ports (only Ethernet/RoCE)", "devices": device_info, "timestamp": datetime.now().isoformat(), } if not active_ib_port: self.console.print( f"[yellow]{len(ib_devices)} IB port(s) detected but all DOWN — " f"fabric not wired, skipping RDMA benchmarks[/yellow]" ) return { "status": "SKIP", "skipped": True, "reason": f"{len(ib_devices)} IB port(s) found but all DOWN (fabric not wired)", "devices": device_info, "timestamp": datetime.now().isoformat(), } self.console.print(f"[cyan]RDMA Test - Devices: {', '.join(devices)}[/cyan]") bw_results = self._run_bandwidth_tests(devices) latency_results = self._run_latency_tests(devices) all_passed = all( r.get("status") == "PASS" for r in bw_results + latency_results if isinstance(r, dict) ) result = { "passed": all_passed, "devices": device_info, "bandwidth_tests": bw_results, "latency_tests": latency_results, "timestamp": datetime.now().isoformat(), } # Cross-node (two-host) RDMA, run only when a peer is configured. if (self.rdma_cfg.get("cross_node", {}) or {}).get("enabled"): result["cross_node"] = self.run_cross_node() return result def _collect_device_info(self, devices: List[str]) -> List[dict]: info = [] for dev in devices: dev_info = {"name": dev, "ports": []} ports = self._get_ib_ports(dev) for port in ports: port_info = {"port": port} rate_path = f"/sys/class/infiniband/{dev}/ports/{port}/rate" state_path = f"/sys/class/infiniband/{dev}/ports/{port}/state" phys_state_path = f"/sys/class/infiniband/{dev}/ports/{port}/phys_state" gid_path = f"/sys/class/infiniband/{dev}/ports/{port}/gids/0" for label, path in [("rate", rate_path), ("state", state_path), ("phys_state", phys_state_path), ("gid", gid_path)]: try: with open(path) as f: port_info[label] = f.read().strip() except (FileNotFoundError, PermissionError): port_info[label] = "N/A" dev_info["ports"].append(port_info) info.append(dev_info) return info def _run_ib_command(self, cmd: List[str], timeout: int = 60) -> dict: try: r = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout) if r.returncode == 0: return {"status": "PASS", "output": r.stdout.strip()} return {"status": "FAIL", "error": r.stderr.strip()[:200]} except subprocess.TimeoutExpired: return {"status": "FAIL", "error": "timeout"} except FileNotFoundError: return {"status": "SKIP", "error": "tool not found"} except Exception as e: return {"status": "FAIL", "error": str(e)} def _run_bandwidth_tests(self, devices: List[str]) -> List[dict]: results = [] ib_write_bw = self._find_tool("ib_write_bw") ib_read_bw = self._find_tool("ib_read_bw") min_bw = self.rdma_cfg.get("min_bandwidth_gbps", 50) msg_size = self.rdma_cfg.get("msg_size", 65536) iters = self.rdma_cfg.get("ib_iterations", 1000) dx = self.rdma_cfg.get("ib_device", None) port = self.rdma_cfg.get("ib_port", 1) for tool, label in [(ib_write_bw, "ib_write_bw"), (ib_read_bw, "ib_read_bw")]: if not tool: results.append({"test": label, "status": "SKIP", "error": "not installed"}) continue server_cmd = [tool, "-d", dx or devices[0], "-i", str(port), "-s", str(msg_size)] client_cmd = server_cmd + ["localhost"] server = subprocess.Popen(server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) import time time.sleep(1) try: client = subprocess.run(client_cmd, capture_output=True, text=True, timeout=60) server.wait(timeout=10) output = client.stdout + server.stdout.read() if server.stdout else "" bw_mbps = 0 for line in output.split("\n"): line = line.strip() if not line: continue parts = line.split() try: bw_mbps = max(bw_mbps, float(parts[-1])) except (ValueError, IndexError): continue bw_gbps = bw_mbps / 1000 if bw_mbps else 0 status = "PASS" if bw_gbps >= min_bw else "WARN" results.append({ "test": label, "status": status, "bandwidth_gbps": round(bw_gbps, 2), "min_required_gbps": min_bw, }) except Exception as e: server.kill() results.append({"test": label, "status": "FAIL", "error": str(e)}) return results def _run_latency_tests(self, devices: List[str]) -> List[dict]: results = [] ib_write_lat = self._find_tool("ib_write_lat") ib_read_lat = self._find_tool("ib_read_lat") max_lat_us = self.rdma_cfg.get("max_latency_us", 10) dx = self.rdma_cfg.get("ib_device", None) port = self.rdma_cfg.get("ib_port", 1) for tool, label in [(ib_write_lat, "ib_write_lat"), (ib_read_lat, "ib_read_lat")]: if not tool: results.append({"test": label, "status": "SKIP", "error": "not installed"}) continue server_cmd = [tool, "-d", dx or devices[0], "-i", str(port)] client_cmd = server_cmd + ["localhost"] server = subprocess.Popen(server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) import time time.sleep(1) try: client = subprocess.run(client_cmd, capture_output=True, text=True, timeout=60) server.wait(timeout=10) output = client.stdout + server.stdout.read() if server.stdout else "" lat_us = 0 for line in output.split("\n"): parts = line.strip().split() try: lat_us = max(lat_us, float(parts[-1])) except (ValueError, IndexError): continue status = "PASS" if 0 < lat_us <= max_lat_us else ("WARN" if lat_us > 0 else "FAIL") results.append({ "test": label, "status": status, "latency_us": round(lat_us, 2), "max_allowed_us": max_lat_us, }) except Exception as e: server.kill() results.append({"test": label, "status": "FAIL", "error": str(e)}) return results # ------------------------------------------------------------------ # Cross-node (two-host) RDMA over perftest, orchestrated via SSH. # Runs FROM the client host: for each IB device it launches the matching # perftest server on the peer over SSH (held open in a live ssh channel), # then runs the local client against the peer's OOB address and parses the # result. Replaces the old standalone scripts/rdma_cross_node.sh. # ------------------------------------------------------------------ def _active_ib_devices(self) -> List[str]: """IB devices whose port 1 is InfiniBand link_layer and ACTIVE.""" out = [] for dev in self._get_ib_devices(): for port in self._get_ib_ports(dev): ll = self._read_sys(f"/sys/class/infiniband/{dev}/ports/{port}/link_layer") st = self._read_sys(f"/sys/class/infiniband/{dev}/ports/{port}/state") if ll == "InfiniBand" and "ACTIVE" in st.upper(): out.append(dev) break return out def run_cross_node(self) -> dict: cn = self.rdma_cfg.get("cross_node", {}) or {} if not cn.get("enabled"): return {"status": "SKIP", "skipped": True, "reason": "rdma.cross_node.enabled is false"} server = cn.get("server") if not server: return {"status": "SKIP", "skipped": True, "reason": "rdma.cross_node.server (peer ssh address) not set"} ssh_user = cn.get("ssh_user", "root") server_target = server if "@" in server else f"{ssh_user}@{server}" # OOB address the client's perftest connects to (defaults to the ssh host). server_addr = cn.get("server_addr") or server.split("@")[-1] ib_port = cn.get("ib_port", 1) gid_index = cn.get("gid_index") msg_size = cn.get("msg_size", 1048576) iters = cn.get("iters", 5000) base_port = cn.get("base_oob_port", 18515) warmup = cn.get("server_warmup_sec", 2.0) min_bw = cn.get("min_bandwidth_gbps", 350) max_lat = cn.get("max_latency_us", 5) devices = cn.get("devices") or self._active_ib_devices() if not devices: return {"status": "SKIP", "skipped": True, "reason": "no active InfiniBand devices to test"} has_bw = self._find_tool("ib_write_bw") is not None has_lat = self._find_tool("ib_write_lat") is not None if not has_bw and not has_lat: return {"status": "SKIP", "skipped": True, "reason": "perftest (ib_write_bw / ib_write_lat) not installed"} self.console.print( f"[cyan]Cross-node RDMA — client → {server_addr}, " f"devices: {', '.join(devices)}[/cyan]") per_device = [] for idx, dev in enumerate(devices): oob = base_port + idx entry = {"device": dev} if has_bw: bw = self._cross_node_perftest( "ib_write_bw", dev, server_target, server_addr, ib_port, oob, gid_index, warmup, extra=["--report_gbits", "-s", str(msg_size), "-n", str(iters)], parse="bw") entry["bandwidth_gbps"] = bw if isinstance(bw, (int, float)): entry["bw_status"] = "PASS" if bw >= min_bw else "WARN" else: entry["bw_status"] = "FAIL" if has_lat: lat = self._cross_node_perftest( "ib_write_lat", dev, server_target, server_addr, ib_port, oob, gid_index, warmup, extra=[], parse="lat") if isinstance(lat, dict): entry["latency_us"] = lat.get("typical") entry["latency_p99_us"] = lat.get("p99") t = lat.get("typical") entry["lat_status"] = ("PASS" if isinstance(t, (int, float)) and 0 < t <= max_lat else ("WARN" if isinstance(t, (int, float)) else "FAIL")) else: entry["latency_us"] = lat entry["lat_status"] = "FAIL" per_device.append(entry) statuses = [e.get(k) for e in per_device for k in ("bw_status", "lat_status") if e.get(k)] verdict = "PASS" for s in statuses: if s == "FAIL": verdict = "FAIL" break if s == "WARN" and verdict == "PASS": verdict = "WARN" return { "status": verdict, "server": server_addr, "min_bandwidth_gbps": min_bw, "max_latency_us": max_lat, "per_device": per_device, "timestamp": datetime.now().isoformat(), } def _cross_node_perftest(self, tool: str, dev: str, server_target: str, server_addr: str, ib_port: int, oob_port: int, gid_index, warmup: float, extra: List[str], parse: str): """Start `tool` server on the peer via SSH, run the local client, parse output. Returns a float (bw, Gb/s), a dict {typical, p99} (lat, µs), or an error string. """ tool_path = self._find_tool(tool) if not tool_path: return f"{tool} not installed" flags = ["-d", dev, "-i", str(ib_port), "-p", str(oob_port), "-F"] if gid_index is not None: flags += ["-x", str(gid_index)] flags += extra server_cmd = " ".join([tool] + flags) # server: no host argument server_proc = None try: server_proc = subprocess.Popen( ["ssh", "-o", "BatchMode=yes", "-o", "StrictHostKeyChecking=no", server_target, server_cmd], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) time.sleep(warmup) # let the remote server bind before the client connects client = subprocess.run([tool_path] + flags + [server_addr], capture_output=True, text=True, timeout=120) out = client.stdout + "\n" + (client.stderr or "") return self._parse_perftest_lat(out) if parse == "lat" else self._parse_perftest_bw(out) except subprocess.TimeoutExpired: return "timeout" except Exception as e: # noqa: BLE001 return f"error: {e}" finally: if server_proc and server_proc.poll() is None: server_proc.terminate() try: server_proc.wait(timeout=5) except Exception: server_proc.kill() # ib_write_* server normally exits after one run; pkill cleans up a # leftover one if the client failed mid-handshake. -x matches the exact # process name so it never kills this ssh command itself. try: subprocess.run( ["ssh", "-o", "BatchMode=yes", server_target, f"pkill -x {tool}"], capture_output=True, timeout=10) except Exception: pass @staticmethod def _parse_perftest_bw(output: str) -> float: """Parse ib_write_bw rows (#bytes #iter BW_peak BW_avg ...); return max BW avg.""" best = 0.0 for line in output.splitlines(): parts = line.split() if len(parts) >= 4: try: int(parts[0]) # #bytes column best = max(best, float(parts[3])) # BW average[Gb/sec] except ValueError: continue return round(best, 2) if best else 0.0 @staticmethod def _parse_perftest_lat(output: str) -> dict: """Parse ib_write_lat row (#bytes #iter t_min t_max t_typical t_avg ... 99%).""" for line in output.splitlines(): parts = line.split() if len(parts) >= 6: try: int(parts[0]); int(parts[1]) typical = float(parts[4]) # t_typical[usec] except ValueError: continue p99 = None if len(parts) >= 8: try: p99 = float(parts[7]) # 99% percentile[usec] except ValueError: p99 = None return {"typical": round(typical, 2), "p99": round(p99, 2) if p99 else None} return {"typical": None, "p99": None} @staticmethod def print_results(results: dict, console: Console = None): c = console or Console() if results.get("skipped") or results.get("status") == "SKIP": c.print(f"\n[bold yellow]RDMA/InfiniBand: SKIPPED[/bold yellow] " f"[dim]({results.get('reason', 'no IB hardware')})[/dim]") return if "error" in results: c.print(f"[bold red]Error: {results['error']}[/bold red]") return devices = results.get("devices", []) c.print(f"\n[bold cyan]RDMA/InfiniBand Test Results[/bold cyan]") c.print(f" Devices found: {len(devices)}") for dev in devices: c.print(f"\n [bold]{dev['name']}[/bold]") for p in dev.get("ports", []): state = p.get("state", "N/A") color = "green" if "Active" in state else "red" c.print(f" Port {p['port']}: [{color}]{state}[/{color}] | " f"Rate: {p.get('rate', 'N/A')} | GID: {p.get('gid', 'N/A')[:20]}") bw_tests = results.get("bandwidth_tests", []) if bw_tests: c.print("\n [bold]Bandwidth Tests[/bold]") for t in bw_tests: status = t.get("status", "SKIP") sc = "green" if status == "PASS" else ("yellow" if status == "WARN" else "red") bw = t.get("bandwidth_gbps", 0) c.print(f" {t['test']}: [{sc}]{status}[/{sc}] " f"({bw:.2f} GB/s, min: {t.get('min_required_gbps', 'N/A')} GB/s)" if status != "SKIP" else f" {t['test']}: [dim]SKIPPED[/dim]") lat_tests = results.get("latency_tests", []) if lat_tests: c.print("\n [bold]Latency Tests[/bold]") for t in lat_tests: status = t.get("status", "SKIP") sc = "green" if status == "PASS" else ("yellow" if status == "WARN" else "red") lat = t.get("latency_us", 0) c.print(f" {t['test']}: [{sc}]{status}[/{sc}] " f"({lat:.2f} us, max: {t.get('max_allowed_us', 'N/A')} us)" if status != "SKIP" else f" {t['test']}: [dim]SKIPPED[/dim]") cn = results.get("cross_node") if cn: if cn.get("skipped"): c.print(f"\n [bold]Cross-node RDMA[/bold]: [dim]SKIPPED " f"({cn.get('reason', '')})[/dim]") else: v = cn.get("status", "?") vc = "green" if v == "PASS" else ("yellow" if v == "WARN" else "red") c.print(f"\n [bold]Cross-node RDMA[/bold] (server {cn.get('server')}) " f"[{vc}]{v}[/{vc}] " f"[dim]min {cn.get('min_bandwidth_gbps')} Gb/s, " f"max {cn.get('max_latency_us')} µs[/dim]") for e in cn.get("per_device", []): bw = e.get("bandwidth_gbps") lat = e.get("latency_us") bws = e.get("bw_status", "") lts = e.get("lat_status", "") bc = "green" if bws == "PASS" else ("yellow" if bws == "WARN" else "red") lc = "green" if lts == "PASS" else ("yellow" if lts == "WARN" else "red") bw_s = f"{bw:.1f} Gb/s" if isinstance(bw, (int, float)) else str(bw) lat_s = f"{lat:.2f} µs" if isinstance(lat, (int, float)) else str(lat) p99 = e.get("latency_p99_us") p99_s = f", p99 {p99:.2f}" if isinstance(p99, (int, float)) else "" c.print(f" {e['device']}: BW [{bc}]{bw_s}[/{bc}] | " f"lat [{lc}]{lat_s}[/{lc}]{p99_s}")