typing fixes
This commit is contained in:
parent
c15982af3f
commit
3b36caceb9
3 changed files with 20 additions and 8 deletions
|
|
@ -265,7 +265,9 @@ def text_search(keyword: str, results: int, source: bool = True, note: bool = Tr
|
|||
raise DatabaseOperationError("Failed to perform text search")
|
||||
|
||||
|
||||
def vector_search(keyword: str, results: int, source: bool = True, note: bool = True):
|
||||
def vector_search(
|
||||
keyword: List[float], results: int, source: bool = True, note: bool = True
|
||||
):
|
||||
if not keyword:
|
||||
raise InvalidInputError("Search keyword cannot be empty")
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from open_notebook.domain.models import DefaultModels, Model
|
||||
from open_notebook.models.embedding_models import (
|
||||
EmbeddingModel,
|
||||
GeminiEmbeddingModel,
|
||||
OllamaEmbeddingModel,
|
||||
OpenAIEmbeddingModel,
|
||||
|
|
@ -10,6 +11,7 @@ from open_notebook.models.embedding_models import (
|
|||
from open_notebook.models.llms import (
|
||||
AnthropicLanguageModel,
|
||||
GeminiLanguageModel,
|
||||
LanguageModel,
|
||||
LiteLLMLanguageModel,
|
||||
OllamaLanguageModel,
|
||||
OpenAILanguageModel,
|
||||
|
|
@ -17,10 +19,14 @@ from open_notebook.models.llms import (
|
|||
VertexAILanguageModel,
|
||||
VertexAnthropicLanguageModel,
|
||||
)
|
||||
from open_notebook.models.speech_to_text_models import OpenAISpeechToTextModel
|
||||
from open_notebook.models.speech_to_text_models import (
|
||||
OpenAISpeechToTextModel,
|
||||
SpeechToTextModel,
|
||||
)
|
||||
from open_notebook.models.text_to_speech_models import (
|
||||
ElevenLabsTextToSpeechModel,
|
||||
OpenAITextToSpeechModel,
|
||||
TextToSpeechModel,
|
||||
)
|
||||
|
||||
# Unified model class map with type information
|
||||
|
|
@ -77,7 +83,9 @@ class ModelManager:
|
|||
self.refresh_defaults()
|
||||
return self._default_models
|
||||
|
||||
def get_model(self, model_id: str, **kwargs) -> object:
|
||||
def get_model(
|
||||
self, model_id: str, **kwargs
|
||||
) -> Union[LanguageModel, EmbeddingModel, SpeechToTextModel, TextToSpeechModel]:
|
||||
"""
|
||||
Get a model instance based on model_id. Uses caching to avoid recreating instances.
|
||||
|
||||
|
|
@ -110,12 +118,14 @@ class ModelManager:
|
|||
|
||||
# Special handling for language models that need langchain conversion
|
||||
if model.type == "language":
|
||||
model_instance = model_instance.to_langchain()
|
||||
model_instance = model_instance
|
||||
|
||||
self._model_cache[cache_key] = model_instance
|
||||
return model_instance
|
||||
|
||||
def get_default_model(self, model_type: str, **kwargs) -> object:
|
||||
def get_default_model(
|
||||
self, model_type: str, **kwargs
|
||||
) -> Union[LanguageModel, EmbeddingModel, SpeechToTextModel, TextToSpeechModel]:
|
||||
"""
|
||||
Get the default model for a specific type.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import streamlit as st
|
||||
|
||||
from open_notebook.domain.notebook import text_search, vector_search
|
||||
from open_notebook.models import model_manager
|
||||
from open_notebook.models import EmbeddingModel, model_manager
|
||||
from stream_app.note import note_list_item
|
||||
from stream_app.source import source_list_item
|
||||
from stream_app.utils import page_commons
|
||||
|
|
@ -11,7 +11,7 @@ st.set_page_config(
|
|||
)
|
||||
page_commons()
|
||||
|
||||
EMBEDDING_MODEL = model_manager.get_default_model("embedding")
|
||||
EMBEDDING_MODEL: EmbeddingModel = model_manager.get_default_model("embedding")
|
||||
|
||||
# search_tab, ask_tab = st.tabs(["Search", "Ask"])
|
||||
# notebooks = Notebook.get_all()
|
||||
|
|
|
|||
Loading…
Reference in a new issue