AnyTool/anytool/dfs_gt.py
2024-02-28 10:18:12 +00:00

221 lines
10 KiB
Python

#encoding:utf-8
import os
from typing import List, Dict, Any
import re
from tqdm import tqdm
import time
from termcolor import colored
from copy import deepcopy
from anytool.api_database_function import *
from anytool.verifier import check_solved_toolbench
import os
from anytool.rapidapi import pipeline_runner
import json
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
# For pipeline environment preparation
def get_white_list(tool_root_dir):
# print(tool_root_dir)
white_list_dir = os.path.join(tool_root_dir)
white_list = {}
for cate in tqdm(os.listdir(white_list_dir)):
if not os.path.isdir(os.path.join(white_list_dir,cate)):
continue
for file in os.listdir(os.path.join(white_list_dir,cate)):
if not file.endswith(".json"):
continue
standard_tool_name = file.split(".")[0]
# print(standard_tool_name)
with open(os.path.join(white_list_dir,cate,file)) as reader:
js_data = json.load(reader)
# print(js_data)
try:
origin_tool_name = js_data["tool_name"]
except:
print('#'*100)
print('error:', 'js_data', js_data[0])
white_list[standardize(origin_tool_name)] = {"description": js_data["tool_description"], "standard_tool_name": standard_tool_name}
return white_list
def standardize(string):
# print(string)
if not isinstance(string, str):
print('*'*100)
print(string)
res = re.compile("[^\\u4e00-\\u9fa5^a-z^A-Z^0-9^_]")
string = res.sub("_", string)
string = re.sub(r"(_)\1+","_", string).lower()
while True:
if len(string) == 0:
return string
if string[0] == "_":
string = string[1:]
else:
break
while True:
if len(string) == 0:
return string
if string[-1] == "_":
string = string[:-1]
else:
break
if string[0].isdigit():
string = "get_" + string
return string
tool_root_dir = "data/toolenv/tools"
white_list = get_white_list(tool_root_dir)
def contain(candidate_list, white_list):
output = []
for cand in candidate_list:
if cand not in white_list.keys():
return False
# print(white_list[cand])
output.append(white_list[cand])
return output
def change_name(name):
change_list = ["from", "class", "return", "false", "true", "id", "and"]
if name in change_list:
name = "is_" + name
return name
def solve_given_api_main(query, api_list, i, messages=None):
answer_dir = dfs_args.output_answer_file
if not os.path.exists(answer_dir):
os.mkdir(answer_dir)
if os.path.exists(os.path.join(answer_dir, f'{i}_DFS_woFilter_w2.json')):
os.remove(os.path.join(answer_dir, f'{i}_DFS_woFilter_w2.json'))
method = dfs_args.method
backbone_model = dfs_runner.backbone_model
data_dict = {}
result_data = {}
data_dict['query'] = query
data_dict['api_list'] = api_list
origin_tool_names = [standardize(cont["tool_name"]) for cont in api_list]
tool_des = contain(origin_tool_names,white_list)
if tool_des == False:
result_data = {'result': 'no tool description'}
return False, result_data
tool_des = [[cont["standard_tool_name"], cont["description"]] for cont in tool_des]
task = (method, backbone_model, i, data_dict, dfs_args, answer_dir, tool_des)
for _ in range(3):
dfs_runner.run(task, messages)
result = json.load(open(os.path.join(answer_dir, f'{i}_DFS_woFilter_w2.json'), 'r', encoding='utf-8'))
try:
result_data['result'] = json.loads(result['answer_generation']['final_answer'])
except:
print(result['answer_generation']['final_answer'])
final_answer_str = result['answer_generation']['final_answer']
return_type = final_answer_str[final_answer_str.find('"return_type": "')+len('"return_type": "'):final_answer_str.find('",')]
result_data['result'] = {
"return_type": return_type,
}
if '"final_answer": "' in final_answer_str:
final_answer = final_answer_str[final_answer_str.find('"final_answer": "')+len('"final_answer": "'):]
result_data['result']['final_answer'] = final_answer
elif return_type == 'give_answer':
continue
if '"reason": "' in final_answer_str:
reason = final_answer_str[final_answer_str.find('"reason": "')+len('"reason": "'):]
result_data['result']['reason'] = reason
result['answer_generation']['final_answer'] = json.dumps(result_data['result'])
if result['answer_generation']['finish_type'] == 'give_answer' and 'final_answer' in result_data['result'] and result_data['result']['final_answer'] != '':
# and not any(word in str(result['answer_generation']['final_answer']).lower() for word in exclusion_words)
solved = True
else:
solved = False
return solved, result_data
result_data['result']['final_answer'] = ''
return False, result_data
from arguments import parse_args
args = parse_args()
output_path = args.output_dir
# output_path = f'{query_dir}/reassign_toolllama_dfs_r1'
os.makedirs(output_path, exist_ok=True)
dfs_args = dotdict(dict(backbone_model='chatgpt_function', openai_key='', model_path='your_model_path/', tool_root_dir='data/toolenv/tools/', lora=False, lora_path='your_lora_path if lora', max_observation_length=1024, max_source_sequence_length=4096, max_sequence_length=8192, observ_compress_method='truncate', method='DFS_woFilter_w2', input_query_file='data/test_instruction/G1_tool.json', output_answer_file=output_path, toolbench_key=toolbench_key, rapidapi_key='', use_rapidapi_key=False, api_customization=False))
dfs_runner = pipeline_runner(dfs_args)
if __name__ == '__main__':
retrieved_api_nums = 10
query_list = []
cnt = 0
success = 0
no_return_type_cnt = 0
failed = []
task_solvable = 'Solvable'
solvable_reason = 'Solvable checked by human'
# for root, dirs, files in os.walk('result/generated_solve_given_api_solvable2'):
solved_dict = json.load(open('solved_dict.json', 'r', encoding='utf-8'))
for i in range(262):
t_s = time.time()
comparison_data = {}
# for file in files:
# if file.endswith('.json'):
# print(file)
data_load = json.load(open(f'{args.query_dir}/{i}.json', 'r', encoding='utf-8'))
if str(data_load['query_id']) in solved_dict and solved_dict[str(data_load['query_id'])]['solved'] != 'Solved':
continue
query = data_load['query']
# continue
cnt += 1
# if cnt > 50:
# break
if os.path.exists(os.path.join(output_path, f'{i}_DFS_woFilter_w2.json')):
data = json.load(open(os.path.join(output_path, f'{i}_DFS_woFilter_w2.json'), 'r', encoding='utf-8'))
final_data = json.load(open(os.path.join(output_path, f'{i}.json'), 'r', encoding='utf-8'))
if data['answer_generation']['finish_type'] != 'give_answer':
print(i)
if 'final_answer' in data['answer_generation'] and not any(word in data['answer_generation']['final_answer'].lower() for word in exclusion_words):
if 'check_solved' in final_data:
check_solved = final_data['check_solved']
reason = final_data['reason']
else:
check_solved, reason = check_solved_toolbench(f'{output_dir}/{i}_DFS_woFilter_w2.json', i, data_load['query_id'], task_solvable, solvable_reason)
if check_solved == 'Solved':
success += 1
else:
check_solved = 'Unsolved'
else:
check_solved = 'Unsolved'
print(output_path, i, file=open(os.path.join(output_path, 'failed.txt'), 'a'))
print(success, cnt, i+1, file=open(os.path.join(output_path, 'success_cnt.txt'), 'a'))
continue
find_api_messages_to_save = []
messages_to_save = []
try:
gt_api_list = [{'category_name': api.get('category_name', ''), 'tool_name':api.get('tool_name', ''),'api_name':api.get('api_name', '') }for api in data_load['api_list'][-1]]
# gt_api_list = [{'category_name': api.get('category_name', ''), 'tool_name':api.get('tool_name', ''),'api_name':api.get('api_name', '') }for api in data_load['gt_api_list']]
comparison_data['gt_api_list'] = gt_api_list
except:
pass
print('#'*100, file=open(os.path.join(output_path, 'time.txt'), 'a'))
solved, result_data = solve_given_api_main(query, data_load['gt_api_list'], i)
if solved:
check_solved, reason, _ = check_solved_toolbench(f'{output_dir}/{i}_DFS_woFilter_w2.json', i, data_load['query_id'], task_solvable, solvable_reason)
if check_solved == 'Solved':
success += 1
else:
check_solved = 'Unsolved'
reason = ''
print(output_path, i, file=open(os.path.join(output_path, 'failed.txt'),'a'))
print(success, cnt, i+1, file=open(os.path.join(output_path, 'success_cnt.txt'), 'a'))
final_data = {}
final_data['query_id'] = data_load['query_id']
final_data['query'] = data_load['query']
final_data['gt_api_list'] = data_load['gt_api_list']
final_data['gt_answer'] = data_load['final_answer']
final_data['result'] = result_data
final_data['check_solved'] = check_solved
final_data['reason'] = reason
json.dump(final_data, open(os.path.join(output_path, f'{i}.json'), 'w', encoding='utf-8'), ensure_ascii=False, indent=4)
print(f'time: {time.time()-t_s}', file=open(os.path.join(output_path, 'time.txt'),'a'))