mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
Implement Parallel Function Calls with List[Union[T]] (#378)
This commit is contained in:
@@ -1,14 +1,11 @@
|
||||
name: Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
branches:
|
||||
- main
|
||||
pull_request_target:
|
||||
branches:
|
||||
- main
|
||||
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
+5
-3
@@ -4,10 +4,12 @@
|
||||
|
||||
::: instructor.dsl.validators
|
||||
|
||||
::: instructor.dsl.citation
|
||||
::: instructor.dsl.iterable
|
||||
|
||||
::: instructor.dsl.multitask
|
||||
::: instructor.dsl.partial
|
||||
|
||||
::: instructor.dsl.parallel
|
||||
|
||||
::: instructor.dsl.maybe
|
||||
|
||||
::: instructor.function_calls
|
||||
::: instructor.function_calls
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
# Parallel Tools
|
||||
|
||||
One of the latest capabilities that OpenAI has recently introduced is parallel function calling.
|
||||
To learn more you can read up on [this](https://platform.openai.com/docs/guides/function-calling/parallel-function-calling)
|
||||
|
||||
!!! warning "Experimental Feature"
|
||||
|
||||
This feature is currently in preview and is subject to change. only supported by the `gpt-4-turbo-preview` model.
|
||||
|
||||
## Understanding Parallel Function Calling
|
||||
|
||||
By using parallel function callings that allow you to call multiple functions in a single request, you can significantly reduce the latency of your application without having to use tricks with now one builds a schema.
|
||||
|
||||
```python hl_lines="19 31"
|
||||
import openai
|
||||
import instructor
|
||||
|
||||
from typing import Iterable, Literal
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Weather(BaseModel):
|
||||
location: str
|
||||
units: Literal["imperial", "metric"]
|
||||
|
||||
|
||||
class GoogleSearch(BaseModel):
|
||||
query: str
|
||||
|
||||
|
||||
client = instructor.patch(
|
||||
openai.OpenAI(),
|
||||
mode=instructor.Mode.PARALLEL_TOOLS #(1)!
|
||||
)
|
||||
|
||||
function_calls = client.chat.completions.create(
|
||||
model="gpt-4-turbo-preview",
|
||||
messages=[
|
||||
{"role": "system", "content": "You must always use tools"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in toronto and dallas and who won the super bowl?",
|
||||
},
|
||||
],
|
||||
response_model=Iterable[Weather | GoogleSearch], #(2)!
|
||||
)
|
||||
|
||||
for fc in function_calls:
|
||||
print(fc)
|
||||
"""
|
||||
```
|
||||
|
||||
1. Set the mode to `PARALLEL_TOOLS` to enable parallel function calling.
|
||||
2. Set the response model to `Iterable[Weather | GoogleSearch]` to indicate that the response will be a list of `Weather` and `GoogleSearch` objects. This is necessary because the response will be a list of objects, and we need to specify the types of the objects in the list.
|
||||
|
||||
```python
|
||||
Weather(location='toronto', units='imperial')
|
||||
Weather(location='dallas', units='imperial')
|
||||
GoogleSearch(query='who won the super bowl?')
|
||||
```
|
||||
|
||||
Noticed that the `response_model` Must be in the form `Iterable[Type1 | Type2 | ...]` or `Iterable[Type1]` where `Type1` and `Type2` are the types of the objects that will be returned in the response.
|
||||
@@ -0,0 +1,33 @@
|
||||
import openai
|
||||
import instructor
|
||||
|
||||
from typing import Iterable, Literal
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Weather(BaseModel):
|
||||
location: str
|
||||
units: Literal["imperial", "metric"]
|
||||
|
||||
|
||||
class GoogleSearch(BaseModel):
|
||||
query: str
|
||||
|
||||
|
||||
client = openai.OpenAI()
|
||||
|
||||
client = instructor.patch(client, mode=instructor.Mode.PARALLEL_TOOLS)
|
||||
resp = client.chat.completions.create(
|
||||
model="gpt-4-turbo-preview",
|
||||
messages=[
|
||||
{"role": "system", "content": "You must always use tools"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in toronto and dallas and who won the super bowl?",
|
||||
},
|
||||
],
|
||||
response_model=Iterable[Weather | GoogleSearch],
|
||||
)
|
||||
|
||||
for r in resp:
|
||||
print(r)
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import instructor
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
from typing import Iterable, Literal, List, Union
|
||||
from pydantic import BaseModel
|
||||
from instructor import OpenAISchema
|
||||
|
||||
import time
|
||||
import openai
|
||||
import instructor
|
||||
|
||||
|
||||
client = openai.OpenAI()
|
||||
|
||||
|
||||
class Weather(OpenAISchema):
|
||||
location: str
|
||||
units: Literal["imperial", "metric"]
|
||||
|
||||
|
||||
class GoogleSearch(OpenAISchema):
|
||||
query: str
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
class Query(BaseModel):
|
||||
query: List[Union[Weather, GoogleSearch]]
|
||||
|
||||
client = instructor.patch(client, mode=instructor.Mode.PARALLEL_TOOLS)
|
||||
|
||||
start = time.perf_counter()
|
||||
resp = client.chat.completions.create(
|
||||
model="gpt-4-turbo-preview",
|
||||
messages=[
|
||||
{"role": "system", "content": "You must always use tools"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in toronto and dallas and who won the super bowl?",
|
||||
},
|
||||
],
|
||||
response_model=Iterable[Union[Weather, GoogleSearch]],
|
||||
)
|
||||
print(f"# Time: {time.perf_counter() - start:.2f}")
|
||||
|
||||
print("# Instructor: Question with Toronto and Super Bowl")
|
||||
print([model for model in resp])
|
||||
|
||||
start = time.perf_counter()
|
||||
resp = client.chat.completions.create(
|
||||
model="gpt-4-turbo-preview",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in toronto and dallas?",
|
||||
},
|
||||
],
|
||||
tools=[
|
||||
{"type": "function", "function": Weather.openai_schema},
|
||||
{"type": "function", "function": GoogleSearch.openai_schema},
|
||||
],
|
||||
tool_choice="auto",
|
||||
)
|
||||
print(f"# Time: {time.perf_counter() - start:.2f}")
|
||||
|
||||
print("# Question with Toronto and Dallas")
|
||||
for tool_call in resp.choices[0].message.tool_calls:
|
||||
print(tool_call.model_dump_json(indent=2))
|
||||
|
||||
start = time.perf_counter()
|
||||
resp = client.chat.completions.create(
|
||||
model="gpt-4-turbo-preview",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in toronto? and who won the super bowl?",
|
||||
},
|
||||
],
|
||||
tools=[
|
||||
{"type": "function", "function": Weather.openai_schema},
|
||||
{"type": "function", "function": GoogleSearch.openai_schema},
|
||||
],
|
||||
tool_choice="auto",
|
||||
)
|
||||
print(f"# Time: {time.perf_counter() - start:.2f}")
|
||||
|
||||
print("# Question with Toronto and Super Bowl")
|
||||
for tool_call in resp.choices[0].message.tool_calls:
|
||||
print(tool_call.model_dump_json(indent=2))
|
||||
@@ -0,0 +1,63 @@
|
||||
from typing import Type, TypeVar, Union, get_origin, get_args
|
||||
from types import UnionType
|
||||
|
||||
from instructor.function_calls import OpenAISchema, Mode, openai_schema
|
||||
from collections.abc import Iterable
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ParallelBase:
|
||||
def __init__(self, *models: Type[OpenAISchema]):
|
||||
# Note that for everything else we've created a class, but for parallel base it is an instance
|
||||
assert len(models) > 0, "At least one model is required"
|
||||
self.models = models
|
||||
self.registry = {model.__name__: model for model in models}
|
||||
|
||||
def from_response(
|
||||
self,
|
||||
response,
|
||||
mode: Mode,
|
||||
validation_context=None,
|
||||
strict: bool = None,
|
||||
) -> Iterable[Union[T]]:
|
||||
#! We expect this from the OpenAISchema class, We should address
|
||||
#! this with a protocol or an abstract class... @jxnlco
|
||||
assert mode == Mode.PARALLEL_TOOLS, "Mode must be PARALLEL_TOOLS"
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
name = tool_call.function.name
|
||||
arguments = tool_call.function.arguments
|
||||
yield self.registry[name].model_validate_json(
|
||||
arguments, context=validation_context, strict=strict
|
||||
)
|
||||
|
||||
|
||||
def get_types_array(typehint: Type[Iterable[Union[T]]]):
|
||||
should_be_iterable = get_origin(typehint)
|
||||
assert should_be_iterable is Iterable
|
||||
|
||||
if get_origin(get_args(typehint)[0]) is Union:
|
||||
# works for Iterable[Union[int, str]]
|
||||
the_types = get_args(get_args(typehint)[0])
|
||||
return the_types
|
||||
|
||||
if get_origin(get_args(typehint)[0]) is UnionType:
|
||||
# works for Iterable[Union[int, str]]
|
||||
the_types = get_args(get_args(typehint)[0])
|
||||
return the_types
|
||||
|
||||
# works for Iterable[int]
|
||||
return get_args(typehint)
|
||||
|
||||
|
||||
def handle_parallel_model(typehint: Type[Iterable[Union[T]]]):
|
||||
the_types = get_types_array(typehint)
|
||||
return [
|
||||
{"type": "function", "function": openai_schema(model).openai_schema}
|
||||
for model in the_types
|
||||
]
|
||||
|
||||
|
||||
def ParallelModel(typehint):
|
||||
the_types = get_types_array(typehint)
|
||||
return ParallelBase(*[model for model in the_types])
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Type, TypeVar, Self
|
||||
from typing import Type, TypeVar
|
||||
from docstring_parser import parse
|
||||
from functools import wraps
|
||||
from pydantic import BaseModel, create_model
|
||||
@@ -13,6 +13,7 @@ class Mode(enum.Enum):
|
||||
"""The mode to use for patching the client"""
|
||||
|
||||
FUNCTIONS: str = "function_call"
|
||||
PARALLEL_TOOLS: str = "parallel_tool_call"
|
||||
TOOLS: str = "tool_call"
|
||||
JSON: str = "json_mode"
|
||||
MD_JSON: str = "markdown_json_mode"
|
||||
@@ -34,46 +35,6 @@ class Mode(enum.Enum):
|
||||
|
||||
|
||||
class OpenAISchema(BaseModel):
|
||||
"""
|
||||
Augments a Pydantic model with OpenAI's schema for function calling
|
||||
|
||||
This class augments a Pydantic model with OpenAI's schema for function calling. The schema is generated from the model's signature and docstring. The schema can be used to validate the response from OpenAI's API and extract the function call.
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from instructor import OpenAISchema
|
||||
|
||||
class User(OpenAISchema):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
completion = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"content": "Jason is 20 years old",
|
||||
"role": "user"
|
||||
}],
|
||||
functions=[User.openai_schema],
|
||||
function_call={"name": User.openai_schema["name"]},
|
||||
)
|
||||
|
||||
user = User.from_response(completion)
|
||||
|
||||
print(user.model_dump())
|
||||
```
|
||||
## Result
|
||||
|
||||
```
|
||||
{
|
||||
"name": "Jason Liu",
|
||||
"age": 20,
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def openai_schema(cls):
|
||||
@@ -124,7 +85,7 @@ class OpenAISchema(BaseModel):
|
||||
validation_context: dict = None,
|
||||
strict: bool = None,
|
||||
mode: Mode = Mode.TOOLS,
|
||||
) -> Self:
|
||||
):
|
||||
"""Execute the function from the response of an openai chat completion
|
||||
|
||||
Parameters:
|
||||
@@ -133,7 +94,6 @@ class OpenAISchema(BaseModel):
|
||||
validation_context (dict): The validation context to use for validating the response
|
||||
strict (bool): Whether to use strict json parsing
|
||||
mode (Mode): The openai completion mode
|
||||
stream_multitask (bool): Whether to stream a multitask response
|
||||
|
||||
Returns:
|
||||
cls (OpenAISchema): An instance of the class
|
||||
@@ -174,65 +134,6 @@ class OpenAISchema(BaseModel):
|
||||
else:
|
||||
raise ValueError(f"Invalid patch mode: {mode}")
|
||||
|
||||
@classmethod
|
||||
async def from_response_async(
|
||||
cls,
|
||||
completion,
|
||||
validation_context: dict = None,
|
||||
strict: bool = None,
|
||||
mode: Mode = Mode.TOOLS,
|
||||
stream_multitask: bool = False,
|
||||
stream_partial: bool = False,
|
||||
):
|
||||
"""Execute the function from the response of an openai chat completion
|
||||
|
||||
Parameters:
|
||||
completion (openai.ChatCompletion): The response from an openai chat completion
|
||||
validation_context (dict): The validation context to use for validating the response
|
||||
strict (bool): Whether to use strict json parsing
|
||||
mode (Mode): The openai completion mode
|
||||
stream_multitask (bool): Whether to stream a multitask response
|
||||
|
||||
Returns:
|
||||
cls (OpenAISchema): An instance of the class
|
||||
"""
|
||||
|
||||
if completion.choices[0].finish_reason == "length":
|
||||
raise IncompleteOutputException()
|
||||
|
||||
message = completion.choices[0].message
|
||||
|
||||
if mode == Mode.FUNCTIONS:
|
||||
assert (
|
||||
message.function_call.name == cls.openai_schema["name"]
|
||||
), "Function name does not match"
|
||||
return cls.model_validate_json(
|
||||
message.function_call.arguments,
|
||||
context=validation_context,
|
||||
strict=strict,
|
||||
)
|
||||
elif mode == Mode.TOOLS:
|
||||
assert (
|
||||
len(message.tool_calls) == 1
|
||||
), "Instructor does not support multiple tool calls, use List[Model] instead."
|
||||
tool_call = message.tool_calls[0]
|
||||
assert (
|
||||
tool_call.function.name == cls.openai_schema["name"]
|
||||
), "Tool name does not match"
|
||||
return cls.model_validate_json(
|
||||
tool_call.function.arguments,
|
||||
context=validation_context,
|
||||
strict=strict,
|
||||
)
|
||||
elif mode in {Mode.JSON, Mode.JSON_SCHEMA, Mode.MD_JSON}:
|
||||
return cls.model_validate_json(
|
||||
message.content,
|
||||
context=validation_context,
|
||||
strict=strict,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid patch mode: {mode}")
|
||||
|
||||
|
||||
def openai_schema(cls: Type[BaseModel]) -> OpenAISchema:
|
||||
if not issubclass(cls, BaseModel):
|
||||
|
||||
+47
-11
@@ -27,11 +27,13 @@ from openai.types.completion_usage import CompletionUsage
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from instructor.dsl.iterable import IterableModel, IterableBase
|
||||
from instructor.dsl.parallel import ParallelBase, ParallelModel, handle_parallel_model
|
||||
from instructor.dsl.partial import PartialBase
|
||||
|
||||
from .function_calls import Mode, OpenAISchema, openai_schema
|
||||
|
||||
logger = logging.getLogger("instructor")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
T_Model = TypeVar("T_Model", bound=BaseModel)
|
||||
@@ -78,6 +80,19 @@ def handle_response_model(
|
||||
"""
|
||||
new_kwargs = kwargs.copy()
|
||||
if response_model is not None:
|
||||
# This a special case for parallel tools
|
||||
if mode == Mode.PARALLEL_TOOLS:
|
||||
assert (
|
||||
new_kwargs.get("stream", False) is False
|
||||
), "stream=True is not supported when using PARALLEL_TOOLS mode"
|
||||
new_kwargs["tools"] = handle_parallel_model(response_model)
|
||||
new_kwargs["tool_choice"] = "auto"
|
||||
|
||||
# This is a special case for parallel models
|
||||
response_model = ParallelModel(typehint=response_model)
|
||||
return response_model, new_kwargs
|
||||
|
||||
# This is for all other single model cases
|
||||
if get_origin(response_model) is Iterable:
|
||||
iterable_element_class = get_args(response_model)[0]
|
||||
response_model = IterableModel(iterable_element_class)
|
||||
@@ -178,7 +193,11 @@ def process_response(
|
||||
if response_model is None:
|
||||
return response
|
||||
|
||||
if issubclass(response_model, (IterableBase, PartialBase)) and stream:
|
||||
if (
|
||||
inspect.isclass(response_model)
|
||||
and issubclass(response_model, (IterableBase, PartialBase))
|
||||
and stream
|
||||
):
|
||||
model = response_model.from_streaming_response(
|
||||
response,
|
||||
mode=mode,
|
||||
@@ -191,12 +210,17 @@ def process_response(
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
)
|
||||
model._raw_response = response
|
||||
|
||||
if issubclass(response_model, IterableBase):
|
||||
# If the response model is a multitask, return the tasks
|
||||
# ? This really hints at the fact that we need a better way of
|
||||
# ? attaching usage data and the raw response to the model we return.
|
||||
if isinstance(response_model, IterableBase):
|
||||
#! If the response model is a multitask, return the tasks
|
||||
return [task for task in model.tasks]
|
||||
|
||||
if isinstance(response_model, ParallelBase):
|
||||
return model
|
||||
|
||||
model._raw_response = response
|
||||
return model
|
||||
|
||||
|
||||
@@ -223,22 +247,34 @@ async def process_response_async(
|
||||
if response_model is None:
|
||||
return response
|
||||
|
||||
if issubclass(response_model, (IterableBase, PartialBase)) and stream:
|
||||
if (
|
||||
inspect.isclass(response_model)
|
||||
and issubclass(response_model, (IterableBase, PartialBase))
|
||||
and stream
|
||||
):
|
||||
model = await response_model.from_streaming_response_async(
|
||||
response,
|
||||
mode=mode,
|
||||
)
|
||||
return model
|
||||
|
||||
model = await response_model.from_response_async(
|
||||
model = response_model.from_response(
|
||||
response,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
# ? This really hints at the fact that we need a better way of
|
||||
# ? attaching usage data and the raw response to the model we return.
|
||||
if isinstance(response_model, IterableBase):
|
||||
#! If the response model is a multitask, return the tasks
|
||||
return [task for task in model.tasks]
|
||||
|
||||
if isinstance(response_model, ParallelBase):
|
||||
return model
|
||||
|
||||
model._raw_response = response
|
||||
if issubclass(response_model, IterableBase):
|
||||
return model.tasks
|
||||
return model
|
||||
|
||||
|
||||
@@ -276,7 +312,6 @@ async def retry_async(
|
||||
except (ValidationError, JSONDecodeError) as e:
|
||||
logger.exception(f"Retrying, exception: {e}")
|
||||
logger.debug(f"Error response: {response}")
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore
|
||||
if mode == Mode.TOOLS:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
@@ -286,6 +321,7 @@ async def retry_async(
|
||||
"content": "failure",
|
||||
}
|
||||
)
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "user",
|
||||
@@ -339,16 +375,16 @@ def retry_sync(
|
||||
except (ValidationError, JSONDecodeError) as e:
|
||||
logger.exception(f"Retrying, exception: {e}")
|
||||
logger.debug(f"Error response: {response}")
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message))
|
||||
if mode == Mode.TOOLS:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": response.choices[0].message.tool_calls[0].id,
|
||||
"name": response.choices[0].message.tool_calls[0].function.name,
|
||||
"content": "failure",
|
||||
"content": f"Recall the function correctly, fix the errors and exceptions found\n{e}",
|
||||
}
|
||||
)
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message))
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "user",
|
||||
|
||||
@@ -134,6 +134,7 @@ nav:
|
||||
- Fields: 'concepts/fields.md'
|
||||
- Missing: "concepts/maybe.md"
|
||||
- Patching: 'concepts/patching.md'
|
||||
- Parallel Tools: 'concepts/parallel.md'
|
||||
- Stream Iterable: "concepts/lists.md"
|
||||
- Stream Partial: "concepts/partial.md"
|
||||
- Raw Response: 'concepts/raw_response.md'
|
||||
|
||||
@@ -112,11 +112,3 @@ def test_complete_output_no_exception(test_model, mock_completion):
|
||||
async def test_incomplete_output_exception_raise(test_model, mock_completion):
|
||||
with pytest.raises(IncompleteOutputException):
|
||||
await test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_complete_output_no_exception(test_model, mock_completion):
|
||||
test_model_instance = await test_model.from_response_async(
|
||||
mock_completion, mode=instructor.Mode.FUNCTIONS
|
||||
)
|
||||
assert test_model_instance.data == "complete data"
|
||||
|
||||
Reference in New Issue
Block a user