mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 22:50:18 +00:00
Refactor import paths in models.py and utils.py
``
This commit is contained in:
@@ -4,7 +4,8 @@ from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from simplemind.utils import find_provider
|
||||
from .utils import find_provider
|
||||
|
||||
|
||||
MESSAGE_ROLE = Literal["system", "user", "assistant"]
|
||||
|
||||
|
||||
@@ -2,13 +2,13 @@ from abc import ABC, abstractmethod
|
||||
|
||||
from instructor import Instructor
|
||||
|
||||
from simplemind.models import Conversation, Message
|
||||
# from ..models import Conversation, Message
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
"""The base provider class."""
|
||||
|
||||
__name__: str
|
||||
NAME: str
|
||||
DEFAULT_MODEL: str
|
||||
|
||||
@property
|
||||
@@ -24,7 +24,7 @@ class BaseProvider(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def send_conversation(self, conversation: Conversation) -> Message:
|
||||
def send_conversation(self, conversation: "Conversation") -> "Message":
|
||||
"""Send a conversation to the provider."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -3,9 +3,8 @@ from typing import Union
|
||||
import anthropic
|
||||
import instructor
|
||||
|
||||
from simplemind.models import Conversation, Message
|
||||
from simplemind.providers._base import BaseProvider
|
||||
from simplemind.settings import settings
|
||||
from ._base import BaseProvider
|
||||
from ..settings import settings
|
||||
|
||||
PROVIDER_NAME = "anthropic"
|
||||
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
||||
@@ -13,7 +12,7 @@ DEFAULT_MAX_TOKENS = 1000
|
||||
|
||||
|
||||
class Anthropic(BaseProvider):
|
||||
__name__ = PROVIDER_NAME
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
|
||||
def __init__(self, api_key: Union[str, None] = None):
|
||||
@@ -33,6 +32,8 @@ class Anthropic(BaseProvider):
|
||||
|
||||
def send_conversation(self, conversation: "Conversation"):
|
||||
"""Send a conversation to the Anthropic API."""
|
||||
from ..models import Message
|
||||
|
||||
messages = [
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
]
|
||||
@@ -43,8 +44,7 @@ class Anthropic(BaseProvider):
|
||||
max_tokens=DEFAULT_MAX_TOKENS,
|
||||
)
|
||||
|
||||
# Get the response content from the OpenAI response
|
||||
# assistant_message = response.choices[0].message
|
||||
# Get the response content from the Anthropic response
|
||||
assistant_message = response.content[0].text
|
||||
|
||||
# Create and return a properly formatted Message instance
|
||||
|
||||
@@ -3,16 +3,15 @@ from typing import Union
|
||||
import groq
|
||||
import instructor
|
||||
|
||||
from simplemind.models import Conversation, Message
|
||||
from simplemind.providers._base import BaseProvider
|
||||
from simplemind.settings import settings
|
||||
from ._base import BaseProvider
|
||||
from ..settings import settings
|
||||
|
||||
PROVIDER_NAME = "groq"
|
||||
DEFAULT_MODEL = "llama3-8b-8192"
|
||||
|
||||
|
||||
class Groq(BaseProvider):
|
||||
__name__ = PROVIDER_NAME
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
|
||||
def __init__(self, api_key: Union[str, None] = None):
|
||||
@@ -30,8 +29,10 @@ class Groq(BaseProvider):
|
||||
"""A client patched with Instructor."""
|
||||
return instructor.from_groq(self.client)
|
||||
|
||||
def send_conversation(self, conversation: Conversation) -> Message:
|
||||
def send_conversation(self, conversation: "Conversation") -> "Message":
|
||||
"""Send a conversation to the Groq API."""
|
||||
from ..models import Message
|
||||
|
||||
messages = [
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
]
|
||||
|
||||
@@ -3,16 +3,15 @@ from typing import Union
|
||||
import instructor
|
||||
import openai as oa
|
||||
|
||||
from simplemind.models import Conversation, Message
|
||||
from simplemind.providers._base import BaseProvider
|
||||
from simplemind.settings import settings
|
||||
from ._base import BaseProvider
|
||||
from ..settings import settings
|
||||
|
||||
PROVIDER_NAME = "openai"
|
||||
DEFAULT_MODEL = "gpt-4o-mini"
|
||||
|
||||
|
||||
class OpenAI(BaseProvider):
|
||||
__name__ = PROVIDER_NAME
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
|
||||
def __init__(self, api_key: Union[str, None] = None):
|
||||
@@ -32,6 +31,8 @@ class OpenAI(BaseProvider):
|
||||
|
||||
def send_conversation(self, conversation: "Conversation"):
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
from ..models import Message
|
||||
|
||||
messages = [
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
]
|
||||
|
||||
@@ -3,9 +3,8 @@ from typing import Union
|
||||
import instructor
|
||||
import openai as oa
|
||||
|
||||
from simplemind.models import Conversation, Message
|
||||
from simplemind.providers._base import BaseProvider
|
||||
from simplemind.settings import settings
|
||||
from ._base import BaseProvider
|
||||
from ..settings import settings
|
||||
|
||||
PROVIDER_NAME = "xai"
|
||||
DEFAULT_MODEL = "grok-beta"
|
||||
@@ -14,7 +13,7 @@ DEFAULT_MAX_TOKENS = 1000
|
||||
|
||||
|
||||
class XAI(BaseProvider):
|
||||
__name__ = PROVIDER_NAME
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
|
||||
def __init__(self, api_key: Union[str, None] = None):
|
||||
@@ -37,6 +36,7 @@ class XAI(BaseProvider):
|
||||
|
||||
def send_conversation(self, conversation: "Conversation"):
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
from ..models import Message
|
||||
|
||||
messages = [
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
|
||||
+2
-2
@@ -1,13 +1,13 @@
|
||||
from typing import Union
|
||||
|
||||
from simplemind.providers import providers
|
||||
from .providers import providers
|
||||
|
||||
|
||||
def find_provider(provider_name: Union[str, None]):
|
||||
"""Find a provider by name."""
|
||||
if provider_name:
|
||||
for provider_class in providers:
|
||||
if provider_class.__name__.lower() == provider_name.lower():
|
||||
if provider_class.NAME.lower() == provider_name.lower():
|
||||
# Instantiate the provider
|
||||
return provider_class()
|
||||
raise ValueError(f"Provider {provider_name} not found")
|
||||
|
||||
Reference in New Issue
Block a user