add gemini support
This commit is contained in:
parent
7648caca7b
commit
aaa7831ab1
4 changed files with 28 additions and 2 deletions
|
|
@ -1,5 +1,6 @@
|
|||
from open_notebook.llms import (
|
||||
AnthropicLanguageModel,
|
||||
GeminiLanguageModel,
|
||||
LiteLLMLanguageModel,
|
||||
OllamaLanguageModel,
|
||||
OpenAILanguageModel,
|
||||
|
|
@ -17,6 +18,7 @@ PROVIDER_CLASS_MAP = {
|
|||
"vertexai": VertexAILanguageModel,
|
||||
"anthropic": AnthropicLanguageModel,
|
||||
"openai": OpenAILanguageModel,
|
||||
"gemini": GeminiLanguageModel,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import Any, Dict, Optional
|
|||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_google_vertexai import ChatVertexAI
|
||||
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
|
||||
from langchain_ollama.chat_models import ChatOllama
|
||||
|
|
@ -62,7 +63,7 @@ class OllamaLanguageModel(LanguageModel):
|
|||
base_url=self.base_url,
|
||||
# keep_alive="10m",
|
||||
num_predict=self.max_tokens,
|
||||
temperature=self.temperature,
|
||||
temperature=self.temperature or 0.5,
|
||||
verbose=True,
|
||||
top_p=self.top_p,
|
||||
)
|
||||
|
|
@ -90,6 +91,7 @@ class VertexAnthropicLanguageModel(LanguageModel):
|
|||
streaming=False,
|
||||
kwargs=self.kwargs,
|
||||
top_p=self.top_p,
|
||||
temperature=self.temperature or 0.5,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -136,6 +138,26 @@ class VertexAILanguageModel(LanguageModel):
|
|||
location=self.location,
|
||||
project=self.project,
|
||||
safety_settings=None,
|
||||
temperature=self.temperature or 0.5,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeminiLanguageModel(LanguageModel):
|
||||
"""
|
||||
Language model that uses the Gemini Family of chat models.
|
||||
"""
|
||||
|
||||
model_name: str
|
||||
|
||||
def to_langchain(self) -> ChatGoogleGenerativeAI:
|
||||
"""
|
||||
Convert the language model to a LangChain chat model.
|
||||
"""
|
||||
return ChatGoogleGenerativeAI(
|
||||
model=self.model_name,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=self.temperature or 0.5,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -188,6 +210,7 @@ class AnthropicLanguageModel(LanguageModel):
|
|||
streaming=False,
|
||||
timeout=30,
|
||||
top_p=self.top_p,
|
||||
temperature=self.temperature or 0.5,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
2
poetry.lock
generated
2
poetry.lock
generated
|
|
@ -6063,4 +6063,4 @@ type = ["pytest-mypy"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "b92bbd2ce61e78ccc2e182627cf0ba5d98ccf849898e5e941d5d17e74a7827ab"
|
||||
content-hash = "5f7bdea405c6c6433fa805b3321ac1550e13deee0d3a3c04e38136cd6992f5b1"
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ langchain-anthropic = "^0.2.3"
|
|||
langchain-ollama = "^0.2.0"
|
||||
langchain-google-vertexai = "^2.0.5"
|
||||
sdblpy = "^0.3.0"
|
||||
langchain-google-genai = "^2.0.1"
|
||||
podcastfy = "^0.2.8"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
|
|
|||
Loading…
Reference in a new issue