From ca086bb5743b2e42bb2a8b4c0812646ff7e06866 Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Tue, 29 Oct 2024 09:28:30 +0530 Subject: [PATCH 1/9] added expection for missing key --- simplemind/providers/anthropic.py | 2 ++ simplemind/providers/groq.py | 4 +++- simplemind/providers/openai.py | 2 ++ simplemind/providers/xai.py | 7 ++++--- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index e4cdb7a..328577f 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -21,6 +21,8 @@ class Anthropic(BaseProvider): @property def client(self): """The raw Anthropic client.""" + if not self.api_key: + raise ValueError("Anthropic API key is required") return anthropic.Anthropic(api_key=self.api_key) @property diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py index 465f428..ec5e6ac 100644 --- a/simplemind/providers/groq.py +++ b/simplemind/providers/groq.py @@ -20,6 +20,8 @@ class Groq(BaseProvider): @property def client(self): """The raw Groq client.""" + if not self.api_key: + raise ValueError("Groq API key is required") return groq.Groq(api_key=self.api_key) @property @@ -43,7 +45,7 @@ class Groq(BaseProvider): # Create and return a properly formatted Message instance return Message( role="assistant", - text=assistant_message.content, + text=assistant_message.content or "", raw=response, llm_model=conversation.llm_model or DEFAULT_MODEL, llm_provider=PROVIDER_NAME, diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index 4eb2795..2006f0c 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -20,6 +20,8 @@ class OpenAI(BaseProvider): @property def client(self): """The raw OpenAI client.""" + if not self.api_key: + raise ValueError("OpenAI API key is required") return oa.OpenAI(api_key=self.api_key) @property diff --git a/simplemind/providers/xai.py b/simplemind/providers/xai.py index 04b9709..df4cc92 100644 --- a/simplemind/providers/xai.py +++ b/simplemind/providers/xai.py @@ -22,10 +22,11 @@ class XAI(BaseProvider): @property def client(self): """The raw OpenAI client.""" - + if not self.api_key: + raise ValueError("XAI API key is required") return oa.OpenAI( - api_key=settings.XAI_API_KEY, - base_url="https://api.x.ai/v1", + api_key=self.api_key, + base_url=BASE_URL, ) @property From 787212bae0ddbd8ab05938583a92fde40b6b8c73 Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Tue, 29 Oct 2024 11:38:36 +0530 Subject: [PATCH 2/9] improved BaseProvider --- simplemind/models.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/simplemind/models.py b/simplemind/models.py index 397e773..9ff1c44 100644 --- a/simplemind/models.py +++ b/simplemind/models.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod import uuid from datetime import datetime from typing import Any, Dict, List, Literal, Optional @@ -17,31 +18,36 @@ class SMBaseModel(BaseModel): return str(self) -class BaseProvider(SMBaseModel): +class BaseProvider(SMBaseModel, ABC): """The base provider class.""" __name__ = "BaseProvider" DEFAULT_MODEL = "DEFAULT_MODEL" @property + @abstractmethod def client(self): """The instructor client for the provider.""" raise NotImplementedError @property + @abstractmethod def structured_client(self): """The structured client for the provider.""" raise NotImplementedError - def send_conversation(self, conversation: "Conversation"): + @abstractmethod + def send_conversation(self, conversation: "Conversation") -> "Message": """Send a conversation to the provider.""" raise NotImplementedError + @abstractmethod def structured_response(self, prompt: str, response_model, **kwargs): """Get a structured response.""" raise NotImplementedError - def generate_text(self, prompt: str, **kwargs): + @abstractmethod + def generate_text(self, prompt: str, **kwargs) -> str: """Generate text from a prompt.""" raise NotImplementedError From a2c81b9b9f6a09eed8068183e1ce8256937552f9 Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Tue, 29 Oct 2024 11:39:01 +0530 Subject: [PATCH 3/9] added more methods for groq --- simplemind/providers/groq.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py index ec5e6ac..a880fd5 100644 --- a/simplemind/providers/groq.py +++ b/simplemind/providers/groq.py @@ -29,7 +29,7 @@ class Groq(BaseProvider): """A client patched with Instructor.""" return instructor.from_groq(self.client) - def send_conversation(self, conversation: "Conversation"): + def send_conversation(self, conversation: Conversation) -> Message: """Send a conversation to the Groq API.""" messages = [ {"role": msg.role, "content": msg.text} for msg in conversation.messages @@ -50,3 +50,26 @@ class Groq(BaseProvider): llm_model=conversation.llm_model or DEFAULT_MODEL, llm_provider=PROVIDER_NAME, ) + + def structured_response(self, prompt: str, response_model): + # Ensure messages are provided in kwargs + messages = [ + {"role": "user", "content": prompt}, + ] + + response = self.structured_client.chat.completions.create( + messages=messages, + response_model=response_model, + ) + return response + + def generate_text(self, prompt: str, *, llm_model: str): + messages = [ + {"role": "user", "content": prompt}, + ] + + response = self.structured_client.chat.completions.create( + messages=messages, model=llm_model + ) + + return response.choices[0].message.content From d1c116171232453c9f73944b82237fccf9b30e96 Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Tue, 29 Oct 2024 12:34:53 +0530 Subject: [PATCH 4/9] improved settings --- simplemind/settings.py | 37 +++++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/simplemind/settings.py b/simplemind/settings.py index 0c3f319..ae81a64 100644 --- a/simplemind/settings.py +++ b/simplemind/settings.py @@ -1,12 +1,37 @@ -from pydantic import Field -from pydantic_settings import BaseSettings +from typing import Optional, Union +from pydantic import Field, SecretStr, field_validator +from pydantic_settings import BaseSettings, SettingsConfigDict class Settings(BaseSettings): - ANTHROPIC_API_KEY: str = Field(..., env="ANTHROPIC_API_KEY") - GROQ_API_KEY: str = Field(..., env="GROQ_API_KEY") - OPENAI_API_KEY: str = Field(..., env="OPENAI_API_KEY") - XAI_API_KEY: str = Field(..., env="XAI_API_KEY") + """The class that holds all the API keys for the application.""" + + ANTHROPIC_API_KEY: Optional[SecretStr] = Field( + None, description="API key for Anthropic" + ) + GROQ_API_KEY: Optional[SecretStr] = Field(None, description="API key for Groq") + OPENAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for OpenAI") + XAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for xAI") + + model_config = SettingsConfigDict( + env_file=".env", env_file_encoding="utf-8", case_sensitive=True, extra="ignore" + ) + + @field_validator("*", mode="before") + @classmethod + def empty_str_to_none(cls, v: str) -> Optional[str]: + """Convert empty strings to None for optional fields.""" + if v == "": + return None + return v + + def get_api_key(self, provider: str) -> Union[str, None]: + """ + Safely get API key for a specific provider. + Returns the key as a string or None if not set. + """ + key = getattr(self, f"{provider.upper()}_API_KEY", None) + return key.get_secret_value() if key else None settings = Settings() From 906fadc03dfc046a66b23de3747c33dc8c594e1b Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Tue, 29 Oct 2024 12:36:17 +0530 Subject: [PATCH 5/9] moved BaseProvider to prevent cyclic imports --- simplemind/models.py | 44 +++++-------------------------- simplemind/providers/__init__.py | 4 ++- simplemind/providers/_base.py | 39 +++++++++++++++++++++++++++ simplemind/providers/anthropic.py | 3 ++- simplemind/providers/groq.py | 3 ++- simplemind/providers/openai.py | 5 ++-- simplemind/providers/xai.py | 3 ++- 7 files changed, 57 insertions(+), 44 deletions(-) create mode 100644 simplemind/providers/_base.py diff --git a/simplemind/models.py b/simplemind/models.py index 9ff1c44..f53c35d 100644 --- a/simplemind/models.py +++ b/simplemind/models.py @@ -1,4 +1,3 @@ -from abc import ABC, abstractmethod import uuid from datetime import datetime from typing import Any, Dict, List, Literal, Optional @@ -8,6 +7,9 @@ from pydantic import BaseModel, Field from .utils import find_provider +MESSAGE_ROLE = Literal["system", "user", "assistant"] + + class SMBaseModel(BaseModel): date_created: datetime = Field(default_factory=datetime.now) @@ -18,46 +20,12 @@ class SMBaseModel(BaseModel): return str(self) -class BaseProvider(SMBaseModel, ABC): - """The base provider class.""" - - __name__ = "BaseProvider" - DEFAULT_MODEL = "DEFAULT_MODEL" - - @property - @abstractmethod - def client(self): - """The instructor client for the provider.""" - raise NotImplementedError - - @property - @abstractmethod - def structured_client(self): - """The structured client for the provider.""" - raise NotImplementedError - - @abstractmethod - def send_conversation(self, conversation: "Conversation") -> "Message": - """Send a conversation to the provider.""" - raise NotImplementedError - - @abstractmethod - def structured_response(self, prompt: str, response_model, **kwargs): - """Get a structured response.""" - raise NotImplementedError - - @abstractmethod - def generate_text(self, prompt: str, **kwargs) -> str: - """Generate text from a prompt.""" - raise NotImplementedError - - class BasePlugin(SMBaseModel): """The base plugin class.""" class Message(SMBaseModel): - role: Literal["system", "user", "assistant"] + role: MESSAGE_ROLE text: str meta: Dict[str, Any] = {} raw: Optional[Any] = None @@ -68,7 +36,7 @@ class Message(SMBaseModel): return f"" @classmethod - def from_raw_response(cls, *, text, raw): + def from_raw_response(cls, *, text: str, raw): self = cls() self.text = text self.raw = raw @@ -85,7 +53,7 @@ class Conversation(SMBaseModel): def __str__(self): return f"" - def add_message(self, role: str, text: str, meta: Dict[str, Any] = {}): + def add_message(self, role: MESSAGE_ROLE, text: str, meta: Dict[str, Any] = {}): """Add a new message to the conversation.""" self.messages.append(Message(role=role, text=text, meta=meta)) diff --git a/simplemind/providers/__init__.py b/simplemind/providers/__init__.py index bf54bf5..5941019 100644 --- a/simplemind/providers/__init__.py +++ b/simplemind/providers/__init__.py @@ -1,6 +1,8 @@ +from typing import List, Type from .anthropic import Anthropic from .groq import Groq from .openai import OpenAI from .xai import XAI +from ._base import BaseProvider -providers = [Anthropic, Groq, OpenAI, XAI] +providers: List[Type[BaseProvider]] = [Anthropic, Groq, OpenAI, XAI] diff --git a/simplemind/providers/_base.py b/simplemind/providers/_base.py new file mode 100644 index 0000000..ff4df23 --- /dev/null +++ b/simplemind/providers/_base.py @@ -0,0 +1,39 @@ +from abc import ABC, abstractmethod + +from instructor import Instructor + +from simplemind.models import Conversation, Message + + +class BaseProvider(ABC): + """The base provider class.""" + + __name__: str + DEFAULT_MODEL: str + + @property + @abstractmethod + def client(self): + """The instructor client for the provider.""" + raise NotImplementedError + + @property + @abstractmethod + def structured_client(self) -> Instructor: + """The structured client for the provider.""" + raise NotImplementedError + + @abstractmethod + def send_conversation(self, conversation: Conversation) -> Message: + """Send a conversation to the provider.""" + raise NotImplementedError + + @abstractmethod + def structured_response(self, prompt: str, response_model, **kwargs): + """Get a structured response.""" + raise NotImplementedError + + @abstractmethod + def generate_text(self, prompt: str, **kwargs) -> str: + """Generate text from a prompt.""" + raise NotImplementedError diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index 328577f..38adec8 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -3,7 +3,8 @@ from typing import Union import anthropic import instructor -from simplemind.models import BaseProvider, Conversation, Message +from simplemind.models import Conversation, Message +from simplemind.providers._base import BaseProvider from simplemind.settings import settings PROVIDER_NAME = "anthropic" diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py index a880fd5..98e3009 100644 --- a/simplemind/providers/groq.py +++ b/simplemind/providers/groq.py @@ -3,7 +3,8 @@ from typing import Union import groq import instructor -from simplemind.models import BaseProvider, Conversation, Message +from simplemind.models import Conversation, Message +from simplemind.providers._base import BaseProvider from simplemind.settings import settings PROVIDER_NAME = "groq" diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index 2006f0c..c090286 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -3,7 +3,8 @@ from typing import Union import instructor import openai as oa -from simplemind.models import BaseProvider, Conversation, Message +from simplemind.models import Conversation, Message +from simplemind.providers._base import BaseProvider from simplemind.settings import settings PROVIDER_NAME = "openai" @@ -51,7 +52,7 @@ class OpenAI(BaseProvider): llm_provider=PROVIDER_NAME, ) - def structured_response(self, prompt, response_model, *, llm_model): + def structured_response(self, prompt, response_model, *, llm_model: str): # Ensure messages are provided in kwargs messages = [ {"role": "user", "content": prompt}, diff --git a/simplemind/providers/xai.py b/simplemind/providers/xai.py index df4cc92..84d349d 100644 --- a/simplemind/providers/xai.py +++ b/simplemind/providers/xai.py @@ -3,7 +3,8 @@ from typing import Union import instructor import openai as oa -from simplemind.models import BaseProvider, Conversation, Message +from simplemind.models import Conversation, Message +from simplemind.providers._base import BaseProvider from simplemind.settings import settings PROVIDER_NAME = "xai" From f2996ac21d0eaf0b45b0aa0e325dbef8593bfac1 Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Tue, 29 Oct 2024 12:37:03 +0530 Subject: [PATCH 6/9] handled None case in find_providers --- simplemind/utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/simplemind/utils.py b/simplemind/utils.py index 8304881..c1bd2b7 100644 --- a/simplemind/utils.py +++ b/simplemind/utils.py @@ -1,10 +1,12 @@ +from typing import Union from .providers import providers -def find_provider(provider_name: str): +def find_provider(provider_name: Union[str, None]): """Find a provider by name.""" - for provider_class in providers: - if provider_class.__name__.lower() == provider_name.lower(): - # Instantiate the provider - return provider_class() + if provider_name: + for provider_class in providers: + if provider_class.__name__.lower() == provider_name.lower(): + # Instantiate the provider + return provider_class() raise ValueError(f"Provider {provider_name} not found") From 8b76f5b54e828948c03e6c1801f3c7502cd8c9e2 Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Tue, 29 Oct 2024 13:52:03 +0530 Subject: [PATCH 7/9] changed api key access method --- simplemind/providers/anthropic.py | 2 +- simplemind/providers/groq.py | 2 +- simplemind/providers/openai.py | 2 +- simplemind/providers/xai.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index 38adec8..4efa976 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -17,7 +17,7 @@ class Anthropic(BaseProvider): DEFAULT_MODEL = DEFAULT_MODEL def __init__(self, api_key: Union[str, None] = None): - self.api_key = api_key or settings.ANTHROPIC_API_KEY + self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) @property def client(self): diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py index 98e3009..4cd48f5 100644 --- a/simplemind/providers/groq.py +++ b/simplemind/providers/groq.py @@ -16,7 +16,7 @@ class Groq(BaseProvider): DEFAULT_MODEL = DEFAULT_MODEL def __init__(self, api_key: Union[str, None] = None): - self.api_key = api_key or settings.GROQ_API_KEY + self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) @property def client(self): diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index c090286..6302747 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -16,7 +16,7 @@ class OpenAI(BaseProvider): DEFAULT_MODEL = DEFAULT_MODEL def __init__(self, api_key: Union[str, None] = None): - self.api_key = api_key or settings.OPENAI_API_KEY + self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) @property def client(self): diff --git a/simplemind/providers/xai.py b/simplemind/providers/xai.py index 84d349d..6b22020 100644 --- a/simplemind/providers/xai.py +++ b/simplemind/providers/xai.py @@ -18,7 +18,7 @@ class XAI(BaseProvider): DEFAULT_MODEL = DEFAULT_MODEL def __init__(self, api_key: Union[str, None] = None): - self.api_key = api_key or settings.XAI_API_KEY + self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) @property def client(self): From 9c1ad7ed4556007baad929fdc4c4bd1e80d6eacb Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Tue, 29 Oct 2024 13:52:25 +0530 Subject: [PATCH 8/9] fixed import paths --- simplemind/__init__.py | 10 ++++++++++ simplemind/models.py | 9 ++++++--- simplemind/providers/__init__.py | 11 ++++++----- simplemind/settings.py | 1 + simplemind/utils.py | 3 ++- 5 files changed, 25 insertions(+), 9 deletions(-) diff --git a/simplemind/__init__.py b/simplemind/__init__.py index b98e75c..2c7ca0e 100644 --- a/simplemind/__init__.py +++ b/simplemind/__init__.py @@ -32,3 +32,13 @@ def generate_text(prompt, *, llm_model=None, llm_provider=None): provider = find_provider(llm_provider) return provider.generate_text(prompt=prompt, llm_model=llm_model) + + +__all__ = [ + "Conversation", + "SimpleMind", + "create_conversation", + "find_provider", + "generate_data", + "generate_text", +] diff --git a/simplemind/models.py b/simplemind/models.py index f53c35d..5846489 100644 --- a/simplemind/models.py +++ b/simplemind/models.py @@ -4,8 +4,7 @@ from typing import Any, Dict, List, Literal, Optional from pydantic import BaseModel, Field -from .utils import find_provider - +from simplemind.utils import find_provider MESSAGE_ROLE = Literal["system", "user", "assistant"] @@ -53,8 +52,12 @@ class Conversation(SMBaseModel): def __str__(self): return f"" - def add_message(self, role: MESSAGE_ROLE, text: str, meta: Dict[str, Any] = {}): + def add_message( + self, role: MESSAGE_ROLE, text: str, meta: Optional[Dict[str, Any]] = None + ): """Add a new message to the conversation.""" + if meta is None: + meta = {} self.messages.append(Message(role=role, text=text, meta=meta)) def send( diff --git a/simplemind/providers/__init__.py b/simplemind/providers/__init__.py index 5941019..9e38db1 100644 --- a/simplemind/providers/__init__.py +++ b/simplemind/providers/__init__.py @@ -1,8 +1,9 @@ from typing import List, Type -from .anthropic import Anthropic -from .groq import Groq -from .openai import OpenAI -from .xai import XAI -from ._base import BaseProvider + +from simplemind.providers._base import BaseProvider +from simplemind.providers.anthropic import Anthropic +from simplemind.providers.groq import Groq +from simplemind.providers.openai import OpenAI +from simplemind.providers.xai import XAI providers: List[Type[BaseProvider]] = [Anthropic, Groq, OpenAI, XAI] diff --git a/simplemind/settings.py b/simplemind/settings.py index ae81a64..54ed61f 100644 --- a/simplemind/settings.py +++ b/simplemind/settings.py @@ -1,4 +1,5 @@ from typing import Optional, Union + from pydantic import Field, SecretStr, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict diff --git a/simplemind/utils.py b/simplemind/utils.py index c1bd2b7..38e0c0d 100644 --- a/simplemind/utils.py +++ b/simplemind/utils.py @@ -1,5 +1,6 @@ from typing import Union -from .providers import providers + +from simplemind.providers import providers def find_provider(provider_name: Union[str, None]): From c420b06a66ab907e1e08c5b3a8a4e45a71736d4a Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Tue, 29 Oct 2024 14:01:15 +0530 Subject: [PATCH 9/9] updated deps --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1c4cebc..57b6a1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ version = "0.1.0" description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases." readme = "README.md" requires-python = ">=3.11" -dependencies = ["pydantic", "pydantic-settings", "instructor", "openai", "anthropic"] +dependencies = ["pydantic", "pydantic-settings", "instructor", "openai", "anthropic", "groq"] [build-system] requires = ["hatchling"]