Introduce total_usage variable to track cumulative token usage (#343)

Co-authored-by: Jason Liu <jason@jxnl.co>
This commit is contained in:
lazyhope
2024-01-14 11:57:28 +08:00
committed by GitHub
parent 3f901bc8d3
commit 36871e6917
4 changed files with 37 additions and 6 deletions
+1 -1
View File
@@ -25,7 +25,7 @@ print(user._raw_response)
!!! tip "Accessing tokens usage"
This is the recommended way to access the tokens usage, since it is a pydantic model you can use any of the pydantic model methods on it. For example, you can access the `total_tokens` by doing `user._raw_response.usage.total_tokens`.
This is the recommended way to access the tokens usage, since it is a pydantic model you can use any of the pydantic model methods on it. For example, you can access the `total_tokens` by doing `user._raw_response.usage.total_tokens`. Note that this also includes the tokens used during any previous unsuccessful attempts.
In the future, we may add additional hooks to the `raw_response` to make it easier to access the tokens usage.
+18 -1
View File
@@ -12,6 +12,7 @@ from openai.types.chat import (
ChatCompletionMessage,
ChatCompletionMessageParam,
)
from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel, ValidationError
from instructor.dsl.multitask import MultiTask, MultiTaskBase
@@ -31,7 +32,7 @@ Using the `response_model` parameter, you can specify a response model to use fo
If `stream=True` is specified, the response will be parsed using the `from_stream_response` method of the response model, if available, otherwise it will be parsed using the `from_response` method.
If need to obtain the raw response from OpenAI's API, you can access it using the `_raw_response` attribute of the response model.
If need to obtain the raw response from OpenAI's API, you can access it using the `_raw_response` attribute of the response model. The `_raw_response.usage` attribute is modified to reflect the token usage from the last successful response as well as from any previous unsuccessful attempts.
Parameters:
response_model (Union[Type[BaseModel], Type[OpenAISchema]]): The response model to use for parsing the response from OpenAI's API, if available (default: None)
@@ -225,10 +226,18 @@ async def retry_async(
mode: Mode = Mode.FUNCTIONS,
):
retries = 0
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
while retries <= max_retries:
try:
response: ChatCompletion = await func(*args, **kwargs)
stream = kwargs.get("stream", False)
if isinstance(response, ChatCompletion) and response.usage is not None:
total_usage.completion_tokens += response.usage.completion_tokens
total_usage.prompt_tokens += response.usage.prompt_tokens
total_usage.total_tokens += response.usage.total_tokens
response.usage = (
total_usage # Replace each response usage with the total usage
)
return await process_response_async(
response,
response_model=response_model,
@@ -279,11 +288,19 @@ def retry_sync(
mode: Mode = Mode.FUNCTIONS,
):
retries = 0
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
while retries <= max_retries:
# Excepts ValidationError, and JSONDecodeError
try:
response = func(*args, **kwargs)
stream = kwargs.get("stream", False)
if isinstance(response, ChatCompletion) and response.usage is not None:
total_usage.completion_tokens += response.usage.completion_tokens
total_usage.prompt_tokens += response.usage.prompt_tokens
total_usage.total_tokens += response.usage.total_tokens
response.usage = (
total_usage # Replace each response usage with the total usage
)
return process_response(
response,
response_model=response_model,
+9
View File
@@ -1,5 +1,6 @@
from itertools import product
from pydantic import BaseModel, field_validator
from openai.types.chat import ChatCompletion
import pytest
import instructor
@@ -29,6 +30,8 @@ def test_runmodel(model, mode, client):
model, "_raw_response"
), "The raw response should be available from OpenAI"
ChatCompletion(**model._raw_response.model_dump())
@pytest.mark.parametrize("model, mode", product(models, modes))
@pytest.mark.asyncio
@@ -49,6 +52,8 @@ async def test_runmodel_async(model, mode, aclient):
model, "_raw_response"
), "The raw response should be available from OpenAI"
ChatCompletion(**model._raw_response.model_dump())
class UserExtractValidated(BaseModel):
name: str
@@ -81,6 +86,8 @@ def test_runmodel_validator(model, mode, client):
model, "_raw_response"
), "The raw response should be available from OpenAI"
ChatCompletion(**model._raw_response.model_dump())
@pytest.mark.parametrize("model, mode", product(models, modes))
@pytest.mark.asyncio
@@ -99,3 +106,5 @@ async def test_runmodel_async_validator(model, mode, aclient):
assert hasattr(
model, "_raw_response"
), "The raw response should be available from OpenAI"
ChatCompletion(**model._raw_response.model_dump())
+9 -4
View File
@@ -3,6 +3,7 @@ from pydantic import BaseModel
from openai.resources.chat.completions import ChatCompletion
from instructor import openai_schema, OpenAISchema
import instructor
from instructor.exceptions import IncompleteOutputException
@@ -92,11 +93,13 @@ def test_no_docstring():
)
def test_incomplete_output_exception(test_model, mock_completion):
with pytest.raises(IncompleteOutputException):
test_model.from_response(mock_completion)
test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS)
def test_complete_output_no_exception(test_model, mock_completion):
test_model_instance = test_model.from_response(mock_completion)
test_model_instance = test_model.from_response(
mock_completion, mode=instructor.Mode.FUNCTIONS
)
assert test_model_instance.data == "complete data"
@@ -108,10 +111,12 @@ def test_complete_output_no_exception(test_model, mock_completion):
)
async def test_incomplete_output_exception_raise(test_model, mock_completion):
with pytest.raises(IncompleteOutputException):
await test_model.from_response(mock_completion)
await test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS)
@pytest.mark.asyncio
async def test_async_complete_output_no_exception(test_model, mock_completion):
test_model_instance = await test_model.from_response_async(mock_completion)
test_model_instance = await test_model.from_response_async(
mock_completion, mode=instructor.Mode.FUNCTIONS
)
assert test_model_instance.data == "complete data"