From 1c6ba4809ab5fc6fe0f1a8e58f7931b7051a733e Mon Sep 17 00:00:00 2001 From: qinyusen Date: Sat, 25 Apr 2026 17:23:57 +0800 Subject: [PATCH] add: stress test (gpu-burn) and RDMA/IB test modules Co-authored-by: Sisyphus --- modules/rdma_test.py | 240 +++++++++++++++++++++++++++++++++++++++++ modules/stress_test.py | 198 ++++++++++++++++++++++++++++++++++ 2 files changed, 438 insertions(+) create mode 100644 modules/rdma_test.py create mode 100644 modules/stress_test.py diff --git a/modules/rdma_test.py b/modules/rdma_test.py new file mode 100644 index 0000000..e1f54f5 --- /dev/null +++ b/modules/rdma_test.py @@ -0,0 +1,240 @@ +"""RDMA / InfiniBand bandwidth and latency test module.""" + +import os +import shutil +import subprocess +from datetime import datetime +from typing import Optional, List + +from rich.console import Console +from rich.table import Table + + +class RDMATest: + + def __init__(self, config: dict): + self.config = config + self.console = Console() + self.rdma_cfg = config.get("rdma", {}) + + def _find_tool(self, name: str) -> Optional[str]: + p = shutil.which(name) + if p: + return p + return None + + def _get_ib_devices(self) -> List[str]: + devices = [] + ib_path = "/sys/class/infiniband" + if os.path.isdir(ib_path): + devices = sorted(os.listdir(ib_path)) + return devices + + def _get_ib_ports(self, device: str) -> List[str]: + ports = [] + ports_dir = f"/sys/class/infiniband/{device}/ports" + if os.path.isdir(ports_dir): + ports = sorted(os.listdir(ports_dir)) + return ports + + def run(self) -> dict: + devices = self._get_ib_devices() + if not devices: + self.console.print("[yellow]No InfiniBand devices found[/yellow]") + return {"error": "no_ib_devices", "passed": False} + + self.console.print(f"[cyan]RDMA Test - Devices: {', '.join(devices)}[/cyan]") + + device_info = self._collect_device_info(devices) + bw_results = self._run_bandwidth_tests(devices) + latency_results = self._run_latency_tests(devices) + + all_passed = all( + r.get("status") == "PASS" + for r in bw_results + latency_results + if isinstance(r, dict) + ) + + return { + "passed": all_passed, + "devices": device_info, + "bandwidth_tests": bw_results, + "latency_tests": latency_results, + "timestamp": datetime.now().isoformat(), + } + + def _collect_device_info(self, devices: List[str]) -> List[dict]: + info = [] + for dev in devices: + dev_info = {"name": dev, "ports": []} + ports = self._get_ib_ports(dev) + for port in ports: + port_info = {"port": port} + rate_path = f"/sys/class/infiniband/{dev}/ports/{port}/rate" + state_path = f"/sys/class/infiniband/{dev}/ports/{port}/state" + phys_state_path = f"/sys/class/infiniband/{dev}/ports/{port}/phys_state" + gid_path = f"/sys/class/infiniband/{dev}/ports/{port}/gids/0" + + for label, path in [("rate", rate_path), ("state", state_path), + ("phys_state", phys_state_path), ("gid", gid_path)]: + try: + with open(path) as f: + port_info[label] = f.read().strip() + except (FileNotFoundError, PermissionError): + port_info[label] = "N/A" + + dev_info["ports"].append(port_info) + info.append(dev_info) + return info + + def _run_ib_command(self, cmd: List[str], timeout: int = 60) -> dict: + try: + r = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout) + if r.returncode == 0: + return {"status": "PASS", "output": r.stdout.strip()} + return {"status": "FAIL", "error": r.stderr.strip()[:200]} + except subprocess.TimeoutExpired: + return {"status": "FAIL", "error": "timeout"} + except FileNotFoundError: + return {"status": "SKIP", "error": "tool not found"} + except Exception as e: + return {"status": "FAIL", "error": str(e)} + + def _run_bandwidth_tests(self, devices: List[str]) -> List[dict]: + results = [] + ib_write_bw = self._find_tool("ib_write_bw") + ib_read_bw = self._find_tool("ib_read_bw") + min_bw = self.rdma_cfg.get("min_bandwidth_gbps", 50) + msg_size = self.rdma_cfg.get("msg_size", 65536) + iters = self.rdma_cfg.get("ib_iterations", 1000) + dx = self.rdma_cfg.get("ib_device", None) + port = self.rdma_cfg.get("ib_port", 1) + + for tool, label in [(ib_write_bw, "ib_write_bw"), (ib_read_bw, "ib_read_bw")]: + if not tool: + results.append({"test": label, "status": "SKIP", "error": "not installed"}) + continue + + server_cmd = [tool, "-d", dx or devices[0], "-i", str(port), "-s", str(msg_size)] + client_cmd = server_cmd + ["localhost"] + + server = subprocess.Popen(server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + import time + time.sleep(1) + + try: + client = subprocess.run(client_cmd, capture_output=True, text=True, timeout=60) + server.wait(timeout=10) + + output = client.stdout + server.stdout.read() if server.stdout else "" + bw_mbps = 0 + for line in output.split("\n"): + line = line.strip() + if not line: + continue + parts = line.split() + try: + bw_mbps = max(bw_mbps, float(parts[-1])) + except (ValueError, IndexError): + continue + + bw_gbps = bw_mbps / 1000 if bw_mbps else 0 + status = "PASS" if bw_gbps >= min_bw else "WARN" + results.append({ + "test": label, + "status": status, + "bandwidth_gbps": round(bw_gbps, 2), + "min_required_gbps": min_bw, + }) + except Exception as e: + server.kill() + results.append({"test": label, "status": "FAIL", "error": str(e)}) + + return results + + def _run_latency_tests(self, devices: List[str]) -> List[dict]: + results = [] + ib_write_lat = self._find_tool("ib_write_lat") + ib_read_lat = self._find_tool("ib_read_lat") + max_lat_us = self.rdma_cfg.get("max_latency_us", 10) + dx = self.rdma_cfg.get("ib_device", None) + port = self.rdma_cfg.get("ib_port", 1) + + for tool, label in [(ib_write_lat, "ib_write_lat"), (ib_read_lat, "ib_read_lat")]: + if not tool: + results.append({"test": label, "status": "SKIP", "error": "not installed"}) + continue + + server_cmd = [tool, "-d", dx or devices[0], "-i", str(port)] + client_cmd = server_cmd + ["localhost"] + + server = subprocess.Popen(server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + import time + time.sleep(1) + + try: + client = subprocess.run(client_cmd, capture_output=True, text=True, timeout=60) + server.wait(timeout=10) + + output = client.stdout + server.stdout.read() if server.stdout else "" + lat_us = 0 + for line in output.split("\n"): + parts = line.strip().split() + try: + lat_us = max(lat_us, float(parts[-1])) + except (ValueError, IndexError): + continue + + status = "PASS" if 0 < lat_us <= max_lat_us else ("WARN" if lat_us > 0 else "FAIL") + results.append({ + "test": label, + "status": status, + "latency_us": round(lat_us, 2), + "max_allowed_us": max_lat_us, + }) + except Exception as e: + server.kill() + results.append({"test": label, "status": "FAIL", "error": str(e)}) + + return results + + @staticmethod + def print_results(results: dict, console: Console = None): + c = console or Console() + if "error" in results: + c.print(f"[bold red]Error: {results['error']}[/bold red]") + return + + devices = results.get("devices", []) + c.print(f"\n[bold cyan]RDMA/InfiniBand Test Results[/bold cyan]") + c.print(f" Devices found: {len(devices)}") + + for dev in devices: + c.print(f"\n [bold]{dev['name']}[/bold]") + for p in dev.get("ports", []): + state = p.get("state", "N/A") + color = "green" if "Active" in state else "red" + c.print(f" Port {p['port']}: [{color}]{state}[/{color}] | " + f"Rate: {p.get('rate', 'N/A')} | GID: {p.get('gid', 'N/A')[:20]}") + + bw_tests = results.get("bandwidth_tests", []) + if bw_tests: + c.print("\n [bold]Bandwidth Tests[/bold]") + for t in bw_tests: + status = t.get("status", "SKIP") + sc = "green" if status == "PASS" else ("yellow" if status == "WARN" else "red") + bw = t.get("bandwidth_gbps", 0) + c.print(f" {t['test']}: [{sc}]{status}[/{sc}] " + f"({bw:.2f} GB/s, min: {t.get('min_required_gbps', 'N/A')} GB/s)" if status != "SKIP" + else f" {t['test']}: [dim]SKIPPED[/dim]") + + lat_tests = results.get("latency_tests", []) + if lat_tests: + c.print("\n [bold]Latency Tests[/bold]") + for t in lat_tests: + status = t.get("status", "SKIP") + sc = "green" if status == "PASS" else ("yellow" if status == "WARN" else "red") + lat = t.get("latency_us", 0) + c.print(f" {t['test']}: [{sc}]{status}[/{sc}] " + f"({lat:.2f} us, max: {t.get('max_allowed_us', 'N/A')} us)" if status != "SKIP" + else f" {t['test']}: [dim]SKIPPED[/dim]") diff --git a/modules/stress_test.py b/modules/stress_test.py new file mode 100644 index 0000000..8f04f1c --- /dev/null +++ b/modules/stress_test.py @@ -0,0 +1,198 @@ +"""GPU stress test module — wraps gpu-burn for long-running stability tests.""" + +import glob +import os +import shutil +import subprocess +import time +from datetime import datetime + +from rich.console import Console +from rich.table import Table +from rich.live import Live +from rich.text import Text + + +class StressTest: + + def __init__(self, config: dict): + self.config = config + self.console = Console() + self.stress_cfg = config.get("stress", {}) + self.tools_dir = config.get("tools", {}).get("install_dir", "/opt/h200-test-tools") + + def _find_gpu_burn(self) -> str: + p = shutil.which("gpu_burn") + if p: + return p + + local = os.path.join(self.tools_dir, "gpu-burn", "gpu_burn") + if os.path.isfile(local) and os.access(local, shutil.os.X_OK): + return local + + matches = glob.glob(os.path.join(self.tools_dir, "gpu-burn", "**", "gpu_burn"), recursive=True) + for m in matches: + if os.access(m, shutil.os.X_OK): + return m + return "" + + def run(self) -> dict: + cfg = self.stress_cfg + duration_sec = cfg.get("duration_sec", 60) + use_doubles = cfg.get("use_doubles", False) + use_tensor_cores = cfg.get("use_tensor_cores", True) + memory_pct = cfg.get("memory_pct", 90) + target_gpus = cfg.get("gpus", "all") + + gpu_burn = self._find_gpu_burn() + + if gpu_burn: + return self._run_gpu_burn(gpu_burn, duration_sec, use_doubles, use_tensor_cores, target_gpus) + + self.console.print("[yellow]gpu_burn not found, falling back to PyTorch stress test[/yellow]") + return self._run_pytorch_stress(duration_sec) + + def _run_gpu_burn(self, gpu_burn: str, duration: int, + doubles: bool, tensor_cores: bool, target_gpus: str) -> dict: + self.console.print(f"[cyan]GPU Stress Test via gpu-burn ({duration}s)[/cyan]") + + cmd = [gpu_burn] + if doubles: + cmd.append("-d") + if tensor_cores: + cmd.append("-tc") + if target_gpus != "all": + cmd.extend(["-i", str(target_gpus)]) + cmd.append(str(duration)) + + t0 = time.time() + try: + r = subprocess.run(cmd, capture_output=True, text=True, timeout=duration + 120) + elapsed = round(time.time() - t0, 1) + + output = r.stdout + r.stderr + passed = r.returncode == 0 + + gpu_results = [] + for line in output.split("\n"): + line = line.strip() + if "GPU" in line and ("PASS" in line.upper() or "FAIL" in line.upper()): + gpu_results.append(line) + + return { + "source": "gpu-burn", + "passed": passed, + "duration_sec": duration, + "elapsed_sec": elapsed, + "gpu_results": gpu_results, + "raw_output_tail": output[-500:] if output else "", + "timestamp": datetime.now().isoformat(), + } + + except subprocess.TimeoutExpired: + return { + "source": "gpu-burn", + "passed": False, + "duration_sec": duration, + "error": "timeout", + "timestamp": datetime.now().isoformat(), + } + except Exception as e: + return { + "source": "gpu-burn", + "passed": False, + "error": str(e), + "timestamp": datetime.now().isoformat(), + } + + def _run_pytorch_stress(self, duration: int) -> dict: + try: + import torch + if not torch.cuda.is_available(): + return {"error": "pytorch_not_available"} + except ImportError: + return {"error": "pytorch_not_available"} + + gpu_count = torch.cuda.device_count() + self.console.print(f"[cyan]PyTorch Stress Test ({duration}s, {gpu_count} GPUs)[/cyan]") + + gpu_status = {} + t0 = time.time() + + try: + tensors = {} + for i in range(gpu_count): + with torch.cuda.device(i): + total_mem = torch.cuda.get_device_properties(i).total_mem + alloc_size = int(total_mem * 0.9) // 4 + tensors[i] = torch.randn(alloc_size, device=f"cuda:{i}", dtype=torch.float32) + + while time.time() - t0 < duration: + for i in range(gpu_count): + with torch.cuda.device(i): + tensors[i] = torch.matmul(tensors[i][:2048, :2048], tensors[i][:2048, :2048].T) + torch.cuda.synchronize() + time.sleep(0.1) + + for i in range(gpu_count): + gpu_status[i] = "PASS" + + except RuntimeError as e: + for i in range(gpu_count): + if i not in gpu_status: + gpu_status[i] = "FAIL" + return { + "source": "pytorch", + "passed": False, + "duration_sec": duration, + "error": str(e), + "gpu_status": gpu_status, + } + finally: + tensors.clear() + torch.cuda.empty_cache() + + elapsed = round(time.time() - t0, 1) + return { + "source": "pytorch", + "passed": True, + "duration_sec": duration, + "elapsed_sec": elapsed, + "gpu_status": gpu_status, + "timestamp": datetime.now().isoformat(), + } + + @staticmethod + def print_results(results: dict, console: Console = None): + c = console or Console() + if "error" in results: + c.print(f"[bold red]Error: {results['error']}[/bold red]") + return + + passed = results.get("passed", False) + source = results.get("source", "unknown") + duration = results.get("duration_sec", "?") + elapsed = results.get("elapsed_sec", "?") + + verdict = "[bold green]✓ Stress Test PASSED[/bold green]" if passed else "[bold red]✗ Stress Test FAILED[/bold red]" + c.print(f"\n{verdict} [dim](via {source})[/dim]") + c.print(f" Target duration: {duration}s | Actual: {elapsed}s") + + gpu_results = results.get("gpu_results", []) + if gpu_results: + c.print("\n Per-GPU results:") + for line in gpu_results: + if "FAIL" in line.upper(): + c.print(f" [red]{line}[/red]") + else: + c.print(f" [green]{line}[/green]") + + gpu_status = results.get("gpu_status", {}) + if gpu_status: + c.print("\n Per-GPU status:") + for gid, status in sorted(gpu_status.items()): + color = "green" if status == "PASS" else "red" + c.print(f" GPU {gid}: [{color}]{status}[/{color}]") + + if results.get("error"): + c.print(f" [red]Error: {results['error']}[/red]")