mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
Include types to instructor.function_calls and tests (#394)
Co-authored-by: Jason Liu <jxnl@users.noreply.github.com>
This commit is contained in:
@@ -13,8 +13,11 @@ env:
|
||||
instructor/_types/_alias.py
|
||||
instructor/cli/cli.py
|
||||
instructor/cli/files.py
|
||||
instructor/cli/jobs.py
|
||||
instructor/cli/usage.py
|
||||
instructor/exceptions.py
|
||||
instructor/function_calls.py
|
||||
tests/test_function_calls.py
|
||||
|
||||
jobs:
|
||||
MyPy:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Type, TypeVar
|
||||
from typing import Any, Dict, Optional, Type, TypeVar
|
||||
from docstring_parser import parse
|
||||
from functools import wraps
|
||||
from pydantic import BaseModel, create_model
|
||||
@@ -19,7 +19,7 @@ class Mode(enum.Enum):
|
||||
MD_JSON: str = "markdown_json_mode"
|
||||
JSON_SCHEMA: str = "json_schema_mode"
|
||||
|
||||
def __new__(cls, value):
|
||||
def __new__(cls, value: str) -> "Mode":
|
||||
member = object.__new__(cls)
|
||||
member._value_ = value
|
||||
|
||||
@@ -34,10 +34,10 @@ class Mode(enum.Enum):
|
||||
return member
|
||||
|
||||
|
||||
class OpenAISchema(BaseModel):
|
||||
@classmethod
|
||||
class OpenAISchema(BaseModel): # type: ignore[misc]
|
||||
@classmethod # type: ignore[misc]
|
||||
@property
|
||||
def openai_schema(cls):
|
||||
def openai_schema(cls) -> Dict[str, Any]:
|
||||
"""
|
||||
Return the schema in the format of OpenAI's schema as jsonschema
|
||||
|
||||
@@ -82,10 +82,10 @@ class OpenAISchema(BaseModel):
|
||||
def from_response(
|
||||
cls,
|
||||
completion: T,
|
||||
validation_context: dict = None,
|
||||
strict: bool = None,
|
||||
validation_context: Optional[Dict[str, Any]] = None,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.TOOLS,
|
||||
):
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute the function from the response of an openai chat completion
|
||||
|
||||
Parameters:
|
||||
@@ -98,6 +98,8 @@ class OpenAISchema(BaseModel):
|
||||
Returns:
|
||||
cls (OpenAISchema): An instance of the class
|
||||
"""
|
||||
assert hasattr(completion, "choices")
|
||||
|
||||
if completion.choices[0].finish_reason == "length":
|
||||
raise IncompleteOutputException()
|
||||
|
||||
@@ -105,7 +107,7 @@ class OpenAISchema(BaseModel):
|
||||
|
||||
if mode == Mode.FUNCTIONS:
|
||||
assert (
|
||||
message.function_call.name == cls.openai_schema["name"]
|
||||
message.function_call.name == cls.openai_schema["name"] # type: ignore[index]
|
||||
), "Function name does not match"
|
||||
return cls.model_validate_json(
|
||||
message.function_call.arguments,
|
||||
@@ -118,7 +120,7 @@ class OpenAISchema(BaseModel):
|
||||
), "Instructor does not support multiple tool calls, use List[Model] instead."
|
||||
tool_call = message.tool_calls[0]
|
||||
assert (
|
||||
tool_call.function.name == cls.openai_schema["name"]
|
||||
tool_call.function.name == cls.openai_schema["name"] # type: ignore[index]
|
||||
), "Tool name does not match"
|
||||
return cls.model_validate_json(
|
||||
tool_call.function.arguments,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from typing import Type, TypeVar
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from openai.resources.chat.completions import ChatCompletion
|
||||
@@ -7,17 +8,20 @@ import instructor
|
||||
from instructor.exceptions import IncompleteOutputException
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_model():
|
||||
class TestModel(OpenAISchema):
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@pytest.fixture # type: ignore[misc]
|
||||
def test_model() -> Type[OpenAISchema]:
|
||||
class TestModel(OpenAISchema): # type: ignore[misc]
|
||||
name: str = "TestModel"
|
||||
data: str
|
||||
|
||||
return TestModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_completion(request):
|
||||
@pytest.fixture # type: ignore[misc]
|
||||
def mock_completion(request: T) -> ChatCompletion:
|
||||
finish_reason = "stop"
|
||||
data_content = '{\n"data": "complete data"\n}'
|
||||
|
||||
@@ -48,9 +52,9 @@ def mock_completion(request):
|
||||
return completion
|
||||
|
||||
|
||||
def test_openai_schema():
|
||||
def test_openai_schema() -> None:
|
||||
@openai_schema
|
||||
class Dataframe(BaseModel):
|
||||
class Dataframe(BaseModel): # type: ignore[misc]
|
||||
"""
|
||||
Class representing a dataframe. This class is used to convert
|
||||
data into a frame that can be used by pandas.
|
||||
@@ -59,16 +63,16 @@ def test_openai_schema():
|
||||
data: str
|
||||
columns: str
|
||||
|
||||
def to_pandas(self):
|
||||
def to_pandas(self) -> None:
|
||||
pass
|
||||
|
||||
assert hasattr(Dataframe, "openai_schema")
|
||||
assert hasattr(Dataframe, "from_response")
|
||||
assert hasattr(Dataframe, "to_pandas")
|
||||
assert Dataframe.openai_schema["name"] == "Dataframe" # type: ignore
|
||||
assert Dataframe.openai_schema["name"] == "Dataframe"
|
||||
|
||||
|
||||
def test_openai_schema_raises_error():
|
||||
def test_openai_schema_raises_error() -> None:
|
||||
with pytest.raises(TypeError, match="must be a subclass of pydantic.BaseModel"):
|
||||
|
||||
@openai_schema
|
||||
@@ -76,8 +80,8 @@ def test_openai_schema_raises_error():
|
||||
pass
|
||||
|
||||
|
||||
def test_no_docstring():
|
||||
class Dummy(OpenAISchema):
|
||||
def test_no_docstring() -> None:
|
||||
class Dummy(OpenAISchema): # type: ignore[misc]
|
||||
attr: str
|
||||
|
||||
assert (
|
||||
@@ -90,25 +94,31 @@ def test_no_docstring():
|
||||
"mock_completion",
|
||||
[{"finish_reason": "length", "data_content": '{\n"data": "incomplete dat"\n}'}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_incomplete_output_exception(test_model, mock_completion):
|
||||
) # type: ignore[misc]
|
||||
def test_incomplete_output_exception(
|
||||
test_model: Type[OpenAISchema], mock_completion: ChatCompletion
|
||||
) -> None:
|
||||
with pytest.raises(IncompleteOutputException):
|
||||
test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS)
|
||||
|
||||
|
||||
def test_complete_output_no_exception(test_model, mock_completion):
|
||||
def test_complete_output_no_exception(
|
||||
test_model: Type[OpenAISchema], mock_completion: ChatCompletion
|
||||
) -> None:
|
||||
test_model_instance = test_model.from_response(
|
||||
mock_completion, mode=instructor.Mode.FUNCTIONS
|
||||
)
|
||||
assert test_model_instance.data == "complete data"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio # type: ignore[misc]
|
||||
@pytest.mark.parametrize(
|
||||
"mock_completion",
|
||||
[{"finish_reason": "length", "data_content": '{\n"data": "incomplete dat"\n}'}],
|
||||
indirect=True,
|
||||
)
|
||||
async def test_incomplete_output_exception_raise(test_model, mock_completion):
|
||||
) # type: ignore[misc]
|
||||
async def test_incomplete_output_exception_raise(
|
||||
test_model: Type[OpenAISchema], mock_completion: ChatCompletion
|
||||
) -> None:
|
||||
with pytest.raises(IncompleteOutputException):
|
||||
await test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS)
|
||||
|
||||
Reference in New Issue
Block a user