60 lines
1.9 KiB
Python
60 lines
1.9 KiB
Python
import json
|
|
import os
|
|
|
|
with open("/workspace/embolab/params/build_task.json") as f:
|
|
task_configs = json.load(f)
|
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Lerobot supports only one GPU for training
|
|
os.environ["HF_HUB_OFFLINE"] = "1"
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
assert "train" in task_configs, "Not a validate train config"
|
|
assert task_configs["train"]["model"] in ["act", "smolvla"], "Only act and smolvla are supported for training"
|
|
|
|
# use_policy = (
|
|
# "--policy.path=lerobot/smolvla_base"
|
|
# if task_configs["train"]["model"] == "smolvla"
|
|
# else "--policy.type=act"
|
|
# )
|
|
use_policy = f"--policy.type={task_configs['train']['model']}"
|
|
|
|
task_id = task_configs["task_id"]
|
|
data_path = task_configs["train"]["input_data_path"]
|
|
ckpt_path = task_configs["train"]["checkpoint_path"]
|
|
bs = task_configs["train"]["batch_size"]
|
|
epochs = task_configs["train"]["epochs"]
|
|
|
|
use_resume = task_configs["train"].get("resume", False)
|
|
if use_resume:
|
|
resume_path = f'--policy.path="{task_configs["train"]["checkpoint_path"]}/pretrained_model"'
|
|
# eg: ${checkpoint_path}/checkpoints/last
|
|
|
|
with open(data_path + "/meta/info.json", "r") as f:
|
|
dataset_info = json.load(f)
|
|
total_frames = dataset_info["total_frames"]
|
|
|
|
steps_per_epoch = total_frames // bs + 1
|
|
steps = steps_per_epoch * epochs
|
|
print(
|
|
"Lerobot only support steps, calculating steps from epochs...",
|
|
f"Steps per epoch: {steps_per_epoch}, Total steps: {steps}",
|
|
)
|
|
|
|
train_cmd = f"""lerobot-train \
|
|
{resume_path if use_resume else use_policy} \
|
|
--policy.push_to_hub=false \
|
|
--dataset.repo_id=D-Robotics/{task_id} \
|
|
--dataset.root={data_path} \
|
|
--batch_size={bs} \
|
|
--output_dir={ckpt_path} \
|
|
--steps={steps} --save_freq={steps_per_epoch} \
|
|
"""
|
|
|
|
print("Executing command:\n", train_cmd)
|
|
|
|
import subprocess, sys
|
|
|
|
completed = subprocess.run(train_cmd, shell=True)
|
|
sys.exit(completed.returncode)
|