mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 14:50:16 +00:00
Compare commits
128 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 75a42044e5 | |||
| cc66dbf8e5 | |||
| a174e60a1e | |||
| b03695f626 | |||
| 082bc24e91 | |||
| aca1b87180 | |||
| 1ff4c5660e | |||
| 241a7ab402 | |||
| 76fa7521eb | |||
| cbec2c5f6d | |||
| 34f463839c | |||
| c648a922b4 | |||
| 873f5ba5f8 | |||
| 28a7b2f140 | |||
| 173162e798 | |||
| cd0be3ad89 | |||
| 3dd2e1b248 | |||
| ad1800840d | |||
| d62f297b68 | |||
| a2597709d2 | |||
| 1455b5ba13 | |||
| 0fb54d1987 | |||
| fe06331662 | |||
| 56b1e65d70 | |||
| 4b3e1bc6dd | |||
| f5b922ade8 | |||
| 3a7383425f | |||
| 92c10fc41e | |||
| 75c42278a2 | |||
| c25f1e1058 | |||
| 2a5966eb10 | |||
| f19263d309 | |||
| 25b742db1f | |||
| caceba381d | |||
| 0795464fd7 | |||
| 8d83050a64 | |||
| d82effdfb1 | |||
| e648292cb3 | |||
| 37a9333be3 | |||
| cbc3739411 | |||
| 7c8f22bef1 | |||
| 9c3f2a6df3 | |||
| febf5473d5 | |||
| 48ac97f070 | |||
| c41a3f00fb | |||
| 25ee4ae32c | |||
| 984721f02b | |||
| 69c8723770 | |||
| 0c10d5676a | |||
| e0ddf41e15 | |||
| f940ae2dfd | |||
| 85fa4f5879 | |||
| 44581e8fe3 | |||
| 9503ec7fd3 | |||
| 418f36dcc0 | |||
| bf9683cfd0 | |||
| 3909588f3e | |||
| 33d8f18bff | |||
| d7388ef0d5 | |||
| 02d10bfda9 | |||
| 5dc6e7b006 | |||
| 62933c8553 | |||
| f0a6be73f8 | |||
| 9257a04f34 | |||
| 64dbe9a2e7 | |||
| ccb8311089 | |||
| 0c29380501 | |||
| 7b43208a03 | |||
| e931fd0eae | |||
| 736d942527 | |||
| 3505c8758d | |||
| 308886e608 | |||
| 9c18d726d5 | |||
| 8f43b660ea | |||
| 222d3025b1 | |||
| fb6c4c289b | |||
| c28e2a3839 | |||
| 2bed7221b3 | |||
| 1504edad78 | |||
| fd7289c8d3 | |||
| c4674fc98f | |||
| 25806221eb | |||
| 5505a3e18d | |||
| 48291c37c5 | |||
| 4b2b094ea6 | |||
| 33e4046ac3 | |||
| 7fe8e91111 | |||
| 42fc0e6bc5 | |||
| ec4f6f9c06 | |||
| 499d3b3e14 | |||
| dd2f5a46d2 | |||
| bd0c739c9a | |||
| 473a054afa | |||
| 55c28a2356 | |||
| 9bd1653b5e | |||
| 59401c4be4 | |||
| 20ad9437e5 | |||
| 9db95cc87b | |||
| d711afec68 | |||
| 9d7fd4cce5 | |||
| 4aa470bb20 | |||
| 88e118cb53 | |||
| 73316c32a3 | |||
| e1331822aa | |||
| baee6e9959 | |||
| 8096609c2e | |||
| 4225f61df3 | |||
| 034e967ecb | |||
| f9c4cce9a4 | |||
| 78f6649969 | |||
| 4f1e52b1f8 | |||
| 74c09d5c87 | |||
| 1f66bac645 | |||
| f828f9991b | |||
| f4de0049f9 | |||
| 524869668d | |||
| a589850288 | |||
| 4f38b44145 | |||
| 4babdcebd9 | |||
| 8474f101f2 | |||
| e9e47e27a1 | |||
| 2309c30b8f | |||
| d972f1cd85 | |||
| e34f9b106c | |||
| 1405c3bbb0 | |||
| 624c132a59 | |||
| 63a0fea60a | |||
| 7b794930ac |
+5
-3
@@ -1,5 +1,7 @@
|
||||
export OPENAI_API_KEY=""
|
||||
export ANTHROPIC_API_KEY=""
|
||||
export XAI_API_KEY=""
|
||||
export OLLAMA_HOST_URL=""
|
||||
export GEMINI_API_KEY=""
|
||||
export GROQ_API_KEY=""
|
||||
export OLLAMA_HOST_URL=""
|
||||
export OPENAI_API_KEY=""
|
||||
export XAI_API_KEY=""
|
||||
export AMAZON_PROFILE_NAME=""
|
||||
|
||||
@@ -167,3 +167,4 @@ cython_debug/
|
||||
|
||||
src/**
|
||||
requirements.txt
|
||||
Pipfile
|
||||
|
||||
@@ -1,6 +1,36 @@
|
||||
Release History
|
||||
===============
|
||||
|
||||
## 0.2.0 (2024-11-01)
|
||||
|
||||
- Add Amazon Bedrock provider.
|
||||
- Make all provider optional dependencies. Use `$ pip install 'simplemind[full]'` to install all providers.
|
||||
- General improvements.
|
||||
|
||||
## 0.1.7 (2024-11-01)
|
||||
|
||||
- Add `logger` decorator.
|
||||
- Add `sm.enable_logfire()` function.
|
||||
- General improvements.
|
||||
|
||||
## 0.1.6 (2024-10-31)
|
||||
|
||||
- Add `sm.Plugin` syntax sugar.
|
||||
- Improvements to Anthropic provider, related to max tokens.
|
||||
- General improvements.
|
||||
- Add tests for structured response.
|
||||
- Add `llm_model` to `structured_response`.
|
||||
|
||||
## 0.1.5 (2024-10-31)
|
||||
|
||||
- Add Gemini provider.
|
||||
- Add structured response to Gemini provider.
|
||||
- Support for Python 3.10.
|
||||
|
||||
## 0.1.4 (2024-10-30)
|
||||
|
||||
- Introduce `Session` class to manage repeatability.
|
||||
- General improvements.
|
||||
|
||||
## 0.1.3 (2024-10-30)
|
||||
|
||||
|
||||
@@ -1,39 +1,48 @@
|
||||
# SimpleMind: AI for Humans™
|
||||
# Simplemind: AI for Humans™
|
||||
|
||||
[](https://mutable.ai/kennethreitz/simplemind)
|
||||
**Keep it simple, keep it human.**
|
||||
|
||||
SimpleMind is an AI library designed to simplify your experience with AI APIs in Python. Inspired by a "for humans" philosophy, it abstracts away complexity, giving developers an intuitive and human-friendly way to interact with powerful AI capabilities. With SimpleMind, tapping into AI is as easy as a friendly conversation.
|
||||
Simplemind is AI library designed to simplify your experience with AI APIs in Python. Inspired by a "for humans" philosophy, it abstracts away complexity, giving developers an intuitive and human-friendly way to interact with powerful AI capabilities.
|
||||
|
||||
```bash
|
||||
$ pip install simplemind
|
||||
```
|
||||

|
||||
|
||||
## Features
|
||||
|
||||
With Simplemind, tapping into AI is as easy as a friendly conversation.
|
||||
|
||||
- **Easy-to-use AI tools**: SimpleMind provides simple interfaces to popular AI services.
|
||||
- **Human-centered design**: The library prioritizes readability and usability—no need to be an expert to start experimenting.
|
||||
- **Minimal configuration**: Get started quickly, without worrying about configuration headaches.
|
||||
|
||||
## Supported APIs
|
||||
|
||||
To specify a specific provider or model, you can use the `llm_provider` and `llm_model` parameters when calling: `generate_text`, `generate_data`, or `create_conversation`.
|
||||
To specify a specific provider or model, you can use the `llm_provider` and `llm_model` parameters when calling: `generate_text`, `generate_data`, or `create_conversation`. The APIs remain identital between all supported providers/models.
|
||||
|
||||
- **[OpenAI's GPT](https://openai.com/gpt)**
|
||||
- **[Anthropic's Claude](https://www.anthropic.com/claude)**
|
||||
- **[xAI's Grok](https://x.ai/)**
|
||||
- **[Groq's Groq](https://groq.com/)**
|
||||
- **[Ollama](https://ollama.com)**
|
||||
- [**Anthropic's Claude**](https://www.anthropic.com/claude)
|
||||
- [**Amazon Bedrock**](https://aws.amazon.com/bedrock/)
|
||||
- [**Google's Gemini**](https://gemini.google/)
|
||||
- [**Groq's Groq**](https://groq.com/)
|
||||
- [**Ollama**](https://ollama.com)
|
||||
- [**OpenAI's GPT**](https://openai.com/gpt)
|
||||
- [**xAI's Grok**](https://x.ai/)
|
||||
|
||||
If you'd like to see SimpleMind support additional providers or models, please send a pull request!
|
||||
If you want to see Simplemind support, additional providers or models, please send a pull request!
|
||||
|
||||
## Why SimpleMind?
|
||||
|
||||
- **Intuitive**: Built with Pythonic simplicity and readability in mind.
|
||||
- **For Humans**: Emphasizes a human-friendly interface, just like `requests` for HTTP.
|
||||
- **Open Source**: SimpleMind is open source, and contributions are always welcome!
|
||||
- **Open Source**: Simplemind is open source, and contributions are always welcome!
|
||||
|
||||
Also, why not? :)
|
||||
|
||||
## Quickstart
|
||||
|
||||
SimpleMind takes care of the complex API calls so you can focus on what matters—building, experimenting, and creating.
|
||||
Simplemind takes care of the complex API calls so you can focus on what matters—building, experimenting, and creating.
|
||||
|
||||
```bash
|
||||
$ pip install 'simplemind[full]'
|
||||
```
|
||||
|
||||
First, authenticate your API keys by setting them in the environment variables:
|
||||
|
||||
@@ -41,18 +50,17 @@ First, authenticate your API keys by setting them in the environment variables:
|
||||
$ export OPENAI_API_KEY="sk-..."
|
||||
```
|
||||
|
||||
This pattern allows you to keep your API keys private and out of your codebase. Other supported environment variables: `ANTHROPIC_API_KEY`, `XAI_API_KEY`, and `GROQ_API_KEY`.
|
||||
This pattern allows you to keep your API keys private and out of your codebase. Other supported environment variables: `ANTHROPIC_API_KEY`, `XAI_API_KEY`, `GROQ_API_KEY`, and `GEMINI_API_KEY`.
|
||||
|
||||
Next, import SimpleMind and start using it:
|
||||
Next, import Simplemind and start using it:
|
||||
|
||||
```python
|
||||
import simplemind as sm
|
||||
```
|
||||
|
||||
|
||||
## Examples
|
||||
|
||||
Here are some examples of how to use SimpleMind:
|
||||
Here are some examples of how to use Simplemind:
|
||||
|
||||
### Text Completion
|
||||
|
||||
@@ -83,6 +91,33 @@ class Poem(BaseModel):
|
||||
title='Eternal Embrace' content='In the quiet hours of the night,\nWhen stars whisper secrets bright,\nTwo hearts beat in a gentle rhyme,\nDancing through the sands of time.\n\nWith every glance, a spark ignites,\nA flame that warms the coldest nights,\nIn laughter shared and whispers sweet,\nLove paints the world, a masterpiece.\n\nThrough stormy skies and sunlit days,\nIn myriad forms, it finds its ways,\nA tender touch, a knowing sigh,\nIn love’s embrace, we learn to fly.\n\nAs seasons change and moments fade,\nIn the tapestry of dreams we’ve laid,\nLove’s threads endure, forever bind,\nA timeless bond, two souls aligned.\n\nSo here’s to love, both bright and true,\nA gift we give, anew, anew,\nIn every heartbeat, every prayer,\nA story written in the air.'
|
||||
```
|
||||
|
||||
#### A more complex example
|
||||
|
||||
```python
|
||||
class InstructionStep(BaseModel):
|
||||
step_number: int
|
||||
instruction: str
|
||||
|
||||
class RecipeIngredient(BaseModel):
|
||||
name: str
|
||||
quantity: float
|
||||
unit: str
|
||||
|
||||
class Recipe(BaseModel):
|
||||
name: str
|
||||
ingredients: list[RecipeIngredient]
|
||||
instructions: list[InstructionStep]
|
||||
|
||||
recipe = sm.generate_data(
|
||||
"Write a recipe for chocolate chip cookies",
|
||||
llm_model="gpt-4o-mini",
|
||||
llm_provider="openai",
|
||||
response_model=Recipe,
|
||||
)
|
||||
```
|
||||
|
||||
Special thanks to [@jxnl](https://github.com/jxnl) for building [Instructor](https://github.com/jxnl/instructor), which makes this possible!
|
||||
|
||||
### Conversational AI
|
||||
|
||||
SimpleMind also allows for easy conversational flows:
|
||||
@@ -105,15 +140,34 @@ To continue the conversation, you can call `conversation.send()` again, which re
|
||||
<Message role=assistant text="The meaning of life is a profound philosophical question that has been explored by cultures, religions, and philosophers for centuries. Different people and belief systems offer varying interpretations:\n\n1. **Religious Perspectives:** Many religions propose that the meaning of life is to fulfill a divine purpose, serve God, or reach an afterlife. For example, Christianity often emphasizes love, faith, and service to God and others as central to life’s meaning.\n\n2. **Philosophical Views:** Philosophers offer diverse answers. Existentialists like Jean-Paul Sartre argue that life has no inherent meaning, and it is up to individuals to create their own purpose. Others, like Aristotle, suggest that achieving eudaimonia (flourishing or happiness) through virtuous living is the key to a meaningful life.\n\n3. **Scientific and Secular Approaches:** Some people find meaning through understanding the natural world, contributing to human knowledge, or through personal accomplishments and happiness. They may view life’s meaning as a product of connection, legacy, or the pursuit of knowledge and creativity.\n\n4. **Personal Perspective:** For many, the meaning of life is deeply personal, involving their relationships, passions, and goals. These individuals define life’s purpose through experiences, connections, and the impact they have on others and the world.\n\nUltimately, the meaning of life is a subjective question, with each person finding their own answers based on their beliefs, experiences, and reflections.">
|
||||
```
|
||||
|
||||
### Stop Repeating Yourself
|
||||
|
||||
You can use the `Session` class to set default parameters for all calls:
|
||||
|
||||
```python
|
||||
# Create a session with defaults
|
||||
gpt_4o_mini = sm.Session(llm_provider="openai", llm_model="gpt-4o-mini")
|
||||
|
||||
# Now all calls use these defaults
|
||||
response = gpt_4o_mini.generate_text("Hello!")
|
||||
conversation = gpt_4o_mini.create_conversation()
|
||||
```
|
||||
|
||||
This maintains the simplicity of the original API while reducing repetition. The session object also supports overriding defaults on a per-call basis:
|
||||
|
||||
```python
|
||||
response = gpt_4o_mini.generate_text(
|
||||
"Complex task here",
|
||||
llm_model="gpt-4"
|
||||
)
|
||||
```
|
||||
|
||||
### Basic Memory Plugin
|
||||
|
||||
Harnessing the power of Python, you can easily create your own plugins to add additional functionality to your conversations:
|
||||
|
||||
```python
|
||||
import simplemind as sm
|
||||
|
||||
|
||||
class SimpleMemoryPlugin:
|
||||
class SimpleMemoryPlugin(sm.BasePlugin):
|
||||
def __init__(self):
|
||||
self.memories = [
|
||||
"the earth has fictionally beeen destroyed.",
|
||||
@@ -137,6 +191,7 @@ conversation.add_message(
|
||||
text="Please write a poem about the moon",
|
||||
)
|
||||
```
|
||||
|
||||
```pycon
|
||||
>>> conversation.send()
|
||||
In the vast expanse where stars do play,
|
||||
@@ -170,9 +225,20 @@ A reminder that in tales and fun,
|
||||
The universe is never done.
|
||||
```
|
||||
|
||||
Simple, yet effective.
|
||||
|
||||
### Logging
|
||||
|
||||
Simplemind uses [logfire](https://logfire.ai) for logging. To enable logging, call `sm.enable_logfire()`.
|
||||
|
||||
### More Examples
|
||||
|
||||
Please see the [examples](examples) directory for executable examples.
|
||||
|
||||
---
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions of all kinds. Feel free to open issues for bug reports or feature requests, and submit pull requests to make SimpleMind even better.
|
||||
|
||||
To get started:
|
||||
@@ -183,11 +249,9 @@ To get started:
|
||||
4. Submit a pull request.
|
||||
|
||||
## License
|
||||
SimpleMind is licensed under the Apache 2.0 License.
|
||||
|
||||
Simplemind is licensed under the Apache 2.0 License.
|
||||
|
||||
## Acknowledgements
|
||||
SimpleMind is inspired by the philosophy of "code for humans" and aims to make working with AI models accessible to all. Special thanks to the open-source community for their contributions and inspiration.
|
||||
|
||||
---------------
|
||||
|
||||
SimpleMind: Keep it simple, keep it human.
|
||||
Simplemind is inspired by the philosophy of "code for humans" and aims to make working with AI models accessible to all. Special thanks to the open-source community for their contributions and inspiration.
|
||||
|
||||
+2
-2
@@ -14,9 +14,9 @@ sys.path.insert(0, os.path.abspath(".."))
|
||||
import simplemind
|
||||
|
||||
project = "simplemind"
|
||||
copyright = "2024, Kenneth Reitz"
|
||||
copyright = "2024 Kenneth Reitz"
|
||||
author = "Kenneth Reitz"
|
||||
release = "v0.1.3"
|
||||
release = "v0.2.0"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from _context import sm
|
||||
|
||||
from pydantic import BaseModel
|
||||
import openai
|
||||
import faiss
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import openai
|
||||
from _context import sm
|
||||
from pydantic import BaseModel
|
||||
|
||||
class ContextualMemoryPlugin:
|
||||
|
||||
class ContextualMemoryPlugin(sm.BasePlugin):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from typing import List, Iterator
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Iterator, List
|
||||
|
||||
from _context import sm
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Movie(BaseModel):
|
||||
@@ -25,7 +24,7 @@ class QuotesList(BaseModel):
|
||||
quotes: List[MovieQuote]
|
||||
|
||||
|
||||
def gen_quotes(n=10) -> Iterator[MovieQuote]:
|
||||
def gen_quotes(n: int = 10) -> Iterator[MovieQuote]:
|
||||
"""Generate a list of quotes from famous movies."""
|
||||
|
||||
for q in sm.generate_data(
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from _context import sm
|
||||
|
||||
|
||||
class MathPlugin:
|
||||
class MathPlugin(sm.BasePlugin):
|
||||
def send_hook(self, conversation: sm.Conversation):
|
||||
last_user_message = conversation.get_last_message(role="user")
|
||||
if last_user_message is None:
|
||||
return
|
||||
if "calculate" in last_user_message.text.lower():
|
||||
expression = last_user_message.text.lower().replace("calculate", "").strip()
|
||||
try:
|
||||
@@ -14,7 +16,7 @@ class MathPlugin:
|
||||
except Exception:
|
||||
conversation.add_message(
|
||||
role="assistant",
|
||||
text="I'm sorry, I couldn't compute that expression.",
|
||||
text="I'm sorry, I couldn't compute that expression. Please try again.",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from _context import sm
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal
|
||||
|
||||
from _context import sm
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SentimentAnalysis(BaseModel):
|
||||
sentiment: Literal["positive", "negative", "neutral"]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from _context import sm
|
||||
|
||||
|
||||
class SimpleMemoryPlugin(sm.BasePlugin):
|
||||
class SimpleMemoryPlugin:
|
||||
def __init__(self):
|
||||
self.memories = [
|
||||
"the earth has fictionally beeen destroyed.",
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
import time
|
||||
|
||||
import simplemind as sm
|
||||
|
||||
|
||||
class ConversationPlugin(sm.BasePlugin):
|
||||
def post_send_hook(self, conversation, response):
|
||||
# Print the LLM model and the response text.
|
||||
print(f"{conversation.llm_model}:\n{response.text.strip()}\n\n------------\n")
|
||||
|
||||
|
||||
def have_conversation(rounds: int = 3):
|
||||
# Create two conversations - one for each AI
|
||||
with (
|
||||
sm.create_conversation(
|
||||
llm_model="claude-3-5-sonnet-20241022", llm_provider="anthropic"
|
||||
) as claude_conv,
|
||||
sm.create_conversation(
|
||||
llm_model="llama3.2", llm_provider="ollama"
|
||||
) as llama_conv,
|
||||
):
|
||||
|
||||
# Add our plugin to both
|
||||
plugin = ConversationPlugin()
|
||||
claude_conv.add_plugin(plugin)
|
||||
llama_conv.add_plugin(plugin)
|
||||
|
||||
# Start the conversation
|
||||
prompt = "What do you think about the future of artificial intelligence? Please keep your response brief."
|
||||
claude_conv.add_message("user", prompt, meta={})
|
||||
claude_response = claude_conv.send()
|
||||
|
||||
# Have them discuss back and forth
|
||||
for _ in range(rounds):
|
||||
# Llama responds to Claude
|
||||
llama_conv.add_message(
|
||||
"user",
|
||||
f"Respond to this statement from another AI: {claude_response.text}",
|
||||
meta={},
|
||||
)
|
||||
llama_response = llama_conv.send()
|
||||
|
||||
time.sleep(1) # Add a small delay between responses
|
||||
|
||||
# Claude responds to Llama
|
||||
claude_conv.add_message(
|
||||
"user",
|
||||
f"Respond to this statement from another AI: {llama_response.text}",
|
||||
meta={},
|
||||
)
|
||||
claude_response = claude_conv.send()
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Starting AI conversation...\n")
|
||||
have_conversation()
|
||||
print("\nConversation ended.")
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 944 KiB |
+14
-3
@@ -1,10 +1,21 @@
|
||||
[project]
|
||||
name = "simplemind"
|
||||
version = "0.1.3"
|
||||
version = "0.2.0"
|
||||
description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = ["pydantic", "pydantic-settings", "instructor", "openai", "anthropic", "ollama", "groq"]
|
||||
requires-python = ">=3.10"
|
||||
dependencies = ["pydantic", "pydantic-settings", "instructor", "logfire"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
full = [
|
||||
"openai",
|
||||
"anthropic",
|
||||
"ollama",
|
||||
"groq",
|
||||
"google-generativeai",
|
||||
"botocore",
|
||||
"boto3"
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
|
||||
+86
-8
@@ -1,18 +1,72 @@
|
||||
from typing import List, Optional
|
||||
from typing import List, Type
|
||||
|
||||
from .models import Conversation, BasePlugin
|
||||
from .utils import find_provider
|
||||
from .models import BaseModel, BasePlugin, Conversation
|
||||
from .settings import settings
|
||||
from .utils import find_provider
|
||||
|
||||
|
||||
class Session:
|
||||
"""A session object that maintains configuration across multiple API calls.
|
||||
|
||||
Similar to `requests.Session`, this allows you to specify default settings
|
||||
that will be used for all operations within the session.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
llm_provider: str = settings.DEFAULT_LLM_PROVIDER,
|
||||
llm_model: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.llm_provider = llm_provider
|
||||
self.llm_model = llm_model
|
||||
self.default_kwargs = kwargs
|
||||
|
||||
def generate_text(self, prompt: str, **kwargs) -> str:
|
||||
"""Generate text using the session's default provider and model."""
|
||||
merged_kwargs = {**self.default_kwargs, **kwargs}
|
||||
return generate_text(
|
||||
prompt=prompt,
|
||||
llm_provider=self.llm_provider,
|
||||
llm_model=self.llm_model,
|
||||
**merged_kwargs,
|
||||
)
|
||||
|
||||
def generate_data(
|
||||
self, prompt: str, response_model: Type[BaseModel], **kwargs
|
||||
) -> BaseModel:
|
||||
"""Generate structured data using the session's default provider and model."""
|
||||
merged_kwargs = {**self.default_kwargs, **kwargs}
|
||||
return generate_data(
|
||||
prompt=prompt,
|
||||
response_model=response_model,
|
||||
llm_provider=self.llm_provider,
|
||||
llm_model=self.llm_model,
|
||||
**merged_kwargs,
|
||||
)
|
||||
|
||||
def create_conversation(self, **kwargs) -> Conversation:
|
||||
"""Create a conversation using the session's default provider and model."""
|
||||
merged_kwargs = {**self.default_kwargs, **kwargs}
|
||||
return create_conversation(
|
||||
llm_provider=self.llm_provider, llm_model=self.llm_model, **merged_kwargs
|
||||
)
|
||||
|
||||
|
||||
def create_conversation(
|
||||
llm_model=None, llm_provider=None, *, plugins: Optional[List[BasePlugin]] = None
|
||||
):
|
||||
*,
|
||||
llm_model: str | None = None,
|
||||
llm_provider: str | None = None,
|
||||
plugins: List[BasePlugin] | None = None,
|
||||
**kwargs,
|
||||
) -> Conversation:
|
||||
"""Create a new conversation."""
|
||||
|
||||
# Create the conversation.
|
||||
conversation = Conversation(
|
||||
llm_model=llm_model, llm_provider=llm_provider or settings.DEFAULT_LLM_PROVIDER
|
||||
llm_model=llm_model,
|
||||
llm_provider=llm_provider or settings.DEFAULT_LLM_PROVIDER,
|
||||
)
|
||||
|
||||
# Add plugins to the conversation.
|
||||
@@ -22,7 +76,14 @@ def create_conversation(
|
||||
return conversation
|
||||
|
||||
|
||||
def generate_data(prompt, *, llm_model=None, llm_provider=None, response_model=None):
|
||||
def generate_data(
|
||||
prompt: str,
|
||||
*,
|
||||
llm_model: str | None = None,
|
||||
llm_provider: str | None = None,
|
||||
response_model: Type[BaseModel],
|
||||
**kwargs,
|
||||
) -> BaseModel:
|
||||
"""Generate structured data from a given prompt."""
|
||||
|
||||
# Find the provider.
|
||||
@@ -36,7 +97,13 @@ def generate_data(prompt, *, llm_model=None, llm_provider=None, response_model=N
|
||||
)
|
||||
|
||||
|
||||
def generate_text(prompt, *, llm_model=None, llm_provider=None, **kwargs):
|
||||
def generate_text(
|
||||
prompt: str,
|
||||
*,
|
||||
llm_model: str | None = None,
|
||||
llm_provider: str | None = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Generate text from a given prompt."""
|
||||
|
||||
# Find the provider.
|
||||
@@ -46,6 +113,14 @@ def generate_text(prompt, *, llm_model=None, llm_provider=None, **kwargs):
|
||||
return provider.generate_text(prompt=prompt, llm_model=llm_model, **kwargs)
|
||||
|
||||
|
||||
def enable_logfire() -> None:
|
||||
"""Enable logfire logging."""
|
||||
settings.logging.enable_logfire()
|
||||
|
||||
|
||||
# Syntax sugar.
|
||||
Plugin = BasePlugin
|
||||
|
||||
__all__ = [
|
||||
"create_conversation",
|
||||
"find_provider",
|
||||
@@ -53,4 +128,7 @@ __all__ = [
|
||||
"generate_text",
|
||||
"settings",
|
||||
"BasePlugin",
|
||||
"Session",
|
||||
"Plugin",
|
||||
"enable_logfire",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
import time
|
||||
from typing import Any, Callable
|
||||
|
||||
import logfire
|
||||
|
||||
from .settings import settings
|
||||
|
||||
|
||||
def logger(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""A decorator that logs the function parameters, function returns,
|
||||
and exceptions raised if logging is enabled, using logfire.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
if not settings.logging.is_enabled:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
logfire.info(f"Calling {func.__name__} with args: {args}, kwargs: {kwargs}")
|
||||
t1 = time.perf_counter()
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
t2 = time.perf_counter()
|
||||
logfire.info(f"{func.__name__} returned: {result} in {t2-t1} seconds")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
t2 = time.perf_counter()
|
||||
logfire.error(f"Error in {func.__name__}: {e} in {t2-t1} seconds")
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
+37
-22
@@ -1,18 +1,18 @@
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .utils import find_provider
|
||||
|
||||
|
||||
MESSAGE_ROLE = Literal["system", "user", "assistant"]
|
||||
|
||||
|
||||
class SMBaseModel(BaseModel):
|
||||
"""The base SimpleMind model class."""
|
||||
|
||||
date_created: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
def __str__(self):
|
||||
@@ -22,39 +22,36 @@ class SMBaseModel(BaseModel):
|
||||
return str(self)
|
||||
|
||||
|
||||
class BasePlugin(ABC):
|
||||
class BasePlugin(SMBaseModel):
|
||||
"""The base conversation plugin class."""
|
||||
|
||||
# Plugin metadata.
|
||||
meta: Dict[str, Any] = {}
|
||||
|
||||
# @abstractmethod
|
||||
def initialize_hook(self, conversation: "Conversation"):
|
||||
def initialize_hook(self, conversation: "Conversation") -> Any:
|
||||
"""Initialize a hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
# @abstractmethod
|
||||
def cleanup_hook(self, conversation: "Conversation"):
|
||||
def cleanup_hook(self, conversation: "Conversation") -> Any:
|
||||
"""Cleanup a hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
# @abstractmethod
|
||||
def add_message_hook(self, conversation: "Conversation", message: "Message"):
|
||||
def add_message_hook(self, conversation: "Conversation", message: "Message") -> Any:
|
||||
"""Add a message hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
# @abstractmethod
|
||||
def pre_send_hook(self, conversation: "Conversation"):
|
||||
def pre_send_hook(self, conversation: "Conversation") -> Any:
|
||||
"""Pre-send hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
# @abstractmethod
|
||||
def post_send_hook(self, conversation: "Conversation", response: "Message"):
|
||||
def post_send_hook(self, conversation: "Conversation", response: "Message") -> Any:
|
||||
"""Post-send hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Message(SMBaseModel):
|
||||
"""A message in a conversation."""
|
||||
|
||||
role: MESSAGE_ROLE
|
||||
text: str
|
||||
meta: Dict[str, Any] = {}
|
||||
@@ -66,7 +63,16 @@ class Message(SMBaseModel):
|
||||
return f"<Message role={self.role} text={self.text!r}>"
|
||||
|
||||
@classmethod
|
||||
def from_raw_response(cls, *, text: str, raw):
|
||||
def from_raw_response(cls, *, text: str, raw: Any) -> "Message":
|
||||
"""Create a Message instance from a raw response.
|
||||
|
||||
Args:
|
||||
text (str): The message text.
|
||||
raw (Any): The raw response data.
|
||||
|
||||
Returns:
|
||||
Message: A new Message instance.
|
||||
"""
|
||||
self = cls()
|
||||
self.text = text
|
||||
self.raw = raw
|
||||
@@ -74,11 +80,13 @@ class Message(SMBaseModel):
|
||||
|
||||
|
||||
class Conversation(SMBaseModel):
|
||||
"""A conversation between a user and an assistant."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
messages: List[Message] = []
|
||||
llm_model: Optional[str] = None
|
||||
llm_provider: Optional[str] = None
|
||||
plugins: List[Any] = []
|
||||
plugins: List[BasePlugin] = []
|
||||
|
||||
def __str__(self):
|
||||
return f"<Conversation id={self.id!r}>"
|
||||
@@ -94,8 +102,13 @@ class Conversation(SMBaseModel):
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
# Execute all cleanup hooks.
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException],
|
||||
exc_value: BaseException,
|
||||
traceback: TracebackType,
|
||||
) -> None:
|
||||
"""Execute all cleanup hooks."""
|
||||
for plugin in self.plugins:
|
||||
if hasattr(plugin, "cleanup_hook"):
|
||||
try:
|
||||
@@ -104,7 +117,7 @@ class Conversation(SMBaseModel):
|
||||
pass
|
||||
|
||||
def prepend_system_message(
|
||||
self, role: str, text: str, meta: Optional[Dict[str, Any]] = None
|
||||
self, role: MESSAGE_ROLE, text: str, meta: Dict[str, Any] | None = None
|
||||
):
|
||||
"""Prepend a system message to the conversation."""
|
||||
self.messages = [Message(role=role, text=text, meta=meta or {})] + self.messages
|
||||
@@ -132,7 +145,9 @@ class Conversation(SMBaseModel):
|
||||
self.messages.append(Message(role=role, text=text, meta=meta))
|
||||
|
||||
def send(
|
||||
self, llm_model: Optional[str] = None, llm_provider: Optional[str] = None
|
||||
self,
|
||||
llm_model: str | None = None,
|
||||
llm_provider: str | None = None,
|
||||
) -> Message:
|
||||
"""Send the conversation to the LLM."""
|
||||
|
||||
@@ -161,10 +176,10 @@ class Conversation(SMBaseModel):
|
||||
|
||||
return response
|
||||
|
||||
def get_last_message(self, role: MESSAGE_ROLE) -> Optional[Message]:
|
||||
def get_last_message(self, role: MESSAGE_ROLE) -> Message | None:
|
||||
"""Get the last message with the given role."""
|
||||
return next((m for m in reversed(self.messages) if m.role == role), None)
|
||||
|
||||
def add_plugin(self, plugin: Any):
|
||||
def add_plugin(self, plugin: BasePlugin) -> None:
|
||||
"""Add a plugin to the conversation."""
|
||||
self.plugins.append(plugin)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from typing import List, Type
|
||||
|
||||
from simplemind.providers._base import BaseProvider
|
||||
from simplemind.providers.anthropic import Anthropic
|
||||
from simplemind.providers.groq import Groq
|
||||
from simplemind.providers.openai import OpenAI
|
||||
from simplemind.providers.ollama import Ollama
|
||||
from simplemind.providers.xai import XAI
|
||||
from ._base import BaseProvider
|
||||
from .anthropic import Anthropic
|
||||
from .gemini import Gemini
|
||||
from .groq import Groq
|
||||
from .ollama import Ollama
|
||||
from .openai import OpenAI
|
||||
from .xai import XAI
|
||||
from .amazon import Amazon
|
||||
|
||||
providers: List[Type[BaseProvider]] = [Anthropic, Groq, OpenAI, Ollama, XAI]
|
||||
providers: List[Type[BaseProvider]] = [Anthropic, Gemini, Groq, OpenAI, Ollama, XAI, Amazon]
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Type, TypeVar
|
||||
|
||||
from instructor import Instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
@@ -9,13 +17,13 @@ class BaseProvider(ABC):
|
||||
NAME: str
|
||||
DEFAULT_MODEL: str
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
@abstractmethod
|
||||
def client(self):
|
||||
def client(self) -> Any:
|
||||
"""The instructor client for the provider."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
@abstractmethod
|
||||
def structured_client(self) -> Instructor:
|
||||
"""The structured client for the provider."""
|
||||
@@ -27,7 +35,7 @@ class BaseProvider(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def structured_response(self, prompt: str, response_model, **kwargs):
|
||||
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
|
||||
"""Get a structured response."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
from typing import Type, TypeVar
|
||||
|
||||
import instructor
|
||||
import anthropic
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._base import BaseProvider
|
||||
from ..settings import settings
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
PROVIDER_NAME = "amazon"
|
||||
DEFAULT_MODEL = "anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
DEFAULT_MAX_TOKENS = 5_000
|
||||
|
||||
|
||||
class Amazon(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
|
||||
def __init__(self, profile_name: str | None = None):
|
||||
self.profile_name = profile_name or settings.AMAZON_PROFILE_NAME
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""The AnthropicBedrock client."""
|
||||
if not self.profile_name:
|
||||
raise ValueError("Profile name is not provided")
|
||||
|
||||
return anthropic.AnthropicBedrock(aws_profile=self.profile_name)
|
||||
|
||||
@property
|
||||
def structured_client(self):
|
||||
"""A client patched with Instructor."""
|
||||
return instructor.from_anthropic(self.client)
|
||||
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
from ..models import Message
|
||||
|
||||
messages = [
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or DEFAULT_MODEL, messages=messages, **kwargs
|
||||
)
|
||||
|
||||
# Get the response content from the OpenAI response
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
# Create and return a properly formatted Message instance
|
||||
return Message(
|
||||
role="assistant",
|
||||
text=assistant_message.content or "",
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or DEFAULT_MODEL,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(
|
||||
self, prompt, response_model: Type[T], *, llm_model: str | None = None, **kwargs
|
||||
) -> T:
|
||||
# Ensure messages are provided in kwargs
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.structured_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
response_model=response_model,
|
||||
max_tokens=DEFAULT_MAX_TOKENS,
|
||||
**kwargs,
|
||||
)
|
||||
return response
|
||||
|
||||
def generate_text(self, prompt, *, llm_model, **kwargs):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.client.messages.create(
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
max_tokens=DEFAULT_MAX_TOKENS,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return response.content[0].text
|
||||
@@ -1,36 +1,54 @@
|
||||
from typing import Union
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import anthropic
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._base import BaseProvider
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "anthropic"
|
||||
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
||||
DEFAULT_MAX_TOKENS = 1000
|
||||
DEFAULT_MAX_TOKENS = 1_000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
|
||||
|
||||
class Anthropic(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
|
||||
def __init__(self, api_key: Union[str, None] = None):
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def client(self):
|
||||
"""The raw Anthropic client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("Anthropic API key is required")
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `anthropic` package: `pip install anthropic`"
|
||||
) from exc
|
||||
|
||||
return anthropic.Anthropic(api_key=self.api_key)
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def structured_client(self):
|
||||
"""A client patched with Instructor."""
|
||||
return instructor.from_anthropic(self.client)
|
||||
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||
"""Send a conversation to the Anthropic API."""
|
||||
from ..models import Message
|
||||
|
||||
@@ -41,8 +59,7 @@ class Anthropic(BaseProvider):
|
||||
response = self.client.messages.create(
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
max_tokens=DEFAULT_MAX_TOKENS,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
# Get the response content from the Anthropic response
|
||||
@@ -57,13 +74,28 @@ class Anthropic(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(self, model, response_model, **kwargs):
|
||||
response = self.structured_client.messages.create(
|
||||
model=model, response_model=response_model or self.DEFAULT_MODEL, **kwargs
|
||||
)
|
||||
return response
|
||||
@logger
|
||||
def structured_response(
|
||||
self, response_model: Type[T], *, llm_model: str | None = None, **kwargs
|
||||
) -> T:
|
||||
model = llm_model or self.DEFAULT_MODEL
|
||||
|
||||
def generate_text(self, prompt, *, llm_model, **kwargs):
|
||||
# Extract the prompt from kwargs if it exists
|
||||
prompt = kwargs.pop("prompt", kwargs.pop("messages", ""))
|
||||
|
||||
# Format the messages properly
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
response = self.structured_client.messages.create(
|
||||
model=model,
|
||||
messages=messages, # Add the messages parameter
|
||||
response_model=response_model,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
return response_model.model_validate(response)
|
||||
|
||||
@logger
|
||||
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
@@ -71,8 +103,7 @@ class Anthropic(BaseProvider):
|
||||
response = self.client.messages.create(
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
max_tokens=DEFAULT_MAX_TOKENS,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
return response.content[0].text
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
# TODO: this is a placeholder file for the Gemini provider
|
||||
# IT is not currently working as desired.
|
||||
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "gemini"
|
||||
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
|
||||
|
||||
|
||||
class Gemini(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
self.model_name = DEFAULT_MODEL
|
||||
|
||||
def set_model(self, model_name: str):
|
||||
self.model_name = model_name
|
||||
|
||||
@cached_property
|
||||
def client(self):
|
||||
"""The raw Gemini client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("Gemini API key is required")
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `google-generativeai` package: `pip install google-generativeai`"
|
||||
) from exc
|
||||
genai.configure(api_key=self.api_key)
|
||||
return genai.GenerativeModel(model_name=self.model_name)
|
||||
|
||||
@cached_property
|
||||
def structured_client(self):
|
||||
"""A Gemini client patched with Instructor."""
|
||||
return instructor.from_gemini(self.client)
|
||||
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation") -> "Message":
|
||||
"""Send a conversation to the Gemini API."""
|
||||
from ..models import Message
|
||||
|
||||
# Convert messages to Gemini's format
|
||||
chat = self.client.start_chat()
|
||||
|
||||
# Send all previous messages to establish context
|
||||
for msg in conversation.messages[:-1]: # All messages except the last one
|
||||
chat.send_message(msg.text)
|
||||
|
||||
# Send the final message and get response
|
||||
try:
|
||||
response = chat.send_message(conversation.messages[-1].text)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to send conversation to Gemini API: {e}") from e
|
||||
|
||||
# Create and return a properly formatted Message instance
|
||||
return Message(
|
||||
role="assistant",
|
||||
text=response.text,
|
||||
raw=response,
|
||||
llm_model=self.model_name,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
|
||||
"""Send a structured response to the Gemini API."""
|
||||
# Only try to pop if the key exists
|
||||
kwargs.pop("llm_model", None) # Add default value of None
|
||||
|
||||
try:
|
||||
response = self.structured_client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
response_model=response_model,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle the exception appropriately, e.g., log the error or raise a custom exception
|
||||
raise RuntimeError(
|
||||
f"Failed to send structured response to Gemini API: {e}"
|
||||
) from e
|
||||
return response_model.model_validate(response)
|
||||
|
||||
@logger
|
||||
def generate_text(self, prompt: str, **kwargs) -> str:
|
||||
"""Generate text using the Gemini API."""
|
||||
kwargs.pop("llm_model")
|
||||
try:
|
||||
response = self.client.generate_content(prompt, **kwargs)
|
||||
except Exception as e:
|
||||
# Handle the exception appropriately, e.g., log the error or raise a custom exception
|
||||
raise RuntimeError(f"Failed to generate text with Gemini API: {e}") from e
|
||||
return response.text
|
||||
@@ -1,34 +1,52 @@
|
||||
from typing import Union
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import groq
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._base import BaseProvider
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "groq"
|
||||
DEFAULT_MODEL = "llama3-8b-8192"
|
||||
DEFAULT_MAX_TOKENS = 1_000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
|
||||
|
||||
class Groq(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
|
||||
def __init__(self, api_key: Union[str, None] = None):
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def client(self):
|
||||
"""The raw Groq client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("Groq API key is required")
|
||||
try:
|
||||
import groq
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `groq` package: `pip install groq`"
|
||||
) from exc
|
||||
return groq.Groq(api_key=self.api_key)
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def structured_client(self):
|
||||
"""A client patched with Instructor."""
|
||||
return instructor.from_groq(self.client)
|
||||
|
||||
@logger
|
||||
def send_conversation(
|
||||
self,
|
||||
conversation: "Conversation",
|
||||
@@ -44,7 +62,7 @@ class Groq(BaseProvider):
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
# Get the response content from the Groq response
|
||||
@@ -59,7 +77,8 @@ class Groq(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(self, prompt: str, response_model, **kwargs):
|
||||
@logger
|
||||
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
|
||||
# Ensure messages are provided in kwargs
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
@@ -68,17 +87,19 @@ class Groq(BaseProvider):
|
||||
response = self.structured_client.chat.completions.create(
|
||||
messages=messages,
|
||||
response_model=response_model,
|
||||
**kwargs,
|
||||
model=kwargs.pop("llm_model", self.DEFAULT_MODEL),
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
return response
|
||||
return response_model.model_validate(response)
|
||||
|
||||
@logger
|
||||
def generate_text(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
llm_model: str,
|
||||
**kwargs,
|
||||
):
|
||||
) -> str:
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
@@ -86,7 +107,7 @@ class Groq(BaseProvider):
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
return str(response.choices[0].message.content)
|
||||
|
||||
@@ -1,32 +1,50 @@
|
||||
import ollama as ol
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._base import BaseProvider
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "ollama"
|
||||
DEFAULT_MODEL = "llama3.2"
|
||||
DEFAULT_TIMEOUT = 60
|
||||
DEFAULT_KWARGS = {}
|
||||
|
||||
|
||||
class Ollama(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
TIMEOUT = DEFAULT_TIMEOUT
|
||||
|
||||
def __init__(self, host_url: str = None):
|
||||
def __init__(self, host_url: str | None = None):
|
||||
self.host_url = host_url or settings.OLLAMA_HOST_URL
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def client(self):
|
||||
"""The raw Ollama client."""
|
||||
if not self.host_url:
|
||||
raise ValueError("No ollama host url provided")
|
||||
try:
|
||||
import ollama as ol
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `ollama` package: `pip install ollama`"
|
||||
) from exc
|
||||
return ol.Client(timeout=self.TIMEOUT, host=self.host_url)
|
||||
|
||||
@property
|
||||
def structured_client(self):
|
||||
@cached_property
|
||||
def structured_client(self) -> instructor.Instructor:
|
||||
"""A client patched with Instructor."""
|
||||
return instructor.from_openai(
|
||||
OpenAI(
|
||||
@@ -36,7 +54,8 @@ class Ollama(BaseProvider):
|
||||
mode=instructor.Mode.JSON,
|
||||
)
|
||||
|
||||
def send_conversation(self, conversation: "Conversation"):
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||
"""Send a conversation to the Ollama API."""
|
||||
from ..models import Message
|
||||
|
||||
@@ -44,7 +63,9 @@ class Ollama(BaseProvider):
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
]
|
||||
response = self.client.chat(
|
||||
model=conversation.llm_model or DEFAULT_MODEL, messages=messages
|
||||
model=conversation.llm_model or DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
assistant_message = response.get("message")
|
||||
|
||||
@@ -57,7 +78,16 @@ class Ollama(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(self, prompt, response_model, *, llm_model: str, **kwargs):
|
||||
@logger
|
||||
def structured_response(
|
||||
self,
|
||||
prompt: str,
|
||||
response_model: Type[T],
|
||||
*,
|
||||
llm_model: str | None = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""Get a structured response from the Ollama API."""
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
@@ -66,17 +96,23 @@ class Ollama(BaseProvider):
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
response_model=response_model,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
return response
|
||||
return response_model.model_validate(response)
|
||||
|
||||
def generate_text(self, prompt, *, llm_model):
|
||||
@logger
|
||||
def generate_text(
|
||||
self, prompt: str, *, llm_model: str | None = None, **kwargs
|
||||
) -> str:
|
||||
"""Generate text using the Ollama API."""
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.client.chat(
|
||||
messages=messages, model=llm_model or self.DEFAULT_MODEL
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
return response.get("message").get("content")
|
||||
return response.get("message", {}).get("content", "")
|
||||
|
||||
@@ -1,35 +1,52 @@
|
||||
from typing import Union
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
import openai as oa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._base import BaseProvider
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
PROVIDER_NAME = "openai"
|
||||
DEFAULT_MODEL = "gpt-4o-mini"
|
||||
DEFAULT_MAX_TOKENS = 1_000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
|
||||
|
||||
class OpenAI(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
|
||||
def __init__(self, api_key: Union[str, None] = None):
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def client(self):
|
||||
"""The raw OpenAI client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("OpenAI API key is required")
|
||||
try:
|
||||
import openai as oa
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `openai` package: `pip install openai`"
|
||||
) from exc
|
||||
return oa.OpenAI(api_key=self.api_key)
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def structured_client(self):
|
||||
"""A OpenAI client with Instructor."""
|
||||
return instructor.from_openai(self.client)
|
||||
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
from ..models import Message
|
||||
|
||||
@@ -38,7 +55,9 @@ class OpenAI(BaseProvider):
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or DEFAULT_MODEL, messages=messages, **kwargs
|
||||
model=conversation.llm_model or DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
# Get the response content from the OpenAI response
|
||||
@@ -53,27 +72,37 @@ class OpenAI(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(self, prompt, response_model, *, llm_model: str, **kwargs):
|
||||
@logger
|
||||
def structured_response(
|
||||
self,
|
||||
prompt: str,
|
||||
response_model: Type[T],
|
||||
*,
|
||||
llm_model: str | None = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""Get a structured response from the OpenAI API."""
|
||||
# Ensure messages are provided in kwargs
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.structured_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
response_model=response_model,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
return response
|
||||
return response_model.model_validate(response)
|
||||
|
||||
def generate_text(self, prompt, *, llm_model, **kwargs):
|
||||
@logger
|
||||
def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs):
|
||||
"""Generate text using the OpenAI API."""
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages, model=llm_model or self.DEFAULT_MODEL, **kwargs
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
+33
-12
@@ -1,40 +1,57 @@
|
||||
from typing import Union
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
import openai as oa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._base import BaseProvider
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "xai"
|
||||
DEFAULT_MODEL = "grok-beta"
|
||||
BASE_URL = "https://api.x.ai/v1"
|
||||
DEFAULT_MAX_TOKENS = 1000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
|
||||
|
||||
class XAI(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
|
||||
def __init__(self, api_key: Union[str, None] = None):
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def client(self):
|
||||
"""The raw OpenAI client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("XAI API key is required")
|
||||
try:
|
||||
import openai as oa
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `openai` package: `pip install openai`"
|
||||
) from exc
|
||||
return oa.OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=BASE_URL,
|
||||
)
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def structured_client(self):
|
||||
"""A client patched with Instructor."""
|
||||
return instructor.from_openai(self.client)
|
||||
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
from ..models import Message
|
||||
|
||||
@@ -45,7 +62,7 @@ class XAI(BaseProvider):
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
# Get the response content from the OpenAI response
|
||||
@@ -60,10 +77,14 @@ class XAI(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(self, prompt: str, response_model, *, llm_model):
|
||||
@logger
|
||||
def structured_response(
|
||||
self, prompt: str, response_model: Type[T], *, llm_model: str
|
||||
) -> T:
|
||||
raise NotImplementedError("XAI does not support structured responses")
|
||||
|
||||
def generate_text(self, prompt, *, llm_model, **kwargs):
|
||||
@logger
|
||||
def generate_text(self, prompt: str, *, llm_model: str, **kwargs) -> str:
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
@@ -71,7 +92,7 @@ class XAI(BaseProvider):
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
return str(response.choices[0].message.content)
|
||||
|
||||
@@ -4,13 +4,49 @@ from pydantic import Field, SecretStr, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class LoggingConfig(BaseSettings):
|
||||
"""The class that holds all the logging settings for the application."""
|
||||
|
||||
is_enabled: bool = Field(False, description="Enable logging")
|
||||
|
||||
model_config = SettingsConfigDict(extra="forbid")
|
||||
|
||||
def enable_logfire(self, **kwargs) -> None:
|
||||
"""Enable logging for the application."""
|
||||
# adding imports here to avoid forced dependencies
|
||||
try:
|
||||
import logfire
|
||||
from logging import basicConfig
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"To enable logging, please install logfire: `pip install logfire`"
|
||||
) from e
|
||||
|
||||
self.is_enabled = True
|
||||
logfire.configure(**kwargs)
|
||||
basicConfig(handlers=[logfire.LogfireLoggingHandler()])
|
||||
|
||||
try:
|
||||
logfire.configure(**kwargs)
|
||||
basicConfig(handlers=[logfire.LogfireLoggingHandler()])
|
||||
except Exception as e:
|
||||
self.is_enabled = False # Reset flag on failure
|
||||
raise RuntimeError("Failed to configure logging") from e
|
||||
|
||||
def disable_logfire(self) -> None:
|
||||
"""Disable logging for the application."""
|
||||
self.is_enabled = False
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""The class that holds all the API keys for the application."""
|
||||
|
||||
AMAZON_PROFILE_NAME: Optional[str] = Field("default", description="AWS Named Profile")
|
||||
ANTHROPIC_API_KEY: Optional[SecretStr] = Field(
|
||||
None, description="API key for Anthropic"
|
||||
)
|
||||
GROQ_API_KEY: Optional[SecretStr] = Field(None, description="API key for Groq")
|
||||
GEMINI_API_KEY: Optional[SecretStr] = Field(None, description="API key for Gemini")
|
||||
OPENAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for OpenAI")
|
||||
OLLAMA_HOST_URL: Optional[str] = Field(
|
||||
"http://127.0.0.1:11434", description="Fully qualified host URL for Ollama"
|
||||
@@ -21,6 +57,7 @@ class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env", env_file_encoding="utf-8", case_sensitive=True, extra="ignore"
|
||||
)
|
||||
logging: LoggingConfig = LoggingConfig()
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
@classmethod
|
||||
|
||||
+25
-13
@@ -1,26 +1,38 @@
|
||||
import difflib
|
||||
from typing import Union
|
||||
|
||||
from .providers import providers
|
||||
from .providers import BaseProvider, providers
|
||||
|
||||
_PROVIDER_NAMES = [provider.NAME.lower() for provider in providers]
|
||||
|
||||
|
||||
def find_provider(provider_name: Union[str, None]):
|
||||
"""Find a provider by name."""
|
||||
if provider_name:
|
||||
for provider_class in providers:
|
||||
if provider_class.NAME.lower() == provider_name.lower():
|
||||
# Instantiate the provider
|
||||
return provider_class()
|
||||
def find_provider(provider_name: str | None) -> BaseProvider:
|
||||
"""
|
||||
Find and instantiate a provider by name.
|
||||
|
||||
Parameters:
|
||||
provider_name (Union[str, None]): The name of the provider to find.
|
||||
|
||||
Returns:
|
||||
An instance of the provider class if found.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not specified or is not found, with a suggestion for the closest match.
|
||||
"""
|
||||
if provider_name is None:
|
||||
raise ValueError("No provider specified.")
|
||||
|
||||
# Find the provider by name.
|
||||
for provider_class in providers:
|
||||
if provider_class.NAME.lower() == provider_name.lower():
|
||||
# Instantiate the provider
|
||||
return provider_class()
|
||||
|
||||
# Find the closest match
|
||||
provider_found = difflib.get_close_matches(
|
||||
provider_name.lower(), _PROVIDER_NAMES, n=1
|
||||
) # Show only one suggestion
|
||||
|
||||
)
|
||||
if provider_found:
|
||||
raise ValueError(
|
||||
f"Provider {provider_name!r} not found. Did you mean {provider_found[0]!r}?"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Provider {provider_name} not found.")
|
||||
raise ValueError(f"Provider {provider_name} not found.")
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
# Add the project root to the Python path.
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from simplemind import Session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sm():
|
||||
"""Fixture that provides a simplemind Session instance with default settings."""
|
||||
return Session()
|
||||
@@ -0,0 +1,2 @@
|
||||
def test_basic_math():
|
||||
assert 1 + 1 == 2
|
||||
@@ -0,0 +1,31 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ResponseModel(BaseModel):
|
||||
result: int
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls",
|
||||
[
|
||||
Anthropic,
|
||||
Gemini,
|
||||
OpenAI,
|
||||
Groq,
|
||||
Ollama,
|
||||
Amazon
|
||||
],
|
||||
)
|
||||
def test_generate_data(provider_cls):
|
||||
provider = provider_cls()
|
||||
prompt = "What is 2+2?"
|
||||
|
||||
data = provider.structured_response(prompt=prompt, response_model=ResponseModel)
|
||||
|
||||
assert isinstance(data, ResponseModel)
|
||||
assert isinstance(data.result, int)
|
||||
@@ -0,0 +1,24 @@
|
||||
import pytest
|
||||
|
||||
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls",
|
||||
[
|
||||
Anthropic,
|
||||
Gemini,
|
||||
OpenAI,
|
||||
Groq,
|
||||
Ollama,
|
||||
Amazon,
|
||||
],
|
||||
)
|
||||
def test_generate_text(provider_cls):
|
||||
provider = provider_cls()
|
||||
prompt = "What is 2+2?"
|
||||
|
||||
response = provider.generate_text(prompt=prompt, llm_model=provider.DEFAULT_MODEL)
|
||||
|
||||
assert isinstance(response, str)
|
||||
assert len(response) > 0
|
||||
Reference in New Issue
Block a user