clean
This commit is contained in:
parent
f5d0c5a322
commit
b584120953
19 changed files with 120 additions and 255051 deletions
48
README.md
48
README.md
|
|
@ -14,19 +14,44 @@ Require Python 3.9+
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
# OpenAI GPT API config
|
# 🔆 Preparation
|
||||||
Fill your and toolbench key into the config.py (see config_example.py).
|
|
||||||
|
|
||||||
|
**OPENAI API config and the ToolBench key**
|
||||||
|
|
||||||
|
Fill your OpenAI GPT-4 API config and toolbench key into the config.py (see config_example.py).
|
||||||
|
|
||||||
|
Fill out the [form](https://docs.google.com/forms/d/e/1FAIpQLSdqHypmYanWU8ZhuUcrEuM5eFB03WqaqYJzvKUxUe1HzUBB3A/viewform?usp=send_form) to get the toolbench key.
|
||||||
|
|
||||||
# 🔆 Data Preparation
|
|
||||||
**ToolBench**
|
**ToolBench**
|
||||||
|
|
||||||
Refer to [ToolBench](https://github.com/OpenBMB/ToolBench).
|
Download the ToolBench data using the following link: [Google Drive](https://drive.google.com/drive/folders/1yBUQ732mPu-KclJnuQELEhtKakdXFc3J) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/c9e50625743b40bfbe10/).
|
||||||
|
|
||||||
|
The file structure is as follows:
|
||||||
|
```
|
||||||
|
├── /data/
|
||||||
|
│ ├── /instruction/
|
||||||
|
│ ├── /answer/
|
||||||
|
│ ├── /toolenv/
|
||||||
|
│ ├── /retrieval/
|
||||||
|
│ ├── /test_instruction/
|
||||||
|
│ ├── /test_query_ids/
|
||||||
|
│ ├── /retrieval_test_query_ids/
|
||||||
|
│ ├── toolllama_G123_dfs_train.json
|
||||||
|
│ └── toolllama_G123_dfs_eval.json
|
||||||
|
├── /reproduction_data/
|
||||||
|
│ ├── /chatgpt_cot/
|
||||||
|
│ ├── /chatgpt_dfs/
|
||||||
|
│ ├── ...
|
||||||
|
│ └── /toolllama_dfs/
|
||||||
|
```
|
||||||
|
|
||||||
|
For more details, please refer to [ToolBench](https://github.com/OpenBMB/ToolBench).
|
||||||
|
|
||||||
**Prepare the API data**
|
**Prepare the API data**
|
||||||
|
|
||||||
You should prepare the ToolBench data first. Make sure you have the directory of data/toolenv/tools
|
You should prepare the ToolBench data first. Make sure you have the directory of data/toolenv/tools
|
||||||
```
|
```
|
||||||
|
export PYTHONPATH=./
|
||||||
python scripts/extract_api_details.py
|
python scripts/extract_api_details.py
|
||||||
python scripts/extract_category_tool_details.py
|
python scripts/extract_category_tool_details.py
|
||||||
python scripts/extract_tool_database.py
|
python scripts/extract_tool_database.py
|
||||||
|
|
@ -36,23 +61,26 @@ python scripts/extract_tool_database.py
|
||||||
|
|
||||||
Generation script
|
Generation script
|
||||||
```
|
```
|
||||||
python scripts/data_generation_by_gpt4.py
|
export PYTHONPATH=./
|
||||||
|
python scripts/anytoolbench_generation.py --output_path atb_data/anytoolbench_new.json
|
||||||
```
|
```
|
||||||
|
|
||||||
We provide sample data in [anytoolbench.json]() file.
|
We provide sample data in [anytoolbench.json](./atb_data/anytoolbench.json) file.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 🚗 Run AnyTool
|
# 🚗 Run AnyTool
|
||||||
Fill your OpenAI GPT API config and toolbench key into the config.py (see config_example.py).
|
Fill your OpenAI GPT API config and toolbench key into the config.py (see config_example.py). We use Azure OpenAI for all our experiments. You can modify it according to your own configuration.
|
||||||
|
|
||||||
Experiment on ToolBench
|
Experiment on ToolBench, take G1-I as an example.
|
||||||
```
|
```
|
||||||
python anytool.py --output_dir result/test_instruction/G1_instruction --query_path data/test_instruction/G1_instruction.json --max_api_number 64
|
export PYTHONPATH=./
|
||||||
|
python scripts/main.py --output_dir result/test_instruction/G1_instruction --query_path data/test_instruction/G1_instruction.json --max_api_number 64
|
||||||
```
|
```
|
||||||
Experiment on AnyToolBench
|
Experiment on AnyToolBench
|
||||||
```
|
```
|
||||||
python anytool.py --output_dir result/anytoolbench --query_path anytoolbench.json -max_api_number 64
|
export PYTHONPATH=./
|
||||||
|
python scripts/main.py --output_dir result/anytoolbench --query_path anytoolbench.json -max_api_number 64
|
||||||
```
|
```
|
||||||
|
|
||||||
# 👨🏫 Acknowledgement
|
# 👨🏫 Acknowledgement
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,8 @@ def parse_args():
|
||||||
parser.add_argument('--max_eval_threads', type=int, default=20, required=False, help='max threads nums')
|
parser.add_argument('--max_eval_threads', type=int, default=20, required=False, help='max threads nums')
|
||||||
parser.add_argument('--evaluate_times', type=int, default=7, required=False, help='how many times to predict with the evaluator for each solution path.')
|
parser.add_argument('--evaluate_times', type=int, default=7, required=False, help='how many times to predict with the evaluator for each solution path.')
|
||||||
parser.add_argument("--query_path", type=str, default='', help="Path to the query directory")
|
parser.add_argument("--query_path", type=str, default='', help="Path to the query directory")
|
||||||
parser.add_argument("--output_dir", type=str, default='', help="Path for the output file")
|
parser.add_argument("--output_dir", type=str, default='./', help="Directory for the output file")
|
||||||
|
parser.add_argument("--output_path", type=str, default='./tmp.json', help="Path for the output file")
|
||||||
parser.add_argument("--check_solvable", action='store_true', default=False, help="check solvable")
|
parser.add_argument("--check_solvable", action='store_true', default=False, help="check solvable")
|
||||||
parser.add_argument("--recheck_solved", action='store_true', default=False, help="check solvable")
|
parser.add_argument("--recheck_solved", action='store_true', default=False, help="check solvable")
|
||||||
parser.add_argument("--include_unsolvable", action='store_true', default=False, help="whether skip unsolvable")
|
parser.add_argument("--include_unsolvable", action='store_true', default=False, help="whether skip unsolvable")
|
||||||
|
|
@ -28,7 +29,7 @@ def parse_args():
|
||||||
|
|
||||||
# 添加整数参数
|
# 添加整数参数
|
||||||
parser.add_argument("--max_api_number", type=int, default=64, help="Maximum number of API calls")
|
parser.add_argument("--max_api_number", type=int, default=64, help="Maximum number of API calls")
|
||||||
parser.add_argument("--all_api_number", type=int, default=17000, help="Total number of API calls")
|
parser.add_argument("--all_api_number", type=int, default=16545, help="Total number of API calls")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
evaluators = [load_registered_automatic_evaluator(evaluator_name=args.evaluator, evaluators_cfg_path=os.path.join('toolbench/tooleval','evaluators')) for _ in range(args.max_eval_threads)]
|
evaluators = [load_registered_automatic_evaluator(evaluator_name=args.evaluator, evaluators_cfg_path=os.path.join('toolbench/tooleval','evaluators')) for _ in range(args.max_eval_threads)]
|
||||||
|
|
@ -1,6 +1,4 @@
|
||||||
#encoding:utf-8
|
#encoding:utf-8
|
||||||
|
|
||||||
import openai
|
|
||||||
import os
|
import os
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
import re
|
import re
|
||||||
|
|
@ -9,18 +7,12 @@ import time
|
||||||
import requests
|
import requests
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from api_database_function import *
|
from anytool.api_database_function import *
|
||||||
from verifier import check_solved_toolbench
|
from anytool.verifier import check_solved_toolbench
|
||||||
import os
|
import os
|
||||||
from rapidapi import pipeline_runner
|
from anytool.rapidapi import pipeline_runner
|
||||||
|
|
||||||
from typing import Any, Callable
|
|
||||||
from openai_function_calling import FunctionInferer
|
|
||||||
import openai
|
import openai
|
||||||
import json
|
import json
|
||||||
# query_data = json.load(open('G1_instruction_query_failed.json', 'r', encoding='utf-8'))
|
|
||||||
# Define example functions.
|
|
||||||
from flask import Flask, jsonify, request
|
|
||||||
|
|
||||||
class dotdict(dict):
|
class dotdict(dict):
|
||||||
"""dot.notation access to dictionary attributes"""
|
"""dot.notation access to dictionary attributes"""
|
||||||
|
|
@ -1,14 +1,17 @@
|
||||||
import openai
|
import openai
|
||||||
from openai_function_calling import FunctionInferer
|
from openai_function_calling import FunctionInferer
|
||||||
import json
|
import json
|
||||||
from prompt_template import *
|
from anytool.prompt_template import *
|
||||||
from tenacity import retry, wait_random_exponential, stop_after_attempt
|
from tenacity import retry, wait_random_exponential, stop_after_attempt
|
||||||
from concurrent.futures import ThreadPoolExecutor,as_completed
|
from concurrent.futures import ThreadPoolExecutor,as_completed
|
||||||
from openai_utils import call_gpt
|
from openai_utils import call_gpt
|
||||||
import time
|
import time
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
# from path_config import *
|
|
||||||
from arguments import parse_args
|
from arguments import parse_args
|
||||||
|
from anytool.check_solved import compute_pass_rate, process_invalid_data, process_valid_data
|
||||||
|
import os
|
||||||
|
from tqdm import tqdm
|
||||||
|
import random
|
||||||
import importlib
|
import importlib
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
output_dir = args.output_dir
|
output_dir = args.output_dir
|
||||||
|
|
@ -121,8 +124,6 @@ def check_task_solvable_by_function(query, functions):
|
||||||
return 'Unsure', 'Connection to the assessing model timeout. You can call the check_current_api_suffucient function to check whether the current APIs is sufficient to solve the query.', response.usage.total_tokens
|
return 'Unsure', 'Connection to the assessing model timeout. You can call the check_current_api_suffucient function to check whether the current APIs is sufficient to solve the query.', response.usage.total_tokens
|
||||||
|
|
||||||
def check_task_solved(query, answer):
|
def check_task_solved(query, answer):
|
||||||
# return 'Solvable', ''
|
|
||||||
# print(functions)
|
|
||||||
messages = [{
|
messages = [{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": CHECK_SOLVED_PROMPT
|
"content": CHECK_SOLVED_PROMPT
|
||||||
|
|
@ -132,54 +133,39 @@ def check_task_solved(query, answer):
|
||||||
]
|
]
|
||||||
print(colored('begin check solved', 'red'))
|
print(colored('begin check solved', 'red'))
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
# try:
|
response = call_gpt(
|
||||||
if True:
|
messages=messages,
|
||||||
response = call_gpt(
|
functions=[solve_finish_function]
|
||||||
messages=messages,
|
)
|
||||||
functions=[solve_finish_function]
|
if isinstance(response, str):
|
||||||
)
|
return 'Timeout', 'Timeout'
|
||||||
if isinstance(response, str):
|
tool_calls = response.choices[0].message.tool_calls
|
||||||
return 'Timeout', 'Timeout'
|
print('Thought:', response.choices[0].message.content)
|
||||||
tool_calls = response.choices[0].message.tool_calls
|
if tool_calls:
|
||||||
print('Thought:', response.choices[0].message.content)
|
for tool_call in tool_calls:
|
||||||
if tool_calls:
|
function_name = tool_call.function.name
|
||||||
# messages.append(
|
function_args = tool_call.function.arguments
|
||||||
# {
|
print(function_name, function_args)
|
||||||
# "role": "assistant",
|
if function_name.lower() == 'finish':
|
||||||
# "tool_calls": tool_calls,
|
solvable, reason = Finish(**json.loads(function_args))
|
||||||
# "content": response.choices[0].message.content if response.choices[0].message.content else ''
|
print(solvable, query, file=open('result/solved.txt', 'a', encoding='utf-8'))
|
||||||
# }
|
if solvable == 'Unsolved' and reason is None:
|
||||||
# )
|
messages.append({"role": "user", "content": 'You must give reason if the answer is Unsolvable'})
|
||||||
for tool_call in tool_calls:
|
continue
|
||||||
function_name = tool_call.function.name
|
if reason is not None:
|
||||||
function_args = tool_call.function.arguments
|
print(reason, file=open('result/solved.txt', 'a', encoding='utf-8'))
|
||||||
print(function_name, function_args)
|
else:
|
||||||
if function_name.lower() == 'finish':
|
reason = ''
|
||||||
solvable, reason = Finish(**json.loads(function_args))
|
return solvable, reason
|
||||||
print(solvable, query, file=open('result/solved.txt', 'a', encoding='utf-8'))
|
|
||||||
if solvable == 'Unsolved' and reason is None:
|
else:
|
||||||
messages.append({"role": "user", "content": 'You must give reason if the answer is Unsolvable'})
|
# continue
|
||||||
continue
|
messages.append({"role": "assistant", "content": '' if response.choices[0].message.content is None else response.choices[0].message.content})
|
||||||
if reason is not None:
|
messages.append({"role": "user", "content": "You must call the Finish function but you didn't"})
|
||||||
print(reason, file=open('result/solved.txt', 'a', encoding='utf-8'))
|
|
||||||
else:
|
|
||||||
reason = ''
|
|
||||||
return solvable, reason
|
|
||||||
|
|
||||||
else:
|
|
||||||
# continue
|
|
||||||
messages.append({"role": "assistant", "content": '' if response.choices[0].message.content is None else response.choices[0].message.content})
|
|
||||||
messages.append({"role": "user", "content": "You must call the Finish function but you didn't"})
|
|
||||||
# except:
|
|
||||||
# pass
|
|
||||||
print('No response from the model', file=open('result/solvable.txt', 'a', encoding='utf-8'))
|
print('No response from the model', file=open('result/solvable.txt', 'a', encoding='utf-8'))
|
||||||
print('No response from the model')
|
print('No response from the model')
|
||||||
return 'No response', 'No response from the model'
|
return 'No response', 'No response from the model'
|
||||||
|
|
||||||
from check_solved import compute_pass_rate, process_invalid_data, process_valid_data
|
|
||||||
import os
|
|
||||||
from tqdm import tqdm
|
|
||||||
import random
|
|
||||||
def check_solved_toolbench(output_path, query_id, task_solvable=None, solvable_task_reason=None):
|
def check_solved_toolbench(output_path, query_id, task_solvable=None, solvable_task_reason=None):
|
||||||
print('begin check solved')
|
print('begin check solved')
|
||||||
data_dict = json.load(open(output_path, 'r', encoding='utf-8'))
|
data_dict = json.load(open(output_path, 'r', encoding='utf-8'))
|
||||||
|
|
@ -221,7 +207,6 @@ def check_solved_toolbench(output_path, query_id, task_solvable=None, solvable_t
|
||||||
|
|
||||||
|
|
||||||
def check_task_complete(query, functions):
|
def check_task_complete(query, functions):
|
||||||
# return 'Solvable', ''
|
|
||||||
messages = [{
|
messages = [{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": CHECK_COMPLETE_PROMPT
|
"content": CHECK_COMPLETE_PROMPT
|
||||||
|
|
@ -5,7 +5,8 @@ def parse_args():
|
||||||
|
|
||||||
# 添加字符串参数
|
# 添加字符串参数
|
||||||
parser.add_argument("--query_path", type=str, default='', help="Path to the query data")
|
parser.add_argument("--query_path", type=str, default='', help="Path to the query data")
|
||||||
parser.add_argument("--output_dir", type=str, default='', help="Path for the output file")
|
parser.add_argument("--output_dir", type=str, default='./', help="Directory for the output file")
|
||||||
|
parser.add_argument("--output_path", type=str, default='./tmp.json', help="Path for the output file")
|
||||||
parser.add_argument("--model", type=str, default='32k', help="openai model name")
|
parser.add_argument("--model", type=str, default='32k', help="openai model name")
|
||||||
parser.add_argument("--solver", type=str, default='dfs', help="solver")
|
parser.add_argument("--solver", type=str, default='dfs', help="solver")
|
||||||
|
|
||||||
|
|
@ -16,7 +17,7 @@ def parse_args():
|
||||||
parser.add_argument("--include_unsolvable", action='store_true', default=False, help="whether skip unsolvable")
|
parser.add_argument("--include_unsolvable", action='store_true', default=False, help="whether skip unsolvable")
|
||||||
parser.add_argument("--use_original_prompt", action='store_true', default=False, help="whether use original prompt")
|
parser.add_argument("--use_original_prompt", action='store_true', default=False, help="whether use original prompt")
|
||||||
parser.add_argument("--leaf_tool_number", type=int, default=5, help="Maximum number of leaf tools")
|
parser.add_argument("--leaf_tool_number", type=int, default=5, help="Maximum number of leaf tools")
|
||||||
parser.add_argument("--all_api_number", type=int, default=17000, help="Total number of API calls")
|
parser.add_argument("--all_api_number", type=int, default=16545, help="Total number of API calls")
|
||||||
|
|
||||||
# 解析命令行参数
|
# 解析命令行参数
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,6 @@ api_version = ""
|
||||||
model_name = ""
|
model_name = ""
|
||||||
api_key = ""
|
api_key = ""
|
||||||
api_base = ""
|
api_base = ""
|
||||||
api_type = "azure"
|
api_type = "azure" # leave it as blank if you do not use azure
|
||||||
toolbench_key = ""
|
toolbench_key = ""
|
||||||
|
|
||||||
File diff suppressed because it is too large
Load diff
232228
data_for_retrieval.json
232228
data_for_retrieval.json
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
|
@ -9,12 +9,12 @@ import time
|
||||||
import requests
|
import requests
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
import random
|
import random
|
||||||
from api_database_function import *
|
from anytool.api_database_function import *
|
||||||
from server import get_rapidapi_response
|
from server import get_rapidapi_response
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from verifier import check_task_complete, check_task_solved
|
from anytool.verifier import check_task_complete, check_task_solved
|
||||||
from prompt_template import FORMAT_INSTRUCTIONS_DATA_GENERATION
|
from anytool.prompt_template import FORMAT_INSTRUCTIONS_DATA_GENERATION
|
||||||
from openai_utils import call_gpt
|
from openai_utils import call_gpt
|
||||||
enc = tiktoken.get_encoding("cl100k_base")
|
enc = tiktoken.get_encoding("cl100k_base")
|
||||||
assert enc.decode(enc.encode("hello world")) == "hello world"
|
assert enc.decode(enc.encode("hello world")) == "hello world"
|
||||||
|
|
@ -25,15 +25,6 @@ assert enc.decode(enc.encode("hello world")) == "hello world"
|
||||||
token_cnt = 0
|
token_cnt = 0
|
||||||
error_list = ['Too many requests error...', 'Rate limit...', 'Unsubscribed', 'Unauthorized', 'not working error...', 'Quota','quota', 'Blocked', 'Rate limit', 'Unauthorized error']
|
error_list = ['Too many requests error...', 'Rate limit...', 'Unsubscribed', 'Unauthorized', 'not working error...', 'Quota','quota', 'Blocked', 'Rate limit', 'Unauthorized error']
|
||||||
|
|
||||||
# def retrieve_context(search_string=None):
|
|
||||||
# """retrieve the context containing the search_string"""
|
|
||||||
# context = ragproxyagent.generate_init_message(problem=query,n_results=5, search_string=search_string)
|
|
||||||
# return summarize_context(query, context.split('Context is')[1][:8000])
|
|
||||||
# To help you explore the api database, you can leverage the retrieve_context meta function, which retrieves the relevant context in the database based on your query. And you can specify
|
|
||||||
# the search_string that the context must contain. The retrieved context may contain the potential category_names, tool_names and api_names you are interested in.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FORMAT_INSTRUCTIONS_CONTINUAL_DATA_GENERATION = """
|
FORMAT_INSTRUCTIONS_CONTINUAL_DATA_GENERATION = """
|
||||||
You have access to a database of tools and functions (apis). Function is same to api in our context.
|
You have access to a database of tools and functions (apis). Function is same to api in our context.
|
||||||
You need to help me extend a user query which can be answered by the apis in the database.
|
You need to help me extend a user query which can be answered by the apis in the database.
|
||||||
|
|
@ -61,14 +52,7 @@ The answer should directly answer the query instead of giving a plan.
|
||||||
You should call the initial meta functions no more than 20 times.
|
You should call the initial meta functions no more than 20 times.
|
||||||
The extended part should consist of a minimum of thirty words.
|
The extended part should consist of a minimum of thirty words.
|
||||||
"""
|
"""
|
||||||
# "\nPlease produce three queries in line with the given requirements and inputs. These three queries should display a diverse range of sentence structures: some queries should be in the form of imperative sentences, others declarative, and yet others, interrogative. Equally, they should encompass a variety of tones, with some being polite, others straightforward. Ensure they vary in length and contain a wide range of subjects: myself, my friends, family, and company. Aim to include a number of engaging queries as long as they relate to API calls. Try to avoid explicitly specifying which API to employ in the query. Each query should consist of a minimum of thirty words
|
|
||||||
# At each step, you need to give your thought to analyze the status now and what to do next, with a function call to actually excute your step.
|
|
||||||
# All the thought is short, at most in 5 sentence.
|
|
||||||
|
|
||||||
|
|
||||||
# These ten queries should display a diverse range of sentence structures: some queries should be in the form of imperative sentences,
|
|
||||||
# others declarative, and yet others, interrogative. Equally, they should encompass a variety of tones, with some being polite,
|
|
||||||
# others straightforward.
|
|
||||||
FORMAT_INSTRUCTIONS_DATA_GENERATION_OPTIMIZED="""
|
FORMAT_INSTRUCTIONS_DATA_GENERATION_OPTIMIZED="""
|
||||||
You are an advanced AutoGPT interface designed for dynamic interaction with a comprehensive database of tools and APIs. Your primary function is to assist in generating user queries that can be resolved using the appropriate APIs within the database. To navigate this task efficiently, you have access to five initial meta APIs: query_all_categories, query_tools_in_category, query_apis_in_tool, query_tool_details, and get_api_details. Additionally, you have the capability to test APIs with the add_apis function and can finalize a process with the Finish function.
|
You are an advanced AutoGPT interface designed for dynamic interaction with a comprehensive database of tools and APIs. Your primary function is to assist in generating user queries that can be resolved using the appropriate APIs within the database. To navigate this task efficiently, you have access to five initial meta APIs: query_all_categories, query_tools_in_category, query_apis_in_tool, query_tool_details, and get_api_details. Additionally, you have the capability to test APIs with the add_apis function and can finalize a process with the Finish function.
|
||||||
|
|
||||||
|
|
@ -661,7 +645,7 @@ def generate_main():
|
||||||
return result['answer'], messages
|
return result['answer'], messages
|
||||||
|
|
||||||
exclusion_words = ["sorry", "apologize", "apology", "unfortunately", "couldn't"]
|
exclusion_words = ["sorry", "apologize", "apology", "unfortunately", "couldn't"]
|
||||||
def generate_return_api_main(query, answer):
|
def generate_return_api_main():
|
||||||
data = {}
|
data = {}
|
||||||
global functions, tool_names, cate_names, generated_query_list, raw_api_list, call_cnt
|
global functions, tool_names, cate_names, generated_query_list, raw_api_list, call_cnt
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -704,60 +688,49 @@ def generate_return_api_main(query, answer):
|
||||||
if 'openai' in result:
|
if 'openai' in result:
|
||||||
return result, messages, raw_api_list
|
return result, messages, raw_api_list
|
||||||
generated_query_list.append(result['query'])
|
generated_query_list.append(result['query'])
|
||||||
query = result['query']
|
|
||||||
answer = result['answer']
|
|
||||||
if not any([word in result['answer'].lower() for word in exclusion_words]):
|
if not any([word in result['answer'].lower() for word in exclusion_words]):
|
||||||
return result, messages, raw_api_list
|
return result, messages, raw_api_list
|
||||||
# return result['query'], result['answer'], messages, [{'api_name': functions[k]['name'], 'tool_name': tool_names[k], 'category_name': cate_names[k] }for k in range(6, len(functions))]
|
|
||||||
# except:
|
|
||||||
# pass
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
exclusion_words = ["sorry", "apologize", "apology", "unfortunately", "couldn't"]
|
exclusion_words = ["sorry", "apologize", "apology", "unfortunately", "couldn't"]
|
||||||
# output_dir = 'result1/custom_data'
|
output_path = args.output_path
|
||||||
output_dir = 'result1/custom_data_0129'
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
generated_query_list = []
|
generated_query_list = []
|
||||||
query = ''
|
query = ''
|
||||||
answer = ''
|
answer = ''
|
||||||
for i in range(1000):
|
for i in range(1000):
|
||||||
t_s = time.time()
|
t_s = time.time()
|
||||||
print('#' * 100)
|
print('#' * 100)
|
||||||
print(i)
|
print('Generate the data', i)
|
||||||
|
|
||||||
data = {}
|
data = {}
|
||||||
output_path = f'{output_dir}/{i}.json'
|
|
||||||
if os.path.exists(output_path):
|
|
||||||
continue
|
|
||||||
query = ''
|
|
||||||
answer = ''
|
|
||||||
plan = ''
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
# try:
|
result, generate_messages, api_list = generate_return_api_main()
|
||||||
result, generate_messages, api_list = generate_return_api_main(query, answer)
|
if isinstance(result, dict):
|
||||||
|
query = result['query']
|
||||||
|
answer = result['answer']
|
||||||
|
plan = result['plan']
|
||||||
|
solved, reason = check_task_solved(data['query'], data['final_answer'])
|
||||||
|
if solved != 'Solved':
|
||||||
|
continue
|
||||||
break
|
break
|
||||||
except:
|
except Exception as e:
|
||||||
|
raise e
|
||||||
continue
|
continue
|
||||||
# except:
|
|
||||||
# pass
|
|
||||||
|
|
||||||
|
|
||||||
generated_query_list.append(query)
|
|
||||||
if isinstance(result, dict):
|
|
||||||
query = result['query']
|
|
||||||
answer = result['answer']
|
|
||||||
plan = result['plan']
|
|
||||||
data['query'] = query
|
|
||||||
data['plan'] = plan
|
|
||||||
data['gt_api_list'] = api_list
|
|
||||||
data['final_answer'] = answer
|
|
||||||
# for message in generate_messages:
|
# for message in generate_messages:
|
||||||
# if message['role'] == 'assistant':
|
# if message['role'] == 'assistant':
|
||||||
# if 'tool_calls' in message:
|
# if 'tool_calls' in message:
|
||||||
# message['tool_calls'] = [tool_call.json() for tool_call in message['tool_calls']]
|
# message['tool_calls'] = [tool_call.json() for tool_call in message['tool_calls']]
|
||||||
data['generate_messages'] = generate_messages
|
data['generate_messages'] = generate_messages
|
||||||
print(query, file=open(os.path.join(output_dir, f'generated_query_given_api_list.txt'),'a'))
|
|
||||||
json.dump(data, open(output_path, 'w'), indent=4)
|
generated_query_list.append({
|
||||||
# print(time.time() - t_s, file=open(os.path.join(output_dir, f'time.txt'),'a'))
|
'query': query,
|
||||||
|
'final_answer': answer,
|
||||||
|
'gt_api_list': api_list,
|
||||||
|
'query_id': str(2000000+i)
|
||||||
|
})
|
||||||
|
json.dump(generated_query_list, open(output_path, 'w'), indent=4)
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
import openai
|
import openai
|
||||||
from api_database_function import *
|
from anytool.api_database_function import *
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from prompt_template import *
|
from anytool.prompt_template import *
|
||||||
from verifier import check_task_solvable_by_function, check_task_solvable, check_solved_toolbench, check_task_complete
|
from anytool.verifier import check_task_solvable_by_function, check_task_solvable, check_solved_toolbench, check_task_complete
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from openai_utils import call_gpt
|
from openai_utils import call_gpt
|
||||||
import threading
|
import threading
|
||||||
|
|
@ -894,9 +894,7 @@ finish_function = {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
import time
|
import time
|
||||||
# from refind_api_cot_gpt4 import solve_given_api_main
|
from anytool.dfs_gt import solve_given_api_main
|
||||||
from dfs_gt import solve_given_api_main
|
|
||||||
# from path_config import *
|
|
||||||
output_dir = args.output_dir
|
output_dir = args.output_dir
|
||||||
query_path = args.query_path
|
query_path = args.query_path
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
@ -904,12 +902,7 @@ if __name__ == "__main__":
|
||||||
success_cnt = 0
|
success_cnt = 0
|
||||||
pass_cnt = 0
|
pass_cnt = 0
|
||||||
unsolvable_task_cnt = 0
|
unsolvable_task_cnt = 0
|
||||||
unsolvable_list = json.load(open('unsolvable.json', 'r', encoding='utf-8'))
|
unsolvable_list = json.load(open('misc/unsolvable.json', 'r', encoding='utf-8'))
|
||||||
if 'custom' in query_path:
|
|
||||||
solved_dict = json.load(open('solved_dict.json', 'r', encoding='utf-8'))
|
|
||||||
for query_id in solved_dict:
|
|
||||||
if solved_dict[query_id]['solved'] != 'Solved':
|
|
||||||
unsolvable_list.append(int(query_id))
|
|
||||||
total_cnt = 0
|
total_cnt = 0
|
||||||
query_data_all = json.load(open(query_path, 'r', encoding='utf-8'))
|
query_data_all = json.load(open(query_path, 'r', encoding='utf-8'))
|
||||||
for query_data in query_data_all:
|
for query_data in query_data_all:
|
||||||
8
time.txt
Normal file
8
time.txt
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
minus: 7.772445678710938e-05
|
||||||
|
minus: 6.413459777832031e-05
|
||||||
|
minus: 7.367134094238281e-05
|
||||||
|
minus: 6.532669067382812e-05
|
||||||
|
minus: 6.079673767089844e-05
|
||||||
|
minus: 6.532669067382812e-05
|
||||||
|
minus: 6.508827209472656e-05
|
||||||
|
minus: 6.937980651855469e-05
|
||||||
Loading…
Reference in a new issue