chore: Include types to instructor.distil and tests (#396)

This commit is contained in:
Ezzeri Esa
2024-02-08 11:52:10 -08:00
committed by GitHub
parent f65ba6b407
commit d6326824ab
3 changed files with 56 additions and 39 deletions
+2
View File
@@ -16,8 +16,10 @@ env:
instructor/cli/jobs.py
instructor/cli/usage.py
instructor/exceptions.py
instructor/distil.py
instructor/function_calls.py
tests/test_function_calls.py
tests/test_distil.py
jobs:
MyPy:
+37 -25
View File
@@ -5,19 +5,22 @@ import logging
import inspect
import functools
from typing import Any, Callable, List, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar
from pydantic import BaseModel, validate_call
from openai import OpenAI
from instructor.function_calls import openai_schema
T_Retval = TypeVar("T_Retval")
class FinetuneFormat(enum.Enum):
MESSAGES: str = "messages"
RAW: str = "raw"
def get_signature_from_fn(fn: Callable) -> str:
def get_signature_from_fn(fn: Callable[..., Any]) -> str:
"""
Get the function signature as a string.
@@ -43,7 +46,7 @@ def get_signature_from_fn(fn: Callable) -> str:
@functools.lru_cache()
def format_function(func: Callable) -> str:
def format_function(func: Callable[..., Any]) -> str:
"""
Format a function as a string with docstring and body.
"""
@@ -79,14 +82,14 @@ def is_return_type_base_model_or_instance(func: Callable[..., Any]) -> bool:
class Instructions:
def __init__(
self,
name: str = None,
id: str = None,
log_handlers: List[logging.Handler] = None,
name: Optional[str] = None,
id: Optional[str] = None,
log_handlers: Optional[List[logging.Handler]] = None,
finetune_format: FinetuneFormat = FinetuneFormat.MESSAGES,
indent: int = 2,
include_code_body: bool = False,
openai_client: OpenAI = None,
):
openai_client: Optional[OpenAI] = None,
) -> None:
"""
Instructions for distillation and dispatch.
@@ -111,12 +114,15 @@ class Instructions:
def distil(
self,
*args,
name: str = None,
*args: Any,
name: Optional[str] = None,
mode: str = "distil",
model: str = "gpt-3.5-turbo",
fine_tune_format: FinetuneFormat = None,
):
fine_tune_format: Optional[FinetuneFormat] = None,
) -> Callable[
[Callable[..., Any]],
Callable[[Callable[..., T_Retval]], Callable[..., T_Retval]],
]:
"""
Decorator to track the function call and response, supports distillation and dispatch modes.
@@ -142,13 +148,16 @@ class Instructions:
if fine_tune_format is None:
fine_tune_format = self.finetune_format
def _wrap_distil(fn):
def _wrap_distil(
fn: Callable[..., Any],
) -> Callable[[Callable[..., T_Retval]], Callable[..., T_Retval]]:
msg = f"Return type hint for {fn} must subclass `pydantic.BaseModel'"
assert is_return_type_base_model_or_instance(fn), msg
return_base_model = inspect.signature(fn).return_annotation
@functools.wraps(fn)
def _dispatch(*args, **kwargs):
def _dispatch(*args: Any, **kwargs: Any) -> Callable[..., T_Retval]:
name = name if name else fn.__name__
openai_kwargs = self.openai_kwargs(
name=name,
fn=fn,
@@ -161,7 +170,7 @@ class Instructions:
)
@functools.wraps(fn)
def _distil(*args, **kwargs):
def _distil(*args: Any, **kwargs: Any) -> Callable[..., T_Retval]:
resp = fn(*args, **kwargs)
self.track(
fn, args, kwargs, resp, name=name, finetune_format=fine_tune_format
@@ -169,27 +178,23 @@ class Instructions:
return resp
if mode == "dispatch":
return _dispatch
if mode == "distil":
return _distil
return _dispatch if mode == "dispatch" else _distil
if len(args) == 1 and callable(args[0]):
return _wrap_distil(args[0])
return _wrap_distil
@validate_call
@validate_call # type: ignore[misc]
def track(
self,
fn: Callable[..., Any],
args: tuple,
kwargs: dict,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
resp: BaseModel,
name: Optional[str] = None,
finetune_format: FinetuneFormat = FinetuneFormat.MESSAGES,
):
) -> None:
"""
Track the function call and response in a log file, later used for finetuning.
@@ -229,7 +234,14 @@ class Instructions:
)
self.logger.info(json.dumps(function_body))
def openai_kwargs(self, name, fn, args, kwargs, base_model):
def openai_kwargs(
self,
name: str,
fn: Callable[..., Any],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
base_model: Type[BaseModel],
) -> Dict[str, Any]:
if self.include_code_body:
func_def = format_function(fn)
else:
+17 -14
View File
@@ -1,3 +1,4 @@
from typing import Any, Dict, Callable, Tuple, cast
import pytest
import instructor
@@ -18,27 +19,27 @@ instructions = Instructions(
)
class SimpleModel(BaseModel):
class SimpleModel(BaseModel): # type: ignore[misc]
data: int
def test_must_have_hint():
def test_must_have_hint() -> None:
with pytest.raises(AssertionError):
@instructions.distil
def test_func(x: int):
def test_func(x: int): # type: ignore[no-untyped-def]
return SimpleModel(data=x)
def test_must_be_base_model():
def test_must_be_base_model() -> None:
with pytest.raises(AssertionError):
@instructions.distil
def test_func(x) -> int:
def test_func(x: int) -> int:
return SimpleModel(data=x)
def test_is_return_type_base_model_or_instance():
def test_is_return_type_base_model_or_instance() -> None:
def valid_function() -> SimpleModel:
return SimpleModel(data=1)
@@ -49,8 +50,8 @@ def test_is_return_type_base_model_or_instance():
assert not is_return_type_base_model_or_instance(invalid_function)
def test_get_signature_from_fn():
def test_function(a: int, b: str) -> float:
def test_get_signature_from_fn() -> None:
def test_function(a: int, b: str) -> float: # type: ignore[empty-body]
"""Sample docstring"""
pass
@@ -60,7 +61,7 @@ def test_get_signature_from_fn():
assert "Sample docstring" in result
def test_format_function():
def test_format_function() -> None:
def sample_function(x: int) -> SimpleModel:
"""This is a docstring."""
return SimpleModel(data=x)
@@ -71,26 +72,28 @@ def test_format_function():
assert "return SimpleModel(data=x)" in formatted
def test_distil_decorator_without_arguments():
def test_distil_decorator_without_arguments() -> None:
@instructions.distil
def test_func(x: int) -> SimpleModel:
return SimpleModel(data=x)
result = test_func(42)
casted_test_func = cast(Callable[[int], SimpleModel], test_func)
result: SimpleModel = casted_test_func(42)
assert result.data == 42
def test_distil_decorator_with_name_argument():
def test_distil_decorator_with_name_argument() -> None:
@instructions.distil(name="custom_name")
def another_test_func(x: int) -> SimpleModel:
return SimpleModel(data=x)
result = another_test_func(55)
casted_another_test_func = cast(Callable[[int], SimpleModel], another_test_func)
result: SimpleModel = casted_another_test_func(55)
assert result.data == 55
# Mock track function for decorator tests
def mock_track(*args, **kwargs):
def mock_track(*args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> None:
pass