Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com> Co-authored-by: Michel Aractingi <michel.aractingi@gmail.com> Co-authored-by: Eugene Mironov <helper2424@gmail.com> Co-authored-by: imstevenpmwork <steven.palma@huggingface.co> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
75 lines
2.3 KiB
Python
75 lines
2.3 KiB
Python
# !/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
|
|
from lerobot.common.cameras import opencv # noqa: F401
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.common.policies.factory import make_policy
|
|
from lerobot.common.robots import ( # noqa: F401
|
|
RobotConfig,
|
|
make_robot_from_config,
|
|
so100_follower,
|
|
)
|
|
from lerobot.common.teleoperators import (
|
|
gamepad, # noqa: F401
|
|
so101_leader, # noqa: F401
|
|
)
|
|
from lerobot.configs import parser
|
|
from lerobot.configs.train import TrainRLServerPipelineConfig
|
|
from lerobot.scripts.rl.gym_manipulator import make_robot_env
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
def eval_policy(env, policy, n_episodes):
|
|
sum_reward_episode = []
|
|
for _ in range(n_episodes):
|
|
obs, _ = env.reset()
|
|
episode_reward = 0.0
|
|
while True:
|
|
action = policy.select_action(obs)
|
|
obs, reward, terminated, truncated, _ = env.step(action)
|
|
episode_reward += reward
|
|
if terminated or truncated:
|
|
break
|
|
sum_reward_episode.append(episode_reward)
|
|
|
|
logging.info(f"Success after 20 steps {sum_reward_episode}")
|
|
logging.info(f"success rate {sum(sum_reward_episode) / len(sum_reward_episode)}")
|
|
|
|
|
|
@parser.wrap()
|
|
def main(cfg: TrainRLServerPipelineConfig):
|
|
env_cfg = cfg.env
|
|
env = make_robot_env(env_cfg)
|
|
dataset_cfg = cfg.dataset
|
|
dataset = LeRobotDataset(repo_id=dataset_cfg.repo_id)
|
|
dataset_meta = dataset.meta
|
|
|
|
policy = make_policy(
|
|
cfg=cfg.policy,
|
|
# env_cfg=cfg.env,
|
|
ds_meta=dataset_meta,
|
|
)
|
|
policy.from_pretrained(env_cfg.pretrained_policy_name_or_path)
|
|
policy.eval()
|
|
|
|
eval_policy(env, policy=policy, n_episodes=10)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|