diff --git a/pyproject.toml b/pyproject.toml index 1c4cebc..57b6a1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ version = "0.1.0" description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases." readme = "README.md" requires-python = ">=3.11" -dependencies = ["pydantic", "pydantic-settings", "instructor", "openai", "anthropic"] +dependencies = ["pydantic", "pydantic-settings", "instructor", "openai", "anthropic", "groq"] [build-system] requires = ["hatchling"] diff --git a/simplemind/__init__.py b/simplemind/__init__.py index b98e75c..2c7ca0e 100644 --- a/simplemind/__init__.py +++ b/simplemind/__init__.py @@ -32,3 +32,13 @@ def generate_text(prompt, *, llm_model=None, llm_provider=None): provider = find_provider(llm_provider) return provider.generate_text(prompt=prompt, llm_model=llm_model) + + +__all__ = [ + "Conversation", + "SimpleMind", + "create_conversation", + "find_provider", + "generate_data", + "generate_text", +] diff --git a/simplemind/models.py b/simplemind/models.py index 397e773..5846489 100644 --- a/simplemind/models.py +++ b/simplemind/models.py @@ -4,7 +4,9 @@ from typing import Any, Dict, List, Literal, Optional from pydantic import BaseModel, Field -from .utils import find_provider +from simplemind.utils import find_provider + +MESSAGE_ROLE = Literal["system", "user", "assistant"] class SMBaseModel(BaseModel): @@ -17,41 +19,12 @@ class SMBaseModel(BaseModel): return str(self) -class BaseProvider(SMBaseModel): - """The base provider class.""" - - __name__ = "BaseProvider" - DEFAULT_MODEL = "DEFAULT_MODEL" - - @property - def client(self): - """The instructor client for the provider.""" - raise NotImplementedError - - @property - def structured_client(self): - """The structured client for the provider.""" - raise NotImplementedError - - def send_conversation(self, conversation: "Conversation"): - """Send a conversation to the provider.""" - raise NotImplementedError - - def structured_response(self, prompt: str, response_model, **kwargs): - """Get a structured response.""" - raise NotImplementedError - - def generate_text(self, prompt: str, **kwargs): - """Generate text from a prompt.""" - raise NotImplementedError - - class BasePlugin(SMBaseModel): """The base plugin class.""" class Message(SMBaseModel): - role: Literal["system", "user", "assistant"] + role: MESSAGE_ROLE text: str meta: Dict[str, Any] = {} raw: Optional[Any] = None @@ -62,7 +35,7 @@ class Message(SMBaseModel): return f"" @classmethod - def from_raw_response(cls, *, text, raw): + def from_raw_response(cls, *, text: str, raw): self = cls() self.text = text self.raw = raw @@ -79,8 +52,12 @@ class Conversation(SMBaseModel): def __str__(self): return f"" - def add_message(self, role: str, text: str, meta: Dict[str, Any] = {}): + def add_message( + self, role: MESSAGE_ROLE, text: str, meta: Optional[Dict[str, Any]] = None + ): """Add a new message to the conversation.""" + if meta is None: + meta = {} self.messages.append(Message(role=role, text=text, meta=meta)) def send( diff --git a/simplemind/providers/__init__.py b/simplemind/providers/__init__.py index bf54bf5..9e38db1 100644 --- a/simplemind/providers/__init__.py +++ b/simplemind/providers/__init__.py @@ -1,6 +1,9 @@ -from .anthropic import Anthropic -from .groq import Groq -from .openai import OpenAI -from .xai import XAI +from typing import List, Type -providers = [Anthropic, Groq, OpenAI, XAI] +from simplemind.providers._base import BaseProvider +from simplemind.providers.anthropic import Anthropic +from simplemind.providers.groq import Groq +from simplemind.providers.openai import OpenAI +from simplemind.providers.xai import XAI + +providers: List[Type[BaseProvider]] = [Anthropic, Groq, OpenAI, XAI] diff --git a/simplemind/providers/_base.py b/simplemind/providers/_base.py new file mode 100644 index 0000000..ff4df23 --- /dev/null +++ b/simplemind/providers/_base.py @@ -0,0 +1,39 @@ +from abc import ABC, abstractmethod + +from instructor import Instructor + +from simplemind.models import Conversation, Message + + +class BaseProvider(ABC): + """The base provider class.""" + + __name__: str + DEFAULT_MODEL: str + + @property + @abstractmethod + def client(self): + """The instructor client for the provider.""" + raise NotImplementedError + + @property + @abstractmethod + def structured_client(self) -> Instructor: + """The structured client for the provider.""" + raise NotImplementedError + + @abstractmethod + def send_conversation(self, conversation: Conversation) -> Message: + """Send a conversation to the provider.""" + raise NotImplementedError + + @abstractmethod + def structured_response(self, prompt: str, response_model, **kwargs): + """Get a structured response.""" + raise NotImplementedError + + @abstractmethod + def generate_text(self, prompt: str, **kwargs) -> str: + """Generate text from a prompt.""" + raise NotImplementedError diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index e4cdb7a..4efa976 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -3,7 +3,8 @@ from typing import Union import anthropic import instructor -from simplemind.models import BaseProvider, Conversation, Message +from simplemind.models import Conversation, Message +from simplemind.providers._base import BaseProvider from simplemind.settings import settings PROVIDER_NAME = "anthropic" @@ -16,11 +17,13 @@ class Anthropic(BaseProvider): DEFAULT_MODEL = DEFAULT_MODEL def __init__(self, api_key: Union[str, None] = None): - self.api_key = api_key or settings.ANTHROPIC_API_KEY + self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) @property def client(self): """The raw Anthropic client.""" + if not self.api_key: + raise ValueError("Anthropic API key is required") return anthropic.Anthropic(api_key=self.api_key) @property diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py index 465f428..4cd48f5 100644 --- a/simplemind/providers/groq.py +++ b/simplemind/providers/groq.py @@ -3,7 +3,8 @@ from typing import Union import groq import instructor -from simplemind.models import BaseProvider, Conversation, Message +from simplemind.models import Conversation, Message +from simplemind.providers._base import BaseProvider from simplemind.settings import settings PROVIDER_NAME = "groq" @@ -15,11 +16,13 @@ class Groq(BaseProvider): DEFAULT_MODEL = DEFAULT_MODEL def __init__(self, api_key: Union[str, None] = None): - self.api_key = api_key or settings.GROQ_API_KEY + self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) @property def client(self): """The raw Groq client.""" + if not self.api_key: + raise ValueError("Groq API key is required") return groq.Groq(api_key=self.api_key) @property @@ -27,7 +30,7 @@ class Groq(BaseProvider): """A client patched with Instructor.""" return instructor.from_groq(self.client) - def send_conversation(self, conversation: "Conversation"): + def send_conversation(self, conversation: Conversation) -> Message: """Send a conversation to the Groq API.""" messages = [ {"role": msg.role, "content": msg.text} for msg in conversation.messages @@ -43,8 +46,31 @@ class Groq(BaseProvider): # Create and return a properly formatted Message instance return Message( role="assistant", - text=assistant_message.content, + text=assistant_message.content or "", raw=response, llm_model=conversation.llm_model or DEFAULT_MODEL, llm_provider=PROVIDER_NAME, ) + + def structured_response(self, prompt: str, response_model): + # Ensure messages are provided in kwargs + messages = [ + {"role": "user", "content": prompt}, + ] + + response = self.structured_client.chat.completions.create( + messages=messages, + response_model=response_model, + ) + return response + + def generate_text(self, prompt: str, *, llm_model: str): + messages = [ + {"role": "user", "content": prompt}, + ] + + response = self.structured_client.chat.completions.create( + messages=messages, model=llm_model + ) + + return response.choices[0].message.content diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index 4eb2795..6302747 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -3,7 +3,8 @@ from typing import Union import instructor import openai as oa -from simplemind.models import BaseProvider, Conversation, Message +from simplemind.models import Conversation, Message +from simplemind.providers._base import BaseProvider from simplemind.settings import settings PROVIDER_NAME = "openai" @@ -15,11 +16,13 @@ class OpenAI(BaseProvider): DEFAULT_MODEL = DEFAULT_MODEL def __init__(self, api_key: Union[str, None] = None): - self.api_key = api_key or settings.OPENAI_API_KEY + self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) @property def client(self): """The raw OpenAI client.""" + if not self.api_key: + raise ValueError("OpenAI API key is required") return oa.OpenAI(api_key=self.api_key) @property @@ -49,7 +52,7 @@ class OpenAI(BaseProvider): llm_provider=PROVIDER_NAME, ) - def structured_response(self, prompt, response_model, *, llm_model): + def structured_response(self, prompt, response_model, *, llm_model: str): # Ensure messages are provided in kwargs messages = [ {"role": "user", "content": prompt}, diff --git a/simplemind/providers/xai.py b/simplemind/providers/xai.py index 04b9709..6b22020 100644 --- a/simplemind/providers/xai.py +++ b/simplemind/providers/xai.py @@ -3,7 +3,8 @@ from typing import Union import instructor import openai as oa -from simplemind.models import BaseProvider, Conversation, Message +from simplemind.models import Conversation, Message +from simplemind.providers._base import BaseProvider from simplemind.settings import settings PROVIDER_NAME = "xai" @@ -17,15 +18,16 @@ class XAI(BaseProvider): DEFAULT_MODEL = DEFAULT_MODEL def __init__(self, api_key: Union[str, None] = None): - self.api_key = api_key or settings.XAI_API_KEY + self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) @property def client(self): """The raw OpenAI client.""" - + if not self.api_key: + raise ValueError("XAI API key is required") return oa.OpenAI( - api_key=settings.XAI_API_KEY, - base_url="https://api.x.ai/v1", + api_key=self.api_key, + base_url=BASE_URL, ) @property diff --git a/simplemind/settings.py b/simplemind/settings.py index 0c3f319..54ed61f 100644 --- a/simplemind/settings.py +++ b/simplemind/settings.py @@ -1,12 +1,38 @@ -from pydantic import Field -from pydantic_settings import BaseSettings +from typing import Optional, Union + +from pydantic import Field, SecretStr, field_validator +from pydantic_settings import BaseSettings, SettingsConfigDict class Settings(BaseSettings): - ANTHROPIC_API_KEY: str = Field(..., env="ANTHROPIC_API_KEY") - GROQ_API_KEY: str = Field(..., env="GROQ_API_KEY") - OPENAI_API_KEY: str = Field(..., env="OPENAI_API_KEY") - XAI_API_KEY: str = Field(..., env="XAI_API_KEY") + """The class that holds all the API keys for the application.""" + + ANTHROPIC_API_KEY: Optional[SecretStr] = Field( + None, description="API key for Anthropic" + ) + GROQ_API_KEY: Optional[SecretStr] = Field(None, description="API key for Groq") + OPENAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for OpenAI") + XAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for xAI") + + model_config = SettingsConfigDict( + env_file=".env", env_file_encoding="utf-8", case_sensitive=True, extra="ignore" + ) + + @field_validator("*", mode="before") + @classmethod + def empty_str_to_none(cls, v: str) -> Optional[str]: + """Convert empty strings to None for optional fields.""" + if v == "": + return None + return v + + def get_api_key(self, provider: str) -> Union[str, None]: + """ + Safely get API key for a specific provider. + Returns the key as a string or None if not set. + """ + key = getattr(self, f"{provider.upper()}_API_KEY", None) + return key.get_secret_value() if key else None settings = Settings() diff --git a/simplemind/utils.py b/simplemind/utils.py index 8304881..38e0c0d 100644 --- a/simplemind/utils.py +++ b/simplemind/utils.py @@ -1,10 +1,13 @@ -from .providers import providers +from typing import Union + +from simplemind.providers import providers -def find_provider(provider_name: str): +def find_provider(provider_name: Union[str, None]): """Find a provider by name.""" - for provider_class in providers: - if provider_class.__name__.lower() == provider_name.lower(): - # Instantiate the provider - return provider_class() + if provider_name: + for provider_class in providers: + if provider_class.__name__.lower() == provider_name.lower(): + # Instantiate the provider + return provider_class() raise ValueError(f"Provider {provider_name} not found")