Start Slack toolkit (#17)

- Start a Slack toolkit with a few tools
- Update Google auth
- Show user's email in `arcade chat`
This commit is contained in:
Nate Barbettini 2024-08-22 16:12:42 -07:00 committed by GitHub
parent 3154298572
commit acba912816
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 272 additions and 57 deletions

View file

@ -220,17 +220,30 @@ def chat(
)
console.print(chat_header)
user = config.user.email if config.user and config.user.email else None
user_attribution = f" ({user})" if user else ""
while True:
user_input = console.input("\n[bold magenta]User: [/bold magenta]")
user_input = console.input(
f"\n[magenta][bold]User[/bold]{user_attribution}:[/magenta] "
)
messages.append({"role": "user", "content": user_input})
if stream:
stream_response = client.stream_complete(
model=model, messages=messages, tool_choice="generate"
model=model,
messages=messages,
tool_choice="generate",
user=user,
)
display_streamed_markdown(stream_response)
else:
response = client.complete(model=model, messages=messages, tool_choice="generate")
response = client.complete(
model=model,
messages=messages,
tool_choice="generate",
user=user,
)
message_content = response.choices[0].message.content or ""
role = response.choices[0].message.role
@ -350,7 +363,7 @@ def display_config_as_table(config: Config) -> None:
table.add_column("Name")
table.add_column("Value")
for section_name in config.dict():
for section_name in config.model_dump():
section = getattr(config, section_name)
if section:
section = section.dict()

View file

@ -26,6 +26,7 @@ from arcade.core.schema import (
GoogleRequirement,
InputParameter,
OAuth2Requirement,
SlackUserRequirement,
ToolAuthRequirement,
ToolContext,
ToolDefinition,
@ -42,7 +43,7 @@ from arcade.core.utils import (
snake_to_pascal_case,
)
from arcade.sdk.annotations import Inferrable
from arcade.sdk.auth import Google, OAuth2, ToolAuthorization
from arcade.sdk.auth import Google, OAuth2, SlackUser, ToolAuthorization
WireType = Literal["string", "integer", "float", "boolean", "json"]
@ -190,6 +191,10 @@ class ToolCatalog(BaseModel):
new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump())
elif isinstance(auth_requirement, Google):
new_auth_requirement.google = GoogleRequirement(**auth_requirement.model_dump())
elif isinstance(auth_requirement, SlackUser):
new_auth_requirement.slack_user = SlackUserRequirement(
**auth_requirement.model_dump()
)
auth_requirement = new_auth_requirement
return ToolDefinition(

View file

@ -79,6 +79,13 @@ class GoogleRequirement(BaseModel):
"""The scope(s) needed for authorization."""
class SlackUserRequirement(BaseModel):
"""Indicates that the tool requires Slack (user token) authorization."""
scope: Optional[list[str]] = None
"""The scope(s) needed for authorization."""
class ToolAuthRequirement(BaseModel):
"""A requirement for authorization to use a tool."""
@ -91,6 +98,9 @@ class ToolAuthRequirement(BaseModel):
google: Optional[GoogleRequirement] = None
"""The Google requirement, if any."""
slack_user: Optional[SlackUserRequirement] = None
"""The Slack (user token) requirement, if any."""
class ToolRequirements(BaseModel):
"""The requirements for a tool to run."""

View file

@ -38,6 +38,16 @@ class Google(ToolAuthorization):
"""The scope(s) needed for the authorized action."""
class SlackUser(ToolAuthorization):
"""Marks a tool as requiring Slack (user token) authorization."""
def get_provider(self) -> str:
return "slack_user"
scope: Optional[list[str]] = None
"""The scope(s) needed for the authorized action."""
class GitHubApp(ToolAuthorization):
"""Marks a tool as requiring GitHub App authorization."""

View file

@ -2,9 +2,9 @@ from fastapi import FastAPI, HTTPException
from openai import AsyncOpenAI
from pydantic import BaseModel
from arcade_arithmetic.tools import arithmetic
from arcade_gmail.tools import gmail
from arcade_github.tools import public_repo, user
from arcade_github.tools import repo, user
from arcade_slack.tools import chat
from arcade.actor.fastapi.actor import FastAPIActor
@ -13,13 +13,16 @@ client = AsyncOpenAI(base_url="http://localhost:9099/v1")
app = FastAPI()
actor = FastAPIActor(app)
actor.register_tool(arithmetic.add)
actor.register_tool(arithmetic.multiply)
actor.register_tool(arithmetic.divide)
actor.register_tool(arithmetic.sqrt)
# actor.register_tool(arithmetic.add)
# actor.register_tool(arithmetic.multiply)
# actor.register_tool(arithmetic.divide)
# actor.register_tool(arithmetic.sqrt)
actor.register_tool(gmail.get_emails)
actor.register_tool(public_repo.count_stargazers)
actor.register_tool(gmail.write_draft)
actor.register_tool(repo.count_stargazers)
actor.register_tool(repo.search_issues)
actor.register_tool(user.set_starred)
actor.register_tool(chat.send_dm_to_user)
class ChatRequest(BaseModel):
@ -27,7 +30,7 @@ class ChatRequest(BaseModel):
@app.post("/chat")
async def chat(request: ChatRequest, tool_choice: str = "execute"):
async def postChat(request: ChatRequest, tool_choice: str = "execute"):
try:
raw_response = await client.chat.completions.create(
messages=[
@ -35,16 +38,15 @@ async def chat(request: ChatRequest, tool_choice: str = "execute"):
{"role": "user", "content": request.message},
],
model="gpt-4o-mini",
max_tokens=150,
max_tokens=500,
# TODO tests for tool choice
tools=[
"Add",
"Multiply",
"Divide",
"Sqrt",
"GetEmails",
"WriteDraft",
"CountStargazers",
"SetStarred",
"SearchIssues",
"SendDmToUser",
],
tool_choice=tool_choice,
user="sam",

View file

@ -116,15 +116,32 @@ googleapis-common-protos = "1.63.2"
type = "directory"
url = "../../toolkits/gmail"
[[package]]
name = "arcade-slack"
version = "0.1.0"
description = "Slack tools for LLMs"
optional = false
python-versions = "^3.10"
files = []
develop = true
[package.dependencies]
arcade-ai = "^0.1.0"
slack-sdk = "^3.31.0"
[package.source]
type = "directory"
url = "../../toolkits/slack"
[[package]]
name = "cachetools"
version = "5.4.0"
version = "5.5.0"
description = "Extensible memoizing collections and decorators"
optional = false
python-versions = ">=3.7"
files = [
{file = "cachetools-5.4.0-py3-none-any.whl", hash = "sha256:3ae3b49a3d5e28a77a0be2b37dbcb89005058959cb2323858c2657c4a8cab474"},
{file = "cachetools-5.4.0.tar.gz", hash = "sha256:b8adc2e7c07f105ced7bc56dbb6dfbe7c4a00acce20e2227b3f355be89bc6827"},
{file = "cachetools-5.5.0-py3-none-any.whl", hash = "sha256:02134e8439cdc2ffb62023ce1debca2944c3f289d66bb17ead3ab3dede74b292"},
{file = "cachetools-5.5.0.tar.gz", hash = "sha256:2cc24fb4cbe39633fb7badd9db9ca6295d766d9c2995f245725a46715d050f2a"},
]
[[package]]
@ -625,13 +642,13 @@ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
[[package]]
name = "openai"
version = "1.40.8"
version = "1.42.0"
description = "The official Python library for the openai API"
optional = false
python-versions = ">=3.7.1"
files = [
{file = "openai-1.40.8-py3-none-any.whl", hash = "sha256:3ed4ddad48e0dde059c9b4d3dc240e47781beca2811e52ba449ddc4a471a2fd4"},
{file = "openai-1.40.8.tar.gz", hash = "sha256:e225f830b946378e214c5b2cfa8df28ba2aeb7e9d44f738cb2a926fd971f5bc0"},
{file = "openai-1.42.0-py3-none-any.whl", hash = "sha256:dc91e0307033a4f94931e5d03cc3b29b9717014ad5e73f9f2051b6cb5eda4d80"},
{file = "openai-1.42.0.tar.gz", hash = "sha256:c9d31853b4e0bc2dc8bd08003b462a006035655a701471695d0bfdc08529cde3"},
]
[package.dependencies]
@ -982,6 +999,20 @@ files = [
[package.dependencies]
pyasn1 = ">=0.1.3"
[[package]]
name = "slack-sdk"
version = "3.31.0"
description = "The Slack API Platform SDK for Python"
optional = false
python-versions = ">=3.6"
files = [
{file = "slack_sdk-3.31.0-py2.py3-none-any.whl", hash = "sha256:a120cc461e8ebb7d9175f171dbe0ded37a6878d9f7b96b28e4bad1227399047b"},
{file = "slack_sdk-3.31.0.tar.gz", hash = "sha256:740d2f9c49cbfcbd46fca56b4be9d527934c225312aac18fd2c0fca0ef6bc935"},
]
[package.extras]
optional = ["SQLAlchemy (>=1.4,<3)", "aiodns (>1.0)", "aiohttp (>=3.7.3,<4)", "boto3 (<=2)", "websocket-client (>=1,<2)", "websockets (>=9.1,<13)"]
[[package]]
name = "sniffio"
version = "1.3.1"
@ -1115,4 +1146,4 @@ zstd = ["zstandard (>=0.18.0)"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "ba5d3be15ef9c2adf5c108b4f8ff3388f3a4df546fa84f0bcfeb56c2768171ad"
content-hash = "b40d9fea27174b3ac18d83a936744819ed484e1baa2aaba5ab54ae9e4681c416"

View file

@ -11,6 +11,7 @@ arcade-ai = {path = "../../arcade", develop = true}
arcade_arithmetic = {path = "../../toolkits/math", develop = true}
arcade_gmail = {path = "../../toolkits/gmail", develop = true}
arcade_github = {path = "../../toolkits/github", develop = true}
arcade_slack = {path = "../../toolkits/slack", develop = true}
[build-system]
requires = ["poetry-core"]

View file

@ -1,31 +0,0 @@
from typing import Annotated
from arcade.sdk import tool
import requests
# TODO: This does not support private repositories. https://app.clickup.com/t/86b1r3mhe
@tool
def count_stargazers(
owner: Annotated[str, "The owner of the repository"],
name: Annotated[str, "The name of the repository"],
) -> int:
"""Count the number of stargazers (stars) for a public GitHub repository.
For example, to count the number of stars for microsoft/vscode, you would use:
```
count_stargazers(owner="microsoft", name="vscode")
```
"""
url = f"https://api.github.com/repos/{owner}/{name}"
response = requests.get(url)
print(response)
if response.status_code == 200:
data = response.json()
return data.get("stargazers_count", 0)
else:
raise Exception(
f"Failed to fetch repository data. Status code: {response.status_code}"
)

View file

@ -0,0 +1,74 @@
from typing import Annotated
from arcade.core.schema import ToolContext
from arcade.sdk import tool
from arcade.sdk.auth import GitHubApp
import requests
@tool(requires_auth=GitHubApp())
def search_issues(
context: ToolContext,
owner: Annotated[str, "The owner of the repository"],
name: Annotated[str, "The name of the repository"],
query: Annotated[str, "The query to search for"],
limit: Annotated[int, "The maximum number of issues to return"] = 10,
) -> dict[str, list[dict]]:
"""Search for issues in a GitHub repository."""
# Build the search query
url = f"https://api.github.com/search/issues?q={query}+is:issue+is:open+repo:{owner}/{name}+sort:created-desc&per_page={limit}"
# Make the API request
headers = {
"Authorization": f"token {context.authorization.token}",
"Accept": "application/vnd.github.v3+json",
}
response = requests.get(url, headers=headers)
# Check for successful response
# handle 422 for can't find repo
# TODO - how should errors bubble back up if tool_choice=execute
if response.status_code != 200:
raise Exception(f"Failed to fetch issues: {response.status_code}")
issues = response.json().get("items", [])
results = []
for issue in issues:
results.append(
{
"title": issue["title"],
"url": issue["html_url"],
"created_at": issue["created_at"],
}
)
return {"issues": results}
# TODO: This does not support private repositories. https://app.clickup.com/t/86b1r3mhe
@tool
def count_stargazers(
owner: Annotated[str, "The owner of the repository"],
name: Annotated[str, "The name of the repository"],
) -> int:
"""Count the number of stargazers (stars) for a public GitHub repository.
For example, to count the number of stars for microsoft/vscode, you would use:
```
count_stargazers(owner="microsoft", name="vscode")
```
"""
url = f"https://api.github.com/repos/{owner}/{name}"
response = requests.get(url)
print(response)
if response.status_code == 200:
data = response.json()
return data.get("stargazers_count", 0)
else:
raise Exception(
f"Failed to fetch repository data. Status code: {response.status_code}"
)

View file

@ -1,9 +1,11 @@
import base64
from email.mime.text import MIMEText
import re
from base64 import urlsafe_b64decode
from typing import Annotated
from bs4 import BeautifulSoup
from google.auth.credentials import Credentials
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from arcade.core.schema import ToolContext
@ -11,6 +13,40 @@ from arcade.sdk import tool
from arcade.sdk.auth import Google
@tool(
requires_auth=Google(
scope=["https://www.googleapis.com/auth/gmail.compose"],
)
)
async def write_draft(
context: ToolContext,
subject: Annotated[str, "The subject of the email"],
body: Annotated[str, "The body of the email"],
recipient: Annotated[str, "The recipient of the email"],
) -> Annotated[str, "The URL of the draft"]:
"""Compose a new email draft."""
# Set up the Gmail API client
service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
message = MIMEText(body)
message["to"] = recipient
message["subject"] = subject
# Encode the message in base64
raw_message = base64.urlsafe_b64encode(message.as_bytes()).decode()
# Create the draft
draft = {"message": {"raw": raw_message}}
draft_message = service.users().drafts().create(userId="me", body=draft).execute()
return f"Draft created: {get_draft_url(draft_message['id'])}"
def get_draft_url(draft_id):
return f"https://mail.google.com/mail/u/0/#drafts/{draft_id}"
@tool(
requires_auth=Google(
scope=["https://www.googleapis.com/auth/gmail.readonly"],

View file

View file

@ -0,0 +1,47 @@
from typing import Annotated
from arcade.core.schema import ToolContext
from arcade.sdk import tool
from arcade.sdk.auth import SlackUser
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
@tool(
requires_auth=SlackUser(
scope=["chat:write", "im:write", "users.profile:read", "users:read"],
)
)
def send_dm_to_user(
context: ToolContext,
user_name: Annotated[str, "The Slack username of the person you want to message"],
message: Annotated[str, "The message you want to send"],
):
"""Send a direct message to a user in Slack."""
slackClient = WebClient(token=context.authorization.token)
try:
# Step 1: Retrieve the user's Slack ID based on their username
response = slackClient.users_list()
user_id = None
for user in response["members"]:
if user["name"].lower() == user_name.lower():
user_id = user["id"]
break
if not user_id:
# does this end up as a developerMessage?
# does it end up in the LLM context?
# provide the dev an Error type that controls what ends up in the LLM context
raise ValueError(f"User with username '{user_name}' not found.")
# Step 2: Retrieve the DM channel ID with the user
im_response = slackClient.conversations_open(users=[user_id])
dm_channel_id = im_response["channel"]["id"]
# Step 3: Send the message as if it's from you (because we're using a user token)
slackClient.chat_postMessage(channel=dm_channel_id, text=message)
except SlackApiError as e:
# this should be caught also, not printed
print(f"Error sending message: {e.response['error']}")

View file

@ -0,0 +1,17 @@
[tool.poetry]
name = "arcade_slack"
version = "0.1.0"
description = "Slack tools for LLMs"
authors = ["Nate Barbettini <nate@arcade-ai.com>"]
[tool.poetry.dependencies]
python = "^3.10"
arcade-ai = "^0.1.0"
slack-sdk = "^3.31.0"
[tool.poetry.dev-dependencies]
pytest = "^7.4.0"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"