feat: pass extra_body through to LiteLLM acompletion (#638)
**Purpose** Allow arbitrary `extra_body` parameters (e.g. `cached_content`) to be forwarded into the LiteLLM call. Useful for context caching in Gemini models ([docs](https://ai.google.dev/gemini-api/docs/caching?lang=python)). **Example usage** ```python import os from agents import Agent, ModelSettings from agents.extensions.models.litellm_model import LitellmModel cache_name = "cachedContents/34jopukfx5di" # previously stored context gemini_model = LitellmModel( model="gemini/gemini-1.5-flash-002", api_key=os.getenv("GOOGLE_API_KEY") ) agent = Agent( name="Cached Gemini Agent", model=gemini_model, model_settings=ModelSettings( extra_body={"cached_content": cache_name} ) )
This commit is contained in:
parent
2c46dae377
commit
1994f9d4c4
2 changed files with 47 additions and 0 deletions
|
|
@ -269,6 +269,8 @@ class LitellmModel(Model):
|
||||||
extra_kwargs["extra_query"] = model_settings.extra_query
|
extra_kwargs["extra_query"] = model_settings.extra_query
|
||||||
if model_settings.metadata:
|
if model_settings.metadata:
|
||||||
extra_kwargs["metadata"] = model_settings.metadata
|
extra_kwargs["metadata"] = model_settings.metadata
|
||||||
|
if model_settings.extra_body and isinstance(model_settings.extra_body, dict):
|
||||||
|
extra_kwargs.update(model_settings.extra_body)
|
||||||
|
|
||||||
ret = await litellm.acompletion(
|
ret = await litellm.acompletion(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
|
|
|
||||||
45
tests/models/test_litellm_extra_body.py
Normal file
45
tests/models/test_litellm_extra_body.py
Normal file
|
|
@ -0,0 +1,45 @@
|
||||||
|
import litellm
|
||||||
|
import pytest
|
||||||
|
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
||||||
|
|
||||||
|
from agents.extensions.models.litellm_model import LitellmModel
|
||||||
|
from agents.model_settings import ModelSettings
|
||||||
|
from agents.models.interface import ModelTracing
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.allow_call_model_methods
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extra_body_is_forwarded(monkeypatch):
|
||||||
|
"""
|
||||||
|
Forward `extra_body` entries into litellm.acompletion kwargs.
|
||||||
|
|
||||||
|
This ensures that user-provided parameters (e.g. cached_content)
|
||||||
|
arrive alongside default arguments.
|
||||||
|
"""
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
|
async def fake_acompletion(model, messages=None, **kwargs):
|
||||||
|
captured.update(kwargs)
|
||||||
|
msg = Message(role="assistant", content="ok")
|
||||||
|
choice = Choices(index=0, message=msg)
|
||||||
|
return ModelResponse(choices=[choice], usage=Usage(0, 0, 0))
|
||||||
|
|
||||||
|
monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
|
||||||
|
settings = ModelSettings(
|
||||||
|
temperature=0.1,
|
||||||
|
extra_body={"cached_content": "some_cache", "foo": 123}
|
||||||
|
)
|
||||||
|
model = LitellmModel(model="test-model")
|
||||||
|
|
||||||
|
await model.get_response(
|
||||||
|
system_instructions=None,
|
||||||
|
input=[],
|
||||||
|
model_settings=settings,
|
||||||
|
tools=[],
|
||||||
|
output_schema=None,
|
||||||
|
handoffs=[],
|
||||||
|
tracing=ModelTracing.DISABLED,
|
||||||
|
previous_response_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert {"cached_content": "some_cache", "foo": 123}.items() <= captured.items()
|
||||||
Loading…
Reference in a new issue