import json import yaml import sys def read_config(config_file, yaml_file): with open(config_file, 'r') as f: json_config = json.load(f) with open(yaml_file, 'r') as f: yaml_config = yaml.load(f, Loader=yaml.FullLoader) yaml_config["model"] = json_config["train"]["model"] + json_config["task_id"] yaml_config["data_path"] = json_config["train"]["input_data_path"] + "/data" yaml_config["checkpoint_path"] = json_config["train"]["checkpoint_path"] + "/" + json_config["task_id"] yaml_config["pretrained_model_name_or_path"] = json_config["train"]["input_data_path"] + "/weights/rdt-1b" yaml_config["cuda_visible_device"] = str(json_config["gpu_id"]) print(f"cuda_visible_device: {yaml_config['cuda_visible_device']}") yaml_config["train_batch_size"] = int(json_config["train"]["batch_size"]) yaml_config["sample_batch_size"] = int(json_config["train"]["batch_size"]) * 2 yaml_config["max_train_steps"] = int(json_config["train"]["epochs"]) yaml_config["checkpointing_period"] = int(int(json_config["train"]["epochs"]) / 10) yaml_config["sample_period"] = 200 yaml_config["checkpoints_total_limit"] = 50 with open(yaml_file, 'w') as f: yaml.dump(yaml_config, f, default_flow_style=False) print("Config YAML file updated successfully") if __name__ == "__main__": read_config(sys.argv[1], sys.argv[2])