移除所有测试代码以节省空间

This commit is contained in:
今日营业中... 2025-10-21 14:44:47 +08:00
parent 1e7bb40565
commit 026bd915b1
110 changed files with 4 additions and 17277 deletions

3
.gitignore vendored
View File

@ -181,3 +181,6 @@ s100
huggingface_models
docker/inputs
docker/outputs
# Skip big files in tests folder
tests

View File

@ -1,13 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9dc9df05797dc0e7b92edc845caab2e4c37c3cfcabb4ee6339c67212b5baba3b
size 38023

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7e11af87616b83c1cdb30330e951b91e86b51c64a1326e1ba5b4a3fbcdec1a11
size 55698

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b8840fb643afe903191248703b1f95a57faf5812ecd9978ac502ee939646fdb2
size 121115

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f79d14daafb1c0cf2fec5d46ee8029a73fe357402fdd31a7cd4a4794d7319a7c
size 260367

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a8d6e64d6cb0e02c94ae125630ee758055bd2e695772c0463a30d63ddc6c5e17
size 3520862

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6bdf22208d49cd36d24bc844d4d8bda5e321eafe39d2b470e4fc95c7812fdb24
size 3687117

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8920d5ebab36ffcba9aa74dcd91677c121f504b4d945b472352d379f9272fabf
size 3687117

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:35723f2db499da3d9d121aa79d2ff4c748effd7c2ea92f277ec543a82fb843ca
size 3687117

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:53172b773d4a78bb3140f10280105c2c4ebcb467f3097579988d42cb87790ab9
size 3687117

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:58a5d91573e7dd2352a1454a5c9118c9ad3798428a0104e5e0b57fc01f780ae7
size 3687117

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bb65a25e989a32a8b6258d368bd077e4548379c74ab5ada01cc532d658670df0
size 3687117

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c3dcff0a705ebfdaf11b7f49ad85b464eff03477ace3d63ce45d6a3a10b429d5
size 111338

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d8ab0274761cdd758bafdf274ce3e6398cd6f0df23393971f3e1b6b465d66ef3
size 111338

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:aee60956925da9687546aafa770d5e6a04f99576f903b08d0bd5f8003a7f4f3e
size 111338

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c8d9f9cc9e232820760fe4a46b47000c921fa5d868420e55d8dbc05dae56e8bd
size 111338

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:01cfe50c537e3aef0cd5947ec0b15b321b54ecb461baf7b4f2506897158eebc8
size 111338

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:96431ca3479eef2379406ef901cad7ba5eac4f7edcc48ecc9e8d1fa0e99d8017
size 111338

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3763d7bff7873cb40ea9d6f2f98d45fcf163addcd2809b6c59f273b6c3627ad5
size 85353

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:24150994c6959631dc081b43e4001a8664e13b194ac194a32100f7d3fd2c0d0f
size 85353

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c9c3fdf34debe47d4b80570a19e676185449df749f37daa2111184c1f439ae5f
size 85353

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f8cfbe444c14d643da2faea9f6a402ddb37114ab15395c381f1a7982e541f868
size 85353

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:07c5c1a63998884ee747a6d0aa8f49217da3c32af2760dad2a9da794d3517003
size 85353

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9927ec508e3335f8b10cf3682e41dedb7e647f92a2063a4196f1e48749c47bc5
size 85353

View File

@ -1,91 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script provides a utility for saving a dataset as safetensors files for the purpose of testing backward compatibility
when updating the data format. It uses the `PushtDataset` to create a DataLoader and saves selected frame from the
dataset into a corresponding safetensors file in a specified output directory.
If you know that your change will break backward compatibility, you should write a shortlived test by modifying
`tests/test_datasets.py::test_backward_compatibility` accordingly, and make sure this custom test pass. Your custom test
doesnt need to be merged into the `main` branch. Then you need to run this script and update the tests artifacts.
Example usage:
`python tests/artifacts/datasets/save_dataset_to_safetensors.py`
"""
import shutil
from pathlib import Path
from safetensors.torch import save_file
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
repo_dir = Path(output_dir) / repo_id
if repo_dir.exists():
shutil.rmtree(repo_dir)
repo_dir.mkdir(parents=True, exist_ok=True)
dataset = LeRobotDataset(
repo_id=repo_id,
episodes=[0],
)
# save 2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
# save 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
# save 2 last frames of first episode
i = dataset.episode_data_index["to"][0].item()
save_file(dataset[i - 2], repo_dir / f"frame_{i - 2}.safetensors")
save_file(dataset[i - 1], repo_dir / f"frame_{i - 1}.safetensors")
# TODO(rcadene): Enable testing on second and last episode
# We currently cant because our test dataset only contains the first episode
# # save 2 first frames of second episode
# i = dataset.episode_data_index["from"][1].item()
# save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
# save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
# # save 2 last frames of second episode
# i = dataset.episode_data_index["to"][1].item()
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
# # save 2 last frames of last episode
# i = dataset.episode_data_index["to"][-1].item()
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
if __name__ == "__main__":
for dataset in [
"lerobot/pusht",
"lerobot/aloha_sim_insertion_human",
"lerobot/xarm_lift_medium",
"lerobot/nyu_franka_play_dataset",
"lerobot/cmu_stretch",
]:
save_dataset_to_safetensors("tests/artifacts/datasets", repo_id=dataset)

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6b1e600768a8771c5fe650e038a1193597e3810f032041b2a0d021e4496381c1
size 3686488

View File

@ -1,75 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import torch
from safetensors.torch import save_file
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.transforms import (
ImageTransformConfig,
ImageTransforms,
ImageTransformsConfig,
make_transform_from_config,
)
from lerobot.utils.random_utils import seeded_context
ARTIFACT_DIR = Path("tests/artifacts/image_transforms")
DATASET_REPO_ID = "lerobot/aloha_static_cups_open"
def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path):
cfg = ImageTransformsConfig(enable=True)
default_tf = ImageTransforms(cfg)
with seeded_context(1337):
img_tf = default_tf(original_frame)
save_file({"default": img_tf}, output_dir / "default_transforms.safetensors")
def save_single_transforms(original_frame: torch.Tensor, output_dir: Path):
transforms = {
("ColorJitter", "brightness", [(0.5, 0.5), (2.0, 2.0)]),
("ColorJitter", "contrast", [(0.5, 0.5), (2.0, 2.0)]),
("ColorJitter", "saturation", [(0.5, 0.5), (2.0, 2.0)]),
("ColorJitter", "hue", [(-0.25, -0.25), (0.25, 0.25)]),
("SharpnessJitter", "sharpness", [(0.5, 0.5), (2.0, 2.0)]),
}
frames = {"original_frame": original_frame}
for tf_type, tf_name, min_max_values in transforms.items():
for min_max in min_max_values:
tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max})
tf = make_transform_from_config(tf_cfg)
key = f"{tf_name}_{min_max[0]}_{min_max[1]}"
frames[key] = tf(original_frame)
save_file(frames, output_dir / "single_transforms.safetensors")
def main():
dataset = LeRobotDataset(DATASET_REPO_ID, episodes=[0], image_transforms=None)
output_dir = Path(ARTIFACT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
original_frame = dataset[0][dataset.meta.camera_keys[0]]
save_single_transforms(original_frame, output_dir)
save_default_config_transform(original_frame, output_dir)
if __name__ == "__main__":
main()

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9d4ebab73eabddc58879a4e770289d19e00a1a4cf2fa5fa33cd3a3246992bc90
size 40551392

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f3e4c8e85e146b043fd4e4984947c2a6f01627f174a19f18b5914cf690579d77
size 5104

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1a7a8b1a457149109f843c32bcbb047d09de2201847b9b79f7501b447f77ecf4
size 31672

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5e6ce85296b2009e7c2060d336c0429b1c7197d9adb159e7df0ba18003067b36
size 68

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9b5f557e30aead3731c38cbd85af8c706395d8689a918ad88805b5a886245603
size 33400

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2e6625cabfeb4800abc80252cf9112a9271c154edd01eb291658f143c951610b
size 515400

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:224b5fa4828aa88171b68c036e8919c1eae563e2113f03b6461eadf5bf8525a6
size 31672

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:016d2fa8fe5f58017dfd46f4632fdc19dfd751e32a2c7cde2077c6f95546d6bd
size 68

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:021562ee3e4814425e367ed0c144d6fbe2eb28838247085716cf0b58fd69a075
size 33400

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a32376dde65a1562403afd1db3e56c7e6b987ebaf6c3c601336e77155b9e608c
size 992

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:12ee532c53173d0361ebb979f087b229cc045aa3d9e6b94cfd4290af54fd1201
size 47424

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:271b00cb2f0cd5fd26b1d53463638e3d1a6e92692ec625fcffb420ca190869e5
size 68

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:010c01181b95625051276d69cb4209423c21f2e30a3fa9464ae67064a2ba4c22
size 49120

View File

@ -1,145 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
from pathlib import Path
import torch
from safetensors.torch import save_file
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy, make_policy_config
from lerobot.utils.random_utils import set_seed
def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
set_seed(1337)
train_cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs),
)
train_cfg.validate() # Needed for auto-setting some parameters
dataset = make_dataset(train_cfg)
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
policy.train()
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=train_cfg.batch_size,
shuffle=False,
)
batch = next(iter(dataloader))
loss, output_dict = policy.forward(batch)
if output_dict is not None:
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
output_dict["loss"] = loss
else:
output_dict = {"loss": loss}
loss.backward()
grad_stats = {}
for key, param in policy.named_parameters():
if param.requires_grad:
grad_stats[f"{key}_mean"] = param.grad.mean()
grad_stats[f"{key}_std"] = (
param.grad.std() if param.grad.numel() > 1 else torch.tensor(float(0.0))
)
optimizer.step()
param_stats = {}
for key, param in policy.named_parameters():
param_stats[f"{key}_mean"] = param.mean()
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0))
optimizer.zero_grad()
policy.reset()
# HACK: We reload a batch with no delta_indices as `select_action` won't expect a timestamps dimension
# We simulate having an environment using a dataset by setting delta_indices to None and dropping tensors
# indicating padding (those ending with "_is_pad")
dataset.delta_indices = None
batch = next(iter(dataloader))
obs = {}
for k in batch:
# TODO: regenerate the safetensors
# for backward compatibility
if k.endswith("_is_pad"):
continue
# for backward compatibility
if k == "task":
continue
if k.startswith("observation"):
obs[k] = batch[k]
if hasattr(train_cfg.policy, "n_action_steps"):
actions_queue = train_cfg.policy.n_action_steps
else:
actions_queue = train_cfg.policy.n_action_repeats
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
return output_dict, grad_stats, param_stats, actions
def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict):
if output_dir.exists():
print(f"Overwrite existing safetensors in '{output_dir}':")
print(f" - Validate with: `git add {output_dir}`")
print(f" - Revert with: `git checkout -- {output_dir}`")
shutil.rmtree(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
save_file(output_dict, output_dir / "output_dict.safetensors")
save_file(grad_stats, output_dir / "grad_stats.safetensors")
save_file(param_stats, output_dir / "param_stats.safetensors")
save_file(actions, output_dir / "actions.safetensors")
if __name__ == "__main__":
artifacts_cfg = [
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
(
"lerobot/pusht",
"diffusion",
{
"n_action_steps": 8,
"num_inference_steps": 10,
"down_dims": [128, 256, 512],
},
"",
),
("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""),
(
"lerobot/aloha_sim_insertion_human",
"act",
{"n_action_steps": 1000, "chunk_size": 1000},
"1000_steps",
),
]
if len(artifacts_cfg) == 0:
raise RuntimeError("No policies were provided!")
for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg:
ds_name = ds_repo_id.split("/")[-1]
output_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}"
save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs)

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c5edc5600d7206f027cb696a597bc99fcdd9073a15fa130b8031c52c0a7c134b
size 200

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
size 16904

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
size 164

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
size 36312

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a9c08753ddc43b6c02a176418b81eb784146e59f4fc914591cbd3582ade392bb
size 200

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
size 16904

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
size 164

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
size 36312

View File

@ -1,177 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""End-to-end test of the asynchronous inference stack (client ↔ server).
This test spins up a lightweight gRPC `PolicyServer` instance with a stubbed
policy network and launches a `RobotClient` that uses a `MockRobot`. The goal
is to exercise the full communication loop:
1. Client sends policy specification Server
2. Client streams observations Server
3. Server streams action chunks Client
4. Client executes received actions
The test succeeds if at least one action is executed and the server records at
least one predicted timestep - demonstrating that the gRPC round-trip works
end-to-end using real (but lightweight) protocol messages.
"""
from __future__ import annotations
import threading
from concurrent import futures
import pytest
import torch
# Skip entire module if grpc is not available
pytest.importorskip("grpc")
# -----------------------------------------------------------------------------
# End-to-end test
# -----------------------------------------------------------------------------
def test_async_inference_e2e(monkeypatch):
"""Tests the full asynchronous inference pipeline."""
# Import grpc-dependent modules inside the test function
import grpc
from lerobot.robots.utils import make_robot_from_config
from lerobot.scripts.server.configs import PolicyServerConfig, RobotClientConfig
from lerobot.scripts.server.helpers import map_robot_keys_to_lerobot_features
from lerobot.scripts.server.policy_server import PolicyServer
from lerobot.scripts.server.robot_client import RobotClient
from lerobot.transport import (
services_pb2, # type: ignore
services_pb2_grpc, # type: ignore
)
from tests.mocks.mock_robot import MockRobotConfig
# Create a stub policy similar to test_policy_server.py
class MockPolicy:
"""A minimal mock for an actual policy, returning zeros."""
class _Config:
robot_type = "dummy_robot"
@property
def image_features(self):
"""Empty image features since this test doesn't use images."""
return {}
def __init__(self):
self.config = self._Config()
def to(self, *args, **kwargs):
return self
def model(self, batch):
# Return a chunk of 20 dummy actions.
batch_size = len(batch["robot_type"])
return torch.zeros(batch_size, 20, 6)
# ------------------------------------------------------------------
# 1. Create PolicyServer instance with mock policy
# ------------------------------------------------------------------
policy_server_config = PolicyServerConfig(host="localhost", port=9999)
policy_server = PolicyServer(policy_server_config)
# Replace the real policy with our fast, deterministic stub.
policy_server.policy = MockPolicy()
policy_server.actions_per_chunk = 20
policy_server.device = "cpu"
# Set up robot config and features
robot_config = MockRobotConfig()
mock_robot = make_robot_from_config(robot_config)
lerobot_features = map_robot_keys_to_lerobot_features(mock_robot)
policy_server.lerobot_features = lerobot_features
# Force server to produce deterministic action chunks in test mode
policy_server.policy_type = "act"
def _fake_get_action_chunk(_self, _obs, _type="test"):
action_dim = 6
batch_size = 1
actions_per_chunk = policy_server.actions_per_chunk
return torch.zeros(batch_size, actions_per_chunk, action_dim)
monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True)
# Bypass potentially heavy model loading inside SendPolicyInstructions
def _fake_send_policy_instructions(self, request, context): # noqa: N802
return services_pb2.Empty()
monkeypatch.setattr(PolicyServer, "SendPolicyInstructions", _fake_send_policy_instructions, raising=True)
# Build gRPC server running a PolicyServer
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="policy_server"))
services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
# Use the host/port specified in the fixture's config
server_address = f"{policy_server.config.host}:{policy_server.config.port}"
server.add_insecure_port(server_address)
server.start()
# ------------------------------------------------------------------
# 2. Create a RobotClient around the MockRobot
# ------------------------------------------------------------------
client_config = RobotClientConfig(
server_address=server_address,
robot=robot_config,
chunk_size_threshold=0.0,
policy_type="test",
pretrained_name_or_path="test",
actions_per_chunk=20,
verify_robot_cameras=False,
)
client = RobotClient(client_config)
assert client.start(), "Client failed initial handshake with the server"
# Track action chunks received without modifying RobotClient
action_chunks_received = {"count": 0}
original_aggregate = client._aggregate_action_queues
def counting_aggregate(*args, **kwargs):
action_chunks_received["count"] += 1
return original_aggregate(*args, **kwargs)
monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate)
# Start client threads
action_thread = threading.Thread(target=client.receive_actions, daemon=True)
control_thread = threading.Thread(target=client.control_loop, args=({"task": ""}), daemon=True)
action_thread.start()
control_thread.start()
# ------------------------------------------------------------------
# 3. System exchanges a few messages
# ------------------------------------------------------------------
# Wait for 5 seconds
server.wait_for_termination(timeout=5)
assert action_chunks_received["count"] > 0, "Client did not receive any action chunks"
assert len(policy_server._predicted_timesteps) > 0, "Server did not record any predicted timesteps"
# ------------------------------------------------------------------
# 4. Stop the system
# ------------------------------------------------------------------
client.stop()
action_thread.join()
control_thread.join()
policy_server.stop()
server.stop(grace=None)

View File

@ -1,459 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import pickle
import time
import numpy as np
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.scripts.server.helpers import (
FPSTracker,
TimedAction,
TimedObservation,
observations_similar,
prepare_image,
prepare_raw_observation,
raw_observation_to_observation,
resize_robot_observation_image,
)
# ---------------------------------------------------------------------
# FPSTracker
# ---------------------------------------------------------------------
def test_fps_tracker_first_observation():
"""First observation should initialize timestamp and return 0 FPS."""
tracker = FPSTracker(target_fps=30.0)
timestamp = 1000.0
metrics = tracker.calculate_fps_metrics(timestamp)
assert tracker.first_timestamp == timestamp
assert tracker.total_obs_count == 1
assert metrics["avg_fps"] == 0.0
assert metrics["target_fps"] == 30.0
def test_fps_tracker_single_interval():
"""Two observations 1 second apart should give 1 FPS."""
tracker = FPSTracker(target_fps=30.0)
# First observation at t=0
metrics1 = tracker.calculate_fps_metrics(0.0)
assert metrics1["avg_fps"] == 0.0
# Second observation at t=1 (1 second later)
metrics2 = tracker.calculate_fps_metrics(1.0)
expected_fps = 1.0 # (2-1) observations / 1.0 seconds = 1 FPS
assert math.isclose(metrics2["avg_fps"], expected_fps, rel_tol=1e-6)
def test_fps_tracker_multiple_intervals():
"""Multiple observations should calculate correct average FPS."""
tracker = FPSTracker(target_fps=30.0)
# Simulate 5 observations over 2 seconds (should be 2 FPS average)
timestamps = [0.0, 0.5, 1.0, 1.5, 2.0]
for i, ts in enumerate(timestamps):
metrics = tracker.calculate_fps_metrics(ts)
if i == 0:
assert metrics["avg_fps"] == 0.0
elif i == len(timestamps) - 1:
# After 5 observations over 2 seconds: (5-1)/2 = 2 FPS
expected_fps = 2.0
assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6)
def test_fps_tracker_irregular_intervals():
"""FPS calculation should work with irregular time intervals."""
tracker = FPSTracker(target_fps=30.0)
# Irregular timestamps: 0, 0.1, 0.5, 2.0, 3.0 seconds
timestamps = [0.0, 0.1, 0.5, 2.0, 3.0]
for ts in timestamps:
metrics = tracker.calculate_fps_metrics(ts)
# 5 observations over 3 seconds: (5-1)/3 = 1.333... FPS
expected_fps = 4.0 / 3.0
assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6)
# ---------------------------------------------------------------------
# TimedData helpers
# ---------------------------------------------------------------------
def test_timed_action_getters():
"""TimedAction stores & returns timestamp, action tensor and timestep."""
ts = time.time()
action = torch.arange(10)
ta = TimedAction(timestamp=ts, action=action, timestep=0)
assert math.isclose(ta.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
torch.testing.assert_close(ta.get_action(), action)
assert ta.get_timestep() == 0
def test_timed_observation_getters():
"""TimedObservation stores & returns timestamp, dict and timestep."""
ts = time.time()
obs_dict = {"observation.state": torch.ones(6)}
to = TimedObservation(timestamp=ts, observation=obs_dict, timestep=0)
assert math.isclose(to.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
assert to.get_observation() is obs_dict
assert to.get_timestep() == 0
def test_timed_data_deserialization_data_getters():
"""TimedAction / TimedObservation survive a round-trip through ``pickle``.
The async-inference stack uses ``pickle.dumps`` to move these objects across
the gRPC boundary (see RobotClient.send_observation and PolicyServer.StreamActions).
This test ensures that the payload keeps its content intact after
the (de)serialization round-trip.
"""
ts = time.time()
# ------------------------------------------------------------------
# TimedAction
# ------------------------------------------------------------------
original_action = torch.randn(6)
ta_in = TimedAction(timestamp=ts, action=original_action, timestep=13)
# Serialize → bytes → deserialize
ta_bytes = pickle.dumps(ta_in) # nosec
ta_out: TimedAction = pickle.loads(ta_bytes) # nosec B301
# Identity & content checks
assert math.isclose(ta_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
assert ta_out.get_timestep() == 13
torch.testing.assert_close(ta_out.get_action(), original_action)
# ------------------------------------------------------------------
# TimedObservation
# ------------------------------------------------------------------
obs_dict = {"observation.state": torch.arange(4).float()}
to_in = TimedObservation(timestamp=ts, observation=obs_dict, timestep=7, must_go=True)
to_bytes = pickle.dumps(to_in) # nosec
to_out: TimedObservation = pickle.loads(to_bytes) # nosec B301
assert math.isclose(to_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
assert to_out.get_timestep() == 7
assert to_out.must_go is True
assert to_out.get_observation().keys() == obs_dict.keys()
torch.testing.assert_close(to_out.get_observation()["observation.state"], obs_dict["observation.state"])
# ---------------------------------------------------------------------
# observations_similar()
# ---------------------------------------------------------------------
def _make_obs(state: torch.Tensor) -> TimedObservation:
"""Create a TimedObservation with raw robot observation format."""
return TimedObservation(
timestamp=time.time(),
observation={
"shoulder": state[0].item() if len(state) > 0 else 0.0,
"elbow": state[1].item() if len(state) > 1 else 0.0,
"wrist": state[2].item() if len(state) > 2 else 0.0,
"gripper": state[3].item() if len(state) > 3 else 0.0,
},
timestep=0,
)
def test_observations_similar_true():
"""Distance below atol → observations considered similar."""
# Create mock lerobot features for the similarity check
lerobot_features = {
"observation.state": {
"dtype": "float32",
"shape": [4],
"names": ["shoulder", "elbow", "wrist", "gripper"],
}
}
obs1 = _make_obs(torch.zeros(4))
obs2 = _make_obs(0.5 * torch.ones(4))
assert observations_similar(obs1, obs2, lerobot_features, atol=2.0)
obs3 = _make_obs(2.0 * torch.ones(4))
assert not observations_similar(obs1, obs3, lerobot_features, atol=2.0)
# ---------------------------------------------------------------------
# raw_observation_to_observation and helpers
# ---------------------------------------------------------------------
def _create_mock_robot_observation():
"""Create a mock robot observation with motor positions and camera images."""
return {
"shoulder": 1.0,
"elbow": 2.0,
"wrist": 3.0,
"gripper": 0.5,
"laptop": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8),
"phone": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8),
}
def _create_mock_lerobot_features():
"""Create mock lerobot features mapping similar to what hw_to_dataset_features returns."""
return {
"observation.state": {
"dtype": "float32",
"shape": [4],
"names": ["shoulder", "elbow", "wrist", "gripper"],
},
"observation.images.laptop": {
"dtype": "image",
"shape": [480, 640, 3],
"names": ["height", "width", "channels"],
},
"observation.images.phone": {
"dtype": "image",
"shape": [480, 640, 3],
"names": ["height", "width", "channels"],
},
}
def _create_mock_policy_image_features():
"""Create mock policy image features with different resolutions."""
return {
"observation.images.laptop": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224), # Policy expects smaller resolution
),
"observation.images.phone": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 160, 160), # Different resolution for second camera
),
}
def test_prepare_image():
"""Test image preprocessing: int8 → float32, normalization to [0,1]."""
# Create mock int8 image data
image_int8 = torch.randint(0, 256, size=(3, 224, 224), dtype=torch.uint8)
processed = prepare_image(image_int8)
# Check dtype conversion
assert processed.dtype == torch.float32
# Check normalization range
assert processed.min() >= 0.0
assert processed.max() <= 1.0
# Check that values are scaled correctly (255 → 1.0, 0 → 0.0)
if image_int8.max() == 255:
assert torch.isclose(processed.max(), torch.tensor(1.0), atol=1e-6)
if image_int8.min() == 0:
assert torch.isclose(processed.min(), torch.tensor(0.0), atol=1e-6)
# Check memory contiguity
assert processed.is_contiguous()
def test_resize_robot_observation_image():
"""Test image resizing from robot resolution to policy resolution."""
# Create mock image: (H=480, W=640, C=3)
original_image = torch.randint(0, 256, size=(480, 640, 3), dtype=torch.uint8)
target_shape = (3, 224, 224) # (C, H, W)
resized = resize_robot_observation_image(original_image, target_shape)
# Check output shape matches target
assert resized.shape == target_shape
# Check that original image had different dimensions
assert original_image.shape != resized.shape
# Check that resizing preserves value range
assert resized.min() >= 0
assert resized.max() <= 255
def test_prepare_raw_observation():
"""Test the preparation of raw robot observation to lerobot format."""
robot_obs = _create_mock_robot_observation()
lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features()
prepared = prepare_raw_observation(robot_obs, lerobot_features, policy_image_features)
# Check that state is properly extracted and batched
assert "observation.state" in prepared
state = prepared["observation.state"]
assert isinstance(state, torch.Tensor)
assert state.shape == (1, 4) # Batched state
# Check that images are processed and resized
assert "observation.images.laptop" in prepared
assert "observation.images.phone" in prepared
laptop_img = prepared["observation.images.laptop"]
phone_img = prepared["observation.images.phone"]
# Check image shapes match policy requirements
assert laptop_img.shape == policy_image_features["observation.images.laptop"].shape
assert phone_img.shape == policy_image_features["observation.images.phone"].shape
# Check that images are tensors
assert isinstance(laptop_img, torch.Tensor)
assert isinstance(phone_img, torch.Tensor)
def test_raw_observation_to_observation_basic():
"""Test the main raw_observation_to_observation function."""
robot_obs = _create_mock_robot_observation()
lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features()
device = "cpu"
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
# Check that all expected keys are present
assert "observation.state" in observation
assert "observation.images.laptop" in observation
assert "observation.images.phone" in observation
# Check state processing
state = observation["observation.state"]
assert isinstance(state, torch.Tensor)
assert state.device.type == device
assert state.shape == (1, 4) # Batched
# Check image processing
laptop_img = observation["observation.images.laptop"]
phone_img = observation["observation.images.phone"]
# Images should have batch dimension: (B, C, H, W)
assert laptop_img.shape == (1, 3, 224, 224)
assert phone_img.shape == (1, 3, 160, 160)
# Check device placement
assert laptop_img.device.type == device
assert phone_img.device.type == device
# Check image dtype and range (should be float32 in [0, 1])
assert laptop_img.dtype == torch.float32
assert phone_img.dtype == torch.float32
assert laptop_img.min() >= 0.0 and laptop_img.max() <= 1.0
assert phone_img.min() >= 0.0 and phone_img.max() <= 1.0
def test_raw_observation_to_observation_with_non_tensor_data():
"""Test that non-tensor data (like task strings) is preserved."""
robot_obs = _create_mock_robot_observation()
robot_obs["task"] = "pick up the red cube" # Add string instruction
lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features()
device = "cpu"
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
# Check that task string is preserved
assert "task" in observation
assert observation["task"] == "pick up the red cube"
assert isinstance(observation["task"], str)
@torch.no_grad()
def test_raw_observation_to_observation_device_handling():
"""Test that tensors are properly moved to the specified device."""
device = "mps" if torch.backends.mps.is_available() else "cpu"
robot_obs = _create_mock_robot_observation()
lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features()
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
# Check that all tensors are on the correct device
for key, value in observation.items():
if isinstance(value, torch.Tensor):
assert value.device.type == device, f"Tensor {key} not on {device}"
def test_raw_observation_to_observation_deterministic():
"""Test that the function produces consistent results for the same input."""
robot_obs = _create_mock_robot_observation()
lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features()
device = "cpu"
# Run twice with same input
obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
# Results should be identical
assert set(obs1.keys()) == set(obs2.keys())
for key in obs1:
if isinstance(obs1[key], torch.Tensor):
torch.testing.assert_close(obs1[key], obs2[key])
else:
assert obs1[key] == obs2[key]
def test_image_processing_pipeline_preserves_content():
"""Test that the image processing pipeline preserves recognizable patterns."""
# Create an image with a specific pattern
original_img = np.zeros((100, 100, 3), dtype=np.uint8)
original_img[25:75, 25:75, :] = 255 # White square in center
robot_obs = {"shoulder": 1.0, "elbow": 1.0, "wrist": 1.0, "gripper": 1.0, "laptop": original_img}
lerobot_features = {
"observation.state": {
"dtype": "float32",
"shape": [4],
"names": ["shoulder", "elbow", "wrist", "gripper"],
},
"observation.images.laptop": {
"dtype": "image",
"shape": [100, 100, 3],
"names": ["height", "width", "channels"],
},
}
policy_image_features = {
"observation.images.laptop": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 50, 50), # Downsamples from 100x100
)
}
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu")
processed_img = observation["observation.images.laptop"].squeeze(0) # Remove batch dim
# Check that the center region has higher values than corners
# Due to bilinear interpolation, exact values will change but pattern should remain
center_val = processed_img[:, 25, 25].mean() # Center of 50x50 image
corner_val = processed_img[:, 5, 5].mean() # Corner
assert center_val > corner_val, "Image processing should preserve recognizable patterns"

View File

@ -1,215 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit-tests for the `PolicyServer` core logic.
Monkey-patch the `policy` attribute with a stub so that no real model inference is performed.
"""
from __future__ import annotations
import time
import pytest
import torch
from lerobot.configs.types import PolicyFeature
from tests.utils import require_package
# -----------------------------------------------------------------------------
# Test fixtures
# -----------------------------------------------------------------------------
class MockPolicy:
"""A minimal mock for an actual policy, returning zeros.
Refer to tests/policies for tests of the individual policies supported."""
class _Config:
robot_type = "dummy_robot"
@property
def image_features(self) -> dict[str, PolicyFeature]:
"""Empty image features since this test doesn't use images."""
return {}
def predict_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return a chunk of 20 dummy actions."""
batch_size = len(observation["observation.state"])
return torch.zeros(batch_size, 20, 6)
def __init__(self):
self.config = self._Config()
def to(self, *args, **kwargs):
# The server calls `policy.to(device)`. This stub ignores it.
return self
def model(self, batch: dict) -> torch.Tensor:
# Return a chunk of 20 dummy actions.
batch_size = len(batch["robot_type"])
return torch.zeros(batch_size, 20, 6)
@pytest.fixture
@require_package("grpc")
def policy_server():
"""Fresh `PolicyServer` instance with a stubbed-out policy model."""
# Import only when the test actually runs (after decorator check)
from lerobot.scripts.server.configs import PolicyServerConfig
from lerobot.scripts.server.policy_server import PolicyServer
test_config = PolicyServerConfig(host="localhost", port=9999)
server = PolicyServer(test_config)
# Replace the real policy with our fast, deterministic stub.
server.policy = MockPolicy()
server.actions_per_chunk = 20
server.device = "cpu"
# Add mock lerobot_features that the observation similarity functions need
server.lerobot_features = {
"observation.state": {
"dtype": "float32",
"shape": [6],
"names": ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"],
}
}
return server
# -----------------------------------------------------------------------------
# Helper utilities for tests
# -----------------------------------------------------------------------------
def _make_obs(state: torch.Tensor, timestep: int = 0, must_go: bool = False):
"""Create a TimedObservation with a given state vector."""
# Import only when needed
from lerobot.scripts.server.helpers import TimedObservation
return TimedObservation(
observation={
"joint1": state[0].item() if len(state) > 0 else 0.0,
"joint2": state[1].item() if len(state) > 1 else 0.0,
"joint3": state[2].item() if len(state) > 2 else 0.0,
"joint4": state[3].item() if len(state) > 3 else 0.0,
"joint5": state[4].item() if len(state) > 4 else 0.0,
"joint6": state[5].item() if len(state) > 5 else 0.0,
},
timestamp=time.time(),
timestep=timestep,
must_go=must_go,
)
# -----------------------------------------------------------------------------
# Tests
# -----------------------------------------------------------------------------
def test_time_action_chunk(policy_server):
"""Verify that `_time_action_chunk` assigns correct timestamps and timesteps."""
start_ts = time.time()
start_t = 10
# A chunk of 3 action tensors.
action_tensors = [torch.randn(6) for _ in range(3)]
timed_actions = policy_server._time_action_chunk(start_ts, action_tensors, start_t)
assert len(timed_actions) == 3
# Check timesteps
assert [ta.get_timestep() for ta in timed_actions] == [10, 11, 12]
# Check timestamps
expected_timestamps = [
start_ts,
start_ts + policy_server.config.environment_dt,
start_ts + 2 * policy_server.config.environment_dt,
]
for ta, expected_ts in zip(timed_actions, expected_timestamps, strict=True):
assert abs(ta.get_timestamp() - expected_ts) < 1e-6
def test_maybe_enqueue_observation_must_go(policy_server):
"""An observation with `must_go=True` is always enqueued."""
obs = _make_obs(torch.zeros(6), must_go=True)
assert policy_server._enqueue_observation(obs) is True
assert policy_server.observation_queue.qsize() == 1
assert policy_server.observation_queue.get_nowait() is obs
def test_maybe_enqueue_observation_dissimilar(policy_server):
"""A dissimilar observation (not `must_go`) is enqueued."""
# Set a last predicted observation.
policy_server.last_processed_obs = _make_obs(torch.zeros(6))
# Create a new, dissimilar observation.
new_obs = _make_obs(torch.ones(6) * 5) # High norm difference
assert policy_server._enqueue_observation(new_obs) is True
assert policy_server.observation_queue.qsize() == 1
def test_maybe_enqueue_observation_is_skipped(policy_server):
"""A similar observation (not `must_go`) is skipped."""
# Set a last predicted observation.
policy_server.last_processed_obs = _make_obs(torch.zeros(6))
# Create a new, very similar observation.
new_obs = _make_obs(torch.zeros(6) + 1e-4)
assert policy_server._enqueue_observation(new_obs) is False
assert policy_server.observation_queue.empty() is True
def test_obs_sanity_checks(policy_server):
"""Unit-test the private `_obs_sanity_checks` helper."""
prev = _make_obs(torch.zeros(6), timestep=0)
# Case 1 timestep already predicted
policy_server._predicted_timesteps.add(1)
obs_same_ts = _make_obs(torch.ones(6), timestep=1)
assert policy_server._obs_sanity_checks(obs_same_ts, prev) is False
# Case 2 observation too similar
policy_server._predicted_timesteps.clear()
obs_similar = _make_obs(torch.zeros(6) + 1e-4, timestep=2)
assert policy_server._obs_sanity_checks(obs_similar, prev) is False
# Case 3 genuinely new & dissimilar observation passes
obs_ok = _make_obs(torch.ones(6) * 5, timestep=3)
assert policy_server._obs_sanity_checks(obs_ok, prev) is True
def test_predict_action_chunk(monkeypatch, policy_server):
"""End-to-end test of `_predict_action_chunk` with a stubbed _get_action_chunk."""
# Import only when needed
from lerobot.scripts.server.policy_server import PolicyServer
# Force server to act-style policy; patch method to return deterministic tensor
policy_server.policy_type = "act"
action_dim = 6
batch_size = 1
actions_per_chunk = policy_server.actions_per_chunk
def _fake_get_action_chunk(_self, _obs, _type="act"):
return torch.zeros(batch_size, actions_per_chunk, action_dim)
monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True)
obs = _make_obs(torch.zeros(6), timestep=5)
timed_actions = policy_server._predict_action_chunk(obs)
assert len(timed_actions) == actions_per_chunk
assert [ta.get_timestep() for ta in timed_actions] == list(range(5, 5 + actions_per_chunk))
for i, ta in enumerate(timed_actions):
expected_ts = obs.get_timestamp() + i * policy_server.config.environment_dt
assert abs(ta.get_timestamp() - expected_ts) < 1e-6

View File

@ -1,234 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit-tests for the `RobotClient` action-queue logic (pure Python, no gRPC).
We monkey-patch `lerobot.robots.utils.make_robot_from_config` so that
no real hardware is accessed. Only the queue-update mechanism is verified.
"""
from __future__ import annotations
import time
from queue import Queue
import pytest
import torch
# Skip entire module if grpc is not available
pytest.importorskip("grpc")
# -----------------------------------------------------------------------------
# Test fixtures
# -----------------------------------------------------------------------------
@pytest.fixture()
def robot_client():
"""Fresh `RobotClient` instance for each test case (no threads started).
Uses DummyRobot."""
# Import only when the test actually runs (after decorator check)
from lerobot.scripts.server.configs import RobotClientConfig
from lerobot.scripts.server.robot_client import RobotClient
from tests.mocks.mock_robot import MockRobotConfig
test_config = MockRobotConfig()
# gRPC channel is not actually used in tests, so using a dummy address
test_config = RobotClientConfig(
robot=test_config,
server_address="localhost:9999",
policy_type="test",
pretrained_name_or_path="test",
actions_per_chunk=20,
verify_robot_cameras=False,
)
client = RobotClient(test_config)
# Initialize attributes that are normally set in start() method
client.chunks_received = 0
client.available_actions_size = []
yield client
if client.robot.is_connected:
client.stop()
# -----------------------------------------------------------------------------
# Helper utilities for tests
# -----------------------------------------------------------------------------
def _make_actions(start_ts: float, start_t: int, count: int):
"""Generate `count` consecutive TimedAction objects starting at timestep `start_t`."""
from lerobot.scripts.server.helpers import TimedAction
fps = 30 # emulates most common frame-rate
actions = []
for i in range(count):
timestep = start_t + i
timestamp = start_ts + i * (1 / fps)
action_tensor = torch.full((6,), timestep, dtype=torch.float32)
actions.append(TimedAction(action=action_tensor, timestep=timestep, timestamp=timestamp))
return actions
# -----------------------------------------------------------------------------
# Tests
# -----------------------------------------------------------------------------
def test_update_action_queue_discards_stale(robot_client):
"""`_update_action_queue` must drop actions with `timestep` <= `latest_action`."""
# Pretend we already executed up to action #4
robot_client.latest_action = 4
# Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept.
incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7
robot_client._aggregate_action_queues(incoming)
# Extract timesteps from queue
resulting_timesteps = [a.get_timestep() for a in robot_client.action_queue.queue]
assert resulting_timesteps == [5, 6, 7]
@pytest.mark.parametrize(
"weight_old, weight_new",
[
(1.0, 0.0),
(0.0, 1.0),
(0.5, 0.5),
(0.2, 0.8),
(0.8, 0.2),
(0.1, 0.9),
(0.9, 0.1),
],
)
def test_aggregate_action_queues_combines_actions_in_overlap(
robot_client, weight_old: float, weight_new: float
):
"""`_aggregate_action_queues` must combine actions on overlapping timesteps according
to the provided aggregate_fn, here tested with multiple coefficients."""
from lerobot.scripts.server.helpers import TimedAction
robot_client.chunks_received = 0
# Pretend we already executed up to action #4, and queue contains actions for timesteps 5..6
robot_client.latest_action = 4
current_actions = _make_actions(
start_ts=time.time(), start_t=5, count=2
) # actions are [torch.ones(6), torch.ones(6), ...]
current_actions = [
TimedAction(action=10 * a.get_action(), timestep=a.get_timestep(), timestamp=a.get_timestamp())
for a in current_actions
]
for a in current_actions:
robot_client.action_queue.put(a)
# Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept.
incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7
overlap_timesteps = [5, 6] # properly tested in test_aggregate_action_queues_discards_stale
nonoverlap_timesteps = [7]
robot_client._aggregate_action_queues(
incoming, aggregate_fn=lambda x1, x2: weight_old * x1 + weight_new * x2
)
queue_overlap_actions = []
queue_non_overlap_actions = []
for a in robot_client.action_queue.queue:
if a.get_timestep() in overlap_timesteps:
queue_overlap_actions.append(a)
elif a.get_timestep() in nonoverlap_timesteps:
queue_non_overlap_actions.append(a)
queue_overlap_actions = sorted(queue_overlap_actions, key=lambda x: x.get_timestep())
queue_non_overlap_actions = sorted(queue_non_overlap_actions, key=lambda x: x.get_timestep())
assert torch.allclose(
queue_overlap_actions[0].get_action(),
weight_old * current_actions[0].get_action() + weight_new * incoming[-3].get_action(),
)
assert torch.allclose(
queue_overlap_actions[1].get_action(),
weight_old * current_actions[1].get_action() + weight_new * incoming[-2].get_action(),
)
assert torch.allclose(queue_non_overlap_actions[0].get_action(), incoming[-1].get_action())
@pytest.mark.parametrize(
"chunk_size, queue_len, expected",
[
(20, 12, False), # 12 / 20 = 0.6 > g=0.5 threshold, not ready to send
(20, 8, True), # 8 / 20 = 0.4 <= g=0.5, ready to send
(10, 5, True),
(10, 6, False),
],
)
def test_ready_to_send_observation(robot_client, chunk_size: int, queue_len: int, expected: bool):
"""Validate `_ready_to_send_observation` ratio logic for various sizes."""
robot_client.action_chunk_size = chunk_size
# Clear any existing actions then fill with `queue_len` dummy entries ----
robot_client.action_queue = Queue()
dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len)
for act in dummy_actions:
robot_client.action_queue.put(act)
assert robot_client._ready_to_send_observation() is expected
@pytest.mark.parametrize(
"g_threshold, expected",
[
# The condition is `queue_size / chunk_size <= g`.
# Here, ratio = 6 / 10 = 0.6.
(0.0, False), # 0.6 <= 0.0 is False
(0.1, False),
(0.2, False),
(0.3, False),
(0.4, False),
(0.5, False),
(0.6, True), # 0.6 <= 0.6 is True
(0.7, True),
(0.8, True),
(0.9, True),
(1.0, True),
],
)
def test_ready_to_send_observation_with_varying_threshold(robot_client, g_threshold: float, expected: bool):
"""Validate `_ready_to_send_observation` with fixed sizes and varying `g`."""
# Fixed sizes for this test: ratio = 6 / 10 = 0.6
chunk_size = 10
queue_len = 6
robot_client.action_chunk_size = chunk_size
# This is the parameter we are testing
robot_client._chunk_size_threshold = g_threshold
# Fill queue with dummy actions
robot_client.action_queue = Queue()
dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len)
for act in dummy_actions:
robot_client.action_queue.put(act)
assert robot_client._ready_to_send_observation() is expected

