add support for GROQ models
This commit is contained in:
parent
9ba5709a3c
commit
321234e485
6 changed files with 95 additions and 10 deletions
|
|
@ -10,6 +10,7 @@ from open_notebook.models.embedding_models import (
|
|||
from open_notebook.models.llms import (
|
||||
AnthropicLanguageModel,
|
||||
GeminiLanguageModel,
|
||||
GroqLanguageModel,
|
||||
LanguageModel,
|
||||
LiteLLMLanguageModel,
|
||||
OllamaLanguageModel,
|
||||
|
|
@ -20,6 +21,7 @@ from open_notebook.models.llms import (
|
|||
XAILanguageModel,
|
||||
)
|
||||
from open_notebook.models.speech_to_text_models import (
|
||||
GroqSpeechToTextModel,
|
||||
OpenAISpeechToTextModel,
|
||||
SpeechToTextModel,
|
||||
)
|
||||
|
|
@ -46,6 +48,7 @@ MODEL_CLASS_MAP: Dict[str, ProviderMap] = {
|
|||
"openai": OpenAILanguageModel,
|
||||
"gemini": GeminiLanguageModel,
|
||||
"xai": XAILanguageModel,
|
||||
"groq": GroqLanguageModel,
|
||||
},
|
||||
"embedding": {
|
||||
"openai": OpenAIEmbeddingModel,
|
||||
|
|
@ -55,6 +58,7 @@ MODEL_CLASS_MAP: Dict[str, ProviderMap] = {
|
|||
},
|
||||
"speech_to_text": {
|
||||
"openai": OpenAISpeechToTextModel,
|
||||
"groq": GroqSpeechToTextModel,
|
||||
},
|
||||
"text_to_speech": {
|
||||
"openai": OpenAITextToSpeechModel,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from langchain_core.language_models.chat_models import BaseChatModel
|
|||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_google_vertexai import ChatVertexAI
|
||||
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
|
||||
from langchain_groq.chat_models import ChatGroq
|
||||
from langchain_ollama.chat_models import ChatOllama
|
||||
from langchain_openai.chat_models import ChatOpenAI
|
||||
from pydantic import SecretStr
|
||||
|
|
@ -191,6 +192,30 @@ class OpenRouterLanguageModel(LanguageModel):
|
|||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroqLanguageModel(LanguageModel):
|
||||
"""
|
||||
Language model that uses the Groq chat model.
|
||||
"""
|
||||
|
||||
model_name: str
|
||||
|
||||
def to_langchain(self) -> ChatGroq:
|
||||
"""
|
||||
Convert the language model to a LangChain chat model for Groq.
|
||||
"""
|
||||
kwargs = self.kwargs
|
||||
kwargs["top_p"] = self.top_p
|
||||
|
||||
return ChatGroq(
|
||||
model=self.model_name,
|
||||
temperature=self.temperature or 0.5,
|
||||
max_tokens=self.max_tokens,
|
||||
model_kwargs=kwargs,
|
||||
stop_sequences=None,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class XAILanguageModel(LanguageModel):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -40,3 +40,22 @@ class OpenAISpeechToTextModel(SpeechToTextModel):
|
|||
model=self.model_name, file=audio
|
||||
)
|
||||
return transcription.text
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroqSpeechToTextModel(SpeechToTextModel):
|
||||
model_name: str
|
||||
|
||||
def transcribe(self, audio_file_path: str) -> str:
|
||||
"""
|
||||
Transcribes an audio file into text
|
||||
"""
|
||||
from groq import Groq
|
||||
|
||||
# todo: make this Singleton
|
||||
client = Groq()
|
||||
with open(audio_file_path, "rb") as audio:
|
||||
transcription = client.audio.transcriptions.create(
|
||||
model=self.model_name, file=audio
|
||||
)
|
||||
return transcription.text
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ model_types = [
|
|||
|
||||
provider_status["ollama"] = os.environ.get("OLLAMA_API_BASE") is not None
|
||||
provider_status["openai"] = os.environ.get("OPENAI_API_KEY") is not None
|
||||
provider_status["groq"] = os.environ.get("GROQ_API_KEY") is not None
|
||||
provider_status["xai"] = os.environ.get("XAI_API_KEY") is not None
|
||||
provider_status["vertexai"] = (
|
||||
os.environ.get("VERTEX_PROJECT") is not None
|
||||
|
|
|
|||
54
poetry.lock
generated
54
poetry.lock
generated
|
|
@ -1166,8 +1166,8 @@ googleapis-common-protos = ">=1.56.2,<2.0.dev0"
|
|||
grpcio = {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}
|
||||
grpcio-status = {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}
|
||||
proto-plus = [
|
||||
{version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""},
|
||||
{version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""},
|
||||
]
|
||||
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0"
|
||||
requests = ">=2.18.0,<3.0.0.dev0"
|
||||
|
|
@ -1346,8 +1346,8 @@ google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extr
|
|||
google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev"
|
||||
grpc-google-iam-v1 = ">=0.12.4,<1.0.0dev"
|
||||
proto-plus = [
|
||||
{version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""},
|
||||
{version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""},
|
||||
]
|
||||
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev"
|
||||
|
||||
|
|
@ -1389,8 +1389,8 @@ files = [
|
|||
google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
|
||||
google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev"
|
||||
proto-plus = [
|
||||
{version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""},
|
||||
{version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""},
|
||||
]
|
||||
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev"
|
||||
|
||||
|
|
@ -1578,6 +1578,25 @@ files = [
|
|||
docs = ["Sphinx", "furo"]
|
||||
test = ["objgraph", "psutil"]
|
||||
|
||||
[[package]]
|
||||
name = "groq"
|
||||
version = "0.12.0"
|
||||
description = "The official Python library for the groq API"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "groq-0.12.0-py3-none-any.whl", hash = "sha256:e8aa1529f82a01b2d15394b7ea242af9ee9387f65bdd1b91ce9a10f5a911dac1"},
|
||||
{file = "groq-0.12.0.tar.gz", hash = "sha256:569229e2dadfc428b0df3d2987407691a4e3bc035b5849a65ef4909514a4605e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=3.5.0,<5"
|
||||
distro = ">=1.7.0,<2"
|
||||
httpx = ">=0.23.0,<1"
|
||||
pydantic = ">=1.9.0,<3"
|
||||
sniffio = "*"
|
||||
typing-extensions = ">=4.7,<5"
|
||||
|
||||
[[package]]
|
||||
name = "grpc-google-iam-v1"
|
||||
version = "0.13.1"
|
||||
|
|
@ -2247,8 +2266,8 @@ langchain-core = ">=0.3.15,<0.4.0"
|
|||
langchain-text-splitters = ">=0.3.0,<0.4.0"
|
||||
langsmith = ">=0.1.17,<0.2.0"
|
||||
numpy = [
|
||||
{version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=1,<2", markers = "python_version < \"3.12\""},
|
||||
{version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""},
|
||||
]
|
||||
pydantic = ">=2.7.4,<3.0.0"
|
||||
PyYAML = ">=5.3"
|
||||
|
|
@ -2292,8 +2311,8 @@ langchain = ">=0.3.6,<0.4.0"
|
|||
langchain-core = ">=0.3.14,<0.4.0"
|
||||
langsmith = ">=0.1.125,<0.2.0"
|
||||
numpy = [
|
||||
{version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=1,<2", markers = "python_version < \"3.12\""},
|
||||
{version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""},
|
||||
]
|
||||
pydantic-settings = ">=2.4.0,<3.0.0"
|
||||
PyYAML = ">=5.3"
|
||||
|
|
@ -2317,8 +2336,8 @@ jsonpatch = ">=1.33,<2.0"
|
|||
langsmith = ">=0.1.125,<0.2.0"
|
||||
packaging = ">=23.2,<25"
|
||||
pydantic = [
|
||||
{version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""},
|
||||
{version = ">=2.5.2,<3.0.0", markers = "python_full_version < \"3.12.4\""},
|
||||
{version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""},
|
||||
]
|
||||
PyYAML = ">=5.3"
|
||||
tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10.0.0"
|
||||
|
|
@ -2366,6 +2385,21 @@ pydantic = ">=2,<3"
|
|||
anthropic = ["anthropic[vertexai] (>=0.35.0,<1)"]
|
||||
mistral = ["langchain-mistralai (>=0.2.0,<1)"]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-groq"
|
||||
version = "0.2.1"
|
||||
description = "An integration package connecting Groq and LangChain"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
files = [
|
||||
{file = "langchain_groq-0.2.1-py3-none-any.whl", hash = "sha256:98d282fd9d7d99b0f55de0a1daea2d5d350ef697e3cb5e97de06aeba4eca8679"},
|
||||
{file = "langchain_groq-0.2.1.tar.gz", hash = "sha256:a59c81d1a15dc97abf4fdb4c2589f98109313eda147e6b378829222d4d929792"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
groq = ">=0.4.1,<1"
|
||||
langchain-core = ">=0.3.15,<0.4.0"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-ollama"
|
||||
version = "0.2.0"
|
||||
|
|
@ -2502,8 +2536,8 @@ files = [
|
|||
httpx = ">=0.23.0,<1"
|
||||
orjson = ">=3.9.14,<4.0.0"
|
||||
pydantic = [
|
||||
{version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""},
|
||||
{version = ">=1,<3", markers = "python_full_version < \"3.12.4\""},
|
||||
{version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""},
|
||||
]
|
||||
requests = ">=2,<3"
|
||||
requests-toolbelt = ">=1.0.0,<2.0.0"
|
||||
|
|
@ -3562,8 +3596,8 @@ files = [
|
|||
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
]
|
||||
python-dateutil = ">=2.8.2"
|
||||
pytz = ">=2020.1"
|
||||
|
|
@ -4186,8 +4220,8 @@ files = [
|
|||
annotated-types = ">=0.6.0"
|
||||
pydantic-core = "2.23.4"
|
||||
typing-extensions = [
|
||||
{version = ">=4.12.2", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=4.6.1", markers = "python_version < \"3.13\""},
|
||||
{version = ">=4.12.2", markers = "python_version >= \"3.13\""},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
|
|
@ -6431,4 +6465,4 @@ type = ["pytest-mypy"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "d7a79be658b4a848e346be1958ce4ef50da966a00319ddb5b9edc24be96c5aba"
|
||||
content-hash = "93b2d5c2ae9dd34b47c12f14b07b76d7d48c57c5eec78b09ae08a1d3a3e747dd"
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ python-docx = "^1.1.2"
|
|||
python-pptx = "^1.0.2"
|
||||
openpyxl = "^3.1.5"
|
||||
google-generativeai = "^0.8.3"
|
||||
langchain-groq = "^0.2.1"
|
||||
groq = "^0.12.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ipykernel = "^6.29.5"
|
||||
|
|
|
|||
Loading…
Reference in a new issue