diff --git a/examples/contextual_memory.py b/examples/contextual_memory.py index 2e4e542..b10b0c5 100644 --- a/examples/contextual_memory.py +++ b/examples/contextual_memory.py @@ -1,12 +1,12 @@ -from _context import sm - -from pydantic import BaseModel -import openai -import faiss -import numpy as np import os import pickle +import faiss +import numpy as np +import openai +from _context import sm +from pydantic import BaseModel + class ContextualMemoryPlugin(sm.BasePlugin): def __init__( diff --git a/examples/generate_data.py b/examples/generate_data.py index a682644..75220a3 100644 --- a/examples/generate_data.py +++ b/examples/generate_data.py @@ -1,8 +1,7 @@ -from typing import List, Iterator - -from pydantic import BaseModel +from typing import Iterator, List from _context import sm +from pydantic import BaseModel class Movie(BaseModel): diff --git a/examples/sentiment_analysis.py b/examples/sentiment_analysis.py index 9cda733..7b855b6 100644 --- a/examples/sentiment_analysis.py +++ b/examples/sentiment_analysis.py @@ -1,8 +1,8 @@ -from _context import sm - -from pydantic import BaseModel from typing import Literal +from _context import sm +from pydantic import BaseModel + class SentimentAnalysis(BaseModel): sentiment: Literal["positive", "negative", "neutral"] diff --git a/examples/two_llms_talking.py b/examples/two_llms_talking.py index d098599..f88dd33 100644 --- a/examples/two_llms_talking.py +++ b/examples/two_llms_talking.py @@ -1,6 +1,7 @@ -import simplemind as sm import time +import simplemind as sm + class ConversationPlugin(sm.BasePlugin): def post_send_hook(self, conversation, response): diff --git a/simplemind/__init__.py b/simplemind/__init__.py index 49a3f99..90313de 100644 --- a/simplemind/__init__.py +++ b/simplemind/__init__.py @@ -1,8 +1,8 @@ from typing import List, Type -from .models import Conversation, BasePlugin, BaseModel -from .utils import find_provider +from .models import BaseModel, BasePlugin, Conversation from .settings import settings +from .utils import find_provider class Session: @@ -81,7 +81,7 @@ def generate_data( *, llm_model: str | None = None, llm_provider: str | None = None, - response_model: Type[BaseModel] = None, + response_model: Type[BaseModel], **kwargs, ) -> BaseModel: """Generate structured data from a given prompt.""" diff --git a/simplemind/models.py b/simplemind/models.py index c2a8e67..6923ead 100644 --- a/simplemind/models.py +++ b/simplemind/models.py @@ -2,12 +2,10 @@ import uuid from datetime import datetime from typing import Any, Dict, List, Literal, Optional - from pydantic import BaseModel, Field from .utils import find_provider - MESSAGE_ROLE = Literal["system", "user", "assistant"] diff --git a/simplemind/providers/__init__.py b/simplemind/providers/__init__.py index d142a4d..201f851 100644 --- a/simplemind/providers/__init__.py +++ b/simplemind/providers/__init__.py @@ -4,8 +4,8 @@ from ._base import BaseProvider from .anthropic import Anthropic from .gemini import Gemini from .groq import Groq -from .openai import OpenAI from .ollama import Ollama +from .openai import OpenAI from .xai import XAI providers: List[Type[BaseProvider]] = [Anthropic, Gemini, Groq, OpenAI, Ollama, XAI] diff --git a/simplemind/providers/_base.py b/simplemind/providers/_base.py index dc1abdc..a8d4a4c 100644 --- a/simplemind/providers/_base.py +++ b/simplemind/providers/_base.py @@ -1,6 +1,10 @@ from abc import ABC, abstractmethod +from typing import Type, TypeVar from instructor import Instructor +from pydantic import BaseModel + +T = TypeVar("T", bound=BaseModel) class BaseProvider(ABC): @@ -27,7 +31,7 @@ class BaseProvider(ABC): raise NotImplementedError @abstractmethod - def structured_response(self, prompt: str, response_model, **kwargs): + def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T: """Get a structured response.""" raise NotImplementedError diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index bfb31c6..38cdbc3 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -1,8 +1,14 @@ +from typing import Type, TypeVar + import anthropic import instructor +from pydantic import BaseModel -from ._base import BaseProvider from ..settings import settings +from ._base import BaseProvider + +T = TypeVar("T", bound=BaseModel) + PROVIDER_NAME = "anthropic" DEFAULT_MODEL = "claude-3-5-sonnet-20241022" @@ -55,7 +61,7 @@ class Anthropic(BaseProvider): llm_provider=PROVIDER_NAME, ) - def structured_response(self, model: str, response_model, **kwargs): + def structured_response(self, model: str, response_model: Type[T], **kwargs) -> T: response = self.structured_client.messages.create( model=model or self.DEFAULT_MODEL, response_model=response_model, **kwargs ) diff --git a/simplemind/providers/gemini.py b/simplemind/providers/gemini.py index 5bdd25d..ec21edc 100644 --- a/simplemind/providers/gemini.py +++ b/simplemind/providers/gemini.py @@ -1,13 +1,18 @@ -import instructor -import google.generativeai as genai +from typing import Type, TypeVar -from ._base import BaseProvider +import google.generativeai as genai +import instructor +from pydantic import BaseModel + +from ..models import Conversation, Message from ..settings import settings -from ..models import Message, Conversation +from ._base import BaseProvider PROVIDER_NAME = "gemini" DEFAULT_MODEL = "models/gemini-1.5-flash-latest" +T = TypeVar("T", bound=BaseModel) + class Gemini(BaseProvider): NAME = PROVIDER_NAME @@ -59,12 +64,13 @@ class Gemini(BaseProvider): llm_provider=PROVIDER_NAME, ) - def structured_response(self, prompt: str, response_model, **kwargs): + def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T: """Send a structured response to the Gemini API.""" try: response = self.structured_client.chat.completions.create( messages=[{"role": "user", "content": prompt}], response_model=response_model, + **kwargs, ) except Exception as e: # Handle the exception appropriately, e.g., log the error or raise a custom exception @@ -77,7 +83,9 @@ class Gemini(BaseProvider): """Generate text using the Gemini API.""" try: response = self.structured_client.chat.completions.create( - messages=[{"role": "user", "content": prompt}], response_model=None + messages=[{"role": "user", "content": prompt}], + response_model=None, + **kwargs, ) except Exception as e: # Handle the exception appropriately, e.g., log the error or raise a custom exception diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py index 62bd54c..574ff10 100644 --- a/simplemind/providers/groq.py +++ b/simplemind/providers/groq.py @@ -1,12 +1,17 @@ +from typing import Type, TypeVar + import groq import instructor +from pydantic import BaseModel -from ._base import BaseProvider from ..settings import settings +from ._base import BaseProvider PROVIDER_NAME = "groq" DEFAULT_MODEL = "llama3-8b-8192" +T = TypeVar("T", bound=BaseModel) + class Groq(BaseProvider): NAME = PROVIDER_NAME @@ -57,7 +62,7 @@ class Groq(BaseProvider): llm_provider=PROVIDER_NAME, ) - def structured_response(self, prompt: str, response_model, **kwargs): + def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T: # Ensure messages are provided in kwargs messages = [ {"role": "user", "content": prompt}, diff --git a/simplemind/providers/ollama.py b/simplemind/providers/ollama.py index d10a036..9950db9 100644 --- a/simplemind/providers/ollama.py +++ b/simplemind/providers/ollama.py @@ -1,9 +1,15 @@ -import ollama as ol -import instructor -from openai import OpenAI +from typing import Type, TypeVar + +import instructor +import ollama as ol +from openai import OpenAI +from pydantic import BaseModel -from ._base import BaseProvider from ..settings import settings +from ._base import BaseProvider + +T = TypeVar("T", bound=BaseModel) + PROVIDER_NAME = "ollama" DEFAULT_MODEL = "llama3.2" @@ -58,8 +64,9 @@ class Ollama(BaseProvider): ) def structured_response( - self, prompt: str, response_model, *, llm_model: str, **kwargs - ): + self, prompt: str, response_model: Type[T], *, llm_model: str, **kwargs + ) -> T: + """Get a structured response from the Ollama API.""" messages = [ {"role": "user", "content": prompt}, ] @@ -72,7 +79,8 @@ class Ollama(BaseProvider): ) return response - def generate_text(self, prompt: str, *, llm_model: str): + def generate_text(self, prompt: str, *, llm_model: str) -> str: + """Generate text using the Ollama API.""" messages = [ {"role": "user", "content": prompt}, ] diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index 549f28a..cc660be 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -1,8 +1,13 @@ +from typing import Type, TypeVar + import instructor import openai as oa +from pydantic import BaseModel -from ._base import BaseProvider from ..settings import settings +from ._base import BaseProvider + +T = TypeVar("T", bound=BaseModel) PROVIDER_NAME = "openai" DEFAULT_MODEL = "gpt-4o-mini" @@ -52,13 +57,18 @@ class OpenAI(BaseProvider): ) def structured_response( - self, prompt: str, response_model, *, llm_model: str | None = None, **kwargs - ): + self, + prompt: str, + response_model: Type[T], + *, + llm_model: str | None = None, + **kwargs, + ) -> T: + """Get a structured response from the OpenAI API.""" # Ensure messages are provided in kwargs messages = [ {"role": "user", "content": prompt}, ] - response = self.structured_client.chat.completions.create( messages=messages, model=llm_model or self.DEFAULT_MODEL, @@ -68,12 +78,11 @@ class OpenAI(BaseProvider): return response def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs): + """Generate text using the OpenAI API.""" messages = [ {"role": "user", "content": prompt}, ] - response = self.client.chat.completions.create( messages=messages, model=llm_model or self.DEFAULT_MODEL, **kwargs ) - return response.choices[0].message.content diff --git a/simplemind/providers/xai.py b/simplemind/providers/xai.py index 697c407..78cb096 100644 --- a/simplemind/providers/xai.py +++ b/simplemind/providers/xai.py @@ -1,8 +1,8 @@ import instructor import openai as oa -from ._base import BaseProvider from ..settings import settings +from ._base import BaseProvider PROVIDER_NAME = "xai" DEFAULT_MODEL = "grok-beta" diff --git a/simplemind/utils.py b/simplemind/utils.py index 67723c5..baabbae 100644 --- a/simplemind/utils.py +++ b/simplemind/utils.py @@ -1,6 +1,6 @@ import difflib -from .providers import providers, BaseProvider +from .providers import BaseProvider, providers _PROVIDER_NAMES = [provider.NAME.lower() for provider in providers]