Poll for auth completion in arcade chat (#66)

This PR re-organizes arcade chat implementation and also adds polling
for auth response to complete (if any)


### arcade chat streaming with auth:
![Screenshot 2024-09-25 at 11 28
45 AM](https://github.com/user-attachments/assets/c351fce7-060a-4060-b215-6b5d05028216)

### arcade chat without streaming with auth:
![Screenshot 2024-09-25 at 11 33
08 AM](https://github.com/user-attachments/assets/29a6c5ad-857c-47e6-92d9-52ec87ff88c9)

---------

Co-authored-by: Nate Barbettini <nathanaelb@gmail.com>
This commit is contained in:
Eric Gustin 2024-09-25 13:43:04 -07:00 committed by GitHub
parent 94f77f26af
commit 9eb9f77a92
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 133 additions and 48 deletions

View file

@ -9,7 +9,6 @@ from urllib.parse import urlencode
import typer
from rich.console import Console
from rich.markdown import Markdown
from rich.markup import escape
from rich.table import Table
from rich.text import Text
@ -21,11 +20,11 @@ from arcade.cli.utils import (
apply_config_overrides,
create_cli_catalog,
display_eval_results,
display_streamed_markdown,
display_tool_messages,
get_tool_messages,
markdownify_urls,
handle_chat_interaction,
is_authorization_pending,
validate_and_get_config,
wait_for_authorization_completion,
)
from arcade.client import Arcade
from arcade.client.errors import EngineNotHealthyError, EngineOfflineError
@ -245,50 +244,23 @@ def chat(
history.append({"role": "user", "content": user_input})
tool_messages: list[dict] = []
chat_result = handle_chat_interaction(client, model, history, user_email, stream)
history = chat_result.history
tool_messages = chat_result.tool_messages
tool_authorization = chat_result.tool_authorization
if stream:
# TODO Fix this in the client so users don't deal with these
# typing issues
stream_response = client.chat.completions.create( # type: ignore[call-overload]
model=model,
messages=history,
tool_choice="generate",
user=user_email,
stream=True,
)
role, message_content, tool_messages = display_streamed_markdown(
stream_response, model
)
history += tool_messages
else:
response = client.chat.completions.create( # type: ignore[call-overload]
model=model,
messages=history,
tool_choice="generate",
user=user_email,
stream=False,
)
message_content = response.choices[0].message.content or ""
tool_messages = get_tool_messages(response.choices[0])
history += tool_messages
role = response.choices[0].message.role
if role == "assistant":
message_content = markdownify_urls(message_content)
console.print(
f"\n[bold blue]Assistant ({model}):[/bold blue] ", Markdown(message_content)
)
else:
console.print(f"\n[bold magenta]{role}:[/bold magenta] {message_content}")
# wait for tool authorizations to complete, if any
if is_authorization_pending(tool_authorization):
with console.status("Waiting for you to authorize the action...", spinner="dots"):
wait_for_authorization_completion(client, tool_authorization)
# re-run the chat request now that authorization is complete
chat_result = handle_chat_interaction(client, model, history, user_email, stream)
history = chat_result.history
tool_messages = chat_result.tool_messages
if debug:
display_tool_messages(tool_messages)
history.append({"role": role, "content": message_content})
except KeyboardInterrupt:
console.print("Chat stopped by user.", style="bold blue")
typer.Exit()

View file

@ -1,12 +1,18 @@
from typing import TYPE_CHECKING, Any
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Union
import typer
from openai.resources.chat.completions import ChatCompletionChunk, Stream
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
from openai.types.chat.chat_completion_chunk import Choice as ChatCompletionChunkChoice
from rich.console import Console
from rich.markdown import Markdown
from typer.core import TyperGroup
from typer.models import Context
from arcade.client.client import Arcade
from arcade.client.schema import AuthResponse
from arcade.core.catalog import ToolCatalog
from arcade.core.config_model import Config
from arcade.core.errors import ToolkitLoadError
@ -75,9 +81,15 @@ def get_tool_messages(choice: dict) -> list[dict]:
return []
def display_streamed_markdown(
stream: Stream[ChatCompletionChunk], model: str
) -> tuple[str, str, list]:
@dataclass
class StreamingResult:
role: str
full_message: str
tool_messages: list
tool_authorization: dict | None
def handle_streaming_content(stream: Stream[ChatCompletionChunk], model: str) -> StreamingResult:
"""
Display the streamed markdown chunks as a single line.
"""
@ -85,6 +97,7 @@ def display_streamed_markdown(
full_message = ""
tool_messages = []
tool_authorization = None
role = ""
with Live(console=console, refresh_per_second=10) as live:
for chunk in stream:
@ -101,13 +114,14 @@ def display_streamed_markdown(
# Display and get tool messages if they exist
tool_messages += get_tool_messages(choice) # type: ignore[arg-type]
tool_authorization = get_tool_authorization(choice)
# Markdownify URLs in the final message if applicable
if role == "assistant":
full_message = markdownify_urls(full_message)
live.update(Markdown(full_message))
return role, full_message, tool_messages
return StreamingResult(role, full_message, tool_messages, tool_authorization)
def markdownify_urls(message: str) -> str:
@ -268,3 +282,102 @@ def _format_evaluation(evaluation: "EvaluationResult") -> str:
f"\n Actual: {actual}"
)
return "\n".join(result_lines)
@dataclass
class ChatInteractionResult:
history: list[dict]
tool_messages: list[dict]
tool_authorization: dict | None
def handle_chat_interaction(
client: Arcade, model: str, history: list[dict], user_email: str | None, stream: bool = False
) -> ChatInteractionResult:
"""
Handle a single chat-request/chat-response interaction for both streamed and non-streamed responses.
Handling the chat response includes:
- Streaming the response if the stream flag is set
- Displaying the response in the console
- Getting the tool messages and tool authorization from the response
- Updating the history with the response, tool calls, and tool responses
"""
if stream:
# TODO Fix this in the client so users don't deal with these
# typing issues
response = client.chat.completions.create( # type: ignore[call-overload]
model=model,
messages=history,
tool_choice="generate",
user=user_email,
stream=True,
)
streaming_result = handle_streaming_content(response, model)
role, message_content = streaming_result.role, streaming_result.full_message
tool_messages, tool_authorization = (
streaming_result.tool_messages,
streaming_result.tool_authorization,
)
else:
response = client.chat.completions.create( # type: ignore[call-overload]
model=model,
messages=history,
tool_choice="generate",
user=user_email,
stream=False,
)
message_content = response.choices[0].message.content or ""
# Get extra fields from the response
tool_messages = get_tool_messages(response.choices[0])
tool_authorization = get_tool_authorization(response.choices[0])
role = response.choices[0].message.role
if role == "assistant":
message_content = markdownify_urls(message_content)
console.print(
f"\n[bold blue]Assistant ({model}):[/bold blue] ", Markdown(message_content)
)
else:
console.print(f"\n[bold magenta]{role}:[/bold magenta] {message_content}")
history += tool_messages
history.append({"role": role, "content": message_content})
return ChatInteractionResult(history, tool_messages, tool_authorization)
def wait_for_authorization_completion(client: Arcade, tool_authorization: dict | None) -> None:
"""
Wait for the authorization for a tool call to complete i.e., wait for the user to click on
the approval link and authorize Arcade.
"""
if tool_authorization is None:
return
auth_response = AuthResponse.model_validate(tool_authorization)
while auth_response.status != "completed":
time.sleep(0.5)
auth_response = client.auth.status(auth_response)
def get_tool_authorization(
choice: Union[ChatCompletionChoice, ChatCompletionChunkChoice],
) -> dict | None:
"""
Get the tool authorization from a chat response's choice.
"""
if hasattr(choice, "tool_authorizations") and choice.tool_authorizations:
return choice.tool_authorizations[0] # type: ignore[no-any-return]
return None
def is_authorization_pending(tool_authorization: dict | None) -> bool:
"""
Check if the authorization for a tool call is pending.
Expects a chat response's choice.tool_authorizations as input.
"""
is_auth_pending = (
tool_authorization is not None and tool_authorization.get("status", "") == "pending"
)
return is_auth_pending