diff --git a/open_notebook/models/__init__.py b/open_notebook/models/__init__.py index 3e3b72a..c131abd 100644 --- a/open_notebook/models/__init__.py +++ b/open_notebook/models/__init__.py @@ -10,6 +10,7 @@ from open_notebook.models.embedding_models import ( from open_notebook.models.llms import ( AnthropicLanguageModel, GeminiLanguageModel, + GroqLanguageModel, LanguageModel, LiteLLMLanguageModel, OllamaLanguageModel, @@ -20,6 +21,7 @@ from open_notebook.models.llms import ( XAILanguageModel, ) from open_notebook.models.speech_to_text_models import ( + GroqSpeechToTextModel, OpenAISpeechToTextModel, SpeechToTextModel, ) @@ -46,6 +48,7 @@ MODEL_CLASS_MAP: Dict[str, ProviderMap] = { "openai": OpenAILanguageModel, "gemini": GeminiLanguageModel, "xai": XAILanguageModel, + "groq": GroqLanguageModel, }, "embedding": { "openai": OpenAIEmbeddingModel, @@ -55,6 +58,7 @@ MODEL_CLASS_MAP: Dict[str, ProviderMap] = { }, "speech_to_text": { "openai": OpenAISpeechToTextModel, + "groq": GroqSpeechToTextModel, }, "text_to_speech": { "openai": OpenAITextToSpeechModel, diff --git a/open_notebook/models/llms.py b/open_notebook/models/llms.py index 018a604..03d8c67 100644 --- a/open_notebook/models/llms.py +++ b/open_notebook/models/llms.py @@ -13,6 +13,7 @@ from langchain_core.language_models.chat_models import BaseChatModel from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_vertexai import ChatVertexAI from langchain_google_vertexai.model_garden import ChatAnthropicVertex +from langchain_groq.chat_models import ChatGroq from langchain_ollama.chat_models import ChatOllama from langchain_openai.chat_models import ChatOpenAI from pydantic import SecretStr @@ -191,6 +192,30 @@ class OpenRouterLanguageModel(LanguageModel): ) +@dataclass +class GroqLanguageModel(LanguageModel): + """ + Language model that uses the Groq chat model. + """ + + model_name: str + + def to_langchain(self) -> ChatGroq: + """ + Convert the language model to a LangChain chat model for Groq. + """ + kwargs = self.kwargs + kwargs["top_p"] = self.top_p + + return ChatGroq( + model=self.model_name, + temperature=self.temperature or 0.5, + max_tokens=self.max_tokens, + model_kwargs=kwargs, + stop_sequences=None, + ) + + @dataclass class XAILanguageModel(LanguageModel): """ diff --git a/open_notebook/models/speech_to_text_models.py b/open_notebook/models/speech_to_text_models.py index aa89d51..113339b 100644 --- a/open_notebook/models/speech_to_text_models.py +++ b/open_notebook/models/speech_to_text_models.py @@ -40,3 +40,22 @@ class OpenAISpeechToTextModel(SpeechToTextModel): model=self.model_name, file=audio ) return transcription.text + + +@dataclass +class GroqSpeechToTextModel(SpeechToTextModel): + model_name: str + + def transcribe(self, audio_file_path: str) -> str: + """ + Transcribes an audio file into text + """ + from groq import Groq + + # todo: make this Singleton + client = Groq() + with open(audio_file_path, "rb") as audio: + transcription = client.audio.transcriptions.create( + model=self.model_name, file=audio + ) + return transcription.text diff --git a/pages/7_⚙️_Settings.py b/pages/7_⚙️_Settings.py index 955c995..67108f2 100644 --- a/pages/7_⚙️_Settings.py +++ b/pages/7_⚙️_Settings.py @@ -30,6 +30,7 @@ model_types = [ provider_status["ollama"] = os.environ.get("OLLAMA_API_BASE") is not None provider_status["openai"] = os.environ.get("OPENAI_API_KEY") is not None +provider_status["groq"] = os.environ.get("GROQ_API_KEY") is not None provider_status["xai"] = os.environ.get("XAI_API_KEY") is not None provider_status["vertexai"] = ( os.environ.get("VERTEX_PROJECT") is not None diff --git a/poetry.lock b/poetry.lock index 11dffc5..0f80bfd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1166,8 +1166,8 @@ googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""} grpcio-status = {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""} proto-plus = [ - {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, {version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""}, + {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" requests = ">=2.18.0,<3.0.0.dev0" @@ -1346,8 +1346,8 @@ google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extr google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev" grpc-google-iam-v1 = ">=0.12.4,<1.0.0dev" proto-plus = [ - {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, {version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""}, + {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, ] protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev" @@ -1389,8 +1389,8 @@ files = [ google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev" proto-plus = [ - {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, {version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""}, + {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, ] protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev" @@ -1578,6 +1578,25 @@ files = [ docs = ["Sphinx", "furo"] test = ["objgraph", "psutil"] +[[package]] +name = "groq" +version = "0.12.0" +description = "The official Python library for the groq API" +optional = false +python-versions = ">=3.8" +files = [ + {file = "groq-0.12.0-py3-none-any.whl", hash = "sha256:e8aa1529f82a01b2d15394b7ea242af9ee9387f65bdd1b91ce9a10f5a911dac1"}, + {file = "groq-0.12.0.tar.gz", hash = "sha256:569229e2dadfc428b0df3d2987407691a4e3bc035b5849a65ef4909514a4605e"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +typing-extensions = ">=4.7,<5" + [[package]] name = "grpc-google-iam-v1" version = "0.13.1" @@ -2247,8 +2266,8 @@ langchain-core = ">=0.3.15,<0.4.0" langchain-text-splitters = ">=0.3.0,<0.4.0" langsmith = ">=0.1.17,<0.2.0" numpy = [ - {version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""}, {version = ">=1,<2", markers = "python_version < \"3.12\""}, + {version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""}, ] pydantic = ">=2.7.4,<3.0.0" PyYAML = ">=5.3" @@ -2292,8 +2311,8 @@ langchain = ">=0.3.6,<0.4.0" langchain-core = ">=0.3.14,<0.4.0" langsmith = ">=0.1.125,<0.2.0" numpy = [ - {version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""}, {version = ">=1,<2", markers = "python_version < \"3.12\""}, + {version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""}, ] pydantic-settings = ">=2.4.0,<3.0.0" PyYAML = ">=5.3" @@ -2317,8 +2336,8 @@ jsonpatch = ">=1.33,<2.0" langsmith = ">=0.1.125,<0.2.0" packaging = ">=23.2,<25" pydantic = [ - {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""}, {version = ">=2.5.2,<3.0.0", markers = "python_full_version < \"3.12.4\""}, + {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""}, ] PyYAML = ">=5.3" tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10.0.0" @@ -2366,6 +2385,21 @@ pydantic = ">=2,<3" anthropic = ["anthropic[vertexai] (>=0.35.0,<1)"] mistral = ["langchain-mistralai (>=0.2.0,<1)"] +[[package]] +name = "langchain-groq" +version = "0.2.1" +description = "An integration package connecting Groq and LangChain" +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "langchain_groq-0.2.1-py3-none-any.whl", hash = "sha256:98d282fd9d7d99b0f55de0a1daea2d5d350ef697e3cb5e97de06aeba4eca8679"}, + {file = "langchain_groq-0.2.1.tar.gz", hash = "sha256:a59c81d1a15dc97abf4fdb4c2589f98109313eda147e6b378829222d4d929792"}, +] + +[package.dependencies] +groq = ">=0.4.1,<1" +langchain-core = ">=0.3.15,<0.4.0" + [[package]] name = "langchain-ollama" version = "0.2.0" @@ -2502,8 +2536,8 @@ files = [ httpx = ">=0.23.0,<1" orjson = ">=3.9.14,<4.0.0" pydantic = [ - {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""}, {version = ">=1,<3", markers = "python_full_version < \"3.12.4\""}, + {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""}, ] requests = ">=2,<3" requests-toolbelt = ">=1.0.0,<2.0.0" @@ -3562,8 +3596,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -4186,8 +4220,8 @@ files = [ annotated-types = ">=0.6.0" pydantic-core = "2.23.4" typing-extensions = [ - {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, {version = ">=4.6.1", markers = "python_version < \"3.13\""}, + {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, ] [package.extras] @@ -6431,4 +6465,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "d7a79be658b4a848e346be1958ce4ef50da966a00319ddb5b9edc24be96c5aba" +content-hash = "93b2d5c2ae9dd34b47c12f14b07b76d7d48c57c5eec78b09ae08a1d3a3e747dd" diff --git a/pyproject.toml b/pyproject.toml index a4a623f..9ae242f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,8 @@ python-docx = "^1.1.2" python-pptx = "^1.0.2" openpyxl = "^3.1.5" google-generativeai = "^0.8.3" +langchain-groq = "^0.2.1" +groq = "^0.12.0" [tool.poetry.group.dev.dependencies] ipykernel = "^6.29.5"