AnyTool/toolbench/inference/toolbench_server.py
2024-02-23 15:13:06 +08:00

176 lines
6.8 KiB
Python

from flask import Flask, Response, stream_with_context, request
from flask_cors import CORS, cross_origin
from callbacks.ServerEventCallback import ServerEventCallback
import argparse
from toolbench.inference.Downstream_tasks.rapidapi import pipeline_runner
import subprocess
import concurrent.futures
import json
import signal
import time
from queue import Queue
import copy
import time
app = Flask(__name__)
cors = CORS(app)
class Model:
def __init__(self, gpu=0):
self.inuse = False
print("Initializing...")
starting_time = time.time()
self.args = self.get_args()
self.pipeline = pipeline_runner(self.args, add_retrieval=False, server=True)
print("Loading model...")
self.llm = self.pipeline.get_backbone_model()
print("Model loaded in {} seconds".format(time.time() - starting_time))
starting_time = time.time()
print("Loading retriever...")
self.retriever = self.pipeline.get_retriever()
print("Retriever loaded in {} seconds".format(time.time() - starting_time))
self.query_id = 0
# self.process_num = self.args.process_num
self.queue = Queue()
self.callback = ServerEventCallback(self.queue)
self.occupied = False
print("Server ready")
def run_pipeline(self, user_input, method, top_k):
# empty the queue
while not self.queue.empty():
self.queue.get()
self.query_id += 1
temp_args = copy.deepcopy(self.args)
temp_args.retrieved_api_nums = top_k
temp_args.method = method
data_dict = {
"query": user_input,
}
self.pipeline.run_single_task(
method=method,
backbone_model=self.llm,
query_id=self.query_id,
data_dict=data_dict,
output_dir_path=self.args.output_answer_file,
retriever=self.retriever,
args=temp_args,
tool_des=None,
callbacks=[self.callback]
)
def get_queue(self):
while not self.queue.empty():
yield self.queue.get()
def get_args(self):
parser = argparse.ArgumentParser()
parser.add_argument('--corpus_tsv_path', type=str, default="your_retrival_corpus_path/", required=False,
help='')
parser.add_argument('--retrieval_model_path', type=str, default="your_model_path/", required=False, help='')
parser.add_argument('--retrieved_api_nums', type=int, default=5, required=False, help='')
parser.add_argument('--backbone_model', type=str, default="toolllama", required=False,
help='chatgpt_function or davinci or toolllama')
parser.add_argument('--openai_key', type=str, default="", required=False,
help='openai key for chatgpt_function or davinci model')
parser.add_argument('--model_path', type=str, default="your_model_path/", required=True, help='')
parser.add_argument('--tool_root_dir', type=str, default="your_tools_path/", required=True, help='')
parser.add_argument("--lora", action="store_true", help="Load lora model or not.")
parser.add_argument('--lora_path', type=str, default="your_lora_path if lora", required=False, help='')
parser.add_argument('--max_observation_length', type=int, default=1024, required=False,
help='maximum observation length')
parser.add_argument('--observ_compress_method', type=str, default="truncate", choices=["truncate", "filter", "random"],
required=False, help='observation compress method')
parser.add_argument('--method', type=str, default="CoT@1", required=False,
help='method for answer generation: CoT@n,Reflexion@n,BFS,DFS,UCT_vote')
parser.add_argument('--input_query_file', type=str, default="", required=False, help='input path')
parser.add_argument('--output_answer_file', type=str, default="", required=False, help='output path')
parser.add_argument('--toolbench_key', type=str, default="", required=False, help='your toolbench key')
parser.add_argument('--rapidapi_key', type=str, default="",required=False, help='your rapidapi key to request rapidapi service')
parser.add_argument('--use_rapidapi_key', action="store_true", help="To use customized rapidapi service or not.")
parser.add_argument('--api_customization', action="store_true", help="To use customized api or not.")
args = parser.parse_args()
return args
model = Model()
@app.route('/stream', methods=['GET', 'POST'])
@cross_origin()
def stream():
data = json.loads(request.data)
user_input = data["text"]
top_k = data["top_k"]
method = data["method"]
print("Called stream")
global model
def generate(model):
print("Called generate")
if model.inuse:
# send 409 error
return Response(json.dumps({
"method_name": "error",
"error": "Model in use"
}), status=409, mimetype='application/json')
return
model.inuse = True
# run model.run_agent in the background
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(model.run_pipeline, user_input, method, top_k)
# keep waiting for the queue to be empty
while True:
if model.queue.empty():
if future.done():
print("Finished with future")
break
time.sleep(0.01)
continue
else:
obj = model.queue.get()
if obj["method_name"] == "unknown": continue
if obj["method_name"] == "on_request_end":
yield json.dumps(obj)
break
try:
yield json.dumps(obj) + "\n"
except Exception as e:
model.inuse = False
print(obj)
print(e)
try:
future.result()
except Exception as e:
model.inuse = False
print(e)
model.inuse = False
return
return Response(stream_with_context(generate(model)))
@app.route('/methods', methods=['GET'])
@cross_origin()
def methods():
# return a list of available methods
return Response(json.dumps({
{
"methods": ["DFS_woFilter_w2"]
}
}), status=200, mimetype='application/json')
def handle_keyboard_interrupt(signal, frame):
global model
exit(0)
signal.signal(signal.SIGINT, handle_keyboard_interrupt)
if __name__ == '__main__':
app.run(use_reloader=False, host="0.0.0.0", debug=True, port=5000)