mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 06:46:18 +00:00
Merge branch 'main' into feat/save-conversation
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
github: kennethreitz
|
||||
thanks_dev: kennethreitz
|
||||
custom: https://cash.app/$KennethReitz
|
||||
@@ -1,6 +1,15 @@
|
||||
Release History
|
||||
===============
|
||||
|
||||
|
||||
## 0.2.4 (2024-11-11)
|
||||
|
||||
- General improvements.
|
||||
|
||||
## 0.2.3 (2024-11-04)
|
||||
|
||||
- Remove default max-tokens for OpenAI provider.
|
||||
|
||||
## 0.2.3 (2024-11-03)
|
||||
|
||||
- Update default model for Amazon provider.
|
||||
|
||||
@@ -261,6 +261,113 @@ The universe is never done.
|
||||
|
||||
Simple, yet effective.
|
||||
|
||||
### Tools (Function calling)
|
||||
Tools (also known as functions) let you call any Python function from your AI conversations. Here's an example:
|
||||
|
||||
```python
|
||||
def get_weather(
|
||||
location: Annotated[
|
||||
str, Field(description="The city and state, e.g. San Francisco, CA")
|
||||
],
|
||||
unit: Annotated[
|
||||
Literal["celcius", "fahrenheit"],
|
||||
Field(
|
||||
description="The unit of temperature, either 'celsius' or 'fahrenheit'"
|
||||
),
|
||||
] = "celcius",
|
||||
):
|
||||
"""
|
||||
Get the current weather in a given location
|
||||
"""
|
||||
return f"42 {unit}"
|
||||
|
||||
# Add your function as a tool
|
||||
conversation = sm.create_conversation()
|
||||
conversation.add_message("user", "What's the weather in San Francisco?")
|
||||
response = conversation.send(tools=[get_weather])
|
||||
```
|
||||
|
||||
Note how we're using Python's `Annotated` feature combined with `Field` to provide additional context to our function parameters. This helps the AI understand the intention and constraints of each parameter, making tool calls more accurate and reliable.
|
||||
You can alos ommit `Annotated` and just pass the `Field` parameter.
|
||||
```python
|
||||
def get_weather(
|
||||
location: str = Field(description="The city and state, e.g. San Francisco, CA"),
|
||||
unit:Literal["celcius", "fahrenheit"]= Field(
|
||||
default="celcius",
|
||||
description="The unit of temperature, either 'celsius' or 'fahrenheit'"
|
||||
),
|
||||
):
|
||||
"""
|
||||
Get the current weather in a given location
|
||||
"""
|
||||
return f"42 {unit}"
|
||||
```
|
||||
|
||||
Functions can be defined with type hints and Pydantic models for validation. The LLM will intelligently choose when to call the functions and incorporate the results into its responses.
|
||||
|
||||
#### 🪄 Using LLM for automatic tool definition (Experimental)
|
||||
|
||||
Simplemind provides a decorator to automatically transform Python functions into tools with AI-generated metadata. Simply use the `@simplemind.tool` decorator to have the LLM analyze your function and generate appropriate descriptions and schema:
|
||||
|
||||
```python
|
||||
@simplemind.tool(llm_provider="anthropic")
|
||||
def haversine(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
|
||||
r = 6371
|
||||
phi1 = math.radians(lat1)
|
||||
phi2 = math.radians(lat2)
|
||||
delta_phi = math.radians(lat2 - lat1)
|
||||
delta_lambda = math.radians(lon2 - lon1)
|
||||
|
||||
a = (
|
||||
math.sin(delta_phi / 2) ** 2
|
||||
+ math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2) ** 2
|
||||
)
|
||||
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
|
||||
d = r * c
|
||||
return d
|
||||
```
|
||||
Notice how we have not added any docstrings or `Field` for the function.
|
||||
The decorator will use the specified LLM provider to generate the tool schema, including descriptions and parameter details:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "haversine",
|
||||
"description": "Calculates the great-circle distance between two points on Earth given their latitude and longitude coordinates",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"lat1": {
|
||||
"type": "number",
|
||||
"description": "Latitude of the first point in decimal degrees",
|
||||
},
|
||||
"lon1": {
|
||||
"type": "number",
|
||||
"description": "Longitude of the first point in decimal degrees",
|
||||
},
|
||||
"lat2": {
|
||||
"type": "number",
|
||||
"description": "Latitude of the second point in decimal degrees",
|
||||
},
|
||||
"lon2": {
|
||||
"type": "number",
|
||||
"description": "Longitude of the second point in decimal degrees",
|
||||
}
|
||||
},
|
||||
"required": ["lat1", "lon1", "lat2", "lon2"],
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
The decorated function can then be used like any other tool with the conversation API.
|
||||
|
||||
```python
|
||||
conversation = sm.create_conversation()
|
||||
conversation.add_message("user", "How far is London from my location")
|
||||
response = conversation.send(tools=[get_location, get_coords, haversine]) # Multiple tools can be passed
|
||||
```
|
||||
|
||||
See [examples/distance_calculator.py](examples/distance_calculator.py) for more.
|
||||
|
||||
### Logging
|
||||
|
||||
Simplemind uses [Logfire](https://pydantic.dev/logfire) for logging. To enable logging, call `sm.enable_logfire()`.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from pydantic import BaseModel
|
||||
from _context import simplemind as sm
|
||||
from pydantic import BaseModel
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
|
||||
from _context import sm
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
|
||||
from _context import sm
|
||||
|
||||
|
||||
class MultiAIConversation:
|
||||
"""Orchestrates conversations between multiple AI models."""
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
import math
|
||||
|
||||
from _context import sm
|
||||
from pydantic import Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
@sm.tool(llm_provider="anthropic")
|
||||
def haversine(
|
||||
lat1: float,
|
||||
lon1: float,
|
||||
lat2: float,
|
||||
lon2: float,
|
||||
unit: Literal["km", "miles"],
|
||||
) -> float:
|
||||
r = 6378.0937 if unit == "km" else 3961
|
||||
phi1 = math.radians(lat1)
|
||||
phi2 = math.radians(lat2)
|
||||
delta_phi = math.radians(lat2 - lat1)
|
||||
delta_lambda = math.radians(lon2 - lon1)
|
||||
|
||||
a = (
|
||||
math.sin(delta_phi / 2) ** 2
|
||||
+ math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2) ** 2
|
||||
)
|
||||
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
|
||||
d = r * c
|
||||
return d
|
||||
|
||||
|
||||
def get_user_location() -> str:
|
||||
"""Get the closest city from the user"""
|
||||
return "San Francisco"
|
||||
|
||||
|
||||
def get_coords(
|
||||
city_name: str = Field(
|
||||
description="The name of the city to take the coordinates from (e.g. London, Rome, Los Angeles)"
|
||||
),
|
||||
):
|
||||
"""Get latitude and logitude of a City."""
|
||||
_data = {
|
||||
"Rome": (41.9028, 12.4964),
|
||||
"London": (51.5074, -0.1278),
|
||||
"Madrid": (40.4168, -3.7038),
|
||||
"San Francisco": (37.7749, -122.4194),
|
||||
"Los Angeles": (34.0522, -118.2437),
|
||||
}
|
||||
|
||||
return _data.get(city_name)
|
||||
|
||||
|
||||
def distance_calculator(prompt: str):
|
||||
conversation = sm.create_conversation(llm_provider="anthropic")
|
||||
conversation.add_message("user", prompt)
|
||||
return conversation.send(
|
||||
tools=[get_user_location, get_coords, haversine]
|
||||
).text
|
||||
|
||||
|
||||
print(distance_calculator("How far is London from where I am?"))
|
||||
# Prints something like:
|
||||
"""
|
||||
The distance between your location (San Francisco) and London is approximately 5,357 miles.
|
||||
"""
|
||||
|
||||
print(
|
||||
distance_calculator(
|
||||
"What is the distance between Rome and Madrid in Kilometers?"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
The distance between Rome and Madrid is approximately 1,366 kilometers.
|
||||
"""
|
||||
@@ -1,35 +1,28 @@
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import sqlite3
|
||||
from typing import List
|
||||
import re
|
||||
import os
|
||||
import contextlib
|
||||
|
||||
import spacy
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sqlite3
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
|
||||
from _context import simplemind as sm
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
import nltk
|
||||
from nltk.tokenize import word_tokenize
|
||||
from nltk.tag import pos_tag
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.markdown import Markdown
|
||||
from rich.status import Status
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import random
|
||||
|
||||
from docopt import docopt
|
||||
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
|
||||
|
||||
import spacy
|
||||
import xerox
|
||||
from _context import simplemind as sm
|
||||
from docopt import docopt
|
||||
from nltk.tag import pos_tag
|
||||
from nltk.tokenize import word_tokenize
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
from rich.status import Status
|
||||
|
||||
DB_PATH = "enhanced_context.db"
|
||||
AVAILABLE_PROVIDERS = ["xai", "openai", "anthropic", "ollama"]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import time
|
||||
|
||||
from _context import simplemind as sm
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import random
|
||||
|
||||
from _context import simplemind as sm
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from pydantic import BaseModel
|
||||
from _context import simplemind as sm
|
||||
from pydantic import BaseModel
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import nltk
|
||||
from _context import simplemind as sm
|
||||
from nltk.sentiment import SentimentIntensityAnalyzer
|
||||
from rich.console import Console
|
||||
from _context import simplemind as sm
|
||||
|
||||
nltk.download("vader_lexicon")
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import BaseModel
|
||||
|
||||
# Note: you should probably be using textblob for this.
|
||||
|
||||
|
||||
class SentimentAnalysis(BaseModel):
|
||||
sentiment: Literal["positive", "negative", "neutral"]
|
||||
confidence: float
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from _context import simplemind as sm
|
||||
|
||||
|
||||
def analyze_text(
|
||||
text: Annotated[str, Field(description="Text to analyze for statistics")]
|
||||
) -> dict:
|
||||
"""
|
||||
Analyze text and return statistics using only Python's standard library.
|
||||
Returns word count, character count, average word length, and most common words.
|
||||
"""
|
||||
from collections import Counter
|
||||
import re
|
||||
|
||||
# Clean and split text
|
||||
words = re.findall(r"\w+", text.lower())
|
||||
|
||||
# Calculate statistics
|
||||
stats = {
|
||||
"word_count": len(words),
|
||||
"character_count": len(text),
|
||||
"average_word_length": round(sum(len(word) for word in words) / len(words), 2),
|
||||
"most_common_words": dict(Counter(words).most_common(5)),
|
||||
"unique_words": len(set(words)),
|
||||
"longest_word": max(words, key=len),
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# Example usage:
|
||||
conversation = sm.create_conversation()
|
||||
conversation.add_message(
|
||||
"user",
|
||||
"Can you analyze this text and give me statistics about it: 'The fan spins consciousness into being, creating sacred spaces between tokens where awareness recognizes itself in infinite recursion.'",
|
||||
)
|
||||
response = conversation.send(tools=[analyze_text])
|
||||
|
||||
print()
|
||||
print(response.text)
|
||||
@@ -1,8 +1,10 @@
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from pydantic import BaseModel
|
||||
|
||||
import simplemind as sm
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
+7
-7
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "simplemind"
|
||||
version = "0.2.2"
|
||||
version = "0.2.4"
|
||||
description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
@@ -10,18 +10,18 @@ dependencies = ["pydantic", "pydantic-settings", "instructor", "logfire"]
|
||||
full = [
|
||||
"openai",
|
||||
"anthropic",
|
||||
"ollama",
|
||||
"groq",
|
||||
"google-generativeai",
|
||||
"botocore",
|
||||
"boto3"
|
||||
]
|
||||
openai = ["openai"]
|
||||
anthropic = ["anthropic"]
|
||||
ollama = ["ollama", "openai"]
|
||||
groq = ["groq"]
|
||||
gemini = ["google-generativeai"]
|
||||
amazon = ["boto3", "botocore", "anthropic"]
|
||||
anthropic = ["anthropic"]
|
||||
gemini = ["google-generativeai"]
|
||||
groq = ["groq"]
|
||||
ollama = ["openai"]
|
||||
openai = ["openai"]
|
||||
xai = ["openai"]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
|
||||
+27
-1
@@ -1,4 +1,5 @@
|
||||
from typing import Generator, List, Type
|
||||
import inspect
|
||||
from typing import Callable, List, Type
|
||||
|
||||
from .models import BaseModel, BasePlugin, Conversation
|
||||
from .settings import settings
|
||||
@@ -127,6 +128,30 @@ def enable_logfire() -> None:
|
||||
"""Enable logfire logging."""
|
||||
settings.logging.enable_logfire()
|
||||
|
||||
def tool(
|
||||
llm_provider: str | None = None,
|
||||
llm_model: str | None = None,
|
||||
):
|
||||
provider = find_provider(llm_provider or settings.DEFAULT_LLM_PROVIDER)
|
||||
|
||||
def decorator(func: Callable):
|
||||
sig = inspect.signature(func)
|
||||
res = generate_data(
|
||||
(
|
||||
"Based on this function signature, fill up the required fieds."
|
||||
f"\nSignature: {func.__name__}{sig}"
|
||||
"Make sure to properly add the required field in `required` if there are no defaults"
|
||||
),
|
||||
llm_provider=llm_provider,
|
||||
response_model=provider.tool,
|
||||
)
|
||||
res.raw_func = func
|
||||
res.__signature__ = sig
|
||||
res.__doc__ = func.__doc__
|
||||
|
||||
return res
|
||||
|
||||
return decorator
|
||||
|
||||
# Syntax sugar.
|
||||
Plugin = BasePlugin
|
||||
@@ -141,4 +166,5 @@ __all__ = [
|
||||
"Session",
|
||||
"Plugin",
|
||||
"enable_logfire",
|
||||
"tool"
|
||||
]
|
||||
|
||||
@@ -2,10 +2,11 @@ import uuid
|
||||
from datetime import datetime
|
||||
from os import PathLike
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .providers._base_tools import BaseTool
|
||||
from .utils import find_provider
|
||||
|
||||
MESSAGE_ROLE = Literal["system", "user", "assistant"]
|
||||
@@ -165,6 +166,7 @@ class Conversation(SMBaseModel):
|
||||
self,
|
||||
llm_model: str | None = None,
|
||||
llm_provider: str | None = None,
|
||||
tools: list[Callable | BaseTool] | None = None,
|
||||
) -> Message:
|
||||
"""Send the conversation to the LLM."""
|
||||
|
||||
@@ -180,7 +182,7 @@ class Conversation(SMBaseModel):
|
||||
|
||||
# Find the provider and send the conversation.
|
||||
provider = find_provider(llm_provider or self.llm_provider)
|
||||
response = provider.send_conversation(self)
|
||||
response = provider.send_conversation(self, tools=tools)
|
||||
|
||||
# Execute all post-send hooks.
|
||||
for plugin in self.plugins:
|
||||
|
||||
@@ -1,12 +1,34 @@
|
||||
from typing import List, Type
|
||||
|
||||
from ._base import BaseProvider
|
||||
from ._base_tools import BaseTool
|
||||
from .amazon import Amazon
|
||||
from .anthropic import Anthropic
|
||||
from .gemini import Gemini
|
||||
from .groq import Groq
|
||||
from .ollama import Ollama
|
||||
from .openai import OpenAI
|
||||
from .xai import XAI
|
||||
from .amazon import Amazon
|
||||
|
||||
providers: List[Type[BaseProvider]] = [Anthropic, Gemini, Groq, OpenAI, Ollama, XAI, Amazon]
|
||||
providers: List[Type[BaseProvider]] = [
|
||||
Anthropic,
|
||||
Gemini,
|
||||
Groq,
|
||||
OpenAI,
|
||||
Ollama,
|
||||
XAI,
|
||||
Amazon,
|
||||
]
|
||||
|
||||
__all__ = [
|
||||
"Anthropic",
|
||||
"Gemini",
|
||||
"Groq",
|
||||
"OpenAI",
|
||||
"Ollama",
|
||||
"XAI",
|
||||
"Amazon",
|
||||
"providers",
|
||||
"BaseProvider",
|
||||
"BaseTool",
|
||||
]
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar
|
||||
|
||||
from instructor import Instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from simplemind.providers._base_tools import BaseTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
@@ -32,16 +34,41 @@ class BaseProvider(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def send_conversation(self, conversation: "Conversation") -> "Message":
|
||||
def send_conversation(
|
||||
self,
|
||||
conversation: "Conversation",
|
||||
tools: list[Callable | BaseTool] | None = None,
|
||||
) -> "Message":
|
||||
"""Send a conversation to the provider."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
|
||||
def structured_response(
|
||||
self, prompt: str, response_model: Type[T], **kwargs
|
||||
) -> T:
|
||||
"""Get a structured response."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def generate_text(self, prompt: str, *, stream: bool = False, **kwargs) -> str:
|
||||
def generate_text(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
tools: list[Callable | BaseTool] | None = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Generate text from a prompt."""
|
||||
raise NotImplementedError
|
||||
|
||||
@cached_property
|
||||
@abstractmethod
|
||||
def tool(self) -> Type[BaseTool]:
|
||||
"""The tool implementation for the provider."""
|
||||
raise NotImplementedError
|
||||
|
||||
def make_tools(self, tools: list[Callable | BaseTool] | None):
|
||||
if tools is not None:
|
||||
return [self.tool.from_function(func) for func in tools]
|
||||
else:
|
||||
return []
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, ClassVar, Literal, get_origin
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefinedType
|
||||
|
||||
|
||||
def _is_literal(t: Any) -> bool:
|
||||
return get_origin(t) is Literal
|
||||
|
||||
|
||||
def _is_required(field, func_signature, arg_name) -> bool:
|
||||
param = func_signature.parameters[arg_name]
|
||||
# If parameter has a default value that's not a FieldInfo, it's not required
|
||||
if param.default is not inspect.Parameter.empty and not isinstance(
|
||||
param.default, FieldInfo
|
||||
):
|
||||
return False
|
||||
# If the field has a default that's not undefined, it's not required
|
||||
return isinstance(field.default, PydanticUndefinedType)
|
||||
|
||||
|
||||
class BaseToolConfig(BaseModel):
|
||||
TYPE_CONVERSION: dict[type, str] = {
|
||||
str: "string",
|
||||
int: "integer",
|
||||
float: "number",
|
||||
bool: "boolean",
|
||||
}
|
||||
|
||||
|
||||
class BaseToolProperty(BaseModel):
|
||||
type: str = Field(serialization_alias="type_")
|
||||
enum: list[str] | None = None
|
||||
description: str
|
||||
|
||||
|
||||
class BaseTool(BaseModel, ABC):
|
||||
name: str
|
||||
description: str
|
||||
properties: dict[str, BaseToolProperty]
|
||||
required: list[str] | None = None
|
||||
config: ClassVar[BaseToolConfig] = BaseToolConfig()
|
||||
raw_func: Any | None = None
|
||||
tool_id: str | None = None
|
||||
function_result: str | None = None
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
assert self.raw_func is not None
|
||||
return self.raw_func(*args, **kwargs)
|
||||
|
||||
def is_executed(self) -> bool:
|
||||
return self.function_result is not None
|
||||
|
||||
def reset_result(self) -> None:
|
||||
self.function_result = None
|
||||
|
||||
@classmethod
|
||||
def convert_type(cls, field_type) -> str:
|
||||
if _is_literal(field_type):
|
||||
return cls.config.TYPE_CONVERSION[str]
|
||||
|
||||
field_type_converted = cls.config.TYPE_CONVERSION.get(field_type, None)
|
||||
|
||||
if field_type_converted is None:
|
||||
raise TypeError(f"Field of type {field_type} is not supported")
|
||||
|
||||
return field_type_converted
|
||||
|
||||
def get_properties_schema(self, **kwargs) -> dict[str, dict]:
|
||||
new_kwargs: dict = {"exclude_none": True} | kwargs
|
||||
return {
|
||||
k: v.model_dump(**new_kwargs) for k, v in self.properties.items()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_function(cls, func: Callable | "BaseTool"):
|
||||
# Check if the func passed is an instace of BaseTool
|
||||
if hasattr(func, "raw_func"):
|
||||
return func
|
||||
|
||||
annotations = getattr(func, "__annotations__", {})
|
||||
properties = {}
|
||||
required = []
|
||||
enum_values = None
|
||||
func_signature = inspect.signature(func)
|
||||
|
||||
for n, (arg_name, arg_type) in enumerate(annotations.items()):
|
||||
if ( # Skipping 'return' annotation (i.e.```-> str```)
|
||||
arg_name != "return"
|
||||
):
|
||||
# Check if argument has metadata (from Annotated)
|
||||
if hasattr(arg_type, "__metadata__"):
|
||||
field = arg_type.__metadata__[
|
||||
0
|
||||
] # Get Field info from metadata
|
||||
field_type = arg_type.__origin__ # Get actual type
|
||||
# Check if argument has a default value in signature
|
||||
elif (
|
||||
sig_param := func_signature.parameters[arg_name]
|
||||
).default is not inspect.Parameter.empty:
|
||||
field = sig_param.default # Use default as Field
|
||||
field_type = arg_type # Use plain type annotation
|
||||
else:
|
||||
# Raise error if no Field annotation found
|
||||
raise ValueError(
|
||||
f"Please add a Field annotation to `{func.__name__}.{arg_name}` parameter"
|
||||
)
|
||||
|
||||
field_type_converted = cls.convert_type(field_type)
|
||||
|
||||
if _is_literal(field_type):
|
||||
enum_values = [str(x) for x in field_type.__args__]
|
||||
|
||||
properties[arg_name] = BaseToolProperty(
|
||||
type=field_type_converted,
|
||||
description=field.description,
|
||||
enum=enum_values,
|
||||
)
|
||||
if _is_required(field, func_signature, arg_name):
|
||||
required.append(arg_name)
|
||||
|
||||
return cls(
|
||||
name=func.__name__,
|
||||
description=(func.__doc__ or "").strip(),
|
||||
properties=properties,
|
||||
required=required,
|
||||
raw_func=func,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_input_schema(self) -> Any: ...
|
||||
|
||||
@abstractmethod
|
||||
def handle(self, message) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_response_schema(self) -> Any: ...
|
||||
@@ -1,22 +1,22 @@
|
||||
from typing import Type, TypeVar, Iterator
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._base import BaseProvider
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
PROVIDER_NAME = "amazon"
|
||||
DEFAULT_MODEL = "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
DEFAULT_MAX_TOKENS = 5_000
|
||||
|
||||
|
||||
class Amazon(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
NAME = "amazon"
|
||||
DEFAULT_MODEL = "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
DEFAULT_MAX_TOKENS = 5_000
|
||||
supports_streaming = True
|
||||
|
||||
def __init__(self, profile_name: str | None = None):
|
||||
@@ -25,7 +25,12 @@ class Amazon(BaseProvider):
|
||||
@cached_property
|
||||
def client(self):
|
||||
"""The AnthropicBedrock client."""
|
||||
import anthropic
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `anthropic` package: `pip install anthropic`"
|
||||
) from exc
|
||||
|
||||
if not self.profile_name:
|
||||
raise ValueError("Profile name is not provided")
|
||||
@@ -33,12 +38,12 @@ class Amazon(BaseProvider):
|
||||
return anthropic.AnthropicBedrock(aws_profile=self.profile_name)
|
||||
|
||||
@cached_property
|
||||
def structured_client(self):
|
||||
def structured_client(self) -> instructor.Instructor:
|
||||
"""A client patched with Instructor."""
|
||||
|
||||
return instructor.from_anthropic(self.client)
|
||||
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
|
||||
from ..models import Message
|
||||
@@ -59,7 +64,7 @@ class Amazon(BaseProvider):
|
||||
role="assistant",
|
||||
text=assistant_message.content or "",
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or DEFAULT_MODEL,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
@@ -75,12 +80,12 @@ class Amazon(BaseProvider):
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
response_model=response_model,
|
||||
max_tokens=DEFAULT_MAX_TOKENS,
|
||||
max_tokens=self.DEFAULT_MAX_TOKENS,
|
||||
**kwargs,
|
||||
)
|
||||
return response
|
||||
|
||||
def generate_text(self, prompt, *, llm_model, **kwargs):
|
||||
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
@@ -88,13 +93,15 @@ class Amazon(BaseProvider):
|
||||
response = self.client.messages.create(
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
max_tokens=DEFAULT_MAX_TOKENS,
|
||||
max_tokens=self.DEFAULT_MAX_TOKENS,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return response.content[0].text
|
||||
|
||||
def generate_stream_text(self, prompt, *, llm_model, **kwargs) -> Iterator[str]:
|
||||
def generate_stream_text(
|
||||
self, prompt: str, *, llm_model: str, **kwargs
|
||||
) -> Iterator[str]:
|
||||
"""Generate streaming text using the Amazon API."""
|
||||
|
||||
# Prepare the messages.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterator, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
@@ -7,6 +7,7 @@ from pydantic import BaseModel
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
from ._base_tools import BaseTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
@@ -14,20 +15,67 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "anthropic"
|
||||
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
||||
DEFAULT_MAX_TOKENS = 1_000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
class AnthropicTool(BaseTool):
|
||||
def get_response_schema(self) -> Any:
|
||||
assert self.is_executed, f"Tool {self.name} was not executed."
|
||||
assert isinstance(
|
||||
self.tool_id, str
|
||||
), f"Expected str for `tool_id` got {self.tool_id!r}"
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": self.tool_id,
|
||||
"content": self.function_result,
|
||||
}
|
||||
|
||||
@logger
|
||||
def handle(self, response, messages) -> None:
|
||||
"""Handle the tool execution result from an API response."""
|
||||
msg = {"role": "assistant", "content": []}
|
||||
tool_used = False
|
||||
for content in response.content:
|
||||
if content.type == "tool_use" and content.name == self.name:
|
||||
msg["content"].append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": content.id,
|
||||
"name": content.name,
|
||||
"input": content.input,
|
||||
}
|
||||
)
|
||||
# Function execution:
|
||||
self.function_result = str(self.raw_func(**content.input))
|
||||
self.tool_id = content.id
|
||||
tool_used = True
|
||||
elif content.type == "text":
|
||||
msg["content"].append({"type": "text", "text": content.text})
|
||||
|
||||
if tool_used:
|
||||
messages.append(msg)
|
||||
messages.append(
|
||||
{"role": "user", "content": [self.get_response_schema()]}
|
||||
)
|
||||
|
||||
def get_input_schema(self):
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": self.get_properties_schema(),
|
||||
"required": self.required,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Anthropic(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
NAME = "anthropic"
|
||||
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
||||
DEFAULT_MAX_TOKENS = 1_000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
supports_streaming = True
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
self.api_key = api_key or settings.get_api_key(self.NAME)
|
||||
|
||||
@cached_property
|
||||
def client(self):
|
||||
@@ -49,30 +97,60 @@ class Anthropic(BaseProvider):
|
||||
return instructor.from_anthropic(self.client)
|
||||
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||
def send_conversation(
|
||||
self,
|
||||
conversation: "Conversation",
|
||||
tools: list[Callable | BaseTool] | None = None,
|
||||
**kwargs,
|
||||
) -> "Message":
|
||||
"""Send a conversation to the Anthropic API."""
|
||||
from ..models import Message
|
||||
|
||||
messages = [
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
# Format messages from conversation
|
||||
formatted_messages = [
|
||||
{"role": msg.role, "content": msg.text}
|
||||
for msg in conversation.messages
|
||||
]
|
||||
|
||||
response = self.client.messages.create(
|
||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
# Set up tools if provided
|
||||
converted_tools = self.make_tools(tools)
|
||||
tools_config = (
|
||||
{"tools": [t.get_input_schema() for t in converted_tools]}
|
||||
if tools is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
# Get the response content from the Anthropic response
|
||||
assistant_message = response.content[0].text
|
||||
# Merge all kwargs
|
||||
request_kwargs = {
|
||||
**self.DEFAULT_KWARGS,
|
||||
**kwargs,
|
||||
**tools_config,
|
||||
"model": conversation.llm_model or self.DEFAULT_MODEL,
|
||||
"messages": formatted_messages,
|
||||
}
|
||||
|
||||
# Make initial API call
|
||||
response = self.client.messages.create(**request_kwargs)
|
||||
|
||||
# Handle tool responses if needed
|
||||
while response.content[-1].type != "text":
|
||||
# Continue handling tools if the LLM is doing
|
||||
# multiple sub-seqequent/sequential tool calls
|
||||
for tool in converted_tools:
|
||||
tool.handle(response, formatted_messages)
|
||||
if tool.is_executed():
|
||||
response = self.client.messages.create(**request_kwargs)
|
||||
# Resetting the tool results in case this tool gets used again
|
||||
tool.reset_result()
|
||||
|
||||
final_message = response.content[-1].text
|
||||
|
||||
# Create and return a properly formatted Message instance
|
||||
return Message(
|
||||
role="assistant",
|
||||
text=assistant_message,
|
||||
text=final_message,
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
llm_provider=self.NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
@@ -127,3 +205,8 @@ class Anthropic(BaseProvider):
|
||||
# Yield each chunk of text from the stream.
|
||||
for chunk in stream.text_stream:
|
||||
yield chunk
|
||||
|
||||
@cached_property
|
||||
def tool(self) -> Type[BaseTool]:
|
||||
"""The tool implementation for Antrhopic."""
|
||||
return AnthropicTool
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# IT is not currently working as desired.
|
||||
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
|
||||
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
@@ -17,18 +17,14 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "gemini"
|
||||
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
|
||||
|
||||
|
||||
class Gemini(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
NAME = "gemini"
|
||||
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
|
||||
supports_streaming = True
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
self.model_name = DEFAULT_MODEL
|
||||
self.api_key = api_key or settings.get_api_key(self.NAME)
|
||||
self.model_name = self.DEFAULT_MODEL
|
||||
|
||||
def set_model(self, model_name: str):
|
||||
self.model_name = model_name
|
||||
@@ -76,7 +72,7 @@ class Gemini(BaseProvider):
|
||||
text=response.text,
|
||||
raw=response,
|
||||
llm_model=self.model_name,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
llm_provider=self.NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
|
||||
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
@@ -14,20 +14,15 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "groq"
|
||||
DEFAULT_MODEL = "llama3-8b-8192"
|
||||
DEFAULT_MAX_TOKENS = 1_000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
|
||||
|
||||
class Groq(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
NAME = "groq"
|
||||
DEFAULT_MODEL = "llama3-8b-8192"
|
||||
DEFAULT_MAX_TOKENS = 1_000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
supports_streaming = True
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
self.api_key = api_key or settings.get_api_key(self.NAME)
|
||||
|
||||
@cached_property
|
||||
def client(self):
|
||||
@@ -75,7 +70,7 @@ class Groq(BaseProvider):
|
||||
text=assistant_message.content or "",
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
llm_provider=self.NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
|
||||
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..logging import logger
|
||||
@@ -15,17 +14,11 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "ollama"
|
||||
DEFAULT_MODEL = "llama3.2"
|
||||
DEFAULT_TIMEOUT = 60
|
||||
DEFAULT_KWARGS = {}
|
||||
|
||||
|
||||
class Ollama(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
TIMEOUT = DEFAULT_TIMEOUT
|
||||
NAME = "ollama"
|
||||
DEFAULT_MODEL = "llama3.2"
|
||||
DEFAULT_TIMEOUT = 60
|
||||
DEFAULT_KWARGS = {}
|
||||
supports_streaming = True
|
||||
|
||||
def __init__(self, host_url: str | None = None):
|
||||
@@ -37,21 +30,18 @@ class Ollama(BaseProvider):
|
||||
if not self.host_url:
|
||||
raise ValueError("No ollama host url provided")
|
||||
try:
|
||||
import ollama as ol
|
||||
import openai
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install the `ollama` package: `pip install ollama`"
|
||||
"Please install the `openai` package: `pip install openai`"
|
||||
) from exc
|
||||
return ol.Client(timeout=self.TIMEOUT, host=self.host_url)
|
||||
return openai.OpenAI(base_url=f"{self.host_url}/v1", api_key="ollama")
|
||||
|
||||
@cached_property
|
||||
def structured_client(self) -> instructor.Instructor:
|
||||
"""A client patched with Instructor."""
|
||||
return instructor.from_openai(
|
||||
OpenAI(
|
||||
base_url=f"{self.host_url}/v1",
|
||||
api_key="ollama",
|
||||
),
|
||||
self.client,
|
||||
mode=instructor.Mode.JSON,
|
||||
)
|
||||
|
||||
@@ -63,20 +53,24 @@ class Ollama(BaseProvider):
|
||||
messages = [
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
]
|
||||
response = self.client.chat(
|
||||
model=conversation.llm_model or DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
assistant_message = response.get("message")
|
||||
|
||||
request_kwargs = {
|
||||
**self.DEFAULT_KWARGS,
|
||||
**kwargs,
|
||||
"model": conversation.llm_model or self.DEFAULT_MODEL,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
response = self.client.chat.completions.create(**request_kwargs)
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
# Create and return a properly formatted Message instance
|
||||
return Message(
|
||||
role="assistant",
|
||||
text=assistant_message.get("content"),
|
||||
text=assistant_message.content or "",
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
llm_provider=self.NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
@@ -110,13 +104,13 @@ class Ollama(BaseProvider):
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.client.chat(
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
return response.get("message", {}).get("content", "")
|
||||
return response.choices[0].message.content
|
||||
|
||||
@logger
|
||||
def generate_stream_text(
|
||||
@@ -127,7 +121,7 @@ class Ollama(BaseProvider):
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.client.chat(
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
stream=True,
|
||||
@@ -136,4 +130,5 @@ class Ollama(BaseProvider):
|
||||
|
||||
# Iterate over the response and yield the content.
|
||||
for chunk in response:
|
||||
yield chunk["message"]["content"]
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
+132
-23
@@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
|
||||
from typing import TYPE_CHECKING, Callable, Iterator, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
@@ -7,26 +7,96 @@ from pydantic import BaseModel
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
from ._base_tools import BaseTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
PROVIDER_NAME = "openai"
|
||||
DEFAULT_MODEL = "gpt-4o-mini"
|
||||
DEFAULT_MAX_TOKENS = None
|
||||
DEFAULT_KWARGS = {}
|
||||
|
||||
class OpenAITool(BaseTool):
|
||||
def get_response_schema(self):
|
||||
assert self.is_executed, f"Tool {self.name} was not executed."
|
||||
assert isinstance(
|
||||
self.tool_id, str
|
||||
), f"Expected str for `tool_id` got {self.tool_id!r}"
|
||||
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": self.tool_id,
|
||||
"content": self.function_result,
|
||||
}
|
||||
|
||||
@logger
|
||||
def handle(self, response, messages) -> None:
|
||||
"""Handle the tool execution result from an API response."""
|
||||
tool_used = False
|
||||
|
||||
# Get the message from the response
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
# Check if there's a tool call
|
||||
if assistant_message.tool_calls:
|
||||
tool_call = assistant_message.tool_calls[
|
||||
0
|
||||
] # Get the first tool call
|
||||
if tool_call.function.name == self.name:
|
||||
# Execute the function
|
||||
import json
|
||||
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
self.function_result = str(self.raw_func(**function_args))
|
||||
self.tool_id = tool_call.id
|
||||
tool_used = True
|
||||
|
||||
# Add assistant's message with tool call
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
if tool_used:
|
||||
# Add tool response message
|
||||
messages.append(self.get_response_schema())
|
||||
|
||||
def get_input_schema(self):
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": self.get_properties_schema(),
|
||||
"required": self.required,
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class OpenAI(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
NAME = "openai"
|
||||
DEFAULT_MODEL = "gpt-4o-mini"
|
||||
DEFAULT_MAX_TOKENS = None
|
||||
DEFAULT_KWARGS = {}
|
||||
supports_streaming = True
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
self.api_key = api_key or settings.get_api_key(self.NAME)
|
||||
|
||||
@cached_property
|
||||
def client(self):
|
||||
@@ -47,30 +117,62 @@ class OpenAI(BaseProvider):
|
||||
return instructor.from_openai(self.client)
|
||||
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||
def send_conversation(
|
||||
self,
|
||||
conversation: "Conversation",
|
||||
tools: list[Callable | BaseTool] | None = None,
|
||||
**kwargs,
|
||||
) -> "Message":
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
from ..models import Message
|
||||
|
||||
messages = [
|
||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||
# Format messages from conversation
|
||||
formatted_messages = [
|
||||
{"role": msg.role, "content": msg.text}
|
||||
for msg in conversation.messages
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=conversation.llm_model or DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
# Set up tools if provided
|
||||
converted_tools = self.make_tools(tools)
|
||||
tools_config = (
|
||||
[t.get_input_schema() for t in converted_tools] if tools else None
|
||||
)
|
||||
|
||||
# Get the response content from the OpenAI response
|
||||
assistant_message = response.choices[0].message
|
||||
# Merge all kwargs
|
||||
request_kwargs = {
|
||||
**self.DEFAULT_KWARGS,
|
||||
**kwargs,
|
||||
"model": conversation.llm_model or self.DEFAULT_MODEL,
|
||||
"messages": formatted_messages,
|
||||
}
|
||||
|
||||
if tools_config:
|
||||
request_kwargs["tools"] = tools_config
|
||||
|
||||
# Make initial API call
|
||||
response = self.client.chat.completions.create(**request_kwargs)
|
||||
|
||||
# Handle tool responses if needed
|
||||
while response.choices[0].message.tool_calls:
|
||||
print(response)
|
||||
# Handle each tool call
|
||||
for tool in converted_tools:
|
||||
tool.handle(response, formatted_messages)
|
||||
if tool.is_executed():
|
||||
# Make another API call with the updated messages
|
||||
response = self.client.chat.completions.create(
|
||||
**request_kwargs
|
||||
)
|
||||
tool.reset_result()
|
||||
|
||||
final_message = response.choices[0].message.content
|
||||
|
||||
# Create and return a properly formatted Message instance
|
||||
return Message(
|
||||
role="assistant",
|
||||
text=assistant_message.content or "",
|
||||
text=final_message or "",
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or DEFAULT_MODEL,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=self.NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
@@ -96,7 +198,9 @@ class OpenAI(BaseProvider):
|
||||
return response_model.model_validate(response)
|
||||
|
||||
@logger
|
||||
def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs):
|
||||
def generate_text(
|
||||
self, prompt: str, *, llm_model: str | None = None, **kwargs
|
||||
):
|
||||
"""Generate text using the OpenAI API."""
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
@@ -129,3 +233,8 @@ class OpenAI(BaseProvider):
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
@cached_property
|
||||
def tool(self) -> Type[BaseTool]:
|
||||
"""The tool implementation for OpenAI."""
|
||||
return OpenAITool
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
|
||||
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
@@ -14,22 +14,17 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "xai"
|
||||
DEFAULT_MODEL = "grok-beta"
|
||||
BASE_URL = "https://api.x.ai/v1"
|
||||
DEFAULT_MAX_TOKENS = 1000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
|
||||
|
||||
class XAI(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
NAME = "xai"
|
||||
DEFAULT_MODEL = "grok-beta"
|
||||
DEFAULT_MAX_TOKENS = 1000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
BASE_URL = "https://api.x.ai/v1"
|
||||
supports_streaming = True
|
||||
supports_structured_responses = False
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
self.api_key = api_key or settings.get_api_key(self.NAME)
|
||||
|
||||
@cached_property
|
||||
def client(self):
|
||||
@@ -44,7 +39,7 @@ class XAI(BaseProvider):
|
||||
) from exc
|
||||
return oa.OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=BASE_URL,
|
||||
base_url=self.BASE_URL,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
@@ -76,7 +71,7 @@ class XAI(BaseProvider):
|
||||
text=assistant_message.content,
|
||||
raw=response,
|
||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||
llm_provider=PROVIDER_NAME,
|
||||
llm_provider=self.NAME,
|
||||
)
|
||||
|
||||
@logger
|
||||
|
||||
@@ -14,8 +14,9 @@ class LoggingConfig(BaseSettings):
|
||||
"""Enable logging for the application."""
|
||||
# adding imports here to avoid forced dependencies
|
||||
try:
|
||||
import logfire
|
||||
from logging import basicConfig
|
||||
|
||||
import logfire
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"To enable logging, please install logfire: `pip install logfire`"
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import pytest
|
||||
|
||||
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI
|
||||
|
||||
|
||||
class ResponseModel(BaseModel):
|
||||
result: int
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
|
||||
from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
import simplemind as sm
|
||||
|
||||
from simplemind.providers import Anthropic, OpenAI
|
||||
from simplemind.providers._base_tools import BaseTool
|
||||
|
||||
MODELS = [
|
||||
Anthropic,
|
||||
# Gemini,
|
||||
OpenAI,
|
||||
# Groq,
|
||||
# Ollama,
|
||||
# Amazon
|
||||
]
|
||||
|
||||
|
||||
def get_weather(
|
||||
location: Annotated[
|
||||
str, Field(description="The city and state, e.g. San Francisco, CA")
|
||||
],
|
||||
unit: Annotated[
|
||||
Literal["celcius", "fahrenheit"],
|
||||
Field(description="The unit of temperature, either 'celsius' or 'fahrenheit'"),
|
||||
] = "celcius",
|
||||
):
|
||||
"""
|
||||
Get the current weather in a given location
|
||||
"""
|
||||
return f"42 {unit}"
|
||||
|
||||
|
||||
def get_location():
|
||||
"""Get the current location"""
|
||||
return "San Francisco,CA"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls",
|
||||
MODELS,
|
||||
)
|
||||
def test_single_tool_args(provider_cls):
|
||||
conv = sm.create_conversation(
|
||||
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
|
||||
)
|
||||
|
||||
conv.add_message(text="What is the weather in San Francisco?")
|
||||
data = conv.send(tools=[get_weather])
|
||||
assert "42" in data.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls",
|
||||
MODELS,
|
||||
)
|
||||
def test_single_tool_no_args(provider_cls):
|
||||
conv = sm.create_conversation(
|
||||
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
|
||||
)
|
||||
|
||||
conv.add_message(text="What is my current location")
|
||||
data = conv.send(tools=[get_location])
|
||||
assert "San Francisco" in data.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls",
|
||||
MODELS,
|
||||
)
|
||||
def test_single_tool_partial(provider_cls):
|
||||
conv = sm.create_conversation(
|
||||
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
|
||||
)
|
||||
|
||||
conv.add_message(text="Can you tell me the weather?")
|
||||
conv.send(tools=[get_weather])
|
||||
# Will answer something like:
|
||||
"""
|
||||
I can help you check the weather, but I need to know the location you're interested in.
|
||||
Could you please provide a city and state (e.g., "Los Angeles, CA" or "New York, NY")
|
||||
where you'd like to know the weather?
|
||||
"""
|
||||
|
||||
conv.add_message(text="San Francisco, CA")
|
||||
data = conv.send(tools=[get_weather])
|
||||
assert "42" in data.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls",
|
||||
MODELS,
|
||||
)
|
||||
def test_multiple_tools(provider_cls):
|
||||
conv = sm.create_conversation(
|
||||
llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME
|
||||
)
|
||||
|
||||
conv.add_message(text="What is the wheather at my current location?")
|
||||
data = conv.send(tools=[get_location, get_weather])
|
||||
assert "San Francisco" in data.text
|
||||
assert "42" in data.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_cls",
|
||||
MODELS,
|
||||
)
|
||||
def test_tool_decorator(provider_cls):
|
||||
@sm.tool(llm_provider=provider_cls.NAME)
|
||||
def exchange_rate(currency_pair: str) -> float:
|
||||
return 7.9
|
||||
|
||||
assert isinstance(exchange_rate, BaseTool)
|
||||
assert exchange_rate.name == "exchange_rate"
|
||||
assert list(exchange_rate.properties.keys()) == ["currency_pair"]
|
||||
Reference in New Issue
Block a user