49 Commits

Author SHA1 Message Date
kennethreitz cd0be3ad89 Refactor LoggingConfig methods for enabling and disabling logging 2024-11-01 08:36:05 -04:00
kennethreitz 3dd2e1b248 Refactor Gemini provider to handle missing llm_model key 2024-11-01 08:28:53 -04:00
Siddhesh Agarwal ad1800840d small changes 2024-11-01 15:27:15 +05:30
Siddhesh Agarwal d62f297b68 removed unused variable 2024-11-01 15:16:20 +05:30
Siddhesh Agarwal a2597709d2 gemini works as expected 2024-11-01 14:55:22 +05:30
Siddhesh Agarwal 1455b5ba13 remove unused import 2024-11-01 14:31:19 +05:30
Siddhesh Agarwal 0fb54d1987 circular import problem solve 2024-11-01 14:31:01 +05:30
Siddhesh Agarwal fe06331662 fixed forced imports + ensured return type in structure_response 2024-11-01 14:24:34 +05:30
Siddhesh Agarwal 56b1e65d70 moved logging functions to LoggingConfig from Settings 2024-11-01 13:06:06 +05:30
Siddhesh Agarwal 4b3e1bc6dd added methods to toggle logging 2024-11-01 12:55:24 +05:30
Siddhesh Agarwal f5b922ade8 added proper type hinting 2024-11-01 12:25:44 +05:30
Siddhesh Agarwal 3a7383425f sorted imports 2024-11-01 11:09:54 +05:30
Siddhesh Agarwal 92c10fc41e added logging 2024-11-01 11:07:04 +05:30
kennethreitz caceba381d Refactor default_kwargs logic in Ollama provider 2024-10-31 19:49:33 -04:00
kennethreitz 0795464fd7 Merge pull request #24 from barisozmen/default_kwargs
Add default kwargs logic to Groq, OpenAI, XAI, and Ollama providers
2024-10-31 19:48:02 -04:00
Barış Özmen d82effdfb1 added default_kwargs logic to xAI provider 2024-11-01 00:18:57 +03:00
Barış Özmen e648292cb3 added default_kwargs logic to Ollama provider 2024-11-01 00:17:22 +03:00
Barış Özmen 37a9333be3 added default_kwargs logic to OpenAI provider 2024-11-01 00:15:49 +03:00
Barış Özmen cbc3739411 added default_kwargs logic to Groq provider 2024-11-01 00:14:41 +03:00
kennethreitz 7c8f22bef1 Update version to v0.1.6 and add sm.Plugin syntax sugar 2024-10-31 16:35:24 -04:00
kennethreitz 9c3f2a6df3 Refactor Anthropic provider and add tests for structured response and llm_model in structured_response 2024-10-31 16:33:44 -04:00
kennethreitz febf5473d5 Refactor message parameter in Anthropic provider 2024-10-31 16:33:01 -04:00
kennethreitz 48ac97f070 Refactor messages parameter in Anthropic provider 2024-10-31 16:29:58 -04:00
kennethreitz c41a3f00fb Add test for generating text with different providers 2024-10-31 16:22:05 -04:00
kennethreitz 25ee4ae32c Add test for basic math 2024-10-31 16:21:59 -04:00
kennethreitz 984721f02b Add conftest.py with fixture for simplemind Session 2024-10-31 16:21:54 -04:00
kennethreitz 69c8723770 Refactor DEFAULT_LLM_MODEL parameter in Settings class 2024-10-31 16:21:43 -04:00
kennethreitz 0c10d5676a Refactor max_tokens parameter in Anthropic provider 2024-10-31 16:21:36 -04:00
kennethreitz e0ddf41e15 Refactor llm_model parameter in Session class 2024-10-31 16:21:31 -04:00
kennethreitz f940ae2dfd the irony is not lost 2024-10-31 16:08:18 -04:00
kennethreitz 85fa4f5879 Add Plugin syntax sugar and improve Anthropic provider for max tokens 2024-10-31 16:08:07 -04:00
kennethreitz 44581e8fe3 Merge pull request #23 from barisozmen/issue_15
Add default kwargs logic into Anthropic provider, which is superseded by user entered kwargs
2024-10-31 16:00:46 -04:00
Barış Özmen 9503ec7fd3 Remove duplicate max_tokens parameter 2024-10-31 22:58:13 +03:00
Barış Özmen 418f36dcc0 kwargs supersede default kwargs for Anthropic provider methods 2024-10-31 22:46:17 +03:00
kennethreitz bf9683cfd0 Refactor code to use syntax sugar for Plugin class 2024-10-31 15:38:58 -04:00
kennethreitz 3909588f3e chore: Update CHANGELOG to include support for Python 3.10 2024-10-31 14:54:51 -04:00
kennethreitz 33d8f18bff refactor: Update Gemini provider to handle conversation-based completions and add structured response 2024-10-31 13:54:33 -04:00
kennethreitz d7388ef0d5 Update README.md 2024-10-31 13:54:17 -04:00
kennethreitz 02d10bfda9 Update README.md 2024-10-31 13:53:24 -04:00
kennethreitz 5dc6e7b006 Update README.md 2024-10-31 13:53:11 -04:00
kennethreitz 62933c8553 Update README.md 2024-10-31 13:52:55 -04:00
kennethreitz f0a6be73f8 Update README.md 2024-10-31 13:52:34 -04:00
kennethreitz 9257a04f34 Update README.md 2024-10-31 13:43:35 -04:00
kennethreitz 64dbe9a2e7 Update README.md 2024-10-31 13:42:20 -04:00
kennethreitz ccb8311089 Update README.md 2024-10-31 13:42:07 -04:00
kennethreitz 0c29380501 Update README.md 2024-10-31 13:20:26 -04:00
kennethreitz 7b43208a03 Update README.md 2024-10-31 13:20:04 -04:00
kennethreitz e931fd0eae Update README.md 2024-10-31 13:19:22 -04:00
kennethreitz 736d942527 Update README.md 2024-10-31 13:18:35 -04:00
20 changed files with 332 additions and 71 deletions
+1
View File
@@ -167,3 +167,4 @@ cython_debug/
src/** src/**
requirements.txt requirements.txt
Pipfile
+14
View File
@@ -1,6 +1,20 @@
Release History Release History
=============== ===============
## 0.1.6 (2024-10-31)
- Add `sm.Plugin` syntax sugar.
- Improvements to Anthropic provider, related to max tokens.
- General improvements.
- Add tests for structured response.
- Add `llm_model` to `structured_response`.
## 0.1.5 (2024-10-31)
- Add Gemini provider.
- Add structured response to Gemini provider.
- Support for Python 3.10.
## 0.1.4 (2024-10-30) ## 0.1.4 (2024-10-30)
- Introduce `Session` class to manage repeatability. - Introduce `Session` class to manage repeatability.
+4 -5
View File
@@ -6,8 +6,6 @@ Simplemind is AI library designed to simplify your experience with AI APIs in Py
![simplemind](https://github.com/user-attachments/assets/36df2103-2583-4958-ad5e-19cda7740256) ![simplemind](https://github.com/user-attachments/assets/36df2103-2583-4958-ad5e-19cda7740256)
[![Auto Wiki](https://img.shields.io/badge/Auto_Wiki-Mutable.ai-blue)](https://mutable.ai/kennethreitz/simplemind)
## Features ## Features
With Simplemind, tapping into AI is as easy as a friendly conversation. With Simplemind, tapping into AI is as easy as a friendly conversation.
@@ -18,7 +16,7 @@ With Simplemind, tapping into AI is as easy as a friendly conversation.
## Supported APIs ## Supported APIs
To specify a specific provider or model, you can use the `llm_provider` and `llm_model` parameters when calling: `generate_text`, `generate_data`, or `create_conversation`. To specify a specific provider or model, you can use the `llm_provider` and `llm_model` parameters when calling: `generate_text`, `generate_data`, or `create_conversation`. The APIs remain identital between all supported providers/models.
- [**Anthropic's Claude**](https://www.anthropic.com/claude) - [**Anthropic's Claude**](https://www.anthropic.com/claude)
- [**Google's Gemini**](https://gemini.google/) - [**Google's Gemini**](https://gemini.google/)
@@ -27,7 +25,7 @@ To specify a specific provider or model, you can use the `llm_provider` and `llm
- [**OpenAI's GPT**](https://openai.com/gpt) - [**OpenAI's GPT**](https://openai.com/gpt)
- [**xAI's Grok**](https://x.ai/) - [**xAI's Grok**](https://x.ai/)
If you want to see Simplemind support, additional providers or models, please request a pull! If you want to see Simplemind support, additional providers or models, please send a pull request!
## Why SimpleMind? ## Why SimpleMind?
- **Intuitive**: Built with Pythonic simplicity and readability in mind. - **Intuitive**: Built with Pythonic simplicity and readability in mind.
@@ -50,7 +48,7 @@ First, authenticate your API keys by setting them in the environment variables:
$ export OPENAI_API_KEY="sk-..." $ export OPENAI_API_KEY="sk-..."
``` ```
This pattern allows you to keep your API keys private and out of your codebase. Other supported environment variables: `ANTHROPIC_API_KEY`, `XAI_API_KEY`, and `GROQ_API_KEY`. This pattern allows you to keep your API keys private and out of your codebase. Other supported environment variables: `ANTHROPIC_API_KEY`, `XAI_API_KEY`, `GROQ_API_KEY`, and `GEMINI_API_KEY`.
Next, import Simplemind and start using it: Next, import Simplemind and start using it:
@@ -219,3 +217,4 @@ Simplemind is licensed under the Apache 2.0 License.
## Acknowledgements ## Acknowledgements
Simplemind is inspired by the philosophy of "code for humans" and aims to make working with AI models accessible to all. Special thanks to the open-source community for their contributions and inspiration. Simplemind is inspired by the philosophy of "code for humans" and aims to make working with AI models accessible to all. Special thanks to the open-source community for their contributions and inspiration.
+1 -1
View File
@@ -16,7 +16,7 @@ import simplemind
project = "simplemind" project = "simplemind"
copyright = "2024 Kenneth Reitz" copyright = "2024 Kenneth Reitz"
author = "Kenneth Reitz" author = "Kenneth Reitz"
release = "v0.1.4" release = "v0.1.6"
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
+1 -1
View File
@@ -1,6 +1,6 @@
[project] [project]
name = "simplemind" name = "simplemind"
version = "0.1.4" version = "0.1.6"
description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases." description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases."
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"
+5 -1
View File
@@ -16,7 +16,7 @@ class Session:
self, self,
*, *,
llm_provider: str = settings.DEFAULT_LLM_PROVIDER, llm_provider: str = settings.DEFAULT_LLM_PROVIDER,
llm_model: str = settings.DEFAULT_LLM_MODEL, llm_model: str | None = None,
**kwargs, **kwargs,
): ):
self.llm_provider = llm_provider self.llm_provider = llm_provider
@@ -113,6 +113,9 @@ def generate_text(
return provider.generate_text(prompt=prompt, llm_model=llm_model, **kwargs) return provider.generate_text(prompt=prompt, llm_model=llm_model, **kwargs)
# Syntax sugar.
Plugin = BasePlugin
__all__ = [ __all__ = [
"create_conversation", "create_conversation",
"find_provider", "find_provider",
@@ -121,4 +124,5 @@ __all__ = [
"settings", "settings",
"BasePlugin", "BasePlugin",
"Session", "Session",
"Plugin",
] ]
+27
View File
@@ -0,0 +1,27 @@
import time
from typing import Any, Callable
import logfire
from .settings import settings
def logger(func: Callable[..., Any]) -> Callable[..., Any]:
"""A @logger decorator that logs the function parameters, function returns, and exceptions raised if logging is enabled."""
def wrapper(*args, **kwargs) -> Any:
if not settings.logging.enabled:
return func(*args, **kwargs)
logfire.info(f"Calling {func.__name__} with args: {args}, kwargs: {kwargs}")
t1 = time.perf_counter()
try:
result = func(*args, **kwargs)
t2 = time.perf_counter()
logfire.info(f"{func.__name__} returned: {result} in {t2-t1} seconds")
return result
except Exception as e:
t2 = time.perf_counter()
logfire.error(f"Error in {func.__name__}: {e} in {t2-t1} seconds")
raise e
return wrapper
+1 -1
View File
@@ -1,6 +1,6 @@
from types import TracebackType
import uuid import uuid
from datetime import datetime from datetime import datetime
from types import TracebackType
from typing import Any, Dict, List, Literal, Optional from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
+4 -1
View File
@@ -1,10 +1,13 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import cached_property from functools import cached_property
from typing import Any, Type, TypeVar from typing import TYPE_CHECKING, Any, Type, TypeVar
from instructor import Instructor from instructor import Instructor
from pydantic import BaseModel from pydantic import BaseModel
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)
+39 -13
View File
@@ -1,24 +1,29 @@
from functools import cached_property from functools import cached_property
from typing import Type, TypeVar from typing import TYPE_CHECKING, Type, TypeVar
import anthropic
import instructor import instructor
from pydantic import BaseModel from pydantic import BaseModel
from ..logging import logger
from ..settings import settings from ..settings import settings
from ._base import BaseProvider from ._base import BaseProvider
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "anthropic" PROVIDER_NAME = "anthropic"
DEFAULT_MODEL = "claude-3-5-sonnet-20241022" DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
DEFAULT_MAX_TOKENS = 1000 DEFAULT_MAX_TOKENS = 1_000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
class Anthropic(BaseProvider): class Anthropic(BaseProvider):
NAME = PROVIDER_NAME NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
def __init__(self, api_key: str | None = None): def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@@ -28,6 +33,13 @@ class Anthropic(BaseProvider):
"""The raw Anthropic client.""" """The raw Anthropic client."""
if not self.api_key: if not self.api_key:
raise ValueError("Anthropic API key is required") raise ValueError("Anthropic API key is required")
try:
import anthropic
except ImportError as exc:
raise ImportError(
"Please install the `anthropic` package: `pip install anthropic`"
) from exc
return anthropic.Anthropic(api_key=self.api_key) return anthropic.Anthropic(api_key=self.api_key)
@cached_property @cached_property
@@ -35,7 +47,8 @@ class Anthropic(BaseProvider):
"""A client patched with Instructor.""" """A client patched with Instructor."""
return instructor.from_anthropic(self.client) return instructor.from_anthropic(self.client)
def send_conversation(self, conversation: "Conversation", **kwargs): @logger
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
"""Send a conversation to the Anthropic API.""" """Send a conversation to the Anthropic API."""
from ..models import Message from ..models import Message
@@ -46,8 +59,7 @@ class Anthropic(BaseProvider):
response = self.client.messages.create( response = self.client.messages.create(
model=conversation.llm_model or self.DEFAULT_MODEL, model=conversation.llm_model or self.DEFAULT_MODEL,
messages=messages, messages=messages,
max_tokens=DEFAULT_MAX_TOKENS, **{**self.DEFAULT_KWARGS, **kwargs},
**kwargs,
) )
# Get the response content from the Anthropic response # Get the response content from the Anthropic response
@@ -62,12 +74,27 @@ class Anthropic(BaseProvider):
llm_provider=PROVIDER_NAME, llm_provider=PROVIDER_NAME,
) )
def structured_response(self, model: str, response_model: Type[T], **kwargs) -> T: @logger
response = self.structured_client.messages.create( def structured_response(
model=model or self.DEFAULT_MODEL, response_model=response_model, **kwargs self, response_model: Type[T], *, llm_model: str | None = None, **kwargs
) ) -> T:
return response model = llm_model or self.DEFAULT_MODEL
# Extract the prompt from kwargs if it exists
prompt = kwargs.pop("prompt", kwargs.pop("messages", ""))
# Format the messages properly
messages = [{"role": "user", "content": prompt}]
response = self.structured_client.messages.create(
model=model,
messages=messages, # Add the messages parameter
response_model=response_model,
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response_model.model_validate(response)
@logger
def generate_text(self, prompt: str, *, llm_model: str, **kwargs): def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
messages = [ messages = [
{"role": "user", "content": prompt}, {"role": "user", "content": prompt},
@@ -76,8 +103,7 @@ class Anthropic(BaseProvider):
response = self.client.messages.create( response = self.client.messages.create(
model=llm_model or self.DEFAULT_MODEL, model=llm_model or self.DEFAULT_MODEL,
messages=messages, messages=messages,
max_tokens=DEFAULT_MAX_TOKENS, **{**self.DEFAULT_KWARGS, **kwargs},
**kwargs,
) )
return response.content[0].text return response.content[0].text
+26 -10
View File
@@ -2,21 +2,25 @@
# IT is not currently working as desired. # IT is not currently working as desired.
from functools import cached_property from functools import cached_property
from typing import Type, TypeVar from typing import TYPE_CHECKING, Type, TypeVar
import google.generativeai as genai
import instructor import instructor
from pydantic import BaseModel from pydantic import BaseModel
from ..logging import logger
from ..settings import settings from ..settings import settings
from ._base import BaseProvider from ._base import BaseProvider
PROVIDER_NAME = "gemini" if TYPE_CHECKING:
DEFAULT_MODEL = "models/gemini-1.5-flash-latest" from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "gemini"
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
class Gemini(BaseProvider): class Gemini(BaseProvider):
NAME = PROVIDER_NAME NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL DEFAULT_MODEL = DEFAULT_MODEL
@@ -25,12 +29,21 @@ class Gemini(BaseProvider):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
self.model_name = DEFAULT_MODEL self.model_name = DEFAULT_MODEL
def set_model(self, model_name: str):
self.model_name = model_name
@cached_property @cached_property
def client(self, model_name: str = DEFAULT_MODEL): def client(self):
"""The raw Gemini client.""" """The raw Gemini client."""
if not self.api_key: if not self.api_key:
raise ValueError("Gemini API key is required") raise ValueError("Gemini API key is required")
self.model_name = model_name try:
import google.generativeai as genai
except ImportError as exc:
raise ImportError(
"Please install the `google-generativeai` package: `pip install google-generativeai`"
) from exc
genai.configure(api_key=self.api_key)
return genai.GenerativeModel(model_name=self.model_name) return genai.GenerativeModel(model_name=self.model_name)
@cached_property @cached_property
@@ -38,6 +51,7 @@ class Gemini(BaseProvider):
"""A Gemini client patched with Instructor.""" """A Gemini client patched with Instructor."""
return instructor.from_gemini(self.client) return instructor.from_gemini(self.client)
@logger
def send_conversation(self, conversation: "Conversation") -> "Message": def send_conversation(self, conversation: "Conversation") -> "Message":
"""Send a conversation to the Gemini API.""" """Send a conversation to the Gemini API."""
from ..models import Message from ..models import Message
@@ -64,9 +78,11 @@ class Gemini(BaseProvider):
llm_provider=PROVIDER_NAME, llm_provider=PROVIDER_NAME,
) )
@logger
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T: def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
"""Send a structured response to the Gemini API.""" """Send a structured response to the Gemini API."""
llm_model = kwargs.pop("llm_model", self.model_name) # Only try to pop if the key exists
kwargs.pop("llm_model", None) # Add default value of None
try: try:
response = self.structured_client.chat.completions.create( response = self.structured_client.chat.completions.create(
@@ -79,12 +95,12 @@ class Gemini(BaseProvider):
raise RuntimeError( raise RuntimeError(
f"Failed to send structured response to Gemini API: {e}" f"Failed to send structured response to Gemini API: {e}"
) from e ) from e
return response return response_model.model_validate(response)
@logger
def generate_text(self, prompt: str, **kwargs) -> str: def generate_text(self, prompt: str, **kwargs) -> str:
"""Generate text using the Gemini API.""" """Generate text using the Gemini API."""
llm_model = kwargs.pop("llm_model", self.model_name) kwargs.pop("llm_model")
try: try:
response = self.client.generate_content(prompt, **kwargs) response = self.client.generate_content(prompt, **kwargs)
except Exception as e: except Exception as e:
+27 -10
View File
@@ -1,22 +1,29 @@
from functools import cached_property from functools import cached_property
from typing import Type, TypeVar from typing import TYPE_CHECKING, Type, TypeVar
import groq
import instructor import instructor
from pydantic import BaseModel from pydantic import BaseModel
from ..logging import logger
from ..settings import settings from ..settings import settings
from ._base import BaseProvider from ._base import BaseProvider
PROVIDER_NAME = "groq" if TYPE_CHECKING:
DEFAULT_MODEL = "llama3-8b-8192" from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "groq"
DEFAULT_MODEL = "llama3-8b-8192"
DEFAULT_MAX_TOKENS = 1_000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
class Groq(BaseProvider): class Groq(BaseProvider):
NAME = PROVIDER_NAME NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
def __init__(self, api_key: str | None = None): def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@@ -26,6 +33,12 @@ class Groq(BaseProvider):
"""The raw Groq client.""" """The raw Groq client."""
if not self.api_key: if not self.api_key:
raise ValueError("Groq API key is required") raise ValueError("Groq API key is required")
try:
import groq
except ImportError as exc:
raise ImportError(
"Please install the `groq` package: `pip install groq`"
) from exc
return groq.Groq(api_key=self.api_key) return groq.Groq(api_key=self.api_key)
@cached_property @cached_property
@@ -33,6 +46,7 @@ class Groq(BaseProvider):
"""A client patched with Instructor.""" """A client patched with Instructor."""
return instructor.from_groq(self.client) return instructor.from_groq(self.client)
@logger
def send_conversation( def send_conversation(
self, self,
conversation: "Conversation", conversation: "Conversation",
@@ -48,7 +62,7 @@ class Groq(BaseProvider):
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=conversation.llm_model or self.DEFAULT_MODEL, model=conversation.llm_model or self.DEFAULT_MODEL,
messages=messages, messages=messages,
**kwargs, **{**self.DEFAULT_KWARGS, **kwargs},
) )
# Get the response content from the Groq response # Get the response content from the Groq response
@@ -63,6 +77,7 @@ class Groq(BaseProvider):
llm_provider=PROVIDER_NAME, llm_provider=PROVIDER_NAME,
) )
@logger
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T: def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
# Ensure messages are provided in kwargs # Ensure messages are provided in kwargs
messages = [ messages = [
@@ -72,17 +87,19 @@ class Groq(BaseProvider):
response = self.structured_client.chat.completions.create( response = self.structured_client.chat.completions.create(
messages=messages, messages=messages,
response_model=response_model, response_model=response_model,
**kwargs, model=kwargs.pop("llm_model", self.DEFAULT_MODEL),
**{**self.DEFAULT_KWARGS, **kwargs},
) )
return response return response_model.model_validate(response)
@logger
def generate_text( def generate_text(
self, self,
prompt: str, prompt: str,
*, *,
llm_model: str, llm_model: str,
**kwargs, **kwargs,
): ) -> str:
messages = [ messages = [
{"role": "user", "content": prompt}, {"role": "user", "content": prompt},
] ]
@@ -90,7 +107,7 @@ class Groq(BaseProvider):
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
messages=messages, messages=messages,
model=llm_model or self.DEFAULT_MODEL, model=llm_model or self.DEFAULT_MODEL,
**kwargs, **{**self.DEFAULT_KWARGS, **kwargs},
) )
return response.choices[0].message.content return str(response.choices[0].message.content)
+34 -9
View File
@@ -1,25 +1,30 @@
from functools import cached_property from functools import cached_property
from typing import Type, TypeVar from typing import TYPE_CHECKING, Type, TypeVar
import instructor import instructor
import ollama as ol
from openai import OpenAI from openai import OpenAI
from pydantic import BaseModel from pydantic import BaseModel
from ..logging import logger
from ..settings import settings from ..settings import settings
from ._base import BaseProvider from ._base import BaseProvider
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "ollama" PROVIDER_NAME = "ollama"
DEFAULT_MODEL = "llama3.2" DEFAULT_MODEL = "llama3.2"
DEFAULT_TIMEOUT = 60 DEFAULT_TIMEOUT = 60
DEFAULT_KWARGS = {}
class Ollama(BaseProvider): class Ollama(BaseProvider):
NAME = PROVIDER_NAME NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
TIMEOUT = DEFAULT_TIMEOUT TIMEOUT = DEFAULT_TIMEOUT
def __init__(self, host_url: str | None = None): def __init__(self, host_url: str | None = None):
@@ -30,6 +35,12 @@ class Ollama(BaseProvider):
"""The raw Ollama client.""" """The raw Ollama client."""
if not self.host_url: if not self.host_url:
raise ValueError("No ollama host url provided") raise ValueError("No ollama host url provided")
try:
import ollama as ol
except ImportError as exc:
raise ImportError(
"Please install the `ollama` package: `pip install ollama`"
) from exc
return ol.Client(timeout=self.TIMEOUT, host=self.host_url) return ol.Client(timeout=self.TIMEOUT, host=self.host_url)
@cached_property @cached_property
@@ -43,7 +54,8 @@ class Ollama(BaseProvider):
mode=instructor.Mode.JSON, mode=instructor.Mode.JSON,
) )
def send_conversation(self, conversation: "Conversation") -> "Message": @logger
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
"""Send a conversation to the Ollama API.""" """Send a conversation to the Ollama API."""
from ..models import Message from ..models import Message
@@ -51,7 +63,9 @@ class Ollama(BaseProvider):
{"role": msg.role, "content": msg.text} for msg in conversation.messages {"role": msg.role, "content": msg.text} for msg in conversation.messages
] ]
response = self.client.chat( response = self.client.chat(
model=conversation.llm_model or DEFAULT_MODEL, messages=messages model=conversation.llm_model or DEFAULT_MODEL,
messages=messages,
**{**self.DEFAULT_KWARGS, **kwargs},
) )
assistant_message = response.get("message") assistant_message = response.get("message")
@@ -64,8 +78,14 @@ class Ollama(BaseProvider):
llm_provider=PROVIDER_NAME, llm_provider=PROVIDER_NAME,
) )
@logger
def structured_response( def structured_response(
self, prompt: str, response_model: Type[T], *, llm_model: str, **kwargs self,
prompt: str,
response_model: Type[T],
*,
llm_model: str | None = None,
**kwargs,
) -> T: ) -> T:
"""Get a structured response from the Ollama API.""" """Get a structured response from the Ollama API."""
messages = [ messages = [
@@ -76,18 +96,23 @@ class Ollama(BaseProvider):
messages=messages, messages=messages,
model=llm_model or self.DEFAULT_MODEL, model=llm_model or self.DEFAULT_MODEL,
response_model=response_model, response_model=response_model,
**kwargs, **{**self.DEFAULT_KWARGS, **kwargs},
) )
return response return response_model.model_validate(response)
def generate_text(self, prompt: str, *, llm_model: str) -> str: @logger
def generate_text(
self, prompt: str, *, llm_model: str | None = None, **kwargs
) -> str:
"""Generate text using the Ollama API.""" """Generate text using the Ollama API."""
messages = [ messages = [
{"role": "user", "content": prompt}, {"role": "user", "content": prompt},
] ]
response = self.client.chat( response = self.client.chat(
messages=messages, model=llm_model or self.DEFAULT_MODEL messages=messages,
model=llm_model or self.DEFAULT_MODEL,
**{**self.DEFAULT_KWARGS, **kwargs},
) )
return response.get("message", {}).get("content", "") return response.get("message", {}).get("content", "")
+26 -7
View File
@@ -1,22 +1,28 @@
from functools import cached_property from functools import cached_property
from typing import Type, TypeVar from typing import TYPE_CHECKING, Type, TypeVar
import instructor import instructor
import openai as oa
from pydantic import BaseModel from pydantic import BaseModel
from ..logging import logger
from ..settings import settings from ..settings import settings
from ._base import BaseProvider from ._base import BaseProvider
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "openai" PROVIDER_NAME = "openai"
DEFAULT_MODEL = "gpt-4o-mini" DEFAULT_MODEL = "gpt-4o-mini"
DEFAULT_MAX_TOKENS = 1_000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
class OpenAI(BaseProvider): class OpenAI(BaseProvider):
NAME = PROVIDER_NAME NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
def __init__(self, api_key: str | None = None): def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@@ -26,6 +32,12 @@ class OpenAI(BaseProvider):
"""The raw OpenAI client.""" """The raw OpenAI client."""
if not self.api_key: if not self.api_key:
raise ValueError("OpenAI API key is required") raise ValueError("OpenAI API key is required")
try:
import openai as oa
except ImportError as exc:
raise ImportError(
"Please install the `openai` package: `pip install openai`"
) from exc
return oa.OpenAI(api_key=self.api_key) return oa.OpenAI(api_key=self.api_key)
@cached_property @cached_property
@@ -33,7 +45,8 @@ class OpenAI(BaseProvider):
"""A OpenAI client with Instructor.""" """A OpenAI client with Instructor."""
return instructor.from_openai(self.client) return instructor.from_openai(self.client)
def send_conversation(self, conversation: "Conversation", **kwargs): @logger
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
"""Send a conversation to the OpenAI API.""" """Send a conversation to the OpenAI API."""
from ..models import Message from ..models import Message
@@ -42,7 +55,9 @@ class OpenAI(BaseProvider):
] ]
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=conversation.llm_model or DEFAULT_MODEL, messages=messages, **kwargs model=conversation.llm_model or DEFAULT_MODEL,
messages=messages,
**{**self.DEFAULT_KWARGS, **kwargs},
) )
# Get the response content from the OpenAI response # Get the response content from the OpenAI response
@@ -57,6 +72,7 @@ class OpenAI(BaseProvider):
llm_provider=PROVIDER_NAME, llm_provider=PROVIDER_NAME,
) )
@logger
def structured_response( def structured_response(
self, self,
prompt: str, prompt: str,
@@ -74,16 +90,19 @@ class OpenAI(BaseProvider):
messages=messages, messages=messages,
model=llm_model or self.DEFAULT_MODEL, model=llm_model or self.DEFAULT_MODEL,
response_model=response_model, response_model=response_model,
**kwargs, **{**self.DEFAULT_KWARGS, **kwargs},
) )
return response return response_model.model_validate(response)
@logger
def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs): def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs):
"""Generate text using the OpenAI API.""" """Generate text using the OpenAI API."""
messages = [ messages = [
{"role": "user", "content": prompt}, {"role": "user", "content": prompt},
] ]
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
messages=messages, model=llm_model or self.DEFAULT_MODEL, **kwargs messages=messages,
model=llm_model or self.DEFAULT_MODEL,
**{**self.DEFAULT_KWARGS, **kwargs},
) )
return response.choices[0].message.content return response.choices[0].message.content
+28 -7
View File
@@ -1,20 +1,30 @@
from functools import cached_property from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar
import instructor import instructor
import openai as oa from pydantic import BaseModel
from ..logging import logger
from ..settings import settings from ..settings import settings
from ._base import BaseProvider from ._base import BaseProvider
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "xai" PROVIDER_NAME = "xai"
DEFAULT_MODEL = "grok-beta" DEFAULT_MODEL = "grok-beta"
BASE_URL = "https://api.x.ai/v1" BASE_URL = "https://api.x.ai/v1"
DEFAULT_MAX_TOKENS = 1000 DEFAULT_MAX_TOKENS = 1000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
class XAI(BaseProvider): class XAI(BaseProvider):
NAME = PROVIDER_NAME NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
def __init__(self, api_key: str | None = None): def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@@ -24,6 +34,12 @@ class XAI(BaseProvider):
"""The raw OpenAI client.""" """The raw OpenAI client."""
if not self.api_key: if not self.api_key:
raise ValueError("XAI API key is required") raise ValueError("XAI API key is required")
try:
import openai as oa
except ImportError as exc:
raise ImportError(
"Please install the `openai` package: `pip install openai`"
) from exc
return oa.OpenAI( return oa.OpenAI(
api_key=self.api_key, api_key=self.api_key,
base_url=BASE_URL, base_url=BASE_URL,
@@ -34,7 +50,8 @@ class XAI(BaseProvider):
"""A client patched with Instructor.""" """A client patched with Instructor."""
return instructor.from_openai(self.client) return instructor.from_openai(self.client)
def send_conversation(self, conversation: "Conversation", **kwargs): @logger
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
"""Send a conversation to the OpenAI API.""" """Send a conversation to the OpenAI API."""
from ..models import Message from ..models import Message
@@ -45,7 +62,7 @@ class XAI(BaseProvider):
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=conversation.llm_model or self.DEFAULT_MODEL, model=conversation.llm_model or self.DEFAULT_MODEL,
messages=messages, messages=messages,
**kwargs, **{**self.DEFAULT_KWARGS, **kwargs},
) )
# Get the response content from the OpenAI response # Get the response content from the OpenAI response
@@ -60,10 +77,14 @@ class XAI(BaseProvider):
llm_provider=PROVIDER_NAME, llm_provider=PROVIDER_NAME,
) )
def structured_response(self, prompt: str, response_model, *, llm_model: str): @logger
def structured_response(
self, prompt: str, response_model: Type[T], *, llm_model: str
) -> T:
raise NotImplementedError("XAI does not support structured responses") raise NotImplementedError("XAI does not support structured responses")
def generate_text(self, prompt: str, *, llm_model: str, **kwargs): @logger
def generate_text(self, prompt: str, *, llm_model: str, **kwargs) -> str:
messages = [ messages = [
{"role": "user", "content": prompt}, {"role": "user", "content": prompt},
] ]
@@ -71,7 +92,7 @@ class XAI(BaseProvider):
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
messages=messages, messages=messages,
model=llm_model or self.DEFAULT_MODEL, model=llm_model or self.DEFAULT_MODEL,
**kwargs, **{**self.DEFAULT_KWARGS, **kwargs},
) )
return response.choices[0].message.content return str(response.choices[0].message.content)
+27 -5
View File
@@ -1,19 +1,42 @@
from typing import Literal, Optional, Union from typing import Optional, Union
from pydantic import Field, SecretStr, field_validator from pydantic import Field, SecretStr, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
logging_level = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
class LoggingConfig(BaseSettings): class LoggingConfig(BaseSettings):
"""The class that holds all the logging settings for the application.""" """The class that holds all the logging settings for the application."""
enabled: bool = Field(False, description="Enable logging") enabled: bool = Field(False, description="Enable logging")
level: logging_level = Field("INFO", description="The logging level")
model_config = SettingsConfigDict(extra="forbid") model_config = SettingsConfigDict(extra="forbid")
def enable_logfire(self, **kwargs) -> None:
"""Enable logging for the application."""
# adding imports here to avoid forced dependencies
try:
import logfire
from logging import basicConfig
except ImportError as e:
raise ImportError(
"To enable logging, please install logfire: `pip install logfire`"
) from e
self.enabled = True
logfire.configure(**kwargs)
basicConfig(handlers=[logfire.LogfireLoggingHandler()])
try:
logfire.configure(**kwargs)
basicConfig(handlers=[logfire.LogfireLoggingHandler()])
except Exception as e:
self.enabled = False # Reset flag on failure
raise RuntimeError("Failed to configure logging") from e
def disable_logfire(self) -> None:
"""Disable logging for the application."""
self.enabled = False
class Settings(BaseSettings): class Settings(BaseSettings):
"""The class that holds all the API keys for the application.""" """The class that holds all the API keys for the application."""
@@ -29,7 +52,6 @@ class Settings(BaseSettings):
) )
XAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for xAI") XAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for xAI")
DEFAULT_LLM_PROVIDER: str = Field("openai", description="The default LLM provider") DEFAULT_LLM_PROVIDER: str = Field("openai", description="The default LLM provider")
DEFAULT_LLM_MODEL: str = Field("gpt-4o-mini", description="The default LLM model")
model_config = SettingsConfigDict( model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", case_sensitive=True, extra="ignore" env_file=".env", env_file_encoding="utf-8", case_sensitive=True, extra="ignore"
+15
View File
@@ -0,0 +1,15 @@
import os
import sys
import pytest
# Add the project root to the Python path.
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from simplemind import Session
@pytest.fixture
def sm():
"""Fixture that provides a simplemind Session instance with default settings."""
return Session()
+2
View File
@@ -0,0 +1,2 @@
def test_basic_math():
assert 1 + 1 == 2
+28
View File
@@ -0,0 +1,28 @@
from pydantic import BaseModel
import pytest
from simplemind.providers import Anthropic, Gemini, Groq, Ollama, OpenAI
class ResponseModel(BaseModel):
result: int
@pytest.mark.parametrize(
"provider_cls",
[
Anthropic,
Gemini,
OpenAI,
Groq,
Ollama,
],
)
def test_generate_data(provider_cls):
provider = provider_cls()
prompt = "What is 2+2?"
data = provider.structured_response(prompt=prompt, response_model=ResponseModel)
assert isinstance(data, ResponseModel)
assert isinstance(data.result, int)
+22
View File
@@ -0,0 +1,22 @@
import pytest
from simplemind.providers import Anthropic, Gemini, Groq, Ollama, OpenAI
@pytest.mark.parametrize(
"provider_cls",
[
Anthropic,
Gemini,
OpenAI,
Groq,
Ollama,
],
)
def test_generate_text(provider_cls):
provider = provider_cls()
prompt = "What is 2+2?"
response = provider.generate_text(prompt=prompt, llm_model=provider.DEFAULT_MODEL)
assert isinstance(response, str)
assert len(response) > 0