AnyTool/toolbench/inference/Algorithms/single_chain.py
2024-02-23 15:13:06 +08:00

189 lines
8.1 KiB
Python

import re
from toolbench.inference.Tree.Tree import my_tree, tree_node
from toolbench.inference.Prompts.ReAct_prompts import FORMAT_INSTRUCTIONS_SYSTEM_FUNCTION, FORMAT_INSTRUCTIONS_USER_FUNCTION
from toolbench.inference.Algorithms.base_search import base_search_method
from copy import deepcopy
class single_chain(base_search_method):
"""Implement of CoT method
"""
def __init__(self,llm,io_func,extra_prefix="",process_id=0,start_message_list=None):
"""extra_prefix and start_message_list is used in Reflection Algo"""
super(single_chain, self).__init__(llm,io_func, process_id, callbacks=None)
self.io_func = io_func
self.llm = llm
self.extra_prefix = extra_prefix
self.start_message_list = start_message_list
self.process_id = process_id
self.restart()
def restart(self):
self.status = 0
self.try_list = []
self.terminal_node = []
self.query_count = 0 # number of interactions with openai
self.total_tokens = 0
self.success_count = 0
def to_json(self, answer=False,process=True):
if process:
json_obj = {
"win": self.status == 1,
"try_count": len(self.try_list),
"trys": self.try_list,
"compare_candidates": [],
"forward_args":self.forward_args,
}
for node in self.terminal_node:
if node.pruned == False: # has final answer
json_obj["compare_candidates"].append(node.get_chain_result_from_this_node(use_messages=False))
else:
json_obj = {}
if answer:
json_obj["answer_generation"] = {
"valid_data": False,
"final_answer": "",
"function": self.io_func.functions,
"query_count": self.query_count,
"total_tokens": self.total_tokens,
"train_messages": [],
"chain": [],
}
for node in self.terminal_node:
if node.pruned == False:
json_obj["answer_generation"]["valid_data"] = True
json_obj["answer_generation"]["final_answer"] = node.description
json_obj["answer_generation"]["train_messages"] = node.get_train_messages_from_this_node()
break
return json_obj
def to_json_single(self):
"""parse the last try
Though the nodes are formed as a tree, We still know they are actually a chain
"""
json_obj = {}
tree_obj = self.terminal_node[-1].get_chain_result_from_this_node()
json_obj["chain"] = tree_obj
json_obj["win"] = self.status == 1
return json_obj
def start(self,single_chain_max_step,pass_at=1,answer=1):
self.forward_args = locals()
if "self" in self.forward_args.keys():
self.forward_args.pop("self")
for i in range(pass_at):
if self.process_id == 0:
print(f"[single_chain]try for the {i+1} time")
self.tree = my_tree()
self.tree.root.node_type = "Action Input"
self.tree.root.io_state = deepcopy(self.io_func)
out_node = self.do_chain(self.tree.root, single_chain_max_step)
self.terminal_node.append(out_node)
self.try_list.append(self.to_json_single())
if out_node.io_state.check_success() == 1:
self.status = 1
self.success_count += 1
if self.success_count >= answer:
return 1
return 0
def do_chain(self,now_node,single_chain_max_step):
if self.start_message_list == None:
system = FORMAT_INSTRUCTIONS_SYSTEM_FUNCTION
system = system.replace("{task_description}",self.io_func.task_description)
self.tree.root.messages.append({"role":"system","content":system})
user = FORMAT_INSTRUCTIONS_USER_FUNCTION
user = user.replace("{input_description}",self.io_func.input_description)
self.tree.root.messages.append({"role":"user","content":user})
else:
"""In Reflection Algo, we startswith former trials and reflections, so the caller will give the start messages"""
self.tree.root.messages = self.start_message_list
now_node = self.tree.root
while True:
# recursively parse message into nodes
self.llm.change_messages(now_node.messages)
new_message,error_code,total_tokens = self.llm.parse(functions=self.io_func.functions,process_id=self.process_id)
self.total_tokens += total_tokens
self.query_count += 1
assert new_message["role"] == "assistant"
if "content" in new_message.keys() and new_message["content"] != None:
temp_node = tree_node()
temp_node.node_type = "Thought"
temp_node.description = new_message["content"]
child_io_state = deepcopy(now_node.io_state)
temp_node.io_state = child_io_state
temp_node.is_terminal = child_io_state.check_success() != 0
temp_node.messages = now_node.messages.copy()
temp_node.father = now_node
now_node.children.append(temp_node)
temp_node.print(self.process_id)
now_node = temp_node
if error_code != 0:
now_node.observation_code = error_code
now_node.pruned = True
if "function_call" in new_message.keys():
function_name = new_message["function_call"]["name"]
temp_node = tree_node()
temp_node.node_type = "Action"
temp_node.description = function_name
child_io_state = deepcopy(now_node.io_state)
temp_node.io_state = child_io_state
temp_node.is_terminal = child_io_state.check_success() != 0
temp_node.messages = now_node.messages.copy()
temp_node.father = now_node
now_node.children.append(temp_node)
temp_node.print(self.process_id)
now_node = temp_node
function_input = new_message["function_call"]["arguments"]
temp_node = tree_node()
temp_node.node_type = "Action Input"
temp_node.description = function_input
child_io_state = deepcopy(now_node.io_state)
observation, status = child_io_state.step(action_name=now_node.description, action_input=function_input)
temp_node.observation = observation
temp_node.observation_code = status
temp_node.io_state = child_io_state
temp_node.is_terminal = child_io_state.check_success() != 0
temp_node.messages = now_node.messages.copy()
temp_node.father = now_node
now_node.children.append(temp_node)
temp_node.print(self.process_id)
now_node = temp_node
if status != 0:
# return code refers to Downstream_tasks/rapidapi
if status == 4:
now_node.pruned = True
elif status == 1: # hallucination api name
assert "function_call" in new_message.keys()
new_message["function_call"]["name"] = "invalid_hallucination_function_name"
now_node.messages.append(new_message)
if now_node.node_type == "Action Input":
now_node.messages.append({
"role":"function",
"name": new_message["function_call"]["name"],
"content": now_node.observation,
})
if now_node.get_depth() >= single_chain_max_step and not (now_node.is_terminal):
now_node.pruned = True
if now_node.pruned or now_node.is_terminal:
return now_node