diff --git a/simplemind/models.py b/simplemind/models.py index 5846489..8617904 100644 --- a/simplemind/models.py +++ b/simplemind/models.py @@ -4,7 +4,8 @@ from typing import Any, Dict, List, Literal, Optional from pydantic import BaseModel, Field -from simplemind.utils import find_provider +from .utils import find_provider + MESSAGE_ROLE = Literal["system", "user", "assistant"] diff --git a/simplemind/providers/_base.py b/simplemind/providers/_base.py index ff4df23..43715d3 100644 --- a/simplemind/providers/_base.py +++ b/simplemind/providers/_base.py @@ -2,13 +2,13 @@ from abc import ABC, abstractmethod from instructor import Instructor -from simplemind.models import Conversation, Message +# from ..models import Conversation, Message class BaseProvider(ABC): """The base provider class.""" - __name__: str + NAME: str DEFAULT_MODEL: str @property @@ -24,7 +24,7 @@ class BaseProvider(ABC): raise NotImplementedError @abstractmethod - def send_conversation(self, conversation: Conversation) -> Message: + def send_conversation(self, conversation: "Conversation") -> "Message": """Send a conversation to the provider.""" raise NotImplementedError diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index 4efa976..b164c83 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -3,9 +3,8 @@ from typing import Union import anthropic import instructor -from simplemind.models import Conversation, Message -from simplemind.providers._base import BaseProvider -from simplemind.settings import settings +from ._base import BaseProvider +from ..settings import settings PROVIDER_NAME = "anthropic" DEFAULT_MODEL = "claude-3-5-sonnet-20241022" @@ -13,7 +12,7 @@ DEFAULT_MAX_TOKENS = 1000 class Anthropic(BaseProvider): - __name__ = PROVIDER_NAME + NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL def __init__(self, api_key: Union[str, None] = None): @@ -33,6 +32,8 @@ class Anthropic(BaseProvider): def send_conversation(self, conversation: "Conversation"): """Send a conversation to the Anthropic API.""" + from ..models import Message + messages = [ {"role": msg.role, "content": msg.text} for msg in conversation.messages ] @@ -43,8 +44,7 @@ class Anthropic(BaseProvider): max_tokens=DEFAULT_MAX_TOKENS, ) - # Get the response content from the OpenAI response - # assistant_message = response.choices[0].message + # Get the response content from the Anthropic response assistant_message = response.content[0].text # Create and return a properly formatted Message instance diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py index 4cd48f5..9fd4b71 100644 --- a/simplemind/providers/groq.py +++ b/simplemind/providers/groq.py @@ -3,16 +3,15 @@ from typing import Union import groq import instructor -from simplemind.models import Conversation, Message -from simplemind.providers._base import BaseProvider -from simplemind.settings import settings +from ._base import BaseProvider +from ..settings import settings PROVIDER_NAME = "groq" DEFAULT_MODEL = "llama3-8b-8192" class Groq(BaseProvider): - __name__ = PROVIDER_NAME + NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL def __init__(self, api_key: Union[str, None] = None): @@ -30,8 +29,10 @@ class Groq(BaseProvider): """A client patched with Instructor.""" return instructor.from_groq(self.client) - def send_conversation(self, conversation: Conversation) -> Message: + def send_conversation(self, conversation: "Conversation") -> "Message": """Send a conversation to the Groq API.""" + from ..models import Message + messages = [ {"role": msg.role, "content": msg.text} for msg in conversation.messages ] diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index 6302747..ee82105 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -3,16 +3,15 @@ from typing import Union import instructor import openai as oa -from simplemind.models import Conversation, Message -from simplemind.providers._base import BaseProvider -from simplemind.settings import settings +from ._base import BaseProvider +from ..settings import settings PROVIDER_NAME = "openai" DEFAULT_MODEL = "gpt-4o-mini" class OpenAI(BaseProvider): - __name__ = PROVIDER_NAME + NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL def __init__(self, api_key: Union[str, None] = None): @@ -32,6 +31,8 @@ class OpenAI(BaseProvider): def send_conversation(self, conversation: "Conversation"): """Send a conversation to the OpenAI API.""" + from ..models import Message + messages = [ {"role": msg.role, "content": msg.text} for msg in conversation.messages ] diff --git a/simplemind/providers/xai.py b/simplemind/providers/xai.py index 6b22020..1d6e58b 100644 --- a/simplemind/providers/xai.py +++ b/simplemind/providers/xai.py @@ -3,9 +3,8 @@ from typing import Union import instructor import openai as oa -from simplemind.models import Conversation, Message -from simplemind.providers._base import BaseProvider -from simplemind.settings import settings +from ._base import BaseProvider +from ..settings import settings PROVIDER_NAME = "xai" DEFAULT_MODEL = "grok-beta" @@ -14,7 +13,7 @@ DEFAULT_MAX_TOKENS = 1000 class XAI(BaseProvider): - __name__ = PROVIDER_NAME + NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL def __init__(self, api_key: Union[str, None] = None): @@ -37,6 +36,7 @@ class XAI(BaseProvider): def send_conversation(self, conversation: "Conversation"): """Send a conversation to the OpenAI API.""" + from ..models import Message messages = [ {"role": msg.role, "content": msg.text} for msg in conversation.messages diff --git a/simplemind/utils.py b/simplemind/utils.py index 38e0c0d..7427312 100644 --- a/simplemind/utils.py +++ b/simplemind/utils.py @@ -1,13 +1,13 @@ from typing import Union -from simplemind.providers import providers +from .providers import providers def find_provider(provider_name: Union[str, None]): """Find a provider by name.""" if provider_name: for provider_class in providers: - if provider_class.__name__.lower() == provider_name.lower(): + if provider_class.NAME.lower() == provider_name.lower(): # Instantiate the provider return provider_class() raise ValueError(f"Provider {provider_name} not found")