fix exceptions

This commit is contained in:
Jason Liu
2023-12-15 11:10:23 -05:00
parent c80dd23ec0
commit 870e2c80b8
3 changed files with 49 additions and 34 deletions
+4 -2
View File
@@ -1,6 +1,8 @@
class IncompleteOutputException(Exception):
"""Exception raised when the output from LLM is incomplete due to max tokens limit reached."""
def __init__(self, message="The output is incomplete due to a max_tokens length limit."):
def __init__(
self, message="The output is incomplete due to a max_tokens length limit."
):
self.message = message
super().__init__(self.message)
super().__init__(self.message)
+5 -5
View File
@@ -122,9 +122,9 @@ class OpenAISchema(BaseModel):
Returns:
cls (OpenAISchema): An instance of the class
"""
if completion.choices[0].finish_reason == 'length':
if completion.choices[0].finish_reason == "length":
raise IncompleteOutputException()
if stream_multitask:
return cls.from_streaming_response(completion, mode)
@@ -183,9 +183,9 @@ class OpenAISchema(BaseModel):
Returns:
cls (OpenAISchema): An instance of the class
"""
if completion.choices[0].finish_reason == 'length':
if completion.choices[0].finish_reason == "length":
raise IncompleteOutputException()
if stream_multitask:
return await cls.from_streaming_response_async(completion, mode)
@@ -232,4 +232,4 @@ def openai_schema(cls) -> OpenAISchema:
cls.__name__,
__base__=(cls, OpenAISchema),
)
) # type: ignore
) # type: ignore
+40 -27
View File
@@ -2,9 +2,10 @@ import pytest
from pydantic import BaseModel
from openai.resources.chat.completions import ChatCompletion
from instructor import openai_schema, OpenAISchema, Mode
from instructor import openai_schema, OpenAISchema
from instructor.exceptions import IncompleteOutputException
@pytest.fixture
def test_model():
class TestModel(OpenAISchema):
@@ -12,29 +13,28 @@ def test_model():
data: str
return TestModel
@pytest.fixture
def mock_completion(request):
finish_reason = 'stop'
data_content = "{\n\"data\": \"complete data\"\n}"
finish_reason = "stop"
data_content = '{\n"data": "complete data"\n}'
if hasattr(request, 'param'):
finish_reason = request.param.get('finish_reason', finish_reason)
data_content = request.param.get('data_content', data_content)
if hasattr(request, "param"):
finish_reason = request.param.get("finish_reason", finish_reason)
data_content = request.param.get("data_content", data_content)
mock_choices = [{
"index": 0,
"message": {
"role": "assistant",
"function_call": {
"name": "TestModel",
"arguments": data_content
mock_choices = [
{
"index": 0,
"message": {
"role": "assistant",
"function_call": {"name": "TestModel", "arguments": data_content},
"content": data_content,
},
"content": data_content,
},
"finish_reason": finish_reason
}]
"finish_reason": finish_reason,
}
]
completion = ChatCompletion(
id="test_id",
@@ -43,9 +43,10 @@ def mock_completion(request):
model="gpt-3.5-turbo",
object="chat.completion",
)
return completion
def test_openai_schema():
@openai_schema
class Dataframe(BaseModel):
@@ -83,22 +84,34 @@ def test_no_docstring():
== "Correctly extracted `Dummy` with all the required parameters with correct types"
)
@pytest.mark.parametrize('mock_completion', [{'finish_reason': 'length', 'data_content': '{\n\"data\": \"incomplete dat\"\n}'}], indirect=True)
@pytest.mark.parametrize(
"mock_completion",
[{"finish_reason": "length", "data_content": '{\n"data": "incomplete dat"\n}'}],
indirect=True,
)
def test_incomplete_output_exception(test_model, mock_completion):
with pytest.raises(IncompleteOutputException):
test_model.from_response(mock_completion)
def test_complete_output_no_exception(test_model, mock_completion):
test_model_instance = test_model.from_response(mock_completion)
assert test_model_instance.data == "complete data"
@pytest.mark.asyncio
@pytest.mark.parametrize('mock_completion', [{'finish_reason': 'length', 'data_content': '{\n\"data\": \"incomplete dat\"\n}'}], indirect=True)
async def test_incomplete_output_exception(test_model, mock_completion):
@pytest.mark.parametrize(
"mock_completion",
[{"finish_reason": "length", "data_content": '{\n"data": "incomplete dat"\n}'}],
indirect=True,
)
async def test_incomplete_output_exception_raise(test_model, mock_completion):
with pytest.raises(IncompleteOutputException):
await test_model.from_response(mock_completion)
@pytest.mark.asyncio
async def test_complete_output_no_exception(test_model, mock_completion):
async def test_async_complete_output_no_exception(test_model, mock_completion):
test_model_instance = await test_model.from_response_async(mock_completion)
assert test_model_instance.data == "complete data"
assert test_model_instance.data == "complete data"