This commit is contained in:
Sam Partee 2024-05-10 19:05:15 -07:00
parent 6272a426f1
commit 2e4542c260

View file

@ -308,33 +308,6 @@ class ToolFlow:
self.openai_client = openai.Client(api_key=model_api_key)
def infer_flow(self, user_query: str) -> FlowSchema:
"""
Infer the tool flow based on the user query.
Args:
user_query (str): The user's query string.
Returns:
FlowSchema: The inferred tool flow schema.
"""
messages = self.__create_prompt(user_query)
func_spec = pydantic_to_openai_tool(FlowSchema)
tool = json.loads(func_spec)
# Call the OpenAI model with the tools and messages
completion = self.openai_client.chat.completions.create(
model=self.model,
messages=messages,
tools=[tool],
tool_choice="required"
)
predicted_args = completion.choices[0].message.tool_calls[0].function.arguments
print(predicted_args)
return predicted_args
def execute_flow(self, flow_schema: Dict[str, Any], user_query: str) -> Any:
"""
Executes the tool flow based on the provided schema. This method performs a breadth-first search (BFS)