View File

@ -1,188 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Example of running a specific test:
# ```bash
# pytest tests/cameras/test_opencv.py::test_connect
# ```
from pathlib import Path
import numpy as np
import pytest
from lerobot.cameras.configs import Cv2Rotation
from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
# NOTE(Steven): more tests + assertions?
TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "cameras"
DEFAULT_PNG_FILE_PATH = TEST_ARTIFACTS_DIR / "image_160x120.png"
TEST_IMAGE_SIZES = ["128x128", "160x120", "320x180", "480x270"]
TEST_IMAGE_PATHS = [TEST_ARTIFACTS_DIR / f"image_{size}.png" for size in TEST_IMAGE_SIZES]
def test_abc_implementation():
"""Instantiation should raise an error if the class doesn't implement abstract methods/properties."""
config = OpenCVCameraConfig(index_or_path=0)
_ = OpenCVCamera(config)
def test_connect():
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
camera = OpenCVCamera(config)
camera.connect(warmup=False)
assert camera.is_connected
def test_connect_already_connected():
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
camera = OpenCVCamera(config)
camera.connect(warmup=False)
with pytest.raises(DeviceAlreadyConnectedError):
camera.connect(warmup=False)
def test_connect_invalid_camera_path():
config = OpenCVCameraConfig(index_or_path="nonexistent/camera.png")
camera = OpenCVCamera(config)
with pytest.raises(ConnectionError):
camera.connect(warmup=False)
def test_invalid_width_connect():
config = OpenCVCameraConfig(
index_or_path=DEFAULT_PNG_FILE_PATH,
width=99999, # Invalid width to trigger error
height=480,
)
camera = OpenCVCamera(config)
with pytest.raises(RuntimeError):
camera.connect(warmup=False)
@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES)
def test_read(index_or_path):
config = OpenCVCameraConfig(index_or_path=index_or_path)
camera = OpenCVCamera(config)
camera.connect(warmup=False)
img = camera.read()
assert isinstance(img, np.ndarray)
def test_read_before_connect():
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
camera = OpenCVCamera(config)
with pytest.raises(DeviceNotConnectedError):
_ = camera.read()
def test_disconnect():
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
camera = OpenCVCamera(config)
camera.connect(warmup=False)
camera.disconnect()
assert not camera.is_connected
def test_disconnect_before_connect():
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
camera = OpenCVCamera(config)
with pytest.raises(DeviceNotConnectedError):
_ = camera.disconnect()
@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES)
def test_async_read(index_or_path):
config = OpenCVCameraConfig(index_or_path=index_or_path)
camera = OpenCVCamera(config)
camera.connect(warmup=False)
try:
img = camera.async_read()
assert camera.thread is not None
assert camera.thread.is_alive()
assert isinstance(img, np.ndarray)
finally:
if camera.is_connected:
camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends
def test_async_read_timeout():
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
camera = OpenCVCamera(config)
camera.connect(warmup=False)
try:
with pytest.raises(TimeoutError):
camera.async_read(timeout_ms=0)
finally:
if camera.is_connected:
camera.disconnect()
def test_async_read_before_connect():
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
camera = OpenCVCamera(config)
with pytest.raises(DeviceNotConnectedError):
_ = camera.async_read()
@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES)
@pytest.mark.parametrize(
"rotation",
[
Cv2Rotation.NO_ROTATION,
Cv2Rotation.ROTATE_90,
Cv2Rotation.ROTATE_180,
Cv2Rotation.ROTATE_270,
],
ids=["no_rot", "rot90", "rot180", "rot270"],
)
def test_rotation(rotation, index_or_path):
filename = Path(index_or_path).name
dimensions = filename.split("_")[-1].split(".")[0] # Assumes filenames format (_wxh.png)
original_width, original_height = map(int, dimensions.split("x"))
config = OpenCVCameraConfig(index_or_path=index_or_path, rotation=rotation)
camera = OpenCVCamera(config)
camera.connect(warmup=False)
img = camera.read()
assert isinstance(img, np.ndarray)
if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270):
assert camera.width == original_height
assert camera.height == original_width
assert img.shape[:2] == (original_width, original_height)
else:
assert camera.width == original_width
assert camera.height == original_height
assert img.shape[:2] == (original_height, original_width)

View File

@ -1,206 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Example of running a specific test:
# ```bash
# pytest tests/cameras/test_opencv.py::test_connect
# ```
from pathlib import Path
from unittest.mock import patch
import numpy as np
import pytest
from lerobot.cameras.configs import Cv2Rotation
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
pytest.importorskip("pyrealsense2")
from lerobot.cameras.realsense import RealSenseCamera, RealSenseCameraConfig
TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "cameras"
BAG_FILE_PATH = TEST_ARTIFACTS_DIR / "test_rs.bag"
# NOTE(Steven): For some reason these tests take ~20sec in macOS but only ~2sec in Linux.
def mock_rs_config_enable_device_from_file(rs_config_instance, _sn):
return rs_config_instance.enable_device_from_file(str(BAG_FILE_PATH), repeat_playback=True)
def mock_rs_config_enable_device_bad_file(rs_config_instance, _sn):
return rs_config_instance.enable_device_from_file("non_existent_file.bag", repeat_playback=True)
@pytest.fixture(name="patch_realsense", autouse=True)
def fixture_patch_realsense():
"""Automatically mock pyrealsense2.config.enable_device for all tests."""
with patch(
"pyrealsense2.config.enable_device", side_effect=mock_rs_config_enable_device_from_file
) as mock:
yield mock
def test_abc_implementation():
"""Instantiation should raise an error if the class doesn't implement abstract methods/properties."""
config = RealSenseCameraConfig(serial_number_or_name="042")
_ = RealSenseCamera(config)
def test_connect():
config = RealSenseCameraConfig(serial_number_or_name="042")
camera = RealSenseCamera(config)
camera.connect(warmup=False)
assert camera.is_connected
def test_connect_already_connected():
config = RealSenseCameraConfig(serial_number_or_name="042")
camera = RealSenseCamera(config)
camera.connect(warmup=False)
with pytest.raises(DeviceAlreadyConnectedError):
camera.connect(warmup=False)
def test_connect_invalid_camera_path(patch_realsense):
patch_realsense.side_effect = mock_rs_config_enable_device_bad_file
config = RealSenseCameraConfig(serial_number_or_name="042")
camera = RealSenseCamera(config)
with pytest.raises(ConnectionError):
camera.connect(warmup=False)
def test_invalid_width_connect():
config = RealSenseCameraConfig(serial_number_or_name="042", width=99999, height=480, fps=30)
camera = RealSenseCamera(config)
with pytest.raises(ConnectionError):
camera.connect(warmup=False)
def test_read():
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30)
camera = RealSenseCamera(config)
camera.connect(warmup=False)
img = camera.read()
assert isinstance(img, np.ndarray)
# TODO(Steven): Fix this test for the latest version of pyrealsense2.
@pytest.mark.skip("Skipping test: pyrealsense2 version > 2.55.1.6486")
def test_read_depth():
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, use_depth=True)
camera = RealSenseCamera(config)
camera.connect(warmup=False)
img = camera.read_depth(timeout_ms=2000) # NOTE(Steven): Reading depth takes longer in CI environments.
assert isinstance(img, np.ndarray)
def test_read_before_connect():
config = RealSenseCameraConfig(serial_number_or_name="042")
camera = RealSenseCamera(config)
with pytest.raises(DeviceNotConnectedError):
_ = camera.read()
def test_disconnect():
config = RealSenseCameraConfig(serial_number_or_name="042")
camera = RealSenseCamera(config)
camera.connect(warmup=False)
camera.disconnect()
assert not camera.is_connected
def test_disconnect_before_connect():
config = RealSenseCameraConfig(serial_number_or_name="042")
camera = RealSenseCamera(config)
with pytest.raises(DeviceNotConnectedError):
camera.disconnect()
def test_async_read():
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30)
camera = RealSenseCamera(config)
camera.connect(warmup=False)
try:
img = camera.async_read()
assert camera.thread is not None
assert camera.thread.is_alive()
assert isinstance(img, np.ndarray)
finally:
if camera.is_connected:
camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends
def test_async_read_timeout():
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30)
camera = RealSenseCamera(config)
camera.connect(warmup=False)
try:
with pytest.raises(TimeoutError):
camera.async_read(timeout_ms=0)
finally:
if camera.is_connected:
camera.disconnect()
def test_async_read_before_connect():
config = RealSenseCameraConfig(serial_number_or_name="042")
camera = RealSenseCamera(config)
with pytest.raises(DeviceNotConnectedError):
_ = camera.async_read()
@pytest.mark.parametrize(
"rotation",
[
Cv2Rotation.NO_ROTATION,
Cv2Rotation.ROTATE_90,
Cv2Rotation.ROTATE_180,
Cv2Rotation.ROTATE_270,
],
ids=["no_rot", "rot90", "rot180", "rot270"],
)
def test_rotation(rotation):
config = RealSenseCameraConfig(serial_number_or_name="042", rotation=rotation)
camera = RealSenseCamera(config)
camera.connect(warmup=False)
img = camera.read()
assert isinstance(img, np.ndarray)
if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270):
assert camera.width == 480
assert camera.height == 640
assert img.shape[:2] == (640, 480)
else:
assert camera.width == 640
assert camera.height == 480
assert img.shape[:2] == (480, 640)

View File

@ -1,105 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from collections.abc import Generator
from dataclasses import dataclass
from pathlib import Path
import pytest
from lerobot.configs.parser import PluginLoadError, load_plugin, parse_plugin_args, wrap
from lerobot.envs.configs import EnvConfig
def create_plugin_code(*, base_class: str = "EnvConfig", plugin_name: str = "test_env") -> str:
"""Creates a dummy plugin module that implements its own EnvConfig subclass."""
return f"""
from dataclasses import dataclass
from lerobot.envs.configs import {base_class}
@{base_class}.register_subclass("{plugin_name}")
@dataclass
class TestPluginConfig:
value: int = 42
"""
@pytest.fixture
def plugin_dir(tmp_path: Path) -> Generator[Path, None, None]:
"""Creates a temporary plugin package structure."""
plugin_pkg = tmp_path / "test_plugin"
plugin_pkg.mkdir()
(plugin_pkg / "__init__.py").touch()
with open(plugin_pkg / "my_plugin.py", "w") as f:
f.write(create_plugin_code())
# Add tmp_path to Python path so we can import from it
sys.path.insert(0, str(tmp_path))
yield plugin_pkg
sys.path.pop(0)
def test_parse_plugin_args():
cli_args = [
"--env.type=test",
"--model.discover_packages_path=some.package",
"--env.discover_packages_path=other.package",
]
plugin_args = parse_plugin_args("discover_packages_path", cli_args)
assert plugin_args == {
"model.discover_packages_path": "some.package",
"env.discover_packages_path": "other.package",
}
def test_load_plugin_success(plugin_dir: Path):
# Import should work and register the plugin with the real EnvConfig
load_plugin("test_plugin")
assert "test_env" in EnvConfig.get_known_choices()
plugin_cls = EnvConfig.get_choice_class("test_env")
plugin_instance = plugin_cls()
assert plugin_instance.value == 42
def test_load_plugin_failure():
with pytest.raises(PluginLoadError) as exc_info:
load_plugin("nonexistent_plugin")
assert "Failed to load plugin 'nonexistent_plugin'" in str(exc_info.value)
def test_wrap_with_plugin(plugin_dir: Path):
@dataclass
class Config:
env: EnvConfig
@wrap()
def dummy_func(cfg: Config):
return cfg
# Test loading plugin via CLI args
sys.argv = [
"dummy_script.py",
"--env.discover_packages_path=test_plugin",
"--env.type=test_env",
]
cfg = dummy_func()
assert isinstance(cfg, Config)
assert isinstance(cfg.env, EnvConfig.get_choice_class("test_env"))
assert cfg.env.value == 42

View File

@ -1,88 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
import pytest
from serial import SerialException
from lerobot.configs.types import FeatureType, PolicyFeature
from tests.utils import DEVICE
# Import fixture modules as plugins
pytest_plugins = [
"tests.fixtures.dataset_factories",
"tests.fixtures.files",
"tests.fixtures.hub",
"tests.fixtures.optimizers",
]
def pytest_collection_finish():
print(f"\nTesting with {DEVICE=}")
def _check_component_availability(component_type, available_components, make_component):
"""Generic helper to check if a hardware component is available"""
if component_type not in available_components:
raise ValueError(
f"The {component_type} type is not valid. Expected one of these '{available_components}'"
)
try:
component = make_component(component_type)
component.connect()
del component
return True
except Exception as e:
print(f"\nA {component_type} is not available.")
if isinstance(e, ModuleNotFoundError):
print(f"\nInstall module '{e.name}'")
elif isinstance(e, SerialException):
print("\nNo physical device detected.")
elif isinstance(e, ValueError) and "camera_index" in str(e):
print("\nNo physical camera detected.")
else:
traceback.print_exc()
return False
@pytest.fixture
def patch_builtins_input(monkeypatch):
def print_text(text=None):
if text is not None:
print(text)
monkeypatch.setattr("builtins.input", print_text)
@pytest.fixture
def policy_feature_factory():
"""PolicyFeature factory"""
def _pf(ft: FeatureType, shape: tuple[int, ...]) -> PolicyFeature:
return PolicyFeature(type=ft, shape=shape)
return _pf
def assert_contract_is_typed(features: dict[str, PolicyFeature]) -> None:
assert isinstance(features, dict)
assert all(isinstance(k, str) for k in features.keys())
assert all(isinstance(v, PolicyFeature) for v in features.values())

View File

@ -1,309 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import patch
import numpy as np
import pytest
from lerobot.datasets.compute_stats import (
_assert_type_and_shape,
aggregate_feature_stats,
aggregate_stats,
compute_episode_stats,
estimate_num_samples,
get_feature_stats,
sample_images,
sample_indices,
)
def mock_load_image_as_numpy(path, dtype, channel_first):
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
@pytest.fixture
def sample_array():
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
def test_estimate_num_samples():
assert estimate_num_samples(1) == 1
assert estimate_num_samples(10) == 10
assert estimate_num_samples(100) == 100
assert estimate_num_samples(200) == 100
assert estimate_num_samples(1000) == 177
assert estimate_num_samples(2000) == 299
assert estimate_num_samples(5000) == 594
assert estimate_num_samples(10_000) == 1000
assert estimate_num_samples(20_000) == 1681
assert estimate_num_samples(50_000) == 3343
assert estimate_num_samples(500_000) == 10_000
def test_sample_indices():
indices = sample_indices(10)
assert len(indices) > 0
assert indices[0] == 0
assert indices[-1] == 9
assert len(indices) == estimate_num_samples(10)
@patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
def test_sample_images(mock_load):
image_paths = [f"image_{i}.jpg" for i in range(100)]
images = sample_images(image_paths)
assert isinstance(images, np.ndarray)
assert images.shape[1:] == (3, 32, 32)
assert images.dtype == np.uint8
assert len(images) == estimate_num_samples(100)
def test_get_feature_stats_images():
data = np.random.rand(100, 3, 32, 32)
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
np.testing.assert_equal(stats["count"], np.array([100]))
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
def test_get_feature_stats_axis_0_keepdims(sample_array):
expected = {
"min": np.array([[1, 2, 3]]),
"max": np.array([[7, 8, 9]]),
"mean": np.array([[4.0, 5.0, 6.0]]),
"std": np.array([[2.44948974, 2.44948974, 2.44948974]]),
"count": np.array([3]),
}
result = get_feature_stats(sample_array, axis=(0,), keepdims=True)
for key in expected:
np.testing.assert_allclose(result[key], expected[key])
def test_get_feature_stats_axis_1(sample_array):
expected = {
"min": np.array([1, 4, 7]),
"max": np.array([3, 6, 9]),
"mean": np.array([2.0, 5.0, 8.0]),
"std": np.array([0.81649658, 0.81649658, 0.81649658]),
"count": np.array([3]),
}
result = get_feature_stats(sample_array, axis=(1,), keepdims=False)
for key in expected:
np.testing.assert_allclose(result[key], expected[key])
def test_get_feature_stats_no_axis(sample_array):
expected = {
"min": np.array(1),
"max": np.array(9),
"mean": np.array(5.0),
"std": np.array(2.5819889),
"count": np.array([3]),
}
result = get_feature_stats(sample_array, axis=None, keepdims=False)
for key in expected:
np.testing.assert_allclose(result[key], expected[key])
def test_get_feature_stats_empty_array():
array = np.array([])
with pytest.raises(ValueError):
get_feature_stats(array, axis=(0,), keepdims=True)
def test_get_feature_stats_single_value():
array = np.array([[1337]])
result = get_feature_stats(array, axis=None, keepdims=True)
np.testing.assert_equal(result["min"], np.array(1337))
np.testing.assert_equal(result["max"], np.array(1337))
np.testing.assert_equal(result["mean"], np.array(1337.0))
np.testing.assert_equal(result["std"], np.array(0.0))
np.testing.assert_equal(result["count"], np.array([1]))
def test_compute_episode_stats():
episode_data = {
"observation.image": [f"image_{i}.jpg" for i in range(100)],
"observation.state": np.random.rand(100, 10),
}
features = {
"observation.image": {"dtype": "image"},
"observation.state": {"dtype": "numeric"},
}
with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy):
stats = compute_episode_stats(episode_data, features)
assert "observation.image" in stats and "observation.state" in stats
assert stats["observation.image"]["count"].item() == 100
assert stats["observation.state"]["count"].item() == 100
assert stats["observation.image"]["mean"].shape == (3, 1, 1)
def test_assert_type_and_shape_valid():
valid_stats = [
{
"feature1": {
"min": np.array([1.0]),
"max": np.array([10.0]),
"mean": np.array([5.0]),
"std": np.array([2.0]),
"count": np.array([1]),
}
}
]
_assert_type_and_shape(valid_stats)
def test_assert_type_and_shape_invalid_type():
invalid_stats = [
{
"feature1": {
"min": [1.0], # Not a numpy array
"max": np.array([10.0]),
"mean": np.array([5.0]),
"std": np.array([2.0]),
"count": np.array([1]),
}
}
]
with pytest.raises(ValueError, match="Stats must be composed of numpy array"):
_assert_type_and_shape(invalid_stats)
def test_assert_type_and_shape_invalid_shape():
invalid_stats = [
{
"feature1": {
"count": np.array([1, 2]), # Wrong shape
}
}
]
with pytest.raises(ValueError, match=r"Shape of 'count' must be \(1\)"):
_assert_type_and_shape(invalid_stats)
def test_aggregate_feature_stats():
stats_ft_list = [
{
"min": np.array([1.0]),
"max": np.array([10.0]),
"mean": np.array([5.0]),
"std": np.array([2.0]),
"count": np.array([1]),
},
{
"min": np.array([2.0]),
"max": np.array([12.0]),
"mean": np.array([6.0]),
"std": np.array([2.5]),
"count": np.array([1]),
},
]
result = aggregate_feature_stats(stats_ft_list)
np.testing.assert_allclose(result["min"], np.array([1.0]))
np.testing.assert_allclose(result["max"], np.array([12.0]))
np.testing.assert_allclose(result["mean"], np.array([5.5]))
np.testing.assert_allclose(result["std"], np.array([2.318405]), atol=1e-6)
np.testing.assert_allclose(result["count"], np.array([2]))
def test_aggregate_stats():
all_stats = [
{
"observation.image": {
"min": [1, 2, 3],
"max": [10, 20, 30],
"mean": [5.5, 10.5, 15.5],
"std": [2.87, 5.87, 8.87],
"count": 10,
},
"observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10},
"extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6},
},
{
"observation.image": {
"min": [2, 1, 0],
"max": [15, 10, 5],
"mean": [8.5, 5.5, 2.5],
"std": [3.42, 2.42, 1.42],
"count": 15,
},
"observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15},
"extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5},
},
]
expected_agg_stats = {
"observation.image": {
"min": [1, 1, 0],
"max": [15, 20, 30],
"mean": [7.3, 7.5, 7.7],
"std": [3.5317, 4.8267, 8.5581],
"count": 25,
},
"observation.state": {
"min": 1,
"max": 15,
"mean": 7.3,
"std": 3.5317,
"count": 25,
},
"extra_key_0": {
"min": 5,
"max": 25,
"mean": 15.0,
"std": 6.0,
"count": 6,
},
"extra_key_1": {
"min": 0,
"max": 20,
"mean": 10.0,
"std": 5.0,
"count": 5,
},
}
# cast to numpy
for ep_stats in all_stats:
for fkey, stats in ep_stats.items():
for k in stats:
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
if fkey == "observation.image" and k != "count":
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
else:
stats[k] = stats[k].reshape(1)
# cast to numpy
for fkey, stats in expected_agg_stats.items():
for k in stats:
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
if fkey == "observation.image" and k != "count":
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
else:
stats[k] = stats[k].reshape(1)
results = aggregate_stats(all_stats)
for fkey in expected_agg_stats:
np.testing.assert_allclose(results[fkey]["min"], expected_agg_stats[fkey]["min"])
np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"])
np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"])
np.testing.assert_allclose(
results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04
)
np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"])

View File

@ -1,573 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import re
from copy import deepcopy
from itertools import chain
from pathlib import Path
import numpy as np
import pytest
import torch
from huggingface_hub import HfApi
from PIL import Image
from safetensors.torch import load_file
import lerobot
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.image_writer import image_array_to_pil_image
from lerobot.datasets.lerobot_dataset import (
LeRobotDataset,
MultiLeRobotDataset,
)
from lerobot.datasets.utils import (
create_branch,
flatten_dict,
unflatten_dict,
)
from lerobot.envs.factory import make_env_config
from lerobot.policies.factory import make_policy_config
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
from tests.utils import require_x86_64_kernel
@pytest.fixture
def image_dataset(tmp_path, empty_lerobot_dataset_factory):
features = {
"image": {
"dtype": "image",
"shape": DUMMY_CHW,
"names": [
"channels",
"height",
"width",
],
}
}
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
"""
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
objects have the same sets of attributes defined.
"""
# Instantiate both ways
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
root_create = tmp_path / "create"
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, features=features, root=root_create)
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init)
init_attr = set(vars(dataset_init).keys())
create_attr = set(vars(dataset_create).keys())
assert init_attr == create_attr
def test_dataset_initialization(tmp_path, lerobot_dataset_factory):
kwargs = {
"repo_id": DUMMY_REPO_ID,
"total_episodes": 10,
"total_frames": 400,
"episodes": [2, 5, 6],
}
dataset = lerobot_dataset_factory(root=tmp_path / "test", **kwargs)
assert dataset.repo_id == kwargs["repo_id"]
assert dataset.meta.total_episodes == kwargs["total_episodes"]
assert dataset.meta.total_frames == kwargs["total_frames"]
assert dataset.episodes == kwargs["episodes"]
assert dataset.num_episodes == len(kwargs["episodes"])
assert dataset.num_frames == len(dataset)
def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n"
):
dataset.add_frame({"wrong_feature": torch.randn(1)}, task="Dummy task")
def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
):
dataset.add_frame({"state": torch.randn(1), "extra": "dummy_extra"}, task="Dummy task")
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
):
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16)}, task="Dummy task")
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError,
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
):
dataset.add_frame({"state": torch.randn(1)}, task="Dummy task")
def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError,
match=re.escape(
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'float'>' provided instead.\n"
),
):
dataset.add_frame({"state": 1.0}, task="Dummy task")
def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError,
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
):
dataset.add_frame({"state": torch.tensor(1.0)}, task="Dummy task")
def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError,
match=re.escape(
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'numpy.float32'>' provided instead.\n"
),
):
dataset.add_frame({"state": np.float32(1.0)}, task="Dummy task")
def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(1)}, task="Dummy task")
dataset.save_episode()
assert len(dataset) == 1
assert dataset[0]["task"] == "Dummy task"
assert dataset[0]["task_index"] == 0
assert dataset[0]["state"].ndim == 0
def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2])
def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4])
def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": np.array([1], dtype=np.float32)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["state"].ndim == 0
def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
features = {"caption": {"dtype": "string", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"caption": "Dummy caption"}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["caption"] == "Dummy caption"
def test_add_frame_image_wrong_shape(image_dataset):
dataset = image_dataset
with pytest.raises(
ValueError,
match=re.escape(
"The feature 'image' of shape '(3, 128, 96)' does not have the expected shape '(3, 96, 128)' or '(96, 128, 3)'.\n"
),
):
c, h, w = DUMMY_CHW
dataset.add_frame({"image": torch.randn(c, w, h)}, task="Dummy task")
def test_add_frame_image_wrong_range(image_dataset):
"""This test will display the following error message from a thread:
```
Error writing image ...test_add_frame_image_wrong_ran0/test/images/image/episode_000000/frame_000000.png:
The image data type is float, which requires values in the range [0.0, 1.0]. However, the provided range is [0.009678772038470007, 254.9776492089887].
Please adjust the range or provide a uint8 image with values in the range [0, 255]
```
Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`.
"""
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255}, task="Dummy task")
with pytest.raises(FileNotFoundError):
dataset.save_episode()
def test_add_frame_image(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
def test_add_frame_image_h_w_c(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
def test_add_frame_image_uint8(image_dataset):
dataset = image_dataset
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": image}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
def test_add_frame_image_pil(image_dataset):
dataset = image_dataset
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": Image.fromarray(image)}, task="Dummy task")
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
def test_image_array_to_pil_image_wrong_range_float_0_255():
image = np.random.rand(*DUMMY_HWC) * 255
with pytest.raises(ValueError):
image_array_to_pil_image(image)
# TODO(aliberts):
# - [ ] test various attributes & state from init and create
# - [ ] test init with episodes and check num_frames
# - [ ] test add_episode
# - [ ] test push_to_hub
# - [ ] test smaller methods
@pytest.mark.parametrize(
"env_name, repo_id, policy_name",
# Single dataset
lerobot.env_dataset_policy_triplets,
# Multi-dataset
# TODO after fix multidataset
# + [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
)
def test_factory(env_name, repo_id, policy_name):
"""
Tests that:
- we can create a dataset with the factory.
- for a commonly used set of data keys, the data dimensions are correct.
"""
cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id=repo_id, episodes=[0]),
env=make_env_config(env_name),
policy=make_policy_config(policy_name, push_to_hub=False),
)
cfg.validate()
dataset = make_dataset(cfg)
delta_timestamps = dataset.delta_timestamps
camera_keys = dataset.meta.camera_keys
item = dataset[0]
keys_ndim_required = [
("action", 1, True),
("episode_index", 0, True),
("frame_index", 0, True),
("timestamp", 0, True),
# TODO(rcadene): should we rename it agent_pos?
("observation.state", 1, True),
("next.reward", 0, False),
("next.done", 0, False),
]
# test number of dimensions
for key, ndim, required in keys_ndim_required:
if key not in item:
if required:
assert key in item, f"{key}"
else:
logging.warning(f'Missing key in dataset: "{key}" not in {dataset}.')
continue
if delta_timestamps is not None and key in delta_timestamps:
assert item[key].ndim == ndim + 1, f"{key}"
assert item[key].shape[0] == len(delta_timestamps[key]), f"{key}"
else:
assert item[key].ndim == ndim, f"{key}"
if key in camera_keys:
assert item[key].dtype == torch.float32, f"{key}"
# TODO(rcadene): we assume for now that image normalization takes place in the model
assert item[key].max() <= 1.0, f"{key}"
assert item[key].min() >= 0.0, f"{key}"
if delta_timestamps is not None and key in delta_timestamps:
# test t,c,h,w
assert item[key].shape[1] == 3, f"{key}"
else:
# test c,h,w
assert item[key].shape[0] == 3, f"{key}"
if delta_timestamps is not None:
# test missing keys in delta_timestamps
for key in delta_timestamps:
assert key in item, f"{key}"
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
@pytest.mark.skip("TODO after fix multidataset")
def test_multidataset_frames():
"""Check that all dataset frames are incorporated."""
# Note: use the image variants of the dataset to make the test approx 3x faster.
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
# logic that wouldn't be caught with two repo IDs.
repo_ids = [
"lerobot/aloha_sim_insertion_human_image",
"lerobot/aloha_sim_transfer_cube_human_image",
"lerobot/aloha_sim_insertion_scripted_image",
]
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
dataset = MultiLeRobotDataset(repo_ids)
assert len(dataset) == sum(len(d) for d in sub_datasets)
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
# check they match.
expected_dataset_indices = []
for i, sub_dataset in enumerate(sub_datasets):
expected_dataset_indices.extend([i] * len(sub_dataset))
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
):
dataset_index = dataset_item.pop("dataset_index")
assert dataset_index == expected_dataset_index
assert sub_dataset_item.keys() == dataset_item.keys()
for k in sub_dataset_item:
assert torch.equal(sub_dataset_item[k], dataset_item[k])
# TODO(aliberts): Move to more appropriate location
def test_flatten_unflatten_dict():
d = {
"obs": {
"min": 0,
"max": 1,
"mean": 2,
"std": 3,
},
"action": {
"min": 4,
"max": 5,
"mean": 6,
"std": 7,
},
}
original_d = deepcopy(d)
d = unflatten_dict(flatten_dict(d))
# test equality between nested dicts
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
@pytest.mark.parametrize(
"repo_id",
[
"lerobot/pusht",
"lerobot/aloha_sim_insertion_human",
"lerobot/xarm_lift_medium",
# (michel-aractingi) commenting the two datasets from openx as test is failing
# "lerobot/nyu_franka_play_dataset",
# "lerobot/cmu_stretch",
],
)
@require_x86_64_kernel
def test_backward_compatibility(repo_id):
"""The artifacts for this test have been generated by `tests/artifacts/datasets/save_dataset_to_safetensors.py`."""
# TODO(rcadene, aliberts): remove dataset download
dataset = LeRobotDataset(repo_id, episodes=[0])
test_dir = Path("tests/artifacts/datasets") / repo_id
def load_and_compare(i):
new_frame = dataset[i] # noqa: B023
old_frame = load_file(test_dir / f"frame_{i}.safetensors") # noqa: B023
# ignore language instructions (if exists) in language conditioned datasets
# TODO (michel-aractingi): transform language obs to language embeddings via tokenizer
new_frame.pop("language_instruction", None)
old_frame.pop("language_instruction", None)
new_frame.pop("task", None)
old_frame.pop("task", None)
# Remove task_index to allow for backward compatibility
# TODO(rcadene): remove when new features have been generated
if "task_index" not in old_frame:
del new_frame["task_index"]
new_keys = set(new_frame.keys())
old_keys = set(old_frame.keys())
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
for key in new_frame:
assert torch.isclose(new_frame[key], old_frame[key]).all(), (
f"{key=} for index={i} does not contain the same value"
)
# test2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()
load_and_compare(i)
load_and_compare(i + 1)
# test 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
load_and_compare(i)
load_and_compare(i + 1)
# test 2 last frames of first episode
i = dataset.episode_data_index["to"][0].item()
load_and_compare(i - 2)
load_and_compare(i - 1)
# TODO(rcadene): Enable testing on second and last episode
# We currently cant because our test dataset only contains the first episode
# # test 2 first frames of second episode
# i = dataset.episode_data_index["from"][1].item()
# load_and_compare(i)
# load_and_compare(i + 1)
# # test 2 last frames of second episode
# i = dataset.episode_data_index["to"][1].item()
# load_and_compare(i - 2)
# load_and_compare(i - 1)
# # test 2 last frames of last episode
# i = dataset.episode_data_index["to"][-1].item()
# load_and_compare(i - 2)
# load_and_compare(i - 1)
@pytest.mark.skip("Requires internet access")
def test_create_branch():
api = HfApi()
repo_id = "cadene/test_create_branch"
repo_type = "dataset"
branch = "test"
ref = f"refs/heads/{branch}"
# Prepare a repo with a test branch
api.delete_repo(repo_id, repo_type=repo_type, missing_ok=True)
api.create_repo(repo_id, repo_type=repo_type)
create_branch(repo_id, repo_type=repo_type, branch=branch)
# Make sure the test branch exists
branches = api.list_repo_refs(repo_id, repo_type=repo_type).branches
refs = [branch.ref for branch in branches]
assert ref in refs
# Overwrite it
create_branch(repo_id, repo_type=repo_type, branch=branch)
# Clean
api.delete_repo(repo_id, repo_type=repo_type)
def test_dataset_feature_with_forward_slash_raises_error():
# make sure dir does not exist
from lerobot.constants import HF_LEROBOT_HOME
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
# make sure does not exist
if dataset_dir.exists():
dataset_dir.rmdir()
with pytest.raises(ValueError):
LeRobotDataset.create(
repo_id="lerobot/test/with/slash",
fps=30,
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
)

View File

