diff --git a/.gitignore b/.gitignore index 5793aa4..8f4c8e1 100644 --- a/.gitignore +++ b/.gitignore @@ -180,4 +180,7 @@ s100 huggingface_models docker/inputs -docker/outputs \ No newline at end of file +docker/outputs + +# Skip big files in tests folder +tests \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index f52df1b..0000000 --- a/tests/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/artifacts/cameras/image_128x128.png b/tests/artifacts/cameras/image_128x128.png deleted file mode 100644 index b117f49..0000000 --- a/tests/artifacts/cameras/image_128x128.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9dc9df05797dc0e7b92edc845caab2e4c37c3cfcabb4ee6339c67212b5baba3b -size 38023 diff --git a/tests/artifacts/cameras/image_160x120.png b/tests/artifacts/cameras/image_160x120.png deleted file mode 100644 index cdc681d..0000000 --- a/tests/artifacts/cameras/image_160x120.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:7e11af87616b83c1cdb30330e951b91e86b51c64a1326e1ba5b4a3fbcdec1a11 -size 55698 diff --git a/tests/artifacts/cameras/image_320x180.png b/tests/artifacts/cameras/image_320x180.png deleted file mode 100644 index 4cfd511..0000000 --- a/tests/artifacts/cameras/image_320x180.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b8840fb643afe903191248703b1f95a57faf5812ecd9978ac502ee939646fdb2 -size 121115 diff --git a/tests/artifacts/cameras/image_480x270.png b/tests/artifacts/cameras/image_480x270.png deleted file mode 100644 index b564d54..0000000 --- a/tests/artifacts/cameras/image_480x270.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f79d14daafb1c0cf2fec5d46ee8029a73fe357402fdd31a7cd4a4794d7319a7c -size 260367 diff --git a/tests/artifacts/cameras/test_rs.bag b/tests/artifacts/cameras/test_rs.bag deleted file mode 100644 index 1b9662c..0000000 --- a/tests/artifacts/cameras/test_rs.bag +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a8d6e64d6cb0e02c94ae125630ee758055bd2e695772c0463a30d63ddc6c5e17 -size 3520862 diff --git a/tests/artifacts/datasets/lerobot/aloha_sim_insertion_human/frame_0.safetensors b/tests/artifacts/datasets/lerobot/aloha_sim_insertion_human/frame_0.safetensors deleted file mode 100644 index 1b1994c..0000000 --- a/tests/artifacts/datasets/lerobot/aloha_sim_insertion_human/frame_0.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:6bdf22208d49cd36d24bc844d4d8bda5e321eafe39d2b470e4fc95c7812fdb24 -size 3687117 diff --git a/tests/artifacts/datasets/lerobot/aloha_sim_insertion_human/frame_1.safetensors b/tests/artifacts/datasets/lerobot/aloha_sim_insertion_human/frame_1.safetensors deleted file mode 100644 index a36663b..0000000 --- a/tests/artifacts/datasets/lerobot/aloha_sim_insertion_human/frame_1.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8920d5ebab36ffcba9aa74dcd91677c121f504b4d945b472352d379f9272fabf -size 3687117 diff --git a/tests/artifacts/datasets/lerobot/aloha_sim_insertion_human/frame_250.safetensors b/tests/artifacts/datasets/lerobot/aloha_sim_insertion_human/frame_250.safetensors deleted file mode 100644 index b6e6e0e..0000000 --- a/tests/artifacts/datasets/lerobot/aloha_sim_insertion_human/frame_250.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:35723f2db499da3d9d121aa79d2ff4c748effd7c2ea92f277ec543a82fb843ca -size 3687117 diff --git a/tests/artifacts/datasets/lerobot/aloha_sim_insertion_human/frame_251.safetensors b/tests/artifacts/datasets/lerobot/aloha_sim_insertion_human/frame_251.safetensors deleted file mode 100644 index ca750b9..0000000 --- a/tests/artifacts/datasets/lerobot/aloha_sim_insertion_human/frame_251.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:53172b773d4a78bb3140f10280105c2c4ebcb467f3097579988d42cb87790ab9 -size 3687117 diff --git a/tests/artifacts/datasets/lerobot/aloha_sim_insertion_human/frame_498.safetensors b/tests/artifacts/datasets/lerobot/aloha_sim_insertion_human/frame_498.safetensors deleted file mode 100644 index 9eb2e14..0000000 --- a/tests/artifacts/datasets/lerobot/aloha_sim_insertion_human/frame_498.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:58a5d91573e7dd2352a1454a5c9118c9ad3798428a0104e5e0b57fc01f780ae7 -size 3687117 diff --git a/tests/artifacts/datasets/lerobot/aloha_sim_insertion_human/frame_499.safetensors b/tests/artifacts/datasets/lerobot/aloha_sim_insertion_human/frame_499.safetensors deleted file mode 100644 index 849c44b..0000000 --- a/tests/artifacts/datasets/lerobot/aloha_sim_insertion_human/frame_499.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:bb65a25e989a32a8b6258d368bd077e4548379c74ab5ada01cc532d658670df0 -size 3687117 diff --git a/tests/artifacts/datasets/lerobot/pusht/frame_0.safetensors b/tests/artifacts/datasets/lerobot/pusht/frame_0.safetensors deleted file mode 100644 index 0a7ced5..0000000 --- a/tests/artifacts/datasets/lerobot/pusht/frame_0.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c3dcff0a705ebfdaf11b7f49ad85b464eff03477ace3d63ce45d6a3a10b429d5 -size 111338 diff --git a/tests/artifacts/datasets/lerobot/pusht/frame_1.safetensors b/tests/artifacts/datasets/lerobot/pusht/frame_1.safetensors deleted file mode 100644 index f999e25..0000000 --- a/tests/artifacts/datasets/lerobot/pusht/frame_1.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d8ab0274761cdd758bafdf274ce3e6398cd6f0df23393971f3e1b6b465d66ef3 -size 111338 diff --git a/tests/artifacts/datasets/lerobot/pusht/frame_159.safetensors b/tests/artifacts/datasets/lerobot/pusht/frame_159.safetensors deleted file mode 100644 index f49a884..0000000 --- a/tests/artifacts/datasets/lerobot/pusht/frame_159.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:aee60956925da9687546aafa770d5e6a04f99576f903b08d0bd5f8003a7f4f3e -size 111338 diff --git a/tests/artifacts/datasets/lerobot/pusht/frame_160.safetensors b/tests/artifacts/datasets/lerobot/pusht/frame_160.safetensors deleted file mode 100644 index dee72c6..0000000 --- a/tests/artifacts/datasets/lerobot/pusht/frame_160.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c8d9f9cc9e232820760fe4a46b47000c921fa5d868420e55d8dbc05dae56e8bd -size 111338 diff --git a/tests/artifacts/datasets/lerobot/pusht/frame_80.safetensors b/tests/artifacts/datasets/lerobot/pusht/frame_80.safetensors deleted file mode 100644 index 9189c4d..0000000 --- a/tests/artifacts/datasets/lerobot/pusht/frame_80.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:01cfe50c537e3aef0cd5947ec0b15b321b54ecb461baf7b4f2506897158eebc8 -size 111338 diff --git a/tests/artifacts/datasets/lerobot/pusht/frame_81.safetensors b/tests/artifacts/datasets/lerobot/pusht/frame_81.safetensors deleted file mode 100644 index 2537af3..0000000 --- a/tests/artifacts/datasets/lerobot/pusht/frame_81.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:96431ca3479eef2379406ef901cad7ba5eac4f7edcc48ecc9e8d1fa0e99d8017 -size 111338 diff --git a/tests/artifacts/datasets/lerobot/xarm_lift_medium/frame_0.safetensors b/tests/artifacts/datasets/lerobot/xarm_lift_medium/frame_0.safetensors deleted file mode 100644 index 00db26a..0000000 --- a/tests/artifacts/datasets/lerobot/xarm_lift_medium/frame_0.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:3763d7bff7873cb40ea9d6f2f98d45fcf163addcd2809b6c59f273b6c3627ad5 -size 85353 diff --git a/tests/artifacts/datasets/lerobot/xarm_lift_medium/frame_1.safetensors b/tests/artifacts/datasets/lerobot/xarm_lift_medium/frame_1.safetensors deleted file mode 100644 index 6f4b0c0..0000000 --- a/tests/artifacts/datasets/lerobot/xarm_lift_medium/frame_1.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:24150994c6959631dc081b43e4001a8664e13b194ac194a32100f7d3fd2c0d0f -size 85353 diff --git a/tests/artifacts/datasets/lerobot/xarm_lift_medium/frame_12.safetensors b/tests/artifacts/datasets/lerobot/xarm_lift_medium/frame_12.safetensors deleted file mode 100644 index fa42365..0000000 --- a/tests/artifacts/datasets/lerobot/xarm_lift_medium/frame_12.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c9c3fdf34debe47d4b80570a19e676185449df749f37daa2111184c1f439ae5f -size 85353 diff --git a/tests/artifacts/datasets/lerobot/xarm_lift_medium/frame_13.safetensors b/tests/artifacts/datasets/lerobot/xarm_lift_medium/frame_13.safetensors deleted file mode 100644 index c010a48..0000000 --- a/tests/artifacts/datasets/lerobot/xarm_lift_medium/frame_13.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f8cfbe444c14d643da2faea9f6a402ddb37114ab15395c381f1a7982e541f868 -size 85353 diff --git a/tests/artifacts/datasets/lerobot/xarm_lift_medium/frame_23.safetensors b/tests/artifacts/datasets/lerobot/xarm_lift_medium/frame_23.safetensors deleted file mode 100644 index 056f9f1..0000000 --- a/tests/artifacts/datasets/lerobot/xarm_lift_medium/frame_23.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:07c5c1a63998884ee747a6d0aa8f49217da3c32af2760dad2a9da794d3517003 -size 85353 diff --git a/tests/artifacts/datasets/lerobot/xarm_lift_medium/frame_24.safetensors b/tests/artifacts/datasets/lerobot/xarm_lift_medium/frame_24.safetensors deleted file mode 100644 index 41a384d..0000000 --- a/tests/artifacts/datasets/lerobot/xarm_lift_medium/frame_24.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9927ec508e3335f8b10cf3682e41dedb7e647f92a2063a4196f1e48749c47bc5 -size 85353 diff --git a/tests/artifacts/datasets/save_dataset_to_safetensors.py b/tests/artifacts/datasets/save_dataset_to_safetensors.py deleted file mode 100644 index 419961b..0000000 --- a/tests/artifacts/datasets/save_dataset_to_safetensors.py +++ /dev/null @@ -1,91 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script provides a utility for saving a dataset as safetensors files for the purpose of testing backward compatibility -when updating the data format. It uses the `PushtDataset` to create a DataLoader and saves selected frame from the -dataset into a corresponding safetensors file in a specified output directory. - -If you know that your change will break backward compatibility, you should write a shortlived test by modifying -`tests/test_datasets.py::test_backward_compatibility` accordingly, and make sure this custom test pass. Your custom test -doesnt need to be merged into the `main` branch. Then you need to run this script and update the tests artifacts. - -Example usage: - `python tests/artifacts/datasets/save_dataset_to_safetensors.py` -""" - -import shutil -from pathlib import Path - -from safetensors.torch import save_file - -from lerobot.datasets.lerobot_dataset import LeRobotDataset - - -def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"): - repo_dir = Path(output_dir) / repo_id - - if repo_dir.exists(): - shutil.rmtree(repo_dir) - - repo_dir.mkdir(parents=True, exist_ok=True) - dataset = LeRobotDataset( - repo_id=repo_id, - episodes=[0], - ) - - # save 2 first frames of first episode - i = dataset.episode_data_index["from"][0].item() - save_file(dataset[i], repo_dir / f"frame_{i}.safetensors") - save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors") - - # save 2 frames at the middle of first episode - i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2) - save_file(dataset[i], repo_dir / f"frame_{i}.safetensors") - save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors") - - # save 2 last frames of first episode - i = dataset.episode_data_index["to"][0].item() - save_file(dataset[i - 2], repo_dir / f"frame_{i - 2}.safetensors") - save_file(dataset[i - 1], repo_dir / f"frame_{i - 1}.safetensors") - - # TODO(rcadene): Enable testing on second and last episode - # We currently cant because our test dataset only contains the first episode - - # # save 2 first frames of second episode - # i = dataset.episode_data_index["from"][1].item() - # save_file(dataset[i], repo_dir / f"frame_{i}.safetensors") - # save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors") - - # # save 2 last frames of second episode - # i = dataset.episode_data_index["to"][1].item() - # save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors") - # save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors") - - # # save 2 last frames of last episode - # i = dataset.episode_data_index["to"][-1].item() - # save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors") - # save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors") - - -if __name__ == "__main__": - for dataset in [ - "lerobot/pusht", - "lerobot/aloha_sim_insertion_human", - "lerobot/xarm_lift_medium", - "lerobot/nyu_franka_play_dataset", - "lerobot/cmu_stretch", - ]: - save_dataset_to_safetensors("tests/artifacts/datasets", repo_id=dataset) diff --git a/tests/artifacts/image_transforms/default_transforms.safetensors b/tests/artifacts/image_transforms/default_transforms.safetensors deleted file mode 100644 index 2c08499..0000000 --- a/tests/artifacts/image_transforms/default_transforms.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:6b1e600768a8771c5fe650e038a1193597e3810f032041b2a0d021e4496381c1 -size 3686488 diff --git a/tests/artifacts/image_transforms/save_image_transforms_to_safetensors.py b/tests/artifacts/image_transforms/save_image_transforms_to_safetensors.py deleted file mode 100644 index ce15d16..0000000 --- a/tests/artifacts/image_transforms/save_image_transforms_to_safetensors.py +++ /dev/null @@ -1,75 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from pathlib import Path - -import torch -from safetensors.torch import save_file - -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.transforms import ( - ImageTransformConfig, - ImageTransforms, - ImageTransformsConfig, - make_transform_from_config, -) -from lerobot.utils.random_utils import seeded_context - -ARTIFACT_DIR = Path("tests/artifacts/image_transforms") -DATASET_REPO_ID = "lerobot/aloha_static_cups_open" - - -def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path): - cfg = ImageTransformsConfig(enable=True) - default_tf = ImageTransforms(cfg) - - with seeded_context(1337): - img_tf = default_tf(original_frame) - - save_file({"default": img_tf}, output_dir / "default_transforms.safetensors") - - -def save_single_transforms(original_frame: torch.Tensor, output_dir: Path): - transforms = { - ("ColorJitter", "brightness", [(0.5, 0.5), (2.0, 2.0)]), - ("ColorJitter", "contrast", [(0.5, 0.5), (2.0, 2.0)]), - ("ColorJitter", "saturation", [(0.5, 0.5), (2.0, 2.0)]), - ("ColorJitter", "hue", [(-0.25, -0.25), (0.25, 0.25)]), - ("SharpnessJitter", "sharpness", [(0.5, 0.5), (2.0, 2.0)]), - } - - frames = {"original_frame": original_frame} - for tf_type, tf_name, min_max_values in transforms.items(): - for min_max in min_max_values: - tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max}) - tf = make_transform_from_config(tf_cfg) - key = f"{tf_name}_{min_max[0]}_{min_max[1]}" - frames[key] = tf(original_frame) - - save_file(frames, output_dir / "single_transforms.safetensors") - - -def main(): - dataset = LeRobotDataset(DATASET_REPO_ID, episodes=[0], image_transforms=None) - output_dir = Path(ARTIFACT_DIR) - output_dir.mkdir(parents=True, exist_ok=True) - original_frame = dataset[0][dataset.meta.camera_keys[0]] - - save_single_transforms(original_frame, output_dir) - save_default_config_transform(original_frame, output_dir) - - -if __name__ == "__main__": - main() diff --git a/tests/artifacts/image_transforms/single_transforms.safetensors b/tests/artifacts/image_transforms/single_transforms.safetensors deleted file mode 100644 index 7a0599d..0000000 --- a/tests/artifacts/image_transforms/single_transforms.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9d4ebab73eabddc58879a4e770289d19e00a1a4cf2fa5fa33cd3a3246992bc90 -size 40551392 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_/actions.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_/actions.safetensors deleted file mode 100644 index 8bd63e8..0000000 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_/actions.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f3e4c8e85e146b043fd4e4984947c2a6f01627f174a19f18b5914cf690579d77 -size 5104 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_/grad_stats.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_/grad_stats.safetensors deleted file mode 100644 index 5209ae6..0000000 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_/grad_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:1a7a8b1a457149109f843c32bcbb047d09de2201847b9b79f7501b447f77ecf4 -size 31672 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_/output_dict.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_/output_dict.safetensors deleted file mode 100644 index 736aff9..0000000 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_/output_dict.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5e6ce85296b2009e7c2060d336c0429b1c7197d9adb159e7df0ba18003067b36 -size 68 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_/param_stats.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_/param_stats.safetensors deleted file mode 100644 index 724d22b..0000000 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_/param_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9b5f557e30aead3731c38cbd85af8c706395d8689a918ad88805b5a886245603 -size 33400 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors deleted file mode 100644 index 6d912d8..0000000 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:2e6625cabfeb4800abc80252cf9112a9271c154edd01eb291658f143c951610b -size 515400 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/grad_stats.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/grad_stats.safetensors deleted file mode 100644 index c58bb44..0000000 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/grad_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:224b5fa4828aa88171b68c036e8919c1eae563e2113f03b6461eadf5bf8525a6 -size 31672 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/output_dict.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/output_dict.safetensors deleted file mode 100644 index 9b6ef7f..0000000 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/output_dict.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:016d2fa8fe5f58017dfd46f4632fdc19dfd751e32a2c7cde2077c6f95546d6bd -size 68 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors deleted file mode 100644 index cc6b4a2..0000000 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:021562ee3e4814425e367ed0c144d6fbe2eb28838247085716cf0b58fd69a075 -size 33400 diff --git a/tests/artifacts/policies/pusht_diffusion_/actions.safetensors b/tests/artifacts/policies/pusht_diffusion_/actions.safetensors deleted file mode 100644 index 84e14b9..0000000 --- a/tests/artifacts/policies/pusht_diffusion_/actions.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a32376dde65a1562403afd1db3e56c7e6b987ebaf6c3c601336e77155b9e608c -size 992 diff --git a/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors b/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors deleted file mode 100644 index 5422979..0000000 --- a/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:12ee532c53173d0361ebb979f087b229cc045aa3d9e6b94cfd4290af54fd1201 -size 47424 diff --git a/tests/artifacts/policies/pusht_diffusion_/output_dict.safetensors b/tests/artifacts/policies/pusht_diffusion_/output_dict.safetensors deleted file mode 100644 index f293039..0000000 --- a/tests/artifacts/policies/pusht_diffusion_/output_dict.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:271b00cb2f0cd5fd26b1d53463638e3d1a6e92692ec625fcffb420ca190869e5 -size 68 diff --git a/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors b/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors deleted file mode 100644 index e91cd08..0000000 --- a/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:010c01181b95625051276d69cb4209423c21f2e30a3fa9464ae67064a2ba4c22 -size 49120 diff --git a/tests/artifacts/policies/save_policy_to_safetensors.py b/tests/artifacts/policies/save_policy_to_safetensors.py deleted file mode 100644 index 6ccb47c..0000000 --- a/tests/artifacts/policies/save_policy_to_safetensors.py +++ /dev/null @@ -1,145 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import shutil -from pathlib import Path - -import torch -from safetensors.torch import save_file - -from lerobot.configs.default import DatasetConfig -from lerobot.configs.train import TrainPipelineConfig -from lerobot.datasets.factory import make_dataset -from lerobot.optim.factory import make_optimizer_and_scheduler -from lerobot.policies.factory import make_policy, make_policy_config -from lerobot.utils.random_utils import set_seed - - -def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): - set_seed(1337) - train_cfg = TrainPipelineConfig( - # TODO(rcadene, aliberts): remove dataset download - dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]), - policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs), - ) - train_cfg.validate() # Needed for auto-setting some parameters - - dataset = make_dataset(train_cfg) - policy = make_policy(train_cfg.policy, ds_meta=dataset.meta) - policy.train() - - optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy) - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=0, - batch_size=train_cfg.batch_size, - shuffle=False, - ) - - batch = next(iter(dataloader)) - loss, output_dict = policy.forward(batch) - if output_dict is not None: - output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)} - output_dict["loss"] = loss - else: - output_dict = {"loss": loss} - - loss.backward() - grad_stats = {} - for key, param in policy.named_parameters(): - if param.requires_grad: - grad_stats[f"{key}_mean"] = param.grad.mean() - grad_stats[f"{key}_std"] = ( - param.grad.std() if param.grad.numel() > 1 else torch.tensor(float(0.0)) - ) - - optimizer.step() - param_stats = {} - for key, param in policy.named_parameters(): - param_stats[f"{key}_mean"] = param.mean() - param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0)) - - optimizer.zero_grad() - policy.reset() - - # HACK: We reload a batch with no delta_indices as `select_action` won't expect a timestamps dimension - # We simulate having an environment using a dataset by setting delta_indices to None and dropping tensors - # indicating padding (those ending with "_is_pad") - dataset.delta_indices = None - batch = next(iter(dataloader)) - obs = {} - for k in batch: - # TODO: regenerate the safetensors - # for backward compatibility - if k.endswith("_is_pad"): - continue - # for backward compatibility - if k == "task": - continue - if k.startswith("observation"): - obs[k] = batch[k] - - if hasattr(train_cfg.policy, "n_action_steps"): - actions_queue = train_cfg.policy.n_action_steps - else: - actions_queue = train_cfg.policy.n_action_repeats - - actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)} - return output_dict, grad_stats, param_stats, actions - - -def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict): - if output_dir.exists(): - print(f"Overwrite existing safetensors in '{output_dir}':") - print(f" - Validate with: `git add {output_dir}`") - print(f" - Revert with: `git checkout -- {output_dir}`") - shutil.rmtree(output_dir) - - output_dir.mkdir(parents=True, exist_ok=True) - output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs) - save_file(output_dict, output_dir / "output_dict.safetensors") - save_file(grad_stats, output_dir / "grad_stats.safetensors") - save_file(param_stats, output_dir / "param_stats.safetensors") - save_file(actions, output_dir / "actions.safetensors") - - -if __name__ == "__main__": - artifacts_cfg = [ - ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"), - ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"), - ( - "lerobot/pusht", - "diffusion", - { - "n_action_steps": 8, - "num_inference_steps": 10, - "down_dims": [128, 256, 512], - }, - "", - ), - ("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""), - ( - "lerobot/aloha_sim_insertion_human", - "act", - {"n_action_steps": 1000, "chunk_size": 1000}, - "1000_steps", - ), - ] - if len(artifacts_cfg) == 0: - raise RuntimeError("No policies were provided!") - for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg: - ds_name = ds_repo_id.split("/")[-1] - output_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}" - save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs) diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/actions.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/actions.safetensors deleted file mode 100644 index fa9bf06..0000000 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/actions.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c5edc5600d7206f027cb696a597bc99fcdd9073a15fa130b8031c52c0a7c134b -size 200 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/grad_stats.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/grad_stats.safetensors deleted file mode 100644 index 8d90a67..0000000 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/grad_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10 -size 16904 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/output_dict.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/output_dict.safetensors deleted file mode 100644 index cde6c6d..0000000 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/output_dict.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b -size 164 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/param_stats.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/param_stats.safetensors deleted file mode 100644 index 692377d..0000000 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/param_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170 -size 36312 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/actions.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/actions.safetensors deleted file mode 100644 index 7a0b165..0000000 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/actions.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a9c08753ddc43b6c02a176418b81eb784146e59f4fc914591cbd3582ade392bb -size 200 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/grad_stats.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/grad_stats.safetensors deleted file mode 100644 index 8d90a67..0000000 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/grad_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10 -size 16904 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/output_dict.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/output_dict.safetensors deleted file mode 100644 index cde6c6d..0000000 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/output_dict.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b -size 164 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/param_stats.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/param_stats.safetensors deleted file mode 100644 index 692377d..0000000 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/param_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170 -size 36312 diff --git a/tests/async_inference/test_e2e.py b/tests/async_inference/test_e2e.py deleted file mode 100644 index 1c0400e..0000000 --- a/tests/async_inference/test_e2e.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright 2025 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""End-to-end test of the asynchronous inference stack (client ↔ server). - -This test spins up a lightweight gRPC `PolicyServer` instance with a stubbed -policy network and launches a `RobotClient` that uses a `MockRobot`. The goal -is to exercise the full communication loop: - -1. Client sends policy specification → Server -2. Client streams observations → Server -3. Server streams action chunks → Client -4. Client executes received actions - -The test succeeds if at least one action is executed and the server records at -least one predicted timestep - demonstrating that the gRPC round-trip works -end-to-end using real (but lightweight) protocol messages. -""" - -from __future__ import annotations - -import threading -from concurrent import futures - -import pytest -import torch - -# Skip entire module if grpc is not available -pytest.importorskip("grpc") - -# ----------------------------------------------------------------------------- -# End-to-end test -# ----------------------------------------------------------------------------- - - -def test_async_inference_e2e(monkeypatch): - """Tests the full asynchronous inference pipeline.""" - # Import grpc-dependent modules inside the test function - import grpc - - from lerobot.robots.utils import make_robot_from_config - from lerobot.scripts.server.configs import PolicyServerConfig, RobotClientConfig - from lerobot.scripts.server.helpers import map_robot_keys_to_lerobot_features - from lerobot.scripts.server.policy_server import PolicyServer - from lerobot.scripts.server.robot_client import RobotClient - from lerobot.transport import ( - services_pb2, # type: ignore - services_pb2_grpc, # type: ignore - ) - from tests.mocks.mock_robot import MockRobotConfig - - # Create a stub policy similar to test_policy_server.py - class MockPolicy: - """A minimal mock for an actual policy, returning zeros.""" - - class _Config: - robot_type = "dummy_robot" - - @property - def image_features(self): - """Empty image features since this test doesn't use images.""" - return {} - - def __init__(self): - self.config = self._Config() - - def to(self, *args, **kwargs): - return self - - def model(self, batch): - # Return a chunk of 20 dummy actions. - batch_size = len(batch["robot_type"]) - return torch.zeros(batch_size, 20, 6) - - # ------------------------------------------------------------------ - # 1. Create PolicyServer instance with mock policy - # ------------------------------------------------------------------ - policy_server_config = PolicyServerConfig(host="localhost", port=9999) - policy_server = PolicyServer(policy_server_config) - # Replace the real policy with our fast, deterministic stub. - policy_server.policy = MockPolicy() - policy_server.actions_per_chunk = 20 - policy_server.device = "cpu" - - # Set up robot config and features - robot_config = MockRobotConfig() - mock_robot = make_robot_from_config(robot_config) - - lerobot_features = map_robot_keys_to_lerobot_features(mock_robot) - policy_server.lerobot_features = lerobot_features - - # Force server to produce deterministic action chunks in test mode - policy_server.policy_type = "act" - - def _fake_get_action_chunk(_self, _obs, _type="test"): - action_dim = 6 - batch_size = 1 - actions_per_chunk = policy_server.actions_per_chunk - - return torch.zeros(batch_size, actions_per_chunk, action_dim) - - monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True) - - # Bypass potentially heavy model loading inside SendPolicyInstructions - def _fake_send_policy_instructions(self, request, context): # noqa: N802 - return services_pb2.Empty() - - monkeypatch.setattr(PolicyServer, "SendPolicyInstructions", _fake_send_policy_instructions, raising=True) - - # Build gRPC server running a PolicyServer - server = grpc.server(futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="policy_server")) - services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server) - - # Use the host/port specified in the fixture's config - server_address = f"{policy_server.config.host}:{policy_server.config.port}" - server.add_insecure_port(server_address) - server.start() - - # ------------------------------------------------------------------ - # 2. Create a RobotClient around the MockRobot - # ------------------------------------------------------------------ - client_config = RobotClientConfig( - server_address=server_address, - robot=robot_config, - chunk_size_threshold=0.0, - policy_type="test", - pretrained_name_or_path="test", - actions_per_chunk=20, - verify_robot_cameras=False, - ) - - client = RobotClient(client_config) - assert client.start(), "Client failed initial handshake with the server" - - # Track action chunks received without modifying RobotClient - action_chunks_received = {"count": 0} - original_aggregate = client._aggregate_action_queues - - def counting_aggregate(*args, **kwargs): - action_chunks_received["count"] += 1 - return original_aggregate(*args, **kwargs) - - monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate) - - # Start client threads - action_thread = threading.Thread(target=client.receive_actions, daemon=True) - control_thread = threading.Thread(target=client.control_loop, args=({"task": ""}), daemon=True) - action_thread.start() - control_thread.start() - - # ------------------------------------------------------------------ - # 3. System exchanges a few messages - # ------------------------------------------------------------------ - # Wait for 5 seconds - server.wait_for_termination(timeout=5) - - assert action_chunks_received["count"] > 0, "Client did not receive any action chunks" - assert len(policy_server._predicted_timesteps) > 0, "Server did not record any predicted timesteps" - - # ------------------------------------------------------------------ - # 4. Stop the system - # ------------------------------------------------------------------ - client.stop() - action_thread.join() - control_thread.join() - policy_server.stop() - server.stop(grace=None) diff --git a/tests/async_inference/test_helpers.py b/tests/async_inference/test_helpers.py deleted file mode 100644 index e0b7973..0000000 --- a/tests/async_inference/test_helpers.py +++ /dev/null @@ -1,459 +0,0 @@ -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import pickle -import time - -import numpy as np -import torch - -from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.scripts.server.helpers import ( - FPSTracker, - TimedAction, - TimedObservation, - observations_similar, - prepare_image, - prepare_raw_observation, - raw_observation_to_observation, - resize_robot_observation_image, -) - -# --------------------------------------------------------------------- -# FPSTracker -# --------------------------------------------------------------------- - - -def test_fps_tracker_first_observation(): - """First observation should initialize timestamp and return 0 FPS.""" - tracker = FPSTracker(target_fps=30.0) - timestamp = 1000.0 - - metrics = tracker.calculate_fps_metrics(timestamp) - - assert tracker.first_timestamp == timestamp - assert tracker.total_obs_count == 1 - assert metrics["avg_fps"] == 0.0 - assert metrics["target_fps"] == 30.0 - - -def test_fps_tracker_single_interval(): - """Two observations 1 second apart should give 1 FPS.""" - tracker = FPSTracker(target_fps=30.0) - - # First observation at t=0 - metrics1 = tracker.calculate_fps_metrics(0.0) - assert metrics1["avg_fps"] == 0.0 - - # Second observation at t=1 (1 second later) - metrics2 = tracker.calculate_fps_metrics(1.0) - expected_fps = 1.0 # (2-1) observations / 1.0 seconds = 1 FPS - assert math.isclose(metrics2["avg_fps"], expected_fps, rel_tol=1e-6) - - -def test_fps_tracker_multiple_intervals(): - """Multiple observations should calculate correct average FPS.""" - tracker = FPSTracker(target_fps=30.0) - - # Simulate 5 observations over 2 seconds (should be 2 FPS average) - timestamps = [0.0, 0.5, 1.0, 1.5, 2.0] - - for i, ts in enumerate(timestamps): - metrics = tracker.calculate_fps_metrics(ts) - - if i == 0: - assert metrics["avg_fps"] == 0.0 - elif i == len(timestamps) - 1: - # After 5 observations over 2 seconds: (5-1)/2 = 2 FPS - expected_fps = 2.0 - assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6) - - -def test_fps_tracker_irregular_intervals(): - """FPS calculation should work with irregular time intervals.""" - tracker = FPSTracker(target_fps=30.0) - - # Irregular timestamps: 0, 0.1, 0.5, 2.0, 3.0 seconds - timestamps = [0.0, 0.1, 0.5, 2.0, 3.0] - - for ts in timestamps: - metrics = tracker.calculate_fps_metrics(ts) - - # 5 observations over 3 seconds: (5-1)/3 = 1.333... FPS - expected_fps = 4.0 / 3.0 - assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6) - - -# --------------------------------------------------------------------- -# TimedData helpers -# --------------------------------------------------------------------- - - -def test_timed_action_getters(): - """TimedAction stores & returns timestamp, action tensor and timestep.""" - ts = time.time() - action = torch.arange(10) - ta = TimedAction(timestamp=ts, action=action, timestep=0) - - assert math.isclose(ta.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6) - torch.testing.assert_close(ta.get_action(), action) - assert ta.get_timestep() == 0 - - -def test_timed_observation_getters(): - """TimedObservation stores & returns timestamp, dict and timestep.""" - ts = time.time() - obs_dict = {"observation.state": torch.ones(6)} - to = TimedObservation(timestamp=ts, observation=obs_dict, timestep=0) - - assert math.isclose(to.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6) - assert to.get_observation() is obs_dict - assert to.get_timestep() == 0 - - -def test_timed_data_deserialization_data_getters(): - """TimedAction / TimedObservation survive a round-trip through ``pickle``. - - The async-inference stack uses ``pickle.dumps`` to move these objects across - the gRPC boundary (see RobotClient.send_observation and PolicyServer.StreamActions). - This test ensures that the payload keeps its content intact after - the (de)serialization round-trip. - """ - ts = time.time() - - # ------------------------------------------------------------------ - # TimedAction - # ------------------------------------------------------------------ - original_action = torch.randn(6) - ta_in = TimedAction(timestamp=ts, action=original_action, timestep=13) - - # Serialize → bytes → deserialize - ta_bytes = pickle.dumps(ta_in) # nosec - ta_out: TimedAction = pickle.loads(ta_bytes) # nosec B301 - - # Identity & content checks - assert math.isclose(ta_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6) - assert ta_out.get_timestep() == 13 - torch.testing.assert_close(ta_out.get_action(), original_action) - - # ------------------------------------------------------------------ - # TimedObservation - # ------------------------------------------------------------------ - obs_dict = {"observation.state": torch.arange(4).float()} - to_in = TimedObservation(timestamp=ts, observation=obs_dict, timestep=7, must_go=True) - - to_bytes = pickle.dumps(to_in) # nosec - to_out: TimedObservation = pickle.loads(to_bytes) # nosec B301 - - assert math.isclose(to_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6) - assert to_out.get_timestep() == 7 - assert to_out.must_go is True - assert to_out.get_observation().keys() == obs_dict.keys() - torch.testing.assert_close(to_out.get_observation()["observation.state"], obs_dict["observation.state"]) - - -# --------------------------------------------------------------------- -# observations_similar() -# --------------------------------------------------------------------- - - -def _make_obs(state: torch.Tensor) -> TimedObservation: - """Create a TimedObservation with raw robot observation format.""" - return TimedObservation( - timestamp=time.time(), - observation={ - "shoulder": state[0].item() if len(state) > 0 else 0.0, - "elbow": state[1].item() if len(state) > 1 else 0.0, - "wrist": state[2].item() if len(state) > 2 else 0.0, - "gripper": state[3].item() if len(state) > 3 else 0.0, - }, - timestep=0, - ) - - -def test_observations_similar_true(): - """Distance below atol → observations considered similar.""" - # Create mock lerobot features for the similarity check - lerobot_features = { - "observation.state": { - "dtype": "float32", - "shape": [4], - "names": ["shoulder", "elbow", "wrist", "gripper"], - } - } - - obs1 = _make_obs(torch.zeros(4)) - obs2 = _make_obs(0.5 * torch.ones(4)) - assert observations_similar(obs1, obs2, lerobot_features, atol=2.0) - - obs3 = _make_obs(2.0 * torch.ones(4)) - assert not observations_similar(obs1, obs3, lerobot_features, atol=2.0) - - -# --------------------------------------------------------------------- -# raw_observation_to_observation and helpers -# --------------------------------------------------------------------- - - -def _create_mock_robot_observation(): - """Create a mock robot observation with motor positions and camera images.""" - return { - "shoulder": 1.0, - "elbow": 2.0, - "wrist": 3.0, - "gripper": 0.5, - "laptop": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8), - "phone": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8), - } - - -def _create_mock_lerobot_features(): - """Create mock lerobot features mapping similar to what hw_to_dataset_features returns.""" - return { - "observation.state": { - "dtype": "float32", - "shape": [4], - "names": ["shoulder", "elbow", "wrist", "gripper"], - }, - "observation.images.laptop": { - "dtype": "image", - "shape": [480, 640, 3], - "names": ["height", "width", "channels"], - }, - "observation.images.phone": { - "dtype": "image", - "shape": [480, 640, 3], - "names": ["height", "width", "channels"], - }, - } - - -def _create_mock_policy_image_features(): - """Create mock policy image features with different resolutions.""" - return { - "observation.images.laptop": PolicyFeature( - type=FeatureType.VISUAL, - shape=(3, 224, 224), # Policy expects smaller resolution - ), - "observation.images.phone": PolicyFeature( - type=FeatureType.VISUAL, - shape=(3, 160, 160), # Different resolution for second camera - ), - } - - -def test_prepare_image(): - """Test image preprocessing: int8 → float32, normalization to [0,1].""" - # Create mock int8 image data - image_int8 = torch.randint(0, 256, size=(3, 224, 224), dtype=torch.uint8) - - processed = prepare_image(image_int8) - - # Check dtype conversion - assert processed.dtype == torch.float32 - - # Check normalization range - assert processed.min() >= 0.0 - assert processed.max() <= 1.0 - - # Check that values are scaled correctly (255 → 1.0, 0 → 0.0) - if image_int8.max() == 255: - assert torch.isclose(processed.max(), torch.tensor(1.0), atol=1e-6) - if image_int8.min() == 0: - assert torch.isclose(processed.min(), torch.tensor(0.0), atol=1e-6) - - # Check memory contiguity - assert processed.is_contiguous() - - -def test_resize_robot_observation_image(): - """Test image resizing from robot resolution to policy resolution.""" - # Create mock image: (H=480, W=640, C=3) - original_image = torch.randint(0, 256, size=(480, 640, 3), dtype=torch.uint8) - target_shape = (3, 224, 224) # (C, H, W) - - resized = resize_robot_observation_image(original_image, target_shape) - - # Check output shape matches target - assert resized.shape == target_shape - - # Check that original image had different dimensions - assert original_image.shape != resized.shape - - # Check that resizing preserves value range - assert resized.min() >= 0 - assert resized.max() <= 255 - - -def test_prepare_raw_observation(): - """Test the preparation of raw robot observation to lerobot format.""" - robot_obs = _create_mock_robot_observation() - lerobot_features = _create_mock_lerobot_features() - policy_image_features = _create_mock_policy_image_features() - - prepared = prepare_raw_observation(robot_obs, lerobot_features, policy_image_features) - - # Check that state is properly extracted and batched - assert "observation.state" in prepared - state = prepared["observation.state"] - assert isinstance(state, torch.Tensor) - assert state.shape == (1, 4) # Batched state - - # Check that images are processed and resized - assert "observation.images.laptop" in prepared - assert "observation.images.phone" in prepared - - laptop_img = prepared["observation.images.laptop"] - phone_img = prepared["observation.images.phone"] - - # Check image shapes match policy requirements - assert laptop_img.shape == policy_image_features["observation.images.laptop"].shape - assert phone_img.shape == policy_image_features["observation.images.phone"].shape - - # Check that images are tensors - assert isinstance(laptop_img, torch.Tensor) - assert isinstance(phone_img, torch.Tensor) - - -def test_raw_observation_to_observation_basic(): - """Test the main raw_observation_to_observation function.""" - robot_obs = _create_mock_robot_observation() - lerobot_features = _create_mock_lerobot_features() - policy_image_features = _create_mock_policy_image_features() - device = "cpu" - - observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) - - # Check that all expected keys are present - assert "observation.state" in observation - assert "observation.images.laptop" in observation - assert "observation.images.phone" in observation - - # Check state processing - state = observation["observation.state"] - assert isinstance(state, torch.Tensor) - assert state.device.type == device - assert state.shape == (1, 4) # Batched - - # Check image processing - laptop_img = observation["observation.images.laptop"] - phone_img = observation["observation.images.phone"] - - # Images should have batch dimension: (B, C, H, W) - assert laptop_img.shape == (1, 3, 224, 224) - assert phone_img.shape == (1, 3, 160, 160) - - # Check device placement - assert laptop_img.device.type == device - assert phone_img.device.type == device - - # Check image dtype and range (should be float32 in [0, 1]) - assert laptop_img.dtype == torch.float32 - assert phone_img.dtype == torch.float32 - assert laptop_img.min() >= 0.0 and laptop_img.max() <= 1.0 - assert phone_img.min() >= 0.0 and phone_img.max() <= 1.0 - - -def test_raw_observation_to_observation_with_non_tensor_data(): - """Test that non-tensor data (like task strings) is preserved.""" - robot_obs = _create_mock_robot_observation() - robot_obs["task"] = "pick up the red cube" # Add string instruction - - lerobot_features = _create_mock_lerobot_features() - policy_image_features = _create_mock_policy_image_features() - device = "cpu" - - observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) - - # Check that task string is preserved - assert "task" in observation - assert observation["task"] == "pick up the red cube" - assert isinstance(observation["task"], str) - - -@torch.no_grad() -def test_raw_observation_to_observation_device_handling(): - """Test that tensors are properly moved to the specified device.""" - device = "mps" if torch.backends.mps.is_available() else "cpu" - - robot_obs = _create_mock_robot_observation() - lerobot_features = _create_mock_lerobot_features() - policy_image_features = _create_mock_policy_image_features() - - observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) - - # Check that all tensors are on the correct device - for key, value in observation.items(): - if isinstance(value, torch.Tensor): - assert value.device.type == device, f"Tensor {key} not on {device}" - - -def test_raw_observation_to_observation_deterministic(): - """Test that the function produces consistent results for the same input.""" - robot_obs = _create_mock_robot_observation() - lerobot_features = _create_mock_lerobot_features() - policy_image_features = _create_mock_policy_image_features() - device = "cpu" - - # Run twice with same input - obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) - obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) - - # Results should be identical - assert set(obs1.keys()) == set(obs2.keys()) - - for key in obs1: - if isinstance(obs1[key], torch.Tensor): - torch.testing.assert_close(obs1[key], obs2[key]) - else: - assert obs1[key] == obs2[key] - - -def test_image_processing_pipeline_preserves_content(): - """Test that the image processing pipeline preserves recognizable patterns.""" - # Create an image with a specific pattern - original_img = np.zeros((100, 100, 3), dtype=np.uint8) - original_img[25:75, 25:75, :] = 255 # White square in center - - robot_obs = {"shoulder": 1.0, "elbow": 1.0, "wrist": 1.0, "gripper": 1.0, "laptop": original_img} - lerobot_features = { - "observation.state": { - "dtype": "float32", - "shape": [4], - "names": ["shoulder", "elbow", "wrist", "gripper"], - }, - "observation.images.laptop": { - "dtype": "image", - "shape": [100, 100, 3], - "names": ["height", "width", "channels"], - }, - } - policy_image_features = { - "observation.images.laptop": PolicyFeature( - type=FeatureType.VISUAL, - shape=(3, 50, 50), # Downsamples from 100x100 - ) - } - - observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu") - - processed_img = observation["observation.images.laptop"].squeeze(0) # Remove batch dim - - # Check that the center region has higher values than corners - # Due to bilinear interpolation, exact values will change but pattern should remain - center_val = processed_img[:, 25, 25].mean() # Center of 50x50 image - corner_val = processed_img[:, 5, 5].mean() # Corner - - assert center_val > corner_val, "Image processing should preserve recognizable patterns" diff --git a/tests/async_inference/test_policy_server.py b/tests/async_inference/test_policy_server.py deleted file mode 100644 index 5c795e7..0000000 --- a/tests/async_inference/test_policy_server.py +++ /dev/null @@ -1,215 +0,0 @@ -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Unit-tests for the `PolicyServer` core logic. -Monkey-patch the `policy` attribute with a stub so that no real model inference is performed. -""" - -from __future__ import annotations - -import time - -import pytest -import torch - -from lerobot.configs.types import PolicyFeature -from tests.utils import require_package - -# ----------------------------------------------------------------------------- -# Test fixtures -# ----------------------------------------------------------------------------- - - -class MockPolicy: - """A minimal mock for an actual policy, returning zeros. - Refer to tests/policies for tests of the individual policies supported.""" - - class _Config: - robot_type = "dummy_robot" - - @property - def image_features(self) -> dict[str, PolicyFeature]: - """Empty image features since this test doesn't use images.""" - return {} - - def predict_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor: - """Return a chunk of 20 dummy actions.""" - batch_size = len(observation["observation.state"]) - return torch.zeros(batch_size, 20, 6) - - def __init__(self): - self.config = self._Config() - - def to(self, *args, **kwargs): - # The server calls `policy.to(device)`. This stub ignores it. - return self - - def model(self, batch: dict) -> torch.Tensor: - # Return a chunk of 20 dummy actions. - batch_size = len(batch["robot_type"]) - return torch.zeros(batch_size, 20, 6) - - -@pytest.fixture -@require_package("grpc") -def policy_server(): - """Fresh `PolicyServer` instance with a stubbed-out policy model.""" - # Import only when the test actually runs (after decorator check) - from lerobot.scripts.server.configs import PolicyServerConfig - from lerobot.scripts.server.policy_server import PolicyServer - - test_config = PolicyServerConfig(host="localhost", port=9999) - server = PolicyServer(test_config) - # Replace the real policy with our fast, deterministic stub. - server.policy = MockPolicy() - server.actions_per_chunk = 20 - server.device = "cpu" - - # Add mock lerobot_features that the observation similarity functions need - server.lerobot_features = { - "observation.state": { - "dtype": "float32", - "shape": [6], - "names": ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"], - } - } - - return server - - -# ----------------------------------------------------------------------------- -# Helper utilities for tests -# ----------------------------------------------------------------------------- - - -def _make_obs(state: torch.Tensor, timestep: int = 0, must_go: bool = False): - """Create a TimedObservation with a given state vector.""" - # Import only when needed - from lerobot.scripts.server.helpers import TimedObservation - - return TimedObservation( - observation={ - "joint1": state[0].item() if len(state) > 0 else 0.0, - "joint2": state[1].item() if len(state) > 1 else 0.0, - "joint3": state[2].item() if len(state) > 2 else 0.0, - "joint4": state[3].item() if len(state) > 3 else 0.0, - "joint5": state[4].item() if len(state) > 4 else 0.0, - "joint6": state[5].item() if len(state) > 5 else 0.0, - }, - timestamp=time.time(), - timestep=timestep, - must_go=must_go, - ) - - -# ----------------------------------------------------------------------------- -# Tests -# ----------------------------------------------------------------------------- - - -def test_time_action_chunk(policy_server): - """Verify that `_time_action_chunk` assigns correct timestamps and timesteps.""" - start_ts = time.time() - start_t = 10 - # A chunk of 3 action tensors. - action_tensors = [torch.randn(6) for _ in range(3)] - - timed_actions = policy_server._time_action_chunk(start_ts, action_tensors, start_t) - - assert len(timed_actions) == 3 - # Check timesteps - assert [ta.get_timestep() for ta in timed_actions] == [10, 11, 12] - # Check timestamps - expected_timestamps = [ - start_ts, - start_ts + policy_server.config.environment_dt, - start_ts + 2 * policy_server.config.environment_dt, - ] - for ta, expected_ts in zip(timed_actions, expected_timestamps, strict=True): - assert abs(ta.get_timestamp() - expected_ts) < 1e-6 - - -def test_maybe_enqueue_observation_must_go(policy_server): - """An observation with `must_go=True` is always enqueued.""" - obs = _make_obs(torch.zeros(6), must_go=True) - assert policy_server._enqueue_observation(obs) is True - assert policy_server.observation_queue.qsize() == 1 - assert policy_server.observation_queue.get_nowait() is obs - - -def test_maybe_enqueue_observation_dissimilar(policy_server): - """A dissimilar observation (not `must_go`) is enqueued.""" - # Set a last predicted observation. - policy_server.last_processed_obs = _make_obs(torch.zeros(6)) - # Create a new, dissimilar observation. - new_obs = _make_obs(torch.ones(6) * 5) # High norm difference - - assert policy_server._enqueue_observation(new_obs) is True - assert policy_server.observation_queue.qsize() == 1 - - -def test_maybe_enqueue_observation_is_skipped(policy_server): - """A similar observation (not `must_go`) is skipped.""" - # Set a last predicted observation. - policy_server.last_processed_obs = _make_obs(torch.zeros(6)) - # Create a new, very similar observation. - new_obs = _make_obs(torch.zeros(6) + 1e-4) - - assert policy_server._enqueue_observation(new_obs) is False - assert policy_server.observation_queue.empty() is True - - -def test_obs_sanity_checks(policy_server): - """Unit-test the private `_obs_sanity_checks` helper.""" - prev = _make_obs(torch.zeros(6), timestep=0) - - # Case 1 – timestep already predicted - policy_server._predicted_timesteps.add(1) - obs_same_ts = _make_obs(torch.ones(6), timestep=1) - assert policy_server._obs_sanity_checks(obs_same_ts, prev) is False - - # Case 2 – observation too similar - policy_server._predicted_timesteps.clear() - obs_similar = _make_obs(torch.zeros(6) + 1e-4, timestep=2) - assert policy_server._obs_sanity_checks(obs_similar, prev) is False - - # Case 3 – genuinely new & dissimilar observation passes - obs_ok = _make_obs(torch.ones(6) * 5, timestep=3) - assert policy_server._obs_sanity_checks(obs_ok, prev) is True - - -def test_predict_action_chunk(monkeypatch, policy_server): - """End-to-end test of `_predict_action_chunk` with a stubbed _get_action_chunk.""" - # Import only when needed - from lerobot.scripts.server.policy_server import PolicyServer - - # Force server to act-style policy; patch method to return deterministic tensor - policy_server.policy_type = "act" - action_dim = 6 - batch_size = 1 - actions_per_chunk = policy_server.actions_per_chunk - - def _fake_get_action_chunk(_self, _obs, _type="act"): - return torch.zeros(batch_size, actions_per_chunk, action_dim) - - monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True) - - obs = _make_obs(torch.zeros(6), timestep=5) - timed_actions = policy_server._predict_action_chunk(obs) - - assert len(timed_actions) == actions_per_chunk - assert [ta.get_timestep() for ta in timed_actions] == list(range(5, 5 + actions_per_chunk)) - - for i, ta in enumerate(timed_actions): - expected_ts = obs.get_timestamp() + i * policy_server.config.environment_dt - assert abs(ta.get_timestamp() - expected_ts) < 1e-6 diff --git a/tests/async_inference/test_robot_client.py b/tests/async_inference/test_robot_client.py deleted file mode 100644 index 51db2c3..0000000 --- a/tests/async_inference/test_robot_client.py +++ /dev/null @@ -1,234 +0,0 @@ -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Unit-tests for the `RobotClient` action-queue logic (pure Python, no gRPC). - -We monkey-patch `lerobot.robots.utils.make_robot_from_config` so that -no real hardware is accessed. Only the queue-update mechanism is verified. -""" - -from __future__ import annotations - -import time -from queue import Queue - -import pytest -import torch - -# Skip entire module if grpc is not available -pytest.importorskip("grpc") - -# ----------------------------------------------------------------------------- -# Test fixtures -# ----------------------------------------------------------------------------- - - -@pytest.fixture() -def robot_client(): - """Fresh `RobotClient` instance for each test case (no threads started). - Uses DummyRobot.""" - # Import only when the test actually runs (after decorator check) - from lerobot.scripts.server.configs import RobotClientConfig - from lerobot.scripts.server.robot_client import RobotClient - from tests.mocks.mock_robot import MockRobotConfig - - test_config = MockRobotConfig() - - # gRPC channel is not actually used in tests, so using a dummy address - test_config = RobotClientConfig( - robot=test_config, - server_address="localhost:9999", - policy_type="test", - pretrained_name_or_path="test", - actions_per_chunk=20, - verify_robot_cameras=False, - ) - - client = RobotClient(test_config) - - # Initialize attributes that are normally set in start() method - client.chunks_received = 0 - client.available_actions_size = [] - - yield client - - if client.robot.is_connected: - client.stop() - - -# ----------------------------------------------------------------------------- -# Helper utilities for tests -# ----------------------------------------------------------------------------- - - -def _make_actions(start_ts: float, start_t: int, count: int): - """Generate `count` consecutive TimedAction objects starting at timestep `start_t`.""" - from lerobot.scripts.server.helpers import TimedAction - - fps = 30 # emulates most common frame-rate - actions = [] - for i in range(count): - timestep = start_t + i - timestamp = start_ts + i * (1 / fps) - action_tensor = torch.full((6,), timestep, dtype=torch.float32) - actions.append(TimedAction(action=action_tensor, timestep=timestep, timestamp=timestamp)) - return actions - - -# ----------------------------------------------------------------------------- -# Tests -# ----------------------------------------------------------------------------- - - -def test_update_action_queue_discards_stale(robot_client): - """`_update_action_queue` must drop actions with `timestep` <= `latest_action`.""" - - # Pretend we already executed up to action #4 - robot_client.latest_action = 4 - - # Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept. - incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7 - - robot_client._aggregate_action_queues(incoming) - - # Extract timesteps from queue - resulting_timesteps = [a.get_timestep() for a in robot_client.action_queue.queue] - - assert resulting_timesteps == [5, 6, 7] - - -@pytest.mark.parametrize( - "weight_old, weight_new", - [ - (1.0, 0.0), - (0.0, 1.0), - (0.5, 0.5), - (0.2, 0.8), - (0.8, 0.2), - (0.1, 0.9), - (0.9, 0.1), - ], -) -def test_aggregate_action_queues_combines_actions_in_overlap( - robot_client, weight_old: float, weight_new: float -): - """`_aggregate_action_queues` must combine actions on overlapping timesteps according - to the provided aggregate_fn, here tested with multiple coefficients.""" - from lerobot.scripts.server.helpers import TimedAction - - robot_client.chunks_received = 0 - - # Pretend we already executed up to action #4, and queue contains actions for timesteps 5..6 - robot_client.latest_action = 4 - current_actions = _make_actions( - start_ts=time.time(), start_t=5, count=2 - ) # actions are [torch.ones(6), torch.ones(6), ...] - current_actions = [ - TimedAction(action=10 * a.get_action(), timestep=a.get_timestep(), timestamp=a.get_timestamp()) - for a in current_actions - ] - - for a in current_actions: - robot_client.action_queue.put(a) - - # Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept. - incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7 - - overlap_timesteps = [5, 6] # properly tested in test_aggregate_action_queues_discards_stale - nonoverlap_timesteps = [7] - - robot_client._aggregate_action_queues( - incoming, aggregate_fn=lambda x1, x2: weight_old * x1 + weight_new * x2 - ) - - queue_overlap_actions = [] - queue_non_overlap_actions = [] - for a in robot_client.action_queue.queue: - if a.get_timestep() in overlap_timesteps: - queue_overlap_actions.append(a) - elif a.get_timestep() in nonoverlap_timesteps: - queue_non_overlap_actions.append(a) - - queue_overlap_actions = sorted(queue_overlap_actions, key=lambda x: x.get_timestep()) - queue_non_overlap_actions = sorted(queue_non_overlap_actions, key=lambda x: x.get_timestep()) - - assert torch.allclose( - queue_overlap_actions[0].get_action(), - weight_old * current_actions[0].get_action() + weight_new * incoming[-3].get_action(), - ) - assert torch.allclose( - queue_overlap_actions[1].get_action(), - weight_old * current_actions[1].get_action() + weight_new * incoming[-2].get_action(), - ) - assert torch.allclose(queue_non_overlap_actions[0].get_action(), incoming[-1].get_action()) - - -@pytest.mark.parametrize( - "chunk_size, queue_len, expected", - [ - (20, 12, False), # 12 / 20 = 0.6 > g=0.5 threshold, not ready to send - (20, 8, True), # 8 / 20 = 0.4 <= g=0.5, ready to send - (10, 5, True), - (10, 6, False), - ], -) -def test_ready_to_send_observation(robot_client, chunk_size: int, queue_len: int, expected: bool): - """Validate `_ready_to_send_observation` ratio logic for various sizes.""" - - robot_client.action_chunk_size = chunk_size - - # Clear any existing actions then fill with `queue_len` dummy entries ---- - robot_client.action_queue = Queue() - - dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len) - for act in dummy_actions: - robot_client.action_queue.put(act) - - assert robot_client._ready_to_send_observation() is expected - - -@pytest.mark.parametrize( - "g_threshold, expected", - [ - # The condition is `queue_size / chunk_size <= g`. - # Here, ratio = 6 / 10 = 0.6. - (0.0, False), # 0.6 <= 0.0 is False - (0.1, False), - (0.2, False), - (0.3, False), - (0.4, False), - (0.5, False), - (0.6, True), # 0.6 <= 0.6 is True - (0.7, True), - (0.8, True), - (0.9, True), - (1.0, True), - ], -) -def test_ready_to_send_observation_with_varying_threshold(robot_client, g_threshold: float, expected: bool): - """Validate `_ready_to_send_observation` with fixed sizes and varying `g`.""" - # Fixed sizes for this test: ratio = 6 / 10 = 0.6 - chunk_size = 10 - queue_len = 6 - - robot_client.action_chunk_size = chunk_size - # This is the parameter we are testing - robot_client._chunk_size_threshold = g_threshold - - # Fill queue with dummy actions - robot_client.action_queue = Queue() - dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len) - for act in dummy_actions: - robot_client.action_queue.put(act) - - assert robot_client._ready_to_send_observation() is expected diff --git a/tests/cameras/test_opencv.py b/tests/cameras/test_opencv.py deleted file mode 100644 index a9c060c..0000000 --- a/tests/cameras/test_opencv.py +++ /dev/null @@ -1,188 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Example of running a specific test: -# ```bash -# pytest tests/cameras/test_opencv.py::test_connect -# ``` - -from pathlib import Path - -import numpy as np -import pytest - -from lerobot.cameras.configs import Cv2Rotation -from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError - -# NOTE(Steven): more tests + assertions? -TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "cameras" -DEFAULT_PNG_FILE_PATH = TEST_ARTIFACTS_DIR / "image_160x120.png" -TEST_IMAGE_SIZES = ["128x128", "160x120", "320x180", "480x270"] -TEST_IMAGE_PATHS = [TEST_ARTIFACTS_DIR / f"image_{size}.png" for size in TEST_IMAGE_SIZES] - - -def test_abc_implementation(): - """Instantiation should raise an error if the class doesn't implement abstract methods/properties.""" - config = OpenCVCameraConfig(index_or_path=0) - - _ = OpenCVCamera(config) - - -def test_connect(): - config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) - camera = OpenCVCamera(config) - - camera.connect(warmup=False) - - assert camera.is_connected - - -def test_connect_already_connected(): - config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) - camera = OpenCVCamera(config) - camera.connect(warmup=False) - - with pytest.raises(DeviceAlreadyConnectedError): - camera.connect(warmup=False) - - -def test_connect_invalid_camera_path(): - config = OpenCVCameraConfig(index_or_path="nonexistent/camera.png") - camera = OpenCVCamera(config) - - with pytest.raises(ConnectionError): - camera.connect(warmup=False) - - -def test_invalid_width_connect(): - config = OpenCVCameraConfig( - index_or_path=DEFAULT_PNG_FILE_PATH, - width=99999, # Invalid width to trigger error - height=480, - ) - camera = OpenCVCamera(config) - - with pytest.raises(RuntimeError): - camera.connect(warmup=False) - - -@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES) -def test_read(index_or_path): - config = OpenCVCameraConfig(index_or_path=index_or_path) - camera = OpenCVCamera(config) - camera.connect(warmup=False) - - img = camera.read() - - assert isinstance(img, np.ndarray) - - -def test_read_before_connect(): - config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) - camera = OpenCVCamera(config) - - with pytest.raises(DeviceNotConnectedError): - _ = camera.read() - - -def test_disconnect(): - config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) - camera = OpenCVCamera(config) - camera.connect(warmup=False) - - camera.disconnect() - - assert not camera.is_connected - - -def test_disconnect_before_connect(): - config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) - camera = OpenCVCamera(config) - - with pytest.raises(DeviceNotConnectedError): - _ = camera.disconnect() - - -@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES) -def test_async_read(index_or_path): - config = OpenCVCameraConfig(index_or_path=index_or_path) - camera = OpenCVCamera(config) - camera.connect(warmup=False) - - try: - img = camera.async_read() - - assert camera.thread is not None - assert camera.thread.is_alive() - assert isinstance(img, np.ndarray) - finally: - if camera.is_connected: - camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends - - -def test_async_read_timeout(): - config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) - camera = OpenCVCamera(config) - camera.connect(warmup=False) - - try: - with pytest.raises(TimeoutError): - camera.async_read(timeout_ms=0) - finally: - if camera.is_connected: - camera.disconnect() - - -def test_async_read_before_connect(): - config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) - camera = OpenCVCamera(config) - - with pytest.raises(DeviceNotConnectedError): - _ = camera.async_read() - - -@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES) -@pytest.mark.parametrize( - "rotation", - [ - Cv2Rotation.NO_ROTATION, - Cv2Rotation.ROTATE_90, - Cv2Rotation.ROTATE_180, - Cv2Rotation.ROTATE_270, - ], - ids=["no_rot", "rot90", "rot180", "rot270"], -) -def test_rotation(rotation, index_or_path): - filename = Path(index_or_path).name - dimensions = filename.split("_")[-1].split(".")[0] # Assumes filenames format (_wxh.png) - original_width, original_height = map(int, dimensions.split("x")) - - config = OpenCVCameraConfig(index_or_path=index_or_path, rotation=rotation) - camera = OpenCVCamera(config) - camera.connect(warmup=False) - - img = camera.read() - assert isinstance(img, np.ndarray) - - if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270): - assert camera.width == original_height - assert camera.height == original_width - assert img.shape[:2] == (original_width, original_height) - else: - assert camera.width == original_width - assert camera.height == original_height - assert img.shape[:2] == (original_height, original_width) diff --git a/tests/cameras/test_realsense.py b/tests/cameras/test_realsense.py deleted file mode 100644 index 4b3fbae..0000000 --- a/tests/cameras/test_realsense.py +++ /dev/null @@ -1,206 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Example of running a specific test: -# ```bash -# pytest tests/cameras/test_opencv.py::test_connect -# ``` - -from pathlib import Path -from unittest.mock import patch - -import numpy as np -import pytest - -from lerobot.cameras.configs import Cv2Rotation -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError - -pytest.importorskip("pyrealsense2") - -from lerobot.cameras.realsense import RealSenseCamera, RealSenseCameraConfig - -TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "cameras" -BAG_FILE_PATH = TEST_ARTIFACTS_DIR / "test_rs.bag" - -# NOTE(Steven): For some reason these tests take ~20sec in macOS but only ~2sec in Linux. - - -def mock_rs_config_enable_device_from_file(rs_config_instance, _sn): - return rs_config_instance.enable_device_from_file(str(BAG_FILE_PATH), repeat_playback=True) - - -def mock_rs_config_enable_device_bad_file(rs_config_instance, _sn): - return rs_config_instance.enable_device_from_file("non_existent_file.bag", repeat_playback=True) - - -@pytest.fixture(name="patch_realsense", autouse=True) -def fixture_patch_realsense(): - """Automatically mock pyrealsense2.config.enable_device for all tests.""" - with patch( - "pyrealsense2.config.enable_device", side_effect=mock_rs_config_enable_device_from_file - ) as mock: - yield mock - - -def test_abc_implementation(): - """Instantiation should raise an error if the class doesn't implement abstract methods/properties.""" - config = RealSenseCameraConfig(serial_number_or_name="042") - _ = RealSenseCamera(config) - - -def test_connect(): - config = RealSenseCameraConfig(serial_number_or_name="042") - camera = RealSenseCamera(config) - - camera.connect(warmup=False) - assert camera.is_connected - - -def test_connect_already_connected(): - config = RealSenseCameraConfig(serial_number_or_name="042") - camera = RealSenseCamera(config) - camera.connect(warmup=False) - - with pytest.raises(DeviceAlreadyConnectedError): - camera.connect(warmup=False) - - -def test_connect_invalid_camera_path(patch_realsense): - patch_realsense.side_effect = mock_rs_config_enable_device_bad_file - config = RealSenseCameraConfig(serial_number_or_name="042") - camera = RealSenseCamera(config) - - with pytest.raises(ConnectionError): - camera.connect(warmup=False) - - -def test_invalid_width_connect(): - config = RealSenseCameraConfig(serial_number_or_name="042", width=99999, height=480, fps=30) - camera = RealSenseCamera(config) - - with pytest.raises(ConnectionError): - camera.connect(warmup=False) - - -def test_read(): - config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30) - camera = RealSenseCamera(config) - camera.connect(warmup=False) - - img = camera.read() - assert isinstance(img, np.ndarray) - - -# TODO(Steven): Fix this test for the latest version of pyrealsense2. -@pytest.mark.skip("Skipping test: pyrealsense2 version > 2.55.1.6486") -def test_read_depth(): - config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, use_depth=True) - camera = RealSenseCamera(config) - camera.connect(warmup=False) - - img = camera.read_depth(timeout_ms=2000) # NOTE(Steven): Reading depth takes longer in CI environments. - assert isinstance(img, np.ndarray) - - -def test_read_before_connect(): - config = RealSenseCameraConfig(serial_number_or_name="042") - camera = RealSenseCamera(config) - - with pytest.raises(DeviceNotConnectedError): - _ = camera.read() - - -def test_disconnect(): - config = RealSenseCameraConfig(serial_number_or_name="042") - camera = RealSenseCamera(config) - camera.connect(warmup=False) - - camera.disconnect() - - assert not camera.is_connected - - -def test_disconnect_before_connect(): - config = RealSenseCameraConfig(serial_number_or_name="042") - camera = RealSenseCamera(config) - - with pytest.raises(DeviceNotConnectedError): - camera.disconnect() - - -def test_async_read(): - config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30) - camera = RealSenseCamera(config) - camera.connect(warmup=False) - - try: - img = camera.async_read() - - assert camera.thread is not None - assert camera.thread.is_alive() - assert isinstance(img, np.ndarray) - finally: - if camera.is_connected: - camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends - - -def test_async_read_timeout(): - config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30) - camera = RealSenseCamera(config) - camera.connect(warmup=False) - - try: - with pytest.raises(TimeoutError): - camera.async_read(timeout_ms=0) - finally: - if camera.is_connected: - camera.disconnect() - - -def test_async_read_before_connect(): - config = RealSenseCameraConfig(serial_number_or_name="042") - camera = RealSenseCamera(config) - - with pytest.raises(DeviceNotConnectedError): - _ = camera.async_read() - - -@pytest.mark.parametrize( - "rotation", - [ - Cv2Rotation.NO_ROTATION, - Cv2Rotation.ROTATE_90, - Cv2Rotation.ROTATE_180, - Cv2Rotation.ROTATE_270, - ], - ids=["no_rot", "rot90", "rot180", "rot270"], -) -def test_rotation(rotation): - config = RealSenseCameraConfig(serial_number_or_name="042", rotation=rotation) - camera = RealSenseCamera(config) - camera.connect(warmup=False) - - img = camera.read() - assert isinstance(img, np.ndarray) - - if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270): - assert camera.width == 480 - assert camera.height == 640 - assert img.shape[:2] == (640, 480) - else: - assert camera.width == 640 - assert camera.height == 480 - assert img.shape[:2] == (480, 640) diff --git a/tests/configs/test_plugin_loading.py b/tests/configs/test_plugin_loading.py deleted file mode 100644 index 3ec60a4..0000000 --- a/tests/configs/test_plugin_loading.py +++ /dev/null @@ -1,105 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -from collections.abc import Generator -from dataclasses import dataclass -from pathlib import Path - -import pytest - -from lerobot.configs.parser import PluginLoadError, load_plugin, parse_plugin_args, wrap -from lerobot.envs.configs import EnvConfig - - -def create_plugin_code(*, base_class: str = "EnvConfig", plugin_name: str = "test_env") -> str: - """Creates a dummy plugin module that implements its own EnvConfig subclass.""" - return f""" -from dataclasses import dataclass -from lerobot.envs.configs import {base_class} - -@{base_class}.register_subclass("{plugin_name}") -@dataclass -class TestPluginConfig: - value: int = 42 - """ - - -@pytest.fixture -def plugin_dir(tmp_path: Path) -> Generator[Path, None, None]: - """Creates a temporary plugin package structure.""" - plugin_pkg = tmp_path / "test_plugin" - plugin_pkg.mkdir() - (plugin_pkg / "__init__.py").touch() - - with open(plugin_pkg / "my_plugin.py", "w") as f: - f.write(create_plugin_code()) - - # Add tmp_path to Python path so we can import from it - sys.path.insert(0, str(tmp_path)) - yield plugin_pkg - sys.path.pop(0) - - -def test_parse_plugin_args(): - cli_args = [ - "--env.type=test", - "--model.discover_packages_path=some.package", - "--env.discover_packages_path=other.package", - ] - plugin_args = parse_plugin_args("discover_packages_path", cli_args) - assert plugin_args == { - "model.discover_packages_path": "some.package", - "env.discover_packages_path": "other.package", - } - - -def test_load_plugin_success(plugin_dir: Path): - # Import should work and register the plugin with the real EnvConfig - load_plugin("test_plugin") - - assert "test_env" in EnvConfig.get_known_choices() - plugin_cls = EnvConfig.get_choice_class("test_env") - plugin_instance = plugin_cls() - assert plugin_instance.value == 42 - - -def test_load_plugin_failure(): - with pytest.raises(PluginLoadError) as exc_info: - load_plugin("nonexistent_plugin") - assert "Failed to load plugin 'nonexistent_plugin'" in str(exc_info.value) - - -def test_wrap_with_plugin(plugin_dir: Path): - @dataclass - class Config: - env: EnvConfig - - @wrap() - def dummy_func(cfg: Config): - return cfg - - # Test loading plugin via CLI args - sys.argv = [ - "dummy_script.py", - "--env.discover_packages_path=test_plugin", - "--env.type=test_env", - ] - - cfg = dummy_func() - assert isinstance(cfg, Config) - assert isinstance(cfg.env, EnvConfig.get_choice_class("test_env")) - assert cfg.env.value == 42 diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 7940cc5..0000000 --- a/tests/conftest.py +++ /dev/null @@ -1,88 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import traceback - -import pytest -from serial import SerialException - -from lerobot.configs.types import FeatureType, PolicyFeature -from tests.utils import DEVICE - -# Import fixture modules as plugins -pytest_plugins = [ - "tests.fixtures.dataset_factories", - "tests.fixtures.files", - "tests.fixtures.hub", - "tests.fixtures.optimizers", -] - - -def pytest_collection_finish(): - print(f"\nTesting with {DEVICE=}") - - -def _check_component_availability(component_type, available_components, make_component): - """Generic helper to check if a hardware component is available""" - if component_type not in available_components: - raise ValueError( - f"The {component_type} type is not valid. Expected one of these '{available_components}'" - ) - - try: - component = make_component(component_type) - component.connect() - del component - return True - - except Exception as e: - print(f"\nA {component_type} is not available.") - - if isinstance(e, ModuleNotFoundError): - print(f"\nInstall module '{e.name}'") - elif isinstance(e, SerialException): - print("\nNo physical device detected.") - elif isinstance(e, ValueError) and "camera_index" in str(e): - print("\nNo physical camera detected.") - else: - traceback.print_exc() - - return False - - -@pytest.fixture -def patch_builtins_input(monkeypatch): - def print_text(text=None): - if text is not None: - print(text) - - monkeypatch.setattr("builtins.input", print_text) - - -@pytest.fixture -def policy_feature_factory(): - """PolicyFeature factory""" - - def _pf(ft: FeatureType, shape: tuple[int, ...]) -> PolicyFeature: - return PolicyFeature(type=ft, shape=shape) - - return _pf - - -def assert_contract_is_typed(features: dict[str, PolicyFeature]) -> None: - assert isinstance(features, dict) - assert all(isinstance(k, str) for k in features.keys()) - assert all(isinstance(v, PolicyFeature) for v in features.values()) diff --git a/tests/datasets/test_compute_stats.py b/tests/datasets/test_compute_stats.py deleted file mode 100644 index 8f8179c..0000000 --- a/tests/datasets/test_compute_stats.py +++ /dev/null @@ -1,309 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from unittest.mock import patch - -import numpy as np -import pytest - -from lerobot.datasets.compute_stats import ( - _assert_type_and_shape, - aggregate_feature_stats, - aggregate_stats, - compute_episode_stats, - estimate_num_samples, - get_feature_stats, - sample_images, - sample_indices, -) - - -def mock_load_image_as_numpy(path, dtype, channel_first): - return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype) - - -@pytest.fixture -def sample_array(): - return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - - -def test_estimate_num_samples(): - assert estimate_num_samples(1) == 1 - assert estimate_num_samples(10) == 10 - assert estimate_num_samples(100) == 100 - assert estimate_num_samples(200) == 100 - assert estimate_num_samples(1000) == 177 - assert estimate_num_samples(2000) == 299 - assert estimate_num_samples(5000) == 594 - assert estimate_num_samples(10_000) == 1000 - assert estimate_num_samples(20_000) == 1681 - assert estimate_num_samples(50_000) == 3343 - assert estimate_num_samples(500_000) == 10_000 - - -def test_sample_indices(): - indices = sample_indices(10) - assert len(indices) > 0 - assert indices[0] == 0 - assert indices[-1] == 9 - assert len(indices) == estimate_num_samples(10) - - -@patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy) -def test_sample_images(mock_load): - image_paths = [f"image_{i}.jpg" for i in range(100)] - images = sample_images(image_paths) - assert isinstance(images, np.ndarray) - assert images.shape[1:] == (3, 32, 32) - assert images.dtype == np.uint8 - assert len(images) == estimate_num_samples(100) - - -def test_get_feature_stats_images(): - data = np.random.rand(100, 3, 32, 32) - stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True) - assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats - np.testing.assert_equal(stats["count"], np.array([100])) - assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape - - -def test_get_feature_stats_axis_0_keepdims(sample_array): - expected = { - "min": np.array([[1, 2, 3]]), - "max": np.array([[7, 8, 9]]), - "mean": np.array([[4.0, 5.0, 6.0]]), - "std": np.array([[2.44948974, 2.44948974, 2.44948974]]), - "count": np.array([3]), - } - result = get_feature_stats(sample_array, axis=(0,), keepdims=True) - for key in expected: - np.testing.assert_allclose(result[key], expected[key]) - - -def test_get_feature_stats_axis_1(sample_array): - expected = { - "min": np.array([1, 4, 7]), - "max": np.array([3, 6, 9]), - "mean": np.array([2.0, 5.0, 8.0]), - "std": np.array([0.81649658, 0.81649658, 0.81649658]), - "count": np.array([3]), - } - result = get_feature_stats(sample_array, axis=(1,), keepdims=False) - for key in expected: - np.testing.assert_allclose(result[key], expected[key]) - - -def test_get_feature_stats_no_axis(sample_array): - expected = { - "min": np.array(1), - "max": np.array(9), - "mean": np.array(5.0), - "std": np.array(2.5819889), - "count": np.array([3]), - } - result = get_feature_stats(sample_array, axis=None, keepdims=False) - for key in expected: - np.testing.assert_allclose(result[key], expected[key]) - - -def test_get_feature_stats_empty_array(): - array = np.array([]) - with pytest.raises(ValueError): - get_feature_stats(array, axis=(0,), keepdims=True) - - -def test_get_feature_stats_single_value(): - array = np.array([[1337]]) - result = get_feature_stats(array, axis=None, keepdims=True) - np.testing.assert_equal(result["min"], np.array(1337)) - np.testing.assert_equal(result["max"], np.array(1337)) - np.testing.assert_equal(result["mean"], np.array(1337.0)) - np.testing.assert_equal(result["std"], np.array(0.0)) - np.testing.assert_equal(result["count"], np.array([1])) - - -def test_compute_episode_stats(): - episode_data = { - "observation.image": [f"image_{i}.jpg" for i in range(100)], - "observation.state": np.random.rand(100, 10), - } - features = { - "observation.image": {"dtype": "image"}, - "observation.state": {"dtype": "numeric"}, - } - - with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy): - stats = compute_episode_stats(episode_data, features) - - assert "observation.image" in stats and "observation.state" in stats - assert stats["observation.image"]["count"].item() == 100 - assert stats["observation.state"]["count"].item() == 100 - assert stats["observation.image"]["mean"].shape == (3, 1, 1) - - -def test_assert_type_and_shape_valid(): - valid_stats = [ - { - "feature1": { - "min": np.array([1.0]), - "max": np.array([10.0]), - "mean": np.array([5.0]), - "std": np.array([2.0]), - "count": np.array([1]), - } - } - ] - _assert_type_and_shape(valid_stats) - - -def test_assert_type_and_shape_invalid_type(): - invalid_stats = [ - { - "feature1": { - "min": [1.0], # Not a numpy array - "max": np.array([10.0]), - "mean": np.array([5.0]), - "std": np.array([2.0]), - "count": np.array([1]), - } - } - ] - with pytest.raises(ValueError, match="Stats must be composed of numpy array"): - _assert_type_and_shape(invalid_stats) - - -def test_assert_type_and_shape_invalid_shape(): - invalid_stats = [ - { - "feature1": { - "count": np.array([1, 2]), # Wrong shape - } - } - ] - with pytest.raises(ValueError, match=r"Shape of 'count' must be \(1\)"): - _assert_type_and_shape(invalid_stats) - - -def test_aggregate_feature_stats(): - stats_ft_list = [ - { - "min": np.array([1.0]), - "max": np.array([10.0]), - "mean": np.array([5.0]), - "std": np.array([2.0]), - "count": np.array([1]), - }, - { - "min": np.array([2.0]), - "max": np.array([12.0]), - "mean": np.array([6.0]), - "std": np.array([2.5]), - "count": np.array([1]), - }, - ] - result = aggregate_feature_stats(stats_ft_list) - np.testing.assert_allclose(result["min"], np.array([1.0])) - np.testing.assert_allclose(result["max"], np.array([12.0])) - np.testing.assert_allclose(result["mean"], np.array([5.5])) - np.testing.assert_allclose(result["std"], np.array([2.318405]), atol=1e-6) - np.testing.assert_allclose(result["count"], np.array([2])) - - -def test_aggregate_stats(): - all_stats = [ - { - "observation.image": { - "min": [1, 2, 3], - "max": [10, 20, 30], - "mean": [5.5, 10.5, 15.5], - "std": [2.87, 5.87, 8.87], - "count": 10, - }, - "observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10}, - "extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6}, - }, - { - "observation.image": { - "min": [2, 1, 0], - "max": [15, 10, 5], - "mean": [8.5, 5.5, 2.5], - "std": [3.42, 2.42, 1.42], - "count": 15, - }, - "observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15}, - "extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5}, - }, - ] - - expected_agg_stats = { - "observation.image": { - "min": [1, 1, 0], - "max": [15, 20, 30], - "mean": [7.3, 7.5, 7.7], - "std": [3.5317, 4.8267, 8.5581], - "count": 25, - }, - "observation.state": { - "min": 1, - "max": 15, - "mean": 7.3, - "std": 3.5317, - "count": 25, - }, - "extra_key_0": { - "min": 5, - "max": 25, - "mean": 15.0, - "std": 6.0, - "count": 6, - }, - "extra_key_1": { - "min": 0, - "max": 20, - "mean": 10.0, - "std": 5.0, - "count": 5, - }, - } - - # cast to numpy - for ep_stats in all_stats: - for fkey, stats in ep_stats.items(): - for k in stats: - stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32) - if fkey == "observation.image" and k != "count": - stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels - else: - stats[k] = stats[k].reshape(1) - - # cast to numpy - for fkey, stats in expected_agg_stats.items(): - for k in stats: - stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32) - if fkey == "observation.image" and k != "count": - stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels - else: - stats[k] = stats[k].reshape(1) - - results = aggregate_stats(all_stats) - - for fkey in expected_agg_stats: - np.testing.assert_allclose(results[fkey]["min"], expected_agg_stats[fkey]["min"]) - np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"]) - np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"]) - np.testing.assert_allclose( - results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04 - ) - np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"]) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py deleted file mode 100644 index d3b78dd..0000000 --- a/tests/datasets/test_datasets.py +++ /dev/null @@ -1,573 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -import logging -import re -from copy import deepcopy -from itertools import chain -from pathlib import Path - -import numpy as np -import pytest -import torch -from huggingface_hub import HfApi -from PIL import Image -from safetensors.torch import load_file - -import lerobot -from lerobot.configs.default import DatasetConfig -from lerobot.configs.train import TrainPipelineConfig -from lerobot.datasets.factory import make_dataset -from lerobot.datasets.image_writer import image_array_to_pil_image -from lerobot.datasets.lerobot_dataset import ( - LeRobotDataset, - MultiLeRobotDataset, -) -from lerobot.datasets.utils import ( - create_branch, - flatten_dict, - unflatten_dict, -) -from lerobot.envs.factory import make_env_config -from lerobot.policies.factory import make_policy_config -from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID -from tests.utils import require_x86_64_kernel - - -@pytest.fixture -def image_dataset(tmp_path, empty_lerobot_dataset_factory): - features = { - "image": { - "dtype": "image", - "shape": DUMMY_CHW, - "names": [ - "channels", - "height", - "width", - ], - } - } - return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - - -def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): - """ - Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated - objects have the same sets of attributes defined. - """ - # Instantiate both ways - features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} - root_create = tmp_path / "create" - dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, features=features, root=root_create) - - root_init = tmp_path / "init" - dataset_init = lerobot_dataset_factory(root=root_init) - - init_attr = set(vars(dataset_init).keys()) - create_attr = set(vars(dataset_create).keys()) - - assert init_attr == create_attr - - -def test_dataset_initialization(tmp_path, lerobot_dataset_factory): - kwargs = { - "repo_id": DUMMY_REPO_ID, - "total_episodes": 10, - "total_frames": 400, - "episodes": [2, 5, 6], - } - dataset = lerobot_dataset_factory(root=tmp_path / "test", **kwargs) - - assert dataset.repo_id == kwargs["repo_id"] - assert dataset.meta.total_episodes == kwargs["total_episodes"] - assert dataset.meta.total_frames == kwargs["total_frames"] - assert dataset.episodes == kwargs["episodes"] - assert dataset.num_episodes == len(kwargs["episodes"]) - assert dataset.num_frames == len(dataset) - - -def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory): - features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} - dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - with pytest.raises( - ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n" - ): - dataset.add_frame({"wrong_feature": torch.randn(1)}, task="Dummy task") - - -def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory): - features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} - dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - with pytest.raises( - ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n" - ): - dataset.add_frame({"state": torch.randn(1), "extra": "dummy_extra"}, task="Dummy task") - - -def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory): - features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} - dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - with pytest.raises( - ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n" - ): - dataset.add_frame({"state": torch.randn(1, dtype=torch.float16)}, task="Dummy task") - - -def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory): - features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} - dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - with pytest.raises( - ValueError, - match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"), - ): - dataset.add_frame({"state": torch.randn(1)}, task="Dummy task") - - -def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory): - features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} - dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - with pytest.raises( - ValueError, - match=re.escape( - "The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '' provided instead.\n" - ), - ): - dataset.add_frame({"state": 1.0}, task="Dummy task") - - -def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory): - features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} - dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - with pytest.raises( - ValueError, - match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"), - ): - dataset.add_frame({"state": torch.tensor(1.0)}, task="Dummy task") - - -def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory): - features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} - dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - with pytest.raises( - ValueError, - match=re.escape( - "The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '' provided instead.\n" - ), - ): - dataset.add_frame({"state": np.float32(1.0)}, task="Dummy task") - - -def test_add_frame(tmp_path, empty_lerobot_dataset_factory): - features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} - dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(1)}, task="Dummy task") - dataset.save_episode() - - assert len(dataset) == 1 - assert dataset[0]["task"] == "Dummy task" - assert dataset[0]["task_index"] == 0 - assert dataset[0]["state"].ndim == 0 - - -def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory): - features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} - dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2)}, task="Dummy task") - dataset.save_episode() - - assert dataset[0]["state"].shape == torch.Size([2]) - - -def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory): - features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}} - dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2, 4)}, task="Dummy task") - dataset.save_episode() - - assert dataset[0]["state"].shape == torch.Size([2, 4]) - - -def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory): - features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}} - dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2, 4, 3)}, task="Dummy task") - dataset.save_episode() - - assert dataset[0]["state"].shape == torch.Size([2, 4, 3]) - - -def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory): - features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}} - dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2, 4, 3, 5)}, task="Dummy task") - dataset.save_episode() - - assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5]) - - -def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory): - features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}} - dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1)}, task="Dummy task") - dataset.save_episode() - - assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1]) - - -def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory): - features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} - dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": np.array([1], dtype=np.float32)}, task="Dummy task") - dataset.save_episode() - - assert dataset[0]["state"].ndim == 0 - - -def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory): - features = {"caption": {"dtype": "string", "shape": (1,), "names": None}} - dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"caption": "Dummy caption"}, task="Dummy task") - dataset.save_episode() - - assert dataset[0]["caption"] == "Dummy caption" - - -def test_add_frame_image_wrong_shape(image_dataset): - dataset = image_dataset - with pytest.raises( - ValueError, - match=re.escape( - "The feature 'image' of shape '(3, 128, 96)' does not have the expected shape '(3, 96, 128)' or '(96, 128, 3)'.\n" - ), - ): - c, h, w = DUMMY_CHW - dataset.add_frame({"image": torch.randn(c, w, h)}, task="Dummy task") - - -def test_add_frame_image_wrong_range(image_dataset): - """This test will display the following error message from a thread: - ``` - Error writing image ...test_add_frame_image_wrong_ran0/test/images/image/episode_000000/frame_000000.png: - The image data type is float, which requires values in the range [0.0, 1.0]. However, the provided range is [0.009678772038470007, 254.9776492089887]. - Please adjust the range or provide a uint8 image with values in the range [0, 255] - ``` - Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`. - """ - dataset = image_dataset - dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255}, task="Dummy task") - with pytest.raises(FileNotFoundError): - dataset.save_episode() - - -def test_add_frame_image(image_dataset): - dataset = image_dataset - dataset.add_frame({"image": np.random.rand(*DUMMY_CHW)}, task="Dummy task") - dataset.save_episode() - - assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) - - -def test_add_frame_image_h_w_c(image_dataset): - dataset = image_dataset - dataset.add_frame({"image": np.random.rand(*DUMMY_HWC)}, task="Dummy task") - dataset.save_episode() - - assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) - - -def test_add_frame_image_uint8(image_dataset): - dataset = image_dataset - image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) - dataset.add_frame({"image": image}, task="Dummy task") - dataset.save_episode() - - assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) - - -def test_add_frame_image_pil(image_dataset): - dataset = image_dataset - image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) - dataset.add_frame({"image": Image.fromarray(image)}, task="Dummy task") - dataset.save_episode() - - assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) - - -def test_image_array_to_pil_image_wrong_range_float_0_255(): - image = np.random.rand(*DUMMY_HWC) * 255 - with pytest.raises(ValueError): - image_array_to_pil_image(image) - - -# TODO(aliberts): -# - [ ] test various attributes & state from init and create -# - [ ] test init with episodes and check num_frames -# - [ ] test add_episode -# - [ ] test push_to_hub -# - [ ] test smaller methods - - -@pytest.mark.parametrize( - "env_name, repo_id, policy_name", - # Single dataset - lerobot.env_dataset_policy_triplets, - # Multi-dataset - # TODO after fix multidataset - # + [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")], -) -def test_factory(env_name, repo_id, policy_name): - """ - Tests that: - - we can create a dataset with the factory. - - for a commonly used set of data keys, the data dimensions are correct. - """ - cfg = TrainPipelineConfig( - # TODO(rcadene, aliberts): remove dataset download - dataset=DatasetConfig(repo_id=repo_id, episodes=[0]), - env=make_env_config(env_name), - policy=make_policy_config(policy_name, push_to_hub=False), - ) - cfg.validate() - - dataset = make_dataset(cfg) - delta_timestamps = dataset.delta_timestamps - camera_keys = dataset.meta.camera_keys - - item = dataset[0] - - keys_ndim_required = [ - ("action", 1, True), - ("episode_index", 0, True), - ("frame_index", 0, True), - ("timestamp", 0, True), - # TODO(rcadene): should we rename it agent_pos? - ("observation.state", 1, True), - ("next.reward", 0, False), - ("next.done", 0, False), - ] - - # test number of dimensions - for key, ndim, required in keys_ndim_required: - if key not in item: - if required: - assert key in item, f"{key}" - else: - logging.warning(f'Missing key in dataset: "{key}" not in {dataset}.') - continue - - if delta_timestamps is not None and key in delta_timestamps: - assert item[key].ndim == ndim + 1, f"{key}" - assert item[key].shape[0] == len(delta_timestamps[key]), f"{key}" - else: - assert item[key].ndim == ndim, f"{key}" - - if key in camera_keys: - assert item[key].dtype == torch.float32, f"{key}" - # TODO(rcadene): we assume for now that image normalization takes place in the model - assert item[key].max() <= 1.0, f"{key}" - assert item[key].min() >= 0.0, f"{key}" - - if delta_timestamps is not None and key in delta_timestamps: - # test t,c,h,w - assert item[key].shape[1] == 3, f"{key}" - else: - # test c,h,w - assert item[key].shape[0] == 3, f"{key}" - - if delta_timestamps is not None: - # test missing keys in delta_timestamps - for key in delta_timestamps: - assert key in item, f"{key}" - - -# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds. -@pytest.mark.skip("TODO after fix multidataset") -def test_multidataset_frames(): - """Check that all dataset frames are incorporated.""" - # Note: use the image variants of the dataset to make the test approx 3x faster. - # Note: We really do need three repo_ids here as at some point this caught an issue with the chaining - # logic that wouldn't be caught with two repo IDs. - repo_ids = [ - "lerobot/aloha_sim_insertion_human_image", - "lerobot/aloha_sim_transfer_cube_human_image", - "lerobot/aloha_sim_insertion_scripted_image", - ] - sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids] - dataset = MultiLeRobotDataset(repo_ids) - assert len(dataset) == sum(len(d) for d in sub_datasets) - assert dataset.num_frames == sum(d.num_frames for d in sub_datasets) - assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets) - - # Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and - # check they match. - expected_dataset_indices = [] - for i, sub_dataset in enumerate(sub_datasets): - expected_dataset_indices.extend([i] * len(sub_dataset)) - - for expected_dataset_index, sub_dataset_item, dataset_item in zip( - expected_dataset_indices, chain(*sub_datasets), dataset, strict=True - ): - dataset_index = dataset_item.pop("dataset_index") - assert dataset_index == expected_dataset_index - assert sub_dataset_item.keys() == dataset_item.keys() - for k in sub_dataset_item: - assert torch.equal(sub_dataset_item[k], dataset_item[k]) - - -# TODO(aliberts): Move to more appropriate location -def test_flatten_unflatten_dict(): - d = { - "obs": { - "min": 0, - "max": 1, - "mean": 2, - "std": 3, - }, - "action": { - "min": 4, - "max": 5, - "mean": 6, - "std": 7, - }, - } - - original_d = deepcopy(d) - d = unflatten_dict(flatten_dict(d)) - - # test equality between nested dicts - assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}" - - -@pytest.mark.parametrize( - "repo_id", - [ - "lerobot/pusht", - "lerobot/aloha_sim_insertion_human", - "lerobot/xarm_lift_medium", - # (michel-aractingi) commenting the two datasets from openx as test is failing - # "lerobot/nyu_franka_play_dataset", - # "lerobot/cmu_stretch", - ], -) -@require_x86_64_kernel -def test_backward_compatibility(repo_id): - """The artifacts for this test have been generated by `tests/artifacts/datasets/save_dataset_to_safetensors.py`.""" - - # TODO(rcadene, aliberts): remove dataset download - dataset = LeRobotDataset(repo_id, episodes=[0]) - - test_dir = Path("tests/artifacts/datasets") / repo_id - - def load_and_compare(i): - new_frame = dataset[i] # noqa: B023 - old_frame = load_file(test_dir / f"frame_{i}.safetensors") # noqa: B023 - - # ignore language instructions (if exists) in language conditioned datasets - # TODO (michel-aractingi): transform language obs to language embeddings via tokenizer - new_frame.pop("language_instruction", None) - old_frame.pop("language_instruction", None) - new_frame.pop("task", None) - old_frame.pop("task", None) - - # Remove task_index to allow for backward compatibility - # TODO(rcadene): remove when new features have been generated - if "task_index" not in old_frame: - del new_frame["task_index"] - - new_keys = set(new_frame.keys()) - old_keys = set(old_frame.keys()) - assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same" - - for key in new_frame: - assert torch.isclose(new_frame[key], old_frame[key]).all(), ( - f"{key=} for index={i} does not contain the same value" - ) - - # test2 first frames of first episode - i = dataset.episode_data_index["from"][0].item() - load_and_compare(i) - load_and_compare(i + 1) - - # test 2 frames at the middle of first episode - i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2) - load_and_compare(i) - load_and_compare(i + 1) - - # test 2 last frames of first episode - i = dataset.episode_data_index["to"][0].item() - load_and_compare(i - 2) - load_and_compare(i - 1) - - # TODO(rcadene): Enable testing on second and last episode - # We currently cant because our test dataset only contains the first episode - - # # test 2 first frames of second episode - # i = dataset.episode_data_index["from"][1].item() - # load_and_compare(i) - # load_and_compare(i + 1) - - # # test 2 last frames of second episode - # i = dataset.episode_data_index["to"][1].item() - # load_and_compare(i - 2) - # load_and_compare(i - 1) - - # # test 2 last frames of last episode - # i = dataset.episode_data_index["to"][-1].item() - # load_and_compare(i - 2) - # load_and_compare(i - 1) - - -@pytest.mark.skip("Requires internet access") -def test_create_branch(): - api = HfApi() - - repo_id = "cadene/test_create_branch" - repo_type = "dataset" - branch = "test" - ref = f"refs/heads/{branch}" - - # Prepare a repo with a test branch - api.delete_repo(repo_id, repo_type=repo_type, missing_ok=True) - api.create_repo(repo_id, repo_type=repo_type) - create_branch(repo_id, repo_type=repo_type, branch=branch) - - # Make sure the test branch exists - branches = api.list_repo_refs(repo_id, repo_type=repo_type).branches - refs = [branch.ref for branch in branches] - assert ref in refs - - # Overwrite it - create_branch(repo_id, repo_type=repo_type, branch=branch) - - # Clean - api.delete_repo(repo_id, repo_type=repo_type) - - -def test_dataset_feature_with_forward_slash_raises_error(): - # make sure dir does not exist - from lerobot.constants import HF_LEROBOT_HOME - - dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash" - # make sure does not exist - if dataset_dir.exists(): - dataset_dir.rmdir() - - with pytest.raises(ValueError): - LeRobotDataset.create( - repo_id="lerobot/test/with/slash", - fps=30, - features={"a/b": {"dtype": "float32", "shape": 2, "names": None}}, - ) diff --git a/tests/datasets/test_delta_timestamps.py b/tests/datasets/test_delta_timestamps.py deleted file mode 100644 index 786b90c..0000000 --- a/tests/datasets/test_delta_timestamps.py +++ /dev/null @@ -1,278 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from itertools import accumulate - -import datasets -import numpy as np -import pyarrow.compute as pc -import pytest -import torch - -from lerobot.datasets.utils import ( - check_delta_timestamps, - check_timestamps_sync, - get_delta_indices, -) -from tests.fixtures.constants import DUMMY_MOTOR_FEATURES - - -def calculate_total_episode( - hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True -) -> dict[str, torch.Tensor]: - episode_indices = sorted(hf_dataset.unique("episode_index")) - total_episodes = len(episode_indices) - if raise_if_not_contiguous and episode_indices != list(range(total_episodes)): - raise ValueError("episode_index values are not sorted and contiguous.") - return total_episodes - - -def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.ndarray]: - episode_lengths = [] - table = hf_dataset.data.table - total_episodes = calculate_total_episode(hf_dataset) - for ep_idx in range(total_episodes): - ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) - episode_lengths.insert(ep_idx, len(ep_table)) - - cumulative_lengths = list(accumulate(episode_lengths)) - return { - "from": np.array([0] + cumulative_lengths[:-1], dtype=np.int64), - "to": np.array(cumulative_lengths, dtype=np.int64), - } - - -@pytest.fixture(scope="module") -def synced_timestamps_factory(hf_dataset_factory): - def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - hf_dataset = hf_dataset_factory(fps=fps) - timestamps = torch.stack(hf_dataset["timestamp"]).numpy() - episode_indices = torch.stack(hf_dataset["episode_index"]).numpy() - episode_data_index = calculate_episode_data_index(hf_dataset) - return timestamps, episode_indices, episode_data_index - - return _create_synced_timestamps - - -@pytest.fixture(scope="module") -def unsynced_timestamps_factory(synced_timestamps_factory): - def _create_unsynced_timestamps( - fps: int = 30, tolerance_s: float = 1e-4 - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps) - timestamps[30] += tolerance_s * 1.1 # Modify a single timestamp just outside tolerance - return timestamps, episode_indices, episode_data_index - - return _create_unsynced_timestamps - - -@pytest.fixture(scope="module") -def slightly_off_timestamps_factory(synced_timestamps_factory): - def _create_slightly_off_timestamps( - fps: int = 30, tolerance_s: float = 1e-4 - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps) - timestamps[30] += tolerance_s * 0.9 # Modify a single timestamp just inside tolerance - return timestamps, episode_indices, episode_data_index - - return _create_slightly_off_timestamps - - -@pytest.fixture(scope="module") -def valid_delta_timestamps_factory(): - def _create_valid_delta_timestamps( - fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10) - ) -> dict: - delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys} - return delta_timestamps - - return _create_valid_delta_timestamps - - -@pytest.fixture(scope="module") -def invalid_delta_timestamps_factory(valid_delta_timestamps_factory): - def _create_invalid_delta_timestamps( - fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES - ) -> dict: - delta_timestamps = valid_delta_timestamps_factory(fps, keys) - # Modify a single timestamp just outside tolerance - for key in keys: - delta_timestamps[key][3] += tolerance_s * 1.1 - return delta_timestamps - - return _create_invalid_delta_timestamps - - -@pytest.fixture(scope="module") -def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory): - def _create_slightly_off_delta_timestamps( - fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES - ) -> dict: - delta_timestamps = valid_delta_timestamps_factory(fps, keys) - # Modify a single timestamp just inside tolerance - for key in delta_timestamps: - delta_timestamps[key][3] += tolerance_s * 0.9 - delta_timestamps[key][-3] += tolerance_s * 0.9 - return delta_timestamps - - return _create_slightly_off_delta_timestamps - - -@pytest.fixture(scope="module") -def delta_indices_factory(): - def _delta_indices(keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)) -> dict: - return {key: list(range(*min_max_range)) for key in keys} - - return _delta_indices - - -def test_check_timestamps_sync_synced(synced_timestamps_factory): - fps = 30 - tolerance_s = 1e-4 - timestamps, ep_idx, ep_data_index = synced_timestamps_factory(fps) - result = check_timestamps_sync( - timestamps=timestamps, - episode_indices=ep_idx, - episode_data_index=ep_data_index, - fps=fps, - tolerance_s=tolerance_s, - ) - assert result is True - - -def test_check_timestamps_sync_unsynced(unsynced_timestamps_factory): - fps = 30 - tolerance_s = 1e-4 - timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s) - with pytest.raises(ValueError): - check_timestamps_sync( - timestamps=timestamps, - episode_indices=ep_idx, - episode_data_index=ep_data_index, - fps=fps, - tolerance_s=tolerance_s, - ) - - -def test_check_timestamps_sync_unsynced_no_exception(unsynced_timestamps_factory): - fps = 30 - tolerance_s = 1e-4 - timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s) - result = check_timestamps_sync( - timestamps=timestamps, - episode_indices=ep_idx, - episode_data_index=ep_data_index, - fps=fps, - tolerance_s=tolerance_s, - raise_value_error=False, - ) - assert result is False - - -def test_check_timestamps_sync_slightly_off(slightly_off_timestamps_factory): - fps = 30 - tolerance_s = 1e-4 - timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(fps, tolerance_s) - result = check_timestamps_sync( - timestamps=timestamps, - episode_indices=ep_idx, - episode_data_index=ep_data_index, - fps=fps, - tolerance_s=tolerance_s, - ) - assert result is True - - -def test_check_timestamps_sync_single_timestamp(): - fps = 30 - tolerance_s = 1e-4 - timestamps, ep_idx = np.array([0.0]), np.array([0]) - episode_data_index = {"to": np.array([1]), "from": np.array([0])} - result = check_timestamps_sync( - timestamps=timestamps, - episode_indices=ep_idx, - episode_data_index=episode_data_index, - fps=fps, - tolerance_s=tolerance_s, - ) - assert result is True - - -def test_check_delta_timestamps_valid(valid_delta_timestamps_factory): - fps = 30 - tolerance_s = 1e-4 - valid_delta_timestamps = valid_delta_timestamps_factory(fps) - result = check_delta_timestamps( - delta_timestamps=valid_delta_timestamps, - fps=fps, - tolerance_s=tolerance_s, - ) - assert result is True - - -def test_check_delta_timestamps_slightly_off(slightly_off_delta_timestamps_factory): - fps = 30 - tolerance_s = 1e-4 - slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(fps, tolerance_s) - result = check_delta_timestamps( - delta_timestamps=slightly_off_delta_timestamps, - fps=fps, - tolerance_s=tolerance_s, - ) - assert result is True - - -def test_check_delta_timestamps_invalid(invalid_delta_timestamps_factory): - fps = 30 - tolerance_s = 1e-4 - invalid_delta_timestamps = invalid_delta_timestamps_factory(fps, tolerance_s) - with pytest.raises(ValueError): - check_delta_timestamps( - delta_timestamps=invalid_delta_timestamps, - fps=fps, - tolerance_s=tolerance_s, - ) - - -def test_check_delta_timestamps_invalid_no_exception(invalid_delta_timestamps_factory): - fps = 30 - tolerance_s = 1e-4 - invalid_delta_timestamps = invalid_delta_timestamps_factory(fps, tolerance_s) - result = check_delta_timestamps( - delta_timestamps=invalid_delta_timestamps, - fps=fps, - tolerance_s=tolerance_s, - raise_value_error=False, - ) - assert result is False - - -def test_check_delta_timestamps_empty(): - delta_timestamps = {} - fps = 30 - tolerance_s = 1e-4 - result = check_delta_timestamps( - delta_timestamps=delta_timestamps, - fps=fps, - tolerance_s=tolerance_s, - ) - assert result is True - - -def test_delta_indices(valid_delta_timestamps_factory, delta_indices_factory): - fps = 50 - min_max_range = (-100, 100) - delta_timestamps = valid_delta_timestamps_factory(fps, min_max_range=min_max_range) - expected_delta_indices = delta_indices_factory(min_max_range=min_max_range) - actual_delta_indices = get_delta_indices(delta_timestamps, fps) - assert expected_delta_indices == actual_delta_indices diff --git a/tests/datasets/test_image_transforms.py b/tests/datasets/test_image_transforms.py deleted file mode 100644 index 3ab93cb..0000000 --- a/tests/datasets/test_image_transforms.py +++ /dev/null @@ -1,382 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -import torch -from packaging import version -from safetensors.torch import load_file -from torchvision.transforms import v2 -from torchvision.transforms.v2 import functional as F # noqa: N812 - -from lerobot.datasets.transforms import ( - ImageTransformConfig, - ImageTransforms, - ImageTransformsConfig, - RandomSubsetApply, - SharpnessJitter, - make_transform_from_config, -) -from lerobot.scripts.visualize_image_transforms import ( - save_all_transforms, - save_each_transform, -) -from lerobot.utils.random_utils import seeded_context -from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ARTIFACT_DIR -from tests.utils import require_x86_64_kernel - - -@pytest.fixture -def color_jitters(): - return [ - v2.ColorJitter(brightness=0.5), - v2.ColorJitter(contrast=0.5), - v2.ColorJitter(saturation=0.5), - ] - - -@pytest.fixture -def single_transforms(): - return load_file(ARTIFACT_DIR / "single_transforms.safetensors") - - -@pytest.fixture -def img_tensor(single_transforms): - return single_transforms["original_frame"] - - -@pytest.fixture -def default_transforms(): - return load_file(ARTIFACT_DIR / "default_transforms.safetensors") - - -def test_get_image_transforms_no_transform_enable_false(img_tensor_factory): - img_tensor = img_tensor_factory() - tf_cfg = ImageTransformsConfig() # default is enable=False - tf_actual = ImageTransforms(tf_cfg) - torch.testing.assert_close(tf_actual(img_tensor), img_tensor) - - -def test_get_image_transforms_no_transform_max_num_transforms_0(img_tensor_factory): - img_tensor = img_tensor_factory() - tf_cfg = ImageTransformsConfig(enable=True, max_num_transforms=0) - tf_actual = ImageTransforms(tf_cfg) - torch.testing.assert_close(tf_actual(img_tensor), img_tensor) - - -@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) -def test_get_image_transforms_brightness(img_tensor_factory, min_max): - img_tensor = img_tensor_factory() - tf_cfg = ImageTransformsConfig( - enable=True, - tfs={"brightness": ImageTransformConfig(type="ColorJitter", kwargs={"brightness": min_max})}, - ) - tf_actual = ImageTransforms(tf_cfg) - tf_expected = v2.ColorJitter(brightness=min_max) - torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor)) - - -@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) -def test_get_image_transforms_contrast(img_tensor_factory, min_max): - img_tensor = img_tensor_factory() - tf_cfg = ImageTransformsConfig( - enable=True, tfs={"contrast": ImageTransformConfig(type="ColorJitter", kwargs={"contrast": min_max})} - ) - tf_actual = ImageTransforms(tf_cfg) - tf_expected = v2.ColorJitter(contrast=min_max) - torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor)) - - -@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) -def test_get_image_transforms_saturation(img_tensor_factory, min_max): - img_tensor = img_tensor_factory() - tf_cfg = ImageTransformsConfig( - enable=True, - tfs={"saturation": ImageTransformConfig(type="ColorJitter", kwargs={"saturation": min_max})}, - ) - tf_actual = ImageTransforms(tf_cfg) - tf_expected = v2.ColorJitter(saturation=min_max) - torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor)) - - -@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)]) -def test_get_image_transforms_hue(img_tensor_factory, min_max): - img_tensor = img_tensor_factory() - tf_cfg = ImageTransformsConfig( - enable=True, tfs={"hue": ImageTransformConfig(type="ColorJitter", kwargs={"hue": min_max})} - ) - tf_actual = ImageTransforms(tf_cfg) - tf_expected = v2.ColorJitter(hue=min_max) - torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor)) - - -@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) -def test_get_image_transforms_sharpness(img_tensor_factory, min_max): - img_tensor = img_tensor_factory() - tf_cfg = ImageTransformsConfig( - enable=True, - tfs={"sharpness": ImageTransformConfig(type="SharpnessJitter", kwargs={"sharpness": min_max})}, - ) - tf_actual = ImageTransforms(tf_cfg) - tf_expected = SharpnessJitter(sharpness=min_max) - torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor)) - - -def test_get_image_transforms_max_num_transforms(img_tensor_factory): - img_tensor = img_tensor_factory() - tf_cfg = ImageTransformsConfig( - enable=True, - max_num_transforms=5, - tfs={ - "brightness": ImageTransformConfig( - weight=1.0, - type="ColorJitter", - kwargs={"brightness": (0.5, 0.5)}, - ), - "contrast": ImageTransformConfig( - weight=1.0, - type="ColorJitter", - kwargs={"contrast": (0.5, 0.5)}, - ), - "saturation": ImageTransformConfig( - weight=1.0, - type="ColorJitter", - kwargs={"saturation": (0.5, 0.5)}, - ), - "hue": ImageTransformConfig( - weight=1.0, - type="ColorJitter", - kwargs={"hue": (0.5, 0.5)}, - ), - "sharpness": ImageTransformConfig( - weight=1.0, - type="SharpnessJitter", - kwargs={"sharpness": (0.5, 0.5)}, - ), - }, - ) - tf_actual = ImageTransforms(tf_cfg) - tf_expected = v2.Compose( - [ - v2.ColorJitter(brightness=(0.5, 0.5)), - v2.ColorJitter(contrast=(0.5, 0.5)), - v2.ColorJitter(saturation=(0.5, 0.5)), - v2.ColorJitter(hue=(0.5, 0.5)), - SharpnessJitter(sharpness=(0.5, 0.5)), - ] - ) - torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor)) - - -@require_x86_64_kernel -def test_get_image_transforms_random_order(img_tensor_factory): - out_imgs = [] - img_tensor = img_tensor_factory() - tf_cfg = ImageTransformsConfig( - enable=True, - random_order=True, - tfs={ - "brightness": ImageTransformConfig( - weight=1.0, - type="ColorJitter", - kwargs={"brightness": (0.5, 0.5)}, - ), - "contrast": ImageTransformConfig( - weight=1.0, - type="ColorJitter", - kwargs={"contrast": (0.5, 0.5)}, - ), - "saturation": ImageTransformConfig( - weight=1.0, - type="ColorJitter", - kwargs={"saturation": (0.5, 0.5)}, - ), - "hue": ImageTransformConfig( - weight=1.0, - type="ColorJitter", - kwargs={"hue": (0.5, 0.5)}, - ), - "sharpness": ImageTransformConfig( - weight=1.0, - type="SharpnessJitter", - kwargs={"sharpness": (0.5, 0.5)}, - ), - }, - ) - tf = ImageTransforms(tf_cfg) - - with seeded_context(1338): - for _ in range(10): - out_imgs.append(tf(img_tensor)) - - tmp_img_tensor = img_tensor - for sub_tf in tf.tf.selected_transforms: - tmp_img_tensor = sub_tf(tmp_img_tensor) - torch.testing.assert_close(tmp_img_tensor, out_imgs[-1]) - - for i in range(1, len(out_imgs)): - with pytest.raises(AssertionError): - torch.testing.assert_close(out_imgs[0], out_imgs[i]) - - -@pytest.mark.parametrize( - "tf_type, tf_name, min_max_values", - [ - ("ColorJitter", "brightness", [(0.5, 0.5), (2.0, 2.0)]), - ("ColorJitter", "contrast", [(0.5, 0.5), (2.0, 2.0)]), - ("ColorJitter", "saturation", [(0.5, 0.5), (2.0, 2.0)]), - ("ColorJitter", "hue", [(-0.25, -0.25), (0.25, 0.25)]), - ("SharpnessJitter", "sharpness", [(0.5, 0.5), (2.0, 2.0)]), - ], -) -def test_backward_compatibility_single_transforms( - img_tensor, tf_type, tf_name, min_max_values, single_transforms -): - for min_max in min_max_values: - tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max}) - tf = make_transform_from_config(tf_cfg) - actual = tf(img_tensor) - key = f"{tf_name}_{min_max[0]}_{min_max[1]}" - expected = single_transforms[key] - torch.testing.assert_close(actual, expected) - - -@require_x86_64_kernel -@pytest.mark.skipif( - version.parse(torch.__version__) < version.parse("2.7.0"), - reason="Test artifacts were generated with PyTorch >= 2.7.0 which has different multinomial behavior", -) -def test_backward_compatibility_default_config(img_tensor, default_transforms): - # NOTE: PyTorch versions have different randomness, it might break this test. - # See this PR: https://github.com/huggingface/lerobot/pull/1127. - - cfg = ImageTransformsConfig(enable=True) - default_tf = ImageTransforms(cfg) - - with seeded_context(1337): - actual = default_tf(img_tensor) - - expected = default_transforms["default"] - - torch.testing.assert_close(actual, expected) - - -@pytest.mark.parametrize("p", [[0, 1], [1, 0]]) -def test_random_subset_apply_single_choice(img_tensor_factory, p): - img_tensor = img_tensor_factory() - flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)] - random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False) - actual = random_choice(img_tensor) - - p_horz, _ = p - if p_horz: - torch.testing.assert_close(actual, F.horizontal_flip(img_tensor)) - else: - torch.testing.assert_close(actual, F.vertical_flip(img_tensor)) - - -def test_random_subset_apply_random_order(img_tensor_factory): - img_tensor = img_tensor_factory() - flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)] - random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True) - # We can't really check whether the transforms are actually applied in random order. However, - # horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform - # applies them in random order, we can use a fixed order to compute the expected value. - actual = random_order(img_tensor) - expected = v2.Compose(flips)(img_tensor) - torch.testing.assert_close(actual, expected) - - -def test_random_subset_apply_valid_transforms(img_tensor_factory, color_jitters): - img_tensor = img_tensor_factory() - transform = RandomSubsetApply(color_jitters) - output = transform(img_tensor) - assert output.shape == img_tensor.shape - - -def test_random_subset_apply_probability_length_mismatch(color_jitters): - with pytest.raises(ValueError): - RandomSubsetApply(color_jitters, p=[0.5, 0.5]) - - -@pytest.mark.parametrize("n_subset", [0, 5]) -def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset): - with pytest.raises(ValueError): - RandomSubsetApply(color_jitters, n_subset=n_subset) - - -def test_sharpness_jitter_valid_range_tuple(img_tensor_factory): - img_tensor = img_tensor_factory() - tf = SharpnessJitter((0.1, 2.0)) - output = tf(img_tensor) - assert output.shape == img_tensor.shape - - -def test_sharpness_jitter_valid_range_float(img_tensor_factory): - img_tensor = img_tensor_factory() - tf = SharpnessJitter(0.5) - output = tf(img_tensor) - assert output.shape == img_tensor.shape - - -def test_sharpness_jitter_invalid_range_min_negative(): - with pytest.raises(ValueError): - SharpnessJitter((-0.1, 2.0)) - - -def test_sharpness_jitter_invalid_range_max_smaller(): - with pytest.raises(ValueError): - SharpnessJitter((2.0, 0.1)) - - -def test_save_all_transforms(img_tensor_factory, tmp_path): - img_tensor = img_tensor_factory() - tf_cfg = ImageTransformsConfig(enable=True) - n_examples = 3 - - save_all_transforms(tf_cfg, img_tensor, tmp_path, n_examples) - - # Check if the combined transforms directory exists and contains the right files - combined_transforms_dir = tmp_path / "all" - assert combined_transforms_dir.exists(), "Combined transforms directory was not created." - assert any(combined_transforms_dir.iterdir()), ( - "No transformed images found in combined transforms directory." - ) - for i in range(1, n_examples + 1): - assert (combined_transforms_dir / f"{i}.png").exists(), ( - f"Combined transform image {i}.png was not found." - ) - - -def test_save_each_transform(img_tensor_factory, tmp_path): - img_tensor = img_tensor_factory() - tf_cfg = ImageTransformsConfig(enable=True) - n_examples = 3 - - save_each_transform(tf_cfg, img_tensor, tmp_path, n_examples) - - # Check if the transformed images exist for each transform type - transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"] - for transform in transforms: - transform_dir = tmp_path / transform - assert transform_dir.exists(), f"{transform} directory was not created." - assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory." - - # Check for specific files within each transform directory - expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"] - for file_name in expected_files: - assert (transform_dir / file_name).exists(), ( - f"{file_name} was not found in {transform} directory." - ) diff --git a/tests/datasets/test_image_writer.py b/tests/datasets/test_image_writer.py deleted file mode 100644 index 99c8b24..0000000 --- a/tests/datasets/test_image_writer.py +++ /dev/null @@ -1,386 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import queue -import time -from multiprocessing import queues -from unittest.mock import MagicMock, patch - -import numpy as np -import pytest -from PIL import Image - -from lerobot.datasets.image_writer import ( - AsyncImageWriter, - image_array_to_pil_image, - safe_stop_image_writer, - write_image, -) -from tests.fixtures.constants import DUMMY_HWC - -DUMMY_IMAGE = "test_image.png" - - -def test_init_threading(): - writer = AsyncImageWriter(num_processes=0, num_threads=2) - try: - assert writer.num_processes == 0 - assert writer.num_threads == 2 - assert isinstance(writer.queue, queue.Queue) - assert len(writer.threads) == 2 - assert len(writer.processes) == 0 - assert all(t.is_alive() for t in writer.threads) - finally: - writer.stop() - - -def test_init_multiprocessing(): - writer = AsyncImageWriter(num_processes=2, num_threads=2) - try: - assert writer.num_processes == 2 - assert writer.num_threads == 2 - assert isinstance(writer.queue, queues.JoinableQueue) - assert len(writer.threads) == 0 - assert len(writer.processes) == 2 - assert all(p.is_alive() for p in writer.processes) - finally: - writer.stop() - - -def test_zero_threads(): - with pytest.raises(ValueError): - AsyncImageWriter(num_processes=0, num_threads=0) - - -def test_image_array_to_pil_image_float_array_wrong_range_0_255(): - image = np.random.rand(*DUMMY_HWC) * 255 - with pytest.raises(ValueError): - image_array_to_pil_image(image) - - -def test_image_array_to_pil_image_float_array_wrong_range_neg_1_1(): - image = np.random.rand(*DUMMY_HWC) * 2 - 1 - with pytest.raises(ValueError): - image_array_to_pil_image(image) - - -def test_image_array_to_pil_image_rgb(img_array_factory): - img_array = img_array_factory(100, 100) - result_image = image_array_to_pil_image(img_array) - assert isinstance(result_image, Image.Image) - assert result_image.size == (100, 100) - assert result_image.mode == "RGB" - - -def test_image_array_to_pil_image_pytorch_format(img_array_factory): - img_array = img_array_factory(100, 100).transpose(2, 0, 1) - result_image = image_array_to_pil_image(img_array) - assert isinstance(result_image, Image.Image) - assert result_image.size == (100, 100) - assert result_image.mode == "RGB" - - -def test_image_array_to_pil_image_single_channel(img_array_factory): - img_array = img_array_factory(channels=1) - with pytest.raises(NotImplementedError): - image_array_to_pil_image(img_array) - - -def test_image_array_to_pil_image_4_channels(img_array_factory): - img_array = img_array_factory(channels=4) - with pytest.raises(NotImplementedError): - image_array_to_pil_image(img_array) - - -def test_image_array_to_pil_image_float_array(img_array_factory): - img_array = img_array_factory(dtype=np.float32) - result_image = image_array_to_pil_image(img_array) - assert isinstance(result_image, Image.Image) - assert result_image.size == (100, 100) - assert result_image.mode == "RGB" - assert np.array(result_image).dtype == np.uint8 - - -def test_image_array_to_pil_image_uint8_array(img_array_factory): - img_array = img_array_factory(dtype=np.float32) - result_image = image_array_to_pil_image(img_array) - assert isinstance(result_image, Image.Image) - assert result_image.size == (100, 100) - assert result_image.mode == "RGB" - assert np.array(result_image).dtype == np.uint8 - - -def test_write_image_numpy(tmp_path, img_array_factory): - image_array = img_array_factory() - fpath = tmp_path / DUMMY_IMAGE - write_image(image_array, fpath) - assert fpath.exists() - saved_image = np.array(Image.open(fpath)) - assert np.array_equal(image_array, saved_image) - - -def test_write_image_image(tmp_path, img_factory): - image_pil = img_factory() - fpath = tmp_path / DUMMY_IMAGE - write_image(image_pil, fpath) - assert fpath.exists() - saved_image = Image.open(fpath) - assert list(saved_image.getdata()) == list(image_pil.getdata()) - assert np.array_equal(image_pil, saved_image) - - -def test_write_image_exception(tmp_path): - image_array = "invalid data" - fpath = tmp_path / DUMMY_IMAGE - with patch("builtins.print") as mock_print: - write_image(image_array, fpath) - mock_print.assert_called() - assert not fpath.exists() - - -def test_save_image_numpy(tmp_path, img_array_factory): - writer = AsyncImageWriter() - try: - image_array = img_array_factory() - fpath = tmp_path / DUMMY_IMAGE - fpath.parent.mkdir(parents=True, exist_ok=True) - writer.save_image(image_array, fpath) - writer.wait_until_done() - assert fpath.exists() - saved_image = np.array(Image.open(fpath)) - assert np.array_equal(image_array, saved_image) - finally: - writer.stop() - - -def test_save_image_numpy_multiprocessing(tmp_path, img_array_factory): - writer = AsyncImageWriter(num_processes=2, num_threads=2) - try: - image_array = img_array_factory() - fpath = tmp_path / DUMMY_IMAGE - writer.save_image(image_array, fpath) - writer.wait_until_done() - assert fpath.exists() - saved_image = np.array(Image.open(fpath)) - assert np.array_equal(image_array, saved_image) - finally: - writer.stop() - - -def test_save_image_torch(tmp_path, img_tensor_factory): - writer = AsyncImageWriter() - try: - image_tensor = img_tensor_factory() - fpath = tmp_path / DUMMY_IMAGE - fpath.parent.mkdir(parents=True, exist_ok=True) - writer.save_image(image_tensor, fpath) - writer.wait_until_done() - assert fpath.exists() - saved_image = np.array(Image.open(fpath)) - expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) - assert np.array_equal(expected_image, saved_image) - finally: - writer.stop() - - -def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory): - writer = AsyncImageWriter(num_processes=2, num_threads=2) - try: - image_tensor = img_tensor_factory() - fpath = tmp_path / DUMMY_IMAGE - writer.save_image(image_tensor, fpath) - writer.wait_until_done() - assert fpath.exists() - saved_image = np.array(Image.open(fpath)) - expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) - assert np.array_equal(expected_image, saved_image) - finally: - writer.stop() - - -def test_save_image_pil(tmp_path, img_factory): - writer = AsyncImageWriter() - try: - image_pil = img_factory() - fpath = tmp_path / DUMMY_IMAGE - fpath.parent.mkdir(parents=True, exist_ok=True) - writer.save_image(image_pil, fpath) - writer.wait_until_done() - assert fpath.exists() - saved_image = Image.open(fpath) - assert list(saved_image.getdata()) == list(image_pil.getdata()) - finally: - writer.stop() - - -def test_save_image_pil_multiprocessing(tmp_path, img_factory): - writer = AsyncImageWriter(num_processes=2, num_threads=2) - try: - image_pil = img_factory() - fpath = tmp_path / DUMMY_IMAGE - writer.save_image(image_pil, fpath) - writer.wait_until_done() - assert fpath.exists() - saved_image = Image.open(fpath) - assert list(saved_image.getdata()) == list(image_pil.getdata()) - finally: - writer.stop() - - -def test_save_image_invalid_data(tmp_path): - writer = AsyncImageWriter() - try: - image_array = "invalid data" - fpath = tmp_path / DUMMY_IMAGE - fpath.parent.mkdir(parents=True, exist_ok=True) - with patch("builtins.print") as mock_print: - writer.save_image(image_array, fpath) - writer.wait_until_done() - mock_print.assert_called() - assert not fpath.exists() - finally: - writer.stop() - - -def test_save_image_after_stop(tmp_path, img_array_factory): - writer = AsyncImageWriter() - writer.stop() - image_array = img_array_factory() - fpath = tmp_path / DUMMY_IMAGE - writer.save_image(image_array, fpath) - time.sleep(1) - assert not fpath.exists() - - -def test_stop(): - writer = AsyncImageWriter(num_processes=0, num_threads=2) - writer.stop() - assert not any(t.is_alive() for t in writer.threads) - - -def test_stop_multiprocessing(): - writer = AsyncImageWriter(num_processes=2, num_threads=2) - writer.stop() - assert not any(p.is_alive() for p in writer.processes) - - -def test_multiple_stops(): - writer = AsyncImageWriter() - writer.stop() - writer.stop() # Should not raise an exception - assert not any(t.is_alive() for t in writer.threads) - - -def test_multiple_stops_multiprocessing(): - writer = AsyncImageWriter(num_processes=2, num_threads=2) - writer.stop() - writer.stop() # Should not raise an exception - assert not any(t.is_alive() for t in writer.threads) - - -def test_wait_until_done(tmp_path, img_array_factory): - writer = AsyncImageWriter(num_processes=0, num_threads=4) - try: - num_images = 100 - image_arrays = [img_array_factory(height=500, width=500) for _ in range(num_images)] - fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)] - for image_array, fpath in zip(image_arrays, fpaths, strict=True): - fpath.parent.mkdir(parents=True, exist_ok=True) - writer.save_image(image_array, fpath) - writer.wait_until_done() - for i, fpath in enumerate(fpaths): - assert fpath.exists() - saved_image = np.array(Image.open(fpath)) - assert np.array_equal(saved_image, image_arrays[i]) - finally: - writer.stop() - - -def test_wait_until_done_multiprocessing(tmp_path, img_array_factory): - writer = AsyncImageWriter(num_processes=2, num_threads=2) - try: - num_images = 100 - image_arrays = [img_array_factory() for _ in range(num_images)] - fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)] - for image_array, fpath in zip(image_arrays, fpaths, strict=True): - fpath.parent.mkdir(parents=True, exist_ok=True) - writer.save_image(image_array, fpath) - writer.wait_until_done() - for i, fpath in enumerate(fpaths): - assert fpath.exists() - saved_image = np.array(Image.open(fpath)) - assert np.array_equal(saved_image, image_arrays[i]) - finally: - writer.stop() - - -def test_exception_handling(tmp_path, img_array_factory): - writer = AsyncImageWriter() - try: - image_array = img_array_factory() - with ( - patch.object(writer.queue, "put", side_effect=queue.Full("Queue is full")), - pytest.raises(queue.Full) as exc_info, - ): - writer.save_image(image_array, tmp_path / "test.png") - assert str(exc_info.value) == "Queue is full" - finally: - writer.stop() - - -def test_with_different_image_formats(tmp_path, img_array_factory): - writer = AsyncImageWriter() - try: - image_array = img_array_factory() - formats = ["png", "jpeg", "bmp"] - for fmt in formats: - fpath = tmp_path / f"test_image.{fmt}" - write_image(image_array, fpath) - assert fpath.exists() - finally: - writer.stop() - - -def test_safe_stop_image_writer_decorator(): - class MockDataset: - def __init__(self): - self.image_writer = MagicMock(spec=AsyncImageWriter) - - @safe_stop_image_writer - def function_that_raises_exception(dataset=None): - raise Exception("Test exception") - - dataset = MockDataset() - - with pytest.raises(Exception) as exc_info: - function_that_raises_exception(dataset=dataset) - - assert str(exc_info.value) == "Test exception" - dataset.image_writer.stop.assert_called_once() - - -def test_main_process_time(tmp_path, img_tensor_factory): - writer = AsyncImageWriter() - try: - image_tensor = img_tensor_factory() - fpath = tmp_path / DUMMY_IMAGE - start_time = time.perf_counter() - writer.save_image(image_tensor, fpath) - end_time = time.perf_counter() - time_spent = end_time - start_time - # Might need to adjust this threshold depending on hardware - assert time_spent < 0.01, f"Main process time exceeded threshold: {time_spent}s" - writer.wait_until_done() - assert fpath.exists() - finally: - writer.stop() diff --git a/tests/datasets/test_online_buffer.py b/tests/datasets/test_online_buffer.py deleted file mode 100644 index 887da60..0000000 --- a/tests/datasets/test_online_buffer.py +++ /dev/null @@ -1,282 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License.d -from copy import deepcopy -from uuid import uuid4 - -import numpy as np -import pytest -import torch - -from lerobot.datasets.online_buffer import OnlineBuffer, compute_sampler_weights - -# Some constants for OnlineBuffer tests. -data_key = "data" -data_shape = (2, 3) # just some arbitrary > 1D shape -buffer_capacity = 100 -fps = 10 - - -def make_new_buffer( - write_dir: str | None = None, delta_timestamps: dict[str, list[float]] | None = None -) -> tuple[OnlineBuffer, str]: - if write_dir is None: - write_dir = f"/tmp/online_buffer_{uuid4().hex}" - buffer = OnlineBuffer( - write_dir, - data_spec={data_key: {"shape": data_shape, "dtype": np.dtype("float32")}}, - buffer_capacity=buffer_capacity, - fps=fps, - delta_timestamps=delta_timestamps, - ) - return buffer, write_dir - - -def make_spoof_data_frames(n_episodes: int, n_frames_per_episode: int) -> dict[str, np.ndarray]: - new_data = { - data_key: np.arange(n_frames_per_episode * n_episodes * np.prod(data_shape)).reshape(-1, *data_shape), - OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes), - OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(np.arange(n_episodes), n_frames_per_episode), - OnlineBuffer.FRAME_INDEX_KEY: np.tile(np.arange(n_frames_per_episode), n_episodes), - OnlineBuffer.TIMESTAMP_KEY: np.tile(np.arange(n_frames_per_episode) / fps, n_episodes), - } - return new_data - - -def test_non_mutate(): - """Checks that the data provided to the add_data method is copied rather than passed by reference. - - This means that mutating the data in the buffer does not mutate the original data. - - NOTE: If this test fails, it means some of the other tests may be compromised. For example, we can't trust - a success case for `test_write_read`. - """ - buffer, _ = make_new_buffer() - new_data = make_spoof_data_frames(2, buffer_capacity // 4) - new_data_copy = deepcopy(new_data) - buffer.add_data(new_data) - buffer._data[data_key][:] += 1 - assert all(np.array_equal(new_data[k], new_data_copy[k]) for k in new_data) - - -def test_index_error_no_data(): - buffer, _ = make_new_buffer() - with pytest.raises(IndexError): - buffer[0] - - -def test_index_error_with_data(): - buffer, _ = make_new_buffer() - n_frames = buffer_capacity // 2 - new_data = make_spoof_data_frames(1, n_frames) - buffer.add_data(new_data) - with pytest.raises(IndexError): - buffer[n_frames] - with pytest.raises(IndexError): - buffer[-n_frames - 1] - - -@pytest.mark.parametrize("do_reload", [False, True]) -def test_write_read(do_reload: bool): - """Checks that data can be added to the buffer and read back. - - If do_reload we delete the buffer object and load the buffer back from disk before reading. - """ - buffer, write_dir = make_new_buffer() - n_episodes = 2 - n_frames_per_episode = buffer_capacity // 4 - new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode) - buffer.add_data(new_data) - - if do_reload: - del buffer - buffer, _ = make_new_buffer(write_dir) - - assert len(buffer) == n_frames_per_episode * n_episodes - for i, item in enumerate(buffer): - assert all(isinstance(item[k], torch.Tensor) for k in item) - assert np.array_equal(item[data_key].numpy(), new_data[data_key][i]) - - -def test_read_data_key(): - """Tests that data can be added to a buffer and all data for a. specific key can be read back.""" - buffer, _ = make_new_buffer() - n_episodes = 2 - n_frames_per_episode = buffer_capacity // 4 - new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode) - buffer.add_data(new_data) - - data_from_buffer = buffer.get_data_by_key(data_key) - assert isinstance(data_from_buffer, torch.Tensor) - assert np.array_equal(data_from_buffer.numpy(), new_data[data_key]) - - -def test_fifo(): - """Checks that if data is added beyond the buffer capacity, we discard the oldest data first.""" - buffer, _ = make_new_buffer() - n_frames_per_episode = buffer_capacity // 4 - n_episodes = 3 - new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode) - buffer.add_data(new_data) - n_more_episodes = 2 - # Developer sanity check (in case someone changes the global `buffer_capacity`). - assert (n_episodes + n_more_episodes) * n_frames_per_episode > buffer_capacity, ( - "Something went wrong with the test code." - ) - more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode) - buffer.add_data(more_new_data) - assert len(buffer) == buffer_capacity, "The buffer should be full." - - expected_data = {} - for k in new_data: - # Concatenate, left-truncate, then roll, to imitate the cyclical FIFO pattern in OnlineBuffer. - expected_data[k] = np.roll( - np.concatenate([new_data[k], more_new_data[k]])[-buffer_capacity:], - shift=len(new_data[k]) + len(more_new_data[k]) - buffer_capacity, - axis=0, - ) - - for i, item in enumerate(buffer): - assert all(isinstance(item[k], torch.Tensor) for k in item) - assert np.array_equal(item[data_key].numpy(), expected_data[data_key][i]) - - -def test_delta_timestamps_within_tolerance(): - """Check that getting an item with delta_timestamps within tolerance succeeds. - - Note: Copied from `test_datasets.py::test_load_previous_and_future_frames_within_tolerance`. - """ - # Sanity check on global fps as we are assuming it is 10 here. - assert fps == 10, "This test assumes fps==10" - buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.139]}) - new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5) - buffer.add_data(new_data) - buffer.tolerance_s = 0.04 - item = buffer[2] - data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"] - torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values") - assert not is_pad.any(), "Unexpected padding detected" - - -def test_delta_timestamps_outside_tolerance_inside_episode_range(): - """Check that getting an item with delta_timestamps outside of tolerance fails. - - We expect it to fail if and only if the requested timestamps are within the episode range. - - Note: Copied from - `test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_inside_episode_range` - """ - # Sanity check on global fps as we are assuming it is 10 here. - assert fps == 10, "This test assumes fps==10" - buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.141]}) - new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5) - buffer.add_data(new_data) - buffer.tolerance_s = 0.04 - with pytest.raises(AssertionError): - buffer[2] - - -def test_delta_timestamps_outside_tolerance_outside_episode_range(): - """Check that copy-padding of timestamps outside of the episode range works. - - Note: Copied from - `test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_outside_episode_range` - """ - # Sanity check on global fps as we are assuming it is 10 here. - assert fps == 10, "This test assumes fps==10" - buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.3, -0.24, 0, 0.26, 0.3]}) - new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5) - buffer.add_data(new_data) - buffer.tolerance_s = 0.04 - item = buffer[2] - data, is_pad = item["index"], item["index_is_pad"] - assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values" - assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), ( - "Padding does not match expected values" - ) - - -# Arbitrarily set small dataset sizes, making sure to have uneven sizes. -@pytest.mark.parametrize("offline_dataset_size", [1, 6]) -@pytest.mark.parametrize("online_dataset_size", [0, 4]) -@pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0]) -def test_compute_sampler_weights_trivial( - lerobot_dataset_factory, - tmp_path, - offline_dataset_size: int, - online_dataset_size: int, - online_sampling_ratio: float, -): - offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size) - online_dataset, _ = make_new_buffer() - if online_dataset_size > 0: - online_dataset.add_data( - make_spoof_data_frames(n_episodes=2, n_frames_per_episode=online_dataset_size // 2) - ) - - weights = compute_sampler_weights( - offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio - ) - if offline_dataset_size == 0 or online_dataset_size == 0: - expected_weights = torch.ones(offline_dataset_size + online_dataset_size) - elif online_sampling_ratio == 0: - expected_weights = torch.cat([torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)]) - elif online_sampling_ratio == 1: - expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)]) - expected_weights /= expected_weights.sum() - torch.testing.assert_close(weights, expected_weights) - - -def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path): - # Arbitrarily set small dataset sizes, making sure to have uneven sizes. - offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4) - online_dataset, _ = make_new_buffer() - online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) - online_sampling_ratio = 0.8 - weights = compute_sampler_weights( - offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio - ) - torch.testing.assert_close( - weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]) - ) - - -def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path): - # Arbitrarily set small dataset sizes, making sure to have uneven sizes. - offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4) - online_dataset, _ = make_new_buffer() - online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) - weights = compute_sampler_weights( - offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1 - ) - torch.testing.assert_close( - weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]) - ) - - -def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path): - """Note: test copied from test_sampler.""" - offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2) - online_dataset, _ = make_new_buffer() - online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) - - weights = compute_sampler_weights( - offline_dataset, - offline_drop_n_last_frames=1, - online_dataset=online_dataset, - online_sampling_ratio=0.5, - online_drop_n_last_frames=1, - ) - torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0])) diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py deleted file mode 100644 index 94576a3..0000000 --- a/tests/datasets/test_sampler.py +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from datasets import Dataset - -from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index -from lerobot.datasets.sampler import EpisodeAwareSampler -from lerobot.datasets.utils import ( - hf_transform_to_torch, -) - - -def test_drop_n_first_frames(): - dataset = Dataset.from_dict( - { - "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - "index": [0, 1, 2, 3, 4, 5], - "episode_index": [0, 0, 1, 2, 2, 2], - }, - ) - dataset.set_transform(hf_transform_to_torch) - episode_data_index = calculate_episode_data_index(dataset) - sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1) - assert sampler.indices == [1, 4, 5] - assert len(sampler) == 3 - assert list(sampler) == [1, 4, 5] - - -def test_drop_n_last_frames(): - dataset = Dataset.from_dict( - { - "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - "index": [0, 1, 2, 3, 4, 5], - "episode_index": [0, 0, 1, 2, 2, 2], - }, - ) - dataset.set_transform(hf_transform_to_torch) - episode_data_index = calculate_episode_data_index(dataset) - sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1) - assert sampler.indices == [0, 3, 4] - assert len(sampler) == 3 - assert list(sampler) == [0, 3, 4] - - -def test_episode_indices_to_use(): - dataset = Dataset.from_dict( - { - "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - "index": [0, 1, 2, 3, 4, 5], - "episode_index": [0, 0, 1, 2, 2, 2], - }, - ) - dataset.set_transform(hf_transform_to_torch) - episode_data_index = calculate_episode_data_index(dataset) - sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2]) - assert sampler.indices == [0, 1, 3, 4, 5] - assert len(sampler) == 5 - assert list(sampler) == [0, 1, 3, 4, 5] - - -def test_shuffle(): - dataset = Dataset.from_dict( - { - "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - "index": [0, 1, 2, 3, 4, 5], - "episode_index": [0, 0, 1, 2, 2, 2], - }, - ) - dataset.set_transform(hf_transform_to_torch) - episode_data_index = calculate_episode_data_index(dataset) - sampler = EpisodeAwareSampler(episode_data_index, shuffle=False) - assert sampler.indices == [0, 1, 2, 3, 4, 5] - assert len(sampler) == 6 - assert list(sampler) == [0, 1, 2, 3, 4, 5] - sampler = EpisodeAwareSampler(episode_data_index, shuffle=True) - assert sampler.indices == [0, 1, 2, 3, 4, 5] - assert len(sampler) == 6 - assert set(sampler) == {0, 1, 2, 3, 4, 5} diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py deleted file mode 100644 index ba16874..0000000 --- a/tests/datasets/test_utils.py +++ /dev/null @@ -1,55 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from datasets import Dataset -from huggingface_hub import DatasetCard - -from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index -from lerobot.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch - - -def test_default_parameters(): - card = create_lerobot_dataset_card() - assert isinstance(card, DatasetCard) - assert card.data.tags == ["LeRobot"] - assert card.data.task_categories == ["robotics"] - assert card.data.configs == [ - { - "config_name": "default", - "data_files": "data/*/*.parquet", - } - ] - - -def test_with_tags(): - tags = ["tag1", "tag2"] - card = create_lerobot_dataset_card(tags=tags) - assert card.data.tags == ["LeRobot", "tag1", "tag2"] - - -def test_calculate_episode_data_index(): - dataset = Dataset.from_dict( - { - "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - "index": [0, 1, 2, 3, 4, 5], - "episode_index": [0, 0, 1, 2, 2, 2], - }, - ) - dataset.set_transform(hf_transform_to_torch) - episode_data_index = calculate_episode_data_index(dataset) - assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3])) - assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6])) diff --git a/tests/datasets/test_visualize_dataset.py b/tests/datasets/test_visualize_dataset.py deleted file mode 100644 index 303342e..0000000 --- a/tests/datasets/test_visualize_dataset.py +++ /dev/null @@ -1,33 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest - -from lerobot.scripts.visualize_dataset import visualize_dataset - - -@pytest.mark.skip("TODO: add dummy videos") -def test_visualize_local_dataset(tmp_path, lerobot_dataset_factory): - root = tmp_path / "dataset" - output_dir = tmp_path / "outputs" - dataset = lerobot_dataset_factory(root=root) - rrd_path = visualize_dataset( - dataset, - episode_index=0, - batch_size=32, - save=True, - output_dir=output_dir, - ) - assert rrd_path.exists() diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py deleted file mode 100644 index 140e9df..0000000 --- a/tests/envs/test_envs.py +++ /dev/null @@ -1,63 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import importlib - -import gymnasium as gym -import pytest -import torch -from gymnasium.utils.env_checker import check_env - -import lerobot -from lerobot.envs.factory import make_env, make_env_config -from lerobot.envs.utils import preprocess_observation -from tests.utils import require_env - -OBS_TYPES = ["state", "pixels", "pixels_agent_pos"] - - -@pytest.mark.parametrize("obs_type", OBS_TYPES) -@pytest.mark.parametrize("env_name, env_task", lerobot.env_task_pairs) -@require_env -def test_env(env_name, env_task, obs_type): - if env_name == "aloha" and obs_type == "state": - pytest.skip("`state` observations not available for aloha") - - package_name = f"gym_{env_name}" - importlib.import_module(package_name) - env = gym.make(f"{package_name}/{env_task}", obs_type=obs_type) - check_env(env.unwrapped, skip_render_check=True) - env.close() - - -@pytest.mark.parametrize("env_name", lerobot.available_envs) -@require_env -def test_factory(env_name): - cfg = make_env_config(env_name) - env = make_env(cfg, n_envs=1) - obs, _ = env.reset() - obs = preprocess_observation(obs) - - # test image keys are float32 in range [0,1] - for key in obs: - if "image" not in key: - continue - img = obs[key] - assert img.dtype == torch.float32 - # TODO(rcadene): we assume for now that image normalization takes place in the model - assert img.max() <= 1.0 - assert img.min() >= 0.0 - - env.close() diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py deleted file mode 100644 index aabec69..0000000 --- a/tests/examples/test_examples.py +++ /dev/null @@ -1,147 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import io -import subprocess -import sys -from pathlib import Path - -import pytest - -from tests.fixtures.constants import DUMMY_REPO_ID -from tests.utils import require_package - - -def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str: - for f, r in finds_and_replaces: - assert f in text - text = text.replace(f, r) - return text - - -# TODO(aliberts): Remove usage of subprocess calls and patch code with fixtures -def _run_script(path): - subprocess.run([sys.executable, path], check=True) - - -def _read_file(path): - with open(path) as file: - return file.read() - - -@pytest.mark.skip("TODO Fix and remove subprocess / excec calls") -def test_example_1(tmp_path, lerobot_dataset_factory): - _ = lerobot_dataset_factory(root=tmp_path, repo_id=DUMMY_REPO_ID) - path = "examples/1_load_lerobot_dataset.py" - file_contents = _read_file(path) - file_contents = _find_and_replace( - file_contents, - [ - ('repo_id = "lerobot/pusht"', f'repo_id = "{DUMMY_REPO_ID}"'), - ( - "LeRobotDataset(repo_id", - f"LeRobotDataset(repo_id, root='{str(tmp_path)}'", - ), - ], - ) - exec(file_contents, {}) - assert Path("outputs/examples/1_load_lerobot_dataset/episode_0.mp4").exists() - - -@pytest.mark.skip("TODO Fix and remove subprocess / excec calls") -@require_package("gym_pusht") -def test_examples_basic2_basic3_advanced1(): - """ - Train a model with example 3, check the outputs. - Evaluate the trained model with example 2, check the outputs. - Calculate the validation loss with advanced example 1, check the outputs. - """ - - ### Test example 3 - file_contents = _read_file("examples/3_train_policy.py") - - # Do fewer steps, use smaller batch, use CPU, and don't complicate things with dataloader workers. - file_contents = _find_and_replace( - file_contents, - [ - ("training_steps = 5000", "training_steps = 1"), - ("num_workers=4", "num_workers=0"), - ('device = torch.device("cuda")', 'device = torch.device("cpu")'), - ("batch_size=64", "batch_size=1"), - ], - ) - - # Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249. - exec(file_contents, {}) - - for file_name in ["model.safetensors", "config.json"]: - assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists() - - ### Test example 2 - file_contents = _read_file("examples/2_evaluate_pretrained_policy.py") - - # Do fewer evals, use CPU, and use the local model. - file_contents = _find_and_replace( - file_contents, - [ - ( - 'pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))', - "", - ), - ( - '# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', - 'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', - ), - ('device = torch.device("cuda")', 'device = torch.device("cpu")'), - ("step += 1", "break"), - ], - ) - - exec(file_contents, {}) - - assert Path("outputs/eval/example_pusht_diffusion/rollout.mp4").exists() - - ## Test example 4 - file_contents = _read_file("examples/advanced/2_calculate_validation_loss.py") - - # Run on a single example from the last episode, use CPU, and use the local model. - file_contents = _find_and_replace( - file_contents, - [ - ( - 'pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))', - "", - ), - ( - '# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', - 'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', - ), - ("train_episodes = episodes[:num_train_episodes]", "train_episodes = [0]"), - ("val_episodes = episodes[num_train_episodes:]", "val_episodes = [1]"), - ("num_workers=4", "num_workers=0"), - ('device = torch.device("cuda")', 'device = torch.device("cpu")'), - ("batch_size=64", "batch_size=1"), - ], - ) - - # Capture the output of the script - output_buffer = io.StringIO() - sys.stdout = output_buffer - exec(file_contents, {}) - printed_output = output_buffer.getvalue() - # Restore stdout to its original state - sys.stdout = sys.__stdout__ - assert "Average loss on validation set" in printed_output diff --git a/tests/fixtures/constants.py b/tests/fixtures/constants.py deleted file mode 100644 index d69a463..0000000 --- a/tests/fixtures/constants.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from lerobot.constants import HF_LEROBOT_HOME - -LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing" -DUMMY_REPO_ID = "dummy/repo" -DUMMY_ROBOT_TYPE = "dummy_robot" -DUMMY_MOTOR_FEATURES = { - "action": { - "dtype": "float32", - "shape": (6,), - "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], - }, - "state": { - "dtype": "float32", - "shape": (6,), - "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], - }, -} -DUMMY_CAMERA_FEATURES = { - "laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, - "phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, -} -DEFAULT_FPS = 30 -DUMMY_VIDEO_INFO = { - "video.fps": DEFAULT_FPS, - "video.codec": "av1", - "video.pix_fmt": "yuv420p", - "video.is_depth_map": False, - "has_audio": False, -} -DUMMY_CHW = (3, 96, 128) -DUMMY_HWC = (96, 128, 3) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py deleted file mode 100644 index 047db33..0000000 --- a/tests/fixtures/dataset_factories.py +++ /dev/null @@ -1,444 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import random -from functools import partial -from pathlib import Path -from typing import Protocol -from unittest.mock import patch - -import datasets -import numpy as np -import PIL.Image -import pytest -import torch - -from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata -from lerobot.datasets.utils import ( - DEFAULT_CHUNK_SIZE, - DEFAULT_FEATURES, - DEFAULT_PARQUET_PATH, - DEFAULT_VIDEO_PATH, - get_hf_features_from_features, - hf_transform_to_torch, -) -from tests.fixtures.constants import ( - DEFAULT_FPS, - DUMMY_CAMERA_FEATURES, - DUMMY_MOTOR_FEATURES, - DUMMY_REPO_ID, - DUMMY_ROBOT_TYPE, - DUMMY_VIDEO_INFO, -) - - -class LeRobotDatasetFactory(Protocol): - def __call__(self, *args, **kwargs) -> LeRobotDataset: ... - - -def get_task_index(task_dicts: dict, task: str) -> int: - tasks = {d["task_index"]: d["task"] for d in task_dicts.values()} - task_to_task_index = {task: task_idx for task_idx, task in tasks.items()} - return task_to_task_index[task] - - -@pytest.fixture(scope="session") -def img_tensor_factory(): - def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor: - return torch.rand((channels, height, width), dtype=dtype) - - return _create_img_tensor - - -@pytest.fixture(scope="session") -def img_array_factory(): - def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray: - if np.issubdtype(dtype, np.unsignedinteger): - # Int array in [0, 255] range - img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype) - elif np.issubdtype(dtype, np.floating): - # Float array in [0, 1] range - img_array = np.random.rand(height, width, channels).astype(dtype) - else: - raise ValueError(dtype) - return img_array - - return _create_img_array - - -@pytest.fixture(scope="session") -def img_factory(img_array_factory): - def _create_img(height=100, width=100) -> PIL.Image.Image: - img_array = img_array_factory(height=height, width=width) - return PIL.Image.fromarray(img_array) - - return _create_img - - -@pytest.fixture(scope="session") -def features_factory(): - def _create_features( - motor_features: dict = DUMMY_MOTOR_FEATURES, - camera_features: dict = DUMMY_CAMERA_FEATURES, - use_videos: bool = True, - ) -> dict: - if use_videos: - camera_ft = { - key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items() - } - else: - camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()} - return { - **motor_features, - **camera_ft, - **DEFAULT_FEATURES, - } - - return _create_features - - -@pytest.fixture(scope="session") -def info_factory(features_factory): - def _create_info( - codebase_version: str = CODEBASE_VERSION, - fps: int = DEFAULT_FPS, - robot_type: str = DUMMY_ROBOT_TYPE, - total_episodes: int = 0, - total_frames: int = 0, - total_tasks: int = 0, - total_videos: int = 0, - total_chunks: int = 0, - chunks_size: int = DEFAULT_CHUNK_SIZE, - data_path: str = DEFAULT_PARQUET_PATH, - video_path: str = DEFAULT_VIDEO_PATH, - motor_features: dict = DUMMY_MOTOR_FEATURES, - camera_features: dict = DUMMY_CAMERA_FEATURES, - use_videos: bool = True, - ) -> dict: - features = features_factory(motor_features, camera_features, use_videos) - return { - "codebase_version": codebase_version, - "robot_type": robot_type, - "total_episodes": total_episodes, - "total_frames": total_frames, - "total_tasks": total_tasks, - "total_videos": total_videos, - "total_chunks": total_chunks, - "chunks_size": chunks_size, - "fps": fps, - "splits": {}, - "data_path": data_path, - "video_path": video_path if use_videos else None, - "features": features, - } - - return _create_info - - -@pytest.fixture(scope="session") -def stats_factory(): - def _create_stats( - features: dict[str] | None = None, - ) -> dict: - stats = {} - for key, ft in features.items(): - shape = ft["shape"] - dtype = ft["dtype"] - if dtype in ["image", "video"]: - stats[key] = { - "max": np.full((3, 1, 1), 1, dtype=np.float32).tolist(), - "mean": np.full((3, 1, 1), 0.5, dtype=np.float32).tolist(), - "min": np.full((3, 1, 1), 0, dtype=np.float32).tolist(), - "std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(), - "count": [10], - } - else: - stats[key] = { - "max": np.full(shape, 1, dtype=dtype).tolist(), - "mean": np.full(shape, 0.5, dtype=dtype).tolist(), - "min": np.full(shape, 0, dtype=dtype).tolist(), - "std": np.full(shape, 0.25, dtype=dtype).tolist(), - "count": [10], - } - return stats - - return _create_stats - - -@pytest.fixture(scope="session") -def episodes_stats_factory(stats_factory): - def _create_episodes_stats( - features: dict[str], - total_episodes: int = 3, - ) -> dict: - episodes_stats = {} - for episode_index in range(total_episodes): - episodes_stats[episode_index] = { - "episode_index": episode_index, - "stats": stats_factory(features), - } - return episodes_stats - - return _create_episodes_stats - - -@pytest.fixture(scope="session") -def tasks_factory(): - def _create_tasks(total_tasks: int = 3) -> int: - tasks = {} - for task_index in range(total_tasks): - task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."} - tasks[task_index] = task_dict - return tasks - - return _create_tasks - - -@pytest.fixture(scope="session") -def episodes_factory(tasks_factory): - def _create_episodes( - total_episodes: int = 3, - total_frames: int = 400, - tasks: dict | None = None, - multi_task: bool = False, - ): - if total_episodes <= 0 or total_frames <= 0: - raise ValueError("num_episodes and total_length must be positive integers.") - if total_frames < total_episodes: - raise ValueError("total_length must be greater than or equal to num_episodes.") - - if not tasks: - min_tasks = 2 if multi_task else 1 - total_tasks = random.randint(min_tasks, total_episodes) - tasks = tasks_factory(total_tasks) - - if total_episodes < len(tasks) and not multi_task: - raise ValueError("The number of tasks should be less than the number of episodes.") - - # Generate random lengths that sum up to total_length - lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist() - - tasks_list = [task_dict["task"] for task_dict in tasks.values()] - num_tasks_available = len(tasks_list) - - episodes = {} - remaining_tasks = tasks_list.copy() - for ep_idx in range(total_episodes): - num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1 - tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list - episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample))) - if remaining_tasks: - for task in episode_tasks: - remaining_tasks.remove(task) - - episodes[ep_idx] = { - "episode_index": ep_idx, - "tasks": episode_tasks, - "length": lengths[ep_idx], - } - - return episodes - - return _create_episodes - - -@pytest.fixture(scope="session") -def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory): - def _create_hf_dataset( - features: dict | None = None, - tasks: list[dict] | None = None, - episodes: list[dict] | None = None, - fps: int = DEFAULT_FPS, - ) -> datasets.Dataset: - if not tasks: - tasks = tasks_factory() - if not episodes: - episodes = episodes_factory() - if not features: - features = features_factory() - - timestamp_col = np.array([], dtype=np.float32) - frame_index_col = np.array([], dtype=np.int64) - episode_index_col = np.array([], dtype=np.int64) - task_index = np.array([], dtype=np.int64) - for ep_dict in episodes.values(): - timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps)) - frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int))) - episode_index_col = np.concatenate( - (episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int)) - ) - ep_task_index = get_task_index(tasks, ep_dict["tasks"][0]) - task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int))) - - index_col = np.arange(len(episode_index_col)) - - robot_cols = {} - for key, ft in features.items(): - if ft["dtype"] == "image": - robot_cols[key] = [ - img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0]) - for _ in range(len(index_col)) - ] - elif ft["shape"][0] > 1 and ft["dtype"] != "video": - robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"]) - - hf_features = get_hf_features_from_features(features) - dataset = datasets.Dataset.from_dict( - { - **robot_cols, - "timestamp": timestamp_col, - "frame_index": frame_index_col, - "episode_index": episode_index_col, - "index": index_col, - "task_index": task_index, - }, - features=hf_features, - ) - dataset.set_transform(hf_transform_to_torch) - return dataset - - return _create_hf_dataset - - -@pytest.fixture(scope="session") -def lerobot_dataset_metadata_factory( - info_factory, - stats_factory, - episodes_stats_factory, - tasks_factory, - episodes_factory, - mock_snapshot_download_factory, -): - def _create_lerobot_dataset_metadata( - root: Path, - repo_id: str = DUMMY_REPO_ID, - info: dict | None = None, - stats: dict | None = None, - episodes_stats: list[dict] | None = None, - tasks: list[dict] | None = None, - episodes: list[dict] | None = None, - ) -> LeRobotDatasetMetadata: - if not info: - info = info_factory() - if not stats: - stats = stats_factory(features=info["features"]) - if not episodes_stats: - episodes_stats = episodes_stats_factory( - features=info["features"], total_episodes=info["total_episodes"] - ) - if not tasks: - tasks = tasks_factory(total_tasks=info["total_tasks"]) - if not episodes: - episodes = episodes_factory( - total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks - ) - - mock_snapshot_download = mock_snapshot_download_factory( - info=info, - stats=stats, - episodes_stats=episodes_stats, - tasks=tasks, - episodes=episodes, - ) - with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download_patch, - ): - mock_get_safe_version_patch.side_effect = lambda repo_id, version: version - mock_snapshot_download_patch.side_effect = mock_snapshot_download - - return LeRobotDatasetMetadata(repo_id=repo_id, root=root) - - return _create_lerobot_dataset_metadata - - -@pytest.fixture(scope="session") -def lerobot_dataset_factory( - info_factory, - stats_factory, - episodes_stats_factory, - tasks_factory, - episodes_factory, - hf_dataset_factory, - mock_snapshot_download_factory, - lerobot_dataset_metadata_factory, -) -> LeRobotDatasetFactory: - def _create_lerobot_dataset( - root: Path, - repo_id: str = DUMMY_REPO_ID, - total_episodes: int = 3, - total_frames: int = 150, - total_tasks: int = 1, - multi_task: bool = False, - info: dict | None = None, - stats: dict | None = None, - episodes_stats: list[dict] | None = None, - tasks: list[dict] | None = None, - episode_dicts: list[dict] | None = None, - hf_dataset: datasets.Dataset | None = None, - **kwargs, - ) -> LeRobotDataset: - if not info: - info = info_factory( - total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks - ) - if not stats: - stats = stats_factory(features=info["features"]) - if not episodes_stats: - episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes) - if not tasks: - tasks = tasks_factory(total_tasks=info["total_tasks"]) - if not episode_dicts: - episode_dicts = episodes_factory( - total_episodes=info["total_episodes"], - total_frames=info["total_frames"], - tasks=tasks, - multi_task=multi_task, - ) - if not hf_dataset: - hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"]) - - mock_snapshot_download = mock_snapshot_download_factory( - info=info, - stats=stats, - episodes_stats=episodes_stats, - tasks=tasks, - episodes=episode_dicts, - hf_dataset=hf_dataset, - ) - mock_metadata = lerobot_dataset_metadata_factory( - root=root, - repo_id=repo_id, - info=info, - stats=stats, - episodes_stats=episodes_stats, - tasks=tasks, - episodes=episode_dicts, - ) - with ( - patch("lerobot.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch, - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download_patch, - ): - mock_metadata_patch.return_value = mock_metadata - mock_get_safe_version_patch.side_effect = lambda repo_id, version: version - mock_snapshot_download_patch.side_effect = mock_snapshot_download - - return LeRobotDataset(repo_id=repo_id, root=root, **kwargs) - - return _create_lerobot_dataset - - -@pytest.fixture(scope="session") -def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory: - return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS) diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py deleted file mode 100644 index e0553f7..0000000 --- a/tests/fixtures/files.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -from pathlib import Path - -import datasets -import jsonlines -import pyarrow.compute as pc -import pyarrow.parquet as pq -import pytest - -from lerobot.datasets.utils import ( - EPISODES_PATH, - EPISODES_STATS_PATH, - INFO_PATH, - STATS_PATH, - TASKS_PATH, -) - - -@pytest.fixture(scope="session") -def info_path(info_factory): - def _create_info_json_file(dir: Path, info: dict | None = None) -> Path: - if not info: - info = info_factory() - fpath = dir / INFO_PATH - fpath.parent.mkdir(parents=True, exist_ok=True) - with open(fpath, "w") as f: - json.dump(info, f, indent=4, ensure_ascii=False) - return fpath - - return _create_info_json_file - - -@pytest.fixture(scope="session") -def stats_path(stats_factory): - def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path: - if not stats: - stats = stats_factory() - fpath = dir / STATS_PATH - fpath.parent.mkdir(parents=True, exist_ok=True) - with open(fpath, "w") as f: - json.dump(stats, f, indent=4, ensure_ascii=False) - return fpath - - return _create_stats_json_file - - -@pytest.fixture(scope="session") -def episodes_stats_path(episodes_stats_factory): - def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path: - if not episodes_stats: - episodes_stats = episodes_stats_factory() - fpath = dir / EPISODES_STATS_PATH - fpath.parent.mkdir(parents=True, exist_ok=True) - with jsonlines.open(fpath, "w") as writer: - writer.write_all(episodes_stats.values()) - return fpath - - return _create_episodes_stats_jsonl_file - - -@pytest.fixture(scope="session") -def tasks_path(tasks_factory): - def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path: - if not tasks: - tasks = tasks_factory() - fpath = dir / TASKS_PATH - fpath.parent.mkdir(parents=True, exist_ok=True) - with jsonlines.open(fpath, "w") as writer: - writer.write_all(tasks.values()) - return fpath - - return _create_tasks_jsonl_file - - -@pytest.fixture(scope="session") -def episode_path(episodes_factory): - def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path: - if not episodes: - episodes = episodes_factory() - fpath = dir / EPISODES_PATH - fpath.parent.mkdir(parents=True, exist_ok=True) - with jsonlines.open(fpath, "w") as writer: - writer.write_all(episodes.values()) - return fpath - - return _create_episodes_jsonl_file - - -@pytest.fixture(scope="session") -def single_episode_parquet_path(hf_dataset_factory, info_factory): - def _create_single_episode_parquet( - dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None - ) -> Path: - if not info: - info = info_factory() - if hf_dataset is None: - hf_dataset = hf_dataset_factory() - - data_path = info["data_path"] - chunks_size = info["chunks_size"] - ep_chunk = ep_idx // chunks_size - fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx) - fpath.parent.mkdir(parents=True, exist_ok=True) - table = hf_dataset.data.table - ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) - pq.write_table(ep_table, fpath) - return fpath - - return _create_single_episode_parquet - - -@pytest.fixture(scope="session") -def multi_episode_parquet_path(hf_dataset_factory, info_factory): - def _create_multi_episode_parquet( - dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None - ) -> Path: - if not info: - info = info_factory() - if hf_dataset is None: - hf_dataset = hf_dataset_factory() - - data_path = info["data_path"] - chunks_size = info["chunks_size"] - total_episodes = info["total_episodes"] - for ep_idx in range(total_episodes): - ep_chunk = ep_idx // chunks_size - fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx) - fpath.parent.mkdir(parents=True, exist_ok=True) - table = hf_dataset.data.table - ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) - pq.write_table(ep_table, fpath) - return dir / "data" - - return _create_multi_episode_parquet diff --git a/tests/fixtures/hub.py b/tests/fixtures/hub.py deleted file mode 100644 index f7c5f5b..0000000 --- a/tests/fixtures/hub.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from pathlib import Path - -import datasets -import pytest -from huggingface_hub.utils import filter_repo_objects - -from lerobot.datasets.utils import ( - EPISODES_PATH, - EPISODES_STATS_PATH, - INFO_PATH, - STATS_PATH, - TASKS_PATH, -) -from tests.fixtures.constants import LEROBOT_TEST_DIR - - -@pytest.fixture(scope="session") -def mock_snapshot_download_factory( - info_factory, - info_path, - stats_factory, - stats_path, - episodes_stats_factory, - episodes_stats_path, - tasks_factory, - tasks_path, - episodes_factory, - episode_path, - single_episode_parquet_path, - hf_dataset_factory, -): - """ - This factory allows to patch snapshot_download such that when called, it will create expected files rather - than making calls to the hub api. Its design allows to pass explicitly files which you want to be created. - """ - - def _mock_snapshot_download_func( - info: dict | None = None, - stats: dict | None = None, - episodes_stats: list[dict] | None = None, - tasks: list[dict] | None = None, - episodes: list[dict] | None = None, - hf_dataset: datasets.Dataset | None = None, - ): - if not info: - info = info_factory() - if not stats: - stats = stats_factory(features=info["features"]) - if not episodes_stats: - episodes_stats = episodes_stats_factory( - features=info["features"], total_episodes=info["total_episodes"] - ) - if not tasks: - tasks = tasks_factory(total_tasks=info["total_tasks"]) - if not episodes: - episodes = episodes_factory( - total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks - ) - if not hf_dataset: - hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"]) - - def _extract_episode_index_from_path(fpath: str) -> int: - path = Path(fpath) - if path.suffix == ".parquet" and path.stem.startswith("episode_"): - episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0 - return episode_index - else: - return None - - def _mock_snapshot_download( - repo_id: str, - local_dir: str | Path | None = None, - allow_patterns: str | list[str] | None = None, - ignore_patterns: str | list[str] | None = None, - *args, - **kwargs, - ) -> str: - if not local_dir: - local_dir = LEROBOT_TEST_DIR - - # List all possible files - all_files = [] - meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH] - all_files.extend(meta_files) - - data_files = [] - for episode_dict in episodes.values(): - ep_idx = episode_dict["episode_index"] - ep_chunk = ep_idx // info["chunks_size"] - data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx) - data_files.append(data_path) - all_files.extend(data_files) - - allowed_files = filter_repo_objects( - all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns - ) - - # Create allowed files - for rel_path in allowed_files: - if rel_path.startswith("data/"): - episode_index = _extract_episode_index_from_path(rel_path) - if episode_index is not None: - _ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info) - if rel_path == INFO_PATH: - _ = info_path(local_dir, info) - elif rel_path == STATS_PATH: - _ = stats_path(local_dir, stats) - elif rel_path == EPISODES_STATS_PATH: - _ = episodes_stats_path(local_dir, episodes_stats) - elif rel_path == TASKS_PATH: - _ = tasks_path(local_dir, tasks) - elif rel_path == EPISODES_PATH: - _ = episode_path(local_dir, episodes) - else: - pass - return str(local_dir) - - return _mock_snapshot_download - - return _mock_snapshot_download_func diff --git a/tests/fixtures/optimizers.py b/tests/fixtures/optimizers.py deleted file mode 100644 index a1b4a9d..0000000 --- a/tests/fixtures/optimizers.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest -import torch - -from lerobot.optim.optimizers import AdamConfig -from lerobot.optim.schedulers import VQBeTSchedulerConfig - - -@pytest.fixture -def model_params(): - return [torch.nn.Parameter(torch.randn(10, 10))] - - -@pytest.fixture -def optimizer(model_params): - optimizer = AdamConfig().build(model_params) - # Dummy step to populate state - loss = sum(param.sum() for param in model_params) - loss.backward() - optimizer.step() - return optimizer - - -@pytest.fixture -def scheduler(optimizer): - config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5) - return config.build(optimizer, num_training_steps=100) diff --git a/tests/mocks/mock_dynamixel.py b/tests/mocks/mock_dynamixel.py deleted file mode 100644 index 84026fc..0000000 --- a/tests/mocks/mock_dynamixel.py +++ /dev/null @@ -1,596 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -from collections.abc import Callable - -import dynamixel_sdk as dxl -import serial -from mock_serial.mock_serial import MockSerial - -from lerobot.motors.dynamixel.dynamixel import _split_into_byte_chunks - -from .mock_serial_patch import WaitableStub - -# https://emanual.robotis.com/docs/en/dxl/crc/ -DXL_CRC_TABLE = [ - 0x0000, 0x8005, 0x800F, 0x000A, 0x801B, 0x001E, 0x0014, 0x8011, - 0x8033, 0x0036, 0x003C, 0x8039, 0x0028, 0x802D, 0x8027, 0x0022, - 0x8063, 0x0066, 0x006C, 0x8069, 0x0078, 0x807D, 0x8077, 0x0072, - 0x0050, 0x8055, 0x805F, 0x005A, 0x804B, 0x004E, 0x0044, 0x8041, - 0x80C3, 0x00C6, 0x00CC, 0x80C9, 0x00D8, 0x80DD, 0x80D7, 0x00D2, - 0x00F0, 0x80F5, 0x80FF, 0x00FA, 0x80EB, 0x00EE, 0x00E4, 0x80E1, - 0x00A0, 0x80A5, 0x80AF, 0x00AA, 0x80BB, 0x00BE, 0x00B4, 0x80B1, - 0x8093, 0x0096, 0x009C, 0x8099, 0x0088, 0x808D, 0x8087, 0x0082, - 0x8183, 0x0186, 0x018C, 0x8189, 0x0198, 0x819D, 0x8197, 0x0192, - 0x01B0, 0x81B5, 0x81BF, 0x01BA, 0x81AB, 0x01AE, 0x01A4, 0x81A1, - 0x01E0, 0x81E5, 0x81EF, 0x01EA, 0x81FB, 0x01FE, 0x01F4, 0x81F1, - 0x81D3, 0x01D6, 0x01DC, 0x81D9, 0x01C8, 0x81CD, 0x81C7, 0x01C2, - 0x0140, 0x8145, 0x814F, 0x014A, 0x815B, 0x015E, 0x0154, 0x8151, - 0x8173, 0x0176, 0x017C, 0x8179, 0x0168, 0x816D, 0x8167, 0x0162, - 0x8123, 0x0126, 0x012C, 0x8129, 0x0138, 0x813D, 0x8137, 0x0132, - 0x0110, 0x8115, 0x811F, 0x011A, 0x810B, 0x010E, 0x0104, 0x8101, - 0x8303, 0x0306, 0x030C, 0x8309, 0x0318, 0x831D, 0x8317, 0x0312, - 0x0330, 0x8335, 0x833F, 0x033A, 0x832B, 0x032E, 0x0324, 0x8321, - 0x0360, 0x8365, 0x836F, 0x036A, 0x837B, 0x037E, 0x0374, 0x8371, - 0x8353, 0x0356, 0x035C, 0x8359, 0x0348, 0x834D, 0x8347, 0x0342, - 0x03C0, 0x83C5, 0x83CF, 0x03CA, 0x83DB, 0x03DE, 0x03D4, 0x83D1, - 0x83F3, 0x03F6, 0x03FC, 0x83F9, 0x03E8, 0x83ED, 0x83E7, 0x03E2, - 0x83A3, 0x03A6, 0x03AC, 0x83A9, 0x03B8, 0x83BD, 0x83B7, 0x03B2, - 0x0390, 0x8395, 0x839F, 0x039A, 0x838B, 0x038E, 0x0384, 0x8381, - 0x0280, 0x8285, 0x828F, 0x028A, 0x829B, 0x029E, 0x0294, 0x8291, - 0x82B3, 0x02B6, 0x02BC, 0x82B9, 0x02A8, 0x82AD, 0x82A7, 0x02A2, - 0x82E3, 0x02E6, 0x02EC, 0x82E9, 0x02F8, 0x82FD, 0x82F7, 0x02F2, - 0x02D0, 0x82D5, 0x82DF, 0x02DA, 0x82CB, 0x02CE, 0x02C4, 0x82C1, - 0x8243, 0x0246, 0x024C, 0x8249, 0x0258, 0x825D, 0x8257, 0x0252, - 0x0270, 0x8275, 0x827F, 0x027A, 0x826B, 0x026E, 0x0264, 0x8261, - 0x0220, 0x8225, 0x822F, 0x022A, 0x823B, 0x023E, 0x0234, 0x8231, - 0x8213, 0x0216, 0x021C, 0x8219, 0x0208, 0x820D, 0x8207, 0x0202 -] # fmt: skip - - -class MockDynamixelPacketv2(abc.ABC): - @classmethod - def build(cls, dxl_id: int, params: list[int], length: int, *args, **kwargs) -> bytes: - packet = cls._build(dxl_id, params, length, *args, **kwargs) - packet = cls._add_stuffing(packet) - packet = cls._add_crc(packet) - return bytes(packet) - - @abc.abstractclassmethod - def _build(cls, dxl_id: int, params: list[int], length: int, *args, **kwargs) -> list[int]: - pass - - @staticmethod - def _add_stuffing(packet: list[int]) -> list[int]: - """ - Byte stuffing is a method of adding additional data to generated instruction packets to ensure that - the packets are processed successfully. When the byte pattern "0xFF 0xFF 0xFD" appears in a packet, - byte stuffing adds 0xFD to the end of the pattern to convert it to “0xFF 0xFF 0xFD 0xFD” to ensure - that it is not interpreted as the header at the start of another packet. - - Source: https://emanual.robotis.com/docs/en/dxl/protocol2/#transmission-process - - Args: - packet (list[int]): The raw packet without stuffing. - - Returns: - list[int]: The packet stuffed if it contained a "0xFF 0xFF 0xFD" byte sequence in its data bytes. - """ - packet_length_in = dxl.DXL_MAKEWORD(packet[dxl.PKT_LENGTH_L], packet[dxl.PKT_LENGTH_H]) - packet_length_out = packet_length_in - - temp = [0] * dxl.TXPACKET_MAX_LEN - - # FF FF FD XX ID LEN_L LEN_H - temp[dxl.PKT_HEADER0 : dxl.PKT_HEADER0 + dxl.PKT_LENGTH_H + 1] = packet[ - dxl.PKT_HEADER0 : dxl.PKT_HEADER0 + dxl.PKT_LENGTH_H + 1 - ] - - index = dxl.PKT_INSTRUCTION - - for i in range(0, packet_length_in - 2): # except CRC - temp[index] = packet[i + dxl.PKT_INSTRUCTION] - index = index + 1 - if ( - packet[i + dxl.PKT_INSTRUCTION] == 0xFD - and packet[i + dxl.PKT_INSTRUCTION - 1] == 0xFF - and packet[i + dxl.PKT_INSTRUCTION - 2] == 0xFF - ): - # FF FF FD - temp[index] = 0xFD - index = index + 1 - packet_length_out = packet_length_out + 1 - - temp[index] = packet[dxl.PKT_INSTRUCTION + packet_length_in - 2] - temp[index + 1] = packet[dxl.PKT_INSTRUCTION + packet_length_in - 1] - index = index + 2 - - if packet_length_in != packet_length_out: - packet = [0] * index - - packet[0:index] = temp[0:index] - - packet[dxl.PKT_LENGTH_L] = dxl.DXL_LOBYTE(packet_length_out) - packet[dxl.PKT_LENGTH_H] = dxl.DXL_HIBYTE(packet_length_out) - - return packet - - @staticmethod - def _add_crc(packet: list[int]) -> list[int]: - """Computes and add CRC to the packet. - - https://emanual.robotis.com/docs/en/dxl/crc/ - https://en.wikipedia.org/wiki/Cyclic_redundancy_check - - Args: - packet (list[int]): The raw packet without CRC (but with placeholders for it). - - Returns: - list[int]: The raw packet with a valid CRC. - """ - crc = 0 - for j in range(len(packet) - 2): - i = ((crc >> 8) ^ packet[j]) & 0xFF - crc = ((crc << 8) ^ DXL_CRC_TABLE[i]) & 0xFFFF - - packet[-2] = dxl.DXL_LOBYTE(crc) - packet[-1] = dxl.DXL_HIBYTE(crc) - - return packet - - -class MockInstructionPacket(MockDynamixelPacketv2): - """ - Helper class to build valid Dynamixel Protocol 2.0 Instruction Packets. - - Protocol 2.0 Instruction Packet structure - https://emanual.robotis.com/docs/en/dxl/protocol2/#instruction-packet - - | Header | Packet ID | Length | Instruction | Params | CRC | - | ------------------- | --------- | ----------- | ----------- | ----------------- | ----------- | - | 0xFF 0xFF 0xFD 0x00 | ID | Len_L Len_H | Instr | Param 1 … Param N | CRC_L CRC_H | - - """ - - @classmethod - def _build(cls, dxl_id: int, params: list[int], length: int, instruction: int) -> list[int]: - length = len(params) + 3 - return [ - 0xFF, 0xFF, 0xFD, 0x00, # header - dxl_id, # servo id - dxl.DXL_LOBYTE(length), # length_l - dxl.DXL_HIBYTE(length), # length_h - instruction, # instruction type - *params, # data bytes - 0x00, 0x00 # placeholder for CRC - ] # fmt: skip - - @classmethod - def ping( - cls, - dxl_id: int, - ) -> bytes: - """ - Builds a "Ping" broadcast instruction. - https://emanual.robotis.com/docs/en/dxl/protocol2/#ping-0x01 - - No parameters required. - """ - return cls.build(dxl_id=dxl_id, params=[], length=3, instruction=dxl.INST_PING) - - @classmethod - def read( - cls, - dxl_id: int, - start_address: int, - data_length: int, - ) -> bytes: - """ - Builds a "Read" instruction. - https://emanual.robotis.com/docs/en/dxl/protocol2/#read-0x02 - - The parameters for Read (Protocol 2.0) are: - param[0] = start_address L - param[1] = start_address H - param[2] = data_length L - param[3] = data_length H - - And 'length' = data_length + 5, where: - +1 is for instruction byte, - +2 is for the length bytes, - +2 is for the CRC at the end. - """ - params = [ - dxl.DXL_LOBYTE(start_address), - dxl.DXL_HIBYTE(start_address), - dxl.DXL_LOBYTE(data_length), - dxl.DXL_HIBYTE(data_length), - ] - length = len(params) + 3 - # length = data_length + 5 - return cls.build(dxl_id=dxl_id, params=params, length=length, instruction=dxl.INST_READ) - - @classmethod - def write( - cls, - dxl_id: int, - value: int, - start_address: int, - data_length: int, - ) -> bytes: - """ - Builds a "Write" instruction. - https://emanual.robotis.com/docs/en/dxl/protocol2/#write-0x03 - - The parameters for Write (Protocol 2.0) are: - param[0] = start_address L - param[1] = start_address H - param[2] = 1st Byte - param[3] = 2nd Byte - ... - param[1+X] = X-th Byte - - And 'length' = data_length + 5, where: - +1 is for instruction byte, - +2 is for the length bytes, - +2 is for the CRC at the end. - """ - data = _split_into_byte_chunks(value, data_length) - params = [ - dxl.DXL_LOBYTE(start_address), - dxl.DXL_HIBYTE(start_address), - *data, - ] - length = data_length + 5 - return cls.build(dxl_id=dxl_id, params=params, length=length, instruction=dxl.INST_WRITE) - - @classmethod - def sync_read( - cls, - dxl_ids: list[int], - start_address: int, - data_length: int, - ) -> bytes: - """ - Builds a "Sync_Read" broadcast instruction. - https://emanual.robotis.com/docs/en/dxl/protocol2/#sync-read-0x82 - - The parameters for Sync_Read (Protocol 2.0) are: - param[0] = start_address L - param[1] = start_address H - param[2] = data_length L - param[3] = data_length H - param[4+] = motor IDs to read from - - And 'length' = (number_of_params + 7), where: - +1 is for instruction byte, - +2 is for the address bytes, - +2 is for the length bytes, - +2 is for the CRC at the end. - """ - params = [ - dxl.DXL_LOBYTE(start_address), - dxl.DXL_HIBYTE(start_address), - dxl.DXL_LOBYTE(data_length), - dxl.DXL_HIBYTE(data_length), - *dxl_ids, - ] - length = len(dxl_ids) + 7 - return cls.build( - dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruction=dxl.INST_SYNC_READ - ) - - @classmethod - def sync_write( - cls, - ids_values: dict[int, int], - start_address: int, - data_length: int, - ) -> bytes: - """ - Builds a "Sync_Write" broadcast instruction. - https://emanual.robotis.com/docs/en/dxl/protocol2/#sync-write-0x83 - - The parameters for Sync_Write (Protocol 2.0) are: - param[0] = start_address L - param[1] = start_address H - param[2] = data_length L - param[3] = data_length H - param[5] = [1st motor] ID - param[5+1] = [1st motor] 1st Byte - param[5+2] = [1st motor] 2nd Byte - ... - param[5+X] = [1st motor] X-th Byte - param[6] = [2nd motor] ID - param[6+1] = [2nd motor] 1st Byte - param[6+2] = [2nd motor] 2nd Byte - ... - param[6+X] = [2nd motor] X-th Byte - - And 'length' = ((number_of_params * 1 + data_length) + 7), where: - +1 is for instruction byte, - +2 is for the address bytes, - +2 is for the length bytes, - +2 is for the CRC at the end. - """ - data = [] - for id_, value in ids_values.items(): - split_value = _split_into_byte_chunks(value, data_length) - data += [id_, *split_value] - params = [ - dxl.DXL_LOBYTE(start_address), - dxl.DXL_HIBYTE(start_address), - dxl.DXL_LOBYTE(data_length), - dxl.DXL_HIBYTE(data_length), - *data, - ] - length = len(ids_values) * (1 + data_length) + 7 - return cls.build( - dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruction=dxl.INST_SYNC_WRITE - ) - - -class MockStatusPacket(MockDynamixelPacketv2): - """ - Helper class to build valid Dynamixel Protocol 2.0 Status Packets. - - Protocol 2.0 Status Packet structure - https://emanual.robotis.com/docs/en/dxl/protocol2/#status-packet - - | Header | Packet ID | Length | Instruction | Error | Params | CRC | - | ------------------- | --------- | ----------- | ----------- | ----- | ----------------- | ----------- | - | 0xFF 0xFF 0xFD 0x00 | ID | Len_L Len_H | 0x55 | Err | Param 1 … Param N | CRC_L CRC_H | - """ - - @classmethod - def _build(cls, dxl_id: int, params: list[int], length: int, error: int = 0) -> list[int]: - return [ - 0xFF, 0xFF, 0xFD, 0x00, # header - dxl_id, # servo id - dxl.DXL_LOBYTE(length), # length_l - dxl.DXL_HIBYTE(length), # length_h - 0x55, # instruction = 'status' - error, # error - *params, # data bytes - 0x00, 0x00 # placeholder for CRC - ] # fmt: skip - - @classmethod - def ping(cls, dxl_id: int, model_nb: int = 1190, firm_ver: int = 50, error: int = 0) -> bytes: - """ - Builds a 'Ping' status packet. - https://emanual.robotis.com/docs/en/dxl/protocol2/#ping-0x01 - - Args: - dxl_id (int): ID of the servo responding. - model_nb (int, optional): Desired 'model number' to be returned in the packet. Defaults to 1190 - which corresponds to a XL330-M077-T. - firm_ver (int, optional): Desired 'firmware version' to be returned in the packet. - Defaults to 50. - - Returns: - bytes: The raw 'Ping' status packet ready to be sent through serial. - """ - params = [dxl.DXL_LOBYTE(model_nb), dxl.DXL_HIBYTE(model_nb), firm_ver] - length = 7 - return cls.build(dxl_id, params=params, length=length, error=error) - - @classmethod - def read(cls, dxl_id: int, value: int, param_length: int, error: int = 0) -> bytes: - """ - Builds a 'Read' status packet (also works for 'Sync Read') - https://emanual.robotis.com/docs/en/dxl/protocol2/#read-0x02 - https://emanual.robotis.com/docs/en/dxl/protocol2/#sync-read-0x82 - - Args: - dxl_id (int): ID of the servo responding. - value (int): Desired value to be returned in the packet. - param_length (int): The address length as reported in the control table. - - Returns: - bytes: The raw 'Present_Position' status packet ready to be sent through serial. - """ - params = _split_into_byte_chunks(value, param_length) - length = param_length + 4 - return cls.build(dxl_id, params=params, length=length, error=error) - - -class MockPortHandler(dxl.PortHandler): - """ - This class overwrite the 'setupPort' method of the Dynamixel PortHandler because it can specify - baudrates that are not supported with a serial port on MacOS. - """ - - def setupPort(self, cflag_baud): # noqa: N802 - if self.is_open: - self.closePort() - - self.ser = serial.Serial( - port=self.port_name, - # baudrate=self.baudrate, <- This will fail on MacOS - # parity = serial.PARITY_ODD, - # stopbits = serial.STOPBITS_TWO, - bytesize=serial.EIGHTBITS, - timeout=0, - ) - self.is_open = True - self.ser.reset_input_buffer() - self.tx_time_per_byte = (1000.0 / self.baudrate) * 10.0 - - return True - - -class MockMotors(MockSerial): - """ - This class will simulate physical motors by responding with valid status packets upon receiving some - instruction packets. It is meant to test MotorsBus classes. - """ - - def __init__(self): - super().__init__() - - @property - def stubs(self) -> dict[str, WaitableStub]: - return super().stubs - - def stub(self, *, name=None, **kwargs): - new_stub = WaitableStub(**kwargs) - self._MockSerial__stubs[name or new_stub.receive_bytes] = new_stub - return new_stub - - def build_broadcast_ping_stub( - self, ids_models: dict[int, list[int]] | None = None, num_invalid_try: int = 0 - ) -> str: - ping_request = MockInstructionPacket.ping(dxl.BROADCAST_ID) - return_packets = b"".join(MockStatusPacket.ping(id_, model) for id_, model in ids_models.items()) - ping_response = self._build_send_fn(return_packets, num_invalid_try) - - stub_name = "Ping_" + "_".join([str(id_) for id_ in ids_models]) - self.stub( - name=stub_name, - receive_bytes=ping_request, - send_fn=ping_response, - ) - return stub_name - - def build_ping_stub( - self, dxl_id: int, model_nb: int, firm_ver: int = 50, num_invalid_try: int = 0, error: int = 0 - ) -> str: - ping_request = MockInstructionPacket.ping(dxl_id) - return_packet = MockStatusPacket.ping(dxl_id, model_nb, firm_ver, error) - ping_response = self._build_send_fn(return_packet, num_invalid_try) - stub_name = f"Ping_{dxl_id}" - self.stub( - name=stub_name, - receive_bytes=ping_request, - send_fn=ping_response, - ) - return stub_name - - def build_read_stub( - self, - address: int, - length: int, - dxl_id: int, - value: int, - reply: bool = True, - error: int = 0, - num_invalid_try: int = 0, - ) -> str: - read_request = MockInstructionPacket.read(dxl_id, address, length) - return_packet = MockStatusPacket.read(dxl_id, value, length, error) if reply else b"" - read_response = self._build_send_fn(return_packet, num_invalid_try) - stub_name = f"Read_{address}_{length}_{dxl_id}_{value}_{error}" - self.stub( - name=stub_name, - receive_bytes=read_request, - send_fn=read_response, - ) - return stub_name - - def build_write_stub( - self, - address: int, - length: int, - dxl_id: int, - value: int, - reply: bool = True, - error: int = 0, - num_invalid_try: int = 0, - ) -> str: - sync_read_request = MockInstructionPacket.write(dxl_id, value, address, length) - return_packet = MockStatusPacket.build(dxl_id, params=[], length=4, error=error) if reply else b"" - stub_name = f"Write_{address}_{length}_{dxl_id}" - self.stub( - name=stub_name, - receive_bytes=sync_read_request, - send_fn=self._build_send_fn(return_packet, num_invalid_try), - ) - return stub_name - - def build_sync_read_stub( - self, - address: int, - length: int, - ids_values: dict[int, int], - reply: bool = True, - num_invalid_try: int = 0, - ) -> str: - sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length) - return_packets = ( - b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items()) - if reply - else b"" - ) - sync_read_response = self._build_send_fn(return_packets, num_invalid_try) - stub_name = f"Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values]) - self.stub( - name=stub_name, - receive_bytes=sync_read_request, - send_fn=sync_read_response, - ) - return stub_name - - def build_sequential_sync_read_stub( - self, address: int, length: int, ids_values: dict[int, list[int]] | None = None - ) -> str: - sequence_length = len(next(iter(ids_values.values()))) - assert all(len(positions) == sequence_length for positions in ids_values.values()) - sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length) - sequential_packets = [] - for count in range(sequence_length): - return_packets = b"".join( - MockStatusPacket.read(id_, positions[count], length) for id_, positions in ids_values.items() - ) - sequential_packets.append(return_packets) - - sync_read_response = self._build_sequential_send_fn(sequential_packets) - stub_name = f"Seq_Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values]) - self.stub( - name=stub_name, - receive_bytes=sync_read_request, - send_fn=sync_read_response, - ) - return stub_name - - def build_sync_write_stub( - self, address: int, length: int, ids_values: dict[int, int], num_invalid_try: int = 0 - ) -> str: - sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length) - stub_name = f"Sync_Write_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values]) - self.stub( - name=stub_name, - receive_bytes=sync_read_request, - send_fn=self._build_send_fn(b"", num_invalid_try), - ) - return stub_name - - @staticmethod - def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]: - def send_fn(_call_count: int) -> bytes: - if num_invalid_try >= _call_count: - return b"" - return packet - - return send_fn - - @staticmethod - def _build_sequential_send_fn(packets: list[bytes]) -> Callable[[int], bytes]: - def send_fn(_call_count: int) -> bytes: - return packets[_call_count - 1] - - return send_fn diff --git a/tests/mocks/mock_feetech.py b/tests/mocks/mock_feetech.py deleted file mode 100644 index 33cbc41..0000000 --- a/tests/mocks/mock_feetech.py +++ /dev/null @@ -1,444 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -from collections.abc import Callable - -import scservo_sdk as scs -import serial -from mock_serial import MockSerial - -from lerobot.motors.feetech.feetech import _split_into_byte_chunks, patch_setPacketTimeout - -from .mock_serial_patch import WaitableStub - - -class MockFeetechPacket(abc.ABC): - @classmethod - def build(cls, scs_id: int, params: list[int], length: int, *args, **kwargs) -> bytes: - packet = cls._build(scs_id, params, length, *args, **kwargs) - packet = cls._add_checksum(packet) - return bytes(packet) - - @abc.abstractclassmethod - def _build(cls, scs_id: int, params: list[int], length: int, *args, **kwargs) -> list[int]: - pass - - @staticmethod - def _add_checksum(packet: list[int]) -> list[int]: - checksum = 0 - for id_ in range(2, len(packet) - 1): # except header & checksum - checksum += packet[id_] - - packet[-1] = ~checksum & 0xFF - - return packet - - -class MockInstructionPacket(MockFeetechPacket): - """ - Helper class to build valid Feetech Instruction Packets. - - Instruction Packet structure - (from https://files.waveshare.com/upload/2/27/Communication_Protocol_User_Manual-EN%28191218-0923%29.pdf) - - | Header | Packet ID | Length | Instruction | Params | Checksum | - | --------- | --------- | ------ | ----------- | ----------------- | -------- | - | 0xFF 0xFF | ID | Len | Instr | Param 1 … Param N | Sum | - - """ - - @classmethod - def _build(cls, scs_id: int, params: list[int], length: int, instruction: int) -> list[int]: - return [ - 0xFF, 0xFF, # header - scs_id, # servo id - length, # length - instruction, # instruction type - *params, # data bytes - 0x00, # placeholder for checksum - ] # fmt: skip - - @classmethod - def ping( - cls, - scs_id: int, - ) -> bytes: - """ - Builds a "Ping" broadcast instruction. - - No parameters required. - """ - return cls.build(scs_id=scs_id, params=[], length=2, instruction=scs.INST_PING) - - @classmethod - def read( - cls, - scs_id: int, - start_address: int, - data_length: int, - ) -> bytes: - """ - Builds a "Read" instruction. - - The parameters for Read are: - param[0] = start_address - param[1] = data_length - - And 'length' = 4, where: - +1 is for instruction byte, - +1 is for the address byte, - +1 is for the length bytes, - +1 is for the checksum at the end. - """ - params = [start_address, data_length] - length = 4 - return cls.build(scs_id=scs_id, params=params, length=length, instruction=scs.INST_READ) - - @classmethod - def write( - cls, - scs_id: int, - value: int, - start_address: int, - data_length: int, - ) -> bytes: - """ - Builds a "Write" instruction. - - The parameters for Write are: - param[0] = start_address L - param[1] = start_address H - param[2] = 1st Byte - param[3] = 2nd Byte - ... - param[1+X] = X-th Byte - - And 'length' = data_length + 3, where: - +1 is for instruction byte, - +1 is for the length bytes, - +1 is for the checksum at the end. - """ - data = _split_into_byte_chunks(value, data_length) - params = [start_address, *data] - length = data_length + 3 - return cls.build(scs_id=scs_id, params=params, length=length, instruction=scs.INST_WRITE) - - @classmethod - def sync_read( - cls, - scs_ids: list[int], - start_address: int, - data_length: int, - ) -> bytes: - """ - Builds a "Sync_Read" broadcast instruction. - - The parameters for Sync Read are: - param[0] = start_address - param[1] = data_length - param[2+] = motor IDs to read from - - And 'length' = (number_of_params + 4), where: - +1 is for instruction byte, - +1 is for the address byte, - +1 is for the length bytes, - +1 is for the checksum at the end. - """ - params = [start_address, data_length, *scs_ids] - length = len(scs_ids) + 4 - return cls.build( - scs_id=scs.BROADCAST_ID, params=params, length=length, instruction=scs.INST_SYNC_READ - ) - - @classmethod - def sync_write( - cls, - ids_values: dict[int, int], - start_address: int, - data_length: int, - ) -> bytes: - """ - Builds a "Sync_Write" broadcast instruction. - - The parameters for Sync_Write are: - param[0] = start_address - param[1] = data_length - param[2] = [1st motor] ID - param[2+1] = [1st motor] 1st Byte - param[2+2] = [1st motor] 2nd Byte - ... - param[5+X] = [1st motor] X-th Byte - param[6] = [2nd motor] ID - param[6+1] = [2nd motor] 1st Byte - param[6+2] = [2nd motor] 2nd Byte - ... - param[6+X] = [2nd motor] X-th Byte - - And 'length' = ((number_of_params * 1 + data_length) + 4), where: - +1 is for instruction byte, - +1 is for the address byte, - +1 is for the length bytes, - +1 is for the checksum at the end. - """ - data = [] - for id_, value in ids_values.items(): - split_value = _split_into_byte_chunks(value, data_length) - data += [id_, *split_value] - params = [start_address, data_length, *data] - length = len(ids_values) * (1 + data_length) + 4 - return cls.build( - scs_id=scs.BROADCAST_ID, params=params, length=length, instruction=scs.INST_SYNC_WRITE - ) - - -class MockStatusPacket(MockFeetechPacket): - """ - Helper class to build valid Feetech Status Packets. - - Status Packet structure - (from https://files.waveshare.com/upload/2/27/Communication_Protocol_User_Manual-EN%28191218-0923%29.pdf) - - | Header | Packet ID | Length | Error | Params | Checksum | - | --------- | --------- | ------ | ----- | ----------------- | -------- | - | 0xFF 0xFF | ID | Len | Err | Param 1 … Param N | Sum | - - """ - - @classmethod - def _build(cls, scs_id: int, params: list[int], length: int, error: int = 0) -> list[int]: - return [ - 0xFF, 0xFF, # header - scs_id, # servo id - length, # length - error, # status - *params, # data bytes - 0x00, # placeholder for checksum - ] # fmt: skip - - @classmethod - def ping(cls, scs_id: int, error: int = 0) -> bytes: - """Builds a 'Ping' status packet. - - Args: - scs_id (int): ID of the servo responding. - error (int, optional): Error to be returned. Defaults to 0 (success). - - Returns: - bytes: The raw 'Ping' status packet ready to be sent through serial. - """ - return cls.build(scs_id, params=[], length=2, error=error) - - @classmethod - def read(cls, scs_id: int, value: int, param_length: int, error: int = 0) -> bytes: - """Builds a 'Read' status packet. - - Args: - scs_id (int): ID of the servo responding. - value (int): Desired value to be returned in the packet. - param_length (int): The address length as reported in the control table. - - Returns: - bytes: The raw 'Sync Read' status packet ready to be sent through serial. - """ - params = _split_into_byte_chunks(value, param_length) - length = param_length + 2 - return cls.build(scs_id, params=params, length=length, error=error) - - -class MockPortHandler(scs.PortHandler): - """ - This class overwrite the 'setupPort' method of the Feetech PortHandler because it can specify - baudrates that are not supported with a serial port on MacOS. - """ - - def setupPort(self, cflag_baud): # noqa: N802 - if self.is_open: - self.closePort() - - self.ser = serial.Serial( - port=self.port_name, - # baudrate=self.baudrate, <- This will fail on MacOS - # parity = serial.PARITY_ODD, - # stopbits = serial.STOPBITS_TWO, - bytesize=serial.EIGHTBITS, - timeout=0, - ) - self.is_open = True - self.ser.reset_input_buffer() - self.tx_time_per_byte = (1000.0 / self.baudrate) * 10.0 - - return True - - def setPacketTimeout(self, packet_length): # noqa: N802 - return patch_setPacketTimeout(self, packet_length) - - -class MockMotors(MockSerial): - """ - This class will simulate physical motors by responding with valid status packets upon receiving some - instruction packets. It is meant to test MotorsBus classes. - """ - - def __init__(self): - super().__init__() - - @property - def stubs(self) -> dict[str, WaitableStub]: - return super().stubs - - def stub(self, *, name=None, **kwargs): - new_stub = WaitableStub(**kwargs) - self._MockSerial__stubs[name or new_stub.receive_bytes] = new_stub - return new_stub - - def build_broadcast_ping_stub(self, ids: list[int] | None = None, num_invalid_try: int = 0) -> str: - ping_request = MockInstructionPacket.ping(scs.BROADCAST_ID) - return_packets = b"".join(MockStatusPacket.ping(id_) for id_ in ids) - ping_response = self._build_send_fn(return_packets, num_invalid_try) - stub_name = "Ping_" + "_".join([str(id_) for id_ in ids]) - self.stub( - name=stub_name, - receive_bytes=ping_request, - send_fn=ping_response, - ) - return stub_name - - def build_ping_stub(self, scs_id: int, num_invalid_try: int = 0, error: int = 0) -> str: - ping_request = MockInstructionPacket.ping(scs_id) - return_packet = MockStatusPacket.ping(scs_id, error) - ping_response = self._build_send_fn(return_packet, num_invalid_try) - stub_name = f"Ping_{scs_id}_{error}" - self.stub( - name=stub_name, - receive_bytes=ping_request, - send_fn=ping_response, - ) - return stub_name - - def build_read_stub( - self, - address: int, - length: int, - scs_id: int, - value: int, - reply: bool = True, - error: int = 0, - num_invalid_try: int = 0, - ) -> str: - read_request = MockInstructionPacket.read(scs_id, address, length) - return_packet = MockStatusPacket.read(scs_id, value, length, error) if reply else b"" - read_response = self._build_send_fn(return_packet, num_invalid_try) - stub_name = f"Read_{address}_{length}_{scs_id}_{value}_{error}" - self.stub( - name=stub_name, - receive_bytes=read_request, - send_fn=read_response, - ) - return stub_name - - def build_write_stub( - self, - address: int, - length: int, - scs_id: int, - value: int, - reply: bool = True, - error: int = 0, - num_invalid_try: int = 0, - ) -> str: - sync_read_request = MockInstructionPacket.write(scs_id, value, address, length) - return_packet = MockStatusPacket.build(scs_id, params=[], length=2, error=error) if reply else b"" - stub_name = f"Write_{address}_{length}_{scs_id}" - self.stub( - name=stub_name, - receive_bytes=sync_read_request, - send_fn=self._build_send_fn(return_packet, num_invalid_try), - ) - return stub_name - - def build_sync_read_stub( - self, - address: int, - length: int, - ids_values: dict[int, int], - reply: bool = True, - num_invalid_try: int = 0, - ) -> str: - sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length) - return_packets = ( - b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items()) - if reply - else b"" - ) - sync_read_response = self._build_send_fn(return_packets, num_invalid_try) - stub_name = f"Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values]) - self.stub( - name=stub_name, - receive_bytes=sync_read_request, - send_fn=sync_read_response, - ) - return stub_name - - def build_sequential_sync_read_stub( - self, address: int, length: int, ids_values: dict[int, list[int]] | None = None - ) -> str: - sequence_length = len(next(iter(ids_values.values()))) - assert all(len(positions) == sequence_length for positions in ids_values.values()) - sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length) - sequential_packets = [] - for count in range(sequence_length): - return_packets = b"".join( - MockStatusPacket.read(id_, positions[count], length) for id_, positions in ids_values.items() - ) - sequential_packets.append(return_packets) - - sync_read_response = self._build_sequential_send_fn(sequential_packets) - stub_name = f"Seq_Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values]) - self.stub( - name=stub_name, - receive_bytes=sync_read_request, - send_fn=sync_read_response, - ) - return stub_name - - def build_sync_write_stub( - self, address: int, length: int, ids_values: dict[int, int], num_invalid_try: int = 0 - ) -> str: - sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length) - stub_name = f"Sync_Write_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values]) - self.stub( - name=stub_name, - receive_bytes=sync_read_request, - send_fn=self._build_send_fn(b"", num_invalid_try), - ) - return stub_name - - @staticmethod - def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]: - def send_fn(_call_count: int) -> bytes: - if num_invalid_try >= _call_count: - return b"" - return packet - - return send_fn - - @staticmethod - def _build_sequential_send_fn(packets: list[bytes]) -> Callable[[int], bytes]: - def send_fn(_call_count: int) -> bytes: - return packets[_call_count - 1] - - return send_fn diff --git a/tests/mocks/mock_motors_bus.py b/tests/mocks/mock_motors_bus.py deleted file mode 100644 index a499dbf..0000000 --- a/tests/mocks/mock_motors_bus.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# ruff: noqa: N802 - -from lerobot.motors.motors_bus import ( - Motor, - MotorsBus, -) - -DUMMY_CTRL_TABLE_1 = { - "Firmware_Version": (0, 1), - "Model_Number": (1, 2), - "Present_Position": (3, 4), - "Goal_Position": (11, 2), -} - -DUMMY_CTRL_TABLE_2 = { - "Model_Number": (0, 2), - "Firmware_Version": (2, 1), - "Present_Position": (3, 4), - "Present_Velocity": (7, 4), - "Goal_Position": (11, 4), - "Goal_Velocity": (15, 4), - "Lock": (19, 1), -} - -DUMMY_MODEL_CTRL_TABLE = { - "model_1": DUMMY_CTRL_TABLE_1, - "model_2": DUMMY_CTRL_TABLE_2, - "model_3": DUMMY_CTRL_TABLE_2, -} - -DUMMY_BAUDRATE_TABLE = { - 0: 1_000_000, - 1: 500_000, - 2: 250_000, -} - -DUMMY_MODEL_BAUDRATE_TABLE = { - "model_1": DUMMY_BAUDRATE_TABLE, - "model_2": DUMMY_BAUDRATE_TABLE, - "model_3": DUMMY_BAUDRATE_TABLE, -} - -DUMMY_ENCODING_TABLE = { - "Present_Position": 8, - "Goal_Position": 10, -} - -DUMMY_MODEL_ENCODING_TABLE = { - "model_1": DUMMY_ENCODING_TABLE, - "model_2": DUMMY_ENCODING_TABLE, - "model_3": DUMMY_ENCODING_TABLE, -} - -DUMMY_MODEL_NUMBER_TABLE = { - "model_1": 1234, - "model_2": 5678, - "model_3": 5799, -} - -DUMMY_MODEL_RESOLUTION_TABLE = { - "model_1": 4096, - "model_2": 1024, - "model_3": 4096, -} - - -class MockPortHandler: - def __init__(self, port_name): - self.is_open: bool = False - self.baudrate: int - self.packet_start_time: float - self.packet_timeout: float - self.tx_time_per_byte: float - self.is_using: bool = False - self.port_name: str = port_name - self.ser = None - - def openPort(self): - self.is_open = True - return self.is_open - - def closePort(self): - self.is_open = False - - def clearPort(self): ... - def setPortName(self, port_name): - self.port_name = port_name - - def getPortName(self): - return self.port_name - - def setBaudRate(self, baudrate): - self.baudrate: baudrate - - def getBaudRate(self): - return self.baudrate - - def getBytesAvailable(self): ... - def readPort(self, length): ... - def writePort(self, packet): ... - def setPacketTimeout(self, packet_length): ... - def setPacketTimeoutMillis(self, msec): ... - def isPacketTimeout(self): ... - def getCurrentTime(self): ... - def getTimeSinceStart(self): ... - def setupPort(self, cflag_baud): ... - def getCFlagBaud(self, baudrate): ... - - -class MockMotorsBus(MotorsBus): - available_baudrates = [500_000, 1_000_000] - default_timeout = 1000 - model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE - model_ctrl_table = DUMMY_MODEL_CTRL_TABLE - model_encoding_table = DUMMY_MODEL_ENCODING_TABLE - model_number_table = DUMMY_MODEL_NUMBER_TABLE - model_resolution_table = DUMMY_MODEL_RESOLUTION_TABLE - normalized_data = ["Present_Position", "Goal_Position"] - - def __init__(self, port: str, motors: dict[str, Motor]): - super().__init__(port, motors) - self.port_handler = MockPortHandler(port) - - def _assert_protocol_is_compatible(self, instruction_name): ... - def _handshake(self): ... - def _find_single_motor(self, motor, initial_baudrate): ... - def configure_motors(self): ... - def is_calibrated(self): ... - def read_calibration(self): ... - def write_calibration(self, calibration_dict): ... - def disable_torque(self, motors, num_retry): ... - def _disable_torque(self, motor, model, num_retry): ... - def enable_torque(self, motors, num_retry): ... - def _get_half_turn_homings(self, positions): ... - def _encode_sign(self, data_name, ids_values): ... - def _decode_sign(self, data_name, ids_values): ... - def _split_into_byte_chunks(self, value, length): ... - def broadcast_ping(self, num_retry, raise_on_error): ... diff --git a/tests/mocks/mock_robot.py b/tests/mocks/mock_robot.py deleted file mode 100644 index 8108c7c..0000000 --- a/tests/mocks/mock_robot.py +++ /dev/null @@ -1,128 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import random -from dataclasses import dataclass, field -from functools import cached_property -from typing import Any - -from lerobot.cameras import CameraConfig, make_cameras_from_configs -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError -from lerobot.robots import Robot, RobotConfig - - -@RobotConfig.register_subclass("mock_robot") -@dataclass -class MockRobotConfig(RobotConfig): - n_motors: int = 3 - cameras: dict[str, CameraConfig] = field(default_factory=dict) - random_values: bool = True - static_values: list[float] | None = None - calibrated: bool = True - - def __post_init__(self): - if self.n_motors < 1: - raise ValueError(self.n_motors) - - if self.random_values and self.static_values is not None: - raise ValueError("Choose either random values or static values") - - if self.static_values is not None and len(self.static_values) != self.n_motors: - raise ValueError("Specify the same number of static values as motors") - - if len(self.cameras) > 0: - raise NotImplementedError # TODO with the cameras refactor - - -class MockRobot(Robot): - """Mock Robot to be used for testing.""" - - config_class = MockRobotConfig - name = "mock_robot" - - def __init__(self, config: MockRobotConfig): - super().__init__(config) - self.config = config - self._is_connected = False - self._is_calibrated = config.calibrated - self.motors = [f"motor_{i + 1}" for i in range(config.n_motors)] - self.cameras = make_cameras_from_configs(config.cameras) - - @property - def _motors_ft(self) -> dict[str, type]: - return {f"{motor}.pos": float for motor in self.motors} - - @property - def _cameras_ft(self) -> dict[str, tuple]: - return { - cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras - } - - @cached_property - def observation_features(self) -> dict[str, type | tuple]: - return {**self._motors_ft, **self._cameras_ft} - - @cached_property - def action_features(self) -> dict[str, type]: - return self._motors_ft - - @property - def is_connected(self) -> bool: - return self._is_connected - - def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - - self._is_connected = True - if calibrate: - self.calibrate() - - @property - def is_calibrated(self) -> bool: - return self._is_calibrated - - def calibrate(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - - self._is_calibrated = True - - def configure(self) -> None: - pass - - def get_observation(self) -> dict[str, Any]: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - - if self.config.random_values: - return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors} - else: - return { - f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True) - } - - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - - return action - - def disconnect(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - - self._is_connected = False diff --git a/tests/mocks/mock_serial_patch.py b/tests/mocks/mock_serial_patch.py deleted file mode 100644 index bde0efa..0000000 --- a/tests/mocks/mock_serial_patch.py +++ /dev/null @@ -1,51 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import threading -import time - -from mock_serial.mock_serial import Stub - - -class WaitableStub(Stub): - """ - In some situations, a test might be checking if a stub has been called before `MockSerial` thread had time - to read, match, and call the stub. In these situations, the test can fail randomly. - - Use `wait_called()` or `wait_calls()` to block until the stub is called, avoiding race conditions. - - Proposed fix: - https://github.com/benthorner/mock_serial/pull/3 - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self._event = threading.Event() - - def call(self): - self._event.set() - return super().call() - - def wait_called(self, timeout: float = 1.0): - return self._event.wait(timeout) - - def wait_calls(self, min_calls: int = 1, timeout: float = 1.0): - start = time.perf_counter() - while time.perf_counter() - start < timeout: - if self.calls >= min_calls: - return self.calls - time.sleep(0.005) - raise TimeoutError(f"Stub not called {min_calls} times within {timeout} seconds.") diff --git a/tests/mocks/mock_teleop.py b/tests/mocks/mock_teleop.py deleted file mode 100644 index e37d4a2..0000000 --- a/tests/mocks/mock_teleop.py +++ /dev/null @@ -1,110 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import random -from dataclasses import dataclass -from functools import cached_property -from typing import Any - -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError -from lerobot.teleoperators import Teleoperator, TeleoperatorConfig - - -@TeleoperatorConfig.register_subclass("mock_teleop") -@dataclass -class MockTeleopConfig(TeleoperatorConfig): - n_motors: int = 3 - random_values: bool = True - static_values: list[float] | None = None - calibrated: bool = True - - def __post_init__(self): - if self.n_motors < 1: - raise ValueError(self.n_motors) - - if self.random_values and self.static_values is not None: - raise ValueError("Choose either random values or static values") - - if self.static_values is not None and len(self.static_values) != self.n_motors: - raise ValueError("Specify the same number of static values as motors") - - -class MockTeleop(Teleoperator): - """Mock Teleoperator to be used for testing.""" - - config_class = MockTeleopConfig - name = "mock_teleop" - - def __init__(self, config: MockTeleopConfig): - super().__init__(config) - self.config = config - self._is_connected = False - self._is_calibrated = config.calibrated - self.motors = [f"motor_{i + 1}" for i in range(config.n_motors)] - - @cached_property - def action_features(self) -> dict[str, type]: - return {f"{motor}.pos": float for motor in self.motors} - - @cached_property - def feedback_features(self) -> dict[str, type]: - return {f"{motor}.pos": float for motor in self.motors} - - @property - def is_connected(self) -> bool: - return self._is_connected - - def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - - self._is_connected = True - if calibrate: - self.calibrate() - - @property - def is_calibrated(self) -> bool: - return self._is_calibrated - - def calibrate(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - - self._is_calibrated = True - - def configure(self) -> None: - pass - - def get_action(self) -> dict[str, Any]: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - - if self.config.random_values: - return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors} - else: - return { - f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True) - } - - def send_feedback(self, feedback: dict[str, Any]) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - - def disconnect(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - - self._is_connected = False diff --git a/tests/motors/test_dynamixel.py b/tests/motors/test_dynamixel.py deleted file mode 100644 index e0dbe71..0000000 --- a/tests/motors/test_dynamixel.py +++ /dev/null @@ -1,416 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -import sys -from collections.abc import Generator -from unittest.mock import MagicMock, patch - -import pytest - -from lerobot.motors import Motor, MotorCalibration, MotorNormMode -from lerobot.motors.dynamixel import MODEL_NUMBER_TABLE, DynamixelMotorsBus -from lerobot.motors.dynamixel.tables import X_SERIES_CONTROL_TABLE -from lerobot.utils.encoding_utils import encode_twos_complement - -try: - import dynamixel_sdk as dxl - - from tests.mocks.mock_dynamixel import MockMotors, MockPortHandler -except (ImportError, ModuleNotFoundError): - pytest.skip("dynamixel_sdk not available", allow_module_level=True) - - -@pytest.fixture(autouse=True) -def patch_port_handler(): - if sys.platform == "darwin": - with patch.object(dxl, "PortHandler", MockPortHandler): - yield - else: - yield - - -@pytest.fixture -def mock_motors() -> Generator[MockMotors, None, None]: - motors = MockMotors() - motors.open() - yield motors - motors.close() - - -@pytest.fixture -def dummy_motors() -> dict[str, Motor]: - return { - "dummy_1": Motor(1, "xl430-w250", MotorNormMode.RANGE_M100_100), - "dummy_2": Motor(2, "xm540-w270", MotorNormMode.RANGE_M100_100), - "dummy_3": Motor(3, "xl330-m077", MotorNormMode.RANGE_M100_100), - } - - -@pytest.fixture -def dummy_calibration(dummy_motors) -> dict[str, MotorCalibration]: - drive_modes = [0, 1, 0] - homings = [-709, -2006, 1624] - mins = [43, 27, 145] - maxes = [1335, 3608, 3999] - calibration = {} - for motor, m in dummy_motors.items(): - calibration[motor] = MotorCalibration( - id=m.id, - drive_mode=drive_modes[m.id - 1], - homing_offset=homings[m.id - 1], - range_min=mins[m.id - 1], - range_max=maxes[m.id - 1], - ) - return calibration - - -@pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}") -def test_autouse_patch(): - """Ensures that the autouse fixture correctly patches dxl.PortHandler with MockPortHandler.""" - assert dxl.PortHandler is MockPortHandler - - -@pytest.mark.parametrize( - "value, length, expected", - [ - (0x12, 1, [0x12]), - (0x1234, 2, [0x34, 0x12]), - (0x12345678, 4, [0x78, 0x56, 0x34, 0x12]), - ], - ids=[ - "1 byte", - "2 bytes", - "4 bytes", - ], -) # fmt: skip -def test__split_into_byte_chunks(value, length, expected): - bus = DynamixelMotorsBus("", {}) - assert bus._split_into_byte_chunks(value, length) == expected - - -def test_abc_implementation(dummy_motors): - """Instantiation should raise an error if the class doesn't implement abstract methods/properties.""" - DynamixelMotorsBus(port="/dev/dummy-port", motors=dummy_motors) - - -@pytest.mark.parametrize("id_", [1, 2, 3]) -def test_ping(id_, mock_motors, dummy_motors): - expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model] - stub = mock_motors.build_ping_stub(id_, expected_model_nb) - bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - - ping_model_nb = bus.ping(id_) - - assert ping_model_nb == expected_model_nb - assert mock_motors.stubs[stub].called - - -def test_broadcast_ping(mock_motors, dummy_motors): - models = {m.id: m.model for m in dummy_motors.values()} - expected_model_nbs = {id_: MODEL_NUMBER_TABLE[model] for id_, model in models.items()} - stub = mock_motors.build_broadcast_ping_stub(expected_model_nbs) - bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - - ping_model_nbs = bus.broadcast_ping() - - assert ping_model_nbs == expected_model_nbs - assert mock_motors.stubs[stub].called - - -@pytest.mark.parametrize( - "addr, length, id_, value", - [ - (0, 1, 1, 2), - (10, 2, 2, 999), - (42, 4, 3, 1337), - ], -) -def test__read(addr, length, id_, value, mock_motors, dummy_motors): - stub = mock_motors.build_read_stub(addr, length, id_, value) - bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - - read_value, _, _ = bus._read(addr, length, id_) - - assert mock_motors.stubs[stub].called - assert read_value == value - - -@pytest.mark.parametrize("raise_on_error", (True, False)) -def test__read_error(raise_on_error, mock_motors, dummy_motors): - addr, length, id_, value, error = (10, 4, 1, 1337, dxl.ERRNUM_DATA_LIMIT) - stub = mock_motors.build_read_stub(addr, length, id_, value, error=error) - bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - - if raise_on_error: - with pytest.raises( - RuntimeError, match=re.escape("[RxPacketError] The data value exceeds the limit value!") - ): - bus._read(addr, length, id_, raise_on_error=raise_on_error) - else: - _, _, read_error = bus._read(addr, length, id_, raise_on_error=raise_on_error) - assert read_error == error - - assert mock_motors.stubs[stub].called - - -@pytest.mark.parametrize("raise_on_error", (True, False)) -def test__read_comm(raise_on_error, mock_motors, dummy_motors): - addr, length, id_, value = (10, 4, 1, 1337) - stub = mock_motors.build_read_stub(addr, length, id_, value, reply=False) - bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - - if raise_on_error: - with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): - bus._read(addr, length, id_, raise_on_error=raise_on_error) - else: - _, read_comm, _ = bus._read(addr, length, id_, raise_on_error=raise_on_error) - assert read_comm == dxl.COMM_RX_TIMEOUT - - assert mock_motors.stubs[stub].called - - -@pytest.mark.parametrize( - "addr, length, id_, value", - [ - (0, 1, 1, 2), - (10, 2, 2, 999), - (42, 4, 3, 1337), - ], -) -def test__write(addr, length, id_, value, mock_motors, dummy_motors): - stub = mock_motors.build_write_stub(addr, length, id_, value) - bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - - comm, error = bus._write(addr, length, id_, value) - - assert mock_motors.stubs[stub].called - assert comm == dxl.COMM_SUCCESS - assert error == 0 - - -@pytest.mark.parametrize("raise_on_error", (True, False)) -def test__write_error(raise_on_error, mock_motors, dummy_motors): - addr, length, id_, value, error = (10, 4, 1, 1337, dxl.ERRNUM_DATA_LIMIT) - stub = mock_motors.build_write_stub(addr, length, id_, value, error=error) - bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - - if raise_on_error: - with pytest.raises( - RuntimeError, match=re.escape("[RxPacketError] The data value exceeds the limit value!") - ): - bus._write(addr, length, id_, value, raise_on_error=raise_on_error) - else: - _, write_error = bus._write(addr, length, id_, value, raise_on_error=raise_on_error) - assert write_error == error - - assert mock_motors.stubs[stub].called - - -@pytest.mark.parametrize("raise_on_error", (True, False)) -def test__write_comm(raise_on_error, mock_motors, dummy_motors): - addr, length, id_, value = (10, 4, 1, 1337) - stub = mock_motors.build_write_stub(addr, length, id_, value, reply=False) - bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - - if raise_on_error: - with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): - bus._write(addr, length, id_, value, raise_on_error=raise_on_error) - else: - write_comm, _ = bus._write(addr, length, id_, value, raise_on_error=raise_on_error) - assert write_comm == dxl.COMM_RX_TIMEOUT - - assert mock_motors.stubs[stub].called - - -@pytest.mark.parametrize( - "addr, length, ids_values", - [ - (0, 1, {1: 4}), - (10, 2, {1: 1337, 2: 42}), - (42, 4, {1: 1337, 2: 42, 3: 4016}), - ], - ids=["1 motor", "2 motors", "3 motors"], -) -def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors): - stub = mock_motors.build_sync_read_stub(addr, length, ids_values) - bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - - read_values, _ = bus._sync_read(addr, length, list(ids_values)) - - assert mock_motors.stubs[stub].called - assert read_values == ids_values - - -@pytest.mark.parametrize("raise_on_error", (True, False)) -def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors): - addr, length, ids_values = (10, 4, {1: 1337}) - stub = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False) - bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - - if raise_on_error: - with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): - bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) - else: - _, read_comm = bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) - assert read_comm == dxl.COMM_RX_TIMEOUT - - assert mock_motors.stubs[stub].called - - -@pytest.mark.parametrize( - "addr, length, ids_values", - [ - (0, 1, {1: 4}), - (10, 2, {1: 1337, 2: 42}), - (42, 4, {1: 1337, 2: 42, 3: 4016}), - ], - ids=["1 motor", "2 motors", "3 motors"], -) -def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors): - stub = mock_motors.build_sync_write_stub(addr, length, ids_values) - bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - - comm = bus._sync_write(addr, length, ids_values) - - assert mock_motors.stubs[stub].wait_called() - assert comm == dxl.COMM_SUCCESS - - -def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration): - drive_modes = {m.id: m.drive_mode for m in dummy_calibration.values()} - encoded_homings = {m.id: encode_twos_complement(m.homing_offset, 4) for m in dummy_calibration.values()} - mins = {m.id: m.range_min for m in dummy_calibration.values()} - maxes = {m.id: m.range_max for m in dummy_calibration.values()} - drive_modes_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Drive_Mode"], drive_modes) - offsets_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], encoded_homings) - mins_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Min_Position_Limit"], mins) - maxes_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Max_Position_Limit"], maxes) - bus = DynamixelMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - calibration=dummy_calibration, - ) - bus.connect(handshake=False) - - is_calibrated = bus.is_calibrated - - assert is_calibrated - assert mock_motors.stubs[drive_modes_stub].called - assert mock_motors.stubs[offsets_stub].called - assert mock_motors.stubs[mins_stub].called - assert mock_motors.stubs[maxes_stub].called - - -def test_reset_calibration(mock_motors, dummy_motors): - write_homing_stubs = [] - write_mins_stubs = [] - write_maxes_stubs = [] - for motor in dummy_motors.values(): - write_homing_stubs.append( - mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], motor.id, 0) - ) - write_mins_stubs.append( - mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Min_Position_Limit"], motor.id, 0) - ) - write_maxes_stubs.append( - mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Max_Position_Limit"], motor.id, 4095) - ) - - bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - - bus.reset_calibration() - - assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs) - assert all(mock_motors.stubs[stub].called for stub in write_mins_stubs) - assert all(mock_motors.stubs[stub].called for stub in write_maxes_stubs) - - -def test_set_half_turn_homings(mock_motors, dummy_motors): - """ - For this test, we assume that the homing offsets are already 0 such that - Present_Position == Actual_Position - """ - current_positions = { - 1: 1337, - 2: 42, - 3: 3672, - } - expected_homings = { - 1: 710, # 2047 - 1337 - 2: 2005, # 2047 - 42 - 3: -1625, # 2047 - 3672 - } - read_pos_stub = mock_motors.build_sync_read_stub( - *X_SERIES_CONTROL_TABLE["Present_Position"], current_positions - ) - write_homing_stubs = [] - for id_, homing in expected_homings.items(): - encoded_homing = encode_twos_complement(homing, 4) - stub = mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], id_, encoded_homing) - write_homing_stubs.append(stub) - - bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - bus.reset_calibration = MagicMock() - - bus.set_half_turn_homings() - - bus.reset_calibration.assert_called_once() - assert mock_motors.stubs[read_pos_stub].called - assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs) - - -def test_record_ranges_of_motion(mock_motors, dummy_motors): - positions = { - 1: [351, 42, 1337], - 2: [28, 3600, 2444], - 3: [4002, 2999, 146], - } - expected_mins = { - "dummy_1": 42, - "dummy_2": 28, - "dummy_3": 146, - } - expected_maxes = { - "dummy_1": 1337, - "dummy_2": 3600, - "dummy_3": 4002, - } - read_pos_stub = mock_motors.build_sequential_sync_read_stub( - *X_SERIES_CONTROL_TABLE["Present_Position"], positions - ) - with patch("lerobot.motors.motors_bus.enter_pressed", side_effect=[False, True]): - bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - - mins, maxes = bus.record_ranges_of_motion(display_values=False) - - assert mock_motors.stubs[read_pos_stub].calls == 3 - assert mins == expected_mins - assert maxes == expected_maxes diff --git a/tests/motors/test_feetech.py b/tests/motors/test_feetech.py deleted file mode 100644 index 31e4a90..0000000 --- a/tests/motors/test_feetech.py +++ /dev/null @@ -1,459 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -import sys -from collections.abc import Generator -from unittest.mock import MagicMock, patch - -import pytest - -from lerobot.motors import Motor, MotorCalibration, MotorNormMode -from lerobot.motors.feetech import MODEL_NUMBER, MODEL_NUMBER_TABLE, FeetechMotorsBus -from lerobot.motors.feetech.tables import STS_SMS_SERIES_CONTROL_TABLE -from lerobot.utils.encoding_utils import encode_sign_magnitude - -try: - import scservo_sdk as scs - - from tests.mocks.mock_feetech import MockMotors, MockPortHandler -except (ImportError, ModuleNotFoundError): - pytest.skip("scservo_sdk not available", allow_module_level=True) - - -@pytest.fixture(autouse=True) -def patch_port_handler(): - if sys.platform == "darwin": - with patch.object(scs, "PortHandler", MockPortHandler): - yield - else: - yield - - -@pytest.fixture -def mock_motors() -> Generator[MockMotors, None, None]: - motors = MockMotors() - motors.open() - yield motors - motors.close() - - -@pytest.fixture -def dummy_motors() -> dict[str, Motor]: - return { - "dummy_1": Motor(1, "sts3215", MotorNormMode.RANGE_M100_100), - "dummy_2": Motor(2, "sts3215", MotorNormMode.RANGE_M100_100), - "dummy_3": Motor(3, "sts3215", MotorNormMode.RANGE_M100_100), - } - - -@pytest.fixture -def dummy_calibration(dummy_motors) -> dict[str, MotorCalibration]: - homings = [-709, -2006, 1624] - mins = [43, 27, 145] - maxes = [1335, 3608, 3999] - calibration = {} - for motor, m in dummy_motors.items(): - calibration[motor] = MotorCalibration( - id=m.id, - drive_mode=0, - homing_offset=homings[m.id - 1], - range_min=mins[m.id - 1], - range_max=maxes[m.id - 1], - ) - return calibration - - -@pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}") -def test_autouse_patch(): - """Ensures that the autouse fixture correctly patches scs.PortHandler with MockPortHandler.""" - assert scs.PortHandler is MockPortHandler - - -@pytest.mark.parametrize( - "protocol, value, length, expected", - [ - (0, 0x12, 1, [0x12]), - (1, 0x12, 1, [0x12]), - (0, 0x1234, 2, [0x34, 0x12]), - (1, 0x1234, 2, [0x12, 0x34]), - (0, 0x12345678, 4, [0x78, 0x56, 0x34, 0x12]), - (1, 0x12345678, 4, [0x56, 0x78, 0x12, 0x34]), - ], - ids=[ - "P0: 1 byte", - "P1: 1 byte", - "P0: 2 bytes", - "P1: 2 bytes", - "P0: 4 bytes", - "P1: 4 bytes", - ], -) # fmt: skip -def test__split_into_byte_chunks(protocol, value, length, expected): - bus = FeetechMotorsBus("", {}, protocol_version=protocol) - assert bus._split_into_byte_chunks(value, length) == expected - - -def test_abc_implementation(dummy_motors): - """Instantiation should raise an error if the class doesn't implement abstract methods/properties.""" - FeetechMotorsBus(port="/dev/dummy-port", motors=dummy_motors) - - -@pytest.mark.parametrize("id_", [1, 2, 3]) -def test_ping(id_, mock_motors, dummy_motors): - expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model] - addr, length = MODEL_NUMBER - ping_stub = mock_motors.build_ping_stub(id_) - mobel_nb_stub = mock_motors.build_read_stub(addr, length, id_, expected_model_nb) - bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - bus.connect(handshake=False) - - ping_model_nb = bus.ping(id_) - - assert ping_model_nb == expected_model_nb - assert mock_motors.stubs[ping_stub].called - assert mock_motors.stubs[mobel_nb_stub].called - - -def test_broadcast_ping(mock_motors, dummy_motors): - models = {m.id: m.model for m in dummy_motors.values()} - addr, length = MODEL_NUMBER - ping_stub = mock_motors.build_broadcast_ping_stub(list(models)) - mobel_nb_stubs = [] - expected_model_nbs = {} - for id_, model in models.items(): - model_nb = MODEL_NUMBER_TABLE[model] - stub = mock_motors.build_read_stub(addr, length, id_, model_nb) - expected_model_nbs[id_] = model_nb - mobel_nb_stubs.append(stub) - bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - bus.connect(handshake=False) - - ping_model_nbs = bus.broadcast_ping() - - assert ping_model_nbs == expected_model_nbs - assert mock_motors.stubs[ping_stub].called - assert all(mock_motors.stubs[stub].called for stub in mobel_nb_stubs) - - -@pytest.mark.parametrize( - "addr, length, id_, value", - [ - (0, 1, 1, 2), - (10, 2, 2, 999), - (42, 4, 3, 1337), - ], -) -def test__read(addr, length, id_, value, mock_motors, dummy_motors): - stub = mock_motors.build_read_stub(addr, length, id_, value) - bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - bus.connect(handshake=False) - - read_value, _, _ = bus._read(addr, length, id_) - - assert mock_motors.stubs[stub].called - assert read_value == value - - -@pytest.mark.parametrize("raise_on_error", (True, False)) -def test__read_error(raise_on_error, mock_motors, dummy_motors): - addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE) - stub = mock_motors.build_read_stub(addr, length, id_, value, error=error) - bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - bus.connect(handshake=False) - - if raise_on_error: - with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")): - bus._read(addr, length, id_, raise_on_error=raise_on_error) - else: - _, _, read_error = bus._read(addr, length, id_, raise_on_error=raise_on_error) - assert read_error == error - - assert mock_motors.stubs[stub].called - - -@pytest.mark.parametrize("raise_on_error", (True, False)) -def test__read_comm(raise_on_error, mock_motors, dummy_motors): - addr, length, id_, value = (10, 4, 1, 1337) - stub = mock_motors.build_read_stub(addr, length, id_, value, reply=False) - bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - bus.connect(handshake=False) - - if raise_on_error: - with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): - bus._read(addr, length, id_, raise_on_error=raise_on_error) - else: - _, read_comm, _ = bus._read(addr, length, id_, raise_on_error=raise_on_error) - assert read_comm == scs.COMM_RX_TIMEOUT - - assert mock_motors.stubs[stub].called - - -@pytest.mark.parametrize( - "addr, length, id_, value", - [ - (0, 1, 1, 2), - (10, 2, 2, 999), - (42, 4, 3, 1337), - ], -) -def test__write(addr, length, id_, value, mock_motors, dummy_motors): - stub = mock_motors.build_write_stub(addr, length, id_, value) - bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - ) - bus.connect(handshake=False) - - comm, error = bus._write(addr, length, id_, value) - - assert mock_motors.stubs[stub].wait_called() - assert comm == scs.COMM_SUCCESS - assert error == 0 - - -@pytest.mark.parametrize("raise_on_error", (True, False)) -def test__write_error(raise_on_error, mock_motors, dummy_motors): - addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE) - stub = mock_motors.build_write_stub(addr, length, id_, value, error=error) - bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - - if raise_on_error: - with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")): - bus._write(addr, length, id_, value, raise_on_error=raise_on_error) - else: - _, write_error = bus._write(addr, length, id_, value, raise_on_error=raise_on_error) - assert write_error == error - - assert mock_motors.stubs[stub].called - - -@pytest.mark.parametrize("raise_on_error", (True, False)) -def test__write_comm(raise_on_error, mock_motors, dummy_motors): - addr, length, id_, value = (10, 4, 1, 1337) - stub = mock_motors.build_write_stub(addr, length, id_, value, reply=False) - bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - - if raise_on_error: - with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): - bus._write(addr, length, id_, value, raise_on_error=raise_on_error) - else: - write_comm, _ = bus._write(addr, length, id_, value, raise_on_error=raise_on_error) - assert write_comm == scs.COMM_RX_TIMEOUT - - assert mock_motors.stubs[stub].called - - -@pytest.mark.parametrize( - "addr, length, ids_values", - [ - (0, 1, {1: 4}), - (10, 2, {1: 1337, 2: 42}), - (42, 4, {1: 1337, 2: 42, 3: 4016}), - ], - ids=["1 motor", "2 motors", "3 motors"], -) -def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors): - stub = mock_motors.build_sync_read_stub(addr, length, ids_values) - bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - - read_values, _ = bus._sync_read(addr, length, list(ids_values)) - - assert mock_motors.stubs[stub].called - assert read_values == ids_values - - -@pytest.mark.parametrize("raise_on_error", (True, False)) -def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors): - addr, length, ids_values = (10, 4, {1: 1337}) - stub = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False) - bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - - if raise_on_error: - with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): - bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) - else: - _, read_comm = bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) - assert read_comm == scs.COMM_RX_TIMEOUT - - assert mock_motors.stubs[stub].called - - -@pytest.mark.parametrize( - "addr, length, ids_values", - [ - (0, 1, {1: 4}), - (10, 2, {1: 1337, 2: 42}), - (42, 4, {1: 1337, 2: 42, 3: 4016}), - ], - ids=["1 motor", "2 motors", "3 motors"], -) -def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors): - stub = mock_motors.build_sync_write_stub(addr, length, ids_values) - bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - - comm = bus._sync_write(addr, length, ids_values) - - assert mock_motors.stubs[stub].wait_called() - assert comm == scs.COMM_SUCCESS - - -def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration): - mins_stubs, maxes_stubs, homings_stubs = [], [], [] - for cal in dummy_calibration.values(): - mins_stubs.append( - mock_motors.build_read_stub( - *STS_SMS_SERIES_CONTROL_TABLE["Min_Position_Limit"], cal.id, cal.range_min - ) - ) - maxes_stubs.append( - mock_motors.build_read_stub( - *STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], cal.id, cal.range_max - ) - ) - homings_stubs.append( - mock_motors.build_read_stub( - *STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], - cal.id, - encode_sign_magnitude(cal.homing_offset, 11), - ) - ) - - bus = FeetechMotorsBus( - port=mock_motors.port, - motors=dummy_motors, - calibration=dummy_calibration, - ) - bus.connect(handshake=False) - - is_calibrated = bus.is_calibrated - - assert is_calibrated - assert all(mock_motors.stubs[stub].called for stub in mins_stubs) - assert all(mock_motors.stubs[stub].called for stub in maxes_stubs) - assert all(mock_motors.stubs[stub].called for stub in homings_stubs) - - -def test_reset_calibration(mock_motors, dummy_motors): - write_homing_stubs = [] - write_mins_stubs = [] - write_maxes_stubs = [] - for motor in dummy_motors.values(): - write_homing_stubs.append( - mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], motor.id, 0) - ) - write_mins_stubs.append( - mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Min_Position_Limit"], motor.id, 0) - ) - write_maxes_stubs.append( - mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], motor.id, 4095) - ) - - bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - - bus.reset_calibration() - - assert all(mock_motors.stubs[stub].wait_called() for stub in write_homing_stubs) - assert all(mock_motors.stubs[stub].wait_called() for stub in write_mins_stubs) - assert all(mock_motors.stubs[stub].wait_called() for stub in write_maxes_stubs) - - -def test_set_half_turn_homings(mock_motors, dummy_motors): - """ - For this test, we assume that the homing offsets are already 0 such that - Present_Position == Actual_Position - """ - current_positions = { - 1: 1337, - 2: 42, - 3: 3672, - } - expected_homings = { - 1: -710, # 1337 - 2047 - 2: -2005, # 42 - 2047 - 3: 1625, # 3672 - 2047 - } - read_pos_stub = mock_motors.build_sync_read_stub( - *STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], current_positions - ) - write_homing_stubs = [] - for id_, homing in expected_homings.items(): - encoded_homing = encode_sign_magnitude(homing, 11) - stub = mock_motors.build_write_stub( - *STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], id_, encoded_homing - ) - write_homing_stubs.append(stub) - - bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - bus.reset_calibration = MagicMock() - - bus.set_half_turn_homings() - - bus.reset_calibration.assert_called_once() - assert mock_motors.stubs[read_pos_stub].called - assert all(mock_motors.stubs[stub].wait_called() for stub in write_homing_stubs) - - -def test_record_ranges_of_motion(mock_motors, dummy_motors): - positions = { - 1: [351, 42, 1337], - 2: [28, 3600, 2444], - 3: [4002, 2999, 146], - } - expected_mins = { - "dummy_1": 42, - "dummy_2": 28, - "dummy_3": 146, - } - expected_maxes = { - "dummy_1": 1337, - "dummy_2": 3600, - "dummy_3": 4002, - } - stub = mock_motors.build_sequential_sync_read_stub( - *STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], positions - ) - with patch("lerobot.motors.motors_bus.enter_pressed", side_effect=[False, True]): - bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) - bus.connect(handshake=False) - - mins, maxes = bus.record_ranges_of_motion(display_values=False) - - assert mock_motors.stubs[stub].calls == 3 - assert mins == expected_mins - assert maxes == expected_maxes diff --git a/tests/motors/test_motors_bus.py b/tests/motors/test_motors_bus.py deleted file mode 100644 index 27650ef..0000000 --- a/tests/motors/test_motors_bus.py +++ /dev/null @@ -1,358 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -from unittest.mock import patch - -import pytest - -from lerobot.motors.motors_bus import ( - Motor, - MotorNormMode, - assert_same_address, - get_address, - get_ctrl_table, -) -from tests.mocks.mock_motors_bus import ( - DUMMY_CTRL_TABLE_1, - DUMMY_CTRL_TABLE_2, - DUMMY_MODEL_CTRL_TABLE, - MockMotorsBus, -) - - -@pytest.fixture -def dummy_motors() -> dict[str, Motor]: - return { - "dummy_1": Motor(1, "model_2", MotorNormMode.RANGE_M100_100), - "dummy_2": Motor(2, "model_3", MotorNormMode.RANGE_M100_100), - "dummy_3": Motor(3, "model_2", MotorNormMode.RANGE_0_100), - } - - -def test_get_ctrl_table(): - model = "model_1" - ctrl_table = get_ctrl_table(DUMMY_MODEL_CTRL_TABLE, model) - assert ctrl_table == DUMMY_CTRL_TABLE_1 - - -def test_get_ctrl_table_error(): - model = "model_99" - with pytest.raises(KeyError, match=f"Control table for {model=} not found."): - get_ctrl_table(DUMMY_MODEL_CTRL_TABLE, model) - - -def test_get_address(): - addr, n_bytes = get_address(DUMMY_MODEL_CTRL_TABLE, "model_1", "Firmware_Version") - assert addr == 0 - assert n_bytes == 1 - - -def test_get_address_error(): - model = "model_1" - data_name = "Lock" - with pytest.raises(KeyError, match=f"Address for '{data_name}' not found in {model} control table."): - get_address(DUMMY_MODEL_CTRL_TABLE, "model_1", data_name) - - -def test_assert_same_address(): - models = ["model_1", "model_2"] - assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Present_Position") - - -def test_assert_same_length_different_addresses(): - models = ["model_1", "model_2"] - with pytest.raises( - NotImplementedError, - match=re.escape("At least two motor models use a different address"), - ): - assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Model_Number") - - -def test_assert_same_address_different_length(): - models = ["model_1", "model_2"] - with pytest.raises( - NotImplementedError, - match=re.escape("At least two motor models use a different bytes representation"), - ): - assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Goal_Position") - - -def test__serialize_data_invalid_length(): - bus = MockMotorsBus("", {}) - with pytest.raises(NotImplementedError): - bus._serialize_data(100, 3) - - -def test__serialize_data_negative_numbers(): - bus = MockMotorsBus("", {}) - with pytest.raises(ValueError): - bus._serialize_data(-1, 1) - - -def test__serialize_data_large_number(): - bus = MockMotorsBus("", {}) - with pytest.raises(ValueError): - bus._serialize_data(2**32, 4) # 4-byte max is 0xFFFFFFFF - - -@pytest.mark.parametrize( - "data_name, id_, value", - [ - ("Firmware_Version", 1, 14), - ("Model_Number", 1, 5678), - ("Present_Position", 2, 1337), - ("Present_Velocity", 3, 42), - ], -) -def test_read(data_name, id_, value, dummy_motors): - bus = MockMotorsBus("/dev/dummy-port", dummy_motors) - bus.connect(handshake=False) - addr, length = DUMMY_CTRL_TABLE_2[data_name] - - with ( - patch.object(MockMotorsBus, "_read", return_value=(value, 0, 0)) as mock__read, - patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign, - patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize, - ): - returned_value = bus.read(data_name, f"dummy_{id_}") - - assert returned_value == value - mock__read.assert_called_once_with( - addr, - length, - id_, - num_retry=0, - raise_on_error=True, - err_msg=f"Failed to read '{data_name}' on {id_=} after 1 tries.", - ) - mock__decode_sign.assert_called_once_with(data_name, {id_: value}) - if data_name in bus.normalized_data: - mock__normalize.assert_called_once_with({id_: value}) - - -@pytest.mark.parametrize( - "data_name, id_, value", - [ - ("Goal_Position", 1, 1337), - ("Goal_Velocity", 2, 3682), - ("Lock", 3, 1), - ], -) -def test_write(data_name, id_, value, dummy_motors): - bus = MockMotorsBus("/dev/dummy-port", dummy_motors) - bus.connect(handshake=False) - addr, length = DUMMY_CTRL_TABLE_2[data_name] - - with ( - patch.object(MockMotorsBus, "_write", return_value=(0, 0)) as mock__write, - patch.object(MockMotorsBus, "_encode_sign", return_value={id_: value}) as mock__encode_sign, - patch.object(MockMotorsBus, "_unnormalize", return_value={id_: value}) as mock__unnormalize, - ): - bus.write(data_name, f"dummy_{id_}", value) - - mock__write.assert_called_once_with( - addr, - length, - id_, - value, - num_retry=0, - raise_on_error=True, - err_msg=f"Failed to write '{data_name}' on {id_=} with '{value}' after 1 tries.", - ) - mock__encode_sign.assert_called_once_with(data_name, {id_: value}) - if data_name in bus.normalized_data: - mock__unnormalize.assert_called_once_with({id_: value}) - - -@pytest.mark.parametrize( - "data_name, id_, value", - [ - ("Firmware_Version", 1, 14), - ("Model_Number", 1, 5678), - ("Present_Position", 2, 1337), - ("Present_Velocity", 3, 42), - ], -) -def test_sync_read_by_str(data_name, id_, value, dummy_motors): - bus = MockMotorsBus("/dev/dummy-port", dummy_motors) - bus.connect(handshake=False) - addr, length = DUMMY_CTRL_TABLE_2[data_name] - ids = [id_] - expected_value = {f"dummy_{id_}": value} - - with ( - patch.object(MockMotorsBus, "_sync_read", return_value=({id_: value}, 0)) as mock__sync_read, - patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign, - patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize, - ): - returned_dict = bus.sync_read(data_name, f"dummy_{id_}") - - assert returned_dict == expected_value - mock__sync_read.assert_called_once_with( - addr, - length, - ids, - num_retry=0, - raise_on_error=True, - err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.", - ) - mock__decode_sign.assert_called_once_with(data_name, {id_: value}) - if data_name in bus.normalized_data: - mock__normalize.assert_called_once_with({id_: value}) - - -@pytest.mark.parametrize( - "data_name, ids_values", - [ - ("Model_Number", {1: 5678}), - ("Present_Position", {1: 1337, 2: 42}), - ("Present_Velocity", {1: 1337, 2: 42, 3: 4016}), - ], - ids=["1 motor", "2 motors", "3 motors"], -) -def test_sync_read_by_list(data_name, ids_values, dummy_motors): - bus = MockMotorsBus("/dev/dummy-port", dummy_motors) - bus.connect(handshake=False) - addr, length = DUMMY_CTRL_TABLE_2[data_name] - ids = list(ids_values) - expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()} - - with ( - patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read, - patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign, - patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize, - ): - returned_dict = bus.sync_read(data_name, [f"dummy_{id_}" for id_ in ids]) - - assert returned_dict == expected_values - mock__sync_read.assert_called_once_with( - addr, - length, - ids, - num_retry=0, - raise_on_error=True, - err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.", - ) - mock__decode_sign.assert_called_once_with(data_name, ids_values) - if data_name in bus.normalized_data: - mock__normalize.assert_called_once_with(ids_values) - - -@pytest.mark.parametrize( - "data_name, ids_values", - [ - ("Model_Number", {1: 5678, 2: 5799, 3: 5678}), - ("Present_Position", {1: 1337, 2: 42, 3: 4016}), - ("Goal_Position", {1: 4008, 2: 199, 3: 3446}), - ], - ids=["Model_Number", "Present_Position", "Goal_Position"], -) -def test_sync_read_by_none(data_name, ids_values, dummy_motors): - bus = MockMotorsBus("/dev/dummy-port", dummy_motors) - bus.connect(handshake=False) - addr, length = DUMMY_CTRL_TABLE_2[data_name] - ids = list(ids_values) - expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()} - - with ( - patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read, - patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign, - patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize, - ): - returned_dict = bus.sync_read(data_name) - - assert returned_dict == expected_values - mock__sync_read.assert_called_once_with( - addr, - length, - ids, - num_retry=0, - raise_on_error=True, - err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.", - ) - mock__decode_sign.assert_called_once_with(data_name, ids_values) - if data_name in bus.normalized_data: - mock__normalize.assert_called_once_with(ids_values) - - -@pytest.mark.parametrize( - "data_name, value", - [ - ("Goal_Position", 500), - ("Goal_Velocity", 4010), - ("Lock", 0), - ], -) -def test_sync_write_by_single_value(data_name, value, dummy_motors): - bus = MockMotorsBus("/dev/dummy-port", dummy_motors) - bus.connect(handshake=False) - addr, length = DUMMY_CTRL_TABLE_2[data_name] - ids_values = {m.id: value for m in dummy_motors.values()} - - with ( - patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write, - patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign, - patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize, - ): - bus.sync_write(data_name, value) - - mock__sync_write.assert_called_once_with( - addr, - length, - ids_values, - num_retry=0, - raise_on_error=True, - err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.", - ) - mock__encode_sign.assert_called_once_with(data_name, ids_values) - if data_name in bus.normalized_data: - mock__unnormalize.assert_called_once_with(ids_values) - - -@pytest.mark.parametrize( - "data_name, ids_values", - [ - ("Goal_Position", {1: 1337, 2: 42, 3: 4016}), - ("Goal_Velocity", {1: 50, 2: 83, 3: 2777}), - ("Lock", {1: 0, 2: 0, 3: 1}), - ], - ids=["Goal_Position", "Goal_Velocity", "Lock"], -) -def test_sync_write_by_value_dict(data_name, ids_values, dummy_motors): - bus = MockMotorsBus("/dev/dummy-port", dummy_motors) - bus.connect(handshake=False) - addr, length = DUMMY_CTRL_TABLE_2[data_name] - values = {f"dummy_{id_}": val for id_, val in ids_values.items()} - - with ( - patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write, - patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign, - patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize, - ): - bus.sync_write(data_name, values) - - mock__sync_write.assert_called_once_with( - addr, - length, - ids_values, - num_retry=0, - raise_on_error=True, - err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.", - ) - mock__encode_sign.assert_called_once_with(data_name, ids_values) - if data_name in bus.normalized_data: - mock__unnormalize.assert_called_once_with(ids_values) diff --git a/tests/optim/test_optimizers.py b/tests/optim/test_optimizers.py deleted file mode 100644 index 4152c7f..0000000 --- a/tests/optim/test_optimizers.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest -import torch - -from lerobot.constants import ( - OPTIMIZER_PARAM_GROUPS, - OPTIMIZER_STATE, -) -from lerobot.optim.optimizers import ( - AdamConfig, - AdamWConfig, - MultiAdamConfig, - SGDConfig, - load_optimizer_state, - save_optimizer_state, -) - - -@pytest.mark.parametrize( - "config_cls, expected_class", - [ - (AdamConfig, torch.optim.Adam), - (AdamWConfig, torch.optim.AdamW), - (SGDConfig, torch.optim.SGD), - (MultiAdamConfig, dict), - ], -) -def test_optimizer_build(config_cls, expected_class, model_params): - config = config_cls() - if config_cls == MultiAdamConfig: - params_dict = {"default": model_params} - optimizer = config.build(params_dict) - assert isinstance(optimizer, expected_class) - assert isinstance(optimizer["default"], torch.optim.Adam) - assert optimizer["default"].defaults["lr"] == config.lr - else: - optimizer = config.build(model_params) - assert isinstance(optimizer, expected_class) - assert optimizer.defaults["lr"] == config.lr - - -def test_save_optimizer_state(optimizer, tmp_path): - save_optimizer_state(optimizer, tmp_path) - assert (tmp_path / OPTIMIZER_STATE).is_file() - assert (tmp_path / OPTIMIZER_PARAM_GROUPS).is_file() - - -def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path): - save_optimizer_state(optimizer, tmp_path) - loaded_optimizer = AdamConfig().build(model_params) - loaded_optimizer = load_optimizer_state(loaded_optimizer, tmp_path) - - torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict()) - - -@pytest.fixture -def base_params_dict(): - return { - "actor": [torch.nn.Parameter(torch.randn(10, 10))], - "critic": [torch.nn.Parameter(torch.randn(5, 5))], - "temperature": [torch.nn.Parameter(torch.randn(3, 3))], - } - - -@pytest.mark.parametrize( - "config_params, expected_values", - [ - # Test 1: Basic configuration with different learning rates - ( - { - "lr": 1e-3, - "weight_decay": 1e-4, - "optimizer_groups": { - "actor": {"lr": 1e-4}, - "critic": {"lr": 5e-4}, - "temperature": {"lr": 2e-3}, - }, - }, - { - "actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)}, - "critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)}, - "temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999)}, - }, - ), - # Test 2: Different weight decays and beta values - ( - { - "lr": 1e-3, - "weight_decay": 1e-4, - "optimizer_groups": { - "actor": {"lr": 1e-4, "weight_decay": 1e-5}, - "critic": {"lr": 5e-4, "weight_decay": 1e-6}, - "temperature": {"lr": 2e-3, "betas": (0.95, 0.999)}, - }, - }, - { - "actor": {"lr": 1e-4, "weight_decay": 1e-5, "betas": (0.9, 0.999)}, - "critic": {"lr": 5e-4, "weight_decay": 1e-6, "betas": (0.9, 0.999)}, - "temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.95, 0.999)}, - }, - ), - # Test 3: Epsilon parameter customization - ( - { - "lr": 1e-3, - "weight_decay": 1e-4, - "optimizer_groups": { - "actor": {"lr": 1e-4, "eps": 1e-6}, - "critic": {"lr": 5e-4, "eps": 1e-7}, - "temperature": {"lr": 2e-3, "eps": 1e-8}, - }, - }, - { - "actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-6}, - "critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-7}, - "temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-8}, - }, - ), - ], -) -def test_multi_adam_configuration(base_params_dict, config_params, expected_values): - # Create config with the given parameters - config = MultiAdamConfig(**config_params) - optimizers = config.build(base_params_dict) - - # Verify optimizer count and keys - assert len(optimizers) == len(expected_values) - assert set(optimizers.keys()) == set(expected_values.keys()) - - # Check that all optimizers are Adam instances - for opt in optimizers.values(): - assert isinstance(opt, torch.optim.Adam) - - # Verify hyperparameters for each optimizer - for name, expected in expected_values.items(): - optimizer = optimizers[name] - for param, value in expected.items(): - assert optimizer.defaults[param] == value - - -@pytest.fixture -def multi_optimizers(base_params_dict): - config = MultiAdamConfig( - lr=1e-3, - optimizer_groups={ - "actor": {"lr": 1e-4}, - "critic": {"lr": 5e-4}, - "temperature": {"lr": 2e-3}, - }, - ) - return config.build(base_params_dict) - - -def test_save_multi_optimizer_state(multi_optimizers, tmp_path): - # Save optimizer states - save_optimizer_state(multi_optimizers, tmp_path) - - # Verify that directories were created for each optimizer - for name in multi_optimizers: - assert (tmp_path / name).is_dir() - assert (tmp_path / name / OPTIMIZER_STATE).is_file() - assert (tmp_path / name / OPTIMIZER_PARAM_GROUPS).is_file() - - -def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers, tmp_path): - # Option 1: Add a minimal backward pass to populate optimizer states - for name, params in base_params_dict.items(): - if name in multi_optimizers: - # Create a dummy loss and do backward - dummy_loss = params[0].sum() - dummy_loss.backward() - # Perform an optimization step - multi_optimizers[name].step() - # Zero gradients for next steps - multi_optimizers[name].zero_grad() - - # Save optimizer states - save_optimizer_state(multi_optimizers, tmp_path) - - # Create new optimizers with the same config - config = MultiAdamConfig( - lr=1e-3, - optimizer_groups={ - "actor": {"lr": 1e-4}, - "critic": {"lr": 5e-4}, - "temperature": {"lr": 2e-3}, - }, - ) - new_optimizers = config.build(base_params_dict) - - # Load optimizer states - loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path) - - # Verify state dictionaries match - for name in multi_optimizers: - torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict()) - - -def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path): - """Test saving and loading optimizer states even when the state is empty (no backward pass).""" - # Create config and build optimizers - config = MultiAdamConfig( - lr=1e-3, - optimizer_groups={ - "actor": {"lr": 1e-4}, - "critic": {"lr": 5e-4}, - "temperature": {"lr": 2e-3}, - }, - ) - optimizers = config.build(base_params_dict) - - # Save optimizer states without any backward pass (empty state) - save_optimizer_state(optimizers, tmp_path) - - # Create new optimizers with the same config - new_optimizers = config.build(base_params_dict) - - # Load optimizer states - loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path) - - # Verify hyperparameters match even with empty state - for name, optimizer in optimizers.items(): - assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"] - assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"] - assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"] - - # Verify state dictionaries match (they will be empty) - torch.testing.assert_close( - optimizer.state_dict()["param_groups"], loaded_optimizers[name].state_dict()["param_groups"] - ) diff --git a/tests/optim/test_schedulers.py b/tests/optim/test_schedulers.py deleted file mode 100644 index 43851c4..0000000 --- a/tests/optim/test_schedulers.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from torch.optim.lr_scheduler import LambdaLR - -from lerobot.constants import SCHEDULER_STATE -from lerobot.optim.schedulers import ( - CosineDecayWithWarmupSchedulerConfig, - DiffuserSchedulerConfig, - VQBeTSchedulerConfig, - load_scheduler_state, - save_scheduler_state, -) - - -def test_diffuser_scheduler(optimizer): - config = DiffuserSchedulerConfig(name="cosine", num_warmup_steps=5) - scheduler = config.build(optimizer, num_training_steps=100) - assert isinstance(scheduler, LambdaLR) - - optimizer.step() # so that we don't get torch warning - scheduler.step() - expected_state_dict = { - "_get_lr_called_within_step": False, - "_last_lr": [0.0002], - "_step_count": 2, - "base_lrs": [0.001], - "last_epoch": 1, - "lr_lambdas": [None], - } - assert scheduler.state_dict() == expected_state_dict - - -def test_vqbet_scheduler(optimizer): - config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5) - scheduler = config.build(optimizer, num_training_steps=100) - assert isinstance(scheduler, LambdaLR) - - optimizer.step() - scheduler.step() - expected_state_dict = { - "_get_lr_called_within_step": False, - "_last_lr": [0.001], - "_step_count": 2, - "base_lrs": [0.001], - "last_epoch": 1, - "lr_lambdas": [None], - } - assert scheduler.state_dict() == expected_state_dict - - -def test_cosine_decay_with_warmup_scheduler(optimizer): - config = CosineDecayWithWarmupSchedulerConfig( - num_warmup_steps=10, num_decay_steps=90, peak_lr=0.01, decay_lr=0.001 - ) - scheduler = config.build(optimizer, num_training_steps=100) - assert isinstance(scheduler, LambdaLR) - - optimizer.step() - scheduler.step() - expected_state_dict = { - "_get_lr_called_within_step": False, - "_last_lr": [0.0001818181818181819], - "_step_count": 2, - "base_lrs": [0.001], - "last_epoch": 1, - "lr_lambdas": [None], - } - assert scheduler.state_dict() == expected_state_dict - - -def test_save_scheduler_state(scheduler, tmp_path): - save_scheduler_state(scheduler, tmp_path) - assert (tmp_path / SCHEDULER_STATE).is_file() - - -def test_save_load_scheduler_state(scheduler, tmp_path): - save_scheduler_state(scheduler, tmp_path) - loaded_scheduler = load_scheduler_state(scheduler, tmp_path) - - assert scheduler.state_dict() == loaded_scheduler.state_dict() diff --git a/tests/policies/hilserl/test_modeling_classifier.py b/tests/policies/hilserl/test_modeling_classifier.py deleted file mode 100644 index 0be1b9c..0000000 --- a/tests/policies/hilserl/test_modeling_classifier.py +++ /dev/null @@ -1,139 +0,0 @@ -# !/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig -from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput -from tests.utils import require_package - - -def test_classifier_output(): - output = ClassifierOutput( - logits=torch.tensor([1, 2, 3]), - probabilities=torch.tensor([0.1, 0.2, 0.3]), - hidden_states=None, - ) - - assert ( - f"{output}" - == "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)" - ) - - -@require_package("transformers") -def test_binary_classifier_with_default_params(): - from lerobot.policies.sac.reward_model.modeling_classifier import Classifier - - config = RewardClassifierConfig() - config.input_features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), - } - config.output_features = { - "next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)), - } - config.normalization_mapping = { - "VISUAL": NormalizationMode.IDENTITY, - "REWARD": NormalizationMode.IDENTITY, - } - config.num_cameras = 1 - classifier = Classifier(config) - - batch_size = 10 - - input = { - "observation.image": torch.rand((batch_size, 3, 128, 128)), - "next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(), - } - - images, labels = classifier.extract_images_and_labels(input) - assert len(images) == 1 - assert images[0].shape == torch.Size([batch_size, 3, 128, 128]) - assert labels.shape == torch.Size([batch_size]) - - output = classifier.predict(images) - - assert output is not None - assert output.logits.size() == torch.Size([batch_size]) - assert not torch.isnan(output.logits).any(), "Tensor contains NaN values" - assert output.probabilities.shape == torch.Size([batch_size]) - assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values" - assert output.hidden_states.shape == torch.Size([batch_size, 256]) - assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values" - - -@require_package("transformers") -def test_multiclass_classifier(): - from lerobot.policies.sac.reward_model.modeling_classifier import Classifier - - num_classes = 5 - config = RewardClassifierConfig() - config.input_features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), - } - config.output_features = { - "next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)), - } - config.num_cameras = 1 - config.num_classes = num_classes - classifier = Classifier(config) - - batch_size = 10 - - input = { - "observation.image": torch.rand((batch_size, 3, 128, 128)), - "next.reward": torch.rand((batch_size, num_classes)), - } - - images, labels = classifier.extract_images_and_labels(input) - assert len(images) == 1 - assert images[0].shape == torch.Size([batch_size, 3, 128, 128]) - assert labels.shape == torch.Size([batch_size, num_classes]) - - output = classifier.predict(images) - - assert output is not None - assert output.logits.shape == torch.Size([batch_size, num_classes]) - assert not torch.isnan(output.logits).any(), "Tensor contains NaN values" - assert output.probabilities.shape == torch.Size([batch_size, num_classes]) - assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values" - assert output.hidden_states.shape == torch.Size([batch_size, 256]) - assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values" - - -@require_package("transformers") -def test_default_device(): - from lerobot.policies.sac.reward_model.modeling_classifier import Classifier - - config = RewardClassifierConfig() - assert config.device == "cpu" - - classifier = Classifier(config) - for p in classifier.parameters(): - assert p.device == torch.device("cpu") - - -@require_package("transformers") -def test_explicit_device_setup(): - from lerobot.policies.sac.reward_model.modeling_classifier import Classifier - - config = RewardClassifierConfig(device="cpu") - assert config.device == "cpu" - - classifier = Classifier(config) - for p in classifier.parameters(): - assert p.device == torch.device("cpu") diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py deleted file mode 100644 index da7573d..0000000 --- a/tests/policies/test_policies.py +++ /dev/null @@ -1,546 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import inspect -from copy import deepcopy -from pathlib import Path - -import einops -import pytest -import torch -from packaging import version -from safetensors.torch import load_file - -from lerobot import available_policies -from lerobot.configs.default import DatasetConfig -from lerobot.configs.train import TrainPipelineConfig -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.constants import ACTION, OBS_STATE -from lerobot.datasets.factory import make_dataset -from lerobot.datasets.utils import cycle, dataset_to_policy_features -from lerobot.envs.factory import make_env, make_env_config -from lerobot.envs.utils import preprocess_observation -from lerobot.optim.factory import make_optimizer_and_scheduler -from lerobot.policies.act.configuration_act import ACTConfig -from lerobot.policies.act.modeling_act import ACTTemporalEnsembler -from lerobot.policies.factory import ( - get_policy_class, - make_policy, - make_policy_config, -) -from lerobot.policies.normalize import Normalize, Unnormalize -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.utils.random_utils import seeded_context -from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats -from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel - - -@pytest.fixture -def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_path): - # Create only one camera input which is squared to fit all current policy constraints - # e.g. vqbet and tdmpc works with one camera only, and tdmpc requires it to be squared - camera_features = { - "observation.images.laptop": { - "shape": (84, 84, 3), - "names": ["height", "width", "channels"], - "info": None, - }, - } - motor_features = { - "action": { - "dtype": "float32", - "shape": (6,), - "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], - }, - "observation.state": { - "dtype": "float32", - "shape": (6,), - "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], - }, - } - info = info_factory( - total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features - ) - ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info) - return ds_meta - - -@pytest.mark.parametrize("policy_name", available_policies) -def test_get_policy_and_config_classes(policy_name: str): - """Check that the correct policy and config classes are returned.""" - policy_cls = get_policy_class(policy_name) - policy_cfg = make_policy_config(policy_name) - assert policy_cls.name == policy_name - assert issubclass( - policy_cfg.__class__, inspect.signature(policy_cls.__init__).parameters["config"].annotation - ) - - -@pytest.mark.parametrize( - "ds_repo_id,env_name,env_kwargs,policy_name,policy_kwargs", - [ - ("lerobot/xarm_lift_medium", "xarm", {}, "tdmpc", {"use_mpc": True}), - ("lerobot/pusht", "pusht", {}, "diffusion", {}), - ("lerobot/pusht", "pusht", {}, "vqbet", {}), - ("lerobot/pusht", "pusht", {}, "act", {}), - ("lerobot/aloha_sim_insertion_human", "aloha", {"task": "AlohaInsertion-v0"}, "act", {}), - ( - "lerobot/aloha_sim_insertion_scripted", - "aloha", - {"task": "AlohaInsertion-v0"}, - "act", - {}, - ), - ( - "lerobot/aloha_sim_insertion_human", - "aloha", - {"task": "AlohaInsertion-v0"}, - "diffusion", - {}, - ), - ( - "lerobot/aloha_sim_transfer_cube_human", - "aloha", - {"task": "AlohaTransferCube-v0"}, - "act", - {}, - ), - ( - "lerobot/aloha_sim_transfer_cube_scripted", - "aloha", - {"task": "AlohaTransferCube-v0"}, - "act", - {}, - ), - ], -) -@require_env -def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): - """ - Tests: - - Making the policy object. - - Checking that the policy follows the correct protocol and subclasses nn.Module - and PyTorchModelHubMixin. - - Updating the policy. - - Using the policy to select actions at inference time. - - Test the action can be applied to the policy - - Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive, - and for now we add tests as we see fit. - """ - - train_cfg = TrainPipelineConfig( - # TODO(rcadene, aliberts): remove dataset download - dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]), - policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs), - env=make_env_config(env_name, **env_kwargs), - ) - train_cfg.validate() - - # Check that we can make the policy object. - dataset = make_dataset(train_cfg) - policy = make_policy(train_cfg.policy, ds_meta=dataset.meta) - assert isinstance(policy, PreTrainedPolicy) - - # Check that we run select_actions and get the appropriate output. - env = make_env(train_cfg.env, n_envs=2) - - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=0, - batch_size=2, - shuffle=True, - pin_memory=DEVICE != "cpu", - drop_last=True, - ) - dl_iter = cycle(dataloader) - - batch = next(dl_iter) - - for key in batch: - if isinstance(batch[key], torch.Tensor): - batch[key] = batch[key].to(DEVICE, non_blocking=True) - - # Test updating the policy (and test that it does not mutate the batch) - batch_ = deepcopy(batch) - policy.forward(batch) - assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass." - assert all( - torch.equal(batch[k], batch_[k]) if isinstance(batch[k], torch.Tensor) else batch[k] == batch_[k] - for k in batch - ), "Batch values are not the same after a forward pass." - - # reset the policy and environment - policy.reset() - observation, _ = env.reset(seed=train_cfg.seed) - - # apply transform to normalize the observations - observation = preprocess_observation(observation) - - # send observation to device/gpu - observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation} - - # get the next action for the environment (also check that the observation batch is not modified) - observation_ = deepcopy(observation) - with torch.inference_mode(): - action = policy.select_action(observation).cpu().numpy() - assert set(observation) == set(observation_), ( - "Observation batch keys are not the same after a forward pass." - ) - assert all(torch.equal(observation[k], observation_[k]) for k in observation), ( - "Observation batch values are not the same after a forward pass." - ) - - # Test step through policy - env.step(action) - - -# TODO(rcadene, aliberts): This test is quite end-to-end. Move this test in test_optimizer? -def test_act_backbone_lr(): - """ - Test that the ACT policy can be instantiated with a different learning rate for the backbone. - """ - - cfg = TrainPipelineConfig( - # TODO(rcadene, aliberts): remove dataset download - dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]), - policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001, push_to_hub=False), - ) - cfg.validate() # Needed for auto-setting some parameters - - assert cfg.policy.optimizer_lr == 0.01 - assert cfg.policy.optimizer_lr_backbone == 0.001 - - dataset = make_dataset(cfg) - policy = make_policy(cfg.policy, ds_meta=dataset.meta) - optimizer, _ = make_optimizer_and_scheduler(cfg, policy) - assert len(optimizer.param_groups) == 2 - assert optimizer.param_groups[0]["lr"] == cfg.policy.optimizer_lr - assert optimizer.param_groups[1]["lr"] == cfg.policy.optimizer_lr_backbone - assert len(optimizer.param_groups[0]["params"]) == 133 - assert len(optimizer.param_groups[1]["params"]) == 20 - - -@pytest.mark.parametrize("policy_name", available_policies) -def test_policy_defaults(dummy_dataset_metadata, policy_name: str): - """Check that the policy can be instantiated with defaults.""" - policy_cls = get_policy_class(policy_name) - policy_cfg = make_policy_config(policy_name) - features = dataset_to_policy_features(dummy_dataset_metadata.features) - policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} - policy_cfg.input_features = { - key: ft for key, ft in features.items() if key not in policy_cfg.output_features - } - policy_cls(policy_cfg) - - -@pytest.mark.parametrize("policy_name", available_policies) -def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: str): - policy_cls = get_policy_class(policy_name) - policy_cfg = make_policy_config(policy_name) - features = dataset_to_policy_features(dummy_dataset_metadata.features) - policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} - policy_cfg.input_features = { - key: ft for key, ft in features.items() if key not in policy_cfg.output_features - } - policy = policy_cls(policy_cfg) - policy.to(policy_cfg.device) - save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}" - policy.save_pretrained(save_dir) - loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg) - torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0) - - -@pytest.mark.parametrize("insert_temporal_dim", [False, True]) -def test_normalize(insert_temporal_dim): - """ - Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise - an exception when the forward pass is called without the stats having been provided. - - TODO(rcadene, alexander-soare): This should also test that the normalization / unnormalization works as - expected. - """ - - input_features = { - "observation.image": PolicyFeature( - type=FeatureType.VISUAL, - shape=(3, 96, 96), - ), - "observation.state": PolicyFeature( - type=FeatureType.STATE, - shape=(10,), - ), - } - output_features = { - "action": PolicyFeature( - type=FeatureType.ACTION, - shape=(5,), - ), - } - - norm_map = { - "VISUAL": NormalizationMode.MEAN_STD, - "STATE": NormalizationMode.MIN_MAX, - "ACTION": NormalizationMode.MIN_MAX, - } - - dataset_stats = { - "observation.image": { - "mean": torch.randn(3, 1, 1), - "std": torch.randn(3, 1, 1), - "min": torch.randn(3, 1, 1), - "max": torch.randn(3, 1, 1), - }, - "observation.state": { - "mean": torch.randn(10), - "std": torch.randn(10), - "min": torch.randn(10), - "max": torch.randn(10), - }, - "action": { - "mean": torch.randn(5), - "std": torch.randn(5), - "min": torch.randn(5), - "max": torch.randn(5), - }, - } - - bsize = 2 - input_batch = { - "observation.image": torch.randn(bsize, 3, 96, 96), - "observation.state": torch.randn(bsize, 10), - } - output_batch = { - "action": torch.randn(bsize, 5), - } - - if insert_temporal_dim: - tdim = 4 - - for key in input_batch: - # [2,3,96,96] -> [2,tdim,3,96,96] - input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1) - - for key in output_batch: - output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1) - - # test without stats - normalize = Normalize(input_features, norm_map, stats=None) - with pytest.raises(AssertionError): - normalize(input_batch) - - # test with stats - normalize = Normalize(input_features, norm_map, stats=dataset_stats) - normalize(input_batch) - - # test loading pretrained models - new_normalize = Normalize(input_features, norm_map, stats=None) - new_normalize.load_state_dict(normalize.state_dict()) - new_normalize(input_batch) - - # test without stats - unnormalize = Unnormalize(output_features, norm_map, stats=None) - with pytest.raises(AssertionError): - unnormalize(output_batch) - - # test with stats - unnormalize = Unnormalize(output_features, norm_map, stats=dataset_stats) - unnormalize(output_batch) - - # test loading pretrained models - new_unnormalize = Unnormalize(output_features, norm_map, stats=None) - new_unnormalize.load_state_dict(unnormalize.state_dict()) - unnormalize(output_batch) - - -@pytest.mark.parametrize("multikey", [True, False]) -def test_multikey_construction(multikey: bool): - """ - Asserts that multiple keys with type State/Action are correctly processed by the policy constructor, - preventing erroneous creation of the policy object. - """ - input_features = { - "observation.state": PolicyFeature( - type=FeatureType.STATE, - shape=(10,), - ), - } - output_features = { - "action": PolicyFeature( - type=FeatureType.ACTION, - shape=(5,), - ), - } - - if multikey: - """Simulates the complete state/action is constructed from more granular multiple - keys, of the same type as the overall state/action""" - input_features = {} - input_features["observation.state.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) - input_features["observation.state.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) - input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(10,)) - - output_features = {} - output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,)) - output_features["action.last_two_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(2,)) - output_features["action"] = PolicyFeature( - type=FeatureType.ACTION, - shape=(5,), - ) - - config = ACTConfig(input_features=input_features, output_features=output_features) - - state_condition = config.robot_state_feature == input_features[OBS_STATE] - action_condition = config.action_feature == output_features[ACTION] - - assert state_condition, ( - f"Discrepancy detected. Robot state feature is {config.robot_state_feature} but policy expects {input_features[OBS_STATE]}" - ) - assert action_condition, ( - f"Discrepancy detected. Action feature is {config.action_feature} but policy expects {output_features[ACTION]}" - ) - - -@pytest.mark.parametrize( - "ds_repo_id, policy_name, policy_kwargs, file_name_extra", - [ - # TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it - # was changed to true. For some reason, tests would pass locally, but not in CI. So here we override - # to test with `policy.use_mpc=false`. - ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"), - # ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"), - # TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to - # to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference - # that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass. - # Thus, we deactivate this test for now. - ( - "lerobot/pusht", - "diffusion", - { - "n_action_steps": 8, - "num_inference_steps": 10, - "down_dims": [128, 256, 512], - }, - "", - ), - ("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""), - ( - "lerobot/aloha_sim_insertion_human", - "act", - {"n_action_steps": 1000, "chunk_size": 1000}, - "1000_steps", - ), - ], -) -# As artifacts have been generated on an x86_64 kernel, this test won't -# pass if it's run on another platform due to floating point errors -@require_x86_64_kernel -@require_cpu -def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str): - """ - NOTE: If this test does not pass, and you have intentionally changed something in the policy: - 1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should - include a report on what changed and how that affected the outputs. - 2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and - add the policies you want to update the test artifacts for. - 3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact - should be updated. - 4. Check that this test now passes. - 5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state. - 6. Remember to stage and commit the resulting changes to `tests/artifacts`. - - NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact - is out of date. For example, some PyTorch versions have different randomness, see this PR: - https://github.com/huggingface/lerobot/pull/1127. - - """ - - # NOTE: ACT policy has different randomness, after PyTorch 2.7.0 - if policy_name == "act" and version.parse(torch.__version__) < version.parse("2.7.0"): - pytest.skip(f"Skipping act policy test with PyTorch {torch.__version__}. Requires PyTorch >= 2.7.0") - - ds_name = ds_repo_id.split("/")[-1] - artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}" - saved_output_dict = load_file(artifact_dir / "output_dict.safetensors") - saved_grad_stats = load_file(artifact_dir / "grad_stats.safetensors") - saved_param_stats = load_file(artifact_dir / "param_stats.safetensors") - saved_actions = load_file(artifact_dir / "actions.safetensors") - - output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs) - - for key in saved_output_dict: - torch.testing.assert_close(output_dict[key], saved_output_dict[key]) - for key in saved_grad_stats: - torch.testing.assert_close(grad_stats[key], saved_grad_stats[key]) - for key in saved_param_stats: - torch.testing.assert_close(param_stats[key], saved_param_stats[key]) - for key in saved_actions: - rtol, atol = (2e-3, 5e-6) if policy_name == "diffusion" else (None, None) # HACK - torch.testing.assert_close(actions[key], saved_actions[key], rtol=rtol, atol=atol) - - -def test_act_temporal_ensembler(): - """Check that the online method in ACTTemporalEnsembler matches a simple offline calculation.""" - temporal_ensemble_coeff = 0.01 - chunk_size = 100 - episode_length = 101 - ensembler = ACTTemporalEnsembler(temporal_ensemble_coeff, chunk_size) - # An batch of arbitrary sequences of 1D actions we wish to compute the average over. We'll keep the - # "action space" in [-1, 1]. Apart from that, there is no real reason for the numbers chosen. - with seeded_context(0): - # Dimension is (batch, episode_length, chunk_size, action_dim(=1)) - # Stepping through the episode_length dim is like running inference at each rollout step and getting - # a different action chunk. - batch_seq = torch.stack( - [ - torch.rand(episode_length, chunk_size) * 0.05 - 0.6, - torch.rand(episode_length, chunk_size) * 0.02 - 0.01, - torch.rand(episode_length, chunk_size) * 0.2 + 0.3, - ], - dim=0, - ).unsqueeze(-1) # unsqueeze for action dim - batch_size = batch_seq.shape[0] - # Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length` - # dimension of `batch_seq`. - weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1) - - # Simulate stepping through a rollout and computing a batch of actions with model on each step. - for i in range(episode_length): - # Mock a batch of actions. - actions = torch.zeros(size=(batch_size, chunk_size, 1)) + batch_seq[:, i] - online_avg = ensembler.update(actions) - # Simple offline calculation: avg = Σ(aᵢ*wᵢ) / Σ(wᵢ). - # Note: The complicated bit here is the slicing. Think about the (episode_length, chunk_size) grid. - # What we want to do is take diagonal slices across it starting from the left. - # eg: chunk_size=4, episode_length=6 - # ┌───────┐ - # │0 1 2 3│ - # │1 2 3 4│ - # │2 3 4 5│ - # │3 4 5 6│ - # │4 5 6 7│ - # │5 6 7 8│ - # └───────┘ - chunk_indices = torch.arange(min(i, chunk_size - 1), -1, -1) - episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :] - seq_slice = batch_seq[:, episode_step_indices, chunk_indices] - offline_avg = ( - einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum() - ) - # Sanity check. The average should be between the extrema. - assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg) - assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max")) - # Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error. - torch.testing.assert_close(online_avg, offline_avg, rtol=1e-4, atol=1e-4) diff --git a/tests/policies/test_sac_config.py b/tests/policies/test_sac_config.py deleted file mode 100644 index a67815e..0000000 --- a/tests/policies/test_sac_config.py +++ /dev/null @@ -1,217 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.policies.sac.configuration_sac import ( - ActorLearnerConfig, - ActorNetworkConfig, - ConcurrencyConfig, - CriticNetworkConfig, - PolicyConfig, - SACConfig, -) - - -def test_sac_config_default_initialization(): - config = SACConfig() - - assert config.normalization_mapping == { - "VISUAL": NormalizationMode.MEAN_STD, - "STATE": NormalizationMode.MIN_MAX, - "ENV": NormalizationMode.MIN_MAX, - "ACTION": NormalizationMode.MIN_MAX, - } - assert config.dataset_stats == { - "observation.image": { - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - }, - "observation.state": { - "min": [0.0, 0.0], - "max": [1.0, 1.0], - }, - "action": { - "min": [0.0, 0.0, 0.0], - "max": [1.0, 1.0, 1.0], - }, - } - - # Basic parameters - assert config.device == "cpu" - assert config.storage_device == "cpu" - assert config.discount == 0.99 - assert config.temperature_init == 1.0 - assert config.num_critics == 2 - - # Architecture specifics - assert config.vision_encoder_name is None - assert config.freeze_vision_encoder is True - assert config.image_encoder_hidden_dim == 32 - assert config.shared_encoder is True - assert config.num_discrete_actions is None - assert config.image_embedding_pooling_dim == 8 - - # Training parameters - assert config.online_steps == 1000000 - assert config.online_env_seed == 10000 - assert config.online_buffer_capacity == 100000 - assert config.offline_buffer_capacity == 100000 - assert config.async_prefetch is False - assert config.online_step_before_learning == 100 - assert config.policy_update_freq == 1 - - # SAC algorithm parameters - assert config.num_subsample_critics is None - assert config.critic_lr == 3e-4 - assert config.actor_lr == 3e-4 - assert config.temperature_lr == 3e-4 - assert config.critic_target_update_weight == 0.005 - assert config.utd_ratio == 1 - assert config.state_encoder_hidden_dim == 256 - assert config.latent_dim == 256 - assert config.target_entropy is None - assert config.use_backup_entropy is True - assert config.grad_clip_norm == 40.0 - - # Dataset stats defaults - expected_dataset_stats = { - "observation.image": { - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - }, - "observation.state": { - "min": [0.0, 0.0], - "max": [1.0, 1.0], - }, - "action": { - "min": [0.0, 0.0, 0.0], - "max": [1.0, 1.0, 1.0], - }, - } - assert config.dataset_stats == expected_dataset_stats - - # Critic network configuration - assert config.critic_network_kwargs.hidden_dims == [256, 256] - assert config.critic_network_kwargs.activate_final is True - assert config.critic_network_kwargs.final_activation is None - - # Actor network configuration - assert config.actor_network_kwargs.hidden_dims == [256, 256] - assert config.actor_network_kwargs.activate_final is True - - # Policy configuration - assert config.policy_kwargs.use_tanh_squash is True - assert config.policy_kwargs.std_min == 1e-5 - assert config.policy_kwargs.std_max == 10.0 - assert config.policy_kwargs.init_final == 0.05 - - # Discrete critic network configuration - assert config.discrete_critic_network_kwargs.hidden_dims == [256, 256] - assert config.discrete_critic_network_kwargs.activate_final is True - assert config.discrete_critic_network_kwargs.final_activation is None - - # Actor learner configuration - assert config.actor_learner_config.learner_host == "127.0.0.1" - assert config.actor_learner_config.learner_port == 50051 - assert config.actor_learner_config.policy_parameters_push_frequency == 4 - - # Concurrency configuration - assert config.concurrency.actor == "threads" - assert config.concurrency.learner == "threads" - - assert isinstance(config.actor_network_kwargs, ActorNetworkConfig) - assert isinstance(config.critic_network_kwargs, CriticNetworkConfig) - assert isinstance(config.policy_kwargs, PolicyConfig) - assert isinstance(config.actor_learner_config, ActorLearnerConfig) - assert isinstance(config.concurrency, ConcurrencyConfig) - - -def test_critic_network_kwargs(): - config = CriticNetworkConfig() - assert config.hidden_dims == [256, 256] - assert config.activate_final is True - assert config.final_activation is None - - -def test_actor_network_kwargs(): - config = ActorNetworkConfig() - assert config.hidden_dims == [256, 256] - assert config.activate_final is True - - -def test_policy_kwargs(): - config = PolicyConfig() - assert config.use_tanh_squash is True - assert config.std_min == 1e-5 - assert config.std_max == 10.0 - assert config.init_final == 0.05 - - -def test_actor_learner_config(): - config = ActorLearnerConfig() - assert config.learner_host == "127.0.0.1" - assert config.learner_port == 50051 - assert config.policy_parameters_push_frequency == 4 - - -def test_concurrency_config(): - config = ConcurrencyConfig() - assert config.actor == "threads" - assert config.learner == "threads" - - -def test_sac_config_custom_initialization(): - config = SACConfig( - device="cpu", - discount=0.95, - temperature_init=0.5, - num_critics=3, - ) - - assert config.device == "cpu" - assert config.discount == 0.95 - assert config.temperature_init == 0.5 - assert config.num_critics == 3 - - -def test_validate_features(): - config = SACConfig( - input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, - output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, - ) - config.validate_features() - - -def test_validate_features_missing_observation(): - config = SACConfig( - input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, - output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, - ) - with pytest.raises( - ValueError, match="You must provide either 'observation.state' or an image observation" - ): - config.validate_features() - - -def test_validate_features_missing_action(): - config = SACConfig( - input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, - output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, - ) - with pytest.raises(ValueError, match="You must provide 'action' in the output features"): - config.validate_features() diff --git a/tests/policies/test_sac_policy.py b/tests/policies/test_sac_policy.py deleted file mode 100644 index 7891c2e..0000000 --- a/tests/policies/test_sac_policy.py +++ /dev/null @@ -1,541 +0,0 @@ -# !/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -import pytest -import torch -from torch import Tensor, nn - -from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.policies.sac.configuration_sac import SACConfig -from lerobot.policies.sac.modeling_sac import MLP, SACPolicy -from lerobot.utils.random_utils import seeded_context, set_seed - -try: - import transformers # noqa: F401 - - TRANSFORMERS_AVAILABLE = True -except ImportError: - TRANSFORMERS_AVAILABLE = False - - -@pytest.fixture(autouse=True) -def set_random_seed(): - seed = 42 - set_seed(seed) - - -def test_mlp_with_default_args(): - mlp = MLP(input_dim=10, hidden_dims=[256, 256]) - - x = torch.randn(10) - y = mlp(x) - assert y.shape == (256,) - - -def test_mlp_with_batch_dim(): - mlp = MLP(input_dim=10, hidden_dims=[256, 256]) - x = torch.randn(2, 10) - y = mlp(x) - assert y.shape == (2, 256) - - -def test_forward_with_empty_hidden_dims(): - mlp = MLP(input_dim=10, hidden_dims=[]) - x = torch.randn(1, 10) - assert mlp(x).shape == (1, 10) - - -def test_mlp_with_dropout(): - mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1) - x = torch.randn(1, 10) - y = mlp(x) - assert y.shape == (1, 11) - - drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net) - assert drop_out_layers_count == 2 - - -def test_mlp_with_custom_final_activation(): - mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh()) - x = torch.randn(1, 10) - y = mlp(x) - assert y.shape == (1, 256) - assert (y >= -1).all() and (y <= 1).all() - - -def test_sac_policy_with_default_args(): - with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"): - SACPolicy() - - -def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor: - return { - "observation.state": torch.randn(batch_size, state_dim), - } - - -def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor: - return { - "observation.image": torch.randn(batch_size, 3, 84, 84), - "observation.state": torch.randn(batch_size, state_dim), - } - - -def create_dummy_action(batch_size: int, action_dim: int = 10) -> Tensor: - return torch.randn(batch_size, action_dim) - - -def create_default_train_batch( - batch_size: int = 8, state_dim: int = 10, action_dim: int = 10 -) -> dict[str, Tensor]: - return { - "action": create_dummy_action(batch_size, action_dim), - "reward": torch.randn(batch_size), - "state": create_dummy_state(batch_size, state_dim), - "next_state": create_dummy_state(batch_size, state_dim), - "done": torch.randn(batch_size), - } - - -def create_train_batch_with_visual_input( - batch_size: int = 8, state_dim: int = 10, action_dim: int = 10 -) -> dict[str, Tensor]: - return { - "action": create_dummy_action(batch_size, action_dim), - "reward": torch.randn(batch_size), - "state": create_dummy_with_visual_input(batch_size, state_dim), - "next_state": create_dummy_with_visual_input(batch_size, state_dim), - "done": torch.randn(batch_size), - } - - -def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: - return { - "observation.state": torch.randn(batch_size, state_dim), - } - - -def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: - return { - "observation.state": torch.randn(batch_size, state_dim), - "observation.image": torch.randn(batch_size, 3, 84, 84), - } - - -def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]: - """Create optimizers for the SAC policy.""" - optimizer_actor = torch.optim.Adam( - # Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient - params=[ - p - for n, p in policy.actor.named_parameters() - if not policy.config.shared_encoder or not n.startswith("encoder") - ], - lr=policy.config.actor_lr, - ) - optimizer_critic = torch.optim.Adam( - params=policy.critic_ensemble.parameters(), - lr=policy.config.critic_lr, - ) - optimizer_temperature = torch.optim.Adam( - params=[policy.log_alpha], - lr=policy.config.critic_lr, - ) - - optimizers = { - "actor": optimizer_actor, - "critic": optimizer_critic, - "temperature": optimizer_temperature, - } - - if has_discrete_action: - optimizers["discrete_critic"] = torch.optim.Adam( - params=policy.discrete_critic.parameters(), - lr=policy.config.critic_lr, - ) - - return optimizers - - -def create_default_config( - state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False -) -> SACConfig: - action_dim = continuous_action_dim - if has_discrete_action: - action_dim += 1 - - config = SACConfig( - input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, - output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))}, - dataset_stats={ - "observation.state": { - "min": [0.0] * state_dim, - "max": [1.0] * state_dim, - }, - "action": { - "min": [0.0] * continuous_action_dim, - "max": [1.0] * continuous_action_dim, - }, - }, - ) - config.validate_features() - return config - - -def create_config_with_visual_input( - state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False -) -> SACConfig: - config = create_default_config( - state_dim=state_dim, - continuous_action_dim=continuous_action_dim, - has_discrete_action=has_discrete_action, - ) - config.input_features["observation.image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84)) - config.dataset_stats["observation.image"] = { - "mean": torch.randn(3, 1, 1), - "std": torch.randn(3, 1, 1), - } - - # Let make tests a little bit faster - config.state_encoder_hidden_dim = 32 - config.latent_dim = 32 - - config.validate_features() - return config - - -@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) -def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int): - batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim) - config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) - - policy = SACPolicy(config=config) - policy.train() - - optimizers = make_optimizers(policy) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - actor_loss = policy.forward(batch, model="actor")["loss_actor"] - assert actor_loss.item() is not None - assert actor_loss.shape == () - - actor_loss.backward() - optimizers["actor"].step() - - temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] - assert temperature_loss.item() is not None - assert temperature_loss.shape == () - - temperature_loss.backward() - optimizers["temperature"].step() - - policy.eval() - with torch.no_grad(): - observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) - selected_action = policy.select_action(observation_batch) - assert selected_action.shape == (batch_size, action_dim) - - -@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) -def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int): - config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) - policy = SACPolicy(config=config) - - batch = create_train_batch_with_visual_input( - batch_size=batch_size, state_dim=state_dim, action_dim=action_dim - ) - - policy.train() - - optimizers = make_optimizers(policy) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - actor_loss = policy.forward(batch, model="actor")["loss_actor"] - assert actor_loss.item() is not None - assert actor_loss.shape == () - - actor_loss.backward() - optimizers["actor"].step() - - temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] - assert temperature_loss.item() is not None - assert temperature_loss.shape == () - - temperature_loss.backward() - optimizers["temperature"].step() - - policy.eval() - with torch.no_grad(): - observation_batch = create_observation_batch_with_visual_input( - batch_size=batch_size, state_dim=state_dim - ) - selected_action = policy.select_action(observation_batch) - assert selected_action.shape == (batch_size, action_dim) - - -# Let's check best candidates for pretrained encoders -@pytest.mark.parametrize( - "batch_size,state_dim,action_dim,vision_encoder_name", - [(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")], -) -@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed") -def test_sac_policy_with_pretrained_encoder( - batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str -): - config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) - config.vision_encoder_name = vision_encoder_name - policy = SACPolicy(config=config) - policy.train() - - batch = create_train_batch_with_visual_input( - batch_size=batch_size, state_dim=state_dim, action_dim=action_dim - ) - - optimizers = make_optimizers(policy) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - actor_loss = policy.forward(batch, model="actor")["loss_actor"] - assert actor_loss.item() is not None - assert actor_loss.shape == () - - -def test_sac_policy_with_shared_encoder(): - batch_size = 2 - action_dim = 10 - state_dim = 10 - config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) - config.shared_encoder = True - - policy = SACPolicy(config=config) - policy.train() - - batch = create_train_batch_with_visual_input( - batch_size=batch_size, state_dim=state_dim, action_dim=action_dim - ) - - policy.train() - - optimizers = make_optimizers(policy) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - actor_loss = policy.forward(batch, model="actor")["loss_actor"] - assert actor_loss.item() is not None - assert actor_loss.shape == () - - actor_loss.backward() - optimizers["actor"].step() - - -def test_sac_policy_with_discrete_critic(): - batch_size = 2 - continuous_action_dim = 9 - full_action_dim = continuous_action_dim + 1 # the last action is discrete - state_dim = 10 - config = create_config_with_visual_input( - state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True - ) - - num_discrete_actions = 5 - config.num_discrete_actions = num_discrete_actions - - policy = SACPolicy(config=config) - policy.train() - - batch = create_train_batch_with_visual_input( - batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim - ) - - policy.train() - - optimizers = make_optimizers(policy, has_discrete_action=True) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"] - assert discrete_critic_loss.item() is not None - assert discrete_critic_loss.shape == () - discrete_critic_loss.backward() - optimizers["discrete_critic"].step() - - actor_loss = policy.forward(batch, model="actor")["loss_actor"] - assert actor_loss.item() is not None - assert actor_loss.shape == () - - actor_loss.backward() - optimizers["actor"].step() - - policy.eval() - with torch.no_grad(): - observation_batch = create_observation_batch_with_visual_input( - batch_size=batch_size, state_dim=state_dim - ) - selected_action = policy.select_action(observation_batch) - assert selected_action.shape == (batch_size, full_action_dim) - - discrete_actions = selected_action[:, -1].long() - discrete_action_values = set(discrete_actions.tolist()) - - assert all(action in range(num_discrete_actions) for action in discrete_action_values), ( - f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})" - ) - - -def test_sac_policy_with_default_entropy(): - config = create_default_config(continuous_action_dim=10, state_dim=10) - policy = SACPolicy(config=config) - assert policy.target_entropy == -5.0 - - -def test_sac_policy_default_target_entropy_with_discrete_action(): - config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True) - policy = SACPolicy(config=config) - assert policy.target_entropy == -3.0 - - -def test_sac_policy_with_predefined_entropy(): - config = create_default_config(state_dim=10, continuous_action_dim=6) - config.target_entropy = -3.5 - - policy = SACPolicy(config=config) - assert policy.target_entropy == pytest.approx(-3.5) - - -def test_sac_policy_update_temperature(): - config = create_default_config(continuous_action_dim=10, state_dim=10) - policy = SACPolicy(config=config) - - assert policy.temperature == pytest.approx(1.0) - policy.log_alpha.data = torch.tensor([math.log(0.1)]) - policy.update_temperature() - assert policy.temperature == pytest.approx(0.1) - - -def test_sac_policy_update_target_network(): - config = create_default_config(state_dim=10, continuous_action_dim=6) - config.critic_target_update_weight = 1.0 - - policy = SACPolicy(config=config) - policy.train() - - for p in policy.critic_ensemble.parameters(): - p.data = torch.ones_like(p.data) - - policy.update_target_networks() - for p in policy.critic_target.parameters(): - assert torch.allclose(p.data, torch.ones_like(p.data)), ( - f"Target network {p.data} is not equal to {torch.ones_like(p.data)}" - ) - - -@pytest.mark.parametrize("num_critics", [1, 3]) -def test_sac_policy_with_critics_number_of_heads(num_critics: int): - batch_size = 2 - action_dim = 10 - state_dim = 10 - config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) - config.num_critics = num_critics - - policy = SACPolicy(config=config) - policy.train() - - assert len(policy.critic_ensemble.critics) == num_critics - - batch = create_train_batch_with_visual_input( - batch_size=batch_size, state_dim=state_dim, action_dim=action_dim - ) - - policy.train() - - optimizers = make_optimizers(policy) - - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - assert cirtic_loss.item() is not None - assert cirtic_loss.shape == () - cirtic_loss.backward() - optimizers["critic"].step() - - -def test_sac_policy_save_and_load(tmp_path): - root = tmp_path / "test_sac_save_and_load" - - state_dim = 10 - action_dim = 10 - batch_size = 2 - - config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) - policy = SACPolicy(config=config) - policy.eval() - policy.save_pretrained(root) - loaded_policy = SACPolicy.from_pretrained(root, config=config) - loaded_policy.eval() - - batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10) - - with torch.no_grad(): - with seeded_context(12): - # Collect policy values before saving - cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] - actor_loss = policy.forward(batch, model="actor")["loss_actor"] - temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] - - observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) - actions = policy.select_action(observation_batch) - - with seeded_context(12): - # Collect policy values after loading - loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"] - loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"] - loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"] - - loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) - loaded_actions = loaded_policy.select_action(loaded_observation_batch) - - assert policy.state_dict().keys() == loaded_policy.state_dict().keys() - for k in policy.state_dict(): - assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6) - - # Compare values before and after saving and loading - # They should be the same - assert torch.allclose(cirtic_loss, loaded_cirtic_loss) - assert torch.allclose(actor_loss, loaded_actor_loss) - assert torch.allclose(temperature_loss, loaded_temperature_loss) - assert torch.allclose(actions, loaded_actions) diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py deleted file mode 100644 index 6389402..0000000 --- a/tests/processor/test_batch_conversion.py +++ /dev/null @@ -1,282 +0,0 @@ -import torch - -from lerobot.processor.pipeline import ( - RobotProcessor, - TransitionKey, - _default_batch_to_transition, - _default_transition_to_batch, -) - - -def _dummy_batch(): - """Create a dummy batch using the new format with observation.* and next.* keys.""" - return { - "observation.image.left": torch.randn(1, 3, 128, 128), - "observation.image.right": torch.randn(1, 3, 128, 128), - "observation.state": torch.tensor([[0.1, 0.2, 0.3, 0.4]]), - "action": torch.tensor([[0.5]]), - "next.reward": 1.0, - "next.done": False, - "next.truncated": False, - "info": {"key": "value"}, - } - - -def test_observation_grouping_roundtrip(): - """Test that observation.* keys are properly grouped and ungrouped.""" - proc = RobotProcessor([]) - batch_in = _dummy_batch() - batch_out = proc(batch_in) - - # Check that all observation.* keys are preserved - original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith("observation.")} - reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith("observation.")} - - assert set(original_obs_keys.keys()) == set(reconstructed_obs_keys.keys()) - - # Check tensor values - assert torch.allclose(batch_out["observation.image.left"], batch_in["observation.image.left"]) - assert torch.allclose(batch_out["observation.image.right"], batch_in["observation.image.right"]) - assert torch.allclose(batch_out["observation.state"], batch_in["observation.state"]) - - # Check other fields - assert torch.allclose(batch_out["action"], batch_in["action"]) - assert batch_out["next.reward"] == batch_in["next.reward"] - assert batch_out["next.done"] == batch_in["next.done"] - assert batch_out["next.truncated"] == batch_in["next.truncated"] - assert batch_out["info"] == batch_in["info"] - - -def test_batch_to_transition_observation_grouping(): - """Test that _default_batch_to_transition correctly groups observation.* keys.""" - batch = { - "observation.image.top": torch.randn(1, 3, 128, 128), - "observation.image.left": torch.randn(1, 3, 128, 128), - "observation.state": [1, 2, 3, 4], - "action": "action_data", - "next.reward": 1.5, - "next.done": True, - "next.truncated": False, - "info": {"episode": 42}, - } - - transition = _default_batch_to_transition(batch) - - # Check observation is a dict with all observation.* keys - assert isinstance(transition[TransitionKey.OBSERVATION], dict) - assert "observation.image.top" in transition[TransitionKey.OBSERVATION] - assert "observation.image.left" in transition[TransitionKey.OBSERVATION] - assert "observation.state" in transition[TransitionKey.OBSERVATION] - - # Check values are preserved - assert torch.allclose( - transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"] - ) - assert torch.allclose( - transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"] - ) - assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4] - - # Check other fields - assert transition[TransitionKey.ACTION] == "action_data" - assert transition[TransitionKey.REWARD] == 1.5 - assert transition[TransitionKey.DONE] - assert not transition[TransitionKey.TRUNCATED] - assert transition[TransitionKey.INFO] == {"episode": 42} - assert transition[TransitionKey.COMPLEMENTARY_DATA] == {} - - -def test_transition_to_batch_observation_flattening(): - """Test that _default_transition_to_batch correctly flattens observation dict.""" - observation_dict = { - "observation.image.top": torch.randn(1, 3, 128, 128), - "observation.image.left": torch.randn(1, 3, 128, 128), - "observation.state": [1, 2, 3, 4], - } - - transition = { - TransitionKey.OBSERVATION: observation_dict, - TransitionKey.ACTION: "action_data", - TransitionKey.REWARD: 1.5, - TransitionKey.DONE: True, - TransitionKey.TRUNCATED: False, - TransitionKey.INFO: {"episode": 42}, - TransitionKey.COMPLEMENTARY_DATA: {}, - } - - batch = _default_transition_to_batch(transition) - - # Check that observation.* keys are flattened back to batch - assert "observation.image.top" in batch - assert "observation.image.left" in batch - assert "observation.state" in batch - - # Check values are preserved - assert torch.allclose(batch["observation.image.top"], observation_dict["observation.image.top"]) - assert torch.allclose(batch["observation.image.left"], observation_dict["observation.image.left"]) - assert batch["observation.state"] == [1, 2, 3, 4] - - # Check other fields are mapped to next.* format - assert batch["action"] == "action_data" - assert batch["next.reward"] == 1.5 - assert batch["next.done"] - assert not batch["next.truncated"] - assert batch["info"] == {"episode": 42} - - -def test_no_observation_keys(): - """Test behavior when there are no observation.* keys.""" - batch = { - "action": "action_data", - "next.reward": 2.0, - "next.done": False, - "next.truncated": True, - "info": {"test": "no_obs"}, - } - - transition = _default_batch_to_transition(batch) - - # Observation should be None when no observation.* keys - assert transition[TransitionKey.OBSERVATION] is None - - # Check other fields - assert transition[TransitionKey.ACTION] == "action_data" - assert transition[TransitionKey.REWARD] == 2.0 - assert not transition[TransitionKey.DONE] - assert transition[TransitionKey.TRUNCATED] - assert transition[TransitionKey.INFO] == {"test": "no_obs"} - - # Round trip should work - reconstructed_batch = _default_transition_to_batch(transition) - assert reconstructed_batch["action"] == "action_data" - assert reconstructed_batch["next.reward"] == 2.0 - assert not reconstructed_batch["next.done"] - assert reconstructed_batch["next.truncated"] - assert reconstructed_batch["info"] == {"test": "no_obs"} - - -def test_minimal_batch(): - """Test with minimal batch containing only observation.* and action.""" - batch = {"observation.state": "minimal_state", "action": "minimal_action"} - - transition = _default_batch_to_transition(batch) - - # Check observation - assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"} - assert transition[TransitionKey.ACTION] == "minimal_action" - - # Check defaults - assert transition[TransitionKey.REWARD] == 0.0 - assert not transition[TransitionKey.DONE] - assert not transition[TransitionKey.TRUNCATED] - assert transition[TransitionKey.INFO] == {} - assert transition[TransitionKey.COMPLEMENTARY_DATA] == {} - - # Round trip - reconstructed_batch = _default_transition_to_batch(transition) - assert reconstructed_batch["observation.state"] == "minimal_state" - assert reconstructed_batch["action"] == "minimal_action" - assert reconstructed_batch["next.reward"] == 0.0 - assert not reconstructed_batch["next.done"] - assert not reconstructed_batch["next.truncated"] - assert reconstructed_batch["info"] == {} - - -def test_empty_batch(): - """Test behavior with empty batch.""" - batch = {} - - transition = _default_batch_to_transition(batch) - - # All fields should have defaults - assert transition[TransitionKey.OBSERVATION] is None - assert transition[TransitionKey.ACTION] is None - assert transition[TransitionKey.REWARD] == 0.0 - assert not transition[TransitionKey.DONE] - assert not transition[TransitionKey.TRUNCATED] - assert transition[TransitionKey.INFO] == {} - assert transition[TransitionKey.COMPLEMENTARY_DATA] == {} - - # Round trip - reconstructed_batch = _default_transition_to_batch(transition) - assert reconstructed_batch["action"] is None - assert reconstructed_batch["next.reward"] == 0.0 - assert not reconstructed_batch["next.done"] - assert not reconstructed_batch["next.truncated"] - assert reconstructed_batch["info"] == {} - - -def test_complex_nested_observation(): - """Test with complex nested observation data.""" - batch = { - "observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890}, - "observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891}, - "observation.state": torch.randn(7), - "action": torch.randn(8), - "next.reward": 3.14, - "next.done": False, - "next.truncated": True, - "info": {"episode_length": 200, "success": True}, - } - - transition = _default_batch_to_transition(batch) - reconstructed_batch = _default_transition_to_batch(transition) - - # Check that all observation keys are preserved - original_obs_keys = {k for k in batch if k.startswith("observation.")} - reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith("observation.")} - - assert original_obs_keys == reconstructed_obs_keys - - # Check tensor values - assert torch.allclose(batch["observation.state"], reconstructed_batch["observation.state"]) - - # Check nested dict with tensors - assert torch.allclose( - batch["observation.image.top"]["image"], reconstructed_batch["observation.image.top"]["image"] - ) - assert torch.allclose( - batch["observation.image.left"]["image"], reconstructed_batch["observation.image.left"]["image"] - ) - - # Check action tensor - assert torch.allclose(batch["action"], reconstructed_batch["action"]) - - # Check other fields - assert batch["next.reward"] == reconstructed_batch["next.reward"] - assert batch["next.done"] == reconstructed_batch["next.done"] - assert batch["next.truncated"] == reconstructed_batch["next.truncated"] - assert batch["info"] == reconstructed_batch["info"] - - -def test_custom_converter(): - """Test that custom converters can still be used.""" - - def to_tr(batch): - # Custom converter that modifies the reward - tr = _default_batch_to_transition(batch) - # Double the reward - reward = tr.get(TransitionKey.REWARD, 0.0) - new_tr = tr.copy() - new_tr[TransitionKey.REWARD] = reward * 2 if reward is not None else 0.0 - return new_tr - - def to_batch(tr): - batch = _default_transition_to_batch(tr) - return batch - - processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch) - - batch = { - "observation.state": torch.randn(1, 4), - "action": torch.randn(1, 2), - "next.reward": 1.0, - "next.done": False, - } - - result = processor(batch) - - # Check the reward was doubled by our custom converter - assert result["next.reward"] == 2.0 - assert torch.allclose(result["observation.state"], batch["observation.state"]) - assert torch.allclose(result["action"], batch["action"]) diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py deleted file mode 100644 index 26aea56..0000000 --- a/tests/processor/test_normalize_processor.py +++ /dev/null @@ -1,628 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from unittest.mock import Mock - -import numpy as np -import pytest -import torch - -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.processor.normalize_processor import ( - NormalizerProcessor, - UnnormalizerProcessor, - _convert_stats_to_tensors, -) -from lerobot.processor.pipeline import RobotProcessor, TransitionKey - - -def create_transition( - observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None -): - """Helper to create an EnvTransition dictionary.""" - return { - TransitionKey.OBSERVATION: observation, - TransitionKey.ACTION: action, - TransitionKey.REWARD: reward, - TransitionKey.DONE: done, - TransitionKey.TRUNCATED: truncated, - TransitionKey.INFO: info, - TransitionKey.COMPLEMENTARY_DATA: complementary_data, - } - - -def test_numpy_conversion(): - stats = { - "observation.image": { - "mean": np.array([0.5, 0.5, 0.5]), - "std": np.array([0.2, 0.2, 0.2]), - } - } - tensor_stats = _convert_stats_to_tensors(stats) - - assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor) - assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor) - assert torch.allclose(tensor_stats["observation.image"]["mean"], torch.tensor([0.5, 0.5, 0.5])) - assert torch.allclose(tensor_stats["observation.image"]["std"], torch.tensor([0.2, 0.2, 0.2])) - - -def test_tensor_conversion(): - stats = { - "action": { - "mean": torch.tensor([0.0, 0.0]), - "std": torch.tensor([1.0, 1.0]), - } - } - tensor_stats = _convert_stats_to_tensors(stats) - - assert tensor_stats["action"]["mean"].dtype == torch.float32 - assert tensor_stats["action"]["std"].dtype == torch.float32 - - -def test_scalar_conversion(): - stats = { - "reward": { - "mean": 0.5, - "std": 0.1, - } - } - tensor_stats = _convert_stats_to_tensors(stats) - - assert torch.allclose(tensor_stats["reward"]["mean"], torch.tensor(0.5)) - assert torch.allclose(tensor_stats["reward"]["std"], torch.tensor(0.1)) - - -def test_list_conversion(): - stats = { - "observation.state": { - "min": [0.0, -1.0, -2.0], - "max": [1.0, 1.0, 2.0], - } - } - tensor_stats = _convert_stats_to_tensors(stats) - - assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0])) - assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0])) - - -def test_unsupported_type(): - stats = { - "bad_key": { - "mean": "string_value", - } - } - with pytest.raises(TypeError, match="Unsupported type"): - _convert_stats_to_tensors(stats) - - -# Helper functions to create feature maps and norm maps -def _create_observation_features(): - return { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), - } - - -def _create_observation_norm_map(): - return { - FeatureType.VISUAL: NormalizationMode.MEAN_STD, - FeatureType.STATE: NormalizationMode.MIN_MAX, - } - - -# Fixtures for observation normalisation tests using NormalizerProcessor -@pytest.fixture -def observation_stats(): - return { - "observation.image": { - "mean": np.array([0.5, 0.5, 0.5]), - "std": np.array([0.2, 0.2, 0.2]), - }, - "observation.state": { - "min": np.array([0.0, -1.0]), - "max": np.array([1.0, 1.0]), - }, - } - - -@pytest.fixture -def observation_normalizer(observation_stats): - """Return a NormalizerProcessor that only has observation stats (no action).""" - features = _create_observation_features() - norm_map = _create_observation_norm_map() - return NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats) - - -def test_mean_std_normalization(observation_normalizer): - observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), - } - transition = create_transition(observation=observation) - - normalized_transition = observation_normalizer(transition) - normalized_obs = normalized_transition[TransitionKey.OBSERVATION] - - # Check mean/std normalization - expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 - assert torch.allclose(normalized_obs["observation.image"], expected_image) - - -def test_min_max_normalization(observation_normalizer): - observation = { - "observation.state": torch.tensor([0.5, 0.0]), - } - transition = create_transition(observation=observation) - - normalized_transition = observation_normalizer(transition) - normalized_obs = normalized_transition[TransitionKey.OBSERVATION] - - # Check min/max normalization to [-1, 1] - # For state[0]: 2 * (0.5 - 0.0) / (1.0 - 0.0) - 1 = 0.0 - # For state[1]: 2 * (0.0 - (-1.0)) / (1.0 - (-1.0)) - 1 = 0.0 - expected_state = torch.tensor([0.0, 0.0]) - assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6) - - -def test_selective_normalization(observation_stats): - features = _create_observation_features() - norm_map = _create_observation_norm_map() - normalizer = NormalizerProcessor( - features=features, norm_map=norm_map, stats=observation_stats, normalize_keys={"observation.image"} - ) - - observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), - } - transition = create_transition(observation=observation) - - normalized_transition = normalizer(transition) - normalized_obs = normalized_transition[TransitionKey.OBSERVATION] - - # Only image should be normalized - assert torch.allclose(normalized_obs["observation.image"], (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2) - # State should remain unchanged - assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"]) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_device_compatibility(observation_stats): - features = _create_observation_features() - norm_map = _create_observation_norm_map() - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats) - observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(), - } - transition = create_transition(observation=observation) - - normalized_transition = normalizer(transition) - normalized_obs = normalized_transition[TransitionKey.OBSERVATION] - - assert normalized_obs["observation.image"].device.type == "cuda" - - -def test_from_lerobot_dataset(): - # Mock dataset - mock_dataset = Mock() - mock_dataset.meta.stats = { - "observation.image": {"mean": [0.5], "std": [0.2]}, - "action": {"mean": [0.0], "std": [1.0]}, - } - - features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "action": PolicyFeature(FeatureType.ACTION, (1,)), - } - norm_map = { - FeatureType.VISUAL: NormalizationMode.MEAN_STD, - FeatureType.ACTION: NormalizationMode.MEAN_STD, - } - - normalizer = NormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) - - # Both observation and action statistics should be present in tensor stats - assert "observation.image" in normalizer._tensor_stats - assert "action" in normalizer._tensor_stats - - -def test_state_dict_save_load(observation_normalizer): - # Save state - state_dict = observation_normalizer.state_dict() - - # Create new normalizer and load state - features = _create_observation_features() - norm_map = _create_observation_norm_map() - new_normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={}) - new_normalizer.load_state_dict(state_dict) - - # Test that it works the same - observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} - transition = create_transition(observation=observation) - - result1 = observation_normalizer(transition)[TransitionKey.OBSERVATION] - result2 = new_normalizer(transition)[TransitionKey.OBSERVATION] - - assert torch.allclose(result1["observation.image"], result2["observation.image"]) - - -# Fixtures for ActionUnnormalizer tests -@pytest.fixture -def action_stats_mean_std(): - return { - "mean": np.array([0.0, 0.0, 0.0]), - "std": np.array([1.0, 2.0, 0.5]), - } - - -@pytest.fixture -def action_stats_min_max(): - return { - "min": np.array([-1.0, -2.0, 0.0]), - "max": np.array([1.0, 2.0, 1.0]), - } - - -def _create_action_features(): - return { - "action": PolicyFeature(FeatureType.ACTION, (3,)), - } - - -def _create_action_norm_map_mean_std(): - return { - FeatureType.ACTION: NormalizationMode.MEAN_STD, - } - - -def _create_action_norm_map_min_max(): - return { - FeatureType.ACTION: NormalizationMode.MIN_MAX, - } - - -def test_mean_std_unnormalization(action_stats_mean_std): - features = _create_action_features() - norm_map = _create_action_norm_map_mean_std() - unnormalizer = UnnormalizerProcessor( - features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} - ) - - normalized_action = torch.tensor([1.0, -0.5, 2.0]) - transition = create_transition(action=normalized_action) - - unnormalized_transition = unnormalizer(transition) - unnormalized_action = unnormalized_transition[TransitionKey.ACTION] - - # action * std + mean - expected = torch.tensor([1.0 * 1.0 + 0.0, -0.5 * 2.0 + 0.0, 2.0 * 0.5 + 0.0]) - assert torch.allclose(unnormalized_action, expected) - - -def test_min_max_unnormalization(action_stats_min_max): - features = _create_action_features() - norm_map = _create_action_norm_map_min_max() - unnormalizer = UnnormalizerProcessor( - features=features, norm_map=norm_map, stats={"action": action_stats_min_max} - ) - - # Actions in [-1, 1] - normalized_action = torch.tensor([0.0, -1.0, 1.0]) - transition = create_transition(action=normalized_action) - - unnormalized_transition = unnormalizer(transition) - unnormalized_action = unnormalized_transition[TransitionKey.ACTION] - - # Map from [-1, 1] to [min, max] - # (action + 1) / 2 * (max - min) + min - expected = torch.tensor( - [ - (0.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0), # 0.0 - (-1.0 + 1) / 2 * (2.0 - (-2.0)) + (-2.0), # -2.0 - (1.0 + 1) / 2 * (1.0 - 0.0) + 0.0, # 1.0 - ] - ) - assert torch.allclose(unnormalized_action, expected) - - -def test_numpy_action_input(action_stats_mean_std): - features = _create_action_features() - norm_map = _create_action_norm_map_mean_std() - unnormalizer = UnnormalizerProcessor( - features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} - ) - - normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32) - transition = create_transition(action=normalized_action) - - unnormalized_transition = unnormalizer(transition) - unnormalized_action = unnormalized_transition[TransitionKey.ACTION] - - assert isinstance(unnormalized_action, torch.Tensor) - expected = torch.tensor([1.0, -1.0, 1.0]) - assert torch.allclose(unnormalized_action, expected) - - -def test_none_action(action_stats_mean_std): - features = _create_action_features() - norm_map = _create_action_norm_map_mean_std() - unnormalizer = UnnormalizerProcessor( - features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} - ) - - transition = create_transition() - result = unnormalizer(transition) - - # Should return transition unchanged - assert result == transition - - -def test_action_from_lerobot_dataset(): - mock_dataset = Mock() - mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}} - features = {"action": PolicyFeature(FeatureType.ACTION, (1,))} - norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD} - unnormalizer = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) - assert "mean" in unnormalizer._tensor_stats["action"] - - -# Fixtures for NormalizerProcessor tests -@pytest.fixture -def full_stats(): - return { - "observation.image": { - "mean": np.array([0.5, 0.5, 0.5]), - "std": np.array([0.2, 0.2, 0.2]), - }, - "observation.state": { - "min": np.array([0.0, -1.0]), - "max": np.array([1.0, 1.0]), - }, - "action": { - "mean": np.array([0.0, 0.0]), - "std": np.array([1.0, 2.0]), - }, - } - - -def _create_full_features(): - return { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), - } - - -def _create_full_norm_map(): - return { - FeatureType.VISUAL: NormalizationMode.MEAN_STD, - FeatureType.STATE: NormalizationMode.MIN_MAX, - FeatureType.ACTION: NormalizationMode.MEAN_STD, - } - - -@pytest.fixture -def normalizer_processor(full_stats): - features = _create_full_features() - norm_map = _create_full_norm_map() - return NormalizerProcessor(features=features, norm_map=norm_map, stats=full_stats) - - -def test_combined_normalization(normalizer_processor): - observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), - } - action = torch.tensor([1.0, -0.5]) - transition = create_transition( - observation=observation, - action=action, - reward=1.0, - done=False, - truncated=False, - info={}, - complementary_data={}, - ) - - processed_transition = normalizer_processor(transition) - - # Check normalized observations - processed_obs = processed_transition[TransitionKey.OBSERVATION] - expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 - assert torch.allclose(processed_obs["observation.image"], expected_image) - - # Check normalized action - processed_action = processed_transition[TransitionKey.ACTION] - expected_action = torch.tensor([(1.0 - 0.0) / 1.0, (-0.5 - 0.0) / 2.0]) - assert torch.allclose(processed_action, expected_action) - - # Check other fields remain unchanged - assert processed_transition[TransitionKey.REWARD] == 1.0 - assert not processed_transition[TransitionKey.DONE] - - -def test_processor_from_lerobot_dataset(full_stats): - # Mock dataset - mock_dataset = Mock() - mock_dataset.meta.stats = full_stats - - features = _create_full_features() - norm_map = _create_full_norm_map() - - processor = NormalizerProcessor.from_lerobot_dataset( - mock_dataset, features, norm_map, normalize_keys={"observation.image"} - ) - - assert processor.normalize_keys == {"observation.image"} - assert "observation.image" in processor._tensor_stats - assert "action" in processor._tensor_stats - - -def test_get_config(full_stats): - features = _create_full_features() - norm_map = _create_full_norm_map() - processor = NormalizerProcessor( - features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6 - ) - - config = processor.get_config() - expected_config = { - "normalize_keys": ["observation.image"], - "eps": 1e-6, - "features": { - "observation.image": {"type": "VISUAL", "shape": (3, 96, 96)}, - "observation.state": {"type": "STATE", "shape": (2,)}, - "action": {"type": "ACTION", "shape": (2,)}, - }, - "norm_map": { - "VISUAL": "MEAN_STD", - "STATE": "MIN_MAX", - "ACTION": "MEAN_STD", - }, - } - assert config == expected_config - - -def test_integration_with_robot_processor(normalizer_processor): - """Test integration with RobotProcessor pipeline""" - robot_processor = RobotProcessor([normalizer_processor]) - - observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), - } - action = torch.tensor([1.0, -0.5]) - transition = create_transition( - observation=observation, - action=action, - reward=1.0, - done=False, - truncated=False, - info={}, - complementary_data={}, - ) - - processed_transition = robot_processor(transition) - - # Verify the processing worked - assert isinstance(processed_transition[TransitionKey.OBSERVATION], dict) - assert isinstance(processed_transition[TransitionKey.ACTION], torch.Tensor) - - -# Edge case tests -def test_empty_observation(): - stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} - features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} - norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) - - transition = create_transition() - result = normalizer(transition) - - assert result == transition - - -def test_empty_stats(): - features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} - norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={}) - observation = {"observation.image": torch.tensor([0.5])} - transition = create_transition(observation=observation) - - result = normalizer(transition) - # Should return observation unchanged since no stats are available - assert torch.allclose( - result[TransitionKey.OBSERVATION]["observation.image"], observation["observation.image"] - ) - - -def test_partial_stats(): - """If statistics are incomplete, the value should pass through unchanged.""" - stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max) - features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} - norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) - observation = {"observation.image": torch.tensor([0.7])} - transition = create_transition(observation=observation) - - processed = normalizer(transition)[TransitionKey.OBSERVATION] - assert torch.allclose(processed["observation.image"], observation["observation.image"]) - - -def test_missing_action_stats_no_error(): - mock_dataset = Mock() - mock_dataset.meta.stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} - - features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} - norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} - - processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) - # The tensor stats should not contain the 'action' key - assert "action" not in processor._tensor_stats - - -def test_serialization_roundtrip(full_stats): - """Test that features and norm_map can be serialized and deserialized correctly.""" - features = _create_full_features() - norm_map = _create_full_norm_map() - original_processor = NormalizerProcessor( - features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6 - ) - - # Get config (serialization) - config = original_processor.get_config() - - # Create a new processor from the config (deserialization) - new_processor = NormalizerProcessor( - features=config["features"], - norm_map=config["norm_map"], - stats=full_stats, - normalize_keys=set(config["normalize_keys"]), - eps=config["eps"], - ) - - # Test that both processors work the same way - observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), - } - action = torch.tensor([1.0, -0.5]) - transition = create_transition( - observation=observation, - action=action, - reward=1.0, - done=False, - truncated=False, - info={}, - complementary_data={}, - ) - - result1 = original_processor(transition) - result2 = new_processor(transition) - - # Compare results - assert torch.allclose( - result1[TransitionKey.OBSERVATION]["observation.image"], - result2[TransitionKey.OBSERVATION]["observation.image"], - ) - assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION]) - - # Verify features and norm_map are correctly reconstructed - assert new_processor.features.keys() == original_processor.features.keys() - for key in new_processor.features: - assert new_processor.features[key].type == original_processor.features[key].type - assert new_processor.features[key].shape == original_processor.features[key].shape - - assert new_processor.norm_map == original_processor.norm_map diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py deleted file mode 100644 index e48b6bc..0000000 --- a/tests/processor/test_observation_processor.py +++ /dev/null @@ -1,486 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import pytest -import torch - -from lerobot.configs.types import FeatureType -from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE -from lerobot.processor import VanillaObservationProcessor -from lerobot.processor.pipeline import TransitionKey -from tests.conftest import assert_contract_is_typed - - -def create_transition( - observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None -): - """Helper to create an EnvTransition dictionary.""" - return { - TransitionKey.OBSERVATION: observation, - TransitionKey.ACTION: action, - TransitionKey.REWARD: reward, - TransitionKey.DONE: done, - TransitionKey.TRUNCATED: truncated, - TransitionKey.INFO: info, - TransitionKey.COMPLEMENTARY_DATA: complementary_data, - } - - -def test_process_single_image(): - """Test processing a single image.""" - processor = VanillaObservationProcessor() - - # Create a mock image (H, W, C) format, uint8 - image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) - - observation = {"pixels": image} - transition = create_transition(observation=observation) - - result = processor(transition) - processed_obs = result[TransitionKey.OBSERVATION] - - # Check that the image was processed correctly - assert "observation.image" in processed_obs - processed_img = processed_obs["observation.image"] - - # Check shape: should be (1, 3, 64, 64) - batch, channels, height, width - assert processed_img.shape == (1, 3, 64, 64) - - # Check dtype and range - assert processed_img.dtype == torch.float32 - assert processed_img.min() >= 0.0 - assert processed_img.max() <= 1.0 - - -def test_process_image_dict(): - """Test processing multiple images in a dictionary.""" - processor = VanillaObservationProcessor() - - # Create mock images - image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) - image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8) - - observation = {"pixels": {"camera1": image1, "camera2": image2}} - transition = create_transition(observation=observation) - - result = processor(transition) - processed_obs = result[TransitionKey.OBSERVATION] - - # Check that both images were processed - assert "observation.images.camera1" in processed_obs - assert "observation.images.camera2" in processed_obs - - # Check shapes - assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32) - assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48) - - -def test_process_batched_image(): - """Test processing already batched images.""" - processor = VanillaObservationProcessor() - - # Create a batched image (B, H, W, C) - image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8) - - observation = {"pixels": image} - transition = create_transition(observation=observation) - - result = processor(transition) - processed_obs = result[TransitionKey.OBSERVATION] - - # Check that batch dimension is preserved - assert processed_obs["observation.image"].shape == (2, 3, 64, 64) - - -def test_invalid_image_format(): - """Test error handling for invalid image formats.""" - processor = VanillaObservationProcessor() - - # Test wrong channel order (channels first) - image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8) - observation = {"pixels": image} - transition = create_transition(observation=observation) - - with pytest.raises(ValueError, match="Expected channel-last images"): - processor(transition) - - -def test_invalid_image_dtype(): - """Test error handling for invalid image dtype.""" - processor = VanillaObservationProcessor() - - # Test wrong dtype - image = np.random.rand(64, 64, 3).astype(np.float32) - observation = {"pixels": image} - transition = create_transition(observation=observation) - - with pytest.raises(ValueError, match="Expected torch.uint8 images"): - processor(transition) - - -def test_no_pixels_in_observation(): - """Test processor when no pixels are in observation.""" - processor = VanillaObservationProcessor() - - observation = {"other_data": np.array([1, 2, 3])} - transition = create_transition(observation=observation) - - result = processor(transition) - processed_obs = result[TransitionKey.OBSERVATION] - - # Should preserve other data unchanged - assert "other_data" in processed_obs - np.testing.assert_array_equal(processed_obs["other_data"], np.array([1, 2, 3])) - - -def test_none_observation(): - """Test processor with None observation.""" - processor = VanillaObservationProcessor() - - transition = create_transition() - result = processor(transition) - - assert result == transition - - -def test_serialization_methods(): - """Test serialization methods.""" - processor = VanillaObservationProcessor() - - # Test get_config - config = processor.get_config() - assert isinstance(config, dict) - - # Test state_dict - state = processor.state_dict() - assert isinstance(state, dict) - - # Test load_state_dict (should not raise) - processor.load_state_dict(state) - - # Test reset (should not raise) - processor.reset() - - -def test_process_environment_state(): - """Test processing environment_state.""" - processor = VanillaObservationProcessor() - - env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) - observation = {"environment_state": env_state} - transition = create_transition(observation=observation) - - result = processor(transition) - processed_obs = result[TransitionKey.OBSERVATION] - - # Check that environment_state was renamed and processed - assert "observation.environment_state" in processed_obs - assert "environment_state" not in processed_obs - - processed_state = processed_obs["observation.environment_state"] - assert processed_state.shape == (1, 3) # Batch dimension added - assert processed_state.dtype == torch.float32 - torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]])) - - -def test_process_agent_pos(): - """Test processing agent_pos.""" - processor = VanillaObservationProcessor() - - agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) - observation = {"agent_pos": agent_pos} - transition = create_transition(observation=observation) - - result = processor(transition) - processed_obs = result[TransitionKey.OBSERVATION] - - # Check that agent_pos was renamed and processed - assert "observation.state" in processed_obs - assert "agent_pos" not in processed_obs - - processed_state = processed_obs["observation.state"] - assert processed_state.shape == (1, 3) # Batch dimension added - assert processed_state.dtype == torch.float32 - torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]])) - - -def test_process_batched_states(): - """Test processing already batched states.""" - processor = VanillaObservationProcessor() - - env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) - agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32) - - observation = {"environment_state": env_state, "agent_pos": agent_pos} - transition = create_transition(observation=observation) - - result = processor(transition) - processed_obs = result[TransitionKey.OBSERVATION] - - # Check that batch dimensions are preserved - assert processed_obs["observation.environment_state"].shape == (2, 2) - assert processed_obs["observation.state"].shape == (2, 2) - - -def test_process_both_states(): - """Test processing both environment_state and agent_pos.""" - processor = VanillaObservationProcessor() - - env_state = np.array([1.0, 2.0], dtype=np.float32) - agent_pos = np.array([0.5, -0.5], dtype=np.float32) - - observation = {"environment_state": env_state, "agent_pos": agent_pos, "other_data": "keep_me"} - transition = create_transition(observation=observation) - - result = processor(transition) - processed_obs = result[TransitionKey.OBSERVATION] - - # Check that both states were processed - assert "observation.environment_state" in processed_obs - assert "observation.state" in processed_obs - - # Check that original keys were removed - assert "environment_state" not in processed_obs - assert "agent_pos" not in processed_obs - - # Check that other data was preserved - assert processed_obs["other_data"] == "keep_me" - - -def test_no_states_in_observation(): - """Test processor when no states are in observation.""" - processor = VanillaObservationProcessor() - - observation = {"other_data": np.array([1, 2, 3])} - transition = create_transition(observation=observation) - - result = processor(transition) - processed_obs = result[TransitionKey.OBSERVATION] - - # Should preserve data unchanged - np.testing.assert_array_equal(processed_obs, observation) - - -def test_complete_observation_processing(): - """Test processing a complete observation with both images and states.""" - processor = VanillaObservationProcessor() - - # Create mock data - image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) - env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) - agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) - - observation = { - "pixels": image, - "environment_state": env_state, - "agent_pos": agent_pos, - "other_data": "preserve_me", - } - transition = create_transition(observation=observation) - - result = processor(transition) - processed_obs = result[TransitionKey.OBSERVATION] - - # Check that image was processed - assert "observation.image" in processed_obs - assert processed_obs["observation.image"].shape == (1, 3, 32, 32) - - # Check that states were processed - assert "observation.environment_state" in processed_obs - assert "observation.state" in processed_obs - - # Check that original keys were removed - assert "pixels" not in processed_obs - assert "environment_state" not in processed_obs - assert "agent_pos" not in processed_obs - - # Check that other data was preserved - assert processed_obs["other_data"] == "preserve_me" - - -def test_image_only_processing(): - """Test processing observation with only images.""" - processor = VanillaObservationProcessor() - - image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) - observation = {"pixels": image} - transition = create_transition(observation=observation) - - result = processor(transition) - processed_obs = result[TransitionKey.OBSERVATION] - - assert "observation.image" in processed_obs - assert len(processed_obs) == 1 - - -def test_state_only_processing(): - """Test processing observation with only states.""" - processor = VanillaObservationProcessor() - - agent_pos = np.array([1.0, 2.0], dtype=np.float32) - observation = {"agent_pos": agent_pos} - transition = create_transition(observation=observation) - - result = processor(transition) - processed_obs = result[TransitionKey.OBSERVATION] - - assert "observation.state" in processed_obs - assert "agent_pos" not in processed_obs - - -def test_empty_observation(): - """Test processing empty observation.""" - processor = VanillaObservationProcessor() - - observation = {} - transition = create_transition(observation=observation) - - result = processor(transition) - processed_obs = result[TransitionKey.OBSERVATION] - - assert processed_obs == {} - - -def test_equivalent_to_original_function(): - """Test that ObservationProcessor produces equivalent results to preprocess_observation.""" - # Import the original function for comparison - from lerobot.envs.utils import preprocess_observation - - processor = VanillaObservationProcessor() - - # Create test data similar to what the original function expects - image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) - env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) - agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) - - observation = {"pixels": image, "environment_state": env_state, "agent_pos": agent_pos} - - # Process with original function - original_result = preprocess_observation(observation) - - # Process with new processor - transition = create_transition(observation=observation) - processor_result = processor(transition)[TransitionKey.OBSERVATION] - - # Compare results - assert set(original_result.keys()) == set(processor_result.keys()) - - for key in original_result: - torch.testing.assert_close(original_result[key], processor_result[key]) - - -def test_equivalent_with_image_dict(): - """Test equivalence with dictionary of images.""" - from lerobot.envs.utils import preprocess_observation - - processor = VanillaObservationProcessor() - - # Create test data with multiple cameras - image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) - image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8) - agent_pos = np.array([1.0, 2.0], dtype=np.float32) - - observation = {"pixels": {"cam1": image1, "cam2": image2}, "agent_pos": agent_pos} - - # Process with original function - original_result = preprocess_observation(observation) - - # Process with new processor - transition = create_transition(observation=observation) - processor_result = processor(transition)[TransitionKey.OBSERVATION] - - # Compare results - assert set(original_result.keys()) == set(processor_result.keys()) - - for key in original_result: - torch.testing.assert_close(original_result[key], processor_result[key]) - - -def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory): - processor = VanillaObservationProcessor() - features = { - "pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), - "keep": policy_feature_factory(FeatureType.ENV, (1,)), - } - out = processor.feature_contract(features.copy()) - - assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"] - assert "pixels" not in out - assert out["keep"] == features["keep"] - assert_contract_is_typed(out) - - -def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory): - processor = VanillaObservationProcessor() - features = { - "observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), - "keep": policy_feature_factory(FeatureType.ENV, (1,)), - } - out = processor.feature_contract(features.copy()) - - assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"] - assert "observation.pixels" not in out - assert out["keep"] == features["keep"] - assert_contract_is_typed(out) - - -def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory): - processor = VanillaObservationProcessor() - features = { - "pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), - "pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), - "observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), - "keep": policy_feature_factory(FeatureType.ENV, (7,)), - } - out = processor.feature_contract(features.copy()) - - assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"] - assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"] - assert f"{OBS_IMAGES}.rear" in out and out[f"{OBS_IMAGES}.rear"] == features["observation.pixels.rear"] - assert "pixels.front" not in out and "pixels.wrist" not in out and "observation.pixels.rear" not in out - assert out["keep"] == features["keep"] - assert_contract_is_typed(out) - - -def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory): - processor = VanillaObservationProcessor() - features = { - "environment_state": policy_feature_factory(FeatureType.STATE, (3,)), - "agent_pos": policy_feature_factory(FeatureType.STATE, (7,)), - "keep": policy_feature_factory(FeatureType.ENV, (1,)), - } - out = processor.feature_contract(features.copy()) - - assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"] - assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"] - assert "environment_state" not in out and "agent_pos" not in out - assert out["keep"] == features["keep"] - assert_contract_is_typed(out) - - -def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory): - proc = VanillaObservationProcessor() - features = { - "observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)), - "observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)), - } - out = proc.feature_contract(features.copy()) - - assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"] - assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"] - assert "environment_state" not in out and "agent_pos" not in out - assert_contract_is_typed(out) diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py deleted file mode 100644 index 5665d5a..0000000 --- a/tests/processor/test_pipeline.py +++ /dev/null @@ -1,1919 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import tempfile -from collections.abc import Callable -from dataclasses import dataclass -from pathlib import Path -from typing import Any - -import pytest -import torch -import torch.nn as nn - -from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor -from lerobot.processor.pipeline import TransitionKey -from tests.conftest import assert_contract_is_typed - - -def create_transition( - observation=None, action=None, reward=0.0, done=False, truncated=False, info=None, complementary_data=None -): - """Helper to create an EnvTransition dictionary.""" - return { - TransitionKey.OBSERVATION: observation, - TransitionKey.ACTION: action, - TransitionKey.REWARD: reward, - TransitionKey.DONE: done, - TransitionKey.TRUNCATED: truncated, - TransitionKey.INFO: info if info is not None else {}, - TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {}, - } - - -@dataclass -class MockStep: - """Mock pipeline step for testing - demonstrates best practices. - - This example shows the proper separation: - - JSON-serializable attributes (name, counter) go in get_config() - - Only torch tensors go in state_dict() - - Note: The counter is part of the configuration, so it will be restored - when the step is recreated from config during loading. - """ - - name: str = "mock_step" - counter: int = 0 - - def __call__(self, transition: EnvTransition) -> EnvTransition: - """Add a counter to the complementary_data.""" - comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) - comp_data = {} if comp_data is None else dict(comp_data) # Make a copy - - comp_data[f"{self.name}_counter"] = self.counter - self.counter += 1 - - # Create a new transition with updated complementary_data - new_transition = transition.copy() - new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data - return new_transition - - def get_config(self) -> dict[str, Any]: - # Return all JSON-serializable attributes that should be persisted - # These will be passed to __init__ when loading - return {"name": self.name, "counter": self.counter} - - def state_dict(self) -> dict[str, torch.Tensor]: - # Only return torch tensors (empty in this case since we have no tensor state) - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - # No tensor state to load - pass - - def reset(self) -> None: - self.counter = 0 - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here - return features - - -@dataclass -class MockStepWithoutOptionalMethods: - """Mock step that only implements the required __call__ method.""" - - multiplier: float = 2.0 - - def __call__(self, transition: EnvTransition) -> EnvTransition: - """Multiply reward by multiplier.""" - reward = transition.get(TransitionKey.REWARD) - - if reward is not None: - new_transition = transition.copy() - new_transition[TransitionKey.REWARD] = reward * self.multiplier - return new_transition - - return transition - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here - return features - - -@dataclass -class MockStepWithTensorState: - """Mock step demonstrating mixed JSON attributes and tensor state.""" - - name: str = "tensor_step" - learning_rate: float = 0.01 - window_size: int = 10 - - def __init__(self, name: str = "tensor_step", learning_rate: float = 0.01, window_size: int = 10): - self.name = name - self.learning_rate = learning_rate - self.window_size = window_size - # Tensor state - self.running_mean = torch.zeros(window_size) - self.running_count = torch.tensor(0) - - def __call__(self, transition: EnvTransition) -> EnvTransition: - """Update running statistics.""" - reward = transition.get(TransitionKey.REWARD) - - if reward is not None: - # Update running mean - idx = self.running_count % self.window_size - self.running_mean[idx] = reward - self.running_count += 1 - - return transition - - def get_config(self) -> dict[str, Any]: - # Only JSON-serializable attributes - return { - "name": self.name, - "learning_rate": self.learning_rate, - "window_size": self.window_size, - } - - def state_dict(self) -> dict[str, torch.Tensor]: - # Only tensor state - return { - "running_mean": self.running_mean, - "running_count": self.running_count, - } - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - self.running_mean = state["running_mean"] - self.running_count = state["running_count"] - - def reset(self) -> None: - self.running_mean.zero_() - self.running_count.zero_() - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here - return features - - -def test_empty_pipeline(): - """Test pipeline with no steps.""" - pipeline = RobotProcessor() - - transition = create_transition() - result = pipeline(transition) - - assert result == transition - assert len(pipeline) == 0 - - -def test_single_step_pipeline(): - """Test pipeline with a single step.""" - step = MockStep("test_step") - pipeline = RobotProcessor([step]) - - transition = create_transition() - result = pipeline(transition) - - assert len(pipeline) == 1 - assert result[TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 0 - - # Call again to test counter increment - result = pipeline(transition) - assert result[TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 1 - - -def test_multiple_steps_pipeline(): - """Test pipeline with multiple steps.""" - step1 = MockStep("step1") - step2 = MockStep("step2") - pipeline = RobotProcessor([step1, step2]) - - transition = create_transition() - result = pipeline(transition) - - assert len(pipeline) == 2 - assert result[TransitionKey.COMPLEMENTARY_DATA]["step1_counter"] == 0 - assert result[TransitionKey.COMPLEMENTARY_DATA]["step2_counter"] == 0 - - -def test_invalid_transition_format(): - """Test pipeline with invalid transition format.""" - pipeline = RobotProcessor([MockStep()]) - - # Test with wrong type (tuple instead of dict) - with pytest.raises(ValueError, match="EnvTransition must be a dictionary"): - pipeline((None, None, 0.0, False, False, {}, {})) # Tuple instead of dict - - # Test with wrong type (string) - with pytest.raises(ValueError, match="EnvTransition must be a dictionary"): - pipeline("not a dict") - - -def test_step_through(): - """Test step_through method with dict input.""" - step1 = MockStep("step1") - step2 = MockStep("step2") - pipeline = RobotProcessor([step1, step2]) - - transition = create_transition() - - results = list(pipeline.step_through(transition)) - - assert len(results) == 3 # Original + 2 steps - assert results[0] == transition # Original - assert "step1_counter" in results[1][TransitionKey.COMPLEMENTARY_DATA] # After step1 - assert "step2_counter" in results[2][TransitionKey.COMPLEMENTARY_DATA] # After step2 - - # Ensure all results are dicts (same format as input) - for result in results: - assert isinstance(result, dict) - assert all(isinstance(k, TransitionKey) for k in result.keys()) - - -def test_step_through_with_dict(): - """Test step_through method with dict input.""" - step1 = MockStep("step1") - step2 = MockStep("step2") - pipeline = RobotProcessor([step1, step2]) - - batch = { - "observation.image": None, - "action": None, - "next.reward": 0.0, - "next.done": False, - "next.truncated": False, - "info": {}, - } - - results = list(pipeline.step_through(batch)) - - assert len(results) == 3 # Original + 2 steps - - # Ensure all results are EnvTransition dicts (regardless of input format) - for result in results: - assert isinstance(result, dict) - # Check that keys are TransitionKey enums or at least valid transition keys - for key in result: - assert key in [ - TransitionKey.OBSERVATION, - TransitionKey.ACTION, - TransitionKey.REWARD, - TransitionKey.DONE, - TransitionKey.TRUNCATED, - TransitionKey.INFO, - TransitionKey.COMPLEMENTARY_DATA, - ] - - # Check that the processing worked - verify step counters in complementary_data - assert results[1].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step1_counter") == 0 - assert results[2].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step1_counter") == 0 - assert results[2].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step2_counter") == 0 - - -def test_step_through_no_hooks(): - """Test that step_through doesn't execute hooks.""" - step = MockStep("test_step") - pipeline = RobotProcessor([step]) - - hook_calls = [] - - def tracking_hook(idx: int, transition: EnvTransition): - hook_calls.append(f"hook_called_step_{idx}") - - # Register hooks - pipeline.register_before_step_hook(tracking_hook) - pipeline.register_after_step_hook(tracking_hook) - - # Use step_through - transition = create_transition() - results = list(pipeline.step_through(transition)) - - # Verify step was executed (counter should increment) - assert len(results) == 2 # Initial + 1 step - assert results[1][TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 0 - - # Verify hooks were NOT called - assert len(hook_calls) == 0 - - # Now use __call__ to verify hooks ARE called there - hook_calls.clear() - pipeline(transition) - - # Verify hooks were called (before and after for 1 step = 2 calls) - assert len(hook_calls) == 2 - assert hook_calls == ["hook_called_step_0", "hook_called_step_0"] - - -def test_indexing(): - """Test pipeline indexing.""" - step1 = MockStep("step1") - step2 = MockStep("step2") - pipeline = RobotProcessor([step1, step2]) - - # Test integer indexing - assert pipeline[0] is step1 - assert pipeline[1] is step2 - - # Test slice indexing - sub_pipeline = pipeline[0:1] - assert isinstance(sub_pipeline, RobotProcessor) - assert len(sub_pipeline) == 1 - assert sub_pipeline[0] is step1 - - -def test_hooks(): - """Test before/after step hooks.""" - step = MockStep("test_step") - pipeline = RobotProcessor([step]) - - before_calls = [] - after_calls = [] - - def before_hook(idx: int, transition: EnvTransition): - before_calls.append(idx) - - def after_hook(idx: int, transition: EnvTransition): - after_calls.append(idx) - - pipeline.register_before_step_hook(before_hook) - pipeline.register_after_step_hook(after_hook) - - transition = create_transition() - pipeline(transition) - - assert before_calls == [0] - assert after_calls == [0] - - -def test_unregister_hooks(): - """Test unregistering hooks from the pipeline.""" - step = MockStep("test_step") - pipeline = RobotProcessor([step]) - - # Test before_step_hook - before_calls = [] - - def before_hook(idx: int, transition: EnvTransition): - before_calls.append(idx) - - pipeline.register_before_step_hook(before_hook) - - # Verify hook is registered - transition = create_transition() - pipeline(transition) - assert len(before_calls) == 1 - - # Unregister and verify it's no longer called - pipeline.unregister_before_step_hook(before_hook) - before_calls.clear() - pipeline(transition) - assert len(before_calls) == 0 - - # Test after_step_hook - after_calls = [] - - def after_hook(idx: int, transition: EnvTransition): - after_calls.append(idx) - - pipeline.register_after_step_hook(after_hook) - pipeline(transition) - assert len(after_calls) == 1 - - pipeline.unregister_after_step_hook(after_hook) - after_calls.clear() - pipeline(transition) - assert len(after_calls) == 0 - - -def test_unregister_nonexistent_hook(): - """Test error handling when unregistering hooks that don't exist.""" - pipeline = RobotProcessor([MockStep()]) - - def some_hook(idx: int, transition: EnvTransition): - pass - - def reset_hook(): - pass - - # Test unregistering hooks that were never registered - with pytest.raises(ValueError, match="not found in before_step_hooks"): - pipeline.unregister_before_step_hook(some_hook) - - with pytest.raises(ValueError, match="not found in after_step_hooks"): - pipeline.unregister_after_step_hook(some_hook) - - -def test_multiple_hooks_and_selective_unregister(): - """Test registering multiple hooks and selectively unregistering them.""" - pipeline = RobotProcessor([MockStep("step1"), MockStep("step2")]) - - calls_1 = [] - calls_2 = [] - calls_3 = [] - - def hook1(idx: int, transition: EnvTransition): - calls_1.append(f"hook1_step{idx}") - - def hook2(idx: int, transition: EnvTransition): - calls_2.append(f"hook2_step{idx}") - - def hook3(idx: int, transition: EnvTransition): - calls_3.append(f"hook3_step{idx}") - - # Register multiple hooks - pipeline.register_before_step_hook(hook1) - pipeline.register_before_step_hook(hook2) - pipeline.register_before_step_hook(hook3) - - # Run pipeline - all hooks should be called for both steps - transition = create_transition() - pipeline(transition) - - assert calls_1 == ["hook1_step0", "hook1_step1"] - assert calls_2 == ["hook2_step0", "hook2_step1"] - assert calls_3 == ["hook3_step0", "hook3_step1"] - - # Clear calls - calls_1.clear() - calls_2.clear() - calls_3.clear() - - # Unregister middle hook - pipeline.unregister_before_step_hook(hook2) - - # Run again - only hook1 and hook3 should be called - pipeline(transition) - - assert calls_1 == ["hook1_step0", "hook1_step1"] - assert calls_2 == [] # hook2 was unregistered - assert calls_3 == ["hook3_step0", "hook3_step1"] - - -def test_hook_execution_order_documentation(): - """Test and document that hooks are executed sequentially in registration order.""" - pipeline = RobotProcessor([MockStep("step")]) - - execution_order = [] - - def hook_a(idx: int, transition: EnvTransition): - execution_order.append("A") - - def hook_b(idx: int, transition: EnvTransition): - execution_order.append("B") - - def hook_c(idx: int, transition: EnvTransition): - execution_order.append("C") - - # Register in specific order: A, B, C - pipeline.register_before_step_hook(hook_a) - pipeline.register_before_step_hook(hook_b) - pipeline.register_before_step_hook(hook_c) - - transition = create_transition() - pipeline(transition) - - # Verify execution order matches registration order - assert execution_order == ["A", "B", "C"] - - # Test that after unregistering B and re-registering it, it goes to the end - pipeline.unregister_before_step_hook(hook_b) - execution_order.clear() - - pipeline(transition) - assert execution_order == ["A", "C"] # B is gone - - # Re-register B - it should now be at the end - pipeline.register_before_step_hook(hook_b) - execution_order.clear() - - pipeline(transition) - assert execution_order == ["A", "C", "B"] # B is now last - - -def test_save_and_load_pretrained(): - """Test saving and loading pipeline. - - This test demonstrates that JSON-serializable attributes (like counter) - are saved in the config and restored when the step is recreated. - """ - step1 = MockStep("step1") - step2 = MockStep("step2") - - # Increment counters to have some state - step1.counter = 5 - step2.counter = 10 - - pipeline = RobotProcessor([step1, step2], name="TestPipeline") - - with tempfile.TemporaryDirectory() as tmp_dir: - # Save pipeline - pipeline.save_pretrained(tmp_dir) - - # Check files were created - config_path = Path(tmp_dir) / "testpipeline.json" # Based on name="TestPipeline" - assert config_path.exists() - - # Check config content - with open(config_path) as f: - config = json.load(f) - - assert config["name"] == "TestPipeline" - assert len(config["steps"]) == 2 - - # Verify counters are saved in config, not in separate state files - assert config["steps"][0]["config"]["counter"] == 5 - assert config["steps"][1]["config"]["counter"] == 10 - - # Load pipeline - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) - - assert loaded_pipeline.name == "TestPipeline" - assert len(loaded_pipeline) == 2 - - # Check that counter was restored from config - assert loaded_pipeline.steps[0].counter == 5 - assert loaded_pipeline.steps[1].counter == 10 - - -def test_step_without_optional_methods(): - """Test pipeline with steps that don't implement optional methods.""" - step = MockStepWithoutOptionalMethods(multiplier=3.0) - pipeline = RobotProcessor([step]) - - transition = create_transition(reward=2.0) - result = pipeline(transition) - - assert result[TransitionKey.REWARD] == 6.0 # 2.0 * 3.0 - - # Reset should work even if step doesn't implement reset - pipeline.reset() - - # Save/load should work even without optional methods - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) - assert len(loaded_pipeline) == 1 - - -def test_mixed_json_and_tensor_state(): - """Test step with both JSON attributes and tensor state.""" - step = MockStepWithTensorState(name="stats", learning_rate=0.05, window_size=5) - pipeline = RobotProcessor([step]) - - # Process some transitions with rewards - for i in range(10): - transition = create_transition(reward=float(i)) - pipeline(transition) - - # Check state - assert step.running_count.item() == 10 - assert step.learning_rate == 0.05 - - # Save and load - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - # Check that both config and state files were created - config_path = Path(tmp_dir) / "robotprocessor.json" # Default name is "RobotProcessor" - state_path = Path(tmp_dir) / "robotprocessor_step_0.safetensors" - assert config_path.exists() - assert state_path.exists() - - # Load and verify - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) - loaded_step = loaded_pipeline.steps[0] - - # Check JSON attributes were restored - assert loaded_step.name == "stats" - assert loaded_step.learning_rate == 0.05 - assert loaded_step.window_size == 5 - - # Check tensor state was restored - assert loaded_step.running_count.item() == 10 - assert torch.allclose(loaded_step.running_mean, step.running_mean) - - -class MockModuleStep(nn.Module): - """Mock step that inherits from nn.Module to test state_dict handling of module parameters.""" - - def __init__(self, input_dim: int = 10, hidden_dim: int = 5): - super().__init__() - self.input_dim = input_dim - self.hidden_dim = hidden_dim - self.linear = nn.Linear(input_dim, hidden_dim) - self.running_mean = nn.Parameter(torch.zeros(hidden_dim), requires_grad=False) - self.counter = 0 # Non-tensor state - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear(x) - - def __call__(self, transition: EnvTransition) -> EnvTransition: - """Process transition and update running mean.""" - obs = transition.get(TransitionKey.OBSERVATION) - - if obs is not None and isinstance(obs, torch.Tensor): - # Process observation through linear layer - processed = self.forward(obs[:, : self.input_dim]) - - # Update running mean in-place (don't reassign the parameter) - with torch.no_grad(): - self.running_mean.mul_(0.9).add_(processed.mean(dim=0), alpha=0.1) - - self.counter += 1 - - return transition - - def get_config(self) -> dict[str, Any]: - return { - "input_dim": self.input_dim, - "hidden_dim": self.hidden_dim, - "counter": self.counter, - } - - def state_dict(self) -> dict[str, torch.Tensor]: - """Override to return all module parameters and buffers.""" - # Get the module's state dict (includes all parameters and buffers) - return super().state_dict() - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - """Override to load all module parameters and buffers.""" - # Use the module's load_state_dict - super().load_state_dict(state) - - def reset(self) -> None: - self.running_mean.zero_() - self.counter = 0 - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here - return features - - -class MockNonModuleStepWithState: - """Mock step that explicitly does NOT inherit from nn.Module but has tensor state. - - This tests the state_dict/load_state_dict path for regular classes. - """ - - def __init__(self, name: str = "non_module_step", feature_dim: int = 10): - self.name = name - self.feature_dim = feature_dim - - # Initialize tensor state - these are regular tensors, not nn.Parameters - self.weights = torch.randn(feature_dim, feature_dim) - self.bias = torch.zeros(feature_dim) - self.running_stats = torch.zeros(feature_dim) - self.step_count = torch.tensor(0) - - # Non-tensor state - self.config_value = 42 - self.history = [] - - def __call__(self, transition: EnvTransition) -> EnvTransition: - """Process transition using tensor operations.""" - obs = transition.get(TransitionKey.OBSERVATION) - comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) - - if obs is not None and isinstance(obs, torch.Tensor) and obs.numel() >= self.feature_dim: - # Perform some tensor operations - flat_obs = obs.flatten()[: self.feature_dim] - - # Simple linear transformation (ensure dimensions match for matmul) - output = torch.matmul(self.weights.T, flat_obs) + self.bias - - # Update running stats - self.running_stats = 0.9 * self.running_stats + 0.1 * output - self.step_count += 1 - - # Add to complementary data - comp_data = {} if comp_data is None else dict(comp_data) - comp_data[f"{self.name}_mean_output"] = output.mean().item() - comp_data[f"{self.name}_steps"] = self.step_count.item() - - # Return updated transition - new_transition = transition.copy() - new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data - return new_transition - - return transition - - def get_config(self) -> dict[str, Any]: - return { - "name": self.name, - "feature_dim": self.feature_dim, - "config_value": self.config_value, - } - - def state_dict(self) -> dict[str, torch.Tensor]: - """Return only tensor state.""" - return { - "weights": self.weights, - "bias": self.bias, - "running_stats": self.running_stats, - "step_count": self.step_count, - } - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - """Load tensor state.""" - self.weights = state["weights"] - self.bias = state["bias"] - self.running_stats = state["running_stats"] - self.step_count = state["step_count"] - - def reset(self) -> None: - """Reset statistics but keep learned parameters.""" - self.running_stats.zero_() - self.step_count.zero_() - self.history.clear() - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here - return features - - -# Tests for overrides functionality -@dataclass -class MockStepWithNonSerializableParam: - """Mock step that requires a non-serializable parameter.""" - - def __init__(self, name: str = "mock_env_step", multiplier: float = 1.0, env: Any = None): - self.name = name - # Add type validation for multiplier - if isinstance(multiplier, str): - raise ValueError(f"multiplier must be a number, got string '{multiplier}'") - if not isinstance(multiplier, (int, float)): - raise TypeError(f"multiplier must be a number, got {type(multiplier).__name__}") - self.multiplier = float(multiplier) - self.env = env # Non-serializable parameter (like gym.Env) - - def __call__(self, transition: EnvTransition) -> EnvTransition: - reward = transition.get(TransitionKey.REWARD) - comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) - - # Use the env parameter if provided - if self.env is not None: - comp_data = {} if comp_data is None else dict(comp_data) - comp_data[f"{self.name}_env_info"] = str(self.env) - - # Apply multiplier to reward - new_transition = transition.copy() - if reward is not None: - new_transition[TransitionKey.REWARD] = reward * self.multiplier - - if comp_data: - new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data - - return new_transition - - def get_config(self) -> dict[str, Any]: - # Note: env is intentionally NOT included here as it's not serializable - return { - "name": self.name, - "multiplier": self.multiplier, - } - - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here - return features - - -@ProcessorStepRegistry.register("registered_mock_step") -@dataclass -class RegisteredMockStep: - """Mock step registered in the registry.""" - - value: int = 42 - device: str = "cpu" - - def __call__(self, transition: EnvTransition) -> EnvTransition: - comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) - - comp_data = {} if comp_data is None else dict(comp_data) - comp_data["registered_step_value"] = self.value - comp_data["registered_step_device"] = self.device - - new_transition = transition.copy() - new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data - return new_transition - - def get_config(self) -> dict[str, Any]: - return { - "value": self.value, - "device": self.device, - } - - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here - return features - - -class MockEnvironment: - """Mock environment for testing non-serializable parameters.""" - - def __init__(self, name: str): - self.name = name - - def __str__(self): - return f"MockEnvironment({self.name})" - - -def test_from_pretrained_with_overrides(): - """Test loading processor with parameter overrides.""" - # Create a processor with steps that need overrides - env_step = MockStepWithNonSerializableParam(name="env_step", multiplier=2.0) - registered_step = RegisteredMockStep(value=100, device="cpu") - - pipeline = RobotProcessor([env_step, registered_step], name="TestOverrides") - - with tempfile.TemporaryDirectory() as tmp_dir: - # Save the pipeline - pipeline.save_pretrained(tmp_dir) - - # Create a mock environment for override - mock_env = MockEnvironment("test_env") - - # Load with overrides - overrides = { - "MockStepWithNonSerializableParam": { - "env": mock_env, - "multiplier": 3.0, # Override the multiplier too - }, - "registered_mock_step": {"device": "cuda", "value": 200}, - } - - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) - - # Verify the pipeline was loaded correctly - assert len(loaded_pipeline) == 2 - assert loaded_pipeline.name == "TestOverrides" - - # Test the loaded steps - transition = create_transition(reward=1.0) - result = loaded_pipeline(transition) - - # Check that overrides were applied - comp_data = result[TransitionKey.COMPLEMENTARY_DATA] - assert "env_step_env_info" in comp_data - assert comp_data["env_step_env_info"] == "MockEnvironment(test_env)" - assert comp_data["registered_step_value"] == 200 - assert comp_data["registered_step_device"] == "cuda" - - # Check that multiplier override was applied - assert result[TransitionKey.REWARD] == 3.0 # 1.0 * 3.0 (overridden multiplier) - - -def test_from_pretrained_with_partial_overrides(): - """Test loading processor with overrides for only some steps.""" - step1 = MockStepWithNonSerializableParam(name="step1", multiplier=1.0) - step2 = MockStepWithNonSerializableParam(name="step2", multiplier=2.0) - - pipeline = RobotProcessor([step1, step2]) - - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - # Override only one step - overrides = {"MockStepWithNonSerializableParam": {"multiplier": 5.0}} - - # The current implementation applies overrides to ALL steps with the same class name - # Both steps will get the override - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) - - transition = create_transition(reward=1.0) - result = loaded_pipeline(transition) - - # The reward should be affected by both steps, both getting the override - # First step: 1.0 * 5.0 = 5.0 (overridden) - # Second step: 5.0 * 5.0 = 25.0 (also overridden) - assert result[TransitionKey.REWARD] == 25.0 - - -def test_from_pretrained_invalid_override_key(): - """Test that invalid override keys raise KeyError.""" - step = MockStepWithNonSerializableParam() - pipeline = RobotProcessor([step]) - - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - # Try to override a non-existent step - overrides = {"NonExistentStep": {"param": "value"}} - - with pytest.raises(KeyError, match="Override keys.*do not match any step"): - RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) - - -def test_from_pretrained_multiple_invalid_override_keys(): - """Test that multiple invalid override keys are reported.""" - step = MockStepWithNonSerializableParam() - pipeline = RobotProcessor([step]) - - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - # Try to override multiple non-existent steps - overrides = {"NonExistentStep1": {"param": "value1"}, "NonExistentStep2": {"param": "value2"}} - - with pytest.raises(KeyError) as exc_info: - RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) - - error_msg = str(exc_info.value) - assert "NonExistentStep1" in error_msg - assert "NonExistentStep2" in error_msg - assert "Available step keys" in error_msg - - -def test_from_pretrained_registered_step_override(): - """Test overriding registered steps using registry names.""" - registered_step = RegisteredMockStep(value=50, device="cpu") - pipeline = RobotProcessor([registered_step]) - - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - # Override using registry name - overrides = {"registered_mock_step": {"value": 999, "device": "cuda"}} - - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) - - # Test that overrides were applied - transition = create_transition() - result = loaded_pipeline(transition) - - comp_data = result[TransitionKey.COMPLEMENTARY_DATA] - assert comp_data["registered_step_value"] == 999 - assert comp_data["registered_step_device"] == "cuda" - - -def test_from_pretrained_mixed_registered_and_unregistered(): - """Test overriding both registered and unregistered steps.""" - unregistered_step = MockStepWithNonSerializableParam(name="unregistered", multiplier=1.0) - registered_step = RegisteredMockStep(value=10, device="cpu") - - pipeline = RobotProcessor([unregistered_step, registered_step]) - - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - mock_env = MockEnvironment("mixed_test") - - overrides = { - "MockStepWithNonSerializableParam": {"env": mock_env, "multiplier": 4.0}, - "registered_mock_step": {"value": 777}, - } - - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) - - # Test both steps - transition = create_transition(reward=2.0) - result = loaded_pipeline(transition) - - comp_data = result[TransitionKey.COMPLEMENTARY_DATA] - assert comp_data["unregistered_env_info"] == "MockEnvironment(mixed_test)" - assert comp_data["registered_step_value"] == 777 - assert result[TransitionKey.REWARD] == 8.0 # 2.0 * 4.0 - - -def test_from_pretrained_no_overrides(): - """Test that from_pretrained works without overrides (backward compatibility).""" - step = MockStepWithNonSerializableParam(name="no_override", multiplier=3.0) - pipeline = RobotProcessor([step]) - - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - # Load without overrides - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) - - assert len(loaded_pipeline) == 1 - - # Test that the step works (env will be None) - transition = create_transition(reward=1.0) - result = loaded_pipeline(transition) - - assert result[TransitionKey.REWARD] == 3.0 # 1.0 * 3.0 - - -def test_from_pretrained_empty_overrides(): - """Test that from_pretrained works with empty overrides dict.""" - step = MockStepWithNonSerializableParam(multiplier=2.0) - pipeline = RobotProcessor([step]) - - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - # Load with empty overrides - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides={}) - - assert len(loaded_pipeline) == 1 - - # Test that the step works normally - transition = create_transition(reward=1.0) - result = loaded_pipeline(transition) - - assert result[TransitionKey.REWARD] == 2.0 - - -def test_from_pretrained_override_instantiation_error(): - """Test that instantiation errors with overrides are properly reported.""" - step = MockStepWithNonSerializableParam(multiplier=1.0) - pipeline = RobotProcessor([step]) - - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - # Try to override with invalid parameter type - overrides = { - "MockStepWithNonSerializableParam": { - "multiplier": "invalid_type" # Should be float, not string - } - } - - with pytest.raises(ValueError, match="Failed to instantiate processor step"): - RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) - - -def test_from_pretrained_with_state_and_overrides(): - """Test that overrides work correctly with steps that have tensor state.""" - step = MockStepWithTensorState(name="tensor_step", learning_rate=0.01, window_size=5) - pipeline = RobotProcessor([step]) - - # Process some data to create state - for i in range(10): - transition = create_transition(reward=float(i)) - pipeline(transition) - - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - # Load with overrides - overrides = { - "MockStepWithTensorState": { - "learning_rate": 0.05, # Override learning rate - "window_size": 3, # Override window size - } - } - - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) - loaded_step = loaded_pipeline.steps[0] - - # Check that config overrides were applied - assert loaded_step.learning_rate == 0.05 - assert loaded_step.window_size == 3 - - # Check that tensor state was preserved - assert loaded_step.running_count.item() == 10 - - # The running_mean should still have the original window_size (5) from saved state - # but the new step will use window_size=3 for future operations - assert loaded_step.running_mean.shape[0] == 5 # From saved state - - -def test_from_pretrained_override_error_messages(): - """Test that error messages for override failures are helpful.""" - step1 = MockStepWithNonSerializableParam(name="step1") - step2 = RegisteredMockStep() - pipeline = RobotProcessor([step1, step2]) - - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - # Test with invalid override key - overrides = {"WrongStepName": {"param": "value"}} - - with pytest.raises(KeyError) as exc_info: - RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) - - error_msg = str(exc_info.value) - assert "WrongStepName" in error_msg - assert "Available step keys" in error_msg - assert "MockStepWithNonSerializableParam" in error_msg - assert "registered_mock_step" in error_msg - - -def test_repr_empty_processor(): - """Test __repr__ with empty processor.""" - pipeline = RobotProcessor() - repr_str = repr(pipeline) - - expected = "RobotProcessor(name='RobotProcessor', steps=0: [])" - assert repr_str == expected - - -def test_repr_single_step(): - """Test __repr__ with single step.""" - step = MockStep("test_step") - pipeline = RobotProcessor([step]) - repr_str = repr(pipeline) - - expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep])" - assert repr_str == expected - - -def test_repr_multiple_steps_under_limit(): - """Test __repr__ with 2-3 steps (all shown).""" - step1 = MockStep("step1") - step2 = MockStepWithoutOptionalMethods() - pipeline = RobotProcessor([step1, step2]) - repr_str = repr(pipeline) - - expected = "RobotProcessor(name='RobotProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods])" - assert repr_str == expected - - # Test with 3 steps (boundary case) - step3 = MockStepWithTensorState() - pipeline = RobotProcessor([step1, step2, step3]) - repr_str = repr(pipeline) - - expected = "RobotProcessor(name='RobotProcessor', steps=3: [MockStep, MockStepWithoutOptionalMethods, MockStepWithTensorState])" - assert repr_str == expected - - -def test_repr_many_steps_truncated(): - """Test __repr__ with more than 3 steps (truncated with ellipsis).""" - step1 = MockStep("step1") - step2 = MockStepWithoutOptionalMethods() - step3 = MockStepWithTensorState() - step4 = MockModuleStep() - step5 = MockNonModuleStepWithState() - - pipeline = RobotProcessor([step1, step2, step3, step4, step5]) - repr_str = repr(pipeline) - - expected = "RobotProcessor(name='RobotProcessor', steps=5: [MockStep, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])" - assert repr_str == expected - - -def test_repr_with_custom_name(): - """Test __repr__ with custom processor name.""" - step = MockStep("test_step") - pipeline = RobotProcessor([step], name="CustomProcessor") - repr_str = repr(pipeline) - - expected = "RobotProcessor(name='CustomProcessor', steps=1: [MockStep])" - assert repr_str == expected - - -def test_repr_with_seed(): - """Test __repr__ with seed parameter.""" - step = MockStep("test_step") - pipeline = RobotProcessor([step]) - repr_str = repr(pipeline) - - expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep])" - assert repr_str == expected - - -def test_repr_with_custom_name_and_seed(): - """Test __repr__ with both custom name and seed.""" - step1 = MockStep("step1") - step2 = MockStepWithoutOptionalMethods() - pipeline = RobotProcessor([step1, step2], name="MyProcessor") - repr_str = repr(pipeline) - - expected = "RobotProcessor(name='MyProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods])" - assert repr_str == expected - - -def test_repr_without_seed(): - """Test __repr__ when seed is explicitly None (should not show seed).""" - step = MockStep("test_step") - pipeline = RobotProcessor([step], name="TestProcessor") - repr_str = repr(pipeline) - - expected = "RobotProcessor(name='TestProcessor', steps=1: [MockStep])" - assert repr_str == expected - - -def test_repr_various_step_types(): - """Test __repr__ with different types of steps to verify class name extraction.""" - step1 = MockStep() - step2 = MockStepWithTensorState() - step3 = MockModuleStep() - step4 = MockNonModuleStepWithState() - - pipeline = RobotProcessor([step1, step2, step3, step4], name="MixedSteps") - repr_str = repr(pipeline) - - expected = "RobotProcessor(name='MixedSteps', steps=4: [MockStep, MockStepWithTensorState, ..., MockNonModuleStepWithState])" - assert repr_str == expected - - -def test_repr_edge_case_long_names(): - """Test __repr__ handles steps with long class names properly.""" - step1 = MockStepWithNonSerializableParam() - step2 = MockStepWithoutOptionalMethods() - step3 = MockStepWithTensorState() - step4 = MockNonModuleStepWithState() - - pipeline = RobotProcessor([step1, step2, step3, step4], name="LongNames") - repr_str = repr(pipeline) - - expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])" - assert repr_str == expected - - -# Tests for config filename features and multiple processors -def test_save_with_custom_config_filename(): - """Test saving processor with custom config filename.""" - step = MockStep("test") - pipeline = RobotProcessor([step], name="TestProcessor") - - with tempfile.TemporaryDirectory() as tmp_dir: - # Save with custom filename - pipeline.save_pretrained(tmp_dir, config_filename="my_custom_config.json") - - # Check file exists - config_path = Path(tmp_dir) / "my_custom_config.json" - assert config_path.exists() - - # Check content - with open(config_path) as f: - config = json.load(f) - assert config["name"] == "TestProcessor" - - # Load with specific filename - loaded = RobotProcessor.from_pretrained(tmp_dir, config_filename="my_custom_config.json") - assert loaded.name == "TestProcessor" - - -def test_multiple_processors_same_directory(): - """Test saving multiple processors to the same directory with different config files.""" - # Create different processors - preprocessor = RobotProcessor([MockStep("pre1"), MockStep("pre2")], name="preprocessor") - - postprocessor = RobotProcessor([MockStepWithoutOptionalMethods(multiplier=0.5)], name="postprocessor") - - with tempfile.TemporaryDirectory() as tmp_dir: - # Save both to same directory - preprocessor.save_pretrained(tmp_dir) - postprocessor.save_pretrained(tmp_dir) - - # Check both config files exist - assert (Path(tmp_dir) / "preprocessor.json").exists() - assert (Path(tmp_dir) / "postprocessor.json").exists() - - # Load them back - loaded_pre = RobotProcessor.from_pretrained(tmp_dir, config_filename="preprocessor.json") - loaded_post = RobotProcessor.from_pretrained(tmp_dir, config_filename="postprocessor.json") - - assert loaded_pre.name == "preprocessor" - assert loaded_post.name == "postprocessor" - assert len(loaded_pre) == 2 - assert len(loaded_post) == 1 - - -def test_auto_detect_single_config(): - """Test automatic config detection when there's only one JSON file.""" - step = MockStepWithTensorState() - pipeline = RobotProcessor([step], name="SingleConfig") - - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - # Load without specifying config_filename - loaded = RobotProcessor.from_pretrained(tmp_dir) - assert loaded.name == "SingleConfig" - - -def test_error_multiple_configs_no_filename(): - """Test error when multiple configs exist and no filename specified.""" - proc1 = RobotProcessor([MockStep()], name="processor1") - proc2 = RobotProcessor([MockStep()], name="processor2") - - with tempfile.TemporaryDirectory() as tmp_dir: - proc1.save_pretrained(tmp_dir) - proc2.save_pretrained(tmp_dir) - - # Should raise error - with pytest.raises(ValueError, match="Multiple .json files found"): - RobotProcessor.from_pretrained(tmp_dir) - - -def test_state_file_naming_with_indices(): - """Test that state files include pipeline name and step indices to avoid conflicts.""" - # Create multiple steps of same type with state - step1 = MockStepWithTensorState(name="norm1", window_size=5) - step2 = MockStepWithTensorState(name="norm2", window_size=10) - step3 = MockModuleStep(input_dim=5) - - pipeline = RobotProcessor([step1, step2, step3]) - - # Process some data to create state - for i in range(5): - transition = create_transition(observation=torch.randn(2, 5), reward=float(i)) - pipeline(transition) - - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - # Check state files have indices - state_files = sorted(Path(tmp_dir).glob("*.safetensors")) - assert len(state_files) == 3 - - # Files should be named with pipeline name prefix and indices - expected_names = [ - "robotprocessor_step_0.safetensors", - "robotprocessor_step_1.safetensors", - "robotprocessor_step_2.safetensors", - ] - actual_names = [f.name for f in state_files] - assert actual_names == expected_names - - -def test_state_file_naming_with_registry(): - """Test state file naming for registered steps includes pipeline name, index and registry name.""" - - # Register a test step - @ProcessorStepRegistry.register("test_stateful_step") - @dataclass - class TestStatefulStep: - value: int = 0 - - def __init__(self, value: int = 0): - self.value = value - self.state_tensor = torch.randn(3, 3) - - def __call__(self, transition: EnvTransition) -> EnvTransition: - return transition - - def get_config(self): - return {"value": self.value} - - def state_dict(self): - return {"state_tensor": self.state_tensor} - - def load_state_dict(self, state): - self.state_tensor = state["state_tensor"] - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here - return features - - try: - # Create pipeline with registered steps - step1 = TestStatefulStep(1) - step2 = TestStatefulStep(2) - pipeline = RobotProcessor([step1, step2]) - - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - # Check state files - state_files = sorted(Path(tmp_dir).glob("*.safetensors")) - assert len(state_files) == 2 - - # Should include pipeline name, index and registry name - expected_names = [ - "robotprocessor_step_0_test_stateful_step.safetensors", - "robotprocessor_step_1_test_stateful_step.safetensors", - ] - actual_names = [f.name for f in state_files] - assert actual_names == expected_names - - finally: - # Cleanup registry - ProcessorStepRegistry.unregister("test_stateful_step") - - -# More comprehensive override tests -def test_override_with_nested_config(): - """Test overrides with nested configuration dictionaries.""" - - @ProcessorStepRegistry.register("complex_config_step") - @dataclass - class ComplexConfigStep: - name: str = "complex" - simple_param: int = 42 - nested_config: dict = None - - def __post_init__(self): - if self.nested_config is None: - self.nested_config = {"level1": {"level2": "default"}} - - def __call__(self, transition: EnvTransition) -> EnvTransition: - comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) - comp_data = dict(comp_data) - comp_data["config_value"] = self.nested_config.get("level1", {}).get("level2", "missing") - - new_transition = transition.copy() - new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data - return new_transition - - def get_config(self): - return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config} - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here - return features - - try: - step = ComplexConfigStep() - pipeline = RobotProcessor([step]) - - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - # Load with nested override - loaded = RobotProcessor.from_pretrained( - tmp_dir, - overrides={"complex_config_step": {"nested_config": {"level1": {"level2": "overridden"}}}}, - ) - - # Test that override worked - transition = create_transition() - result = loaded(transition) - assert result[TransitionKey.COMPLEMENTARY_DATA]["config_value"] == "overridden" - finally: - ProcessorStepRegistry.unregister("complex_config_step") - - -def test_override_preserves_defaults(): - """Test that overrides only affect specified parameters.""" - step = MockStepWithNonSerializableParam(name="test", multiplier=2.0) - pipeline = RobotProcessor([step]) - - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - # Override only one parameter - loaded = RobotProcessor.from_pretrained( - tmp_dir, - overrides={ - "MockStepWithNonSerializableParam": { - "multiplier": 5.0 # Only override multiplier - } - }, - ) - - # Check that name was preserved from saved config - loaded_step = loaded.steps[0] - assert loaded_step.name == "test" # Original value - assert loaded_step.multiplier == 5.0 # Overridden value - - -def test_override_type_validation(): - """Test that type errors in overrides are caught properly.""" - step = MockStepWithTensorState(learning_rate=0.01) - pipeline = RobotProcessor([step]) - - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - # Try to override with wrong type - overrides = { - "MockStepWithTensorState": { - "window_size": "not_an_int" # Should be int - } - } - - with pytest.raises(ValueError, match="Failed to instantiate"): - RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) - - -def test_override_with_callables(): - """Test overriding with callable objects.""" - - @ProcessorStepRegistry.register("callable_step") - @dataclass - class CallableStep: - name: str = "callable_step" - transform_fn: Any = None - - def __call__(self, transition: EnvTransition) -> EnvTransition: - obs = transition.get(TransitionKey.OBSERVATION) - if obs is not None and self.transform_fn is not None: - processed_obs = {} - for k, v in obs.items(): - processed_obs[k] = self.transform_fn(v) - - new_transition = transition.copy() - new_transition[TransitionKey.OBSERVATION] = processed_obs - return new_transition - return transition - - def get_config(self): - return {"name": self.name} - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here - return features - - try: - step = CallableStep() - pipeline = RobotProcessor([step]) - - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - # Define a transform function - def double_values(x): - if isinstance(x, (int, float)): - return x * 2 - elif isinstance(x, torch.Tensor): - return x * 2 - return x - - # Load with callable override - loaded = RobotProcessor.from_pretrained( - tmp_dir, overrides={"callable_step": {"transform_fn": double_values}} - ) - - # Test it works - transition = create_transition(observation={"value": torch.tensor(5.0)}) - result = loaded(transition) - assert result[TransitionKey.OBSERVATION]["value"].item() == 10.0 - finally: - ProcessorStepRegistry.unregister("callable_step") - - -def test_override_multiple_same_class_warning(): - """Test behavior when multiple steps of same class exist.""" - step1 = MockStepWithNonSerializableParam(name="step1", multiplier=1.0) - step2 = MockStepWithNonSerializableParam(name="step2", multiplier=2.0) - pipeline = RobotProcessor([step1, step2]) - - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - # Override affects all instances of the class - loaded = RobotProcessor.from_pretrained( - tmp_dir, overrides={"MockStepWithNonSerializableParam": {"multiplier": 10.0}} - ) - - # Both steps get the same override - assert loaded.steps[0].multiplier == 10.0 - assert loaded.steps[1].multiplier == 10.0 - - # But original names are preserved - assert loaded.steps[0].name == "step1" - assert loaded.steps[1].name == "step2" - - -def test_config_filename_special_characters(): - """Test config filenames with special characters are sanitized.""" - # Processor name with special characters - pipeline = RobotProcessor([MockStep()], name="My/Processor\\With:Special*Chars") - - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - # Check that filename was sanitized - json_files = list(Path(tmp_dir).glob("*.json")) - assert len(json_files) == 1 - - # Should have replaced special chars with underscores - expected_name = "my_processor_with_special_chars.json" - assert json_files[0].name == expected_name - - -def test_state_file_naming_with_multiple_processors(): - """Test that state files are properly prefixed with pipeline names to avoid conflicts.""" - # Create two processors with state - step1 = MockStepWithTensorState(name="norm", window_size=5) - preprocessor = RobotProcessor([step1], name="PreProcessor") - - step2 = MockStepWithTensorState(name="norm", window_size=10) - postprocessor = RobotProcessor([step2], name="PostProcessor") - - # Process some data to create state - for i in range(3): - transition = create_transition(reward=float(i)) - preprocessor(transition) - postprocessor(transition) - - with tempfile.TemporaryDirectory() as tmp_dir: - # Save both processors to the same directory - preprocessor.save_pretrained(tmp_dir) - postprocessor.save_pretrained(tmp_dir) - - # Check that all files exist and are distinct - assert (Path(tmp_dir) / "preprocessor.json").exists() - assert (Path(tmp_dir) / "postprocessor.json").exists() - assert (Path(tmp_dir) / "preprocessor_step_0.safetensors").exists() - assert (Path(tmp_dir) / "postprocessor_step_0.safetensors").exists() - - # Load both back and verify they work correctly - loaded_pre = RobotProcessor.from_pretrained(tmp_dir, config_filename="preprocessor.json") - loaded_post = RobotProcessor.from_pretrained(tmp_dir, config_filename="postprocessor.json") - - assert loaded_pre.name == "PreProcessor" - assert loaded_post.name == "PostProcessor" - assert loaded_pre.steps[0].window_size == 5 - assert loaded_post.steps[0].window_size == 10 - - -def test_override_with_device_strings(): - """Test overriding device parameters with string values.""" - - @ProcessorStepRegistry.register("device_aware_step") - @dataclass - class DeviceAwareStep: - device: str = "cpu" - - def __init__(self, device: str = "cpu"): - self.device = device - self.buffer = torch.zeros(10, device=device) - - def __call__(self, transition: EnvTransition) -> EnvTransition: - return transition - - def get_config(self): - return {"device": str(self.device)} - - def state_dict(self): - return {"buffer": self.buffer} - - def load_state_dict(self, state): - self.buffer = state["buffer"] - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here - return features - - try: - step = DeviceAwareStep(device="cpu") - pipeline = RobotProcessor([step]) - - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - # Override device - if torch.cuda.is_available(): - loaded = RobotProcessor.from_pretrained( - tmp_dir, overrides={"device_aware_step": {"device": "cuda:0"}} - ) - - loaded_step = loaded.steps[0] - assert loaded_step.device == "cuda:0" - # Note: buffer will still be on CPU from saved state - # until .to() is called on the processor - - finally: - ProcessorStepRegistry.unregister("device_aware_step") - - -def test_from_pretrained_nonexistent_path(): - """Test error handling when loading from non-existent sources.""" - from huggingface_hub.errors import HfHubHTTPError, HFValidationError - - # Test with an invalid repo ID (too many slashes) - caught by HF validation - with pytest.raises(HFValidationError): - RobotProcessor.from_pretrained("/path/that/does/not/exist") - - # Test with a non-existent but valid Hub repo format - with pytest.raises((FileNotFoundError, HfHubHTTPError)): - RobotProcessor.from_pretrained("nonexistent-user/nonexistent-repo") - - # Test with a local directory that exists but has no config files - with tempfile.TemporaryDirectory() as tmp_dir: - with pytest.raises(FileNotFoundError, match="No .json configuration files found"): - RobotProcessor.from_pretrained(tmp_dir) - - -def test_save_load_with_custom_converter_functions(): - """Test that custom to_transition and to_output functions are NOT saved.""" - - def custom_to_transition(batch): - # Custom conversion logic - return { - TransitionKey.OBSERVATION: batch.get("obs"), - TransitionKey.ACTION: batch.get("act"), - TransitionKey.REWARD: batch.get("rew", 0.0), - TransitionKey.DONE: batch.get("done", False), - TransitionKey.TRUNCATED: batch.get("truncated", False), - TransitionKey.INFO: {}, - TransitionKey.COMPLEMENTARY_DATA: {}, - } - - def custom_to_output(transition): - # Custom output format - return { - "obs": transition.get(TransitionKey.OBSERVATION), - "act": transition.get(TransitionKey.ACTION), - "rew": transition.get(TransitionKey.REWARD), - "done": transition.get(TransitionKey.DONE), - "truncated": transition.get(TransitionKey.TRUNCATED), - } - - # Create processor with custom converters - pipeline = RobotProcessor([MockStep()], to_transition=custom_to_transition, to_output=custom_to_output) - - with tempfile.TemporaryDirectory() as tmp_dir: - pipeline.save_pretrained(tmp_dir) - - # Load - should use default converters - loaded = RobotProcessor.from_pretrained(tmp_dir) - - # Verify it uses default converters by checking with standard batch format - batch = { - "observation.image": torch.randn(1, 3, 32, 32), - "action": torch.randn(1, 7), - "next.reward": torch.tensor([1.0]), - "next.done": torch.tensor([False]), - "next.truncated": torch.tensor([False]), - "info": {}, - } - - # Should work with standard format (wouldn't work with custom converter) - result = loaded(batch) - assert "observation.image" in result # Standard format preserved - - -class NonCompliantStep: - """Intentionally non-compliant: missing feature_contract.""" - - def __call__(self, transition: EnvTransition) -> EnvTransition: - return transition - - -def test_construction_rejects_step_without_feature_contract(): - with pytest.raises(TypeError, match=r"must define feature_contract\(features\) -> dict\[str, Any\]"): - RobotProcessor([NonCompliantStep()]) - - -class NonCallableStep: - """Intentionally non-compliant: missing __call__.""" - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -def test_construction_rejects_step_without_call(): - with pytest.raises(TypeError, match=r"must define __call__"): - RobotProcessor([NonCallableStep()]) - - -@dataclass -class FeatureContractAddStep: - """Adds a PolicyFeature""" - - key: str = "a" - value: PolicyFeature = PolicyFeature(type=FeatureType.STATE, shape=(1,)) - - def __call__(self, transition: EnvTransition) -> EnvTransition: - return transition - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - features[self.key] = self.value - return features - - -@dataclass -class FeatureContractMutateStep: - """Mutates a PolicyFeature""" - - key: str = "a" - fn: Callable[[PolicyFeature | None], PolicyFeature] = lambda x: x # noqa: E731 - - def __call__(self, transition: EnvTransition) -> EnvTransition: - return transition - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - features[self.key] = self.fn(features.get(self.key)) - return features - - -@dataclass -class FeatureContractBadReturnStep: - """Returns a non-dict""" - - def __call__(self, transition: EnvTransition) -> EnvTransition: - return transition - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return ["not-a-dict"] - - -@dataclass -class FeatureContractRemoveStep: - """Removes a PolicyFeature""" - - key: str - - def __call__(self, transition: EnvTransition) -> EnvTransition: - return transition - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - features.pop(self.key, None) - return features - - -def test_feature_contract_orders_and_merges(policy_feature_factory): - p = RobotProcessor( - [ - FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))), - FeatureContractMutateStep("a", lambda v: PolicyFeature(type=v.type, shape=(3,))), - FeatureContractAddStep("b", policy_feature_factory(FeatureType.ENV, (2,))), - ] - ) - out = p.feature_contract({}) - - assert out["a"].type == FeatureType.STATE and out["a"].shape == (3,) - assert out["b"].type == FeatureType.ENV and out["b"].shape == (2,) - assert_contract_is_typed(out) - - -def test_feature_contract_respects_initial_without_mutation(policy_feature_factory): - initial = { - "seed": policy_feature_factory(FeatureType.STATE, (7,)), - "nested": policy_feature_factory(FeatureType.ENV, (0,)), - } - p = RobotProcessor( - [ - FeatureContractMutateStep("seed", lambda v: PolicyFeature(type=v.type, shape=(v.shape[0] + 1,))), - FeatureContractMutateStep( - "nested", lambda v: PolicyFeature(type=v.type, shape=(v.shape[0] + 5,)) - ), - ] - ) - out = p.feature_contract(initial_features=initial) - - assert out["seed"].shape == (8,) - assert out["nested"].shape == (5,) - # Initial dict must be preserved - assert initial["seed"].shape == (7,) - assert initial["nested"].shape == (0,) - - assert_contract_is_typed(out) - - -def test_feature_contract_type_error_on_bad_step(): - p = RobotProcessor([FeatureContractAddStep(), FeatureContractBadReturnStep()]) - with pytest.raises(TypeError, match=r"\w+\.feature_contract must return dict\[str, Any\]"): - _ = p.feature_contract({}) - - -def test_feature_contract_execution_order_tracking(): - class Track: - def __init__(self, label): - self.label = label - - def __call__(self, transition: EnvTransition) -> EnvTransition: - return transition - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - code = {"A": 1, "B": 2, "C": 3}[self.label] - pf = features.get("order", PolicyFeature(type=FeatureType.ENV, shape=())) - features["order"] = PolicyFeature(type=pf.type, shape=pf.shape + (code,)) - return features - - out = RobotProcessor([Track("A"), Track("B"), Track("C")]).feature_contract({}) - assert out["order"].shape == (1, 2, 3) - - -def test_feature_contract_remove_key(policy_feature_factory): - p = RobotProcessor( - [ - FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))), - FeatureContractRemoveStep("a"), - ] - ) - out = p.feature_contract({}) - assert "a" not in out - - -def test_feature_contract_remove_from_initial(policy_feature_factory): - initial = { - "keep": policy_feature_factory(FeatureType.STATE, (1,)), - "drop": policy_feature_factory(FeatureType.STATE, (1,)), - } - p = RobotProcessor([FeatureContractRemoveStep("drop")]) - out = p.feature_contract(initial_features=initial) - assert "drop" not in out and out["keep"] == initial["keep"] diff --git a/tests/processor/test_rename_processor.py b/tests/processor/test_rename_processor.py deleted file mode 100644 index 229d57f..0000000 --- a/tests/processor/test_rename_processor.py +++ /dev/null @@ -1,467 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import tempfile -from pathlib import Path - -import numpy as np -import torch - -from lerobot.configs.types import FeatureType -from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey -from tests.conftest import assert_contract_is_typed - - -def create_transition( - observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None -): - """Helper to create an EnvTransition dictionary.""" - return { - TransitionKey.OBSERVATION: observation, - TransitionKey.ACTION: action, - TransitionKey.REWARD: reward, - TransitionKey.DONE: done, - TransitionKey.TRUNCATED: truncated, - TransitionKey.INFO: info, - TransitionKey.COMPLEMENTARY_DATA: complementary_data, - } - - -def test_basic_renaming(): - """Test basic key renaming functionality.""" - rename_map = { - "old_key1": "new_key1", - "old_key2": "new_key2", - } - processor = RenameProcessor(rename_map=rename_map) - - observation = { - "old_key1": torch.tensor([1.0, 2.0]), - "old_key2": np.array([3.0, 4.0]), - "unchanged_key": "keep_me", - } - transition = create_transition(observation=observation) - - result = processor(transition) - processed_obs = result[TransitionKey.OBSERVATION] - - # Check renamed keys - assert "new_key1" in processed_obs - assert "new_key2" in processed_obs - assert "old_key1" not in processed_obs - assert "old_key2" not in processed_obs - - # Check values are preserved - torch.testing.assert_close(processed_obs["new_key1"], torch.tensor([1.0, 2.0])) - np.testing.assert_array_equal(processed_obs["new_key2"], np.array([3.0, 4.0])) - - # Check unchanged key is preserved - assert processed_obs["unchanged_key"] == "keep_me" - - -def test_empty_rename_map(): - """Test processor with empty rename map (should pass through unchanged).""" - processor = RenameProcessor(rename_map={}) - - observation = { - "key1": torch.tensor([1.0]), - "key2": "value2", - } - transition = create_transition(observation=observation) - - result = processor(transition) - processed_obs = result[TransitionKey.OBSERVATION] - - # All keys should be unchanged - assert processed_obs.keys() == observation.keys() - torch.testing.assert_close(processed_obs["key1"], observation["key1"]) - assert processed_obs["key2"] == observation["key2"] - - -def test_none_observation(): - """Test processor with None observation.""" - processor = RenameProcessor(rename_map={"old": "new"}) - - transition = create_transition() - result = processor(transition) - - # Should return transition unchanged - assert result == transition - - -def test_overlapping_rename(): - """Test renaming when new names might conflict.""" - rename_map = { - "a": "b", - "b": "c", # This creates a potential conflict - } - processor = RenameProcessor(rename_map=rename_map) - - observation = { - "a": 1, - "b": 2, - "x": 3, - } - transition = create_transition(observation=observation) - - result = processor(transition) - processed_obs = result[TransitionKey.OBSERVATION] - - # Check that renaming happens correctly - assert "a" not in processed_obs - assert processed_obs["b"] == 1 # 'a' renamed to 'b' - assert processed_obs["c"] == 2 # original 'b' renamed to 'c' - assert processed_obs["x"] == 3 - - -def test_partial_rename(): - """Test renaming only some keys.""" - rename_map = { - "observation.state": "observation.proprio_state", - "pixels": "observation.image", - } - processor = RenameProcessor(rename_map=rename_map) - - observation = { - "observation.state": torch.randn(10), - "pixels": np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8), - "reward": 1.0, - "info": {"episode": 1}, - } - transition = create_transition(observation=observation) - - result = processor(transition) - processed_obs = result[TransitionKey.OBSERVATION] - - # Check renamed keys - assert "observation.proprio_state" in processed_obs - assert "observation.image" in processed_obs - assert "observation.state" not in processed_obs - assert "pixels" not in processed_obs - - # Check unchanged keys - assert processed_obs["reward"] == 1.0 - assert processed_obs["info"] == {"episode": 1} - - -def test_get_config(): - """Test configuration serialization.""" - rename_map = { - "old1": "new1", - "old2": "new2", - } - processor = RenameProcessor(rename_map=rename_map) - - config = processor.get_config() - assert config == {"rename_map": rename_map} - - -def test_state_dict(): - """Test state dict (should be empty for RenameProcessor).""" - processor = RenameProcessor(rename_map={"old": "new"}) - - state = processor.state_dict() - assert state == {} - - # Load state dict should work even with empty dict - processor.load_state_dict({}) - - -def test_integration_with_robot_processor(): - """Test integration with RobotProcessor pipeline.""" - rename_map = { - "agent_pos": "observation.state", - "pixels": "observation.image", - } - rename_processor = RenameProcessor(rename_map=rename_map) - - pipeline = RobotProcessor([rename_processor]) - - observation = { - "agent_pos": np.array([1.0, 2.0, 3.0]), - "pixels": np.zeros((32, 32, 3), dtype=np.uint8), - "other_data": "preserve_me", - } - transition = create_transition( - observation=observation, reward=0.5, done=False, truncated=False, info={}, complementary_data={} - ) - - result = pipeline(transition) - processed_obs = result[TransitionKey.OBSERVATION] - - # Check renaming worked through pipeline - assert "observation.state" in processed_obs - assert "observation.image" in processed_obs - assert "agent_pos" not in processed_obs - assert "pixels" not in processed_obs - assert processed_obs["other_data"] == "preserve_me" - - # Check other transition elements unchanged - assert result[TransitionKey.REWARD] == 0.5 - assert result[TransitionKey.DONE] is False - - -def test_save_and_load_pretrained(): - """Test saving and loading processor with RobotProcessor.""" - rename_map = { - "old_state": "observation.state", - "old_image": "observation.image", - } - processor = RenameProcessor(rename_map=rename_map) - pipeline = RobotProcessor([processor], name="TestRenameProcessor") - - with tempfile.TemporaryDirectory() as tmp_dir: - # Save pipeline - pipeline.save_pretrained(tmp_dir) - - # Check files were created - config_path = Path(tmp_dir) / "testrenameprocessor.json" # Based on name="TestRenameProcessor" - assert config_path.exists() - - # No state files should be created for RenameProcessor - state_files = list(Path(tmp_dir).glob("*.safetensors")) - assert len(state_files) == 0 - - # Load pipeline - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) - - assert loaded_pipeline.name == "TestRenameProcessor" - assert len(loaded_pipeline) == 1 - - # Check that loaded processor works correctly - loaded_processor = loaded_pipeline.steps[0] - assert isinstance(loaded_processor, RenameProcessor) - assert loaded_processor.rename_map == rename_map - - # Test functionality after loading - observation = {"old_state": [1, 2, 3], "old_image": "image_data"} - transition = create_transition(observation=observation) - - result = loaded_pipeline(transition) - processed_obs = result[TransitionKey.OBSERVATION] - - assert "observation.state" in processed_obs - assert "observation.image" in processed_obs - assert processed_obs["observation.state"] == [1, 2, 3] - assert processed_obs["observation.image"] == "image_data" - - -def test_registry_functionality(): - """Test that RenameProcessor is properly registered.""" - # Check that it's registered - assert "rename_processor" in ProcessorStepRegistry.list() - - # Get from registry - retrieved_class = ProcessorStepRegistry.get("rename_processor") - assert retrieved_class is RenameProcessor - - # Create instance from registry - instance = retrieved_class(rename_map={"old": "new"}) - assert isinstance(instance, RenameProcessor) - assert instance.rename_map == {"old": "new"} - - -def test_registry_based_save_load(): - """Test save/load using registry name instead of module path.""" - processor = RenameProcessor(rename_map={"key1": "renamed_key1"}) - pipeline = RobotProcessor([processor]) - - with tempfile.TemporaryDirectory() as tmp_dir: - # Save and load - pipeline.save_pretrained(tmp_dir) - - # Verify config uses registry name - import json - - with open(Path(tmp_dir) / "robotprocessor.json") as f: # Default name is "RobotProcessor" - config = json.load(f) - - assert "registry_name" in config["steps"][0] - assert config["steps"][0]["registry_name"] == "rename_processor" - assert "class" not in config["steps"][0] # Should use registry, not module path - - # Load should work - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) - loaded_processor = loaded_pipeline.steps[0] - assert isinstance(loaded_processor, RenameProcessor) - assert loaded_processor.rename_map == {"key1": "renamed_key1"} - - -def test_chained_rename_processors(): - """Test multiple RenameProcessors in a pipeline.""" - # First processor: rename raw keys to intermediate format - processor1 = RenameProcessor( - rename_map={ - "pos": "agent_position", - "img": "camera_image", - } - ) - - # Second processor: rename to final format - processor2 = RenameProcessor( - rename_map={ - "agent_position": "observation.state", - "camera_image": "observation.image", - } - ) - - pipeline = RobotProcessor([processor1, processor2]) - - observation = { - "pos": np.array([1.0, 2.0]), - "img": "image_data", - "extra": "keep_me", - } - transition = create_transition(observation=observation) - - # Step through to see intermediate results - results = list(pipeline.step_through(transition)) - - # After first processor - assert "agent_position" in results[1][TransitionKey.OBSERVATION] - assert "camera_image" in results[1][TransitionKey.OBSERVATION] - - # After second processor - final_obs = results[2][TransitionKey.OBSERVATION] - assert "observation.state" in final_obs - assert "observation.image" in final_obs - assert final_obs["extra"] == "keep_me" - - # Original keys should be gone - assert "pos" not in final_obs - assert "img" not in final_obs - assert "agent_position" not in final_obs - assert "camera_image" not in final_obs - - -def test_nested_observation_rename(): - """Test renaming with nested observation structures.""" - rename_map = { - "observation.images.left": "observation.camera.left_view", - "observation.images.right": "observation.camera.right_view", - "observation.proprio": "observation.proprioception", - } - processor = RenameProcessor(rename_map=rename_map) - - observation = { - "observation.images.left": torch.randn(3, 64, 64), - "observation.images.right": torch.randn(3, 64, 64), - "observation.proprio": torch.randn(7), - "observation.gripper": torch.tensor([0.0]), # Not renamed - } - transition = create_transition(observation=observation) - - result = processor(transition) - processed_obs = result[TransitionKey.OBSERVATION] - - # Check renames - assert "observation.camera.left_view" in processed_obs - assert "observation.camera.right_view" in processed_obs - assert "observation.proprioception" in processed_obs - - # Check unchanged key - assert "observation.gripper" in processed_obs - - # Check old keys removed - assert "observation.images.left" not in processed_obs - assert "observation.images.right" not in processed_obs - assert "observation.proprio" not in processed_obs - - -def test_value_types_preserved(): - """Test that various value types are preserved during renaming.""" - rename_map = {"old_tensor": "new_tensor", "old_array": "new_array", "old_scalar": "new_scalar"} - processor = RenameProcessor(rename_map=rename_map) - - tensor_value = torch.randn(3, 3) - array_value = np.random.rand(2, 2) - - observation = { - "old_tensor": tensor_value, - "old_array": array_value, - "old_scalar": 42, - "old_string": "hello", - "old_dict": {"nested": "value"}, - "old_list": [1, 2, 3], - } - transition = create_transition(observation=observation) - - result = processor(transition) - processed_obs = result[TransitionKey.OBSERVATION] - - # Check that values and types are preserved - assert torch.equal(processed_obs["new_tensor"], tensor_value) - assert np.array_equal(processed_obs["new_array"], array_value) - assert processed_obs["new_scalar"] == 42 - assert processed_obs["old_string"] == "hello" - assert processed_obs["old_dict"] == {"nested": "value"} - assert processed_obs["old_list"] == [1, 2, 3] - - -def test_feature_contract_basic_renaming(policy_feature_factory): - processor = RenameProcessor(rename_map={"a": "x", "b": "y"}) - features = { - "a": policy_feature_factory(FeatureType.STATE, (2,)), - "b": policy_feature_factory(FeatureType.ACTION, (3,)), - "c": policy_feature_factory(FeatureType.ENV, (1,)), - } - - out = processor.feature_contract(features.copy()) - - # Values preserved and typed - assert out["x"] == features["a"] - assert out["y"] == features["b"] - assert out["c"] == features["c"] - - assert_contract_is_typed(out) - # Input not mutated - assert set(features) == {"a", "b", "c"} - - -def test_feature_contract_overlapping_keys(policy_feature_factory): - # Overlapping renames: both 'a' and 'b' exist. 'a'->'b', 'b'->'c' - processor = RenameProcessor(rename_map={"a": "b", "b": "c"}) - features = { - "a": policy_feature_factory(FeatureType.STATE, (1,)), - "b": policy_feature_factory(FeatureType.STATE, (2,)), - } - out = processor.feature_contract(features) - - assert set(out) == {"b", "c"} - assert out["b"] == features["a"] # 'a' renamed to'b' - assert out["c"] == features["b"] # 'b' renamed to 'c' - assert_contract_is_typed(out) - - -def test_feature_contract_chained_processors(policy_feature_factory): - # Chain two rename processors at the contract level - processor1 = RenameProcessor(rename_map={"pos": "agent_position", "img": "camera_image"}) - processor2 = RenameProcessor( - rename_map={"agent_position": "observation.state", "camera_image": "observation.image"} - ) - pipeline = RobotProcessor([processor1, processor2]) - - spec = { - "pos": policy_feature_factory(FeatureType.STATE, (7,)), - "img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), - "extra": policy_feature_factory(FeatureType.ENV, (1,)), - } - out = pipeline.feature_contract(initial_features=spec) - - assert set(out) == {"observation.state", "observation.image", "extra"} - assert out["observation.state"] == spec["pos"] - assert out["observation.image"] == spec["img"] - assert out["extra"] == spec["extra"] - assert_contract_is_typed(out) diff --git a/tests/rl/test_actor.py b/tests/rl/test_actor.py deleted file mode 100644 index f078b46..0000000 --- a/tests/rl/test_actor.py +++ /dev/null @@ -1,208 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from concurrent import futures -from unittest.mock import patch - -import pytest -import torch -from torch.multiprocessing import Event, Queue - -from lerobot.utils.transition import Transition -from tests.utils import require_package - - -def create_learner_service_stub(): - import grpc - - from lerobot.transport import services_pb2, services_pb2_grpc - - class MockLearnerService(services_pb2_grpc.LearnerServiceServicer): - def __init__(self): - self.ready_call_count = 0 - self.should_fail = False - - def Ready(self, request, context): # noqa: N802 - self.ready_call_count += 1 - if self.should_fail: - context.set_code(grpc.StatusCode.UNAVAILABLE) - context.set_details("Service unavailable") - raise grpc.RpcError("Service unavailable") - return services_pb2.Empty() - - """Fixture to start a LearnerService gRPC server and provide a connected stub.""" - - servicer = MockLearnerService() - - # Create a gRPC server and add our servicer to it. - server = grpc.server(futures.ThreadPoolExecutor(max_workers=4)) - services_pb2_grpc.add_LearnerServiceServicer_to_server(servicer, server) - port = server.add_insecure_port("[::]:0") # bind to a free port chosen by OS - server.start() # start the server (non-blocking call):contentReference[oaicite:1]{index=1} - - # Create a client channel and stub connected to the server's port. - channel = grpc.insecure_channel(f"localhost:{port}") - return services_pb2_grpc.LearnerServiceStub(channel), servicer, channel, server - - -def close_service_stub(channel, server): - channel.close() - server.stop(None) - - -@require_package("grpc") -def test_establish_learner_connection_success(): - from lerobot.scripts.rl.actor import establish_learner_connection - - """Test successful connection establishment.""" - stub, _servicer, channel, server = create_learner_service_stub() - - shutdown_event = Event() - - # Test successful connection - result = establish_learner_connection(stub, shutdown_event, attempts=5) - - assert result is True - - close_service_stub(channel, server) - - -@require_package("grpc") -def test_establish_learner_connection_failure(): - from lerobot.scripts.rl.actor import establish_learner_connection - - """Test connection failure.""" - stub, servicer, channel, server = create_learner_service_stub() - servicer.should_fail = True - - shutdown_event = Event() - - # Test failed connection - with patch("time.sleep"): # Speed up the test - result = establish_learner_connection(stub, shutdown_event, attempts=2) - - assert result is False - - close_service_stub(channel, server) - - -@require_package("grpc") -def test_push_transitions_to_transport_queue(): - from lerobot.scripts.rl.actor import push_transitions_to_transport_queue - from lerobot.transport.utils import bytes_to_transitions - from tests.transport.test_transport_utils import assert_transitions_equal - - """Test pushing transitions to transport queue.""" - # Create mock transitions - transitions = [] - for i in range(3): - transition = Transition( - state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, - action=torch.randn(5), - reward=torch.tensor(1.0 + i), - done=torch.tensor(False), - truncated=torch.tensor(False), - next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, - complementary_info={"step": torch.tensor(i)}, - ) - transitions.append(transition) - - transitions_queue = Queue() - - # Test pushing transitions - push_transitions_to_transport_queue(transitions, transitions_queue) - - # Verify the data can be retrieved - serialized_data = transitions_queue.get() - assert isinstance(serialized_data, bytes) - deserialized_transitions = bytes_to_transitions(serialized_data) - assert len(deserialized_transitions) == len(transitions) - for i, deserialized_transition in enumerate(deserialized_transitions): - assert_transitions_equal(deserialized_transition, transitions[i]) - - -@require_package("grpc") -@pytest.mark.timeout(3) # force cross-platform watchdog -def test_transitions_stream(): - from lerobot.scripts.rl.actor import transitions_stream - - """Test transitions stream functionality.""" - shutdown_event = Event() - transitions_queue = Queue() - - # Add test data to queue - test_data = [b"transition_data_1", b"transition_data_2", b"transition_data_3"] - for data in test_data: - transitions_queue.put(data) - - # Collect streamed data - streamed_data = [] - stream_generator = transitions_stream(shutdown_event, transitions_queue, 0.1) - - # Process a few items - for i, message in enumerate(stream_generator): - streamed_data.append(message) - if i >= len(test_data) - 1: - shutdown_event.set() - break - - # Verify we got messages - assert len(streamed_data) == len(test_data) - assert streamed_data[0].data == b"transition_data_1" - assert streamed_data[1].data == b"transition_data_2" - assert streamed_data[2].data == b"transition_data_3" - - -@require_package("grpc") -@pytest.mark.timeout(3) # force cross-platform watchdog -def test_interactions_stream(): - from lerobot.scripts.rl.actor import interactions_stream - from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes - - """Test interactions stream functionality.""" - shutdown_event = Event() - interactions_queue = Queue() - - # Create test interaction data (similar structure to what would be sent) - test_interactions = [ - {"episode_reward": 10.5, "step": 1, "policy_fps": 30.2}, - {"episode_reward": 15.2, "step": 2, "policy_fps": 28.7}, - {"episode_reward": 8.7, "step": 3, "policy_fps": 29.1}, - ] - - # Serialize the interaction data as it would be in practice - test_data = [ - interactions_queue.put(python_object_to_bytes(interaction)) for interaction in test_interactions - ] - - # Collect streamed data - streamed_data = [] - stream_generator = interactions_stream(shutdown_event, interactions_queue, 0.1) - - # Process the items - for i, message in enumerate(stream_generator): - streamed_data.append(message) - if i >= len(test_data) - 1: - shutdown_event.set() - break - - # Verify we got messages - assert len(streamed_data) == len(test_data) - - # Verify the messages can be deserialized back to original data - for i, message in enumerate(streamed_data): - deserialized_interaction = bytes_to_python_object(message.data) - assert deserialized_interaction == test_interactions[i] diff --git a/tests/rl/test_actor_learner.py b/tests/rl/test_actor_learner.py deleted file mode 100644 index b2a7a5d..0000000 --- a/tests/rl/test_actor_learner.py +++ /dev/null @@ -1,297 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import socket -import threading -import time - -import pytest -import torch -from torch.multiprocessing import Event, Queue - -from lerobot.configs.train import TrainRLServerPipelineConfig -from lerobot.policies.sac.configuration_sac import SACConfig -from lerobot.utils.transition import Transition -from tests.utils import require_package - - -def create_test_transitions(count: int = 3) -> list[Transition]: - """Create test transitions for integration testing.""" - transitions = [] - for i in range(count): - transition = Transition( - state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, - action=torch.randn(5), - reward=torch.tensor(1.0 + i), - done=torch.tensor(i == count - 1), # Last transition is done - truncated=torch.tensor(False), - next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, - complementary_info={"step": torch.tensor(i), "episode_id": i // 2}, - ) - transitions.append(transition) - return transitions - - -def create_test_interactions(count: int = 3) -> list[dict]: - """Create test interactions for integration testing.""" - interactions = [] - for i in range(count): - interaction = { - "episode_reward": 10.0 + i * 5, - "step": i * 100, - "policy_fps": 30.0 + i, - "intervention_rate": 0.1 * i, - "episode_length": 200 + i * 50, - } - interactions.append(interaction) - return interactions - - -def find_free_port(): - """Finds a free port on the local machine.""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) # Bind to port 0 to let the OS choose a free port - s.listen(1) - port = s.getsockname()[1] - return port - - -@pytest.fixture -def cfg(): - cfg = TrainRLServerPipelineConfig() - - port = find_free_port() - - policy_cfg = SACConfig() - policy_cfg.actor_learner_config.learner_host = "127.0.0.1" - policy_cfg.actor_learner_config.learner_port = port - policy_cfg.concurrency.actor = "threads" - policy_cfg.concurrency.learner = "threads" - policy_cfg.actor_learner_config.queue_get_timeout = 0.1 - - cfg.policy = policy_cfg - - return cfg - - -@require_package("grpc") -@pytest.mark.timeout(10) # force cross-platform watchdog -def test_end_to_end_transitions_flow(cfg): - from lerobot.scripts.rl.actor import ( - establish_learner_connection, - learner_service_client, - push_transitions_to_transport_queue, - send_transitions, - ) - from lerobot.scripts.rl.learner import start_learner - from lerobot.transport.utils import bytes_to_transitions - from tests.transport.test_transport_utils import assert_transitions_equal - - """Test complete transitions flow from actor to learner.""" - transitions_actor_queue = Queue() - transitions_learner_queue = Queue() - - interactions_queue = Queue() - parameters_queue = Queue() - shutdown_event = Event() - - learner_thread = threading.Thread( - target=start_learner, - args=(parameters_queue, transitions_learner_queue, interactions_queue, shutdown_event, cfg), - ) - learner_thread.start() - - policy_cfg = cfg.policy - learner_client, channel = learner_service_client( - host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port - ) - - assert establish_learner_connection(learner_client, shutdown_event, attempts=5) - - send_transitions_thread = threading.Thread( - target=send_transitions, args=(cfg, transitions_actor_queue, shutdown_event, learner_client, channel) - ) - send_transitions_thread.start() - - input_transitions = create_test_transitions(count=5) - - push_transitions_to_transport_queue(input_transitions, transitions_actor_queue) - - # Wait for learner to start - time.sleep(0.1) - - shutdown_event.set() - - # Wait for learner to receive transitions - learner_thread.join() - send_transitions_thread.join() - channel.close() - - received_transitions = [] - while not transitions_learner_queue.empty(): - received_transitions.extend(bytes_to_transitions(transitions_learner_queue.get())) - - assert len(received_transitions) == len(input_transitions) - for i, transition in enumerate(received_transitions): - assert_transitions_equal(transition, input_transitions[i]) - - -@require_package("grpc") -@pytest.mark.timeout(10) -def test_end_to_end_interactions_flow(cfg): - from lerobot.scripts.rl.actor import ( - establish_learner_connection, - learner_service_client, - send_interactions, - ) - from lerobot.scripts.rl.learner import start_learner - from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes - - """Test complete interactions flow from actor to learner.""" - # Queues for actor-learner communication - interactions_actor_queue = Queue() - interactions_learner_queue = Queue() - - # Other queues required by the learner - parameters_queue = Queue() - transitions_learner_queue = Queue() - - shutdown_event = Event() - - # Start the learner in a separate thread - learner_thread = threading.Thread( - target=start_learner, - args=(parameters_queue, transitions_learner_queue, interactions_learner_queue, shutdown_event, cfg), - ) - learner_thread.start() - - # Establish connection from actor to learner - policy_cfg = cfg.policy - learner_client, channel = learner_service_client( - host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port - ) - - assert establish_learner_connection(learner_client, shutdown_event, attempts=5) - - # Start the actor's interaction sending process in a separate thread - send_interactions_thread = threading.Thread( - target=send_interactions, - args=(cfg, interactions_actor_queue, shutdown_event, learner_client, channel), - ) - send_interactions_thread.start() - - # Create and push test interactions to the actor's queue - input_interactions = create_test_interactions(count=5) - for interaction in input_interactions: - interactions_actor_queue.put(python_object_to_bytes(interaction)) - - # Wait for the communication to happen - time.sleep(0.1) - - # Signal shutdown and wait for threads to complete - shutdown_event.set() - learner_thread.join() - send_interactions_thread.join() - channel.close() - - # Verify that the learner received the interactions - received_interactions = [] - while not interactions_learner_queue.empty(): - received_interactions.append(bytes_to_python_object(interactions_learner_queue.get())) - - assert len(received_interactions) == len(input_interactions) - - # Sort by a unique key to handle potential reordering in queues - received_interactions.sort(key=lambda x: x["step"]) - input_interactions.sort(key=lambda x: x["step"]) - - for received, expected in zip(received_interactions, input_interactions, strict=False): - assert received == expected - - -@require_package("grpc") -@pytest.mark.parametrize("data_size", ["small", "large"]) -@pytest.mark.timeout(10) -def test_end_to_end_parameters_flow(cfg, data_size): - from lerobot.scripts.rl.actor import establish_learner_connection, learner_service_client, receive_policy - from lerobot.scripts.rl.learner import start_learner - from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes - - """Test complete parameter flow from learner to actor, with small and large data.""" - # Actor's local queue to receive params - parameters_actor_queue = Queue() - # Learner's queue to send params from - parameters_learner_queue = Queue() - - # Other queues required by the learner - transitions_learner_queue = Queue() - interactions_learner_queue = Queue() - - shutdown_event = Event() - - # Start the learner in a separate thread - learner_thread = threading.Thread( - target=start_learner, - args=( - parameters_learner_queue, - transitions_learner_queue, - interactions_learner_queue, - shutdown_event, - cfg, - ), - ) - learner_thread.start() - - # Establish connection from actor to learner - policy_cfg = cfg.policy - learner_client, channel = learner_service_client( - host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port - ) - - assert establish_learner_connection(learner_client, shutdown_event, attempts=5) - - # Start the actor's parameter receiving process in a separate thread - receive_params_thread = threading.Thread( - target=receive_policy, - args=(cfg, parameters_actor_queue, shutdown_event, learner_client, channel), - ) - receive_params_thread.start() - - # Create test parameters based on parametrization - if data_size == "small": - input_params = {"layer.weight": torch.randn(128, 64)} - else: # "large" - # CHUNK_SIZE is 2MB, so this tensor (4MB) will force chunking - input_params = {"large_layer.weight": torch.randn(1024, 1024)} - - # Simulate learner having new parameters to send - parameters_learner_queue.put(state_to_bytes(input_params)) - - # Wait for the actor to receive the parameters - time.sleep(0.1) - - # Signal shutdown and wait for threads to complete - shutdown_event.set() - learner_thread.join() - receive_params_thread.join() - channel.close() - - # Verify that the actor received the parameters correctly - received_params = bytes_to_state_dict(parameters_actor_queue.get()) - - assert received_params.keys() == input_params.keys() - for key in input_params: - assert torch.allclose(received_params[key], input_params[key]) diff --git a/tests/rl/test_learner_service.py b/tests/rl/test_learner_service.py deleted file mode 100644 index f5e1e8d..0000000 --- a/tests/rl/test_learner_service.py +++ /dev/null @@ -1,374 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import threading -import time -from concurrent import futures -from multiprocessing import Event, Queue - -import pytest - -from tests.utils import require_package # our gRPC servicer class - - -@pytest.fixture(scope="function") -def learner_service_stub(): - shutdown_event = Event() - parameters_queue = Queue() - transitions_queue = Queue() - interactions_queue = Queue() - seconds_between_pushes = 1 - client, channel, server = create_learner_service_stub( - shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes - ) - - yield client # provide the stub to the test function - - close_learner_service_stub(channel, server) - - -@require_package("grpc") -def create_learner_service_stub( - shutdown_event: Event, - parameters_queue: Queue, - transitions_queue: Queue, - interactions_queue: Queue, - seconds_between_pushes: int, - queue_get_timeout: float = 0.1, -): - import grpc - - from lerobot.scripts.rl.learner_service import LearnerService - from lerobot.transport import services_pb2_grpc # generated from .proto - - """Fixture to start a LearnerService gRPC server and provide a connected stub.""" - - servicer = LearnerService( - shutdown_event=shutdown_event, - parameters_queue=parameters_queue, - seconds_between_pushes=seconds_between_pushes, - transition_queue=transitions_queue, - interaction_message_queue=interactions_queue, - queue_get_timeout=queue_get_timeout, - ) - - # Create a gRPC server and add our servicer to it. - server = grpc.server(futures.ThreadPoolExecutor(max_workers=4)) - services_pb2_grpc.add_LearnerServiceServicer_to_server(servicer, server) - port = server.add_insecure_port("[::]:0") # bind to a free port chosen by OS - server.start() # start the server (non-blocking call):contentReference[oaicite:1]{index=1} - - # Create a client channel and stub connected to the server's port. - channel = grpc.insecure_channel(f"localhost:{port}") - return services_pb2_grpc.LearnerServiceStub(channel), channel, server - - -@require_package("grpc") -def close_learner_service_stub(channel, server): - channel.close() - server.stop(None) - - -@pytest.mark.timeout(3) # force cross-platform watchdog -def test_ready_method(learner_service_stub): - from lerobot.transport import services_pb2 - - """Test the ready method of the UserService.""" - request = services_pb2.Empty() - response = learner_service_stub.Ready(request) - assert response == services_pb2.Empty() - - -@require_package("grpc") -@pytest.mark.timeout(3) # force cross-platform watchdog -def test_send_interactions(): - from lerobot.transport import services_pb2 - - shutdown_event = Event() - - parameters_queue = Queue() - transitions_queue = Queue() - interactions_queue = Queue() - seconds_between_pushes = 1 - client, channel, server = create_learner_service_stub( - shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes - ) - - list_of_interaction_messages = [ - services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"1"), - services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"2"), - services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"3"), - services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"4"), - services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"5"), - services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"6"), - services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"7"), - services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"8"), - ] - - def mock_intercations_stream(): - yield from list_of_interaction_messages - - return services_pb2.Empty() - - response = client.SendInteractions(mock_intercations_stream()) - assert response == services_pb2.Empty() - - close_learner_service_stub(channel, server) - - # Extract the data from the interactions queue - interactions = [] - while not interactions_queue.empty(): - interactions.append(interactions_queue.get()) - - assert interactions == [b"123", b"4", b"5", b"678"] - - -@require_package("grpc") -@pytest.mark.timeout(3) # force cross-platform watchdog -def test_send_transitions(): - from lerobot.transport import services_pb2 - - """Test the SendTransitions method with various transition data.""" - shutdown_event = Event() - parameters_queue = Queue() - transitions_queue = Queue() - interactions_queue = Queue() - seconds_between_pushes = 1 - - client, channel, server = create_learner_service_stub( - shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes - ) - - # Create test transition messages - list_of_transition_messages = [ - services_pb2.Transition( - transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"transition_1" - ), - services_pb2.Transition( - transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"transition_2" - ), - services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"transition_3"), - services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"batch_1"), - services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"batch_2"), - ] - - def mock_transitions_stream(): - yield from list_of_transition_messages - - response = client.SendTransitions(mock_transitions_stream()) - assert response == services_pb2.Empty() - - close_learner_service_stub(channel, server) - - # Extract the data from the transitions queue - transitions = [] - while not transitions_queue.empty(): - transitions.append(transitions_queue.get()) - - # Should have assembled the chunked data - assert transitions == [b"transition_1transition_2transition_3", b"batch_1batch_2"] - - -@require_package("grpc") -@pytest.mark.timeout(3) # force cross-platform watchdog -def test_send_transitions_empty_stream(): - from lerobot.transport import services_pb2 - - """Test SendTransitions with empty stream.""" - shutdown_event = Event() - parameters_queue = Queue() - transitions_queue = Queue() - interactions_queue = Queue() - seconds_between_pushes = 1 - - client, channel, server = create_learner_service_stub( - shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes - ) - - def empty_stream(): - return iter([]) - - response = client.SendTransitions(empty_stream()) - assert response == services_pb2.Empty() - - close_learner_service_stub(channel, server) - - # Queue should remain empty - assert transitions_queue.empty() - - -@require_package("grpc") -@pytest.mark.timeout(10) # force cross-platform watchdog -def test_stream_parameters(): - import time - - from lerobot.transport import services_pb2 - - """Test the StreamParameters method.""" - shutdown_event = Event() - parameters_queue = Queue() - transitions_queue = Queue() - interactions_queue = Queue() - seconds_between_pushes = 0.2 # Short delay for testing - - client, channel, server = create_learner_service_stub( - shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes - ) - - # Add test parameters to the queue - test_params = [b"param_batch_1", b"param_batch_2"] - for param in test_params: - parameters_queue.put(param) - - # Start streaming parameters - request = services_pb2.Empty() - stream = client.StreamParameters(request) - - # Collect streamed parameters and timestamps - received_params = [] - timestamps = [] - - for response in stream: - received_params.append(response.data) - timestamps.append(time.time()) - - # We should receive one last item - break - - parameters_queue.put(b"param_batch_3") - - for response in stream: - received_params.append(response.data) - timestamps.append(time.time()) - - # We should receive only one item - break - - shutdown_event.set() - close_learner_service_stub(channel, server) - - assert received_params == [b"param_batch_2", b"param_batch_3"] - - # Check the time difference between the two sends - time_diff = timestamps[1] - timestamps[0] - # Check if the time difference is close to the expected push frequency - assert time_diff == pytest.approx(seconds_between_pushes, abs=0.1) - - -@require_package("grpc") -@pytest.mark.timeout(3) # force cross-platform watchdog -def test_stream_parameters_with_shutdown(): - from lerobot.transport import services_pb2 - - """Test StreamParameters handles shutdown gracefully.""" - shutdown_event = Event() - parameters_queue = Queue() - transitions_queue = Queue() - interactions_queue = Queue() - seconds_between_pushes = 0.1 - queue_get_timeout = 0.001 - - client, channel, server = create_learner_service_stub( - shutdown_event, - parameters_queue, - transitions_queue, - interactions_queue, - seconds_between_pushes, - queue_get_timeout=queue_get_timeout, - ) - - test_params = [b"param_batch_1", b"stop", b"param_batch_3", b"param_batch_4"] - - # create a thread that will put the parameters in the queue - def producer(): - for param in test_params: - parameters_queue.put(param) - time.sleep(0.1) - - producer_thread = threading.Thread(target=producer) - producer_thread.start() - - # Start streaming - request = services_pb2.Empty() - stream = client.StreamParameters(request) - - # Collect streamed parameters - received_params = [] - - for response in stream: - received_params.append(response.data) - - if response.data == b"stop": - shutdown_event.set() - - producer_thread.join() - close_learner_service_stub(channel, server) - - assert received_params == [b"param_batch_1", b"stop"] - - -@require_package("grpc") -@pytest.mark.timeout(3) # force cross-platform watchdog -def test_stream_parameters_waits_and_retries_on_empty_queue(): - import threading - import time - - from lerobot.transport import services_pb2 - - """Test that StreamParameters waits and retries when the queue is empty.""" - shutdown_event = Event() - parameters_queue = Queue() - transitions_queue = Queue() - interactions_queue = Queue() - seconds_between_pushes = 0.05 - queue_get_timeout = 0.01 - - client, channel, server = create_learner_service_stub( - shutdown_event, - parameters_queue, - transitions_queue, - interactions_queue, - seconds_between_pushes, - queue_get_timeout=queue_get_timeout, - ) - - request = services_pb2.Empty() - stream = client.StreamParameters(request) - - received_params = [] - - def producer(): - # Let the consumer start and find an empty queue. - # It will wait `seconds_between_pushes` (0.05s), then `get` will timeout after `queue_get_timeout` (0.01s). - # Total time for the first empty loop is > 0.06s. We wait a bit longer to be safe. - time.sleep(0.06) - parameters_queue.put(b"param_after_wait") - time.sleep(0.05) - parameters_queue.put(b"param_after_wait_2") - - producer_thread = threading.Thread(target=producer) - producer_thread.start() - - # The consumer will block here until the producer sends an item. - for response in stream: - received_params.append(response.data) - if response.data == b"param_after_wait_2": - break # We only need one item for this test. - - shutdown_event.set() - producer_thread.join() - close_learner_service_stub(channel, server) - - assert received_params == [b"param_after_wait", b"param_after_wait_2"] diff --git a/tests/robots/test_so100_follower.py b/tests/robots/test_so100_follower.py deleted file mode 100644 index d76b959..0000000 --- a/tests/robots/test_so100_follower.py +++ /dev/null @@ -1,111 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from contextlib import contextmanager -from unittest.mock import MagicMock, patch - -import pytest - -from lerobot.robots.so100_follower import ( - SO100Follower, - SO100FollowerConfig, -) - - -def _make_bus_mock() -> MagicMock: - """Return a bus mock with just the attributes used by the robot.""" - bus = MagicMock(name="FeetechBusMock") - bus.is_connected = False - - def _connect(): - bus.is_connected = True - - def _disconnect(_disable=True): - bus.is_connected = False - - bus.connect.side_effect = _connect - bus.disconnect.side_effect = _disconnect - - @contextmanager - def _dummy_cm(): - yield - - bus.torque_disabled.side_effect = _dummy_cm - - return bus - - -@pytest.fixture -def follower(): - bus_mock = _make_bus_mock() - - def _bus_side_effect(*_args, **kwargs): - bus_mock.motors = kwargs["motors"] - motors_order: list[str] = list(bus_mock.motors) - - bus_mock.sync_read.return_value = {motor: idx for idx, motor in enumerate(motors_order, 1)} - bus_mock.sync_write.return_value = None - bus_mock.write.return_value = None - bus_mock.disable_torque.return_value = None - bus_mock.enable_torque.return_value = None - bus_mock.is_calibrated = True - return bus_mock - - with ( - patch( - "lerobot.robots.so100_follower.so100_follower.FeetechMotorsBus", - side_effect=_bus_side_effect, - ), - patch.object(SO100Follower, "configure", lambda self: None), - ): - cfg = SO100FollowerConfig(port="/dev/null") - robot = SO100Follower(cfg) - yield robot - if robot.is_connected: - robot.disconnect() - - -def test_connect_disconnect(follower): - assert not follower.is_connected - - follower.connect() - assert follower.is_connected - - follower.disconnect() - assert not follower.is_connected - - -def test_get_observation(follower): - follower.connect() - obs = follower.get_observation() - - expected_keys = {f"{m}.pos" for m in follower.bus.motors} - assert set(obs.keys()) == expected_keys - - for idx, motor in enumerate(follower.bus.motors, 1): - assert obs[f"{motor}.pos"] == idx - - -def test_send_action(follower): - follower.connect() - - action = {f"{m}.pos": i * 10 for i, m in enumerate(follower.bus.motors, 1)} - returned = follower.send_action(action) - - assert returned == action - - goal_pos = {m: (i + 1) * 10 for i, m in enumerate(follower.bus.motors)} - follower.bus.sync_write.assert_called_once_with("Goal_Position", goal_pos) diff --git a/tests/test_available.py b/tests/test_available.py deleted file mode 100644 index 19e39b2..0000000 --- a/tests/test_available.py +++ /dev/null @@ -1,60 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import importlib - -import gymnasium as gym -import pytest - -import lerobot -from lerobot.policies.act.modeling_act import ACTPolicy -from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy -from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy -from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy -from tests.utils import require_env - - -@pytest.mark.parametrize("env_name, task_name", lerobot.env_task_pairs) -@require_env -def test_available_env_task(env_name: str, task_name: list): - """ - This test verifies that all environments listed in `lerobot/__init__.py` can - be successfully imported — if they're installed — and that their - `available_tasks_per_env` are valid. - """ - package_name = f"gym_{env_name}" - importlib.import_module(package_name) - gym_handle = f"{package_name}/{task_name}" - assert gym_handle in gym.envs.registry, gym_handle - - -def test_available_policies(): - """ - This test verifies that the class attribute `name` for all policies is - consistent with those listed in `lerobot/__init__.py`. - """ - policy_classes = [ACTPolicy, DiffusionPolicy, TDMPCPolicy, VQBeTPolicy] - policies = [pol_cls.name for pol_cls in policy_classes] - assert set(policies) == set(lerobot.available_policies), policies - - -def test_print(): - print(lerobot.available_envs) - print(lerobot.available_tasks_per_env) - print(lerobot.available_datasets) - print(lerobot.available_datasets_per_env) - print(lerobot.available_real_world_datasets) - print(lerobot.available_policies) - print(lerobot.available_policies_per_env) diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py deleted file mode 100644 index e45688c..0000000 --- a/tests/test_control_robot.py +++ /dev/null @@ -1,106 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from lerobot.calibrate import CalibrateConfig, calibrate -from lerobot.record import DatasetRecordConfig, RecordConfig, record -from lerobot.replay import DatasetReplayConfig, ReplayConfig, replay -from lerobot.teleoperate import TeleoperateConfig, teleoperate -from tests.fixtures.constants import DUMMY_REPO_ID -from tests.mocks.mock_robot import MockRobotConfig -from tests.mocks.mock_teleop import MockTeleopConfig - - -def test_calibrate(): - robot_cfg = MockRobotConfig() - cfg = CalibrateConfig(robot=robot_cfg) - calibrate(cfg) - - -def test_teleoperate(): - robot_cfg = MockRobotConfig() - teleop_cfg = MockTeleopConfig() - cfg = TeleoperateConfig( - robot=robot_cfg, - teleop=teleop_cfg, - teleop_time_s=0.1, - ) - teleoperate(cfg) - - -def test_record_and_resume(tmp_path): - robot_cfg = MockRobotConfig() - teleop_cfg = MockTeleopConfig() - dataset_cfg = DatasetRecordConfig( - repo_id=DUMMY_REPO_ID, - single_task="Dummy task", - root=tmp_path / "record", - num_episodes=1, - episode_time_s=0.1, - reset_time_s=0, - push_to_hub=False, - ) - cfg = RecordConfig( - robot=robot_cfg, - dataset=dataset_cfg, - teleop=teleop_cfg, - play_sounds=False, - ) - - dataset = record(cfg) - - assert dataset.fps == 30 - assert dataset.meta.total_episodes == dataset.num_episodes == 1 - assert dataset.meta.total_frames == dataset.num_frames == 3 - assert dataset.meta.total_tasks == 1 - - cfg.resume = True - dataset = record(cfg) - - assert dataset.meta.total_episodes == dataset.num_episodes == 2 - assert dataset.meta.total_frames == dataset.num_frames == 6 - assert dataset.meta.total_tasks == 1 - - -def test_record_and_replay(tmp_path): - robot_cfg = MockRobotConfig() - teleop_cfg = MockTeleopConfig() - record_dataset_cfg = DatasetRecordConfig( - repo_id=DUMMY_REPO_ID, - single_task="Dummy task", - root=tmp_path / "record_and_replay", - num_episodes=1, - episode_time_s=0.1, - push_to_hub=False, - ) - record_cfg = RecordConfig( - robot=robot_cfg, - dataset=record_dataset_cfg, - teleop=teleop_cfg, - play_sounds=False, - ) - replay_dataset_cfg = DatasetReplayConfig( - repo_id=DUMMY_REPO_ID, - episode=0, - root=tmp_path / "record_and_replay", - ) - replay_cfg = ReplayConfig( - robot=robot_cfg, - dataset=replay_dataset_cfg, - play_sounds=False, - ) - - record(record_cfg) - replay(replay_cfg) diff --git a/tests/transport/test_transport_utils.py b/tests/transport/test_transport_utils.py deleted file mode 100644 index 79edad4..0000000 --- a/tests/transport/test_transport_utils.py +++ /dev/null @@ -1,571 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import io -from multiprocessing import Event, Queue -from pickle import UnpicklingError - -import pytest -import torch - -from lerobot.utils.transition import Transition -from tests.utils import require_cuda, require_package - - -@require_package("grpc") -def test_bytes_buffer_size_empty_buffer(): - from lerobot.transport.utils import bytes_buffer_size - - """Test with an empty buffer.""" - buffer = io.BytesIO() - assert bytes_buffer_size(buffer) == 0 - # Ensure position is reset to beginning - assert buffer.tell() == 0 - - -@require_package("grpc") -def test_bytes_buffer_size_small_buffer(): - from lerobot.transport.utils import bytes_buffer_size - - """Test with a small buffer.""" - buffer = io.BytesIO(b"Hello, World!") - assert bytes_buffer_size(buffer) == 13 - assert buffer.tell() == 0 - - -@require_package("grpc") -def test_bytes_buffer_size_large_buffer(): - from lerobot.transport.utils import CHUNK_SIZE, bytes_buffer_size - - """Test with a large buffer.""" - data = b"x" * (CHUNK_SIZE * 2 + 1000) - buffer = io.BytesIO(data) - assert bytes_buffer_size(buffer) == len(data) - assert buffer.tell() == 0 - - -@require_package("grpc") -def test_send_bytes_in_chunks_empty_data(): - from lerobot.transport.utils import send_bytes_in_chunks, services_pb2 - - """Test sending empty data.""" - message_class = services_pb2.InteractionMessage - chunks = list(send_bytes_in_chunks(b"", message_class)) - assert len(chunks) == 0 - - -@require_package("grpc") -def test_single_chunk_small_data(): - from lerobot.transport.utils import send_bytes_in_chunks, services_pb2 - - """Test data that fits in a single chunk.""" - data = b"Some data" - message_class = services_pb2.InteractionMessage - chunks = list(send_bytes_in_chunks(data, message_class)) - - assert len(chunks) == 1 - assert chunks[0].data == b"Some data" - assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END - - -@require_package("grpc") -def test_not_silent_mode(): - from lerobot.transport.utils import send_bytes_in_chunks, services_pb2 - - """Test not silent mode.""" - data = b"Some data" - message_class = services_pb2.InteractionMessage - chunks = list(send_bytes_in_chunks(data, message_class, silent=False)) - assert len(chunks) == 1 - assert chunks[0].data == b"Some data" - - -@require_package("grpc") -def test_send_bytes_in_chunks_large_data(): - from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2 - - """Test sending large data.""" - data = b"x" * (CHUNK_SIZE * 2 + 1000) - message_class = services_pb2.InteractionMessage - chunks = list(send_bytes_in_chunks(data, message_class)) - assert len(chunks) == 3 - assert chunks[0].data == b"x" * CHUNK_SIZE - assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_BEGIN - assert chunks[1].data == b"x" * CHUNK_SIZE - assert chunks[1].transfer_state == services_pb2.TransferState.TRANSFER_MIDDLE - assert chunks[2].data == b"x" * 1000 - assert chunks[2].transfer_state == services_pb2.TransferState.TRANSFER_END - - -@require_package("grpc") -def test_send_bytes_in_chunks_large_data_with_exact_chunk_size(): - from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2 - - """Test sending large data with exact chunk size.""" - data = b"x" * CHUNK_SIZE - message_class = services_pb2.InteractionMessage - chunks = list(send_bytes_in_chunks(data, message_class)) - assert len(chunks) == 1 - assert chunks[0].data == data - assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END - - -@require_package("grpc") -def test_receive_bytes_in_chunks_empty_data(): - from lerobot.transport.utils import receive_bytes_in_chunks - - """Test receiving empty data.""" - queue = Queue() - shutdown_event = Event() - - # Empty iterator - receive_bytes_in_chunks(iter([]), queue, shutdown_event) - - assert queue.empty() - - -@require_package("grpc") -def test_receive_bytes_in_chunks_single_chunk(): - from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 - - """Test receiving a single chunk message.""" - queue = Queue() - shutdown_event = Event() - - data = b"Single chunk data" - chunks = [ - services_pb2.InteractionMessage(data=data, transfer_state=services_pb2.TransferState.TRANSFER_END) - ] - - receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) - - assert queue.get(timeout=0.01) == data - assert queue.empty() - - -@require_package("grpc") -def test_receive_bytes_in_chunks_single_not_end_chunk(): - from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 - - """Test receiving a single chunk message.""" - queue = Queue() - shutdown_event = Event() - - data = b"Single chunk data" - chunks = [ - services_pb2.InteractionMessage(data=data, transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE) - ] - - receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) - - assert queue.empty() - - -@require_package("grpc") -def test_receive_bytes_in_chunks_multiple_chunks(): - from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 - - """Test receiving a multi-chunk message.""" - queue = Queue() - shutdown_event = Event() - - chunks = [ - services_pb2.InteractionMessage( - data=b"First ", transfer_state=services_pb2.TransferState.TRANSFER_BEGIN - ), - services_pb2.InteractionMessage( - data=b"Middle ", transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE - ), - services_pb2.InteractionMessage(data=b"Last", transfer_state=services_pb2.TransferState.TRANSFER_END), - ] - - receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) - - assert queue.get(timeout=0.01) == b"First Middle Last" - assert queue.empty() - - -@require_package("grpc") -def test_receive_bytes_in_chunks_multiple_messages(): - from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 - - """Test receiving multiple complete messages in sequence.""" - queue = Queue() - shutdown_event = Event() - - chunks = [ - # First message - single chunk - services_pb2.InteractionMessage( - data=b"Message1", transfer_state=services_pb2.TransferState.TRANSFER_END - ), - # Second message - multi chunk - services_pb2.InteractionMessage( - data=b"Start2 ", transfer_state=services_pb2.TransferState.TRANSFER_BEGIN - ), - services_pb2.InteractionMessage( - data=b"Middle2 ", transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE - ), - services_pb2.InteractionMessage(data=b"End2", transfer_state=services_pb2.TransferState.TRANSFER_END), - # Third message - single chunk - services_pb2.InteractionMessage( - data=b"Message3", transfer_state=services_pb2.TransferState.TRANSFER_END - ), - ] - - receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) - - # Should have three messages in queue - assert queue.get(timeout=0.01) == b"Message1" - assert queue.get(timeout=0.01) == b"Start2 Middle2 End2" - assert queue.get(timeout=0.01) == b"Message3" - assert queue.empty() - - -@require_package("grpc") -def test_receive_bytes_in_chunks_shutdown_during_receive(): - from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 - - """Test that shutdown event stops receiving mid-stream.""" - queue = Queue() - shutdown_event = Event() - shutdown_event.set() - - chunks = [ - services_pb2.InteractionMessage( - data=b"First ", transfer_state=services_pb2.TransferState.TRANSFER_BEGIN - ), - services_pb2.InteractionMessage( - data=b"Middle ", transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE - ), - services_pb2.InteractionMessage(data=b"Last", transfer_state=services_pb2.TransferState.TRANSFER_END), - ] - - receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) - - assert queue.empty() - - -@require_package("grpc") -def test_receive_bytes_in_chunks_only_begin_chunk(): - from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 - - """Test receiving only a BEGIN chunk without END.""" - queue = Queue() - shutdown_event = Event() - - chunks = [ - services_pb2.InteractionMessage( - data=b"Start", transfer_state=services_pb2.TransferState.TRANSFER_BEGIN - ), - # No END chunk - ] - - receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) - - assert queue.empty() - - -@require_package("grpc") -def test_receive_bytes_in_chunks_missing_begin(): - from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 - - """Test receiving chunks starting with MIDDLE instead of BEGIN.""" - queue = Queue() - shutdown_event = Event() - - chunks = [ - # Missing BEGIN - services_pb2.InteractionMessage( - data=b"Middle", transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE - ), - services_pb2.InteractionMessage(data=b"End", transfer_state=services_pb2.TransferState.TRANSFER_END), - ] - - receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) - - # The implementation continues from where it is, so we should get partial data - assert queue.get(timeout=0.01) == b"MiddleEnd" - assert queue.empty() - - -# Tests for state_to_bytes and bytes_to_state_dict -@require_package("grpc") -def test_state_to_bytes_empty_dict(): - from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes - - """Test converting empty state dict to bytes.""" - state_dict = {} - data = state_to_bytes(state_dict) - reconstructed = bytes_to_state_dict(data) - assert reconstructed == state_dict - - -@require_package("grpc") -def test_bytes_to_state_dict_empty_data(): - from lerobot.transport.utils import bytes_to_state_dict - - """Test converting empty data to state dict.""" - with pytest.raises(EOFError): - bytes_to_state_dict(b"") - - -@require_package("grpc") -def test_state_to_bytes_simple_dict(): - from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes - - """Test converting simple state dict to bytes.""" - state_dict = { - "layer1.weight": torch.randn(10, 5), - "layer1.bias": torch.randn(10), - "layer2.weight": torch.randn(1, 10), - "layer2.bias": torch.randn(1), - } - - data = state_to_bytes(state_dict) - assert isinstance(data, bytes) - assert len(data) > 0 - - reconstructed = bytes_to_state_dict(data) - - assert len(reconstructed) == len(state_dict) - for key in state_dict: - assert key in reconstructed - assert torch.allclose(state_dict[key], reconstructed[key]) - - -@require_package("grpc") -def test_state_to_bytes_various_dtypes(): - from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes - - """Test converting state dict with various tensor dtypes.""" - state_dict = { - "float32": torch.randn(5, 5), - "float64": torch.randn(3, 3).double(), - "int32": torch.randint(0, 100, (4, 4), dtype=torch.int32), - "int64": torch.randint(0, 100, (2, 2), dtype=torch.int64), - "bool": torch.tensor([True, False, True]), - "uint8": torch.randint(0, 255, (3, 3), dtype=torch.uint8), - } - - data = state_to_bytes(state_dict) - reconstructed = bytes_to_state_dict(data) - - for key in state_dict: - assert reconstructed[key].dtype == state_dict[key].dtype - if state_dict[key].dtype == torch.bool: - assert torch.equal(state_dict[key], reconstructed[key]) - else: - assert torch.allclose(state_dict[key], reconstructed[key]) - - -@require_package("grpc") -def test_bytes_to_state_dict_invalid_data(): - from lerobot.transport.utils import bytes_to_state_dict - - """Test bytes_to_state_dict with invalid data.""" - with pytest.raises(UnpicklingError): - bytes_to_state_dict(b"This is not a valid torch save file") - - -@require_cuda -@require_package("grpc") -def test_state_to_bytes_various_dtypes_cuda(): - from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes - - """Test converting state dict with various tensor dtypes.""" - state_dict = { - "float32": torch.randn(5, 5).cuda(), - "float64": torch.randn(3, 3).double().cuda(), - "int32": torch.randint(0, 100, (4, 4), dtype=torch.int32).cuda(), - "int64": torch.randint(0, 100, (2, 2), dtype=torch.int64).cuda(), - "bool": torch.tensor([True, False, True]), - "uint8": torch.randint(0, 255, (3, 3), dtype=torch.uint8), - } - - data = state_to_bytes(state_dict) - reconstructed = bytes_to_state_dict(data) - - for key in state_dict: - assert reconstructed[key].dtype == state_dict[key].dtype - if state_dict[key].dtype == torch.bool: - assert torch.equal(state_dict[key], reconstructed[key]) - else: - assert torch.allclose(state_dict[key], reconstructed[key]) - - -@require_package("grpc") -def test_python_object_to_bytes_none(): - from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes - - """Test converting None to bytes.""" - obj = None - data = python_object_to_bytes(obj) - reconstructed = bytes_to_python_object(data) - assert reconstructed is None - - -@pytest.mark.parametrize( - "obj", - [ - 42, - -123, - 3.14159, - -2.71828, - "Hello, World!", - "Unicode: 你好世界 🌍", - True, - False, - b"byte string", - [], - [1, 2, 3], - [1, "two", 3.0, True, None], - {}, - {"key": "value", "number": 123, "nested": {"a": 1}}, - (), - (1, 2, 3), - ], -) -@require_package("grpc") -def test_python_object_to_bytes_simple_types(obj): - from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes - - """Test converting simple Python types.""" - data = python_object_to_bytes(obj) - reconstructed = bytes_to_python_object(data) - assert reconstructed == obj - assert type(reconstructed) is type(obj) - - -@require_package("grpc") -def test_python_object_to_bytes_with_tensors(): - from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes - - """Test converting objects containing PyTorch tensors.""" - obj = { - "tensor": torch.randn(5, 5), - "list_with_tensor": [1, 2, torch.randn(3, 3), "string"], - "nested": { - "tensor1": torch.randn(2, 2), - "tensor2": torch.tensor([1, 2, 3]), - }, - } - - data = python_object_to_bytes(obj) - reconstructed = bytes_to_python_object(data) - - assert torch.allclose(obj["tensor"], reconstructed["tensor"]) - assert reconstructed["list_with_tensor"][0] == 1 - assert reconstructed["list_with_tensor"][3] == "string" - assert torch.allclose(obj["list_with_tensor"][2], reconstructed["list_with_tensor"][2]) - assert torch.allclose(obj["nested"]["tensor1"], reconstructed["nested"]["tensor1"]) - assert torch.equal(obj["nested"]["tensor2"], reconstructed["nested"]["tensor2"]) - - -@require_package("grpc") -def test_transitions_to_bytes_empty_list(): - from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes - - """Test converting empty transitions list.""" - transitions = [] - data = transitions_to_bytes(transitions) - reconstructed = bytes_to_transitions(data) - assert reconstructed == transitions - assert isinstance(reconstructed, list) - - -@require_package("grpc") -def test_transitions_to_bytes_single_transition(): - from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes - - """Test converting a single transition.""" - transition = Transition( - state={"image": torch.randn(3, 64, 64), "state": torch.randn(10)}, - action=torch.randn(5), - reward=torch.tensor(1.5), - done=torch.tensor(False), - next_state={"image": torch.randn(3, 64, 64), "state": torch.randn(10)}, - ) - - transitions = [transition] - data = transitions_to_bytes(transitions) - reconstructed = bytes_to_transitions(data) - - assert len(reconstructed) == 1 - - assert_transitions_equal(transitions[0], reconstructed[0]) - - -@require_package("grpc") -def assert_transitions_equal(t1: Transition, t2: Transition): - """Helper to assert two transitions are equal.""" - assert_observation_equal(t1["state"], t2["state"]) - assert torch.allclose(t1["action"], t2["action"]) - assert torch.allclose(t1["reward"], t2["reward"]) - assert torch.equal(t1["done"], t2["done"]) - assert_observation_equal(t1["next_state"], t2["next_state"]) - - -@require_package("grpc") -def assert_observation_equal(o1: dict, o2: dict): - """Helper to assert two observations are equal.""" - assert set(o1.keys()) == set(o2.keys()) - for key in o1: - assert torch.allclose(o1[key], o2[key]) - - -@require_package("grpc") -def test_transitions_to_bytes_multiple_transitions(): - from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes - - """Test converting multiple transitions.""" - transitions = [] - for i in range(5): - transition = Transition( - state={"data": torch.randn(10)}, - action=torch.randn(3), - reward=torch.tensor(float(i)), - done=torch.tensor(i == 4), - next_state={"data": torch.randn(10)}, - ) - transitions.append(transition) - - data = transitions_to_bytes(transitions) - reconstructed = bytes_to_transitions(data) - - assert len(reconstructed) == len(transitions) - for original, reconstructed_item in zip(transitions, reconstructed, strict=False): - assert_transitions_equal(original, reconstructed_item) - - -@require_package("grpc") -def test_receive_bytes_in_chunks_unknown_state(): - from lerobot.transport.utils import receive_bytes_in_chunks - - """Test receive_bytes_in_chunks with an unknown transfer state.""" - - # Mock the gRPC message object, which has `transfer_state` and `data` attributes. - class MockMessage: - def __init__(self, transfer_state, data): - self.transfer_state = transfer_state - self.data = data - - # 10 is not a valid TransferState enum value - bad_iterator = [MockMessage(transfer_state=10, data=b"bad_data")] - output_queue = Queue() - shutdown_event = Event() - - with pytest.raises(ValueError, match="Received unknown transfer state"): - receive_bytes_in_chunks(bad_iterator, output_queue, shutdown_event) diff --git a/tests/utils.py b/tests/utils.py deleted file mode 100644 index 800b7d4..0000000 --- a/tests/utils.py +++ /dev/null @@ -1,184 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import platform -from functools import wraps - -import pytest -import torch - -from lerobot import available_cameras, available_motors, available_robots -from lerobot.utils.import_utils import is_package_available - -DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu" - -TEST_ROBOT_TYPES = [] -for robot_type in available_robots: - TEST_ROBOT_TYPES += [(robot_type, True), (robot_type, False)] - -TEST_CAMERA_TYPES = [] -for camera_type in available_cameras: - TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)] - -TEST_MOTOR_TYPES = [] -for motor_type in available_motors: - TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)] - -# Camera indices used for connecting physical cameras -OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0)) -INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614)) - -DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081") -DYNAMIXEL_MOTORS = { - "shoulder_pan": [1, "xl430-w250"], - "shoulder_lift": [2, "xl430-w250"], - "elbow_flex": [3, "xl330-m288"], - "wrist_flex": [4, "xl330-m288"], - "wrist_roll": [5, "xl330-m288"], - "gripper": [6, "xl330-m288"], -} - -FEETECH_PORT = os.environ.get("LEROBOT_TEST_FEETECH_PORT", "/dev/tty.usbmodem585A0080971") -FEETECH_MOTORS = { - "shoulder_pan": [1, "sts3215"], - "shoulder_lift": [2, "sts3215"], - "elbow_flex": [3, "sts3215"], - "wrist_flex": [4, "sts3215"], - "wrist_roll": [5, "sts3215"], - "gripper": [6, "sts3215"], -} - - -def require_x86_64_kernel(func): - """ - Decorator that skips the test if plateform device is not an x86_64 cpu. - """ - from functools import wraps - - @wraps(func) - def wrapper(*args, **kwargs): - if platform.machine() != "x86_64": - pytest.skip("requires x86_64 plateform") - return func(*args, **kwargs) - - return wrapper - - -def require_cpu(func): - """ - Decorator that skips the test if device is not cpu. - """ - from functools import wraps - - @wraps(func) - def wrapper(*args, **kwargs): - if DEVICE != "cpu": - pytest.skip("requires cpu") - return func(*args, **kwargs) - - return wrapper - - -def require_cuda(func): - """ - Decorator that skips the test if cuda is not available. - """ - from functools import wraps - - @wraps(func) - def wrapper(*args, **kwargs): - if not torch.cuda.is_available(): - pytest.skip("requires cuda") - return func(*args, **kwargs) - - return wrapper - - -def require_env(func): - """ - Decorator that skips the test if the required environment package is not installed. - As it need 'env_name' in args, it also checks whether it is provided as an argument. - If 'env_name' is None, this check is skipped. - """ - - @wraps(func) - def wrapper(*args, **kwargs): - # Determine if 'env_name' is provided and extract its value - arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] - if "env_name" in arg_names: - # Get the index of 'env_name' and retrieve the value from args - index = arg_names.index("env_name") - env_name = args[index] if len(args) > index else kwargs.get("env_name") - else: - raise ValueError("Function does not have 'env_name' as an argument.") - - # Perform the package check - package_name = f"gym_{env_name}" - if env_name is not None and not is_package_available(package_name): - pytest.skip(f"gym-{env_name} not installed") - - return func(*args, **kwargs) - - return wrapper - - -def require_package_arg(func): - """ - Decorator that skips the test if the required package is not installed. - This is similar to `require_env` but more general in that it can check any package (not just environments). - As it need 'required_packages' in args, it also checks whether it is provided as an argument. - If 'required_packages' is None, this check is skipped. - """ - - @wraps(func) - def wrapper(*args, **kwargs): - # Determine if 'required_packages' is provided and extract its value - arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] - if "required_packages" in arg_names: - # Get the index of 'required_packages' and retrieve the value from args - index = arg_names.index("required_packages") - required_packages = args[index] if len(args) > index else kwargs.get("required_packages") - else: - raise ValueError("Function does not have 'required_packages' as an argument.") - - if required_packages is None: - return func(*args, **kwargs) - - # Perform the package check - for package in required_packages: - if not is_package_available(package): - pytest.skip(f"{package} not installed") - - return func(*args, **kwargs) - - return wrapper - - -def require_package(package_name): - """ - Decorator that skips the test if the specified package is not installed. - """ - - def decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - if not is_package_available(package_name): - pytest.skip(f"{package_name} not installed") - return func(*args, **kwargs) - - return wrapper - - return decorator diff --git a/tests/utils/test_encoding_utils.py b/tests/utils/test_encoding_utils.py deleted file mode 100644 index 8139428..0000000 --- a/tests/utils/test_encoding_utils.py +++ /dev/null @@ -1,171 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -from lerobot.utils.encoding_utils import ( - decode_sign_magnitude, - decode_twos_complement, - encode_sign_magnitude, - encode_twos_complement, -) - - -@pytest.mark.parametrize( - "value, sign_bit_index, expected", - [ - (5, 4, 5), - (0, 4, 0), - (7, 3, 7), - (-1, 4, 17), - (-8, 4, 24), - (-3, 3, 11), - ], -) -def test_encode_sign_magnitude(value, sign_bit_index, expected): - assert encode_sign_magnitude(value, sign_bit_index) == expected - - -@pytest.mark.parametrize( - "encoded, sign_bit_index, expected", - [ - (5, 4, 5), - (0, 4, 0), - (7, 3, 7), - (17, 4, -1), - (24, 4, -8), - (11, 3, -3), - ], -) -def test_decode_sign_magnitude(encoded, sign_bit_index, expected): - assert decode_sign_magnitude(encoded, sign_bit_index) == expected - - -@pytest.mark.parametrize( - "encoded, sign_bit_index", - [ - (16, 4), - (-9, 3), - ], -) -def test_encode_raises_on_overflow(encoded, sign_bit_index): - with pytest.raises(ValueError): - encode_sign_magnitude(encoded, sign_bit_index) - - -def test_encode_decode_sign_magnitude(): - for sign_bit_index in range(2, 6): - max_val = (1 << sign_bit_index) - 1 - for value in range(-max_val, max_val + 1): - encoded = encode_sign_magnitude(value, sign_bit_index) - decoded = decode_sign_magnitude(encoded, sign_bit_index) - assert decoded == value, f"Failed at value={value}, index={sign_bit_index}" - - -@pytest.mark.parametrize( - "value, n_bytes, expected", - [ - (0, 1, 0), - (5, 1, 5), - (-1, 1, 255), - (-128, 1, 128), - (-2, 1, 254), - (127, 1, 127), - (0, 2, 0), - (5, 2, 5), - (-1, 2, 65_535), - (-32_768, 2, 32_768), - (-2, 2, 65_534), - (32_767, 2, 32_767), - (0, 4, 0), - (5, 4, 5), - (-1, 4, 4_294_967_295), - (-2_147_483_648, 4, 2_147_483_648), - (-2, 4, 4_294_967_294), - (2_147_483_647, 4, 2_147_483_647), - ], -) -def test_encode_twos_complement(value, n_bytes, expected): - assert encode_twos_complement(value, n_bytes) == expected - - -@pytest.mark.parametrize( - "value, n_bytes, expected", - [ - (0, 1, 0), - (5, 1, 5), - (255, 1, -1), - (128, 1, -128), - (254, 1, -2), - (127, 1, 127), - (0, 2, 0), - (5, 2, 5), - (65_535, 2, -1), - (32_768, 2, -32_768), - (65_534, 2, -2), - (32_767, 2, 32_767), - (0, 4, 0), - (5, 4, 5), - (4_294_967_295, 4, -1), - (2_147_483_648, 4, -2_147_483_648), - (4_294_967_294, 4, -2), - (2_147_483_647, 4, 2_147_483_647), - ], -) -def test_decode_twos_complement(value, n_bytes, expected): - assert decode_twos_complement(value, n_bytes) == expected - - -@pytest.mark.parametrize( - "value, n_bytes", - [ - (-129, 1), - (128, 1), - (-32_769, 2), - (32_768, 2), - (-2_147_483_649, 4), - (2_147_483_648, 4), - ], -) -def test_encode_twos_complement_out_of_range(value, n_bytes): - with pytest.raises(ValueError): - encode_twos_complement(value, n_bytes) - - -@pytest.mark.parametrize( - "value, n_bytes", - [ - (-128, 1), - (-1, 1), - (0, 1), - (1, 1), - (127, 1), - (-32_768, 2), - (-1, 2), - (0, 2), - (1, 2), - (32_767, 2), - (-2_147_483_648, 4), - (-1, 4), - (0, 4), - (1, 4), - (2_147_483_647, 4), - ], -) -def test_encode_decode_twos_complement(value, n_bytes): - encoded = encode_twos_complement(value, n_bytes) - decoded = decode_twos_complement(encoded, n_bytes) - assert decoded == value, f"Failed at value={value}, n_bytes={n_bytes}" diff --git a/tests/utils/test_io_utils.py b/tests/utils/test_io_utils.py deleted file mode 100644 index 9768a5e..0000000 --- a/tests/utils/test_io_utils.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -from pathlib import Path -from typing import Any - -import pytest - -from lerobot.utils.io_utils import deserialize_json_into_object - - -@pytest.fixture -def tmp_json_file(tmp_path: Path): - """Writes `data` to a temporary JSON file and returns the file's path.""" - - def _write(data: Any) -> Path: - file_path = tmp_path / "data.json" - with file_path.open("w", encoding="utf-8") as f: - json.dump(data, f) - return file_path - - return _write - - -def test_simple_dict(tmp_json_file): - data = {"name": "Alice", "age": 30} - json_path = tmp_json_file(data) - obj = {"name": "", "age": 0} - assert deserialize_json_into_object(json_path, obj) == data - - -def test_nested_structure(tmp_json_file): - data = {"items": [1, 2, 3], "info": {"active": True}} - json_path = tmp_json_file(data) - obj = {"items": [0, 0, 0], "info": {"active": False}} - assert deserialize_json_into_object(json_path, obj) == data - - -def test_tuple_conversion(tmp_json_file): - data = {"coords": [10.5, 20.5]} - json_path = tmp_json_file(data) - obj = {"coords": (0.0, 0.0)} - result = deserialize_json_into_object(json_path, obj) - assert result["coords"] == (10.5, 20.5) - - -def test_type_mismatch_raises(tmp_json_file): - data = {"numbers": {"bad": "structure"}} - json_path = tmp_json_file(data) - obj = {"numbers": [0, 0]} - with pytest.raises(TypeError): - deserialize_json_into_object(json_path, obj) - - -def test_missing_key_raises(tmp_json_file): - data = {"one": 1} - json_path = tmp_json_file(data) - obj = {"one": 0, "two": 0} - with pytest.raises(ValueError): - deserialize_json_into_object(json_path, obj) - - -def test_extra_key_raises(tmp_json_file): - data = {"one": 1, "two": 2} - json_path = tmp_json_file(data) - obj = {"one": 0} - with pytest.raises(ValueError): - deserialize_json_into_object(json_path, obj) - - -def test_list_length_mismatch_raises(tmp_json_file): - data = {"nums": [1, 2, 3]} - json_path = tmp_json_file(data) - obj = {"nums": [0, 0]} - with pytest.raises(ValueError): - deserialize_json_into_object(json_path, obj) diff --git a/tests/utils/test_logging_utils.py b/tests/utils/test_logging_utils.py deleted file mode 100644 index 927fdc1..0000000 --- a/tests/utils/test_logging_utils.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest - -from lerobot.utils.logging_utils import AverageMeter, MetricsTracker - - -@pytest.fixture -def mock_metrics(): - return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")} - - -def test_average_meter_initialization(): - meter = AverageMeter("loss", ":.2f") - assert meter.name == "loss" - assert meter.fmt == ":.2f" - assert meter.val == 0.0 - assert meter.avg == 0.0 - assert meter.sum == 0.0 - assert meter.count == 0.0 - - -def test_average_meter_update(): - meter = AverageMeter("accuracy") - meter.update(5, n=2) - assert meter.val == 5 - assert meter.sum == 10 - assert meter.count == 2 - assert meter.avg == 5 - - -def test_average_meter_reset(): - meter = AverageMeter("loss") - meter.update(3, 4) - meter.reset() - assert meter.val == 0.0 - assert meter.avg == 0.0 - assert meter.sum == 0.0 - assert meter.count == 0.0 - - -def test_average_meter_str(): - meter = AverageMeter("metric", ":.1f") - meter.update(4.567, 3) - assert str(meter) == "metric:4.6" - - -def test_metrics_tracker_initialization(mock_metrics): - tracker = MetricsTracker( - batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=10 - ) - assert tracker.steps == 10 - assert tracker.samples == 10 * 32 - assert tracker.episodes == tracker.samples / (1000 / 50) - assert tracker.epochs == tracker.samples / 1000 - assert "loss" in tracker.metrics - assert "accuracy" in tracker.metrics - - -def test_metrics_tracker_step(mock_metrics): - tracker = MetricsTracker( - batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=5 - ) - tracker.step() - assert tracker.steps == 6 - assert tracker.samples == 6 * 32 - assert tracker.episodes == tracker.samples / (1000 / 50) - assert tracker.epochs == tracker.samples / 1000 - - -def test_metrics_tracker_getattr(mock_metrics): - tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) - assert tracker.loss == mock_metrics["loss"] - assert tracker.accuracy == mock_metrics["accuracy"] - with pytest.raises(AttributeError): - _ = tracker.non_existent_metric - - -def test_metrics_tracker_setattr(mock_metrics): - tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) - tracker.loss = 2.0 - assert tracker.loss.val == 2.0 - - -def test_metrics_tracker_str(mock_metrics): - tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) - tracker.loss.update(3.456, 1) - tracker.accuracy.update(0.876, 1) - output = str(tracker) - assert "loss:3.456" in output - assert "accuracy:0.88" in output - - -def test_metrics_tracker_to_dict(mock_metrics): - tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) - tracker.loss.update(5, 2) - metrics_dict = tracker.to_dict() - assert isinstance(metrics_dict, dict) - assert metrics_dict["loss"] == 5 # average value - assert metrics_dict["steps"] == tracker.steps - - -def test_metrics_tracker_reset_averages(mock_metrics): - tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) - tracker.loss.update(10, 3) - tracker.accuracy.update(0.95, 5) - tracker.reset_averages() - assert tracker.loss.avg == 0.0 - assert tracker.accuracy.avg == 0.0 diff --git a/tests/utils/test_process.py b/tests/utils/test_process.py deleted file mode 100644 index 61e6e2c..0000000 --- a/tests/utils/test_process.py +++ /dev/null @@ -1,112 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import multiprocessing -import os -import signal -import threading -from unittest.mock import patch - -import pytest - -from lerobot.utils.process import ProcessSignalHandler - - -# Fixture to reset shutdown_event_counter and original signal handlers before and after each test -@pytest.fixture(autouse=True) -def reset_globals_and_handlers(): - # Store original signal handlers - original_handlers = { - sig: signal.getsignal(sig) - for sig in [signal.SIGINT, signal.SIGTERM, signal.SIGHUP, signal.SIGQUIT] - if hasattr(signal, sig.name) - } - - yield - - # Restore original signal handlers - for sig, handler in original_handlers.items(): - signal.signal(sig, handler) - - -def test_setup_process_handlers_event_with_threads(): - """Test that setup_process_handlers returns the correct event type.""" - handler = ProcessSignalHandler(use_threads=True) - shutdown_event = handler.shutdown_event - assert isinstance(shutdown_event, threading.Event), "Should be a threading.Event" - assert not shutdown_event.is_set(), "Event should initially be unset" - - -def test_setup_process_handlers_event_with_processes(): - """Test that setup_process_handlers returns the correct event type.""" - handler = ProcessSignalHandler(use_threads=False) - shutdown_event = handler.shutdown_event - assert isinstance(shutdown_event, type(multiprocessing.Event())), "Should be a multiprocessing.Event" - assert not shutdown_event.is_set(), "Event should initially be unset" - - -@pytest.mark.parametrize("use_threads", [True, False]) -@pytest.mark.parametrize( - "sig", - [ - signal.SIGINT, - signal.SIGTERM, - # SIGHUP and SIGQUIT are not reliably available on all platforms (e.g. Windows) - pytest.param( - signal.SIGHUP, - marks=pytest.mark.skipif(not hasattr(signal, "SIGHUP"), reason="SIGHUP not available"), - ), - pytest.param( - signal.SIGQUIT, - marks=pytest.mark.skipif(not hasattr(signal, "SIGQUIT"), reason="SIGQUIT not available"), - ), - ], -) -def test_signal_handler_sets_event(use_threads, sig): - """Test that the signal handler sets the event on receiving a signal.""" - handler = ProcessSignalHandler(use_threads=use_threads) - shutdown_event = handler.shutdown_event - - assert handler.counter == 0 - - os.kill(os.getpid(), sig) - - # In some environments, the signal might take a moment to be handled. - shutdown_event.wait(timeout=1.0) - - assert shutdown_event.is_set(), f"Event should be set after receiving signal {sig}" - - # Ensure the internal counter was incremented - assert handler.counter == 1 - - -@pytest.mark.parametrize("use_threads", [True, False]) -@patch("sys.exit") -def test_force_shutdown_on_second_signal(mock_sys_exit, use_threads): - """Test that a second signal triggers a force shutdown.""" - handler = ProcessSignalHandler(use_threads=use_threads) - - os.kill(os.getpid(), signal.SIGINT) - # Give a moment for the first signal to be processed - import time - - time.sleep(0.1) - os.kill(os.getpid(), signal.SIGINT) - - time.sleep(0.1) - - assert handler.counter == 2 - mock_sys_exit.assert_called_once_with(1) diff --git a/tests/utils/test_queue.py b/tests/utils/test_queue.py deleted file mode 100644 index 6e42acd..0000000 --- a/tests/utils/test_queue.py +++ /dev/null @@ -1,166 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import threading -import time -from queue import Queue - -from torch.multiprocessing import Queue as TorchMPQueue - -from lerobot.utils.queue import get_last_item_from_queue - - -def test_get_last_item_single_item(): - """Test getting the last item when queue has only one item.""" - queue = Queue() - queue.put("single_item") - - result = get_last_item_from_queue(queue) - - assert result == "single_item" - assert queue.empty() - - -def test_get_last_item_multiple_items(): - """Test getting the last item when queue has multiple items.""" - queue = Queue() - items = ["first", "second", "third", "fourth", "last"] - - for item in items: - queue.put(item) - - result = get_last_item_from_queue(queue) - - assert result == "last" - assert queue.empty() - - -def test_get_last_item_multiple_items_with_torch_queue(): - """Test getting the last item when queue has multiple items.""" - queue = TorchMPQueue() - items = ["first", "second", "third", "fourth", "last"] - - for item in items: - queue.put(item) - - result = get_last_item_from_queue(queue) - - assert result == "last" - assert queue.empty() - - -def test_get_last_item_different_types(): - """Test with different data types in the queue.""" - queue = Queue() - items = [1, 2.5, "string", {"key": "value"}, [1, 2, 3], ("tuple", "data")] - - for item in items: - queue.put(item) - - result = get_last_item_from_queue(queue) - - assert result == ("tuple", "data") - assert queue.empty() - - -def test_get_last_item_maxsize_queue(): - """Test with a queue that has a maximum size.""" - queue = Queue(maxsize=5) - - # Fill the queue - for i in range(5): - queue.put(i) - - # Give the queue time to fill - time.sleep(0.1) - - result = get_last_item_from_queue(queue) - - assert result == 4 - assert queue.empty() - - -def test_get_last_item_with_none_values(): - """Test with None values in the queue.""" - queue = Queue() - items = [1, None, 2, None, 3] - - for item in items: - queue.put(item) - - # Give the queue time to fill - time.sleep(0.1) - - result = get_last_item_from_queue(queue) - - assert result == 3 - assert queue.empty() - - -def test_get_last_item_blocking_timeout(): - """Test get_last_item_from_queue returns None on timeout.""" - queue = Queue() - result = get_last_item_from_queue(queue, block=True, timeout=0.1) - assert result is None - - -def test_get_last_item_non_blocking_empty(): - """Test get_last_item_from_queue with block=False on an empty queue returns None.""" - queue = Queue() - result = get_last_item_from_queue(queue, block=False) - assert result is None - - -def test_get_last_item_non_blocking_success(): - """Test get_last_item_from_queue with block=False on a non-empty queue.""" - queue = Queue() - items = ["first", "second", "last"] - for item in items: - queue.put(item) - - # Give the queue time to fill - time.sleep(0.1) - - result = get_last_item_from_queue(queue, block=False) - assert result == "last" - assert queue.empty() - - -def test_get_last_item_blocking_waits_for_item(): - """Test that get_last_item_from_queue waits for an item if block=True.""" - queue = Queue() - result = [] - - def producer(): - queue.put("item1") - queue.put("item2") - - def consumer(): - # This will block until the producer puts the first item - item = get_last_item_from_queue(queue, block=True, timeout=0.2) - result.append(item) - - producer_thread = threading.Thread(target=producer) - consumer_thread = threading.Thread(target=consumer) - - producer_thread.start() - consumer_thread.start() - - producer_thread.join() - consumer_thread.join() - - assert result == ["item2"] - assert queue.empty() diff --git a/tests/utils/test_random_utils.py b/tests/utils/test_random_utils.py deleted file mode 100644 index 5865361..0000000 --- a/tests/utils/test_random_utils.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import random - -import numpy as np -import pytest -import torch - -from lerobot.utils.random_utils import ( - deserialize_numpy_rng_state, - deserialize_python_rng_state, - deserialize_rng_state, - deserialize_torch_rng_state, - get_rng_state, - seeded_context, - serialize_numpy_rng_state, - serialize_python_rng_state, - serialize_rng_state, - serialize_torch_rng_state, - set_rng_state, - set_seed, -) - - -@pytest.fixture -def fixed_seed(): - """Fixture to set a consistent initial seed for each test.""" - set_seed(12345) - yield - - -def test_serialize_deserialize_python_rng(fixed_seed): - # Save state after generating val1 - _ = random.random() - st = serialize_python_rng_state() - # Next random is val2 - val2 = random.random() - # Restore the state, so the next random should match val2 - deserialize_python_rng_state(st) - val3 = random.random() - assert val2 == val3 - - -def test_serialize_deserialize_numpy_rng(fixed_seed): - _ = np.random.rand() - st = serialize_numpy_rng_state() - val2 = np.random.rand() - deserialize_numpy_rng_state(st) - val3 = np.random.rand() - assert val2 == val3 - - -def test_serialize_deserialize_torch_rng(fixed_seed): - _ = torch.rand(1).item() - st = serialize_torch_rng_state() - val2 = torch.rand(1).item() - deserialize_torch_rng_state(st) - val3 = torch.rand(1).item() - assert val2 == val3 - - -def test_serialize_deserialize_rng(fixed_seed): - # Generate one from each library - _ = random.random() - _ = np.random.rand() - _ = torch.rand(1).item() - # Serialize - st = serialize_rng_state() - # Generate second set - val_py2 = random.random() - val_np2 = np.random.rand() - val_th2 = torch.rand(1).item() - # Restore, so the next draws should match val_py2, val_np2, val_th2 - deserialize_rng_state(st) - assert random.random() == val_py2 - assert np.random.rand() == val_np2 - assert torch.rand(1).item() == val_th2 - - -def test_get_set_rng_state(fixed_seed): - st = get_rng_state() - val1 = (random.random(), np.random.rand(), torch.rand(1).item()) - # Change states - random.random() - np.random.rand() - torch.rand(1) - # Restore - set_rng_state(st) - val2 = (random.random(), np.random.rand(), torch.rand(1).item()) - assert val1 == val2 - - -def test_set_seed(): - set_seed(1337) - val1 = (random.random(), np.random.rand(), torch.rand(1).item()) - set_seed(1337) - val2 = (random.random(), np.random.rand(), torch.rand(1).item()) - assert val1 == val2 - - -def test_seeded_context(fixed_seed): - val1 = (random.random(), np.random.rand(), torch.rand(1).item()) - with seeded_context(1337): - seeded_val1 = (random.random(), np.random.rand(), torch.rand(1).item()) - val2 = (random.random(), np.random.rand(), torch.rand(1).item()) - with seeded_context(1337): - seeded_val2 = (random.random(), np.random.rand(), torch.rand(1).item()) - - assert seeded_val1 == seeded_val2 - assert all(a != b for a, b in zip(val1, seeded_val1, strict=True)) # changed inside the context - assert all(a != b for a, b in zip(val2, seeded_val2, strict=True)) # changed again after exiting diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py deleted file mode 100644 index a53d7ba..0000000 --- a/tests/utils/test_replay_buffer.py +++ /dev/null @@ -1,682 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -from collections.abc import Callable - -import pytest -import torch - -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.utils.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized -from tests.fixtures.constants import DUMMY_REPO_ID - - -def state_dims() -> list[str]: - return ["observation.image", "observation.state"] - - -@pytest.fixture -def replay_buffer() -> ReplayBuffer: - return create_empty_replay_buffer() - - -def clone_state(state: dict) -> dict: - return {k: v.clone() for k, v in state.items()} - - -def create_empty_replay_buffer( - optimize_memory: bool = False, - use_drq: bool = False, - image_augmentation_function: Callable | None = None, -) -> ReplayBuffer: - buffer_capacity = 10 - device = "cpu" - return ReplayBuffer( - buffer_capacity, - device, - state_dims(), - optimize_memory=optimize_memory, - use_drq=use_drq, - image_augmentation_function=image_augmentation_function, - ) - - -def create_random_image() -> torch.Tensor: - return torch.rand(3, 84, 84) - - -def create_dummy_transition() -> dict: - return { - "observation.image": create_random_image(), - "action": torch.randn(4), - "reward": torch.tensor(1.0), - "observation.state": torch.randn( - 10, - ), - "done": torch.tensor(False), - "truncated": torch.tensor(False), - "complementary_info": {}, - } - - -def create_dataset_from_replay_buffer(tmp_path) -> tuple[LeRobotDataset, ReplayBuffer]: - dummy_state_1 = create_dummy_state() - dummy_action_1 = create_dummy_action() - - dummy_state_2 = create_dummy_state() - dummy_action_2 = create_dummy_action() - - dummy_state_3 = create_dummy_state() - dummy_action_3 = create_dummy_action() - - dummy_state_4 = create_dummy_state() - dummy_action_4 = create_dummy_action() - - replay_buffer = create_empty_replay_buffer() - replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False) - replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False) - replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True) - replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True) - - root = tmp_path / "test" - return (replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root), replay_buffer) - - -def create_dummy_state() -> dict: - return { - "observation.image": create_random_image(), - "observation.state": torch.randn( - 10, - ), - } - - -def get_tensor_memory_consumption(tensor): - return tensor.nelement() * tensor.element_size() - - -def get_tensors_memory_consumption(obj, visited_addresses): - total_size = 0 - - address = id(obj) - if address in visited_addresses: - return 0 - - visited_addresses.add(address) - - if isinstance(obj, torch.Tensor): - return get_tensor_memory_consumption(obj) - elif isinstance(obj, (list, tuple)): - for item in obj: - total_size += get_tensors_memory_consumption(item, visited_addresses) - elif isinstance(obj, dict): - for value in obj.values(): - total_size += get_tensors_memory_consumption(value, visited_addresses) - elif hasattr(obj, "__dict__"): - # It's an object, we need to get the size of the attributes - for _, attr in vars(obj).items(): - total_size += get_tensors_memory_consumption(attr, visited_addresses) - - return total_size - - -def get_object_memory(obj): - # Track visited addresses to avoid infinite loops - # and cases when two properties point to the same object - visited_addresses = set() - - # Get the size of the object in bytes - total_size = sys.getsizeof(obj) - - # Get the size of the tensor attributes - total_size += get_tensors_memory_consumption(obj, visited_addresses) - - return total_size - - -def create_dummy_action() -> torch.Tensor: - return torch.randn(4) - - -def dict_properties() -> list: - return ["state", "next_state"] - - -@pytest.fixture -def dummy_state() -> dict: - return create_dummy_state() - - -@pytest.fixture -def next_dummy_state() -> dict: - return create_dummy_state() - - -@pytest.fixture -def dummy_action() -> torch.Tensor: - return torch.randn(4) - - -def test_empty_buffer_sample_raises_error(replay_buffer): - assert len(replay_buffer) == 0, "Replay buffer should be empty." - assert replay_buffer.capacity == 10, "Replay buffer capacity should be 10." - with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"): - replay_buffer.sample(1) - - -def test_zero_capacity_buffer_raises_error(): - with pytest.raises(ValueError, match="Capacity must be greater than 0."): - ReplayBuffer(0, "cpu", ["observation", "next_observation"]) - - -def test_add_transition(replay_buffer, dummy_state, dummy_action): - replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False) - assert len(replay_buffer) == 1, "Replay buffer should have one transition after adding." - assert torch.equal(replay_buffer.actions[0], dummy_action), ( - "Action should be equal to the first transition." - ) - assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the first transition." - assert not replay_buffer.dones[0], "Done should be False for the first transition." - assert not replay_buffer.truncateds[0], "Truncated should be False for the first transition." - - for dim in state_dims(): - assert torch.equal(replay_buffer.states[dim][0], dummy_state[dim]), ( - "Observation should be equal to the first transition." - ) - assert torch.equal(replay_buffer.next_states[dim][0], dummy_state[dim]), ( - "Next observation should be equal to the first transition." - ) - - -def test_add_over_capacity(): - replay_buffer = ReplayBuffer(2, "cpu", ["observation", "next_observation"]) - dummy_state_1 = create_dummy_state() - dummy_action_1 = create_dummy_action() - - dummy_state_2 = create_dummy_state() - dummy_action_2 = create_dummy_action() - - dummy_state_3 = create_dummy_state() - dummy_action_3 = create_dummy_action() - - replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False) - replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False) - replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True) - - assert len(replay_buffer) == 2, "Replay buffer should have 2 transitions after adding 3." - - for dim in state_dims(): - assert torch.equal(replay_buffer.states[dim][0], dummy_state_3[dim]), ( - "Observation should be equal to the first transition." - ) - assert torch.equal(replay_buffer.next_states[dim][0], dummy_state_3[dim]), ( - "Next observation should be equal to the first transition." - ) - - assert torch.equal(replay_buffer.actions[0], dummy_action_3), ( - "Action should be equal to the last transition." - ) - assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the last transition." - assert replay_buffer.dones[0], "Done should be True for the first transition." - assert replay_buffer.truncateds[0], "Truncated should be True for the first transition." - - -def test_sample_from_empty_buffer(replay_buffer): - with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"): - replay_buffer.sample(1) - - -def test_sample_with_1_transition(replay_buffer, dummy_state, next_dummy_state, dummy_action): - replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False) - got_batch_transition = replay_buffer.sample(1) - - expected_batch_transition = BatchTransition( - state=clone_state(dummy_state), - action=dummy_action.clone(), - reward=1.0, - next_state=clone_state(next_dummy_state), - done=False, - truncated=False, - ) - - for buffer_property in dict_properties(): - for k, v in expected_batch_transition[buffer_property].items(): - got_state = got_batch_transition[buffer_property][k] - - assert got_state.shape[0] == 1, f"{k} should have 1 transition." - assert got_state.device.type == "cpu", f"{k} should be on cpu." - - assert torch.equal(got_state[0], v), f"{k} should be equal to the expected batch transition." - - for key, _value in expected_batch_transition.items(): - if key in dict_properties(): - continue - - got_value = got_batch_transition[key] - - v_tensor = expected_batch_transition[key] - if not isinstance(v_tensor, torch.Tensor): - v_tensor = torch.tensor(v_tensor) - - assert got_value.shape[0] == 1, f"{key} should have 1 transition." - assert got_value.device.type == "cpu", f"{key} should be on cpu." - assert torch.equal(got_value[0], v_tensor), f"{key} should be equal to the expected batch transition." - - -def test_sample_with_batch_bigger_than_buffer_size( - replay_buffer, dummy_state, next_dummy_state, dummy_action -): - replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False) - got_batch_transition = replay_buffer.sample(10) - - expected_batch_transition = BatchTransition( - state=dummy_state, - action=dummy_action, - reward=1.0, - next_state=next_dummy_state, - done=False, - truncated=False, - ) - - for buffer_property in dict_properties(): - for k in expected_batch_transition[buffer_property]: - got_state = got_batch_transition[buffer_property][k] - - assert got_state.shape[0] == 1, f"{k} should have 1 transition." - - for key in expected_batch_transition: - if key in dict_properties(): - continue - - got_value = got_batch_transition[key] - assert got_value.shape[0] == 1, f"{key} should have 1 transition." - - -def test_sample_batch(replay_buffer): - dummy_state_1 = create_dummy_state() - dummy_action_1 = create_dummy_action() - - dummy_state_2 = create_dummy_state() - dummy_action_2 = create_dummy_action() - - dummy_state_3 = create_dummy_state() - dummy_action_3 = create_dummy_action() - - dummy_state_4 = create_dummy_state() - dummy_action_4 = create_dummy_action() - - replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False) - replay_buffer.add(dummy_state_2, dummy_action_2, 2.0, dummy_state_2, False, False) - replay_buffer.add(dummy_state_3, dummy_action_3, 3.0, dummy_state_3, True, True) - replay_buffer.add(dummy_state_4, dummy_action_4, 4.0, dummy_state_4, True, True) - - dummy_states = [dummy_state_1, dummy_state_2, dummy_state_3, dummy_state_4] - dummy_actions = [dummy_action_1, dummy_action_2, dummy_action_3, dummy_action_4] - - got_batch_transition = replay_buffer.sample(3) - - for buffer_property in dict_properties(): - for k in got_batch_transition[buffer_property]: - got_state = got_batch_transition[buffer_property][k] - - assert got_state.shape[0] == 3, f"{k} should have 3 transition." - - for got_state_item in got_state: - assert any(torch.equal(got_state_item, dummy_state[k]) for dummy_state in dummy_states), ( - f"{k} should be equal to one of the dummy states." - ) - - for got_action_item in got_batch_transition["action"]: - assert any(torch.equal(got_action_item, dummy_action) for dummy_action in dummy_actions), ( - "Actions should be equal to the dummy actions." - ) - - for k in got_batch_transition: - if k in dict_properties() or k == "complementary_info": - continue - - got_value = got_batch_transition[k] - assert got_value.shape[0] == 3, f"{k} should have 3 transition." - - -def test_to_lerobot_dataset_with_empty_buffer(replay_buffer): - with pytest.raises(ValueError, match="The replay buffer is empty. Cannot convert to a dataset."): - replay_buffer.to_lerobot_dataset("dummy_repo") - - -def test_to_lerobot_dataset(tmp_path): - ds, buffer = create_dataset_from_replay_buffer(tmp_path) - - assert len(ds) == len(buffer), "Dataset should have the same size as the Replay Buffer" - assert ds.fps == 1, "FPS should be 1" - assert ds.repo_id == "dummy/repo", "The dataset should have `dummy/repo` repo id" - - for dim in state_dims(): - assert dim in ds.features - assert ds.features[dim]["shape"] == buffer.states[dim][0].shape - - assert ds.num_episodes == 2 - assert ds.num_frames == 4 - - for j, value in enumerate(ds): - print(torch.equal(value["observation.image"], buffer.next_states["observation.image"][j])) - - for i in range(len(ds)): - for feature, value in ds[i].items(): - if feature == "action": - assert torch.equal(value, buffer.actions[i]) - elif feature == "next.reward": - assert torch.equal(value, buffer.rewards[i]) - elif feature == "next.done": - assert torch.equal(value, buffer.dones[i]) - elif feature == "observation.image": - # Tenssor -> numpy is not precise, so we have some diff there - # TODO: Check and fix it - torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003) - elif feature == "observation.state": - assert torch.equal(value, buffer.states["observation.state"][i]) - - -def test_from_lerobot_dataset(tmp_path): - dummy_state_1 = create_dummy_state() - dummy_action_1 = create_dummy_action() - - dummy_state_2 = create_dummy_state() - dummy_action_2 = create_dummy_action() - - dummy_state_3 = create_dummy_state() - dummy_action_3 = create_dummy_action() - - dummy_state_4 = create_dummy_state() - dummy_action_4 = create_dummy_action() - - replay_buffer = create_empty_replay_buffer() - replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False) - replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False) - replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True) - replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True) - - root = tmp_path / "test" - ds = replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root) - - reconverted_buffer = ReplayBuffer.from_lerobot_dataset( - ds, state_keys=list(state_dims()), device="cpu", capacity=replay_buffer.capacity, use_drq=False - ) - - # Check only the part of the buffer that's actually filled with data - assert torch.equal( - reconverted_buffer.actions[: len(replay_buffer)], - replay_buffer.actions[: len(replay_buffer)], - ), "Actions from converted buffer should be equal to the original replay buffer." - assert torch.equal( - reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)] - ), "Rewards from converted buffer should be equal to the original replay buffer." - assert torch.equal( - reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)] - ), "Dones from converted buffer should be equal to the original replay buffer." - - # Lerobot DS haven't supported truncateds yet - expected_truncateds = torch.zeros(len(replay_buffer)).bool() - assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), ( - "Truncateds from converted buffer should be equal False" - ) - - assert torch.equal( - replay_buffer.states["observation.state"][: len(replay_buffer)], - reconverted_buffer.states["observation.state"][: len(replay_buffer)], - ), "State should be the same after converting to dataset and return back" - - for i in range(4): - torch.testing.assert_close( - replay_buffer.states["observation.image"][i], - reconverted_buffer.states["observation.image"][i], - rtol=0.4, - atol=0.004, - ) - - # The 2, 3 frames have done flag, so their values will be equal to the current state - for i in range(2): - # In the current implementation we take the next state from the `states` and ignore `next_states` - next_index = (i + 1) % 4 - - torch.testing.assert_close( - replay_buffer.states["observation.image"][next_index], - reconverted_buffer.next_states["observation.image"][i], - rtol=0.4, - atol=0.004, - ) - - for i in range(2, 4): - assert torch.equal( - replay_buffer.states["observation.state"][i], - reconverted_buffer.next_states["observation.state"][i], - ) - - -def test_buffer_sample_alignment(): - # Initialize buffer - buffer = ReplayBuffer(capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu") - - # Fill buffer with patterned data - for i in range(100): - signature = float(i) / 100.0 - state = {"state_value": torch.tensor([[signature]]).float()} - action = torch.tensor([[2.0 * signature]]).float() - reward = 3.0 * signature - - is_end = (i + 1) % 10 == 0 - if is_end: - next_state = {"state_value": torch.tensor([[signature]]).float()} - done = True - else: - next_signature = float(i + 1) / 100.0 - next_state = {"state_value": torch.tensor([[next_signature]]).float()} - done = False - - buffer.add(state, action, reward, next_state, done, False) - - # Sample and verify - batch = buffer.sample(50) - - for i in range(50): - state_sig = batch["state"]["state_value"][i].item() - action_val = batch["action"][i].item() - reward_val = batch["reward"][i].item() - next_state_sig = batch["next_state"]["state_value"][i].item() - is_done = batch["done"][i].item() > 0.5 - - # Verify relationships - assert abs(action_val - 2.0 * state_sig) < 1e-4, ( - f"Action {action_val} should be 2x state signature {state_sig}" - ) - - assert abs(reward_val - 3.0 * state_sig) < 1e-4, ( - f"Reward {reward_val} should be 3x state signature {state_sig}" - ) - - if is_done: - assert abs(next_state_sig - state_sig) < 1e-4, ( - f"For done states, next_state {next_state_sig} should equal state {state_sig}" - ) - else: - # Either it's the next sequential state (+0.01) or same state (for episode boundaries) - valid_next = ( - abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4 - ) - assert valid_next, ( - f"Next state {next_state_sig} should be either state+0.01 or same as state {state_sig}" - ) - - -def test_memory_optimization(): - dummy_state_1 = create_dummy_state() - dummy_action_1 = create_dummy_action() - - dummy_state_2 = create_dummy_state() - dummy_action_2 = create_dummy_action() - - dummy_state_3 = create_dummy_state() - dummy_action_3 = create_dummy_action() - - dummy_state_4 = create_dummy_state() - dummy_action_4 = create_dummy_action() - - replay_buffer = create_empty_replay_buffer() - replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False) - replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False) - replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False) - replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True) - - optimized_replay_buffer = create_empty_replay_buffer(True) - optimized_replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False) - optimized_replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False) - optimized_replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False) - optimized_replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, None, True, True) - - assert get_object_memory(optimized_replay_buffer) < get_object_memory(replay_buffer), ( - "Optimized replay buffer should be smaller than the original replay buffer" - ) - - -def test_check_image_augmentations_with_drq_and_dummy_image_augmentation_function(dummy_state, dummy_action): - def dummy_image_augmentation_function(x): - return torch.ones_like(x) * 10 - - replay_buffer = create_empty_replay_buffer( - use_drq=True, image_augmentation_function=dummy_image_augmentation_function - ) - - replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False) - - sampled_transitions = replay_buffer.sample(1) - assert torch.all(sampled_transitions["state"]["observation.image"] == 10), ( - "Image augmentations should be applied" - ) - assert torch.all(sampled_transitions["next_state"]["observation.image"] == 10), ( - "Image augmentations should be applied" - ) - - -def test_check_image_augmentations_with_drq_and_default_image_augmentation_function( - dummy_state, dummy_action -): - replay_buffer = create_empty_replay_buffer(use_drq=True) - - replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False) - - # Let's check that it doesn't fail and shapes are correct - sampled_transitions = replay_buffer.sample(1) - assert sampled_transitions["state"]["observation.image"].shape == (1, 3, 84, 84) - assert sampled_transitions["next_state"]["observation.image"].shape == (1, 3, 84, 84) - - -def test_random_crop_vectorized_basic(): - # Create a batch of 2 images with known patterns - batch_size, channels, height, width = 2, 3, 10, 8 - images = torch.zeros((batch_size, channels, height, width)) - - # Fill with unique values for testing - for b in range(batch_size): - images[b] = b + 1 - - crop_size = (6, 4) # Smaller than original - cropped = random_crop_vectorized(images, crop_size) - - # Check output shape - assert cropped.shape == (batch_size, channels, *crop_size) - - # Check that values are preserved (should be either 1s or 2s for respective batches) - assert torch.all(cropped[0] == 1) - assert torch.all(cropped[1] == 2) - - -def test_random_crop_vectorized_invalid_size(): - images = torch.zeros((2, 3, 10, 8)) - - # Test crop size larger than image - with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"): - random_crop_vectorized(images, (12, 8)) - - with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"): - random_crop_vectorized(images, (10, 10)) - - -def _populate_buffer_for_async_test(capacity: int = 10) -> ReplayBuffer: - """Create a small buffer with deterministic 3×128×128 images and 11-D state.""" - buffer = ReplayBuffer( - capacity=capacity, - device="cpu", - state_keys=["observation.image", "observation.state"], - storage_device="cpu", - ) - - for i in range(capacity): - img = torch.ones(3, 128, 128) * i - state_vec = torch.arange(11).float() + i - state = { - "observation.image": img, - "observation.state": state_vec, - } - buffer.add( - state=state, - action=torch.tensor([0.0]), - reward=0.0, - next_state=state, - done=False, - truncated=False, - ) - return buffer - - -def test_async_iterator_shapes_basic(): - buffer = _populate_buffer_for_async_test() - batch_size = 2 - iterator = buffer.get_iterator(batch_size=batch_size, async_prefetch=True, queue_size=1) - batch = next(iterator) - - images = batch["state"]["observation.image"] - states = batch["state"]["observation.state"] - - assert images.shape == (batch_size, 3, 128, 128) - assert states.shape == (batch_size, 11) - - next_images = batch["next_state"]["observation.image"] - next_states = batch["next_state"]["observation.state"] - - assert next_images.shape == (batch_size, 3, 128, 128) - assert next_states.shape == (batch_size, 11) - - -def test_async_iterator_multiple_iterations(): - buffer = _populate_buffer_for_async_test() - batch_size = 2 - iterator = buffer.get_iterator(batch_size=batch_size, async_prefetch=True, queue_size=2) - - for _ in range(5): - batch = next(iterator) - images = batch["state"]["observation.image"] - states = batch["state"]["observation.state"] - assert images.shape == (batch_size, 3, 128, 128) - assert states.shape == (batch_size, 11) - - next_images = batch["next_state"]["observation.image"] - next_states = batch["next_state"]["observation.state"] - assert next_images.shape == (batch_size, 3, 128, 128) - assert next_states.shape == (batch_size, 11) - - # Ensure iterator can be disposed without blocking - del iterator diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py deleted file mode 100644 index 2d963d7..0000000 --- a/tests/utils/test_train_utils.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from pathlib import Path -from unittest.mock import Mock, patch - -from lerobot.constants import ( - CHECKPOINTS_DIR, - LAST_CHECKPOINT_LINK, - OPTIMIZER_PARAM_GROUPS, - OPTIMIZER_STATE, - RNG_STATE, - SCHEDULER_STATE, - TRAINING_STATE_DIR, - TRAINING_STEP, -) -from lerobot.utils.train_utils import ( - get_step_checkpoint_dir, - get_step_identifier, - load_training_state, - load_training_step, - save_checkpoint, - save_training_state, - save_training_step, - update_last_checkpoint, -) - - -def test_get_step_identifier(): - assert get_step_identifier(5, 1000) == "000005" - assert get_step_identifier(123, 100_000) == "000123" - assert get_step_identifier(456789, 1_000_000) == "0456789" - - -def test_get_step_checkpoint_dir(): - output_dir = Path("/checkpoints") - step_dir = get_step_checkpoint_dir(output_dir, 1000, 5) - assert step_dir == output_dir / CHECKPOINTS_DIR / "000005" - - -def test_save_load_training_step(tmp_path): - save_training_step(5000, tmp_path) - assert (tmp_path / TRAINING_STEP).is_file() - - -def test_load_training_step(tmp_path): - step = 5000 - save_training_step(step, tmp_path) - loaded_step = load_training_step(tmp_path) - assert loaded_step == step - - -def test_update_last_checkpoint(tmp_path): - checkpoint = tmp_path / "0005" - checkpoint.mkdir() - update_last_checkpoint(checkpoint) - last_checkpoint = tmp_path / LAST_CHECKPOINT_LINK - assert last_checkpoint.is_symlink() - assert last_checkpoint.resolve() == checkpoint - - -@patch("lerobot.utils.train_utils.save_training_state") -def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer): - policy = Mock() - cfg = Mock() - save_checkpoint(tmp_path, 10, cfg, policy, optimizer) - policy.save_pretrained.assert_called_once() - cfg.save_pretrained.assert_called_once() - mock_save_training_state.assert_called_once() - - -def test_save_training_state(tmp_path, optimizer, scheduler): - save_training_state(tmp_path, 10, optimizer, scheduler) - assert (tmp_path / TRAINING_STATE_DIR).is_dir() - assert (tmp_path / TRAINING_STATE_DIR / TRAINING_STEP).is_file() - assert (tmp_path / TRAINING_STATE_DIR / RNG_STATE).is_file() - assert (tmp_path / TRAINING_STATE_DIR / OPTIMIZER_STATE).is_file() - assert (tmp_path / TRAINING_STATE_DIR / OPTIMIZER_PARAM_GROUPS).is_file() - assert (tmp_path / TRAINING_STATE_DIR / SCHEDULER_STATE).is_file() - - -def test_save_load_training_state(tmp_path, optimizer, scheduler): - save_training_state(tmp_path, 10, optimizer, scheduler) - loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(tmp_path, optimizer, scheduler) - assert loaded_step == 10 - assert loaded_optimizer is optimizer - assert loaded_scheduler is scheduler