diff --git a/instructor/__init__.py b/instructor/__init__.py index 6a659e5..fdee478 100644 --- a/instructor/__init__.py +++ b/instructor/__init__.py @@ -1,4 +1,5 @@ from .function_calls import OpenAISchema, openai_function, openai_schema +from .distil import FinetuneFormat, distil, track from .dsl import MultiTask, Maybe, llm_validator, CitationMixin from .patch import patch @@ -11,4 +12,7 @@ __all__ = [ "openai_schema", "patch", "llm_validator", + "FinetuneFormat", + "distil", + "track", ] diff --git a/instructor/distil.py b/instructor/distil.py new file mode 100644 index 0000000..1160e09 --- /dev/null +++ b/instructor/distil.py @@ -0,0 +1,212 @@ +import enum +import functools +import inspect +import json + +from typing import Any, Callable, Optional +from pydantic import BaseModel, validate_call + +import inspect +import logging + +from instructor import openai_schema + +distil_logger = logging.getLogger("instructor.distil") + + +def logging(level=logging.INFO, handler=None, log_to_file=True, filename_prefix=None): + """ + Configure the instructor module's logging. + + :param level: Log level. + :param handler: Optional logging handler. If not provided, defaults to FileHandler or NullHandler based on log_to_file. + :param log_to_file: If True and no handler is provided, logs to a file. + :param filename: Optional filename for logging if log_to_file is True. Defaults to 'instructor.log'. + """ + distil_logger.setLevel(level) + + # Clear existing handlers + for h in distil_logger.handlers[:]: + distil_logger.removeHandler(h) + + if handler: + distil_logger.addHandler(handler) + elif log_to_file: + filename = filename_prefix or "instructor.log" + file_handler = logging.FileHandler(filename) + file_handler.setFormatter(logging.Formatter("%(message)s")) + distil_logger.addHandler(file_handler) + else: + distil_logger.addHandler(logging.NullHandler()) + + +class FinetuneFormat(enum.Enum): + MESSAGES: str = "messages" + RAW: str = "raw" + + +def get_signature_from_fn(fn: Callable) -> str: + """ + Get the function signature as a string. + + :Example: + + >>> def my_function(a: int, b: int) -> int: + >>> return a + b + >>> + >>> get_signature_from_fn(my_function) + "def my_function(a: int, b: int) -> int" + + :param fn: Function to get the signature for. + :return: Function signature as a string. + """ + sig = inspect.signature(fn) + lines = f"def {fn.__name__}{sig}" + docstring = inspect.getdoc(fn) + if docstring: + formatted_docstring = f'"""\n{docstring}\n"""' + else: + formatted_docstring = "" + return f"{lines}\n{formatted_docstring}" + + +@functools.lru_cache() +def format_function(func: Callable) -> str: + """ + Format a function as a string with docstring and body. + """ + source_lines = inspect.getsourcelines(func) + definition = " ".join(source_lines[0]).strip() + + docstring = inspect.getdoc(func) + if docstring: + formatted_docstring = f'"""\n{docstring}\n"""' + else: + formatted_docstring = "" + + body = inspect.getsource(func) + body = body.replace(f"def {func.__name__}", "") + + return f"{definition}\n{formatted_docstring}\n{body}" + + +def is_return_type_base_model_or_instance(func: Callable[..., Any]) -> bool: + """ + Check if the return type of a function is a pydantic BaseModel or an instance of it. + + :param func: Function to check. + :return: True if the return type is a pydantic BaseModel or an instance of it. + """ + return_type = inspect.signature(func).return_annotation + return inspect.isclass(return_type) and issubclass(return_type, BaseModel) + + +@validate_call +def track( + fn: Callable[..., Any], + args: tuple, + kwargs: dict, + resp: BaseModel, + name: Optional[str] = None, + finetune_format: FinetuneFormat = FinetuneFormat.RAW, +): + """ + Track the function call and response in a log file, later used for finetuning. + + :param fn: Function to track. + :param args: Arguments passed to the function. + :param kwargs: Keyword arguments passed to the function. + :param resp: Response returned by the function. + :param name: Name of the function to track. Defaults to the function name. + :param finetune_format: Format to use for finetuning. Defaults to "raw". + """ + name = name if name else fn.__name__ + base_model: BaseModel = type(resp) + + if finetune_format == FinetuneFormat.RAW: + function_body = dict( + fn_name=name, + fn_repr=format_function(fn), + args=args, + kwargs=kwargs, + resp=resp.model_dump(), + schema=base_model.model_json_schema(), + ) + distil_logger.info(json.dumps(function_body)) + return + + if finetune_format == FinetuneFormat.MESSAGES: + # This is the format that OpenAI's API expects for a finetune call + openai_function_call = openai_schema(base_model).openai_schema + function_definition = get_signature_from_fn(fn) + function_body = { + "messages": [ + { + "role": "system", + "content": f"Return the response from the function call.\n\n {function_definition}", + }, + { + "role": "user", + "content": f"Return the results of the function with the following arguments:\n\n {name}(*{args}, **{kwargs})", + }, + { + "role": "function", + "function_call": { + "name": openai_function_call["name"], + "augments": resp.model_dump(), + }, + }, + ], + "functions": [openai_function_call], + "function_call": {"name": name}, + } + distil_logger.info(json.dumps(function_body)) + return + raise ValueError(f"Invalid finetune format: {finetune_format}") + + +def distil( + *args, + name: str = None, + mode: str = "distil", + fine_tune_format: FinetuneFormat = FinetuneFormat.RAW, +): + """ + Decorator to track the function call and response, supports distillation and dispatch modes. + + If used without arguments, it must be used as a decorator. + + :Example: + + >>> @distil + >>> def my_function() -> MyModel: + >>> return MyModel() + >>> + >>> @distil(name="my_function") + >>> def my_function() -> MyModel: + >>> return MyModel() + + :param fn: Function to track. + :param name: Name of the function to track. Defaults to the function name. + :param mode: Mode to use for distillation. Defaults to "distil". + """ + allowed_modes = {"distil", "dispatch"} + assert mode in allowed_modes, f"Must be in {allowed_modes}" + assert mode == "distil", "Only distil mode is supported at the moment." + + def _wrap_distil(fn): + msg = f"Return type hint for {fn} must subclass `pydantic.BaseModel'" + assert is_return_type_base_model_or_instance(fn), msg + + @functools.wraps(fn) + def _distil(*args, **kwargs): + resp = fn(*args, **kwargs) + track(fn, args, kwargs, resp, name=name, finetune_format=fine_tune_format) + return resp + + return _distil + + if len(args) == 1 and callable(args[0]): + return _wrap_distil(args[0]) + + return _wrap_distil diff --git a/tests/test_distil.py b/tests/test_distil.py new file mode 100644 index 0000000..c07bd62 --- /dev/null +++ b/tests/test_distil.py @@ -0,0 +1,69 @@ +from pydantic import BaseModel +from instructor.distil import ( + distil, + format_function, + get_signature_from_fn, + is_return_type_base_model_or_instance, +) + +# Replace `your_module_name` with your actual module name + + +class SimpleModel(BaseModel): + data: int + + +def test_is_return_type_base_model_or_instance(): + def valid_function() -> SimpleModel: + return SimpleModel(data=1) + + def invalid_function() -> int: + return 1 + + assert is_return_type_base_model_or_instance(valid_function) + assert not is_return_type_base_model_or_instance(invalid_function) + + +def test_get_signature_from_fn(): + def test_function(a: int, b: str) -> float: + """Sample docstring""" + pass + + result = get_signature_from_fn(test_function) + expected = "def test_function(a: int, b: str) -> float" + assert expected in result + assert "Sample docstring" in result + + +def test_format_function(): + def sample_function(x: int) -> SimpleModel: + """This is a docstring.""" + return SimpleModel(data=x) + + formatted = format_function(sample_function) + assert "def sample_function(x: int) -> SimpleModel:" in formatted + assert '"""This is a docstring."""' in formatted + assert "return SimpleModel(data=x)" in formatted + + +def test_distil_decorator_without_arguments(): + @distil + def test_func(x: int) -> SimpleModel: + return SimpleModel(data=x) + + result = test_func(42) + assert result.data == 42 + + +def test_distil_decorator_with_name_argument(): + @distil(name="custom_name") + def another_test_func(x: int) -> SimpleModel: + return SimpleModel(data=x) + + result = another_test_func(55) + assert result.data == 55 + + +# Mock track function for decorator tests +def mock_track(*args, **kwargs): + pass