mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 14:50:16 +00:00
Compare commits
40 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 75a42044e5 | |||
| cc66dbf8e5 | |||
| a174e60a1e | |||
| b03695f626 | |||
| 082bc24e91 | |||
| aca1b87180 | |||
| 1ff4c5660e | |||
| 241a7ab402 | |||
| 76fa7521eb | |||
| cbec2c5f6d | |||
| 34f463839c | |||
| c648a922b4 | |||
| 873f5ba5f8 | |||
| 28a7b2f140 | |||
| 173162e798 | |||
| cd0be3ad89 | |||
| 3dd2e1b248 | |||
| ad1800840d | |||
| d62f297b68 | |||
| a2597709d2 | |||
| 1455b5ba13 | |||
| 0fb54d1987 | |||
| fe06331662 | |||
| 56b1e65d70 | |||
| 4b3e1bc6dd | |||
| f5b922ade8 | |||
| 3a7383425f | |||
| 92c10fc41e | |||
| 75c42278a2 | |||
| c25f1e1058 | |||
| 2a5966eb10 | |||
| f19263d309 | |||
| 25b742db1f | |||
| caceba381d | |||
| 0795464fd7 | |||
| 8d83050a64 | |||
| d82effdfb1 | |||
| e648292cb3 | |||
| 37a9333be3 | |||
| cbc3739411 |
@@ -4,3 +4,4 @@ export GROQ_API_KEY=""
|
||||
export OLLAMA_HOST_URL=""
|
||||
export OPENAI_API_KEY=""
|
||||
export XAI_API_KEY=""
|
||||
export AMAZON_PROFILE_NAME=""
|
||||
|
||||
@@ -1,6 +1,18 @@
|
||||
Release History
|
||||
===============
|
||||
|
||||
## 0.2.0 (2024-11-01)
|
||||
|
||||
- Add Amazon Bedrock provider.
|
||||
- Make all provider optional dependencies. Use `$ pip install 'simplemind[full]'` to install all providers.
|
||||
- General improvements.
|
||||
|
||||
## 0.1.7 (2024-11-01)
|
||||
|
||||
- Add `logger` decorator.
|
||||
- Add `sm.enable_logfire()` function.
|
||||
- General improvements.
|
||||
|
||||
## 0.1.6 (2024-10-31)
|
||||
|
||||
- Add `sm.Plugin` syntax sugar.
|
||||
|
||||
@@ -19,6 +19,7 @@ With Simplemind, tapping into AI is as easy as a friendly conversation.
|
||||
To specify a specific provider or model, you can use the `llm_provider` and `llm_model` parameters when calling: `generate_text`, `generate_data`, or `create_conversation`. The APIs remain identital between all supported providers/models.
|
||||
|
||||
- [**Anthropic's Claude**](https://www.anthropic.com/claude)
|
||||
- [**Amazon Bedrock**](https://aws.amazon.com/bedrock/)
|
||||
- [**Google's Gemini**](https://gemini.google/)
|
||||
- [**Groq's Groq**](https://groq.com/)
|
||||
- [**Ollama**](https://ollama.com)
|
||||
@@ -28,6 +29,7 @@ To specify a specific provider or model, you can use the `llm_provider` and `llm
|
||||
If you want to see Simplemind support, additional providers or models, please send a pull request!
|
||||
|
||||
## Why SimpleMind?
|
||||
|
||||
- **Intuitive**: Built with Pythonic simplicity and readability in mind.
|
||||
- **For Humans**: Emphasizes a human-friendly interface, just like `requests` for HTTP.
|
||||
- **Open Source**: Simplemind is open source, and contributions are always welcome!
|
||||
@@ -39,7 +41,7 @@ Also, why not? :)
|
||||
Simplemind takes care of the complex API calls so you can focus on what matters—building, experimenting, and creating.
|
||||
|
||||
```bash
|
||||
$ pip install simplemind
|
||||
$ pip install 'simplemind[full]'
|
||||
```
|
||||
|
||||
First, authenticate your API keys by setting them in the environment variables:
|
||||
@@ -56,7 +58,6 @@ Next, import Simplemind and start using it:
|
||||
import simplemind as sm
|
||||
```
|
||||
|
||||
|
||||
## Examples
|
||||
|
||||
Here are some examples of how to use Simplemind:
|
||||
@@ -90,6 +91,33 @@ class Poem(BaseModel):
|
||||
title='Eternal Embrace' content='In the quiet hours of the night,\nWhen stars whisper secrets bright,\nTwo hearts beat in a gentle rhyme,\nDancing through the sands of time.\n\nWith every glance, a spark ignites,\nA flame that warms the coldest nights,\nIn laughter shared and whispers sweet,\nLove paints the world, a masterpiece.\n\nThrough stormy skies and sunlit days,\nIn myriad forms, it finds its ways,\nA tender touch, a knowing sigh,\nIn love’s embrace, we learn to fly.\n\nAs seasons change and moments fade,\nIn the tapestry of dreams we’ve laid,\nLove’s threads endure, forever bind,\nA timeless bond, two souls aligned.\n\nSo here’s to love, both bright and true,\nA gift we give, anew, anew,\nIn every heartbeat, every prayer,\nA story written in the air.'
|
||||
```
|
||||
|
||||
#### A more complex example
|
||||
|
||||
```python
|
||||
class InstructionStep(BaseModel):
|
||||
step_number: int
|
||||
instruction: str
|
||||
|
||||
class RecipeIngredient(BaseModel):
|
||||
name: str
|
||||
quantity: float
|
||||
unit: str
|
||||
|
||||
class Recipe(BaseModel):
|
||||
name: str
|
||||
ingredients: list[RecipeIngredient]
|
||||
instructions: list[InstructionStep]
|
||||
|
||||
recipe = sm.generate_data(
|
||||
"Write a recipe for chocolate chip cookies",
|
||||
llm_model="gpt-4o-mini",
|
||||
llm_provider="openai",
|
||||
response_model=Recipe,
|
||||
)
|
||||
```
|
||||
|
||||
Special thanks to [@jxnl](https://github.com/jxnl) for building [Instructor](https://github.com/jxnl/instructor), which makes this possible!
|
||||
|
||||
### Conversational AI
|
||||
|
||||
SimpleMind also allows for easy conversational flows:
|
||||
@@ -163,6 +191,7 @@ conversation.add_message(
|
||||
text="Please write a poem about the moon",
|
||||
)
|
||||
```
|
||||
|
||||
```pycon
|
||||
>>> conversation.send()
|
||||
In the vast expanse where stars do play,
|
||||
@@ -198,11 +227,18 @@ The universe is never done.
|
||||
|
||||
Simple, yet effective.
|
||||
|
||||
### Logging
|
||||
|
||||
Simplemind uses [logfire](https://logfire.ai) for logging. To enable logging, call `sm.enable_logfire()`.
|
||||
|
||||
### More Examples
|
||||
|
||||
Please see the [examples](examples) directory for executable examples.
|
||||
|
||||
-------------------
|
||||
---
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions of all kinds. Feel free to open issues for bug reports or feature requests, and submit pull requests to make SimpleMind even better.
|
||||
|
||||
To get started:
|
||||
@@ -213,8 +249,9 @@ To get started:
|
||||
4. Submit a pull request.
|
||||
|
||||
## License
|
||||
|
||||
Simplemind is licensed under the Apache 2.0 License.
|
||||
|
||||
## Acknowledgements
|
||||
Simplemind is inspired by the philosophy of "code for humans" and aims to make working with AI models accessible to all. Special thanks to the open-source community for their contributions and inspiration.
|
||||
|
||||
Simplemind is inspired by the philosophy of "code for humans" and aims to make working with AI models accessible to all. Special thanks to the open-source community for their contributions and inspiration.
|
||||
|
||||
+1
-1
@@ -16,7 +16,7 @@ import simplemind
|
||||
project = "simplemind"
|
||||
copyright = "2024 Kenneth Reitz"
|
||||
author = "Kenneth Reitz"
|
||||
release = "v0.1.6"
|
||||
release = "v0.2.0"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||
|
||||
+13
-2
@@ -1,10 +1,21 @@
|
||||
[project]
|
||||
name = "simplemind"
|
||||
version = "0.1.6"
|
||||
version = "0.2.0"
|
||||
description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = ["pydantic", "pydantic-settings", "instructor", "openai", "anthropic", "ollama", "groq", "google-generativeai"]
|
||||
dependencies = ["pydantic", "pydantic-settings", "instructor", "logfire"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
full = [
|
||||
"openai",
|
||||
"anthropic",
|
||||
"ollama",
|
||||
"groq",
|
||||
"google-generativeai",
|
||||
"botocore",
|
||||
"boto3"
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
|
||||
@@ -113,6 +113,11 @@ def generate_text(
|
||||
return provider.generate_text(prompt=prompt, llm_model=llm_model, **kwargs)
|
||||
|
||||
|
||||
def enable_logfire() -> None:
|
||||
"""Enable logfire logging."""
|
||||
settings.logging.enable_logfire()
|
||||
|
||||
|
||||
# Syntax sugar.
|
||||
Plugin = BasePlugin
|
||||
|
||||
@@ -125,4 +130,5 @@ __all__ = [
|
||||
"BasePlugin",
|
||||
"Session",
|
||||
"Plugin",
|
||||
"enable_logfire",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
import time
|
||||
from typing import Any, Callable
|
||||
|
||||
import logfire
|
||||
|
||||
from .settings import settings
|
||||
|
||||
|
||||
def logger(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""A decorator that logs the function parameters, function returns,
|
||||
and exceptions raised if logging is enabled, using logfire.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
if not settings.logging.is_enabled:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
logfire.info(f"Calling {func.__name__} with args: {args}, kwargs: {kwargs}")
|
||||
t1 = time.perf_counter()
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
t2 = time.perf_counter()
|
||||
logfire.info(f"{func.__name__} returned: {result} in {t2-t1} seconds")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
t2 = time.perf_counter()
|
||||
logfire.error(f"Error in {func.__name__}: {e} in {t2-t1} seconds")
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
@@ -1,6 +1,6 @@
|
||||
from types import TracebackType
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -7,5 +7,6 @@ from .groq import Groq
|
||||
from .ollama import Ollama
|
||||
from .openai import OpenAI
|
||||
from .xai import XAI
|
||||
from .amazon import Amazon
|
||||
|
||||
providers: List[Type[BaseProvider]] = [Anthropic, Gemini, Groq, OpenAI, Ollama, XAI]
|
||||
providers: List[Type[BaseProvider]] = [Anthropic, Gemini, Groq, OpenAI, Ollama, XAI, Amazon]
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import Any, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Type, TypeVar
|
||||
|
||||
from instructor import Instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
from typing import Type, TypeVar
|
||||
|
||||
import instructor
|
||||
import anthropic
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._base import BaseProvider
|
||||
from ..settings import settings
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
PROVIDER_NAME = "amazon"
|
||||
DEFAULT_MODEL = "anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
DEFAULT_MAX_TOKENS = 5_000
|
||||
|
||||
|
||||
class Amazon(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
|
||||
def __init__(self, profile_name: str | None = None):
|
||||
self.profile_name = profile_name or settings.AMAZON_PROFILE_NAME
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""The AnthropicBedrock client."""
|
||||
if not self.profile_name:
|
||||
raise ValueError("Profile name is not provided")
|
||||
|
||||
return anthropic.AnthropicBedrock(aws_profile=self.profile_name)
|
||||
|
||||
@property
|
||||
def structured_client(self):
|
||||
"""A client patched with Instructor."""
|
||||
return instructor.from_anthropic(self.client)
|
||||
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
from ..models import Message
|
||||
|
||||
messages = [
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or DEFAULT_MODEL, messages=messages, **kwargs
|
||||
)
|
||||
|
||||
# Get the response content from the OpenAI response
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
# Create and return a properly formatted Message instance
|
||||
return Message(
|
||||
role="assistant",
|
||||
text=assistant_message.content or "",
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or DEFAULT_MODEL,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(
|
||||
self, prompt, response_model: Type[T], *, llm_model: str | None = None, **kwargs
|
||||
) -> T:
|
||||
# Ensure messages are provided in kwargs
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.structured_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
response_model=response_model,
|
||||
max_tokens=DEFAULT_MAX_TOKENS,
|
||||
**kwargs,
|
||||
)
|
||||
return response
|
||||
|
||||
def generate_text(self, prompt, *, llm_model, **kwargs):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.client.messages.create(
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
max_tokens=DEFAULT_MAX_TOKENS,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return response.content[0].text
|
||||
@@ -1,13 +1,16 @@
|
||||
from functools import cached_property
|
||||
from typing import Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import anthropic
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
@@ -30,6 +33,13 @@ class Anthropic(BaseProvider):
|
||||
"""The raw Anthropic client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("Anthropic API key is required")
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `anthropic` package: `pip install anthropic`"
|
||||
) from exc
|
||||
|
||||
return anthropic.Anthropic(api_key=self.api_key)
|
||||
|
||||
@cached_property
|
||||
@@ -37,7 +47,8 @@ class Anthropic(BaseProvider):
|
||||
"""A client patched with Instructor."""
|
||||
return instructor.from_anthropic(self.client)
|
||||
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||
"""Send a conversation to the Anthropic API."""
|
||||
from ..models import Message
|
||||
|
||||
@@ -63,6 +74,7 @@ class Anthropic(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
def structured_response(
|
||||
self, response_model: Type[T], *, llm_model: str | None = None, **kwargs
|
||||
) -> T:
|
||||
@@ -80,8 +92,9 @@ class Anthropic(BaseProvider):
|
||||
response_model=response_model,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
return response
|
||||
return response_model.model_validate(response)
|
||||
|
||||
@logger
|
||||
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
|
||||
@@ -2,21 +2,25 @@
|
||||
# IT is not currently working as desired.
|
||||
|
||||
from functools import cached_property
|
||||
from typing import Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import google.generativeai as genai
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
PROVIDER_NAME = "gemini"
|
||||
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "gemini"
|
||||
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
|
||||
|
||||
|
||||
class Gemini(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
@@ -25,12 +29,21 @@ class Gemini(BaseProvider):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
self.model_name = DEFAULT_MODEL
|
||||
|
||||
def set_model(self, model_name: str):
|
||||
self.model_name = model_name
|
||||
|
||||
@cached_property
|
||||
def client(self, model_name: str = DEFAULT_MODEL):
|
||||
def client(self):
|
||||
"""The raw Gemini client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("Gemini API key is required")
|
||||
self.model_name = model_name
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `google-generativeai` package: `pip install google-generativeai`"
|
||||
) from exc
|
||||
genai.configure(api_key=self.api_key)
|
||||
return genai.GenerativeModel(model_name=self.model_name)
|
||||
|
||||
@cached_property
|
||||
@@ -38,6 +51,7 @@ class Gemini(BaseProvider):
|
||||
"""A Gemini client patched with Instructor."""
|
||||
return instructor.from_gemini(self.client)
|
||||
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation") -> "Message":
|
||||
"""Send a conversation to the Gemini API."""
|
||||
from ..models import Message
|
||||
@@ -64,9 +78,11 @@ class Gemini(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
|
||||
"""Send a structured response to the Gemini API."""
|
||||
llm_model = kwargs.pop("llm_model", self.model_name)
|
||||
# Only try to pop if the key exists
|
||||
kwargs.pop("llm_model", None) # Add default value of None
|
||||
|
||||
try:
|
||||
response = self.structured_client.chat.completions.create(
|
||||
@@ -79,12 +95,12 @@ class Gemini(BaseProvider):
|
||||
raise RuntimeError(
|
||||
f"Failed to send structured response to Gemini API: {e}"
|
||||
) from e
|
||||
return response
|
||||
return response_model.model_validate(response)
|
||||
|
||||
@logger
|
||||
def generate_text(self, prompt: str, **kwargs) -> str:
|
||||
"""Generate text using the Gemini API."""
|
||||
llm_model = kwargs.pop("llm_model", self.model_name)
|
||||
|
||||
kwargs.pop("llm_model")
|
||||
try:
|
||||
response = self.client.generate_content(prompt, **kwargs)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,22 +1,29 @@
|
||||
from functools import cached_property
|
||||
from typing import Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import groq
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
PROVIDER_NAME = "groq"
|
||||
DEFAULT_MODEL = "llama3-8b-8192"
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "groq"
|
||||
DEFAULT_MODEL = "llama3-8b-8192"
|
||||
DEFAULT_MAX_TOKENS = 1_000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
|
||||
|
||||
class Groq(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
@@ -26,6 +33,12 @@ class Groq(BaseProvider):
|
||||
"""The raw Groq client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("Groq API key is required")
|
||||
try:
|
||||
import groq
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `groq` package: `pip install groq`"
|
||||
) from exc
|
||||
return groq.Groq(api_key=self.api_key)
|
||||
|
||||
@cached_property
|
||||
@@ -33,6 +46,7 @@ class Groq(BaseProvider):
|
||||
"""A client patched with Instructor."""
|
||||
return instructor.from_groq(self.client)
|
||||
|
||||
@logger
|
||||
def send_conversation(
|
||||
self,
|
||||
conversation: "Conversation",
|
||||
@@ -48,7 +62,7 @@ class Groq(BaseProvider):
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
# Get the response content from the Groq response
|
||||
@@ -63,6 +77,7 @@ class Groq(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
|
||||
# Ensure messages are provided in kwargs
|
||||
messages = [
|
||||
@@ -73,17 +88,18 @@ class Groq(BaseProvider):
|
||||
messages=messages,
|
||||
response_model=response_model,
|
||||
model=kwargs.pop("llm_model", self.DEFAULT_MODEL),
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
return response
|
||||
return response_model.model_validate(response)
|
||||
|
||||
@logger
|
||||
def generate_text(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
llm_model: str,
|
||||
**kwargs,
|
||||
):
|
||||
) -> str:
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
@@ -91,7 +107,7 @@ class Groq(BaseProvider):
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
return str(response.choices[0].message.content)
|
||||
|
||||
@@ -1,25 +1,30 @@
|
||||
from functools import cached_property
|
||||
from typing import Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
import ollama as ol
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "ollama"
|
||||
DEFAULT_MODEL = "llama3.2"
|
||||
DEFAULT_TIMEOUT = 60
|
||||
DEFAULT_KWARGS = {}
|
||||
|
||||
|
||||
class Ollama(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
TIMEOUT = DEFAULT_TIMEOUT
|
||||
|
||||
def __init__(self, host_url: str | None = None):
|
||||
@@ -30,6 +35,12 @@ class Ollama(BaseProvider):
|
||||
"""The raw Ollama client."""
|
||||
if not self.host_url:
|
||||
raise ValueError("No ollama host url provided")
|
||||
try:
|
||||
import ollama as ol
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `ollama` package: `pip install ollama`"
|
||||
) from exc
|
||||
return ol.Client(timeout=self.TIMEOUT, host=self.host_url)
|
||||
|
||||
@cached_property
|
||||
@@ -43,7 +54,8 @@ class Ollama(BaseProvider):
|
||||
mode=instructor.Mode.JSON,
|
||||
)
|
||||
|
||||
def send_conversation(self, conversation: "Conversation") -> "Message":
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||
"""Send a conversation to the Ollama API."""
|
||||
from ..models import Message
|
||||
|
||||
@@ -51,7 +63,9 @@ class Ollama(BaseProvider):
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
]
|
||||
response = self.client.chat(
|
||||
model=conversation.llm_model or DEFAULT_MODEL, messages=messages
|
||||
model=conversation.llm_model or DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
assistant_message = response.get("message")
|
||||
|
||||
@@ -64,6 +78,7 @@ class Ollama(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
def structured_response(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -81,18 +96,23 @@ class Ollama(BaseProvider):
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
response_model=response_model,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
return response
|
||||
return response_model.model_validate(response)
|
||||
|
||||
def generate_text(self, prompt: str, *, llm_model: str | None = None) -> str:
|
||||
@logger
|
||||
def generate_text(
|
||||
self, prompt: str, *, llm_model: str | None = None, **kwargs
|
||||
) -> str:
|
||||
"""Generate text using the Ollama API."""
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.client.chat(
|
||||
messages=messages, model=llm_model or self.DEFAULT_MODEL
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
return response.get("message", {}).get("content", "")
|
||||
|
||||
@@ -1,22 +1,28 @@
|
||||
from functools import cached_property
|
||||
from typing import Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
import openai as oa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
PROVIDER_NAME = "openai"
|
||||
DEFAULT_MODEL = "gpt-4o-mini"
|
||||
DEFAULT_MAX_TOKENS = 1_000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
|
||||
|
||||
class OpenAI(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
@@ -26,6 +32,12 @@ class OpenAI(BaseProvider):
|
||||
"""The raw OpenAI client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("OpenAI API key is required")
|
||||
try:
|
||||
import openai as oa
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `openai` package: `pip install openai`"
|
||||
) from exc
|
||||
return oa.OpenAI(api_key=self.api_key)
|
||||
|
||||
@cached_property
|
||||
@@ -33,7 +45,8 @@ class OpenAI(BaseProvider):
|
||||
"""A OpenAI client with Instructor."""
|
||||
return instructor.from_openai(self.client)
|
||||
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
from ..models import Message
|
||||
|
||||
@@ -42,7 +55,9 @@ class OpenAI(BaseProvider):
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or DEFAULT_MODEL, messages=messages, **kwargs
|
||||
model=conversation.llm_model or DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
# Get the response content from the OpenAI response
|
||||
@@ -57,6 +72,7 @@ class OpenAI(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
def structured_response(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -74,16 +90,19 @@ class OpenAI(BaseProvider):
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
response_model=response_model,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
return response
|
||||
return response_model.model_validate(response)
|
||||
|
||||
@logger
|
||||
def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs):
|
||||
"""Generate text using the OpenAI API."""
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages, model=llm_model or self.DEFAULT_MODEL, **kwargs
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
@@ -1,20 +1,30 @@
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
import openai as oa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "xai"
|
||||
DEFAULT_MODEL = "grok-beta"
|
||||
BASE_URL = "https://api.x.ai/v1"
|
||||
DEFAULT_MAX_TOKENS = 1000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
|
||||
|
||||
class XAI(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
@@ -24,6 +34,12 @@ class XAI(BaseProvider):
|
||||
"""The raw OpenAI client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("XAI API key is required")
|
||||
try:
|
||||
import openai as oa
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `openai` package: `pip install openai`"
|
||||
) from exc
|
||||
return oa.OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=BASE_URL,
|
||||
@@ -34,7 +50,8 @@ class XAI(BaseProvider):
|
||||
"""A client patched with Instructor."""
|
||||
return instructor.from_openai(self.client)
|
||||
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
from ..models import Message
|
||||
|
||||
@@ -45,7 +62,7 @@ class XAI(BaseProvider):
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
# Get the response content from the OpenAI response
|
||||
@@ -60,10 +77,14 @@ class XAI(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(self, prompt: str, response_model, *, llm_model: str):
|
||||
@logger
|
||||
def structured_response(
|
||||
self, prompt: str, response_model: Type[T], *, llm_model: str
|
||||
) -> T:
|
||||
raise NotImplementedError("XAI does not support structured responses")
|
||||
|
||||
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
|
||||
@logger
|
||||
def generate_text(self, prompt: str, *, llm_model: str, **kwargs) -> str:
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
@@ -71,7 +92,7 @@ class XAI(BaseProvider):
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
return str(response.choices[0].message.content)
|
||||
|
||||
+29
-5
@@ -1,23 +1,47 @@
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import Field, SecretStr, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
logging_level = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
||||
|
||||
|
||||
class LoggingConfig(BaseSettings):
|
||||
"""The class that holds all the logging settings for the application."""
|
||||
|
||||
enabled: bool = Field(False, description="Enable logging")
|
||||
level: logging_level = Field("INFO", description="The logging level")
|
||||
is_enabled: bool = Field(False, description="Enable logging")
|
||||
|
||||
model_config = SettingsConfigDict(extra="forbid")
|
||||
|
||||
def enable_logfire(self, **kwargs) -> None:
|
||||
"""Enable logging for the application."""
|
||||
# adding imports here to avoid forced dependencies
|
||||
try:
|
||||
import logfire
|
||||
from logging import basicConfig
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"To enable logging, please install logfire: `pip install logfire`"
|
||||
) from e
|
||||
|
||||
self.is_enabled = True
|
||||
logfire.configure(**kwargs)
|
||||
basicConfig(handlers=[logfire.LogfireLoggingHandler()])
|
||||
|
||||
try:
|
||||
logfire.configure(**kwargs)
|
||||
basicConfig(handlers=[logfire.LogfireLoggingHandler()])
|
||||
except Exception as e:
|
||||
self.is_enabled = False # Reset flag on failure
|
||||
raise RuntimeError("Failed to configure logging") from e
|
||||
|
||||
def disable_logfire(self) -> None:
|
||||
"""Disable logging for the application."""
|
||||
self.is_enabled = False
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""The class that holds all the API keys for the application."""
|
||||
|
||||
AMAZON_PROFILE_NAME: Optional[str] = Field("default", description="AWS Named Profile")
|
||||
ANTHROPIC_API_KEY: Optional[SecretStr] = Field(
|
||||
None, description="API key for Anthropic"
|
||||
)
|
||||
|
||||
+2
-2
@@ -1,8 +1,8 @@
|
||||
import pytest
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
# Add the project root to the Python path.
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama
|
||||
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -16,6 +18,7 @@ class ResponseModel(BaseModel):
|
||||
OpenAI,
|
||||
Groq,
|
||||
Ollama,
|
||||
Amazon
|
||||
],
|
||||
)
|
||||
def test_generate_data(provider_cls):
|
||||
@@ -25,4 +28,4 @@ def test_generate_data(provider_cls):
|
||||
data = provider.structured_response(prompt=prompt, response_model=ResponseModel)
|
||||
|
||||
assert isinstance(data, ResponseModel)
|
||||
assert type(data.result) == int
|
||||
assert isinstance(data.result, int)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama
|
||||
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -11,6 +11,7 @@ from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama
|
||||
OpenAI,
|
||||
Groq,
|
||||
Ollama,
|
||||
Amazon,
|
||||
],
|
||||
)
|
||||
def test_generate_text(provider_cls):
|
||||
|
||||
Reference in New Issue
Block a user