This commit is contained in:
dyabel 2024-03-02 16:25:20 +00:00
parent 5b1cba3d75
commit 2ac6a4c53d
6 changed files with 426 additions and 1226 deletions

View file

@ -25,8 +25,7 @@ Fill out the [form](https://docs.google.com/forms/d/e/1FAIpQLSdqHypmYanWU8ZhuUcr
**ToolBench**
Download the ToolBench data using the following link: [Google Drive](https://drive.google.com/drive/folders/1yBUQ732mPu-KclJnuQELEhtKakdXFc3J) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/c9e50625743b40bfbe10/).
The file structure is as follows:
Decompress the data.zip and the file structure is as follows:
```
├── /data/
│ ├── /instruction/

View file

@ -306,9 +306,8 @@ Please check whether the given task solvable with following rules:
1. If the `query` provide invalid information (e.g. invalid email address or phone number), return "Unsolvable"
2. If the `query` needs more information to solve (e.g. the target restaurant name in a navigation task), return "Unsolvable"
3. If you are unable to draw a conclusion, return "Unsure"
5. Otherwise, return "Solvable"
4. Otherwise, return "Solvable"
Remember, you should assume you have all the tools to solve the query but you do not need to answer the query at this time.
You must call the Finish function at one step.
"""
# 4. If the query is illegal or unethical or sensitive, return "Unsure"

File diff suppressed because it is too large Load diff

View file

@ -27,6 +27,4 @@ def extract_tool_data():
tool_data[root.split('/')[-1]][tool_name] = {"tool_description": tool_description}
return tool_data
tool_data = extract_tool_data()
print(tool_data.keys())
json.dump(tool_data, open("category_tool_details.json", "w", encoding='utf-8'), indent=4)
# json.dump(tool_data, open("category_tool_details_add_nonfree.json", "w", encoding='utf-8'), indent=4)

View file

@ -1 +1 @@
[]
["xxxxx"]

View file

@ -899,9 +899,9 @@ if __name__ == "__main__":
os.makedirs(output_dir, exist_ok=True)
os.makedirs('output', exist_ok=True)
success_cnt = 0
pass_cnt = 0
unsolvable_task_cnt = 0
unsolvable_list = json.load(open('misc/unsolvable.json', 'r', encoding='utf-8'))
json.dump(sorted(list(set(unsolvable_list))), open('misc/unsolvable.json', 'w', encoding='utf-8'), indent=4)
total_cnt = 0
query_data_all = json.load(open(query_path, 'r', encoding='utf-8'))
for query_data in query_data_all:
@ -930,11 +930,9 @@ if __name__ == "__main__":
assign_results['stop'] = []
ts = time.time()
resumed_agents = []
print(query_id, query, file=open(f'{output_dir}/query.txt', 'a', encoding='utf-8'))
if not args.include_unsolvable and int(query_id) in unsolvable_list:
unsolvable_task_cnt += 1
print(unsolvable_task_cnt)
print('Unsolvable human', unsolvable_task_cnt, pass_cnt, success_cnt, total_cnt, file=open(f'{output_dir}/success_cnt.txt', 'a', encoding='utf-8'))
print('Unsolvable human', unsolvable_task_cnt, success_cnt, total_cnt, file=open(f'{output_dir}/success_cnt.txt', 'a', encoding='utf-8'))
continue
total_cnt += 1
task_solvable = 'Solvable'
@ -945,8 +943,6 @@ if __name__ == "__main__":
solved = assign_results['solved']
check_solved = assign_results['check_solved']
last_solve_time = assign_results['last_solve_time']
if solved:
pass_cnt += 1
if args.recheck_solved:
check_solved, reason, _ = check_solved_toolbench(f'{output_dir}/{query_id}_{last_solve_time}_DFS_woFilter_w2.json', assign_results['query_id'])
assign_results['check_solved'] = check_solved
@ -1007,8 +1003,6 @@ if __name__ == "__main__":
break
cnt += 1
failed_reason = None
print(len(global_api_list), file=open(f'{output_dir}/api_list_len.txt', 'a', encoding='utf-8'))
print(global_api_list, file=open(f'{output_dir}/api_list.txt', 'a', encoding='utf-8'))
print('#'*100)
assign_results['api_list'].append(deepcopy(global_api_list))
if stop or not flag or all([agent.finish_search for agent in agents]) and len(global_api_list) > 0:
@ -1045,11 +1039,9 @@ if __name__ == "__main__":
if origin_api in global_api_list:
api_list_to_prune.append(origin_api)
print(colored(api_list_to_prune, 'red'))
print(len(api_list_to_prune))
remove_apis(api_list_to_prune)
if len(global_api_list) >= max_api_number:
break
# print(api_list_to_prune, file=open(f'{output_dir}/prune_api_list.txt', 'a', encoding='utf-8'))
stop = False
else:
assert status != 'The current api list can solve the query.'
@ -1108,8 +1100,6 @@ if __name__ == "__main__":
success_cnt += 1
else:
print(output_dir, 'failed', file=open(f'{output_dir}/failed.txt', 'a', encoding='utf-8'))
if solved:
pass_cnt += 1
assign_results['loop_times'] = cnt
assign_results['last_solve_time'] = last_solve_time
if 'messages' in solve_data:
@ -1144,7 +1134,7 @@ if __name__ == "__main__":
assign_results['check_solved'] = check_solved
json.dump(assign_results, open(f'{output_dir}/{query_id}.json', 'w', encoding='utf-8'), indent=4)
print(check_solved, total_tokens, time.time() - ts, query_path, file=open(f'{output_dir}/time.txt', 'a', encoding='utf-8'))
print(query_id, task_solvable, cnt, check_solved, unsolvable_task_cnt, pass_cnt, success_cnt, total_cnt, success_cnt/total_cnt, file=open(f'{output_dir}/success_cnt.txt', 'a', encoding='utf-8'))
print(query_id, check_solved, success_cnt, total_cnt, success_cnt/total_cnt, file=open(f'{output_dir}/success_cnt.txt', 'a', encoding='utf-8'))
except Exception as e:
print(e)
continue