From b77fd5dd0a1882438cbb0892ee0591db83a2b3b9 Mon Sep 17 00:00:00 2001 From: Jason Liu Date: Fri, 20 Oct 2023 12:05:50 -0400 Subject: [PATCH] pass strict through from create to from_response (#119) --- instructor/function_calls.py | 17 +++++++-- instructor/patch.py | 73 ++++++++++++++++++++++++++---------- 2 files changed, 67 insertions(+), 23 deletions(-) diff --git a/instructor/function_calls.py b/instructor/function_calls.py index c759aae..9daa86b 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -98,7 +98,7 @@ class openai_function: return wrapper(*args, **kwargs) - def from_response(self, completion, throw_error=True): + def from_response(self, completion, throw_error=True, strict: bool = None): """ Parse the response from OpenAI's API and return the function call @@ -118,7 +118,7 @@ class openai_function: ), "Function name does not match" function_call = message["function_call"] - arguments = json.loads(function_call["arguments"], strict=False) + arguments = json.loads(function_call["arguments"], strict=strict) return self.validate_func(**arguments) @@ -209,13 +209,20 @@ class OpenAISchema(BaseModel): } @classmethod - def from_response(cls, completion, throw_error=True, validation_context=None): + def from_response( + cls, + completion, + throw_error: bool = True, + validation_context=None, + strict: bool = None, + ): """Execute the function from the response of an openai chat completion Parameters: completion (openai.ChatCompletion): The response from an openai chat completion throw_error (bool): Whether to throw an error if the function call is not detected validation_context (dict): The validation context to use for validating the response + strict (bool): Whether to use strict json parsing Returns: cls (OpenAISchema): An instance of the class @@ -229,7 +236,9 @@ class OpenAISchema(BaseModel): ), "Function name does not match" return cls.model_validate_json( - message["function_call"]["arguments"], context=validation_context + message["function_call"]["arguments"], + context=validation_context, + strict=strict, ) diff --git a/instructor/patch.py b/instructor/patch.py index 3833225..aa6f0a6 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -3,7 +3,7 @@ from json import JSONDecodeError from pydantic import ValidationError import openai import inspect -from typing import Callable, Type +from typing import Callable, Type, Optional from pydantic import BaseModel from .function_calls import OpenAISchema, openai_schema @@ -46,10 +46,20 @@ def handle_response_model(response_model: Type[BaseModel], kwargs): return response_model, new_kwargs -def process_response(response, response_model, validation_context=None): # type: ignore +def process_response(response, response_model, validation_context: dict = None, strict=None): # type: ignore + """Processes a OpenAI response with the response model, if available + It can use `validation_context` and `strict` to validate the response + via the pydantic model + + Args: + response (ChatCompletion): The response from OpenAI's API + response_model (BaseModel): The response model to use for parsing the response + validation_context (dict, optional): The validation context to use for validating the response. Defaults to None. + strict (bool, optional): Whether to use strict json parsing. Defaults to None. + """ if response_model is not None: model = response_model.from_response( - response, validation_context=validation_context + response, validation_context=validation_context, strict=strict ) model._raw_response = response return model @@ -57,13 +67,27 @@ def process_response(response, response_model, validation_context=None): # type async def retry_async( - func, response_model, validation_context, args, kwargs, max_retries + func, + response_model, + validation_context, + args, + kwargs, + max_retries, + strict: Optional[bool] = None, ): retries = 0 while retries <= max_retries: try: response = await func(*args, **kwargs) - return process_response(response, response_model, validation_context), None + return ( + process_response( + response, + response_model, + validation_context, + strict=strict, + ), + None, + ) except (ValidationError, JSONDecodeError) as e: kwargs["messages"].append(dict(**response.choices[0].message)) # type: ignore kwargs["messages"].append( @@ -77,14 +101,26 @@ async def retry_async( raise e -def retry_sync(func, response_model, validation_context, args, kwargs, max_retries): +def retry_sync( + func, + response_model, + validation_context, + args, + kwargs, + max_retries, + strict: Optional[bool] = None, +): retries = 0 - new_kwargs = kwargs.copy() while retries <= max_retries: # Excepts ValidationError, and JSONDecodeError try: response = func(*args, **kwargs) - return process_response(response, response_model, validation_context), None + return ( + process_response( + response, response_model, validation_context, strict=strict + ), + None, + ) except (ValidationError, JSONDecodeError) as e: kwargs["messages"].append(dict(**response.choices[0].message)) # type: ignore kwargs["messages"].append( @@ -140,23 +176,20 @@ def wrap_chatcompletion(func: Callable) -> Callable: return wrapper_function -def process_response(response, response_model, validation_context=None): - if response_model is not None: - model = response_model.from_response( - response, validation_context=validation_context - ) - model._raw_response = response - return model - return response - - original_chatcompletion = openai.ChatCompletion.create original_chatcompletion_async = openai.ChatCompletion.acreate def patch(): """ - Patch the `openai.ChatCompletion.create` and `openai.ChatCompletion.acreate` methods to support the `response_model` parameter. + Patch the `openai.ChatCompletion.create` and `openai.ChatCompletion.acreate` methods + + Enables the following features: + + - `response_model` parameter to parse the response from OpenAI's API + - `max_retries` parameter to retry the function if the response is not valid + - `validation_context` parameter to validate the response using the pydantic model + - `strict` parameter to use strict json parsing ## Usage @@ -180,6 +213,8 @@ def patch(): }, ], response_model=User, + validation_context={...}, + strict=True, ) print(user.model_dump())