mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 22:50:18 +00:00
Merge pull request #11 from Siddhesh-Agarwal/main
This commit is contained in:
+1
-1
@@ -4,7 +4,7 @@ version = "0.1.0"
|
||||
description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = ["pydantic", "pydantic-settings", "instructor", "openai", "anthropic"]
|
||||
dependencies = ["pydantic", "pydantic-settings", "instructor", "openai", "anthropic", "groq"]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
|
||||
@@ -32,3 +32,13 @@ def generate_text(prompt, *, llm_model=None, llm_provider=None):
|
||||
provider = find_provider(llm_provider)
|
||||
|
||||
return provider.generate_text(prompt=prompt, llm_model=llm_model)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Conversation",
|
||||
"SimpleMind",
|
||||
"create_conversation",
|
||||
"find_provider",
|
||||
"generate_data",
|
||||
"generate_text",
|
||||
]
|
||||
|
||||
+10
-33
@@ -4,7 +4,9 @@ from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .utils import find_provider
|
||||
from simplemind.utils import find_provider
|
||||
|
||||
MESSAGE_ROLE = Literal["system", "user", "assistant"]
|
||||
|
||||
|
||||
class SMBaseModel(BaseModel):
|
||||
@@ -17,41 +19,12 @@ class SMBaseModel(BaseModel):
|
||||
return str(self)
|
||||
|
||||
|
||||
class BaseProvider(SMBaseModel):
|
||||
"""The base provider class."""
|
||||
|
||||
__name__ = "BaseProvider"
|
||||
DEFAULT_MODEL = "DEFAULT_MODEL"
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""The instructor client for the provider."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def structured_client(self):
|
||||
"""The structured client for the provider."""
|
||||
raise NotImplementedError
|
||||
|
||||
def send_conversation(self, conversation: "Conversation"):
|
||||
"""Send a conversation to the provider."""
|
||||
raise NotImplementedError
|
||||
|
||||
def structured_response(self, prompt: str, response_model, **kwargs):
|
||||
"""Get a structured response."""
|
||||
raise NotImplementedError
|
||||
|
||||
def generate_text(self, prompt: str, **kwargs):
|
||||
"""Generate text from a prompt."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BasePlugin(SMBaseModel):
|
||||
"""The base plugin class."""
|
||||
|
||||
|
||||
class Message(SMBaseModel):
|
||||
role: Literal["system", "user", "assistant"]
|
||||
role: MESSAGE_ROLE
|
||||
text: str
|
||||
meta: Dict[str, Any] = {}
|
||||
raw: Optional[Any] = None
|
||||
@@ -62,7 +35,7 @@ class Message(SMBaseModel):
|
||||
return f"<Message role={self.role} text={self.text!r}>"
|
||||
|
||||
@classmethod
|
||||
def from_raw_response(cls, *, text, raw):
|
||||
def from_raw_response(cls, *, text: str, raw):
|
||||
self = cls()
|
||||
self.text = text
|
||||
self.raw = raw
|
||||
@@ -79,8 +52,12 @@ class Conversation(SMBaseModel):
|
||||
def __str__(self):
|
||||
return f"<Conversation id={self.id!r}>"
|
||||
|
||||
def add_message(self, role: str, text: str, meta: Dict[str, Any] = {}):
|
||||
def add_message(
|
||||
self, role: MESSAGE_ROLE, text: str, meta: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""Add a new message to the conversation."""
|
||||
if meta is None:
|
||||
meta = {}
|
||||
self.messages.append(Message(role=role, text=text, meta=meta))
|
||||
|
||||
def send(
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from .anthropic import Anthropic
|
||||
from .groq import Groq
|
||||
from .openai import OpenAI
|
||||
from .xai import XAI
|
||||
from typing import List, Type
|
||||
|
||||
providers = [Anthropic, Groq, OpenAI, XAI]
|
||||
from simplemind.providers._base import BaseProvider
|
||||
from simplemind.providers.anthropic import Anthropic
|
||||
from simplemind.providers.groq import Groq
|
||||
from simplemind.providers.openai import OpenAI
|
||||
from simplemind.providers.xai import XAI
|
||||
|
||||
providers: List[Type[BaseProvider]] = [Anthropic, Groq, OpenAI, XAI]
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from instructor import Instructor
|
||||
|
||||
from simplemind.models import Conversation, Message
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
"""The base provider class."""
|
||||
|
||||
__name__: str
|
||||
DEFAULT_MODEL: str
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def client(self):
|
||||
"""The instructor client for the provider."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def structured_client(self) -> Instructor:
|
||||
"""The structured client for the provider."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def send_conversation(self, conversation: Conversation) -> Message:
|
||||
"""Send a conversation to the provider."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def structured_response(self, prompt: str, response_model, **kwargs):
|
||||
"""Get a structured response."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def generate_text(self, prompt: str, **kwargs) -> str:
|
||||
"""Generate text from a prompt."""
|
||||
raise NotImplementedError
|
||||
@@ -3,7 +3,8 @@ from typing import Union
|
||||
import anthropic
|
||||
import instructor
|
||||
|
||||
from simplemind.models import BaseProvider, Conversation, Message
|
||||
from simplemind.models import Conversation, Message
|
||||
from simplemind.providers._base import BaseProvider
|
||||
from simplemind.settings import settings
|
||||
|
||||
PROVIDER_NAME = "anthropic"
|
||||
@@ -16,11 +17,13 @@ class Anthropic(BaseProvider):
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
|
||||
def __init__(self, api_key: Union[str, None] = None):
|
||||
self.api_key = api_key or settings.ANTHROPIC_API_KEY
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""The raw Anthropic client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("Anthropic API key is required")
|
||||
return anthropic.Anthropic(api_key=self.api_key)
|
||||
|
||||
@property
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import Union
|
||||
import groq
|
||||
import instructor
|
||||
|
||||
from simplemind.models import BaseProvider, Conversation, Message
|
||||
from simplemind.models import Conversation, Message
|
||||
from simplemind.providers._base import BaseProvider
|
||||
from simplemind.settings import settings
|
||||
|
||||
PROVIDER_NAME = "groq"
|
||||
@@ -15,11 +16,13 @@ class Groq(BaseProvider):
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
|
||||
def __init__(self, api_key: Union[str, None] = None):
|
||||
self.api_key = api_key or settings.GROQ_API_KEY
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""The raw Groq client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("Groq API key is required")
|
||||
return groq.Groq(api_key=self.api_key)
|
||||
|
||||
@property
|
||||
@@ -27,7 +30,7 @@ class Groq(BaseProvider):
|
||||
"""A client patched with Instructor."""
|
||||
return instructor.from_groq(self.client)
|
||||
|
||||
def send_conversation(self, conversation: "Conversation"):
|
||||
def send_conversation(self, conversation: Conversation) -> Message:
|
||||
"""Send a conversation to the Groq API."""
|
||||
messages = [
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
@@ -43,8 +46,31 @@ class Groq(BaseProvider):
|
||||
# Create and return a properly formatted Message instance
|
||||
return Message(
|
||||
role="assistant",
|
||||
text=assistant_message.content,
|
||||
text=assistant_message.content or "",
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or DEFAULT_MODEL,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(self, prompt: str, response_model):
|
||||
# Ensure messages are provided in kwargs
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.structured_client.chat.completions.create(
|
||||
messages=messages,
|
||||
response_model=response_model,
|
||||
)
|
||||
return response
|
||||
|
||||
def generate_text(self, prompt: str, *, llm_model: str):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.structured_client.chat.completions.create(
|
||||
messages=messages, model=llm_model
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import Union
|
||||
import instructor
|
||||
import openai as oa
|
||||
|
||||
from simplemind.models import BaseProvider, Conversation, Message
|
||||
from simplemind.models import Conversation, Message
|
||||
from simplemind.providers._base import BaseProvider
|
||||
from simplemind.settings import settings
|
||||
|
||||
PROVIDER_NAME = "openai"
|
||||
@@ -15,11 +16,13 @@ class OpenAI(BaseProvider):
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
|
||||
def __init__(self, api_key: Union[str, None] = None):
|
||||
self.api_key = api_key or settings.OPENAI_API_KEY
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""The raw OpenAI client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("OpenAI API key is required")
|
||||
return oa.OpenAI(api_key=self.api_key)
|
||||
|
||||
@property
|
||||
@@ -49,7 +52,7 @@ class OpenAI(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(self, prompt, response_model, *, llm_model):
|
||||
def structured_response(self, prompt, response_model, *, llm_model: str):
|
||||
# Ensure messages are provided in kwargs
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import Union
|
||||
import instructor
|
||||
import openai as oa
|
||||
|
||||
from simplemind.models import BaseProvider, Conversation, Message
|
||||
from simplemind.models import Conversation, Message
|
||||
from simplemind.providers._base import BaseProvider
|
||||
from simplemind.settings import settings
|
||||
|
||||
PROVIDER_NAME = "xai"
|
||||
@@ -17,15 +18,16 @@ class XAI(BaseProvider):
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
|
||||
def __init__(self, api_key: Union[str, None] = None):
|
||||
self.api_key = api_key or settings.XAI_API_KEY
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""The raw OpenAI client."""
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("XAI API key is required")
|
||||
return oa.OpenAI(
|
||||
api_key=settings.XAI_API_KEY,
|
||||
base_url="https://api.x.ai/v1",
|
||||
api_key=self.api_key,
|
||||
base_url=BASE_URL,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
+32
-6
@@ -1,12 +1,38 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import Field, SecretStr, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
ANTHROPIC_API_KEY: str = Field(..., env="ANTHROPIC_API_KEY")
|
||||
GROQ_API_KEY: str = Field(..., env="GROQ_API_KEY")
|
||||
OPENAI_API_KEY: str = Field(..., env="OPENAI_API_KEY")
|
||||
XAI_API_KEY: str = Field(..., env="XAI_API_KEY")
|
||||
"""The class that holds all the API keys for the application."""
|
||||
|
||||
ANTHROPIC_API_KEY: Optional[SecretStr] = Field(
|
||||
None, description="API key for Anthropic"
|
||||
)
|
||||
GROQ_API_KEY: Optional[SecretStr] = Field(None, description="API key for Groq")
|
||||
OPENAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for OpenAI")
|
||||
XAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for xAI")
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env", env_file_encoding="utf-8", case_sensitive=True, extra="ignore"
|
||||
)
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
@classmethod
|
||||
def empty_str_to_none(cls, v: str) -> Optional[str]:
|
||||
"""Convert empty strings to None for optional fields."""
|
||||
if v == "":
|
||||
return None
|
||||
return v
|
||||
|
||||
def get_api_key(self, provider: str) -> Union[str, None]:
|
||||
"""
|
||||
Safely get API key for a specific provider.
|
||||
Returns the key as a string or None if not set.
|
||||
"""
|
||||
key = getattr(self, f"{provider.upper()}_API_KEY", None)
|
||||
return key.get_secret_value() if key else None
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
+9
-6
@@ -1,10 +1,13 @@
|
||||
from .providers import providers
|
||||
from typing import Union
|
||||
|
||||
from simplemind.providers import providers
|
||||
|
||||
|
||||
def find_provider(provider_name: str):
|
||||
def find_provider(provider_name: Union[str, None]):
|
||||
"""Find a provider by name."""
|
||||
for provider_class in providers:
|
||||
if provider_class.__name__.lower() == provider_name.lower():
|
||||
# Instantiate the provider
|
||||
return provider_class()
|
||||
if provider_name:
|
||||
for provider_class in providers:
|
||||
if provider_class.__name__.lower() == provider_name.lower():
|
||||
# Instantiate the provider
|
||||
return provider_class()
|
||||
raise ValueError(f"Provider {provider_name} not found")
|
||||
|
||||
Reference in New Issue
Block a user