import sys sys.path.append('./') import torch import sapien.core as sapien import os import numpy as np from envs import * import pathlib import argparse import yaml from datetime import datetime import importlib def class_decorator(task_name): envs_module = importlib.import_module(f'envs.{task_name}') try: env_class = getattr(envs_module, task_name) env_instance = env_class() except: raise SystemExit("No Task") return env_instance def load_model(model_path): model = torch.load(model_path) model.eval() return model TASK = None def main(args): global TASK TASK = args.task_name print('Task name:', TASK) task = class_decorator(args['task_name']) st_seed = 100000 suc_nums = [] test_num = 100 policy = YOUR_POLICY() # TODO: init your policy st_seed, suc_num = test_policy(task, args, policy, st_seed, test_num=test_num) suc_nums.append(suc_num) save_dir = f'result/{TASK}' file_path = os.path.join(save_dir, f'result.txt') current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") if not os.path.exists(os.path.dirname(file_path)): os.makedirs(os.path.dirname(file_path)) with open(file_path, 'w') as file: file.write(f'Timestamp: {current_time}\n\n') file.write(f'Success Rate:\n') file.write(f'\n'.join(map(str, np.array(suc_nums) / test_num))) file.write('\n\n') print(f'Data has been saved to {file_path}') def test_policy(Demo_class, args, policy, st_seed, test_num=20): global TASK epid = 0 seed_list=[] suc_num = 0 expert_check = True print("Task name: ",args["task_name"]) Demo_class.suc = 0 Demo_class.test_num =0 now_id = 0 succ_seed = 0 suc_test_seed_list = [] now_seed = st_seed while succ_seed < test_num: render_freq = args['render_freq'] args['render_freq'] = 0 if expert_check: # if even expert can't finish the task, skip this seed. try: Demo_class.setup_demo(now_ep_num=now_id, seed = now_seed, ** args) Demo_class.play_once() Demo_class.close() except: Demo_class.close() now_seed += 1 args['render_freq'] = render_freq print('error occurs !') continue if (not expert_check) or ( Demo_class.plan_success and Demo_class.check_success() ): succ_seed +=1 suc_test_seed_list.append(now_seed) else: now_seed += 1 args['render_freq'] = render_freq continue args['render_freq'] = render_freq Demo_class.setup_demo(now_ep_num=now_id, seed = now_seed, is_test = True, ** args) Demo_class.apply_policy_demo(policy) now_id += 1 Demo_class.close() if Demo_class.render_freq: Demo_class.viewer.close() print(f"{TASK} success rate: {Demo_class.suc}/{Demo_class.test_num}, current seed: {now_seed}\n") now_seed += 1 return now_seed, Demo_class.suc if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('task_name', type=str, help="input task name") args = parser.parse_args() main(args)