406 lines
19 KiB
Python
406 lines
19 KiB
Python
import re
|
|
from toolbench.inference.Tree.Tree import my_tree, tree_node
|
|
from toolbench.inference.Prompts.ReAct_prompts import FORMAT_INSTRUCTIONS_SYSTEM_FUNCTION, FORMAT_INSTRUCTIONS_SYSTEM_FUNCTION_ADAPTED, FORMAT_INSTRUCTIONS_USER_FUNCTION
|
|
from toolbench.inference.Prompts.Tree_search_prompts import DIVERSITY_PROMPT
|
|
from toolbench.inference.Algorithms.base_search import base_search_method
|
|
from copy import deepcopy
|
|
from toolbench.inference.LLM_rank.rank_candidate import sum_based_rankn, rank2_subfix
|
|
import json
|
|
import random
|
|
import os
|
|
from arguments import parse_args
|
|
args = parse_args()
|
|
file_name = 'llama_io.txt'
|
|
|
|
class DFS_tree_search(base_search_method):
|
|
|
|
def __init__(self, llm, io_func, process_id=0, callbacks=None):
|
|
super(DFS_tree_search, self).__init__(
|
|
llm, io_func, process_id, callbacks)
|
|
"""Depth-first search.
|
|
with_filter=True: Every time a child node is generated, choose the best multiple iterations to go.
|
|
with_filter=False: Do as Preorder traversal.
|
|
"""
|
|
self.io_func = io_func
|
|
self.llm = llm
|
|
self.process_id = process_id
|
|
self.restart()
|
|
|
|
self.callbacks = callbacks if callbacks is not None else []
|
|
|
|
def restart(self):
|
|
self.status = 0
|
|
self.terminal_node = []
|
|
self.give_up_node = []
|
|
self.now_expand_num = 0
|
|
self.query_count = 0
|
|
self.total_tokens = 0
|
|
|
|
def send_agent_chain_end(self, depth, agent_block_ids, chain_block_ids):
|
|
for i in range(len(self.callbacks)):
|
|
callback = self.callbacks[i]
|
|
callback.on_chain_end(
|
|
depth=depth,
|
|
block_id=chain_block_ids[i]
|
|
)
|
|
if i < len(agent_block_ids):
|
|
callback.on_agent_end(
|
|
depth=depth,
|
|
block_id=agent_block_ids[i]
|
|
)
|
|
|
|
def to_json(self, answer=False, process=True):
|
|
|
|
if process:
|
|
json_obj = {
|
|
"win": self.status == 1,
|
|
"tree": self.tree.to_json_recursive(),
|
|
"forward_args": self.forward_args,
|
|
"compare_candidates": [],
|
|
}
|
|
for node in self.terminal_node:
|
|
if node.pruned == False: # has answer
|
|
json_obj["compare_candidates"].append(
|
|
node.get_chain_result_from_this_node(use_messages=False))
|
|
else:
|
|
json_obj = {}
|
|
|
|
if answer:
|
|
json_obj["answer_generation"] = {
|
|
"valid_data": False,
|
|
"query_count": self.query_count,
|
|
"total_tokens": self.total_tokens,
|
|
"final_answer": "",
|
|
"finish_type": "give_answer",
|
|
"function": self.io_func.functions,
|
|
"chain": [],
|
|
}
|
|
for node in self.terminal_node:
|
|
if node.pruned == False:
|
|
if 'give_up' in node.description.lower():
|
|
json_obj["answer_generation"]["finish_type"] = "give_up"
|
|
else:
|
|
json_obj["answer_generation"]["finish_type"] = "give_answer"
|
|
json_obj["answer_generation"]["final_answer"] = node.description
|
|
json_obj["answer_generation"]["valid_data"] = True
|
|
json_obj["answer_generation"]["train_messages"] = node.get_train_messages_from_this_node(
|
|
)
|
|
break
|
|
# do not have final answer, look for give_up
|
|
if json_obj["answer_generation"]["valid_data"] == False:
|
|
if len(self.give_up_node) > 0:
|
|
random_pos = random.randint(0, len(self.give_up_node) - 1)
|
|
choose_give_up_node = self.give_up_node[random_pos]
|
|
json_obj["answer_generation"]["valid_data"] = True
|
|
json_obj["answer_generation"]["finish_type"] = "give_up"
|
|
json_obj["answer_generation"]["final_answer"] = choose_give_up_node.description
|
|
json_obj["answer_generation"]["train_messages"] = choose_give_up_node.get_train_messages_from_this_node()
|
|
return json_obj
|
|
|
|
def start(self, single_chain_max_step, tree_beam_size, max_query_count, answer=1, with_filter=True, messages=None):
|
|
""" single_chain_max_step: The maximum depth of the tree
|
|
tree_beam_size: How many children nodes for one node are generated per layer
|
|
answer = n means the Algo exits when find n "give_answer" nodes
|
|
max_query_count: the Algo exits when OpenAI-query exists this value
|
|
with_filter: This is the difference between normal DFS(with_filter=True) and DFSDT(with_filter=False).
|
|
"""
|
|
self.forward_args = locals()
|
|
if "self" in self.forward_args.keys():
|
|
# self.forward_args.pop("self")
|
|
self.forward_args=None
|
|
self.tree = my_tree()
|
|
self.tree.root.node_type = "Action Input"
|
|
self.tree.root.io_state = deepcopy(self.io_func)
|
|
if args.use_original_prompt:
|
|
system = FORMAT_INSTRUCTIONS_SYSTEM_FUNCTION
|
|
else:
|
|
system = FORMAT_INSTRUCTIONS_SYSTEM_FUNCTION_ADAPTED
|
|
system = system.replace("{task_description}",
|
|
self.io_func.task_description)
|
|
user = FORMAT_INSTRUCTIONS_USER_FUNCTION
|
|
user = user.replace("{input_description}",
|
|
self.io_func.input_description)
|
|
if messages is None:
|
|
self.tree.root.messages.append({"role": "system", "content": system})
|
|
self.tree.root.messages.append({"role": "user", "content": user})
|
|
else:
|
|
messages[0] = {"role": "system", "content": system}
|
|
messages.pop()
|
|
function_names = []
|
|
for function in self.io_func.functions:
|
|
function_names.append(function["name"])
|
|
for i, message in reversed(list(enumerate(messages))):
|
|
if message["role"] == "function":
|
|
if message["name"] not in function_names:
|
|
messages.pop(i)
|
|
messages.pop(i-1)
|
|
if message["role"] == "user":
|
|
if 'maximum query count' in message["content"]:
|
|
messages.pop(i)
|
|
|
|
self.tree.root.messages = messages
|
|
print('#'*100, file=open(file_name,'a') )
|
|
|
|
return self.DFS(self.tree.root, single_chain_max_step, tree_beam_size, max_query_count, answer, with_filter)
|
|
|
|
def DFS(self, now_node, single_chain_max_step, tree_beam_size, max_query_count, answer, with_filter=True):
|
|
"""Returns the number of grids to go back. When a child node of a node generates a final answer or give up, it should go back a few more grids
|
|
In a sense, the larger this value is, the more diverse it is, and it is GreedySearch@n when it is enlarged to infinity.
|
|
"""
|
|
|
|
# this two value declares the rate to go back, Algo degrades to CoT when the value=Inf
|
|
if args.solver == 'dfs':
|
|
final_answer_back_length = 2
|
|
prune_back_length = 2
|
|
else:
|
|
final_answer_back_length = 10000
|
|
prune_back_length = 10000
|
|
|
|
now_node.expand_num = self.now_expand_num
|
|
self.now_expand_num += 1
|
|
if now_node.get_depth() >= single_chain_max_step or now_node.pruned or now_node.is_terminal:
|
|
if now_node.is_terminal: # final answer
|
|
self.status = 1
|
|
self.terminal_node.append(now_node)
|
|
return final_answer_back_length
|
|
else:
|
|
now_node.pruned = True
|
|
if now_node.observation_code == 4:
|
|
self.give_up_node.append(now_node)
|
|
return prune_back_length
|
|
else:
|
|
return 1
|
|
|
|
next_tree_split_nodes = []
|
|
for i in range(tree_beam_size):
|
|
temp_now_node = now_node
|
|
|
|
"""If a node have children now, We will prompt the model to generate different nodes than all the existing nodes"""
|
|
delete_former_diversity_message = False
|
|
diversity_message = None
|
|
if len(temp_now_node.children) > 0:
|
|
|
|
former_candidates_des = ""
|
|
js_list = []
|
|
for k, child in enumerate(temp_now_node.children):
|
|
temp_node = child
|
|
while not temp_node.is_terminal and temp_node.node_type != "Action Input" and len(temp_node.children) > 0:
|
|
temp_node = temp_node.children[0]
|
|
if temp_node.node_type == "Action Input":
|
|
obj_dict = {
|
|
"name": temp_node.father.description,
|
|
"arguments": temp_node.description,
|
|
"function_output": temp_node.observation,
|
|
"mento-carlo-action-value": temp_node.compute_weight(),
|
|
}
|
|
js_list.append(obj_dict)
|
|
|
|
if len(js_list) > 0:
|
|
former_candidates_des = former_candidates_des + \
|
|
f"{json.dumps(js_list,indent=2)}\n"
|
|
if temp_now_node.observation != "":
|
|
former_candidates_des = former_candidates_des + \
|
|
f"again, your former observation: {temp_now_node.observation}\n"
|
|
diverse_prompt = DIVERSITY_PROMPT
|
|
diverse_prompt = diverse_prompt.replace(
|
|
"{previous_candidate}", former_candidates_des)
|
|
diversity_message = {
|
|
"role": "user", "content": diverse_prompt}
|
|
temp_now_node.messages.append(diversity_message)
|
|
|
|
delete_former_diversity_message = True
|
|
# on_chain_start
|
|
now_depth = temp_now_node.get_depth() // 3
|
|
chain_block_ids = [callback.on_chain_start(
|
|
depth=now_depth,
|
|
inputs=temp_now_node.messages
|
|
) for callback in self.callbacks]
|
|
agent_block_ids = []
|
|
self.llm.change_messages(temp_now_node.messages)
|
|
# on_llm_start
|
|
[callback.on_llm_start(
|
|
depth=now_depth,
|
|
messages=temp_now_node.messages
|
|
) for callback in self.callbacks]
|
|
new_message, error_code, total_tokens = self.llm.parse(
|
|
self.io_func.functions, process_id=self.process_id)
|
|
# print('-'*100, file=open(file_name,'a'))
|
|
# print('input', file=open(file_name,'a') )
|
|
a = deepcopy(temp_now_node.messages)
|
|
for aa in a:
|
|
# pprint.pprint(get_pretty_print(json.dumps(aa, indent=4)))
|
|
if 'function_call' in aa:
|
|
aa['function_call'] = {}
|
|
print(json.dumps(aa, indent=4), file=open(file_name,'a'))
|
|
# print('output', file=open(file_name,'a') )
|
|
# print(new_message, file=open(file_name,'a') )
|
|
# on_llm_end
|
|
[callback.on_llm_end(
|
|
depth=now_depth,
|
|
response=new_message
|
|
) for callback in self.callbacks]
|
|
self.query_count += 1
|
|
self.total_tokens += total_tokens
|
|
# if self.query_count >= max_query_count: # a big return value will cause the Algo to exit
|
|
# return 100000
|
|
|
|
# We need to exclude the diversity_message, because it will influence child nodes
|
|
if delete_former_diversity_message:
|
|
temp_now_node.messages[-1]["valid"] = False
|
|
|
|
# parse nodes from OpenAI-message like CoT method
|
|
assert new_message["role"] == "assistant"
|
|
if "content" in new_message.keys() and new_message["content"] != None:
|
|
temp_node = tree_node()
|
|
temp_node.node_type = "Thought"
|
|
temp_node.description = new_message["content"]
|
|
child_io_state = deepcopy(temp_now_node.io_state)
|
|
child_io_state.retriever=None
|
|
|
|
temp_node.io_state = child_io_state
|
|
temp_node.is_terminal = child_io_state.check_success() != 0
|
|
temp_node.messages = deepcopy(temp_now_node.messages)
|
|
temp_node.father = temp_now_node
|
|
temp_now_node.children.append(temp_node)
|
|
temp_node.print(self.process_id)
|
|
temp_now_node = temp_node
|
|
|
|
if error_code != 0:
|
|
temp_now_node.observation_code = error_code
|
|
temp_now_node.pruned = True
|
|
|
|
if "function_call" in new_message.keys():
|
|
# on_agent_action
|
|
agent_block_ids = [callback.on_agent_action(
|
|
depth=now_depth,
|
|
action=new_message["function_call"]["name"],
|
|
action_input=new_message["function_call"]["arguments"]
|
|
) for callback in self.callbacks]
|
|
function_name = new_message["function_call"]["name"]
|
|
temp_node = tree_node()
|
|
temp_node.node_type = "Action"
|
|
temp_node.description = function_name
|
|
child_io_state = deepcopy(temp_now_node.io_state)
|
|
child_io_state.retriever=None
|
|
|
|
temp_node.io_state = child_io_state
|
|
temp_node.is_terminal = child_io_state.check_success() != 0
|
|
temp_node.messages = deepcopy(temp_now_node.messages)
|
|
temp_node.father = temp_now_node
|
|
temp_now_node.children.append(temp_node)
|
|
|
|
temp_node.print(self.process_id)
|
|
temp_now_node = temp_node
|
|
|
|
function_input = new_message["function_call"]["arguments"]
|
|
temp_node = tree_node()
|
|
temp_node.node_type = "Action Input"
|
|
temp_node.description = function_input
|
|
child_io_state = deepcopy(temp_now_node.io_state)
|
|
child_io_state.retriever=None
|
|
|
|
# on_tool_start
|
|
[callback.on_tool_start(
|
|
depth=now_depth,
|
|
tool_name=temp_now_node.description,
|
|
tool_input=function_input
|
|
) for callback in self.callbacks]
|
|
observation, status = child_io_state.step(
|
|
action_name=temp_now_node.description, action_input=function_input)
|
|
if status == 1:
|
|
print(observation)
|
|
temp_node.observation = observation
|
|
temp_node.observation_code = status
|
|
|
|
temp_node.io_state = child_io_state
|
|
temp_node.is_terminal = child_io_state.check_success() != 0
|
|
temp_node.messages = deepcopy(temp_now_node.messages)
|
|
temp_node.father = temp_now_node
|
|
temp_now_node.children.append(temp_node)
|
|
temp_node.print(self.process_id)
|
|
temp_now_node = temp_node
|
|
# on_tool_end
|
|
[callback.on_tool_end(
|
|
depth=now_depth,
|
|
output=observation,
|
|
status=status
|
|
) for callback in self.callbacks]
|
|
if status != 0:
|
|
# return code defination can be seen in Downstream_tasks/rapid_api
|
|
if status == 4:
|
|
temp_now_node.pruned = True
|
|
elif status == 1: # hallucination api name
|
|
assert "function_call" in new_message.keys()
|
|
os.makedirs('output', exist_ok=True)
|
|
print(new_message["function_call"]["name"], file=open('output/hallucination.txt','a'))
|
|
new_message["function_call"]["name"] = "invalid_hallucination_function_name"
|
|
elif status == 3: # final answer
|
|
temp_now_node.is_terminal = True
|
|
temp_now_node.make_finish(final_answer_back_length)
|
|
|
|
temp_now_node.messages.append(new_message)
|
|
if temp_now_node.node_type == "Action Input":
|
|
temp_now_node.messages.append({
|
|
"role": "function",
|
|
"name": new_message["function_call"]["name"],
|
|
"content": temp_now_node.observation,
|
|
})
|
|
if self.query_count >= max_query_count: # a big return value will cause the Algo to exit
|
|
temp_now_node.messages.append({
|
|
"role": "user",
|
|
"content": "you have reached the maximum query count, please call the finish function to give the answer or give up without restart.",
|
|
})
|
|
return_value = None
|
|
if not with_filter: # DFSDT
|
|
result = self.DFS(temp_now_node, single_chain_max_step,
|
|
tree_beam_size, max_query_count, answer, with_filter)
|
|
if len(self.terminal_node) >= answer:
|
|
return_value = 10000
|
|
elif result > 1:
|
|
return_value = result-1
|
|
|
|
else:
|
|
|
|
next_tree_split_nodes.append(temp_now_node)
|
|
self.send_agent_chain_end(
|
|
now_depth, agent_block_ids, chain_block_ids)
|
|
if return_value is not None:
|
|
return return_value
|
|
|
|
# Sort the generated next_tree_split_nodes nodes when normal DFS
|
|
if len(next_tree_split_nodes) > 1:
|
|
# When using normal DFS, if we have many child nodes, we will refer to LLM to compare and choose the best one to expand first
|
|
# remember, this operator will cost extra OpenAI calls.
|
|
LLM_rank_args = {
|
|
"functions": self.io_func.functions,
|
|
"process_id": self.process_id,
|
|
"task_description": self.io_func.task_description,
|
|
"rank_func": rank2_subfix,
|
|
}
|
|
scores, rank_query_count, total_tokens = sum_based_rankn(
|
|
self.llm, LLM_rank_args=LLM_rank_args, candidates=next_tree_split_nodes)
|
|
self.query_count += rank_query_count
|
|
self.total_tokens += total_tokens
|
|
for score, node in zip(scores, next_tree_split_nodes):
|
|
node.prior_score = score
|
|
zip_value = list(
|
|
zip(next_tree_split_nodes, range(len(next_tree_split_nodes))))
|
|
zip_value.sort(
|
|
key=lambda x: x[0].prior_score, reverse=True) # 先做score高的
|
|
next_tree_split_nodes, filtered_order = zip(*zip_value)
|
|
# if self.process_id == 0:
|
|
# print(f"score={scores}, filtered order: {filtered_order}")
|
|
|
|
'''
|
|
Choose one to expand
|
|
'''
|
|
for i in range(len(next_tree_split_nodes)):
|
|
result = self.DFS(
|
|
next_tree_split_nodes[i], single_chain_max_step, tree_beam_size, max_query_count, answer)
|
|
if len(self.terminal_node) >= answer:
|
|
return 10000
|
|
elif result > 1:
|
|
now_node.make_finish(2)
|
|
return result - 1
|
|
|
|
return 1
|