mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 22:50:18 +00:00
Refactor code structure and imports in simplemind package
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
from .core.client import Client
|
||||
|
||||
__all__ = ["Client"]
|
||||
|
||||
@@ -22,3 +22,8 @@ class Conversation:
|
||||
self.add_message(response.text, role="assistant")
|
||||
|
||||
return response
|
||||
|
||||
def set_model(self, model: str):
|
||||
"""Set the model for the conversation."""
|
||||
self.client.set_model(model)
|
||||
return True
|
||||
|
||||
@@ -3,6 +3,8 @@ from .models import AIResponse
|
||||
from ..concepts.context import Context
|
||||
from ..providers.base import BaseClientProvider
|
||||
|
||||
from .config import settings
|
||||
|
||||
|
||||
class SimpleMind:
|
||||
"""Main class for SimpleMind functionality."""
|
||||
|
||||
+30
-14
@@ -1,32 +1,48 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from simplemind.core.models import Conversation, AIResponse
|
||||
from simplemind.concepts.context import Context
|
||||
from simplemind.providers.openai import OpenAI
|
||||
from simplemind.providers.anthropic import Anthropic
|
||||
import logging
|
||||
|
||||
from .models import Conversation, AIResponse
|
||||
from ..concepts.context import Context
|
||||
|
||||
from .errors import ProviderError
|
||||
from .logger import logger
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(self, api_key: str, context: Optional[Context] = None):
|
||||
self.api_key = api_key
|
||||
self.context = context or Context()
|
||||
self.providers = self._initialize_providers()
|
||||
def __init__(self, api_key=None):
|
||||
self.providers = {}
|
||||
|
||||
def _initialize_providers(self):
|
||||
return {
|
||||
"openai": OpenAI(api_key=self.api_key),
|
||||
"anthropic": Anthropic(api_key=self.api_key),
|
||||
# Auto-detect available API keys from environment
|
||||
api_keys = {
|
||||
"openai": os.getenv("OPENAI_API_KEY"),
|
||||
"anthropic": os.getenv("ANTHROPIC_API_KEY"),
|
||||
# Add other providers as needed
|
||||
}
|
||||
|
||||
# Initialize providers for which we have API keys
|
||||
for provider, key in api_keys.items():
|
||||
if key:
|
||||
self.providers[provider] = self._initialize_provider(provider, key)
|
||||
|
||||
def _initialize_provider(self, provider_name, api_key):
|
||||
if provider_name == "openai":
|
||||
from ..providers.openai import OpenAI
|
||||
|
||||
return OpenAI(api_key)
|
||||
elif provider_name == "anthropic":
|
||||
from ..providers.anthropic import Anthropic
|
||||
|
||||
return Anthropic(api_key)
|
||||
# Add other providers as needed
|
||||
|
||||
def create_conversation(
|
||||
self, provider: str = "openai", context: Optional[Context] = None
|
||||
) -> Conversation:
|
||||
if provider not in self.providers:
|
||||
raise ValueError(f"Provider '{provider}' not supported.")
|
||||
conversation_context = context or Context()
|
||||
return self.providers[provider].create_conversation(
|
||||
initial_message="Hello!", context=self.context.model_dump()
|
||||
initial_message="Hello!", context=conversation_context.model_dump()
|
||||
)
|
||||
|
||||
def _handle_api_error(self, error: Exception, operation: str):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic import Field, field_validator, ValidationError
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict, Set
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
@@ -10,12 +10,44 @@ class Settings(BaseSettings):
|
||||
default_model: str = Field("gpt-4", env="DEFAULT_MODEL")
|
||||
log_level: str = Field("INFO", env="LOG_LEVEL")
|
||||
|
||||
# Map of provider names to their required environment variables
|
||||
PROVIDER_REQUIREMENTS: Dict[str, str] = {
|
||||
"openai": "openai_api_key",
|
||||
"anthropic": "anthropic_api_key",
|
||||
"ollama": "ollama_host_url",
|
||||
}
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
def check_required(cls, v, info):
|
||||
if info.field_name in info.data and info.data[info.field_name] is None:
|
||||
raise ValueError(f"{info.field_name} is required")
|
||||
def check_empty_string(cls, v):
|
||||
if isinstance(v, str) and not v.strip():
|
||||
return None
|
||||
return v
|
||||
|
||||
def validate_provider(self, provider: str) -> bool:
|
||||
"""
|
||||
Validate that the necessary API key exists for a given provider.
|
||||
Raises ValueError if the provider is not properly configured.
|
||||
"""
|
||||
if provider not in self.PROVIDER_REQUIREMENTS:
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
required_key = self.PROVIDER_REQUIREMENTS[provider]
|
||||
if getattr(self, required_key) is None:
|
||||
raise ValueError(
|
||||
f"Missing API key for {provider}. "
|
||||
f"Please set {required_key.upper()} environment variable."
|
||||
)
|
||||
return True
|
||||
|
||||
@property
|
||||
def available_providers(self) -> Set[str]:
|
||||
"""Return a set of properly configured providers."""
|
||||
return {
|
||||
provider
|
||||
for provider, key in self.PROVIDER_REQUIREMENTS.items()
|
||||
if getattr(self, key) is not None
|
||||
}
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = False
|
||||
|
||||
+20
-10
@@ -1,6 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
from ..concepts.context import Context
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@@ -12,24 +12,30 @@ class AIRequest(BaseModel):
|
||||
return self.text
|
||||
|
||||
|
||||
class AIResponse(BaseModel):
|
||||
text: str
|
||||
response: Any
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
def __str__(self):
|
||||
return self.text
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str # "user", "assistant", "system"
|
||||
content: str
|
||||
created_at: datetime = datetime.now()
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
message: Message
|
||||
index: int = 0
|
||||
|
||||
|
||||
class AIResponse(BaseModel):
|
||||
choices: List[Choice]
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
"""Helper to get the first message content directly."""
|
||||
return self.choices[0].message.content
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
id: str
|
||||
messages: List[Message] = []
|
||||
context: Optional[Context] = None
|
||||
created_at: datetime = datetime.now()
|
||||
updated_at: datetime = datetime.now()
|
||||
|
||||
@@ -44,6 +50,10 @@ class Conversation(BaseModel):
|
||||
self.updated_at = datetime.now()
|
||||
return message
|
||||
|
||||
def set_context(self, context: Context):
|
||||
"""Sets the context for the conversation."""
|
||||
self.context = context
|
||||
|
||||
|
||||
class ConversationRequest(BaseModel):
|
||||
conversation_id: Optional[str] = None
|
||||
|
||||
@@ -10,6 +10,7 @@ from ..core.logger import logger
|
||||
|
||||
|
||||
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
||||
DEFAULT_MAX_TOKENS = 4096
|
||||
|
||||
|
||||
class Anthropic(BaseClientProvider):
|
||||
@@ -22,6 +23,8 @@ class Anthropic(BaseClientProvider):
|
||||
self._api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not self._api_key:
|
||||
raise ValueError("Anthropic API key not provided.")
|
||||
logger.debug(f"API key length: {len(self._api_key) if self._api_key else 0}")
|
||||
|
||||
base_client = BaseAnthropic(api_key=self._api_key)
|
||||
self.client = instructor.from_anthropic(base_client)
|
||||
if not self.test_connection():
|
||||
@@ -55,12 +58,17 @@ class Anthropic(BaseClientProvider):
|
||||
params = {
|
||||
"messages": messages,
|
||||
"model": self.model,
|
||||
"max_tokens": DEFAULT_MAX_TOKENS,
|
||||
}
|
||||
if conversation.context:
|
||||
params["context"] = conversation.context
|
||||
params["context"] = (
|
||||
vars(conversation.context)
|
||||
if hasattr(conversation.context, "__dict__")
|
||||
else dict(conversation.context)
|
||||
)
|
||||
|
||||
try:
|
||||
completion = self.client.completions.create(**params)
|
||||
completion = self.client.completions.create(response_model=str, **params)
|
||||
response_text = completion.completion
|
||||
metadata = {"model": completion.model, "usage": completion.usage}
|
||||
logger.info("Generated response from Anthropic.")
|
||||
|
||||
@@ -20,27 +20,32 @@ class BaseClientProvider:
|
||||
|
||||
def login(self):
|
||||
"""Initializes the AI provider client."""
|
||||
|
||||
msg = "This method must be implemented by the AI provider client."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""Tests the connection to the AI provider client."""
|
||||
|
||||
msg = "This method must be implemented by the AI provider client."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def health_check(self):
|
||||
"""Checks the health of the AI provider client."""
|
||||
|
||||
msg = "This method must be implemented by the AI provider client."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
@property
|
||||
def available_models(self) -> List[str]:
|
||||
"""Returns the available models from the AI provider client."""
|
||||
|
||||
msg = "This method must be implemented by the AI provider client."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def message(self, message: str, **kwargs) -> AIResponse:
|
||||
"""Generates a response from the AI provider client."""
|
||||
|
||||
msg = "This method must be implemented by the AI provider client."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from typing import List, Optional
|
||||
from openai import OpenAI as BaseOpenAI
|
||||
from ..core.errors import AuthenticationError, ProviderError
|
||||
from simplemind.core.config import settings
|
||||
from ..core.config import settings
|
||||
from ..core.logger import logger
|
||||
from .base import BaseClientProvider
|
||||
|
||||
DEFAULT_MODEL = "gpt-4"
|
||||
|
||||
|
||||
class OpenAI(BaseClientProvider):
|
||||
def __init__(self, model: str = "gpt-4", api_key: Optional[str] = None):
|
||||
@@ -47,7 +49,7 @@ class OpenAI(BaseClientProvider):
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise ProviderError(f"OpenAI API error: {e}")
|
||||
|
||||
def generate_response(self, conversation) -> str:
|
||||
def generate_response(self, conversation, **kwargs) -> str:
|
||||
"""Generate a response using the OpenAI API."""
|
||||
try:
|
||||
messages = [
|
||||
@@ -55,10 +57,24 @@ class OpenAI(BaseClientProvider):
|
||||
for msg in conversation.messages
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model, messages=messages
|
||||
)
|
||||
# Ensure we're using a valid model name, not the API key
|
||||
if not isinstance(self.model, str) or self.model.startswith("sk-"):
|
||||
logger.warning(
|
||||
f"Invalid model name detected. Falling back to {DEFAULT_MODEL!r}"
|
||||
)
|
||||
model_name = DEFAULT_MODEL
|
||||
else:
|
||||
model_name = self.model
|
||||
|
||||
return response.choices[0].message.content
|
||||
r = self.client.chat.completions.create(
|
||||
model=model_name, # Use the validated model name
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
)
|
||||
return r
|
||||
except Exception as e:
|
||||
self._handle_api_error(e)
|
||||
|
||||
def generate_text(self, conversation, **kwargs) -> str:
|
||||
"""Generate a text response using the OpenAI API."""
|
||||
return self.generate_response(conversation, **kwargs).choices[0].message.content
|
||||
|
||||
@@ -1,42 +1,15 @@
|
||||
import os
|
||||
from pprint import pprint
|
||||
from pydantic import BaseModel
|
||||
import simplemind
|
||||
from simplemind.concepts.context import Context
|
||||
from simplemind.plugins.kv import KVPlugin
|
||||
from simplemind.plugins.basic_memory import BasicMemoryPlugin
|
||||
from simplemind.chains.reverse_text import ReverseTextChain
|
||||
from simplemind.core.client import Client
|
||||
|
||||
from simplemind.core import settings
|
||||
|
||||
class CustomContext(Context):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_plugin("kv", KVPlugin())
|
||||
# self.add_plugin("basic_memory", BasicMemoryPlugin())
|
||||
print(settings)
|
||||
|
||||
# Initialize client without explicit API key
|
||||
ai = simplemind.Client()
|
||||
|
||||
# Initialize context and client
|
||||
ctx = CustomContext()
|
||||
aiclient = Client(
|
||||
context=ctx,
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
)
|
||||
print(ai.available_models)
|
||||
|
||||
# Test connection and available models
|
||||
print(aiclient.available_models)
|
||||
|
||||
# Example usage
|
||||
conversation = aiclient.create_conversation(provider="anthropic")
|
||||
conversation.set_context(ctx)
|
||||
response = aiclient.send_message(
|
||||
conversation, "Who is Kenneth Reitz?", provider="anthropic"
|
||||
)
|
||||
# response = aiclient.send_message(
|
||||
# conversation, "Who is Kenneth Reitz?", provider="openai"
|
||||
# )
|
||||
# print(response)
|
||||
|
||||
# reverse_chain = ReverseTextChain()
|
||||
# result = reverse_chain.run("Hello, World!")
|
||||
# print(result) # Output: !dlroW ,olleH
|
||||
# The provider will automatically use OPENAI_API_KEY from environment
|
||||
conversation = ai.create_conversation(provider="openai")
|
||||
response = ai.send_message(conversation, "Who is Kenneth Reitz?")
|
||||
print(response)
|
||||
|
||||
Reference in New Issue
Block a user