mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
Introduce total_usage variable to track cumulative token usage (#343)
Co-authored-by: Jason Liu <jason@jxnl.co>
This commit is contained in:
@@ -25,7 +25,7 @@ print(user._raw_response)
|
||||
|
||||
!!! tip "Accessing tokens usage"
|
||||
|
||||
This is the recommended way to access the tokens usage, since it is a pydantic model you can use any of the pydantic model methods on it. For example, you can access the `total_tokens` by doing `user._raw_response.usage.total_tokens`.
|
||||
This is the recommended way to access the tokens usage, since it is a pydantic model you can use any of the pydantic model methods on it. For example, you can access the `total_tokens` by doing `user._raw_response.usage.total_tokens`. Note that this also includes the tokens used during any previous unsuccessful attempts.
|
||||
|
||||
In the future, we may add additional hooks to the `raw_response` to make it easier to access the tokens usage.
|
||||
|
||||
|
||||
+18
-1
@@ -12,6 +12,7 @@ from openai.types.chat import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageParam,
|
||||
)
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from instructor.dsl.multitask import MultiTask, MultiTaskBase
|
||||
@@ -31,7 +32,7 @@ Using the `response_model` parameter, you can specify a response model to use fo
|
||||
|
||||
If `stream=True` is specified, the response will be parsed using the `from_stream_response` method of the response model, if available, otherwise it will be parsed using the `from_response` method.
|
||||
|
||||
If need to obtain the raw response from OpenAI's API, you can access it using the `_raw_response` attribute of the response model.
|
||||
If need to obtain the raw response from OpenAI's API, you can access it using the `_raw_response` attribute of the response model. The `_raw_response.usage` attribute is modified to reflect the token usage from the last successful response as well as from any previous unsuccessful attempts.
|
||||
|
||||
Parameters:
|
||||
response_model (Union[Type[BaseModel], Type[OpenAISchema]]): The response model to use for parsing the response from OpenAI's API, if available (default: None)
|
||||
@@ -225,10 +226,18 @@ async def retry_async(
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
):
|
||||
retries = 0
|
||||
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
response: ChatCompletion = await func(*args, **kwargs)
|
||||
stream = kwargs.get("stream", False)
|
||||
if isinstance(response, ChatCompletion) and response.usage is not None:
|
||||
total_usage.completion_tokens += response.usage.completion_tokens
|
||||
total_usage.prompt_tokens += response.usage.prompt_tokens
|
||||
total_usage.total_tokens += response.usage.total_tokens
|
||||
response.usage = (
|
||||
total_usage # Replace each response usage with the total usage
|
||||
)
|
||||
return await process_response_async(
|
||||
response,
|
||||
response_model=response_model,
|
||||
@@ -279,11 +288,19 @@ def retry_sync(
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
):
|
||||
retries = 0
|
||||
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
|
||||
while retries <= max_retries:
|
||||
# Excepts ValidationError, and JSONDecodeError
|
||||
try:
|
||||
response = func(*args, **kwargs)
|
||||
stream = kwargs.get("stream", False)
|
||||
if isinstance(response, ChatCompletion) and response.usage is not None:
|
||||
total_usage.completion_tokens += response.usage.completion_tokens
|
||||
total_usage.prompt_tokens += response.usage.prompt_tokens
|
||||
total_usage.total_tokens += response.usage.total_tokens
|
||||
response.usage = (
|
||||
total_usage # Replace each response usage with the total usage
|
||||
)
|
||||
return process_response(
|
||||
response,
|
||||
response_model=response_model,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from itertools import product
|
||||
from pydantic import BaseModel, field_validator
|
||||
from openai.types.chat import ChatCompletion
|
||||
import pytest
|
||||
import instructor
|
||||
|
||||
@@ -29,6 +30,8 @@ def test_runmodel(model, mode, client):
|
||||
model, "_raw_response"
|
||||
), "The raw response should be available from OpenAI"
|
||||
|
||||
ChatCompletion(**model._raw_response.model_dump())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model, mode", product(models, modes))
|
||||
@pytest.mark.asyncio
|
||||
@@ -49,6 +52,8 @@ async def test_runmodel_async(model, mode, aclient):
|
||||
model, "_raw_response"
|
||||
), "The raw response should be available from OpenAI"
|
||||
|
||||
ChatCompletion(**model._raw_response.model_dump())
|
||||
|
||||
|
||||
class UserExtractValidated(BaseModel):
|
||||
name: str
|
||||
@@ -81,6 +86,8 @@ def test_runmodel_validator(model, mode, client):
|
||||
model, "_raw_response"
|
||||
), "The raw response should be available from OpenAI"
|
||||
|
||||
ChatCompletion(**model._raw_response.model_dump())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model, mode", product(models, modes))
|
||||
@pytest.mark.asyncio
|
||||
@@ -99,3 +106,5 @@ async def test_runmodel_async_validator(model, mode, aclient):
|
||||
assert hasattr(
|
||||
model, "_raw_response"
|
||||
), "The raw response should be available from OpenAI"
|
||||
|
||||
ChatCompletion(**model._raw_response.model_dump())
|
||||
|
||||
@@ -3,6 +3,7 @@ from pydantic import BaseModel
|
||||
from openai.resources.chat.completions import ChatCompletion
|
||||
|
||||
from instructor import openai_schema, OpenAISchema
|
||||
import instructor
|
||||
from instructor.exceptions import IncompleteOutputException
|
||||
|
||||
|
||||
@@ -92,11 +93,13 @@ def test_no_docstring():
|
||||
)
|
||||
def test_incomplete_output_exception(test_model, mock_completion):
|
||||
with pytest.raises(IncompleteOutputException):
|
||||
test_model.from_response(mock_completion)
|
||||
test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS)
|
||||
|
||||
|
||||
def test_complete_output_no_exception(test_model, mock_completion):
|
||||
test_model_instance = test_model.from_response(mock_completion)
|
||||
test_model_instance = test_model.from_response(
|
||||
mock_completion, mode=instructor.Mode.FUNCTIONS
|
||||
)
|
||||
assert test_model_instance.data == "complete data"
|
||||
|
||||
|
||||
@@ -108,10 +111,12 @@ def test_complete_output_no_exception(test_model, mock_completion):
|
||||
)
|
||||
async def test_incomplete_output_exception_raise(test_model, mock_completion):
|
||||
with pytest.raises(IncompleteOutputException):
|
||||
await test_model.from_response(mock_completion)
|
||||
await test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_complete_output_no_exception(test_model, mock_completion):
|
||||
test_model_instance = await test_model.from_response_async(mock_completion)
|
||||
test_model_instance = await test_model.from_response_async(
|
||||
mock_completion, mode=instructor.Mode.FUNCTIONS
|
||||
)
|
||||
assert test_model_instance.data == "complete data"
|
||||
|
||||
Reference in New Issue
Block a user