feat: pass extra_body through to LiteLLM acompletion (#638)

**Purpose**  
Allow arbitrary `extra_body` parameters (e.g. `cached_content`) to be
forwarded into the LiteLLM call. Useful for context caching in Gemini
models
([docs](https://ai.google.dev/gemini-api/docs/caching?lang=python)).

**Example usage**  
```python
import os
from agents import Agent, ModelSettings
from agents.extensions.models.litellm_model import LitellmModel

cache_name = "cachedContents/34jopukfx5di"  # previously stored context

gemini_model = LitellmModel(
    model="gemini/gemini-1.5-flash-002",
    api_key=os.getenv("GOOGLE_API_KEY")
)

agent = Agent(
    name="Cached Gemini Agent",
    model=gemini_model,
    model_settings=ModelSettings(
        extra_body={"cached_content": cache_name}
    )
)
This commit is contained in:
Ashok Saravanan 2025-05-14 11:34:27 -05:00 committed by GitHub
parent 2c46dae377
commit 1994f9d4c4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 47 additions and 0 deletions

View file

@ -269,6 +269,8 @@ class LitellmModel(Model):
extra_kwargs["extra_query"] = model_settings.extra_query
if model_settings.metadata:
extra_kwargs["metadata"] = model_settings.metadata
if model_settings.extra_body and isinstance(model_settings.extra_body, dict):
extra_kwargs.update(model_settings.extra_body)
ret = await litellm.acompletion(
model=self.model,

View file

@ -0,0 +1,45 @@
import litellm
import pytest
from litellm.types.utils import Choices, Message, ModelResponse, Usage
from agents.extensions.models.litellm_model import LitellmModel
from agents.model_settings import ModelSettings
from agents.models.interface import ModelTracing
@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_extra_body_is_forwarded(monkeypatch):
"""
Forward `extra_body` entries into litellm.acompletion kwargs.
This ensures that user-provided parameters (e.g. cached_content)
arrive alongside default arguments.
"""
captured: dict[str, object] = {}
async def fake_acompletion(model, messages=None, **kwargs):
captured.update(kwargs)
msg = Message(role="assistant", content="ok")
choice = Choices(index=0, message=msg)
return ModelResponse(choices=[choice], usage=Usage(0, 0, 0))
monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
settings = ModelSettings(
temperature=0.1,
extra_body={"cached_content": "some_cache", "foo": 123}
)
model = LitellmModel(model="test-model")
await model.get_response(
system_instructions=None,
input=[],
model_settings=settings,
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
)
assert {"cached_content": "some_cache", "foo": 123}.items() <= captured.items()