108 lines
4.3 KiB
Python
108 lines
4.3 KiB
Python
import sys
|
|
import os
|
|
|
|
# 添加项目根目录到路径
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from gpt_api.gpt_agent import *
|
|
from gpt_api.prompt import *
|
|
from gpt_api.task_info import *
|
|
from test_gpt_code import *
|
|
|
|
import argparse
|
|
|
|
def generate_code(task_info, las_error = None, message:list = None):
|
|
task_discription = task_info['task_description']
|
|
current_code = task_info['current_code']
|
|
available_env_function = str(AVAILABLE_ENV_FUNCTOIN)
|
|
function_example = str(FUNCTION_EXAMPLE)
|
|
task_name = task_info['task_name']
|
|
available_constants = str(AVAILABLE_CONSTANTS)
|
|
|
|
# Generate Code
|
|
if las_error is not None:
|
|
actor_name_keys, actor_data_keys, actor_points_discription = get_actor_keys_and_points_discription(f"gpt_{task_name}")
|
|
Prompt = f"The code is unsuccessful, \nLast Error Message: \n{las_error}\n\n\
|
|
Task Discription: \n{task_discription}\n\n\
|
|
The Actor Points Discription: {actor_points_discription}"
|
|
else:
|
|
res = f'''
|
|
from .base_task import Base_task
|
|
from .{task_name} import {task_name}
|
|
from .utils import *
|
|
import sapien
|
|
|
|
class gpt_{task_name}({task_name}):
|
|
def play_once(self):
|
|
pass
|
|
'''
|
|
file_name = f"envs/gpt_{task_name}.py"
|
|
with open(file_name, 'w') as file:
|
|
file.write(res)
|
|
actor_name_keys, actor_data_keys, actor_points_discription = get_actor_keys_and_points_discription(f"gpt_{task_name}")
|
|
Prompt = f"{BASIC_INFO}\n\n\
|
|
Task Discription: \n{task_discription}\n\n\
|
|
Available API: \n{available_env_function}\n\n\
|
|
Function Example: \n{function_example}\n\n\
|
|
Available Constants: \n{available_constants}\n\n\
|
|
The Actor Name List: {actor_name_keys}\n\n\
|
|
The Actor Data List: {actor_data_keys}\n\n\
|
|
The Actor Points Discription: {actor_points_discription}\n\n\
|
|
Current Code:\n{current_code}"
|
|
message.append({"role": "user", "content": Prompt})
|
|
|
|
# Start Generation Process
|
|
res = generate(message)
|
|
res = f'''
|
|
from .base_task import Base_task
|
|
from .{task_name} import {task_name}
|
|
from .utils import *
|
|
import sapien
|
|
|
|
class gpt_{task_name}({task_name}):
|
|
''' + res[res.find('def play_once'):res.rfind("```")]
|
|
# pdb.set_trace()
|
|
file_name = f"envs/gpt_{task_name}.py"
|
|
with open(file_name, 'w') as file:
|
|
file.write(res)
|
|
|
|
# return success rate, error message, error count
|
|
success_rate, error_message, error_count = test_run(f"gpt_{task_name}")
|
|
return res, success_rate, error_message, error_count
|
|
|
|
def main(task_info_dic):
|
|
# keys: "task_name", "task_description", "current_code"
|
|
|
|
task_info = now_task_info = task_info_dic
|
|
messages=[{"role": "system", "content": "You need to generate relevant code for some robot tasks in a robot simulation environment based on the provided API."}]
|
|
generate_num = 5
|
|
success_threshold = 0.5
|
|
las_error_message = None
|
|
|
|
suc_list = []
|
|
for id in range(generate_num):
|
|
print("Generate code for task: ", task_info['task_name'], f"({id+1}/{generate_num})")
|
|
res_code, success_rate, las_error_message, error_count = generate_code(now_task_info, las_error_message, messages)
|
|
|
|
suc_list.append(success_rate)
|
|
if success_rate >= success_threshold:
|
|
print("Success generate code for task: ", task_info['task_name'])
|
|
break
|
|
print("Failed to generate code for task: ", task_info['task_name'], f"{id}\nError massage: \n{las_error_message}")
|
|
now_task_info["task_description"] = " Failed to generate code, error message: " + las_error_message + ", error count: " + str(error_count)
|
|
# las_code = task_info["current_code"][:task_info["current_code"].find('def play_once')]
|
|
now_task_info["current_code"] = res_code
|
|
print(success_rate)
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description='Process some integers.')
|
|
parser.add_argument('task_name', type=str)
|
|
now_task = None
|
|
|
|
try:
|
|
task_name = parser.parse_args().task_name.upper()
|
|
exec(f'now_task = {task_name}')
|
|
except:
|
|
raise ValueError("The task name is wrong.")
|
|
|
|
main(now_task) |