RoboTwin_image/script/eval_policy.py
2025-07-02 03:13:07 +00:00

129 lines
3.3 KiB
Python

import sys
sys.path.append('./')
import torch
import sapien.core as sapien
import os
import numpy as np
from envs import *
import pathlib
import argparse
import yaml
from datetime import datetime
import importlib
def class_decorator(task_name):
envs_module = importlib.import_module(f'envs.{task_name}')
try:
env_class = getattr(envs_module, task_name)
env_instance = env_class()
except:
raise SystemExit("No Task")
return env_instance
def load_model(model_path):
model = torch.load(model_path)
model.eval()
return model
TASK = None
def main(args):
global TASK
TASK = args.task_name
print('Task name:', TASK)
task = class_decorator(args['task_name'])
st_seed = 100000
suc_nums = []
test_num = 100
policy = YOUR_POLICY() # TODO: init your policy
st_seed, suc_num = test_policy(task, args, policy, st_seed, test_num=test_num)
suc_nums.append(suc_num)
save_dir = f'result/{TASK}'
file_path = os.path.join(save_dir, f'result.txt')
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
if not os.path.exists(os.path.dirname(file_path)):
os.makedirs(os.path.dirname(file_path))
with open(file_path, 'w') as file:
file.write(f'Timestamp: {current_time}\n\n')
file.write(f'Success Rate:\n')
file.write(f'\n'.join(map(str, np.array(suc_nums) / test_num)))
file.write('\n\n')
print(f'Data has been saved to {file_path}')
def test_policy(Demo_class, args, policy, st_seed, test_num=20):
global TASK
epid = 0
seed_list=[]
suc_num = 0
expert_check = True
print("Task name: ",args["task_name"])
Demo_class.suc = 0
Demo_class.test_num =0
now_id = 0
succ_seed = 0
suc_test_seed_list = []
now_seed = st_seed
while succ_seed < test_num:
render_freq = args['render_freq']
args['render_freq'] = 0
if expert_check: # if even expert can't finish the task, skip this seed.
try:
Demo_class.setup_demo(now_ep_num=now_id, seed = now_seed, ** args)
Demo_class.play_once()
Demo_class.close()
except:
Demo_class.close()
now_seed += 1
args['render_freq'] = render_freq
print('error occurs !')
continue
if (not expert_check) or ( Demo_class.plan_success and Demo_class.check_success() ):
succ_seed +=1
suc_test_seed_list.append(now_seed)
else:
now_seed += 1
args['render_freq'] = render_freq
continue
args['render_freq'] = render_freq
Demo_class.setup_demo(now_ep_num=now_id, seed = now_seed, is_test = True, ** args)
Demo_class.apply_policy_demo(policy)
now_id += 1
Demo_class.close()
if Demo_class.render_freq:
Demo_class.viewer.close()
print(f"{TASK} success rate: {Demo_class.suc}/{Demo_class.test_num}, current seed: {now_seed}\n")
now_seed += 1
return now_seed, Demo_class.suc
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('task_name', type=str, help="input task name")
args = parser.parse_args()
main(args)