From 9eb9f77a92915876d78d3d785772faa73bd22e8d Mon Sep 17 00:00:00 2001 From: Eric Gustin <34000337+EricGustin@users.noreply.github.com> Date: Wed, 25 Sep 2024 13:43:04 -0700 Subject: [PATCH] Poll for auth completion in `arcade chat` (#66) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR re-organizes arcade chat implementation and also adds polling for auth response to complete (if any) ### arcade chat streaming with auth: ![Screenshot 2024-09-25 at 11 28 45 AM](https://github.com/user-attachments/assets/c351fce7-060a-4060-b215-6b5d05028216) ### arcade chat without streaming with auth: ![Screenshot 2024-09-25 at 11 33 08 AM](https://github.com/user-attachments/assets/29a6c5ad-857c-47e6-92d9-52ec87ff88c9) --------- Co-authored-by: Nate Barbettini --- arcade/arcade/cli/main.py | 58 +++++------------ arcade/arcade/cli/utils.py | 123 +++++++++++++++++++++++++++++++++++-- 2 files changed, 133 insertions(+), 48 deletions(-) diff --git a/arcade/arcade/cli/main.py b/arcade/arcade/cli/main.py index b8624ffc..9090fb59 100644 --- a/arcade/arcade/cli/main.py +++ b/arcade/arcade/cli/main.py @@ -9,7 +9,6 @@ from urllib.parse import urlencode import typer from rich.console import Console -from rich.markdown import Markdown from rich.markup import escape from rich.table import Table from rich.text import Text @@ -21,11 +20,11 @@ from arcade.cli.utils import ( apply_config_overrides, create_cli_catalog, display_eval_results, - display_streamed_markdown, display_tool_messages, - get_tool_messages, - markdownify_urls, + handle_chat_interaction, + is_authorization_pending, validate_and_get_config, + wait_for_authorization_completion, ) from arcade.client import Arcade from arcade.client.errors import EngineNotHealthyError, EngineOfflineError @@ -245,50 +244,23 @@ def chat( history.append({"role": "user", "content": user_input}) - tool_messages: list[dict] = [] + chat_result = handle_chat_interaction(client, model, history, user_email, stream) + history = chat_result.history + tool_messages = chat_result.tool_messages + tool_authorization = chat_result.tool_authorization - if stream: - # TODO Fix this in the client so users don't deal with these - # typing issues - stream_response = client.chat.completions.create( # type: ignore[call-overload] - model=model, - messages=history, - tool_choice="generate", - user=user_email, - stream=True, - ) - role, message_content, tool_messages = display_streamed_markdown( - stream_response, model - ) - - history += tool_messages - else: - response = client.chat.completions.create( # type: ignore[call-overload] - model=model, - messages=history, - tool_choice="generate", - user=user_email, - stream=False, - ) - message_content = response.choices[0].message.content or "" - - tool_messages = get_tool_messages(response.choices[0]) - history += tool_messages - - role = response.choices[0].message.role - if role == "assistant": - message_content = markdownify_urls(message_content) - console.print( - f"\n[bold blue]Assistant ({model}):[/bold blue] ", Markdown(message_content) - ) - else: - console.print(f"\n[bold magenta]{role}:[/bold magenta] {message_content}") + # wait for tool authorizations to complete, if any + if is_authorization_pending(tool_authorization): + with console.status("Waiting for you to authorize the action...", spinner="dots"): + wait_for_authorization_completion(client, tool_authorization) + # re-run the chat request now that authorization is complete + chat_result = handle_chat_interaction(client, model, history, user_email, stream) + history = chat_result.history + tool_messages = chat_result.tool_messages if debug: display_tool_messages(tool_messages) - history.append({"role": role, "content": message_content}) - except KeyboardInterrupt: console.print("Chat stopped by user.", style="bold blue") typer.Exit() diff --git a/arcade/arcade/cli/utils.py b/arcade/arcade/cli/utils.py index fd6b306c..be970157 100644 --- a/arcade/arcade/cli/utils.py +++ b/arcade/arcade/cli/utils.py @@ -1,12 +1,18 @@ -from typing import TYPE_CHECKING, Any +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Union import typer from openai.resources.chat.completions import ChatCompletionChunk, Stream +from openai.types.chat.chat_completion import Choice as ChatCompletionChoice +from openai.types.chat.chat_completion_chunk import Choice as ChatCompletionChunkChoice from rich.console import Console from rich.markdown import Markdown from typer.core import TyperGroup from typer.models import Context +from arcade.client.client import Arcade +from arcade.client.schema import AuthResponse from arcade.core.catalog import ToolCatalog from arcade.core.config_model import Config from arcade.core.errors import ToolkitLoadError @@ -75,9 +81,15 @@ def get_tool_messages(choice: dict) -> list[dict]: return [] -def display_streamed_markdown( - stream: Stream[ChatCompletionChunk], model: str -) -> tuple[str, str, list]: +@dataclass +class StreamingResult: + role: str + full_message: str + tool_messages: list + tool_authorization: dict | None + + +def handle_streaming_content(stream: Stream[ChatCompletionChunk], model: str) -> StreamingResult: """ Display the streamed markdown chunks as a single line. """ @@ -85,6 +97,7 @@ def display_streamed_markdown( full_message = "" tool_messages = [] + tool_authorization = None role = "" with Live(console=console, refresh_per_second=10) as live: for chunk in stream: @@ -101,13 +114,14 @@ def display_streamed_markdown( # Display and get tool messages if they exist tool_messages += get_tool_messages(choice) # type: ignore[arg-type] + tool_authorization = get_tool_authorization(choice) # Markdownify URLs in the final message if applicable if role == "assistant": full_message = markdownify_urls(full_message) live.update(Markdown(full_message)) - return role, full_message, tool_messages + return StreamingResult(role, full_message, tool_messages, tool_authorization) def markdownify_urls(message: str) -> str: @@ -268,3 +282,102 @@ def _format_evaluation(evaluation: "EvaluationResult") -> str: f"\n Actual: {actual}" ) return "\n".join(result_lines) + + +@dataclass +class ChatInteractionResult: + history: list[dict] + tool_messages: list[dict] + tool_authorization: dict | None + + +def handle_chat_interaction( + client: Arcade, model: str, history: list[dict], user_email: str | None, stream: bool = False +) -> ChatInteractionResult: + """ + Handle a single chat-request/chat-response interaction for both streamed and non-streamed responses. + Handling the chat response includes: + - Streaming the response if the stream flag is set + - Displaying the response in the console + - Getting the tool messages and tool authorization from the response + - Updating the history with the response, tool calls, and tool responses + """ + if stream: + # TODO Fix this in the client so users don't deal with these + # typing issues + response = client.chat.completions.create( # type: ignore[call-overload] + model=model, + messages=history, + tool_choice="generate", + user=user_email, + stream=True, + ) + streaming_result = handle_streaming_content(response, model) + role, message_content = streaming_result.role, streaming_result.full_message + tool_messages, tool_authorization = ( + streaming_result.tool_messages, + streaming_result.tool_authorization, + ) + else: + response = client.chat.completions.create( # type: ignore[call-overload] + model=model, + messages=history, + tool_choice="generate", + user=user_email, + stream=False, + ) + message_content = response.choices[0].message.content or "" + + # Get extra fields from the response + tool_messages = get_tool_messages(response.choices[0]) + tool_authorization = get_tool_authorization(response.choices[0]) + + role = response.choices[0].message.role + if role == "assistant": + message_content = markdownify_urls(message_content) + console.print( + f"\n[bold blue]Assistant ({model}):[/bold blue] ", Markdown(message_content) + ) + else: + console.print(f"\n[bold magenta]{role}:[/bold magenta] {message_content}") + + history += tool_messages + history.append({"role": role, "content": message_content}) + + return ChatInteractionResult(history, tool_messages, tool_authorization) + + +def wait_for_authorization_completion(client: Arcade, tool_authorization: dict | None) -> None: + """ + Wait for the authorization for a tool call to complete i.e., wait for the user to click on + the approval link and authorize Arcade. + """ + if tool_authorization is None: + return + auth_response = AuthResponse.model_validate(tool_authorization) + + while auth_response.status != "completed": + time.sleep(0.5) + auth_response = client.auth.status(auth_response) + + +def get_tool_authorization( + choice: Union[ChatCompletionChoice, ChatCompletionChunkChoice], +) -> dict | None: + """ + Get the tool authorization from a chat response's choice. + """ + if hasattr(choice, "tool_authorizations") and choice.tool_authorizations: + return choice.tool_authorizations[0] # type: ignore[no-any-return] + return None + + +def is_authorization_pending(tool_authorization: dict | None) -> bool: + """ + Check if the authorization for a tool call is pending. + Expects a chat response's choice.tool_authorizations as input. + """ + is_auth_pending = ( + tool_authorization is not None and tool_authorization.get("status", "") == "pending" + ) + return is_auth_pending