32 lines
1.3 KiB
Python
32 lines
1.3 KiB
Python
import json
|
|
import yaml
|
|
import sys
|
|
|
|
def read_config(config_file, yaml_file):
|
|
with open(config_file, 'r') as f:
|
|
json_config = json.load(f)
|
|
with open(yaml_file, 'r') as f:
|
|
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
|
|
|
|
yaml_config["model"] = json_config["train"]["model"] + json_config["task_id"]
|
|
yaml_config["data_path"] = json_config["train"]["input_data_path"] + "/data"
|
|
yaml_config["checkpoint_path"] = json_config["train"]["checkpoint_path"] + "/" + json_config["task_id"]
|
|
yaml_config["pretrained_model_name_or_path"] = "/weights/rdt-1b"
|
|
yaml_config["cuda_visible_device"] = str(json_config["gpu_id"])
|
|
print(f"cuda_visible_device: {yaml_config['cuda_visible_device']}")
|
|
yaml_config["train_batch_size"] = int(json_config["train"]["batch_size"])
|
|
yaml_config["sample_batch_size"] = int(json_config["train"]["batch_size"]) * 2
|
|
yaml_config["max_train_steps"] = int(json_config["train"]["epochs"])
|
|
yaml_config["checkpointing_period"] = int(int(json_config["train"]["epochs"]) / 10)
|
|
yaml_config["sample_period"] = 200
|
|
yaml_config["checkpoints_total_limit"] = 50
|
|
|
|
|
|
with open(yaml_file, 'w') as f:
|
|
yaml.dump(yaml_config, f, default_flow_style=False)
|
|
|
|
print("Config YAML file updated successfully")
|
|
|
|
if __name__ == "__main__":
|
|
read_config(sys.argv[1], sys.argv[2])
|