mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 06:46:18 +00:00
cleanup
This commit is contained in:
+18
-1
@@ -1 +1,18 @@
|
||||
from .core import SimpleMind
|
||||
from .models import Conversation
|
||||
|
||||
|
||||
class SimpleMind:
|
||||
def create_conversation(self, *, llm_model=None, llm_provider=None):
|
||||
return Conversation()
|
||||
|
||||
def structured_response(
|
||||
self, *, llm_model=None, llm_provider=None, response_model=None
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def create_conversation():
|
||||
return SimpleMind().create_conversation()
|
||||
|
||||
|
||||
globals().update(locals())
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
@abstractmethod
|
||||
def decide(self, context, *args, **kwargs):
|
||||
pass
|
||||
@@ -1,24 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseChain(ABC):
|
||||
"""Abstract base class for implementing chain operations.
|
||||
|
||||
A Chain represents a processing step that can be executed on input data
|
||||
and should be implemented by concrete classes to define specific behaviors.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def run(self, input_data: str) -> str:
|
||||
"""Execute the chain's operation on the input data.
|
||||
|
||||
Args:
|
||||
input_data: The input string to be processed by the chain.
|
||||
|
||||
Returns:
|
||||
The processed output string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input data is invalid or cannot be processed.
|
||||
"""
|
||||
pass
|
||||
@@ -1,25 +0,0 @@
|
||||
from .base import BaseChain
|
||||
|
||||
|
||||
class ReverseTextChain(BaseChain):
|
||||
"""Chain that reverses input text.
|
||||
|
||||
This chain takes a text input and returns it reversed. For example,
|
||||
"hello" becomes "olleh".
|
||||
"""
|
||||
|
||||
def run(self, input_data: str) -> str:
|
||||
"""Reverse the input text.
|
||||
|
||||
Args:
|
||||
input_data: The text to reverse.
|
||||
|
||||
Returns:
|
||||
The reversed text.
|
||||
|
||||
Raises:
|
||||
TypeError: If input_data is not a string.
|
||||
"""
|
||||
if not isinstance(input_data, str):
|
||||
raise TypeError("Input must be a string")
|
||||
return input_data[::-1]
|
||||
@@ -1,17 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, Any
|
||||
from simplemind.plugins.base import BasePlugin
|
||||
|
||||
|
||||
class Context(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
plugins: Dict[str, BasePlugin] = {}
|
||||
|
||||
def add_plugin(self, name: str, plugin: BasePlugin):
|
||||
self.plugins[name] = plugin
|
||||
|
||||
def execute_plugin(self, name: str, *args, **kwargs):
|
||||
if name in self.plugins:
|
||||
return self.plugins[name].execute(self, *args, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Plugin '{name}' not found in context.")
|
||||
@@ -1,53 +0,0 @@
|
||||
from typing import Dict, Any, Optional
|
||||
from .models import AIResponse
|
||||
from ..concepts.context import Context
|
||||
from ..providers.base import BaseClientProvider
|
||||
|
||||
from .config import settings
|
||||
|
||||
|
||||
class SimpleMind:
|
||||
"""Main class for SimpleMind functionality."""
|
||||
|
||||
def __init__(self, provider: str = "openai", context: Optional[Context] = None):
|
||||
"""Initialize SimpleMind with the specified provider."""
|
||||
self.provider = provider
|
||||
self.context = context or Context()
|
||||
self._client = self._get_provider()
|
||||
|
||||
def _get_provider(self) -> BaseClientProvider:
|
||||
"""Get the appropriate provider client."""
|
||||
from ..providers.openai import OpenAI
|
||||
from ..providers.anthropic import Anthropic
|
||||
|
||||
# Initialize providers based on environment variables.
|
||||
providers = {}
|
||||
if settings.openai_api_key:
|
||||
providers.update({"openai": OpenAI})
|
||||
if settings.anthropic_api_key:
|
||||
providers.update({"anthropic": Anthropic})
|
||||
|
||||
if self.provider not in providers:
|
||||
raise ValueError(
|
||||
f"Provider '{self.provider}' not supported. Available providers: {list(providers.keys())}"
|
||||
)
|
||||
|
||||
return providers[self.provider]()
|
||||
|
||||
def generate(self, prompt: str, **kwargs) -> AIResponse:
|
||||
"""Generate a response using the configured provider."""
|
||||
|
||||
return self._client.message(prompt, **kwargs)
|
||||
|
||||
def create_conversation(self) -> str:
|
||||
"""Create a new conversation and return its ID."""
|
||||
|
||||
initial_message = "You are a helpful assistant."
|
||||
|
||||
conversation = self._client.create_conversation(initial_message)
|
||||
return conversation
|
||||
|
||||
def add_message(self, conversation_id: str, message: str) -> AIResponse:
|
||||
"""Send a message in an existing conversation."""
|
||||
|
||||
return self._client.add_message(conversation_id, message)
|
||||
@@ -1,68 +0,0 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from .models import Conversation, AIResponse
|
||||
from ..concepts.context import Context
|
||||
|
||||
from .errors import ProviderError
|
||||
from .logger import logger
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(self, api_key=None):
|
||||
self.providers = {}
|
||||
|
||||
# Auto-detect available API keys from environment
|
||||
api_keys = {
|
||||
"openai": os.getenv("OPENAI_API_KEY"),
|
||||
"anthropic": os.getenv("ANTHROPIC_API_KEY"),
|
||||
# Add other providers as needed
|
||||
}
|
||||
|
||||
# Initialize providers for which we have API keys
|
||||
for provider, key in api_keys.items():
|
||||
if key:
|
||||
self.providers[provider] = self._initialize_provider(provider, key)
|
||||
|
||||
def _initialize_provider(self, provider_name, api_key):
|
||||
if provider_name == "openai":
|
||||
from ..providers.openai import OpenAI
|
||||
|
||||
return OpenAI(api_key)
|
||||
elif provider_name == "anthropic":
|
||||
from ..providers.anthropic import Anthropic
|
||||
|
||||
return Anthropic(api_key)
|
||||
# Add other providers as needed
|
||||
|
||||
def create_conversation(
|
||||
self, provider: str = "openai", context: Optional[Context] = None
|
||||
) -> Conversation:
|
||||
if provider not in self.providers:
|
||||
raise ValueError(f"Provider '{provider}' not supported.")
|
||||
conversation_context = context or Context()
|
||||
return self.providers[provider].create_conversation(
|
||||
initial_message="Hello!", context=conversation_context.model_dump()
|
||||
)
|
||||
|
||||
def _handle_api_error(self, error: Exception, operation: str):
|
||||
"""Handle API errors in a consistent way."""
|
||||
logger.error(f"Error during {operation}: {str(error)}")
|
||||
raise ProviderError(f"Failed to {operation}: {str(error)}") from error
|
||||
|
||||
def send_message(
|
||||
self, conversation: Conversation, message: str, provider: str = "openai"
|
||||
) -> AIResponse:
|
||||
if provider not in self.providers:
|
||||
raise ValueError(f"Provider '{provider}' not supported.")
|
||||
try:
|
||||
return self.providers[provider].send_message(conversation.id, message)
|
||||
except Exception as e:
|
||||
self._handle_api_error(e, "send message")
|
||||
|
||||
@property
|
||||
def available_models(self):
|
||||
available = {}
|
||||
for name, provider in self.providers.items():
|
||||
available[name] = provider.available_models
|
||||
return available
|
||||
@@ -1,56 +0,0 @@
|
||||
from pydantic import Field, field_validator, ValidationError
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional, Dict, Set
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
openai_api_key: Optional[str] = Field(None, env="OPENAI_API_KEY")
|
||||
anthropic_api_key: Optional[str] = Field(None, env="ANTHROPIC_API_KEY")
|
||||
ollama_host_url: Optional[str] = Field(None, env="OLLAMA_HOST_URL")
|
||||
default_model: str = Field("gpt-4", env="DEFAULT_MODEL")
|
||||
log_level: str = Field("INFO", env="LOG_LEVEL")
|
||||
|
||||
# Map of provider names to their required environment variables
|
||||
PROVIDER_REQUIREMENTS: Dict[str, str] = {
|
||||
"openai": "openai_api_key",
|
||||
"anthropic": "anthropic_api_key",
|
||||
"ollama": "ollama_host_url",
|
||||
}
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
def check_empty_string(cls, v):
|
||||
if isinstance(v, str) and not v.strip():
|
||||
return None
|
||||
return v
|
||||
|
||||
def validate_provider(self, provider: str) -> bool:
|
||||
"""
|
||||
Validate that the necessary API key exists for a given provider.
|
||||
Raises ValueError if the provider is not properly configured.
|
||||
"""
|
||||
if provider not in self.PROVIDER_REQUIREMENTS:
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
required_key = self.PROVIDER_REQUIREMENTS[provider]
|
||||
if getattr(self, required_key) is None:
|
||||
raise ValueError(
|
||||
f"Missing API key for {provider}. "
|
||||
f"Please set {required_key.upper()} environment variable."
|
||||
)
|
||||
return True
|
||||
|
||||
@property
|
||||
def available_providers(self) -> Set[str]:
|
||||
"""Return a set of properly configured providers."""
|
||||
return {
|
||||
provider
|
||||
for provider, key in self.PROVIDER_REQUIREMENTS.items()
|
||||
if getattr(self, key) is not None
|
||||
}
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = False
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -1,22 +0,0 @@
|
||||
class SimpleMindError(Exception):
|
||||
"""Base exception for SimpleMind errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ProviderError(SimpleMindError):
|
||||
"""Raised when there's an error with the AI provider."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ConfigurationError(SimpleMindError):
|
||||
"""Raised when there's a configuration error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationError(SimpleMindError):
|
||||
"""Raised when authentication fails."""
|
||||
|
||||
pass
|
||||
@@ -1,15 +0,0 @@
|
||||
import logging
|
||||
|
||||
def setup_logger(name: str) -> logging.Logger:
|
||||
logger = logging.getLogger(name)
|
||||
if not logger.hasHandlers():
|
||||
logger.setLevel(logging.INFO)
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter('[%(asctime)s] %(levelname)s - %(message)s')
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
return logger
|
||||
|
||||
# Initialize a global logger
|
||||
logger = setup_logger("simplemind")
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict, List, Optional
|
||||
from ..concepts.context import Context
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class AIRequest(BaseModel):
|
||||
text: str
|
||||
parameters: Dict[str, Any] = {}
|
||||
|
||||
def __str__(self):
|
||||
return self.text
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str # "user", "assistant", "system"
|
||||
content: str
|
||||
created_at: datetime = datetime.now()
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
message: Message
|
||||
index: int = 0
|
||||
|
||||
|
||||
class AIResponse(BaseModel):
|
||||
choices: List[Choice]
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
"""Helper to get the first message content directly."""
|
||||
return self.choices[0].message.content
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
id: str
|
||||
messages: List[Message] = []
|
||||
context: Optional[Context] = None
|
||||
created_at: datetime = datetime.now()
|
||||
updated_at: datetime = datetime.now()
|
||||
|
||||
def get_messages(self) -> List[Message]:
|
||||
"""Returns a list of messages in the conversation."""
|
||||
return self.messages
|
||||
|
||||
def add_message(self, role: str, content: str) -> Message:
|
||||
"""Adds a new message to the conversation."""
|
||||
message = Message(role=role, content=content)
|
||||
self.messages.append(message)
|
||||
self.updated_at = datetime.now()
|
||||
return message
|
||||
|
||||
def set_context(self, context: Context):
|
||||
"""Sets the context for the conversation."""
|
||||
self.context = context
|
||||
|
||||
|
||||
class ConversationRequest(BaseModel):
|
||||
conversation_id: Optional[str] = None
|
||||
message: str
|
||||
context_update: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ConversationResponse(BaseModel):
|
||||
conversation_id: str
|
||||
messages: List[Message]
|
||||
metadata: Dict[str, Any] = {}
|
||||
@@ -0,0 +1,54 @@
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .utils import find_provider
|
||||
|
||||
|
||||
class SMBaseModel(BaseModel):
|
||||
date_created: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
def __str__(self):
|
||||
return f"<{self.__class__.__name__} {self.model_dump_json()}>"
|
||||
|
||||
|
||||
class Message(SMBaseModel):
|
||||
role: str
|
||||
text: str
|
||||
meta: Dict[str, Any] = {}
|
||||
raw: Optional[Any] = None
|
||||
llm_model: Optional[str] = None
|
||||
llm_provider: Optional[str] = None
|
||||
|
||||
def __str__(self):
|
||||
return f"<Message role={self.role} text={self.text!r}>"
|
||||
|
||||
@classmethod
|
||||
def from_raw_response(cls, *, text, raw):
|
||||
self = cls()
|
||||
self.text = text
|
||||
self.raw = raw
|
||||
return self
|
||||
|
||||
|
||||
class Conversation(SMBaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
messages: List[Message] = []
|
||||
llm_model: Optional[str] = None
|
||||
llm_provider: Optional[str] = None
|
||||
|
||||
def __str__(self):
|
||||
return f"<Conversation id={self.id!r}>"
|
||||
|
||||
def add_message(self, role: str, text: str, meta: Dict[str, Any] = {}):
|
||||
"""Add a new message to the conversation."""
|
||||
self.messages.append(Message(role=role, text=text, meta=meta))
|
||||
|
||||
def send(
|
||||
self, llm_model: Optional[str] = None, llm_provider: Optional[str] = None
|
||||
) -> Message:
|
||||
"""Send the conversation to the LLM."""
|
||||
provider = find_provider(llm_provider or self.llm_provider)
|
||||
return provider.send_conversation(self)
|
||||
@@ -1,26 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
created_at: datetime = datetime.now()
|
||||
|
||||
class Conversation(BaseModel):
|
||||
id: str
|
||||
messages: List[Message] = []
|
||||
context: Dict[str, Any] = {}
|
||||
created_at: datetime = datetime.now()
|
||||
updated_at: datetime = datetime.now()
|
||||
|
||||
def add_message(self, role: str, content: str) -> Message:
|
||||
message = Message(role=role, content=content)
|
||||
self.messages.append(message)
|
||||
self.updated_at = datetime.now()
|
||||
return message
|
||||
|
||||
class AIResponse(BaseModel):
|
||||
text: str
|
||||
response: Any
|
||||
metadata: Dict[str, Any] = {}
|
||||
@@ -1,34 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class BasePlugin(ABC):
|
||||
"""Base class for all SimpleMind plugins."""
|
||||
|
||||
def __init__(self):
|
||||
self.enabled: bool = True
|
||||
self.name: str = self.__class__.__name__
|
||||
|
||||
@abstractmethod
|
||||
def process(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Process the context and return modified context.
|
||||
|
||||
Args:
|
||||
context: The current conversation context
|
||||
|
||||
Returns:
|
||||
Modified context dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
def enable(self) -> None:
|
||||
"""Enable the plugin."""
|
||||
self.enabled = True
|
||||
|
||||
def disable(self) -> None:
|
||||
"""Disable the plugin."""
|
||||
self.enabled = False
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
return self.enabled
|
||||
@@ -1,10 +0,0 @@
|
||||
from .base import BasePlugin
|
||||
|
||||
|
||||
class BasicMemoryPlugin(BasePlugin):
|
||||
def __init__(self):
|
||||
self.memory = []
|
||||
|
||||
def execute(self, context, message):
|
||||
self.memory.append(message)
|
||||
return self.memory
|
||||
@@ -1,18 +0,0 @@
|
||||
from .base import BasePlugin
|
||||
|
||||
|
||||
class KVPlugin(BasePlugin):
|
||||
def __init__(self):
|
||||
self.store = {}
|
||||
|
||||
def process(self, key: str, value=None):
|
||||
"""
|
||||
Get or set a value in the key-value store.
|
||||
If value is None, returns the value for the key.
|
||||
If value is provided, sets the value for the key and returns it.
|
||||
"""
|
||||
if value is None:
|
||||
return self.store.get(key)
|
||||
|
||||
self.store[key] = value
|
||||
return value
|
||||
@@ -1,18 +1,3 @@
|
||||
from ..core.config import settings
|
||||
from .openai import OpenAI
|
||||
|
||||
__all__ = []
|
||||
|
||||
if settings.anthropic_api_key:
|
||||
from .anthropic import Anthropic
|
||||
|
||||
__all__.append("Anthropic")
|
||||
|
||||
if settings.openai_api_key:
|
||||
from .openai import OpenAI
|
||||
|
||||
__all__.append("OpenAI")
|
||||
|
||||
if settings.ollama_host_url:
|
||||
from .ollama import Ollama
|
||||
|
||||
__all__.append("Ollama")
|
||||
providers = [OpenAI]
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import instructor
|
||||
from anthropic import Anthropic as BaseAnthropic
|
||||
|
||||
from .base import BaseClientProvider
|
||||
from ..core.models import AIResponse, Conversation
|
||||
from ..core.logger import logger
|
||||
|
||||
|
||||
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
||||
DEFAULT_MAX_TOKENS = 4096
|
||||
|
||||
|
||||
class Anthropic(BaseClientProvider):
|
||||
def __init__(self, model: str = DEFAULT_MODEL, api_key: Optional[str] = None):
|
||||
super().__init__(model=model, api_key=api_key)
|
||||
self.login()
|
||||
|
||||
def login(self):
|
||||
if not self._api_key:
|
||||
self._api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not self._api_key:
|
||||
raise ValueError("Anthropic API key not provided.")
|
||||
logger.debug(f"API key length: {len(self._api_key) if self._api_key else 0}")
|
||||
|
||||
base_client = BaseAnthropic(api_key=self._api_key)
|
||||
self.client = instructor.from_anthropic(base_client)
|
||||
if not self.test_connection():
|
||||
raise ConnectionError("Failed to connect to Anthropic API.")
|
||||
logger.info("Logged in to Anthropic successfully.")
|
||||
|
||||
@property
|
||||
def available_models(self) -> List[str]:
|
||||
try:
|
||||
return [
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-haiku-20240307",
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching models: {e}")
|
||||
return []
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
models = self.available_models
|
||||
if models:
|
||||
logger.info(f"Available models: {models}")
|
||||
return True
|
||||
logger.warning("No available models found.")
|
||||
return False
|
||||
|
||||
def generate_response(self, conversation: Conversation) -> AIResponse:
|
||||
messages = [
|
||||
{"role": msg.role, "content": msg.content} for msg in conversation.messages
|
||||
]
|
||||
params = {
|
||||
"messages": messages,
|
||||
"model": self.model,
|
||||
"max_tokens": DEFAULT_MAX_TOKENS,
|
||||
}
|
||||
if conversation.context:
|
||||
params["context"] = (
|
||||
vars(conversation.context)
|
||||
if hasattr(conversation.context, "__dict__")
|
||||
else dict(conversation.context)
|
||||
)
|
||||
|
||||
try:
|
||||
completion = self.client.completions.create(response_model=str, **params)
|
||||
response_text = completion.completion
|
||||
metadata = {"model": completion.model, "usage": completion.usage}
|
||||
logger.info("Generated response from Anthropic.")
|
||||
return AIResponse(
|
||||
text=response_text, response=completion, metadata=metadata
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating response: {e}")
|
||||
raise e
|
||||
@@ -1,112 +0,0 @@
|
||||
# import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict, List, Optional
|
||||
from ..core.models import AIResponse, Conversation, Message
|
||||
import uuid
|
||||
|
||||
|
||||
DEFAULT_MODEL = "gpt-4o"
|
||||
|
||||
|
||||
class BaseClientProvider:
|
||||
|
||||
def __init__(self, *, model: str = DEFAULT_MODEL, api_key: Optional[str] = None):
|
||||
# self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.client = None
|
||||
self.model = model
|
||||
self._api_key = api_key
|
||||
self.conversations: Dict[str, Conversation] = {}
|
||||
|
||||
def login(self):
|
||||
"""Initializes the AI provider client."""
|
||||
|
||||
msg = "This method must be implemented by the AI provider client."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""Tests the connection to the AI provider client."""
|
||||
|
||||
msg = "This method must be implemented by the AI provider client."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def health_check(self):
|
||||
"""Checks the health of the AI provider client."""
|
||||
|
||||
msg = "This method must be implemented by the AI provider client."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
@property
|
||||
def available_models(self) -> List[str]:
|
||||
"""Returns the available models from the AI provider client."""
|
||||
|
||||
msg = "This method must be implemented by the AI provider client."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def message(self, message: str, **kwargs) -> AIResponse:
|
||||
"""Generates a response from the AI provider client."""
|
||||
|
||||
msg = "This method must be implemented by the AI provider client."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
# Uncomment and implement additional methods as needed
|
||||
# def features(self):
|
||||
# """Returns the features of the AI provider client."""
|
||||
# msg = "This method must be implemented by the AI provider client."
|
||||
# raise NotImplementedError(msg)
|
||||
|
||||
# def structured_response(self, model, message, **kwargs):
|
||||
# pass
|
||||
|
||||
# def structured_conversation(self, model, message, **kwargs):
|
||||
# pass
|
||||
|
||||
# def single_message(self, model, message, **kwargs):
|
||||
# return self.generate_response(message)
|
||||
|
||||
# def start_conversation(self, model, message, **kwargs):
|
||||
# pass
|
||||
|
||||
def create_conversation(
|
||||
self, initial_message: str, context: Optional[Dict[str, Any]] = None
|
||||
) -> Conversation:
|
||||
conv_id = str(uuid.uuid4())
|
||||
conversation = Conversation(
|
||||
id=conv_id,
|
||||
messages=[Message(role="user", content=initial_message)],
|
||||
context=context or {},
|
||||
)
|
||||
self.conversations[conv_id] = conversation
|
||||
return conversation
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
message: str,
|
||||
context_update: Optional[Dict[str, Any]] = None,
|
||||
) -> AIResponse:
|
||||
if conversation_id not in self.conversations:
|
||||
raise ValueError("Conversation ID does not exist.")
|
||||
|
||||
conversation = self.conversations[conversation_id]
|
||||
conversation.messages.append(Message(role="user", content=message))
|
||||
|
||||
if context_update:
|
||||
conversation.context.update(context_update)
|
||||
|
||||
response = self.generate_response(conversation)
|
||||
conversation.messages.append(
|
||||
Message(role="assistant", content=response.choices[0].message.content)
|
||||
)
|
||||
return response
|
||||
|
||||
def generate_response(self, conversation: Conversation) -> AIResponse:
|
||||
"""Generates a response based on the conversation."""
|
||||
raise NotImplementedError(
|
||||
"This method must be implemented by the AI provider client."
|
||||
)
|
||||
|
||||
def get_conversation(self, conversation_id: str) -> Conversation:
|
||||
if conversation_id not in self.conversations:
|
||||
raise ValueError("Conversation ID does not exist.")
|
||||
return self.conversations[conversation_id]
|
||||
@@ -1,110 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import List, Optional, Dict, Any
|
||||
from ..core.errors import ProviderError
|
||||
from ..core.logger import logger
|
||||
from .base import BaseClientProvider
|
||||
from some_module import BaseOllama # Replace 'some_module' with the actual module name
|
||||
|
||||
TIMEOUT = 60
|
||||
|
||||
|
||||
class Ollama(BaseClientProvider):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.login()
|
||||
self.conversation = []
|
||||
|
||||
def login(self):
|
||||
"""Initialize Ollama client, with Instructor enabled."""
|
||||
if not os.environ.get("OLLAMA_HOST_URL"):
|
||||
raise ValueError("Please set the OLLAMA_HOST_URL environment variable")
|
||||
|
||||
if not os.environ.get("OLLAMA_MODEL"):
|
||||
raise ValueError("Please set the OLLAMA_MODEL environment variable")
|
||||
else:
|
||||
self.model = os.environ.get("OLLAMA_MODEL")
|
||||
|
||||
self.client = BaseOllama(
|
||||
timeout=TIMEOUT, host=os.environ.get("OLLAMA_HOST_URL")
|
||||
)
|
||||
assert self.test_connection()
|
||||
|
||||
@property
|
||||
def available_models(self):
|
||||
"""Returns the available models from the Ollama client."""
|
||||
|
||||
def gen():
|
||||
for model in self.client.list().get("models"):
|
||||
yield model
|
||||
|
||||
return [g for g in gen()]
|
||||
|
||||
def test_connection(self):
|
||||
"""Test the connection to Ollama. Returns True if successful."""
|
||||
|
||||
return bool(len(self.available_models))
|
||||
|
||||
def generate_text(self, prompt, *, response_model=False, **kwargs):
|
||||
use_instructor = bool(response_model)
|
||||
|
||||
client = self.instructor_client if use_instructor else self.client
|
||||
|
||||
# Parameters for the Ollama client.
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"model": self.model,
|
||||
}
|
||||
params.update(kwargs)
|
||||
|
||||
if use_instructor:
|
||||
params["response_model"] = response_model
|
||||
|
||||
# Make the request to Ollama.
|
||||
completion = client.generate(**params)
|
||||
if use_instructor:
|
||||
return completion.model_dump()
|
||||
|
||||
else:
|
||||
return AIResponse(
|
||||
response=completion,
|
||||
text=completion.get("response"),
|
||||
)
|
||||
|
||||
def message(
|
||||
self, message=None, message_history=None, response_model=False, **kwargs
|
||||
):
|
||||
"""Generates a response from the Ollama client."""
|
||||
use_instructor = bool(response_model)
|
||||
|
||||
client = self.instructor_client if use_instructor else self.client
|
||||
|
||||
# Parameters for the Ollama client.
|
||||
all_messages = []
|
||||
if message_history:
|
||||
all_messages.extend(message_history)
|
||||
if message:
|
||||
all_messages.append({"role": "user", "content": message})
|
||||
params = {
|
||||
"messages": all_messages,
|
||||
"model": self.model,
|
||||
}
|
||||
params.update(kwargs)
|
||||
|
||||
if use_instructor:
|
||||
params["response_model"] = response_model
|
||||
|
||||
# Make the request to Ollama.
|
||||
completion = client.chat(**params)
|
||||
if use_instructor:
|
||||
return completion.model_dump()
|
||||
|
||||
else:
|
||||
return AIResponse(
|
||||
response=completion,
|
||||
text=completion.get("message").get("content"),
|
||||
)
|
||||
|
||||
def start_conversation(self):
|
||||
return Conversation(self)
|
||||
@@ -1,69 +1,57 @@
|
||||
from typing import List, Optional
|
||||
import openai
|
||||
from ..core.errors import AuthenticationError, ProviderError
|
||||
from ..core.config import settings
|
||||
from ..core.logger import logger
|
||||
from .base import BaseClientProvider
|
||||
import openai as oa
|
||||
import instructor
|
||||
|
||||
DEFAULT_MODEL = "gpt-4o"
|
||||
# from ..models import Conversation, Message
|
||||
from ..settings import settings
|
||||
|
||||
PROVIDER_NAME = "openai"
|
||||
DEFAULT_MODEL = "gpt-4o-mini"
|
||||
|
||||
|
||||
class OpenAI(BaseClientProvider):
|
||||
def __init__(self, model: str = DEFAULT_MODEL, api_key: Optional[str] = None):
|
||||
super().__init__(model=model, api_key=api_key)
|
||||
self.login()
|
||||
class OpenAI:
|
||||
__name__ = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
|
||||
def login(self) -> None:
|
||||
if not self._api_key:
|
||||
self._api_key = settings.openai_api_key
|
||||
if not self._api_key:
|
||||
raise AuthenticationError("OpenAI API key not provided")
|
||||
|
||||
try:
|
||||
openai.api_key = self._api_key
|
||||
if not self.test_connection():
|
||||
raise ProviderError("Failed to connect to OpenAI API")
|
||||
logger.info("Successfully connected to OpenAI")
|
||||
except Exception as e:
|
||||
raise ProviderError(f"OpenAI initialization failed: {str(e)}")
|
||||
def __init__(self, api_key: str = None):
|
||||
self.api_key = api_key or settings.OPENAI_API_KEY
|
||||
|
||||
@property
|
||||
def available_models(self) -> List[str]:
|
||||
try:
|
||||
models = openai.models.list()
|
||||
return [model.id for model in models["data"]]
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching models: {e}")
|
||||
return []
|
||||
def client(self):
|
||||
"""The raw OpenAI client."""
|
||||
return oa.OpenAI(api_key=self.api_key)
|
||||
|
||||
def test_connection(self):
|
||||
"""Test the connection to OpenAI API."""
|
||||
try:
|
||||
openai.models.list()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Connection test failed: {e}")
|
||||
return False
|
||||
@property
|
||||
def structured_client(self):
|
||||
"""A client patched with Instructor."""
|
||||
return instructor.patch(oa.OpenAI(api_key=self.api_key))
|
||||
|
||||
def _handle_api_error(self, e: Exception) -> None:
|
||||
"""Handle API errors."""
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise ProviderError(f"OpenAI API error: {e}")
|
||||
def send_conversation(self, conversation: "Conversation"):
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
from ..models import Message
|
||||
|
||||
def add_message(self, conversation_id, message, *args, **kwargs) -> str:
|
||||
"""Generate a response using the OpenAI API."""
|
||||
try:
|
||||
# Create a client instance using the API key
|
||||
client = openai.OpenAI(api_key=self._api_key)
|
||||
messages = [
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
]
|
||||
|
||||
# Create the message for the conversation
|
||||
messages = [{"role": "user", "content": message}]
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or DEFAULT_MODEL, messages=messages
|
||||
)
|
||||
|
||||
# Use the new API syntax
|
||||
response = client.chat.completions.create(
|
||||
model=self.model, messages=messages, *args, **kwargs
|
||||
)
|
||||
# Get the response content from the OpenAI response
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
self._handle_api_error(e)
|
||||
# Create and return a properly formatted Message instance
|
||||
return Message(
|
||||
role="assistant",
|
||||
text=assistant_message.content,
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or DEFAULT_MODEL,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(self, model, response_model, **kwargs):
|
||||
client = instructor.patch(oa.OpenAI(api_key=self.api_key))
|
||||
response = client.chat.completions.create(
|
||||
model=model, response_model=response_model, **kwargs
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
OPENAI_API_KEY: str = Field(..., env="OPENAI_API_KEY")
|
||||
ANTHROPIC_API_KEY: str = Field(..., env="ANTHROPIC_API_KEY")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -1,24 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from simplemind.providers.openai import OpenAI
|
||||
from simplemind.core.errors import AuthenticationError, ProviderError
|
||||
|
||||
|
||||
class TestOpenAIProvider(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.api_key = "test_key"
|
||||
|
||||
@patch("simplemind.providers.openai.openai")
|
||||
def test_initialization(self, mock_openai):
|
||||
mock_openai.Model.list.return_value = {"data": ["gpt-4"]}
|
||||
provider = OpenAI(api_key=self.api_key)
|
||||
self.assertIsNotNone(provider.client)
|
||||
mock_openai.Model.list.assert_called_once()
|
||||
|
||||
def test_missing_api_key(self):
|
||||
with self.assertRaises(AuthenticationError):
|
||||
OpenAI(api_key=None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,10 @@
|
||||
from .providers import providers
|
||||
|
||||
|
||||
def find_provider(provider_name: str):
|
||||
"""Find a provider by name."""
|
||||
for provider_class in providers:
|
||||
if provider_class.__name__.lower() == provider_name.lower():
|
||||
# Instantiate the provider
|
||||
return provider_class()
|
||||
raise ValueError(f"Provider {provider_name} not found")
|
||||
@@ -1,19 +0,0 @@
|
||||
import faiss
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
|
||||
class FAISSStore:
|
||||
def __init__(self, dimension: int):
|
||||
self.dimension = dimension
|
||||
self.index = faiss.IndexFlatL2(dimension)
|
||||
self.ids = []
|
||||
|
||||
def add_embeddings(self, embeddings: np.ndarray, ids: List[str]):
|
||||
self.index.add(embeddings)
|
||||
self.ids.extend(ids)
|
||||
|
||||
def search(self, query_embedding: np.ndarray, top_k: int = 5):
|
||||
distances, indices = self.index.search(query_embedding, top_k)
|
||||
results = [(self.ids[idx], distances[i]) for i, idx in enumerate(indices[0])]
|
||||
return results
|
||||
@@ -1,9 +1,7 @@
|
||||
from simplemind import SimpleMind
|
||||
import simplemind as sm
|
||||
|
||||
sm = SimpleMind()
|
||||
|
||||
# The provider will automatically use OPENAI_API_KEY from environment
|
||||
conversation = sm.create_conversation()
|
||||
r = sm.add_message(conversation.id, "Who is Kenneth Reitz?")
|
||||
conversation.add_message(role="user", text="Hello, how are you?")
|
||||
r = conversation.send(llm_model="gpt-4o-mini", llm_provider="openai")
|
||||
|
||||
print(r)
|
||||
#
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from simplemind.providers.openai import OpenAI
|
||||
from simplemind.core.errors import AuthenticationError, ProviderError
|
||||
|
||||
|
||||
class TestOpenAIProvider(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.api_key = "test_key"
|
||||
|
||||
@patch("simplemind.integrations.openai.BaseOpenAI")
|
||||
def test_initialization(self, mock_openai):
|
||||
provider = OpenAI(api_key=self.api_key)
|
||||
self.assertIsNotNone(provider.client)
|
||||
|
||||
def test_missing_api_key(self):
|
||||
with self.assertRaises(AuthenticationError):
|
||||
OpenAI(api_key=None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user