support ollama, vertex and gemini embedding models

This commit is contained in:
LUIS NOVO 2024-10-30 15:19:16 -03:00
parent 2f07f0ab49
commit f64897fbf8
6 changed files with 122 additions and 5 deletions

View file

@ -1,5 +1,10 @@
from open_notebook.domain.models import Model
from open_notebook.models.embedding_models import OpenAIEmbeddingModel
from open_notebook.models.embedding_models import (
GeminiEmbeddingModel,
OllamaEmbeddingModel,
OpenAIEmbeddingModel,
VertexEmbeddingModel,
)
from open_notebook.models.llms import (
AnthropicLanguageModel,
GeminiLanguageModel,
@ -26,6 +31,9 @@ MODEL_CLASS_MAP = {
},
"embedding": {
"openai": OpenAIEmbeddingModel,
"gemini": GeminiEmbeddingModel,
"vertexai": VertexEmbeddingModel,
"ollama": OllamaEmbeddingModel,
},
"speech_to_text": {
"openai": OpenAISpeechToTextModel,

View file

@ -2,11 +2,16 @@
Classes for supporting different embedding models
"""
from __future__ import annotations
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional
from openai import OpenAI
import requests
# todo: add support for multiple embeddings (array)
@dataclass
@ -25,11 +30,68 @@ class EmbeddingModel(ABC):
raise NotImplementedError
@dataclass
class OllamaEmbeddingModel(EmbeddingModel):
model_name: str
base_url: str = os.environ.get("OLLAMA_API_BASE", "http://localhost:11434")
def embed(self, text: str) -> List[float]:
"""
Embeds the content using Open AI embedding
"""
# todo: make this Singleton
text = text.replace("\n", " ")
response = requests.post(
f"{self.base_url}/api/embed",
json={"model": self.model_name, "input": [text]},
)
return response.json()["embeddings"][0]
@dataclass
class GeminiEmbeddingModel(EmbeddingModel):
model_name: str
def embed(self, text: str) -> List[float]:
import google.generativeai as genai
"""
Embeds the content using Open AI embedding
"""
model_name = (
self.model_name
if self.model_name.startswith("models/")
else f"models/{self.model_name}"
)
result = genai.embed_content(model=model_name, content=text)
return result["embedding"]
@dataclass
class VertexEmbeddingModel(EmbeddingModel):
model_name: str
def embed(self, text: str) -> List[float]:
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
texts = [text]
# The dimensionality of the output embeddings.
# dimensionality = 256
# The task type for embedding. Check the available tasks in the model's documentation.
model = TextEmbeddingModel.from_pretrained(self.model_name)
inputs = [TextEmbeddingInput(text) for text in texts]
embeddings = model.get_embeddings(inputs)
return embeddings[0].values
@dataclass
class OpenAIEmbeddingModel(EmbeddingModel):
model_name: str
def embed(self, text: str) -> List[float]:
from openai import OpenAI
"""
Embeds the content using Open AI embedding
"""

View file

@ -17,6 +17,8 @@ from langchain_ollama.chat_models import ChatOllama
from langchain_openai.chat_models import ChatOpenAI
from pydantic import SecretStr
# future: is there a value on returning langchain specific models?
@dataclass
class LanguageModel(ABC):

View file

@ -6,8 +6,6 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional
from openai import OpenAI
@dataclass
class SpeechToTextModel(ABC):
@ -33,6 +31,8 @@ class OpenAISpeechToTextModel(SpeechToTextModel):
"""
Transcribes an audio file into text
"""
from openai import OpenAI
# todo: make this Singleton
client = OpenAI()
with open(audio_file_path, "rb") as audio:

45
poetry.lock generated
View file

@ -1932,6 +1932,27 @@ qtconsole = ["qtconsole"]
test = ["packaging", "pickleshare", "pytest", "pytest-asyncio (<0.22)", "testpath"]
test-extra = ["curio", "ipython[test]", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.23)", "pandas", "trio"]
[[package]]
name = "ipywidgets"
version = "8.1.5"
description = "Jupyter interactive widgets"
optional = false
python-versions = ">=3.7"
files = [
{file = "ipywidgets-8.1.5-py3-none-any.whl", hash = "sha256:3290f526f87ae6e77655555baba4f36681c555b8bdbbff430b70e52c34c86245"},
{file = "ipywidgets-8.1.5.tar.gz", hash = "sha256:870e43b1a35656a80c18c9503bbf2d16802db1cb487eec6fab27d683381dde17"},
]
[package.dependencies]
comm = ">=0.1.3"
ipython = ">=6.1.0"
jupyterlab-widgets = ">=3.0.12,<3.1.0"
traitlets = ">=4.3.1"
widgetsnbextension = ">=4.0.12,<4.1.0"
[package.extras]
test = ["ipykernel", "jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"]
[[package]]
name = "jedi"
version = "0.19.1"
@ -2163,6 +2184,17 @@ files = [
{file = "jupyterlab_pygments-0.3.0.tar.gz", hash = "sha256:721aca4d9029252b11cfa9d185e5b5af4d54772bb8072f9b7036f4170054d35d"},
]
[[package]]
name = "jupyterlab-widgets"
version = "3.0.13"
description = "Jupyter interactive widgets for JupyterLab"
optional = false
python-versions = ">=3.7"
files = [
{file = "jupyterlab_widgets-3.0.13-py3-none-any.whl", hash = "sha256:e3cda2c233ce144192f1e29914ad522b2f4c40e77214b0cc97377ca3d323db54"},
{file = "jupyterlab_widgets-3.0.13.tar.gz", hash = "sha256:a2966d385328c1942b683a8cd96b89b8dd82c8b8f81dda902bb2bc06d46f5bed"},
]
[[package]]
name = "langchain"
version = "0.3.4"
@ -6167,6 +6199,17 @@ files = [
[package.extras]
test = ["pytest (>=6.0.0)", "setuptools (>=65)"]
[[package]]
name = "widgetsnbextension"
version = "4.0.13"
description = "Jupyter interactive widgets for Jupyter Notebook"
optional = false
python-versions = ">=3.7"
files = [
{file = "widgetsnbextension-4.0.13-py3-none-any.whl", hash = "sha256:74b2692e8500525cc38c2b877236ba51d34541e6385eeed5aec15a70f88a6c71"},
{file = "widgetsnbextension-4.0.13.tar.gz", hash = "sha256:ffcb67bc9febd10234a362795f643927f4e0c05d9342c727b65d2384f8feacb6"},
]
[[package]]
name = "win32-setctime"
version = "1.1.0"
@ -6324,4 +6367,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "4fa191c6df5a7a355eb0d61f9560ec70e4671ac49cd54fa3a166c1e25c325671"
content-hash = "265ed7b26b19c54847b8e549f09ccbf8be68120b34f392fb5b8afc9ffccd62ac"

View file

@ -46,12 +46,14 @@ bs4 = "^0.0.2"
python-docx = "^1.1.2"
python-pptx = "^1.0.2"
openpyxl = "^3.1.5"
google-generativeai = "^0.8.3"
[tool.poetry.group.dev.dependencies]
ipykernel = "^6.29.5"
ruff = "^0.5.5"
mypy = "^1.11.1"
types-requests = "^2.32.0.20241016"
ipywidgets = "^8.1.5"
[build-system]
requires = ["poetry-core"]