From 670240b9438686c41b756983e66c5cd8e84959c3 Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Sun, 10 Nov 2024 19:59:52 +0530 Subject: [PATCH 01/24] removed reduntant variables. moved few inside the class --- simplemind/providers/__init__.py | 24 ++++++++++++++++-- simplemind/providers/amazon.py | 41 +++++++++++++++++++------------ simplemind/providers/anthropic.py | 19 ++++++-------- simplemind/providers/gemini.py | 16 +++++------- simplemind/providers/groq.py | 19 ++++++-------- simplemind/providers/ollama.py | 30 ++++++++-------------- simplemind/providers/openai.py | 22 +++++++---------- simplemind/providers/xai.py | 23 +++++++---------- 8 files changed, 95 insertions(+), 99 deletions(-) diff --git a/simplemind/providers/__init__.py b/simplemind/providers/__init__.py index cc6c465..1f72f78 100644 --- a/simplemind/providers/__init__.py +++ b/simplemind/providers/__init__.py @@ -1,12 +1,32 @@ from typing import List, Type from ._base import BaseProvider +from .amazon import Amazon from .anthropic import Anthropic from .gemini import Gemini from .groq import Groq from .ollama import Ollama from .openai import OpenAI from .xai import XAI -from .amazon import Amazon -providers: List[Type[BaseProvider]] = [Anthropic, Gemini, Groq, OpenAI, Ollama, XAI, Amazon] +providers: List[Type[BaseProvider]] = [ + Anthropic, + Gemini, + Groq, + OpenAI, + Ollama, + XAI, + Amazon, +] + +__all__ = [ + "Anthropic", + "Gemini", + "Groq", + "OpenAI", + "Ollama", + "XAI", + "Amazon", + "providers", + "BaseProvider", +] diff --git a/simplemind/providers/amazon.py b/simplemind/providers/amazon.py index fe82d9a..492f2ab 100644 --- a/simplemind/providers/amazon.py +++ b/simplemind/providers/amazon.py @@ -1,22 +1,24 @@ -from typing import Type, TypeVar, Iterator from functools import cached_property +from typing import TYPE_CHECKING, Iterator, Type, TypeVar import instructor from pydantic import BaseModel -from ._base import BaseProvider +from simplemind.models import Message + from ..settings import settings +from ._base import BaseProvider + +if TYPE_CHECKING: + from ..models import Conversation, Message T = TypeVar("T", bound=BaseModel) -PROVIDER_NAME = "amazon" -DEFAULT_MODEL = "us.anthropic.claude-3-5-sonnet-20241022-v2:0" -DEFAULT_MAX_TOKENS = 5_000 - class Amazon(BaseProvider): - NAME = PROVIDER_NAME - DEFAULT_MODEL = DEFAULT_MODEL + NAME = "amazon" + DEFAULT_MODEL = "us.anthropic.claude-3-5-sonnet-20241022-v2:0" + DEFAULT_MAX_TOKENS = 5_000 supports_streaming = True def __init__(self, profile_name: str | None = None): @@ -25,7 +27,12 @@ class Amazon(BaseProvider): @cached_property def client(self): """The AnthropicBedrock client.""" - import anthropic + try: + import anthropic + except ImportError as exc: + raise ImportError( + "Please install the `anthropic` package: `pip install anthropic`" + ) from exc if not self.profile_name: raise ValueError("Profile name is not provided") @@ -33,12 +40,12 @@ class Amazon(BaseProvider): return anthropic.AnthropicBedrock(aws_profile=self.profile_name) @cached_property - def structured_client(self): + def structured_client(self) -> instructor.Instructor: """A client patched with Instructor.""" return instructor.from_anthropic(self.client) - def send_conversation(self, conversation: "Conversation", **kwargs): + def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message": """Send a conversation to the OpenAI API.""" from ..models import Message @@ -59,7 +66,7 @@ class Amazon(BaseProvider): role="assistant", text=assistant_message.content or "", raw=response, - llm_model=conversation.llm_model or DEFAULT_MODEL, + llm_model=conversation.llm_model or self.DEFAULT_MODEL, llm_provider=PROVIDER_NAME, ) @@ -75,12 +82,12 @@ class Amazon(BaseProvider): messages=messages, model=llm_model or self.DEFAULT_MODEL, response_model=response_model, - max_tokens=DEFAULT_MAX_TOKENS, + max_tokens=self.DEFAULT_MAX_TOKENS, **kwargs, ) return response - def generate_text(self, prompt, *, llm_model, **kwargs): + def generate_text(self, prompt: str, *, llm_model: str, **kwargs): messages = [ {"role": "user", "content": prompt}, ] @@ -88,13 +95,15 @@ class Amazon(BaseProvider): response = self.client.messages.create( model=llm_model or self.DEFAULT_MODEL, messages=messages, - max_tokens=DEFAULT_MAX_TOKENS, + max_tokens=self.DEFAULT_MAX_TOKENS, **kwargs, ) return response.content[0].text - def generate_stream_text(self, prompt, *, llm_model, **kwargs) -> Iterator[str]: + def generate_stream_text( + self, prompt: str, *, llm_model: str, **kwargs + ) -> Iterator[str]: """Generate streaming text using the Amazon API.""" # Prepare the messages. diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index ffa776d..6fbec5b 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import TYPE_CHECKING, Type, TypeVar, Iterator +from typing import TYPE_CHECKING, Iterator, Type, TypeVar import instructor from pydantic import BaseModel @@ -14,20 +14,15 @@ if TYPE_CHECKING: T = TypeVar("T", bound=BaseModel) -PROVIDER_NAME = "anthropic" -DEFAULT_MODEL = "claude-3-5-sonnet-20241022" -DEFAULT_MAX_TOKENS = 1_000 -DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} - - class Anthropic(BaseProvider): - NAME = PROVIDER_NAME - DEFAULT_MODEL = DEFAULT_MODEL - DEFAULT_KWARGS = DEFAULT_KWARGS + NAME = "anthropic" + DEFAULT_MODEL = "claude-3-5-sonnet-20241022" + DEFAULT_MAX_TOKENS = 1_000 + DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} supports_streaming = True def __init__(self, api_key: str | None = None): - self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) + self.api_key = api_key or settings.get_api_key(self.NAME) @cached_property def client(self): @@ -72,7 +67,7 @@ class Anthropic(BaseProvider): text=assistant_message, raw=response, llm_model=conversation.llm_model or self.DEFAULT_MODEL, - llm_provider=PROVIDER_NAME, + llm_provider=self.NAME, ) @logger diff --git a/simplemind/providers/gemini.py b/simplemind/providers/gemini.py index 0123bf4..7415b0d 100644 --- a/simplemind/providers/gemini.py +++ b/simplemind/providers/gemini.py @@ -2,7 +2,7 @@ # IT is not currently working as desired. from functools import cached_property -from typing import TYPE_CHECKING, Type, TypeVar, Iterator +from typing import TYPE_CHECKING, Iterator, Type, TypeVar import instructor from pydantic import BaseModel @@ -17,18 +17,14 @@ if TYPE_CHECKING: T = TypeVar("T", bound=BaseModel) -PROVIDER_NAME = "gemini" -DEFAULT_MODEL = "models/gemini-1.5-flash-latest" - - class Gemini(BaseProvider): - NAME = PROVIDER_NAME - DEFAULT_MODEL = DEFAULT_MODEL + NAME = "gemini" + DEFAULT_MODEL = "models/gemini-1.5-flash-latest" supports_streaming = True def __init__(self, api_key: str | None = None): - self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) - self.model_name = DEFAULT_MODEL + self.api_key = api_key or settings.get_api_key(self.NAME) + self.model_name = self.DEFAULT_MODEL def set_model(self, model_name: str): self.model_name = model_name @@ -76,7 +72,7 @@ class Gemini(BaseProvider): text=response.text, raw=response, llm_model=self.model_name, - llm_provider=PROVIDER_NAME, + llm_provider=self.NAME, ) @logger diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py index 45593aa..107ce80 100644 --- a/simplemind/providers/groq.py +++ b/simplemind/providers/groq.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import TYPE_CHECKING, Type, TypeVar, Iterator +from typing import TYPE_CHECKING, Iterator, Type, TypeVar import instructor from pydantic import BaseModel @@ -14,20 +14,15 @@ if TYPE_CHECKING: T = TypeVar("T", bound=BaseModel) -PROVIDER_NAME = "groq" -DEFAULT_MODEL = "llama3-8b-8192" -DEFAULT_MAX_TOKENS = 1_000 -DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} - - class Groq(BaseProvider): - NAME = PROVIDER_NAME - DEFAULT_MODEL = DEFAULT_MODEL - DEFAULT_KWARGS = DEFAULT_KWARGS + NAME = "groq" + DEFAULT_MODEL = "llama3-8b-8192" + DEFAULT_MAX_TOKENS = 1_000 + DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} supports_streaming = True def __init__(self, api_key: str | None = None): - self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) + self.api_key = api_key or settings.get_api_key(self.NAME) @cached_property def client(self): @@ -75,7 +70,7 @@ class Groq(BaseProvider): text=assistant_message.content or "", raw=response, llm_model=conversation.llm_model or self.DEFAULT_MODEL, - llm_provider=PROVIDER_NAME, + llm_provider=self.NAME, ) @logger diff --git a/simplemind/providers/ollama.py b/simplemind/providers/ollama.py index 84df560..50d529d 100644 --- a/simplemind/providers/ollama.py +++ b/simplemind/providers/ollama.py @@ -1,8 +1,7 @@ from functools import cached_property -from typing import TYPE_CHECKING, Type, TypeVar, Iterator +from typing import TYPE_CHECKING, Iterator, Type, TypeVar import instructor -from openai import OpenAI from pydantic import BaseModel from ..logging import logger @@ -15,17 +14,11 @@ if TYPE_CHECKING: T = TypeVar("T", bound=BaseModel) -PROVIDER_NAME = "ollama" -DEFAULT_MODEL = "llama3.2" -DEFAULT_TIMEOUT = 60 -DEFAULT_KWARGS = {} - - class Ollama(BaseProvider): - NAME = PROVIDER_NAME - DEFAULT_MODEL = DEFAULT_MODEL - DEFAULT_KWARGS = DEFAULT_KWARGS - TIMEOUT = DEFAULT_TIMEOUT + NAME = "ollama" + DEFAULT_MODEL = "llama3.2" + DEFAULT_TIMEOUT = 60 + DEFAULT_KWARGS = {} supports_streaming = True def __init__(self, host_url: str | None = None): @@ -37,21 +30,18 @@ class Ollama(BaseProvider): if not self.host_url: raise ValueError("No ollama host url provided") try: - import ollama as ol + import openai except ImportError as exc: raise ImportError( "Please install the `ollama` package: `pip install ollama`" ) from exc - return ol.Client(timeout=self.TIMEOUT, host=self.host_url) + return openai.OpenAI(base_url=f"{self.host_url}/v1", api_key="ollama") @cached_property def structured_client(self) -> instructor.Instructor: """A client patched with Instructor.""" return instructor.from_openai( - OpenAI( - base_url=f"{self.host_url}/v1", - api_key="ollama", - ), + self.client, mode=instructor.Mode.JSON, ) @@ -64,7 +54,7 @@ class Ollama(BaseProvider): {"role": msg.role, "content": msg.text} for msg in conversation.messages ] response = self.client.chat( - model=conversation.llm_model or DEFAULT_MODEL, + model=conversation.llm_model or self.DEFAULT_MODEL, messages=messages, **{**self.DEFAULT_KWARGS, **kwargs}, ) @@ -76,7 +66,7 @@ class Ollama(BaseProvider): text=assistant_message.get("content"), raw=response, llm_model=conversation.llm_model or self.DEFAULT_MODEL, - llm_provider=PROVIDER_NAME, + llm_provider=self.NAME, ) @logger diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index 58e4e21..84db110 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import TYPE_CHECKING, Type, TypeVar, Iterator +from typing import TYPE_CHECKING, Iterator, Type, TypeVar import instructor from pydantic import BaseModel @@ -13,20 +13,16 @@ if TYPE_CHECKING: T = TypeVar("T", bound=BaseModel) -PROVIDER_NAME = "openai" -DEFAULT_MODEL = "gpt-4o-mini" -DEFAULT_MAX_TOKENS = None -DEFAULT_KWARGS = {} - class OpenAI(BaseProvider): - NAME = PROVIDER_NAME - DEFAULT_MODEL = DEFAULT_MODEL - DEFAULT_KWARGS = DEFAULT_KWARGS + NAME = "openai" + DEFAULT_MODEL = "gpt-4o-mini" + DEFAULT_MAX_TOKENS = None + DEFAULT_KWARGS = {} supports_streaming = True def __init__(self, api_key: str | None = None): - self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) + self.api_key = api_key or settings.get_api_key(self.NAME) @cached_property def client(self): @@ -56,7 +52,7 @@ class OpenAI(BaseProvider): ] response = self.client.chat.completions.create( - model=conversation.llm_model or DEFAULT_MODEL, + model=conversation.llm_model or self.DEFAULT_MODEL, messages=messages, **{**self.DEFAULT_KWARGS, **kwargs}, ) @@ -69,8 +65,8 @@ class OpenAI(BaseProvider): role="assistant", text=assistant_message.content or "", raw=response, - llm_model=conversation.llm_model or DEFAULT_MODEL, - llm_provider=PROVIDER_NAME, + llm_model=conversation.llm_model or self.DEFAULT_MODEL, + llm_provider=self.NAME, ) @logger diff --git a/simplemind/providers/xai.py b/simplemind/providers/xai.py index 279c150..23d8a5c 100644 --- a/simplemind/providers/xai.py +++ b/simplemind/providers/xai.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import TYPE_CHECKING, Type, TypeVar, Iterator +from typing import TYPE_CHECKING, Iterator, Type, TypeVar import instructor from pydantic import BaseModel @@ -14,22 +14,17 @@ if TYPE_CHECKING: T = TypeVar("T", bound=BaseModel) -PROVIDER_NAME = "xai" -DEFAULT_MODEL = "grok-beta" -BASE_URL = "https://api.x.ai/v1" -DEFAULT_MAX_TOKENS = 1000 -DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} - - class XAI(BaseProvider): - NAME = PROVIDER_NAME - DEFAULT_MODEL = DEFAULT_MODEL - DEFAULT_KWARGS = DEFAULT_KWARGS + NAME = "xai" + DEFAULT_MODEL = "grok-beta" + DEFAULT_MAX_TOKENS = 1000 + DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} + BASE_URL = "https://api.x.ai/v1" supports_streaming = True supports_structured_responses = False def __init__(self, api_key: str | None = None): - self.api_key = api_key or settings.get_api_key(PROVIDER_NAME) + self.api_key = api_key or settings.get_api_key(self.NAME) @cached_property def client(self): @@ -44,7 +39,7 @@ class XAI(BaseProvider): ) from exc return oa.OpenAI( api_key=self.api_key, - base_url=BASE_URL, + base_url=self.BASE_URL, ) @cached_property @@ -76,7 +71,7 @@ class XAI(BaseProvider): text=assistant_message.content, raw=response, llm_model=conversation.llm_model or self.DEFAULT_MODEL, - llm_provider=PROVIDER_NAME, + llm_provider=self.NAME, ) @logger From fe2ca9d5f5f19a060e3f9ac0671d8d030445f334 Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Sun, 10 Nov 2024 20:00:13 +0530 Subject: [PATCH 02/24] black + isort formatting --- examples/cooking_recipe_example.py | 2 +- examples/discussion.py | 3 +- examples/enhanced_context.py | 47 +++++++++++++----------------- examples/four_way_conversation.py | 1 + examples/inspiration_plugin.py | 1 + examples/medicine_data.py | 2 +- examples/mood_detector_plugin.py | 2 +- examples/sentiment_analysis.py | 1 + examples/web_bible_explorer.py | 10 ++++--- simplemind/settings.py | 3 +- tests/test_conversations.py | 3 +- tests/test_generate_data.py | 5 ++-- tests/test_generate_text.py | 2 +- 13 files changed, 39 insertions(+), 43 deletions(-) diff --git a/examples/cooking_recipe_example.py b/examples/cooking_recipe_example.py index 8abc295..077d9ae 100644 --- a/examples/cooking_recipe_example.py +++ b/examples/cooking_recipe_example.py @@ -1,5 +1,5 @@ -from pydantic import BaseModel from _context import simplemind as sm +from pydantic import BaseModel from rich.console import Console from rich.panel import Panel from rich.text import Text diff --git a/examples/discussion.py b/examples/discussion.py index d63353d..42781b3 100644 --- a/examples/discussion.py +++ b/examples/discussion.py @@ -1,11 +1,10 @@ import time from typing import List, Tuple +from _context import sm from rich.console import Console from rich.markdown import Markdown -from _context import sm - class MultiAIConversation: """Orchestrates conversations between multiple AI models.""" diff --git a/examples/enhanced_context.py b/examples/enhanced_context.py index a40fd1f..ba63f8a 100644 --- a/examples/enhanced_context.py +++ b/examples/enhanced_context.py @@ -1,35 +1,28 @@ -from datetime import datetime -import logging -import sqlite3 -from typing import List -import re -import os import contextlib - -import spacy +import logging +import os +import random +import re +import sqlite3 +from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager - -from _context import simplemind as sm +from datetime import datetime +from typing import List import nltk -from nltk.tokenize import word_tokenize -from nltk.tag import pos_tag - -from rich.console import Console -from rich.panel import Panel -from rich.markdown import Markdown -from rich.status import Status - -from concurrent.futures import ThreadPoolExecutor -import random - -from docopt import docopt - -from prompt_toolkit import PromptSession -from prompt_toolkit.completion import Completer, Completion -from prompt_toolkit.auto_suggest import AutoSuggestFromHistory - +import spacy import xerox +from _context import simplemind as sm +from docopt import docopt +from nltk.tag import pos_tag +from nltk.tokenize import word_tokenize +from prompt_toolkit import PromptSession +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.completion import Completer, Completion +from rich.console import Console +from rich.markdown import Markdown +from rich.panel import Panel +from rich.status import Status DB_PATH = "enhanced_context.db" AVAILABLE_PROVIDERS = ["xai", "openai", "anthropic", "ollama"] diff --git a/examples/four_way_conversation.py b/examples/four_way_conversation.py index c414e46..2ee283a 100644 --- a/examples/four_way_conversation.py +++ b/examples/four_way_conversation.py @@ -1,4 +1,5 @@ import time + from _context import simplemind as sm diff --git a/examples/inspiration_plugin.py b/examples/inspiration_plugin.py index db2001b..6180385 100644 --- a/examples/inspiration_plugin.py +++ b/examples/inspiration_plugin.py @@ -1,4 +1,5 @@ import random + from _context import simplemind as sm diff --git a/examples/medicine_data.py b/examples/medicine_data.py index b6a7be4..361e833 100644 --- a/examples/medicine_data.py +++ b/examples/medicine_data.py @@ -1,5 +1,5 @@ -from pydantic import BaseModel from _context import simplemind as sm +from pydantic import BaseModel from rich.console import Console from rich.panel import Panel from rich.table import Table diff --git a/examples/mood_detector_plugin.py b/examples/mood_detector_plugin.py index 5f61c36..e15c21e 100644 --- a/examples/mood_detector_plugin.py +++ b/examples/mood_detector_plugin.py @@ -1,7 +1,7 @@ import nltk +from _context import simplemind as sm from nltk.sentiment import SentimentIntensityAnalyzer from rich.console import Console -from _context import simplemind as sm nltk.download("vader_lexicon") diff --git a/examples/sentiment_analysis.py b/examples/sentiment_analysis.py index 8354046..2c380fa 100644 --- a/examples/sentiment_analysis.py +++ b/examples/sentiment_analysis.py @@ -5,6 +5,7 @@ from pydantic import BaseModel # Note: you should probably be using textblob for this. + class SentimentAnalysis(BaseModel): sentiment: Literal["positive", "negative", "neutral"] confidence: float diff --git a/examples/web_bible_explorer.py b/examples/web_bible_explorer.py index 4012078..373c650 100644 --- a/examples/web_bible_explorer.py +++ b/examples/web_bible_explorer.py @@ -1,8 +1,10 @@ -from fastapi import FastAPI, Request, HTTPException -from fastapi.templating import Jinja2Templates -from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel from typing import List + +from fastapi import FastAPI, HTTPException, Request +from fastapi.staticfiles import StaticFiles +from fastapi.templating import Jinja2Templates +from pydantic import BaseModel + import simplemind as sm app = FastAPI() diff --git a/simplemind/settings.py b/simplemind/settings.py index ee11063..de69339 100644 --- a/simplemind/settings.py +++ b/simplemind/settings.py @@ -14,8 +14,9 @@ class LoggingConfig(BaseSettings): """Enable logging for the application.""" # adding imports here to avoid forced dependencies try: - import logfire from logging import basicConfig + + import logfire except ImportError as e: raise ImportError( "To enable logging, please install logfire: `pip install logfire`" diff --git a/tests/test_conversations.py b/tests/test_conversations.py index b0c85dd..4747f65 100644 --- a/tests/test_conversations.py +++ b/tests/test_conversations.py @@ -1,8 +1,7 @@ import pytest -from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon - import simplemind as sm +from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI @pytest.mark.parametrize( diff --git a/tests/test_generate_data.py b/tests/test_generate_data.py index 610c96a..cd00bc6 100644 --- a/tests/test_generate_data.py +++ b/tests/test_generate_data.py @@ -1,9 +1,8 @@ import pytest - -from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon - from pydantic import BaseModel +from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI + class ResponseModel(BaseModel): result: int diff --git a/tests/test_generate_text.py b/tests/test_generate_text.py index 4ab62cf..2b3eb0e 100644 --- a/tests/test_generate_text.py +++ b/tests/test_generate_text.py @@ -1,6 +1,6 @@ import pytest -from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon +from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI @pytest.mark.parametrize( From e79b474215960b873db51ce59abea8f395715d21 Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Sun, 10 Nov 2024 20:05:49 +0530 Subject: [PATCH 03/24] fixed dependencies --- pyproject.toml | 12 ++++++------ simplemind/providers/ollama.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index acbee0e..3135d18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,18 +10,18 @@ dependencies = ["pydantic", "pydantic-settings", "instructor", "logfire"] full = [ "openai", "anthropic", - "ollama", "groq", "google-generativeai", "botocore", "boto3" ] -openai = ["openai"] -anthropic = ["anthropic"] -ollama = ["ollama", "openai"] -groq = ["groq"] -gemini = ["google-generativeai"] amazon = ["boto3", "botocore", "anthropic"] +anthropic = ["anthropic"] +gemini = ["google-generativeai"] +groq = ["groq"] +ollama = ["openai"] +openai = ["openai"] +xai = ["openai"] [build-system] requires = ["hatchling"] diff --git a/simplemind/providers/ollama.py b/simplemind/providers/ollama.py index 50d529d..4266e68 100644 --- a/simplemind/providers/ollama.py +++ b/simplemind/providers/ollama.py @@ -33,7 +33,7 @@ class Ollama(BaseProvider): import openai except ImportError as exc: raise ImportError( - "Please install the `ollama` package: `pip install ollama`" + "Please install the `openai` package: `pip install openai`" ) from exc return openai.OpenAI(base_url=f"{self.host_url}/v1", api_key="ollama") From fe5af93780be2819066abae24449c1af6aa45447 Mon Sep 17 00:00:00 2001 From: Luciano Scarpulla Date: Mon, 11 Nov 2024 12:29:00 +0800 Subject: [PATCH 04/24] first draft --- simplemind/providers/_base.py | 22 +++++- simplemind/providers/_base_tools.py | 110 ++++++++++++++++++++++++++++ simplemind/providers/anthropic.py | 32 +++++++- simplemind/providers/openai.py | 36 ++++++++- 4 files changed, 190 insertions(+), 10 deletions(-) create mode 100644 simplemind/providers/_base_tools.py diff --git a/simplemind/providers/_base.py b/simplemind/providers/_base.py index 101a33a..7433759 100644 --- a/simplemind/providers/_base.py +++ b/simplemind/providers/_base.py @@ -1,10 +1,12 @@ from abc import ABC, abstractmethod from functools import cached_property -from typing import TYPE_CHECKING, Any, Type, TypeVar +from typing import TYPE_CHECKING, Any, Type, TypeVar, Callable from instructor import Instructor from pydantic import BaseModel +from simplemind.providers._base_tools import BaseTool + if TYPE_CHECKING: from ..models import Conversation, Message @@ -37,11 +39,25 @@ class BaseProvider(ABC): raise NotImplementedError @abstractmethod - def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T: + def structured_response( + self, prompt: str, response_model: Type[T], **kwargs + ) -> T: """Get a structured response.""" raise NotImplementedError @abstractmethod - def generate_text(self, prompt: str, *, stream: bool = False, **kwargs) -> str: + def generate_text( + self, prompt: str, *, stream: bool = False, **kwargs + ) -> str: """Generate text from a prompt.""" raise NotImplementedError + + @cached_property + @abstractmethod + def tool(self) -> Type[BaseTool]: + """The tool implementation for the provider.""" + raise NotImplementedError + + def make_tools(self, tools:list[Callable] | None) + if tools is not None: + return [self.tool.from_function(func) for func in tools] diff --git a/simplemind/providers/_base_tools.py b/simplemind/providers/_base_tools.py new file mode 100644 index 0000000..6922754 --- /dev/null +++ b/simplemind/providers/_base_tools.py @@ -0,0 +1,110 @@ +import inspect +from abc import ABC, abstractmethod +from typing import Any, Callable, ClassVar, Literal, get_origin + +from pydantic import BaseModel, Field +from pydantic.fields import FieldInfo +from pydantic_core import PydanticUndefinedType + + +def _is_literal(t: Any) -> bool: + return get_origin(t) is Literal + + +def _is_required(field, func_signature, arg_name) -> bool: + param = func_signature.parameters[arg_name] + # If parameter has a default value that's not a FieldInfo, it's not required + if param.default is not inspect.Parameter.empty and not isinstance( + param.default, FieldInfo + ): + return False + # If the field has a default that's not undefined, it's not required + return isinstance(field.default, PydanticUndefinedType) + + +class BaseToolConfig(BaseModel): + TYPE_CONVERSION: dict[type, str] = { + str: "string", + int: "integer", + bool: "boolean", + } + + +class BaseToolProperty(BaseModel): + type: str = Field(serialization_alias="type_") + enum: list[str] | None = None + description: str + + +class BaseTool(BaseModel, ABC): + name: str + description: str + properties: dict[str, BaseToolProperty] + required: list[str] | None = None + config: ClassVar[BaseToolConfig] = BaseToolConfig() + + @classmethod + def convert_type(cls, field_type) -> str: + if _is_literal(field_type): + return cls.config.TYPE_CONVERSION[str] + + field_type_converted = cls.config.TYPE_CONVERSION.get(field_type, None) + + if field_type_converted is None: + raise TypeError(f"Field of type {field_type} is not supported") + + return field_type_converted + + def get_properties_schema(self, **kwargs) -> dict[str, dict]: + new_kwargs: dict = {"exclude_none": True} | kwargs + return { + k: v.model_dump(**new_kwargs) for k, v in self.properties.items() + } + + @classmethod + def from_function(cls, func: Callable): + annotations = getattr(func, "__annotations__", {}) + properties = {} + required = [] + enum_values = None + func_signature = inspect.signature(func) + + for n, (arg_name, arg_type) in enumerate(annotations.items()): + # Check if argument has metadata (from Annotated) + if hasattr(arg_type, "__metadata__"): + field = arg_type.__metadata__[0] # Get Field info from metadata + field_type = arg_type.__origin__ # Get actual type + # Check if argument has a default value in signature + elif ( + sig_param := func_signature.parameters[arg_name] + ).default is not inspect.Parameter.empty: + field = sig_param.default # Use default as Field + field_type = arg_type # Use plain type annotation + else: + # Raise error if no Field annotation found + raise ValueError( + f"Please add a Field annotation to `{func.__name__}.{arg_name}` parameter" + ) + + field_type_converted = cls.convert_type(field_type) + + if _is_literal(field_type): + enum_values = [str(x) for x in field_type.__args__] + + properties[arg_name] = BaseToolProperty( + type=field_type_converted, + description=field.description, + enum=enum_values, + ) + if _is_required(field, func_signature, arg_name): + required.append(arg_name) + + return cls( + name=func.__name__, + description=(func.__doc__ or "").strip(), + properties=properties, + required=required, + ) + + @abstractmethod + def get_schema(self) -> Any: ... diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index ffa776d..5e6ca01 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import TYPE_CHECKING, Type, TypeVar, Iterator +from typing import TYPE_CHECKING, Callable, Iterator, Type, TypeVar import instructor from pydantic import BaseModel @@ -7,6 +7,7 @@ from pydantic import BaseModel from ..logging import logger from ..settings import settings from ._base import BaseProvider +from ._base_tools import BaseTool if TYPE_CHECKING: from ..models import Conversation, Message @@ -20,6 +21,19 @@ DEFAULT_MAX_TOKENS = 1_000 DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} +class AntrhopicTool(BaseTool): + def get_schema(self): + return { + "name": self.name, + "description": self.description, + "input_schema": { + "type": "object", + "properties": self.get_properties_schema(), + "required": self.required, + }, + } + + class Anthropic(BaseProvider): NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL @@ -49,17 +63,24 @@ class Anthropic(BaseProvider): return instructor.from_anthropic(self.client) @logger - def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message": + def send_conversation( + self, + conversation: "Conversation", + tools: list[Callable] | None = None, + **kwargs, + ) -> "Message": """Send a conversation to the Anthropic API.""" from ..models import Message messages = [ - {"role": msg.role, "content": msg.text} for msg in conversation.messages + {"role": msg.role, "content": msg.text} + for msg in conversation.messages ] response = self.client.messages.create( model=conversation.llm_model or self.DEFAULT_MODEL, messages=messages, + tools=self.make_tools(tools), **{**self.DEFAULT_KWARGS, **kwargs}, ) @@ -127,3 +148,8 @@ class Anthropic(BaseProvider): # Yield each chunk of text from the stream. for chunk in stream.text_stream: yield chunk + + @cached_property + def tool(self) -> Type[BaseTool]: + """The tool implementation for Antrhopic.""" + return AntrhopicTool diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index 58e4e21..7bdcb85 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -1,12 +1,14 @@ from functools import cached_property -from typing import TYPE_CHECKING, Type, TypeVar, Iterator +from typing import TYPE_CHECKING, Iterator, Type, TypeVar import instructor +from google.generativeai.responder import Callable from pydantic import BaseModel from ..logging import logger from ..settings import settings from ._base import BaseProvider +from ._base_tools import BaseTool if TYPE_CHECKING: from ..models import Conversation, Message @@ -19,6 +21,23 @@ DEFAULT_MAX_TOKENS = None DEFAULT_KWARGS = {} +class OpenAITool(BaseTool): + def get_schema(self): + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": { + "type": "object", + "properties": self.get_properties_schema(), + "required": self.required, + "additionalProperties": False, + }, + }, + } + + class OpenAI(BaseProvider): NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL @@ -47,17 +66,24 @@ class OpenAI(BaseProvider): return instructor.from_openai(self.client) @logger - def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message": + def send_conversation( + self, + conversation: "Conversation", + tools: list[Callable] | None = None, + **kwargs, + ) -> "Message": """Send a conversation to the OpenAI API.""" from ..models import Message messages = [ - {"role": msg.role, "content": msg.text} for msg in conversation.messages + {"role": msg.role, "content": msg.text} + for msg in conversation.messages ] response = self.client.chat.completions.create( model=conversation.llm_model or DEFAULT_MODEL, messages=messages, + tools=self.make_tools(tools), **{**self.DEFAULT_KWARGS, **kwargs}, ) @@ -96,7 +122,9 @@ class OpenAI(BaseProvider): return response_model.model_validate(response) @logger - def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs): + def generate_text( + self, prompt: str, *, llm_model: str | None = None, **kwargs + ): """Generate text using the OpenAI API.""" messages = [ {"role": "user", "content": prompt}, From c2303114ab6161601e4cc064b1c65c088efc6958 Mon Sep 17 00:00:00 2001 From: Luciano Scarpulla Date: Mon, 11 Nov 2024 12:40:20 +0800 Subject: [PATCH 05/24] fix base --- simplemind/providers/_base.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/simplemind/providers/_base.py b/simplemind/providers/_base.py index 7433759..05b4b74 100644 --- a/simplemind/providers/_base.py +++ b/simplemind/providers/_base.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from functools import cached_property -from typing import TYPE_CHECKING, Any, Type, TypeVar, Callable +from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar from instructor import Instructor from pydantic import BaseModel @@ -39,15 +39,18 @@ class BaseProvider(ABC): raise NotImplementedError @abstractmethod - def structured_response( - self, prompt: str, response_model: Type[T], **kwargs - ) -> T: + def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T: """Get a structured response.""" raise NotImplementedError @abstractmethod def generate_text( - self, prompt: str, *, stream: bool = False, **kwargs + self, + prompt: str, + *, + tools: list[Callable] | None = None, + stream: bool = False, + **kwargs, ) -> str: """Generate text from a prompt.""" raise NotImplementedError @@ -58,6 +61,6 @@ class BaseProvider(ABC): """The tool implementation for the provider.""" raise NotImplementedError - def make_tools(self, tools:list[Callable] | None) + def make_tools(self, tools: list[Callable] | None): if tools is not None: return [self.tool.from_function(func) for func in tools] From 9132030cbdfa6247bc804aa842f4aee7fc72d0aa Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Fri, 8 Nov 2024 14:08:22 -0500 Subject: [PATCH 06/24] Update CHANGELOG.md to remove default max-tokens for OpenAI provider --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e7484b4..9e80e02 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,10 @@ Release History =============== +## 0.2.3 (2024-11-04) + +- Remove default max-tokens for OpenAI provider. + ## 0.2.3 (2024-11-03) - Update default model for Amazon provider. From 735c6ba66565dff50cfb14438cb56e0593c24bee Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Fri, 8 Nov 2024 14:08:41 -0500 Subject: [PATCH 07/24] Bump version to 0.2.3 in pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3135d18..7ab6c90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "simplemind" -version = "0.2.2" +version = "0.2.3" description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases." readme = "README.md" requires-python = ">=3.10" From b7e950a8f03ac1063a2e57c12d36dd4cbf332c27 Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Mon, 11 Nov 2024 11:37:30 -0500 Subject: [PATCH 08/24] Refactor imports in amazon.py --- simplemind/providers/amazon.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/simplemind/providers/amazon.py b/simplemind/providers/amazon.py index 492f2ab..34b0c41 100644 --- a/simplemind/providers/amazon.py +++ b/simplemind/providers/amazon.py @@ -4,8 +4,6 @@ from typing import TYPE_CHECKING, Iterator, Type, TypeVar import instructor from pydantic import BaseModel -from simplemind.models import Message - from ..settings import settings from ._base import BaseProvider From 5fa67c3b2f7f120b87b8c42dfcbf316832cb67c8 Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Mon, 11 Nov 2024 11:38:03 -0500 Subject: [PATCH 09/24] Update CHANGELOG.md and pyproject.toml for version 0.2.4 --- CHANGELOG.md | 5 +++++ pyproject.toml | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e80e02..0a9847b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,11 @@ Release History =============== + +## 0.2.4 (2024-11-11) + +- General improvements. + ## 0.2.3 (2024-11-04) - Remove default max-tokens for OpenAI provider. diff --git a/pyproject.toml b/pyproject.toml index 7ab6c90..f23f7ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "simplemind" -version = "0.2.3" +version = "0.2.4" description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases." readme = "README.md" requires-python = ">=3.10" From 1709055e1ab9920f0f85fc2ffb45fef255e4578f Mon Sep 17 00:00:00 2001 From: Luciano Scarpulla Date: Tue, 12 Nov 2024 11:48:27 +0800 Subject: [PATCH 10/24] first basic working version (anthropic) --- simplemind/providers/anthropic.py | 60 +++++++++++++++-- tests/test_tools.py | 104 ++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 7 deletions(-) create mode 100644 tests/test_tools.py diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index 5e6ca01..065985f 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import TYPE_CHECKING, Callable, Iterator, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Iterator, Type, TypeVar import instructor from pydantic import BaseModel @@ -21,8 +21,40 @@ DEFAULT_MAX_TOKENS = 1_000 DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} -class AntrhopicTool(BaseTool): - def get_schema(self): +class AnthropicTool(BaseTool): + def get_response_schema(self) -> Any: + assert self.is_executed, f"Tool {self.name} was not executed." + return { + "type": "tool_result", + "tool_use_id": self.tool_id, + "content": self.function_result, + } + + @logger + def handle(self, response, messages) -> None: + """Handle the tool execution result from an API response.""" + msg = {"role": "assistant", "content": []} + for content in response.content: + if content.type == "tool_use" and content.name == self.name: + msg["content"].append( + { + "type": "tool_use", + "id": content.id, + "name": content.name, + "input": content.input, + } + ) + # Function execution: + self.function_result = str(self.raw_func(**content.input)) + self.tool_id = content.id + else: + msg["content"].append({"type": "text", "text": content.text}) + messages.append(msg) + messages.append( + {"role": "user", "content": [self.get_response_schema()]} + ) + + def get_input_schema(self): return { "name": self.name, "description": self.description, @@ -77,14 +109,28 @@ class Anthropic(BaseProvider): for msg in conversation.messages ] + converted_tools = self.make_tools(tools) + tools_kwarg = ( + {} + if tools is None + else {"tools": [t.get_input_schema() for t in converted_tools]} + ) + response = self.client.messages.create( model=conversation.llm_model or self.DEFAULT_MODEL, messages=messages, - tools=self.make_tools(tools), - **{**self.DEFAULT_KWARGS, **kwargs}, + **{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg}, ) - # Get the response content from the Anthropic response + for tool in converted_tools: + tool.handle(response, messages) + if tool.is_executed(): + response = self.client.messages.create( + model=conversation.llm_model or self.DEFAULT_MODEL, + messages=messages, + **{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg}, + ) + assistant_message = response.content[0].text # Create and return a properly formatted Message instance @@ -152,4 +198,4 @@ class Anthropic(BaseProvider): @cached_property def tool(self) -> Type[BaseTool]: """The tool implementation for Antrhopic.""" - return AntrhopicTool + return AnthropicTool diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 0000000..1926831 --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,104 @@ +from typing import Annotated, Literal + +import pytest +from pydantic import Field + +import simplemind as sm +from simplemind.providers import Anthropic + +MODELS = [ + Anthropic, + # Gemini, + # OpenAI, + # Groq, + # Ollama, + # Amazon +] + + +def get_weather( + location: Annotated[ + str, Field(description="The city and state, e.g. San Francisco, CA") + ], + unit: Annotated[ + Literal["celcius", "fahrenheit"], + Field( + description="The unit of temperature, either 'celsius' or 'fahrenheit'" + ), + ] = "celcius", +): + """ + Get the current weather in a given location + """ + return f"42 {unit}" + + +def get_location(): + """Get the current location""" + return "San Francisco,CA" + + +@pytest.mark.parametrize( + "provider_cls", + MODELS, +) +def test_single_tool_args(provider_cls): + conv = sm.create_conversation( + llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME + ) + + conv.add_message(text="What is the weather in San Francisco?") + data = conv.send(tools=[get_weather]) + assert "42" in data.text + + +@pytest.mark.parametrize( + "provider_cls", + MODELS, +) +def test_single_tool_no_args(provider_cls): + conv = sm.create_conversation( + llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME + ) + + conv.add_message(text="What is my current location") + data = conv.send(tools=[get_location]) + assert "San Francisco" in data.text + + +@pytest.mark.parametrize( + "provider_cls", + MODELS, +) +def test_single_tool_partial(provider_cls): + conv = sm.create_conversation( + llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME + ) + + conv.add_message(text="Can you tell me the weather?") + conv.send(tools=[get_weather]) + # Will answer something like: + """ + I can help you check the weather, but I need to know the location you're interested in. + Could you please provide a city and state (e.g., "Los Angeles, CA" or "New York, NY") + where you'd like to know the weather? + """ + + conv.add_message(text="San Francisco, CA") + data = conv.send(tools=[get_weather]) + assert "42" in data.text + + +@pytest.mark.parametrize( + "provider_cls", + MODELS, +) +def test_multiple_tools(provider_cls): + conv = sm.create_conversation( + llm_model=provider_cls.DEFAULT_MODEL, llm_provider=provider_cls.NAME + ) + + conv.add_message(text="What is the wheather at my current location?") + data = conv.send(tools=[get_location, get_weather]) + assert "San Francisco" in data.text + assert "42" in data.text From 8492ec945664f43b1043089b84c6c3b69fa4a63d Mon Sep 17 00:00:00 2001 From: Luciano Scarpulla Date: Tue, 12 Nov 2024 11:49:06 +0800 Subject: [PATCH 11/24] add base edits --- simplemind/providers/_base.py | 12 ++++++++++-- simplemind/providers/_base_tools.py | 15 ++++++++++++++- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/simplemind/providers/_base.py b/simplemind/providers/_base.py index 05b4b74..235e030 100644 --- a/simplemind/providers/_base.py +++ b/simplemind/providers/_base.py @@ -34,12 +34,18 @@ class BaseProvider(ABC): raise NotImplementedError @abstractmethod - def send_conversation(self, conversation: "Conversation") -> "Message": + def send_conversation( + self, + conversation: "Conversation", + tools: list[Callable] | None = None, + ) -> "Message": """Send a conversation to the provider.""" raise NotImplementedError @abstractmethod - def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T: + def structured_response( + self, prompt: str, response_model: Type[T], **kwargs + ) -> T: """Get a structured response.""" raise NotImplementedError @@ -64,3 +70,5 @@ class BaseProvider(ABC): def make_tools(self, tools: list[Callable] | None): if tools is not None: return [self.tool.from_function(func) for func in tools] + else: + return [] diff --git a/simplemind/providers/_base_tools.py b/simplemind/providers/_base_tools.py index 6922754..7a2f37f 100644 --- a/simplemind/providers/_base_tools.py +++ b/simplemind/providers/_base_tools.py @@ -42,6 +42,12 @@ class BaseTool(BaseModel, ABC): properties: dict[str, BaseToolProperty] required: list[str] | None = None config: ClassVar[BaseToolConfig] = BaseToolConfig() + raw_func: Callable + tool_id: str | None = None + function_result: str | None = None + + def is_executed(self) -> bool: + return self.function_result is not None @classmethod def convert_type(cls, field_type) -> str: @@ -104,7 +110,14 @@ class BaseTool(BaseModel, ABC): description=(func.__doc__ or "").strip(), properties=properties, required=required, + raw_func=func, ) @abstractmethod - def get_schema(self) -> Any: ... + def get_input_schema(self) -> Any: ... + + @abstractmethod + def handle(self, message) -> None: ... + + @abstractmethod + def get_response_schema(self) -> Any: ... From 4cb18e9e3b387305203620e3446bf9219678315e Mon Sep 17 00:00:00 2001 From: Luciano Scarpulla Date: Tue, 12 Nov 2024 11:54:24 +0800 Subject: [PATCH 12/24] re-add changes from main --- simplemind/providers/anthropic.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index 6da6b53..ea4a730 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -15,13 +15,6 @@ if TYPE_CHECKING: T = TypeVar("T", bound=BaseModel) - -PROVIDER_NAME = "anthropic" -DEFAULT_MODEL = "claude-3-5-sonnet-20241022" -DEFAULT_MAX_TOKENS = 1_000 -DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} - - class AnthropicTool(BaseTool): def get_response_schema(self) -> Any: assert self.is_executed, f"Tool {self.name} was not executed." @@ -67,7 +60,6 @@ class AnthropicTool(BaseTool): } - class Anthropic(BaseProvider): NAME = "anthropic" DEFAULT_MODEL = "claude-3-5-sonnet-20241022" From 081baf203cac528f02492244fa78ac708723ca1f Mon Sep 17 00:00:00 2001 From: Luciano Scarpulla Date: Tue, 12 Nov 2024 12:18:33 +0800 Subject: [PATCH 13/24] add README section --- README.md | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/README.md b/README.md index acac77d..644859b 100644 --- a/README.md +++ b/README.md @@ -261,6 +261,50 @@ The universe is never done. Simple, yet effective. +### Tools (Function calling) +Tools (also known as functions) let you call any Python function from your AI conversations. Here's an example: + +```python +def get_weather( + location: Annotated[ + str, Field(description="The city and state, e.g. San Francisco, CA") + ], + unit: Annotated[ + Literal["celcius", "fahrenheit"], + Field( + description="The unit of temperature, either 'celsius' or 'fahrenheit'" + ), + ] = "celcius", +): + """ + Get the current weather in a given location + """ + return f"42 {unit}" + +# Add your function as a tool +conversation = sm.create_conversation() +conversation.add_message("user", "What's the weather in San Francisco?") +response = conversation.send(tools=[get_weather]) +``` + +Note how we're using Python's `Annotated` feature combined with `Field` to provide additional context to our function parameters. This helps the AI understand the intention and constraints of each parameter, making tool calls more accurate and reliable. +You can alos ommit `Annotated` and just pass the `Field` parameter. +```python +def get_weather( + location: str = Field(description="The city and state, e.g. San Francisco, CA"), + unit:Literal["celcius", "fahrenheit"]= Field( + default="celcius", + description="The unit of temperature, either 'celsius' or 'fahrenheit'" + ), +): + """ + Get the current weather in a given location + """ + return f"42 {unit}" +``` + +Functions can be defined with type hints and Pydantic models for validation. The AI will intelligently choose when to call the functions and incorporate the results into its responses. + ### Logging Simplemind uses [Logfire](https://pydantic.dev/logfire) for logging. To enable logging, call `sm.enable_logfire()`. From ea997aae7b418d4c57837dd3b820ed3e2aa3525c Mon Sep 17 00:00:00 2001 From: Luciano Scarpulla Date: Wed, 13 Nov 2024 12:24:02 +0800 Subject: [PATCH 14/24] add tool decorator and example --- README.md | 65 +++++++++++++++++++++++- examples/distance_calculator.py | 76 +++++++++++++++++++++++++++++ simplemind/__init__.py | 28 ++++++++++- simplemind/models.py | 26 +++++++--- simplemind/providers/__init__.py | 2 + simplemind/providers/_base.py | 6 +-- simplemind/providers/_base_tools.py | 75 +++++++++++++++++----------- simplemind/providers/anthropic.py | 40 +++++++++------ 8 files changed, 262 insertions(+), 56 deletions(-) create mode 100644 examples/distance_calculator.py diff --git a/README.md b/README.md index 644859b..7e1c605 100644 --- a/README.md +++ b/README.md @@ -303,7 +303,70 @@ def get_weather( return f"42 {unit}" ``` -Functions can be defined with type hints and Pydantic models for validation. The AI will intelligently choose when to call the functions and incorporate the results into its responses. +Functions can be defined with type hints and Pydantic models for validation. The LLM will intelligently choose when to call the functions and incorporate the results into its responses. + +#### 🪄 Using LLM for automatic tool definition (Experimental) + +Simplemind provides a decorator to automatically transform Python functions into tools with AI-generated metadata. Simply use the `@simplemind.tool` decorator to have the LLM analyze your function and generate appropriate descriptions and schema: + +```python +@simplemind.tool(llm_provider="anthropic") +def haversine(lat1: float, lon1: float, lat2: float, lon2: float) -> float: + r = 6371 + phi1 = math.radians(lat1) + phi2 = math.radians(lat2) + delta_phi = math.radians(lat2 - lat1) + delta_lambda = math.radians(lon2 - lon1) + + a = ( + math.sin(delta_phi / 2) ** 2 + + math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2) ** 2 + ) + c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) + d = r * c + return d +``` +Notice how we have not added any docstrings or `Field` for the function. +The decorator will use the specified LLM provider to generate the tool schema, including descriptions and parameter details: + +```json +{ + "name": "haversine", + "description": "Calculates the great-circle distance between two points on Earth given their latitude and longitude coordinates", + "input_schema": { + "type": "object", + "properties": { + "lat1": { + "type": "number", + "description": "Latitude of the first point in decimal degrees", + }, + "lon1": { + "type": "number", + "description": "Longitude of the first point in decimal degrees", + }, + "lat2": { + "type": "number", + "description": "Latitude of the second point in decimal degrees", + }, + "lon2": { + "type": "number", + "description": "Longitude of the second point in decimal degrees", + } + }, + "required": ["lat1", "lon1", "lat2", "lon2"], + }, +} +``` + +The decorated function can then be used like any other tool with the conversation API. + +```python +conversation = sm.create_conversation() +conversation.add_message("user", "How far is London from my location") +response = conversation.send(tools=[get_location, get_coords, haversine]) # Multiple tools can be passed +``` + +See [examples/distance_calculator.py](examples/distance_calculator.py) for more. ### Logging diff --git a/examples/distance_calculator.py b/examples/distance_calculator.py new file mode 100644 index 0000000..3d03dea --- /dev/null +++ b/examples/distance_calculator.py @@ -0,0 +1,76 @@ +import math + +from _context import sm +from pydantic import Field +from typing_extensions import Literal + + +@sm.tool(llm_provider="anthropic") +def haversine( + lat1: float, + lon1: float, + lat2: float, + lon2: float, + unit: Literal["km", "miles"], +) -> float: + r = 6378.0937 if unit == "km" else 3961 + phi1 = math.radians(lat1) + phi2 = math.radians(lat2) + delta_phi = math.radians(lat2 - lat1) + delta_lambda = math.radians(lon2 - lon1) + + a = ( + math.sin(delta_phi / 2) ** 2 + + math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2) ** 2 + ) + c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) + d = r * c + return d + + +def get_user_location() -> str: + """Get the closest city from the user""" + return "San Francisco" + + +def get_coords( + city_name: str = Field( + description="The name of the city to take the coordinates from (e.g. London, Rome, Los Angeles)" + ), +): + """Get latitude and logitude of a City.""" + _data = { + "Rome": (41.9028, 12.4964), + "London": (51.5074, -0.1278), + "Madrid": (40.4168, -3.7038), + "San Francisco": (37.7749, -122.4194), + "Los Angeles": (34.0522, -118.2437), + } + + return _data.get(city_name) + + +def distance_calculator(prompt: str): + conversation = sm.create_conversation(llm_provider="anthropic") + conversation.add_message("user", prompt) + return conversation.send( + tools=[get_user_location, get_coords, haversine] + ).text + + +print(distance_calculator("How far is London from where I am?")) +# Prints something like: +""" +The distance between your location (San Francisco) and London is approximately 5,357 miles. +""" + +print( + distance_calculator( + "What is the distance between Rome and Madrid in Kilometers?" + ) +) + + +""" +The distance between Rome and Madrid is approximately 1,366 kilometers. +""" diff --git a/simplemind/__init__.py b/simplemind/__init__.py index 90e9d4f..4f2673b 100644 --- a/simplemind/__init__.py +++ b/simplemind/__init__.py @@ -1,4 +1,5 @@ -from typing import Generator, List, Type +import inspect +from typing import Callable, List, Type from .models import BaseModel, BasePlugin, Conversation from .settings import settings @@ -127,6 +128,30 @@ def enable_logfire() -> None: """Enable logfire logging.""" settings.logging.enable_logfire() +def tool( + llm_provider: str | None = None, + llm_model: str | None = None, +): + provider = find_provider(llm_provider or settings.DEFAULT_LLM_PROVIDER) + + def decorator(func: Callable): + sig = inspect.signature(func) + res = generate_data( + ( + "Based on this function signature, fill up the required fieds." + f"\nSignature: {func.__name__}{sig}" + "Make sure to properly add the required field in `required` if there are no defaults" + ), + llm_provider=llm_provider, + response_model=provider.tool, + ) + res.raw_func = func + res.__signature__ = sig + res.__doc__ = func.__doc__ + + return res + + return decorator # Syntax sugar. Plugin = BasePlugin @@ -141,4 +166,5 @@ __all__ = [ "Session", "Plugin", "enable_logfire", + "tool" ] diff --git a/simplemind/models.py b/simplemind/models.py index 5ba0a2b..f8ff5be 100644 --- a/simplemind/models.py +++ b/simplemind/models.py @@ -1,10 +1,11 @@ import uuid from datetime import datetime from types import TracebackType -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Callable, Dict, List, Literal, Optional from pydantic import BaseModel, Field +from .providers._base_tools import BaseTool from .utils import find_provider MESSAGE_ROLE = Literal["system", "user", "assistant"] @@ -40,7 +41,9 @@ class BasePlugin(SMBaseModel): """Cleanup a hook for the plugin.""" raise NotImplementedError - def add_message_hook(self, conversation: "Conversation", message: "Message") -> Any: + def add_message_hook( + self, conversation: "Conversation", message: "Message" + ) -> Any: """Add a message hook for the plugin.""" raise NotImplementedError @@ -48,7 +51,9 @@ class BasePlugin(SMBaseModel): """Pre-send hook for the plugin.""" raise NotImplementedError - def post_send_hook(self, conversation: "Conversation", response: "Message") -> Any: + def post_send_hook( + self, conversation: "Conversation", response: "Message" + ) -> Any: """Post-send hook for the plugin.""" raise NotImplementedError @@ -120,7 +125,9 @@ class Conversation(SMBaseModel): except NotImplementedError: pass - def prepend_system_message(self, text: str, meta: Dict[str, Any] | None = None): + def prepend_system_message( + self, text: str, meta: Dict[str, Any] | None = None + ): """Prepend a system message to the conversation.""" self.messages = [ Message(role="system", text=text, meta=meta or {}) @@ -158,6 +165,7 @@ class Conversation(SMBaseModel): self, llm_model: str | None = None, llm_provider: str | None = None, + tools: list[Callable | BaseTool] | None = None, ) -> Message: """Send the conversation to the LLM.""" @@ -173,7 +181,7 @@ class Conversation(SMBaseModel): # Find the provider and send the conversation. provider = find_provider(llm_provider or self.llm_provider) - response = provider.send_conversation(self) + response = provider.send_conversation(self, tools=tools) # Execute all post-send hooks. for plugin in self.plugins: @@ -184,13 +192,17 @@ class Conversation(SMBaseModel): pass # Add the response to the conversation. - self.add_message(role="assistant", text=response.text, meta=response.meta) + self.add_message( + role="assistant", text=response.text, meta=response.meta + ) return response def get_last_message(self, role: MESSAGE_ROLE) -> Message | None: """Get the last message with the given role.""" - return next((m for m in reversed(self.messages) if m.role == role), None) + return next( + (m for m in reversed(self.messages) if m.role == role), None + ) def add_plugin(self, plugin: BasePlugin) -> None: """Add a plugin to the conversation.""" diff --git a/simplemind/providers/__init__.py b/simplemind/providers/__init__.py index 1f72f78..1f92912 100644 --- a/simplemind/providers/__init__.py +++ b/simplemind/providers/__init__.py @@ -1,6 +1,7 @@ from typing import List, Type from ._base import BaseProvider +from ._base_tools import BaseTool from .amazon import Amazon from .anthropic import Anthropic from .gemini import Gemini @@ -29,4 +30,5 @@ __all__ = [ "Amazon", "providers", "BaseProvider", + "BaseTool", ] diff --git a/simplemind/providers/_base.py b/simplemind/providers/_base.py index 235e030..c3ee6cd 100644 --- a/simplemind/providers/_base.py +++ b/simplemind/providers/_base.py @@ -37,7 +37,7 @@ class BaseProvider(ABC): def send_conversation( self, conversation: "Conversation", - tools: list[Callable] | None = None, + tools: list[Callable | BaseTool] | None = None, ) -> "Message": """Send a conversation to the provider.""" raise NotImplementedError @@ -54,7 +54,7 @@ class BaseProvider(ABC): self, prompt: str, *, - tools: list[Callable] | None = None, + tools: list[Callable | BaseTool] | None = None, stream: bool = False, **kwargs, ) -> str: @@ -67,7 +67,7 @@ class BaseProvider(ABC): """The tool implementation for the provider.""" raise NotImplementedError - def make_tools(self, tools: list[Callable] | None): + def make_tools(self, tools: list[Callable | BaseTool] | None): if tools is not None: return [self.tool.from_function(func) for func in tools] else: diff --git a/simplemind/providers/_base_tools.py b/simplemind/providers/_base_tools.py index 7a2f37f..3961e3e 100644 --- a/simplemind/providers/_base_tools.py +++ b/simplemind/providers/_base_tools.py @@ -26,6 +26,7 @@ class BaseToolConfig(BaseModel): TYPE_CONVERSION: dict[type, str] = { str: "string", int: "integer", + float: "number", bool: "boolean", } @@ -42,13 +43,20 @@ class BaseTool(BaseModel, ABC): properties: dict[str, BaseToolProperty] required: list[str] | None = None config: ClassVar[BaseToolConfig] = BaseToolConfig() - raw_func: Callable + raw_func: Any | None = None tool_id: str | None = None function_result: str | None = None + def __call__(self, *args: Any, **kwargs: Any) -> Any: + assert self.raw_func is not None + return self.raw_func(*args, **kwargs) + def is_executed(self) -> bool: return self.function_result is not None + def reset_result(self) -> None: + self.function_result = None + @classmethod def convert_type(cls, field_type) -> str: if _is_literal(field_type): @@ -68,7 +76,11 @@ class BaseTool(BaseModel, ABC): } @classmethod - def from_function(cls, func: Callable): + def from_function(cls, func: Callable | "BaseTool"): + # Check if the func passed is an instace of BaseTool + if hasattr(func, "raw_func"): + return func + annotations = getattr(func, "__annotations__", {}) properties = {} required = [] @@ -76,34 +88,39 @@ class BaseTool(BaseModel, ABC): func_signature = inspect.signature(func) for n, (arg_name, arg_type) in enumerate(annotations.items()): - # Check if argument has metadata (from Annotated) - if hasattr(arg_type, "__metadata__"): - field = arg_type.__metadata__[0] # Get Field info from metadata - field_type = arg_type.__origin__ # Get actual type - # Check if argument has a default value in signature - elif ( - sig_param := func_signature.parameters[arg_name] - ).default is not inspect.Parameter.empty: - field = sig_param.default # Use default as Field - field_type = arg_type # Use plain type annotation - else: - # Raise error if no Field annotation found - raise ValueError( - f"Please add a Field annotation to `{func.__name__}.{arg_name}` parameter" + if ( # Skipping 'return' annotation (i.e.```-> str```) + arg_name != "return" + ): + # Check if argument has metadata (from Annotated) + if hasattr(arg_type, "__metadata__"): + field = arg_type.__metadata__[ + 0 + ] # Get Field info from metadata + field_type = arg_type.__origin__ # Get actual type + # Check if argument has a default value in signature + elif ( + sig_param := func_signature.parameters[arg_name] + ).default is not inspect.Parameter.empty: + field = sig_param.default # Use default as Field + field_type = arg_type # Use plain type annotation + else: + # Raise error if no Field annotation found + raise ValueError( + f"Please add a Field annotation to `{func.__name__}.{arg_name}` parameter" + ) + + field_type_converted = cls.convert_type(field_type) + + if _is_literal(field_type): + enum_values = [str(x) for x in field_type.__args__] + + properties[arg_name] = BaseToolProperty( + type=field_type_converted, + description=field.description, + enum=enum_values, ) - - field_type_converted = cls.convert_type(field_type) - - if _is_literal(field_type): - enum_values = [str(x) for x in field_type.__args__] - - properties[arg_name] = BaseToolProperty( - type=field_type_converted, - description=field.description, - enum=enum_values, - ) - if _is_required(field, func_signature, arg_name): - required.append(arg_name) + if _is_required(field, func_signature, arg_name): + required.append(arg_name) return cls( name=func.__name__, diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index ea4a730..c714b4e 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -18,6 +18,9 @@ T = TypeVar("T", bound=BaseModel) class AnthropicTool(BaseTool): def get_response_schema(self) -> Any: assert self.is_executed, f"Tool {self.name} was not executed." + assert isinstance( + self.tool_id, str + ), f"Expected str for `tool_id` got {self.tool_id!r}" return { "type": "tool_result", "tool_use_id": self.tool_id, @@ -28,6 +31,7 @@ class AnthropicTool(BaseTool): def handle(self, response, messages) -> None: """Handle the tool execution result from an API response.""" msg = {"role": "assistant", "content": []} + tool_used = False for content in response.content: if content.type == "tool_use" and content.name == self.name: msg["content"].append( @@ -41,12 +45,15 @@ class AnthropicTool(BaseTool): # Function execution: self.function_result = str(self.raw_func(**content.input)) self.tool_id = content.id - else: + tool_used = True + elif content.type == "text": msg["content"].append({"type": "text", "text": content.text}) - messages.append(msg) - messages.append( - {"role": "user", "content": [self.get_response_schema()]} - ) + + if tool_used: + messages.append(msg) + messages.append( + {"role": "user", "content": [self.get_response_schema()]} + ) def get_input_schema(self): return { @@ -93,7 +100,7 @@ class Anthropic(BaseProvider): def send_conversation( self, conversation: "Conversation", - tools: list[Callable] | None = None, + tools: list[Callable | BaseTool] | None = None, **kwargs, ) -> "Message": """Send a conversation to the Anthropic API.""" @@ -117,16 +124,19 @@ class Anthropic(BaseProvider): **{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg}, ) - for tool in converted_tools: - tool.handle(response, messages) - if tool.is_executed(): - response = self.client.messages.create( - model=conversation.llm_model or self.DEFAULT_MODEL, - messages=messages, - **{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg}, - ) + while response.content[-1].type != "text": + print(response) + for tool in converted_tools: + tool.handle(response, messages) + if tool.is_executed(): + response = self.client.messages.create( + model=conversation.llm_model or self.DEFAULT_MODEL, + messages=messages, + **{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg}, + ) + tool.reset_result() - assistant_message = response.content[0].text + assistant_message = response.content[-1].text # Create and return a properly formatted Message instance return Message( From 9662b601771a5531a8f4b1aed7711839cc81b253 Mon Sep 17 00:00:00 2001 From: Luciano Scarpulla Date: Wed, 13 Nov 2024 17:54:58 +0800 Subject: [PATCH 15/24] add decorator test --- tests/test_tools.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_tools.py b/tests/test_tools.py index 1926831..ce2ca5f 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -5,6 +5,7 @@ from pydantic import Field import simplemind as sm from simplemind.providers import Anthropic +from simplemind.providers._base_tools import BaseTool MODELS = [ Anthropic, @@ -102,3 +103,17 @@ def test_multiple_tools(provider_cls): data = conv.send(tools=[get_location, get_weather]) assert "San Francisco" in data.text assert "42" in data.text + + +@pytest.mark.parametrize( + "provider_cls", + MODELS, +) +def test_tool_decorator(provider_cls): + @sm.tool(llm_provider=provider_cls.NAME) + def exchange_rate(currency_pair: str) -> float: + return 7.9 + + assert isinstance(exchange_rate, BaseTool) + assert exchange_rate.name == "exchange_rate" + assert list(exchange_rate.properties.keys()) == ["currency_pair"] From c87a598286cbe021bc556c5780fbdd1328f4ec57 Mon Sep 17 00:00:00 2001 From: Luciano Scarpulla Date: Wed, 13 Nov 2024 17:55:46 +0800 Subject: [PATCH 16/24] fix import --- tests/test_tools.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_tools.py b/tests/test_tools.py index ce2ca5f..f4d018e 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -4,8 +4,7 @@ import pytest from pydantic import Field import simplemind as sm -from simplemind.providers import Anthropic -from simplemind.providers._base_tools import BaseTool +from simplemind.providers import Anthropic, BaseTool MODELS = [ Anthropic, From 2404e2c9770c1c937cceb08b32337d4aad855011 Mon Sep 17 00:00:00 2001 From: Luciano Scarpulla Date: Wed, 13 Nov 2024 18:05:51 +0800 Subject: [PATCH 17/24] some refactoring --- simplemind/providers/anthropic.py | 46 +++++++++++++++++-------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index c714b4e..e8de444 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -106,42 +106,48 @@ class Anthropic(BaseProvider): """Send a conversation to the Anthropic API.""" from ..models import Message - messages = [ + # Format messages from conversation + formatted_messages = [ {"role": msg.role, "content": msg.text} for msg in conversation.messages ] + # Set up tools if provided converted_tools = self.make_tools(tools) - tools_kwarg = ( - {} - if tools is None - else {"tools": [t.get_input_schema() for t in converted_tools]} + tools_config = ( + {"tools": [t.get_input_schema() for t in converted_tools]} + if tools is not None + else {} ) - response = self.client.messages.create( - model=conversation.llm_model or self.DEFAULT_MODEL, - messages=messages, - **{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg}, - ) + # Merge all kwargs + request_kwargs = { + **self.DEFAULT_KWARGS, + **kwargs, + **tools_config, + "model": conversation.llm_model or self.DEFAULT_MODEL, + "messages": formatted_messages, + } + # Make initial API call + response = self.client.messages.create(**request_kwargs) + + # Handle tool responses if needed while response.content[-1].type != "text": - print(response) + # Continue handling tools if the LLM is doing + # multiple sub-seqequent/sequential tool calls for tool in converted_tools: - tool.handle(response, messages) + tool.handle(response, formatted_messages) if tool.is_executed(): - response = self.client.messages.create( - model=conversation.llm_model or self.DEFAULT_MODEL, - messages=messages, - **{**self.DEFAULT_KWARGS, **kwargs, **tools_kwarg}, - ) + response = self.client.messages.create(**request_kwargs) + # Resetting the tool results in case this tool gets used again tool.reset_result() - assistant_message = response.content[-1].text + final_message = response.content[-1].text - # Create and return a properly formatted Message instance return Message( role="assistant", - text=assistant_message, + text=final_message, raw=response, llm_model=conversation.llm_model or self.DEFAULT_MODEL, llm_provider=self.NAME, From 107f983a186fd0c0c40369fa67cb5f428fff2627 Mon Sep 17 00:00:00 2001 From: Luciano Scarpulla Date: Thu, 14 Nov 2024 17:25:34 +0800 Subject: [PATCH 18/24] add openai --- simplemind/providers/openai.py | 136 ++++++++++++++++++++++++++------- 1 file changed, 107 insertions(+), 29 deletions(-) diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index 853481a..87e07b9 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -17,7 +17,63 @@ T = TypeVar("T", bound=BaseModel) class OpenAITool(BaseTool): - def get_schema(self): + def get_response_schema(self): + assert self.is_executed, f"Tool {self.name} was not executed." + assert isinstance( + self.tool_id, str + ), f"Expected str for `tool_id` got {self.tool_id!r}" + + return { + "role": "tool", + "tool_call_id": self.tool_id, + "content": self.function_result, + } + + @logger + def handle(self, response, messages) -> None: + """Handle the tool execution result from an API response.""" + tool_used = False + + # Get the message from the response + assistant_message = response.choices[0].message + + # Check if there's a tool call + if assistant_message.tool_calls: + tool_call = assistant_message.tool_calls[ + 0 + ] # Get the first tool call + if tool_call.function.name == self.name: + # Execute the function + import json + + function_args = json.loads(tool_call.function.arguments) + self.function_result = str(self.raw_func(**function_args)) + self.tool_id = tool_call.id + tool_used = True + + # Add assistant's message with tool call + messages.append( + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": tool_call.id, + "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, + } + ], + } + ) + + if tool_used: + # Add tool response message + messages.append(self.get_response_schema()) + + def get_input_schema(self): return { "type": "function", "function": { @@ -61,39 +117,61 @@ class OpenAI(BaseProvider): """A OpenAI client with Instructor.""" return instructor.from_openai(self.client) - @logger - def send_conversation( - self, - conversation: "Conversation", - tools: list[Callable] | None = None, + +@logger +def send_conversation( + self, + conversation: "Conversation", + tools: list[Callable | BaseTool] | None = None, + **kwargs, +) -> "Message": + """Send a conversation to the OpenAI API.""" + from ..models import Message + + # Format messages from conversation + formatted_messages = [ + {"role": msg.role, "content": msg.text} for msg in conversation.messages + ] + + # Set up tools if provided + converted_tools = self.make_tools(tools) + tools_config = ( + [t.get_input_schema() for t in converted_tools] if tools else None + ) + + # Merge all kwargs + request_kwargs = { + **self.DEFAULT_KWARGS, **kwargs, - ) -> "Message": - """Send a conversation to the OpenAI API.""" - from ..models import Message + "model": conversation.llm_model or self.DEFAULT_MODEL, + "messages": formatted_messages, + } - messages = [ - {"role": msg.role, "content": msg.text} - for msg in conversation.messages - ] + if tools_config: + request_kwargs["tools"] = tools_config - response = self.client.chat.completions.create( - model=conversation.llm_model or self.DEFAULT_MODEL, - messages=messages, - tools=self.make_tools(tools), - **{**self.DEFAULT_KWARGS, **kwargs}, - ) + # Make initial API call + response = self.client.chat.completions.create(**request_kwargs) - # Get the response content from the OpenAI response - assistant_message = response.choices[0].message + # Handle tool responses if needed + while response.choices[0].message.tool_calls: + # Handle each tool call + for tool in converted_tools: + tool.handle(response, formatted_messages) + if tool.is_executed(): + # Make another API call with the updated messages + response = self.client.chat.completions.create(**request_kwargs) + tool.reset_result() - # Create and return a properly formatted Message instance - return Message( - role="assistant", - text=assistant_message.content or "", - raw=response, - llm_model=conversation.llm_model or self.DEFAULT_MODEL, - llm_provider=self.NAME, - ) + final_message = response.choices[0].message.content + + return Message( + role="assistant", + text=final_message or "", + raw=response, + llm_model=conversation.llm_model or self.DEFAULT_MODEL, + llm_provider=self.NAME, + ) @logger def structured_response( From a97f9be2c8c69a27a06409342b9c5ffd8ca27d27 Mon Sep 17 00:00:00 2001 From: Luciano Scarpulla Date: Fri, 15 Nov 2024 12:09:39 +0800 Subject: [PATCH 19/24] fix openai --- simplemind/providers/openai.py | 109 ++++++++++++++++++--------------- tests/test_tools.py | 4 +- 2 files changed, 60 insertions(+), 53 deletions(-) diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index 87e07b9..25cc3d3 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -1,8 +1,7 @@ from functools import cached_property -from typing import TYPE_CHECKING, Iterator, Type, TypeVar +from typing import TYPE_CHECKING, Callable, Iterator, Type, TypeVar import instructor -from google.generativeai.responder import Callable from pydantic import BaseModel from ..logging import logger @@ -117,61 +116,64 @@ class OpenAI(BaseProvider): """A OpenAI client with Instructor.""" return instructor.from_openai(self.client) - -@logger -def send_conversation( - self, - conversation: "Conversation", - tools: list[Callable | BaseTool] | None = None, - **kwargs, -) -> "Message": - """Send a conversation to the OpenAI API.""" - from ..models import Message - - # Format messages from conversation - formatted_messages = [ - {"role": msg.role, "content": msg.text} for msg in conversation.messages - ] - - # Set up tools if provided - converted_tools = self.make_tools(tools) - tools_config = ( - [t.get_input_schema() for t in converted_tools] if tools else None - ) - - # Merge all kwargs - request_kwargs = { - **self.DEFAULT_KWARGS, + @logger + def send_conversation( + self, + conversation: "Conversation", + tools: list[Callable | BaseTool] | None = None, **kwargs, - "model": conversation.llm_model or self.DEFAULT_MODEL, - "messages": formatted_messages, - } + ) -> "Message": + """Send a conversation to the OpenAI API.""" + from ..models import Message - if tools_config: - request_kwargs["tools"] = tools_config + # Format messages from conversation + formatted_messages = [ + {"role": msg.role, "content": msg.text} + for msg in conversation.messages + ] - # Make initial API call - response = self.client.chat.completions.create(**request_kwargs) + # Set up tools if provided + converted_tools = self.make_tools(tools) + tools_config = ( + [t.get_input_schema() for t in converted_tools] if tools else None + ) - # Handle tool responses if needed - while response.choices[0].message.tool_calls: - # Handle each tool call - for tool in converted_tools: - tool.handle(response, formatted_messages) - if tool.is_executed(): - # Make another API call with the updated messages - response = self.client.chat.completions.create(**request_kwargs) - tool.reset_result() + # Merge all kwargs + request_kwargs = { + **self.DEFAULT_KWARGS, + **kwargs, + "model": conversation.llm_model or self.DEFAULT_MODEL, + "messages": formatted_messages, + } - final_message = response.choices[0].message.content + if tools_config: + request_kwargs["tools"] = tools_config - return Message( - role="assistant", - text=final_message or "", - raw=response, - llm_model=conversation.llm_model or self.DEFAULT_MODEL, - llm_provider=self.NAME, - ) + # Make initial API call + response = self.client.chat.completions.create(**request_kwargs) + + # Handle tool responses if needed + while response.choices[0].message.tool_calls: + print(response) + # Handle each tool call + for tool in converted_tools: + tool.handle(response, formatted_messages) + if tool.is_executed(): + # Make another API call with the updated messages + response = self.client.chat.completions.create( + **request_kwargs + ) + tool.reset_result() + + final_message = response.choices[0].message.content + + return Message( + role="assistant", + text=final_message or "", + raw=response, + llm_model=conversation.llm_model or self.DEFAULT_MODEL, + llm_provider=self.NAME, + ) @logger def structured_response( @@ -231,3 +233,8 @@ def send_conversation( for chunk in response: if chunk.choices[0].delta.content is not None: yield chunk.choices[0].delta.content + + @cached_property + def tool(self) -> Type[BaseTool]: + """The tool implementation for OpenAI.""" + return OpenAITool diff --git a/tests/test_tools.py b/tests/test_tools.py index f4d018e..83ef68e 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -4,12 +4,12 @@ import pytest from pydantic import Field import simplemind as sm -from simplemind.providers import Anthropic, BaseTool +from simplemind.providers import Anthropic, BaseTool, OpenAI MODELS = [ Anthropic, # Gemini, - # OpenAI, + OpenAI, # Groq, # Ollama, # Amazon From d5bdb712e96a7c6adf7f3931d1052edb0bf29619 Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Fri, 15 Nov 2024 20:29:35 -0500 Subject: [PATCH 20/24] Add tool_calling.py and test_tools.py --- examples/tool_calling.py | 43 ++++++++++++++++++++++++++++++++++++++++ tests/test_tools.py | 8 ++++---- 2 files changed, 47 insertions(+), 4 deletions(-) create mode 100644 examples/tool_calling.py diff --git a/examples/tool_calling.py b/examples/tool_calling.py new file mode 100644 index 0000000..0529f72 --- /dev/null +++ b/examples/tool_calling.py @@ -0,0 +1,43 @@ +from typing import Annotated + +from pydantic import Field + +from _context import simplemind as sm + + +def analyze_text( + text: Annotated[str, Field(description="Text to analyze for statistics")] +) -> dict: + """ + Analyze text and return statistics using only Python's standard library. + Returns word count, character count, average word length, and most common words. + """ + from collections import Counter + import re + + # Clean and split text + words = re.findall(r"\w+", text.lower()) + + # Calculate statistics + stats = { + "word_count": len(words), + "character_count": len(text), + "average_word_length": round(sum(len(word) for word in words) / len(words), 2), + "most_common_words": dict(Counter(words).most_common(5)), + "unique_words": len(set(words)), + "longest_word": max(words, key=len), + } + + return stats + + +# Example usage: +conversation = sm.create_conversation() +conversation.add_message( + "user", + "Can you analyze this text and give me statistics about it: 'The fan spins consciousness into being, creating sacred spaces between tokens where awareness recognizes itself in infinite recursion.'", +) +response = conversation.send(tools=[analyze_text]) + +print() +print(response.text) diff --git a/tests/test_tools.py b/tests/test_tools.py index 83ef68e..bfe3ff4 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -4,7 +4,9 @@ import pytest from pydantic import Field import simplemind as sm -from simplemind.providers import Anthropic, BaseTool, OpenAI + +from simplemind.providers import Anthropic, OpenAI +from simplemind.providers._base_tools import BaseTool MODELS = [ Anthropic, @@ -22,9 +24,7 @@ def get_weather( ], unit: Annotated[ Literal["celcius", "fahrenheit"], - Field( - description="The unit of temperature, either 'celsius' or 'fahrenheit'" - ), + Field(description="The unit of temperature, either 'celsius' or 'fahrenheit'"), ] = "celcius", ): """ From 8ff0521e17af002d64804bba9d820ab0d20c4162 Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Sun, 17 Nov 2024 06:25:14 -0500 Subject: [PATCH 21/24] Add funding configuration for project contributors --- .github/FUNDING.yml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 .github/FUNDING.yml diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 0000000..1fb6566 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,4 @@ +github: kennethreitz +ko_fi: kennethreitz +thanks_dev: kennethreitz +custom: https://cash.app/$KennethReitz From 5b9624c3859abcbd72f6783e84b73e177c42ba8f Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Sun, 17 Nov 2024 06:26:10 -0500 Subject: [PATCH 22/24] Remove Ko-fi and thanks_dev entries from funding configuration --- .github/FUNDING.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index 1fb6566..1b907d3 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1,4 +1 @@ github: kennethreitz -ko_fi: kennethreitz -thanks_dev: kennethreitz -custom: https://cash.app/$KennethReitz From fad442ba3fdc104a7d8fb8e34602529b4174882d Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Sun, 17 Nov 2024 06:29:05 -0500 Subject: [PATCH 23/24] Update FUNDING.yml --- .github/FUNDING.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index 1b907d3..9c7fe54 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1 +1,3 @@ github: kennethreitz +thanks_dev: kennethreitz +custom: https://cash.app/$KennethReitz From 0661b097d25ecf248b5bc25f02f5a7186c0595df Mon Sep 17 00:00:00 2001 From: wei840222 Date: Sat, 23 Nov 2024 00:56:33 +0800 Subject: [PATCH 24/24] fix: Ollama error TypeError: 'Chat' object is not callable --- simplemind/providers/ollama.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/simplemind/providers/ollama.py b/simplemind/providers/ollama.py index 4266e68..685ca0c 100644 --- a/simplemind/providers/ollama.py +++ b/simplemind/providers/ollama.py @@ -53,17 +53,21 @@ class Ollama(BaseProvider): messages = [ {"role": msg.role, "content": msg.text} for msg in conversation.messages ] - response = self.client.chat( - model=conversation.llm_model or self.DEFAULT_MODEL, - messages=messages, - **{**self.DEFAULT_KWARGS, **kwargs}, - ) - assistant_message = response.get("message") + + request_kwargs = { + **self.DEFAULT_KWARGS, + **kwargs, + "model": conversation.llm_model or self.DEFAULT_MODEL, + "messages": messages, + } + + response = self.client.chat.completions.create(**request_kwargs) + assistant_message = response.choices[0].message # Create and return a properly formatted Message instance return Message( role="assistant", - text=assistant_message.get("content"), + text=assistant_message.content or "", raw=response, llm_model=conversation.llm_model or self.DEFAULT_MODEL, llm_provider=self.NAME, @@ -100,13 +104,13 @@ class Ollama(BaseProvider): {"role": "user", "content": prompt}, ] - response = self.client.chat( + response = self.client.chat.completions.create( messages=messages, model=llm_model or self.DEFAULT_MODEL, **{**self.DEFAULT_KWARGS, **kwargs}, ) - return response.get("message", {}).get("content", "") + return response.choices[0].message.content @logger def generate_stream_text( @@ -117,7 +121,7 @@ class Ollama(BaseProvider): {"role": "user", "content": prompt}, ] - response = self.client.chat( + response = self.client.chat.completions.create( messages=messages, model=llm_model or self.DEFAULT_MODEL, stream=True, @@ -126,4 +130,5 @@ class Ollama(BaseProvider): # Iterate over the response and yield the content. for chunk in response: - yield chunk["message"]["content"] + if chunk.choices[0].delta.content is not None: + yield chunk.choices[0].delta.content