test_gpu_scripts/modules/dcgm_test.py

232 lines
8.8 KiB
Python

"""DCGM diagnostic acceptance wrapper."""
import json
import os
import re
import shutil
import signal
import subprocess
from datetime import datetime
from typing import Optional
from rich.console import Console
from rich.table import Table
class DCGMTest:
def __init__(self, config: dict):
self.config = config
self.console = Console()
self.cfg = config.get("dcgm", {})
def run(self) -> dict:
dcgmi = shutil.which("dcgmi")
if not dcgmi:
return {
"passed": False,
"error": "dcgmi not found",
"timestamp": datetime.now().isoformat(),
}
level = str(self.cfg.get("diag_level", 3))
timeout = int(self.cfg.get("timeout_sec", 1200))
cmd = [dcgmi, "diag", "-r", level]
expected_gpus = self.cfg.get("expected_num_gpus")
if expected_gpus:
cmd.extend(["-n", f"gpu:{int(expected_gpus)}"])
if self.cfg.get("json_output", True):
cmd.append("-j")
try:
r = self._run_with_process_group_timeout(cmd, timeout)
except subprocess.TimeoutExpired as e:
output = ((e.output or "") + "\n" + (e.stderr or "")).strip()
return {
"passed": False,
"error": f"dcgmi diag -r {level} timeout after {timeout}s",
"command": cmd,
"raw_output_tail": output[-8000:],
"timestamp": datetime.now().isoformat(),
}
output = r.stdout + "\n" + r.stderr
subtests = self._parse_json_output(output) or self._parse_output(output)
strict_statuses = {"PASS"}
failed = [s for s in subtests if s["status"] not in strict_statuses]
require_subtests = bool(self.cfg.get("require_subtests", True))
passed = r.returncode == 0 and not failed and (bool(subtests) or not require_subtests)
return {
"passed": passed,
"returncode": r.returncode,
"level": int(level),
"command": cmd,
"expected_num_gpus": int(expected_gpus) if expected_gpus else None,
"subtests": subtests,
"raw_output_tail": output[-8000:],
"timestamp": datetime.now().isoformat(),
}
@staticmethod
def _run_with_process_group_timeout(cmd: list[str], timeout: int) -> subprocess.CompletedProcess:
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
start_new_session=True,
)
try:
stdout, stderr = proc.communicate(timeout=timeout)
except subprocess.TimeoutExpired as e:
try:
os.killpg(proc.pid, signal.SIGTERM)
stdout, stderr = proc.communicate(timeout=10)
except subprocess.TimeoutExpired:
os.killpg(proc.pid, signal.SIGKILL)
stdout, stderr = proc.communicate(timeout=10)
raise subprocess.TimeoutExpired(cmd, timeout, output=stdout, stderr=stderr) from e
return subprocess.CompletedProcess(cmd, proc.returncode, stdout, stderr)
@classmethod
def _parse_json_output(cls, output: str) -> list[dict]:
text = output.strip()
if not text:
return []
try:
payload = json.loads(text)
except json.JSONDecodeError:
m = re.search(r"(\{.*\})", text, re.S)
if not m:
return []
try:
payload = json.loads(m.group(1))
except json.JSONDecodeError:
return []
dcgm_payload = payload.get("DCGM Diagnostic") if isinstance(payload, dict) else None
if isinstance(dcgm_payload, dict):
parsed = cls._parse_dcgm_diagnostic_json(dcgm_payload)
if parsed:
return parsed
subtests = []
def walk(node, path: list[str]):
if isinstance(node, dict):
node_name = (
node.get("name")
or node.get("testName")
or node.get("test_name")
or node.get("category")
or node.get("category_name")
)
child_path = [*path, str(node_name)] if node_name else path
status = node.get("status") or node.get("result") or node.get("Result")
if isinstance(status, str):
name = (
node_name
or " / ".join(path[-3:])
)
normalized = cls._normalize_status(status)
if normalized:
subtests.append({
"name": str(name)[:160],
"status": normalized,
"raw": json.dumps(node, default=str)[:1000],
})
for key, value in node.items():
walk(value, [*child_path, str(key)])
elif isinstance(node, list):
for idx, item in enumerate(node):
walk(item, [*path, str(idx)])
walk(payload, [])
return subtests
@classmethod
def _parse_dcgm_diagnostic_json(cls, payload: dict) -> list[dict]:
subtests = []
for category in payload.get("test_categories", []) or []:
category_name = str(category.get("category") or "DCGM")
for test in category.get("tests", []) or []:
test_name = str(test.get("name") or "unnamed")
for result in test.get("results", []) or []:
status = cls._normalize_status(str(result.get("status", "")))
if not status:
continue
entity_group = result.get("entity_group") or "entity"
entity_id = result.get("entity_id", "unknown")
name = f"{category_name}/{test_name}/{entity_group}{entity_id}"
subtests.append({
"name": name[:160],
"status": status,
"raw": json.dumps(result, default=str)[:1000],
})
summary = test.get("test_summary") or {}
status = cls._normalize_status(str(summary.get("status", "")))
if status:
subtests.append({
"name": f"{category_name}/{test_name}/summary"[:160],
"status": status,
"raw": json.dumps(summary, default=str)[:1000],
})
return subtests
@staticmethod
def _normalize_status(status: str) -> str:
s = status.strip().upper()
aliases = {
"PASS": "PASS",
"PASSED": "PASS",
"OK": "PASS",
"FAIL": "FAIL",
"FAILED": "FAIL",
"ERROR": "ERROR",
"WARN": "WARN",
"WARNING": "WARN",
"SKIP": "SKIP",
"SKIPPED": "SKIP",
"NOT_RUN": "SKIP",
"NOT RUN": "SKIP",
}
return aliases.get(s, s if s in {"PASS", "FAIL", "ERROR", "WARN", "SKIP"} else "")
@staticmethod
def _parse_output(output: str) -> list[dict]:
subtests = []
for line in output.splitlines():
stripped = line.strip()
if not stripped:
continue
m = re.search(r"(.+?)\s*[:|]\s*(PASS|FAIL|WARN|ERROR|SKIP)\b", stripped, re.I)
if not m:
m = re.search(r"\b(PASS|FAIL|WARN|ERROR|SKIP)\b\s*[-:|]\s*(.+)", stripped, re.I)
if m:
status = DCGMTest._normalize_status(m.group(1))
name = m.group(2).strip()
else:
continue
else:
name = m.group(1).strip(" .|-")
status = DCGMTest._normalize_status(m.group(2))
if name and len(name) < 160:
subtests.append({"name": name, "status": status, "raw": stripped})
return subtests
@staticmethod
def print_results(results: dict, console: Optional[Console] = None):
c = console or Console()
if results.get("error"):
c.print(f"[bold red]DCGM error: {results['error']}[/bold red]")
return
passed = results.get("passed", False)
c.print("[bold green]✓ DCGM diag PASSED[/bold green]" if passed else "[bold red]✗ DCGM diag FAILED[/bold red]")
subtests = results.get("subtests", [])
if subtests:
table = Table(box=None, padding=(0, 1))
table.add_column("Subtest")
table.add_column("Status", style="bold")
for s in subtests:
table.add_row(s.get("name", ""), s.get("status", ""))
c.print(table)