41 lines
1.5 KiB
Python
41 lines
1.5 KiB
Python
import os
|
|
import yaml
|
|
import argparse
|
|
from datetime import datetime
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Generate finetune config.")
|
|
parser.add_argument("model_name", type=str, help="The name of the task (e.g., beat_block_hammer)")
|
|
args = parser.parse_args()
|
|
model_name = args.model_name
|
|
fintune_data_path = os.path.join("training_data/", f"{model_name}")
|
|
checkpoint_path = os.path.join("checkpoints/", f"{model_name}")
|
|
data = {
|
|
"model": model_name,
|
|
"data_path": fintune_data_path,
|
|
"checkpoint_path": checkpoint_path,
|
|
"pretrained_model_name_or_path": "../weights/RDT/rdt-1b",
|
|
"cuda_visible_device": "...", # args.gpu_use,
|
|
"train_batch_size": 32,
|
|
"sample_batch_size": 64,
|
|
"max_train_steps": 20000,
|
|
"checkpointing_period": 2500,
|
|
"sample_period": 100,
|
|
"checkpoints_total_limit": 40,
|
|
"learning_rate": 1e-4,
|
|
"dataloader_num_workers": 8,
|
|
"state_noise_snr": 40,
|
|
"gradient_accumulation_steps": 1,
|
|
}
|
|
task_config_path = os.path.join("model_config/", f"{model_name}.yml")
|
|
|
|
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
time_comment = f"# Generated on {current_time}\n"
|
|
|
|
with open(task_config_path, "w") as f:
|
|
f.write(time_comment)
|
|
yaml.dump(data, f, default_flow_style=False, sort_keys=False)
|
|
|
|
if not os.path.exists(fintune_data_path):
|
|
os.makedirs(fintune_data_path)
|