@ -1,278 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from itertools import accumulate
import datasets
import numpy as np
import pyarrow.compute as pc
import pytest
import torch
from lerobot.datasets.utils import (
check_delta_timestamps,
check_timestamps_sync,
get_delta_indices,
)
from tests.fixtures.constants import DUMMY_MOTOR_FEATURES
def calculate_total_episode(
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
) -> dict[str, torch.Tensor]:
episode_indices = sorted(hf_dataset.unique("episode_index"))
total_episodes = len(episode_indices)
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
raise ValueError("episode_index values are not sorted and contiguous.")
return total_episodes
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.ndarray]:
episode_lengths = []
table = hf_dataset.data.table
total_episodes = calculate_total_episode(hf_dataset)
for ep_idx in range(total_episodes):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table))
cumulative_lengths = list(accumulate(episode_lengths))
return {
"from": np.array([0] + cumulative_lengths[:-1], dtype=np.int64),
"to": np.array(cumulative_lengths, dtype=np.int64),
}
@pytest.fixture(scope="module")
def synced_timestamps_factory(hf_dataset_factory):
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
hf_dataset = hf_dataset_factory(fps=fps)
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(hf_dataset["episode_index"]).numpy()
episode_data_index = calculate_episode_data_index(hf_dataset)
return timestamps, episode_indices, episode_data_index
return _create_synced_timestamps
@pytest.fixture(scope="module")
def unsynced_timestamps_factory(synced_timestamps_factory):
def _create_unsynced_timestamps(
fps: int = 30, tolerance_s: float = 1e-4
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
timestamps[30] += tolerance_s * 1.1 # Modify a single timestamp just outside tolerance
return timestamps, episode_indices, episode_data_index
return _create_unsynced_timestamps
@pytest.fixture(scope="module")
def slightly_off_timestamps_factory(synced_timestamps_factory):
def _create_slightly_off_timestamps(
fps: int = 30, tolerance_s: float = 1e-4
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
timestamps[30] += tolerance_s * 0.9 # Modify a single timestamp just inside tolerance
return timestamps, episode_indices, episode_data_index
return _create_slightly_off_timestamps
@pytest.fixture(scope="module")
def valid_delta_timestamps_factory():
def _create_valid_delta_timestamps(
fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)
) -> dict:
delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys}
return delta_timestamps
return _create_valid_delta_timestamps
@pytest.fixture(scope="module")
def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
def _create_invalid_delta_timestamps(
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
) -> dict:
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
# Modify a single timestamp just outside tolerance
for key in keys:
delta_timestamps[key][3] += tolerance_s * 1.1
return delta_timestamps
return _create_invalid_delta_timestamps
@pytest.fixture(scope="module")
def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
def _create_slightly_off_delta_timestamps(
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
) -> dict:
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
# Modify a single timestamp just inside tolerance
for key in delta_timestamps:
delta_timestamps[key][3] += tolerance_s * 0.9
delta_timestamps[key][-3] += tolerance_s * 0.9
return delta_timestamps
return _create_slightly_off_delta_timestamps
@pytest.fixture(scope="module")
def delta_indices_factory():
def _delta_indices(keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
return {key: list(range(*min_max_range)) for key in keys}
return _delta_indices
def test_check_timestamps_sync_synced(synced_timestamps_factory):
fps = 30
tolerance_s = 1e-4
timestamps, ep_idx, ep_data_index = synced_timestamps_factory(fps)
result = check_timestamps_sync(
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=ep_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_timestamps_sync_unsynced(unsynced_timestamps_factory):
fps = 30
tolerance_s = 1e-4
timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s)
with pytest.raises(ValueError):
check_timestamps_sync(
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=ep_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
def test_check_timestamps_sync_unsynced_no_exception(unsynced_timestamps_factory):
fps = 30
tolerance_s = 1e-4
timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s)
result = check_timestamps_sync(
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=ep_data_index,
fps=fps,
tolerance_s=tolerance_s,
raise_value_error=False,
)
assert result is False
def test_check_timestamps_sync_slightly_off(slightly_off_timestamps_factory):
fps = 30
tolerance_s = 1e-4
timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(fps, tolerance_s)
result = check_timestamps_sync(
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=ep_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_timestamps_sync_single_timestamp():
fps = 30
tolerance_s = 1e-4
timestamps, ep_idx = np.array([0.0]), np.array([0])
episode_data_index = {"to": np.array([1]), "from": np.array([0])}
result = check_timestamps_sync(
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=episode_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_delta_timestamps_valid(valid_delta_timestamps_factory):
fps = 30
tolerance_s = 1e-4
valid_delta_timestamps = valid_delta_timestamps_factory(fps)
result = check_delta_timestamps(
delta_timestamps=valid_delta_timestamps,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_delta_timestamps_slightly_off(slightly_off_delta_timestamps_factory):
fps = 30
tolerance_s = 1e-4
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(fps, tolerance_s)
result = check_delta_timestamps(
delta_timestamps=slightly_off_delta_timestamps,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_delta_timestamps_invalid(invalid_delta_timestamps_factory):
fps = 30
tolerance_s = 1e-4
invalid_delta_timestamps = invalid_delta_timestamps_factory(fps, tolerance_s)
with pytest.raises(ValueError):
check_delta_timestamps(
delta_timestamps=invalid_delta_timestamps,
fps=fps,
tolerance_s=tolerance_s,
)
def test_check_delta_timestamps_invalid_no_exception(invalid_delta_timestamps_factory):
fps = 30
tolerance_s = 1e-4
invalid_delta_timestamps = invalid_delta_timestamps_factory(fps, tolerance_s)
result = check_delta_timestamps(
delta_timestamps=invalid_delta_timestamps,
fps=fps,
tolerance_s=tolerance_s,
raise_value_error=False,
)
assert result is False
def test_check_delta_timestamps_empty():
delta_timestamps = {}
fps = 30
tolerance_s = 1e-4
result = check_delta_timestamps(
delta_timestamps=delta_timestamps,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_delta_indices(valid_delta_timestamps_factory, delta_indices_factory):
fps = 50
min_max_range = (-100, 100)
delta_timestamps = valid_delta_timestamps_factory(fps, min_max_range=min_max_range)
expected_delta_indices = delta_indices_factory(min_max_range=min_max_range)
actual_delta_indices = get_delta_indices(delta_timestamps, fps)
assert expected_delta_indices == actual_delta_indices

View File

@ -1,382 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from packaging import version
from safetensors.torch import load_file
from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F # noqa: N812
from lerobot.datasets.transforms import (
ImageTransformConfig,
ImageTransforms,
ImageTransformsConfig,
RandomSubsetApply,
SharpnessJitter,
make_transform_from_config,
)
from lerobot.scripts.visualize_image_transforms import (
save_all_transforms,
save_each_transform,
)
from lerobot.utils.random_utils import seeded_context
from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ARTIFACT_DIR
from tests.utils import require_x86_64_kernel
@pytest.fixture
def color_jitters():
return [
v2.ColorJitter(brightness=0.5),
v2.ColorJitter(contrast=0.5),
v2.ColorJitter(saturation=0.5),
]
@pytest.fixture
def single_transforms():
return load_file(ARTIFACT_DIR / "single_transforms.safetensors")
@pytest.fixture
def img_tensor(single_transforms):
return single_transforms["original_frame"]
@pytest.fixture
def default_transforms():
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
def test_get_image_transforms_no_transform_enable_false(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig() # default is enable=False
tf_actual = ImageTransforms(tf_cfg)
torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
def test_get_image_transforms_no_transform_max_num_transforms_0(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(enable=True, max_num_transforms=0)
tf_actual = ImageTransforms(tf_cfg)
torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_brightness(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(
enable=True,
tfs={"brightness": ImageTransformConfig(type="ColorJitter", kwargs={"brightness": min_max})},
)
tf_actual = ImageTransforms(tf_cfg)
tf_expected = v2.ColorJitter(brightness=min_max)
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_contrast(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(
enable=True, tfs={"contrast": ImageTransformConfig(type="ColorJitter", kwargs={"contrast": min_max})}
)
tf_actual = ImageTransforms(tf_cfg)
tf_expected = v2.ColorJitter(contrast=min_max)
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_saturation(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(
enable=True,
tfs={"saturation": ImageTransformConfig(type="ColorJitter", kwargs={"saturation": min_max})},
)
tf_actual = ImageTransforms(tf_cfg)
tf_expected = v2.ColorJitter(saturation=min_max)
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)])
def test_get_image_transforms_hue(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(
enable=True, tfs={"hue": ImageTransformConfig(type="ColorJitter", kwargs={"hue": min_max})}
)
tf_actual = ImageTransforms(tf_cfg)
tf_expected = v2.ColorJitter(hue=min_max)
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(
enable=True,
tfs={"sharpness": ImageTransformConfig(type="SharpnessJitter", kwargs={"sharpness": min_max})},
)
tf_actual = ImageTransforms(tf_cfg)
tf_expected = SharpnessJitter(sharpness=min_max)
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
def test_get_image_transforms_max_num_transforms(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(
enable=True,
max_num_transforms=5,
tfs={
"brightness": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"brightness": (0.5, 0.5)},
),
"contrast": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"contrast": (0.5, 0.5)},
),
"saturation": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"saturation": (0.5, 0.5)},
),
"hue": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"hue": (0.5, 0.5)},
),
"sharpness": ImageTransformConfig(
weight=1.0,
type="SharpnessJitter",
kwargs={"sharpness": (0.5, 0.5)},
),
},
)
tf_actual = ImageTransforms(tf_cfg)
tf_expected = v2.Compose(
[
v2.ColorJitter(brightness=(0.5, 0.5)),
v2.ColorJitter(contrast=(0.5, 0.5)),
v2.ColorJitter(saturation=(0.5, 0.5)),
v2.ColorJitter(hue=(0.5, 0.5)),
SharpnessJitter(sharpness=(0.5, 0.5)),
]
)
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@require_x86_64_kernel
def test_get_image_transforms_random_order(img_tensor_factory):
out_imgs = []
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(
enable=True,
random_order=True,
tfs={
"brightness": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"brightness": (0.5, 0.5)},
),
"contrast": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"contrast": (0.5, 0.5)},
),
"saturation": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"saturation": (0.5, 0.5)},
),
"hue": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"hue": (0.5, 0.5)},
),
"sharpness": ImageTransformConfig(
weight=1.0,
type="SharpnessJitter",
kwargs={"sharpness": (0.5, 0.5)},
),
},
)
tf = ImageTransforms(tf_cfg)
with seeded_context(1338):
for _ in range(10):
out_imgs.append(tf(img_tensor))
tmp_img_tensor = img_tensor
for sub_tf in tf.tf.selected_transforms:
tmp_img_tensor = sub_tf(tmp_img_tensor)
torch.testing.assert_close(tmp_img_tensor, out_imgs[-1])
for i in range(1, len(out_imgs)):
with pytest.raises(AssertionError):
torch.testing.assert_close(out_imgs[0], out_imgs[i])
@pytest.mark.parametrize(
"tf_type, tf_name, min_max_values",
[
("ColorJitter", "brightness", [(0.5, 0.5), (2.0, 2.0)]),
("ColorJitter", "contrast", [(0.5, 0.5), (2.0, 2.0)]),
("ColorJitter", "saturation", [(0.5, 0.5), (2.0, 2.0)]),
("ColorJitter", "hue", [(-0.25, -0.25), (0.25, 0.25)]),
("SharpnessJitter", "sharpness", [(0.5, 0.5), (2.0, 2.0)]),
],
)
def test_backward_compatibility_single_transforms(
img_tensor, tf_type, tf_name, min_max_values, single_transforms
):
for min_max in min_max_values:
tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max})
tf = make_transform_from_config(tf_cfg)
actual = tf(img_tensor)
key = f"{tf_name}_{min_max[0]}_{min_max[1]}"
expected = single_transforms[key]
torch.testing.assert_close(actual, expected)
@require_x86_64_kernel
@pytest.mark.skipif(
version.parse(torch.__version__) < version.parse("2.7.0"),
reason="Test artifacts were generated with PyTorch >= 2.7.0 which has different multinomial behavior",
)
def test_backward_compatibility_default_config(img_tensor, default_transforms):
# NOTE: PyTorch versions have different randomness, it might break this test.
# See this PR: https://github.com/huggingface/lerobot/pull/1127.
cfg = ImageTransformsConfig(enable=True)
default_tf = ImageTransforms(cfg)
with seeded_context(1337):
actual = default_tf(img_tensor)
expected = default_transforms["default"]
torch.testing.assert_close(actual, expected)
@pytest.mark.parametrize("p", [[0, 1], [1, 0]])
def test_random_subset_apply_single_choice(img_tensor_factory, p):
img_tensor = img_tensor_factory()
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False)
actual = random_choice(img_tensor)
p_horz, _ = p
if p_horz:
torch.testing.assert_close(actual, F.horizontal_flip(img_tensor))
else:
torch.testing.assert_close(actual, F.vertical_flip(img_tensor))
def test_random_subset_apply_random_order(img_tensor_factory):
img_tensor = img_tensor_factory()
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True)
# We can't really check whether the transforms are actually applied in random order. However,
# horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
# applies them in random order, we can use a fixed order to compute the expected value.
actual = random_order(img_tensor)
expected = v2.Compose(flips)(img_tensor)
torch.testing.assert_close(actual, expected)
def test_random_subset_apply_valid_transforms(img_tensor_factory, color_jitters):
img_tensor = img_tensor_factory()
transform = RandomSubsetApply(color_jitters)
output = transform(img_tensor)
assert output.shape == img_tensor.shape
def test_random_subset_apply_probability_length_mismatch(color_jitters):
with pytest.raises(ValueError):
RandomSubsetApply(color_jitters, p=[0.5, 0.5])
@pytest.mark.parametrize("n_subset", [0, 5])
def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset):
with pytest.raises(ValueError):
RandomSubsetApply(color_jitters, n_subset=n_subset)
def test_sharpness_jitter_valid_range_tuple(img_tensor_factory):
img_tensor = img_tensor_factory()
tf = SharpnessJitter((0.1, 2.0))
output = tf(img_tensor)
assert output.shape == img_tensor.shape
def test_sharpness_jitter_valid_range_float(img_tensor_factory):
img_tensor = img_tensor_factory()
tf = SharpnessJitter(0.5)
output = tf(img_tensor)
assert output.shape == img_tensor.shape
def test_sharpness_jitter_invalid_range_min_negative():
with pytest.raises(ValueError):
SharpnessJitter((-0.1, 2.0))
def test_sharpness_jitter_invalid_range_max_smaller():
with pytest.raises(ValueError):
SharpnessJitter((2.0, 0.1))
def test_save_all_transforms(img_tensor_factory, tmp_path):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(enable=True)
n_examples = 3
save_all_transforms(tf_cfg, img_tensor, tmp_path, n_examples)
# Check if the combined transforms directory exists and contains the right files
combined_transforms_dir = tmp_path / "all"
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
assert any(combined_transforms_dir.iterdir()), (
"No transformed images found in combined transforms directory."
)
for i in range(1, n_examples + 1):
assert (combined_transforms_dir / f"{i}.png").exists(), (
f"Combined transform image {i}.png was not found."
)
def test_save_each_transform(img_tensor_factory, tmp_path):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(enable=True)
n_examples = 3
save_each_transform(tf_cfg, img_tensor, tmp_path, n_examples)
# Check if the transformed images exist for each transform type
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"]
for transform in transforms:
transform_dir = tmp_path / transform
assert transform_dir.exists(), f"{transform} directory was not created."
assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory."
# Check for specific files within each transform directory
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
for file_name in expected_files:
assert (transform_dir / file_name).exists(), (
f"{file_name} was not found in {transform} directory."
)

View File

@ -1,386 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import queue
import time
from multiprocessing import queues
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from PIL import Image
from lerobot.datasets.image_writer import (
AsyncImageWriter,
image_array_to_pil_image,
safe_stop_image_writer,
write_image,
)
from tests.fixtures.constants import DUMMY_HWC
DUMMY_IMAGE = "test_image.png"
def test_init_threading():
writer = AsyncImageWriter(num_processes=0, num_threads=2)
try:
assert writer.num_processes == 0
assert writer.num_threads == 2
assert isinstance(writer.queue, queue.Queue)
assert len(writer.threads) == 2
assert len(writer.processes) == 0
assert all(t.is_alive() for t in writer.threads)
finally:
writer.stop()
def test_init_multiprocessing():
writer = AsyncImageWriter(num_processes=2, num_threads=2)
try:
assert writer.num_processes == 2
assert writer.num_threads == 2
assert isinstance(writer.queue, queues.JoinableQueue)
assert len(writer.threads) == 0
assert len(writer.processes) == 2
assert all(p.is_alive() for p in writer.processes)
finally:
writer.stop()
def test_zero_threads():
with pytest.raises(ValueError):
AsyncImageWriter(num_processes=0, num_threads=0)
def test_image_array_to_pil_image_float_array_wrong_range_0_255():
image = np.random.rand(*DUMMY_HWC) * 255
with pytest.raises(ValueError):
image_array_to_pil_image(image)
def test_image_array_to_pil_image_float_array_wrong_range_neg_1_1():
image = np.random.rand(*DUMMY_HWC) * 2 - 1
with pytest.raises(ValueError):
image_array_to_pil_image(image)
def test_image_array_to_pil_image_rgb(img_array_factory):
img_array = img_array_factory(100, 100)
result_image = image_array_to_pil_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "RGB"
def test_image_array_to_pil_image_pytorch_format(img_array_factory):
img_array = img_array_factory(100, 100).transpose(2, 0, 1)
result_image = image_array_to_pil_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "RGB"
def test_image_array_to_pil_image_single_channel(img_array_factory):
img_array = img_array_factory(channels=1)
with pytest.raises(NotImplementedError):
image_array_to_pil_image(img_array)
def test_image_array_to_pil_image_4_channels(img_array_factory):
img_array = img_array_factory(channels=4)
with pytest.raises(NotImplementedError):
image_array_to_pil_image(img_array)
def test_image_array_to_pil_image_float_array(img_array_factory):
img_array = img_array_factory(dtype=np.float32)
result_image = image_array_to_pil_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "RGB"
assert np.array(result_image).dtype == np.uint8
def test_image_array_to_pil_image_uint8_array(img_array_factory):
img_array = img_array_factory(dtype=np.float32)
result_image = image_array_to_pil_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "RGB"
assert np.array(result_image).dtype == np.uint8
def test_write_image_numpy(tmp_path, img_array_factory):
image_array = img_array_factory()
fpath = tmp_path / DUMMY_IMAGE
write_image(image_array, fpath)
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
assert np.array_equal(image_array, saved_image)
def test_write_image_image(tmp_path, img_factory):
image_pil = img_factory()
fpath = tmp_path / DUMMY_IMAGE
write_image(image_pil, fpath)
assert fpath.exists()
saved_image = Image.open(fpath)
assert list(saved_image.getdata()) == list(image_pil.getdata())
assert np.array_equal(image_pil, saved_image)
def test_write_image_exception(tmp_path):
image_array = "invalid data"
fpath = tmp_path / DUMMY_IMAGE
with patch("builtins.print") as mock_print:
write_image(image_array, fpath)
mock_print.assert_called()
assert not fpath.exists()
def test_save_image_numpy(tmp_path, img_array_factory):
writer = AsyncImageWriter()
try:
image_array = img_array_factory()
fpath = tmp_path / DUMMY_IMAGE
fpath.parent.mkdir(parents=True, exist_ok=True)
writer.save_image(image_array, fpath)
writer.wait_until_done()
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
assert np.array_equal(image_array, saved_image)
finally:
writer.stop()
def test_save_image_numpy_multiprocessing(tmp_path, img_array_factory):
writer = AsyncImageWriter(num_processes=2, num_threads=2)
try:
image_array = img_array_factory()
fpath = tmp_path / DUMMY_IMAGE
writer.save_image(image_array, fpath)
writer.wait_until_done()
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
assert np.array_equal(image_array, saved_image)
finally:
writer.stop()
def test_save_image_torch(tmp_path, img_tensor_factory):
writer = AsyncImageWriter()
try:
image_tensor = img_tensor_factory()
fpath = tmp_path / DUMMY_IMAGE
fpath.parent.mkdir(parents=True, exist_ok=True)
writer.save_image(image_tensor, fpath)
writer.wait_until_done()
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
assert np.array_equal(expected_image, saved_image)
finally:
writer.stop()
def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
writer = AsyncImageWriter(num_processes=2, num_threads=2)
try:
image_tensor = img_tensor_factory()
fpath = tmp_path / DUMMY_IMAGE
writer.save_image(image_tensor, fpath)
writer.wait_until_done()
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
assert np.array_equal(expected_image, saved_image)
finally:
writer.stop()
def test_save_image_pil(tmp_path, img_factory):
writer = AsyncImageWriter()
try:
image_pil = img_factory()
fpath = tmp_path / DUMMY_IMAGE
fpath.parent.mkdir(parents=True, exist_ok=True)
writer.save_image(image_pil, fpath)
writer.wait_until_done()
assert fpath.exists()
saved_image = Image.open(fpath)
assert list(saved_image.getdata()) == list(image_pil.getdata())
finally:
writer.stop()
def test_save_image_pil_multiprocessing(tmp_path, img_factory):
writer = AsyncImageWriter(num_processes=2, num_threads=2)
try:
image_pil = img_factory()
fpath = tmp_path / DUMMY_IMAGE
writer.save_image(image_pil, fpath)
writer.wait_until_done()
assert fpath.exists()
saved_image = Image.open(fpath)
assert list(saved_image.getdata()) == list(image_pil.getdata())
finally:
writer.stop()
def test_save_image_invalid_data(tmp_path):
writer = AsyncImageWriter()
try:
image_array = "invalid data"
fpath = tmp_path / DUMMY_IMAGE
fpath.parent.mkdir(parents=True, exist_ok=True)
with patch("builtins.print") as mock_print:
writer.save_image(image_array, fpath)
writer.wait_until_done()
mock_print.assert_called()
assert not fpath.exists()
finally:
writer.stop()
def test_save_image_after_stop(tmp_path, img_array_factory):
writer = AsyncImageWriter()
writer.stop()
image_array = img_array_factory()
fpath = tmp_path / DUMMY_IMAGE
writer.save_image(image_array, fpath)
time.sleep(1)
assert not fpath.exists()
def test_stop():
writer = AsyncImageWriter(num_processes=0, num_threads=2)
writer.stop()
assert not any(t.is_alive() for t in writer.threads)
def test_stop_multiprocessing():
writer = AsyncImageWriter(num_processes=2, num_threads=2)
writer.stop()
assert not any(p.is_alive() for p in writer.processes)
def test_multiple_stops():
writer = AsyncImageWriter()
writer.stop()
writer.stop() # Should not raise an exception
assert not any(t.is_alive() for t in writer.threads)
def test_multiple_stops_multiprocessing():
writer = AsyncImageWriter(num_processes=2, num_threads=2)
writer.stop()
writer.stop() # Should not raise an exception
assert not any(t.is_alive() for t in writer.threads)
def test_wait_until_done(tmp_path, img_array_factory):
writer = AsyncImageWriter(num_processes=0, num_threads=4)
try:
num_images = 100
image_arrays = [img_array_factory(height=500, width=500) for _ in range(num_images)]
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
fpath.parent.mkdir(parents=True, exist_ok=True)
writer.save_image(image_array, fpath)
writer.wait_until_done()
for i, fpath in enumerate(fpaths):
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
assert np.array_equal(saved_image, image_arrays[i])
finally:
writer.stop()
def test_wait_until_done_multiprocessing(tmp_path, img_array_factory):
writer = AsyncImageWriter(num_processes=2, num_threads=2)
try:
num_images = 100
image_arrays = [img_array_factory() for _ in range(num_images)]
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
fpath.parent.mkdir(parents=True, exist_ok=True)
writer.save_image(image_array, fpath)
writer.wait_until_done()
for i, fpath in enumerate(fpaths):
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
assert np.array_equal(saved_image, image_arrays[i])
finally:
writer.stop()
def test_exception_handling(tmp_path, img_array_factory):
writer = AsyncImageWriter()
try:
image_array = img_array_factory()
with (
patch.object(writer.queue, "put", side_effect=queue.Full("Queue is full")),
pytest.raises(queue.Full) as exc_info,
):
writer.save_image(image_array, tmp_path / "test.png")
assert str(exc_info.value) == "Queue is full"
finally:
writer.stop()
def test_with_different_image_formats(tmp_path, img_array_factory):
writer = AsyncImageWriter()
try:
image_array = img_array_factory()
formats = ["png", "jpeg", "bmp"]
for fmt in formats:
fpath = tmp_path / f"test_image.{fmt}"
write_image(image_array, fpath)
assert fpath.exists()
finally:
writer.stop()
def test_safe_stop_image_writer_decorator():
class MockDataset:
def __init__(self):
self.image_writer = MagicMock(spec=AsyncImageWriter)
@safe_stop_image_writer
def function_that_raises_exception(dataset=None):
raise Exception("Test exception")
dataset = MockDataset()
with pytest.raises(Exception) as exc_info:
function_that_raises_exception(dataset=dataset)
assert str(exc_info.value) == "Test exception"
dataset.image_writer.stop.assert_called_once()
def test_main_process_time(tmp_path, img_tensor_factory):
writer = AsyncImageWriter()
try:
image_tensor = img_tensor_factory()
fpath = tmp_path / DUMMY_IMAGE
start_time = time.perf_counter()
writer.save_image(image_tensor, fpath)
end_time = time.perf_counter()
time_spent = end_time - start_time
# Might need to adjust this threshold depending on hardware
assert time_spent < 0.01, f"Main process time exceeded threshold: {time_spent}s"
writer.wait_until_done()
assert fpath.exists()
finally:
writer.stop()

View File

@ -1,282 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.d
from copy import deepcopy
from uuid import uuid4
import numpy as np
import pytest
import torch
from lerobot.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
# Some constants for OnlineBuffer tests.
data_key = "data"
data_shape = (2, 3) # just some arbitrary > 1D shape
buffer_capacity = 100
fps = 10
def make_new_buffer(
write_dir: str | None = None, delta_timestamps: dict[str, list[float]] | None = None
) -> tuple[OnlineBuffer, str]:
if write_dir is None:
write_dir = f"/tmp/online_buffer_{uuid4().hex}"
buffer = OnlineBuffer(
write_dir,
data_spec={data_key: {"shape": data_shape, "dtype": np.dtype("float32")}},
buffer_capacity=buffer_capacity,
fps=fps,
delta_timestamps=delta_timestamps,
)
return buffer, write_dir
def make_spoof_data_frames(n_episodes: int, n_frames_per_episode: int) -> dict[str, np.ndarray]:
new_data = {
data_key: np.arange(n_frames_per_episode * n_episodes * np.prod(data_shape)).reshape(-1, *data_shape),
OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes),
OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(np.arange(n_episodes), n_frames_per_episode),
OnlineBuffer.FRAME_INDEX_KEY: np.tile(np.arange(n_frames_per_episode), n_episodes),
OnlineBuffer.TIMESTAMP_KEY: np.tile(np.arange(n_frames_per_episode) / fps, n_episodes),
}
return new_data
def test_non_mutate():
"""Checks that the data provided to the add_data method is copied rather than passed by reference.
This means that mutating the data in the buffer does not mutate the original data.
NOTE: If this test fails, it means some of the other tests may be compromised. For example, we can't trust
a success case for `test_write_read`.
"""
buffer, _ = make_new_buffer()
new_data = make_spoof_data_frames(2, buffer_capacity // 4)
new_data_copy = deepcopy(new_data)
buffer.add_data(new_data)
buffer._data[data_key][:] += 1
assert all(np.array_equal(new_data[k], new_data_copy[k]) for k in new_data)
def test_index_error_no_data():
buffer, _ = make_new_buffer()
with pytest.raises(IndexError):
buffer[0]
def test_index_error_with_data():
buffer, _ = make_new_buffer()
n_frames = buffer_capacity // 2
new_data = make_spoof_data_frames(1, n_frames)
buffer.add_data(new_data)
with pytest.raises(IndexError):
buffer[n_frames]
with pytest.raises(IndexError):
buffer[-n_frames - 1]
@pytest.mark.parametrize("do_reload", [False, True])
def test_write_read(do_reload: bool):
"""Checks that data can be added to the buffer and read back.
If do_reload we delete the buffer object and load the buffer back from disk before reading.
"""
buffer, write_dir = make_new_buffer()
n_episodes = 2
n_frames_per_episode = buffer_capacity // 4
new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode)
buffer.add_data(new_data)
if do_reload:
del buffer
buffer, _ = make_new_buffer(write_dir)
assert len(buffer) == n_frames_per_episode * n_episodes
for i, item in enumerate(buffer):
assert all(isinstance(item[k], torch.Tensor) for k in item)
assert np.array_equal(item[data_key].numpy(), new_data[data_key][i])
def test_read_data_key():
"""Tests that data can be added to a buffer and all data for a. specific key can be read back."""
buffer, _ = make_new_buffer()
n_episodes = 2
n_frames_per_episode = buffer_capacity // 4
new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode)
buffer.add_data(new_data)
data_from_buffer = buffer.get_data_by_key(data_key)
assert isinstance(data_from_buffer, torch.Tensor)
assert np.array_equal(data_from_buffer.numpy(), new_data[data_key])
def test_fifo():
"""Checks that if data is added beyond the buffer capacity, we discard the oldest data first."""
buffer, _ = make_new_buffer()
n_frames_per_episode = buffer_capacity // 4
n_episodes = 3
new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode)
buffer.add_data(new_data)
n_more_episodes = 2
# Developer sanity check (in case someone changes the global `buffer_capacity`).
assert (n_episodes + n_more_episodes) * n_frames_per_episode > buffer_capacity, (
"Something went wrong with the test code."
)
more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode)
buffer.add_data(more_new_data)
assert len(buffer) == buffer_capacity, "The buffer should be full."
expected_data = {}
for k in new_data:
# Concatenate, left-truncate, then roll, to imitate the cyclical FIFO pattern in OnlineBuffer.
expected_data[k] = np.roll(
np.concatenate([new_data[k], more_new_data[k]])[-buffer_capacity:],
shift=len(new_data[k]) + len(more_new_data[k]) - buffer_capacity,
axis=0,
)
for i, item in enumerate(buffer):
assert all(isinstance(item[k], torch.Tensor) for k in item)
assert np.array_equal(item[data_key].numpy(), expected_data[data_key][i])
def test_delta_timestamps_within_tolerance():
"""Check that getting an item with delta_timestamps within tolerance succeeds.
Note: Copied from `test_datasets.py::test_load_previous_and_future_frames_within_tolerance`.
"""
# Sanity check on global fps as we are assuming it is 10 here.
assert fps == 10, "This test assumes fps==10"
buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.139]})
new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5)
buffer.add_data(new_data)
buffer.tolerance_s = 0.04
item = buffer[2]
data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"]
torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values")
assert not is_pad.any(), "Unexpected padding detected"
def test_delta_timestamps_outside_tolerance_inside_episode_range():
"""Check that getting an item with delta_timestamps outside of tolerance fails.
We expect it to fail if and only if the requested timestamps are within the episode range.
Note: Copied from
`test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_inside_episode_range`
"""
# Sanity check on global fps as we are assuming it is 10 here.
assert fps == 10, "This test assumes fps==10"
buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.141]})
new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5)
buffer.add_data(new_data)
buffer.tolerance_s = 0.04
with pytest.raises(AssertionError):
buffer[2]
def test_delta_timestamps_outside_tolerance_outside_episode_range():
"""Check that copy-padding of timestamps outside of the episode range works.
Note: Copied from
`test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_outside_episode_range`
"""
# Sanity check on global fps as we are assuming it is 10 here.
assert fps == 10, "This test assumes fps==10"
buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.3, -0.24, 0, 0.26, 0.3]})
new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5)
buffer.add_data(new_data)
buffer.tolerance_s = 0.04
item = buffer[2]
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), (
"Padding does not match expected values"
)
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
@pytest.mark.parametrize("offline_dataset_size", [1, 6])
@pytest.mark.parametrize("online_dataset_size", [0, 4])
@pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0])
def test_compute_sampler_weights_trivial(
lerobot_dataset_factory,
tmp_path,
offline_dataset_size: int,
online_dataset_size: int,
online_sampling_ratio: float,
):
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size)
online_dataset, _ = make_new_buffer()
if online_dataset_size > 0:
online_dataset.add_data(
make_spoof_data_frames(n_episodes=2, n_frames_per_episode=online_dataset_size // 2)
)
weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
)
if offline_dataset_size == 0 or online_dataset_size == 0:
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
elif online_sampling_ratio == 0:
expected_weights = torch.cat([torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)])
elif online_sampling_ratio == 1:
expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)])
expected_weights /= expected_weights.sum()
torch.testing.assert_close(weights, expected_weights)
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
online_sampling_ratio = 0.8
weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
)
torch.testing.assert_close(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
)
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1
)
torch.testing.assert_close(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
)
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path):
"""Note: test copied from test_sampler."""
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
weights = compute_sampler_weights(
offline_dataset,
offline_drop_n_last_frames=1,
online_dataset=online_dataset,
online_sampling_ratio=0.5,
online_drop_n_last_frames=1,
)
torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))

View File

@ -1,90 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datasets import Dataset
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import (
hf_transform_to_torch,
)
def test_drop_n_first_frames():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
assert sampler.indices == [1, 4, 5]
assert len(sampler) == 3
assert list(sampler) == [1, 4, 5]
def test_drop_n_last_frames():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
assert sampler.indices == [0, 3, 4]
assert len(sampler) == 3
assert list(sampler) == [0, 3, 4]
def test_episode_indices_to_use():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
assert sampler.indices == [0, 1, 3, 4, 5]
assert len(sampler) == 5
assert list(sampler) == [0, 1, 3, 4, 5]
def test_shuffle():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, shuffle=False)
assert sampler.indices == [0, 1, 2, 3, 4, 5]
assert len(sampler) == 6
assert list(sampler) == [0, 1, 2, 3, 4, 5]
sampler = EpisodeAwareSampler(episode_data_index, shuffle=True)
assert sampler.indices == [0, 1, 2, 3, 4, 5]
assert len(sampler) == 6
assert set(sampler) == {0, 1, 2, 3, 4, 5}

View File

@ -1,55 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from datasets import Dataset
from huggingface_hub import DatasetCard
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
def test_default_parameters():
card = create_lerobot_dataset_card()
assert isinstance(card, DatasetCard)
assert card.data.tags == ["LeRobot"]
assert card.data.task_categories == ["robotics"]
assert card.data.configs == [
{
"config_name": "default",
"data_files": "data/*/*.parquet",
}
]
def test_with_tags():
tags = ["tag1", "tag2"]
card = create_lerobot_dataset_card(tags=tags)
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
def test_calculate_episode_data_index():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))

View File

@ -1,33 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from lerobot.scripts.visualize_dataset import visualize_dataset
@pytest.mark.skip("TODO: add dummy videos")
def test_visualize_local_dataset(tmp_path, lerobot_dataset_factory):
root = tmp_path / "dataset"
output_dir = tmp_path / "outputs"
dataset = lerobot_dataset_factory(root=root)
rrd_path = visualize_dataset(
dataset,
episode_index=0,
batch_size=32,
save=True,
output_dir=output_dir,
)
assert rrd_path.exists()

View File

@ -1,63 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import gymnasium as gym
import pytest
import torch
from gymnasium.utils.env_checker import check_env
import lerobot
from lerobot.envs.factory import make_env, make_env_config
from lerobot.envs.utils import preprocess_observation
from tests.utils import require_env
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
@pytest.mark.parametrize("obs_type", OBS_TYPES)
@pytest.mark.parametrize("env_name, env_task", lerobot.env_task_pairs)
@require_env
def test_env(env_name, env_task, obs_type):
if env_name == "aloha" and obs_type == "state":
pytest.skip("`state` observations not available for aloha")
package_name = f"gym_{env_name}"
importlib.import_module(package_name)
env = gym.make(f"{package_name}/{env_task}", obs_type=obs_type)
check_env(env.unwrapped, skip_render_check=True)
env.close()
@pytest.mark.parametrize("env_name", lerobot.available_envs)
@require_env
def test_factory(env_name):
cfg = make_env_config(env_name)
env = make_env(cfg, n_envs=1)
obs, _ = env.reset()
obs = preprocess_observation(obs)
# test image keys are float32 in range [0,1]
for key in obs:
if "image" not in key:
continue
img = obs[key]
assert img.dtype == torch.float32
# TODO(rcadene): we assume for now that image normalization takes place in the model
assert img.max() <= 1.0
assert img.min() >= 0.0
env.close()

View File

@ -1,147 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import subprocess
import sys
from pathlib import Path
import pytest
from tests.fixtures.constants import DUMMY_REPO_ID
from tests.utils import require_package
def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str:
for f, r in finds_and_replaces:
assert f in text
text = text.replace(f, r)
return text
# TODO(aliberts): Remove usage of subprocess calls and patch code with fixtures
def _run_script(path):
subprocess.run([sys.executable, path], check=True)
def _read_file(path):
with open(path) as file:
return file.read()
@pytest.mark.skip("TODO Fix and remove subprocess / excec calls")
def test_example_1(tmp_path, lerobot_dataset_factory):
_ = lerobot_dataset_factory(root=tmp_path, repo_id=DUMMY_REPO_ID)
path = "examples/1_load_lerobot_dataset.py"
file_contents = _read_file(path)
file_contents = _find_and_replace(
file_contents,
[
('repo_id = "lerobot/pusht"', f'repo_id = "{DUMMY_REPO_ID}"'),
(
"LeRobotDataset(repo_id",
f"LeRobotDataset(repo_id, root='{str(tmp_path)}'",
),
],
)
exec(file_contents, {})
assert Path("outputs/examples/1_load_lerobot_dataset/episode_0.mp4").exists()
@pytest.mark.skip("TODO Fix and remove subprocess / excec calls")
@require_package("gym_pusht")
def test_examples_basic2_basic3_advanced1():
"""
Train a model with example 3, check the outputs.
Evaluate the trained model with example 2, check the outputs.
Calculate the validation loss with advanced example 1, check the outputs.
"""
### Test example 3
file_contents = _read_file("examples/3_train_policy.py")
# Do fewer steps, use smaller batch, use CPU, and don't complicate things with dataloader workers.
file_contents = _find_and_replace(
file_contents,
[
("training_steps = 5000", "training_steps = 1"),
("num_workers=4", "num_workers=0"),
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
("batch_size=64", "batch_size=1"),
],
)
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
exec(file_contents, {})
for file_name in ["model.safetensors", "config.json"]:
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
### Test example 2
file_contents = _read_file("examples/2_evaluate_pretrained_policy.py")
# Do fewer evals, use CPU, and use the local model.
file_contents = _find_and_replace(
file_contents,
[
(
'pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))',
"",
),
(
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
),
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
("step += 1", "break"),
],
)
exec(file_contents, {})
assert Path("outputs/eval/example_pusht_diffusion/rollout.mp4").exists()
## Test example 4
file_contents = _read_file("examples/advanced/2_calculate_validation_loss.py")
# Run on a single example from the last episode, use CPU, and use the local model.
file_contents = _find_and_replace(
file_contents,
[
(
'pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))',
"",
),
(
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
),
("train_episodes = episodes[:num_train_episodes]", "train_episodes = [0]"),
("val_episodes = episodes[num_train_episodes:]", "val_episodes = [1]"),
("num_workers=4", "num_workers=0"),
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
("batch_size=64", "batch_size=1"),
],
)
# Capture the output of the script
output_buffer = io.StringIO()
sys.stdout = output_buffer
exec(file_contents, {})
printed_output = output_buffer.getvalue()
# Restore stdout to its original state
sys.stdout = sys.__stdout__
assert "Average loss on validation set" in printed_output

View File

@ -1,44 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.constants import HF_LEROBOT_HOME
LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing"
DUMMY_REPO_ID = "dummy/repo"
DUMMY_ROBOT_TYPE = "dummy_robot"
DUMMY_MOTOR_FEATURES = {
"action": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
},
"state": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
},
}
DUMMY_CAMERA_FEATURES = {
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
}
DEFAULT_FPS = 30
DUMMY_VIDEO_INFO = {
"video.fps": DEFAULT_FPS,
"video.codec": "av1",
"video.pix_fmt": "yuv420p",
"video.is_depth_map": False,
"has_audio": False,
}
DUMMY_CHW = (3, 96, 128)
DUMMY_HWC = (96, 128, 3)

View File

