#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import time from copy import deepcopy from pathlib import Path import hydra import torch from omegaconf import DictConfig from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy from lerobot.common.policies.policy_protocol import PolicyWithUpdate from lerobot.common.utils.utils import ( format_big_number, get_safe_torch_device, init_logging, set_global_seed, ) from lerobot.scripts.eval import eval_policy def make_optimizer_and_scheduler(cfg, policy): if cfg.policy.name == "act": optimizer_params_dicts = [ { "params": [ p for n, p in policy.named_parameters() if not n.startswith("backbone") and p.requires_grad ] }, { "params": [ p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad ], "lr": cfg.training.lr_backbone, }, ] optimizer = torch.optim.AdamW( optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay ) lr_scheduler = None elif cfg.policy.name == "diffusion": optimizer = torch.optim.Adam( policy.diffusion.parameters(), cfg.training.lr, cfg.training.adam_betas, cfg.training.adam_eps, cfg.training.adam_weight_decay, ) from diffusers.optimization import get_scheduler lr_scheduler = get_scheduler( cfg.training.lr_scheduler, optimizer=optimizer, num_warmup_steps=cfg.training.lr_warmup_steps, num_training_steps=cfg.training.offline_steps, ) elif policy.name == "tdmpc": optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) lr_scheduler = None else: raise NotImplementedError() return optimizer, lr_scheduler def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): """Returns a dictionary of items for logging.""" start_time = time.time() policy.train() output_dict = policy.forward(batch) # TODO(rcadene): policy.unnormalize_outputs(out_dict) loss = output_dict["loss"] loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( policy.parameters(), grad_clip_norm, error_if_nonfinite=False, ) optimizer.step() optimizer.zero_grad() if lr_scheduler is not None: lr_scheduler.step() if isinstance(policy, PolicyWithUpdate): # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). policy.update() info = { "loss": loss.item(), "grad_norm": float(grad_norm), "lr": optimizer.param_groups[0]["lr"], "update_s": time.time() - start_time, **{k: v for k, v in output_dict.items() if k != "loss"}, } return info @hydra.main(version_base="1.2", config_name="default", config_path="../configs") def train_cli(cfg: dict): train( cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir, job_name=hydra.core.hydra_config.HydraConfig.get().job.name, ) def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"): from hydra import compose, initialize hydra.core.global_hydra.GlobalHydra.instance().clear() initialize(config_path=config_path) cfg = compose(config_name=config_name) train(cfg, out_dir=out_dir, job_name=job_name) def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline): loss = info["loss"] grad_norm = info["grad_norm"] lr = info["lr"] update_s = info["update_s"] # A sample is an (observation,action) pair, where observation and action # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples. num_samples = (step + 1) * cfg.training.batch_size avg_samples_per_ep = dataset.num_samples / dataset.num_episodes num_episodes = num_samples / avg_samples_per_ep num_epochs = num_samples / dataset.num_samples log_items = [ f"step:{format_big_number(step)}", # number of samples seen during training f"smpl:{format_big_number(num_samples)}", # number of episodes seen during training f"ep:{format_big_number(num_episodes)}", # number of time all unique samples are seen f"epch:{num_epochs:.2f}", f"loss:{loss:.3f}", f"grdn:{grad_norm:.3f}", f"lr:{lr:0.1e}", # in seconds f"updt_s:{update_s:.3f}", ] logging.info(" ".join(log_items)) info["step"] = step info["num_samples"] = num_samples info["num_episodes"] = num_episodes info["num_epochs"] = num_epochs info["is_offline"] = is_offline logger.log_dict(info, step, mode="train") def log_eval_info(logger, info, step, cfg, dataset, is_offline): eval_s = info["eval_s"] avg_sum_reward = info["avg_sum_reward"] pc_success = info["pc_success"] # A sample is an (observation,action) pair, where observation and action # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples. num_samples = (step + 1) * cfg.training.batch_size avg_samples_per_ep = dataset.num_samples / dataset.num_episodes num_episodes = num_samples / avg_samples_per_ep num_epochs = num_samples / dataset.num_samples log_items = [ f"step:{format_big_number(step)}", # number of samples seen during training f"smpl:{format_big_number(num_samples)}", # number of episodes seen during training f"ep:{format_big_number(num_episodes)}", # number of time all unique samples are seen f"epch:{num_epochs:.2f}", f"∑rwrd:{avg_sum_reward:.3f}", f"success:{pc_success:.1f}%", f"eval_s:{eval_s:.3f}", ] logging.info(" ".join(log_items)) info["step"] = step info["num_samples"] = num_samples info["num_episodes"] = num_episodes info["num_epochs"] = num_epochs info["is_offline"] = is_offline logger.log_dict(info, step, mode="eval") def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): if out_dir is None: raise NotImplementedError() if job_name is None: raise NotImplementedError() init_logging() if cfg.training.online_steps > 0: raise NotImplementedError("Online training is not implemented yet.") # Check device is available get_safe_torch_device(cfg.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True set_global_seed(cfg.seed) logging.info("make_dataset") offline_dataset = make_dataset(cfg) logging.info("make_env") eval_env = make_env(cfg) logging.info("make_policy") policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats) # Create optimizer and scheduler # Temporary hack to move optimizer out of policy optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) # log metrics to terminal and wandb logger = Logger(out_dir, job_name, cfg) log_output_dir(out_dir) logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})") logging.info(f"{cfg.training.online_steps=}") logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})") logging.info(f"{offline_dataset.num_episodes=}") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") # Note: this helper will be used in offline and online training loops. def evaluate_and_checkpoint_if_needed(step): if step % cfg.training.eval_freq == 0: logging.info(f"Eval policy at step {step}") eval_info = eval_policy( eval_env, policy, cfg.eval.n_episodes, video_dir=Path(out_dir) / "eval", max_episodes_rendered=4, start_seed=cfg.seed, ) log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline) if cfg.wandb.enable: logger.log_video(eval_info["video_paths"][0], step, mode="eval") logging.info("Resume training") if cfg.training.save_model and step % cfg.training.save_freq == 0: logging.info(f"Checkpoint policy after step {step}") # Note: Save with step as the identifier, and format it to have at least 6 digits but more if # needed (choose 6 as a minimum for consistency without being overkill). logger.save_model( policy, identifier=str(step).zfill( max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps))) ), ) logging.info("Resume training") # create dataloader for offline training dataloader = torch.utils.data.DataLoader( offline_dataset, num_workers=4, batch_size=cfg.training.batch_size, shuffle=True, pin_memory=cfg.device != "cpu", drop_last=False, ) dl_iter = cycle(dataloader) policy.train() is_offline = True for step in range(cfg.training.offline_steps): if step == 0: logging.info("Start offline training on a fixed dataset") batch = next(dl_iter) for key in batch: batch[key] = batch[key].to(cfg.device, non_blocking=True) train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler) # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.training.log_freq == 0: log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline) # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed, # so we pass in step + 1. evaluate_and_checkpoint_if_needed(step + 1) # create an empty online dataset similar to offline dataset online_dataset = deepcopy(offline_dataset) online_dataset.hf_dataset = {} online_dataset.episode_data_index = {} # create dataloader for online training concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset]) weights = [1.0] * len(concat_dataset) sampler = torch.utils.data.WeightedRandomSampler( weights, num_samples=len(concat_dataset), replacement=True ) dataloader = torch.utils.data.DataLoader( concat_dataset, num_workers=4, batch_size=cfg.training.batch_size, sampler=sampler, pin_memory=cfg.device != "cpu", drop_last=False, ) eval_env.close() logging.info("End of training") if __name__ == "__main__": train_cli()