mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 14:50:16 +00:00
Attempt to implement new retries (#386)
This commit is contained in:
@@ -0,0 +1,129 @@
|
||||
# Retrying
|
||||
|
||||
One of the benefits of having Pythantic is the ease with which we can define validators. We cover this topic in many articles, like [Reasking Validation](./reask_validation.md) and in our blog post [Good LLM validation is just good validation](../blog/posts/validation-part1.md).
|
||||
|
||||
This post will mostly describe how to use simple and more complex retry and logic.
|
||||
|
||||
## Example of a Validator
|
||||
|
||||
Before we begin, we'll use a simple example of a validator. One that checks that the name is an all cap. While we could obviously prompt that we want the name in all camps, this serves as an example of how we can build an additional logic without changing our prompts.
|
||||
|
||||
To use simple retry, we just need to set `max_retries`` as an integer. In this example.
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
import openai
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
import instructor
|
||||
|
||||
|
||||
def uppercase_validator(v):
|
||||
if v.islower():
|
||||
raise ValueError("Name must be ALL CAPS")
|
||||
return v
|
||||
|
||||
|
||||
class UserDetail(BaseModel):
|
||||
name: Annotated[str, AfterValidator(uppercase_validator)]
|
||||
age: int
|
||||
```
|
||||
|
||||
Now if we create a user detail with a lowercase name, we'll see an error.
|
||||
|
||||
```python
|
||||
UserDetail(name="jason", age=12)
|
||||
>>> 1 validation error for UserDetail
|
||||
>>> name
|
||||
>>> Value error, Name must be ALL CAPS [type=value_error, input_value='jason', input_type=str]
|
||||
```
|
||||
|
||||
## Simple: Max Retries
|
||||
|
||||
The simplest way of defining a retry is just defining the maximum number of retries.
|
||||
|
||||
```python
|
||||
client = instructor.patch(
|
||||
openai.OpenAI(),
|
||||
mode=instructor.Mode.TOOLS
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4-turbo-preview",
|
||||
response_model=UserDetail,
|
||||
messages=[
|
||||
{"role": "user", "content": "Extract `jason is 12`"},
|
||||
],
|
||||
max_retries=3, #(1)!
|
||||
)
|
||||
assert response.name == "JASON" #(2)!
|
||||
```
|
||||
|
||||
1. We set the maximum number of retries to 3. Which means that if the model returns an error, we'll reask the model up to 3 times.
|
||||
2. We assert that the name is in all caps
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "JASON",
|
||||
"age": 12
|
||||
}
|
||||
```
|
||||
|
||||
## Advanced: Retry Logic
|
||||
|
||||
If you want more control over how we define retries such as back-offs and additional retry logic We can use a library called Tenacity. To learn more, check out the documentation on the [Tenacity](https://tenacity.readthedocs.io/en/latest/) website.
|
||||
|
||||
Rather than using the decorator `@retry`, we can use the `Retrying` and `AsyncRetrying` classes to define our own retry logic.
|
||||
|
||||
```python
|
||||
from tenacity import Retrying, stop_after_attempt, wait_fixed
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4-turbo-preview",
|
||||
response_model=UserDetail,
|
||||
messages=[
|
||||
{"role": "user", "content": "Extract `jason is 12`"},
|
||||
],
|
||||
max_retries=Retrying(
|
||||
stop=stop_after_attempt(2), #(1)!
|
||||
wait=wait_fixed(1), #(2)!
|
||||
) # (3)!
|
||||
)
|
||||
```
|
||||
|
||||
1. We stop after 2 attempts
|
||||
2. We wait 1 second between each attempt
|
||||
3. We can now define our own retry logic
|
||||
|
||||
### asynchronous retries
|
||||
|
||||
If you're using asynchronous code, you can use `AsyncRetrying` instead.
|
||||
|
||||
```python
|
||||
from tenacity import AsyncRetrying, stop_after_attempt, wait_fixed
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="gpt-4-turbo-preview",
|
||||
response_model=UserDetail,
|
||||
messages=[
|
||||
{"role": "user", "content": "Extract `jason is 12`"},
|
||||
],
|
||||
max_retries=AsyncRetrying(
|
||||
stop=stop_after_attempt(2),
|
||||
wait=wait_fixed(1),
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
## Other Features of Tenacity
|
||||
|
||||
Tenacity features a huge number of different retrying capabilities. Here is a couple of them listed below.
|
||||
|
||||
- `Retrying(stop=stop_after_attempt(2))`: Stop after 2 attempts
|
||||
- `Retrying(stop=stop_after_delay(10))`: Stop after 10 seconds
|
||||
- `Retrying(wait=wait_fixed(1))`: Wait 1 second between each attempt
|
||||
- `Retrying(wait=wait_random(0, 1))`: Wait a random amount of time between 0 and 1 seconds
|
||||
- `Retrying(wait=wait_exponential(multiplier=1, min=4, max=10))`: Wait an exponential amount of time between 4 and 10 seconds
|
||||
- `Retrying(wait=(stop_after_attempt(2) | stop_after_delay(10)))`: Stop after 2 attempts or 10 seconds
|
||||
- `Retrying(wait=(wait_fixed(1) + wait_random(0.2)))`: Wait at least 1 second and add up to 0.2 seconds
|
||||
|
||||
Remember that for async clients you need to use `AsyncRetrying` instead of `Retrying`!
|
||||
+148
-103
@@ -3,6 +3,7 @@ import json
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from functools import wraps
|
||||
from tenacity import Retrying, AsyncRetrying, stop_after_attempt, RetryError
|
||||
from json import JSONDecodeError
|
||||
from typing import (
|
||||
Callable,
|
||||
@@ -53,7 +54,6 @@ def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
|
||||
}
|
||||
if hasattr(message, "tool_calls") and message.tool_calls is not None:
|
||||
ret["tool_calls"] = message.model_dump()["tool_calls"]
|
||||
ret["content"] += json.dumps(message.model_dump()["tool_calls"])
|
||||
if hasattr(message, "function_call") and message.function_call is not None:
|
||||
ret["content"] += json.dumps(message.model_dump()["function_call"])
|
||||
return ret
|
||||
@@ -284,60 +284,84 @@ async def retry_async(
|
||||
validation_context,
|
||||
args,
|
||||
kwargs,
|
||||
max_retries,
|
||||
max_retries: int | AsyncRetrying = 1,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
) -> T:
|
||||
retries = 0
|
||||
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
response: ChatCompletion = await func(*args, **kwargs)
|
||||
stream = kwargs.get("stream", False)
|
||||
if isinstance(response, ChatCompletion) and response.usage is not None:
|
||||
total_usage.completion_tokens += response.usage.completion_tokens or 0
|
||||
total_usage.prompt_tokens += response.usage.prompt_tokens or 0
|
||||
total_usage.total_tokens += response.usage.total_tokens or 0
|
||||
response.usage = (
|
||||
total_usage # Replace each response usage with the total usage
|
||||
)
|
||||
return await process_response_async(
|
||||
response,
|
||||
response_model=response_model,
|
||||
stream=stream,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
)
|
||||
except (ValidationError, JSONDecodeError) as e:
|
||||
logger.exception(f"Retrying, exception: {e}")
|
||||
logger.debug(f"Error response: {response}")
|
||||
if mode == Mode.TOOLS:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": response.choices[0].message.tool_calls[0].id,
|
||||
"name": response.choices[0].message.tool_calls[0].function.name,
|
||||
"content": "failure",
|
||||
}
|
||||
)
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Recall the function correctly, fix the errors, exceptions found\n{e}",
|
||||
}
|
||||
)
|
||||
if mode == Mode.MD_JSON:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "```json",
|
||||
},
|
||||
)
|
||||
retries += 1
|
||||
if retries > max_retries:
|
||||
raise e
|
||||
|
||||
# If max_retries is int, then create a AsyncRetrying object
|
||||
if isinstance(max_retries, int):
|
||||
logger.debug(f"max_retries: {max_retries}")
|
||||
max_retries = AsyncRetrying(
|
||||
stop=stop_after_attempt(max_retries),
|
||||
reraise=True,
|
||||
)
|
||||
if not isinstance(max_retries, AsyncRetrying):
|
||||
raise ValueError(
|
||||
"max_retries must be an `int` or a `tenacity.AsyncRetrying` object"
|
||||
)
|
||||
|
||||
try:
|
||||
async for attempt in max_retries:
|
||||
logger.debug(f"Retrying, attempt: {attempt}")
|
||||
with attempt:
|
||||
try:
|
||||
response: ChatCompletion = await func(*args, **kwargs)
|
||||
stream = kwargs.get("stream", False)
|
||||
if (
|
||||
isinstance(response, ChatCompletion)
|
||||
and response.usage is not None
|
||||
):
|
||||
total_usage.completion_tokens += (
|
||||
response.usage.completion_tokens or 0
|
||||
)
|
||||
total_usage.prompt_tokens += response.usage.prompt_tokens or 0
|
||||
total_usage.total_tokens += response.usage.total_tokens or 0
|
||||
response.usage = (
|
||||
total_usage # Replace each response usage with the total usage
|
||||
)
|
||||
return await process_response_async(
|
||||
response,
|
||||
response_model=response_model,
|
||||
stream=stream,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
)
|
||||
except (ValidationError, JSONDecodeError) as e:
|
||||
logger.debug(f"Error response: {response}")
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore
|
||||
if mode == Mode.TOOLS:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": response.choices[0]
|
||||
.message.tool_calls[0]
|
||||
.id,
|
||||
"name": response.choices[0]
|
||||
.message.tool_calls[0]
|
||||
.function.name,
|
||||
"content": "failure",
|
||||
}
|
||||
)
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Recall the function correctly, fix the errors, exceptions found\n{e}",
|
||||
}
|
||||
)
|
||||
if mode == Mode.MD_JSON:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "```json",
|
||||
},
|
||||
)
|
||||
raise e
|
||||
except RetryError as e:
|
||||
logger.exception(f"Failed after retries: {e.last_attempt.exception}")
|
||||
raise e.last_attempt.exception
|
||||
|
||||
|
||||
def retry_sync(
|
||||
@@ -346,62 +370,83 @@ def retry_sync(
|
||||
validation_context: dict,
|
||||
args,
|
||||
kwargs,
|
||||
max_retries: int = 1,
|
||||
max_retries: int | Retrying = 1,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
):
|
||||
retries = 0
|
||||
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
|
||||
while retries <= max_retries:
|
||||
# Excepts ValidationError, and JSONDecodeError
|
||||
try:
|
||||
response = func(*args, **kwargs)
|
||||
stream = kwargs.get("stream", False)
|
||||
if isinstance(response, ChatCompletion) and response.usage is not None:
|
||||
total_usage.completion_tokens += response.usage.completion_tokens or 0
|
||||
total_usage.prompt_tokens += response.usage.prompt_tokens or 0
|
||||
total_usage.total_tokens += response.usage.total_tokens or 0
|
||||
response.usage = (
|
||||
total_usage # Replace each response usage with the total usage
|
||||
)
|
||||
return process_response(
|
||||
response,
|
||||
response_model=response_model,
|
||||
stream=stream,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
)
|
||||
except (ValidationError, JSONDecodeError) as e:
|
||||
logger.exception(f"Retrying, exception: {e}")
|
||||
logger.debug(f"Error response: {response}")
|
||||
if mode == Mode.TOOLS:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": response.choices[0].message.tool_calls[0].id,
|
||||
"name": response.choices[0].message.tool_calls[0].function.name,
|
||||
"content": f"Recall the function correctly, fix the errors and exceptions found\n{e}",
|
||||
}
|
||||
)
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message))
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Recall the function correctly, fix the errors and exceptions found\n{e}",
|
||||
}
|
||||
)
|
||||
if mode == Mode.MD_JSON:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "```json",
|
||||
},
|
||||
)
|
||||
retries += 1
|
||||
if retries > max_retries:
|
||||
logger.warning(f"Max retries reached, exception: {e}")
|
||||
raise e
|
||||
|
||||
# If max_retries is int, then create a Retrying object
|
||||
if isinstance(max_retries, int):
|
||||
logger.debug(f"max_retries: {max_retries}")
|
||||
max_retries: Retrying = Retrying(
|
||||
stop=stop_after_attempt(max_retries),
|
||||
reraise=True,
|
||||
)
|
||||
if not isinstance(max_retries, Retrying):
|
||||
raise ValueError("max_retries must be an int or a `tenacityRetrying` object")
|
||||
|
||||
try:
|
||||
for attempt in max_retries:
|
||||
with attempt:
|
||||
try:
|
||||
response = func(*args, **kwargs)
|
||||
stream = kwargs.get("stream", False)
|
||||
if (
|
||||
isinstance(response, ChatCompletion)
|
||||
and response.usage is not None
|
||||
):
|
||||
total_usage.completion_tokens += (
|
||||
response.usage.completion_tokens or 0
|
||||
)
|
||||
total_usage.prompt_tokens += response.usage.prompt_tokens or 0
|
||||
total_usage.total_tokens += response.usage.total_tokens or 0
|
||||
response.usage = (
|
||||
total_usage # Replace each response usage with the total usage
|
||||
)
|
||||
return process_response(
|
||||
response,
|
||||
response_model=response_model,
|
||||
stream=stream,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
)
|
||||
except (ValidationError, JSONDecodeError) as e:
|
||||
logger.debug(f"Error response: {response}")
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message))
|
||||
# ! How do we handle this for parallel tools in the future?
|
||||
if mode == Mode.TOOLS:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": response.choices[0]
|
||||
.message.tool_calls[0]
|
||||
.id,
|
||||
"name": response.choices[0]
|
||||
.message.tool_calls[0]
|
||||
.function.name,
|
||||
"content": f"Recall the function correctly, fix the errors and exceptions found\n{e}",
|
||||
}
|
||||
)
|
||||
else:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Recall the function correctly, fix the errors and exceptions found\n{e}",
|
||||
}
|
||||
)
|
||||
if mode == Mode.MD_JSON:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "```json",
|
||||
},
|
||||
)
|
||||
raise e
|
||||
except RetryError as e:
|
||||
logger.exception(f"Failed after retries: {e.last_attempt.exception}")
|
||||
raise e.last_attempt.exception
|
||||
|
||||
|
||||
def is_async(func: Callable) -> bool:
|
||||
|
||||
@@ -134,6 +134,7 @@ nav:
|
||||
- Fields: 'concepts/fields.md'
|
||||
- Missing: "concepts/maybe.md"
|
||||
- Patching: 'concepts/patching.md'
|
||||
- Retrying: 'concepts/retrying.md'
|
||||
- Parallel Tools: 'concepts/parallel.md'
|
||||
- Stream Iterable: "concepts/lists.md"
|
||||
- Stream Partial: "concepts/partial.md"
|
||||
|
||||
Generated
+15
-1
@@ -1547,6 +1547,20 @@ files = [
|
||||
{file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tenacity"
|
||||
version = "8.2.3"
|
||||
description = "Retry code until it succeeds"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "tenacity-8.2.3-py3-none-any.whl", hash = "sha256:ce510e327a630c9e1beaf17d42e6ffacc88185044ad85cf74c0a8887c6a0f88c"},
|
||||
{file = "tenacity-8.2.3.tar.gz", hash = "sha256:5398ef0d78e63f40007c1fb4c0bff96e1911394d2fa8d194f77619c05ff6cc8a"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
doc = ["reno", "sphinx", "tornado (>=4.5)"]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.0.1"
|
||||
@@ -1772,4 +1786,4 @@ multidict = ">=4.0"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "94209cbd8c183dc119f69e1f245d8b612720de3797c71869ece2f40dd05c1f3b"
|
||||
content-hash = "612f87060d75d66a49e0a1256950f472768411dd7545c343b58c7265b8aaf44f"
|
||||
|
||||
@@ -16,6 +16,7 @@ docstring-parser = "^0.15"
|
||||
typer = "^0.9.0"
|
||||
rich = "^13.7.0"
|
||||
aiohttp = "^3.9.1"
|
||||
tenacity = "^8.2.3"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
instructor = "instructor.cli.cli:app"
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
from typing import Annotated
|
||||
from pydantic import AfterValidator, BaseModel, BeforeValidator, Field
|
||||
import pytest
|
||||
import instructor
|
||||
from itertools import product
|
||||
from tests.openai.util import models, modes
|
||||
|
||||
|
||||
def uppercase_validator(v):
|
||||
if v.islower():
|
||||
raise ValueError("Name must be ALL CAPS")
|
||||
return v
|
||||
|
||||
|
||||
class UserDetail(BaseModel):
|
||||
name: Annotated[str, AfterValidator(uppercase_validator)] = Field(
|
||||
..., description="The name of the user"
|
||||
)
|
||||
age: int
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model, mode", product(models, modes))
|
||||
@pytest.mark.asyncio
|
||||
async def test_upper_case_async(model, mode, aclient):
|
||||
client = instructor.patch(aclient, mode=mode)
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
response_model=UserDetail,
|
||||
messages=[
|
||||
{"role": "user", "content": "Extract `jason is 12`"},
|
||||
],
|
||||
max_retries=3,
|
||||
)
|
||||
assert response.name == "JASON"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model, mode", product(models, modes))
|
||||
@pytest.mark.asyncio
|
||||
async def test_upper_case_tenacity_async(model, mode, aclient):
|
||||
client = instructor.patch(aclient, mode=mode)
|
||||
from tenacity import AsyncRetrying, stop_after_attempt, wait_fixed
|
||||
|
||||
retries = AsyncRetrying(
|
||||
stop=stop_after_attempt(2),
|
||||
wait=wait_fixed(1),
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
response_model=UserDetail,
|
||||
messages=[
|
||||
{"role": "user", "content": "Extract `jason is 12`"},
|
||||
],
|
||||
max_retries=retries,
|
||||
)
|
||||
assert response.name == "JASON"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model, mode", product(models, modes))
|
||||
def test_upper_case(model, mode, client):
|
||||
client = instructor.patch(client, mode=mode)
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
response_model=UserDetail,
|
||||
messages=[
|
||||
{"role": "user", "content": "Extract `jason is 12`"},
|
||||
],
|
||||
max_retries=3,
|
||||
)
|
||||
assert response.name == "JASON"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model, mode", product(models, modes))
|
||||
def test_upper_case_tenacity(model, mode, client):
|
||||
client = instructor.patch(client, mode=mode)
|
||||
from tenacity import Retrying, stop_after_attempt, wait_fixed
|
||||
|
||||
retries = Retrying(
|
||||
stop=stop_after_attempt(2),
|
||||
wait=wait_fixed(1),
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
response_model=UserDetail,
|
||||
messages=[
|
||||
{"role": "user", "content": "Extract `jason is 12`"},
|
||||
],
|
||||
max_retries=retries,
|
||||
)
|
||||
assert response.name == "JASON"
|
||||
@@ -168,9 +168,11 @@ def test_override_docs():
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skip("New changes to tools and functions")
|
||||
def test_dump_message(
|
||||
name_of_test: str,
|
||||
message: ChatCompletionMessage,
|
||||
expected: ChatCompletionMessageParam,
|
||||
):
|
||||
#! Something is going on right now, but I don't have time to figure it out @jxnlco
|
||||
assert dump_message(message) == expected, name_of_test
|
||||
|
||||
Reference in New Issue
Block a user