Refactor import paths in models.py and utils.py

``
This commit is contained in:
2024-10-29 06:21:14 -04:00
parent 3e21d1e22b
commit 7ac0f76839
7 changed files with 28 additions and 25 deletions
+2 -1
View File
@@ -4,7 +4,8 @@ from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
from simplemind.utils import find_provider
from .utils import find_provider
MESSAGE_ROLE = Literal["system", "user", "assistant"]
+3 -3
View File
@@ -2,13 +2,13 @@ from abc import ABC, abstractmethod
from instructor import Instructor
from simplemind.models import Conversation, Message
# from ..models import Conversation, Message
class BaseProvider(ABC):
"""The base provider class."""
__name__: str
NAME: str
DEFAULT_MODEL: str
@property
@@ -24,7 +24,7 @@ class BaseProvider(ABC):
raise NotImplementedError
@abstractmethod
def send_conversation(self, conversation: Conversation) -> Message:
def send_conversation(self, conversation: "Conversation") -> "Message":
"""Send a conversation to the provider."""
raise NotImplementedError
+6 -6
View File
@@ -3,9 +3,8 @@ from typing import Union
import anthropic
import instructor
from simplemind.models import Conversation, Message
from simplemind.providers._base import BaseProvider
from simplemind.settings import settings
from ._base import BaseProvider
from ..settings import settings
PROVIDER_NAME = "anthropic"
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
@@ -13,7 +12,7 @@ DEFAULT_MAX_TOKENS = 1000
class Anthropic(BaseProvider):
__name__ = PROVIDER_NAME
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
def __init__(self, api_key: Union[str, None] = None):
@@ -33,6 +32,8 @@ class Anthropic(BaseProvider):
def send_conversation(self, conversation: "Conversation"):
"""Send a conversation to the Anthropic API."""
from ..models import Message
messages = [
{"role": msg.role, "content": msg.text} for msg in conversation.messages
]
@@ -43,8 +44,7 @@ class Anthropic(BaseProvider):
max_tokens=DEFAULT_MAX_TOKENS,
)
# Get the response content from the OpenAI response
# assistant_message = response.choices[0].message
# Get the response content from the Anthropic response
assistant_message = response.content[0].text
# Create and return a properly formatted Message instance
+6 -5
View File
@@ -3,16 +3,15 @@ from typing import Union
import groq
import instructor
from simplemind.models import Conversation, Message
from simplemind.providers._base import BaseProvider
from simplemind.settings import settings
from ._base import BaseProvider
from ..settings import settings
PROVIDER_NAME = "groq"
DEFAULT_MODEL = "llama3-8b-8192"
class Groq(BaseProvider):
__name__ = PROVIDER_NAME
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
def __init__(self, api_key: Union[str, None] = None):
@@ -30,8 +29,10 @@ class Groq(BaseProvider):
"""A client patched with Instructor."""
return instructor.from_groq(self.client)
def send_conversation(self, conversation: Conversation) -> Message:
def send_conversation(self, conversation: "Conversation") -> "Message":
"""Send a conversation to the Groq API."""
from ..models import Message
messages = [
{"role": msg.role, "content": msg.text} for msg in conversation.messages
]
+5 -4
View File
@@ -3,16 +3,15 @@ from typing import Union
import instructor
import openai as oa
from simplemind.models import Conversation, Message
from simplemind.providers._base import BaseProvider
from simplemind.settings import settings
from ._base import BaseProvider
from ..settings import settings
PROVIDER_NAME = "openai"
DEFAULT_MODEL = "gpt-4o-mini"
class OpenAI(BaseProvider):
__name__ = PROVIDER_NAME
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
def __init__(self, api_key: Union[str, None] = None):
@@ -32,6 +31,8 @@ class OpenAI(BaseProvider):
def send_conversation(self, conversation: "Conversation"):
"""Send a conversation to the OpenAI API."""
from ..models import Message
messages = [
{"role": msg.role, "content": msg.text} for msg in conversation.messages
]
+4 -4
View File
@@ -3,9 +3,8 @@ from typing import Union
import instructor
import openai as oa
from simplemind.models import Conversation, Message
from simplemind.providers._base import BaseProvider
from simplemind.settings import settings
from ._base import BaseProvider
from ..settings import settings
PROVIDER_NAME = "xai"
DEFAULT_MODEL = "grok-beta"
@@ -14,7 +13,7 @@ DEFAULT_MAX_TOKENS = 1000
class XAI(BaseProvider):
__name__ = PROVIDER_NAME
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
def __init__(self, api_key: Union[str, None] = None):
@@ -37,6 +36,7 @@ class XAI(BaseProvider):
def send_conversation(self, conversation: "Conversation"):
"""Send a conversation to the OpenAI API."""
from ..models import Message
messages = [
{"role": msg.role, "content": msg.text} for msg in conversation.messages
+2 -2
View File
@@ -1,13 +1,13 @@
from typing import Union
from simplemind.providers import providers
from .providers import providers
def find_provider(provider_name: Union[str, None]):
"""Find a provider by name."""
if provider_name:
for provider_class in providers:
if provider_class.__name__.lower() == provider_name.lower():
if provider_class.NAME.lower() == provider_name.lower():
# Instantiate the provider
return provider_class()
raise ValueError(f"Provider {provider_name} not found")