mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 06:46:18 +00:00
fixed forced imports + ensured return type in structure_response
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import anthropic
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -34,6 +33,13 @@ class Anthropic(BaseProvider):
|
||||
"""The raw Anthropic client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("Anthropic API key is required")
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `anthropic` package: `pip install anthropic`"
|
||||
) from exc
|
||||
|
||||
return anthropic.Anthropic(api_key=self.api_key)
|
||||
|
||||
@cached_property
|
||||
@@ -86,7 +92,7 @@ class Anthropic(BaseProvider):
|
||||
response_model=response_model,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
return response
|
||||
return response_model.model_validate(response)
|
||||
|
||||
@logger
|
||||
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import groq
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -34,6 +33,12 @@ class Groq(BaseProvider):
|
||||
"""The raw Groq client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("Groq API key is required")
|
||||
try:
|
||||
import groq
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `groq` package: `pip install groq`"
|
||||
) from exc
|
||||
return groq.Groq(api_key=self.api_key)
|
||||
|
||||
@cached_property
|
||||
@@ -85,7 +90,7 @@ class Groq(BaseProvider):
|
||||
model=kwargs.pop("llm_model", self.DEFAULT_MODEL),
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
return response
|
||||
return response_model.model_validate(response)
|
||||
|
||||
@logger
|
||||
def generate_text(
|
||||
@@ -94,7 +99,7 @@ class Groq(BaseProvider):
|
||||
*,
|
||||
llm_model: str,
|
||||
**kwargs,
|
||||
):
|
||||
) -> str:
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
@@ -105,4 +110,4 @@ class Groq(BaseProvider):
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
return str(response.choices[0].message.content)
|
||||
|
||||
@@ -2,7 +2,6 @@ from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
import ollama as ol
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -36,6 +35,12 @@ class Ollama(BaseProvider):
|
||||
"""The raw Ollama client."""
|
||||
if not self.host_url:
|
||||
raise ValueError("No ollama host url provided")
|
||||
try:
|
||||
import ollama as ol
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `ollama` package: `pip install ollama`"
|
||||
) from exc
|
||||
return ol.Client(timeout=self.TIMEOUT, host=self.host_url)
|
||||
|
||||
@cached_property
|
||||
@@ -93,7 +98,7 @@ class Ollama(BaseProvider):
|
||||
response_model=response_model,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
return response
|
||||
return response_model.model_validate(response)
|
||||
|
||||
@logger
|
||||
def generate_text(
|
||||
|
||||
@@ -2,7 +2,6 @@ from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
import openai as oa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..logging import logger
|
||||
@@ -33,6 +32,12 @@ class OpenAI(BaseProvider):
|
||||
"""The raw OpenAI client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("OpenAI API key is required")
|
||||
try:
|
||||
import openai as oa
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `openai` package: `pip install openai`"
|
||||
) from exc
|
||||
return oa.OpenAI(api_key=self.api_key)
|
||||
|
||||
@cached_property
|
||||
@@ -87,7 +92,7 @@ class OpenAI(BaseProvider):
|
||||
response_model=response_model,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
return response
|
||||
return response_model.model_validate(response)
|
||||
|
||||
@logger
|
||||
def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs):
|
||||
|
||||
@@ -2,7 +2,6 @@ from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
import openai as oa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from simplemind.models import Message
|
||||
@@ -37,6 +36,12 @@ class XAI(BaseProvider):
|
||||
"""The raw OpenAI client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("XAI API key is required")
|
||||
try:
|
||||
import openai as oa
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `openai` package: `pip install openai`"
|
||||
) from exc
|
||||
return oa.OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=BASE_URL,
|
||||
|
||||
Reference in New Issue
Block a user