diff --git a/instructor/dsl/iterable.py b/instructor/dsl/iterable.py index 295fb67..b0fc882 100644 --- a/instructor/dsl/iterable.py +++ b/instructor/dsl/iterable.py @@ -110,7 +110,7 @@ def IterableModel( subtask_class: Type[BaseModel], name: Optional[str] = None, description: Optional[str] = None, -): +) -> Type[BaseModel]: """ Dynamically create a IterableModel OpenAISchema that can be used to segment multiple tasks given a base class. This creates class that can be used to create a toolkit @@ -187,5 +187,7 @@ def IterableModel( if description is None else description ) - + assert issubclass( + new_cls, OpenAISchema + ), "The new class should be a subclass of OpenAISchema" return new_cls diff --git a/instructor/dsl/partialjson.py b/instructor/dsl/partialjson.py index 62d2490..47ce242 100644 --- a/instructor/dsl/partialjson.py +++ b/instructor/dsl/partialjson.py @@ -125,7 +125,7 @@ class JSONParser: s = s[end + 1 :] return json.loads(str_val), s - def parse_number(self, s, e): + def parse_number(self, s): i = 0 while i < len(s) and s[i] in "0123456789.-": i += 1 @@ -139,7 +139,7 @@ class JSONParser: if "." in num_str or "e" in num_str or "E" in num_str else int(num_str) ) - except ValueError: + except ValueError as e: raise e return num, s diff --git a/instructor/function_calls.py b/instructor/function_calls.py index 42221a7..6872b25 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -1,11 +1,13 @@ +from typing import Type, TypeVar, Self from docstring_parser import parse from functools import wraps from pydantic import BaseModel, create_model from instructor.exceptions import IncompleteOutputException - import enum import warnings +T = TypeVar("T") + class Mode(enum.Enum): """The mode to use for patching the client""" @@ -118,11 +120,11 @@ class OpenAISchema(BaseModel): @classmethod def from_response( cls, - completion, - validation_context=None, + completion: T, + validation_context: dict = None, strict: bool = None, mode: Mode = Mode.TOOLS, - ): + ) -> Self: """Execute the function from the response of an openai chat completion Parameters: @@ -176,9 +178,11 @@ class OpenAISchema(BaseModel): async def from_response_async( cls, completion, - validation_context=None, + validation_context: dict = None, strict: bool = None, mode: Mode = Mode.TOOLS, + stream_multitask: bool = False, + stream_partial: bool = False, ): """Execute the function from the response of an openai chat completion @@ -230,7 +234,7 @@ class OpenAISchema(BaseModel): raise ValueError(f"Invalid patch mode: {mode}") -def openai_schema(cls) -> OpenAISchema: +def openai_schema(cls: Type[BaseModel]) -> OpenAISchema: if not issubclass(cls, BaseModel): raise TypeError("Class must be a subclass of pydantic.BaseModel") @@ -239,4 +243,4 @@ def openai_schema(cls) -> OpenAISchema: cls.__name__, __base__=(cls, OpenAISchema), ) - ) # type: ignore + ) diff --git a/instructor/patch.py b/instructor/patch.py index dfa3335..1985c5a 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -4,7 +4,18 @@ import logging from collections.abc import Iterable from functools import wraps from json import JSONDecodeError -from typing import Callable, Optional, Type, Union, get_args, get_origin +from typing import ( + Callable, + Optional, + ParamSpec, + Protocol, + Type, + TypeVar, + Union, + get_args, + get_origin, + overload, +) from openai import AsyncOpenAI, OpenAI from openai.types.chat import ( @@ -22,24 +33,11 @@ from .function_calls import Mode, OpenAISchema, openai_schema logger = logging.getLogger("instructor") -OVERRIDE_DOCS = """ -Creates a new chat completion for the provided messages and parameters. -See: https://platform.openai.com/docs/api-reference/chat-completions/create - -Additional Notes: - -Using the `response_model` parameter, you can specify a response model to use for parsing the response from OpenAI's API. If its present, the response will be parsed using the response model, otherwise it will be returned as is. - -If `stream=True` is specified, the response will be parsed using the `from_stream_response` method of the response model, if available, otherwise it will be parsed using the `from_response` method. - -If need to obtain the raw response from OpenAI's API, you can access it using the `_raw_response` attribute of the response model. The `_raw_response.usage` attribute is modified to reflect the token usage from the last successful response as well as from any previous unsuccessful attempts. - -Parameters: - response_model (Union[Type[BaseModel], Type[OpenAISchema]]): The response model to use for parsing the response from OpenAI's API, if available (default: None) - max_retries (int): The maximum number of retries to attempt if the response is not valid (default: 0) - validation_context (dict): The validation context to use for validating the response (default: None) -""" +T_Model = TypeVar("T_Model", bound=BaseModel) +T_Retval = TypeVar("T_Retval") +T_ParamSpec = ParamSpec("T_ParamSpec") +T = TypeVar("T") def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam: @@ -60,11 +58,24 @@ def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam: def handle_response_model( - *, - response_model: Type[BaseModel], - kwargs, - mode: Mode = Mode.FUNCTIONS, -): + response_model: T, mode: Mode = Mode.TOOLS, **kwargs +) -> Union[Type[OpenAISchema], dict]: + """Prepare the response model type hint, and returns the response_model + along with the new modified kwargs needed to be able to use the response_model + parameter with the patch function. + + + Args: + response_model (T): The response model to use for parsing the response + mode (Mode, optional): The openai completion mode. Defaults to Mode.TOOLS. + + Raises: + NotImplementedError: When using stream=True with a non-iterable response_model + ValueError: When using an invalid patch mode + + Returns: + Union[Type[OpenAISchema], dict]: The response model to use for parsing the response + """ new_kwargs = kwargs.copy() if response_model is not None: if get_origin(response_model) is Iterable: @@ -143,24 +154,26 @@ def handle_response_model( def process_response( - response, + response: T, *, - response_model: Type[BaseModel], + response_model: Type[T_Model], stream: bool, validation_context: dict = None, strict=None, mode: Mode = Mode.FUNCTIONS, -): # type: ignore +) -> Union[T_Model, T]: """Processes a OpenAI response with the response model, if available. - It can use `validation_context` and `strict` to validate the response - via the pydantic model Args: - response (ChatCompletion): The response from OpenAI's API - response_model (BaseModel): The response model to use for parsing the response + response (T): The response from OpenAI's API + response_model (Type[T_Model]): The response model to use for parsing the response stream (bool): Whether the response is a stream validation_context (dict, optional): The validation context to use for validating the response. Defaults to None. - strict (bool, optional): Whether to use strict json parsing. Defaults to None. + strict (_type_, optional): Whether to use strict json parsing. Defaults to None. + mode (Mode, optional): The openai completion mode. Defaults to Mode.FUNCTIONS. + + Returns: + Union[T_Model, T]: The parsed response, if a response model is available, otherwise the response as is from the SDK """ if response_model is None: return response @@ -188,14 +201,14 @@ def process_response( async def process_response_async( - response, + response: ChatCompletion, *, - response_model: Type[BaseModel], - stream: bool, + response_model: Type[T_Model], + stream: bool = False, validation_context: dict = None, - strict=None, + strict: Optional[bool] = None, mode: Mode = Mode.FUNCTIONS, -): # type: ignore +) -> T: """Processes a OpenAI response with the response model, if available. It can use `validation_context` and `strict` to validate the response via the pydantic model @@ -230,15 +243,15 @@ async def process_response_async( async def retry_async( - func, - response_model, + func: Callable[T_ParamSpec, T_Retval], + response_model: Type[T], validation_context, args, kwargs, max_retries, strict: Optional[bool] = None, mode: Mode = Mode.FUNCTIONS, -): +) -> T: retries = 0 total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0) while retries <= max_retries: @@ -292,12 +305,12 @@ async def retry_async( def retry_sync( - func, - response_model, - validation_context, + func: Callable[T_ParamSpec, T_Retval], + response_model: Type[T], + validation_context: dict, args, kwargs, - max_retries, + max_retries: int = 1, strict: Optional[bool] = None, mode: Mode = Mode.FUNCTIONS, ): @@ -362,61 +375,67 @@ def is_async(func: Callable) -> bool: ) -def wrap_chatcompletion(func: Callable, mode: Mode = Mode.FUNCTIONS) -> Callable: - func_is_async = is_async(func) +OVERRIDE_DOCS = """ +Creates a new chat completion for the provided messages and parameters. - @wraps(func) - async def new_chatcompletion_async( - response_model=None, - validation_context=None, - max_retries=1, - *args, - **kwargs, - ): - response_model, new_kwargs = handle_response_model( - response_model=response_model, kwargs=kwargs, mode=mode - ) # type: ignore - response = await retry_async( - func=func, - response_model=response_model, - validation_context=validation_context, - max_retries=max_retries, - args=args, - kwargs=new_kwargs, - mode=mode, - ) # type: ignore - return response +See: https://platform.openai.com/docs/api-reference/chat-completions/create - @wraps(func) - def new_chatcompletion_sync( - response_model=None, - validation_context=None, - max_retries=1, - *args, - **kwargs, - ): - response_model, new_kwargs = handle_response_model( - response_model=response_model, kwargs=kwargs, mode=mode - ) # type: ignore - response = retry_sync( - func=func, - response_model=response_model, - validation_context=validation_context, - max_retries=max_retries, - args=args, - kwargs=new_kwargs, - mode=mode, - ) # type: ignore - return response +Additional Notes: - wrapper_function = ( - new_chatcompletion_async if func_is_async else new_chatcompletion_sync - ) - wrapper_function.__doc__ = OVERRIDE_DOCS - return wrapper_function +Using the `response_model` parameter, you can specify a response model to use for parsing the response from OpenAI's API. If its present, the response will be parsed using the response model, otherwise it will be returned as is. + +If `stream=True` is specified, the response will be parsed using the `from_stream_response` method of the response model, if available, otherwise it will be parsed using the `from_response` method. + +If need to obtain the raw response from OpenAI's API, you can access it using the `_raw_response` attribute of the response model. The `_raw_response.usage` attribute is modified to reflect the token usage from the last successful response as well as from any previous unsuccessful attempts. + +Parameters: + response_model (Union[Type[BaseModel], Type[OpenAISchema]]): The response model to use for parsing the response from OpenAI's API, if available (default: None) + max_retries (int): The maximum number of retries to attempt if the response is not valid (default: 0) + validation_context (dict): The validation context to use for validating the response (default: None) +""" -def patch(client: Union[OpenAI, AsyncOpenAI], mode: Mode = Mode.FUNCTIONS): +class InstructorChatCompletionCreate(Protocol): + def __call__( + self, + response_model: Type[T_Model] = None, + validation_context: dict = None, + max_retries: int = 1, + *args: T_ParamSpec.args, + **kwargs: T_ParamSpec.kwargs, + ) -> T_Model: + ... + + +@overload +def patch( + client: OpenAI, + mode: Mode = Mode.FUNCTIONS, +) -> OpenAI: + ... + + +@overload +def patch( + client: AsyncOpenAI, + mode: Mode = Mode.FUNCTIONS, +) -> AsyncOpenAI: + ... + + +@overload +def patch( + create: Callable[T_ParamSpec, T_Retval], + mode: Mode = Mode.FUNCTIONS, +) -> InstructorChatCompletionCreate: + ... + + +def patch( + client: Union[OpenAI, AsyncOpenAI] = None, + create: Callable[T_ParamSpec, T_Retval] = None, + mode: Mode = Mode.FUNCTIONS, +) -> Union[OpenAI, AsyncOpenAI]: """ Patch the `client.chat.completions.create` method @@ -429,10 +448,68 @@ def patch(client: Union[OpenAI, AsyncOpenAI], mode: Mode = Mode.FUNCTIONS): """ logger.debug(f"Patching `client.chat.completions.create` with {mode=}") - client.chat.completions.create = wrap_chatcompletion( - client.chat.completions.create, mode=mode - ) - return client + + if create is not None: + func = create + elif client is not None: + func = client.chat.completions.create + else: + raise ValueError("Either client or create must be provided") + + func_is_async = is_async(func) + + @wraps(func) + async def new_create_async( + response_model: Type[T_Model] = None, + validation_context: dict = None, + max_retries: int = 1, + *args: T_ParamSpec.args, + **kwargs: T_ParamSpec.kwargs, + ) -> T_Model: + response_model, new_kwargs = handle_response_model( + response_model=response_model, mode=mode, **kwargs + ) + response = await retry_async( + func=func, + response_model=response_model, + validation_context=validation_context, + max_retries=max_retries, + args=args, + kwargs=new_kwargs, + mode=mode, + ) # type: ignore + return response + + @wraps(func) + def new_create_sync( + response_model: Type[T_Model] = None, + validation_context: dict = None, + max_retries: int = 1, + *args: T_ParamSpec.args, + **kwargs: T_ParamSpec.kwargs, + ) -> T_Model: + response_model, new_kwargs = handle_response_model( + response_model=response_model, mode=mode, **kwargs + ) + response = retry_sync( + func=func, + response_model=response_model, + validation_context=validation_context, + max_retries=max_retries, + args=args, + kwargs=new_kwargs, + mode=mode, + ) + return response + + new_create = new_create_async if func_is_async else new_create_sync + new_create.__doc__ = OVERRIDE_DOCS + + if client is not None: + client.chat.completions.create = new_create + return client + else: + return new_create def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS): @@ -448,4 +525,9 @@ def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS): - `validation_context` parameter to validate the response using the pydantic model - `strict` parameter to use strict json parsing """ + import warnings + + warnings.warn( + "apatch is deprecated, use patch instead", DeprecationWarning, stacklevel=2 + ) return patch(client, mode=mode) diff --git a/tests/openai/test_parallel.py b/tests/openai/test_parallel.py new file mode 100644 index 0000000..8f1b261 --- /dev/null +++ b/tests/openai/test_parallel.py @@ -0,0 +1,80 @@ +from typing import Iterable, Literal, Union +from pydantic import BaseModel + +import pytest +import instructor + + +class Weather(BaseModel): + location: str + units: Literal["imperial", "metric"] + + +class GoogleSearch(BaseModel): + query: str + + +def test_sync_parallel_tools_or(client): + client = instructor.patch(client, mode=instructor.Mode.PARALLEL_TOOLS) + resp = client.chat.completions.create( + model="gpt-4-turbo-preview", + messages=[ + {"role": "system", "content": "You must always use tools"}, + { + "role": "user", + "content": "What is the weather in toronto and dallas and who won the super bowl?", + }, + ], + response_model=Iterable[Weather | GoogleSearch], + ) + assert len(list(resp)) == 3 + + +@pytest.mark.asyncio +async def test_async_parallel_tools_or(aclient): + client = instructor.patch(aclient, mode=instructor.Mode.PARALLEL_TOOLS) + resp = await client.chat.completions.create( + model="gpt-4-turbo-preview", + messages=[ + {"role": "system", "content": "You must always use tools"}, + { + "role": "user", + "content": "What is the weather in toronto and dallas and who won the super bowl?", + }, + ], + response_model=Iterable[Weather | GoogleSearch], + ) + assert len(list(resp)) == 3 + + +def test_sync_parallel_tools_one(client): + client = instructor.patch(client, mode=instructor.Mode.PARALLEL_TOOLS) + resp = client.chat.completions.create( + model="gpt-4-turbo-preview", + messages=[ + {"role": "system", "content": "You must always use tools"}, + { + "role": "user", + "content": "What is the weather in toronto and dallas?", + }, + ], + response_model=Iterable[Weather], + ) + assert len(list(resp)) == 2 + + +@pytest.mark.asyncio +async def test_async_parallel_tools_one(aclient): + client = instructor.patch(aclient, mode=instructor.Mode.PARALLEL_TOOLS) + resp = await client.chat.completions.create( + model="gpt-4-turbo-preview", + messages=[ + {"role": "system", "content": "You must always use tools"}, + { + "role": "user", + "content": "What is the weather in toronto and dallas?", + }, + ], + response_model=Iterable[Weather], + ) + assert len(list(resp)) == 2 diff --git a/tests/test_patch.py b/tests/test_patch.py index bf4810f..6d6d998 100644 --- a/tests/test_patch.py +++ b/tests/test_patch.py @@ -10,7 +10,7 @@ from openai.types.chat.chat_completion_message_tool_call import ( ) import instructor -from instructor.patch import OVERRIDE_DOCS, dump_message, is_async, wrap_chatcompletion +from instructor.patch import OVERRIDE_DOCS, dump_message, is_async def test_patch_completes_successfully(): @@ -21,27 +21,6 @@ def test_apatch_completes_successfully(): instructor.apatch(AsyncOpenAI()) -@pytest.mark.asyncio -async def test_wrap_chatcompletion_wraps_async_input_function(): - async def input_function(*args, **kwargs): - return "Hello, World!" - - wrapped_function = wrap_chatcompletion(input_function) - result = await wrapped_function() - - assert result == "Hello, World!" - - -def test_wrap_chatcompletion_wraps_input_function(): - def input_function(*args, **kwargs): - return "Hello, World!" - - wrapped_function = wrap_chatcompletion(input_function) - result = wrapped_function() - - assert result == "Hello, World!" - - def test_is_async_returns_true_if_function_is_async(): async def async_function(): pass