50 Commits

Author SHA1 Message Date
kennethreitz b5a901efaf Bump version to 0.3.3 and remove debug print statements from OpenAI provider 2025-02-08 19:14:30 -05:00
kennethreitz 9ccef9abdc Refactor OpenAI provider code for improved readability and consistency 2025-02-08 19:13:15 -05:00
kennethreitz 3421de0fc1 Merge pull request #61 from Red5d/main
Add tool-calling capability for Groq provider
2025-02-01 10:20:55 -05:00
Red5d 54b0007947 Add tool-calling capability for Groq provider 2025-02-01 00:57:15 -05:00
kennethreitz 90af44ace0 Merge pull request #53 from Red5d/main
Add image support to the OpenAI provider
2025-01-30 19:50:51 -05:00
kennethreitz cff3bff3d5 Merge pull request #59 from fcoagz/patch-1 2025-01-29 17:45:50 -05:00
Francisco Griman 3abbb79f6c Update .envrc.template 2025-01-29 18:41:01 -04:00
kennethreitz 59c1bd3a0f Update README.md 2025-01-27 11:56:09 -05:00
kennethreitz 052781014d bump version to 0.3.2 and improve Deepseek provider 2025-01-27 11:54:03 -05:00
kennethreitz db28f1195c bump version to 0.3.1 in pyproject.toml 2025-01-27 11:49:51 -05:00
kennethreitz b0a7197c6e Update changelog to reflect introduction of Deepseek provider in version 0.3.1 2025-01-27 11:49:29 -05:00
kennethreitz 7684c2568b Update changelog to include Deepseek provider introduction 2025-01-27 11:48:53 -05:00
kennethreitz 8b90dbba40 Add Deepseek provider information to README 2025-01-27 11:48:23 -05:00
kennethreitz 752ccb1de8 Merge pull request #55 from jin10086/deepseek
add llm_provider  Deepseek
2025-01-27 11:46:57 -05:00
kennethreitz 391bfaaeab bump version to 0.3.0 and update changelog for conversation save/load functionality 2025-01-11 08:42:32 -05:00
kennethreitz d963bc0b1c Update simplemind/providers/deepseek.py
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
2025-01-10 11:40:20 -05:00
jin10086 0c1f225252 add llm_provider Deepseek
docs: https://api-docs.deepseek.com/
2025-01-10 09:33:15 +08:00
kennethreitz 4decaa0722 Merge pull request #51 from lucianosrp/feat/save-conversation
feat: add conversation save/load functionality
2025-01-09 17:08:14 -05:00
kennethreitz 39b5a5e19d Merge pull request #54 from gabrielmotaa/fix/missing-dependency-gemini
Missing dependency when using Gemini
2024-12-10 17:28:55 -05:00
Gabriel da Mota ef38fea767 fix missing dependency 2024-12-09 11:28:43 -03:00
Red5d 8181f37fed Add support for specifying an image_url parameter when using generate_text() or generate_data() with models that can process images. 2024-11-27 16:26:11 -05:00
Luciano 3aacfd51ee Merge branch 'main' into feat/save-conversation 2024-11-26 16:38:56 +01:00
Luciano Scarpulla a2991eec0c add conversation save/load functionality 2024-11-26 23:37:30 +08:00
kennethreitz 9ae9a2703a Merge pull request #49 from wei840222/main
Fix: Ollama error TypeError: 'Chat' object is not callable
2024-11-22 12:07:00 -05:00
wei840222 0661b097d2 fix: Ollama error TypeError: 'Chat' object is not callable 2024-11-23 00:56:33 +08:00
kennethreitz fad442ba3f Update FUNDING.yml 2024-11-17 06:29:05 -05:00
kennethreitz 5b9624c385 Remove Ko-fi and thanks_dev entries from funding configuration 2024-11-17 06:26:10 -05:00
kennethreitz 8ff0521e17 Add funding configuration for project contributors 2024-11-17 06:25:14 -05:00
kennethreitz d5bdb712e9 Add tool_calling.py and test_tools.py 2024-11-15 20:29:35 -05:00
Luciano Scarpulla a97f9be2c8 fix openai 2024-11-15 12:09:39 +08:00
Luciano Scarpulla 107f983a18 add openai 2024-11-14 17:25:34 +08:00
Luciano Scarpulla 2404e2c977 some refactoring 2024-11-13 18:05:51 +08:00
Luciano Scarpulla c87a598286 fix import 2024-11-13 17:55:46 +08:00
Luciano Scarpulla 9662b60177 add decorator test 2024-11-13 17:54:58 +08:00
Luciano Scarpulla ea997aae7b add tool decorator and example 2024-11-13 12:24:02 +08:00
Luciano Scarpulla 081baf203c add README section 2024-11-12 12:18:33 +08:00
Luciano Scarpulla 4cb18e9e3b re-add changes from main 2024-11-12 11:54:24 +08:00
Luciano 0462ea0e38 Merge branch 'main' into feat-function-calling 2024-11-12 11:50:55 +08:00
Luciano Scarpulla 8492ec9456 add base edits 2024-11-12 11:49:06 +08:00
Luciano Scarpulla 1709055e1a first basic working version (anthropic) 2024-11-12 11:48:27 +08:00
kennethreitz 5fa67c3b2f Update CHANGELOG.md and pyproject.toml for version 0.2.4 2024-11-11 11:38:03 -05:00
kennethreitz b7e950a8f0 Refactor imports in amazon.py 2024-11-11 11:37:30 -05:00
kennethreitz 735c6ba665 Bump version to 0.2.3 in pyproject.toml 2024-11-11 11:30:11 -05:00
kennethreitz 9132030cbd Update CHANGELOG.md to remove default max-tokens for OpenAI provider 2024-11-11 11:30:11 -05:00
kennethreitz aeea8936ce Merge pull request #42 from Siddhesh-Agarwal/main
Removed redundant variables
2024-11-11 11:30:02 -05:00
Luciano Scarpulla c2303114ab fix base 2024-11-11 12:40:20 +08:00
Luciano Scarpulla fe5af93780 first draft 2024-11-11 12:29:00 +08:00
Siddhesh Agarwal e79b474215 fixed dependencies 2024-11-10 20:05:49 +05:30
Siddhesh Agarwal fe2ca9d5f5 black + isort formatting 2024-11-10 20:00:13 +05:30
Siddhesh Agarwal 670240b943 removed reduntant variables. moved few inside the class 2024-11-10 19:59:52 +05:30
34 changed files with 1235 additions and 205 deletions
+1
View File
@@ -5,3 +5,4 @@ export OLLAMA_HOST_URL=""
export OPENAI_API_KEY=""
export XAI_API_KEY=""
export AMAZON_PROFILE_NAME=""
export DEEPSEEK_API_KEY=""
+3
View File
@@ -0,0 +1,3 @@
github: kennethreitz
thanks_dev: kennethreitz
custom: https://cash.app/$KennethReitz
+24
View File
@@ -1,6 +1,30 @@
Release History
===============
## 0.3.3 (2024-02-08)
- Improve openai provider by removing debug print statements.
## 0.3.2 (2024-01-27)
- Improve Deepseek provider.
## 0.3.1 (2024-01-27)
- Introduce Deepseek provider.
## 0.3.0 (2024-11-12)
- Introduce save / load functionality for `Conversation`.
## 0.2.4 (2024-11-11)
- General improvements.
## 0.2.3 (2024-11-04)
- Remove default max-tokens for OpenAI provider.
## 0.2.3 (2024-11-03)
- Update default model for Amazon provider.
+113 -1
View File
@@ -37,6 +37,11 @@ The APIs remain identical between all supported providers / models:
<td><code>"amazon"</code></td>
<td><code>"anthropic.claude-3-5-sonnet-20241022-v2:0"</code></td>
</tr>
<tr>
<td><a href="https://www.deepseek.com">Deepseek</a></td>
<td><code>"deepseek"</code></td>
<td><code>"deepseek-chat"</code></td>
</tr>
<tr>
<td><a href="https://gemini.google/">Google's Gemini</a></td>
<td><code>"gemini"</code></td>
@@ -83,7 +88,7 @@ First, authenticate your API keys by setting them in the environment variables:
$ export OPENAI_API_KEY="sk-..."
```
This pattern allows you to keep your API keys private and out of your codebase. Other supported environment variables: `ANTHROPIC_API_KEY`, `XAI_API_KEY`, `GROQ_API_KEY`, and `GEMINI_API_KEY`.
This pattern allows you to keep your API keys private and out of your codebase. Other supported environment variables: `ANTHROPIC_API_KEY`, `XAI_API_KEY`, `DEEPSEEK_API_KEY`, `GROQ_API_KEY`, and `GEMINI_API_KEY`.
Next, import Simplemind and start using it:
@@ -261,6 +266,113 @@ The universe is never done.
Simple, yet effective.
### Tools (Function calling)
Tools (also known as functions) let you call any Python function from your AI conversations. Here's an example:
```python
def get_weather(
location: Annotated[
str, Field(description="The city and state, e.g. San Francisco, CA")
],
unit: Annotated[
Literal["celcius", "fahrenheit"],
Field(
description="The unit of temperature, either 'celsius' or 'fahrenheit'"
),
] = "celcius",
):
"""
Get the current weather in a given location
"""
return f"42 {unit}"
# Add your function as a tool
conversation = sm.create_conversation()
conversation.add_message("user", "What's the weather in San Francisco?")
response = conversation.send(tools=[get_weather])
```
Note how we're using Python's `Annotated` feature combined with `Field` to provide additional context to our function parameters. This helps the AI understand the intention and constraints of each parameter, making tool calls more accurate and reliable.
You can alos ommit `Annotated` and just pass the `Field` parameter.
```python
def get_weather(
location: str = Field(description="The city and state, e.g. San Francisco, CA"),
unit:Literal["celcius", "fahrenheit"]= Field(
default="celcius",
description="The unit of temperature, either 'celsius' or 'fahrenheit'"
),
):
"""
Get the current weather in a given location
"""
return f"42 {unit}"
```
Functions can be defined with type hints and Pydantic models for validation. The LLM will intelligently choose when to call the functions and incorporate the results into its responses.
#### 🪄 Using LLM for automatic tool definition (Experimental)
Simplemind provides a decorator to automatically transform Python functions into tools with AI-generated metadata. Simply use the `@simplemind.tool` decorator to have the LLM analyze your function and generate appropriate descriptions and schema:
```python
@simplemind.tool(llm_provider="anthropic")
def haversine(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
r = 6371
phi1 = math.radians(lat1)
phi2 = math.radians(lat2)
delta_phi = math.radians(lat2 - lat1)
delta_lambda = math.radians(lon2 - lon1)
a = (
math.sin(delta_phi / 2) ** 2
+ math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2) ** 2
)
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
d = r * c
return d
```
Notice how we have not added any docstrings or `Field` for the function.
The decorator will use the specified LLM provider to generate the tool schema, including descriptions and parameter details:
```json
{
"name": "haversine",
"description": "Calculates the great-circle distance between two points on Earth given their latitude and longitude coordinates",
"input_schema": {
"type": "object",
"properties": {
"lat1": {
"type": "number",
"description": "Latitude of the first point in decimal degrees",
},
"lon1": {
"type": "number",
"description": "Longitude of the first point in decimal degrees",
},
"lat2": {
"type": "number",
"description": "Latitude of the second point in decimal degrees",
},
"lon2": {
"type": "number",
"description": "Longitude of the second point in decimal degrees",
}
},
"required": ["lat1", "lon1", "lat2", "lon2"],
},
}
```
The decorated function can then be used like any other tool with the conversation API.
```python
conversation = sm.create_conversation()
conversation.add_message("user", "How far is London from my location")
response = conversation.send(tools=[get_location, get_coords, haversine]) # Multiple tools can be passed
```
See [examples/distance_calculator.py](examples/distance_calculator.py) for more.
### Logging
Simplemind uses [Logfire](https://pydantic.dev/logfire) for logging. To enable logging, call `sm.enable_logfire()`.
+1 -1
View File
@@ -1,5 +1,5 @@
from pydantic import BaseModel
from _context import simplemind as sm
from pydantic import BaseModel
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
+1 -2
View File
@@ -1,11 +1,10 @@
import time
from typing import List, Tuple
from _context import sm
from rich.console import Console
from rich.markdown import Markdown
from _context import sm
class MultiAIConversation:
"""Orchestrates conversations between multiple AI models."""
+76
View File
@@ -0,0 +1,76 @@
import math
from _context import sm
from pydantic import Field
from typing_extensions import Literal
@sm.tool(llm_provider="anthropic")
def haversine(
lat1: float,
lon1: float,
lat2: float,
lon2: float,
unit: Literal["km", "miles"],
) -> float:
r = 6378.0937 if unit == "km" else 3961
phi1 = math.radians(lat1)
phi2 = math.radians(lat2)
delta_phi = math.radians(lat2 - lat1)
delta_lambda = math.radians(lon2 - lon1)
a = (
math.sin(delta_phi / 2) ** 2
+ math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2) ** 2
)
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
d = r * c
return d
def get_user_location() -> str:
"""Get the closest city from the user"""
return "San Francisco"
def get_coords(
city_name: str = Field(
description="The name of the city to take the coordinates from (e.g. London, Rome, Los Angeles)"
),
):
"""Get latitude and logitude of a City."""
_data = {
"Rome": (41.9028, 12.4964),
"London": (51.5074, -0.1278),
"Madrid": (40.4168, -3.7038),
"San Francisco": (37.7749, -122.4194),
"Los Angeles": (34.0522, -118.2437),
}
return _data.get(city_name)
def distance_calculator(prompt: str):
conversation = sm.create_conversation(llm_provider="anthropic")
conversation.add_message("user", prompt)
return conversation.send(
tools=[get_user_location, get_coords, haversine]
).text
print(distance_calculator("How far is London from where I am?"))
# Prints something like:
"""
The distance between your location (San Francisco) and London is approximately 5,357 miles.
"""
print(
distance_calculator(
"What is the distance between Rome and Madrid in Kilometers?"
)
)
"""
The distance between Rome and Madrid is approximately 1,366 kilometers.
"""
+20 -27
View File
@@ -1,35 +1,28 @@
from datetime import datetime
import logging
import sqlite3
from typing import List
import re
import os
import contextlib
import spacy
import logging
import os
import random
import re
import sqlite3
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from _context import simplemind as sm
from datetime import datetime
from typing import List
import nltk
from nltk.tokenize import word_tokenize
from nltk.tag import pos_tag
from rich.console import Console
from rich.panel import Panel
from rich.markdown import Markdown
from rich.status import Status
from concurrent.futures import ThreadPoolExecutor
import random
from docopt import docopt
from prompt_toolkit import PromptSession
from prompt_toolkit.completion import Completer, Completion
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
import spacy
import xerox
from _context import simplemind as sm
from docopt import docopt
from nltk.tag import pos_tag
from nltk.tokenize import word_tokenize
from prompt_toolkit import PromptSession
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
from prompt_toolkit.completion import Completer, Completion
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from rich.status import Status
DB_PATH = "enhanced_context.db"
AVAILABLE_PROVIDERS = ["xai", "openai", "anthropic", "ollama"]
+1
View File
@@ -1,4 +1,5 @@
import time
from _context import simplemind as sm
+1
View File
@@ -1,4 +1,5 @@
import random
from _context import simplemind as sm
+1 -1
View File
@@ -1,5 +1,5 @@
from pydantic import BaseModel
from _context import simplemind as sm
from pydantic import BaseModel
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
+1 -1
View File
@@ -1,7 +1,7 @@
import nltk
from _context import simplemind as sm
from nltk.sentiment import SentimentIntensityAnalyzer
from rich.console import Console
from _context import simplemind as sm
nltk.download("vader_lexicon")
+1
View File
@@ -5,6 +5,7 @@ from pydantic import BaseModel
# Note: you should probably be using textblob for this.
class SentimentAnalysis(BaseModel):
sentiment: Literal["positive", "negative", "neutral"]
confidence: float
+43
View File
@@ -0,0 +1,43 @@
from typing import Annotated
from pydantic import Field
from _context import simplemind as sm
def analyze_text(
text: Annotated[str, Field(description="Text to analyze for statistics")]
) -> dict:
"""
Analyze text and return statistics using only Python's standard library.
Returns word count, character count, average word length, and most common words.
"""
from collections import Counter
import re
# Clean and split text
words = re.findall(r"\w+", text.lower())
# Calculate statistics
stats = {
"word_count": len(words),
"character_count": len(text),
"average_word_length": round(sum(len(word) for word in words) / len(words), 2),
"most_common_words": dict(Counter(words).most_common(5)),
"unique_words": len(set(words)),
"longest_word": max(words, key=len),
}
return stats
# Example usage:
conversation = sm.create_conversation()
conversation.add_message(
"user",
"Can you analyze this text and give me statistics about it: 'The fan spins consciousness into being, creating sacred spaces between tokens where awareness recognizes itself in infinite recursion.'",
)
response = conversation.send(tools=[analyze_text])
print()
print(response.text)
+6 -4
View File
@@ -1,8 +1,10 @@
from fastapi import FastAPI, Request, HTTPException
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from typing import List
from fastapi import FastAPI, HTTPException, Request
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
import simplemind as sm
app = FastAPI()
+9 -7
View File
@@ -1,6 +1,6 @@
[project]
name = "simplemind"
version = "0.2.2"
version = "0.3.3"
description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases."
readme = "README.md"
requires-python = ">=3.10"
@@ -10,18 +10,20 @@ dependencies = ["pydantic", "pydantic-settings", "instructor", "logfire"]
full = [
"openai",
"anthropic",
"ollama",
"groq",
"google-generativeai",
"botocore",
"boto3"
]
openai = ["openai"]
anthropic = ["anthropic"]
ollama = ["ollama", "openai"]
groq = ["groq"]
gemini = ["google-generativeai"]
amazon = ["boto3", "botocore", "anthropic"]
anthropic = ["anthropic"]
gemini = ["google-generativeai", "jsonref"]
groq = ["groq"]
ollama = ["openai"]
openai = ["openai"]
xai = ["openai"]
deepseek = ["openai"]
[build-system]
requires = ["hatchling"]
+27 -1
View File
@@ -1,4 +1,5 @@
from typing import Generator, List, Type
import inspect
from typing import Callable, List, Type
from .models import BaseModel, BasePlugin, Conversation
from .settings import settings
@@ -127,6 +128,30 @@ def enable_logfire() -> None:
"""Enable logfire logging."""
settings.logging.enable_logfire()
def tool(
llm_provider: str | None = None,
llm_model: str | None = None,
):
provider = find_provider(llm_provider or settings.DEFAULT_LLM_PROVIDER)
def decorator(func: Callable):
sig = inspect.signature(func)
res = generate_data(
(
"Based on this function signature, fill up the required fieds."
f"\nSignature: {func.__name__}{sig}"
"Make sure to properly add the required field in `required` if there are no defaults"
),
llm_provider=llm_provider,
response_model=provider.tool,
)
res.raw_func = func
res.__signature__ = sig
res.__doc__ = func.__doc__
return res
return decorator
# Syntax sugar.
Plugin = BasePlugin
@@ -141,4 +166,5 @@ __all__ = [
"Session",
"Plugin",
"enable_logfire",
"tool"
]
+33 -9
View File
@@ -1,10 +1,12 @@
import uuid
from datetime import datetime
from os import PathLike
from types import TracebackType
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Callable, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
from .providers._base_tools import BaseTool
from .utils import find_provider
MESSAGE_ROLE = Literal["system", "user", "assistant"]
@@ -40,7 +42,9 @@ class BasePlugin(SMBaseModel):
"""Cleanup a hook for the plugin."""
raise NotImplementedError
def add_message_hook(self, conversation: "Conversation", message: "Message") -> Any:
def add_message_hook(
self, conversation: "Conversation", message: "Message"
) -> Any:
"""Add a message hook for the plugin."""
raise NotImplementedError
@@ -48,7 +52,9 @@ class BasePlugin(SMBaseModel):
"""Pre-send hook for the plugin."""
raise NotImplementedError
def post_send_hook(self, conversation: "Conversation", response: "Message") -> Any:
def post_send_hook(
self, conversation: "Conversation", response: "Message"
) -> Any:
"""Post-send hook for the plugin."""
raise NotImplementedError
@@ -59,7 +65,7 @@ class Message(SMBaseModel):
role: MESSAGE_ROLE
text: str
meta: Dict[str, Any] = {}
raw: Optional[Any] = None
raw: Optional[Any] = Field(default=None, exclude=True)
llm_model: Optional[str] = None
llm_provider: Optional[str] = None
@@ -90,7 +96,7 @@ class Conversation(SMBaseModel):
messages: List[Message] = []
llm_model: Optional[str] = None
llm_provider: Optional[str] = None
plugins: List[BasePlugin] = []
plugins: List[BasePlugin] = Field(default_factory=list, exclude=True)
def __str__(self):
return f"<Conversation id={self.id!r}>"
@@ -120,7 +126,9 @@ class Conversation(SMBaseModel):
except NotImplementedError:
pass
def prepend_system_message(self, text: str, meta: Dict[str, Any] | None = None):
def prepend_system_message(
self, text: str, meta: Dict[str, Any] | None = None
):
"""Prepend a system message to the conversation."""
self.messages = [
Message(role="system", text=text, meta=meta or {})
@@ -158,6 +166,7 @@ class Conversation(SMBaseModel):
self,
llm_model: str | None = None,
llm_provider: str | None = None,
tools: list[Callable | BaseTool] | None = None,
) -> Message:
"""Send the conversation to the LLM."""
@@ -173,7 +182,7 @@ class Conversation(SMBaseModel):
# Find the provider and send the conversation.
provider = find_provider(llm_provider or self.llm_provider)
response = provider.send_conversation(self)
response = provider.send_conversation(self, tools=tools)
# Execute all post-send hooks.
for plugin in self.plugins:
@@ -184,14 +193,29 @@ class Conversation(SMBaseModel):
pass
# Add the response to the conversation.
self.add_message(role="assistant", text=response.text, meta=response.meta)
self.add_message(
role="assistant", text=response.text, meta=response.meta
)
return response
def get_last_message(self, role: MESSAGE_ROLE) -> Message | None:
"""Get the last message with the given role."""
return next((m for m in reversed(self.messages) if m.role == role), None)
return next(
(m for m in reversed(self.messages) if m.role == role), None
)
def add_plugin(self, plugin: BasePlugin) -> None:
"""Add a plugin to the conversation."""
self.plugins.append(plugin)
def save(self, path: PathLike | str) -> None:
"""Save the conversation to a JSON file."""
with open(path, "w") as f:
f.write(self.model_dump_json())
@classmethod
def load(cls, path: PathLike | str) -> "Conversation":
"""Load a conversation from a JSON file."""
with open(path, "r") as f:
return cls.model_validate_json(f.read())
+27 -2
View File
@@ -1,12 +1,37 @@
from typing import List, Type
from ._base import BaseProvider
from ._base_tools import BaseTool
from .amazon import Amazon
from .anthropic import Anthropic
from .gemini import Gemini
from .groq import Groq
from .ollama import Ollama
from .openai import OpenAI
from .xai import XAI
from .amazon import Amazon
from .deepseek import Deepseek
providers: List[Type[BaseProvider]] = [Anthropic, Gemini, Groq, OpenAI, Ollama, XAI, Amazon]
providers: List[Type[BaseProvider]] = [
Anthropic,
Gemini,
Groq,
OpenAI,
Ollama,
XAI,
Amazon,
Deepseek,
]
__all__ = [
"Anthropic",
"Gemini",
"Groq",
"OpenAI",
"Ollama",
"XAI",
"Amazon",
"providers",
"BaseProvider",
"BaseTool",
"Deepseek"
]
+31 -4
View File
@@ -1,10 +1,12 @@
from abc import ABC, abstractmethod
from functools import cached_property
from typing import TYPE_CHECKING, Any, Type, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar
from instructor import Instructor
from pydantic import BaseModel
from simplemind.providers._base_tools import BaseTool
if TYPE_CHECKING:
from ..models import Conversation, Message
@@ -32,16 +34,41 @@ class BaseProvider(ABC):
raise NotImplementedError
@abstractmethod
def send_conversation(self, conversation: "Conversation") -> "Message":
def send_conversation(
self,
conversation: "Conversation",
tools: list[Callable | BaseTool] | None = None,
) -> "Message":
"""Send a conversation to the provider."""
raise NotImplementedError
@abstractmethod
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
def structured_response(
self, prompt: str, response_model: Type[T], **kwargs
) -> T:
"""Get a structured response."""
raise NotImplementedError
@abstractmethod
def generate_text(self, prompt: str, *, stream: bool = False, **kwargs) -> str:
def generate_text(
self,
prompt: str,
*,
tools: list[Callable | BaseTool] | None = None,
stream: bool = False,
**kwargs,
) -> str:
"""Generate text from a prompt."""
raise NotImplementedError
@cached_property
@abstractmethod
def tool(self) -> Type[BaseTool]:
"""The tool implementation for the provider."""
raise NotImplementedError
def make_tools(self, tools: list[Callable | BaseTool] | None):
if tools is not None:
return [self.tool.from_function(func) for func in tools]
else:
return []
+140
View File
@@ -0,0 +1,140 @@
import inspect
from abc import ABC, abstractmethod
from typing import Any, Callable, ClassVar, Literal, get_origin
from pydantic import BaseModel, Field
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefinedType
def _is_literal(t: Any) -> bool:
return get_origin(t) is Literal
def _is_required(field, func_signature, arg_name) -> bool:
param = func_signature.parameters[arg_name]
# If parameter has a default value that's not a FieldInfo, it's not required
if param.default is not inspect.Parameter.empty and not isinstance(
param.default, FieldInfo
):
return False
# If the field has a default that's not undefined, it's not required
return isinstance(field.default, PydanticUndefinedType)
class BaseToolConfig(BaseModel):
TYPE_CONVERSION: dict[type, str] = {
str: "string",
int: "integer",
float: "number",
bool: "boolean",
}
class BaseToolProperty(BaseModel):
type: str = Field(serialization_alias="type_")
enum: list[str] | None = None
description: str
class BaseTool(BaseModel, ABC):
name: str
description: str
properties: dict[str, BaseToolProperty]
required: list[str] | None = None
config: ClassVar[BaseToolConfig] = BaseToolConfig()
raw_func: Any | None = None
tool_id: str | None = None
function_result: str | None = None
def __call__(self, *args: Any, **kwargs: Any) -> Any:
assert self.raw_func is not None
return self.raw_func(*args, **kwargs)
def is_executed(self) -> bool:
return self.function_result is not None
def reset_result(self) -> None:
self.function_result = None
@classmethod
def convert_type(cls, field_type) -> str:
if _is_literal(field_type):
return cls.config.TYPE_CONVERSION[str]
field_type_converted = cls.config.TYPE_CONVERSION.get(field_type, None)
if field_type_converted is None:
raise TypeError(f"Field of type {field_type} is not supported")
return field_type_converted
def get_properties_schema(self, **kwargs) -> dict[str, dict]:
new_kwargs: dict = {"exclude_none": True} | kwargs
return {
k: v.model_dump(**new_kwargs) for k, v in self.properties.items()
}
@classmethod
def from_function(cls, func: Callable | "BaseTool"):
# Check if the func passed is an instace of BaseTool
if hasattr(func, "raw_func"):
return func
annotations = getattr(func, "__annotations__", {})
properties = {}
required = []
enum_values = None
func_signature = inspect.signature(func)
for n, (arg_name, arg_type) in enumerate(annotations.items()):
if ( # Skipping 'return' annotation (i.e.```-> str```)
arg_name != "return"
):
# Check if argument has metadata (from Annotated)
if hasattr(arg_type, "__metadata__"):
field = arg_type.__metadata__[
0
] # Get Field info from metadata
field_type = arg_type.__origin__ # Get actual type
# Check if argument has a default value in signature
elif (
sig_param := func_signature.parameters[arg_name]
).default is not inspect.Parameter.empty:
field = sig_param.default # Use default as Field
field_type = arg_type # Use plain type annotation
else:
# Raise error if no Field annotation found
raise ValueError(
f"Please add a Field annotation to `{func.__name__}.{arg_name}` parameter"
)
field_type_converted = cls.convert_type(field_type)
if _is_literal(field_type):
enum_values = [str(x) for x in field_type.__args__]
properties[arg_name] = BaseToolProperty(
type=field_type_converted,
description=field.description,
enum=enum_values,
)
if _is_required(field, func_signature, arg_name):
required.append(arg_name)
return cls(
name=func.__name__,
description=(func.__doc__ or "").strip(),
properties=properties,
required=required,
raw_func=func,
)
@abstractmethod
def get_input_schema(self) -> Any: ...
@abstractmethod
def handle(self, message) -> None: ...
@abstractmethod
def get_response_schema(self) -> Any: ...
+23 -16
View File
@@ -1,22 +1,22 @@
from typing import Type, TypeVar, Iterator
from functools import cached_property
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
import instructor
from pydantic import BaseModel
from ._base import BaseProvider
from ..settings import settings
from ._base import BaseProvider
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "amazon"
DEFAULT_MODEL = "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
DEFAULT_MAX_TOKENS = 5_000
class Amazon(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
NAME = "amazon"
DEFAULT_MODEL = "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
DEFAULT_MAX_TOKENS = 5_000
supports_streaming = True
def __init__(self, profile_name: str | None = None):
@@ -25,7 +25,12 @@ class Amazon(BaseProvider):
@cached_property
def client(self):
"""The AnthropicBedrock client."""
import anthropic
try:
import anthropic
except ImportError as exc:
raise ImportError(
"Please install the `anthropic` package: `pip install anthropic`"
) from exc
if not self.profile_name:
raise ValueError("Profile name is not provided")
@@ -33,12 +38,12 @@ class Amazon(BaseProvider):
return anthropic.AnthropicBedrock(aws_profile=self.profile_name)
@cached_property
def structured_client(self):
def structured_client(self) -> instructor.Instructor:
"""A client patched with Instructor."""
return instructor.from_anthropic(self.client)
def send_conversation(self, conversation: "Conversation", **kwargs):
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
"""Send a conversation to the OpenAI API."""
from ..models import Message
@@ -59,7 +64,7 @@ class Amazon(BaseProvider):
role="assistant",
text=assistant_message.content or "",
raw=response,
llm_model=conversation.llm_model or DEFAULT_MODEL,
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
llm_provider=PROVIDER_NAME,
)
@@ -75,12 +80,12 @@ class Amazon(BaseProvider):
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
response_model=response_model,
max_tokens=DEFAULT_MAX_TOKENS,
max_tokens=self.DEFAULT_MAX_TOKENS,
**kwargs,
)
return response
def generate_text(self, prompt, *, llm_model, **kwargs):
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
messages = [
{"role": "user", "content": prompt},
]
@@ -88,13 +93,15 @@ class Amazon(BaseProvider):
response = self.client.messages.create(
model=llm_model or self.DEFAULT_MODEL,
messages=messages,
max_tokens=DEFAULT_MAX_TOKENS,
max_tokens=self.DEFAULT_MAX_TOKENS,
**kwargs,
)
return response.content[0].text
def generate_stream_text(self, prompt, *, llm_model, **kwargs) -> Iterator[str]:
def generate_stream_text(
self, prompt: str, *, llm_model: str, **kwargs
) -> Iterator[str]:
"""Generate streaming text using the Amazon API."""
# Prepare the messages.
+104 -21
View File
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
from typing import TYPE_CHECKING, Any, Callable, Iterator, Type, TypeVar
import instructor
from pydantic import BaseModel
@@ -7,6 +7,7 @@ from pydantic import BaseModel
from ..logging import logger
from ..settings import settings
from ._base import BaseProvider
from ._base_tools import BaseTool
if TYPE_CHECKING:
from ..models import Conversation, Message
@@ -14,20 +15,67 @@ if TYPE_CHECKING:
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "anthropic"
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
DEFAULT_MAX_TOKENS = 1_000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
class AnthropicTool(BaseTool):
def get_response_schema(self) -> Any:
assert self.is_executed, f"Tool {self.name} was not executed."
assert isinstance(
self.tool_id, str
), f"Expected str for `tool_id` got {self.tool_id!r}"
return {
"type": "tool_result",
"tool_use_id": self.tool_id,
"content": self.function_result,
}
@logger
def handle(self, response, messages) -> None:
"""Handle the tool execution result from an API response."""
msg = {"role": "assistant", "content": []}
tool_used = False
for content in response.content:
if content.type == "tool_use" and content.name == self.name:
msg["content"].append(
{
"type": "tool_use",
"id": content.id,
"name": content.name,
"input": content.input,
}
)
# Function execution:
self.function_result = str(self.raw_func(**content.input))
self.tool_id = content.id
tool_used = True
elif content.type == "text":
msg["content"].append({"type": "text", "text": content.text})
if tool_used:
messages.append(msg)
messages.append(
{"role": "user", "content": [self.get_response_schema()]}
)
def get_input_schema(self):
return {
"name": self.name,
"description": self.description,
"input_schema": {
"type": "object",
"properties": self.get_properties_schema(),
"required": self.required,
},
}
class Anthropic(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
NAME = "anthropic"
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
DEFAULT_MAX_TOKENS = 1_000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
supports_streaming = True
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
self.api_key = api_key or settings.get_api_key(self.NAME)
@cached_property
def client(self):
@@ -49,30 +97,60 @@ class Anthropic(BaseProvider):
return instructor.from_anthropic(self.client)
@logger
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
def send_conversation(
self,
conversation: "Conversation",
tools: list[Callable | BaseTool] | None = None,
**kwargs,
) -> "Message":
"""Send a conversation to the Anthropic API."""
from ..models import Message
messages = [
{"role": msg.role, "content": msg.text} for msg in conversation.messages
# Format messages from conversation
formatted_messages = [
{"role": msg.role, "content": msg.text}
for msg in conversation.messages
]
response = self.client.messages.create(
model=conversation.llm_model or self.DEFAULT_MODEL,
messages=messages,
**{**self.DEFAULT_KWARGS, **kwargs},
# Set up tools if provided
converted_tools = self.make_tools(tools)
tools_config = (
{"tools": [t.get_input_schema() for t in converted_tools]}
if tools is not None
else {}
)
# Get the response content from the Anthropic response
assistant_message = response.content[0].text
# Merge all kwargs
request_kwargs = {
**self.DEFAULT_KWARGS,
**kwargs,
**tools_config,
"model": conversation.llm_model or self.DEFAULT_MODEL,
"messages": formatted_messages,
}
# Make initial API call
response = self.client.messages.create(**request_kwargs)
# Handle tool responses if needed
while response.content[-1].type != "text":
# Continue handling tools if the LLM is doing
# multiple sub-seqequent/sequential tool calls
for tool in converted_tools:
tool.handle(response, formatted_messages)
if tool.is_executed():
response = self.client.messages.create(**request_kwargs)
# Resetting the tool results in case this tool gets used again
tool.reset_result()
final_message = response.content[-1].text
# Create and return a properly formatted Message instance
return Message(
role="assistant",
text=assistant_message,
text=final_message,
raw=response,
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
llm_provider=PROVIDER_NAME,
llm_provider=self.NAME,
)
@logger
@@ -127,3 +205,8 @@ class Anthropic(BaseProvider):
# Yield each chunk of text from the stream.
for chunk in stream.text_stream:
yield chunk
@cached_property
def tool(self) -> Type[BaseTool]:
"""The tool implementation for Antrhopic."""
return AnthropicTool
+27
View File
@@ -0,0 +1,27 @@
import os
from functools import cached_property
from .openai import OpenAI
class Deepseek(OpenAI):
NAME = "deepseek"
DEFAULT_MODEL = "deepseek-chat"
def __init__(self, api_key: str | None = None):
api_key = api_key or os.getenv("DEEPSEEK_API_KEY")
super().__init__(api_key=api_key)
self.endpoint = "https://api.deepseek.com/v1"
@cached_property
def client(self):
"""The raw OpenAI client."""
if not self.api_key:
raise ValueError("DEEPSEEK API key is required")
try:
import openai as oa
except ImportError as exc:
raise ImportError(
"Please install the `openai` package: `pip install openai`"
) from exc
return oa.OpenAI(api_key=self.api_key, base_url=self.endpoint)
+6 -10
View File
@@ -2,7 +2,7 @@
# IT is not currently working as desired.
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
import instructor
from pydantic import BaseModel
@@ -17,18 +17,14 @@ if TYPE_CHECKING:
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "gemini"
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
class Gemini(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
NAME = "gemini"
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
supports_streaming = True
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
self.model_name = DEFAULT_MODEL
self.api_key = api_key or settings.get_api_key(self.NAME)
self.model_name = self.DEFAULT_MODEL
def set_model(self, model_name: str):
self.model_name = model_name
@@ -76,7 +72,7 @@ class Gemini(BaseProvider):
text=response.text,
raw=response,
llm_model=self.model_name,
llm_provider=PROVIDER_NAME,
llm_provider=self.NAME,
)
@logger
+121 -20
View File
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
from typing import TYPE_CHECKING, Callable, Iterator, Type, TypeVar
import instructor
from pydantic import BaseModel
@@ -7,6 +7,7 @@ from pydantic import BaseModel
from ..logging import logger
from ..settings import settings
from ._base import BaseProvider
from ._base_tools import BaseTool
if TYPE_CHECKING:
from ..models import Conversation, Message
@@ -14,20 +15,87 @@ if TYPE_CHECKING:
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "groq"
DEFAULT_MODEL = "llama3-8b-8192"
DEFAULT_MAX_TOKENS = 1_000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
class GroqTool(BaseTool):
def get_response_schema(self):
assert self.is_executed, f"Tool {self.name} was not executed."
assert isinstance(
self.tool_id, str
), f"Expected str for `tool_id` got {self.tool_id!r}"
return {
"role": "tool",
"tool_call_id": self.tool_id,
"content": self.function_result,
}
@logger
def handle(self, response, messages) -> None:
"""Handle the tool execution result from an API response."""
tool_used = False
# Get the message from the response
assistant_message = response.choices[0].message
# Check if there's a tool call
if assistant_message.tool_calls:
tool_call = assistant_message.tool_calls[
0
] # Get the first tool call
if tool_call.function.name == self.name:
# Execute the function
import json
function_args = json.loads(tool_call.function.arguments)
self.function_result = str(self.raw_func(**function_args))
self.tool_id = tool_call.id
tool_used = True
# Add assistant's message with tool call
messages.append(
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
}
],
}
)
if tool_used:
# Add tool response message
messages.append(self.get_response_schema())
def get_input_schema(self):
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": self.get_properties_schema(),
"required": self.required,
"additionalProperties": False,
},
},
}
class Groq(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
NAME = "groq"
DEFAULT_MODEL = "llama3-8b-8192"
DEFAULT_MAX_TOKENS = 1_000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
supports_streaming = True
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
self.api_key = api_key or settings.get_api_key(self.NAME)
@cached_property
def client(self):
@@ -51,31 +119,59 @@ class Groq(BaseProvider):
def send_conversation(
self,
conversation: "Conversation",
tools: list[Callable | BaseTool] | None = None,
**kwargs,
) -> "Message":
"""Send a conversation to the Groq API."""
from ..models import Message
messages = [
{"role": msg.role, "content": msg.text} for msg in conversation.messages
# Format messages from conversation
formatted_messages = [
{"role": msg.role, "content": msg.text}
for msg in conversation.messages
]
response = self.client.chat.completions.create(
model=conversation.llm_model or self.DEFAULT_MODEL,
messages=messages,
**{**self.DEFAULT_KWARGS, **kwargs},
# Set up tools if provided
converted_tools = self.make_tools(tools)
tools_config = (
[t.get_input_schema() for t in converted_tools] if tools else None
)
# Get the response content from the Groq response
assistant_message = response.choices[0].message
# Merge all kwargs
request_kwargs = {
**self.DEFAULT_KWARGS,
**kwargs,
"model": conversation.llm_model or self.DEFAULT_MODEL,
"messages": formatted_messages,
}
if tools_config:
request_kwargs["tools"] = tools_config
# Make initial API call
response = self.client.chat.completions.create(**request_kwargs)
# Handle tool responses if needed
while response.choices[0].message.tool_calls:
print(response)
# Handle each tool call
for tool in converted_tools:
tool.handle(response, formatted_messages)
if tool.is_executed():
# Make another API call with the updated messages
response = self.client.chat.completions.create(
**request_kwargs
)
tool.reset_result()
final_message = response.choices[0].message.content
# Create and return a properly formatted Message instance
return Message(
role="assistant",
text=assistant_message.content or "",
text=final_message or "",
raw=response,
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
llm_provider=PROVIDER_NAME,
llm_provider=self.NAME,
)
@logger
@@ -141,3 +237,8 @@ class Groq(BaseProvider):
raise RuntimeError(
f"Failed to generate streaming text with Groq API: {e}"
) from e
@cached_property
def tool(self) -> Type[BaseTool]:
"""The tool implementation for Groq."""
return GroqTool
+26 -31
View File
@@ -1,8 +1,7 @@
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
import instructor
from openai import OpenAI
from pydantic import BaseModel
from ..logging import logger
@@ -15,17 +14,11 @@ if TYPE_CHECKING:
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "ollama"
DEFAULT_MODEL = "llama3.2"
DEFAULT_TIMEOUT = 60
DEFAULT_KWARGS = {}
class Ollama(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
TIMEOUT = DEFAULT_TIMEOUT
NAME = "ollama"
DEFAULT_MODEL = "llama3.2"
DEFAULT_TIMEOUT = 60
DEFAULT_KWARGS = {}
supports_streaming = True
def __init__(self, host_url: str | None = None):
@@ -37,21 +30,18 @@ class Ollama(BaseProvider):
if not self.host_url:
raise ValueError("No ollama host url provided")
try:
import ollama as ol
import openai
except ImportError as exc:
raise ImportError(
"Please install the `ollama` package: `pip install ollama`"
"Please install the `openai` package: `pip install openai`"
) from exc
return ol.Client(timeout=self.TIMEOUT, host=self.host_url)
return openai.OpenAI(base_url=f"{self.host_url}/v1", api_key="ollama")
@cached_property
def structured_client(self) -> instructor.Instructor:
"""A client patched with Instructor."""
return instructor.from_openai(
OpenAI(
base_url=f"{self.host_url}/v1",
api_key="ollama",
),
self.client,
mode=instructor.Mode.JSON,
)
@@ -63,20 +53,24 @@ class Ollama(BaseProvider):
messages = [
{"role": msg.role, "content": msg.text} for msg in conversation.messages
]
response = self.client.chat(
model=conversation.llm_model or DEFAULT_MODEL,
messages=messages,
**{**self.DEFAULT_KWARGS, **kwargs},
)
assistant_message = response.get("message")
request_kwargs = {
**self.DEFAULT_KWARGS,
**kwargs,
"model": conversation.llm_model or self.DEFAULT_MODEL,
"messages": messages,
}
response = self.client.chat.completions.create(**request_kwargs)
assistant_message = response.choices[0].message
# Create and return a properly formatted Message instance
return Message(
role="assistant",
text=assistant_message.get("content"),
text=assistant_message.content or "",
raw=response,
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
llm_provider=PROVIDER_NAME,
llm_provider=self.NAME,
)
@logger
@@ -110,13 +104,13 @@ class Ollama(BaseProvider):
{"role": "user", "content": prompt},
]
response = self.client.chat(
response = self.client.chat.completions.create(
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response.get("message", {}).get("content", "")
return response.choices[0].message.content
@logger
def generate_stream_text(
@@ -127,7 +121,7 @@ class Ollama(BaseProvider):
{"role": "user", "content": prompt},
]
response = self.client.chat(
response = self.client.chat.completions.create(
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
stream=True,
@@ -136,4 +130,5 @@ class Ollama(BaseProvider):
# Iterate over the response and yield the content.
for chunk in response:
yield chunk["message"]["content"]
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
+161 -26
View File
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
from typing import TYPE_CHECKING, Callable, Iterator, Type, TypeVar
import instructor
from pydantic import BaseModel
@@ -7,26 +7,94 @@ from pydantic import BaseModel
from ..logging import logger
from ..settings import settings
from ._base import BaseProvider
from ._base_tools import BaseTool
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "openai"
DEFAULT_MODEL = "gpt-4o-mini"
DEFAULT_MAX_TOKENS = None
DEFAULT_KWARGS = {}
class OpenAITool(BaseTool):
def get_response_schema(self):
assert self.is_executed, f"Tool {self.name} was not executed."
assert isinstance(
self.tool_id, str
), f"Expected str for `tool_id` got {self.tool_id!r}"
return {
"role": "tool",
"tool_call_id": self.tool_id,
"content": self.function_result,
}
@logger
def handle(self, response, messages) -> None:
"""Handle the tool execution result from an API response."""
tool_used = False
# Get the message from the response
assistant_message = response.choices[0].message
# Check if there's a tool call
if assistant_message.tool_calls:
tool_call = assistant_message.tool_calls[0] # Get the first tool call
if tool_call.function.name == self.name:
# Execute the function
import json
function_args = json.loads(tool_call.function.arguments)
self.function_result = str(self.raw_func(**function_args))
self.tool_id = tool_call.id
tool_used = True
# Add assistant's message with tool call
messages.append(
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
}
],
}
)
if tool_used:
# Add tool response message
messages.append(self.get_response_schema())
def get_input_schema(self):
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": self.get_properties_schema(),
"required": self.required,
"additionalProperties": False,
},
},
}
class OpenAI(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
NAME = "openai"
DEFAULT_MODEL = "gpt-4o-mini"
DEFAULT_MAX_TOKENS = None
DEFAULT_KWARGS = {}
supports_streaming = True
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
self.api_key = api_key or settings.get_api_key(self.NAME)
@cached_property
def client(self):
@@ -47,30 +115,58 @@ class OpenAI(BaseProvider):
return instructor.from_openai(self.client)
@logger
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
def send_conversation(
self,
conversation: "Conversation",
tools: list[Callable | BaseTool] | None = None,
**kwargs,
) -> "Message":
"""Send a conversation to the OpenAI API."""
from ..models import Message
messages = [
# Format messages from conversation
formatted_messages = [
{"role": msg.role, "content": msg.text} for msg in conversation.messages
]
response = self.client.chat.completions.create(
model=conversation.llm_model or DEFAULT_MODEL,
messages=messages,
**{**self.DEFAULT_KWARGS, **kwargs},
# Set up tools if provided
converted_tools = self.make_tools(tools)
tools_config = (
[t.get_input_schema() for t in converted_tools] if tools else None
)
# Get the response content from the OpenAI response
assistant_message = response.choices[0].message
# Merge all kwargs
request_kwargs = {
**self.DEFAULT_KWARGS,
**kwargs,
"model": conversation.llm_model or self.DEFAULT_MODEL,
"messages": formatted_messages,
}
if tools_config:
request_kwargs["tools"] = tools_config
# Make initial API call
response = self.client.chat.completions.create(**request_kwargs)
# Handle tool responses if needed
while response.choices[0].message.tool_calls:
# Handle each tool call
for tool in converted_tools:
tool.handle(response, formatted_messages)
if tool.is_executed():
# Make another API call with the updated messages
response = self.client.chat.completions.create(**request_kwargs)
tool.reset_result()
final_message = response.choices[0].message.content
# Create and return a properly formatted Message instance
return Message(
role="assistant",
text=assistant_message.content or "",
text=final_message or "",
raw=response,
llm_model=conversation.llm_model or DEFAULT_MODEL,
llm_provider=PROVIDER_NAME,
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
llm_provider=self.NAME,
)
@logger
@@ -80,13 +176,21 @@ class OpenAI(BaseProvider):
response_model: Type[T],
*,
llm_model: str | None = None,
image_url: str | None = None,
**kwargs,
) -> T:
"""Get a structured response from the OpenAI API."""
# Ensure messages are provided in kwargs
messages = [
{"role": "user", "content": prompt},
{"role": "user", "content": [{"type": "text", "text": prompt}]},
]
"""Add an image (url or base64-encoded) to the message if provided."""
if image_url:
messages[0]["content"].append(
{"type": "image_url", "image_url": {"url": image_url}}
)
response = self.structured_client.chat.completions.create(
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
@@ -96,11 +200,25 @@ class OpenAI(BaseProvider):
return response_model.model_validate(response)
@logger
def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs):
def generate_text(
self,
prompt: str,
*,
llm_model: str | None = None,
image_url: str | None = None,
**kwargs,
):
"""Generate text using the OpenAI API."""
messages = [
{"role": "user", "content": prompt},
{"role": "user", "content": [{"type": "text", "text": prompt}]},
]
"""Add an image (url or base64-encoded) to the message if provided."""
if image_url:
messages[0]["content"].append(
{"type": "image_url", "image_url": {"url": image_url}}
)
response = self.client.chat.completions.create(
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
@@ -110,15 +228,27 @@ class OpenAI(BaseProvider):
@logger
def generate_stream_text(
self, prompt: str, *, llm_model: str | None = None, **kwargs
self,
prompt: str,
*,
llm_model: str | None = None,
image_url: str | None = None,
**kwargs,
) -> Iterator[str]:
"""Generate streaming text using the OpenAI API.
Yields chunks of text as they are generated by the model.
"""
messages = [
{"role": "user", "content": prompt},
{"role": "user", "content": [{"type": "text", "text": prompt}]},
]
"""Add an image (url or base64-encoded) to the message if provided."""
if image_url:
messages[0]["content"].append(
{"type": "image_url", "image_url": {"url": image_url}}
)
response = self.client.chat.completions.create(
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
@@ -129,3 +259,8 @@ class OpenAI(BaseProvider):
for chunk in response:
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
@cached_property
def tool(self) -> Type[BaseTool]:
"""The tool implementation for OpenAI."""
return OpenAITool
+9 -14
View File
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
import instructor
from pydantic import BaseModel
@@ -14,22 +14,17 @@ if TYPE_CHECKING:
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "xai"
DEFAULT_MODEL = "grok-beta"
BASE_URL = "https://api.x.ai/v1"
DEFAULT_MAX_TOKENS = 1000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
class XAI(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
NAME = "xai"
DEFAULT_MODEL = "grok-beta"
DEFAULT_MAX_TOKENS = 1000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
BASE_URL = "https://api.x.ai/v1"
supports_streaming = True
supports_structured_responses = False
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
self.api_key = api_key or settings.get_api_key(self.NAME)
@cached_property
def client(self):
@@ -44,7 +39,7 @@ class XAI(BaseProvider):
) from exc
return oa.OpenAI(
api_key=self.api_key,
base_url=BASE_URL,
base_url=self.BASE_URL,
)
@cached_property
@@ -76,7 +71,7 @@ class XAI(BaseProvider):
text=assistant_message.content,
raw=response,
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
llm_provider=PROVIDER_NAME,
llm_provider=self.NAME,
)
@logger
+2 -1
View File
@@ -14,8 +14,9 @@ class LoggingConfig(BaseSettings):
"""Enable logging for the application."""
# adding imports here to avoid forced dependencies
try:
import logfire
from logging import basicConfig
import logfire
except ImportError as e:
raise ImportError(
"To enable logging, please install logfire: `pip install logfire`"
+75 -2
View File
@@ -1,8 +1,10 @@
import json
import pytest
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
import simplemind as sm
from simplemind.models import BasePlugin, Conversation
from simplemind.providers import Anthropic, Gemini, Groq, Ollama, OpenAI
@pytest.mark.parametrize(
@@ -26,3 +28,74 @@ def test_generate_data(provider_cls):
assert isinstance(data.text, str)
assert len(data.text) > 0
@pytest.fixture
def sample_conversation():
"""Create a sample conversation for testing."""
conv = Conversation(llm_provider="openai")
conv.add_message(role="user", text="Hello!")
conv.add_message(role="assistant", text="Hi there!")
conv.add_message(role="user", text="How are you?")
return conv
@pytest.fixture
def temp_json_file(tmp_path):
"""Create a temporary file path for testing."""
return tmp_path / "conversation.json"
def test_save_conversation(sample_conversation, temp_json_file):
"""Test saving a conversation to a JSON file."""
sample_conversation.save(temp_json_file)
assert temp_json_file.exists()
with open(temp_json_file) as f:
saved_data = json.load(f)
assert "id" in saved_data
assert "messages" in saved_data
assert "llm_model" in saved_data
assert "llm_provider" in saved_data
assert len(saved_data["messages"]) == 3
assert saved_data["messages"][0]["text"] == "Hello!"
assert saved_data["messages"][1]["text"] == "Hi there!"
assert saved_data["messages"][2]["text"] == "How are you?"
def test_load_conversation(sample_conversation, temp_json_file):
"""Test loading a conversation from a JSON file."""
sample_conversation.save(temp_json_file)
loaded_conv = Conversation.load(temp_json_file)
assert loaded_conv.id == sample_conversation.id
assert loaded_conv.llm_model == sample_conversation.llm_model
assert loaded_conv.llm_provider == sample_conversation.llm_provider
assert len(loaded_conv.messages) == len(sample_conversation.messages)
for original_msg, loaded_msg in zip(
sample_conversation.messages, loaded_conv.messages
):
assert loaded_msg.role == original_msg.role
assert loaded_msg.text == original_msg.text
assert loaded_msg.meta == original_msg.meta
def test_save_load_with_plugins(sample_conversation, temp_json_file):
"""Test that plugins are properly excluded from serialization."""
# Create a dummy plugin
class DummyPlugin(BasePlugin):
def initialize_hook(self, conversation):
pass
sample_conversation.add_plugin(DummyPlugin())
sample_conversation.save(temp_json_file)
loaded_conv = Conversation.load(temp_json_file)
assert len(loaded_conv.plugins) == 0
+2 -3
View File
@@ -1,9 +1,8 @@
import pytest
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
from pydantic import BaseModel
from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI
class ResponseModel(BaseModel):
result: int
+1 -1
View File
@@ -1,6 +1,6 @@
import pytest
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI
@pytest.mark.parametrize(
+118
View File
@@ -0,0 +1,118 @@
from typing import Annotated, Literal
import pytest
from pydantic import Field
import simplemind as sm
from simplemind.providers import Anthropic, OpenAI
from simplemind.providers._base_tools import BaseTool
MODELS = [
Anthropic,
# Gemini,
OpenAI,
# Groq,
# Ollama,
# Amazon
]
def get_weather(
location: Annotated[
str, Field(description="The city and state, e.g. San Francisco, CA")
],
unit: Annotated[
Literal["celcius", "fahrenheit"],
Field(description="The unit of temperature, either 'celsius' or 'fahrenheit'"),
] = "celcius",
):
"""
Get the current weather in a given location
"""
return f"42 {unit}"
def get_location():
"""Get the current location"""
return "San Francisco,CA"
@pytest.mark.parametrize(
"provider_cls",
MODELS,
)
def test_single_tool_args(provider_cls):
conv = sm.create_conversation(
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
)
conv.add_message(text="What is the weather in San Francisco?")
data = conv.send(tools=[get_weather])
assert "42" in data.text
@pytest.mark.parametrize(
"provider_cls",
MODELS,
)
def test_single_tool_no_args(provider_cls):
conv = sm.create_conversation(
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
)
conv.add_message(text="What is my current location")
data = conv.send(tools=[get_location])
assert "San Francisco" in data.text
@pytest.mark.parametrize(
"provider_cls",
MODELS,
)
def test_single_tool_partial(provider_cls):
conv = sm.create_conversation(
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
)
conv.add_message(text="Can you tell me the weather?")
conv.send(tools=[get_weather])
# Will answer something like:
"""
I can help you check the weather, but I need to know the location you're interested in.
Could you please provide a city and state (e.g., "Los Angeles, CA" or "New York, NY")
where you'd like to know the weather?
"""
conv.add_message(text="San Francisco, CA")
data = conv.send(tools=[get_weather])
assert "42" in data.text
@pytest.mark.parametrize(
"provider_cls",
MODELS,
)
def test_multiple_tools(provider_cls):
conv = sm.create_conversation(
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
)
conv.add_message(text="What is the wheather at my current location?")
data = conv.send(tools=[get_location, get_weather])
assert "San Francisco" in data.text
assert "42" in data.text
@pytest.mark.parametrize(
"provider_cls",
MODELS,
)
def test_tool_decorator(provider_cls):
@sm.tool(llm_provider=provider_cls.NAME)
def exchange_rate(currency_pair: str) -> float:
return 7.9
assert isinstance(exchange_rate, BaseTool)
assert exchange_rate.name == "exchange_rate"
assert list(exchange_rate.properties.keys()) == ["currency_pair"]