@ -1,444 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from functools import partial
from pathlib import Path
from typing import Protocol
from unittest.mock import patch
import datasets
import numpy as np
import PIL.Image
import pytest
import torch
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_FEATURES,
DEFAULT_PARQUET_PATH,
DEFAULT_VIDEO_PATH,
get_hf_features_from_features,
hf_transform_to_torch,
)
from tests.fixtures.constants import (
DEFAULT_FPS,
DUMMY_CAMERA_FEATURES,
DUMMY_MOTOR_FEATURES,
DUMMY_REPO_ID,
DUMMY_ROBOT_TYPE,
DUMMY_VIDEO_INFO,
)
class LeRobotDatasetFactory(Protocol):
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
def get_task_index(task_dicts: dict, task: str) -> int:
tasks = {d["task_index"]: d["task"] for d in task_dicts.values()}
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
return task_to_task_index[task]
@pytest.fixture(scope="session")
def img_tensor_factory():
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
return torch.rand((channels, height, width), dtype=dtype)
return _create_img_tensor
@pytest.fixture(scope="session")
def img_array_factory():
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
if np.issubdtype(dtype, np.unsignedinteger):
# Int array in [0, 255] range
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
elif np.issubdtype(dtype, np.floating):
# Float array in [0, 1] range
img_array = np.random.rand(height, width, channels).astype(dtype)
else:
raise ValueError(dtype)
return img_array
return _create_img_array
@pytest.fixture(scope="session")
def img_factory(img_array_factory):
def _create_img(height=100, width=100) -> PIL.Image.Image:
img_array = img_array_factory(height=height, width=width)
return PIL.Image.fromarray(img_array)
return _create_img
@pytest.fixture(scope="session")
def features_factory():
def _create_features(
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True,
) -> dict:
if use_videos:
camera_ft = {
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
}
else:
camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()}
return {
**motor_features,
**camera_ft,
**DEFAULT_FEATURES,
}
return _create_features
@pytest.fixture(scope="session")
def info_factory(features_factory):
def _create_info(
codebase_version: str = CODEBASE_VERSION,
fps: int = DEFAULT_FPS,
robot_type: str = DUMMY_ROBOT_TYPE,
total_episodes: int = 0,
total_frames: int = 0,
total_tasks: int = 0,
total_videos: int = 0,
total_chunks: int = 0,
chunks_size: int = DEFAULT_CHUNK_SIZE,
data_path: str = DEFAULT_PARQUET_PATH,
video_path: str = DEFAULT_VIDEO_PATH,
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True,
) -> dict:
features = features_factory(motor_features, camera_features, use_videos)
return {
"codebase_version": codebase_version,
"robot_type": robot_type,
"total_episodes": total_episodes,
"total_frames": total_frames,
"total_tasks": total_tasks,
"total_videos": total_videos,
"total_chunks": total_chunks,
"chunks_size": chunks_size,
"fps": fps,
"splits": {},
"data_path": data_path,
"video_path": video_path if use_videos else None,
"features": features,
}
return _create_info
@pytest.fixture(scope="session")
def stats_factory():
def _create_stats(
features: dict[str] | None = None,
) -> dict:
stats = {}
for key, ft in features.items():
shape = ft["shape"]
dtype = ft["dtype"]
if dtype in ["image", "video"]:
stats[key] = {
"max": np.full((3, 1, 1), 1, dtype=np.float32).tolist(),
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32).tolist(),
"min": np.full((3, 1, 1), 0, dtype=np.float32).tolist(),
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
"count": [10],
}
else:
stats[key] = {
"max": np.full(shape, 1, dtype=dtype).tolist(),
"mean": np.full(shape, 0.5, dtype=dtype).tolist(),
"min": np.full(shape, 0, dtype=dtype).tolist(),
"std": np.full(shape, 0.25, dtype=dtype).tolist(),
"count": [10],
}
return stats
return _create_stats
@pytest.fixture(scope="session")
def episodes_stats_factory(stats_factory):
def _create_episodes_stats(
features: dict[str],
total_episodes: int = 3,
) -> dict:
episodes_stats = {}
for episode_index in range(total_episodes):
episodes_stats[episode_index] = {
"episode_index": episode_index,
"stats": stats_factory(features),
}
return episodes_stats
return _create_episodes_stats
@pytest.fixture(scope="session")
def tasks_factory():
def _create_tasks(total_tasks: int = 3) -> int:
tasks = {}
for task_index in range(total_tasks):
task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."}
tasks[task_index] = task_dict
return tasks
return _create_tasks
@pytest.fixture(scope="session")
def episodes_factory(tasks_factory):
def _create_episodes(
total_episodes: int = 3,
total_frames: int = 400,
tasks: dict | None = None,
multi_task: bool = False,
):
if total_episodes <= 0 or total_frames <= 0:
raise ValueError("num_episodes and total_length must be positive integers.")
if total_frames < total_episodes:
raise ValueError("total_length must be greater than or equal to num_episodes.")
if not tasks:
min_tasks = 2 if multi_task else 1
total_tasks = random.randint(min_tasks, total_episodes)
tasks = tasks_factory(total_tasks)
if total_episodes < len(tasks) and not multi_task:
raise ValueError("The number of tasks should be less than the number of episodes.")
# Generate random lengths that sum up to total_length
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
tasks_list = [task_dict["task"] for task_dict in tasks.values()]
num_tasks_available = len(tasks_list)
episodes = {}
remaining_tasks = tasks_list.copy()
for ep_idx in range(total_episodes):
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
if remaining_tasks:
for task in episode_tasks:
remaining_tasks.remove(task)
episodes[ep_idx] = {
"episode_index": ep_idx,
"tasks": episode_tasks,
"length": lengths[ep_idx],
}
return episodes
return _create_episodes
@pytest.fixture(scope="session")
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
def _create_hf_dataset(
features: dict | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
fps: int = DEFAULT_FPS,
) -> datasets.Dataset:
if not tasks:
tasks = tasks_factory()
if not episodes:
episodes = episodes_factory()
if not features:
features = features_factory()
timestamp_col = np.array([], dtype=np.float32)
frame_index_col = np.array([], dtype=np.int64)
episode_index_col = np.array([], dtype=np.int64)
task_index = np.array([], dtype=np.int64)
for ep_dict in episodes.values():
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
episode_index_col = np.concatenate(
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
)
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
index_col = np.arange(len(episode_index_col))
robot_cols = {}
for key, ft in features.items():
if ft["dtype"] == "image":
robot_cols[key] = [
img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0])
for _ in range(len(index_col))
]
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
hf_features = get_hf_features_from_features(features)
dataset = datasets.Dataset.from_dict(
{
**robot_cols,
"timestamp": timestamp_col,
"frame_index": frame_index_col,
"episode_index": episode_index_col,
"index": index_col,
"task_index": task_index,
},
features=hf_features,
)
dataset.set_transform(hf_transform_to_torch)
return dataset
return _create_hf_dataset
@pytest.fixture(scope="session")
def lerobot_dataset_metadata_factory(
info_factory,
stats_factory,
episodes_stats_factory,
tasks_factory,
episodes_factory,
mock_snapshot_download_factory,
):
def _create_lerobot_dataset_metadata(
root: Path,
repo_id: str = DUMMY_REPO_ID,
info: dict | None = None,
stats: dict | None = None,
episodes_stats: list[dict] | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
) -> LeRobotDatasetMetadata:
if not info:
info = info_factory()
if not stats:
stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(
features=info["features"], total_episodes=info["total_episodes"]
)
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
)
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
stats=stats,
episodes_stats=episodes_stats,
tasks=tasks,
episodes=episodes,
)
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download_patch,
):
mock_get_safe_version_patch.side_effect = lambda repo_id, version: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDatasetMetadata(repo_id=repo_id, root=root)
return _create_lerobot_dataset_metadata
@pytest.fixture(scope="session")
def lerobot_dataset_factory(
info_factory,
stats_factory,
episodes_stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
mock_snapshot_download_factory,
lerobot_dataset_metadata_factory,
) -> LeRobotDatasetFactory:
def _create_lerobot_dataset(
root: Path,
repo_id: str = DUMMY_REPO_ID,
total_episodes: int = 3,
total_frames: int = 150,
total_tasks: int = 1,
multi_task: bool = False,
info: dict | None = None,
stats: dict | None = None,
episodes_stats: list[dict] | None = None,
tasks: list[dict] | None = None,
episode_dicts: list[dict] | None = None,
hf_dataset: datasets.Dataset | None = None,
**kwargs,
) -> LeRobotDataset:
if not info:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
)
if not stats:
stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes)
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episode_dicts:
episode_dicts = episodes_factory(
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
multi_task=multi_task,
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
stats=stats,
episodes_stats=episodes_stats,
tasks=tasks,
episodes=episode_dicts,
hf_dataset=hf_dataset,
)
mock_metadata = lerobot_dataset_metadata_factory(
root=root,
repo_id=repo_id,
info=info,
stats=stats,
episodes_stats=episodes_stats,
tasks=tasks,
episodes=episode_dicts,
)
with (
patch("lerobot.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download_patch,
):
mock_metadata_patch.return_value = mock_metadata
mock_get_safe_version_patch.side_effect = lambda repo_id, version: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
return _create_lerobot_dataset
@pytest.fixture(scope="session")
def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)

View File

@ -1,147 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from pathlib import Path
import datasets
import jsonlines
import pyarrow.compute as pc
import pyarrow.parquet as pq
import pytest
from lerobot.datasets.utils import (
EPISODES_PATH,
EPISODES_STATS_PATH,
INFO_PATH,
STATS_PATH,
TASKS_PATH,
)
@pytest.fixture(scope="session")
def info_path(info_factory):
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path:
if not info:
info = info_factory()
fpath = dir / INFO_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f:
json.dump(info, f, indent=4, ensure_ascii=False)
return fpath
return _create_info_json_file
@pytest.fixture(scope="session")
def stats_path(stats_factory):
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path:
if not stats:
stats = stats_factory()
fpath = dir / STATS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f:
json.dump(stats, f, indent=4, ensure_ascii=False)
return fpath
return _create_stats_json_file
@pytest.fixture(scope="session")
def episodes_stats_path(episodes_stats_factory):
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path:
if not episodes_stats:
episodes_stats = episodes_stats_factory()
fpath = dir / EPISODES_STATS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(episodes_stats.values())
return fpath
return _create_episodes_stats_jsonl_file
@pytest.fixture(scope="session")
def tasks_path(tasks_factory):
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
if not tasks:
tasks = tasks_factory()
fpath = dir / TASKS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(tasks.values())
return fpath
return _create_tasks_jsonl_file
@pytest.fixture(scope="session")
def episode_path(episodes_factory):
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path:
if not episodes:
episodes = episodes_factory()
fpath = dir / EPISODES_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(episodes.values())
return fpath
return _create_episodes_jsonl_file
@pytest.fixture(scope="session")
def single_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_single_episode_parquet(
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
if not info:
info = info_factory()
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
data_path = info["data_path"]
chunks_size = info["chunks_size"]
ep_chunk = ep_idx // chunks_size
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
fpath.parent.mkdir(parents=True, exist_ok=True)
table = hf_dataset.data.table
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
pq.write_table(ep_table, fpath)
return fpath
return _create_single_episode_parquet
@pytest.fixture(scope="session")
def multi_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_multi_episode_parquet(
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
if not info:
info = info_factory()
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
data_path = info["data_path"]
chunks_size = info["chunks_size"]
total_episodes = info["total_episodes"]
for ep_idx in range(total_episodes):
ep_chunk = ep_idx // chunks_size
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
fpath.parent.mkdir(parents=True, exist_ok=True)
table = hf_dataset.data.table
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
pq.write_table(ep_table, fpath)
return dir / "data"
return _create_multi_episode_parquet

133
tests/fixtures/hub.py vendored
View File

@ -1,133 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import datasets
import pytest
from huggingface_hub.utils import filter_repo_objects
from lerobot.datasets.utils import (
EPISODES_PATH,
EPISODES_STATS_PATH,
INFO_PATH,
STATS_PATH,
TASKS_PATH,
)
from tests.fixtures.constants import LEROBOT_TEST_DIR
@pytest.fixture(scope="session")
def mock_snapshot_download_factory(
info_factory,
info_path,
stats_factory,
stats_path,
episodes_stats_factory,
episodes_stats_path,
tasks_factory,
tasks_path,
episodes_factory,
episode_path,
single_episode_parquet_path,
hf_dataset_factory,
):
"""
This factory allows to patch snapshot_download such that when called, it will create expected files rather
than making calls to the hub api. Its design allows to pass explicitly files which you want to be created.
"""
def _mock_snapshot_download_func(
info: dict | None = None,
stats: dict | None = None,
episodes_stats: list[dict] | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
hf_dataset: datasets.Dataset | None = None,
):
if not info:
info = info_factory()
if not stats:
stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(
features=info["features"], total_episodes=info["total_episodes"]
)
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
def _extract_episode_index_from_path(fpath: str) -> int:
path = Path(fpath)
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
return episode_index
else:
return None
def _mock_snapshot_download(
repo_id: str,
local_dir: str | Path | None = None,
allow_patterns: str | list[str] | None = None,
ignore_patterns: str | list[str] | None = None,
*args,
**kwargs,
) -> str:
if not local_dir:
local_dir = LEROBOT_TEST_DIR
# List all possible files
all_files = []
meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH]
all_files.extend(meta_files)
data_files = []
for episode_dict in episodes.values():
ep_idx = episode_dict["episode_index"]
ep_chunk = ep_idx // info["chunks_size"]
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
data_files.append(data_path)
all_files.extend(data_files)
allowed_files = filter_repo_objects(
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
)
# Create allowed files
for rel_path in allowed_files:
if rel_path.startswith("data/"):
episode_index = _extract_episode_index_from_path(rel_path)
if episode_index is not None:
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
if rel_path == INFO_PATH:
_ = info_path(local_dir, info)
elif rel_path == STATS_PATH:
_ = stats_path(local_dir, stats)
elif rel_path == EPISODES_STATS_PATH:
_ = episodes_stats_path(local_dir, episodes_stats)
elif rel_path == TASKS_PATH:
_ = tasks_path(local_dir, tasks)
elif rel_path == EPISODES_PATH:
_ = episode_path(local_dir, episodes)
else:
pass
return str(local_dir)
return _mock_snapshot_download
return _mock_snapshot_download_func

View File

@ -1,39 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from lerobot.optim.optimizers import AdamConfig
from lerobot.optim.schedulers import VQBeTSchedulerConfig
@pytest.fixture
def model_params():
return [torch.nn.Parameter(torch.randn(10, 10))]
@pytest.fixture
def optimizer(model_params):
optimizer = AdamConfig().build(model_params)
# Dummy step to populate state
loss = sum(param.sum() for param in model_params)
loss.backward()
optimizer.step()
return optimizer
@pytest.fixture
def scheduler(optimizer):
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
return config.build(optimizer, num_training_steps=100)

View File

@ -1,596 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from collections.abc import Callable
import dynamixel_sdk as dxl
import serial
from mock_serial.mock_serial import MockSerial
from lerobot.motors.dynamixel.dynamixel import _split_into_byte_chunks
from .mock_serial_patch import WaitableStub
# https://emanual.robotis.com/docs/en/dxl/crc/
DXL_CRC_TABLE = [
0x0000, 0x8005, 0x800F, 0x000A, 0x801B, 0x001E, 0x0014, 0x8011,
0x8033, 0x0036, 0x003C, 0x8039, 0x0028, 0x802D, 0x8027, 0x0022,
0x8063, 0x0066, 0x006C, 0x8069, 0x0078, 0x807D, 0x8077, 0x0072,
0x0050, 0x8055, 0x805F, 0x005A, 0x804B, 0x004E, 0x0044, 0x8041,
0x80C3, 0x00C6, 0x00CC, 0x80C9, 0x00D8, 0x80DD, 0x80D7, 0x00D2,
0x00F0, 0x80F5, 0x80FF, 0x00FA, 0x80EB, 0x00EE, 0x00E4, 0x80E1,
0x00A0, 0x80A5, 0x80AF, 0x00AA, 0x80BB, 0x00BE, 0x00B4, 0x80B1,
0x8093, 0x0096, 0x009C, 0x8099, 0x0088, 0x808D, 0x8087, 0x0082,
0x8183, 0x0186, 0x018C, 0x8189, 0x0198, 0x819D, 0x8197, 0x0192,
0x01B0, 0x81B5, 0x81BF, 0x01BA, 0x81AB, 0x01AE, 0x01A4, 0x81A1,
0x01E0, 0x81E5, 0x81EF, 0x01EA, 0x81FB, 0x01FE, 0x01F4, 0x81F1,
0x81D3, 0x01D6, 0x01DC, 0x81D9, 0x01C8, 0x81CD, 0x81C7, 0x01C2,
0x0140, 0x8145, 0x814F, 0x014A, 0x815B, 0x015E, 0x0154, 0x8151,
0x8173, 0x0176, 0x017C, 0x8179, 0x0168, 0x816D, 0x8167, 0x0162,
0x8123, 0x0126, 0x012C, 0x8129, 0x0138, 0x813D, 0x8137, 0x0132,
0x0110, 0x8115, 0x811F, 0x011A, 0x810B, 0x010E, 0x0104, 0x8101,
0x8303, 0x0306, 0x030C, 0x8309, 0x0318, 0x831D, 0x8317, 0x0312,
0x0330, 0x8335, 0x833F, 0x033A, 0x832B, 0x032E, 0x0324, 0x8321,
0x0360, 0x8365, 0x836F, 0x036A, 0x837B, 0x037E, 0x0374, 0x8371,
0x8353, 0x0356, 0x035C, 0x8359, 0x0348, 0x834D, 0x8347, 0x0342,
0x03C0, 0x83C5, 0x83CF, 0x03CA, 0x83DB, 0x03DE, 0x03D4, 0x83D1,
0x83F3, 0x03F6, 0x03FC, 0x83F9, 0x03E8, 0x83ED, 0x83E7, 0x03E2,
0x83A3, 0x03A6, 0x03AC, 0x83A9, 0x03B8, 0x83BD, 0x83B7, 0x03B2,
0x0390, 0x8395, 0x839F, 0x039A, 0x838B, 0x038E, 0x0384, 0x8381,
0x0280, 0x8285, 0x828F, 0x028A, 0x829B, 0x029E, 0x0294, 0x8291,
0x82B3, 0x02B6, 0x02BC, 0x82B9, 0x02A8, 0x82AD, 0x82A7, 0x02A2,
0x82E3, 0x02E6, 0x02EC, 0x82E9, 0x02F8, 0x82FD, 0x82F7, 0x02F2,
0x02D0, 0x82D5, 0x82DF, 0x02DA, 0x82CB, 0x02CE, 0x02C4, 0x82C1,
0x8243, 0x0246, 0x024C, 0x8249, 0x0258, 0x825D, 0x8257, 0x0252,
0x0270, 0x8275, 0x827F, 0x027A, 0x826B, 0x026E, 0x0264, 0x8261,
0x0220, 0x8225, 0x822F, 0x022A, 0x823B, 0x023E, 0x0234, 0x8231,
0x8213, 0x0216, 0x021C, 0x8219, 0x0208, 0x820D, 0x8207, 0x0202
] # fmt: skip
class MockDynamixelPacketv2(abc.ABC):
@classmethod
def build(cls, dxl_id: int, params: list[int], length: int, *args, **kwargs) -> bytes:
packet = cls._build(dxl_id, params, length, *args, **kwargs)
packet = cls._add_stuffing(packet)
packet = cls._add_crc(packet)
return bytes(packet)
@abc.abstractclassmethod
def _build(cls, dxl_id: int, params: list[int], length: int, *args, **kwargs) -> list[int]:
pass
@staticmethod
def _add_stuffing(packet: list[int]) -> list[int]:
"""
Byte stuffing is a method of adding additional data to generated instruction packets to ensure that
the packets are processed successfully. When the byte pattern "0xFF 0xFF 0xFD" appears in a packet,
byte stuffing adds 0xFD to the end of the pattern to convert it to 0xFF 0xFF 0xFD 0xFD to ensure
that it is not interpreted as the header at the start of another packet.
Source: https://emanual.robotis.com/docs/en/dxl/protocol2/#transmission-process
Args:
packet (list[int]): The raw packet without stuffing.
Returns:
list[int]: The packet stuffed if it contained a "0xFF 0xFF 0xFD" byte sequence in its data bytes.
"""
packet_length_in = dxl.DXL_MAKEWORD(packet[dxl.PKT_LENGTH_L], packet[dxl.PKT_LENGTH_H])
packet_length_out = packet_length_in
temp = [0] * dxl.TXPACKET_MAX_LEN
# FF FF FD XX ID LEN_L LEN_H
temp[dxl.PKT_HEADER0 : dxl.PKT_HEADER0 + dxl.PKT_LENGTH_H + 1] = packet[
dxl.PKT_HEADER0 : dxl.PKT_HEADER0 + dxl.PKT_LENGTH_H + 1
]
index = dxl.PKT_INSTRUCTION
for i in range(0, packet_length_in - 2): # except CRC
temp[index] = packet[i + dxl.PKT_INSTRUCTION]
index = index + 1
if (
packet[i + dxl.PKT_INSTRUCTION] == 0xFD
and packet[i + dxl.PKT_INSTRUCTION - 1] == 0xFF
and packet[i + dxl.PKT_INSTRUCTION - 2] == 0xFF
):
# FF FF FD
temp[index] = 0xFD
index = index + 1
packet_length_out = packet_length_out + 1
temp[index] = packet[dxl.PKT_INSTRUCTION + packet_length_in - 2]
temp[index + 1] = packet[dxl.PKT_INSTRUCTION + packet_length_in - 1]
index = index + 2
if packet_length_in != packet_length_out:
packet = [0] * index
packet[0:index] = temp[0:index]
packet[dxl.PKT_LENGTH_L] = dxl.DXL_LOBYTE(packet_length_out)
packet[dxl.PKT_LENGTH_H] = dxl.DXL_HIBYTE(packet_length_out)
return packet
@staticmethod
def _add_crc(packet: list[int]) -> list[int]:
"""Computes and add CRC to the packet.
https://emanual.robotis.com/docs/en/dxl/crc/
https://en.wikipedia.org/wiki/Cyclic_redundancy_check
Args:
packet (list[int]): The raw packet without CRC (but with placeholders for it).
Returns:
list[int]: The raw packet with a valid CRC.
"""
crc = 0
for j in range(len(packet) - 2):
i = ((crc >> 8) ^ packet[j]) & 0xFF
crc = ((crc << 8) ^ DXL_CRC_TABLE[i]) & 0xFFFF
packet[-2] = dxl.DXL_LOBYTE(crc)
packet[-1] = dxl.DXL_HIBYTE(crc)
return packet
class MockInstructionPacket(MockDynamixelPacketv2):
"""
Helper class to build valid Dynamixel Protocol 2.0 Instruction Packets.
Protocol 2.0 Instruction Packet structure
https://emanual.robotis.com/docs/en/dxl/protocol2/#instruction-packet
| Header | Packet ID | Length | Instruction | Params | CRC |
| ------------------- | --------- | ----------- | ----------- | ----------------- | ----------- |
| 0xFF 0xFF 0xFD 0x00 | ID | Len_L Len_H | Instr | Param 1 Param N | CRC_L CRC_H |
"""
@classmethod
def _build(cls, dxl_id: int, params: list[int], length: int, instruction: int) -> list[int]:
length = len(params) + 3
return [
0xFF, 0xFF, 0xFD, 0x00, # header
dxl_id, # servo id
dxl.DXL_LOBYTE(length), # length_l
dxl.DXL_HIBYTE(length), # length_h
instruction, # instruction type
*params, # data bytes
0x00, 0x00 # placeholder for CRC
] # fmt: skip
@classmethod
def ping(
cls,
dxl_id: int,
) -> bytes:
"""
Builds a "Ping" broadcast instruction.
https://emanual.robotis.com/docs/en/dxl/protocol2/#ping-0x01
No parameters required.
"""
return cls.build(dxl_id=dxl_id, params=[], length=3, instruction=dxl.INST_PING)
@classmethod
def read(
cls,
dxl_id: int,
start_address: int,
data_length: int,
) -> bytes:
"""
Builds a "Read" instruction.
https://emanual.robotis.com/docs/en/dxl/protocol2/#read-0x02
The parameters for Read (Protocol 2.0) are:
param[0] = start_address L
param[1] = start_address H
param[2] = data_length L
param[3] = data_length H
And 'length' = data_length + 5, where:
+1 is for instruction byte,
+2 is for the length bytes,
+2 is for the CRC at the end.
"""
params = [
dxl.DXL_LOBYTE(start_address),
dxl.DXL_HIBYTE(start_address),
dxl.DXL_LOBYTE(data_length),
dxl.DXL_HIBYTE(data_length),
]
length = len(params) + 3
# length = data_length + 5
return cls.build(dxl_id=dxl_id, params=params, length=length, instruction=dxl.INST_READ)
@classmethod
def write(
cls,
dxl_id: int,
value: int,
start_address: int,
data_length: int,
) -> bytes:
"""
Builds a "Write" instruction.
https://emanual.robotis.com/docs/en/dxl/protocol2/#write-0x03
The parameters for Write (Protocol 2.0) are:
param[0] = start_address L
param[1] = start_address H
param[2] = 1st Byte
param[3] = 2nd Byte
...
param[1+X] = X-th Byte
And 'length' = data_length + 5, where:
+1 is for instruction byte,
+2 is for the length bytes,
+2 is for the CRC at the end.
"""
data = _split_into_byte_chunks(value, data_length)
params = [
dxl.DXL_LOBYTE(start_address),
dxl.DXL_HIBYTE(start_address),
*data,
]
length = data_length + 5
return cls.build(dxl_id=dxl_id, params=params, length=length, instruction=dxl.INST_WRITE)
@classmethod
def sync_read(
cls,
dxl_ids: list[int],
start_address: int,
data_length: int,
) -> bytes:
"""
Builds a "Sync_Read" broadcast instruction.
https://emanual.robotis.com/docs/en/dxl/protocol2/#sync-read-0x82
The parameters for Sync_Read (Protocol 2.0) are:
param[0] = start_address L
param[1] = start_address H
param[2] = data_length L
param[3] = data_length H
param[4+] = motor IDs to read from
And 'length' = (number_of_params + 7), where:
+1 is for instruction byte,
+2 is for the address bytes,
+2 is for the length bytes,
+2 is for the CRC at the end.
"""
params = [
dxl.DXL_LOBYTE(start_address),
dxl.DXL_HIBYTE(start_address),
dxl.DXL_LOBYTE(data_length),
dxl.DXL_HIBYTE(data_length),
*dxl_ids,
]
length = len(dxl_ids) + 7
return cls.build(
dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruction=dxl.INST_SYNC_READ
)
@classmethod
def sync_write(
cls,
ids_values: dict[int, int],
start_address: int,
data_length: int,
) -> bytes:
"""
Builds a "Sync_Write" broadcast instruction.
https://emanual.robotis.com/docs/en/dxl/protocol2/#sync-write-0x83
The parameters for Sync_Write (Protocol 2.0) are:
param[0] = start_address L
param[1] = start_address H
param[2] = data_length L
param[3] = data_length H
param[5] = [1st motor] ID
param[5+1] = [1st motor] 1st Byte
param[5+2] = [1st motor] 2nd Byte
...
param[5+X] = [1st motor] X-th Byte
param[6] = [2nd motor] ID
param[6+1] = [2nd motor] 1st Byte
param[6+2] = [2nd motor] 2nd Byte
...
param[6+X] = [2nd motor] X-th Byte
And 'length' = ((number_of_params * 1 + data_length) + 7), where:
+1 is for instruction byte,
+2 is for the address bytes,
+2 is for the length bytes,
+2 is for the CRC at the end.
"""
data = []
for id_, value in ids_values.items():
split_value = _split_into_byte_chunks(value, data_length)
data += [id_, *split_value]
params = [
dxl.DXL_LOBYTE(start_address),
dxl.DXL_HIBYTE(start_address),
dxl.DXL_LOBYTE(data_length),
dxl.DXL_HIBYTE(data_length),
*data,
]
length = len(ids_values) * (1 + data_length) + 7
return cls.build(
dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruction=dxl.INST_SYNC_WRITE
)
class MockStatusPacket(MockDynamixelPacketv2):
"""
Helper class to build valid Dynamixel Protocol 2.0 Status Packets.
Protocol 2.0 Status Packet structure
https://emanual.robotis.com/docs/en/dxl/protocol2/#status-packet
| Header | Packet ID | Length | Instruction | Error | Params | CRC |
| ------------------- | --------- | ----------- | ----------- | ----- | ----------------- | ----------- |
| 0xFF 0xFF 0xFD 0x00 | ID | Len_L Len_H | 0x55 | Err | Param 1 Param N | CRC_L CRC_H |
"""
@classmethod
def _build(cls, dxl_id: int, params: list[int], length: int, error: int = 0) -> list[int]:
return [
0xFF, 0xFF, 0xFD, 0x00, # header
dxl_id, # servo id
dxl.DXL_LOBYTE(length), # length_l
dxl.DXL_HIBYTE(length), # length_h
0x55, # instruction = 'status'
error, # error
*params, # data bytes
0x00, 0x00 # placeholder for CRC
] # fmt: skip
@classmethod
def ping(cls, dxl_id: int, model_nb: int = 1190, firm_ver: int = 50, error: int = 0) -> bytes:
"""
Builds a 'Ping' status packet.
https://emanual.robotis.com/docs/en/dxl/protocol2/#ping-0x01
Args:
dxl_id (int): ID of the servo responding.
model_nb (int, optional): Desired 'model number' to be returned in the packet. Defaults to 1190
which corresponds to a XL330-M077-T.
firm_ver (int, optional): Desired 'firmware version' to be returned in the packet.
Defaults to 50.
Returns:
bytes: The raw 'Ping' status packet ready to be sent through serial.
"""
params = [dxl.DXL_LOBYTE(model_nb), dxl.DXL_HIBYTE(model_nb), firm_ver]
length = 7
return cls.build(dxl_id, params=params, length=length, error=error)
@classmethod
def read(cls, dxl_id: int, value: int, param_length: int, error: int = 0) -> bytes:
"""
Builds a 'Read' status packet (also works for 'Sync Read')
https://emanual.robotis.com/docs/en/dxl/protocol2/#read-0x02
https://emanual.robotis.com/docs/en/dxl/protocol2/#sync-read-0x82
Args:
dxl_id (int): ID of the servo responding.
value (int): Desired value to be returned in the packet.
param_length (int): The address length as reported in the control table.
Returns:
bytes: The raw 'Present_Position' status packet ready to be sent through serial.
"""
params = _split_into_byte_chunks(value, param_length)
length = param_length + 4
return cls.build(dxl_id, params=params, length=length, error=error)
class MockPortHandler(dxl.PortHandler):
"""
This class overwrite the 'setupPort' method of the Dynamixel PortHandler because it can specify
baudrates that are not supported with a serial port on MacOS.
"""
def setupPort(self, cflag_baud): # noqa: N802
if self.is_open:
self.closePort()
self.ser = serial.Serial(
port=self.port_name,
# baudrate=self.baudrate, <- This will fail on MacOS
# parity = serial.PARITY_ODD,
# stopbits = serial.STOPBITS_TWO,
bytesize=serial.EIGHTBITS,
timeout=0,
)
self.is_open = True
self.ser.reset_input_buffer()
self.tx_time_per_byte = (1000.0 / self.baudrate) * 10.0
return True
class MockMotors(MockSerial):
"""
This class will simulate physical motors by responding with valid status packets upon receiving some
instruction packets. It is meant to test MotorsBus classes.
"""
def __init__(self):
super().__init__()
@property
def stubs(self) -> dict[str, WaitableStub]:
return super().stubs
def stub(self, *, name=None, **kwargs):
new_stub = WaitableStub(**kwargs)
self._MockSerial__stubs[name or new_stub.receive_bytes] = new_stub
return new_stub
def build_broadcast_ping_stub(
self, ids_models: dict[int, list[int]] | None = None, num_invalid_try: int = 0
) -> str:
ping_request = MockInstructionPacket.ping(dxl.BROADCAST_ID)
return_packets = b"".join(MockStatusPacket.ping(id_, model) for id_, model in ids_models.items())
ping_response = self._build_send_fn(return_packets, num_invalid_try)
stub_name = "Ping_" + "_".join([str(id_) for id_ in ids_models])
self.stub(
name=stub_name,
receive_bytes=ping_request,
send_fn=ping_response,
)
return stub_name
def build_ping_stub(
self, dxl_id: int, model_nb: int, firm_ver: int = 50, num_invalid_try: int = 0, error: int = 0
) -> str:
ping_request = MockInstructionPacket.ping(dxl_id)
return_packet = MockStatusPacket.ping(dxl_id, model_nb, firm_ver, error)
ping_response = self._build_send_fn(return_packet, num_invalid_try)
stub_name = f"Ping_{dxl_id}"
self.stub(
name=stub_name,
receive_bytes=ping_request,
send_fn=ping_response,
)
return stub_name
def build_read_stub(
self,
address: int,
length: int,
dxl_id: int,
value: int,
reply: bool = True,
error: int = 0,
num_invalid_try: int = 0,
) -> str:
read_request = MockInstructionPacket.read(dxl_id, address, length)
return_packet = MockStatusPacket.read(dxl_id, value, length, error) if reply else b""
read_response = self._build_send_fn(return_packet, num_invalid_try)
stub_name = f"Read_{address}_{length}_{dxl_id}_{value}_{error}"
self.stub(
name=stub_name,
receive_bytes=read_request,
send_fn=read_response,
)
return stub_name
def build_write_stub(
self,
address: int,
length: int,
dxl_id: int,
value: int,
reply: bool = True,
error: int = 0,
num_invalid_try: int = 0,
) -> str:
sync_read_request = MockInstructionPacket.write(dxl_id, value, address, length)
return_packet = MockStatusPacket.build(dxl_id, params=[], length=4, error=error) if reply else b""
stub_name = f"Write_{address}_{length}_{dxl_id}"
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
send_fn=self._build_send_fn(return_packet, num_invalid_try),
)
return stub_name
def build_sync_read_stub(
self,
address: int,
length: int,
ids_values: dict[int, int],
reply: bool = True,
num_invalid_try: int = 0,
) -> str:
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
return_packets = (
b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items())
if reply
else b""
)
sync_read_response = self._build_send_fn(return_packets, num_invalid_try)
stub_name = f"Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
send_fn=sync_read_response,
)
return stub_name
def build_sequential_sync_read_stub(
self, address: int, length: int, ids_values: dict[int, list[int]] | None = None
) -> str:
sequence_length = len(next(iter(ids_values.values())))
assert all(len(positions) == sequence_length for positions in ids_values.values())
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
sequential_packets = []
for count in range(sequence_length):
return_packets = b"".join(
MockStatusPacket.read(id_, positions[count], length) for id_, positions in ids_values.items()
)
sequential_packets.append(return_packets)
sync_read_response = self._build_sequential_send_fn(sequential_packets)
stub_name = f"Seq_Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
send_fn=sync_read_response,
)
return stub_name
def build_sync_write_stub(
self, address: int, length: int, ids_values: dict[int, int], num_invalid_try: int = 0
) -> str:
sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length)
stub_name = f"Sync_Write_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
send_fn=self._build_send_fn(b"", num_invalid_try),
)
return stub_name
@staticmethod
def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]:
def send_fn(_call_count: int) -> bytes:
if num_invalid_try >= _call_count:
return b""
return packet
return send_fn
@staticmethod
def _build_sequential_send_fn(packets: list[bytes]) -> Callable[[int], bytes]:
def send_fn(_call_count: int) -> bytes:
return packets[_call_count - 1]
return send_fn

View File

