pass strict through from create to from_response (#119)

This commit is contained in:
Jason Liu
2023-10-20 12:05:50 -04:00
committed by GitHub
parent 5191283f86
commit b77fd5dd0a
2 changed files with 67 additions and 23 deletions
+13 -4
View File
@@ -98,7 +98,7 @@ class openai_function:
return wrapper(*args, **kwargs)
def from_response(self, completion, throw_error=True):
def from_response(self, completion, throw_error=True, strict: bool = None):
"""
Parse the response from OpenAI's API and return the function call
@@ -118,7 +118,7 @@ class openai_function:
), "Function name does not match"
function_call = message["function_call"]
arguments = json.loads(function_call["arguments"], strict=False)
arguments = json.loads(function_call["arguments"], strict=strict)
return self.validate_func(**arguments)
@@ -209,13 +209,20 @@ class OpenAISchema(BaseModel):
}
@classmethod
def from_response(cls, completion, throw_error=True, validation_context=None):
def from_response(
cls,
completion,
throw_error: bool = True,
validation_context=None,
strict: bool = None,
):
"""Execute the function from the response of an openai chat completion
Parameters:
completion (openai.ChatCompletion): The response from an openai chat completion
throw_error (bool): Whether to throw an error if the function call is not detected
validation_context (dict): The validation context to use for validating the response
strict (bool): Whether to use strict json parsing
Returns:
cls (OpenAISchema): An instance of the class
@@ -229,7 +236,9 @@ class OpenAISchema(BaseModel):
), "Function name does not match"
return cls.model_validate_json(
message["function_call"]["arguments"], context=validation_context
message["function_call"]["arguments"],
context=validation_context,
strict=strict,
)
+54 -19
View File
@@ -3,7 +3,7 @@ from json import JSONDecodeError
from pydantic import ValidationError
import openai
import inspect
from typing import Callable, Type
from typing import Callable, Type, Optional
from pydantic import BaseModel
from .function_calls import OpenAISchema, openai_schema
@@ -46,10 +46,20 @@ def handle_response_model(response_model: Type[BaseModel], kwargs):
return response_model, new_kwargs
def process_response(response, response_model, validation_context=None): # type: ignore
def process_response(response, response_model, validation_context: dict = None, strict=None): # type: ignore
"""Processes a OpenAI response with the response model, if available
It can use `validation_context` and `strict` to validate the response
via the pydantic model
Args:
response (ChatCompletion): The response from OpenAI's API
response_model (BaseModel): The response model to use for parsing the response
validation_context (dict, optional): The validation context to use for validating the response. Defaults to None.
strict (bool, optional): Whether to use strict json parsing. Defaults to None.
"""
if response_model is not None:
model = response_model.from_response(
response, validation_context=validation_context
response, validation_context=validation_context, strict=strict
)
model._raw_response = response
return model
@@ -57,13 +67,27 @@ def process_response(response, response_model, validation_context=None): # type
async def retry_async(
func, response_model, validation_context, args, kwargs, max_retries
func,
response_model,
validation_context,
args,
kwargs,
max_retries,
strict: Optional[bool] = None,
):
retries = 0
while retries <= max_retries:
try:
response = await func(*args, **kwargs)
return process_response(response, response_model, validation_context), None
return (
process_response(
response,
response_model,
validation_context,
strict=strict,
),
None,
)
except (ValidationError, JSONDecodeError) as e:
kwargs["messages"].append(dict(**response.choices[0].message)) # type: ignore
kwargs["messages"].append(
@@ -77,14 +101,26 @@ async def retry_async(
raise e
def retry_sync(func, response_model, validation_context, args, kwargs, max_retries):
def retry_sync(
func,
response_model,
validation_context,
args,
kwargs,
max_retries,
strict: Optional[bool] = None,
):
retries = 0
new_kwargs = kwargs.copy()
while retries <= max_retries:
# Excepts ValidationError, and JSONDecodeError
try:
response = func(*args, **kwargs)
return process_response(response, response_model, validation_context), None
return (
process_response(
response, response_model, validation_context, strict=strict
),
None,
)
except (ValidationError, JSONDecodeError) as e:
kwargs["messages"].append(dict(**response.choices[0].message)) # type: ignore
kwargs["messages"].append(
@@ -140,23 +176,20 @@ def wrap_chatcompletion(func: Callable) -> Callable:
return wrapper_function
def process_response(response, response_model, validation_context=None):
if response_model is not None:
model = response_model.from_response(
response, validation_context=validation_context
)
model._raw_response = response
return model
return response
original_chatcompletion = openai.ChatCompletion.create
original_chatcompletion_async = openai.ChatCompletion.acreate
def patch():
"""
Patch the `openai.ChatCompletion.create` and `openai.ChatCompletion.acreate` methods to support the `response_model` parameter.
Patch the `openai.ChatCompletion.create` and `openai.ChatCompletion.acreate` methods
Enables the following features:
- `response_model` parameter to parse the response from OpenAI's API
- `max_retries` parameter to retry the function if the response is not valid
- `validation_context` parameter to validate the response using the pydantic model
- `strict` parameter to use strict json parsing
## Usage
@@ -180,6 +213,8 @@ def patch():
},
],
response_model=User,
validation_context={...},
strict=True,
)
print(user.model_dump())