Lerobot/lerobot/configs/policies.py
Pepijn 0b2285d1ec
Feat: Improve hub integration (#1382)
* feat(policies): Initial setup to push policies to hub with tags and model card

* feat: add dataset that is used to train

* Add model template summary

* fix: Update link model_card template

* fix: remove print

* fix: change import name

* fix: add model summary in template

* fix: minor text

* fix: comments Lucain

* fix: feedback steven

* fix: restructure push to hub

* fix: remove unneeded changes

* fix: import

* fix: import 2

* Add MANIFEST.in

* fix: feedback pr

* Fix tests

* tests: Add smolvla end-to-end test

* Fix: smolvla test

* fix test name

* fix policy tests

* Add push to hub false policy tests

* Do push to hub cleaner

* fix(ci): add push_to_hub false in tests

---------

Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2025-06-26 14:36:16 +02:00

191 lines
7.1 KiB
Python

# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Type, TypeVar
import draccus
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
from lerobot.common.optim.optimizers import OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.common.utils.hub import HubMixin
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
# Generic variable that is either PreTrainedConfig or a subclass thereof
T = TypeVar("T", bound="PreTrainedConfig")
@dataclass
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
"""
Base configuration class for policy models.
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
input_shapes: A dictionary defining the shapes of the input data for the policy.
output_shapes: A dictionary defining the shapes of the output data for the policy.
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
normalization mode to apply.
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
the original scale.
"""
n_obs_steps: int = 1
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
device: str | None = None # cuda | cpu | mp
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool = False
push_to_hub: bool = True
repo_id: str | None = None
# Upload on private repository on the Hugging Face hub.
private: bool | None = None
# Add tags to your policy on the hub.
tags: list[str] | None = None
# Add tags to your policy on the hub.
license: str | None = None
def __post_init__(self):
self.pretrained_path = None
if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device.type
# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@property
@abc.abstractmethod
def observation_delta_indices(self) -> list | None:
raise NotImplementedError
@property
@abc.abstractmethod
def action_delta_indices(self) -> list | None:
raise NotImplementedError
@property
@abc.abstractmethod
def reward_delta_indices(self) -> list | None:
raise NotImplementedError
@abc.abstractmethod
def get_optimizer_preset(self) -> OptimizerConfig:
raise NotImplementedError
@abc.abstractmethod
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
raise NotImplementedError
@abc.abstractmethod
def validate_features(self) -> None:
raise NotImplementedError
@property
def robot_state_feature(self) -> PolicyFeature | None:
for _, ft in self.input_features.items():
if ft.type is FeatureType.STATE:
return ft
return None
@property
def env_state_feature(self) -> PolicyFeature | None:
for _, ft in self.input_features.items():
if ft.type is FeatureType.ENV:
return ft
return None
@property
def image_features(self) -> dict[str, PolicyFeature]:
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
@property
def action_feature(self) -> PolicyFeature | None:
for _, ft in self.output_features.items():
if ft.type is FeatureType.ACTION:
return ft
return None
def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
draccus.dump(self, f, indent=4)
@classmethod
def from_pretrained(
cls: Type[T],
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**policy_kwargs,
) -> T:
model_id = str(pretrained_name_or_path)
config_file: str | None = None
if Path(model_id).is_dir():
if CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, CONFIG_NAME)
else:
print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
else:
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
) from e
# HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus
# something like --policy.path (in addition to --policy.type)
cli_overrides = policy_kwargs.pop("cli_overrides", [])
with draccus.config_type("json"):
return draccus.parse(cls, config_file, args=cli_overrides)