@ -1,444 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from collections.abc import Callable
import scservo_sdk as scs
import serial
from mock_serial import MockSerial
from lerobot.motors.feetech.feetech import _split_into_byte_chunks, patch_setPacketTimeout
from .mock_serial_patch import WaitableStub
class MockFeetechPacket(abc.ABC):
@classmethod
def build(cls, scs_id: int, params: list[int], length: int, *args, **kwargs) -> bytes:
packet = cls._build(scs_id, params, length, *args, **kwargs)
packet = cls._add_checksum(packet)
return bytes(packet)
@abc.abstractclassmethod
def _build(cls, scs_id: int, params: list[int], length: int, *args, **kwargs) -> list[int]:
pass
@staticmethod
def _add_checksum(packet: list[int]) -> list[int]:
checksum = 0
for id_ in range(2, len(packet) - 1): # except header & checksum
checksum += packet[id_]
packet[-1] = ~checksum & 0xFF
return packet
class MockInstructionPacket(MockFeetechPacket):
"""
Helper class to build valid Feetech Instruction Packets.
Instruction Packet structure
(from https://files.waveshare.com/upload/2/27/Communication_Protocol_User_Manual-EN%28191218-0923%29.pdf)
| Header | Packet ID | Length | Instruction | Params | Checksum |
| --------- | --------- | ------ | ----------- | ----------------- | -------- |
| 0xFF 0xFF | ID | Len | Instr | Param 1 Param N | Sum |
"""
@classmethod
def _build(cls, scs_id: int, params: list[int], length: int, instruction: int) -> list[int]:
return [
0xFF, 0xFF, # header
scs_id, # servo id
length, # length
instruction, # instruction type
*params, # data bytes
0x00, # placeholder for checksum
] # fmt: skip
@classmethod
def ping(
cls,
scs_id: int,
) -> bytes:
"""
Builds a "Ping" broadcast instruction.
No parameters required.
"""
return cls.build(scs_id=scs_id, params=[], length=2, instruction=scs.INST_PING)
@classmethod
def read(
cls,
scs_id: int,
start_address: int,
data_length: int,
) -> bytes:
"""
Builds a "Read" instruction.
The parameters for Read are:
param[0] = start_address
param[1] = data_length
And 'length' = 4, where:
+1 is for instruction byte,
+1 is for the address byte,
+1 is for the length bytes,
+1 is for the checksum at the end.
"""
params = [start_address, data_length]
length = 4
return cls.build(scs_id=scs_id, params=params, length=length, instruction=scs.INST_READ)
@classmethod
def write(
cls,
scs_id: int,
value: int,
start_address: int,
data_length: int,
) -> bytes:
"""
Builds a "Write" instruction.
The parameters for Write are:
param[0] = start_address L
param[1] = start_address H
param[2] = 1st Byte
param[3] = 2nd Byte
...
param[1+X] = X-th Byte
And 'length' = data_length + 3, where:
+1 is for instruction byte,
+1 is for the length bytes,
+1 is for the checksum at the end.
"""
data = _split_into_byte_chunks(value, data_length)
params = [start_address, *data]
length = data_length + 3
return cls.build(scs_id=scs_id, params=params, length=length, instruction=scs.INST_WRITE)
@classmethod
def sync_read(
cls,
scs_ids: list[int],
start_address: int,
data_length: int,
) -> bytes:
"""
Builds a "Sync_Read" broadcast instruction.
The parameters for Sync Read are:
param[0] = start_address
param[1] = data_length
param[2+] = motor IDs to read from
And 'length' = (number_of_params + 4), where:
+1 is for instruction byte,
+1 is for the address byte,
+1 is for the length bytes,
+1 is for the checksum at the end.
"""
params = [start_address, data_length, *scs_ids]
length = len(scs_ids) + 4
return cls.build(
scs_id=scs.BROADCAST_ID, params=params, length=length, instruction=scs.INST_SYNC_READ
)
@classmethod
def sync_write(
cls,
ids_values: dict[int, int],
start_address: int,
data_length: int,
) -> bytes:
"""
Builds a "Sync_Write" broadcast instruction.
The parameters for Sync_Write are:
param[0] = start_address
param[1] = data_length
param[2] = [1st motor] ID
param[2+1] = [1st motor] 1st Byte
param[2+2] = [1st motor] 2nd Byte
...
param[5+X] = [1st motor] X-th Byte
param[6] = [2nd motor] ID
param[6+1] = [2nd motor] 1st Byte
param[6+2] = [2nd motor] 2nd Byte
...
param[6+X] = [2nd motor] X-th Byte
And 'length' = ((number_of_params * 1 + data_length) + 4), where:
+1 is for instruction byte,
+1 is for the address byte,
+1 is for the length bytes,
+1 is for the checksum at the end.
"""
data = []
for id_, value in ids_values.items():
split_value = _split_into_byte_chunks(value, data_length)
data += [id_, *split_value]
params = [start_address, data_length, *data]
length = len(ids_values) * (1 + data_length) + 4
return cls.build(
scs_id=scs.BROADCAST_ID, params=params, length=length, instruction=scs.INST_SYNC_WRITE
)
class MockStatusPacket(MockFeetechPacket):
"""
Helper class to build valid Feetech Status Packets.
Status Packet structure
(from https://files.waveshare.com/upload/2/27/Communication_Protocol_User_Manual-EN%28191218-0923%29.pdf)
| Header | Packet ID | Length | Error | Params | Checksum |
| --------- | --------- | ------ | ----- | ----------------- | -------- |
| 0xFF 0xFF | ID | Len | Err | Param 1 Param N | Sum |
"""
@classmethod
def _build(cls, scs_id: int, params: list[int], length: int, error: int = 0) -> list[int]:
return [
0xFF, 0xFF, # header
scs_id, # servo id
length, # length
error, # status
*params, # data bytes
0x00, # placeholder for checksum
] # fmt: skip
@classmethod
def ping(cls, scs_id: int, error: int = 0) -> bytes:
"""Builds a 'Ping' status packet.
Args:
scs_id (int): ID of the servo responding.
error (int, optional): Error to be returned. Defaults to 0 (success).
Returns:
bytes: The raw 'Ping' status packet ready to be sent through serial.
"""
return cls.build(scs_id, params=[], length=2, error=error)
@classmethod
def read(cls, scs_id: int, value: int, param_length: int, error: int = 0) -> bytes:
"""Builds a 'Read' status packet.
Args:
scs_id (int): ID of the servo responding.
value (int): Desired value to be returned in the packet.
param_length (int): The address length as reported in the control table.
Returns:
bytes: The raw 'Sync Read' status packet ready to be sent through serial.
"""
params = _split_into_byte_chunks(value, param_length)
length = param_length + 2
return cls.build(scs_id, params=params, length=length, error=error)
class MockPortHandler(scs.PortHandler):
"""
This class overwrite the 'setupPort' method of the Feetech PortHandler because it can specify
baudrates that are not supported with a serial port on MacOS.
"""
def setupPort(self, cflag_baud): # noqa: N802
if self.is_open:
self.closePort()
self.ser = serial.Serial(
port=self.port_name,
# baudrate=self.baudrate, <- This will fail on MacOS
# parity = serial.PARITY_ODD,
# stopbits = serial.STOPBITS_TWO,
bytesize=serial.EIGHTBITS,
timeout=0,
)
self.is_open = True
self.ser.reset_input_buffer()
self.tx_time_per_byte = (1000.0 / self.baudrate) * 10.0
return True
def setPacketTimeout(self, packet_length): # noqa: N802
return patch_setPacketTimeout(self, packet_length)
class MockMotors(MockSerial):
"""
This class will simulate physical motors by responding with valid status packets upon receiving some
instruction packets. It is meant to test MotorsBus classes.
"""
def __init__(self):
super().__init__()
@property
def stubs(self) -> dict[str, WaitableStub]:
return super().stubs
def stub(self, *, name=None, **kwargs):
new_stub = WaitableStub(**kwargs)
self._MockSerial__stubs[name or new_stub.receive_bytes] = new_stub
return new_stub
def build_broadcast_ping_stub(self, ids: list[int] | None = None, num_invalid_try: int = 0) -> str:
ping_request = MockInstructionPacket.ping(scs.BROADCAST_ID)
return_packets = b"".join(MockStatusPacket.ping(id_) for id_ in ids)
ping_response = self._build_send_fn(return_packets, num_invalid_try)
stub_name = "Ping_" + "_".join([str(id_) for id_ in ids])
self.stub(
name=stub_name,
receive_bytes=ping_request,
send_fn=ping_response,
)
return stub_name
def build_ping_stub(self, scs_id: int, num_invalid_try: int = 0, error: int = 0) -> str:
ping_request = MockInstructionPacket.ping(scs_id)
return_packet = MockStatusPacket.ping(scs_id, error)
ping_response = self._build_send_fn(return_packet, num_invalid_try)
stub_name = f"Ping_{scs_id}_{error}"
self.stub(
name=stub_name,
receive_bytes=ping_request,
send_fn=ping_response,
)
return stub_name
def build_read_stub(
self,
address: int,
length: int,
scs_id: int,
value: int,
reply: bool = True,
error: int = 0,
num_invalid_try: int = 0,
) -> str:
read_request = MockInstructionPacket.read(scs_id, address, length)
return_packet = MockStatusPacket.read(scs_id, value, length, error) if reply else b""
read_response = self._build_send_fn(return_packet, num_invalid_try)
stub_name = f"Read_{address}_{length}_{scs_id}_{value}_{error}"
self.stub(
name=stub_name,
receive_bytes=read_request,
send_fn=read_response,
)
return stub_name
def build_write_stub(
self,
address: int,
length: int,
scs_id: int,
value: int,
reply: bool = True,
error: int = 0,
num_invalid_try: int = 0,
) -> str:
sync_read_request = MockInstructionPacket.write(scs_id, value, address, length)
return_packet = MockStatusPacket.build(scs_id, params=[], length=2, error=error) if reply else b""
stub_name = f"Write_{address}_{length}_{scs_id}"
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
send_fn=self._build_send_fn(return_packet, num_invalid_try),
)
return stub_name
def build_sync_read_stub(
self,
address: int,
length: int,
ids_values: dict[int, int],
reply: bool = True,
num_invalid_try: int = 0,
) -> str:
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
return_packets = (
b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items())
if reply
else b""
)
sync_read_response = self._build_send_fn(return_packets, num_invalid_try)
stub_name = f"Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
send_fn=sync_read_response,
)
return stub_name
def build_sequential_sync_read_stub(
self, address: int, length: int, ids_values: dict[int, list[int]] | None = None
) -> str:
sequence_length = len(next(iter(ids_values.values())))
assert all(len(positions) == sequence_length for positions in ids_values.values())
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
sequential_packets = []
for count in range(sequence_length):
return_packets = b"".join(
MockStatusPacket.read(id_, positions[count], length) for id_, positions in ids_values.items()
)
sequential_packets.append(return_packets)
sync_read_response = self._build_sequential_send_fn(sequential_packets)
stub_name = f"Seq_Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
send_fn=sync_read_response,
)
return stub_name
def build_sync_write_stub(
self, address: int, length: int, ids_values: dict[int, int], num_invalid_try: int = 0
) -> str:
sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length)
stub_name = f"Sync_Write_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
self.stub(
name=stub_name,
receive_bytes=sync_read_request,
send_fn=self._build_send_fn(b"", num_invalid_try),
)
return stub_name
@staticmethod
def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]:
def send_fn(_call_count: int) -> bytes:
if num_invalid_try >= _call_count:
return b""
return packet
return send_fn
@staticmethod
def _build_sequential_send_fn(packets: list[bytes]) -> Callable[[int], bytes]:
def send_fn(_call_count: int) -> bytes:
return packets[_call_count - 1]
return send_fn

View File

@ -1,152 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: N802
from lerobot.motors.motors_bus import (
Motor,
MotorsBus,
)
DUMMY_CTRL_TABLE_1 = {
"Firmware_Version": (0, 1),
"Model_Number": (1, 2),
"Present_Position": (3, 4),
"Goal_Position": (11, 2),
}
DUMMY_CTRL_TABLE_2 = {
"Model_Number": (0, 2),
"Firmware_Version": (2, 1),
"Present_Position": (3, 4),
"Present_Velocity": (7, 4),
"Goal_Position": (11, 4),
"Goal_Velocity": (15, 4),
"Lock": (19, 1),
}
DUMMY_MODEL_CTRL_TABLE = {
"model_1": DUMMY_CTRL_TABLE_1,
"model_2": DUMMY_CTRL_TABLE_2,
"model_3": DUMMY_CTRL_TABLE_2,
}
DUMMY_BAUDRATE_TABLE = {
0: 1_000_000,
1: 500_000,
2: 250_000,
}
DUMMY_MODEL_BAUDRATE_TABLE = {
"model_1": DUMMY_BAUDRATE_TABLE,
"model_2": DUMMY_BAUDRATE_TABLE,
"model_3": DUMMY_BAUDRATE_TABLE,
}
DUMMY_ENCODING_TABLE = {
"Present_Position": 8,
"Goal_Position": 10,
}
DUMMY_MODEL_ENCODING_TABLE = {
"model_1": DUMMY_ENCODING_TABLE,
"model_2": DUMMY_ENCODING_TABLE,
"model_3": DUMMY_ENCODING_TABLE,
}
DUMMY_MODEL_NUMBER_TABLE = {
"model_1": 1234,
"model_2": 5678,
"model_3": 5799,
}
DUMMY_MODEL_RESOLUTION_TABLE = {
"model_1": 4096,
"model_2": 1024,
"model_3": 4096,
}
class MockPortHandler:
def __init__(self, port_name):
self.is_open: bool = False
self.baudrate: int
self.packet_start_time: float
self.packet_timeout: float
self.tx_time_per_byte: float
self.is_using: bool = False
self.port_name: str = port_name
self.ser = None
def openPort(self):
self.is_open = True
return self.is_open
def closePort(self):
self.is_open = False
def clearPort(self): ...
def setPortName(self, port_name):
self.port_name = port_name
def getPortName(self):
return self.port_name
def setBaudRate(self, baudrate):
self.baudrate: baudrate
def getBaudRate(self):
return self.baudrate
def getBytesAvailable(self): ...
def readPort(self, length): ...
def writePort(self, packet): ...
def setPacketTimeout(self, packet_length): ...
def setPacketTimeoutMillis(self, msec): ...
def isPacketTimeout(self): ...
def getCurrentTime(self): ...
def getTimeSinceStart(self): ...
def setupPort(self, cflag_baud): ...
def getCFlagBaud(self, baudrate): ...
class MockMotorsBus(MotorsBus):
available_baudrates = [500_000, 1_000_000]
default_timeout = 1000
model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE
model_ctrl_table = DUMMY_MODEL_CTRL_TABLE
model_encoding_table = DUMMY_MODEL_ENCODING_TABLE
model_number_table = DUMMY_MODEL_NUMBER_TABLE
model_resolution_table = DUMMY_MODEL_RESOLUTION_TABLE
normalized_data = ["Present_Position", "Goal_Position"]
def __init__(self, port: str, motors: dict[str, Motor]):
super().__init__(port, motors)
self.port_handler = MockPortHandler(port)
def _assert_protocol_is_compatible(self, instruction_name): ...
def _handshake(self): ...
def _find_single_motor(self, motor, initial_baudrate): ...
def configure_motors(self): ...
def is_calibrated(self): ...
def read_calibration(self): ...
def write_calibration(self, calibration_dict): ...
def disable_torque(self, motors, num_retry): ...
def _disable_torque(self, motor, model, num_retry): ...
def enable_torque(self, motors, num_retry): ...
def _get_half_turn_homings(self, positions): ...
def _encode_sign(self, data_name, ids_values): ...
def _decode_sign(self, data_name, ids_values): ...
def _split_into_byte_chunks(self, value, length): ...
def broadcast_ping(self, num_retry, raise_on_error): ...

View File

@ -1,128 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any
from lerobot.cameras import CameraConfig, make_cameras_from_configs
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.robots import Robot, RobotConfig
@RobotConfig.register_subclass("mock_robot")
@dataclass
class MockRobotConfig(RobotConfig):
n_motors: int = 3
cameras: dict[str, CameraConfig] = field(default_factory=dict)
random_values: bool = True
static_values: list[float] | None = None
calibrated: bool = True
def __post_init__(self):
if self.n_motors < 1:
raise ValueError(self.n_motors)
if self.random_values and self.static_values is not None:
raise ValueError("Choose either random values or static values")
if self.static_values is not None and len(self.static_values) != self.n_motors:
raise ValueError("Specify the same number of static values as motors")
if len(self.cameras) > 0:
raise NotImplementedError # TODO with the cameras refactor
class MockRobot(Robot):
"""Mock Robot to be used for testing."""
config_class = MockRobotConfig
name = "mock_robot"
def __init__(self, config: MockRobotConfig):
super().__init__(config)
self.config = config
self._is_connected = False
self._is_calibrated = config.calibrated
self.motors = [f"motor_{i + 1}" for i in range(config.n_motors)]
self.cameras = make_cameras_from_configs(config.cameras)
@property
def _motors_ft(self) -> dict[str, type]:
return {f"{motor}.pos": float for motor in self.motors}
@property
def _cameras_ft(self) -> dict[str, tuple]:
return {
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
return {**self._motors_ft, **self._cameras_ft}
@cached_property
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self._is_connected
def connect(self, calibrate: bool = True) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self._is_connected = True
if calibrate:
self.calibrate()
@property
def is_calibrated(self) -> bool:
return self._is_calibrated
def calibrate(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self._is_calibrated = True
def configure(self) -> None:
pass
def get_observation(self) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.config.random_values:
return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors}
else:
return {
f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True)
}
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
return action
def disconnect(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self._is_connected = False

View File

@ -1,51 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time
from mock_serial.mock_serial import Stub
class WaitableStub(Stub):
"""
In some situations, a test might be checking if a stub has been called before `MockSerial` thread had time
to read, match, and call the stub. In these situations, the test can fail randomly.
Use `wait_called()` or `wait_calls()` to block until the stub is called, avoiding race conditions.
Proposed fix:
https://github.com/benthorner/mock_serial/pull/3
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._event = threading.Event()
def call(self):
self._event.set()
return super().call()
def wait_called(self, timeout: float = 1.0):
return self._event.wait(timeout)
def wait_calls(self, min_calls: int = 1, timeout: float = 1.0):
start = time.perf_counter()
while time.perf_counter() - start < timeout:
if self.calls >= min_calls:
return self.calls
time.sleep(0.005)
raise TimeoutError(f"Stub not called {min_calls} times within {timeout} seconds.")

View File

@ -1,110 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from dataclasses import dataclass
from functools import cached_property
from typing import Any
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig
@TeleoperatorConfig.register_subclass("mock_teleop")
@dataclass
class MockTeleopConfig(TeleoperatorConfig):
n_motors: int = 3
random_values: bool = True
static_values: list[float] | None = None
calibrated: bool = True
def __post_init__(self):
if self.n_motors < 1:
raise ValueError(self.n_motors)
if self.random_values and self.static_values is not None:
raise ValueError("Choose either random values or static values")
if self.static_values is not None and len(self.static_values) != self.n_motors:
raise ValueError("Specify the same number of static values as motors")
class MockTeleop(Teleoperator):
"""Mock Teleoperator to be used for testing."""
config_class = MockTeleopConfig
name = "mock_teleop"
def __init__(self, config: MockTeleopConfig):
super().__init__(config)
self.config = config
self._is_connected = False
self._is_calibrated = config.calibrated
self.motors = [f"motor_{i + 1}" for i in range(config.n_motors)]
@cached_property
def action_features(self) -> dict[str, type]:
return {f"{motor}.pos": float for motor in self.motors}
@cached_property
def feedback_features(self) -> dict[str, type]:
return {f"{motor}.pos": float for motor in self.motors}
@property
def is_connected(self) -> bool:
return self._is_connected
def connect(self, calibrate: bool = True) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self._is_connected = True
if calibrate:
self.calibrate()
@property
def is_calibrated(self) -> bool:
return self._is_calibrated
def calibrate(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self._is_calibrated = True
def configure(self) -> None:
pass
def get_action(self) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.config.random_values:
return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors}
else:
return {
f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True)
}
def send_feedback(self, feedback: dict[str, Any]) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
def disconnect(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self._is_connected = False

View File

@ -1,416 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import sys
from collections.abc import Generator
from unittest.mock import MagicMock, patch
import pytest
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.dynamixel import MODEL_NUMBER_TABLE, DynamixelMotorsBus
from lerobot.motors.dynamixel.tables import X_SERIES_CONTROL_TABLE
from lerobot.utils.encoding_utils import encode_twos_complement
try:
import dynamixel_sdk as dxl
from tests.mocks.mock_dynamixel import MockMotors, MockPortHandler
except (ImportError, ModuleNotFoundError):
pytest.skip("dynamixel_sdk not available", allow_module_level=True)
@pytest.fixture(autouse=True)
def patch_port_handler():
if sys.platform == "darwin":
with patch.object(dxl, "PortHandler", MockPortHandler):
yield
else:
yield
@pytest.fixture
def mock_motors() -> Generator[MockMotors, None, None]:
motors = MockMotors()
motors.open()
yield motors
motors.close()
@pytest.fixture
def dummy_motors() -> dict[str, Motor]:
return {
"dummy_1": Motor(1, "xl430-w250", MotorNormMode.RANGE_M100_100),
"dummy_2": Motor(2, "xm540-w270", MotorNormMode.RANGE_M100_100),
"dummy_3": Motor(3, "xl330-m077", MotorNormMode.RANGE_M100_100),
}
@pytest.fixture
def dummy_calibration(dummy_motors) -> dict[str, MotorCalibration]:
drive_modes = [0, 1, 0]
homings = [-709, -2006, 1624]
mins = [43, 27, 145]
maxes = [1335, 3608, 3999]
calibration = {}
for motor, m in dummy_motors.items():
calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=drive_modes[m.id - 1],
homing_offset=homings[m.id - 1],
range_min=mins[m.id - 1],
range_max=maxes[m.id - 1],
)
return calibration
@pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}")
def test_autouse_patch():
"""Ensures that the autouse fixture correctly patches dxl.PortHandler with MockPortHandler."""
assert dxl.PortHandler is MockPortHandler
@pytest.mark.parametrize(
"value, length, expected",
[
(0x12, 1, [0x12]),
(0x1234, 2, [0x34, 0x12]),
(0x12345678, 4, [0x78, 0x56, 0x34, 0x12]),
],
ids=[
"1 byte",
"2 bytes",
"4 bytes",
],
) # fmt: skip
def test__split_into_byte_chunks(value, length, expected):
bus = DynamixelMotorsBus("", {})
assert bus._split_into_byte_chunks(value, length) == expected
def test_abc_implementation(dummy_motors):
"""Instantiation should raise an error if the class doesn't implement abstract methods/properties."""
DynamixelMotorsBus(port="/dev/dummy-port", motors=dummy_motors)
@pytest.mark.parametrize("id_", [1, 2, 3])
def test_ping(id_, mock_motors, dummy_motors):
expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model]
stub = mock_motors.build_ping_stub(id_, expected_model_nb)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
ping_model_nb = bus.ping(id_)
assert ping_model_nb == expected_model_nb
assert mock_motors.stubs[stub].called
def test_broadcast_ping(mock_motors, dummy_motors):
models = {m.id: m.model for m in dummy_motors.values()}
expected_model_nbs = {id_: MODEL_NUMBER_TABLE[model] for id_, model in models.items()}
stub = mock_motors.build_broadcast_ping_stub(expected_model_nbs)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
ping_model_nbs = bus.broadcast_ping()
assert ping_model_nbs == expected_model_nbs
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize(
"addr, length, id_, value",
[
(0, 1, 1, 2),
(10, 2, 2, 999),
(42, 4, 3, 1337),
],
)
def test__read(addr, length, id_, value, mock_motors, dummy_motors):
stub = mock_motors.build_read_stub(addr, length, id_, value)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
read_value, _, _ = bus._read(addr, length, id_)
assert mock_motors.stubs[stub].called
assert read_value == value
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__read_error(raise_on_error, mock_motors, dummy_motors):
addr, length, id_, value, error = (10, 4, 1, 1337, dxl.ERRNUM_DATA_LIMIT)
stub = mock_motors.build_read_stub(addr, length, id_, value, error=error)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
if raise_on_error:
with pytest.raises(
RuntimeError, match=re.escape("[RxPacketError] The data value exceeds the limit value!")
):
bus._read(addr, length, id_, raise_on_error=raise_on_error)
else:
_, _, read_error = bus._read(addr, length, id_, raise_on_error=raise_on_error)
assert read_error == error
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__read_comm(raise_on_error, mock_motors, dummy_motors):
addr, length, id_, value = (10, 4, 1, 1337)
stub = mock_motors.build_read_stub(addr, length, id_, value, reply=False)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
if raise_on_error:
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
bus._read(addr, length, id_, raise_on_error=raise_on_error)
else:
_, read_comm, _ = bus._read(addr, length, id_, raise_on_error=raise_on_error)
assert read_comm == dxl.COMM_RX_TIMEOUT
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize(
"addr, length, id_, value",
[
(0, 1, 1, 2),
(10, 2, 2, 999),
(42, 4, 3, 1337),
],
)
def test__write(addr, length, id_, value, mock_motors, dummy_motors):
stub = mock_motors.build_write_stub(addr, length, id_, value)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
comm, error = bus._write(addr, length, id_, value)
assert mock_motors.stubs[stub].called
assert comm == dxl.COMM_SUCCESS
assert error == 0
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__write_error(raise_on_error, mock_motors, dummy_motors):
addr, length, id_, value, error = (10, 4, 1, 1337, dxl.ERRNUM_DATA_LIMIT)
stub = mock_motors.build_write_stub(addr, length, id_, value, error=error)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
if raise_on_error:
with pytest.raises(
RuntimeError, match=re.escape("[RxPacketError] The data value exceeds the limit value!")
):
bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
else:
_, write_error = bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
assert write_error == error
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__write_comm(raise_on_error, mock_motors, dummy_motors):
addr, length, id_, value = (10, 4, 1, 1337)
stub = mock_motors.build_write_stub(addr, length, id_, value, reply=False)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
if raise_on_error:
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
else:
write_comm, _ = bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
assert write_comm == dxl.COMM_RX_TIMEOUT
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize(
"addr, length, ids_values",
[
(0, 1, {1: 4}),
(10, 2, {1: 1337, 2: 42}),
(42, 4, {1: 1337, 2: 42, 3: 4016}),
],
ids=["1 motor", "2 motors", "3 motors"],
)
def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors):
stub = mock_motors.build_sync_read_stub(addr, length, ids_values)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
read_values, _ = bus._sync_read(addr, length, list(ids_values))
assert mock_motors.stubs[stub].called
assert read_values == ids_values
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors):
addr, length, ids_values = (10, 4, {1: 1337})
stub = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
if raise_on_error:
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
else:
_, read_comm = bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
assert read_comm == dxl.COMM_RX_TIMEOUT
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize(
"addr, length, ids_values",
[
(0, 1, {1: 4}),
(10, 2, {1: 1337, 2: 42}),
(42, 4, {1: 1337, 2: 42, 3: 4016}),
],
ids=["1 motor", "2 motors", "3 motors"],
)
def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors):
stub = mock_motors.build_sync_write_stub(addr, length, ids_values)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
comm = bus._sync_write(addr, length, ids_values)
assert mock_motors.stubs[stub].wait_called()
assert comm == dxl.COMM_SUCCESS
def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration):
drive_modes = {m.id: m.drive_mode for m in dummy_calibration.values()}
encoded_homings = {m.id: encode_twos_complement(m.homing_offset, 4) for m in dummy_calibration.values()}
mins = {m.id: m.range_min for m in dummy_calibration.values()}
maxes = {m.id: m.range_max for m in dummy_calibration.values()}
drive_modes_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Drive_Mode"], drive_modes)
offsets_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], encoded_homings)
mins_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Min_Position_Limit"], mins)
maxes_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Max_Position_Limit"], maxes)
bus = DynamixelMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
calibration=dummy_calibration,
)
bus.connect(handshake=False)
is_calibrated = bus.is_calibrated
assert is_calibrated
assert mock_motors.stubs[drive_modes_stub].called
assert mock_motors.stubs[offsets_stub].called
assert mock_motors.stubs[mins_stub].called
assert mock_motors.stubs[maxes_stub].called
def test_reset_calibration(mock_motors, dummy_motors):
write_homing_stubs = []
write_mins_stubs = []
write_maxes_stubs = []
for motor in dummy_motors.values():
write_homing_stubs.append(
mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], motor.id, 0)
)
write_mins_stubs.append(
mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Min_Position_Limit"], motor.id, 0)
)
write_maxes_stubs.append(
mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Max_Position_Limit"], motor.id, 4095)
)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
bus.reset_calibration()
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
assert all(mock_motors.stubs[stub].called for stub in write_mins_stubs)
assert all(mock_motors.stubs[stub].called for stub in write_maxes_stubs)
def test_set_half_turn_homings(mock_motors, dummy_motors):
"""
For this test, we assume that the homing offsets are already 0 such that
Present_Position == Actual_Position
"""
current_positions = {
1: 1337,
2: 42,
3: 3672,
}
expected_homings = {
1: 710, # 2047 - 1337
2: 2005, # 2047 - 42
3: -1625, # 2047 - 3672
}
read_pos_stub = mock_motors.build_sync_read_stub(
*X_SERIES_CONTROL_TABLE["Present_Position"], current_positions
)
write_homing_stubs = []
for id_, homing in expected_homings.items():
encoded_homing = encode_twos_complement(homing, 4)
stub = mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], id_, encoded_homing)
write_homing_stubs.append(stub)
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
bus.reset_calibration = MagicMock()
bus.set_half_turn_homings()
bus.reset_calibration.assert_called_once()
assert mock_motors.stubs[read_pos_stub].called
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
def test_record_ranges_of_motion(mock_motors, dummy_motors):
positions = {
1: [351, 42, 1337],
2: [28, 3600, 2444],
3: [4002, 2999, 146],
}
expected_mins = {
"dummy_1": 42,
"dummy_2": 28,
"dummy_3": 146,
}
expected_maxes = {
"dummy_1": 1337,
"dummy_2": 3600,
"dummy_3": 4002,
}
read_pos_stub = mock_motors.build_sequential_sync_read_stub(
*X_SERIES_CONTROL_TABLE["Present_Position"], positions
)
with patch("lerobot.motors.motors_bus.enter_pressed", side_effect=[False, True]):
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
mins, maxes = bus.record_ranges_of_motion(display_values=False)
assert mock_motors.stubs[read_pos_stub].calls == 3
assert mins == expected_mins
assert maxes == expected_maxes

View File

@ -1,459 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import sys
from collections.abc import Generator
from unittest.mock import MagicMock, patch
import pytest
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.feetech import MODEL_NUMBER, MODEL_NUMBER_TABLE, FeetechMotorsBus
from lerobot.motors.feetech.tables import STS_SMS_SERIES_CONTROL_TABLE
from lerobot.utils.encoding_utils import encode_sign_magnitude
try:
import scservo_sdk as scs
from tests.mocks.mock_feetech import MockMotors, MockPortHandler
except (ImportError, ModuleNotFoundError):
pytest.skip("scservo_sdk not available", allow_module_level=True)
@pytest.fixture(autouse=True)
def patch_port_handler():
if sys.platform == "darwin":
with patch.object(scs, "PortHandler", MockPortHandler):
yield
else:
yield
@pytest.fixture
def mock_motors() -> Generator[MockMotors, None, None]:
motors = MockMotors()
motors.open()
yield motors
motors.close()
@pytest.fixture
def dummy_motors() -> dict[str, Motor]:
return {
"dummy_1": Motor(1, "sts3215", MotorNormMode.RANGE_M100_100),
"dummy_2": Motor(2, "sts3215", MotorNormMode.RANGE_M100_100),
"dummy_3": Motor(3, "sts3215", MotorNormMode.RANGE_M100_100),
}
@pytest.fixture
def dummy_calibration(dummy_motors) -> dict[str, MotorCalibration]:
homings = [-709, -2006, 1624]
mins = [43, 27, 145]
maxes = [1335, 3608, 3999]
calibration = {}
for motor, m in dummy_motors.items():
calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=0,
homing_offset=homings[m.id - 1],
range_min=mins[m.id - 1],
range_max=maxes[m.id - 1],
)
return calibration
@pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}")
def test_autouse_patch():
"""Ensures that the autouse fixture correctly patches scs.PortHandler with MockPortHandler."""
assert scs.PortHandler is MockPortHandler
@pytest.mark.parametrize(
"protocol, value, length, expected",
[
(0, 0x12, 1, [0x12]),
(1, 0x12, 1, [0x12]),
(0, 0x1234, 2, [0x34, 0x12]),
(1, 0x1234, 2, [0x12, 0x34]),
(0, 0x12345678, 4, [0x78, 0x56, 0x34, 0x12]),
(1, 0x12345678, 4, [0x56, 0x78, 0x12, 0x34]),
],
ids=[
"P0: 1 byte",
"P1: 1 byte",
"P0: 2 bytes",
"P1: 2 bytes",
"P0: 4 bytes",
"P1: 4 bytes",
],
) # fmt: skip
def test__split_into_byte_chunks(protocol, value, length, expected):
bus = FeetechMotorsBus("", {}, protocol_version=protocol)
assert bus._split_into_byte_chunks(value, length) == expected
def test_abc_implementation(dummy_motors):
"""Instantiation should raise an error if the class doesn't implement abstract methods/properties."""
FeetechMotorsBus(port="/dev/dummy-port", motors=dummy_motors)
@pytest.mark.parametrize("id_", [1, 2, 3])
def test_ping(id_, mock_motors, dummy_motors):
expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model]
addr, length = MODEL_NUMBER
ping_stub = mock_motors.build_ping_stub(id_)
mobel_nb_stub = mock_motors.build_read_stub(addr, length, id_, expected_model_nb)
bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
bus.connect(handshake=False)
ping_model_nb = bus.ping(id_)
assert ping_model_nb == expected_model_nb
assert mock_motors.stubs[ping_stub].called
assert mock_motors.stubs[mobel_nb_stub].called
def test_broadcast_ping(mock_motors, dummy_motors):
models = {m.id: m.model for m in dummy_motors.values()}
addr, length = MODEL_NUMBER
ping_stub = mock_motors.build_broadcast_ping_stub(list(models))
mobel_nb_stubs = []
expected_model_nbs = {}
for id_, model in models.items():
model_nb = MODEL_NUMBER_TABLE[model]
stub = mock_motors.build_read_stub(addr, length, id_, model_nb)
expected_model_nbs[id_] = model_nb
mobel_nb_stubs.append(stub)
bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
bus.connect(handshake=False)
ping_model_nbs = bus.broadcast_ping()
assert ping_model_nbs == expected_model_nbs
assert mock_motors.stubs[ping_stub].called
assert all(mock_motors.stubs[stub].called for stub in mobel_nb_stubs)
@pytest.mark.parametrize(
"addr, length, id_, value",
[
(0, 1, 1, 2),
(10, 2, 2, 999),
(42, 4, 3, 1337),
],
)
def test__read(addr, length, id_, value, mock_motors, dummy_motors):
stub = mock_motors.build_read_stub(addr, length, id_, value)
bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
bus.connect(handshake=False)
read_value, _, _ = bus._read(addr, length, id_)
assert mock_motors.stubs[stub].called
assert read_value == value
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__read_error(raise_on_error, mock_motors, dummy_motors):
addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE)
stub = mock_motors.build_read_stub(addr, length, id_, value, error=error)
bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
bus.connect(handshake=False)
if raise_on_error:
with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")):
bus._read(addr, length, id_, raise_on_error=raise_on_error)
else:
_, _, read_error = bus._read(addr, length, id_, raise_on_error=raise_on_error)
assert read_error == error
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__read_comm(raise_on_error, mock_motors, dummy_motors):
addr, length, id_, value = (10, 4, 1, 1337)
stub = mock_motors.build_read_stub(addr, length, id_, value, reply=False)
bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
bus.connect(handshake=False)
if raise_on_error:
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
bus._read(addr, length, id_, raise_on_error=raise_on_error)
else:
_, read_comm, _ = bus._read(addr, length, id_, raise_on_error=raise_on_error)
assert read_comm == scs.COMM_RX_TIMEOUT
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize(
"addr, length, id_, value",
[
(0, 1, 1, 2),
(10, 2, 2, 999),
(42, 4, 3, 1337),
],
)
def test__write(addr, length, id_, value, mock_motors, dummy_motors):
stub = mock_motors.build_write_stub(addr, length, id_, value)
bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
)
bus.connect(handshake=False)
comm, error = bus._write(addr, length, id_, value)
assert mock_motors.stubs[stub].wait_called()
assert comm == scs.COMM_SUCCESS
assert error == 0
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__write_error(raise_on_error, mock_motors, dummy_motors):
addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE)
stub = mock_motors.build_write_stub(addr, length, id_, value, error=error)
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
if raise_on_error:
with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")):
bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
else:
_, write_error = bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
assert write_error == error
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__write_comm(raise_on_error, mock_motors, dummy_motors):
addr, length, id_, value = (10, 4, 1, 1337)
stub = mock_motors.build_write_stub(addr, length, id_, value, reply=False)
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
if raise_on_error:
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
else:
write_comm, _ = bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
assert write_comm == scs.COMM_RX_TIMEOUT
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize(
"addr, length, ids_values",
[
(0, 1, {1: 4}),
(10, 2, {1: 1337, 2: 42}),
(42, 4, {1: 1337, 2: 42, 3: 4016}),
],
ids=["1 motor", "2 motors", "3 motors"],
)
def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors):
stub = mock_motors.build_sync_read_stub(addr, length, ids_values)
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
read_values, _ = bus._sync_read(addr, length, list(ids_values))
assert mock_motors.stubs[stub].called
assert read_values == ids_values
@pytest.mark.parametrize("raise_on_error", (True, False))
def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors):
addr, length, ids_values = (10, 4, {1: 1337})
stub = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False)
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
if raise_on_error:
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
else:
_, read_comm = bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
assert read_comm == scs.COMM_RX_TIMEOUT
assert mock_motors.stubs[stub].called
@pytest.mark.parametrize(
"addr, length, ids_values",
[
(0, 1, {1: 4}),
(10, 2, {1: 1337, 2: 42}),
(42, 4, {1: 1337, 2: 42, 3: 4016}),
],
ids=["1 motor", "2 motors", "3 motors"],
)
def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors):
stub = mock_motors.build_sync_write_stub(addr, length, ids_values)
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
comm = bus._sync_write(addr, length, ids_values)
assert mock_motors.stubs[stub].wait_called()
assert comm == scs.COMM_SUCCESS
def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration):
mins_stubs, maxes_stubs, homings_stubs = [], [], []
for cal in dummy_calibration.values():
mins_stubs.append(
mock_motors.build_read_stub(
*STS_SMS_SERIES_CONTROL_TABLE["Min_Position_Limit"], cal.id, cal.range_min
)
)
maxes_stubs.append(
mock_motors.build_read_stub(
*STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], cal.id, cal.range_max
)
)
homings_stubs.append(
mock_motors.build_read_stub(
*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"],
cal.id,
encode_sign_magnitude(cal.homing_offset, 11),
)
)
bus = FeetechMotorsBus(
port=mock_motors.port,
motors=dummy_motors,
calibration=dummy_calibration,
)
bus.connect(handshake=False)
is_calibrated = bus.is_calibrated
assert is_calibrated
assert all(mock_motors.stubs[stub].called for stub in mins_stubs)
assert all(mock_motors.stubs[stub].called for stub in maxes_stubs)
assert all(mock_motors.stubs[stub].called for stub in homings_stubs)
def test_reset_calibration(mock_motors, dummy_motors):
write_homing_stubs = []
write_mins_stubs = []
write_maxes_stubs = []
for motor in dummy_motors.values():
write_homing_stubs.append(
mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], motor.id, 0)
)
write_mins_stubs.append(
mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Min_Position_Limit"], motor.id, 0)
)
write_maxes_stubs.append(
mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], motor.id, 4095)
)
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
bus.reset_calibration()
assert all(mock_motors.stubs[stub].wait_called() for stub in write_homing_stubs)
assert all(mock_motors.stubs[stub].wait_called() for stub in write_mins_stubs)
assert all(mock_motors.stubs[stub].wait_called() for stub in write_maxes_stubs)
def test_set_half_turn_homings(mock_motors, dummy_motors):
"""
For this test, we assume that the homing offsets are already 0 such that
Present_Position == Actual_Position
"""
current_positions = {
1: 1337,
2: 42,
3: 3672,
}
expected_homings = {
1: -710, # 1337 - 2047
2: -2005, # 42 - 2047
3: 1625, # 3672 - 2047
}
read_pos_stub = mock_motors.build_sync_read_stub(
*STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], current_positions
)
write_homing_stubs = []
for id_, homing in expected_homings.items():
encoded_homing = encode_sign_magnitude(homing, 11)
stub = mock_motors.build_write_stub(
*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], id_, encoded_homing
)
write_homing_stubs.append(stub)
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
bus.reset_calibration = MagicMock()
bus.set_half_turn_homings()
bus.reset_calibration.assert_called_once()
assert mock_motors.stubs[read_pos_stub].called
assert all(mock_motors.stubs[stub].wait_called() for stub in write_homing_stubs)
def test_record_ranges_of_motion(mock_motors, dummy_motors):
positions = {
1: [351, 42, 1337],
2: [28, 3600, 2444],
3: [4002, 2999, 146],
}
expected_mins = {
"dummy_1": 42,
"dummy_2": 28,
"dummy_3": 146,
}
expected_maxes = {
"dummy_1": 1337,
"dummy_2": 3600,
"dummy_3": 4002,
}
stub = mock_motors.build_sequential_sync_read_stub(
*STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], positions
)
with patch("lerobot.motors.motors_bus.enter_pressed", side_effect=[False, True]):
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
bus.connect(handshake=False)
mins, maxes = bus.record_ranges_of_motion(display_values=False)
assert mock_motors.stubs[stub].calls == 3
assert mins == expected_mins
assert maxes == expected_maxes

View File

