fix tests

This commit is contained in:
Dominik Kundel 2025-03-20 09:52:15 -07:00
parent c7ce154637
commit aec066649c
4 changed files with 26 additions and 18 deletions

View file

@ -42,16 +42,14 @@ class FinancialResearchManager:
is_done=True,
hide_checkmark=True,
)
self.printer.update_item(
"start", "Starting financial research...", is_done=True)
self.printer.update_item("start", "Starting financial research...", is_done=True)
search_plan = await self._plan_searches(query)
search_results = await self._perform_searches(search_plan)
report = await self._write_report(query, search_results)
verification = await self._verify_report(report)
final_report = f"Report summary\n\n{report.short_summary}"
self.printer.update_item(
"final_report", final_report, is_done=True)
self.printer.update_item("final_report", final_report, is_done=True)
self.printer.end()
@ -76,8 +74,7 @@ class FinancialResearchManager:
async def _perform_searches(self, search_plan: FinancialSearchPlan) -> Sequence[str]:
with custom_span("Search the web"):
self.printer.update_item("searching", "Searching...")
tasks = [asyncio.create_task(self._search(item))
for item in search_plan.searches]
tasks = [asyncio.create_task(self._search(item)) for item in search_plan.searches]
results: list[str] = []
num_completed = 0
for task in asyncio.as_completed(tasks):
@ -112,8 +109,7 @@ class FinancialResearchManager:
tool_description="Use to get a short writeup of potential red flags",
custom_output_extractor=_summary_extractor,
)
writer_with_tools = writer_agent.clone(
tools=[fundamentals_tool, risk_tool])
writer_with_tools = writer_agent.clone(tools=[fundamentals_tool, risk_tool])
self.printer.update_item("writing", "Thinking about report...")
input_data = f"Original query: {query}\nSummarized search results: {search_results}"
result = Runner.run_streamed(writer_with_tools, input_data)
@ -126,8 +122,7 @@ class FinancialResearchManager:
next_message = 0
async for _ in result.stream_events():
if time.time() - last_update > 5 and next_message < len(update_messages):
self.printer.update_item(
"writing", update_messages[next_message])
self.printer.update_item("writing", update_messages[next_message])
next_message += 1
last_update = time.time()
self.printer.mark_item_done("writing")

View file

@ -10,6 +10,7 @@ class Printer:
Simple wrapper to stream status updates. Used by the financial bot
manager as it orchestrates planning, search and writing.
"""
def __init__(self, console: Console) -> None:
self.live = Live(console=console)
self.items: dict[str, tuple[str, bool]] = {}

14
tests/voice/conftest.py Normal file
View file

@ -0,0 +1,14 @@
import os
import sys
import pytest
def pytest_collection_modifyitems(config, items):
if sys.version_info[:2] == (3, 9):
this_dir = os.path.dirname(__file__)
skip_marker = pytest.mark.skip(reason="Skipped on Python 3.9")
for item in items:
if item.fspath.dirname.startswith(this_dir):
item.add_marker(skip_marker)

14
uv.lock
View file

@ -1028,21 +1028,24 @@ wheels = [
[[package]]
name = "openai"
version = "1.67.0"
version = "1.68.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "distro" },
{ name = "httpx" },
{ name = "jiter" },
{ name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
{ name = "numpy", version = "2.2.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" },
{ name = "pydantic" },
{ name = "sniffio" },
{ name = "sounddevice" },
{ name = "tqdm" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a7/63/6fd027fa4cb7c3b6bee4c3150f44803b3a7e4335f0b6e49e83a0c51c321b/openai-1.67.0.tar.gz", hash = "sha256:3b386a866396daa4bf80e05a891c50a7746ecd7863b8a27423b62136e3b8f6bc", size = 403596 }
sdist = { url = "https://files.pythonhosted.org/packages/58/ea/58102e9bfda09edc963e6e877e39cca12706b46ebf35d5fc9da7b8af10f2/openai-1.68.0.tar.gz", hash = "sha256:c570c06c9ba10f98b891ac30a3dd7b5c89ed48094c711c7a3f35fb5ade6c0757", size = 413039 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/42/de/b42ddabe211411645105ae99ad93f4f3984f53be7ced2ad441378c27f62e/openai-1.67.0-py3-none-any.whl", hash = "sha256:dbbb144f38739fc0e1d951bc67864647fca0b9ffa05aef6b70eeea9f71d79663", size = 580168 },
{ url = "https://files.pythonhosted.org/packages/a5/b6/bd67b7031572cba7d8451d82ac4a990b3a96bbd3b037634726b48ac972c8/openai-1.68.0-py3-none-any.whl", hash = "sha256:20e279b0f3a78cb4a95f3eab2a180f3ee30c6a196aeebd6bf642a4f88ab85ee1", size = 605645 },
]
[[package]]
@ -1051,14 +1054,11 @@ version = "0.0.5"
source = { editable = "." }
dependencies = [
{ name = "griffe" },
{ name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
{ name = "numpy", version = "2.2.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" },
{ name = "openai" },
{ name = "pydantic" },
{ name = "requests" },
{ name = "types-requests" },
{ name = "typing-extensions" },
{ name = "websockets" },
]
[package.optional-dependencies]
@ -1090,14 +1090,12 @@ dev = [
[package.metadata]
requires-dist = [
{ name = "griffe", specifier = ">=1.5.6,<2" },
{ name = "numpy", specifier = ">=2.0.2" },
{ name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" },
{ name = "openai", specifier = ">=1.66.5" },
{ name = "pydantic", specifier = ">=2.10,<3" },
{ name = "requests", specifier = ">=2.0,<3" },
{ name = "types-requests", specifier = ">=2.0,<3" },
{ name = "typing-extensions", specifier = ">=4.12.2,<5" },
{ name = "websockets", specifier = ">=15.0.1" },
{ name = "websockets", marker = "extra == 'voice'", specifier = ">=15.0,<16" },
]
provides-extras = ["voice"]