移除所有测试代码以节省空间
This commit is contained in:
parent
1e7bb40565
commit
026bd915b1
5
.gitignore
vendored
5
.gitignore
vendored
@ -180,4 +180,7 @@ s100
|
|||||||
|
|
||||||
huggingface_models
|
huggingface_models
|
||||||
docker/inputs
|
docker/inputs
|
||||||
docker/outputs
|
docker/outputs
|
||||||
|
|
||||||
|
# Skip big files in tests folder
|
||||||
|
tests
|
||||||
@ -1,13 +0,0 @@
|
|||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:9dc9df05797dc0e7b92edc845caab2e4c37c3cfcabb4ee6339c67212b5baba3b
|
|
||||||
size 38023
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:7e11af87616b83c1cdb30330e951b91e86b51c64a1326e1ba5b4a3fbcdec1a11
|
|
||||||
size 55698
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:b8840fb643afe903191248703b1f95a57faf5812ecd9978ac502ee939646fdb2
|
|
||||||
size 121115
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:f79d14daafb1c0cf2fec5d46ee8029a73fe357402fdd31a7cd4a4794d7319a7c
|
|
||||||
size 260367
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:a8d6e64d6cb0e02c94ae125630ee758055bd2e695772c0463a30d63ddc6c5e17
|
|
||||||
size 3520862
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:6bdf22208d49cd36d24bc844d4d8bda5e321eafe39d2b470e4fc95c7812fdb24
|
|
||||||
size 3687117
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:8920d5ebab36ffcba9aa74dcd91677c121f504b4d945b472352d379f9272fabf
|
|
||||||
size 3687117
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:35723f2db499da3d9d121aa79d2ff4c748effd7c2ea92f277ec543a82fb843ca
|
|
||||||
size 3687117
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:53172b773d4a78bb3140f10280105c2c4ebcb467f3097579988d42cb87790ab9
|
|
||||||
size 3687117
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:58a5d91573e7dd2352a1454a5c9118c9ad3798428a0104e5e0b57fc01f780ae7
|
|
||||||
size 3687117
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:bb65a25e989a32a8b6258d368bd077e4548379c74ab5ada01cc532d658670df0
|
|
||||||
size 3687117
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:c3dcff0a705ebfdaf11b7f49ad85b464eff03477ace3d63ce45d6a3a10b429d5
|
|
||||||
size 111338
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:d8ab0274761cdd758bafdf274ce3e6398cd6f0df23393971f3e1b6b465d66ef3
|
|
||||||
size 111338
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:aee60956925da9687546aafa770d5e6a04f99576f903b08d0bd5f8003a7f4f3e
|
|
||||||
size 111338
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:c8d9f9cc9e232820760fe4a46b47000c921fa5d868420e55d8dbc05dae56e8bd
|
|
||||||
size 111338
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:01cfe50c537e3aef0cd5947ec0b15b321b54ecb461baf7b4f2506897158eebc8
|
|
||||||
size 111338
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:96431ca3479eef2379406ef901cad7ba5eac4f7edcc48ecc9e8d1fa0e99d8017
|
|
||||||
size 111338
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:3763d7bff7873cb40ea9d6f2f98d45fcf163addcd2809b6c59f273b6c3627ad5
|
|
||||||
size 85353
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:24150994c6959631dc081b43e4001a8664e13b194ac194a32100f7d3fd2c0d0f
|
|
||||||
size 85353
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:c9c3fdf34debe47d4b80570a19e676185449df749f37daa2111184c1f439ae5f
|
|
||||||
size 85353
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:f8cfbe444c14d643da2faea9f6a402ddb37114ab15395c381f1a7982e541f868
|
|
||||||
size 85353
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:07c5c1a63998884ee747a6d0aa8f49217da3c32af2760dad2a9da794d3517003
|
|
||||||
size 85353
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:9927ec508e3335f8b10cf3682e41dedb7e647f92a2063a4196f1e48749c47bc5
|
|
||||||
size 85353
|
|
||||||
@ -1,91 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
This script provides a utility for saving a dataset as safetensors files for the purpose of testing backward compatibility
|
|
||||||
when updating the data format. It uses the `PushtDataset` to create a DataLoader and saves selected frame from the
|
|
||||||
dataset into a corresponding safetensors file in a specified output directory.
|
|
||||||
|
|
||||||
If you know that your change will break backward compatibility, you should write a shortlived test by modifying
|
|
||||||
`tests/test_datasets.py::test_backward_compatibility` accordingly, and make sure this custom test pass. Your custom test
|
|
||||||
doesnt need to be merged into the `main` branch. Then you need to run this script and update the tests artifacts.
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
`python tests/artifacts/datasets/save_dataset_to_safetensors.py`
|
|
||||||
"""
|
|
||||||
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
|
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
||||||
|
|
||||||
|
|
||||||
def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
|
||||||
repo_dir = Path(output_dir) / repo_id
|
|
||||||
|
|
||||||
if repo_dir.exists():
|
|
||||||
shutil.rmtree(repo_dir)
|
|
||||||
|
|
||||||
repo_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
dataset = LeRobotDataset(
|
|
||||||
repo_id=repo_id,
|
|
||||||
episodes=[0],
|
|
||||||
)
|
|
||||||
|
|
||||||
# save 2 first frames of first episode
|
|
||||||
i = dataset.episode_data_index["from"][0].item()
|
|
||||||
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
|
||||||
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
|
|
||||||
|
|
||||||
# save 2 frames at the middle of first episode
|
|
||||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
|
||||||
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
|
||||||
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
|
|
||||||
|
|
||||||
# save 2 last frames of first episode
|
|
||||||
i = dataset.episode_data_index["to"][0].item()
|
|
||||||
save_file(dataset[i - 2], repo_dir / f"frame_{i - 2}.safetensors")
|
|
||||||
save_file(dataset[i - 1], repo_dir / f"frame_{i - 1}.safetensors")
|
|
||||||
|
|
||||||
# TODO(rcadene): Enable testing on second and last episode
|
|
||||||
# We currently cant because our test dataset only contains the first episode
|
|
||||||
|
|
||||||
# # save 2 first frames of second episode
|
|
||||||
# i = dataset.episode_data_index["from"][1].item()
|
|
||||||
# save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
|
||||||
# save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
|
|
||||||
|
|
||||||
# # save 2 last frames of second episode
|
|
||||||
# i = dataset.episode_data_index["to"][1].item()
|
|
||||||
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
|
|
||||||
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
|
|
||||||
|
|
||||||
# # save 2 last frames of last episode
|
|
||||||
# i = dataset.episode_data_index["to"][-1].item()
|
|
||||||
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
|
|
||||||
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
for dataset in [
|
|
||||||
"lerobot/pusht",
|
|
||||||
"lerobot/aloha_sim_insertion_human",
|
|
||||||
"lerobot/xarm_lift_medium",
|
|
||||||
"lerobot/nyu_franka_play_dataset",
|
|
||||||
"lerobot/cmu_stretch",
|
|
||||||
]:
|
|
||||||
save_dataset_to_safetensors("tests/artifacts/datasets", repo_id=dataset)
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:6b1e600768a8771c5fe650e038a1193597e3810f032041b2a0d021e4496381c1
|
|
||||||
size 3686488
|
|
||||||
@ -1,75 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
|
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
||||||
from lerobot.datasets.transforms import (
|
|
||||||
ImageTransformConfig,
|
|
||||||
ImageTransforms,
|
|
||||||
ImageTransformsConfig,
|
|
||||||
make_transform_from_config,
|
|
||||||
)
|
|
||||||
from lerobot.utils.random_utils import seeded_context
|
|
||||||
|
|
||||||
ARTIFACT_DIR = Path("tests/artifacts/image_transforms")
|
|
||||||
DATASET_REPO_ID = "lerobot/aloha_static_cups_open"
|
|
||||||
|
|
||||||
|
|
||||||
def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path):
|
|
||||||
cfg = ImageTransformsConfig(enable=True)
|
|
||||||
default_tf = ImageTransforms(cfg)
|
|
||||||
|
|
||||||
with seeded_context(1337):
|
|
||||||
img_tf = default_tf(original_frame)
|
|
||||||
|
|
||||||
save_file({"default": img_tf}, output_dir / "default_transforms.safetensors")
|
|
||||||
|
|
||||||
|
|
||||||
def save_single_transforms(original_frame: torch.Tensor, output_dir: Path):
|
|
||||||
transforms = {
|
|
||||||
("ColorJitter", "brightness", [(0.5, 0.5), (2.0, 2.0)]),
|
|
||||||
("ColorJitter", "contrast", [(0.5, 0.5), (2.0, 2.0)]),
|
|
||||||
("ColorJitter", "saturation", [(0.5, 0.5), (2.0, 2.0)]),
|
|
||||||
("ColorJitter", "hue", [(-0.25, -0.25), (0.25, 0.25)]),
|
|
||||||
("SharpnessJitter", "sharpness", [(0.5, 0.5), (2.0, 2.0)]),
|
|
||||||
}
|
|
||||||
|
|
||||||
frames = {"original_frame": original_frame}
|
|
||||||
for tf_type, tf_name, min_max_values in transforms.items():
|
|
||||||
for min_max in min_max_values:
|
|
||||||
tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max})
|
|
||||||
tf = make_transform_from_config(tf_cfg)
|
|
||||||
key = f"{tf_name}_{min_max[0]}_{min_max[1]}"
|
|
||||||
frames[key] = tf(original_frame)
|
|
||||||
|
|
||||||
save_file(frames, output_dir / "single_transforms.safetensors")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
dataset = LeRobotDataset(DATASET_REPO_ID, episodes=[0], image_transforms=None)
|
|
||||||
output_dir = Path(ARTIFACT_DIR)
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
original_frame = dataset[0][dataset.meta.camera_keys[0]]
|
|
||||||
|
|
||||||
save_single_transforms(original_frame, output_dir)
|
|
||||||
save_default_config_transform(original_frame, output_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:9d4ebab73eabddc58879a4e770289d19e00a1a4cf2fa5fa33cd3a3246992bc90
|
|
||||||
size 40551392
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:f3e4c8e85e146b043fd4e4984947c2a6f01627f174a19f18b5914cf690579d77
|
|
||||||
size 5104
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:1a7a8b1a457149109f843c32bcbb047d09de2201847b9b79f7501b447f77ecf4
|
|
||||||
size 31672
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:5e6ce85296b2009e7c2060d336c0429b1c7197d9adb159e7df0ba18003067b36
|
|
||||||
size 68
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:9b5f557e30aead3731c38cbd85af8c706395d8689a918ad88805b5a886245603
|
|
||||||
size 33400
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:2e6625cabfeb4800abc80252cf9112a9271c154edd01eb291658f143c951610b
|
|
||||||
size 515400
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:224b5fa4828aa88171b68c036e8919c1eae563e2113f03b6461eadf5bf8525a6
|
|
||||||
size 31672
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:016d2fa8fe5f58017dfd46f4632fdc19dfd751e32a2c7cde2077c6f95546d6bd
|
|
||||||
size 68
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:021562ee3e4814425e367ed0c144d6fbe2eb28838247085716cf0b58fd69a075
|
|
||||||
size 33400
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:a32376dde65a1562403afd1db3e56c7e6b987ebaf6c3c601336e77155b9e608c
|
|
||||||
size 992
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:12ee532c53173d0361ebb979f087b229cc045aa3d9e6b94cfd4290af54fd1201
|
|
||||||
size 47424
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:271b00cb2f0cd5fd26b1d53463638e3d1a6e92692ec625fcffb420ca190869e5
|
|
||||||
size 68
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:010c01181b95625051276d69cb4209423c21f2e30a3fa9464ae67064a2ba4c22
|
|
||||||
size 49120
|
|
||||||
@ -1,145 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
|
|
||||||
from lerobot.configs.default import DatasetConfig
|
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
|
||||||
from lerobot.datasets.factory import make_dataset
|
|
||||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
|
||||||
from lerobot.policies.factory import make_policy, make_policy_config
|
|
||||||
from lerobot.utils.random_utils import set_seed
|
|
||||||
|
|
||||||
|
|
||||||
def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
|
||||||
set_seed(1337)
|
|
||||||
train_cfg = TrainPipelineConfig(
|
|
||||||
# TODO(rcadene, aliberts): remove dataset download
|
|
||||||
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
|
||||||
policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs),
|
|
||||||
)
|
|
||||||
train_cfg.validate() # Needed for auto-setting some parameters
|
|
||||||
|
|
||||||
dataset = make_dataset(train_cfg)
|
|
||||||
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
|
|
||||||
policy.train()
|
|
||||||
|
|
||||||
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
|
||||||
dataset,
|
|
||||||
num_workers=0,
|
|
||||||
batch_size=train_cfg.batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
batch = next(iter(dataloader))
|
|
||||||
loss, output_dict = policy.forward(batch)
|
|
||||||
if output_dict is not None:
|
|
||||||
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
|
||||||
output_dict["loss"] = loss
|
|
||||||
else:
|
|
||||||
output_dict = {"loss": loss}
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
grad_stats = {}
|
|
||||||
for key, param in policy.named_parameters():
|
|
||||||
if param.requires_grad:
|
|
||||||
grad_stats[f"{key}_mean"] = param.grad.mean()
|
|
||||||
grad_stats[f"{key}_std"] = (
|
|
||||||
param.grad.std() if param.grad.numel() > 1 else torch.tensor(float(0.0))
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
param_stats = {}
|
|
||||||
for key, param in policy.named_parameters():
|
|
||||||
param_stats[f"{key}_mean"] = param.mean()
|
|
||||||
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0))
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
policy.reset()
|
|
||||||
|
|
||||||
# HACK: We reload a batch with no delta_indices as `select_action` won't expect a timestamps dimension
|
|
||||||
# We simulate having an environment using a dataset by setting delta_indices to None and dropping tensors
|
|
||||||
# indicating padding (those ending with "_is_pad")
|
|
||||||
dataset.delta_indices = None
|
|
||||||
batch = next(iter(dataloader))
|
|
||||||
obs = {}
|
|
||||||
for k in batch:
|
|
||||||
# TODO: regenerate the safetensors
|
|
||||||
# for backward compatibility
|
|
||||||
if k.endswith("_is_pad"):
|
|
||||||
continue
|
|
||||||
# for backward compatibility
|
|
||||||
if k == "task":
|
|
||||||
continue
|
|
||||||
if k.startswith("observation"):
|
|
||||||
obs[k] = batch[k]
|
|
||||||
|
|
||||||
if hasattr(train_cfg.policy, "n_action_steps"):
|
|
||||||
actions_queue = train_cfg.policy.n_action_steps
|
|
||||||
else:
|
|
||||||
actions_queue = train_cfg.policy.n_action_repeats
|
|
||||||
|
|
||||||
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
|
||||||
return output_dict, grad_stats, param_stats, actions
|
|
||||||
|
|
||||||
|
|
||||||
def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
|
||||||
if output_dir.exists():
|
|
||||||
print(f"Overwrite existing safetensors in '{output_dir}':")
|
|
||||||
print(f" - Validate with: `git add {output_dir}`")
|
|
||||||
print(f" - Revert with: `git checkout -- {output_dir}`")
|
|
||||||
shutil.rmtree(output_dir)
|
|
||||||
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
|
|
||||||
save_file(output_dict, output_dir / "output_dict.safetensors")
|
|
||||||
save_file(grad_stats, output_dir / "grad_stats.safetensors")
|
|
||||||
save_file(param_stats, output_dir / "param_stats.safetensors")
|
|
||||||
save_file(actions, output_dir / "actions.safetensors")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
artifacts_cfg = [
|
|
||||||
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
|
|
||||||
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
|
|
||||||
(
|
|
||||||
"lerobot/pusht",
|
|
||||||
"diffusion",
|
|
||||||
{
|
|
||||||
"n_action_steps": 8,
|
|
||||||
"num_inference_steps": 10,
|
|
||||||
"down_dims": [128, 256, 512],
|
|
||||||
},
|
|
||||||
"",
|
|
||||||
),
|
|
||||||
("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""),
|
|
||||||
(
|
|
||||||
"lerobot/aloha_sim_insertion_human",
|
|
||||||
"act",
|
|
||||||
{"n_action_steps": 1000, "chunk_size": 1000},
|
|
||||||
"1000_steps",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
if len(artifacts_cfg) == 0:
|
|
||||||
raise RuntimeError("No policies were provided!")
|
|
||||||
for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg:
|
|
||||||
ds_name = ds_repo_id.split("/")[-1]
|
|
||||||
output_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}"
|
|
||||||
save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs)
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:c5edc5600d7206f027cb696a597bc99fcdd9073a15fa130b8031c52c0a7c134b
|
|
||||||
size 200
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
|
|
||||||
size 16904
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
|
|
||||||
size 164
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
|
|
||||||
size 36312
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:a9c08753ddc43b6c02a176418b81eb784146e59f4fc914591cbd3582ade392bb
|
|
||||||
size 200
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
|
|
||||||
size 16904
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
|
|
||||||
size 164
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
|
|
||||||
size 36312
|
|
||||||
@ -1,177 +0,0 @@
|
|||||||
# Copyright 2025 The HuggingFace Inc. team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""End-to-end test of the asynchronous inference stack (client ↔ server).
|
|
||||||
|
|
||||||
This test spins up a lightweight gRPC `PolicyServer` instance with a stubbed
|
|
||||||
policy network and launches a `RobotClient` that uses a `MockRobot`. The goal
|
|
||||||
is to exercise the full communication loop:
|
|
||||||
|
|
||||||
1. Client sends policy specification → Server
|
|
||||||
2. Client streams observations → Server
|
|
||||||
3. Server streams action chunks → Client
|
|
||||||
4. Client executes received actions
|
|
||||||
|
|
||||||
The test succeeds if at least one action is executed and the server records at
|
|
||||||
least one predicted timestep - demonstrating that the gRPC round-trip works
|
|
||||||
end-to-end using real (but lightweight) protocol messages.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import threading
|
|
||||||
from concurrent import futures
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# Skip entire module if grpc is not available
|
|
||||||
pytest.importorskip("grpc")
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# End-to-end test
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_async_inference_e2e(monkeypatch):
|
|
||||||
"""Tests the full asynchronous inference pipeline."""
|
|
||||||
# Import grpc-dependent modules inside the test function
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
from lerobot.robots.utils import make_robot_from_config
|
|
||||||
from lerobot.scripts.server.configs import PolicyServerConfig, RobotClientConfig
|
|
||||||
from lerobot.scripts.server.helpers import map_robot_keys_to_lerobot_features
|
|
||||||
from lerobot.scripts.server.policy_server import PolicyServer
|
|
||||||
from lerobot.scripts.server.robot_client import RobotClient
|
|
||||||
from lerobot.transport import (
|
|
||||||
services_pb2, # type: ignore
|
|
||||||
services_pb2_grpc, # type: ignore
|
|
||||||
)
|
|
||||||
from tests.mocks.mock_robot import MockRobotConfig
|
|
||||||
|
|
||||||
# Create a stub policy similar to test_policy_server.py
|
|
||||||
class MockPolicy:
|
|
||||||
"""A minimal mock for an actual policy, returning zeros."""
|
|
||||||
|
|
||||||
class _Config:
|
|
||||||
robot_type = "dummy_robot"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def image_features(self):
|
|
||||||
"""Empty image features since this test doesn't use images."""
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.config = self._Config()
|
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def model(self, batch):
|
|
||||||
# Return a chunk of 20 dummy actions.
|
|
||||||
batch_size = len(batch["robot_type"])
|
|
||||||
return torch.zeros(batch_size, 20, 6)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 1. Create PolicyServer instance with mock policy
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
policy_server_config = PolicyServerConfig(host="localhost", port=9999)
|
|
||||||
policy_server = PolicyServer(policy_server_config)
|
|
||||||
# Replace the real policy with our fast, deterministic stub.
|
|
||||||
policy_server.policy = MockPolicy()
|
|
||||||
policy_server.actions_per_chunk = 20
|
|
||||||
policy_server.device = "cpu"
|
|
||||||
|
|
||||||
# Set up robot config and features
|
|
||||||
robot_config = MockRobotConfig()
|
|
||||||
mock_robot = make_robot_from_config(robot_config)
|
|
||||||
|
|
||||||
lerobot_features = map_robot_keys_to_lerobot_features(mock_robot)
|
|
||||||
policy_server.lerobot_features = lerobot_features
|
|
||||||
|
|
||||||
# Force server to produce deterministic action chunks in test mode
|
|
||||||
policy_server.policy_type = "act"
|
|
||||||
|
|
||||||
def _fake_get_action_chunk(_self, _obs, _type="test"):
|
|
||||||
action_dim = 6
|
|
||||||
batch_size = 1
|
|
||||||
actions_per_chunk = policy_server.actions_per_chunk
|
|
||||||
|
|
||||||
return torch.zeros(batch_size, actions_per_chunk, action_dim)
|
|
||||||
|
|
||||||
monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True)
|
|
||||||
|
|
||||||
# Bypass potentially heavy model loading inside SendPolicyInstructions
|
|
||||||
def _fake_send_policy_instructions(self, request, context): # noqa: N802
|
|
||||||
return services_pb2.Empty()
|
|
||||||
|
|
||||||
monkeypatch.setattr(PolicyServer, "SendPolicyInstructions", _fake_send_policy_instructions, raising=True)
|
|
||||||
|
|
||||||
# Build gRPC server running a PolicyServer
|
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="policy_server"))
|
|
||||||
services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
|
||||||
|
|
||||||
# Use the host/port specified in the fixture's config
|
|
||||||
server_address = f"{policy_server.config.host}:{policy_server.config.port}"
|
|
||||||
server.add_insecure_port(server_address)
|
|
||||||
server.start()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 2. Create a RobotClient around the MockRobot
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
client_config = RobotClientConfig(
|
|
||||||
server_address=server_address,
|
|
||||||
robot=robot_config,
|
|
||||||
chunk_size_threshold=0.0,
|
|
||||||
policy_type="test",
|
|
||||||
pretrained_name_or_path="test",
|
|
||||||
actions_per_chunk=20,
|
|
||||||
verify_robot_cameras=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
client = RobotClient(client_config)
|
|
||||||
assert client.start(), "Client failed initial handshake with the server"
|
|
||||||
|
|
||||||
# Track action chunks received without modifying RobotClient
|
|
||||||
action_chunks_received = {"count": 0}
|
|
||||||
original_aggregate = client._aggregate_action_queues
|
|
||||||
|
|
||||||
def counting_aggregate(*args, **kwargs):
|
|
||||||
action_chunks_received["count"] += 1
|
|
||||||
return original_aggregate(*args, **kwargs)
|
|
||||||
|
|
||||||
monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate)
|
|
||||||
|
|
||||||
# Start client threads
|
|
||||||
action_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
|
||||||
control_thread = threading.Thread(target=client.control_loop, args=({"task": ""}), daemon=True)
|
|
||||||
action_thread.start()
|
|
||||||
control_thread.start()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 3. System exchanges a few messages
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Wait for 5 seconds
|
|
||||||
server.wait_for_termination(timeout=5)
|
|
||||||
|
|
||||||
assert action_chunks_received["count"] > 0, "Client did not receive any action chunks"
|
|
||||||
assert len(policy_server._predicted_timesteps) > 0, "Server did not record any predicted timesteps"
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 4. Stop the system
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
client.stop()
|
|
||||||
action_thread.join()
|
|
||||||
control_thread.join()
|
|
||||||
policy_server.stop()
|
|
||||||
server.stop(grace=None)
|
|
||||||
@ -1,459 +0,0 @@
|
|||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import math
|
|
||||||
import pickle
|
|
||||||
import time
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
|
||||||
from lerobot.scripts.server.helpers import (
|
|
||||||
FPSTracker,
|
|
||||||
TimedAction,
|
|
||||||
TimedObservation,
|
|
||||||
observations_similar,
|
|
||||||
prepare_image,
|
|
||||||
prepare_raw_observation,
|
|
||||||
raw_observation_to_observation,
|
|
||||||
resize_robot_observation_image,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------
|
|
||||||
# FPSTracker
|
|
||||||
# ---------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_fps_tracker_first_observation():
|
|
||||||
"""First observation should initialize timestamp and return 0 FPS."""
|
|
||||||
tracker = FPSTracker(target_fps=30.0)
|
|
||||||
timestamp = 1000.0
|
|
||||||
|
|
||||||
metrics = tracker.calculate_fps_metrics(timestamp)
|
|
||||||
|
|
||||||
assert tracker.first_timestamp == timestamp
|
|
||||||
assert tracker.total_obs_count == 1
|
|
||||||
assert metrics["avg_fps"] == 0.0
|
|
||||||
assert metrics["target_fps"] == 30.0
|
|
||||||
|
|
||||||
|
|
||||||
def test_fps_tracker_single_interval():
|
|
||||||
"""Two observations 1 second apart should give 1 FPS."""
|
|
||||||
tracker = FPSTracker(target_fps=30.0)
|
|
||||||
|
|
||||||
# First observation at t=0
|
|
||||||
metrics1 = tracker.calculate_fps_metrics(0.0)
|
|
||||||
assert metrics1["avg_fps"] == 0.0
|
|
||||||
|
|
||||||
# Second observation at t=1 (1 second later)
|
|
||||||
metrics2 = tracker.calculate_fps_metrics(1.0)
|
|
||||||
expected_fps = 1.0 # (2-1) observations / 1.0 seconds = 1 FPS
|
|
||||||
assert math.isclose(metrics2["avg_fps"], expected_fps, rel_tol=1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
def test_fps_tracker_multiple_intervals():
|
|
||||||
"""Multiple observations should calculate correct average FPS."""
|
|
||||||
tracker = FPSTracker(target_fps=30.0)
|
|
||||||
|
|
||||||
# Simulate 5 observations over 2 seconds (should be 2 FPS average)
|
|
||||||
timestamps = [0.0, 0.5, 1.0, 1.5, 2.0]
|
|
||||||
|
|
||||||
for i, ts in enumerate(timestamps):
|
|
||||||
metrics = tracker.calculate_fps_metrics(ts)
|
|
||||||
|
|
||||||
if i == 0:
|
|
||||||
assert metrics["avg_fps"] == 0.0
|
|
||||||
elif i == len(timestamps) - 1:
|
|
||||||
# After 5 observations over 2 seconds: (5-1)/2 = 2 FPS
|
|
||||||
expected_fps = 2.0
|
|
||||||
assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
def test_fps_tracker_irregular_intervals():
|
|
||||||
"""FPS calculation should work with irregular time intervals."""
|
|
||||||
tracker = FPSTracker(target_fps=30.0)
|
|
||||||
|
|
||||||
# Irregular timestamps: 0, 0.1, 0.5, 2.0, 3.0 seconds
|
|
||||||
timestamps = [0.0, 0.1, 0.5, 2.0, 3.0]
|
|
||||||
|
|
||||||
for ts in timestamps:
|
|
||||||
metrics = tracker.calculate_fps_metrics(ts)
|
|
||||||
|
|
||||||
# 5 observations over 3 seconds: (5-1)/3 = 1.333... FPS
|
|
||||||
expected_fps = 4.0 / 3.0
|
|
||||||
assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------
|
|
||||||
# TimedData helpers
|
|
||||||
# ---------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_timed_action_getters():
|
|
||||||
"""TimedAction stores & returns timestamp, action tensor and timestep."""
|
|
||||||
ts = time.time()
|
|
||||||
action = torch.arange(10)
|
|
||||||
ta = TimedAction(timestamp=ts, action=action, timestep=0)
|
|
||||||
|
|
||||||
assert math.isclose(ta.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
|
||||||
torch.testing.assert_close(ta.get_action(), action)
|
|
||||||
assert ta.get_timestep() == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_timed_observation_getters():
|
|
||||||
"""TimedObservation stores & returns timestamp, dict and timestep."""
|
|
||||||
ts = time.time()
|
|
||||||
obs_dict = {"observation.state": torch.ones(6)}
|
|
||||||
to = TimedObservation(timestamp=ts, observation=obs_dict, timestep=0)
|
|
||||||
|
|
||||||
assert math.isclose(to.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
|
||||||
assert to.get_observation() is obs_dict
|
|
||||||
assert to.get_timestep() == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_timed_data_deserialization_data_getters():
|
|
||||||
"""TimedAction / TimedObservation survive a round-trip through ``pickle``.
|
|
||||||
|
|
||||||
The async-inference stack uses ``pickle.dumps`` to move these objects across
|
|
||||||
the gRPC boundary (see RobotClient.send_observation and PolicyServer.StreamActions).
|
|
||||||
This test ensures that the payload keeps its content intact after
|
|
||||||
the (de)serialization round-trip.
|
|
||||||
"""
|
|
||||||
ts = time.time()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# TimedAction
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
original_action = torch.randn(6)
|
|
||||||
ta_in = TimedAction(timestamp=ts, action=original_action, timestep=13)
|
|
||||||
|
|
||||||
# Serialize → bytes → deserialize
|
|
||||||
ta_bytes = pickle.dumps(ta_in) # nosec
|
|
||||||
ta_out: TimedAction = pickle.loads(ta_bytes) # nosec B301
|
|
||||||
|
|
||||||
# Identity & content checks
|
|
||||||
assert math.isclose(ta_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
|
||||||
assert ta_out.get_timestep() == 13
|
|
||||||
torch.testing.assert_close(ta_out.get_action(), original_action)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# TimedObservation
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
obs_dict = {"observation.state": torch.arange(4).float()}
|
|
||||||
to_in = TimedObservation(timestamp=ts, observation=obs_dict, timestep=7, must_go=True)
|
|
||||||
|
|
||||||
to_bytes = pickle.dumps(to_in) # nosec
|
|
||||||
to_out: TimedObservation = pickle.loads(to_bytes) # nosec B301
|
|
||||||
|
|
||||||
assert math.isclose(to_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
|
||||||
assert to_out.get_timestep() == 7
|
|
||||||
assert to_out.must_go is True
|
|
||||||
assert to_out.get_observation().keys() == obs_dict.keys()
|
|
||||||
torch.testing.assert_close(to_out.get_observation()["observation.state"], obs_dict["observation.state"])
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------
|
|
||||||
# observations_similar()
|
|
||||||
# ---------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _make_obs(state: torch.Tensor) -> TimedObservation:
|
|
||||||
"""Create a TimedObservation with raw robot observation format."""
|
|
||||||
return TimedObservation(
|
|
||||||
timestamp=time.time(),
|
|
||||||
observation={
|
|
||||||
"shoulder": state[0].item() if len(state) > 0 else 0.0,
|
|
||||||
"elbow": state[1].item() if len(state) > 1 else 0.0,
|
|
||||||
"wrist": state[2].item() if len(state) > 2 else 0.0,
|
|
||||||
"gripper": state[3].item() if len(state) > 3 else 0.0,
|
|
||||||
},
|
|
||||||
timestep=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_observations_similar_true():
|
|
||||||
"""Distance below atol → observations considered similar."""
|
|
||||||
# Create mock lerobot features for the similarity check
|
|
||||||
lerobot_features = {
|
|
||||||
"observation.state": {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": [4],
|
|
||||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
obs1 = _make_obs(torch.zeros(4))
|
|
||||||
obs2 = _make_obs(0.5 * torch.ones(4))
|
|
||||||
assert observations_similar(obs1, obs2, lerobot_features, atol=2.0)
|
|
||||||
|
|
||||||
obs3 = _make_obs(2.0 * torch.ones(4))
|
|
||||||
assert not observations_similar(obs1, obs3, lerobot_features, atol=2.0)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------
|
|
||||||
# raw_observation_to_observation and helpers
|
|
||||||
# ---------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _create_mock_robot_observation():
|
|
||||||
"""Create a mock robot observation with motor positions and camera images."""
|
|
||||||
return {
|
|
||||||
"shoulder": 1.0,
|
|
||||||
"elbow": 2.0,
|
|
||||||
"wrist": 3.0,
|
|
||||||
"gripper": 0.5,
|
|
||||||
"laptop": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8),
|
|
||||||
"phone": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _create_mock_lerobot_features():
|
|
||||||
"""Create mock lerobot features mapping similar to what hw_to_dataset_features returns."""
|
|
||||||
return {
|
|
||||||
"observation.state": {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": [4],
|
|
||||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
|
||||||
},
|
|
||||||
"observation.images.laptop": {
|
|
||||||
"dtype": "image",
|
|
||||||
"shape": [480, 640, 3],
|
|
||||||
"names": ["height", "width", "channels"],
|
|
||||||
},
|
|
||||||
"observation.images.phone": {
|
|
||||||
"dtype": "image",
|
|
||||||
"shape": [480, 640, 3],
|
|
||||||
"names": ["height", "width", "channels"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _create_mock_policy_image_features():
|
|
||||||
"""Create mock policy image features with different resolutions."""
|
|
||||||
return {
|
|
||||||
"observation.images.laptop": PolicyFeature(
|
|
||||||
type=FeatureType.VISUAL,
|
|
||||||
shape=(3, 224, 224), # Policy expects smaller resolution
|
|
||||||
),
|
|
||||||
"observation.images.phone": PolicyFeature(
|
|
||||||
type=FeatureType.VISUAL,
|
|
||||||
shape=(3, 160, 160), # Different resolution for second camera
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_image():
|
|
||||||
"""Test image preprocessing: int8 → float32, normalization to [0,1]."""
|
|
||||||
# Create mock int8 image data
|
|
||||||
image_int8 = torch.randint(0, 256, size=(3, 224, 224), dtype=torch.uint8)
|
|
||||||
|
|
||||||
processed = prepare_image(image_int8)
|
|
||||||
|
|
||||||
# Check dtype conversion
|
|
||||||
assert processed.dtype == torch.float32
|
|
||||||
|
|
||||||
# Check normalization range
|
|
||||||
assert processed.min() >= 0.0
|
|
||||||
assert processed.max() <= 1.0
|
|
||||||
|
|
||||||
# Check that values are scaled correctly (255 → 1.0, 0 → 0.0)
|
|
||||||
if image_int8.max() == 255:
|
|
||||||
assert torch.isclose(processed.max(), torch.tensor(1.0), atol=1e-6)
|
|
||||||
if image_int8.min() == 0:
|
|
||||||
assert torch.isclose(processed.min(), torch.tensor(0.0), atol=1e-6)
|
|
||||||
|
|
||||||
# Check memory contiguity
|
|
||||||
assert processed.is_contiguous()
|
|
||||||
|
|
||||||
|
|
||||||
def test_resize_robot_observation_image():
|
|
||||||
"""Test image resizing from robot resolution to policy resolution."""
|
|
||||||
# Create mock image: (H=480, W=640, C=3)
|
|
||||||
original_image = torch.randint(0, 256, size=(480, 640, 3), dtype=torch.uint8)
|
|
||||||
target_shape = (3, 224, 224) # (C, H, W)
|
|
||||||
|
|
||||||
resized = resize_robot_observation_image(original_image, target_shape)
|
|
||||||
|
|
||||||
# Check output shape matches target
|
|
||||||
assert resized.shape == target_shape
|
|
||||||
|
|
||||||
# Check that original image had different dimensions
|
|
||||||
assert original_image.shape != resized.shape
|
|
||||||
|
|
||||||
# Check that resizing preserves value range
|
|
||||||
assert resized.min() >= 0
|
|
||||||
assert resized.max() <= 255
|
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_raw_observation():
|
|
||||||
"""Test the preparation of raw robot observation to lerobot format."""
|
|
||||||
robot_obs = _create_mock_robot_observation()
|
|
||||||
lerobot_features = _create_mock_lerobot_features()
|
|
||||||
policy_image_features = _create_mock_policy_image_features()
|
|
||||||
|
|
||||||
prepared = prepare_raw_observation(robot_obs, lerobot_features, policy_image_features)
|
|
||||||
|
|
||||||
# Check that state is properly extracted and batched
|
|
||||||
assert "observation.state" in prepared
|
|
||||||
state = prepared["observation.state"]
|
|
||||||
assert isinstance(state, torch.Tensor)
|
|
||||||
assert state.shape == (1, 4) # Batched state
|
|
||||||
|
|
||||||
# Check that images are processed and resized
|
|
||||||
assert "observation.images.laptop" in prepared
|
|
||||||
assert "observation.images.phone" in prepared
|
|
||||||
|
|
||||||
laptop_img = prepared["observation.images.laptop"]
|
|
||||||
phone_img = prepared["observation.images.phone"]
|
|
||||||
|
|
||||||
# Check image shapes match policy requirements
|
|
||||||
assert laptop_img.shape == policy_image_features["observation.images.laptop"].shape
|
|
||||||
assert phone_img.shape == policy_image_features["observation.images.phone"].shape
|
|
||||||
|
|
||||||
# Check that images are tensors
|
|
||||||
assert isinstance(laptop_img, torch.Tensor)
|
|
||||||
assert isinstance(phone_img, torch.Tensor)
|
|
||||||
|
|
||||||
|
|
||||||
def test_raw_observation_to_observation_basic():
|
|
||||||
"""Test the main raw_observation_to_observation function."""
|
|
||||||
robot_obs = _create_mock_robot_observation()
|
|
||||||
lerobot_features = _create_mock_lerobot_features()
|
|
||||||
policy_image_features = _create_mock_policy_image_features()
|
|
||||||
device = "cpu"
|
|
||||||
|
|
||||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
|
||||||
|
|
||||||
# Check that all expected keys are present
|
|
||||||
assert "observation.state" in observation
|
|
||||||
assert "observation.images.laptop" in observation
|
|
||||||
assert "observation.images.phone" in observation
|
|
||||||
|
|
||||||
# Check state processing
|
|
||||||
state = observation["observation.state"]
|
|
||||||
assert isinstance(state, torch.Tensor)
|
|
||||||
assert state.device.type == device
|
|
||||||
assert state.shape == (1, 4) # Batched
|
|
||||||
|
|
||||||
# Check image processing
|
|
||||||
laptop_img = observation["observation.images.laptop"]
|
|
||||||
phone_img = observation["observation.images.phone"]
|
|
||||||
|
|
||||||
# Images should have batch dimension: (B, C, H, W)
|
|
||||||
assert laptop_img.shape == (1, 3, 224, 224)
|
|
||||||
assert phone_img.shape == (1, 3, 160, 160)
|
|
||||||
|
|
||||||
# Check device placement
|
|
||||||
assert laptop_img.device.type == device
|
|
||||||
assert phone_img.device.type == device
|
|
||||||
|
|
||||||
# Check image dtype and range (should be float32 in [0, 1])
|
|
||||||
assert laptop_img.dtype == torch.float32
|
|
||||||
assert phone_img.dtype == torch.float32
|
|
||||||
assert laptop_img.min() >= 0.0 and laptop_img.max() <= 1.0
|
|
||||||
assert phone_img.min() >= 0.0 and phone_img.max() <= 1.0
|
|
||||||
|
|
||||||
|
|
||||||
def test_raw_observation_to_observation_with_non_tensor_data():
|
|
||||||
"""Test that non-tensor data (like task strings) is preserved."""
|
|
||||||
robot_obs = _create_mock_robot_observation()
|
|
||||||
robot_obs["task"] = "pick up the red cube" # Add string instruction
|
|
||||||
|
|
||||||
lerobot_features = _create_mock_lerobot_features()
|
|
||||||
policy_image_features = _create_mock_policy_image_features()
|
|
||||||
device = "cpu"
|
|
||||||
|
|
||||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
|
||||||
|
|
||||||
# Check that task string is preserved
|
|
||||||
assert "task" in observation
|
|
||||||
assert observation["task"] == "pick up the red cube"
|
|
||||||
assert isinstance(observation["task"], str)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def test_raw_observation_to_observation_device_handling():
|
|
||||||
"""Test that tensors are properly moved to the specified device."""
|
|
||||||
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
|
||||||
|
|
||||||
robot_obs = _create_mock_robot_observation()
|
|
||||||
lerobot_features = _create_mock_lerobot_features()
|
|
||||||
policy_image_features = _create_mock_policy_image_features()
|
|
||||||
|
|
||||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
|
||||||
|
|
||||||
# Check that all tensors are on the correct device
|
|
||||||
for key, value in observation.items():
|
|
||||||
if isinstance(value, torch.Tensor):
|
|
||||||
assert value.device.type == device, f"Tensor {key} not on {device}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_raw_observation_to_observation_deterministic():
|
|
||||||
"""Test that the function produces consistent results for the same input."""
|
|
||||||
robot_obs = _create_mock_robot_observation()
|
|
||||||
lerobot_features = _create_mock_lerobot_features()
|
|
||||||
policy_image_features = _create_mock_policy_image_features()
|
|
||||||
device = "cpu"
|
|
||||||
|
|
||||||
# Run twice with same input
|
|
||||||
obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
|
||||||
obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
|
||||||
|
|
||||||
# Results should be identical
|
|
||||||
assert set(obs1.keys()) == set(obs2.keys())
|
|
||||||
|
|
||||||
for key in obs1:
|
|
||||||
if isinstance(obs1[key], torch.Tensor):
|
|
||||||
torch.testing.assert_close(obs1[key], obs2[key])
|
|
||||||
else:
|
|
||||||
assert obs1[key] == obs2[key]
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_processing_pipeline_preserves_content():
|
|
||||||
"""Test that the image processing pipeline preserves recognizable patterns."""
|
|
||||||
# Create an image with a specific pattern
|
|
||||||
original_img = np.zeros((100, 100, 3), dtype=np.uint8)
|
|
||||||
original_img[25:75, 25:75, :] = 255 # White square in center
|
|
||||||
|
|
||||||
robot_obs = {"shoulder": 1.0, "elbow": 1.0, "wrist": 1.0, "gripper": 1.0, "laptop": original_img}
|
|
||||||
lerobot_features = {
|
|
||||||
"observation.state": {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": [4],
|
|
||||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
|
||||||
},
|
|
||||||
"observation.images.laptop": {
|
|
||||||
"dtype": "image",
|
|
||||||
"shape": [100, 100, 3],
|
|
||||||
"names": ["height", "width", "channels"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
policy_image_features = {
|
|
||||||
"observation.images.laptop": PolicyFeature(
|
|
||||||
type=FeatureType.VISUAL,
|
|
||||||
shape=(3, 50, 50), # Downsamples from 100x100
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu")
|
|
||||||
|
|
||||||
processed_img = observation["observation.images.laptop"].squeeze(0) # Remove batch dim
|
|
||||||
|
|
||||||
# Check that the center region has higher values than corners
|
|
||||||
# Due to bilinear interpolation, exact values will change but pattern should remain
|
|
||||||
center_val = processed_img[:, 25, 25].mean() # Center of 50x50 image
|
|
||||||
corner_val = processed_img[:, 5, 5].mean() # Corner
|
|
||||||
|
|
||||||
assert center_val > corner_val, "Image processing should preserve recognizable patterns"
|
|
||||||
@ -1,215 +0,0 @@
|
|||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Unit-tests for the `PolicyServer` core logic.
|
|
||||||
Monkey-patch the `policy` attribute with a stub so that no real model inference is performed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.configs.types import PolicyFeature
|
|
||||||
from tests.utils import require_package
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# Test fixtures
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class MockPolicy:
|
|
||||||
"""A minimal mock for an actual policy, returning zeros.
|
|
||||||
Refer to tests/policies for tests of the individual policies supported."""
|
|
||||||
|
|
||||||
class _Config:
|
|
||||||
robot_type = "dummy_robot"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def image_features(self) -> dict[str, PolicyFeature]:
|
|
||||||
"""Empty image features since this test doesn't use images."""
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def predict_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
||||||
"""Return a chunk of 20 dummy actions."""
|
|
||||||
batch_size = len(observation["observation.state"])
|
|
||||||
return torch.zeros(batch_size, 20, 6)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.config = self._Config()
|
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
|
||||||
# The server calls `policy.to(device)`. This stub ignores it.
|
|
||||||
return self
|
|
||||||
|
|
||||||
def model(self, batch: dict) -> torch.Tensor:
|
|
||||||
# Return a chunk of 20 dummy actions.
|
|
||||||
batch_size = len(batch["robot_type"])
|
|
||||||
return torch.zeros(batch_size, 20, 6)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
@require_package("grpc")
|
|
||||||
def policy_server():
|
|
||||||
"""Fresh `PolicyServer` instance with a stubbed-out policy model."""
|
|
||||||
# Import only when the test actually runs (after decorator check)
|
|
||||||
from lerobot.scripts.server.configs import PolicyServerConfig
|
|
||||||
from lerobot.scripts.server.policy_server import PolicyServer
|
|
||||||
|
|
||||||
test_config = PolicyServerConfig(host="localhost", port=9999)
|
|
||||||
server = PolicyServer(test_config)
|
|
||||||
# Replace the real policy with our fast, deterministic stub.
|
|
||||||
server.policy = MockPolicy()
|
|
||||||
server.actions_per_chunk = 20
|
|
||||||
server.device = "cpu"
|
|
||||||
|
|
||||||
# Add mock lerobot_features that the observation similarity functions need
|
|
||||||
server.lerobot_features = {
|
|
||||||
"observation.state": {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": [6],
|
|
||||||
"names": ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return server
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# Helper utilities for tests
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _make_obs(state: torch.Tensor, timestep: int = 0, must_go: bool = False):
|
|
||||||
"""Create a TimedObservation with a given state vector."""
|
|
||||||
# Import only when needed
|
|
||||||
from lerobot.scripts.server.helpers import TimedObservation
|
|
||||||
|
|
||||||
return TimedObservation(
|
|
||||||
observation={
|
|
||||||
"joint1": state[0].item() if len(state) > 0 else 0.0,
|
|
||||||
"joint2": state[1].item() if len(state) > 1 else 0.0,
|
|
||||||
"joint3": state[2].item() if len(state) > 2 else 0.0,
|
|
||||||
"joint4": state[3].item() if len(state) > 3 else 0.0,
|
|
||||||
"joint5": state[4].item() if len(state) > 4 else 0.0,
|
|
||||||
"joint6": state[5].item() if len(state) > 5 else 0.0,
|
|
||||||
},
|
|
||||||
timestamp=time.time(),
|
|
||||||
timestep=timestep,
|
|
||||||
must_go=must_go,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# Tests
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_time_action_chunk(policy_server):
|
|
||||||
"""Verify that `_time_action_chunk` assigns correct timestamps and timesteps."""
|
|
||||||
start_ts = time.time()
|
|
||||||
start_t = 10
|
|
||||||
# A chunk of 3 action tensors.
|
|
||||||
action_tensors = [torch.randn(6) for _ in range(3)]
|
|
||||||
|
|
||||||
timed_actions = policy_server._time_action_chunk(start_ts, action_tensors, start_t)
|
|
||||||
|
|
||||||
assert len(timed_actions) == 3
|
|
||||||
# Check timesteps
|
|
||||||
assert [ta.get_timestep() for ta in timed_actions] == [10, 11, 12]
|
|
||||||
# Check timestamps
|
|
||||||
expected_timestamps = [
|
|
||||||
start_ts,
|
|
||||||
start_ts + policy_server.config.environment_dt,
|
|
||||||
start_ts + 2 * policy_server.config.environment_dt,
|
|
||||||
]
|
|
||||||
for ta, expected_ts in zip(timed_actions, expected_timestamps, strict=True):
|
|
||||||
assert abs(ta.get_timestamp() - expected_ts) < 1e-6
|
|
||||||
|
|
||||||
|
|
||||||
def test_maybe_enqueue_observation_must_go(policy_server):
|
|
||||||
"""An observation with `must_go=True` is always enqueued."""
|
|
||||||
obs = _make_obs(torch.zeros(6), must_go=True)
|
|
||||||
assert policy_server._enqueue_observation(obs) is True
|
|
||||||
assert policy_server.observation_queue.qsize() == 1
|
|
||||||
assert policy_server.observation_queue.get_nowait() is obs
|
|
||||||
|
|
||||||
|
|
||||||
def test_maybe_enqueue_observation_dissimilar(policy_server):
|
|
||||||
"""A dissimilar observation (not `must_go`) is enqueued."""
|
|
||||||
# Set a last predicted observation.
|
|
||||||
policy_server.last_processed_obs = _make_obs(torch.zeros(6))
|
|
||||||
# Create a new, dissimilar observation.
|
|
||||||
new_obs = _make_obs(torch.ones(6) * 5) # High norm difference
|
|
||||||
|
|
||||||
assert policy_server._enqueue_observation(new_obs) is True
|
|
||||||
assert policy_server.observation_queue.qsize() == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_maybe_enqueue_observation_is_skipped(policy_server):
|
|
||||||
"""A similar observation (not `must_go`) is skipped."""
|
|
||||||
# Set a last predicted observation.
|
|
||||||
policy_server.last_processed_obs = _make_obs(torch.zeros(6))
|
|
||||||
# Create a new, very similar observation.
|
|
||||||
new_obs = _make_obs(torch.zeros(6) + 1e-4)
|
|
||||||
|
|
||||||
assert policy_server._enqueue_observation(new_obs) is False
|
|
||||||
assert policy_server.observation_queue.empty() is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_obs_sanity_checks(policy_server):
|
|
||||||
"""Unit-test the private `_obs_sanity_checks` helper."""
|
|
||||||
prev = _make_obs(torch.zeros(6), timestep=0)
|
|
||||||
|
|
||||||
# Case 1 – timestep already predicted
|
|
||||||
policy_server._predicted_timesteps.add(1)
|
|
||||||
obs_same_ts = _make_obs(torch.ones(6), timestep=1)
|
|
||||||
assert policy_server._obs_sanity_checks(obs_same_ts, prev) is False
|
|
||||||
|
|
||||||
# Case 2 – observation too similar
|
|
||||||
policy_server._predicted_timesteps.clear()
|
|
||||||
obs_similar = _make_obs(torch.zeros(6) + 1e-4, timestep=2)
|
|
||||||
assert policy_server._obs_sanity_checks(obs_similar, prev) is False
|
|
||||||
|
|
||||||
# Case 3 – genuinely new & dissimilar observation passes
|
|
||||||
obs_ok = _make_obs(torch.ones(6) * 5, timestep=3)
|
|
||||||
assert policy_server._obs_sanity_checks(obs_ok, prev) is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_predict_action_chunk(monkeypatch, policy_server):
|
|
||||||
"""End-to-end test of `_predict_action_chunk` with a stubbed _get_action_chunk."""
|
|
||||||
# Import only when needed
|
|
||||||
from lerobot.scripts.server.policy_server import PolicyServer
|
|
||||||
|
|
||||||
# Force server to act-style policy; patch method to return deterministic tensor
|
|
||||||
policy_server.policy_type = "act"
|
|
||||||
action_dim = 6
|
|
||||||
batch_size = 1
|
|
||||||
actions_per_chunk = policy_server.actions_per_chunk
|
|
||||||
|
|
||||||
def _fake_get_action_chunk(_self, _obs, _type="act"):
|
|
||||||
return torch.zeros(batch_size, actions_per_chunk, action_dim)
|
|
||||||
|
|
||||||
monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True)
|
|
||||||
|
|
||||||
obs = _make_obs(torch.zeros(6), timestep=5)
|
|
||||||
timed_actions = policy_server._predict_action_chunk(obs)
|
|
||||||
|
|
||||||
assert len(timed_actions) == actions_per_chunk
|
|
||||||
assert [ta.get_timestep() for ta in timed_actions] == list(range(5, 5 + actions_per_chunk))
|
|
||||||
|
|
||||||
for i, ta in enumerate(timed_actions):
|
|
||||||
expected_ts = obs.get_timestamp() + i * policy_server.config.environment_dt
|
|
||||||
assert abs(ta.get_timestamp() - expected_ts) < 1e-6
|
|
||||||
@ -1,234 +0,0 @@
|
|||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Unit-tests for the `RobotClient` action-queue logic (pure Python, no gRPC).
|
|
||||||
|
|
||||||
We monkey-patch `lerobot.robots.utils.make_robot_from_config` so that
|
|
||||||
no real hardware is accessed. Only the queue-update mechanism is verified.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import time
|
|
||||||
from queue import Queue
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# Skip entire module if grpc is not available
|
|
||||||
pytest.importorskip("grpc")
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# Test fixtures
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def robot_client():
|
|
||||||
"""Fresh `RobotClient` instance for each test case (no threads started).
|
|
||||||
Uses DummyRobot."""
|
|
||||||
# Import only when the test actually runs (after decorator check)
|
|
||||||
from lerobot.scripts.server.configs import RobotClientConfig
|
|
||||||
from lerobot.scripts.server.robot_client import RobotClient
|
|
||||||
from tests.mocks.mock_robot import MockRobotConfig
|
|
||||||
|
|
||||||
test_config = MockRobotConfig()
|
|
||||||
|
|
||||||
# gRPC channel is not actually used in tests, so using a dummy address
|
|
||||||
test_config = RobotClientConfig(
|
|
||||||
robot=test_config,
|
|
||||||
server_address="localhost:9999",
|
|
||||||
policy_type="test",
|
|
||||||
pretrained_name_or_path="test",
|
|
||||||
actions_per_chunk=20,
|
|
||||||
verify_robot_cameras=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
client = RobotClient(test_config)
|
|
||||||
|
|
||||||
# Initialize attributes that are normally set in start() method
|
|
||||||
client.chunks_received = 0
|
|
||||||
client.available_actions_size = []
|
|
||||||
|
|
||||||
yield client
|
|
||||||
|
|
||||||
if client.robot.is_connected:
|
|
||||||
client.stop()
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# Helper utilities for tests
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _make_actions(start_ts: float, start_t: int, count: int):
|
|
||||||
"""Generate `count` consecutive TimedAction objects starting at timestep `start_t`."""
|
|
||||||
from lerobot.scripts.server.helpers import TimedAction
|
|
||||||
|
|
||||||
fps = 30 # emulates most common frame-rate
|
|
||||||
actions = []
|
|
||||||
for i in range(count):
|
|
||||||
timestep = start_t + i
|
|
||||||
timestamp = start_ts + i * (1 / fps)
|
|
||||||
action_tensor = torch.full((6,), timestep, dtype=torch.float32)
|
|
||||||
actions.append(TimedAction(action=action_tensor, timestep=timestep, timestamp=timestamp))
|
|
||||||
return actions
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# Tests
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_action_queue_discards_stale(robot_client):
|
|
||||||
"""`_update_action_queue` must drop actions with `timestep` <= `latest_action`."""
|
|
||||||
|
|
||||||
# Pretend we already executed up to action #4
|
|
||||||
robot_client.latest_action = 4
|
|
||||||
|
|
||||||
# Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept.
|
|
||||||
incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7
|
|
||||||
|
|
||||||
robot_client._aggregate_action_queues(incoming)
|
|
||||||
|
|
||||||
# Extract timesteps from queue
|
|
||||||
resulting_timesteps = [a.get_timestep() for a in robot_client.action_queue.queue]
|
|
||||||
|
|
||||||
assert resulting_timesteps == [5, 6, 7]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"weight_old, weight_new",
|
|
||||||
[
|
|
||||||
(1.0, 0.0),
|
|
||||||
(0.0, 1.0),
|
|
||||||
(0.5, 0.5),
|
|
||||||
(0.2, 0.8),
|
|
||||||
(0.8, 0.2),
|
|
||||||
(0.1, 0.9),
|
|
||||||
(0.9, 0.1),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_aggregate_action_queues_combines_actions_in_overlap(
|
|
||||||
robot_client, weight_old: float, weight_new: float
|
|
||||||
):
|
|
||||||
"""`_aggregate_action_queues` must combine actions on overlapping timesteps according
|
|
||||||
to the provided aggregate_fn, here tested with multiple coefficients."""
|
|
||||||
from lerobot.scripts.server.helpers import TimedAction
|
|
||||||
|
|
||||||
robot_client.chunks_received = 0
|
|
||||||
|
|
||||||
# Pretend we already executed up to action #4, and queue contains actions for timesteps 5..6
|
|
||||||
robot_client.latest_action = 4
|
|
||||||
current_actions = _make_actions(
|
|
||||||
start_ts=time.time(), start_t=5, count=2
|
|
||||||
) # actions are [torch.ones(6), torch.ones(6), ...]
|
|
||||||
current_actions = [
|
|
||||||
TimedAction(action=10 * a.get_action(), timestep=a.get_timestep(), timestamp=a.get_timestamp())
|
|
||||||
for a in current_actions
|
|
||||||
]
|
|
||||||
|
|
||||||
for a in current_actions:
|
|
||||||
robot_client.action_queue.put(a)
|
|
||||||
|
|
||||||
# Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept.
|
|
||||||
incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7
|
|
||||||
|
|
||||||
overlap_timesteps = [5, 6] # properly tested in test_aggregate_action_queues_discards_stale
|
|
||||||
nonoverlap_timesteps = [7]
|
|
||||||
|
|
||||||
robot_client._aggregate_action_queues(
|
|
||||||
incoming, aggregate_fn=lambda x1, x2: weight_old * x1 + weight_new * x2
|
|
||||||
)
|
|
||||||
|
|
||||||
queue_overlap_actions = []
|
|
||||||
queue_non_overlap_actions = []
|
|
||||||
for a in robot_client.action_queue.queue:
|
|
||||||
if a.get_timestep() in overlap_timesteps:
|
|
||||||
queue_overlap_actions.append(a)
|
|
||||||
elif a.get_timestep() in nonoverlap_timesteps:
|
|
||||||
queue_non_overlap_actions.append(a)
|
|
||||||
|
|
||||||
queue_overlap_actions = sorted(queue_overlap_actions, key=lambda x: x.get_timestep())
|
|
||||||
queue_non_overlap_actions = sorted(queue_non_overlap_actions, key=lambda x: x.get_timestep())
|
|
||||||
|
|
||||||
assert torch.allclose(
|
|
||||||
queue_overlap_actions[0].get_action(),
|
|
||||||
weight_old * current_actions[0].get_action() + weight_new * incoming[-3].get_action(),
|
|
||||||
)
|
|
||||||
assert torch.allclose(
|
|
||||||
queue_overlap_actions[1].get_action(),
|
|
||||||
weight_old * current_actions[1].get_action() + weight_new * incoming[-2].get_action(),
|
|
||||||
)
|
|
||||||
assert torch.allclose(queue_non_overlap_actions[0].get_action(), incoming[-1].get_action())
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"chunk_size, queue_len, expected",
|
|
||||||
[
|
|
||||||
(20, 12, False), # 12 / 20 = 0.6 > g=0.5 threshold, not ready to send
|
|
||||||
(20, 8, True), # 8 / 20 = 0.4 <= g=0.5, ready to send
|
|
||||||
(10, 5, True),
|
|
||||||
(10, 6, False),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_ready_to_send_observation(robot_client, chunk_size: int, queue_len: int, expected: bool):
|
|
||||||
"""Validate `_ready_to_send_observation` ratio logic for various sizes."""
|
|
||||||
|
|
||||||
robot_client.action_chunk_size = chunk_size
|
|
||||||
|
|
||||||
# Clear any existing actions then fill with `queue_len` dummy entries ----
|
|
||||||
robot_client.action_queue = Queue()
|
|
||||||
|
|
||||||
dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len)
|
|
||||||
for act in dummy_actions:
|
|
||||||
robot_client.action_queue.put(act)
|
|
||||||
|
|
||||||
assert robot_client._ready_to_send_observation() is expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"g_threshold, expected",
|
|
||||||
[
|
|
||||||
# The condition is `queue_size / chunk_size <= g`.
|
|
||||||
# Here, ratio = 6 / 10 = 0.6.
|
|
||||||
(0.0, False), # 0.6 <= 0.0 is False
|
|
||||||
(0.1, False),
|
|
||||||
(0.2, False),
|
|
||||||
(0.3, False),
|
|
||||||
(0.4, False),
|
|
||||||
(0.5, False),
|
|
||||||
(0.6, True), # 0.6 <= 0.6 is True
|
|
||||||
(0.7, True),
|
|
||||||
(0.8, True),
|
|
||||||
(0.9, True),
|
|
||||||
(1.0, True),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_ready_to_send_observation_with_varying_threshold(robot_client, g_threshold: float, expected: bool):
|
|
||||||
"""Validate `_ready_to_send_observation` with fixed sizes and varying `g`."""
|
|
||||||
# Fixed sizes for this test: ratio = 6 / 10 = 0.6
|
|
||||||
chunk_size = 10
|
|
||||||
queue_len = 6
|
|
||||||
|
|
||||||
robot_client.action_chunk_size = chunk_size
|
|
||||||
# This is the parameter we are testing
|
|
||||||
robot_client._chunk_size_threshold = g_threshold
|
|
||||||
|
|
||||||
# Fill queue with dummy actions
|
|
||||||
robot_client.action_queue = Queue()
|
|
||||||
dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len)
|
|
||||||
for act in dummy_actions:
|
|
||||||
robot_client.action_queue.put(act)
|
|
||||||
|
|
||||||
assert robot_client._ready_to_send_observation() is expected
|
|
||||||
@ -1,188 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
# Example of running a specific test:
|
|
||||||
# ```bash
|
|
||||||
# pytest tests/cameras/test_opencv.py::test_connect
|
|
||||||
# ```
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from lerobot.cameras.configs import Cv2Rotation
|
|
||||||
from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig
|
|
||||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
|
||||||
|
|
||||||
# NOTE(Steven): more tests + assertions?
|
|
||||||
TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "cameras"
|
|
||||||
DEFAULT_PNG_FILE_PATH = TEST_ARTIFACTS_DIR / "image_160x120.png"
|
|
||||||
TEST_IMAGE_SIZES = ["128x128", "160x120", "320x180", "480x270"]
|
|
||||||
TEST_IMAGE_PATHS = [TEST_ARTIFACTS_DIR / f"image_{size}.png" for size in TEST_IMAGE_SIZES]
|
|
||||||
|
|
||||||
|
|
||||||
def test_abc_implementation():
|
|
||||||
"""Instantiation should raise an error if the class doesn't implement abstract methods/properties."""
|
|
||||||
config = OpenCVCameraConfig(index_or_path=0)
|
|
||||||
|
|
||||||
_ = OpenCVCamera(config)
|
|
||||||
|
|
||||||
|
|
||||||
def test_connect():
|
|
||||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
|
|
||||||
camera = OpenCVCamera(config)
|
|
||||||
|
|
||||||
camera.connect(warmup=False)
|
|
||||||
|
|
||||||
assert camera.is_connected
|
|
||||||
|
|
||||||
|
|
||||||
def test_connect_already_connected():
|
|
||||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
|
|
||||||
camera = OpenCVCamera(config)
|
|
||||||
camera.connect(warmup=False)
|
|
||||||
|
|
||||||
with pytest.raises(DeviceAlreadyConnectedError):
|
|
||||||
camera.connect(warmup=False)
|
|
||||||
|
|
||||||
|
|
||||||
def test_connect_invalid_camera_path():
|
|
||||||
config = OpenCVCameraConfig(index_or_path="nonexistent/camera.png")
|
|
||||||
camera = OpenCVCamera(config)
|
|
||||||
|
|
||||||
with pytest.raises(ConnectionError):
|
|
||||||
camera.connect(warmup=False)
|
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_width_connect():
|
|
||||||
config = OpenCVCameraConfig(
|
|
||||||
index_or_path=DEFAULT_PNG_FILE_PATH,
|
|
||||||
width=99999, # Invalid width to trigger error
|
|
||||||
height=480,
|
|
||||||
)
|
|
||||||
camera = OpenCVCamera(config)
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError):
|
|
||||||
camera.connect(warmup=False)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES)
|
|
||||||
def test_read(index_or_path):
|
|
||||||
config = OpenCVCameraConfig(index_or_path=index_or_path)
|
|
||||||
camera = OpenCVCamera(config)
|
|
||||||
camera.connect(warmup=False)
|
|
||||||
|
|
||||||
img = camera.read()
|
|
||||||
|
|
||||||
assert isinstance(img, np.ndarray)
|
|
||||||
|
|
||||||
|
|
||||||
def test_read_before_connect():
|
|
||||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
|
|
||||||
camera = OpenCVCamera(config)
|
|
||||||
|
|
||||||
with pytest.raises(DeviceNotConnectedError):
|
|
||||||
_ = camera.read()
|
|
||||||
|
|
||||||
|
|
||||||
def test_disconnect():
|
|
||||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
|
|
||||||
camera = OpenCVCamera(config)
|
|
||||||
camera.connect(warmup=False)
|
|
||||||
|
|
||||||
camera.disconnect()
|
|
||||||
|
|
||||||
assert not camera.is_connected
|
|
||||||
|
|
||||||
|
|
||||||
def test_disconnect_before_connect():
|
|
||||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
|
|
||||||
camera = OpenCVCamera(config)
|
|
||||||
|
|
||||||
with pytest.raises(DeviceNotConnectedError):
|
|
||||||
_ = camera.disconnect()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES)
|
|
||||||
def test_async_read(index_or_path):
|
|
||||||
config = OpenCVCameraConfig(index_or_path=index_or_path)
|
|
||||||
camera = OpenCVCamera(config)
|
|
||||||
camera.connect(warmup=False)
|
|
||||||
|
|
||||||
try:
|
|
||||||
img = camera.async_read()
|
|
||||||
|
|
||||||
assert camera.thread is not None
|
|
||||||
assert camera.thread.is_alive()
|
|
||||||
assert isinstance(img, np.ndarray)
|
|
||||||
finally:
|
|
||||||
if camera.is_connected:
|
|
||||||
camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends
|
|
||||||
|
|
||||||
|
|
||||||
def test_async_read_timeout():
|
|
||||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
|
|
||||||
camera = OpenCVCamera(config)
|
|
||||||
camera.connect(warmup=False)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with pytest.raises(TimeoutError):
|
|
||||||
camera.async_read(timeout_ms=0)
|
|
||||||
finally:
|
|
||||||
if camera.is_connected:
|
|
||||||
camera.disconnect()
|
|
||||||
|
|
||||||
|
|
||||||
def test_async_read_before_connect():
|
|
||||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
|
|
||||||
camera = OpenCVCamera(config)
|
|
||||||
|
|
||||||
with pytest.raises(DeviceNotConnectedError):
|
|
||||||
_ = camera.async_read()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES)
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"rotation",
|
|
||||||
[
|
|
||||||
Cv2Rotation.NO_ROTATION,
|
|
||||||
Cv2Rotation.ROTATE_90,
|
|
||||||
Cv2Rotation.ROTATE_180,
|
|
||||||
Cv2Rotation.ROTATE_270,
|
|
||||||
],
|
|
||||||
ids=["no_rot", "rot90", "rot180", "rot270"],
|
|
||||||
)
|
|
||||||
def test_rotation(rotation, index_or_path):
|
|
||||||
filename = Path(index_or_path).name
|
|
||||||
dimensions = filename.split("_")[-1].split(".")[0] # Assumes filenames format (_wxh.png)
|
|
||||||
original_width, original_height = map(int, dimensions.split("x"))
|
|
||||||
|
|
||||||
config = OpenCVCameraConfig(index_or_path=index_or_path, rotation=rotation)
|
|
||||||
camera = OpenCVCamera(config)
|
|
||||||
camera.connect(warmup=False)
|
|
||||||
|
|
||||||
img = camera.read()
|
|
||||||
assert isinstance(img, np.ndarray)
|
|
||||||
|
|
||||||
if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270):
|
|
||||||
assert camera.width == original_height
|
|
||||||
assert camera.height == original_width
|
|
||||||
assert img.shape[:2] == (original_width, original_height)
|
|
||||||
else:
|
|
||||||
assert camera.width == original_width
|
|
||||||
assert camera.height == original_height
|
|
||||||
assert img.shape[:2] == (original_height, original_width)
|
|
||||||
@ -1,206 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
# Example of running a specific test:
|
|
||||||
# ```bash
|
|
||||||
# pytest tests/cameras/test_opencv.py::test_connect
|
|
||||||
# ```
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from lerobot.cameras.configs import Cv2Rotation
|
|
||||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
|
||||||
|
|
||||||
pytest.importorskip("pyrealsense2")
|
|
||||||
|
|
||||||
from lerobot.cameras.realsense import RealSenseCamera, RealSenseCameraConfig
|
|
||||||
|
|
||||||
TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "cameras"
|
|
||||||
BAG_FILE_PATH = TEST_ARTIFACTS_DIR / "test_rs.bag"
|
|
||||||
|
|
||||||
# NOTE(Steven): For some reason these tests take ~20sec in macOS but only ~2sec in Linux.
|
|
||||||
|
|
||||||
|
|
||||||
def mock_rs_config_enable_device_from_file(rs_config_instance, _sn):
|
|
||||||
return rs_config_instance.enable_device_from_file(str(BAG_FILE_PATH), repeat_playback=True)
|
|
||||||
|
|
||||||
|
|
||||||
def mock_rs_config_enable_device_bad_file(rs_config_instance, _sn):
|
|
||||||
return rs_config_instance.enable_device_from_file("non_existent_file.bag", repeat_playback=True)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="patch_realsense", autouse=True)
|
|
||||||
def fixture_patch_realsense():
|
|
||||||
"""Automatically mock pyrealsense2.config.enable_device for all tests."""
|
|
||||||
with patch(
|
|
||||||
"pyrealsense2.config.enable_device", side_effect=mock_rs_config_enable_device_from_file
|
|
||||||
) as mock:
|
|
||||||
yield mock
|
|
||||||
|
|
||||||
|
|
||||||
def test_abc_implementation():
|
|
||||||
"""Instantiation should raise an error if the class doesn't implement abstract methods/properties."""
|
|
||||||
config = RealSenseCameraConfig(serial_number_or_name="042")
|
|
||||||
_ = RealSenseCamera(config)
|
|
||||||
|
|
||||||
|
|
||||||
def test_connect():
|
|
||||||
config = RealSenseCameraConfig(serial_number_or_name="042")
|
|
||||||
camera = RealSenseCamera(config)
|
|
||||||
|
|
||||||
camera.connect(warmup=False)
|
|
||||||
assert camera.is_connected
|
|
||||||
|
|
||||||
|
|
||||||
def test_connect_already_connected():
|
|
||||||
config = RealSenseCameraConfig(serial_number_or_name="042")
|
|
||||||
camera = RealSenseCamera(config)
|
|
||||||
camera.connect(warmup=False)
|
|
||||||
|
|
||||||
with pytest.raises(DeviceAlreadyConnectedError):
|
|
||||||
camera.connect(warmup=False)
|
|
||||||
|
|
||||||
|
|
||||||
def test_connect_invalid_camera_path(patch_realsense):
|
|
||||||
patch_realsense.side_effect = mock_rs_config_enable_device_bad_file
|
|
||||||
config = RealSenseCameraConfig(serial_number_or_name="042")
|
|
||||||
camera = RealSenseCamera(config)
|
|
||||||
|
|
||||||
with pytest.raises(ConnectionError):
|
|
||||||
camera.connect(warmup=False)
|
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_width_connect():
|
|
||||||
config = RealSenseCameraConfig(serial_number_or_name="042", width=99999, height=480, fps=30)
|
|
||||||
camera = RealSenseCamera(config)
|
|
||||||
|
|
||||||
with pytest.raises(ConnectionError):
|
|
||||||
camera.connect(warmup=False)
|
|
||||||
|
|
||||||
|
|
||||||
def test_read():
|
|
||||||
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30)
|
|
||||||
camera = RealSenseCamera(config)
|
|
||||||
camera.connect(warmup=False)
|
|
||||||
|
|
||||||
img = camera.read()
|
|
||||||
assert isinstance(img, np.ndarray)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(Steven): Fix this test for the latest version of pyrealsense2.
|
|
||||||
@pytest.mark.skip("Skipping test: pyrealsense2 version > 2.55.1.6486")
|
|
||||||
def test_read_depth():
|
|
||||||
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, use_depth=True)
|
|
||||||
camera = RealSenseCamera(config)
|
|
||||||
camera.connect(warmup=False)
|
|
||||||
|
|
||||||
img = camera.read_depth(timeout_ms=2000) # NOTE(Steven): Reading depth takes longer in CI environments.
|
|
||||||
assert isinstance(img, np.ndarray)
|
|
||||||
|
|
||||||
|
|
||||||
def test_read_before_connect():
|
|
||||||
config = RealSenseCameraConfig(serial_number_or_name="042")
|
|
||||||
camera = RealSenseCamera(config)
|
|
||||||
|
|
||||||
with pytest.raises(DeviceNotConnectedError):
|
|
||||||
_ = camera.read()
|
|
||||||
|
|
||||||
|
|
||||||
def test_disconnect():
|
|
||||||
config = RealSenseCameraConfig(serial_number_or_name="042")
|
|
||||||
camera = RealSenseCamera(config)
|
|
||||||
camera.connect(warmup=False)
|
|
||||||
|
|
||||||
camera.disconnect()
|
|
||||||
|
|
||||||
assert not camera.is_connected
|
|
||||||
|
|
||||||
|
|
||||||
def test_disconnect_before_connect():
|
|
||||||
config = RealSenseCameraConfig(serial_number_or_name="042")
|
|
||||||
camera = RealSenseCamera(config)
|
|
||||||
|
|
||||||
with pytest.raises(DeviceNotConnectedError):
|
|
||||||
camera.disconnect()
|
|
||||||
|
|
||||||
|
|
||||||
def test_async_read():
|
|
||||||
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30)
|
|
||||||
camera = RealSenseCamera(config)
|
|
||||||
camera.connect(warmup=False)
|
|
||||||
|
|
||||||
try:
|
|
||||||
img = camera.async_read()
|
|
||||||
|
|
||||||
assert camera.thread is not None
|
|
||||||
assert camera.thread.is_alive()
|
|
||||||
assert isinstance(img, np.ndarray)
|
|
||||||
finally:
|
|
||||||
if camera.is_connected:
|
|
||||||
camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends
|
|
||||||
|
|
||||||
|
|
||||||
def test_async_read_timeout():
|
|
||||||
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30)
|
|
||||||
camera = RealSenseCamera(config)
|
|
||||||
camera.connect(warmup=False)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with pytest.raises(TimeoutError):
|
|
||||||
camera.async_read(timeout_ms=0)
|
|
||||||
finally:
|
|
||||||
if camera.is_connected:
|
|
||||||
camera.disconnect()
|
|
||||||
|
|
||||||
|
|
||||||
def test_async_read_before_connect():
|
|
||||||
config = RealSenseCameraConfig(serial_number_or_name="042")
|
|
||||||
camera = RealSenseCamera(config)
|
|
||||||
|
|
||||||
with pytest.raises(DeviceNotConnectedError):
|
|
||||||
_ = camera.async_read()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"rotation",
|
|
||||||
[
|
|
||||||
Cv2Rotation.NO_ROTATION,
|
|
||||||
Cv2Rotation.ROTATE_90,
|
|
||||||
Cv2Rotation.ROTATE_180,
|
|
||||||
Cv2Rotation.ROTATE_270,
|
|
||||||
],
|
|
||||||
ids=["no_rot", "rot90", "rot180", "rot270"],
|
|
||||||
)
|
|
||||||
def test_rotation(rotation):
|
|
||||||
config = RealSenseCameraConfig(serial_number_or_name="042", rotation=rotation)
|
|
||||||
camera = RealSenseCamera(config)
|
|
||||||
camera.connect(warmup=False)
|
|
||||||
|
|
||||||
img = camera.read()
|
|
||||||
assert isinstance(img, np.ndarray)
|
|
||||||
|
|
||||||
if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270):
|
|
||||||
assert camera.width == 480
|
|
||||||
assert camera.height == 640
|
|
||||||
assert img.shape[:2] == (640, 480)
|
|
||||||
else:
|
|
||||||
assert camera.width == 640
|
|
||||||
assert camera.height == 480
|
|
||||||
assert img.shape[:2] == (480, 640)
|
|
||||||
@ -1,105 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from collections.abc import Generator
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from lerobot.configs.parser import PluginLoadError, load_plugin, parse_plugin_args, wrap
|
|
||||||
from lerobot.envs.configs import EnvConfig
|
|
||||||
|
|
||||||
|
|
||||||
def create_plugin_code(*, base_class: str = "EnvConfig", plugin_name: str = "test_env") -> str:
|
|
||||||
"""Creates a dummy plugin module that implements its own EnvConfig subclass."""
|
|
||||||
return f"""
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from lerobot.envs.configs import {base_class}
|
|
||||||
|
|
||||||
@{base_class}.register_subclass("{plugin_name}")
|
|
||||||
@dataclass
|
|
||||||
class TestPluginConfig:
|
|
||||||
value: int = 42
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def plugin_dir(tmp_path: Path) -> Generator[Path, None, None]:
|
|
||||||
"""Creates a temporary plugin package structure."""
|
|
||||||
plugin_pkg = tmp_path / "test_plugin"
|
|
||||||
plugin_pkg.mkdir()
|
|
||||||
(plugin_pkg / "__init__.py").touch()
|
|
||||||
|
|
||||||
with open(plugin_pkg / "my_plugin.py", "w") as f:
|
|
||||||
f.write(create_plugin_code())
|
|
||||||
|
|
||||||
# Add tmp_path to Python path so we can import from it
|
|
||||||
sys.path.insert(0, str(tmp_path))
|
|
||||||
yield plugin_pkg
|
|
||||||
sys.path.pop(0)
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_plugin_args():
|
|
||||||
cli_args = [
|
|
||||||
"--env.type=test",
|
|
||||||
"--model.discover_packages_path=some.package",
|
|
||||||
"--env.discover_packages_path=other.package",
|
|
||||||
]
|
|
||||||
plugin_args = parse_plugin_args("discover_packages_path", cli_args)
|
|
||||||
assert plugin_args == {
|
|
||||||
"model.discover_packages_path": "some.package",
|
|
||||||
"env.discover_packages_path": "other.package",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_plugin_success(plugin_dir: Path):
|
|
||||||
# Import should work and register the plugin with the real EnvConfig
|
|
||||||
load_plugin("test_plugin")
|
|
||||||
|
|
||||||
assert "test_env" in EnvConfig.get_known_choices()
|
|
||||||
plugin_cls = EnvConfig.get_choice_class("test_env")
|
|
||||||
plugin_instance = plugin_cls()
|
|
||||||
assert plugin_instance.value == 42
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_plugin_failure():
|
|
||||||
with pytest.raises(PluginLoadError) as exc_info:
|
|
||||||
load_plugin("nonexistent_plugin")
|
|
||||||
assert "Failed to load plugin 'nonexistent_plugin'" in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_wrap_with_plugin(plugin_dir: Path):
|
|
||||||
@dataclass
|
|
||||||
class Config:
|
|
||||||
env: EnvConfig
|
|
||||||
|
|
||||||
@wrap()
|
|
||||||
def dummy_func(cfg: Config):
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
# Test loading plugin via CLI args
|
|
||||||
sys.argv = [
|
|
||||||
"dummy_script.py",
|
|
||||||
"--env.discover_packages_path=test_plugin",
|
|
||||||
"--env.type=test_env",
|
|
||||||
]
|
|
||||||
|
|
||||||
cfg = dummy_func()
|
|
||||||
assert isinstance(cfg, Config)
|
|
||||||
assert isinstance(cfg.env, EnvConfig.get_choice_class("test_env"))
|
|
||||||
assert cfg.env.value == 42
|
|
||||||
@ -1,88 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from serial import SerialException
|
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
|
||||||
from tests.utils import DEVICE
|
|
||||||
|
|
||||||
# Import fixture modules as plugins
|
|
||||||
pytest_plugins = [
|
|
||||||
"tests.fixtures.dataset_factories",
|
|
||||||
"tests.fixtures.files",
|
|
||||||
"tests.fixtures.hub",
|
|
||||||
"tests.fixtures.optimizers",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_collection_finish():
|
|
||||||
print(f"\nTesting with {DEVICE=}")
|
|
||||||
|
|
||||||
|
|
||||||
def _check_component_availability(component_type, available_components, make_component):
|
|
||||||
"""Generic helper to check if a hardware component is available"""
|
|
||||||
if component_type not in available_components:
|
|
||||||
raise ValueError(
|
|
||||||
f"The {component_type} type is not valid. Expected one of these '{available_components}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
component = make_component(component_type)
|
|
||||||
component.connect()
|
|
||||||
del component
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\nA {component_type} is not available.")
|
|
||||||
|
|
||||||
if isinstance(e, ModuleNotFoundError):
|
|
||||||
print(f"\nInstall module '{e.name}'")
|
|
||||||
elif isinstance(e, SerialException):
|
|
||||||
print("\nNo physical device detected.")
|
|
||||||
elif isinstance(e, ValueError) and "camera_index" in str(e):
|
|
||||||
print("\nNo physical camera detected.")
|
|
||||||
else:
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def patch_builtins_input(monkeypatch):
|
|
||||||
def print_text(text=None):
|
|
||||||
if text is not None:
|
|
||||||
print(text)
|
|
||||||
|
|
||||||
monkeypatch.setattr("builtins.input", print_text)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def policy_feature_factory():
|
|
||||||
"""PolicyFeature factory"""
|
|
||||||
|
|
||||||
def _pf(ft: FeatureType, shape: tuple[int, ...]) -> PolicyFeature:
|
|
||||||
return PolicyFeature(type=ft, shape=shape)
|
|
||||||
|
|
||||||
return _pf
|
|
||||||
|
|
||||||
|
|
||||||
def assert_contract_is_typed(features: dict[str, PolicyFeature]) -> None:
|
|
||||||
assert isinstance(features, dict)
|
|
||||||
assert all(isinstance(k, str) for k in features.keys())
|
|
||||||
assert all(isinstance(v, PolicyFeature) for v in features.values())
|
|
||||||
@ -1,309 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from lerobot.datasets.compute_stats import (
|
|
||||||
_assert_type_and_shape,
|
|
||||||
aggregate_feature_stats,
|
|
||||||
aggregate_stats,
|
|
||||||
compute_episode_stats,
|
|
||||||
estimate_num_samples,
|
|
||||||
get_feature_stats,
|
|
||||||
sample_images,
|
|
||||||
sample_indices,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
|
||||||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_array():
|
|
||||||
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
|
||||||
|
|
||||||
|
|
||||||
def test_estimate_num_samples():
|
|
||||||
assert estimate_num_samples(1) == 1
|
|
||||||
assert estimate_num_samples(10) == 10
|
|
||||||
assert estimate_num_samples(100) == 100
|
|
||||||
assert estimate_num_samples(200) == 100
|
|
||||||
assert estimate_num_samples(1000) == 177
|
|
||||||
assert estimate_num_samples(2000) == 299
|
|
||||||
assert estimate_num_samples(5000) == 594
|
|
||||||
assert estimate_num_samples(10_000) == 1000
|
|
||||||
assert estimate_num_samples(20_000) == 1681
|
|
||||||
assert estimate_num_samples(50_000) == 3343
|
|
||||||
assert estimate_num_samples(500_000) == 10_000
|
|
||||||
|
|
||||||
|
|
||||||
def test_sample_indices():
|
|
||||||
indices = sample_indices(10)
|
|
||||||
assert len(indices) > 0
|
|
||||||
assert indices[0] == 0
|
|
||||||
assert indices[-1] == 9
|
|
||||||
assert len(indices) == estimate_num_samples(10)
|
|
||||||
|
|
||||||
|
|
||||||
@patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
|
|
||||||
def test_sample_images(mock_load):
|
|
||||||
image_paths = [f"image_{i}.jpg" for i in range(100)]
|
|
||||||
images = sample_images(image_paths)
|
|
||||||
assert isinstance(images, np.ndarray)
|
|
||||||
assert images.shape[1:] == (3, 32, 32)
|
|
||||||
assert images.dtype == np.uint8
|
|
||||||
assert len(images) == estimate_num_samples(100)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_feature_stats_images():
|
|
||||||
data = np.random.rand(100, 3, 32, 32)
|
|
||||||
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
|
||||||
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
|
|
||||||
np.testing.assert_equal(stats["count"], np.array([100]))
|
|
||||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
|
||||||
expected = {
|
|
||||||
"min": np.array([[1, 2, 3]]),
|
|
||||||
"max": np.array([[7, 8, 9]]),
|
|
||||||
"mean": np.array([[4.0, 5.0, 6.0]]),
|
|
||||||
"std": np.array([[2.44948974, 2.44948974, 2.44948974]]),
|
|
||||||
"count": np.array([3]),
|
|
||||||
}
|
|
||||||
result = get_feature_stats(sample_array, axis=(0,), keepdims=True)
|
|
||||||
for key in expected:
|
|
||||||
np.testing.assert_allclose(result[key], expected[key])
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_feature_stats_axis_1(sample_array):
|
|
||||||
expected = {
|
|
||||||
"min": np.array([1, 4, 7]),
|
|
||||||
"max": np.array([3, 6, 9]),
|
|
||||||
"mean": np.array([2.0, 5.0, 8.0]),
|
|
||||||
"std": np.array([0.81649658, 0.81649658, 0.81649658]),
|
|
||||||
"count": np.array([3]),
|
|
||||||
}
|
|
||||||
result = get_feature_stats(sample_array, axis=(1,), keepdims=False)
|
|
||||||
for key in expected:
|
|
||||||
np.testing.assert_allclose(result[key], expected[key])
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_feature_stats_no_axis(sample_array):
|
|
||||||
expected = {
|
|
||||||
"min": np.array(1),
|
|
||||||
"max": np.array(9),
|
|
||||||
"mean": np.array(5.0),
|
|
||||||
"std": np.array(2.5819889),
|
|
||||||
"count": np.array([3]),
|
|
||||||
}
|
|
||||||
result = get_feature_stats(sample_array, axis=None, keepdims=False)
|
|
||||||
for key in expected:
|
|
||||||
np.testing.assert_allclose(result[key], expected[key])
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_feature_stats_empty_array():
|
|
||||||
array = np.array([])
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
get_feature_stats(array, axis=(0,), keepdims=True)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_feature_stats_single_value():
|
|
||||||
array = np.array([[1337]])
|
|
||||||
result = get_feature_stats(array, axis=None, keepdims=True)
|
|
||||||
np.testing.assert_equal(result["min"], np.array(1337))
|
|
||||||
np.testing.assert_equal(result["max"], np.array(1337))
|
|
||||||
np.testing.assert_equal(result["mean"], np.array(1337.0))
|
|
||||||
np.testing.assert_equal(result["std"], np.array(0.0))
|
|
||||||
np.testing.assert_equal(result["count"], np.array([1]))
|
|
||||||
|
|
||||||
|
|
||||||
def test_compute_episode_stats():
|
|
||||||
episode_data = {
|
|
||||||
"observation.image": [f"image_{i}.jpg" for i in range(100)],
|
|
||||||
"observation.state": np.random.rand(100, 10),
|
|
||||||
}
|
|
||||||
features = {
|
|
||||||
"observation.image": {"dtype": "image"},
|
|
||||||
"observation.state": {"dtype": "numeric"},
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy):
|
|
||||||
stats = compute_episode_stats(episode_data, features)
|
|
||||||
|
|
||||||
assert "observation.image" in stats and "observation.state" in stats
|
|
||||||
assert stats["observation.image"]["count"].item() == 100
|
|
||||||
assert stats["observation.state"]["count"].item() == 100
|
|
||||||
assert stats["observation.image"]["mean"].shape == (3, 1, 1)
|
|
||||||
|
|
||||||
|
|
||||||
def test_assert_type_and_shape_valid():
|
|
||||||
valid_stats = [
|
|
||||||
{
|
|
||||||
"feature1": {
|
|
||||||
"min": np.array([1.0]),
|
|
||||||
"max": np.array([10.0]),
|
|
||||||
"mean": np.array([5.0]),
|
|
||||||
"std": np.array([2.0]),
|
|
||||||
"count": np.array([1]),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
_assert_type_and_shape(valid_stats)
|
|
||||||
|
|
||||||
|
|
||||||
def test_assert_type_and_shape_invalid_type():
|
|
||||||
invalid_stats = [
|
|
||||||
{
|
|
||||||
"feature1": {
|
|
||||||
"min": [1.0], # Not a numpy array
|
|
||||||
"max": np.array([10.0]),
|
|
||||||
"mean": np.array([5.0]),
|
|
||||||
"std": np.array([2.0]),
|
|
||||||
"count": np.array([1]),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
with pytest.raises(ValueError, match="Stats must be composed of numpy array"):
|
|
||||||
_assert_type_and_shape(invalid_stats)
|
|
||||||
|
|
||||||
|
|
||||||
def test_assert_type_and_shape_invalid_shape():
|
|
||||||
invalid_stats = [
|
|
||||||
{
|
|
||||||
"feature1": {
|
|
||||||
"count": np.array([1, 2]), # Wrong shape
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
with pytest.raises(ValueError, match=r"Shape of 'count' must be \(1\)"):
|
|
||||||
_assert_type_and_shape(invalid_stats)
|
|
||||||
|
|
||||||
|
|
||||||
def test_aggregate_feature_stats():
|
|
||||||
stats_ft_list = [
|
|
||||||
{
|
|
||||||
"min": np.array([1.0]),
|
|
||||||
"max": np.array([10.0]),
|
|
||||||
"mean": np.array([5.0]),
|
|
||||||
"std": np.array([2.0]),
|
|
||||||
"count": np.array([1]),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"min": np.array([2.0]),
|
|
||||||
"max": np.array([12.0]),
|
|
||||||
"mean": np.array([6.0]),
|
|
||||||
"std": np.array([2.5]),
|
|
||||||
"count": np.array([1]),
|
|
||||||
},
|
|
||||||
]
|
|
||||||
result = aggregate_feature_stats(stats_ft_list)
|
|
||||||
np.testing.assert_allclose(result["min"], np.array([1.0]))
|
|
||||||
np.testing.assert_allclose(result["max"], np.array([12.0]))
|
|
||||||
np.testing.assert_allclose(result["mean"], np.array([5.5]))
|
|
||||||
np.testing.assert_allclose(result["std"], np.array([2.318405]), atol=1e-6)
|
|
||||||
np.testing.assert_allclose(result["count"], np.array([2]))
|
|
||||||
|
|
||||||
|
|
||||||
def test_aggregate_stats():
|
|
||||||
all_stats = [
|
|
||||||
{
|
|
||||||
"observation.image": {
|
|
||||||
"min": [1, 2, 3],
|
|
||||||
"max": [10, 20, 30],
|
|
||||||
"mean": [5.5, 10.5, 15.5],
|
|
||||||
"std": [2.87, 5.87, 8.87],
|
|
||||||
"count": 10,
|
|
||||||
},
|
|
||||||
"observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10},
|
|
||||||
"extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"observation.image": {
|
|
||||||
"min": [2, 1, 0],
|
|
||||||
"max": [15, 10, 5],
|
|
||||||
"mean": [8.5, 5.5, 2.5],
|
|
||||||
"std": [3.42, 2.42, 1.42],
|
|
||||||
"count": 15,
|
|
||||||
},
|
|
||||||
"observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15},
|
|
||||||
"extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
expected_agg_stats = {
|
|
||||||
"observation.image": {
|
|
||||||
"min": [1, 1, 0],
|
|
||||||
"max": [15, 20, 30],
|
|
||||||
"mean": [7.3, 7.5, 7.7],
|
|
||||||
"std": [3.5317, 4.8267, 8.5581],
|
|
||||||
"count": 25,
|
|
||||||
},
|
|
||||||
"observation.state": {
|
|
||||||
"min": 1,
|
|
||||||
"max": 15,
|
|
||||||
"mean": 7.3,
|
|
||||||
"std": 3.5317,
|
|
||||||
"count": 25,
|
|
||||||
},
|
|
||||||
"extra_key_0": {
|
|
||||||
"min": 5,
|
|
||||||
"max": 25,
|
|
||||||
"mean": 15.0,
|
|
||||||
"std": 6.0,
|
|
||||||
"count": 6,
|
|
||||||
},
|
|
||||||
"extra_key_1": {
|
|
||||||
"min": 0,
|
|
||||||
"max": 20,
|
|
||||||
"mean": 10.0,
|
|
||||||
"std": 5.0,
|
|
||||||
"count": 5,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# cast to numpy
|
|
||||||
for ep_stats in all_stats:
|
|
||||||
for fkey, stats in ep_stats.items():
|
|
||||||
for k in stats:
|
|
||||||
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
|
|
||||||
if fkey == "observation.image" and k != "count":
|
|
||||||
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
|
|
||||||
else:
|
|
||||||
stats[k] = stats[k].reshape(1)
|
|
||||||
|
|
||||||
# cast to numpy
|
|
||||||
for fkey, stats in expected_agg_stats.items():
|
|
||||||
for k in stats:
|
|
||||||
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
|
|
||||||
if fkey == "observation.image" and k != "count":
|
|
||||||
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
|
|
||||||
else:
|
|
||||||
stats[k] = stats[k].reshape(1)
|
|
||||||
|
|
||||||
results = aggregate_stats(all_stats)
|
|
||||||
|
|
||||||
for fkey in expected_agg_stats:
|
|
||||||
np.testing.assert_allclose(results[fkey]["min"], expected_agg_stats[fkey]["min"])
|
|
||||||
np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"])
|
|
||||||
np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"])
|
|
||||||
np.testing.assert_allclose(
|
|
||||||
results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04
|
|
||||||
)
|
|
||||||
np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"])
|
|
||||||
@ -1,573 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from copy import deepcopy
|
|
||||||
from itertools import chain
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from huggingface_hub import HfApi
|
|
||||||
from PIL import Image
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
|
|
||||||
import lerobot
|
|
||||||
from lerobot.configs.default import DatasetConfig
|
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
|
||||||
from lerobot.datasets.factory import make_dataset
|
|
||||||
from lerobot.datasets.image_writer import image_array_to_pil_image
|
|
||||||
from lerobot.datasets.lerobot_dataset import (
|
|
||||||
LeRobotDataset,
|
|
||||||
MultiLeRobotDataset,
|
|
||||||
)
|
|
||||||
from lerobot.datasets.utils import (
|
|
||||||
create_branch,
|
|
||||||
flatten_dict,
|
|
||||||
unflatten_dict,
|
|
||||||
)
|
|
||||||
from lerobot.envs.factory import make_env_config
|
|
||||||
from lerobot.policies.factory import make_policy_config
|
|
||||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
|
||||||
from tests.utils import require_x86_64_kernel
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
|
||||||
features = {
|
|
||||||
"image": {
|
|
||||||
"dtype": "image",
|
|
||||||
"shape": DUMMY_CHW,
|
|
||||||
"names": [
|
|
||||||
"channels",
|
|
||||||
"height",
|
|
||||||
"width",
|
|
||||||
],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
|
||||||
|
|
||||||
|
|
||||||
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
|
||||||
"""
|
|
||||||
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
|
||||||
objects have the same sets of attributes defined.
|
|
||||||
"""
|
|
||||||
# Instantiate both ways
|
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
|
||||||
root_create = tmp_path / "create"
|
|
||||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, features=features, root=root_create)
|
|
||||||
|
|
||||||
root_init = tmp_path / "init"
|
|
||||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
|
||||||
|
|
||||||
init_attr = set(vars(dataset_init).keys())
|
|
||||||
create_attr = set(vars(dataset_create).keys())
|
|
||||||
|
|
||||||
assert init_attr == create_attr
|
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_initialization(tmp_path, lerobot_dataset_factory):
|
|
||||||
kwargs = {
|
|
||||||
"repo_id": DUMMY_REPO_ID,
|
|
||||||
"total_episodes": 10,
|
|
||||||
"total_frames": 400,
|
|
||||||
"episodes": [2, 5, 6],
|
|
||||||
}
|
|
||||||
dataset = lerobot_dataset_factory(root=tmp_path / "test", **kwargs)
|
|
||||||
|
|
||||||
assert dataset.repo_id == kwargs["repo_id"]
|
|
||||||
assert dataset.meta.total_episodes == kwargs["total_episodes"]
|
|
||||||
assert dataset.meta.total_frames == kwargs["total_frames"]
|
|
||||||
assert dataset.episodes == kwargs["episodes"]
|
|
||||||
assert dataset.num_episodes == len(kwargs["episodes"])
|
|
||||||
assert dataset.num_frames == len(dataset)
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory):
|
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n"
|
|
||||||
):
|
|
||||||
dataset.add_frame({"wrong_feature": torch.randn(1)}, task="Dummy task")
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
|
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
|
|
||||||
):
|
|
||||||
dataset.add_frame({"state": torch.randn(1), "extra": "dummy_extra"}, task="Dummy task")
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
|
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
|
|
||||||
):
|
|
||||||
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16)}, task="Dummy task")
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
|
|
||||||
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError,
|
|
||||||
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
|
|
||||||
):
|
|
||||||
dataset.add_frame({"state": torch.randn(1)}, task="Dummy task")
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory):
|
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError,
|
|
||||||
match=re.escape(
|
|
||||||
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'float'>' provided instead.\n"
|
|
||||||
),
|
|
||||||
):
|
|
||||||
dataset.add_frame({"state": 1.0}, task="Dummy task")
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory):
|
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError,
|
|
||||||
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
|
|
||||||
):
|
|
||||||
dataset.add_frame({"state": torch.tensor(1.0)}, task="Dummy task")
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory):
|
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError,
|
|
||||||
match=re.escape(
|
|
||||||
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'numpy.float32'>' provided instead.\n"
|
|
||||||
),
|
|
||||||
):
|
|
||||||
dataset.add_frame({"state": np.float32(1.0)}, task="Dummy task")
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
|
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
|
||||||
dataset.add_frame({"state": torch.randn(1)}, task="Dummy task")
|
|
||||||
dataset.save_episode()
|
|
||||||
|
|
||||||
assert len(dataset) == 1
|
|
||||||
assert dataset[0]["task"] == "Dummy task"
|
|
||||||
assert dataset[0]["task_index"] == 0
|
|
||||||
assert dataset[0]["state"].ndim == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
|
|
||||||
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
|
||||||
dataset.add_frame({"state": torch.randn(2)}, task="Dummy task")
|
|
||||||
dataset.save_episode()
|
|
||||||
|
|
||||||
assert dataset[0]["state"].shape == torch.Size([2])
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
|
|
||||||
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
|
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
|
||||||
dataset.add_frame({"state": torch.randn(2, 4)}, task="Dummy task")
|
|
||||||
dataset.save_episode()
|
|
||||||
|
|
||||||
assert dataset[0]["state"].shape == torch.Size([2, 4])
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
|
|
||||||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
|
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
|
||||||
dataset.add_frame({"state": torch.randn(2, 4, 3)}, task="Dummy task")
|
|
||||||
dataset.save_episode()
|
|
||||||
|
|
||||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
|
|
||||||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
|
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
|
||||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5)}, task="Dummy task")
|
|
||||||
dataset.save_episode()
|
|
||||||
|
|
||||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
|
|
||||||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
|
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
|
||||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1)}, task="Dummy task")
|
|
||||||
dataset.save_episode()
|
|
||||||
|
|
||||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
|
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
|
||||||
dataset.add_frame({"state": np.array([1], dtype=np.float32)}, task="Dummy task")
|
|
||||||
dataset.save_episode()
|
|
||||||
|
|
||||||
assert dataset[0]["state"].ndim == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
|
|
||||||
features = {"caption": {"dtype": "string", "shape": (1,), "names": None}}
|
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
|
||||||
dataset.add_frame({"caption": "Dummy caption"}, task="Dummy task")
|
|
||||||
dataset.save_episode()
|
|
||||||
|
|
||||||
assert dataset[0]["caption"] == "Dummy caption"
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_image_wrong_shape(image_dataset):
|
|
||||||
dataset = image_dataset
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError,
|
|
||||||
match=re.escape(
|
|
||||||
"The feature 'image' of shape '(3, 128, 96)' does not have the expected shape '(3, 96, 128)' or '(96, 128, 3)'.\n"
|
|
||||||
),
|
|
||||||
):
|
|
||||||
c, h, w = DUMMY_CHW
|
|
||||||
dataset.add_frame({"image": torch.randn(c, w, h)}, task="Dummy task")
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_image_wrong_range(image_dataset):
|
|
||||||
"""This test will display the following error message from a thread:
|
|
||||||
```
|
|
||||||
Error writing image ...test_add_frame_image_wrong_ran0/test/images/image/episode_000000/frame_000000.png:
|
|
||||||
The image data type is float, which requires values in the range [0.0, 1.0]. However, the provided range is [0.009678772038470007, 254.9776492089887].
|
|
||||||
Please adjust the range or provide a uint8 image with values in the range [0, 255]
|
|
||||||
```
|
|
||||||
Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`.
|
|
||||||
"""
|
|
||||||
dataset = image_dataset
|
|
||||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255}, task="Dummy task")
|
|
||||||
with pytest.raises(FileNotFoundError):
|
|
||||||
dataset.save_episode()
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_image(image_dataset):
|
|
||||||
dataset = image_dataset
|
|
||||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW)}, task="Dummy task")
|
|
||||||
dataset.save_episode()
|
|
||||||
|
|
||||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_image_h_w_c(image_dataset):
|
|
||||||
dataset = image_dataset
|
|
||||||
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC)}, task="Dummy task")
|
|
||||||
dataset.save_episode()
|
|
||||||
|
|
||||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_image_uint8(image_dataset):
|
|
||||||
dataset = image_dataset
|
|
||||||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
|
||||||
dataset.add_frame({"image": image}, task="Dummy task")
|
|
||||||
dataset.save_episode()
|
|
||||||
|
|
||||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_image_pil(image_dataset):
|
|
||||||
dataset = image_dataset
|
|
||||||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
|
||||||
dataset.add_frame({"image": Image.fromarray(image)}, task="Dummy task")
|
|
||||||
dataset.save_episode()
|
|
||||||
|
|
||||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_array_to_pil_image_wrong_range_float_0_255():
|
|
||||||
image = np.random.rand(*DUMMY_HWC) * 255
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
image_array_to_pil_image(image)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts):
|
|
||||||
# - [ ] test various attributes & state from init and create
|
|
||||||
# - [ ] test init with episodes and check num_frames
|
|
||||||
# - [ ] test add_episode
|
|
||||||
# - [ ] test push_to_hub
|
|
||||||
# - [ ] test smaller methods
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"env_name, repo_id, policy_name",
|
|
||||||
# Single dataset
|
|
||||||
lerobot.env_dataset_policy_triplets,
|
|
||||||
# Multi-dataset
|
|
||||||
# TODO after fix multidataset
|
|
||||||
# + [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
|
|
||||||
)
|
|
||||||
def test_factory(env_name, repo_id, policy_name):
|
|
||||||
"""
|
|
||||||
Tests that:
|
|
||||||
- we can create a dataset with the factory.
|
|
||||||
- for a commonly used set of data keys, the data dimensions are correct.
|
|
||||||
"""
|
|
||||||
cfg = TrainPipelineConfig(
|
|
||||||
# TODO(rcadene, aliberts): remove dataset download
|
|
||||||
dataset=DatasetConfig(repo_id=repo_id, episodes=[0]),
|
|
||||||
env=make_env_config(env_name),
|
|
||||||
policy=make_policy_config(policy_name, push_to_hub=False),
|
|
||||||
)
|
|
||||||
cfg.validate()
|
|
||||||
|
|
||||||
dataset = make_dataset(cfg)
|
|
||||||
delta_timestamps = dataset.delta_timestamps
|
|
||||||
camera_keys = dataset.meta.camera_keys
|
|
||||||
|
|
||||||
item = dataset[0]
|
|
||||||
|
|
||||||
keys_ndim_required = [
|
|
||||||
("action", 1, True),
|
|
||||||
("episode_index", 0, True),
|
|
||||||
("frame_index", 0, True),
|
|
||||||
("timestamp", 0, True),
|
|
||||||
# TODO(rcadene): should we rename it agent_pos?
|
|
||||||
("observation.state", 1, True),
|
|
||||||
("next.reward", 0, False),
|
|
||||||
("next.done", 0, False),
|
|
||||||
]
|
|
||||||
|
|
||||||
# test number of dimensions
|
|
||||||
for key, ndim, required in keys_ndim_required:
|
|
||||||
if key not in item:
|
|
||||||
if required:
|
|
||||||
assert key in item, f"{key}"
|
|
||||||
else:
|
|
||||||
logging.warning(f'Missing key in dataset: "{key}" not in {dataset}.')
|
|
||||||
continue
|
|
||||||
|
|
||||||
if delta_timestamps is not None and key in delta_timestamps:
|
|
||||||
assert item[key].ndim == ndim + 1, f"{key}"
|
|
||||||
assert item[key].shape[0] == len(delta_timestamps[key]), f"{key}"
|
|
||||||
else:
|
|
||||||
assert item[key].ndim == ndim, f"{key}"
|
|
||||||
|
|
||||||
if key in camera_keys:
|
|
||||||
assert item[key].dtype == torch.float32, f"{key}"
|
|
||||||
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
|
||||||
assert item[key].max() <= 1.0, f"{key}"
|
|
||||||
assert item[key].min() >= 0.0, f"{key}"
|
|
||||||
|
|
||||||
if delta_timestamps is not None and key in delta_timestamps:
|
|
||||||
# test t,c,h,w
|
|
||||||
assert item[key].shape[1] == 3, f"{key}"
|
|
||||||
else:
|
|
||||||
# test c,h,w
|
|
||||||
assert item[key].shape[0] == 3, f"{key}"
|
|
||||||
|
|
||||||
if delta_timestamps is not None:
|
|
||||||
# test missing keys in delta_timestamps
|
|
||||||
for key in delta_timestamps:
|
|
||||||
assert key in item, f"{key}"
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
|
|
||||||
@pytest.mark.skip("TODO after fix multidataset")
|
|
||||||
def test_multidataset_frames():
|
|
||||||
"""Check that all dataset frames are incorporated."""
|
|
||||||
# Note: use the image variants of the dataset to make the test approx 3x faster.
|
|
||||||
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
|
|
||||||
# logic that wouldn't be caught with two repo IDs.
|
|
||||||
repo_ids = [
|
|
||||||
"lerobot/aloha_sim_insertion_human_image",
|
|
||||||
"lerobot/aloha_sim_transfer_cube_human_image",
|
|
||||||
"lerobot/aloha_sim_insertion_scripted_image",
|
|
||||||
]
|
|
||||||
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
|
|
||||||
dataset = MultiLeRobotDataset(repo_ids)
|
|
||||||
assert len(dataset) == sum(len(d) for d in sub_datasets)
|
|
||||||
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
|
|
||||||
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
|
|
||||||
|
|
||||||
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
|
|
||||||
# check they match.
|
|
||||||
expected_dataset_indices = []
|
|
||||||
for i, sub_dataset in enumerate(sub_datasets):
|
|
||||||
expected_dataset_indices.extend([i] * len(sub_dataset))
|
|
||||||
|
|
||||||
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
|
|
||||||
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
|
|
||||||
):
|
|
||||||
dataset_index = dataset_item.pop("dataset_index")
|
|
||||||
assert dataset_index == expected_dataset_index
|
|
||||||
assert sub_dataset_item.keys() == dataset_item.keys()
|
|
||||||
for k in sub_dataset_item:
|
|
||||||
assert torch.equal(sub_dataset_item[k], dataset_item[k])
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts): Move to more appropriate location
|
|
||||||
def test_flatten_unflatten_dict():
|
|
||||||
d = {
|
|
||||||
"obs": {
|
|
||||||
"min": 0,
|
|
||||||
"max": 1,
|
|
||||||
"mean": 2,
|
|
||||||
"std": 3,
|
|
||||||
},
|
|
||||||
"action": {
|
|
||||||
"min": 4,
|
|
||||||
"max": 5,
|
|
||||||
"mean": 6,
|
|
||||||
"std": 7,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
original_d = deepcopy(d)
|
|
||||||
d = unflatten_dict(flatten_dict(d))
|
|
||||||
|
|
||||||
# test equality between nested dicts
|
|
||||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"repo_id",
|
|
||||||
[
|
|
||||||
"lerobot/pusht",
|
|
||||||
"lerobot/aloha_sim_insertion_human",
|
|
||||||
"lerobot/xarm_lift_medium",
|
|
||||||
# (michel-aractingi) commenting the two datasets from openx as test is failing
|
|
||||||
# "lerobot/nyu_franka_play_dataset",
|
|
||||||
# "lerobot/cmu_stretch",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@require_x86_64_kernel
|
|
||||||
def test_backward_compatibility(repo_id):
|
|
||||||
"""The artifacts for this test have been generated by `tests/artifacts/datasets/save_dataset_to_safetensors.py`."""
|
|
||||||
|
|
||||||
# TODO(rcadene, aliberts): remove dataset download
|
|
||||||
dataset = LeRobotDataset(repo_id, episodes=[0])
|
|
||||||
|
|
||||||
test_dir = Path("tests/artifacts/datasets") / repo_id
|
|
||||||
|
|
||||||
def load_and_compare(i):
|
|
||||||
new_frame = dataset[i] # noqa: B023
|
|
||||||
old_frame = load_file(test_dir / f"frame_{i}.safetensors") # noqa: B023
|
|
||||||
|
|
||||||
# ignore language instructions (if exists) in language conditioned datasets
|
|
||||||
# TODO (michel-aractingi): transform language obs to language embeddings via tokenizer
|
|
||||||
new_frame.pop("language_instruction", None)
|
|
||||||
old_frame.pop("language_instruction", None)
|
|
||||||
new_frame.pop("task", None)
|
|
||||||
old_frame.pop("task", None)
|
|
||||||
|
|
||||||
# Remove task_index to allow for backward compatibility
|
|
||||||
# TODO(rcadene): remove when new features have been generated
|
|
||||||
if "task_index" not in old_frame:
|
|
||||||
del new_frame["task_index"]
|
|
||||||
|
|
||||||
new_keys = set(new_frame.keys())
|
|
||||||
old_keys = set(old_frame.keys())
|
|
||||||
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
|
|
||||||
|
|
||||||
for key in new_frame:
|
|
||||||
assert torch.isclose(new_frame[key], old_frame[key]).all(), (
|
|
||||||
f"{key=} for index={i} does not contain the same value"
|
|
||||||
)
|
|
||||||
|
|
||||||
# test2 first frames of first episode
|
|
||||||
i = dataset.episode_data_index["from"][0].item()
|
|
||||||
load_and_compare(i)
|
|
||||||
load_and_compare(i + 1)
|
|
||||||
|
|
||||||
# test 2 frames at the middle of first episode
|
|
||||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
|
||||||
load_and_compare(i)
|
|
||||||
load_and_compare(i + 1)
|
|
||||||
|
|
||||||
# test 2 last frames of first episode
|
|
||||||
i = dataset.episode_data_index["to"][0].item()
|
|
||||||
load_and_compare(i - 2)
|
|
||||||
load_and_compare(i - 1)
|
|
||||||
|
|
||||||
# TODO(rcadene): Enable testing on second and last episode
|
|
||||||
# We currently cant because our test dataset only contains the first episode
|
|
||||||
|
|
||||||
# # test 2 first frames of second episode
|
|
||||||
# i = dataset.episode_data_index["from"][1].item()
|
|
||||||
# load_and_compare(i)
|
|
||||||
# load_and_compare(i + 1)
|
|
||||||
|
|
||||||
# # test 2 last frames of second episode
|
|
||||||
# i = dataset.episode_data_index["to"][1].item()
|
|
||||||
# load_and_compare(i - 2)
|
|
||||||
# load_and_compare(i - 1)
|
|
||||||
|
|
||||||
# # test 2 last frames of last episode
|
|
||||||
# i = dataset.episode_data_index["to"][-1].item()
|
|
||||||
# load_and_compare(i - 2)
|
|
||||||
# load_and_compare(i - 1)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("Requires internet access")
|
|
||||||
def test_create_branch():
|
|
||||||
api = HfApi()
|
|
||||||
|
|
||||||
repo_id = "cadene/test_create_branch"
|
|
||||||
repo_type = "dataset"
|
|
||||||
branch = "test"
|
|
||||||
ref = f"refs/heads/{branch}"
|
|
||||||
|
|
||||||
# Prepare a repo with a test branch
|
|
||||||
api.delete_repo(repo_id, repo_type=repo_type, missing_ok=True)
|
|
||||||
api.create_repo(repo_id, repo_type=repo_type)
|
|
||||||
create_branch(repo_id, repo_type=repo_type, branch=branch)
|
|
||||||
|
|
||||||
# Make sure the test branch exists
|
|
||||||
branches = api.list_repo_refs(repo_id, repo_type=repo_type).branches
|
|
||||||
refs = [branch.ref for branch in branches]
|
|
||||||
assert ref in refs
|
|
||||||
|
|
||||||
# Overwrite it
|
|
||||||
create_branch(repo_id, repo_type=repo_type, branch=branch)
|
|
||||||
|
|
||||||
# Clean
|
|
||||||
api.delete_repo(repo_id, repo_type=repo_type)
|
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_feature_with_forward_slash_raises_error():
|
|
||||||
# make sure dir does not exist
|
|
||||||
from lerobot.constants import HF_LEROBOT_HOME
|
|
||||||
|
|
||||||
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
|
|
||||||
# make sure does not exist
|
|
||||||
if dataset_dir.exists():
|
|
||||||
dataset_dir.rmdir()
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
LeRobotDataset.create(
|
|
||||||
repo_id="lerobot/test/with/slash",
|
|
||||||
fps=30,
|
|
||||||
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
|
|
||||||
)
|
|
||||||
@ -1,278 +0,0 @@
|
|||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from itertools import accumulate
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
import numpy as np
|
|
||||||
import pyarrow.compute as pc
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.datasets.utils import (
|
|
||||||
check_delta_timestamps,
|
|
||||||
check_timestamps_sync,
|
|
||||||
get_delta_indices,
|
|
||||||
)
|
|
||||||
from tests.fixtures.constants import DUMMY_MOTOR_FEATURES
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_total_episode(
|
|
||||||
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
|
|
||||||
) -> dict[str, torch.Tensor]:
|
|
||||||
episode_indices = sorted(hf_dataset.unique("episode_index"))
|
|
||||||
total_episodes = len(episode_indices)
|
|
||||||
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
|
|
||||||
raise ValueError("episode_index values are not sorted and contiguous.")
|
|
||||||
return total_episodes
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.ndarray]:
|
|
||||||
episode_lengths = []
|
|
||||||
table = hf_dataset.data.table
|
|
||||||
total_episodes = calculate_total_episode(hf_dataset)
|
|
||||||
for ep_idx in range(total_episodes):
|
|
||||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
|
||||||
episode_lengths.insert(ep_idx, len(ep_table))
|
|
||||||
|
|
||||||
cumulative_lengths = list(accumulate(episode_lengths))
|
|
||||||
return {
|
|
||||||
"from": np.array([0] + cumulative_lengths[:-1], dtype=np.int64),
|
|
||||||
"to": np.array(cumulative_lengths, dtype=np.int64),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def synced_timestamps_factory(hf_dataset_factory):
|
|
||||||
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
||||||
hf_dataset = hf_dataset_factory(fps=fps)
|
|
||||||
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
|
|
||||||
episode_indices = torch.stack(hf_dataset["episode_index"]).numpy()
|
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
|
||||||
return timestamps, episode_indices, episode_data_index
|
|
||||||
|
|
||||||
return _create_synced_timestamps
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def unsynced_timestamps_factory(synced_timestamps_factory):
|
|
||||||
def _create_unsynced_timestamps(
|
|
||||||
fps: int = 30, tolerance_s: float = 1e-4
|
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
||||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
|
|
||||||
timestamps[30] += tolerance_s * 1.1 # Modify a single timestamp just outside tolerance
|
|
||||||
return timestamps, episode_indices, episode_data_index
|
|
||||||
|
|
||||||
return _create_unsynced_timestamps
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def slightly_off_timestamps_factory(synced_timestamps_factory):
|
|
||||||
def _create_slightly_off_timestamps(
|
|
||||||
fps: int = 30, tolerance_s: float = 1e-4
|
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
||||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
|
|
||||||
timestamps[30] += tolerance_s * 0.9 # Modify a single timestamp just inside tolerance
|
|
||||||
return timestamps, episode_indices, episode_data_index
|
|
||||||
|
|
||||||
return _create_slightly_off_timestamps
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def valid_delta_timestamps_factory():
|
|
||||||
def _create_valid_delta_timestamps(
|
|
||||||
fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)
|
|
||||||
) -> dict:
|
|
||||||
delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys}
|
|
||||||
return delta_timestamps
|
|
||||||
|
|
||||||
return _create_valid_delta_timestamps
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
|
||||||
def _create_invalid_delta_timestamps(
|
|
||||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
|
||||||
) -> dict:
|
|
||||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
|
||||||
# Modify a single timestamp just outside tolerance
|
|
||||||
for key in keys:
|
|
||||||
delta_timestamps[key][3] += tolerance_s * 1.1
|
|
||||||
return delta_timestamps
|
|
||||||
|
|
||||||
return _create_invalid_delta_timestamps
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
|
||||||
def _create_slightly_off_delta_timestamps(
|
|
||||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
|
||||||
) -> dict:
|
|
||||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
|
||||||
# Modify a single timestamp just inside tolerance
|
|
||||||
for key in delta_timestamps:
|
|
||||||
delta_timestamps[key][3] += tolerance_s * 0.9
|
|
||||||
delta_timestamps[key][-3] += tolerance_s * 0.9
|
|
||||||
return delta_timestamps
|
|
||||||
|
|
||||||
return _create_slightly_off_delta_timestamps
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def delta_indices_factory():
|
|
||||||
def _delta_indices(keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
|
|
||||||
return {key: list(range(*min_max_range)) for key in keys}
|
|
||||||
|
|
||||||
return _delta_indices
|
|
||||||
|
|
||||||
|
|
||||||
def test_check_timestamps_sync_synced(synced_timestamps_factory):
|
|
||||||
fps = 30
|
|
||||||
tolerance_s = 1e-4
|
|
||||||
timestamps, ep_idx, ep_data_index = synced_timestamps_factory(fps)
|
|
||||||
result = check_timestamps_sync(
|
|
||||||
timestamps=timestamps,
|
|
||||||
episode_indices=ep_idx,
|
|
||||||
episode_data_index=ep_data_index,
|
|
||||||
fps=fps,
|
|
||||||
tolerance_s=tolerance_s,
|
|
||||||
)
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_check_timestamps_sync_unsynced(unsynced_timestamps_factory):
|
|
||||||
fps = 30
|
|
||||||
tolerance_s = 1e-4
|
|
||||||
timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
check_timestamps_sync(
|
|
||||||
timestamps=timestamps,
|
|
||||||
episode_indices=ep_idx,
|
|
||||||
episode_data_index=ep_data_index,
|
|
||||||
fps=fps,
|
|
||||||
tolerance_s=tolerance_s,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_check_timestamps_sync_unsynced_no_exception(unsynced_timestamps_factory):
|
|
||||||
fps = 30
|
|
||||||
tolerance_s = 1e-4
|
|
||||||
timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s)
|
|
||||||
result = check_timestamps_sync(
|
|
||||||
timestamps=timestamps,
|
|
||||||
episode_indices=ep_idx,
|
|
||||||
episode_data_index=ep_data_index,
|
|
||||||
fps=fps,
|
|
||||||
tolerance_s=tolerance_s,
|
|
||||||
raise_value_error=False,
|
|
||||||
)
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_check_timestamps_sync_slightly_off(slightly_off_timestamps_factory):
|
|
||||||
fps = 30
|
|
||||||
tolerance_s = 1e-4
|
|
||||||
timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(fps, tolerance_s)
|
|
||||||
result = check_timestamps_sync(
|
|
||||||
timestamps=timestamps,
|
|
||||||
episode_indices=ep_idx,
|
|
||||||
episode_data_index=ep_data_index,
|
|
||||||
fps=fps,
|
|
||||||
tolerance_s=tolerance_s,
|
|
||||||
)
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_check_timestamps_sync_single_timestamp():
|
|
||||||
fps = 30
|
|
||||||
tolerance_s = 1e-4
|
|
||||||
timestamps, ep_idx = np.array([0.0]), np.array([0])
|
|
||||||
episode_data_index = {"to": np.array([1]), "from": np.array([0])}
|
|
||||||
result = check_timestamps_sync(
|
|
||||||
timestamps=timestamps,
|
|
||||||
episode_indices=ep_idx,
|
|
||||||
episode_data_index=episode_data_index,
|
|
||||||
fps=fps,
|
|
||||||
tolerance_s=tolerance_s,
|
|
||||||
)
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_check_delta_timestamps_valid(valid_delta_timestamps_factory):
|
|
||||||
fps = 30
|
|
||||||
tolerance_s = 1e-4
|
|
||||||
valid_delta_timestamps = valid_delta_timestamps_factory(fps)
|
|
||||||
result = check_delta_timestamps(
|
|
||||||
delta_timestamps=valid_delta_timestamps,
|
|
||||||
fps=fps,
|
|
||||||
tolerance_s=tolerance_s,
|
|
||||||
)
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_check_delta_timestamps_slightly_off(slightly_off_delta_timestamps_factory):
|
|
||||||
fps = 30
|
|
||||||
tolerance_s = 1e-4
|
|
||||||
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(fps, tolerance_s)
|
|
||||||
result = check_delta_timestamps(
|
|
||||||
delta_timestamps=slightly_off_delta_timestamps,
|
|
||||||
fps=fps,
|
|
||||||
tolerance_s=tolerance_s,
|
|
||||||
)
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_check_delta_timestamps_invalid(invalid_delta_timestamps_factory):
|
|
||||||
fps = 30
|
|
||||||
tolerance_s = 1e-4
|
|
||||||
invalid_delta_timestamps = invalid_delta_timestamps_factory(fps, tolerance_s)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
check_delta_timestamps(
|
|
||||||
delta_timestamps=invalid_delta_timestamps,
|
|
||||||
fps=fps,
|
|
||||||
tolerance_s=tolerance_s,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_check_delta_timestamps_invalid_no_exception(invalid_delta_timestamps_factory):
|
|
||||||
fps = 30
|
|
||||||
tolerance_s = 1e-4
|
|
||||||
invalid_delta_timestamps = invalid_delta_timestamps_factory(fps, tolerance_s)
|
|
||||||
result = check_delta_timestamps(
|
|
||||||
delta_timestamps=invalid_delta_timestamps,
|
|
||||||
fps=fps,
|
|
||||||
tolerance_s=tolerance_s,
|
|
||||||
raise_value_error=False,
|
|
||||||
)
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_check_delta_timestamps_empty():
|
|
||||||
delta_timestamps = {}
|
|
||||||
fps = 30
|
|
||||||
tolerance_s = 1e-4
|
|
||||||
result = check_delta_timestamps(
|
|
||||||
delta_timestamps=delta_timestamps,
|
|
||||||
fps=fps,
|
|
||||||
tolerance_s=tolerance_s,
|
|
||||||
)
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_delta_indices(valid_delta_timestamps_factory, delta_indices_factory):
|
|
||||||
fps = 50
|
|
||||||
min_max_range = (-100, 100)
|
|
||||||
delta_timestamps = valid_delta_timestamps_factory(fps, min_max_range=min_max_range)
|
|
||||||
expected_delta_indices = delta_indices_factory(min_max_range=min_max_range)
|
|
||||||
actual_delta_indices = get_delta_indices(delta_timestamps, fps)
|
|
||||||
assert expected_delta_indices == actual_delta_indices
|
|
||||||
@ -1,382 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from packaging import version
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
from torchvision.transforms import v2
|
|
||||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
|
||||||
|
|
||||||
from lerobot.datasets.transforms import (
|
|
||||||
ImageTransformConfig,
|
|
||||||
ImageTransforms,
|
|
||||||
ImageTransformsConfig,
|
|
||||||
RandomSubsetApply,
|
|
||||||
SharpnessJitter,
|
|
||||||
make_transform_from_config,
|
|
||||||
)
|
|
||||||
from lerobot.scripts.visualize_image_transforms import (
|
|
||||||
save_all_transforms,
|
|
||||||
save_each_transform,
|
|
||||||
)
|
|
||||||
from lerobot.utils.random_utils import seeded_context
|
|
||||||
from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ARTIFACT_DIR
|
|
||||||
from tests.utils import require_x86_64_kernel
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def color_jitters():
|
|
||||||
return [
|
|
||||||
v2.ColorJitter(brightness=0.5),
|
|
||||||
v2.ColorJitter(contrast=0.5),
|
|
||||||
v2.ColorJitter(saturation=0.5),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def single_transforms():
|
|
||||||
return load_file(ARTIFACT_DIR / "single_transforms.safetensors")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def img_tensor(single_transforms):
|
|
||||||
return single_transforms["original_frame"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def default_transforms():
|
|
||||||
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_image_transforms_no_transform_enable_false(img_tensor_factory):
|
|
||||||
img_tensor = img_tensor_factory()
|
|
||||||
tf_cfg = ImageTransformsConfig() # default is enable=False
|
|
||||||
tf_actual = ImageTransforms(tf_cfg)
|
|
||||||
torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_image_transforms_no_transform_max_num_transforms_0(img_tensor_factory):
|
|
||||||
img_tensor = img_tensor_factory()
|
|
||||||
tf_cfg = ImageTransformsConfig(enable=True, max_num_transforms=0)
|
|
||||||
tf_actual = ImageTransforms(tf_cfg)
|
|
||||||
torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
|
||||||
def test_get_image_transforms_brightness(img_tensor_factory, min_max):
|
|
||||||
img_tensor = img_tensor_factory()
|
|
||||||
tf_cfg = ImageTransformsConfig(
|
|
||||||
enable=True,
|
|
||||||
tfs={"brightness": ImageTransformConfig(type="ColorJitter", kwargs={"brightness": min_max})},
|
|
||||||
)
|
|
||||||
tf_actual = ImageTransforms(tf_cfg)
|
|
||||||
tf_expected = v2.ColorJitter(brightness=min_max)
|
|
||||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
|
||||||
def test_get_image_transforms_contrast(img_tensor_factory, min_max):
|
|
||||||
img_tensor = img_tensor_factory()
|
|
||||||
tf_cfg = ImageTransformsConfig(
|
|
||||||
enable=True, tfs={"contrast": ImageTransformConfig(type="ColorJitter", kwargs={"contrast": min_max})}
|
|
||||||
)
|
|
||||||
tf_actual = ImageTransforms(tf_cfg)
|
|
||||||
tf_expected = v2.ColorJitter(contrast=min_max)
|
|
||||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
|
||||||
def test_get_image_transforms_saturation(img_tensor_factory, min_max):
|
|
||||||
img_tensor = img_tensor_factory()
|
|
||||||
tf_cfg = ImageTransformsConfig(
|
|
||||||
enable=True,
|
|
||||||
tfs={"saturation": ImageTransformConfig(type="ColorJitter", kwargs={"saturation": min_max})},
|
|
||||||
)
|
|
||||||
tf_actual = ImageTransforms(tf_cfg)
|
|
||||||
tf_expected = v2.ColorJitter(saturation=min_max)
|
|
||||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)])
|
|
||||||
def test_get_image_transforms_hue(img_tensor_factory, min_max):
|
|
||||||
img_tensor = img_tensor_factory()
|
|
||||||
tf_cfg = ImageTransformsConfig(
|
|
||||||
enable=True, tfs={"hue": ImageTransformConfig(type="ColorJitter", kwargs={"hue": min_max})}
|
|
||||||
)
|
|
||||||
tf_actual = ImageTransforms(tf_cfg)
|
|
||||||
tf_expected = v2.ColorJitter(hue=min_max)
|
|
||||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
|
||||||
def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
|
|
||||||
img_tensor = img_tensor_factory()
|
|
||||||
tf_cfg = ImageTransformsConfig(
|
|
||||||
enable=True,
|
|
||||||
tfs={"sharpness": ImageTransformConfig(type="SharpnessJitter", kwargs={"sharpness": min_max})},
|
|
||||||
)
|
|
||||||
tf_actual = ImageTransforms(tf_cfg)
|
|
||||||
tf_expected = SharpnessJitter(sharpness=min_max)
|
|
||||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_image_transforms_max_num_transforms(img_tensor_factory):
|
|
||||||
img_tensor = img_tensor_factory()
|
|
||||||
tf_cfg = ImageTransformsConfig(
|
|
||||||
enable=True,
|
|
||||||
max_num_transforms=5,
|
|
||||||
tfs={
|
|
||||||
"brightness": ImageTransformConfig(
|
|
||||||
weight=1.0,
|
|
||||||
type="ColorJitter",
|
|
||||||
kwargs={"brightness": (0.5, 0.5)},
|
|
||||||
),
|
|
||||||
"contrast": ImageTransformConfig(
|
|
||||||
weight=1.0,
|
|
||||||
type="ColorJitter",
|
|
||||||
kwargs={"contrast": (0.5, 0.5)},
|
|
||||||
),
|
|
||||||
"saturation": ImageTransformConfig(
|
|
||||||
weight=1.0,
|
|
||||||
type="ColorJitter",
|
|
||||||
kwargs={"saturation": (0.5, 0.5)},
|
|
||||||
),
|
|
||||||
"hue": ImageTransformConfig(
|
|
||||||
weight=1.0,
|
|
||||||
type="ColorJitter",
|
|
||||||
kwargs={"hue": (0.5, 0.5)},
|
|
||||||
),
|
|
||||||
"sharpness": ImageTransformConfig(
|
|
||||||
weight=1.0,
|
|
||||||
type="SharpnessJitter",
|
|
||||||
kwargs={"sharpness": (0.5, 0.5)},
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
tf_actual = ImageTransforms(tf_cfg)
|
|
||||||
tf_expected = v2.Compose(
|
|
||||||
[
|
|
||||||
v2.ColorJitter(brightness=(0.5, 0.5)),
|
|
||||||
v2.ColorJitter(contrast=(0.5, 0.5)),
|
|
||||||
v2.ColorJitter(saturation=(0.5, 0.5)),
|
|
||||||
v2.ColorJitter(hue=(0.5, 0.5)),
|
|
||||||
SharpnessJitter(sharpness=(0.5, 0.5)),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
|
||||||
|
|
||||||
|
|
||||||
@require_x86_64_kernel
|
|
||||||
def test_get_image_transforms_random_order(img_tensor_factory):
|
|
||||||
out_imgs = []
|
|
||||||
img_tensor = img_tensor_factory()
|
|
||||||
tf_cfg = ImageTransformsConfig(
|
|
||||||
enable=True,
|
|
||||||
random_order=True,
|
|
||||||
tfs={
|
|
||||||
"brightness": ImageTransformConfig(
|
|
||||||
weight=1.0,
|
|
||||||
type="ColorJitter",
|
|
||||||
kwargs={"brightness": (0.5, 0.5)},
|
|
||||||
),
|
|
||||||
"contrast": ImageTransformConfig(
|
|
||||||
weight=1.0,
|
|
||||||
type="ColorJitter",
|
|
||||||
kwargs={"contrast": (0.5, 0.5)},
|
|
||||||
),
|
|
||||||
"saturation": ImageTransformConfig(
|
|
||||||
weight=1.0,
|
|
||||||
type="ColorJitter",
|
|
||||||
kwargs={"saturation": (0.5, 0.5)},
|
|
||||||
),
|
|
||||||
"hue": ImageTransformConfig(
|
|
||||||
weight=1.0,
|
|
||||||
type="ColorJitter",
|
|
||||||
kwargs={"hue": (0.5, 0.5)},
|
|
||||||
),
|
|
||||||
"sharpness": ImageTransformConfig(
|
|
||||||
weight=1.0,
|
|
||||||
type="SharpnessJitter",
|
|
||||||
kwargs={"sharpness": (0.5, 0.5)},
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
tf = ImageTransforms(tf_cfg)
|
|
||||||
|
|
||||||
with seeded_context(1338):
|
|
||||||
for _ in range(10):
|
|
||||||
out_imgs.append(tf(img_tensor))
|
|
||||||
|
|
||||||
tmp_img_tensor = img_tensor
|
|
||||||
for sub_tf in tf.tf.selected_transforms:
|
|
||||||
tmp_img_tensor = sub_tf(tmp_img_tensor)
|
|
||||||
torch.testing.assert_close(tmp_img_tensor, out_imgs[-1])
|
|
||||||
|
|
||||||
for i in range(1, len(out_imgs)):
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
torch.testing.assert_close(out_imgs[0], out_imgs[i])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"tf_type, tf_name, min_max_values",
|
|
||||||
[
|
|
||||||
("ColorJitter", "brightness", [(0.5, 0.5), (2.0, 2.0)]),
|
|
||||||
("ColorJitter", "contrast", [(0.5, 0.5), (2.0, 2.0)]),
|
|
||||||
("ColorJitter", "saturation", [(0.5, 0.5), (2.0, 2.0)]),
|
|
||||||
("ColorJitter", "hue", [(-0.25, -0.25), (0.25, 0.25)]),
|
|
||||||
("SharpnessJitter", "sharpness", [(0.5, 0.5), (2.0, 2.0)]),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_backward_compatibility_single_transforms(
|
|
||||||
img_tensor, tf_type, tf_name, min_max_values, single_transforms
|
|
||||||
):
|
|
||||||
for min_max in min_max_values:
|
|
||||||
tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max})
|
|
||||||
tf = make_transform_from_config(tf_cfg)
|
|
||||||
actual = tf(img_tensor)
|
|
||||||
key = f"{tf_name}_{min_max[0]}_{min_max[1]}"
|
|
||||||
expected = single_transforms[key]
|
|
||||||
torch.testing.assert_close(actual, expected)
|
|
||||||
|
|
||||||
|
|
||||||
@require_x86_64_kernel
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
version.parse(torch.__version__) < version.parse("2.7.0"),
|
|
||||||
reason="Test artifacts were generated with PyTorch >= 2.7.0 which has different multinomial behavior",
|
|
||||||
)
|
|
||||||
def test_backward_compatibility_default_config(img_tensor, default_transforms):
|
|
||||||
# NOTE: PyTorch versions have different randomness, it might break this test.
|
|
||||||
# See this PR: https://github.com/huggingface/lerobot/pull/1127.
|
|
||||||
|
|
||||||
cfg = ImageTransformsConfig(enable=True)
|
|
||||||
default_tf = ImageTransforms(cfg)
|
|
||||||
|
|
||||||
with seeded_context(1337):
|
|
||||||
actual = default_tf(img_tensor)
|
|
||||||
|
|
||||||
expected = default_transforms["default"]
|
|
||||||
|
|
||||||
torch.testing.assert_close(actual, expected)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("p", [[0, 1], [1, 0]])
|
|
||||||
def test_random_subset_apply_single_choice(img_tensor_factory, p):
|
|
||||||
img_tensor = img_tensor_factory()
|
|
||||||
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
|
||||||
random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False)
|
|
||||||
actual = random_choice(img_tensor)
|
|
||||||
|
|
||||||
p_horz, _ = p
|
|
||||||
if p_horz:
|
|
||||||
torch.testing.assert_close(actual, F.horizontal_flip(img_tensor))
|
|
||||||
else:
|
|
||||||
torch.testing.assert_close(actual, F.vertical_flip(img_tensor))
|
|
||||||
|
|
||||||
|
|
||||||
def test_random_subset_apply_random_order(img_tensor_factory):
|
|
||||||
img_tensor = img_tensor_factory()
|
|
||||||
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
|
||||||
random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True)
|
|
||||||
# We can't really check whether the transforms are actually applied in random order. However,
|
|
||||||
# horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
|
|
||||||
# applies them in random order, we can use a fixed order to compute the expected value.
|
|
||||||
actual = random_order(img_tensor)
|
|
||||||
expected = v2.Compose(flips)(img_tensor)
|
|
||||||
torch.testing.assert_close(actual, expected)
|
|
||||||
|
|
||||||
|
|
||||||
def test_random_subset_apply_valid_transforms(img_tensor_factory, color_jitters):
|
|
||||||
img_tensor = img_tensor_factory()
|
|
||||||
transform = RandomSubsetApply(color_jitters)
|
|
||||||
output = transform(img_tensor)
|
|
||||||
assert output.shape == img_tensor.shape
|
|
||||||
|
|
||||||
|
|
||||||
def test_random_subset_apply_probability_length_mismatch(color_jitters):
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
RandomSubsetApply(color_jitters, p=[0.5, 0.5])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n_subset", [0, 5])
|
|
||||||
def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset):
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
RandomSubsetApply(color_jitters, n_subset=n_subset)
|
|
||||||
|
|
||||||
|
|
||||||
def test_sharpness_jitter_valid_range_tuple(img_tensor_factory):
|
|
||||||
img_tensor = img_tensor_factory()
|
|
||||||
tf = SharpnessJitter((0.1, 2.0))
|
|
||||||
output = tf(img_tensor)
|
|
||||||
assert output.shape == img_tensor.shape
|
|
||||||
|
|
||||||
|
|
||||||
def test_sharpness_jitter_valid_range_float(img_tensor_factory):
|
|
||||||
img_tensor = img_tensor_factory()
|
|
||||||
tf = SharpnessJitter(0.5)
|
|
||||||
output = tf(img_tensor)
|
|
||||||
assert output.shape == img_tensor.shape
|
|
||||||
|
|
||||||
|
|
||||||
def test_sharpness_jitter_invalid_range_min_negative():
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
SharpnessJitter((-0.1, 2.0))
|
|
||||||
|
|
||||||
|
|
||||||
def test_sharpness_jitter_invalid_range_max_smaller():
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
SharpnessJitter((2.0, 0.1))
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_all_transforms(img_tensor_factory, tmp_path):
|
|
||||||
img_tensor = img_tensor_factory()
|
|
||||||
tf_cfg = ImageTransformsConfig(enable=True)
|
|
||||||
n_examples = 3
|
|
||||||
|
|
||||||
save_all_transforms(tf_cfg, img_tensor, tmp_path, n_examples)
|
|
||||||
|
|
||||||
# Check if the combined transforms directory exists and contains the right files
|
|
||||||
combined_transforms_dir = tmp_path / "all"
|
|
||||||
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
|
|
||||||
assert any(combined_transforms_dir.iterdir()), (
|
|
||||||
"No transformed images found in combined transforms directory."
|
|
||||||
)
|
|
||||||
for i in range(1, n_examples + 1):
|
|
||||||
assert (combined_transforms_dir / f"{i}.png").exists(), (
|
|
||||||
f"Combined transform image {i}.png was not found."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_each_transform(img_tensor_factory, tmp_path):
|
|
||||||
img_tensor = img_tensor_factory()
|
|
||||||
tf_cfg = ImageTransformsConfig(enable=True)
|
|
||||||
n_examples = 3
|
|
||||||
|
|
||||||
save_each_transform(tf_cfg, img_tensor, tmp_path, n_examples)
|
|
||||||
|
|
||||||
# Check if the transformed images exist for each transform type
|
|
||||||
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"]
|
|
||||||
for transform in transforms:
|
|
||||||
transform_dir = tmp_path / transform
|
|
||||||
assert transform_dir.exists(), f"{transform} directory was not created."
|
|
||||||
assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory."
|
|
||||||
|
|
||||||
# Check for specific files within each transform directory
|
|
||||||
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
|
|
||||||
for file_name in expected_files:
|
|
||||||
assert (transform_dir / file_name).exists(), (
|
|
||||||
f"{file_name} was not found in {transform} directory."
|
|
||||||
)
|
|
||||||
@ -1,386 +0,0 @@
|
|||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import queue
|
|
||||||
import time
|
|
||||||
from multiprocessing import queues
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from lerobot.datasets.image_writer import (
|
|
||||||
AsyncImageWriter,
|
|
||||||
image_array_to_pil_image,
|
|
||||||
safe_stop_image_writer,
|
|
||||||
write_image,
|
|
||||||
)
|
|
||||||
from tests.fixtures.constants import DUMMY_HWC
|
|
||||||
|
|
||||||
DUMMY_IMAGE = "test_image.png"
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_threading():
|
|
||||||
writer = AsyncImageWriter(num_processes=0, num_threads=2)
|
|
||||||
try:
|
|
||||||
assert writer.num_processes == 0
|
|
||||||
assert writer.num_threads == 2
|
|
||||||
assert isinstance(writer.queue, queue.Queue)
|
|
||||||
assert len(writer.threads) == 2
|
|
||||||
assert len(writer.processes) == 0
|
|
||||||
assert all(t.is_alive() for t in writer.threads)
|
|
||||||
finally:
|
|
||||||
writer.stop()
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_multiprocessing():
|
|
||||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
|
||||||
try:
|
|
||||||
assert writer.num_processes == 2
|
|
||||||
assert writer.num_threads == 2
|
|
||||||
assert isinstance(writer.queue, queues.JoinableQueue)
|
|
||||||
assert len(writer.threads) == 0
|
|
||||||
assert len(writer.processes) == 2
|
|
||||||
assert all(p.is_alive() for p in writer.processes)
|
|
||||||
finally:
|
|
||||||
writer.stop()
|
|
||||||
|
|
||||||
|
|
||||||
def test_zero_threads():
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
AsyncImageWriter(num_processes=0, num_threads=0)
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_array_to_pil_image_float_array_wrong_range_0_255():
|
|
||||||
image = np.random.rand(*DUMMY_HWC) * 255
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
image_array_to_pil_image(image)
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_array_to_pil_image_float_array_wrong_range_neg_1_1():
|
|
||||||
image = np.random.rand(*DUMMY_HWC) * 2 - 1
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
image_array_to_pil_image(image)
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_array_to_pil_image_rgb(img_array_factory):
|
|
||||||
img_array = img_array_factory(100, 100)
|
|
||||||
result_image = image_array_to_pil_image(img_array)
|
|
||||||
assert isinstance(result_image, Image.Image)
|
|
||||||
assert result_image.size == (100, 100)
|
|
||||||
assert result_image.mode == "RGB"
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_array_to_pil_image_pytorch_format(img_array_factory):
|
|
||||||
img_array = img_array_factory(100, 100).transpose(2, 0, 1)
|
|
||||||
result_image = image_array_to_pil_image(img_array)
|
|
||||||
assert isinstance(result_image, Image.Image)
|
|
||||||
assert result_image.size == (100, 100)
|
|
||||||
assert result_image.mode == "RGB"
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_array_to_pil_image_single_channel(img_array_factory):
|
|
||||||
img_array = img_array_factory(channels=1)
|
|
||||||
with pytest.raises(NotImplementedError):
|
|
||||||
image_array_to_pil_image(img_array)
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_array_to_pil_image_4_channels(img_array_factory):
|
|
||||||
img_array = img_array_factory(channels=4)
|
|
||||||
with pytest.raises(NotImplementedError):
|
|
||||||
image_array_to_pil_image(img_array)
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_array_to_pil_image_float_array(img_array_factory):
|
|
||||||
img_array = img_array_factory(dtype=np.float32)
|
|
||||||
result_image = image_array_to_pil_image(img_array)
|
|
||||||
assert isinstance(result_image, Image.Image)
|
|
||||||
assert result_image.size == (100, 100)
|
|
||||||
assert result_image.mode == "RGB"
|
|
||||||
assert np.array(result_image).dtype == np.uint8
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_array_to_pil_image_uint8_array(img_array_factory):
|
|
||||||
img_array = img_array_factory(dtype=np.float32)
|
|
||||||
result_image = image_array_to_pil_image(img_array)
|
|
||||||
assert isinstance(result_image, Image.Image)
|
|
||||||
assert result_image.size == (100, 100)
|
|
||||||
assert result_image.mode == "RGB"
|
|
||||||
assert np.array(result_image).dtype == np.uint8
|
|
||||||
|
|
||||||
|
|
||||||
def test_write_image_numpy(tmp_path, img_array_factory):
|
|
||||||
image_array = img_array_factory()
|
|
||||||
fpath = tmp_path / DUMMY_IMAGE
|
|
||||||
write_image(image_array, fpath)
|
|
||||||
assert fpath.exists()
|
|
||||||
saved_image = np.array(Image.open(fpath))
|
|
||||||
assert np.array_equal(image_array, saved_image)
|
|
||||||
|
|
||||||
|
|
||||||
def test_write_image_image(tmp_path, img_factory):
|
|
||||||
image_pil = img_factory()
|
|
||||||
fpath = tmp_path / DUMMY_IMAGE
|
|
||||||
write_image(image_pil, fpath)
|
|
||||||
assert fpath.exists()
|
|
||||||
saved_image = Image.open(fpath)
|
|
||||||
assert list(saved_image.getdata()) == list(image_pil.getdata())
|
|
||||||
assert np.array_equal(image_pil, saved_image)
|
|
||||||
|
|
||||||
|
|
||||||
def test_write_image_exception(tmp_path):
|
|
||||||
image_array = "invalid data"
|
|
||||||
fpath = tmp_path / DUMMY_IMAGE
|
|
||||||
with patch("builtins.print") as mock_print:
|
|
||||||
write_image(image_array, fpath)
|
|
||||||
mock_print.assert_called()
|
|
||||||
assert not fpath.exists()
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_image_numpy(tmp_path, img_array_factory):
|
|
||||||
writer = AsyncImageWriter()
|
|
||||||
try:
|
|
||||||
image_array = img_array_factory()
|
|
||||||
fpath = tmp_path / DUMMY_IMAGE
|
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
writer.save_image(image_array, fpath)
|
|
||||||
writer.wait_until_done()
|
|
||||||
assert fpath.exists()
|
|
||||||
saved_image = np.array(Image.open(fpath))
|
|
||||||
assert np.array_equal(image_array, saved_image)
|
|
||||||
finally:
|
|
||||||
writer.stop()
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_image_numpy_multiprocessing(tmp_path, img_array_factory):
|
|
||||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
|
||||||
try:
|
|
||||||
image_array = img_array_factory()
|
|
||||||
fpath = tmp_path / DUMMY_IMAGE
|
|
||||||
writer.save_image(image_array, fpath)
|
|
||||||
writer.wait_until_done()
|
|
||||||
assert fpath.exists()
|
|
||||||
saved_image = np.array(Image.open(fpath))
|
|
||||||
assert np.array_equal(image_array, saved_image)
|
|
||||||
finally:
|
|
||||||
writer.stop()
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_image_torch(tmp_path, img_tensor_factory):
|
|
||||||
writer = AsyncImageWriter()
|
|
||||||
try:
|
|
||||||
image_tensor = img_tensor_factory()
|
|
||||||
fpath = tmp_path / DUMMY_IMAGE
|
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
writer.save_image(image_tensor, fpath)
|
|
||||||
writer.wait_until_done()
|
|
||||||
assert fpath.exists()
|
|
||||||
saved_image = np.array(Image.open(fpath))
|
|
||||||
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
|
||||||
assert np.array_equal(expected_image, saved_image)
|
|
||||||
finally:
|
|
||||||
writer.stop()
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
|
|
||||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
|
||||||
try:
|
|
||||||
image_tensor = img_tensor_factory()
|
|
||||||
fpath = tmp_path / DUMMY_IMAGE
|
|
||||||
writer.save_image(image_tensor, fpath)
|
|
||||||
writer.wait_until_done()
|
|
||||||
assert fpath.exists()
|
|
||||||
saved_image = np.array(Image.open(fpath))
|
|
||||||
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
|
||||||
assert np.array_equal(expected_image, saved_image)
|
|
||||||
finally:
|
|
||||||
writer.stop()
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_image_pil(tmp_path, img_factory):
|
|
||||||
writer = AsyncImageWriter()
|
|
||||||
try:
|
|
||||||
image_pil = img_factory()
|
|
||||||
fpath = tmp_path / DUMMY_IMAGE
|
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
writer.save_image(image_pil, fpath)
|
|
||||||
writer.wait_until_done()
|
|
||||||
assert fpath.exists()
|
|
||||||
saved_image = Image.open(fpath)
|
|
||||||
assert list(saved_image.getdata()) == list(image_pil.getdata())
|
|
||||||
finally:
|
|
||||||
writer.stop()
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_image_pil_multiprocessing(tmp_path, img_factory):
|
|
||||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
|
||||||
try:
|
|
||||||
image_pil = img_factory()
|
|
||||||
fpath = tmp_path / DUMMY_IMAGE
|
|
||||||
writer.save_image(image_pil, fpath)
|
|
||||||
writer.wait_until_done()
|
|
||||||
assert fpath.exists()
|
|
||||||
saved_image = Image.open(fpath)
|
|
||||||
assert list(saved_image.getdata()) == list(image_pil.getdata())
|
|
||||||
finally:
|
|
||||||
writer.stop()
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_image_invalid_data(tmp_path):
|
|
||||||
writer = AsyncImageWriter()
|
|
||||||
try:
|
|
||||||
image_array = "invalid data"
|
|
||||||
fpath = tmp_path / DUMMY_IMAGE
|
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with patch("builtins.print") as mock_print:
|
|
||||||
writer.save_image(image_array, fpath)
|
|
||||||
writer.wait_until_done()
|
|
||||||
mock_print.assert_called()
|
|
||||||
assert not fpath.exists()
|
|
||||||
finally:
|
|
||||||
writer.stop()
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_image_after_stop(tmp_path, img_array_factory):
|
|
||||||
writer = AsyncImageWriter()
|
|
||||||
writer.stop()
|
|
||||||
image_array = img_array_factory()
|
|
||||||
fpath = tmp_path / DUMMY_IMAGE
|
|
||||||
writer.save_image(image_array, fpath)
|
|
||||||
time.sleep(1)
|
|
||||||
assert not fpath.exists()
|
|
||||||
|
|
||||||
|
|
||||||
def test_stop():
|
|
||||||
writer = AsyncImageWriter(num_processes=0, num_threads=2)
|
|
||||||
writer.stop()
|
|
||||||
assert not any(t.is_alive() for t in writer.threads)
|
|
||||||
|
|
||||||
|
|
||||||
def test_stop_multiprocessing():
|
|
||||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
|
||||||
writer.stop()
|
|
||||||
assert not any(p.is_alive() for p in writer.processes)
|
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_stops():
|
|
||||||
writer = AsyncImageWriter()
|
|
||||||
writer.stop()
|
|
||||||
writer.stop() # Should not raise an exception
|
|
||||||
assert not any(t.is_alive() for t in writer.threads)
|
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_stops_multiprocessing():
|
|
||||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
|
||||||
writer.stop()
|
|
||||||
writer.stop() # Should not raise an exception
|
|
||||||
assert not any(t.is_alive() for t in writer.threads)
|
|
||||||
|
|
||||||
|
|
||||||
def test_wait_until_done(tmp_path, img_array_factory):
|
|
||||||
writer = AsyncImageWriter(num_processes=0, num_threads=4)
|
|
||||||
try:
|
|
||||||
num_images = 100
|
|
||||||
image_arrays = [img_array_factory(height=500, width=500) for _ in range(num_images)]
|
|
||||||
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
|
|
||||||
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
|
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
writer.save_image(image_array, fpath)
|
|
||||||
writer.wait_until_done()
|
|
||||||
for i, fpath in enumerate(fpaths):
|
|
||||||
assert fpath.exists()
|
|
||||||
saved_image = np.array(Image.open(fpath))
|
|
||||||
assert np.array_equal(saved_image, image_arrays[i])
|
|
||||||
finally:
|
|
||||||
writer.stop()
|
|
||||||
|
|
||||||
|
|
||||||
def test_wait_until_done_multiprocessing(tmp_path, img_array_factory):
|
|
||||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
|
||||||
try:
|
|
||||||
num_images = 100
|
|
||||||
image_arrays = [img_array_factory() for _ in range(num_images)]
|
|
||||||
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
|
|
||||||
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
|
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
writer.save_image(image_array, fpath)
|
|
||||||
writer.wait_until_done()
|
|
||||||
for i, fpath in enumerate(fpaths):
|
|
||||||
assert fpath.exists()
|
|
||||||
saved_image = np.array(Image.open(fpath))
|
|
||||||
assert np.array_equal(saved_image, image_arrays[i])
|
|
||||||
finally:
|
|
||||||
writer.stop()
|
|
||||||
|
|
||||||
|
|
||||||
def test_exception_handling(tmp_path, img_array_factory):
|
|
||||||
writer = AsyncImageWriter()
|
|
||||||
try:
|
|
||||||
image_array = img_array_factory()
|
|
||||||
with (
|
|
||||||
patch.object(writer.queue, "put", side_effect=queue.Full("Queue is full")),
|
|
||||||
pytest.raises(queue.Full) as exc_info,
|
|
||||||
):
|
|
||||||
writer.save_image(image_array, tmp_path / "test.png")
|
|
||||||
assert str(exc_info.value) == "Queue is full"
|
|
||||||
finally:
|
|
||||||
writer.stop()
|
|
||||||
|
|
||||||
|
|
||||||
def test_with_different_image_formats(tmp_path, img_array_factory):
|
|
||||||
writer = AsyncImageWriter()
|
|
||||||
try:
|
|
||||||
image_array = img_array_factory()
|
|
||||||
formats = ["png", "jpeg", "bmp"]
|
|
||||||
for fmt in formats:
|
|
||||||
fpath = tmp_path / f"test_image.{fmt}"
|
|
||||||
write_image(image_array, fpath)
|
|
||||||
assert fpath.exists()
|
|
||||||
finally:
|
|
||||||
writer.stop()
|
|
||||||
|
|
||||||
|
|
||||||
def test_safe_stop_image_writer_decorator():
|
|
||||||
class MockDataset:
|
|
||||||
def __init__(self):
|
|
||||||
self.image_writer = MagicMock(spec=AsyncImageWriter)
|
|
||||||
|
|
||||||
@safe_stop_image_writer
|
|
||||||
def function_that_raises_exception(dataset=None):
|
|
||||||
raise Exception("Test exception")
|
|
||||||
|
|
||||||
dataset = MockDataset()
|
|
||||||
|
|
||||||
with pytest.raises(Exception) as exc_info:
|
|
||||||
function_that_raises_exception(dataset=dataset)
|
|
||||||
|
|
||||||
assert str(exc_info.value) == "Test exception"
|
|
||||||
dataset.image_writer.stop.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
def test_main_process_time(tmp_path, img_tensor_factory):
|
|
||||||
writer = AsyncImageWriter()
|
|
||||||
try:
|
|
||||||
image_tensor = img_tensor_factory()
|
|
||||||
fpath = tmp_path / DUMMY_IMAGE
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
writer.save_image(image_tensor, fpath)
|
|
||||||
end_time = time.perf_counter()
|
|
||||||
time_spent = end_time - start_time
|
|
||||||
# Might need to adjust this threshold depending on hardware
|
|
||||||
assert time_spent < 0.01, f"Main process time exceeded threshold: {time_spent}s"
|
|
||||||
writer.wait_until_done()
|
|
||||||
assert fpath.exists()
|
|
||||||
finally:
|
|
||||||
writer.stop()
|
|
||||||
@ -1,282 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.d
|
|
||||||
from copy import deepcopy
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
|
|
||||||
|
|
||||||
# Some constants for OnlineBuffer tests.
|
|
||||||
data_key = "data"
|
|
||||||
data_shape = (2, 3) # just some arbitrary > 1D shape
|
|
||||||
buffer_capacity = 100
|
|
||||||
fps = 10
|
|
||||||
|
|
||||||
|
|
||||||
def make_new_buffer(
|
|
||||||
write_dir: str | None = None, delta_timestamps: dict[str, list[float]] | None = None
|
|
||||||
) -> tuple[OnlineBuffer, str]:
|
|
||||||
if write_dir is None:
|
|
||||||
write_dir = f"/tmp/online_buffer_{uuid4().hex}"
|
|
||||||
buffer = OnlineBuffer(
|
|
||||||
write_dir,
|
|
||||||
data_spec={data_key: {"shape": data_shape, "dtype": np.dtype("float32")}},
|
|
||||||
buffer_capacity=buffer_capacity,
|
|
||||||
fps=fps,
|
|
||||||
delta_timestamps=delta_timestamps,
|
|
||||||
)
|
|
||||||
return buffer, write_dir
|
|
||||||
|
|
||||||
|
|
||||||
def make_spoof_data_frames(n_episodes: int, n_frames_per_episode: int) -> dict[str, np.ndarray]:
|
|
||||||
new_data = {
|
|
||||||
data_key: np.arange(n_frames_per_episode * n_episodes * np.prod(data_shape)).reshape(-1, *data_shape),
|
|
||||||
OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes),
|
|
||||||
OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(np.arange(n_episodes), n_frames_per_episode),
|
|
||||||
OnlineBuffer.FRAME_INDEX_KEY: np.tile(np.arange(n_frames_per_episode), n_episodes),
|
|
||||||
OnlineBuffer.TIMESTAMP_KEY: np.tile(np.arange(n_frames_per_episode) / fps, n_episodes),
|
|
||||||
}
|
|
||||||
return new_data
|
|
||||||
|
|
||||||
|
|
||||||
def test_non_mutate():
|
|
||||||
"""Checks that the data provided to the add_data method is copied rather than passed by reference.
|
|
||||||
|
|
||||||
This means that mutating the data in the buffer does not mutate the original data.
|
|
||||||
|
|
||||||
NOTE: If this test fails, it means some of the other tests may be compromised. For example, we can't trust
|
|
||||||
a success case for `test_write_read`.
|
|
||||||
"""
|
|
||||||
buffer, _ = make_new_buffer()
|
|
||||||
new_data = make_spoof_data_frames(2, buffer_capacity // 4)
|
|
||||||
new_data_copy = deepcopy(new_data)
|
|
||||||
buffer.add_data(new_data)
|
|
||||||
buffer._data[data_key][:] += 1
|
|
||||||
assert all(np.array_equal(new_data[k], new_data_copy[k]) for k in new_data)
|
|
||||||
|
|
||||||
|
|
||||||
def test_index_error_no_data():
|
|
||||||
buffer, _ = make_new_buffer()
|
|
||||||
with pytest.raises(IndexError):
|
|
||||||
buffer[0]
|
|
||||||
|
|
||||||
|
|
||||||
def test_index_error_with_data():
|
|
||||||
buffer, _ = make_new_buffer()
|
|
||||||
n_frames = buffer_capacity // 2
|
|
||||||
new_data = make_spoof_data_frames(1, n_frames)
|
|
||||||
buffer.add_data(new_data)
|
|
||||||
with pytest.raises(IndexError):
|
|
||||||
buffer[n_frames]
|
|
||||||
with pytest.raises(IndexError):
|
|
||||||
buffer[-n_frames - 1]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("do_reload", [False, True])
|
|
||||||
def test_write_read(do_reload: bool):
|
|
||||||
"""Checks that data can be added to the buffer and read back.
|
|
||||||
|
|
||||||
If do_reload we delete the buffer object and load the buffer back from disk before reading.
|
|
||||||
"""
|
|
||||||
buffer, write_dir = make_new_buffer()
|
|
||||||
n_episodes = 2
|
|
||||||
n_frames_per_episode = buffer_capacity // 4
|
|
||||||
new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode)
|
|
||||||
buffer.add_data(new_data)
|
|
||||||
|
|
||||||
if do_reload:
|
|
||||||
del buffer
|
|
||||||
buffer, _ = make_new_buffer(write_dir)
|
|
||||||
|
|
||||||
assert len(buffer) == n_frames_per_episode * n_episodes
|
|
||||||
for i, item in enumerate(buffer):
|
|
||||||
assert all(isinstance(item[k], torch.Tensor) for k in item)
|
|
||||||
assert np.array_equal(item[data_key].numpy(), new_data[data_key][i])
|
|
||||||
|
|
||||||
|
|
||||||
def test_read_data_key():
|
|
||||||
"""Tests that data can be added to a buffer and all data for a. specific key can be read back."""
|
|
||||||
buffer, _ = make_new_buffer()
|
|
||||||
n_episodes = 2
|
|
||||||
n_frames_per_episode = buffer_capacity // 4
|
|
||||||
new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode)
|
|
||||||
buffer.add_data(new_data)
|
|
||||||
|
|
||||||
data_from_buffer = buffer.get_data_by_key(data_key)
|
|
||||||
assert isinstance(data_from_buffer, torch.Tensor)
|
|
||||||
assert np.array_equal(data_from_buffer.numpy(), new_data[data_key])
|
|
||||||
|
|
||||||
|
|
||||||
def test_fifo():
|
|
||||||
"""Checks that if data is added beyond the buffer capacity, we discard the oldest data first."""
|
|
||||||
buffer, _ = make_new_buffer()
|
|
||||||
n_frames_per_episode = buffer_capacity // 4
|
|
||||||
n_episodes = 3
|
|
||||||
new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode)
|
|
||||||
buffer.add_data(new_data)
|
|
||||||
n_more_episodes = 2
|
|
||||||
# Developer sanity check (in case someone changes the global `buffer_capacity`).
|
|
||||||
assert (n_episodes + n_more_episodes) * n_frames_per_episode > buffer_capacity, (
|
|
||||||
"Something went wrong with the test code."
|
|
||||||
)
|
|
||||||
more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode)
|
|
||||||
buffer.add_data(more_new_data)
|
|
||||||
assert len(buffer) == buffer_capacity, "The buffer should be full."
|
|
||||||
|
|
||||||
expected_data = {}
|
|
||||||
for k in new_data:
|
|
||||||
# Concatenate, left-truncate, then roll, to imitate the cyclical FIFO pattern in OnlineBuffer.
|
|
||||||
expected_data[k] = np.roll(
|
|
||||||
np.concatenate([new_data[k], more_new_data[k]])[-buffer_capacity:],
|
|
||||||
shift=len(new_data[k]) + len(more_new_data[k]) - buffer_capacity,
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, item in enumerate(buffer):
|
|
||||||
assert all(isinstance(item[k], torch.Tensor) for k in item)
|
|
||||||
assert np.array_equal(item[data_key].numpy(), expected_data[data_key][i])
|
|
||||||
|
|
||||||
|
|
||||||
def test_delta_timestamps_within_tolerance():
|
|
||||||
"""Check that getting an item with delta_timestamps within tolerance succeeds.
|
|
||||||
|
|
||||||
Note: Copied from `test_datasets.py::test_load_previous_and_future_frames_within_tolerance`.
|
|
||||||
"""
|
|
||||||
# Sanity check on global fps as we are assuming it is 10 here.
|
|
||||||
assert fps == 10, "This test assumes fps==10"
|
|
||||||
buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.139]})
|
|
||||||
new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5)
|
|
||||||
buffer.add_data(new_data)
|
|
||||||
buffer.tolerance_s = 0.04
|
|
||||||
item = buffer[2]
|
|
||||||
data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"]
|
|
||||||
torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values")
|
|
||||||
assert not is_pad.any(), "Unexpected padding detected"
|
|
||||||
|
|
||||||
|
|
||||||
def test_delta_timestamps_outside_tolerance_inside_episode_range():
|
|
||||||
"""Check that getting an item with delta_timestamps outside of tolerance fails.
|
|
||||||
|
|
||||||
We expect it to fail if and only if the requested timestamps are within the episode range.
|
|
||||||
|
|
||||||
Note: Copied from
|
|
||||||
`test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_inside_episode_range`
|
|
||||||
"""
|
|
||||||
# Sanity check on global fps as we are assuming it is 10 here.
|
|
||||||
assert fps == 10, "This test assumes fps==10"
|
|
||||||
buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.141]})
|
|
||||||
new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5)
|
|
||||||
buffer.add_data(new_data)
|
|
||||||
buffer.tolerance_s = 0.04
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
buffer[2]
|
|
||||||
|
|
||||||
|
|
||||||
def test_delta_timestamps_outside_tolerance_outside_episode_range():
|
|
||||||
"""Check that copy-padding of timestamps outside of the episode range works.
|
|
||||||
|
|
||||||
Note: Copied from
|
|
||||||
`test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_outside_episode_range`
|
|
||||||
"""
|
|
||||||
# Sanity check on global fps as we are assuming it is 10 here.
|
|
||||||
assert fps == 10, "This test assumes fps==10"
|
|
||||||
buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.3, -0.24, 0, 0.26, 0.3]})
|
|
||||||
new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5)
|
|
||||||
buffer.add_data(new_data)
|
|
||||||
buffer.tolerance_s = 0.04
|
|
||||||
item = buffer[2]
|
|
||||||
data, is_pad = item["index"], item["index_is_pad"]
|
|
||||||
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
|
|
||||||
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), (
|
|
||||||
"Padding does not match expected values"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
|
||||||
@pytest.mark.parametrize("offline_dataset_size", [1, 6])
|
|
||||||
@pytest.mark.parametrize("online_dataset_size", [0, 4])
|
|
||||||
@pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0])
|
|
||||||
def test_compute_sampler_weights_trivial(
|
|
||||||
lerobot_dataset_factory,
|
|
||||||
tmp_path,
|
|
||||||
offline_dataset_size: int,
|
|
||||||
online_dataset_size: int,
|
|
||||||
online_sampling_ratio: float,
|
|
||||||
):
|
|
||||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size)
|
|
||||||
online_dataset, _ = make_new_buffer()
|
|
||||||
if online_dataset_size > 0:
|
|
||||||
online_dataset.add_data(
|
|
||||||
make_spoof_data_frames(n_episodes=2, n_frames_per_episode=online_dataset_size // 2)
|
|
||||||
)
|
|
||||||
|
|
||||||
weights = compute_sampler_weights(
|
|
||||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
|
||||||
)
|
|
||||||
if offline_dataset_size == 0 or online_dataset_size == 0:
|
|
||||||
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
|
|
||||||
elif online_sampling_ratio == 0:
|
|
||||||
expected_weights = torch.cat([torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)])
|
|
||||||
elif online_sampling_ratio == 1:
|
|
||||||
expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)])
|
|
||||||
expected_weights /= expected_weights.sum()
|
|
||||||
torch.testing.assert_close(weights, expected_weights)
|
|
||||||
|
|
||||||
|
|
||||||
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
|
|
||||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
|
||||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
|
|
||||||
online_dataset, _ = make_new_buffer()
|
|
||||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
|
||||||
online_sampling_ratio = 0.8
|
|
||||||
weights = compute_sampler_weights(
|
|
||||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
|
||||||
)
|
|
||||||
torch.testing.assert_close(
|
|
||||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path):
|
|
||||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
|
||||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
|
|
||||||
online_dataset, _ = make_new_buffer()
|
|
||||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
|
||||||
weights = compute_sampler_weights(
|
|
||||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1
|
|
||||||
)
|
|
||||||
torch.testing.assert_close(
|
|
||||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path):
|
|
||||||
"""Note: test copied from test_sampler."""
|
|
||||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2)
|
|
||||||
online_dataset, _ = make_new_buffer()
|
|
||||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
|
||||||
|
|
||||||
weights = compute_sampler_weights(
|
|
||||||
offline_dataset,
|
|
||||||
offline_drop_n_last_frames=1,
|
|
||||||
online_dataset=online_dataset,
|
|
||||||
online_sampling_ratio=0.5,
|
|
||||||
online_drop_n_last_frames=1,
|
|
||||||
)
|
|
||||||
torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))
|
|
||||||
@ -1,90 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from datasets import Dataset
|
|
||||||
|
|
||||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
|
||||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
|
||||||
from lerobot.datasets.utils import (
|
|
||||||
hf_transform_to_torch,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_drop_n_first_frames():
|
|
||||||
dataset = Dataset.from_dict(
|
|
||||||
{
|
|
||||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
|
||||||
"index": [0, 1, 2, 3, 4, 5],
|
|
||||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
dataset.set_transform(hf_transform_to_torch)
|
|
||||||
episode_data_index = calculate_episode_data_index(dataset)
|
|
||||||
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
|
|
||||||
assert sampler.indices == [1, 4, 5]
|
|
||||||
assert len(sampler) == 3
|
|
||||||
assert list(sampler) == [1, 4, 5]
|
|
||||||
|
|
||||||
|
|
||||||
def test_drop_n_last_frames():
|
|
||||||
dataset = Dataset.from_dict(
|
|
||||||
{
|
|
||||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
|
||||||
"index": [0, 1, 2, 3, 4, 5],
|
|
||||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
dataset.set_transform(hf_transform_to_torch)
|
|
||||||
episode_data_index = calculate_episode_data_index(dataset)
|
|
||||||
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
|
|
||||||
assert sampler.indices == [0, 3, 4]
|
|
||||||
assert len(sampler) == 3
|
|
||||||
assert list(sampler) == [0, 3, 4]
|
|
||||||
|
|
||||||
|
|
||||||
def test_episode_indices_to_use():
|
|
||||||
dataset = Dataset.from_dict(
|
|
||||||
{
|
|
||||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
|
||||||
"index": [0, 1, 2, 3, 4, 5],
|
|
||||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
dataset.set_transform(hf_transform_to_torch)
|
|
||||||
episode_data_index = calculate_episode_data_index(dataset)
|
|
||||||
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
|
|
||||||
assert sampler.indices == [0, 1, 3, 4, 5]
|
|
||||||
assert len(sampler) == 5
|
|
||||||
assert list(sampler) == [0, 1, 3, 4, 5]
|
|
||||||
|
|
||||||
|
|
||||||
def test_shuffle():
|
|
||||||
dataset = Dataset.from_dict(
|
|
||||||
{
|
|
||||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
|
||||||
"index": [0, 1, 2, 3, 4, 5],
|
|
||||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
dataset.set_transform(hf_transform_to_torch)
|
|
||||||
episode_data_index = calculate_episode_data_index(dataset)
|
|
||||||
sampler = EpisodeAwareSampler(episode_data_index, shuffle=False)
|
|
||||||
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
|
||||||
assert len(sampler) == 6
|
|
||||||
assert list(sampler) == [0, 1, 2, 3, 4, 5]
|
|
||||||
sampler = EpisodeAwareSampler(episode_data_index, shuffle=True)
|
|
||||||
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
|
||||||
assert len(sampler) == 6
|
|
||||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
|
||||||
@ -1,55 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from datasets import Dataset
|
|
||||||
from huggingface_hub import DatasetCard
|
|
||||||
|
|
||||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
|
||||||
from lerobot.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
|
|
||||||
|
|
||||||
|
|
||||||
def test_default_parameters():
|
|
||||||
card = create_lerobot_dataset_card()
|
|
||||||
assert isinstance(card, DatasetCard)
|
|
||||||
assert card.data.tags == ["LeRobot"]
|
|
||||||
assert card.data.task_categories == ["robotics"]
|
|
||||||
assert card.data.configs == [
|
|
||||||
{
|
|
||||||
"config_name": "default",
|
|
||||||
"data_files": "data/*/*.parquet",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_with_tags():
|
|
||||||
tags = ["tag1", "tag2"]
|
|
||||||
card = create_lerobot_dataset_card(tags=tags)
|
|
||||||
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_calculate_episode_data_index():
|
|
||||||
dataset = Dataset.from_dict(
|
|
||||||
{
|
|
||||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
|
||||||
"index": [0, 1, 2, 3, 4, 5],
|
|
||||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
dataset.set_transform(hf_transform_to_torch)
|
|
||||||
episode_data_index = calculate_episode_data_index(dataset)
|
|
||||||
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
|
||||||
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
|
||||||
@ -1,33 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from lerobot.scripts.visualize_dataset import visualize_dataset
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("TODO: add dummy videos")
|
|
||||||
def test_visualize_local_dataset(tmp_path, lerobot_dataset_factory):
|
|
||||||
root = tmp_path / "dataset"
|
|
||||||
output_dir = tmp_path / "outputs"
|
|
||||||
dataset = lerobot_dataset_factory(root=root)
|
|
||||||
rrd_path = visualize_dataset(
|
|
||||||
dataset,
|
|
||||||
episode_index=0,
|
|
||||||
batch_size=32,
|
|
||||||
save=True,
|
|
||||||
output_dir=output_dir,
|
|
||||||
)
|
|
||||||
assert rrd_path.exists()
|
|
||||||
@ -1,63 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import importlib
|
|
||||||
|
|
||||||
import gymnasium as gym
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from gymnasium.utils.env_checker import check_env
|
|
||||||
|
|
||||||
import lerobot
|
|
||||||
from lerobot.envs.factory import make_env, make_env_config
|
|
||||||
from lerobot.envs.utils import preprocess_observation
|
|
||||||
from tests.utils import require_env
|
|
||||||
|
|
||||||
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("obs_type", OBS_TYPES)
|
|
||||||
@pytest.mark.parametrize("env_name, env_task", lerobot.env_task_pairs)
|
|
||||||
@require_env
|
|
||||||
def test_env(env_name, env_task, obs_type):
|
|
||||||
if env_name == "aloha" and obs_type == "state":
|
|
||||||
pytest.skip("`state` observations not available for aloha")
|
|
||||||
|
|
||||||
package_name = f"gym_{env_name}"
|
|
||||||
importlib.import_module(package_name)
|
|
||||||
env = gym.make(f"{package_name}/{env_task}", obs_type=obs_type)
|
|
||||||
check_env(env.unwrapped, skip_render_check=True)
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("env_name", lerobot.available_envs)
|
|
||||||
@require_env
|
|
||||||
def test_factory(env_name):
|
|
||||||
cfg = make_env_config(env_name)
|
|
||||||
env = make_env(cfg, n_envs=1)
|
|
||||||
obs, _ = env.reset()
|
|
||||||
obs = preprocess_observation(obs)
|
|
||||||
|
|
||||||
# test image keys are float32 in range [0,1]
|
|
||||||
for key in obs:
|
|
||||||
if "image" not in key:
|
|
||||||
continue
|
|
||||||
img = obs[key]
|
|
||||||
assert img.dtype == torch.float32
|
|
||||||
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
|
||||||
assert img.max() <= 1.0
|
|
||||||
assert img.min() >= 0.0
|
|
||||||
|
|
||||||
env.close()
|
|
||||||
@ -1,147 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import io
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
|
||||||
from tests.utils import require_package
|
|
||||||
|
|
||||||
|
|
||||||
def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str:
|
|
||||||
for f, r in finds_and_replaces:
|
|
||||||
assert f in text
|
|
||||||
text = text.replace(f, r)
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts): Remove usage of subprocess calls and patch code with fixtures
|
|
||||||
def _run_script(path):
|
|
||||||
subprocess.run([sys.executable, path], check=True)
|
|
||||||
|
|
||||||
|
|
||||||
def _read_file(path):
|
|
||||||
with open(path) as file:
|
|
||||||
return file.read()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("TODO Fix and remove subprocess / excec calls")
|
|
||||||
def test_example_1(tmp_path, lerobot_dataset_factory):
|
|
||||||
_ = lerobot_dataset_factory(root=tmp_path, repo_id=DUMMY_REPO_ID)
|
|
||||||
path = "examples/1_load_lerobot_dataset.py"
|
|
||||||
file_contents = _read_file(path)
|
|
||||||
file_contents = _find_and_replace(
|
|
||||||
file_contents,
|
|
||||||
[
|
|
||||||
('repo_id = "lerobot/pusht"', f'repo_id = "{DUMMY_REPO_ID}"'),
|
|
||||||
(
|
|
||||||
"LeRobotDataset(repo_id",
|
|
||||||
f"LeRobotDataset(repo_id, root='{str(tmp_path)}'",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
exec(file_contents, {})
|
|
||||||
assert Path("outputs/examples/1_load_lerobot_dataset/episode_0.mp4").exists()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("TODO Fix and remove subprocess / excec calls")
|
|
||||||
@require_package("gym_pusht")
|
|
||||||
def test_examples_basic2_basic3_advanced1():
|
|
||||||
"""
|
|
||||||
Train a model with example 3, check the outputs.
|
|
||||||
Evaluate the trained model with example 2, check the outputs.
|
|
||||||
Calculate the validation loss with advanced example 1, check the outputs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
### Test example 3
|
|
||||||
file_contents = _read_file("examples/3_train_policy.py")
|
|
||||||
|
|
||||||
# Do fewer steps, use smaller batch, use CPU, and don't complicate things with dataloader workers.
|
|
||||||
file_contents = _find_and_replace(
|
|
||||||
file_contents,
|
|
||||||
[
|
|
||||||
("training_steps = 5000", "training_steps = 1"),
|
|
||||||
("num_workers=4", "num_workers=0"),
|
|
||||||
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
|
|
||||||
("batch_size=64", "batch_size=1"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
|
|
||||||
exec(file_contents, {})
|
|
||||||
|
|
||||||
for file_name in ["model.safetensors", "config.json"]:
|
|
||||||
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
|
||||||
|
|
||||||
### Test example 2
|
|
||||||
file_contents = _read_file("examples/2_evaluate_pretrained_policy.py")
|
|
||||||
|
|
||||||
# Do fewer evals, use CPU, and use the local model.
|
|
||||||
file_contents = _find_and_replace(
|
|
||||||
file_contents,
|
|
||||||
[
|
|
||||||
(
|
|
||||||
'pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))',
|
|
||||||
"",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
|
||||||
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
|
||||||
),
|
|
||||||
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
|
|
||||||
("step += 1", "break"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
exec(file_contents, {})
|
|
||||||
|
|
||||||
assert Path("outputs/eval/example_pusht_diffusion/rollout.mp4").exists()
|
|
||||||
|
|
||||||
## Test example 4
|
|
||||||
file_contents = _read_file("examples/advanced/2_calculate_validation_loss.py")
|
|
||||||
|
|
||||||
# Run on a single example from the last episode, use CPU, and use the local model.
|
|
||||||
file_contents = _find_and_replace(
|
|
||||||
file_contents,
|
|
||||||
[
|
|
||||||
(
|
|
||||||
'pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))',
|
|
||||||
"",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
|
||||||
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
|
||||||
),
|
|
||||||
("train_episodes = episodes[:num_train_episodes]", "train_episodes = [0]"),
|
|
||||||
("val_episodes = episodes[num_train_episodes:]", "val_episodes = [1]"),
|
|
||||||
("num_workers=4", "num_workers=0"),
|
|
||||||
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
|
|
||||||
("batch_size=64", "batch_size=1"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Capture the output of the script
|
|
||||||
output_buffer = io.StringIO()
|
|
||||||
sys.stdout = output_buffer
|
|
||||||
exec(file_contents, {})
|
|
||||||
printed_output = output_buffer.getvalue()
|
|
||||||
# Restore stdout to its original state
|
|
||||||
sys.stdout = sys.__stdout__
|
|
||||||
assert "Average loss on validation set" in printed_output
|
|
||||||
44
tests/fixtures/constants.py
vendored
44
tests/fixtures/constants.py
vendored
@ -1,44 +0,0 @@
|
|||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from lerobot.constants import HF_LEROBOT_HOME
|
|
||||||
|
|
||||||
LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing"
|
|
||||||
DUMMY_REPO_ID = "dummy/repo"
|
|
||||||
DUMMY_ROBOT_TYPE = "dummy_robot"
|
|
||||||
DUMMY_MOTOR_FEATURES = {
|
|
||||||
"action": {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": (6,),
|
|
||||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
|
||||||
},
|
|
||||||
"state": {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": (6,),
|
|
||||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
DUMMY_CAMERA_FEATURES = {
|
|
||||||
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
|
||||||
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
|
||||||
}
|
|
||||||
DEFAULT_FPS = 30
|
|
||||||
DUMMY_VIDEO_INFO = {
|
|
||||||
"video.fps": DEFAULT_FPS,
|
|
||||||
"video.codec": "av1",
|
|
||||||
"video.pix_fmt": "yuv420p",
|
|
||||||
"video.is_depth_map": False,
|
|
||||||
"has_audio": False,
|
|
||||||
}
|
|
||||||
DUMMY_CHW = (3, 96, 128)
|
|
||||||
DUMMY_HWC = (96, 128, 3)
|
|
||||||
444
tests/fixtures/dataset_factories.py
vendored
444
tests/fixtures/dataset_factories.py
vendored
@ -1,444 +0,0 @@
|
|||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import random
|
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Protocol
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
import numpy as np
|
|
||||||
import PIL.Image
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
|
||||||
from lerobot.datasets.utils import (
|
|
||||||
DEFAULT_CHUNK_SIZE,
|
|
||||||
DEFAULT_FEATURES,
|
|
||||||
DEFAULT_PARQUET_PATH,
|
|
||||||
DEFAULT_VIDEO_PATH,
|
|
||||||
get_hf_features_from_features,
|
|
||||||
hf_transform_to_torch,
|
|
||||||
)
|
|
||||||
from tests.fixtures.constants import (
|
|
||||||
DEFAULT_FPS,
|
|
||||||
DUMMY_CAMERA_FEATURES,
|
|
||||||
DUMMY_MOTOR_FEATURES,
|
|
||||||
DUMMY_REPO_ID,
|
|
||||||
DUMMY_ROBOT_TYPE,
|
|
||||||
DUMMY_VIDEO_INFO,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LeRobotDatasetFactory(Protocol):
|
|
||||||
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
|
|
||||||
|
|
||||||
|
|
||||||
def get_task_index(task_dicts: dict, task: str) -> int:
|
|
||||||
tasks = {d["task_index"]: d["task"] for d in task_dicts.values()}
|
|
||||||
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
|
|
||||||
return task_to_task_index[task]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def img_tensor_factory():
|
|
||||||
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
|
|
||||||
return torch.rand((channels, height, width), dtype=dtype)
|
|
||||||
|
|
||||||
return _create_img_tensor
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def img_array_factory():
|
|
||||||
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
|
|
||||||
if np.issubdtype(dtype, np.unsignedinteger):
|
|
||||||
# Int array in [0, 255] range
|
|
||||||
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
|
|
||||||
elif np.issubdtype(dtype, np.floating):
|
|
||||||
# Float array in [0, 1] range
|
|
||||||
img_array = np.random.rand(height, width, channels).astype(dtype)
|
|
||||||
else:
|
|
||||||
raise ValueError(dtype)
|
|
||||||
return img_array
|
|
||||||
|
|
||||||
return _create_img_array
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def img_factory(img_array_factory):
|
|
||||||
def _create_img(height=100, width=100) -> PIL.Image.Image:
|
|
||||||
img_array = img_array_factory(height=height, width=width)
|
|
||||||
return PIL.Image.fromarray(img_array)
|
|
||||||
|
|
||||||
return _create_img
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def features_factory():
|
|
||||||
def _create_features(
|
|
||||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
|
||||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
|
||||||
use_videos: bool = True,
|
|
||||||
) -> dict:
|
|
||||||
if use_videos:
|
|
||||||
camera_ft = {
|
|
||||||
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()}
|
|
||||||
return {
|
|
||||||
**motor_features,
|
|
||||||
**camera_ft,
|
|
||||||
**DEFAULT_FEATURES,
|
|
||||||
}
|
|
||||||
|
|
||||||
return _create_features
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def info_factory(features_factory):
|
|
||||||
def _create_info(
|
|
||||||
codebase_version: str = CODEBASE_VERSION,
|
|
||||||
fps: int = DEFAULT_FPS,
|
|
||||||
robot_type: str = DUMMY_ROBOT_TYPE,
|
|
||||||
total_episodes: int = 0,
|
|
||||||
total_frames: int = 0,
|
|
||||||
total_tasks: int = 0,
|
|
||||||
total_videos: int = 0,
|
|
||||||
total_chunks: int = 0,
|
|
||||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
|
||||||
data_path: str = DEFAULT_PARQUET_PATH,
|
|
||||||
video_path: str = DEFAULT_VIDEO_PATH,
|
|
||||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
|
||||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
|
||||||
use_videos: bool = True,
|
|
||||||
) -> dict:
|
|
||||||
features = features_factory(motor_features, camera_features, use_videos)
|
|
||||||
return {
|
|
||||||
"codebase_version": codebase_version,
|
|
||||||
"robot_type": robot_type,
|
|
||||||
"total_episodes": total_episodes,
|
|
||||||
"total_frames": total_frames,
|
|
||||||
"total_tasks": total_tasks,
|
|
||||||
"total_videos": total_videos,
|
|
||||||
"total_chunks": total_chunks,
|
|
||||||
"chunks_size": chunks_size,
|
|
||||||
"fps": fps,
|
|
||||||
"splits": {},
|
|
||||||
"data_path": data_path,
|
|
||||||
"video_path": video_path if use_videos else None,
|
|
||||||
"features": features,
|
|
||||||
}
|
|
||||||
|
|
||||||
return _create_info
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def stats_factory():
|
|
||||||
def _create_stats(
|
|
||||||
features: dict[str] | None = None,
|
|
||||||
) -> dict:
|
|
||||||
stats = {}
|
|
||||||
for key, ft in features.items():
|
|
||||||
shape = ft["shape"]
|
|
||||||
dtype = ft["dtype"]
|
|
||||||
if dtype in ["image", "video"]:
|
|
||||||
stats[key] = {
|
|
||||||
"max": np.full((3, 1, 1), 1, dtype=np.float32).tolist(),
|
|
||||||
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32).tolist(),
|
|
||||||
"min": np.full((3, 1, 1), 0, dtype=np.float32).tolist(),
|
|
||||||
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
|
|
||||||
"count": [10],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
stats[key] = {
|
|
||||||
"max": np.full(shape, 1, dtype=dtype).tolist(),
|
|
||||||
"mean": np.full(shape, 0.5, dtype=dtype).tolist(),
|
|
||||||
"min": np.full(shape, 0, dtype=dtype).tolist(),
|
|
||||||
"std": np.full(shape, 0.25, dtype=dtype).tolist(),
|
|
||||||
"count": [10],
|
|
||||||
}
|
|
||||||
return stats
|
|
||||||
|
|
||||||
return _create_stats
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def episodes_stats_factory(stats_factory):
|
|
||||||
def _create_episodes_stats(
|
|
||||||
features: dict[str],
|
|
||||||
total_episodes: int = 3,
|
|
||||||
) -> dict:
|
|
||||||
episodes_stats = {}
|
|
||||||
for episode_index in range(total_episodes):
|
|
||||||
episodes_stats[episode_index] = {
|
|
||||||
"episode_index": episode_index,
|
|
||||||
"stats": stats_factory(features),
|
|
||||||
}
|
|
||||||
return episodes_stats
|
|
||||||
|
|
||||||
return _create_episodes_stats
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def tasks_factory():
|
|
||||||
def _create_tasks(total_tasks: int = 3) -> int:
|
|
||||||
tasks = {}
|
|
||||||
for task_index in range(total_tasks):
|
|
||||||
task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."}
|
|
||||||
tasks[task_index] = task_dict
|
|
||||||
return tasks
|
|
||||||
|
|
||||||
return _create_tasks
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def episodes_factory(tasks_factory):
|
|
||||||
def _create_episodes(
|
|
||||||
total_episodes: int = 3,
|
|
||||||
total_frames: int = 400,
|
|
||||||
tasks: dict | None = None,
|
|
||||||
multi_task: bool = False,
|
|
||||||
):
|
|
||||||
if total_episodes <= 0 or total_frames <= 0:
|
|
||||||
raise ValueError("num_episodes and total_length must be positive integers.")
|
|
||||||
if total_frames < total_episodes:
|
|
||||||
raise ValueError("total_length must be greater than or equal to num_episodes.")
|
|
||||||
|
|
||||||
if not tasks:
|
|
||||||
min_tasks = 2 if multi_task else 1
|
|
||||||
total_tasks = random.randint(min_tasks, total_episodes)
|
|
||||||
tasks = tasks_factory(total_tasks)
|
|
||||||
|
|
||||||
if total_episodes < len(tasks) and not multi_task:
|
|
||||||
raise ValueError("The number of tasks should be less than the number of episodes.")
|
|
||||||
|
|
||||||
# Generate random lengths that sum up to total_length
|
|
||||||
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
|
|
||||||
|
|
||||||
tasks_list = [task_dict["task"] for task_dict in tasks.values()]
|
|
||||||
num_tasks_available = len(tasks_list)
|
|
||||||
|
|
||||||
episodes = {}
|
|
||||||
remaining_tasks = tasks_list.copy()
|
|
||||||
for ep_idx in range(total_episodes):
|
|
||||||
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
|
|
||||||
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
|
|
||||||
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
|
|
||||||
if remaining_tasks:
|
|
||||||
for task in episode_tasks:
|
|
||||||
remaining_tasks.remove(task)
|
|
||||||
|
|
||||||
episodes[ep_idx] = {
|
|
||||||
"episode_index": ep_idx,
|
|
||||||
"tasks": episode_tasks,
|
|
||||||
"length": lengths[ep_idx],
|
|
||||||
}
|
|
||||||
|
|
||||||
return episodes
|
|
||||||
|
|
||||||
return _create_episodes
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
|
||||||
def _create_hf_dataset(
|
|
||||||
features: dict | None = None,
|
|
||||||
tasks: list[dict] | None = None,
|
|
||||||
episodes: list[dict] | None = None,
|
|
||||||
fps: int = DEFAULT_FPS,
|
|
||||||
) -> datasets.Dataset:
|
|
||||||
if not tasks:
|
|
||||||
tasks = tasks_factory()
|
|
||||||
if not episodes:
|
|
||||||
episodes = episodes_factory()
|
|
||||||
if not features:
|
|
||||||
features = features_factory()
|
|
||||||
|
|
||||||
timestamp_col = np.array([], dtype=np.float32)
|
|
||||||
frame_index_col = np.array([], dtype=np.int64)
|
|
||||||
episode_index_col = np.array([], dtype=np.int64)
|
|
||||||
task_index = np.array([], dtype=np.int64)
|
|
||||||
for ep_dict in episodes.values():
|
|
||||||
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
|
|
||||||
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
|
|
||||||
episode_index_col = np.concatenate(
|
|
||||||
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
|
|
||||||
)
|
|
||||||
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
|
|
||||||
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
|
|
||||||
|
|
||||||
index_col = np.arange(len(episode_index_col))
|
|
||||||
|
|
||||||
robot_cols = {}
|
|
||||||
for key, ft in features.items():
|
|
||||||
if ft["dtype"] == "image":
|
|
||||||
robot_cols[key] = [
|
|
||||||
img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0])
|
|
||||||
for _ in range(len(index_col))
|
|
||||||
]
|
|
||||||
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
|
|
||||||
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
|
|
||||||
|
|
||||||
hf_features = get_hf_features_from_features(features)
|
|
||||||
dataset = datasets.Dataset.from_dict(
|
|
||||||
{
|
|
||||||
**robot_cols,
|
|
||||||
"timestamp": timestamp_col,
|
|
||||||
"frame_index": frame_index_col,
|
|
||||||
"episode_index": episode_index_col,
|
|
||||||
"index": index_col,
|
|
||||||
"task_index": task_index,
|
|
||||||
},
|
|
||||||
features=hf_features,
|
|
||||||
)
|
|
||||||
dataset.set_transform(hf_transform_to_torch)
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
return _create_hf_dataset
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def lerobot_dataset_metadata_factory(
|
|
||||||
info_factory,
|
|
||||||
stats_factory,
|
|
||||||
episodes_stats_factory,
|
|
||||||
tasks_factory,
|
|
||||||
episodes_factory,
|
|
||||||
mock_snapshot_download_factory,
|
|
||||||
):
|
|
||||||
def _create_lerobot_dataset_metadata(
|
|
||||||
root: Path,
|
|
||||||
repo_id: str = DUMMY_REPO_ID,
|
|
||||||
info: dict | None = None,
|
|
||||||
stats: dict | None = None,
|
|
||||||
episodes_stats: list[dict] | None = None,
|
|
||||||
tasks: list[dict] | None = None,
|
|
||||||
episodes: list[dict] | None = None,
|
|
||||||
) -> LeRobotDatasetMetadata:
|
|
||||||
if not info:
|
|
||||||
info = info_factory()
|
|
||||||
if not stats:
|
|
||||||
stats = stats_factory(features=info["features"])
|
|
||||||
if not episodes_stats:
|
|
||||||
episodes_stats = episodes_stats_factory(
|
|
||||||
features=info["features"], total_episodes=info["total_episodes"]
|
|
||||||
)
|
|
||||||
if not tasks:
|
|
||||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
|
||||||
if not episodes:
|
|
||||||
episodes = episodes_factory(
|
|
||||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_snapshot_download = mock_snapshot_download_factory(
|
|
||||||
info=info,
|
|
||||||
stats=stats,
|
|
||||||
episodes_stats=episodes_stats,
|
|
||||||
tasks=tasks,
|
|
||||||
episodes=episodes,
|
|
||||||
)
|
|
||||||
with (
|
|
||||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
|
|
||||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download_patch,
|
|
||||||
):
|
|
||||||
mock_get_safe_version_patch.side_effect = lambda repo_id, version: version
|
|
||||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
|
||||||
|
|
||||||
return LeRobotDatasetMetadata(repo_id=repo_id, root=root)
|
|
||||||
|
|
||||||
return _create_lerobot_dataset_metadata
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def lerobot_dataset_factory(
|
|
||||||
info_factory,
|
|
||||||
stats_factory,
|
|
||||||
episodes_stats_factory,
|
|
||||||
tasks_factory,
|
|
||||||
episodes_factory,
|
|
||||||
hf_dataset_factory,
|
|
||||||
mock_snapshot_download_factory,
|
|
||||||
lerobot_dataset_metadata_factory,
|
|
||||||
) -> LeRobotDatasetFactory:
|
|
||||||
def _create_lerobot_dataset(
|
|
||||||
root: Path,
|
|
||||||
repo_id: str = DUMMY_REPO_ID,
|
|
||||||
total_episodes: int = 3,
|
|
||||||
total_frames: int = 150,
|
|
||||||
total_tasks: int = 1,
|
|
||||||
multi_task: bool = False,
|
|
||||||
info: dict | None = None,
|
|
||||||
stats: dict | None = None,
|
|
||||||
episodes_stats: list[dict] | None = None,
|
|
||||||
tasks: list[dict] | None = None,
|
|
||||||
episode_dicts: list[dict] | None = None,
|
|
||||||
hf_dataset: datasets.Dataset | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> LeRobotDataset:
|
|
||||||
if not info:
|
|
||||||
info = info_factory(
|
|
||||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
|
||||||
)
|
|
||||||
if not stats:
|
|
||||||
stats = stats_factory(features=info["features"])
|
|
||||||
if not episodes_stats:
|
|
||||||
episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes)
|
|
||||||
if not tasks:
|
|
||||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
|
||||||
if not episode_dicts:
|
|
||||||
episode_dicts = episodes_factory(
|
|
||||||
total_episodes=info["total_episodes"],
|
|
||||||
total_frames=info["total_frames"],
|
|
||||||
tasks=tasks,
|
|
||||||
multi_task=multi_task,
|
|
||||||
)
|
|
||||||
if not hf_dataset:
|
|
||||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
|
|
||||||
|
|
||||||
mock_snapshot_download = mock_snapshot_download_factory(
|
|
||||||
info=info,
|
|
||||||
stats=stats,
|
|
||||||
episodes_stats=episodes_stats,
|
|
||||||
tasks=tasks,
|
|
||||||
episodes=episode_dicts,
|
|
||||||
hf_dataset=hf_dataset,
|
|
||||||
)
|
|
||||||
mock_metadata = lerobot_dataset_metadata_factory(
|
|
||||||
root=root,
|
|
||||||
repo_id=repo_id,
|
|
||||||
info=info,
|
|
||||||
stats=stats,
|
|
||||||
episodes_stats=episodes_stats,
|
|
||||||
tasks=tasks,
|
|
||||||
episodes=episode_dicts,
|
|
||||||
)
|
|
||||||
with (
|
|
||||||
patch("lerobot.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
|
||||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
|
|
||||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download_patch,
|
|
||||||
):
|
|
||||||
mock_metadata_patch.return_value = mock_metadata
|
|
||||||
mock_get_safe_version_patch.side_effect = lambda repo_id, version: version
|
|
||||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
|
||||||
|
|
||||||
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
|
|
||||||
|
|
||||||
return _create_lerobot_dataset
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
|
|
||||||
return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)
|
|
||||||
147
tests/fixtures/files.py
vendored
147
tests/fixtures/files.py
vendored
@ -1,147 +0,0 @@
|
|||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
import jsonlines
|
|
||||||
import pyarrow.compute as pc
|
|
||||||
import pyarrow.parquet as pq
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from lerobot.datasets.utils import (
|
|
||||||
EPISODES_PATH,
|
|
||||||
EPISODES_STATS_PATH,
|
|
||||||
INFO_PATH,
|
|
||||||
STATS_PATH,
|
|
||||||
TASKS_PATH,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def info_path(info_factory):
|
|
||||||
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path:
|
|
||||||
if not info:
|
|
||||||
info = info_factory()
|
|
||||||
fpath = dir / INFO_PATH
|
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(fpath, "w") as f:
|
|
||||||
json.dump(info, f, indent=4, ensure_ascii=False)
|
|
||||||
return fpath
|
|
||||||
|
|
||||||
return _create_info_json_file
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def stats_path(stats_factory):
|
|
||||||
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path:
|
|
||||||
if not stats:
|
|
||||||
stats = stats_factory()
|
|
||||||
fpath = dir / STATS_PATH
|
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(fpath, "w") as f:
|
|
||||||
json.dump(stats, f, indent=4, ensure_ascii=False)
|
|
||||||
return fpath
|
|
||||||
|
|
||||||
return _create_stats_json_file
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def episodes_stats_path(episodes_stats_factory):
|
|
||||||
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path:
|
|
||||||
if not episodes_stats:
|
|
||||||
episodes_stats = episodes_stats_factory()
|
|
||||||
fpath = dir / EPISODES_STATS_PATH
|
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with jsonlines.open(fpath, "w") as writer:
|
|
||||||
writer.write_all(episodes_stats.values())
|
|
||||||
return fpath
|
|
||||||
|
|
||||||
return _create_episodes_stats_jsonl_file
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def tasks_path(tasks_factory):
|
|
||||||
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
|
|
||||||
if not tasks:
|
|
||||||
tasks = tasks_factory()
|
|
||||||
fpath = dir / TASKS_PATH
|
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with jsonlines.open(fpath, "w") as writer:
|
|
||||||
writer.write_all(tasks.values())
|
|
||||||
return fpath
|
|
||||||
|
|
||||||
return _create_tasks_jsonl_file
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def episode_path(episodes_factory):
|
|
||||||
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path:
|
|
||||||
if not episodes:
|
|
||||||
episodes = episodes_factory()
|
|
||||||
fpath = dir / EPISODES_PATH
|
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with jsonlines.open(fpath, "w") as writer:
|
|
||||||
writer.write_all(episodes.values())
|
|
||||||
return fpath
|
|
||||||
|
|
||||||
return _create_episodes_jsonl_file
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
|
||||||
def _create_single_episode_parquet(
|
|
||||||
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
|
||||||
) -> Path:
|
|
||||||
if not info:
|
|
||||||
info = info_factory()
|
|
||||||
if hf_dataset is None:
|
|
||||||
hf_dataset = hf_dataset_factory()
|
|
||||||
|
|
||||||
data_path = info["data_path"]
|
|
||||||
chunks_size = info["chunks_size"]
|
|
||||||
ep_chunk = ep_idx // chunks_size
|
|
||||||
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
table = hf_dataset.data.table
|
|
||||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
|
||||||
pq.write_table(ep_table, fpath)
|
|
||||||
return fpath
|
|
||||||
|
|
||||||
return _create_single_episode_parquet
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def multi_episode_parquet_path(hf_dataset_factory, info_factory):
|
|
||||||
def _create_multi_episode_parquet(
|
|
||||||
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
|
||||||
) -> Path:
|
|
||||||
if not info:
|
|
||||||
info = info_factory()
|
|
||||||
if hf_dataset is None:
|
|
||||||
hf_dataset = hf_dataset_factory()
|
|
||||||
|
|
||||||
data_path = info["data_path"]
|
|
||||||
chunks_size = info["chunks_size"]
|
|
||||||
total_episodes = info["total_episodes"]
|
|
||||||
for ep_idx in range(total_episodes):
|
|
||||||
ep_chunk = ep_idx // chunks_size
|
|
||||||
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
table = hf_dataset.data.table
|
|
||||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
|
||||||
pq.write_table(ep_table, fpath)
|
|
||||||
return dir / "data"
|
|
||||||
|
|
||||||
return _create_multi_episode_parquet
|
|
||||||
133
tests/fixtures/hub.py
vendored
133
tests/fixtures/hub.py
vendored
@ -1,133 +0,0 @@
|
|||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
import pytest
|
|
||||||
from huggingface_hub.utils import filter_repo_objects
|
|
||||||
|
|
||||||
from lerobot.datasets.utils import (
|
|
||||||
EPISODES_PATH,
|
|
||||||
EPISODES_STATS_PATH,
|
|
||||||
INFO_PATH,
|
|
||||||
STATS_PATH,
|
|
||||||
TASKS_PATH,
|
|
||||||
)
|
|
||||||
from tests.fixtures.constants import LEROBOT_TEST_DIR
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def mock_snapshot_download_factory(
|
|
||||||
info_factory,
|
|
||||||
info_path,
|
|
||||||
stats_factory,
|
|
||||||
stats_path,
|
|
||||||
episodes_stats_factory,
|
|
||||||
episodes_stats_path,
|
|
||||||
tasks_factory,
|
|
||||||
tasks_path,
|
|
||||||
episodes_factory,
|
|
||||||
episode_path,
|
|
||||||
single_episode_parquet_path,
|
|
||||||
hf_dataset_factory,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
This factory allows to patch snapshot_download such that when called, it will create expected files rather
|
|
||||||
than making calls to the hub api. Its design allows to pass explicitly files which you want to be created.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _mock_snapshot_download_func(
|
|
||||||
info: dict | None = None,
|
|
||||||
stats: dict | None = None,
|
|
||||||
episodes_stats: list[dict] | None = None,
|
|
||||||
tasks: list[dict] | None = None,
|
|
||||||
episodes: list[dict] | None = None,
|
|
||||||
hf_dataset: datasets.Dataset | None = None,
|
|
||||||
):
|
|
||||||
if not info:
|
|
||||||
info = info_factory()
|
|
||||||
if not stats:
|
|
||||||
stats = stats_factory(features=info["features"])
|
|
||||||
if not episodes_stats:
|
|
||||||
episodes_stats = episodes_stats_factory(
|
|
||||||
features=info["features"], total_episodes=info["total_episodes"]
|
|
||||||
)
|
|
||||||
if not tasks:
|
|
||||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
|
||||||
if not episodes:
|
|
||||||
episodes = episodes_factory(
|
|
||||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
|
||||||
)
|
|
||||||
if not hf_dataset:
|
|
||||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
|
|
||||||
|
|
||||||
def _extract_episode_index_from_path(fpath: str) -> int:
|
|
||||||
path = Path(fpath)
|
|
||||||
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
|
|
||||||
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
|
|
||||||
return episode_index
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _mock_snapshot_download(
|
|
||||||
repo_id: str,
|
|
||||||
local_dir: str | Path | None = None,
|
|
||||||
allow_patterns: str | list[str] | None = None,
|
|
||||||
ignore_patterns: str | list[str] | None = None,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
) -> str:
|
|
||||||
if not local_dir:
|
|
||||||
local_dir = LEROBOT_TEST_DIR
|
|
||||||
|
|
||||||
# List all possible files
|
|
||||||
all_files = []
|
|
||||||
meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH]
|
|
||||||
all_files.extend(meta_files)
|
|
||||||
|
|
||||||
data_files = []
|
|
||||||
for episode_dict in episodes.values():
|
|
||||||
ep_idx = episode_dict["episode_index"]
|
|
||||||
ep_chunk = ep_idx // info["chunks_size"]
|
|
||||||
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
|
||||||
data_files.append(data_path)
|
|
||||||
all_files.extend(data_files)
|
|
||||||
|
|
||||||
allowed_files = filter_repo_objects(
|
|
||||||
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create allowed files
|
|
||||||
for rel_path in allowed_files:
|
|
||||||
if rel_path.startswith("data/"):
|
|
||||||
episode_index = _extract_episode_index_from_path(rel_path)
|
|
||||||
if episode_index is not None:
|
|
||||||
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
|
|
||||||
if rel_path == INFO_PATH:
|
|
||||||
_ = info_path(local_dir, info)
|
|
||||||
elif rel_path == STATS_PATH:
|
|
||||||
_ = stats_path(local_dir, stats)
|
|
||||||
elif rel_path == EPISODES_STATS_PATH:
|
|
||||||
_ = episodes_stats_path(local_dir, episodes_stats)
|
|
||||||
elif rel_path == TASKS_PATH:
|
|
||||||
_ = tasks_path(local_dir, tasks)
|
|
||||||
elif rel_path == EPISODES_PATH:
|
|
||||||
_ = episode_path(local_dir, episodes)
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
return str(local_dir)
|
|
||||||
|
|
||||||
return _mock_snapshot_download
|
|
||||||
|
|
||||||
return _mock_snapshot_download_func
|
|
||||||
39
tests/fixtures/optimizers.py
vendored
39
tests/fixtures/optimizers.py
vendored
@ -1,39 +0,0 @@
|
|||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.optim.optimizers import AdamConfig
|
|
||||||
from lerobot.optim.schedulers import VQBeTSchedulerConfig
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def model_params():
|
|
||||||
return [torch.nn.Parameter(torch.randn(10, 10))]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def optimizer(model_params):
|
|
||||||
optimizer = AdamConfig().build(model_params)
|
|
||||||
# Dummy step to populate state
|
|
||||||
loss = sum(param.sum() for param in model_params)
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def scheduler(optimizer):
|
|
||||||
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
|
|
||||||
return config.build(optimizer, num_training_steps=100)
|
|
||||||
@ -1,596 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import abc
|
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
import dynamixel_sdk as dxl
|
|
||||||
import serial
|
|
||||||
from mock_serial.mock_serial import MockSerial
|
|
||||||
|
|
||||||
from lerobot.motors.dynamixel.dynamixel import _split_into_byte_chunks
|
|
||||||
|
|
||||||
from .mock_serial_patch import WaitableStub
|
|
||||||
|
|
||||||
# https://emanual.robotis.com/docs/en/dxl/crc/
|
|
||||||
DXL_CRC_TABLE = [
|
|
||||||
0x0000, 0x8005, 0x800F, 0x000A, 0x801B, 0x001E, 0x0014, 0x8011,
|
|
||||||
0x8033, 0x0036, 0x003C, 0x8039, 0x0028, 0x802D, 0x8027, 0x0022,
|
|
||||||
0x8063, 0x0066, 0x006C, 0x8069, 0x0078, 0x807D, 0x8077, 0x0072,
|
|
||||||
0x0050, 0x8055, 0x805F, 0x005A, 0x804B, 0x004E, 0x0044, 0x8041,
|
|
||||||
0x80C3, 0x00C6, 0x00CC, 0x80C9, 0x00D8, 0x80DD, 0x80D7, 0x00D2,
|
|
||||||
0x00F0, 0x80F5, 0x80FF, 0x00FA, 0x80EB, 0x00EE, 0x00E4, 0x80E1,
|
|
||||||
0x00A0, 0x80A5, 0x80AF, 0x00AA, 0x80BB, 0x00BE, 0x00B4, 0x80B1,
|
|
||||||
0x8093, 0x0096, 0x009C, 0x8099, 0x0088, 0x808D, 0x8087, 0x0082,
|
|
||||||
0x8183, 0x0186, 0x018C, 0x8189, 0x0198, 0x819D, 0x8197, 0x0192,
|
|
||||||
0x01B0, 0x81B5, 0x81BF, 0x01BA, 0x81AB, 0x01AE, 0x01A4, 0x81A1,
|
|
||||||
0x01E0, 0x81E5, 0x81EF, 0x01EA, 0x81FB, 0x01FE, 0x01F4, 0x81F1,
|
|
||||||
0x81D3, 0x01D6, 0x01DC, 0x81D9, 0x01C8, 0x81CD, 0x81C7, 0x01C2,
|
|
||||||
0x0140, 0x8145, 0x814F, 0x014A, 0x815B, 0x015E, 0x0154, 0x8151,
|
|
||||||
0x8173, 0x0176, 0x017C, 0x8179, 0x0168, 0x816D, 0x8167, 0x0162,
|
|
||||||
0x8123, 0x0126, 0x012C, 0x8129, 0x0138, 0x813D, 0x8137, 0x0132,
|
|
||||||
0x0110, 0x8115, 0x811F, 0x011A, 0x810B, 0x010E, 0x0104, 0x8101,
|
|
||||||
0x8303, 0x0306, 0x030C, 0x8309, 0x0318, 0x831D, 0x8317, 0x0312,
|
|
||||||
0x0330, 0x8335, 0x833F, 0x033A, 0x832B, 0x032E, 0x0324, 0x8321,
|
|
||||||
0x0360, 0x8365, 0x836F, 0x036A, 0x837B, 0x037E, 0x0374, 0x8371,
|
|
||||||
0x8353, 0x0356, 0x035C, 0x8359, 0x0348, 0x834D, 0x8347, 0x0342,
|
|
||||||
0x03C0, 0x83C5, 0x83CF, 0x03CA, 0x83DB, 0x03DE, 0x03D4, 0x83D1,
|
|
||||||
0x83F3, 0x03F6, 0x03FC, 0x83F9, 0x03E8, 0x83ED, 0x83E7, 0x03E2,
|
|
||||||
0x83A3, 0x03A6, 0x03AC, 0x83A9, 0x03B8, 0x83BD, 0x83B7, 0x03B2,
|
|
||||||
0x0390, 0x8395, 0x839F, 0x039A, 0x838B, 0x038E, 0x0384, 0x8381,
|
|
||||||
0x0280, 0x8285, 0x828F, 0x028A, 0x829B, 0x029E, 0x0294, 0x8291,
|
|
||||||
0x82B3, 0x02B6, 0x02BC, 0x82B9, 0x02A8, 0x82AD, 0x82A7, 0x02A2,
|
|
||||||
0x82E3, 0x02E6, 0x02EC, 0x82E9, 0x02F8, 0x82FD, 0x82F7, 0x02F2,
|
|
||||||
0x02D0, 0x82D5, 0x82DF, 0x02DA, 0x82CB, 0x02CE, 0x02C4, 0x82C1,
|
|
||||||
0x8243, 0x0246, 0x024C, 0x8249, 0x0258, 0x825D, 0x8257, 0x0252,
|
|
||||||
0x0270, 0x8275, 0x827F, 0x027A, 0x826B, 0x026E, 0x0264, 0x8261,
|
|
||||||
0x0220, 0x8225, 0x822F, 0x022A, 0x823B, 0x023E, 0x0234, 0x8231,
|
|
||||||
0x8213, 0x0216, 0x021C, 0x8219, 0x0208, 0x820D, 0x8207, 0x0202
|
|
||||||
] # fmt: skip
|
|
||||||
|
|
||||||
|
|
||||||
class MockDynamixelPacketv2(abc.ABC):
|
|
||||||
@classmethod
|
|
||||||
def build(cls, dxl_id: int, params: list[int], length: int, *args, **kwargs) -> bytes:
|
|
||||||
packet = cls._build(dxl_id, params, length, *args, **kwargs)
|
|
||||||
packet = cls._add_stuffing(packet)
|
|
||||||
packet = cls._add_crc(packet)
|
|
||||||
return bytes(packet)
|
|
||||||
|
|
||||||
@abc.abstractclassmethod
|
|
||||||
def _build(cls, dxl_id: int, params: list[int], length: int, *args, **kwargs) -> list[int]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _add_stuffing(packet: list[int]) -> list[int]:
|
|
||||||
"""
|
|
||||||
Byte stuffing is a method of adding additional data to generated instruction packets to ensure that
|
|
||||||
the packets are processed successfully. When the byte pattern "0xFF 0xFF 0xFD" appears in a packet,
|
|
||||||
byte stuffing adds 0xFD to the end of the pattern to convert it to “0xFF 0xFF 0xFD 0xFD” to ensure
|
|
||||||
that it is not interpreted as the header at the start of another packet.
|
|
||||||
|
|
||||||
Source: https://emanual.robotis.com/docs/en/dxl/protocol2/#transmission-process
|
|
||||||
|
|
||||||
Args:
|
|
||||||
packet (list[int]): The raw packet without stuffing.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[int]: The packet stuffed if it contained a "0xFF 0xFF 0xFD" byte sequence in its data bytes.
|
|
||||||
"""
|
|
||||||
packet_length_in = dxl.DXL_MAKEWORD(packet[dxl.PKT_LENGTH_L], packet[dxl.PKT_LENGTH_H])
|
|
||||||
packet_length_out = packet_length_in
|
|
||||||
|
|
||||||
temp = [0] * dxl.TXPACKET_MAX_LEN
|
|
||||||
|
|
||||||
# FF FF FD XX ID LEN_L LEN_H
|
|
||||||
temp[dxl.PKT_HEADER0 : dxl.PKT_HEADER0 + dxl.PKT_LENGTH_H + 1] = packet[
|
|
||||||
dxl.PKT_HEADER0 : dxl.PKT_HEADER0 + dxl.PKT_LENGTH_H + 1
|
|
||||||
]
|
|
||||||
|
|
||||||
index = dxl.PKT_INSTRUCTION
|
|
||||||
|
|
||||||
for i in range(0, packet_length_in - 2): # except CRC
|
|
||||||
temp[index] = packet[i + dxl.PKT_INSTRUCTION]
|
|
||||||
index = index + 1
|
|
||||||
if (
|
|
||||||
packet[i + dxl.PKT_INSTRUCTION] == 0xFD
|
|
||||||
and packet[i + dxl.PKT_INSTRUCTION - 1] == 0xFF
|
|
||||||
and packet[i + dxl.PKT_INSTRUCTION - 2] == 0xFF
|
|
||||||
):
|
|
||||||
# FF FF FD
|
|
||||||
temp[index] = 0xFD
|
|
||||||
index = index + 1
|
|
||||||
packet_length_out = packet_length_out + 1
|
|
||||||
|
|
||||||
temp[index] = packet[dxl.PKT_INSTRUCTION + packet_length_in - 2]
|
|
||||||
temp[index + 1] = packet[dxl.PKT_INSTRUCTION + packet_length_in - 1]
|
|
||||||
index = index + 2
|
|
||||||
|
|
||||||
if packet_length_in != packet_length_out:
|
|
||||||
packet = [0] * index
|
|
||||||
|
|
||||||
packet[0:index] = temp[0:index]
|
|
||||||
|
|
||||||
packet[dxl.PKT_LENGTH_L] = dxl.DXL_LOBYTE(packet_length_out)
|
|
||||||
packet[dxl.PKT_LENGTH_H] = dxl.DXL_HIBYTE(packet_length_out)
|
|
||||||
|
|
||||||
return packet
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _add_crc(packet: list[int]) -> list[int]:
|
|
||||||
"""Computes and add CRC to the packet.
|
|
||||||
|
|
||||||
https://emanual.robotis.com/docs/en/dxl/crc/
|
|
||||||
https://en.wikipedia.org/wiki/Cyclic_redundancy_check
|
|
||||||
|
|
||||||
Args:
|
|
||||||
packet (list[int]): The raw packet without CRC (but with placeholders for it).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[int]: The raw packet with a valid CRC.
|
|
||||||
"""
|
|
||||||
crc = 0
|
|
||||||
for j in range(len(packet) - 2):
|
|
||||||
i = ((crc >> 8) ^ packet[j]) & 0xFF
|
|
||||||
crc = ((crc << 8) ^ DXL_CRC_TABLE[i]) & 0xFFFF
|
|
||||||
|
|
||||||
packet[-2] = dxl.DXL_LOBYTE(crc)
|
|
||||||
packet[-1] = dxl.DXL_HIBYTE(crc)
|
|
||||||
|
|
||||||
return packet
|
|
||||||
|
|
||||||
|
|
||||||
class MockInstructionPacket(MockDynamixelPacketv2):
|
|
||||||
"""
|
|
||||||
Helper class to build valid Dynamixel Protocol 2.0 Instruction Packets.
|
|
||||||
|
|
||||||
Protocol 2.0 Instruction Packet structure
|
|
||||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#instruction-packet
|
|
||||||
|
|
||||||
| Header | Packet ID | Length | Instruction | Params | CRC |
|
|
||||||
| ------------------- | --------- | ----------- | ----------- | ----------------- | ----------- |
|
|
||||||
| 0xFF 0xFF 0xFD 0x00 | ID | Len_L Len_H | Instr | Param 1 … Param N | CRC_L CRC_H |
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _build(cls, dxl_id: int, params: list[int], length: int, instruction: int) -> list[int]:
|
|
||||||
length = len(params) + 3
|
|
||||||
return [
|
|
||||||
0xFF, 0xFF, 0xFD, 0x00, # header
|
|
||||||
dxl_id, # servo id
|
|
||||||
dxl.DXL_LOBYTE(length), # length_l
|
|
||||||
dxl.DXL_HIBYTE(length), # length_h
|
|
||||||
instruction, # instruction type
|
|
||||||
*params, # data bytes
|
|
||||||
0x00, 0x00 # placeholder for CRC
|
|
||||||
] # fmt: skip
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def ping(
|
|
||||||
cls,
|
|
||||||
dxl_id: int,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Builds a "Ping" broadcast instruction.
|
|
||||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#ping-0x01
|
|
||||||
|
|
||||||
No parameters required.
|
|
||||||
"""
|
|
||||||
return cls.build(dxl_id=dxl_id, params=[], length=3, instruction=dxl.INST_PING)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def read(
|
|
||||||
cls,
|
|
||||||
dxl_id: int,
|
|
||||||
start_address: int,
|
|
||||||
data_length: int,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Builds a "Read" instruction.
|
|
||||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#read-0x02
|
|
||||||
|
|
||||||
The parameters for Read (Protocol 2.0) are:
|
|
||||||
param[0] = start_address L
|
|
||||||
param[1] = start_address H
|
|
||||||
param[2] = data_length L
|
|
||||||
param[3] = data_length H
|
|
||||||
|
|
||||||
And 'length' = data_length + 5, where:
|
|
||||||
+1 is for instruction byte,
|
|
||||||
+2 is for the length bytes,
|
|
||||||
+2 is for the CRC at the end.
|
|
||||||
"""
|
|
||||||
params = [
|
|
||||||
dxl.DXL_LOBYTE(start_address),
|
|
||||||
dxl.DXL_HIBYTE(start_address),
|
|
||||||
dxl.DXL_LOBYTE(data_length),
|
|
||||||
dxl.DXL_HIBYTE(data_length),
|
|
||||||
]
|
|
||||||
length = len(params) + 3
|
|
||||||
# length = data_length + 5
|
|
||||||
return cls.build(dxl_id=dxl_id, params=params, length=length, instruction=dxl.INST_READ)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def write(
|
|
||||||
cls,
|
|
||||||
dxl_id: int,
|
|
||||||
value: int,
|
|
||||||
start_address: int,
|
|
||||||
data_length: int,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Builds a "Write" instruction.
|
|
||||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#write-0x03
|
|
||||||
|
|
||||||
The parameters for Write (Protocol 2.0) are:
|
|
||||||
param[0] = start_address L
|
|
||||||
param[1] = start_address H
|
|
||||||
param[2] = 1st Byte
|
|
||||||
param[3] = 2nd Byte
|
|
||||||
...
|
|
||||||
param[1+X] = X-th Byte
|
|
||||||
|
|
||||||
And 'length' = data_length + 5, where:
|
|
||||||
+1 is for instruction byte,
|
|
||||||
+2 is for the length bytes,
|
|
||||||
+2 is for the CRC at the end.
|
|
||||||
"""
|
|
||||||
data = _split_into_byte_chunks(value, data_length)
|
|
||||||
params = [
|
|
||||||
dxl.DXL_LOBYTE(start_address),
|
|
||||||
dxl.DXL_HIBYTE(start_address),
|
|
||||||
*data,
|
|
||||||
]
|
|
||||||
length = data_length + 5
|
|
||||||
return cls.build(dxl_id=dxl_id, params=params, length=length, instruction=dxl.INST_WRITE)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def sync_read(
|
|
||||||
cls,
|
|
||||||
dxl_ids: list[int],
|
|
||||||
start_address: int,
|
|
||||||
data_length: int,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Builds a "Sync_Read" broadcast instruction.
|
|
||||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#sync-read-0x82
|
|
||||||
|
|
||||||
The parameters for Sync_Read (Protocol 2.0) are:
|
|
||||||
param[0] = start_address L
|
|
||||||
param[1] = start_address H
|
|
||||||
param[2] = data_length L
|
|
||||||
param[3] = data_length H
|
|
||||||
param[4+] = motor IDs to read from
|
|
||||||
|
|
||||||
And 'length' = (number_of_params + 7), where:
|
|
||||||
+1 is for instruction byte,
|
|
||||||
+2 is for the address bytes,
|
|
||||||
+2 is for the length bytes,
|
|
||||||
+2 is for the CRC at the end.
|
|
||||||
"""
|
|
||||||
params = [
|
|
||||||
dxl.DXL_LOBYTE(start_address),
|
|
||||||
dxl.DXL_HIBYTE(start_address),
|
|
||||||
dxl.DXL_LOBYTE(data_length),
|
|
||||||
dxl.DXL_HIBYTE(data_length),
|
|
||||||
*dxl_ids,
|
|
||||||
]
|
|
||||||
length = len(dxl_ids) + 7
|
|
||||||
return cls.build(
|
|
||||||
dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruction=dxl.INST_SYNC_READ
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def sync_write(
|
|
||||||
cls,
|
|
||||||
ids_values: dict[int, int],
|
|
||||||
start_address: int,
|
|
||||||
data_length: int,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Builds a "Sync_Write" broadcast instruction.
|
|
||||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#sync-write-0x83
|
|
||||||
|
|
||||||
The parameters for Sync_Write (Protocol 2.0) are:
|
|
||||||
param[0] = start_address L
|
|
||||||
param[1] = start_address H
|
|
||||||
param[2] = data_length L
|
|
||||||
param[3] = data_length H
|
|
||||||
param[5] = [1st motor] ID
|
|
||||||
param[5+1] = [1st motor] 1st Byte
|
|
||||||
param[5+2] = [1st motor] 2nd Byte
|
|
||||||
...
|
|
||||||
param[5+X] = [1st motor] X-th Byte
|
|
||||||
param[6] = [2nd motor] ID
|
|
||||||
param[6+1] = [2nd motor] 1st Byte
|
|
||||||
param[6+2] = [2nd motor] 2nd Byte
|
|
||||||
...
|
|
||||||
param[6+X] = [2nd motor] X-th Byte
|
|
||||||
|
|
||||||
And 'length' = ((number_of_params * 1 + data_length) + 7), where:
|
|
||||||
+1 is for instruction byte,
|
|
||||||
+2 is for the address bytes,
|
|
||||||
+2 is for the length bytes,
|
|
||||||
+2 is for the CRC at the end.
|
|
||||||
"""
|
|
||||||
data = []
|
|
||||||
for id_, value in ids_values.items():
|
|
||||||
split_value = _split_into_byte_chunks(value, data_length)
|
|
||||||
data += [id_, *split_value]
|
|
||||||
params = [
|
|
||||||
dxl.DXL_LOBYTE(start_address),
|
|
||||||
dxl.DXL_HIBYTE(start_address),
|
|
||||||
dxl.DXL_LOBYTE(data_length),
|
|
||||||
dxl.DXL_HIBYTE(data_length),
|
|
||||||
*data,
|
|
||||||
]
|
|
||||||
length = len(ids_values) * (1 + data_length) + 7
|
|
||||||
return cls.build(
|
|
||||||
dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruction=dxl.INST_SYNC_WRITE
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MockStatusPacket(MockDynamixelPacketv2):
|
|
||||||
"""
|
|
||||||
Helper class to build valid Dynamixel Protocol 2.0 Status Packets.
|
|
||||||
|
|
||||||
Protocol 2.0 Status Packet structure
|
|
||||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#status-packet
|
|
||||||
|
|
||||||
| Header | Packet ID | Length | Instruction | Error | Params | CRC |
|
|
||||||
| ------------------- | --------- | ----------- | ----------- | ----- | ----------------- | ----------- |
|
|
||||||
| 0xFF 0xFF 0xFD 0x00 | ID | Len_L Len_H | 0x55 | Err | Param 1 … Param N | CRC_L CRC_H |
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _build(cls, dxl_id: int, params: list[int], length: int, error: int = 0) -> list[int]:
|
|
||||||
return [
|
|
||||||
0xFF, 0xFF, 0xFD, 0x00, # header
|
|
||||||
dxl_id, # servo id
|
|
||||||
dxl.DXL_LOBYTE(length), # length_l
|
|
||||||
dxl.DXL_HIBYTE(length), # length_h
|
|
||||||
0x55, # instruction = 'status'
|
|
||||||
error, # error
|
|
||||||
*params, # data bytes
|
|
||||||
0x00, 0x00 # placeholder for CRC
|
|
||||||
] # fmt: skip
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def ping(cls, dxl_id: int, model_nb: int = 1190, firm_ver: int = 50, error: int = 0) -> bytes:
|
|
||||||
"""
|
|
||||||
Builds a 'Ping' status packet.
|
|
||||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#ping-0x01
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dxl_id (int): ID of the servo responding.
|
|
||||||
model_nb (int, optional): Desired 'model number' to be returned in the packet. Defaults to 1190
|
|
||||||
which corresponds to a XL330-M077-T.
|
|
||||||
firm_ver (int, optional): Desired 'firmware version' to be returned in the packet.
|
|
||||||
Defaults to 50.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bytes: The raw 'Ping' status packet ready to be sent through serial.
|
|
||||||
"""
|
|
||||||
params = [dxl.DXL_LOBYTE(model_nb), dxl.DXL_HIBYTE(model_nb), firm_ver]
|
|
||||||
length = 7
|
|
||||||
return cls.build(dxl_id, params=params, length=length, error=error)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def read(cls, dxl_id: int, value: int, param_length: int, error: int = 0) -> bytes:
|
|
||||||
"""
|
|
||||||
Builds a 'Read' status packet (also works for 'Sync Read')
|
|
||||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#read-0x02
|
|
||||||
https://emanual.robotis.com/docs/en/dxl/protocol2/#sync-read-0x82
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dxl_id (int): ID of the servo responding.
|
|
||||||
value (int): Desired value to be returned in the packet.
|
|
||||||
param_length (int): The address length as reported in the control table.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bytes: The raw 'Present_Position' status packet ready to be sent through serial.
|
|
||||||
"""
|
|
||||||
params = _split_into_byte_chunks(value, param_length)
|
|
||||||
length = param_length + 4
|
|
||||||
return cls.build(dxl_id, params=params, length=length, error=error)
|
|
||||||
|
|
||||||
|
|
||||||
class MockPortHandler(dxl.PortHandler):
|
|
||||||
"""
|
|
||||||
This class overwrite the 'setupPort' method of the Dynamixel PortHandler because it can specify
|
|
||||||
baudrates that are not supported with a serial port on MacOS.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def setupPort(self, cflag_baud): # noqa: N802
|
|
||||||
if self.is_open:
|
|
||||||
self.closePort()
|
|
||||||
|
|
||||||
self.ser = serial.Serial(
|
|
||||||
port=self.port_name,
|
|
||||||
# baudrate=self.baudrate, <- This will fail on MacOS
|
|
||||||
# parity = serial.PARITY_ODD,
|
|
||||||
# stopbits = serial.STOPBITS_TWO,
|
|
||||||
bytesize=serial.EIGHTBITS,
|
|
||||||
timeout=0,
|
|
||||||
)
|
|
||||||
self.is_open = True
|
|
||||||
self.ser.reset_input_buffer()
|
|
||||||
self.tx_time_per_byte = (1000.0 / self.baudrate) * 10.0
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class MockMotors(MockSerial):
|
|
||||||
"""
|
|
||||||
This class will simulate physical motors by responding with valid status packets upon receiving some
|
|
||||||
instruction packets. It is meant to test MotorsBus classes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def stubs(self) -> dict[str, WaitableStub]:
|
|
||||||
return super().stubs
|
|
||||||
|
|
||||||
def stub(self, *, name=None, **kwargs):
|
|
||||||
new_stub = WaitableStub(**kwargs)
|
|
||||||
self._MockSerial__stubs[name or new_stub.receive_bytes] = new_stub
|
|
||||||
return new_stub
|
|
||||||
|
|
||||||
def build_broadcast_ping_stub(
|
|
||||||
self, ids_models: dict[int, list[int]] | None = None, num_invalid_try: int = 0
|
|
||||||
) -> str:
|
|
||||||
ping_request = MockInstructionPacket.ping(dxl.BROADCAST_ID)
|
|
||||||
return_packets = b"".join(MockStatusPacket.ping(id_, model) for id_, model in ids_models.items())
|
|
||||||
ping_response = self._build_send_fn(return_packets, num_invalid_try)
|
|
||||||
|
|
||||||
stub_name = "Ping_" + "_".join([str(id_) for id_ in ids_models])
|
|
||||||
self.stub(
|
|
||||||
name=stub_name,
|
|
||||||
receive_bytes=ping_request,
|
|
||||||
send_fn=ping_response,
|
|
||||||
)
|
|
||||||
return stub_name
|
|
||||||
|
|
||||||
def build_ping_stub(
|
|
||||||
self, dxl_id: int, model_nb: int, firm_ver: int = 50, num_invalid_try: int = 0, error: int = 0
|
|
||||||
) -> str:
|
|
||||||
ping_request = MockInstructionPacket.ping(dxl_id)
|
|
||||||
return_packet = MockStatusPacket.ping(dxl_id, model_nb, firm_ver, error)
|
|
||||||
ping_response = self._build_send_fn(return_packet, num_invalid_try)
|
|
||||||
stub_name = f"Ping_{dxl_id}"
|
|
||||||
self.stub(
|
|
||||||
name=stub_name,
|
|
||||||
receive_bytes=ping_request,
|
|
||||||
send_fn=ping_response,
|
|
||||||
)
|
|
||||||
return stub_name
|
|
||||||
|
|
||||||
def build_read_stub(
|
|
||||||
self,
|
|
||||||
address: int,
|
|
||||||
length: int,
|
|
||||||
dxl_id: int,
|
|
||||||
value: int,
|
|
||||||
reply: bool = True,
|
|
||||||
error: int = 0,
|
|
||||||
num_invalid_try: int = 0,
|
|
||||||
) -> str:
|
|
||||||
read_request = MockInstructionPacket.read(dxl_id, address, length)
|
|
||||||
return_packet = MockStatusPacket.read(dxl_id, value, length, error) if reply else b""
|
|
||||||
read_response = self._build_send_fn(return_packet, num_invalid_try)
|
|
||||||
stub_name = f"Read_{address}_{length}_{dxl_id}_{value}_{error}"
|
|
||||||
self.stub(
|
|
||||||
name=stub_name,
|
|
||||||
receive_bytes=read_request,
|
|
||||||
send_fn=read_response,
|
|
||||||
)
|
|
||||||
return stub_name
|
|
||||||
|
|
||||||
def build_write_stub(
|
|
||||||
self,
|
|
||||||
address: int,
|
|
||||||
length: int,
|
|
||||||
dxl_id: int,
|
|
||||||
value: int,
|
|
||||||
reply: bool = True,
|
|
||||||
error: int = 0,
|
|
||||||
num_invalid_try: int = 0,
|
|
||||||
) -> str:
|
|
||||||
sync_read_request = MockInstructionPacket.write(dxl_id, value, address, length)
|
|
||||||
return_packet = MockStatusPacket.build(dxl_id, params=[], length=4, error=error) if reply else b""
|
|
||||||
stub_name = f"Write_{address}_{length}_{dxl_id}"
|
|
||||||
self.stub(
|
|
||||||
name=stub_name,
|
|
||||||
receive_bytes=sync_read_request,
|
|
||||||
send_fn=self._build_send_fn(return_packet, num_invalid_try),
|
|
||||||
)
|
|
||||||
return stub_name
|
|
||||||
|
|
||||||
def build_sync_read_stub(
|
|
||||||
self,
|
|
||||||
address: int,
|
|
||||||
length: int,
|
|
||||||
ids_values: dict[int, int],
|
|
||||||
reply: bool = True,
|
|
||||||
num_invalid_try: int = 0,
|
|
||||||
) -> str:
|
|
||||||
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
|
|
||||||
return_packets = (
|
|
||||||
b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items())
|
|
||||||
if reply
|
|
||||||
else b""
|
|
||||||
)
|
|
||||||
sync_read_response = self._build_send_fn(return_packets, num_invalid_try)
|
|
||||||
stub_name = f"Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
|
|
||||||
self.stub(
|
|
||||||
name=stub_name,
|
|
||||||
receive_bytes=sync_read_request,
|
|
||||||
send_fn=sync_read_response,
|
|
||||||
)
|
|
||||||
return stub_name
|
|
||||||
|
|
||||||
def build_sequential_sync_read_stub(
|
|
||||||
self, address: int, length: int, ids_values: dict[int, list[int]] | None = None
|
|
||||||
) -> str:
|
|
||||||
sequence_length = len(next(iter(ids_values.values())))
|
|
||||||
assert all(len(positions) == sequence_length for positions in ids_values.values())
|
|
||||||
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
|
|
||||||
sequential_packets = []
|
|
||||||
for count in range(sequence_length):
|
|
||||||
return_packets = b"".join(
|
|
||||||
MockStatusPacket.read(id_, positions[count], length) for id_, positions in ids_values.items()
|
|
||||||
)
|
|
||||||
sequential_packets.append(return_packets)
|
|
||||||
|
|
||||||
sync_read_response = self._build_sequential_send_fn(sequential_packets)
|
|
||||||
stub_name = f"Seq_Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
|
|
||||||
self.stub(
|
|
||||||
name=stub_name,
|
|
||||||
receive_bytes=sync_read_request,
|
|
||||||
send_fn=sync_read_response,
|
|
||||||
)
|
|
||||||
return stub_name
|
|
||||||
|
|
||||||
def build_sync_write_stub(
|
|
||||||
self, address: int, length: int, ids_values: dict[int, int], num_invalid_try: int = 0
|
|
||||||
) -> str:
|
|
||||||
sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length)
|
|
||||||
stub_name = f"Sync_Write_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
|
|
||||||
self.stub(
|
|
||||||
name=stub_name,
|
|
||||||
receive_bytes=sync_read_request,
|
|
||||||
send_fn=self._build_send_fn(b"", num_invalid_try),
|
|
||||||
)
|
|
||||||
return stub_name
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]:
|
|
||||||
def send_fn(_call_count: int) -> bytes:
|
|
||||||
if num_invalid_try >= _call_count:
|
|
||||||
return b""
|
|
||||||
return packet
|
|
||||||
|
|
||||||
return send_fn
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _build_sequential_send_fn(packets: list[bytes]) -> Callable[[int], bytes]:
|
|
||||||
def send_fn(_call_count: int) -> bytes:
|
|
||||||
return packets[_call_count - 1]
|
|
||||||
|
|
||||||
return send_fn
|
|
||||||
@ -1,444 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import abc
|
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
import scservo_sdk as scs
|
|
||||||
import serial
|
|
||||||
from mock_serial import MockSerial
|
|
||||||
|
|
||||||
from lerobot.motors.feetech.feetech import _split_into_byte_chunks, patch_setPacketTimeout
|
|
||||||
|
|
||||||
from .mock_serial_patch import WaitableStub
|
|
||||||
|
|
||||||
|
|
||||||
class MockFeetechPacket(abc.ABC):
|
|
||||||
@classmethod
|
|
||||||
def build(cls, scs_id: int, params: list[int], length: int, *args, **kwargs) -> bytes:
|
|
||||||
packet = cls._build(scs_id, params, length, *args, **kwargs)
|
|
||||||
packet = cls._add_checksum(packet)
|
|
||||||
return bytes(packet)
|
|
||||||
|
|
||||||
@abc.abstractclassmethod
|
|
||||||
def _build(cls, scs_id: int, params: list[int], length: int, *args, **kwargs) -> list[int]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _add_checksum(packet: list[int]) -> list[int]:
|
|
||||||
checksum = 0
|
|
||||||
for id_ in range(2, len(packet) - 1): # except header & checksum
|
|
||||||
checksum += packet[id_]
|
|
||||||
|
|
||||||
packet[-1] = ~checksum & 0xFF
|
|
||||||
|
|
||||||
return packet
|
|
||||||
|
|
||||||
|
|
||||||
class MockInstructionPacket(MockFeetechPacket):
|
|
||||||
"""
|
|
||||||
Helper class to build valid Feetech Instruction Packets.
|
|
||||||
|
|
||||||
Instruction Packet structure
|
|
||||||
(from https://files.waveshare.com/upload/2/27/Communication_Protocol_User_Manual-EN%28191218-0923%29.pdf)
|
|
||||||
|
|
||||||
| Header | Packet ID | Length | Instruction | Params | Checksum |
|
|
||||||
| --------- | --------- | ------ | ----------- | ----------------- | -------- |
|
|
||||||
| 0xFF 0xFF | ID | Len | Instr | Param 1 … Param N | Sum |
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _build(cls, scs_id: int, params: list[int], length: int, instruction: int) -> list[int]:
|
|
||||||
return [
|
|
||||||
0xFF, 0xFF, # header
|
|
||||||
scs_id, # servo id
|
|
||||||
length, # length
|
|
||||||
instruction, # instruction type
|
|
||||||
*params, # data bytes
|
|
||||||
0x00, # placeholder for checksum
|
|
||||||
] # fmt: skip
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def ping(
|
|
||||||
cls,
|
|
||||||
scs_id: int,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Builds a "Ping" broadcast instruction.
|
|
||||||
|
|
||||||
No parameters required.
|
|
||||||
"""
|
|
||||||
return cls.build(scs_id=scs_id, params=[], length=2, instruction=scs.INST_PING)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def read(
|
|
||||||
cls,
|
|
||||||
scs_id: int,
|
|
||||||
start_address: int,
|
|
||||||
data_length: int,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Builds a "Read" instruction.
|
|
||||||
|
|
||||||
The parameters for Read are:
|
|
||||||
param[0] = start_address
|
|
||||||
param[1] = data_length
|
|
||||||
|
|
||||||
And 'length' = 4, where:
|
|
||||||
+1 is for instruction byte,
|
|
||||||
+1 is for the address byte,
|
|
||||||
+1 is for the length bytes,
|
|
||||||
+1 is for the checksum at the end.
|
|
||||||
"""
|
|
||||||
params = [start_address, data_length]
|
|
||||||
length = 4
|
|
||||||
return cls.build(scs_id=scs_id, params=params, length=length, instruction=scs.INST_READ)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def write(
|
|
||||||
cls,
|
|
||||||
scs_id: int,
|
|
||||||
value: int,
|
|
||||||
start_address: int,
|
|
||||||
data_length: int,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Builds a "Write" instruction.
|
|
||||||
|
|
||||||
The parameters for Write are:
|
|
||||||
param[0] = start_address L
|
|
||||||
param[1] = start_address H
|
|
||||||
param[2] = 1st Byte
|
|
||||||
param[3] = 2nd Byte
|
|
||||||
...
|
|
||||||
param[1+X] = X-th Byte
|
|
||||||
|
|
||||||
And 'length' = data_length + 3, where:
|
|
||||||
+1 is for instruction byte,
|
|
||||||
+1 is for the length bytes,
|
|
||||||
+1 is for the checksum at the end.
|
|
||||||
"""
|
|
||||||
data = _split_into_byte_chunks(value, data_length)
|
|
||||||
params = [start_address, *data]
|
|
||||||
length = data_length + 3
|
|
||||||
return cls.build(scs_id=scs_id, params=params, length=length, instruction=scs.INST_WRITE)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def sync_read(
|
|
||||||
cls,
|
|
||||||
scs_ids: list[int],
|
|
||||||
start_address: int,
|
|
||||||
data_length: int,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Builds a "Sync_Read" broadcast instruction.
|
|
||||||
|
|
||||||
The parameters for Sync Read are:
|
|
||||||
param[0] = start_address
|
|
||||||
param[1] = data_length
|
|
||||||
param[2+] = motor IDs to read from
|
|
||||||
|
|
||||||
And 'length' = (number_of_params + 4), where:
|
|
||||||
+1 is for instruction byte,
|
|
||||||
+1 is for the address byte,
|
|
||||||
+1 is for the length bytes,
|
|
||||||
+1 is for the checksum at the end.
|
|
||||||
"""
|
|
||||||
params = [start_address, data_length, *scs_ids]
|
|
||||||
length = len(scs_ids) + 4
|
|
||||||
return cls.build(
|
|
||||||
scs_id=scs.BROADCAST_ID, params=params, length=length, instruction=scs.INST_SYNC_READ
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def sync_write(
|
|
||||||
cls,
|
|
||||||
ids_values: dict[int, int],
|
|
||||||
start_address: int,
|
|
||||||
data_length: int,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Builds a "Sync_Write" broadcast instruction.
|
|
||||||
|
|
||||||
The parameters for Sync_Write are:
|
|
||||||
param[0] = start_address
|
|
||||||
param[1] = data_length
|
|
||||||
param[2] = [1st motor] ID
|
|
||||||
param[2+1] = [1st motor] 1st Byte
|
|
||||||
param[2+2] = [1st motor] 2nd Byte
|
|
||||||
...
|
|
||||||
param[5+X] = [1st motor] X-th Byte
|
|
||||||
param[6] = [2nd motor] ID
|
|
||||||
param[6+1] = [2nd motor] 1st Byte
|
|
||||||
param[6+2] = [2nd motor] 2nd Byte
|
|
||||||
...
|
|
||||||
param[6+X] = [2nd motor] X-th Byte
|
|
||||||
|
|
||||||
And 'length' = ((number_of_params * 1 + data_length) + 4), where:
|
|
||||||
+1 is for instruction byte,
|
|
||||||
+1 is for the address byte,
|
|
||||||
+1 is for the length bytes,
|
|
||||||
+1 is for the checksum at the end.
|
|
||||||
"""
|
|
||||||
data = []
|
|
||||||
for id_, value in ids_values.items():
|
|
||||||
split_value = _split_into_byte_chunks(value, data_length)
|
|
||||||
data += [id_, *split_value]
|
|
||||||
params = [start_address, data_length, *data]
|
|
||||||
length = len(ids_values) * (1 + data_length) + 4
|
|
||||||
return cls.build(
|
|
||||||
scs_id=scs.BROADCAST_ID, params=params, length=length, instruction=scs.INST_SYNC_WRITE
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MockStatusPacket(MockFeetechPacket):
|
|
||||||
"""
|
|
||||||
Helper class to build valid Feetech Status Packets.
|
|
||||||
|
|
||||||
Status Packet structure
|
|
||||||
(from https://files.waveshare.com/upload/2/27/Communication_Protocol_User_Manual-EN%28191218-0923%29.pdf)
|
|
||||||
|
|
||||||
| Header | Packet ID | Length | Error | Params | Checksum |
|
|
||||||
| --------- | --------- | ------ | ----- | ----------------- | -------- |
|
|
||||||
| 0xFF 0xFF | ID | Len | Err | Param 1 … Param N | Sum |
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _build(cls, scs_id: int, params: list[int], length: int, error: int = 0) -> list[int]:
|
|
||||||
return [
|
|
||||||
0xFF, 0xFF, # header
|
|
||||||
scs_id, # servo id
|
|
||||||
length, # length
|
|
||||||
error, # status
|
|
||||||
*params, # data bytes
|
|
||||||
0x00, # placeholder for checksum
|
|
||||||
] # fmt: skip
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def ping(cls, scs_id: int, error: int = 0) -> bytes:
|
|
||||||
"""Builds a 'Ping' status packet.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scs_id (int): ID of the servo responding.
|
|
||||||
error (int, optional): Error to be returned. Defaults to 0 (success).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bytes: The raw 'Ping' status packet ready to be sent through serial.
|
|
||||||
"""
|
|
||||||
return cls.build(scs_id, params=[], length=2, error=error)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def read(cls, scs_id: int, value: int, param_length: int, error: int = 0) -> bytes:
|
|
||||||
"""Builds a 'Read' status packet.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scs_id (int): ID of the servo responding.
|
|
||||||
value (int): Desired value to be returned in the packet.
|
|
||||||
param_length (int): The address length as reported in the control table.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bytes: The raw 'Sync Read' status packet ready to be sent through serial.
|
|
||||||
"""
|
|
||||||
params = _split_into_byte_chunks(value, param_length)
|
|
||||||
length = param_length + 2
|
|
||||||
return cls.build(scs_id, params=params, length=length, error=error)
|
|
||||||
|
|
||||||
|
|
||||||
class MockPortHandler(scs.PortHandler):
|
|
||||||
"""
|
|
||||||
This class overwrite the 'setupPort' method of the Feetech PortHandler because it can specify
|
|
||||||
baudrates that are not supported with a serial port on MacOS.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def setupPort(self, cflag_baud): # noqa: N802
|
|
||||||
if self.is_open:
|
|
||||||
self.closePort()
|
|
||||||
|
|
||||||
self.ser = serial.Serial(
|
|
||||||
port=self.port_name,
|
|
||||||
# baudrate=self.baudrate, <- This will fail on MacOS
|
|
||||||
# parity = serial.PARITY_ODD,
|
|
||||||
# stopbits = serial.STOPBITS_TWO,
|
|
||||||
bytesize=serial.EIGHTBITS,
|
|
||||||
timeout=0,
|
|
||||||
)
|
|
||||||
self.is_open = True
|
|
||||||
self.ser.reset_input_buffer()
|
|
||||||
self.tx_time_per_byte = (1000.0 / self.baudrate) * 10.0
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def setPacketTimeout(self, packet_length): # noqa: N802
|
|
||||||
return patch_setPacketTimeout(self, packet_length)
|
|
||||||
|
|
||||||
|
|
||||||
class MockMotors(MockSerial):
|
|
||||||
"""
|
|
||||||
This class will simulate physical motors by responding with valid status packets upon receiving some
|
|
||||||
instruction packets. It is meant to test MotorsBus classes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def stubs(self) -> dict[str, WaitableStub]:
|
|
||||||
return super().stubs
|
|
||||||
|
|
||||||
def stub(self, *, name=None, **kwargs):
|
|
||||||
new_stub = WaitableStub(**kwargs)
|
|
||||||
self._MockSerial__stubs[name or new_stub.receive_bytes] = new_stub
|
|
||||||
return new_stub
|
|
||||||
|
|
||||||
def build_broadcast_ping_stub(self, ids: list[int] | None = None, num_invalid_try: int = 0) -> str:
|
|
||||||
ping_request = MockInstructionPacket.ping(scs.BROADCAST_ID)
|
|
||||||
return_packets = b"".join(MockStatusPacket.ping(id_) for id_ in ids)
|
|
||||||
ping_response = self._build_send_fn(return_packets, num_invalid_try)
|
|
||||||
stub_name = "Ping_" + "_".join([str(id_) for id_ in ids])
|
|
||||||
self.stub(
|
|
||||||
name=stub_name,
|
|
||||||
receive_bytes=ping_request,
|
|
||||||
send_fn=ping_response,
|
|
||||||
)
|
|
||||||
return stub_name
|
|
||||||
|
|
||||||
def build_ping_stub(self, scs_id: int, num_invalid_try: int = 0, error: int = 0) -> str:
|
|
||||||
ping_request = MockInstructionPacket.ping(scs_id)
|
|
||||||
return_packet = MockStatusPacket.ping(scs_id, error)
|
|
||||||
ping_response = self._build_send_fn(return_packet, num_invalid_try)
|
|
||||||
stub_name = f"Ping_{scs_id}_{error}"
|
|
||||||
self.stub(
|
|
||||||
name=stub_name,
|
|
||||||
receive_bytes=ping_request,
|
|
||||||
send_fn=ping_response,
|
|
||||||
)
|
|
||||||
return stub_name
|
|
||||||
|
|
||||||
def build_read_stub(
|
|
||||||
self,
|
|
||||||
address: int,
|
|
||||||
length: int,
|
|
||||||
scs_id: int,
|
|
||||||
value: int,
|
|
||||||
reply: bool = True,
|
|
||||||
error: int = 0,
|
|
||||||
num_invalid_try: int = 0,
|
|
||||||
) -> str:
|
|
||||||
read_request = MockInstructionPacket.read(scs_id, address, length)
|
|
||||||
return_packet = MockStatusPacket.read(scs_id, value, length, error) if reply else b""
|
|
||||||
read_response = self._build_send_fn(return_packet, num_invalid_try)
|
|
||||||
stub_name = f"Read_{address}_{length}_{scs_id}_{value}_{error}"
|
|
||||||
self.stub(
|
|
||||||
name=stub_name,
|
|
||||||
receive_bytes=read_request,
|
|
||||||
send_fn=read_response,
|
|
||||||
)
|
|
||||||
return stub_name
|
|
||||||
|
|
||||||
def build_write_stub(
|
|
||||||
self,
|
|
||||||
address: int,
|
|
||||||
length: int,
|
|
||||||
scs_id: int,
|
|
||||||
value: int,
|
|
||||||
reply: bool = True,
|
|
||||||
error: int = 0,
|
|
||||||
num_invalid_try: int = 0,
|
|
||||||
) -> str:
|
|
||||||
sync_read_request = MockInstructionPacket.write(scs_id, value, address, length)
|
|
||||||
return_packet = MockStatusPacket.build(scs_id, params=[], length=2, error=error) if reply else b""
|
|
||||||
stub_name = f"Write_{address}_{length}_{scs_id}"
|
|
||||||
self.stub(
|
|
||||||
name=stub_name,
|
|
||||||
receive_bytes=sync_read_request,
|
|
||||||
send_fn=self._build_send_fn(return_packet, num_invalid_try),
|
|
||||||
)
|
|
||||||
return stub_name
|
|
||||||
|
|
||||||
def build_sync_read_stub(
|
|
||||||
self,
|
|
||||||
address: int,
|
|
||||||
length: int,
|
|
||||||
ids_values: dict[int, int],
|
|
||||||
reply: bool = True,
|
|
||||||
num_invalid_try: int = 0,
|
|
||||||
) -> str:
|
|
||||||
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
|
|
||||||
return_packets = (
|
|
||||||
b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items())
|
|
||||||
if reply
|
|
||||||
else b""
|
|
||||||
)
|
|
||||||
sync_read_response = self._build_send_fn(return_packets, num_invalid_try)
|
|
||||||
stub_name = f"Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
|
|
||||||
self.stub(
|
|
||||||
name=stub_name,
|
|
||||||
receive_bytes=sync_read_request,
|
|
||||||
send_fn=sync_read_response,
|
|
||||||
)
|
|
||||||
return stub_name
|
|
||||||
|
|
||||||
def build_sequential_sync_read_stub(
|
|
||||||
self, address: int, length: int, ids_values: dict[int, list[int]] | None = None
|
|
||||||
) -> str:
|
|
||||||
sequence_length = len(next(iter(ids_values.values())))
|
|
||||||
assert all(len(positions) == sequence_length for positions in ids_values.values())
|
|
||||||
sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length)
|
|
||||||
sequential_packets = []
|
|
||||||
for count in range(sequence_length):
|
|
||||||
return_packets = b"".join(
|
|
||||||
MockStatusPacket.read(id_, positions[count], length) for id_, positions in ids_values.items()
|
|
||||||
)
|
|
||||||
sequential_packets.append(return_packets)
|
|
||||||
|
|
||||||
sync_read_response = self._build_sequential_send_fn(sequential_packets)
|
|
||||||
stub_name = f"Seq_Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
|
|
||||||
self.stub(
|
|
||||||
name=stub_name,
|
|
||||||
receive_bytes=sync_read_request,
|
|
||||||
send_fn=sync_read_response,
|
|
||||||
)
|
|
||||||
return stub_name
|
|
||||||
|
|
||||||
def build_sync_write_stub(
|
|
||||||
self, address: int, length: int, ids_values: dict[int, int], num_invalid_try: int = 0
|
|
||||||
) -> str:
|
|
||||||
sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length)
|
|
||||||
stub_name = f"Sync_Write_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values])
|
|
||||||
self.stub(
|
|
||||||
name=stub_name,
|
|
||||||
receive_bytes=sync_read_request,
|
|
||||||
send_fn=self._build_send_fn(b"", num_invalid_try),
|
|
||||||
)
|
|
||||||
return stub_name
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]:
|
|
||||||
def send_fn(_call_count: int) -> bytes:
|
|
||||||
if num_invalid_try >= _call_count:
|
|
||||||
return b""
|
|
||||||
return packet
|
|
||||||
|
|
||||||
return send_fn
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _build_sequential_send_fn(packets: list[bytes]) -> Callable[[int], bytes]:
|
|
||||||
def send_fn(_call_count: int) -> bytes:
|
|
||||||
return packets[_call_count - 1]
|
|
||||||
|
|
||||||
return send_fn
|
|
||||||
@ -1,152 +0,0 @@
|
|||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
# ruff: noqa: N802
|
|
||||||
|
|
||||||
from lerobot.motors.motors_bus import (
|
|
||||||
Motor,
|
|
||||||
MotorsBus,
|
|
||||||
)
|
|
||||||
|
|
||||||
DUMMY_CTRL_TABLE_1 = {
|
|
||||||
"Firmware_Version": (0, 1),
|
|
||||||
"Model_Number": (1, 2),
|
|
||||||
"Present_Position": (3, 4),
|
|
||||||
"Goal_Position": (11, 2),
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_CTRL_TABLE_2 = {
|
|
||||||
"Model_Number": (0, 2),
|
|
||||||
"Firmware_Version": (2, 1),
|
|
||||||
"Present_Position": (3, 4),
|
|
||||||
"Present_Velocity": (7, 4),
|
|
||||||
"Goal_Position": (11, 4),
|
|
||||||
"Goal_Velocity": (15, 4),
|
|
||||||
"Lock": (19, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_MODEL_CTRL_TABLE = {
|
|
||||||
"model_1": DUMMY_CTRL_TABLE_1,
|
|
||||||
"model_2": DUMMY_CTRL_TABLE_2,
|
|
||||||
"model_3": DUMMY_CTRL_TABLE_2,
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_BAUDRATE_TABLE = {
|
|
||||||
0: 1_000_000,
|
|
||||||
1: 500_000,
|
|
||||||
2: 250_000,
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_MODEL_BAUDRATE_TABLE = {
|
|
||||||
"model_1": DUMMY_BAUDRATE_TABLE,
|
|
||||||
"model_2": DUMMY_BAUDRATE_TABLE,
|
|
||||||
"model_3": DUMMY_BAUDRATE_TABLE,
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_ENCODING_TABLE = {
|
|
||||||
"Present_Position": 8,
|
|
||||||
"Goal_Position": 10,
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_MODEL_ENCODING_TABLE = {
|
|
||||||
"model_1": DUMMY_ENCODING_TABLE,
|
|
||||||
"model_2": DUMMY_ENCODING_TABLE,
|
|
||||||
"model_3": DUMMY_ENCODING_TABLE,
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_MODEL_NUMBER_TABLE = {
|
|
||||||
"model_1": 1234,
|
|
||||||
"model_2": 5678,
|
|
||||||
"model_3": 5799,
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_MODEL_RESOLUTION_TABLE = {
|
|
||||||
"model_1": 4096,
|
|
||||||
"model_2": 1024,
|
|
||||||
"model_3": 4096,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class MockPortHandler:
|
|
||||||
def __init__(self, port_name):
|
|
||||||
self.is_open: bool = False
|
|
||||||
self.baudrate: int
|
|
||||||
self.packet_start_time: float
|
|
||||||
self.packet_timeout: float
|
|
||||||
self.tx_time_per_byte: float
|
|
||||||
self.is_using: bool = False
|
|
||||||
self.port_name: str = port_name
|
|
||||||
self.ser = None
|
|
||||||
|
|
||||||
def openPort(self):
|
|
||||||
self.is_open = True
|
|
||||||
return self.is_open
|
|
||||||
|
|
||||||
def closePort(self):
|
|
||||||
self.is_open = False
|
|
||||||
|
|
||||||
def clearPort(self): ...
|
|
||||||
def setPortName(self, port_name):
|
|
||||||
self.port_name = port_name
|
|
||||||
|
|
||||||
def getPortName(self):
|
|
||||||
return self.port_name
|
|
||||||
|
|
||||||
def setBaudRate(self, baudrate):
|
|
||||||
self.baudrate: baudrate
|
|
||||||
|
|
||||||
def getBaudRate(self):
|
|
||||||
return self.baudrate
|
|
||||||
|
|
||||||
def getBytesAvailable(self): ...
|
|
||||||
def readPort(self, length): ...
|
|
||||||
def writePort(self, packet): ...
|
|
||||||
def setPacketTimeout(self, packet_length): ...
|
|
||||||
def setPacketTimeoutMillis(self, msec): ...
|
|
||||||
def isPacketTimeout(self): ...
|
|
||||||
def getCurrentTime(self): ...
|
|
||||||
def getTimeSinceStart(self): ...
|
|
||||||
def setupPort(self, cflag_baud): ...
|
|
||||||
def getCFlagBaud(self, baudrate): ...
|
|
||||||
|
|
||||||
|
|
||||||
class MockMotorsBus(MotorsBus):
|
|
||||||
available_baudrates = [500_000, 1_000_000]
|
|
||||||
default_timeout = 1000
|
|
||||||
model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE
|
|
||||||
model_ctrl_table = DUMMY_MODEL_CTRL_TABLE
|
|
||||||
model_encoding_table = DUMMY_MODEL_ENCODING_TABLE
|
|
||||||
model_number_table = DUMMY_MODEL_NUMBER_TABLE
|
|
||||||
model_resolution_table = DUMMY_MODEL_RESOLUTION_TABLE
|
|
||||||
normalized_data = ["Present_Position", "Goal_Position"]
|
|
||||||
|
|
||||||
def __init__(self, port: str, motors: dict[str, Motor]):
|
|
||||||
super().__init__(port, motors)
|
|
||||||
self.port_handler = MockPortHandler(port)
|
|
||||||
|
|
||||||
def _assert_protocol_is_compatible(self, instruction_name): ...
|
|
||||||
def _handshake(self): ...
|
|
||||||
def _find_single_motor(self, motor, initial_baudrate): ...
|
|
||||||
def configure_motors(self): ...
|
|
||||||
def is_calibrated(self): ...
|
|
||||||
def read_calibration(self): ...
|
|
||||||
def write_calibration(self, calibration_dict): ...
|
|
||||||
def disable_torque(self, motors, num_retry): ...
|
|
||||||
def _disable_torque(self, motor, model, num_retry): ...
|
|
||||||
def enable_torque(self, motors, num_retry): ...
|
|
||||||
def _get_half_turn_homings(self, positions): ...
|
|
||||||
def _encode_sign(self, data_name, ids_values): ...
|
|
||||||
def _decode_sign(self, data_name, ids_values): ...
|
|
||||||
def _split_into_byte_chunks(self, value, length): ...
|
|
||||||
def broadcast_ping(self, num_retry, raise_on_error): ...
|
|
||||||
@ -1,128 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import random
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from functools import cached_property
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from lerobot.cameras import CameraConfig, make_cameras_from_configs
|
|
||||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
|
||||||
from lerobot.robots import Robot, RobotConfig
|
|
||||||
|
|
||||||
|
|
||||||
@RobotConfig.register_subclass("mock_robot")
|
|
||||||
@dataclass
|
|
||||||
class MockRobotConfig(RobotConfig):
|
|
||||||
n_motors: int = 3
|
|
||||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
|
||||||
random_values: bool = True
|
|
||||||
static_values: list[float] | None = None
|
|
||||||
calibrated: bool = True
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.n_motors < 1:
|
|
||||||
raise ValueError(self.n_motors)
|
|
||||||
|
|
||||||
if self.random_values and self.static_values is not None:
|
|
||||||
raise ValueError("Choose either random values or static values")
|
|
||||||
|
|
||||||
if self.static_values is not None and len(self.static_values) != self.n_motors:
|
|
||||||
raise ValueError("Specify the same number of static values as motors")
|
|
||||||
|
|
||||||
if len(self.cameras) > 0:
|
|
||||||
raise NotImplementedError # TODO with the cameras refactor
|
|
||||||
|
|
||||||
|
|
||||||
class MockRobot(Robot):
|
|
||||||
"""Mock Robot to be used for testing."""
|
|
||||||
|
|
||||||
config_class = MockRobotConfig
|
|
||||||
name = "mock_robot"
|
|
||||||
|
|
||||||
def __init__(self, config: MockRobotConfig):
|
|
||||||
super().__init__(config)
|
|
||||||
self.config = config
|
|
||||||
self._is_connected = False
|
|
||||||
self._is_calibrated = config.calibrated
|
|
||||||
self.motors = [f"motor_{i + 1}" for i in range(config.n_motors)]
|
|
||||||
self.cameras = make_cameras_from_configs(config.cameras)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _motors_ft(self) -> dict[str, type]:
|
|
||||||
return {f"{motor}.pos": float for motor in self.motors}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _cameras_ft(self) -> dict[str, tuple]:
|
|
||||||
return {
|
|
||||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
|
||||||
}
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def observation_features(self) -> dict[str, type | tuple]:
|
|
||||||
return {**self._motors_ft, **self._cameras_ft}
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def action_features(self) -> dict[str, type]:
|
|
||||||
return self._motors_ft
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_connected(self) -> bool:
|
|
||||||
return self._is_connected
|
|
||||||
|
|
||||||
def connect(self, calibrate: bool = True) -> None:
|
|
||||||
if self.is_connected:
|
|
||||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
|
||||||
|
|
||||||
self._is_connected = True
|
|
||||||
if calibrate:
|
|
||||||
self.calibrate()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_calibrated(self) -> bool:
|
|
||||||
return self._is_calibrated
|
|
||||||
|
|
||||||
def calibrate(self) -> None:
|
|
||||||
if not self.is_connected:
|
|
||||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
|
||||||
|
|
||||||
self._is_calibrated = True
|
|
||||||
|
|
||||||
def configure(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_observation(self) -> dict[str, Any]:
|
|
||||||
if not self.is_connected:
|
|
||||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
|
||||||
|
|
||||||
if self.config.random_values:
|
|
||||||
return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True)
|
|
||||||
}
|
|
||||||
|
|
||||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
if not self.is_connected:
|
|
||||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
|
||||||
|
|
||||||
return action
|
|
||||||
|
|
||||||
def disconnect(self) -> None:
|
|
||||||
if not self.is_connected:
|
|
||||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
|
||||||
|
|
||||||
self._is_connected = False
|
|
||||||
@ -1,51 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
|
|
||||||
from mock_serial.mock_serial import Stub
|
|
||||||
|
|
||||||
|
|
||||||
class WaitableStub(Stub):
|
|
||||||
"""
|
|
||||||
In some situations, a test might be checking if a stub has been called before `MockSerial` thread had time
|
|
||||||
to read, match, and call the stub. In these situations, the test can fail randomly.
|
|
||||||
|
|
||||||
Use `wait_called()` or `wait_calls()` to block until the stub is called, avoiding race conditions.
|
|
||||||
|
|
||||||
Proposed fix:
|
|
||||||
https://github.com/benthorner/mock_serial/pull/3
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._event = threading.Event()
|
|
||||||
|
|
||||||
def call(self):
|
|
||||||
self._event.set()
|
|
||||||
return super().call()
|
|
||||||
|
|
||||||
def wait_called(self, timeout: float = 1.0):
|
|
||||||
return self._event.wait(timeout)
|
|
||||||
|
|
||||||
def wait_calls(self, min_calls: int = 1, timeout: float = 1.0):
|
|
||||||
start = time.perf_counter()
|
|
||||||
while time.perf_counter() - start < timeout:
|
|
||||||
if self.calls >= min_calls:
|
|
||||||
return self.calls
|
|
||||||
time.sleep(0.005)
|
|
||||||
raise TimeoutError(f"Stub not called {min_calls} times within {timeout} seconds.")
|
|
||||||
@ -1,110 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import random
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from functools import cached_property
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
|
||||||
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig
|
|
||||||
|
|
||||||
|
|
||||||
@TeleoperatorConfig.register_subclass("mock_teleop")
|
|
||||||
@dataclass
|
|
||||||
class MockTeleopConfig(TeleoperatorConfig):
|
|
||||||
n_motors: int = 3
|
|
||||||
random_values: bool = True
|
|
||||||
static_values: list[float] | None = None
|
|
||||||
calibrated: bool = True
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.n_motors < 1:
|
|
||||||
raise ValueError(self.n_motors)
|
|
||||||
|
|
||||||
if self.random_values and self.static_values is not None:
|
|
||||||
raise ValueError("Choose either random values or static values")
|
|
||||||
|
|
||||||
if self.static_values is not None and len(self.static_values) != self.n_motors:
|
|
||||||
raise ValueError("Specify the same number of static values as motors")
|
|
||||||
|
|
||||||
|
|
||||||
class MockTeleop(Teleoperator):
|
|
||||||
"""Mock Teleoperator to be used for testing."""
|
|
||||||
|
|
||||||
config_class = MockTeleopConfig
|
|
||||||
name = "mock_teleop"
|
|
||||||
|
|
||||||
def __init__(self, config: MockTeleopConfig):
|
|
||||||
super().__init__(config)
|
|
||||||
self.config = config
|
|
||||||
self._is_connected = False
|
|
||||||
self._is_calibrated = config.calibrated
|
|
||||||
self.motors = [f"motor_{i + 1}" for i in range(config.n_motors)]
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def action_features(self) -> dict[str, type]:
|
|
||||||
return {f"{motor}.pos": float for motor in self.motors}
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def feedback_features(self) -> dict[str, type]:
|
|
||||||
return {f"{motor}.pos": float for motor in self.motors}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_connected(self) -> bool:
|
|
||||||
return self._is_connected
|
|
||||||
|
|
||||||
def connect(self, calibrate: bool = True) -> None:
|
|
||||||
if self.is_connected:
|
|
||||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
|
||||||
|
|
||||||
self._is_connected = True
|
|
||||||
if calibrate:
|
|
||||||
self.calibrate()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_calibrated(self) -> bool:
|
|
||||||
return self._is_calibrated
|
|
||||||
|
|
||||||
def calibrate(self) -> None:
|
|
||||||
if not self.is_connected:
|
|
||||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
|
||||||
|
|
||||||
self._is_calibrated = True
|
|
||||||
|
|
||||||
def configure(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_action(self) -> dict[str, Any]:
|
|
||||||
if not self.is_connected:
|
|
||||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
|
||||||
|
|
||||||
if self.config.random_values:
|
|
||||||
return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True)
|
|
||||||
}
|
|
||||||
|
|
||||||
def send_feedback(self, feedback: dict[str, Any]) -> None:
|
|
||||||
if not self.is_connected:
|
|
||||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
|
||||||
|
|
||||||
def disconnect(self) -> None:
|
|
||||||
if not self.is_connected:
|
|
||||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
|
||||||
|
|
||||||
self._is_connected = False
|
|
||||||
@ -1,416 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from collections.abc import Generator
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
|
||||||
from lerobot.motors.dynamixel import MODEL_NUMBER_TABLE, DynamixelMotorsBus
|
|
||||||
from lerobot.motors.dynamixel.tables import X_SERIES_CONTROL_TABLE
|
|
||||||
from lerobot.utils.encoding_utils import encode_twos_complement
|
|
||||||
|
|
||||||
try:
|
|
||||||
import dynamixel_sdk as dxl
|
|
||||||
|
|
||||||
from tests.mocks.mock_dynamixel import MockMotors, MockPortHandler
|
|
||||||
except (ImportError, ModuleNotFoundError):
|
|
||||||
pytest.skip("dynamixel_sdk not available", allow_module_level=True)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def patch_port_handler():
|
|
||||||
if sys.platform == "darwin":
|
|
||||||
with patch.object(dxl, "PortHandler", MockPortHandler):
|
|
||||||
yield
|
|
||||||
else:
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_motors() -> Generator[MockMotors, None, None]:
|
|
||||||
motors = MockMotors()
|
|
||||||
motors.open()
|
|
||||||
yield motors
|
|
||||||
motors.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def dummy_motors() -> dict[str, Motor]:
|
|
||||||
return {
|
|
||||||
"dummy_1": Motor(1, "xl430-w250", MotorNormMode.RANGE_M100_100),
|
|
||||||
"dummy_2": Motor(2, "xm540-w270", MotorNormMode.RANGE_M100_100),
|
|
||||||
"dummy_3": Motor(3, "xl330-m077", MotorNormMode.RANGE_M100_100),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def dummy_calibration(dummy_motors) -> dict[str, MotorCalibration]:
|
|
||||||
drive_modes = [0, 1, 0]
|
|
||||||
homings = [-709, -2006, 1624]
|
|
||||||
mins = [43, 27, 145]
|
|
||||||
maxes = [1335, 3608, 3999]
|
|
||||||
calibration = {}
|
|
||||||
for motor, m in dummy_motors.items():
|
|
||||||
calibration[motor] = MotorCalibration(
|
|
||||||
id=m.id,
|
|
||||||
drive_mode=drive_modes[m.id - 1],
|
|
||||||
homing_offset=homings[m.id - 1],
|
|
||||||
range_min=mins[m.id - 1],
|
|
||||||
range_max=maxes[m.id - 1],
|
|
||||||
)
|
|
||||||
return calibration
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}")
|
|
||||||
def test_autouse_patch():
|
|
||||||
"""Ensures that the autouse fixture correctly patches dxl.PortHandler with MockPortHandler."""
|
|
||||||
assert dxl.PortHandler is MockPortHandler
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"value, length, expected",
|
|
||||||
[
|
|
||||||
(0x12, 1, [0x12]),
|
|
||||||
(0x1234, 2, [0x34, 0x12]),
|
|
||||||
(0x12345678, 4, [0x78, 0x56, 0x34, 0x12]),
|
|
||||||
],
|
|
||||||
ids=[
|
|
||||||
"1 byte",
|
|
||||||
"2 bytes",
|
|
||||||
"4 bytes",
|
|
||||||
],
|
|
||||||
) # fmt: skip
|
|
||||||
def test__split_into_byte_chunks(value, length, expected):
|
|
||||||
bus = DynamixelMotorsBus("", {})
|
|
||||||
assert bus._split_into_byte_chunks(value, length) == expected
|
|
||||||
|
|
||||||
|
|
||||||
def test_abc_implementation(dummy_motors):
|
|
||||||
"""Instantiation should raise an error if the class doesn't implement abstract methods/properties."""
|
|
||||||
DynamixelMotorsBus(port="/dev/dummy-port", motors=dummy_motors)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("id_", [1, 2, 3])
|
|
||||||
def test_ping(id_, mock_motors, dummy_motors):
|
|
||||||
expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model]
|
|
||||||
stub = mock_motors.build_ping_stub(id_, expected_model_nb)
|
|
||||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
ping_model_nb = bus.ping(id_)
|
|
||||||
|
|
||||||
assert ping_model_nb == expected_model_nb
|
|
||||||
assert mock_motors.stubs[stub].called
|
|
||||||
|
|
||||||
|
|
||||||
def test_broadcast_ping(mock_motors, dummy_motors):
|
|
||||||
models = {m.id: m.model for m in dummy_motors.values()}
|
|
||||||
expected_model_nbs = {id_: MODEL_NUMBER_TABLE[model] for id_, model in models.items()}
|
|
||||||
stub = mock_motors.build_broadcast_ping_stub(expected_model_nbs)
|
|
||||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
ping_model_nbs = bus.broadcast_ping()
|
|
||||||
|
|
||||||
assert ping_model_nbs == expected_model_nbs
|
|
||||||
assert mock_motors.stubs[stub].called
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"addr, length, id_, value",
|
|
||||||
[
|
|
||||||
(0, 1, 1, 2),
|
|
||||||
(10, 2, 2, 999),
|
|
||||||
(42, 4, 3, 1337),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test__read(addr, length, id_, value, mock_motors, dummy_motors):
|
|
||||||
stub = mock_motors.build_read_stub(addr, length, id_, value)
|
|
||||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
read_value, _, _ = bus._read(addr, length, id_)
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub].called
|
|
||||||
assert read_value == value
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("raise_on_error", (True, False))
|
|
||||||
def test__read_error(raise_on_error, mock_motors, dummy_motors):
|
|
||||||
addr, length, id_, value, error = (10, 4, 1, 1337, dxl.ERRNUM_DATA_LIMIT)
|
|
||||||
stub = mock_motors.build_read_stub(addr, length, id_, value, error=error)
|
|
||||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
if raise_on_error:
|
|
||||||
with pytest.raises(
|
|
||||||
RuntimeError, match=re.escape("[RxPacketError] The data value exceeds the limit value!")
|
|
||||||
):
|
|
||||||
bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
|
||||||
else:
|
|
||||||
_, _, read_error = bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
|
||||||
assert read_error == error
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub].called
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("raise_on_error", (True, False))
|
|
||||||
def test__read_comm(raise_on_error, mock_motors, dummy_motors):
|
|
||||||
addr, length, id_, value = (10, 4, 1, 1337)
|
|
||||||
stub = mock_motors.build_read_stub(addr, length, id_, value, reply=False)
|
|
||||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
if raise_on_error:
|
|
||||||
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
|
|
||||||
bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
|
||||||
else:
|
|
||||||
_, read_comm, _ = bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
|
||||||
assert read_comm == dxl.COMM_RX_TIMEOUT
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub].called
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"addr, length, id_, value",
|
|
||||||
[
|
|
||||||
(0, 1, 1, 2),
|
|
||||||
(10, 2, 2, 999),
|
|
||||||
(42, 4, 3, 1337),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test__write(addr, length, id_, value, mock_motors, dummy_motors):
|
|
||||||
stub = mock_motors.build_write_stub(addr, length, id_, value)
|
|
||||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
comm, error = bus._write(addr, length, id_, value)
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub].called
|
|
||||||
assert comm == dxl.COMM_SUCCESS
|
|
||||||
assert error == 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("raise_on_error", (True, False))
|
|
||||||
def test__write_error(raise_on_error, mock_motors, dummy_motors):
|
|
||||||
addr, length, id_, value, error = (10, 4, 1, 1337, dxl.ERRNUM_DATA_LIMIT)
|
|
||||||
stub = mock_motors.build_write_stub(addr, length, id_, value, error=error)
|
|
||||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
if raise_on_error:
|
|
||||||
with pytest.raises(
|
|
||||||
RuntimeError, match=re.escape("[RxPacketError] The data value exceeds the limit value!")
|
|
||||||
):
|
|
||||||
bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
|
||||||
else:
|
|
||||||
_, write_error = bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
|
||||||
assert write_error == error
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub].called
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("raise_on_error", (True, False))
|
|
||||||
def test__write_comm(raise_on_error, mock_motors, dummy_motors):
|
|
||||||
addr, length, id_, value = (10, 4, 1, 1337)
|
|
||||||
stub = mock_motors.build_write_stub(addr, length, id_, value, reply=False)
|
|
||||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
if raise_on_error:
|
|
||||||
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
|
|
||||||
bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
|
||||||
else:
|
|
||||||
write_comm, _ = bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
|
||||||
assert write_comm == dxl.COMM_RX_TIMEOUT
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub].called
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"addr, length, ids_values",
|
|
||||||
[
|
|
||||||
(0, 1, {1: 4}),
|
|
||||||
(10, 2, {1: 1337, 2: 42}),
|
|
||||||
(42, 4, {1: 1337, 2: 42, 3: 4016}),
|
|
||||||
],
|
|
||||||
ids=["1 motor", "2 motors", "3 motors"],
|
|
||||||
)
|
|
||||||
def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors):
|
|
||||||
stub = mock_motors.build_sync_read_stub(addr, length, ids_values)
|
|
||||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
read_values, _ = bus._sync_read(addr, length, list(ids_values))
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub].called
|
|
||||||
assert read_values == ids_values
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("raise_on_error", (True, False))
|
|
||||||
def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors):
|
|
||||||
addr, length, ids_values = (10, 4, {1: 1337})
|
|
||||||
stub = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False)
|
|
||||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
if raise_on_error:
|
|
||||||
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
|
|
||||||
bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
|
|
||||||
else:
|
|
||||||
_, read_comm = bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
|
|
||||||
assert read_comm == dxl.COMM_RX_TIMEOUT
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub].called
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"addr, length, ids_values",
|
|
||||||
[
|
|
||||||
(0, 1, {1: 4}),
|
|
||||||
(10, 2, {1: 1337, 2: 42}),
|
|
||||||
(42, 4, {1: 1337, 2: 42, 3: 4016}),
|
|
||||||
],
|
|
||||||
ids=["1 motor", "2 motors", "3 motors"],
|
|
||||||
)
|
|
||||||
def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors):
|
|
||||||
stub = mock_motors.build_sync_write_stub(addr, length, ids_values)
|
|
||||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
comm = bus._sync_write(addr, length, ids_values)
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub].wait_called()
|
|
||||||
assert comm == dxl.COMM_SUCCESS
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration):
|
|
||||||
drive_modes = {m.id: m.drive_mode for m in dummy_calibration.values()}
|
|
||||||
encoded_homings = {m.id: encode_twos_complement(m.homing_offset, 4) for m in dummy_calibration.values()}
|
|
||||||
mins = {m.id: m.range_min for m in dummy_calibration.values()}
|
|
||||||
maxes = {m.id: m.range_max for m in dummy_calibration.values()}
|
|
||||||
drive_modes_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Drive_Mode"], drive_modes)
|
|
||||||
offsets_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], encoded_homings)
|
|
||||||
mins_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Min_Position_Limit"], mins)
|
|
||||||
maxes_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Max_Position_Limit"], maxes)
|
|
||||||
bus = DynamixelMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
calibration=dummy_calibration,
|
|
||||||
)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
is_calibrated = bus.is_calibrated
|
|
||||||
|
|
||||||
assert is_calibrated
|
|
||||||
assert mock_motors.stubs[drive_modes_stub].called
|
|
||||||
assert mock_motors.stubs[offsets_stub].called
|
|
||||||
assert mock_motors.stubs[mins_stub].called
|
|
||||||
assert mock_motors.stubs[maxes_stub].called
|
|
||||||
|
|
||||||
|
|
||||||
def test_reset_calibration(mock_motors, dummy_motors):
|
|
||||||
write_homing_stubs = []
|
|
||||||
write_mins_stubs = []
|
|
||||||
write_maxes_stubs = []
|
|
||||||
for motor in dummy_motors.values():
|
|
||||||
write_homing_stubs.append(
|
|
||||||
mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], motor.id, 0)
|
|
||||||
)
|
|
||||||
write_mins_stubs.append(
|
|
||||||
mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Min_Position_Limit"], motor.id, 0)
|
|
||||||
)
|
|
||||||
write_maxes_stubs.append(
|
|
||||||
mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Max_Position_Limit"], motor.id, 4095)
|
|
||||||
)
|
|
||||||
|
|
||||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
bus.reset_calibration()
|
|
||||||
|
|
||||||
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
|
|
||||||
assert all(mock_motors.stubs[stub].called for stub in write_mins_stubs)
|
|
||||||
assert all(mock_motors.stubs[stub].called for stub in write_maxes_stubs)
|
|
||||||
|
|
||||||
|
|
||||||
def test_set_half_turn_homings(mock_motors, dummy_motors):
|
|
||||||
"""
|
|
||||||
For this test, we assume that the homing offsets are already 0 such that
|
|
||||||
Present_Position == Actual_Position
|
|
||||||
"""
|
|
||||||
current_positions = {
|
|
||||||
1: 1337,
|
|
||||||
2: 42,
|
|
||||||
3: 3672,
|
|
||||||
}
|
|
||||||
expected_homings = {
|
|
||||||
1: 710, # 2047 - 1337
|
|
||||||
2: 2005, # 2047 - 42
|
|
||||||
3: -1625, # 2047 - 3672
|
|
||||||
}
|
|
||||||
read_pos_stub = mock_motors.build_sync_read_stub(
|
|
||||||
*X_SERIES_CONTROL_TABLE["Present_Position"], current_positions
|
|
||||||
)
|
|
||||||
write_homing_stubs = []
|
|
||||||
for id_, homing in expected_homings.items():
|
|
||||||
encoded_homing = encode_twos_complement(homing, 4)
|
|
||||||
stub = mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], id_, encoded_homing)
|
|
||||||
write_homing_stubs.append(stub)
|
|
||||||
|
|
||||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
bus.reset_calibration = MagicMock()
|
|
||||||
|
|
||||||
bus.set_half_turn_homings()
|
|
||||||
|
|
||||||
bus.reset_calibration.assert_called_once()
|
|
||||||
assert mock_motors.stubs[read_pos_stub].called
|
|
||||||
assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs)
|
|
||||||
|
|
||||||
|
|
||||||
def test_record_ranges_of_motion(mock_motors, dummy_motors):
|
|
||||||
positions = {
|
|
||||||
1: [351, 42, 1337],
|
|
||||||
2: [28, 3600, 2444],
|
|
||||||
3: [4002, 2999, 146],
|
|
||||||
}
|
|
||||||
expected_mins = {
|
|
||||||
"dummy_1": 42,
|
|
||||||
"dummy_2": 28,
|
|
||||||
"dummy_3": 146,
|
|
||||||
}
|
|
||||||
expected_maxes = {
|
|
||||||
"dummy_1": 1337,
|
|
||||||
"dummy_2": 3600,
|
|
||||||
"dummy_3": 4002,
|
|
||||||
}
|
|
||||||
read_pos_stub = mock_motors.build_sequential_sync_read_stub(
|
|
||||||
*X_SERIES_CONTROL_TABLE["Present_Position"], positions
|
|
||||||
)
|
|
||||||
with patch("lerobot.motors.motors_bus.enter_pressed", side_effect=[False, True]):
|
|
||||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
mins, maxes = bus.record_ranges_of_motion(display_values=False)
|
|
||||||
|
|
||||||
assert mock_motors.stubs[read_pos_stub].calls == 3
|
|
||||||
assert mins == expected_mins
|
|
||||||
assert maxes == expected_maxes
|
|
||||||
@ -1,459 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from collections.abc import Generator
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
|
||||||
from lerobot.motors.feetech import MODEL_NUMBER, MODEL_NUMBER_TABLE, FeetechMotorsBus
|
|
||||||
from lerobot.motors.feetech.tables import STS_SMS_SERIES_CONTROL_TABLE
|
|
||||||
from lerobot.utils.encoding_utils import encode_sign_magnitude
|
|
||||||
|
|
||||||
try:
|
|
||||||
import scservo_sdk as scs
|
|
||||||
|
|
||||||
from tests.mocks.mock_feetech import MockMotors, MockPortHandler
|
|
||||||
except (ImportError, ModuleNotFoundError):
|
|
||||||
pytest.skip("scservo_sdk not available", allow_module_level=True)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def patch_port_handler():
|
|
||||||
if sys.platform == "darwin":
|
|
||||||
with patch.object(scs, "PortHandler", MockPortHandler):
|
|
||||||
yield
|
|
||||||
else:
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_motors() -> Generator[MockMotors, None, None]:
|
|
||||||
motors = MockMotors()
|
|
||||||
motors.open()
|
|
||||||
yield motors
|
|
||||||
motors.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def dummy_motors() -> dict[str, Motor]:
|
|
||||||
return {
|
|
||||||
"dummy_1": Motor(1, "sts3215", MotorNormMode.RANGE_M100_100),
|
|
||||||
"dummy_2": Motor(2, "sts3215", MotorNormMode.RANGE_M100_100),
|
|
||||||
"dummy_3": Motor(3, "sts3215", MotorNormMode.RANGE_M100_100),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def dummy_calibration(dummy_motors) -> dict[str, MotorCalibration]:
|
|
||||||
homings = [-709, -2006, 1624]
|
|
||||||
mins = [43, 27, 145]
|
|
||||||
maxes = [1335, 3608, 3999]
|
|
||||||
calibration = {}
|
|
||||||
for motor, m in dummy_motors.items():
|
|
||||||
calibration[motor] = MotorCalibration(
|
|
||||||
id=m.id,
|
|
||||||
drive_mode=0,
|
|
||||||
homing_offset=homings[m.id - 1],
|
|
||||||
range_min=mins[m.id - 1],
|
|
||||||
range_max=maxes[m.id - 1],
|
|
||||||
)
|
|
||||||
return calibration
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}")
|
|
||||||
def test_autouse_patch():
|
|
||||||
"""Ensures that the autouse fixture correctly patches scs.PortHandler with MockPortHandler."""
|
|
||||||
assert scs.PortHandler is MockPortHandler
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"protocol, value, length, expected",
|
|
||||||
[
|
|
||||||
(0, 0x12, 1, [0x12]),
|
|
||||||
(1, 0x12, 1, [0x12]),
|
|
||||||
(0, 0x1234, 2, [0x34, 0x12]),
|
|
||||||
(1, 0x1234, 2, [0x12, 0x34]),
|
|
||||||
(0, 0x12345678, 4, [0x78, 0x56, 0x34, 0x12]),
|
|
||||||
(1, 0x12345678, 4, [0x56, 0x78, 0x12, 0x34]),
|
|
||||||
],
|
|
||||||
ids=[
|
|
||||||
"P0: 1 byte",
|
|
||||||
"P1: 1 byte",
|
|
||||||
"P0: 2 bytes",
|
|
||||||
"P1: 2 bytes",
|
|
||||||
"P0: 4 bytes",
|
|
||||||
"P1: 4 bytes",
|
|
||||||
],
|
|
||||||
) # fmt: skip
|
|
||||||
def test__split_into_byte_chunks(protocol, value, length, expected):
|
|
||||||
bus = FeetechMotorsBus("", {}, protocol_version=protocol)
|
|
||||||
assert bus._split_into_byte_chunks(value, length) == expected
|
|
||||||
|
|
||||||
|
|
||||||
def test_abc_implementation(dummy_motors):
|
|
||||||
"""Instantiation should raise an error if the class doesn't implement abstract methods/properties."""
|
|
||||||
FeetechMotorsBus(port="/dev/dummy-port", motors=dummy_motors)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("id_", [1, 2, 3])
|
|
||||||
def test_ping(id_, mock_motors, dummy_motors):
|
|
||||||
expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model]
|
|
||||||
addr, length = MODEL_NUMBER
|
|
||||||
ping_stub = mock_motors.build_ping_stub(id_)
|
|
||||||
mobel_nb_stub = mock_motors.build_read_stub(addr, length, id_, expected_model_nb)
|
|
||||||
bus = FeetechMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
ping_model_nb = bus.ping(id_)
|
|
||||||
|
|
||||||
assert ping_model_nb == expected_model_nb
|
|
||||||
assert mock_motors.stubs[ping_stub].called
|
|
||||||
assert mock_motors.stubs[mobel_nb_stub].called
|
|
||||||
|
|
||||||
|
|
||||||
def test_broadcast_ping(mock_motors, dummy_motors):
|
|
||||||
models = {m.id: m.model for m in dummy_motors.values()}
|
|
||||||
addr, length = MODEL_NUMBER
|
|
||||||
ping_stub = mock_motors.build_broadcast_ping_stub(list(models))
|
|
||||||
mobel_nb_stubs = []
|
|
||||||
expected_model_nbs = {}
|
|
||||||
for id_, model in models.items():
|
|
||||||
model_nb = MODEL_NUMBER_TABLE[model]
|
|
||||||
stub = mock_motors.build_read_stub(addr, length, id_, model_nb)
|
|
||||||
expected_model_nbs[id_] = model_nb
|
|
||||||
mobel_nb_stubs.append(stub)
|
|
||||||
bus = FeetechMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
ping_model_nbs = bus.broadcast_ping()
|
|
||||||
|
|
||||||
assert ping_model_nbs == expected_model_nbs
|
|
||||||
assert mock_motors.stubs[ping_stub].called
|
|
||||||
assert all(mock_motors.stubs[stub].called for stub in mobel_nb_stubs)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"addr, length, id_, value",
|
|
||||||
[
|
|
||||||
(0, 1, 1, 2),
|
|
||||||
(10, 2, 2, 999),
|
|
||||||
(42, 4, 3, 1337),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test__read(addr, length, id_, value, mock_motors, dummy_motors):
|
|
||||||
stub = mock_motors.build_read_stub(addr, length, id_, value)
|
|
||||||
bus = FeetechMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
read_value, _, _ = bus._read(addr, length, id_)
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub].called
|
|
||||||
assert read_value == value
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("raise_on_error", (True, False))
|
|
||||||
def test__read_error(raise_on_error, mock_motors, dummy_motors):
|
|
||||||
addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE)
|
|
||||||
stub = mock_motors.build_read_stub(addr, length, id_, value, error=error)
|
|
||||||
bus = FeetechMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
if raise_on_error:
|
|
||||||
with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")):
|
|
||||||
bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
|
||||||
else:
|
|
||||||
_, _, read_error = bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
|
||||||
assert read_error == error
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub].called
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("raise_on_error", (True, False))
|
|
||||||
def test__read_comm(raise_on_error, mock_motors, dummy_motors):
|
|
||||||
addr, length, id_, value = (10, 4, 1, 1337)
|
|
||||||
stub = mock_motors.build_read_stub(addr, length, id_, value, reply=False)
|
|
||||||
bus = FeetechMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
if raise_on_error:
|
|
||||||
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
|
|
||||||
bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
|
||||||
else:
|
|
||||||
_, read_comm, _ = bus._read(addr, length, id_, raise_on_error=raise_on_error)
|
|
||||||
assert read_comm == scs.COMM_RX_TIMEOUT
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub].called
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"addr, length, id_, value",
|
|
||||||
[
|
|
||||||
(0, 1, 1, 2),
|
|
||||||
(10, 2, 2, 999),
|
|
||||||
(42, 4, 3, 1337),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test__write(addr, length, id_, value, mock_motors, dummy_motors):
|
|
||||||
stub = mock_motors.build_write_stub(addr, length, id_, value)
|
|
||||||
bus = FeetechMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
comm, error = bus._write(addr, length, id_, value)
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub].wait_called()
|
|
||||||
assert comm == scs.COMM_SUCCESS
|
|
||||||
assert error == 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("raise_on_error", (True, False))
|
|
||||||
def test__write_error(raise_on_error, mock_motors, dummy_motors):
|
|
||||||
addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE)
|
|
||||||
stub = mock_motors.build_write_stub(addr, length, id_, value, error=error)
|
|
||||||
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
if raise_on_error:
|
|
||||||
with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")):
|
|
||||||
bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
|
||||||
else:
|
|
||||||
_, write_error = bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
|
||||||
assert write_error == error
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub].called
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("raise_on_error", (True, False))
|
|
||||||
def test__write_comm(raise_on_error, mock_motors, dummy_motors):
|
|
||||||
addr, length, id_, value = (10, 4, 1, 1337)
|
|
||||||
stub = mock_motors.build_write_stub(addr, length, id_, value, reply=False)
|
|
||||||
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
if raise_on_error:
|
|
||||||
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
|
|
||||||
bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
|
||||||
else:
|
|
||||||
write_comm, _ = bus._write(addr, length, id_, value, raise_on_error=raise_on_error)
|
|
||||||
assert write_comm == scs.COMM_RX_TIMEOUT
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub].called
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"addr, length, ids_values",
|
|
||||||
[
|
|
||||||
(0, 1, {1: 4}),
|
|
||||||
(10, 2, {1: 1337, 2: 42}),
|
|
||||||
(42, 4, {1: 1337, 2: 42, 3: 4016}),
|
|
||||||
],
|
|
||||||
ids=["1 motor", "2 motors", "3 motors"],
|
|
||||||
)
|
|
||||||
def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors):
|
|
||||||
stub = mock_motors.build_sync_read_stub(addr, length, ids_values)
|
|
||||||
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
read_values, _ = bus._sync_read(addr, length, list(ids_values))
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub].called
|
|
||||||
assert read_values == ids_values
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("raise_on_error", (True, False))
|
|
||||||
def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors):
|
|
||||||
addr, length, ids_values = (10, 4, {1: 1337})
|
|
||||||
stub = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False)
|
|
||||||
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
if raise_on_error:
|
|
||||||
with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")):
|
|
||||||
bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
|
|
||||||
else:
|
|
||||||
_, read_comm = bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error)
|
|
||||||
assert read_comm == scs.COMM_RX_TIMEOUT
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub].called
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"addr, length, ids_values",
|
|
||||||
[
|
|
||||||
(0, 1, {1: 4}),
|
|
||||||
(10, 2, {1: 1337, 2: 42}),
|
|
||||||
(42, 4, {1: 1337, 2: 42, 3: 4016}),
|
|
||||||
],
|
|
||||||
ids=["1 motor", "2 motors", "3 motors"],
|
|
||||||
)
|
|
||||||
def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors):
|
|
||||||
stub = mock_motors.build_sync_write_stub(addr, length, ids_values)
|
|
||||||
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
comm = bus._sync_write(addr, length, ids_values)
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub].wait_called()
|
|
||||||
assert comm == scs.COMM_SUCCESS
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration):
|
|
||||||
mins_stubs, maxes_stubs, homings_stubs = [], [], []
|
|
||||||
for cal in dummy_calibration.values():
|
|
||||||
mins_stubs.append(
|
|
||||||
mock_motors.build_read_stub(
|
|
||||||
*STS_SMS_SERIES_CONTROL_TABLE["Min_Position_Limit"], cal.id, cal.range_min
|
|
||||||
)
|
|
||||||
)
|
|
||||||
maxes_stubs.append(
|
|
||||||
mock_motors.build_read_stub(
|
|
||||||
*STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], cal.id, cal.range_max
|
|
||||||
)
|
|
||||||
)
|
|
||||||
homings_stubs.append(
|
|
||||||
mock_motors.build_read_stub(
|
|
||||||
*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"],
|
|
||||||
cal.id,
|
|
||||||
encode_sign_magnitude(cal.homing_offset, 11),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
bus = FeetechMotorsBus(
|
|
||||||
port=mock_motors.port,
|
|
||||||
motors=dummy_motors,
|
|
||||||
calibration=dummy_calibration,
|
|
||||||
)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
is_calibrated = bus.is_calibrated
|
|
||||||
|
|
||||||
assert is_calibrated
|
|
||||||
assert all(mock_motors.stubs[stub].called for stub in mins_stubs)
|
|
||||||
assert all(mock_motors.stubs[stub].called for stub in maxes_stubs)
|
|
||||||
assert all(mock_motors.stubs[stub].called for stub in homings_stubs)
|
|
||||||
|
|
||||||
|
|
||||||
def test_reset_calibration(mock_motors, dummy_motors):
|
|
||||||
write_homing_stubs = []
|
|
||||||
write_mins_stubs = []
|
|
||||||
write_maxes_stubs = []
|
|
||||||
for motor in dummy_motors.values():
|
|
||||||
write_homing_stubs.append(
|
|
||||||
mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], motor.id, 0)
|
|
||||||
)
|
|
||||||
write_mins_stubs.append(
|
|
||||||
mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Min_Position_Limit"], motor.id, 0)
|
|
||||||
)
|
|
||||||
write_maxes_stubs.append(
|
|
||||||
mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], motor.id, 4095)
|
|
||||||
)
|
|
||||||
|
|
||||||
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
bus.reset_calibration()
|
|
||||||
|
|
||||||
assert all(mock_motors.stubs[stub].wait_called() for stub in write_homing_stubs)
|
|
||||||
assert all(mock_motors.stubs[stub].wait_called() for stub in write_mins_stubs)
|
|
||||||
assert all(mock_motors.stubs[stub].wait_called() for stub in write_maxes_stubs)
|
|
||||||
|
|
||||||
|
|
||||||
def test_set_half_turn_homings(mock_motors, dummy_motors):
|
|
||||||
"""
|
|
||||||
For this test, we assume that the homing offsets are already 0 such that
|
|
||||||
Present_Position == Actual_Position
|
|
||||||
"""
|
|
||||||
current_positions = {
|
|
||||||
1: 1337,
|
|
||||||
2: 42,
|
|
||||||
3: 3672,
|
|
||||||
}
|
|
||||||
expected_homings = {
|
|
||||||
1: -710, # 1337 - 2047
|
|
||||||
2: -2005, # 42 - 2047
|
|
||||||
3: 1625, # 3672 - 2047
|
|
||||||
}
|
|
||||||
read_pos_stub = mock_motors.build_sync_read_stub(
|
|
||||||
*STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], current_positions
|
|
||||||
)
|
|
||||||
write_homing_stubs = []
|
|
||||||
for id_, homing in expected_homings.items():
|
|
||||||
encoded_homing = encode_sign_magnitude(homing, 11)
|
|
||||||
stub = mock_motors.build_write_stub(
|
|
||||||
*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], id_, encoded_homing
|
|
||||||
)
|
|
||||||
write_homing_stubs.append(stub)
|
|
||||||
|
|
||||||
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
bus.reset_calibration = MagicMock()
|
|
||||||
|
|
||||||
bus.set_half_turn_homings()
|
|
||||||
|
|
||||||
bus.reset_calibration.assert_called_once()
|
|
||||||
assert mock_motors.stubs[read_pos_stub].called
|
|
||||||
assert all(mock_motors.stubs[stub].wait_called() for stub in write_homing_stubs)
|
|
||||||
|
|
||||||
|
|
||||||
def test_record_ranges_of_motion(mock_motors, dummy_motors):
|
|
||||||
positions = {
|
|
||||||
1: [351, 42, 1337],
|
|
||||||
2: [28, 3600, 2444],
|
|
||||||
3: [4002, 2999, 146],
|
|
||||||
}
|
|
||||||
expected_mins = {
|
|
||||||
"dummy_1": 42,
|
|
||||||
"dummy_2": 28,
|
|
||||||
"dummy_3": 146,
|
|
||||||
}
|
|
||||||
expected_maxes = {
|
|
||||||
"dummy_1": 1337,
|
|
||||||
"dummy_2": 3600,
|
|
||||||
"dummy_3": 4002,
|
|
||||||
}
|
|
||||||
stub = mock_motors.build_sequential_sync_read_stub(
|
|
||||||
*STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], positions
|
|
||||||
)
|
|
||||||
with patch("lerobot.motors.motors_bus.enter_pressed", side_effect=[False, True]):
|
|
||||||
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
|
|
||||||
mins, maxes = bus.record_ranges_of_motion(display_values=False)
|
|
||||||
|
|
||||||
assert mock_motors.stubs[stub].calls == 3
|
|
||||||
assert mins == expected_mins
|
|
||||||
assert maxes == expected_maxes
|
|
||||||
@ -1,358 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import re
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from lerobot.motors.motors_bus import (
|
|
||||||
Motor,
|
|
||||||
MotorNormMode,
|
|
||||||
assert_same_address,
|
|
||||||
get_address,
|
|
||||||
get_ctrl_table,
|
|
||||||
)
|
|
||||||
from tests.mocks.mock_motors_bus import (
|
|
||||||
DUMMY_CTRL_TABLE_1,
|
|
||||||
DUMMY_CTRL_TABLE_2,
|
|
||||||
DUMMY_MODEL_CTRL_TABLE,
|
|
||||||
MockMotorsBus,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def dummy_motors() -> dict[str, Motor]:
|
|
||||||
return {
|
|
||||||
"dummy_1": Motor(1, "model_2", MotorNormMode.RANGE_M100_100),
|
|
||||||
"dummy_2": Motor(2, "model_3", MotorNormMode.RANGE_M100_100),
|
|
||||||
"dummy_3": Motor(3, "model_2", MotorNormMode.RANGE_0_100),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_ctrl_table():
|
|
||||||
model = "model_1"
|
|
||||||
ctrl_table = get_ctrl_table(DUMMY_MODEL_CTRL_TABLE, model)
|
|
||||||
assert ctrl_table == DUMMY_CTRL_TABLE_1
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_ctrl_table_error():
|
|
||||||
model = "model_99"
|
|
||||||
with pytest.raises(KeyError, match=f"Control table for {model=} not found."):
|
|
||||||
get_ctrl_table(DUMMY_MODEL_CTRL_TABLE, model)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_address():
|
|
||||||
addr, n_bytes = get_address(DUMMY_MODEL_CTRL_TABLE, "model_1", "Firmware_Version")
|
|
||||||
assert addr == 0
|
|
||||||
assert n_bytes == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_address_error():
|
|
||||||
model = "model_1"
|
|
||||||
data_name = "Lock"
|
|
||||||
with pytest.raises(KeyError, match=f"Address for '{data_name}' not found in {model} control table."):
|
|
||||||
get_address(DUMMY_MODEL_CTRL_TABLE, "model_1", data_name)
|
|
||||||
|
|
||||||
|
|
||||||
def test_assert_same_address():
|
|
||||||
models = ["model_1", "model_2"]
|
|
||||||
assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Present_Position")
|
|
||||||
|
|
||||||
|
|
||||||
def test_assert_same_length_different_addresses():
|
|
||||||
models = ["model_1", "model_2"]
|
|
||||||
with pytest.raises(
|
|
||||||
NotImplementedError,
|
|
||||||
match=re.escape("At least two motor models use a different address"),
|
|
||||||
):
|
|
||||||
assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Model_Number")
|
|
||||||
|
|
||||||
|
|
||||||
def test_assert_same_address_different_length():
|
|
||||||
models = ["model_1", "model_2"]
|
|
||||||
with pytest.raises(
|
|
||||||
NotImplementedError,
|
|
||||||
match=re.escape("At least two motor models use a different bytes representation"),
|
|
||||||
):
|
|
||||||
assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Goal_Position")
|
|
||||||
|
|
||||||
|
|
||||||
def test__serialize_data_invalid_length():
|
|
||||||
bus = MockMotorsBus("", {})
|
|
||||||
with pytest.raises(NotImplementedError):
|
|
||||||
bus._serialize_data(100, 3)
|
|
||||||
|
|
||||||
|
|
||||||
def test__serialize_data_negative_numbers():
|
|
||||||
bus = MockMotorsBus("", {})
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
bus._serialize_data(-1, 1)
|
|
||||||
|
|
||||||
|
|
||||||
def test__serialize_data_large_number():
|
|
||||||
bus = MockMotorsBus("", {})
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
bus._serialize_data(2**32, 4) # 4-byte max is 0xFFFFFFFF
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"data_name, id_, value",
|
|
||||||
[
|
|
||||||
("Firmware_Version", 1, 14),
|
|
||||||
("Model_Number", 1, 5678),
|
|
||||||
("Present_Position", 2, 1337),
|
|
||||||
("Present_Velocity", 3, 42),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_read(data_name, id_, value, dummy_motors):
|
|
||||||
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
addr, length = DUMMY_CTRL_TABLE_2[data_name]
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MockMotorsBus, "_read", return_value=(value, 0, 0)) as mock__read,
|
|
||||||
patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign,
|
|
||||||
patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize,
|
|
||||||
):
|
|
||||||
returned_value = bus.read(data_name, f"dummy_{id_}")
|
|
||||||
|
|
||||||
assert returned_value == value
|
|
||||||
mock__read.assert_called_once_with(
|
|
||||||
addr,
|
|
||||||
length,
|
|
||||||
id_,
|
|
||||||
num_retry=0,
|
|
||||||
raise_on_error=True,
|
|
||||||
err_msg=f"Failed to read '{data_name}' on {id_=} after 1 tries.",
|
|
||||||
)
|
|
||||||
mock__decode_sign.assert_called_once_with(data_name, {id_: value})
|
|
||||||
if data_name in bus.normalized_data:
|
|
||||||
mock__normalize.assert_called_once_with({id_: value})
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"data_name, id_, value",
|
|
||||||
[
|
|
||||||
("Goal_Position", 1, 1337),
|
|
||||||
("Goal_Velocity", 2, 3682),
|
|
||||||
("Lock", 3, 1),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_write(data_name, id_, value, dummy_motors):
|
|
||||||
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
addr, length = DUMMY_CTRL_TABLE_2[data_name]
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MockMotorsBus, "_write", return_value=(0, 0)) as mock__write,
|
|
||||||
patch.object(MockMotorsBus, "_encode_sign", return_value={id_: value}) as mock__encode_sign,
|
|
||||||
patch.object(MockMotorsBus, "_unnormalize", return_value={id_: value}) as mock__unnormalize,
|
|
||||||
):
|
|
||||||
bus.write(data_name, f"dummy_{id_}", value)
|
|
||||||
|
|
||||||
mock__write.assert_called_once_with(
|
|
||||||
addr,
|
|
||||||
length,
|
|
||||||
id_,
|
|
||||||
value,
|
|
||||||
num_retry=0,
|
|
||||||
raise_on_error=True,
|
|
||||||
err_msg=f"Failed to write '{data_name}' on {id_=} with '{value}' after 1 tries.",
|
|
||||||
)
|
|
||||||
mock__encode_sign.assert_called_once_with(data_name, {id_: value})
|
|
||||||
if data_name in bus.normalized_data:
|
|
||||||
mock__unnormalize.assert_called_once_with({id_: value})
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"data_name, id_, value",
|
|
||||||
[
|
|
||||||
("Firmware_Version", 1, 14),
|
|
||||||
("Model_Number", 1, 5678),
|
|
||||||
("Present_Position", 2, 1337),
|
|
||||||
("Present_Velocity", 3, 42),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_sync_read_by_str(data_name, id_, value, dummy_motors):
|
|
||||||
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
addr, length = DUMMY_CTRL_TABLE_2[data_name]
|
|
||||||
ids = [id_]
|
|
||||||
expected_value = {f"dummy_{id_}": value}
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MockMotorsBus, "_sync_read", return_value=({id_: value}, 0)) as mock__sync_read,
|
|
||||||
patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign,
|
|
||||||
patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize,
|
|
||||||
):
|
|
||||||
returned_dict = bus.sync_read(data_name, f"dummy_{id_}")
|
|
||||||
|
|
||||||
assert returned_dict == expected_value
|
|
||||||
mock__sync_read.assert_called_once_with(
|
|
||||||
addr,
|
|
||||||
length,
|
|
||||||
ids,
|
|
||||||
num_retry=0,
|
|
||||||
raise_on_error=True,
|
|
||||||
err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.",
|
|
||||||
)
|
|
||||||
mock__decode_sign.assert_called_once_with(data_name, {id_: value})
|
|
||||||
if data_name in bus.normalized_data:
|
|
||||||
mock__normalize.assert_called_once_with({id_: value})
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"data_name, ids_values",
|
|
||||||
[
|
|
||||||
("Model_Number", {1: 5678}),
|
|
||||||
("Present_Position", {1: 1337, 2: 42}),
|
|
||||||
("Present_Velocity", {1: 1337, 2: 42, 3: 4016}),
|
|
||||||
],
|
|
||||||
ids=["1 motor", "2 motors", "3 motors"],
|
|
||||||
)
|
|
||||||
def test_sync_read_by_list(data_name, ids_values, dummy_motors):
|
|
||||||
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
addr, length = DUMMY_CTRL_TABLE_2[data_name]
|
|
||||||
ids = list(ids_values)
|
|
||||||
expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()}
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read,
|
|
||||||
patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign,
|
|
||||||
patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize,
|
|
||||||
):
|
|
||||||
returned_dict = bus.sync_read(data_name, [f"dummy_{id_}" for id_ in ids])
|
|
||||||
|
|
||||||
assert returned_dict == expected_values
|
|
||||||
mock__sync_read.assert_called_once_with(
|
|
||||||
addr,
|
|
||||||
length,
|
|
||||||
ids,
|
|
||||||
num_retry=0,
|
|
||||||
raise_on_error=True,
|
|
||||||
err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.",
|
|
||||||
)
|
|
||||||
mock__decode_sign.assert_called_once_with(data_name, ids_values)
|
|
||||||
if data_name in bus.normalized_data:
|
|
||||||
mock__normalize.assert_called_once_with(ids_values)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"data_name, ids_values",
|
|
||||||
[
|
|
||||||
("Model_Number", {1: 5678, 2: 5799, 3: 5678}),
|
|
||||||
("Present_Position", {1: 1337, 2: 42, 3: 4016}),
|
|
||||||
("Goal_Position", {1: 4008, 2: 199, 3: 3446}),
|
|
||||||
],
|
|
||||||
ids=["Model_Number", "Present_Position", "Goal_Position"],
|
|
||||||
)
|
|
||||||
def test_sync_read_by_none(data_name, ids_values, dummy_motors):
|
|
||||||
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
addr, length = DUMMY_CTRL_TABLE_2[data_name]
|
|
||||||
ids = list(ids_values)
|
|
||||||
expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()}
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read,
|
|
||||||
patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign,
|
|
||||||
patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize,
|
|
||||||
):
|
|
||||||
returned_dict = bus.sync_read(data_name)
|
|
||||||
|
|
||||||
assert returned_dict == expected_values
|
|
||||||
mock__sync_read.assert_called_once_with(
|
|
||||||
addr,
|
|
||||||
length,
|
|
||||||
ids,
|
|
||||||
num_retry=0,
|
|
||||||
raise_on_error=True,
|
|
||||||
err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.",
|
|
||||||
)
|
|
||||||
mock__decode_sign.assert_called_once_with(data_name, ids_values)
|
|
||||||
if data_name in bus.normalized_data:
|
|
||||||
mock__normalize.assert_called_once_with(ids_values)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"data_name, value",
|
|
||||||
[
|
|
||||||
("Goal_Position", 500),
|
|
||||||
("Goal_Velocity", 4010),
|
|
||||||
("Lock", 0),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_sync_write_by_single_value(data_name, value, dummy_motors):
|
|
||||||
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
addr, length = DUMMY_CTRL_TABLE_2[data_name]
|
|
||||||
ids_values = {m.id: value for m in dummy_motors.values()}
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write,
|
|
||||||
patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign,
|
|
||||||
patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize,
|
|
||||||
):
|
|
||||||
bus.sync_write(data_name, value)
|
|
||||||
|
|
||||||
mock__sync_write.assert_called_once_with(
|
|
||||||
addr,
|
|
||||||
length,
|
|
||||||
ids_values,
|
|
||||||
num_retry=0,
|
|
||||||
raise_on_error=True,
|
|
||||||
err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.",
|
|
||||||
)
|
|
||||||
mock__encode_sign.assert_called_once_with(data_name, ids_values)
|
|
||||||
if data_name in bus.normalized_data:
|
|
||||||
mock__unnormalize.assert_called_once_with(ids_values)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"data_name, ids_values",
|
|
||||||
[
|
|
||||||
("Goal_Position", {1: 1337, 2: 42, 3: 4016}),
|
|
||||||
("Goal_Velocity", {1: 50, 2: 83, 3: 2777}),
|
|
||||||
("Lock", {1: 0, 2: 0, 3: 1}),
|
|
||||||
],
|
|
||||||
ids=["Goal_Position", "Goal_Velocity", "Lock"],
|
|
||||||
)
|
|
||||||
def test_sync_write_by_value_dict(data_name, ids_values, dummy_motors):
|
|
||||||
bus = MockMotorsBus("/dev/dummy-port", dummy_motors)
|
|
||||||
bus.connect(handshake=False)
|
|
||||||
addr, length = DUMMY_CTRL_TABLE_2[data_name]
|
|
||||||
values = {f"dummy_{id_}": val for id_, val in ids_values.items()}
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write,
|
|
||||||
patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign,
|
|
||||||
patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize,
|
|
||||||
):
|
|
||||||
bus.sync_write(data_name, values)
|
|
||||||
|
|
||||||
mock__sync_write.assert_called_once_with(
|
|
||||||
addr,
|
|
||||||
length,
|
|
||||||
ids_values,
|
|
||||||
num_retry=0,
|
|
||||||
raise_on_error=True,
|
|
||||||
err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.",
|
|
||||||
)
|
|
||||||
mock__encode_sign.assert_called_once_with(data_name, ids_values)
|
|
||||||
if data_name in bus.normalized_data:
|
|
||||||
mock__unnormalize.assert_called_once_with(ids_values)
|
|
||||||
@ -1,242 +0,0 @@
|
|||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.constants import (
|
|
||||||
OPTIMIZER_PARAM_GROUPS,
|
|
||||||
OPTIMIZER_STATE,
|
|
||||||
)
|
|
||||||
from lerobot.optim.optimizers import (
|
|
||||||
AdamConfig,
|
|
||||||
AdamWConfig,
|
|
||||||
MultiAdamConfig,
|
|
||||||
SGDConfig,
|
|
||||||
load_optimizer_state,
|
|
||||||
save_optimizer_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"config_cls, expected_class",
|
|
||||||
[
|
|
||||||
(AdamConfig, torch.optim.Adam),
|
|
||||||
(AdamWConfig, torch.optim.AdamW),
|
|
||||||
(SGDConfig, torch.optim.SGD),
|
|
||||||
(MultiAdamConfig, dict),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_optimizer_build(config_cls, expected_class, model_params):
|
|
||||||
config = config_cls()
|
|
||||||
if config_cls == MultiAdamConfig:
|
|
||||||
params_dict = {"default": model_params}
|
|
||||||
optimizer = config.build(params_dict)
|
|
||||||
assert isinstance(optimizer, expected_class)
|
|
||||||
assert isinstance(optimizer["default"], torch.optim.Adam)
|
|
||||||
assert optimizer["default"].defaults["lr"] == config.lr
|
|
||||||
else:
|
|
||||||
optimizer = config.build(model_params)
|
|
||||||
assert isinstance(optimizer, expected_class)
|
|
||||||
assert optimizer.defaults["lr"] == config.lr
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_optimizer_state(optimizer, tmp_path):
|
|
||||||
save_optimizer_state(optimizer, tmp_path)
|
|
||||||
assert (tmp_path / OPTIMIZER_STATE).is_file()
|
|
||||||
assert (tmp_path / OPTIMIZER_PARAM_GROUPS).is_file()
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path):
|
|
||||||
save_optimizer_state(optimizer, tmp_path)
|
|
||||||
loaded_optimizer = AdamConfig().build(model_params)
|
|
||||||
loaded_optimizer = load_optimizer_state(loaded_optimizer, tmp_path)
|
|
||||||
|
|
||||||
torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict())
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def base_params_dict():
|
|
||||||
return {
|
|
||||||
"actor": [torch.nn.Parameter(torch.randn(10, 10))],
|
|
||||||
"critic": [torch.nn.Parameter(torch.randn(5, 5))],
|
|
||||||
"temperature": [torch.nn.Parameter(torch.randn(3, 3))],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"config_params, expected_values",
|
|
||||||
[
|
|
||||||
# Test 1: Basic configuration with different learning rates
|
|
||||||
(
|
|
||||||
{
|
|
||||||
"lr": 1e-3,
|
|
||||||
"weight_decay": 1e-4,
|
|
||||||
"optimizer_groups": {
|
|
||||||
"actor": {"lr": 1e-4},
|
|
||||||
"critic": {"lr": 5e-4},
|
|
||||||
"temperature": {"lr": 2e-3},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
|
|
||||||
"critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
|
|
||||||
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
# Test 2: Different weight decays and beta values
|
|
||||||
(
|
|
||||||
{
|
|
||||||
"lr": 1e-3,
|
|
||||||
"weight_decay": 1e-4,
|
|
||||||
"optimizer_groups": {
|
|
||||||
"actor": {"lr": 1e-4, "weight_decay": 1e-5},
|
|
||||||
"critic": {"lr": 5e-4, "weight_decay": 1e-6},
|
|
||||||
"temperature": {"lr": 2e-3, "betas": (0.95, 0.999)},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"actor": {"lr": 1e-4, "weight_decay": 1e-5, "betas": (0.9, 0.999)},
|
|
||||||
"critic": {"lr": 5e-4, "weight_decay": 1e-6, "betas": (0.9, 0.999)},
|
|
||||||
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.95, 0.999)},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
# Test 3: Epsilon parameter customization
|
|
||||||
(
|
|
||||||
{
|
|
||||||
"lr": 1e-3,
|
|
||||||
"weight_decay": 1e-4,
|
|
||||||
"optimizer_groups": {
|
|
||||||
"actor": {"lr": 1e-4, "eps": 1e-6},
|
|
||||||
"critic": {"lr": 5e-4, "eps": 1e-7},
|
|
||||||
"temperature": {"lr": 2e-3, "eps": 1e-8},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-6},
|
|
||||||
"critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-7},
|
|
||||||
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-8},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_multi_adam_configuration(base_params_dict, config_params, expected_values):
|
|
||||||
# Create config with the given parameters
|
|
||||||
config = MultiAdamConfig(**config_params)
|
|
||||||
optimizers = config.build(base_params_dict)
|
|
||||||
|
|
||||||
# Verify optimizer count and keys
|
|
||||||
assert len(optimizers) == len(expected_values)
|
|
||||||
assert set(optimizers.keys()) == set(expected_values.keys())
|
|
||||||
|
|
||||||
# Check that all optimizers are Adam instances
|
|
||||||
for opt in optimizers.values():
|
|
||||||
assert isinstance(opt, torch.optim.Adam)
|
|
||||||
|
|
||||||
# Verify hyperparameters for each optimizer
|
|
||||||
for name, expected in expected_values.items():
|
|
||||||
optimizer = optimizers[name]
|
|
||||||
for param, value in expected.items():
|
|
||||||
assert optimizer.defaults[param] == value
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def multi_optimizers(base_params_dict):
|
|
||||||
config = MultiAdamConfig(
|
|
||||||
lr=1e-3,
|
|
||||||
optimizer_groups={
|
|
||||||
"actor": {"lr": 1e-4},
|
|
||||||
"critic": {"lr": 5e-4},
|
|
||||||
"temperature": {"lr": 2e-3},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return config.build(base_params_dict)
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_multi_optimizer_state(multi_optimizers, tmp_path):
|
|
||||||
# Save optimizer states
|
|
||||||
save_optimizer_state(multi_optimizers, tmp_path)
|
|
||||||
|
|
||||||
# Verify that directories were created for each optimizer
|
|
||||||
for name in multi_optimizers:
|
|
||||||
assert (tmp_path / name).is_dir()
|
|
||||||
assert (tmp_path / name / OPTIMIZER_STATE).is_file()
|
|
||||||
assert (tmp_path / name / OPTIMIZER_PARAM_GROUPS).is_file()
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers, tmp_path):
|
|
||||||
# Option 1: Add a minimal backward pass to populate optimizer states
|
|
||||||
for name, params in base_params_dict.items():
|
|
||||||
if name in multi_optimizers:
|
|
||||||
# Create a dummy loss and do backward
|
|
||||||
dummy_loss = params[0].sum()
|
|
||||||
dummy_loss.backward()
|
|
||||||
# Perform an optimization step
|
|
||||||
multi_optimizers[name].step()
|
|
||||||
# Zero gradients for next steps
|
|
||||||
multi_optimizers[name].zero_grad()
|
|
||||||
|
|
||||||
# Save optimizer states
|
|
||||||
save_optimizer_state(multi_optimizers, tmp_path)
|
|
||||||
|
|
||||||
# Create new optimizers with the same config
|
|
||||||
config = MultiAdamConfig(
|
|
||||||
lr=1e-3,
|
|
||||||
optimizer_groups={
|
|
||||||
"actor": {"lr": 1e-4},
|
|
||||||
"critic": {"lr": 5e-4},
|
|
||||||
"temperature": {"lr": 2e-3},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
new_optimizers = config.build(base_params_dict)
|
|
||||||
|
|
||||||
# Load optimizer states
|
|
||||||
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
|
|
||||||
|
|
||||||
# Verify state dictionaries match
|
|
||||||
for name in multi_optimizers:
|
|
||||||
torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict())
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
|
|
||||||
"""Test saving and loading optimizer states even when the state is empty (no backward pass)."""
|
|
||||||
# Create config and build optimizers
|
|
||||||
config = MultiAdamConfig(
|
|
||||||
lr=1e-3,
|
|
||||||
optimizer_groups={
|
|
||||||
"actor": {"lr": 1e-4},
|
|
||||||
"critic": {"lr": 5e-4},
|
|
||||||
"temperature": {"lr": 2e-3},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
optimizers = config.build(base_params_dict)
|
|
||||||
|
|
||||||
# Save optimizer states without any backward pass (empty state)
|
|
||||||
save_optimizer_state(optimizers, tmp_path)
|
|
||||||
|
|
||||||
# Create new optimizers with the same config
|
|
||||||
new_optimizers = config.build(base_params_dict)
|
|
||||||
|
|
||||||
# Load optimizer states
|
|
||||||
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
|
|
||||||
|
|
||||||
# Verify hyperparameters match even with empty state
|
|
||||||
for name, optimizer in optimizers.items():
|
|
||||||
assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"]
|
|
||||||
assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"]
|
|
||||||
assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"]
|
|
||||||
|
|
||||||
# Verify state dictionaries match (they will be empty)
|
|
||||||
torch.testing.assert_close(
|
|
||||||
optimizer.state_dict()["param_groups"], loaded_optimizers[name].state_dict()["param_groups"]
|
|
||||||
)
|
|
||||||
@ -1,91 +0,0 @@
|
|||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from torch.optim.lr_scheduler import LambdaLR
|
|
||||||
|
|
||||||
from lerobot.constants import SCHEDULER_STATE
|
|
||||||
from lerobot.optim.schedulers import (
|
|
||||||
CosineDecayWithWarmupSchedulerConfig,
|
|
||||||
DiffuserSchedulerConfig,
|
|
||||||
VQBeTSchedulerConfig,
|
|
||||||
load_scheduler_state,
|
|
||||||
save_scheduler_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_diffuser_scheduler(optimizer):
|
|
||||||
config = DiffuserSchedulerConfig(name="cosine", num_warmup_steps=5)
|
|
||||||
scheduler = config.build(optimizer, num_training_steps=100)
|
|
||||||
assert isinstance(scheduler, LambdaLR)
|
|
||||||
|
|
||||||
optimizer.step() # so that we don't get torch warning
|
|
||||||
scheduler.step()
|
|
||||||
expected_state_dict = {
|
|
||||||
"_get_lr_called_within_step": False,
|
|
||||||
"_last_lr": [0.0002],
|
|
||||||
"_step_count": 2,
|
|
||||||
"base_lrs": [0.001],
|
|
||||||
"last_epoch": 1,
|
|
||||||
"lr_lambdas": [None],
|
|
||||||
}
|
|
||||||
assert scheduler.state_dict() == expected_state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def test_vqbet_scheduler(optimizer):
|
|
||||||
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
|
|
||||||
scheduler = config.build(optimizer, num_training_steps=100)
|
|
||||||
assert isinstance(scheduler, LambdaLR)
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
scheduler.step()
|
|
||||||
expected_state_dict = {
|
|
||||||
"_get_lr_called_within_step": False,
|
|
||||||
"_last_lr": [0.001],
|
|
||||||
"_step_count": 2,
|
|
||||||
"base_lrs": [0.001],
|
|
||||||
"last_epoch": 1,
|
|
||||||
"lr_lambdas": [None],
|
|
||||||
}
|
|
||||||
assert scheduler.state_dict() == expected_state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def test_cosine_decay_with_warmup_scheduler(optimizer):
|
|
||||||
config = CosineDecayWithWarmupSchedulerConfig(
|
|
||||||
num_warmup_steps=10, num_decay_steps=90, peak_lr=0.01, decay_lr=0.001
|
|
||||||
)
|
|
||||||
scheduler = config.build(optimizer, num_training_steps=100)
|
|
||||||
assert isinstance(scheduler, LambdaLR)
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
scheduler.step()
|
|
||||||
expected_state_dict = {
|
|
||||||
"_get_lr_called_within_step": False,
|
|
||||||
"_last_lr": [0.0001818181818181819],
|
|
||||||
"_step_count": 2,
|
|
||||||
"base_lrs": [0.001],
|
|
||||||
"last_epoch": 1,
|
|
||||||
"lr_lambdas": [None],
|
|
||||||
}
|
|
||||||
assert scheduler.state_dict() == expected_state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_scheduler_state(scheduler, tmp_path):
|
|
||||||
save_scheduler_state(scheduler, tmp_path)
|
|
||||||
assert (tmp_path / SCHEDULER_STATE).is_file()
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_load_scheduler_state(scheduler, tmp_path):
|
|
||||||
save_scheduler_state(scheduler, tmp_path)
|
|
||||||
loaded_scheduler = load_scheduler_state(scheduler, tmp_path)
|
|
||||||
|
|
||||||
assert scheduler.state_dict() == loaded_scheduler.state_dict()
|
|
||||||
@ -1,139 +0,0 @@
|
|||||||
# !/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
|
||||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
|
|
||||||
from tests.utils import require_package
|
|
||||||
|
|
||||||
|
|
||||||
def test_classifier_output():
|
|
||||||
output = ClassifierOutput(
|
|
||||||
logits=torch.tensor([1, 2, 3]),
|
|
||||||
probabilities=torch.tensor([0.1, 0.2, 0.3]),
|
|
||||||
hidden_states=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert (
|
|
||||||
f"{output}"
|
|
||||||
== "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@require_package("transformers")
|
|
||||||
def test_binary_classifier_with_default_params():
|
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
|
||||||
|
|
||||||
config = RewardClassifierConfig()
|
|
||||||
config.input_features = {
|
|
||||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
|
||||||
}
|
|
||||||
config.output_features = {
|
|
||||||
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
|
|
||||||
}
|
|
||||||
config.normalization_mapping = {
|
|
||||||
"VISUAL": NormalizationMode.IDENTITY,
|
|
||||||
"REWARD": NormalizationMode.IDENTITY,
|
|
||||||
}
|
|
||||||
config.num_cameras = 1
|
|
||||||
classifier = Classifier(config)
|
|
||||||
|
|
||||||
batch_size = 10
|
|
||||||
|
|
||||||
input = {
|
|
||||||
"observation.image": torch.rand((batch_size, 3, 128, 128)),
|
|
||||||
"next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(),
|
|
||||||
}
|
|
||||||
|
|
||||||
images, labels = classifier.extract_images_and_labels(input)
|
|
||||||
assert len(images) == 1
|
|
||||||
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
|
|
||||||
assert labels.shape == torch.Size([batch_size])
|
|
||||||
|
|
||||||
output = classifier.predict(images)
|
|
||||||
|
|
||||||
assert output is not None
|
|
||||||
assert output.logits.size() == torch.Size([batch_size])
|
|
||||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
|
||||||
assert output.probabilities.shape == torch.Size([batch_size])
|
|
||||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
|
||||||
assert output.hidden_states.shape == torch.Size([batch_size, 256])
|
|
||||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
|
||||||
|
|
||||||
|
|
||||||
@require_package("transformers")
|
|
||||||
def test_multiclass_classifier():
|
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
|
||||||
|
|
||||||
num_classes = 5
|
|
||||||
config = RewardClassifierConfig()
|
|
||||||
config.input_features = {
|
|
||||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
|
||||||
}
|
|
||||||
config.output_features = {
|
|
||||||
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
|
|
||||||
}
|
|
||||||
config.num_cameras = 1
|
|
||||||
config.num_classes = num_classes
|
|
||||||
classifier = Classifier(config)
|
|
||||||
|
|
||||||
batch_size = 10
|
|
||||||
|
|
||||||
input = {
|
|
||||||
"observation.image": torch.rand((batch_size, 3, 128, 128)),
|
|
||||||
"next.reward": torch.rand((batch_size, num_classes)),
|
|
||||||
}
|
|
||||||
|
|
||||||
images, labels = classifier.extract_images_and_labels(input)
|
|
||||||
assert len(images) == 1
|
|
||||||
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
|
|
||||||
assert labels.shape == torch.Size([batch_size, num_classes])
|
|
||||||
|
|
||||||
output = classifier.predict(images)
|
|
||||||
|
|
||||||
assert output is not None
|
|
||||||
assert output.logits.shape == torch.Size([batch_size, num_classes])
|
|
||||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
|
||||||
assert output.probabilities.shape == torch.Size([batch_size, num_classes])
|
|
||||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
|
||||||
assert output.hidden_states.shape == torch.Size([batch_size, 256])
|
|
||||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
|
||||||
|
|
||||||
|
|
||||||
@require_package("transformers")
|
|
||||||
def test_default_device():
|
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
|
||||||
|
|
||||||
config = RewardClassifierConfig()
|
|
||||||
assert config.device == "cpu"
|
|
||||||
|
|
||||||
classifier = Classifier(config)
|
|
||||||
for p in classifier.parameters():
|
|
||||||
assert p.device == torch.device("cpu")
|
|
||||||
|
|
||||||
|
|
||||||
@require_package("transformers")
|
|
||||||
def test_explicit_device_setup():
|
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
|
||||||
|
|
||||||
config = RewardClassifierConfig(device="cpu")
|
|
||||||
assert config.device == "cpu"
|
|
||||||
|
|
||||||
classifier = Classifier(config)
|
|
||||||
for p in classifier.parameters():
|
|
||||||
assert p.device == torch.device("cpu")
|
|
||||||
@ -1,546 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import inspect
|
|
||||||
from copy import deepcopy
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import einops
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from packaging import version
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
|
|
||||||
from lerobot import available_policies
|
|
||||||
from lerobot.configs.default import DatasetConfig
|
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
|
||||||
from lerobot.constants import ACTION, OBS_STATE
|
|
||||||
from lerobot.datasets.factory import make_dataset
|
|
||||||
from lerobot.datasets.utils import cycle, dataset_to_policy_features
|
|
||||||
from lerobot.envs.factory import make_env, make_env_config
|
|
||||||
from lerobot.envs.utils import preprocess_observation
|
|
||||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
|
||||||
from lerobot.policies.act.configuration_act import ACTConfig
|
|
||||||
from lerobot.policies.act.modeling_act import ACTTemporalEnsembler
|
|
||||||
from lerobot.policies.factory import (
|
|
||||||
get_policy_class,
|
|
||||||
make_policy,
|
|
||||||
make_policy_config,
|
|
||||||
)
|
|
||||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
|
||||||
from lerobot.utils.random_utils import seeded_context
|
|
||||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
|
||||||
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_path):
|
|
||||||
# Create only one camera input which is squared to fit all current policy constraints
|
|
||||||
# e.g. vqbet and tdmpc works with one camera only, and tdmpc requires it to be squared
|
|
||||||
camera_features = {
|
|
||||||
"observation.images.laptop": {
|
|
||||||
"shape": (84, 84, 3),
|
|
||||||
"names": ["height", "width", "channels"],
|
|
||||||
"info": None,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
motor_features = {
|
|
||||||
"action": {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": (6,),
|
|
||||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
|
||||||
},
|
|
||||||
"observation.state": {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": (6,),
|
|
||||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
info = info_factory(
|
|
||||||
total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features
|
|
||||||
)
|
|
||||||
ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info)
|
|
||||||
return ds_meta
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("policy_name", available_policies)
|
|
||||||
def test_get_policy_and_config_classes(policy_name: str):
|
|
||||||
"""Check that the correct policy and config classes are returned."""
|
|
||||||
policy_cls = get_policy_class(policy_name)
|
|
||||||
policy_cfg = make_policy_config(policy_name)
|
|
||||||
assert policy_cls.name == policy_name
|
|
||||||
assert issubclass(
|
|
||||||
policy_cfg.__class__, inspect.signature(policy_cls.__init__).parameters["config"].annotation
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"ds_repo_id,env_name,env_kwargs,policy_name,policy_kwargs",
|
|
||||||
[
|
|
||||||
("lerobot/xarm_lift_medium", "xarm", {}, "tdmpc", {"use_mpc": True}),
|
|
||||||
("lerobot/pusht", "pusht", {}, "diffusion", {}),
|
|
||||||
("lerobot/pusht", "pusht", {}, "vqbet", {}),
|
|
||||||
("lerobot/pusht", "pusht", {}, "act", {}),
|
|
||||||
("lerobot/aloha_sim_insertion_human", "aloha", {"task": "AlohaInsertion-v0"}, "act", {}),
|
|
||||||
(
|
|
||||||
"lerobot/aloha_sim_insertion_scripted",
|
|
||||||
"aloha",
|
|
||||||
{"task": "AlohaInsertion-v0"},
|
|
||||||
"act",
|
|
||||||
{},
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"lerobot/aloha_sim_insertion_human",
|
|
||||||
"aloha",
|
|
||||||
{"task": "AlohaInsertion-v0"},
|
|
||||||
"diffusion",
|
|
||||||
{},
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"lerobot/aloha_sim_transfer_cube_human",
|
|
||||||
"aloha",
|
|
||||||
{"task": "AlohaTransferCube-v0"},
|
|
||||||
"act",
|
|
||||||
{},
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"lerobot/aloha_sim_transfer_cube_scripted",
|
|
||||||
"aloha",
|
|
||||||
{"task": "AlohaTransferCube-v0"},
|
|
||||||
"act",
|
|
||||||
{},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@require_env
|
|
||||||
def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
|
||||||
"""
|
|
||||||
Tests:
|
|
||||||
- Making the policy object.
|
|
||||||
- Checking that the policy follows the correct protocol and subclasses nn.Module
|
|
||||||
and PyTorchModelHubMixin.
|
|
||||||
- Updating the policy.
|
|
||||||
- Using the policy to select actions at inference time.
|
|
||||||
- Test the action can be applied to the policy
|
|
||||||
|
|
||||||
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
|
|
||||||
and for now we add tests as we see fit.
|
|
||||||
"""
|
|
||||||
|
|
||||||
train_cfg = TrainPipelineConfig(
|
|
||||||
# TODO(rcadene, aliberts): remove dataset download
|
|
||||||
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
|
||||||
policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs),
|
|
||||||
env=make_env_config(env_name, **env_kwargs),
|
|
||||||
)
|
|
||||||
train_cfg.validate()
|
|
||||||
|
|
||||||
# Check that we can make the policy object.
|
|
||||||
dataset = make_dataset(train_cfg)
|
|
||||||
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
|
|
||||||
assert isinstance(policy, PreTrainedPolicy)
|
|
||||||
|
|
||||||
# Check that we run select_actions and get the appropriate output.
|
|
||||||
env = make_env(train_cfg.env, n_envs=2)
|
|
||||||
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
|
||||||
dataset,
|
|
||||||
num_workers=0,
|
|
||||||
batch_size=2,
|
|
||||||
shuffle=True,
|
|
||||||
pin_memory=DEVICE != "cpu",
|
|
||||||
drop_last=True,
|
|
||||||
)
|
|
||||||
dl_iter = cycle(dataloader)
|
|
||||||
|
|
||||||
batch = next(dl_iter)
|
|
||||||
|
|
||||||
for key in batch:
|
|
||||||
if isinstance(batch[key], torch.Tensor):
|
|
||||||
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
|
||||||
|
|
||||||
# Test updating the policy (and test that it does not mutate the batch)
|
|
||||||
batch_ = deepcopy(batch)
|
|
||||||
policy.forward(batch)
|
|
||||||
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass."
|
|
||||||
assert all(
|
|
||||||
torch.equal(batch[k], batch_[k]) if isinstance(batch[k], torch.Tensor) else batch[k] == batch_[k]
|
|
||||||
for k in batch
|
|
||||||
), "Batch values are not the same after a forward pass."
|
|
||||||
|
|
||||||
# reset the policy and environment
|
|
||||||
policy.reset()
|
|
||||||
observation, _ = env.reset(seed=train_cfg.seed)
|
|
||||||
|
|
||||||
# apply transform to normalize the observations
|
|
||||||
observation = preprocess_observation(observation)
|
|
||||||
|
|
||||||
# send observation to device/gpu
|
|
||||||
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
|
|
||||||
|
|
||||||
# get the next action for the environment (also check that the observation batch is not modified)
|
|
||||||
observation_ = deepcopy(observation)
|
|
||||||
with torch.inference_mode():
|
|
||||||
action = policy.select_action(observation).cpu().numpy()
|
|
||||||
assert set(observation) == set(observation_), (
|
|
||||||
"Observation batch keys are not the same after a forward pass."
|
|
||||||
)
|
|
||||||
assert all(torch.equal(observation[k], observation_[k]) for k in observation), (
|
|
||||||
"Observation batch values are not the same after a forward pass."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test step through policy
|
|
||||||
env.step(action)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(rcadene, aliberts): This test is quite end-to-end. Move this test in test_optimizer?
|
|
||||||
def test_act_backbone_lr():
|
|
||||||
"""
|
|
||||||
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
|
|
||||||
"""
|
|
||||||
|
|
||||||
cfg = TrainPipelineConfig(
|
|
||||||
# TODO(rcadene, aliberts): remove dataset download
|
|
||||||
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
|
|
||||||
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001, push_to_hub=False),
|
|
||||||
)
|
|
||||||
cfg.validate() # Needed for auto-setting some parameters
|
|
||||||
|
|
||||||
assert cfg.policy.optimizer_lr == 0.01
|
|
||||||
assert cfg.policy.optimizer_lr_backbone == 0.001
|
|
||||||
|
|
||||||
dataset = make_dataset(cfg)
|
|
||||||
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
|
|
||||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
|
||||||
assert len(optimizer.param_groups) == 2
|
|
||||||
assert optimizer.param_groups[0]["lr"] == cfg.policy.optimizer_lr
|
|
||||||
assert optimizer.param_groups[1]["lr"] == cfg.policy.optimizer_lr_backbone
|
|
||||||
assert len(optimizer.param_groups[0]["params"]) == 133
|
|
||||||
assert len(optimizer.param_groups[1]["params"]) == 20
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("policy_name", available_policies)
|
|
||||||
def test_policy_defaults(dummy_dataset_metadata, policy_name: str):
|
|
||||||
"""Check that the policy can be instantiated with defaults."""
|
|
||||||
policy_cls = get_policy_class(policy_name)
|
|
||||||
policy_cfg = make_policy_config(policy_name)
|
|
||||||
features = dataset_to_policy_features(dummy_dataset_metadata.features)
|
|
||||||
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
|
||||||
policy_cfg.input_features = {
|
|
||||||
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
|
|
||||||
}
|
|
||||||
policy_cls(policy_cfg)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("policy_name", available_policies)
|
|
||||||
def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: str):
|
|
||||||
policy_cls = get_policy_class(policy_name)
|
|
||||||
policy_cfg = make_policy_config(policy_name)
|
|
||||||
features = dataset_to_policy_features(dummy_dataset_metadata.features)
|
|
||||||
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
|
||||||
policy_cfg.input_features = {
|
|
||||||
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
|
|
||||||
}
|
|
||||||
policy = policy_cls(policy_cfg)
|
|
||||||
policy.to(policy_cfg.device)
|
|
||||||
save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}"
|
|
||||||
policy.save_pretrained(save_dir)
|
|
||||||
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
|
|
||||||
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
|
|
||||||
def test_normalize(insert_temporal_dim):
|
|
||||||
"""
|
|
||||||
Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise
|
|
||||||
an exception when the forward pass is called without the stats having been provided.
|
|
||||||
|
|
||||||
TODO(rcadene, alexander-soare): This should also test that the normalization / unnormalization works as
|
|
||||||
expected.
|
|
||||||
"""
|
|
||||||
|
|
||||||
input_features = {
|
|
||||||
"observation.image": PolicyFeature(
|
|
||||||
type=FeatureType.VISUAL,
|
|
||||||
shape=(3, 96, 96),
|
|
||||||
),
|
|
||||||
"observation.state": PolicyFeature(
|
|
||||||
type=FeatureType.STATE,
|
|
||||||
shape=(10,),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
output_features = {
|
|
||||||
"action": PolicyFeature(
|
|
||||||
type=FeatureType.ACTION,
|
|
||||||
shape=(5,),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
norm_map = {
|
|
||||||
"VISUAL": NormalizationMode.MEAN_STD,
|
|
||||||
"STATE": NormalizationMode.MIN_MAX,
|
|
||||||
"ACTION": NormalizationMode.MIN_MAX,
|
|
||||||
}
|
|
||||||
|
|
||||||
dataset_stats = {
|
|
||||||
"observation.image": {
|
|
||||||
"mean": torch.randn(3, 1, 1),
|
|
||||||
"std": torch.randn(3, 1, 1),
|
|
||||||
"min": torch.randn(3, 1, 1),
|
|
||||||
"max": torch.randn(3, 1, 1),
|
|
||||||
},
|
|
||||||
"observation.state": {
|
|
||||||
"mean": torch.randn(10),
|
|
||||||
"std": torch.randn(10),
|
|
||||||
"min": torch.randn(10),
|
|
||||||
"max": torch.randn(10),
|
|
||||||
},
|
|
||||||
"action": {
|
|
||||||
"mean": torch.randn(5),
|
|
||||||
"std": torch.randn(5),
|
|
||||||
"min": torch.randn(5),
|
|
||||||
"max": torch.randn(5),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
bsize = 2
|
|
||||||
input_batch = {
|
|
||||||
"observation.image": torch.randn(bsize, 3, 96, 96),
|
|
||||||
"observation.state": torch.randn(bsize, 10),
|
|
||||||
}
|
|
||||||
output_batch = {
|
|
||||||
"action": torch.randn(bsize, 5),
|
|
||||||
}
|
|
||||||
|
|
||||||
if insert_temporal_dim:
|
|
||||||
tdim = 4
|
|
||||||
|
|
||||||
for key in input_batch:
|
|
||||||
# [2,3,96,96] -> [2,tdim,3,96,96]
|
|
||||||
input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1)
|
|
||||||
|
|
||||||
for key in output_batch:
|
|
||||||
output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1)
|
|
||||||
|
|
||||||
# test without stats
|
|
||||||
normalize = Normalize(input_features, norm_map, stats=None)
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
normalize(input_batch)
|
|
||||||
|
|
||||||
# test with stats
|
|
||||||
normalize = Normalize(input_features, norm_map, stats=dataset_stats)
|
|
||||||
normalize(input_batch)
|
|
||||||
|
|
||||||
# test loading pretrained models
|
|
||||||
new_normalize = Normalize(input_features, norm_map, stats=None)
|
|
||||||
new_normalize.load_state_dict(normalize.state_dict())
|
|
||||||
new_normalize(input_batch)
|
|
||||||
|
|
||||||
# test without stats
|
|
||||||
unnormalize = Unnormalize(output_features, norm_map, stats=None)
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
unnormalize(output_batch)
|
|
||||||
|
|
||||||
# test with stats
|
|
||||||
unnormalize = Unnormalize(output_features, norm_map, stats=dataset_stats)
|
|
||||||
unnormalize(output_batch)
|
|
||||||
|
|
||||||
# test loading pretrained models
|
|
||||||
new_unnormalize = Unnormalize(output_features, norm_map, stats=None)
|
|
||||||
new_unnormalize.load_state_dict(unnormalize.state_dict())
|
|
||||||
unnormalize(output_batch)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("multikey", [True, False])
|
|
||||||
def test_multikey_construction(multikey: bool):
|
|
||||||
"""
|
|
||||||
Asserts that multiple keys with type State/Action are correctly processed by the policy constructor,
|
|
||||||
preventing erroneous creation of the policy object.
|
|
||||||
"""
|
|
||||||
input_features = {
|
|
||||||
"observation.state": PolicyFeature(
|
|
||||||
type=FeatureType.STATE,
|
|
||||||
shape=(10,),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
output_features = {
|
|
||||||
"action": PolicyFeature(
|
|
||||||
type=FeatureType.ACTION,
|
|
||||||
shape=(5,),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
if multikey:
|
|
||||||
"""Simulates the complete state/action is constructed from more granular multiple
|
|
||||||
keys, of the same type as the overall state/action"""
|
|
||||||
input_features = {}
|
|
||||||
input_features["observation.state.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
|
|
||||||
input_features["observation.state.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
|
|
||||||
input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(10,))
|
|
||||||
|
|
||||||
output_features = {}
|
|
||||||
output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,))
|
|
||||||
output_features["action.last_two_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(2,))
|
|
||||||
output_features["action"] = PolicyFeature(
|
|
||||||
type=FeatureType.ACTION,
|
|
||||||
shape=(5,),
|
|
||||||
)
|
|
||||||
|
|
||||||
config = ACTConfig(input_features=input_features, output_features=output_features)
|
|
||||||
|
|
||||||
state_condition = config.robot_state_feature == input_features[OBS_STATE]
|
|
||||||
action_condition = config.action_feature == output_features[ACTION]
|
|
||||||
|
|
||||||
assert state_condition, (
|
|
||||||
f"Discrepancy detected. Robot state feature is {config.robot_state_feature} but policy expects {input_features[OBS_STATE]}"
|
|
||||||
)
|
|
||||||
assert action_condition, (
|
|
||||||
f"Discrepancy detected. Action feature is {config.action_feature} but policy expects {output_features[ACTION]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"ds_repo_id, policy_name, policy_kwargs, file_name_extra",
|
|
||||||
[
|
|
||||||
# TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it
|
|
||||||
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
|
|
||||||
# to test with `policy.use_mpc=false`.
|
|
||||||
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
|
|
||||||
# ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
|
|
||||||
# TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to
|
|
||||||
# to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference
|
|
||||||
# that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.
|
|
||||||
# Thus, we deactivate this test for now.
|
|
||||||
(
|
|
||||||
"lerobot/pusht",
|
|
||||||
"diffusion",
|
|
||||||
{
|
|
||||||
"n_action_steps": 8,
|
|
||||||
"num_inference_steps": 10,
|
|
||||||
"down_dims": [128, 256, 512],
|
|
||||||
},
|
|
||||||
"",
|
|
||||||
),
|
|
||||||
("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""),
|
|
||||||
(
|
|
||||||
"lerobot/aloha_sim_insertion_human",
|
|
||||||
"act",
|
|
||||||
{"n_action_steps": 1000, "chunk_size": 1000},
|
|
||||||
"1000_steps",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
# As artifacts have been generated on an x86_64 kernel, this test won't
|
|
||||||
# pass if it's run on another platform due to floating point errors
|
|
||||||
@require_x86_64_kernel
|
|
||||||
@require_cpu
|
|
||||||
def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str):
|
|
||||||
"""
|
|
||||||
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
|
|
||||||
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
|
|
||||||
include a report on what changed and how that affected the outputs.
|
|
||||||
2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and
|
|
||||||
add the policies you want to update the test artifacts for.
|
|
||||||
3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact
|
|
||||||
should be updated.
|
|
||||||
4. Check that this test now passes.
|
|
||||||
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
|
||||||
6. Remember to stage and commit the resulting changes to `tests/artifacts`.
|
|
||||||
|
|
||||||
NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact
|
|
||||||
is out of date. For example, some PyTorch versions have different randomness, see this PR:
|
|
||||||
https://github.com/huggingface/lerobot/pull/1127.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
# NOTE: ACT policy has different randomness, after PyTorch 2.7.0
|
|
||||||
if policy_name == "act" and version.parse(torch.__version__) < version.parse("2.7.0"):
|
|
||||||
pytest.skip(f"Skipping act policy test with PyTorch {torch.__version__}. Requires PyTorch >= 2.7.0")
|
|
||||||
|
|
||||||
ds_name = ds_repo_id.split("/")[-1]
|
|
||||||
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
|
|
||||||
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
|
|
||||||
saved_grad_stats = load_file(artifact_dir / "grad_stats.safetensors")
|
|
||||||
saved_param_stats = load_file(artifact_dir / "param_stats.safetensors")
|
|
||||||
saved_actions = load_file(artifact_dir / "actions.safetensors")
|
|
||||||
|
|
||||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
|
|
||||||
|
|
||||||
for key in saved_output_dict:
|
|
||||||
torch.testing.assert_close(output_dict[key], saved_output_dict[key])
|
|
||||||
for key in saved_grad_stats:
|
|
||||||
torch.testing.assert_close(grad_stats[key], saved_grad_stats[key])
|
|
||||||
for key in saved_param_stats:
|
|
||||||
torch.testing.assert_close(param_stats[key], saved_param_stats[key])
|
|
||||||
for key in saved_actions:
|
|
||||||
rtol, atol = (2e-3, 5e-6) if policy_name == "diffusion" else (None, None) # HACK
|
|
||||||
torch.testing.assert_close(actions[key], saved_actions[key], rtol=rtol, atol=atol)
|
|
||||||
|
|
||||||
|
|
||||||
def test_act_temporal_ensembler():
|
|
||||||
"""Check that the online method in ACTTemporalEnsembler matches a simple offline calculation."""
|
|
||||||
temporal_ensemble_coeff = 0.01
|
|
||||||
chunk_size = 100
|
|
||||||
episode_length = 101
|
|
||||||
ensembler = ACTTemporalEnsembler(temporal_ensemble_coeff, chunk_size)
|
|
||||||
# An batch of arbitrary sequences of 1D actions we wish to compute the average over. We'll keep the
|
|
||||||
# "action space" in [-1, 1]. Apart from that, there is no real reason for the numbers chosen.
|
|
||||||
with seeded_context(0):
|
|
||||||
# Dimension is (batch, episode_length, chunk_size, action_dim(=1))
|
|
||||||
# Stepping through the episode_length dim is like running inference at each rollout step and getting
|
|
||||||
# a different action chunk.
|
|
||||||
batch_seq = torch.stack(
|
|
||||||
[
|
|
||||||
torch.rand(episode_length, chunk_size) * 0.05 - 0.6,
|
|
||||||
torch.rand(episode_length, chunk_size) * 0.02 - 0.01,
|
|
||||||
torch.rand(episode_length, chunk_size) * 0.2 + 0.3,
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
).unsqueeze(-1) # unsqueeze for action dim
|
|
||||||
batch_size = batch_seq.shape[0]
|
|
||||||
# Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length`
|
|
||||||
# dimension of `batch_seq`.
|
|
||||||
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1)
|
|
||||||
|
|
||||||
# Simulate stepping through a rollout and computing a batch of actions with model on each step.
|
|
||||||
for i in range(episode_length):
|
|
||||||
# Mock a batch of actions.
|
|
||||||
actions = torch.zeros(size=(batch_size, chunk_size, 1)) + batch_seq[:, i]
|
|
||||||
online_avg = ensembler.update(actions)
|
|
||||||
# Simple offline calculation: avg = Σ(aᵢ*wᵢ) / Σ(wᵢ).
|
|
||||||
# Note: The complicated bit here is the slicing. Think about the (episode_length, chunk_size) grid.
|
|
||||||
# What we want to do is take diagonal slices across it starting from the left.
|
|
||||||
# eg: chunk_size=4, episode_length=6
|
|
||||||
# ┌───────┐
|
|
||||||
# │0 1 2 3│
|
|
||||||
# │1 2 3 4│
|
|
||||||
# │2 3 4 5│
|
|
||||||
# │3 4 5 6│
|
|
||||||
# │4 5 6 7│
|
|
||||||
# │5 6 7 8│
|
|
||||||
# └───────┘
|
|
||||||
chunk_indices = torch.arange(min(i, chunk_size - 1), -1, -1)
|
|
||||||
episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :]
|
|
||||||
seq_slice = batch_seq[:, episode_step_indices, chunk_indices]
|
|
||||||
offline_avg = (
|
|
||||||
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum()
|
|
||||||
)
|
|
||||||
# Sanity check. The average should be between the extrema.
|
|
||||||
assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg)
|
|
||||||
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
|
|
||||||
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
|
|
||||||
torch.testing.assert_close(online_avg, offline_avg, rtol=1e-4, atol=1e-4)
|
|
||||||
@ -1,217 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
|
||||||
from lerobot.policies.sac.configuration_sac import (
|
|
||||||
ActorLearnerConfig,
|
|
||||||
ActorNetworkConfig,
|
|
||||||
ConcurrencyConfig,
|
|
||||||
CriticNetworkConfig,
|
|
||||||
PolicyConfig,
|
|
||||||
SACConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_sac_config_default_initialization():
|
|
||||||
config = SACConfig()
|
|
||||||
|
|
||||||
assert config.normalization_mapping == {
|
|
||||||
"VISUAL": NormalizationMode.MEAN_STD,
|
|
||||||
"STATE": NormalizationMode.MIN_MAX,
|
|
||||||
"ENV": NormalizationMode.MIN_MAX,
|
|
||||||
"ACTION": NormalizationMode.MIN_MAX,
|
|
||||||
}
|
|
||||||
assert config.dataset_stats == {
|
|
||||||
"observation.image": {
|
|
||||||
"mean": [0.485, 0.456, 0.406],
|
|
||||||
"std": [0.229, 0.224, 0.225],
|
|
||||||
},
|
|
||||||
"observation.state": {
|
|
||||||
"min": [0.0, 0.0],
|
|
||||||
"max": [1.0, 1.0],
|
|
||||||
},
|
|
||||||
"action": {
|
|
||||||
"min": [0.0, 0.0, 0.0],
|
|
||||||
"max": [1.0, 1.0, 1.0],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Basic parameters
|
|
||||||
assert config.device == "cpu"
|
|
||||||
assert config.storage_device == "cpu"
|
|
||||||
assert config.discount == 0.99
|
|
||||||
assert config.temperature_init == 1.0
|
|
||||||
assert config.num_critics == 2
|
|
||||||
|
|
||||||
# Architecture specifics
|
|
||||||
assert config.vision_encoder_name is None
|
|
||||||
assert config.freeze_vision_encoder is True
|
|
||||||
assert config.image_encoder_hidden_dim == 32
|
|
||||||
assert config.shared_encoder is True
|
|
||||||
assert config.num_discrete_actions is None
|
|
||||||
assert config.image_embedding_pooling_dim == 8
|
|
||||||
|
|
||||||
# Training parameters
|
|
||||||
assert config.online_steps == 1000000
|
|
||||||
assert config.online_env_seed == 10000
|
|
||||||
assert config.online_buffer_capacity == 100000
|
|
||||||
assert config.offline_buffer_capacity == 100000
|
|
||||||
assert config.async_prefetch is False
|
|
||||||
assert config.online_step_before_learning == 100
|
|
||||||
assert config.policy_update_freq == 1
|
|
||||||
|
|
||||||
# SAC algorithm parameters
|
|
||||||
assert config.num_subsample_critics is None
|
|
||||||
assert config.critic_lr == 3e-4
|
|
||||||
assert config.actor_lr == 3e-4
|
|
||||||
assert config.temperature_lr == 3e-4
|
|
||||||
assert config.critic_target_update_weight == 0.005
|
|
||||||
assert config.utd_ratio == 1
|
|
||||||
assert config.state_encoder_hidden_dim == 256
|
|
||||||
assert config.latent_dim == 256
|
|
||||||
assert config.target_entropy is None
|
|
||||||
assert config.use_backup_entropy is True
|
|
||||||
assert config.grad_clip_norm == 40.0
|
|
||||||
|
|
||||||
# Dataset stats defaults
|
|
||||||
expected_dataset_stats = {
|
|
||||||
"observation.image": {
|
|
||||||
"mean": [0.485, 0.456, 0.406],
|
|
||||||
"std": [0.229, 0.224, 0.225],
|
|
||||||
},
|
|
||||||
"observation.state": {
|
|
||||||
"min": [0.0, 0.0],
|
|
||||||
"max": [1.0, 1.0],
|
|
||||||
},
|
|
||||||
"action": {
|
|
||||||
"min": [0.0, 0.0, 0.0],
|
|
||||||
"max": [1.0, 1.0, 1.0],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
assert config.dataset_stats == expected_dataset_stats
|
|
||||||
|
|
||||||
# Critic network configuration
|
|
||||||
assert config.critic_network_kwargs.hidden_dims == [256, 256]
|
|
||||||
assert config.critic_network_kwargs.activate_final is True
|
|
||||||
assert config.critic_network_kwargs.final_activation is None
|
|
||||||
|
|
||||||
# Actor network configuration
|
|
||||||
assert config.actor_network_kwargs.hidden_dims == [256, 256]
|
|
||||||
assert config.actor_network_kwargs.activate_final is True
|
|
||||||
|
|
||||||
# Policy configuration
|
|
||||||
assert config.policy_kwargs.use_tanh_squash is True
|
|
||||||
assert config.policy_kwargs.std_min == 1e-5
|
|
||||||
assert config.policy_kwargs.std_max == 10.0
|
|
||||||
assert config.policy_kwargs.init_final == 0.05
|
|
||||||
|
|
||||||
# Discrete critic network configuration
|
|
||||||
assert config.discrete_critic_network_kwargs.hidden_dims == [256, 256]
|
|
||||||
assert config.discrete_critic_network_kwargs.activate_final is True
|
|
||||||
assert config.discrete_critic_network_kwargs.final_activation is None
|
|
||||||
|
|
||||||
# Actor learner configuration
|
|
||||||
assert config.actor_learner_config.learner_host == "127.0.0.1"
|
|
||||||
assert config.actor_learner_config.learner_port == 50051
|
|
||||||
assert config.actor_learner_config.policy_parameters_push_frequency == 4
|
|
||||||
|
|
||||||
# Concurrency configuration
|
|
||||||
assert config.concurrency.actor == "threads"
|
|
||||||
assert config.concurrency.learner == "threads"
|
|
||||||
|
|
||||||
assert isinstance(config.actor_network_kwargs, ActorNetworkConfig)
|
|
||||||
assert isinstance(config.critic_network_kwargs, CriticNetworkConfig)
|
|
||||||
assert isinstance(config.policy_kwargs, PolicyConfig)
|
|
||||||
assert isinstance(config.actor_learner_config, ActorLearnerConfig)
|
|
||||||
assert isinstance(config.concurrency, ConcurrencyConfig)
|
|
||||||
|
|
||||||
|
|
||||||
def test_critic_network_kwargs():
|
|
||||||
config = CriticNetworkConfig()
|
|
||||||
assert config.hidden_dims == [256, 256]
|
|
||||||
assert config.activate_final is True
|
|
||||||
assert config.final_activation is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_actor_network_kwargs():
|
|
||||||
config = ActorNetworkConfig()
|
|
||||||
assert config.hidden_dims == [256, 256]
|
|
||||||
assert config.activate_final is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_policy_kwargs():
|
|
||||||
config = PolicyConfig()
|
|
||||||
assert config.use_tanh_squash is True
|
|
||||||
assert config.std_min == 1e-5
|
|
||||||
assert config.std_max == 10.0
|
|
||||||
assert config.init_final == 0.05
|
|
||||||
|
|
||||||
|
|
||||||
def test_actor_learner_config():
|
|
||||||
config = ActorLearnerConfig()
|
|
||||||
assert config.learner_host == "127.0.0.1"
|
|
||||||
assert config.learner_port == 50051
|
|
||||||
assert config.policy_parameters_push_frequency == 4
|
|
||||||
|
|
||||||
|
|
||||||
def test_concurrency_config():
|
|
||||||
config = ConcurrencyConfig()
|
|
||||||
assert config.actor == "threads"
|
|
||||||
assert config.learner == "threads"
|
|
||||||
|
|
||||||
|
|
||||||
def test_sac_config_custom_initialization():
|
|
||||||
config = SACConfig(
|
|
||||||
device="cpu",
|
|
||||||
discount=0.95,
|
|
||||||
temperature_init=0.5,
|
|
||||||
num_critics=3,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert config.device == "cpu"
|
|
||||||
assert config.discount == 0.95
|
|
||||||
assert config.temperature_init == 0.5
|
|
||||||
assert config.num_critics == 3
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_features():
|
|
||||||
config = SACConfig(
|
|
||||||
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
|
||||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
|
||||||
)
|
|
||||||
config.validate_features()
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_features_missing_observation():
|
|
||||||
config = SACConfig(
|
|
||||||
input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
|
||||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
|
||||||
)
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError, match="You must provide either 'observation.state' or an image observation"
|
|
||||||
):
|
|
||||||
config.validate_features()
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_features_missing_action():
|
|
||||||
config = SACConfig(
|
|
||||||
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
|
||||||
output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
|
||||||
)
|
|
||||||
with pytest.raises(ValueError, match="You must provide 'action' in the output features"):
|
|
||||||
config.validate_features()
|
|
||||||
@ -1,541 +0,0 @@
|
|||||||
# !/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import math
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from torch import Tensor, nn
|
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
|
||||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
|
||||||
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
|
|
||||||
from lerobot.utils.random_utils import seeded_context, set_seed
|
|
||||||
|
|
||||||
try:
|
|
||||||
import transformers # noqa: F401
|
|
||||||
|
|
||||||
TRANSFORMERS_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
TRANSFORMERS_AVAILABLE = False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def set_random_seed():
|
|
||||||
seed = 42
|
|
||||||
set_seed(seed)
|
|
||||||
|
|
||||||
|
|
||||||
def test_mlp_with_default_args():
|
|
||||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
|
|
||||||
|
|
||||||
x = torch.randn(10)
|
|
||||||
y = mlp(x)
|
|
||||||
assert y.shape == (256,)
|
|
||||||
|
|
||||||
|
|
||||||
def test_mlp_with_batch_dim():
|
|
||||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
|
|
||||||
x = torch.randn(2, 10)
|
|
||||||
y = mlp(x)
|
|
||||||
assert y.shape == (2, 256)
|
|
||||||
|
|
||||||
|
|
||||||
def test_forward_with_empty_hidden_dims():
|
|
||||||
mlp = MLP(input_dim=10, hidden_dims=[])
|
|
||||||
x = torch.randn(1, 10)
|
|
||||||
assert mlp(x).shape == (1, 10)
|
|
||||||
|
|
||||||
|
|
||||||
def test_mlp_with_dropout():
|
|
||||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1)
|
|
||||||
x = torch.randn(1, 10)
|
|
||||||
y = mlp(x)
|
|
||||||
assert y.shape == (1, 11)
|
|
||||||
|
|
||||||
drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net)
|
|
||||||
assert drop_out_layers_count == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_mlp_with_custom_final_activation():
|
|
||||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh())
|
|
||||||
x = torch.randn(1, 10)
|
|
||||||
y = mlp(x)
|
|
||||||
assert y.shape == (1, 256)
|
|
||||||
assert (y >= -1).all() and (y <= 1).all()
|
|
||||||
|
|
||||||
|
|
||||||
def test_sac_policy_with_default_args():
|
|
||||||
with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"):
|
|
||||||
SACPolicy()
|
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
|
|
||||||
return {
|
|
||||||
"observation.state": torch.randn(batch_size, state_dim),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor:
|
|
||||||
return {
|
|
||||||
"observation.image": torch.randn(batch_size, 3, 84, 84),
|
|
||||||
"observation.state": torch.randn(batch_size, state_dim),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_action(batch_size: int, action_dim: int = 10) -> Tensor:
|
|
||||||
return torch.randn(batch_size, action_dim)
|
|
||||||
|
|
||||||
|
|
||||||
def create_default_train_batch(
|
|
||||||
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
|
|
||||||
) -> dict[str, Tensor]:
|
|
||||||
return {
|
|
||||||
"action": create_dummy_action(batch_size, action_dim),
|
|
||||||
"reward": torch.randn(batch_size),
|
|
||||||
"state": create_dummy_state(batch_size, state_dim),
|
|
||||||
"next_state": create_dummy_state(batch_size, state_dim),
|
|
||||||
"done": torch.randn(batch_size),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def create_train_batch_with_visual_input(
|
|
||||||
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
|
|
||||||
) -> dict[str, Tensor]:
|
|
||||||
return {
|
|
||||||
"action": create_dummy_action(batch_size, action_dim),
|
|
||||||
"reward": torch.randn(batch_size),
|
|
||||||
"state": create_dummy_with_visual_input(batch_size, state_dim),
|
|
||||||
"next_state": create_dummy_with_visual_input(batch_size, state_dim),
|
|
||||||
"done": torch.randn(batch_size),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
|
|
||||||
return {
|
|
||||||
"observation.state": torch.randn(batch_size, state_dim),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
|
|
||||||
return {
|
|
||||||
"observation.state": torch.randn(batch_size, state_dim),
|
|
||||||
"observation.image": torch.randn(batch_size, 3, 84, 84),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]:
|
|
||||||
"""Create optimizers for the SAC policy."""
|
|
||||||
optimizer_actor = torch.optim.Adam(
|
|
||||||
# Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient
|
|
||||||
params=[
|
|
||||||
p
|
|
||||||
for n, p in policy.actor.named_parameters()
|
|
||||||
if not policy.config.shared_encoder or not n.startswith("encoder")
|
|
||||||
],
|
|
||||||
lr=policy.config.actor_lr,
|
|
||||||
)
|
|
||||||
optimizer_critic = torch.optim.Adam(
|
|
||||||
params=policy.critic_ensemble.parameters(),
|
|
||||||
lr=policy.config.critic_lr,
|
|
||||||
)
|
|
||||||
optimizer_temperature = torch.optim.Adam(
|
|
||||||
params=[policy.log_alpha],
|
|
||||||
lr=policy.config.critic_lr,
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizers = {
|
|
||||||
"actor": optimizer_actor,
|
|
||||||
"critic": optimizer_critic,
|
|
||||||
"temperature": optimizer_temperature,
|
|
||||||
}
|
|
||||||
|
|
||||||
if has_discrete_action:
|
|
||||||
optimizers["discrete_critic"] = torch.optim.Adam(
|
|
||||||
params=policy.discrete_critic.parameters(),
|
|
||||||
lr=policy.config.critic_lr,
|
|
||||||
)
|
|
||||||
|
|
||||||
return optimizers
|
|
||||||
|
|
||||||
|
|
||||||
def create_default_config(
|
|
||||||
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
|
|
||||||
) -> SACConfig:
|
|
||||||
action_dim = continuous_action_dim
|
|
||||||
if has_discrete_action:
|
|
||||||
action_dim += 1
|
|
||||||
|
|
||||||
config = SACConfig(
|
|
||||||
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
|
||||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
|
|
||||||
dataset_stats={
|
|
||||||
"observation.state": {
|
|
||||||
"min": [0.0] * state_dim,
|
|
||||||
"max": [1.0] * state_dim,
|
|
||||||
},
|
|
||||||
"action": {
|
|
||||||
"min": [0.0] * continuous_action_dim,
|
|
||||||
"max": [1.0] * continuous_action_dim,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
config.validate_features()
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def create_config_with_visual_input(
|
|
||||||
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
|
|
||||||
) -> SACConfig:
|
|
||||||
config = create_default_config(
|
|
||||||
state_dim=state_dim,
|
|
||||||
continuous_action_dim=continuous_action_dim,
|
|
||||||
has_discrete_action=has_discrete_action,
|
|
||||||
)
|
|
||||||
config.input_features["observation.image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
|
|
||||||
config.dataset_stats["observation.image"] = {
|
|
||||||
"mean": torch.randn(3, 1, 1),
|
|
||||||
"std": torch.randn(3, 1, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Let make tests a little bit faster
|
|
||||||
config.state_encoder_hidden_dim = 32
|
|
||||||
config.latent_dim = 32
|
|
||||||
|
|
||||||
config.validate_features()
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
|
||||||
def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int):
|
|
||||||
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
|
|
||||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
|
||||||
|
|
||||||
policy = SACPolicy(config=config)
|
|
||||||
policy.train()
|
|
||||||
|
|
||||||
optimizers = make_optimizers(policy)
|
|
||||||
|
|
||||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
|
||||||
assert cirtic_loss.item() is not None
|
|
||||||
assert cirtic_loss.shape == ()
|
|
||||||
cirtic_loss.backward()
|
|
||||||
optimizers["critic"].step()
|
|
||||||
|
|
||||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
|
||||||
assert actor_loss.item() is not None
|
|
||||||
assert actor_loss.shape == ()
|
|
||||||
|
|
||||||
actor_loss.backward()
|
|
||||||
optimizers["actor"].step()
|
|
||||||
|
|
||||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
|
||||||
assert temperature_loss.item() is not None
|
|
||||||
assert temperature_loss.shape == ()
|
|
||||||
|
|
||||||
temperature_loss.backward()
|
|
||||||
optimizers["temperature"].step()
|
|
||||||
|
|
||||||
policy.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
|
||||||
selected_action = policy.select_action(observation_batch)
|
|
||||||
assert selected_action.shape == (batch_size, action_dim)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
|
||||||
def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
|
|
||||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
|
||||||
policy = SACPolicy(config=config)
|
|
||||||
|
|
||||||
batch = create_train_batch_with_visual_input(
|
|
||||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
policy.train()
|
|
||||||
|
|
||||||
optimizers = make_optimizers(policy)
|
|
||||||
|
|
||||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
|
||||||
assert cirtic_loss.item() is not None
|
|
||||||
assert cirtic_loss.shape == ()
|
|
||||||
cirtic_loss.backward()
|
|
||||||
optimizers["critic"].step()
|
|
||||||
|
|
||||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
|
||||||
assert actor_loss.item() is not None
|
|
||||||
assert actor_loss.shape == ()
|
|
||||||
|
|
||||||
actor_loss.backward()
|
|
||||||
optimizers["actor"].step()
|
|
||||||
|
|
||||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
|
||||||
assert temperature_loss.item() is not None
|
|
||||||
assert temperature_loss.shape == ()
|
|
||||||
|
|
||||||
temperature_loss.backward()
|
|
||||||
optimizers["temperature"].step()
|
|
||||||
|
|
||||||
policy.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
observation_batch = create_observation_batch_with_visual_input(
|
|
||||||
batch_size=batch_size, state_dim=state_dim
|
|
||||||
)
|
|
||||||
selected_action = policy.select_action(observation_batch)
|
|
||||||
assert selected_action.shape == (batch_size, action_dim)
|
|
||||||
|
|
||||||
|
|
||||||
# Let's check best candidates for pretrained encoders
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"batch_size,state_dim,action_dim,vision_encoder_name",
|
|
||||||
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
|
|
||||||
)
|
|
||||||
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
|
|
||||||
def test_sac_policy_with_pretrained_encoder(
|
|
||||||
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
|
|
||||||
):
|
|
||||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
|
||||||
config.vision_encoder_name = vision_encoder_name
|
|
||||||
policy = SACPolicy(config=config)
|
|
||||||
policy.train()
|
|
||||||
|
|
||||||
batch = create_train_batch_with_visual_input(
|
|
||||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizers = make_optimizers(policy)
|
|
||||||
|
|
||||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
|
||||||
assert cirtic_loss.item() is not None
|
|
||||||
assert cirtic_loss.shape == ()
|
|
||||||
cirtic_loss.backward()
|
|
||||||
optimizers["critic"].step()
|
|
||||||
|
|
||||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
|
||||||
assert actor_loss.item() is not None
|
|
||||||
assert actor_loss.shape == ()
|
|
||||||
|
|
||||||
|
|
||||||
def test_sac_policy_with_shared_encoder():
|
|
||||||
batch_size = 2
|
|
||||||
action_dim = 10
|
|
||||||
state_dim = 10
|
|
||||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
|
||||||
config.shared_encoder = True
|
|
||||||
|
|
||||||
policy = SACPolicy(config=config)
|
|
||||||
policy.train()
|
|
||||||
|
|
||||||
batch = create_train_batch_with_visual_input(
|
|
||||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
policy.train()
|
|
||||||
|
|
||||||
optimizers = make_optimizers(policy)
|
|
||||||
|
|
||||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
|
||||||
assert cirtic_loss.item() is not None
|
|
||||||
assert cirtic_loss.shape == ()
|
|
||||||
cirtic_loss.backward()
|
|
||||||
optimizers["critic"].step()
|
|
||||||
|
|
||||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
|
||||||
assert actor_loss.item() is not None
|
|
||||||
assert actor_loss.shape == ()
|
|
||||||
|
|
||||||
actor_loss.backward()
|
|
||||||
optimizers["actor"].step()
|
|
||||||
|
|
||||||
|
|
||||||
def test_sac_policy_with_discrete_critic():
|
|
||||||
batch_size = 2
|
|
||||||
continuous_action_dim = 9
|
|
||||||
full_action_dim = continuous_action_dim + 1 # the last action is discrete
|
|
||||||
state_dim = 10
|
|
||||||
config = create_config_with_visual_input(
|
|
||||||
state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True
|
|
||||||
)
|
|
||||||
|
|
||||||
num_discrete_actions = 5
|
|
||||||
config.num_discrete_actions = num_discrete_actions
|
|
||||||
|
|
||||||
policy = SACPolicy(config=config)
|
|
||||||
policy.train()
|
|
||||||
|
|
||||||
batch = create_train_batch_with_visual_input(
|
|
||||||
batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
policy.train()
|
|
||||||
|
|
||||||
optimizers = make_optimizers(policy, has_discrete_action=True)
|
|
||||||
|
|
||||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
|
||||||
assert cirtic_loss.item() is not None
|
|
||||||
assert cirtic_loss.shape == ()
|
|
||||||
cirtic_loss.backward()
|
|
||||||
optimizers["critic"].step()
|
|
||||||
|
|
||||||
discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"]
|
|
||||||
assert discrete_critic_loss.item() is not None
|
|
||||||
assert discrete_critic_loss.shape == ()
|
|
||||||
discrete_critic_loss.backward()
|
|
||||||
optimizers["discrete_critic"].step()
|
|
||||||
|
|
||||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
|
||||||
assert actor_loss.item() is not None
|
|
||||||
assert actor_loss.shape == ()
|
|
||||||
|
|
||||||
actor_loss.backward()
|
|
||||||
optimizers["actor"].step()
|
|
||||||
|
|
||||||
policy.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
observation_batch = create_observation_batch_with_visual_input(
|
|
||||||
batch_size=batch_size, state_dim=state_dim
|
|
||||||
)
|
|
||||||
selected_action = policy.select_action(observation_batch)
|
|
||||||
assert selected_action.shape == (batch_size, full_action_dim)
|
|
||||||
|
|
||||||
discrete_actions = selected_action[:, -1].long()
|
|
||||||
discrete_action_values = set(discrete_actions.tolist())
|
|
||||||
|
|
||||||
assert all(action in range(num_discrete_actions) for action in discrete_action_values), (
|
|
||||||
f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_sac_policy_with_default_entropy():
|
|
||||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
|
||||||
policy = SACPolicy(config=config)
|
|
||||||
assert policy.target_entropy == -5.0
|
|
||||||
|
|
||||||
|
|
||||||
def test_sac_policy_default_target_entropy_with_discrete_action():
|
|
||||||
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
|
|
||||||
policy = SACPolicy(config=config)
|
|
||||||
assert policy.target_entropy == -3.0
|
|
||||||
|
|
||||||
|
|
||||||
def test_sac_policy_with_predefined_entropy():
|
|
||||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
|
||||||
config.target_entropy = -3.5
|
|
||||||
|
|
||||||
policy = SACPolicy(config=config)
|
|
||||||
assert policy.target_entropy == pytest.approx(-3.5)
|
|
||||||
|
|
||||||
|
|
||||||
def test_sac_policy_update_temperature():
|
|
||||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
|
||||||
policy = SACPolicy(config=config)
|
|
||||||
|
|
||||||
assert policy.temperature == pytest.approx(1.0)
|
|
||||||
policy.log_alpha.data = torch.tensor([math.log(0.1)])
|
|
||||||
policy.update_temperature()
|
|
||||||
assert policy.temperature == pytest.approx(0.1)
|
|
||||||
|
|
||||||
|
|
||||||
def test_sac_policy_update_target_network():
|
|
||||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
|
||||||
config.critic_target_update_weight = 1.0
|
|
||||||
|
|
||||||
policy = SACPolicy(config=config)
|
|
||||||
policy.train()
|
|
||||||
|
|
||||||
for p in policy.critic_ensemble.parameters():
|
|
||||||
p.data = torch.ones_like(p.data)
|
|
||||||
|
|
||||||
policy.update_target_networks()
|
|
||||||
for p in policy.critic_target.parameters():
|
|
||||||
assert torch.allclose(p.data, torch.ones_like(p.data)), (
|
|
||||||
f"Target network {p.data} is not equal to {torch.ones_like(p.data)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_critics", [1, 3])
|
|
||||||
def test_sac_policy_with_critics_number_of_heads(num_critics: int):
|
|
||||||
batch_size = 2
|
|
||||||
action_dim = 10
|
|
||||||
state_dim = 10
|
|
||||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
|
||||||
config.num_critics = num_critics
|
|
||||||
|
|
||||||
policy = SACPolicy(config=config)
|
|
||||||
policy.train()
|
|
||||||
|
|
||||||
assert len(policy.critic_ensemble.critics) == num_critics
|
|
||||||
|
|
||||||
batch = create_train_batch_with_visual_input(
|
|
||||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
policy.train()
|
|
||||||
|
|
||||||
optimizers = make_optimizers(policy)
|
|
||||||
|
|
||||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
|
||||||
assert cirtic_loss.item() is not None
|
|
||||||
assert cirtic_loss.shape == ()
|
|
||||||
cirtic_loss.backward()
|
|
||||||
optimizers["critic"].step()
|
|
||||||
|
|
||||||
|
|
||||||
def test_sac_policy_save_and_load(tmp_path):
|
|
||||||
root = tmp_path / "test_sac_save_and_load"
|
|
||||||
|
|
||||||
state_dim = 10
|
|
||||||
action_dim = 10
|
|
||||||
batch_size = 2
|
|
||||||
|
|
||||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
|
||||||
policy = SACPolicy(config=config)
|
|
||||||
policy.eval()
|
|
||||||
policy.save_pretrained(root)
|
|
||||||
loaded_policy = SACPolicy.from_pretrained(root, config=config)
|
|
||||||
loaded_policy.eval()
|
|
||||||
|
|
||||||
batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
with seeded_context(12):
|
|
||||||
# Collect policy values before saving
|
|
||||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
|
||||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
|
||||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
|
||||||
|
|
||||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
|
||||||
actions = policy.select_action(observation_batch)
|
|
||||||
|
|
||||||
with seeded_context(12):
|
|
||||||
# Collect policy values after loading
|
|
||||||
loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"]
|
|
||||||
loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"]
|
|
||||||
loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"]
|
|
||||||
|
|
||||||
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
|
||||||
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
|
|
||||||
|
|
||||||
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
|
||||||
for k in policy.state_dict():
|
|
||||||
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
|
|
||||||
|
|
||||||
# Compare values before and after saving and loading
|
|
||||||
# They should be the same
|
|
||||||
assert torch.allclose(cirtic_loss, loaded_cirtic_loss)
|
|
||||||
assert torch.allclose(actor_loss, loaded_actor_loss)
|
|
||||||
assert torch.allclose(temperature_loss, loaded_temperature_loss)
|
|
||||||
assert torch.allclose(actions, loaded_actions)
|
|
||||||
@ -1,282 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.processor.pipeline import (
|
|
||||||
RobotProcessor,
|
|
||||||
TransitionKey,
|
|
||||||
_default_batch_to_transition,
|
|
||||||
_default_transition_to_batch,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _dummy_batch():
|
|
||||||
"""Create a dummy batch using the new format with observation.* and next.* keys."""
|
|
||||||
return {
|
|
||||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
|
||||||
"observation.image.right": torch.randn(1, 3, 128, 128),
|
|
||||||
"observation.state": torch.tensor([[0.1, 0.2, 0.3, 0.4]]),
|
|
||||||
"action": torch.tensor([[0.5]]),
|
|
||||||
"next.reward": 1.0,
|
|
||||||
"next.done": False,
|
|
||||||
"next.truncated": False,
|
|
||||||
"info": {"key": "value"},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_observation_grouping_roundtrip():
|
|
||||||
"""Test that observation.* keys are properly grouped and ungrouped."""
|
|
||||||
proc = RobotProcessor([])
|
|
||||||
batch_in = _dummy_batch()
|
|
||||||
batch_out = proc(batch_in)
|
|
||||||
|
|
||||||
# Check that all observation.* keys are preserved
|
|
||||||
original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith("observation.")}
|
|
||||||
reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith("observation.")}
|
|
||||||
|
|
||||||
assert set(original_obs_keys.keys()) == set(reconstructed_obs_keys.keys())
|
|
||||||
|
|
||||||
# Check tensor values
|
|
||||||
assert torch.allclose(batch_out["observation.image.left"], batch_in["observation.image.left"])
|
|
||||||
assert torch.allclose(batch_out["observation.image.right"], batch_in["observation.image.right"])
|
|
||||||
assert torch.allclose(batch_out["observation.state"], batch_in["observation.state"])
|
|
||||||
|
|
||||||
# Check other fields
|
|
||||||
assert torch.allclose(batch_out["action"], batch_in["action"])
|
|
||||||
assert batch_out["next.reward"] == batch_in["next.reward"]
|
|
||||||
assert batch_out["next.done"] == batch_in["next.done"]
|
|
||||||
assert batch_out["next.truncated"] == batch_in["next.truncated"]
|
|
||||||
assert batch_out["info"] == batch_in["info"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_batch_to_transition_observation_grouping():
|
|
||||||
"""Test that _default_batch_to_transition correctly groups observation.* keys."""
|
|
||||||
batch = {
|
|
||||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
|
||||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
|
||||||
"observation.state": [1, 2, 3, 4],
|
|
||||||
"action": "action_data",
|
|
||||||
"next.reward": 1.5,
|
|
||||||
"next.done": True,
|
|
||||||
"next.truncated": False,
|
|
||||||
"info": {"episode": 42},
|
|
||||||
}
|
|
||||||
|
|
||||||
transition = _default_batch_to_transition(batch)
|
|
||||||
|
|
||||||
# Check observation is a dict with all observation.* keys
|
|
||||||
assert isinstance(transition[TransitionKey.OBSERVATION], dict)
|
|
||||||
assert "observation.image.top" in transition[TransitionKey.OBSERVATION]
|
|
||||||
assert "observation.image.left" in transition[TransitionKey.OBSERVATION]
|
|
||||||
assert "observation.state" in transition[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Check values are preserved
|
|
||||||
assert torch.allclose(
|
|
||||||
transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"]
|
|
||||||
)
|
|
||||||
assert torch.allclose(
|
|
||||||
transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"]
|
|
||||||
)
|
|
||||||
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
|
|
||||||
|
|
||||||
# Check other fields
|
|
||||||
assert transition[TransitionKey.ACTION] == "action_data"
|
|
||||||
assert transition[TransitionKey.REWARD] == 1.5
|
|
||||||
assert transition[TransitionKey.DONE]
|
|
||||||
assert not transition[TransitionKey.TRUNCATED]
|
|
||||||
assert transition[TransitionKey.INFO] == {"episode": 42}
|
|
||||||
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_transition_to_batch_observation_flattening():
|
|
||||||
"""Test that _default_transition_to_batch correctly flattens observation dict."""
|
|
||||||
observation_dict = {
|
|
||||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
|
||||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
|
||||||
"observation.state": [1, 2, 3, 4],
|
|
||||||
}
|
|
||||||
|
|
||||||
transition = {
|
|
||||||
TransitionKey.OBSERVATION: observation_dict,
|
|
||||||
TransitionKey.ACTION: "action_data",
|
|
||||||
TransitionKey.REWARD: 1.5,
|
|
||||||
TransitionKey.DONE: True,
|
|
||||||
TransitionKey.TRUNCATED: False,
|
|
||||||
TransitionKey.INFO: {"episode": 42},
|
|
||||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
|
||||||
}
|
|
||||||
|
|
||||||
batch = _default_transition_to_batch(transition)
|
|
||||||
|
|
||||||
# Check that observation.* keys are flattened back to batch
|
|
||||||
assert "observation.image.top" in batch
|
|
||||||
assert "observation.image.left" in batch
|
|
||||||
assert "observation.state" in batch
|
|
||||||
|
|
||||||
# Check values are preserved
|
|
||||||
assert torch.allclose(batch["observation.image.top"], observation_dict["observation.image.top"])
|
|
||||||
assert torch.allclose(batch["observation.image.left"], observation_dict["observation.image.left"])
|
|
||||||
assert batch["observation.state"] == [1, 2, 3, 4]
|
|
||||||
|
|
||||||
# Check other fields are mapped to next.* format
|
|
||||||
assert batch["action"] == "action_data"
|
|
||||||
assert batch["next.reward"] == 1.5
|
|
||||||
assert batch["next.done"]
|
|
||||||
assert not batch["next.truncated"]
|
|
||||||
assert batch["info"] == {"episode": 42}
|
|
||||||
|
|
||||||
|
|
||||||
def test_no_observation_keys():
|
|
||||||
"""Test behavior when there are no observation.* keys."""
|
|
||||||
batch = {
|
|
||||||
"action": "action_data",
|
|
||||||
"next.reward": 2.0,
|
|
||||||
"next.done": False,
|
|
||||||
"next.truncated": True,
|
|
||||||
"info": {"test": "no_obs"},
|
|
||||||
}
|
|
||||||
|
|
||||||
transition = _default_batch_to_transition(batch)
|
|
||||||
|
|
||||||
# Observation should be None when no observation.* keys
|
|
||||||
assert transition[TransitionKey.OBSERVATION] is None
|
|
||||||
|
|
||||||
# Check other fields
|
|
||||||
assert transition[TransitionKey.ACTION] == "action_data"
|
|
||||||
assert transition[TransitionKey.REWARD] == 2.0
|
|
||||||
assert not transition[TransitionKey.DONE]
|
|
||||||
assert transition[TransitionKey.TRUNCATED]
|
|
||||||
assert transition[TransitionKey.INFO] == {"test": "no_obs"}
|
|
||||||
|
|
||||||
# Round trip should work
|
|
||||||
reconstructed_batch = _default_transition_to_batch(transition)
|
|
||||||
assert reconstructed_batch["action"] == "action_data"
|
|
||||||
assert reconstructed_batch["next.reward"] == 2.0
|
|
||||||
assert not reconstructed_batch["next.done"]
|
|
||||||
assert reconstructed_batch["next.truncated"]
|
|
||||||
assert reconstructed_batch["info"] == {"test": "no_obs"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_minimal_batch():
|
|
||||||
"""Test with minimal batch containing only observation.* and action."""
|
|
||||||
batch = {"observation.state": "minimal_state", "action": "minimal_action"}
|
|
||||||
|
|
||||||
transition = _default_batch_to_transition(batch)
|
|
||||||
|
|
||||||
# Check observation
|
|
||||||
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
|
|
||||||
assert transition[TransitionKey.ACTION] == "minimal_action"
|
|
||||||
|
|
||||||
# Check defaults
|
|
||||||
assert transition[TransitionKey.REWARD] == 0.0
|
|
||||||
assert not transition[TransitionKey.DONE]
|
|
||||||
assert not transition[TransitionKey.TRUNCATED]
|
|
||||||
assert transition[TransitionKey.INFO] == {}
|
|
||||||
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
|
|
||||||
|
|
||||||
# Round trip
|
|
||||||
reconstructed_batch = _default_transition_to_batch(transition)
|
|
||||||
assert reconstructed_batch["observation.state"] == "minimal_state"
|
|
||||||
assert reconstructed_batch["action"] == "minimal_action"
|
|
||||||
assert reconstructed_batch["next.reward"] == 0.0
|
|
||||||
assert not reconstructed_batch["next.done"]
|
|
||||||
assert not reconstructed_batch["next.truncated"]
|
|
||||||
assert reconstructed_batch["info"] == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_empty_batch():
|
|
||||||
"""Test behavior with empty batch."""
|
|
||||||
batch = {}
|
|
||||||
|
|
||||||
transition = _default_batch_to_transition(batch)
|
|
||||||
|
|
||||||
# All fields should have defaults
|
|
||||||
assert transition[TransitionKey.OBSERVATION] is None
|
|
||||||
assert transition[TransitionKey.ACTION] is None
|
|
||||||
assert transition[TransitionKey.REWARD] == 0.0
|
|
||||||
assert not transition[TransitionKey.DONE]
|
|
||||||
assert not transition[TransitionKey.TRUNCATED]
|
|
||||||
assert transition[TransitionKey.INFO] == {}
|
|
||||||
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
|
|
||||||
|
|
||||||
# Round trip
|
|
||||||
reconstructed_batch = _default_transition_to_batch(transition)
|
|
||||||
assert reconstructed_batch["action"] is None
|
|
||||||
assert reconstructed_batch["next.reward"] == 0.0
|
|
||||||
assert not reconstructed_batch["next.done"]
|
|
||||||
assert not reconstructed_batch["next.truncated"]
|
|
||||||
assert reconstructed_batch["info"] == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_complex_nested_observation():
|
|
||||||
"""Test with complex nested observation data."""
|
|
||||||
batch = {
|
|
||||||
"observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890},
|
|
||||||
"observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
|
|
||||||
"observation.state": torch.randn(7),
|
|
||||||
"action": torch.randn(8),
|
|
||||||
"next.reward": 3.14,
|
|
||||||
"next.done": False,
|
|
||||||
"next.truncated": True,
|
|
||||||
"info": {"episode_length": 200, "success": True},
|
|
||||||
}
|
|
||||||
|
|
||||||
transition = _default_batch_to_transition(batch)
|
|
||||||
reconstructed_batch = _default_transition_to_batch(transition)
|
|
||||||
|
|
||||||
# Check that all observation keys are preserved
|
|
||||||
original_obs_keys = {k for k in batch if k.startswith("observation.")}
|
|
||||||
reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith("observation.")}
|
|
||||||
|
|
||||||
assert original_obs_keys == reconstructed_obs_keys
|
|
||||||
|
|
||||||
# Check tensor values
|
|
||||||
assert torch.allclose(batch["observation.state"], reconstructed_batch["observation.state"])
|
|
||||||
|
|
||||||
# Check nested dict with tensors
|
|
||||||
assert torch.allclose(
|
|
||||||
batch["observation.image.top"]["image"], reconstructed_batch["observation.image.top"]["image"]
|
|
||||||
)
|
|
||||||
assert torch.allclose(
|
|
||||||
batch["observation.image.left"]["image"], reconstructed_batch["observation.image.left"]["image"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check action tensor
|
|
||||||
assert torch.allclose(batch["action"], reconstructed_batch["action"])
|
|
||||||
|
|
||||||
# Check other fields
|
|
||||||
assert batch["next.reward"] == reconstructed_batch["next.reward"]
|
|
||||||
assert batch["next.done"] == reconstructed_batch["next.done"]
|
|
||||||
assert batch["next.truncated"] == reconstructed_batch["next.truncated"]
|
|
||||||
assert batch["info"] == reconstructed_batch["info"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_custom_converter():
|
|
||||||
"""Test that custom converters can still be used."""
|
|
||||||
|
|
||||||
def to_tr(batch):
|
|
||||||
# Custom converter that modifies the reward
|
|
||||||
tr = _default_batch_to_transition(batch)
|
|
||||||
# Double the reward
|
|
||||||
reward = tr.get(TransitionKey.REWARD, 0.0)
|
|
||||||
new_tr = tr.copy()
|
|
||||||
new_tr[TransitionKey.REWARD] = reward * 2 if reward is not None else 0.0
|
|
||||||
return new_tr
|
|
||||||
|
|
||||||
def to_batch(tr):
|
|
||||||
batch = _default_transition_to_batch(tr)
|
|
||||||
return batch
|
|
||||||
|
|
||||||
processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch)
|
|
||||||
|
|
||||||
batch = {
|
|
||||||
"observation.state": torch.randn(1, 4),
|
|
||||||
"action": torch.randn(1, 2),
|
|
||||||
"next.reward": 1.0,
|
|
||||||
"next.done": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
result = processor(batch)
|
|
||||||
|
|
||||||
# Check the reward was doubled by our custom converter
|
|
||||||
assert result["next.reward"] == 2.0
|
|
||||||
assert torch.allclose(result["observation.state"], batch["observation.state"])
|
|
||||||
assert torch.allclose(result["action"], batch["action"])
|
|
||||||
@ -1,628 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from unittest.mock import Mock
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
|
||||||
from lerobot.processor.normalize_processor import (
|
|
||||||
NormalizerProcessor,
|
|
||||||
UnnormalizerProcessor,
|
|
||||||
_convert_stats_to_tensors,
|
|
||||||
)
|
|
||||||
from lerobot.processor.pipeline import RobotProcessor, TransitionKey
|
|
||||||
|
|
||||||
|
|
||||||
def create_transition(
|
|
||||||
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
|
|
||||||
):
|
|
||||||
"""Helper to create an EnvTransition dictionary."""
|
|
||||||
return {
|
|
||||||
TransitionKey.OBSERVATION: observation,
|
|
||||||
TransitionKey.ACTION: action,
|
|
||||||
TransitionKey.REWARD: reward,
|
|
||||||
TransitionKey.DONE: done,
|
|
||||||
TransitionKey.TRUNCATED: truncated,
|
|
||||||
TransitionKey.INFO: info,
|
|
||||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_numpy_conversion():
|
|
||||||
stats = {
|
|
||||||
"observation.image": {
|
|
||||||
"mean": np.array([0.5, 0.5, 0.5]),
|
|
||||||
"std": np.array([0.2, 0.2, 0.2]),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tensor_stats = _convert_stats_to_tensors(stats)
|
|
||||||
|
|
||||||
assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor)
|
|
||||||
assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor)
|
|
||||||
assert torch.allclose(tensor_stats["observation.image"]["mean"], torch.tensor([0.5, 0.5, 0.5]))
|
|
||||||
assert torch.allclose(tensor_stats["observation.image"]["std"], torch.tensor([0.2, 0.2, 0.2]))
|
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_conversion():
|
|
||||||
stats = {
|
|
||||||
"action": {
|
|
||||||
"mean": torch.tensor([0.0, 0.0]),
|
|
||||||
"std": torch.tensor([1.0, 1.0]),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tensor_stats = _convert_stats_to_tensors(stats)
|
|
||||||
|
|
||||||
assert tensor_stats["action"]["mean"].dtype == torch.float32
|
|
||||||
assert tensor_stats["action"]["std"].dtype == torch.float32
|
|
||||||
|
|
||||||
|
|
||||||
def test_scalar_conversion():
|
|
||||||
stats = {
|
|
||||||
"reward": {
|
|
||||||
"mean": 0.5,
|
|
||||||
"std": 0.1,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tensor_stats = _convert_stats_to_tensors(stats)
|
|
||||||
|
|
||||||
assert torch.allclose(tensor_stats["reward"]["mean"], torch.tensor(0.5))
|
|
||||||
assert torch.allclose(tensor_stats["reward"]["std"], torch.tensor(0.1))
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_conversion():
|
|
||||||
stats = {
|
|
||||||
"observation.state": {
|
|
||||||
"min": [0.0, -1.0, -2.0],
|
|
||||||
"max": [1.0, 1.0, 2.0],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tensor_stats = _convert_stats_to_tensors(stats)
|
|
||||||
|
|
||||||
assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0]))
|
|
||||||
assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0]))
|
|
||||||
|
|
||||||
|
|
||||||
def test_unsupported_type():
|
|
||||||
stats = {
|
|
||||||
"bad_key": {
|
|
||||||
"mean": "string_value",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
with pytest.raises(TypeError, match="Unsupported type"):
|
|
||||||
_convert_stats_to_tensors(stats)
|
|
||||||
|
|
||||||
|
|
||||||
# Helper functions to create feature maps and norm maps
|
|
||||||
def _create_observation_features():
|
|
||||||
return {
|
|
||||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
|
||||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _create_observation_norm_map():
|
|
||||||
return {
|
|
||||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
|
||||||
FeatureType.STATE: NormalizationMode.MIN_MAX,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Fixtures for observation normalisation tests using NormalizerProcessor
|
|
||||||
@pytest.fixture
|
|
||||||
def observation_stats():
|
|
||||||
return {
|
|
||||||
"observation.image": {
|
|
||||||
"mean": np.array([0.5, 0.5, 0.5]),
|
|
||||||
"std": np.array([0.2, 0.2, 0.2]),
|
|
||||||
},
|
|
||||||
"observation.state": {
|
|
||||||
"min": np.array([0.0, -1.0]),
|
|
||||||
"max": np.array([1.0, 1.0]),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def observation_normalizer(observation_stats):
|
|
||||||
"""Return a NormalizerProcessor that only has observation stats (no action)."""
|
|
||||||
features = _create_observation_features()
|
|
||||||
norm_map = _create_observation_norm_map()
|
|
||||||
return NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats)
|
|
||||||
|
|
||||||
|
|
||||||
def test_mean_std_normalization(observation_normalizer):
|
|
||||||
observation = {
|
|
||||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
|
||||||
"observation.state": torch.tensor([0.5, 0.0]),
|
|
||||||
}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
normalized_transition = observation_normalizer(transition)
|
|
||||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Check mean/std normalization
|
|
||||||
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
|
|
||||||
assert torch.allclose(normalized_obs["observation.image"], expected_image)
|
|
||||||
|
|
||||||
|
|
||||||
def test_min_max_normalization(observation_normalizer):
|
|
||||||
observation = {
|
|
||||||
"observation.state": torch.tensor([0.5, 0.0]),
|
|
||||||
}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
normalized_transition = observation_normalizer(transition)
|
|
||||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Check min/max normalization to [-1, 1]
|
|
||||||
# For state[0]: 2 * (0.5 - 0.0) / (1.0 - 0.0) - 1 = 0.0
|
|
||||||
# For state[1]: 2 * (0.0 - (-1.0)) / (1.0 - (-1.0)) - 1 = 0.0
|
|
||||||
expected_state = torch.tensor([0.0, 0.0])
|
|
||||||
assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
def test_selective_normalization(observation_stats):
|
|
||||||
features = _create_observation_features()
|
|
||||||
norm_map = _create_observation_norm_map()
|
|
||||||
normalizer = NormalizerProcessor(
|
|
||||||
features=features, norm_map=norm_map, stats=observation_stats, normalize_keys={"observation.image"}
|
|
||||||
)
|
|
||||||
|
|
||||||
observation = {
|
|
||||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
|
||||||
"observation.state": torch.tensor([0.5, 0.0]),
|
|
||||||
}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
normalized_transition = normalizer(transition)
|
|
||||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Only image should be normalized
|
|
||||||
assert torch.allclose(normalized_obs["observation.image"], (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2)
|
|
||||||
# State should remain unchanged
|
|
||||||
assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
|
||||||
def test_device_compatibility(observation_stats):
|
|
||||||
features = _create_observation_features()
|
|
||||||
norm_map = _create_observation_norm_map()
|
|
||||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats)
|
|
||||||
observation = {
|
|
||||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(),
|
|
||||||
}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
normalized_transition = normalizer(transition)
|
|
||||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
assert normalized_obs["observation.image"].device.type == "cuda"
|
|
||||||
|
|
||||||
|
|
||||||
def test_from_lerobot_dataset():
|
|
||||||
# Mock dataset
|
|
||||||
mock_dataset = Mock()
|
|
||||||
mock_dataset.meta.stats = {
|
|
||||||
"observation.image": {"mean": [0.5], "std": [0.2]},
|
|
||||||
"action": {"mean": [0.0], "std": [1.0]},
|
|
||||||
}
|
|
||||||
|
|
||||||
features = {
|
|
||||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
|
||||||
"action": PolicyFeature(FeatureType.ACTION, (1,)),
|
|
||||||
}
|
|
||||||
norm_map = {
|
|
||||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
|
||||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
|
||||||
}
|
|
||||||
|
|
||||||
normalizer = NormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
|
|
||||||
|
|
||||||
# Both observation and action statistics should be present in tensor stats
|
|
||||||
assert "observation.image" in normalizer._tensor_stats
|
|
||||||
assert "action" in normalizer._tensor_stats
|
|
||||||
|
|
||||||
|
|
||||||
def test_state_dict_save_load(observation_normalizer):
|
|
||||||
# Save state
|
|
||||||
state_dict = observation_normalizer.state_dict()
|
|
||||||
|
|
||||||
# Create new normalizer and load state
|
|
||||||
features = _create_observation_features()
|
|
||||||
norm_map = _create_observation_norm_map()
|
|
||||||
new_normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={})
|
|
||||||
new_normalizer.load_state_dict(state_dict)
|
|
||||||
|
|
||||||
# Test that it works the same
|
|
||||||
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result1 = observation_normalizer(transition)[TransitionKey.OBSERVATION]
|
|
||||||
result2 = new_normalizer(transition)[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
assert torch.allclose(result1["observation.image"], result2["observation.image"])
|
|
||||||
|
|
||||||
|
|
||||||
# Fixtures for ActionUnnormalizer tests
|
|
||||||
@pytest.fixture
|
|
||||||
def action_stats_mean_std():
|
|
||||||
return {
|
|
||||||
"mean": np.array([0.0, 0.0, 0.0]),
|
|
||||||
"std": np.array([1.0, 2.0, 0.5]),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def action_stats_min_max():
|
|
||||||
return {
|
|
||||||
"min": np.array([-1.0, -2.0, 0.0]),
|
|
||||||
"max": np.array([1.0, 2.0, 1.0]),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _create_action_features():
|
|
||||||
return {
|
|
||||||
"action": PolicyFeature(FeatureType.ACTION, (3,)),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _create_action_norm_map_mean_std():
|
|
||||||
return {
|
|
||||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _create_action_norm_map_min_max():
|
|
||||||
return {
|
|
||||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_mean_std_unnormalization(action_stats_mean_std):
|
|
||||||
features = _create_action_features()
|
|
||||||
norm_map = _create_action_norm_map_mean_std()
|
|
||||||
unnormalizer = UnnormalizerProcessor(
|
|
||||||
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
|
|
||||||
)
|
|
||||||
|
|
||||||
normalized_action = torch.tensor([1.0, -0.5, 2.0])
|
|
||||||
transition = create_transition(action=normalized_action)
|
|
||||||
|
|
||||||
unnormalized_transition = unnormalizer(transition)
|
|
||||||
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
|
|
||||||
|
|
||||||
# action * std + mean
|
|
||||||
expected = torch.tensor([1.0 * 1.0 + 0.0, -0.5 * 2.0 + 0.0, 2.0 * 0.5 + 0.0])
|
|
||||||
assert torch.allclose(unnormalized_action, expected)
|
|
||||||
|
|
||||||
|
|
||||||
def test_min_max_unnormalization(action_stats_min_max):
|
|
||||||
features = _create_action_features()
|
|
||||||
norm_map = _create_action_norm_map_min_max()
|
|
||||||
unnormalizer = UnnormalizerProcessor(
|
|
||||||
features=features, norm_map=norm_map, stats={"action": action_stats_min_max}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Actions in [-1, 1]
|
|
||||||
normalized_action = torch.tensor([0.0, -1.0, 1.0])
|
|
||||||
transition = create_transition(action=normalized_action)
|
|
||||||
|
|
||||||
unnormalized_transition = unnormalizer(transition)
|
|
||||||
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
|
|
||||||
|
|
||||||
# Map from [-1, 1] to [min, max]
|
|
||||||
# (action + 1) / 2 * (max - min) + min
|
|
||||||
expected = torch.tensor(
|
|
||||||
[
|
|
||||||
(0.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0), # 0.0
|
|
||||||
(-1.0 + 1) / 2 * (2.0 - (-2.0)) + (-2.0), # -2.0
|
|
||||||
(1.0 + 1) / 2 * (1.0 - 0.0) + 0.0, # 1.0
|
|
||||||
]
|
|
||||||
)
|
|
||||||
assert torch.allclose(unnormalized_action, expected)
|
|
||||||
|
|
||||||
|
|
||||||
def test_numpy_action_input(action_stats_mean_std):
|
|
||||||
features = _create_action_features()
|
|
||||||
norm_map = _create_action_norm_map_mean_std()
|
|
||||||
unnormalizer = UnnormalizerProcessor(
|
|
||||||
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
|
|
||||||
)
|
|
||||||
|
|
||||||
normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32)
|
|
||||||
transition = create_transition(action=normalized_action)
|
|
||||||
|
|
||||||
unnormalized_transition = unnormalizer(transition)
|
|
||||||
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
|
|
||||||
|
|
||||||
assert isinstance(unnormalized_action, torch.Tensor)
|
|
||||||
expected = torch.tensor([1.0, -1.0, 1.0])
|
|
||||||
assert torch.allclose(unnormalized_action, expected)
|
|
||||||
|
|
||||||
|
|
||||||
def test_none_action(action_stats_mean_std):
|
|
||||||
features = _create_action_features()
|
|
||||||
norm_map = _create_action_norm_map_mean_std()
|
|
||||||
unnormalizer = UnnormalizerProcessor(
|
|
||||||
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
|
|
||||||
)
|
|
||||||
|
|
||||||
transition = create_transition()
|
|
||||||
result = unnormalizer(transition)
|
|
||||||
|
|
||||||
# Should return transition unchanged
|
|
||||||
assert result == transition
|
|
||||||
|
|
||||||
|
|
||||||
def test_action_from_lerobot_dataset():
|
|
||||||
mock_dataset = Mock()
|
|
||||||
mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}}
|
|
||||||
features = {"action": PolicyFeature(FeatureType.ACTION, (1,))}
|
|
||||||
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
|
||||||
unnormalizer = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
|
|
||||||
assert "mean" in unnormalizer._tensor_stats["action"]
|
|
||||||
|
|
||||||
|
|
||||||
# Fixtures for NormalizerProcessor tests
|
|
||||||
@pytest.fixture
|
|
||||||
def full_stats():
|
|
||||||
return {
|
|
||||||
"observation.image": {
|
|
||||||
"mean": np.array([0.5, 0.5, 0.5]),
|
|
||||||
"std": np.array([0.2, 0.2, 0.2]),
|
|
||||||
},
|
|
||||||
"observation.state": {
|
|
||||||
"min": np.array([0.0, -1.0]),
|
|
||||||
"max": np.array([1.0, 1.0]),
|
|
||||||
},
|
|
||||||
"action": {
|
|
||||||
"mean": np.array([0.0, 0.0]),
|
|
||||||
"std": np.array([1.0, 2.0]),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _create_full_features():
|
|
||||||
return {
|
|
||||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
|
||||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
|
||||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _create_full_norm_map():
|
|
||||||
return {
|
|
||||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
|
||||||
FeatureType.STATE: NormalizationMode.MIN_MAX,
|
|
||||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def normalizer_processor(full_stats):
|
|
||||||
features = _create_full_features()
|
|
||||||
norm_map = _create_full_norm_map()
|
|
||||||
return NormalizerProcessor(features=features, norm_map=norm_map, stats=full_stats)
|
|
||||||
|
|
||||||
|
|
||||||
def test_combined_normalization(normalizer_processor):
|
|
||||||
observation = {
|
|
||||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
|
||||||
"observation.state": torch.tensor([0.5, 0.0]),
|
|
||||||
}
|
|
||||||
action = torch.tensor([1.0, -0.5])
|
|
||||||
transition = create_transition(
|
|
||||||
observation=observation,
|
|
||||||
action=action,
|
|
||||||
reward=1.0,
|
|
||||||
done=False,
|
|
||||||
truncated=False,
|
|
||||||
info={},
|
|
||||||
complementary_data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
processed_transition = normalizer_processor(transition)
|
|
||||||
|
|
||||||
# Check normalized observations
|
|
||||||
processed_obs = processed_transition[TransitionKey.OBSERVATION]
|
|
||||||
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
|
|
||||||
assert torch.allclose(processed_obs["observation.image"], expected_image)
|
|
||||||
|
|
||||||
# Check normalized action
|
|
||||||
processed_action = processed_transition[TransitionKey.ACTION]
|
|
||||||
expected_action = torch.tensor([(1.0 - 0.0) / 1.0, (-0.5 - 0.0) / 2.0])
|
|
||||||
assert torch.allclose(processed_action, expected_action)
|
|
||||||
|
|
||||||
# Check other fields remain unchanged
|
|
||||||
assert processed_transition[TransitionKey.REWARD] == 1.0
|
|
||||||
assert not processed_transition[TransitionKey.DONE]
|
|
||||||
|
|
||||||
|
|
||||||
def test_processor_from_lerobot_dataset(full_stats):
|
|
||||||
# Mock dataset
|
|
||||||
mock_dataset = Mock()
|
|
||||||
mock_dataset.meta.stats = full_stats
|
|
||||||
|
|
||||||
features = _create_full_features()
|
|
||||||
norm_map = _create_full_norm_map()
|
|
||||||
|
|
||||||
processor = NormalizerProcessor.from_lerobot_dataset(
|
|
||||||
mock_dataset, features, norm_map, normalize_keys={"observation.image"}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert processor.normalize_keys == {"observation.image"}
|
|
||||||
assert "observation.image" in processor._tensor_stats
|
|
||||||
assert "action" in processor._tensor_stats
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_config(full_stats):
|
|
||||||
features = _create_full_features()
|
|
||||||
norm_map = _create_full_norm_map()
|
|
||||||
processor = NormalizerProcessor(
|
|
||||||
features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6
|
|
||||||
)
|
|
||||||
|
|
||||||
config = processor.get_config()
|
|
||||||
expected_config = {
|
|
||||||
"normalize_keys": ["observation.image"],
|
|
||||||
"eps": 1e-6,
|
|
||||||
"features": {
|
|
||||||
"observation.image": {"type": "VISUAL", "shape": (3, 96, 96)},
|
|
||||||
"observation.state": {"type": "STATE", "shape": (2,)},
|
|
||||||
"action": {"type": "ACTION", "shape": (2,)},
|
|
||||||
},
|
|
||||||
"norm_map": {
|
|
||||||
"VISUAL": "MEAN_STD",
|
|
||||||
"STATE": "MIN_MAX",
|
|
||||||
"ACTION": "MEAN_STD",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
assert config == expected_config
|
|
||||||
|
|
||||||
|
|
||||||
def test_integration_with_robot_processor(normalizer_processor):
|
|
||||||
"""Test integration with RobotProcessor pipeline"""
|
|
||||||
robot_processor = RobotProcessor([normalizer_processor])
|
|
||||||
|
|
||||||
observation = {
|
|
||||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
|
||||||
"observation.state": torch.tensor([0.5, 0.0]),
|
|
||||||
}
|
|
||||||
action = torch.tensor([1.0, -0.5])
|
|
||||||
transition = create_transition(
|
|
||||||
observation=observation,
|
|
||||||
action=action,
|
|
||||||
reward=1.0,
|
|
||||||
done=False,
|
|
||||||
truncated=False,
|
|
||||||
info={},
|
|
||||||
complementary_data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
processed_transition = robot_processor(transition)
|
|
||||||
|
|
||||||
# Verify the processing worked
|
|
||||||
assert isinstance(processed_transition[TransitionKey.OBSERVATION], dict)
|
|
||||||
assert isinstance(processed_transition[TransitionKey.ACTION], torch.Tensor)
|
|
||||||
|
|
||||||
|
|
||||||
# Edge case tests
|
|
||||||
def test_empty_observation():
|
|
||||||
stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
|
|
||||||
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
|
|
||||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
|
||||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
|
||||||
|
|
||||||
transition = create_transition()
|
|
||||||
result = normalizer(transition)
|
|
||||||
|
|
||||||
assert result == transition
|
|
||||||
|
|
||||||
|
|
||||||
def test_empty_stats():
|
|
||||||
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
|
|
||||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
|
||||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={})
|
|
||||||
observation = {"observation.image": torch.tensor([0.5])}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = normalizer(transition)
|
|
||||||
# Should return observation unchanged since no stats are available
|
|
||||||
assert torch.allclose(
|
|
||||||
result[TransitionKey.OBSERVATION]["observation.image"], observation["observation.image"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_partial_stats():
|
|
||||||
"""If statistics are incomplete, the value should pass through unchanged."""
|
|
||||||
stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max)
|
|
||||||
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
|
|
||||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
|
||||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
|
||||||
observation = {"observation.image": torch.tensor([0.7])}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
processed = normalizer(transition)[TransitionKey.OBSERVATION]
|
|
||||||
assert torch.allclose(processed["observation.image"], observation["observation.image"])
|
|
||||||
|
|
||||||
|
|
||||||
def test_missing_action_stats_no_error():
|
|
||||||
mock_dataset = Mock()
|
|
||||||
mock_dataset.meta.stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
|
|
||||||
|
|
||||||
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
|
|
||||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
|
||||||
|
|
||||||
processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
|
|
||||||
# The tensor stats should not contain the 'action' key
|
|
||||||
assert "action" not in processor._tensor_stats
|
|
||||||
|
|
||||||
|
|
||||||
def test_serialization_roundtrip(full_stats):
|
|
||||||
"""Test that features and norm_map can be serialized and deserialized correctly."""
|
|
||||||
features = _create_full_features()
|
|
||||||
norm_map = _create_full_norm_map()
|
|
||||||
original_processor = NormalizerProcessor(
|
|
||||||
features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get config (serialization)
|
|
||||||
config = original_processor.get_config()
|
|
||||||
|
|
||||||
# Create a new processor from the config (deserialization)
|
|
||||||
new_processor = NormalizerProcessor(
|
|
||||||
features=config["features"],
|
|
||||||
norm_map=config["norm_map"],
|
|
||||||
stats=full_stats,
|
|
||||||
normalize_keys=set(config["normalize_keys"]),
|
|
||||||
eps=config["eps"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test that both processors work the same way
|
|
||||||
observation = {
|
|
||||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
|
||||||
"observation.state": torch.tensor([0.5, 0.0]),
|
|
||||||
}
|
|
||||||
action = torch.tensor([1.0, -0.5])
|
|
||||||
transition = create_transition(
|
|
||||||
observation=observation,
|
|
||||||
action=action,
|
|
||||||
reward=1.0,
|
|
||||||
done=False,
|
|
||||||
truncated=False,
|
|
||||||
info={},
|
|
||||||
complementary_data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
result1 = original_processor(transition)
|
|
||||||
result2 = new_processor(transition)
|
|
||||||
|
|
||||||
# Compare results
|
|
||||||
assert torch.allclose(
|
|
||||||
result1[TransitionKey.OBSERVATION]["observation.image"],
|
|
||||||
result2[TransitionKey.OBSERVATION]["observation.image"],
|
|
||||||
)
|
|
||||||
assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION])
|
|
||||||
|
|
||||||
# Verify features and norm_map are correctly reconstructed
|
|
||||||
assert new_processor.features.keys() == original_processor.features.keys()
|
|
||||||
for key in new_processor.features:
|
|
||||||
assert new_processor.features[key].type == original_processor.features[key].type
|
|
||||||
assert new_processor.features[key].shape == original_processor.features[key].shape
|
|
||||||
|
|
||||||
assert new_processor.norm_map == original_processor.norm_map
|
|
||||||
@ -1,486 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType
|
|
||||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
|
||||||
from lerobot.processor import VanillaObservationProcessor
|
|
||||||
from lerobot.processor.pipeline import TransitionKey
|
|
||||||
from tests.conftest import assert_contract_is_typed
|
|
||||||
|
|
||||||
|
|
||||||
def create_transition(
|
|
||||||
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
|
|
||||||
):
|
|
||||||
"""Helper to create an EnvTransition dictionary."""
|
|
||||||
return {
|
|
||||||
TransitionKey.OBSERVATION: observation,
|
|
||||||
TransitionKey.ACTION: action,
|
|
||||||
TransitionKey.REWARD: reward,
|
|
||||||
TransitionKey.DONE: done,
|
|
||||||
TransitionKey.TRUNCATED: truncated,
|
|
||||||
TransitionKey.INFO: info,
|
|
||||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_single_image():
|
|
||||||
"""Test processing a single image."""
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
|
|
||||||
# Create a mock image (H, W, C) format, uint8
|
|
||||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
|
||||||
|
|
||||||
observation = {"pixels": image}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = processor(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Check that the image was processed correctly
|
|
||||||
assert "observation.image" in processed_obs
|
|
||||||
processed_img = processed_obs["observation.image"]
|
|
||||||
|
|
||||||
# Check shape: should be (1, 3, 64, 64) - batch, channels, height, width
|
|
||||||
assert processed_img.shape == (1, 3, 64, 64)
|
|
||||||
|
|
||||||
# Check dtype and range
|
|
||||||
assert processed_img.dtype == torch.float32
|
|
||||||
assert processed_img.min() >= 0.0
|
|
||||||
assert processed_img.max() <= 1.0
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_image_dict():
|
|
||||||
"""Test processing multiple images in a dictionary."""
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
|
|
||||||
# Create mock images
|
|
||||||
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
|
||||||
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
|
|
||||||
|
|
||||||
observation = {"pixels": {"camera1": image1, "camera2": image2}}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = processor(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Check that both images were processed
|
|
||||||
assert "observation.images.camera1" in processed_obs
|
|
||||||
assert "observation.images.camera2" in processed_obs
|
|
||||||
|
|
||||||
# Check shapes
|
|
||||||
assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32)
|
|
||||||
assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48)
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_batched_image():
|
|
||||||
"""Test processing already batched images."""
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
|
|
||||||
# Create a batched image (B, H, W, C)
|
|
||||||
image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8)
|
|
||||||
|
|
||||||
observation = {"pixels": image}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = processor(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Check that batch dimension is preserved
|
|
||||||
assert processed_obs["observation.image"].shape == (2, 3, 64, 64)
|
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_image_format():
|
|
||||||
"""Test error handling for invalid image formats."""
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
|
|
||||||
# Test wrong channel order (channels first)
|
|
||||||
image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8)
|
|
||||||
observation = {"pixels": image}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Expected channel-last images"):
|
|
||||||
processor(transition)
|
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_image_dtype():
|
|
||||||
"""Test error handling for invalid image dtype."""
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
|
|
||||||
# Test wrong dtype
|
|
||||||
image = np.random.rand(64, 64, 3).astype(np.float32)
|
|
||||||
observation = {"pixels": image}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Expected torch.uint8 images"):
|
|
||||||
processor(transition)
|
|
||||||
|
|
||||||
|
|
||||||
def test_no_pixels_in_observation():
|
|
||||||
"""Test processor when no pixels are in observation."""
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
|
|
||||||
observation = {"other_data": np.array([1, 2, 3])}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = processor(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Should preserve other data unchanged
|
|
||||||
assert "other_data" in processed_obs
|
|
||||||
np.testing.assert_array_equal(processed_obs["other_data"], np.array([1, 2, 3]))
|
|
||||||
|
|
||||||
|
|
||||||
def test_none_observation():
|
|
||||||
"""Test processor with None observation."""
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
|
|
||||||
transition = create_transition()
|
|
||||||
result = processor(transition)
|
|
||||||
|
|
||||||
assert result == transition
|
|
||||||
|
|
||||||
|
|
||||||
def test_serialization_methods():
|
|
||||||
"""Test serialization methods."""
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
|
|
||||||
# Test get_config
|
|
||||||
config = processor.get_config()
|
|
||||||
assert isinstance(config, dict)
|
|
||||||
|
|
||||||
# Test state_dict
|
|
||||||
state = processor.state_dict()
|
|
||||||
assert isinstance(state, dict)
|
|
||||||
|
|
||||||
# Test load_state_dict (should not raise)
|
|
||||||
processor.load_state_dict(state)
|
|
||||||
|
|
||||||
# Test reset (should not raise)
|
|
||||||
processor.reset()
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_environment_state():
|
|
||||||
"""Test processing environment_state."""
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
|
|
||||||
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
|
||||||
observation = {"environment_state": env_state}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = processor(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Check that environment_state was renamed and processed
|
|
||||||
assert "observation.environment_state" in processed_obs
|
|
||||||
assert "environment_state" not in processed_obs
|
|
||||||
|
|
||||||
processed_state = processed_obs["observation.environment_state"]
|
|
||||||
assert processed_state.shape == (1, 3) # Batch dimension added
|
|
||||||
assert processed_state.dtype == torch.float32
|
|
||||||
torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]]))
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_agent_pos():
|
|
||||||
"""Test processing agent_pos."""
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
|
|
||||||
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
|
|
||||||
observation = {"agent_pos": agent_pos}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = processor(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Check that agent_pos was renamed and processed
|
|
||||||
assert "observation.state" in processed_obs
|
|
||||||
assert "agent_pos" not in processed_obs
|
|
||||||
|
|
||||||
processed_state = processed_obs["observation.state"]
|
|
||||||
assert processed_state.shape == (1, 3) # Batch dimension added
|
|
||||||
assert processed_state.dtype == torch.float32
|
|
||||||
torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]]))
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_batched_states():
|
|
||||||
"""Test processing already batched states."""
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
|
|
||||||
env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
|
|
||||||
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
|
|
||||||
|
|
||||||
observation = {"environment_state": env_state, "agent_pos": agent_pos}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = processor(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Check that batch dimensions are preserved
|
|
||||||
assert processed_obs["observation.environment_state"].shape == (2, 2)
|
|
||||||
assert processed_obs["observation.state"].shape == (2, 2)
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_both_states():
|
|
||||||
"""Test processing both environment_state and agent_pos."""
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
|
|
||||||
env_state = np.array([1.0, 2.0], dtype=np.float32)
|
|
||||||
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
|
|
||||||
|
|
||||||
observation = {"environment_state": env_state, "agent_pos": agent_pos, "other_data": "keep_me"}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = processor(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Check that both states were processed
|
|
||||||
assert "observation.environment_state" in processed_obs
|
|
||||||
assert "observation.state" in processed_obs
|
|
||||||
|
|
||||||
# Check that original keys were removed
|
|
||||||
assert "environment_state" not in processed_obs
|
|
||||||
assert "agent_pos" not in processed_obs
|
|
||||||
|
|
||||||
# Check that other data was preserved
|
|
||||||
assert processed_obs["other_data"] == "keep_me"
|
|
||||||
|
|
||||||
|
|
||||||
def test_no_states_in_observation():
|
|
||||||
"""Test processor when no states are in observation."""
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
|
|
||||||
observation = {"other_data": np.array([1, 2, 3])}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = processor(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Should preserve data unchanged
|
|
||||||
np.testing.assert_array_equal(processed_obs, observation)
|
|
||||||
|
|
||||||
|
|
||||||
def test_complete_observation_processing():
|
|
||||||
"""Test processing a complete observation with both images and states."""
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
|
|
||||||
# Create mock data
|
|
||||||
image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
|
||||||
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
|
||||||
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
|
|
||||||
|
|
||||||
observation = {
|
|
||||||
"pixels": image,
|
|
||||||
"environment_state": env_state,
|
|
||||||
"agent_pos": agent_pos,
|
|
||||||
"other_data": "preserve_me",
|
|
||||||
}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = processor(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Check that image was processed
|
|
||||||
assert "observation.image" in processed_obs
|
|
||||||
assert processed_obs["observation.image"].shape == (1, 3, 32, 32)
|
|
||||||
|
|
||||||
# Check that states were processed
|
|
||||||
assert "observation.environment_state" in processed_obs
|
|
||||||
assert "observation.state" in processed_obs
|
|
||||||
|
|
||||||
# Check that original keys were removed
|
|
||||||
assert "pixels" not in processed_obs
|
|
||||||
assert "environment_state" not in processed_obs
|
|
||||||
assert "agent_pos" not in processed_obs
|
|
||||||
|
|
||||||
# Check that other data was preserved
|
|
||||||
assert processed_obs["other_data"] == "preserve_me"
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_only_processing():
|
|
||||||
"""Test processing observation with only images."""
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
|
|
||||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
|
||||||
observation = {"pixels": image}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = processor(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
assert "observation.image" in processed_obs
|
|
||||||
assert len(processed_obs) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_state_only_processing():
|
|
||||||
"""Test processing observation with only states."""
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
|
|
||||||
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
|
|
||||||
observation = {"agent_pos": agent_pos}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = processor(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
assert "observation.state" in processed_obs
|
|
||||||
assert "agent_pos" not in processed_obs
|
|
||||||
|
|
||||||
|
|
||||||
def test_empty_observation():
|
|
||||||
"""Test processing empty observation."""
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
|
|
||||||
observation = {}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = processor(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
assert processed_obs == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_equivalent_to_original_function():
|
|
||||||
"""Test that ObservationProcessor produces equivalent results to preprocess_observation."""
|
|
||||||
# Import the original function for comparison
|
|
||||||
from lerobot.envs.utils import preprocess_observation
|
|
||||||
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
|
|
||||||
# Create test data similar to what the original function expects
|
|
||||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
|
||||||
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
|
||||||
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
|
|
||||||
|
|
||||||
observation = {"pixels": image, "environment_state": env_state, "agent_pos": agent_pos}
|
|
||||||
|
|
||||||
# Process with original function
|
|
||||||
original_result = preprocess_observation(observation)
|
|
||||||
|
|
||||||
# Process with new processor
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
processor_result = processor(transition)[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Compare results
|
|
||||||
assert set(original_result.keys()) == set(processor_result.keys())
|
|
||||||
|
|
||||||
for key in original_result:
|
|
||||||
torch.testing.assert_close(original_result[key], processor_result[key])
|
|
||||||
|
|
||||||
|
|
||||||
def test_equivalent_with_image_dict():
|
|
||||||
"""Test equivalence with dictionary of images."""
|
|
||||||
from lerobot.envs.utils import preprocess_observation
|
|
||||||
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
|
|
||||||
# Create test data with multiple cameras
|
|
||||||
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
|
||||||
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
|
|
||||||
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
|
|
||||||
|
|
||||||
observation = {"pixels": {"cam1": image1, "cam2": image2}, "agent_pos": agent_pos}
|
|
||||||
|
|
||||||
# Process with original function
|
|
||||||
original_result = preprocess_observation(observation)
|
|
||||||
|
|
||||||
# Process with new processor
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
processor_result = processor(transition)[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Compare results
|
|
||||||
assert set(original_result.keys()) == set(processor_result.keys())
|
|
||||||
|
|
||||||
for key in original_result:
|
|
||||||
torch.testing.assert_close(original_result[key], processor_result[key])
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory):
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
features = {
|
|
||||||
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
|
||||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
|
||||||
}
|
|
||||||
out = processor.feature_contract(features.copy())
|
|
||||||
|
|
||||||
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"]
|
|
||||||
assert "pixels" not in out
|
|
||||||
assert out["keep"] == features["keep"]
|
|
||||||
assert_contract_is_typed(out)
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory):
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
features = {
|
|
||||||
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
|
||||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
|
||||||
}
|
|
||||||
out = processor.feature_contract(features.copy())
|
|
||||||
|
|
||||||
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"]
|
|
||||||
assert "observation.pixels" not in out
|
|
||||||
assert out["keep"] == features["keep"]
|
|
||||||
assert_contract_is_typed(out)
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory):
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
features = {
|
|
||||||
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
|
||||||
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
|
||||||
"observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
|
||||||
"keep": policy_feature_factory(FeatureType.ENV, (7,)),
|
|
||||||
}
|
|
||||||
out = processor.feature_contract(features.copy())
|
|
||||||
|
|
||||||
assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"]
|
|
||||||
assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"]
|
|
||||||
assert f"{OBS_IMAGES}.rear" in out and out[f"{OBS_IMAGES}.rear"] == features["observation.pixels.rear"]
|
|
||||||
assert "pixels.front" not in out and "pixels.wrist" not in out and "observation.pixels.rear" not in out
|
|
||||||
assert out["keep"] == features["keep"]
|
|
||||||
assert_contract_is_typed(out)
|
|
||||||
|
|
||||||
|
|
||||||
def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory):
|
|
||||||
processor = VanillaObservationProcessor()
|
|
||||||
features = {
|
|
||||||
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
|
|
||||||
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
|
|
||||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
|
||||||
}
|
|
||||||
out = processor.feature_contract(features.copy())
|
|
||||||
|
|
||||||
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"]
|
|
||||||
assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"]
|
|
||||||
assert "environment_state" not in out and "agent_pos" not in out
|
|
||||||
assert out["keep"] == features["keep"]
|
|
||||||
assert_contract_is_typed(out)
|
|
||||||
|
|
||||||
|
|
||||||
def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory):
|
|
||||||
proc = VanillaObservationProcessor()
|
|
||||||
features = {
|
|
||||||
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
|
|
||||||
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
|
|
||||||
}
|
|
||||||
out = proc.feature_contract(features.copy())
|
|
||||||
|
|
||||||
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"]
|
|
||||||
assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"]
|
|
||||||
assert "environment_state" not in out and "agent_pos" not in out
|
|
||||||
assert_contract_is_typed(out)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@ -1,467 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType
|
|
||||||
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey
|
|
||||||
from tests.conftest import assert_contract_is_typed
|
|
||||||
|
|
||||||
|
|
||||||
def create_transition(
|
|
||||||
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
|
|
||||||
):
|
|
||||||
"""Helper to create an EnvTransition dictionary."""
|
|
||||||
return {
|
|
||||||
TransitionKey.OBSERVATION: observation,
|
|
||||||
TransitionKey.ACTION: action,
|
|
||||||
TransitionKey.REWARD: reward,
|
|
||||||
TransitionKey.DONE: done,
|
|
||||||
TransitionKey.TRUNCATED: truncated,
|
|
||||||
TransitionKey.INFO: info,
|
|
||||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_basic_renaming():
|
|
||||||
"""Test basic key renaming functionality."""
|
|
||||||
rename_map = {
|
|
||||||
"old_key1": "new_key1",
|
|
||||||
"old_key2": "new_key2",
|
|
||||||
}
|
|
||||||
processor = RenameProcessor(rename_map=rename_map)
|
|
||||||
|
|
||||||
observation = {
|
|
||||||
"old_key1": torch.tensor([1.0, 2.0]),
|
|
||||||
"old_key2": np.array([3.0, 4.0]),
|
|
||||||
"unchanged_key": "keep_me",
|
|
||||||
}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = processor(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Check renamed keys
|
|
||||||
assert "new_key1" in processed_obs
|
|
||||||
assert "new_key2" in processed_obs
|
|
||||||
assert "old_key1" not in processed_obs
|
|
||||||
assert "old_key2" not in processed_obs
|
|
||||||
|
|
||||||
# Check values are preserved
|
|
||||||
torch.testing.assert_close(processed_obs["new_key1"], torch.tensor([1.0, 2.0]))
|
|
||||||
np.testing.assert_array_equal(processed_obs["new_key2"], np.array([3.0, 4.0]))
|
|
||||||
|
|
||||||
# Check unchanged key is preserved
|
|
||||||
assert processed_obs["unchanged_key"] == "keep_me"
|
|
||||||
|
|
||||||
|
|
||||||
def test_empty_rename_map():
|
|
||||||
"""Test processor with empty rename map (should pass through unchanged)."""
|
|
||||||
processor = RenameProcessor(rename_map={})
|
|
||||||
|
|
||||||
observation = {
|
|
||||||
"key1": torch.tensor([1.0]),
|
|
||||||
"key2": "value2",
|
|
||||||
}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = processor(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# All keys should be unchanged
|
|
||||||
assert processed_obs.keys() == observation.keys()
|
|
||||||
torch.testing.assert_close(processed_obs["key1"], observation["key1"])
|
|
||||||
assert processed_obs["key2"] == observation["key2"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_none_observation():
|
|
||||||
"""Test processor with None observation."""
|
|
||||||
processor = RenameProcessor(rename_map={"old": "new"})
|
|
||||||
|
|
||||||
transition = create_transition()
|
|
||||||
result = processor(transition)
|
|
||||||
|
|
||||||
# Should return transition unchanged
|
|
||||||
assert result == transition
|
|
||||||
|
|
||||||
|
|
||||||
def test_overlapping_rename():
|
|
||||||
"""Test renaming when new names might conflict."""
|
|
||||||
rename_map = {
|
|
||||||
"a": "b",
|
|
||||||
"b": "c", # This creates a potential conflict
|
|
||||||
}
|
|
||||||
processor = RenameProcessor(rename_map=rename_map)
|
|
||||||
|
|
||||||
observation = {
|
|
||||||
"a": 1,
|
|
||||||
"b": 2,
|
|
||||||
"x": 3,
|
|
||||||
}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = processor(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Check that renaming happens correctly
|
|
||||||
assert "a" not in processed_obs
|
|
||||||
assert processed_obs["b"] == 1 # 'a' renamed to 'b'
|
|
||||||
assert processed_obs["c"] == 2 # original 'b' renamed to 'c'
|
|
||||||
assert processed_obs["x"] == 3
|
|
||||||
|
|
||||||
|
|
||||||
def test_partial_rename():
|
|
||||||
"""Test renaming only some keys."""
|
|
||||||
rename_map = {
|
|
||||||
"observation.state": "observation.proprio_state",
|
|
||||||
"pixels": "observation.image",
|
|
||||||
}
|
|
||||||
processor = RenameProcessor(rename_map=rename_map)
|
|
||||||
|
|
||||||
observation = {
|
|
||||||
"observation.state": torch.randn(10),
|
|
||||||
"pixels": np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8),
|
|
||||||
"reward": 1.0,
|
|
||||||
"info": {"episode": 1},
|
|
||||||
}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = processor(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Check renamed keys
|
|
||||||
assert "observation.proprio_state" in processed_obs
|
|
||||||
assert "observation.image" in processed_obs
|
|
||||||
assert "observation.state" not in processed_obs
|
|
||||||
assert "pixels" not in processed_obs
|
|
||||||
|
|
||||||
# Check unchanged keys
|
|
||||||
assert processed_obs["reward"] == 1.0
|
|
||||||
assert processed_obs["info"] == {"episode": 1}
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_config():
|
|
||||||
"""Test configuration serialization."""
|
|
||||||
rename_map = {
|
|
||||||
"old1": "new1",
|
|
||||||
"old2": "new2",
|
|
||||||
}
|
|
||||||
processor = RenameProcessor(rename_map=rename_map)
|
|
||||||
|
|
||||||
config = processor.get_config()
|
|
||||||
assert config == {"rename_map": rename_map}
|
|
||||||
|
|
||||||
|
|
||||||
def test_state_dict():
|
|
||||||
"""Test state dict (should be empty for RenameProcessor)."""
|
|
||||||
processor = RenameProcessor(rename_map={"old": "new"})
|
|
||||||
|
|
||||||
state = processor.state_dict()
|
|
||||||
assert state == {}
|
|
||||||
|
|
||||||
# Load state dict should work even with empty dict
|
|
||||||
processor.load_state_dict({})
|
|
||||||
|
|
||||||
|
|
||||||
def test_integration_with_robot_processor():
|
|
||||||
"""Test integration with RobotProcessor pipeline."""
|
|
||||||
rename_map = {
|
|
||||||
"agent_pos": "observation.state",
|
|
||||||
"pixels": "observation.image",
|
|
||||||
}
|
|
||||||
rename_processor = RenameProcessor(rename_map=rename_map)
|
|
||||||
|
|
||||||
pipeline = RobotProcessor([rename_processor])
|
|
||||||
|
|
||||||
observation = {
|
|
||||||
"agent_pos": np.array([1.0, 2.0, 3.0]),
|
|
||||||
"pixels": np.zeros((32, 32, 3), dtype=np.uint8),
|
|
||||||
"other_data": "preserve_me",
|
|
||||||
}
|
|
||||||
transition = create_transition(
|
|
||||||
observation=observation, reward=0.5, done=False, truncated=False, info={}, complementary_data={}
|
|
||||||
)
|
|
||||||
|
|
||||||
result = pipeline(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Check renaming worked through pipeline
|
|
||||||
assert "observation.state" in processed_obs
|
|
||||||
assert "observation.image" in processed_obs
|
|
||||||
assert "agent_pos" not in processed_obs
|
|
||||||
assert "pixels" not in processed_obs
|
|
||||||
assert processed_obs["other_data"] == "preserve_me"
|
|
||||||
|
|
||||||
# Check other transition elements unchanged
|
|
||||||
assert result[TransitionKey.REWARD] == 0.5
|
|
||||||
assert result[TransitionKey.DONE] is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_and_load_pretrained():
|
|
||||||
"""Test saving and loading processor with RobotProcessor."""
|
|
||||||
rename_map = {
|
|
||||||
"old_state": "observation.state",
|
|
||||||
"old_image": "observation.image",
|
|
||||||
}
|
|
||||||
processor = RenameProcessor(rename_map=rename_map)
|
|
||||||
pipeline = RobotProcessor([processor], name="TestRenameProcessor")
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
# Save pipeline
|
|
||||||
pipeline.save_pretrained(tmp_dir)
|
|
||||||
|
|
||||||
# Check files were created
|
|
||||||
config_path = Path(tmp_dir) / "testrenameprocessor.json" # Based on name="TestRenameProcessor"
|
|
||||||
assert config_path.exists()
|
|
||||||
|
|
||||||
# No state files should be created for RenameProcessor
|
|
||||||
state_files = list(Path(tmp_dir).glob("*.safetensors"))
|
|
||||||
assert len(state_files) == 0
|
|
||||||
|
|
||||||
# Load pipeline
|
|
||||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
|
||||||
|
|
||||||
assert loaded_pipeline.name == "TestRenameProcessor"
|
|
||||||
assert len(loaded_pipeline) == 1
|
|
||||||
|
|
||||||
# Check that loaded processor works correctly
|
|
||||||
loaded_processor = loaded_pipeline.steps[0]
|
|
||||||
assert isinstance(loaded_processor, RenameProcessor)
|
|
||||||
assert loaded_processor.rename_map == rename_map
|
|
||||||
|
|
||||||
# Test functionality after loading
|
|
||||||
observation = {"old_state": [1, 2, 3], "old_image": "image_data"}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = loaded_pipeline(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
assert "observation.state" in processed_obs
|
|
||||||
assert "observation.image" in processed_obs
|
|
||||||
assert processed_obs["observation.state"] == [1, 2, 3]
|
|
||||||
assert processed_obs["observation.image"] == "image_data"
|
|
||||||
|
|
||||||
|
|
||||||
def test_registry_functionality():
|
|
||||||
"""Test that RenameProcessor is properly registered."""
|
|
||||||
# Check that it's registered
|
|
||||||
assert "rename_processor" in ProcessorStepRegistry.list()
|
|
||||||
|
|
||||||
# Get from registry
|
|
||||||
retrieved_class = ProcessorStepRegistry.get("rename_processor")
|
|
||||||
assert retrieved_class is RenameProcessor
|
|
||||||
|
|
||||||
# Create instance from registry
|
|
||||||
instance = retrieved_class(rename_map={"old": "new"})
|
|
||||||
assert isinstance(instance, RenameProcessor)
|
|
||||||
assert instance.rename_map == {"old": "new"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_registry_based_save_load():
|
|
||||||
"""Test save/load using registry name instead of module path."""
|
|
||||||
processor = RenameProcessor(rename_map={"key1": "renamed_key1"})
|
|
||||||
pipeline = RobotProcessor([processor])
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
# Save and load
|
|
||||||
pipeline.save_pretrained(tmp_dir)
|
|
||||||
|
|
||||||
# Verify config uses registry name
|
|
||||||
import json
|
|
||||||
|
|
||||||
with open(Path(tmp_dir) / "robotprocessor.json") as f: # Default name is "RobotProcessor"
|
|
||||||
config = json.load(f)
|
|
||||||
|
|
||||||
assert "registry_name" in config["steps"][0]
|
|
||||||
assert config["steps"][0]["registry_name"] == "rename_processor"
|
|
||||||
assert "class" not in config["steps"][0] # Should use registry, not module path
|
|
||||||
|
|
||||||
# Load should work
|
|
||||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
|
||||||
loaded_processor = loaded_pipeline.steps[0]
|
|
||||||
assert isinstance(loaded_processor, RenameProcessor)
|
|
||||||
assert loaded_processor.rename_map == {"key1": "renamed_key1"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_chained_rename_processors():
|
|
||||||
"""Test multiple RenameProcessors in a pipeline."""
|
|
||||||
# First processor: rename raw keys to intermediate format
|
|
||||||
processor1 = RenameProcessor(
|
|
||||||
rename_map={
|
|
||||||
"pos": "agent_position",
|
|
||||||
"img": "camera_image",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Second processor: rename to final format
|
|
||||||
processor2 = RenameProcessor(
|
|
||||||
rename_map={
|
|
||||||
"agent_position": "observation.state",
|
|
||||||
"camera_image": "observation.image",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
pipeline = RobotProcessor([processor1, processor2])
|
|
||||||
|
|
||||||
observation = {
|
|
||||||
"pos": np.array([1.0, 2.0]),
|
|
||||||
"img": "image_data",
|
|
||||||
"extra": "keep_me",
|
|
||||||
}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
# Step through to see intermediate results
|
|
||||||
results = list(pipeline.step_through(transition))
|
|
||||||
|
|
||||||
# After first processor
|
|
||||||
assert "agent_position" in results[1][TransitionKey.OBSERVATION]
|
|
||||||
assert "camera_image" in results[1][TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# After second processor
|
|
||||||
final_obs = results[2][TransitionKey.OBSERVATION]
|
|
||||||
assert "observation.state" in final_obs
|
|
||||||
assert "observation.image" in final_obs
|
|
||||||
assert final_obs["extra"] == "keep_me"
|
|
||||||
|
|
||||||
# Original keys should be gone
|
|
||||||
assert "pos" not in final_obs
|
|
||||||
assert "img" not in final_obs
|
|
||||||
assert "agent_position" not in final_obs
|
|
||||||
assert "camera_image" not in final_obs
|
|
||||||
|
|
||||||
|
|
||||||
def test_nested_observation_rename():
|
|
||||||
"""Test renaming with nested observation structures."""
|
|
||||||
rename_map = {
|
|
||||||
"observation.images.left": "observation.camera.left_view",
|
|
||||||
"observation.images.right": "observation.camera.right_view",
|
|
||||||
"observation.proprio": "observation.proprioception",
|
|
||||||
}
|
|
||||||
processor = RenameProcessor(rename_map=rename_map)
|
|
||||||
|
|
||||||
observation = {
|
|
||||||
"observation.images.left": torch.randn(3, 64, 64),
|
|
||||||
"observation.images.right": torch.randn(3, 64, 64),
|
|
||||||
"observation.proprio": torch.randn(7),
|
|
||||||
"observation.gripper": torch.tensor([0.0]), # Not renamed
|
|
||||||
}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = processor(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Check renames
|
|
||||||
assert "observation.camera.left_view" in processed_obs
|
|
||||||
assert "observation.camera.right_view" in processed_obs
|
|
||||||
assert "observation.proprioception" in processed_obs
|
|
||||||
|
|
||||||
# Check unchanged key
|
|
||||||
assert "observation.gripper" in processed_obs
|
|
||||||
|
|
||||||
# Check old keys removed
|
|
||||||
assert "observation.images.left" not in processed_obs
|
|
||||||
assert "observation.images.right" not in processed_obs
|
|
||||||
assert "observation.proprio" not in processed_obs
|
|
||||||
|
|
||||||
|
|
||||||
def test_value_types_preserved():
|
|
||||||
"""Test that various value types are preserved during renaming."""
|
|
||||||
rename_map = {"old_tensor": "new_tensor", "old_array": "new_array", "old_scalar": "new_scalar"}
|
|
||||||
processor = RenameProcessor(rename_map=rename_map)
|
|
||||||
|
|
||||||
tensor_value = torch.randn(3, 3)
|
|
||||||
array_value = np.random.rand(2, 2)
|
|
||||||
|
|
||||||
observation = {
|
|
||||||
"old_tensor": tensor_value,
|
|
||||||
"old_array": array_value,
|
|
||||||
"old_scalar": 42,
|
|
||||||
"old_string": "hello",
|
|
||||||
"old_dict": {"nested": "value"},
|
|
||||||
"old_list": [1, 2, 3],
|
|
||||||
}
|
|
||||||
transition = create_transition(observation=observation)
|
|
||||||
|
|
||||||
result = processor(transition)
|
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# Check that values and types are preserved
|
|
||||||
assert torch.equal(processed_obs["new_tensor"], tensor_value)
|
|
||||||
assert np.array_equal(processed_obs["new_array"], array_value)
|
|
||||||
assert processed_obs["new_scalar"] == 42
|
|
||||||
assert processed_obs["old_string"] == "hello"
|
|
||||||
assert processed_obs["old_dict"] == {"nested": "value"}
|
|
||||||
assert processed_obs["old_list"] == [1, 2, 3]
|
|
||||||
|
|
||||||
|
|
||||||
def test_feature_contract_basic_renaming(policy_feature_factory):
|
|
||||||
processor = RenameProcessor(rename_map={"a": "x", "b": "y"})
|
|
||||||
features = {
|
|
||||||
"a": policy_feature_factory(FeatureType.STATE, (2,)),
|
|
||||||
"b": policy_feature_factory(FeatureType.ACTION, (3,)),
|
|
||||||
"c": policy_feature_factory(FeatureType.ENV, (1,)),
|
|
||||||
}
|
|
||||||
|
|
||||||
out = processor.feature_contract(features.copy())
|
|
||||||
|
|
||||||
# Values preserved and typed
|
|
||||||
assert out["x"] == features["a"]
|
|
||||||
assert out["y"] == features["b"]
|
|
||||||
assert out["c"] == features["c"]
|
|
||||||
|
|
||||||
assert_contract_is_typed(out)
|
|
||||||
# Input not mutated
|
|
||||||
assert set(features) == {"a", "b", "c"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_feature_contract_overlapping_keys(policy_feature_factory):
|
|
||||||
# Overlapping renames: both 'a' and 'b' exist. 'a'->'b', 'b'->'c'
|
|
||||||
processor = RenameProcessor(rename_map={"a": "b", "b": "c"})
|
|
||||||
features = {
|
|
||||||
"a": policy_feature_factory(FeatureType.STATE, (1,)),
|
|
||||||
"b": policy_feature_factory(FeatureType.STATE, (2,)),
|
|
||||||
}
|
|
||||||
out = processor.feature_contract(features)
|
|
||||||
|
|
||||||
assert set(out) == {"b", "c"}
|
|
||||||
assert out["b"] == features["a"] # 'a' renamed to'b'
|
|
||||||
assert out["c"] == features["b"] # 'b' renamed to 'c'
|
|
||||||
assert_contract_is_typed(out)
|
|
||||||
|
|
||||||
|
|
||||||
def test_feature_contract_chained_processors(policy_feature_factory):
|
|
||||||
# Chain two rename processors at the contract level
|
|
||||||
processor1 = RenameProcessor(rename_map={"pos": "agent_position", "img": "camera_image"})
|
|
||||||
processor2 = RenameProcessor(
|
|
||||||
rename_map={"agent_position": "observation.state", "camera_image": "observation.image"}
|
|
||||||
)
|
|
||||||
pipeline = RobotProcessor([processor1, processor2])
|
|
||||||
|
|
||||||
spec = {
|
|
||||||
"pos": policy_feature_factory(FeatureType.STATE, (7,)),
|
|
||||||
"img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
|
||||||
"extra": policy_feature_factory(FeatureType.ENV, (1,)),
|
|
||||||
}
|
|
||||||
out = pipeline.feature_contract(initial_features=spec)
|
|
||||||
|
|
||||||
assert set(out) == {"observation.state", "observation.image", "extra"}
|
|
||||||
assert out["observation.state"] == spec["pos"]
|
|
||||||
assert out["observation.image"] == spec["img"]
|
|
||||||
assert out["extra"] == spec["extra"]
|
|
||||||
assert_contract_is_typed(out)
|
|
||||||
@ -1,208 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from concurrent import futures
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from torch.multiprocessing import Event, Queue
|
|
||||||
|
|
||||||
from lerobot.utils.transition import Transition
|
|
||||||
from tests.utils import require_package
|
|
||||||
|
|
||||||
|
|
||||||
def create_learner_service_stub():
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
from lerobot.transport import services_pb2, services_pb2_grpc
|
|
||||||
|
|
||||||
class MockLearnerService(services_pb2_grpc.LearnerServiceServicer):
|
|
||||||
def __init__(self):
|
|
||||||
self.ready_call_count = 0
|
|
||||||
self.should_fail = False
|
|
||||||
|
|
||||||
def Ready(self, request, context): # noqa: N802
|
|
||||||
self.ready_call_count += 1
|
|
||||||
if self.should_fail:
|
|
||||||
context.set_code(grpc.StatusCode.UNAVAILABLE)
|
|
||||||
context.set_details("Service unavailable")
|
|
||||||
raise grpc.RpcError("Service unavailable")
|
|
||||||
return services_pb2.Empty()
|
|
||||||
|
|
||||||
"""Fixture to start a LearnerService gRPC server and provide a connected stub."""
|
|
||||||
|
|
||||||
servicer = MockLearnerService()
|
|
||||||
|
|
||||||
# Create a gRPC server and add our servicer to it.
|
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
|
|
||||||
services_pb2_grpc.add_LearnerServiceServicer_to_server(servicer, server)
|
|
||||||
port = server.add_insecure_port("[::]:0") # bind to a free port chosen by OS
|
|
||||||
server.start() # start the server (non-blocking call):contentReference[oaicite:1]{index=1}
|
|
||||||
|
|
||||||
# Create a client channel and stub connected to the server's port.
|
|
||||||
channel = grpc.insecure_channel(f"localhost:{port}")
|
|
||||||
return services_pb2_grpc.LearnerServiceStub(channel), servicer, channel, server
|
|
||||||
|
|
||||||
|
|
||||||
def close_service_stub(channel, server):
|
|
||||||
channel.close()
|
|
||||||
server.stop(None)
|
|
||||||
|
|
||||||
|
|
||||||
@require_package("grpc")
|
|
||||||
def test_establish_learner_connection_success():
|
|
||||||
from lerobot.scripts.rl.actor import establish_learner_connection
|
|
||||||
|
|
||||||
"""Test successful connection establishment."""
|
|
||||||
stub, _servicer, channel, server = create_learner_service_stub()
|
|
||||||
|
|
||||||
shutdown_event = Event()
|
|
||||||
|
|
||||||
# Test successful connection
|
|
||||||
result = establish_learner_connection(stub, shutdown_event, attempts=5)
|
|
||||||
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
close_service_stub(channel, server)
|
|
||||||
|
|
||||||
|
|
||||||
@require_package("grpc")
|
|
||||||
def test_establish_learner_connection_failure():
|
|
||||||
from lerobot.scripts.rl.actor import establish_learner_connection
|
|
||||||
|
|
||||||
"""Test connection failure."""
|
|
||||||
stub, servicer, channel, server = create_learner_service_stub()
|
|
||||||
servicer.should_fail = True
|
|
||||||
|
|
||||||
shutdown_event = Event()
|
|
||||||
|
|
||||||
# Test failed connection
|
|
||||||
with patch("time.sleep"): # Speed up the test
|
|
||||||
result = establish_learner_connection(stub, shutdown_event, attempts=2)
|
|
||||||
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
close_service_stub(channel, server)
|
|
||||||
|
|
||||||
|
|
||||||
@require_package("grpc")
|
|
||||||
def test_push_transitions_to_transport_queue():
|
|
||||||
from lerobot.scripts.rl.actor import push_transitions_to_transport_queue
|
|
||||||
from lerobot.transport.utils import bytes_to_transitions
|
|
||||||
from tests.transport.test_transport_utils import assert_transitions_equal
|
|
||||||
|
|
||||||
"""Test pushing transitions to transport queue."""
|
|
||||||
# Create mock transitions
|
|
||||||
transitions = []
|
|
||||||
for i in range(3):
|
|
||||||
transition = Transition(
|
|
||||||
state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
|
|
||||||
action=torch.randn(5),
|
|
||||||
reward=torch.tensor(1.0 + i),
|
|
||||||
done=torch.tensor(False),
|
|
||||||
truncated=torch.tensor(False),
|
|
||||||
next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
|
|
||||||
complementary_info={"step": torch.tensor(i)},
|
|
||||||
)
|
|
||||||
transitions.append(transition)
|
|
||||||
|
|
||||||
transitions_queue = Queue()
|
|
||||||
|
|
||||||
# Test pushing transitions
|
|
||||||
push_transitions_to_transport_queue(transitions, transitions_queue)
|
|
||||||
|
|
||||||
# Verify the data can be retrieved
|
|
||||||
serialized_data = transitions_queue.get()
|
|
||||||
assert isinstance(serialized_data, bytes)
|
|
||||||
deserialized_transitions = bytes_to_transitions(serialized_data)
|
|
||||||
assert len(deserialized_transitions) == len(transitions)
|
|
||||||
for i, deserialized_transition in enumerate(deserialized_transitions):
|
|
||||||
assert_transitions_equal(deserialized_transition, transitions[i])
|
|
||||||
|
|
||||||
|
|
||||||
@require_package("grpc")
|
|
||||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
|
||||||
def test_transitions_stream():
|
|
||||||
from lerobot.scripts.rl.actor import transitions_stream
|
|
||||||
|
|
||||||
"""Test transitions stream functionality."""
|
|
||||||
shutdown_event = Event()
|
|
||||||
transitions_queue = Queue()
|
|
||||||
|
|
||||||
# Add test data to queue
|
|
||||||
test_data = [b"transition_data_1", b"transition_data_2", b"transition_data_3"]
|
|
||||||
for data in test_data:
|
|
||||||
transitions_queue.put(data)
|
|
||||||
|
|
||||||
# Collect streamed data
|
|
||||||
streamed_data = []
|
|
||||||
stream_generator = transitions_stream(shutdown_event, transitions_queue, 0.1)
|
|
||||||
|
|
||||||
# Process a few items
|
|
||||||
for i, message in enumerate(stream_generator):
|
|
||||||
streamed_data.append(message)
|
|
||||||
if i >= len(test_data) - 1:
|
|
||||||
shutdown_event.set()
|
|
||||||
break
|
|
||||||
|
|
||||||
# Verify we got messages
|
|
||||||
assert len(streamed_data) == len(test_data)
|
|
||||||
assert streamed_data[0].data == b"transition_data_1"
|
|
||||||
assert streamed_data[1].data == b"transition_data_2"
|
|
||||||
assert streamed_data[2].data == b"transition_data_3"
|
|
||||||
|
|
||||||
|
|
||||||
@require_package("grpc")
|
|
||||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
|
||||||
def test_interactions_stream():
|
|
||||||
from lerobot.scripts.rl.actor import interactions_stream
|
|
||||||
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
|
|
||||||
|
|
||||||
"""Test interactions stream functionality."""
|
|
||||||
shutdown_event = Event()
|
|
||||||
interactions_queue = Queue()
|
|
||||||
|
|
||||||
# Create test interaction data (similar structure to what would be sent)
|
|
||||||
test_interactions = [
|
|
||||||
{"episode_reward": 10.5, "step": 1, "policy_fps": 30.2},
|
|
||||||
{"episode_reward": 15.2, "step": 2, "policy_fps": 28.7},
|
|
||||||
{"episode_reward": 8.7, "step": 3, "policy_fps": 29.1},
|
|
||||||
]
|
|
||||||
|
|
||||||
# Serialize the interaction data as it would be in practice
|
|
||||||
test_data = [
|
|
||||||
interactions_queue.put(python_object_to_bytes(interaction)) for interaction in test_interactions
|
|
||||||
]
|
|
||||||
|
|
||||||
# Collect streamed data
|
|
||||||
streamed_data = []
|
|
||||||
stream_generator = interactions_stream(shutdown_event, interactions_queue, 0.1)
|
|
||||||
|
|
||||||
# Process the items
|
|
||||||
for i, message in enumerate(stream_generator):
|
|
||||||
streamed_data.append(message)
|
|
||||||
if i >= len(test_data) - 1:
|
|
||||||
shutdown_event.set()
|
|
||||||
break
|
|
||||||
|
|
||||||
# Verify we got messages
|
|
||||||
assert len(streamed_data) == len(test_data)
|
|
||||||
|
|
||||||
# Verify the messages can be deserialized back to original data
|
|
||||||
for i, message in enumerate(streamed_data):
|
|
||||||
deserialized_interaction = bytes_to_python_object(message.data)
|
|
||||||
assert deserialized_interaction == test_interactions[i]
|
|
||||||
@ -1,297 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import socket
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from torch.multiprocessing import Event, Queue
|
|
||||||
|
|
||||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
|
||||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
|
||||||
from lerobot.utils.transition import Transition
|
|
||||||
from tests.utils import require_package
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_transitions(count: int = 3) -> list[Transition]:
|
|
||||||
"""Create test transitions for integration testing."""
|
|
||||||
transitions = []
|
|
||||||
for i in range(count):
|
|
||||||
transition = Transition(
|
|
||||||
state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
|
|
||||||
action=torch.randn(5),
|
|
||||||
reward=torch.tensor(1.0 + i),
|
|
||||||
done=torch.tensor(i == count - 1), # Last transition is done
|
|
||||||
truncated=torch.tensor(False),
|
|
||||||
next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
|
|
||||||
complementary_info={"step": torch.tensor(i), "episode_id": i // 2},
|
|
||||||
)
|
|
||||||
transitions.append(transition)
|
|
||||||
return transitions
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_interactions(count: int = 3) -> list[dict]:
|
|
||||||
"""Create test interactions for integration testing."""
|
|
||||||
interactions = []
|
|
||||||
for i in range(count):
|
|
||||||
interaction = {
|
|
||||||
"episode_reward": 10.0 + i * 5,
|
|
||||||
"step": i * 100,
|
|
||||||
"policy_fps": 30.0 + i,
|
|
||||||
"intervention_rate": 0.1 * i,
|
|
||||||
"episode_length": 200 + i * 50,
|
|
||||||
}
|
|
||||||
interactions.append(interaction)
|
|
||||||
return interactions
|
|
||||||
|
|
||||||
|
|
||||||
def find_free_port():
|
|
||||||
"""Finds a free port on the local machine."""
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
||||||
s.bind(("", 0)) # Bind to port 0 to let the OS choose a free port
|
|
||||||
s.listen(1)
|
|
||||||
port = s.getsockname()[1]
|
|
||||||
return port
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def cfg():
|
|
||||||
cfg = TrainRLServerPipelineConfig()
|
|
||||||
|
|
||||||
port = find_free_port()
|
|
||||||
|
|
||||||
policy_cfg = SACConfig()
|
|
||||||
policy_cfg.actor_learner_config.learner_host = "127.0.0.1"
|
|
||||||
policy_cfg.actor_learner_config.learner_port = port
|
|
||||||
policy_cfg.concurrency.actor = "threads"
|
|
||||||
policy_cfg.concurrency.learner = "threads"
|
|
||||||
policy_cfg.actor_learner_config.queue_get_timeout = 0.1
|
|
||||||
|
|
||||||
cfg.policy = policy_cfg
|
|
||||||
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
|
|
||||||
@require_package("grpc")
|
|
||||||
@pytest.mark.timeout(10) # force cross-platform watchdog
|
|
||||||
def test_end_to_end_transitions_flow(cfg):
|
|
||||||
from lerobot.scripts.rl.actor import (
|
|
||||||
establish_learner_connection,
|
|
||||||
learner_service_client,
|
|
||||||
push_transitions_to_transport_queue,
|
|
||||||
send_transitions,
|
|
||||||
)
|
|
||||||
from lerobot.scripts.rl.learner import start_learner
|
|
||||||
from lerobot.transport.utils import bytes_to_transitions
|
|
||||||
from tests.transport.test_transport_utils import assert_transitions_equal
|
|
||||||
|
|
||||||
"""Test complete transitions flow from actor to learner."""
|
|
||||||
transitions_actor_queue = Queue()
|
|
||||||
transitions_learner_queue = Queue()
|
|
||||||
|
|
||||||
interactions_queue = Queue()
|
|
||||||
parameters_queue = Queue()
|
|
||||||
shutdown_event = Event()
|
|
||||||
|
|
||||||
learner_thread = threading.Thread(
|
|
||||||
target=start_learner,
|
|
||||||
args=(parameters_queue, transitions_learner_queue, interactions_queue, shutdown_event, cfg),
|
|
||||||
)
|
|
||||||
learner_thread.start()
|
|
||||||
|
|
||||||
policy_cfg = cfg.policy
|
|
||||||
learner_client, channel = learner_service_client(
|
|
||||||
host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port
|
|
||||||
)
|
|
||||||
|
|
||||||
assert establish_learner_connection(learner_client, shutdown_event, attempts=5)
|
|
||||||
|
|
||||||
send_transitions_thread = threading.Thread(
|
|
||||||
target=send_transitions, args=(cfg, transitions_actor_queue, shutdown_event, learner_client, channel)
|
|
||||||
)
|
|
||||||
send_transitions_thread.start()
|
|
||||||
|
|
||||||
input_transitions = create_test_transitions(count=5)
|
|
||||||
|
|
||||||
push_transitions_to_transport_queue(input_transitions, transitions_actor_queue)
|
|
||||||
|
|
||||||
# Wait for learner to start
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
shutdown_event.set()
|
|
||||||
|
|
||||||
# Wait for learner to receive transitions
|
|
||||||
learner_thread.join()
|
|
||||||
send_transitions_thread.join()
|
|
||||||
channel.close()
|
|
||||||
|
|
||||||
received_transitions = []
|
|
||||||
while not transitions_learner_queue.empty():
|
|
||||||
received_transitions.extend(bytes_to_transitions(transitions_learner_queue.get()))
|
|
||||||
|
|
||||||
assert len(received_transitions) == len(input_transitions)
|
|
||||||
for i, transition in enumerate(received_transitions):
|
|
||||||
assert_transitions_equal(transition, input_transitions[i])
|
|
||||||
|
|
||||||
|
|
||||||
@require_package("grpc")
|
|
||||||
@pytest.mark.timeout(10)
|
|
||||||
def test_end_to_end_interactions_flow(cfg):
|
|
||||||
from lerobot.scripts.rl.actor import (
|
|
||||||
establish_learner_connection,
|
|
||||||
learner_service_client,
|
|
||||||
send_interactions,
|
|
||||||
)
|
|
||||||
from lerobot.scripts.rl.learner import start_learner
|
|
||||||
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
|
|
||||||
|
|
||||||
"""Test complete interactions flow from actor to learner."""
|
|
||||||
# Queues for actor-learner communication
|
|
||||||
interactions_actor_queue = Queue()
|
|
||||||
interactions_learner_queue = Queue()
|
|
||||||
|
|
||||||
# Other queues required by the learner
|
|
||||||
parameters_queue = Queue()
|
|
||||||
transitions_learner_queue = Queue()
|
|
||||||
|
|
||||||
shutdown_event = Event()
|
|
||||||
|
|
||||||
# Start the learner in a separate thread
|
|
||||||
learner_thread = threading.Thread(
|
|
||||||
target=start_learner,
|
|
||||||
args=(parameters_queue, transitions_learner_queue, interactions_learner_queue, shutdown_event, cfg),
|
|
||||||
)
|
|
||||||
learner_thread.start()
|
|
||||||
|
|
||||||
# Establish connection from actor to learner
|
|
||||||
policy_cfg = cfg.policy
|
|
||||||
learner_client, channel = learner_service_client(
|
|
||||||
host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port
|
|
||||||
)
|
|
||||||
|
|
||||||
assert establish_learner_connection(learner_client, shutdown_event, attempts=5)
|
|
||||||
|
|
||||||
# Start the actor's interaction sending process in a separate thread
|
|
||||||
send_interactions_thread = threading.Thread(
|
|
||||||
target=send_interactions,
|
|
||||||
args=(cfg, interactions_actor_queue, shutdown_event, learner_client, channel),
|
|
||||||
)
|
|
||||||
send_interactions_thread.start()
|
|
||||||
|
|
||||||
# Create and push test interactions to the actor's queue
|
|
||||||
input_interactions = create_test_interactions(count=5)
|
|
||||||
for interaction in input_interactions:
|
|
||||||
interactions_actor_queue.put(python_object_to_bytes(interaction))
|
|
||||||
|
|
||||||
# Wait for the communication to happen
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
# Signal shutdown and wait for threads to complete
|
|
||||||
shutdown_event.set()
|
|
||||||
learner_thread.join()
|
|
||||||
send_interactions_thread.join()
|
|
||||||
channel.close()
|
|
||||||
|
|
||||||
# Verify that the learner received the interactions
|
|
||||||
received_interactions = []
|
|
||||||
while not interactions_learner_queue.empty():
|
|
||||||
received_interactions.append(bytes_to_python_object(interactions_learner_queue.get()))
|
|
||||||
|
|
||||||
assert len(received_interactions) == len(input_interactions)
|
|
||||||
|
|
||||||
# Sort by a unique key to handle potential reordering in queues
|
|
||||||
received_interactions.sort(key=lambda x: x["step"])
|
|
||||||
input_interactions.sort(key=lambda x: x["step"])
|
|
||||||
|
|
||||||
for received, expected in zip(received_interactions, input_interactions, strict=False):
|
|
||||||
assert received == expected
|
|
||||||
|
|
||||||
|
|
||||||
@require_package("grpc")
|
|
||||||
@pytest.mark.parametrize("data_size", ["small", "large"])
|
|
||||||
@pytest.mark.timeout(10)
|
|
||||||
def test_end_to_end_parameters_flow(cfg, data_size):
|
|
||||||
from lerobot.scripts.rl.actor import establish_learner_connection, learner_service_client, receive_policy
|
|
||||||
from lerobot.scripts.rl.learner import start_learner
|
|
||||||
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
|
|
||||||
|
|
||||||
"""Test complete parameter flow from learner to actor, with small and large data."""
|
|
||||||
# Actor's local queue to receive params
|
|
||||||
parameters_actor_queue = Queue()
|
|
||||||
# Learner's queue to send params from
|
|
||||||
parameters_learner_queue = Queue()
|
|
||||||
|
|
||||||
# Other queues required by the learner
|
|
||||||
transitions_learner_queue = Queue()
|
|
||||||
interactions_learner_queue = Queue()
|
|
||||||
|
|
||||||
shutdown_event = Event()
|
|
||||||
|
|
||||||
# Start the learner in a separate thread
|
|
||||||
learner_thread = threading.Thread(
|
|
||||||
target=start_learner,
|
|
||||||
args=(
|
|
||||||
parameters_learner_queue,
|
|
||||||
transitions_learner_queue,
|
|
||||||
interactions_learner_queue,
|
|
||||||
shutdown_event,
|
|
||||||
cfg,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
learner_thread.start()
|
|
||||||
|
|
||||||
# Establish connection from actor to learner
|
|
||||||
policy_cfg = cfg.policy
|
|
||||||
learner_client, channel = learner_service_client(
|
|
||||||
host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port
|
|
||||||
)
|
|
||||||
|
|
||||||
assert establish_learner_connection(learner_client, shutdown_event, attempts=5)
|
|
||||||
|
|
||||||
# Start the actor's parameter receiving process in a separate thread
|
|
||||||
receive_params_thread = threading.Thread(
|
|
||||||
target=receive_policy,
|
|
||||||
args=(cfg, parameters_actor_queue, shutdown_event, learner_client, channel),
|
|
||||||
)
|
|
||||||
receive_params_thread.start()
|
|
||||||
|
|
||||||
# Create test parameters based on parametrization
|
|
||||||
if data_size == "small":
|
|
||||||
input_params = {"layer.weight": torch.randn(128, 64)}
|
|
||||||
else: # "large"
|
|
||||||
# CHUNK_SIZE is 2MB, so this tensor (4MB) will force chunking
|
|
||||||
input_params = {"large_layer.weight": torch.randn(1024, 1024)}
|
|
||||||
|
|
||||||
# Simulate learner having new parameters to send
|
|
||||||
parameters_learner_queue.put(state_to_bytes(input_params))
|
|
||||||
|
|
||||||
# Wait for the actor to receive the parameters
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
# Signal shutdown and wait for threads to complete
|
|
||||||
shutdown_event.set()
|
|
||||||
learner_thread.join()
|
|
||||||
receive_params_thread.join()
|
|
||||||
channel.close()
|
|
||||||
|
|
||||||
# Verify that the actor received the parameters correctly
|
|
||||||
received_params = bytes_to_state_dict(parameters_actor_queue.get())
|
|
||||||
|
|
||||||
assert received_params.keys() == input_params.keys()
|
|
||||||
for key in input_params:
|
|
||||||
assert torch.allclose(received_params[key], input_params[key])
|
|
||||||
@ -1,374 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from concurrent import futures
|
|
||||||
from multiprocessing import Event, Queue
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from tests.utils import require_package # our gRPC servicer class
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def learner_service_stub():
|
|
||||||
shutdown_event = Event()
|
|
||||||
parameters_queue = Queue()
|
|
||||||
transitions_queue = Queue()
|
|
||||||
interactions_queue = Queue()
|
|
||||||
seconds_between_pushes = 1
|
|
||||||
client, channel, server = create_learner_service_stub(
|
|
||||||
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
|
|
||||||
)
|
|
||||||
|
|
||||||
yield client # provide the stub to the test function
|
|
||||||
|
|
||||||
close_learner_service_stub(channel, server)
|
|
||||||
|
|
||||||
|
|
||||||
@require_package("grpc")
|
|
||||||
def create_learner_service_stub(
|
|
||||||
shutdown_event: Event,
|
|
||||||
parameters_queue: Queue,
|
|
||||||
transitions_queue: Queue,
|
|
||||||
interactions_queue: Queue,
|
|
||||||
seconds_between_pushes: int,
|
|
||||||
queue_get_timeout: float = 0.1,
|
|
||||||
):
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
from lerobot.scripts.rl.learner_service import LearnerService
|
|
||||||
from lerobot.transport import services_pb2_grpc # generated from .proto
|
|
||||||
|
|
||||||
"""Fixture to start a LearnerService gRPC server and provide a connected stub."""
|
|
||||||
|
|
||||||
servicer = LearnerService(
|
|
||||||
shutdown_event=shutdown_event,
|
|
||||||
parameters_queue=parameters_queue,
|
|
||||||
seconds_between_pushes=seconds_between_pushes,
|
|
||||||
transition_queue=transitions_queue,
|
|
||||||
interaction_message_queue=interactions_queue,
|
|
||||||
queue_get_timeout=queue_get_timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a gRPC server and add our servicer to it.
|
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
|
|
||||||
services_pb2_grpc.add_LearnerServiceServicer_to_server(servicer, server)
|
|
||||||
port = server.add_insecure_port("[::]:0") # bind to a free port chosen by OS
|
|
||||||
server.start() # start the server (non-blocking call):contentReference[oaicite:1]{index=1}
|
|
||||||
|
|
||||||
# Create a client channel and stub connected to the server's port.
|
|
||||||
channel = grpc.insecure_channel(f"localhost:{port}")
|
|
||||||
return services_pb2_grpc.LearnerServiceStub(channel), channel, server
|
|
||||||
|
|
||||||
|
|
||||||
@require_package("grpc")
|
|
||||||
def close_learner_service_stub(channel, server):
|
|
||||||
channel.close()
|
|
||||||
server.stop(None)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
|
||||||
def test_ready_method(learner_service_stub):
|
|
||||||
from lerobot.transport import services_pb2
|
|
||||||
|
|
||||||
"""Test the ready method of the UserService."""
|
|
||||||
request = services_pb2.Empty()
|
|
||||||
response = learner_service_stub.Ready(request)
|
|
||||||
assert response == services_pb2.Empty()
|
|
||||||
|
|
||||||
|
|
||||||
@require_package("grpc")
|
|
||||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
|
||||||
def test_send_interactions():
|
|
||||||
from lerobot.transport import services_pb2
|
|
||||||
|
|
||||||
shutdown_event = Event()
|
|
||||||
|
|
||||||
parameters_queue = Queue()
|
|
||||||
transitions_queue = Queue()
|
|
||||||
interactions_queue = Queue()
|
|
||||||
seconds_between_pushes = 1
|
|
||||||
client, channel, server = create_learner_service_stub(
|
|
||||||
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
|
|
||||||
)
|
|
||||||
|
|
||||||
list_of_interaction_messages = [
|
|
||||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"1"),
|
|
||||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"2"),
|
|
||||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"3"),
|
|
||||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"4"),
|
|
||||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"5"),
|
|
||||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"6"),
|
|
||||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"7"),
|
|
||||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"8"),
|
|
||||||
]
|
|
||||||
|
|
||||||
def mock_intercations_stream():
|
|
||||||
yield from list_of_interaction_messages
|
|
||||||
|
|
||||||
return services_pb2.Empty()
|
|
||||||
|
|
||||||
response = client.SendInteractions(mock_intercations_stream())
|
|
||||||
assert response == services_pb2.Empty()
|
|
||||||
|
|
||||||
close_learner_service_stub(channel, server)
|
|
||||||
|
|
||||||
# Extract the data from the interactions queue
|
|
||||||
interactions = []
|
|
||||||
while not interactions_queue.empty():
|
|
||||||
interactions.append(interactions_queue.get())
|
|
||||||
|
|
||||||
assert interactions == [b"123", b"4", b"5", b"678"]
|
|
||||||
|
|
||||||
|
|
||||||
@require_package("grpc")
|
|
||||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
|
||||||
def test_send_transitions():
|
|
||||||
from lerobot.transport import services_pb2
|
|
||||||
|
|
||||||
"""Test the SendTransitions method with various transition data."""
|
|
||||||
shutdown_event = Event()
|
|
||||||
parameters_queue = Queue()
|
|
||||||
transitions_queue = Queue()
|
|
||||||
interactions_queue = Queue()
|
|
||||||
seconds_between_pushes = 1
|
|
||||||
|
|
||||||
client, channel, server = create_learner_service_stub(
|
|
||||||
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create test transition messages
|
|
||||||
list_of_transition_messages = [
|
|
||||||
services_pb2.Transition(
|
|
||||||
transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"transition_1"
|
|
||||||
),
|
|
||||||
services_pb2.Transition(
|
|
||||||
transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"transition_2"
|
|
||||||
),
|
|
||||||
services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"transition_3"),
|
|
||||||
services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"batch_1"),
|
|
||||||
services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"batch_2"),
|
|
||||||
]
|
|
||||||
|
|
||||||
def mock_transitions_stream():
|
|
||||||
yield from list_of_transition_messages
|
|
||||||
|
|
||||||
response = client.SendTransitions(mock_transitions_stream())
|
|
||||||
assert response == services_pb2.Empty()
|
|
||||||
|
|
||||||
close_learner_service_stub(channel, server)
|
|
||||||
|
|
||||||
# Extract the data from the transitions queue
|
|
||||||
transitions = []
|
|
||||||
while not transitions_queue.empty():
|
|
||||||
transitions.append(transitions_queue.get())
|
|
||||||
|
|
||||||
# Should have assembled the chunked data
|
|
||||||
assert transitions == [b"transition_1transition_2transition_3", b"batch_1batch_2"]
|
|
||||||
|
|
||||||
|
|
||||||
@require_package("grpc")
|
|
||||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
|
||||||
def test_send_transitions_empty_stream():
|
|
||||||
from lerobot.transport import services_pb2
|
|
||||||
|
|
||||||
"""Test SendTransitions with empty stream."""
|
|
||||||
shutdown_event = Event()
|
|
||||||
parameters_queue = Queue()
|
|
||||||
transitions_queue = Queue()
|
|
||||||
interactions_queue = Queue()
|
|
||||||
seconds_between_pushes = 1
|
|
||||||
|
|
||||||
client, channel, server = create_learner_service_stub(
|
|
||||||
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
|
|
||||||
)
|
|
||||||
|
|
||||||
def empty_stream():
|
|
||||||
return iter([])
|
|
||||||
|
|
||||||
response = client.SendTransitions(empty_stream())
|
|
||||||
assert response == services_pb2.Empty()
|
|
||||||
|
|
||||||
close_learner_service_stub(channel, server)
|
|
||||||
|
|
||||||
# Queue should remain empty
|
|
||||||
assert transitions_queue.empty()
|
|
||||||
|
|
||||||
|
|
||||||
@require_package("grpc")
|
|
||||||
@pytest.mark.timeout(10) # force cross-platform watchdog
|
|
||||||
def test_stream_parameters():
|
|
||||||
import time
|
|
||||||
|
|
||||||
from lerobot.transport import services_pb2
|
|
||||||
|
|
||||||
"""Test the StreamParameters method."""
|
|
||||||
shutdown_event = Event()
|
|
||||||
parameters_queue = Queue()
|
|
||||||
transitions_queue = Queue()
|
|
||||||
interactions_queue = Queue()
|
|
||||||
seconds_between_pushes = 0.2 # Short delay for testing
|
|
||||||
|
|
||||||
client, channel, server = create_learner_service_stub(
|
|
||||||
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add test parameters to the queue
|
|
||||||
test_params = [b"param_batch_1", b"param_batch_2"]
|
|
||||||
for param in test_params:
|
|
||||||
parameters_queue.put(param)
|
|
||||||
|
|
||||||
# Start streaming parameters
|
|
||||||
request = services_pb2.Empty()
|
|
||||||
stream = client.StreamParameters(request)
|
|
||||||
|
|
||||||
# Collect streamed parameters and timestamps
|
|
||||||
received_params = []
|
|
||||||
timestamps = []
|
|
||||||
|
|
||||||
for response in stream:
|
|
||||||
received_params.append(response.data)
|
|
||||||
timestamps.append(time.time())
|
|
||||||
|
|
||||||
# We should receive one last item
|
|
||||||
break
|
|
||||||
|
|
||||||
parameters_queue.put(b"param_batch_3")
|
|
||||||
|
|
||||||
for response in stream:
|
|
||||||
received_params.append(response.data)
|
|
||||||
timestamps.append(time.time())
|
|
||||||
|
|
||||||
# We should receive only one item
|
|
||||||
break
|
|
||||||
|
|
||||||
shutdown_event.set()
|
|
||||||
close_learner_service_stub(channel, server)
|
|
||||||
|
|
||||||
assert received_params == [b"param_batch_2", b"param_batch_3"]
|
|
||||||
|
|
||||||
# Check the time difference between the two sends
|
|
||||||
time_diff = timestamps[1] - timestamps[0]
|
|
||||||
# Check if the time difference is close to the expected push frequency
|
|
||||||
assert time_diff == pytest.approx(seconds_between_pushes, abs=0.1)
|
|
||||||
|
|
||||||
|
|
||||||
@require_package("grpc")
|
|
||||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
|
||||||
def test_stream_parameters_with_shutdown():
|
|
||||||
from lerobot.transport import services_pb2
|
|
||||||
|
|
||||||
"""Test StreamParameters handles shutdown gracefully."""
|
|
||||||
shutdown_event = Event()
|
|
||||||
parameters_queue = Queue()
|
|
||||||
transitions_queue = Queue()
|
|
||||||
interactions_queue = Queue()
|
|
||||||
seconds_between_pushes = 0.1
|
|
||||||
queue_get_timeout = 0.001
|
|
||||||
|
|
||||||
client, channel, server = create_learner_service_stub(
|
|
||||||
shutdown_event,
|
|
||||||
parameters_queue,
|
|
||||||
transitions_queue,
|
|
||||||
interactions_queue,
|
|
||||||
seconds_between_pushes,
|
|
||||||
queue_get_timeout=queue_get_timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
test_params = [b"param_batch_1", b"stop", b"param_batch_3", b"param_batch_4"]
|
|
||||||
|
|
||||||
# create a thread that will put the parameters in the queue
|
|
||||||
def producer():
|
|
||||||
for param in test_params:
|
|
||||||
parameters_queue.put(param)
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
producer_thread = threading.Thread(target=producer)
|
|
||||||
producer_thread.start()
|
|
||||||
|
|
||||||
# Start streaming
|
|
||||||
request = services_pb2.Empty()
|
|
||||||
stream = client.StreamParameters(request)
|
|
||||||
|
|
||||||
# Collect streamed parameters
|
|
||||||
received_params = []
|
|
||||||
|
|
||||||
for response in stream:
|
|
||||||
received_params.append(response.data)
|
|
||||||
|
|
||||||
if response.data == b"stop":
|
|
||||||
shutdown_event.set()
|
|
||||||
|
|
||||||
producer_thread.join()
|
|
||||||
close_learner_service_stub(channel, server)
|
|
||||||
|
|
||||||
assert received_params == [b"param_batch_1", b"stop"]
|
|
||||||
|
|
||||||
|
|
||||||
@require_package("grpc")
|
|
||||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
|
||||||
def test_stream_parameters_waits_and_retries_on_empty_queue():
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
|
|
||||||
from lerobot.transport import services_pb2
|
|
||||||
|
|
||||||
"""Test that StreamParameters waits and retries when the queue is empty."""
|
|
||||||
shutdown_event = Event()
|
|
||||||
parameters_queue = Queue()
|
|
||||||
transitions_queue = Queue()
|
|
||||||
interactions_queue = Queue()
|
|
||||||
seconds_between_pushes = 0.05
|
|
||||||
queue_get_timeout = 0.01
|
|
||||||
|
|
||||||
client, channel, server = create_learner_service_stub(
|
|
||||||
shutdown_event,
|
|
||||||
parameters_queue,
|
|
||||||
transitions_queue,
|
|
||||||
interactions_queue,
|
|
||||||
seconds_between_pushes,
|
|
||||||
queue_get_timeout=queue_get_timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
request = services_pb2.Empty()
|
|
||||||
stream = client.StreamParameters(request)
|
|
||||||
|
|
||||||
received_params = []
|
|
||||||
|
|
||||||
def producer():
|
|
||||||
# Let the consumer start and find an empty queue.
|
|
||||||
# It will wait `seconds_between_pushes` (0.05s), then `get` will timeout after `queue_get_timeout` (0.01s).
|
|
||||||
# Total time for the first empty loop is > 0.06s. We wait a bit longer to be safe.
|
|
||||||
time.sleep(0.06)
|
|
||||||
parameters_queue.put(b"param_after_wait")
|
|
||||||
time.sleep(0.05)
|
|
||||||
parameters_queue.put(b"param_after_wait_2")
|
|
||||||
|
|
||||||
producer_thread = threading.Thread(target=producer)
|
|
||||||
producer_thread.start()
|
|
||||||
|
|
||||||
# The consumer will block here until the producer sends an item.
|
|
||||||
for response in stream:
|
|
||||||
received_params.append(response.data)
|
|
||||||
if response.data == b"param_after_wait_2":
|
|
||||||
break # We only need one item for this test.
|
|
||||||
|
|
||||||
shutdown_event.set()
|
|
||||||
producer_thread.join()
|
|
||||||
close_learner_service_stub(channel, server)
|
|
||||||
|
|
||||||
assert received_params == [b"param_after_wait", b"param_after_wait_2"]
|
|
||||||
@ -1,111 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from lerobot.robots.so100_follower import (
|
|
||||||
SO100Follower,
|
|
||||||
SO100FollowerConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_bus_mock() -> MagicMock:
|
|
||||||
"""Return a bus mock with just the attributes used by the robot."""
|
|
||||||
bus = MagicMock(name="FeetechBusMock")
|
|
||||||
bus.is_connected = False
|
|
||||||
|
|
||||||
def _connect():
|
|
||||||
bus.is_connected = True
|
|
||||||
|
|
||||||
def _disconnect(_disable=True):
|
|
||||||
bus.is_connected = False
|
|
||||||
|
|
||||||
bus.connect.side_effect = _connect
|
|
||||||
bus.disconnect.side_effect = _disconnect
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def _dummy_cm():
|
|
||||||
yield
|
|
||||||
|
|
||||||
bus.torque_disabled.side_effect = _dummy_cm
|
|
||||||
|
|
||||||
return bus
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def follower():
|
|
||||||
bus_mock = _make_bus_mock()
|
|
||||||
|
|
||||||
def _bus_side_effect(*_args, **kwargs):
|
|
||||||
bus_mock.motors = kwargs["motors"]
|
|
||||||
motors_order: list[str] = list(bus_mock.motors)
|
|
||||||
|
|
||||||
bus_mock.sync_read.return_value = {motor: idx for idx, motor in enumerate(motors_order, 1)}
|
|
||||||
bus_mock.sync_write.return_value = None
|
|
||||||
bus_mock.write.return_value = None
|
|
||||||
bus_mock.disable_torque.return_value = None
|
|
||||||
bus_mock.enable_torque.return_value = None
|
|
||||||
bus_mock.is_calibrated = True
|
|
||||||
return bus_mock
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"lerobot.robots.so100_follower.so100_follower.FeetechMotorsBus",
|
|
||||||
side_effect=_bus_side_effect,
|
|
||||||
),
|
|
||||||
patch.object(SO100Follower, "configure", lambda self: None),
|
|
||||||
):
|
|
||||||
cfg = SO100FollowerConfig(port="/dev/null")
|
|
||||||
robot = SO100Follower(cfg)
|
|
||||||
yield robot
|
|
||||||
if robot.is_connected:
|
|
||||||
robot.disconnect()
|
|
||||||
|
|
||||||
|
|
||||||
def test_connect_disconnect(follower):
|
|
||||||
assert not follower.is_connected
|
|
||||||
|
|
||||||
follower.connect()
|
|
||||||
assert follower.is_connected
|
|
||||||
|
|
||||||
follower.disconnect()
|
|
||||||
assert not follower.is_connected
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_observation(follower):
|
|
||||||
follower.connect()
|
|
||||||
obs = follower.get_observation()
|
|
||||||
|
|
||||||
expected_keys = {f"{m}.pos" for m in follower.bus.motors}
|
|
||||||
assert set(obs.keys()) == expected_keys
|
|
||||||
|
|
||||||
for idx, motor in enumerate(follower.bus.motors, 1):
|
|
||||||
assert obs[f"{motor}.pos"] == idx
|
|
||||||
|
|
||||||
|
|
||||||
def test_send_action(follower):
|
|
||||||
follower.connect()
|
|
||||||
|
|
||||||
action = {f"{m}.pos": i * 10 for i, m in enumerate(follower.bus.motors, 1)}
|
|
||||||
returned = follower.send_action(action)
|
|
||||||
|
|
||||||
assert returned == action
|
|
||||||
|
|
||||||
goal_pos = {m: (i + 1) * 10 for i, m in enumerate(follower.bus.motors)}
|
|
||||||
follower.bus.sync_write.assert_called_once_with("Goal_Position", goal_pos)
|
|
||||||
@ -1,60 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import importlib
|
|
||||||
|
|
||||||
import gymnasium as gym
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import lerobot
|
|
||||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
|
||||||
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
|
||||||
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
|
||||||
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
|
||||||
from tests.utils import require_env
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("env_name, task_name", lerobot.env_task_pairs)
|
|
||||||
@require_env
|
|
||||||
def test_available_env_task(env_name: str, task_name: list):
|
|
||||||
"""
|
|
||||||
This test verifies that all environments listed in `lerobot/__init__.py` can
|
|
||||||
be successfully imported — if they're installed — and that their
|
|
||||||
`available_tasks_per_env` are valid.
|
|
||||||
"""
|
|
||||||
package_name = f"gym_{env_name}"
|
|
||||||
importlib.import_module(package_name)
|
|
||||||
gym_handle = f"{package_name}/{task_name}"
|
|
||||||
assert gym_handle in gym.envs.registry, gym_handle
|
|
||||||
|
|
||||||
|
|
||||||
def test_available_policies():
|
|
||||||
"""
|
|
||||||
This test verifies that the class attribute `name` for all policies is
|
|
||||||
consistent with those listed in `lerobot/__init__.py`.
|
|
||||||
"""
|
|
||||||
policy_classes = [ACTPolicy, DiffusionPolicy, TDMPCPolicy, VQBeTPolicy]
|
|
||||||
policies = [pol_cls.name for pol_cls in policy_classes]
|
|
||||||
assert set(policies) == set(lerobot.available_policies), policies
|
|
||||||
|
|
||||||
|
|
||||||
def test_print():
|
|
||||||
print(lerobot.available_envs)
|
|
||||||
print(lerobot.available_tasks_per_env)
|
|
||||||
print(lerobot.available_datasets)
|
|
||||||
print(lerobot.available_datasets_per_env)
|
|
||||||
print(lerobot.available_real_world_datasets)
|
|
||||||
print(lerobot.available_policies)
|
|
||||||
print(lerobot.available_policies_per_env)
|
|
||||||
@ -1,106 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from lerobot.calibrate import CalibrateConfig, calibrate
|
|
||||||
from lerobot.record import DatasetRecordConfig, RecordConfig, record
|
|
||||||
from lerobot.replay import DatasetReplayConfig, ReplayConfig, replay
|
|
||||||
from lerobot.teleoperate import TeleoperateConfig, teleoperate
|
|
||||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
|
||||||
from tests.mocks.mock_robot import MockRobotConfig
|
|
||||||
from tests.mocks.mock_teleop import MockTeleopConfig
|
|
||||||
|
|
||||||
|
|
||||||
def test_calibrate():
|
|
||||||
robot_cfg = MockRobotConfig()
|
|
||||||
cfg = CalibrateConfig(robot=robot_cfg)
|
|
||||||
calibrate(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def test_teleoperate():
|
|
||||||
robot_cfg = MockRobotConfig()
|
|
||||||
teleop_cfg = MockTeleopConfig()
|
|
||||||
cfg = TeleoperateConfig(
|
|
||||||
robot=robot_cfg,
|
|
||||||
teleop=teleop_cfg,
|
|
||||||
teleop_time_s=0.1,
|
|
||||||
)
|
|
||||||
teleoperate(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def test_record_and_resume(tmp_path):
|
|
||||||
robot_cfg = MockRobotConfig()
|
|
||||||
teleop_cfg = MockTeleopConfig()
|
|
||||||
dataset_cfg = DatasetRecordConfig(
|
|
||||||
repo_id=DUMMY_REPO_ID,
|
|
||||||
single_task="Dummy task",
|
|
||||||
root=tmp_path / "record",
|
|
||||||
num_episodes=1,
|
|
||||||
episode_time_s=0.1,
|
|
||||||
reset_time_s=0,
|
|
||||||
push_to_hub=False,
|
|
||||||
)
|
|
||||||
cfg = RecordConfig(
|
|
||||||
robot=robot_cfg,
|
|
||||||
dataset=dataset_cfg,
|
|
||||||
teleop=teleop_cfg,
|
|
||||||
play_sounds=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = record(cfg)
|
|
||||||
|
|
||||||
assert dataset.fps == 30
|
|
||||||
assert dataset.meta.total_episodes == dataset.num_episodes == 1
|
|
||||||
assert dataset.meta.total_frames == dataset.num_frames == 3
|
|
||||||
assert dataset.meta.total_tasks == 1
|
|
||||||
|
|
||||||
cfg.resume = True
|
|
||||||
dataset = record(cfg)
|
|
||||||
|
|
||||||
assert dataset.meta.total_episodes == dataset.num_episodes == 2
|
|
||||||
assert dataset.meta.total_frames == dataset.num_frames == 6
|
|
||||||
assert dataset.meta.total_tasks == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_record_and_replay(tmp_path):
|
|
||||||
robot_cfg = MockRobotConfig()
|
|
||||||
teleop_cfg = MockTeleopConfig()
|
|
||||||
record_dataset_cfg = DatasetRecordConfig(
|
|
||||||
repo_id=DUMMY_REPO_ID,
|
|
||||||
single_task="Dummy task",
|
|
||||||
root=tmp_path / "record_and_replay",
|
|
||||||
num_episodes=1,
|
|
||||||
episode_time_s=0.1,
|
|
||||||
push_to_hub=False,
|
|
||||||
)
|
|
||||||
record_cfg = RecordConfig(
|
|
||||||
robot=robot_cfg,
|
|
||||||
dataset=record_dataset_cfg,
|
|
||||||
teleop=teleop_cfg,
|
|
||||||
play_sounds=False,
|
|
||||||
)
|
|
||||||
replay_dataset_cfg = DatasetReplayConfig(
|
|
||||||
repo_id=DUMMY_REPO_ID,
|
|
||||||
episode=0,
|
|
||||||
root=tmp_path / "record_and_replay",
|
|
||||||
)
|
|
||||||
replay_cfg = ReplayConfig(
|
|
||||||
robot=robot_cfg,
|
|
||||||
dataset=replay_dataset_cfg,
|
|
||||||
play_sounds=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
record(record_cfg)
|
|
||||||
replay(replay_cfg)
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user