@ -1,358 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from unittest.mock import patch
import pytest
from lerobot.motors.motors_bus import (
Motor,
MotorNormMode,
assert_same_address,
get_address,
get_ctrl_table,
)
from tests.mocks.mock_motors_bus import (
DUMMY_CTRL_TABLE_1,
DUMMY_CTRL_TABLE_2,
DUMMY_MODEL_CTRL_TABLE,
MockMotorsBus,
)
@pytest.fixture
def dummy_motors() -> dict[str, Motor]:
return {
"dummy_1": Motor(1, "model_2", MotorNormMode.RANGE_M100_100),
"dummy_2": Motor(2, "model_3", MotorNormMode.RANGE_M100_100),
"dummy_3": Motor(3, "model_2", MotorNormMode.RANGE_0_100),
}
def test_get_ctrl_table():
model = "model_1"
ctrl_table = get_ctrl_table(DUMMY_MODEL_CTRL_TABLE, model)
assert ctrl_table == DUMMY_CTRL_TABLE_1
def test_get_ctrl_table_error():
model = "model_99"
with pytest.raises(KeyError, match=f"Control table for {model=} not found."):
get_ctrl_table(DUMMY_MODEL_CTRL_TABLE, model)
def test_get_address():
addr, n_bytes = get_address(DUMMY_MODEL_CTRL_TABLE, "model_1", "Firmware_Version")
assert addr == 0
assert n_bytes == 1
def test_get_address_error():
model = "model_1"
data_name = "Lock"
with pytest.raises(KeyError, match=f"Address for '{data_name}' not found in {model} control table."):
get_address(DUMMY_MODEL_CTRL_TABLE, "model_1", data_name)
def test_assert_same_address():
models = ["model_1", "model_2"]
assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Present_Position")
def test_assert_same_length_different_addresses():
models = ["model_1", "model_2"]
with pytest.raises(
NotImplementedError,
match=re.escape("At least two motor models use a different address"),
):
assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Model_Number")
def test_assert_same_address_different_length():
models = ["model_1", "model_2"]
with pytest.raises(
NotImplementedError,
match=re.escape("At least two motor models use a different bytes representation"),
):
assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Goal_Position")
def test__serialize_data_invalid_length():
bus = MockMotorsBus("", {})
with pytest.raises(NotImplementedError):
bus._serialize_data(100, 3)
def test__serialize_data_negative_numbers():
bus = MockMotorsBus("", {})
with pytest.raises(ValueError):
bus._serialize_data(-1, 1)
def test__serialize_data_large_number():
bus = MockMotorsBus("", {})
with pytest.raises(ValueError):
bus._serialize_data(2**32, 4) # 4-byte max is 0xFFFFFFFF
@pytest.mark.parametrize(
"data_name, id_, value",
[
("Firmware_Version", 1, 14),
("Model_Number", 1, 5678),
("Present_Position", 2, 1337),
("Present_Velocity", 3, 42),
],
)
def test_read(data_name, id_, value, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(handshake=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
with (
patch.object(MockMotorsBus, "_read", return_value=(value, 0, 0)) as mock__read,
patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign,
patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize,
):
returned_value = bus.read(data_name, f"dummy_{id_}")
assert returned_value == value
mock__read.assert_called_once_with(
addr,
length,
id_,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to read '{data_name}' on {id_=} after 1 tries.",
)
mock__decode_sign.assert_called_once_with(data_name, {id_: value})
if data_name in bus.normalized_data:
mock__normalize.assert_called_once_with({id_: value})
@pytest.mark.parametrize(
"data_name, id_, value",
[
("Goal_Position", 1, 1337),
("Goal_Velocity", 2, 3682),
("Lock", 3, 1),
],
)
def test_write(data_name, id_, value, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(handshake=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
with (
patch.object(MockMotorsBus, "_write", return_value=(0, 0)) as mock__write,
patch.object(MockMotorsBus, "_encode_sign", return_value={id_: value}) as mock__encode_sign,
patch.object(MockMotorsBus, "_unnormalize", return_value={id_: value}) as mock__unnormalize,
):
bus.write(data_name, f"dummy_{id_}", value)
mock__write.assert_called_once_with(
addr,
length,
id_,
value,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to write '{data_name}' on {id_=} with '{value}' after 1 tries.",
)
mock__encode_sign.assert_called_once_with(data_name, {id_: value})
if data_name in bus.normalized_data:
mock__unnormalize.assert_called_once_with({id_: value})
@pytest.mark.parametrize(
"data_name, id_, value",
[
("Firmware_Version", 1, 14),
("Model_Number", 1, 5678),
("Present_Position", 2, 1337),
("Present_Velocity", 3, 42),
],
)
def test_sync_read_by_str(data_name, id_, value, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(handshake=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
ids = [id_]
expected_value = {f"dummy_{id_}": value}
with (
patch.object(MockMotorsBus, "_sync_read", return_value=({id_: value}, 0)) as mock__sync_read,
patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign,
patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize,
):
returned_dict = bus.sync_read(data_name, f"dummy_{id_}")
assert returned_dict == expected_value
mock__sync_read.assert_called_once_with(
addr,
length,
ids,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.",
)
mock__decode_sign.assert_called_once_with(data_name, {id_: value})
if data_name in bus.normalized_data:
mock__normalize.assert_called_once_with({id_: value})
@pytest.mark.parametrize(
"data_name, ids_values",
[
("Model_Number", {1: 5678}),
("Present_Position", {1: 1337, 2: 42}),
("Present_Velocity", {1: 1337, 2: 42, 3: 4016}),
],
ids=["1 motor", "2 motors", "3 motors"],
)
def test_sync_read_by_list(data_name, ids_values, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(handshake=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
ids = list(ids_values)
expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()}
with (
patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read,
patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign,
patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize,
):
returned_dict = bus.sync_read(data_name, [f"dummy_{id_}" for id_ in ids])
assert returned_dict == expected_values
mock__sync_read.assert_called_once_with(
addr,
length,
ids,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.",
)
mock__decode_sign.assert_called_once_with(data_name, ids_values)
if data_name in bus.normalized_data:
mock__normalize.assert_called_once_with(ids_values)
@pytest.mark.parametrize(
"data_name, ids_values",
[
("Model_Number", {1: 5678, 2: 5799, 3: 5678}),
("Present_Position", {1: 1337, 2: 42, 3: 4016}),
("Goal_Position", {1: 4008, 2: 199, 3: 3446}),
],
ids=["Model_Number", "Present_Position", "Goal_Position"],
)
def test_sync_read_by_none(data_name, ids_values, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(handshake=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
ids = list(ids_values)
expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()}
with (
patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read,
patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign,
patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize,
):
returned_dict = bus.sync_read(data_name)
assert returned_dict == expected_values
mock__sync_read.assert_called_once_with(
addr,
length,
ids,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.",
)
mock__decode_sign.assert_called_once_with(data_name, ids_values)
if data_name in bus.normalized_data:
mock__normalize.assert_called_once_with(ids_values)
@pytest.mark.parametrize(
"data_name, value",
[
("Goal_Position", 500),
("Goal_Velocity", 4010),
("Lock", 0),
],
)
def test_sync_write_by_single_value(data_name, value, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(handshake=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
ids_values = {m.id: value for m in dummy_motors.values()}
with (
patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write,
patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign,
patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize,
):
bus.sync_write(data_name, value)
mock__sync_write.assert_called_once_with(
addr,
length,
ids_values,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.",
)
mock__encode_sign.assert_called_once_with(data_name, ids_values)
if data_name in bus.normalized_data:
mock__unnormalize.assert_called_once_with(ids_values)
@pytest.mark.parametrize(
"data_name, ids_values",
[
("Goal_Position", {1: 1337, 2: 42, 3: 4016}),
("Goal_Velocity", {1: 50, 2: 83, 3: 2777}),
("Lock", {1: 0, 2: 0, 3: 1}),
],
ids=["Goal_Position", "Goal_Velocity", "Lock"],
)
def test_sync_write_by_value_dict(data_name, ids_values, dummy_motors):
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
bus.connect(handshake=False)
addr, length = DUMMY_CTRL_TABLE_2[data_name]
values = {f"dummy_{id_}": val for id_, val in ids_values.items()}
with (
patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write,
patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign,
patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize,
):
bus.sync_write(data_name, values)
mock__sync_write.assert_called_once_with(
addr,
length,
ids_values,
num_retry=0,
raise_on_error=True,
err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.",
)
mock__encode_sign.assert_called_once_with(data_name, ids_values)
if data_name in bus.normalized_data:
mock__unnormalize.assert_called_once_with(ids_values)

View File

@ -1,242 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from lerobot.constants import (
OPTIMIZER_PARAM_GROUPS,
OPTIMIZER_STATE,
)
from lerobot.optim.optimizers import (
AdamConfig,
AdamWConfig,
MultiAdamConfig,
SGDConfig,
load_optimizer_state,
save_optimizer_state,
)
@pytest.mark.parametrize(
"config_cls, expected_class",
[
(AdamConfig, torch.optim.Adam),
(AdamWConfig, torch.optim.AdamW),
(SGDConfig, torch.optim.SGD),
(MultiAdamConfig, dict),
],
)
def test_optimizer_build(config_cls, expected_class, model_params):
config = config_cls()
if config_cls == MultiAdamConfig:
params_dict = {"default": model_params}
optimizer = config.build(params_dict)
assert isinstance(optimizer, expected_class)
assert isinstance(optimizer["default"], torch.optim.Adam)
assert optimizer["default"].defaults["lr"] == config.lr
else:
optimizer = config.build(model_params)
assert isinstance(optimizer, expected_class)
assert optimizer.defaults["lr"] == config.lr
def test_save_optimizer_state(optimizer, tmp_path):
save_optimizer_state(optimizer, tmp_path)
assert (tmp_path / OPTIMIZER_STATE).is_file()
assert (tmp_path / OPTIMIZER_PARAM_GROUPS).is_file()
def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path):
save_optimizer_state(optimizer, tmp_path)
loaded_optimizer = AdamConfig().build(model_params)
loaded_optimizer = load_optimizer_state(loaded_optimizer, tmp_path)
torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict())
@pytest.fixture
def base_params_dict():
return {
"actor": [torch.nn.Parameter(torch.randn(10, 10))],
"critic": [torch.nn.Parameter(torch.randn(5, 5))],
"temperature": [torch.nn.Parameter(torch.randn(3, 3))],
}
@pytest.mark.parametrize(
"config_params, expected_values",
[
# Test 1: Basic configuration with different learning rates
(
{
"lr": 1e-3,
"weight_decay": 1e-4,
"optimizer_groups": {
"actor": {"lr": 1e-4},
"critic": {"lr": 5e-4},
"temperature": {"lr": 2e-3},
},
},
{
"actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
"critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
},
),
# Test 2: Different weight decays and beta values
(
{
"lr": 1e-3,
"weight_decay": 1e-4,
"optimizer_groups": {
"actor": {"lr": 1e-4, "weight_decay": 1e-5},
"critic": {"lr": 5e-4, "weight_decay": 1e-6},
"temperature": {"lr": 2e-3, "betas": (0.95, 0.999)},
},
},
{
"actor": {"lr": 1e-4, "weight_decay": 1e-5, "betas": (0.9, 0.999)},
"critic": {"lr": 5e-4, "weight_decay": 1e-6, "betas": (0.9, 0.999)},
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.95, 0.999)},
},
),
# Test 3: Epsilon parameter customization
(
{
"lr": 1e-3,
"weight_decay": 1e-4,
"optimizer_groups": {
"actor": {"lr": 1e-4, "eps": 1e-6},
"critic": {"lr": 5e-4, "eps": 1e-7},
"temperature": {"lr": 2e-3, "eps": 1e-8},
},
},
{
"actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-6},
"critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-7},
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-8},
},
),
],
)
def test_multi_adam_configuration(base_params_dict, config_params, expected_values):
# Create config with the given parameters
config = MultiAdamConfig(**config_params)
optimizers = config.build(base_params_dict)
# Verify optimizer count and keys
assert len(optimizers) == len(expected_values)
assert set(optimizers.keys()) == set(expected_values.keys())
# Check that all optimizers are Adam instances
for opt in optimizers.values():
assert isinstance(opt, torch.optim.Adam)
# Verify hyperparameters for each optimizer
for name, expected in expected_values.items():
optimizer = optimizers[name]
for param, value in expected.items():
assert optimizer.defaults[param] == value
@pytest.fixture
def multi_optimizers(base_params_dict):
config = MultiAdamConfig(
lr=1e-3,
optimizer_groups={
"actor": {"lr": 1e-4},
"critic": {"lr": 5e-4},
"temperature": {"lr": 2e-3},
},
)
return config.build(base_params_dict)
def test_save_multi_optimizer_state(multi_optimizers, tmp_path):
# Save optimizer states
save_optimizer_state(multi_optimizers, tmp_path)
# Verify that directories were created for each optimizer
for name in multi_optimizers:
assert (tmp_path / name).is_dir()
assert (tmp_path / name / OPTIMIZER_STATE).is_file()
assert (tmp_path / name / OPTIMIZER_PARAM_GROUPS).is_file()
def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers, tmp_path):
# Option 1: Add a minimal backward pass to populate optimizer states
for name, params in base_params_dict.items():
if name in multi_optimizers:
# Create a dummy loss and do backward
dummy_loss = params[0].sum()
dummy_loss.backward()
# Perform an optimization step
multi_optimizers[name].step()
# Zero gradients for next steps
multi_optimizers[name].zero_grad()
# Save optimizer states
save_optimizer_state(multi_optimizers, tmp_path)
# Create new optimizers with the same config
config = MultiAdamConfig(
lr=1e-3,
optimizer_groups={
"actor": {"lr": 1e-4},
"critic": {"lr": 5e-4},
"temperature": {"lr": 2e-3},
},
)
new_optimizers = config.build(base_params_dict)
# Load optimizer states
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
# Verify state dictionaries match
for name in multi_optimizers:
torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict())
def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
"""Test saving and loading optimizer states even when the state is empty (no backward pass)."""
# Create config and build optimizers
config = MultiAdamConfig(
lr=1e-3,
optimizer_groups={
"actor": {"lr": 1e-4},
"critic": {"lr": 5e-4},
"temperature": {"lr": 2e-3},
},
)
optimizers = config.build(base_params_dict)
# Save optimizer states without any backward pass (empty state)
save_optimizer_state(optimizers, tmp_path)
# Create new optimizers with the same config
new_optimizers = config.build(base_params_dict)
# Load optimizer states
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
# Verify hyperparameters match even with empty state
for name, optimizer in optimizers.items():
assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"]
assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"]
assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"]
# Verify state dictionaries match (they will be empty)
torch.testing.assert_close(
optimizer.state_dict()["param_groups"], loaded_optimizers[name].state_dict()["param_groups"]
)

View File

@ -1,91 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch.optim.lr_scheduler import LambdaLR
from lerobot.constants import SCHEDULER_STATE
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
DiffuserSchedulerConfig,
VQBeTSchedulerConfig,
load_scheduler_state,
save_scheduler_state,
)
def test_diffuser_scheduler(optimizer):
config = DiffuserSchedulerConfig(name="cosine", num_warmup_steps=5)
scheduler = config.build(optimizer, num_training_steps=100)
assert isinstance(scheduler, LambdaLR)
optimizer.step() # so that we don't get torch warning
scheduler.step()
expected_state_dict = {
"_get_lr_called_within_step": False,
"_last_lr": [0.0002],
"_step_count": 2,
"base_lrs": [0.001],
"last_epoch": 1,
"lr_lambdas": [None],
}
assert scheduler.state_dict() == expected_state_dict
def test_vqbet_scheduler(optimizer):
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
scheduler = config.build(optimizer, num_training_steps=100)
assert isinstance(scheduler, LambdaLR)
optimizer.step()
scheduler.step()
expected_state_dict = {
"_get_lr_called_within_step": False,
"_last_lr": [0.001],
"_step_count": 2,
"base_lrs": [0.001],
"last_epoch": 1,
"lr_lambdas": [None],
}
assert scheduler.state_dict() == expected_state_dict
def test_cosine_decay_with_warmup_scheduler(optimizer):
config = CosineDecayWithWarmupSchedulerConfig(
num_warmup_steps=10, num_decay_steps=90, peak_lr=0.01, decay_lr=0.001
)
scheduler = config.build(optimizer, num_training_steps=100)
assert isinstance(scheduler, LambdaLR)
optimizer.step()
scheduler.step()
expected_state_dict = {
"_get_lr_called_within_step": False,
"_last_lr": [0.0001818181818181819],
"_step_count": 2,
"base_lrs": [0.001],
"last_epoch": 1,
"lr_lambdas": [None],
}
assert scheduler.state_dict() == expected_state_dict
def test_save_scheduler_state(scheduler, tmp_path):
save_scheduler_state(scheduler, tmp_path)
assert (tmp_path / SCHEDULER_STATE).is_file()
def test_save_load_scheduler_state(scheduler, tmp_path):
save_scheduler_state(scheduler, tmp_path)
loaded_scheduler = load_scheduler_state(scheduler, tmp_path)
assert scheduler.state_dict() == loaded_scheduler.state_dict()

View File

@ -1,139 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
from tests.utils import require_package
def test_classifier_output():
output = ClassifierOutput(
logits=torch.tensor([1, 2, 3]),
probabilities=torch.tensor([0.1, 0.2, 0.3]),
hidden_states=None,
)
assert (
f"{output}"
== "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)"
)
@require_package("transformers")
def test_binary_classifier_with_default_params():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
config = RewardClassifierConfig()
config.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
}
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"REWARD": NormalizationMode.IDENTITY,
}
config.num_cameras = 1
classifier = Classifier(config)
batch_size = 10
input = {
"observation.image": torch.rand((batch_size, 3, 128, 128)),
"next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(),
}
images, labels = classifier.extract_images_and_labels(input)
assert len(images) == 1
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
assert labels.shape == torch.Size([batch_size])
output = classifier.predict(images)
assert output is not None
assert output.logits.size() == torch.Size([batch_size])
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
assert output.probabilities.shape == torch.Size([batch_size])
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
assert output.hidden_states.shape == torch.Size([batch_size, 256])
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
@require_package("transformers")
def test_multiclass_classifier():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
num_classes = 5
config = RewardClassifierConfig()
config.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
}
config.num_cameras = 1
config.num_classes = num_classes
classifier = Classifier(config)
batch_size = 10
input = {
"observation.image": torch.rand((batch_size, 3, 128, 128)),
"next.reward": torch.rand((batch_size, num_classes)),
}
images, labels = classifier.extract_images_and_labels(input)
assert len(images) == 1
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
assert labels.shape == torch.Size([batch_size, num_classes])
output = classifier.predict(images)
assert output is not None
assert output.logits.shape == torch.Size([batch_size, num_classes])
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
assert output.probabilities.shape == torch.Size([batch_size, num_classes])
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
assert output.hidden_states.shape == torch.Size([batch_size, 256])
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
@require_package("transformers")
def test_default_device():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
config = RewardClassifierConfig()
assert config.device == "cpu"
classifier = Classifier(config)
for p in classifier.parameters():
assert p.device == torch.device("cpu")
@require_package("transformers")
def test_explicit_device_setup():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
config = RewardClassifierConfig(device="cpu")
assert config.device == "cpu"
classifier = Classifier(config)
for p in classifier.parameters():
assert p.device == torch.device("cpu")

View File

@ -1,546 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from copy import deepcopy
from pathlib import Path
import einops
import pytest
import torch
from packaging import version
from safetensors.torch import load_file
from lerobot import available_policies
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.utils import cycle, dataset_to_policy_features
from lerobot.envs.factory import make_env, make_env_config
from lerobot.envs.utils import preprocess_observation
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.act.modeling_act import ACTTemporalEnsembler
from lerobot.policies.factory import (
get_policy_class,
make_policy,
make_policy_config,
)
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.random_utils import seeded_context
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
@pytest.fixture
def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_path):
# Create only one camera input which is squared to fit all current policy constraints
# e.g. vqbet and tdmpc works with one camera only, and tdmpc requires it to be squared
camera_features = {
"observation.images.laptop": {
"shape": (84, 84, 3),
"names": ["height", "width", "channels"],
"info": None,
},
}
motor_features = {
"action": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
},
"observation.state": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
},
}
info = info_factory(
total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features
)
ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info)
return ds_meta
@pytest.mark.parametrize("policy_name", available_policies)
def test_get_policy_and_config_classes(policy_name: str):
"""Check that the correct policy and config classes are returned."""
policy_cls = get_policy_class(policy_name)
policy_cfg = make_policy_config(policy_name)
assert policy_cls.name == policy_name
assert issubclass(
policy_cfg.__class__, inspect.signature(policy_cls.__init__).parameters["config"].annotation
)
@pytest.mark.parametrize(
"ds_repo_id,env_name,env_kwargs,policy_name,policy_kwargs",
[
("lerobot/xarm_lift_medium", "xarm", {}, "tdmpc", {"use_mpc": True}),
("lerobot/pusht", "pusht", {}, "diffusion", {}),
("lerobot/pusht", "pusht", {}, "vqbet", {}),
("lerobot/pusht", "pusht", {}, "act", {}),
("lerobot/aloha_sim_insertion_human", "aloha", {"task": "AlohaInsertion-v0"}, "act", {}),
(
"lerobot/aloha_sim_insertion_scripted",
"aloha",
{"task": "AlohaInsertion-v0"},
"act",
{},
),
(
"lerobot/aloha_sim_insertion_human",
"aloha",
{"task": "AlohaInsertion-v0"},
"diffusion",
{},
),
(
"lerobot/aloha_sim_transfer_cube_human",
"aloha",
{"task": "AlohaTransferCube-v0"},
"act",
{},
),
(
"lerobot/aloha_sim_transfer_cube_scripted",
"aloha",
{"task": "AlohaTransferCube-v0"},
"act",
{},
),
],
)
@require_env
def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
"""
Tests:
- Making the policy object.
- Checking that the policy follows the correct protocol and subclasses nn.Module
and PyTorchModelHubMixin.
- Updating the policy.
- Using the policy to select actions at inference time.
- Test the action can be applied to the policy
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
and for now we add tests as we see fit.
"""
train_cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs),
env=make_env_config(env_name, **env_kwargs),
)
train_cfg.validate()
# Check that we can make the policy object.
dataset = make_dataset(train_cfg)
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
assert isinstance(policy, PreTrainedPolicy)
# Check that we run select_actions and get the appropriate output.
env = make_env(train_cfg.env, n_envs=2)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=2,
shuffle=True,
pin_memory=DEVICE != "cpu",
drop_last=True,
)
dl_iter = cycle(dataloader)
batch = next(dl_iter)
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(DEVICE, non_blocking=True)
# Test updating the policy (and test that it does not mutate the batch)
batch_ = deepcopy(batch)
policy.forward(batch)
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass."
assert all(
torch.equal(batch[k], batch_[k]) if isinstance(batch[k], torch.Tensor) else batch[k] == batch_[k]
for k in batch
), "Batch values are not the same after a forward pass."
# reset the policy and environment
policy.reset()
observation, _ = env.reset(seed=train_cfg.seed)
# apply transform to normalize the observations
observation = preprocess_observation(observation)
# send observation to device/gpu
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
# get the next action for the environment (also check that the observation batch is not modified)
observation_ = deepcopy(observation)
with torch.inference_mode():
action = policy.select_action(observation).cpu().numpy()
assert set(observation) == set(observation_), (
"Observation batch keys are not the same after a forward pass."
)
assert all(torch.equal(observation[k], observation_[k]) for k in observation), (
"Observation batch values are not the same after a forward pass."
)
# Test step through policy
env.step(action)
# TODO(rcadene, aliberts): This test is quite end-to-end. Move this test in test_optimizer?
def test_act_backbone_lr():
"""
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
"""
cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001, push_to_hub=False),
)
cfg.validate() # Needed for auto-setting some parameters
assert cfg.policy.optimizer_lr == 0.01
assert cfg.policy.optimizer_lr_backbone == 0.001
dataset = make_dataset(cfg)
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
assert len(optimizer.param_groups) == 2
assert optimizer.param_groups[0]["lr"] == cfg.policy.optimizer_lr
assert optimizer.param_groups[1]["lr"] == cfg.policy.optimizer_lr_backbone
assert len(optimizer.param_groups[0]["params"]) == 133
assert len(optimizer.param_groups[1]["params"]) == 20
@pytest.mark.parametrize("policy_name", available_policies)
def test_policy_defaults(dummy_dataset_metadata, policy_name: str):
"""Check that the policy can be instantiated with defaults."""
policy_cls = get_policy_class(policy_name)
policy_cfg = make_policy_config(policy_name)
features = dataset_to_policy_features(dummy_dataset_metadata.features)
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
policy_cfg.input_features = {
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
}
policy_cls(policy_cfg)
@pytest.mark.parametrize("policy_name", available_policies)
def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: str):
policy_cls = get_policy_class(policy_name)
policy_cfg = make_policy_config(policy_name)
features = dataset_to_policy_features(dummy_dataset_metadata.features)
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
policy_cfg.input_features = {
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
}
policy = policy_cls(policy_cfg)
policy.to(policy_cfg.device)
save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}"
policy.save_pretrained(save_dir)
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
def test_normalize(insert_temporal_dim):
"""
Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise
an exception when the forward pass is called without the stats having been provided.
TODO(rcadene, alexander-soare): This should also test that the normalization / unnormalization works as
expected.
"""
input_features = {
"observation.image": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 96, 96),
),
"observation.state": PolicyFeature(
type=FeatureType.STATE,
shape=(10,),
),
}
output_features = {
"action": PolicyFeature(
type=FeatureType.ACTION,
shape=(5,),
),
}
norm_map = {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
dataset_stats = {
"observation.image": {
"mean": torch.randn(3, 1, 1),
"std": torch.randn(3, 1, 1),
"min": torch.randn(3, 1, 1),
"max": torch.randn(3, 1, 1),
},
"observation.state": {
"mean": torch.randn(10),
"std": torch.randn(10),
"min": torch.randn(10),
"max": torch.randn(10),
},
"action": {
"mean": torch.randn(5),
"std": torch.randn(5),
"min": torch.randn(5),
"max": torch.randn(5),
},
}
bsize = 2
input_batch = {
"observation.image": torch.randn(bsize, 3, 96, 96),
"observation.state": torch.randn(bsize, 10),
}
output_batch = {
"action": torch.randn(bsize, 5),
}
if insert_temporal_dim:
tdim = 4
for key in input_batch:
# [2,3,96,96] -> [2,tdim,3,96,96]
input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1)
for key in output_batch:
output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1)
# test without stats
normalize = Normalize(input_features, norm_map, stats=None)
with pytest.raises(AssertionError):
normalize(input_batch)
# test with stats
normalize = Normalize(input_features, norm_map, stats=dataset_stats)
normalize(input_batch)
# test loading pretrained models
new_normalize = Normalize(input_features, norm_map, stats=None)
new_normalize.load_state_dict(normalize.state_dict())
new_normalize(input_batch)
# test without stats
unnormalize = Unnormalize(output_features, norm_map, stats=None)
with pytest.raises(AssertionError):
unnormalize(output_batch)
# test with stats
unnormalize = Unnormalize(output_features, norm_map, stats=dataset_stats)
unnormalize(output_batch)
# test loading pretrained models
new_unnormalize = Unnormalize(output_features, norm_map, stats=None)
new_unnormalize.load_state_dict(unnormalize.state_dict())
unnormalize(output_batch)
@pytest.mark.parametrize("multikey", [True, False])
def test_multikey_construction(multikey: bool):
"""
Asserts that multiple keys with type State/Action are correctly processed by the policy constructor,
preventing erroneous creation of the policy object.
"""
input_features = {
"observation.state": PolicyFeature(
type=FeatureType.STATE,
shape=(10,),
),
}
output_features = {
"action": PolicyFeature(
type=FeatureType.ACTION,
shape=(5,),
),
}
if multikey:
"""Simulates the complete state/action is constructed from more granular multiple
keys, of the same type as the overall state/action"""
input_features = {}
input_features["observation.state.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
input_features["observation.state.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(10,))
output_features = {}
output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,))
output_features["action.last_two_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(2,))
output_features["action"] = PolicyFeature(
type=FeatureType.ACTION,
shape=(5,),
)
config = ACTConfig(input_features=input_features, output_features=output_features)
state_condition = config.robot_state_feature == input_features[OBS_STATE]
action_condition = config.action_feature == output_features[ACTION]
assert state_condition, (
f"Discrepancy detected. Robot state feature is {config.robot_state_feature} but policy expects {input_features[OBS_STATE]}"
)
assert action_condition, (
f"Discrepancy detected. Action feature is {config.action_feature} but policy expects {output_features[ACTION]}"
)
@pytest.mark.parametrize(
"ds_repo_id, policy_name, policy_kwargs, file_name_extra",
[
# TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
# to test with `policy.use_mpc=false`.
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
# ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
# TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to
# to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference
# that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.
# Thus, we deactivate this test for now.
(
"lerobot/pusht",
"diffusion",
{
"n_action_steps": 8,
"num_inference_steps": 10,
"down_dims": [128, 256, 512],
},
"",
),
("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""),
(
"lerobot/aloha_sim_insertion_human",
"act",
{"n_action_steps": 1000, "chunk_size": 1000},
"1000_steps",
),
],
)
# As artifacts have been generated on an x86_64 kernel, this test won't
# pass if it's run on another platform due to floating point errors
@require_x86_64_kernel
@require_cpu
def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str):
"""
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
include a report on what changed and how that affected the outputs.
2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and
add the policies you want to update the test artifacts for.
3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact
should be updated.
4. Check that this test now passes.
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
6. Remember to stage and commit the resulting changes to `tests/artifacts`.
NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact
is out of date. For example, some PyTorch versions have different randomness, see this PR:
https://github.com/huggingface/lerobot/pull/1127.
"""
# NOTE: ACT policy has different randomness, after PyTorch 2.7.0
if policy_name == "act" and version.parse(torch.__version__) < version.parse("2.7.0"):
pytest.skip(f"Skipping act policy test with PyTorch {torch.__version__}. Requires PyTorch >= 2.7.0")
ds_name = ds_repo_id.split("/")[-1]
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
saved_grad_stats = load_file(artifact_dir / "grad_stats.safetensors")
saved_param_stats = load_file(artifact_dir / "param_stats.safetensors")
saved_actions = load_file(artifact_dir / "actions.safetensors")
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
for key in saved_output_dict:
torch.testing.assert_close(output_dict[key], saved_output_dict[key])
for key in saved_grad_stats:
torch.testing.assert_close(grad_stats[key], saved_grad_stats[key])
for key in saved_param_stats:
torch.testing.assert_close(param_stats[key], saved_param_stats[key])
for key in saved_actions:
rtol, atol = (2e-3, 5e-6) if policy_name == "diffusion" else (None, None) # HACK
torch.testing.assert_close(actions[key], saved_actions[key], rtol=rtol, atol=atol)
def test_act_temporal_ensembler():
"""Check that the online method in ACTTemporalEnsembler matches a simple offline calculation."""
temporal_ensemble_coeff = 0.01
chunk_size = 100
episode_length = 101
ensembler = ACTTemporalEnsembler(temporal_ensemble_coeff, chunk_size)
# An batch of arbitrary sequences of 1D actions we wish to compute the average over. We'll keep the
# "action space" in [-1, 1]. Apart from that, there is no real reason for the numbers chosen.
with seeded_context(0):
# Dimension is (batch, episode_length, chunk_size, action_dim(=1))
# Stepping through the episode_length dim is like running inference at each rollout step and getting
# a different action chunk.
batch_seq = torch.stack(
[
torch.rand(episode_length, chunk_size) * 0.05 - 0.6,
torch.rand(episode_length, chunk_size) * 0.02 - 0.01,
torch.rand(episode_length, chunk_size) * 0.2 + 0.3,
],
dim=0,
).unsqueeze(-1) # unsqueeze for action dim
batch_size = batch_seq.shape[0]
# Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length`
# dimension of `batch_seq`.
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1)
# Simulate stepping through a rollout and computing a batch of actions with model on each step.
for i in range(episode_length):
# Mock a batch of actions.
actions = torch.zeros(size=(batch_size, chunk_size, 1)) + batch_seq[:, i]
online_avg = ensembler.update(actions)
# Simple offline calculation: avg = Σ(aᵢ*wᵢ) / Σ(wᵢ).
# Note: The complicated bit here is the slicing. Think about the (episode_length, chunk_size) grid.
# What we want to do is take diagonal slices across it starting from the left.
# eg: chunk_size=4, episode_length=6
# ┌───────┐
# │0 1 2 3│
# │1 2 3 4│
# │2 3 4 5│
# │3 4 5 6│
# │4 5 6 7│
# │5 6 7 8│
# └───────┘
chunk_indices = torch.arange(min(i, chunk_size - 1), -1, -1)
episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :]
seq_slice = batch_seq[:, episode_step_indices, chunk_indices]
offline_avg = (
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum()
)
# Sanity check. The average should be between the extrema.
assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg)
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
torch.testing.assert_close(online_avg, offline_avg, rtol=1e-4, atol=1e-4)

View File

