mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
Types!!! (#372)
Co-authored-by: Luke Van Seters <lukevanseters@gmail.com> Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
@@ -110,7 +110,7 @@ def IterableModel(
|
||||
subtask_class: Type[BaseModel],
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
) -> Type[BaseModel]:
|
||||
"""
|
||||
Dynamically create a IterableModel OpenAISchema that can be used to segment multiple
|
||||
tasks given a base class. This creates class that can be used to create a toolkit
|
||||
@@ -187,5 +187,7 @@ def IterableModel(
|
||||
if description is None
|
||||
else description
|
||||
)
|
||||
|
||||
assert issubclass(
|
||||
new_cls, OpenAISchema
|
||||
), "The new class should be a subclass of OpenAISchema"
|
||||
return new_cls
|
||||
|
||||
@@ -125,7 +125,7 @@ class JSONParser:
|
||||
s = s[end + 1 :]
|
||||
return json.loads(str_val), s
|
||||
|
||||
def parse_number(self, s, e):
|
||||
def parse_number(self, s):
|
||||
i = 0
|
||||
while i < len(s) and s[i] in "0123456789.-":
|
||||
i += 1
|
||||
@@ -139,7 +139,7 @@ class JSONParser:
|
||||
if "." in num_str or "e" in num_str or "E" in num_str
|
||||
else int(num_str)
|
||||
)
|
||||
except ValueError:
|
||||
except ValueError as e:
|
||||
raise e
|
||||
return num, s
|
||||
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from typing import Type, TypeVar, Self
|
||||
from docstring_parser import parse
|
||||
from functools import wraps
|
||||
from pydantic import BaseModel, create_model
|
||||
from instructor.exceptions import IncompleteOutputException
|
||||
|
||||
import enum
|
||||
import warnings
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Mode(enum.Enum):
|
||||
"""The mode to use for patching the client"""
|
||||
@@ -118,11 +120,11 @@ class OpenAISchema(BaseModel):
|
||||
@classmethod
|
||||
def from_response(
|
||||
cls,
|
||||
completion,
|
||||
validation_context=None,
|
||||
completion: T,
|
||||
validation_context: dict = None,
|
||||
strict: bool = None,
|
||||
mode: Mode = Mode.TOOLS,
|
||||
):
|
||||
) -> Self:
|
||||
"""Execute the function from the response of an openai chat completion
|
||||
|
||||
Parameters:
|
||||
@@ -176,9 +178,11 @@ class OpenAISchema(BaseModel):
|
||||
async def from_response_async(
|
||||
cls,
|
||||
completion,
|
||||
validation_context=None,
|
||||
validation_context: dict = None,
|
||||
strict: bool = None,
|
||||
mode: Mode = Mode.TOOLS,
|
||||
stream_multitask: bool = False,
|
||||
stream_partial: bool = False,
|
||||
):
|
||||
"""Execute the function from the response of an openai chat completion
|
||||
|
||||
@@ -230,7 +234,7 @@ class OpenAISchema(BaseModel):
|
||||
raise ValueError(f"Invalid patch mode: {mode}")
|
||||
|
||||
|
||||
def openai_schema(cls) -> OpenAISchema:
|
||||
def openai_schema(cls: Type[BaseModel]) -> OpenAISchema:
|
||||
if not issubclass(cls, BaseModel):
|
||||
raise TypeError("Class must be a subclass of pydantic.BaseModel")
|
||||
|
||||
@@ -239,4 +243,4 @@ def openai_schema(cls) -> OpenAISchema:
|
||||
cls.__name__,
|
||||
__base__=(cls, OpenAISchema),
|
||||
)
|
||||
) # type: ignore
|
||||
)
|
||||
|
||||
+179
-97
@@ -4,7 +4,18 @@ import logging
|
||||
from collections.abc import Iterable
|
||||
from functools import wraps
|
||||
from json import JSONDecodeError
|
||||
from typing import Callable, Optional, Type, Union, get_args, get_origin
|
||||
from typing import (
|
||||
Callable,
|
||||
Optional,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
overload,
|
||||
)
|
||||
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from openai.types.chat import (
|
||||
@@ -22,24 +33,11 @@ from .function_calls import Mode, OpenAISchema, openai_schema
|
||||
|
||||
logger = logging.getLogger("instructor")
|
||||
|
||||
OVERRIDE_DOCS = """
|
||||
Creates a new chat completion for the provided messages and parameters.
|
||||
|
||||
See: https://platform.openai.com/docs/api-reference/chat-completions/create
|
||||
|
||||
Additional Notes:
|
||||
|
||||
Using the `response_model` parameter, you can specify a response model to use for parsing the response from OpenAI's API. If its present, the response will be parsed using the response model, otherwise it will be returned as is.
|
||||
|
||||
If `stream=True` is specified, the response will be parsed using the `from_stream_response` method of the response model, if available, otherwise it will be parsed using the `from_response` method.
|
||||
|
||||
If need to obtain the raw response from OpenAI's API, you can access it using the `_raw_response` attribute of the response model. The `_raw_response.usage` attribute is modified to reflect the token usage from the last successful response as well as from any previous unsuccessful attempts.
|
||||
|
||||
Parameters:
|
||||
response_model (Union[Type[BaseModel], Type[OpenAISchema]]): The response model to use for parsing the response from OpenAI's API, if available (default: None)
|
||||
max_retries (int): The maximum number of retries to attempt if the response is not valid (default: 0)
|
||||
validation_context (dict): The validation context to use for validating the response (default: None)
|
||||
"""
|
||||
T_Model = TypeVar("T_Model", bound=BaseModel)
|
||||
T_Retval = TypeVar("T_Retval")
|
||||
T_ParamSpec = ParamSpec("T_ParamSpec")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
|
||||
@@ -60,11 +58,24 @@ def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
|
||||
|
||||
|
||||
def handle_response_model(
|
||||
*,
|
||||
response_model: Type[BaseModel],
|
||||
kwargs,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
):
|
||||
response_model: T, mode: Mode = Mode.TOOLS, **kwargs
|
||||
) -> Union[Type[OpenAISchema], dict]:
|
||||
"""Prepare the response model type hint, and returns the response_model
|
||||
along with the new modified kwargs needed to be able to use the response_model
|
||||
parameter with the patch function.
|
||||
|
||||
|
||||
Args:
|
||||
response_model (T): The response model to use for parsing the response
|
||||
mode (Mode, optional): The openai completion mode. Defaults to Mode.TOOLS.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: When using stream=True with a non-iterable response_model
|
||||
ValueError: When using an invalid patch mode
|
||||
|
||||
Returns:
|
||||
Union[Type[OpenAISchema], dict]: The response model to use for parsing the response
|
||||
"""
|
||||
new_kwargs = kwargs.copy()
|
||||
if response_model is not None:
|
||||
if get_origin(response_model) is Iterable:
|
||||
@@ -143,24 +154,26 @@ def handle_response_model(
|
||||
|
||||
|
||||
def process_response(
|
||||
response,
|
||||
response: T,
|
||||
*,
|
||||
response_model: Type[BaseModel],
|
||||
response_model: Type[T_Model],
|
||||
stream: bool,
|
||||
validation_context: dict = None,
|
||||
strict=None,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
): # type: ignore
|
||||
) -> Union[T_Model, T]:
|
||||
"""Processes a OpenAI response with the response model, if available.
|
||||
It can use `validation_context` and `strict` to validate the response
|
||||
via the pydantic model
|
||||
|
||||
Args:
|
||||
response (ChatCompletion): The response from OpenAI's API
|
||||
response_model (BaseModel): The response model to use for parsing the response
|
||||
response (T): The response from OpenAI's API
|
||||
response_model (Type[T_Model]): The response model to use for parsing the response
|
||||
stream (bool): Whether the response is a stream
|
||||
validation_context (dict, optional): The validation context to use for validating the response. Defaults to None.
|
||||
strict (bool, optional): Whether to use strict json parsing. Defaults to None.
|
||||
strict (_type_, optional): Whether to use strict json parsing. Defaults to None.
|
||||
mode (Mode, optional): The openai completion mode. Defaults to Mode.FUNCTIONS.
|
||||
|
||||
Returns:
|
||||
Union[T_Model, T]: The parsed response, if a response model is available, otherwise the response as is from the SDK
|
||||
"""
|
||||
if response_model is None:
|
||||
return response
|
||||
@@ -188,14 +201,14 @@ def process_response(
|
||||
|
||||
|
||||
async def process_response_async(
|
||||
response,
|
||||
response: ChatCompletion,
|
||||
*,
|
||||
response_model: Type[BaseModel],
|
||||
stream: bool,
|
||||
response_model: Type[T_Model],
|
||||
stream: bool = False,
|
||||
validation_context: dict = None,
|
||||
strict=None,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
): # type: ignore
|
||||
) -> T:
|
||||
"""Processes a OpenAI response with the response model, if available.
|
||||
It can use `validation_context` and `strict` to validate the response
|
||||
via the pydantic model
|
||||
@@ -230,15 +243,15 @@ async def process_response_async(
|
||||
|
||||
|
||||
async def retry_async(
|
||||
func,
|
||||
response_model,
|
||||
func: Callable[T_ParamSpec, T_Retval],
|
||||
response_model: Type[T],
|
||||
validation_context,
|
||||
args,
|
||||
kwargs,
|
||||
max_retries,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
):
|
||||
) -> T:
|
||||
retries = 0
|
||||
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
|
||||
while retries <= max_retries:
|
||||
@@ -292,12 +305,12 @@ async def retry_async(
|
||||
|
||||
|
||||
def retry_sync(
|
||||
func,
|
||||
response_model,
|
||||
validation_context,
|
||||
func: Callable[T_ParamSpec, T_Retval],
|
||||
response_model: Type[T],
|
||||
validation_context: dict,
|
||||
args,
|
||||
kwargs,
|
||||
max_retries,
|
||||
max_retries: int = 1,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
):
|
||||
@@ -362,61 +375,67 @@ def is_async(func: Callable) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def wrap_chatcompletion(func: Callable, mode: Mode = Mode.FUNCTIONS) -> Callable:
|
||||
func_is_async = is_async(func)
|
||||
OVERRIDE_DOCS = """
|
||||
Creates a new chat completion for the provided messages and parameters.
|
||||
|
||||
@wraps(func)
|
||||
async def new_chatcompletion_async(
|
||||
response_model=None,
|
||||
validation_context=None,
|
||||
max_retries=1,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
response_model, new_kwargs = handle_response_model(
|
||||
response_model=response_model, kwargs=kwargs, mode=mode
|
||||
) # type: ignore
|
||||
response = await retry_async(
|
||||
func=func,
|
||||
response_model=response_model,
|
||||
validation_context=validation_context,
|
||||
max_retries=max_retries,
|
||||
args=args,
|
||||
kwargs=new_kwargs,
|
||||
mode=mode,
|
||||
) # type: ignore
|
||||
return response
|
||||
See: https://platform.openai.com/docs/api-reference/chat-completions/create
|
||||
|
||||
@wraps(func)
|
||||
def new_chatcompletion_sync(
|
||||
response_model=None,
|
||||
validation_context=None,
|
||||
max_retries=1,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
response_model, new_kwargs = handle_response_model(
|
||||
response_model=response_model, kwargs=kwargs, mode=mode
|
||||
) # type: ignore
|
||||
response = retry_sync(
|
||||
func=func,
|
||||
response_model=response_model,
|
||||
validation_context=validation_context,
|
||||
max_retries=max_retries,
|
||||
args=args,
|
||||
kwargs=new_kwargs,
|
||||
mode=mode,
|
||||
) # type: ignore
|
||||
return response
|
||||
Additional Notes:
|
||||
|
||||
wrapper_function = (
|
||||
new_chatcompletion_async if func_is_async else new_chatcompletion_sync
|
||||
)
|
||||
wrapper_function.__doc__ = OVERRIDE_DOCS
|
||||
return wrapper_function
|
||||
Using the `response_model` parameter, you can specify a response model to use for parsing the response from OpenAI's API. If its present, the response will be parsed using the response model, otherwise it will be returned as is.
|
||||
|
||||
If `stream=True` is specified, the response will be parsed using the `from_stream_response` method of the response model, if available, otherwise it will be parsed using the `from_response` method.
|
||||
|
||||
If need to obtain the raw response from OpenAI's API, you can access it using the `_raw_response` attribute of the response model. The `_raw_response.usage` attribute is modified to reflect the token usage from the last successful response as well as from any previous unsuccessful attempts.
|
||||
|
||||
Parameters:
|
||||
response_model (Union[Type[BaseModel], Type[OpenAISchema]]): The response model to use for parsing the response from OpenAI's API, if available (default: None)
|
||||
max_retries (int): The maximum number of retries to attempt if the response is not valid (default: 0)
|
||||
validation_context (dict): The validation context to use for validating the response (default: None)
|
||||
"""
|
||||
|
||||
|
||||
def patch(client: Union[OpenAI, AsyncOpenAI], mode: Mode = Mode.FUNCTIONS):
|
||||
class InstructorChatCompletionCreate(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
response_model: Type[T_Model] = None,
|
||||
validation_context: dict = None,
|
||||
max_retries: int = 1,
|
||||
*args: T_ParamSpec.args,
|
||||
**kwargs: T_ParamSpec.kwargs,
|
||||
) -> T_Model:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def patch(
|
||||
client: OpenAI,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
) -> OpenAI:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def patch(
|
||||
client: AsyncOpenAI,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
) -> AsyncOpenAI:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def patch(
|
||||
create: Callable[T_ParamSpec, T_Retval],
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
) -> InstructorChatCompletionCreate:
|
||||
...
|
||||
|
||||
|
||||
def patch(
|
||||
client: Union[OpenAI, AsyncOpenAI] = None,
|
||||
create: Callable[T_ParamSpec, T_Retval] = None,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
) -> Union[OpenAI, AsyncOpenAI]:
|
||||
"""
|
||||
Patch the `client.chat.completions.create` method
|
||||
|
||||
@@ -429,10 +448,68 @@ def patch(client: Union[OpenAI, AsyncOpenAI], mode: Mode = Mode.FUNCTIONS):
|
||||
"""
|
||||
|
||||
logger.debug(f"Patching `client.chat.completions.create` with {mode=}")
|
||||
client.chat.completions.create = wrap_chatcompletion(
|
||||
client.chat.completions.create, mode=mode
|
||||
)
|
||||
return client
|
||||
|
||||
if create is not None:
|
||||
func = create
|
||||
elif client is not None:
|
||||
func = client.chat.completions.create
|
||||
else:
|
||||
raise ValueError("Either client or create must be provided")
|
||||
|
||||
func_is_async = is_async(func)
|
||||
|
||||
@wraps(func)
|
||||
async def new_create_async(
|
||||
response_model: Type[T_Model] = None,
|
||||
validation_context: dict = None,
|
||||
max_retries: int = 1,
|
||||
*args: T_ParamSpec.args,
|
||||
**kwargs: T_ParamSpec.kwargs,
|
||||
) -> T_Model:
|
||||
response_model, new_kwargs = handle_response_model(
|
||||
response_model=response_model, mode=mode, **kwargs
|
||||
)
|
||||
response = await retry_async(
|
||||
func=func,
|
||||
response_model=response_model,
|
||||
validation_context=validation_context,
|
||||
max_retries=max_retries,
|
||||
args=args,
|
||||
kwargs=new_kwargs,
|
||||
mode=mode,
|
||||
) # type: ignore
|
||||
return response
|
||||
|
||||
@wraps(func)
|
||||
def new_create_sync(
|
||||
response_model: Type[T_Model] = None,
|
||||
validation_context: dict = None,
|
||||
max_retries: int = 1,
|
||||
*args: T_ParamSpec.args,
|
||||
**kwargs: T_ParamSpec.kwargs,
|
||||
) -> T_Model:
|
||||
response_model, new_kwargs = handle_response_model(
|
||||
response_model=response_model, mode=mode, **kwargs
|
||||
)
|
||||
response = retry_sync(
|
||||
func=func,
|
||||
response_model=response_model,
|
||||
validation_context=validation_context,
|
||||
max_retries=max_retries,
|
||||
args=args,
|
||||
kwargs=new_kwargs,
|
||||
mode=mode,
|
||||
)
|
||||
return response
|
||||
|
||||
new_create = new_create_async if func_is_async else new_create_sync
|
||||
new_create.__doc__ = OVERRIDE_DOCS
|
||||
|
||||
if client is not None:
|
||||
client.chat.completions.create = new_create
|
||||
return client
|
||||
else:
|
||||
return new_create
|
||||
|
||||
|
||||
def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS):
|
||||
@@ -448,4 +525,9 @@ def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS):
|
||||
- `validation_context` parameter to validate the response using the pydantic model
|
||||
- `strict` parameter to use strict json parsing
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"apatch is deprecated, use patch instead", DeprecationWarning, stacklevel=2
|
||||
)
|
||||
return patch(client, mode=mode)
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
from typing import Iterable, Literal, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
import pytest
|
||||
import instructor
|
||||
|
||||
|
||||
class Weather(BaseModel):
|
||||
location: str
|
||||
units: Literal["imperial", "metric"]
|
||||
|
||||
|
||||
class GoogleSearch(BaseModel):
|
||||
query: str
|
||||
|
||||
|
||||
def test_sync_parallel_tools_or(client):
|
||||
client = instructor.patch(client, mode=instructor.Mode.PARALLEL_TOOLS)
|
||||
resp = client.chat.completions.create(
|
||||
model="gpt-4-turbo-preview",
|
||||
messages=[
|
||||
{"role": "system", "content": "You must always use tools"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in toronto and dallas and who won the super bowl?",
|
||||
},
|
||||
],
|
||||
response_model=Iterable[Weather | GoogleSearch],
|
||||
)
|
||||
assert len(list(resp)) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_parallel_tools_or(aclient):
|
||||
client = instructor.patch(aclient, mode=instructor.Mode.PARALLEL_TOOLS)
|
||||
resp = await client.chat.completions.create(
|
||||
model="gpt-4-turbo-preview",
|
||||
messages=[
|
||||
{"role": "system", "content": "You must always use tools"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in toronto and dallas and who won the super bowl?",
|
||||
},
|
||||
],
|
||||
response_model=Iterable[Weather | GoogleSearch],
|
||||
)
|
||||
assert len(list(resp)) == 3
|
||||
|
||||
|
||||
def test_sync_parallel_tools_one(client):
|
||||
client = instructor.patch(client, mode=instructor.Mode.PARALLEL_TOOLS)
|
||||
resp = client.chat.completions.create(
|
||||
model="gpt-4-turbo-preview",
|
||||
messages=[
|
||||
{"role": "system", "content": "You must always use tools"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in toronto and dallas?",
|
||||
},
|
||||
],
|
||||
response_model=Iterable[Weather],
|
||||
)
|
||||
assert len(list(resp)) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_parallel_tools_one(aclient):
|
||||
client = instructor.patch(aclient, mode=instructor.Mode.PARALLEL_TOOLS)
|
||||
resp = await client.chat.completions.create(
|
||||
model="gpt-4-turbo-preview",
|
||||
messages=[
|
||||
{"role": "system", "content": "You must always use tools"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in toronto and dallas?",
|
||||
},
|
||||
],
|
||||
response_model=Iterable[Weather],
|
||||
)
|
||||
assert len(list(resp)) == 2
|
||||
+1
-22
@@ -10,7 +10,7 @@ from openai.types.chat.chat_completion_message_tool_call import (
|
||||
)
|
||||
|
||||
import instructor
|
||||
from instructor.patch import OVERRIDE_DOCS, dump_message, is_async, wrap_chatcompletion
|
||||
from instructor.patch import OVERRIDE_DOCS, dump_message, is_async
|
||||
|
||||
|
||||
def test_patch_completes_successfully():
|
||||
@@ -21,27 +21,6 @@ def test_apatch_completes_successfully():
|
||||
instructor.apatch(AsyncOpenAI())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wrap_chatcompletion_wraps_async_input_function():
|
||||
async def input_function(*args, **kwargs):
|
||||
return "Hello, World!"
|
||||
|
||||
wrapped_function = wrap_chatcompletion(input_function)
|
||||
result = await wrapped_function()
|
||||
|
||||
assert result == "Hello, World!"
|
||||
|
||||
|
||||
def test_wrap_chatcompletion_wraps_input_function():
|
||||
def input_function(*args, **kwargs):
|
||||
return "Hello, World!"
|
||||
|
||||
wrapped_function = wrap_chatcompletion(input_function)
|
||||
result = wrapped_function()
|
||||
|
||||
assert result == "Hello, World!"
|
||||
|
||||
|
||||
def test_is_async_returns_true_if_function_is_async():
|
||||
async def async_function():
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user