AnyTool/toolbench/inference/LLM_rank/rank_candidate.py
2024-02-23 15:13:06 +08:00

101 lines
3.6 KiB
Python

'''
Evaluate the score of a query corresponding to different candidates
'''
from toolbench.inference.Prompts.rank_prompts import LLM_PAIRWISE_RANK_SUBFIX_SYSTEM_PROMPT, LLM_PAIRWISE_RANK_USER_PROMPT
import random
from toolbench.inference.Tree.Tree import tree_node
def rank2symmetry(llm_interface, LLM_rank_args, cand1,cand2):
'''
Use llm to compare the height, due to the sequence, you need to compare each of the two in the front
'''
single_rank_func = LLM_rank_args["rank_func"]
score = [0,0]
bigger1,query_count1, total_tokens1 = single_rank_func(llm_interface, LLM_rank_args, cand1,cand2)
score[1 - bigger1] += 1
bigger2,query_count2, total_tokens2 = single_rank_func(llm_interface, LLM_rank_args, cand2,cand1)
score[bigger2] += 1
if score[0] > score[1]:
return 1 , query_count1 + query_count2, total_tokens1 + total_tokens2
elif score[0] < score[1]:
return -1, query_count1 + query_count2, total_tokens1 + total_tokens2
else:
return 0, query_count1 + query_count2, total_tokens1 + total_tokens2
def rank2_subfix(llm_interface,LLM_rank_args, cand1,cand2):
'''
Assumed that the two candidates have a long common prefix
'''
anscestor_interesction = tree_node.find_ancestor_intersection(cand1,cand2)
assert anscestor_interesction != None
intersect_trice = anscestor_interesction.get_former_trice_from_this_node(end_node=None)
trice_1 = cand1.get_former_trice_from_this_node(end_node=anscestor_interesction)
trice_2 = cand2.get_former_trice_from_this_node(end_node=anscestor_interesction)
system_message = LLM_PAIRWISE_RANK_SUBFIX_SYSTEM_PROMPT
system_message = system_message.replace("{task_description}", LLM_rank_args["task_description"])
system_message = system_message.replace("{intersect_trice}", intersect_trice)
system_message = system_message.replace("{candidate_A}",trice_1)
system_message = system_message.replace("{candidate_B}",trice_2)
llm_interface.change_messages([{"role":"system","content":system_message},
{"role":"user","content":LLM_PAIRWISE_RANK_USER_PROMPT},
])
output,error_code, total_tokens = llm_interface.parse(functions=LLM_rank_args["functions"],function_call="none",process_id=LLM_rank_args["process_id"])
if output["content"].strip().lower()[-1] == "a":
return 1, 1, total_tokens
else:
return 0, 1, total_tokens
def sum_based_rankn(llm_interface,LLM_rank_args, candidates):
'''
All pairs are sorted pairwise, sum the total points, and choose the best
'''
total_querys = 0
total_tokens = 0
scores = [0]*len(candidates)
for i in range(len(candidates)-1):
for j in range(i+1,len(candidates)):
pairwise_rank,query_count,rank2_tokens = rank2symmetry(llm_interface,LLM_rank_args, candidates[i],candidates[j])
total_querys += query_count
total_tokens += rank2_tokens
if pairwise_rank > 0:
scores[i] += 1
elif pairwise_rank < 0:
scores[j] += 1
else:
scores[i] += 0.5
scores[j] += 0.5
return scores, total_querys, total_tokens
if __name__ == "__main__":
random.seed(42)
# candidates = [
# "234",
# "66.5",
# "77.1",
# "88.967",
# "pi",
# # "e",
# # "ln(2)"
# ]
candidates = [
"77.1",
"88.967",
"pi",
"66.5",
"234",
"ln(2)"
]
'''
starting_delta:
50 -> 42.85%
100 -> 35.99%
150 -> 29.66%
200 -> 24.03%
'''