This commit is contained in:
dyabel 2024-03-02 15:56:53 +00:00
parent 1182893a23
commit 5b1cba3d75
6 changed files with 173 additions and 432 deletions

View file

@ -96,186 +96,3 @@ def compute_pass_rate(query_id, example, task_solvable=None, task_solvable_reaso
else: else:
label = "failed" label = "failed"
return query_id, task_solvable, is_solved, label, reason, not_hallucinate, tokens return query_id, task_solvable, is_solved, label, reason, not_hallucinate, tokens
# output_dir = f'result1/generated_solve_given_api_solvable_multicat_complex_r1/stack_reassign_solve_results_turbo_r16'
if __name__ == '__main__':
# reassign = False
test_sets = ["G1_instruction", "G1_tool", "G1_category", "G2_instruction", "G2_category", "G3_instruction"]
# test_sets = ["G1_tool", "G1_category", "G2_instruction", "G2_category", "G3_instruction"]
# test_sets = ['custom_data']
# test_sets = ['G1_instruction']
# test_sets = [ "G1_category","G1_instruction", "G1_tool", "G2_instruction"]
# test_sets = ['G2_instruction', 'G3_instruction']
unsolvable_list = json.load(open("unsolvable.json", "r", encoding="utf-8"))
# unsolvable_list = []
pass_rate_list = []
average_tokens_list = []
for test_set in test_sets:
total_tokens = 0
# query_dir = f'data/test_instruction/{test_set}'
# output_dir = f'result2/test_instruction/{test_set}'
# output_dir = f'result0111/turbo/test_instruction/{test_set}_r1'
# output_dir = f'data/reproduction_data/model_predictions/chatgpt_dfs/{test_set}'
# output_dir = f'data/reproduction_data/model_predictions/toolllama_dfs/{test_set}'
output_dir = f'data/reproduction_data/model_predictions/toolllama_dfs_retriever/{test_set}'
# output_dir = f'data/reproduction_data/model_predictions/gpt-4-0613_dfs/{test_set}'
# output_dir = f'data/reproduction_data/model_predictions/chatgpt_cot/{test_set}'
# output_dir = f'data/reproduction_data/model_predictions/gpt-4-0613_cot/{test_set}'
# 33.5&33.5&41.0&23.5&29.5&3.0 27.3
# output_dir = f'result0111/32k/test_instruction/{test_set}_r1'
# output_dir = f'result0111/32k/max32/test_instruction/{test_set}_r1'
# output_dir = f'result_final/toolbench/{test_set}'
# output_dir = f'result0126/toolbench/{test_set}'
# output_dir = f'repos/toolbench_ori/{test_set}_filtered/gpt4_retriever_dfs'
# output_dir = f'repos/toolbench_ori/{test_set}_filtered/toolllama_retriever_ada_dfs'
# output_dir = f'data/reproduction_data/model_predictions/gpt-35-turbo_dfs/{test_set}'
# output_dir = 'result_final/custom_data/gpt_dfs_retriever'
# output_dir = 'result_final/custom_data/toolllama_dfs_retriever'
# output_dir = 'result_final/custom_data/gpt4_gt_dfs'
# output_dir = 'result0111/32k_aus/custom_data'
if 'reproduction' in output_dir or 'ori' in output_dir:
reassign = False
else:
reassign = True
# reassign = False
if reassign:
test_ids = list(range(200))
else:
test_ids = json.load(open(f'data/test_query_ids/{test_set}.json', 'r', encoding='utf-8'))
if 'cot' in output_dir:
method = 'CoT@1'
else:
method = 'DFS_woFilter_w2'
if not os.path.exists(output_dir):
continue
# evaluation_output_dir = f'result2/test_instruction/{test_set}/pass_rate_result_reeval_32k'
# os.system(f'mv {evaluation_output_dir} {output_dir}')
# evaluation_output_dir = f'result2/test_instruction/{test_set}/pass_rate_result_35'
# evaluation_output_dir = f'{output_dir}/pass_rate_result_reeval_32k_3times_nounsure_aus_r1'
# evaluation_output_dir = f'{output_dir}/pass_rate_result_reeval_32k_r1'
# final
evaluation_output_dir = f'{output_dir}/pass_rate_result_reeval_32k_3times'
# evaluation_output_dir = f'{output_dir}/pass_rate_result_35'
# evaluation_output_dir = f'{output_dir}/pass_rate_result_reeval_32k'
# continue
os.makedirs(evaluation_output_dir, exist_ok=True)
# label_cnt = {}
# answer_dict = {}
if os.path.exists(f"{evaluation_output_dir}/label_cnt.json"):
label_cnt = json.load(open(f"{evaluation_output_dir}/label_cnt.json", "r", encoding="utf-8"))
else:
label_cnt = {}
future = []
if os.path.exists(f"{evaluation_output_dir}/answer_dict.json"):
answer_dict = json.load(open(f"{evaluation_output_dir}/answer_dict.json", "r", encoding="utf-8"))
else:
answer_dict = {}
# result_data = json.load(open(f'data/reproduction_data/model_predictions/gpt-4-0613_dfs/{test_set}.json', 'r', encoding='utf-8'))
referenced_examples = {}
with ThreadPoolExecutor(args.max_eval_threads) as pool:
for i in test_ids:
# print(i)
if reassign:
try:
# print(f'{output_dir}/{i}.json')
data = json.load(open(f'{output_dir}/{i}.json', 'r', encoding='utf-8'))
except:
continue
query_id = data['query_id']
if int(query_id) in unsolvable_list:
continue
else:
query_id = i
if int(query_id) in unsolvable_list:
continue
if 'chatgpt' in output_dir and 'cot' not in output_dir:
data = json.load(open(f'{output_dir}/{i}_ChatGPT_{method}.json', 'r', encoding='utf-8'))
elif 'chatgpt' in output_dir:
data = json.load(open(f'{output_dir}/{i}_{method}.json', 'r', encoding='utf-8'))
else:
data = json.load(open(f'{output_dir}/{i}_{method}.json', 'r', encoding='utf-8'))
if not reassign:
total_tokens += data['answer_generation']['total_tokens']
else:
if 'total_tokens' in data:
total_tokens += data['total_tokens']
if str(query_id) in label_cnt:
continue
if reassign:
# print(i)
if 'last_solve_time' not in data:
try:
data_dict = json.load(open(f'{output_dir}/{i}_DFS_woFilter_w2.json', 'r', encoding='utf-8'))
except:
continue
else:
last_solve_time = data['last_solve_time']
data_dict = json.load(open(f'{output_dir}/{i}_{last_solve_time}_DFS_woFilter_w2.json', 'r', encoding='utf-8'))
else:
data_dict = data
if not data_dict['answer_generation']['valid_data']:
answer_dict[i] = process_invalid_data(method,data_dict)
else:
answer_dict[i] = process_valid_data(method,data_dict['answer_generation'])
example = answer_dict[i]
# query_id = i
# example['available_tools'] = query_data[str(query_id)]['available_tools']
referenced_examples[query_id] = example
for _ in range(args.evaluate_times):
future.append(pool.submit(
compute_pass_rate,
query_id,
example,
'Solvable',
'Task solvable human label'
))
for thd in tqdm(as_completed(future),total=len(future),ncols=100):
query_id, task_solvable, is_solved, machine_label, reason, not_hallucinate, tokens = thd.result()
example = referenced_examples[query_id]
query = example["query"]
tool_names = []
for tool_dict in example["available_tools"]:
tool_name = tool_dict["name"]
tool_names.append(tool_name)
answer_steps, final_step = get_steps(example)
if query_id not in label_cnt:
label_cnt[query_id] = {"passed":0, "failed":0, "unsure":0}
if machine_label == "passed":
label_cnt[query_id]["passed"] += 1
elif machine_label == "failed":
label_cnt[query_id]["failed"] += 1
else:
label_cnt[query_id]["unsure"] += 1
label_cnt[query_id]["query"] = query
label_cnt[query_id]["task_solvable"] = str(task_solvable)
label_cnt[query_id]["tool_names"] = tool_names
label_cnt[query_id]["answer_steps"] = answer_steps
label_cnt[query_id]["final_step"] = final_step
label_cnt[query_id]["is_solved"] = str(is_solved)
label_cnt[query_id]["reason"] = reason
label_cnt[query_id]["not_hallucinate"] = not_hallucinate
json.dump(label_cnt, open(f"{evaluation_output_dir}/label_cnt.json", "w"), ensure_ascii=False, indent=4)
filename = f"{evaluation_output_dir}/label_cnt.csv"
write_results(filename, 'result', label_cnt)
pass_rate = 0
total_num = 0
print('#'*100)
for query_id in label_cnt:
if int(query_id) in unsolvable_list:
continue
if label_cnt[query_id]["failed"] <= label_cnt[query_id]["passed"]:
pass_rate += 1
# if label_cnt[query_id]["unsure"] > 0:
# print('unsure')
total_num += 1
pass_rate /= total_num
pass_rate_list.append(pass_rate)
average_tokens_list.append(total_tokens/total_num)
print(f"Pass rate: {str(pass_rate)} total num {total_num} average tokens {total_tokens/total_num} {test_set}")
json.dump(answer_dict, open(f"{evaluation_output_dir}/answer_dict.json", "w"), ensure_ascii=False, indent=4)
print('&'.join([str(round(x*100,1)) for x in pass_rate_list]),round(np.mean(pass_rate_list)*100,1))
print('&'.join([str(round(x,1)) for x in average_tokens_list]),round(np.mean(average_tokens_list),1))

View file

@ -125,7 +125,6 @@ def solve_given_api_main(query, api_list, i, messages=None):
result_data['result']['reason'] = reason result_data['result']['reason'] = reason
result['answer_generation']['final_answer'] = json.dumps(result_data['result']) result['answer_generation']['final_answer'] = json.dumps(result_data['result'])
if result['answer_generation']['finish_type'] == 'give_answer' and 'final_answer' in result_data['result'] and result_data['result']['final_answer'] != '': if result['answer_generation']['finish_type'] == 'give_answer' and 'final_answer' in result_data['result'] and result_data['result']['final_answer'] != '':
# and not any(word in str(result['answer_generation']['final_answer']).lower() for word in exclusion_words)
solved = True solved = True
else: else:
solved = False solved = False
@ -137,7 +136,6 @@ from arguments import parse_args
args = parse_args() args = parse_args()
output_path = args.output_dir output_path = args.output_dir
# output_path = f'{query_dir}/reassign_toolllama_dfs_r1'
os.makedirs(output_path, exist_ok=True) os.makedirs(output_path, exist_ok=True)
dfs_args = dotdict(dict(backbone_model='chatgpt_function', openai_key='', model_path='your_model_path/', tool_root_dir='data/toolenv/tools/', lora=False, lora_path='your_lora_path if lora', max_observation_length=1024, max_source_sequence_length=4096, max_sequence_length=8192, observ_compress_method='truncate', method='DFS_woFilter_w2', input_query_file='data/test_instruction/G1_tool.json', output_answer_file=output_path, toolbench_key=toolbench_key, rapidapi_key='', use_rapidapi_key=False, api_customization=False)) dfs_args = dotdict(dict(backbone_model='chatgpt_function', openai_key='', model_path='your_model_path/', tool_root_dir='data/toolenv/tools/', lora=False, lora_path='your_lora_path if lora', max_observation_length=1024, max_source_sequence_length=4096, max_sequence_length=8192, observ_compress_method='truncate', method='DFS_woFilter_w2', input_query_file='data/test_instruction/G1_tool.json', output_answer_file=output_path, toolbench_key=toolbench_key, rapidapi_key='', use_rapidapi_key=False, api_customization=False))
dfs_runner = pipeline_runner(dfs_args) dfs_runner = pipeline_runner(dfs_args)
@ -156,9 +154,6 @@ if __name__ == '__main__':
for i in range(262): for i in range(262):
t_s = time.time() t_s = time.time()
comparison_data = {} comparison_data = {}
# for file in files:
# if file.endswith('.json'):
# print(file)
data_load = json.load(open(f'{args.query_dir}/{i}.json', 'r', encoding='utf-8')) data_load = json.load(open(f'{args.query_dir}/{i}.json', 'r', encoding='utf-8'))
if str(data_load['query_id']) in solved_dict and solved_dict[str(data_load['query_id'])]['solved'] != 'Solved': if str(data_load['query_id']) in solved_dict and solved_dict[str(data_load['query_id'])]['solved'] != 'Solved':
continue continue
@ -166,8 +161,6 @@ if __name__ == '__main__':
query = data_load['query'] query = data_load['query']
# continue # continue
cnt += 1 cnt += 1
# if cnt > 50:
# break
if os.path.exists(os.path.join(output_path, f'{i}_DFS_woFilter_w2.json')): if os.path.exists(os.path.join(output_path, f'{i}_DFS_woFilter_w2.json')):
data = json.load(open(os.path.join(output_path, f'{i}_DFS_woFilter_w2.json'), 'r', encoding='utf-8')) data = json.load(open(os.path.join(output_path, f'{i}_DFS_woFilter_w2.json'), 'r', encoding='utf-8'))
final_data = json.load(open(os.path.join(output_path, f'{i}.json'), 'r', encoding='utf-8')) final_data = json.load(open(os.path.join(output_path, f'{i}.json'), 'r', encoding='utf-8'))

View file

@ -25,49 +25,41 @@ def check_task_solvable(query):
"content": f"Please check whether the following query is solvable: {query}. Begin!"} "content": f"Please check whether the following query is solvable: {query}. Begin!"}
] ]
for i in range(5): for i in range(5):
# try: response = call_gpt(
if True: messages=messages,
t_s = time.time() functions=[solvable_finish_function]
response = call_gpt( )
messages=messages, tool_calls = response.choices[0].message.tool_calls
functions=[solvable_finish_function] print('Thought:', response.choices[0].message.content)
) if tool_calls:
print(time.time() - t_s) for tool_call in tool_calls:
tool_calls = response.choices[0].message.tool_calls function_name = tool_call.function.name
print('Thought:', response.choices[0].message.content) function_args = tool_call.function.arguments
if tool_calls: if function_name == 'Finish':
for tool_call in tool_calls: try:
function_name = tool_call.function.name solvable, reason = Finish(**json.loads(function_args))
function_args = tool_call.function.arguments except:
if function_name == 'Finish':
try:
solvable, reason = Finish(**json.loads(function_args))
except:
continue
# solvable, reason = Finish(json.loads(function_args))
else:
continue continue
print(solvable, query, file=open('result/solvable.txt', 'a', encoding='utf-6')) # solvable, reason = Finish(json.loads(function_args))
if solvable == 'Unsolvable' and reason is None:
messages.append({"role": "user", "content": 'You must give reason if the answer is Unsolvable'}) else:
if reason is not None: continue
print(reason, file=open('result/solvable.txt', 'a', encoding='utf-8')) print(solvable, query, file=open('result/solvable.txt', 'a', encoding='utf-6'))
else: if solvable == 'Unsolvable' and reason is None:
reason = '' messages.append({"role": "user", "content": 'You must give reason if the answer is Unsolvable'})
return solvable, reason if reason is not None:
else: print(reason, file=open('result/solvable.txt', 'a', encoding='utf-8'))
print('Thought:', response.choices[0].message.content) else:
continue reason = ''
return solvable, reason
else:
print('Thought:', response.choices[0].message.content)
continue
# messages.append({"role": "assistant", "content": response.choices[0].message.get('content', '')}) # messages.append({"role": "assistant", "content": response.choices[0].message.get('content', '')})
# except:
# pass
print('No response from the model', file=open('result/solvable.txt', 'a', encoding='utf-8')) print('No response from the model', file=open('result/solvable.txt', 'a', encoding='utf-8'))
return 'No response', 'No response from the model' return 'No response', 'No response from the model'
def check_task_solvable_by_function(query, functions): def check_task_solvable_by_function(query, functions):
# return 'Solvable', ''
# print(functions)
messages = [{ messages = [{
"role": "system", "role": "system",
"content": CHECK_SOLVABLE_BY_FUNCTION_PROMPT "content": CHECK_SOLVABLE_BY_FUNCTION_PROMPT
@ -76,40 +68,33 @@ def check_task_solvable_by_function(query, functions):
"content": f"Query: {query}. Available_tools: {functions}. Begin!"} "content": f"Query: {query}. Available_tools: {functions}. Begin!"}
] ]
for i in range(5): for i in range(5):
# try: response = call_gpt(
if True: messages=messages,
t_s = time.time() functions=[solvable_finish_function]
response = call_gpt( )
messages=messages, tool_calls = response.choices[0].message.tool_calls
functions=[solvable_finish_function] print('Thought:', response.choices[0].message.content)
) if tool_calls:
print(time.time() - t_s) for tool_call in tool_calls:
tool_calls = response.choices[0].message.tool_calls function_name = tool_call.function.name
print('Thought:', response.choices[0].message.content) function_args = tool_call.function.arguments
if tool_calls: if function_name.lower() == 'finish':
for tool_call in tool_calls: try:
function_name = tool_call.function.name solvable, reason = Finish(**json.loads(function_args))
function_args = tool_call.function.arguments except:
if function_name.lower() == 'finish':
try:
solvable, reason = Finish(**json.loads(function_args))
except:
continue
# solvable, reason = Finish(json.loads(function_args))
else:
continue continue
print(solvable, query, file=open('result/solvable.txt', 'a', encoding='utf-8')) # solvable, reason = Finish(json.loads(function_args))
if solvable == 'Unsolvable' and reason is None:
messages.append({"role": "user", "content": 'You must give reason if the answer is Unsolvable'}) else:
if reason is not None: continue
print(reason, file=open('result/solvable.txt', 'a', encoding='utf-8')) if solvable == 'Unsolvable' and reason is None:
else: messages.append({"role": "user", "content": 'You must give reason if the answer is Unsolvable'})
reason = '' if reason is None:
return solvable, reason, response.usage.total_tokens reason = ''
else: return solvable, reason, response.usage.total_tokens
print('Thought:', response.choices[0].message.content) else:
continue print('Thought:', response.choices[0].message.content)
continue
# messages.append({"role": "assistant", "content": response.choices[0].message.get('content', '')}) # messages.append({"role": "assistant", "content": response.choices[0].message.get('content', '')})
# except: # except:
# pass # pass
@ -141,13 +126,10 @@ def check_task_solved(query, answer):
print(function_name, function_args) print(function_name, function_args)
if function_name.lower() == 'finish': if function_name.lower() == 'finish':
solvable, reason = Finish(**json.loads(function_args)) solvable, reason = Finish(**json.loads(function_args))
print(solvable, query, file=open('result/solved.txt', 'a', encoding='utf-8'))
if solvable == 'Unsolved' and reason is None: if solvable == 'Unsolved' and reason is None:
messages.append({"role": "user", "content": 'You must give reason if the answer is Unsolvable'}) messages.append({"role": "user", "content": 'You must give reason if the answer is Unsolvable'})
continue continue
if reason is not None: if reason is None:
print(reason, file=open('result/solved.txt', 'a', encoding='utf-8'))
else:
reason = '' reason = ''
return solvable, reason return solvable, reason
@ -167,11 +149,9 @@ def check_solved_toolbench(output_path, query_id, task_solvable=None, solvable_t
example = process_invalid_data(method,data_dict) example = process_invalid_data(method,data_dict)
else: else:
example = process_valid_data(method,data_dict['answer_generation']) example = process_valid_data(method,data_dict['answer_generation'])
# example['available_tools'] = query_data[str(ori_query_id)]['available_tools']
future = [] future = []
answer_dict = {'passed':0, 'failed':0} answer_dict = {'passed':0, 'failed':0}
with ThreadPoolExecutor(32) as pool: with ThreadPoolExecutor(32) as pool:
print(task_solvable, solvable_task_reason, file=open(os.path.join(output_dir, 'solvable.txt'), 'a', encoding='utf-8'))
for _ in range(3): for _ in range(3):
future.append(pool.submit( future.append(pool.submit(
compute_pass_rate, compute_pass_rate,

View file

@ -60,15 +60,6 @@ def call_gpt(messages, functions=None, **kwargs):
**kwargs **kwargs
) )
except openai.BadRequestError as e: except openai.BadRequestError as e:
# try:
# response = turbo_client.chat.completions.create(
# seed=123,
# model='gpt-4-turbo',
# messages=messages,
# functions=functions
# )
# except Exception as e:
# raise e
raise e raise e
except openai.RateLimitError as e: except openai.RateLimitError as e:
@ -85,8 +76,6 @@ def call_gpt(messages, functions=None, **kwargs):
try: try:
response, t_real = call_gpt_retry(messages_converted, functions) response, t_real = call_gpt_retry(messages_converted, functions)
# json_content = response.choices[0].message.content # json_content = response.choices[0].message.content
t = time.time() - t_s
print('minus:', t-t_real, file=open(os.path.join(output_dir, "time.txt"), "a"))
# print(response.choices[0].message.function_call) # print(response.choices[0].message.function_call)
if response.choices[0].finish_reason == 'function_call': if response.choices[0].finish_reason == 'function_call':
response_json = json.loads(response.json()) response_json = json.loads(response.json())

View file

@ -963,135 +963,132 @@ if __name__ == "__main__":
# continue # continue
print(f'query: {query}') print(f'query: {query}')
flag = False flag = False
try: runner = Main_Search_Agent(query)
runner = Main_Search_Agent(query) agents.append(runner)
agents.append(runner) if multi_thread:
if multi_thread: thread = threading.Thread(target=runner.assign_main, args=(query,))
thread = threading.Thread(target=runner.assign_main, args=(query,)) sem.acquire()
sem.acquire() thread.start()
thread.start() threads.append(thread)
threads.append(thread) else:
iter_func = runner.assign_main(query)
messages = None
cnt = 0
solve_data = {}
if multi_thread:
while True:
thread_num = len(threads)
has_thread_alive = False
for thread in threads:
if thread.is_alive():
has_thread_alive = True
thread.join()
if not has_thread_alive:
break
if error_flag: raise Exception('GPT Call Error')
threads = []
# refind
check_solved = ''
max_depth = max([agent.depth for agent in agents])
while not all([agent.finish_search for agent in agents]) or not flag:
if check_solved == 'Solved':
break
max_depth = max([agent.depth for agent in agents])
depth = max_depth
while all([agent.finish_search for agent in agents if agent.depth == depth]) and depth >= 0:
depth -= 1
if depth < 0 and flag:
break
agents_to_resume = [agent for agent in agents if not agent.finish_search and agent.depth == depth]
if total_tokens > 200000 and flag:
solved = False
check_solved = 'Timeout'
solve_data = {'result': 'Timeout'}
break
cnt += 1
failed_reason = None
print(len(global_api_list), file=open(f'{output_dir}/api_list_len.txt', 'a', encoding='utf-8'))
print(global_api_list, file=open(f'{output_dir}/api_list.txt', 'a', encoding='utf-8'))
print('#'*100)
assign_results['api_list'].append(deepcopy(global_api_list))
if stop or not flag or all([agent.finish_search for agent in agents]) and len(global_api_list) > 0:
flag = True
last_solve_time = cnt
t_s = time.time()
selected_api_list = deepcopy(global_api_list)
solved, solve_data = solve_given_api_main(query, selected_api_list, f'{query_id}_{cnt}', messages)
print('solve time:', time.time() - t_s, 'api number:', len(global_api_list),file=open(f'{output_dir}/time.txt', 'a', encoding='utf-8'))
result_list.append(deepcopy(solve_data['result']))
print(solve_data['result'], solved)
if not solved or any([word in solve_data['result']['final_answer'] for word in exclusion_words]):
check_solved = 'Unsolved'
reason = solve_data['result']
else:
check_solved, reason, tokens = check_solved_toolbench(f'{output_dir}/{query_id}_{last_solve_time}_DFS_woFilter_w2.json', query_id, task_solvable, solvable_reason)
total_tokens += tokens
print(colored((check_solved, reason), 'red'))
failed_reason = reason
dfs_data = json.load(open(f'{output_dir}/{query_id}_{last_solve_time}_DFS_woFilter_w2.json', 'r', encoding='utf-8'))
total_tokens += dfs_data['answer_generation']['total_tokens']
solve_tokens += dfs_data['answer_generation']['total_tokens']
if check_solved == 'Solved':
break
try:
messages = dfs_data['answer_generation']['train_messages'][-1]
except:
messages = None
api_list_to_prune = []
for standardized_api_name, origin_api in dfs_data['api2origin'].items():
if standardized_api_name in str(failed_reason):
if origin_api in global_api_list:
api_list_to_prune.append(origin_api)
print(colored(api_list_to_prune, 'red'))
print(len(api_list_to_prune))
remove_apis(api_list_to_prune)
if len(global_api_list) >= max_api_number:
break
# print(api_list_to_prune, file=open(f'{output_dir}/prune_api_list.txt', 'a', encoding='utf-8'))
stop = False
else: else:
iter_func = runner.assign_main(query) assert status != 'The current api list can solve the query.'
messages = None failed_reason = status
cnt = 0 reason_list.append(failed_reason)
solve_data = {} print(colored('Refind Begin', 'red'))
print(colored(agents_to_resume, 'red'))
print([agent.finish_search for agent in agents_to_resume])
threads = []
resume_cnt = 0
resumed_agents.append([(str(a), a.index) for a in agents_to_resume])
for agent in reversed(agents_to_resume):
if agent.finish_search: continue
resume_cnt += 1
agent.failed_reason = str(failed_reason)
print(colored(('resuming', agent, agent.depth), 'red'))
print(colored(('resuming', agent, agent.depth), 'red'), file=open(f'{output_dir}/resume.txt', 'a', encoding='utf-8'))
if multi_thread:
thread = threading.Thread(target=agent.resume_search)
sem.acquire()
thread.start()
threads.append(thread)
else:
agent.resume_search()
if multi_thread: if multi_thread:
while True: while True:
thread_num = len(threads) thread_num = len(threads)
has_thread_alive = False
for thread in threads: for thread in threads:
if thread.is_alive(): if thread.is_alive():
has_thread_alive = True
thread.join() thread.join()
if not has_thread_alive: if thread_num == len(threads):
break break
if error_flag: raise Exception('GPT Call Error') if error_flag: raise Exception('GPT Call Error')
threads = [] if not stop:
# refind check_if_request_solvable()
check_solved = '' print(colored(f'status:{status}', 'red'))
max_depth = max([agent.depth for agent in agents]) assign_results['stop'].append(stop)
while not all([agent.finish_search for agent in agents]) or not flag:
if check_solved == 'Solved':
break
max_depth = max([agent.depth for agent in agents])
depth = max_depth
while all([agent.finish_search for agent in agents if agent.depth == depth]) and depth >= 0:
depth -= 1
if depth < 0 and flag:
break
agents_to_resume = [agent for agent in agents if not agent.finish_search and agent.depth == depth]
if total_tokens > 200000 and flag:
solved = False
check_solved = 'Timeout'
solve_data = {'result': 'Timeout'}
break
cnt += 1
failed_reason = None
print(len(global_api_list), file=open(f'{output_dir}/api_list_len.txt', 'a', encoding='utf-8'))
print(global_api_list, file=open(f'{output_dir}/api_list.txt', 'a', encoding='utf-8'))
print('#'*100)
assign_results['api_list'].append(deepcopy(global_api_list))
if stop or not flag or all([agent.finish_search for agent in agents]) and len(global_api_list) > 0:
flag = True
last_solve_time = cnt
t_s = time.time()
selected_api_list = deepcopy(global_api_list)
solved, solve_data = solve_given_api_main(query, selected_api_list, f'{query_id}_{cnt}', messages)
print('solve time:', time.time() - t_s, 'api number:', len(global_api_list),file=open(f'{output_dir}/time.txt', 'a', encoding='utf-8'))
result_list.append(deepcopy(solve_data['result']))
print(solve_data['result'], solved)
if not solved or any([word in solve_data['result']['final_answer'] for word in exclusion_words]):
check_solved = 'Unsolved'
reason = solve_data['result']
else:
check_solved, reason, tokens = check_solved_toolbench(f'{output_dir}/{query_id}_{last_solve_time}_DFS_woFilter_w2.json', query_id, task_solvable, solvable_reason)
total_tokens += tokens
print(colored((check_solved, reason), 'red'))
failed_reason = reason
dfs_data = json.load(open(f'{output_dir}/{query_id}_{last_solve_time}_DFS_woFilter_w2.json', 'r', encoding='utf-8'))
total_tokens += dfs_data['answer_generation']['total_tokens']
solve_tokens += dfs_data['answer_generation']['total_tokens']
if check_solved == 'Solved':
break
try:
messages = dfs_data['answer_generation']['train_messages'][-1]
except:
messages = None
api_list_to_prune = []
for standardized_api_name, origin_api in dfs_data['api2origin'].items():
if standardized_api_name in str(failed_reason):
if origin_api in global_api_list:
api_list_to_prune.append(origin_api)
print(colored(api_list_to_prune, 'red'))
print(len(api_list_to_prune))
remove_apis(api_list_to_prune)
if len(global_api_list) >= max_api_number:
break
# print(api_list_to_prune, file=open(f'{output_dir}/prune_api_list.txt', 'a', encoding='utf-8'))
stop = False
else:
assert status != 'The current api list can solve the query.'
failed_reason = status
reason_list.append(failed_reason)
print(colored('Refind Begin', 'red'))
print(colored(agents_to_resume, 'red'))
print([agent.finish_search for agent in agents_to_resume])
threads = []
resume_cnt = 0
resumed_agents.append([(str(a), a.index) for a in agents_to_resume])
for agent in reversed(agents_to_resume):
if agent.finish_search: continue
resume_cnt += 1
agent.failed_reason = str(failed_reason)
print(colored(('resuming', agent, agent.depth), 'red'))
print(colored(('resuming', agent, agent.depth), 'red'), file=open(f'{output_dir}/resume.txt', 'a', encoding='utf-8'))
if multi_thread:
thread = threading.Thread(target=agent.resume_search)
sem.acquire()
thread.start()
threads.append(thread)
else:
agent.resume_search()
if multi_thread:
while True:
thread_num = len(threads)
for thread in threads:
if thread.is_alive():
thread.join()
if thread_num == len(threads):
break
if error_flag: raise Exception('GPT Call Error')
if not stop:
check_if_request_solvable()
print(colored(f'status:{status}', 'red'))
assign_results['stop'].append(stop)
except KeyboardInterrupt as e:
continue
assign_results['api_complete'] = flag assign_results['api_complete'] = flag

View file

@ -10,25 +10,6 @@ from arguments import parse_args
from config import * from config import *
from openai_utils import call_gpt from openai_utils import call_gpt
args = parse_args() args = parse_args()
output_dir = args.output_dir
# import importlib
# module = importlib.import_module(args.openai_config_path.replace('.py',''))
# for name in dir(module):
# if not name.startswith('_'):
# globals()[name] = getattr(module, name)
# api_key = globals()['api_key']
# api_version = globals()['api_version']
# model_name = globals()['model_name']
# api_base = globals()['api_base']
if api_type == "azure":
from openai import AzureOpenAI as Client
else:
from openai import OpenAI as Client
client = Client(
api_key=api_key,
api_version=api_version,
azure_endpoint = api_base
)
@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3)) @retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))
def chat_completion_request(key, messages, functions=None,function_call=None,key_pos=None, model="gpt-4-32k",stop=None,process_id=0, **args): def chat_completion_request(key, messages, functions=None,function_call=None,key_pos=None, model="gpt-4-32k",stop=None,process_id=0, **args):
use_messages = [] use_messages = []
@ -52,30 +33,14 @@ def chat_completion_request(key, messages, functions=None,function_call=None,key
json_data.update({"function_call": function_call}) json_data.update({"function_call": function_call})
try: try:
# if model in ["gpt-3.5-turbo-16k-0613","gpt-4-0613", "gpt-4-deployment","gpt-4-32k", 'gpt-4-turbo']:
# openai.api_key = key
# else:
# raise NotImplementedError
# ts = time.time()
# print(json_data, file=open('output/gpt_io.txt','a'))
# print(time.time()-ts)
ts = time.time()
# json.dump(json_data['messages'], open(os.path.join(output_dir,'messages.json'),'w'), indent=4)
openai_response = call_gpt( openai_response = call_gpt(
**json_data, **json_data,
) )
# openai_response = client.chat.completions.create(
# **json_data,
# )
# print('solve', time.time()-ts, file=open(os.path.join(output_dir,'time.txt'),'a'))
# json_data = json.loads(str(openai_response))
json_data = json.loads(openai_response.json()) json_data = json.loads(openai_response.json())
json_data["choices"][0]['message'].pop('tool_calls') json_data["choices"][0]['message'].pop('tool_calls')
return json_data return json_data
except Exception as e: except Exception as e:
# print('solve', time.time()-ts, file=open(os.path.join(output_dir,'time.txt'),'a'))
# # json_data = json.loads(str(openai_response))
# json_data = json.loads(openai_response.json()) # json_data = json.loads(openai_response.json())
# json_data["choices"][0]['message'].pop('tool_calls') # json_data["choices"][0]['message'].pop('tool_calls')
print("Unable to generate ChatCompletion response") print("Unable to generate ChatCompletion response")