Lerobot/lerobot/scripts/rl/eval_policy.py
Pepijn 438334d58e
Add sim tutorial, fix lekiwi motor config, add notebook links (#1275)
Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
Co-authored-by: Michel Aractingi <michel.aractingi@gmail.com>
Co-authored-by: Eugene Mironov <helper2424@gmail.com>
Co-authored-by: imstevenpmwork <steven.palma@huggingface.co>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-06-13 18:48:39 +02:00

75 lines
2.3 KiB
Python

# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from lerobot.common.cameras import opencv # noqa: F401
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.robots import ( # noqa: F401
RobotConfig,
make_robot_from_config,
so100_follower,
)
from lerobot.common.teleoperators import (
gamepad, # noqa: F401
so101_leader, # noqa: F401
)
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.scripts.rl.gym_manipulator import make_robot_env
logging.basicConfig(level=logging.INFO)
def eval_policy(env, policy, n_episodes):
sum_reward_episode = []
for _ in range(n_episodes):
obs, _ = env.reset()
episode_reward = 0.0
while True:
action = policy.select_action(obs)
obs, reward, terminated, truncated, _ = env.step(action)
episode_reward += reward
if terminated or truncated:
break
sum_reward_episode.append(episode_reward)
logging.info(f"Success after 20 steps {sum_reward_episode}")
logging.info(f"success rate {sum(sum_reward_episode) / len(sum_reward_episode)}")
@parser.wrap()
def main(cfg: TrainRLServerPipelineConfig):
env_cfg = cfg.env
env = make_robot_env(env_cfg)
dataset_cfg = cfg.dataset
dataset = LeRobotDataset(repo_id=dataset_cfg.repo_id)
dataset_meta = dataset.meta
policy = make_policy(
cfg=cfg.policy,
# env_cfg=cfg.env,
ds_meta=dataset_meta,
)
policy.from_pretrained(env_cfg.pretrained_policy_name_or_path)
policy.eval()
eval_policy(env, policy=policy, n_episodes=10)
if __name__ == "__main__":
main()