From 75908132a03174eecfd1ee606c60ece709da49e9 Mon Sep 17 00:00:00 2001 From: Ezzeri Esa Date: Sun, 4 Feb 2024 18:24:45 -0800 Subject: [PATCH] Include types to instructor.function_calls and tests (#394) Co-authored-by: Jason Liu --- .github/workflows/mypy.yml | 3 +++ instructor/function_calls.py | 22 +++++++++-------- tests/test_function_calls.py | 46 ++++++++++++++++++++++-------------- 3 files changed, 43 insertions(+), 28 deletions(-) diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index 48fe24d..6775ac8 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -13,8 +13,11 @@ env: instructor/_types/_alias.py instructor/cli/cli.py instructor/cli/files.py + instructor/cli/jobs.py instructor/cli/usage.py instructor/exceptions.py + instructor/function_calls.py + tests/test_function_calls.py jobs: MyPy: diff --git a/instructor/function_calls.py b/instructor/function_calls.py index 5b3ba1d..f24f9a5 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -1,4 +1,4 @@ -from typing import Type, TypeVar +from typing import Any, Dict, Optional, Type, TypeVar from docstring_parser import parse from functools import wraps from pydantic import BaseModel, create_model @@ -19,7 +19,7 @@ class Mode(enum.Enum): MD_JSON: str = "markdown_json_mode" JSON_SCHEMA: str = "json_schema_mode" - def __new__(cls, value): + def __new__(cls, value: str) -> "Mode": member = object.__new__(cls) member._value_ = value @@ -34,10 +34,10 @@ class Mode(enum.Enum): return member -class OpenAISchema(BaseModel): - @classmethod +class OpenAISchema(BaseModel): # type: ignore[misc] + @classmethod # type: ignore[misc] @property - def openai_schema(cls): + def openai_schema(cls) -> Dict[str, Any]: """ Return the schema in the format of OpenAI's schema as jsonschema @@ -82,10 +82,10 @@ class OpenAISchema(BaseModel): def from_response( cls, completion: T, - validation_context: dict = None, - strict: bool = None, + validation_context: Optional[Dict[str, Any]] = None, + strict: Optional[bool] = None, mode: Mode = Mode.TOOLS, - ): + ) -> Dict[str, Any]: """Execute the function from the response of an openai chat completion Parameters: @@ -98,6 +98,8 @@ class OpenAISchema(BaseModel): Returns: cls (OpenAISchema): An instance of the class """ + assert hasattr(completion, "choices") + if completion.choices[0].finish_reason == "length": raise IncompleteOutputException() @@ -105,7 +107,7 @@ class OpenAISchema(BaseModel): if mode == Mode.FUNCTIONS: assert ( - message.function_call.name == cls.openai_schema["name"] + message.function_call.name == cls.openai_schema["name"] # type: ignore[index] ), "Function name does not match" return cls.model_validate_json( message.function_call.arguments, @@ -118,7 +120,7 @@ class OpenAISchema(BaseModel): ), "Instructor does not support multiple tool calls, use List[Model] instead." tool_call = message.tool_calls[0] assert ( - tool_call.function.name == cls.openai_schema["name"] + tool_call.function.name == cls.openai_schema["name"] # type: ignore[index] ), "Tool name does not match" return cls.model_validate_json( tool_call.function.arguments, diff --git a/tests/test_function_calls.py b/tests/test_function_calls.py index 9a0db9e..c4fb9ea 100644 --- a/tests/test_function_calls.py +++ b/tests/test_function_calls.py @@ -1,3 +1,4 @@ +from typing import Type, TypeVar import pytest from pydantic import BaseModel from openai.resources.chat.completions import ChatCompletion @@ -7,17 +8,20 @@ import instructor from instructor.exceptions import IncompleteOutputException -@pytest.fixture -def test_model(): - class TestModel(OpenAISchema): +T = TypeVar("T") + + +@pytest.fixture # type: ignore[misc] +def test_model() -> Type[OpenAISchema]: + class TestModel(OpenAISchema): # type: ignore[misc] name: str = "TestModel" data: str return TestModel -@pytest.fixture -def mock_completion(request): +@pytest.fixture # type: ignore[misc] +def mock_completion(request: T) -> ChatCompletion: finish_reason = "stop" data_content = '{\n"data": "complete data"\n}' @@ -48,9 +52,9 @@ def mock_completion(request): return completion -def test_openai_schema(): +def test_openai_schema() -> None: @openai_schema - class Dataframe(BaseModel): + class Dataframe(BaseModel): # type: ignore[misc] """ Class representing a dataframe. This class is used to convert data into a frame that can be used by pandas. @@ -59,16 +63,16 @@ def test_openai_schema(): data: str columns: str - def to_pandas(self): + def to_pandas(self) -> None: pass assert hasattr(Dataframe, "openai_schema") assert hasattr(Dataframe, "from_response") assert hasattr(Dataframe, "to_pandas") - assert Dataframe.openai_schema["name"] == "Dataframe" # type: ignore + assert Dataframe.openai_schema["name"] == "Dataframe" -def test_openai_schema_raises_error(): +def test_openai_schema_raises_error() -> None: with pytest.raises(TypeError, match="must be a subclass of pydantic.BaseModel"): @openai_schema @@ -76,8 +80,8 @@ def test_openai_schema_raises_error(): pass -def test_no_docstring(): - class Dummy(OpenAISchema): +def test_no_docstring() -> None: + class Dummy(OpenAISchema): # type: ignore[misc] attr: str assert ( @@ -90,25 +94,31 @@ def test_no_docstring(): "mock_completion", [{"finish_reason": "length", "data_content": '{\n"data": "incomplete dat"\n}'}], indirect=True, -) -def test_incomplete_output_exception(test_model, mock_completion): +) # type: ignore[misc] +def test_incomplete_output_exception( + test_model: Type[OpenAISchema], mock_completion: ChatCompletion +) -> None: with pytest.raises(IncompleteOutputException): test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS) -def test_complete_output_no_exception(test_model, mock_completion): +def test_complete_output_no_exception( + test_model: Type[OpenAISchema], mock_completion: ChatCompletion +) -> None: test_model_instance = test_model.from_response( mock_completion, mode=instructor.Mode.FUNCTIONS ) assert test_model_instance.data == "complete data" -@pytest.mark.asyncio +@pytest.mark.asyncio # type: ignore[misc] @pytest.mark.parametrize( "mock_completion", [{"finish_reason": "length", "data_content": '{\n"data": "incomplete dat"\n}'}], indirect=True, -) -async def test_incomplete_output_exception_raise(test_model, mock_completion): +) # type: ignore[misc] +async def test_incomplete_output_exception_raise( + test_model: Type[OpenAISchema], mock_completion: ChatCompletion +) -> None: with pytest.raises(IncompleteOutputException): await test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS)