Refactor code structure and imports in simplemind package

This commit is contained in:
2024-10-28 17:02:15 -04:00
parent 8b043acf7f
commit ed144fa701
10 changed files with 143 additions and 73 deletions
+3
View File
@@ -0,0 +1,3 @@
from .core.client import Client
__all__ = ["Client"]
+5
View File
@@ -22,3 +22,8 @@ class Conversation:
self.add_message(response.text, role="assistant")
return response
def set_model(self, model: str):
"""Set the model for the conversation."""
self.client.set_model(model)
return True
+2
View File
@@ -3,6 +3,8 @@ from .models import AIResponse
from ..concepts.context import Context
from ..providers.base import BaseClientProvider
from .config import settings
class SimpleMind:
"""Main class for SimpleMind functionality."""
+30 -14
View File
@@ -1,32 +1,48 @@
import os
from typing import Optional
from simplemind.core.models import Conversation, AIResponse
from simplemind.concepts.context import Context
from simplemind.providers.openai import OpenAI
from simplemind.providers.anthropic import Anthropic
import logging
from .models import Conversation, AIResponse
from ..concepts.context import Context
from .errors import ProviderError
from .logger import logger
class Client:
def __init__(self, api_key: str, context: Optional[Context] = None):
self.api_key = api_key
self.context = context or Context()
self.providers = self._initialize_providers()
def __init__(self, api_key=None):
self.providers = {}
def _initialize_providers(self):
return {
"openai": OpenAI(api_key=self.api_key),
"anthropic": Anthropic(api_key=self.api_key),
# Auto-detect available API keys from environment
api_keys = {
"openai": os.getenv("OPENAI_API_KEY"),
"anthropic": os.getenv("ANTHROPIC_API_KEY"),
# Add other providers as needed
}
# Initialize providers for which we have API keys
for provider, key in api_keys.items():
if key:
self.providers[provider] = self._initialize_provider(provider, key)
def _initialize_provider(self, provider_name, api_key):
if provider_name == "openai":
from ..providers.openai import OpenAI
return OpenAI(api_key)
elif provider_name == "anthropic":
from ..providers.anthropic import Anthropic
return Anthropic(api_key)
# Add other providers as needed
def create_conversation(
self, provider: str = "openai", context: Optional[Context] = None
) -> Conversation:
if provider not in self.providers:
raise ValueError(f"Provider '{provider}' not supported.")
conversation_context = context or Context()
return self.providers[provider].create_conversation(
initial_message="Hello!", context=self.context.model_dump()
initial_message="Hello!", context=conversation_context.model_dump()
)
def _handle_api_error(self, error: Exception, operation: str):
+37 -5
View File
@@ -1,6 +1,6 @@
from pydantic import Field, field_validator
from pydantic import Field, field_validator, ValidationError
from pydantic_settings import BaseSettings
from typing import Optional
from typing import Optional, Dict, Set
class Settings(BaseSettings):
@@ -10,12 +10,44 @@ class Settings(BaseSettings):
default_model: str = Field("gpt-4", env="DEFAULT_MODEL")
log_level: str = Field("INFO", env="LOG_LEVEL")
# Map of provider names to their required environment variables
PROVIDER_REQUIREMENTS: Dict[str, str] = {
"openai": "openai_api_key",
"anthropic": "anthropic_api_key",
"ollama": "ollama_host_url",
}
@field_validator("*", mode="before")
def check_required(cls, v, info):
if info.field_name in info.data and info.data[info.field_name] is None:
raise ValueError(f"{info.field_name} is required")
def check_empty_string(cls, v):
if isinstance(v, str) and not v.strip():
return None
return v
def validate_provider(self, provider: str) -> bool:
"""
Validate that the necessary API key exists for a given provider.
Raises ValueError if the provider is not properly configured.
"""
if provider not in self.PROVIDER_REQUIREMENTS:
raise ValueError(f"Unknown provider: {provider}")
required_key = self.PROVIDER_REQUIREMENTS[provider]
if getattr(self, required_key) is None:
raise ValueError(
f"Missing API key for {provider}. "
f"Please set {required_key.upper()} environment variable."
)
return True
@property
def available_providers(self) -> Set[str]:
"""Return a set of properly configured providers."""
return {
provider
for provider, key in self.PROVIDER_REQUIREMENTS.items()
if getattr(self, key) is not None
}
class Config:
env_file = ".env"
case_sensitive = False
+20 -10
View File
@@ -1,6 +1,6 @@
from pydantic import BaseModel
from typing import Any, Dict, List, Optional
import uuid
from ..concepts.context import Context
from datetime import datetime
@@ -12,24 +12,30 @@ class AIRequest(BaseModel):
return self.text
class AIResponse(BaseModel):
text: str
response: Any
metadata: Dict[str, Any] = {}
def __str__(self):
return self.text
class Message(BaseModel):
role: str # "user", "assistant", "system"
content: str
created_at: datetime = datetime.now()
class Choice(BaseModel):
message: Message
index: int = 0
class AIResponse(BaseModel):
choices: List[Choice]
@property
def content(self) -> str:
"""Helper to get the first message content directly."""
return self.choices[0].message.content
class Conversation(BaseModel):
id: str
messages: List[Message] = []
context: Optional[Context] = None
created_at: datetime = datetime.now()
updated_at: datetime = datetime.now()
@@ -44,6 +50,10 @@ class Conversation(BaseModel):
self.updated_at = datetime.now()
return message
def set_context(self, context: Context):
"""Sets the context for the conversation."""
self.context = context
class ConversationRequest(BaseModel):
conversation_id: Optional[str] = None
+10 -2
View File
@@ -10,6 +10,7 @@ from ..core.logger import logger
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
DEFAULT_MAX_TOKENS = 4096
class Anthropic(BaseClientProvider):
@@ -22,6 +23,8 @@ class Anthropic(BaseClientProvider):
self._api_key = os.getenv("ANTHROPIC_API_KEY")
if not self._api_key:
raise ValueError("Anthropic API key not provided.")
logger.debug(f"API key length: {len(self._api_key) if self._api_key else 0}")
base_client = BaseAnthropic(api_key=self._api_key)
self.client = instructor.from_anthropic(base_client)
if not self.test_connection():
@@ -55,12 +58,17 @@ class Anthropic(BaseClientProvider):
params = {
"messages": messages,
"model": self.model,
"max_tokens": DEFAULT_MAX_TOKENS,
}
if conversation.context:
params["context"] = conversation.context
params["context"] = (
vars(conversation.context)
if hasattr(conversation.context, "__dict__")
else dict(conversation.context)
)
try:
completion = self.client.completions.create(**params)
completion = self.client.completions.create(response_model=str, **params)
response_text = completion.completion
metadata = {"model": completion.model, "usage": completion.usage}
logger.info("Generated response from Anthropic.")
+5
View File
@@ -20,27 +20,32 @@ class BaseClientProvider:
def login(self):
"""Initializes the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
def test_connection(self) -> bool:
"""Tests the connection to the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
def health_check(self):
"""Checks the health of the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
@property
def available_models(self) -> List[str]:
"""Returns the available models from the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
def message(self, message: str, **kwargs) -> AIResponse:
"""Generates a response from the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
+22 -6
View File
@@ -1,10 +1,12 @@
from typing import List, Optional
from openai import OpenAI as BaseOpenAI
from ..core.errors import AuthenticationError, ProviderError
from simplemind.core.config import settings
from ..core.config import settings
from ..core.logger import logger
from .base import BaseClientProvider
DEFAULT_MODEL = "gpt-4"
class OpenAI(BaseClientProvider):
def __init__(self, model: str = "gpt-4", api_key: Optional[str] = None):
@@ -47,7 +49,7 @@ class OpenAI(BaseClientProvider):
logger.error(f"OpenAI API error: {e}")
raise ProviderError(f"OpenAI API error: {e}")
def generate_response(self, conversation) -> str:
def generate_response(self, conversation, **kwargs) -> str:
"""Generate a response using the OpenAI API."""
try:
messages = [
@@ -55,10 +57,24 @@ class OpenAI(BaseClientProvider):
for msg in conversation.messages
]
response = self.client.chat.completions.create(
model=self.model, messages=messages
)
# Ensure we're using a valid model name, not the API key
if not isinstance(self.model, str) or self.model.startswith("sk-"):
logger.warning(
f"Invalid model name detected. Falling back to {DEFAULT_MODEL!r}"
)
model_name = DEFAULT_MODEL
else:
model_name = self.model
return response.choices[0].message.content
r = self.client.chat.completions.create(
model=model_name, # Use the validated model name
messages=messages,
**kwargs,
)
return r
except Exception as e:
self._handle_api_error(e)
def generate_text(self, conversation, **kwargs) -> str:
"""Generate a text response using the OpenAI API."""
return self.generate_response(conversation, **kwargs).choices[0].message.content
+9 -36
View File
@@ -1,42 +1,15 @@
import os
from pprint import pprint
from pydantic import BaseModel
import simplemind
from simplemind.concepts.context import Context
from simplemind.plugins.kv import KVPlugin
from simplemind.plugins.basic_memory import BasicMemoryPlugin
from simplemind.chains.reverse_text import ReverseTextChain
from simplemind.core.client import Client
from simplemind.core import settings
class CustomContext(Context):
def __init__(self):
super().__init__()
self.add_plugin("kv", KVPlugin())
# self.add_plugin("basic_memory", BasicMemoryPlugin())
print(settings)
# Initialize client without explicit API key
ai = simplemind.Client()
# Initialize context and client
ctx = CustomContext()
aiclient = Client(
context=ctx,
api_key=os.environ["OPENAI_API_KEY"],
)
print(ai.available_models)
# Test connection and available models
print(aiclient.available_models)
# Example usage
conversation = aiclient.create_conversation(provider="anthropic")
conversation.set_context(ctx)
response = aiclient.send_message(
conversation, "Who is Kenneth Reitz?", provider="anthropic"
)
# response = aiclient.send_message(
# conversation, "Who is Kenneth Reitz?", provider="openai"
# )
# print(response)
# reverse_chain = ReverseTextChain()
# result = reverse_chain.run("Hello, World!")
# print(result) # Output: !dlroW ,olleH
# The provider will automatically use OPENAI_API_KEY from environment
conversation = ai.create_conversation(provider="openai")
response = ai.send_message(conversation, "Who is Kenneth Reitz?")
print(response)