typing fixes

This commit is contained in:
LUIS NOVO 2024-11-01 21:40:28 -03:00
parent c15982af3f
commit 3b36caceb9
3 changed files with 20 additions and 8 deletions

View file

@ -265,7 +265,9 @@ def text_search(keyword: str, results: int, source: bool = True, note: bool = Tr
raise DatabaseOperationError("Failed to perform text search")
def vector_search(keyword: str, results: int, source: bool = True, note: bool = True):
def vector_search(
keyword: List[float], results: int, source: bool = True, note: bool = True
):
if not keyword:
raise InvalidInputError("Search keyword cannot be empty")
try:

View file

@ -1,7 +1,8 @@
from typing import Dict, Optional
from typing import Dict, Optional, Union
from open_notebook.domain.models import DefaultModels, Model
from open_notebook.models.embedding_models import (
EmbeddingModel,
GeminiEmbeddingModel,
OllamaEmbeddingModel,
OpenAIEmbeddingModel,
@ -10,6 +11,7 @@ from open_notebook.models.embedding_models import (
from open_notebook.models.llms import (
AnthropicLanguageModel,
GeminiLanguageModel,
LanguageModel,
LiteLLMLanguageModel,
OllamaLanguageModel,
OpenAILanguageModel,
@ -17,10 +19,14 @@ from open_notebook.models.llms import (
VertexAILanguageModel,
VertexAnthropicLanguageModel,
)
from open_notebook.models.speech_to_text_models import OpenAISpeechToTextModel
from open_notebook.models.speech_to_text_models import (
OpenAISpeechToTextModel,
SpeechToTextModel,
)
from open_notebook.models.text_to_speech_models import (
ElevenLabsTextToSpeechModel,
OpenAITextToSpeechModel,
TextToSpeechModel,
)
# Unified model class map with type information
@ -77,7 +83,9 @@ class ModelManager:
self.refresh_defaults()
return self._default_models
def get_model(self, model_id: str, **kwargs) -> object:
def get_model(
self, model_id: str, **kwargs
) -> Union[LanguageModel, EmbeddingModel, SpeechToTextModel, TextToSpeechModel]:
"""
Get a model instance based on model_id. Uses caching to avoid recreating instances.
@ -110,12 +118,14 @@ class ModelManager:
# Special handling for language models that need langchain conversion
if model.type == "language":
model_instance = model_instance.to_langchain()
model_instance = model_instance
self._model_cache[cache_key] = model_instance
return model_instance
def get_default_model(self, model_type: str, **kwargs) -> object:
def get_default_model(
self, model_type: str, **kwargs
) -> Union[LanguageModel, EmbeddingModel, SpeechToTextModel, TextToSpeechModel]:
"""
Get the default model for a specific type.

View file

@ -1,7 +1,7 @@
import streamlit as st
from open_notebook.domain.notebook import text_search, vector_search
from open_notebook.models import model_manager
from open_notebook.models import EmbeddingModel, model_manager
from stream_app.note import note_list_item
from stream_app.source import source_list_item
from stream_app.utils import page_commons
@ -11,7 +11,7 @@ st.set_page_config(
)
page_commons()
EMBEDDING_MODEL = model_manager.get_default_model("embedding")
EMBEDDING_MODEL: EmbeddingModel = model_manager.get_default_model("embedding")
# search_tab, ask_tab = st.tabs(["Search", "Ask"])
# notebooks = Notebook.get_all()