diff --git a/simplemind/__init__.py b/simplemind/__init__.py index e69de29..78a6812 100644 --- a/simplemind/__init__.py +++ b/simplemind/__init__.py @@ -0,0 +1,3 @@ +from .core.client import Client + +__all__ = ["Client"] diff --git a/simplemind/conversation.py b/simplemind/conversation.py index 4bae53e..c5d976c 100644 --- a/simplemind/conversation.py +++ b/simplemind/conversation.py @@ -22,3 +22,8 @@ class Conversation: self.add_message(response.text, role="assistant") return response + + def set_model(self, model: str): + """Set the model for the conversation.""" + self.client.set_model(model) + return True diff --git a/simplemind/core/__init__.py b/simplemind/core/__init__.py index 0b5ce1a..2a6a490 100644 --- a/simplemind/core/__init__.py +++ b/simplemind/core/__init__.py @@ -3,6 +3,8 @@ from .models import AIResponse from ..concepts.context import Context from ..providers.base import BaseClientProvider +from .config import settings + class SimpleMind: """Main class for SimpleMind functionality.""" diff --git a/simplemind/core/client.py b/simplemind/core/client.py index f27de43..d8a857a 100644 --- a/simplemind/core/client.py +++ b/simplemind/core/client.py @@ -1,32 +1,48 @@ +import os from typing import Optional -from simplemind.core.models import Conversation, AIResponse -from simplemind.concepts.context import Context -from simplemind.providers.openai import OpenAI -from simplemind.providers.anthropic import Anthropic -import logging + +from .models import Conversation, AIResponse +from ..concepts.context import Context + from .errors import ProviderError from .logger import logger class Client: - def __init__(self, api_key: str, context: Optional[Context] = None): - self.api_key = api_key - self.context = context or Context() - self.providers = self._initialize_providers() + def __init__(self, api_key=None): + self.providers = {} - def _initialize_providers(self): - return { - "openai": OpenAI(api_key=self.api_key), - "anthropic": Anthropic(api_key=self.api_key), + # Auto-detect available API keys from environment + api_keys = { + "openai": os.getenv("OPENAI_API_KEY"), + "anthropic": os.getenv("ANTHROPIC_API_KEY"), + # Add other providers as needed } + # Initialize providers for which we have API keys + for provider, key in api_keys.items(): + if key: + self.providers[provider] = self._initialize_provider(provider, key) + + def _initialize_provider(self, provider_name, api_key): + if provider_name == "openai": + from ..providers.openai import OpenAI + + return OpenAI(api_key) + elif provider_name == "anthropic": + from ..providers.anthropic import Anthropic + + return Anthropic(api_key) + # Add other providers as needed + def create_conversation( self, provider: str = "openai", context: Optional[Context] = None ) -> Conversation: if provider not in self.providers: raise ValueError(f"Provider '{provider}' not supported.") + conversation_context = context or Context() return self.providers[provider].create_conversation( - initial_message="Hello!", context=self.context.model_dump() + initial_message="Hello!", context=conversation_context.model_dump() ) def _handle_api_error(self, error: Exception, operation: str): diff --git a/simplemind/core/config.py b/simplemind/core/config.py index 6e6e8f8..1defeaf 100644 --- a/simplemind/core/config.py +++ b/simplemind/core/config.py @@ -1,6 +1,6 @@ -from pydantic import Field, field_validator +from pydantic import Field, field_validator, ValidationError from pydantic_settings import BaseSettings -from typing import Optional +from typing import Optional, Dict, Set class Settings(BaseSettings): @@ -10,12 +10,44 @@ class Settings(BaseSettings): default_model: str = Field("gpt-4", env="DEFAULT_MODEL") log_level: str = Field("INFO", env="LOG_LEVEL") + # Map of provider names to their required environment variables + PROVIDER_REQUIREMENTS: Dict[str, str] = { + "openai": "openai_api_key", + "anthropic": "anthropic_api_key", + "ollama": "ollama_host_url", + } + @field_validator("*", mode="before") - def check_required(cls, v, info): - if info.field_name in info.data and info.data[info.field_name] is None: - raise ValueError(f"{info.field_name} is required") + def check_empty_string(cls, v): + if isinstance(v, str) and not v.strip(): + return None return v + def validate_provider(self, provider: str) -> bool: + """ + Validate that the necessary API key exists for a given provider. + Raises ValueError if the provider is not properly configured. + """ + if provider not in self.PROVIDER_REQUIREMENTS: + raise ValueError(f"Unknown provider: {provider}") + + required_key = self.PROVIDER_REQUIREMENTS[provider] + if getattr(self, required_key) is None: + raise ValueError( + f"Missing API key for {provider}. " + f"Please set {required_key.upper()} environment variable." + ) + return True + + @property + def available_providers(self) -> Set[str]: + """Return a set of properly configured providers.""" + return { + provider + for provider, key in self.PROVIDER_REQUIREMENTS.items() + if getattr(self, key) is not None + } + class Config: env_file = ".env" case_sensitive = False diff --git a/simplemind/core/models.py b/simplemind/core/models.py index 91fc603..eaac571 100644 --- a/simplemind/core/models.py +++ b/simplemind/core/models.py @@ -1,6 +1,6 @@ from pydantic import BaseModel from typing import Any, Dict, List, Optional -import uuid +from ..concepts.context import Context from datetime import datetime @@ -12,24 +12,30 @@ class AIRequest(BaseModel): return self.text -class AIResponse(BaseModel): - text: str - response: Any - metadata: Dict[str, Any] = {} - - def __str__(self): - return self.text - - class Message(BaseModel): role: str # "user", "assistant", "system" content: str created_at: datetime = datetime.now() +class Choice(BaseModel): + message: Message + index: int = 0 + + +class AIResponse(BaseModel): + choices: List[Choice] + + @property + def content(self) -> str: + """Helper to get the first message content directly.""" + return self.choices[0].message.content + + class Conversation(BaseModel): id: str messages: List[Message] = [] + context: Optional[Context] = None created_at: datetime = datetime.now() updated_at: datetime = datetime.now() @@ -44,6 +50,10 @@ class Conversation(BaseModel): self.updated_at = datetime.now() return message + def set_context(self, context: Context): + """Sets the context for the conversation.""" + self.context = context + class ConversationRequest(BaseModel): conversation_id: Optional[str] = None diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index 15ac541..e17047a 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -10,6 +10,7 @@ from ..core.logger import logger DEFAULT_MODEL = "claude-3-5-sonnet-20241022" +DEFAULT_MAX_TOKENS = 4096 class Anthropic(BaseClientProvider): @@ -22,6 +23,8 @@ class Anthropic(BaseClientProvider): self._api_key = os.getenv("ANTHROPIC_API_KEY") if not self._api_key: raise ValueError("Anthropic API key not provided.") + logger.debug(f"API key length: {len(self._api_key) if self._api_key else 0}") + base_client = BaseAnthropic(api_key=self._api_key) self.client = instructor.from_anthropic(base_client) if not self.test_connection(): @@ -55,12 +58,17 @@ class Anthropic(BaseClientProvider): params = { "messages": messages, "model": self.model, + "max_tokens": DEFAULT_MAX_TOKENS, } if conversation.context: - params["context"] = conversation.context + params["context"] = ( + vars(conversation.context) + if hasattr(conversation.context, "__dict__") + else dict(conversation.context) + ) try: - completion = self.client.completions.create(**params) + completion = self.client.completions.create(response_model=str, **params) response_text = completion.completion metadata = {"model": completion.model, "usage": completion.usage} logger.info("Generated response from Anthropic.") diff --git a/simplemind/providers/base.py b/simplemind/providers/base.py index 938585c..fb72740 100644 --- a/simplemind/providers/base.py +++ b/simplemind/providers/base.py @@ -20,27 +20,32 @@ class BaseClientProvider: def login(self): """Initializes the AI provider client.""" + msg = "This method must be implemented by the AI provider client." raise NotImplementedError(msg) def test_connection(self) -> bool: """Tests the connection to the AI provider client.""" + msg = "This method must be implemented by the AI provider client." raise NotImplementedError(msg) def health_check(self): """Checks the health of the AI provider client.""" + msg = "This method must be implemented by the AI provider client." raise NotImplementedError(msg) @property def available_models(self) -> List[str]: """Returns the available models from the AI provider client.""" + msg = "This method must be implemented by the AI provider client." raise NotImplementedError(msg) def message(self, message: str, **kwargs) -> AIResponse: """Generates a response from the AI provider client.""" + msg = "This method must be implemented by the AI provider client." raise NotImplementedError(msg) diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index 7d43a06..d92515f 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -1,10 +1,12 @@ from typing import List, Optional from openai import OpenAI as BaseOpenAI from ..core.errors import AuthenticationError, ProviderError -from simplemind.core.config import settings +from ..core.config import settings from ..core.logger import logger from .base import BaseClientProvider +DEFAULT_MODEL = "gpt-4" + class OpenAI(BaseClientProvider): def __init__(self, model: str = "gpt-4", api_key: Optional[str] = None): @@ -47,7 +49,7 @@ class OpenAI(BaseClientProvider): logger.error(f"OpenAI API error: {e}") raise ProviderError(f"OpenAI API error: {e}") - def generate_response(self, conversation) -> str: + def generate_response(self, conversation, **kwargs) -> str: """Generate a response using the OpenAI API.""" try: messages = [ @@ -55,10 +57,24 @@ class OpenAI(BaseClientProvider): for msg in conversation.messages ] - response = self.client.chat.completions.create( - model=self.model, messages=messages - ) + # Ensure we're using a valid model name, not the API key + if not isinstance(self.model, str) or self.model.startswith("sk-"): + logger.warning( + f"Invalid model name detected. Falling back to {DEFAULT_MODEL!r}" + ) + model_name = DEFAULT_MODEL + else: + model_name = self.model - return response.choices[0].message.content + r = self.client.chat.completions.create( + model=model_name, # Use the validated model name + messages=messages, + **kwargs, + ) + return r except Exception as e: self._handle_api_error(e) + + def generate_text(self, conversation, **kwargs) -> str: + """Generate a text response using the OpenAI API.""" + return self.generate_response(conversation, **kwargs).choices[0].message.content diff --git a/t.py b/t.py index a096c19..380bffc 100644 --- a/t.py +++ b/t.py @@ -1,42 +1,15 @@ -import os -from pprint import pprint -from pydantic import BaseModel import simplemind -from simplemind.concepts.context import Context -from simplemind.plugins.kv import KVPlugin -from simplemind.plugins.basic_memory import BasicMemoryPlugin -from simplemind.chains.reverse_text import ReverseTextChain -from simplemind.core.client import Client +from simplemind.core import settings -class CustomContext(Context): - def __init__(self): - super().__init__() - self.add_plugin("kv", KVPlugin()) - # self.add_plugin("basic_memory", BasicMemoryPlugin()) +print(settings) +# Initialize client without explicit API key +ai = simplemind.Client() -# Initialize context and client -ctx = CustomContext() -aiclient = Client( - context=ctx, - api_key=os.environ["OPENAI_API_KEY"], -) +print(ai.available_models) -# Test connection and available models -print(aiclient.available_models) - -# Example usage -conversation = aiclient.create_conversation(provider="anthropic") -conversation.set_context(ctx) -response = aiclient.send_message( - conversation, "Who is Kenneth Reitz?", provider="anthropic" -) -# response = aiclient.send_message( -# conversation, "Who is Kenneth Reitz?", provider="openai" -# ) -# print(response) - -# reverse_chain = ReverseTextChain() -# result = reverse_chain.run("Hello, World!") -# print(result) # Output: !dlroW ,olleH +# The provider will automatically use OPENAI_API_KEY from environment +conversation = ai.create_conversation(provider="openai") +response = ai.send_message(conversation, "Who is Kenneth Reitz?") +print(response)