Make the postgres toolikit better (#509)

Per the discussions in the [blog
post](https://docs.google.com/document/d/1wZi1yKWKOyg1dpueA2eBvTtYlUjqqp7e_M1nuAoF8NY/edit?tab=t.0),
let's update the postgres toolkit!

<img width="2141" height="1364" alt="Screenshot 2025-07-22 at 5 28
06 PM"
src="https://github.com/user-attachments/assets/1b6a5e0a-9429-4c16-9a0c-ac36ea520bea"
/>
This commit is contained in:
Evan Tahler 2025-07-23 15:28:57 -07:00 committed by GitHub
parent 30739dc44a
commit 4144a42392
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 589 additions and 82 deletions

View file

@ -5,3 +5,4 @@ sdreyer
wdawson
byrro
torresmateo
evantahler

View file

@ -78,27 +78,103 @@ class DatabaseEngine:
cls._engines.clear()
@classmethod
def sanitize_query(cls, query: str) -> str:
"""
Sanitize a query to not break our read-only session.
THIS IS REALLY UNSAFE AND SHOULD NOT BE USED IN PRODUCTION. USE A DATABASE CONNECTION WITH A READ-ONLY USER AND PREPARE STATEMENTS.
There are also valid reasons for the ";" character, and this prevents that.
"""
def sanitize_query( # noqa: C901
cls,
select_clause: str,
from_clause: str,
limit: int,
offset: int,
join_clause: str | None,
where_clause: str | None,
having_clause: str | None,
group_by_clause: str | None,
order_by_clause: str | None,
with_clause: str | None,
) -> tuple[str, dict[str, Any]]:
# Remove the leading keywords from the clauses if they are present
if select_clause.strip().split(" ")[0].upper() == "SELECT":
select_clause = select_clause.strip()[6:]
parts = query.split(";")
if len(parts) > 1:
raise RetryableToolError(
"Multiple statements are not allowed in a single query.",
developer_message="Multiple statements are not allowed in a single query.",
additional_prompt_content="Split your query into multiple queries and try again.",
)
if from_clause.strip().split(" ")[0].upper() == "FROM":
from_clause = from_clause.strip()[4:]
words = parts[0].split(" ")
if words[0].upper().strip() != "SELECT":
if join_clause and join_clause.strip().split(" ")[0].upper() == "JOIN":
join_clause = join_clause.strip()[4:]
if where_clause and where_clause.strip().split(" ")[0].upper() == "WHERE":
where_clause = where_clause.strip()[5:]
if group_by_clause and group_by_clause.strip().split(" ")[0].upper() == "GROUP BY":
group_by_clause = group_by_clause.strip()[8:]
if order_by_clause and order_by_clause.strip().split(" ")[0].upper() == "ORDER BY":
order_by_clause = order_by_clause.strip()[8:]
if having_clause and having_clause.strip().split(" ")[0].upper() == "HAVING":
having_clause = having_clause.strip()[6:]
first_select_word = select_clause.strip().split(" ")[0].upper()
if first_select_word in [
"INSERT",
"UPDATE",
"DELETE",
"CREATE",
"ALTER",
"DROP",
"TRUNCATE",
"REINDEX",
"VACUUM",
"ANALYZE",
"COMMENT",
]:
raise RetryableToolError(
"Only SELECT queries are allowed.",
developer_message="Only SELECT queries are allowed.",
additional_prompt_content="Use the <DiscoverTables> and <GetTableSchema> tools to discover the tables and try again.",
)
return f"{query}"
if select_clause.strip() == "*":
raise RetryableToolError(
"Do not use * in the select clause. Use a comma separated list of columns you wish to return.",
)
if limit > MAX_ROWS_RETURNED:
raise RetryableToolError(
f"Limit is too high. Maximum is {MAX_ROWS_RETURNED}.",
)
if offset < 0:
raise RetryableToolError(
"Offset must be greater than or equal to 0.",
developer_message="Offset must be greater than or equal to 0.",
)
if limit <= 0:
raise RetryableToolError(
"Limit must be greater than 0.",
developer_message="Limit must be greater than 0.",
)
# Build query with identifiers directly interpolated, but use parameters for values
parts = []
if with_clause:
parts.append(f"WITH {with_clause}")
parts.append(f"SELECT {select_clause} FROM {from_clause}") # noqa: S608
if join_clause:
parts.append(f"JOIN {join_clause}")
if where_clause:
parts.append(f"WHERE {where_clause}")
if group_by_clause:
parts.append(f"GROUP BY {group_by_clause}")
if having_clause:
parts.append(f"HAVING {having_clause}")
if order_by_clause:
parts.append(f"ORDER BY {order_by_clause}")
parts.append("LIMIT :limit OFFSET :offset")
query = " ".join(parts)
# Only use parameters for values, not identifiers
parameters = {
"limit": limit,
"offset": offset,
}
return query, parameters

View file

@ -29,7 +29,7 @@ async def discover_tables(
) -> list[str]:
"""Discover all the tables in the postgres database when the list of tables is not known.
THIS TOOL SHOULD ALWAYS BE USED BEFORE ANY OTHER TOOL THAT REQUIRES A TABLE NAME.
ALWAYS use this tool before any other tool that requires a table name.
"""
async with await DatabaseEngine.get_engine(
context.get_secret("DATABASE_CONNECTION_STRING")
@ -47,7 +47,7 @@ async def get_table_schema(
"""
Get the schema/structure of a postgres table in the postgres database when the schema is not known, and the name of the table is provided.
THIS TOOL SHOULD ALWAYS BE USED BEFORE EXECUTING ANY QUERY. ALL TABLES IN THE QUERY MUST BE DISCOVERED FIRST USING THE <DiscoverTables> TOOL.
This tool should ALWAYS be used before executing any query. All tables in the query must be discovered first using the <DiscoverTables> tool.
"""
async with await DatabaseEngine.get_engine(
context.get_secret("DATABASE_CONNECTION_STRING")
@ -56,35 +56,86 @@ async def get_table_schema(
@tool(requires_secrets=["DATABASE_CONNECTION_STRING"])
async def execute_query(
async def execute_select_query(
context: ToolContext,
query: Annotated[str, "The postgres SQL query to execute. Only SELECT queries are allowed."],
select_clause: Annotated[
str,
"This is the part of the SQL query that comes after the SELECT keyword wish a comma separated list of columns you wish to return. Do not include the SELECT keyword.",
],
from_clause: Annotated[
str,
"This is the part of the SQL query that comes after the FROM keyword. Do not include the FROM keyword.",
],
limit: Annotated[
int,
"The maximum number of rows to return. This is the LIMIT clause of the query. Default: 100.",
] = 100,
offset: Annotated[
int, "The number of rows to skip. This is the OFFSET clause of the query. Default: 0."
] = 0,
join_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the JOIN keyword. Do not include the JOIN keyword. If no join is needed, leave this blank.",
] = None,
where_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the WHERE keyword. Do not include the WHERE keyword. If no where clause is needed, leave this blank.",
] = None,
having_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the HAVING keyword. Do not include the HAVING keyword. If no having clause is needed, leave this blank.",
] = None,
group_by_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the GROUP BY keyword. Do not include the GROUP BY keyword. If no group by clause is needed, leave this blank.",
] = None,
order_by_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the ORDER BY keyword. Do not include the ORDER BY keyword. If no order by clause is needed, leave this blank.",
] = None,
with_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the WITH keyword when basing the query on a virtual table. If no WITH clause is needed, leave this blank.",
] = None,
) -> list[str]:
"""
You have a connection to a postgres database.
Execute a query and return the results against the postgres database.
Execute a SELECT query and return the results against the postgres database. No other queries (INSERT, UPDATE, DELETE, etc.) are allowed.
ONLY USE THIS TOOL IF YOU HAVE ALREADY LOADED THE SCHEMA OF THE TABLES YOU NEED TO QUERY. USE THE <GetTableSchema> TOOL TO LOAD THE SCHEMA IF NOT ALREADY KNOWN.
ONLY use this tool if you have already loaded the schema of the tables you need to query. Use the <GetTableSchema> tool to load the schema if not already known.
The final query will be constructed as follows:
SELECT {select_query_part} FROM {from_clause} JOIN {join_clause} WHERE {where_clause} HAVING {having_clause} ORDER BY {order_by_clause} LIMIT {limit} OFFSET {offset}
When running queries, follow these rules which will help avoid errors:
* Never "select *" from a table. Always select the columns you need.
* Always order your results by the most important columns first. If you aren't sure, order by the primary key.
* Always use case-insensitive queries to match strings in the query.
* Always trim strings in the query.
* Prefer LIKE queries over direct string matches or regex queries.
* Only join on columns that are indexed or the primary key. Do not join on arbitrary columns.
Only SELECT queries are allowed. Do not use INSERT, UPDATE, DELETE, or other DML statements. This tool will reject them.
Unless otherwise specified, ensure that query has a LIMIT of 100 for all results. This tool will enforce that no more than 1000 rows are returned at maximum.
"""
async with await DatabaseEngine.get_engine(
context.get_secret("DATABASE_CONNECTION_STRING")
) as engine:
try:
return await _execute_query(engine, query)
return await _execute_query(
engine,
select_clause=select_clause,
from_clause=from_clause,
limit=limit,
offset=offset,
join_clause=join_clause,
where_clause=where_clause,
having_clause=having_clause,
group_by_clause=group_by_clause,
order_by_clause=order_by_clause,
with_clause=with_clause,
)
except Exception as e:
raise RetryableToolError(
f"Query failed: {e}",
developer_message=f"Query '{query}' failed.",
developer_message=f"Query failed with parameters: select_clause={select_clause}, from_clause={from_clause}, limit={limit}, offset={offset}, join_clause={join_clause}, where_clause={where_clause}, having_clause={having_clause}, order_by_clause={order_by_clause}, with_clause={with_clause}.",
additional_prompt_content="Load the database schema <GetTableSchema> or use the <DiscoverTables> tool to discover the tables and try again.",
retry_after_ms=10,
) from e
@ -168,11 +219,35 @@ async def _get_table_schema(engine: AsyncEngine, schema_name: str, table_name: s
async def _execute_query(
engine: AsyncEngine, query: str, params: dict[str, Any] | None = None
engine: AsyncEngine,
select_clause: str,
from_clause: str,
limit: int,
offset: int,
join_clause: str | None,
where_clause: str | None,
having_clause: str | None,
group_by_clause: str | None,
order_by_clause: str | None,
with_clause: str | None,
) -> list[str]:
"""Execute a query and return the results."""
async with engine.connect() as connection:
result = await connection.execute(text(DatabaseEngine.sanitize_query(query)), params)
query, parameters = DatabaseEngine.sanitize_query(
select_clause=select_clause,
from_clause=from_clause,
limit=limit,
offset=offset,
join_clause=join_clause,
where_clause=where_clause,
having_clause=having_clause,
group_by_clause=group_by_clause,
order_by_clause=order_by_clause,
with_clause=with_clause,
)
print(f"Query: {query}")
print(f"Parameters: {parameters}")
result = await connection.execute(text(query), parameters)
rows = result.fetchall()
results = [str(row) for row in rows]
return results[:MAX_ROWS_RETURNED]

View file

@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "arcade_postgres"
version = "0.1.0"
version = "0.2.0"
description = "Tools to query and explore a postgres database"
requires-python = ">=3.10"
dependencies = [

View file

@ -33,40 +33,261 @@ INSERT INTO "public"."messages" (
"created_at",
"updated_at"
)
VALUES (
VALUES -- User 1 (Alice) - 3 messages
(
1,
'Evan says hello',
'Hello everyone!',
1,
'2025-01-10 10:00:00.000000',
'2025-01-10 10:00:00.000000'
),
(
2,
'How is everyone doing today?',
1,
'2025-01-10 11:30:00.000000',
'2025-01-10 11:30:00.000000'
),
(
3,
'2025-04-10 17:21:05.504468',
'2025-04-10 17:21:05.504468'
'Great to see you all here!',
1,
'2025-01-10 14:15:00.000000',
'2025-01-10 14:15:00.000000'
),
-- User 2 (Bob) - 2 messages
(
4,
'Hi Alice! Doing well, thanks for asking.',
2,
'2025-01-10 11:35:00.000000',
'2025-01-10 11:35:00.000000'
),
(
5100,
'Hello! The current time is 2025-01-13T14:38:39.204Z',
12,
'2025-01-13 06:38:39.210897',
'2025-01-13 06:38:39.210897'
5,
'Anyone up for a game later?',
2,
'2025-01-10 16:20:00.000000',
'2025-01-10 16:20:00.000000'
),
-- User 3 (Charlie) - 3 messages
(
6,
'Count me in for the game!',
3,
'2025-01-10 16:25:00.000000',
'2025-01-10 16:25:00.000000'
),
(
5101,
'Hello! The current time is 2025-01-13T14:55:32.560Z',
12,
'2025-01-13 06:55:32.56934',
'2025-01-13 06:55:32.56934'
7,
'What time works for everyone?',
3,
'2025-01-10 16:30:00.000000',
'2025-01-10 16:30:00.000000'
),
(
5102,
'Hello! The current time is 2025-01-13T15:00:37.250Z',
12,
'2025-01-13 07:00:37.261816',
'2025-01-13 07:00:37.261816'
8,
'I can play around 8 PM',
3,
'2025-01-10 17:00:00.000000',
'2025-01-10 17:00:00.000000'
),
-- User 4 (Diana) - 2 messages
(
9,
'8 PM works for me too!',
4,
'2025-01-10 17:05:00.000000',
'2025-01-10 17:05:00.000000'
),
(
10,
'What game should we play?',
4,
'2025-01-10 17:10:00.000000',
'2025-01-10 17:10:00.000000'
),
-- User 5 (Evan) - 3 messages
(
11,
'I suggest we try the new arcade game!',
5,
'2025-01-10 17:15:00.000000',
'2025-01-10 17:15:00.000000'
),
(
5319,
'Hello! The current time is 2025-01-14T07:17:07.115Z',
12,
'2025-01-13 23:17:07.123393',
'2025-01-13 23:17:07.123393'
'It has great multiplayer features',
5,
'2025-01-10 17:20:00.000000',
'2025-01-10 17:20:00.000000'
),
(
13,
'Perfect timing for a weekend session',
5,
'2025-01-10 18:00:00.000000',
'2025-01-10 18:00:00.000000'
),
-- User 6 (Fiona) - 2 messages
(
14,
'Sounds like fun! I love arcade games.',
6,
'2025-01-10 18:05:00.000000',
'2025-01-10 18:05:00.000000'
),
(
15,
'Should I bring snacks?',
6,
'2025-01-10 18:10:00.000000',
'2025-01-10 18:10:00.000000'
),
-- User 7 (George) - 3 messages
(
16,
'Snacks are always welcome!',
7,
'2025-01-10 18:15:00.000000',
'2025-01-10 18:15:00.000000'
),
(
17,
'I can bring some drinks',
7,
'2025-01-10 18:20:00.000000',
'2025-01-10 18:20:00.000000'
),
(
18,
'This is going to be awesome',
7,
'2025-01-10 19:00:00.000000',
'2025-01-10 19:00:00.000000'
),
-- User 8 (Helen) - 2 messages
(
19,
'I agree! Cannot wait for the game night.',
8,
'2025-01-10 19:05:00.000000',
'2025-01-10 19:05:00.000000'
),
(
20,
'Should we set up a Discord call?',
8,
'2025-01-10 19:10:00.000000',
'2025-01-10 19:10:00.000000'
),
-- User 9 (Ian) - 3 messages
(
21,
'Discord would be perfect for voice chat',
9,
'2025-01-10 19:15:00.000000',
'2025-01-10 19:15:00.000000'
),
(
22,
'I will create a server for us',
9,
'2025-01-10 19:20:00.000000',
'2025-01-10 19:20:00.000000'
),
(
23,
'Link will be shared in a few minutes',
9,
'2025-01-10 19:25:00.000000',
'2025-01-10 19:25:00.000000'
),
-- User 10 (Julia) - 2 messages
(
24,
'Thanks Ian! You are the best.',
10,
'2025-01-10 19:30:00.000000',
'2025-01-10 19:30:00.000000'
),
(
25,
'See you all at 8 PM!',
10,
'2025-01-10 19:35:00.000000',
'2025-01-10 19:35:00.000000'
),
-- Additional messages for Evan (user_id 5) - 10 more messages
(
26,
'Just finished setting up the game server!',
5,
'2025-01-10 20:00:00.000000',
'2025-01-10 20:00:00.000000'
),
(
27,
'Everyone should be able to connect now',
5,
'2025-01-10 20:05:00.000000',
'2025-01-10 20:05:00.000000'
),
(
28,
'I added some custom maps too',
5,
'2025-01-10 20:10:00.000000',
'2025-01-10 20:10:00.000000'
),
(
29,
'The graphics look amazing on this new version',
5,
'2025-01-10 20:15:00.000000',
'2025-01-10 20:15:00.000000'
),
(
30,
'Hope you all enjoy the new features',
5,
'2025-01-10 20:20:00.000000',
'2025-01-10 20:20:00.000000'
),
(
31,
'I also set up a leaderboard system',
5,
'2025-01-10 20:25:00.000000',
'2025-01-10 20:25:00.000000'
),
(
32,
'We can track high scores now',
5,
'2025-01-10 20:30:00.000000',
'2025-01-10 20:30:00.000000'
),
(
33,
'The game supports up to 8 players simultaneously',
5,
'2025-01-10 20:35:00.000000',
'2025-01-10 20:35:00.000000'
),
(
34,
'I tested it earlier and it runs smoothly',
5,
'2025-01-10 20:40:00.000000',
'2025-01-10 20:40:00.000000'
),
(
35,
'Cannot wait to see everyone online tonight!',
5,
'2025-01-10 20:45:00.000000',
'2025-01-10 20:45:00.000000'
);
INSERT INTO "public"."users" (
"id",
@ -79,36 +300,100 @@ INSERT INTO "public"."users" (
)
VALUES (
1,
'Mario',
'mario@example.com',
'Alice',
'alice@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$tMg1Rd3IEDnp3iFKrqsF4Dsbw6/Cbf6seRB/H5bhaPg$zZj5yn4x3D3O3mDHcW2aczQNiYfAs3cw21XMEIgkF0E',
'2024-09-01 20:49:38.759432',
'2024-09-02 03:49:39.927',
'active'
),
(
3,
'Evan',
'evantahler@gmail.com',
2,
'Bob',
'bob@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$CvOMK1WUd99R7kYXpiBPNYw4OQP53pYIgeMnwz92mrE$HPthId4phMoPT1TWuCRHHCr9BSQA8XoUkQuB1HZsqTY',
'2024-09-02 17:49:23.377425',
'2024-09-02 17:49:23.377425',
'active'
),
(
12,
'Admin',
'admin@arcade.dev',
3,
'Charlie',
'charlie@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$paCAAD1HVZkncP/WvecuUO6zFXp2/8BISpgr5rXRxps$M5kBFc9JHHGNw9SXnPu2ggpJY0mFFCska7TXMrllndo',
'2024-10-13 15:01:30.792909',
'2024-10-13 15:01:30.792909',
'inactive'
'2024-09-03 10:30:15.123456',
'2024-09-03 10:30:15.123456',
'active'
),
(
4,
'Diana',
'diana@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$xyz123ABC456DEF789GHI$SampleHashForDiana123',
'2024-09-04 14:20:30.654321',
'2024-09-04 14:20:30.654321',
'active'
),
(
5,
'Evan',
'evan@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$evanHash123$EvanPasswordHash456',
'2024-09-05 09:15:45.987654',
'2024-09-05 09:15:45.987654',
'active'
),
(
6,
'Fiona',
'fiona@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$fionaHash456$FionaPasswordHash789',
'2024-09-06 16:45:12.345678',
'2024-09-06 16:45:12.345678',
'active'
),
(
7,
'George',
'george@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$georgeHash789$GeorgePasswordHash012',
'2024-09-07 11:30:25.876543',
'2024-09-07 11:30:25.876543',
'active'
),
(
8,
'Helen',
'helen@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$helenHash012$HelenPasswordHash345',
'2024-09-08 13:25:40.234567',
'2024-09-08 13:25:40.234567',
'active'
),
(
9,
'Ian',
'ian@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$ianHash345$IanPasswordHash678',
'2024-09-09 08:40:55.765432',
'2024-09-09 08:40:55.765432',
'active'
),
(
10,
'Julia',
'julia@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$juliaHash678$JuliaPasswordHash901',
'2024-09-10 15:55:18.123456',
'2024-09-10 15:55:18.123456',
'active'
);
ALTER TABLE "public"."messages"
ADD FOREIGN KEY ("user_id") REFERENCES "public"."users"("id");
-- set pk to 13
ALTER SEQUENCE users_id_seq RESTART WITH 13;
-- set pk to 11
ALTER SEQUENCE users_id_seq RESTART WITH 11;
-- Indices
CREATE UNIQUE INDEX name_idx ON public.users USING btree (name);
CREATE UNIQUE INDEX email_idx ON public.users USING btree (email);
DROP INDEX IF EXISTS users_email_unique;
CREATE UNIQUE INDEX users_email_unique ON public.users USING btree (email);

View file

@ -7,7 +7,7 @@ from arcade_postgres.tools.postgres import (
DatabaseEngine,
discover_schemas,
discover_tables,
execute_query,
execute_select_query,
get_table_schema,
)
from arcade_tdk import ToolContext, ToolSecretItem
@ -89,31 +89,101 @@ async def test_get_table_schema(mock_context) -> None:
@pytest.mark.asyncio
async def test_execute_query(mock_context) -> None:
assert await execute_query(mock_context, "SELECT id, name, email FROM users WHERE id = 1") == [
"(1, 'Mario', 'mario@example.com')"
async def test_execute_select_query(mock_context) -> None:
assert await execute_select_query(
mock_context,
select_clause="id, name, email",
from_clause="users",
where_clause="id = 1",
) == [
"(1, 'Alice', 'alice@example.com')",
]
assert await execute_select_query(
mock_context,
select_clause="id, name, email",
from_clause="users",
order_by_clause="id",
limit=1,
offset=1,
) == [
"(2, 'Bob', 'bob@example.com')",
]
@pytest.mark.asyncio
async def test_execute_query_with_no_results(mock_context) -> None:
# does not raise an error
assert await execute_query(mock_context, "SELECT * FROM users WHERE id = 9999999999") == []
async def test_execute_select_query_with_keywords(mock_context) -> None:
assert await execute_select_query(
mock_context,
select_clause="SELECT id, name, email",
from_clause="FROM users",
limit=1,
) == [
"(1, 'Alice', 'alice@example.com')",
]
@pytest.mark.asyncio
async def test_execute_query_with_problem(mock_context) -> None:
async def test_execute_select_query_with_join(mock_context) -> None:
assert await execute_select_query(
mock_context,
select_clause="u.id, u.name, u.email, m.id, m.body",
from_clause="users u",
join_clause="messages m ON u.id = m.user_id",
limit=1,
) == [
"(1, 'Alice', 'alice@example.com', 1, 'Hello everyone!')",
]
@pytest.mark.asyncio
async def test_execute_select_query_with_group_by(mock_context) -> None:
assert await execute_select_query(
mock_context,
select_clause="u.name, COUNT(m.id) AS message_count",
from_clause="messages m",
join_clause="users u ON m.user_id = u.id",
group_by_clause="u.name",
order_by_clause="message_count DESC",
limit=2,
) == [
"('Evan', 13)",
"('Alice', 3)",
]
@pytest.mark.asyncio
async def test_execute_select_query_with_no_results(mock_context) -> None:
# does not raise an error
assert (
await execute_select_query(
mock_context,
select_clause="id, name, email",
from_clause="users",
where_clause="id = 9999999999",
)
== []
)
@pytest.mark.asyncio
async def test_execute_select_query_with_problem(mock_context) -> None:
# 'foo' is not a valid id
with pytest.raises(RetryableToolError) as e:
await execute_query(mock_context, "SELECT * FROM users WHERE id = 'foo'")
assert "invalid input syntax" in str(e.value)
await execute_select_query(
mock_context,
select_clause="*",
from_clause="users",
where_clause="id = 'foo'",
)
assert "Do not use * in the select clause" in str(e.value)
@pytest.mark.asyncio
async def test_execute_query_rejects_non_select(mock_context) -> None:
async def test_execute_select_query_rejects_non_select(mock_context) -> None:
with pytest.raises(RetryableToolError) as e:
await execute_query(
await execute_select_query(
mock_context,
"INSERT INTO users (name, email, password_hash) VALUES ('Luigi', 'luigi@example.com', 'password')",
select_clause="INSERT INTO users (name, email, password_hash) VALUES ('Luigi', 'luigi@example.com', 'password')",
from_clause="users",
)
assert "Only SELECT queries are allowed" in str(e.value)