231 lines
10 KiB
Python
231 lines
10 KiB
Python
#encoding:utf-8
|
|
|
|
import openai
|
|
import os
|
|
from typing import List, Dict, Any
|
|
import re
|
|
from tqdm import tqdm
|
|
import time
|
|
import requests
|
|
from termcolor import colored
|
|
from copy import deepcopy
|
|
from api_database_function import *
|
|
from verifier import check_solved_toolbench
|
|
import os
|
|
from rapidapi import pipeline_runner
|
|
|
|
from typing import Any, Callable
|
|
from openai_function_calling import FunctionInferer
|
|
import openai
|
|
import json
|
|
# query_data = json.load(open('G1_instruction_query_failed.json', 'r', encoding='utf-8'))
|
|
# Define example functions.
|
|
from flask import Flask, jsonify, request
|
|
|
|
class dotdict(dict):
|
|
"""dot.notation access to dictionary attributes"""
|
|
__getattr__ = dict.get
|
|
__setattr__ = dict.__setitem__
|
|
__delattr__ = dict.__delitem__
|
|
# For pipeline environment preparation
|
|
def get_white_list(tool_root_dir):
|
|
# print(tool_root_dir)
|
|
white_list_dir = os.path.join(tool_root_dir)
|
|
white_list = {}
|
|
for cate in tqdm(os.listdir(white_list_dir)):
|
|
if not os.path.isdir(os.path.join(white_list_dir,cate)):
|
|
continue
|
|
for file in os.listdir(os.path.join(white_list_dir,cate)):
|
|
if not file.endswith(".json"):
|
|
continue
|
|
standard_tool_name = file.split(".")[0]
|
|
# print(standard_tool_name)
|
|
with open(os.path.join(white_list_dir,cate,file)) as reader:
|
|
js_data = json.load(reader)
|
|
# print(js_data)
|
|
try:
|
|
origin_tool_name = js_data["tool_name"]
|
|
except:
|
|
print('#'*100)
|
|
print('error:', 'js_data', js_data[0])
|
|
|
|
white_list[standardize(origin_tool_name)] = {"description": js_data["tool_description"], "standard_tool_name": standard_tool_name}
|
|
return white_list
|
|
|
|
def standardize(string):
|
|
# print(string)
|
|
if not isinstance(string, str):
|
|
print('*'*100)
|
|
print(string)
|
|
res = re.compile("[^\\u4e00-\\u9fa5^a-z^A-Z^0-9^_]")
|
|
string = res.sub("_", string)
|
|
string = re.sub(r"(_)\1+","_", string).lower()
|
|
while True:
|
|
if len(string) == 0:
|
|
return string
|
|
if string[0] == "_":
|
|
string = string[1:]
|
|
else:
|
|
break
|
|
while True:
|
|
if len(string) == 0:
|
|
return string
|
|
if string[-1] == "_":
|
|
string = string[:-1]
|
|
else:
|
|
break
|
|
if string[0].isdigit():
|
|
string = "get_" + string
|
|
return string
|
|
|
|
tool_root_dir = "data/toolenv/tools"
|
|
white_list = get_white_list(tool_root_dir)
|
|
def contain(candidate_list, white_list):
|
|
output = []
|
|
for cand in candidate_list:
|
|
if cand not in white_list.keys():
|
|
return False
|
|
# print(white_list[cand])
|
|
output.append(white_list[cand])
|
|
return output
|
|
def change_name(name):
|
|
change_list = ["from", "class", "return", "false", "true", "id", "and"]
|
|
if name in change_list:
|
|
name = "is_" + name
|
|
return name
|
|
|
|
def solve_given_api_main(query, api_list, i, messages=None):
|
|
answer_dir = dfs_args.output_answer_file
|
|
if not os.path.exists(answer_dir):
|
|
os.mkdir(answer_dir)
|
|
if os.path.exists(os.path.join(answer_dir, f'{i}_DFS_woFilter_w2.json')):
|
|
os.remove(os.path.join(answer_dir, f'{i}_DFS_woFilter_w2.json'))
|
|
method = dfs_args.method
|
|
backbone_model = dfs_runner.backbone_model
|
|
data_dict = {}
|
|
result_data = {}
|
|
data_dict['query'] = query
|
|
data_dict['api_list'] = api_list
|
|
origin_tool_names = [standardize(cont["tool_name"]) for cont in api_list]
|
|
tool_des = contain(origin_tool_names,white_list)
|
|
if tool_des == False:
|
|
result_data = {'result': 'no tool description'}
|
|
return False, result_data
|
|
tool_des = [[cont["standard_tool_name"], cont["description"]] for cont in tool_des]
|
|
task = (method, backbone_model, i, data_dict, dfs_args, answer_dir, tool_des)
|
|
for _ in range(3):
|
|
dfs_runner.run(task, messages)
|
|
result = json.load(open(os.path.join(answer_dir, f'{i}_DFS_woFilter_w2.json'), 'r', encoding='utf-8'))
|
|
try:
|
|
result_data['result'] = json.loads(result['answer_generation']['final_answer'])
|
|
except:
|
|
print(result['answer_generation']['final_answer'])
|
|
final_answer_str = result['answer_generation']['final_answer']
|
|
return_type = final_answer_str[final_answer_str.find('"return_type": "')+len('"return_type": "'):final_answer_str.find('",')]
|
|
result_data['result'] = {
|
|
"return_type": return_type,
|
|
}
|
|
if '"final_answer": "' in final_answer_str:
|
|
final_answer = final_answer_str[final_answer_str.find('"final_answer": "')+len('"final_answer": "'):]
|
|
result_data['result']['final_answer'] = final_answer
|
|
elif return_type == 'give_answer':
|
|
continue
|
|
if '"reason": "' in final_answer_str:
|
|
reason = final_answer_str[final_answer_str.find('"reason": "')+len('"reason": "'):]
|
|
result_data['result']['reason'] = reason
|
|
result['answer_generation']['final_answer'] = json.dumps(result_data['result'])
|
|
if result['answer_generation']['finish_type'] == 'give_answer' and 'final_answer' in result_data['result'] and result_data['result']['final_answer'] != '':
|
|
# and not any(word in str(result['answer_generation']['final_answer']).lower() for word in exclusion_words)
|
|
solved = True
|
|
else:
|
|
solved = False
|
|
return solved, result_data
|
|
result_data['result']['final_answer'] = ''
|
|
return False, result_data
|
|
|
|
from arguments import parse_args
|
|
args = parse_args()
|
|
output_path = args.output_dir
|
|
|
|
# output_path = f'{query_dir}/reassign_toolllama_dfs_r1'
|
|
os.makedirs(output_path, exist_ok=True)
|
|
dfs_args = dotdict(dict(backbone_model='chatgpt_function', openai_key='', model_path='your_model_path/', tool_root_dir='data/toolenv/tools/', lora=False, lora_path='your_lora_path if lora', max_observation_length=1024, max_source_sequence_length=4096, max_sequence_length=8192, observ_compress_method='truncate', method='DFS_woFilter_w2', input_query_file='data/test_instruction/G1_tool.json', output_answer_file=output_path, toolbench_key=toolbench_key, rapidapi_key='', use_rapidapi_key=False, api_customization=False))
|
|
dfs_runner = pipeline_runner(dfs_args)
|
|
|
|
if __name__ == '__main__':
|
|
retrieved_api_nums = 10
|
|
query_list = []
|
|
cnt = 0
|
|
success = 0
|
|
no_return_type_cnt = 0
|
|
failed = []
|
|
task_solvable = 'Solvable'
|
|
solvable_reason = 'Solvable checked by human'
|
|
# for root, dirs, files in os.walk('result/generated_solve_given_api_solvable2'):
|
|
solved_dict = json.load(open('solved_dict.json', 'r', encoding='utf-8'))
|
|
for i in range(262):
|
|
t_s = time.time()
|
|
comparison_data = {}
|
|
# for file in files:
|
|
# if file.endswith('.json'):
|
|
# print(file)
|
|
data_load = json.load(open(f'{args.query_dir}/{i}.json', 'r', encoding='utf-8'))
|
|
if str(data_load['query_id']) in solved_dict and solved_dict[str(data_load['query_id'])]['solved'] != 'Solved':
|
|
continue
|
|
|
|
query = data_load['query']
|
|
# continue
|
|
cnt += 1
|
|
# if cnt > 50:
|
|
# break
|
|
if os.path.exists(os.path.join(output_path, f'{i}_DFS_woFilter_w2.json')):
|
|
data = json.load(open(os.path.join(output_path, f'{i}_DFS_woFilter_w2.json'), 'r', encoding='utf-8'))
|
|
final_data = json.load(open(os.path.join(output_path, f'{i}.json'), 'r', encoding='utf-8'))
|
|
if data['answer_generation']['finish_type'] != 'give_answer':
|
|
print(i)
|
|
if 'final_answer' in data['answer_generation'] and not any(word in data['answer_generation']['final_answer'].lower() for word in exclusion_words):
|
|
if 'check_solved' in final_data:
|
|
check_solved = final_data['check_solved']
|
|
reason = final_data['reason']
|
|
else:
|
|
check_solved, reason = check_solved_toolbench(f'{output_dir}/{i}_DFS_woFilter_w2.json', i, data_load['query_id'], task_solvable, solvable_reason)
|
|
if check_solved == 'Solved':
|
|
success += 1
|
|
else:
|
|
check_solved = 'Unsolved'
|
|
else:
|
|
check_solved = 'Unsolved'
|
|
print(output_path, i, file=open(os.path.join(output_path, 'failed.txt'), 'a'))
|
|
print(success, cnt, i+1, file=open(os.path.join(output_path, 'success_cnt.txt'), 'a'))
|
|
continue
|
|
find_api_messages_to_save = []
|
|
messages_to_save = []
|
|
try:
|
|
gt_api_list = [{'category_name': api.get('category_name', ''), 'tool_name':api.get('tool_name', ''),'api_name':api.get('api_name', '') }for api in data_load['api_list'][-1]]
|
|
# gt_api_list = [{'category_name': api.get('category_name', ''), 'tool_name':api.get('tool_name', ''),'api_name':api.get('api_name', '') }for api in data_load['gt_api_list']]
|
|
comparison_data['gt_api_list'] = gt_api_list
|
|
except:
|
|
pass
|
|
print('#'*100, file=open(os.path.join(output_path, 'time.txt'), 'a'))
|
|
solved, result_data = solve_given_api_main(query, data_load['gt_api_list'], i)
|
|
if solved:
|
|
check_solved, reason, _ = check_solved_toolbench(f'{output_dir}/{i}_DFS_woFilter_w2.json', i, data_load['query_id'], task_solvable, solvable_reason)
|
|
if check_solved == 'Solved':
|
|
success += 1
|
|
else:
|
|
check_solved = 'Unsolved'
|
|
reason = ''
|
|
print(output_path, i, file=open(os.path.join(output_path, 'failed.txt'),'a'))
|
|
print(success, cnt, i+1, file=open(os.path.join(output_path, 'success_cnt.txt'), 'a'))
|
|
final_data = {}
|
|
final_data['query_id'] = data_load['query_id']
|
|
final_data['query'] = data_load['query']
|
|
final_data['gt_api_list'] = data_load['gt_api_list']
|
|
final_data['gt_answer'] = data_load['final_answer']
|
|
final_data['result'] = result_data
|
|
final_data['check_solved'] = check_solved
|
|
final_data['reason'] = reason
|
|
json.dump(final_data, open(os.path.join(output_path, f'{i}.json'), 'w', encoding='utf-8'), ensure_ascii=False, indent=4)
|
|
|
|
print(f'time: {time.time()-t_s}', file=open(os.path.join(output_path, 'time.txt'),'a'))
|
|
|