AnyTool/toolbench/inference/callbacks/ServerEventCallback.py
2024-02-23 15:13:06 +08:00

188 lines
No EOL
6 KiB
Python

from typing import Any, Dict, List, Union
import queue
class ServerEventCallback():
"""Base callback handler"""
def __init__(self, queue: queue.Queue, *args, **kwargs):
super().__init__(*args, **kwargs)
self.queue = queue
self.llm_block_id = 0
self.tool_block_id = 0
self.tool_descriptions = {}
def add_to_queue(self, method_name: str, block_id, **kwargs: Any):
data = {
"method_name": method_name,
"block_id": block_id,
}
data.update(kwargs)
self.queue.put(data)
def on_tool_retrieval_start(self):
# tools should be of the form
# {tool_name, tool_desc}
self.add_to_queue(
"on_tool_retrieval_start",
"recommendation-1",
)
print("on_tool_retrieval_start method called")
def on_tool_retrieval_end(self, tools):
# tool should be of the form
# {tool_name, tool_desc}
self.add_to_queue(
"on_tool_retrieval_end",
"recommendation-1",
recommendations=tools
)
self.tool_descriptions = {
tool["name"]: tool for tool in tools
}
print("on_tool_retrieval_end method called")
def on_request_start(self, user_input: str, method: str) -> Any:
self.tool_block_id = 0
self.llm_block_id = 0
self.add_to_queue(
"on_request_start",
block_id="start",
user_input=user_input,
method=method
)
def on_request_end(self, outputs: str, chain: List[Any]):
self.add_to_queue(
"on_request_end",
block_id="end",
output=outputs,
chain=chain
)
def on_request_error(self, error: str):
self.add_to_queue(
"on_request_error",
block_id="error",
error=error
)
# keep
def on_chain_start(self, inputs: str, depth: int) -> Any:
"""Run when chain starts running."""
print("on_chain_start method called")
self.llm_block_id += 1
block_id = "llm-" + str(self.llm_block_id)
self.add_to_queue(
"on_chain_start",
block_id=block_id,
messages=inputs,
depth=depth
)
return block_id
# this one needs the block_id memorized
def on_chain_end(self, block_id: str, depth: int) -> Any:
self.add_to_queue(
"on_chain_end",
block_id=block_id,
# output=output,
depth=depth
)
print("on_chain_end method called")
def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
method_name = "on_chain_error"
self.add_to_queue(method_name, error=error, **kwargs)
print("on_chain_error method called")
def on_llm_start(
self, messages: str, depth: int
) -> Any:
"""Run when LLM starts running."""
self.add_to_queue(
"on_llm_start",
block_id="llm-" + str(self.llm_block_id),
messages=messages,
depth=depth
)
print("on_llm_start method called")
def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
"""Run on new LLM token. Only available when streaming is enabled."""
method_name = "on_llm_new_token"
self.add_to_queue(method_name, token=token, **kwargs)
print("on_llm_new_token method called")
def on_llm_end(self, response: str, depth: int) -> Any:
"""Run when LLM ends running."""
self.add_to_queue(
"on_llm_end",
block_id="llm-" + str(self.llm_block_id),
response=response,
depth=depth
)
print("on_llm_end method called")
def on_llm_error(self, error: Union[Exception, KeyboardInterrupt]) -> Any:
"""Run when LLM errors."""
self.add_to_queue(
"on_llm_error",
block_id="llm-" + str(self.llm_block_id),
message=str(error),
error=error
)
print("on_llm_error method called")
def on_agent_action(self, action, action_input, depth: int) -> str:
self.tool_block_id += 1
block_id="tool-" + str(self.tool_block_id)
self.add_to_queue(
"on_agent_action",
block_id=block_id,
action=action,
action_input = action_input,
depth=depth
)
print("on_agent_action method called")
return block_id
def on_tool_start(self, tool_name: str, tool_input: str, depth: int) -> Any:
method_name = "on_tool_start"
tool_description = "Tool not found in tool descriptions"
if tool_name in self.tool_descriptions:
tool_description = self.tool_descriptions[tool_name]
else:
print(self.tool_descriptions)
print("Key", tool_name, "not found in tool descriptions")
self.add_to_queue(
method_name,
block_id="tool-" + str(self.tool_block_id),
tool_name=tool_name,
tool_description=tool_description,
tool_input=tool_input,
depth=depth
)
print("on_tool_start method called")
def on_tool_end(self, output: str, status:int, depth: int) -> Any:
method_name = "on_tool_end"
self.add_to_queue(
method_name,
block_id="tool-" + str(self.tool_block_id),
output=output,
status= status,
depth=depth
)
print("on_tool_end method called")
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt]) -> Any:
method_name = "on_tool_error"
self.add_to_queue(
method_name,
error=error
)
print("on_tool_error method called")
def on_agent_end(self, block_id:str, depth: int):
self.add_to_queue(
"on_agent_end",
block_id=block_id,
depth=depth
)
print("on_agent_end method called")