85 lines
3.5 KiB
Python
85 lines
3.5 KiB
Python
import json
|
||
import os
|
||
import sys
|
||
import re
|
||
|
||
def extract_metrics_from_log(log_file_path):
|
||
all_metrics = []
|
||
pattern = re.compile(
|
||
r"\{'agilex_sample_mse':\s*([0-9.eE+-]+),\s*'agilex_sample_l2err':\s*([0-9.eE+-]+),\s*'overall_avg_sample_mse':\s*([0-9.eE+-]+),\s*'overall_avg_sample_l2err':\s*([0-9.eE+-]+)\}"
|
||
)
|
||
try:
|
||
with open(log_file_path, 'r', encoding='utf-8') as f:
|
||
for line in f:
|
||
m = pattern.search(line)
|
||
if m:
|
||
metrics = (
|
||
float(m.group(1)),
|
||
float(m.group(2)),
|
||
float(m.group(3)),
|
||
float(m.group(4))
|
||
)
|
||
all_metrics.append(metrics)
|
||
print(f"Find Metrics: agilex_sample_mse={metrics[0]}, agilex_sample_l2err={metrics[1]}, "
|
||
f"overall_avg_sample_mse={metrics[2]}, overall_avg_sample_l2err={metrics[3]}")
|
||
except Exception as e:
|
||
print(f"Failed to read log: {e}")
|
||
return (None, None, None, None)
|
||
|
||
if not all_metrics:
|
||
print("No metrics found in the log file")
|
||
return (None, None, None, None)
|
||
|
||
print(f"\nTotal {len(all_metrics)} metrics found in the log file")
|
||
|
||
best_agilex_mse = min(m[0] for m in all_metrics)
|
||
best_agilex_l2err = min(m[1] for m in all_metrics)
|
||
best_overall_mse = min(m[2] for m in all_metrics)
|
||
best_overall_l2err = min(m[3] for m in all_metrics)
|
||
|
||
print(f"\nBest metrics:")
|
||
print(f" agilex_sample_mse: {best_agilex_mse}")
|
||
print(f" agilex_sample_l2err: {best_agilex_l2err}")
|
||
print(f" overall_avg_sample_mse: {best_overall_mse}")
|
||
print(f" overall_avg_sample_l2err: {best_overall_l2err}")
|
||
|
||
return (best_agilex_mse, best_agilex_l2err, best_overall_mse, best_overall_l2err)
|
||
|
||
def generate_output_json(input_config_file, output_dir, runtime):
|
||
with open(input_config_file, 'r') as f:
|
||
config = json.load(f)
|
||
|
||
log_file = os.path.join(output_dir, 'output.log')
|
||
agilex_sample_mse, agilex_sample_l2err, overall_avg_sample_mse, overall_avg_sample_l2err = extract_metrics_from_log(log_file)
|
||
|
||
if None in [agilex_sample_mse, agilex_sample_l2err, overall_avg_sample_mse, overall_avg_sample_l2err]:
|
||
print("Warning: Some metrics are missing in the log file.")
|
||
|
||
output_json = {
|
||
"task_id": config.get("task_id"),
|
||
"model_type": "RDT-170M",
|
||
"model_name": config.get("model_name") if "model_name" in config else config.get("train", {}).get("model"),
|
||
"gpu_id": config.get("gpu_id"),
|
||
"runtime": runtime,
|
||
"log_path": log_file,
|
||
"output_dir": output_dir,
|
||
"model_path": os.path.join(output_dir, 'pytorch_model.bin'),
|
||
"metrics": {
|
||
"agilex_sample_mse": agilex_sample_mse,
|
||
"agilex_sample_l2err": agilex_sample_l2err,
|
||
"overall_avg_sample_mse": overall_avg_sample_mse,
|
||
"overall_avg_sample_l2err": overall_avg_sample_l2err
|
||
}
|
||
}
|
||
|
||
# 写入 output.json,格式化输出、确保null与规范json一致
|
||
output_json_path = os.path.join(output_dir, 'output.json')
|
||
with open(output_json_path, 'w') as f:
|
||
json.dump(output_json, f, indent=4, ensure_ascii=False)
|
||
|
||
if __name__ == "__main__":
|
||
if len(sys.argv) != 4:
|
||
print("Usage: python generate_output_json.py <input_config_file> <output_dir> <runtime>")
|
||
sys.exit(1)
|
||
generate_output_json(sys.argv[1], sys.argv[2], sys.argv[3])
|