mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
pass strict through from create to from_response (#119)
This commit is contained in:
@@ -98,7 +98,7 @@ class openai_function:
|
||||
|
||||
return wrapper(*args, **kwargs)
|
||||
|
||||
def from_response(self, completion, throw_error=True):
|
||||
def from_response(self, completion, throw_error=True, strict: bool = None):
|
||||
"""
|
||||
Parse the response from OpenAI's API and return the function call
|
||||
|
||||
@@ -118,7 +118,7 @@ class openai_function:
|
||||
), "Function name does not match"
|
||||
|
||||
function_call = message["function_call"]
|
||||
arguments = json.loads(function_call["arguments"], strict=False)
|
||||
arguments = json.loads(function_call["arguments"], strict=strict)
|
||||
return self.validate_func(**arguments)
|
||||
|
||||
|
||||
@@ -209,13 +209,20 @@ class OpenAISchema(BaseModel):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_response(cls, completion, throw_error=True, validation_context=None):
|
||||
def from_response(
|
||||
cls,
|
||||
completion,
|
||||
throw_error: bool = True,
|
||||
validation_context=None,
|
||||
strict: bool = None,
|
||||
):
|
||||
"""Execute the function from the response of an openai chat completion
|
||||
|
||||
Parameters:
|
||||
completion (openai.ChatCompletion): The response from an openai chat completion
|
||||
throw_error (bool): Whether to throw an error if the function call is not detected
|
||||
validation_context (dict): The validation context to use for validating the response
|
||||
strict (bool): Whether to use strict json parsing
|
||||
|
||||
Returns:
|
||||
cls (OpenAISchema): An instance of the class
|
||||
@@ -229,7 +236,9 @@ class OpenAISchema(BaseModel):
|
||||
), "Function name does not match"
|
||||
|
||||
return cls.model_validate_json(
|
||||
message["function_call"]["arguments"], context=validation_context
|
||||
message["function_call"]["arguments"],
|
||||
context=validation_context,
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
|
||||
|
||||
+54
-19
@@ -3,7 +3,7 @@ from json import JSONDecodeError
|
||||
from pydantic import ValidationError
|
||||
import openai
|
||||
import inspect
|
||||
from typing import Callable, Type
|
||||
from typing import Callable, Type, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from .function_calls import OpenAISchema, openai_schema
|
||||
@@ -46,10 +46,20 @@ def handle_response_model(response_model: Type[BaseModel], kwargs):
|
||||
return response_model, new_kwargs
|
||||
|
||||
|
||||
def process_response(response, response_model, validation_context=None): # type: ignore
|
||||
def process_response(response, response_model, validation_context: dict = None, strict=None): # type: ignore
|
||||
"""Processes a OpenAI response with the response model, if available
|
||||
It can use `validation_context` and `strict` to validate the response
|
||||
via the pydantic model
|
||||
|
||||
Args:
|
||||
response (ChatCompletion): The response from OpenAI's API
|
||||
response_model (BaseModel): The response model to use for parsing the response
|
||||
validation_context (dict, optional): The validation context to use for validating the response. Defaults to None.
|
||||
strict (bool, optional): Whether to use strict json parsing. Defaults to None.
|
||||
"""
|
||||
if response_model is not None:
|
||||
model = response_model.from_response(
|
||||
response, validation_context=validation_context
|
||||
response, validation_context=validation_context, strict=strict
|
||||
)
|
||||
model._raw_response = response
|
||||
return model
|
||||
@@ -57,13 +67,27 @@ def process_response(response, response_model, validation_context=None): # type
|
||||
|
||||
|
||||
async def retry_async(
|
||||
func, response_model, validation_context, args, kwargs, max_retries
|
||||
func,
|
||||
response_model,
|
||||
validation_context,
|
||||
args,
|
||||
kwargs,
|
||||
max_retries,
|
||||
strict: Optional[bool] = None,
|
||||
):
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
response = await func(*args, **kwargs)
|
||||
return process_response(response, response_model, validation_context), None
|
||||
return (
|
||||
process_response(
|
||||
response,
|
||||
response_model,
|
||||
validation_context,
|
||||
strict=strict,
|
||||
),
|
||||
None,
|
||||
)
|
||||
except (ValidationError, JSONDecodeError) as e:
|
||||
kwargs["messages"].append(dict(**response.choices[0].message)) # type: ignore
|
||||
kwargs["messages"].append(
|
||||
@@ -77,14 +101,26 @@ async def retry_async(
|
||||
raise e
|
||||
|
||||
|
||||
def retry_sync(func, response_model, validation_context, args, kwargs, max_retries):
|
||||
def retry_sync(
|
||||
func,
|
||||
response_model,
|
||||
validation_context,
|
||||
args,
|
||||
kwargs,
|
||||
max_retries,
|
||||
strict: Optional[bool] = None,
|
||||
):
|
||||
retries = 0
|
||||
new_kwargs = kwargs.copy()
|
||||
while retries <= max_retries:
|
||||
# Excepts ValidationError, and JSONDecodeError
|
||||
try:
|
||||
response = func(*args, **kwargs)
|
||||
return process_response(response, response_model, validation_context), None
|
||||
return (
|
||||
process_response(
|
||||
response, response_model, validation_context, strict=strict
|
||||
),
|
||||
None,
|
||||
)
|
||||
except (ValidationError, JSONDecodeError) as e:
|
||||
kwargs["messages"].append(dict(**response.choices[0].message)) # type: ignore
|
||||
kwargs["messages"].append(
|
||||
@@ -140,23 +176,20 @@ def wrap_chatcompletion(func: Callable) -> Callable:
|
||||
return wrapper_function
|
||||
|
||||
|
||||
def process_response(response, response_model, validation_context=None):
|
||||
if response_model is not None:
|
||||
model = response_model.from_response(
|
||||
response, validation_context=validation_context
|
||||
)
|
||||
model._raw_response = response
|
||||
return model
|
||||
return response
|
||||
|
||||
|
||||
original_chatcompletion = openai.ChatCompletion.create
|
||||
original_chatcompletion_async = openai.ChatCompletion.acreate
|
||||
|
||||
|
||||
def patch():
|
||||
"""
|
||||
Patch the `openai.ChatCompletion.create` and `openai.ChatCompletion.acreate` methods to support the `response_model` parameter.
|
||||
Patch the `openai.ChatCompletion.create` and `openai.ChatCompletion.acreate` methods
|
||||
|
||||
Enables the following features:
|
||||
|
||||
- `response_model` parameter to parse the response from OpenAI's API
|
||||
- `max_retries` parameter to retry the function if the response is not valid
|
||||
- `validation_context` parameter to validate the response using the pydantic model
|
||||
- `strict` parameter to use strict json parsing
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -180,6 +213,8 @@ def patch():
|
||||
},
|
||||
],
|
||||
response_model=User,
|
||||
validation_context={...},
|
||||
strict=True,
|
||||
)
|
||||
|
||||
print(user.model_dump())
|
||||
|
||||
Reference in New Issue
Block a user