Include types to instructor.function_calls and tests (#394)

Co-authored-by: Jason Liu <jxnl@users.noreply.github.com>
This commit is contained in:
Ezzeri Esa
2024-02-04 18:24:45 -08:00
committed by GitHub
parent 85ead574c8
commit 75908132a0
3 changed files with 43 additions and 28 deletions
+3
View File
@@ -13,8 +13,11 @@ env:
instructor/_types/_alias.py
instructor/cli/cli.py
instructor/cli/files.py
instructor/cli/jobs.py
instructor/cli/usage.py
instructor/exceptions.py
instructor/function_calls.py
tests/test_function_calls.py
jobs:
MyPy:
+12 -10
View File
@@ -1,4 +1,4 @@
from typing import Type, TypeVar
from typing import Any, Dict, Optional, Type, TypeVar
from docstring_parser import parse
from functools import wraps
from pydantic import BaseModel, create_model
@@ -19,7 +19,7 @@ class Mode(enum.Enum):
MD_JSON: str = "markdown_json_mode"
JSON_SCHEMA: str = "json_schema_mode"
def __new__(cls, value):
def __new__(cls, value: str) -> "Mode":
member = object.__new__(cls)
member._value_ = value
@@ -34,10 +34,10 @@ class Mode(enum.Enum):
return member
class OpenAISchema(BaseModel):
@classmethod
class OpenAISchema(BaseModel): # type: ignore[misc]
@classmethod # type: ignore[misc]
@property
def openai_schema(cls):
def openai_schema(cls) -> Dict[str, Any]:
"""
Return the schema in the format of OpenAI's schema as jsonschema
@@ -82,10 +82,10 @@ class OpenAISchema(BaseModel):
def from_response(
cls,
completion: T,
validation_context: dict = None,
strict: bool = None,
validation_context: Optional[Dict[str, Any]] = None,
strict: Optional[bool] = None,
mode: Mode = Mode.TOOLS,
):
) -> Dict[str, Any]:
"""Execute the function from the response of an openai chat completion
Parameters:
@@ -98,6 +98,8 @@ class OpenAISchema(BaseModel):
Returns:
cls (OpenAISchema): An instance of the class
"""
assert hasattr(completion, "choices")
if completion.choices[0].finish_reason == "length":
raise IncompleteOutputException()
@@ -105,7 +107,7 @@ class OpenAISchema(BaseModel):
if mode == Mode.FUNCTIONS:
assert (
message.function_call.name == cls.openai_schema["name"]
message.function_call.name == cls.openai_schema["name"] # type: ignore[index]
), "Function name does not match"
return cls.model_validate_json(
message.function_call.arguments,
@@ -118,7 +120,7 @@ class OpenAISchema(BaseModel):
), "Instructor does not support multiple tool calls, use List[Model] instead."
tool_call = message.tool_calls[0]
assert (
tool_call.function.name == cls.openai_schema["name"]
tool_call.function.name == cls.openai_schema["name"] # type: ignore[index]
), "Tool name does not match"
return cls.model_validate_json(
tool_call.function.arguments,
+28 -18
View File
@@ -1,3 +1,4 @@
from typing import Type, TypeVar
import pytest
from pydantic import BaseModel
from openai.resources.chat.completions import ChatCompletion
@@ -7,17 +8,20 @@ import instructor
from instructor.exceptions import IncompleteOutputException
@pytest.fixture
def test_model():
class TestModel(OpenAISchema):
T = TypeVar("T")
@pytest.fixture # type: ignore[misc]
def test_model() -> Type[OpenAISchema]:
class TestModel(OpenAISchema): # type: ignore[misc]
name: str = "TestModel"
data: str
return TestModel
@pytest.fixture
def mock_completion(request):
@pytest.fixture # type: ignore[misc]
def mock_completion(request: T) -> ChatCompletion:
finish_reason = "stop"
data_content = '{\n"data": "complete data"\n}'
@@ -48,9 +52,9 @@ def mock_completion(request):
return completion
def test_openai_schema():
def test_openai_schema() -> None:
@openai_schema
class Dataframe(BaseModel):
class Dataframe(BaseModel): # type: ignore[misc]
"""
Class representing a dataframe. This class is used to convert
data into a frame that can be used by pandas.
@@ -59,16 +63,16 @@ def test_openai_schema():
data: str
columns: str
def to_pandas(self):
def to_pandas(self) -> None:
pass
assert hasattr(Dataframe, "openai_schema")
assert hasattr(Dataframe, "from_response")
assert hasattr(Dataframe, "to_pandas")
assert Dataframe.openai_schema["name"] == "Dataframe" # type: ignore
assert Dataframe.openai_schema["name"] == "Dataframe"
def test_openai_schema_raises_error():
def test_openai_schema_raises_error() -> None:
with pytest.raises(TypeError, match="must be a subclass of pydantic.BaseModel"):
@openai_schema
@@ -76,8 +80,8 @@ def test_openai_schema_raises_error():
pass
def test_no_docstring():
class Dummy(OpenAISchema):
def test_no_docstring() -> None:
class Dummy(OpenAISchema): # type: ignore[misc]
attr: str
assert (
@@ -90,25 +94,31 @@ def test_no_docstring():
"mock_completion",
[{"finish_reason": "length", "data_content": '{\n"data": "incomplete dat"\n}'}],
indirect=True,
)
def test_incomplete_output_exception(test_model, mock_completion):
) # type: ignore[misc]
def test_incomplete_output_exception(
test_model: Type[OpenAISchema], mock_completion: ChatCompletion
) -> None:
with pytest.raises(IncompleteOutputException):
test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS)
def test_complete_output_no_exception(test_model, mock_completion):
def test_complete_output_no_exception(
test_model: Type[OpenAISchema], mock_completion: ChatCompletion
) -> None:
test_model_instance = test_model.from_response(
mock_completion, mode=instructor.Mode.FUNCTIONS
)
assert test_model_instance.data == "complete data"
@pytest.mark.asyncio
@pytest.mark.asyncio # type: ignore[misc]
@pytest.mark.parametrize(
"mock_completion",
[{"finish_reason": "length", "data_content": '{\n"data": "incomplete dat"\n}'}],
indirect=True,
)
async def test_incomplete_output_exception_raise(test_model, mock_completion):
) # type: ignore[misc]
async def test_incomplete_output_exception_raise(
test_model: Type[OpenAISchema], mock_completion: ChatCompletion
) -> None:
with pytest.raises(IncompleteOutputException):
await test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS)