From 1994f9d4c4bb807fd5460919b4352b715ef1324a Mon Sep 17 00:00:00 2001 From: Ashok Saravanan <90977640+AshokSaravanan222@users.noreply.github.com> Date: Wed, 14 May 2025 11:34:27 -0500 Subject: [PATCH] feat: pass extra_body through to LiteLLM acompletion (#638) **Purpose** Allow arbitrary `extra_body` parameters (e.g. `cached_content`) to be forwarded into the LiteLLM call. Useful for context caching in Gemini models ([docs](https://ai.google.dev/gemini-api/docs/caching?lang=python)). **Example usage** ```python import os from agents import Agent, ModelSettings from agents.extensions.models.litellm_model import LitellmModel cache_name = "cachedContents/34jopukfx5di" # previously stored context gemini_model = LitellmModel( model="gemini/gemini-1.5-flash-002", api_key=os.getenv("GOOGLE_API_KEY") ) agent = Agent( name="Cached Gemini Agent", model=gemini_model, model_settings=ModelSettings( extra_body={"cached_content": cache_name} ) ) --- src/agents/extensions/models/litellm_model.py | 2 + tests/models/test_litellm_extra_body.py | 45 +++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 tests/models/test_litellm_extra_body.py diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index dc672ac..d3b25a1 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -269,6 +269,8 @@ class LitellmModel(Model): extra_kwargs["extra_query"] = model_settings.extra_query if model_settings.metadata: extra_kwargs["metadata"] = model_settings.metadata + if model_settings.extra_body and isinstance(model_settings.extra_body, dict): + extra_kwargs.update(model_settings.extra_body) ret = await litellm.acompletion( model=self.model, diff --git a/tests/models/test_litellm_extra_body.py b/tests/models/test_litellm_extra_body.py new file mode 100644 index 0000000..ac56c25 --- /dev/null +++ b/tests/models/test_litellm_extra_body.py @@ -0,0 +1,45 @@ +import litellm +import pytest +from litellm.types.utils import Choices, Message, ModelResponse, Usage + +from agents.extensions.models.litellm_model import LitellmModel +from agents.model_settings import ModelSettings +from agents.models.interface import ModelTracing + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_extra_body_is_forwarded(monkeypatch): + """ + Forward `extra_body` entries into litellm.acompletion kwargs. + + This ensures that user-provided parameters (e.g. cached_content) + arrive alongside default arguments. + """ + captured: dict[str, object] = {} + + async def fake_acompletion(model, messages=None, **kwargs): + captured.update(kwargs) + msg = Message(role="assistant", content="ok") + choice = Choices(index=0, message=msg) + return ModelResponse(choices=[choice], usage=Usage(0, 0, 0)) + + monkeypatch.setattr(litellm, "acompletion", fake_acompletion) + settings = ModelSettings( + temperature=0.1, + extra_body={"cached_content": "some_cache", "foo": 123} + ) + model = LitellmModel(model="test-model") + + await model.get_response( + system_instructions=None, + input=[], + model_settings=settings, + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + + assert {"cached_content": "some_cache", "foo": 123}.items() <= captured.items()