add: stress test (gpu-burn) and RDMA/IB test modules
Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
parent
eac1438227
commit
1c6ba4809a
240
modules/rdma_test.py
Normal file
240
modules/rdma_test.py
Normal file
@ -0,0 +1,240 @@
|
||||
"""RDMA / InfiniBand bandwidth and latency test module."""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
|
||||
class RDMATest:
|
||||
|
||||
def __init__(self, config: dict):
|
||||
self.config = config
|
||||
self.console = Console()
|
||||
self.rdma_cfg = config.get("rdma", {})
|
||||
|
||||
def _find_tool(self, name: str) -> Optional[str]:
|
||||
p = shutil.which(name)
|
||||
if p:
|
||||
return p
|
||||
return None
|
||||
|
||||
def _get_ib_devices(self) -> List[str]:
|
||||
devices = []
|
||||
ib_path = "/sys/class/infiniband"
|
||||
if os.path.isdir(ib_path):
|
||||
devices = sorted(os.listdir(ib_path))
|
||||
return devices
|
||||
|
||||
def _get_ib_ports(self, device: str) -> List[str]:
|
||||
ports = []
|
||||
ports_dir = f"/sys/class/infiniband/{device}/ports"
|
||||
if os.path.isdir(ports_dir):
|
||||
ports = sorted(os.listdir(ports_dir))
|
||||
return ports
|
||||
|
||||
def run(self) -> dict:
|
||||
devices = self._get_ib_devices()
|
||||
if not devices:
|
||||
self.console.print("[yellow]No InfiniBand devices found[/yellow]")
|
||||
return {"error": "no_ib_devices", "passed": False}
|
||||
|
||||
self.console.print(f"[cyan]RDMA Test - Devices: {', '.join(devices)}[/cyan]")
|
||||
|
||||
device_info = self._collect_device_info(devices)
|
||||
bw_results = self._run_bandwidth_tests(devices)
|
||||
latency_results = self._run_latency_tests(devices)
|
||||
|
||||
all_passed = all(
|
||||
r.get("status") == "PASS"
|
||||
for r in bw_results + latency_results
|
||||
if isinstance(r, dict)
|
||||
)
|
||||
|
||||
return {
|
||||
"passed": all_passed,
|
||||
"devices": device_info,
|
||||
"bandwidth_tests": bw_results,
|
||||
"latency_tests": latency_results,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
def _collect_device_info(self, devices: List[str]) -> List[dict]:
|
||||
info = []
|
||||
for dev in devices:
|
||||
dev_info = {"name": dev, "ports": []}
|
||||
ports = self._get_ib_ports(dev)
|
||||
for port in ports:
|
||||
port_info = {"port": port}
|
||||
rate_path = f"/sys/class/infiniband/{dev}/ports/{port}/rate"
|
||||
state_path = f"/sys/class/infiniband/{dev}/ports/{port}/state"
|
||||
phys_state_path = f"/sys/class/infiniband/{dev}/ports/{port}/phys_state"
|
||||
gid_path = f"/sys/class/infiniband/{dev}/ports/{port}/gids/0"
|
||||
|
||||
for label, path in [("rate", rate_path), ("state", state_path),
|
||||
("phys_state", phys_state_path), ("gid", gid_path)]:
|
||||
try:
|
||||
with open(path) as f:
|
||||
port_info[label] = f.read().strip()
|
||||
except (FileNotFoundError, PermissionError):
|
||||
port_info[label] = "N/A"
|
||||
|
||||
dev_info["ports"].append(port_info)
|
||||
info.append(dev_info)
|
||||
return info
|
||||
|
||||
def _run_ib_command(self, cmd: List[str], timeout: int = 60) -> dict:
|
||||
try:
|
||||
r = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout)
|
||||
if r.returncode == 0:
|
||||
return {"status": "PASS", "output": r.stdout.strip()}
|
||||
return {"status": "FAIL", "error": r.stderr.strip()[:200]}
|
||||
except subprocess.TimeoutExpired:
|
||||
return {"status": "FAIL", "error": "timeout"}
|
||||
except FileNotFoundError:
|
||||
return {"status": "SKIP", "error": "tool not found"}
|
||||
except Exception as e:
|
||||
return {"status": "FAIL", "error": str(e)}
|
||||
|
||||
def _run_bandwidth_tests(self, devices: List[str]) -> List[dict]:
|
||||
results = []
|
||||
ib_write_bw = self._find_tool("ib_write_bw")
|
||||
ib_read_bw = self._find_tool("ib_read_bw")
|
||||
min_bw = self.rdma_cfg.get("min_bandwidth_gbps", 50)
|
||||
msg_size = self.rdma_cfg.get("msg_size", 65536)
|
||||
iters = self.rdma_cfg.get("ib_iterations", 1000)
|
||||
dx = self.rdma_cfg.get("ib_device", None)
|
||||
port = self.rdma_cfg.get("ib_port", 1)
|
||||
|
||||
for tool, label in [(ib_write_bw, "ib_write_bw"), (ib_read_bw, "ib_read_bw")]:
|
||||
if not tool:
|
||||
results.append({"test": label, "status": "SKIP", "error": "not installed"})
|
||||
continue
|
||||
|
||||
server_cmd = [tool, "-d", dx or devices[0], "-i", str(port), "-s", str(msg_size)]
|
||||
client_cmd = server_cmd + ["localhost"]
|
||||
|
||||
server = subprocess.Popen(server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
import time
|
||||
time.sleep(1)
|
||||
|
||||
try:
|
||||
client = subprocess.run(client_cmd, capture_output=True, text=True, timeout=60)
|
||||
server.wait(timeout=10)
|
||||
|
||||
output = client.stdout + server.stdout.read() if server.stdout else ""
|
||||
bw_mbps = 0
|
||||
for line in output.split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
parts = line.split()
|
||||
try:
|
||||
bw_mbps = max(bw_mbps, float(parts[-1]))
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
bw_gbps = bw_mbps / 1000 if bw_mbps else 0
|
||||
status = "PASS" if bw_gbps >= min_bw else "WARN"
|
||||
results.append({
|
||||
"test": label,
|
||||
"status": status,
|
||||
"bandwidth_gbps": round(bw_gbps, 2),
|
||||
"min_required_gbps": min_bw,
|
||||
})
|
||||
except Exception as e:
|
||||
server.kill()
|
||||
results.append({"test": label, "status": "FAIL", "error": str(e)})
|
||||
|
||||
return results
|
||||
|
||||
def _run_latency_tests(self, devices: List[str]) -> List[dict]:
|
||||
results = []
|
||||
ib_write_lat = self._find_tool("ib_write_lat")
|
||||
ib_read_lat = self._find_tool("ib_read_lat")
|
||||
max_lat_us = self.rdma_cfg.get("max_latency_us", 10)
|
||||
dx = self.rdma_cfg.get("ib_device", None)
|
||||
port = self.rdma_cfg.get("ib_port", 1)
|
||||
|
||||
for tool, label in [(ib_write_lat, "ib_write_lat"), (ib_read_lat, "ib_read_lat")]:
|
||||
if not tool:
|
||||
results.append({"test": label, "status": "SKIP", "error": "not installed"})
|
||||
continue
|
||||
|
||||
server_cmd = [tool, "-d", dx or devices[0], "-i", str(port)]
|
||||
client_cmd = server_cmd + ["localhost"]
|
||||
|
||||
server = subprocess.Popen(server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
import time
|
||||
time.sleep(1)
|
||||
|
||||
try:
|
||||
client = subprocess.run(client_cmd, capture_output=True, text=True, timeout=60)
|
||||
server.wait(timeout=10)
|
||||
|
||||
output = client.stdout + server.stdout.read() if server.stdout else ""
|
||||
lat_us = 0
|
||||
for line in output.split("\n"):
|
||||
parts = line.strip().split()
|
||||
try:
|
||||
lat_us = max(lat_us, float(parts[-1]))
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
status = "PASS" if 0 < lat_us <= max_lat_us else ("WARN" if lat_us > 0 else "FAIL")
|
||||
results.append({
|
||||
"test": label,
|
||||
"status": status,
|
||||
"latency_us": round(lat_us, 2),
|
||||
"max_allowed_us": max_lat_us,
|
||||
})
|
||||
except Exception as e:
|
||||
server.kill()
|
||||
results.append({"test": label, "status": "FAIL", "error": str(e)})
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def print_results(results: dict, console: Console = None):
|
||||
c = console or Console()
|
||||
if "error" in results:
|
||||
c.print(f"[bold red]Error: {results['error']}[/bold red]")
|
||||
return
|
||||
|
||||
devices = results.get("devices", [])
|
||||
c.print(f"\n[bold cyan]RDMA/InfiniBand Test Results[/bold cyan]")
|
||||
c.print(f" Devices found: {len(devices)}")
|
||||
|
||||
for dev in devices:
|
||||
c.print(f"\n [bold]{dev['name']}[/bold]")
|
||||
for p in dev.get("ports", []):
|
||||
state = p.get("state", "N/A")
|
||||
color = "green" if "Active" in state else "red"
|
||||
c.print(f" Port {p['port']}: [{color}]{state}[/{color}] | "
|
||||
f"Rate: {p.get('rate', 'N/A')} | GID: {p.get('gid', 'N/A')[:20]}")
|
||||
|
||||
bw_tests = results.get("bandwidth_tests", [])
|
||||
if bw_tests:
|
||||
c.print("\n [bold]Bandwidth Tests[/bold]")
|
||||
for t in bw_tests:
|
||||
status = t.get("status", "SKIP")
|
||||
sc = "green" if status == "PASS" else ("yellow" if status == "WARN" else "red")
|
||||
bw = t.get("bandwidth_gbps", 0)
|
||||
c.print(f" {t['test']}: [{sc}]{status}[/{sc}] "
|
||||
f"({bw:.2f} GB/s, min: {t.get('min_required_gbps', 'N/A')} GB/s)" if status != "SKIP"
|
||||
else f" {t['test']}: [dim]SKIPPED[/dim]")
|
||||
|
||||
lat_tests = results.get("latency_tests", [])
|
||||
if lat_tests:
|
||||
c.print("\n [bold]Latency Tests[/bold]")
|
||||
for t in lat_tests:
|
||||
status = t.get("status", "SKIP")
|
||||
sc = "green" if status == "PASS" else ("yellow" if status == "WARN" else "red")
|
||||
lat = t.get("latency_us", 0)
|
||||
c.print(f" {t['test']}: [{sc}]{status}[/{sc}] "
|
||||
f"({lat:.2f} us, max: {t.get('max_allowed_us', 'N/A')} us)" if status != "SKIP"
|
||||
else f" {t['test']}: [dim]SKIPPED[/dim]")
|
||||
198
modules/stress_test.py
Normal file
198
modules/stress_test.py
Normal file
@ -0,0 +1,198 @@
|
||||
"""GPU stress test module — wraps gpu-burn for long-running stability tests."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.live import Live
|
||||
from rich.text import Text
|
||||
|
||||
|
||||
class StressTest:
|
||||
|
||||
def __init__(self, config: dict):
|
||||
self.config = config
|
||||
self.console = Console()
|
||||
self.stress_cfg = config.get("stress", {})
|
||||
self.tools_dir = config.get("tools", {}).get("install_dir", "/opt/h200-test-tools")
|
||||
|
||||
def _find_gpu_burn(self) -> str:
|
||||
p = shutil.which("gpu_burn")
|
||||
if p:
|
||||
return p
|
||||
|
||||
local = os.path.join(self.tools_dir, "gpu-burn", "gpu_burn")
|
||||
if os.path.isfile(local) and os.access(local, shutil.os.X_OK):
|
||||
return local
|
||||
|
||||
matches = glob.glob(os.path.join(self.tools_dir, "gpu-burn", "**", "gpu_burn"), recursive=True)
|
||||
for m in matches:
|
||||
if os.access(m, shutil.os.X_OK):
|
||||
return m
|
||||
return ""
|
||||
|
||||
def run(self) -> dict:
|
||||
cfg = self.stress_cfg
|
||||
duration_sec = cfg.get("duration_sec", 60)
|
||||
use_doubles = cfg.get("use_doubles", False)
|
||||
use_tensor_cores = cfg.get("use_tensor_cores", True)
|
||||
memory_pct = cfg.get("memory_pct", 90)
|
||||
target_gpus = cfg.get("gpus", "all")
|
||||
|
||||
gpu_burn = self._find_gpu_burn()
|
||||
|
||||
if gpu_burn:
|
||||
return self._run_gpu_burn(gpu_burn, duration_sec, use_doubles, use_tensor_cores, target_gpus)
|
||||
|
||||
self.console.print("[yellow]gpu_burn not found, falling back to PyTorch stress test[/yellow]")
|
||||
return self._run_pytorch_stress(duration_sec)
|
||||
|
||||
def _run_gpu_burn(self, gpu_burn: str, duration: int,
|
||||
doubles: bool, tensor_cores: bool, target_gpus: str) -> dict:
|
||||
self.console.print(f"[cyan]GPU Stress Test via gpu-burn ({duration}s)[/cyan]")
|
||||
|
||||
cmd = [gpu_burn]
|
||||
if doubles:
|
||||
cmd.append("-d")
|
||||
if tensor_cores:
|
||||
cmd.append("-tc")
|
||||
if target_gpus != "all":
|
||||
cmd.extend(["-i", str(target_gpus)])
|
||||
cmd.append(str(duration))
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
r = subprocess.run(cmd, capture_output=True, text=True, timeout=duration + 120)
|
||||
elapsed = round(time.time() - t0, 1)
|
||||
|
||||
output = r.stdout + r.stderr
|
||||
passed = r.returncode == 0
|
||||
|
||||
gpu_results = []
|
||||
for line in output.split("\n"):
|
||||
line = line.strip()
|
||||
if "GPU" in line and ("PASS" in line.upper() or "FAIL" in line.upper()):
|
||||
gpu_results.append(line)
|
||||
|
||||
return {
|
||||
"source": "gpu-burn",
|
||||
"passed": passed,
|
||||
"duration_sec": duration,
|
||||
"elapsed_sec": elapsed,
|
||||
"gpu_results": gpu_results,
|
||||
"raw_output_tail": output[-500:] if output else "",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"source": "gpu-burn",
|
||||
"passed": False,
|
||||
"duration_sec": duration,
|
||||
"error": "timeout",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"source": "gpu-burn",
|
||||
"passed": False,
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
def _run_pytorch_stress(self, duration: int) -> dict:
|
||||
try:
|
||||
import torch
|
||||
if not torch.cuda.is_available():
|
||||
return {"error": "pytorch_not_available"}
|
||||
except ImportError:
|
||||
return {"error": "pytorch_not_available"}
|
||||
|
||||
gpu_count = torch.cuda.device_count()
|
||||
self.console.print(f"[cyan]PyTorch Stress Test ({duration}s, {gpu_count} GPUs)[/cyan]")
|
||||
|
||||
gpu_status = {}
|
||||
t0 = time.time()
|
||||
|
||||
try:
|
||||
tensors = {}
|
||||
for i in range(gpu_count):
|
||||
with torch.cuda.device(i):
|
||||
total_mem = torch.cuda.get_device_properties(i).total_mem
|
||||
alloc_size = int(total_mem * 0.9) // 4
|
||||
tensors[i] = torch.randn(alloc_size, device=f"cuda:{i}", dtype=torch.float32)
|
||||
|
||||
while time.time() - t0 < duration:
|
||||
for i in range(gpu_count):
|
||||
with torch.cuda.device(i):
|
||||
tensors[i] = torch.matmul(tensors[i][:2048, :2048], tensors[i][:2048, :2048].T)
|
||||
torch.cuda.synchronize()
|
||||
time.sleep(0.1)
|
||||
|
||||
for i in range(gpu_count):
|
||||
gpu_status[i] = "PASS"
|
||||
|
||||
except RuntimeError as e:
|
||||
for i in range(gpu_count):
|
||||
if i not in gpu_status:
|
||||
gpu_status[i] = "FAIL"
|
||||
return {
|
||||
"source": "pytorch",
|
||||
"passed": False,
|
||||
"duration_sec": duration,
|
||||
"error": str(e),
|
||||
"gpu_status": gpu_status,
|
||||
}
|
||||
finally:
|
||||
tensors.clear()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
elapsed = round(time.time() - t0, 1)
|
||||
return {
|
||||
"source": "pytorch",
|
||||
"passed": True,
|
||||
"duration_sec": duration,
|
||||
"elapsed_sec": elapsed,
|
||||
"gpu_status": gpu_status,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def print_results(results: dict, console: Console = None):
|
||||
c = console or Console()
|
||||
if "error" in results:
|
||||
c.print(f"[bold red]Error: {results['error']}[/bold red]")
|
||||
return
|
||||
|
||||
passed = results.get("passed", False)
|
||||
source = results.get("source", "unknown")
|
||||
duration = results.get("duration_sec", "?")
|
||||
elapsed = results.get("elapsed_sec", "?")
|
||||
|
||||
verdict = "[bold green]✓ Stress Test PASSED[/bold green]" if passed else "[bold red]✗ Stress Test FAILED[/bold red]"
|
||||
c.print(f"\n{verdict} [dim](via {source})[/dim]")
|
||||
c.print(f" Target duration: {duration}s | Actual: {elapsed}s")
|
||||
|
||||
gpu_results = results.get("gpu_results", [])
|
||||
if gpu_results:
|
||||
c.print("\n Per-GPU results:")
|
||||
for line in gpu_results:
|
||||
if "FAIL" in line.upper():
|
||||
c.print(f" [red]{line}[/red]")
|
||||
else:
|
||||
c.print(f" [green]{line}[/green]")
|
||||
|
||||
gpu_status = results.get("gpu_status", {})
|
||||
if gpu_status:
|
||||
c.print("\n Per-GPU status:")
|
||||
for gid, status in sorted(gpu_status.items()):
|
||||
color = "green" if status == "PASS" else "red"
|
||||
c.print(f" GPU {gid}: [{color}]{status}[/{color}]")
|
||||
|
||||
if results.get("error"):
|
||||
c.print(f" [red]Error: {results['error']}[/red]")
|
||||
Loading…
x
Reference in New Issue
Block a user