mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 06:46:18 +00:00
cleanup
This commit is contained in:
+18
-1
@@ -1 +1,18 @@
|
|||||||
from .core import SimpleMind
|
from .models import Conversation
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleMind:
|
||||||
|
def create_conversation(self, *, llm_model=None, llm_provider=None):
|
||||||
|
return Conversation()
|
||||||
|
|
||||||
|
def structured_response(
|
||||||
|
self, *, llm_model=None, llm_provider=None, response_model=None
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def create_conversation():
|
||||||
|
return SimpleMind().create_conversation()
|
||||||
|
|
||||||
|
|
||||||
|
globals().update(locals())
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def decide(self, context, *args, **kwargs):
|
|
||||||
pass
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
|
|
||||||
class BaseChain(ABC):
|
|
||||||
"""Abstract base class for implementing chain operations.
|
|
||||||
|
|
||||||
A Chain represents a processing step that can be executed on input data
|
|
||||||
and should be implemented by concrete classes to define specific behaviors.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def run(self, input_data: str) -> str:
|
|
||||||
"""Execute the chain's operation on the input data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_data: The input string to be processed by the chain.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The processed output string.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the input data is invalid or cannot be processed.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
from .base import BaseChain
|
|
||||||
|
|
||||||
|
|
||||||
class ReverseTextChain(BaseChain):
|
|
||||||
"""Chain that reverses input text.
|
|
||||||
|
|
||||||
This chain takes a text input and returns it reversed. For example,
|
|
||||||
"hello" becomes "olleh".
|
|
||||||
"""
|
|
||||||
|
|
||||||
def run(self, input_data: str) -> str:
|
|
||||||
"""Reverse the input text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_data: The text to reverse.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The reversed text.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: If input_data is not a string.
|
|
||||||
"""
|
|
||||||
if not isinstance(input_data, str):
|
|
||||||
raise TypeError("Input must be a string")
|
|
||||||
return input_data[::-1]
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
from pydantic import BaseModel
|
|
||||||
from typing import Dict, Any
|
|
||||||
from simplemind.plugins.base import BasePlugin
|
|
||||||
|
|
||||||
|
|
||||||
class Context(BaseModel):
|
|
||||||
model_config = {"arbitrary_types_allowed": True}
|
|
||||||
plugins: Dict[str, BasePlugin] = {}
|
|
||||||
|
|
||||||
def add_plugin(self, name: str, plugin: BasePlugin):
|
|
||||||
self.plugins[name] = plugin
|
|
||||||
|
|
||||||
def execute_plugin(self, name: str, *args, **kwargs):
|
|
||||||
if name in self.plugins:
|
|
||||||
return self.plugins[name].execute(self, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Plugin '{name}' not found in context.")
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
from typing import Dict, Any, Optional
|
|
||||||
from .models import AIResponse
|
|
||||||
from ..concepts.context import Context
|
|
||||||
from ..providers.base import BaseClientProvider
|
|
||||||
|
|
||||||
from .config import settings
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleMind:
|
|
||||||
"""Main class for SimpleMind functionality."""
|
|
||||||
|
|
||||||
def __init__(self, provider: str = "openai", context: Optional[Context] = None):
|
|
||||||
"""Initialize SimpleMind with the specified provider."""
|
|
||||||
self.provider = provider
|
|
||||||
self.context = context or Context()
|
|
||||||
self._client = self._get_provider()
|
|
||||||
|
|
||||||
def _get_provider(self) -> BaseClientProvider:
|
|
||||||
"""Get the appropriate provider client."""
|
|
||||||
from ..providers.openai import OpenAI
|
|
||||||
from ..providers.anthropic import Anthropic
|
|
||||||
|
|
||||||
# Initialize providers based on environment variables.
|
|
||||||
providers = {}
|
|
||||||
if settings.openai_api_key:
|
|
||||||
providers.update({"openai": OpenAI})
|
|
||||||
if settings.anthropic_api_key:
|
|
||||||
providers.update({"anthropic": Anthropic})
|
|
||||||
|
|
||||||
if self.provider not in providers:
|
|
||||||
raise ValueError(
|
|
||||||
f"Provider '{self.provider}' not supported. Available providers: {list(providers.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return providers[self.provider]()
|
|
||||||
|
|
||||||
def generate(self, prompt: str, **kwargs) -> AIResponse:
|
|
||||||
"""Generate a response using the configured provider."""
|
|
||||||
|
|
||||||
return self._client.message(prompt, **kwargs)
|
|
||||||
|
|
||||||
def create_conversation(self) -> str:
|
|
||||||
"""Create a new conversation and return its ID."""
|
|
||||||
|
|
||||||
initial_message = "You are a helpful assistant."
|
|
||||||
|
|
||||||
conversation = self._client.create_conversation(initial_message)
|
|
||||||
return conversation
|
|
||||||
|
|
||||||
def add_message(self, conversation_id: str, message: str) -> AIResponse:
|
|
||||||
"""Send a message in an existing conversation."""
|
|
||||||
|
|
||||||
return self._client.add_message(conversation_id, message)
|
|
||||||
@@ -1,68 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from .models import Conversation, AIResponse
|
|
||||||
from ..concepts.context import Context
|
|
||||||
|
|
||||||
from .errors import ProviderError
|
|
||||||
from .logger import logger
|
|
||||||
|
|
||||||
|
|
||||||
class Client:
|
|
||||||
def __init__(self, api_key=None):
|
|
||||||
self.providers = {}
|
|
||||||
|
|
||||||
# Auto-detect available API keys from environment
|
|
||||||
api_keys = {
|
|
||||||
"openai": os.getenv("OPENAI_API_KEY"),
|
|
||||||
"anthropic": os.getenv("ANTHROPIC_API_KEY"),
|
|
||||||
# Add other providers as needed
|
|
||||||
}
|
|
||||||
|
|
||||||
# Initialize providers for which we have API keys
|
|
||||||
for provider, key in api_keys.items():
|
|
||||||
if key:
|
|
||||||
self.providers[provider] = self._initialize_provider(provider, key)
|
|
||||||
|
|
||||||
def _initialize_provider(self, provider_name, api_key):
|
|
||||||
if provider_name == "openai":
|
|
||||||
from ..providers.openai import OpenAI
|
|
||||||
|
|
||||||
return OpenAI(api_key)
|
|
||||||
elif provider_name == "anthropic":
|
|
||||||
from ..providers.anthropic import Anthropic
|
|
||||||
|
|
||||||
return Anthropic(api_key)
|
|
||||||
# Add other providers as needed
|
|
||||||
|
|
||||||
def create_conversation(
|
|
||||||
self, provider: str = "openai", context: Optional[Context] = None
|
|
||||||
) -> Conversation:
|
|
||||||
if provider not in self.providers:
|
|
||||||
raise ValueError(f"Provider '{provider}' not supported.")
|
|
||||||
conversation_context = context or Context()
|
|
||||||
return self.providers[provider].create_conversation(
|
|
||||||
initial_message="Hello!", context=conversation_context.model_dump()
|
|
||||||
)
|
|
||||||
|
|
||||||
def _handle_api_error(self, error: Exception, operation: str):
|
|
||||||
"""Handle API errors in a consistent way."""
|
|
||||||
logger.error(f"Error during {operation}: {str(error)}")
|
|
||||||
raise ProviderError(f"Failed to {operation}: {str(error)}") from error
|
|
||||||
|
|
||||||
def send_message(
|
|
||||||
self, conversation: Conversation, message: str, provider: str = "openai"
|
|
||||||
) -> AIResponse:
|
|
||||||
if provider not in self.providers:
|
|
||||||
raise ValueError(f"Provider '{provider}' not supported.")
|
|
||||||
try:
|
|
||||||
return self.providers[provider].send_message(conversation.id, message)
|
|
||||||
except Exception as e:
|
|
||||||
self._handle_api_error(e, "send message")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def available_models(self):
|
|
||||||
available = {}
|
|
||||||
for name, provider in self.providers.items():
|
|
||||||
available[name] = provider.available_models
|
|
||||||
return available
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
from pydantic import Field, field_validator, ValidationError
|
|
||||||
from pydantic_settings import BaseSettings
|
|
||||||
from typing import Optional, Dict, Set
|
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
|
||||||
openai_api_key: Optional[str] = Field(None, env="OPENAI_API_KEY")
|
|
||||||
anthropic_api_key: Optional[str] = Field(None, env="ANTHROPIC_API_KEY")
|
|
||||||
ollama_host_url: Optional[str] = Field(None, env="OLLAMA_HOST_URL")
|
|
||||||
default_model: str = Field("gpt-4", env="DEFAULT_MODEL")
|
|
||||||
log_level: str = Field("INFO", env="LOG_LEVEL")
|
|
||||||
|
|
||||||
# Map of provider names to their required environment variables
|
|
||||||
PROVIDER_REQUIREMENTS: Dict[str, str] = {
|
|
||||||
"openai": "openai_api_key",
|
|
||||||
"anthropic": "anthropic_api_key",
|
|
||||||
"ollama": "ollama_host_url",
|
|
||||||
}
|
|
||||||
|
|
||||||
@field_validator("*", mode="before")
|
|
||||||
def check_empty_string(cls, v):
|
|
||||||
if isinstance(v, str) and not v.strip():
|
|
||||||
return None
|
|
||||||
return v
|
|
||||||
|
|
||||||
def validate_provider(self, provider: str) -> bool:
|
|
||||||
"""
|
|
||||||
Validate that the necessary API key exists for a given provider.
|
|
||||||
Raises ValueError if the provider is not properly configured.
|
|
||||||
"""
|
|
||||||
if provider not in self.PROVIDER_REQUIREMENTS:
|
|
||||||
raise ValueError(f"Unknown provider: {provider}")
|
|
||||||
|
|
||||||
required_key = self.PROVIDER_REQUIREMENTS[provider]
|
|
||||||
if getattr(self, required_key) is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Missing API key for {provider}. "
|
|
||||||
f"Please set {required_key.upper()} environment variable."
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def available_providers(self) -> Set[str]:
|
|
||||||
"""Return a set of properly configured providers."""
|
|
||||||
return {
|
|
||||||
provider
|
|
||||||
for provider, key in self.PROVIDER_REQUIREMENTS.items()
|
|
||||||
if getattr(self, key) is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
env_file = ".env"
|
|
||||||
case_sensitive = False
|
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
class SimpleMindError(Exception):
|
|
||||||
"""Base exception for SimpleMind errors."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderError(SimpleMindError):
|
|
||||||
"""Raised when there's an error with the AI provider."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigurationError(SimpleMindError):
|
|
||||||
"""Raised when there's a configuration error."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationError(SimpleMindError):
|
|
||||||
"""Raised when authentication fails."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
import logging
|
|
||||||
|
|
||||||
def setup_logger(name: str) -> logging.Logger:
|
|
||||||
logger = logging.getLogger(name)
|
|
||||||
if not logger.hasHandlers():
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
handler = logging.StreamHandler()
|
|
||||||
formatter = logging.Formatter('[%(asctime)s] %(levelname)s - %(message)s')
|
|
||||||
handler.setFormatter(formatter)
|
|
||||||
logger.addHandler(handler)
|
|
||||||
return logger
|
|
||||||
|
|
||||||
# Initialize a global logger
|
|
||||||
logger = setup_logger("simplemind")
|
|
||||||
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
from pydantic import BaseModel
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
from ..concepts.context import Context
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
|
||||||
class AIRequest(BaseModel):
|
|
||||||
text: str
|
|
||||||
parameters: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.text
|
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
|
||||||
role: str # "user", "assistant", "system"
|
|
||||||
content: str
|
|
||||||
created_at: datetime = datetime.now()
|
|
||||||
|
|
||||||
|
|
||||||
class Choice(BaseModel):
|
|
||||||
message: Message
|
|
||||||
index: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class AIResponse(BaseModel):
|
|
||||||
choices: List[Choice]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def content(self) -> str:
|
|
||||||
"""Helper to get the first message content directly."""
|
|
||||||
return self.choices[0].message.content
|
|
||||||
|
|
||||||
|
|
||||||
class Conversation(BaseModel):
|
|
||||||
id: str
|
|
||||||
messages: List[Message] = []
|
|
||||||
context: Optional[Context] = None
|
|
||||||
created_at: datetime = datetime.now()
|
|
||||||
updated_at: datetime = datetime.now()
|
|
||||||
|
|
||||||
def get_messages(self) -> List[Message]:
|
|
||||||
"""Returns a list of messages in the conversation."""
|
|
||||||
return self.messages
|
|
||||||
|
|
||||||
def add_message(self, role: str, content: str) -> Message:
|
|
||||||
"""Adds a new message to the conversation."""
|
|
||||||
message = Message(role=role, content=content)
|
|
||||||
self.messages.append(message)
|
|
||||||
self.updated_at = datetime.now()
|
|
||||||
return message
|
|
||||||
|
|
||||||
def set_context(self, context: Context):
|
|
||||||
"""Sets the context for the conversation."""
|
|
||||||
self.context = context
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationRequest(BaseModel):
|
|
||||||
conversation_id: Optional[str] = None
|
|
||||||
message: str
|
|
||||||
context_update: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationResponse(BaseModel):
|
|
||||||
conversation_id: str
|
|
||||||
messages: List[Message]
|
|
||||||
metadata: Dict[str, Any] = {}
|
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
import uuid
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from .utils import find_provider
|
||||||
|
|
||||||
|
|
||||||
|
class SMBaseModel(BaseModel):
|
||||||
|
date_created: datetime = Field(default_factory=datetime.now)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"<{self.__class__.__name__} {self.model_dump_json()}>"
|
||||||
|
|
||||||
|
|
||||||
|
class Message(SMBaseModel):
|
||||||
|
role: str
|
||||||
|
text: str
|
||||||
|
meta: Dict[str, Any] = {}
|
||||||
|
raw: Optional[Any] = None
|
||||||
|
llm_model: Optional[str] = None
|
||||||
|
llm_provider: Optional[str] = None
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"<Message role={self.role} text={self.text!r}>"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_raw_response(cls, *, text, raw):
|
||||||
|
self = cls()
|
||||||
|
self.text = text
|
||||||
|
self.raw = raw
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class Conversation(SMBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
messages: List[Message] = []
|
||||||
|
llm_model: Optional[str] = None
|
||||||
|
llm_provider: Optional[str] = None
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"<Conversation id={self.id!r}>"
|
||||||
|
|
||||||
|
def add_message(self, role: str, text: str, meta: Dict[str, Any] = {}):
|
||||||
|
"""Add a new message to the conversation."""
|
||||||
|
self.messages.append(Message(role=role, text=text, meta=meta))
|
||||||
|
|
||||||
|
def send(
|
||||||
|
self, llm_model: Optional[str] = None, llm_provider: Optional[str] = None
|
||||||
|
) -> Message:
|
||||||
|
"""Send the conversation to the LLM."""
|
||||||
|
provider = find_provider(llm_provider or self.llm_provider)
|
||||||
|
return provider.send_conversation(self)
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
|
||||||
role: str
|
|
||||||
content: str
|
|
||||||
created_at: datetime = datetime.now()
|
|
||||||
|
|
||||||
class Conversation(BaseModel):
|
|
||||||
id: str
|
|
||||||
messages: List[Message] = []
|
|
||||||
context: Dict[str, Any] = {}
|
|
||||||
created_at: datetime = datetime.now()
|
|
||||||
updated_at: datetime = datetime.now()
|
|
||||||
|
|
||||||
def add_message(self, role: str, content: str) -> Message:
|
|
||||||
message = Message(role=role, content=content)
|
|
||||||
self.messages.append(message)
|
|
||||||
self.updated_at = datetime.now()
|
|
||||||
return message
|
|
||||||
|
|
||||||
class AIResponse(BaseModel):
|
|
||||||
text: str
|
|
||||||
response: Any
|
|
||||||
metadata: Dict[str, Any] = {}
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
|
|
||||||
class BasePlugin(ABC):
|
|
||||||
"""Base class for all SimpleMind plugins."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.enabled: bool = True
|
|
||||||
self.name: str = self.__class__.__name__
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def process(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Process the context and return modified context.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context: The current conversation context
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Modified context dictionary
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def enable(self) -> None:
|
|
||||||
"""Enable the plugin."""
|
|
||||||
self.enabled = True
|
|
||||||
|
|
||||||
def disable(self) -> None:
|
|
||||||
"""Disable the plugin."""
|
|
||||||
self.enabled = False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_enabled(self) -> bool:
|
|
||||||
return self.enabled
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
from .base import BasePlugin
|
|
||||||
|
|
||||||
|
|
||||||
class BasicMemoryPlugin(BasePlugin):
|
|
||||||
def __init__(self):
|
|
||||||
self.memory = []
|
|
||||||
|
|
||||||
def execute(self, context, message):
|
|
||||||
self.memory.append(message)
|
|
||||||
return self.memory
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
from .base import BasePlugin
|
|
||||||
|
|
||||||
|
|
||||||
class KVPlugin(BasePlugin):
|
|
||||||
def __init__(self):
|
|
||||||
self.store = {}
|
|
||||||
|
|
||||||
def process(self, key: str, value=None):
|
|
||||||
"""
|
|
||||||
Get or set a value in the key-value store.
|
|
||||||
If value is None, returns the value for the key.
|
|
||||||
If value is provided, sets the value for the key and returns it.
|
|
||||||
"""
|
|
||||||
if value is None:
|
|
||||||
return self.store.get(key)
|
|
||||||
|
|
||||||
self.store[key] = value
|
|
||||||
return value
|
|
||||||
@@ -1,18 +1,3 @@
|
|||||||
from ..core.config import settings
|
from .openai import OpenAI
|
||||||
|
|
||||||
__all__ = []
|
providers = [OpenAI]
|
||||||
|
|
||||||
if settings.anthropic_api_key:
|
|
||||||
from .anthropic import Anthropic
|
|
||||||
|
|
||||||
__all__.append("Anthropic")
|
|
||||||
|
|
||||||
if settings.openai_api_key:
|
|
||||||
from .openai import OpenAI
|
|
||||||
|
|
||||||
__all__.append("OpenAI")
|
|
||||||
|
|
||||||
if settings.ollama_host_url:
|
|
||||||
from .ollama import Ollama
|
|
||||||
|
|
||||||
__all__.append("Ollama")
|
|
||||||
|
|||||||
@@ -1,80 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import instructor
|
|
||||||
from anthropic import Anthropic as BaseAnthropic
|
|
||||||
|
|
||||||
from .base import BaseClientProvider
|
|
||||||
from ..core.models import AIResponse, Conversation
|
|
||||||
from ..core.logger import logger
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
|
||||||
DEFAULT_MAX_TOKENS = 4096
|
|
||||||
|
|
||||||
|
|
||||||
class Anthropic(BaseClientProvider):
|
|
||||||
def __init__(self, model: str = DEFAULT_MODEL, api_key: Optional[str] = None):
|
|
||||||
super().__init__(model=model, api_key=api_key)
|
|
||||||
self.login()
|
|
||||||
|
|
||||||
def login(self):
|
|
||||||
if not self._api_key:
|
|
||||||
self._api_key = os.getenv("ANTHROPIC_API_KEY")
|
|
||||||
if not self._api_key:
|
|
||||||
raise ValueError("Anthropic API key not provided.")
|
|
||||||
logger.debug(f"API key length: {len(self._api_key) if self._api_key else 0}")
|
|
||||||
|
|
||||||
base_client = BaseAnthropic(api_key=self._api_key)
|
|
||||||
self.client = instructor.from_anthropic(base_client)
|
|
||||||
if not self.test_connection():
|
|
||||||
raise ConnectionError("Failed to connect to Anthropic API.")
|
|
||||||
logger.info("Logged in to Anthropic successfully.")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def available_models(self) -> List[str]:
|
|
||||||
try:
|
|
||||||
return [
|
|
||||||
"claude-3-opus-20240229",
|
|
||||||
"claude-3-5-sonnet-20240620",
|
|
||||||
"claude-3-haiku-20240307",
|
|
||||||
]
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error fetching models: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def test_connection(self) -> bool:
|
|
||||||
models = self.available_models
|
|
||||||
if models:
|
|
||||||
logger.info(f"Available models: {models}")
|
|
||||||
return True
|
|
||||||
logger.warning("No available models found.")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def generate_response(self, conversation: Conversation) -> AIResponse:
|
|
||||||
messages = [
|
|
||||||
{"role": msg.role, "content": msg.content} for msg in conversation.messages
|
|
||||||
]
|
|
||||||
params = {
|
|
||||||
"messages": messages,
|
|
||||||
"model": self.model,
|
|
||||||
"max_tokens": DEFAULT_MAX_TOKENS,
|
|
||||||
}
|
|
||||||
if conversation.context:
|
|
||||||
params["context"] = (
|
|
||||||
vars(conversation.context)
|
|
||||||
if hasattr(conversation.context, "__dict__")
|
|
||||||
else dict(conversation.context)
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
completion = self.client.completions.create(response_model=str, **params)
|
|
||||||
response_text = completion.completion
|
|
||||||
metadata = {"model": completion.model, "usage": completion.usage}
|
|
||||||
logger.info("Generated response from Anthropic.")
|
|
||||||
return AIResponse(
|
|
||||||
text=response_text, response=completion, metadata=metadata
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error generating response: {e}")
|
|
||||||
raise e
|
|
||||||
@@ -1,112 +0,0 @@
|
|||||||
# import logging
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
from ..core.models import AIResponse, Conversation, Message
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_MODEL = "gpt-4o"
|
|
||||||
|
|
||||||
|
|
||||||
class BaseClientProvider:
|
|
||||||
|
|
||||||
def __init__(self, *, model: str = DEFAULT_MODEL, api_key: Optional[str] = None):
|
|
||||||
# self.logger = logging.getLogger(self.__class__.__name__)
|
|
||||||
self.client = None
|
|
||||||
self.model = model
|
|
||||||
self._api_key = api_key
|
|
||||||
self.conversations: Dict[str, Conversation] = {}
|
|
||||||
|
|
||||||
def login(self):
|
|
||||||
"""Initializes the AI provider client."""
|
|
||||||
|
|
||||||
msg = "This method must be implemented by the AI provider client."
|
|
||||||
raise NotImplementedError(msg)
|
|
||||||
|
|
||||||
def test_connection(self) -> bool:
|
|
||||||
"""Tests the connection to the AI provider client."""
|
|
||||||
|
|
||||||
msg = "This method must be implemented by the AI provider client."
|
|
||||||
raise NotImplementedError(msg)
|
|
||||||
|
|
||||||
def health_check(self):
|
|
||||||
"""Checks the health of the AI provider client."""
|
|
||||||
|
|
||||||
msg = "This method must be implemented by the AI provider client."
|
|
||||||
raise NotImplementedError(msg)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def available_models(self) -> List[str]:
|
|
||||||
"""Returns the available models from the AI provider client."""
|
|
||||||
|
|
||||||
msg = "This method must be implemented by the AI provider client."
|
|
||||||
raise NotImplementedError(msg)
|
|
||||||
|
|
||||||
def message(self, message: str, **kwargs) -> AIResponse:
|
|
||||||
"""Generates a response from the AI provider client."""
|
|
||||||
|
|
||||||
msg = "This method must be implemented by the AI provider client."
|
|
||||||
raise NotImplementedError(msg)
|
|
||||||
|
|
||||||
# Uncomment and implement additional methods as needed
|
|
||||||
# def features(self):
|
|
||||||
# """Returns the features of the AI provider client."""
|
|
||||||
# msg = "This method must be implemented by the AI provider client."
|
|
||||||
# raise NotImplementedError(msg)
|
|
||||||
|
|
||||||
# def structured_response(self, model, message, **kwargs):
|
|
||||||
# pass
|
|
||||||
|
|
||||||
# def structured_conversation(self, model, message, **kwargs):
|
|
||||||
# pass
|
|
||||||
|
|
||||||
# def single_message(self, model, message, **kwargs):
|
|
||||||
# return self.generate_response(message)
|
|
||||||
|
|
||||||
# def start_conversation(self, model, message, **kwargs):
|
|
||||||
# pass
|
|
||||||
|
|
||||||
def create_conversation(
|
|
||||||
self, initial_message: str, context: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Conversation:
|
|
||||||
conv_id = str(uuid.uuid4())
|
|
||||||
conversation = Conversation(
|
|
||||||
id=conv_id,
|
|
||||||
messages=[Message(role="user", content=initial_message)],
|
|
||||||
context=context or {},
|
|
||||||
)
|
|
||||||
self.conversations[conv_id] = conversation
|
|
||||||
return conversation
|
|
||||||
|
|
||||||
def send_message(
|
|
||||||
self,
|
|
||||||
conversation_id: str,
|
|
||||||
message: str,
|
|
||||||
context_update: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> AIResponse:
|
|
||||||
if conversation_id not in self.conversations:
|
|
||||||
raise ValueError("Conversation ID does not exist.")
|
|
||||||
|
|
||||||
conversation = self.conversations[conversation_id]
|
|
||||||
conversation.messages.append(Message(role="user", content=message))
|
|
||||||
|
|
||||||
if context_update:
|
|
||||||
conversation.context.update(context_update)
|
|
||||||
|
|
||||||
response = self.generate_response(conversation)
|
|
||||||
conversation.messages.append(
|
|
||||||
Message(role="assistant", content=response.choices[0].message.content)
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
|
|
||||||
def generate_response(self, conversation: Conversation) -> AIResponse:
|
|
||||||
"""Generates a response based on the conversation."""
|
|
||||||
raise NotImplementedError(
|
|
||||||
"This method must be implemented by the AI provider client."
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_conversation(self, conversation_id: str) -> Conversation:
|
|
||||||
if conversation_id not in self.conversations:
|
|
||||||
raise ValueError("Conversation ID does not exist.")
|
|
||||||
return self.conversations[conversation_id]
|
|
||||||
@@ -1,110 +0,0 @@
|
|||||||
import os
|
|
||||||
import uuid
|
|
||||||
from typing import List, Optional, Dict, Any
|
|
||||||
from ..core.errors import ProviderError
|
|
||||||
from ..core.logger import logger
|
|
||||||
from .base import BaseClientProvider
|
|
||||||
from some_module import BaseOllama # Replace 'some_module' with the actual module name
|
|
||||||
|
|
||||||
TIMEOUT = 60
|
|
||||||
|
|
||||||
|
|
||||||
class Ollama(BaseClientProvider):
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.login()
|
|
||||||
self.conversation = []
|
|
||||||
|
|
||||||
def login(self):
|
|
||||||
"""Initialize Ollama client, with Instructor enabled."""
|
|
||||||
if not os.environ.get("OLLAMA_HOST_URL"):
|
|
||||||
raise ValueError("Please set the OLLAMA_HOST_URL environment variable")
|
|
||||||
|
|
||||||
if not os.environ.get("OLLAMA_MODEL"):
|
|
||||||
raise ValueError("Please set the OLLAMA_MODEL environment variable")
|
|
||||||
else:
|
|
||||||
self.model = os.environ.get("OLLAMA_MODEL")
|
|
||||||
|
|
||||||
self.client = BaseOllama(
|
|
||||||
timeout=TIMEOUT, host=os.environ.get("OLLAMA_HOST_URL")
|
|
||||||
)
|
|
||||||
assert self.test_connection()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def available_models(self):
|
|
||||||
"""Returns the available models from the Ollama client."""
|
|
||||||
|
|
||||||
def gen():
|
|
||||||
for model in self.client.list().get("models"):
|
|
||||||
yield model
|
|
||||||
|
|
||||||
return [g for g in gen()]
|
|
||||||
|
|
||||||
def test_connection(self):
|
|
||||||
"""Test the connection to Ollama. Returns True if successful."""
|
|
||||||
|
|
||||||
return bool(len(self.available_models))
|
|
||||||
|
|
||||||
def generate_text(self, prompt, *, response_model=False, **kwargs):
|
|
||||||
use_instructor = bool(response_model)
|
|
||||||
|
|
||||||
client = self.instructor_client if use_instructor else self.client
|
|
||||||
|
|
||||||
# Parameters for the Ollama client.
|
|
||||||
params = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"model": self.model,
|
|
||||||
}
|
|
||||||
params.update(kwargs)
|
|
||||||
|
|
||||||
if use_instructor:
|
|
||||||
params["response_model"] = response_model
|
|
||||||
|
|
||||||
# Make the request to Ollama.
|
|
||||||
completion = client.generate(**params)
|
|
||||||
if use_instructor:
|
|
||||||
return completion.model_dump()
|
|
||||||
|
|
||||||
else:
|
|
||||||
return AIResponse(
|
|
||||||
response=completion,
|
|
||||||
text=completion.get("response"),
|
|
||||||
)
|
|
||||||
|
|
||||||
def message(
|
|
||||||
self, message=None, message_history=None, response_model=False, **kwargs
|
|
||||||
):
|
|
||||||
"""Generates a response from the Ollama client."""
|
|
||||||
use_instructor = bool(response_model)
|
|
||||||
|
|
||||||
client = self.instructor_client if use_instructor else self.client
|
|
||||||
|
|
||||||
# Parameters for the Ollama client.
|
|
||||||
all_messages = []
|
|
||||||
if message_history:
|
|
||||||
all_messages.extend(message_history)
|
|
||||||
if message:
|
|
||||||
all_messages.append({"role": "user", "content": message})
|
|
||||||
params = {
|
|
||||||
"messages": all_messages,
|
|
||||||
"model": self.model,
|
|
||||||
}
|
|
||||||
params.update(kwargs)
|
|
||||||
|
|
||||||
if use_instructor:
|
|
||||||
params["response_model"] = response_model
|
|
||||||
|
|
||||||
# Make the request to Ollama.
|
|
||||||
completion = client.chat(**params)
|
|
||||||
if use_instructor:
|
|
||||||
return completion.model_dump()
|
|
||||||
|
|
||||||
else:
|
|
||||||
return AIResponse(
|
|
||||||
response=completion,
|
|
||||||
text=completion.get("message").get("content"),
|
|
||||||
)
|
|
||||||
|
|
||||||
def start_conversation(self):
|
|
||||||
return Conversation(self)
|
|
||||||
@@ -1,69 +1,57 @@
|
|||||||
from typing import List, Optional
|
import openai as oa
|
||||||
import openai
|
import instructor
|
||||||
from ..core.errors import AuthenticationError, ProviderError
|
|
||||||
from ..core.config import settings
|
|
||||||
from ..core.logger import logger
|
|
||||||
from .base import BaseClientProvider
|
|
||||||
|
|
||||||
DEFAULT_MODEL = "gpt-4o"
|
# from ..models import Conversation, Message
|
||||||
|
from ..settings import settings
|
||||||
|
|
||||||
|
PROVIDER_NAME = "openai"
|
||||||
|
DEFAULT_MODEL = "gpt-4o-mini"
|
||||||
|
|
||||||
|
|
||||||
class OpenAI(BaseClientProvider):
|
class OpenAI:
|
||||||
def __init__(self, model: str = DEFAULT_MODEL, api_key: Optional[str] = None):
|
__name__ = PROVIDER_NAME
|
||||||
super().__init__(model=model, api_key=api_key)
|
DEFAULT_MODEL = DEFAULT_MODEL
|
||||||
self.login()
|
|
||||||
|
|
||||||
def login(self) -> None:
|
def __init__(self, api_key: str = None):
|
||||||
if not self._api_key:
|
self.api_key = api_key or settings.OPENAI_API_KEY
|
||||||
self._api_key = settings.openai_api_key
|
|
||||||
if not self._api_key:
|
|
||||||
raise AuthenticationError("OpenAI API key not provided")
|
|
||||||
|
|
||||||
try:
|
|
||||||
openai.api_key = self._api_key
|
|
||||||
if not self.test_connection():
|
|
||||||
raise ProviderError("Failed to connect to OpenAI API")
|
|
||||||
logger.info("Successfully connected to OpenAI")
|
|
||||||
except Exception as e:
|
|
||||||
raise ProviderError(f"OpenAI initialization failed: {str(e)}")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def available_models(self) -> List[str]:
|
def client(self):
|
||||||
try:
|
"""The raw OpenAI client."""
|
||||||
models = openai.models.list()
|
return oa.OpenAI(api_key=self.api_key)
|
||||||
return [model.id for model in models["data"]]
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error fetching models: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def test_connection(self):
|
@property
|
||||||
"""Test the connection to OpenAI API."""
|
def structured_client(self):
|
||||||
try:
|
"""A client patched with Instructor."""
|
||||||
openai.models.list()
|
return instructor.patch(oa.OpenAI(api_key=self.api_key))
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Connection test failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _handle_api_error(self, e: Exception) -> None:
|
def send_conversation(self, conversation: "Conversation"):
|
||||||
"""Handle API errors."""
|
"""Send a conversation to the OpenAI API."""
|
||||||
logger.error(f"OpenAI API error: {e}")
|
from ..models import Message
|
||||||
raise ProviderError(f"OpenAI API error: {e}")
|
|
||||||
|
|
||||||
def add_message(self, conversation_id, message, *args, **kwargs) -> str:
|
messages = [
|
||||||
"""Generate a response using the OpenAI API."""
|
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||||
try:
|
]
|
||||||
# Create a client instance using the API key
|
|
||||||
client = openai.OpenAI(api_key=self._api_key)
|
|
||||||
|
|
||||||
# Create the message for the conversation
|
response = self.client.chat.completions.create(
|
||||||
messages = [{"role": "user", "content": message}]
|
model=conversation.llm_model or DEFAULT_MODEL, messages=messages
|
||||||
|
)
|
||||||
|
|
||||||
# Use the new API syntax
|
# Get the response content from the OpenAI response
|
||||||
response = client.chat.completions.create(
|
assistant_message = response.choices[0].message
|
||||||
model=self.model, messages=messages, *args, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.choices[0].message.content
|
# Create and return a properly formatted Message instance
|
||||||
except Exception as e:
|
return Message(
|
||||||
self._handle_api_error(e)
|
role="assistant",
|
||||||
|
text=assistant_message.content,
|
||||||
|
raw=response,
|
||||||
|
llm_model=conversation.llm_model or DEFAULT_MODEL,
|
||||||
|
llm_provider=PROVIDER_NAME,
|
||||||
|
)
|
||||||
|
|
||||||
|
def structured_response(self, model, response_model, **kwargs):
|
||||||
|
client = instructor.patch(oa.OpenAI(api_key=self.api_key))
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=model, response_model=response_model, **kwargs
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|||||||
@@ -0,0 +1,10 @@
|
|||||||
|
from pydantic import Field
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
OPENAI_API_KEY: str = Field(..., env="OPENAI_API_KEY")
|
||||||
|
ANTHROPIC_API_KEY: str = Field(..., env="ANTHROPIC_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
from simplemind.providers.openai import OpenAI
|
|
||||||
from simplemind.core.errors import AuthenticationError, ProviderError
|
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIProvider(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.api_key = "test_key"
|
|
||||||
|
|
||||||
@patch("simplemind.providers.openai.openai")
|
|
||||||
def test_initialization(self, mock_openai):
|
|
||||||
mock_openai.Model.list.return_value = {"data": ["gpt-4"]}
|
|
||||||
provider = OpenAI(api_key=self.api_key)
|
|
||||||
self.assertIsNotNone(provider.client)
|
|
||||||
mock_openai.Model.list.assert_called_once()
|
|
||||||
|
|
||||||
def test_missing_api_key(self):
|
|
||||||
with self.assertRaises(AuthenticationError):
|
|
||||||
OpenAI(api_key=None)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
from .providers import providers
|
||||||
|
|
||||||
|
|
||||||
|
def find_provider(provider_name: str):
|
||||||
|
"""Find a provider by name."""
|
||||||
|
for provider_class in providers:
|
||||||
|
if provider_class.__name__.lower() == provider_name.lower():
|
||||||
|
# Instantiate the provider
|
||||||
|
return provider_class()
|
||||||
|
raise ValueError(f"Provider {provider_name} not found")
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
import faiss
|
|
||||||
import numpy as np
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
|
|
||||||
class FAISSStore:
|
|
||||||
def __init__(self, dimension: int):
|
|
||||||
self.dimension = dimension
|
|
||||||
self.index = faiss.IndexFlatL2(dimension)
|
|
||||||
self.ids = []
|
|
||||||
|
|
||||||
def add_embeddings(self, embeddings: np.ndarray, ids: List[str]):
|
|
||||||
self.index.add(embeddings)
|
|
||||||
self.ids.extend(ids)
|
|
||||||
|
|
||||||
def search(self, query_embedding: np.ndarray, top_k: int = 5):
|
|
||||||
distances, indices = self.index.search(query_embedding, top_k)
|
|
||||||
results = [(self.ids[idx], distances[i]) for i, idx in enumerate(indices[0])]
|
|
||||||
return results
|
|
||||||
@@ -1,9 +1,7 @@
|
|||||||
from simplemind import SimpleMind
|
import simplemind as sm
|
||||||
|
|
||||||
sm = SimpleMind()
|
|
||||||
|
|
||||||
# The provider will automatically use OPENAI_API_KEY from environment
|
|
||||||
conversation = sm.create_conversation()
|
conversation = sm.create_conversation()
|
||||||
r = sm.add_message(conversation.id, "Who is Kenneth Reitz?")
|
conversation.add_message(role="user", text="Hello, how are you?")
|
||||||
|
r = conversation.send(llm_model="gpt-4o-mini", llm_provider="openai")
|
||||||
|
|
||||||
print(r)
|
print(r)
|
||||||
#
|
|
||||||
|
|||||||
@@ -1,22 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
from simplemind.providers.openai import OpenAI
|
|
||||||
from simplemind.core.errors import AuthenticationError, ProviderError
|
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIProvider(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.api_key = "test_key"
|
|
||||||
|
|
||||||
@patch("simplemind.integrations.openai.BaseOpenAI")
|
|
||||||
def test_initialization(self, mock_openai):
|
|
||||||
provider = OpenAI(api_key=self.api_key)
|
|
||||||
self.assertIsNotNone(provider.client)
|
|
||||||
|
|
||||||
def test_missing_api_key(self):
|
|
||||||
with self.assertRaises(AuthenticationError):
|
|
||||||
OpenAI(api_key=None)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
Reference in New Issue
Block a user