diff --git a/Makefile b/Makefile
index 20d2c55..d9f2f2c 100644
--- a/Makefile
+++ b/Makefile
@@ -22,8 +22,9 @@ test-end-to-end:
${MAKE} test-act-ete-eval
${MAKE} test-diffusion-ete-train
${MAKE} test-diffusion-ete-eval
- ${MAKE} test-tdmpc-ete-train
- ${MAKE} test-tdmpc-ete-eval
+ # TODO(rcadene, alexander-soare): enable end-to-end tests for tdmpc
+ # ${MAKE} test-tdmpc-ete-train
+ # ${MAKE} test-tdmpc-ete-eval
${MAKE} test-default-ete-eval
test-act-ete-train:
diff --git a/lerobot/common/datasets/_video_benchmark/README.md b/lerobot/common/datasets/_video_benchmark/README.md
new file mode 100644
index 0000000..10e8d12
--- /dev/null
+++ b/lerobot/common/datasets/_video_benchmark/README.md
@@ -0,0 +1,334 @@
+# Video benchmark
+
+
+## Questions
+
+What is the optimal trade-off between:
+- maximizing loading time with random access,
+- minimizing memory space on disk,
+- maximizing success rate of policies?
+
+How to encode videos?
+- How much compression (`-crf`)? Low compression with `0`, normal compression with `20` or extreme with `56`?
+- What pixel format to use (`-pix_fmt`)? `yuv444p` or `yuv420p`?
+- How many key frames (`-g`)? A key frame every `10` frames?
+
+How to decode videos?
+- Which `decoder`? `torchvision`, `torchaudio`, `ffmpegio`, `decord`, or `nvc`?
+
+## Metrics
+
+**Percentage of data compression (higher is better)**
+`compression_factor` is the ratio of the memory space on disk taken by the original images to encode, to the memory space taken by the encoded video. For instance, `compression_factor=4` means that the video takes 4 times less memory space on disk compared to the original images.
+
+**Percentage of loading time (higher is better)**
+`load_time_factor` is the ratio of the time it takes to load original images at given timestamps, to the time it takes to decode the exact same frames from the video. Higher is better. For instance, `load_time_factor=0.5` means that decoding from video is 2 times slower than loading the original images.
+
+**Average L2 error per pixel (lower is better)**
+`avg_per_pixel_l2_error` is the average L2 error between each decoded frame and its corresponding original image over all requested timestamps, and also divided by the number of pixels in the image to be comparable when switching to different image sizes.
+
+**Loss of a pretrained policy (higher is better)** (not available)
+`loss_pretrained` is the result of evaluating with the selected encoding/decoding settings a policy pretrained on original images. It is easier to understand than `avg_l2_error`.
+
+**Success rate after retraining (higher is better)** (not available)
+`success_rate` is the result of training and evaluating a policy with the selected encoding/decoding settings. It is the most difficult metric to get but also the very best.
+
+
+## Variables
+
+**Image content**
+We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, etc. Hence, we run this bechmark on two datasets: `pusht` (simulation) and `umi` (real-world outdoor).
+
+**Requested timestamps**
+In this benchmark, we focus on the loading time of random access, so we are not interested about sequentially loading all frames of a video like in a movie. However, the number of consecutive timestamps requested and their spacing can greatly affect the `load_time_factor`. In fact, it is expected to get faster loading time by decoding a large number of consecutive frames from a video, than to load the same data from individual images. To reflect our robotics use case, we consider a few settings:
+- `single_frame`: 1 frame,
+- `2_frames`: 2 consecutive frames (e.g. `[t, t + 1 / fps]`),
+- `2_frames_4_space`: 2 consecutive frames with 4 frames of spacing (e.g `[t, t + 4 / fps]`),
+
+**Data augmentations**
+We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robusts (e.g. robust to color changes, compression, etc.).
+
+
+## Results
+
+**`decoder`**
+| repo_id | decoder | load_time_factor | avg_per_pixel_l2_error |
+| --- | --- | --- | --- |
+| lerobot/pusht | torchvision | 0.166 | 0.0000119 |
+| lerobot/pusht | ffmpegio | 0.009 | 0.0001182 |
+| lerobot/pusht | torchaudio | 0.138 | 0.0000359 |
+| lerobot/umi_cup_in_the_wild | torchvision | 0.174 | 0.0000174 |
+| lerobot/umi_cup_in_the_wild | ffmpegio | 0.010 | 0.0000735 |
+| lerobot/umi_cup_in_the_wild | torchaudio | 0.154 | 0.0000340 |
+
+### `1_frame`
+
+**`pix_fmt`**
+| repo_id | pix_fmt | compression_factor | load_time_factor | avg_per_pixel_l2_error |
+| --- | --- | --- | --- | --- |
+| lerobot/pusht | yuv420p | 3.788 | 0.224 | 0.0000760 |
+| lerobot/pusht | yuv444p | 3.646 | 0.185 | 0.0000443 |
+| lerobot/umi_cup_in_the_wild | yuv420p | 14.391 | 0.388 | 0.0000469 |
+| lerobot/umi_cup_in_the_wild | yuv444p | 14.932 | 0.329 | 0.0000397 |
+
+**`g`**
+| repo_id | g | compression_factor | load_time_factor | avg_per_pixel_l2_error |
+| --- | --- | --- | --- | --- |
+| lerobot/pusht | 1 | 2.543 | 0.204 | 0.0000556 |
+| lerobot/pusht | 2 | 3.646 | 0.182 | 0.0000443 |
+| lerobot/pusht | 3 | 4.431 | 0.174 | 0.0000450 |
+| lerobot/pusht | 4 | 5.103 | 0.163 | 0.0000448 |
+| lerobot/pusht | 5 | 5.625 | 0.163 | 0.0000436 |
+| lerobot/pusht | 6 | 5.974 | 0.155 | 0.0000427 |
+| lerobot/pusht | 10 | 6.814 | 0.130 | 0.0000410 |
+| lerobot/pusht | 15 | 7.431 | 0.105 | 0.0000406 |
+| lerobot/pusht | 20 | 7.662 | 0.097 | 0.0000400 |
+| lerobot/pusht | 40 | 8.163 | 0.061 | 0.0000405 |
+| lerobot/pusht | 100 | 8.761 | 0.039 | 0.0000422 |
+| lerobot/pusht | None | 8.909 | 0.024 | 0.0000431 |
+| lerobot/umi_cup_in_the_wild | 1 | 14.411 | 0.444 | 0.0000601 |
+| lerobot/umi_cup_in_the_wild | 2 | 14.932 | 0.345 | 0.0000397 |
+| lerobot/umi_cup_in_the_wild | 3 | 20.174 | 0.282 | 0.0000416 |
+| lerobot/umi_cup_in_the_wild | 4 | 24.889 | 0.271 | 0.0000415 |
+| lerobot/umi_cup_in_the_wild | 5 | 28.825 | 0.260 | 0.0000415 |
+| lerobot/umi_cup_in_the_wild | 6 | 31.635 | 0.249 | 0.0000415 |
+| lerobot/umi_cup_in_the_wild | 10 | 39.418 | 0.195 | 0.0000399 |
+| lerobot/umi_cup_in_the_wild | 15 | 44.577 | 0.169 | 0.0000394 |
+| lerobot/umi_cup_in_the_wild | 20 | 47.907 | 0.140 | 0.0000390 |
+| lerobot/umi_cup_in_the_wild | 40 | 52.554 | 0.096 | 0.0000384 |
+| lerobot/umi_cup_in_the_wild | 100 | 58.241 | 0.046 | 0.0000390 |
+| lerobot/umi_cup_in_the_wild | None | 60.530 | 0.022 | 0.0000400 |
+
+**`crf`**
+| repo_id | crf | compression_factor | load_time_factor | avg_per_pixel_l2_error |
+| --- | --- | --- | --- | --- |
+| lerobot/pusht | 0 | 1.699 | 0.175 | 0.0000035 |
+| lerobot/pusht | 5 | 1.409 | 0.181 | 0.0000080 |
+| lerobot/pusht | 10 | 1.842 | 0.172 | 0.0000123 |
+| lerobot/pusht | 15 | 2.322 | 0.187 | 0.0000211 |
+| lerobot/pusht | 20 | 3.050 | 0.181 | 0.0000346 |
+| lerobot/pusht | None | 3.646 | 0.189 | 0.0000443 |
+| lerobot/pusht | 25 | 3.969 | 0.186 | 0.0000521 |
+| lerobot/pusht | 30 | 5.687 | 0.184 | 0.0000850 |
+| lerobot/pusht | 40 | 10.818 | 0.193 | 0.0001726 |
+| lerobot/pusht | 50 | 18.185 | 0.183 | 0.0002606 |
+| lerobot/umi_cup_in_the_wild | 0 | 1.918 | 0.165 | 0.0000056 |
+| lerobot/umi_cup_in_the_wild | 5 | 3.207 | 0.171 | 0.0000111 |
+| lerobot/umi_cup_in_the_wild | 10 | 4.818 | 0.212 | 0.0000153 |
+| lerobot/umi_cup_in_the_wild | 15 | 7.329 | 0.261 | 0.0000218 |
+| lerobot/umi_cup_in_the_wild | 20 | 11.361 | 0.312 | 0.0000317 |
+| lerobot/umi_cup_in_the_wild | None | 14.932 | 0.339 | 0.0000397 |
+| lerobot/umi_cup_in_the_wild | 25 | 17.741 | 0.297 | 0.0000452 |
+| lerobot/umi_cup_in_the_wild | 30 | 27.983 | 0.406 | 0.0000629 |
+| lerobot/umi_cup_in_the_wild | 40 | 82.449 | 0.468 | 0.0001184 |
+| lerobot/umi_cup_in_the_wild | 50 | 186.145 | 0.515 | 0.0001879 |
+
+**best**
+| repo_id | compression_factor | load_time_factor | avg_per_pixel_l2_error |
+| --- | --- | --- | --- |
+| lerobot/pusht | 3.646 | 0.188 | 0.0000443 |
+| lerobot/umi_cup_in_the_wild | 14.932 | 0.339 | 0.0000397 |
+
+### `2_frames`
+
+**`pix_fmt`**
+| repo_id | pix_fmt | compression_factor | load_time_factor | avg_per_pixel_l2_error |
+| --- | --- | --- | --- | --- |
+| lerobot/pusht | yuv420p | 3.788 | 0.314 | 0.0000799 |
+| lerobot/pusht | yuv444p | 3.646 | 0.303 | 0.0000496 |
+| lerobot/umi_cup_in_the_wild | yuv420p | 14.391 | 0.642 | 0.0000503 |
+| lerobot/umi_cup_in_the_wild | yuv444p | 14.932 | 0.529 | 0.0000436 |
+
+**`g`**
+| repo_id | g | compression_factor | load_time_factor | avg_per_pixel_l2_error |
+| --- | --- | --- | --- | --- |
+| lerobot/pusht | 1 | 2.543 | 0.308 | 0.0000599 |
+| lerobot/pusht | 2 | 3.646 | 0.279 | 0.0000496 |
+| lerobot/pusht | 3 | 4.431 | 0.259 | 0.0000498 |
+| lerobot/pusht | 4 | 5.103 | 0.243 | 0.0000501 |
+| lerobot/pusht | 5 | 5.625 | 0.235 | 0.0000492 |
+| lerobot/pusht | 6 | 5.974 | 0.230 | 0.0000481 |
+| lerobot/pusht | 10 | 6.814 | 0.194 | 0.0000468 |
+| lerobot/pusht | 15 | 7.431 | 0.152 | 0.0000460 |
+| lerobot/pusht | 20 | 7.662 | 0.151 | 0.0000455 |
+| lerobot/pusht | 40 | 8.163 | 0.095 | 0.0000454 |
+| lerobot/pusht | 100 | 8.761 | 0.062 | 0.0000472 |
+| lerobot/pusht | None | 8.909 | 0.037 | 0.0000479 |
+| lerobot/umi_cup_in_the_wild | 1 | 14.411 | 0.638 | 0.0000625 |
+| lerobot/umi_cup_in_the_wild | 2 | 14.932 | 0.537 | 0.0000436 |
+| lerobot/umi_cup_in_the_wild | 3 | 20.174 | 0.493 | 0.0000437 |
+| lerobot/umi_cup_in_the_wild | 4 | 24.889 | 0.458 | 0.0000446 |
+| lerobot/umi_cup_in_the_wild | 5 | 28.825 | 0.438 | 0.0000445 |
+| lerobot/umi_cup_in_the_wild | 6 | 31.635 | 0.424 | 0.0000444 |
+| lerobot/umi_cup_in_the_wild | 10 | 39.418 | 0.345 | 0.0000435 |
+| lerobot/umi_cup_in_the_wild | 15 | 44.577 | 0.313 | 0.0000417 |
+| lerobot/umi_cup_in_the_wild | 20 | 47.907 | 0.264 | 0.0000421 |
+| lerobot/umi_cup_in_the_wild | 40 | 52.554 | 0.185 | 0.0000414 |
+| lerobot/umi_cup_in_the_wild | 100 | 58.241 | 0.090 | 0.0000420 |
+| lerobot/umi_cup_in_the_wild | None | 60.530 | 0.042 | 0.0000424 |
+
+**`crf`**
+| repo_id | crf | compression_factor | load_time_factor | avg_per_pixel_l2_error |
+| --- | --- | --- | --- | --- |
+| lerobot/pusht | 0 | 1.699 | 0.302 | 0.0000097 |
+| lerobot/pusht | 5 | 1.409 | 0.287 | 0.0000142 |
+| lerobot/pusht | 10 | 1.842 | 0.283 | 0.0000184 |
+| lerobot/pusht | 15 | 2.322 | 0.305 | 0.0000268 |
+| lerobot/pusht | 20 | 3.050 | 0.285 | 0.0000402 |
+| lerobot/pusht | None | 3.646 | 0.285 | 0.0000496 |
+| lerobot/pusht | 25 | 3.969 | 0.293 | 0.0000572 |
+| lerobot/pusht | 30 | 5.687 | 0.293 | 0.0000893 |
+| lerobot/pusht | 40 | 10.818 | 0.319 | 0.0001762 |
+| lerobot/pusht | 50 | 18.185 | 0.304 | 0.0002626 |
+| lerobot/umi_cup_in_the_wild | 0 | 1.918 | 0.235 | 0.0000112 |
+| lerobot/umi_cup_in_the_wild | 5 | 3.207 | 0.261 | 0.0000166 |
+| lerobot/umi_cup_in_the_wild | 10 | 4.818 | 0.333 | 0.0000207 |
+| lerobot/umi_cup_in_the_wild | 15 | 7.329 | 0.406 | 0.0000267 |
+| lerobot/umi_cup_in_the_wild | 20 | 11.361 | 0.489 | 0.0000361 |
+| lerobot/umi_cup_in_the_wild | None | 14.932 | 0.537 | 0.0000436 |
+| lerobot/umi_cup_in_the_wild | 25 | 17.741 | 0.578 | 0.0000487 |
+| lerobot/umi_cup_in_the_wild | 30 | 27.983 | 0.453 | 0.0000655 |
+| lerobot/umi_cup_in_the_wild | 40 | 82.449 | 0.767 | 0.0001192 |
+| lerobot/umi_cup_in_the_wild | 50 | 186.145 | 0.816 | 0.0001881 |
+
+**best**
+| repo_id | compression_factor | load_time_factor | avg_per_pixel_l2_error |
+| --- | --- | --- | --- |
+| lerobot/pusht | 3.646 | 0.283 | 0.0000496 |
+| lerobot/umi_cup_in_the_wild | 14.932 | 0.543 | 0.0000436 |
+
+### `2_frames_4_space`
+
+**`pix_fmt`**
+| repo_id | pix_fmt | compression_factor | load_time_factor | avg_per_pixel_l2_error |
+| --- | --- | --- | --- | --- |
+| lerobot/pusht | yuv420p | 3.788 | 0.257 | 0.0000855 |
+| lerobot/pusht | yuv444p | 3.646 | 0.261 | 0.0000556 |
+| lerobot/umi_cup_in_the_wild | yuv420p | 14.391 | 0.493 | 0.0000476 |
+| lerobot/umi_cup_in_the_wild | yuv444p | 14.932 | 0.371 | 0.0000404 |
+
+**`g`**
+| repo_id | g | compression_factor | load_time_factor | avg_per_pixel_l2_error |
+| --- | --- | --- | --- | --- |
+| lerobot/pusht | 1 | 2.543 | 0.226 | 0.0000670 |
+| lerobot/pusht | 2 | 3.646 | 0.222 | 0.0000556 |
+| lerobot/pusht | 3 | 4.431 | 0.217 | 0.0000567 |
+| lerobot/pusht | 4 | 5.103 | 0.204 | 0.0000555 |
+| lerobot/pusht | 5 | 5.625 | 0.179 | 0.0000556 |
+| lerobot/pusht | 6 | 5.974 | 0.188 | 0.0000544 |
+| lerobot/pusht | 10 | 6.814 | 0.160 | 0.0000531 |
+| lerobot/pusht | 15 | 7.431 | 0.150 | 0.0000521 |
+| lerobot/pusht | 20 | 7.662 | 0.123 | 0.0000519 |
+| lerobot/pusht | 40 | 8.163 | 0.092 | 0.0000519 |
+| lerobot/pusht | 100 | 8.761 | 0.053 | 0.0000533 |
+| lerobot/pusht | None | 8.909 | 0.034 | 0.0000541 |
+| lerobot/umi_cup_in_the_wild | 1 | 14.411 | 0.409 | 0.0000607 |
+| lerobot/umi_cup_in_the_wild | 2 | 14.932 | 0.381 | 0.0000404 |
+| lerobot/umi_cup_in_the_wild | 3 | 20.174 | 0.355 | 0.0000418 |
+| lerobot/umi_cup_in_the_wild | 4 | 24.889 | 0.346 | 0.0000425 |
+| lerobot/umi_cup_in_the_wild | 5 | 28.825 | 0.354 | 0.0000419 |
+| lerobot/umi_cup_in_the_wild | 6 | 31.635 | 0.336 | 0.0000419 |
+| lerobot/umi_cup_in_the_wild | 10 | 39.418 | 0.314 | 0.0000402 |
+| lerobot/umi_cup_in_the_wild | 15 | 44.577 | 0.269 | 0.0000397 |
+| lerobot/umi_cup_in_the_wild | 20 | 47.907 | 0.246 | 0.0000395 |
+| lerobot/umi_cup_in_the_wild | 40 | 52.554 | 0.171 | 0.0000390 |
+| lerobot/umi_cup_in_the_wild | 100 | 58.241 | 0.091 | 0.0000399 |
+| lerobot/umi_cup_in_the_wild | None | 60.530 | 0.043 | 0.0000409 |
+
+**`crf`**
+| repo_id | crf | compression_factor | load_time_factor | avg_per_pixel_l2_error |
+| --- | --- | --- | --- | --- |
+| lerobot/pusht | 0 | 1.699 | 0.212 | 0.0000193 |
+| lerobot/pusht | 5 | 1.409 | 0.211 | 0.0000232 |
+| lerobot/pusht | 10 | 1.842 | 0.199 | 0.0000270 |
+| lerobot/pusht | 15 | 2.322 | 0.198 | 0.0000347 |
+| lerobot/pusht | 20 | 3.050 | 0.211 | 0.0000469 |
+| lerobot/pusht | None | 3.646 | 0.206 | 0.0000556 |
+| lerobot/pusht | 25 | 3.969 | 0.210 | 0.0000626 |
+| lerobot/pusht | 30 | 5.687 | 0.223 | 0.0000927 |
+| lerobot/pusht | 40 | 10.818 | 0.227 | 0.0001763 |
+| lerobot/pusht | 50 | 18.185 | 0.223 | 0.0002625 |
+| lerobot/umi_cup_in_the_wild | 0 | 1.918 | 0.147 | 0.0000071 |
+| lerobot/umi_cup_in_the_wild | 5 | 3.207 | 0.182 | 0.0000125 |
+| lerobot/umi_cup_in_the_wild | 10 | 4.818 | 0.222 | 0.0000166 |
+| lerobot/umi_cup_in_the_wild | 15 | 7.329 | 0.270 | 0.0000229 |
+| lerobot/umi_cup_in_the_wild | 20 | 11.361 | 0.325 | 0.0000326 |
+| lerobot/umi_cup_in_the_wild | None | 14.932 | 0.362 | 0.0000404 |
+| lerobot/umi_cup_in_the_wild | 25 | 17.741 | 0.390 | 0.0000459 |
+| lerobot/umi_cup_in_the_wild | 30 | 27.983 | 0.437 | 0.0000633 |
+| lerobot/umi_cup_in_the_wild | 40 | 82.449 | 0.499 | 0.0001186 |
+| lerobot/umi_cup_in_the_wild | 50 | 186.145 | 0.564 | 0.0001879 |
+
+**best**
+| repo_id | compression_factor | load_time_factor | avg_per_pixel_l2_error |
+| --- | --- | --- | --- |
+| lerobot/pusht | 3.646 | 0.224 | 0.0000556 |
+| lerobot/umi_cup_in_the_wild | 14.932 | 0.368 | 0.0000404 |
+
+### `6_frames`
+
+**`pix_fmt`**
+| repo_id | pix_fmt | compression_factor | load_time_factor | avg_per_pixel_l2_error |
+| --- | --- | --- | --- | --- |
+| lerobot/pusht | yuv420p | 3.788 | 0.660 | 0.0000839 |
+| lerobot/pusht | yuv444p | 3.646 | 0.546 | 0.0000542 |
+| lerobot/umi_cup_in_the_wild | yuv420p | 14.391 | 1.225 | 0.0000497 |
+| lerobot/umi_cup_in_the_wild | yuv444p | 14.932 | 0.908 | 0.0000428 |
+
+**`g`**
+| repo_id | g | compression_factor | load_time_factor | avg_per_pixel_l2_error |
+| --- | --- | --- | --- | --- |
+| lerobot/pusht | 1 | 2.543 | 0.552 | 0.0000646 |
+| lerobot/pusht | 2 | 3.646 | 0.534 | 0.0000542 |
+| lerobot/pusht | 3 | 4.431 | 0.563 | 0.0000546 |
+| lerobot/pusht | 4 | 5.103 | 0.537 | 0.0000545 |
+| lerobot/pusht | 5 | 5.625 | 0.477 | 0.0000532 |
+| lerobot/pusht | 6 | 5.974 | 0.515 | 0.0000530 |
+| lerobot/pusht | 10 | 6.814 | 0.410 | 0.0000512 |
+| lerobot/pusht | 15 | 7.431 | 0.405 | 0.0000503 |
+| lerobot/pusht | 20 | 7.662 | 0.345 | 0.0000500 |
+| lerobot/pusht | 40 | 8.163 | 0.247 | 0.0000496 |
+| lerobot/pusht | 100 | 8.761 | 0.147 | 0.0000510 |
+| lerobot/pusht | None | 8.909 | 0.100 | 0.0000519 |
+| lerobot/umi_cup_in_the_wild | 1 | 14.411 | 0.997 | 0.0000620 |
+| lerobot/umi_cup_in_the_wild | 2 | 14.932 | 0.911 | 0.0000428 |
+| lerobot/umi_cup_in_the_wild | 3 | 20.174 | 0.869 | 0.0000433 |
+| lerobot/umi_cup_in_the_wild | 4 | 24.889 | 0.874 | 0.0000438 |
+| lerobot/umi_cup_in_the_wild | 5 | 28.825 | 0.864 | 0.0000439 |
+| lerobot/umi_cup_in_the_wild | 6 | 31.635 | 0.834 | 0.0000440 |
+| lerobot/umi_cup_in_the_wild | 10 | 39.418 | 0.781 | 0.0000421 |
+| lerobot/umi_cup_in_the_wild | 15 | 44.577 | 0.679 | 0.0000411 |
+| lerobot/umi_cup_in_the_wild | 20 | 47.907 | 0.652 | 0.0000410 |
+| lerobot/umi_cup_in_the_wild | 40 | 52.554 | 0.465 | 0.0000404 |
+| lerobot/umi_cup_in_the_wild | 100 | 58.241 | 0.245 | 0.0000413 |
+| lerobot/umi_cup_in_the_wild | None | 60.530 | 0.116 | 0.0000417 |
+
+**`crf`**
+| repo_id | crf | compression_factor | load_time_factor | avg_per_pixel_l2_error |
+| --- | --- | --- | --- | --- |
+| lerobot/pusht | 0 | 1.699 | 0.534 | 0.0000163 |
+| lerobot/pusht | 5 | 1.409 | 0.524 | 0.0000205 |
+| lerobot/pusht | 10 | 1.842 | 0.510 | 0.0000245 |
+| lerobot/pusht | 15 | 2.322 | 0.512 | 0.0000324 |
+| lerobot/pusht | 20 | 3.050 | 0.508 | 0.0000452 |
+| lerobot/pusht | None | 3.646 | 0.518 | 0.0000542 |
+| lerobot/pusht | 25 | 3.969 | 0.534 | 0.0000616 |
+| lerobot/pusht | 30 | 5.687 | 0.530 | 0.0000927 |
+| lerobot/pusht | 40 | 10.818 | 0.552 | 0.0001777 |
+| lerobot/pusht | 50 | 18.185 | 0.564 | 0.0002644 |
+| lerobot/umi_cup_in_the_wild | 0 | 1.918 | 0.401 | 0.0000101 |
+| lerobot/umi_cup_in_the_wild | 5 | 3.207 | 0.499 | 0.0000156 |
+| lerobot/umi_cup_in_the_wild | 10 | 4.818 | 0.599 | 0.0000197 |
+| lerobot/umi_cup_in_the_wild | 15 | 7.329 | 0.704 | 0.0000258 |
+| lerobot/umi_cup_in_the_wild | 20 | 11.361 | 0.834 | 0.0000352 |
+| lerobot/umi_cup_in_the_wild | None | 14.932 | 0.925 | 0.0000428 |
+| lerobot/umi_cup_in_the_wild | 25 | 17.741 | 0.978 | 0.0000480 |
+| lerobot/umi_cup_in_the_wild | 30 | 27.983 | 1.088 | 0.0000648 |
+| lerobot/umi_cup_in_the_wild | 40 | 82.449 | 1.324 | 0.0001190 |
+| lerobot/umi_cup_in_the_wild | 50 | 186.145 | 1.436 | 0.0001880 |
+
+**best**
+| repo_id | compression_factor | load_time_factor | avg_per_pixel_l2_error |
+| --- | --- | --- | --- |
+| lerobot/pusht | 3.646 | 0.546 | 0.0000542 |
+| lerobot/umi_cup_in_the_wild | 14.932 | 0.934 | 0.0000428 |
diff --git a/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py b/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py
new file mode 100644
index 0000000..b6e83a0
--- /dev/null
+++ b/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py
@@ -0,0 +1,360 @@
+import json
+import os
+import random
+import shutil
+import subprocess
+import time
+from pathlib import Path
+
+import einops
+import numpy
+import PIL
+import torch
+
+from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.common.datasets.video_utils import (
+ decode_video_frames_torchvision,
+)
+
+
+def get_directory_size(directory):
+ total_size = 0
+ # Iterate over all files and subdirectories recursively
+ for item in directory.rglob("*"):
+ if item.is_file():
+ # Add the file size to the total
+ total_size += item.stat().st_size
+ return total_size
+
+
+def run_video_benchmark(
+ output_dir,
+ cfg,
+ timestamps_mode,
+ seed=1337,
+):
+ output_dir = Path(output_dir)
+ if output_dir.exists():
+ shutil.rmtree(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ repo_id = cfg["repo_id"]
+
+ # TODO(rcadene): rewrite with hardcoding of original images and episodes
+ dataset = LeRobotDataset(
+ repo_id,
+ root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
+ )
+ # Get fps
+ fps = dataset.fps
+
+ # we only load first episode
+ ep_num_images = dataset.episode_data_index["to"][0].item()
+
+ # Save/Load image directory for the first episode
+ imgs_dir = Path(f"tmp/data/images/{repo_id}/observation.image_episode_000000")
+ if not imgs_dir.exists():
+ imgs_dir.mkdir(parents=True, exist_ok=True)
+ hf_dataset = dataset.hf_dataset.with_format(None)
+ imgs_dataset = hf_dataset.select_columns("observation.image")
+
+ for i, item in enumerate(imgs_dataset):
+ img = item["observation.image"]
+ img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
+
+ if i >= ep_num_images - 1:
+ break
+
+ sum_original_frames_size_bytes = get_directory_size(imgs_dir)
+
+ # Encode images into video
+ video_path = output_dir / "episode_0.mp4"
+
+ g = cfg.get("g")
+ crf = cfg.get("crf")
+ pix_fmt = cfg["pix_fmt"]
+
+ cmd = f"ffmpeg -r {fps} "
+ cmd += "-f image2 "
+ cmd += "-loglevel error "
+ cmd += f"-i {str(imgs_dir / 'frame_%06d.png')} "
+ cmd += "-vcodec libx264 "
+ if g is not None:
+ cmd += f"-g {g} " # ensures at least 1 keyframe every 10 frames
+ # cmd += "-keyint_min 10 " set a minimum of 10 frames between 2 key frames
+ # cmd += "-sc_threshold 0 " disable scene change detection to lower the number of key frames
+ if crf is not None:
+ cmd += f"-crf {crf} "
+ cmd += f"-pix_fmt {pix_fmt} "
+ cmd += f"{str(video_path)}"
+ subprocess.run(cmd.split(" "), check=True)
+
+ video_size_bytes = video_path.stat().st_size
+
+ # Set decoder
+
+ decoder = cfg["decoder"]
+ decoder_kwgs = cfg["decoder_kwgs"]
+ device = cfg["device"]
+
+ if decoder == "torchvision":
+ decode_frames_fn = decode_video_frames_torchvision
+ else:
+ raise ValueError(decoder)
+
+ # Estimate average loading time
+
+ def load_original_frames(imgs_dir, timestamps):
+ frames = []
+ for ts in timestamps:
+ idx = int(ts * fps)
+ frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png")
+ frame = torch.from_numpy(numpy.array(frame))
+ frame = frame.type(torch.float32) / 255
+ frame = einops.rearrange(frame, "h w c -> c h w")
+ frames.append(frame)
+ return frames
+
+ list_avg_load_time = []
+ list_avg_load_time_from_images = []
+ per_pixel_l2_errors = []
+
+ random.seed(seed)
+
+ for t in range(50):
+ # test loading 2 frames that are 4 frames appart, which might be a common setting
+ ts = random.randint(fps, ep_num_images - fps) / fps
+
+ if timestamps_mode == "1_frame":
+ timestamps = [ts]
+ elif timestamps_mode == "2_frames":
+ timestamps = [ts - 1 / fps, ts]
+ elif timestamps_mode == "2_frames_4_space":
+ timestamps = [ts - 4 / fps, ts]
+ elif timestamps_mode == "6_frames":
+ timestamps = [ts - i / fps for i in range(6)][::-1]
+ else:
+ raise ValueError(timestamps_mode)
+
+ num_frames = len(timestamps)
+
+ start_time_s = time.monotonic()
+ frames = decode_frames_fn(
+ video_path, timestamps=timestamps, tolerance_s=1e-4, device=device, **decoder_kwgs
+ )
+ avg_load_time = (time.monotonic() - start_time_s) / num_frames
+ list_avg_load_time.append(avg_load_time)
+
+ start_time_s = time.monotonic()
+ original_frames = load_original_frames(imgs_dir, timestamps)
+ avg_load_time_from_images = (time.monotonic() - start_time_s) / num_frames
+ list_avg_load_time_from_images.append(avg_load_time_from_images)
+
+ # Estimate average L2 error between original frames and decoded frames
+ for i, ts in enumerate(timestamps):
+ # are_close = torch.allclose(frames[i], original_frames[i], atol=0.02)
+ num_pixels = original_frames[i].numel()
+ per_pixel_l2_error = torch.norm(frames[i] - original_frames[i], p=2).item() / num_pixels
+
+ # save decoded frames
+ if t == 0:
+ frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
+ PIL.Image.fromarray(frame_hwc).save(output_dir / f"frame_{i:06d}.png")
+
+ # save original_frames
+ idx = int(ts * fps)
+ if t == 0:
+ original_frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png")
+ original_frame.save(output_dir / f"original_frame_{i:06d}.png")
+
+ per_pixel_l2_errors.append(per_pixel_l2_error)
+
+ avg_load_time = float(numpy.array(list_avg_load_time).mean())
+ avg_load_time_from_images = float(numpy.array(list_avg_load_time_from_images).mean())
+ avg_per_pixel_l2_error = float(numpy.array(per_pixel_l2_errors).mean())
+
+ # Save benchmark info
+
+ info = {
+ "sum_original_frames_size_bytes": sum_original_frames_size_bytes,
+ "video_size_bytes": video_size_bytes,
+ "avg_load_time_from_images": avg_load_time_from_images,
+ "avg_load_time": avg_load_time,
+ "compression_factor": sum_original_frames_size_bytes / video_size_bytes,
+ "load_time_factor": avg_load_time_from_images / avg_load_time,
+ "avg_per_pixel_l2_error": avg_per_pixel_l2_error,
+ }
+
+ with open(output_dir / "info.json", "w") as f:
+ json.dump(info, f)
+
+ return info
+
+
+def display_markdown_table(headers, rows):
+ for i, row in enumerate(rows):
+ new_row = []
+ for col in row:
+ if col is None:
+ new_col = "None"
+ elif isinstance(col, float):
+ new_col = f"{col:.3f}"
+ if new_col == "0.000":
+ new_col = f"{col:.7f}"
+ elif isinstance(col, int):
+ new_col = f"{col}"
+ else:
+ new_col = col
+ new_row.append(new_col)
+ rows[i] = new_row
+
+ header_line = "| " + " | ".join(headers) + " |"
+ separator_line = "| " + " | ".join(["---" for _ in headers]) + " |"
+ body_lines = ["| " + " | ".join(row) + " |" for row in rows]
+ markdown_table = "\n".join([header_line, separator_line] + body_lines)
+ print(markdown_table)
+ print()
+
+
+def load_info(out_dir):
+ with open(out_dir / "info.json") as f:
+ info = json.load(f)
+ return info
+
+
+def main():
+ out_dir = Path("tmp/run_video_benchmark")
+ dry_run = False
+ repo_ids = ["lerobot/pusht", "lerobot/umi_cup_in_the_wild"]
+ timestamps_modes = [
+ "1_frame",
+ "2_frames",
+ "2_frames_4_space",
+ "6_frames",
+ ]
+ for timestamps_mode in timestamps_modes:
+ bench_dir = out_dir / timestamps_mode
+
+ print(f"### `{timestamps_mode}`")
+ print()
+
+ print("**`pix_fmt`**")
+ headers = ["repo_id", "pix_fmt", "compression_factor", "load_time_factor", "avg_per_pixel_l2_error"]
+ rows = []
+ for repo_id in repo_ids:
+ for pix_fmt in ["yuv420p", "yuv444p"]:
+ cfg = {
+ "repo_id": repo_id,
+ # video encoding
+ "g": 2,
+ "crf": None,
+ "pix_fmt": pix_fmt,
+ # video decoding
+ "device": "cpu",
+ "decoder": "torchvision",
+ "decoder_kwgs": {},
+ }
+ if not dry_run:
+ run_video_benchmark(bench_dir / repo_id / f"torchvision_{pix_fmt}", cfg, timestamps_mode)
+ info = load_info(bench_dir / repo_id / f"torchvision_{pix_fmt}")
+ rows.append(
+ [
+ repo_id,
+ pix_fmt,
+ info["compression_factor"],
+ info["load_time_factor"],
+ info["avg_per_pixel_l2_error"],
+ ]
+ )
+ display_markdown_table(headers, rows)
+
+ print("**`g`**")
+ headers = ["repo_id", "g", "compression_factor", "load_time_factor", "avg_per_pixel_l2_error"]
+ rows = []
+ for repo_id in repo_ids:
+ for g in [1, 2, 3, 4, 5, 6, 10, 15, 20, 40, 100, None]:
+ cfg = {
+ "repo_id": repo_id,
+ # video encoding
+ "g": g,
+ "pix_fmt": "yuv444p",
+ # video decoding
+ "device": "cpu",
+ "decoder": "torchvision",
+ "decoder_kwgs": {},
+ }
+ if not dry_run:
+ run_video_benchmark(bench_dir / repo_id / f"torchvision_g_{g}", cfg, timestamps_mode)
+ info = load_info(bench_dir / repo_id / f"torchvision_g_{g}")
+ rows.append(
+ [
+ repo_id,
+ g,
+ info["compression_factor"],
+ info["load_time_factor"],
+ info["avg_per_pixel_l2_error"],
+ ]
+ )
+ display_markdown_table(headers, rows)
+
+ print("**`crf`**")
+ headers = ["repo_id", "crf", "compression_factor", "load_time_factor", "avg_per_pixel_l2_error"]
+ rows = []
+ for repo_id in repo_ids:
+ for crf in [0, 5, 10, 15, 20, None, 25, 30, 40, 50]:
+ cfg = {
+ "repo_id": repo_id,
+ # video encoding
+ "g": 2,
+ "crf": crf,
+ "pix_fmt": "yuv444p",
+ # video decoding
+ "device": "cpu",
+ "decoder": "torchvision",
+ "decoder_kwgs": {},
+ }
+ if not dry_run:
+ run_video_benchmark(bench_dir / repo_id / f"torchvision_crf_{crf}", cfg, timestamps_mode)
+ info = load_info(bench_dir / repo_id / f"torchvision_crf_{crf}")
+ rows.append(
+ [
+ repo_id,
+ crf,
+ info["compression_factor"],
+ info["load_time_factor"],
+ info["avg_per_pixel_l2_error"],
+ ]
+ )
+ display_markdown_table(headers, rows)
+
+ print("**best**")
+ headers = ["repo_id", "compression_factor", "load_time_factor", "avg_per_pixel_l2_error"]
+ rows = []
+ for repo_id in repo_ids:
+ cfg = {
+ "repo_id": repo_id,
+ # video encoding
+ "g": 2,
+ "crf": None,
+ "pix_fmt": "yuv444p",
+ # video decoding
+ "device": "cpu",
+ "decoder": "torchvision",
+ "decoder_kwgs": {},
+ }
+ if not dry_run:
+ run_video_benchmark(bench_dir / repo_id / "torchvision_best", cfg, timestamps_mode)
+ info = load_info(bench_dir / repo_id / "torchvision_best")
+ rows.append(
+ [
+ repo_id,
+ info["compression_factor"],
+ info["load_time_factor"],
+ info["avg_per_pixel_l2_error"],
+ ]
+ )
+ display_markdown_table(headers, rows)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py
index 6eedb23..186c3e4 100644
--- a/lerobot/common/datasets/lerobot_dataset.py
+++ b/lerobot/common/datasets/lerobot_dataset.py
@@ -9,14 +9,18 @@ from lerobot.common.datasets.utils import (
load_info,
load_previous_and_future_frames,
load_stats,
+ load_videos,
)
+from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
+
+CODEBASE_VERSION = "v1.2"
class LeRobotDataset(torch.utils.data.Dataset):
def __init__(
self,
repo_id: str,
- version: str | None = "v1.1",
+ version: str | None = CODEBASE_VERSION,
root: Path | None = None,
split: str = "train",
transform: callable = None,
@@ -30,18 +34,45 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.transform = transform
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
+ # TODO(rcadene, aliberts): implement faster transfer
+ # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
self.hf_dataset = load_hf_dataset(repo_id, version, root, split)
self.episode_data_index = load_episode_data_index(repo_id, version, root)
self.stats = load_stats(repo_id, version, root)
self.info = load_info(repo_id, version, root)
+ if self.video:
+ self.videos_dir = load_videos(repo_id, version, root)
@property
def fps(self) -> int:
return self.info["fps"]
+ @property
+ def video(self) -> bool:
+ """Returns True if this dataset loads video frames from mp4 files.
+ Returns False if it only loads images from png files.
+ """
+ return self.info.get("video", False)
+
+ @property
+ def features(self) -> datasets.Features:
+ return self.hf_dataset.features
+
@property
def image_keys(self) -> list[str]:
- return [key for key, feats in self.hf_dataset.features.items() if isinstance(feats, datasets.Image)]
+ image_keys = []
+ for key, feats in self.hf_dataset.features.items():
+ if isinstance(feats, datasets.Image):
+ image_keys.append(key)
+ return image_keys + self.video_frame_keys
+
+ @property
+ def video_frame_keys(self):
+ video_frame_keys = []
+ for key, feats in self.hf_dataset.features.items():
+ if isinstance(feats, VideoFrame):
+ video_frame_keys.append(key)
+ return video_frame_keys
@property
def num_samples(self) -> int:
@@ -51,6 +82,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_index"))
+ @property
+ def tolerance_s(self) -> float:
+ """Tolerance in seconds used to discard loaded frames when their timestamps
+ are not close enough from the requested frames. It is only used when `delta_timestamps`
+ is provided or when loading video frames from mp4 files.
+ """
+ # 1e-4 to account for possible numerical error
+ return 1 / self.fps - 1e-4
+
def __len__(self):
return self.num_samples
@@ -63,10 +103,49 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
- tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
+ self.tolerance_s,
+ )
+
+ if self.video:
+ item = load_from_videos(
+ item,
+ self.video_frame_keys,
+ self.videos_dir,
+ self.tolerance_s,
)
if self.transform is not None:
item = self.transform(item)
return item
+
+ @classmethod
+ def from_preloaded(
+ cls,
+ repo_id: str,
+ version: str | None = CODEBASE_VERSION,
+ root: Path | None = None,
+ split: str = "train",
+ transform: callable = None,
+ delta_timestamps: dict[list[float]] | None = None,
+ # additional preloaded attributes
+ hf_dataset=None,
+ episode_data_index=None,
+ stats=None,
+ info=None,
+ videos_dir=None,
+ ):
+ # create an empty object of type LeRobotDataset
+ obj = cls.__new__(cls)
+ obj.repo_id = repo_id
+ obj.version = version
+ obj.root = root
+ obj.split = split
+ obj.transform = transform
+ obj.delta_timestamps = delta_timestamps
+ obj.hf_dataset = hf_dataset
+ obj.episode_data_index = episode_data_index
+ obj.stats = stats
+ obj.info = info
+ obj.videos_dir = videos_dir
+ return obj
diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py
index 70343ea..db9cd03 100644
--- a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py
+++ b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py
@@ -16,9 +16,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
-
-# TODO(rcadene): enable for PR video dataset
-# from lerobot.common.datasets.video_utils import encode_video_frames
+from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
def check_format(raw_dir) -> bool:
@@ -79,14 +77,17 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video
- video_path = out_dir / "videos" / f"{img_key}_episode_{ep_idx:06d}.mp4"
- encode_video_frames(tmp_imgs_dir, video_path, fps) # noqa: F821
+ fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
+ video_path = out_dir / "videos" / fname
+ encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
- # store the episode idx
- ep_dict[img_key] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
+ # store the reference to the video frame
+ ep_dict[img_key] = [
+ {"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
+ ]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
@@ -122,7 +123,7 @@ def to_hf_dataset(data_dict, video) -> Dataset:
image_keys = [key for key in data_dict if "observation.images." in key]
for image_key in image_keys:
if video:
- features[image_key] = Value(dtype="int64", id="video")
+ features[image_key] = VideoFrame()
else:
features[image_key] = Image()
diff --git a/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py b/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py
new file mode 100644
index 0000000..a7a952f
--- /dev/null
+++ b/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py
@@ -0,0 +1,146 @@
+from copy import deepcopy
+from math import ceil
+
+import datasets
+import einops
+import torch
+import tqdm
+from datasets import Image
+
+from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.common.datasets.video_utils import VideoFrame
+
+
+def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_workers=0):
+ """These einops patterns will be used to aggregate batches and compute statistics.
+
+ Note: We assume the images are in channel first format
+ """
+
+ dataloader = torch.utils.data.DataLoader(
+ dataset,
+ num_workers=num_workers,
+ batch_size=2,
+ shuffle=False,
+ )
+ batch = next(iter(dataloader))
+
+ stats_patterns = {}
+ for key, feats_type in dataset.features.items():
+ # sanity check that tensors are not float64
+ assert batch[key].dtype != torch.float64
+
+ if isinstance(feats_type, (VideoFrame, Image)):
+ # sanity check that images are channel first
+ _, c, h, w = batch[key].shape
+ assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
+
+ # sanity check that images are float32 in range [0,1]
+ assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
+ assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
+ assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
+
+ stats_patterns[key] = "b c h w -> c 1 1"
+ elif batch[key].ndim == 2:
+ stats_patterns[key] = "b c -> c "
+ elif batch[key].ndim == 1:
+ stats_patterns[key] = "b -> 1"
+ else:
+ raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
+
+ return stats_patterns
+
+
+def compute_stats(
+ dataset: LeRobotDataset | datasets.Dataset, batch_size=32, num_workers=16, max_num_samples=None
+):
+ if max_num_samples is None:
+ max_num_samples = len(dataset)
+
+ # for more info on why we need to set the same number of workers, see `load_from_videos`
+ stats_patterns = get_stats_einops_patterns(dataset, num_workers)
+
+ # mean and std will be computed incrementally while max and min will track the running value.
+ mean, std, max, min = {}, {}, {}, {}
+ for key in stats_patterns:
+ mean[key] = torch.tensor(0.0).float()
+ std[key] = torch.tensor(0.0).float()
+ max[key] = torch.tensor(-float("inf")).float()
+ min[key] = torch.tensor(float("inf")).float()
+
+ def create_seeded_dataloader(dataset, batch_size, seed):
+ generator = torch.Generator()
+ generator.manual_seed(seed)
+ dataloader = torch.utils.data.DataLoader(
+ dataset,
+ num_workers=num_workers,
+ batch_size=batch_size,
+ shuffle=True,
+ drop_last=False,
+ generator=generator,
+ )
+ return dataloader
+
+ # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
+ # surprises when rerunning the sampler.
+ first_batch = None
+ running_item_count = 0 # for online mean computation
+ dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
+ for i, batch in enumerate(
+ tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
+ ):
+ this_batch_size = len(batch["index"])
+ running_item_count += this_batch_size
+ if first_batch is None:
+ first_batch = deepcopy(batch)
+ for key, pattern in stats_patterns.items():
+ batch[key] = batch[key].float()
+ # Numerically stable update step for mean computation.
+ batch_mean = einops.reduce(batch[key], pattern, "mean")
+ # Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
+ # the update step, N is the running item count, B is this batch size, x̄ is the running mean,
+ # and x is the current batch mean. Some rearrangement is then required to avoid risking
+ # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
+ # x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
+ mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
+ max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
+ min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
+
+ if i == ceil(max_num_samples / batch_size) - 1:
+ break
+
+ first_batch_ = None
+ running_item_count = 0 # for online std computation
+ dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
+ for i, batch in enumerate(
+ tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
+ ):
+ this_batch_size = len(batch["index"])
+ running_item_count += this_batch_size
+ # Sanity check to make sure the batches are still in the same order as before.
+ if first_batch_ is None:
+ first_batch_ = deepcopy(batch)
+ for key in stats_patterns:
+ assert torch.equal(first_batch_[key], first_batch[key])
+ for key, pattern in stats_patterns.items():
+ batch[key] = batch[key].float()
+ # Numerically stable update step for mean computation (where the mean is over squared
+ # residuals).See notes in the mean computation loop above.
+ batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
+ std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
+
+ if i == ceil(max_num_samples / batch_size) - 1:
+ break
+
+ for key in stats_patterns:
+ std[key] = torch.sqrt(std[key])
+
+ stats = {}
+ for key in stats_patterns:
+ stats[key] = {
+ "mean": mean[key],
+ "std": std[key],
+ "max": max[key],
+ "min": min[key],
+ }
+ return stats
diff --git a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py
index 4c6eeb5..0c3a8d1 100644
--- a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py
+++ b/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py
@@ -14,9 +14,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
-
-# TODO(rcadene): enable for PR video dataset
-# from lerobot.common.datasets.video_utils import encode_video_frames
+from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
def check_format(raw_dir):
@@ -127,26 +125,28 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
ep_dict = {}
imgs_array = [x.numpy() for x in image]
+ img_key = "observation.image"
if video:
# save png images in temporary directory
tmp_imgs_dir = out_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video
- video_path = out_dir / "videos" / f"observation.image_episode_{ep_idx:06d}.mp4"
- encode_video_frames(tmp_imgs_dir, video_path, fps) # noqa: F821
+ fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
+ video_path = out_dir / "videos" / fname
+ encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
- # store the episode index
- ep_dict["observation.image"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
+ # store the reference to the video frame
+ ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
else:
- ep_dict["observation.image"] = [PILImage.fromarray(x) for x in imgs_array]
+ ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict["observation.state"] = agent_pos
ep_dict["action"] = actions[id_from:id_to]
- ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
+ ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
# ep_dict["next.observation.image"] = image[1:],
@@ -174,7 +174,7 @@ def to_hf_dataset(data_dict, video):
features = {}
if video:
- features["observation.image"] = Value(dtype="int64", id="video")
+ features["observation.image"] = VideoFrame()
else:
features["observation.image"] = Image()
diff --git a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py
index fd9100c..0082875 100644
--- a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py
+++ b/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py
@@ -16,9 +16,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
-
-# TODO(rcadene): enable for PR video dataset
-# from lerobot.common.datasets.video_utils import encode_video_frames
+from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
def check_format(raw_dir) -> bool:
@@ -103,25 +101,27 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
# load 57MB of images in RAM (400x224x224x3 uint8)
imgs_array = zarr_data["data/camera0_rgb"][id_from:id_to]
+ img_key = "observation.image"
if video:
# save png images in temporary directory
tmp_imgs_dir = out_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video
- video_path = out_dir / "videos" / f"observation.image_episode_{ep_idx:06d}.mp4"
- encode_video_frames(tmp_imgs_dir, video_path, fps) # noqa: F821
+ fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
+ video_path = out_dir / "videos" / fname
+ encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
- # store the episode index
- ep_dict["observation.image"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
+ # store the reference to the video frame
+ ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
else:
- ep_dict["observation.image"] = [PILImage.fromarray(x) for x in imgs_array]
+ ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict["observation.state"] = state
- ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
+ ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
ep_dict["episode_data_index_from"] = torch.tensor([id_from] * num_frames)
@@ -151,7 +151,7 @@ def to_hf_dataset(data_dict, video):
features = {}
if video:
- features["observation.image"] = Value(dtype="int64", id="video")
+ features["observation.image"] = VideoFrame()
else:
features["observation.image"] = Image()
diff --git a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py b/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py
index ba16f57..686edf4 100644
--- a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py
+++ b/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py
@@ -14,9 +14,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
-
-# TODO(rcadene): enable for PR video dataset
-# from lerobot.common.datasets.video_utils import encode_video_frames
+from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
def check_format(raw_dir):
@@ -76,26 +74,28 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
ep_dict = {}
imgs_array = [x.numpy() for x in image]
+ img_key = "observation.image"
if video:
# save png images in temporary directory
tmp_imgs_dir = out_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video
- video_path = out_dir / "videos" / f"observation.image_episode_{ep_idx:06d}.mp4"
- encode_video_frames(tmp_imgs_dir, video_path, fps) # noqa: F821
+ fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
+ video_path = out_dir / "videos" / fname
+ encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
- # store the episode index
- ep_dict["observation.image"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
+ # store the reference to the video frame
+ ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
else:
- ep_dict["observation.image"] = [PILImage.fromarray(x) for x in imgs_array]
+ ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict["observation.state"] = state
ep_dict["action"] = action
- ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
+ ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
# ep_dict["next.observation.image"] = next_image
@@ -122,7 +122,7 @@ def to_hf_dataset(data_dict, video):
features = {}
if video:
- features["observation.image"] = Value(dtype="int64", id="video")
+ features["observation.image"] = VideoFrame()
else:
features["observation.image"] = Image()
diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py
index ea03537..96b8fbb 100644
--- a/lerobot/common/datasets/utils.py
+++ b/lerobot/common/datasets/utils.py
@@ -1,14 +1,10 @@
import json
-from copy import deepcopy
-from math import ceil
from pathlib import Path
import datasets
-import einops
import torch
-import tqdm
-from datasets import Image, load_dataset, load_from_disk
-from huggingface_hub import hf_hub_download
+from datasets import load_dataset, load_from_disk
+from huggingface_hub import hf_hub_download, snapshot_download
from PIL import Image as PILImage
from safetensors.torch import load_file
from torchvision import transforms
@@ -57,6 +53,9 @@ def hf_transform_to_torch(items_dict):
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
+ elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
+ # video frame will be processed downstream
+ pass
else:
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
return items_dict
@@ -127,17 +126,29 @@ def load_info(repo_id, version, root) -> dict:
return info
+def load_videos(repo_id, version, root) -> Path:
+ if root is not None:
+ path = Path(root) / repo_id / "videos"
+ else:
+ # TODO(rcadene): we download the whole repo here. see if we can avoid this
+ repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=version)
+ path = Path(repo_dir) / "videos"
+
+ return path
+
+
def load_previous_and_future_frames(
item: dict[str, torch.Tensor],
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
delta_timestamps: dict[str, list[float]],
- tol: float,
+ tolerance_s: float,
) -> dict[torch.Tensor]:
"""
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of
some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), this function computes for each
- given modality a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames in the dataset.
+ given modality (e.g. "observation.image") a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest
+ frames in the dataset.
Importantly, when no frame can be found around a query timestamp within a specified tolerance window, this function
raises an AssertionError. When a timestamp is queried before the first available timestamp of the episode or after
@@ -156,7 +167,7 @@ def load_previous_and_future_frames(
They indicate the start index and end index of each episode in the dataset.
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
retrieved. These deltas are added to the item timestamp to form the query timestamps.
- - tol (float, optional): The tolerance level used to determine if a data point is close enough to the query
+ - tolerance_s (float, optional): The tolerance level (in seconds) used to determine if a data point is close enough to the query
timestamp by asserting `tol > difference`. It is suggested to set `tol` to a smaller value than the
smallest expected inter-frame period, but large enough to account for jitter.
@@ -194,11 +205,11 @@ def load_previous_and_future_frames(
# TODO(rcadene): synchronize timestamps + interpolation if needed
- is_pad = min_ > tol
+ is_pad = min_ > tolerance_s
# check violated query timestamps are all outside the episode range
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
- f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range."
+ f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tolerance_s=}) inside episode range."
"This might be due to synchronization issues with timestamps during data collection."
)
@@ -207,144 +218,18 @@ def load_previous_and_future_frames(
# load frames modality
item[key] = hf_dataset.select_columns(key)[data_ids][key]
- item[key] = torch.stack(item[key])
+
+ if isinstance(item[key][0], dict) and "path" in item[key][0]:
+ # video mode where frame are expressed as dict of path and timestamp
+ item[key] = item[key]
+ else:
+ item[key] = torch.stack(item[key])
+
item[f"{key}_is_pad"] = is_pad
return item
-def get_stats_einops_patterns(hf_dataset):
- """These einops patterns will be used to aggregate batches and compute statistics.
-
- Note: We assume the images of `hf_dataset` are in channel first format
- """
-
- dataloader = torch.utils.data.DataLoader(
- hf_dataset,
- num_workers=0,
- batch_size=2,
- shuffle=False,
- )
- batch = next(iter(dataloader))
-
- stats_patterns = {}
- for key, feats_type in hf_dataset.features.items():
- # sanity check that tensors are not float64
- assert batch[key].dtype != torch.float64
-
- if isinstance(feats_type, Image):
- # sanity check that images are channel first
- _, c, h, w = batch[key].shape
- assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
-
- # sanity check that images are float32 in range [0,1]
- assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
- assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
- assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
-
- stats_patterns[key] = "b c h w -> c 1 1"
- elif batch[key].ndim == 2:
- stats_patterns[key] = "b c -> c "
- elif batch[key].ndim == 1:
- stats_patterns[key] = "b -> 1"
- else:
- raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
-
- return stats_patterns
-
-
-def compute_stats(hf_dataset, batch_size=32, max_num_samples=None):
- if max_num_samples is None:
- max_num_samples = len(hf_dataset)
-
- stats_patterns = get_stats_einops_patterns(hf_dataset)
-
- # mean and std will be computed incrementally while max and min will track the running value.
- mean, std, max, min = {}, {}, {}, {}
- for key in stats_patterns:
- mean[key] = torch.tensor(0.0).float()
- std[key] = torch.tensor(0.0).float()
- max[key] = torch.tensor(-float("inf")).float()
- min[key] = torch.tensor(float("inf")).float()
-
- def create_seeded_dataloader(hf_dataset, batch_size, seed):
- generator = torch.Generator()
- generator.manual_seed(seed)
- dataloader = torch.utils.data.DataLoader(
- hf_dataset,
- num_workers=4,
- batch_size=batch_size,
- shuffle=True,
- drop_last=False,
- generator=generator,
- )
- return dataloader
-
- # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
- # surprises when rerunning the sampler.
- first_batch = None
- running_item_count = 0 # for online mean computation
- dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
- for i, batch in enumerate(
- tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
- ):
- this_batch_size = len(batch["index"])
- running_item_count += this_batch_size
- if first_batch is None:
- first_batch = deepcopy(batch)
- for key, pattern in stats_patterns.items():
- batch[key] = batch[key].float()
- # Numerically stable update step for mean computation.
- batch_mean = einops.reduce(batch[key], pattern, "mean")
- # Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
- # the update step, N is the running item count, B is this batch size, x̄ is the running mean,
- # and x is the current batch mean. Some rearrangement is then required to avoid risking
- # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
- # x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
- mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
- max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
- min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
-
- if i == ceil(max_num_samples / batch_size) - 1:
- break
-
- first_batch_ = None
- running_item_count = 0 # for online std computation
- dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
- for i, batch in enumerate(
- tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
- ):
- this_batch_size = len(batch["index"])
- running_item_count += this_batch_size
- # Sanity check to make sure the batches are still in the same order as before.
- if first_batch_ is None:
- first_batch_ = deepcopy(batch)
- for key in stats_patterns:
- assert torch.equal(first_batch_[key], first_batch[key])
- for key, pattern in stats_patterns.items():
- batch[key] = batch[key].float()
- # Numerically stable update step for mean computation (where the mean is over squared
- # residuals).See notes in the mean computation loop above.
- batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
- std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
-
- if i == ceil(max_num_samples / batch_size) - 1:
- break
-
- for key in stats_patterns:
- std[key] = torch.sqrt(std[key])
-
- stats = {}
- for key in stats_patterns:
- stats[key] = {
- "mean": mean[key],
- "std": std[key],
- "max": max[key],
- "min": min[key],
- }
- return stats
-
-
def cycle(iterable):
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py
new file mode 100644
index 0000000..0252be2
--- /dev/null
+++ b/lerobot/common/datasets/video_utils.py
@@ -0,0 +1,187 @@
+import logging
+import subprocess
+import warnings
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Any, ClassVar
+
+import pyarrow as pa
+import torch
+import torchvision
+from datasets.features.features import register_feature
+
+
+def load_from_videos(
+ item: dict[str, torch.Tensor], video_frame_keys: list[str], videos_dir: Path, tolerance_s: float
+):
+ """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
+ in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault.
+ This probably happens because a memory reference to the video loader is created in the main process and a
+ subprocess fails to access it.
+ """
+ # since video path already contains "videos" (e.g. videos_dir="data/videos", path="videos/episode_0.mp4")
+ data_dir = videos_dir.parent
+
+ for key in video_frame_keys:
+ if isinstance(item[key], list):
+ # load multiple frames at once (expected when delta_timestamps is not None)
+ timestamps = [frame["timestamp"] for frame in item[key]]
+ paths = [frame["path"] for frame in item[key]]
+ if len(set(paths)) > 1:
+ raise NotImplementedError("All video paths are expected to be the same for now.")
+ video_path = data_dir / paths[0]
+
+ frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s)
+ item[key] = frames
+ else:
+ # load one frame
+ timestamps = [item[key]["timestamp"]]
+ video_path = data_dir / item[key]["path"]
+
+ frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s)
+ item[key] = frames[0]
+
+ return item
+
+
+def decode_video_frames_torchvision(
+ video_path: str,
+ timestamps: list[float],
+ tolerance_s: float,
+ device: str = "cpu",
+ log_loaded_timestamps: bool = False,
+):
+ """Loads frames associated to the requested timestamps of a video
+
+ Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
+ the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
+ that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
+ and all subsequent frames until reaching the requested frame. The number of key frames in a video
+ can be adjusted during encoding to take into account decoding time and video size in bytes.
+ """
+ video_path = str(video_path)
+
+ # set backend
+ keyframes_only = False
+ if device == "cpu":
+ # explicitely use pyav
+ torchvision.set_video_backend("pyav")
+ keyframes_only = True # pyav doesnt support accuracte seek
+ elif device == "cuda":
+ # TODO(rcadene, aliberts): implement video decoding with GPU
+ # torchvision.set_video_backend("cuda")
+ # torchvision.set_video_backend("video_reader")
+ # requires installing torchvision from source, see: https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
+ # check possible bug: https://github.com/pytorch/vision/issues/7745
+ raise NotImplementedError(
+ "Video decoding on gpu with cuda is currently not supported. Use `device='cpu'`."
+ )
+ else:
+ raise ValueError(device)
+
+ # set a video stream reader
+ # TODO(rcadene): also load audio stream at the same time
+ reader = torchvision.io.VideoReader(video_path, "video")
+
+ # set the first and last requested timestamps
+ # Note: previous timestamps are usually loaded, since we need to access the previous key frame
+ first_ts = timestamps[0]
+ last_ts = timestamps[-1]
+
+ # access closest key frame of the first requested frame
+ # Note: closest key frame timestamp is usally smaller than `first_ts` (e.g. key frame can be the first frame of the video)
+ # for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
+ reader.seek(first_ts, keyframes_only=keyframes_only)
+
+ # load all frames until last requested frame
+ loaded_frames = []
+ loaded_ts = []
+ for frame in reader:
+ current_ts = frame["pts"]
+ if log_loaded_timestamps:
+ logging.info(f"frame loaded at timestamp={current_ts:.4f}")
+ loaded_frames.append(frame["data"])
+ loaded_ts.append(current_ts)
+ if current_ts >= last_ts:
+ break
+
+ reader.container.close()
+ reader = None
+
+ query_ts = torch.tensor(timestamps)
+ loaded_ts = torch.tensor(loaded_ts)
+
+ # compute distances between each query timestamp and timestamps of all loaded frames
+ dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
+ min_, argmin_ = dist.min(1)
+
+ is_within_tol = min_ < tolerance_s
+ assert is_within_tol.all(), (
+ f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
+ "It means that the closest frame that can be loaded from the video is too far away in time."
+ "This might be due to synchronization issues with timestamps during data collection."
+ "To be safe, we advise to ignore this item during training."
+ )
+
+ # get closest frames to the query timestamps
+ closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
+ closest_ts = loaded_ts[argmin_]
+
+ if log_loaded_timestamps:
+ logging.info(f"{closest_ts=}")
+
+ # convert to the pytorch format which is float32 in [0,1] range (and channel first)
+ closest_frames = closest_frames.type(torch.float32) / 255
+
+ assert len(timestamps) == len(closest_frames)
+ return closest_frames
+
+
+def encode_video_frames(imgs_dir: Path, video_path: Path, fps: int):
+ """More info on ffmpeg arguments tuning on `lerobot/common/datasets/_video_benchmark/README.md`"""
+ video_path = Path(video_path)
+ video_path.parent.mkdir(parents=True, exist_ok=True)
+
+ ffmpeg_cmd = (
+ f"ffmpeg -r {fps} "
+ "-f image2 "
+ "-loglevel error "
+ f"-i {str(imgs_dir / 'frame_%06d.png')} "
+ "-vcodec libx264 "
+ "-g 2 "
+ "-pix_fmt yuv444p "
+ f"{str(video_path)}"
+ )
+ subprocess.run(ffmpeg_cmd.split(" "), check=True)
+
+
+@dataclass
+class VideoFrame:
+ # TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
+ """
+ Provides a type for a dataset containing video frames.
+
+ Example:
+
+ ```python
+ data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
+ features = {"image": VideoFrame()}
+ Dataset.from_dict(data_dict, features=Features(features))
+ ```
+ """
+
+ pa_type: ClassVar[Any] = pa.struct({"path": pa.string(), "timestamp": pa.float32()})
+ _type: str = field(default="VideoFrame", init=False, repr=False)
+
+ def __call__(self):
+ return self.pa_type
+
+
+with warnings.catch_warnings():
+ warnings.filterwarnings(
+ "ignore",
+ "'register_feature' is experimental and might be subject to breaking changes in the future.",
+ category=UserWarning,
+ )
+ # to make VideoFrame available in HuggingFace `datasets`
+ register_feature(VideoFrame, "VideoFrame")
diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py
index 4c27fe7..bf1d51a 100644
--- a/lerobot/common/logger.py
+++ b/lerobot/common/logger.py
@@ -1,3 +1,6 @@
+# TODO(rcadene, alexander-soare): clean this file
+"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py"""
+
import logging
import os
from pathlib import Path
diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py
index 830b7a0..f5c9c74 100644
--- a/lerobot/scripts/push_dataset_to_hub.py
+++ b/lerobot/scripts/push_dataset_to_hub.py
@@ -1,5 +1,5 @@
"""
-Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub,
+Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub,
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
installation of neural net specific packages like pytorch, tensorflow, jax.
@@ -60,8 +60,10 @@ import torch
from huggingface_hub import HfApi
from safetensors.torch import save_file
+from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
-from lerobot.common.datasets.utils import compute_stats, flatten_dict
+from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats
+from lerobot.common.datasets.utils import flatten_dict
def get_from_raw_to_lerobot_format_fn(raw_format):
@@ -97,24 +99,34 @@ def save_meta_data(info, stats, episode_data_index, meta_data_dir):
save_file(episode_data_index, ep_data_idx_path)
-def push_meta_data_to_hub(meta_data_dir, repo_id, revision):
+def push_meta_data_to_hub(repo_id, meta_data_dir, revision):
+ """Expect all meta data files to be all stored in a single "meta_data" directory.
+ On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
+ """
api = HfApi()
+ api.upload_folder(
+ folder_path=meta_data_dir,
+ path_in_repo="meta_data",
+ repo_id=repo_id,
+ revision=revision,
+ repo_type="dataset",
+ allow_patterns=["*.json, *.safetensors"],
+ )
- def upload(filename, revision):
- api.upload_file(
- path_or_fileobj=meta_data_dir / filename,
- path_in_repo=f"meta_data/{filename}",
- repo_id=repo_id,
- revision=revision,
- repo_type="dataset",
- )
- upload("info.json", "main")
- upload("info.json", revision)
- upload("stats.safetensors", "main")
- upload("stats.safetensors", revision)
- upload("episode_data_index.safetensors", "main")
- upload("episode_data_index.safetensors", revision)
+def push_videos_to_hub(repo_id, videos_dir, revision):
+ """Expect mp4 files to be all stored in a single "videos" directory.
+ On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
+ """
+ api = HfApi()
+ api.upload_folder(
+ folder_path=videos_dir,
+ path_in_repo="videos",
+ repo_id=repo_id,
+ revision=revision,
+ repo_type="dataset",
+ allow_patterns="*.mp4",
+ )
def push_dataset_to_hub(
@@ -129,16 +141,21 @@ def push_dataset_to_hub(
save_tests_to_disk: bool,
fps: int | None,
video: bool,
+ batch_size: int,
+ num_workers: int,
debug: bool,
):
+ repo_id = f"{community_id}/{dataset_id}"
+
raw_dir = data_dir / f"{dataset_id}_raw"
- out_dir = data_dir / community_id / dataset_id
+ out_dir = data_dir / repo_id
meta_data_dir = out_dir / "meta_data"
videos_dir = out_dir / "videos"
- tests_out_dir = tests_data_dir / community_id / dataset_id
+ tests_out_dir = tests_data_dir / repo_id
tests_meta_data_dir = tests_out_dir / "meta_data"
+ tests_videos_dir = tests_out_dir / "videos"
if out_dir.exists():
shutil.rmtree(out_dir)
@@ -159,7 +176,15 @@ def push_dataset_to_hub(
# convert dataset from original raw format to LeRobot format
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug)
- stats = compute_stats(hf_dataset)
+ lerobot_dataset = LeRobotDataset.from_preloaded(
+ repo_id=repo_id,
+ version=revision,
+ hf_dataset=hf_dataset,
+ episode_data_index=episode_data_index,
+ info=info,
+ videos_dir=videos_dir,
+ )
+ stats = compute_stats(lerobot_dataset, batch_size, num_workers)
if save_to_disk:
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
@@ -170,12 +195,15 @@ def push_dataset_to_hub(
save_meta_data(info, stats, episode_data_index, meta_data_dir)
if not dry_run:
- repo_id = f"{community_id}/{dataset_id}"
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
- push_meta_data_to_hub(repo_id, meta_data_dir)
+
+ push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
+ push_meta_data_to_hub(repo_id, meta_data_dir, revision=revision)
+
if video:
- push_meta_data_to_hub(repo_id, videos_dir)
+ push_videos_to_hub(repo_id, videos_dir, revision="main")
+ push_videos_to_hub(repo_id, videos_dir, revision=revision)
if save_tests_to_disk:
# get the first episode
@@ -186,10 +214,15 @@ def push_dataset_to_hub(
test_hf_dataset.save_to_disk(str(tests_out_dir / "train"))
# copy meta data to tests directory
- if Path(tests_meta_data_dir).exists():
- shutil.rmtree(tests_meta_data_dir)
shutil.copytree(meta_data_dir, tests_meta_data_dir)
+ # copy videos of first episode to tests directory
+ episode_index = 0
+ tests_videos_dir.mkdir(parents=True, exist_ok=True)
+ for key in lerobot_dataset.video_frame_keys:
+ fname = f"{key}_episode_{episode_index:06d}.mp4"
+ shutil.copy(videos_dir / fname, tests_videos_dir / fname)
+
def main():
parser = argparse.ArgumentParser()
@@ -255,10 +288,21 @@ def main():
parser.add_argument(
"--video",
type=int,
- # TODO(rcadene): enable when video PR merges
- default=0,
+ default=1,
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
)
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=32,
+ help="Batch size loaded by DataLoader for computing the dataset statistics.",
+ )
+ parser.add_argument(
+ "--num-workers",
+ type=int,
+ default=16,
+ help="Number of processes of Dataloader for computing the dataset statistics.",
+ )
parser.add_argument(
"--debug",
type=int,
diff --git a/poetry.lock b/poetry.lock
index 89d35a5..cb7cd6d 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
[[package]]
name = "absl-py"
@@ -595,6 +595,17 @@ files = [
{file = "debugpy-1.8.1.zip", hash = "sha256:f696d6be15be87aef621917585f9bb94b1dc9e8aced570db1b8a6fc14e8f9b42"},
]
+[[package]]
+name = "decorator"
+version = "4.4.2"
+description = "Decorators for Humans"
+optional = false
+python-versions = ">=2.6, !=3.0.*, !=3.1.*"
+files = [
+ {file = "decorator-4.4.2-py2.py3-none-any.whl", hash = "sha256:41fa54c2a0cc4ba648be4fd43cff00aedf5b9465c9bf18d64325bc225f08f760"},
+ {file = "decorator-4.4.2.tar.gz", hash = "sha256:e3a62f0520172440ca0dcc823749319382e377f37f140a0b99ef45fecb84bfe7"},
+]
+
[[package]]
name = "diffusers"
version = "0.27.2"
@@ -1840,6 +1851,30 @@ files = [
intel-openmp = "==2021.*"
tbb = "==2021.*"
+[[package]]
+name = "moviepy"
+version = "1.0.3"
+description = "Video editing with Python"
+optional = false
+python-versions = "*"
+files = [
+ {file = "moviepy-1.0.3.tar.gz", hash = "sha256:2884e35d1788077db3ff89e763c5ba7bfddbd7ae9108c9bc809e7ba58fa433f5"},
+]
+
+[package.dependencies]
+decorator = ">=4.0.2,<5.0"
+imageio = {version = ">=2.5,<3.0", markers = "python_version >= \"3.4\""}
+imageio_ffmpeg = {version = ">=0.2.0", markers = "python_version >= \"3.4\""}
+numpy = {version = ">=1.17.3", markers = "python_version > \"2.7\""}
+proglog = "<=1.0.0"
+requests = ">=2.8.1,<3.0"
+tqdm = ">=4.11.2,<5.0"
+
+[package.extras]
+doc = ["Sphinx (>=1.5.2,<2.0)", "numpydoc (>=0.6.0,<1.0)", "pygame (>=1.9.3,<2.0)", "sphinx_rtd_theme (>=0.1.10b0,<1.0)"]
+optional = ["matplotlib (>=2.0.0,<3.0)", "opencv-python (>=3.0,<4.0)", "scikit-image (>=0.13.0,<1.0)", "scikit-learn", "scipy (>=0.19.0,<1.5)", "youtube_dl"]
+test = ["coverage (<5.0)", "coveralls (>=1.1,<2.0)", "pytest (>=3.0.0,<4.0)", "pytest-cov (>=2.5.1,<3.0)", "requests (>=2.8.1,<3.0)"]
+
[[package]]
name = "mpmath"
version = "1.3.0"
@@ -2594,6 +2629,20 @@ nodeenv = ">=0.11.1"
pyyaml = ">=5.1"
virtualenv = ">=20.10.0"
+[[package]]
+name = "proglog"
+version = "0.1.10"
+description = "Log and progress bar manager for console, notebooks, web..."
+optional = false
+python-versions = "*"
+files = [
+ {file = "proglog-0.1.10-py3-none-any.whl", hash = "sha256:19d5da037e8c813da480b741e3fa71fb1ac0a5b02bf21c41577c7f327485ec50"},
+ {file = "proglog-0.1.10.tar.gz", hash = "sha256:658c28c9c82e4caeb2f25f488fff9ceace22f8d69b15d0c1c86d64275e4ddab4"},
+]
+
+[package.dependencies]
+tqdm = "*"
+
[[package]]
name = "protobuf"
version = "4.25.3"
@@ -2701,6 +2750,44 @@ files = [
{file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"},
]
+[[package]]
+name = "pyav"
+version = "12.0.5"
+description = "Pythonic bindings for FFmpeg's libraries."
+optional = false
+python-versions = ">=3.9"
+files = [
+ {file = "pyav-12.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f19129d01d6be826ccf9b16151b0f52d954c8a797bd0fe3b84664f42c55070e2"},
+ {file = "pyav-12.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c4d6bf60a86cd73d7b195e7e3b6a386771f64524db72604242acc50beeaa7b62"},
+ {file = "pyav-12.0.5-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fc4521f2f8f48e0d30d5a83d898a7059bad49cbcc51cff299df00d554c6cbf26"},
+ {file = "pyav-12.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67eacfa977ac669ee3c9952955bce57ad3e93c3c24a686986b7c80e748fcfdd4"},
+ {file = "pyav-12.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:2a8503ba2464fb2a0a23bdb0ac1743942063f7cf2eb55b5d2477567b33acfc3d"},
+ {file = "pyav-12.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ac20eb76aeec143d571615c2dcd831976a68fc198b9d53b878b26be175a6499b"},
+ {file = "pyav-12.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2110c813aa9b0f2cac979367d69f95cfe94fc1bcef28e2c58cee56bf7f26de34"},
+ {file = "pyav-12.0.5-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6426807ce868b7e56effd7f6bb5092a9101e92ecfbadc3849691faf0bab32c21"},
+ {file = "pyav-12.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5bb08a9f2efe5673bf4c1cf8a809062490de7babafd50c0d5b78894d6c288054"},
+ {file = "pyav-12.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:684edd212f876061e191361f92c7120d6bf43ba3f312f5b56acf3afc8d8333f6"},
+ {file = "pyav-12.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:795b3624c8eab6bb8d530d88afcdba744cbb5f8f89d36d3da0265dc388772bde"},
+ {file = "pyav-12.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7f083314a92352ceb13b736a71504dea05534aab912ea5f341c4382482395eb3"},
+ {file = "pyav-12.0.5-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f832618f9bd2f219cec5683939ae76c474ef993b682a67815d8ffb0b377fc17"},
+ {file = "pyav-12.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f315cc0d0f87b53ae6de71df29fbae3cd4bfa995029129000ff9d66886e3bcbe"},
+ {file = "pyav-12.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:c8be9e573183a02e88c09ee9fcee8463c3b79625ff905ae96e05f1a282fe4b13"},
+ {file = "pyav-12.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c3d11e789115704a0a14805f3cb1d9459b9ab03efeb24bb28b8ee1b25a52ce6d"},
+ {file = "pyav-12.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:820bf8ebc82960fd2ae8c1cf1a6d09f6a84abd492d38c4580c37fed082130a22"},
+ {file = "pyav-12.0.5-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eed90bc92f3e9d92ef0119e0e424fd1c58db8b186128e9b9cd9ed0da0360bf13"},
+ {file = "pyav-12.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4f8b5fa78779acea93c986ab8afaaae6a71e3995dceff87d8a969c3a2b8c55c"},
+ {file = "pyav-12.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:d8a73d93e3d0377591b08dae057ba8e87211b4a05e6a59a9c90b51b801ce64ea"},
+ {file = "pyav-12.0.5-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8ad7bc5215b15f9da4990d74b4bf4d4dbf93cd61caf42e8b06d84fa1c960e864"},
+ {file = "pyav-12.0.5-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4ca5db3bc68f572f0fe5d316183725270edefa61ddb4032ebda5cd7751e09020"},
+ {file = "pyav-12.0.5-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c1d86d38b90e13250f62a258b90d6641957dab9bc069cbd4929bc7d3d017ec7"},
+ {file = "pyav-12.0.5-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:ccf267724fe1472add37968ff3768e4e5629c125c1c79af957b366fbad3d2e59"},
+ {file = "pyav-12.0.5-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f7519a05b19123e074e67248ed0f5672df752852cc43505f721ec2db9f80813c"},
+ {file = "pyav-12.0.5-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1ce141031338974567bc1e0504a5355449c61756626a07e3a43ded37a71afe39"},
+ {file = "pyav-12.0.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02f77d361ef728483ffe9430391ee554257c5c0872da8a2276275636226b3a85"},
+ {file = "pyav-12.0.5-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:647ebc369b1c7bfbdae626048e4d59265c3ab3ceb2e571ac83ddbbeaa70abb22"},
+ {file = "pyav-12.0.5.tar.gz", hash = "sha256:fc65bcb72f3f8040c47a5b5a8025b535c71dcb16f1c8f9ff9bb3bf3af17ac09a"},
+]
+
[[package]]
name = "pycparser"
version = "2.22"
@@ -3782,38 +3869,6 @@ typing-extensions = ">=4.8.0"
opt-einsum = ["opt-einsum (>=3.3)"]
optree = ["optree (>=0.9.1)"]
-[[package]]
-name = "torchaudio"
-version = "2.3.0"
-description = "An audio package for PyTorch"
-optional = false
-python-versions = "*"
-files = [
- {file = "torchaudio-2.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:342108da83aa19a457c9a128b1206fadb603753b51cca022b9f585aac2f4754c"},
- {file = "torchaudio-2.3.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:73fedb2c631e01fa10feaac308540b836aefe758e55ca3ee026335e5d01e8e30"},
- {file = "torchaudio-2.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:e5bb50b7a4874ed97086c9e516dd90b103d954edcb5ed4b36f4fc22c4000a5a7"},
- {file = "torchaudio-2.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:b4cc9cef5c98ed37e9405c4e0b0e6413bc101f3f49d45dc4f1d4e927757fe41e"},
- {file = "torchaudio-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:341ca3048ce6edcc731519b30187f0b13acb245c4efe16f925f69f9d533546e1"},
- {file = "torchaudio-2.3.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:8f2e0a28740bb0ee66369f92c811f33c0a47e6fcfc2de9cee89746472d713906"},
- {file = "torchaudio-2.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:61edb02ae9c0efea4399f9c1f899601136b24f35d430548284ea8eaf6ccbe3be"},
- {file = "torchaudio-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:04bc960cf1aef3b469b095a432a25496bc28197850fc2d90b7b52d6b5255487b"},
- {file = "torchaudio-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:535144a2fbba95fbb3b883224ffcf44788e4cecbabbe49c4a1ae3e7a74f71485"},
- {file = "torchaudio-2.3.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:fb3f52ed1d63b272c240d9bf051705312cb172212051b8a6a2f64d42e3cc1633"},
- {file = "torchaudio-2.3.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:668a8b694e5522cff28cd5e02d01aa1b75ce940aa9fb40480892bdc623b1735d"},
- {file = "torchaudio-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:6c1f538018b85d7766835d042e555de2f096f7a69bba6b16031bf42a914dd9e1"},
- {file = "torchaudio-2.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7ba93265455dc363385e98c0cfcaeb586b7401af8a2c824811ee1466134a4f30"},
- {file = "torchaudio-2.3.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:21bb6d1b384fc8895133f01489133d575d4a715cd81734b89651fb0264bd8b80"},
- {file = "torchaudio-2.3.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:ed1866f508dc689c4f682d330b2ed4c83108d35865e4fb89431819364d8ad9ed"},
- {file = "torchaudio-2.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:a3cbb230e2bb38ad1a1dd74aea242a154a9f76ab819d9c058b2c5074a9f5d7d2"},
- {file = "torchaudio-2.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f4b933776f20a36af5ddc57968fcb3da34dd03881db8d6760f3e1176803b9cf8"},
- {file = "torchaudio-2.3.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:c5e63cc2dbf179088b6cdfd21ecdbb943aa003c780075aa440162f231ee72db2"},
- {file = "torchaudio-2.3.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d243bb8a1ee263c2cdafb9feed1569c3742d8135731e8f7818de12f4e0c83e28"},
- {file = "torchaudio-2.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:6cd6d45cf8a45c89953e35434d9a461feb418e51e760adafc606a903dcbb9bd5"},
-]
-
-[package.dependencies]
-torch = "2.3.0"
-
[[package]]
name = "torchvision"
version = "0.18.0"
@@ -4275,4 +4330,4 @@ xarm = ["gym-xarm"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
-content-hash = "0f72eb92ac8817a46f0659b4d72647a6b76f6e4ba762d11b280f8a88e6cd4371"
+content-hash = "bcebde3fd603ba6867521297ab9c774286f8f9db87c1eb5a77fb6afb03fbd693"
diff --git a/pyproject.toml b/pyproject.toml
index 1072323..e5107c7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -56,7 +56,8 @@ pytest = {version = "^8.1.0", optional = true}
pytest-cov = {version = "^5.0.0", optional = true}
datasets = "^2.19.0"
imagecodecs = { version = "^2024.1.1", optional = true }
-torchaudio = "^2.3.0"
+pyav = "^12.0.5"
+moviepy = "^1.0.3"
[tool.poetry.extras]
diff --git a/tests/data/lerobot/aloha_sim_insertion_human/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_sim_insertion_human/meta_data/episode_data_index.safetensors
index 8ed156c..828c672 100644
Binary files a/tests/data/lerobot/aloha_sim_insertion_human/meta_data/episode_data_index.safetensors and b/tests/data/lerobot/aloha_sim_insertion_human/meta_data/episode_data_index.safetensors differ
diff --git a/tests/data/lerobot/aloha_sim_insertion_human/meta_data/info.json b/tests/data/lerobot/aloha_sim_insertion_human/meta_data/info.json
index 02e62b6..279cf2c 100644
--- a/tests/data/lerobot/aloha_sim_insertion_human/meta_data/info.json
+++ b/tests/data/lerobot/aloha_sim_insertion_human/meta_data/info.json
@@ -1,3 +1,4 @@
{
- "fps": 50
+ "fps": 50,
+ "video": 1
}
\ No newline at end of file
diff --git a/tests/data/lerobot/aloha_sim_insertion_human/meta_data/stats.safetensors b/tests/data/lerobot/aloha_sim_insertion_human/meta_data/stats.safetensors
index e956450..9b6a7c8 100644
Binary files a/tests/data/lerobot/aloha_sim_insertion_human/meta_data/stats.safetensors and b/tests/data/lerobot/aloha_sim_insertion_human/meta_data/stats.safetensors differ
diff --git a/tests/data/lerobot/aloha_sim_insertion_human/stats.pth b/tests/data/lerobot/aloha_sim_insertion_human/stats.pth
deleted file mode 100644
index a7b9248..0000000
Binary files a/tests/data/lerobot/aloha_sim_insertion_human/stats.pth and /dev/null differ
diff --git a/tests/data/lerobot/aloha_sim_insertion_human/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_sim_insertion_human/train/data-00000-of-00001.arrow
index f6c89ff..d93d0e2 100644
Binary files a/tests/data/lerobot/aloha_sim_insertion_human/train/data-00000-of-00001.arrow and b/tests/data/lerobot/aloha_sim_insertion_human/train/data-00000-of-00001.arrow differ
diff --git a/tests/data/lerobot/aloha_sim_insertion_human/train/dataset_info.json b/tests/data/lerobot/aloha_sim_insertion_human/train/dataset_info.json
index 69f8083..c6f7b93 100644
--- a/tests/data/lerobot/aloha_sim_insertion_human/train/dataset_info.json
+++ b/tests/data/lerobot/aloha_sim_insertion_human/train/dataset_info.json
@@ -3,7 +3,7 @@
"description": "",
"features": {
"observation.images.top": {
- "_type": "Image"
+ "_type": "VideoFrame"
},
"observation.state": {
"feature": {
diff --git a/tests/data/lerobot/aloha_sim_insertion_human/train/state.json b/tests/data/lerobot/aloha_sim_insertion_human/train/state.json
index 153a412..6cd9158 100644
--- a/tests/data/lerobot/aloha_sim_insertion_human/train/state.json
+++ b/tests/data/lerobot/aloha_sim_insertion_human/train/state.json
@@ -4,10 +4,10 @@
"filename": "data-00000-of-00001.arrow"
}
],
- "_fingerprint": "22eeca7a3f4725ee",
+ "_fingerprint": "eb913a2b1a68aa74",
"_format_columns": null,
"_format_kwargs": {},
- "_format_type": "torch",
+ "_format_type": null,
"_output_all_columns": false,
"_split": null
}
\ No newline at end of file
diff --git a/tests/data/lerobot/aloha_sim_insertion_human/videos/observation.images.top_episode_000000.mp4 b/tests/data/lerobot/aloha_sim_insertion_human/videos/observation.images.top_episode_000000.mp4
new file mode 100644
index 0000000..56280d5
Binary files /dev/null and b/tests/data/lerobot/aloha_sim_insertion_human/videos/observation.images.top_episode_000000.mp4 differ
diff --git a/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/episode_data_index.safetensors
index 8685de7..1505d61 100644
Binary files a/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/episode_data_index.safetensors and b/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/episode_data_index.safetensors differ
diff --git a/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/info.json b/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/info.json
index 02e62b6..279cf2c 100644
--- a/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/info.json
+++ b/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/info.json
@@ -1,3 +1,4 @@
{
- "fps": 50
+ "fps": 50,
+ "video": 1
}
\ No newline at end of file
diff --git a/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/stats.safetensors b/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/stats.safetensors
index 619ee90..6cce9ff 100644
Binary files a/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/stats.safetensors and b/tests/data/lerobot/aloha_sim_insertion_scripted/meta_data/stats.safetensors differ
diff --git a/tests/data/lerobot/aloha_sim_insertion_scripted/stats.pth b/tests/data/lerobot/aloha_sim_insertion_scripted/stats.pth
deleted file mode 100644
index 990d464..0000000
Binary files a/tests/data/lerobot/aloha_sim_insertion_scripted/stats.pth and /dev/null differ
diff --git a/tests/data/lerobot/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow
index f5cdadc..65a231a 100644
Binary files a/tests/data/lerobot/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow and b/tests/data/lerobot/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow differ
diff --git a/tests/data/lerobot/aloha_sim_insertion_scripted/train/dataset_info.json b/tests/data/lerobot/aloha_sim_insertion_scripted/train/dataset_info.json
index 69f8083..c6f7b93 100644
--- a/tests/data/lerobot/aloha_sim_insertion_scripted/train/dataset_info.json
+++ b/tests/data/lerobot/aloha_sim_insertion_scripted/train/dataset_info.json
@@ -3,7 +3,7 @@
"description": "",
"features": {
"observation.images.top": {
- "_type": "Image"
+ "_type": "VideoFrame"
},
"observation.state": {
"feature": {
diff --git a/tests/data/lerobot/aloha_sim_insertion_scripted/train/state.json b/tests/data/lerobot/aloha_sim_insertion_scripted/train/state.json
index 716aca6..b96705c 100644
--- a/tests/data/lerobot/aloha_sim_insertion_scripted/train/state.json
+++ b/tests/data/lerobot/aloha_sim_insertion_scripted/train/state.json
@@ -4,10 +4,10 @@
"filename": "data-00000-of-00001.arrow"
}
],
- "_fingerprint": "97c28d4ad1536e4c",
+ "_fingerprint": "d20c2acf1e107266",
"_format_columns": null,
"_format_kwargs": {},
- "_format_type": "torch",
+ "_format_type": null,
"_output_all_columns": false,
"_split": null
}
\ No newline at end of file
diff --git a/tests/data/lerobot/aloha_sim_insertion_scripted/videos/observation.images.top_episode_000000.mp4 b/tests/data/lerobot/aloha_sim_insertion_scripted/videos/observation.images.top_episode_000000.mp4
new file mode 100644
index 0000000..f36a0c1
Binary files /dev/null and b/tests/data/lerobot/aloha_sim_insertion_scripted/videos/observation.images.top_episode_000000.mp4 differ
diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/episode_data_index.safetensors
index 8685de7..1505d61 100644
Binary files a/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/episode_data_index.safetensors and b/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/episode_data_index.safetensors differ
diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/info.json b/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/info.json
index 02e62b6..279cf2c 100644
--- a/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/info.json
+++ b/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/info.json
@@ -1,3 +1,4 @@
{
- "fps": 50
+ "fps": 50,
+ "video": 1
}
\ No newline at end of file
diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/stats.safetensors b/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/stats.safetensors
index 998e610..2fe6aff 100644
Binary files a/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/stats.safetensors and b/tests/data/lerobot/aloha_sim_transfer_cube_human/meta_data/stats.safetensors differ
diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_human/stats.pth b/tests/data/lerobot/aloha_sim_transfer_cube_human/stats.pth
deleted file mode 100644
index 1ae356e..0000000
Binary files a/tests/data/lerobot/aloha_sim_transfer_cube_human/stats.pth and /dev/null differ
diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow
index 1bb1f51..a9f60d3 100644
Binary files a/tests/data/lerobot/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow and b/tests/data/lerobot/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow differ
diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_human/train/dataset_info.json b/tests/data/lerobot/aloha_sim_transfer_cube_human/train/dataset_info.json
index 69f8083..c6f7b93 100644
--- a/tests/data/lerobot/aloha_sim_transfer_cube_human/train/dataset_info.json
+++ b/tests/data/lerobot/aloha_sim_transfer_cube_human/train/dataset_info.json
@@ -3,7 +3,7 @@
"description": "",
"features": {
"observation.images.top": {
- "_type": "Image"
+ "_type": "VideoFrame"
},
"observation.state": {
"feature": {
diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_human/train/state.json b/tests/data/lerobot/aloha_sim_transfer_cube_human/train/state.json
index d9449a3..eb74ba8 100644
--- a/tests/data/lerobot/aloha_sim_transfer_cube_human/train/state.json
+++ b/tests/data/lerobot/aloha_sim_transfer_cube_human/train/state.json
@@ -4,10 +4,10 @@
"filename": "data-00000-of-00001.arrow"
}
],
- "_fingerprint": "cb9349b5c92951e8",
+ "_fingerprint": "243b01eb8a4b184e",
"_format_columns": null,
"_format_kwargs": {},
- "_format_type": "torch",
+ "_format_type": null,
"_output_all_columns": false,
"_split": null
}
\ No newline at end of file
diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_human/videos/observation.images.top_episode_000000.mp4 b/tests/data/lerobot/aloha_sim_transfer_cube_human/videos/observation.images.top_episode_000000.mp4
new file mode 100644
index 0000000..12a1e5b
Binary files /dev/null and b/tests/data/lerobot/aloha_sim_transfer_cube_human/videos/observation.images.top_episode_000000.mp4 differ
diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/episode_data_index.safetensors b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/episode_data_index.safetensors
index 8685de7..1505d61 100644
Binary files a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/episode_data_index.safetensors and b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/episode_data_index.safetensors differ
diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/info.json b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/info.json
index 02e62b6..279cf2c 100644
--- a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/info.json
+++ b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/info.json
@@ -1,3 +1,4 @@
{
- "fps": 50
+ "fps": 50,
+ "video": 1
}
\ No newline at end of file
diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/stats.safetensors b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/stats.safetensors
index 91696d3..c2ab5b2 100644
Binary files a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/stats.safetensors and b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/meta_data/stats.safetensors differ
diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/stats.pth b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/stats.pth
deleted file mode 100644
index 71547f0..0000000
Binary files a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/stats.pth and /dev/null differ
diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow
index d658a6d..405509d 100644
Binary files a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow and b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow differ
diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/dataset_info.json b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/dataset_info.json
index 69f8083..c6f7b93 100644
--- a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/dataset_info.json
+++ b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/dataset_info.json
@@ -3,7 +3,7 @@
"description": "",
"features": {
"observation.images.top": {
- "_type": "Image"
+ "_type": "VideoFrame"
},
"observation.state": {
"feature": {
diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/state.json b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/state.json
index 2d4dfc6..91c4651 100644
--- a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/state.json
+++ b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/train/state.json
@@ -4,10 +4,10 @@
"filename": "data-00000-of-00001.arrow"
}
],
- "_fingerprint": "e4d7ad2b360db1af",
+ "_fingerprint": "eb759bbf60df7be9",
"_format_columns": null,
"_format_kwargs": {},
- "_format_type": "torch",
+ "_format_type": null,
"_output_all_columns": false,
"_split": null
}
\ No newline at end of file
diff --git a/tests/data/lerobot/aloha_sim_transfer_cube_scripted/videos/observation.images.top_episode_000000.mp4 b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/videos/observation.images.top_episode_000000.mp4
new file mode 100644
index 0000000..2d25242
Binary files /dev/null and b/tests/data/lerobot/aloha_sim_transfer_cube_scripted/videos/observation.images.top_episode_000000.mp4 differ
diff --git a/tests/data/lerobot/pusht/meta_data/episode_data_index.safetensors b/tests/data/lerobot/pusht/meta_data/episode_data_index.safetensors
index 9343d2d..3511c26 100644
Binary files a/tests/data/lerobot/pusht/meta_data/episode_data_index.safetensors and b/tests/data/lerobot/pusht/meta_data/episode_data_index.safetensors differ
diff --git a/tests/data/lerobot/pusht/meta_data/info.json b/tests/data/lerobot/pusht/meta_data/info.json
index 5c9a8ae..b7f3971 100644
--- a/tests/data/lerobot/pusht/meta_data/info.json
+++ b/tests/data/lerobot/pusht/meta_data/info.json
@@ -1,3 +1,4 @@
{
- "fps": 10
+ "fps": 10,
+ "video": 1
}
\ No newline at end of file
diff --git a/tests/data/lerobot/pusht/meta_data/stats.safetensors b/tests/data/lerobot/pusht/meta_data/stats.safetensors
index fa2380e..e4ebbef 100644
Binary files a/tests/data/lerobot/pusht/meta_data/stats.safetensors and b/tests/data/lerobot/pusht/meta_data/stats.safetensors differ
diff --git a/tests/data/lerobot/pusht/stats.pth b/tests/data/lerobot/pusht/stats.pth
deleted file mode 100644
index 636985f..0000000
Binary files a/tests/data/lerobot/pusht/stats.pth and /dev/null differ
diff --git a/tests/data/lerobot/pusht/train/data-00000-of-00001.arrow b/tests/data/lerobot/pusht/train/data-00000-of-00001.arrow
index 5972be9..b99aa29 100644
Binary files a/tests/data/lerobot/pusht/train/data-00000-of-00001.arrow and b/tests/data/lerobot/pusht/train/data-00000-of-00001.arrow differ
diff --git a/tests/data/lerobot/pusht/train/dataset_info.json b/tests/data/lerobot/pusht/train/dataset_info.json
index aefe478..a0db336 100644
--- a/tests/data/lerobot/pusht/train/dataset_info.json
+++ b/tests/data/lerobot/pusht/train/dataset_info.json
@@ -3,7 +3,7 @@
"description": "",
"features": {
"observation.image": {
- "_type": "Image"
+ "_type": "VideoFrame"
},
"observation.state": {
"feature": {
diff --git a/tests/data/lerobot/pusht/train/meta_data/episode_data_index.safetensors b/tests/data/lerobot/pusht/train/meta_data/episode_data_index.safetensors
deleted file mode 100644
index 3511c26..0000000
Binary files a/tests/data/lerobot/pusht/train/meta_data/episode_data_index.safetensors and /dev/null differ
diff --git a/tests/data/lerobot/pusht/train/meta_data/info.json b/tests/data/lerobot/pusht/train/meta_data/info.json
deleted file mode 100644
index 5c9a8ae..0000000
--- a/tests/data/lerobot/pusht/train/meta_data/info.json
+++ /dev/null
@@ -1,3 +0,0 @@
-{
- "fps": 10
-}
\ No newline at end of file
diff --git a/tests/data/lerobot/pusht/train/meta_data/stats_action.safetensors b/tests/data/lerobot/pusht/train/meta_data/stats_action.safetensors
deleted file mode 100644
index 2c2553b..0000000
Binary files a/tests/data/lerobot/pusht/train/meta_data/stats_action.safetensors and /dev/null differ
diff --git a/tests/data/lerobot/pusht/train/meta_data/stats_observation.image.safetensors b/tests/data/lerobot/pusht/train/meta_data/stats_observation.image.safetensors
deleted file mode 100644
index 0a145d4..0000000
Binary files a/tests/data/lerobot/pusht/train/meta_data/stats_observation.image.safetensors and /dev/null differ
diff --git a/tests/data/lerobot/pusht/train/meta_data/stats_observation.state.safetensors b/tests/data/lerobot/pusht/train/meta_data/stats_observation.state.safetensors
deleted file mode 100644
index 28ee285..0000000
Binary files a/tests/data/lerobot/pusht/train/meta_data/stats_observation.state.safetensors and /dev/null differ
diff --git a/tests/data/lerobot/pusht/train/state.json b/tests/data/lerobot/pusht/train/state.json
index dda3f88..776f29f 100644
--- a/tests/data/lerobot/pusht/train/state.json
+++ b/tests/data/lerobot/pusht/train/state.json
@@ -4,10 +4,10 @@
"filename": "data-00000-of-00001.arrow"
}
],
- "_fingerprint": "a04a9ce660122e23",
+ "_fingerprint": "3e02d7879f423c56",
"_format_columns": null,
"_format_kwargs": {},
- "_format_type": "torch",
+ "_format_type": null,
"_output_all_columns": false,
"_split": null
}
\ No newline at end of file
diff --git a/tests/data/lerobot/pusht/videos/observation.image_episode_000000.mp4 b/tests/data/lerobot/pusht/videos/observation.image_episode_000000.mp4
new file mode 100644
index 0000000..b2040bd
Binary files /dev/null and b/tests/data/lerobot/pusht/videos/observation.image_episode_000000.mp4 differ
diff --git a/tests/data/lerobot/umi_cup_in_the_wild/meta_data/episode_data_index.safetensors b/tests/data/lerobot/umi_cup_in_the_wild/meta_data/episode_data_index.safetensors
index 0ba6962..1505d61 100644
Binary files a/tests/data/lerobot/umi_cup_in_the_wild/meta_data/episode_data_index.safetensors and b/tests/data/lerobot/umi_cup_in_the_wild/meta_data/episode_data_index.safetensors differ
diff --git a/tests/data/lerobot/umi_cup_in_the_wild/meta_data/info.json b/tests/data/lerobot/umi_cup_in_the_wild/meta_data/info.json
index 5c9a8ae..b7f3971 100644
--- a/tests/data/lerobot/umi_cup_in_the_wild/meta_data/info.json
+++ b/tests/data/lerobot/umi_cup_in_the_wild/meta_data/info.json
@@ -1,3 +1,4 @@
{
- "fps": 10
+ "fps": 10,
+ "video": 1
}
\ No newline at end of file
diff --git a/tests/data/lerobot/umi_cup_in_the_wild/meta_data/stats.safetensors b/tests/data/lerobot/umi_cup_in_the_wild/meta_data/stats.safetensors
index 0f964fa..d936f44 100644
Binary files a/tests/data/lerobot/umi_cup_in_the_wild/meta_data/stats.safetensors and b/tests/data/lerobot/umi_cup_in_the_wild/meta_data/stats.safetensors differ
diff --git a/tests/data/lerobot/umi_cup_in_the_wild/train/data-00000-of-00001.arrow b/tests/data/lerobot/umi_cup_in_the_wild/train/data-00000-of-00001.arrow
index 272e0fb..11f45a5 100644
Binary files a/tests/data/lerobot/umi_cup_in_the_wild/train/data-00000-of-00001.arrow and b/tests/data/lerobot/umi_cup_in_the_wild/train/data-00000-of-00001.arrow differ
diff --git a/tests/data/lerobot/umi_cup_in_the_wild/train/dataset_info.json b/tests/data/lerobot/umi_cup_in_the_wild/train/dataset_info.json
index c05585f..f590f3e 100644
--- a/tests/data/lerobot/umi_cup_in_the_wild/train/dataset_info.json
+++ b/tests/data/lerobot/umi_cup_in_the_wild/train/dataset_info.json
@@ -2,6 +2,9 @@
"citation": "",
"description": "",
"features": {
+ "observation.image": {
+ "_type": "VideoFrame"
+ },
"observation.state": {
"feature": {
"dtype": "float32",
@@ -57,9 +60,6 @@
"index": {
"dtype": "int64",
"_type": "Value"
- },
- "observation.image": {
- "_type": "Image"
}
},
"homepage": "",
diff --git a/tests/data/lerobot/umi_cup_in_the_wild/train/state.json b/tests/data/lerobot/umi_cup_in_the_wild/train/state.json
index f1f1b6e..80e610b 100644
--- a/tests/data/lerobot/umi_cup_in_the_wild/train/state.json
+++ b/tests/data/lerobot/umi_cup_in_the_wild/train/state.json
@@ -4,10 +4,10 @@
"filename": "data-00000-of-00001.arrow"
}
],
- "_fingerprint": "fd95ee932cb1fce2",
+ "_fingerprint": "c8b78ec1bbf7a579",
"_format_columns": null,
"_format_kwargs": {},
- "_format_type": "torch",
+ "_format_type": null,
"_output_all_columns": false,
"_split": null
}
\ No newline at end of file
diff --git a/tests/data/lerobot/umi_cup_in_the_wild/videos/observation.image_episode_000000.mp4 b/tests/data/lerobot/umi_cup_in_the_wild/videos/observation.image_episode_000000.mp4
new file mode 100644
index 0000000..3266cf7
Binary files /dev/null and b/tests/data/lerobot/umi_cup_in_the_wild/videos/observation.image_episode_000000.mp4 differ
diff --git a/tests/data/lerobot/xarm_lift_medium/meta_data/episode_data_index.safetensors b/tests/data/lerobot/xarm_lift_medium/meta_data/episode_data_index.safetensors
index 7216093..f5e09ec 100644
Binary files a/tests/data/lerobot/xarm_lift_medium/meta_data/episode_data_index.safetensors and b/tests/data/lerobot/xarm_lift_medium/meta_data/episode_data_index.safetensors differ
diff --git a/tests/data/lerobot/xarm_lift_medium/meta_data/info.json b/tests/data/lerobot/xarm_lift_medium/meta_data/info.json
index f9d6b30..d73052c 100644
--- a/tests/data/lerobot/xarm_lift_medium/meta_data/info.json
+++ b/tests/data/lerobot/xarm_lift_medium/meta_data/info.json
@@ -1,3 +1,4 @@
{
- "fps": 15
+ "fps": 15,
+ "video": 1
}
\ No newline at end of file
diff --git a/tests/data/lerobot/xarm_lift_medium/meta_data/stats.safetensors b/tests/data/lerobot/xarm_lift_medium/meta_data/stats.safetensors
index bdcc1b0..712c625 100644
Binary files a/tests/data/lerobot/xarm_lift_medium/meta_data/stats.safetensors and b/tests/data/lerobot/xarm_lift_medium/meta_data/stats.safetensors differ
diff --git a/tests/data/lerobot/xarm_lift_medium/stats.pth b/tests/data/lerobot/xarm_lift_medium/stats.pth
deleted file mode 100644
index 3ab4e05..0000000
Binary files a/tests/data/lerobot/xarm_lift_medium/stats.pth and /dev/null differ
diff --git a/tests/data/lerobot/xarm_lift_medium/train/data-00000-of-00001.arrow b/tests/data/lerobot/xarm_lift_medium/train/data-00000-of-00001.arrow
index d621210..9625a74 100644
Binary files a/tests/data/lerobot/xarm_lift_medium/train/data-00000-of-00001.arrow and b/tests/data/lerobot/xarm_lift_medium/train/data-00000-of-00001.arrow differ
diff --git a/tests/data/lerobot/xarm_lift_medium/train/dataset_info.json b/tests/data/lerobot/xarm_lift_medium/train/dataset_info.json
index 59a43bd..3791dee 100644
--- a/tests/data/lerobot/xarm_lift_medium/train/dataset_info.json
+++ b/tests/data/lerobot/xarm_lift_medium/train/dataset_info.json
@@ -3,7 +3,7 @@
"description": "",
"features": {
"observation.image": {
- "_type": "Image"
+ "_type": "VideoFrame"
},
"observation.state": {
"feature": {
diff --git a/tests/data/lerobot/xarm_lift_medium/train/state.json b/tests/data/lerobot/xarm_lift_medium/train/state.json
index 642fda3..3989b59 100644
--- a/tests/data/lerobot/xarm_lift_medium/train/state.json
+++ b/tests/data/lerobot/xarm_lift_medium/train/state.json
@@ -4,10 +4,10 @@
"filename": "data-00000-of-00001.arrow"
}
],
- "_fingerprint": "cc6afdfcdd6f63ab",
+ "_fingerprint": "720072274a55db4d",
"_format_columns": null,
"_format_kwargs": {},
- "_format_type": "torch",
+ "_format_type": null,
"_output_all_columns": false,
"_split": null
}
\ No newline at end of file
diff --git a/tests/data/lerobot/xarm_lift_medium/videos/observation.image_episode_000000.mp4 b/tests/data/lerobot/xarm_lift_medium/videos/observation.image_episode_000000.mp4
new file mode 100644
index 0000000..618b888
Binary files /dev/null and b/tests/data/lerobot/xarm_lift_medium/videos/observation.image_episode_000000.mp4 differ
diff --git a/tests/data/lerobot/xarm_lift_medium_replay/meta_data/episode_data_index.safetensors b/tests/data/lerobot/xarm_lift_medium_replay/meta_data/episode_data_index.safetensors
index 7216093..f5e09ec 100644
Binary files a/tests/data/lerobot/xarm_lift_medium_replay/meta_data/episode_data_index.safetensors and b/tests/data/lerobot/xarm_lift_medium_replay/meta_data/episode_data_index.safetensors differ
diff --git a/tests/data/lerobot/xarm_lift_medium_replay/meta_data/info.json b/tests/data/lerobot/xarm_lift_medium_replay/meta_data/info.json
index f9d6b30..d73052c 100644
--- a/tests/data/lerobot/xarm_lift_medium_replay/meta_data/info.json
+++ b/tests/data/lerobot/xarm_lift_medium_replay/meta_data/info.json
@@ -1,3 +1,4 @@
{
- "fps": 15
+ "fps": 15,
+ "video": 1
}
\ No newline at end of file
diff --git a/tests/data/lerobot/xarm_lift_medium_replay/meta_data/stats.safetensors b/tests/data/lerobot/xarm_lift_medium_replay/meta_data/stats.safetensors
index 4808895..a7548ba 100644
Binary files a/tests/data/lerobot/xarm_lift_medium_replay/meta_data/stats.safetensors and b/tests/data/lerobot/xarm_lift_medium_replay/meta_data/stats.safetensors differ
diff --git a/tests/data/lerobot/xarm_lift_medium_replay/train/data-00000-of-00001.arrow b/tests/data/lerobot/xarm_lift_medium_replay/train/data-00000-of-00001.arrow
index b524811..102a615 100644
Binary files a/tests/data/lerobot/xarm_lift_medium_replay/train/data-00000-of-00001.arrow and b/tests/data/lerobot/xarm_lift_medium_replay/train/data-00000-of-00001.arrow differ
diff --git a/tests/data/lerobot/xarm_lift_medium_replay/train/dataset_info.json b/tests/data/lerobot/xarm_lift_medium_replay/train/dataset_info.json
index 59a43bd..69bf84e 100644
--- a/tests/data/lerobot/xarm_lift_medium_replay/train/dataset_info.json
+++ b/tests/data/lerobot/xarm_lift_medium_replay/train/dataset_info.json
@@ -3,7 +3,7 @@
"description": "",
"features": {
"observation.image": {
- "_type": "Image"
+ "_type": "VideoFrame"
},
"observation.state": {
"feature": {
@@ -18,7 +18,7 @@
"dtype": "float32",
"_type": "Value"
},
- "length": 4,
+ "length": 3,
"_type": "Sequence"
},
"episode_index": {
diff --git a/tests/data/lerobot/xarm_lift_medium_replay/train/state.json b/tests/data/lerobot/xarm_lift_medium_replay/train/state.json
index e9b74d7..6522dcb 100644
--- a/tests/data/lerobot/xarm_lift_medium_replay/train/state.json
+++ b/tests/data/lerobot/xarm_lift_medium_replay/train/state.json
@@ -4,10 +4,10 @@
"filename": "data-00000-of-00001.arrow"
}
],
- "_fingerprint": "9f8e1a8c1845df55",
+ "_fingerprint": "9f3d8cbb0b2e74a2",
"_format_columns": null,
"_format_kwargs": {},
- "_format_type": "torch",
+ "_format_type": null,
"_output_all_columns": false,
"_split": null
}
\ No newline at end of file
diff --git a/tests/data/lerobot/xarm_lift_medium_replay/videos/observation.image_episode_000000.mp4 b/tests/data/lerobot/xarm_lift_medium_replay/videos/observation.image_episode_000000.mp4
new file mode 100644
index 0000000..f1089c1
Binary files /dev/null and b/tests/data/lerobot/xarm_lift_medium_replay/videos/observation.image_episode_000000.mp4 differ
diff --git a/tests/data/lerobot/xarm_push_medium/meta_data/episode_data_index.safetensors b/tests/data/lerobot/xarm_push_medium/meta_data/episode_data_index.safetensors
index 7216093..f5e09ec 100644
Binary files a/tests/data/lerobot/xarm_push_medium/meta_data/episode_data_index.safetensors and b/tests/data/lerobot/xarm_push_medium/meta_data/episode_data_index.safetensors differ
diff --git a/tests/data/lerobot/xarm_push_medium/meta_data/info.json b/tests/data/lerobot/xarm_push_medium/meta_data/info.json
index f9d6b30..d73052c 100644
--- a/tests/data/lerobot/xarm_push_medium/meta_data/info.json
+++ b/tests/data/lerobot/xarm_push_medium/meta_data/info.json
@@ -1,3 +1,4 @@
{
- "fps": 15
+ "fps": 15,
+ "video": 1
}
\ No newline at end of file
diff --git a/tests/data/lerobot/xarm_push_medium/meta_data/stats.safetensors b/tests/data/lerobot/xarm_push_medium/meta_data/stats.safetensors
index f216e05..a7548ba 100644
Binary files a/tests/data/lerobot/xarm_push_medium/meta_data/stats.safetensors and b/tests/data/lerobot/xarm_push_medium/meta_data/stats.safetensors differ
diff --git a/tests/data/lerobot/xarm_push_medium/train/data-00000-of-00001.arrow b/tests/data/lerobot/xarm_push_medium/train/data-00000-of-00001.arrow
index 241117c..102a615 100644
Binary files a/tests/data/lerobot/xarm_push_medium/train/data-00000-of-00001.arrow and b/tests/data/lerobot/xarm_push_medium/train/data-00000-of-00001.arrow differ
diff --git a/tests/data/lerobot/xarm_push_medium/train/dataset_info.json b/tests/data/lerobot/xarm_push_medium/train/dataset_info.json
index 9e47b34..69bf84e 100644
--- a/tests/data/lerobot/xarm_push_medium/train/dataset_info.json
+++ b/tests/data/lerobot/xarm_push_medium/train/dataset_info.json
@@ -3,7 +3,7 @@
"description": "",
"features": {
"observation.image": {
- "_type": "Image"
+ "_type": "VideoFrame"
},
"observation.state": {
"feature": {
diff --git a/tests/data/lerobot/xarm_push_medium/train/state.json b/tests/data/lerobot/xarm_push_medium/train/state.json
index 0ec1f04..6522dcb 100644
--- a/tests/data/lerobot/xarm_push_medium/train/state.json
+++ b/tests/data/lerobot/xarm_push_medium/train/state.json
@@ -4,10 +4,10 @@
"filename": "data-00000-of-00001.arrow"
}
],
- "_fingerprint": "c900258061dd0b3f",
+ "_fingerprint": "9f3d8cbb0b2e74a2",
"_format_columns": null,
"_format_kwargs": {},
- "_format_type": "torch",
+ "_format_type": null,
"_output_all_columns": false,
"_split": null
}
\ No newline at end of file
diff --git a/tests/data/lerobot/xarm_push_medium/videos/observation.image_episode_000000.mp4 b/tests/data/lerobot/xarm_push_medium/videos/observation.image_episode_000000.mp4
new file mode 100644
index 0000000..f1089c1
Binary files /dev/null and b/tests/data/lerobot/xarm_push_medium/videos/observation.image_episode_000000.mp4 differ
diff --git a/tests/data/lerobot/xarm_push_medium_replay/meta_data/episode_data_index.safetensors b/tests/data/lerobot/xarm_push_medium_replay/meta_data/episode_data_index.safetensors
index 7216093..f5e09ec 100644
Binary files a/tests/data/lerobot/xarm_push_medium_replay/meta_data/episode_data_index.safetensors and b/tests/data/lerobot/xarm_push_medium_replay/meta_data/episode_data_index.safetensors differ
diff --git a/tests/data/lerobot/xarm_push_medium_replay/meta_data/info.json b/tests/data/lerobot/xarm_push_medium_replay/meta_data/info.json
index f9d6b30..d73052c 100644
--- a/tests/data/lerobot/xarm_push_medium_replay/meta_data/info.json
+++ b/tests/data/lerobot/xarm_push_medium_replay/meta_data/info.json
@@ -1,3 +1,4 @@
{
- "fps": 15
+ "fps": 15,
+ "video": 1
}
\ No newline at end of file
diff --git a/tests/data/lerobot/xarm_push_medium_replay/meta_data/stats.safetensors b/tests/data/lerobot/xarm_push_medium_replay/meta_data/stats.safetensors
index 0de4755..a7548ba 100644
Binary files a/tests/data/lerobot/xarm_push_medium_replay/meta_data/stats.safetensors and b/tests/data/lerobot/xarm_push_medium_replay/meta_data/stats.safetensors differ
diff --git a/tests/data/lerobot/xarm_push_medium_replay/train/data-00000-of-00001.arrow b/tests/data/lerobot/xarm_push_medium_replay/train/data-00000-of-00001.arrow
index 2e07ea9..102a615 100644
Binary files a/tests/data/lerobot/xarm_push_medium_replay/train/data-00000-of-00001.arrow and b/tests/data/lerobot/xarm_push_medium_replay/train/data-00000-of-00001.arrow differ
diff --git a/tests/data/lerobot/xarm_push_medium_replay/train/dataset_info.json b/tests/data/lerobot/xarm_push_medium_replay/train/dataset_info.json
index 9e47b34..69bf84e 100644
--- a/tests/data/lerobot/xarm_push_medium_replay/train/dataset_info.json
+++ b/tests/data/lerobot/xarm_push_medium_replay/train/dataset_info.json
@@ -3,7 +3,7 @@
"description": "",
"features": {
"observation.image": {
- "_type": "Image"
+ "_type": "VideoFrame"
},
"observation.state": {
"feature": {
diff --git a/tests/data/lerobot/xarm_push_medium_replay/train/state.json b/tests/data/lerobot/xarm_push_medium_replay/train/state.json
index 39ffeaf..6522dcb 100644
--- a/tests/data/lerobot/xarm_push_medium_replay/train/state.json
+++ b/tests/data/lerobot/xarm_push_medium_replay/train/state.json
@@ -4,10 +4,10 @@
"filename": "data-00000-of-00001.arrow"
}
],
- "_fingerprint": "e51c80a33c7688c0",
+ "_fingerprint": "9f3d8cbb0b2e74a2",
"_format_columns": null,
"_format_kwargs": {},
- "_format_type": "torch",
+ "_format_type": null,
"_output_all_columns": false,
"_split": null
}
\ No newline at end of file
diff --git a/tests/data/lerobot/xarm_push_medium_replay/videos/observation.image_episode_000000.mp4 b/tests/data/lerobot/xarm_push_medium_replay/videos/observation.image_episode_000000.mp4
new file mode 100644
index 0000000..f1089c1
Binary files /dev/null and b/tests/data/lerobot/xarm_push_medium_replay/videos/observation.image_episode_000000.mp4 differ
diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_0.safetensors
index 862de6b..2a89688 100644
Binary files a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_0.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_0.safetensors differ
diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_1.safetensors
index 56fa9b8..b144d76 100644
Binary files a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_1.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_1.safetensors differ
diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_250.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_250.safetensors
index 71497de..9c1ab2f 100644
Binary files a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_250.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_250.safetensors differ
diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_251.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_251.safetensors
index 3dd76d1..b631637 100644
Binary files a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_251.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_251.safetensors differ
diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_498.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_498.safetensors
index 9182284..f61e9d2 100644
Binary files a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_498.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_498.safetensors differ
diff --git a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_499.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_499.safetensors
index 3b5f440..80a3642 100644
Binary files a/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_499.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/aloha_sim_insertion_human/frame_499.safetensors differ
diff --git a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_0.safetensors
index 1bb7c06..6a3f3e5 100644
Binary files a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_0.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_0.safetensors differ
diff --git a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_1.safetensors
index ae46012..69c8995 100644
Binary files a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_1.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_1.safetensors differ
diff --git a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_159.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_159.safetensors
index 2b5729d..cbf1f1f 100644
Binary files a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_159.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_159.safetensors differ
diff --git a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_160.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_160.safetensors
index a048c0c..2107611 100644
Binary files a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_160.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_160.safetensors differ
diff --git a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_80.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_80.safetensors
index e37d54c..94c55fe 100644
Binary files a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_80.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_80.safetensors differ
diff --git a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_81.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_81.safetensors
index 5cd8451..7f63f83 100644
Binary files a/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_81.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/pusht/frame_81.safetensors differ
diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_0.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_0.safetensors
index 98f5562..d256267 100644
Binary files a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_0.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_0.safetensors differ
diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_1.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_1.safetensors
index c33aa6b..5d0e800 100644
Binary files a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_1.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_1.safetensors differ
diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_12.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_12.safetensors
index 2980af2..4c3be9f 100644
Binary files a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_12.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_12.safetensors differ
diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_13.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_13.safetensors
index 06d34dc..bc3f3a3 100644
Binary files a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_13.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_13.safetensors differ
diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_23.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_23.safetensors
index 39e878c..9683bb8 100644
Binary files a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_23.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_23.safetensors differ
diff --git a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_24.safetensors b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_24.safetensors
index 2822763..0777e9a 100644
Binary files a/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_24.safetensors and b/tests/data/save_dataset_to_safetensors/lerobot/xarm_lift_medium/frame_24.safetensors differ
diff --git a/tests/scripts/save_dataset_to_safetensors.py b/tests/scripts/save_dataset_to_safetensors.py
index e3f5a9e..b4b0f76 100644
--- a/tests/scripts/save_dataset_to_safetensors.py
+++ b/tests/scripts/save_dataset_to_safetensors.py
@@ -8,7 +8,7 @@ If you know that your change will break backward compatibility, you should write
doesnt need to be merged into the `main` branch. Then you need to run this script and update the tests artifacts.
Example usage:
- `python tests/script/save_dataset_to_safetensors.py`
+ `python tests/scripts/save_dataset_to_safetensors.py`
"""
import os
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index 2f5d45a..e4be423 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -15,10 +15,12 @@ from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
)
-from lerobot.common.datasets.utils import (
+from lerobot.common.datasets.push_dataset_to_hub.compute_stats import (
compute_stats,
- flatten_dict,
get_stats_einops_patterns,
+)
+from lerobot.common.datasets.utils import (
+ flatten_dict,
hf_transform_to_torch,
load_previous_and_future_frames,
unflatten_dict,
@@ -105,15 +107,15 @@ def test_compute_stats_on_xarm():
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
# dataset into even batches.
- computed_stats = compute_stats(dataset.hf_dataset, batch_size=int(len(dataset) * 0.25))
+ computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25), num_workers=0)
# get einops patterns to aggregate batches and compute statistics
- stats_patterns = get_stats_einops_patterns(dataset.hf_dataset)
+ stats_patterns = get_stats_einops_patterns(dataset)
# get all frames from the dataset in the same dtype and range as during compute_stats
dataloader = torch.utils.data.DataLoader(
dataset,
- num_workers=8,
+ num_workers=0,
batch_size=len(dataset),
shuffle=False,
)