clean
This commit is contained in:
parent
1182893a23
commit
5b1cba3d75
6 changed files with 173 additions and 432 deletions
|
|
@ -96,186 +96,3 @@ def compute_pass_rate(query_id, example, task_solvable=None, task_solvable_reaso
|
||||||
else:
|
else:
|
||||||
label = "failed"
|
label = "failed"
|
||||||
return query_id, task_solvable, is_solved, label, reason, not_hallucinate, tokens
|
return query_id, task_solvable, is_solved, label, reason, not_hallucinate, tokens
|
||||||
# output_dir = f'result1/generated_solve_given_api_solvable_multicat_complex_r1/stack_reassign_solve_results_turbo_r16'
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# reassign = False
|
|
||||||
test_sets = ["G1_instruction", "G1_tool", "G1_category", "G2_instruction", "G2_category", "G3_instruction"]
|
|
||||||
# test_sets = ["G1_tool", "G1_category", "G2_instruction", "G2_category", "G3_instruction"]
|
|
||||||
# test_sets = ['custom_data']
|
|
||||||
# test_sets = ['G1_instruction']
|
|
||||||
# test_sets = [ "G1_category","G1_instruction", "G1_tool", "G2_instruction"]
|
|
||||||
# test_sets = ['G2_instruction', 'G3_instruction']
|
|
||||||
unsolvable_list = json.load(open("unsolvable.json", "r", encoding="utf-8"))
|
|
||||||
# unsolvable_list = []
|
|
||||||
pass_rate_list = []
|
|
||||||
average_tokens_list = []
|
|
||||||
for test_set in test_sets:
|
|
||||||
total_tokens = 0
|
|
||||||
# query_dir = f'data/test_instruction/{test_set}'
|
|
||||||
# output_dir = f'result2/test_instruction/{test_set}'
|
|
||||||
# output_dir = f'result0111/turbo/test_instruction/{test_set}_r1'
|
|
||||||
# output_dir = f'data/reproduction_data/model_predictions/chatgpt_dfs/{test_set}'
|
|
||||||
# output_dir = f'data/reproduction_data/model_predictions/toolllama_dfs/{test_set}'
|
|
||||||
output_dir = f'data/reproduction_data/model_predictions/toolllama_dfs_retriever/{test_set}'
|
|
||||||
# output_dir = f'data/reproduction_data/model_predictions/gpt-4-0613_dfs/{test_set}'
|
|
||||||
# output_dir = f'data/reproduction_data/model_predictions/chatgpt_cot/{test_set}'
|
|
||||||
# output_dir = f'data/reproduction_data/model_predictions/gpt-4-0613_cot/{test_set}'
|
|
||||||
# 33.5&33.5&41.0&23.5&29.5&3.0 27.3
|
|
||||||
# output_dir = f'result0111/32k/test_instruction/{test_set}_r1'
|
|
||||||
# output_dir = f'result0111/32k/max32/test_instruction/{test_set}_r1'
|
|
||||||
# output_dir = f'result_final/toolbench/{test_set}'
|
|
||||||
# output_dir = f'result0126/toolbench/{test_set}'
|
|
||||||
# output_dir = f'repos/toolbench_ori/{test_set}_filtered/gpt4_retriever_dfs'
|
|
||||||
# output_dir = f'repos/toolbench_ori/{test_set}_filtered/toolllama_retriever_ada_dfs'
|
|
||||||
# output_dir = f'data/reproduction_data/model_predictions/gpt-35-turbo_dfs/{test_set}'
|
|
||||||
# output_dir = 'result_final/custom_data/gpt_dfs_retriever'
|
|
||||||
# output_dir = 'result_final/custom_data/toolllama_dfs_retriever'
|
|
||||||
# output_dir = 'result_final/custom_data/gpt4_gt_dfs'
|
|
||||||
# output_dir = 'result0111/32k_aus/custom_data'
|
|
||||||
if 'reproduction' in output_dir or 'ori' in output_dir:
|
|
||||||
reassign = False
|
|
||||||
else:
|
|
||||||
reassign = True
|
|
||||||
# reassign = False
|
|
||||||
if reassign:
|
|
||||||
test_ids = list(range(200))
|
|
||||||
else:
|
|
||||||
test_ids = json.load(open(f'data/test_query_ids/{test_set}.json', 'r', encoding='utf-8'))
|
|
||||||
if 'cot' in output_dir:
|
|
||||||
method = 'CoT@1'
|
|
||||||
else:
|
|
||||||
method = 'DFS_woFilter_w2'
|
|
||||||
if not os.path.exists(output_dir):
|
|
||||||
continue
|
|
||||||
# evaluation_output_dir = f'result2/test_instruction/{test_set}/pass_rate_result_reeval_32k'
|
|
||||||
# os.system(f'mv {evaluation_output_dir} {output_dir}')
|
|
||||||
# evaluation_output_dir = f'result2/test_instruction/{test_set}/pass_rate_result_35'
|
|
||||||
# evaluation_output_dir = f'{output_dir}/pass_rate_result_reeval_32k_3times_nounsure_aus_r1'
|
|
||||||
# evaluation_output_dir = f'{output_dir}/pass_rate_result_reeval_32k_r1'
|
|
||||||
# final
|
|
||||||
evaluation_output_dir = f'{output_dir}/pass_rate_result_reeval_32k_3times'
|
|
||||||
# evaluation_output_dir = f'{output_dir}/pass_rate_result_35'
|
|
||||||
# evaluation_output_dir = f'{output_dir}/pass_rate_result_reeval_32k'
|
|
||||||
# continue
|
|
||||||
os.makedirs(evaluation_output_dir, exist_ok=True)
|
|
||||||
# label_cnt = {}
|
|
||||||
# answer_dict = {}
|
|
||||||
if os.path.exists(f"{evaluation_output_dir}/label_cnt.json"):
|
|
||||||
label_cnt = json.load(open(f"{evaluation_output_dir}/label_cnt.json", "r", encoding="utf-8"))
|
|
||||||
else:
|
|
||||||
label_cnt = {}
|
|
||||||
future = []
|
|
||||||
if os.path.exists(f"{evaluation_output_dir}/answer_dict.json"):
|
|
||||||
answer_dict = json.load(open(f"{evaluation_output_dir}/answer_dict.json", "r", encoding="utf-8"))
|
|
||||||
else:
|
|
||||||
answer_dict = {}
|
|
||||||
# result_data = json.load(open(f'data/reproduction_data/model_predictions/gpt-4-0613_dfs/{test_set}.json', 'r', encoding='utf-8'))
|
|
||||||
referenced_examples = {}
|
|
||||||
|
|
||||||
with ThreadPoolExecutor(args.max_eval_threads) as pool:
|
|
||||||
for i in test_ids:
|
|
||||||
# print(i)
|
|
||||||
if reassign:
|
|
||||||
try:
|
|
||||||
# print(f'{output_dir}/{i}.json')
|
|
||||||
data = json.load(open(f'{output_dir}/{i}.json', 'r', encoding='utf-8'))
|
|
||||||
except:
|
|
||||||
continue
|
|
||||||
query_id = data['query_id']
|
|
||||||
if int(query_id) in unsolvable_list:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
query_id = i
|
|
||||||
if int(query_id) in unsolvable_list:
|
|
||||||
continue
|
|
||||||
if 'chatgpt' in output_dir and 'cot' not in output_dir:
|
|
||||||
data = json.load(open(f'{output_dir}/{i}_ChatGPT_{method}.json', 'r', encoding='utf-8'))
|
|
||||||
elif 'chatgpt' in output_dir:
|
|
||||||
data = json.load(open(f'{output_dir}/{i}_{method}.json', 'r', encoding='utf-8'))
|
|
||||||
else:
|
|
||||||
data = json.load(open(f'{output_dir}/{i}_{method}.json', 'r', encoding='utf-8'))
|
|
||||||
|
|
||||||
if not reassign:
|
|
||||||
total_tokens += data['answer_generation']['total_tokens']
|
|
||||||
else:
|
|
||||||
if 'total_tokens' in data:
|
|
||||||
total_tokens += data['total_tokens']
|
|
||||||
if str(query_id) in label_cnt:
|
|
||||||
continue
|
|
||||||
if reassign:
|
|
||||||
# print(i)
|
|
||||||
if 'last_solve_time' not in data:
|
|
||||||
try:
|
|
||||||
data_dict = json.load(open(f'{output_dir}/{i}_DFS_woFilter_w2.json', 'r', encoding='utf-8'))
|
|
||||||
except:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
last_solve_time = data['last_solve_time']
|
|
||||||
data_dict = json.load(open(f'{output_dir}/{i}_{last_solve_time}_DFS_woFilter_w2.json', 'r', encoding='utf-8'))
|
|
||||||
else:
|
|
||||||
data_dict = data
|
|
||||||
if not data_dict['answer_generation']['valid_data']:
|
|
||||||
answer_dict[i] = process_invalid_data(method,data_dict)
|
|
||||||
else:
|
|
||||||
answer_dict[i] = process_valid_data(method,data_dict['answer_generation'])
|
|
||||||
example = answer_dict[i]
|
|
||||||
# query_id = i
|
|
||||||
# example['available_tools'] = query_data[str(query_id)]['available_tools']
|
|
||||||
referenced_examples[query_id] = example
|
|
||||||
for _ in range(args.evaluate_times):
|
|
||||||
future.append(pool.submit(
|
|
||||||
compute_pass_rate,
|
|
||||||
query_id,
|
|
||||||
example,
|
|
||||||
'Solvable',
|
|
||||||
'Task solvable human label'
|
|
||||||
))
|
|
||||||
for thd in tqdm(as_completed(future),total=len(future),ncols=100):
|
|
||||||
query_id, task_solvable, is_solved, machine_label, reason, not_hallucinate, tokens = thd.result()
|
|
||||||
example = referenced_examples[query_id]
|
|
||||||
query = example["query"]
|
|
||||||
tool_names = []
|
|
||||||
for tool_dict in example["available_tools"]:
|
|
||||||
tool_name = tool_dict["name"]
|
|
||||||
tool_names.append(tool_name)
|
|
||||||
answer_steps, final_step = get_steps(example)
|
|
||||||
if query_id not in label_cnt:
|
|
||||||
label_cnt[query_id] = {"passed":0, "failed":0, "unsure":0}
|
|
||||||
if machine_label == "passed":
|
|
||||||
label_cnt[query_id]["passed"] += 1
|
|
||||||
elif machine_label == "failed":
|
|
||||||
label_cnt[query_id]["failed"] += 1
|
|
||||||
else:
|
|
||||||
label_cnt[query_id]["unsure"] += 1
|
|
||||||
label_cnt[query_id]["query"] = query
|
|
||||||
label_cnt[query_id]["task_solvable"] = str(task_solvable)
|
|
||||||
label_cnt[query_id]["tool_names"] = tool_names
|
|
||||||
label_cnt[query_id]["answer_steps"] = answer_steps
|
|
||||||
label_cnt[query_id]["final_step"] = final_step
|
|
||||||
label_cnt[query_id]["is_solved"] = str(is_solved)
|
|
||||||
label_cnt[query_id]["reason"] = reason
|
|
||||||
label_cnt[query_id]["not_hallucinate"] = not_hallucinate
|
|
||||||
json.dump(label_cnt, open(f"{evaluation_output_dir}/label_cnt.json", "w"), ensure_ascii=False, indent=4)
|
|
||||||
filename = f"{evaluation_output_dir}/label_cnt.csv"
|
|
||||||
write_results(filename, 'result', label_cnt)
|
|
||||||
pass_rate = 0
|
|
||||||
total_num = 0
|
|
||||||
print('#'*100)
|
|
||||||
for query_id in label_cnt:
|
|
||||||
if int(query_id) in unsolvable_list:
|
|
||||||
continue
|
|
||||||
if label_cnt[query_id]["failed"] <= label_cnt[query_id]["passed"]:
|
|
||||||
pass_rate += 1
|
|
||||||
# if label_cnt[query_id]["unsure"] > 0:
|
|
||||||
# print('unsure')
|
|
||||||
total_num += 1
|
|
||||||
pass_rate /= total_num
|
|
||||||
pass_rate_list.append(pass_rate)
|
|
||||||
average_tokens_list.append(total_tokens/total_num)
|
|
||||||
print(f"Pass rate: {str(pass_rate)} total num {total_num} average tokens {total_tokens/total_num} {test_set}")
|
|
||||||
json.dump(answer_dict, open(f"{evaluation_output_dir}/answer_dict.json", "w"), ensure_ascii=False, indent=4)
|
|
||||||
print('&'.join([str(round(x*100,1)) for x in pass_rate_list]),round(np.mean(pass_rate_list)*100,1))
|
|
||||||
print('&'.join([str(round(x,1)) for x in average_tokens_list]),round(np.mean(average_tokens_list),1))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -125,7 +125,6 @@ def solve_given_api_main(query, api_list, i, messages=None):
|
||||||
result_data['result']['reason'] = reason
|
result_data['result']['reason'] = reason
|
||||||
result['answer_generation']['final_answer'] = json.dumps(result_data['result'])
|
result['answer_generation']['final_answer'] = json.dumps(result_data['result'])
|
||||||
if result['answer_generation']['finish_type'] == 'give_answer' and 'final_answer' in result_data['result'] and result_data['result']['final_answer'] != '':
|
if result['answer_generation']['finish_type'] == 'give_answer' and 'final_answer' in result_data['result'] and result_data['result']['final_answer'] != '':
|
||||||
# and not any(word in str(result['answer_generation']['final_answer']).lower() for word in exclusion_words)
|
|
||||||
solved = True
|
solved = True
|
||||||
else:
|
else:
|
||||||
solved = False
|
solved = False
|
||||||
|
|
@ -137,7 +136,6 @@ from arguments import parse_args
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
output_path = args.output_dir
|
output_path = args.output_dir
|
||||||
|
|
||||||
# output_path = f'{query_dir}/reassign_toolllama_dfs_r1'
|
|
||||||
os.makedirs(output_path, exist_ok=True)
|
os.makedirs(output_path, exist_ok=True)
|
||||||
dfs_args = dotdict(dict(backbone_model='chatgpt_function', openai_key='', model_path='your_model_path/', tool_root_dir='data/toolenv/tools/', lora=False, lora_path='your_lora_path if lora', max_observation_length=1024, max_source_sequence_length=4096, max_sequence_length=8192, observ_compress_method='truncate', method='DFS_woFilter_w2', input_query_file='data/test_instruction/G1_tool.json', output_answer_file=output_path, toolbench_key=toolbench_key, rapidapi_key='', use_rapidapi_key=False, api_customization=False))
|
dfs_args = dotdict(dict(backbone_model='chatgpt_function', openai_key='', model_path='your_model_path/', tool_root_dir='data/toolenv/tools/', lora=False, lora_path='your_lora_path if lora', max_observation_length=1024, max_source_sequence_length=4096, max_sequence_length=8192, observ_compress_method='truncate', method='DFS_woFilter_w2', input_query_file='data/test_instruction/G1_tool.json', output_answer_file=output_path, toolbench_key=toolbench_key, rapidapi_key='', use_rapidapi_key=False, api_customization=False))
|
||||||
dfs_runner = pipeline_runner(dfs_args)
|
dfs_runner = pipeline_runner(dfs_args)
|
||||||
|
|
@ -156,9 +154,6 @@ if __name__ == '__main__':
|
||||||
for i in range(262):
|
for i in range(262):
|
||||||
t_s = time.time()
|
t_s = time.time()
|
||||||
comparison_data = {}
|
comparison_data = {}
|
||||||
# for file in files:
|
|
||||||
# if file.endswith('.json'):
|
|
||||||
# print(file)
|
|
||||||
data_load = json.load(open(f'{args.query_dir}/{i}.json', 'r', encoding='utf-8'))
|
data_load = json.load(open(f'{args.query_dir}/{i}.json', 'r', encoding='utf-8'))
|
||||||
if str(data_load['query_id']) in solved_dict and solved_dict[str(data_load['query_id'])]['solved'] != 'Solved':
|
if str(data_load['query_id']) in solved_dict and solved_dict[str(data_load['query_id'])]['solved'] != 'Solved':
|
||||||
continue
|
continue
|
||||||
|
|
@ -166,8 +161,6 @@ if __name__ == '__main__':
|
||||||
query = data_load['query']
|
query = data_load['query']
|
||||||
# continue
|
# continue
|
||||||
cnt += 1
|
cnt += 1
|
||||||
# if cnt > 50:
|
|
||||||
# break
|
|
||||||
if os.path.exists(os.path.join(output_path, f'{i}_DFS_woFilter_w2.json')):
|
if os.path.exists(os.path.join(output_path, f'{i}_DFS_woFilter_w2.json')):
|
||||||
data = json.load(open(os.path.join(output_path, f'{i}_DFS_woFilter_w2.json'), 'r', encoding='utf-8'))
|
data = json.load(open(os.path.join(output_path, f'{i}_DFS_woFilter_w2.json'), 'r', encoding='utf-8'))
|
||||||
final_data = json.load(open(os.path.join(output_path, f'{i}.json'), 'r', encoding='utf-8'))
|
final_data = json.load(open(os.path.join(output_path, f'{i}.json'), 'r', encoding='utf-8'))
|
||||||
|
|
|
||||||
|
|
@ -25,49 +25,41 @@ def check_task_solvable(query):
|
||||||
"content": f"Please check whether the following query is solvable: {query}. Begin!"}
|
"content": f"Please check whether the following query is solvable: {query}. Begin!"}
|
||||||
]
|
]
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
# try:
|
response = call_gpt(
|
||||||
if True:
|
messages=messages,
|
||||||
t_s = time.time()
|
functions=[solvable_finish_function]
|
||||||
response = call_gpt(
|
)
|
||||||
messages=messages,
|
tool_calls = response.choices[0].message.tool_calls
|
||||||
functions=[solvable_finish_function]
|
print('Thought:', response.choices[0].message.content)
|
||||||
)
|
if tool_calls:
|
||||||
print(time.time() - t_s)
|
for tool_call in tool_calls:
|
||||||
tool_calls = response.choices[0].message.tool_calls
|
function_name = tool_call.function.name
|
||||||
print('Thought:', response.choices[0].message.content)
|
function_args = tool_call.function.arguments
|
||||||
if tool_calls:
|
if function_name == 'Finish':
|
||||||
for tool_call in tool_calls:
|
try:
|
||||||
function_name = tool_call.function.name
|
solvable, reason = Finish(**json.loads(function_args))
|
||||||
function_args = tool_call.function.arguments
|
except:
|
||||||
if function_name == 'Finish':
|
|
||||||
try:
|
|
||||||
solvable, reason = Finish(**json.loads(function_args))
|
|
||||||
except:
|
|
||||||
continue
|
|
||||||
# solvable, reason = Finish(json.loads(function_args))
|
|
||||||
|
|
||||||
else:
|
|
||||||
continue
|
continue
|
||||||
print(solvable, query, file=open('result/solvable.txt', 'a', encoding='utf-6'))
|
# solvable, reason = Finish(json.loads(function_args))
|
||||||
if solvable == 'Unsolvable' and reason is None:
|
|
||||||
messages.append({"role": "user", "content": 'You must give reason if the answer is Unsolvable'})
|
else:
|
||||||
if reason is not None:
|
continue
|
||||||
print(reason, file=open('result/solvable.txt', 'a', encoding='utf-8'))
|
print(solvable, query, file=open('result/solvable.txt', 'a', encoding='utf-6'))
|
||||||
else:
|
if solvable == 'Unsolvable' and reason is None:
|
||||||
reason = ''
|
messages.append({"role": "user", "content": 'You must give reason if the answer is Unsolvable'})
|
||||||
return solvable, reason
|
if reason is not None:
|
||||||
else:
|
print(reason, file=open('result/solvable.txt', 'a', encoding='utf-8'))
|
||||||
print('Thought:', response.choices[0].message.content)
|
else:
|
||||||
continue
|
reason = ''
|
||||||
|
return solvable, reason
|
||||||
|
else:
|
||||||
|
print('Thought:', response.choices[0].message.content)
|
||||||
|
continue
|
||||||
# messages.append({"role": "assistant", "content": response.choices[0].message.get('content', '')})
|
# messages.append({"role": "assistant", "content": response.choices[0].message.get('content', '')})
|
||||||
# except:
|
|
||||||
# pass
|
|
||||||
print('No response from the model', file=open('result/solvable.txt', 'a', encoding='utf-8'))
|
print('No response from the model', file=open('result/solvable.txt', 'a', encoding='utf-8'))
|
||||||
return 'No response', 'No response from the model'
|
return 'No response', 'No response from the model'
|
||||||
|
|
||||||
def check_task_solvable_by_function(query, functions):
|
def check_task_solvable_by_function(query, functions):
|
||||||
# return 'Solvable', ''
|
|
||||||
# print(functions)
|
|
||||||
messages = [{
|
messages = [{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": CHECK_SOLVABLE_BY_FUNCTION_PROMPT
|
"content": CHECK_SOLVABLE_BY_FUNCTION_PROMPT
|
||||||
|
|
@ -76,40 +68,33 @@ def check_task_solvable_by_function(query, functions):
|
||||||
"content": f"Query: {query}. Available_tools: {functions}. Begin!"}
|
"content": f"Query: {query}. Available_tools: {functions}. Begin!"}
|
||||||
]
|
]
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
# try:
|
response = call_gpt(
|
||||||
if True:
|
messages=messages,
|
||||||
t_s = time.time()
|
functions=[solvable_finish_function]
|
||||||
response = call_gpt(
|
)
|
||||||
messages=messages,
|
tool_calls = response.choices[0].message.tool_calls
|
||||||
functions=[solvable_finish_function]
|
print('Thought:', response.choices[0].message.content)
|
||||||
)
|
if tool_calls:
|
||||||
print(time.time() - t_s)
|
for tool_call in tool_calls:
|
||||||
tool_calls = response.choices[0].message.tool_calls
|
function_name = tool_call.function.name
|
||||||
print('Thought:', response.choices[0].message.content)
|
function_args = tool_call.function.arguments
|
||||||
if tool_calls:
|
if function_name.lower() == 'finish':
|
||||||
for tool_call in tool_calls:
|
try:
|
||||||
function_name = tool_call.function.name
|
solvable, reason = Finish(**json.loads(function_args))
|
||||||
function_args = tool_call.function.arguments
|
except:
|
||||||
if function_name.lower() == 'finish':
|
|
||||||
try:
|
|
||||||
solvable, reason = Finish(**json.loads(function_args))
|
|
||||||
except:
|
|
||||||
continue
|
|
||||||
# solvable, reason = Finish(json.loads(function_args))
|
|
||||||
|
|
||||||
else:
|
|
||||||
continue
|
continue
|
||||||
print(solvable, query, file=open('result/solvable.txt', 'a', encoding='utf-8'))
|
# solvable, reason = Finish(json.loads(function_args))
|
||||||
if solvable == 'Unsolvable' and reason is None:
|
|
||||||
messages.append({"role": "user", "content": 'You must give reason if the answer is Unsolvable'})
|
else:
|
||||||
if reason is not None:
|
continue
|
||||||
print(reason, file=open('result/solvable.txt', 'a', encoding='utf-8'))
|
if solvable == 'Unsolvable' and reason is None:
|
||||||
else:
|
messages.append({"role": "user", "content": 'You must give reason if the answer is Unsolvable'})
|
||||||
reason = ''
|
if reason is None:
|
||||||
return solvable, reason, response.usage.total_tokens
|
reason = ''
|
||||||
else:
|
return solvable, reason, response.usage.total_tokens
|
||||||
print('Thought:', response.choices[0].message.content)
|
else:
|
||||||
continue
|
print('Thought:', response.choices[0].message.content)
|
||||||
|
continue
|
||||||
# messages.append({"role": "assistant", "content": response.choices[0].message.get('content', '')})
|
# messages.append({"role": "assistant", "content": response.choices[0].message.get('content', '')})
|
||||||
# except:
|
# except:
|
||||||
# pass
|
# pass
|
||||||
|
|
@ -141,13 +126,10 @@ def check_task_solved(query, answer):
|
||||||
print(function_name, function_args)
|
print(function_name, function_args)
|
||||||
if function_name.lower() == 'finish':
|
if function_name.lower() == 'finish':
|
||||||
solvable, reason = Finish(**json.loads(function_args))
|
solvable, reason = Finish(**json.loads(function_args))
|
||||||
print(solvable, query, file=open('result/solved.txt', 'a', encoding='utf-8'))
|
|
||||||
if solvable == 'Unsolved' and reason is None:
|
if solvable == 'Unsolved' and reason is None:
|
||||||
messages.append({"role": "user", "content": 'You must give reason if the answer is Unsolvable'})
|
messages.append({"role": "user", "content": 'You must give reason if the answer is Unsolvable'})
|
||||||
continue
|
continue
|
||||||
if reason is not None:
|
if reason is None:
|
||||||
print(reason, file=open('result/solved.txt', 'a', encoding='utf-8'))
|
|
||||||
else:
|
|
||||||
reason = ''
|
reason = ''
|
||||||
return solvable, reason
|
return solvable, reason
|
||||||
|
|
||||||
|
|
@ -167,11 +149,9 @@ def check_solved_toolbench(output_path, query_id, task_solvable=None, solvable_t
|
||||||
example = process_invalid_data(method,data_dict)
|
example = process_invalid_data(method,data_dict)
|
||||||
else:
|
else:
|
||||||
example = process_valid_data(method,data_dict['answer_generation'])
|
example = process_valid_data(method,data_dict['answer_generation'])
|
||||||
# example['available_tools'] = query_data[str(ori_query_id)]['available_tools']
|
|
||||||
future = []
|
future = []
|
||||||
answer_dict = {'passed':0, 'failed':0}
|
answer_dict = {'passed':0, 'failed':0}
|
||||||
with ThreadPoolExecutor(32) as pool:
|
with ThreadPoolExecutor(32) as pool:
|
||||||
print(task_solvable, solvable_task_reason, file=open(os.path.join(output_dir, 'solvable.txt'), 'a', encoding='utf-8'))
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
future.append(pool.submit(
|
future.append(pool.submit(
|
||||||
compute_pass_rate,
|
compute_pass_rate,
|
||||||
|
|
|
||||||
|
|
@ -60,15 +60,6 @@ def call_gpt(messages, functions=None, **kwargs):
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
except openai.BadRequestError as e:
|
except openai.BadRequestError as e:
|
||||||
# try:
|
|
||||||
# response = turbo_client.chat.completions.create(
|
|
||||||
# seed=123,
|
|
||||||
# model='gpt-4-turbo',
|
|
||||||
# messages=messages,
|
|
||||||
# functions=functions
|
|
||||||
# )
|
|
||||||
# except Exception as e:
|
|
||||||
# raise e
|
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
except openai.RateLimitError as e:
|
except openai.RateLimitError as e:
|
||||||
|
|
@ -85,8 +76,6 @@ def call_gpt(messages, functions=None, **kwargs):
|
||||||
try:
|
try:
|
||||||
response, t_real = call_gpt_retry(messages_converted, functions)
|
response, t_real = call_gpt_retry(messages_converted, functions)
|
||||||
# json_content = response.choices[0].message.content
|
# json_content = response.choices[0].message.content
|
||||||
t = time.time() - t_s
|
|
||||||
print('minus:', t-t_real, file=open(os.path.join(output_dir, "time.txt"), "a"))
|
|
||||||
# print(response.choices[0].message.function_call)
|
# print(response.choices[0].message.function_call)
|
||||||
if response.choices[0].finish_reason == 'function_call':
|
if response.choices[0].finish_reason == 'function_call':
|
||||||
response_json = json.loads(response.json())
|
response_json = json.loads(response.json())
|
||||||
|
|
|
||||||
237
scripts/main.py
237
scripts/main.py
|
|
@ -963,135 +963,132 @@ if __name__ == "__main__":
|
||||||
# continue
|
# continue
|
||||||
print(f'query: {query}')
|
print(f'query: {query}')
|
||||||
flag = False
|
flag = False
|
||||||
try:
|
runner = Main_Search_Agent(query)
|
||||||
runner = Main_Search_Agent(query)
|
agents.append(runner)
|
||||||
agents.append(runner)
|
if multi_thread:
|
||||||
if multi_thread:
|
thread = threading.Thread(target=runner.assign_main, args=(query,))
|
||||||
thread = threading.Thread(target=runner.assign_main, args=(query,))
|
sem.acquire()
|
||||||
sem.acquire()
|
thread.start()
|
||||||
thread.start()
|
threads.append(thread)
|
||||||
threads.append(thread)
|
else:
|
||||||
|
iter_func = runner.assign_main(query)
|
||||||
|
messages = None
|
||||||
|
cnt = 0
|
||||||
|
solve_data = {}
|
||||||
|
if multi_thread:
|
||||||
|
while True:
|
||||||
|
thread_num = len(threads)
|
||||||
|
has_thread_alive = False
|
||||||
|
for thread in threads:
|
||||||
|
if thread.is_alive():
|
||||||
|
has_thread_alive = True
|
||||||
|
thread.join()
|
||||||
|
if not has_thread_alive:
|
||||||
|
break
|
||||||
|
if error_flag: raise Exception('GPT Call Error')
|
||||||
|
threads = []
|
||||||
|
# refind
|
||||||
|
check_solved = ''
|
||||||
|
max_depth = max([agent.depth for agent in agents])
|
||||||
|
while not all([agent.finish_search for agent in agents]) or not flag:
|
||||||
|
if check_solved == 'Solved':
|
||||||
|
break
|
||||||
|
max_depth = max([agent.depth for agent in agents])
|
||||||
|
depth = max_depth
|
||||||
|
while all([agent.finish_search for agent in agents if agent.depth == depth]) and depth >= 0:
|
||||||
|
depth -= 1
|
||||||
|
if depth < 0 and flag:
|
||||||
|
break
|
||||||
|
agents_to_resume = [agent for agent in agents if not agent.finish_search and agent.depth == depth]
|
||||||
|
if total_tokens > 200000 and flag:
|
||||||
|
solved = False
|
||||||
|
check_solved = 'Timeout'
|
||||||
|
solve_data = {'result': 'Timeout'}
|
||||||
|
break
|
||||||
|
cnt += 1
|
||||||
|
failed_reason = None
|
||||||
|
print(len(global_api_list), file=open(f'{output_dir}/api_list_len.txt', 'a', encoding='utf-8'))
|
||||||
|
print(global_api_list, file=open(f'{output_dir}/api_list.txt', 'a', encoding='utf-8'))
|
||||||
|
print('#'*100)
|
||||||
|
assign_results['api_list'].append(deepcopy(global_api_list))
|
||||||
|
if stop or not flag or all([agent.finish_search for agent in agents]) and len(global_api_list) > 0:
|
||||||
|
flag = True
|
||||||
|
last_solve_time = cnt
|
||||||
|
t_s = time.time()
|
||||||
|
selected_api_list = deepcopy(global_api_list)
|
||||||
|
|
||||||
|
solved, solve_data = solve_given_api_main(query, selected_api_list, f'{query_id}_{cnt}', messages)
|
||||||
|
|
||||||
|
print('solve time:', time.time() - t_s, 'api number:', len(global_api_list),file=open(f'{output_dir}/time.txt', 'a', encoding='utf-8'))
|
||||||
|
result_list.append(deepcopy(solve_data['result']))
|
||||||
|
print(solve_data['result'], solved)
|
||||||
|
if not solved or any([word in solve_data['result']['final_answer'] for word in exclusion_words]):
|
||||||
|
check_solved = 'Unsolved'
|
||||||
|
reason = solve_data['result']
|
||||||
|
else:
|
||||||
|
check_solved, reason, tokens = check_solved_toolbench(f'{output_dir}/{query_id}_{last_solve_time}_DFS_woFilter_w2.json', query_id, task_solvable, solvable_reason)
|
||||||
|
total_tokens += tokens
|
||||||
|
print(colored((check_solved, reason), 'red'))
|
||||||
|
failed_reason = reason
|
||||||
|
dfs_data = json.load(open(f'{output_dir}/{query_id}_{last_solve_time}_DFS_woFilter_w2.json', 'r', encoding='utf-8'))
|
||||||
|
total_tokens += dfs_data['answer_generation']['total_tokens']
|
||||||
|
solve_tokens += dfs_data['answer_generation']['total_tokens']
|
||||||
|
if check_solved == 'Solved':
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
messages = dfs_data['answer_generation']['train_messages'][-1]
|
||||||
|
except:
|
||||||
|
messages = None
|
||||||
|
api_list_to_prune = []
|
||||||
|
for standardized_api_name, origin_api in dfs_data['api2origin'].items():
|
||||||
|
if standardized_api_name in str(failed_reason):
|
||||||
|
if origin_api in global_api_list:
|
||||||
|
api_list_to_prune.append(origin_api)
|
||||||
|
print(colored(api_list_to_prune, 'red'))
|
||||||
|
print(len(api_list_to_prune))
|
||||||
|
remove_apis(api_list_to_prune)
|
||||||
|
if len(global_api_list) >= max_api_number:
|
||||||
|
break
|
||||||
|
# print(api_list_to_prune, file=open(f'{output_dir}/prune_api_list.txt', 'a', encoding='utf-8'))
|
||||||
|
stop = False
|
||||||
else:
|
else:
|
||||||
iter_func = runner.assign_main(query)
|
assert status != 'The current api list can solve the query.'
|
||||||
messages = None
|
failed_reason = status
|
||||||
cnt = 0
|
reason_list.append(failed_reason)
|
||||||
solve_data = {}
|
print(colored('Refind Begin', 'red'))
|
||||||
|
print(colored(agents_to_resume, 'red'))
|
||||||
|
print([agent.finish_search for agent in agents_to_resume])
|
||||||
|
|
||||||
|
|
||||||
|
threads = []
|
||||||
|
resume_cnt = 0
|
||||||
|
resumed_agents.append([(str(a), a.index) for a in agents_to_resume])
|
||||||
|
for agent in reversed(agents_to_resume):
|
||||||
|
if agent.finish_search: continue
|
||||||
|
resume_cnt += 1
|
||||||
|
agent.failed_reason = str(failed_reason)
|
||||||
|
print(colored(('resuming', agent, agent.depth), 'red'))
|
||||||
|
print(colored(('resuming', agent, agent.depth), 'red'), file=open(f'{output_dir}/resume.txt', 'a', encoding='utf-8'))
|
||||||
|
if multi_thread:
|
||||||
|
thread = threading.Thread(target=agent.resume_search)
|
||||||
|
sem.acquire()
|
||||||
|
thread.start()
|
||||||
|
threads.append(thread)
|
||||||
|
else:
|
||||||
|
agent.resume_search()
|
||||||
if multi_thread:
|
if multi_thread:
|
||||||
while True:
|
while True:
|
||||||
thread_num = len(threads)
|
thread_num = len(threads)
|
||||||
has_thread_alive = False
|
|
||||||
for thread in threads:
|
for thread in threads:
|
||||||
if thread.is_alive():
|
if thread.is_alive():
|
||||||
has_thread_alive = True
|
|
||||||
thread.join()
|
thread.join()
|
||||||
if not has_thread_alive:
|
if thread_num == len(threads):
|
||||||
break
|
break
|
||||||
if error_flag: raise Exception('GPT Call Error')
|
if error_flag: raise Exception('GPT Call Error')
|
||||||
threads = []
|
if not stop:
|
||||||
# refind
|
check_if_request_solvable()
|
||||||
check_solved = ''
|
print(colored(f'status:{status}', 'red'))
|
||||||
max_depth = max([agent.depth for agent in agents])
|
assign_results['stop'].append(stop)
|
||||||
while not all([agent.finish_search for agent in agents]) or not flag:
|
|
||||||
if check_solved == 'Solved':
|
|
||||||
break
|
|
||||||
max_depth = max([agent.depth for agent in agents])
|
|
||||||
depth = max_depth
|
|
||||||
while all([agent.finish_search for agent in agents if agent.depth == depth]) and depth >= 0:
|
|
||||||
depth -= 1
|
|
||||||
if depth < 0 and flag:
|
|
||||||
break
|
|
||||||
agents_to_resume = [agent for agent in agents if not agent.finish_search and agent.depth == depth]
|
|
||||||
if total_tokens > 200000 and flag:
|
|
||||||
solved = False
|
|
||||||
check_solved = 'Timeout'
|
|
||||||
solve_data = {'result': 'Timeout'}
|
|
||||||
break
|
|
||||||
cnt += 1
|
|
||||||
failed_reason = None
|
|
||||||
print(len(global_api_list), file=open(f'{output_dir}/api_list_len.txt', 'a', encoding='utf-8'))
|
|
||||||
print(global_api_list, file=open(f'{output_dir}/api_list.txt', 'a', encoding='utf-8'))
|
|
||||||
print('#'*100)
|
|
||||||
assign_results['api_list'].append(deepcopy(global_api_list))
|
|
||||||
if stop or not flag or all([agent.finish_search for agent in agents]) and len(global_api_list) > 0:
|
|
||||||
flag = True
|
|
||||||
last_solve_time = cnt
|
|
||||||
t_s = time.time()
|
|
||||||
selected_api_list = deepcopy(global_api_list)
|
|
||||||
|
|
||||||
solved, solve_data = solve_given_api_main(query, selected_api_list, f'{query_id}_{cnt}', messages)
|
|
||||||
|
|
||||||
print('solve time:', time.time() - t_s, 'api number:', len(global_api_list),file=open(f'{output_dir}/time.txt', 'a', encoding='utf-8'))
|
|
||||||
result_list.append(deepcopy(solve_data['result']))
|
|
||||||
print(solve_data['result'], solved)
|
|
||||||
if not solved or any([word in solve_data['result']['final_answer'] for word in exclusion_words]):
|
|
||||||
check_solved = 'Unsolved'
|
|
||||||
reason = solve_data['result']
|
|
||||||
else:
|
|
||||||
check_solved, reason, tokens = check_solved_toolbench(f'{output_dir}/{query_id}_{last_solve_time}_DFS_woFilter_w2.json', query_id, task_solvable, solvable_reason)
|
|
||||||
total_tokens += tokens
|
|
||||||
print(colored((check_solved, reason), 'red'))
|
|
||||||
failed_reason = reason
|
|
||||||
dfs_data = json.load(open(f'{output_dir}/{query_id}_{last_solve_time}_DFS_woFilter_w2.json', 'r', encoding='utf-8'))
|
|
||||||
total_tokens += dfs_data['answer_generation']['total_tokens']
|
|
||||||
solve_tokens += dfs_data['answer_generation']['total_tokens']
|
|
||||||
if check_solved == 'Solved':
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
messages = dfs_data['answer_generation']['train_messages'][-1]
|
|
||||||
except:
|
|
||||||
messages = None
|
|
||||||
api_list_to_prune = []
|
|
||||||
for standardized_api_name, origin_api in dfs_data['api2origin'].items():
|
|
||||||
if standardized_api_name in str(failed_reason):
|
|
||||||
if origin_api in global_api_list:
|
|
||||||
api_list_to_prune.append(origin_api)
|
|
||||||
print(colored(api_list_to_prune, 'red'))
|
|
||||||
print(len(api_list_to_prune))
|
|
||||||
remove_apis(api_list_to_prune)
|
|
||||||
if len(global_api_list) >= max_api_number:
|
|
||||||
break
|
|
||||||
# print(api_list_to_prune, file=open(f'{output_dir}/prune_api_list.txt', 'a', encoding='utf-8'))
|
|
||||||
stop = False
|
|
||||||
else:
|
|
||||||
assert status != 'The current api list can solve the query.'
|
|
||||||
failed_reason = status
|
|
||||||
reason_list.append(failed_reason)
|
|
||||||
print(colored('Refind Begin', 'red'))
|
|
||||||
print(colored(agents_to_resume, 'red'))
|
|
||||||
print([agent.finish_search for agent in agents_to_resume])
|
|
||||||
|
|
||||||
|
|
||||||
threads = []
|
|
||||||
resume_cnt = 0
|
|
||||||
resumed_agents.append([(str(a), a.index) for a in agents_to_resume])
|
|
||||||
for agent in reversed(agents_to_resume):
|
|
||||||
if agent.finish_search: continue
|
|
||||||
resume_cnt += 1
|
|
||||||
agent.failed_reason = str(failed_reason)
|
|
||||||
print(colored(('resuming', agent, agent.depth), 'red'))
|
|
||||||
print(colored(('resuming', agent, agent.depth), 'red'), file=open(f'{output_dir}/resume.txt', 'a', encoding='utf-8'))
|
|
||||||
if multi_thread:
|
|
||||||
thread = threading.Thread(target=agent.resume_search)
|
|
||||||
sem.acquire()
|
|
||||||
thread.start()
|
|
||||||
threads.append(thread)
|
|
||||||
else:
|
|
||||||
agent.resume_search()
|
|
||||||
if multi_thread:
|
|
||||||
while True:
|
|
||||||
thread_num = len(threads)
|
|
||||||
for thread in threads:
|
|
||||||
if thread.is_alive():
|
|
||||||
thread.join()
|
|
||||||
if thread_num == len(threads):
|
|
||||||
break
|
|
||||||
if error_flag: raise Exception('GPT Call Error')
|
|
||||||
if not stop:
|
|
||||||
check_if_request_solvable()
|
|
||||||
print(colored(f'status:{status}', 'red'))
|
|
||||||
assign_results['stop'].append(stop)
|
|
||||||
except KeyboardInterrupt as e:
|
|
||||||
continue
|
|
||||||
|
|
||||||
assign_results['api_complete'] = flag
|
assign_results['api_complete'] = flag
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,25 +10,6 @@ from arguments import parse_args
|
||||||
from config import *
|
from config import *
|
||||||
from openai_utils import call_gpt
|
from openai_utils import call_gpt
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
output_dir = args.output_dir
|
|
||||||
# import importlib
|
|
||||||
# module = importlib.import_module(args.openai_config_path.replace('.py',''))
|
|
||||||
# for name in dir(module):
|
|
||||||
# if not name.startswith('_'):
|
|
||||||
# globals()[name] = getattr(module, name)
|
|
||||||
# api_key = globals()['api_key']
|
|
||||||
# api_version = globals()['api_version']
|
|
||||||
# model_name = globals()['model_name']
|
|
||||||
# api_base = globals()['api_base']
|
|
||||||
if api_type == "azure":
|
|
||||||
from openai import AzureOpenAI as Client
|
|
||||||
else:
|
|
||||||
from openai import OpenAI as Client
|
|
||||||
client = Client(
|
|
||||||
api_key=api_key,
|
|
||||||
api_version=api_version,
|
|
||||||
azure_endpoint = api_base
|
|
||||||
)
|
|
||||||
@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))
|
@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))
|
||||||
def chat_completion_request(key, messages, functions=None,function_call=None,key_pos=None, model="gpt-4-32k",stop=None,process_id=0, **args):
|
def chat_completion_request(key, messages, functions=None,function_call=None,key_pos=None, model="gpt-4-32k",stop=None,process_id=0, **args):
|
||||||
use_messages = []
|
use_messages = []
|
||||||
|
|
@ -52,30 +33,14 @@ def chat_completion_request(key, messages, functions=None,function_call=None,key
|
||||||
json_data.update({"function_call": function_call})
|
json_data.update({"function_call": function_call})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# if model in ["gpt-3.5-turbo-16k-0613","gpt-4-0613", "gpt-4-deployment","gpt-4-32k", 'gpt-4-turbo']:
|
|
||||||
# openai.api_key = key
|
|
||||||
# else:
|
|
||||||
# raise NotImplementedError
|
|
||||||
# ts = time.time()
|
|
||||||
# print(json_data, file=open('output/gpt_io.txt','a'))
|
|
||||||
# print(time.time()-ts)
|
|
||||||
ts = time.time()
|
|
||||||
# json.dump(json_data['messages'], open(os.path.join(output_dir,'messages.json'),'w'), indent=4)
|
|
||||||
openai_response = call_gpt(
|
openai_response = call_gpt(
|
||||||
**json_data,
|
**json_data,
|
||||||
)
|
)
|
||||||
# openai_response = client.chat.completions.create(
|
|
||||||
# **json_data,
|
|
||||||
# )
|
|
||||||
# print('solve', time.time()-ts, file=open(os.path.join(output_dir,'time.txt'),'a'))
|
|
||||||
# json_data = json.loads(str(openai_response))
|
|
||||||
json_data = json.loads(openai_response.json())
|
json_data = json.loads(openai_response.json())
|
||||||
json_data["choices"][0]['message'].pop('tool_calls')
|
json_data["choices"][0]['message'].pop('tool_calls')
|
||||||
return json_data
|
return json_data
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# print('solve', time.time()-ts, file=open(os.path.join(output_dir,'time.txt'),'a'))
|
|
||||||
# # json_data = json.loads(str(openai_response))
|
|
||||||
# json_data = json.loads(openai_response.json())
|
# json_data = json.loads(openai_response.json())
|
||||||
# json_data["choices"][0]['message'].pop('tool_calls')
|
# json_data["choices"][0]['message'].pop('tool_calls')
|
||||||
print("Unable to generate ChatCompletion response")
|
print("Unable to generate ChatCompletion response")
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue