129 lines
3.3 KiB
Python
129 lines
3.3 KiB
Python
import sys
|
|
sys.path.append('./')
|
|
|
|
import torch
|
|
import sapien.core as sapien
|
|
import os
|
|
import numpy as np
|
|
from envs import *
|
|
import pathlib
|
|
import argparse
|
|
|
|
import yaml
|
|
from datetime import datetime
|
|
import importlib
|
|
|
|
def class_decorator(task_name):
|
|
envs_module = importlib.import_module(f'envs.{task_name}')
|
|
try:
|
|
env_class = getattr(envs_module, task_name)
|
|
env_instance = env_class()
|
|
except:
|
|
raise SystemExit("No Task")
|
|
return env_instance
|
|
|
|
|
|
def load_model(model_path):
|
|
model = torch.load(model_path)
|
|
model.eval()
|
|
return model
|
|
|
|
TASK = None
|
|
|
|
def main(args):
|
|
global TASK
|
|
TASK = args.task_name
|
|
print('Task name:', TASK)
|
|
|
|
task = class_decorator(args['task_name'])
|
|
|
|
st_seed = 100000
|
|
suc_nums = []
|
|
test_num = 100
|
|
|
|
policy = YOUR_POLICY() # TODO: init your policy
|
|
|
|
st_seed, suc_num = test_policy(task, args, policy, st_seed, test_num=test_num)
|
|
|
|
suc_nums.append(suc_num)
|
|
save_dir = f'result/{TASK}'
|
|
|
|
file_path = os.path.join(save_dir, f'result.txt')
|
|
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
if not os.path.exists(os.path.dirname(file_path)):
|
|
os.makedirs(os.path.dirname(file_path))
|
|
|
|
with open(file_path, 'w') as file:
|
|
file.write(f'Timestamp: {current_time}\n\n')
|
|
file.write(f'Success Rate:\n')
|
|
file.write(f'\n'.join(map(str, np.array(suc_nums) / test_num)))
|
|
file.write('\n\n')
|
|
|
|
print(f'Data has been saved to {file_path}')
|
|
|
|
|
|
def test_policy(Demo_class, args, policy, st_seed, test_num=20):
|
|
global TASK
|
|
epid = 0
|
|
seed_list=[]
|
|
suc_num = 0
|
|
expert_check = True
|
|
print("Task name: ",args["task_name"])
|
|
|
|
|
|
Demo_class.suc = 0
|
|
Demo_class.test_num =0
|
|
|
|
now_id = 0
|
|
succ_seed = 0
|
|
suc_test_seed_list = []
|
|
|
|
|
|
now_seed = st_seed
|
|
while succ_seed < test_num:
|
|
render_freq = args['render_freq']
|
|
args['render_freq'] = 0
|
|
|
|
if expert_check: # if even expert can't finish the task, skip this seed.
|
|
try:
|
|
Demo_class.setup_demo(now_ep_num=now_id, seed = now_seed, ** args)
|
|
Demo_class.play_once()
|
|
Demo_class.close()
|
|
except:
|
|
Demo_class.close()
|
|
now_seed += 1
|
|
args['render_freq'] = render_freq
|
|
print('error occurs !')
|
|
continue
|
|
|
|
if (not expert_check) or ( Demo_class.plan_success and Demo_class.check_success() ):
|
|
succ_seed +=1
|
|
suc_test_seed_list.append(now_seed)
|
|
else:
|
|
now_seed += 1
|
|
args['render_freq'] = render_freq
|
|
continue
|
|
|
|
|
|
args['render_freq'] = render_freq
|
|
|
|
Demo_class.setup_demo(now_ep_num=now_id, seed = now_seed, is_test = True, ** args)
|
|
Demo_class.apply_policy_demo(policy)
|
|
|
|
now_id += 1
|
|
Demo_class.close()
|
|
if Demo_class.render_freq:
|
|
Demo_class.viewer.close()
|
|
|
|
print(f"{TASK} success rate: {Demo_class.suc}/{Demo_class.test_num}, current seed: {now_seed}\n")
|
|
now_seed += 1
|
|
|
|
return now_seed, Demo_class.suc
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('task_name', type=str, help="input task name")
|
|
args = parser.parse_args()
|
|
main(args)
|