"""DCGM diagnostic acceptance wrapper.""" import json import os import re import shutil import signal import subprocess from datetime import datetime from typing import Optional from rich.console import Console from rich.table import Table class DCGMTest: def __init__(self, config: dict): self.config = config self.console = Console() self.cfg = config.get("dcgm", {}) def run(self) -> dict: dcgmi = shutil.which("dcgmi") if not dcgmi: return { "passed": False, "error": "dcgmi not found", "timestamp": datetime.now().isoformat(), } level = str(self.cfg.get("diag_level", 3)) timeout = int(self.cfg.get("timeout_sec", 1200)) cmd = [dcgmi, "diag", "-r", level] expected_gpus = self.cfg.get("expected_num_gpus") if expected_gpus: cmd.extend(["-n", f"gpu:{int(expected_gpus)}"]) if self.cfg.get("json_output", True): cmd.append("-j") try: r = self._run_with_process_group_timeout(cmd, timeout) except subprocess.TimeoutExpired as e: output = ((e.output or "") + "\n" + (e.stderr or "")).strip() return { "passed": False, "error": f"dcgmi diag -r {level} timeout after {timeout}s", "command": cmd, "raw_output_tail": output[-8000:], "timestamp": datetime.now().isoformat(), } output = r.stdout + "\n" + r.stderr subtests = self._parse_json_output(output) or self._parse_output(output) strict_statuses = {"PASS"} failed = [s for s in subtests if s["status"] not in strict_statuses] require_subtests = bool(self.cfg.get("require_subtests", True)) passed = r.returncode == 0 and not failed and (bool(subtests) or not require_subtests) return { "passed": passed, "returncode": r.returncode, "level": int(level), "command": cmd, "expected_num_gpus": int(expected_gpus) if expected_gpus else None, "subtests": subtests, "raw_output_tail": output[-8000:], "timestamp": datetime.now().isoformat(), } @staticmethod def _run_with_process_group_timeout(cmd: list[str], timeout: int) -> subprocess.CompletedProcess: proc = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, start_new_session=True, ) try: stdout, stderr = proc.communicate(timeout=timeout) except subprocess.TimeoutExpired as e: try: os.killpg(proc.pid, signal.SIGTERM) stdout, stderr = proc.communicate(timeout=10) except subprocess.TimeoutExpired: os.killpg(proc.pid, signal.SIGKILL) stdout, stderr = proc.communicate(timeout=10) raise subprocess.TimeoutExpired(cmd, timeout, output=stdout, stderr=stderr) from e return subprocess.CompletedProcess(cmd, proc.returncode, stdout, stderr) @classmethod def _parse_json_output(cls, output: str) -> list[dict]: text = output.strip() if not text: return [] try: payload = json.loads(text) except json.JSONDecodeError: m = re.search(r"(\{.*\})", text, re.S) if not m: return [] try: payload = json.loads(m.group(1)) except json.JSONDecodeError: return [] dcgm_payload = payload.get("DCGM Diagnostic") if isinstance(payload, dict) else None if isinstance(dcgm_payload, dict): parsed = cls._parse_dcgm_diagnostic_json(dcgm_payload) if parsed: return parsed subtests = [] def walk(node, path: list[str]): if isinstance(node, dict): node_name = ( node.get("name") or node.get("testName") or node.get("test_name") or node.get("category") or node.get("category_name") ) child_path = [*path, str(node_name)] if node_name else path status = node.get("status") or node.get("result") or node.get("Result") if isinstance(status, str): name = ( node_name or " / ".join(path[-3:]) ) normalized = cls._normalize_status(status) if normalized: subtests.append({ "name": str(name)[:160], "status": normalized, "raw": json.dumps(node, default=str)[:1000], }) for key, value in node.items(): walk(value, [*child_path, str(key)]) elif isinstance(node, list): for idx, item in enumerate(node): walk(item, [*path, str(idx)]) walk(payload, []) return subtests @classmethod def _parse_dcgm_diagnostic_json(cls, payload: dict) -> list[dict]: subtests = [] for category in payload.get("test_categories", []) or []: category_name = str(category.get("category") or "DCGM") for test in category.get("tests", []) or []: test_name = str(test.get("name") or "unnamed") for result in test.get("results", []) or []: status = cls._normalize_status(str(result.get("status", ""))) if not status: continue entity_group = result.get("entity_group") or "entity" entity_id = result.get("entity_id", "unknown") name = f"{category_name}/{test_name}/{entity_group}{entity_id}" subtests.append({ "name": name[:160], "status": status, "raw": json.dumps(result, default=str)[:1000], }) summary = test.get("test_summary") or {} status = cls._normalize_status(str(summary.get("status", ""))) if status: subtests.append({ "name": f"{category_name}/{test_name}/summary"[:160], "status": status, "raw": json.dumps(summary, default=str)[:1000], }) return subtests @staticmethod def _normalize_status(status: str) -> str: s = status.strip().upper() aliases = { "PASS": "PASS", "PASSED": "PASS", "OK": "PASS", "FAIL": "FAIL", "FAILED": "FAIL", "ERROR": "ERROR", "WARN": "WARN", "WARNING": "WARN", "SKIP": "SKIP", "SKIPPED": "SKIP", "NOT_RUN": "SKIP", "NOT RUN": "SKIP", } return aliases.get(s, s if s in {"PASS", "FAIL", "ERROR", "WARN", "SKIP"} else "") @staticmethod def _parse_output(output: str) -> list[dict]: subtests = [] for line in output.splitlines(): stripped = line.strip() if not stripped: continue m = re.search(r"(.+?)\s*[:|]\s*(PASS|FAIL|WARN|ERROR|SKIP)\b", stripped, re.I) if not m: m = re.search(r"\b(PASS|FAIL|WARN|ERROR|SKIP)\b\s*[-:|]\s*(.+)", stripped, re.I) if m: status = DCGMTest._normalize_status(m.group(1)) name = m.group(2).strip() else: continue else: name = m.group(1).strip(" .|-") status = DCGMTest._normalize_status(m.group(2)) if name and len(name) < 160: subtests.append({"name": name, "status": status, "raw": stripped}) return subtests @staticmethod def print_results(results: dict, console: Optional[Console] = None): c = console or Console() if results.get("error"): c.print(f"[bold red]DCGM error: {results['error']}[/bold red]") return passed = results.get("passed", False) c.print("[bold green]✓ DCGM diag PASSED[/bold green]" if passed else "[bold red]✗ DCGM diag FAILED[/bold red]") subtests = results.get("subtests", []) if subtests: table = Table(box=None, padding=(0, 1)) table.add_column("Subtest") table.add_column("Status", style="bold") for s in subtests: table.add_row(s.get("name", ""), s.get("status", "")) c.print(table)