add: stress test (gpu-burn) and RDMA/IB test modules

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
qinyusen 2026-04-25 17:23:57 +08:00
parent eac1438227
commit 1c6ba4809a
2 changed files with 438 additions and 0 deletions

240
modules/rdma_test.py Normal file
View File

@ -0,0 +1,240 @@
"""RDMA / InfiniBand bandwidth and latency test module."""
import os
import shutil
import subprocess
from datetime import datetime
from typing import Optional, List
from rich.console import Console
from rich.table import Table
class RDMATest:
def __init__(self, config: dict):
self.config = config
self.console = Console()
self.rdma_cfg = config.get("rdma", {})
def _find_tool(self, name: str) -> Optional[str]:
p = shutil.which(name)
if p:
return p
return None
def _get_ib_devices(self) -> List[str]:
devices = []
ib_path = "/sys/class/infiniband"
if os.path.isdir(ib_path):
devices = sorted(os.listdir(ib_path))
return devices
def _get_ib_ports(self, device: str) -> List[str]:
ports = []
ports_dir = f"/sys/class/infiniband/{device}/ports"
if os.path.isdir(ports_dir):
ports = sorted(os.listdir(ports_dir))
return ports
def run(self) -> dict:
devices = self._get_ib_devices()
if not devices:
self.console.print("[yellow]No InfiniBand devices found[/yellow]")
return {"error": "no_ib_devices", "passed": False}
self.console.print(f"[cyan]RDMA Test - Devices: {', '.join(devices)}[/cyan]")
device_info = self._collect_device_info(devices)
bw_results = self._run_bandwidth_tests(devices)
latency_results = self._run_latency_tests(devices)
all_passed = all(
r.get("status") == "PASS"
for r in bw_results + latency_results
if isinstance(r, dict)
)
return {
"passed": all_passed,
"devices": device_info,
"bandwidth_tests": bw_results,
"latency_tests": latency_results,
"timestamp": datetime.now().isoformat(),
}
def _collect_device_info(self, devices: List[str]) -> List[dict]:
info = []
for dev in devices:
dev_info = {"name": dev, "ports": []}
ports = self._get_ib_ports(dev)
for port in ports:
port_info = {"port": port}
rate_path = f"/sys/class/infiniband/{dev}/ports/{port}/rate"
state_path = f"/sys/class/infiniband/{dev}/ports/{port}/state"
phys_state_path = f"/sys/class/infiniband/{dev}/ports/{port}/phys_state"
gid_path = f"/sys/class/infiniband/{dev}/ports/{port}/gids/0"
for label, path in [("rate", rate_path), ("state", state_path),
("phys_state", phys_state_path), ("gid", gid_path)]:
try:
with open(path) as f:
port_info[label] = f.read().strip()
except (FileNotFoundError, PermissionError):
port_info[label] = "N/A"
dev_info["ports"].append(port_info)
info.append(dev_info)
return info
def _run_ib_command(self, cmd: List[str], timeout: int = 60) -> dict:
try:
r = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout)
if r.returncode == 0:
return {"status": "PASS", "output": r.stdout.strip()}
return {"status": "FAIL", "error": r.stderr.strip()[:200]}
except subprocess.TimeoutExpired:
return {"status": "FAIL", "error": "timeout"}
except FileNotFoundError:
return {"status": "SKIP", "error": "tool not found"}
except Exception as e:
return {"status": "FAIL", "error": str(e)}
def _run_bandwidth_tests(self, devices: List[str]) -> List[dict]:
results = []
ib_write_bw = self._find_tool("ib_write_bw")
ib_read_bw = self._find_tool("ib_read_bw")
min_bw = self.rdma_cfg.get("min_bandwidth_gbps", 50)
msg_size = self.rdma_cfg.get("msg_size", 65536)
iters = self.rdma_cfg.get("ib_iterations", 1000)
dx = self.rdma_cfg.get("ib_device", None)
port = self.rdma_cfg.get("ib_port", 1)
for tool, label in [(ib_write_bw, "ib_write_bw"), (ib_read_bw, "ib_read_bw")]:
if not tool:
results.append({"test": label, "status": "SKIP", "error": "not installed"})
continue
server_cmd = [tool, "-d", dx or devices[0], "-i", str(port), "-s", str(msg_size)]
client_cmd = server_cmd + ["localhost"]
server = subprocess.Popen(server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
import time
time.sleep(1)
try:
client = subprocess.run(client_cmd, capture_output=True, text=True, timeout=60)
server.wait(timeout=10)
output = client.stdout + server.stdout.read() if server.stdout else ""
bw_mbps = 0
for line in output.split("\n"):
line = line.strip()
if not line:
continue
parts = line.split()
try:
bw_mbps = max(bw_mbps, float(parts[-1]))
except (ValueError, IndexError):
continue
bw_gbps = bw_mbps / 1000 if bw_mbps else 0
status = "PASS" if bw_gbps >= min_bw else "WARN"
results.append({
"test": label,
"status": status,
"bandwidth_gbps": round(bw_gbps, 2),
"min_required_gbps": min_bw,
})
except Exception as e:
server.kill()
results.append({"test": label, "status": "FAIL", "error": str(e)})
return results
def _run_latency_tests(self, devices: List[str]) -> List[dict]:
results = []
ib_write_lat = self._find_tool("ib_write_lat")
ib_read_lat = self._find_tool("ib_read_lat")
max_lat_us = self.rdma_cfg.get("max_latency_us", 10)
dx = self.rdma_cfg.get("ib_device", None)
port = self.rdma_cfg.get("ib_port", 1)
for tool, label in [(ib_write_lat, "ib_write_lat"), (ib_read_lat, "ib_read_lat")]:
if not tool:
results.append({"test": label, "status": "SKIP", "error": "not installed"})
continue
server_cmd = [tool, "-d", dx or devices[0], "-i", str(port)]
client_cmd = server_cmd + ["localhost"]
server = subprocess.Popen(server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
import time
time.sleep(1)
try:
client = subprocess.run(client_cmd, capture_output=True, text=True, timeout=60)
server.wait(timeout=10)
output = client.stdout + server.stdout.read() if server.stdout else ""
lat_us = 0
for line in output.split("\n"):
parts = line.strip().split()
try:
lat_us = max(lat_us, float(parts[-1]))
except (ValueError, IndexError):
continue
status = "PASS" if 0 < lat_us <= max_lat_us else ("WARN" if lat_us > 0 else "FAIL")
results.append({
"test": label,
"status": status,
"latency_us": round(lat_us, 2),
"max_allowed_us": max_lat_us,
})
except Exception as e:
server.kill()
results.append({"test": label, "status": "FAIL", "error": str(e)})
return results
@staticmethod
def print_results(results: dict, console: Console = None):
c = console or Console()
if "error" in results:
c.print(f"[bold red]Error: {results['error']}[/bold red]")
return
devices = results.get("devices", [])
c.print(f"\n[bold cyan]RDMA/InfiniBand Test Results[/bold cyan]")
c.print(f" Devices found: {len(devices)}")
for dev in devices:
c.print(f"\n [bold]{dev['name']}[/bold]")
for p in dev.get("ports", []):
state = p.get("state", "N/A")
color = "green" if "Active" in state else "red"
c.print(f" Port {p['port']}: [{color}]{state}[/{color}] | "
f"Rate: {p.get('rate', 'N/A')} | GID: {p.get('gid', 'N/A')[:20]}")
bw_tests = results.get("bandwidth_tests", [])
if bw_tests:
c.print("\n [bold]Bandwidth Tests[/bold]")
for t in bw_tests:
status = t.get("status", "SKIP")
sc = "green" if status == "PASS" else ("yellow" if status == "WARN" else "red")
bw = t.get("bandwidth_gbps", 0)
c.print(f" {t['test']}: [{sc}]{status}[/{sc}] "
f"({bw:.2f} GB/s, min: {t.get('min_required_gbps', 'N/A')} GB/s)" if status != "SKIP"
else f" {t['test']}: [dim]SKIPPED[/dim]")
lat_tests = results.get("latency_tests", [])
if lat_tests:
c.print("\n [bold]Latency Tests[/bold]")
for t in lat_tests:
status = t.get("status", "SKIP")
sc = "green" if status == "PASS" else ("yellow" if status == "WARN" else "red")
lat = t.get("latency_us", 0)
c.print(f" {t['test']}: [{sc}]{status}[/{sc}] "
f"({lat:.2f} us, max: {t.get('max_allowed_us', 'N/A')} us)" if status != "SKIP"
else f" {t['test']}: [dim]SKIPPED[/dim]")

198
modules/stress_test.py Normal file
View File

@ -0,0 +1,198 @@
"""GPU stress test module — wraps gpu-burn for long-running stability tests."""
import glob
import os
import shutil
import subprocess
import time
from datetime import datetime
from rich.console import Console
from rich.table import Table
from rich.live import Live
from rich.text import Text
class StressTest:
def __init__(self, config: dict):
self.config = config
self.console = Console()
self.stress_cfg = config.get("stress", {})
self.tools_dir = config.get("tools", {}).get("install_dir", "/opt/h200-test-tools")
def _find_gpu_burn(self) -> str:
p = shutil.which("gpu_burn")
if p:
return p
local = os.path.join(self.tools_dir, "gpu-burn", "gpu_burn")
if os.path.isfile(local) and os.access(local, shutil.os.X_OK):
return local
matches = glob.glob(os.path.join(self.tools_dir, "gpu-burn", "**", "gpu_burn"), recursive=True)
for m in matches:
if os.access(m, shutil.os.X_OK):
return m
return ""
def run(self) -> dict:
cfg = self.stress_cfg
duration_sec = cfg.get("duration_sec", 60)
use_doubles = cfg.get("use_doubles", False)
use_tensor_cores = cfg.get("use_tensor_cores", True)
memory_pct = cfg.get("memory_pct", 90)
target_gpus = cfg.get("gpus", "all")
gpu_burn = self._find_gpu_burn()
if gpu_burn:
return self._run_gpu_burn(gpu_burn, duration_sec, use_doubles, use_tensor_cores, target_gpus)
self.console.print("[yellow]gpu_burn not found, falling back to PyTorch stress test[/yellow]")
return self._run_pytorch_stress(duration_sec)
def _run_gpu_burn(self, gpu_burn: str, duration: int,
doubles: bool, tensor_cores: bool, target_gpus: str) -> dict:
self.console.print(f"[cyan]GPU Stress Test via gpu-burn ({duration}s)[/cyan]")
cmd = [gpu_burn]
if doubles:
cmd.append("-d")
if tensor_cores:
cmd.append("-tc")
if target_gpus != "all":
cmd.extend(["-i", str(target_gpus)])
cmd.append(str(duration))
t0 = time.time()
try:
r = subprocess.run(cmd, capture_output=True, text=True, timeout=duration + 120)
elapsed = round(time.time() - t0, 1)
output = r.stdout + r.stderr
passed = r.returncode == 0
gpu_results = []
for line in output.split("\n"):
line = line.strip()
if "GPU" in line and ("PASS" in line.upper() or "FAIL" in line.upper()):
gpu_results.append(line)
return {
"source": "gpu-burn",
"passed": passed,
"duration_sec": duration,
"elapsed_sec": elapsed,
"gpu_results": gpu_results,
"raw_output_tail": output[-500:] if output else "",
"timestamp": datetime.now().isoformat(),
}
except subprocess.TimeoutExpired:
return {
"source": "gpu-burn",
"passed": False,
"duration_sec": duration,
"error": "timeout",
"timestamp": datetime.now().isoformat(),
}
except Exception as e:
return {
"source": "gpu-burn",
"passed": False,
"error": str(e),
"timestamp": datetime.now().isoformat(),
}
def _run_pytorch_stress(self, duration: int) -> dict:
try:
import torch
if not torch.cuda.is_available():
return {"error": "pytorch_not_available"}
except ImportError:
return {"error": "pytorch_not_available"}
gpu_count = torch.cuda.device_count()
self.console.print(f"[cyan]PyTorch Stress Test ({duration}s, {gpu_count} GPUs)[/cyan]")
gpu_status = {}
t0 = time.time()
try:
tensors = {}
for i in range(gpu_count):
with torch.cuda.device(i):
total_mem = torch.cuda.get_device_properties(i).total_mem
alloc_size = int(total_mem * 0.9) // 4
tensors[i] = torch.randn(alloc_size, device=f"cuda:{i}", dtype=torch.float32)
while time.time() - t0 < duration:
for i in range(gpu_count):
with torch.cuda.device(i):
tensors[i] = torch.matmul(tensors[i][:2048, :2048], tensors[i][:2048, :2048].T)
torch.cuda.synchronize()
time.sleep(0.1)
for i in range(gpu_count):
gpu_status[i] = "PASS"
except RuntimeError as e:
for i in range(gpu_count):
if i not in gpu_status:
gpu_status[i] = "FAIL"
return {
"source": "pytorch",
"passed": False,
"duration_sec": duration,
"error": str(e),
"gpu_status": gpu_status,
}
finally:
tensors.clear()
torch.cuda.empty_cache()
elapsed = round(time.time() - t0, 1)
return {
"source": "pytorch",
"passed": True,
"duration_sec": duration,
"elapsed_sec": elapsed,
"gpu_status": gpu_status,
"timestamp": datetime.now().isoformat(),
}
@staticmethod
def print_results(results: dict, console: Console = None):
c = console or Console()
if "error" in results:
c.print(f"[bold red]Error: {results['error']}[/bold red]")
return
passed = results.get("passed", False)
source = results.get("source", "unknown")
duration = results.get("duration_sec", "?")
elapsed = results.get("elapsed_sec", "?")
verdict = "[bold green]✓ Stress Test PASSED[/bold green]" if passed else "[bold red]✗ Stress Test FAILED[/bold red]"
c.print(f"\n{verdict} [dim](via {source})[/dim]")
c.print(f" Target duration: {duration}s | Actual: {elapsed}s")
gpu_results = results.get("gpu_results", [])
if gpu_results:
c.print("\n Per-GPU results:")
for line in gpu_results:
if "FAIL" in line.upper():
c.print(f" [red]{line}[/red]")
else:
c.print(f" [green]{line}[/green]")
gpu_status = results.get("gpu_status", {})
if gpu_status:
c.print("\n Per-GPU status:")
for gid, status in sorted(gpu_status.items()):
color = "green" if status == "PASS" else "red"
c.print(f" GPU {gid}: [{color}]{status}[/{color}]")
if results.get("error"):
c.print(f" [red]Error: {results['error']}[/red]")