diff --git a/examples/cooking_recipe_example.py b/examples/cooking_recipe_example.py index 8abc295..077d9ae 100644 --- a/examples/cooking_recipe_example.py +++ b/examples/cooking_recipe_example.py @@ -1,5 +1,5 @@ -from pydantic import BaseModel from _context import simplemind as sm +from pydantic import BaseModel from rich.console import Console from rich.panel import Panel from rich.text import Text diff --git a/examples/discussion.py b/examples/discussion.py index d63353d..42781b3 100644 --- a/examples/discussion.py +++ b/examples/discussion.py @@ -1,11 +1,10 @@ import time from typing import List, Tuple +from _context import sm from rich.console import Console from rich.markdown import Markdown -from _context import sm - class MultiAIConversation: """Orchestrates conversations between multiple AI models.""" diff --git a/examples/enhanced_context.py b/examples/enhanced_context.py index a40fd1f..ba63f8a 100644 --- a/examples/enhanced_context.py +++ b/examples/enhanced_context.py @@ -1,35 +1,28 @@ -from datetime import datetime -import logging -import sqlite3 -from typing import List -import re -import os import contextlib - -import spacy +import logging +import os +import random +import re +import sqlite3 +from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager - -from _context import simplemind as sm +from datetime import datetime +from typing import List import nltk -from nltk.tokenize import word_tokenize -from nltk.tag import pos_tag - -from rich.console import Console -from rich.panel import Panel -from rich.markdown import Markdown -from rich.status import Status - -from concurrent.futures import ThreadPoolExecutor -import random - -from docopt import docopt - -from prompt_toolkit import PromptSession -from prompt_toolkit.completion import Completer, Completion -from prompt_toolkit.auto_suggest import AutoSuggestFromHistory - +import spacy import xerox +from _context import simplemind as sm +from docopt import docopt +from nltk.tag import pos_tag +from nltk.tokenize import word_tokenize +from prompt_toolkit import PromptSession +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.completion import Completer, Completion +from rich.console import Console +from rich.markdown import Markdown +from rich.panel import Panel +from rich.status import Status DB_PATH = "enhanced_context.db" AVAILABLE_PROVIDERS = ["xai", "openai", "anthropic", "ollama"] diff --git a/examples/four_way_conversation.py b/examples/four_way_conversation.py index c414e46..2ee283a 100644 --- a/examples/four_way_conversation.py +++ b/examples/four_way_conversation.py @@ -1,4 +1,5 @@ import time + from _context import simplemind as sm diff --git a/examples/inspiration_plugin.py b/examples/inspiration_plugin.py index db2001b..6180385 100644 --- a/examples/inspiration_plugin.py +++ b/examples/inspiration_plugin.py @@ -1,4 +1,5 @@ import random + from _context import simplemind as sm diff --git a/examples/medicine_data.py b/examples/medicine_data.py index b6a7be4..361e833 100644 --- a/examples/medicine_data.py +++ b/examples/medicine_data.py @@ -1,5 +1,5 @@ -from pydantic import BaseModel from _context import simplemind as sm +from pydantic import BaseModel from rich.console import Console from rich.panel import Panel from rich.table import Table diff --git a/examples/mood_detector_plugin.py b/examples/mood_detector_plugin.py index 5f61c36..e15c21e 100644 --- a/examples/mood_detector_plugin.py +++ b/examples/mood_detector_plugin.py @@ -1,7 +1,7 @@ import nltk +from _context import simplemind as sm from nltk.sentiment import SentimentIntensityAnalyzer from rich.console import Console -from _context import simplemind as sm nltk.download("vader_lexicon") diff --git a/examples/sentiment_analysis.py b/examples/sentiment_analysis.py index 8354046..2c380fa 100644 --- a/examples/sentiment_analysis.py +++ b/examples/sentiment_analysis.py @@ -5,6 +5,7 @@ from pydantic import BaseModel # Note: you should probably be using textblob for this. + class SentimentAnalysis(BaseModel): sentiment: Literal["positive", "negative", "neutral"] confidence: float diff --git a/examples/web_bible_explorer.py b/examples/web_bible_explorer.py index 4012078..373c650 100644 --- a/examples/web_bible_explorer.py +++ b/examples/web_bible_explorer.py @@ -1,8 +1,10 @@ -from fastapi import FastAPI, Request, HTTPException -from fastapi.templating import Jinja2Templates -from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel from typing import List + +from fastapi import FastAPI, HTTPException, Request +from fastapi.staticfiles import StaticFiles +from fastapi.templating import Jinja2Templates +from pydantic import BaseModel + import simplemind as sm app = FastAPI() diff --git a/pyproject.toml b/pyproject.toml index acbee0e..3135d18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,18 +10,18 @@ dependencies = ["pydantic", "pydantic-settings", "instructor", "logfire"] full = [ "openai", "anthropic", - "ollama", "groq", "google-generativeai", "botocore", "boto3" ] -openai = ["openai"] -anthropic = ["anthropic"] -ollama = ["ollama", "openai"] -groq = ["groq"] -gemini = ["google-generativeai"] amazon = ["boto3", "botocore", "anthropic"] +anthropic = ["anthropic"] +gemini = ["google-generativeai"] +groq = ["groq"] +ollama = ["openai"] +openai = ["openai"] +xai = ["openai"] [build-system] requires = ["hatchling"] diff --git a/simplemind/providers/__init__.py b/simplemind/providers/__init__.py index cc6c465..1f72f78 100644 --- a/simplemind/providers/__init__.py +++ b/simplemind/providers/__init__.py @@ -1,12 +1,32 @@ from typing import List, Type from ._base import BaseProvider +from .amazon import Amazon from .anthropic import Anthropic from .gemini import Gemini from .groq import Groq from .ollama import Ollama from .openai import OpenAI from .xai import XAI -from .amazon import Amazon -providers: List[Type[BaseProvider]] = [Anthropic, Gemini, Groq, OpenAI, Ollama, XAI, Amazon] +providers: List[Type[BaseProvider]] = [ + Anthropic, + Gemini, + Groq, + OpenAI, + Ollama, + XAI, + Amazon, +] + +__all__ = [ + "Anthropic", + "Gemini", + "Groq", + "OpenAI", + "Ollama", + "XAI", + "Amazon", + "providers", + "BaseProvider", +] diff --git a/simplemind/providers/amazon.py b/simplemind/providers/amazon.py index fe82d9a..492f2ab 100644 --- a/simplemind/providers/amazon.py +++ b/simplemind/providers/amazon.py @@ -1,22 +1,24 @@ -from typing import Type, TypeVar, Iterator from functools import cached_property +from typing import TYPE_CHECKING, Iterator, Type, TypeVar import instructor from pydantic import BaseModel -from ._base import BaseProvider +from simplemind.models import Message + from ..settings import settings +from ._base import BaseProvider + +if TYPE_CHECKING: + from ..models import Conversation, Message T = TypeVar("T", bound=BaseModel) -PROVIDER_NAME = "amazon" -DEFAULT_MODEL = "us.anthropic.claude-3-5-sonnet-20241022-v2:0" -DEFAULT_MAX_TOKENS = 5_000 - class Amazon(BaseProvider): - NAME = PROVIDER_NAME - DEFAULT_MODEL = DEFAULT_MODEL + NAME = "amazon" + DEFAULT_MODEL = "us.anthropic.claude-3-5-sonnet-20241022-v2:0" + DEFAULT_MAX_TOKENS = 5_000 supports_streaming = True def __init__(self, profile_name: str | None = None): @@ -25,7 +27,12 @@ class Amazon(BaseProvider): @cached_property def client(self): """The AnthropicBedrock client.""" - import anthropic + try: + import anthropic + except ImportError as exc: + raise ImportError( + "Please install the `anthropic` package: `pip install anthropic`" + ) from exc if not self.profile_name: raise ValueError("Profile name is not provided") @@ -33,12 +40,12 @@ class Amazon(BaseProvider): return anthropic.AnthropicBedrock(aws_profile=self.profile_name) @cached_property - def structured_client(self): + def structured_client(self) -> instructor.Instructor: """A client patched with Instructor.""" return instructor.from_anthropic(self.client) - def send_conversation(self, conversation: "Conversation", **kwargs): + def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message": """Send a conversation to the OpenAI API.""" from ..models import Message @@ -59,7 +66,7 @@ class Amazon(BaseProvider): role="assistant", text=assistant_message.content or "", raw=response, - llm_model=conversation.llm_model or DEFAULT_MODEL, + llm_model=conversation.llm_model or self.DEFAULT_MODEL, llm_provider=PROVIDER_NAME, ) @@ -75,12 +82,12 @@ class Amazon(BaseProvider): messages=messages, model=llm_model or self.DEFAULT_MODEL, response_model=response_model, - max_tokens=DEFAULT_MAX_TOKENS, + max_tokens=self.DEFAULT_MAX_TOKENS, **kwargs, ) return response - def generate_text(self, prompt, *, llm_model, **kwargs): + def generate_text(self, prompt: str, *, llm_model: str, **kwargs): messages = [ {"role": "user", "content": prompt}, ] @@ -88,13 +95,15 @@ class Amazon(BaseProvider): response = self.client.messages.create( model=llm_model or self.DEFAULT_MODEL, messages=messages, - max_tokens=DEFAULT_MAX_TOKENS, + max_tokens=self.DEFAULT_MAX_TOKENS, **kwargs, ) return response.content[0].text - def generate_stream_text(self, prompt, *, llm_model, **kwargs) -> Iterator[str]: + def generate_stream_text( + self, prompt: str, *, llm_model: str, **kwargs + ) -> Iterator[str]: """Generate streaming text using the Amazon API.""" # Prepare the messages. diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index ffa776d..6fbec5b 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import TYPE_CHECKING, Type, TypeVar, Iterator +from typing import TYPE_CHECKING, Iterator, Type, TypeVar import instructor from pydantic import BaseModel @@ -14,20 +14,15 @@ if TYPE_CHECKING: T = TypeVar("T", bound=BaseModel) -PROVIDER_NAME = "anthropic" -DEFAULT_MODEL = "claude-3-5-sonnet-20241022" -DEFAULT_MAX_TOKENS = 1_000 -DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} - - class Anthropic(BaseProvider): - NAME = PROVIDER_NAME - DEFAULT_MODEL = DEFAULT_MODEL - DEFAULT_KWARGS = DEFAULT_KWARGS + NAME = "anthropic" + DEFAULT_MODEL = "claude-3-5-sonnet-20241022" + DEFAULT_MAX_TOKENS = 1_000 + DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} supports_streaming = True def __init__(self, api_key: str | None = None): - self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) + self.api_key = api_key or settings.get_api_key(self.NAME) @cached_property def client(self): @@ -72,7 +67,7 @@ class Anthropic(BaseProvider): text=assistant_message, raw=response, llm_model=conversation.llm_model or self.DEFAULT_MODEL, - llm_provider=PROVIDER_NAME, + llm_provider=self.NAME, ) @logger diff --git a/simplemind/providers/gemini.py b/simplemind/providers/gemini.py index 0123bf4..7415b0d 100644 --- a/simplemind/providers/gemini.py +++ b/simplemind/providers/gemini.py @@ -2,7 +2,7 @@ # IT is not currently working as desired. from functools import cached_property -from typing import TYPE_CHECKING, Type, TypeVar, Iterator +from typing import TYPE_CHECKING, Iterator, Type, TypeVar import instructor from pydantic import BaseModel @@ -17,18 +17,14 @@ if TYPE_CHECKING: T = TypeVar("T", bound=BaseModel) -PROVIDER_NAME = "gemini" -DEFAULT_MODEL = "models/gemini-1.5-flash-latest" - - class Gemini(BaseProvider): - NAME = PROVIDER_NAME - DEFAULT_MODEL = DEFAULT_MODEL + NAME = "gemini" + DEFAULT_MODEL = "models/gemini-1.5-flash-latest" supports_streaming = True def __init__(self, api_key: str | None = None): - self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) - self.model_name = DEFAULT_MODEL + self.api_key = api_key or settings.get_api_key(self.NAME) + self.model_name = self.DEFAULT_MODEL def set_model(self, model_name: str): self.model_name = model_name @@ -76,7 +72,7 @@ class Gemini(BaseProvider): text=response.text, raw=response, llm_model=self.model_name, - llm_provider=PROVIDER_NAME, + llm_provider=self.NAME, ) @logger diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py index 45593aa..107ce80 100644 --- a/simplemind/providers/groq.py +++ b/simplemind/providers/groq.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import TYPE_CHECKING, Type, TypeVar, Iterator +from typing import TYPE_CHECKING, Iterator, Type, TypeVar import instructor from pydantic import BaseModel @@ -14,20 +14,15 @@ if TYPE_CHECKING: T = TypeVar("T", bound=BaseModel) -PROVIDER_NAME = "groq" -DEFAULT_MODEL = "llama3-8b-8192" -DEFAULT_MAX_TOKENS = 1_000 -DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} - - class Groq(BaseProvider): - NAME = PROVIDER_NAME - DEFAULT_MODEL = DEFAULT_MODEL - DEFAULT_KWARGS = DEFAULT_KWARGS + NAME = "groq" + DEFAULT_MODEL = "llama3-8b-8192" + DEFAULT_MAX_TOKENS = 1_000 + DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} supports_streaming = True def __init__(self, api_key: str | None = None): - self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) + self.api_key = api_key or settings.get_api_key(self.NAME) @cached_property def client(self): @@ -75,7 +70,7 @@ class Groq(BaseProvider): text=assistant_message.content or "", raw=response, llm_model=conversation.llm_model or self.DEFAULT_MODEL, - llm_provider=PROVIDER_NAME, + llm_provider=self.NAME, ) @logger diff --git a/simplemind/providers/ollama.py b/simplemind/providers/ollama.py index 84df560..4266e68 100644 --- a/simplemind/providers/ollama.py +++ b/simplemind/providers/ollama.py @@ -1,8 +1,7 @@ from functools import cached_property -from typing import TYPE_CHECKING, Type, TypeVar, Iterator +from typing import TYPE_CHECKING, Iterator, Type, TypeVar import instructor -from openai import OpenAI from pydantic import BaseModel from ..logging import logger @@ -15,17 +14,11 @@ if TYPE_CHECKING: T = TypeVar("T", bound=BaseModel) -PROVIDER_NAME = "ollama" -DEFAULT_MODEL = "llama3.2" -DEFAULT_TIMEOUT = 60 -DEFAULT_KWARGS = {} - - class Ollama(BaseProvider): - NAME = PROVIDER_NAME - DEFAULT_MODEL = DEFAULT_MODEL - DEFAULT_KWARGS = DEFAULT_KWARGS - TIMEOUT = DEFAULT_TIMEOUT + NAME = "ollama" + DEFAULT_MODEL = "llama3.2" + DEFAULT_TIMEOUT = 60 + DEFAULT_KWARGS = {} supports_streaming = True def __init__(self, host_url: str | None = None): @@ -37,21 +30,18 @@ class Ollama(BaseProvider): if not self.host_url: raise ValueError("No ollama host url provided") try: - import ollama as ol + import openai except ImportError as exc: raise ImportError( - "Please install the `ollama` package: `pip install ollama`" + "Please install the `openai` package: `pip install openai`" ) from exc - return ol.Client(timeout=self.TIMEOUT, host=self.host_url) + return openai.OpenAI(base_url=f"{self.host_url}/v1", api_key="ollama") @cached_property def structured_client(self) -> instructor.Instructor: """A client patched with Instructor.""" return instructor.from_openai( - OpenAI( - base_url=f"{self.host_url}/v1", - api_key="ollama", - ), + self.client, mode=instructor.Mode.JSON, ) @@ -64,7 +54,7 @@ class Ollama(BaseProvider): {"role": msg.role, "content": msg.text} for msg in conversation.messages ] response = self.client.chat( - model=conversation.llm_model or DEFAULT_MODEL, + model=conversation.llm_model or self.DEFAULT_MODEL, messages=messages, **{**self.DEFAULT_KWARGS, **kwargs}, ) @@ -76,7 +66,7 @@ class Ollama(BaseProvider): text=assistant_message.get("content"), raw=response, llm_model=conversation.llm_model or self.DEFAULT_MODEL, - llm_provider=PROVIDER_NAME, + llm_provider=self.NAME, ) @logger diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index 58e4e21..84db110 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import TYPE_CHECKING, Type, TypeVar, Iterator +from typing import TYPE_CHECKING, Iterator, Type, TypeVar import instructor from pydantic import BaseModel @@ -13,20 +13,16 @@ if TYPE_CHECKING: T = TypeVar("T", bound=BaseModel) -PROVIDER_NAME = "openai" -DEFAULT_MODEL = "gpt-4o-mini" -DEFAULT_MAX_TOKENS = None -DEFAULT_KWARGS = {} - class OpenAI(BaseProvider): - NAME = PROVIDER_NAME - DEFAULT_MODEL = DEFAULT_MODEL - DEFAULT_KWARGS = DEFAULT_KWARGS + NAME = "openai" + DEFAULT_MODEL = "gpt-4o-mini" + DEFAULT_MAX_TOKENS = None + DEFAULT_KWARGS = {} supports_streaming = True def __init__(self, api_key: str | None = None): - self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) + self.api_key = api_key or settings.get_api_key(self.NAME) @cached_property def client(self): @@ -56,7 +52,7 @@ class OpenAI(BaseProvider): ] response = self.client.chat.completions.create( - model=conversation.llm_model or DEFAULT_MODEL, + model=conversation.llm_model or self.DEFAULT_MODEL, messages=messages, **{**self.DEFAULT_KWARGS, **kwargs}, ) @@ -69,8 +65,8 @@ class OpenAI(BaseProvider): role="assistant", text=assistant_message.content or "", raw=response, - llm_model=conversation.llm_model or DEFAULT_MODEL, - llm_provider=PROVIDER_NAME, + llm_model=conversation.llm_model or self.DEFAULT_MODEL, + llm_provider=self.NAME, ) @logger diff --git a/simplemind/providers/xai.py b/simplemind/providers/xai.py index 279c150..23d8a5c 100644 --- a/simplemind/providers/xai.py +++ b/simplemind/providers/xai.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import TYPE_CHECKING, Type, TypeVar, Iterator +from typing import TYPE_CHECKING, Iterator, Type, TypeVar import instructor from pydantic import BaseModel @@ -14,22 +14,17 @@ if TYPE_CHECKING: T = TypeVar("T", bound=BaseModel) -PROVIDER_NAME = "xai" -DEFAULT_MODEL = "grok-beta" -BASE_URL = "https://api.x.ai/v1" -DEFAULT_MAX_TOKENS = 1000 -DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} - - class XAI(BaseProvider): - NAME = PROVIDER_NAME - DEFAULT_MODEL = DEFAULT_MODEL - DEFAULT_KWARGS = DEFAULT_KWARGS + NAME = "xai" + DEFAULT_MODEL = "grok-beta" + DEFAULT_MAX_TOKENS = 1000 + DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} + BASE_URL = "https://api.x.ai/v1" supports_streaming = True supports_structured_responses = False def __init__(self, api_key: str | None = None): - self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) + self.api_key = api_key or settings.get_api_key(self.NAME) @cached_property def client(self): @@ -44,7 +39,7 @@ class XAI(BaseProvider): ) from exc return oa.OpenAI( api_key=self.api_key, - base_url=BASE_URL, + base_url=self.BASE_URL, ) @cached_property @@ -76,7 +71,7 @@ class XAI(BaseProvider): text=assistant_message.content, raw=response, llm_model=conversation.llm_model or self.DEFAULT_MODEL, - llm_provider=PROVIDER_NAME, + llm_provider=self.NAME, ) @logger diff --git a/simplemind/settings.py b/simplemind/settings.py index ee11063..de69339 100644 --- a/simplemind/settings.py +++ b/simplemind/settings.py @@ -14,8 +14,9 @@ class LoggingConfig(BaseSettings): """Enable logging for the application.""" # adding imports here to avoid forced dependencies try: - import logfire from logging import basicConfig + + import logfire except ImportError as e: raise ImportError( "To enable logging, please install logfire: `pip install logfire`" diff --git a/tests/test_conversations.py b/tests/test_conversations.py index b0c85dd..4747f65 100644 --- a/tests/test_conversations.py +++ b/tests/test_conversations.py @@ -1,8 +1,7 @@ import pytest -from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon - import simplemind as sm +from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI @pytest.mark.parametrize( diff --git a/tests/test_generate_data.py b/tests/test_generate_data.py index 610c96a..cd00bc6 100644 --- a/tests/test_generate_data.py +++ b/tests/test_generate_data.py @@ -1,9 +1,8 @@ import pytest - -from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon - from pydantic import BaseModel +from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI + class ResponseModel(BaseModel): result: int diff --git a/tests/test_generate_text.py b/tests/test_generate_text.py index 4ab62cf..2b3eb0e 100644 --- a/tests/test_generate_text.py +++ b/tests/test_generate_text.py @@ -1,6 +1,6 @@ import pytest -from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon +from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI @pytest.mark.parametrize(