Reasking logic on validations (#98)

* working cleaned up patch

* Reasking logic

* clean up

* remove

* clean up tests
This commit is contained in:
Jason Liu
2023-09-08 00:58:29 -05:00
committed by GitHub
parent cffbb04dca
commit 1cc45e3faf
5 changed files with 235 additions and 68 deletions
+82
View File
@@ -0,0 +1,82 @@
# Reasking When Validation Fails
Validators are a great tool for ensuring some property of the outputs. When you use the `patch()` method with the `openai` client, you can use the `max_retries` parameter to set the number of times you can reask. This allows the client to reattempt the API call a specified number of times if validation fails. Its another layer of defense against bad outputs of two forms.
1. Pydantic Validation Errors
2. JSON Decoding Errors
## Future Improvements
!!! notes "Contributions Welcome"
The current retry mechanism relies on a while loop. For a more robust solution, contributions to integrate the `tenacity` library are welcome.
## Example: Using Validators for Reasking
The example utilizes Pydantic's field validators in tandem with the `max_retries` parameter. In this example if the `name` field fails validation, the `openai` client will reattempt the API call. Here we use a plain validator, but we can also use [llms for validation](validation.md)
### Step 1: Define the Response Model with Validators
```python
import instructor
from pydantic import BaseModel, field_validator
# Apply the patch to the OpenAI client
instructor.patch()
class UserDetails(BaseModel):
name: str
age: int
@field_validator("name")
@classmethod
def validate_name(cls, v):
if v.upper() != v:
raise ValueError("Name must be in uppercase.")
return v
```
Here, the `UserDetails` class includes a validator for the `name` attribute. The validator checks that the name is in uppercase and raises a `ValueError` otherwise.
### Step 2: Exception Handling and Reasking
When validation fails, several steps are taken:
1. The existing messages are retained for the new API request.
2. The previous function call's response is added back.
3. A user prompt is included to reask the model, with details on the error.
```python
try:
...
except (ValidationError, JSONDecodeError) as e:
kwargs["messages"].append(dict(**response.choices[0].message))
kwargs["messages"].append(
{
"role": "user",
"content": f"Please correct the function call; errors encountered:\n{e}",
}
)
```
## Using the Client with Retries
Here, the `UserDetails` model is passed as the `response_model`, and `max_retries` is set to 2.
```python
model = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
response_model=UserDetails,
max_retries=2,
messages=[
{"role": "user", "content": "Extract jason is 25 years old"},
],
)
assert model.name == "JASON"
```
The `max_retries` parameter will trigger up to 2 reattempts if the `name` attribute fails the uppercase validation in `UserDetails`.
## Takeaways
Instead of framing "self-critique" or "self-reflection" in AI as new concepts, we can view them as validation errors with clear error messages that the systen can use to heal. This approach leverages existing programming practices for error handling, avoiding the need for new methodologies. We simplify the issue into code we already know how to write and leverage pydantic's powerful validation system to do so.
+3 -3
View File
@@ -1,6 +1,6 @@
# Introduction to Validation in Pydantic and LLMs
Validation is crucial when using Large Language Models (LLMs) for data extraction. It ensures data integrity, enables reasking for better results, and allows for overwriting incorrect values. Pydantic offers versatile validation capabilities suitable for use with LLM outputs.
Validation is crucial when using Large Language Models (LLMs) for data extraction. It ensures data integrity, enables [reasking for better results](reask.md), and allows for overwriting incorrect values. Pydantic offers versatile validation capabilities suitable for use with LLM outputs.
!!! note "Pydantic Validation Docs"
@@ -14,14 +14,14 @@ Validation is crucial when using Large Language Models (LLMs) for data extractio
## Importance of LLM Validation
- **Data Integrity**: Enforces data quality standards.
- **Reasking**: Utilizes Pydantic's error messages to improve LLM outputs.
- **[Reasking](reask.md)**: Utilizes Pydantic's error messages to improve LLM outputs.
- **Overwriting**: Overwrites incorrect values during API calls.
## Code Examples
### Simple Validation with Pydantic
The example uses a custom validator function to enforce a rule on the name attribute. If a user fails to input a full name (first and last name separated by a space), Pydantic will raise a validation error. This is useful for pre-processing data generated or extracted by an LLM. In the future, we can use this error to reask the model when appropriate.
The example uses a custom validator function to enforce a rule on the name attribute. If a user fails to input a full name (first and last name separated by a space), Pydantic will raise a validation error. If you want the LLM to automatically fix the error check out our [reasking docs.](reask.md)
```python
from pydantic import BaseModel, ValidationError
+115 -63
View File
@@ -1,72 +1,14 @@
from functools import wraps
from json import JSONDecodeError
from pydantic import ValidationError
import openai
import inspect
from typing import Callable, Optional, Type, Union
from typing import Callable, Type
from pydantic import BaseModel
from .function_calls import OpenAISchema, openai_schema
def wrap_chatcompletion(func: Callable) -> Callable:
is_async = inspect.iscoroutinefunction(func)
if is_async:
@wraps(func)
async def new_chatcompletion(
*args,
response_model: Optional[Union[Type[BaseModel], Type[OpenAISchema]]] = None,
**kwargs
): # type: ignore
if response_model is not None:
if not issubclass(response_model, OpenAISchema):
response_model = openai_schema(response_model)
kwargs["functions"] = [response_model.openai_schema]
kwargs["function_call"] = {"name": response_model.openai_schema["name"]}
if kwargs.get("stream", False) and response_model is not None:
import warnings
warnings.warn(
"stream=True is not supported when using response_model parameter"
)
response = await func(*args, **kwargs)
if response_model is not None:
model = response_model.from_response(response)
model._raw_response = response
return model
return response
else:
@wraps(func)
def new_chatcompletion(
*args,
response_model: Optional[Union[Type[BaseModel], Type[OpenAISchema]]] = None,
**kwargs
):
if response_model is not None:
if not issubclass(response_model, OpenAISchema):
response_model = openai_schema(response_model)
kwargs["functions"] = [response_model.openai_schema]
kwargs["function_call"] = {"name": response_model.openai_schema["name"]}
if kwargs.get("stream", False) and response_model is not None:
import warnings
warnings.warn(
"stream=True is not supported when using response_model parameter"
)
response = func(*args, **kwargs)
if response_model is not None:
model = response_model.from_response(response)
model._raw_response = response
return model
return response
new_chatcompletion.__doc__ = """
OVERRIDE_DOCS = """
Creates a new chat completion for the provided messages and parameters.
See: https://platform.openai.com/docs/api-reference/chat-completions/create
@@ -82,8 +24,118 @@ If need to obtain the raw response from OpenAI's API, you can access it using th
Parameters:
response_model (Union[Type[BaseModel], Type[OpenAISchema]]): The response model to use for parsing the response from OpenAI's API, if available (default: None)
max_retries (int): The maximum number of retries to attempt if the response is not valid (default: 0)
"""
return new_chatcompletion
def handle_response_model(response_model: Type[BaseModel], kwargs):
new_kwargs = kwargs.copy()
if response_model is not None:
if not issubclass(response_model, OpenAISchema):
response_model = openai_schema(response_model) # type: ignore
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore
if new_kwargs.get("stream", False) and response_model is not None:
import warnings
warnings.warn(
"stream=True is not supported when using response_model parameter"
)
return response_model, new_kwargs
def process_response(response, response_model): # type: ignore
if response_model is not None:
model = response_model.from_response(response)
model._raw_response = response
return model
return response
async def retry_async(func, response_model, args, kwargs, max_retries):
retries = 0
while retries <= max_retries:
try:
response = await func(*args, **kwargs)
return process_response(response, response_model), None
except (ValidationError, JSONDecodeError) as e:
kwargs["messages"].append(dict(**response.choices[0].message))
kwargs["messages"].append(
{
"role": "user",
"content": f"Recall the function correctly, exceptions found\n{e}",
}
)
retries += 1
if retries > max_retries:
raise e
def retry_sync(func, response_model, args, kwargs, max_retries):
retries = 0
new_kwargs = kwargs.copy()
while retries <= max_retries:
# Excepts ValidationError, and JSONDecodeError
try:
response = func(*args, **kwargs)
return process_response(response, response_model), None
except (ValidationError, JSONDecodeError) as e:
kwargs["messages"].append(dict(**response.choices[0].message))
kwargs["messages"].append(
{
"role": "user",
"content": f"Recall the function correctly, exceptions found\n{e}",
}
)
retries += 1
if retries > max_retries:
raise e
def wrap_chatcompletion(func: Callable) -> Callable:
is_async = inspect.iscoroutinefunction(func)
@wraps(func)
async def new_chatcompletion_async(response_model, *args, max_retries=0, **kwargs):
response_model, new_kwargs = handle_response_model(response_model, kwargs)
response, error = await retry_async(
func=func,
response_model=response_model,
max_retries=max_retries,
args=args,
kwargs=new_kwargs,
) # type: ignore
if error:
raise ValueError(error)
return process_response(response, response_model)
@wraps(func)
def new_chatcompletion_sync(response_model, *args, max_retries=0, **kwargs):
response_model, new_kwargs = handle_response_model(response_model, kwargs)
response, error = retry_sync(
func=func,
response_model=response_model,
max_retries=max_retries,
args=args,
kwargs=new_kwargs,
) # type: ignore
if error:
raise ValueError(error)
return response
wrapper_function = new_chatcompletion_async if is_async else new_chatcompletion_sync
wrapper_function.__doc__ = OVERRIDE_DOCS
return wrapper_function
def process_response(response, response_model):
if response_model is not None:
model = response_model.from_response(response)
model._raw_response = response
return model
return response
original_chatcompletion = openai.ChatCompletion.create
+2 -1
View File
@@ -46,8 +46,9 @@ nav:
- Introduction:
- Getting Started: 'index.md'
- Prompt Engineering Tips: 'tips/index.md'
- Meta Functions:
- Helpers:
- Validations (self critique): "validation.md"
- Reasking via Validators: "reask.md"
- Multiple Extractions: "multitask.md"
- Handling Missing Content: "maybe.md"
- Philosophy: 'philosophy.md'
+33 -1
View File
@@ -4,7 +4,7 @@ import openai
from instructor import patch
@pytest.mark.skip(reason="Needs openai call")
@pytest.mark.skip("Not implemented")
def test_runmodel():
patch()
@@ -24,3 +24,35 @@ def test_runmodel():
assert hasattr(
model, "_raw_response"
), "The raw response should be available from OpenAI"
@pytest.mark.skip("Not implemented")
def test_runmodel_validator():
patch()
from pydantic import field_validator
class UserExtract(BaseModel):
name: str
age: int
@field_validator("name")
@classmethod
def validate_name(cls, v):
if v.upper() != v:
raise ValueError("Name should be uppercase")
return v
model = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
response_model=UserExtract,
max_retries=2,
messages=[
{"role": "user", "content": "Extract jason is 25 years old"},
],
)
assert isinstance(model, UserExtract), "Should be instance of UserExtract"
assert model.name == "JASON"
assert hasattr(
model, "_raw_response"
), "The raw response should be available from OpenAI"