This commit is contained in:
2024-10-28 18:21:25 -04:00
parent 300d5a1d81
commit 578f3fc11e
32 changed files with 143 additions and 890 deletions
+18 -1
View File
@@ -1 +1,18 @@
from .core import SimpleMind from .models import Conversation
class SimpleMind:
def create_conversation(self, *, llm_model=None, llm_provider=None):
return Conversation()
def structured_response(
self, *, llm_model=None, llm_provider=None, response_model=None
):
pass
def create_conversation():
return SimpleMind().create_conversation()
globals().update(locals())
View File
-7
View File
@@ -1,7 +0,0 @@
from abc import ABC, abstractmethod
class BaseAgent(ABC):
@abstractmethod
def decide(self, context, *args, **kwargs):
pass
View File
-24
View File
@@ -1,24 +0,0 @@
from abc import ABC, abstractmethod
class BaseChain(ABC):
"""Abstract base class for implementing chain operations.
A Chain represents a processing step that can be executed on input data
and should be implemented by concrete classes to define specific behaviors.
"""
@abstractmethod
def run(self, input_data: str) -> str:
"""Execute the chain's operation on the input data.
Args:
input_data: The input string to be processed by the chain.
Returns:
The processed output string.
Raises:
ValueError: If the input data is invalid or cannot be processed.
"""
pass
-25
View File
@@ -1,25 +0,0 @@
from .base import BaseChain
class ReverseTextChain(BaseChain):
"""Chain that reverses input text.
This chain takes a text input and returns it reversed. For example,
"hello" becomes "olleh".
"""
def run(self, input_data: str) -> str:
"""Reverse the input text.
Args:
input_data: The text to reverse.
Returns:
The reversed text.
Raises:
TypeError: If input_data is not a string.
"""
if not isinstance(input_data, str):
raise TypeError("Input must be a string")
return input_data[::-1]
View File
-17
View File
@@ -1,17 +0,0 @@
from pydantic import BaseModel
from typing import Dict, Any
from simplemind.plugins.base import BasePlugin
class Context(BaseModel):
model_config = {"arbitrary_types_allowed": True}
plugins: Dict[str, BasePlugin] = {}
def add_plugin(self, name: str, plugin: BasePlugin):
self.plugins[name] = plugin
def execute_plugin(self, name: str, *args, **kwargs):
if name in self.plugins:
return self.plugins[name].execute(self, *args, **kwargs)
else:
raise ValueError(f"Plugin '{name}' not found in context.")
-53
View File
@@ -1,53 +0,0 @@
from typing import Dict, Any, Optional
from .models import AIResponse
from ..concepts.context import Context
from ..providers.base import BaseClientProvider
from .config import settings
class SimpleMind:
"""Main class for SimpleMind functionality."""
def __init__(self, provider: str = "openai", context: Optional[Context] = None):
"""Initialize SimpleMind with the specified provider."""
self.provider = provider
self.context = context or Context()
self._client = self._get_provider()
def _get_provider(self) -> BaseClientProvider:
"""Get the appropriate provider client."""
from ..providers.openai import OpenAI
from ..providers.anthropic import Anthropic
# Initialize providers based on environment variables.
providers = {}
if settings.openai_api_key:
providers.update({"openai": OpenAI})
if settings.anthropic_api_key:
providers.update({"anthropic": Anthropic})
if self.provider not in providers:
raise ValueError(
f"Provider '{self.provider}' not supported. Available providers: {list(providers.keys())}"
)
return providers[self.provider]()
def generate(self, prompt: str, **kwargs) -> AIResponse:
"""Generate a response using the configured provider."""
return self._client.message(prompt, **kwargs)
def create_conversation(self) -> str:
"""Create a new conversation and return its ID."""
initial_message = "You are a helpful assistant."
conversation = self._client.create_conversation(initial_message)
return conversation
def add_message(self, conversation_id: str, message: str) -> AIResponse:
"""Send a message in an existing conversation."""
return self._client.add_message(conversation_id, message)
-68
View File
@@ -1,68 +0,0 @@
import os
from typing import Optional
from .models import Conversation, AIResponse
from ..concepts.context import Context
from .errors import ProviderError
from .logger import logger
class Client:
def __init__(self, api_key=None):
self.providers = {}
# Auto-detect available API keys from environment
api_keys = {
"openai": os.getenv("OPENAI_API_KEY"),
"anthropic": os.getenv("ANTHROPIC_API_KEY"),
# Add other providers as needed
}
# Initialize providers for which we have API keys
for provider, key in api_keys.items():
if key:
self.providers[provider] = self._initialize_provider(provider, key)
def _initialize_provider(self, provider_name, api_key):
if provider_name == "openai":
from ..providers.openai import OpenAI
return OpenAI(api_key)
elif provider_name == "anthropic":
from ..providers.anthropic import Anthropic
return Anthropic(api_key)
# Add other providers as needed
def create_conversation(
self, provider: str = "openai", context: Optional[Context] = None
) -> Conversation:
if provider not in self.providers:
raise ValueError(f"Provider '{provider}' not supported.")
conversation_context = context or Context()
return self.providers[provider].create_conversation(
initial_message="Hello!", context=conversation_context.model_dump()
)
def _handle_api_error(self, error: Exception, operation: str):
"""Handle API errors in a consistent way."""
logger.error(f"Error during {operation}: {str(error)}")
raise ProviderError(f"Failed to {operation}: {str(error)}") from error
def send_message(
self, conversation: Conversation, message: str, provider: str = "openai"
) -> AIResponse:
if provider not in self.providers:
raise ValueError(f"Provider '{provider}' not supported.")
try:
return self.providers[provider].send_message(conversation.id, message)
except Exception as e:
self._handle_api_error(e, "send message")
@property
def available_models(self):
available = {}
for name, provider in self.providers.items():
available[name] = provider.available_models
return available
-56
View File
@@ -1,56 +0,0 @@
from pydantic import Field, field_validator, ValidationError
from pydantic_settings import BaseSettings
from typing import Optional, Dict, Set
class Settings(BaseSettings):
openai_api_key: Optional[str] = Field(None, env="OPENAI_API_KEY")
anthropic_api_key: Optional[str] = Field(None, env="ANTHROPIC_API_KEY")
ollama_host_url: Optional[str] = Field(None, env="OLLAMA_HOST_URL")
default_model: str = Field("gpt-4", env="DEFAULT_MODEL")
log_level: str = Field("INFO", env="LOG_LEVEL")
# Map of provider names to their required environment variables
PROVIDER_REQUIREMENTS: Dict[str, str] = {
"openai": "openai_api_key",
"anthropic": "anthropic_api_key",
"ollama": "ollama_host_url",
}
@field_validator("*", mode="before")
def check_empty_string(cls, v):
if isinstance(v, str) and not v.strip():
return None
return v
def validate_provider(self, provider: str) -> bool:
"""
Validate that the necessary API key exists for a given provider.
Raises ValueError if the provider is not properly configured.
"""
if provider not in self.PROVIDER_REQUIREMENTS:
raise ValueError(f"Unknown provider: {provider}")
required_key = self.PROVIDER_REQUIREMENTS[provider]
if getattr(self, required_key) is None:
raise ValueError(
f"Missing API key for {provider}. "
f"Please set {required_key.upper()} environment variable."
)
return True
@property
def available_providers(self) -> Set[str]:
"""Return a set of properly configured providers."""
return {
provider
for provider, key in self.PROVIDER_REQUIREMENTS.items()
if getattr(self, key) is not None
}
class Config:
env_file = ".env"
case_sensitive = False
settings = Settings()
-22
View File
@@ -1,22 +0,0 @@
class SimpleMindError(Exception):
"""Base exception for SimpleMind errors."""
pass
class ProviderError(SimpleMindError):
"""Raised when there's an error with the AI provider."""
pass
class ConfigurationError(SimpleMindError):
"""Raised when there's a configuration error."""
pass
class AuthenticationError(SimpleMindError):
"""Raised when authentication fails."""
pass
-15
View File
@@ -1,15 +0,0 @@
import logging
def setup_logger(name: str) -> logging.Logger:
logger = logging.getLogger(name)
if not logger.hasHandlers():
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter('[%(asctime)s] %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
# Initialize a global logger
logger = setup_logger("simplemind")
-67
View File
@@ -1,67 +0,0 @@
from pydantic import BaseModel
from typing import Any, Dict, List, Optional
from ..concepts.context import Context
from datetime import datetime
class AIRequest(BaseModel):
text: str
parameters: Dict[str, Any] = {}
def __str__(self):
return self.text
class Message(BaseModel):
role: str # "user", "assistant", "system"
content: str
created_at: datetime = datetime.now()
class Choice(BaseModel):
message: Message
index: int = 0
class AIResponse(BaseModel):
choices: List[Choice]
@property
def content(self) -> str:
"""Helper to get the first message content directly."""
return self.choices[0].message.content
class Conversation(BaseModel):
id: str
messages: List[Message] = []
context: Optional[Context] = None
created_at: datetime = datetime.now()
updated_at: datetime = datetime.now()
def get_messages(self) -> List[Message]:
"""Returns a list of messages in the conversation."""
return self.messages
def add_message(self, role: str, content: str) -> Message:
"""Adds a new message to the conversation."""
message = Message(role=role, content=content)
self.messages.append(message)
self.updated_at = datetime.now()
return message
def set_context(self, context: Context):
"""Sets the context for the conversation."""
self.context = context
class ConversationRequest(BaseModel):
conversation_id: Optional[str] = None
message: str
context_update: Optional[Dict[str, Any]] = None
class ConversationResponse(BaseModel):
conversation_id: str
messages: List[Message]
metadata: Dict[str, Any] = {}
+54
View File
@@ -0,0 +1,54 @@
import uuid
from typing import List, Dict, Any, Optional
from datetime import datetime
from pydantic import BaseModel, Field
from .utils import find_provider
class SMBaseModel(BaseModel):
date_created: datetime = Field(default_factory=datetime.now)
def __str__(self):
return f"<{self.__class__.__name__} {self.model_dump_json()}>"
class Message(SMBaseModel):
role: str
text: str
meta: Dict[str, Any] = {}
raw: Optional[Any] = None
llm_model: Optional[str] = None
llm_provider: Optional[str] = None
def __str__(self):
return f"<Message role={self.role} text={self.text!r}>"
@classmethod
def from_raw_response(cls, *, text, raw):
self = cls()
self.text = text
self.raw = raw
return self
class Conversation(SMBaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
messages: List[Message] = []
llm_model: Optional[str] = None
llm_provider: Optional[str] = None
def __str__(self):
return f"<Conversation id={self.id!r}>"
def add_message(self, role: str, text: str, meta: Dict[str, Any] = {}):
"""Add a new message to the conversation."""
self.messages.append(Message(role=role, text=text, meta=meta))
def send(
self, llm_model: Optional[str] = None, llm_provider: Optional[str] = None
) -> Message:
"""Send the conversation to the LLM."""
provider = find_provider(llm_provider or self.llm_provider)
return provider.send_conversation(self)
-26
View File
@@ -1,26 +0,0 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
class Message(BaseModel):
role: str
content: str
created_at: datetime = datetime.now()
class Conversation(BaseModel):
id: str
messages: List[Message] = []
context: Dict[str, Any] = {}
created_at: datetime = datetime.now()
updated_at: datetime = datetime.now()
def add_message(self, role: str, content: str) -> Message:
message = Message(role=role, content=content)
self.messages.append(message)
self.updated_at = datetime.now()
return message
class AIResponse(BaseModel):
text: str
response: Any
metadata: Dict[str, Any] = {}
View File
-34
View File
@@ -1,34 +0,0 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
class BasePlugin(ABC):
"""Base class for all SimpleMind plugins."""
def __init__(self):
self.enabled: bool = True
self.name: str = self.__class__.__name__
@abstractmethod
def process(self, context: Dict[str, Any]) -> Dict[str, Any]:
"""Process the context and return modified context.
Args:
context: The current conversation context
Returns:
Modified context dictionary
"""
pass
def enable(self) -> None:
"""Enable the plugin."""
self.enabled = True
def disable(self) -> None:
"""Disable the plugin."""
self.enabled = False
@property
def is_enabled(self) -> bool:
return self.enabled
-10
View File
@@ -1,10 +0,0 @@
from .base import BasePlugin
class BasicMemoryPlugin(BasePlugin):
def __init__(self):
self.memory = []
def execute(self, context, message):
self.memory.append(message)
return self.memory
-18
View File
@@ -1,18 +0,0 @@
from .base import BasePlugin
class KVPlugin(BasePlugin):
def __init__(self):
self.store = {}
def process(self, key: str, value=None):
"""
Get or set a value in the key-value store.
If value is None, returns the value for the key.
If value is provided, sets the value for the key and returns it.
"""
if value is None:
return self.store.get(key)
self.store[key] = value
return value
+2 -17
View File
@@ -1,18 +1,3 @@
from ..core.config import settings from .openai import OpenAI
__all__ = [] providers = [OpenAI]
if settings.anthropic_api_key:
from .anthropic import Anthropic
__all__.append("Anthropic")
if settings.openai_api_key:
from .openai import OpenAI
__all__.append("OpenAI")
if settings.ollama_host_url:
from .ollama import Ollama
__all__.append("Ollama")
-80
View File
@@ -1,80 +0,0 @@
import os
from typing import List, Optional
import instructor
from anthropic import Anthropic as BaseAnthropic
from .base import BaseClientProvider
from ..core.models import AIResponse, Conversation
from ..core.logger import logger
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
DEFAULT_MAX_TOKENS = 4096
class Anthropic(BaseClientProvider):
def __init__(self, model: str = DEFAULT_MODEL, api_key: Optional[str] = None):
super().__init__(model=model, api_key=api_key)
self.login()
def login(self):
if not self._api_key:
self._api_key = os.getenv("ANTHROPIC_API_KEY")
if not self._api_key:
raise ValueError("Anthropic API key not provided.")
logger.debug(f"API key length: {len(self._api_key) if self._api_key else 0}")
base_client = BaseAnthropic(api_key=self._api_key)
self.client = instructor.from_anthropic(base_client)
if not self.test_connection():
raise ConnectionError("Failed to connect to Anthropic API.")
logger.info("Logged in to Anthropic successfully.")
@property
def available_models(self) -> List[str]:
try:
return [
"claude-3-opus-20240229",
"claude-3-5-sonnet-20240620",
"claude-3-haiku-20240307",
]
except Exception as e:
logger.error(f"Error fetching models: {e}")
return []
def test_connection(self) -> bool:
models = self.available_models
if models:
logger.info(f"Available models: {models}")
return True
logger.warning("No available models found.")
return False
def generate_response(self, conversation: Conversation) -> AIResponse:
messages = [
{"role": msg.role, "content": msg.content} for msg in conversation.messages
]
params = {
"messages": messages,
"model": self.model,
"max_tokens": DEFAULT_MAX_TOKENS,
}
if conversation.context:
params["context"] = (
vars(conversation.context)
if hasattr(conversation.context, "__dict__")
else dict(conversation.context)
)
try:
completion = self.client.completions.create(response_model=str, **params)
response_text = completion.completion
metadata = {"model": completion.model, "usage": completion.usage}
logger.info("Generated response from Anthropic.")
return AIResponse(
text=response_text, response=completion, metadata=metadata
)
except Exception as e:
logger.error(f"Error generating response: {e}")
raise e
-112
View File
@@ -1,112 +0,0 @@
# import logging
from pydantic import BaseModel
from typing import Any, Dict, List, Optional
from ..core.models import AIResponse, Conversation, Message
import uuid
DEFAULT_MODEL = "gpt-4o"
class BaseClientProvider:
def __init__(self, *, model: str = DEFAULT_MODEL, api_key: Optional[str] = None):
# self.logger = logging.getLogger(self.__class__.__name__)
self.client = None
self.model = model
self._api_key = api_key
self.conversations: Dict[str, Conversation] = {}
def login(self):
"""Initializes the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
def test_connection(self) -> bool:
"""Tests the connection to the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
def health_check(self):
"""Checks the health of the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
@property
def available_models(self) -> List[str]:
"""Returns the available models from the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
def message(self, message: str, **kwargs) -> AIResponse:
"""Generates a response from the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
# Uncomment and implement additional methods as needed
# def features(self):
# """Returns the features of the AI provider client."""
# msg = "This method must be implemented by the AI provider client."
# raise NotImplementedError(msg)
# def structured_response(self, model, message, **kwargs):
# pass
# def structured_conversation(self, model, message, **kwargs):
# pass
# def single_message(self, model, message, **kwargs):
# return self.generate_response(message)
# def start_conversation(self, model, message, **kwargs):
# pass
def create_conversation(
self, initial_message: str, context: Optional[Dict[str, Any]] = None
) -> Conversation:
conv_id = str(uuid.uuid4())
conversation = Conversation(
id=conv_id,
messages=[Message(role="user", content=initial_message)],
context=context or {},
)
self.conversations[conv_id] = conversation
return conversation
def send_message(
self,
conversation_id: str,
message: str,
context_update: Optional[Dict[str, Any]] = None,
) -> AIResponse:
if conversation_id not in self.conversations:
raise ValueError("Conversation ID does not exist.")
conversation = self.conversations[conversation_id]
conversation.messages.append(Message(role="user", content=message))
if context_update:
conversation.context.update(context_update)
response = self.generate_response(conversation)
conversation.messages.append(
Message(role="assistant", content=response.choices[0].message.content)
)
return response
def generate_response(self, conversation: Conversation) -> AIResponse:
"""Generates a response based on the conversation."""
raise NotImplementedError(
"This method must be implemented by the AI provider client."
)
def get_conversation(self, conversation_id: str) -> Conversation:
if conversation_id not in self.conversations:
raise ValueError("Conversation ID does not exist.")
return self.conversations[conversation_id]
-110
View File
@@ -1,110 +0,0 @@
import os
import uuid
from typing import List, Optional, Dict, Any
from ..core.errors import ProviderError
from ..core.logger import logger
from .base import BaseClientProvider
from some_module import BaseOllama # Replace 'some_module' with the actual module name
TIMEOUT = 60
class Ollama(BaseClientProvider):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.login()
self.conversation = []
def login(self):
"""Initialize Ollama client, with Instructor enabled."""
if not os.environ.get("OLLAMA_HOST_URL"):
raise ValueError("Please set the OLLAMA_HOST_URL environment variable")
if not os.environ.get("OLLAMA_MODEL"):
raise ValueError("Please set the OLLAMA_MODEL environment variable")
else:
self.model = os.environ.get("OLLAMA_MODEL")
self.client = BaseOllama(
timeout=TIMEOUT, host=os.environ.get("OLLAMA_HOST_URL")
)
assert self.test_connection()
@property
def available_models(self):
"""Returns the available models from the Ollama client."""
def gen():
for model in self.client.list().get("models"):
yield model
return [g for g in gen()]
def test_connection(self):
"""Test the connection to Ollama. Returns True if successful."""
return bool(len(self.available_models))
def generate_text(self, prompt, *, response_model=False, **kwargs):
use_instructor = bool(response_model)
client = self.instructor_client if use_instructor else self.client
# Parameters for the Ollama client.
params = {
"prompt": prompt,
"model": self.model,
}
params.update(kwargs)
if use_instructor:
params["response_model"] = response_model
# Make the request to Ollama.
completion = client.generate(**params)
if use_instructor:
return completion.model_dump()
else:
return AIResponse(
response=completion,
text=completion.get("response"),
)
def message(
self, message=None, message_history=None, response_model=False, **kwargs
):
"""Generates a response from the Ollama client."""
use_instructor = bool(response_model)
client = self.instructor_client if use_instructor else self.client
# Parameters for the Ollama client.
all_messages = []
if message_history:
all_messages.extend(message_history)
if message:
all_messages.append({"role": "user", "content": message})
params = {
"messages": all_messages,
"model": self.model,
}
params.update(kwargs)
if use_instructor:
params["response_model"] = response_model
# Make the request to Ollama.
completion = client.chat(**params)
if use_instructor:
return completion.model_dump()
else:
return AIResponse(
response=completion,
text=completion.get("message").get("content"),
)
def start_conversation(self):
return Conversation(self)
+45 -57
View File
@@ -1,69 +1,57 @@
from typing import List, Optional import openai as oa
import openai import instructor
from ..core.errors import AuthenticationError, ProviderError
from ..core.config import settings
from ..core.logger import logger
from .base import BaseClientProvider
DEFAULT_MODEL = "gpt-4o" # from ..models import Conversation, Message
from ..settings import settings
PROVIDER_NAME = "openai"
DEFAULT_MODEL = "gpt-4o-mini"
class OpenAI(BaseClientProvider): class OpenAI:
def __init__(self, model: str = DEFAULT_MODEL, api_key: Optional[str] = None): __name__ = PROVIDER_NAME
super().__init__(model=model, api_key=api_key) DEFAULT_MODEL = DEFAULT_MODEL
self.login()
def login(self) -> None: def __init__(self, api_key: str = None):
if not self._api_key: self.api_key = api_key or settings.OPENAI_API_KEY
self._api_key = settings.openai_api_key
if not self._api_key:
raise AuthenticationError("OpenAI API key not provided")
try:
openai.api_key = self._api_key
if not self.test_connection():
raise ProviderError("Failed to connect to OpenAI API")
logger.info("Successfully connected to OpenAI")
except Exception as e:
raise ProviderError(f"OpenAI initialization failed: {str(e)}")
@property @property
def available_models(self) -> List[str]: def client(self):
try: """The raw OpenAI client."""
models = openai.models.list() return oa.OpenAI(api_key=self.api_key)
return [model.id for model in models["data"]]
except Exception as e:
logger.error(f"Error fetching models: {e}")
return []
def test_connection(self): @property
"""Test the connection to OpenAI API.""" def structured_client(self):
try: """A client patched with Instructor."""
openai.models.list() return instructor.patch(oa.OpenAI(api_key=self.api_key))
return True
except Exception as e:
logger.error(f"Connection test failed: {e}")
return False
def _handle_api_error(self, e: Exception) -> None: def send_conversation(self, conversation: "Conversation"):
"""Handle API errors.""" """Send a conversation to the OpenAI API."""
logger.error(f"OpenAI API error: {e}") from ..models import Message
raise ProviderError(f"OpenAI API error: {e}")
def add_message(self, conversation_id, message, *args, **kwargs) -> str: messages = [
"""Generate a response using the OpenAI API.""" {"role": msg.role, "content": msg.text} for msg in conversation.messages
try: ]
# Create a client instance using the API key
client = openai.OpenAI(api_key=self._api_key)
# Create the message for the conversation response = self.client.chat.completions.create(
messages = [{"role": "user", "content": message}] model=conversation.llm_model or DEFAULT_MODEL, messages=messages
)
# Use the new API syntax # Get the response content from the OpenAI response
response = client.chat.completions.create( assistant_message = response.choices[0].message
model=self.model, messages=messages, *args, **kwargs
)
return response.choices[0].message.content # Create and return a properly formatted Message instance
except Exception as e: return Message(
self._handle_api_error(e) role="assistant",
text=assistant_message.content,
raw=response,
llm_model=conversation.llm_model or DEFAULT_MODEL,
llm_provider=PROVIDER_NAME,
)
def structured_response(self, model, response_model, **kwargs):
client = instructor.patch(oa.OpenAI(api_key=self.api_key))
response = client.chat.completions.create(
model=model, response_model=response_model, **kwargs
)
return response
+10
View File
@@ -0,0 +1,10 @@
from pydantic import Field
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
OPENAI_API_KEY: str = Field(..., env="OPENAI_API_KEY")
ANTHROPIC_API_KEY: str = Field(..., env="ANTHROPIC_API_KEY")
settings = Settings()
-24
View File
@@ -1,24 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
from simplemind.providers.openai import OpenAI
from simplemind.core.errors import AuthenticationError, ProviderError
class TestOpenAIProvider(unittest.TestCase):
def setUp(self):
self.api_key = "test_key"
@patch("simplemind.providers.openai.openai")
def test_initialization(self, mock_openai):
mock_openai.Model.list.return_value = {"data": ["gpt-4"]}
provider = OpenAI(api_key=self.api_key)
self.assertIsNotNone(provider.client)
mock_openai.Model.list.assert_called_once()
def test_missing_api_key(self):
with self.assertRaises(AuthenticationError):
OpenAI(api_key=None)
if __name__ == "__main__":
unittest.main()
+10
View File
@@ -0,0 +1,10 @@
from .providers import providers
def find_provider(provider_name: str):
"""Find a provider by name."""
for provider_class in providers:
if provider_class.__name__.lower() == provider_name.lower():
# Instantiate the provider
return provider_class()
raise ValueError(f"Provider {provider_name} not found")
View File
-19
View File
@@ -1,19 +0,0 @@
import faiss
import numpy as np
from typing import List
class FAISSStore:
def __init__(self, dimension: int):
self.dimension = dimension
self.index = faiss.IndexFlatL2(dimension)
self.ids = []
def add_embeddings(self, embeddings: np.ndarray, ids: List[str]):
self.index.add(embeddings)
self.ids.extend(ids)
def search(self, query_embedding: np.ndarray, top_k: int = 5):
distances, indices = self.index.search(query_embedding, top_k)
results = [(self.ids[idx], distances[i]) for i, idx in enumerate(indices[0])]
return results
+4 -6
View File
@@ -1,9 +1,7 @@
from simplemind import SimpleMind import simplemind as sm
sm = SimpleMind()
# The provider will automatically use OPENAI_API_KEY from environment
conversation = sm.create_conversation() conversation = sm.create_conversation()
r = sm.add_message(conversation.id, "Who is Kenneth Reitz?") conversation.add_message(role="user", text="Hello, how are you?")
r = conversation.send(llm_model="gpt-4o-mini", llm_provider="openai")
print(r) print(r)
#
-22
View File
@@ -1,22 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
from simplemind.providers.openai import OpenAI
from simplemind.core.errors import AuthenticationError, ProviderError
class TestOpenAIProvider(unittest.TestCase):
def setUp(self):
self.api_key = "test_key"
@patch("simplemind.integrations.openai.BaseOpenAI")
def test_initialization(self, mock_openai):
provider = OpenAI(api_key=self.api_key)
self.assertIsNotNone(provider.client)
def test_missing_api_key(self):
with self.assertRaises(AuthenticationError):
OpenAI(api_key=None)
if __name__ == "__main__":
unittest.main()