@ -1,217 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.configuration_sac import (
ActorLearnerConfig,
ActorNetworkConfig,
ConcurrencyConfig,
CriticNetworkConfig,
PolicyConfig,
SACConfig,
)
def test_sac_config_default_initialization():
config = SACConfig()
assert config.normalization_mapping == {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ENV": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
assert config.dataset_stats == {
"observation.image": {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
},
"observation.state": {
"min": [0.0, 0.0],
"max": [1.0, 1.0],
},
"action": {
"min": [0.0, 0.0, 0.0],
"max": [1.0, 1.0, 1.0],
},
}
# Basic parameters
assert config.device == "cpu"
assert config.storage_device == "cpu"
assert config.discount == 0.99
assert config.temperature_init == 1.0
assert config.num_critics == 2
# Architecture specifics
assert config.vision_encoder_name is None
assert config.freeze_vision_encoder is True
assert config.image_encoder_hidden_dim == 32
assert config.shared_encoder is True
assert config.num_discrete_actions is None
assert config.image_embedding_pooling_dim == 8
# Training parameters
assert config.online_steps == 1000000
assert config.online_env_seed == 10000
assert config.online_buffer_capacity == 100000
assert config.offline_buffer_capacity == 100000
assert config.async_prefetch is False
assert config.online_step_before_learning == 100
assert config.policy_update_freq == 1
# SAC algorithm parameters
assert config.num_subsample_critics is None
assert config.critic_lr == 3e-4
assert config.actor_lr == 3e-4
assert config.temperature_lr == 3e-4
assert config.critic_target_update_weight == 0.005
assert config.utd_ratio == 1
assert config.state_encoder_hidden_dim == 256
assert config.latent_dim == 256
assert config.target_entropy is None
assert config.use_backup_entropy is True
assert config.grad_clip_norm == 40.0
# Dataset stats defaults
expected_dataset_stats = {
"observation.image": {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
},
"observation.state": {
"min": [0.0, 0.0],
"max": [1.0, 1.0],
},
"action": {
"min": [0.0, 0.0, 0.0],
"max": [1.0, 1.0, 1.0],
},
}
assert config.dataset_stats == expected_dataset_stats
# Critic network configuration
assert config.critic_network_kwargs.hidden_dims == [256, 256]
assert config.critic_network_kwargs.activate_final is True
assert config.critic_network_kwargs.final_activation is None
# Actor network configuration
assert config.actor_network_kwargs.hidden_dims == [256, 256]
assert config.actor_network_kwargs.activate_final is True
# Policy configuration
assert config.policy_kwargs.use_tanh_squash is True
assert config.policy_kwargs.std_min == 1e-5
assert config.policy_kwargs.std_max == 10.0
assert config.policy_kwargs.init_final == 0.05
# Discrete critic network configuration
assert config.discrete_critic_network_kwargs.hidden_dims == [256, 256]
assert config.discrete_critic_network_kwargs.activate_final is True
assert config.discrete_critic_network_kwargs.final_activation is None
# Actor learner configuration
assert config.actor_learner_config.learner_host == "127.0.0.1"
assert config.actor_learner_config.learner_port == 50051
assert config.actor_learner_config.policy_parameters_push_frequency == 4
# Concurrency configuration
assert config.concurrency.actor == "threads"
assert config.concurrency.learner == "threads"
assert isinstance(config.actor_network_kwargs, ActorNetworkConfig)
assert isinstance(config.critic_network_kwargs, CriticNetworkConfig)
assert isinstance(config.policy_kwargs, PolicyConfig)
assert isinstance(config.actor_learner_config, ActorLearnerConfig)
assert isinstance(config.concurrency, ConcurrencyConfig)
def test_critic_network_kwargs():
config = CriticNetworkConfig()
assert config.hidden_dims == [256, 256]
assert config.activate_final is True
assert config.final_activation is None
def test_actor_network_kwargs():
config = ActorNetworkConfig()
assert config.hidden_dims == [256, 256]
assert config.activate_final is True
def test_policy_kwargs():
config = PolicyConfig()
assert config.use_tanh_squash is True
assert config.std_min == 1e-5
assert config.std_max == 10.0
assert config.init_final == 0.05
def test_actor_learner_config():
config = ActorLearnerConfig()
assert config.learner_host == "127.0.0.1"
assert config.learner_port == 50051
assert config.policy_parameters_push_frequency == 4
def test_concurrency_config():
config = ConcurrencyConfig()
assert config.actor == "threads"
assert config.learner == "threads"
def test_sac_config_custom_initialization():
config = SACConfig(
device="cpu",
discount=0.95,
temperature_init=0.5,
num_critics=3,
)
assert config.device == "cpu"
assert config.discount == 0.95
assert config.temperature_init == 0.5
assert config.num_critics == 3
def test_validate_features():
config = SACConfig(
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
config.validate_features()
def test_validate_features_missing_observation():
config = SACConfig(
input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
with pytest.raises(
ValueError, match="You must provide either 'observation.state' or an image observation"
):
config.validate_features()
def test_validate_features_missing_action():
config = SACConfig(
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
with pytest.raises(ValueError, match="You must provide 'action' in the output features"):
config.validate_features()

View File

@ -1,541 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import pytest
import torch
from torch import Tensor, nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
from lerobot.utils.random_utils import seeded_context, set_seed
try:
import transformers # noqa: F401
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
@pytest.fixture(autouse=True)
def set_random_seed():
seed = 42
set_seed(seed)
def test_mlp_with_default_args():
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
x = torch.randn(10)
y = mlp(x)
assert y.shape == (256,)
def test_mlp_with_batch_dim():
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
x = torch.randn(2, 10)
y = mlp(x)
assert y.shape == (2, 256)
def test_forward_with_empty_hidden_dims():
mlp = MLP(input_dim=10, hidden_dims=[])
x = torch.randn(1, 10)
assert mlp(x).shape == (1, 10)
def test_mlp_with_dropout():
mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1)
x = torch.randn(1, 10)
y = mlp(x)
assert y.shape == (1, 11)
drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net)
assert drop_out_layers_count == 2
def test_mlp_with_custom_final_activation():
mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh())
x = torch.randn(1, 10)
y = mlp(x)
assert y.shape == (1, 256)
assert (y >= -1).all() and (y <= 1).all()
def test_sac_policy_with_default_args():
with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"):
SACPolicy()
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
return {
"observation.state": torch.randn(batch_size, state_dim),
}
def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor:
return {
"observation.image": torch.randn(batch_size, 3, 84, 84),
"observation.state": torch.randn(batch_size, state_dim),
}
def create_dummy_action(batch_size: int, action_dim: int = 10) -> Tensor:
return torch.randn(batch_size, action_dim)
def create_default_train_batch(
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
) -> dict[str, Tensor]:
return {
"action": create_dummy_action(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": create_dummy_state(batch_size, state_dim),
"next_state": create_dummy_state(batch_size, state_dim),
"done": torch.randn(batch_size),
}
def create_train_batch_with_visual_input(
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
) -> dict[str, Tensor]:
return {
"action": create_dummy_action(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": create_dummy_with_visual_input(batch_size, state_dim),
"next_state": create_dummy_with_visual_input(batch_size, state_dim),
"done": torch.randn(batch_size),
}
def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
return {
"observation.state": torch.randn(batch_size, state_dim),
}
def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
return {
"observation.state": torch.randn(batch_size, state_dim),
"observation.image": torch.randn(batch_size, 3, 84, 84),
}
def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]:
"""Create optimizers for the SAC policy."""
optimizer_actor = torch.optim.Adam(
# Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient
params=[
p
for n, p in policy.actor.named_parameters()
if not policy.config.shared_encoder or not n.startswith("encoder")
],
lr=policy.config.actor_lr,
)
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters(),
lr=policy.config.critic_lr,
)
optimizer_temperature = torch.optim.Adam(
params=[policy.log_alpha],
lr=policy.config.critic_lr,
)
optimizers = {
"actor": optimizer_actor,
"critic": optimizer_critic,
"temperature": optimizer_temperature,
}
if has_discrete_action:
optimizers["discrete_critic"] = torch.optim.Adam(
params=policy.discrete_critic.parameters(),
lr=policy.config.critic_lr,
)
return optimizers
def create_default_config(
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
) -> SACConfig:
action_dim = continuous_action_dim
if has_discrete_action:
action_dim += 1
config = SACConfig(
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
dataset_stats={
"observation.state": {
"min": [0.0] * state_dim,
"max": [1.0] * state_dim,
},
"action": {
"min": [0.0] * continuous_action_dim,
"max": [1.0] * continuous_action_dim,
},
},
)
config.validate_features()
return config
def create_config_with_visual_input(
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
) -> SACConfig:
config = create_default_config(
state_dim=state_dim,
continuous_action_dim=continuous_action_dim,
has_discrete_action=has_discrete_action,
)
config.input_features["observation.image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
config.dataset_stats["observation.image"] = {
"mean": torch.randn(3, 1, 1),
"std": torch.randn(3, 1, 1),
}
# Let make tests a little bit faster
config.state_encoder_hidden_dim = 32
config.latent_dim = 32
config.validate_features()
return config
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int):
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
policy.train()
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
assert actor_loss.shape == ()
actor_loss.backward()
optimizers["actor"].step()
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
assert temperature_loss.item() is not None
assert temperature_loss.shape == ()
temperature_loss.backward()
optimizers["temperature"].step()
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
selected_action = policy.select_action(observation_batch)
assert selected_action.shape == (batch_size, action_dim)
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
policy.train()
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
assert actor_loss.shape == ()
actor_loss.backward()
optimizers["actor"].step()
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
assert temperature_loss.item() is not None
assert temperature_loss.shape == ()
temperature_loss.backward()
optimizers["temperature"].step()
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim
)
selected_action = policy.select_action(observation_batch)
assert selected_action.shape == (batch_size, action_dim)
# Let's check best candidates for pretrained encoders
@pytest.mark.parametrize(
"batch_size,state_dim,action_dim,vision_encoder_name",
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
)
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
def test_sac_policy_with_pretrained_encoder(
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
):
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.vision_encoder_name = vision_encoder_name
policy = SACPolicy(config=config)
policy.train()
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
assert actor_loss.shape == ()
def test_sac_policy_with_shared_encoder():
batch_size = 2
action_dim = 10
state_dim = 10
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.shared_encoder = True
policy = SACPolicy(config=config)
policy.train()
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
policy.train()
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
assert actor_loss.shape == ()
actor_loss.backward()
optimizers["actor"].step()
def test_sac_policy_with_discrete_critic():
batch_size = 2
continuous_action_dim = 9
full_action_dim = continuous_action_dim + 1 # the last action is discrete
state_dim = 10
config = create_config_with_visual_input(
state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True
)
num_discrete_actions = 5
config.num_discrete_actions = num_discrete_actions
policy = SACPolicy(config=config)
policy.train()
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim
)
policy.train()
optimizers = make_optimizers(policy, has_discrete_action=True)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"]
assert discrete_critic_loss.item() is not None
assert discrete_critic_loss.shape == ()
discrete_critic_loss.backward()
optimizers["discrete_critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
assert actor_loss.shape == ()
actor_loss.backward()
optimizers["actor"].step()
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim
)
selected_action = policy.select_action(observation_batch)
assert selected_action.shape == (batch_size, full_action_dim)
discrete_actions = selected_action[:, -1].long()
discrete_action_values = set(discrete_actions.tolist())
assert all(action in range(num_discrete_actions) for action in discrete_action_values), (
f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})"
)
def test_sac_policy_with_default_entropy():
config = create_default_config(continuous_action_dim=10, state_dim=10)
policy = SACPolicy(config=config)
assert policy.target_entropy == -5.0
def test_sac_policy_default_target_entropy_with_discrete_action():
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
policy = SACPolicy(config=config)
assert policy.target_entropy == -3.0
def test_sac_policy_with_predefined_entropy():
config = create_default_config(state_dim=10, continuous_action_dim=6)
config.target_entropy = -3.5
policy = SACPolicy(config=config)
assert policy.target_entropy == pytest.approx(-3.5)
def test_sac_policy_update_temperature():
config = create_default_config(continuous_action_dim=10, state_dim=10)
policy = SACPolicy(config=config)
assert policy.temperature == pytest.approx(1.0)
policy.log_alpha.data = torch.tensor([math.log(0.1)])
policy.update_temperature()
assert policy.temperature == pytest.approx(0.1)
def test_sac_policy_update_target_network():
config = create_default_config(state_dim=10, continuous_action_dim=6)
config.critic_target_update_weight = 1.0
policy = SACPolicy(config=config)
policy.train()
for p in policy.critic_ensemble.parameters():
p.data = torch.ones_like(p.data)
policy.update_target_networks()
for p in policy.critic_target.parameters():
assert torch.allclose(p.data, torch.ones_like(p.data)), (
f"Target network {p.data} is not equal to {torch.ones_like(p.data)}"
)
@pytest.mark.parametrize("num_critics", [1, 3])
def test_sac_policy_with_critics_number_of_heads(num_critics: int):
batch_size = 2
action_dim = 10
state_dim = 10
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.num_critics = num_critics
policy = SACPolicy(config=config)
policy.train()
assert len(policy.critic_ensemble.critics) == num_critics
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
policy.train()
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
def test_sac_policy_save_and_load(tmp_path):
root = tmp_path / "test_sac_save_and_load"
state_dim = 10
action_dim = 10
batch_size = 2
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
policy.eval()
policy.save_pretrained(root)
loaded_policy = SACPolicy.from_pretrained(root, config=config)
loaded_policy.eval()
batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10)
with torch.no_grad():
with seeded_context(12):
# Collect policy values before saving
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
actions = policy.select_action(observation_batch)
with seeded_context(12):
# Collect policy values after loading
loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"]
loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"]
loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"]
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
for k in policy.state_dict():
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
# Compare values before and after saving and loading
# They should be the same
assert torch.allclose(cirtic_loss, loaded_cirtic_loss)
assert torch.allclose(actor_loss, loaded_actor_loss)
assert torch.allclose(temperature_loss, loaded_temperature_loss)
assert torch.allclose(actions, loaded_actions)

View File

@ -1,282 +0,0 @@
import torch
from lerobot.processor.pipeline import (
RobotProcessor,
TransitionKey,
_default_batch_to_transition,
_default_transition_to_batch,
)
def _dummy_batch():
"""Create a dummy batch using the new format with observation.* and next.* keys."""
return {
"observation.image.left": torch.randn(1, 3, 128, 128),
"observation.image.right": torch.randn(1, 3, 128, 128),
"observation.state": torch.tensor([[0.1, 0.2, 0.3, 0.4]]),
"action": torch.tensor([[0.5]]),
"next.reward": 1.0,
"next.done": False,
"next.truncated": False,
"info": {"key": "value"},
}
def test_observation_grouping_roundtrip():
"""Test that observation.* keys are properly grouped and ungrouped."""
proc = RobotProcessor([])
batch_in = _dummy_batch()
batch_out = proc(batch_in)
# Check that all observation.* keys are preserved
original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith("observation.")}
reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith("observation.")}
assert set(original_obs_keys.keys()) == set(reconstructed_obs_keys.keys())
# Check tensor values
assert torch.allclose(batch_out["observation.image.left"], batch_in["observation.image.left"])
assert torch.allclose(batch_out["observation.image.right"], batch_in["observation.image.right"])
assert torch.allclose(batch_out["observation.state"], batch_in["observation.state"])
# Check other fields
assert torch.allclose(batch_out["action"], batch_in["action"])
assert batch_out["next.reward"] == batch_in["next.reward"]
assert batch_out["next.done"] == batch_in["next.done"]
assert batch_out["next.truncated"] == batch_in["next.truncated"]
assert batch_out["info"] == batch_in["info"]
def test_batch_to_transition_observation_grouping():
"""Test that _default_batch_to_transition correctly groups observation.* keys."""
batch = {
"observation.image.top": torch.randn(1, 3, 128, 128),
"observation.image.left": torch.randn(1, 3, 128, 128),
"observation.state": [1, 2, 3, 4],
"action": "action_data",
"next.reward": 1.5,
"next.done": True,
"next.truncated": False,
"info": {"episode": 42},
}
transition = _default_batch_to_transition(batch)
# Check observation is a dict with all observation.* keys
assert isinstance(transition[TransitionKey.OBSERVATION], dict)
assert "observation.image.top" in transition[TransitionKey.OBSERVATION]
assert "observation.image.left" in transition[TransitionKey.OBSERVATION]
assert "observation.state" in transition[TransitionKey.OBSERVATION]
# Check values are preserved
assert torch.allclose(
transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"]
)
assert torch.allclose(
transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"]
)
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
# Check other fields
assert transition[TransitionKey.ACTION] == "action_data"
assert transition[TransitionKey.REWARD] == 1.5
assert transition[TransitionKey.DONE]
assert not transition[TransitionKey.TRUNCATED]
assert transition[TransitionKey.INFO] == {"episode": 42}
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
def test_transition_to_batch_observation_flattening():
"""Test that _default_transition_to_batch correctly flattens observation dict."""
observation_dict = {
"observation.image.top": torch.randn(1, 3, 128, 128),
"observation.image.left": torch.randn(1, 3, 128, 128),
"observation.state": [1, 2, 3, 4],
}
transition = {
TransitionKey.OBSERVATION: observation_dict,
TransitionKey.ACTION: "action_data",
TransitionKey.REWARD: 1.5,
TransitionKey.DONE: True,
TransitionKey.TRUNCATED: False,
TransitionKey.INFO: {"episode": 42},
TransitionKey.COMPLEMENTARY_DATA: {},
}
batch = _default_transition_to_batch(transition)
# Check that observation.* keys are flattened back to batch
assert "observation.image.top" in batch
assert "observation.image.left" in batch
assert "observation.state" in batch
# Check values are preserved
assert torch.allclose(batch["observation.image.top"], observation_dict["observation.image.top"])
assert torch.allclose(batch["observation.image.left"], observation_dict["observation.image.left"])
assert batch["observation.state"] == [1, 2, 3, 4]
# Check other fields are mapped to next.* format
assert batch["action"] == "action_data"
assert batch["next.reward"] == 1.5
assert batch["next.done"]
assert not batch["next.truncated"]
assert batch["info"] == {"episode": 42}
def test_no_observation_keys():
"""Test behavior when there are no observation.* keys."""
batch = {
"action": "action_data",
"next.reward": 2.0,
"next.done": False,
"next.truncated": True,
"info": {"test": "no_obs"},
}
transition = _default_batch_to_transition(batch)
# Observation should be None when no observation.* keys
assert transition[TransitionKey.OBSERVATION] is None
# Check other fields
assert transition[TransitionKey.ACTION] == "action_data"
assert transition[TransitionKey.REWARD] == 2.0
assert not transition[TransitionKey.DONE]
assert transition[TransitionKey.TRUNCATED]
assert transition[TransitionKey.INFO] == {"test": "no_obs"}
# Round trip should work
reconstructed_batch = _default_transition_to_batch(transition)
assert reconstructed_batch["action"] == "action_data"
assert reconstructed_batch["next.reward"] == 2.0
assert not reconstructed_batch["next.done"]
assert reconstructed_batch["next.truncated"]
assert reconstructed_batch["info"] == {"test": "no_obs"}
def test_minimal_batch():
"""Test with minimal batch containing only observation.* and action."""
batch = {"observation.state": "minimal_state", "action": "minimal_action"}
transition = _default_batch_to_transition(batch)
# Check observation
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
assert transition[TransitionKey.ACTION] == "minimal_action"
# Check defaults
assert transition[TransitionKey.REWARD] == 0.0
assert not transition[TransitionKey.DONE]
assert not transition[TransitionKey.TRUNCATED]
assert transition[TransitionKey.INFO] == {}
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
# Round trip
reconstructed_batch = _default_transition_to_batch(transition)
assert reconstructed_batch["observation.state"] == "minimal_state"
assert reconstructed_batch["action"] == "minimal_action"
assert reconstructed_batch["next.reward"] == 0.0
assert not reconstructed_batch["next.done"]
assert not reconstructed_batch["next.truncated"]
assert reconstructed_batch["info"] == {}
def test_empty_batch():
"""Test behavior with empty batch."""
batch = {}
transition = _default_batch_to_transition(batch)
# All fields should have defaults
assert transition[TransitionKey.OBSERVATION] is None
assert transition[TransitionKey.ACTION] is None
assert transition[TransitionKey.REWARD] == 0.0
assert not transition[TransitionKey.DONE]
assert not transition[TransitionKey.TRUNCATED]
assert transition[TransitionKey.INFO] == {}
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
# Round trip
reconstructed_batch = _default_transition_to_batch(transition)
assert reconstructed_batch["action"] is None
assert reconstructed_batch["next.reward"] == 0.0
assert not reconstructed_batch["next.done"]
assert not reconstructed_batch["next.truncated"]
assert reconstructed_batch["info"] == {}
def test_complex_nested_observation():
"""Test with complex nested observation data."""
batch = {
"observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890},
"observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
"observation.state": torch.randn(7),
"action": torch.randn(8),
"next.reward": 3.14,
"next.done": False,
"next.truncated": True,
"info": {"episode_length": 200, "success": True},
}
transition = _default_batch_to_transition(batch)
reconstructed_batch = _default_transition_to_batch(transition)
# Check that all observation keys are preserved
original_obs_keys = {k for k in batch if k.startswith("observation.")}
reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith("observation.")}
assert original_obs_keys == reconstructed_obs_keys
# Check tensor values
assert torch.allclose(batch["observation.state"], reconstructed_batch["observation.state"])
# Check nested dict with tensors
assert torch.allclose(
batch["observation.image.top"]["image"], reconstructed_batch["observation.image.top"]["image"]
)
assert torch.allclose(
batch["observation.image.left"]["image"], reconstructed_batch["observation.image.left"]["image"]
)
# Check action tensor
assert torch.allclose(batch["action"], reconstructed_batch["action"])
# Check other fields
assert batch["next.reward"] == reconstructed_batch["next.reward"]
assert batch["next.done"] == reconstructed_batch["next.done"]
assert batch["next.truncated"] == reconstructed_batch["next.truncated"]
assert batch["info"] == reconstructed_batch["info"]
def test_custom_converter():
"""Test that custom converters can still be used."""
def to_tr(batch):
# Custom converter that modifies the reward
tr = _default_batch_to_transition(batch)
# Double the reward
reward = tr.get(TransitionKey.REWARD, 0.0)
new_tr = tr.copy()
new_tr[TransitionKey.REWARD] = reward * 2 if reward is not None else 0.0
return new_tr
def to_batch(tr):
batch = _default_transition_to_batch(tr)
return batch
processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch)
batch = {
"observation.state": torch.randn(1, 4),
"action": torch.randn(1, 2),
"next.reward": 1.0,
"next.done": False,
}
result = processor(batch)
# Check the reward was doubled by our custom converter
assert result["next.reward"] == 2.0
assert torch.allclose(result["observation.state"], batch["observation.state"])
assert torch.allclose(result["action"], batch["action"])

View File

@ -1,628 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock
import numpy as np
import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor.normalize_processor import (
NormalizerProcessor,
UnnormalizerProcessor,
_convert_stats_to_tensors,
)
from lerobot.processor.pipeline import RobotProcessor, TransitionKey
def create_transition(
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
def test_numpy_conversion():
stats = {
"observation.image": {
"mean": np.array([0.5, 0.5, 0.5]),
"std": np.array([0.2, 0.2, 0.2]),
}
}
tensor_stats = _convert_stats_to_tensors(stats)
assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor)
assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor)
assert torch.allclose(tensor_stats["observation.image"]["mean"], torch.tensor([0.5, 0.5, 0.5]))
assert torch.allclose(tensor_stats["observation.image"]["std"], torch.tensor([0.2, 0.2, 0.2]))
def test_tensor_conversion():
stats = {
"action": {
"mean": torch.tensor([0.0, 0.0]),
"std": torch.tensor([1.0, 1.0]),
}
}
tensor_stats = _convert_stats_to_tensors(stats)
assert tensor_stats["action"]["mean"].dtype == torch.float32
assert tensor_stats["action"]["std"].dtype == torch.float32
def test_scalar_conversion():
stats = {
"reward": {
"mean": 0.5,
"std": 0.1,
}
}
tensor_stats = _convert_stats_to_tensors(stats)
assert torch.allclose(tensor_stats["reward"]["mean"], torch.tensor(0.5))
assert torch.allclose(tensor_stats["reward"]["std"], torch.tensor(0.1))
def test_list_conversion():
stats = {
"observation.state": {
"min": [0.0, -1.0, -2.0],
"max": [1.0, 1.0, 2.0],
}
}
tensor_stats = _convert_stats_to_tensors(stats)
assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0]))
assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0]))
def test_unsupported_type():
stats = {
"bad_key": {
"mean": "string_value",
}
}
with pytest.raises(TypeError, match="Unsupported type"):
_convert_stats_to_tensors(stats)
# Helper functions to create feature maps and norm maps
def _create_observation_features():
return {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
def _create_observation_norm_map():
return {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.STATE: NormalizationMode.MIN_MAX,
}
# Fixtures for observation normalisation tests using NormalizerProcessor
@pytest.fixture
def observation_stats():
return {
"observation.image": {
"mean": np.array([0.5, 0.5, 0.5]),
"std": np.array([0.2, 0.2, 0.2]),
},
"observation.state": {
"min": np.array([0.0, -1.0]),
"max": np.array([1.0, 1.0]),
},
}
@pytest.fixture
def observation_normalizer(observation_stats):
"""Return a NormalizerProcessor that only has observation stats (no action)."""
features = _create_observation_features()
norm_map = _create_observation_norm_map()
return NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats)
def test_mean_std_normalization(observation_normalizer):
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = create_transition(observation=observation)
normalized_transition = observation_normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Check mean/std normalization
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
assert torch.allclose(normalized_obs["observation.image"], expected_image)
def test_min_max_normalization(observation_normalizer):
observation = {
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = create_transition(observation=observation)
normalized_transition = observation_normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Check min/max normalization to [-1, 1]
# For state[0]: 2 * (0.5 - 0.0) / (1.0 - 0.0) - 1 = 0.0
# For state[1]: 2 * (0.0 - (-1.0)) / (1.0 - (-1.0)) - 1 = 0.0
expected_state = torch.tensor([0.0, 0.0])
assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6)
def test_selective_normalization(observation_stats):
features = _create_observation_features()
norm_map = _create_observation_norm_map()
normalizer = NormalizerProcessor(
features=features, norm_map=norm_map, stats=observation_stats, normalize_keys={"observation.image"}
)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Only image should be normalized
assert torch.allclose(normalized_obs["observation.image"], (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2)
# State should remain unchanged
assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_device_compatibility(observation_stats):
features = _create_observation_features()
norm_map = _create_observation_norm_map()
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(),
}
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
assert normalized_obs["observation.image"].device.type == "cuda"
def test_from_lerobot_dataset():
# Mock dataset
mock_dataset = Mock()
mock_dataset.meta.stats = {
"observation.image": {"mean": [0.5], "std": [0.2]},
"action": {"mean": [0.0], "std": [1.0]},
}
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
"action": PolicyFeature(FeatureType.ACTION, (1,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
normalizer = NormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
# Both observation and action statistics should be present in tensor stats
assert "observation.image" in normalizer._tensor_stats
assert "action" in normalizer._tensor_stats
def test_state_dict_save_load(observation_normalizer):
# Save state
state_dict = observation_normalizer.state_dict()
# Create new normalizer and load state
features = _create_observation_features()
norm_map = _create_observation_norm_map()
new_normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={})
new_normalizer.load_state_dict(state_dict)
# Test that it works the same
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
transition = create_transition(observation=observation)
result1 = observation_normalizer(transition)[TransitionKey.OBSERVATION]
result2 = new_normalizer(transition)[TransitionKey.OBSERVATION]
assert torch.allclose(result1["observation.image"], result2["observation.image"])
# Fixtures for ActionUnnormalizer tests
@pytest.fixture
def action_stats_mean_std():
return {
"mean": np.array([0.0, 0.0, 0.0]),
"std": np.array([1.0, 2.0, 0.5]),
}
@pytest.fixture
def action_stats_min_max():
return {
"min": np.array([-1.0, -2.0, 0.0]),
"max": np.array([1.0, 2.0, 1.0]),
}
def _create_action_features():
return {
"action": PolicyFeature(FeatureType.ACTION, (3,)),
}
def _create_action_norm_map_mean_std():
return {
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
def _create_action_norm_map_min_max():
return {
FeatureType.ACTION: NormalizationMode.MIN_MAX,
}
def test_mean_std_unnormalization(action_stats_mean_std):
features = _create_action_features()
norm_map = _create_action_norm_map_mean_std()
unnormalizer = UnnormalizerProcessor(
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
)
normalized_action = torch.tensor([1.0, -0.5, 2.0])
transition = create_transition(action=normalized_action)
unnormalized_transition = unnormalizer(transition)
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
# action * std + mean
expected = torch.tensor([1.0 * 1.0 + 0.0, -0.5 * 2.0 + 0.0, 2.0 * 0.5 + 0.0])
assert torch.allclose(unnormalized_action, expected)
def test_min_max_unnormalization(action_stats_min_max):
features = _create_action_features()
norm_map = _create_action_norm_map_min_max()
unnormalizer = UnnormalizerProcessor(
features=features, norm_map=norm_map, stats={"action": action_stats_min_max}
)
# Actions in [-1, 1]
normalized_action = torch.tensor([0.0, -1.0, 1.0])
transition = create_transition(action=normalized_action)
unnormalized_transition = unnormalizer(transition)
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
# Map from [-1, 1] to [min, max]
# (action + 1) / 2 * (max - min) + min
expected = torch.tensor(
[
(0.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0), # 0.0
(-1.0 + 1) / 2 * (2.0 - (-2.0)) + (-2.0), # -2.0
(1.0 + 1) / 2 * (1.0 - 0.0) + 0.0, # 1.0
]
)
assert torch.allclose(unnormalized_action, expected)
def test_numpy_action_input(action_stats_mean_std):
features = _create_action_features()
norm_map = _create_action_norm_map_mean_std()
unnormalizer = UnnormalizerProcessor(
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
)
normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32)
transition = create_transition(action=normalized_action)
unnormalized_transition = unnormalizer(transition)
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
assert isinstance(unnormalized_action, torch.Tensor)
expected = torch.tensor([1.0, -1.0, 1.0])
assert torch.allclose(unnormalized_action, expected)
def test_none_action(action_stats_mean_std):
features = _create_action_features()
norm_map = _create_action_norm_map_mean_std()
unnormalizer = UnnormalizerProcessor(
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
)
transition = create_transition()
result = unnormalizer(transition)
# Should return transition unchanged
assert result == transition
def test_action_from_lerobot_dataset():
mock_dataset = Mock()
mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}}
features = {"action": PolicyFeature(FeatureType.ACTION, (1,))}
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
unnormalizer = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
assert "mean" in unnormalizer._tensor_stats["action"]
# Fixtures for NormalizerProcessor tests
@pytest.fixture
def full_stats():
return {
"observation.image": {
"mean": np.array([0.5, 0.5, 0.5]),
"std": np.array([0.2, 0.2, 0.2]),
},
"observation.state": {
"min": np.array([0.0, -1.0]),
"max": np.array([1.0, 1.0]),
},
"action": {
"mean": np.array([0.0, 0.0]),
"std": np.array([1.0, 2.0]),
},
}
def _create_full_features():
return {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
def _create_full_norm_map():
return {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.STATE: NormalizationMode.MIN_MAX,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
@pytest.fixture
def normalizer_processor(full_stats):
features = _create_full_features()
norm_map = _create_full_norm_map()
return NormalizerProcessor(features=features, norm_map=norm_map, stats=full_stats)
def test_combined_normalization(normalizer_processor):
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
action = torch.tensor([1.0, -0.5])
transition = create_transition(
observation=observation,
action=action,
reward=1.0,
done=False,
truncated=False,
info={},
complementary_data={},
)
processed_transition = normalizer_processor(transition)
# Check normalized observations
processed_obs = processed_transition[TransitionKey.OBSERVATION]
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
assert torch.allclose(processed_obs["observation.image"], expected_image)
# Check normalized action
processed_action = processed_transition[TransitionKey.ACTION]
expected_action = torch.tensor([(1.0 - 0.0) / 1.0, (-0.5 - 0.0) / 2.0])
assert torch.allclose(processed_action, expected_action)
# Check other fields remain unchanged
assert processed_transition[TransitionKey.REWARD] == 1.0
assert not processed_transition[TransitionKey.DONE]
def test_processor_from_lerobot_dataset(full_stats):
# Mock dataset
mock_dataset = Mock()
mock_dataset.meta.stats = full_stats
features = _create_full_features()
norm_map = _create_full_norm_map()
processor = NormalizerProcessor.from_lerobot_dataset(
mock_dataset, features, norm_map, normalize_keys={"observation.image"}
)
assert processor.normalize_keys == {"observation.image"}
assert "observation.image" in processor._tensor_stats
assert "action" in processor._tensor_stats
def test_get_config(full_stats):
features = _create_full_features()
norm_map = _create_full_norm_map()
processor = NormalizerProcessor(
features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6
)
config = processor.get_config()
expected_config = {
"normalize_keys": ["observation.image"],
"eps": 1e-6,
"features": {
"observation.image": {"type": "VISUAL", "shape": (3, 96, 96)},
"observation.state": {"type": "STATE", "shape": (2,)},
"action": {"type": "ACTION", "shape": (2,)},
},
"norm_map": {
"VISUAL": "MEAN_STD",
"STATE": "MIN_MAX",
"ACTION": "MEAN_STD",
},
}
assert config == expected_config
def test_integration_with_robot_processor(normalizer_processor):
"""Test integration with RobotProcessor pipeline"""
robot_processor = RobotProcessor([normalizer_processor])
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
action = torch.tensor([1.0, -0.5])
transition = create_transition(
observation=observation,
action=action,
reward=1.0,
done=False,
truncated=False,
info={},
complementary_data={},
)
processed_transition = robot_processor(transition)
# Verify the processing worked
assert isinstance(processed_transition[TransitionKey.OBSERVATION], dict)
assert isinstance(processed_transition[TransitionKey.ACTION], torch.Tensor)
# Edge case tests
def test_empty_observation():
stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
transition = create_transition()
result = normalizer(transition)
assert result == transition
def test_empty_stats():
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={})
observation = {"observation.image": torch.tensor([0.5])}
transition = create_transition(observation=observation)
result = normalizer(transition)
# Should return observation unchanged since no stats are available
assert torch.allclose(
result[TransitionKey.OBSERVATION]["observation.image"], observation["observation.image"]
)
def test_partial_stats():
"""If statistics are incomplete, the value should pass through unchanged."""
stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max)
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {"observation.image": torch.tensor([0.7])}
transition = create_transition(observation=observation)
processed = normalizer(transition)[TransitionKey.OBSERVATION]
assert torch.allclose(processed["observation.image"], observation["observation.image"])
def test_missing_action_stats_no_error():
mock_dataset = Mock()
mock_dataset.meta.stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
# The tensor stats should not contain the 'action' key
assert "action" not in processor._tensor_stats
def test_serialization_roundtrip(full_stats):
"""Test that features and norm_map can be serialized and deserialized correctly."""
features = _create_full_features()
norm_map = _create_full_norm_map()
original_processor = NormalizerProcessor(
features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6
)
# Get config (serialization)
config = original_processor.get_config()
# Create a new processor from the config (deserialization)
new_processor = NormalizerProcessor(
features=config["features"],
norm_map=config["norm_map"],
stats=full_stats,
normalize_keys=set(config["normalize_keys"]),
eps=config["eps"],
)
# Test that both processors work the same way
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
action = torch.tensor([1.0, -0.5])
transition = create_transition(
observation=observation,
action=action,
reward=1.0,
done=False,
truncated=False,
info={},
complementary_data={},
)
result1 = original_processor(transition)
result2 = new_processor(transition)
# Compare results
assert torch.allclose(
result1[TransitionKey.OBSERVATION]["observation.image"],
result2[TransitionKey.OBSERVATION]["observation.image"],
)
assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION])
# Verify features and norm_map are correctly reconstructed
assert new_processor.features.keys() == original_processor.features.keys()
for key in new_processor.features:
assert new_processor.features[key].type == original_processor.features[key].type
assert new_processor.features[key].shape == original_processor.features[key].shape
assert new_processor.norm_map == original_processor.norm_map

View File

@ -1,486 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import torch
from lerobot.configs.types import FeatureType
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor import VanillaObservationProcessor
from lerobot.processor.pipeline import TransitionKey
from tests.conftest import assert_contract_is_typed
def create_transition(
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
def test_process_single_image():
"""Test processing a single image."""
processor = VanillaObservationProcessor()
# Create a mock image (H, W, C) format, uint8
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check that the image was processed correctly
assert "observation.image" in processed_obs
processed_img = processed_obs["observation.image"]
# Check shape: should be (1, 3, 64, 64) - batch, channels, height, width
assert processed_img.shape == (1, 3, 64, 64)
# Check dtype and range
assert processed_img.dtype == torch.float32
assert processed_img.min() >= 0.0
assert processed_img.max() <= 1.0
def test_process_image_dict():
"""Test processing multiple images in a dictionary."""
processor = VanillaObservationProcessor()
# Create mock images
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
observation = {"pixels": {"camera1": image1, "camera2": image2}}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check that both images were processed
assert "observation.images.camera1" in processed_obs
assert "observation.images.camera2" in processed_obs
# Check shapes
assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32)
assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48)
def test_process_batched_image():
"""Test processing already batched images."""
processor = VanillaObservationProcessor()
# Create a batched image (B, H, W, C)
image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check that batch dimension is preserved
assert processed_obs["observation.image"].shape == (2, 3, 64, 64)
def test_invalid_image_format():
"""Test error handling for invalid image formats."""
processor = VanillaObservationProcessor()
# Test wrong channel order (channels first)
image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8)
observation = {"pixels": image}
transition = create_transition(observation=observation)
with pytest.raises(ValueError, match="Expected channel-last images"):
processor(transition)
def test_invalid_image_dtype():
"""Test error handling for invalid image dtype."""
processor = VanillaObservationProcessor()
# Test wrong dtype
image = np.random.rand(64, 64, 3).astype(np.float32)
observation = {"pixels": image}
transition = create_transition(observation=observation)
with pytest.raises(ValueError, match="Expected torch.uint8 images"):
processor(transition)
def test_no_pixels_in_observation():
"""Test processor when no pixels are in observation."""
processor = VanillaObservationProcessor()
observation = {"other_data": np.array([1, 2, 3])}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Should preserve other data unchanged
assert "other_data" in processed_obs
np.testing.assert_array_equal(processed_obs["other_data"], np.array([1, 2, 3]))
def test_none_observation():
"""Test processor with None observation."""
processor = VanillaObservationProcessor()
transition = create_transition()
result = processor(transition)
assert result == transition
def test_serialization_methods():
"""Test serialization methods."""
processor = VanillaObservationProcessor()
# Test get_config
config = processor.get_config()
assert isinstance(config, dict)
# Test state_dict
state = processor.state_dict()
assert isinstance(state, dict)
# Test load_state_dict (should not raise)
processor.load_state_dict(state)
# Test reset (should not raise)
processor.reset()
def test_process_environment_state():
"""Test processing environment_state."""
processor = VanillaObservationProcessor()
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
observation = {"environment_state": env_state}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check that environment_state was renamed and processed
assert "observation.environment_state" in processed_obs
assert "environment_state" not in processed_obs
processed_state = processed_obs["observation.environment_state"]
assert processed_state.shape == (1, 3) # Batch dimension added
assert processed_state.dtype == torch.float32
torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]]))
def test_process_agent_pos():
"""Test processing agent_pos."""
processor = VanillaObservationProcessor()
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check that agent_pos was renamed and processed
assert "observation.state" in processed_obs
assert "agent_pos" not in processed_obs
processed_state = processed_obs["observation.state"]
assert processed_state.shape == (1, 3) # Batch dimension added
assert processed_state.dtype == torch.float32
torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]]))
def test_process_batched_states():
"""Test processing already batched states."""
processor = VanillaObservationProcessor()
env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
observation = {"environment_state": env_state, "agent_pos": agent_pos}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check that batch dimensions are preserved
assert processed_obs["observation.environment_state"].shape == (2, 2)
assert processed_obs["observation.state"].shape == (2, 2)
def test_process_both_states():
"""Test processing both environment_state and agent_pos."""
processor = VanillaObservationProcessor()
env_state = np.array([1.0, 2.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
observation = {"environment_state": env_state, "agent_pos": agent_pos, "other_data": "keep_me"}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check that both states were processed
assert "observation.environment_state" in processed_obs
assert "observation.state" in processed_obs
# Check that original keys were removed
assert "environment_state" not in processed_obs
assert "agent_pos" not in processed_obs
# Check that other data was preserved
assert processed_obs["other_data"] == "keep_me"
def test_no_states_in_observation():
"""Test processor when no states are in observation."""
processor = VanillaObservationProcessor()
observation = {"other_data": np.array([1, 2, 3])}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Should preserve data unchanged
np.testing.assert_array_equal(processed_obs, observation)
def test_complete_observation_processing():
"""Test processing a complete observation with both images and states."""
processor = VanillaObservationProcessor()
# Create mock data
image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {
"pixels": image,
"environment_state": env_state,
"agent_pos": agent_pos,
"other_data": "preserve_me",
}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check that image was processed
assert "observation.image" in processed_obs
assert processed_obs["observation.image"].shape == (1, 3, 32, 32)
# Check that states were processed
assert "observation.environment_state" in processed_obs
assert "observation.state" in processed_obs
# Check that original keys were removed
assert "pixels" not in processed_obs
assert "environment_state" not in processed_obs
assert "agent_pos" not in processed_obs
# Check that other data was preserved
assert processed_obs["other_data"] == "preserve_me"
def test_image_only_processing():
"""Test processing observation with only images."""
processor = VanillaObservationProcessor()
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
assert "observation.image" in processed_obs
assert len(processed_obs) == 1
def test_state_only_processing():
"""Test processing observation with only states."""
processor = VanillaObservationProcessor()
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
assert "observation.state" in processed_obs
assert "agent_pos" not in processed_obs
def test_empty_observation():
"""Test processing empty observation."""
processor = VanillaObservationProcessor()
observation = {}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
assert processed_obs == {}
def test_equivalent_to_original_function():
"""Test that ObservationProcessor produces equivalent results to preprocess_observation."""
# Import the original function for comparison
from lerobot.envs.utils import preprocess_observation
processor = VanillaObservationProcessor()
# Create test data similar to what the original function expects
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {"pixels": image, "environment_state": env_state, "agent_pos": agent_pos}
# Process with original function
original_result = preprocess_observation(observation)
# Process with new processor
transition = create_transition(observation=observation)
processor_result = processor(transition)[TransitionKey.OBSERVATION]
# Compare results
assert set(original_result.keys()) == set(processor_result.keys())
for key in original_result:
torch.testing.assert_close(original_result[key], processor_result[key])
def test_equivalent_with_image_dict():
"""Test equivalence with dictionary of images."""
from lerobot.envs.utils import preprocess_observation
processor = VanillaObservationProcessor()
# Create test data with multiple cameras
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
observation = {"pixels": {"cam1": image1, "cam2": image2}, "agent_pos": agent_pos}
# Process with original function
original_result = preprocess_observation(observation)
# Process with new processor
transition = create_transition(observation=observation)
processor_result = processor(transition)[TransitionKey.OBSERVATION]
# Compare results
assert set(original_result.keys()) == set(processor_result.keys())
for key in original_result:
torch.testing.assert_close(original_result[key], processor_result[key])
def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory):
processor = VanillaObservationProcessor()
features = {
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
}
out = processor.feature_contract(features.copy())
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"]
assert "pixels" not in out
assert out["keep"] == features["keep"]
assert_contract_is_typed(out)
def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory):
processor = VanillaObservationProcessor()
features = {
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
}
out = processor.feature_contract(features.copy())
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"]
assert "observation.pixels" not in out
assert out["keep"] == features["keep"]
assert_contract_is_typed(out)
def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory):
processor = VanillaObservationProcessor()
features = {
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (7,)),
}
out = processor.feature_contract(features.copy())
assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"]
assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"]
assert f"{OBS_IMAGES}.rear" in out and out[f"{OBS_IMAGES}.rear"] == features["observation.pixels.rear"]
assert "pixels.front" not in out and "pixels.wrist" not in out and "observation.pixels.rear" not in out
assert out["keep"] == features["keep"]
assert_contract_is_typed(out)
def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory):
processor = VanillaObservationProcessor()
features = {
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
}
out = processor.feature_contract(features.copy())
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"]
assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"]
assert "environment_state" not in out and "agent_pos" not in out
assert out["keep"] == features["keep"]
assert_contract_is_typed(out)
def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory):
proc = VanillaObservationProcessor()
features = {
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
}
out = proc.feature_contract(features.copy())
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"]
assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"]
assert "environment_state" not in out and "agent_pos" not in out
assert_contract_is_typed(out)

File diff suppressed because it is too large Load Diff

View File

