d-robotics-rdt/RDT-1B/scripts/read_config.py
2025-10-26 16:23:28 +08:00

32 lines
1.3 KiB
Python

import json
import yaml
import sys
def read_config(config_file, yaml_file):
with open(config_file, 'r') as f:
json_config = json.load(f)
with open(yaml_file, 'r') as f:
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
yaml_config["model"] = json_config["train"]["model"] + json_config["task_id"]
yaml_config["data_path"] = json_config["train"]["input_data_path"] + "/data"
yaml_config["checkpoint_path"] = json_config["train"]["checkpoint_path"] + "/" + json_config["task_id"]
yaml_config["pretrained_model_name_or_path"] = "/weights/rdt-1b"
yaml_config["cuda_visible_device"] = str(json_config["gpu_id"])
print(f"cuda_visible_device: {yaml_config['cuda_visible_device']}")
yaml_config["train_batch_size"] = int(json_config["train"]["batch_size"])
yaml_config["sample_batch_size"] = int(json_config["train"]["batch_size"]) * 2
yaml_config["max_train_steps"] = int(json_config["train"]["epochs"])
yaml_config["checkpointing_period"] = int(int(json_config["train"]["epochs"]) / 10)
yaml_config["sample_period"] = 200
yaml_config["checkpoints_total_limit"] = 50
with open(yaml_file, 'w') as f:
yaml.dump(yaml_config, f, default_flow_style=False)
print("Config YAML file updated successfully")
if __name__ == "__main__":
read_config(sys.argv[1], sys.argv[2])