QOL improvements for arcade chat (#27)
Includes these improvements: - Up-arrow loads your previous message history in the session (if available) - Shows the target engine URL when the chat starts - Shows the model name in the chat: e.g. `Assistant (gpt-4o-mini): ` - URLs are re-written as clickable markdown links in the chat output - mypy cleanup
This commit is contained in:
parent
950e075750
commit
aa1b59497b
2 changed files with 43 additions and 12 deletions
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import readline
|
||||
import threading
|
||||
import uuid
|
||||
import webbrowser
|
||||
|
|
@ -12,11 +13,12 @@ from rich.markup import escape
|
|||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from arcade.cli.authn import check_existing_login, LocalAuthCallbackServer
|
||||
from arcade.cli.authn import LocalAuthCallbackServer, check_existing_login
|
||||
from arcade.cli.utils import (
|
||||
OrderCommands,
|
||||
create_cli_catalog,
|
||||
display_streamed_markdown,
|
||||
markdownify_urls,
|
||||
validate_and_get_config,
|
||||
)
|
||||
from arcade.client import Arcade
|
||||
|
|
@ -150,13 +152,20 @@ def chat(
|
|||
"bold magenta underline",
|
||||
),
|
||||
"\n",
|
||||
"\n",
|
||||
"Chatting with Arcade Engine at " + config.engine_url,
|
||||
)
|
||||
console.print(chat_header)
|
||||
|
||||
while True:
|
||||
user_input = console.input(
|
||||
f"\n[magenta][bold]User[/bold] {user_attribution}:[/magenta] "
|
||||
)
|
||||
console.print(f"\n[magenta][bold]User[/bold] {user_attribution}:[/magenta] ")
|
||||
|
||||
# Use input() instead of console.input() to leverage readline history
|
||||
user_input = input()
|
||||
|
||||
# Add the input to history
|
||||
readline.add_history(user_input)
|
||||
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
|
||||
if stream:
|
||||
|
|
@ -169,8 +178,7 @@ def chat(
|
|||
user=user_email,
|
||||
stream=True,
|
||||
)
|
||||
role, message = display_streamed_markdown(stream_response)
|
||||
messages.append({"role": role, "content": message})
|
||||
role, message_content = display_streamed_markdown(stream_response, model)
|
||||
else:
|
||||
response = client.chat.completions.create( # type: ignore[call-overload]
|
||||
model=model,
|
||||
|
|
@ -183,11 +191,14 @@ def chat(
|
|||
role = response.choices[0].message.role
|
||||
|
||||
if role == "assistant":
|
||||
console.print("\n[bold blue]Assistant:[/bold blue] ", Markdown(message_content))
|
||||
message_content = markdownify_urls(message_content)
|
||||
console.print(
|
||||
f"\n[bold blue]Assistant ({model}):[/bold blue] ", Markdown(message_content)
|
||||
)
|
||||
else:
|
||||
console.print(f"\n[bold magenta]{role}:[/bold magenta] {message_content}")
|
||||
|
||||
messages.append({"role": role, "content": message_content})
|
||||
messages.append({"role": role, "content": message_content})
|
||||
|
||||
except KeyboardInterrupt:
|
||||
console.print("Chat stopped by user.", style="bold blue")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from typing import TYPE_CHECKING
|
|||
import typer
|
||||
from openai.resources.chat.completions import ChatCompletionChunk, Stream
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.markdown import Markdown
|
||||
from typer.core import TyperGroup
|
||||
from typer.models import Context
|
||||
|
|
@ -55,10 +54,11 @@ def create_cli_catalog(
|
|||
return catalog
|
||||
|
||||
|
||||
def display_streamed_markdown(stream: Stream[ChatCompletionChunk]) -> tuple[str, str]:
|
||||
def display_streamed_markdown(stream: Stream[ChatCompletionChunk], model: str) -> tuple[str, str]:
|
||||
"""
|
||||
Display the streamed markdown chunks as a single line.
|
||||
"""
|
||||
from rich.live import Live
|
||||
|
||||
full_message = ""
|
||||
role = ""
|
||||
|
|
@ -69,12 +69,32 @@ def display_streamed_markdown(stream: Stream[ChatCompletionChunk]) -> tuple[str,
|
|||
if role == "":
|
||||
role = choice.delta.role or ""
|
||||
if role == "assistant":
|
||||
console.print("\n[bold blue]Assistant:[/bold blue] ")
|
||||
console.print(f"\n[bold blue]Assistant ({model}):[/bold blue] ")
|
||||
if chunk_message:
|
||||
full_message += chunk_message
|
||||
markdown_chunk = Markdown(full_message)
|
||||
live.update(markdown_chunk)
|
||||
return role, full_message
|
||||
|
||||
# Markdownify URLs in the final message if applicable
|
||||
if role == "assistant":
|
||||
full_message = markdownify_urls(full_message)
|
||||
live.update(Markdown(full_message))
|
||||
|
||||
return role, full_message
|
||||
|
||||
|
||||
def markdownify_urls(message: str) -> str:
|
||||
"""
|
||||
Convert URLs in the message to markdown links.
|
||||
"""
|
||||
import re
|
||||
|
||||
# This regex will match URLs that are not already formatted as markdown links:
|
||||
# [Link text](https://example.com)
|
||||
url_pattern = r"(?<!\]\()https?://\S+"
|
||||
|
||||
# Wrap all URLs in the message with markdown links
|
||||
return re.sub(url_pattern, r"[Link](\g<0>)", message)
|
||||
|
||||
|
||||
def validate_and_get_config(
|
||||
|
|
|
|||
Loading…
Reference in a new issue