232 lines
8.8 KiB
Python
232 lines
8.8 KiB
Python
"""DCGM diagnostic acceptance wrapper."""
|
|
|
|
import json
|
|
import os
|
|
import re
|
|
import shutil
|
|
import signal
|
|
import subprocess
|
|
from datetime import datetime
|
|
from typing import Optional
|
|
|
|
from rich.console import Console
|
|
from rich.table import Table
|
|
|
|
|
|
class DCGMTest:
|
|
def __init__(self, config: dict):
|
|
self.config = config
|
|
self.console = Console()
|
|
self.cfg = config.get("dcgm", {})
|
|
|
|
def run(self) -> dict:
|
|
dcgmi = shutil.which("dcgmi")
|
|
if not dcgmi:
|
|
return {
|
|
"passed": False,
|
|
"error": "dcgmi not found",
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
|
|
level = str(self.cfg.get("diag_level", 3))
|
|
timeout = int(self.cfg.get("timeout_sec", 1200))
|
|
cmd = [dcgmi, "diag", "-r", level]
|
|
expected_gpus = self.cfg.get("expected_num_gpus")
|
|
if expected_gpus:
|
|
cmd.extend(["-n", f"gpu:{int(expected_gpus)}"])
|
|
if self.cfg.get("json_output", True):
|
|
cmd.append("-j")
|
|
|
|
try:
|
|
r = self._run_with_process_group_timeout(cmd, timeout)
|
|
except subprocess.TimeoutExpired as e:
|
|
output = ((e.output or "") + "\n" + (e.stderr or "")).strip()
|
|
return {
|
|
"passed": False,
|
|
"error": f"dcgmi diag -r {level} timeout after {timeout}s",
|
|
"command": cmd,
|
|
"raw_output_tail": output[-8000:],
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
|
|
output = r.stdout + "\n" + r.stderr
|
|
subtests = self._parse_json_output(output) or self._parse_output(output)
|
|
strict_statuses = {"PASS"}
|
|
failed = [s for s in subtests if s["status"] not in strict_statuses]
|
|
require_subtests = bool(self.cfg.get("require_subtests", True))
|
|
passed = r.returncode == 0 and not failed and (bool(subtests) or not require_subtests)
|
|
return {
|
|
"passed": passed,
|
|
"returncode": r.returncode,
|
|
"level": int(level),
|
|
"command": cmd,
|
|
"expected_num_gpus": int(expected_gpus) if expected_gpus else None,
|
|
"subtests": subtests,
|
|
"raw_output_tail": output[-8000:],
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
|
|
@staticmethod
|
|
def _run_with_process_group_timeout(cmd: list[str], timeout: int) -> subprocess.CompletedProcess:
|
|
proc = subprocess.Popen(
|
|
cmd,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
text=True,
|
|
start_new_session=True,
|
|
)
|
|
try:
|
|
stdout, stderr = proc.communicate(timeout=timeout)
|
|
except subprocess.TimeoutExpired as e:
|
|
try:
|
|
os.killpg(proc.pid, signal.SIGTERM)
|
|
stdout, stderr = proc.communicate(timeout=10)
|
|
except subprocess.TimeoutExpired:
|
|
os.killpg(proc.pid, signal.SIGKILL)
|
|
stdout, stderr = proc.communicate(timeout=10)
|
|
raise subprocess.TimeoutExpired(cmd, timeout, output=stdout, stderr=stderr) from e
|
|
return subprocess.CompletedProcess(cmd, proc.returncode, stdout, stderr)
|
|
|
|
@classmethod
|
|
def _parse_json_output(cls, output: str) -> list[dict]:
|
|
text = output.strip()
|
|
if not text:
|
|
return []
|
|
try:
|
|
payload = json.loads(text)
|
|
except json.JSONDecodeError:
|
|
m = re.search(r"(\{.*\})", text, re.S)
|
|
if not m:
|
|
return []
|
|
try:
|
|
payload = json.loads(m.group(1))
|
|
except json.JSONDecodeError:
|
|
return []
|
|
|
|
dcgm_payload = payload.get("DCGM Diagnostic") if isinstance(payload, dict) else None
|
|
if isinstance(dcgm_payload, dict):
|
|
parsed = cls._parse_dcgm_diagnostic_json(dcgm_payload)
|
|
if parsed:
|
|
return parsed
|
|
|
|
subtests = []
|
|
|
|
def walk(node, path: list[str]):
|
|
if isinstance(node, dict):
|
|
node_name = (
|
|
node.get("name")
|
|
or node.get("testName")
|
|
or node.get("test_name")
|
|
or node.get("category")
|
|
or node.get("category_name")
|
|
)
|
|
child_path = [*path, str(node_name)] if node_name else path
|
|
status = node.get("status") or node.get("result") or node.get("Result")
|
|
if isinstance(status, str):
|
|
name = (
|
|
node_name
|
|
or " / ".join(path[-3:])
|
|
)
|
|
normalized = cls._normalize_status(status)
|
|
if normalized:
|
|
subtests.append({
|
|
"name": str(name)[:160],
|
|
"status": normalized,
|
|
"raw": json.dumps(node, default=str)[:1000],
|
|
})
|
|
for key, value in node.items():
|
|
walk(value, [*child_path, str(key)])
|
|
elif isinstance(node, list):
|
|
for idx, item in enumerate(node):
|
|
walk(item, [*path, str(idx)])
|
|
|
|
walk(payload, [])
|
|
return subtests
|
|
|
|
@classmethod
|
|
def _parse_dcgm_diagnostic_json(cls, payload: dict) -> list[dict]:
|
|
subtests = []
|
|
for category in payload.get("test_categories", []) or []:
|
|
category_name = str(category.get("category") or "DCGM")
|
|
for test in category.get("tests", []) or []:
|
|
test_name = str(test.get("name") or "unnamed")
|
|
for result in test.get("results", []) or []:
|
|
status = cls._normalize_status(str(result.get("status", "")))
|
|
if not status:
|
|
continue
|
|
entity_group = result.get("entity_group") or "entity"
|
|
entity_id = result.get("entity_id", "unknown")
|
|
name = f"{category_name}/{test_name}/{entity_group}{entity_id}"
|
|
subtests.append({
|
|
"name": name[:160],
|
|
"status": status,
|
|
"raw": json.dumps(result, default=str)[:1000],
|
|
})
|
|
summary = test.get("test_summary") or {}
|
|
status = cls._normalize_status(str(summary.get("status", "")))
|
|
if status:
|
|
subtests.append({
|
|
"name": f"{category_name}/{test_name}/summary"[:160],
|
|
"status": status,
|
|
"raw": json.dumps(summary, default=str)[:1000],
|
|
})
|
|
return subtests
|
|
|
|
@staticmethod
|
|
def _normalize_status(status: str) -> str:
|
|
s = status.strip().upper()
|
|
aliases = {
|
|
"PASS": "PASS",
|
|
"PASSED": "PASS",
|
|
"OK": "PASS",
|
|
"FAIL": "FAIL",
|
|
"FAILED": "FAIL",
|
|
"ERROR": "ERROR",
|
|
"WARN": "WARN",
|
|
"WARNING": "WARN",
|
|
"SKIP": "SKIP",
|
|
"SKIPPED": "SKIP",
|
|
"NOT_RUN": "SKIP",
|
|
"NOT RUN": "SKIP",
|
|
}
|
|
return aliases.get(s, s if s in {"PASS", "FAIL", "ERROR", "WARN", "SKIP"} else "")
|
|
|
|
@staticmethod
|
|
def _parse_output(output: str) -> list[dict]:
|
|
subtests = []
|
|
for line in output.splitlines():
|
|
stripped = line.strip()
|
|
if not stripped:
|
|
continue
|
|
m = re.search(r"(.+?)\s*[:|]\s*(PASS|FAIL|WARN|ERROR|SKIP)\b", stripped, re.I)
|
|
if not m:
|
|
m = re.search(r"\b(PASS|FAIL|WARN|ERROR|SKIP)\b\s*[-:|]\s*(.+)", stripped, re.I)
|
|
if m:
|
|
status = DCGMTest._normalize_status(m.group(1))
|
|
name = m.group(2).strip()
|
|
else:
|
|
continue
|
|
else:
|
|
name = m.group(1).strip(" .|-")
|
|
status = DCGMTest._normalize_status(m.group(2))
|
|
if name and len(name) < 160:
|
|
subtests.append({"name": name, "status": status, "raw": stripped})
|
|
return subtests
|
|
|
|
@staticmethod
|
|
def print_results(results: dict, console: Optional[Console] = None):
|
|
c = console or Console()
|
|
if results.get("error"):
|
|
c.print(f"[bold red]DCGM error: {results['error']}[/bold red]")
|
|
return
|
|
passed = results.get("passed", False)
|
|
c.print("[bold green]✓ DCGM diag PASSED[/bold green]" if passed else "[bold red]✗ DCGM diag FAILED[/bold red]")
|
|
subtests = results.get("subtests", [])
|
|
if subtests:
|
|
table = Table(box=None, padding=(0, 1))
|
|
table.add_column("Subtest")
|
|
table.add_column("Status", style="bold")
|
|
for s in subtests:
|
|
table.add_row(s.get("name", ""), s.get("status", ""))
|
|
c.print(table)
|