mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
chore: Include types to instructor.distil and tests (#396)
This commit is contained in:
@@ -16,8 +16,10 @@ env:
|
||||
instructor/cli/jobs.py
|
||||
instructor/cli/usage.py
|
||||
instructor/exceptions.py
|
||||
instructor/distil.py
|
||||
instructor/function_calls.py
|
||||
tests/test_function_calls.py
|
||||
tests/test_distil.py
|
||||
|
||||
jobs:
|
||||
MyPy:
|
||||
|
||||
+37
-25
@@ -5,19 +5,22 @@ import logging
|
||||
import inspect
|
||||
import functools
|
||||
|
||||
from typing import Any, Callable, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar
|
||||
from pydantic import BaseModel, validate_call
|
||||
|
||||
from openai import OpenAI
|
||||
from instructor.function_calls import openai_schema
|
||||
|
||||
|
||||
T_Retval = TypeVar("T_Retval")
|
||||
|
||||
|
||||
class FinetuneFormat(enum.Enum):
|
||||
MESSAGES: str = "messages"
|
||||
RAW: str = "raw"
|
||||
|
||||
|
||||
def get_signature_from_fn(fn: Callable) -> str:
|
||||
def get_signature_from_fn(fn: Callable[..., Any]) -> str:
|
||||
"""
|
||||
Get the function signature as a string.
|
||||
|
||||
@@ -43,7 +46,7 @@ def get_signature_from_fn(fn: Callable) -> str:
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def format_function(func: Callable) -> str:
|
||||
def format_function(func: Callable[..., Any]) -> str:
|
||||
"""
|
||||
Format a function as a string with docstring and body.
|
||||
"""
|
||||
@@ -79,14 +82,14 @@ def is_return_type_base_model_or_instance(func: Callable[..., Any]) -> bool:
|
||||
class Instructions:
|
||||
def __init__(
|
||||
self,
|
||||
name: str = None,
|
||||
id: str = None,
|
||||
log_handlers: List[logging.Handler] = None,
|
||||
name: Optional[str] = None,
|
||||
id: Optional[str] = None,
|
||||
log_handlers: Optional[List[logging.Handler]] = None,
|
||||
finetune_format: FinetuneFormat = FinetuneFormat.MESSAGES,
|
||||
indent: int = 2,
|
||||
include_code_body: bool = False,
|
||||
openai_client: OpenAI = None,
|
||||
):
|
||||
openai_client: Optional[OpenAI] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Instructions for distillation and dispatch.
|
||||
|
||||
@@ -111,12 +114,15 @@ class Instructions:
|
||||
|
||||
def distil(
|
||||
self,
|
||||
*args,
|
||||
name: str = None,
|
||||
*args: Any,
|
||||
name: Optional[str] = None,
|
||||
mode: str = "distil",
|
||||
model: str = "gpt-3.5-turbo",
|
||||
fine_tune_format: FinetuneFormat = None,
|
||||
):
|
||||
fine_tune_format: Optional[FinetuneFormat] = None,
|
||||
) -> Callable[
|
||||
[Callable[..., Any]],
|
||||
Callable[[Callable[..., T_Retval]], Callable[..., T_Retval]],
|
||||
]:
|
||||
"""
|
||||
Decorator to track the function call and response, supports distillation and dispatch modes.
|
||||
|
||||
@@ -142,13 +148,16 @@ class Instructions:
|
||||
if fine_tune_format is None:
|
||||
fine_tune_format = self.finetune_format
|
||||
|
||||
def _wrap_distil(fn):
|
||||
def _wrap_distil(
|
||||
fn: Callable[..., Any],
|
||||
) -> Callable[[Callable[..., T_Retval]], Callable[..., T_Retval]]:
|
||||
msg = f"Return type hint for {fn} must subclass `pydantic.BaseModel'"
|
||||
assert is_return_type_base_model_or_instance(fn), msg
|
||||
return_base_model = inspect.signature(fn).return_annotation
|
||||
|
||||
@functools.wraps(fn)
|
||||
def _dispatch(*args, **kwargs):
|
||||
def _dispatch(*args: Any, **kwargs: Any) -> Callable[..., T_Retval]:
|
||||
name = name if name else fn.__name__
|
||||
openai_kwargs = self.openai_kwargs(
|
||||
name=name,
|
||||
fn=fn,
|
||||
@@ -161,7 +170,7 @@ class Instructions:
|
||||
)
|
||||
|
||||
@functools.wraps(fn)
|
||||
def _distil(*args, **kwargs):
|
||||
def _distil(*args: Any, **kwargs: Any) -> Callable[..., T_Retval]:
|
||||
resp = fn(*args, **kwargs)
|
||||
self.track(
|
||||
fn, args, kwargs, resp, name=name, finetune_format=fine_tune_format
|
||||
@@ -169,27 +178,23 @@ class Instructions:
|
||||
|
||||
return resp
|
||||
|
||||
if mode == "dispatch":
|
||||
return _dispatch
|
||||
|
||||
if mode == "distil":
|
||||
return _distil
|
||||
return _dispatch if mode == "dispatch" else _distil
|
||||
|
||||
if len(args) == 1 and callable(args[0]):
|
||||
return _wrap_distil(args[0])
|
||||
|
||||
return _wrap_distil
|
||||
|
||||
@validate_call
|
||||
@validate_call # type: ignore[misc]
|
||||
def track(
|
||||
self,
|
||||
fn: Callable[..., Any],
|
||||
args: tuple,
|
||||
kwargs: dict,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
resp: BaseModel,
|
||||
name: Optional[str] = None,
|
||||
finetune_format: FinetuneFormat = FinetuneFormat.MESSAGES,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Track the function call and response in a log file, later used for finetuning.
|
||||
|
||||
@@ -229,7 +234,14 @@ class Instructions:
|
||||
)
|
||||
self.logger.info(json.dumps(function_body))
|
||||
|
||||
def openai_kwargs(self, name, fn, args, kwargs, base_model):
|
||||
def openai_kwargs(
|
||||
self,
|
||||
name: str,
|
||||
fn: Callable[..., Any],
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
base_model: Type[BaseModel],
|
||||
) -> Dict[str, Any]:
|
||||
if self.include_code_body:
|
||||
func_def = format_function(fn)
|
||||
else:
|
||||
|
||||
+17
-14
@@ -1,3 +1,4 @@
|
||||
from typing import Any, Dict, Callable, Tuple, cast
|
||||
import pytest
|
||||
import instructor
|
||||
|
||||
@@ -18,27 +19,27 @@ instructions = Instructions(
|
||||
)
|
||||
|
||||
|
||||
class SimpleModel(BaseModel):
|
||||
class SimpleModel(BaseModel): # type: ignore[misc]
|
||||
data: int
|
||||
|
||||
|
||||
def test_must_have_hint():
|
||||
def test_must_have_hint() -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
|
||||
@instructions.distil
|
||||
def test_func(x: int):
|
||||
def test_func(x: int): # type: ignore[no-untyped-def]
|
||||
return SimpleModel(data=x)
|
||||
|
||||
|
||||
def test_must_be_base_model():
|
||||
def test_must_be_base_model() -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
|
||||
@instructions.distil
|
||||
def test_func(x) -> int:
|
||||
def test_func(x: int) -> int:
|
||||
return SimpleModel(data=x)
|
||||
|
||||
|
||||
def test_is_return_type_base_model_or_instance():
|
||||
def test_is_return_type_base_model_or_instance() -> None:
|
||||
def valid_function() -> SimpleModel:
|
||||
return SimpleModel(data=1)
|
||||
|
||||
@@ -49,8 +50,8 @@ def test_is_return_type_base_model_or_instance():
|
||||
assert not is_return_type_base_model_or_instance(invalid_function)
|
||||
|
||||
|
||||
def test_get_signature_from_fn():
|
||||
def test_function(a: int, b: str) -> float:
|
||||
def test_get_signature_from_fn() -> None:
|
||||
def test_function(a: int, b: str) -> float: # type: ignore[empty-body]
|
||||
"""Sample docstring"""
|
||||
pass
|
||||
|
||||
@@ -60,7 +61,7 @@ def test_get_signature_from_fn():
|
||||
assert "Sample docstring" in result
|
||||
|
||||
|
||||
def test_format_function():
|
||||
def test_format_function() -> None:
|
||||
def sample_function(x: int) -> SimpleModel:
|
||||
"""This is a docstring."""
|
||||
return SimpleModel(data=x)
|
||||
@@ -71,26 +72,28 @@ def test_format_function():
|
||||
assert "return SimpleModel(data=x)" in formatted
|
||||
|
||||
|
||||
def test_distil_decorator_without_arguments():
|
||||
def test_distil_decorator_without_arguments() -> None:
|
||||
@instructions.distil
|
||||
def test_func(x: int) -> SimpleModel:
|
||||
return SimpleModel(data=x)
|
||||
|
||||
result = test_func(42)
|
||||
casted_test_func = cast(Callable[[int], SimpleModel], test_func)
|
||||
result: SimpleModel = casted_test_func(42)
|
||||
assert result.data == 42
|
||||
|
||||
|
||||
def test_distil_decorator_with_name_argument():
|
||||
def test_distil_decorator_with_name_argument() -> None:
|
||||
@instructions.distil(name="custom_name")
|
||||
def another_test_func(x: int) -> SimpleModel:
|
||||
return SimpleModel(data=x)
|
||||
|
||||
result = another_test_func(55)
|
||||
casted_another_test_func = cast(Callable[[int], SimpleModel], another_test_func)
|
||||
result: SimpleModel = casted_another_test_func(55)
|
||||
assert result.data == 55
|
||||
|
||||
|
||||
# Mock track function for decorator tests
|
||||
def mock_track(*args, **kwargs):
|
||||
def mock_track(*args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user