mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 22:50:18 +00:00
Compare commits
93 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cd0be3ad89 | |||
| 3dd2e1b248 | |||
| ad1800840d | |||
| d62f297b68 | |||
| a2597709d2 | |||
| 1455b5ba13 | |||
| 0fb54d1987 | |||
| fe06331662 | |||
| 56b1e65d70 | |||
| 4b3e1bc6dd | |||
| f5b922ade8 | |||
| 3a7383425f | |||
| 92c10fc41e | |||
| caceba381d | |||
| 0795464fd7 | |||
| d82effdfb1 | |||
| e648292cb3 | |||
| 37a9333be3 | |||
| cbc3739411 | |||
| 7c8f22bef1 | |||
| 9c3f2a6df3 | |||
| febf5473d5 | |||
| 48ac97f070 | |||
| c41a3f00fb | |||
| 25ee4ae32c | |||
| 984721f02b | |||
| 69c8723770 | |||
| 0c10d5676a | |||
| e0ddf41e15 | |||
| f940ae2dfd | |||
| 85fa4f5879 | |||
| 44581e8fe3 | |||
| 9503ec7fd3 | |||
| 418f36dcc0 | |||
| bf9683cfd0 | |||
| 3909588f3e | |||
| 33d8f18bff | |||
| d7388ef0d5 | |||
| 02d10bfda9 | |||
| 5dc6e7b006 | |||
| 62933c8553 | |||
| f0a6be73f8 | |||
| 9257a04f34 | |||
| 64dbe9a2e7 | |||
| ccb8311089 | |||
| 0c29380501 | |||
| 7b43208a03 | |||
| e931fd0eae | |||
| 736d942527 | |||
| 3505c8758d | |||
| 308886e608 | |||
| 9c18d726d5 | |||
| 8f43b660ea | |||
| 222d3025b1 | |||
| fb6c4c289b | |||
| c28e2a3839 | |||
| 2bed7221b3 | |||
| 1504edad78 | |||
| fd7289c8d3 | |||
| c4674fc98f | |||
| 25806221eb | |||
| 5505a3e18d | |||
| 48291c37c5 | |||
| 4b2b094ea6 | |||
| 33e4046ac3 | |||
| 7fe8e91111 | |||
| 42fc0e6bc5 | |||
| ec4f6f9c06 | |||
| 499d3b3e14 | |||
| dd2f5a46d2 | |||
| bd0c739c9a | |||
| 473a054afa | |||
| 55c28a2356 | |||
| 9bd1653b5e | |||
| 59401c4be4 | |||
| 20ad9437e5 | |||
| 9db95cc87b | |||
| d711afec68 | |||
| 9d7fd4cce5 | |||
| 4aa470bb20 | |||
| 88e118cb53 | |||
| 73316c32a3 | |||
| e1331822aa | |||
| baee6e9959 | |||
| 8096609c2e | |||
| 4225f61df3 | |||
| 034e967ecb | |||
| f9c4cce9a4 | |||
| 78f6649969 | |||
| 4f1e52b1f8 | |||
| 74c09d5c87 | |||
| 1f66bac645 | |||
| f828f9991b |
+4
-3
@@ -1,5 +1,6 @@
|
|||||||
export OPENAI_API_KEY=""
|
|
||||||
export ANTHROPIC_API_KEY=""
|
export ANTHROPIC_API_KEY=""
|
||||||
export XAI_API_KEY=""
|
export GEMINI_API_KEY=""
|
||||||
export OLLAMA_HOST_URL=""
|
|
||||||
export GROQ_API_KEY=""
|
export GROQ_API_KEY=""
|
||||||
|
export OLLAMA_HOST_URL=""
|
||||||
|
export OPENAI_API_KEY=""
|
||||||
|
export XAI_API_KEY=""
|
||||||
|
|||||||
@@ -167,3 +167,4 @@ cython_debug/
|
|||||||
|
|
||||||
src/**
|
src/**
|
||||||
requirements.txt
|
requirements.txt
|
||||||
|
Pipfile
|
||||||
|
|||||||
+15
-1
@@ -1,9 +1,23 @@
|
|||||||
Release History
|
Release History
|
||||||
===============
|
===============
|
||||||
|
|
||||||
|
## 0.1.6 (2024-10-31)
|
||||||
|
|
||||||
|
- Add `sm.Plugin` syntax sugar.
|
||||||
|
- Improvements to Anthropic provider, related to max tokens.
|
||||||
|
- General improvements.
|
||||||
|
- Add tests for structured response.
|
||||||
|
- Add `llm_model` to `structured_response`.
|
||||||
|
|
||||||
|
## 0.1.5 (2024-10-31)
|
||||||
|
|
||||||
|
- Add Gemini provider.
|
||||||
|
- Add structured response to Gemini provider.
|
||||||
|
- Support for Python 3.10.
|
||||||
|
|
||||||
## 0.1.4 (2024-10-30)
|
## 0.1.4 (2024-10-30)
|
||||||
|
|
||||||
- Introduce `Session` class to manage multiple conversations.
|
- Introduce `Session` class to manage repeatability.
|
||||||
- General improvements.
|
- General improvements.
|
||||||
|
|
||||||
## 0.1.3 (2024-10-30)
|
## 0.1.3 (2024-10-30)
|
||||||
|
|||||||
@@ -1,43 +1,46 @@
|
|||||||
# SimpleMind: AI for Humans™
|
# Simplemind: AI for Humans™
|
||||||
|
|
||||||
[](https://mutable.ai/kennethreitz/simplemind)
|
**Keep it simple, keep it human.**
|
||||||
|
|
||||||
SimpleMind is an AI library designed to simplify your experience with AI APIs in Python. Inspired by a "for humans" philosophy, it abstracts away complexity, giving developers an intuitive and human-friendly way to interact with powerful AI capabilities.
|
Simplemind is AI library designed to simplify your experience with AI APIs in Python. Inspired by a "for humans" philosophy, it abstracts away complexity, giving developers an intuitive and human-friendly way to interact with powerful AI capabilities.
|
||||||
|
|
||||||
With SimpleMind, tapping into AI is as easy as a friendly conversation.
|

|
||||||
|
|
||||||
```bash
|
|
||||||
$ pip install simplemind
|
|
||||||
```
|
|
||||||
|
|
||||||
**Note**: SimpleMind is currently in beta. We welcome feedback and contributions to help make it even better.
|
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
|
With Simplemind, tapping into AI is as easy as a friendly conversation.
|
||||||
|
|
||||||
- **Easy-to-use AI tools**: SimpleMind provides simple interfaces to popular AI services.
|
- **Easy-to-use AI tools**: SimpleMind provides simple interfaces to popular AI services.
|
||||||
- **Human-centered design**: The library prioritizes readability and usability—no need to be an expert to start experimenting.
|
- **Human-centered design**: The library prioritizes readability and usability—no need to be an expert to start experimenting.
|
||||||
- **Minimal configuration**: Get started quickly, without worrying about configuration headaches.
|
- **Minimal configuration**: Get started quickly, without worrying about configuration headaches.
|
||||||
|
|
||||||
## Supported APIs
|
## Supported APIs
|
||||||
|
|
||||||
To specify a specific provider or model, you can use the `llm_provider` and `llm_model` parameters when calling: `generate_text`, `generate_data`, or `create_conversation`.
|
To specify a specific provider or model, you can use the `llm_provider` and `llm_model` parameters when calling: `generate_text`, `generate_data`, or `create_conversation`. The APIs remain identital between all supported providers/models.
|
||||||
|
|
||||||
- **[OpenAI's GPT](https://openai.com/gpt)**
|
- [**Anthropic's Claude**](https://www.anthropic.com/claude)
|
||||||
- **[Anthropic's Claude](https://www.anthropic.com/claude)**
|
- [**Google's Gemini**](https://gemini.google/)
|
||||||
- **[xAI's Grok](https://x.ai/)**
|
- [**Groq's Groq**](https://groq.com/)
|
||||||
- **[Groq's Groq](https://groq.com/)**
|
- [**Ollama**](https://ollama.com)
|
||||||
- **[Ollama](https://ollama.com)**
|
- [**OpenAI's GPT**](https://openai.com/gpt)
|
||||||
|
- [**xAI's Grok**](https://x.ai/)
|
||||||
|
|
||||||
If you'd like to see SimpleMind support additional providers or models, please send a pull request!
|
If you want to see Simplemind support, additional providers or models, please send a pull request!
|
||||||
|
|
||||||
## Why SimpleMind?
|
## Why SimpleMind?
|
||||||
- **Intuitive**: Built with Pythonic simplicity and readability in mind.
|
- **Intuitive**: Built with Pythonic simplicity and readability in mind.
|
||||||
- **For Humans**: Emphasizes a human-friendly interface, just like `requests` for HTTP.
|
- **For Humans**: Emphasizes a human-friendly interface, just like `requests` for HTTP.
|
||||||
- **Open Source**: SimpleMind is open source, and contributions are always welcome!
|
- **Open Source**: Simplemind is open source, and contributions are always welcome!
|
||||||
|
|
||||||
|
Also, why not? :)
|
||||||
|
|
||||||
## Quickstart
|
## Quickstart
|
||||||
|
|
||||||
SimpleMind takes care of the complex API calls so you can focus on what matters—building, experimenting, and creating.
|
Simplemind takes care of the complex API calls so you can focus on what matters—building, experimenting, and creating.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ pip install simplemind
|
||||||
|
```
|
||||||
|
|
||||||
First, authenticate your API keys by setting them in the environment variables:
|
First, authenticate your API keys by setting them in the environment variables:
|
||||||
|
|
||||||
@@ -45,9 +48,9 @@ First, authenticate your API keys by setting them in the environment variables:
|
|||||||
$ export OPENAI_API_KEY="sk-..."
|
$ export OPENAI_API_KEY="sk-..."
|
||||||
```
|
```
|
||||||
|
|
||||||
This pattern allows you to keep your API keys private and out of your codebase. Other supported environment variables: `ANTHROPIC_API_KEY`, `XAI_API_KEY`, and `GROQ_API_KEY`.
|
This pattern allows you to keep your API keys private and out of your codebase. Other supported environment variables: `ANTHROPIC_API_KEY`, `XAI_API_KEY`, `GROQ_API_KEY`, and `GEMINI_API_KEY`.
|
||||||
|
|
||||||
Next, import SimpleMind and start using it:
|
Next, import Simplemind and start using it:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import simplemind as sm
|
import simplemind as sm
|
||||||
@@ -56,7 +59,7 @@ import simplemind as sm
|
|||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
Here are some examples of how to use SimpleMind:
|
Here are some examples of how to use Simplemind:
|
||||||
|
|
||||||
### Text Completion
|
### Text Completion
|
||||||
|
|
||||||
@@ -114,8 +117,6 @@ To continue the conversation, you can call `conversation.send()` again, which re
|
|||||||
You can use the `Session` class to set default parameters for all calls:
|
You can use the `Session` class to set default parameters for all calls:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import simplemind as sm
|
|
||||||
|
|
||||||
# Create a session with defaults
|
# Create a session with defaults
|
||||||
gpt_4o_mini = sm.Session(llm_provider="openai", llm_model="gpt-4o-mini")
|
gpt_4o_mini = sm.Session(llm_provider="openai", llm_model="gpt-4o-mini")
|
||||||
|
|
||||||
@@ -138,10 +139,7 @@ response = gpt_4o_mini.generate_text(
|
|||||||
Harnessing the power of Python, you can easily create your own plugins to add additional functionality to your conversations:
|
Harnessing the power of Python, you can easily create your own plugins to add additional functionality to your conversations:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import simplemind as sm
|
class SimpleMemoryPlugin(sm.BasePlugin):
|
||||||
|
|
||||||
|
|
||||||
class SimpleMemoryPlugin:
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.memories = [
|
self.memories = [
|
||||||
"the earth has fictionally beeen destroyed.",
|
"the earth has fictionally beeen destroyed.",
|
||||||
@@ -198,8 +196,12 @@ A reminder that in tales and fun,
|
|||||||
The universe is never done.
|
The universe is never done.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Simple, yet effective.
|
||||||
|
|
||||||
Please see the [examples](examples) directory for executable examples.
|
Please see the [examples](examples) directory for executable examples.
|
||||||
|
|
||||||
|
-------------------
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
We welcome contributions of all kinds. Feel free to open issues for bug reports or feature requests, and submit pull requests to make SimpleMind even better.
|
We welcome contributions of all kinds. Feel free to open issues for bug reports or feature requests, and submit pull requests to make SimpleMind even better.
|
||||||
|
|
||||||
@@ -211,11 +213,8 @@ To get started:
|
|||||||
4. Submit a pull request.
|
4. Submit a pull request.
|
||||||
|
|
||||||
## License
|
## License
|
||||||
SimpleMind is licensed under the Apache 2.0 License.
|
Simplemind is licensed under the Apache 2.0 License.
|
||||||
|
|
||||||
## Acknowledgements
|
## Acknowledgements
|
||||||
SimpleMind is inspired by the philosophy of "code for humans" and aims to make working with AI models accessible to all. Special thanks to the open-source community for their contributions and inspiration.
|
Simplemind is inspired by the philosophy of "code for humans" and aims to make working with AI models accessible to all. Special thanks to the open-source community for their contributions and inspiration.
|
||||||
|
|
||||||
---------------
|
|
||||||
|
|
||||||
SimpleMind: Keep it simple, keep it human.
|
|
||||||
|
|||||||
+1
-1
@@ -16,7 +16,7 @@ import simplemind
|
|||||||
project = "simplemind"
|
project = "simplemind"
|
||||||
copyright = "2024 Kenneth Reitz"
|
copyright = "2024 Kenneth Reitz"
|
||||||
author = "Kenneth Reitz"
|
author = "Kenneth Reitz"
|
||||||
release = "v0.1.4"
|
release = "v0.1.6"
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
from _context import sm
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
import openai
|
|
||||||
import faiss
|
|
||||||
import numpy as np
|
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
import openai
|
||||||
|
from _context import sm
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
class ContextualMemoryPlugin:
|
|
||||||
|
class ContextualMemoryPlugin(sm.BasePlugin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
from typing import List, Iterator
|
from typing import Iterator, List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from _context import sm
|
from _context import sm
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class Movie(BaseModel):
|
class Movie(BaseModel):
|
||||||
@@ -25,7 +24,7 @@ class QuotesList(BaseModel):
|
|||||||
quotes: List[MovieQuote]
|
quotes: List[MovieQuote]
|
||||||
|
|
||||||
|
|
||||||
def gen_quotes(n=10) -> Iterator[MovieQuote]:
|
def gen_quotes(n: int = 10) -> Iterator[MovieQuote]:
|
||||||
"""Generate a list of quotes from famous movies."""
|
"""Generate a list of quotes from famous movies."""
|
||||||
|
|
||||||
for q in sm.generate_data(
|
for q in sm.generate_data(
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
from _context import sm
|
from _context import sm
|
||||||
|
|
||||||
|
|
||||||
class MathPlugin:
|
class MathPlugin(sm.BasePlugin):
|
||||||
def send_hook(self, conversation: sm.Conversation):
|
def send_hook(self, conversation: sm.Conversation):
|
||||||
last_user_message = conversation.get_last_message(role="user")
|
last_user_message = conversation.get_last_message(role="user")
|
||||||
|
if last_user_message is None:
|
||||||
|
return
|
||||||
if "calculate" in last_user_message.text.lower():
|
if "calculate" in last_user_message.text.lower():
|
||||||
expression = last_user_message.text.lower().replace("calculate", "").strip()
|
expression = last_user_message.text.lower().replace("calculate", "").strip()
|
||||||
try:
|
try:
|
||||||
@@ -14,7 +16,7 @@ class MathPlugin:
|
|||||||
except Exception:
|
except Exception:
|
||||||
conversation.add_message(
|
conversation.add_message(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
text="I'm sorry, I couldn't compute that expression.",
|
text="I'm sorry, I couldn't compute that expression. Please try again.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
from _context import sm
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
from _context import sm
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class SentimentAnalysis(BaseModel):
|
class SentimentAnalysis(BaseModel):
|
||||||
sentiment: Literal["positive", "negative", "neutral"]
|
sentiment: Literal["positive", "negative", "neutral"]
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from _context import sm
|
from _context import sm
|
||||||
|
|
||||||
|
|
||||||
class SimpleMemoryPlugin(sm.BasePlugin):
|
class SimpleMemoryPlugin:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.memories = [
|
self.memories = [
|
||||||
"the earth has fictionally beeen destroyed.",
|
"the earth has fictionally beeen destroyed.",
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import simplemind as sm
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import simplemind as sm
|
||||||
|
|
||||||
|
|
||||||
class ConversationPlugin(sm.BasePlugin):
|
class ConversationPlugin(sm.BasePlugin):
|
||||||
def post_send_hook(self, conversation, response):
|
def post_send_hook(self, conversation, response):
|
||||||
@@ -8,7 +9,7 @@ class ConversationPlugin(sm.BasePlugin):
|
|||||||
print(f"{conversation.llm_model}:\n{response.text.strip()}\n\n------------\n")
|
print(f"{conversation.llm_model}:\n{response.text.strip()}\n\n------------\n")
|
||||||
|
|
||||||
|
|
||||||
def have_conversation(rounds=3):
|
def have_conversation(rounds: int = 3):
|
||||||
# Create two conversations - one for each AI
|
# Create two conversations - one for each AI
|
||||||
with (
|
with (
|
||||||
sm.create_conversation(
|
sm.create_conversation(
|
||||||
|
|||||||
Binary file not shown.
|
After Width: | Height: | Size: 944 KiB |
+3
-3
@@ -1,10 +1,10 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "simplemind"
|
name = "simplemind"
|
||||||
version = "0.1.4"
|
version = "0.1.6"
|
||||||
description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases."
|
description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases."
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.11"
|
requires-python = ">=3.10"
|
||||||
dependencies = ["pydantic", "pydantic-settings", "instructor", "openai", "anthropic", "ollama", "groq"]
|
dependencies = ["pydantic", "pydantic-settings", "instructor", "openai", "anthropic", "ollama", "groq", "google-generativeai"]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["hatchling"]
|
requires = ["hatchling"]
|
||||||
|
|||||||
+32
-10
@@ -1,8 +1,8 @@
|
|||||||
from typing import List, Optional, Type
|
from typing import List, Type
|
||||||
|
|
||||||
from .models import Conversation, BasePlugin, BaseModel
|
from .models import BaseModel, BasePlugin, Conversation
|
||||||
from .utils import find_provider
|
|
||||||
from .settings import settings
|
from .settings import settings
|
||||||
|
from .utils import find_provider
|
||||||
|
|
||||||
|
|
||||||
class Session:
|
class Session:
|
||||||
@@ -16,7 +16,7 @@ class Session:
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
llm_provider: str = settings.DEFAULT_LLM_PROVIDER,
|
llm_provider: str = settings.DEFAULT_LLM_PROVIDER,
|
||||||
llm_model: str = settings.DEFAULT_LLM_MODEL,
|
llm_model: str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.llm_provider = llm_provider
|
self.llm_provider = llm_provider
|
||||||
@@ -46,7 +46,7 @@ class Session:
|
|||||||
**merged_kwargs,
|
**merged_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_conversation(self, **kwargs) -> "Conversation":
|
def create_conversation(self, **kwargs) -> Conversation:
|
||||||
"""Create a conversation using the session's default provider and model."""
|
"""Create a conversation using the session's default provider and model."""
|
||||||
merged_kwargs = {**self.default_kwargs, **kwargs}
|
merged_kwargs = {**self.default_kwargs, **kwargs}
|
||||||
return create_conversation(
|
return create_conversation(
|
||||||
@@ -55,13 +55,18 @@ class Session:
|
|||||||
|
|
||||||
|
|
||||||
def create_conversation(
|
def create_conversation(
|
||||||
llm_model=None, llm_provider=None, *, plugins: Optional[List[BasePlugin]] = None
|
*,
|
||||||
):
|
llm_model: str | None = None,
|
||||||
|
llm_provider: str | None = None,
|
||||||
|
plugins: List[BasePlugin] | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Conversation:
|
||||||
"""Create a new conversation."""
|
"""Create a new conversation."""
|
||||||
|
|
||||||
# Create the conversation.
|
# Create the conversation.
|
||||||
conversation = Conversation(
|
conversation = Conversation(
|
||||||
llm_model=llm_model, llm_provider=llm_provider or settings.DEFAULT_LLM_PROVIDER
|
llm_model=llm_model,
|
||||||
|
llm_provider=llm_provider or settings.DEFAULT_LLM_PROVIDER,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add plugins to the conversation.
|
# Add plugins to the conversation.
|
||||||
@@ -71,7 +76,14 @@ def create_conversation(
|
|||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
|
|
||||||
def generate_data(prompt, *, llm_model=None, llm_provider=None, response_model=None):
|
def generate_data(
|
||||||
|
prompt: str,
|
||||||
|
*,
|
||||||
|
llm_model: str | None = None,
|
||||||
|
llm_provider: str | None = None,
|
||||||
|
response_model: Type[BaseModel],
|
||||||
|
**kwargs,
|
||||||
|
) -> BaseModel:
|
||||||
"""Generate structured data from a given prompt."""
|
"""Generate structured data from a given prompt."""
|
||||||
|
|
||||||
# Find the provider.
|
# Find the provider.
|
||||||
@@ -85,7 +97,13 @@ def generate_data(prompt, *, llm_model=None, llm_provider=None, response_model=N
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_text(prompt, *, llm_model=None, llm_provider=None, **kwargs):
|
def generate_text(
|
||||||
|
prompt: str,
|
||||||
|
*,
|
||||||
|
llm_model: str | None = None,
|
||||||
|
llm_provider: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
"""Generate text from a given prompt."""
|
"""Generate text from a given prompt."""
|
||||||
|
|
||||||
# Find the provider.
|
# Find the provider.
|
||||||
@@ -95,6 +113,9 @@ def generate_text(prompt, *, llm_model=None, llm_provider=None, **kwargs):
|
|||||||
return provider.generate_text(prompt=prompt, llm_model=llm_model, **kwargs)
|
return provider.generate_text(prompt=prompt, llm_model=llm_model, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Syntax sugar.
|
||||||
|
Plugin = BasePlugin
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"create_conversation",
|
"create_conversation",
|
||||||
"find_provider",
|
"find_provider",
|
||||||
@@ -103,4 +124,5 @@ __all__ = [
|
|||||||
"settings",
|
"settings",
|
||||||
"BasePlugin",
|
"BasePlugin",
|
||||||
"Session",
|
"Session",
|
||||||
|
"Plugin",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -0,0 +1,27 @@
|
|||||||
|
import time
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
import logfire
|
||||||
|
|
||||||
|
from .settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
def logger(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
"""A @logger decorator that logs the function parameters, function returns, and exceptions raised if logging is enabled."""
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs) -> Any:
|
||||||
|
if not settings.logging.enabled:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
logfire.info(f"Calling {func.__name__} with args: {args}, kwargs: {kwargs}")
|
||||||
|
t1 = time.perf_counter()
|
||||||
|
try:
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
t2 = time.perf_counter()
|
||||||
|
logfire.info(f"{func.__name__} returned: {result} in {t2-t1} seconds")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
t2 = time.perf_counter()
|
||||||
|
logfire.error(f"Error in {func.__name__}: {e} in {t2-t1} seconds")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return wrapper
|
||||||
+37
-17
@@ -1,18 +1,18 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from types import TracebackType
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from .utils import find_provider
|
from .utils import find_provider
|
||||||
|
|
||||||
|
|
||||||
MESSAGE_ROLE = Literal["system", "user", "assistant"]
|
MESSAGE_ROLE = Literal["system", "user", "assistant"]
|
||||||
|
|
||||||
|
|
||||||
class SMBaseModel(BaseModel):
|
class SMBaseModel(BaseModel):
|
||||||
|
"""The base SimpleMind model class."""
|
||||||
|
|
||||||
date_created: datetime = Field(default_factory=datetime.now)
|
date_created: datetime = Field(default_factory=datetime.now)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
@@ -22,34 +22,36 @@ class SMBaseModel(BaseModel):
|
|||||||
return str(self)
|
return str(self)
|
||||||
|
|
||||||
|
|
||||||
class BasePlugin:
|
class BasePlugin(SMBaseModel):
|
||||||
"""The base conversation plugin class."""
|
"""The base conversation plugin class."""
|
||||||
|
|
||||||
# Plugin metadata.
|
# Plugin metadata.
|
||||||
meta: Dict[str, Any] = {}
|
meta: Dict[str, Any] = {}
|
||||||
|
|
||||||
def initialize_hook(self, conversation: "Conversation"):
|
def initialize_hook(self, conversation: "Conversation") -> Any:
|
||||||
"""Initialize a hook for the plugin."""
|
"""Initialize a hook for the plugin."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def cleanup_hook(self, conversation: "Conversation"):
|
def cleanup_hook(self, conversation: "Conversation") -> Any:
|
||||||
"""Cleanup a hook for the plugin."""
|
"""Cleanup a hook for the plugin."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def add_message_hook(self, conversation: "Conversation", message: "Message"):
|
def add_message_hook(self, conversation: "Conversation", message: "Message") -> Any:
|
||||||
"""Add a message hook for the plugin."""
|
"""Add a message hook for the plugin."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def pre_send_hook(self, conversation: "Conversation"):
|
def pre_send_hook(self, conversation: "Conversation") -> Any:
|
||||||
"""Pre-send hook for the plugin."""
|
"""Pre-send hook for the plugin."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def post_send_hook(self, conversation: "Conversation", response: "Message"):
|
def post_send_hook(self, conversation: "Conversation", response: "Message") -> Any:
|
||||||
"""Post-send hook for the plugin."""
|
"""Post-send hook for the plugin."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class Message(SMBaseModel):
|
class Message(SMBaseModel):
|
||||||
|
"""A message in a conversation."""
|
||||||
|
|
||||||
role: MESSAGE_ROLE
|
role: MESSAGE_ROLE
|
||||||
text: str
|
text: str
|
||||||
meta: Dict[str, Any] = {}
|
meta: Dict[str, Any] = {}
|
||||||
@@ -61,7 +63,16 @@ class Message(SMBaseModel):
|
|||||||
return f"<Message role={self.role} text={self.text!r}>"
|
return f"<Message role={self.role} text={self.text!r}>"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_raw_response(cls, *, text: str, raw):
|
def from_raw_response(cls, *, text: str, raw: Any) -> "Message":
|
||||||
|
"""Create a Message instance from a raw response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The message text.
|
||||||
|
raw (Any): The raw response data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Message: A new Message instance.
|
||||||
|
"""
|
||||||
self = cls()
|
self = cls()
|
||||||
self.text = text
|
self.text = text
|
||||||
self.raw = raw
|
self.raw = raw
|
||||||
@@ -69,11 +80,13 @@ class Message(SMBaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Conversation(SMBaseModel):
|
class Conversation(SMBaseModel):
|
||||||
|
"""A conversation between a user and an assistant."""
|
||||||
|
|
||||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
messages: List[Message] = []
|
messages: List[Message] = []
|
||||||
llm_model: Optional[str] = None
|
llm_model: Optional[str] = None
|
||||||
llm_provider: Optional[str] = None
|
llm_provider: Optional[str] = None
|
||||||
plugins: List[Any] = []
|
plugins: List[BasePlugin] = []
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"<Conversation id={self.id!r}>"
|
return f"<Conversation id={self.id!r}>"
|
||||||
@@ -89,8 +102,13 @@ class Conversation(SMBaseModel):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(
|
||||||
# Execute all cleanup hooks.
|
self,
|
||||||
|
exc_type: type[BaseException],
|
||||||
|
exc_value: BaseException,
|
||||||
|
traceback: TracebackType,
|
||||||
|
) -> None:
|
||||||
|
"""Execute all cleanup hooks."""
|
||||||
for plugin in self.plugins:
|
for plugin in self.plugins:
|
||||||
if hasattr(plugin, "cleanup_hook"):
|
if hasattr(plugin, "cleanup_hook"):
|
||||||
try:
|
try:
|
||||||
@@ -99,7 +117,7 @@ class Conversation(SMBaseModel):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def prepend_system_message(
|
def prepend_system_message(
|
||||||
self, role: str, text: str, meta: Optional[Dict[str, Any]] = None
|
self, role: MESSAGE_ROLE, text: str, meta: Dict[str, Any] | None = None
|
||||||
):
|
):
|
||||||
"""Prepend a system message to the conversation."""
|
"""Prepend a system message to the conversation."""
|
||||||
self.messages = [Message(role=role, text=text, meta=meta or {})] + self.messages
|
self.messages = [Message(role=role, text=text, meta=meta or {})] + self.messages
|
||||||
@@ -127,7 +145,9 @@ class Conversation(SMBaseModel):
|
|||||||
self.messages.append(Message(role=role, text=text, meta=meta))
|
self.messages.append(Message(role=role, text=text, meta=meta))
|
||||||
|
|
||||||
def send(
|
def send(
|
||||||
self, llm_model: Optional[str] = None, llm_provider: Optional[str] = None
|
self,
|
||||||
|
llm_model: str | None = None,
|
||||||
|
llm_provider: str | None = None,
|
||||||
) -> Message:
|
) -> Message:
|
||||||
"""Send the conversation to the LLM."""
|
"""Send the conversation to the LLM."""
|
||||||
|
|
||||||
@@ -156,10 +176,10 @@ class Conversation(SMBaseModel):
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def get_last_message(self, role: MESSAGE_ROLE) -> Optional[Message]:
|
def get_last_message(self, role: MESSAGE_ROLE) -> Message | None:
|
||||||
"""Get the last message with the given role."""
|
"""Get the last message with the given role."""
|
||||||
return next((m for m in reversed(self.messages) if m.role == role), None)
|
return next((m for m in reversed(self.messages) if m.role == role), None)
|
||||||
|
|
||||||
def add_plugin(self, plugin: Any):
|
def add_plugin(self, plugin: BasePlugin) -> None:
|
||||||
"""Add a plugin to the conversation."""
|
"""Add a plugin to the conversation."""
|
||||||
self.plugins.append(plugin)
|
self.plugins.append(plugin)
|
||||||
|
|||||||
@@ -2,9 +2,10 @@ from typing import List, Type
|
|||||||
|
|
||||||
from ._base import BaseProvider
|
from ._base import BaseProvider
|
||||||
from .anthropic import Anthropic
|
from .anthropic import Anthropic
|
||||||
|
from .gemini import Gemini
|
||||||
from .groq import Groq
|
from .groq import Groq
|
||||||
from .openai import OpenAI
|
|
||||||
from .ollama import Ollama
|
from .ollama import Ollama
|
||||||
|
from .openai import OpenAI
|
||||||
from .xai import XAI
|
from .xai import XAI
|
||||||
|
|
||||||
providers: List[Type[BaseProvider]] = [Anthropic, Groq, OpenAI, Ollama, XAI]
|
providers: List[Type[BaseProvider]] = [Anthropic, Gemini, Groq, OpenAI, Ollama, XAI]
|
||||||
|
|||||||
@@ -1,6 +1,14 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import TYPE_CHECKING, Any, Type, TypeVar
|
||||||
|
|
||||||
from instructor import Instructor
|
from instructor import Instructor
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..models import Conversation, Message
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
class BaseProvider(ABC):
|
class BaseProvider(ABC):
|
||||||
@@ -9,13 +17,13 @@ class BaseProvider(ABC):
|
|||||||
NAME: str
|
NAME: str
|
||||||
DEFAULT_MODEL: str
|
DEFAULT_MODEL: str
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def client(self):
|
def client(self) -> Any:
|
||||||
"""The instructor client for the provider."""
|
"""The instructor client for the provider."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def structured_client(self) -> Instructor:
|
def structured_client(self) -> Instructor:
|
||||||
"""The structured client for the provider."""
|
"""The structured client for the provider."""
|
||||||
@@ -27,7 +35,7 @@ class BaseProvider(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def structured_response(self, prompt: str, response_model, **kwargs):
|
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
|
||||||
"""Get a structured response."""
|
"""Get a structured response."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@@ -1,36 +1,54 @@
|
|||||||
from typing import Union
|
from functools import cached_property
|
||||||
|
from typing import TYPE_CHECKING, Type, TypeVar
|
||||||
|
|
||||||
import anthropic
|
|
||||||
import instructor
|
import instructor
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from ._base import BaseProvider
|
from ..logging import logger
|
||||||
from ..settings import settings
|
from ..settings import settings
|
||||||
|
from ._base import BaseProvider
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..models import Conversation, Message
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_NAME = "anthropic"
|
PROVIDER_NAME = "anthropic"
|
||||||
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
||||||
DEFAULT_MAX_TOKENS = 1000
|
DEFAULT_MAX_TOKENS = 1_000
|
||||||
|
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||||
|
|
||||||
|
|
||||||
class Anthropic(BaseProvider):
|
class Anthropic(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = PROVIDER_NAME
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = DEFAULT_MODEL
|
||||||
|
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||||
|
|
||||||
def __init__(self, api_key: Union[str, None] = None):
|
def __init__(self, api_key: str | None = None):
|
||||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
def client(self):
|
def client(self):
|
||||||
"""The raw Anthropic client."""
|
"""The raw Anthropic client."""
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("Anthropic API key is required")
|
raise ValueError("Anthropic API key is required")
|
||||||
|
try:
|
||||||
|
import anthropic
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install the `anthropic` package: `pip install anthropic`"
|
||||||
|
) from exc
|
||||||
|
|
||||||
return anthropic.Anthropic(api_key=self.api_key)
|
return anthropic.Anthropic(api_key=self.api_key)
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
def structured_client(self):
|
def structured_client(self):
|
||||||
"""A client patched with Instructor."""
|
"""A client patched with Instructor."""
|
||||||
return instructor.from_anthropic(self.client)
|
return instructor.from_anthropic(self.client)
|
||||||
|
|
||||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
@logger
|
||||||
|
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||||
"""Send a conversation to the Anthropic API."""
|
"""Send a conversation to the Anthropic API."""
|
||||||
from ..models import Message
|
from ..models import Message
|
||||||
|
|
||||||
@@ -41,8 +59,7 @@ class Anthropic(BaseProvider):
|
|||||||
response = self.client.messages.create(
|
response = self.client.messages.create(
|
||||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=DEFAULT_MAX_TOKENS,
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the response content from the Anthropic response
|
# Get the response content from the Anthropic response
|
||||||
@@ -57,13 +74,28 @@ class Anthropic(BaseProvider):
|
|||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=PROVIDER_NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
def structured_response(self, model, response_model, **kwargs):
|
@logger
|
||||||
response = self.structured_client.messages.create(
|
def structured_response(
|
||||||
model=model, response_model=response_model or self.DEFAULT_MODEL, **kwargs
|
self, response_model: Type[T], *, llm_model: str | None = None, **kwargs
|
||||||
)
|
) -> T:
|
||||||
return response
|
model = llm_model or self.DEFAULT_MODEL
|
||||||
|
|
||||||
def generate_text(self, prompt, *, llm_model, **kwargs):
|
# Extract the prompt from kwargs if it exists
|
||||||
|
prompt = kwargs.pop("prompt", kwargs.pop("messages", ""))
|
||||||
|
|
||||||
|
# Format the messages properly
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
|
||||||
|
response = self.structured_client.messages.create(
|
||||||
|
model=model,
|
||||||
|
messages=messages, # Add the messages parameter
|
||||||
|
response_model=response_model,
|
||||||
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
|
)
|
||||||
|
return response_model.model_validate(response)
|
||||||
|
|
||||||
|
@logger
|
||||||
|
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
@@ -71,8 +103,7 @@ class Anthropic(BaseProvider):
|
|||||||
response = self.client.messages.create(
|
response = self.client.messages.create(
|
||||||
model=llm_model or self.DEFAULT_MODEL,
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=DEFAULT_MAX_TOKENS,
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return response.content[0].text
|
return response.content[0].text
|
||||||
|
|||||||
@@ -0,0 +1,109 @@
|
|||||||
|
# TODO: this is a placeholder file for the Gemini provider
|
||||||
|
# IT is not currently working as desired.
|
||||||
|
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import TYPE_CHECKING, Type, TypeVar
|
||||||
|
|
||||||
|
import instructor
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..logging import logger
|
||||||
|
from ..settings import settings
|
||||||
|
from ._base import BaseProvider
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..models import Conversation, Message
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
PROVIDER_NAME = "gemini"
|
||||||
|
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
|
||||||
|
|
||||||
|
|
||||||
|
class Gemini(BaseProvider):
|
||||||
|
NAME = PROVIDER_NAME
|
||||||
|
DEFAULT_MODEL = DEFAULT_MODEL
|
||||||
|
|
||||||
|
def __init__(self, api_key: str | None = None):
|
||||||
|
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||||
|
self.model_name = DEFAULT_MODEL
|
||||||
|
|
||||||
|
def set_model(self, model_name: str):
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def client(self):
|
||||||
|
"""The raw Gemini client."""
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("Gemini API key is required")
|
||||||
|
try:
|
||||||
|
import google.generativeai as genai
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install the `google-generativeai` package: `pip install google-generativeai`"
|
||||||
|
) from exc
|
||||||
|
genai.configure(api_key=self.api_key)
|
||||||
|
return genai.GenerativeModel(model_name=self.model_name)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def structured_client(self):
|
||||||
|
"""A Gemini client patched with Instructor."""
|
||||||
|
return instructor.from_gemini(self.client)
|
||||||
|
|
||||||
|
@logger
|
||||||
|
def send_conversation(self, conversation: "Conversation") -> "Message":
|
||||||
|
"""Send a conversation to the Gemini API."""
|
||||||
|
from ..models import Message
|
||||||
|
|
||||||
|
# Convert messages to Gemini's format
|
||||||
|
chat = self.client.start_chat()
|
||||||
|
|
||||||
|
# Send all previous messages to establish context
|
||||||
|
for msg in conversation.messages[:-1]: # All messages except the last one
|
||||||
|
chat.send_message(msg.text)
|
||||||
|
|
||||||
|
# Send the final message and get response
|
||||||
|
try:
|
||||||
|
response = chat.send_message(conversation.messages[-1].text)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to send conversation to Gemini API: {e}") from e
|
||||||
|
|
||||||
|
# Create and return a properly formatted Message instance
|
||||||
|
return Message(
|
||||||
|
role="assistant",
|
||||||
|
text=response.text,
|
||||||
|
raw=response,
|
||||||
|
llm_model=self.model_name,
|
||||||
|
llm_provider=PROVIDER_NAME,
|
||||||
|
)
|
||||||
|
|
||||||
|
@logger
|
||||||
|
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
|
||||||
|
"""Send a structured response to the Gemini API."""
|
||||||
|
# Only try to pop if the key exists
|
||||||
|
kwargs.pop("llm_model", None) # Add default value of None
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.structured_client.chat.completions.create(
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
response_model=response_model,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Handle the exception appropriately, e.g., log the error or raise a custom exception
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to send structured response to Gemini API: {e}"
|
||||||
|
) from e
|
||||||
|
return response_model.model_validate(response)
|
||||||
|
|
||||||
|
@logger
|
||||||
|
def generate_text(self, prompt: str, **kwargs) -> str:
|
||||||
|
"""Generate text using the Gemini API."""
|
||||||
|
kwargs.pop("llm_model")
|
||||||
|
try:
|
||||||
|
response = self.client.generate_content(prompt, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
# Handle the exception appropriately, e.g., log the error or raise a custom exception
|
||||||
|
raise RuntimeError(f"Failed to generate text with Gemini API: {e}") from e
|
||||||
|
return response.text
|
||||||
@@ -1,34 +1,52 @@
|
|||||||
from typing import Union
|
from functools import cached_property
|
||||||
|
from typing import TYPE_CHECKING, Type, TypeVar
|
||||||
|
|
||||||
import groq
|
|
||||||
import instructor
|
import instructor
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from ._base import BaseProvider
|
from ..logging import logger
|
||||||
from ..settings import settings
|
from ..settings import settings
|
||||||
|
from ._base import BaseProvider
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..models import Conversation, Message
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_NAME = "groq"
|
PROVIDER_NAME = "groq"
|
||||||
DEFAULT_MODEL = "llama3-8b-8192"
|
DEFAULT_MODEL = "llama3-8b-8192"
|
||||||
|
DEFAULT_MAX_TOKENS = 1_000
|
||||||
|
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||||
|
|
||||||
|
|
||||||
class Groq(BaseProvider):
|
class Groq(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = PROVIDER_NAME
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = DEFAULT_MODEL
|
||||||
|
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||||
|
|
||||||
def __init__(self, api_key: Union[str, None] = None):
|
def __init__(self, api_key: str | None = None):
|
||||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
def client(self):
|
def client(self):
|
||||||
"""The raw Groq client."""
|
"""The raw Groq client."""
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("Groq API key is required")
|
raise ValueError("Groq API key is required")
|
||||||
|
try:
|
||||||
|
import groq
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install the `groq` package: `pip install groq`"
|
||||||
|
) from exc
|
||||||
return groq.Groq(api_key=self.api_key)
|
return groq.Groq(api_key=self.api_key)
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
def structured_client(self):
|
def structured_client(self):
|
||||||
"""A client patched with Instructor."""
|
"""A client patched with Instructor."""
|
||||||
return instructor.from_groq(self.client)
|
return instructor.from_groq(self.client)
|
||||||
|
|
||||||
|
@logger
|
||||||
def send_conversation(
|
def send_conversation(
|
||||||
self,
|
self,
|
||||||
conversation: "Conversation",
|
conversation: "Conversation",
|
||||||
@@ -44,7 +62,7 @@ class Groq(BaseProvider):
|
|||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
**kwargs,
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the response content from the Groq response
|
# Get the response content from the Groq response
|
||||||
@@ -59,7 +77,8 @@ class Groq(BaseProvider):
|
|||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=PROVIDER_NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
def structured_response(self, prompt: str, response_model, **kwargs):
|
@logger
|
||||||
|
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
|
||||||
# Ensure messages are provided in kwargs
|
# Ensure messages are provided in kwargs
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
@@ -68,17 +87,19 @@ class Groq(BaseProvider):
|
|||||||
response = self.structured_client.chat.completions.create(
|
response = self.structured_client.chat.completions.create(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
**kwargs,
|
model=kwargs.pop("llm_model", self.DEFAULT_MODEL),
|
||||||
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
return response
|
return response_model.model_validate(response)
|
||||||
|
|
||||||
|
@logger
|
||||||
def generate_text(
|
def generate_text(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
*,
|
*,
|
||||||
llm_model: str,
|
llm_model: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
) -> str:
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
@@ -86,7 +107,7 @@ class Groq(BaseProvider):
|
|||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=llm_model or self.DEFAULT_MODEL,
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
**kwargs,
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
return response.choices[0].message.content
|
return str(response.choices[0].message.content)
|
||||||
|
|||||||
@@ -1,32 +1,50 @@
|
|||||||
import ollama as ol
|
from functools import cached_property
|
||||||
|
from typing import TYPE_CHECKING, Type, TypeVar
|
||||||
|
|
||||||
import instructor
|
import instructor
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from ._base import BaseProvider
|
from ..logging import logger
|
||||||
from ..settings import settings
|
from ..settings import settings
|
||||||
|
from ._base import BaseProvider
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..models import Conversation, Message
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_NAME = "ollama"
|
PROVIDER_NAME = "ollama"
|
||||||
DEFAULT_MODEL = "llama3.2"
|
DEFAULT_MODEL = "llama3.2"
|
||||||
DEFAULT_TIMEOUT = 60
|
DEFAULT_TIMEOUT = 60
|
||||||
|
DEFAULT_KWARGS = {}
|
||||||
|
|
||||||
|
|
||||||
class Ollama(BaseProvider):
|
class Ollama(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = PROVIDER_NAME
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = DEFAULT_MODEL
|
||||||
|
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||||
TIMEOUT = DEFAULT_TIMEOUT
|
TIMEOUT = DEFAULT_TIMEOUT
|
||||||
|
|
||||||
def __init__(self, host_url: str = None):
|
def __init__(self, host_url: str | None = None):
|
||||||
self.host_url = host_url or settings.OLLAMA_HOST_URL
|
self.host_url = host_url or settings.OLLAMA_HOST_URL
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
def client(self):
|
def client(self):
|
||||||
"""The raw Ollama client."""
|
"""The raw Ollama client."""
|
||||||
if not self.host_url:
|
if not self.host_url:
|
||||||
raise ValueError("No ollama host url provided")
|
raise ValueError("No ollama host url provided")
|
||||||
|
try:
|
||||||
|
import ollama as ol
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install the `ollama` package: `pip install ollama`"
|
||||||
|
) from exc
|
||||||
return ol.Client(timeout=self.TIMEOUT, host=self.host_url)
|
return ol.Client(timeout=self.TIMEOUT, host=self.host_url)
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
def structured_client(self):
|
def structured_client(self) -> instructor.Instructor:
|
||||||
"""A client patched with Instructor."""
|
"""A client patched with Instructor."""
|
||||||
return instructor.from_openai(
|
return instructor.from_openai(
|
||||||
OpenAI(
|
OpenAI(
|
||||||
@@ -36,7 +54,8 @@ class Ollama(BaseProvider):
|
|||||||
mode=instructor.Mode.JSON,
|
mode=instructor.Mode.JSON,
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_conversation(self, conversation: "Conversation"):
|
@logger
|
||||||
|
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||||
"""Send a conversation to the Ollama API."""
|
"""Send a conversation to the Ollama API."""
|
||||||
from ..models import Message
|
from ..models import Message
|
||||||
|
|
||||||
@@ -44,7 +63,9 @@ class Ollama(BaseProvider):
|
|||||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||||
]
|
]
|
||||||
response = self.client.chat(
|
response = self.client.chat(
|
||||||
model=conversation.llm_model or DEFAULT_MODEL, messages=messages
|
model=conversation.llm_model or DEFAULT_MODEL,
|
||||||
|
messages=messages,
|
||||||
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
assistant_message = response.get("message")
|
assistant_message = response.get("message")
|
||||||
|
|
||||||
@@ -57,7 +78,16 @@ class Ollama(BaseProvider):
|
|||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=PROVIDER_NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
def structured_response(self, prompt, response_model, *, llm_model: str, **kwargs):
|
@logger
|
||||||
|
def structured_response(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
response_model: Type[T],
|
||||||
|
*,
|
||||||
|
llm_model: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> T:
|
||||||
|
"""Get a structured response from the Ollama API."""
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
@@ -66,17 +96,23 @@ class Ollama(BaseProvider):
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
model=llm_model or self.DEFAULT_MODEL,
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
**kwargs,
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
return response
|
return response_model.model_validate(response)
|
||||||
|
|
||||||
def generate_text(self, prompt, *, llm_model):
|
@logger
|
||||||
|
def generate_text(
|
||||||
|
self, prompt: str, *, llm_model: str | None = None, **kwargs
|
||||||
|
) -> str:
|
||||||
|
"""Generate text using the Ollama API."""
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
|
|
||||||
response = self.client.chat(
|
response = self.client.chat(
|
||||||
messages=messages, model=llm_model or self.DEFAULT_MODEL
|
messages=messages,
|
||||||
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
return response.get("message").get("content")
|
return response.get("message", {}).get("content", "")
|
||||||
|
|||||||
@@ -1,35 +1,52 @@
|
|||||||
from typing import Union
|
from functools import cached_property
|
||||||
|
from typing import TYPE_CHECKING, Type, TypeVar
|
||||||
|
|
||||||
import instructor
|
import instructor
|
||||||
import openai as oa
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from ._base import BaseProvider
|
from ..logging import logger
|
||||||
from ..settings import settings
|
from ..settings import settings
|
||||||
|
from ._base import BaseProvider
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..models import Conversation, Message
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
PROVIDER_NAME = "openai"
|
PROVIDER_NAME = "openai"
|
||||||
DEFAULT_MODEL = "gpt-4o-mini"
|
DEFAULT_MODEL = "gpt-4o-mini"
|
||||||
|
DEFAULT_MAX_TOKENS = 1_000
|
||||||
|
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||||
|
|
||||||
|
|
||||||
class OpenAI(BaseProvider):
|
class OpenAI(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = PROVIDER_NAME
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = DEFAULT_MODEL
|
||||||
|
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||||
|
|
||||||
def __init__(self, api_key: Union[str, None] = None):
|
def __init__(self, api_key: str | None = None):
|
||||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
def client(self):
|
def client(self):
|
||||||
"""The raw OpenAI client."""
|
"""The raw OpenAI client."""
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("OpenAI API key is required")
|
raise ValueError("OpenAI API key is required")
|
||||||
|
try:
|
||||||
|
import openai as oa
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install the `openai` package: `pip install openai`"
|
||||||
|
) from exc
|
||||||
return oa.OpenAI(api_key=self.api_key)
|
return oa.OpenAI(api_key=self.api_key)
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
def structured_client(self):
|
def structured_client(self):
|
||||||
"""A OpenAI client with Instructor."""
|
"""A OpenAI client with Instructor."""
|
||||||
return instructor.from_openai(self.client)
|
return instructor.from_openai(self.client)
|
||||||
|
|
||||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
@logger
|
||||||
|
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||||
"""Send a conversation to the OpenAI API."""
|
"""Send a conversation to the OpenAI API."""
|
||||||
from ..models import Message
|
from ..models import Message
|
||||||
|
|
||||||
@@ -38,7 +55,9 @@ class OpenAI(BaseProvider):
|
|||||||
]
|
]
|
||||||
|
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=conversation.llm_model or DEFAULT_MODEL, messages=messages, **kwargs
|
model=conversation.llm_model or DEFAULT_MODEL,
|
||||||
|
messages=messages,
|
||||||
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the response content from the OpenAI response
|
# Get the response content from the OpenAI response
|
||||||
@@ -53,27 +72,37 @@ class OpenAI(BaseProvider):
|
|||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=PROVIDER_NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
def structured_response(self, prompt, response_model, *, llm_model: str, **kwargs):
|
@logger
|
||||||
|
def structured_response(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
response_model: Type[T],
|
||||||
|
*,
|
||||||
|
llm_model: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> T:
|
||||||
|
"""Get a structured response from the OpenAI API."""
|
||||||
# Ensure messages are provided in kwargs
|
# Ensure messages are provided in kwargs
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
|
|
||||||
response = self.structured_client.chat.completions.create(
|
response = self.structured_client.chat.completions.create(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=llm_model or self.DEFAULT_MODEL,
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
**kwargs,
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
return response
|
return response_model.model_validate(response)
|
||||||
|
|
||||||
def generate_text(self, prompt, *, llm_model, **kwargs):
|
@logger
|
||||||
|
def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs):
|
||||||
|
"""Generate text using the OpenAI API."""
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
|
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
messages=messages, model=llm_model or self.DEFAULT_MODEL, **kwargs
|
messages=messages,
|
||||||
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
return response.choices[0].message.content
|
return response.choices[0].message.content
|
||||||
|
|||||||
+33
-12
@@ -1,40 +1,57 @@
|
|||||||
from typing import Union
|
from functools import cached_property
|
||||||
|
from typing import TYPE_CHECKING, Type, TypeVar
|
||||||
|
|
||||||
import instructor
|
import instructor
|
||||||
import openai as oa
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from ._base import BaseProvider
|
from ..logging import logger
|
||||||
from ..settings import settings
|
from ..settings import settings
|
||||||
|
from ._base import BaseProvider
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..models import Conversation, Message
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_NAME = "xai"
|
PROVIDER_NAME = "xai"
|
||||||
DEFAULT_MODEL = "grok-beta"
|
DEFAULT_MODEL = "grok-beta"
|
||||||
BASE_URL = "https://api.x.ai/v1"
|
BASE_URL = "https://api.x.ai/v1"
|
||||||
DEFAULT_MAX_TOKENS = 1000
|
DEFAULT_MAX_TOKENS = 1000
|
||||||
|
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||||
|
|
||||||
|
|
||||||
class XAI(BaseProvider):
|
class XAI(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = PROVIDER_NAME
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = DEFAULT_MODEL
|
||||||
|
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||||
|
|
||||||
def __init__(self, api_key: Union[str, None] = None):
|
def __init__(self, api_key: str | None = None):
|
||||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
def client(self):
|
def client(self):
|
||||||
"""The raw OpenAI client."""
|
"""The raw OpenAI client."""
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("XAI API key is required")
|
raise ValueError("XAI API key is required")
|
||||||
|
try:
|
||||||
|
import openai as oa
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install the `openai` package: `pip install openai`"
|
||||||
|
) from exc
|
||||||
return oa.OpenAI(
|
return oa.OpenAI(
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
base_url=BASE_URL,
|
base_url=BASE_URL,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
def structured_client(self):
|
def structured_client(self):
|
||||||
"""A client patched with Instructor."""
|
"""A client patched with Instructor."""
|
||||||
return instructor.from_openai(self.client)
|
return instructor.from_openai(self.client)
|
||||||
|
|
||||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
@logger
|
||||||
|
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||||
"""Send a conversation to the OpenAI API."""
|
"""Send a conversation to the OpenAI API."""
|
||||||
from ..models import Message
|
from ..models import Message
|
||||||
|
|
||||||
@@ -45,7 +62,7 @@ class XAI(BaseProvider):
|
|||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
**kwargs,
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the response content from the OpenAI response
|
# Get the response content from the OpenAI response
|
||||||
@@ -60,10 +77,14 @@ class XAI(BaseProvider):
|
|||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=PROVIDER_NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
def structured_response(self, prompt: str, response_model, *, llm_model):
|
@logger
|
||||||
|
def structured_response(
|
||||||
|
self, prompt: str, response_model: Type[T], *, llm_model: str
|
||||||
|
) -> T:
|
||||||
raise NotImplementedError("XAI does not support structured responses")
|
raise NotImplementedError("XAI does not support structured responses")
|
||||||
|
|
||||||
def generate_text(self, prompt, *, llm_model, **kwargs):
|
@logger
|
||||||
|
def generate_text(self, prompt: str, *, llm_model: str, **kwargs) -> str:
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
@@ -71,7 +92,7 @@ class XAI(BaseProvider):
|
|||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=llm_model or self.DEFAULT_MODEL,
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
**kwargs,
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
return response.choices[0].message.content
|
return str(response.choices[0].message.content)
|
||||||
|
|||||||
+36
-1
@@ -4,6 +4,40 @@ from pydantic import Field, SecretStr, field_validator
|
|||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingConfig(BaseSettings):
|
||||||
|
"""The class that holds all the logging settings for the application."""
|
||||||
|
|
||||||
|
enabled: bool = Field(False, description="Enable logging")
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
def enable_logfire(self, **kwargs) -> None:
|
||||||
|
"""Enable logging for the application."""
|
||||||
|
# adding imports here to avoid forced dependencies
|
||||||
|
try:
|
||||||
|
import logfire
|
||||||
|
from logging import basicConfig
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"To enable logging, please install logfire: `pip install logfire`"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
self.enabled = True
|
||||||
|
logfire.configure(**kwargs)
|
||||||
|
basicConfig(handlers=[logfire.LogfireLoggingHandler()])
|
||||||
|
|
||||||
|
try:
|
||||||
|
logfire.configure(**kwargs)
|
||||||
|
basicConfig(handlers=[logfire.LogfireLoggingHandler()])
|
||||||
|
except Exception as e:
|
||||||
|
self.enabled = False # Reset flag on failure
|
||||||
|
raise RuntimeError("Failed to configure logging") from e
|
||||||
|
|
||||||
|
def disable_logfire(self) -> None:
|
||||||
|
"""Disable logging for the application."""
|
||||||
|
self.enabled = False
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
"""The class that holds all the API keys for the application."""
|
"""The class that holds all the API keys for the application."""
|
||||||
|
|
||||||
@@ -11,17 +45,18 @@ class Settings(BaseSettings):
|
|||||||
None, description="API key for Anthropic"
|
None, description="API key for Anthropic"
|
||||||
)
|
)
|
||||||
GROQ_API_KEY: Optional[SecretStr] = Field(None, description="API key for Groq")
|
GROQ_API_KEY: Optional[SecretStr] = Field(None, description="API key for Groq")
|
||||||
|
GEMINI_API_KEY: Optional[SecretStr] = Field(None, description="API key for Gemini")
|
||||||
OPENAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for OpenAI")
|
OPENAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for OpenAI")
|
||||||
OLLAMA_HOST_URL: Optional[str] = Field(
|
OLLAMA_HOST_URL: Optional[str] = Field(
|
||||||
"http://127.0.0.1:11434", description="Fully qualified host URL for Ollama"
|
"http://127.0.0.1:11434", description="Fully qualified host URL for Ollama"
|
||||||
)
|
)
|
||||||
XAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for xAI")
|
XAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for xAI")
|
||||||
DEFAULT_LLM_PROVIDER: str = Field("openai", description="The default LLM provider")
|
DEFAULT_LLM_PROVIDER: str = Field("openai", description="The default LLM provider")
|
||||||
DEFAULT_LLM_MODEL: str = Field("gpt-4o-mini", description="The default LLM model")
|
|
||||||
|
|
||||||
model_config = SettingsConfigDict(
|
model_config = SettingsConfigDict(
|
||||||
env_file=".env", env_file_encoding="utf-8", case_sensitive=True, extra="ignore"
|
env_file=".env", env_file_encoding="utf-8", case_sensitive=True, extra="ignore"
|
||||||
)
|
)
|
||||||
|
logging: LoggingConfig = LoggingConfig()
|
||||||
|
|
||||||
@field_validator("*", mode="before")
|
@field_validator("*", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
+9
-9
@@ -1,24 +1,26 @@
|
|||||||
import difflib
|
import difflib
|
||||||
from typing import Optional, Type
|
|
||||||
|
|
||||||
from .providers import providers, BaseProvider
|
from .providers import BaseProvider, providers
|
||||||
|
|
||||||
_PROVIDER_NAMES = [provider.NAME.lower() for provider in providers]
|
_PROVIDER_NAMES = [provider.NAME.lower() for provider in providers]
|
||||||
|
|
||||||
|
|
||||||
def find_provider(provider_name: str) -> BaseProvider:
|
def find_provider(provider_name: str | None) -> BaseProvider:
|
||||||
"""
|
"""
|
||||||
Find and instantiate a provider by name.
|
Find and instantiate a provider by name.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
provider_name (Union[str, None]): The name of the provider to find.
|
provider_name (Union[str, None]): The name of the provider to find.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An instance of the provider class if found.
|
An instance of the provider class if found.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the provider is not found, with a suggestion for the closest match.
|
ValueError: If the provider is not specified or is not found, with a suggestion for the closest match.
|
||||||
"""
|
"""
|
||||||
|
if provider_name is None:
|
||||||
|
raise ValueError("No provider specified.")
|
||||||
|
|
||||||
# Find the provider by name.
|
# Find the provider by name.
|
||||||
for provider_class in providers:
|
for provider_class in providers:
|
||||||
if provider_class.NAME.lower() == provider_name.lower():
|
if provider_class.NAME.lower() == provider_name.lower():
|
||||||
@@ -29,10 +31,8 @@ def find_provider(provider_name: str) -> BaseProvider:
|
|||||||
provider_found = difflib.get_close_matches(
|
provider_found = difflib.get_close_matches(
|
||||||
provider_name.lower(), _PROVIDER_NAMES, n=1
|
provider_name.lower(), _PROVIDER_NAMES, n=1
|
||||||
)
|
)
|
||||||
|
|
||||||
if provider_found:
|
if provider_found:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Provider {provider_name!r} not found. Did you mean {provider_found[0]!r}?"
|
f"Provider {provider_name!r} not found. Did you mean {provider_found[0]!r}?"
|
||||||
)
|
)
|
||||||
else:
|
raise ValueError(f"Provider {provider_name} not found.")
|
||||||
raise ValueError(f"Provider {provider_name} not found.")
|
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Add the project root to the Python path.
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from simplemind import Session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sm():
|
||||||
|
"""Fixture that provides a simplemind Session instance with default settings."""
|
||||||
|
return Session()
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
def test_basic_math():
|
||||||
|
assert 1 + 1 == 2
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from simplemind.providers import Anthropic, Gemini, Groq, Ollama, OpenAI
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseModel(BaseModel):
|
||||||
|
result: int
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider_cls",
|
||||||
|
[
|
||||||
|
Anthropic,
|
||||||
|
Gemini,
|
||||||
|
OpenAI,
|
||||||
|
Groq,
|
||||||
|
Ollama,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_generate_data(provider_cls):
|
||||||
|
provider = provider_cls()
|
||||||
|
prompt = "What is 2+2?"
|
||||||
|
|
||||||
|
data = provider.structured_response(prompt=prompt, response_model=ResponseModel)
|
||||||
|
|
||||||
|
assert isinstance(data, ResponseModel)
|
||||||
|
assert isinstance(data.result, int)
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
import pytest
|
||||||
|
from simplemind.providers import Anthropic, Gemini, Groq, Ollama, OpenAI
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider_cls",
|
||||||
|
[
|
||||||
|
Anthropic,
|
||||||
|
Gemini,
|
||||||
|
OpenAI,
|
||||||
|
Groq,
|
||||||
|
Ollama,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_generate_text(provider_cls):
|
||||||
|
provider = provider_cls()
|
||||||
|
prompt = "What is 2+2?"
|
||||||
|
|
||||||
|
response = provider.generate_text(prompt=prompt, llm_model=provider.DEFAULT_MODEL)
|
||||||
|
|
||||||
|
assert isinstance(response, str)
|
||||||
|
assert len(response) > 0
|
||||||
Reference in New Issue
Block a user