fixed forced imports + ensured return type in structure_response

This commit is contained in:
Siddhesh Agarwal
2024-11-01 14:24:34 +05:30
parent 56b1e65d70
commit fe06331662
5 changed files with 37 additions and 11 deletions
+8 -2
View File
@@ -1,7 +1,6 @@
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar
import anthropic
import instructor
from pydantic import BaseModel
@@ -34,6 +33,13 @@ class Anthropic(BaseProvider):
"""The raw Anthropic client."""
if not self.api_key:
raise ValueError("Anthropic API key is required")
try:
import anthropic
except ImportError as exc:
raise ImportError(
"Please install the `anthropic` package: `pip install anthropic`"
) from exc
return anthropic.Anthropic(api_key=self.api_key)
@cached_property
@@ -86,7 +92,7 @@ class Anthropic(BaseProvider):
response_model=response_model,
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response
return response_model.model_validate(response)
@logger
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
+9 -4
View File
@@ -1,7 +1,6 @@
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar
import groq
import instructor
from pydantic import BaseModel
@@ -34,6 +33,12 @@ class Groq(BaseProvider):
"""The raw Groq client."""
if not self.api_key:
raise ValueError("Groq API key is required")
try:
import groq
except ImportError as exc:
raise ImportError(
"Please install the `groq` package: `pip install groq`"
) from exc
return groq.Groq(api_key=self.api_key)
@cached_property
@@ -85,7 +90,7 @@ class Groq(BaseProvider):
model=kwargs.pop("llm_model", self.DEFAULT_MODEL),
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response
return response_model.model_validate(response)
@logger
def generate_text(
@@ -94,7 +99,7 @@ class Groq(BaseProvider):
*,
llm_model: str,
**kwargs,
):
) -> str:
messages = [
{"role": "user", "content": prompt},
]
@@ -105,4 +110,4 @@ class Groq(BaseProvider):
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response.choices[0].message.content
return str(response.choices[0].message.content)
+7 -2
View File
@@ -2,7 +2,6 @@ from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar
import instructor
import ollama as ol
from openai import OpenAI
from pydantic import BaseModel
@@ -36,6 +35,12 @@ class Ollama(BaseProvider):
"""The raw Ollama client."""
if not self.host_url:
raise ValueError("No ollama host url provided")
try:
import ollama as ol
except ImportError as exc:
raise ImportError(
"Please install the `ollama` package: `pip install ollama`"
) from exc
return ol.Client(timeout=self.TIMEOUT, host=self.host_url)
@cached_property
@@ -93,7 +98,7 @@ class Ollama(BaseProvider):
response_model=response_model,
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response
return response_model.model_validate(response)
@logger
def generate_text(
+7 -2
View File
@@ -2,7 +2,6 @@ from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar
import instructor
import openai as oa
from pydantic import BaseModel
from ..logging import logger
@@ -33,6 +32,12 @@ class OpenAI(BaseProvider):
"""The raw OpenAI client."""
if not self.api_key:
raise ValueError("OpenAI API key is required")
try:
import openai as oa
except ImportError as exc:
raise ImportError(
"Please install the `openai` package: `pip install openai`"
) from exc
return oa.OpenAI(api_key=self.api_key)
@cached_property
@@ -87,7 +92,7 @@ class OpenAI(BaseProvider):
response_model=response_model,
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response
return response_model.model_validate(response)
@logger
def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs):
+6 -1
View File
@@ -2,7 +2,6 @@ from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar
import instructor
import openai as oa
from pydantic import BaseModel
from simplemind.models import Message
@@ -37,6 +36,12 @@ class XAI(BaseProvider):
"""The raw OpenAI client."""
if not self.api_key:
raise ValueError("XAI API key is required")
try:
import openai as oa
except ImportError as exc:
raise ImportError(
"Please install the `openai` package: `pip install openai`"
) from exc
return oa.OpenAI(
api_key=self.api_key,
base_url=BASE_URL,