QOL improvements for arcade chat (#27)

Includes these improvements:
- Up-arrow loads your previous message history in the session (if
available)
- Shows the target engine URL when the chat starts
- Shows the model name in the chat: e.g. `Assistant (gpt-4o-mini): `
- URLs are re-written as clickable markdown links in the chat output
- mypy cleanup
This commit is contained in:
Nate Barbettini 2024-08-30 13:06:38 -07:00 committed by GitHub
parent 950e075750
commit aa1b59497b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 43 additions and 12 deletions

View file

@ -1,4 +1,5 @@
import os
import readline
import threading
import uuid
import webbrowser
@ -12,11 +13,12 @@ from rich.markup import escape
from rich.table import Table
from rich.text import Text
from arcade.cli.authn import check_existing_login, LocalAuthCallbackServer
from arcade.cli.authn import LocalAuthCallbackServer, check_existing_login
from arcade.cli.utils import (
OrderCommands,
create_cli_catalog,
display_streamed_markdown,
markdownify_urls,
validate_and_get_config,
)
from arcade.client import Arcade
@ -150,13 +152,20 @@ def chat(
"bold magenta underline",
),
"\n",
"\n",
"Chatting with Arcade Engine at " + config.engine_url,
)
console.print(chat_header)
while True:
user_input = console.input(
f"\n[magenta][bold]User[/bold] {user_attribution}:[/magenta] "
)
console.print(f"\n[magenta][bold]User[/bold] {user_attribution}:[/magenta] ")
# Use input() instead of console.input() to leverage readline history
user_input = input()
# Add the input to history
readline.add_history(user_input)
messages.append({"role": "user", "content": user_input})
if stream:
@ -169,8 +178,7 @@ def chat(
user=user_email,
stream=True,
)
role, message = display_streamed_markdown(stream_response)
messages.append({"role": role, "content": message})
role, message_content = display_streamed_markdown(stream_response, model)
else:
response = client.chat.completions.create( # type: ignore[call-overload]
model=model,
@ -183,11 +191,14 @@ def chat(
role = response.choices[0].message.role
if role == "assistant":
console.print("\n[bold blue]Assistant:[/bold blue] ", Markdown(message_content))
message_content = markdownify_urls(message_content)
console.print(
f"\n[bold blue]Assistant ({model}):[/bold blue] ", Markdown(message_content)
)
else:
console.print(f"\n[bold magenta]{role}:[/bold magenta] {message_content}")
messages.append({"role": role, "content": message_content})
messages.append({"role": role, "content": message_content})
except KeyboardInterrupt:
console.print("Chat stopped by user.", style="bold blue")

View file

@ -3,7 +3,6 @@ from typing import TYPE_CHECKING
import typer
from openai.resources.chat.completions import ChatCompletionChunk, Stream
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
from typer.core import TyperGroup
from typer.models import Context
@ -55,10 +54,11 @@ def create_cli_catalog(
return catalog
def display_streamed_markdown(stream: Stream[ChatCompletionChunk]) -> tuple[str, str]:
def display_streamed_markdown(stream: Stream[ChatCompletionChunk], model: str) -> tuple[str, str]:
"""
Display the streamed markdown chunks as a single line.
"""
from rich.live import Live
full_message = ""
role = ""
@ -69,12 +69,32 @@ def display_streamed_markdown(stream: Stream[ChatCompletionChunk]) -> tuple[str,
if role == "":
role = choice.delta.role or ""
if role == "assistant":
console.print("\n[bold blue]Assistant:[/bold blue] ")
console.print(f"\n[bold blue]Assistant ({model}):[/bold blue] ")
if chunk_message:
full_message += chunk_message
markdown_chunk = Markdown(full_message)
live.update(markdown_chunk)
return role, full_message
# Markdownify URLs in the final message if applicable
if role == "assistant":
full_message = markdownify_urls(full_message)
live.update(Markdown(full_message))
return role, full_message
def markdownify_urls(message: str) -> str:
"""
Convert URLs in the message to markdown links.
"""
import re
# This regex will match URLs that are not already formatted as markdown links:
# [Link text](https://example.com)
url_pattern = r"(?<!\]\()https?://\S+"
# Wrap all URLs in the message with markdown links
return re.sub(url_pattern, r"[Link](\g<0>)", message)
def validate_and_get_config(