mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
added finish reason exception IncompleteOutputException (#279)
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
class IncompleteOutputException(Exception):
|
||||
"""Exception raised when the output from LLM is incomplete due to max tokens limit reached."""
|
||||
|
||||
def __init__(self, message="The output is incomplete due to a max_tokens length limit."):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
@@ -1,6 +1,7 @@
|
||||
from docstring_parser import parse
|
||||
from functools import wraps
|
||||
from pydantic import BaseModel, create_model
|
||||
from instructor.exceptions import IncompleteOutputException
|
||||
|
||||
import enum
|
||||
|
||||
@@ -121,6 +122,9 @@ class OpenAISchema(BaseModel):
|
||||
Returns:
|
||||
cls (OpenAISchema): An instance of the class
|
||||
"""
|
||||
if completion.choices[0].finish_reason == 'length':
|
||||
raise IncompleteOutputException()
|
||||
|
||||
if stream_multitask:
|
||||
return cls.from_streaming_response(completion, mode)
|
||||
|
||||
@@ -179,6 +183,9 @@ class OpenAISchema(BaseModel):
|
||||
Returns:
|
||||
cls (OpenAISchema): An instance of the class
|
||||
"""
|
||||
if completion.choices[0].finish_reason == 'length':
|
||||
raise IncompleteOutputException()
|
||||
|
||||
if stream_multitask:
|
||||
return await cls.from_streaming_response_async(completion, mode)
|
||||
|
||||
@@ -225,4 +232,4 @@ def openai_schema(cls) -> OpenAISchema:
|
||||
cls.__name__,
|
||||
__base__=(cls, OpenAISchema),
|
||||
)
|
||||
) # type: ignore
|
||||
) # type: ignore
|
||||
@@ -1,8 +1,50 @@
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from openai.resources.chat.completions import ChatCompletion
|
||||
|
||||
from instructor import openai_schema, OpenAISchema
|
||||
from instructor import openai_schema, OpenAISchema, Mode
|
||||
from instructor.exceptions import IncompleteOutputException
|
||||
|
||||
@pytest.fixture
|
||||
def test_model():
|
||||
class TestModel(OpenAISchema):
|
||||
name: str = "TestModel"
|
||||
data: str
|
||||
|
||||
return TestModel
|
||||
|
||||
@pytest.fixture
|
||||
def mock_completion(request):
|
||||
finish_reason = 'stop'
|
||||
data_content = "{\n\"data\": \"complete data\"\n}"
|
||||
|
||||
if hasattr(request, 'param'):
|
||||
finish_reason = request.param.get('finish_reason', finish_reason)
|
||||
data_content = request.param.get('data_content', data_content)
|
||||
|
||||
|
||||
mock_choices = [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"function_call": {
|
||||
"name": "TestModel",
|
||||
"arguments": data_content
|
||||
},
|
||||
"content": data_content,
|
||||
},
|
||||
"finish_reason": finish_reason
|
||||
}]
|
||||
|
||||
completion = ChatCompletion(
|
||||
id="test_id",
|
||||
choices=mock_choices,
|
||||
created=1234567890,
|
||||
model="gpt-3.5-turbo",
|
||||
object="chat.completion",
|
||||
)
|
||||
|
||||
return completion
|
||||
|
||||
def test_openai_schema():
|
||||
@openai_schema
|
||||
@@ -40,3 +82,23 @@ def test_no_docstring():
|
||||
Dummy.openai_schema["description"]
|
||||
== "Correctly extracted `Dummy` with all the required parameters with correct types"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('mock_completion', [{'finish_reason': 'length', 'data_content': '{\n\"data\": \"incomplete dat\"\n}'}], indirect=True)
|
||||
def test_incomplete_output_exception(test_model, mock_completion):
|
||||
with pytest.raises(IncompleteOutputException):
|
||||
test_model.from_response(mock_completion)
|
||||
|
||||
def test_complete_output_no_exception(test_model, mock_completion):
|
||||
test_model_instance = test_model.from_response(mock_completion)
|
||||
assert test_model_instance.data == "complete data"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize('mock_completion', [{'finish_reason': 'length', 'data_content': '{\n\"data\": \"incomplete dat\"\n}'}], indirect=True)
|
||||
async def test_incomplete_output_exception(test_model, mock_completion):
|
||||
with pytest.raises(IncompleteOutputException):
|
||||
await test_model.from_response(mock_completion)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_output_no_exception(test_model, mock_completion):
|
||||
test_model_instance = await test_model.from_response_async(mock_completion)
|
||||
assert test_model_instance.data == "complete data"
|
||||
Reference in New Issue
Block a user