add support for GROQ models

This commit is contained in:
LUIS NOVO 2024-11-13 15:09:56 -03:00
parent 9ba5709a3c
commit 321234e485
6 changed files with 95 additions and 10 deletions

View file

@ -10,6 +10,7 @@ from open_notebook.models.embedding_models import (
from open_notebook.models.llms import (
AnthropicLanguageModel,
GeminiLanguageModel,
GroqLanguageModel,
LanguageModel,
LiteLLMLanguageModel,
OllamaLanguageModel,
@ -20,6 +21,7 @@ from open_notebook.models.llms import (
XAILanguageModel,
)
from open_notebook.models.speech_to_text_models import (
GroqSpeechToTextModel,
OpenAISpeechToTextModel,
SpeechToTextModel,
)
@ -46,6 +48,7 @@ MODEL_CLASS_MAP: Dict[str, ProviderMap] = {
"openai": OpenAILanguageModel,
"gemini": GeminiLanguageModel,
"xai": XAILanguageModel,
"groq": GroqLanguageModel,
},
"embedding": {
"openai": OpenAIEmbeddingModel,
@ -55,6 +58,7 @@ MODEL_CLASS_MAP: Dict[str, ProviderMap] = {
},
"speech_to_text": {
"openai": OpenAISpeechToTextModel,
"groq": GroqSpeechToTextModel,
},
"text_to_speech": {
"openai": OpenAITextToSpeechModel,

View file

@ -13,6 +13,7 @@ from langchain_core.language_models.chat_models import BaseChatModel
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_google_vertexai import ChatVertexAI
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
from langchain_groq.chat_models import ChatGroq
from langchain_ollama.chat_models import ChatOllama
from langchain_openai.chat_models import ChatOpenAI
from pydantic import SecretStr
@ -191,6 +192,30 @@ class OpenRouterLanguageModel(LanguageModel):
)
@dataclass
class GroqLanguageModel(LanguageModel):
"""
Language model that uses the Groq chat model.
"""
model_name: str
def to_langchain(self) -> ChatGroq:
"""
Convert the language model to a LangChain chat model for Groq.
"""
kwargs = self.kwargs
kwargs["top_p"] = self.top_p
return ChatGroq(
model=self.model_name,
temperature=self.temperature or 0.5,
max_tokens=self.max_tokens,
model_kwargs=kwargs,
stop_sequences=None,
)
@dataclass
class XAILanguageModel(LanguageModel):
"""

View file

@ -40,3 +40,22 @@ class OpenAISpeechToTextModel(SpeechToTextModel):
model=self.model_name, file=audio
)
return transcription.text
@dataclass
class GroqSpeechToTextModel(SpeechToTextModel):
model_name: str
def transcribe(self, audio_file_path: str) -> str:
"""
Transcribes an audio file into text
"""
from groq import Groq
# todo: make this Singleton
client = Groq()
with open(audio_file_path, "rb") as audio:
transcription = client.audio.transcriptions.create(
model=self.model_name, file=audio
)
return transcription.text

View file

@ -30,6 +30,7 @@ model_types = [
provider_status["ollama"] = os.environ.get("OLLAMA_API_BASE") is not None
provider_status["openai"] = os.environ.get("OPENAI_API_KEY") is not None
provider_status["groq"] = os.environ.get("GROQ_API_KEY") is not None
provider_status["xai"] = os.environ.get("XAI_API_KEY") is not None
provider_status["vertexai"] = (
os.environ.get("VERTEX_PROJECT") is not None

54
poetry.lock generated
View file

@ -1166,8 +1166,8 @@ googleapis-common-protos = ">=1.56.2,<2.0.dev0"
grpcio = {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}
grpcio-status = {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}
proto-plus = [
{version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""},
{version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""},
{version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""},
]
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0"
requests = ">=2.18.0,<3.0.0.dev0"
@ -1346,8 +1346,8 @@ google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extr
google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev"
grpc-google-iam-v1 = ">=0.12.4,<1.0.0dev"
proto-plus = [
{version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""},
{version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""},
{version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""},
]
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev"
@ -1389,8 +1389,8 @@ files = [
google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev"
proto-plus = [
{version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""},
{version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""},
{version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""},
]
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev"
@ -1578,6 +1578,25 @@ files = [
docs = ["Sphinx", "furo"]
test = ["objgraph", "psutil"]
[[package]]
name = "groq"
version = "0.12.0"
description = "The official Python library for the groq API"
optional = false
python-versions = ">=3.8"
files = [
{file = "groq-0.12.0-py3-none-any.whl", hash = "sha256:e8aa1529f82a01b2d15394b7ea242af9ee9387f65bdd1b91ce9a10f5a911dac1"},
{file = "groq-0.12.0.tar.gz", hash = "sha256:569229e2dadfc428b0df3d2987407691a4e3bc035b5849a65ef4909514a4605e"},
]
[package.dependencies]
anyio = ">=3.5.0,<5"
distro = ">=1.7.0,<2"
httpx = ">=0.23.0,<1"
pydantic = ">=1.9.0,<3"
sniffio = "*"
typing-extensions = ">=4.7,<5"
[[package]]
name = "grpc-google-iam-v1"
version = "0.13.1"
@ -2247,8 +2266,8 @@ langchain-core = ">=0.3.15,<0.4.0"
langchain-text-splitters = ">=0.3.0,<0.4.0"
langsmith = ">=0.1.17,<0.2.0"
numpy = [
{version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""},
{version = ">=1,<2", markers = "python_version < \"3.12\""},
{version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""},
]
pydantic = ">=2.7.4,<3.0.0"
PyYAML = ">=5.3"
@ -2292,8 +2311,8 @@ langchain = ">=0.3.6,<0.4.0"
langchain-core = ">=0.3.14,<0.4.0"
langsmith = ">=0.1.125,<0.2.0"
numpy = [
{version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""},
{version = ">=1,<2", markers = "python_version < \"3.12\""},
{version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""},
]
pydantic-settings = ">=2.4.0,<3.0.0"
PyYAML = ">=5.3"
@ -2317,8 +2336,8 @@ jsonpatch = ">=1.33,<2.0"
langsmith = ">=0.1.125,<0.2.0"
packaging = ">=23.2,<25"
pydantic = [
{version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""},
{version = ">=2.5.2,<3.0.0", markers = "python_full_version < \"3.12.4\""},
{version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""},
]
PyYAML = ">=5.3"
tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10.0.0"
@ -2366,6 +2385,21 @@ pydantic = ">=2,<3"
anthropic = ["anthropic[vertexai] (>=0.35.0,<1)"]
mistral = ["langchain-mistralai (>=0.2.0,<1)"]
[[package]]
name = "langchain-groq"
version = "0.2.1"
description = "An integration package connecting Groq and LangChain"
optional = false
python-versions = "<4.0,>=3.9"
files = [
{file = "langchain_groq-0.2.1-py3-none-any.whl", hash = "sha256:98d282fd9d7d99b0f55de0a1daea2d5d350ef697e3cb5e97de06aeba4eca8679"},
{file = "langchain_groq-0.2.1.tar.gz", hash = "sha256:a59c81d1a15dc97abf4fdb4c2589f98109313eda147e6b378829222d4d929792"},
]
[package.dependencies]
groq = ">=0.4.1,<1"
langchain-core = ">=0.3.15,<0.4.0"
[[package]]
name = "langchain-ollama"
version = "0.2.0"
@ -2502,8 +2536,8 @@ files = [
httpx = ">=0.23.0,<1"
orjson = ">=3.9.14,<4.0.0"
pydantic = [
{version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""},
{version = ">=1,<3", markers = "python_full_version < \"3.12.4\""},
{version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""},
]
requests = ">=2,<3"
requests-toolbelt = ">=1.0.0,<2.0.0"
@ -3562,8 +3596,8 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
]
python-dateutil = ">=2.8.2"
pytz = ">=2020.1"
@ -4186,8 +4220,8 @@ files = [
annotated-types = ">=0.6.0"
pydantic-core = "2.23.4"
typing-extensions = [
{version = ">=4.12.2", markers = "python_version >= \"3.13\""},
{version = ">=4.6.1", markers = "python_version < \"3.13\""},
{version = ">=4.12.2", markers = "python_version >= \"3.13\""},
]
[package.extras]
@ -6431,4 +6465,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "d7a79be658b4a848e346be1958ce4ef50da966a00319ddb5b9edc24be96c5aba"
content-hash = "93b2d5c2ae9dd34b47c12f14b07b76d7d48c57c5eec78b09ae08a1d3a3e747dd"

View file

@ -47,6 +47,8 @@ python-docx = "^1.1.2"
python-pptx = "^1.0.2"
openpyxl = "^3.1.5"
google-generativeai = "^0.8.3"
langchain-groq = "^0.2.1"
groq = "^0.12.0"
[tool.poetry.group.dev.dependencies]
ipykernel = "^6.29.5"