mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 14:50:16 +00:00
Compare commits
53 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c648a922b4 | |||
| 873f5ba5f8 | |||
| 28a7b2f140 | |||
| 173162e798 | |||
| cd0be3ad89 | |||
| 3dd2e1b248 | |||
| ad1800840d | |||
| d62f297b68 | |||
| a2597709d2 | |||
| 1455b5ba13 | |||
| 0fb54d1987 | |||
| fe06331662 | |||
| 56b1e65d70 | |||
| 4b3e1bc6dd | |||
| f5b922ade8 | |||
| 3a7383425f | |||
| 92c10fc41e | |||
| caceba381d | |||
| 0795464fd7 | |||
| d82effdfb1 | |||
| e648292cb3 | |||
| 37a9333be3 | |||
| cbc3739411 | |||
| 7c8f22bef1 | |||
| 9c3f2a6df3 | |||
| febf5473d5 | |||
| 48ac97f070 | |||
| c41a3f00fb | |||
| 25ee4ae32c | |||
| 984721f02b | |||
| 69c8723770 | |||
| 0c10d5676a | |||
| e0ddf41e15 | |||
| f940ae2dfd | |||
| 85fa4f5879 | |||
| 44581e8fe3 | |||
| 9503ec7fd3 | |||
| 418f36dcc0 | |||
| bf9683cfd0 | |||
| 3909588f3e | |||
| 33d8f18bff | |||
| d7388ef0d5 | |||
| 02d10bfda9 | |||
| 5dc6e7b006 | |||
| 62933c8553 | |||
| f0a6be73f8 | |||
| 9257a04f34 | |||
| 64dbe9a2e7 | |||
| ccb8311089 | |||
| 0c29380501 | |||
| 7b43208a03 | |||
| e931fd0eae | |||
| 736d942527 |
@@ -167,3 +167,4 @@ cython_debug/
|
||||
|
||||
src/**
|
||||
requirements.txt
|
||||
Pipfile
|
||||
|
||||
@@ -1,6 +1,26 @@
|
||||
Release History
|
||||
===============
|
||||
|
||||
## 0.1.7 (2024-11-01)
|
||||
|
||||
- Add `logger` decorator.
|
||||
- Add `sm.enable_logfire()` function.
|
||||
- General improvements.
|
||||
|
||||
## 0.1.6 (2024-10-31)
|
||||
|
||||
- Add `sm.Plugin` syntax sugar.
|
||||
- Improvements to Anthropic provider, related to max tokens.
|
||||
- General improvements.
|
||||
- Add tests for structured response.
|
||||
- Add `llm_model` to `structured_response`.
|
||||
|
||||
## 0.1.5 (2024-10-31)
|
||||
|
||||
- Add Gemini provider.
|
||||
- Add structured response to Gemini provider.
|
||||
- Support for Python 3.10.
|
||||
|
||||
## 0.1.4 (2024-10-30)
|
||||
|
||||
- Introduce `Session` class to manage repeatability.
|
||||
|
||||
@@ -6,8 +6,6 @@ Simplemind is AI library designed to simplify your experience with AI APIs in Py
|
||||
|
||||

|
||||
|
||||
[](https://mutable.ai/kennethreitz/simplemind)
|
||||
|
||||
## Features
|
||||
|
||||
With Simplemind, tapping into AI is as easy as a friendly conversation.
|
||||
@@ -18,7 +16,7 @@ With Simplemind, tapping into AI is as easy as a friendly conversation.
|
||||
|
||||
## Supported APIs
|
||||
|
||||
To specify a specific provider or model, you can use the `llm_provider` and `llm_model` parameters when calling: `generate_text`, `generate_data`, or `create_conversation`.
|
||||
To specify a specific provider or model, you can use the `llm_provider` and `llm_model` parameters when calling: `generate_text`, `generate_data`, or `create_conversation`. The APIs remain identital between all supported providers/models.
|
||||
|
||||
- [**Anthropic's Claude**](https://www.anthropic.com/claude)
|
||||
- [**Google's Gemini**](https://gemini.google/)
|
||||
@@ -27,7 +25,7 @@ To specify a specific provider or model, you can use the `llm_provider` and `llm
|
||||
- [**OpenAI's GPT**](https://openai.com/gpt)
|
||||
- [**xAI's Grok**](https://x.ai/)
|
||||
|
||||
If you want to see Simplemind support, additional providers or models, please request a pull!
|
||||
If you want to see Simplemind support, additional providers or models, please send a pull request!
|
||||
|
||||
## Why SimpleMind?
|
||||
- **Intuitive**: Built with Pythonic simplicity and readability in mind.
|
||||
@@ -50,7 +48,7 @@ First, authenticate your API keys by setting them in the environment variables:
|
||||
$ export OPENAI_API_KEY="sk-..."
|
||||
```
|
||||
|
||||
This pattern allows you to keep your API keys private and out of your codebase. Other supported environment variables: `ANTHROPIC_API_KEY`, `XAI_API_KEY`, and `GROQ_API_KEY`.
|
||||
This pattern allows you to keep your API keys private and out of your codebase. Other supported environment variables: `ANTHROPIC_API_KEY`, `XAI_API_KEY`, `GROQ_API_KEY`, and `GEMINI_API_KEY`.
|
||||
|
||||
Next, import Simplemind and start using it:
|
||||
|
||||
@@ -219,3 +217,4 @@ Simplemind is licensed under the Apache 2.0 License.
|
||||
|
||||
## Acknowledgements
|
||||
Simplemind is inspired by the philosophy of "code for humans" and aims to make working with AI models accessible to all. Special thanks to the open-source community for their contributions and inspiration.
|
||||
|
||||
|
||||
+1
-1
@@ -16,7 +16,7 @@ import simplemind
|
||||
project = "simplemind"
|
||||
copyright = "2024 Kenneth Reitz"
|
||||
author = "Kenneth Reitz"
|
||||
release = "v0.1.4"
|
||||
release = "v0.1.7"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "simplemind"
|
||||
version = "0.1.4"
|
||||
version = "0.1.7"
|
||||
description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
+11
-1
@@ -16,7 +16,7 @@ class Session:
|
||||
self,
|
||||
*,
|
||||
llm_provider: str = settings.DEFAULT_LLM_PROVIDER,
|
||||
llm_model: str = settings.DEFAULT_LLM_MODEL,
|
||||
llm_model: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.llm_provider = llm_provider
|
||||
@@ -113,6 +113,14 @@ def generate_text(
|
||||
return provider.generate_text(prompt=prompt, llm_model=llm_model, **kwargs)
|
||||
|
||||
|
||||
def enable_logfire() -> None:
|
||||
"""Enable logfire logging."""
|
||||
settings.logging.enable_logfire()
|
||||
|
||||
|
||||
# Syntax sugar.
|
||||
Plugin = BasePlugin
|
||||
|
||||
__all__ = [
|
||||
"create_conversation",
|
||||
"find_provider",
|
||||
@@ -121,4 +129,6 @@ __all__ = [
|
||||
"settings",
|
||||
"BasePlugin",
|
||||
"Session",
|
||||
"Plugin",
|
||||
"enable_logfire",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
import time
|
||||
from typing import Any, Callable
|
||||
|
||||
import logfire
|
||||
|
||||
from .settings import settings
|
||||
|
||||
|
||||
def logger(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""A decorator that logs the function parameters, function returns,
|
||||
and exceptions raised if logging is enabled, using logfire.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
if not settings.logging.is_enabled:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
logfire.info(f"Calling {func.__name__} with args: {args}, kwargs: {kwargs}")
|
||||
t1 = time.perf_counter()
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
t2 = time.perf_counter()
|
||||
logfire.info(f"{func.__name__} returned: {result} in {t2-t1} seconds")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
t2 = time.perf_counter()
|
||||
logfire.error(f"Error in {func.__name__}: {e} in {t2-t1} seconds")
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
@@ -1,6 +1,6 @@
|
||||
from types import TracebackType
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import Any, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Type, TypeVar
|
||||
|
||||
from instructor import Instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
|
||||
@@ -1,24 +1,29 @@
|
||||
from functools import cached_property
|
||||
from typing import Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import anthropic
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "anthropic"
|
||||
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
||||
DEFAULT_MAX_TOKENS = 1000
|
||||
DEFAULT_MAX_TOKENS = 1_000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
|
||||
|
||||
class Anthropic(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
@@ -28,6 +33,13 @@ class Anthropic(BaseProvider):
|
||||
"""The raw Anthropic client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("Anthropic API key is required")
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `anthropic` package: `pip install anthropic`"
|
||||
) from exc
|
||||
|
||||
return anthropic.Anthropic(api_key=self.api_key)
|
||||
|
||||
@cached_property
|
||||
@@ -35,7 +47,8 @@ class Anthropic(BaseProvider):
|
||||
"""A client patched with Instructor."""
|
||||
return instructor.from_anthropic(self.client)
|
||||
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||
"""Send a conversation to the Anthropic API."""
|
||||
from ..models import Message
|
||||
|
||||
@@ -46,8 +59,7 @@ class Anthropic(BaseProvider):
|
||||
response = self.client.messages.create(
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
max_tokens=DEFAULT_MAX_TOKENS,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
# Get the response content from the Anthropic response
|
||||
@@ -62,12 +74,27 @@ class Anthropic(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(self, model: str, response_model: Type[T], **kwargs) -> T:
|
||||
response = self.structured_client.messages.create(
|
||||
model=model or self.DEFAULT_MODEL, response_model=response_model, **kwargs
|
||||
)
|
||||
return response
|
||||
@logger
|
||||
def structured_response(
|
||||
self, response_model: Type[T], *, llm_model: str | None = None, **kwargs
|
||||
) -> T:
|
||||
model = llm_model or self.DEFAULT_MODEL
|
||||
|
||||
# Extract the prompt from kwargs if it exists
|
||||
prompt = kwargs.pop("prompt", kwargs.pop("messages", ""))
|
||||
|
||||
# Format the messages properly
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
response = self.structured_client.messages.create(
|
||||
model=model,
|
||||
messages=messages, # Add the messages parameter
|
||||
response_model=response_model,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
return response_model.model_validate(response)
|
||||
|
||||
@logger
|
||||
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
@@ -76,8 +103,7 @@ class Anthropic(BaseProvider):
|
||||
response = self.client.messages.create(
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
max_tokens=DEFAULT_MAX_TOKENS,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
return response.content[0].text
|
||||
|
||||
@@ -2,21 +2,25 @@
|
||||
# IT is not currently working as desired.
|
||||
|
||||
from functools import cached_property
|
||||
from typing import Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import google.generativeai as genai
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
PROVIDER_NAME = "gemini"
|
||||
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "gemini"
|
||||
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
|
||||
|
||||
|
||||
class Gemini(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
@@ -25,12 +29,21 @@ class Gemini(BaseProvider):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
self.model_name = DEFAULT_MODEL
|
||||
|
||||
def set_model(self, model_name: str):
|
||||
self.model_name = model_name
|
||||
|
||||
@cached_property
|
||||
def client(self, model_name: str = DEFAULT_MODEL):
|
||||
def client(self):
|
||||
"""The raw Gemini client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("Gemini API key is required")
|
||||
self.model_name = model_name
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `google-generativeai` package: `pip install google-generativeai`"
|
||||
) from exc
|
||||
genai.configure(api_key=self.api_key)
|
||||
return genai.GenerativeModel(model_name=self.model_name)
|
||||
|
||||
@cached_property
|
||||
@@ -38,6 +51,7 @@ class Gemini(BaseProvider):
|
||||
"""A Gemini client patched with Instructor."""
|
||||
return instructor.from_gemini(self.client)
|
||||
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation") -> "Message":
|
||||
"""Send a conversation to the Gemini API."""
|
||||
from ..models import Message
|
||||
@@ -64,9 +78,11 @@ class Gemini(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
|
||||
"""Send a structured response to the Gemini API."""
|
||||
llm_model = kwargs.pop("llm_model", self.model_name)
|
||||
# Only try to pop if the key exists
|
||||
kwargs.pop("llm_model", None) # Add default value of None
|
||||
|
||||
try:
|
||||
response = self.structured_client.chat.completions.create(
|
||||
@@ -79,12 +95,12 @@ class Gemini(BaseProvider):
|
||||
raise RuntimeError(
|
||||
f"Failed to send structured response to Gemini API: {e}"
|
||||
) from e
|
||||
return response
|
||||
return response_model.model_validate(response)
|
||||
|
||||
@logger
|
||||
def generate_text(self, prompt: str, **kwargs) -> str:
|
||||
"""Generate text using the Gemini API."""
|
||||
llm_model = kwargs.pop("llm_model", self.model_name)
|
||||
|
||||
kwargs.pop("llm_model")
|
||||
try:
|
||||
response = self.client.generate_content(prompt, **kwargs)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,22 +1,29 @@
|
||||
from functools import cached_property
|
||||
from typing import Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import groq
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
PROVIDER_NAME = "groq"
|
||||
DEFAULT_MODEL = "llama3-8b-8192"
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "groq"
|
||||
DEFAULT_MODEL = "llama3-8b-8192"
|
||||
DEFAULT_MAX_TOKENS = 1_000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
|
||||
|
||||
class Groq(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
@@ -26,6 +33,12 @@ class Groq(BaseProvider):
|
||||
"""The raw Groq client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("Groq API key is required")
|
||||
try:
|
||||
import groq
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `groq` package: `pip install groq`"
|
||||
) from exc
|
||||
return groq.Groq(api_key=self.api_key)
|
||||
|
||||
@cached_property
|
||||
@@ -33,6 +46,7 @@ class Groq(BaseProvider):
|
||||
"""A client patched with Instructor."""
|
||||
return instructor.from_groq(self.client)
|
||||
|
||||
@logger
|
||||
def send_conversation(
|
||||
self,
|
||||
conversation: "Conversation",
|
||||
@@ -48,7 +62,7 @@ class Groq(BaseProvider):
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
# Get the response content from the Groq response
|
||||
@@ -63,6 +77,7 @@ class Groq(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
|
||||
# Ensure messages are provided in kwargs
|
||||
messages = [
|
||||
@@ -72,17 +87,19 @@ class Groq(BaseProvider):
|
||||
response = self.structured_client.chat.completions.create(
|
||||
messages=messages,
|
||||
response_model=response_model,
|
||||
**kwargs,
|
||||
model=kwargs.pop("llm_model", self.DEFAULT_MODEL),
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
return response
|
||||
return response_model.model_validate(response)
|
||||
|
||||
@logger
|
||||
def generate_text(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
llm_model: str,
|
||||
**kwargs,
|
||||
):
|
||||
) -> str:
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
@@ -90,7 +107,7 @@ class Groq(BaseProvider):
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
return str(response.choices[0].message.content)
|
||||
|
||||
@@ -1,25 +1,30 @@
|
||||
from functools import cached_property
|
||||
from typing import Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
import ollama as ol
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "ollama"
|
||||
DEFAULT_MODEL = "llama3.2"
|
||||
DEFAULT_TIMEOUT = 60
|
||||
DEFAULT_KWARGS = {}
|
||||
|
||||
|
||||
class Ollama(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
TIMEOUT = DEFAULT_TIMEOUT
|
||||
|
||||
def __init__(self, host_url: str | None = None):
|
||||
@@ -30,6 +35,12 @@ class Ollama(BaseProvider):
|
||||
"""The raw Ollama client."""
|
||||
if not self.host_url:
|
||||
raise ValueError("No ollama host url provided")
|
||||
try:
|
||||
import ollama as ol
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `ollama` package: `pip install ollama`"
|
||||
) from exc
|
||||
return ol.Client(timeout=self.TIMEOUT, host=self.host_url)
|
||||
|
||||
@cached_property
|
||||
@@ -43,7 +54,8 @@ class Ollama(BaseProvider):
|
||||
mode=instructor.Mode.JSON,
|
||||
)
|
||||
|
||||
def send_conversation(self, conversation: "Conversation") -> "Message":
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||
"""Send a conversation to the Ollama API."""
|
||||
from ..models import Message
|
||||
|
||||
@@ -51,7 +63,9 @@ class Ollama(BaseProvider):
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
]
|
||||
response = self.client.chat(
|
||||
model=conversation.llm_model or DEFAULT_MODEL, messages=messages
|
||||
model=conversation.llm_model or DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
assistant_message = response.get("message")
|
||||
|
||||
@@ -64,8 +78,14 @@ class Ollama(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
def structured_response(
|
||||
self, prompt: str, response_model: Type[T], *, llm_model: str, **kwargs
|
||||
self,
|
||||
prompt: str,
|
||||
response_model: Type[T],
|
||||
*,
|
||||
llm_model: str | None = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""Get a structured response from the Ollama API."""
|
||||
messages = [
|
||||
@@ -76,18 +96,23 @@ class Ollama(BaseProvider):
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
response_model=response_model,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
return response
|
||||
return response_model.model_validate(response)
|
||||
|
||||
def generate_text(self, prompt: str, *, llm_model: str) -> str:
|
||||
@logger
|
||||
def generate_text(
|
||||
self, prompt: str, *, llm_model: str | None = None, **kwargs
|
||||
) -> str:
|
||||
"""Generate text using the Ollama API."""
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.client.chat(
|
||||
messages=messages, model=llm_model or self.DEFAULT_MODEL
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
return response.get("message", {}).get("content", "")
|
||||
|
||||
@@ -1,22 +1,28 @@
|
||||
from functools import cached_property
|
||||
from typing import Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
import openai as oa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
PROVIDER_NAME = "openai"
|
||||
DEFAULT_MODEL = "gpt-4o-mini"
|
||||
DEFAULT_MAX_TOKENS = 1_000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
|
||||
|
||||
class OpenAI(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
@@ -26,6 +32,12 @@ class OpenAI(BaseProvider):
|
||||
"""The raw OpenAI client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("OpenAI API key is required")
|
||||
try:
|
||||
import openai as oa
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `openai` package: `pip install openai`"
|
||||
) from exc
|
||||
return oa.OpenAI(api_key=self.api_key)
|
||||
|
||||
@cached_property
|
||||
@@ -33,7 +45,8 @@ class OpenAI(BaseProvider):
|
||||
"""A OpenAI client with Instructor."""
|
||||
return instructor.from_openai(self.client)
|
||||
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
from ..models import Message
|
||||
|
||||
@@ -42,7 +55,9 @@ class OpenAI(BaseProvider):
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or DEFAULT_MODEL, messages=messages, **kwargs
|
||||
model=conversation.llm_model or DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
# Get the response content from the OpenAI response
|
||||
@@ -57,6 +72,7 @@ class OpenAI(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
def structured_response(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -74,16 +90,19 @@ class OpenAI(BaseProvider):
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
response_model=response_model,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
return response
|
||||
return response_model.model_validate(response)
|
||||
|
||||
@logger
|
||||
def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs):
|
||||
"""Generate text using the OpenAI API."""
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages, model=llm_model or self.DEFAULT_MODEL, **kwargs
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
@@ -1,20 +1,30 @@
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
import openai as oa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "xai"
|
||||
DEFAULT_MODEL = "grok-beta"
|
||||
BASE_URL = "https://api.x.ai/v1"
|
||||
DEFAULT_MAX_TOKENS = 1000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
|
||||
|
||||
class XAI(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
@@ -24,6 +34,12 @@ class XAI(BaseProvider):
|
||||
"""The raw OpenAI client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("XAI API key is required")
|
||||
try:
|
||||
import openai as oa
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `openai` package: `pip install openai`"
|
||||
) from exc
|
||||
return oa.OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=BASE_URL,
|
||||
@@ -34,7 +50,8 @@ class XAI(BaseProvider):
|
||||
"""A client patched with Instructor."""
|
||||
return instructor.from_openai(self.client)
|
||||
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
from ..models import Message
|
||||
|
||||
@@ -45,7 +62,7 @@ class XAI(BaseProvider):
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
# Get the response content from the OpenAI response
|
||||
@@ -60,10 +77,14 @@ class XAI(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(self, prompt: str, response_model, *, llm_model: str):
|
||||
@logger
|
||||
def structured_response(
|
||||
self, prompt: str, response_model: Type[T], *, llm_model: str
|
||||
) -> T:
|
||||
raise NotImplementedError("XAI does not support structured responses")
|
||||
|
||||
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
|
||||
@logger
|
||||
def generate_text(self, prompt: str, *, llm_model: str, **kwargs) -> str:
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
@@ -71,7 +92,7 @@ class XAI(BaseProvider):
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
**kwargs,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
return str(response.choices[0].message.content)
|
||||
|
||||
+28
-6
@@ -1,19 +1,42 @@
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import Field, SecretStr, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
logging_level = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
||||
|
||||
|
||||
class LoggingConfig(BaseSettings):
|
||||
"""The class that holds all the logging settings for the application."""
|
||||
|
||||
enabled: bool = Field(False, description="Enable logging")
|
||||
level: logging_level = Field("INFO", description="The logging level")
|
||||
is_enabled: bool = Field(False, description="Enable logging")
|
||||
|
||||
model_config = SettingsConfigDict(extra="forbid")
|
||||
|
||||
def enable_logfire(self, **kwargs) -> None:
|
||||
"""Enable logging for the application."""
|
||||
# adding imports here to avoid forced dependencies
|
||||
try:
|
||||
import logfire
|
||||
from logging import basicConfig
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"To enable logging, please install logfire: `pip install logfire`"
|
||||
) from e
|
||||
|
||||
self.is_enabled = True
|
||||
logfire.configure(**kwargs)
|
||||
basicConfig(handlers=[logfire.LogfireLoggingHandler()])
|
||||
|
||||
try:
|
||||
logfire.configure(**kwargs)
|
||||
basicConfig(handlers=[logfire.LogfireLoggingHandler()])
|
||||
except Exception as e:
|
||||
self.is_enabled = False # Reset flag on failure
|
||||
raise RuntimeError("Failed to configure logging") from e
|
||||
|
||||
def disable_logfire(self) -> None:
|
||||
"""Disable logging for the application."""
|
||||
self.is_enabled = False
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""The class that holds all the API keys for the application."""
|
||||
@@ -29,7 +52,6 @@ class Settings(BaseSettings):
|
||||
)
|
||||
XAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for xAI")
|
||||
DEFAULT_LLM_PROVIDER: str = Field("openai", description="The default LLM provider")
|
||||
DEFAULT_LLM_MODEL: str = Field("gpt-4o-mini", description="The default LLM model")
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env", env_file_encoding="utf-8", case_sensitive=True, extra="ignore"
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
# Add the project root to the Python path.
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from simplemind import Session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sm():
|
||||
"""Fixture that provides a simplemind Session instance with default settings."""
|
||||
return Session()
|
||||
@@ -0,0 +1,2 @@
|
||||
def test_basic_math():
|
||||
assert 1 + 1 == 2
|
||||
@@ -0,0 +1,28 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
import pytest
|
||||
from simplemind.providers import Anthropic, Gemini, Groq, Ollama, OpenAI
|
||||
|
||||
|
||||
class ResponseModel(BaseModel):
|
||||
result: int
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls",
|
||||
[
|
||||
Anthropic,
|
||||
Gemini,
|
||||
OpenAI,
|
||||
Groq,
|
||||
Ollama,
|
||||
],
|
||||
)
|
||||
def test_generate_data(provider_cls):
|
||||
provider = provider_cls()
|
||||
prompt = "What is 2+2?"
|
||||
|
||||
data = provider.structured_response(prompt=prompt, response_model=ResponseModel)
|
||||
|
||||
assert isinstance(data, ResponseModel)
|
||||
assert isinstance(data.result, int)
|
||||
@@ -0,0 +1,22 @@
|
||||
import pytest
|
||||
from simplemind.providers import Anthropic, Gemini, Groq, Ollama, OpenAI
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls",
|
||||
[
|
||||
Anthropic,
|
||||
Gemini,
|
||||
OpenAI,
|
||||
Groq,
|
||||
Ollama,
|
||||
],
|
||||
)
|
||||
def test_generate_text(provider_cls):
|
||||
provider = provider_cls()
|
||||
prompt = "What is 2+2?"
|
||||
|
||||
response = provider.generate_text(prompt=prompt, llm_model=provider.DEFAULT_MODEL)
|
||||
|
||||
assert isinstance(response, str)
|
||||
assert len(response) > 0
|
||||
Reference in New Issue
Block a user