clean
This commit is contained in:
parent
5b1cba3d75
commit
2ac6a4c53d
6 changed files with 426 additions and 1226 deletions
|
|
@ -25,8 +25,7 @@ Fill out the [form](https://docs.google.com/forms/d/e/1FAIpQLSdqHypmYanWU8ZhuUcr
|
|||
**ToolBench**
|
||||
|
||||
Download the ToolBench data using the following link: [Google Drive](https://drive.google.com/drive/folders/1yBUQ732mPu-KclJnuQELEhtKakdXFc3J) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/c9e50625743b40bfbe10/).
|
||||
|
||||
The file structure is as follows:
|
||||
Decompress the data.zip and the file structure is as follows:
|
||||
```
|
||||
├── /data/
|
||||
│ ├── /instruction/
|
||||
|
|
|
|||
|
|
@ -306,9 +306,8 @@ Please check whether the given task solvable with following rules:
|
|||
1. If the `query` provide invalid information (e.g. invalid email address or phone number), return "Unsolvable"
|
||||
2. If the `query` needs more information to solve (e.g. the target restaurant name in a navigation task), return "Unsolvable"
|
||||
3. If you are unable to draw a conclusion, return "Unsure"
|
||||
5. Otherwise, return "Solvable"
|
||||
4. Otherwise, return "Solvable"
|
||||
Remember, you should assume you have all the tools to solve the query but you do not need to answer the query at this time.
|
||||
|
||||
You must call the Finish function at one step.
|
||||
"""
|
||||
# 4. If the query is illegal or unethical or sensitive, return "Unsure"
|
||||
"""
|
||||
1624
misc/unsolvable.json
1624
misc/unsolvable.json
File diff suppressed because it is too large
Load diff
|
|
@ -27,6 +27,4 @@ def extract_tool_data():
|
|||
tool_data[root.split('/')[-1]][tool_name] = {"tool_description": tool_description}
|
||||
return tool_data
|
||||
tool_data = extract_tool_data()
|
||||
print(tool_data.keys())
|
||||
json.dump(tool_data, open("category_tool_details.json", "w", encoding='utf-8'), indent=4)
|
||||
# json.dump(tool_data, open("category_tool_details_add_nonfree.json", "w", encoding='utf-8'), indent=4)
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
[]
|
||||
["xxxxx"]
|
||||
|
|
@ -899,9 +899,9 @@ if __name__ == "__main__":
|
|||
os.makedirs(output_dir, exist_ok=True)
|
||||
os.makedirs('output', exist_ok=True)
|
||||
success_cnt = 0
|
||||
pass_cnt = 0
|
||||
unsolvable_task_cnt = 0
|
||||
unsolvable_list = json.load(open('misc/unsolvable.json', 'r', encoding='utf-8'))
|
||||
json.dump(sorted(list(set(unsolvable_list))), open('misc/unsolvable.json', 'w', encoding='utf-8'), indent=4)
|
||||
total_cnt = 0
|
||||
query_data_all = json.load(open(query_path, 'r', encoding='utf-8'))
|
||||
for query_data in query_data_all:
|
||||
|
|
@ -930,11 +930,9 @@ if __name__ == "__main__":
|
|||
assign_results['stop'] = []
|
||||
ts = time.time()
|
||||
resumed_agents = []
|
||||
print(query_id, query, file=open(f'{output_dir}/query.txt', 'a', encoding='utf-8'))
|
||||
if not args.include_unsolvable and int(query_id) in unsolvable_list:
|
||||
unsolvable_task_cnt += 1
|
||||
print(unsolvable_task_cnt)
|
||||
print('Unsolvable human', unsolvable_task_cnt, pass_cnt, success_cnt, total_cnt, file=open(f'{output_dir}/success_cnt.txt', 'a', encoding='utf-8'))
|
||||
print('Unsolvable human', unsolvable_task_cnt, success_cnt, total_cnt, file=open(f'{output_dir}/success_cnt.txt', 'a', encoding='utf-8'))
|
||||
continue
|
||||
total_cnt += 1
|
||||
task_solvable = 'Solvable'
|
||||
|
|
@ -945,8 +943,6 @@ if __name__ == "__main__":
|
|||
solved = assign_results['solved']
|
||||
check_solved = assign_results['check_solved']
|
||||
last_solve_time = assign_results['last_solve_time']
|
||||
if solved:
|
||||
pass_cnt += 1
|
||||
if args.recheck_solved:
|
||||
check_solved, reason, _ = check_solved_toolbench(f'{output_dir}/{query_id}_{last_solve_time}_DFS_woFilter_w2.json', assign_results['query_id'])
|
||||
assign_results['check_solved'] = check_solved
|
||||
|
|
@ -1007,8 +1003,6 @@ if __name__ == "__main__":
|
|||
break
|
||||
cnt += 1
|
||||
failed_reason = None
|
||||
print(len(global_api_list), file=open(f'{output_dir}/api_list_len.txt', 'a', encoding='utf-8'))
|
||||
print(global_api_list, file=open(f'{output_dir}/api_list.txt', 'a', encoding='utf-8'))
|
||||
print('#'*100)
|
||||
assign_results['api_list'].append(deepcopy(global_api_list))
|
||||
if stop or not flag or all([agent.finish_search for agent in agents]) and len(global_api_list) > 0:
|
||||
|
|
@ -1045,11 +1039,9 @@ if __name__ == "__main__":
|
|||
if origin_api in global_api_list:
|
||||
api_list_to_prune.append(origin_api)
|
||||
print(colored(api_list_to_prune, 'red'))
|
||||
print(len(api_list_to_prune))
|
||||
remove_apis(api_list_to_prune)
|
||||
if len(global_api_list) >= max_api_number:
|
||||
break
|
||||
# print(api_list_to_prune, file=open(f'{output_dir}/prune_api_list.txt', 'a', encoding='utf-8'))
|
||||
stop = False
|
||||
else:
|
||||
assert status != 'The current api list can solve the query.'
|
||||
|
|
@ -1108,8 +1100,6 @@ if __name__ == "__main__":
|
|||
success_cnt += 1
|
||||
else:
|
||||
print(output_dir, 'failed', file=open(f'{output_dir}/failed.txt', 'a', encoding='utf-8'))
|
||||
if solved:
|
||||
pass_cnt += 1
|
||||
assign_results['loop_times'] = cnt
|
||||
assign_results['last_solve_time'] = last_solve_time
|
||||
if 'messages' in solve_data:
|
||||
|
|
@ -1144,7 +1134,7 @@ if __name__ == "__main__":
|
|||
assign_results['check_solved'] = check_solved
|
||||
json.dump(assign_results, open(f'{output_dir}/{query_id}.json', 'w', encoding='utf-8'), indent=4)
|
||||
print(check_solved, total_tokens, time.time() - ts, query_path, file=open(f'{output_dir}/time.txt', 'a', encoding='utf-8'))
|
||||
print(query_id, task_solvable, cnt, check_solved, unsolvable_task_cnt, pass_cnt, success_cnt, total_cnt, success_cnt/total_cnt, file=open(f'{output_dir}/success_cnt.txt', 'a', encoding='utf-8'))
|
||||
print(query_id, check_solved, success_cnt, total_cnt, success_cnt/total_cnt, file=open(f'{output_dir}/success_cnt.txt', 'a', encoding='utf-8'))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
continue
|
||||
Loading…
Reference in a new issue