@ -1,467 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
from pathlib import Path
import numpy as np
import torch
from lerobot.configs.types import FeatureType
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey
from tests.conftest import assert_contract_is_typed
def create_transition(
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
def test_basic_renaming():
"""Test basic key renaming functionality."""
rename_map = {
"old_key1": "new_key1",
"old_key2": "new_key2",
}
processor = RenameProcessor(rename_map=rename_map)
observation = {
"old_key1": torch.tensor([1.0, 2.0]),
"old_key2": np.array([3.0, 4.0]),
"unchanged_key": "keep_me",
}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check renamed keys
assert "new_key1" in processed_obs
assert "new_key2" in processed_obs
assert "old_key1" not in processed_obs
assert "old_key2" not in processed_obs
# Check values are preserved
torch.testing.assert_close(processed_obs["new_key1"], torch.tensor([1.0, 2.0]))
np.testing.assert_array_equal(processed_obs["new_key2"], np.array([3.0, 4.0]))
# Check unchanged key is preserved
assert processed_obs["unchanged_key"] == "keep_me"
def test_empty_rename_map():
"""Test processor with empty rename map (should pass through unchanged)."""
processor = RenameProcessor(rename_map={})
observation = {
"key1": torch.tensor([1.0]),
"key2": "value2",
}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# All keys should be unchanged
assert processed_obs.keys() == observation.keys()
torch.testing.assert_close(processed_obs["key1"], observation["key1"])
assert processed_obs["key2"] == observation["key2"]
def test_none_observation():
"""Test processor with None observation."""
processor = RenameProcessor(rename_map={"old": "new"})
transition = create_transition()
result = processor(transition)
# Should return transition unchanged
assert result == transition
def test_overlapping_rename():
"""Test renaming when new names might conflict."""
rename_map = {
"a": "b",
"b": "c", # This creates a potential conflict
}
processor = RenameProcessor(rename_map=rename_map)
observation = {
"a": 1,
"b": 2,
"x": 3,
}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check that renaming happens correctly
assert "a" not in processed_obs
assert processed_obs["b"] == 1 # 'a' renamed to 'b'
assert processed_obs["c"] == 2 # original 'b' renamed to 'c'
assert processed_obs["x"] == 3
def test_partial_rename():
"""Test renaming only some keys."""
rename_map = {
"observation.state": "observation.proprio_state",
"pixels": "observation.image",
}
processor = RenameProcessor(rename_map=rename_map)
observation = {
"observation.state": torch.randn(10),
"pixels": np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8),
"reward": 1.0,
"info": {"episode": 1},
}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check renamed keys
assert "observation.proprio_state" in processed_obs
assert "observation.image" in processed_obs
assert "observation.state" not in processed_obs
assert "pixels" not in processed_obs
# Check unchanged keys
assert processed_obs["reward"] == 1.0
assert processed_obs["info"] == {"episode": 1}
def test_get_config():
"""Test configuration serialization."""
rename_map = {
"old1": "new1",
"old2": "new2",
}
processor = RenameProcessor(rename_map=rename_map)
config = processor.get_config()
assert config == {"rename_map": rename_map}
def test_state_dict():
"""Test state dict (should be empty for RenameProcessor)."""
processor = RenameProcessor(rename_map={"old": "new"})
state = processor.state_dict()
assert state == {}
# Load state dict should work even with empty dict
processor.load_state_dict({})
def test_integration_with_robot_processor():
"""Test integration with RobotProcessor pipeline."""
rename_map = {
"agent_pos": "observation.state",
"pixels": "observation.image",
}
rename_processor = RenameProcessor(rename_map=rename_map)
pipeline = RobotProcessor([rename_processor])
observation = {
"agent_pos": np.array([1.0, 2.0, 3.0]),
"pixels": np.zeros((32, 32, 3), dtype=np.uint8),
"other_data": "preserve_me",
}
transition = create_transition(
observation=observation, reward=0.5, done=False, truncated=False, info={}, complementary_data={}
)
result = pipeline(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check renaming worked through pipeline
assert "observation.state" in processed_obs
assert "observation.image" in processed_obs
assert "agent_pos" not in processed_obs
assert "pixels" not in processed_obs
assert processed_obs["other_data"] == "preserve_me"
# Check other transition elements unchanged
assert result[TransitionKey.REWARD] == 0.5
assert result[TransitionKey.DONE] is False
def test_save_and_load_pretrained():
"""Test saving and loading processor with RobotProcessor."""
rename_map = {
"old_state": "observation.state",
"old_image": "observation.image",
}
processor = RenameProcessor(rename_map=rename_map)
pipeline = RobotProcessor([processor], name="TestRenameProcessor")
with tempfile.TemporaryDirectory() as tmp_dir:
# Save pipeline
pipeline.save_pretrained(tmp_dir)
# Check files were created
config_path = Path(tmp_dir) / "testrenameprocessor.json" # Based on name="TestRenameProcessor"
assert config_path.exists()
# No state files should be created for RenameProcessor
state_files = list(Path(tmp_dir).glob("*.safetensors"))
assert len(state_files) == 0
# Load pipeline
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
assert loaded_pipeline.name == "TestRenameProcessor"
assert len(loaded_pipeline) == 1
# Check that loaded processor works correctly
loaded_processor = loaded_pipeline.steps[0]
assert isinstance(loaded_processor, RenameProcessor)
assert loaded_processor.rename_map == rename_map
# Test functionality after loading
observation = {"old_state": [1, 2, 3], "old_image": "image_data"}
transition = create_transition(observation=observation)
result = loaded_pipeline(transition)
processed_obs = result[TransitionKey.OBSERVATION]
assert "observation.state" in processed_obs
assert "observation.image" in processed_obs
assert processed_obs["observation.state"] == [1, 2, 3]
assert processed_obs["observation.image"] == "image_data"
def test_registry_functionality():
"""Test that RenameProcessor is properly registered."""
# Check that it's registered
assert "rename_processor" in ProcessorStepRegistry.list()
# Get from registry
retrieved_class = ProcessorStepRegistry.get("rename_processor")
assert retrieved_class is RenameProcessor
# Create instance from registry
instance = retrieved_class(rename_map={"old": "new"})
assert isinstance(instance, RenameProcessor)
assert instance.rename_map == {"old": "new"}
def test_registry_based_save_load():
"""Test save/load using registry name instead of module path."""
processor = RenameProcessor(rename_map={"key1": "renamed_key1"})
pipeline = RobotProcessor([processor])
with tempfile.TemporaryDirectory() as tmp_dir:
# Save and load
pipeline.save_pretrained(tmp_dir)
# Verify config uses registry name
import json
with open(Path(tmp_dir) / "robotprocessor.json") as f: # Default name is "RobotProcessor"
config = json.load(f)
assert "registry_name" in config["steps"][0]
assert config["steps"][0]["registry_name"] == "rename_processor"
assert "class" not in config["steps"][0] # Should use registry, not module path
# Load should work
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
loaded_processor = loaded_pipeline.steps[0]
assert isinstance(loaded_processor, RenameProcessor)
assert loaded_processor.rename_map == {"key1": "renamed_key1"}
def test_chained_rename_processors():
"""Test multiple RenameProcessors in a pipeline."""
# First processor: rename raw keys to intermediate format
processor1 = RenameProcessor(
rename_map={
"pos": "agent_position",
"img": "camera_image",
}
)
# Second processor: rename to final format
processor2 = RenameProcessor(
rename_map={
"agent_position": "observation.state",
"camera_image": "observation.image",
}
)
pipeline = RobotProcessor([processor1, processor2])
observation = {
"pos": np.array([1.0, 2.0]),
"img": "image_data",
"extra": "keep_me",
}
transition = create_transition(observation=observation)
# Step through to see intermediate results
results = list(pipeline.step_through(transition))
# After first processor
assert "agent_position" in results[1][TransitionKey.OBSERVATION]
assert "camera_image" in results[1][TransitionKey.OBSERVATION]
# After second processor
final_obs = results[2][TransitionKey.OBSERVATION]
assert "observation.state" in final_obs
assert "observation.image" in final_obs
assert final_obs["extra"] == "keep_me"
# Original keys should be gone
assert "pos" not in final_obs
assert "img" not in final_obs
assert "agent_position" not in final_obs
assert "camera_image" not in final_obs
def test_nested_observation_rename():
"""Test renaming with nested observation structures."""
rename_map = {
"observation.images.left": "observation.camera.left_view",
"observation.images.right": "observation.camera.right_view",
"observation.proprio": "observation.proprioception",
}
processor = RenameProcessor(rename_map=rename_map)
observation = {
"observation.images.left": torch.randn(3, 64, 64),
"observation.images.right": torch.randn(3, 64, 64),
"observation.proprio": torch.randn(7),
"observation.gripper": torch.tensor([0.0]), # Not renamed
}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check renames
assert "observation.camera.left_view" in processed_obs
assert "observation.camera.right_view" in processed_obs
assert "observation.proprioception" in processed_obs
# Check unchanged key
assert "observation.gripper" in processed_obs
# Check old keys removed
assert "observation.images.left" not in processed_obs
assert "observation.images.right" not in processed_obs
assert "observation.proprio" not in processed_obs
def test_value_types_preserved():
"""Test that various value types are preserved during renaming."""
rename_map = {"old_tensor": "new_tensor", "old_array": "new_array", "old_scalar": "new_scalar"}
processor = RenameProcessor(rename_map=rename_map)
tensor_value = torch.randn(3, 3)
array_value = np.random.rand(2, 2)
observation = {
"old_tensor": tensor_value,
"old_array": array_value,
"old_scalar": 42,
"old_string": "hello",
"old_dict": {"nested": "value"},
"old_list": [1, 2, 3],
}
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
# Check that values and types are preserved
assert torch.equal(processed_obs["new_tensor"], tensor_value)
assert np.array_equal(processed_obs["new_array"], array_value)
assert processed_obs["new_scalar"] == 42
assert processed_obs["old_string"] == "hello"
assert processed_obs["old_dict"] == {"nested": "value"}
assert processed_obs["old_list"] == [1, 2, 3]
def test_feature_contract_basic_renaming(policy_feature_factory):
processor = RenameProcessor(rename_map={"a": "x", "b": "y"})
features = {
"a": policy_feature_factory(FeatureType.STATE, (2,)),
"b": policy_feature_factory(FeatureType.ACTION, (3,)),
"c": policy_feature_factory(FeatureType.ENV, (1,)),
}
out = processor.feature_contract(features.copy())
# Values preserved and typed
assert out["x"] == features["a"]
assert out["y"] == features["b"]
assert out["c"] == features["c"]
assert_contract_is_typed(out)
# Input not mutated
assert set(features) == {"a", "b", "c"}
def test_feature_contract_overlapping_keys(policy_feature_factory):
# Overlapping renames: both 'a' and 'b' exist. 'a'->'b', 'b'->'c'
processor = RenameProcessor(rename_map={"a": "b", "b": "c"})
features = {
"a": policy_feature_factory(FeatureType.STATE, (1,)),
"b": policy_feature_factory(FeatureType.STATE, (2,)),
}
out = processor.feature_contract(features)
assert set(out) == {"b", "c"}
assert out["b"] == features["a"] # 'a' renamed to'b'
assert out["c"] == features["b"] # 'b' renamed to 'c'
assert_contract_is_typed(out)
def test_feature_contract_chained_processors(policy_feature_factory):
# Chain two rename processors at the contract level
processor1 = RenameProcessor(rename_map={"pos": "agent_position", "img": "camera_image"})
processor2 = RenameProcessor(
rename_map={"agent_position": "observation.state", "camera_image": "observation.image"}
)
pipeline = RobotProcessor([processor1, processor2])
spec = {
"pos": policy_feature_factory(FeatureType.STATE, (7,)),
"img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"extra": policy_feature_factory(FeatureType.ENV, (1,)),
}
out = pipeline.feature_contract(initial_features=spec)
assert set(out) == {"observation.state", "observation.image", "extra"}
assert out["observation.state"] == spec["pos"]
assert out["observation.image"] == spec["img"]
assert out["extra"] == spec["extra"]
assert_contract_is_typed(out)

View File

@ -1,208 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from concurrent import futures
from unittest.mock import patch
import pytest
import torch
from torch.multiprocessing import Event, Queue
from lerobot.utils.transition import Transition
from tests.utils import require_package
def create_learner_service_stub():
import grpc
from lerobot.transport import services_pb2, services_pb2_grpc
class MockLearnerService(services_pb2_grpc.LearnerServiceServicer):
def __init__(self):
self.ready_call_count = 0
self.should_fail = False
def Ready(self, request, context): # noqa: N802
self.ready_call_count += 1
if self.should_fail:
context.set_code(grpc.StatusCode.UNAVAILABLE)
context.set_details("Service unavailable")
raise grpc.RpcError("Service unavailable")
return services_pb2.Empty()
"""Fixture to start a LearnerService gRPC server and provide a connected stub."""
servicer = MockLearnerService()
# Create a gRPC server and add our servicer to it.
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
services_pb2_grpc.add_LearnerServiceServicer_to_server(servicer, server)
port = server.add_insecure_port("[::]:0") # bind to a free port chosen by OS
server.start() # start the server (non-blocking call):contentReference[oaicite:1]{index=1}
# Create a client channel and stub connected to the server's port.
channel = grpc.insecure_channel(f"localhost:{port}")
return services_pb2_grpc.LearnerServiceStub(channel), servicer, channel, server
def close_service_stub(channel, server):
channel.close()
server.stop(None)
@require_package("grpc")
def test_establish_learner_connection_success():
from lerobot.scripts.rl.actor import establish_learner_connection
"""Test successful connection establishment."""
stub, _servicer, channel, server = create_learner_service_stub()
shutdown_event = Event()
# Test successful connection
result = establish_learner_connection(stub, shutdown_event, attempts=5)
assert result is True
close_service_stub(channel, server)
@require_package("grpc")
def test_establish_learner_connection_failure():
from lerobot.scripts.rl.actor import establish_learner_connection
"""Test connection failure."""
stub, servicer, channel, server = create_learner_service_stub()
servicer.should_fail = True
shutdown_event = Event()
# Test failed connection
with patch("time.sleep"): # Speed up the test
result = establish_learner_connection(stub, shutdown_event, attempts=2)
assert result is False
close_service_stub(channel, server)
@require_package("grpc")
def test_push_transitions_to_transport_queue():
from lerobot.scripts.rl.actor import push_transitions_to_transport_queue
from lerobot.transport.utils import bytes_to_transitions
from tests.transport.test_transport_utils import assert_transitions_equal
"""Test pushing transitions to transport queue."""
# Create mock transitions
transitions = []
for i in range(3):
transition = Transition(
state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
action=torch.randn(5),
reward=torch.tensor(1.0 + i),
done=torch.tensor(False),
truncated=torch.tensor(False),
next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
complementary_info={"step": torch.tensor(i)},
)
transitions.append(transition)
transitions_queue = Queue()
# Test pushing transitions
push_transitions_to_transport_queue(transitions, transitions_queue)
# Verify the data can be retrieved
serialized_data = transitions_queue.get()
assert isinstance(serialized_data, bytes)
deserialized_transitions = bytes_to_transitions(serialized_data)
assert len(deserialized_transitions) == len(transitions)
for i, deserialized_transition in enumerate(deserialized_transitions):
assert_transitions_equal(deserialized_transition, transitions[i])
@require_package("grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_transitions_stream():
from lerobot.scripts.rl.actor import transitions_stream
"""Test transitions stream functionality."""
shutdown_event = Event()
transitions_queue = Queue()
# Add test data to queue
test_data = [b"transition_data_1", b"transition_data_2", b"transition_data_3"]
for data in test_data:
transitions_queue.put(data)
# Collect streamed data
streamed_data = []
stream_generator = transitions_stream(shutdown_event, transitions_queue, 0.1)
# Process a few items
for i, message in enumerate(stream_generator):
streamed_data.append(message)
if i >= len(test_data) - 1:
shutdown_event.set()
break
# Verify we got messages
assert len(streamed_data) == len(test_data)
assert streamed_data[0].data == b"transition_data_1"
assert streamed_data[1].data == b"transition_data_2"
assert streamed_data[2].data == b"transition_data_3"
@require_package("grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_interactions_stream():
from lerobot.scripts.rl.actor import interactions_stream
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
"""Test interactions stream functionality."""
shutdown_event = Event()
interactions_queue = Queue()
# Create test interaction data (similar structure to what would be sent)
test_interactions = [
{"episode_reward": 10.5, "step": 1, "policy_fps": 30.2},
{"episode_reward": 15.2, "step": 2, "policy_fps": 28.7},
{"episode_reward": 8.7, "step": 3, "policy_fps": 29.1},
]
# Serialize the interaction data as it would be in practice
test_data = [
interactions_queue.put(python_object_to_bytes(interaction)) for interaction in test_interactions
]
# Collect streamed data
streamed_data = []
stream_generator = interactions_stream(shutdown_event, interactions_queue, 0.1)
# Process the items
for i, message in enumerate(stream_generator):
streamed_data.append(message)
if i >= len(test_data) - 1:
shutdown_event.set()
break
# Verify we got messages
assert len(streamed_data) == len(test_data)
# Verify the messages can be deserialized back to original data
for i, message in enumerate(streamed_data):
deserialized_interaction = bytes_to_python_object(message.data)
assert deserialized_interaction == test_interactions[i]

View File

@ -1,297 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket
import threading
import time
import pytest
import torch
from torch.multiprocessing import Event, Queue
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.utils.transition import Transition
from tests.utils import require_package
def create_test_transitions(count: int = 3) -> list[Transition]:
"""Create test transitions for integration testing."""
transitions = []
for i in range(count):
transition = Transition(
state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
action=torch.randn(5),
reward=torch.tensor(1.0 + i),
done=torch.tensor(i == count - 1), # Last transition is done
truncated=torch.tensor(False),
next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
complementary_info={"step": torch.tensor(i), "episode_id": i // 2},
)
transitions.append(transition)
return transitions
def create_test_interactions(count: int = 3) -> list[dict]:
"""Create test interactions for integration testing."""
interactions = []
for i in range(count):
interaction = {
"episode_reward": 10.0 + i * 5,
"step": i * 100,
"policy_fps": 30.0 + i,
"intervention_rate": 0.1 * i,
"episode_length": 200 + i * 50,
}
interactions.append(interaction)
return interactions
def find_free_port():
"""Finds a free port on the local machine."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0)) # Bind to port 0 to let the OS choose a free port
s.listen(1)
port = s.getsockname()[1]
return port
@pytest.fixture
def cfg():
cfg = TrainRLServerPipelineConfig()
port = find_free_port()
policy_cfg = SACConfig()
policy_cfg.actor_learner_config.learner_host = "127.0.0.1"
policy_cfg.actor_learner_config.learner_port = port
policy_cfg.concurrency.actor = "threads"
policy_cfg.concurrency.learner = "threads"
policy_cfg.actor_learner_config.queue_get_timeout = 0.1
cfg.policy = policy_cfg
return cfg
@require_package("grpc")
@pytest.mark.timeout(10) # force cross-platform watchdog
def test_end_to_end_transitions_flow(cfg):
from lerobot.scripts.rl.actor import (
establish_learner_connection,
learner_service_client,
push_transitions_to_transport_queue,
send_transitions,
)
from lerobot.scripts.rl.learner import start_learner
from lerobot.transport.utils import bytes_to_transitions
from tests.transport.test_transport_utils import assert_transitions_equal
"""Test complete transitions flow from actor to learner."""
transitions_actor_queue = Queue()
transitions_learner_queue = Queue()
interactions_queue = Queue()
parameters_queue = Queue()
shutdown_event = Event()
learner_thread = threading.Thread(
target=start_learner,
args=(parameters_queue, transitions_learner_queue, interactions_queue, shutdown_event, cfg),
)
learner_thread.start()
policy_cfg = cfg.policy
learner_client, channel = learner_service_client(
host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port
)
assert establish_learner_connection(learner_client, shutdown_event, attempts=5)
send_transitions_thread = threading.Thread(
target=send_transitions, args=(cfg, transitions_actor_queue, shutdown_event, learner_client, channel)
)
send_transitions_thread.start()
input_transitions = create_test_transitions(count=5)
push_transitions_to_transport_queue(input_transitions, transitions_actor_queue)
# Wait for learner to start
time.sleep(0.1)
shutdown_event.set()
# Wait for learner to receive transitions
learner_thread.join()
send_transitions_thread.join()
channel.close()
received_transitions = []
while not transitions_learner_queue.empty():
received_transitions.extend(bytes_to_transitions(transitions_learner_queue.get()))
assert len(received_transitions) == len(input_transitions)
for i, transition in enumerate(received_transitions):
assert_transitions_equal(transition, input_transitions[i])
@require_package("grpc")
@pytest.mark.timeout(10)
def test_end_to_end_interactions_flow(cfg):
from lerobot.scripts.rl.actor import (
establish_learner_connection,
learner_service_client,
send_interactions,
)
from lerobot.scripts.rl.learner import start_learner
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
"""Test complete interactions flow from actor to learner."""
# Queues for actor-learner communication
interactions_actor_queue = Queue()
interactions_learner_queue = Queue()
# Other queues required by the learner
parameters_queue = Queue()
transitions_learner_queue = Queue()
shutdown_event = Event()
# Start the learner in a separate thread
learner_thread = threading.Thread(
target=start_learner,
args=(parameters_queue, transitions_learner_queue, interactions_learner_queue, shutdown_event, cfg),
)
learner_thread.start()
# Establish connection from actor to learner
policy_cfg = cfg.policy
learner_client, channel = learner_service_client(
host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port
)
assert establish_learner_connection(learner_client, shutdown_event, attempts=5)
# Start the actor's interaction sending process in a separate thread
send_interactions_thread = threading.Thread(
target=send_interactions,
args=(cfg, interactions_actor_queue, shutdown_event, learner_client, channel),
)
send_interactions_thread.start()
# Create and push test interactions to the actor's queue
input_interactions = create_test_interactions(count=5)
for interaction in input_interactions:
interactions_actor_queue.put(python_object_to_bytes(interaction))
# Wait for the communication to happen
time.sleep(0.1)
# Signal shutdown and wait for threads to complete
shutdown_event.set()
learner_thread.join()
send_interactions_thread.join()
channel.close()
# Verify that the learner received the interactions
received_interactions = []
while not interactions_learner_queue.empty():
received_interactions.append(bytes_to_python_object(interactions_learner_queue.get()))
assert len(received_interactions) == len(input_interactions)
# Sort by a unique key to handle potential reordering in queues
received_interactions.sort(key=lambda x: x["step"])
input_interactions.sort(key=lambda x: x["step"])
for received, expected in zip(received_interactions, input_interactions, strict=False):
assert received == expected
@require_package("grpc")
@pytest.mark.parametrize("data_size", ["small", "large"])
@pytest.mark.timeout(10)
def test_end_to_end_parameters_flow(cfg, data_size):
from lerobot.scripts.rl.actor import establish_learner_connection, learner_service_client, receive_policy
from lerobot.scripts.rl.learner import start_learner
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
"""Test complete parameter flow from learner to actor, with small and large data."""
# Actor's local queue to receive params
parameters_actor_queue = Queue()
# Learner's queue to send params from
parameters_learner_queue = Queue()
# Other queues required by the learner
transitions_learner_queue = Queue()
interactions_learner_queue = Queue()
shutdown_event = Event()
# Start the learner in a separate thread
learner_thread = threading.Thread(
target=start_learner,
args=(
parameters_learner_queue,
transitions_learner_queue,
interactions_learner_queue,
shutdown_event,
cfg,
),
)
learner_thread.start()
# Establish connection from actor to learner
policy_cfg = cfg.policy
learner_client, channel = learner_service_client(
host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port
)
assert establish_learner_connection(learner_client, shutdown_event, attempts=5)
# Start the actor's parameter receiving process in a separate thread
receive_params_thread = threading.Thread(
target=receive_policy,
args=(cfg, parameters_actor_queue, shutdown_event, learner_client, channel),
)
receive_params_thread.start()
# Create test parameters based on parametrization
if data_size == "small":
input_params = {"layer.weight": torch.randn(128, 64)}
else: # "large"
# CHUNK_SIZE is 2MB, so this tensor (4MB) will force chunking
input_params = {"large_layer.weight": torch.randn(1024, 1024)}
# Simulate learner having new parameters to send
parameters_learner_queue.put(state_to_bytes(input_params))
# Wait for the actor to receive the parameters
time.sleep(0.1)
# Signal shutdown and wait for threads to complete
shutdown_event.set()
learner_thread.join()
receive_params_thread.join()
channel.close()
# Verify that the actor received the parameters correctly
received_params = bytes_to_state_dict(parameters_actor_queue.get())
assert received_params.keys() == input_params.keys()
for key in input_params:
assert torch.allclose(received_params[key], input_params[key])

View File

@ -1,374 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time
from concurrent import futures
from multiprocessing import Event, Queue
import pytest
from tests.utils import require_package # our gRPC servicer class
@pytest.fixture(scope="function")
def learner_service_stub():
shutdown_event = Event()
parameters_queue = Queue()
transitions_queue = Queue()
interactions_queue = Queue()
seconds_between_pushes = 1
client, channel, server = create_learner_service_stub(
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
)
yield client # provide the stub to the test function
close_learner_service_stub(channel, server)
@require_package("grpc")
def create_learner_service_stub(
shutdown_event: Event,
parameters_queue: Queue,
transitions_queue: Queue,
interactions_queue: Queue,
seconds_between_pushes: int,
queue_get_timeout: float = 0.1,
):
import grpc
from lerobot.scripts.rl.learner_service import LearnerService
from lerobot.transport import services_pb2_grpc # generated from .proto
"""Fixture to start a LearnerService gRPC server and provide a connected stub."""
servicer = LearnerService(
shutdown_event=shutdown_event,
parameters_queue=parameters_queue,
seconds_between_pushes=seconds_between_pushes,
transition_queue=transitions_queue,
interaction_message_queue=interactions_queue,
queue_get_timeout=queue_get_timeout,
)
# Create a gRPC server and add our servicer to it.
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
services_pb2_grpc.add_LearnerServiceServicer_to_server(servicer, server)
port = server.add_insecure_port("[::]:0") # bind to a free port chosen by OS
server.start() # start the server (non-blocking call):contentReference[oaicite:1]{index=1}
# Create a client channel and stub connected to the server's port.
channel = grpc.insecure_channel(f"localhost:{port}")
return services_pb2_grpc.LearnerServiceStub(channel), channel, server
@require_package("grpc")
def close_learner_service_stub(channel, server):
channel.close()
server.stop(None)
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_ready_method(learner_service_stub):
from lerobot.transport import services_pb2
"""Test the ready method of the UserService."""
request = services_pb2.Empty()
response = learner_service_stub.Ready(request)
assert response == services_pb2.Empty()
@require_package("grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_send_interactions():
from lerobot.transport import services_pb2
shutdown_event = Event()
parameters_queue = Queue()
transitions_queue = Queue()
interactions_queue = Queue()
seconds_between_pushes = 1
client, channel, server = create_learner_service_stub(
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
)
list_of_interaction_messages = [
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"1"),
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"2"),
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"3"),
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"4"),
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"5"),
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"6"),
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"7"),
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"8"),
]
def mock_intercations_stream():
yield from list_of_interaction_messages
return services_pb2.Empty()
response = client.SendInteractions(mock_intercations_stream())
assert response == services_pb2.Empty()
close_learner_service_stub(channel, server)
# Extract the data from the interactions queue
interactions = []
while not interactions_queue.empty():
interactions.append(interactions_queue.get())
assert interactions == [b"123", b"4", b"5", b"678"]
@require_package("grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_send_transitions():
from lerobot.transport import services_pb2
"""Test the SendTransitions method with various transition data."""
shutdown_event = Event()
parameters_queue = Queue()
transitions_queue = Queue()
interactions_queue = Queue()
seconds_between_pushes = 1
client, channel, server = create_learner_service_stub(
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
)
# Create test transition messages
list_of_transition_messages = [
services_pb2.Transition(
transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"transition_1"
),
services_pb2.Transition(
transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"transition_2"
),
services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"transition_3"),
services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"batch_1"),
services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"batch_2"),
]
def mock_transitions_stream():
yield from list_of_transition_messages
response = client.SendTransitions(mock_transitions_stream())
assert response == services_pb2.Empty()
close_learner_service_stub(channel, server)
# Extract the data from the transitions queue
transitions = []
while not transitions_queue.empty():
transitions.append(transitions_queue.get())
# Should have assembled the chunked data
assert transitions == [b"transition_1transition_2transition_3", b"batch_1batch_2"]
@require_package("grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_send_transitions_empty_stream():
from lerobot.transport import services_pb2
"""Test SendTransitions with empty stream."""
shutdown_event = Event()
parameters_queue = Queue()
transitions_queue = Queue()
interactions_queue = Queue()
seconds_between_pushes = 1
client, channel, server = create_learner_service_stub(
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
)
def empty_stream():
return iter([])
response = client.SendTransitions(empty_stream())
assert response == services_pb2.Empty()
close_learner_service_stub(channel, server)
# Queue should remain empty
assert transitions_queue.empty()
@require_package("grpc")
@pytest.mark.timeout(10) # force cross-platform watchdog
def test_stream_parameters():
import time
from lerobot.transport import services_pb2
"""Test the StreamParameters method."""
shutdown_event = Event()
parameters_queue = Queue()
transitions_queue = Queue()
interactions_queue = Queue()
seconds_between_pushes = 0.2 # Short delay for testing
client, channel, server = create_learner_service_stub(
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
)
# Add test parameters to the queue
test_params = [b"param_batch_1", b"param_batch_2"]
for param in test_params:
parameters_queue.put(param)
# Start streaming parameters
request = services_pb2.Empty()
stream = client.StreamParameters(request)
# Collect streamed parameters and timestamps
received_params = []
timestamps = []
for response in stream:
received_params.append(response.data)
timestamps.append(time.time())
# We should receive one last item
break
parameters_queue.put(b"param_batch_3")
for response in stream:
received_params.append(response.data)
timestamps.append(time.time())
# We should receive only one item
break
shutdown_event.set()
close_learner_service_stub(channel, server)
assert received_params == [b"param_batch_2", b"param_batch_3"]
# Check the time difference between the two sends
time_diff = timestamps[1] - timestamps[0]
# Check if the time difference is close to the expected push frequency
assert time_diff == pytest.approx(seconds_between_pushes, abs=0.1)
@require_package("grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_stream_parameters_with_shutdown():
from lerobot.transport import services_pb2
"""Test StreamParameters handles shutdown gracefully."""
shutdown_event = Event()
parameters_queue = Queue()
transitions_queue = Queue()
interactions_queue = Queue()
seconds_between_pushes = 0.1
queue_get_timeout = 0.001
client, channel, server = create_learner_service_stub(
shutdown_event,
parameters_queue,
transitions_queue,
interactions_queue,
seconds_between_pushes,
queue_get_timeout=queue_get_timeout,
)
test_params = [b"param_batch_1", b"stop", b"param_batch_3", b"param_batch_4"]
# create a thread that will put the parameters in the queue
def producer():
for param in test_params:
parameters_queue.put(param)
time.sleep(0.1)
producer_thread = threading.Thread(target=producer)
producer_thread.start()
# Start streaming
request = services_pb2.Empty()
stream = client.StreamParameters(request)
# Collect streamed parameters
received_params = []
for response in stream:
received_params.append(response.data)
if response.data == b"stop":
shutdown_event.set()
producer_thread.join()
close_learner_service_stub(channel, server)
assert received_params == [b"param_batch_1", b"stop"]
@require_package("grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_stream_parameters_waits_and_retries_on_empty_queue():
import threading
import time
from lerobot.transport import services_pb2
"""Test that StreamParameters waits and retries when the queue is empty."""
shutdown_event = Event()
parameters_queue = Queue()
transitions_queue = Queue()
interactions_queue = Queue()
seconds_between_pushes = 0.05
queue_get_timeout = 0.01
client, channel, server = create_learner_service_stub(
shutdown_event,
parameters_queue,
transitions_queue,
interactions_queue,
seconds_between_pushes,
queue_get_timeout=queue_get_timeout,
)
request = services_pb2.Empty()
stream = client.StreamParameters(request)
received_params = []
def producer():
# Let the consumer start and find an empty queue.
# It will wait `seconds_between_pushes` (0.05s), then `get` will timeout after `queue_get_timeout` (0.01s).
# Total time for the first empty loop is > 0.06s. We wait a bit longer to be safe.
time.sleep(0.06)
parameters_queue.put(b"param_after_wait")
time.sleep(0.05)
parameters_queue.put(b"param_after_wait_2")
producer_thread = threading.Thread(target=producer)
producer_thread.start()
# The consumer will block here until the producer sends an item.
for response in stream:
received_params.append(response.data)
if response.data == b"param_after_wait_2":
break # We only need one item for this test.
shutdown_event.set()
producer_thread.join()
close_learner_service_stub(channel, server)
assert received_params == [b"param_after_wait", b"param_after_wait_2"]

View File

@ -1,111 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
import pytest
from lerobot.robots.so100_follower import (
SO100Follower,
SO100FollowerConfig,
)
def _make_bus_mock() -> MagicMock:
"""Return a bus mock with just the attributes used by the robot."""
bus = MagicMock(name="FeetechBusMock")
bus.is_connected = False
def _connect():
bus.is_connected = True
def _disconnect(_disable=True):
bus.is_connected = False
bus.connect.side_effect = _connect
bus.disconnect.side_effect = _disconnect
@contextmanager
def _dummy_cm():
yield
bus.torque_disabled.side_effect = _dummy_cm
return bus
@pytest.fixture
def follower():
bus_mock = _make_bus_mock()
def _bus_side_effect(*_args, **kwargs):
bus_mock.motors = kwargs["motors"]
motors_order: list[str] = list(bus_mock.motors)
bus_mock.sync_read.return_value = {motor: idx for idx, motor in enumerate(motors_order, 1)}
bus_mock.sync_write.return_value = None
bus_mock.write.return_value = None
bus_mock.disable_torque.return_value = None
bus_mock.enable_torque.return_value = None
bus_mock.is_calibrated = True
return bus_mock
with (
patch(
"lerobot.robots.so100_follower.so100_follower.FeetechMotorsBus",
side_effect=_bus_side_effect,
),
patch.object(SO100Follower, "configure", lambda self: None),
):
cfg = SO100FollowerConfig(port="/dev/null")
robot = SO100Follower(cfg)
yield robot
if robot.is_connected:
robot.disconnect()
def test_connect_disconnect(follower):
assert not follower.is_connected
follower.connect()
assert follower.is_connected
follower.disconnect()
assert not follower.is_connected
def test_get_observation(follower):
follower.connect()
obs = follower.get_observation()
expected_keys = {f"{m}.pos" for m in follower.bus.motors}
assert set(obs.keys()) == expected_keys
for idx, motor in enumerate(follower.bus.motors, 1):
assert obs[f"{motor}.pos"] == idx
def test_send_action(follower):
follower.connect()
action = {f"{m}.pos": i * 10 for i, m in enumerate(follower.bus.motors, 1)}
returned = follower.send_action(action)
assert returned == action
goal_pos = {m: (i + 1) * 10 for i, m in enumerate(follower.bus.motors)}
follower.bus.sync_write.assert_called_once_with("Goal_Position", goal_pos)

View File

@ -1,60 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import gymnasium as gym
import pytest
import lerobot
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
from tests.utils import require_env
@pytest.mark.parametrize("env_name, task_name", lerobot.env_task_pairs)
@require_env
def test_available_env_task(env_name: str, task_name: list):
"""
This test verifies that all environments listed in `lerobot/__init__.py` can
be successfully imported if they're installed — and that their
`available_tasks_per_env` are valid.
"""
package_name = f"gym_{env_name}"
importlib.import_module(package_name)
gym_handle = f"{package_name}/{task_name}"
assert gym_handle in gym.envs.registry, gym_handle
def test_available_policies():
"""
This test verifies that the class attribute `name` for all policies is
consistent with those listed in `lerobot/__init__.py`.
"""
policy_classes = [ACTPolicy, DiffusionPolicy, TDMPCPolicy, VQBeTPolicy]
policies = [pol_cls.name for pol_cls in policy_classes]
assert set(policies) == set(lerobot.available_policies), policies
def test_print():
print(lerobot.available_envs)
print(lerobot.available_tasks_per_env)
print(lerobot.available_datasets)
print(lerobot.available_datasets_per_env)
print(lerobot.available_real_world_datasets)
print(lerobot.available_policies)
print(lerobot.available_policies_per_env)

View File

@ -1,106 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.calibrate import CalibrateConfig, calibrate
from lerobot.record import DatasetRecordConfig, RecordConfig, record
from lerobot.replay import DatasetReplayConfig, ReplayConfig, replay
from lerobot.teleoperate import TeleoperateConfig, teleoperate
from tests.fixtures.constants import DUMMY_REPO_ID
from tests.mocks.mock_robot import MockRobotConfig
from tests.mocks.mock_teleop import MockTeleopConfig
def test_calibrate():
robot_cfg = MockRobotConfig()
cfg = CalibrateConfig(robot=robot_cfg)
calibrate(cfg)
def test_teleoperate():
robot_cfg = MockRobotConfig()
teleop_cfg = MockTeleopConfig()
cfg = TeleoperateConfig(
robot=robot_cfg,
teleop=teleop_cfg,
teleop_time_s=0.1,
)
teleoperate(cfg)
def test_record_and_resume(tmp_path):
robot_cfg = MockRobotConfig()
teleop_cfg = MockTeleopConfig()
dataset_cfg = DatasetRecordConfig(
repo_id=DUMMY_REPO_ID,
single_task="Dummy task",
root=tmp_path / "record",
num_episodes=1,
episode_time_s=0.1,
reset_time_s=0,
push_to_hub=False,
)
cfg = RecordConfig(
robot=robot_cfg,
dataset=dataset_cfg,
teleop=teleop_cfg,
play_sounds=False,
)
dataset = record(cfg)
assert dataset.fps == 30
assert dataset.meta.total_episodes == dataset.num_episodes == 1
assert dataset.meta.total_frames == dataset.num_frames == 3
assert dataset.meta.total_tasks == 1
cfg.resume = True
dataset = record(cfg)
assert dataset.meta.total_episodes == dataset.num_episodes == 2
assert dataset.meta.total_frames == dataset.num_frames == 6
assert dataset.meta.total_tasks == 1
def test_record_and_replay(tmp_path):
robot_cfg = MockRobotConfig()
teleop_cfg = MockTeleopConfig()
record_dataset_cfg = DatasetRecordConfig(
repo_id=DUMMY_REPO_ID,
single_task="Dummy task",
root=tmp_path / "record_and_replay",
num_episodes=1,
episode_time_s=0.1,
push_to_hub=False,
)
record_cfg = RecordConfig(
robot=robot_cfg,
dataset=record_dataset_cfg,
teleop=teleop_cfg,
play_sounds=False,
)
replay_dataset_cfg = DatasetReplayConfig(
repo_id=DUMMY_REPO_ID,
episode=0,
root=tmp_path / "record_and_replay",
)
replay_cfg = ReplayConfig(
robot=robot_cfg,
dataset=replay_dataset_cfg,
play_sounds=False,
)
record(record_cfg)
replay(replay_cfg)

Some files were not shown because too many files have changed in this diff Show More