test_gpu_scripts/modules/stress_test.py

488 lines
21 KiB
Python

"""GPU stress test module — gpu-burn or PyTorch GEMM with telemetry."""
import glob
import os
import shutil
import subprocess
import threading
import time
from datetime import datetime
from rich.console import Console
from rich.table import Table
from rich.live import Live
from rich.text import Text
from modules.gpu_specs import resolve_tools_dir
class StressTest:
def __init__(self, config: dict):
self.config = config
self.console = Console()
self.stress_cfg = config.get("stress", {})
self.tools_dir = resolve_tools_dir(config)
def _find_gpu_burn(self) -> str:
p = shutil.which("gpu_burn")
if p:
return p
local = os.path.join(self.tools_dir, "gpu-burn", "gpu_burn")
if os.path.isfile(local) and os.access(local, shutil.os.X_OK):
return local
matches = glob.glob(os.path.join(self.tools_dir, "gpu-burn", "**", "gpu_burn"), recursive=True)
for m in matches:
if os.access(m, shutil.os.X_OK):
return m
return ""
def run(self) -> dict:
cfg = self.stress_cfg
duration_sec = cfg.get("duration_sec", 60)
use_doubles = cfg.get("use_doubles", False)
use_tensor_cores = cfg.get("use_tensor_cores", True)
memory_pct = cfg.get("memory_pct", 90)
target_gpus = cfg.get("gpus", "all")
gpu_burn = self._find_gpu_burn() if cfg.get("use_gpu_burn", False) else ""
if gpu_burn:
# Try gpu-burn first
result = self._run_gpu_burn(gpu_burn, duration_sec, use_doubles, use_tensor_cores, target_gpus)
# If gpu-burn fails (e.g. OOM), auto-fallback to PyTorch
if not result.get("passed") and result.get("elapsed_sec", 0) < duration_sec * 0.5:
self.console.print("\n[yellow]gpu-burn exited early (possible OOM), switching to PyTorch stress test[/yellow]")
self.console.print("[dim]PyTorch mode dynamically adapts to available memory[/dim]\n")
return self._run_pytorch_stress(duration_sec, memory_pct)
return result
self.console.print("[yellow]Using PyTorch stress test[/yellow]")
return self._run_pytorch_stress(duration_sec, memory_pct)
def _run_gpu_burn(self, gpu_burn: str, duration: int,
doubles: bool, tensor_cores: bool, target_gpus: str) -> dict:
self.console.print(f"[cyan]GPU Stress Test via gpu-burn ({duration}s)[/cyan]")
cmd = [gpu_burn]
if doubles:
cmd.append("-d")
if tensor_cores:
cmd.append("-tc")
if target_gpus != "all":
cmd.extend(["-i", str(target_gpus)])
cmd.append(str(duration))
t0 = time.time()
xid_before = self._collect_xid_events()
interval = int(self.stress_cfg.get("telemetry_interval_sec", 1))
telemetry = []
stop_sampling = threading.Event()
sampler = threading.Thread(
target=self._sample_telemetry,
args=(telemetry, stop_sampling, interval),
daemon=True,
)
sampler.start()
try:
r = subprocess.run(cmd, capture_output=True, text=True, timeout=duration + 120)
elapsed = round(time.time() - t0, 1)
stop_sampling.set()
sampler.join(timeout=interval + 1)
output = r.stdout + r.stderr
xid_events = self._new_xid_events(xid_before, self._collect_xid_events())
telemetry_summary = self._evaluate_telemetry(telemetry, [], xid_events)
passed = r.returncode == 0 and telemetry_summary.get("passed", False)
gpu_results = []
for line in output.split("\n"):
line = line.strip()
if "GPU" in line and ("PASS" in line.upper() or "FAIL" in line.upper()):
gpu_results.append(line)
return {
"source": "gpu-burn",
"passed": passed,
"duration_sec": duration,
"elapsed_sec": elapsed,
"gpu_results": gpu_results,
"telemetry": telemetry_summary,
"raw_output_tail": output[-500:] if output else "",
"timestamp": datetime.now().isoformat(),
}
except subprocess.TimeoutExpired:
stop_sampling.set()
return {
"source": "gpu-burn",
"passed": False,
"duration_sec": duration,
"error": "timeout",
"telemetry": self._evaluate_telemetry(
telemetry, [], self._new_xid_events(xid_before, self._collect_xid_events())
),
"timestamp": datetime.now().isoformat(),
}
except Exception as e:
stop_sampling.set()
return {
"source": "gpu-burn",
"passed": False,
"error": str(e),
"telemetry": self._evaluate_telemetry(
telemetry, [], self._new_xid_events(xid_before, self._collect_xid_events())
),
"timestamp": datetime.now().isoformat(),
}
finally:
stop_sampling.set()
def _run_pytorch_stress(self, duration: int, memory_pct: int = 90) -> dict:
try:
import torch
if not torch.cuda.is_available():
return {"error": "pytorch_not_available"}
except ImportError:
return {"error": "pytorch_not_available"}
gpu_count = torch.cuda.device_count()
self.console.print(f"[cyan]PyTorch Stress Test ({duration}s, {gpu_count} GPUs, target {memory_pct}% memory)[/cyan]")
dtype_name = self.stress_cfg.get("dtype", "bf16")
matrix_size = int(self.stress_cfg.get("matrix_size", 8192))
interval = int(self.stress_cfg.get("telemetry_interval_sec", 1))
dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}
dtype = dtype_map.get(dtype_name, torch.bfloat16)
gpu_status = {}
telemetry = []
stop_sampling = threading.Event()
t0 = time.time()
xid_before = self._collect_xid_events()
try:
sampler = threading.Thread(
target=self._sample_telemetry,
args=(telemetry, stop_sampling, interval),
daemon=True,
)
sampler.start()
tensors = {}
ballast = {}
pass_tflops = []
for i in range(gpu_count):
with torch.cuda.device(i):
free_mem, total_mem = torch.cuda.mem_get_info(i)
side = matrix_size
elem = torch.tensor([], dtype=dtype).element_size()
compute_bytes = side * side * elem * 3
target_mem = min(int(total_mem * memory_pct / 100), int(free_mem * 0.90))
ballast_bytes = max(0, target_mem - compute_bytes)
if ballast_bytes:
ballast_elems = ballast_bytes // 2
ballast[i] = torch.empty(ballast_elems, device=f"cuda:{i}", dtype=torch.float16)
actual_mem_mb = (compute_bytes + ballast_bytes) / 1024 / 1024
total_mem_mb = total_mem / 1024 / 1024
free_mem_mb = free_mem / 1024 / 1024
self.console.print(
f" [dim]GPU {i}: total {total_mem_mb:.0f}MB, free {free_mem_mb:.0f}MB, "
f"alloc {actual_mem_mb:.0f}MB ({actual_mem_mb/total_mem_mb*100:.0f}%) - "
f"{dtype_name} matrix {side}x{side}[/dim]"
)
tensors[i] = (
torch.randn(side, side, device=f"cuda:{i}", dtype=dtype),
torch.randn(side, side, device=f"cuda:{i}", dtype=dtype),
torch.empty(side, side, device=f"cuda:{i}", dtype=dtype),
)
self.console.print(f"\n[cyan]Starting stress test for {duration} seconds...[/cyan]")
elapsed_check = 0
while time.time() - t0 < duration:
loop_start = time.perf_counter()
# Dispatch matmul on all GPUs in parallel — do NOT synchronize between
# GPUs, otherwise the 8 GPUs run serially and overshoot the duration.
for i in range(gpu_count):
with torch.cuda.device(i):
a, b, out = tensors[i]
torch.matmul(a, b, out=out)
# Single sync per pass — waits for all 8 streams concurrently
for i in range(gpu_count):
with torch.cuda.device(i):
torch.cuda.synchronize()
loop_elapsed = time.perf_counter() - loop_start
current_elapsed = time.time() - t0
if loop_elapsed > 0:
flops = gpu_count * 2 * (matrix_size ** 3)
pass_tflops.append({
"elapsed_sec": current_elapsed,
"tflops": flops / loop_elapsed / 1e12,
})
# Show progress every 10 seconds
if int(current_elapsed) != int(elapsed_check) and int(current_elapsed) % 10 == 0:
self.console.print(f" [dim]Running {int(current_elapsed)}s / {duration}s[/dim]")
elapsed_check = current_elapsed
for i in range(gpu_count):
gpu_status[i] = "PASS"
except RuntimeError as e:
error_msg = str(e)
self.console.print(f"\n[red]Stress test error: {error_msg}[/red]")
for i in range(gpu_count):
if i not in gpu_status:
gpu_status[i] = "FAIL"
return {
"source": "pytorch",
"passed": False,
"duration_sec": duration,
"error": error_msg,
"gpu_status": gpu_status,
"telemetry": self._evaluate_telemetry(
telemetry, pass_tflops if "pass_tflops" in locals() else [],
self._new_xid_events(xid_before, self._collect_xid_events()),
),
}
finally:
stop_sampling.set()
tensors.clear()
ballast.clear()
torch.cuda.empty_cache()
elapsed = round(time.time() - t0, 1)
xid_events = self._new_xid_events(xid_before, self._collect_xid_events())
telemetry_summary = self._evaluate_telemetry(telemetry, pass_tflops, xid_events)
passed = all(v == "PASS" for v in gpu_status.values()) and telemetry_summary.get("passed", False)
return {
"source": "pytorch",
"passed": passed,
"duration_sec": duration,
"elapsed_sec": elapsed,
"gpu_status": gpu_status,
"telemetry": telemetry_summary,
"timestamp": datetime.now().isoformat(),
}
def _sample_telemetry(self, telemetry: list, stop_event: threading.Event, interval: int):
query = "index,temperature.gpu,power.draw,clocks_throttle_reasons.active"
while not stop_event.is_set():
try:
r = subprocess.run(
["nvidia-smi", f"--query-gpu={query}", "--format=csv,noheader,nounits"],
capture_output=True, text=True, timeout=10,
)
if r.returncode == 0:
sample = {"time": time.time(), "gpus": []}
for line in r.stdout.splitlines():
parts = [p.strip() for p in line.split(",")]
if len(parts) >= 4:
sample["gpus"].append({
"index": int(parts[0]),
"temp_c": float(parts[1]),
"power_w": float(parts[2]),
"throttle": parts[3],
})
telemetry.append(sample)
except Exception:
pass
stop_event.wait(interval)
def _collect_xid_events(self) -> list[str]:
try:
r = subprocess.run(
["dmesg", "--color=never"],
capture_output=True, text=True, timeout=10,
)
if r.returncode != 0:
return []
return [
line.strip()
for line in r.stdout.splitlines()
if any(token in line.upper() for token in ("XID", "NVRM: XID"))
]
except Exception:
return []
@staticmethod
def _new_xid_events(before: list[str], after: list[str]) -> list[str]:
seen = set(before)
return [line for line in after if line not in seen]
def _evaluate_telemetry(self, telemetry: list, pass_tflops: list, xid_events: list[str] | None = None) -> dict:
cfg = self.stress_cfg
max_temp = float(cfg.get("max_temp_c", 80))
max_delta = float(cfg.get("max_temp_delta_c", 5))
min_power = float(cfg.get("min_power_watts", 630))
max_jitter = float(cfg.get("max_tflops_jitter_pct", 5))
require_jitter = bool(cfg.get("require_tflops_jitter", True))
duration = float(cfg.get("duration_sec", 60))
requested_warmup = float(cfg.get("warmup_sec", 60))
warmup_sec = min(requested_warmup, max(0.0, duration * 0.2))
min_steady_samples = int(cfg.get("min_steady_samples", 10))
temps = {}
powers = {}
throttle_bad = []
xid_events = xid_events or []
steady_telemetry = [
sample for sample in telemetry
if sample.get("time", 0) - telemetry[0].get("time", 0) >= warmup_sec
] if telemetry else []
evaluation_samples = steady_telemetry if len(steady_telemetry) >= min_steady_samples else telemetry
for sample in evaluation_samples:
for g in sample.get("gpus", []):
idx = g["index"]
temps.setdefault(idx, []).append(g["temp_c"])
powers.setdefault(idx, []).append(g["power_w"])
try:
bitmask = int(str(g["throttle"]), 16)
except ValueError:
bitmask = 0
real_throttle = bitmask & ~0x1
if real_throttle:
throttle_bad.append({
"gpu": idx,
"throttle": g["throttle"],
"real_throttle": f"0x{real_throttle:x}",
})
max_temps = {idx: max(vals) for idx, vals in temps.items() if vals}
avg_powers = {idx: sum(vals) / len(vals) for idx, vals in powers.items() if vals}
temp_delta = (max(max_temps.values()) - min(max_temps.values())) if len(max_temps) >= 2 else 0
jitter = 0
steady_tflops = []
for item in pass_tflops:
if isinstance(item, dict):
if float(item.get("elapsed_sec", 0)) >= warmup_sec:
steady_tflops.append(float(item.get("tflops", 0)))
else:
steady_tflops.append(float(item))
if len(steady_tflops) < 2 and pass_tflops:
steady_tflops = [
float(item.get("tflops", 0)) if isinstance(item, dict) else float(item)
for item in pass_tflops
]
if steady_tflops:
mean = sum(steady_tflops) / len(steady_tflops)
jitter = max(abs(v - mean) / mean * 100 for v in steady_tflops) if mean else 0
failures = []
temp_failures = {idx: v for idx, v in max_temps.items() if v > max_temp}
power_failures = {idx: v for idx, v in avg_powers.items() if v < min_power}
if not evaluation_samples:
failures.append("no telemetry samples available for evaluation")
if temp_failures:
failures.append(
"max temperature above threshold: "
+ ", ".join(f"GPU {idx} {val:.1f}C" for idx, val in sorted(temp_failures.items()))
)
if temp_delta > max_delta:
failures.append(f"GPU temperature delta {temp_delta:.1f}C exceeds {max_delta:.1f}C")
if power_failures:
failures.append(
"average steady-state power below threshold: "
+ ", ".join(f"GPU {idx} {val:.1f}W" for idx, val in sorted(power_failures.items()))
)
if throttle_bad:
failures.append(
f"non-idle throttle reasons observed in {len(throttle_bad)} samples "
f"(first: GPU {throttle_bad[0]['gpu']} {throttle_bad[0]['real_throttle']})"
)
if xid_events:
failures.append(f"{len(xid_events)} new XID/NVRM XID events observed")
if require_jitter and len(steady_tflops) < 2:
failures.append(
f"insufficient steady TFLOPS samples for jitter evaluation: {len(steady_tflops)} < 2"
)
if jitter > max_jitter:
failures.append(f"TFLOPS jitter {jitter:.2f}% exceeds {max_jitter:.2f}%")
passed = (
bool(evaluation_samples)
and all(v <= max_temp for v in max_temps.values())
and temp_delta <= max_delta
and all(v >= min_power for v in avg_powers.values())
and not throttle_bad
and not xid_events
and (not require_jitter or len(steady_tflops) >= 2)
and jitter <= max_jitter
)
return {
"passed": passed,
"samples": len(telemetry),
"steady_samples": len(evaluation_samples),
"warmup_sec": round(warmup_sec, 1),
"max_temp_c": {k: round(v, 1) for k, v in max_temps.items()},
"avg_power_w": {k: round(v, 1) for k, v in avg_powers.items()},
"temp_delta_c": round(temp_delta, 1),
"throttle_events": throttle_bad[:20],
"throttle_event_count": len(throttle_bad),
"xid_events": xid_events[-20:],
"tflops_jitter_pct": round(jitter, 2),
"steady_tflops_samples": len(steady_tflops),
"failures": failures,
"thresholds": {
"max_temp_c": max_temp,
"max_temp_delta_c": max_delta,
"min_power_w": min_power,
"max_tflops_jitter_pct": max_jitter,
"require_tflops_jitter": require_jitter,
"warmup_sec": requested_warmup,
"min_steady_samples": min_steady_samples,
},
}
@staticmethod
def print_results(results: dict, console: Console = None):
c = console or Console()
if "error" in results:
c.print(f"[bold red]Error: {results['error']}[/bold red]")
return
passed = results.get("passed", False)
source = results.get("source", "unknown")
duration = results.get("duration_sec", "?")
elapsed = results.get("elapsed_sec", "?")
verdict = "[bold green]✓ Stress Test PASSED[/bold green]" if passed else "[bold red]✗ Stress Test FAILED[/bold red]"
c.print(f"\n{verdict} [dim](via {source})[/dim]")
c.print(f" Target duration: {duration}s | Actual: {elapsed}s")
gpu_results = results.get("gpu_results", [])
if gpu_results:
c.print("\n Per-GPU results:")
for line in gpu_results:
if "FAIL" in line.upper():
c.print(f" [red]{line}[/red]")
else:
c.print(f" [green]{line}[/green]")
gpu_status = results.get("gpu_status", {})
if gpu_status:
c.print("\n Per-GPU status:")
for gid, status in sorted(gpu_status.items()):
color = "green" if status == "PASS" else "red"
c.print(f" GPU {gid}: [{color}]{status}[/{color}]")
telemetry = results.get("telemetry") or {}
if telemetry:
c.print("\n Telemetry:")
c.print(f" Samples: {telemetry.get('samples', 0)} total, {telemetry.get('steady_samples', 0)} evaluated after {telemetry.get('warmup_sec', 0)}s warmup")
c.print(f" Avg steady power: {telemetry.get('avg_power_w', {})}")
c.print(f" Max steady temp: {telemetry.get('max_temp_c', {})}")
c.print(f" Temp delta: {telemetry.get('temp_delta_c', 'N/A')} C")
c.print(f" TFLOPS jitter: {telemetry.get('tflops_jitter_pct', 'N/A')}%")
c.print(f" Throttle events: {telemetry.get('throttle_event_count', len(telemetry.get('throttle_events', [])))}")
c.print(f" XID events: {len(telemetry.get('xid_events', []))}")
failures = telemetry.get("failures", [])
if failures:
c.print(" [red]Failure reasons:[/red]")
for reason in failures:
c.print(f" [red]- {reason}[/red]")
if results.get("error"):
c.print(f" [red]Error: {results['error']}[/red]")