488 lines
21 KiB
Python
488 lines
21 KiB
Python
"""GPU stress test module — gpu-burn or PyTorch GEMM with telemetry."""
|
|
|
|
import glob
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import threading
|
|
import time
|
|
from datetime import datetime
|
|
|
|
from rich.console import Console
|
|
from rich.table import Table
|
|
from rich.live import Live
|
|
from rich.text import Text
|
|
|
|
from modules.gpu_specs import resolve_tools_dir
|
|
|
|
|
|
class StressTest:
|
|
|
|
def __init__(self, config: dict):
|
|
self.config = config
|
|
self.console = Console()
|
|
self.stress_cfg = config.get("stress", {})
|
|
self.tools_dir = resolve_tools_dir(config)
|
|
|
|
def _find_gpu_burn(self) -> str:
|
|
p = shutil.which("gpu_burn")
|
|
if p:
|
|
return p
|
|
|
|
local = os.path.join(self.tools_dir, "gpu-burn", "gpu_burn")
|
|
if os.path.isfile(local) and os.access(local, shutil.os.X_OK):
|
|
return local
|
|
|
|
matches = glob.glob(os.path.join(self.tools_dir, "gpu-burn", "**", "gpu_burn"), recursive=True)
|
|
for m in matches:
|
|
if os.access(m, shutil.os.X_OK):
|
|
return m
|
|
return ""
|
|
|
|
def run(self) -> dict:
|
|
cfg = self.stress_cfg
|
|
duration_sec = cfg.get("duration_sec", 60)
|
|
use_doubles = cfg.get("use_doubles", False)
|
|
use_tensor_cores = cfg.get("use_tensor_cores", True)
|
|
memory_pct = cfg.get("memory_pct", 90)
|
|
target_gpus = cfg.get("gpus", "all")
|
|
|
|
gpu_burn = self._find_gpu_burn() if cfg.get("use_gpu_burn", False) else ""
|
|
|
|
if gpu_burn:
|
|
# Try gpu-burn first
|
|
result = self._run_gpu_burn(gpu_burn, duration_sec, use_doubles, use_tensor_cores, target_gpus)
|
|
|
|
# If gpu-burn fails (e.g. OOM), auto-fallback to PyTorch
|
|
if not result.get("passed") and result.get("elapsed_sec", 0) < duration_sec * 0.5:
|
|
self.console.print("\n[yellow]gpu-burn exited early (possible OOM), switching to PyTorch stress test[/yellow]")
|
|
self.console.print("[dim]PyTorch mode dynamically adapts to available memory[/dim]\n")
|
|
return self._run_pytorch_stress(duration_sec, memory_pct)
|
|
|
|
return result
|
|
|
|
self.console.print("[yellow]Using PyTorch stress test[/yellow]")
|
|
return self._run_pytorch_stress(duration_sec, memory_pct)
|
|
|
|
def _run_gpu_burn(self, gpu_burn: str, duration: int,
|
|
doubles: bool, tensor_cores: bool, target_gpus: str) -> dict:
|
|
self.console.print(f"[cyan]GPU Stress Test via gpu-burn ({duration}s)[/cyan]")
|
|
|
|
cmd = [gpu_burn]
|
|
if doubles:
|
|
cmd.append("-d")
|
|
if tensor_cores:
|
|
cmd.append("-tc")
|
|
if target_gpus != "all":
|
|
cmd.extend(["-i", str(target_gpus)])
|
|
cmd.append(str(duration))
|
|
|
|
t0 = time.time()
|
|
xid_before = self._collect_xid_events()
|
|
interval = int(self.stress_cfg.get("telemetry_interval_sec", 1))
|
|
telemetry = []
|
|
stop_sampling = threading.Event()
|
|
sampler = threading.Thread(
|
|
target=self._sample_telemetry,
|
|
args=(telemetry, stop_sampling, interval),
|
|
daemon=True,
|
|
)
|
|
sampler.start()
|
|
try:
|
|
r = subprocess.run(cmd, capture_output=True, text=True, timeout=duration + 120)
|
|
elapsed = round(time.time() - t0, 1)
|
|
stop_sampling.set()
|
|
sampler.join(timeout=interval + 1)
|
|
|
|
output = r.stdout + r.stderr
|
|
xid_events = self._new_xid_events(xid_before, self._collect_xid_events())
|
|
telemetry_summary = self._evaluate_telemetry(telemetry, [], xid_events)
|
|
passed = r.returncode == 0 and telemetry_summary.get("passed", False)
|
|
|
|
gpu_results = []
|
|
for line in output.split("\n"):
|
|
line = line.strip()
|
|
if "GPU" in line and ("PASS" in line.upper() or "FAIL" in line.upper()):
|
|
gpu_results.append(line)
|
|
|
|
return {
|
|
"source": "gpu-burn",
|
|
"passed": passed,
|
|
"duration_sec": duration,
|
|
"elapsed_sec": elapsed,
|
|
"gpu_results": gpu_results,
|
|
"telemetry": telemetry_summary,
|
|
"raw_output_tail": output[-500:] if output else "",
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
|
|
except subprocess.TimeoutExpired:
|
|
stop_sampling.set()
|
|
return {
|
|
"source": "gpu-burn",
|
|
"passed": False,
|
|
"duration_sec": duration,
|
|
"error": "timeout",
|
|
"telemetry": self._evaluate_telemetry(
|
|
telemetry, [], self._new_xid_events(xid_before, self._collect_xid_events())
|
|
),
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
except Exception as e:
|
|
stop_sampling.set()
|
|
return {
|
|
"source": "gpu-burn",
|
|
"passed": False,
|
|
"error": str(e),
|
|
"telemetry": self._evaluate_telemetry(
|
|
telemetry, [], self._new_xid_events(xid_before, self._collect_xid_events())
|
|
),
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
finally:
|
|
stop_sampling.set()
|
|
|
|
def _run_pytorch_stress(self, duration: int, memory_pct: int = 90) -> dict:
|
|
try:
|
|
import torch
|
|
if not torch.cuda.is_available():
|
|
return {"error": "pytorch_not_available"}
|
|
except ImportError:
|
|
return {"error": "pytorch_not_available"}
|
|
|
|
gpu_count = torch.cuda.device_count()
|
|
self.console.print(f"[cyan]PyTorch Stress Test ({duration}s, {gpu_count} GPUs, target {memory_pct}% memory)[/cyan]")
|
|
|
|
dtype_name = self.stress_cfg.get("dtype", "bf16")
|
|
matrix_size = int(self.stress_cfg.get("matrix_size", 8192))
|
|
interval = int(self.stress_cfg.get("telemetry_interval_sec", 1))
|
|
dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}
|
|
dtype = dtype_map.get(dtype_name, torch.bfloat16)
|
|
|
|
gpu_status = {}
|
|
telemetry = []
|
|
stop_sampling = threading.Event()
|
|
t0 = time.time()
|
|
xid_before = self._collect_xid_events()
|
|
|
|
try:
|
|
sampler = threading.Thread(
|
|
target=self._sample_telemetry,
|
|
args=(telemetry, stop_sampling, interval),
|
|
daemon=True,
|
|
)
|
|
sampler.start()
|
|
tensors = {}
|
|
ballast = {}
|
|
pass_tflops = []
|
|
for i in range(gpu_count):
|
|
with torch.cuda.device(i):
|
|
free_mem, total_mem = torch.cuda.mem_get_info(i)
|
|
side = matrix_size
|
|
elem = torch.tensor([], dtype=dtype).element_size()
|
|
compute_bytes = side * side * elem * 3
|
|
target_mem = min(int(total_mem * memory_pct / 100), int(free_mem * 0.90))
|
|
ballast_bytes = max(0, target_mem - compute_bytes)
|
|
if ballast_bytes:
|
|
ballast_elems = ballast_bytes // 2
|
|
ballast[i] = torch.empty(ballast_elems, device=f"cuda:{i}", dtype=torch.float16)
|
|
actual_mem_mb = (compute_bytes + ballast_bytes) / 1024 / 1024
|
|
total_mem_mb = total_mem / 1024 / 1024
|
|
free_mem_mb = free_mem / 1024 / 1024
|
|
|
|
self.console.print(
|
|
f" [dim]GPU {i}: total {total_mem_mb:.0f}MB, free {free_mem_mb:.0f}MB, "
|
|
f"alloc {actual_mem_mb:.0f}MB ({actual_mem_mb/total_mem_mb*100:.0f}%) - "
|
|
f"{dtype_name} matrix {side}x{side}[/dim]"
|
|
)
|
|
tensors[i] = (
|
|
torch.randn(side, side, device=f"cuda:{i}", dtype=dtype),
|
|
torch.randn(side, side, device=f"cuda:{i}", dtype=dtype),
|
|
torch.empty(side, side, device=f"cuda:{i}", dtype=dtype),
|
|
)
|
|
|
|
self.console.print(f"\n[cyan]Starting stress test for {duration} seconds...[/cyan]")
|
|
|
|
elapsed_check = 0
|
|
while time.time() - t0 < duration:
|
|
loop_start = time.perf_counter()
|
|
# Dispatch matmul on all GPUs in parallel — do NOT synchronize between
|
|
# GPUs, otherwise the 8 GPUs run serially and overshoot the duration.
|
|
for i in range(gpu_count):
|
|
with torch.cuda.device(i):
|
|
a, b, out = tensors[i]
|
|
torch.matmul(a, b, out=out)
|
|
# Single sync per pass — waits for all 8 streams concurrently
|
|
for i in range(gpu_count):
|
|
with torch.cuda.device(i):
|
|
torch.cuda.synchronize()
|
|
loop_elapsed = time.perf_counter() - loop_start
|
|
current_elapsed = time.time() - t0
|
|
if loop_elapsed > 0:
|
|
flops = gpu_count * 2 * (matrix_size ** 3)
|
|
pass_tflops.append({
|
|
"elapsed_sec": current_elapsed,
|
|
"tflops": flops / loop_elapsed / 1e12,
|
|
})
|
|
|
|
# Show progress every 10 seconds
|
|
if int(current_elapsed) != int(elapsed_check) and int(current_elapsed) % 10 == 0:
|
|
self.console.print(f" [dim]Running {int(current_elapsed)}s / {duration}s[/dim]")
|
|
elapsed_check = current_elapsed
|
|
|
|
for i in range(gpu_count):
|
|
gpu_status[i] = "PASS"
|
|
|
|
except RuntimeError as e:
|
|
error_msg = str(e)
|
|
self.console.print(f"\n[red]Stress test error: {error_msg}[/red]")
|
|
for i in range(gpu_count):
|
|
if i not in gpu_status:
|
|
gpu_status[i] = "FAIL"
|
|
return {
|
|
"source": "pytorch",
|
|
"passed": False,
|
|
"duration_sec": duration,
|
|
"error": error_msg,
|
|
"gpu_status": gpu_status,
|
|
"telemetry": self._evaluate_telemetry(
|
|
telemetry, pass_tflops if "pass_tflops" in locals() else [],
|
|
self._new_xid_events(xid_before, self._collect_xid_events()),
|
|
),
|
|
}
|
|
finally:
|
|
stop_sampling.set()
|
|
tensors.clear()
|
|
ballast.clear()
|
|
torch.cuda.empty_cache()
|
|
|
|
elapsed = round(time.time() - t0, 1)
|
|
xid_events = self._new_xid_events(xid_before, self._collect_xid_events())
|
|
telemetry_summary = self._evaluate_telemetry(telemetry, pass_tflops, xid_events)
|
|
passed = all(v == "PASS" for v in gpu_status.values()) and telemetry_summary.get("passed", False)
|
|
return {
|
|
"source": "pytorch",
|
|
"passed": passed,
|
|
"duration_sec": duration,
|
|
"elapsed_sec": elapsed,
|
|
"gpu_status": gpu_status,
|
|
"telemetry": telemetry_summary,
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
|
|
def _sample_telemetry(self, telemetry: list, stop_event: threading.Event, interval: int):
|
|
query = "index,temperature.gpu,power.draw,clocks_throttle_reasons.active"
|
|
while not stop_event.is_set():
|
|
try:
|
|
r = subprocess.run(
|
|
["nvidia-smi", f"--query-gpu={query}", "--format=csv,noheader,nounits"],
|
|
capture_output=True, text=True, timeout=10,
|
|
)
|
|
if r.returncode == 0:
|
|
sample = {"time": time.time(), "gpus": []}
|
|
for line in r.stdout.splitlines():
|
|
parts = [p.strip() for p in line.split(",")]
|
|
if len(parts) >= 4:
|
|
sample["gpus"].append({
|
|
"index": int(parts[0]),
|
|
"temp_c": float(parts[1]),
|
|
"power_w": float(parts[2]),
|
|
"throttle": parts[3],
|
|
})
|
|
telemetry.append(sample)
|
|
except Exception:
|
|
pass
|
|
stop_event.wait(interval)
|
|
|
|
def _collect_xid_events(self) -> list[str]:
|
|
try:
|
|
r = subprocess.run(
|
|
["dmesg", "--color=never"],
|
|
capture_output=True, text=True, timeout=10,
|
|
)
|
|
if r.returncode != 0:
|
|
return []
|
|
return [
|
|
line.strip()
|
|
for line in r.stdout.splitlines()
|
|
if any(token in line.upper() for token in ("XID", "NVRM: XID"))
|
|
]
|
|
except Exception:
|
|
return []
|
|
|
|
@staticmethod
|
|
def _new_xid_events(before: list[str], after: list[str]) -> list[str]:
|
|
seen = set(before)
|
|
return [line for line in after if line not in seen]
|
|
|
|
def _evaluate_telemetry(self, telemetry: list, pass_tflops: list, xid_events: list[str] | None = None) -> dict:
|
|
cfg = self.stress_cfg
|
|
max_temp = float(cfg.get("max_temp_c", 80))
|
|
max_delta = float(cfg.get("max_temp_delta_c", 5))
|
|
min_power = float(cfg.get("min_power_watts", 630))
|
|
max_jitter = float(cfg.get("max_tflops_jitter_pct", 5))
|
|
require_jitter = bool(cfg.get("require_tflops_jitter", True))
|
|
duration = float(cfg.get("duration_sec", 60))
|
|
requested_warmup = float(cfg.get("warmup_sec", 60))
|
|
warmup_sec = min(requested_warmup, max(0.0, duration * 0.2))
|
|
min_steady_samples = int(cfg.get("min_steady_samples", 10))
|
|
temps = {}
|
|
powers = {}
|
|
throttle_bad = []
|
|
xid_events = xid_events or []
|
|
steady_telemetry = [
|
|
sample for sample in telemetry
|
|
if sample.get("time", 0) - telemetry[0].get("time", 0) >= warmup_sec
|
|
] if telemetry else []
|
|
evaluation_samples = steady_telemetry if len(steady_telemetry) >= min_steady_samples else telemetry
|
|
for sample in evaluation_samples:
|
|
for g in sample.get("gpus", []):
|
|
idx = g["index"]
|
|
temps.setdefault(idx, []).append(g["temp_c"])
|
|
powers.setdefault(idx, []).append(g["power_w"])
|
|
try:
|
|
bitmask = int(str(g["throttle"]), 16)
|
|
except ValueError:
|
|
bitmask = 0
|
|
real_throttle = bitmask & ~0x1
|
|
if real_throttle:
|
|
throttle_bad.append({
|
|
"gpu": idx,
|
|
"throttle": g["throttle"],
|
|
"real_throttle": f"0x{real_throttle:x}",
|
|
})
|
|
max_temps = {idx: max(vals) for idx, vals in temps.items() if vals}
|
|
avg_powers = {idx: sum(vals) / len(vals) for idx, vals in powers.items() if vals}
|
|
temp_delta = (max(max_temps.values()) - min(max_temps.values())) if len(max_temps) >= 2 else 0
|
|
jitter = 0
|
|
steady_tflops = []
|
|
for item in pass_tflops:
|
|
if isinstance(item, dict):
|
|
if float(item.get("elapsed_sec", 0)) >= warmup_sec:
|
|
steady_tflops.append(float(item.get("tflops", 0)))
|
|
else:
|
|
steady_tflops.append(float(item))
|
|
if len(steady_tflops) < 2 and pass_tflops:
|
|
steady_tflops = [
|
|
float(item.get("tflops", 0)) if isinstance(item, dict) else float(item)
|
|
for item in pass_tflops
|
|
]
|
|
if steady_tflops:
|
|
mean = sum(steady_tflops) / len(steady_tflops)
|
|
jitter = max(abs(v - mean) / mean * 100 for v in steady_tflops) if mean else 0
|
|
failures = []
|
|
temp_failures = {idx: v for idx, v in max_temps.items() if v > max_temp}
|
|
power_failures = {idx: v for idx, v in avg_powers.items() if v < min_power}
|
|
if not evaluation_samples:
|
|
failures.append("no telemetry samples available for evaluation")
|
|
if temp_failures:
|
|
failures.append(
|
|
"max temperature above threshold: "
|
|
+ ", ".join(f"GPU {idx} {val:.1f}C" for idx, val in sorted(temp_failures.items()))
|
|
)
|
|
if temp_delta > max_delta:
|
|
failures.append(f"GPU temperature delta {temp_delta:.1f}C exceeds {max_delta:.1f}C")
|
|
if power_failures:
|
|
failures.append(
|
|
"average steady-state power below threshold: "
|
|
+ ", ".join(f"GPU {idx} {val:.1f}W" for idx, val in sorted(power_failures.items()))
|
|
)
|
|
if throttle_bad:
|
|
failures.append(
|
|
f"non-idle throttle reasons observed in {len(throttle_bad)} samples "
|
|
f"(first: GPU {throttle_bad[0]['gpu']} {throttle_bad[0]['real_throttle']})"
|
|
)
|
|
if xid_events:
|
|
failures.append(f"{len(xid_events)} new XID/NVRM XID events observed")
|
|
if require_jitter and len(steady_tflops) < 2:
|
|
failures.append(
|
|
f"insufficient steady TFLOPS samples for jitter evaluation: {len(steady_tflops)} < 2"
|
|
)
|
|
if jitter > max_jitter:
|
|
failures.append(f"TFLOPS jitter {jitter:.2f}% exceeds {max_jitter:.2f}%")
|
|
passed = (
|
|
bool(evaluation_samples)
|
|
and all(v <= max_temp for v in max_temps.values())
|
|
and temp_delta <= max_delta
|
|
and all(v >= min_power for v in avg_powers.values())
|
|
and not throttle_bad
|
|
and not xid_events
|
|
and (not require_jitter or len(steady_tflops) >= 2)
|
|
and jitter <= max_jitter
|
|
)
|
|
return {
|
|
"passed": passed,
|
|
"samples": len(telemetry),
|
|
"steady_samples": len(evaluation_samples),
|
|
"warmup_sec": round(warmup_sec, 1),
|
|
"max_temp_c": {k: round(v, 1) for k, v in max_temps.items()},
|
|
"avg_power_w": {k: round(v, 1) for k, v in avg_powers.items()},
|
|
"temp_delta_c": round(temp_delta, 1),
|
|
"throttle_events": throttle_bad[:20],
|
|
"throttle_event_count": len(throttle_bad),
|
|
"xid_events": xid_events[-20:],
|
|
"tflops_jitter_pct": round(jitter, 2),
|
|
"steady_tflops_samples": len(steady_tflops),
|
|
"failures": failures,
|
|
"thresholds": {
|
|
"max_temp_c": max_temp,
|
|
"max_temp_delta_c": max_delta,
|
|
"min_power_w": min_power,
|
|
"max_tflops_jitter_pct": max_jitter,
|
|
"require_tflops_jitter": require_jitter,
|
|
"warmup_sec": requested_warmup,
|
|
"min_steady_samples": min_steady_samples,
|
|
},
|
|
}
|
|
|
|
@staticmethod
|
|
def print_results(results: dict, console: Console = None):
|
|
c = console or Console()
|
|
if "error" in results:
|
|
c.print(f"[bold red]Error: {results['error']}[/bold red]")
|
|
return
|
|
|
|
passed = results.get("passed", False)
|
|
source = results.get("source", "unknown")
|
|
duration = results.get("duration_sec", "?")
|
|
elapsed = results.get("elapsed_sec", "?")
|
|
|
|
verdict = "[bold green]✓ Stress Test PASSED[/bold green]" if passed else "[bold red]✗ Stress Test FAILED[/bold red]"
|
|
c.print(f"\n{verdict} [dim](via {source})[/dim]")
|
|
c.print(f" Target duration: {duration}s | Actual: {elapsed}s")
|
|
|
|
gpu_results = results.get("gpu_results", [])
|
|
if gpu_results:
|
|
c.print("\n Per-GPU results:")
|
|
for line in gpu_results:
|
|
if "FAIL" in line.upper():
|
|
c.print(f" [red]{line}[/red]")
|
|
else:
|
|
c.print(f" [green]{line}[/green]")
|
|
|
|
gpu_status = results.get("gpu_status", {})
|
|
if gpu_status:
|
|
c.print("\n Per-GPU status:")
|
|
for gid, status in sorted(gpu_status.items()):
|
|
color = "green" if status == "PASS" else "red"
|
|
c.print(f" GPU {gid}: [{color}]{status}[/{color}]")
|
|
|
|
telemetry = results.get("telemetry") or {}
|
|
if telemetry:
|
|
c.print("\n Telemetry:")
|
|
c.print(f" Samples: {telemetry.get('samples', 0)} total, {telemetry.get('steady_samples', 0)} evaluated after {telemetry.get('warmup_sec', 0)}s warmup")
|
|
c.print(f" Avg steady power: {telemetry.get('avg_power_w', {})}")
|
|
c.print(f" Max steady temp: {telemetry.get('max_temp_c', {})}")
|
|
c.print(f" Temp delta: {telemetry.get('temp_delta_c', 'N/A')} C")
|
|
c.print(f" TFLOPS jitter: {telemetry.get('tflops_jitter_pct', 'N/A')}%")
|
|
c.print(f" Throttle events: {telemetry.get('throttle_event_count', len(telemetry.get('throttle_events', [])))}")
|
|
c.print(f" XID events: {len(telemetry.get('xid_events', []))}")
|
|
failures = telemetry.get("failures", [])
|
|
if failures:
|
|
c.print(" [red]Failure reasons:[/red]")
|
|
for reason in failures:
|
|
c.print(f" [red]- {reason}[/red]")
|
|
|
|
if results.get("error"):
|
|
c.print(f" [red]Error: {results['error']}[/red]")
|