Implement Parallel Function Calls with List[Union[T]] (#378)

This commit is contained in:
Jason Liu
2024-01-31 23:03:38 -05:00
committed by GitHub
parent c7f1ceeb5c
commit 62e6a40389
11 changed files with 301 additions and 129 deletions
+1 -4
View File
@@ -1,14 +1,11 @@
name: Test
on:
push:
branches:
branches:
- main
pull_request_target:
branches:
- main
jobs:
release:
runs-on: ubuntu-latest
+5 -3
View File
@@ -4,10 +4,12 @@
::: instructor.dsl.validators
::: instructor.dsl.citation
::: instructor.dsl.iterable
::: instructor.dsl.multitask
::: instructor.dsl.partial
::: instructor.dsl.parallel
::: instructor.dsl.maybe
::: instructor.function_calls
::: instructor.function_calls
+62
View File
@@ -0,0 +1,62 @@
# Parallel Tools
One of the latest capabilities that OpenAI has recently introduced is parallel function calling.
To learn more you can read up on [this](https://platform.openai.com/docs/guides/function-calling/parallel-function-calling)
!!! warning "Experimental Feature"
This feature is currently in preview and is subject to change. only supported by the `gpt-4-turbo-preview` model.
## Understanding Parallel Function Calling
By using parallel function callings that allow you to call multiple functions in a single request, you can significantly reduce the latency of your application without having to use tricks with now one builds a schema.
```python hl_lines="19 31"
import openai
import instructor
from typing import Iterable, Literal
from pydantic import BaseModel
class Weather(BaseModel):
location: str
units: Literal["imperial", "metric"]
class GoogleSearch(BaseModel):
query: str
client = instructor.patch(
openai.OpenAI(),
mode=instructor.Mode.PARALLEL_TOOLS #(1)!
)
function_calls = client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=[
{"role": "system", "content": "You must always use tools"},
{
"role": "user",
"content": "What is the weather in toronto and dallas and who won the super bowl?",
},
],
response_model=Iterable[Weather | GoogleSearch], #(2)!
)
for fc in function_calls:
print(fc)
"""
```
1. Set the mode to `PARALLEL_TOOLS` to enable parallel function calling.
2. Set the response model to `Iterable[Weather | GoogleSearch]` to indicate that the response will be a list of `Weather` and `GoogleSearch` objects. This is necessary because the response will be a list of objects, and we need to specify the types of the objects in the list.
```python
Weather(location='toronto', units='imperial')
Weather(location='dallas', units='imperial')
GoogleSearch(query='who won the super bowl?')
```
Noticed that the `response_model` Must be in the form `Iterable[Type1 | Type2 | ...]` or `Iterable[Type1]` where `Type1` and `Type2` are the types of the objects that will be returned in the response.
+33
View File
@@ -0,0 +1,33 @@
import openai
import instructor
from typing import Iterable, Literal
from pydantic import BaseModel
class Weather(BaseModel):
location: str
units: Literal["imperial", "metric"]
class GoogleSearch(BaseModel):
query: str
client = openai.OpenAI()
client = instructor.patch(client, mode=instructor.Mode.PARALLEL_TOOLS)
resp = client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=[
{"role": "system", "content": "You must always use tools"},
{
"role": "user",
"content": "What is the weather in toronto and dallas and who won the super bowl?",
},
],
response_model=Iterable[Weather | GoogleSearch],
)
for r in resp:
print(r)
-1
View File
@@ -1,4 +1,3 @@
import os
import instructor
from openai import OpenAI
+86
View File
@@ -0,0 +1,86 @@
from typing import Iterable, Literal, List, Union
from pydantic import BaseModel
from instructor import OpenAISchema
import time
import openai
import instructor
client = openai.OpenAI()
class Weather(OpenAISchema):
location: str
units: Literal["imperial", "metric"]
class GoogleSearch(OpenAISchema):
query: str
if __name__ == "__main__":
class Query(BaseModel):
query: List[Union[Weather, GoogleSearch]]
client = instructor.patch(client, mode=instructor.Mode.PARALLEL_TOOLS)
start = time.perf_counter()
resp = client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=[
{"role": "system", "content": "You must always use tools"},
{
"role": "user",
"content": "What is the weather in toronto and dallas and who won the super bowl?",
},
],
response_model=Iterable[Union[Weather, GoogleSearch]],
)
print(f"# Time: {time.perf_counter() - start:.2f}")
print("# Instructor: Question with Toronto and Super Bowl")
print([model for model in resp])
start = time.perf_counter()
resp = client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=[
{
"role": "user",
"content": "What is the weather in toronto and dallas?",
},
],
tools=[
{"type": "function", "function": Weather.openai_schema},
{"type": "function", "function": GoogleSearch.openai_schema},
],
tool_choice="auto",
)
print(f"# Time: {time.perf_counter() - start:.2f}")
print("# Question with Toronto and Dallas")
for tool_call in resp.choices[0].message.tool_calls:
print(tool_call.model_dump_json(indent=2))
start = time.perf_counter()
resp = client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=[
{
"role": "user",
"content": "What is the weather in toronto? and who won the super bowl?",
},
],
tools=[
{"type": "function", "function": Weather.openai_schema},
{"type": "function", "function": GoogleSearch.openai_schema},
],
tool_choice="auto",
)
print(f"# Time: {time.perf_counter() - start:.2f}")
print("# Question with Toronto and Super Bowl")
for tool_call in resp.choices[0].message.tool_calls:
print(tool_call.model_dump_json(indent=2))
+63
View File
@@ -0,0 +1,63 @@
from typing import Type, TypeVar, Union, get_origin, get_args
from types import UnionType
from instructor.function_calls import OpenAISchema, Mode, openai_schema
from collections.abc import Iterable
T = TypeVar("T")
class ParallelBase:
def __init__(self, *models: Type[OpenAISchema]):
# Note that for everything else we've created a class, but for parallel base it is an instance
assert len(models) > 0, "At least one model is required"
self.models = models
self.registry = {model.__name__: model for model in models}
def from_response(
self,
response,
mode: Mode,
validation_context=None,
strict: bool = None,
) -> Iterable[Union[T]]:
#! We expect this from the OpenAISchema class, We should address
#! this with a protocol or an abstract class... @jxnlco
assert mode == Mode.PARALLEL_TOOLS, "Mode must be PARALLEL_TOOLS"
for tool_call in response.choices[0].message.tool_calls:
name = tool_call.function.name
arguments = tool_call.function.arguments
yield self.registry[name].model_validate_json(
arguments, context=validation_context, strict=strict
)
def get_types_array(typehint: Type[Iterable[Union[T]]]):
should_be_iterable = get_origin(typehint)
assert should_be_iterable is Iterable
if get_origin(get_args(typehint)[0]) is Union:
# works for Iterable[Union[int, str]]
the_types = get_args(get_args(typehint)[0])
return the_types
if get_origin(get_args(typehint)[0]) is UnionType:
# works for Iterable[Union[int, str]]
the_types = get_args(get_args(typehint)[0])
return the_types
# works for Iterable[int]
return get_args(typehint)
def handle_parallel_model(typehint: Type[Iterable[Union[T]]]):
the_types = get_types_array(typehint)
return [
{"type": "function", "function": openai_schema(model).openai_schema}
for model in the_types
]
def ParallelModel(typehint):
the_types = get_types_array(typehint)
return ParallelBase(*[model for model in the_types])
+3 -102
View File
@@ -1,4 +1,4 @@
from typing import Type, TypeVar, Self
from typing import Type, TypeVar
from docstring_parser import parse
from functools import wraps
from pydantic import BaseModel, create_model
@@ -13,6 +13,7 @@ class Mode(enum.Enum):
"""The mode to use for patching the client"""
FUNCTIONS: str = "function_call"
PARALLEL_TOOLS: str = "parallel_tool_call"
TOOLS: str = "tool_call"
JSON: str = "json_mode"
MD_JSON: str = "markdown_json_mode"
@@ -34,46 +35,6 @@ class Mode(enum.Enum):
class OpenAISchema(BaseModel):
"""
Augments a Pydantic model with OpenAI's schema for function calling
This class augments a Pydantic model with OpenAI's schema for function calling. The schema is generated from the model's signature and docstring. The schema can be used to validate the response from OpenAI's API and extract the function call.
## Usage
```python
from instructor import OpenAISchema
class User(OpenAISchema):
name: str
age: int
completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[{
"content": "Jason is 20 years old",
"role": "user"
}],
functions=[User.openai_schema],
function_call={"name": User.openai_schema["name"]},
)
user = User.from_response(completion)
print(user.model_dump())
```
## Result
```
{
"name": "Jason Liu",
"age": 20,
}
```
"""
@classmethod
@property
def openai_schema(cls):
@@ -124,7 +85,7 @@ class OpenAISchema(BaseModel):
validation_context: dict = None,
strict: bool = None,
mode: Mode = Mode.TOOLS,
) -> Self:
):
"""Execute the function from the response of an openai chat completion
Parameters:
@@ -133,7 +94,6 @@ class OpenAISchema(BaseModel):
validation_context (dict): The validation context to use for validating the response
strict (bool): Whether to use strict json parsing
mode (Mode): The openai completion mode
stream_multitask (bool): Whether to stream a multitask response
Returns:
cls (OpenAISchema): An instance of the class
@@ -174,65 +134,6 @@ class OpenAISchema(BaseModel):
else:
raise ValueError(f"Invalid patch mode: {mode}")
@classmethod
async def from_response_async(
cls,
completion,
validation_context: dict = None,
strict: bool = None,
mode: Mode = Mode.TOOLS,
stream_multitask: bool = False,
stream_partial: bool = False,
):
"""Execute the function from the response of an openai chat completion
Parameters:
completion (openai.ChatCompletion): The response from an openai chat completion
validation_context (dict): The validation context to use for validating the response
strict (bool): Whether to use strict json parsing
mode (Mode): The openai completion mode
stream_multitask (bool): Whether to stream a multitask response
Returns:
cls (OpenAISchema): An instance of the class
"""
if completion.choices[0].finish_reason == "length":
raise IncompleteOutputException()
message = completion.choices[0].message
if mode == Mode.FUNCTIONS:
assert (
message.function_call.name == cls.openai_schema["name"]
), "Function name does not match"
return cls.model_validate_json(
message.function_call.arguments,
context=validation_context,
strict=strict,
)
elif mode == Mode.TOOLS:
assert (
len(message.tool_calls) == 1
), "Instructor does not support multiple tool calls, use List[Model] instead."
tool_call = message.tool_calls[0]
assert (
tool_call.function.name == cls.openai_schema["name"]
), "Tool name does not match"
return cls.model_validate_json(
tool_call.function.arguments,
context=validation_context,
strict=strict,
)
elif mode in {Mode.JSON, Mode.JSON_SCHEMA, Mode.MD_JSON}:
return cls.model_validate_json(
message.content,
context=validation_context,
strict=strict,
)
else:
raise ValueError(f"Invalid patch mode: {mode}")
def openai_schema(cls: Type[BaseModel]) -> OpenAISchema:
if not issubclass(cls, BaseModel):
+47 -11
View File
@@ -27,11 +27,13 @@ from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel, ValidationError
from instructor.dsl.iterable import IterableModel, IterableBase
from instructor.dsl.parallel import ParallelBase, ParallelModel, handle_parallel_model
from instructor.dsl.partial import PartialBase
from .function_calls import Mode, OpenAISchema, openai_schema
logger = logging.getLogger("instructor")
T = TypeVar("T")
T_Model = TypeVar("T_Model", bound=BaseModel)
@@ -78,6 +80,19 @@ def handle_response_model(
"""
new_kwargs = kwargs.copy()
if response_model is not None:
# This a special case for parallel tools
if mode == Mode.PARALLEL_TOOLS:
assert (
new_kwargs.get("stream", False) is False
), "stream=True is not supported when using PARALLEL_TOOLS mode"
new_kwargs["tools"] = handle_parallel_model(response_model)
new_kwargs["tool_choice"] = "auto"
# This is a special case for parallel models
response_model = ParallelModel(typehint=response_model)
return response_model, new_kwargs
# This is for all other single model cases
if get_origin(response_model) is Iterable:
iterable_element_class = get_args(response_model)[0]
response_model = IterableModel(iterable_element_class)
@@ -178,7 +193,11 @@ def process_response(
if response_model is None:
return response
if issubclass(response_model, (IterableBase, PartialBase)) and stream:
if (
inspect.isclass(response_model)
and issubclass(response_model, (IterableBase, PartialBase))
and stream
):
model = response_model.from_streaming_response(
response,
mode=mode,
@@ -191,12 +210,17 @@ def process_response(
strict=strict,
mode=mode,
)
model._raw_response = response
if issubclass(response_model, IterableBase):
# If the response model is a multitask, return the tasks
# ? This really hints at the fact that we need a better way of
# ? attaching usage data and the raw response to the model we return.
if isinstance(response_model, IterableBase):
#! If the response model is a multitask, return the tasks
return [task for task in model.tasks]
if isinstance(response_model, ParallelBase):
return model
model._raw_response = response
return model
@@ -223,22 +247,34 @@ async def process_response_async(
if response_model is None:
return response
if issubclass(response_model, (IterableBase, PartialBase)) and stream:
if (
inspect.isclass(response_model)
and issubclass(response_model, (IterableBase, PartialBase))
and stream
):
model = await response_model.from_streaming_response_async(
response,
mode=mode,
)
return model
model = await response_model.from_response_async(
model = response_model.from_response(
response,
validation_context=validation_context,
strict=strict,
mode=mode,
)
# ? This really hints at the fact that we need a better way of
# ? attaching usage data and the raw response to the model we return.
if isinstance(response_model, IterableBase):
#! If the response model is a multitask, return the tasks
return [task for task in model.tasks]
if isinstance(response_model, ParallelBase):
return model
model._raw_response = response
if issubclass(response_model, IterableBase):
return model.tasks
return model
@@ -276,7 +312,6 @@ async def retry_async(
except (ValidationError, JSONDecodeError) as e:
logger.exception(f"Retrying, exception: {e}")
logger.debug(f"Error response: {response}")
kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore
if mode == Mode.TOOLS:
kwargs["messages"].append(
{
@@ -286,6 +321,7 @@ async def retry_async(
"content": "failure",
}
)
kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore
kwargs["messages"].append(
{
"role": "user",
@@ -339,16 +375,16 @@ def retry_sync(
except (ValidationError, JSONDecodeError) as e:
logger.exception(f"Retrying, exception: {e}")
logger.debug(f"Error response: {response}")
kwargs["messages"].append(dump_message(response.choices[0].message))
if mode == Mode.TOOLS:
kwargs["messages"].append(
{
"role": "tool",
"tool_call_id": response.choices[0].message.tool_calls[0].id,
"name": response.choices[0].message.tool_calls[0].function.name,
"content": "failure",
"content": f"Recall the function correctly, fix the errors and exceptions found\n{e}",
}
)
kwargs["messages"].append(dump_message(response.choices[0].message))
kwargs["messages"].append(
{
"role": "user",
+1
View File
@@ -134,6 +134,7 @@ nav:
- Fields: 'concepts/fields.md'
- Missing: "concepts/maybe.md"
- Patching: 'concepts/patching.md'
- Parallel Tools: 'concepts/parallel.md'
- Stream Iterable: "concepts/lists.md"
- Stream Partial: "concepts/partial.md"
- Raw Response: 'concepts/raw_response.md'
-8
View File
@@ -112,11 +112,3 @@ def test_complete_output_no_exception(test_model, mock_completion):
async def test_incomplete_output_exception_raise(test_model, mock_completion):
with pytest.raises(IncompleteOutputException):
await test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS)
@pytest.mark.asyncio
async def test_async_complete_output_no_exception(test_model, mock_completion):
test_model_instance = await test_model.from_response_async(
mock_completion, mode=instructor.Mode.FUNCTIONS
)
assert test_model_instance.data == "complete data"