clean
This commit is contained in:
parent
268f1c0c18
commit
ef1a552417
9 changed files with 30 additions and 203 deletions
26
README.md
26
README.md
|
|
@ -18,9 +18,9 @@ pip install -r requirements.txt
|
|||
|
||||
**OPENAI API config and the ToolBench key**
|
||||
|
||||
Fill your OpenAI GPT-4 API config and toolbench key into the config.py (see config_example.py).
|
||||
Fill your OpenAI GPT-4 API config and toolbench key into the config.py (see config_example.py). We use Azure OpenAI for all our experiments. You can modify it according to your own configuration.
|
||||
|
||||
Fill out the [form](https://docs.google.com/forms/d/e/1FAIpQLSdqHypmYanWU8ZhuUcrEuM5eFB03WqaqYJzvKUxUe1HzUBB3A/viewform?usp=send_form) to get the toolbench key.
|
||||
Fill out the [form](https://docs.google.com/forms/d/e/1FAIpQLSdqHypmYanWU8ZhuUcrEuM5eFB03WqaqYJzvKUxUe1HzUBB3A/viewform?usp=send_form) to get the toolbench key. If you want to use your own RapidAPI key, you can put your key in the rapidapi_key_list.json
|
||||
|
||||
**ToolBench**
|
||||
|
||||
|
|
@ -67,10 +67,28 @@ python scripts/anytoolbench_generation.py --output_path atb_data/anytoolbench_ne
|
|||
|
||||
We provide sample data in [anytoolbench.json](./atb_data/anytoolbench.json) file.
|
||||
|
||||
The data looks like
|
||||
```json
|
||||
"query": "Can you provide detailed information about \"The Incredible Hulk\" movie that was released in 2008, including its plot, genres, and how it's evaluated by audiences, and also tell me the current timezone for Los Angeles, USA?",
|
||||
"final_answer": "The Incredible Hulk (2008) is about scientist Bruce Banner who searches for an antidote to his unbridled rage, the Hulk, but faces new foes when forced back to civilization. GENRES: Sci-Fi, Action, Adventure. AUDIENCE SCORE: 6.2/10. The current timezone for Los Angeles, USA, is America/Los_Angeles.",
|
||||
"query_id": "1000006",
|
||||
"gt_api_list": [
|
||||
{
|
||||
"category_name": "Movies",
|
||||
"tool_name": "Advanced Movie Search",
|
||||
"api_name": "Search by Name"
|
||||
},
|
||||
{
|
||||
"category_name": "Location",
|
||||
"tool_name": "Timezone By API-Ninjas",
|
||||
"api_name": "/v1/timezone"
|
||||
}
|
||||
],
|
||||
|
||||
```
|
||||
|
||||
|
||||
# 🚗 Run AnyTool
|
||||
Fill your OpenAI GPT API config and toolbench key into the config.py (see config_example.py). We use Azure OpenAI for all our experiments. You can modify it according to your own configuration.
|
||||
|
||||
Experiment on ToolBench, take G1-I as an example.
|
||||
```
|
||||
|
|
@ -83,6 +101,8 @@ export PYTHONPATH=./
|
|||
python scripts/main.py --output_dir result/anytoolbench --query_path anytoolbench.json -max_api_number 64
|
||||
```
|
||||
|
||||
The pass rate can be found in the success_cnt.txt under the output directory.
|
||||
|
||||
# 👨🏫 Acknowledgement
|
||||
This repo is built on [ToolBench](https://github.com/OpenBMB/ToolBench).
|
||||
|
||||
|
|
|
|||
|
|
@ -2,10 +2,8 @@ import json
|
|||
from copy import deepcopy
|
||||
from autogen.retrieve_utils import TEXT_FORMATS
|
||||
# from openai_function_calling import FunctionInferer
|
||||
import autogen
|
||||
# import autogen
|
||||
from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
|
||||
import chromadb
|
||||
import openai
|
||||
import random
|
||||
import re
|
||||
import os
|
||||
|
|
|
|||
|
|
@ -4,14 +4,12 @@ from typing import List, Dict, Any
|
|||
import re
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
import requests
|
||||
from termcolor import colored
|
||||
from copy import deepcopy
|
||||
from anytool.api_database_function import *
|
||||
from anytool.verifier import check_solved_toolbench
|
||||
import os
|
||||
from anytool.rapidapi import pipeline_runner
|
||||
import openai
|
||||
import json
|
||||
|
||||
class dotdict(dict):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import time
|
|||
import requests
|
||||
from tqdm import tqdm
|
||||
from termcolor import colored
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from toolbench.inference.LLM.chatgpt_function_model import ChatGPTFunction, GPT4Function
|
||||
from toolbench.inference.LLM.davinci_model import Davinci
|
||||
|
|
|
|||
|
|
@ -1,8 +1,5 @@
|
|||
import openai
|
||||
from openai_function_calling import FunctionInferer
|
||||
import json
|
||||
from anytool.prompt_template import *
|
||||
from tenacity import retry, wait_random_exponential, stop_after_attempt
|
||||
from concurrent.futures import ThreadPoolExecutor,as_completed
|
||||
from openai_utils import call_gpt
|
||||
import time
|
||||
|
|
@ -12,13 +9,9 @@ from anytool.check_solved import compute_pass_rate, process_invalid_data, proces
|
|||
import os
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
import importlib
|
||||
args = parse_args()
|
||||
output_dir = args.output_dir
|
||||
|
||||
|
||||
|
||||
|
||||
def Finish(answer:str, reason:str=None):
|
||||
"""Finish the conversation"""
|
||||
return answer, reason
|
||||
|
|
|
|||
|
|
@ -32,3 +32,4 @@ sentence_transformers
|
|||
pypdf
|
||||
chromadb
|
||||
IPython
|
||||
pyautogen
|
||||
|
|
@ -10,7 +10,7 @@ import requests
|
|||
from termcolor import colored
|
||||
import random
|
||||
from anytool.api_database_function import *
|
||||
from server import get_rapidapi_response
|
||||
from toolbench.inference.server import get_rapidapi_response
|
||||
import tiktoken
|
||||
from copy import deepcopy
|
||||
from anytool.verifier import check_task_complete, check_task_solved
|
||||
|
|
@ -712,7 +712,9 @@ if __name__ == '__main__':
|
|||
query = result['query']
|
||||
answer = result['answer']
|
||||
plan = result['plan']
|
||||
solved, reason = check_task_solved(data['query'], data['final_answer'])
|
||||
else:
|
||||
continue
|
||||
solved, reason = check_task_solved(query, answer)
|
||||
if solved != 'Solved':
|
||||
continue
|
||||
break
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import openai
|
||||
from anytool.api_database_function import *
|
||||
import json
|
||||
import os
|
||||
|
|
@ -9,7 +8,6 @@ from openai_utils import call_gpt
|
|||
import threading
|
||||
from threading import Thread, Semaphore
|
||||
import time
|
||||
import collections
|
||||
import numpy as np
|
||||
from arguments import parse_args
|
||||
args = parse_args()
|
||||
|
|
|
|||
182
server.py
182
server.py
|
|
@ -1,182 +0,0 @@
|
|||
from pydantic import BaseModel
|
||||
import json
|
||||
import os
|
||||
from typing import Union
|
||||
from toolbench.utils import standardize, change_name
|
||||
import random
|
||||
|
||||
|
||||
class Info(BaseModel):
|
||||
category: str
|
||||
tool_name: str
|
||||
api_name: str
|
||||
tool_input: Union[str, dict]
|
||||
strip: str
|
||||
|
||||
def prepare_tool_name_and_url(tools_root, info):
|
||||
category = info.category
|
||||
standard_category = category.replace(" ", "_").replace(",", "_").replace("/", "_")
|
||||
while " " in standard_category or "," in standard_category:
|
||||
standard_category = standard_category.replace(" ", "_").replace(",", "_")
|
||||
standard_category = standard_category.replace("__", "_")
|
||||
|
||||
tool_name = info.tool_name
|
||||
api_name = change_name(standardize(info.api_name))
|
||||
if not tool_name.endswith(f"_for_{standard_category}"):
|
||||
tool_name = standardize(info.tool_name)
|
||||
code_string = f"""from {tools_root}.{standard_category}.{tool_name}.api import {api_name}"""
|
||||
tool_name += f"_for_{standard_category}"
|
||||
else:
|
||||
tmp_tool_name = standardize(tool_name.replace(f"_for_{standard_category}", ""))
|
||||
code_string = f"""from {tools_root}.{standard_category}.{tmp_tool_name}.api import {api_name}"""
|
||||
return tool_name, standard_category, api_name, code_string
|
||||
|
||||
def process_error(response):
|
||||
save_cache_flag = False
|
||||
switch_flag = False
|
||||
if "The request to the API has timed out. Please try again later, or if the issue persists" in str(response):
|
||||
return_dict = {"error": "API temporarily not working error...", "response": response}
|
||||
|
||||
if "Your Client (working) ---> Gateway (working) ---> API (not working)" in str(response):
|
||||
return_dict = {"error": "API not working error...", "response": response}
|
||||
|
||||
elif "Unauthorized" in str(response) or "unauthorized" in str(response):
|
||||
save_cache_flag = True
|
||||
return_dict = {"error": "Unauthorized error...", "response": response}
|
||||
|
||||
elif "You are not subscribed to this API." in str(response):
|
||||
switch_flag = True
|
||||
return_dict = {"error": "Unsubscribed error...", "response": response}
|
||||
|
||||
elif "Too many requests" in str(response):
|
||||
switch_flag = True
|
||||
return_dict = {"error": "Too many requests error...", "response": response}
|
||||
|
||||
elif "You have exceeded" in str(response) or "you are being rate limited" in str(response):
|
||||
switch_flag = True
|
||||
return_dict = {"error": "Rate limit error...", "response": response}
|
||||
|
||||
elif "Access restricted. Check credits balance or enter the correct API key." in str(response):
|
||||
switch_flag = True
|
||||
return_dict = {"error": "Rate limit error...", "response": response}
|
||||
|
||||
elif "Oops, an error in the gateway has occurred." in str(response):
|
||||
switch_flag = True
|
||||
return_dict = {"error": "Gateway error...", "response": response}
|
||||
|
||||
elif "Blocked User. Please contact your API provider." in str(response):
|
||||
switch_flag = True
|
||||
return_dict = {"error": "Blocked error...", "response": response}
|
||||
|
||||
elif "error" in str(response):
|
||||
return_dict = {"error": "Message error...", "response": response}
|
||||
|
||||
else:
|
||||
save_cache_flag = True
|
||||
return_dict = {"error": "", "response": response}
|
||||
return return_dict, save_cache_flag, switch_flag
|
||||
|
||||
def run(toolbench_code_string, toolbench_api_name, toolbench_input_params_str):
|
||||
# get observation
|
||||
success_flag = False
|
||||
switch_flag = False
|
||||
save_cache = False
|
||||
# print('#'*100)
|
||||
print(toolbench_code_string, file=open("output/log.txt", "a"))
|
||||
# from data.toolenv.tools.Data.refactor_numbers_in_human_readable_form_like_1k_or_1m.api import number
|
||||
# from data.toolenv.tools.Data.get_twitter_mentions.api import getmentions
|
||||
exec(toolbench_code_string)
|
||||
# print('*'*100)
|
||||
try:
|
||||
eval_func_str = f"{toolbench_api_name}({toolbench_input_params_str})"
|
||||
new_func = eval(eval_func_str)
|
||||
response, save_cache, switch_flag = process_error(new_func)
|
||||
success_flag = True
|
||||
except Exception as e:
|
||||
response = {"error": f"Function executing {toolbench_code_string} error...\n{e}", "response": ""}
|
||||
save_cache = False
|
||||
return success_flag, switch_flag, response, save_cache
|
||||
|
||||
|
||||
def dict_shorten(origin: dict, schema: dict):
|
||||
for key, value in list(origin.items()):
|
||||
if key not in schema:
|
||||
del origin[key]
|
||||
else:
|
||||
if isinstance(value, dict):
|
||||
dict_shorten(value, schema[key]) # schema[key] should be a dict
|
||||
elif isinstance(value, list):
|
||||
if value:
|
||||
if isinstance(value[0], dict):
|
||||
for item in value:
|
||||
dict_shorten(item, schema[key][0]) # schema[key] should be a list with only one dict element
|
||||
return origin
|
||||
|
||||
def observation_shorten(schema_root, response_dict, category, tool_name, api_name, strip_method):
|
||||
print(random.random())
|
||||
if strip_method == "filter" or (strip_method == "random" and random.random() > 0.5):
|
||||
if isinstance(response_dict["response"], dict):
|
||||
if os.path.exists(os.path.join(schema_root, category)):
|
||||
if os.path.exists(os.path.join(schema_root, category, tool_name+".json")):
|
||||
schema_dicts = json.load(open(os.path.join(schema_root, category, tool_name+".json"), "r"))
|
||||
api_list = schema_dicts["api_list"]
|
||||
schema = None
|
||||
for schema_dict in api_list:
|
||||
schema_api_name = change_name(standardize(schema_dict["name"]))
|
||||
if schema_api_name == api_name and len(schema_dict["schema"]) > 0:
|
||||
schema = schema_dict["schema"]
|
||||
break
|
||||
if schema is not None:
|
||||
response_dict["response"] = dict_shorten(response_dict["response"], schema)
|
||||
return str(response_dict["response"])
|
||||
|
||||
|
||||
def get_rapidapi_response(input_dict: dict, api_customization: bool=False, tools_root: str="data.toolenv.tools", schema_root: str="data/toolenv/response_examples"):
|
||||
info = Info
|
||||
info.category = input_dict['category']
|
||||
info.tool_name = input_dict['tool_name']
|
||||
info.api_name = input_dict['api_name']
|
||||
info.tool_input = input_dict['tool_input']
|
||||
info.strip = input_dict['strip']
|
||||
rapidapi_key = input_dict['rapidapi_key']
|
||||
|
||||
tool_name, standard_category, api_name, code_string = prepare_tool_name_and_url(tools_root, info)
|
||||
tool_input = info.tool_input
|
||||
|
||||
strip_method = info.strip
|
||||
|
||||
try:
|
||||
tool_input = json.loads(tool_input)
|
||||
except Exception as e:
|
||||
if tool_input == "":
|
||||
tool_input = {}
|
||||
else:
|
||||
print(f"Can not parse tool input into json: {tool_input}")
|
||||
response_dict = {"error": f"Tool input parse error...\n", "response": ""}
|
||||
return response_dict
|
||||
|
||||
input_params_str = ""
|
||||
if len(tool_input) > 0:
|
||||
for key, value in tool_input.items():
|
||||
if isinstance(value, str):
|
||||
input_params_str += f'{key}="{value}", '
|
||||
else:
|
||||
input_params_str += f'{key}={value}, '
|
||||
if not api_customization:
|
||||
input_params_str += f"toolbench_rapidapi_key='{rapidapi_key}'"
|
||||
success_flag, switch_flag, response_dict, save_cache = run(code_string, api_name, input_params_str)
|
||||
observation = observation_shorten(schema_root, response_dict, standard_category, tool_name.replace(f"_for_{standard_category}", ""), api_name, strip_method)
|
||||
result = str(observation)[:2048]
|
||||
return {"error": response_dict['error'], "response": result}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = get_rapidapi_response({
|
||||
"category": "Social",
|
||||
"tool_name": "olato_quotes",
|
||||
"api_name": "love_quote",
|
||||
"tool_input": '{}',
|
||||
"strip": "filter",
|
||||
"rapidapi_key": ""
|
||||
})
|
||||
print(result)
|
||||
Loading…
Reference in a new issue