diff --git a/simplemind/__init__.py b/simplemind/__init__.py index fb4d02d..986b3e9 100644 --- a/simplemind/__init__.py +++ b/simplemind/__init__.py @@ -1 +1,18 @@ -from .core import SimpleMind +from .models import Conversation + + +class SimpleMind: + def create_conversation(self, *, llm_model=None, llm_provider=None): + return Conversation() + + def structured_response( + self, *, llm_model=None, llm_provider=None, response_model=None + ): + pass + + +def create_conversation(): + return SimpleMind().create_conversation() + + +globals().update(locals()) diff --git a/simplemind/agents/__init__.py b/simplemind/agents/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/simplemind/agents/base.py b/simplemind/agents/base.py deleted file mode 100644 index cdc32a2..0000000 --- a/simplemind/agents/base.py +++ /dev/null @@ -1,7 +0,0 @@ -from abc import ABC, abstractmethod - - -class BaseAgent(ABC): - @abstractmethod - def decide(self, context, *args, **kwargs): - pass diff --git a/simplemind/chains/__init__.py b/simplemind/chains/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/simplemind/chains/base.py b/simplemind/chains/base.py deleted file mode 100644 index 41013c7..0000000 --- a/simplemind/chains/base.py +++ /dev/null @@ -1,24 +0,0 @@ -from abc import ABC, abstractmethod - - -class BaseChain(ABC): - """Abstract base class for implementing chain operations. - - A Chain represents a processing step that can be executed on input data - and should be implemented by concrete classes to define specific behaviors. - """ - - @abstractmethod - def run(self, input_data: str) -> str: - """Execute the chain's operation on the input data. - - Args: - input_data: The input string to be processed by the chain. - - Returns: - The processed output string. - - Raises: - ValueError: If the input data is invalid or cannot be processed. - """ - pass diff --git a/simplemind/chains/reverse_text.py b/simplemind/chains/reverse_text.py deleted file mode 100644 index cccdcbc..0000000 --- a/simplemind/chains/reverse_text.py +++ /dev/null @@ -1,25 +0,0 @@ -from .base import BaseChain - - -class ReverseTextChain(BaseChain): - """Chain that reverses input text. - - This chain takes a text input and returns it reversed. For example, - "hello" becomes "olleh". - """ - - def run(self, input_data: str) -> str: - """Reverse the input text. - - Args: - input_data: The text to reverse. - - Returns: - The reversed text. - - Raises: - TypeError: If input_data is not a string. - """ - if not isinstance(input_data, str): - raise TypeError("Input must be a string") - return input_data[::-1] diff --git a/simplemind/concepts/__init__.py b/simplemind/concepts/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/simplemind/concepts/context.py b/simplemind/concepts/context.py deleted file mode 100644 index 4e7da47..0000000 --- a/simplemind/concepts/context.py +++ /dev/null @@ -1,17 +0,0 @@ -from pydantic import BaseModel -from typing import Dict, Any -from simplemind.plugins.base import BasePlugin - - -class Context(BaseModel): - model_config = {"arbitrary_types_allowed": True} - plugins: Dict[str, BasePlugin] = {} - - def add_plugin(self, name: str, plugin: BasePlugin): - self.plugins[name] = plugin - - def execute_plugin(self, name: str, *args, **kwargs): - if name in self.plugins: - return self.plugins[name].execute(self, *args, **kwargs) - else: - raise ValueError(f"Plugin '{name}' not found in context.") diff --git a/simplemind/core/__init__.py b/simplemind/core/__init__.py deleted file mode 100644 index da9fda1..0000000 --- a/simplemind/core/__init__.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Dict, Any, Optional -from .models import AIResponse -from ..concepts.context import Context -from ..providers.base import BaseClientProvider - -from .config import settings - - -class SimpleMind: - """Main class for SimpleMind functionality.""" - - def __init__(self, provider: str = "openai", context: Optional[Context] = None): - """Initialize SimpleMind with the specified provider.""" - self.provider = provider - self.context = context or Context() - self._client = self._get_provider() - - def _get_provider(self) -> BaseClientProvider: - """Get the appropriate provider client.""" - from ..providers.openai import OpenAI - from ..providers.anthropic import Anthropic - - # Initialize providers based on environment variables. - providers = {} - if settings.openai_api_key: - providers.update({"openai": OpenAI}) - if settings.anthropic_api_key: - providers.update({"anthropic": Anthropic}) - - if self.provider not in providers: - raise ValueError( - f"Provider '{self.provider}' not supported. Available providers: {list(providers.keys())}" - ) - - return providers[self.provider]() - - def generate(self, prompt: str, **kwargs) -> AIResponse: - """Generate a response using the configured provider.""" - - return self._client.message(prompt, **kwargs) - - def create_conversation(self) -> str: - """Create a new conversation and return its ID.""" - - initial_message = "You are a helpful assistant." - - conversation = self._client.create_conversation(initial_message) - return conversation - - def add_message(self, conversation_id: str, message: str) -> AIResponse: - """Send a message in an existing conversation.""" - - return self._client.add_message(conversation_id, message) diff --git a/simplemind/core/client.py b/simplemind/core/client.py deleted file mode 100644 index d8a857a..0000000 --- a/simplemind/core/client.py +++ /dev/null @@ -1,68 +0,0 @@ -import os -from typing import Optional - -from .models import Conversation, AIResponse -from ..concepts.context import Context - -from .errors import ProviderError -from .logger import logger - - -class Client: - def __init__(self, api_key=None): - self.providers = {} - - # Auto-detect available API keys from environment - api_keys = { - "openai": os.getenv("OPENAI_API_KEY"), - "anthropic": os.getenv("ANTHROPIC_API_KEY"), - # Add other providers as needed - } - - # Initialize providers for which we have API keys - for provider, key in api_keys.items(): - if key: - self.providers[provider] = self._initialize_provider(provider, key) - - def _initialize_provider(self, provider_name, api_key): - if provider_name == "openai": - from ..providers.openai import OpenAI - - return OpenAI(api_key) - elif provider_name == "anthropic": - from ..providers.anthropic import Anthropic - - return Anthropic(api_key) - # Add other providers as needed - - def create_conversation( - self, provider: str = "openai", context: Optional[Context] = None - ) -> Conversation: - if provider not in self.providers: - raise ValueError(f"Provider '{provider}' not supported.") - conversation_context = context or Context() - return self.providers[provider].create_conversation( - initial_message="Hello!", context=conversation_context.model_dump() - ) - - def _handle_api_error(self, error: Exception, operation: str): - """Handle API errors in a consistent way.""" - logger.error(f"Error during {operation}: {str(error)}") - raise ProviderError(f"Failed to {operation}: {str(error)}") from error - - def send_message( - self, conversation: Conversation, message: str, provider: str = "openai" - ) -> AIResponse: - if provider not in self.providers: - raise ValueError(f"Provider '{provider}' not supported.") - try: - return self.providers[provider].send_message(conversation.id, message) - except Exception as e: - self._handle_api_error(e, "send message") - - @property - def available_models(self): - available = {} - for name, provider in self.providers.items(): - available[name] = provider.available_models - return available diff --git a/simplemind/core/config.py b/simplemind/core/config.py deleted file mode 100644 index 1defeaf..0000000 --- a/simplemind/core/config.py +++ /dev/null @@ -1,56 +0,0 @@ -from pydantic import Field, field_validator, ValidationError -from pydantic_settings import BaseSettings -from typing import Optional, Dict, Set - - -class Settings(BaseSettings): - openai_api_key: Optional[str] = Field(None, env="OPENAI_API_KEY") - anthropic_api_key: Optional[str] = Field(None, env="ANTHROPIC_API_KEY") - ollama_host_url: Optional[str] = Field(None, env="OLLAMA_HOST_URL") - default_model: str = Field("gpt-4", env="DEFAULT_MODEL") - log_level: str = Field("INFO", env="LOG_LEVEL") - - # Map of provider names to their required environment variables - PROVIDER_REQUIREMENTS: Dict[str, str] = { - "openai": "openai_api_key", - "anthropic": "anthropic_api_key", - "ollama": "ollama_host_url", - } - - @field_validator("*", mode="before") - def check_empty_string(cls, v): - if isinstance(v, str) and not v.strip(): - return None - return v - - def validate_provider(self, provider: str) -> bool: - """ - Validate that the necessary API key exists for a given provider. - Raises ValueError if the provider is not properly configured. - """ - if provider not in self.PROVIDER_REQUIREMENTS: - raise ValueError(f"Unknown provider: {provider}") - - required_key = self.PROVIDER_REQUIREMENTS[provider] - if getattr(self, required_key) is None: - raise ValueError( - f"Missing API key for {provider}. " - f"Please set {required_key.upper()} environment variable." - ) - return True - - @property - def available_providers(self) -> Set[str]: - """Return a set of properly configured providers.""" - return { - provider - for provider, key in self.PROVIDER_REQUIREMENTS.items() - if getattr(self, key) is not None - } - - class Config: - env_file = ".env" - case_sensitive = False - - -settings = Settings() diff --git a/simplemind/core/errors.py b/simplemind/core/errors.py deleted file mode 100644 index 3aaf310..0000000 --- a/simplemind/core/errors.py +++ /dev/null @@ -1,22 +0,0 @@ -class SimpleMindError(Exception): - """Base exception for SimpleMind errors.""" - - pass - - -class ProviderError(SimpleMindError): - """Raised when there's an error with the AI provider.""" - - pass - - -class ConfigurationError(SimpleMindError): - """Raised when there's a configuration error.""" - - pass - - -class AuthenticationError(SimpleMindError): - """Raised when authentication fails.""" - - pass diff --git a/simplemind/core/logger.py b/simplemind/core/logger.py deleted file mode 100644 index 8c2e2e7..0000000 --- a/simplemind/core/logger.py +++ /dev/null @@ -1,15 +0,0 @@ -import logging - -def setup_logger(name: str) -> logging.Logger: - logger = logging.getLogger(name) - if not logger.hasHandlers(): - logger.setLevel(logging.INFO) - handler = logging.StreamHandler() - formatter = logging.Formatter('[%(asctime)s] %(levelname)s - %(message)s') - handler.setFormatter(formatter) - logger.addHandler(handler) - return logger - -# Initialize a global logger -logger = setup_logger("simplemind") - diff --git a/simplemind/core/models.py b/simplemind/core/models.py deleted file mode 100644 index eaac571..0000000 --- a/simplemind/core/models.py +++ /dev/null @@ -1,67 +0,0 @@ -from pydantic import BaseModel -from typing import Any, Dict, List, Optional -from ..concepts.context import Context -from datetime import datetime - - -class AIRequest(BaseModel): - text: str - parameters: Dict[str, Any] = {} - - def __str__(self): - return self.text - - -class Message(BaseModel): - role: str # "user", "assistant", "system" - content: str - created_at: datetime = datetime.now() - - -class Choice(BaseModel): - message: Message - index: int = 0 - - -class AIResponse(BaseModel): - choices: List[Choice] - - @property - def content(self) -> str: - """Helper to get the first message content directly.""" - return self.choices[0].message.content - - -class Conversation(BaseModel): - id: str - messages: List[Message] = [] - context: Optional[Context] = None - created_at: datetime = datetime.now() - updated_at: datetime = datetime.now() - - def get_messages(self) -> List[Message]: - """Returns a list of messages in the conversation.""" - return self.messages - - def add_message(self, role: str, content: str) -> Message: - """Adds a new message to the conversation.""" - message = Message(role=role, content=content) - self.messages.append(message) - self.updated_at = datetime.now() - return message - - def set_context(self, context: Context): - """Sets the context for the conversation.""" - self.context = context - - -class ConversationRequest(BaseModel): - conversation_id: Optional[str] = None - message: str - context_update: Optional[Dict[str, Any]] = None - - -class ConversationResponse(BaseModel): - conversation_id: str - messages: List[Message] - metadata: Dict[str, Any] = {} diff --git a/simplemind/models.py b/simplemind/models.py new file mode 100644 index 0000000..50d931d --- /dev/null +++ b/simplemind/models.py @@ -0,0 +1,54 @@ +import uuid +from typing import List, Dict, Any, Optional +from datetime import datetime + +from pydantic import BaseModel, Field + +from .utils import find_provider + + +class SMBaseModel(BaseModel): + date_created: datetime = Field(default_factory=datetime.now) + + def __str__(self): + return f"<{self.__class__.__name__} {self.model_dump_json()}>" + + +class Message(SMBaseModel): + role: str + text: str + meta: Dict[str, Any] = {} + raw: Optional[Any] = None + llm_model: Optional[str] = None + llm_provider: Optional[str] = None + + def __str__(self): + return f"" + + @classmethod + def from_raw_response(cls, *, text, raw): + self = cls() + self.text = text + self.raw = raw + return self + + +class Conversation(SMBaseModel): + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + messages: List[Message] = [] + llm_model: Optional[str] = None + llm_provider: Optional[str] = None + + def __str__(self): + return f"" + + def add_message(self, role: str, text: str, meta: Dict[str, Any] = {}): + """Add a new message to the conversation.""" + self.messages.append(Message(role=role, text=text, meta=meta)) + + def send( + self, llm_model: Optional[str] = None, llm_provider: Optional[str] = None + ) -> Message: + """Send the conversation to the LLM.""" + provider = find_provider(llm_provider or self.llm_provider) + return provider.send_conversation(self) diff --git a/simplemind/models/base.py b/simplemind/models/base.py deleted file mode 100644 index 9981ed1..0000000 --- a/simplemind/models/base.py +++ /dev/null @@ -1,26 +0,0 @@ -from datetime import datetime -from typing import Any, Dict, List, Optional -from pydantic import BaseModel - -class Message(BaseModel): - role: str - content: str - created_at: datetime = datetime.now() - -class Conversation(BaseModel): - id: str - messages: List[Message] = [] - context: Dict[str, Any] = {} - created_at: datetime = datetime.now() - updated_at: datetime = datetime.now() - - def add_message(self, role: str, content: str) -> Message: - message = Message(role=role, content=content) - self.messages.append(message) - self.updated_at = datetime.now() - return message - -class AIResponse(BaseModel): - text: str - response: Any - metadata: Dict[str, Any] = {} diff --git a/simplemind/plugins/__init__.py b/simplemind/plugins/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/simplemind/plugins/base.py b/simplemind/plugins/base.py deleted file mode 100644 index f346f70..0000000 --- a/simplemind/plugins/base.py +++ /dev/null @@ -1,34 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict, Optional - - -class BasePlugin(ABC): - """Base class for all SimpleMind plugins.""" - - def __init__(self): - self.enabled: bool = True - self.name: str = self.__class__.__name__ - - @abstractmethod - def process(self, context: Dict[str, Any]) -> Dict[str, Any]: - """Process the context and return modified context. - - Args: - context: The current conversation context - - Returns: - Modified context dictionary - """ - pass - - def enable(self) -> None: - """Enable the plugin.""" - self.enabled = True - - def disable(self) -> None: - """Disable the plugin.""" - self.enabled = False - - @property - def is_enabled(self) -> bool: - return self.enabled diff --git a/simplemind/plugins/basic_memory.py b/simplemind/plugins/basic_memory.py deleted file mode 100644 index d3365e4..0000000 --- a/simplemind/plugins/basic_memory.py +++ /dev/null @@ -1,10 +0,0 @@ -from .base import BasePlugin - - -class BasicMemoryPlugin(BasePlugin): - def __init__(self): - self.memory = [] - - def execute(self, context, message): - self.memory.append(message) - return self.memory diff --git a/simplemind/plugins/kv.py b/simplemind/plugins/kv.py deleted file mode 100644 index 0e0ebbe..0000000 --- a/simplemind/plugins/kv.py +++ /dev/null @@ -1,18 +0,0 @@ -from .base import BasePlugin - - -class KVPlugin(BasePlugin): - def __init__(self): - self.store = {} - - def process(self, key: str, value=None): - """ - Get or set a value in the key-value store. - If value is None, returns the value for the key. - If value is provided, sets the value for the key and returns it. - """ - if value is None: - return self.store.get(key) - - self.store[key] = value - return value diff --git a/simplemind/providers/__init__.py b/simplemind/providers/__init__.py index 9a4b90d..f6725d5 100644 --- a/simplemind/providers/__init__.py +++ b/simplemind/providers/__init__.py @@ -1,18 +1,3 @@ -from ..core.config import settings +from .openai import OpenAI -__all__ = [] - -if settings.anthropic_api_key: - from .anthropic import Anthropic - - __all__.append("Anthropic") - -if settings.openai_api_key: - from .openai import OpenAI - - __all__.append("OpenAI") - -if settings.ollama_host_url: - from .ollama import Ollama - - __all__.append("Ollama") +providers = [OpenAI] diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py deleted file mode 100644 index e17047a..0000000 --- a/simplemind/providers/anthropic.py +++ /dev/null @@ -1,80 +0,0 @@ -import os -from typing import List, Optional - -import instructor -from anthropic import Anthropic as BaseAnthropic - -from .base import BaseClientProvider -from ..core.models import AIResponse, Conversation -from ..core.logger import logger - - -DEFAULT_MODEL = "claude-3-5-sonnet-20241022" -DEFAULT_MAX_TOKENS = 4096 - - -class Anthropic(BaseClientProvider): - def __init__(self, model: str = DEFAULT_MODEL, api_key: Optional[str] = None): - super().__init__(model=model, api_key=api_key) - self.login() - - def login(self): - if not self._api_key: - self._api_key = os.getenv("ANTHROPIC_API_KEY") - if not self._api_key: - raise ValueError("Anthropic API key not provided.") - logger.debug(f"API key length: {len(self._api_key) if self._api_key else 0}") - - base_client = BaseAnthropic(api_key=self._api_key) - self.client = instructor.from_anthropic(base_client) - if not self.test_connection(): - raise ConnectionError("Failed to connect to Anthropic API.") - logger.info("Logged in to Anthropic successfully.") - - @property - def available_models(self) -> List[str]: - try: - return [ - "claude-3-opus-20240229", - "claude-3-5-sonnet-20240620", - "claude-3-haiku-20240307", - ] - except Exception as e: - logger.error(f"Error fetching models: {e}") - return [] - - def test_connection(self) -> bool: - models = self.available_models - if models: - logger.info(f"Available models: {models}") - return True - logger.warning("No available models found.") - return False - - def generate_response(self, conversation: Conversation) -> AIResponse: - messages = [ - {"role": msg.role, "content": msg.content} for msg in conversation.messages - ] - params = { - "messages": messages, - "model": self.model, - "max_tokens": DEFAULT_MAX_TOKENS, - } - if conversation.context: - params["context"] = ( - vars(conversation.context) - if hasattr(conversation.context, "__dict__") - else dict(conversation.context) - ) - - try: - completion = self.client.completions.create(response_model=str, **params) - response_text = completion.completion - metadata = {"model": completion.model, "usage": completion.usage} - logger.info("Generated response from Anthropic.") - return AIResponse( - text=response_text, response=completion, metadata=metadata - ) - except Exception as e: - logger.error(f"Error generating response: {e}") - raise e diff --git a/simplemind/providers/base.py b/simplemind/providers/base.py deleted file mode 100644 index fb72740..0000000 --- a/simplemind/providers/base.py +++ /dev/null @@ -1,112 +0,0 @@ -# import logging - -from pydantic import BaseModel -from typing import Any, Dict, List, Optional -from ..core.models import AIResponse, Conversation, Message -import uuid - - -DEFAULT_MODEL = "gpt-4o" - - -class BaseClientProvider: - - def __init__(self, *, model: str = DEFAULT_MODEL, api_key: Optional[str] = None): - # self.logger = logging.getLogger(self.__class__.__name__) - self.client = None - self.model = model - self._api_key = api_key - self.conversations: Dict[str, Conversation] = {} - - def login(self): - """Initializes the AI provider client.""" - - msg = "This method must be implemented by the AI provider client." - raise NotImplementedError(msg) - - def test_connection(self) -> bool: - """Tests the connection to the AI provider client.""" - - msg = "This method must be implemented by the AI provider client." - raise NotImplementedError(msg) - - def health_check(self): - """Checks the health of the AI provider client.""" - - msg = "This method must be implemented by the AI provider client." - raise NotImplementedError(msg) - - @property - def available_models(self) -> List[str]: - """Returns the available models from the AI provider client.""" - - msg = "This method must be implemented by the AI provider client." - raise NotImplementedError(msg) - - def message(self, message: str, **kwargs) -> AIResponse: - """Generates a response from the AI provider client.""" - - msg = "This method must be implemented by the AI provider client." - raise NotImplementedError(msg) - - # Uncomment and implement additional methods as needed - # def features(self): - # """Returns the features of the AI provider client.""" - # msg = "This method must be implemented by the AI provider client." - # raise NotImplementedError(msg) - - # def structured_response(self, model, message, **kwargs): - # pass - - # def structured_conversation(self, model, message, **kwargs): - # pass - - # def single_message(self, model, message, **kwargs): - # return self.generate_response(message) - - # def start_conversation(self, model, message, **kwargs): - # pass - - def create_conversation( - self, initial_message: str, context: Optional[Dict[str, Any]] = None - ) -> Conversation: - conv_id = str(uuid.uuid4()) - conversation = Conversation( - id=conv_id, - messages=[Message(role="user", content=initial_message)], - context=context or {}, - ) - self.conversations[conv_id] = conversation - return conversation - - def send_message( - self, - conversation_id: str, - message: str, - context_update: Optional[Dict[str, Any]] = None, - ) -> AIResponse: - if conversation_id not in self.conversations: - raise ValueError("Conversation ID does not exist.") - - conversation = self.conversations[conversation_id] - conversation.messages.append(Message(role="user", content=message)) - - if context_update: - conversation.context.update(context_update) - - response = self.generate_response(conversation) - conversation.messages.append( - Message(role="assistant", content=response.choices[0].message.content) - ) - return response - - def generate_response(self, conversation: Conversation) -> AIResponse: - """Generates a response based on the conversation.""" - raise NotImplementedError( - "This method must be implemented by the AI provider client." - ) - - def get_conversation(self, conversation_id: str) -> Conversation: - if conversation_id not in self.conversations: - raise ValueError("Conversation ID does not exist.") - return self.conversations[conversation_id] diff --git a/simplemind/providers/ollama.py b/simplemind/providers/ollama.py deleted file mode 100644 index 78ee08f..0000000 --- a/simplemind/providers/ollama.py +++ /dev/null @@ -1,110 +0,0 @@ -import os -import uuid -from typing import List, Optional, Dict, Any -from ..core.errors import ProviderError -from ..core.logger import logger -from .base import BaseClientProvider -from some_module import BaseOllama # Replace 'some_module' with the actual module name - -TIMEOUT = 60 - - -class Ollama(BaseClientProvider): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.login() - self.conversation = [] - - def login(self): - """Initialize Ollama client, with Instructor enabled.""" - if not os.environ.get("OLLAMA_HOST_URL"): - raise ValueError("Please set the OLLAMA_HOST_URL environment variable") - - if not os.environ.get("OLLAMA_MODEL"): - raise ValueError("Please set the OLLAMA_MODEL environment variable") - else: - self.model = os.environ.get("OLLAMA_MODEL") - - self.client = BaseOllama( - timeout=TIMEOUT, host=os.environ.get("OLLAMA_HOST_URL") - ) - assert self.test_connection() - - @property - def available_models(self): - """Returns the available models from the Ollama client.""" - - def gen(): - for model in self.client.list().get("models"): - yield model - - return [g for g in gen()] - - def test_connection(self): - """Test the connection to Ollama. Returns True if successful.""" - - return bool(len(self.available_models)) - - def generate_text(self, prompt, *, response_model=False, **kwargs): - use_instructor = bool(response_model) - - client = self.instructor_client if use_instructor else self.client - - # Parameters for the Ollama client. - params = { - "prompt": prompt, - "model": self.model, - } - params.update(kwargs) - - if use_instructor: - params["response_model"] = response_model - - # Make the request to Ollama. - completion = client.generate(**params) - if use_instructor: - return completion.model_dump() - - else: - return AIResponse( - response=completion, - text=completion.get("response"), - ) - - def message( - self, message=None, message_history=None, response_model=False, **kwargs - ): - """Generates a response from the Ollama client.""" - use_instructor = bool(response_model) - - client = self.instructor_client if use_instructor else self.client - - # Parameters for the Ollama client. - all_messages = [] - if message_history: - all_messages.extend(message_history) - if message: - all_messages.append({"role": "user", "content": message}) - params = { - "messages": all_messages, - "model": self.model, - } - params.update(kwargs) - - if use_instructor: - params["response_model"] = response_model - - # Make the request to Ollama. - completion = client.chat(**params) - if use_instructor: - return completion.model_dump() - - else: - return AIResponse( - response=completion, - text=completion.get("message").get("content"), - ) - - def start_conversation(self): - return Conversation(self) diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index 1b83b41..06c8f09 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -1,69 +1,57 @@ -from typing import List, Optional -import openai -from ..core.errors import AuthenticationError, ProviderError -from ..core.config import settings -from ..core.logger import logger -from .base import BaseClientProvider +import openai as oa +import instructor -DEFAULT_MODEL = "gpt-4o" +# from ..models import Conversation, Message +from ..settings import settings + +PROVIDER_NAME = "openai" +DEFAULT_MODEL = "gpt-4o-mini" -class OpenAI(BaseClientProvider): - def __init__(self, model: str = DEFAULT_MODEL, api_key: Optional[str] = None): - super().__init__(model=model, api_key=api_key) - self.login() +class OpenAI: + __name__ = PROVIDER_NAME + DEFAULT_MODEL = DEFAULT_MODEL - def login(self) -> None: - if not self._api_key: - self._api_key = settings.openai_api_key - if not self._api_key: - raise AuthenticationError("OpenAI API key not provided") - - try: - openai.api_key = self._api_key - if not self.test_connection(): - raise ProviderError("Failed to connect to OpenAI API") - logger.info("Successfully connected to OpenAI") - except Exception as e: - raise ProviderError(f"OpenAI initialization failed: {str(e)}") + def __init__(self, api_key: str = None): + self.api_key = api_key or settings.OPENAI_API_KEY @property - def available_models(self) -> List[str]: - try: - models = openai.models.list() - return [model.id for model in models["data"]] - except Exception as e: - logger.error(f"Error fetching models: {e}") - return [] + def client(self): + """The raw OpenAI client.""" + return oa.OpenAI(api_key=self.api_key) - def test_connection(self): - """Test the connection to OpenAI API.""" - try: - openai.models.list() - return True - except Exception as e: - logger.error(f"Connection test failed: {e}") - return False + @property + def structured_client(self): + """A client patched with Instructor.""" + return instructor.patch(oa.OpenAI(api_key=self.api_key)) - def _handle_api_error(self, e: Exception) -> None: - """Handle API errors.""" - logger.error(f"OpenAI API error: {e}") - raise ProviderError(f"OpenAI API error: {e}") + def send_conversation(self, conversation: "Conversation"): + """Send a conversation to the OpenAI API.""" + from ..models import Message - def add_message(self, conversation_id, message, *args, **kwargs) -> str: - """Generate a response using the OpenAI API.""" - try: - # Create a client instance using the API key - client = openai.OpenAI(api_key=self._api_key) + messages = [ + {"role": msg.role, "content": msg.text} for msg in conversation.messages + ] - # Create the message for the conversation - messages = [{"role": "user", "content": message}] + response = self.client.chat.completions.create( + model=conversation.llm_model or DEFAULT_MODEL, messages=messages + ) - # Use the new API syntax - response = client.chat.completions.create( - model=self.model, messages=messages, *args, **kwargs - ) + # Get the response content from the OpenAI response + assistant_message = response.choices[0].message - return response.choices[0].message.content - except Exception as e: - self._handle_api_error(e) + # Create and return a properly formatted Message instance + return Message( + role="assistant", + text=assistant_message.content, + raw=response, + llm_model=conversation.llm_model or DEFAULT_MODEL, + llm_provider=PROVIDER_NAME, + ) + + def structured_response(self, model, response_model, **kwargs): + client = instructor.patch(oa.OpenAI(api_key=self.api_key)) + response = client.chat.completions.create( + model=model, response_model=response_model, **kwargs + ) + return response diff --git a/simplemind/settings.py b/simplemind/settings.py new file mode 100644 index 0000000..0189beb --- /dev/null +++ b/simplemind/settings.py @@ -0,0 +1,10 @@ +from pydantic import Field +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + OPENAI_API_KEY: str = Field(..., env="OPENAI_API_KEY") + ANTHROPIC_API_KEY: str = Field(..., env="ANTHROPIC_API_KEY") + + +settings = Settings() diff --git a/simplemind/tests/test_openai.py b/simplemind/tests/test_openai.py deleted file mode 100644 index 7721493..0000000 --- a/simplemind/tests/test_openai.py +++ /dev/null @@ -1,24 +0,0 @@ -import unittest -from unittest.mock import patch, MagicMock -from simplemind.providers.openai import OpenAI -from simplemind.core.errors import AuthenticationError, ProviderError - - -class TestOpenAIProvider(unittest.TestCase): - def setUp(self): - self.api_key = "test_key" - - @patch("simplemind.providers.openai.openai") - def test_initialization(self, mock_openai): - mock_openai.Model.list.return_value = {"data": ["gpt-4"]} - provider = OpenAI(api_key=self.api_key) - self.assertIsNotNone(provider.client) - mock_openai.Model.list.assert_called_once() - - def test_missing_api_key(self): - with self.assertRaises(AuthenticationError): - OpenAI(api_key=None) - - -if __name__ == "__main__": - unittest.main() diff --git a/simplemind/utils.py b/simplemind/utils.py new file mode 100644 index 0000000..8304881 --- /dev/null +++ b/simplemind/utils.py @@ -0,0 +1,10 @@ +from .providers import providers + + +def find_provider(provider_name: str): + """Find a provider by name.""" + for provider_class in providers: + if provider_class.__name__.lower() == provider_name.lower(): + # Instantiate the provider + return provider_class() + raise ValueError(f"Provider {provider_name} not found") diff --git a/simplemind/utils/__init__.py b/simplemind/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/simplemind/utils/faiss_store.py b/simplemind/utils/faiss_store.py deleted file mode 100644 index 9bbc559..0000000 --- a/simplemind/utils/faiss_store.py +++ /dev/null @@ -1,19 +0,0 @@ -import faiss -import numpy as np -from typing import List - - -class FAISSStore: - def __init__(self, dimension: int): - self.dimension = dimension - self.index = faiss.IndexFlatL2(dimension) - self.ids = [] - - def add_embeddings(self, embeddings: np.ndarray, ids: List[str]): - self.index.add(embeddings) - self.ids.extend(ids) - - def search(self, query_embedding: np.ndarray, top_k: int = 5): - distances, indices = self.index.search(query_embedding, top_k) - results = [(self.ids[idx], distances[i]) for i, idx in enumerate(indices[0])] - return results diff --git a/t.py b/t.py index 2863b50..b0ecfa4 100644 --- a/t.py +++ b/t.py @@ -1,9 +1,7 @@ -from simplemind import SimpleMind +import simplemind as sm -sm = SimpleMind() - -# The provider will automatically use OPENAI_API_KEY from environment conversation = sm.create_conversation() -r = sm.add_message(conversation.id, "Who is Kenneth Reitz?") +conversation.add_message(role="user", text="Hello, how are you?") +r = conversation.send(llm_model="gpt-4o-mini", llm_provider="openai") + print(r) -# diff --git a/tests/test_openai.py b/tests/test_openai.py deleted file mode 100644 index fd2d54c..0000000 --- a/tests/test_openai.py +++ /dev/null @@ -1,22 +0,0 @@ -import unittest -from unittest.mock import patch, MagicMock -from simplemind.providers.openai import OpenAI -from simplemind.core.errors import AuthenticationError, ProviderError - - -class TestOpenAIProvider(unittest.TestCase): - def setUp(self): - self.api_key = "test_key" - - @patch("simplemind.integrations.openai.BaseOpenAI") - def test_initialization(self, mock_openai): - provider = OpenAI(api_key=self.api_key) - self.assertIsNotNone(provider.client) - - def test_missing_api_key(self): - with self.assertRaises(AuthenticationError): - OpenAI(api_key=None) - - -if __name__ == "__main__": - unittest.main()