AnyTool/toolbench/tooleval/evaluators/registered_cls/rtl.py
2024-02-23 15:13:06 +08:00

246 lines
No EOL
9.5 KiB
Python

import json
import re
import random
import math
from typing import List, Union, Dict, Any, Callable, Optional
from copy import deepcopy
from tenacity import retry, stop_after_attempt
from .utils import register_evaluator,OpenaiPoolRequest
from .tooleval import OpenAINormalizedEvaluator
from enum import Enum
class AnswerStatus(Enum):
Unsure = "Unsure"
Unsolved = "Unsolved"
Solved = "Solved"
Filtered = "Filtered"
class TaskStatus(Enum):
Unsure = "Unsure"
Unsolvable = "Unsolvable"
Solvable = "Solvable"
Filtered = "Filtered"
class AnswerPass(Enum):
Unsure = "Unsure"
Failed = "Failed"
Passed = "Passed"
@register_evaluator
class ReinforceToolLearningEvaluator(OpenAINormalizedEvaluator):
def check_has_hallucination(self,available_tools:List[Dict],answer:Dict[Any,Any])->bool:
available_names = set([tool['name'] for tool in available_tools])
def check_node_valid(node:Dict)->bool:
# print(node)
if node['role'] == "tool":
if isinstance(node['message'], dict):
node['message'] = str(node['message'])
name = re.findall(r"'name':\s*'(.*?)'",node['message'],re.DOTALL)[0]
return name in available_names
return True
def recurssive_check(nodes:Union[List,Dict])->bool:
if isinstance(nodes,Dict):
if not check_node_valid(nodes):
return False
else:
return recurssive_check(nodes['next'])
if isinstance(nodes,List):
for node in nodes:
if not recurssive_check(node):
return False
return True
raise ValueError(f'Unknown node type {type(nodes)}')
return recurssive_check(answer['answer_details'])
def check_is_solved(self,
task_description:Dict,
answer:Dict[Any,Any],
return_reason=False,
) -> Union[AnswerStatus,Optional[str]]:
# empty situation
if answer['final_answer']=='' or 'give_up_and_restart' in answer['final_answer']:
if return_reason:
return AnswerStatus.Unsolved, "Empty final answer!",0
return AnswerStatus.Unsolved, "", 0
# print(answer['final_answer'])
ret = self.function_call(
'check_answer_status',
{
'query':task_description['query'],
'answer':answer['final_answer']
},
return_reason=return_reason
)
answer_status = AnswerStatus(ret['answer_status'])
if answer_status == AnswerStatus.Unsure:
# detailed check here
ret = self.function_call(
'parse_answer_status',
{
'query':task_description['query'],
'answer':json.dumps(answer)
},
return_reason=return_reason
)
if ret['answer_status'] not in ['Unsure','Unsolved', 'Solved']:
ret['answer_status'] = 'Unsure'
answer_status = AnswerStatus(ret['answer_status'])
# print(answer_status)
if return_reason:
return answer_status,ret['reason'], ret['tokens']
return answer_status, "", ret['tokens']
def check_task_solvable(self,
task_description:Dict,
has_been_solved=False,
return_reason=False,
)->Union[TaskStatus,Optional[str]]:
if has_been_solved:
if return_reason:
return TaskStatus.Solvable, 'Task has been solved before.', 0
return TaskStatus.Solvable, '', 0
ret = self.function_call(
'check_task_solvable',
{
'task':json.dumps(task_description)
},
return_reason=return_reason
)
task_status = TaskStatus(ret['task_status'])
if return_reason:
return task_status, ret['reason'], ret['tokens']
return task_status, '', ret['tokens']
def is_passed(self,
task_description:Dict,
answer:Dict[Any,Any],
answer_status:AnswerStatus=None,
task_status:TaskStatus=None,
)->AnswerPass:
tokens = 0
if answer_status is None:
answer_status, _, tokens = self.check_is_solved(task_description,answer)
if answer_status == AnswerStatus.Solved:
return AnswerPass.Passed, tokens
else:
if task_status is None:
task_status, _ = self.check_task_solvable(
task_description,
has_been_solved=answer_status==AnswerStatus.Solved)
if answer_status == AnswerStatus.Unsolved:
if task_status == TaskStatus.Solvable:
return AnswerPass.Failed, tokens
if task_status == TaskStatus.Unsure:
return AnswerPass.Unsure, tokens
if task_status == TaskStatus.Unsolvable:
return AnswerPass.Passed, tokens
elif answer_status == AnswerStatus.Unsure:
if task_status == TaskStatus.Solvable:
return AnswerPass.Unsure, tokens
if task_status == TaskStatus.Unsure:
return AnswerPass.Unsure, tokens
if task_status == TaskStatus.Unsolvable:
return AnswerPass.Passed, tokens
return AnswerPass.Failed, tokens
def check_identity_answers(self,
answers:List[Dict[Any,Any]],
)->bool:
ref_answer = answers[0]
for ans in answers[1:]:
if ans['final_answer']!=ref_answer['final_answer']:
return False
if str(ans['answer_details'])!=str(ref_answer['answer_details']):
return False
return True
@retry(stop=stop_after_attempt(3),reraise=True)
def select_better_answer(self,
task_description:Dict,
task_status:TaskStatus,
ans_idxs:List[int],
answers:List[Dict[Any,Any]],
answer_status:AnswerStatus,
*,
return_reason=True)->int:
answers = deepcopy(answers)
if self.check_identity_answers(answers):
return random.choice(ans_idxs)
judge_focus = {
TaskStatus.Solvable:'Since query is solvable, you should select answer with smaller "total_steps" and informative, accurate "final_answer".',
TaskStatus.Unsure:'Since query is unsure, you should select a more comprehensive exploration for possible solutions.',
TaskStatus.Unsolvable:'Since query is unsolvable, you should select answer with smaller "total_steps" and detailed reasons for failure.'
}
ret = self.function_call(
'select_better_answer', {
'query':task_description['query'],
'answer_0':json.dumps(answers[0]),
'answer_1':json.dumps(answers[1]),
# 'q_status':judge_focus[task_status],
},
return_reason=return_reason
)
index = int(ret['index'])
if index in ans_idxs:
return index
else:
raise ValueError(f'Index {index} not found!')
def normalized_openai_completions(self,task_description:Dict, answers:List[Dict[Any,Any]], task_status:None, answer_statuss:[None, None])->int:
if answer_statuss[0] is None:
print("comparing from scratch...")
status = [self.check_is_solved(task_description,ans)[0] for ans in answers]
else:
status = answer_statuss
# check whether there are answers solve the task
solves = [idx for idx,s in enumerate(status) if s==AnswerStatus.Solved]
if len(solves)==1:
return solves[0]
elif len(solves)>1:
# pick best one
if task_status is None:
task_status, _ = self.check_task_solvable(task_description,has_been_solved=True)
else:
task_status = task_status
return self.select_better_answer(task_description,task_status,solves,[answers[idx] for idx in solves],AnswerStatus.Solved)
# if no answer solves the task, check whether unsure answer exists
unsures = [idx for idx,s in enumerate(status) if s==AnswerStatus.Unsure]
if len(unsures) == 1:
return unsures[0]
elif len(unsures)>1:
# pick best one
if task_status is None:
task_status, _ = self.check_task_solvable(task_description)
else:
task_status = task_status
return self.select_better_answer(task_description,task_status,unsures,[answers[idx] for idx in unsures],AnswerStatus.Unsure)
# if all failed
# pick best one
if task_status is None:
task_status, _ = self.check_task_solvable(task_description)
else:
task_status = task_status
return self.select_better_answer(task_description,task_status,list(range(len(answers))),answers,AnswerStatus.Unsolved)