From fe06331662b8f85f2c692963604c2e4c6f05c14f Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Fri, 1 Nov 2024 14:24:34 +0530 Subject: [PATCH] fixed forced imports + ensured return type in structure_response --- simplemind/providers/anthropic.py | 10 ++++++++-- simplemind/providers/groq.py | 13 +++++++++---- simplemind/providers/ollama.py | 9 +++++++-- simplemind/providers/openai.py | 9 +++++++-- simplemind/providers/xai.py | 7 ++++++- 5 files changed, 37 insertions(+), 11 deletions(-) diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index b51e40f..4798933 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -1,7 +1,6 @@ from functools import cached_property from typing import TYPE_CHECKING, Type, TypeVar -import anthropic import instructor from pydantic import BaseModel @@ -34,6 +33,13 @@ class Anthropic(BaseProvider): """The raw Anthropic client.""" if not self.api_key: raise ValueError("Anthropic API key is required") + try: + import anthropic + except ImportError as exc: + raise ImportError( + "Please install the `anthropic` package: `pip install anthropic`" + ) from exc + return anthropic.Anthropic(api_key=self.api_key) @cached_property @@ -86,7 +92,7 @@ class Anthropic(BaseProvider): response_model=response_model, **{**self.DEFAULT_KWARGS, **kwargs}, ) - return response + return response_model.model_validate(response) @logger def generate_text(self, prompt: str, *, llm_model: str, **kwargs): diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py index 5114e90..5b2801f 100644 --- a/simplemind/providers/groq.py +++ b/simplemind/providers/groq.py @@ -1,7 +1,6 @@ from functools import cached_property from typing import TYPE_CHECKING, Type, TypeVar -import groq import instructor from pydantic import BaseModel @@ -34,6 +33,12 @@ class Groq(BaseProvider): """The raw Groq client.""" if not self.api_key: raise ValueError("Groq API key is required") + try: + import groq + except ImportError as exc: + raise ImportError( + "Please install the `groq` package: `pip install groq`" + ) from exc return groq.Groq(api_key=self.api_key) @cached_property @@ -85,7 +90,7 @@ class Groq(BaseProvider): model=kwargs.pop("llm_model", self.DEFAULT_MODEL), **{**self.DEFAULT_KWARGS, **kwargs}, ) - return response + return response_model.model_validate(response) @logger def generate_text( @@ -94,7 +99,7 @@ class Groq(BaseProvider): *, llm_model: str, **kwargs, - ): + ) -> str: messages = [ {"role": "user", "content": prompt}, ] @@ -105,4 +110,4 @@ class Groq(BaseProvider): **{**self.DEFAULT_KWARGS, **kwargs}, ) - return response.choices[0].message.content + return str(response.choices[0].message.content) diff --git a/simplemind/providers/ollama.py b/simplemind/providers/ollama.py index c422eb4..3e00c25 100644 --- a/simplemind/providers/ollama.py +++ b/simplemind/providers/ollama.py @@ -2,7 +2,6 @@ from functools import cached_property from typing import TYPE_CHECKING, Type, TypeVar import instructor -import ollama as ol from openai import OpenAI from pydantic import BaseModel @@ -36,6 +35,12 @@ class Ollama(BaseProvider): """The raw Ollama client.""" if not self.host_url: raise ValueError("No ollama host url provided") + try: + import ollama as ol + except ImportError as exc: + raise ImportError( + "Please install the `ollama` package: `pip install ollama`" + ) from exc return ol.Client(timeout=self.TIMEOUT, host=self.host_url) @cached_property @@ -93,7 +98,7 @@ class Ollama(BaseProvider): response_model=response_model, **{**self.DEFAULT_KWARGS, **kwargs}, ) - return response + return response_model.model_validate(response) @logger def generate_text( diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index 2a05080..fb197e5 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -2,7 +2,6 @@ from functools import cached_property from typing import TYPE_CHECKING, Type, TypeVar import instructor -import openai as oa from pydantic import BaseModel from ..logging import logger @@ -33,6 +32,12 @@ class OpenAI(BaseProvider): """The raw OpenAI client.""" if not self.api_key: raise ValueError("OpenAI API key is required") + try: + import openai as oa + except ImportError as exc: + raise ImportError( + "Please install the `openai` package: `pip install openai`" + ) from exc return oa.OpenAI(api_key=self.api_key) @cached_property @@ -87,7 +92,7 @@ class OpenAI(BaseProvider): response_model=response_model, **{**self.DEFAULT_KWARGS, **kwargs}, ) - return response + return response_model.model_validate(response) @logger def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs): diff --git a/simplemind/providers/xai.py b/simplemind/providers/xai.py index afc1faf..60e2654 100644 --- a/simplemind/providers/xai.py +++ b/simplemind/providers/xai.py @@ -2,7 +2,6 @@ from functools import cached_property from typing import TYPE_CHECKING, Type, TypeVar import instructor -import openai as oa from pydantic import BaseModel from simplemind.models import Message @@ -37,6 +36,12 @@ class XAI(BaseProvider): """The raw OpenAI client.""" if not self.api_key: raise ValueError("XAI API key is required") + try: + import openai as oa + except ImportError as exc: + raise ImportError( + "Please install the `openai` package: `pip install openai`" + ) from exc return oa.OpenAI( api_key=self.api_key, base_url=BASE_URL,