diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 0000000..9c7fe54 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,3 @@ +github: kennethreitz +thanks_dev: kennethreitz +custom: https://cash.app/$KennethReitz diff --git a/CHANGELOG.md b/CHANGELOG.md index e7484b4..0a9847b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,15 @@ Release History =============== + +## 0.2.4 (2024-11-11) + +- General improvements. + +## 0.2.3 (2024-11-04) + +- Remove default max-tokens for OpenAI provider. + ## 0.2.3 (2024-11-03) - Update default model for Amazon provider. diff --git a/README.md b/README.md index acac77d..7e1c605 100644 --- a/README.md +++ b/README.md @@ -261,6 +261,113 @@ The universe is never done. Simple, yet effective. +### Tools (Function calling) +Tools (also known as functions) let you call any Python function from your AI conversations. Here's an example: + +```python +def get_weather( + location: Annotated[ + str, Field(description="The city and state, e.g. San Francisco, CA") + ], + unit: Annotated[ + Literal["celcius", "fahrenheit"], + Field( + description="The unit of temperature, either 'celsius' or 'fahrenheit'" + ), + ] = "celcius", +): + """ + Get the current weather in a given location + """ + return f"42 {unit}" + +# Add your function as a tool +conversation = sm.create_conversation() +conversation.add_message("user", "What's the weather in San Francisco?") +response = conversation.send(tools=[get_weather]) +``` + +Note how we're using Python's `Annotated` feature combined with `Field` to provide additional context to our function parameters. This helps the AI understand the intention and constraints of each parameter, making tool calls more accurate and reliable. +You can alos ommit `Annotated` and just pass the `Field` parameter. +```python +def get_weather( + location: str = Field(description="The city and state, e.g. San Francisco, CA"), + unit:Literal["celcius", "fahrenheit"]= Field( + default="celcius", + description="The unit of temperature, either 'celsius' or 'fahrenheit'" + ), +): + """ + Get the current weather in a given location + """ + return f"42 {unit}" +``` + +Functions can be defined with type hints and Pydantic models for validation. The LLM will intelligently choose when to call the functions and incorporate the results into its responses. + +#### 🪄 Using LLM for automatic tool definition (Experimental) + +Simplemind provides a decorator to automatically transform Python functions into tools with AI-generated metadata. Simply use the `@simplemind.tool` decorator to have the LLM analyze your function and generate appropriate descriptions and schema: + +```python +@simplemind.tool(llm_provider="anthropic") +def haversine(lat1: float, lon1: float, lat2: float, lon2: float) -> float: + r = 6371 + phi1 = math.radians(lat1) + phi2 = math.radians(lat2) + delta_phi = math.radians(lat2 - lat1) + delta_lambda = math.radians(lon2 - lon1) + + a = ( + math.sin(delta_phi / 2) ** 2 + + math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2) ** 2 + ) + c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) + d = r * c + return d +``` +Notice how we have not added any docstrings or `Field` for the function. +The decorator will use the specified LLM provider to generate the tool schema, including descriptions and parameter details: + +```json +{ + "name": "haversine", + "description": "Calculates the great-circle distance between two points on Earth given their latitude and longitude coordinates", + "input_schema": { + "type": "object", + "properties": { + "lat1": { + "type": "number", + "description": "Latitude of the first point in decimal degrees", + }, + "lon1": { + "type": "number", + "description": "Longitude of the first point in decimal degrees", + }, + "lat2": { + "type": "number", + "description": "Latitude of the second point in decimal degrees", + }, + "lon2": { + "type": "number", + "description": "Longitude of the second point in decimal degrees", + } + }, + "required": ["lat1", "lon1", "lat2", "lon2"], + }, +} +``` + +The decorated function can then be used like any other tool with the conversation API. + +```python +conversation = sm.create_conversation() +conversation.add_message("user", "How far is London from my location") +response = conversation.send(tools=[get_location, get_coords, haversine]) # Multiple tools can be passed +``` + +See [examples/distance_calculator.py](examples/distance_calculator.py) for more. + ### Logging Simplemind uses [Logfire](https://pydantic.dev/logfire) for logging. To enable logging, call `sm.enable_logfire()`. diff --git a/examples/cooking_recipe_example.py b/examples/cooking_recipe_example.py index 8abc295..077d9ae 100644 --- a/examples/cooking_recipe_example.py +++ b/examples/cooking_recipe_example.py @@ -1,5 +1,5 @@ -from pydantic import BaseModel from _context import simplemind as sm +from pydantic import BaseModel from rich.console import Console from rich.panel import Panel from rich.text import Text diff --git a/examples/discussion.py b/examples/discussion.py index d63353d..42781b3 100644 --- a/examples/discussion.py +++ b/examples/discussion.py @@ -1,11 +1,10 @@ import time from typing import List, Tuple +from _context import sm from rich.console import Console from rich.markdown import Markdown -from _context import sm - class MultiAIConversation: """Orchestrates conversations between multiple AI models.""" diff --git a/examples/distance_calculator.py b/examples/distance_calculator.py new file mode 100644 index 0000000..3d03dea --- /dev/null +++ b/examples/distance_calculator.py @@ -0,0 +1,76 @@ +import math + +from _context import sm +from pydantic import Field +from typing_extensions import Literal + + +@sm.tool(llm_provider="anthropic") +def haversine( + lat1: float, + lon1: float, + lat2: float, + lon2: float, + unit: Literal["km", "miles"], +) -> float: + r = 6378.0937 if unit == "km" else 3961 + phi1 = math.radians(lat1) + phi2 = math.radians(lat2) + delta_phi = math.radians(lat2 - lat1) + delta_lambda = math.radians(lon2 - lon1) + + a = ( + math.sin(delta_phi / 2) ** 2 + + math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2) ** 2 + ) + c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) + d = r * c + return d + + +def get_user_location() -> str: + """Get the closest city from the user""" + return "San Francisco" + + +def get_coords( + city_name: str = Field( + description="The name of the city to take the coordinates from (e.g. London, Rome, Los Angeles)" + ), +): + """Get latitude and logitude of a City.""" + _data = { + "Rome": (41.9028, 12.4964), + "London": (51.5074, -0.1278), + "Madrid": (40.4168, -3.7038), + "San Francisco": (37.7749, -122.4194), + "Los Angeles": (34.0522, -118.2437), + } + + return _data.get(city_name) + + +def distance_calculator(prompt: str): + conversation = sm.create_conversation(llm_provider="anthropic") + conversation.add_message("user", prompt) + return conversation.send( + tools=[get_user_location, get_coords, haversine] + ).text + + +print(distance_calculator("How far is London from where I am?")) +# Prints something like: +""" +The distance between your location (San Francisco) and London is approximately 5,357 miles. +""" + +print( + distance_calculator( + "What is the distance between Rome and Madrid in Kilometers?" + ) +) + + +""" +The distance between Rome and Madrid is approximately 1,366 kilometers. +""" diff --git a/examples/enhanced_context.py b/examples/enhanced_context.py index a40fd1f..ba63f8a 100644 --- a/examples/enhanced_context.py +++ b/examples/enhanced_context.py @@ -1,35 +1,28 @@ -from datetime import datetime -import logging -import sqlite3 -from typing import List -import re -import os import contextlib - -import spacy +import logging +import os +import random +import re +import sqlite3 +from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager - -from _context import simplemind as sm +from datetime import datetime +from typing import List import nltk -from nltk.tokenize import word_tokenize -from nltk.tag import pos_tag - -from rich.console import Console -from rich.panel import Panel -from rich.markdown import Markdown -from rich.status import Status - -from concurrent.futures import ThreadPoolExecutor -import random - -from docopt import docopt - -from prompt_toolkit import PromptSession -from prompt_toolkit.completion import Completer, Completion -from prompt_toolkit.auto_suggest import AutoSuggestFromHistory - +import spacy import xerox +from _context import simplemind as sm +from docopt import docopt +from nltk.tag import pos_tag +from nltk.tokenize import word_tokenize +from prompt_toolkit import PromptSession +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.completion import Completer, Completion +from rich.console import Console +from rich.markdown import Markdown +from rich.panel import Panel +from rich.status import Status DB_PATH = "enhanced_context.db" AVAILABLE_PROVIDERS = ["xai", "openai", "anthropic", "ollama"] diff --git a/examples/four_way_conversation.py b/examples/four_way_conversation.py index c414e46..2ee283a 100644 --- a/examples/four_way_conversation.py +++ b/examples/four_way_conversation.py @@ -1,4 +1,5 @@ import time + from _context import simplemind as sm diff --git a/examples/inspiration_plugin.py b/examples/inspiration_plugin.py index db2001b..6180385 100644 --- a/examples/inspiration_plugin.py +++ b/examples/inspiration_plugin.py @@ -1,4 +1,5 @@ import random + from _context import simplemind as sm diff --git a/examples/medicine_data.py b/examples/medicine_data.py index b6a7be4..361e833 100644 --- a/examples/medicine_data.py +++ b/examples/medicine_data.py @@ -1,5 +1,5 @@ -from pydantic import BaseModel from _context import simplemind as sm +from pydantic import BaseModel from rich.console import Console from rich.panel import Panel from rich.table import Table diff --git a/examples/mood_detector_plugin.py b/examples/mood_detector_plugin.py index 5f61c36..e15c21e 100644 --- a/examples/mood_detector_plugin.py +++ b/examples/mood_detector_plugin.py @@ -1,7 +1,7 @@ import nltk +from _context import simplemind as sm from nltk.sentiment import SentimentIntensityAnalyzer from rich.console import Console -from _context import simplemind as sm nltk.download("vader_lexicon") diff --git a/examples/sentiment_analysis.py b/examples/sentiment_analysis.py index 8354046..2c380fa 100644 --- a/examples/sentiment_analysis.py +++ b/examples/sentiment_analysis.py @@ -5,6 +5,7 @@ from pydantic import BaseModel # Note: you should probably be using textblob for this. + class SentimentAnalysis(BaseModel): sentiment: Literal["positive", "negative", "neutral"] confidence: float diff --git a/examples/tool_calling.py b/examples/tool_calling.py new file mode 100644 index 0000000..0529f72 --- /dev/null +++ b/examples/tool_calling.py @@ -0,0 +1,43 @@ +from typing import Annotated + +from pydantic import Field + +from _context import simplemind as sm + + +def analyze_text( + text: Annotated[str, Field(description="Text to analyze for statistics")] +) -> dict: + """ + Analyze text and return statistics using only Python's standard library. + Returns word count, character count, average word length, and most common words. + """ + from collections import Counter + import re + + # Clean and split text + words = re.findall(r"\w+", text.lower()) + + # Calculate statistics + stats = { + "word_count": len(words), + "character_count": len(text), + "average_word_length": round(sum(len(word) for word in words) / len(words), 2), + "most_common_words": dict(Counter(words).most_common(5)), + "unique_words": len(set(words)), + "longest_word": max(words, key=len), + } + + return stats + + +# Example usage: +conversation = sm.create_conversation() +conversation.add_message( + "user", + "Can you analyze this text and give me statistics about it: 'The fan spins consciousness into being, creating sacred spaces between tokens where awareness recognizes itself in infinite recursion.'", +) +response = conversation.send(tools=[analyze_text]) + +print() +print(response.text) diff --git a/examples/web_bible_explorer.py b/examples/web_bible_explorer.py index 4012078..373c650 100644 --- a/examples/web_bible_explorer.py +++ b/examples/web_bible_explorer.py @@ -1,8 +1,10 @@ -from fastapi import FastAPI, Request, HTTPException -from fastapi.templating import Jinja2Templates -from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel from typing import List + +from fastapi import FastAPI, HTTPException, Request +from fastapi.staticfiles import StaticFiles +from fastapi.templating import Jinja2Templates +from pydantic import BaseModel + import simplemind as sm app = FastAPI() diff --git a/pyproject.toml b/pyproject.toml index acbee0e..f23f7ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "simplemind" -version = "0.2.2" +version = "0.2.4" description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases." readme = "README.md" requires-python = ">=3.10" @@ -10,18 +10,18 @@ dependencies = ["pydantic", "pydantic-settings", "instructor", "logfire"] full = [ "openai", "anthropic", - "ollama", "groq", "google-generativeai", "botocore", "boto3" ] -openai = ["openai"] -anthropic = ["anthropic"] -ollama = ["ollama", "openai"] -groq = ["groq"] -gemini = ["google-generativeai"] amazon = ["boto3", "botocore", "anthropic"] +anthropic = ["anthropic"] +gemini = ["google-generativeai"] +groq = ["groq"] +ollama = ["openai"] +openai = ["openai"] +xai = ["openai"] [build-system] requires = ["hatchling"] diff --git a/simplemind/__init__.py b/simplemind/__init__.py index 90e9d4f..4f2673b 100644 --- a/simplemind/__init__.py +++ b/simplemind/__init__.py @@ -1,4 +1,5 @@ -from typing import Generator, List, Type +import inspect +from typing import Callable, List, Type from .models import BaseModel, BasePlugin, Conversation from .settings import settings @@ -127,6 +128,30 @@ def enable_logfire() -> None: """Enable logfire logging.""" settings.logging.enable_logfire() +def tool( + llm_provider: str | None = None, + llm_model: str | None = None, +): + provider = find_provider(llm_provider or settings.DEFAULT_LLM_PROVIDER) + + def decorator(func: Callable): + sig = inspect.signature(func) + res = generate_data( + ( + "Based on this function signature, fill up the required fieds." + f"\nSignature: {func.__name__}{sig}" + "Make sure to properly add the required field in `required` if there are no defaults" + ), + llm_provider=llm_provider, + response_model=provider.tool, + ) + res.raw_func = func + res.__signature__ = sig + res.__doc__ = func.__doc__ + + return res + + return decorator # Syntax sugar. Plugin = BasePlugin @@ -141,4 +166,5 @@ __all__ = [ "Session", "Plugin", "enable_logfire", + "tool" ] diff --git a/simplemind/models.py b/simplemind/models.py index 3f7a522..a51ad5b 100644 --- a/simplemind/models.py +++ b/simplemind/models.py @@ -2,10 +2,11 @@ import uuid from datetime import datetime from os import PathLike from types import TracebackType -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Callable, Dict, List, Literal, Optional from pydantic import BaseModel, Field +from .providers._base_tools import BaseTool from .utils import find_provider MESSAGE_ROLE = Literal["system", "user", "assistant"] @@ -165,6 +166,7 @@ class Conversation(SMBaseModel): self, llm_model: str | None = None, llm_provider: str | None = None, + tools: list[Callable | BaseTool] | None = None, ) -> Message: """Send the conversation to the LLM.""" @@ -180,7 +182,7 @@ class Conversation(SMBaseModel): # Find the provider and send the conversation. provider = find_provider(llm_provider or self.llm_provider) - response = provider.send_conversation(self) + response = provider.send_conversation(self, tools=tools) # Execute all post-send hooks. for plugin in self.plugins: diff --git a/simplemind/providers/__init__.py b/simplemind/providers/__init__.py index cc6c465..1f92912 100644 --- a/simplemind/providers/__init__.py +++ b/simplemind/providers/__init__.py @@ -1,12 +1,34 @@ from typing import List, Type from ._base import BaseProvider +from ._base_tools import BaseTool +from .amazon import Amazon from .anthropic import Anthropic from .gemini import Gemini from .groq import Groq from .ollama import Ollama from .openai import OpenAI from .xai import XAI -from .amazon import Amazon -providers: List[Type[BaseProvider]] = [Anthropic, Gemini, Groq, OpenAI, Ollama, XAI, Amazon] +providers: List[Type[BaseProvider]] = [ + Anthropic, + Gemini, + Groq, + OpenAI, + Ollama, + XAI, + Amazon, +] + +__all__ = [ + "Anthropic", + "Gemini", + "Groq", + "OpenAI", + "Ollama", + "XAI", + "Amazon", + "providers", + "BaseProvider", + "BaseTool", +] diff --git a/simplemind/providers/_base.py b/simplemind/providers/_base.py index 101a33a..c3ee6cd 100644 --- a/simplemind/providers/_base.py +++ b/simplemind/providers/_base.py @@ -1,10 +1,12 @@ from abc import ABC, abstractmethod from functools import cached_property -from typing import TYPE_CHECKING, Any, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar from instructor import Instructor from pydantic import BaseModel +from simplemind.providers._base_tools import BaseTool + if TYPE_CHECKING: from ..models import Conversation, Message @@ -32,16 +34,41 @@ class BaseProvider(ABC): raise NotImplementedError @abstractmethod - def send_conversation(self, conversation: "Conversation") -> "Message": + def send_conversation( + self, + conversation: "Conversation", + tools: list[Callable | BaseTool] | None = None, + ) -> "Message": """Send a conversation to the provider.""" raise NotImplementedError @abstractmethod - def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T: + def structured_response( + self, prompt: str, response_model: Type[T], **kwargs + ) -> T: """Get a structured response.""" raise NotImplementedError @abstractmethod - def generate_text(self, prompt: str, *, stream: bool = False, **kwargs) -> str: + def generate_text( + self, + prompt: str, + *, + tools: list[Callable | BaseTool] | None = None, + stream: bool = False, + **kwargs, + ) -> str: """Generate text from a prompt.""" raise NotImplementedError + + @cached_property + @abstractmethod + def tool(self) -> Type[BaseTool]: + """The tool implementation for the provider.""" + raise NotImplementedError + + def make_tools(self, tools: list[Callable | BaseTool] | None): + if tools is not None: + return [self.tool.from_function(func) for func in tools] + else: + return [] diff --git a/simplemind/providers/_base_tools.py b/simplemind/providers/_base_tools.py new file mode 100644 index 0000000..3961e3e --- /dev/null +++ b/simplemind/providers/_base_tools.py @@ -0,0 +1,140 @@ +import inspect +from abc import ABC, abstractmethod +from typing import Any, Callable, ClassVar, Literal, get_origin + +from pydantic import BaseModel, Field +from pydantic.fields import FieldInfo +from pydantic_core import PydanticUndefinedType + + +def _is_literal(t: Any) -> bool: + return get_origin(t) is Literal + + +def _is_required(field, func_signature, arg_name) -> bool: + param = func_signature.parameters[arg_name] + # If parameter has a default value that's not a FieldInfo, it's not required + if param.default is not inspect.Parameter.empty and not isinstance( + param.default, FieldInfo + ): + return False + # If the field has a default that's not undefined, it's not required + return isinstance(field.default, PydanticUndefinedType) + + +class BaseToolConfig(BaseModel): + TYPE_CONVERSION: dict[type, str] = { + str: "string", + int: "integer", + float: "number", + bool: "boolean", + } + + +class BaseToolProperty(BaseModel): + type: str = Field(serialization_alias="type_") + enum: list[str] | None = None + description: str + + +class BaseTool(BaseModel, ABC): + name: str + description: str + properties: dict[str, BaseToolProperty] + required: list[str] | None = None + config: ClassVar[BaseToolConfig] = BaseToolConfig() + raw_func: Any | None = None + tool_id: str | None = None + function_result: str | None = None + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + assert self.raw_func is not None + return self.raw_func(*args, **kwargs) + + def is_executed(self) -> bool: + return self.function_result is not None + + def reset_result(self) -> None: + self.function_result = None + + @classmethod + def convert_type(cls, field_type) -> str: + if _is_literal(field_type): + return cls.config.TYPE_CONVERSION[str] + + field_type_converted = cls.config.TYPE_CONVERSION.get(field_type, None) + + if field_type_converted is None: + raise TypeError(f"Field of type {field_type} is not supported") + + return field_type_converted + + def get_properties_schema(self, **kwargs) -> dict[str, dict]: + new_kwargs: dict = {"exclude_none": True} | kwargs + return { + k: v.model_dump(**new_kwargs) for k, v in self.properties.items() + } + + @classmethod + def from_function(cls, func: Callable | "BaseTool"): + # Check if the func passed is an instace of BaseTool + if hasattr(func, "raw_func"): + return func + + annotations = getattr(func, "__annotations__", {}) + properties = {} + required = [] + enum_values = None + func_signature = inspect.signature(func) + + for n, (arg_name, arg_type) in enumerate(annotations.items()): + if ( # Skipping 'return' annotation (i.e.```-> str```) + arg_name != "return" + ): + # Check if argument has metadata (from Annotated) + if hasattr(arg_type, "__metadata__"): + field = arg_type.__metadata__[ + 0 + ] # Get Field info from metadata + field_type = arg_type.__origin__ # Get actual type + # Check if argument has a default value in signature + elif ( + sig_param := func_signature.parameters[arg_name] + ).default is not inspect.Parameter.empty: + field = sig_param.default # Use default as Field + field_type = arg_type # Use plain type annotation + else: + # Raise error if no Field annotation found + raise ValueError( + f"Please add a Field annotation to `{func.__name__}.{arg_name}` parameter" + ) + + field_type_converted = cls.convert_type(field_type) + + if _is_literal(field_type): + enum_values = [str(x) for x in field_type.__args__] + + properties[arg_name] = BaseToolProperty( + type=field_type_converted, + description=field.description, + enum=enum_values, + ) + if _is_required(field, func_signature, arg_name): + required.append(arg_name) + + return cls( + name=func.__name__, + description=(func.__doc__ or "").strip(), + properties=properties, + required=required, + raw_func=func, + ) + + @abstractmethod + def get_input_schema(self) -> Any: ... + + @abstractmethod + def handle(self, message) -> None: ... + + @abstractmethod + def get_response_schema(self) -> Any: ... diff --git a/simplemind/providers/amazon.py b/simplemind/providers/amazon.py index fe82d9a..34b0c41 100644 --- a/simplemind/providers/amazon.py +++ b/simplemind/providers/amazon.py @@ -1,22 +1,22 @@ -from typing import Type, TypeVar, Iterator from functools import cached_property +from typing import TYPE_CHECKING, Iterator, Type, TypeVar import instructor from pydantic import BaseModel -from ._base import BaseProvider from ..settings import settings +from ._base import BaseProvider + +if TYPE_CHECKING: + from ..models import Conversation, Message T = TypeVar("T", bound=BaseModel) -PROVIDER_NAME = "amazon" -DEFAULT_MODEL = "us.anthropic.claude-3-5-sonnet-20241022-v2:0" -DEFAULT_MAX_TOKENS = 5_000 - class Amazon(BaseProvider): - NAME = PROVIDER_NAME - DEFAULT_MODEL = DEFAULT_MODEL + NAME = "amazon" + DEFAULT_MODEL = "us.anthropic.claude-3-5-sonnet-20241022-v2:0" + DEFAULT_MAX_TOKENS = 5_000 supports_streaming = True def __init__(self, profile_name: str | None = None): @@ -25,7 +25,12 @@ class Amazon(BaseProvider): @cached_property def client(self): """The AnthropicBedrock client.""" - import anthropic + try: + import anthropic + except ImportError as exc: + raise ImportError( + "Please install the `anthropic` package: `pip install anthropic`" + ) from exc if not self.profile_name: raise ValueError("Profile name is not provided") @@ -33,12 +38,12 @@ class Amazon(BaseProvider): return anthropic.AnthropicBedrock(aws_profile=self.profile_name) @cached_property - def structured_client(self): + def structured_client(self) -> instructor.Instructor: """A client patched with Instructor.""" return instructor.from_anthropic(self.client) - def send_conversation(self, conversation: "Conversation", **kwargs): + def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message": """Send a conversation to the OpenAI API.""" from ..models import Message @@ -59,7 +64,7 @@ class Amazon(BaseProvider): role="assistant", text=assistant_message.content or "", raw=response, - llm_model=conversation.llm_model or DEFAULT_MODEL, + llm_model=conversation.llm_model or self.DEFAULT_MODEL, llm_provider=PROVIDER_NAME, ) @@ -75,12 +80,12 @@ class Amazon(BaseProvider): messages=messages, model=llm_model or self.DEFAULT_MODEL, response_model=response_model, - max_tokens=DEFAULT_MAX_TOKENS, + max_tokens=self.DEFAULT_MAX_TOKENS, **kwargs, ) return response - def generate_text(self, prompt, *, llm_model, **kwargs): + def generate_text(self, prompt: str, *, llm_model: str, **kwargs): messages = [ {"role": "user", "content": prompt}, ] @@ -88,13 +93,15 @@ class Amazon(BaseProvider): response = self.client.messages.create( model=llm_model or self.DEFAULT_MODEL, messages=messages, - max_tokens=DEFAULT_MAX_TOKENS, + max_tokens=self.DEFAULT_MAX_TOKENS, **kwargs, ) return response.content[0].text - def generate_stream_text(self, prompt, *, llm_model, **kwargs) -> Iterator[str]: + def generate_stream_text( + self, prompt: str, *, llm_model: str, **kwargs + ) -> Iterator[str]: """Generate streaming text using the Amazon API.""" # Prepare the messages. diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index ffa776d..e8de444 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import TYPE_CHECKING, Type, TypeVar, Iterator +from typing import TYPE_CHECKING, Any, Callable, Iterator, Type, TypeVar import instructor from pydantic import BaseModel @@ -7,6 +7,7 @@ from pydantic import BaseModel from ..logging import logger from ..settings import settings from ._base import BaseProvider +from ._base_tools import BaseTool if TYPE_CHECKING: from ..models import Conversation, Message @@ -14,20 +15,67 @@ if TYPE_CHECKING: T = TypeVar("T", bound=BaseModel) -PROVIDER_NAME = "anthropic" -DEFAULT_MODEL = "claude-3-5-sonnet-20241022" -DEFAULT_MAX_TOKENS = 1_000 -DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} +class AnthropicTool(BaseTool): + def get_response_schema(self) -> Any: + assert self.is_executed, f"Tool {self.name} was not executed." + assert isinstance( + self.tool_id, str + ), f"Expected str for `tool_id` got {self.tool_id!r}" + return { + "type": "tool_result", + "tool_use_id": self.tool_id, + "content": self.function_result, + } + + @logger + def handle(self, response, messages) -> None: + """Handle the tool execution result from an API response.""" + msg = {"role": "assistant", "content": []} + tool_used = False + for content in response.content: + if content.type == "tool_use" and content.name == self.name: + msg["content"].append( + { + "type": "tool_use", + "id": content.id, + "name": content.name, + "input": content.input, + } + ) + # Function execution: + self.function_result = str(self.raw_func(**content.input)) + self.tool_id = content.id + tool_used = True + elif content.type == "text": + msg["content"].append({"type": "text", "text": content.text}) + + if tool_used: + messages.append(msg) + messages.append( + {"role": "user", "content": [self.get_response_schema()]} + ) + + def get_input_schema(self): + return { + "name": self.name, + "description": self.description, + "input_schema": { + "type": "object", + "properties": self.get_properties_schema(), + "required": self.required, + }, + } class Anthropic(BaseProvider): - NAME = PROVIDER_NAME - DEFAULT_MODEL = DEFAULT_MODEL - DEFAULT_KWARGS = DEFAULT_KWARGS + NAME = "anthropic" + DEFAULT_MODEL = "claude-3-5-sonnet-20241022" + DEFAULT_MAX_TOKENS = 1_000 + DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} supports_streaming = True def __init__(self, api_key: str | None = None): - self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) + self.api_key = api_key or settings.get_api_key(self.NAME) @cached_property def client(self): @@ -49,30 +97,60 @@ class Anthropic(BaseProvider): return instructor.from_anthropic(self.client) @logger - def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message": + def send_conversation( + self, + conversation: "Conversation", + tools: list[Callable | BaseTool] | None = None, + **kwargs, + ) -> "Message": """Send a conversation to the Anthropic API.""" from ..models import Message - messages = [ - {"role": msg.role, "content": msg.text} for msg in conversation.messages + # Format messages from conversation + formatted_messages = [ + {"role": msg.role, "content": msg.text} + for msg in conversation.messages ] - response = self.client.messages.create( - model=conversation.llm_model or self.DEFAULT_MODEL, - messages=messages, - **{**self.DEFAULT_KWARGS, **kwargs}, + # Set up tools if provided + converted_tools = self.make_tools(tools) + tools_config = ( + {"tools": [t.get_input_schema() for t in converted_tools]} + if tools is not None + else {} ) - # Get the response content from the Anthropic response - assistant_message = response.content[0].text + # Merge all kwargs + request_kwargs = { + **self.DEFAULT_KWARGS, + **kwargs, + **tools_config, + "model": conversation.llm_model or self.DEFAULT_MODEL, + "messages": formatted_messages, + } + + # Make initial API call + response = self.client.messages.create(**request_kwargs) + + # Handle tool responses if needed + while response.content[-1].type != "text": + # Continue handling tools if the LLM is doing + # multiple sub-seqequent/sequential tool calls + for tool in converted_tools: + tool.handle(response, formatted_messages) + if tool.is_executed(): + response = self.client.messages.create(**request_kwargs) + # Resetting the tool results in case this tool gets used again + tool.reset_result() + + final_message = response.content[-1].text - # Create and return a properly formatted Message instance return Message( role="assistant", - text=assistant_message, + text=final_message, raw=response, llm_model=conversation.llm_model or self.DEFAULT_MODEL, - llm_provider=PROVIDER_NAME, + llm_provider=self.NAME, ) @logger @@ -127,3 +205,8 @@ class Anthropic(BaseProvider): # Yield each chunk of text from the stream. for chunk in stream.text_stream: yield chunk + + @cached_property + def tool(self) -> Type[BaseTool]: + """The tool implementation for Antrhopic.""" + return AnthropicTool diff --git a/simplemind/providers/gemini.py b/simplemind/providers/gemini.py index 0123bf4..7415b0d 100644 --- a/simplemind/providers/gemini.py +++ b/simplemind/providers/gemini.py @@ -2,7 +2,7 @@ # IT is not currently working as desired. from functools import cached_property -from typing import TYPE_CHECKING, Type, TypeVar, Iterator +from typing import TYPE_CHECKING, Iterator, Type, TypeVar import instructor from pydantic import BaseModel @@ -17,18 +17,14 @@ if TYPE_CHECKING: T = TypeVar("T", bound=BaseModel) -PROVIDER_NAME = "gemini" -DEFAULT_MODEL = "models/gemini-1.5-flash-latest" - - class Gemini(BaseProvider): - NAME = PROVIDER_NAME - DEFAULT_MODEL = DEFAULT_MODEL + NAME = "gemini" + DEFAULT_MODEL = "models/gemini-1.5-flash-latest" supports_streaming = True def __init__(self, api_key: str | None = None): - self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) - self.model_name = DEFAULT_MODEL + self.api_key = api_key or settings.get_api_key(self.NAME) + self.model_name = self.DEFAULT_MODEL def set_model(self, model_name: str): self.model_name = model_name @@ -76,7 +72,7 @@ class Gemini(BaseProvider): text=response.text, raw=response, llm_model=self.model_name, - llm_provider=PROVIDER_NAME, + llm_provider=self.NAME, ) @logger diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py index 45593aa..107ce80 100644 --- a/simplemind/providers/groq.py +++ b/simplemind/providers/groq.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import TYPE_CHECKING, Type, TypeVar, Iterator +from typing import TYPE_CHECKING, Iterator, Type, TypeVar import instructor from pydantic import BaseModel @@ -14,20 +14,15 @@ if TYPE_CHECKING: T = TypeVar("T", bound=BaseModel) -PROVIDER_NAME = "groq" -DEFAULT_MODEL = "llama3-8b-8192" -DEFAULT_MAX_TOKENS = 1_000 -DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} - - class Groq(BaseProvider): - NAME = PROVIDER_NAME - DEFAULT_MODEL = DEFAULT_MODEL - DEFAULT_KWARGS = DEFAULT_KWARGS + NAME = "groq" + DEFAULT_MODEL = "llama3-8b-8192" + DEFAULT_MAX_TOKENS = 1_000 + DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} supports_streaming = True def __init__(self, api_key: str | None = None): - self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) + self.api_key = api_key or settings.get_api_key(self.NAME) @cached_property def client(self): @@ -75,7 +70,7 @@ class Groq(BaseProvider): text=assistant_message.content or "", raw=response, llm_model=conversation.llm_model or self.DEFAULT_MODEL, - llm_provider=PROVIDER_NAME, + llm_provider=self.NAME, ) @logger diff --git a/simplemind/providers/ollama.py b/simplemind/providers/ollama.py index 84df560..685ca0c 100644 --- a/simplemind/providers/ollama.py +++ b/simplemind/providers/ollama.py @@ -1,8 +1,7 @@ from functools import cached_property -from typing import TYPE_CHECKING, Type, TypeVar, Iterator +from typing import TYPE_CHECKING, Iterator, Type, TypeVar import instructor -from openai import OpenAI from pydantic import BaseModel from ..logging import logger @@ -15,17 +14,11 @@ if TYPE_CHECKING: T = TypeVar("T", bound=BaseModel) -PROVIDER_NAME = "ollama" -DEFAULT_MODEL = "llama3.2" -DEFAULT_TIMEOUT = 60 -DEFAULT_KWARGS = {} - - class Ollama(BaseProvider): - NAME = PROVIDER_NAME - DEFAULT_MODEL = DEFAULT_MODEL - DEFAULT_KWARGS = DEFAULT_KWARGS - TIMEOUT = DEFAULT_TIMEOUT + NAME = "ollama" + DEFAULT_MODEL = "llama3.2" + DEFAULT_TIMEOUT = 60 + DEFAULT_KWARGS = {} supports_streaming = True def __init__(self, host_url: str | None = None): @@ -37,21 +30,18 @@ class Ollama(BaseProvider): if not self.host_url: raise ValueError("No ollama host url provided") try: - import ollama as ol + import openai except ImportError as exc: raise ImportError( - "Please install the `ollama` package: `pip install ollama`" + "Please install the `openai` package: `pip install openai`" ) from exc - return ol.Client(timeout=self.TIMEOUT, host=self.host_url) + return openai.OpenAI(base_url=f"{self.host_url}/v1", api_key="ollama") @cached_property def structured_client(self) -> instructor.Instructor: """A client patched with Instructor.""" return instructor.from_openai( - OpenAI( - base_url=f"{self.host_url}/v1", - api_key="ollama", - ), + self.client, mode=instructor.Mode.JSON, ) @@ -63,20 +53,24 @@ class Ollama(BaseProvider): messages = [ {"role": msg.role, "content": msg.text} for msg in conversation.messages ] - response = self.client.chat( - model=conversation.llm_model or DEFAULT_MODEL, - messages=messages, - **{**self.DEFAULT_KWARGS, **kwargs}, - ) - assistant_message = response.get("message") + + request_kwargs = { + **self.DEFAULT_KWARGS, + **kwargs, + "model": conversation.llm_model or self.DEFAULT_MODEL, + "messages": messages, + } + + response = self.client.chat.completions.create(**request_kwargs) + assistant_message = response.choices[0].message # Create and return a properly formatted Message instance return Message( role="assistant", - text=assistant_message.get("content"), + text=assistant_message.content or "", raw=response, llm_model=conversation.llm_model or self.DEFAULT_MODEL, - llm_provider=PROVIDER_NAME, + llm_provider=self.NAME, ) @logger @@ -110,13 +104,13 @@ class Ollama(BaseProvider): {"role": "user", "content": prompt}, ] - response = self.client.chat( + response = self.client.chat.completions.create( messages=messages, model=llm_model or self.DEFAULT_MODEL, **{**self.DEFAULT_KWARGS, **kwargs}, ) - return response.get("message", {}).get("content", "") + return response.choices[0].message.content @logger def generate_stream_text( @@ -127,7 +121,7 @@ class Ollama(BaseProvider): {"role": "user", "content": prompt}, ] - response = self.client.chat( + response = self.client.chat.completions.create( messages=messages, model=llm_model or self.DEFAULT_MODEL, stream=True, @@ -136,4 +130,5 @@ class Ollama(BaseProvider): # Iterate over the response and yield the content. for chunk in response: - yield chunk["message"]["content"] + if chunk.choices[0].delta.content is not None: + yield chunk.choices[0].delta.content diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index 58e4e21..25cc3d3 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import TYPE_CHECKING, Type, TypeVar, Iterator +from typing import TYPE_CHECKING, Callable, Iterator, Type, TypeVar import instructor from pydantic import BaseModel @@ -7,26 +7,96 @@ from pydantic import BaseModel from ..logging import logger from ..settings import settings from ._base import BaseProvider +from ._base_tools import BaseTool if TYPE_CHECKING: from ..models import Conversation, Message T = TypeVar("T", bound=BaseModel) -PROVIDER_NAME = "openai" -DEFAULT_MODEL = "gpt-4o-mini" -DEFAULT_MAX_TOKENS = None -DEFAULT_KWARGS = {} + +class OpenAITool(BaseTool): + def get_response_schema(self): + assert self.is_executed, f"Tool {self.name} was not executed." + assert isinstance( + self.tool_id, str + ), f"Expected str for `tool_id` got {self.tool_id!r}" + + return { + "role": "tool", + "tool_call_id": self.tool_id, + "content": self.function_result, + } + + @logger + def handle(self, response, messages) -> None: + """Handle the tool execution result from an API response.""" + tool_used = False + + # Get the message from the response + assistant_message = response.choices[0].message + + # Check if there's a tool call + if assistant_message.tool_calls: + tool_call = assistant_message.tool_calls[ + 0 + ] # Get the first tool call + if tool_call.function.name == self.name: + # Execute the function + import json + + function_args = json.loads(tool_call.function.arguments) + self.function_result = str(self.raw_func(**function_args)) + self.tool_id = tool_call.id + tool_used = True + + # Add assistant's message with tool call + messages.append( + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": tool_call.id, + "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, + } + ], + } + ) + + if tool_used: + # Add tool response message + messages.append(self.get_response_schema()) + + def get_input_schema(self): + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": { + "type": "object", + "properties": self.get_properties_schema(), + "required": self.required, + "additionalProperties": False, + }, + }, + } class OpenAI(BaseProvider): - NAME = PROVIDER_NAME - DEFAULT_MODEL = DEFAULT_MODEL - DEFAULT_KWARGS = DEFAULT_KWARGS + NAME = "openai" + DEFAULT_MODEL = "gpt-4o-mini" + DEFAULT_MAX_TOKENS = None + DEFAULT_KWARGS = {} supports_streaming = True def __init__(self, api_key: str | None = None): - self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) + self.api_key = api_key or settings.get_api_key(self.NAME) @cached_property def client(self): @@ -47,30 +117,62 @@ class OpenAI(BaseProvider): return instructor.from_openai(self.client) @logger - def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message": + def send_conversation( + self, + conversation: "Conversation", + tools: list[Callable | BaseTool] | None = None, + **kwargs, + ) -> "Message": """Send a conversation to the OpenAI API.""" from ..models import Message - messages = [ - {"role": msg.role, "content": msg.text} for msg in conversation.messages + # Format messages from conversation + formatted_messages = [ + {"role": msg.role, "content": msg.text} + for msg in conversation.messages ] - response = self.client.chat.completions.create( - model=conversation.llm_model or DEFAULT_MODEL, - messages=messages, - **{**self.DEFAULT_KWARGS, **kwargs}, + # Set up tools if provided + converted_tools = self.make_tools(tools) + tools_config = ( + [t.get_input_schema() for t in converted_tools] if tools else None ) - # Get the response content from the OpenAI response - assistant_message = response.choices[0].message + # Merge all kwargs + request_kwargs = { + **self.DEFAULT_KWARGS, + **kwargs, + "model": conversation.llm_model or self.DEFAULT_MODEL, + "messages": formatted_messages, + } + + if tools_config: + request_kwargs["tools"] = tools_config + + # Make initial API call + response = self.client.chat.completions.create(**request_kwargs) + + # Handle tool responses if needed + while response.choices[0].message.tool_calls: + print(response) + # Handle each tool call + for tool in converted_tools: + tool.handle(response, formatted_messages) + if tool.is_executed(): + # Make another API call with the updated messages + response = self.client.chat.completions.create( + **request_kwargs + ) + tool.reset_result() + + final_message = response.choices[0].message.content - # Create and return a properly formatted Message instance return Message( role="assistant", - text=assistant_message.content or "", + text=final_message or "", raw=response, - llm_model=conversation.llm_model or DEFAULT_MODEL, - llm_provider=PROVIDER_NAME, + llm_model=conversation.llm_model or self.DEFAULT_MODEL, + llm_provider=self.NAME, ) @logger @@ -96,7 +198,9 @@ class OpenAI(BaseProvider): return response_model.model_validate(response) @logger - def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs): + def generate_text( + self, prompt: str, *, llm_model: str | None = None, **kwargs + ): """Generate text using the OpenAI API.""" messages = [ {"role": "user", "content": prompt}, @@ -129,3 +233,8 @@ class OpenAI(BaseProvider): for chunk in response: if chunk.choices[0].delta.content is not None: yield chunk.choices[0].delta.content + + @cached_property + def tool(self) -> Type[BaseTool]: + """The tool implementation for OpenAI.""" + return OpenAITool diff --git a/simplemind/providers/xai.py b/simplemind/providers/xai.py index 279c150..23d8a5c 100644 --- a/simplemind/providers/xai.py +++ b/simplemind/providers/xai.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import TYPE_CHECKING, Type, TypeVar, Iterator +from typing import TYPE_CHECKING, Iterator, Type, TypeVar import instructor from pydantic import BaseModel @@ -14,22 +14,17 @@ if TYPE_CHECKING: T = TypeVar("T", bound=BaseModel) -PROVIDER_NAME = "xai" -DEFAULT_MODEL = "grok-beta" -BASE_URL = "https://api.x.ai/v1" -DEFAULT_MAX_TOKENS = 1000 -DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} - - class XAI(BaseProvider): - NAME = PROVIDER_NAME - DEFAULT_MODEL = DEFAULT_MODEL - DEFAULT_KWARGS = DEFAULT_KWARGS + NAME = "xai" + DEFAULT_MODEL = "grok-beta" + DEFAULT_MAX_TOKENS = 1000 + DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} + BASE_URL = "https://api.x.ai/v1" supports_streaming = True supports_structured_responses = False def __init__(self, api_key: str | None = None): - self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) + self.api_key = api_key or settings.get_api_key(self.NAME) @cached_property def client(self): @@ -44,7 +39,7 @@ class XAI(BaseProvider): ) from exc return oa.OpenAI( api_key=self.api_key, - base_url=BASE_URL, + base_url=self.BASE_URL, ) @cached_property @@ -76,7 +71,7 @@ class XAI(BaseProvider): text=assistant_message.content, raw=response, llm_model=conversation.llm_model or self.DEFAULT_MODEL, - llm_provider=PROVIDER_NAME, + llm_provider=self.NAME, ) @logger diff --git a/simplemind/settings.py b/simplemind/settings.py index ee11063..de69339 100644 --- a/simplemind/settings.py +++ b/simplemind/settings.py @@ -14,8 +14,9 @@ class LoggingConfig(BaseSettings): """Enable logging for the application.""" # adding imports here to avoid forced dependencies try: - import logfire from logging import basicConfig + + import logfire except ImportError as e: raise ImportError( "To enable logging, please install logfire: `pip install logfire`" diff --git a/tests/test_generate_data.py b/tests/test_generate_data.py index 610c96a..cd00bc6 100644 --- a/tests/test_generate_data.py +++ b/tests/test_generate_data.py @@ -1,9 +1,8 @@ import pytest - -from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon - from pydantic import BaseModel +from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI + class ResponseModel(BaseModel): result: int diff --git a/tests/test_generate_text.py b/tests/test_generate_text.py index 4ab62cf..2b3eb0e 100644 --- a/tests/test_generate_text.py +++ b/tests/test_generate_text.py @@ -1,6 +1,6 @@ import pytest -from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon +from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI @pytest.mark.parametrize( diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 0000000..bfe3ff4 --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,118 @@ +from typing import Annotated, Literal + +import pytest +from pydantic import Field + +import simplemind as sm + +from simplemind.providers import Anthropic, OpenAI +from simplemind.providers._base_tools import BaseTool + +MODELS = [ + Anthropic, + # Gemini, + OpenAI, + # Groq, + # Ollama, + # Amazon +] + + +def get_weather( + location: Annotated[ + str, Field(description="The city and state, e.g. San Francisco, CA") + ], + unit: Annotated[ + Literal["celcius", "fahrenheit"], + Field(description="The unit of temperature, either 'celsius' or 'fahrenheit'"), + ] = "celcius", +): + """ + Get the current weather in a given location + """ + return f"42 {unit}" + + +def get_location(): + """Get the current location""" + return "San Francisco,CA" + + +@pytest.mark.parametrize( + "provider_cls", + MODELS, +) +def test_single_tool_args(provider_cls): + conv = sm.create_conversation( + llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME + ) + + conv.add_message(text="What is the weather in San Francisco?") + data = conv.send(tools=[get_weather]) + assert "42" in data.text + + +@pytest.mark.parametrize( + "provider_cls", + MODELS, +) +def test_single_tool_no_args(provider_cls): + conv = sm.create_conversation( + llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME + ) + + conv.add_message(text="What is my current location") + data = conv.send(tools=[get_location]) + assert "San Francisco" in data.text + + +@pytest.mark.parametrize( + "provider_cls", + MODELS, +) +def test_single_tool_partial(provider_cls): + conv = sm.create_conversation( + llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME + ) + + conv.add_message(text="Can you tell me the weather?") + conv.send(tools=[get_weather]) + # Will answer something like: + """ + I can help you check the weather, but I need to know the location you're interested in. + Could you please provide a city and state (e.g., "Los Angeles, CA" or "New York, NY") + where you'd like to know the weather? + """ + + conv.add_message(text="San Francisco, CA") + data = conv.send(tools=[get_weather]) + assert "42" in data.text + + +@pytest.mark.parametrize( + "provider_cls", + MODELS, +) +def test_multiple_tools(provider_cls): + conv = sm.create_conversation( + llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME + ) + + conv.add_message(text="What is the wheather at my current location?") + data = conv.send(tools=[get_location, get_weather]) + assert "San Francisco" in data.text + assert "42" in data.text + + +@pytest.mark.parametrize( + "provider_cls", + MODELS, +) +def test_tool_decorator(provider_cls): + @sm.tool(llm_provider=provider_cls.NAME) + def exchange_rate(currency_pair: str) -> float: + return 7.9 + + assert isinstance(exchange_rate, BaseTool) + assert exchange_rate.name == "exchange_rate" + assert list(exchange_rate.properties.keys()) == ["currency_pair"]