diff --git a/.envrc.template b/.envrc.template index c365f90..16f0415 100644 --- a/.envrc.template +++ b/.envrc.template @@ -1,5 +1,6 @@ -export OPENAI_API_KEY="" export ANTHROPIC_API_KEY="" -export XAI_API_KEY="" -export OLLAMA_HOST_URL="" +export GEMINI_API_KEY="" export GROQ_API_KEY="" +export OLLAMA_HOST_URL="" +export OPENAI_API_KEY="" +export XAI_API_KEY="" diff --git a/README.md b/README.md index 8726049..3910082 100644 --- a/README.md +++ b/README.md @@ -20,13 +20,14 @@ With Simplemind, tapping into AI is as easy as a friendly conversation. To specify a specific provider or model, you can use the `llm_provider` and `llm_model` parameters when calling: `generate_text`, `generate_data`, or `create_conversation`. -- **[OpenAI's GPT](https://openai.com/gpt)** -- **[Anthropic's Claude](https://www.anthropic.com/claude)** -- **[xAI's Grok](https://x.ai/)** -- **[Groq's Groq](https://groq.com/)** -- **[Ollama](https://ollama.com)** +- [**Anthropic's Claude**](https://www.anthropic.com/claude) +- [**Google's Gemini**](https://ai.google.dev/gemini-api) +- [**Groq's Groq**](https://groq.com/) +- [**Ollama**](https://ollama.com) +- [**OpenAI's GPT**](https://openai.com/gpt) +- [**xAI's Grok**](https://x.ai/) -If you'd like to see Simplemind support additional providers or models, please send a pull request! +If you want to see Simplemind support, additional providers or models, please request a pull! ## Why SimpleMind? - **Intuitive**: Built with Pythonic simplicity and readability in mind. @@ -140,7 +141,7 @@ response = gpt_4o_mini.generate_text( Harnessing the power of Python, you can easily create your own plugins to add additional functionality to your conversations: ```python -class SimpleMemoryPlugin: +class SimpleMemoryPlugin(sm.BasePlugin): def __init__(self): self.memories = [ "the earth has fictionally beeen destroyed.", diff --git a/examples/contextual_memory.py b/examples/contextual_memory.py index 5797213..b10b0c5 100644 --- a/examples/contextual_memory.py +++ b/examples/contextual_memory.py @@ -1,14 +1,14 @@ -from _context import sm - -from pydantic import BaseModel -import openai -import faiss -import numpy as np import os import pickle +import faiss +import numpy as np +import openai +from _context import sm +from pydantic import BaseModel -class ContextualMemoryPlugin: + +class ContextualMemoryPlugin(sm.BasePlugin): def __init__( self, api_key: str, diff --git a/examples/generate_data.py b/examples/generate_data.py index 888d694..75220a3 100644 --- a/examples/generate_data.py +++ b/examples/generate_data.py @@ -1,8 +1,7 @@ -from typing import List, Iterator - -from pydantic import BaseModel +from typing import Iterator, List from _context import sm +from pydantic import BaseModel class Movie(BaseModel): @@ -25,7 +24,7 @@ class QuotesList(BaseModel): quotes: List[MovieQuote] -def gen_quotes(n=10) -> Iterator[MovieQuote]: +def gen_quotes(n: int = 10) -> Iterator[MovieQuote]: """Generate a list of quotes from famous movies.""" for q in sm.generate_data( diff --git a/examples/math_plugin.py b/examples/math_plugin.py index 8637073..436d04f 100644 --- a/examples/math_plugin.py +++ b/examples/math_plugin.py @@ -1,9 +1,11 @@ from _context import sm -class MathPlugin: +class MathPlugin(sm.BasePlugin): def send_hook(self, conversation: sm.Conversation): last_user_message = conversation.get_last_message(role="user") + if last_user_message is None: + return if "calculate" in last_user_message.text.lower(): expression = last_user_message.text.lower().replace("calculate", "").strip() try: @@ -14,7 +16,7 @@ class MathPlugin: except Exception: conversation.add_message( role="assistant", - text="I'm sorry, I couldn't compute that expression.", + text="I'm sorry, I couldn't compute that expression. Please try again.", ) diff --git a/examples/sentiment_analysis.py b/examples/sentiment_analysis.py index 9cda733..7b855b6 100644 --- a/examples/sentiment_analysis.py +++ b/examples/sentiment_analysis.py @@ -1,8 +1,8 @@ -from _context import sm - -from pydantic import BaseModel from typing import Literal +from _context import sm +from pydantic import BaseModel + class SentimentAnalysis(BaseModel): sentiment: Literal["positive", "negative", "neutral"] diff --git a/examples/simple_memory.py b/examples/simple_memory.py index c2974da..a0d119e 100644 --- a/examples/simple_memory.py +++ b/examples/simple_memory.py @@ -1,7 +1,7 @@ from _context import sm -class SimpleMemoryPlugin(sm.BasePlugin): +class SimpleMemoryPlugin: def __init__(self): self.memories = [ "the earth has fictionally beeen destroyed.", diff --git a/examples/two_llms_talking.py b/examples/two_llms_talking.py index 4971867..f88dd33 100644 --- a/examples/two_llms_talking.py +++ b/examples/two_llms_talking.py @@ -1,6 +1,7 @@ -import simplemind as sm import time +import simplemind as sm + class ConversationPlugin(sm.BasePlugin): def post_send_hook(self, conversation, response): @@ -8,7 +9,7 @@ class ConversationPlugin(sm.BasePlugin): print(f"{conversation.llm_model}:\n{response.text.strip()}\n\n------------\n") -def have_conversation(rounds=3): +def have_conversation(rounds: int = 3): # Create two conversations - one for each AI with ( sm.create_conversation( diff --git a/pyproject.toml b/pyproject.toml index c86f0c2..8c287c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,8 +3,8 @@ name = "simplemind" version = "0.1.4" description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases." readme = "README.md" -requires-python = ">=3.11" -dependencies = ["pydantic", "pydantic-settings", "instructor", "openai", "anthropic", "ollama", "groq"] +requires-python = ">=3.10" +dependencies = ["pydantic", "pydantic-settings", "instructor", "openai", "anthropic", "ollama", "groq", "google-generativeai"] [build-system] requires = ["hatchling"] diff --git a/simplemind/__init__.py b/simplemind/__init__.py index 3d87feb..90313de 100644 --- a/simplemind/__init__.py +++ b/simplemind/__init__.py @@ -1,8 +1,8 @@ -from typing import List, Optional, Type +from typing import List, Type -from .models import Conversation, BasePlugin, BaseModel -from .utils import find_provider +from .models import BaseModel, BasePlugin, Conversation from .settings import settings +from .utils import find_provider class Session: @@ -56,9 +56,9 @@ class Session: def create_conversation( *, - llm_model=None, - llm_provider=None, - plugins: Optional[List[BasePlugin]] = None, + llm_model: str | None = None, + llm_provider: str | None = None, + plugins: List[BasePlugin] | None = None, **kwargs, ) -> Conversation: """Create a new conversation.""" @@ -77,7 +77,12 @@ def create_conversation( def generate_data( - prompt, *, llm_model=None, llm_provider=None, response_model=None, **kwargs + prompt: str, + *, + llm_model: str | None = None, + llm_provider: str | None = None, + response_model: Type[BaseModel], + **kwargs, ) -> BaseModel: """Generate structured data from a given prompt.""" @@ -92,7 +97,13 @@ def generate_data( ) -def generate_text(prompt, *, llm_model=None, llm_provider=None, **kwargs) -> str: +def generate_text( + prompt: str, + *, + llm_model: str | None = None, + llm_provider: str | None = None, + **kwargs, +) -> str: """Generate text from a given prompt.""" # Find the provider. diff --git a/simplemind/models.py b/simplemind/models.py index ec9f173..e548435 100644 --- a/simplemind/models.py +++ b/simplemind/models.py @@ -1,18 +1,18 @@ +from types import TracebackType import uuid -from abc import ABC, abstractmethod from datetime import datetime from typing import Any, Dict, List, Literal, Optional - from pydantic import BaseModel, Field -from .utils import find_provider - +from .providers import find_provider MESSAGE_ROLE = Literal["system", "user", "assistant"] class SMBaseModel(BaseModel): + """The base SimpleMind model class.""" + date_created: datetime = Field(default_factory=datetime.now) def __str__(self): @@ -22,34 +22,36 @@ class SMBaseModel(BaseModel): return str(self) -class BasePlugin: +class BasePlugin(SMBaseModel): """The base conversation plugin class.""" # Plugin metadata. meta: Dict[str, Any] = {} - def initialize_hook(self, conversation: "Conversation"): + def initialize_hook(self, conversation: "Conversation") -> Any: """Initialize a hook for the plugin.""" raise NotImplementedError - def cleanup_hook(self, conversation: "Conversation"): + def cleanup_hook(self, conversation: "Conversation") -> Any: """Cleanup a hook for the plugin.""" raise NotImplementedError - def add_message_hook(self, conversation: "Conversation", message: "Message"): + def add_message_hook(self, conversation: "Conversation", message: "Message") -> Any: """Add a message hook for the plugin.""" raise NotImplementedError - def pre_send_hook(self, conversation: "Conversation"): + def pre_send_hook(self, conversation: "Conversation") -> Any: """Pre-send hook for the plugin.""" raise NotImplementedError - def post_send_hook(self, conversation: "Conversation", response: "Message"): + def post_send_hook(self, conversation: "Conversation", response: "Message") -> Any: """Post-send hook for the plugin.""" raise NotImplementedError class Message(SMBaseModel): + """A message in a conversation.""" + role: MESSAGE_ROLE text: str meta: Dict[str, Any] = {} @@ -61,7 +63,16 @@ class Message(SMBaseModel): return f"" @classmethod - def from_raw_response(cls, *, text: str, raw): + def from_raw_response(cls, *, text: str, raw: Any) -> "Message": + """Create a Message instance from a raw response. + + Args: + text (str): The message text. + raw (Any): The raw response data. + + Returns: + Message: A new Message instance. + """ self = cls() self.text = text self.raw = raw @@ -69,11 +80,13 @@ class Message(SMBaseModel): class Conversation(SMBaseModel): + """A conversation between a user and an assistant.""" + id: str = Field(default_factory=lambda: str(uuid.uuid4())) messages: List[Message] = [] llm_model: Optional[str] = None llm_provider: Optional[str] = None - plugins: List[Any] = [] + plugins: List[BasePlugin] = [] def __str__(self): return f"" @@ -89,8 +102,13 @@ class Conversation(SMBaseModel): return self - def __exit__(self, exc_type, exc_value, traceback): - # Execute all cleanup hooks. + def __exit__( + self, + exc_type: type[BaseException], + exc_value: BaseException, + traceback: TracebackType, + ) -> None: + """Execute all cleanup hooks.""" for plugin in self.plugins: if hasattr(plugin, "cleanup_hook"): try: @@ -99,7 +117,7 @@ class Conversation(SMBaseModel): pass def prepend_system_message( - self, role: str, text: str, meta: Optional[Dict[str, Any]] = None + self, role: MESSAGE_ROLE, text: str, meta: Dict[str, Any] | None = None ): """Prepend a system message to the conversation.""" self.messages = [Message(role=role, text=text, meta=meta or {})] + self.messages @@ -127,7 +145,9 @@ class Conversation(SMBaseModel): self.messages.append(Message(role=role, text=text, meta=meta)) def send( - self, llm_model: Optional[str] = None, llm_provider: Optional[str] = None + self, + llm_model: str | None = None, + llm_provider: str | None = None, ) -> Message: """Send the conversation to the LLM.""" @@ -156,10 +176,10 @@ class Conversation(SMBaseModel): return response - def get_last_message(self, role: MESSAGE_ROLE) -> Optional[Message]: + def get_last_message(self, role: MESSAGE_ROLE) -> Message | None: """Get the last message with the given role.""" return next((m for m in reversed(self.messages) if m.role == role), None) - def add_plugin(self, plugin: Any): + def add_plugin(self, plugin: BasePlugin) -> None: """Add a plugin to the conversation.""" self.plugins.append(plugin) diff --git a/simplemind/providers/__init__.py b/simplemind/providers/__init__.py index f9b1983..201f851 100644 --- a/simplemind/providers/__init__.py +++ b/simplemind/providers/__init__.py @@ -2,9 +2,10 @@ from typing import List, Type from ._base import BaseProvider from .anthropic import Anthropic +from .gemini import Gemini from .groq import Groq -from .openai import OpenAI from .ollama import Ollama +from .openai import OpenAI from .xai import XAI -providers: List[Type[BaseProvider]] = [Anthropic, Groq, OpenAI, Ollama, XAI] +providers: List[Type[BaseProvider]] = [Anthropic, Gemini, Groq, OpenAI, Ollama, XAI] diff --git a/simplemind/providers/_base.py b/simplemind/providers/_base.py index dc1abdc..42ffb9b 100644 --- a/simplemind/providers/_base.py +++ b/simplemind/providers/_base.py @@ -1,6 +1,11 @@ from abc import ABC, abstractmethod +from functools import cached_property +from typing import Any, Type, TypeVar from instructor import Instructor +from pydantic import BaseModel + +T = TypeVar("T", bound=BaseModel) class BaseProvider(ABC): @@ -9,13 +14,13 @@ class BaseProvider(ABC): NAME: str DEFAULT_MODEL: str - @property + @cached_property @abstractmethod - def client(self): + def client(self) -> Any: """The instructor client for the provider.""" raise NotImplementedError - @property + @cached_property @abstractmethod def structured_client(self) -> Instructor: """The structured client for the provider.""" @@ -27,7 +32,7 @@ class BaseProvider(ABC): raise NotImplementedError @abstractmethod - def structured_response(self, prompt: str, response_model, **kwargs): + def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T: """Get a structured response.""" raise NotImplementedError diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index 1028b09..4a79dce 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -1,10 +1,15 @@ -from typing import Union +from functools import cached_property +from typing import Type, TypeVar import anthropic import instructor +from pydantic import BaseModel -from ._base import BaseProvider from ..settings import settings +from ._base import BaseProvider + +T = TypeVar("T", bound=BaseModel) + PROVIDER_NAME = "anthropic" DEFAULT_MODEL = "claude-3-5-sonnet-20241022" @@ -15,17 +20,17 @@ class Anthropic(BaseProvider): NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL - def __init__(self, api_key: Union[str, None] = None): + def __init__(self, api_key: str | None = None): self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) - @property + @cached_property def client(self): """The raw Anthropic client.""" if not self.api_key: raise ValueError("Anthropic API key is required") return anthropic.Anthropic(api_key=self.api_key) - @property + @cached_property def structured_client(self): """A client patched with Instructor.""" return instructor.from_anthropic(self.client) @@ -57,13 +62,13 @@ class Anthropic(BaseProvider): llm_provider=PROVIDER_NAME, ) - def structured_response(self, model, response_model, **kwargs): + def structured_response(self, model: str, response_model: Type[T], **kwargs) -> T: response = self.structured_client.messages.create( - model=model, response_model=response_model or self.DEFAULT_MODEL, **kwargs + model=model or self.DEFAULT_MODEL, response_model=response_model, **kwargs ) return response - def generate_text(self, prompt, *, llm_model, **kwargs): + def generate_text(self, prompt: str, *, llm_model: str, **kwargs): messages = [ {"role": "user", "content": prompt}, ] diff --git a/simplemind/providers/gemini.py b/simplemind/providers/gemini.py new file mode 100644 index 0000000..1e7a932 --- /dev/null +++ b/simplemind/providers/gemini.py @@ -0,0 +1,94 @@ +from functools import cached_property +from typing import Type, TypeVar + +import google.generativeai as genai +import instructor +from pydantic import BaseModel + +from ..models import Conversation, Message +from ..settings import settings +from ._base import BaseProvider + +PROVIDER_NAME = "gemini" +DEFAULT_MODEL = "models/gemini-1.5-flash-latest" + +T = TypeVar("T", bound=BaseModel) + + +class Gemini(BaseProvider): + NAME = PROVIDER_NAME + DEFAULT_MODEL = DEFAULT_MODEL + + def __init__(self, api_key: str | None = None): + self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) + self.model_name = DEFAULT_MODEL + + @cached_property + def client(self, model_name: str = DEFAULT_MODEL): + """The raw Gemini client.""" + if not self.api_key: + raise ValueError("Gemini API key is required") + self.model_name = model_name + return genai.GenerativeModel(model_name=self.model_name) + + @cached_property + def structured_client(self): + """A Gemini client patched with Instructor.""" + return instructor.from_gemini(self.client) + + def send_conversation(self, conversation: "Conversation") -> "Message": + """Send a conversation to the Gemini API.""" + + messages = [ + { + "role": msg.role, + "content": msg.text, + "metadata": msg.meta or {}, + } + for msg in conversation.messages + ] + + try: + response = self.structured_client.chat.completions.create( + messages=messages, response_model=None + ) + except Exception as e: + # Handle the exception appropriately, e.g., log the error or raise a custom exception + raise RuntimeError(f"Failed to send conversation to Gemini API: {e}") from e + + # Create and return a properly formatted Message instance + return Message( + role="assistant", + text=str(response), + raw=response, + llm_model=self.model_name, + llm_provider=PROVIDER_NAME, + ) + + def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T: + """Send a structured response to the Gemini API.""" + try: + response = self.structured_client.chat.completions.create( + messages=[{"role": "user", "content": prompt}], + response_model=response_model, + **kwargs, + ) + except Exception as e: + # Handle the exception appropriately, e.g., log the error or raise a custom exception + raise RuntimeError( + f"Failed to send structured response to Gemini API: {e}" + ) from e + return response + + def generate_text(self, prompt: str, **kwargs) -> str: + """Generate text using the Gemini API.""" + try: + response = self.structured_client.chat.completions.create( + messages=[{"role": "user", "content": prompt}], + response_model=None, + **kwargs, + ) + except Exception as e: + # Handle the exception appropriately, e.g., log the error or raise a custom exception + raise RuntimeError(f"Failed to generate text with Gemini API: {e}") from e + return str(response) diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py index 47f14ae..d1a3c8e 100644 --- a/simplemind/providers/groq.py +++ b/simplemind/providers/groq.py @@ -1,30 +1,34 @@ -from typing import Union +from functools import cached_property +from typing import Type, TypeVar import groq import instructor +from pydantic import BaseModel -from ._base import BaseProvider from ..settings import settings +from ._base import BaseProvider PROVIDER_NAME = "groq" DEFAULT_MODEL = "llama3-8b-8192" +T = TypeVar("T", bound=BaseModel) + class Groq(BaseProvider): NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL - def __init__(self, api_key: Union[str, None] = None): + def __init__(self, api_key: str | None = None): self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) - @property + @cached_property def client(self): """The raw Groq client.""" if not self.api_key: raise ValueError("Groq API key is required") return groq.Groq(api_key=self.api_key) - @property + @cached_property def structured_client(self): """A client patched with Instructor.""" return instructor.from_groq(self.client) @@ -59,7 +63,7 @@ class Groq(BaseProvider): llm_provider=PROVIDER_NAME, ) - def structured_response(self, prompt: str, response_model, **kwargs): + def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T: # Ensure messages are provided in kwargs messages = [ {"role": "user", "content": prompt}, diff --git a/simplemind/providers/ollama.py b/simplemind/providers/ollama.py index 5ba019a..7d1ac22 100644 --- a/simplemind/providers/ollama.py +++ b/simplemind/providers/ollama.py @@ -1,9 +1,16 @@ -import ollama as ol -import instructor -from openai import OpenAI +from functools import cached_property +from typing import Type, TypeVar + +import instructor +import ollama as ol +from openai import OpenAI +from pydantic import BaseModel -from ._base import BaseProvider from ..settings import settings +from ._base import BaseProvider + +T = TypeVar("T", bound=BaseModel) + PROVIDER_NAME = "ollama" DEFAULT_MODEL = "llama3.2" @@ -15,18 +22,18 @@ class Ollama(BaseProvider): DEFAULT_MODEL = DEFAULT_MODEL TIMEOUT = DEFAULT_TIMEOUT - def __init__(self, host_url: str = None): + def __init__(self, host_url: str | None = None): self.host_url = host_url or settings.OLLAMA_HOST_URL - @property + @cached_property def client(self): """The raw Ollama client.""" if not self.host_url: raise ValueError("No ollama host url provided") return ol.Client(timeout=self.TIMEOUT, host=self.host_url) - @property - def structured_client(self): + @cached_property + def structured_client(self) -> instructor.Instructor: """A client patched with Instructor.""" return instructor.from_openai( OpenAI( @@ -36,7 +43,7 @@ class Ollama(BaseProvider): mode=instructor.Mode.JSON, ) - def send_conversation(self, conversation: "Conversation"): + def send_conversation(self, conversation: "Conversation") -> "Message": """Send a conversation to the Ollama API.""" from ..models import Message @@ -57,7 +64,10 @@ class Ollama(BaseProvider): llm_provider=PROVIDER_NAME, ) - def structured_response(self, prompt, response_model, *, llm_model: str, **kwargs): + def structured_response( + self, prompt: str, response_model: Type[T], *, llm_model: str, **kwargs + ) -> T: + """Get a structured response from the Ollama API.""" messages = [ {"role": "user", "content": prompt}, ] @@ -70,7 +80,8 @@ class Ollama(BaseProvider): ) return response - def generate_text(self, prompt, *, llm_model): + def generate_text(self, prompt: str, *, llm_model: str) -> str: + """Generate text using the Ollama API.""" messages = [ {"role": "user", "content": prompt}, ] @@ -79,4 +90,4 @@ class Ollama(BaseProvider): messages=messages, model=llm_model or self.DEFAULT_MODEL ) - return response.get("message").get("content") + return response.get("message", {}).get("content", "") diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index 3895096..b6fa568 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -1,10 +1,14 @@ -from typing import Union +from functools import cached_property +from typing import Type, TypeVar import instructor import openai as oa +from pydantic import BaseModel -from ._base import BaseProvider from ..settings import settings +from ._base import BaseProvider + +T = TypeVar("T", bound=BaseModel) PROVIDER_NAME = "openai" DEFAULT_MODEL = "gpt-4o-mini" @@ -14,17 +18,17 @@ class OpenAI(BaseProvider): NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL - def __init__(self, api_key: Union[str, None] = None): + def __init__(self, api_key: str | None = None): self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) - @property + @cached_property def client(self): """The raw OpenAI client.""" if not self.api_key: raise ValueError("OpenAI API key is required") return oa.OpenAI(api_key=self.api_key) - @property + @cached_property def structured_client(self): """A OpenAI client with Instructor.""" return instructor.from_openai(self.client) @@ -53,12 +57,19 @@ class OpenAI(BaseProvider): llm_provider=PROVIDER_NAME, ) - def structured_response(self, prompt, response_model, *, llm_model: str, **kwargs): + def structured_response( + self, + prompt: str, + response_model: Type[T], + *, + llm_model: str | None = None, + **kwargs, + ) -> T: + """Get a structured response from the OpenAI API.""" # Ensure messages are provided in kwargs messages = [ {"role": "user", "content": prompt}, ] - response = self.structured_client.chat.completions.create( messages=messages, model=llm_model or self.DEFAULT_MODEL, @@ -67,13 +78,12 @@ class OpenAI(BaseProvider): ) return response - def generate_text(self, prompt, *, llm_model, **kwargs): + def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs): + """Generate text using the OpenAI API.""" messages = [ {"role": "user", "content": prompt}, ] - response = self.client.chat.completions.create( messages=messages, model=llm_model or self.DEFAULT_MODEL, **kwargs ) - return response.choices[0].message.content diff --git a/simplemind/providers/xai.py b/simplemind/providers/xai.py index 8b8b84d..dd32e0b 100644 --- a/simplemind/providers/xai.py +++ b/simplemind/providers/xai.py @@ -1,10 +1,10 @@ -from typing import Union +from functools import cached_property import instructor import openai as oa -from ._base import BaseProvider from ..settings import settings +from ._base import BaseProvider PROVIDER_NAME = "xai" DEFAULT_MODEL = "grok-beta" @@ -16,10 +16,10 @@ class XAI(BaseProvider): NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL - def __init__(self, api_key: Union[str, None] = None): + def __init__(self, api_key: str | None = None): self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) - @property + @cached_property def client(self): """The raw OpenAI client.""" if not self.api_key: @@ -29,7 +29,7 @@ class XAI(BaseProvider): base_url=BASE_URL, ) - @property + @cached_property def structured_client(self): """A client patched with Instructor.""" return instructor.from_openai(self.client) @@ -60,10 +60,10 @@ class XAI(BaseProvider): llm_provider=PROVIDER_NAME, ) - def structured_response(self, prompt: str, response_model, *, llm_model): + def structured_response(self, prompt: str, response_model, *, llm_model: str): raise NotImplementedError("XAI does not support structured responses") - def generate_text(self, prompt, *, llm_model, **kwargs): + def generate_text(self, prompt: str, *, llm_model: str, **kwargs): messages = [ {"role": "user", "content": prompt}, ] diff --git a/simplemind/settings.py b/simplemind/settings.py index 4068ebd..1c35347 100644 --- a/simplemind/settings.py +++ b/simplemind/settings.py @@ -1,8 +1,19 @@ -from typing import Optional, Union +from typing import Literal, Optional, Union from pydantic import Field, SecretStr, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict +logging_level = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + + +class LoggingConfig(BaseSettings): + """The class that holds all the logging settings for the application.""" + + enabled: bool = Field(False, description="Enable logging") + level: logging_level = Field("INFO", description="The logging level") + + model_config = SettingsConfigDict(extra="forbid") + class Settings(BaseSettings): """The class that holds all the API keys for the application.""" @@ -11,6 +22,7 @@ class Settings(BaseSettings): None, description="API key for Anthropic" ) GROQ_API_KEY: Optional[SecretStr] = Field(None, description="API key for Groq") + GEMINI_API_KEY: Optional[SecretStr] = Field(None, description="API key for Gemini") OPENAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for OpenAI") OLLAMA_HOST_URL: Optional[str] = Field( "http://127.0.0.1:11434", description="Fully qualified host URL for Ollama" @@ -22,6 +34,7 @@ class Settings(BaseSettings): model_config = SettingsConfigDict( env_file=".env", env_file_encoding="utf-8", case_sensitive=True, extra="ignore" ) + logging: LoggingConfig = LoggingConfig() @field_validator("*", mode="before") @classmethod diff --git a/simplemind/utils.py b/simplemind/utils.py index 67477ca..0226686 100644 --- a/simplemind/utils.py +++ b/simplemind/utils.py @@ -1,24 +1,26 @@ import difflib -from typing import Optional, Type -from .providers import providers, BaseProvider +from .providers import BaseProvider, providers _PROVIDER_NAMES = [provider.NAME.lower() for provider in providers] -def find_provider(provider_name: str) -> BaseProvider: +def find_provider(provider_name: str | None) -> BaseProvider: """ Find and instantiate a provider by name. Parameters: - provider_name (Union[str, None]): The name of the provider to find. + provider_name (Union[str, None]): The name of the provider to find. Returns: - An instance of the provider class if found. + An instance of the provider class if found. Raises: - ValueError: If the provider is not found, with a suggestion for the closest match. + ValueError: If the provider is not specified or is not found, with a suggestion for the closest match. """ + if provider_name is None: + raise ValueError("No provider specified.") + # Find the provider by name. for provider_class in providers: if provider_class.NAME.lower() == provider_name.lower(): @@ -29,10 +31,8 @@ def find_provider(provider_name: str) -> BaseProvider: provider_found = difflib.get_close_matches( provider_name.lower(), _PROVIDER_NAMES, n=1 ) - if provider_found: raise ValueError( f"Provider {provider_name!r} not found. Did you mean {provider_found[0]!r}?" ) - else: - raise ValueError(f"Provider {provider_name} not found.") + raise ValueError(f"Provider {provider_name} not found.")