from anytool.api_database_function import * import json import os from anytool.prompt_template import * from anytool.verifier import check_task_solvable_by_function, check_task_solvable, check_solved_toolbench, check_task_complete from termcolor import colored from openai_utils import call_gpt import threading from threading import Thread, Semaphore import time import numpy as np from arguments import parse_args args = parse_args() output_dir = args.output_dir raise_error = False max_api_number = args.max_api_number sem = Semaphore(16) # 允许同时运行的最大线程数为16 class DoNothingContextManager: def __enter__(self): pass def __exit__(self, exc_type, exc_val, exc_tb): pass leaf_tool_number = args.leaf_tool_number multi_thread = True if multi_thread: counter_lock = threading.Lock() else: counter_lock = DoNothingContextManager() def Finish(): """Finish the conversation""" return 'finished' def remove_apis(api_list): """remove apis from the current available api list. required input to be list of dictionaries describing with the keys category_name, tool_name, api_name""" print(colored(f'removing apis: {api_list}', 'red')) if len(api_list) == 0: return 'empty api list' if isinstance(api_list, str): api_list = eval(api_list) if not isinstance(api_list, list) or any('category_name' not in ele or 'tool_name' not in ele or 'api_name' not in ele for ele in api_list): return 'illegal input, input should be list, each element in the list should have category_name, tool_name, api_name' if not all([isinstance(ele['category_name'],str) and isinstance(ele['tool_name'],str) and isinstance(ele['api_name'],str) for ele in api_list]): return 'illegal input, category_name, tool_name, api_name should be string' origin_api_list = deepcopy(api_list) # for api in origin_api_list: # self.api_list.remove(api) global global_api_list, global_api_list_detailed for api in api_list: # api.update(get_api_details(api['category_name'], api['tool_name'], api['api_name'])) tool_details = get_tool_description(api['category_name'], api['tool_name']) api_details = get_api_details(**api) api['tool_description'] = tool_details['tool_description'] if isinstance(tool_details, dict) else '' api['api_description'] = api_details['description'] if 'description' in api_details else '' try: with counter_lock: if api in global_api_list: global_api_list.remove(api) except: pass for api in origin_api_list: for ele in global_api_list: if ele['category_name'] == api['category_name'] and ele['tool_name'] == api['tool_name'] and ele['api_name'] == api['api_name']: with counter_lock: global_api_list.remove(ele) break return f'APIs removed successfully. Current API number: {len(global_api_list)}. Max API number: {max_api_number}' class Agent(object): def __init__(self) -> None: self.failed_reason = None self.messages = [] self.depth = 0 self.index = 0 self.finish_search = False self.sub_agents = [] def check_if_request_solvable(): global stop, status, total_tokens, call_cnt if stop: return 'Current APIs already sufficient to solve the query.' t_s = time.time() solvable, reason, tokens = check_task_solvable_by_function(query, global_api_list_detailed) total_tokens += tokens call_cnt += 1 print(time.time() - t_s, file=open(f'{output_dir}/time.txt', 'a', encoding='utf-8')) if solvable != 'Unsolvable': with counter_lock: stop = True status = 'The current API list can solve the query.' return f'Current API number: {len(global_api_list)}. Max API number: {max_api_number}' else: with counter_lock: status = f'The current API list cannot solve the query due to the following reason: {reason}' if len(global_api_list) >= max_api_number: with counter_lock: stop = True return f'Current API number: {len(global_api_list)}. Max API number: {max_api_number}. The current API list cannot solve the query due to the following reason: {reason}' class Category_Agent(Agent): def __init__(self, query, category=None) -> None: super().__init__() self.category = category self.tools = get_tools_in_category(self.category) self.query = query self.info = f'category: {self.category} assigned' self.api_mapping = { "query_all_categories": query_all_categories, "retrieve_context": retrieve_context, "Finish": Finish, "get_tools_descriptions": get_tools_descriptions, "create_agent_tool_level": self.create_agent_tool_level, } self.functions = [ get_tools_descriptions_function, finish_function ] self.tools = get_tools_in_category(self.category) def resume_search(self): """Assign a category to an agent""" global call_cnt, total_tokens, stop, error_flag if stop or total_tokens > 200000: self.finish_search = True if multi_thread: sem.release() return f'category: {self.category} assigned' print(colored(f'assigning category: {self.category}', 'green')) if len(self.tools) <= leaf_tool_number: self.finish_search = True return f'category: {self.category} assigned' if self.failed_reason is not None: self.messages.append({"role": "user", "content": REFIND_TOOL_PROMPT.replace('{failed_reason}', str(self.failed_reason))}) self.failed_reason = None for i in range(20): if stop or total_tokens > 200000: if multi_thread: sem.release() return f'category: {self.category} assigned' t_s = time.time() try: response = call_gpt( messages=self.messages, functions=self.functions ) except: error_flag = True stop = True continue with counter_lock: total_tokens += response.usage.total_tokens call_cnt += 1 print(time.time() - t_s, file=open(f'{output_dir}/time.txt', 'a', encoding='utf-8')) if isinstance(response, str): continue tool_calls = response.choices[0].message.tool_calls print('Thought:', response.choices[0].message.content) if tool_calls is not None: print('tool call number', len(tool_calls)) # print('message', response.choices[0].message) if tool_calls: self.messages.append({ "role": "assistant", "tool_calls": tool_calls, "content": response.choices[0].message.content if response.choices[0].message.content is not None else '', }) for tool_call in tool_calls: function_name = tool_call.function.name function_args = tool_call.function.arguments print('function call:', function_name, function_args) if function_name == 'get_tools_in_category': self.query_tools_call = True if function_name.lower() == 'finish': print(colored(f'category: {self.category} assigned', 'green')) print(colored('category finish search', 'green')) self.messages.append( { "tool_call_id": tool_call.id, "role": "tool", "name": function_name, "content": 'Finished', }) self.finish_search = True if multi_thread: sem.release() return f'category: {self.category} assigned.' else: return f'category: {self.category} assigned. The status of current found apis is: {status}' elif function_name not in self.api_mapping: function_name = 'hullucinating_function_name' tool_call.function.name = function_name function_call_result = "Function name error" self.messages.append( { "tool_call_id": tool_call.id, "role": "tool", "name": function_name, "content": str(function_call_result), } ) else: try: function_call_result = self.api_mapping[function_name](**json.loads(function_args)) if function_name in ['get_apis_in_tool'] and isinstance(function_call_result, str) and 'Illegal tool' in function_call_result: function_call_result = f'Illegal tool. The tool should be in the tool list {self.tools}' except Exception as e: print(e, function_name, function_args, file=open(f'{output_dir}/error.txt', 'a', encoding='utf-8')) if raise_error: raise e function_call_result = 'input format error' self.messages.append( { "tool_call_id": tool_call.id, "role": "tool", "name": function_name, "content": str(function_call_result), }) print('function response:', function_call_result) else: # continue self.messages.append({ "role": "assistant", "content": response.choices[0].message.content if response.choices[0].message.content is not None else '', }) self.messages.append({'role': "user", 'content': 'At each step, you should call a function to actually excute your step.'}) print(colored(f'category: {self.category} assigned', 'green')) self.finish_search = True if multi_thread: sem.release() return f'category: {self.category} assigned.' else: return f'category: {self.category} assigned. The status of current found apis is: {status}' def category_search(self): """Assign a category to an agent""" print(colored(f'assigning category: {self.category}', 'green')) self.tools = get_tools_in_category(self.category) if len(self.tools) > leaf_tool_number: self.functions.append(create_agent_tool_level_function) self.messages = [{ "role": "system", "content": CATEGORY_AGENT_PROMPT.replace('{category}', self.category)}, {"role": "user", "content": f"Task description: {self.query}. All the tools: {self.tools}. Begin!"}] else: function_call_result = self.create_agent_tool_level(self.category, self.tools) return f'category: {self.category} assigned' global total_tokens, call_cnt, stop, error_flag for i in range(20): if stop or total_tokens > 200000: if multi_thread: sem.release() return f'category: {self.category} assigned' t_s = time.time() try: response = call_gpt( messages=self.messages, functions=self.functions ) except: error_flag = True stop = True continue print(time.time() - t_s, file=open(f'{output_dir}/time.txt', 'a', encoding='utf-8')) if isinstance(response, str): continue with counter_lock: total_tokens += response.usage.total_tokens call_cnt += 1 tool_calls = response.choices[0].message.tool_calls print('Thought:', response.choices[0].message.content) if tool_calls is not None: print('tool call number', len(tool_calls)) if tool_calls: self.messages.append( { "role": "assistant", "tool_calls": tool_calls, "content": response.choices[0].message.content if response.choices[0].message.content is not None else '', } ) for tool_call in tool_calls: function_name = tool_call.function.name function_args = tool_call.function.arguments print('function call:', function_name, function_args) if function_name.lower() == 'finish': print(colored(f'category: {self.category} assigned', 'green')) print(colored('category finish search', 'green')) self.messages.append( { "tool_call_id": tool_call.id, "role": "tool", "name": function_name, "content": 'Finished', }) self.finish_search = True if multi_thread: sem.release() return f'category: {self.category} assigned.' else: return f'category: {self.category} assigned. The status of current found apis is: {status}' elif function_name not in self.api_mapping: function_name = 'hullucinating_function_name' tool_call.function.name = function_name function_call_result = "Function name error" else: if function_name == "create_agent_tool_level": print(colored('create_agent_tool_level', 'green')) try: # if True: function_call_result = self.api_mapping[function_name](**json.loads(function_args)) if function_name in ['get_apis_in_tool'] and isinstance(function_call_result, str) and 'Illegal tool' in function_call_result: function_call_result = f'Illegal tool. The tool should be in the tool list {self.tools}' except Exception as e: print(e, function_name, function_args, file=open(f'{output_dir}/error.txt', 'a', encoding='utf-8')) if raise_error: raise e function_call_result = 'input format error' self.messages.append( { "tool_call_id": tool_call.id, "role": "tool", "name": function_name, "content": str(function_call_result), }) print('function response:', function_call_result) else: self.messages.append({ "role": "assistant", "content": response.choices[0].message.content if response.choices[0].message.content is not None else '', }) self.messages.append({'role': "user", 'content': 'At each step, you should call a function to actually excute your step.'}) print(colored(f'category: {self.category} assigned', 'green')) self.finish_search = True if multi_thread: sem.release() return f'category: {self.category} assigned.' else: return f'category: {self.category} assigned. The status of current found apis is: {status}' def create_agent_tool_level(self, category: str, tools): """Assign a subset of tools in a category to a agent""" if isinstance(tools, str): tools = eval(tools) illegal_tools = [] for tool in tools: if tool not in self.tools: illegal_tools.append(tool) if len(illegal_tools) > 0: print(colored(f'Illegal tools: {illegal_tools} in category: {category} assigned', 'red')) return f'Illegal tools: {illegal_tools} in category: {category} assigned' if len(tools) > leaf_tool_number: return f'Tool number should not exceed the max tool number of {leaf_tool_number}. Please assign again' global tree with counter_lock: tree[category][str(tools)] = {} global agents, index with counter_lock: agents.append(Tool_Agent(self.query, category, tools)) agents[-1].depth = self.depth + 1 index += 1 agents[-1].index = index self.sub_agents.append(agents[-1]) # yield from agents[-1].tool_search() global threads if multi_thread: thread = threading.Thread(target=agents[-1].tool_search) sem.acquire() thread.start() with counter_lock: threads.append(thread) else: agents[-1].tool_search() if multi_thread: return f'tools {tools} assigned.' else: return f'tools {tools} assigned. The status of current found apis is: {status}' class Tool_Agent(Agent): def __init__(self, query, category=None, tools=None) -> None: super().__init__() self.category = category if isinstance(tools, str): tools = eval(tools) self.tools = tools self.functions = [finish_function] self.query = query if isinstance(tools, str): tools = eval(tools) if len(tools) > leaf_tool_number: return f"you should assign less than {leaf_tool_number} tools each time" else: self.functions.extend([ # get_api_details_function, # get_apis_in_tool_function, # get_tool_details_function, check_if_request_solvable_function, # remove_apis_function, add_apis_into_api_pool_function, ]) tools_info = query_all_tool_info(category, tools) self.messages = [{ "role": "system", "content": TOOL_AGENT_PROMPT.replace('{category}', str(category)).replace('{tools}', str(tools))}, {"role": "user", "content": f"Task description: {self.query} All the tool description and the contained api_list as a dict: {tools_info}. Begin!"}] self.api_mapping = { "query_all_categories": query_all_categories, "get_tools_in_category": get_tools_in_category, # "get_apis_in_tool": get_apis_in_tool, "Finish": Finish, # "get_api_details": get_api_details, "create_agent_tool_level": self.create_agent_tool_level, "add_apis_into_api_pool": self.add_apis_into_api_pool, "check_if_request_solvable": check_if_request_solvable, # "remove_apis": self.remove_apis, } def remove_apis(self, api_list): """remove apis from the current available api list. required input to be list of dictionaries describing with the keys category_name, tool_name, api_name""" print(colored(f'removing apis: {api_list}', 'red')) if isinstance(api_list, str): api_list = eval(api_list) if not isinstance(api_list, list) or any('category_name' not in ele or 'tool_name' not in ele or 'api_name' not in ele for ele in api_list): return 'illegal input, input should be list, each element in the list should have category_name, tool_name, api_name' if not all([isinstance(ele['category_name'],str) and isinstance(ele['tool_name'],str) and isinstance(ele['api_name'],str) for ele in api_list]): return 'illegal input, category_name, tool_name, api_name should be string' origin_api_list = deepcopy(api_list) global global_api_list, global_api_list_detailed for api in api_list: tool_details = get_tool_description(self.category, api['tool_name']) api_details = get_api_details(**api) api['tool_description'] = tool_details['tool_description'] if isinstance(tool_details, dict) else '' api['api_description'] = api_details['description'] if 'description' in api_details else '' try: with counter_lock: if api in global_api_list: global_api_list.remove(api) except: pass for api in origin_api_list: for ele in global_api_list: if ele['category_name'] == api['category_name'] and ele['tool_name'] == api['tool_name'] and ele['api_name'] == api['api_name']: with counter_lock: global_api_list.remove(ele) break return f'apis removed successfully. Current api number: {len(global_api_list)}. Max api number: {max_api_number}' def create_agent_tool_level(self, category: str, tools): """Assign a subset of tools in a category to a agent""" if isinstance(tools, str): tools = eval(tools) illegal_tools = [] for tool in tools: if tool not in self.tools: illegal_tools.append(tool) if len(illegal_tools) > 0: print(colored(f'Illegal tools: {illegal_tools} in category: {category} assigned', 'red')) return f'Illegal tools: {illegal_tools} in category: {category} assigned' global tree with counter_lock: tree[category][str(tools)] = {} global agents, index with counter_lock: agents.append(Tool_Agent(self.query, category, tools)) agents[-1].depth = self.depth + 1 index += 1 agents[-1].index = index self.sub_agents.append(agents[-1]) # generator = agents[-1].tool_search() global threads if multi_thread: thread = threading.Thread(target=agents[-1].tool_search) sem.acquire() thread.start() with counter_lock: threads.append(thread) else: agents[-1].tool_search() if multi_thread: return f'tools {tools} assigned.' else: return f'tools {tools} assigned. The status of current found apis is: {status}' def add_apis_into_api_pool(self, api_list): """add apis to the current available api list. required input to be list of dictionaries describing with the keys category_name, tool_name, api_name""" print(colored(f'adding apis: {api_list}', 'red')) global global_api_list, global_api_list_detailed, stop, status if len(global_api_list) + len(api_list) > max_api_number: return f'API number exceeds the max API number of {max_api_number}, current API number: {len(global_api_list)}, number of APIs to be added: {len(api_list)}. Please reduce the APIs to be added.' if isinstance(api_list, str): api_list = eval(api_list) # if len(api_list) > 2: # return 'too many apis to add, please add less than 2 apis each time' if not isinstance(api_list, list) or any('category_name' not in ele or 'tool_name' not in ele or 'api_name' not in ele for ele in api_list): return 'illegal input, input should be list, each element in the list should have category_name, tool_name, api_name' if not all([isinstance(ele['category_name'],str) and isinstance(ele['tool_name'],str) and isinstance(ele['api_name'],str) for ele in api_list]): return 'illegal input, category_name, tool_name, api_name should be string' # with counter_lock: # for api in deepcopy(api_list): # with counter_lock: # if api not in global_api_list: # global_api_list.append(api) # if stop: # return 'adding apis failed. Current apis already sufficient to solve the query. Please add again later.' # with counter_lock: for api in api_list: tool_details = get_tool_description(self.category, api['tool_name']) if tool_details == 'tool name not found': continue if api not in global_api_list: global_api_list.append(deepcopy(api)) api_details = get_api_details(**api) api['tool_description'] = tool_details['tool_description'] if isinstance(tool_details, dict) else '' api['api_description'] = api_details['description'] if 'description' in api_details else '' if api not in global_api_list_detailed: global_api_list_detailed.append(api) if not stop: t_s = time.time() solvable, reason, tokens = check_task_solvable_by_function(self.query, global_api_list_detailed) global total_tokens, call_cnt total_tokens += tokens call_cnt += 1 print(time.time() - t_s, file=open(f'{output_dir}/time.txt', 'a', encoding='utf-8')) if solvable != 'Unsolvable': stop = True status = 'The current api list can solve the query.' return f'APIs added. Current API number: {len(global_api_list)}. Max API number: {max_api_number}' # return 'apis added. The current api list can solve the query. If you think you have finished, call the Finish function.' else: status = f'The current API list cannot solve the query due to the following reason: {reason}' if len(global_api_list) >= max_api_number: stop = True # return f'apis added. Current api number: {len(global_api_list)}. Max api number: {max_api_number}' return f'APIs added. Current API number: {len(global_api_list)}. Max API number: {max_api_number}.' # return f'apis added. Current api number: {len(global_api_list)}. Max api number: {max_api_number}. The current api list cannot solve the query due to the following reason: {reason} Please find apis more purposely.' return f'APIs added. Current API number: {len(global_api_list)}. Max API number: {max_api_number}' def resume_search(self): if stop or total_tokens > 200000: self.finish_search = True if multi_thread: sem.release() print(f'tools {self.tools} assigned') return f'tools {self.tools} assigned' # self.functions.append(remove_apis_function) if self.failed_reason is not None: if len(self.tools) > leaf_tool_number: self.messages.append({"role": "user", "content": REFIND_TOOL_PROMPT.replace('{failed_reason}', str(self.failed_reason))}) else: self.messages.append({"role": "user", "content": REFIND_API_PROMPT.replace('{failed_reason}', str(self.failed_reason))}) self.failed_reason = None return self.tool_search() def tool_search(self): global stop, total_tokens, call_cnt, error_flag print(colored(f'assigning tools: {self.tools} in category: {self.category}', 'blue')) for i in range(20): if stop or total_tokens > 200000: print('#'*100) print(colored('stop', 'red')) if multi_thread: sem.release() return f'tools {self.tools} assigned' t_s = time.time() try: response = call_gpt( messages=self.messages, functions=self.functions ) except: error_flag = True stop = True continue print(time.time() - t_s, file=open(f'{output_dir}/time.txt', 'a', encoding='utf-8')) if isinstance(response, str): continue with counter_lock: total_tokens += response.usage.total_tokens call_cnt += 1 tool_calls = response.choices[0].message.tool_calls print('Thought:', response.choices[0].message.content) if tool_calls is not None: print('tool call number', len(tool_calls)) if tool_calls: # self.messages.append(response.choices[0].message) self.messages.append( { "role": "assistant", "tool_calls": tool_calls, "content": response.choices[0].message.content if response.choices[0].message.content is not None else '', } ) for tool_call in tool_calls: function_name = tool_call.function.name function_args = tool_call.function.arguments print('function call:', function_name, function_args) if function_name.lower() == 'finish': self.messages.append( { "tool_call_id": tool_call.id, "role": "tool", "name": function_name, "content": 'Finished', }) print(f'tools {self.tools} assigned') print(colored('tool finish search', 'green')) self.finish_search = True if multi_thread: sem.release() return f'tools {self.tools} assigned' else: return f'tools {self.tools} assigned. The status of current found apis is: {status}' if function_name not in self.api_mapping: function_name = 'hullucinating_function_name' tool_call.function.name = function_name function_call_result = "Function name error" elif function_name == 'add_apis_into_api_pool': with counter_lock: try: function_call_result = self.api_mapping[function_name](**json.loads(function_args)) except Exception as e: print(e, function_name, function_args, file=open(f'{output_dir}/error.txt', 'a', encoding='utf-8')) if raise_error: raise e function_call_result = 'input format error' else: try: function_call_result = self.api_mapping[function_name](**json.loads(function_args)) if function_name in ['get_apis_in_tool'] and isinstance(function_call_result, str) and 'Illegal tool' in function_call_result: function_call_result = f'Illegal tool. The tool should be in the tool list {self.tools}' except Exception as e: print(e, function_name, function_args, file=open(f'{output_dir}/error.txt', 'a', encoding='utf-8')) if raise_error: raise e function_call_result = 'input format error' self.messages.append( { "tool_call_id": tool_call.id, "role": "tool", "name": function_name, "content": str(function_call_result), }) print('function response:', function_call_result) else: # continue self.messages.append({ "role": "assistant", "content": response.choices[0].message.content if response.choices[0].message.content is not None else '', }) self.messages.append({'role': "user", 'content': 'At each step, you should call a function to actually excute your step.'}) print(f'tools {self.tools} assigned') self.finish_search = True if multi_thread: sem.release() return f'tools {self.tools} assigned' else: return f'tools {self.tools} assigned. The status of current found apis is: {status}' class Main_Search_Agent(Agent): def __init__(self, query) -> None: super().__init__() self.categories = [] self.query = query self.api_mapping = { "query_all_categories": query_all_categories, "get_tools_in_category": get_tools_in_category, "get_apis_in_tool": get_apis_in_tool, # "retrieve_context": retrieve_context, "Finish": Finish, "get_api_details": get_api_details, # "locate_api": locate_api, # "query_tool_details": query_tool_details, "get_tools_descriptions": get_tools_descriptions, "create_agent_category_level": self.create_agent_category_level, } self.functions = [ # get_categories_function.to_json_schema(), # get_tools_in_category_function.to_json_schema(), # locate_api_function, get_tools_in_category_function, get_tools_descriptions_function, create_agent_category_level_function, # retrieve_context_function, ] self.functions.append(finish_function) self.messages = [{ "role": "system", "content": META_AGENT_PROMPT.replace('{categories}', str(query_all_categories()))}, {"role": "user", "content": f"Task description: {query}.\ Please determine relevant categories and assign them use the create_agent_category_level function. Begin!"}] # All the categories and the contained tools as a dictionary: {all_cates_all_tools} # "content": f"Task description: {query}. All the categories as well as the contained tools and their descriptions: {category_tool_info}\ def create_agent_category_level(self, category): """Assign a category to an agent""" # print(colored(f'assigning category: {category}', 'green')) global agents, tree, index if category in self.categories: print(colored(f'category: {category} already assigned', 'green')) return f'category: {category} already assigned' with counter_lock: tree[category] = {} if not isinstance(category, str): return f'Error: category: {category} is not str' if category not in query_all_categories(): return f'category: {category} not in database' self.categories.append(category) with counter_lock: agents.append(Category_Agent(self.query, category)) index += 1 agents[-1].depth = self.depth + 1 agents[-1].index = index self.sub_agents.append(agents[-1]) if multi_thread: thread = threading.Thread(target=agents[-1].category_search) sem.acquire() thread.start() with counter_lock: threads.append(thread) else: agents[-1].category_search() if multi_thread: return f'category: {category} assigned.' else: return f'category: {category} assigned. The status of current found apis is: {status}' def resume_search(self): if stop or total_tokens > 200000: self.finish_search = True if multi_thread: sem.release() return self.categories if self.failed_reason is not None: self.messages.append({"role": "user", "content": REFIND_CATEGORY_PROMPT.replace('{failed_reason}', str(self.failed_reason))}) self.failed_reason = None return self.assign_main(self.query) def assign_main(self, query): global total_tokens, stop, error_flag, call_cnt for i in range(20): if stop or total_tokens > 200000: if multi_thread: sem.release() return self.categories t_s = time.time() try: response = call_gpt( messages=self.messages, functions=self.functions ) except: error_flag = True stop = True continue print(time.time() - t_s, file=open(f'{output_dir}/time.txt', 'a', encoding='utf-8')) if isinstance(response, str): print(response) print('response is str') continue with counter_lock: total_tokens += response.usage.total_tokens call_cnt += 1 print('#'*100) tool_calls = response.choices[0].message.tool_calls print('Thought:', response.choices[0].message.content) if tool_calls is not None: print('tool call number', len(tool_calls)) if tool_calls: self.messages.append( { "role": "assistant", "tool_calls": tool_calls, "content": response.choices[0].message.content if response.choices[0].message.content is not None else '', } ) for tool_call in tool_calls: function_name = tool_call.function.name function_args = tool_call.function.arguments if function_name.lower() == 'finish': self.messages.append( { "tool_call_id": tool_call.id, "role": "tool", "name": function_name, "content": 'Finished', }) self.finish_search = True print(colored('main finish search', 'green')) if multi_thread: sem.release() return self.categories if function_name not in self.api_mapping: function_name = 'hullucinating_function_name' tool_call.function.name = function_name function_call_result = "Function name error" else: if function_name == "retrieve_context" and 'query' not in function_args: function_call_result = self.api_mapping[function_name](query, **json.loads(function_args)) else: try: function_call_result = self.api_mapping[function_name](**json.loads(function_args)) if function_name in ['get_apis_in_tool'] and isinstance(function_call_result, str) and 'Illegal tool' in function_call_result: function_call_result = f'Illegal tool. The tool should be in the tool list {self.tools}' except Exception as e: print(e, function_name, function_args, file=open(f'{output_dir}/error.txt', 'a', encoding='utf-8')) if raise_error: raise e function_call_result = 'input format error' self.messages.append( { "tool_call_id": tool_call.id, "role": "tool", "name": function_name, "content": str(function_call_result), } ) print('function call:', function_name, function_args) print('function response:', function_call_result) else: self.messages.append({ "role": "assistant", "content": response.choices[0].message.content if response.choices[0].message.content is not None else '', }) self.messages.append({'role': "user", 'content': 'At each step, you should call a function to actually excute your step.'}) self.finish_search = True if multi_thread: sem.release() return self.categories create_agent_category_level_function = { 'name': 'create_agent_category_level', 'description': 'Assign a category to an agent', 'parameters': { 'type': 'object', 'properties': { 'category': {'type': 'string'} }, 'required': ['category'] } } create_agent_tool_level_function = { 'name': 'create_agent_tool_level', 'description': 'Assign a subset of tools in a category to an agent', 'parameters': { 'type': 'object', 'properties': { 'category': {'type': 'string'}, 'tools': { 'type': 'array', 'items': {'type': 'string'} } }, 'required': ['category', 'tools'] } } finish_function = { "name": "Finish", "description": "If you think you have finished, call this function.", "parameters": { "type": "object", 'properties': { } } } import time from anytool.dfs_gt import solve_given_api_main output_dir = args.output_dir query_path = args.query_path if __name__ == "__main__": os.makedirs(output_dir, exist_ok=True) os.makedirs('output', exist_ok=True) success_cnt = 0 unsolvable_task_cnt = 0 unsolvable_list = json.load(open('misc/unsolvable.json', 'r', encoding='utf-8')) json.dump(sorted(list(set(unsolvable_list))), open('misc/unsolvable.json', 'w', encoding='utf-8'), indent=4) total_cnt = 0 query_data_all = json.load(open(query_path, 'r', encoding='utf-8')) for query_data in query_data_all: try: query_id = query_data['query_id'] query = query_data['query'] threads = [] global_api_list = [] global_api_list_detailed = [] call_cnt = 0 total_tokens = 0 solve_tokens = 0 agents = [] index = 0 failed_reason = None stop = False error_flag = False status = '' solved = False check_solved = 'Unsolved' tree = {} result_list = [] reason_list = [] assign_results = {} assign_results['api_list'] = [] assign_results['stop'] = [] ts = time.time() resumed_agents = [] if not args.include_unsolvable and int(query_id) in unsolvable_list: unsolvable_task_cnt += 1 print('Unsolvable human', unsolvable_task_cnt, success_cnt, total_cnt, file=open(f'{output_dir}/success_cnt.txt', 'a', encoding='utf-8')) continue total_cnt += 1 task_solvable = 'Solvable' solvable_reason = 'Solvable checked by human' if os.path.exists(f'{output_dir}/{query_id}.json'): assign_results = json.load(open(f'{output_dir}/{query_id}.json', 'r', encoding='utf-8')) if 'last_solve_time' in assign_results: solved = assign_results['solved'] check_solved = assign_results['check_solved'] last_solve_time = assign_results['last_solve_time'] if args.recheck_solved: check_solved, reason, _ = check_solved_toolbench(f'{output_dir}/{query_id}_{last_solve_time}_DFS_woFilter_w2.json', assign_results['query_id']) assign_results['check_solved'] = check_solved assign_results['reason'] = reason json.dump(assign_results, open(f'{output_dir}/{query_id}.json', 'w', encoding='utf-8'), indent=4) api_list = assign_results['api_list'][-1] api2origin = json.load(open(f'{output_dir}/{query_id}_{last_solve_time}_DFS_woFilter_w2.json', 'r', encoding='utf-8'))['api2origin'] if check_solved == 'Solved' and len(api_list) <= max_api_number: success_cnt += 1 print(query_id, check_solved, unsolvable_task_cnt, success_cnt, total_cnt, success_cnt/total_cnt) if assign_results['result'] != 'Timeout': continue # continue print(f'query: {query}') flag = False runner = Main_Search_Agent(query) agents.append(runner) if multi_thread: thread = threading.Thread(target=runner.assign_main, args=(query,)) sem.acquire() thread.start() threads.append(thread) else: iter_func = runner.assign_main(query) messages = None cnt = 0 solve_data = {} if multi_thread: while True: thread_num = len(threads) has_thread_alive = False for thread in threads: if thread.is_alive(): has_thread_alive = True thread.join() if not has_thread_alive: break if error_flag: raise Exception('GPT Call Error') threads = [] # refind check_solved = '' max_depth = max([agent.depth for agent in agents]) while not all([agent.finish_search for agent in agents]) or not flag: if check_solved == 'Solved': break max_depth = max([agent.depth for agent in agents]) depth = max_depth while all([agent.finish_search for agent in agents if agent.depth == depth]) and depth >= 0: depth -= 1 if depth < 0 and flag: break agents_to_resume = [agent for agent in agents if not agent.finish_search and agent.depth == depth] if total_tokens > 200000 and flag: solved = False check_solved = 'Timeout' solve_data = {'result': 'Timeout'} break cnt += 1 failed_reason = None print('#'*100) assign_results['api_list'].append(deepcopy(global_api_list)) if stop or not flag or all([agent.finish_search for agent in agents]) and len(global_api_list) > 0: flag = True last_solve_time = cnt t_s = time.time() selected_api_list = deepcopy(global_api_list) solved, solve_data = solve_given_api_main(query, selected_api_list, f'{query_id}_{cnt}', messages) print('solve time:', time.time() - t_s, 'api number:', len(global_api_list),file=open(f'{output_dir}/time.txt', 'a', encoding='utf-8')) result_list.append(deepcopy(solve_data['result'])) print(solve_data['result'], solved) if not solved or any([word in solve_data['result']['final_answer'] for word in exclusion_words]): check_solved = 'Unsolved' reason = solve_data['result'] else: check_solved, reason, tokens = check_solved_toolbench(f'{output_dir}/{query_id}_{last_solve_time}_DFS_woFilter_w2.json', query_id, task_solvable, solvable_reason) total_tokens += tokens print(colored((check_solved, reason), 'red')) failed_reason = reason dfs_data = json.load(open(f'{output_dir}/{query_id}_{last_solve_time}_DFS_woFilter_w2.json', 'r', encoding='utf-8')) total_tokens += dfs_data['answer_generation']['total_tokens'] solve_tokens += dfs_data['answer_generation']['total_tokens'] if check_solved == 'Solved': break try: messages = dfs_data['answer_generation']['train_messages'][-1] except: messages = None api_list_to_prune = [] for standardized_api_name, origin_api in dfs_data['api2origin'].items(): if standardized_api_name in str(failed_reason): if origin_api in global_api_list: api_list_to_prune.append(origin_api) print(colored(api_list_to_prune, 'red')) remove_apis(api_list_to_prune) if len(global_api_list) >= max_api_number: break stop = False else: assert status != 'The current api list can solve the query.' failed_reason = status reason_list.append(failed_reason) print(colored('Refind Begin', 'red')) print(colored(agents_to_resume, 'red')) print([agent.finish_search for agent in agents_to_resume]) threads = [] resume_cnt = 0 resumed_agents.append([(str(a), a.index) for a in agents_to_resume]) for agent in reversed(agents_to_resume): if agent.finish_search: continue resume_cnt += 1 agent.failed_reason = str(failed_reason) print(colored(('resuming', agent, agent.depth), 'red')) print(colored(('resuming', agent, agent.depth), 'red'), file=open(f'{output_dir}/resume.txt', 'a', encoding='utf-8')) if multi_thread: thread = threading.Thread(target=agent.resume_search) sem.acquire() thread.start() threads.append(thread) else: agent.resume_search() if multi_thread: while True: thread_num = len(threads) for thread in threads: if thread.is_alive(): thread.join() if thread_num == len(threads): break if error_flag: raise Exception('GPT Call Error') if not stop: check_if_request_solvable() print(colored(f'status:{status}', 'red')) assign_results['stop'].append(stop) assign_results['api_complete'] = flag find_messages = [] for agent in agents: find_messages.append([str(agent), agent.depth, agent.messages]) assign_results['tree'] = tree assign_results['max_depth'] = max_depth assign_results['query'] = query assign_results['find_messages'] = find_messages assign_results['status'] = status assign_results['solved'] = solved assign_results['query_id'] = query_id assign_results['finish_search'] = [agent.finish_search for agent in agents] assign_results['flag'] = flag if check_solved == 'Solved': success_cnt += 1 else: print(output_dir, 'failed', file=open(f'{output_dir}/failed.txt', 'a', encoding='utf-8')) assign_results['loop_times'] = cnt assign_results['last_solve_time'] = last_solve_time if 'messages' in solve_data: assign_results['solve_messages'] = solve_data['messages'] def parse_tree(node, tree): tree[str(node)] = {} if not isinstance(node, Main_Search_Agent): tree[str(node)]['category'] = node.category tree[str(node)]['tools'] = len(node.tools) tree[str(node)]['index'] = node.index tree[str(node)]['children'] = {} for agent in node.sub_agents: tree[str(node)]['children'].update(parse_tree(agent, {})) return tree agent_tree = parse_tree(runner, {}) tree_results = {} tree_results['agent_tree'] = agent_tree tree_results['resume_agents'] = resumed_agents tree_results['result_list'] = result_list tree_results['reason_list'] = reason_list json.dump(tree_results, open(f'{output_dir}/{query_id}_agent_tree.json', 'w', encoding='utf-8'), indent=4) assign_results['resume_agents'] = resumed_agents assign_results['result_list'] = result_list assign_results['reason'] = reason assign_results['reason_list'] = reason_list assign_results['call_cnt'] = call_cnt assign_results['total_tokens'] = total_tokens assign_results['solve_tokens'] = solve_tokens if 'result' in solve_data: assign_results['result'] = solve_data['result'] assign_results['check_solved'] = check_solved json.dump(assign_results, open(f'{output_dir}/{query_id}.json', 'w', encoding='utf-8'), indent=4) print(check_solved, total_tokens, time.time() - ts, query_path, file=open(f'{output_dir}/time.txt', 'a', encoding='utf-8')) print(query_id, check_solved, success_cnt, total_cnt, success_cnt/total_cnt, file=open(f'{output_dir}/success_cnt.txt', 'a', encoding='utf-8')) except Exception as e: print(e) continue