d-robotics-rdt/rdt1b-train/scripts/generate_output_json.py
2025-11-02 12:25:30 +08:00

85 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import sys
import re
def extract_metrics_from_log(log_file_path):
all_metrics = []
pattern = re.compile(
r"\{'agilex_sample_mse':\s*([0-9.eE+-]+),\s*'agilex_sample_l2err':\s*([0-9.eE+-]+),\s*'overall_avg_sample_mse':\s*([0-9.eE+-]+),\s*'overall_avg_sample_l2err':\s*([0-9.eE+-]+)\}"
)
try:
with open(log_file_path, 'r', encoding='utf-8') as f:
for line in f:
m = pattern.search(line)
if m:
metrics = (
float(m.group(1)),
float(m.group(2)),
float(m.group(3)),
float(m.group(4))
)
all_metrics.append(metrics)
print(f"Find Metrics: agilex_sample_mse={metrics[0]}, agilex_sample_l2err={metrics[1]}, "
f"overall_avg_sample_mse={metrics[2]}, overall_avg_sample_l2err={metrics[3]}")
except Exception as e:
print(f"Failed to read log: {e}")
return (None, None, None, None)
if not all_metrics:
print("No metrics found in the log file")
return (None, None, None, None)
print(f"\nTotal {len(all_metrics)} metrics found in the log file")
best_agilex_mse = min(m[0] for m in all_metrics)
best_agilex_l2err = min(m[1] for m in all_metrics)
best_overall_mse = min(m[2] for m in all_metrics)
best_overall_l2err = min(m[3] for m in all_metrics)
print(f"\nBest metrics:")
print(f" agilex_sample_mse: {best_agilex_mse}")
print(f" agilex_sample_l2err: {best_agilex_l2err}")
print(f" overall_avg_sample_mse: {best_overall_mse}")
print(f" overall_avg_sample_l2err: {best_overall_l2err}")
return (best_agilex_mse, best_agilex_l2err, best_overall_mse, best_overall_l2err)
def generate_output_json(input_config_file, output_dir, runtime):
with open(input_config_file, 'r') as f:
config = json.load(f)
log_file = os.path.join(output_dir, 'output.log')
agilex_sample_mse, agilex_sample_l2err, overall_avg_sample_mse, overall_avg_sample_l2err = extract_metrics_from_log(log_file)
if None in [agilex_sample_mse, agilex_sample_l2err, overall_avg_sample_mse, overall_avg_sample_l2err]:
print("Warning: Some metrics are missing in the log file.")
output_json = {
"task_id": config.get("task_id"),
"model_type": "RDT-1B",
"model_name": config.get("model_name") if "model_name" in config else config.get("train", {}).get("model"),
"gpu_id": config.get("gpu_id"),
"runtime": runtime,
"log_path": log_file,
"output_dir": output_dir,
"model_path": os.path.join(output_dir, 'pytorch_model.bin'),
"metrics": {
"agilex_sample_mse": agilex_sample_mse,
"agilex_sample_l2err": agilex_sample_l2err,
"overall_avg_sample_mse": overall_avg_sample_mse,
"overall_avg_sample_l2err": overall_avg_sample_l2err
}
}
# 写入 output.json格式化输出、确保null与规范json一致
output_json_path = os.path.join(output_dir, 'output.json')
with open(output_json_path, 'w') as f:
json.dump(output_json, f, indent=4, ensure_ascii=False)
if __name__ == "__main__":
if len(sys.argv) != 4:
print("Usage: python generate_output_json.py <input_config_file> <output_dir> <runtime>")
sys.exit(1)
generate_output_json(sys.argv[1], sys.argv[2], sys.argv[3])