RoboTwin_image/script/task_generation.py
2025-07-02 03:13:07 +00:00

108 lines
4.3 KiB
Python

import sys
import os
# 添加项目根目录到路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from gpt_api.gpt_agent import *
from gpt_api.prompt import *
from gpt_api.task_info import *
from test_gpt_code import *
import argparse
def generate_code(task_info, las_error = None, message:list = None):
task_discription = task_info['task_description']
current_code = task_info['current_code']
available_env_function = str(AVAILABLE_ENV_FUNCTOIN)
function_example = str(FUNCTION_EXAMPLE)
task_name = task_info['task_name']
available_constants = str(AVAILABLE_CONSTANTS)
# Generate Code
if las_error is not None:
actor_name_keys, actor_data_keys, actor_points_discription = get_actor_keys_and_points_discription(f"gpt_{task_name}")
Prompt = f"The code is unsuccessful, \nLast Error Message: \n{las_error}\n\n\
Task Discription: \n{task_discription}\n\n\
The Actor Points Discription: {actor_points_discription}"
else:
res = f'''
from .base_task import Base_task
from .{task_name} import {task_name}
from .utils import *
import sapien
class gpt_{task_name}({task_name}):
def play_once(self):
pass
'''
file_name = f"envs/gpt_{task_name}.py"
with open(file_name, 'w') as file:
file.write(res)
actor_name_keys, actor_data_keys, actor_points_discription = get_actor_keys_and_points_discription(f"gpt_{task_name}")
Prompt = f"{BASIC_INFO}\n\n\
Task Discription: \n{task_discription}\n\n\
Available API: \n{available_env_function}\n\n\
Function Example: \n{function_example}\n\n\
Available Constants: \n{available_constants}\n\n\
The Actor Name List: {actor_name_keys}\n\n\
The Actor Data List: {actor_data_keys}\n\n\
The Actor Points Discription: {actor_points_discription}\n\n\
Current Code:\n{current_code}"
message.append({"role": "user", "content": Prompt})
# Start Generation Process
res = generate(message)
res = f'''
from .base_task import Base_task
from .{task_name} import {task_name}
from .utils import *
import sapien
class gpt_{task_name}({task_name}):
''' + res[res.find('def play_once'):res.rfind("```")]
# pdb.set_trace()
file_name = f"envs/gpt_{task_name}.py"
with open(file_name, 'w') as file:
file.write(res)
# return success rate, error message, error count
success_rate, error_message, error_count = test_run(f"gpt_{task_name}")
return res, success_rate, error_message, error_count
def main(task_info_dic):
# keys: "task_name", "task_description", "current_code"
task_info = now_task_info = task_info_dic
messages=[{"role": "system", "content": "You need to generate relevant code for some robot tasks in a robot simulation environment based on the provided API."}]
generate_num = 5
success_threshold = 0.5
las_error_message = None
suc_list = []
for id in range(generate_num):
print("Generate code for task: ", task_info['task_name'], f"({id+1}/{generate_num})")
res_code, success_rate, las_error_message, error_count = generate_code(now_task_info, las_error_message, messages)
suc_list.append(success_rate)
if success_rate >= success_threshold:
print("Success generate code for task: ", task_info['task_name'])
break
print("Failed to generate code for task: ", task_info['task_name'], f"{id}\nError massage: \n{las_error_message}")
now_task_info["task_description"] = " Failed to generate code, error message: " + las_error_message + ", error count: " + str(error_count)
# las_code = task_info["current_code"][:task_info["current_code"].find('def play_once')]
now_task_info["current_code"] = res_code
print(success_rate)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('task_name', type=str)
now_task = None
try:
task_name = parser.parse_args().task_name.upper()
exec(f'now_task = {task_name}')
except:
raise ValueError("The task name is wrong.")
main(now_task)