AnyTool/toolbench/inference/LLM/retriever.py
2024-02-23 15:13:06 +08:00

56 lines
No EOL
2.3 KiB
Python

import time
import pandas as pd
from sentence_transformers import SentenceTransformer, util
import json
import re
from toolbench.utils import standardize, standardize_category, change_name, process_retrieval_ducoment
class ToolRetriever:
def __init__(self, corpus_tsv_path = "", model_path=""):
self.corpus_tsv_path = corpus_tsv_path
self.model_path = model_path
self.corpus, self.corpus2tool = self.build_retrieval_corpus()
self.embedder = self.build_retrieval_embedder()
self.corpus_embeddings = self.build_corpus_embeddings()
def build_retrieval_corpus(self):
print("Building corpus...")
documents_df = pd.read_csv(self.corpus_tsv_path, sep='\t')
corpus, corpus2tool = process_retrieval_ducoment(documents_df)
corpus_ids = list(corpus.keys())
corpus = [corpus[cid] for cid in corpus_ids]
return corpus, corpus2tool
def build_retrieval_embedder(self):
print("Building embedder...")
embedder = SentenceTransformer(self.model_path)
return embedder
def build_corpus_embeddings(self):
print("Building corpus embeddings with embedder...")
corpus_embeddings = self.embedder.encode(self.corpus, convert_to_tensor=True)
return corpus_embeddings
def retrieving(self, query, top_k=5, excluded_tools={}):
print("Retrieving...")
start = time.time()
query_embedding = self.embedder.encode(query, convert_to_tensor=True)
hits = util.semantic_search(query_embedding, self.corpus_embeddings, top_k=10*top_k, score_function=util.cos_sim)
retrieved_tools = []
for rank, hit in enumerate(hits[0]):
category, tool_name, api_name = self.corpus2tool[self.corpus[hit['corpus_id']]].split('\t')
# category = standardize_category(category)
# tool_name = standardize(tool_name) # standardizing
# api_name = change_name(standardize(api_name)) # standardizing
if category in excluded_tools:
if tool_name in excluded_tools[category]:
top_k += 1
continue
tmp_dict = {
"category": category,
"tool_name": tool_name,
"api_name": api_name
}
retrieved_tools.append(tmp_dict)
return retrieved_tools