1 Commits

23 changed files with 75 additions and 501 deletions
-1
View File
@@ -4,4 +4,3 @@ export GROQ_API_KEY=""
export OLLAMA_HOST_URL=""
export OPENAI_API_KEY=""
export XAI_API_KEY=""
export AMAZON_PROFILE_NAME=""
-1
View File
@@ -167,4 +167,3 @@ cython_debug/
src/**
requirements.txt
Pipfile
-21
View File
@@ -1,31 +1,10 @@
Release History
===============
## 0.2.0 (2024-11-01)
- Add Amazon Bedrock provider.
- Make all provider optional dependencies. Use `$ pip install 'simplemind[full]'` to install all providers.
- General improvements.
## 0.1.7 (2024-11-01)
- Add `logger` decorator.
- Add `sm.enable_logfire()` function.
- General improvements.
## 0.1.6 (2024-10-31)
- Add `sm.Plugin` syntax sugar.
- Improvements to Anthropic provider, related to max tokens.
- General improvements.
- Add tests for structured response.
- Add `llm_model` to `structured_response`.
## 0.1.5 (2024-10-31)
- Add Gemini provider.
- Add structured response to Gemini provider.
- Support for Python 3.10.
## 0.1.4 (2024-10-30)
+8 -44
View File
@@ -6,6 +6,8 @@ Simplemind is AI library designed to simplify your experience with AI APIs in Py
![simplemind](https://github.com/user-attachments/assets/36df2103-2583-4958-ad5e-19cda7740256)
[![Auto Wiki](https://img.shields.io/badge/Auto_Wiki-Mutable.ai-blue)](https://mutable.ai/kennethreitz/simplemind)
## Features
With Simplemind, tapping into AI is as easy as a friendly conversation.
@@ -16,20 +18,18 @@ With Simplemind, tapping into AI is as easy as a friendly conversation.
## Supported APIs
To specify a specific provider or model, you can use the `llm_provider` and `llm_model` parameters when calling: `generate_text`, `generate_data`, or `create_conversation`. The APIs remain identital between all supported providers/models.
To specify a specific provider or model, you can use the `llm_provider` and `llm_model` parameters when calling: `generate_text`, `generate_data`, or `create_conversation`.
- [**Anthropic's Claude**](https://www.anthropic.com/claude)
- [**Amazon Bedrock**](https://aws.amazon.com/bedrock/)
- [**Google's Gemini**](https://gemini.google/)
- [**Groq's Groq**](https://groq.com/)
- [**Ollama**](https://ollama.com)
- [**OpenAI's GPT**](https://openai.com/gpt)
- [**xAI's Grok**](https://x.ai/)
If you want to see Simplemind support, additional providers or models, please send a pull request!
If you want to see Simplemind support, additional providers or models, please request a pull!
## Why SimpleMind?
- **Intuitive**: Built with Pythonic simplicity and readability in mind.
- **For Humans**: Emphasizes a human-friendly interface, just like `requests` for HTTP.
- **Open Source**: Simplemind is open source, and contributions are always welcome!
@@ -41,7 +41,7 @@ Also, why not? :)
Simplemind takes care of the complex API calls so you can focus on what matters—building, experimenting, and creating.
```bash
$ pip install 'simplemind[full]'
$ pip install simplemind
```
First, authenticate your API keys by setting them in the environment variables:
@@ -50,7 +50,7 @@ First, authenticate your API keys by setting them in the environment variables:
$ export OPENAI_API_KEY="sk-..."
```
This pattern allows you to keep your API keys private and out of your codebase. Other supported environment variables: `ANTHROPIC_API_KEY`, `XAI_API_KEY`, `GROQ_API_KEY`, and `GEMINI_API_KEY`.
This pattern allows you to keep your API keys private and out of your codebase. Other supported environment variables: `ANTHROPIC_API_KEY`, `XAI_API_KEY`, and `GROQ_API_KEY`.
Next, import Simplemind and start using it:
@@ -58,6 +58,7 @@ Next, import Simplemind and start using it:
import simplemind as sm
```
## Examples
Here are some examples of how to use Simplemind:
@@ -91,33 +92,6 @@ class Poem(BaseModel):
title='Eternal Embrace' content='In the quiet hours of the night,\nWhen stars whisper secrets bright,\nTwo hearts beat in a gentle rhyme,\nDancing through the sands of time.\n\nWith every glance, a spark ignites,\nA flame that warms the coldest nights,\nIn laughter shared and whispers sweet,\nLove paints the world, a masterpiece.\n\nThrough stormy skies and sunlit days,\nIn myriad forms, it finds its ways,\nA tender touch, a knowing sigh,\nIn loves embrace, we learn to fly.\n\nAs seasons change and moments fade,\nIn the tapestry of dreams weve laid,\nLoves threads endure, forever bind,\nA timeless bond, two souls aligned.\n\nSo heres to love, both bright and true,\nA gift we give, anew, anew,\nIn every heartbeat, every prayer,\nA story written in the air.'
```
#### A more complex example
```python
class InstructionStep(BaseModel):
step_number: int
instruction: str
class RecipeIngredient(BaseModel):
name: str
quantity: float
unit: str
class Recipe(BaseModel):
name: str
ingredients: list[RecipeIngredient]
instructions: list[InstructionStep]
recipe = sm.generate_data(
"Write a recipe for chocolate chip cookies",
llm_model="gpt-4o-mini",
llm_provider="openai",
response_model=Recipe,
)
```
Special thanks to [@jxnl](https://github.com/jxnl) for building [Instructor](https://github.com/jxnl/instructor), which makes this possible!
### Conversational AI
SimpleMind also allows for easy conversational flows:
@@ -191,7 +165,6 @@ conversation.add_message(
text="Please write a poem about the moon",
)
```
```pycon
>>> conversation.send()
In the vast expanse where stars do play,
@@ -227,18 +200,11 @@ The universe is never done.
Simple, yet effective.
### Logging
Simplemind uses [logfire](https://logfire.ai) for logging. To enable logging, call `sm.enable_logfire()`.
### More Examples
Please see the [examples](examples) directory for executable examples.
---
-------------------
## Contributing
We welcome contributions of all kinds. Feel free to open issues for bug reports or feature requests, and submit pull requests to make SimpleMind even better.
To get started:
@@ -249,9 +215,7 @@ To get started:
4. Submit a pull request.
## License
Simplemind is licensed under the Apache 2.0 License.
## Acknowledgements
Simplemind is inspired by the philosophy of "code for humans" and aims to make working with AI models accessible to all. Special thanks to the open-source community for their contributions and inspiration.
+1 -1
View File
@@ -16,7 +16,7 @@ import simplemind
project = "simplemind"
copyright = "2024 Kenneth Reitz"
author = "Kenneth Reitz"
release = "v0.2.0"
release = "v0.1.5"
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
+2 -13
View File
@@ -1,21 +1,10 @@
[project]
name = "simplemind"
version = "0.2.0"
version = "0.1.5"
description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases."
readme = "README.md"
requires-python = ">=3.10"
dependencies = ["pydantic", "pydantic-settings", "instructor", "logfire"]
[project.optional-dependencies]
full = [
"openai",
"anthropic",
"ollama",
"groq",
"google-generativeai",
"botocore",
"boto3"
]
dependencies = ["pydantic", "pydantic-settings", "instructor", "openai", "anthropic", "ollama", "groq", "google-generativeai"]
[build-system]
requires = ["hatchling"]
+1 -11
View File
@@ -16,7 +16,7 @@ class Session:
self,
*,
llm_provider: str = settings.DEFAULT_LLM_PROVIDER,
llm_model: str | None = None,
llm_model: str = settings.DEFAULT_LLM_MODEL,
**kwargs,
):
self.llm_provider = llm_provider
@@ -113,14 +113,6 @@ def generate_text(
return provider.generate_text(prompt=prompt, llm_model=llm_model, **kwargs)
def enable_logfire() -> None:
"""Enable logfire logging."""
settings.logging.enable_logfire()
# Syntax sugar.
Plugin = BasePlugin
__all__ = [
"create_conversation",
"find_provider",
@@ -129,6 +121,4 @@ __all__ = [
"settings",
"BasePlugin",
"Session",
"Plugin",
"enable_logfire",
]
-33
View File
@@ -1,33 +0,0 @@
import time
from typing import Any, Callable
import logfire
from .settings import settings
def logger(func: Callable[..., Any]) -> Callable[..., Any]:
"""A decorator that logs the function parameters, function returns,
and exceptions raised if logging is enabled, using logfire.
"""
def wrapper(*args, **kwargs) -> Any:
if not settings.logging.is_enabled:
return func(*args, **kwargs)
logfire.info(f"Calling {func.__name__} with args: {args}, kwargs: {kwargs}")
t1 = time.perf_counter()
try:
result = func(*args, **kwargs)
t2 = time.perf_counter()
logfire.info(f"{func.__name__} returned: {result} in {t2-t1} seconds")
return result
except Exception as e:
t2 = time.perf_counter()
logfire.error(f"Error in {func.__name__}: {e} in {t2-t1} seconds")
raise e
return wrapper
+1 -1
View File
@@ -1,6 +1,6 @@
from types import TracebackType
import uuid
from datetime import datetime
from types import TracebackType
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
+1 -2
View File
@@ -7,6 +7,5 @@ from .groq import Groq
from .ollama import Ollama
from .openai import OpenAI
from .xai import XAI
from .amazon import Amazon
providers: List[Type[BaseProvider]] = [Anthropic, Gemini, Groq, OpenAI, Ollama, XAI, Amazon]
providers: List[Type[BaseProvider]] = [Anthropic, Gemini, Groq, OpenAI, Ollama, XAI]
+1 -4
View File
@@ -1,13 +1,10 @@
from abc import ABC, abstractmethod
from functools import cached_property
from typing import TYPE_CHECKING, Any, Type, TypeVar
from typing import Any, Type, TypeVar
from instructor import Instructor
from pydantic import BaseModel
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
-90
View File
@@ -1,90 +0,0 @@
from typing import Type, TypeVar
import instructor
import anthropic
from pydantic import BaseModel
from ._base import BaseProvider
from ..settings import settings
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "amazon"
DEFAULT_MODEL = "anthropic.claude-3-sonnet-20240229-v1:0"
DEFAULT_MAX_TOKENS = 5_000
class Amazon(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
def __init__(self, profile_name: str | None = None):
self.profile_name = profile_name or settings.AMAZON_PROFILE_NAME
@property
def client(self):
"""The AnthropicBedrock client."""
if not self.profile_name:
raise ValueError("Profile name is not provided")
return anthropic.AnthropicBedrock(aws_profile=self.profile_name)
@property
def structured_client(self):
"""A client patched with Instructor."""
return instructor.from_anthropic(self.client)
def send_conversation(self, conversation: "Conversation", **kwargs):
"""Send a conversation to the OpenAI API."""
from ..models import Message
messages = [
{"role": msg.role, "content": msg.text} for msg in conversation.messages
]
response = self.client.chat.completions.create(
model=conversation.llm_model or DEFAULT_MODEL, messages=messages, **kwargs
)
# Get the response content from the OpenAI response
assistant_message = response.choices[0].message
# Create and return a properly formatted Message instance
return Message(
role="assistant",
text=assistant_message.content or "",
raw=response,
llm_model=conversation.llm_model or DEFAULT_MODEL,
llm_provider=PROVIDER_NAME,
)
def structured_response(
self, prompt, response_model: Type[T], *, llm_model: str | None = None, **kwargs
) -> T:
# Ensure messages are provided in kwargs
messages = [
{"role": "user", "content": prompt},
]
response = self.structured_client.chat.completions.create(
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
response_model=response_model,
max_tokens=DEFAULT_MAX_TOKENS,
**kwargs,
)
return response
def generate_text(self, prompt, *, llm_model, **kwargs):
messages = [
{"role": "user", "content": prompt},
]
response = self.client.messages.create(
model=llm_model or self.DEFAULT_MODEL,
messages=messages,
max_tokens=DEFAULT_MAX_TOKENS,
**kwargs,
)
return response.content[0].text
+11 -37
View File
@@ -1,29 +1,24 @@
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar
from typing import Type, TypeVar
import anthropic
import instructor
from pydantic import BaseModel
from ..logging import logger
from ..settings import settings
from ._base import BaseProvider
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "anthropic"
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
DEFAULT_MAX_TOKENS = 1_000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
DEFAULT_MAX_TOKENS = 1000
class Anthropic(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@@ -33,13 +28,6 @@ class Anthropic(BaseProvider):
"""The raw Anthropic client."""
if not self.api_key:
raise ValueError("Anthropic API key is required")
try:
import anthropic
except ImportError as exc:
raise ImportError(
"Please install the `anthropic` package: `pip install anthropic`"
) from exc
return anthropic.Anthropic(api_key=self.api_key)
@cached_property
@@ -47,8 +35,7 @@ class Anthropic(BaseProvider):
"""A client patched with Instructor."""
return instructor.from_anthropic(self.client)
@logger
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
def send_conversation(self, conversation: "Conversation", **kwargs):
"""Send a conversation to the Anthropic API."""
from ..models import Message
@@ -59,7 +46,8 @@ class Anthropic(BaseProvider):
response = self.client.messages.create(
model=conversation.llm_model or self.DEFAULT_MODEL,
messages=messages,
**{**self.DEFAULT_KWARGS, **kwargs},
max_tokens=DEFAULT_MAX_TOKENS,
**kwargs,
)
# Get the response content from the Anthropic response
@@ -74,27 +62,12 @@ class Anthropic(BaseProvider):
llm_provider=PROVIDER_NAME,
)
@logger
def structured_response(
self, response_model: Type[T], *, llm_model: str | None = None, **kwargs
) -> T:
model = llm_model or self.DEFAULT_MODEL
# Extract the prompt from kwargs if it exists
prompt = kwargs.pop("prompt", kwargs.pop("messages", ""))
# Format the messages properly
messages = [{"role": "user", "content": prompt}]
def structured_response(self, model: str, response_model: Type[T], **kwargs) -> T:
response = self.structured_client.messages.create(
model=model,
messages=messages, # Add the messages parameter
response_model=response_model,
**{**self.DEFAULT_KWARGS, **kwargs},
model=model or self.DEFAULT_MODEL, response_model=response_model, **kwargs
)
return response_model.model_validate(response)
return response
@logger
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
messages = [
{"role": "user", "content": prompt},
@@ -103,7 +76,8 @@ class Anthropic(BaseProvider):
response = self.client.messages.create(
model=llm_model or self.DEFAULT_MODEL,
messages=messages,
**{**self.DEFAULT_KWARGS, **kwargs},
max_tokens=DEFAULT_MAX_TOKENS,
**kwargs,
)
return response.content[0].text
+10 -26
View File
@@ -2,24 +2,20 @@
# IT is not currently working as desired.
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar
from typing import Type, TypeVar
import google.generativeai as genai
import instructor
from pydantic import BaseModel
from ..logging import logger
from ..settings import settings
from ._base import BaseProvider
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "gemini"
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
T = TypeVar("T", bound=BaseModel)
class Gemini(BaseProvider):
NAME = PROVIDER_NAME
@@ -29,21 +25,12 @@ class Gemini(BaseProvider):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
self.model_name = DEFAULT_MODEL
def set_model(self, model_name: str):
self.model_name = model_name
@cached_property
def client(self):
def client(self, model_name: str = DEFAULT_MODEL):
"""The raw Gemini client."""
if not self.api_key:
raise ValueError("Gemini API key is required")
try:
import google.generativeai as genai
except ImportError as exc:
raise ImportError(
"Please install the `google-generativeai` package: `pip install google-generativeai`"
) from exc
genai.configure(api_key=self.api_key)
self.model_name = model_name
return genai.GenerativeModel(model_name=self.model_name)
@cached_property
@@ -51,7 +38,6 @@ class Gemini(BaseProvider):
"""A Gemini client patched with Instructor."""
return instructor.from_gemini(self.client)
@logger
def send_conversation(self, conversation: "Conversation") -> "Message":
"""Send a conversation to the Gemini API."""
from ..models import Message
@@ -78,11 +64,9 @@ class Gemini(BaseProvider):
llm_provider=PROVIDER_NAME,
)
@logger
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
"""Send a structured response to the Gemini API."""
# Only try to pop if the key exists
kwargs.pop("llm_model", None) # Add default value of None
llm_model = kwargs.pop("llm_model", self.model_name)
try:
response = self.structured_client.chat.completions.create(
@@ -95,12 +79,12 @@ class Gemini(BaseProvider):
raise RuntimeError(
f"Failed to send structured response to Gemini API: {e}"
) from e
return response_model.model_validate(response)
return response
@logger
def generate_text(self, prompt: str, **kwargs) -> str:
"""Generate text using the Gemini API."""
kwargs.pop("llm_model")
llm_model = kwargs.pop("llm_model", self.model_name)
try:
response = self.client.generate_content(prompt, **kwargs)
except Exception as e:
+10 -27
View File
@@ -1,29 +1,22 @@
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar
from typing import Type, TypeVar
import groq
import instructor
from pydantic import BaseModel
from ..logging import logger
from ..settings import settings
from ._base import BaseProvider
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "groq"
DEFAULT_MODEL = "llama3-8b-8192"
DEFAULT_MAX_TOKENS = 1_000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
T = TypeVar("T", bound=BaseModel)
class Groq(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@@ -33,12 +26,6 @@ class Groq(BaseProvider):
"""The raw Groq client."""
if not self.api_key:
raise ValueError("Groq API key is required")
try:
import groq
except ImportError as exc:
raise ImportError(
"Please install the `groq` package: `pip install groq`"
) from exc
return groq.Groq(api_key=self.api_key)
@cached_property
@@ -46,7 +33,6 @@ class Groq(BaseProvider):
"""A client patched with Instructor."""
return instructor.from_groq(self.client)
@logger
def send_conversation(
self,
conversation: "Conversation",
@@ -62,7 +48,7 @@ class Groq(BaseProvider):
response = self.client.chat.completions.create(
model=conversation.llm_model or self.DEFAULT_MODEL,
messages=messages,
**{**self.DEFAULT_KWARGS, **kwargs},
**kwargs,
)
# Get the response content from the Groq response
@@ -77,7 +63,6 @@ class Groq(BaseProvider):
llm_provider=PROVIDER_NAME,
)
@logger
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
# Ensure messages are provided in kwargs
messages = [
@@ -87,19 +72,17 @@ class Groq(BaseProvider):
response = self.structured_client.chat.completions.create(
messages=messages,
response_model=response_model,
model=kwargs.pop("llm_model", self.DEFAULT_MODEL),
**{**self.DEFAULT_KWARGS, **kwargs},
**kwargs,
)
return response_model.model_validate(response)
return response
@logger
def generate_text(
self,
prompt: str,
*,
llm_model: str,
**kwargs,
) -> str:
):
messages = [
{"role": "user", "content": prompt},
]
@@ -107,7 +90,7 @@ class Groq(BaseProvider):
response = self.client.chat.completions.create(
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
**{**self.DEFAULT_KWARGS, **kwargs},
**kwargs,
)
return str(response.choices[0].message.content)
return response.choices[0].message.content
+9 -34
View File
@@ -1,30 +1,25 @@
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar
from typing import Type, TypeVar
import instructor
import ollama as ol
from openai import OpenAI
from pydantic import BaseModel
from ..logging import logger
from ..settings import settings
from ._base import BaseProvider
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "ollama"
DEFAULT_MODEL = "llama3.2"
DEFAULT_TIMEOUT = 60
DEFAULT_KWARGS = {}
class Ollama(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
TIMEOUT = DEFAULT_TIMEOUT
def __init__(self, host_url: str | None = None):
@@ -35,12 +30,6 @@ class Ollama(BaseProvider):
"""The raw Ollama client."""
if not self.host_url:
raise ValueError("No ollama host url provided")
try:
import ollama as ol
except ImportError as exc:
raise ImportError(
"Please install the `ollama` package: `pip install ollama`"
) from exc
return ol.Client(timeout=self.TIMEOUT, host=self.host_url)
@cached_property
@@ -54,8 +43,7 @@ class Ollama(BaseProvider):
mode=instructor.Mode.JSON,
)
@logger
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
def send_conversation(self, conversation: "Conversation") -> "Message":
"""Send a conversation to the Ollama API."""
from ..models import Message
@@ -63,9 +51,7 @@ class Ollama(BaseProvider):
{"role": msg.role, "content": msg.text} for msg in conversation.messages
]
response = self.client.chat(
model=conversation.llm_model or DEFAULT_MODEL,
messages=messages,
**{**self.DEFAULT_KWARGS, **kwargs},
model=conversation.llm_model or DEFAULT_MODEL, messages=messages
)
assistant_message = response.get("message")
@@ -78,14 +64,8 @@ class Ollama(BaseProvider):
llm_provider=PROVIDER_NAME,
)
@logger
def structured_response(
self,
prompt: str,
response_model: Type[T],
*,
llm_model: str | None = None,
**kwargs,
self, prompt: str, response_model: Type[T], *, llm_model: str, **kwargs
) -> T:
"""Get a structured response from the Ollama API."""
messages = [
@@ -96,23 +76,18 @@ class Ollama(BaseProvider):
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
response_model=response_model,
**{**self.DEFAULT_KWARGS, **kwargs},
**kwargs,
)
return response_model.model_validate(response)
return response
@logger
def generate_text(
self, prompt: str, *, llm_model: str | None = None, **kwargs
) -> str:
def generate_text(self, prompt: str, *, llm_model: str) -> str:
"""Generate text using the Ollama API."""
messages = [
{"role": "user", "content": prompt},
]
response = self.client.chat(
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
**{**self.DEFAULT_KWARGS, **kwargs},
messages=messages, model=llm_model or self.DEFAULT_MODEL
)
return response.get("message", {}).get("content", "")
+7 -26
View File
@@ -1,28 +1,22 @@
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar
from typing import Type, TypeVar
import instructor
import openai as oa
from pydantic import BaseModel
from ..logging import logger
from ..settings import settings
from ._base import BaseProvider
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "openai"
DEFAULT_MODEL = "gpt-4o-mini"
DEFAULT_MAX_TOKENS = 1_000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
class OpenAI(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@@ -32,12 +26,6 @@ class OpenAI(BaseProvider):
"""The raw OpenAI client."""
if not self.api_key:
raise ValueError("OpenAI API key is required")
try:
import openai as oa
except ImportError as exc:
raise ImportError(
"Please install the `openai` package: `pip install openai`"
) from exc
return oa.OpenAI(api_key=self.api_key)
@cached_property
@@ -45,8 +33,7 @@ class OpenAI(BaseProvider):
"""A OpenAI client with Instructor."""
return instructor.from_openai(self.client)
@logger
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
def send_conversation(self, conversation: "Conversation", **kwargs):
"""Send a conversation to the OpenAI API."""
from ..models import Message
@@ -55,9 +42,7 @@ class OpenAI(BaseProvider):
]
response = self.client.chat.completions.create(
model=conversation.llm_model or DEFAULT_MODEL,
messages=messages,
**{**self.DEFAULT_KWARGS, **kwargs},
model=conversation.llm_model or DEFAULT_MODEL, messages=messages, **kwargs
)
# Get the response content from the OpenAI response
@@ -72,7 +57,6 @@ class OpenAI(BaseProvider):
llm_provider=PROVIDER_NAME,
)
@logger
def structured_response(
self,
prompt: str,
@@ -90,19 +74,16 @@ class OpenAI(BaseProvider):
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
response_model=response_model,
**{**self.DEFAULT_KWARGS, **kwargs},
**kwargs,
)
return response_model.model_validate(response)
return response
@logger
def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs):
"""Generate text using the OpenAI API."""
messages = [
{"role": "user", "content": prompt},
]
response = self.client.chat.completions.create(
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
**{**self.DEFAULT_KWARGS, **kwargs},
messages=messages, model=llm_model or self.DEFAULT_MODEL, **kwargs
)
return response.choices[0].message.content
+7 -28
View File
@@ -1,30 +1,20 @@
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar
import instructor
from pydantic import BaseModel
import openai as oa
from ..logging import logger
from ..settings import settings
from ._base import BaseProvider
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "xai"
DEFAULT_MODEL = "grok-beta"
BASE_URL = "https://api.x.ai/v1"
DEFAULT_MAX_TOKENS = 1000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
class XAI(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@@ -34,12 +24,6 @@ class XAI(BaseProvider):
"""The raw OpenAI client."""
if not self.api_key:
raise ValueError("XAI API key is required")
try:
import openai as oa
except ImportError as exc:
raise ImportError(
"Please install the `openai` package: `pip install openai`"
) from exc
return oa.OpenAI(
api_key=self.api_key,
base_url=BASE_URL,
@@ -50,8 +34,7 @@ class XAI(BaseProvider):
"""A client patched with Instructor."""
return instructor.from_openai(self.client)
@logger
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
def send_conversation(self, conversation: "Conversation", **kwargs):
"""Send a conversation to the OpenAI API."""
from ..models import Message
@@ -62,7 +45,7 @@ class XAI(BaseProvider):
response = self.client.chat.completions.create(
model=conversation.llm_model or self.DEFAULT_MODEL,
messages=messages,
**{**self.DEFAULT_KWARGS, **kwargs},
**kwargs,
)
# Get the response content from the OpenAI response
@@ -77,14 +60,10 @@ class XAI(BaseProvider):
llm_provider=PROVIDER_NAME,
)
@logger
def structured_response(
self, prompt: str, response_model: Type[T], *, llm_model: str
) -> T:
def structured_response(self, prompt: str, response_model, *, llm_model: str):
raise NotImplementedError("XAI does not support structured responses")
@logger
def generate_text(self, prompt: str, *, llm_model: str, **kwargs) -> str:
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
messages = [
{"role": "user", "content": prompt},
]
@@ -92,7 +71,7 @@ class XAI(BaseProvider):
response = self.client.chat.completions.create(
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
**{**self.DEFAULT_KWARGS, **kwargs},
**kwargs,
)
return str(response.choices[0].message.content)
return response.choices[0].message.content
+6 -29
View File
@@ -1,47 +1,23 @@
from typing import Optional, Union
from typing import Literal, Optional, Union
from pydantic import Field, SecretStr, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
logging_level = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
class LoggingConfig(BaseSettings):
"""The class that holds all the logging settings for the application."""
is_enabled: bool = Field(False, description="Enable logging")
enabled: bool = Field(False, description="Enable logging")
level: logging_level = Field("INFO", description="The logging level")
model_config = SettingsConfigDict(extra="forbid")
def enable_logfire(self, **kwargs) -> None:
"""Enable logging for the application."""
# adding imports here to avoid forced dependencies
try:
import logfire
from logging import basicConfig
except ImportError as e:
raise ImportError(
"To enable logging, please install logfire: `pip install logfire`"
) from e
self.is_enabled = True
logfire.configure(**kwargs)
basicConfig(handlers=[logfire.LogfireLoggingHandler()])
try:
logfire.configure(**kwargs)
basicConfig(handlers=[logfire.LogfireLoggingHandler()])
except Exception as e:
self.is_enabled = False # Reset flag on failure
raise RuntimeError("Failed to configure logging") from e
def disable_logfire(self) -> None:
"""Disable logging for the application."""
self.is_enabled = False
class Settings(BaseSettings):
"""The class that holds all the API keys for the application."""
AMAZON_PROFILE_NAME: Optional[str] = Field("default", description="AWS Named Profile")
ANTHROPIC_API_KEY: Optional[SecretStr] = Field(
None, description="API key for Anthropic"
)
@@ -53,6 +29,7 @@ class Settings(BaseSettings):
)
XAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for xAI")
DEFAULT_LLM_PROVIDER: str = Field("openai", description="The default LLM provider")
DEFAULT_LLM_MODEL: str = Field("gpt-4o-mini", description="The default LLM model")
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", case_sensitive=True, extra="ignore"
-15
View File
@@ -1,15 +0,0 @@
import os
import sys
import pytest
# Add the project root to the Python path.
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from simplemind import Session
@pytest.fixture
def sm():
"""Fixture that provides a simplemind Session instance with default settings."""
return Session()
-2
View File
@@ -1,2 +0,0 @@
def test_basic_math():
assert 1 + 1 == 2
-31
View File
@@ -1,31 +0,0 @@
import pytest
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
from pydantic import BaseModel
class ResponseModel(BaseModel):
result: int
@pytest.mark.parametrize(
"provider_cls",
[
Anthropic,
Gemini,
OpenAI,
Groq,
Ollama,
Amazon
],
)
def test_generate_data(provider_cls):
provider = provider_cls()
prompt = "What is 2+2?"
data = provider.structured_response(prompt=prompt, response_model=ResponseModel)
assert isinstance(data, ResponseModel)
assert isinstance(data.result, int)
-24
View File
@@ -1,24 +0,0 @@
import pytest
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
@pytest.mark.parametrize(
"provider_cls",
[
Anthropic,
Gemini,
OpenAI,
Groq,
Ollama,
Amazon,
],
)
def test_generate_text(provider_cls):
provider = provider_cls()
prompt = "What is 2+2?"
response = provider.generate_text(prompt=prompt, llm_model=provider.DEFAULT_MODEL)
assert isinstance(response, str)
assert len(response) > 0