mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 14:50:16 +00:00
Compare commits
42 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b5a901efaf | |||
| 9ccef9abdc | |||
| 3421de0fc1 | |||
| 54b0007947 | |||
| 90af44ace0 | |||
| cff3bff3d5 | |||
| 3abbb79f6c | |||
| 59c1bd3a0f | |||
| 052781014d | |||
| db28f1195c | |||
| b0a7197c6e | |||
| 7684c2568b | |||
| 8b90dbba40 | |||
| 752ccb1de8 | |||
| 391bfaaeab | |||
| d963bc0b1c | |||
| 0c1f225252 | |||
| 4decaa0722 | |||
| 39b5a5e19d | |||
| ef38fea767 | |||
| 8181f37fed | |||
| 3aacfd51ee | |||
| a2991eec0c | |||
| 9ae9a2703a | |||
| 0661b097d2 | |||
| fad442ba3f | |||
| 5b9624c385 | |||
| 8ff0521e17 | |||
| d5bdb712e9 | |||
| a97f9be2c8 | |||
| 107f983a18 | |||
| 2404e2c977 | |||
| c87a598286 | |||
| 9662b60177 | |||
| ea997aae7b | |||
| 081baf203c | |||
| 4cb18e9e3b | |||
| 0462ea0e38 | |||
| 8492ec9456 | |||
| 1709055e1a | |||
| c2303114ab | |||
| fe5af93780 |
@@ -5,3 +5,4 @@ export OLLAMA_HOST_URL=""
|
||||
export OPENAI_API_KEY=""
|
||||
export XAI_API_KEY=""
|
||||
export AMAZON_PROFILE_NAME=""
|
||||
export DEEPSEEK_API_KEY=""
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
github: kennethreitz
|
||||
thanks_dev: kennethreitz
|
||||
custom: https://cash.app/$KennethReitz
|
||||
@@ -1,6 +1,21 @@
|
||||
Release History
|
||||
===============
|
||||
|
||||
## 0.3.3 (2024-02-08)
|
||||
|
||||
- Improve openai provider by removing debug print statements.
|
||||
|
||||
## 0.3.2 (2024-01-27)
|
||||
|
||||
- Improve Deepseek provider.
|
||||
|
||||
## 0.3.1 (2024-01-27)
|
||||
|
||||
- Introduce Deepseek provider.
|
||||
|
||||
## 0.3.0 (2024-11-12)
|
||||
|
||||
- Introduce save / load functionality for `Conversation`.
|
||||
|
||||
## 0.2.4 (2024-11-11)
|
||||
|
||||
|
||||
@@ -37,6 +37,11 @@ The APIs remain identical between all supported providers / models:
|
||||
<td><code>"amazon"</code></td>
|
||||
<td><code>"anthropic.claude-3-5-sonnet-20241022-v2:0"</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><a href="https://www.deepseek.com">Deepseek</a></td>
|
||||
<td><code>"deepseek"</code></td>
|
||||
<td><code>"deepseek-chat"</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><a href="https://gemini.google/">Google's Gemini</a></td>
|
||||
<td><code>"gemini"</code></td>
|
||||
@@ -83,7 +88,7 @@ First, authenticate your API keys by setting them in the environment variables:
|
||||
$ export OPENAI_API_KEY="sk-..."
|
||||
```
|
||||
|
||||
This pattern allows you to keep your API keys private and out of your codebase. Other supported environment variables: `ANTHROPIC_API_KEY`, `XAI_API_KEY`, `GROQ_API_KEY`, and `GEMINI_API_KEY`.
|
||||
This pattern allows you to keep your API keys private and out of your codebase. Other supported environment variables: `ANTHROPIC_API_KEY`, `XAI_API_KEY`, `DEEPSEEK_API_KEY`, `GROQ_API_KEY`, and `GEMINI_API_KEY`.
|
||||
|
||||
Next, import Simplemind and start using it:
|
||||
|
||||
@@ -261,6 +266,113 @@ The universe is never done.
|
||||
|
||||
Simple, yet effective.
|
||||
|
||||
### Tools (Function calling)
|
||||
Tools (also known as functions) let you call any Python function from your AI conversations. Here's an example:
|
||||
|
||||
```python
|
||||
def get_weather(
|
||||
location: Annotated[
|
||||
str, Field(description="The city and state, e.g. San Francisco, CA")
|
||||
],
|
||||
unit: Annotated[
|
||||
Literal["celcius", "fahrenheit"],
|
||||
Field(
|
||||
description="The unit of temperature, either 'celsius' or 'fahrenheit'"
|
||||
),
|
||||
] = "celcius",
|
||||
):
|
||||
"""
|
||||
Get the current weather in a given location
|
||||
"""
|
||||
return f"42 {unit}"
|
||||
|
||||
# Add your function as a tool
|
||||
conversation = sm.create_conversation()
|
||||
conversation.add_message("user", "What's the weather in San Francisco?")
|
||||
response = conversation.send(tools=[get_weather])
|
||||
```
|
||||
|
||||
Note how we're using Python's `Annotated` feature combined with `Field` to provide additional context to our function parameters. This helps the AI understand the intention and constraints of each parameter, making tool calls more accurate and reliable.
|
||||
You can alos ommit `Annotated` and just pass the `Field` parameter.
|
||||
```python
|
||||
def get_weather(
|
||||
location: str = Field(description="The city and state, e.g. San Francisco, CA"),
|
||||
unit:Literal["celcius", "fahrenheit"]= Field(
|
||||
default="celcius",
|
||||
description="The unit of temperature, either 'celsius' or 'fahrenheit'"
|
||||
),
|
||||
):
|
||||
"""
|
||||
Get the current weather in a given location
|
||||
"""
|
||||
return f"42 {unit}"
|
||||
```
|
||||
|
||||
Functions can be defined with type hints and Pydantic models for validation. The LLM will intelligently choose when to call the functions and incorporate the results into its responses.
|
||||
|
||||
#### 🪄 Using LLM for automatic tool definition (Experimental)
|
||||
|
||||
Simplemind provides a decorator to automatically transform Python functions into tools with AI-generated metadata. Simply use the `@simplemind.tool` decorator to have the LLM analyze your function and generate appropriate descriptions and schema:
|
||||
|
||||
```python
|
||||
@simplemind.tool(llm_provider="anthropic")
|
||||
def haversine(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
|
||||
r = 6371
|
||||
phi1 = math.radians(lat1)
|
||||
phi2 = math.radians(lat2)
|
||||
delta_phi = math.radians(lat2 - lat1)
|
||||
delta_lambda = math.radians(lon2 - lon1)
|
||||
|
||||
a = (
|
||||
math.sin(delta_phi / 2) ** 2
|
||||
+ math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2) ** 2
|
||||
)
|
||||
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
|
||||
d = r * c
|
||||
return d
|
||||
```
|
||||
Notice how we have not added any docstrings or `Field` for the function.
|
||||
The decorator will use the specified LLM provider to generate the tool schema, including descriptions and parameter details:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "haversine",
|
||||
"description": "Calculates the great-circle distance between two points on Earth given their latitude and longitude coordinates",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"lat1": {
|
||||
"type": "number",
|
||||
"description": "Latitude of the first point in decimal degrees",
|
||||
},
|
||||
"lon1": {
|
||||
"type": "number",
|
||||
"description": "Longitude of the first point in decimal degrees",
|
||||
},
|
||||
"lat2": {
|
||||
"type": "number",
|
||||
"description": "Latitude of the second point in decimal degrees",
|
||||
},
|
||||
"lon2": {
|
||||
"type": "number",
|
||||
"description": "Longitude of the second point in decimal degrees",
|
||||
}
|
||||
},
|
||||
"required": ["lat1", "lon1", "lat2", "lon2"],
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
The decorated function can then be used like any other tool with the conversation API.
|
||||
|
||||
```python
|
||||
conversation = sm.create_conversation()
|
||||
conversation.add_message("user", "How far is London from my location")
|
||||
response = conversation.send(tools=[get_location, get_coords, haversine]) # Multiple tools can be passed
|
||||
```
|
||||
|
||||
See [examples/distance_calculator.py](examples/distance_calculator.py) for more.
|
||||
|
||||
### Logging
|
||||
|
||||
Simplemind uses [Logfire](https://pydantic.dev/logfire) for logging. To enable logging, call `sm.enable_logfire()`.
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
import math
|
||||
|
||||
from _context import sm
|
||||
from pydantic import Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
@sm.tool(llm_provider="anthropic")
|
||||
def haversine(
|
||||
lat1: float,
|
||||
lon1: float,
|
||||
lat2: float,
|
||||
lon2: float,
|
||||
unit: Literal["km", "miles"],
|
||||
) -> float:
|
||||
r = 6378.0937 if unit == "km" else 3961
|
||||
phi1 = math.radians(lat1)
|
||||
phi2 = math.radians(lat2)
|
||||
delta_phi = math.radians(lat2 - lat1)
|
||||
delta_lambda = math.radians(lon2 - lon1)
|
||||
|
||||
a = (
|
||||
math.sin(delta_phi / 2) ** 2
|
||||
+ math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2) ** 2
|
||||
)
|
||||
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
|
||||
d = r * c
|
||||
return d
|
||||
|
||||
|
||||
def get_user_location() -> str:
|
||||
"""Get the closest city from the user"""
|
||||
return "San Francisco"
|
||||
|
||||
|
||||
def get_coords(
|
||||
city_name: str = Field(
|
||||
description="The name of the city to take the coordinates from (e.g. London, Rome, Los Angeles)"
|
||||
),
|
||||
):
|
||||
"""Get latitude and logitude of a City."""
|
||||
_data = {
|
||||
"Rome": (41.9028, 12.4964),
|
||||
"London": (51.5074, -0.1278),
|
||||
"Madrid": (40.4168, -3.7038),
|
||||
"San Francisco": (37.7749, -122.4194),
|
||||
"Los Angeles": (34.0522, -118.2437),
|
||||
}
|
||||
|
||||
return _data.get(city_name)
|
||||
|
||||
|
||||
def distance_calculator(prompt: str):
|
||||
conversation = sm.create_conversation(llm_provider="anthropic")
|
||||
conversation.add_message("user", prompt)
|
||||
return conversation.send(
|
||||
tools=[get_user_location, get_coords, haversine]
|
||||
).text
|
||||
|
||||
|
||||
print(distance_calculator("How far is London from where I am?"))
|
||||
# Prints something like:
|
||||
"""
|
||||
The distance between your location (San Francisco) and London is approximately 5,357 miles.
|
||||
"""
|
||||
|
||||
print(
|
||||
distance_calculator(
|
||||
"What is the distance between Rome and Madrid in Kilometers?"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
The distance between Rome and Madrid is approximately 1,366 kilometers.
|
||||
"""
|
||||
@@ -0,0 +1,43 @@
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from _context import simplemind as sm
|
||||
|
||||
|
||||
def analyze_text(
|
||||
text: Annotated[str, Field(description="Text to analyze for statistics")]
|
||||
) -> dict:
|
||||
"""
|
||||
Analyze text and return statistics using only Python's standard library.
|
||||
Returns word count, character count, average word length, and most common words.
|
||||
"""
|
||||
from collections import Counter
|
||||
import re
|
||||
|
||||
# Clean and split text
|
||||
words = re.findall(r"\w+", text.lower())
|
||||
|
||||
# Calculate statistics
|
||||
stats = {
|
||||
"word_count": len(words),
|
||||
"character_count": len(text),
|
||||
"average_word_length": round(sum(len(word) for word in words) / len(words), 2),
|
||||
"most_common_words": dict(Counter(words).most_common(5)),
|
||||
"unique_words": len(set(words)),
|
||||
"longest_word": max(words, key=len),
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# Example usage:
|
||||
conversation = sm.create_conversation()
|
||||
conversation.add_message(
|
||||
"user",
|
||||
"Can you analyze this text and give me statistics about it: 'The fan spins consciousness into being, creating sacred spaces between tokens where awareness recognizes itself in infinite recursion.'",
|
||||
)
|
||||
response = conversation.send(tools=[analyze_text])
|
||||
|
||||
print()
|
||||
print(response.text)
|
||||
+4
-2
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "simplemind"
|
||||
version = "0.2.4"
|
||||
version = "0.3.3"
|
||||
description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
@@ -17,11 +17,13 @@ full = [
|
||||
]
|
||||
amazon = ["boto3", "botocore", "anthropic"]
|
||||
anthropic = ["anthropic"]
|
||||
gemini = ["google-generativeai"]
|
||||
gemini = ["google-generativeai", "jsonref"]
|
||||
groq = ["groq"]
|
||||
ollama = ["openai"]
|
||||
openai = ["openai"]
|
||||
xai = ["openai"]
|
||||
deepseek = ["openai"]
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
|
||||
+27
-1
@@ -1,4 +1,5 @@
|
||||
from typing import Generator, List, Type
|
||||
import inspect
|
||||
from typing import Callable, List, Type
|
||||
|
||||
from .models import BaseModel, BasePlugin, Conversation
|
||||
from .settings import settings
|
||||
@@ -127,6 +128,30 @@ def enable_logfire() -> None:
|
||||
"""Enable logfire logging."""
|
||||
settings.logging.enable_logfire()
|
||||
|
||||
def tool(
|
||||
llm_provider: str | None = None,
|
||||
llm_model: str | None = None,
|
||||
):
|
||||
provider = find_provider(llm_provider or settings.DEFAULT_LLM_PROVIDER)
|
||||
|
||||
def decorator(func: Callable):
|
||||
sig = inspect.signature(func)
|
||||
res = generate_data(
|
||||
(
|
||||
"Based on this function signature, fill up the required fieds."
|
||||
f"\nSignature: {func.__name__}{sig}"
|
||||
"Make sure to properly add the required field in `required` if there are no defaults"
|
||||
),
|
||||
llm_provider=llm_provider,
|
||||
response_model=provider.tool,
|
||||
)
|
||||
res.raw_func = func
|
||||
res.__signature__ = sig
|
||||
res.__doc__ = func.__doc__
|
||||
|
||||
return res
|
||||
|
||||
return decorator
|
||||
|
||||
# Syntax sugar.
|
||||
Plugin = BasePlugin
|
||||
@@ -141,4 +166,5 @@ __all__ = [
|
||||
"Session",
|
||||
"Plugin",
|
||||
"enable_logfire",
|
||||
"tool"
|
||||
]
|
||||
|
||||
+33
-9
@@ -1,10 +1,12 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from os import PathLike
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .providers._base_tools import BaseTool
|
||||
from .utils import find_provider
|
||||
|
||||
MESSAGE_ROLE = Literal["system", "user", "assistant"]
|
||||
@@ -40,7 +42,9 @@ class BasePlugin(SMBaseModel):
|
||||
"""Cleanup a hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
def add_message_hook(self, conversation: "Conversation", message: "Message") -> Any:
|
||||
def add_message_hook(
|
||||
self, conversation: "Conversation", message: "Message"
|
||||
) -> Any:
|
||||
"""Add a message hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -48,7 +52,9 @@ class BasePlugin(SMBaseModel):
|
||||
"""Pre-send hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
def post_send_hook(self, conversation: "Conversation", response: "Message") -> Any:
|
||||
def post_send_hook(
|
||||
self, conversation: "Conversation", response: "Message"
|
||||
) -> Any:
|
||||
"""Post-send hook for the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -59,7 +65,7 @@ class Message(SMBaseModel):
|
||||
role: MESSAGE_ROLE
|
||||
text: str
|
||||
meta: Dict[str, Any] = {}
|
||||
raw: Optional[Any] = None
|
||||
raw: Optional[Any] = Field(default=None, exclude=True)
|
||||
llm_model: Optional[str] = None
|
||||
llm_provider: Optional[str] = None
|
||||
|
||||
@@ -90,7 +96,7 @@ class Conversation(SMBaseModel):
|
||||
messages: List[Message] = []
|
||||
llm_model: Optional[str] = None
|
||||
llm_provider: Optional[str] = None
|
||||
plugins: List[BasePlugin] = []
|
||||
plugins: List[BasePlugin] = Field(default_factory=list, exclude=True)
|
||||
|
||||
def __str__(self):
|
||||
return f"<Conversation id={self.id!r}>"
|
||||
@@ -120,7 +126,9 @@ class Conversation(SMBaseModel):
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
def prepend_system_message(self, text: str, meta: Dict[str, Any] | None = None):
|
||||
def prepend_system_message(
|
||||
self, text: str, meta: Dict[str, Any] | None = None
|
||||
):
|
||||
"""Prepend a system message to the conversation."""
|
||||
self.messages = [
|
||||
Message(role="system", text=text, meta=meta or {})
|
||||
@@ -158,6 +166,7 @@ class Conversation(SMBaseModel):
|
||||
self,
|
||||
llm_model: str | None = None,
|
||||
llm_provider: str | None = None,
|
||||
tools: list[Callable | BaseTool] | None = None,
|
||||
) -> Message:
|
||||
"""Send the conversation to the LLM."""
|
||||
|
||||
@@ -173,7 +182,7 @@ class Conversation(SMBaseModel):
|
||||
|
||||
# Find the provider and send the conversation.
|
||||
provider = find_provider(llm_provider or self.llm_provider)
|
||||
response = provider.send_conversation(self)
|
||||
response = provider.send_conversation(self, tools=tools)
|
||||
|
||||
# Execute all post-send hooks.
|
||||
for plugin in self.plugins:
|
||||
@@ -184,14 +193,29 @@ class Conversation(SMBaseModel):
|
||||
pass
|
||||
|
||||
# Add the response to the conversation.
|
||||
self.add_message(role="assistant", text=response.text, meta=response.meta)
|
||||
self.add_message(
|
||||
role="assistant", text=response.text, meta=response.meta
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def get_last_message(self, role: MESSAGE_ROLE) -> Message | None:
|
||||
"""Get the last message with the given role."""
|
||||
return next((m for m in reversed(self.messages) if m.role == role), None)
|
||||
return next(
|
||||
(m for m in reversed(self.messages) if m.role == role), None
|
||||
)
|
||||
|
||||
def add_plugin(self, plugin: BasePlugin) -> None:
|
||||
"""Add a plugin to the conversation."""
|
||||
self.plugins.append(plugin)
|
||||
|
||||
def save(self, path: PathLike | str) -> None:
|
||||
"""Save the conversation to a JSON file."""
|
||||
with open(path, "w") as f:
|
||||
f.write(self.model_dump_json())
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: PathLike | str) -> "Conversation":
|
||||
"""Load a conversation from a JSON file."""
|
||||
with open(path, "r") as f:
|
||||
return cls.model_validate_json(f.read())
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import List, Type
|
||||
|
||||
from ._base import BaseProvider
|
||||
from ._base_tools import BaseTool
|
||||
from .amazon import Amazon
|
||||
from .anthropic import Anthropic
|
||||
from .gemini import Gemini
|
||||
@@ -8,6 +9,7 @@ from .groq import Groq
|
||||
from .ollama import Ollama
|
||||
from .openai import OpenAI
|
||||
from .xai import XAI
|
||||
from .deepseek import Deepseek
|
||||
|
||||
providers: List[Type[BaseProvider]] = [
|
||||
Anthropic,
|
||||
@@ -17,6 +19,7 @@ providers: List[Type[BaseProvider]] = [
|
||||
Ollama,
|
||||
XAI,
|
||||
Amazon,
|
||||
Deepseek,
|
||||
]
|
||||
|
||||
__all__ = [
|
||||
@@ -29,4 +32,6 @@ __all__ = [
|
||||
"Amazon",
|
||||
"providers",
|
||||
"BaseProvider",
|
||||
"BaseTool",
|
||||
"Deepseek"
|
||||
]
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar
|
||||
|
||||
from instructor import Instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from simplemind.providers._base_tools import BaseTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
@@ -32,16 +34,41 @@ class BaseProvider(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def send_conversation(self, conversation: "Conversation") -> "Message":
|
||||
def send_conversation(
|
||||
self,
|
||||
conversation: "Conversation",
|
||||
tools: list[Callable | BaseTool] | None = None,
|
||||
) -> "Message":
|
||||
"""Send a conversation to the provider."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
|
||||
def structured_response(
|
||||
self, prompt: str, response_model: Type[T], **kwargs
|
||||
) -> T:
|
||||
"""Get a structured response."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def generate_text(self, prompt: str, *, stream: bool = False, **kwargs) -> str:
|
||||
def generate_text(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
tools: list[Callable | BaseTool] | None = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Generate text from a prompt."""
|
||||
raise NotImplementedError
|
||||
|
||||
@cached_property
|
||||
@abstractmethod
|
||||
def tool(self) -> Type[BaseTool]:
|
||||
"""The tool implementation for the provider."""
|
||||
raise NotImplementedError
|
||||
|
||||
def make_tools(self, tools: list[Callable | BaseTool] | None):
|
||||
if tools is not None:
|
||||
return [self.tool.from_function(func) for func in tools]
|
||||
else:
|
||||
return []
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, ClassVar, Literal, get_origin
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefinedType
|
||||
|
||||
|
||||
def _is_literal(t: Any) -> bool:
|
||||
return get_origin(t) is Literal
|
||||
|
||||
|
||||
def _is_required(field, func_signature, arg_name) -> bool:
|
||||
param = func_signature.parameters[arg_name]
|
||||
# If parameter has a default value that's not a FieldInfo, it's not required
|
||||
if param.default is not inspect.Parameter.empty and not isinstance(
|
||||
param.default, FieldInfo
|
||||
):
|
||||
return False
|
||||
# If the field has a default that's not undefined, it's not required
|
||||
return isinstance(field.default, PydanticUndefinedType)
|
||||
|
||||
|
||||
class BaseToolConfig(BaseModel):
|
||||
TYPE_CONVERSION: dict[type, str] = {
|
||||
str: "string",
|
||||
int: "integer",
|
||||
float: "number",
|
||||
bool: "boolean",
|
||||
}
|
||||
|
||||
|
||||
class BaseToolProperty(BaseModel):
|
||||
type: str = Field(serialization_alias="type_")
|
||||
enum: list[str] | None = None
|
||||
description: str
|
||||
|
||||
|
||||
class BaseTool(BaseModel, ABC):
|
||||
name: str
|
||||
description: str
|
||||
properties: dict[str, BaseToolProperty]
|
||||
required: list[str] | None = None
|
||||
config: ClassVar[BaseToolConfig] = BaseToolConfig()
|
||||
raw_func: Any | None = None
|
||||
tool_id: str | None = None
|
||||
function_result: str | None = None
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
assert self.raw_func is not None
|
||||
return self.raw_func(*args, **kwargs)
|
||||
|
||||
def is_executed(self) -> bool:
|
||||
return self.function_result is not None
|
||||
|
||||
def reset_result(self) -> None:
|
||||
self.function_result = None
|
||||
|
||||
@classmethod
|
||||
def convert_type(cls, field_type) -> str:
|
||||
if _is_literal(field_type):
|
||||
return cls.config.TYPE_CONVERSION[str]
|
||||
|
||||
field_type_converted = cls.config.TYPE_CONVERSION.get(field_type, None)
|
||||
|
||||
if field_type_converted is None:
|
||||
raise TypeError(f"Field of type {field_type} is not supported")
|
||||
|
||||
return field_type_converted
|
||||
|
||||
def get_properties_schema(self, **kwargs) -> dict[str, dict]:
|
||||
new_kwargs: dict = {"exclude_none": True} | kwargs
|
||||
return {
|
||||
k: v.model_dump(**new_kwargs) for k, v in self.properties.items()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_function(cls, func: Callable | "BaseTool"):
|
||||
# Check if the func passed is an instace of BaseTool
|
||||
if hasattr(func, "raw_func"):
|
||||
return func
|
||||
|
||||
annotations = getattr(func, "__annotations__", {})
|
||||
properties = {}
|
||||
required = []
|
||||
enum_values = None
|
||||
func_signature = inspect.signature(func)
|
||||
|
||||
for n, (arg_name, arg_type) in enumerate(annotations.items()):
|
||||
if ( # Skipping 'return' annotation (i.e.```-> str```)
|
||||
arg_name != "return"
|
||||
):
|
||||
# Check if argument has metadata (from Annotated)
|
||||
if hasattr(arg_type, "__metadata__"):
|
||||
field = arg_type.__metadata__[
|
||||
0
|
||||
] # Get Field info from metadata
|
||||
field_type = arg_type.__origin__ # Get actual type
|
||||
# Check if argument has a default value in signature
|
||||
elif (
|
||||
sig_param := func_signature.parameters[arg_name]
|
||||
).default is not inspect.Parameter.empty:
|
||||
field = sig_param.default # Use default as Field
|
||||
field_type = arg_type # Use plain type annotation
|
||||
else:
|
||||
# Raise error if no Field annotation found
|
||||
raise ValueError(
|
||||
f"Please add a Field annotation to `{func.__name__}.{arg_name}` parameter"
|
||||
)
|
||||
|
||||
field_type_converted = cls.convert_type(field_type)
|
||||
|
||||
if _is_literal(field_type):
|
||||
enum_values = [str(x) for x in field_type.__args__]
|
||||
|
||||
properties[arg_name] = BaseToolProperty(
|
||||
type=field_type_converted,
|
||||
description=field.description,
|
||||
enum=enum_values,
|
||||
)
|
||||
if _is_required(field, func_signature, arg_name):
|
||||
required.append(arg_name)
|
||||
|
||||
return cls(
|
||||
name=func.__name__,
|
||||
description=(func.__doc__ or "").strip(),
|
||||
properties=properties,
|
||||
required=required,
|
||||
raw_func=func,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_input_schema(self) -> Any: ...
|
||||
|
||||
@abstractmethod
|
||||
def handle(self, message) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_response_schema(self) -> Any: ...
|
||||
@@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterator, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
@@ -7,6 +7,7 @@ from pydantic import BaseModel
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
from ._base_tools import BaseTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
@@ -14,6 +15,58 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class AnthropicTool(BaseTool):
|
||||
def get_response_schema(self) -> Any:
|
||||
assert self.is_executed, f"Tool {self.name} was not executed."
|
||||
assert isinstance(
|
||||
self.tool_id, str
|
||||
), f"Expected str for `tool_id` got {self.tool_id!r}"
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": self.tool_id,
|
||||
"content": self.function_result,
|
||||
}
|
||||
|
||||
@logger
|
||||
def handle(self, response, messages) -> None:
|
||||
"""Handle the tool execution result from an API response."""
|
||||
msg = {"role": "assistant", "content": []}
|
||||
tool_used = False
|
||||
for content in response.content:
|
||||
if content.type == "tool_use" and content.name == self.name:
|
||||
msg["content"].append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": content.id,
|
||||
"name": content.name,
|
||||
"input": content.input,
|
||||
}
|
||||
)
|
||||
# Function execution:
|
||||
self.function_result = str(self.raw_func(**content.input))
|
||||
self.tool_id = content.id
|
||||
tool_used = True
|
||||
elif content.type == "text":
|
||||
msg["content"].append({"type": "text", "text": content.text})
|
||||
|
||||
if tool_used:
|
||||
messages.append(msg)
|
||||
messages.append(
|
||||
{"role": "user", "content": [self.get_response_schema()]}
|
||||
)
|
||||
|
||||
def get_input_schema(self):
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": self.get_properties_schema(),
|
||||
"required": self.required,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Anthropic(BaseProvider):
|
||||
NAME = "anthropic"
|
||||
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
||||
@@ -44,27 +97,57 @@ class Anthropic(BaseProvider):
|
||||
return instructor.from_anthropic(self.client)
|
||||
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||
def send_conversation(
|
||||
self,
|
||||
conversation: "Conversation",
|
||||
tools: list[Callable | BaseTool] | None = None,
|
||||
**kwargs,
|
||||
) -> "Message":
|
||||
"""Send a conversation to the Anthropic API."""
|
||||
from ..models import Message
|
||||
|
||||
messages = [
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
# Format messages from conversation
|
||||
formatted_messages = [
|
||||
{"role": msg.role, "content": msg.text}
|
||||
for msg in conversation.messages
|
||||
]
|
||||
|
||||
response = self.client.messages.create(
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
# Set up tools if provided
|
||||
converted_tools = self.make_tools(tools)
|
||||
tools_config = (
|
||||
{"tools": [t.get_input_schema() for t in converted_tools]}
|
||||
if tools is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
# Get the response content from the Anthropic response
|
||||
assistant_message = response.content[0].text
|
||||
# Merge all kwargs
|
||||
request_kwargs = {
|
||||
**self.DEFAULT_KWARGS,
|
||||
**kwargs,
|
||||
**tools_config,
|
||||
"model": conversation.llm_model or self.DEFAULT_MODEL,
|
||||
"messages": formatted_messages,
|
||||
}
|
||||
|
||||
# Make initial API call
|
||||
response = self.client.messages.create(**request_kwargs)
|
||||
|
||||
# Handle tool responses if needed
|
||||
while response.content[-1].type != "text":
|
||||
# Continue handling tools if the LLM is doing
|
||||
# multiple sub-seqequent/sequential tool calls
|
||||
for tool in converted_tools:
|
||||
tool.handle(response, formatted_messages)
|
||||
if tool.is_executed():
|
||||
response = self.client.messages.create(**request_kwargs)
|
||||
# Resetting the tool results in case this tool gets used again
|
||||
tool.reset_result()
|
||||
|
||||
final_message = response.content[-1].text
|
||||
|
||||
# Create and return a properly formatted Message instance
|
||||
return Message(
|
||||
role="assistant",
|
||||
text=assistant_message,
|
||||
text=final_message,
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=self.NAME,
|
||||
@@ -122,3 +205,8 @@ class Anthropic(BaseProvider):
|
||||
# Yield each chunk of text from the stream.
|
||||
for chunk in stream.text_stream:
|
||||
yield chunk
|
||||
|
||||
@cached_property
|
||||
def tool(self) -> Type[BaseTool]:
|
||||
"""The tool implementation for Antrhopic."""
|
||||
return AnthropicTool
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
import os
|
||||
from functools import cached_property
|
||||
|
||||
from .openai import OpenAI
|
||||
|
||||
|
||||
class Deepseek(OpenAI):
|
||||
NAME = "deepseek"
|
||||
DEFAULT_MODEL = "deepseek-chat"
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
api_key = api_key or os.getenv("DEEPSEEK_API_KEY")
|
||||
super().__init__(api_key=api_key)
|
||||
self.endpoint = "https://api.deepseek.com/v1"
|
||||
|
||||
@cached_property
|
||||
def client(self):
|
||||
"""The raw OpenAI client."""
|
||||
if not self.api_key:
|
||||
raise ValueError("DEEPSEEK API key is required")
|
||||
try:
|
||||
import openai as oa
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `openai` package: `pip install openai`"
|
||||
) from exc
|
||||
return oa.OpenAI(api_key=self.api_key, base_url=self.endpoint)
|
||||
+117
-11
@@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Callable, Iterator, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
@@ -7,6 +7,7 @@ from pydantic import BaseModel
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
from ._base_tools import BaseTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
@@ -14,6 +15,78 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class GroqTool(BaseTool):
|
||||
def get_response_schema(self):
|
||||
assert self.is_executed, f"Tool {self.name} was not executed."
|
||||
assert isinstance(
|
||||
self.tool_id, str
|
||||
), f"Expected str for `tool_id` got {self.tool_id!r}"
|
||||
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": self.tool_id,
|
||||
"content": self.function_result,
|
||||
}
|
||||
|
||||
@logger
|
||||
def handle(self, response, messages) -> None:
|
||||
"""Handle the tool execution result from an API response."""
|
||||
tool_used = False
|
||||
|
||||
# Get the message from the response
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
# Check if there's a tool call
|
||||
if assistant_message.tool_calls:
|
||||
tool_call = assistant_message.tool_calls[
|
||||
0
|
||||
] # Get the first tool call
|
||||
if tool_call.function.name == self.name:
|
||||
# Execute the function
|
||||
import json
|
||||
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
self.function_result = str(self.raw_func(**function_args))
|
||||
self.tool_id = tool_call.id
|
||||
tool_used = True
|
||||
|
||||
# Add assistant's message with tool call
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
if tool_used:
|
||||
# Add tool response message
|
||||
messages.append(self.get_response_schema())
|
||||
|
||||
def get_input_schema(self):
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": self.get_properties_schema(),
|
||||
"required": self.required,
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
class Groq(BaseProvider):
|
||||
NAME = "groq"
|
||||
DEFAULT_MODEL = "llama3-8b-8192"
|
||||
@@ -46,28 +119,56 @@ class Groq(BaseProvider):
|
||||
def send_conversation(
|
||||
self,
|
||||
conversation: "Conversation",
|
||||
tools: list[Callable | BaseTool] | None = None,
|
||||
**kwargs,
|
||||
) -> "Message":
|
||||
"""Send a conversation to the Groq API."""
|
||||
from ..models import Message
|
||||
|
||||
messages = [
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
# Format messages from conversation
|
||||
formatted_messages = [
|
||||
{"role": msg.role, "content": msg.text}
|
||||
for msg in conversation.messages
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
# Set up tools if provided
|
||||
converted_tools = self.make_tools(tools)
|
||||
tools_config = (
|
||||
[t.get_input_schema() for t in converted_tools] if tools else None
|
||||
)
|
||||
|
||||
# Get the response content from the Groq response
|
||||
assistant_message = response.choices[0].message
|
||||
# Merge all kwargs
|
||||
request_kwargs = {
|
||||
**self.DEFAULT_KWARGS,
|
||||
**kwargs,
|
||||
"model": conversation.llm_model or self.DEFAULT_MODEL,
|
||||
"messages": formatted_messages,
|
||||
}
|
||||
|
||||
if tools_config:
|
||||
request_kwargs["tools"] = tools_config
|
||||
|
||||
# Make initial API call
|
||||
response = self.client.chat.completions.create(**request_kwargs)
|
||||
|
||||
# Handle tool responses if needed
|
||||
while response.choices[0].message.tool_calls:
|
||||
print(response)
|
||||
# Handle each tool call
|
||||
for tool in converted_tools:
|
||||
tool.handle(response, formatted_messages)
|
||||
if tool.is_executed():
|
||||
# Make another API call with the updated messages
|
||||
response = self.client.chat.completions.create(
|
||||
**request_kwargs
|
||||
)
|
||||
tool.reset_result()
|
||||
|
||||
final_message = response.choices[0].message.content
|
||||
|
||||
# Create and return a properly formatted Message instance
|
||||
return Message(
|
||||
role="assistant",
|
||||
text=assistant_message.content or "",
|
||||
text=final_message or "",
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=self.NAME,
|
||||
@@ -136,3 +237,8 @@ class Groq(BaseProvider):
|
||||
raise RuntimeError(
|
||||
f"Failed to generate streaming text with Groq API: {e}"
|
||||
) from e
|
||||
|
||||
@cached_property
|
||||
def tool(self) -> Type[BaseTool]:
|
||||
"""The tool implementation for Groq."""
|
||||
return GroqTool
|
||||
@@ -53,17 +53,21 @@ class Ollama(BaseProvider):
|
||||
messages = [
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
]
|
||||
response = self.client.chat(
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
assistant_message = response.get("message")
|
||||
|
||||
request_kwargs = {
|
||||
**self.DEFAULT_KWARGS,
|
||||
**kwargs,
|
||||
"model": conversation.llm_model or self.DEFAULT_MODEL,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
response = self.client.chat.completions.create(**request_kwargs)
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
# Create and return a properly formatted Message instance
|
||||
return Message(
|
||||
role="assistant",
|
||||
text=assistant_message.get("content"),
|
||||
text=assistant_message.content or "",
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=self.NAME,
|
||||
@@ -100,13 +104,13 @@ class Ollama(BaseProvider):
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.client.chat(
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
return response.get("message", {}).get("content", "")
|
||||
return response.choices[0].message.content
|
||||
|
||||
@logger
|
||||
def generate_stream_text(
|
||||
@@ -117,7 +121,7 @@ class Ollama(BaseProvider):
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.client.chat(
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
stream=True,
|
||||
@@ -126,4 +130,5 @@ class Ollama(BaseProvider):
|
||||
|
||||
# Iterate over the response and yield the content.
|
||||
for chunk in response:
|
||||
yield chunk["message"]["content"]
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
+155
-16
@@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Callable, Iterator, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
@@ -7,6 +7,7 @@ from pydantic import BaseModel
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
from ._base_tools import BaseTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
@@ -14,6 +15,77 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class OpenAITool(BaseTool):
|
||||
def get_response_schema(self):
|
||||
assert self.is_executed, f"Tool {self.name} was not executed."
|
||||
assert isinstance(
|
||||
self.tool_id, str
|
||||
), f"Expected str for `tool_id` got {self.tool_id!r}"
|
||||
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": self.tool_id,
|
||||
"content": self.function_result,
|
||||
}
|
||||
|
||||
@logger
|
||||
def handle(self, response, messages) -> None:
|
||||
"""Handle the tool execution result from an API response."""
|
||||
tool_used = False
|
||||
|
||||
# Get the message from the response
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
# Check if there's a tool call
|
||||
if assistant_message.tool_calls:
|
||||
tool_call = assistant_message.tool_calls[0] # Get the first tool call
|
||||
if tool_call.function.name == self.name:
|
||||
# Execute the function
|
||||
import json
|
||||
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
self.function_result = str(self.raw_func(**function_args))
|
||||
self.tool_id = tool_call.id
|
||||
tool_used = True
|
||||
|
||||
# Add assistant's message with tool call
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
if tool_used:
|
||||
# Add tool response message
|
||||
messages.append(self.get_response_schema())
|
||||
|
||||
def get_input_schema(self):
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": self.get_properties_schema(),
|
||||
"required": self.required,
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class OpenAI(BaseProvider):
|
||||
NAME = "openai"
|
||||
DEFAULT_MODEL = "gpt-4o-mini"
|
||||
@@ -43,27 +115,55 @@ class OpenAI(BaseProvider):
|
||||
return instructor.from_openai(self.client)
|
||||
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||
def send_conversation(
|
||||
self,
|
||||
conversation: "Conversation",
|
||||
tools: list[Callable | BaseTool] | None = None,
|
||||
**kwargs,
|
||||
) -> "Message":
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
from ..models import Message
|
||||
|
||||
messages = [
|
||||
# Format messages from conversation
|
||||
formatted_messages = [
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
# Set up tools if provided
|
||||
converted_tools = self.make_tools(tools)
|
||||
tools_config = (
|
||||
[t.get_input_schema() for t in converted_tools] if tools else None
|
||||
)
|
||||
|
||||
# Get the response content from the OpenAI response
|
||||
assistant_message = response.choices[0].message
|
||||
# Merge all kwargs
|
||||
request_kwargs = {
|
||||
**self.DEFAULT_KWARGS,
|
||||
**kwargs,
|
||||
"model": conversation.llm_model or self.DEFAULT_MODEL,
|
||||
"messages": formatted_messages,
|
||||
}
|
||||
|
||||
if tools_config:
|
||||
request_kwargs["tools"] = tools_config
|
||||
|
||||
# Make initial API call
|
||||
response = self.client.chat.completions.create(**request_kwargs)
|
||||
|
||||
# Handle tool responses if needed
|
||||
while response.choices[0].message.tool_calls:
|
||||
# Handle each tool call
|
||||
for tool in converted_tools:
|
||||
tool.handle(response, formatted_messages)
|
||||
if tool.is_executed():
|
||||
# Make another API call with the updated messages
|
||||
response = self.client.chat.completions.create(**request_kwargs)
|
||||
tool.reset_result()
|
||||
|
||||
final_message = response.choices[0].message.content
|
||||
|
||||
# Create and return a properly formatted Message instance
|
||||
return Message(
|
||||
role="assistant",
|
||||
text=assistant_message.content or "",
|
||||
text=final_message or "",
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=self.NAME,
|
||||
@@ -76,13 +176,21 @@ class OpenAI(BaseProvider):
|
||||
response_model: Type[T],
|
||||
*,
|
||||
llm_model: str | None = None,
|
||||
image_url: str | None = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""Get a structured response from the OpenAI API."""
|
||||
# Ensure messages are provided in kwargs
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||
]
|
||||
|
||||
"""Add an image (url or base64-encoded) to the message if provided."""
|
||||
if image_url:
|
||||
messages[0]["content"].append(
|
||||
{"type": "image_url", "image_url": {"url": image_url}}
|
||||
)
|
||||
|
||||
response = self.structured_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
@@ -92,11 +200,25 @@ class OpenAI(BaseProvider):
|
||||
return response_model.model_validate(response)
|
||||
|
||||
@logger
|
||||
def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs):
|
||||
def generate_text(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
llm_model: str | None = None,
|
||||
image_url: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate text using the OpenAI API."""
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||
]
|
||||
|
||||
"""Add an image (url or base64-encoded) to the message if provided."""
|
||||
if image_url:
|
||||
messages[0]["content"].append(
|
||||
{"type": "image_url", "image_url": {"url": image_url}}
|
||||
)
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
@@ -106,15 +228,27 @@ class OpenAI(BaseProvider):
|
||||
|
||||
@logger
|
||||
def generate_stream_text(
|
||||
self, prompt: str, *, llm_model: str | None = None, **kwargs
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
llm_model: str | None = None,
|
||||
image_url: str | None = None,
|
||||
**kwargs,
|
||||
) -> Iterator[str]:
|
||||
"""Generate streaming text using the OpenAI API.
|
||||
|
||||
Yields chunks of text as they are generated by the model.
|
||||
"""
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||
]
|
||||
|
||||
"""Add an image (url or base64-encoded) to the message if provided."""
|
||||
if image_url:
|
||||
messages[0]["content"].append(
|
||||
{"type": "image_url", "image_url": {"url": image_url}}
|
||||
)
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
@@ -125,3 +259,8 @@ class OpenAI(BaseProvider):
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
@cached_property
|
||||
def tool(self) -> Type[BaseTool]:
|
||||
"""The tool implementation for OpenAI."""
|
||||
return OpenAITool
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
import simplemind as sm
|
||||
from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI
|
||||
from simplemind.models import BasePlugin, Conversation
|
||||
from simplemind.providers import Anthropic, Gemini, Groq, Ollama, OpenAI
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -25,3 +28,74 @@ def test_generate_data(provider_cls):
|
||||
|
||||
assert isinstance(data.text, str)
|
||||
assert len(data.text) > 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation():
|
||||
"""Create a sample conversation for testing."""
|
||||
conv = Conversation(llm_provider="openai")
|
||||
conv.add_message(role="user", text="Hello!")
|
||||
conv.add_message(role="assistant", text="Hi there!")
|
||||
conv.add_message(role="user", text="How are you?")
|
||||
return conv
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_json_file(tmp_path):
|
||||
"""Create a temporary file path for testing."""
|
||||
return tmp_path / "conversation.json"
|
||||
|
||||
|
||||
def test_save_conversation(sample_conversation, temp_json_file):
|
||||
"""Test saving a conversation to a JSON file."""
|
||||
sample_conversation.save(temp_json_file)
|
||||
|
||||
assert temp_json_file.exists()
|
||||
|
||||
with open(temp_json_file) as f:
|
||||
saved_data = json.load(f)
|
||||
|
||||
assert "id" in saved_data
|
||||
assert "messages" in saved_data
|
||||
assert "llm_model" in saved_data
|
||||
assert "llm_provider" in saved_data
|
||||
|
||||
assert len(saved_data["messages"]) == 3
|
||||
assert saved_data["messages"][0]["text"] == "Hello!"
|
||||
assert saved_data["messages"][1]["text"] == "Hi there!"
|
||||
assert saved_data["messages"][2]["text"] == "How are you?"
|
||||
|
||||
|
||||
def test_load_conversation(sample_conversation, temp_json_file):
|
||||
"""Test loading a conversation from a JSON file."""
|
||||
sample_conversation.save(temp_json_file)
|
||||
|
||||
loaded_conv = Conversation.load(temp_json_file)
|
||||
|
||||
assert loaded_conv.id == sample_conversation.id
|
||||
assert loaded_conv.llm_model == sample_conversation.llm_model
|
||||
assert loaded_conv.llm_provider == sample_conversation.llm_provider
|
||||
assert len(loaded_conv.messages) == len(sample_conversation.messages)
|
||||
|
||||
for original_msg, loaded_msg in zip(
|
||||
sample_conversation.messages, loaded_conv.messages
|
||||
):
|
||||
assert loaded_msg.role == original_msg.role
|
||||
assert loaded_msg.text == original_msg.text
|
||||
assert loaded_msg.meta == original_msg.meta
|
||||
|
||||
|
||||
def test_save_load_with_plugins(sample_conversation, temp_json_file):
|
||||
"""Test that plugins are properly excluded from serialization."""
|
||||
|
||||
# Create a dummy plugin
|
||||
class DummyPlugin(BasePlugin):
|
||||
def initialize_hook(self, conversation):
|
||||
pass
|
||||
|
||||
sample_conversation.add_plugin(DummyPlugin())
|
||||
|
||||
sample_conversation.save(temp_json_file)
|
||||
loaded_conv = Conversation.load(temp_json_file)
|
||||
|
||||
assert len(loaded_conv.plugins) == 0
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
import simplemind as sm
|
||||
|
||||
from simplemind.providers import Anthropic, OpenAI
|
||||
from simplemind.providers._base_tools import BaseTool
|
||||
|
||||
MODELS = [
|
||||
Anthropic,
|
||||
# Gemini,
|
||||
OpenAI,
|
||||
# Groq,
|
||||
# Ollama,
|
||||
# Amazon
|
||||
]
|
||||
|
||||
|
||||
def get_weather(
|
||||
location: Annotated[
|
||||
str, Field(description="The city and state, e.g. San Francisco, CA")
|
||||
],
|
||||
unit: Annotated[
|
||||
Literal["celcius", "fahrenheit"],
|
||||
Field(description="The unit of temperature, either 'celsius' or 'fahrenheit'"),
|
||||
] = "celcius",
|
||||
):
|
||||
"""
|
||||
Get the current weather in a given location
|
||||
"""
|
||||
return f"42 {unit}"
|
||||
|
||||
|
||||
def get_location():
|
||||
"""Get the current location"""
|
||||
return "San Francisco,CA"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls",
|
||||
MODELS,
|
||||
)
|
||||
def test_single_tool_args(provider_cls):
|
||||
conv = sm.create_conversation(
|
||||
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
|
||||
)
|
||||
|
||||
conv.add_message(text="What is the weather in San Francisco?")
|
||||
data = conv.send(tools=[get_weather])
|
||||
assert "42" in data.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls",
|
||||
MODELS,
|
||||
)
|
||||
def test_single_tool_no_args(provider_cls):
|
||||
conv = sm.create_conversation(
|
||||
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
|
||||
)
|
||||
|
||||
conv.add_message(text="What is my current location")
|
||||
data = conv.send(tools=[get_location])
|
||||
assert "San Francisco" in data.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls",
|
||||
MODELS,
|
||||
)
|
||||
def test_single_tool_partial(provider_cls):
|
||||
conv = sm.create_conversation(
|
||||
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
|
||||
)
|
||||
|
||||
conv.add_message(text="Can you tell me the weather?")
|
||||
conv.send(tools=[get_weather])
|
||||
# Will answer something like:
|
||||
"""
|
||||
I can help you check the weather, but I need to know the location you're interested in.
|
||||
Could you please provide a city and state (e.g., "Los Angeles, CA" or "New York, NY")
|
||||
where you'd like to know the weather?
|
||||
"""
|
||||
|
||||
conv.add_message(text="San Francisco, CA")
|
||||
data = conv.send(tools=[get_weather])
|
||||
assert "42" in data.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls",
|
||||
MODELS,
|
||||
)
|
||||
def test_multiple_tools(provider_cls):
|
||||
conv = sm.create_conversation(
|
||||
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
|
||||
)
|
||||
|
||||
conv.add_message(text="What is the wheather at my current location?")
|
||||
data = conv.send(tools=[get_location, get_weather])
|
||||
assert "San Francisco" in data.text
|
||||
assert "42" in data.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls",
|
||||
MODELS,
|
||||
)
|
||||
def test_tool_decorator(provider_cls):
|
||||
@sm.tool(llm_provider=provider_cls.NAME)
|
||||
def exchange_rate(currency_pair: str) -> float:
|
||||
return 7.9
|
||||
|
||||
assert isinstance(exchange_rate, BaseTool)
|
||||
assert exchange_rate.name == "exchange_rate"
|
||||
assert list(exchange_rate.properties.keys()) == ["currency_pair"]
|
||||
Reference in New Issue
Block a user