diff --git a/.envrc.template b/.envrc.template index 8438939..cb0f439 100644 --- a/.envrc.template +++ b/.envrc.template @@ -1,3 +1,4 @@ export OPENAI_API_KEY="" export ANTHROPIC_API_KEY="" export XAI_API_KEY="" +export GROQ_API_KEY="" \ No newline at end of file diff --git a/simplemind/models.py b/simplemind/models.py index 4a85c99..397e773 100644 --- a/simplemind/models.py +++ b/simplemind/models.py @@ -1,6 +1,6 @@ import uuid -from typing import List, Dict, Any, Optional from datetime import datetime +from typing import Any, Dict, List, Literal, Optional from pydantic import BaseModel, Field @@ -17,8 +17,41 @@ class SMBaseModel(BaseModel): return str(self) +class BaseProvider(SMBaseModel): + """The base provider class.""" + + __name__ = "BaseProvider" + DEFAULT_MODEL = "DEFAULT_MODEL" + + @property + def client(self): + """The instructor client for the provider.""" + raise NotImplementedError + + @property + def structured_client(self): + """The structured client for the provider.""" + raise NotImplementedError + + def send_conversation(self, conversation: "Conversation"): + """Send a conversation to the provider.""" + raise NotImplementedError + + def structured_response(self, prompt: str, response_model, **kwargs): + """Get a structured response.""" + raise NotImplementedError + + def generate_text(self, prompt: str, **kwargs): + """Generate text from a prompt.""" + raise NotImplementedError + + +class BasePlugin(SMBaseModel): + """The base plugin class.""" + + class Message(SMBaseModel): - role: str + role: Literal["system", "user", "assistant"] text: str meta: Dict[str, Any] = {} raw: Optional[Any] = None diff --git a/simplemind/providers/__init__.py b/simplemind/providers/__init__.py index 87787ea..bf54bf5 100644 --- a/simplemind/providers/__init__.py +++ b/simplemind/providers/__init__.py @@ -1,5 +1,6 @@ -from .openai import OpenAI from .anthropic import Anthropic +from .groq import Groq +from .openai import OpenAI from .xai import XAI -providers = [OpenAI, Anthropic, XAI] +providers = [Anthropic, Groq, OpenAI, XAI] diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index 12ec1c4..e4cdb7a 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -1,19 +1,21 @@ +from typing import Union + import anthropic import instructor -# from ..models import Conversation, Message -from ..settings import settings +from simplemind.models import BaseProvider, Conversation, Message +from simplemind.settings import settings PROVIDER_NAME = "anthropic" DEFAULT_MODEL = "claude-3-5-sonnet-20241022" DEFAULT_MAX_TOKENS = 1000 -class Anthropic: +class Anthropic(BaseProvider): __name__ = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL - def __init__(self, api_key: str = None): + def __init__(self, api_key: Union[str, None] = None): self.api_key = api_key or settings.ANTHROPIC_API_KEY @property @@ -24,12 +26,10 @@ class Anthropic: @property def structured_client(self): """A client patched with Instructor.""" - return instructor.from_anthropic(anthropic.Anthropic(api_key=self.api_key)) + return instructor.from_anthropic(self.client) def send_conversation(self, conversation: "Conversation"): """Send a conversation to the Anthropic API.""" - from ..models import Message - messages = [ {"role": msg.role, "content": msg.text} for msg in conversation.messages ] diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py new file mode 100644 index 0000000..465f428 --- /dev/null +++ b/simplemind/providers/groq.py @@ -0,0 +1,50 @@ +from typing import Union + +import groq +import instructor + +from simplemind.models import BaseProvider, Conversation, Message +from simplemind.settings import settings + +PROVIDER_NAME = "groq" +DEFAULT_MODEL = "llama3-8b-8192" + + +class Groq(BaseProvider): + __name__ = PROVIDER_NAME + DEFAULT_MODEL = DEFAULT_MODEL + + def __init__(self, api_key: Union[str, None] = None): + self.api_key = api_key or settings.GROQ_API_KEY + + @property + def client(self): + """The raw Groq client.""" + return groq.Groq(api_key=self.api_key) + + @property + def structured_client(self): + """A client patched with Instructor.""" + return instructor.from_groq(self.client) + + def send_conversation(self, conversation: "Conversation"): + """Send a conversation to the Groq API.""" + messages = [ + {"role": msg.role, "content": msg.text} for msg in conversation.messages + ] + + response = self.client.chat.completions.create( + model=conversation.llm_model or DEFAULT_MODEL, messages=messages + ) + + # Get the response content from the Groq response + assistant_message = response.choices[0].message + + # Create and return a properly formatted Message instance + return Message( + role="assistant", + text=assistant_message.content, + raw=response, + llm_model=conversation.llm_model or DEFAULT_MODEL, + llm_provider=PROVIDER_NAME, + ) diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index f385d58..4eb2795 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -1,17 +1,20 @@ -import openai as oa -import instructor +from typing import Union -from ..settings import settings +import instructor +import openai as oa + +from simplemind.models import BaseProvider, Conversation, Message +from simplemind.settings import settings PROVIDER_NAME = "openai" DEFAULT_MODEL = "gpt-4o-mini" -class OpenAI: +class OpenAI(BaseProvider): __name__ = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL - def __init__(self, api_key: str = None): + def __init__(self, api_key: Union[str, None] = None): self.api_key = api_key or settings.OPENAI_API_KEY @property @@ -21,13 +24,11 @@ class OpenAI: @property def structured_client(self): - """A client patched with Instructor.""" - return instructor.patch(oa.OpenAI(api_key=self.api_key)) + """A OpenAI client with Instructor.""" + return instructor.from_openai(self.client) def send_conversation(self, conversation: "Conversation"): """Send a conversation to the OpenAI API.""" - from ..models import Message - messages = [ {"role": msg.role, "content": msg.text} for msg in conversation.messages ] @@ -42,7 +43,7 @@ class OpenAI: # Create and return a properly formatted Message instance return Message( role="assistant", - text=assistant_message.content, + text=assistant_message.content or "", raw=response, llm_model=conversation.llm_model or DEFAULT_MODEL, llm_provider=PROVIDER_NAME, diff --git a/simplemind/providers/xai.py b/simplemind/providers/xai.py index 4451f00..04b9709 100644 --- a/simplemind/providers/xai.py +++ b/simplemind/providers/xai.py @@ -1,8 +1,10 @@ -import openai as oa -import instructor +from typing import Union -# from ..models import Conversation, Message -from ..settings import settings +import instructor +import openai as oa + +from simplemind.models import BaseProvider, Conversation, Message +from simplemind.settings import settings PROVIDER_NAME = "xai" DEFAULT_MODEL = "grok-beta" @@ -10,11 +12,11 @@ BASE_URL = "https://api.x.ai/v1" DEFAULT_MAX_TOKENS = 1000 -class XAI: +class XAI(BaseProvider): __name__ = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL - def __init__(self, api_key: str = None): + def __init__(self, api_key: Union[str, None] = None): self.api_key = api_key or settings.XAI_API_KEY @property @@ -29,13 +31,10 @@ class XAI: @property def structured_client(self): """A client patched with Instructor.""" - return instructor.patch( - oa.OpenAI(api_key=self.api_key, base_url="https://api.x.ai/v1") - ) + return instructor.from_openai(self.client) def send_conversation(self, conversation: "Conversation"): """Send a conversation to the OpenAI API.""" - from ..models import Message messages = [ {"role": msg.role, "content": msg.text} for msg in conversation.messages @@ -57,7 +56,7 @@ class XAI: llm_provider=PROVIDER_NAME, ) - def structured_response(self, prompt, response_model, *, llm_model): + def structured_response(self, prompt: str, response_model, *, llm_model): raise NotImplementedError("XAI does not support structured responses") def generate_text(self, prompt, *, llm_model): diff --git a/simplemind/settings.py b/simplemind/settings.py index 7828508..0c3f319 100644 --- a/simplemind/settings.py +++ b/simplemind/settings.py @@ -3,8 +3,9 @@ from pydantic_settings import BaseSettings class Settings(BaseSettings): - OPENAI_API_KEY: str = Field(..., env="OPENAI_API_KEY") ANTHROPIC_API_KEY: str = Field(..., env="ANTHROPIC_API_KEY") + GROQ_API_KEY: str = Field(..., env="GROQ_API_KEY") + OPENAI_API_KEY: str = Field(..., env="OPENAI_API_KEY") XAI_API_KEY: str = Field(..., env="XAI_API_KEY")