support ollama, vertex and gemini embedding models
This commit is contained in:
parent
2f07f0ab49
commit
f64897fbf8
6 changed files with 122 additions and 5 deletions
|
|
@ -1,5 +1,10 @@
|
|||
from open_notebook.domain.models import Model
|
||||
from open_notebook.models.embedding_models import OpenAIEmbeddingModel
|
||||
from open_notebook.models.embedding_models import (
|
||||
GeminiEmbeddingModel,
|
||||
OllamaEmbeddingModel,
|
||||
OpenAIEmbeddingModel,
|
||||
VertexEmbeddingModel,
|
||||
)
|
||||
from open_notebook.models.llms import (
|
||||
AnthropicLanguageModel,
|
||||
GeminiLanguageModel,
|
||||
|
|
@ -26,6 +31,9 @@ MODEL_CLASS_MAP = {
|
|||
},
|
||||
"embedding": {
|
||||
"openai": OpenAIEmbeddingModel,
|
||||
"gemini": GeminiEmbeddingModel,
|
||||
"vertexai": VertexEmbeddingModel,
|
||||
"ollama": OllamaEmbeddingModel,
|
||||
},
|
||||
"speech_to_text": {
|
||||
"openai": OpenAISpeechToTextModel,
|
||||
|
|
|
|||
|
|
@ -2,11 +2,16 @@
|
|||
Classes for supporting different embedding models
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
import requests
|
||||
|
||||
# todo: add support for multiple embeddings (array)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -25,11 +30,68 @@ class EmbeddingModel(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class OllamaEmbeddingModel(EmbeddingModel):
|
||||
model_name: str
|
||||
base_url: str = os.environ.get("OLLAMA_API_BASE", "http://localhost:11434")
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""
|
||||
Embeds the content using Open AI embedding
|
||||
"""
|
||||
# todo: make this Singleton
|
||||
text = text.replace("\n", " ")
|
||||
response = requests.post(
|
||||
f"{self.base_url}/api/embed",
|
||||
json={"model": self.model_name, "input": [text]},
|
||||
)
|
||||
return response.json()["embeddings"][0]
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeminiEmbeddingModel(EmbeddingModel):
|
||||
model_name: str
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
import google.generativeai as genai
|
||||
|
||||
"""
|
||||
Embeds the content using Open AI embedding
|
||||
"""
|
||||
model_name = (
|
||||
self.model_name
|
||||
if self.model_name.startswith("models/")
|
||||
else f"models/{self.model_name}"
|
||||
)
|
||||
result = genai.embed_content(model=model_name, content=text)
|
||||
|
||||
return result["embedding"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class VertexEmbeddingModel(EmbeddingModel):
|
||||
model_name: str
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
|
||||
|
||||
texts = [text]
|
||||
# The dimensionality of the output embeddings.
|
||||
# dimensionality = 256
|
||||
# The task type for embedding. Check the available tasks in the model's documentation.
|
||||
model = TextEmbeddingModel.from_pretrained(self.model_name)
|
||||
inputs = [TextEmbeddingInput(text) for text in texts]
|
||||
embeddings = model.get_embeddings(inputs)
|
||||
return embeddings[0].values
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAIEmbeddingModel(EmbeddingModel):
|
||||
model_name: str
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
from openai import OpenAI
|
||||
|
||||
"""
|
||||
Embeds the content using Open AI embedding
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ from langchain_ollama.chat_models import ChatOllama
|
|||
from langchain_openai.chat_models import ChatOpenAI
|
||||
from pydantic import SecretStr
|
||||
|
||||
# future: is there a value on returning langchain specific models?
|
||||
|
||||
|
||||
@dataclass
|
||||
class LanguageModel(ABC):
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@ from abc import ABC, abstractmethod
|
|||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeechToTextModel(ABC):
|
||||
|
|
@ -33,6 +31,8 @@ class OpenAISpeechToTextModel(SpeechToTextModel):
|
|||
"""
|
||||
Transcribes an audio file into text
|
||||
"""
|
||||
from openai import OpenAI
|
||||
|
||||
# todo: make this Singleton
|
||||
client = OpenAI()
|
||||
with open(audio_file_path, "rb") as audio:
|
||||
|
|
|
|||
45
poetry.lock
generated
45
poetry.lock
generated
|
|
@ -1932,6 +1932,27 @@ qtconsole = ["qtconsole"]
|
|||
test = ["packaging", "pickleshare", "pytest", "pytest-asyncio (<0.22)", "testpath"]
|
||||
test-extra = ["curio", "ipython[test]", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.23)", "pandas", "trio"]
|
||||
|
||||
[[package]]
|
||||
name = "ipywidgets"
|
||||
version = "8.1.5"
|
||||
description = "Jupyter interactive widgets"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "ipywidgets-8.1.5-py3-none-any.whl", hash = "sha256:3290f526f87ae6e77655555baba4f36681c555b8bdbbff430b70e52c34c86245"},
|
||||
{file = "ipywidgets-8.1.5.tar.gz", hash = "sha256:870e43b1a35656a80c18c9503bbf2d16802db1cb487eec6fab27d683381dde17"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
comm = ">=0.1.3"
|
||||
ipython = ">=6.1.0"
|
||||
jupyterlab-widgets = ">=3.0.12,<3.1.0"
|
||||
traitlets = ">=4.3.1"
|
||||
widgetsnbextension = ">=4.0.12,<4.1.0"
|
||||
|
||||
[package.extras]
|
||||
test = ["ipykernel", "jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"]
|
||||
|
||||
[[package]]
|
||||
name = "jedi"
|
||||
version = "0.19.1"
|
||||
|
|
@ -2163,6 +2184,17 @@ files = [
|
|||
{file = "jupyterlab_pygments-0.3.0.tar.gz", hash = "sha256:721aca4d9029252b11cfa9d185e5b5af4d54772bb8072f9b7036f4170054d35d"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jupyterlab-widgets"
|
||||
version = "3.0.13"
|
||||
description = "Jupyter interactive widgets for JupyterLab"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "jupyterlab_widgets-3.0.13-py3-none-any.whl", hash = "sha256:e3cda2c233ce144192f1e29914ad522b2f4c40e77214b0cc97377ca3d323db54"},
|
||||
{file = "jupyterlab_widgets-3.0.13.tar.gz", hash = "sha256:a2966d385328c1942b683a8cd96b89b8dd82c8b8f81dda902bb2bc06d46f5bed"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langchain"
|
||||
version = "0.3.4"
|
||||
|
|
@ -6167,6 +6199,17 @@ files = [
|
|||
[package.extras]
|
||||
test = ["pytest (>=6.0.0)", "setuptools (>=65)"]
|
||||
|
||||
[[package]]
|
||||
name = "widgetsnbextension"
|
||||
version = "4.0.13"
|
||||
description = "Jupyter interactive widgets for Jupyter Notebook"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "widgetsnbextension-4.0.13-py3-none-any.whl", hash = "sha256:74b2692e8500525cc38c2b877236ba51d34541e6385eeed5aec15a70f88a6c71"},
|
||||
{file = "widgetsnbextension-4.0.13.tar.gz", hash = "sha256:ffcb67bc9febd10234a362795f643927f4e0c05d9342c727b65d2384f8feacb6"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "win32-setctime"
|
||||
version = "1.1.0"
|
||||
|
|
@ -6324,4 +6367,4 @@ type = ["pytest-mypy"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "4fa191c6df5a7a355eb0d61f9560ec70e4671ac49cd54fa3a166c1e25c325671"
|
||||
content-hash = "265ed7b26b19c54847b8e549f09ccbf8be68120b34f392fb5b8afc9ffccd62ac"
|
||||
|
|
|
|||
|
|
@ -46,12 +46,14 @@ bs4 = "^0.0.2"
|
|||
python-docx = "^1.1.2"
|
||||
python-pptx = "^1.0.2"
|
||||
openpyxl = "^3.1.5"
|
||||
google-generativeai = "^0.8.3"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ipykernel = "^6.29.5"
|
||||
ruff = "^0.5.5"
|
||||
mypy = "^1.11.1"
|
||||
types-requests = "^2.32.0.20241016"
|
||||
ipywidgets = "^8.1.5"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
|
|
|||
Loading…
Reference in a new issue