Files

104 lines
2.8 KiB
Python

from typing import Any, Dict, Callable, Tuple, cast
import pytest
import instructor
from openai import OpenAI
from pydantic import BaseModel
from instructor.distil import (
Instructions,
format_function,
get_signature_from_fn,
is_return_type_base_model_or_instance,
)
client = instructor.patch(OpenAI())
instructions = Instructions(
name="test_distil",
)
class SimpleModel(BaseModel): # type: ignore[misc]
data: int
def test_must_have_hint() -> None:
with pytest.raises(AssertionError):
@instructions.distil
def test_func(x: int): # type: ignore[no-untyped-def]
return SimpleModel(data=x)
def test_must_be_base_model() -> None:
with pytest.raises(AssertionError):
@instructions.distil
def test_func(x: int) -> int:
return SimpleModel(data=x)
def test_is_return_type_base_model_or_instance() -> None:
def valid_function() -> SimpleModel:
return SimpleModel(data=1)
def invalid_function() -> int:
return 1
assert is_return_type_base_model_or_instance(valid_function)
assert not is_return_type_base_model_or_instance(invalid_function)
def test_get_signature_from_fn() -> None:
def test_function(a: int, b: str) -> float: # type: ignore[empty-body]
"""Sample docstring"""
pass
result = get_signature_from_fn(test_function)
expected = "def test_function(a: int, b: str) -> float"
assert expected in result
assert "Sample docstring" in result
def test_format_function() -> None:
def sample_function(x: int) -> SimpleModel:
"""This is a docstring."""
return SimpleModel(data=x)
formatted = format_function(sample_function)
assert "def sample_function(x: int) -> SimpleModel:" in formatted
assert '"""This is a docstring."""' in formatted
assert "return SimpleModel(data=x)" in formatted
def test_distil_decorator_without_arguments() -> None:
@instructions.distil
def test_func(x: int) -> SimpleModel:
return SimpleModel(data=x)
casted_test_func = cast(Callable[[int], SimpleModel], test_func)
result: SimpleModel = casted_test_func(42)
assert result.data == 42
def test_distil_decorator_with_name_argument() -> None:
@instructions.distil(name="custom_name")
def another_test_func(x: int) -> SimpleModel:
return SimpleModel(data=x)
casted_another_test_func = cast(Callable[[int], SimpleModel], another_test_func)
result: SimpleModel = casted_another_test_func(55)
assert result.data == 55
# Mock track function for decorator tests
def mock_track(*args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> None:
pass
def fn(a: int, b: int) -> int:
return client.chat.completions.create(
messages=[], model="davinci", response_model=SimpleModel
)