diff --git a/instructor/exceptions.py b/instructor/exceptions.py index b36bd9f..815a265 100644 --- a/instructor/exceptions.py +++ b/instructor/exceptions.py @@ -1,6 +1,8 @@ class IncompleteOutputException(Exception): """Exception raised when the output from LLM is incomplete due to max tokens limit reached.""" - def __init__(self, message="The output is incomplete due to a max_tokens length limit."): + def __init__( + self, message="The output is incomplete due to a max_tokens length limit." + ): self.message = message - super().__init__(self.message) \ No newline at end of file + super().__init__(self.message) diff --git a/instructor/function_calls.py b/instructor/function_calls.py index 4bef31f..53cb6d5 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -122,9 +122,9 @@ class OpenAISchema(BaseModel): Returns: cls (OpenAISchema): An instance of the class """ - if completion.choices[0].finish_reason == 'length': + if completion.choices[0].finish_reason == "length": raise IncompleteOutputException() - + if stream_multitask: return cls.from_streaming_response(completion, mode) @@ -183,9 +183,9 @@ class OpenAISchema(BaseModel): Returns: cls (OpenAISchema): An instance of the class """ - if completion.choices[0].finish_reason == 'length': + if completion.choices[0].finish_reason == "length": raise IncompleteOutputException() - + if stream_multitask: return await cls.from_streaming_response_async(completion, mode) @@ -232,4 +232,4 @@ def openai_schema(cls) -> OpenAISchema: cls.__name__, __base__=(cls, OpenAISchema), ) - ) # type: ignore \ No newline at end of file + ) # type: ignore diff --git a/tests/test_function_calls.py b/tests/test_function_calls.py index 6514f04..bf0315a 100644 --- a/tests/test_function_calls.py +++ b/tests/test_function_calls.py @@ -2,9 +2,10 @@ import pytest from pydantic import BaseModel from openai.resources.chat.completions import ChatCompletion -from instructor import openai_schema, OpenAISchema, Mode +from instructor import openai_schema, OpenAISchema from instructor.exceptions import IncompleteOutputException + @pytest.fixture def test_model(): class TestModel(OpenAISchema): @@ -12,29 +13,28 @@ def test_model(): data: str return TestModel - + + @pytest.fixture def mock_completion(request): - finish_reason = 'stop' - data_content = "{\n\"data\": \"complete data\"\n}" + finish_reason = "stop" + data_content = '{\n"data": "complete data"\n}' - if hasattr(request, 'param'): - finish_reason = request.param.get('finish_reason', finish_reason) - data_content = request.param.get('data_content', data_content) + if hasattr(request, "param"): + finish_reason = request.param.get("finish_reason", finish_reason) + data_content = request.param.get("data_content", data_content) - - mock_choices = [{ - "index": 0, - "message": { - "role": "assistant", - "function_call": { - "name": "TestModel", - "arguments": data_content + mock_choices = [ + { + "index": 0, + "message": { + "role": "assistant", + "function_call": {"name": "TestModel", "arguments": data_content}, + "content": data_content, }, - "content": data_content, - }, - "finish_reason": finish_reason - }] + "finish_reason": finish_reason, + } + ] completion = ChatCompletion( id="test_id", @@ -43,9 +43,10 @@ def mock_completion(request): model="gpt-3.5-turbo", object="chat.completion", ) - + return completion + def test_openai_schema(): @openai_schema class Dataframe(BaseModel): @@ -83,22 +84,34 @@ def test_no_docstring(): == "Correctly extracted `Dummy` with all the required parameters with correct types" ) -@pytest.mark.parametrize('mock_completion', [{'finish_reason': 'length', 'data_content': '{\n\"data\": \"incomplete dat\"\n}'}], indirect=True) + +@pytest.mark.parametrize( + "mock_completion", + [{"finish_reason": "length", "data_content": '{\n"data": "incomplete dat"\n}'}], + indirect=True, +) def test_incomplete_output_exception(test_model, mock_completion): with pytest.raises(IncompleteOutputException): test_model.from_response(mock_completion) - + + def test_complete_output_no_exception(test_model, mock_completion): test_model_instance = test_model.from_response(mock_completion) assert test_model_instance.data == "complete data" + @pytest.mark.asyncio -@pytest.mark.parametrize('mock_completion', [{'finish_reason': 'length', 'data_content': '{\n\"data\": \"incomplete dat\"\n}'}], indirect=True) -async def test_incomplete_output_exception(test_model, mock_completion): +@pytest.mark.parametrize( + "mock_completion", + [{"finish_reason": "length", "data_content": '{\n"data": "incomplete dat"\n}'}], + indirect=True, +) +async def test_incomplete_output_exception_raise(test_model, mock_completion): with pytest.raises(IncompleteOutputException): await test_model.from_response(mock_completion) - + + @pytest.mark.asyncio -async def test_complete_output_no_exception(test_model, mock_completion): +async def test_async_complete_output_no_exception(test_model, mock_completion): test_model_instance = await test_model.from_response_async(mock_completion) - assert test_model_instance.data == "complete data" \ No newline at end of file + assert test_model_instance.data == "complete data"