This commit is contained in:
dyabel 2024-03-02 16:25:20 +00:00
parent 5b1cba3d75
commit 2ac6a4c53d
6 changed files with 426 additions and 1226 deletions

View file

@ -25,8 +25,7 @@ Fill out the [form](https://docs.google.com/forms/d/e/1FAIpQLSdqHypmYanWU8ZhuUcr
**ToolBench** **ToolBench**
Download the ToolBench data using the following link: [Google Drive](https://drive.google.com/drive/folders/1yBUQ732mPu-KclJnuQELEhtKakdXFc3J) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/c9e50625743b40bfbe10/). Download the ToolBench data using the following link: [Google Drive](https://drive.google.com/drive/folders/1yBUQ732mPu-KclJnuQELEhtKakdXFc3J) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/c9e50625743b40bfbe10/).
Decompress the data.zip and the file structure is as follows:
The file structure is as follows:
``` ```
├── /data/ ├── /data/
│ ├── /instruction/ │ ├── /instruction/

View file

@ -306,9 +306,8 @@ Please check whether the given task solvable with following rules:
1. If the `query` provide invalid information (e.g. invalid email address or phone number), return "Unsolvable" 1. If the `query` provide invalid information (e.g. invalid email address or phone number), return "Unsolvable"
2. If the `query` needs more information to solve (e.g. the target restaurant name in a navigation task), return "Unsolvable" 2. If the `query` needs more information to solve (e.g. the target restaurant name in a navigation task), return "Unsolvable"
3. If you are unable to draw a conclusion, return "Unsure" 3. If you are unable to draw a conclusion, return "Unsure"
5. Otherwise, return "Solvable" 4. Otherwise, return "Solvable"
Remember, you should assume you have all the tools to solve the query but you do not need to answer the query at this time. Remember, you should assume you have all the tools to solve the query but you do not need to answer the query at this time.
You must call the Finish function at one step. You must call the Finish function at one step.
""" """
# 4. If the query is illegal or unethical or sensitive, return "Unsure"

File diff suppressed because it is too large Load diff

View file

@ -27,6 +27,4 @@ def extract_tool_data():
tool_data[root.split('/')[-1]][tool_name] = {"tool_description": tool_description} tool_data[root.split('/')[-1]][tool_name] = {"tool_description": tool_description}
return tool_data return tool_data
tool_data = extract_tool_data() tool_data = extract_tool_data()
print(tool_data.keys())
json.dump(tool_data, open("category_tool_details.json", "w", encoding='utf-8'), indent=4) json.dump(tool_data, open("category_tool_details.json", "w", encoding='utf-8'), indent=4)
# json.dump(tool_data, open("category_tool_details_add_nonfree.json", "w", encoding='utf-8'), indent=4)

View file

@ -1 +1 @@
[] ["xxxxx"]

View file

@ -899,9 +899,9 @@ if __name__ == "__main__":
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
os.makedirs('output', exist_ok=True) os.makedirs('output', exist_ok=True)
success_cnt = 0 success_cnt = 0
pass_cnt = 0
unsolvable_task_cnt = 0 unsolvable_task_cnt = 0
unsolvable_list = json.load(open('misc/unsolvable.json', 'r', encoding='utf-8')) unsolvable_list = json.load(open('misc/unsolvable.json', 'r', encoding='utf-8'))
json.dump(sorted(list(set(unsolvable_list))), open('misc/unsolvable.json', 'w', encoding='utf-8'), indent=4)
total_cnt = 0 total_cnt = 0
query_data_all = json.load(open(query_path, 'r', encoding='utf-8')) query_data_all = json.load(open(query_path, 'r', encoding='utf-8'))
for query_data in query_data_all: for query_data in query_data_all:
@ -930,11 +930,9 @@ if __name__ == "__main__":
assign_results['stop'] = [] assign_results['stop'] = []
ts = time.time() ts = time.time()
resumed_agents = [] resumed_agents = []
print(query_id, query, file=open(f'{output_dir}/query.txt', 'a', encoding='utf-8'))
if not args.include_unsolvable and int(query_id) in unsolvable_list: if not args.include_unsolvable and int(query_id) in unsolvable_list:
unsolvable_task_cnt += 1 unsolvable_task_cnt += 1
print(unsolvable_task_cnt) print('Unsolvable human', unsolvable_task_cnt, success_cnt, total_cnt, file=open(f'{output_dir}/success_cnt.txt', 'a', encoding='utf-8'))
print('Unsolvable human', unsolvable_task_cnt, pass_cnt, success_cnt, total_cnt, file=open(f'{output_dir}/success_cnt.txt', 'a', encoding='utf-8'))
continue continue
total_cnt += 1 total_cnt += 1
task_solvable = 'Solvable' task_solvable = 'Solvable'
@ -945,8 +943,6 @@ if __name__ == "__main__":
solved = assign_results['solved'] solved = assign_results['solved']
check_solved = assign_results['check_solved'] check_solved = assign_results['check_solved']
last_solve_time = assign_results['last_solve_time'] last_solve_time = assign_results['last_solve_time']
if solved:
pass_cnt += 1
if args.recheck_solved: if args.recheck_solved:
check_solved, reason, _ = check_solved_toolbench(f'{output_dir}/{query_id}_{last_solve_time}_DFS_woFilter_w2.json', assign_results['query_id']) check_solved, reason, _ = check_solved_toolbench(f'{output_dir}/{query_id}_{last_solve_time}_DFS_woFilter_w2.json', assign_results['query_id'])
assign_results['check_solved'] = check_solved assign_results['check_solved'] = check_solved
@ -1007,8 +1003,6 @@ if __name__ == "__main__":
break break
cnt += 1 cnt += 1
failed_reason = None failed_reason = None
print(len(global_api_list), file=open(f'{output_dir}/api_list_len.txt', 'a', encoding='utf-8'))
print(global_api_list, file=open(f'{output_dir}/api_list.txt', 'a', encoding='utf-8'))
print('#'*100) print('#'*100)
assign_results['api_list'].append(deepcopy(global_api_list)) assign_results['api_list'].append(deepcopy(global_api_list))
if stop or not flag or all([agent.finish_search for agent in agents]) and len(global_api_list) > 0: if stop or not flag or all([agent.finish_search for agent in agents]) and len(global_api_list) > 0:
@ -1045,11 +1039,9 @@ if __name__ == "__main__":
if origin_api in global_api_list: if origin_api in global_api_list:
api_list_to_prune.append(origin_api) api_list_to_prune.append(origin_api)
print(colored(api_list_to_prune, 'red')) print(colored(api_list_to_prune, 'red'))
print(len(api_list_to_prune))
remove_apis(api_list_to_prune) remove_apis(api_list_to_prune)
if len(global_api_list) >= max_api_number: if len(global_api_list) >= max_api_number:
break break
# print(api_list_to_prune, file=open(f'{output_dir}/prune_api_list.txt', 'a', encoding='utf-8'))
stop = False stop = False
else: else:
assert status != 'The current api list can solve the query.' assert status != 'The current api list can solve the query.'
@ -1108,8 +1100,6 @@ if __name__ == "__main__":
success_cnt += 1 success_cnt += 1
else: else:
print(output_dir, 'failed', file=open(f'{output_dir}/failed.txt', 'a', encoding='utf-8')) print(output_dir, 'failed', file=open(f'{output_dir}/failed.txt', 'a', encoding='utf-8'))
if solved:
pass_cnt += 1
assign_results['loop_times'] = cnt assign_results['loop_times'] = cnt
assign_results['last_solve_time'] = last_solve_time assign_results['last_solve_time'] = last_solve_time
if 'messages' in solve_data: if 'messages' in solve_data:
@ -1144,7 +1134,7 @@ if __name__ == "__main__":
assign_results['check_solved'] = check_solved assign_results['check_solved'] = check_solved
json.dump(assign_results, open(f'{output_dir}/{query_id}.json', 'w', encoding='utf-8'), indent=4) json.dump(assign_results, open(f'{output_dir}/{query_id}.json', 'w', encoding='utf-8'), indent=4)
print(check_solved, total_tokens, time.time() - ts, query_path, file=open(f'{output_dir}/time.txt', 'a', encoding='utf-8')) print(check_solved, total_tokens, time.time() - ts, query_path, file=open(f'{output_dir}/time.txt', 'a', encoding='utf-8'))
print(query_id, task_solvable, cnt, check_solved, unsolvable_task_cnt, pass_cnt, success_cnt, total_cnt, success_cnt/total_cnt, file=open(f'{output_dir}/success_cnt.txt', 'a', encoding='utf-8')) print(query_id, check_solved, success_cnt, total_cnt, success_cnt/total_cnt, file=open(f'{output_dir}/success_cnt.txt', 'a', encoding='utf-8'))
except Exception as e: except Exception as e:
print(e) print(e)
continue continue