56 lines
No EOL
2.3 KiB
Python
56 lines
No EOL
2.3 KiB
Python
import time
|
|
import pandas as pd
|
|
from sentence_transformers import SentenceTransformer, util
|
|
import json
|
|
import re
|
|
from toolbench.utils import standardize, standardize_category, change_name, process_retrieval_ducoment
|
|
|
|
|
|
class ToolRetriever:
|
|
def __init__(self, corpus_tsv_path = "", model_path=""):
|
|
self.corpus_tsv_path = corpus_tsv_path
|
|
self.model_path = model_path
|
|
self.corpus, self.corpus2tool = self.build_retrieval_corpus()
|
|
self.embedder = self.build_retrieval_embedder()
|
|
self.corpus_embeddings = self.build_corpus_embeddings()
|
|
|
|
def build_retrieval_corpus(self):
|
|
print("Building corpus...")
|
|
documents_df = pd.read_csv(self.corpus_tsv_path, sep='\t')
|
|
corpus, corpus2tool = process_retrieval_ducoment(documents_df)
|
|
corpus_ids = list(corpus.keys())
|
|
corpus = [corpus[cid] for cid in corpus_ids]
|
|
return corpus, corpus2tool
|
|
|
|
def build_retrieval_embedder(self):
|
|
print("Building embedder...")
|
|
embedder = SentenceTransformer(self.model_path)
|
|
return embedder
|
|
|
|
def build_corpus_embeddings(self):
|
|
print("Building corpus embeddings with embedder...")
|
|
corpus_embeddings = self.embedder.encode(self.corpus, convert_to_tensor=True)
|
|
return corpus_embeddings
|
|
|
|
def retrieving(self, query, top_k=5, excluded_tools={}):
|
|
print("Retrieving...")
|
|
start = time.time()
|
|
query_embedding = self.embedder.encode(query, convert_to_tensor=True)
|
|
hits = util.semantic_search(query_embedding, self.corpus_embeddings, top_k=10*top_k, score_function=util.cos_sim)
|
|
retrieved_tools = []
|
|
for rank, hit in enumerate(hits[0]):
|
|
category, tool_name, api_name = self.corpus2tool[self.corpus[hit['corpus_id']]].split('\t')
|
|
# category = standardize_category(category)
|
|
# tool_name = standardize(tool_name) # standardizing
|
|
# api_name = change_name(standardize(api_name)) # standardizing
|
|
if category in excluded_tools:
|
|
if tool_name in excluded_tools[category]:
|
|
top_k += 1
|
|
continue
|
|
tmp_dict = {
|
|
"category": category,
|
|
"tool_name": tool_name,
|
|
"api_name": api_name
|
|
}
|
|
retrieved_tools.append(tmp_dict)
|
|
return retrieved_tools |