From 62e6a40389066889aa21181f38257f3dd00f2b65 Mon Sep 17 00:00:00 2001 From: Jason Liu Date: Wed, 31 Jan 2024 23:03:38 -0500 Subject: [PATCH] Implement Parallel Function Calls with `List[Union[T]]` (#378) --- .github/workflows/test.yml | 5 +- docs/api.md | 8 ++- docs/concepts/parallel.md | 62 +++++++++++++++++++++ examples/parallel/run.py | 33 +++++++++++ examples/patching/oai.py | 1 - examples/patching/pcalls.py | 86 ++++++++++++++++++++++++++++ instructor/dsl/parallel.py | 63 +++++++++++++++++++++ instructor/function_calls.py | 105 +---------------------------------- instructor/patch.py | 58 +++++++++++++++---- mkdocs.yml | 1 + tests/test_function_calls.py | 8 --- 11 files changed, 301 insertions(+), 129 deletions(-) create mode 100644 docs/concepts/parallel.md create mode 100644 examples/parallel/run.py create mode 100644 examples/patching/pcalls.py create mode 100644 instructor/dsl/parallel.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4e19041..c7063c0 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,14 +1,11 @@ name: Test - on: push: - branches: + branches: - main pull_request_target: branches: - main - - jobs: release: runs-on: ubuntu-latest diff --git a/docs/api.md b/docs/api.md index dbf1323..01e2495 100644 --- a/docs/api.md +++ b/docs/api.md @@ -4,10 +4,12 @@ ::: instructor.dsl.validators -::: instructor.dsl.citation +::: instructor.dsl.iterable -::: instructor.dsl.multitask +::: instructor.dsl.partial + +::: instructor.dsl.parallel ::: instructor.dsl.maybe -::: instructor.function_calls \ No newline at end of file +::: instructor.function_calls diff --git a/docs/concepts/parallel.md b/docs/concepts/parallel.md new file mode 100644 index 0000000..bee433a --- /dev/null +++ b/docs/concepts/parallel.md @@ -0,0 +1,62 @@ +# Parallel Tools + +One of the latest capabilities that OpenAI has recently introduced is parallel function calling. +To learn more you can read up on [this](https://platform.openai.com/docs/guides/function-calling/parallel-function-calling) + +!!! warning "Experimental Feature" + + This feature is currently in preview and is subject to change. only supported by the `gpt-4-turbo-preview` model. + +## Understanding Parallel Function Calling + +By using parallel function callings that allow you to call multiple functions in a single request, you can significantly reduce the latency of your application without having to use tricks with now one builds a schema. + +```python hl_lines="19 31" +import openai +import instructor + +from typing import Iterable, Literal +from pydantic import BaseModel + + +class Weather(BaseModel): + location: str + units: Literal["imperial", "metric"] + + +class GoogleSearch(BaseModel): + query: str + + +client = instructor.patch( + openai.OpenAI(), + mode=instructor.Mode.PARALLEL_TOOLS #(1)! +) + +function_calls = client.chat.completions.create( + model="gpt-4-turbo-preview", + messages=[ + {"role": "system", "content": "You must always use tools"}, + { + "role": "user", + "content": "What is the weather in toronto and dallas and who won the super bowl?", + }, + ], + response_model=Iterable[Weather | GoogleSearch], #(2)! +) + +for fc in function_calls: + print(fc) + """ +``` + +1. Set the mode to `PARALLEL_TOOLS` to enable parallel function calling. +2. Set the response model to `Iterable[Weather | GoogleSearch]` to indicate that the response will be a list of `Weather` and `GoogleSearch` objects. This is necessary because the response will be a list of objects, and we need to specify the types of the objects in the list. + +```python +Weather(location='toronto', units='imperial') +Weather(location='dallas', units='imperial') +GoogleSearch(query='who won the super bowl?') +``` + +Noticed that the `response_model` Must be in the form `Iterable[Type1 | Type2 | ...]` or `Iterable[Type1]` where `Type1` and `Type2` are the types of the objects that will be returned in the response. diff --git a/examples/parallel/run.py b/examples/parallel/run.py new file mode 100644 index 0000000..1046c02 --- /dev/null +++ b/examples/parallel/run.py @@ -0,0 +1,33 @@ +import openai +import instructor + +from typing import Iterable, Literal +from pydantic import BaseModel + + +class Weather(BaseModel): + location: str + units: Literal["imperial", "metric"] + + +class GoogleSearch(BaseModel): + query: str + + +client = openai.OpenAI() + +client = instructor.patch(client, mode=instructor.Mode.PARALLEL_TOOLS) +resp = client.chat.completions.create( + model="gpt-4-turbo-preview", + messages=[ + {"role": "system", "content": "You must always use tools"}, + { + "role": "user", + "content": "What is the weather in toronto and dallas and who won the super bowl?", + }, + ], + response_model=Iterable[Weather | GoogleSearch], +) + +for r in resp: + print(r) diff --git a/examples/patching/oai.py b/examples/patching/oai.py index 205e9bf..a671e65 100644 --- a/examples/patching/oai.py +++ b/examples/patching/oai.py @@ -1,4 +1,3 @@ -import os import instructor from openai import OpenAI diff --git a/examples/patching/pcalls.py b/examples/patching/pcalls.py new file mode 100644 index 0000000..fd1a076 --- /dev/null +++ b/examples/patching/pcalls.py @@ -0,0 +1,86 @@ +from typing import Iterable, Literal, List, Union +from pydantic import BaseModel +from instructor import OpenAISchema + +import time +import openai +import instructor + + +client = openai.OpenAI() + + +class Weather(OpenAISchema): + location: str + units: Literal["imperial", "metric"] + + +class GoogleSearch(OpenAISchema): + query: str + + +if __name__ == "__main__": + + class Query(BaseModel): + query: List[Union[Weather, GoogleSearch]] + + client = instructor.patch(client, mode=instructor.Mode.PARALLEL_TOOLS) + + start = time.perf_counter() + resp = client.chat.completions.create( + model="gpt-4-turbo-preview", + messages=[ + {"role": "system", "content": "You must always use tools"}, + { + "role": "user", + "content": "What is the weather in toronto and dallas and who won the super bowl?", + }, + ], + response_model=Iterable[Union[Weather, GoogleSearch]], + ) + print(f"# Time: {time.perf_counter() - start:.2f}") + + print("# Instructor: Question with Toronto and Super Bowl") + print([model for model in resp]) + + start = time.perf_counter() + resp = client.chat.completions.create( + model="gpt-4-turbo-preview", + messages=[ + { + "role": "user", + "content": "What is the weather in toronto and dallas?", + }, + ], + tools=[ + {"type": "function", "function": Weather.openai_schema}, + {"type": "function", "function": GoogleSearch.openai_schema}, + ], + tool_choice="auto", + ) + print(f"# Time: {time.perf_counter() - start:.2f}") + + print("# Question with Toronto and Dallas") + for tool_call in resp.choices[0].message.tool_calls: + print(tool_call.model_dump_json(indent=2)) + + start = time.perf_counter() + resp = client.chat.completions.create( + model="gpt-4-turbo-preview", + messages=[ + { + "role": "user", + "content": "What is the weather in toronto? and who won the super bowl?", + }, + ], + tools=[ + {"type": "function", "function": Weather.openai_schema}, + {"type": "function", "function": GoogleSearch.openai_schema}, + ], + tool_choice="auto", + ) + print(f"# Time: {time.perf_counter() - start:.2f}") + + print("# Question with Toronto and Super Bowl") + for tool_call in resp.choices[0].message.tool_calls: + print(tool_call.model_dump_json(indent=2)) diff --git a/instructor/dsl/parallel.py b/instructor/dsl/parallel.py new file mode 100644 index 0000000..7369134 --- /dev/null +++ b/instructor/dsl/parallel.py @@ -0,0 +1,63 @@ +from typing import Type, TypeVar, Union, get_origin, get_args +from types import UnionType + +from instructor.function_calls import OpenAISchema, Mode, openai_schema +from collections.abc import Iterable + +T = TypeVar("T") + + +class ParallelBase: + def __init__(self, *models: Type[OpenAISchema]): + # Note that for everything else we've created a class, but for parallel base it is an instance + assert len(models) > 0, "At least one model is required" + self.models = models + self.registry = {model.__name__: model for model in models} + + def from_response( + self, + response, + mode: Mode, + validation_context=None, + strict: bool = None, + ) -> Iterable[Union[T]]: + #! We expect this from the OpenAISchema class, We should address + #! this with a protocol or an abstract class... @jxnlco + assert mode == Mode.PARALLEL_TOOLS, "Mode must be PARALLEL_TOOLS" + for tool_call in response.choices[0].message.tool_calls: + name = tool_call.function.name + arguments = tool_call.function.arguments + yield self.registry[name].model_validate_json( + arguments, context=validation_context, strict=strict + ) + + +def get_types_array(typehint: Type[Iterable[Union[T]]]): + should_be_iterable = get_origin(typehint) + assert should_be_iterable is Iterable + + if get_origin(get_args(typehint)[0]) is Union: + # works for Iterable[Union[int, str]] + the_types = get_args(get_args(typehint)[0]) + return the_types + + if get_origin(get_args(typehint)[0]) is UnionType: + # works for Iterable[Union[int, str]] + the_types = get_args(get_args(typehint)[0]) + return the_types + + # works for Iterable[int] + return get_args(typehint) + + +def handle_parallel_model(typehint: Type[Iterable[Union[T]]]): + the_types = get_types_array(typehint) + return [ + {"type": "function", "function": openai_schema(model).openai_schema} + for model in the_types + ] + + +def ParallelModel(typehint): + the_types = get_types_array(typehint) + return ParallelBase(*[model for model in the_types]) diff --git a/instructor/function_calls.py b/instructor/function_calls.py index 6872b25..5b3ba1d 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -1,4 +1,4 @@ -from typing import Type, TypeVar, Self +from typing import Type, TypeVar from docstring_parser import parse from functools import wraps from pydantic import BaseModel, create_model @@ -13,6 +13,7 @@ class Mode(enum.Enum): """The mode to use for patching the client""" FUNCTIONS: str = "function_call" + PARALLEL_TOOLS: str = "parallel_tool_call" TOOLS: str = "tool_call" JSON: str = "json_mode" MD_JSON: str = "markdown_json_mode" @@ -34,46 +35,6 @@ class Mode(enum.Enum): class OpenAISchema(BaseModel): - """ - Augments a Pydantic model with OpenAI's schema for function calling - - This class augments a Pydantic model with OpenAI's schema for function calling. The schema is generated from the model's signature and docstring. The schema can be used to validate the response from OpenAI's API and extract the function call. - - ## Usage - - ```python - from instructor import OpenAISchema - - class User(OpenAISchema): - name: str - age: int - - completion = openai.ChatCompletion.create( - model="gpt-3.5-turbo", - messages=[{ - "content": "Jason is 20 years old", - "role": "user" - }], - functions=[User.openai_schema], - function_call={"name": User.openai_schema["name"]}, - ) - - user = User.from_response(completion) - - print(user.model_dump()) - ``` - ## Result - - ``` - { - "name": "Jason Liu", - "age": 20, - } - ``` - - - """ - @classmethod @property def openai_schema(cls): @@ -124,7 +85,7 @@ class OpenAISchema(BaseModel): validation_context: dict = None, strict: bool = None, mode: Mode = Mode.TOOLS, - ) -> Self: + ): """Execute the function from the response of an openai chat completion Parameters: @@ -133,7 +94,6 @@ class OpenAISchema(BaseModel): validation_context (dict): The validation context to use for validating the response strict (bool): Whether to use strict json parsing mode (Mode): The openai completion mode - stream_multitask (bool): Whether to stream a multitask response Returns: cls (OpenAISchema): An instance of the class @@ -174,65 +134,6 @@ class OpenAISchema(BaseModel): else: raise ValueError(f"Invalid patch mode: {mode}") - @classmethod - async def from_response_async( - cls, - completion, - validation_context: dict = None, - strict: bool = None, - mode: Mode = Mode.TOOLS, - stream_multitask: bool = False, - stream_partial: bool = False, - ): - """Execute the function from the response of an openai chat completion - - Parameters: - completion (openai.ChatCompletion): The response from an openai chat completion - validation_context (dict): The validation context to use for validating the response - strict (bool): Whether to use strict json parsing - mode (Mode): The openai completion mode - stream_multitask (bool): Whether to stream a multitask response - - Returns: - cls (OpenAISchema): An instance of the class - """ - - if completion.choices[0].finish_reason == "length": - raise IncompleteOutputException() - - message = completion.choices[0].message - - if mode == Mode.FUNCTIONS: - assert ( - message.function_call.name == cls.openai_schema["name"] - ), "Function name does not match" - return cls.model_validate_json( - message.function_call.arguments, - context=validation_context, - strict=strict, - ) - elif mode == Mode.TOOLS: - assert ( - len(message.tool_calls) == 1 - ), "Instructor does not support multiple tool calls, use List[Model] instead." - tool_call = message.tool_calls[0] - assert ( - tool_call.function.name == cls.openai_schema["name"] - ), "Tool name does not match" - return cls.model_validate_json( - tool_call.function.arguments, - context=validation_context, - strict=strict, - ) - elif mode in {Mode.JSON, Mode.JSON_SCHEMA, Mode.MD_JSON}: - return cls.model_validate_json( - message.content, - context=validation_context, - strict=strict, - ) - else: - raise ValueError(f"Invalid patch mode: {mode}") - def openai_schema(cls: Type[BaseModel]) -> OpenAISchema: if not issubclass(cls, BaseModel): diff --git a/instructor/patch.py b/instructor/patch.py index 1985c5a..86dff2f 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -27,11 +27,13 @@ from openai.types.completion_usage import CompletionUsage from pydantic import BaseModel, ValidationError from instructor.dsl.iterable import IterableModel, IterableBase +from instructor.dsl.parallel import ParallelBase, ParallelModel, handle_parallel_model from instructor.dsl.partial import PartialBase from .function_calls import Mode, OpenAISchema, openai_schema logger = logging.getLogger("instructor") +T = TypeVar("T") T_Model = TypeVar("T_Model", bound=BaseModel) @@ -78,6 +80,19 @@ def handle_response_model( """ new_kwargs = kwargs.copy() if response_model is not None: + # This a special case for parallel tools + if mode == Mode.PARALLEL_TOOLS: + assert ( + new_kwargs.get("stream", False) is False + ), "stream=True is not supported when using PARALLEL_TOOLS mode" + new_kwargs["tools"] = handle_parallel_model(response_model) + new_kwargs["tool_choice"] = "auto" + + # This is a special case for parallel models + response_model = ParallelModel(typehint=response_model) + return response_model, new_kwargs + + # This is for all other single model cases if get_origin(response_model) is Iterable: iterable_element_class = get_args(response_model)[0] response_model = IterableModel(iterable_element_class) @@ -178,7 +193,11 @@ def process_response( if response_model is None: return response - if issubclass(response_model, (IterableBase, PartialBase)) and stream: + if ( + inspect.isclass(response_model) + and issubclass(response_model, (IterableBase, PartialBase)) + and stream + ): model = response_model.from_streaming_response( response, mode=mode, @@ -191,12 +210,17 @@ def process_response( strict=strict, mode=mode, ) - model._raw_response = response - if issubclass(response_model, IterableBase): - # If the response model is a multitask, return the tasks + # ? This really hints at the fact that we need a better way of + # ? attaching usage data and the raw response to the model we return. + if isinstance(response_model, IterableBase): + #! If the response model is a multitask, return the tasks return [task for task in model.tasks] + if isinstance(response_model, ParallelBase): + return model + + model._raw_response = response return model @@ -223,22 +247,34 @@ async def process_response_async( if response_model is None: return response - if issubclass(response_model, (IterableBase, PartialBase)) and stream: + if ( + inspect.isclass(response_model) + and issubclass(response_model, (IterableBase, PartialBase)) + and stream + ): model = await response_model.from_streaming_response_async( response, mode=mode, ) return model - model = await response_model.from_response_async( + model = response_model.from_response( response, validation_context=validation_context, strict=strict, mode=mode, ) + + # ? This really hints at the fact that we need a better way of + # ? attaching usage data and the raw response to the model we return. + if isinstance(response_model, IterableBase): + #! If the response model is a multitask, return the tasks + return [task for task in model.tasks] + + if isinstance(response_model, ParallelBase): + return model + model._raw_response = response - if issubclass(response_model, IterableBase): - return model.tasks return model @@ -276,7 +312,6 @@ async def retry_async( except (ValidationError, JSONDecodeError) as e: logger.exception(f"Retrying, exception: {e}") logger.debug(f"Error response: {response}") - kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore if mode == Mode.TOOLS: kwargs["messages"].append( { @@ -286,6 +321,7 @@ async def retry_async( "content": "failure", } ) + kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore kwargs["messages"].append( { "role": "user", @@ -339,16 +375,16 @@ def retry_sync( except (ValidationError, JSONDecodeError) as e: logger.exception(f"Retrying, exception: {e}") logger.debug(f"Error response: {response}") - kwargs["messages"].append(dump_message(response.choices[0].message)) if mode == Mode.TOOLS: kwargs["messages"].append( { "role": "tool", "tool_call_id": response.choices[0].message.tool_calls[0].id, "name": response.choices[0].message.tool_calls[0].function.name, - "content": "failure", + "content": f"Recall the function correctly, fix the errors and exceptions found\n{e}", } ) + kwargs["messages"].append(dump_message(response.choices[0].message)) kwargs["messages"].append( { "role": "user", diff --git a/mkdocs.yml b/mkdocs.yml index b8f3c7a..c5b5ce3 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -134,6 +134,7 @@ nav: - Fields: 'concepts/fields.md' - Missing: "concepts/maybe.md" - Patching: 'concepts/patching.md' + - Parallel Tools: 'concepts/parallel.md' - Stream Iterable: "concepts/lists.md" - Stream Partial: "concepts/partial.md" - Raw Response: 'concepts/raw_response.md' diff --git a/tests/test_function_calls.py b/tests/test_function_calls.py index cd03126..9a0db9e 100644 --- a/tests/test_function_calls.py +++ b/tests/test_function_calls.py @@ -112,11 +112,3 @@ def test_complete_output_no_exception(test_model, mock_completion): async def test_incomplete_output_exception_raise(test_model, mock_completion): with pytest.raises(IncompleteOutputException): await test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS) - - -@pytest.mark.asyncio -async def test_async_complete_output_no_exception(test_model, mock_completion): - test_model_instance = await test_model.from_response_async( - mock_completion, mode=instructor.Mode.FUNCTIONS - ) - assert test_model_instance.data == "complete data"