Poll for auth completion in arcade chat (#66)
This PR re-organizes arcade chat implementation and also adds polling for auth response to complete (if any) ### arcade chat streaming with auth:  ### arcade chat without streaming with auth:  --------- Co-authored-by: Nate Barbettini <nathanaelb@gmail.com>
This commit is contained in:
parent
94f77f26af
commit
9eb9f77a92
2 changed files with 133 additions and 48 deletions
|
|
@ -9,7 +9,6 @@ from urllib.parse import urlencode
|
|||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.markup import escape
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
|
@ -21,11 +20,11 @@ from arcade.cli.utils import (
|
|||
apply_config_overrides,
|
||||
create_cli_catalog,
|
||||
display_eval_results,
|
||||
display_streamed_markdown,
|
||||
display_tool_messages,
|
||||
get_tool_messages,
|
||||
markdownify_urls,
|
||||
handle_chat_interaction,
|
||||
is_authorization_pending,
|
||||
validate_and_get_config,
|
||||
wait_for_authorization_completion,
|
||||
)
|
||||
from arcade.client import Arcade
|
||||
from arcade.client.errors import EngineNotHealthyError, EngineOfflineError
|
||||
|
|
@ -245,50 +244,23 @@ def chat(
|
|||
|
||||
history.append({"role": "user", "content": user_input})
|
||||
|
||||
tool_messages: list[dict] = []
|
||||
chat_result = handle_chat_interaction(client, model, history, user_email, stream)
|
||||
history = chat_result.history
|
||||
tool_messages = chat_result.tool_messages
|
||||
tool_authorization = chat_result.tool_authorization
|
||||
|
||||
if stream:
|
||||
# TODO Fix this in the client so users don't deal with these
|
||||
# typing issues
|
||||
stream_response = client.chat.completions.create( # type: ignore[call-overload]
|
||||
model=model,
|
||||
messages=history,
|
||||
tool_choice="generate",
|
||||
user=user_email,
|
||||
stream=True,
|
||||
)
|
||||
role, message_content, tool_messages = display_streamed_markdown(
|
||||
stream_response, model
|
||||
)
|
||||
|
||||
history += tool_messages
|
||||
else:
|
||||
response = client.chat.completions.create( # type: ignore[call-overload]
|
||||
model=model,
|
||||
messages=history,
|
||||
tool_choice="generate",
|
||||
user=user_email,
|
||||
stream=False,
|
||||
)
|
||||
message_content = response.choices[0].message.content or ""
|
||||
|
||||
tool_messages = get_tool_messages(response.choices[0])
|
||||
history += tool_messages
|
||||
|
||||
role = response.choices[0].message.role
|
||||
if role == "assistant":
|
||||
message_content = markdownify_urls(message_content)
|
||||
console.print(
|
||||
f"\n[bold blue]Assistant ({model}):[/bold blue] ", Markdown(message_content)
|
||||
)
|
||||
else:
|
||||
console.print(f"\n[bold magenta]{role}:[/bold magenta] {message_content}")
|
||||
# wait for tool authorizations to complete, if any
|
||||
if is_authorization_pending(tool_authorization):
|
||||
with console.status("Waiting for you to authorize the action...", spinner="dots"):
|
||||
wait_for_authorization_completion(client, tool_authorization)
|
||||
# re-run the chat request now that authorization is complete
|
||||
chat_result = handle_chat_interaction(client, model, history, user_email, stream)
|
||||
history = chat_result.history
|
||||
tool_messages = chat_result.tool_messages
|
||||
|
||||
if debug:
|
||||
display_tool_messages(tool_messages)
|
||||
|
||||
history.append({"role": role, "content": message_content})
|
||||
|
||||
except KeyboardInterrupt:
|
||||
console.print("Chat stopped by user.", style="bold blue")
|
||||
typer.Exit()
|
||||
|
|
|
|||
|
|
@ -1,12 +1,18 @@
|
|||
from typing import TYPE_CHECKING, Any
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
import typer
|
||||
from openai.resources.chat.completions import ChatCompletionChunk, Stream
|
||||
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
|
||||
from openai.types.chat.chat_completion_chunk import Choice as ChatCompletionChunkChoice
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from typer.core import TyperGroup
|
||||
from typer.models import Context
|
||||
|
||||
from arcade.client.client import Arcade
|
||||
from arcade.client.schema import AuthResponse
|
||||
from arcade.core.catalog import ToolCatalog
|
||||
from arcade.core.config_model import Config
|
||||
from arcade.core.errors import ToolkitLoadError
|
||||
|
|
@ -75,9 +81,15 @@ def get_tool_messages(choice: dict) -> list[dict]:
|
|||
return []
|
||||
|
||||
|
||||
def display_streamed_markdown(
|
||||
stream: Stream[ChatCompletionChunk], model: str
|
||||
) -> tuple[str, str, list]:
|
||||
@dataclass
|
||||
class StreamingResult:
|
||||
role: str
|
||||
full_message: str
|
||||
tool_messages: list
|
||||
tool_authorization: dict | None
|
||||
|
||||
|
||||
def handle_streaming_content(stream: Stream[ChatCompletionChunk], model: str) -> StreamingResult:
|
||||
"""
|
||||
Display the streamed markdown chunks as a single line.
|
||||
"""
|
||||
|
|
@ -85,6 +97,7 @@ def display_streamed_markdown(
|
|||
|
||||
full_message = ""
|
||||
tool_messages = []
|
||||
tool_authorization = None
|
||||
role = ""
|
||||
with Live(console=console, refresh_per_second=10) as live:
|
||||
for chunk in stream:
|
||||
|
|
@ -101,13 +114,14 @@ def display_streamed_markdown(
|
|||
|
||||
# Display and get tool messages if they exist
|
||||
tool_messages += get_tool_messages(choice) # type: ignore[arg-type]
|
||||
tool_authorization = get_tool_authorization(choice)
|
||||
|
||||
# Markdownify URLs in the final message if applicable
|
||||
if role == "assistant":
|
||||
full_message = markdownify_urls(full_message)
|
||||
live.update(Markdown(full_message))
|
||||
|
||||
return role, full_message, tool_messages
|
||||
return StreamingResult(role, full_message, tool_messages, tool_authorization)
|
||||
|
||||
|
||||
def markdownify_urls(message: str) -> str:
|
||||
|
|
@ -268,3 +282,102 @@ def _format_evaluation(evaluation: "EvaluationResult") -> str:
|
|||
f"\n Actual: {actual}"
|
||||
)
|
||||
return "\n".join(result_lines)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatInteractionResult:
|
||||
history: list[dict]
|
||||
tool_messages: list[dict]
|
||||
tool_authorization: dict | None
|
||||
|
||||
|
||||
def handle_chat_interaction(
|
||||
client: Arcade, model: str, history: list[dict], user_email: str | None, stream: bool = False
|
||||
) -> ChatInteractionResult:
|
||||
"""
|
||||
Handle a single chat-request/chat-response interaction for both streamed and non-streamed responses.
|
||||
Handling the chat response includes:
|
||||
- Streaming the response if the stream flag is set
|
||||
- Displaying the response in the console
|
||||
- Getting the tool messages and tool authorization from the response
|
||||
- Updating the history with the response, tool calls, and tool responses
|
||||
"""
|
||||
if stream:
|
||||
# TODO Fix this in the client so users don't deal with these
|
||||
# typing issues
|
||||
response = client.chat.completions.create( # type: ignore[call-overload]
|
||||
model=model,
|
||||
messages=history,
|
||||
tool_choice="generate",
|
||||
user=user_email,
|
||||
stream=True,
|
||||
)
|
||||
streaming_result = handle_streaming_content(response, model)
|
||||
role, message_content = streaming_result.role, streaming_result.full_message
|
||||
tool_messages, tool_authorization = (
|
||||
streaming_result.tool_messages,
|
||||
streaming_result.tool_authorization,
|
||||
)
|
||||
else:
|
||||
response = client.chat.completions.create( # type: ignore[call-overload]
|
||||
model=model,
|
||||
messages=history,
|
||||
tool_choice="generate",
|
||||
user=user_email,
|
||||
stream=False,
|
||||
)
|
||||
message_content = response.choices[0].message.content or ""
|
||||
|
||||
# Get extra fields from the response
|
||||
tool_messages = get_tool_messages(response.choices[0])
|
||||
tool_authorization = get_tool_authorization(response.choices[0])
|
||||
|
||||
role = response.choices[0].message.role
|
||||
if role == "assistant":
|
||||
message_content = markdownify_urls(message_content)
|
||||
console.print(
|
||||
f"\n[bold blue]Assistant ({model}):[/bold blue] ", Markdown(message_content)
|
||||
)
|
||||
else:
|
||||
console.print(f"\n[bold magenta]{role}:[/bold magenta] {message_content}")
|
||||
|
||||
history += tool_messages
|
||||
history.append({"role": role, "content": message_content})
|
||||
|
||||
return ChatInteractionResult(history, tool_messages, tool_authorization)
|
||||
|
||||
|
||||
def wait_for_authorization_completion(client: Arcade, tool_authorization: dict | None) -> None:
|
||||
"""
|
||||
Wait for the authorization for a tool call to complete i.e., wait for the user to click on
|
||||
the approval link and authorize Arcade.
|
||||
"""
|
||||
if tool_authorization is None:
|
||||
return
|
||||
auth_response = AuthResponse.model_validate(tool_authorization)
|
||||
|
||||
while auth_response.status != "completed":
|
||||
time.sleep(0.5)
|
||||
auth_response = client.auth.status(auth_response)
|
||||
|
||||
|
||||
def get_tool_authorization(
|
||||
choice: Union[ChatCompletionChoice, ChatCompletionChunkChoice],
|
||||
) -> dict | None:
|
||||
"""
|
||||
Get the tool authorization from a chat response's choice.
|
||||
"""
|
||||
if hasattr(choice, "tool_authorizations") and choice.tool_authorizations:
|
||||
return choice.tool_authorizations[0] # type: ignore[no-any-return]
|
||||
return None
|
||||
|
||||
|
||||
def is_authorization_pending(tool_authorization: dict | None) -> bool:
|
||||
"""
|
||||
Check if the authorization for a tool call is pending.
|
||||
Expects a chat response's choice.tool_authorizations as input.
|
||||
"""
|
||||
is_auth_pending = (
|
||||
tool_authorization is not None and tool_authorization.get("status", "") == "pending"
|
||||
)
|
||||
return is_auth_pending
|
||||
|
|
|
|||
Loading…
Reference in a new issue