AnyTool/toolbench/inference/Algorithms/DFS.py
2024-02-23 15:13:06 +08:00

406 lines
19 KiB
Python

import re
from toolbench.inference.Tree.Tree import my_tree, tree_node
from toolbench.inference.Prompts.ReAct_prompts import FORMAT_INSTRUCTIONS_SYSTEM_FUNCTION, FORMAT_INSTRUCTIONS_SYSTEM_FUNCTION_ADAPTED, FORMAT_INSTRUCTIONS_USER_FUNCTION
from toolbench.inference.Prompts.Tree_search_prompts import DIVERSITY_PROMPT
from toolbench.inference.Algorithms.base_search import base_search_method
from copy import deepcopy
from toolbench.inference.LLM_rank.rank_candidate import sum_based_rankn, rank2_subfix
import json
import random
import os
from arguments import parse_args
args = parse_args()
file_name = 'llama_io.txt'
class DFS_tree_search(base_search_method):
def __init__(self, llm, io_func, process_id=0, callbacks=None):
super(DFS_tree_search, self).__init__(
llm, io_func, process_id, callbacks)
"""Depth-first search.
with_filter=True: Every time a child node is generated, choose the best multiple iterations to go.
with_filter=False: Do as Preorder traversal.
"""
self.io_func = io_func
self.llm = llm
self.process_id = process_id
self.restart()
self.callbacks = callbacks if callbacks is not None else []
def restart(self):
self.status = 0
self.terminal_node = []
self.give_up_node = []
self.now_expand_num = 0
self.query_count = 0
self.total_tokens = 0
def send_agent_chain_end(self, depth, agent_block_ids, chain_block_ids):
for i in range(len(self.callbacks)):
callback = self.callbacks[i]
callback.on_chain_end(
depth=depth,
block_id=chain_block_ids[i]
)
if i < len(agent_block_ids):
callback.on_agent_end(
depth=depth,
block_id=agent_block_ids[i]
)
def to_json(self, answer=False, process=True):
if process:
json_obj = {
"win": self.status == 1,
"tree": self.tree.to_json_recursive(),
"forward_args": self.forward_args,
"compare_candidates": [],
}
for node in self.terminal_node:
if node.pruned == False: # has answer
json_obj["compare_candidates"].append(
node.get_chain_result_from_this_node(use_messages=False))
else:
json_obj = {}
if answer:
json_obj["answer_generation"] = {
"valid_data": False,
"query_count": self.query_count,
"total_tokens": self.total_tokens,
"final_answer": "",
"finish_type": "give_answer",
"function": self.io_func.functions,
"chain": [],
}
for node in self.terminal_node:
if node.pruned == False:
if 'give_up' in node.description.lower():
json_obj["answer_generation"]["finish_type"] = "give_up"
else:
json_obj["answer_generation"]["finish_type"] = "give_answer"
json_obj["answer_generation"]["final_answer"] = node.description
json_obj["answer_generation"]["valid_data"] = True
json_obj["answer_generation"]["train_messages"] = node.get_train_messages_from_this_node(
)
break
# do not have final answer, look for give_up
if json_obj["answer_generation"]["valid_data"] == False:
if len(self.give_up_node) > 0:
random_pos = random.randint(0, len(self.give_up_node) - 1)
choose_give_up_node = self.give_up_node[random_pos]
json_obj["answer_generation"]["valid_data"] = True
json_obj["answer_generation"]["finish_type"] = "give_up"
json_obj["answer_generation"]["final_answer"] = choose_give_up_node.description
json_obj["answer_generation"]["train_messages"] = choose_give_up_node.get_train_messages_from_this_node()
return json_obj
def start(self, single_chain_max_step, tree_beam_size, max_query_count, answer=1, with_filter=True, messages=None):
""" single_chain_max_step: The maximum depth of the tree
tree_beam_size: How many children nodes for one node are generated per layer
answer = n means the Algo exits when find n "give_answer" nodes
max_query_count: the Algo exits when OpenAI-query exists this value
with_filter: This is the difference between normal DFS(with_filter=True) and DFSDT(with_filter=False).
"""
self.forward_args = locals()
if "self" in self.forward_args.keys():
# self.forward_args.pop("self")
self.forward_args=None
self.tree = my_tree()
self.tree.root.node_type = "Action Input"
self.tree.root.io_state = deepcopy(self.io_func)
if args.use_original_prompt:
system = FORMAT_INSTRUCTIONS_SYSTEM_FUNCTION
else:
system = FORMAT_INSTRUCTIONS_SYSTEM_FUNCTION_ADAPTED
system = system.replace("{task_description}",
self.io_func.task_description)
user = FORMAT_INSTRUCTIONS_USER_FUNCTION
user = user.replace("{input_description}",
self.io_func.input_description)
if messages is None:
self.tree.root.messages.append({"role": "system", "content": system})
self.tree.root.messages.append({"role": "user", "content": user})
else:
messages[0] = {"role": "system", "content": system}
messages.pop()
function_names = []
for function in self.io_func.functions:
function_names.append(function["name"])
for i, message in reversed(list(enumerate(messages))):
if message["role"] == "function":
if message["name"] not in function_names:
messages.pop(i)
messages.pop(i-1)
if message["role"] == "user":
if 'maximum query count' in message["content"]:
messages.pop(i)
self.tree.root.messages = messages
print('#'*100, file=open(file_name,'a') )
return self.DFS(self.tree.root, single_chain_max_step, tree_beam_size, max_query_count, answer, with_filter)
def DFS(self, now_node, single_chain_max_step, tree_beam_size, max_query_count, answer, with_filter=True):
"""Returns the number of grids to go back. When a child node of a node generates a final answer or give up, it should go back a few more grids
In a sense, the larger this value is, the more diverse it is, and it is GreedySearch@n when it is enlarged to infinity.
"""
# this two value declares the rate to go back, Algo degrades to CoT when the value=Inf
if args.solver == 'dfs':
final_answer_back_length = 2
prune_back_length = 2
else:
final_answer_back_length = 10000
prune_back_length = 10000
now_node.expand_num = self.now_expand_num
self.now_expand_num += 1
if now_node.get_depth() >= single_chain_max_step or now_node.pruned or now_node.is_terminal:
if now_node.is_terminal: # final answer
self.status = 1
self.terminal_node.append(now_node)
return final_answer_back_length
else:
now_node.pruned = True
if now_node.observation_code == 4:
self.give_up_node.append(now_node)
return prune_back_length
else:
return 1
next_tree_split_nodes = []
for i in range(tree_beam_size):
temp_now_node = now_node
"""If a node have children now, We will prompt the model to generate different nodes than all the existing nodes"""
delete_former_diversity_message = False
diversity_message = None
if len(temp_now_node.children) > 0:
former_candidates_des = ""
js_list = []
for k, child in enumerate(temp_now_node.children):
temp_node = child
while not temp_node.is_terminal and temp_node.node_type != "Action Input" and len(temp_node.children) > 0:
temp_node = temp_node.children[0]
if temp_node.node_type == "Action Input":
obj_dict = {
"name": temp_node.father.description,
"arguments": temp_node.description,
"function_output": temp_node.observation,
"mento-carlo-action-value": temp_node.compute_weight(),
}
js_list.append(obj_dict)
if len(js_list) > 0:
former_candidates_des = former_candidates_des + \
f"{json.dumps(js_list,indent=2)}\n"
if temp_now_node.observation != "":
former_candidates_des = former_candidates_des + \
f"again, your former observation: {temp_now_node.observation}\n"
diverse_prompt = DIVERSITY_PROMPT
diverse_prompt = diverse_prompt.replace(
"{previous_candidate}", former_candidates_des)
diversity_message = {
"role": "user", "content": diverse_prompt}
temp_now_node.messages.append(diversity_message)
delete_former_diversity_message = True
# on_chain_start
now_depth = temp_now_node.get_depth() // 3
chain_block_ids = [callback.on_chain_start(
depth=now_depth,
inputs=temp_now_node.messages
) for callback in self.callbacks]
agent_block_ids = []
self.llm.change_messages(temp_now_node.messages)
# on_llm_start
[callback.on_llm_start(
depth=now_depth,
messages=temp_now_node.messages
) for callback in self.callbacks]
new_message, error_code, total_tokens = self.llm.parse(
self.io_func.functions, process_id=self.process_id)
# print('-'*100, file=open(file_name,'a'))
# print('input', file=open(file_name,'a') )
a = deepcopy(temp_now_node.messages)
for aa in a:
# pprint.pprint(get_pretty_print(json.dumps(aa, indent=4)))
if 'function_call' in aa:
aa['function_call'] = {}
print(json.dumps(aa, indent=4), file=open(file_name,'a'))
# print('output', file=open(file_name,'a') )
# print(new_message, file=open(file_name,'a') )
# on_llm_end
[callback.on_llm_end(
depth=now_depth,
response=new_message
) for callback in self.callbacks]
self.query_count += 1
self.total_tokens += total_tokens
# if self.query_count >= max_query_count: # a big return value will cause the Algo to exit
# return 100000
# We need to exclude the diversity_message, because it will influence child nodes
if delete_former_diversity_message:
temp_now_node.messages[-1]["valid"] = False
# parse nodes from OpenAI-message like CoT method
assert new_message["role"] == "assistant"
if "content" in new_message.keys() and new_message["content"] != None:
temp_node = tree_node()
temp_node.node_type = "Thought"
temp_node.description = new_message["content"]
child_io_state = deepcopy(temp_now_node.io_state)
child_io_state.retriever=None
temp_node.io_state = child_io_state
temp_node.is_terminal = child_io_state.check_success() != 0
temp_node.messages = deepcopy(temp_now_node.messages)
temp_node.father = temp_now_node
temp_now_node.children.append(temp_node)
temp_node.print(self.process_id)
temp_now_node = temp_node
if error_code != 0:
temp_now_node.observation_code = error_code
temp_now_node.pruned = True
if "function_call" in new_message.keys():
# on_agent_action
agent_block_ids = [callback.on_agent_action(
depth=now_depth,
action=new_message["function_call"]["name"],
action_input=new_message["function_call"]["arguments"]
) for callback in self.callbacks]
function_name = new_message["function_call"]["name"]
temp_node = tree_node()
temp_node.node_type = "Action"
temp_node.description = function_name
child_io_state = deepcopy(temp_now_node.io_state)
child_io_state.retriever=None
temp_node.io_state = child_io_state
temp_node.is_terminal = child_io_state.check_success() != 0
temp_node.messages = deepcopy(temp_now_node.messages)
temp_node.father = temp_now_node
temp_now_node.children.append(temp_node)
temp_node.print(self.process_id)
temp_now_node = temp_node
function_input = new_message["function_call"]["arguments"]
temp_node = tree_node()
temp_node.node_type = "Action Input"
temp_node.description = function_input
child_io_state = deepcopy(temp_now_node.io_state)
child_io_state.retriever=None
# on_tool_start
[callback.on_tool_start(
depth=now_depth,
tool_name=temp_now_node.description,
tool_input=function_input
) for callback in self.callbacks]
observation, status = child_io_state.step(
action_name=temp_now_node.description, action_input=function_input)
if status == 1:
print(observation)
temp_node.observation = observation
temp_node.observation_code = status
temp_node.io_state = child_io_state
temp_node.is_terminal = child_io_state.check_success() != 0
temp_node.messages = deepcopy(temp_now_node.messages)
temp_node.father = temp_now_node
temp_now_node.children.append(temp_node)
temp_node.print(self.process_id)
temp_now_node = temp_node
# on_tool_end
[callback.on_tool_end(
depth=now_depth,
output=observation,
status=status
) for callback in self.callbacks]
if status != 0:
# return code defination can be seen in Downstream_tasks/rapid_api
if status == 4:
temp_now_node.pruned = True
elif status == 1: # hallucination api name
assert "function_call" in new_message.keys()
os.makedirs('output', exist_ok=True)
print(new_message["function_call"]["name"], file=open('output/hallucination.txt','a'))
new_message["function_call"]["name"] = "invalid_hallucination_function_name"
elif status == 3: # final answer
temp_now_node.is_terminal = True
temp_now_node.make_finish(final_answer_back_length)
temp_now_node.messages.append(new_message)
if temp_now_node.node_type == "Action Input":
temp_now_node.messages.append({
"role": "function",
"name": new_message["function_call"]["name"],
"content": temp_now_node.observation,
})
if self.query_count >= max_query_count: # a big return value will cause the Algo to exit
temp_now_node.messages.append({
"role": "user",
"content": "you have reached the maximum query count, please call the finish function to give the answer or give up without restart.",
})
return_value = None
if not with_filter: # DFSDT
result = self.DFS(temp_now_node, single_chain_max_step,
tree_beam_size, max_query_count, answer, with_filter)
if len(self.terminal_node) >= answer:
return_value = 10000
elif result > 1:
return_value = result-1
else:
next_tree_split_nodes.append(temp_now_node)
self.send_agent_chain_end(
now_depth, agent_block_ids, chain_block_ids)
if return_value is not None:
return return_value
# Sort the generated next_tree_split_nodes nodes when normal DFS
if len(next_tree_split_nodes) > 1:
# When using normal DFS, if we have many child nodes, we will refer to LLM to compare and choose the best one to expand first
# remember, this operator will cost extra OpenAI calls.
LLM_rank_args = {
"functions": self.io_func.functions,
"process_id": self.process_id,
"task_description": self.io_func.task_description,
"rank_func": rank2_subfix,
}
scores, rank_query_count, total_tokens = sum_based_rankn(
self.llm, LLM_rank_args=LLM_rank_args, candidates=next_tree_split_nodes)
self.query_count += rank_query_count
self.total_tokens += total_tokens
for score, node in zip(scores, next_tree_split_nodes):
node.prior_score = score
zip_value = list(
zip(next_tree_split_nodes, range(len(next_tree_split_nodes))))
zip_value.sort(
key=lambda x: x[0].prior_score, reverse=True) # 先做score高的
next_tree_split_nodes, filtered_order = zip(*zip_value)
# if self.process_id == 0:
# print(f"score={scores}, filtered order: {filtered_order}")
'''
Choose one to expand
'''
for i in range(len(next_tree_split_nodes)):
result = self.DFS(
next_tree_split_nodes[i], single_chain_max_step, tree_beam_size, max_query_count, answer)
if len(self.terminal_node) >= answer:
return 10000
elif result > 1:
now_node.make_finish(2)
return result - 1
return 1