diff --git a/instructor/exceptions.py b/instructor/exceptions.py new file mode 100644 index 0000000..b36bd9f --- /dev/null +++ b/instructor/exceptions.py @@ -0,0 +1,6 @@ +class IncompleteOutputException(Exception): + """Exception raised when the output from LLM is incomplete due to max tokens limit reached.""" + + def __init__(self, message="The output is incomplete due to a max_tokens length limit."): + self.message = message + super().__init__(self.message) \ No newline at end of file diff --git a/instructor/function_calls.py b/instructor/function_calls.py index 3def9a0..4bef31f 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -1,6 +1,7 @@ from docstring_parser import parse from functools import wraps from pydantic import BaseModel, create_model +from instructor.exceptions import IncompleteOutputException import enum @@ -121,6 +122,9 @@ class OpenAISchema(BaseModel): Returns: cls (OpenAISchema): An instance of the class """ + if completion.choices[0].finish_reason == 'length': + raise IncompleteOutputException() + if stream_multitask: return cls.from_streaming_response(completion, mode) @@ -179,6 +183,9 @@ class OpenAISchema(BaseModel): Returns: cls (OpenAISchema): An instance of the class """ + if completion.choices[0].finish_reason == 'length': + raise IncompleteOutputException() + if stream_multitask: return await cls.from_streaming_response_async(completion, mode) @@ -225,4 +232,4 @@ def openai_schema(cls) -> OpenAISchema: cls.__name__, __base__=(cls, OpenAISchema), ) - ) # type: ignore + ) # type: ignore \ No newline at end of file diff --git a/tests/test_function_calls.py b/tests/test_function_calls.py index b5447d3..6514f04 100644 --- a/tests/test_function_calls.py +++ b/tests/test_function_calls.py @@ -1,8 +1,50 @@ import pytest from pydantic import BaseModel +from openai.resources.chat.completions import ChatCompletion -from instructor import openai_schema, OpenAISchema +from instructor import openai_schema, OpenAISchema, Mode +from instructor.exceptions import IncompleteOutputException +@pytest.fixture +def test_model(): + class TestModel(OpenAISchema): + name: str = "TestModel" + data: str + + return TestModel + +@pytest.fixture +def mock_completion(request): + finish_reason = 'stop' + data_content = "{\n\"data\": \"complete data\"\n}" + + if hasattr(request, 'param'): + finish_reason = request.param.get('finish_reason', finish_reason) + data_content = request.param.get('data_content', data_content) + + + mock_choices = [{ + "index": 0, + "message": { + "role": "assistant", + "function_call": { + "name": "TestModel", + "arguments": data_content + }, + "content": data_content, + }, + "finish_reason": finish_reason + }] + + completion = ChatCompletion( + id="test_id", + choices=mock_choices, + created=1234567890, + model="gpt-3.5-turbo", + object="chat.completion", + ) + + return completion def test_openai_schema(): @openai_schema @@ -40,3 +82,23 @@ def test_no_docstring(): Dummy.openai_schema["description"] == "Correctly extracted `Dummy` with all the required parameters with correct types" ) + +@pytest.mark.parametrize('mock_completion', [{'finish_reason': 'length', 'data_content': '{\n\"data\": \"incomplete dat\"\n}'}], indirect=True) +def test_incomplete_output_exception(test_model, mock_completion): + with pytest.raises(IncompleteOutputException): + test_model.from_response(mock_completion) + +def test_complete_output_no_exception(test_model, mock_completion): + test_model_instance = test_model.from_response(mock_completion) + assert test_model_instance.data == "complete data" + +@pytest.mark.asyncio +@pytest.mark.parametrize('mock_completion', [{'finish_reason': 'length', 'data_content': '{\n\"data\": \"incomplete dat\"\n}'}], indirect=True) +async def test_incomplete_output_exception(test_model, mock_completion): + with pytest.raises(IncompleteOutputException): + await test_model.from_response(mock_completion) + +@pytest.mark.asyncio +async def test_complete_output_no_exception(test_model, mock_completion): + test_model_instance = await test_model.from_response_async(mock_completion) + assert test_model_instance.data == "complete data" \ No newline at end of file