diff --git a/open_notebook/models/__init__.py b/open_notebook/models/__init__.py index fc431d7..3e3b72a 100644 --- a/open_notebook/models/__init__.py +++ b/open_notebook/models/__init__.py @@ -17,6 +17,7 @@ from open_notebook.models.llms import ( OpenRouterLanguageModel, VertexAILanguageModel, VertexAnthropicLanguageModel, + XAILanguageModel, ) from open_notebook.models.speech_to_text_models import ( OpenAISpeechToTextModel, @@ -44,6 +45,7 @@ MODEL_CLASS_MAP: Dict[str, ProviderMap] = { "anthropic": AnthropicLanguageModel, "openai": OpenAILanguageModel, "gemini": GeminiLanguageModel, + "xai": XAILanguageModel, }, "embedding": { "openai": OpenAIEmbeddingModel, diff --git a/open_notebook/models/llms.py b/open_notebook/models/llms.py index 3c9046d..52ff94b 100644 --- a/open_notebook/models/llms.py +++ b/open_notebook/models/llms.py @@ -171,7 +171,7 @@ class OpenRouterLanguageModel(LanguageModel): def to_langchain(self) -> ChatOpenAI: """ - Convert the language model to a LangChain chat model. + Convert the language model to a LangChain chat model for Open Router. """ kwargs = self.kwargs if self.json: @@ -191,6 +191,34 @@ class OpenRouterLanguageModel(LanguageModel): ) +@dataclass +class XAILanguageModel(LanguageModel): + """ + Language model that uses the OpenAI chat model for X.AI. + """ + + model_name: str + + def to_langchain(self) -> ChatOpenAI: + """ + Convert the language model to a LangChain chat model. + """ + kwargs = self.kwargs + if self.json: + kwargs["response_format"] = {"type": "json_object"} + + return ChatOpenAI( + model=self.model_name, + temperature=self.temperature or 0.5, + base_url=os.environ.get("XAI_BASE_URL", "https://api.x.ai/v1"), + max_tokens=self.max_tokens, + model_kwargs=kwargs, + streaming=self.streaming, + api_key=SecretStr(os.environ.get("XAI_API_KEY", "xai")), + top_p=self.top_p, + ) + + @dataclass class AnthropicLanguageModel(LanguageModel): """ diff --git a/pages/7_⚙️_Settings.py b/pages/7_⚙️_Settings.py index 135e889..e2b09f0 100644 --- a/pages/7_⚙️_Settings.py +++ b/pages/7_⚙️_Settings.py @@ -28,6 +28,7 @@ model_types = [ provider_status["ollama"] = os.environ.get("OLLAMA_API_BASE") is not None provider_status["openai"] = os.environ.get("OPENAI_API_KEY") is not None +provider_status["xai"] = os.environ.get("XAI_API_KEY") is not None provider_status["vertexai"] = ( os.environ.get("VERTEX_PROJECT") is not None and os.environ.get("VERTEX_LOCATION") is not None