Co-authored-by: Luke Van Seters <lukevanseters@gmail.com>
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
Jason Liu
2024-01-31 22:08:07 -05:00
committed by GitHub
parent ef76cfe97b
commit c7f1ceeb5c
6 changed files with 277 additions and 130 deletions
+4 -2
View File
@@ -110,7 +110,7 @@ def IterableModel(
subtask_class: Type[BaseModel],
name: Optional[str] = None,
description: Optional[str] = None,
):
) -> Type[BaseModel]:
"""
Dynamically create a IterableModel OpenAISchema that can be used to segment multiple
tasks given a base class. This creates class that can be used to create a toolkit
@@ -187,5 +187,7 @@ def IterableModel(
if description is None
else description
)
assert issubclass(
new_cls, OpenAISchema
), "The new class should be a subclass of OpenAISchema"
return new_cls
+2 -2
View File
@@ -125,7 +125,7 @@ class JSONParser:
s = s[end + 1 :]
return json.loads(str_val), s
def parse_number(self, s, e):
def parse_number(self, s):
i = 0
while i < len(s) and s[i] in "0123456789.-":
i += 1
@@ -139,7 +139,7 @@ class JSONParser:
if "." in num_str or "e" in num_str or "E" in num_str
else int(num_str)
)
except ValueError:
except ValueError as e:
raise e
return num, s
+11 -7
View File
@@ -1,11 +1,13 @@
from typing import Type, TypeVar, Self
from docstring_parser import parse
from functools import wraps
from pydantic import BaseModel, create_model
from instructor.exceptions import IncompleteOutputException
import enum
import warnings
T = TypeVar("T")
class Mode(enum.Enum):
"""The mode to use for patching the client"""
@@ -118,11 +120,11 @@ class OpenAISchema(BaseModel):
@classmethod
def from_response(
cls,
completion,
validation_context=None,
completion: T,
validation_context: dict = None,
strict: bool = None,
mode: Mode = Mode.TOOLS,
):
) -> Self:
"""Execute the function from the response of an openai chat completion
Parameters:
@@ -176,9 +178,11 @@ class OpenAISchema(BaseModel):
async def from_response_async(
cls,
completion,
validation_context=None,
validation_context: dict = None,
strict: bool = None,
mode: Mode = Mode.TOOLS,
stream_multitask: bool = False,
stream_partial: bool = False,
):
"""Execute the function from the response of an openai chat completion
@@ -230,7 +234,7 @@ class OpenAISchema(BaseModel):
raise ValueError(f"Invalid patch mode: {mode}")
def openai_schema(cls) -> OpenAISchema:
def openai_schema(cls: Type[BaseModel]) -> OpenAISchema:
if not issubclass(cls, BaseModel):
raise TypeError("Class must be a subclass of pydantic.BaseModel")
@@ -239,4 +243,4 @@ def openai_schema(cls) -> OpenAISchema:
cls.__name__,
__base__=(cls, OpenAISchema),
)
) # type: ignore
)
+179 -97
View File
@@ -4,7 +4,18 @@ import logging
from collections.abc import Iterable
from functools import wraps
from json import JSONDecodeError
from typing import Callable, Optional, Type, Union, get_args, get_origin
from typing import (
Callable,
Optional,
ParamSpec,
Protocol,
Type,
TypeVar,
Union,
get_args,
get_origin,
overload,
)
from openai import AsyncOpenAI, OpenAI
from openai.types.chat import (
@@ -22,24 +33,11 @@ from .function_calls import Mode, OpenAISchema, openai_schema
logger = logging.getLogger("instructor")
OVERRIDE_DOCS = """
Creates a new chat completion for the provided messages and parameters.
See: https://platform.openai.com/docs/api-reference/chat-completions/create
Additional Notes:
Using the `response_model` parameter, you can specify a response model to use for parsing the response from OpenAI's API. If its present, the response will be parsed using the response model, otherwise it will be returned as is.
If `stream=True` is specified, the response will be parsed using the `from_stream_response` method of the response model, if available, otherwise it will be parsed using the `from_response` method.
If need to obtain the raw response from OpenAI's API, you can access it using the `_raw_response` attribute of the response model. The `_raw_response.usage` attribute is modified to reflect the token usage from the last successful response as well as from any previous unsuccessful attempts.
Parameters:
response_model (Union[Type[BaseModel], Type[OpenAISchema]]): The response model to use for parsing the response from OpenAI's API, if available (default: None)
max_retries (int): The maximum number of retries to attempt if the response is not valid (default: 0)
validation_context (dict): The validation context to use for validating the response (default: None)
"""
T_Model = TypeVar("T_Model", bound=BaseModel)
T_Retval = TypeVar("T_Retval")
T_ParamSpec = ParamSpec("T_ParamSpec")
T = TypeVar("T")
def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
@@ -60,11 +58,24 @@ def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
def handle_response_model(
*,
response_model: Type[BaseModel],
kwargs,
mode: Mode = Mode.FUNCTIONS,
):
response_model: T, mode: Mode = Mode.TOOLS, **kwargs
) -> Union[Type[OpenAISchema], dict]:
"""Prepare the response model type hint, and returns the response_model
along with the new modified kwargs needed to be able to use the response_model
parameter with the patch function.
Args:
response_model (T): The response model to use for parsing the response
mode (Mode, optional): The openai completion mode. Defaults to Mode.TOOLS.
Raises:
NotImplementedError: When using stream=True with a non-iterable response_model
ValueError: When using an invalid patch mode
Returns:
Union[Type[OpenAISchema], dict]: The response model to use for parsing the response
"""
new_kwargs = kwargs.copy()
if response_model is not None:
if get_origin(response_model) is Iterable:
@@ -143,24 +154,26 @@ def handle_response_model(
def process_response(
response,
response: T,
*,
response_model: Type[BaseModel],
response_model: Type[T_Model],
stream: bool,
validation_context: dict = None,
strict=None,
mode: Mode = Mode.FUNCTIONS,
): # type: ignore
) -> Union[T_Model, T]:
"""Processes a OpenAI response with the response model, if available.
It can use `validation_context` and `strict` to validate the response
via the pydantic model
Args:
response (ChatCompletion): The response from OpenAI's API
response_model (BaseModel): The response model to use for parsing the response
response (T): The response from OpenAI's API
response_model (Type[T_Model]): The response model to use for parsing the response
stream (bool): Whether the response is a stream
validation_context (dict, optional): The validation context to use for validating the response. Defaults to None.
strict (bool, optional): Whether to use strict json parsing. Defaults to None.
strict (_type_, optional): Whether to use strict json parsing. Defaults to None.
mode (Mode, optional): The openai completion mode. Defaults to Mode.FUNCTIONS.
Returns:
Union[T_Model, T]: The parsed response, if a response model is available, otherwise the response as is from the SDK
"""
if response_model is None:
return response
@@ -188,14 +201,14 @@ def process_response(
async def process_response_async(
response,
response: ChatCompletion,
*,
response_model: Type[BaseModel],
stream: bool,
response_model: Type[T_Model],
stream: bool = False,
validation_context: dict = None,
strict=None,
strict: Optional[bool] = None,
mode: Mode = Mode.FUNCTIONS,
): # type: ignore
) -> T:
"""Processes a OpenAI response with the response model, if available.
It can use `validation_context` and `strict` to validate the response
via the pydantic model
@@ -230,15 +243,15 @@ async def process_response_async(
async def retry_async(
func,
response_model,
func: Callable[T_ParamSpec, T_Retval],
response_model: Type[T],
validation_context,
args,
kwargs,
max_retries,
strict: Optional[bool] = None,
mode: Mode = Mode.FUNCTIONS,
):
) -> T:
retries = 0
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
while retries <= max_retries:
@@ -292,12 +305,12 @@ async def retry_async(
def retry_sync(
func,
response_model,
validation_context,
func: Callable[T_ParamSpec, T_Retval],
response_model: Type[T],
validation_context: dict,
args,
kwargs,
max_retries,
max_retries: int = 1,
strict: Optional[bool] = None,
mode: Mode = Mode.FUNCTIONS,
):
@@ -362,61 +375,67 @@ def is_async(func: Callable) -> bool:
)
def wrap_chatcompletion(func: Callable, mode: Mode = Mode.FUNCTIONS) -> Callable:
func_is_async = is_async(func)
OVERRIDE_DOCS = """
Creates a new chat completion for the provided messages and parameters.
@wraps(func)
async def new_chatcompletion_async(
response_model=None,
validation_context=None,
max_retries=1,
*args,
**kwargs,
):
response_model, new_kwargs = handle_response_model(
response_model=response_model, kwargs=kwargs, mode=mode
) # type: ignore
response = await retry_async(
func=func,
response_model=response_model,
validation_context=validation_context,
max_retries=max_retries,
args=args,
kwargs=new_kwargs,
mode=mode,
) # type: ignore
return response
See: https://platform.openai.com/docs/api-reference/chat-completions/create
@wraps(func)
def new_chatcompletion_sync(
response_model=None,
validation_context=None,
max_retries=1,
*args,
**kwargs,
):
response_model, new_kwargs = handle_response_model(
response_model=response_model, kwargs=kwargs, mode=mode
) # type: ignore
response = retry_sync(
func=func,
response_model=response_model,
validation_context=validation_context,
max_retries=max_retries,
args=args,
kwargs=new_kwargs,
mode=mode,
) # type: ignore
return response
Additional Notes:
wrapper_function = (
new_chatcompletion_async if func_is_async else new_chatcompletion_sync
)
wrapper_function.__doc__ = OVERRIDE_DOCS
return wrapper_function
Using the `response_model` parameter, you can specify a response model to use for parsing the response from OpenAI's API. If its present, the response will be parsed using the response model, otherwise it will be returned as is.
If `stream=True` is specified, the response will be parsed using the `from_stream_response` method of the response model, if available, otherwise it will be parsed using the `from_response` method.
If need to obtain the raw response from OpenAI's API, you can access it using the `_raw_response` attribute of the response model. The `_raw_response.usage` attribute is modified to reflect the token usage from the last successful response as well as from any previous unsuccessful attempts.
Parameters:
response_model (Union[Type[BaseModel], Type[OpenAISchema]]): The response model to use for parsing the response from OpenAI's API, if available (default: None)
max_retries (int): The maximum number of retries to attempt if the response is not valid (default: 0)
validation_context (dict): The validation context to use for validating the response (default: None)
"""
def patch(client: Union[OpenAI, AsyncOpenAI], mode: Mode = Mode.FUNCTIONS):
class InstructorChatCompletionCreate(Protocol):
def __call__(
self,
response_model: Type[T_Model] = None,
validation_context: dict = None,
max_retries: int = 1,
*args: T_ParamSpec.args,
**kwargs: T_ParamSpec.kwargs,
) -> T_Model:
...
@overload
def patch(
client: OpenAI,
mode: Mode = Mode.FUNCTIONS,
) -> OpenAI:
...
@overload
def patch(
client: AsyncOpenAI,
mode: Mode = Mode.FUNCTIONS,
) -> AsyncOpenAI:
...
@overload
def patch(
create: Callable[T_ParamSpec, T_Retval],
mode: Mode = Mode.FUNCTIONS,
) -> InstructorChatCompletionCreate:
...
def patch(
client: Union[OpenAI, AsyncOpenAI] = None,
create: Callable[T_ParamSpec, T_Retval] = None,
mode: Mode = Mode.FUNCTIONS,
) -> Union[OpenAI, AsyncOpenAI]:
"""
Patch the `client.chat.completions.create` method
@@ -429,10 +448,68 @@ def patch(client: Union[OpenAI, AsyncOpenAI], mode: Mode = Mode.FUNCTIONS):
"""
logger.debug(f"Patching `client.chat.completions.create` with {mode=}")
client.chat.completions.create = wrap_chatcompletion(
client.chat.completions.create, mode=mode
)
return client
if create is not None:
func = create
elif client is not None:
func = client.chat.completions.create
else:
raise ValueError("Either client or create must be provided")
func_is_async = is_async(func)
@wraps(func)
async def new_create_async(
response_model: Type[T_Model] = None,
validation_context: dict = None,
max_retries: int = 1,
*args: T_ParamSpec.args,
**kwargs: T_ParamSpec.kwargs,
) -> T_Model:
response_model, new_kwargs = handle_response_model(
response_model=response_model, mode=mode, **kwargs
)
response = await retry_async(
func=func,
response_model=response_model,
validation_context=validation_context,
max_retries=max_retries,
args=args,
kwargs=new_kwargs,
mode=mode,
) # type: ignore
return response
@wraps(func)
def new_create_sync(
response_model: Type[T_Model] = None,
validation_context: dict = None,
max_retries: int = 1,
*args: T_ParamSpec.args,
**kwargs: T_ParamSpec.kwargs,
) -> T_Model:
response_model, new_kwargs = handle_response_model(
response_model=response_model, mode=mode, **kwargs
)
response = retry_sync(
func=func,
response_model=response_model,
validation_context=validation_context,
max_retries=max_retries,
args=args,
kwargs=new_kwargs,
mode=mode,
)
return response
new_create = new_create_async if func_is_async else new_create_sync
new_create.__doc__ = OVERRIDE_DOCS
if client is not None:
client.chat.completions.create = new_create
return client
else:
return new_create
def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS):
@@ -448,4 +525,9 @@ def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS):
- `validation_context` parameter to validate the response using the pydantic model
- `strict` parameter to use strict json parsing
"""
import warnings
warnings.warn(
"apatch is deprecated, use patch instead", DeprecationWarning, stacklevel=2
)
return patch(client, mode=mode)
+80
View File
@@ -0,0 +1,80 @@
from typing import Iterable, Literal, Union
from pydantic import BaseModel
import pytest
import instructor
class Weather(BaseModel):
location: str
units: Literal["imperial", "metric"]
class GoogleSearch(BaseModel):
query: str
def test_sync_parallel_tools_or(client):
client = instructor.patch(client, mode=instructor.Mode.PARALLEL_TOOLS)
resp = client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=[
{"role": "system", "content": "You must always use tools"},
{
"role": "user",
"content": "What is the weather in toronto and dallas and who won the super bowl?",
},
],
response_model=Iterable[Weather | GoogleSearch],
)
assert len(list(resp)) == 3
@pytest.mark.asyncio
async def test_async_parallel_tools_or(aclient):
client = instructor.patch(aclient, mode=instructor.Mode.PARALLEL_TOOLS)
resp = await client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=[
{"role": "system", "content": "You must always use tools"},
{
"role": "user",
"content": "What is the weather in toronto and dallas and who won the super bowl?",
},
],
response_model=Iterable[Weather | GoogleSearch],
)
assert len(list(resp)) == 3
def test_sync_parallel_tools_one(client):
client = instructor.patch(client, mode=instructor.Mode.PARALLEL_TOOLS)
resp = client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=[
{"role": "system", "content": "You must always use tools"},
{
"role": "user",
"content": "What is the weather in toronto and dallas?",
},
],
response_model=Iterable[Weather],
)
assert len(list(resp)) == 2
@pytest.mark.asyncio
async def test_async_parallel_tools_one(aclient):
client = instructor.patch(aclient, mode=instructor.Mode.PARALLEL_TOOLS)
resp = await client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=[
{"role": "system", "content": "You must always use tools"},
{
"role": "user",
"content": "What is the weather in toronto and dallas?",
},
],
response_model=Iterable[Weather],
)
assert len(list(resp)) == 2
+1 -22
View File
@@ -10,7 +10,7 @@ from openai.types.chat.chat_completion_message_tool_call import (
)
import instructor
from instructor.patch import OVERRIDE_DOCS, dump_message, is_async, wrap_chatcompletion
from instructor.patch import OVERRIDE_DOCS, dump_message, is_async
def test_patch_completes_successfully():
@@ -21,27 +21,6 @@ def test_apatch_completes_successfully():
instructor.apatch(AsyncOpenAI())
@pytest.mark.asyncio
async def test_wrap_chatcompletion_wraps_async_input_function():
async def input_function(*args, **kwargs):
return "Hello, World!"
wrapped_function = wrap_chatcompletion(input_function)
result = await wrapped_function()
assert result == "Hello, World!"
def test_wrap_chatcompletion_wraps_input_function():
def input_function(*args, **kwargs):
return "Hello, World!"
wrapped_function = wrap_chatcompletion(input_function)
result = wrapped_function()
assert result == "Hello, World!"
def test_is_async_returns_true_if_function_is_async():
async def async_function():
pass