import json import os import sys import re def extract_metrics_from_log(log_file_path): all_metrics = [] pattern = re.compile( r"\{'agilex_sample_mse':\s*([0-9.eE+-]+),\s*'agilex_sample_l2err':\s*([0-9.eE+-]+),\s*'overall_avg_sample_mse':\s*([0-9.eE+-]+),\s*'overall_avg_sample_l2err':\s*([0-9.eE+-]+)\}" ) try: with open(log_file_path, 'r', encoding='utf-8') as f: for line in f: m = pattern.search(line) if m: metrics = ( float(m.group(1)), float(m.group(2)), float(m.group(3)), float(m.group(4)) ) all_metrics.append(metrics) print(f"Find Metrics: agilex_sample_mse={metrics[0]}, agilex_sample_l2err={metrics[1]}, " f"overall_avg_sample_mse={metrics[2]}, overall_avg_sample_l2err={metrics[3]}") except Exception as e: print(f"Failed to read log: {e}") return (None, None, None, None) if not all_metrics: print("No metrics found in the log file") return (None, None, None, None) print(f"\nTotal {len(all_metrics)} metrics found in the log file") best_agilex_mse = min(m[0] for m in all_metrics) best_agilex_l2err = min(m[1] for m in all_metrics) best_overall_mse = min(m[2] for m in all_metrics) best_overall_l2err = min(m[3] for m in all_metrics) print(f"\nBest metrics:") print(f" agilex_sample_mse: {best_agilex_mse}") print(f" agilex_sample_l2err: {best_agilex_l2err}") print(f" overall_avg_sample_mse: {best_overall_mse}") print(f" overall_avg_sample_l2err: {best_overall_l2err}") return (best_agilex_mse, best_agilex_l2err, best_overall_mse, best_overall_l2err) def generate_output_json(input_config_file, output_dir, runtime): with open(input_config_file, 'r') as f: config = json.load(f) log_file = os.path.join(output_dir, 'output.log') agilex_sample_mse, agilex_sample_l2err, overall_avg_sample_mse, overall_avg_sample_l2err = extract_metrics_from_log(log_file) if None in [agilex_sample_mse, agilex_sample_l2err, overall_avg_sample_mse, overall_avg_sample_l2err]: print("Warning: Some metrics are missing in the log file.") output_json = { "task_id": config.get("task_id"), "model_type": "RDT-1B", "model_name": config.get("model_name") if "model_name" in config else config.get("train", {}).get("model"), "gpu_id": config.get("gpu_id"), "runtime": runtime, "log_path": log_file, "output_dir": output_dir, "model_path": os.path.join(output_dir, 'pytorch_model.bin'), "metrics": { "agilex_sample_mse": agilex_sample_mse, "agilex_sample_l2err": agilex_sample_l2err, "overall_avg_sample_mse": overall_avg_sample_mse, "overall_avg_sample_l2err": overall_avg_sample_l2err } } # 写入 output.json,格式化输出、确保null与规范json一致 output_json_path = os.path.join(output_dir, 'output.json') with open(output_json_path, 'w') as f: json.dump(output_json, f, indent=4, ensure_ascii=False) if __name__ == "__main__": if len(sys.argv) != 4: print("Usage: python generate_output_json.py ") sys.exit(1) generate_output_json(sys.argv[1], sys.argv[2], sys.argv[3])