diff --git a/anytool/check_solved.py b/anytool/check_solved.py index 79f7bc1..25fb553 100644 --- a/anytool/check_solved.py +++ b/anytool/check_solved.py @@ -96,186 +96,3 @@ def compute_pass_rate(query_id, example, task_solvable=None, task_solvable_reaso else: label = "failed" return query_id, task_solvable, is_solved, label, reason, not_hallucinate, tokens -# output_dir = f'result1/generated_solve_given_api_solvable_multicat_complex_r1/stack_reassign_solve_results_turbo_r16' -if __name__ == '__main__': - # reassign = False - test_sets = ["G1_instruction", "G1_tool", "G1_category", "G2_instruction", "G2_category", "G3_instruction"] - # test_sets = ["G1_tool", "G1_category", "G2_instruction", "G2_category", "G3_instruction"] - # test_sets = ['custom_data'] - # test_sets = ['G1_instruction'] - # test_sets = [ "G1_category","G1_instruction", "G1_tool", "G2_instruction"] - # test_sets = ['G2_instruction', 'G3_instruction'] - unsolvable_list = json.load(open("unsolvable.json", "r", encoding="utf-8")) - # unsolvable_list = [] - pass_rate_list = [] - average_tokens_list = [] - for test_set in test_sets: - total_tokens = 0 - # query_dir = f'data/test_instruction/{test_set}' - # output_dir = f'result2/test_instruction/{test_set}' - # output_dir = f'result0111/turbo/test_instruction/{test_set}_r1' - # output_dir = f'data/reproduction_data/model_predictions/chatgpt_dfs/{test_set}' - # output_dir = f'data/reproduction_data/model_predictions/toolllama_dfs/{test_set}' - output_dir = f'data/reproduction_data/model_predictions/toolllama_dfs_retriever/{test_set}' - # output_dir = f'data/reproduction_data/model_predictions/gpt-4-0613_dfs/{test_set}' - # output_dir = f'data/reproduction_data/model_predictions/chatgpt_cot/{test_set}' - # output_dir = f'data/reproduction_data/model_predictions/gpt-4-0613_cot/{test_set}' - # 33.5&33.5&41.0&23.5&29.5&3.0 27.3 - # output_dir = f'result0111/32k/test_instruction/{test_set}_r1' - # output_dir = f'result0111/32k/max32/test_instruction/{test_set}_r1' - # output_dir = f'result_final/toolbench/{test_set}' - # output_dir = f'result0126/toolbench/{test_set}' - # output_dir = f'repos/toolbench_ori/{test_set}_filtered/gpt4_retriever_dfs' - # output_dir = f'repos/toolbench_ori/{test_set}_filtered/toolllama_retriever_ada_dfs' - # output_dir = f'data/reproduction_data/model_predictions/gpt-35-turbo_dfs/{test_set}' - # output_dir = 'result_final/custom_data/gpt_dfs_retriever' - # output_dir = 'result_final/custom_data/toolllama_dfs_retriever' - # output_dir = 'result_final/custom_data/gpt4_gt_dfs' - # output_dir = 'result0111/32k_aus/custom_data' - if 'reproduction' in output_dir or 'ori' in output_dir: - reassign = False - else: - reassign = True - # reassign = False - if reassign: - test_ids = list(range(200)) - else: - test_ids = json.load(open(f'data/test_query_ids/{test_set}.json', 'r', encoding='utf-8')) - if 'cot' in output_dir: - method = 'CoT@1' - else: - method = 'DFS_woFilter_w2' - if not os.path.exists(output_dir): - continue - # evaluation_output_dir = f'result2/test_instruction/{test_set}/pass_rate_result_reeval_32k' - # os.system(f'mv {evaluation_output_dir} {output_dir}') - # evaluation_output_dir = f'result2/test_instruction/{test_set}/pass_rate_result_35' - # evaluation_output_dir = f'{output_dir}/pass_rate_result_reeval_32k_3times_nounsure_aus_r1' - # evaluation_output_dir = f'{output_dir}/pass_rate_result_reeval_32k_r1' - # final - evaluation_output_dir = f'{output_dir}/pass_rate_result_reeval_32k_3times' - # evaluation_output_dir = f'{output_dir}/pass_rate_result_35' - # evaluation_output_dir = f'{output_dir}/pass_rate_result_reeval_32k' - # continue - os.makedirs(evaluation_output_dir, exist_ok=True) - # label_cnt = {} - # answer_dict = {} - if os.path.exists(f"{evaluation_output_dir}/label_cnt.json"): - label_cnt = json.load(open(f"{evaluation_output_dir}/label_cnt.json", "r", encoding="utf-8")) - else: - label_cnt = {} - future = [] - if os.path.exists(f"{evaluation_output_dir}/answer_dict.json"): - answer_dict = json.load(open(f"{evaluation_output_dir}/answer_dict.json", "r", encoding="utf-8")) - else: - answer_dict = {} - # result_data = json.load(open(f'data/reproduction_data/model_predictions/gpt-4-0613_dfs/{test_set}.json', 'r', encoding='utf-8')) - referenced_examples = {} - - with ThreadPoolExecutor(args.max_eval_threads) as pool: - for i in test_ids: - # print(i) - if reassign: - try: - # print(f'{output_dir}/{i}.json') - data = json.load(open(f'{output_dir}/{i}.json', 'r', encoding='utf-8')) - except: - continue - query_id = data['query_id'] - if int(query_id) in unsolvable_list: - continue - else: - query_id = i - if int(query_id) in unsolvable_list: - continue - if 'chatgpt' in output_dir and 'cot' not in output_dir: - data = json.load(open(f'{output_dir}/{i}_ChatGPT_{method}.json', 'r', encoding='utf-8')) - elif 'chatgpt' in output_dir: - data = json.load(open(f'{output_dir}/{i}_{method}.json', 'r', encoding='utf-8')) - else: - data = json.load(open(f'{output_dir}/{i}_{method}.json', 'r', encoding='utf-8')) - - if not reassign: - total_tokens += data['answer_generation']['total_tokens'] - else: - if 'total_tokens' in data: - total_tokens += data['total_tokens'] - if str(query_id) in label_cnt: - continue - if reassign: - # print(i) - if 'last_solve_time' not in data: - try: - data_dict = json.load(open(f'{output_dir}/{i}_DFS_woFilter_w2.json', 'r', encoding='utf-8')) - except: - continue - else: - last_solve_time = data['last_solve_time'] - data_dict = json.load(open(f'{output_dir}/{i}_{last_solve_time}_DFS_woFilter_w2.json', 'r', encoding='utf-8')) - else: - data_dict = data - if not data_dict['answer_generation']['valid_data']: - answer_dict[i] = process_invalid_data(method,data_dict) - else: - answer_dict[i] = process_valid_data(method,data_dict['answer_generation']) - example = answer_dict[i] - # query_id = i - # example['available_tools'] = query_data[str(query_id)]['available_tools'] - referenced_examples[query_id] = example - for _ in range(args.evaluate_times): - future.append(pool.submit( - compute_pass_rate, - query_id, - example, - 'Solvable', - 'Task solvable human label' - )) - for thd in tqdm(as_completed(future),total=len(future),ncols=100): - query_id, task_solvable, is_solved, machine_label, reason, not_hallucinate, tokens = thd.result() - example = referenced_examples[query_id] - query = example["query"] - tool_names = [] - for tool_dict in example["available_tools"]: - tool_name = tool_dict["name"] - tool_names.append(tool_name) - answer_steps, final_step = get_steps(example) - if query_id not in label_cnt: - label_cnt[query_id] = {"passed":0, "failed":0, "unsure":0} - if machine_label == "passed": - label_cnt[query_id]["passed"] += 1 - elif machine_label == "failed": - label_cnt[query_id]["failed"] += 1 - else: - label_cnt[query_id]["unsure"] += 1 - label_cnt[query_id]["query"] = query - label_cnt[query_id]["task_solvable"] = str(task_solvable) - label_cnt[query_id]["tool_names"] = tool_names - label_cnt[query_id]["answer_steps"] = answer_steps - label_cnt[query_id]["final_step"] = final_step - label_cnt[query_id]["is_solved"] = str(is_solved) - label_cnt[query_id]["reason"] = reason - label_cnt[query_id]["not_hallucinate"] = not_hallucinate - json.dump(label_cnt, open(f"{evaluation_output_dir}/label_cnt.json", "w"), ensure_ascii=False, indent=4) - filename = f"{evaluation_output_dir}/label_cnt.csv" - write_results(filename, 'result', label_cnt) - pass_rate = 0 - total_num = 0 - print('#'*100) - for query_id in label_cnt: - if int(query_id) in unsolvable_list: - continue - if label_cnt[query_id]["failed"] <= label_cnt[query_id]["passed"]: - pass_rate += 1 - # if label_cnt[query_id]["unsure"] > 0: - # print('unsure') - total_num += 1 - pass_rate /= total_num - pass_rate_list.append(pass_rate) - average_tokens_list.append(total_tokens/total_num) - print(f"Pass rate: {str(pass_rate)} total num {total_num} average tokens {total_tokens/total_num} {test_set}") - json.dump(answer_dict, open(f"{evaluation_output_dir}/answer_dict.json", "w"), ensure_ascii=False, indent=4) - print('&'.join([str(round(x*100,1)) for x in pass_rate_list]),round(np.mean(pass_rate_list)*100,1)) - print('&'.join([str(round(x,1)) for x in average_tokens_list]),round(np.mean(average_tokens_list),1)) - - - diff --git a/anytool/dfs_gt.py b/anytool/dfs_gt.py index 6a3d2cf..3363bec 100644 --- a/anytool/dfs_gt.py +++ b/anytool/dfs_gt.py @@ -125,7 +125,6 @@ def solve_given_api_main(query, api_list, i, messages=None): result_data['result']['reason'] = reason result['answer_generation']['final_answer'] = json.dumps(result_data['result']) if result['answer_generation']['finish_type'] == 'give_answer' and 'final_answer' in result_data['result'] and result_data['result']['final_answer'] != '': - # and not any(word in str(result['answer_generation']['final_answer']).lower() for word in exclusion_words) solved = True else: solved = False @@ -137,7 +136,6 @@ from arguments import parse_args args = parse_args() output_path = args.output_dir -# output_path = f'{query_dir}/reassign_toolllama_dfs_r1' os.makedirs(output_path, exist_ok=True) dfs_args = dotdict(dict(backbone_model='chatgpt_function', openai_key='', model_path='your_model_path/', tool_root_dir='data/toolenv/tools/', lora=False, lora_path='your_lora_path if lora', max_observation_length=1024, max_source_sequence_length=4096, max_sequence_length=8192, observ_compress_method='truncate', method='DFS_woFilter_w2', input_query_file='data/test_instruction/G1_tool.json', output_answer_file=output_path, toolbench_key=toolbench_key, rapidapi_key='', use_rapidapi_key=False, api_customization=False)) dfs_runner = pipeline_runner(dfs_args) @@ -156,9 +154,6 @@ if __name__ == '__main__': for i in range(262): t_s = time.time() comparison_data = {} - # for file in files: - # if file.endswith('.json'): - # print(file) data_load = json.load(open(f'{args.query_dir}/{i}.json', 'r', encoding='utf-8')) if str(data_load['query_id']) in solved_dict and solved_dict[str(data_load['query_id'])]['solved'] != 'Solved': continue @@ -166,8 +161,6 @@ if __name__ == '__main__': query = data_load['query'] # continue cnt += 1 - # if cnt > 50: - # break if os.path.exists(os.path.join(output_path, f'{i}_DFS_woFilter_w2.json')): data = json.load(open(os.path.join(output_path, f'{i}_DFS_woFilter_w2.json'), 'r', encoding='utf-8')) final_data = json.load(open(os.path.join(output_path, f'{i}.json'), 'r', encoding='utf-8')) diff --git a/anytool/verifier.py b/anytool/verifier.py index 6ba3eef..7b00d39 100644 --- a/anytool/verifier.py +++ b/anytool/verifier.py @@ -25,49 +25,41 @@ def check_task_solvable(query): "content": f"Please check whether the following query is solvable: {query}. Begin!"} ] for i in range(5): - # try: - if True: - t_s = time.time() - response = call_gpt( - messages=messages, - functions=[solvable_finish_function] - ) - print(time.time() - t_s) - tool_calls = response.choices[0].message.tool_calls - print('Thought:', response.choices[0].message.content) - if tool_calls: - for tool_call in tool_calls: - function_name = tool_call.function.name - function_args = tool_call.function.arguments - if function_name == 'Finish': - try: - solvable, reason = Finish(**json.loads(function_args)) - except: - continue - # solvable, reason = Finish(json.loads(function_args)) - - else: + response = call_gpt( + messages=messages, + functions=[solvable_finish_function] + ) + tool_calls = response.choices[0].message.tool_calls + print('Thought:', response.choices[0].message.content) + if tool_calls: + for tool_call in tool_calls: + function_name = tool_call.function.name + function_args = tool_call.function.arguments + if function_name == 'Finish': + try: + solvable, reason = Finish(**json.loads(function_args)) + except: continue - print(solvable, query, file=open('result/solvable.txt', 'a', encoding='utf-6')) - if solvable == 'Unsolvable' and reason is None: - messages.append({"role": "user", "content": 'You must give reason if the answer is Unsolvable'}) - if reason is not None: - print(reason, file=open('result/solvable.txt', 'a', encoding='utf-8')) - else: - reason = '' - return solvable, reason - else: - print('Thought:', response.choices[0].message.content) - continue + # solvable, reason = Finish(json.loads(function_args)) + + else: + continue + print(solvable, query, file=open('result/solvable.txt', 'a', encoding='utf-6')) + if solvable == 'Unsolvable' and reason is None: + messages.append({"role": "user", "content": 'You must give reason if the answer is Unsolvable'}) + if reason is not None: + print(reason, file=open('result/solvable.txt', 'a', encoding='utf-8')) + else: + reason = '' + return solvable, reason + else: + print('Thought:', response.choices[0].message.content) + continue # messages.append({"role": "assistant", "content": response.choices[0].message.get('content', '')}) - # except: - # pass print('No response from the model', file=open('result/solvable.txt', 'a', encoding='utf-8')) return 'No response', 'No response from the model' def check_task_solvable_by_function(query, functions): - # return 'Solvable', '' - # print(functions) messages = [{ "role": "system", "content": CHECK_SOLVABLE_BY_FUNCTION_PROMPT @@ -76,40 +68,33 @@ def check_task_solvable_by_function(query, functions): "content": f"Query: {query}. Available_tools: {functions}. Begin!"} ] for i in range(5): - # try: - if True: - t_s = time.time() - response = call_gpt( - messages=messages, - functions=[solvable_finish_function] - ) - print(time.time() - t_s) - tool_calls = response.choices[0].message.tool_calls - print('Thought:', response.choices[0].message.content) - if tool_calls: - for tool_call in tool_calls: - function_name = tool_call.function.name - function_args = tool_call.function.arguments - if function_name.lower() == 'finish': - try: - solvable, reason = Finish(**json.loads(function_args)) - except: - continue - # solvable, reason = Finish(json.loads(function_args)) - - else: + response = call_gpt( + messages=messages, + functions=[solvable_finish_function] + ) + tool_calls = response.choices[0].message.tool_calls + print('Thought:', response.choices[0].message.content) + if tool_calls: + for tool_call in tool_calls: + function_name = tool_call.function.name + function_args = tool_call.function.arguments + if function_name.lower() == 'finish': + try: + solvable, reason = Finish(**json.loads(function_args)) + except: continue - print(solvable, query, file=open('result/solvable.txt', 'a', encoding='utf-8')) - if solvable == 'Unsolvable' and reason is None: - messages.append({"role": "user", "content": 'You must give reason if the answer is Unsolvable'}) - if reason is not None: - print(reason, file=open('result/solvable.txt', 'a', encoding='utf-8')) - else: - reason = '' - return solvable, reason, response.usage.total_tokens - else: - print('Thought:', response.choices[0].message.content) - continue + # solvable, reason = Finish(json.loads(function_args)) + + else: + continue + if solvable == 'Unsolvable' and reason is None: + messages.append({"role": "user", "content": 'You must give reason if the answer is Unsolvable'}) + if reason is None: + reason = '' + return solvable, reason, response.usage.total_tokens + else: + print('Thought:', response.choices[0].message.content) + continue # messages.append({"role": "assistant", "content": response.choices[0].message.get('content', '')}) # except: # pass @@ -141,13 +126,10 @@ def check_task_solved(query, answer): print(function_name, function_args) if function_name.lower() == 'finish': solvable, reason = Finish(**json.loads(function_args)) - print(solvable, query, file=open('result/solved.txt', 'a', encoding='utf-8')) if solvable == 'Unsolved' and reason is None: messages.append({"role": "user", "content": 'You must give reason if the answer is Unsolvable'}) continue - if reason is not None: - print(reason, file=open('result/solved.txt', 'a', encoding='utf-8')) - else: + if reason is None: reason = '' return solvable, reason @@ -167,11 +149,9 @@ def check_solved_toolbench(output_path, query_id, task_solvable=None, solvable_t example = process_invalid_data(method,data_dict) else: example = process_valid_data(method,data_dict['answer_generation']) - # example['available_tools'] = query_data[str(ori_query_id)]['available_tools'] future = [] answer_dict = {'passed':0, 'failed':0} with ThreadPoolExecutor(32) as pool: - print(task_solvable, solvable_task_reason, file=open(os.path.join(output_dir, 'solvable.txt'), 'a', encoding='utf-8')) for _ in range(3): future.append(pool.submit( compute_pass_rate, diff --git a/openai_utils.py b/openai_utils.py index 2e0524e..988fb5a 100644 --- a/openai_utils.py +++ b/openai_utils.py @@ -60,15 +60,6 @@ def call_gpt(messages, functions=None, **kwargs): **kwargs ) except openai.BadRequestError as e: - # try: - # response = turbo_client.chat.completions.create( - # seed=123, - # model='gpt-4-turbo', - # messages=messages, - # functions=functions - # ) - # except Exception as e: - # raise e raise e except openai.RateLimitError as e: @@ -85,8 +76,6 @@ def call_gpt(messages, functions=None, **kwargs): try: response, t_real = call_gpt_retry(messages_converted, functions) # json_content = response.choices[0].message.content - t = time.time() - t_s - print('minus:', t-t_real, file=open(os.path.join(output_dir, "time.txt"), "a")) # print(response.choices[0].message.function_call) if response.choices[0].finish_reason == 'function_call': response_json = json.loads(response.json()) diff --git a/scripts/main.py b/scripts/main.py index d44c90a..b4a2c3f 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -963,135 +963,132 @@ if __name__ == "__main__": # continue print(f'query: {query}') flag = False - try: - runner = Main_Search_Agent(query) - agents.append(runner) - if multi_thread: - thread = threading.Thread(target=runner.assign_main, args=(query,)) - sem.acquire() - thread.start() - threads.append(thread) + runner = Main_Search_Agent(query) + agents.append(runner) + if multi_thread: + thread = threading.Thread(target=runner.assign_main, args=(query,)) + sem.acquire() + thread.start() + threads.append(thread) + else: + iter_func = runner.assign_main(query) + messages = None + cnt = 0 + solve_data = {} + if multi_thread: + while True: + thread_num = len(threads) + has_thread_alive = False + for thread in threads: + if thread.is_alive(): + has_thread_alive = True + thread.join() + if not has_thread_alive: + break + if error_flag: raise Exception('GPT Call Error') + threads = [] + # refind + check_solved = '' + max_depth = max([agent.depth for agent in agents]) + while not all([agent.finish_search for agent in agents]) or not flag: + if check_solved == 'Solved': + break + max_depth = max([agent.depth for agent in agents]) + depth = max_depth + while all([agent.finish_search for agent in agents if agent.depth == depth]) and depth >= 0: + depth -= 1 + if depth < 0 and flag: + break + agents_to_resume = [agent for agent in agents if not agent.finish_search and agent.depth == depth] + if total_tokens > 200000 and flag: + solved = False + check_solved = 'Timeout' + solve_data = {'result': 'Timeout'} + break + cnt += 1 + failed_reason = None + print(len(global_api_list), file=open(f'{output_dir}/api_list_len.txt', 'a', encoding='utf-8')) + print(global_api_list, file=open(f'{output_dir}/api_list.txt', 'a', encoding='utf-8')) + print('#'*100) + assign_results['api_list'].append(deepcopy(global_api_list)) + if stop or not flag or all([agent.finish_search for agent in agents]) and len(global_api_list) > 0: + flag = True + last_solve_time = cnt + t_s = time.time() + selected_api_list = deepcopy(global_api_list) + + solved, solve_data = solve_given_api_main(query, selected_api_list, f'{query_id}_{cnt}', messages) + + print('solve time:', time.time() - t_s, 'api number:', len(global_api_list),file=open(f'{output_dir}/time.txt', 'a', encoding='utf-8')) + result_list.append(deepcopy(solve_data['result'])) + print(solve_data['result'], solved) + if not solved or any([word in solve_data['result']['final_answer'] for word in exclusion_words]): + check_solved = 'Unsolved' + reason = solve_data['result'] + else: + check_solved, reason, tokens = check_solved_toolbench(f'{output_dir}/{query_id}_{last_solve_time}_DFS_woFilter_w2.json', query_id, task_solvable, solvable_reason) + total_tokens += tokens + print(colored((check_solved, reason), 'red')) + failed_reason = reason + dfs_data = json.load(open(f'{output_dir}/{query_id}_{last_solve_time}_DFS_woFilter_w2.json', 'r', encoding='utf-8')) + total_tokens += dfs_data['answer_generation']['total_tokens'] + solve_tokens += dfs_data['answer_generation']['total_tokens'] + if check_solved == 'Solved': + break + try: + messages = dfs_data['answer_generation']['train_messages'][-1] + except: + messages = None + api_list_to_prune = [] + for standardized_api_name, origin_api in dfs_data['api2origin'].items(): + if standardized_api_name in str(failed_reason): + if origin_api in global_api_list: + api_list_to_prune.append(origin_api) + print(colored(api_list_to_prune, 'red')) + print(len(api_list_to_prune)) + remove_apis(api_list_to_prune) + if len(global_api_list) >= max_api_number: + break + # print(api_list_to_prune, file=open(f'{output_dir}/prune_api_list.txt', 'a', encoding='utf-8')) + stop = False else: - iter_func = runner.assign_main(query) - messages = None - cnt = 0 - solve_data = {} + assert status != 'The current api list can solve the query.' + failed_reason = status + reason_list.append(failed_reason) + print(colored('Refind Begin', 'red')) + print(colored(agents_to_resume, 'red')) + print([agent.finish_search for agent in agents_to_resume]) + + + threads = [] + resume_cnt = 0 + resumed_agents.append([(str(a), a.index) for a in agents_to_resume]) + for agent in reversed(agents_to_resume): + if agent.finish_search: continue + resume_cnt += 1 + agent.failed_reason = str(failed_reason) + print(colored(('resuming', agent, agent.depth), 'red')) + print(colored(('resuming', agent, agent.depth), 'red'), file=open(f'{output_dir}/resume.txt', 'a', encoding='utf-8')) + if multi_thread: + thread = threading.Thread(target=agent.resume_search) + sem.acquire() + thread.start() + threads.append(thread) + else: + agent.resume_search() if multi_thread: while True: thread_num = len(threads) - has_thread_alive = False for thread in threads: if thread.is_alive(): - has_thread_alive = True thread.join() - if not has_thread_alive: + if thread_num == len(threads): break if error_flag: raise Exception('GPT Call Error') - threads = [] - # refind - check_solved = '' - max_depth = max([agent.depth for agent in agents]) - while not all([agent.finish_search for agent in agents]) or not flag: - if check_solved == 'Solved': - break - max_depth = max([agent.depth for agent in agents]) - depth = max_depth - while all([agent.finish_search for agent in agents if agent.depth == depth]) and depth >= 0: - depth -= 1 - if depth < 0 and flag: - break - agents_to_resume = [agent for agent in agents if not agent.finish_search and agent.depth == depth] - if total_tokens > 200000 and flag: - solved = False - check_solved = 'Timeout' - solve_data = {'result': 'Timeout'} - break - cnt += 1 - failed_reason = None - print(len(global_api_list), file=open(f'{output_dir}/api_list_len.txt', 'a', encoding='utf-8')) - print(global_api_list, file=open(f'{output_dir}/api_list.txt', 'a', encoding='utf-8')) - print('#'*100) - assign_results['api_list'].append(deepcopy(global_api_list)) - if stop or not flag or all([agent.finish_search for agent in agents]) and len(global_api_list) > 0: - flag = True - last_solve_time = cnt - t_s = time.time() - selected_api_list = deepcopy(global_api_list) - - solved, solve_data = solve_given_api_main(query, selected_api_list, f'{query_id}_{cnt}', messages) - - print('solve time:', time.time() - t_s, 'api number:', len(global_api_list),file=open(f'{output_dir}/time.txt', 'a', encoding='utf-8')) - result_list.append(deepcopy(solve_data['result'])) - print(solve_data['result'], solved) - if not solved or any([word in solve_data['result']['final_answer'] for word in exclusion_words]): - check_solved = 'Unsolved' - reason = solve_data['result'] - else: - check_solved, reason, tokens = check_solved_toolbench(f'{output_dir}/{query_id}_{last_solve_time}_DFS_woFilter_w2.json', query_id, task_solvable, solvable_reason) - total_tokens += tokens - print(colored((check_solved, reason), 'red')) - failed_reason = reason - dfs_data = json.load(open(f'{output_dir}/{query_id}_{last_solve_time}_DFS_woFilter_w2.json', 'r', encoding='utf-8')) - total_tokens += dfs_data['answer_generation']['total_tokens'] - solve_tokens += dfs_data['answer_generation']['total_tokens'] - if check_solved == 'Solved': - break - try: - messages = dfs_data['answer_generation']['train_messages'][-1] - except: - messages = None - api_list_to_prune = [] - for standardized_api_name, origin_api in dfs_data['api2origin'].items(): - if standardized_api_name in str(failed_reason): - if origin_api in global_api_list: - api_list_to_prune.append(origin_api) - print(colored(api_list_to_prune, 'red')) - print(len(api_list_to_prune)) - remove_apis(api_list_to_prune) - if len(global_api_list) >= max_api_number: - break - # print(api_list_to_prune, file=open(f'{output_dir}/prune_api_list.txt', 'a', encoding='utf-8')) - stop = False - else: - assert status != 'The current api list can solve the query.' - failed_reason = status - reason_list.append(failed_reason) - print(colored('Refind Begin', 'red')) - print(colored(agents_to_resume, 'red')) - print([agent.finish_search for agent in agents_to_resume]) - - - threads = [] - resume_cnt = 0 - resumed_agents.append([(str(a), a.index) for a in agents_to_resume]) - for agent in reversed(agents_to_resume): - if agent.finish_search: continue - resume_cnt += 1 - agent.failed_reason = str(failed_reason) - print(colored(('resuming', agent, agent.depth), 'red')) - print(colored(('resuming', agent, agent.depth), 'red'), file=open(f'{output_dir}/resume.txt', 'a', encoding='utf-8')) - if multi_thread: - thread = threading.Thread(target=agent.resume_search) - sem.acquire() - thread.start() - threads.append(thread) - else: - agent.resume_search() - if multi_thread: - while True: - thread_num = len(threads) - for thread in threads: - if thread.is_alive(): - thread.join() - if thread_num == len(threads): - break - if error_flag: raise Exception('GPT Call Error') - if not stop: - check_if_request_solvable() - print(colored(f'status:{status}', 'red')) - assign_results['stop'].append(stop) - except KeyboardInterrupt as e: - continue + if not stop: + check_if_request_solvable() + print(colored(f'status:{status}', 'red')) + assign_results['stop'].append(stop) assign_results['api_complete'] = flag diff --git a/toolbench/inference/LLM/chatgpt_function_model.py b/toolbench/inference/LLM/chatgpt_function_model.py index 4c8a382..c034638 100644 --- a/toolbench/inference/LLM/chatgpt_function_model.py +++ b/toolbench/inference/LLM/chatgpt_function_model.py @@ -10,25 +10,6 @@ from arguments import parse_args from config import * from openai_utils import call_gpt args = parse_args() -output_dir = args.output_dir -# import importlib -# module = importlib.import_module(args.openai_config_path.replace('.py','')) -# for name in dir(module): -# if not name.startswith('_'): -# globals()[name] = getattr(module, name) -# api_key = globals()['api_key'] -# api_version = globals()['api_version'] -# model_name = globals()['model_name'] -# api_base = globals()['api_base'] -if api_type == "azure": - from openai import AzureOpenAI as Client -else: - from openai import OpenAI as Client -client = Client( - api_key=api_key, - api_version=api_version, - azure_endpoint = api_base - ) @retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3)) def chat_completion_request(key, messages, functions=None,function_call=None,key_pos=None, model="gpt-4-32k",stop=None,process_id=0, **args): use_messages = [] @@ -52,30 +33,14 @@ def chat_completion_request(key, messages, functions=None,function_call=None,key json_data.update({"function_call": function_call}) try: - # if model in ["gpt-3.5-turbo-16k-0613","gpt-4-0613", "gpt-4-deployment","gpt-4-32k", 'gpt-4-turbo']: - # openai.api_key = key - # else: - # raise NotImplementedError - # ts = time.time() - # print(json_data, file=open('output/gpt_io.txt','a')) - # print(time.time()-ts) - ts = time.time() - # json.dump(json_data['messages'], open(os.path.join(output_dir,'messages.json'),'w'), indent=4) openai_response = call_gpt( **json_data, ) - # openai_response = client.chat.completions.create( - # **json_data, - # ) - # print('solve', time.time()-ts, file=open(os.path.join(output_dir,'time.txt'),'a')) - # json_data = json.loads(str(openai_response)) json_data = json.loads(openai_response.json()) json_data["choices"][0]['message'].pop('tool_calls') return json_data except Exception as e: - # print('solve', time.time()-ts, file=open(os.path.join(output_dir,'time.txt'),'a')) - # # json_data = json.loads(str(openai_response)) # json_data = json.loads(openai_response.json()) # json_data["choices"][0]['message'].pop('tool_calls') print("Unable to generate ChatCompletion response")