This commit is contained in:
Jason Liu
2023-10-07 14:43:19 -04:00
parent 9663daa4f2
commit 8383795b0e
3 changed files with 285 additions and 0 deletions
+4
View File
@@ -1,4 +1,5 @@
from .function_calls import OpenAISchema, openai_function, openai_schema
from .distil import FinetuneFormat, distil, track
from .dsl import MultiTask, Maybe, llm_validator, CitationMixin
from .patch import patch
@@ -11,4 +12,7 @@ __all__ = [
"openai_schema",
"patch",
"llm_validator",
"FinetuneFormat",
"distil",
"track",
]
+212
View File
@@ -0,0 +1,212 @@
import enum
import functools
import inspect
import json
from typing import Any, Callable, Optional
from pydantic import BaseModel, validate_call
import inspect
import logging
from instructor import openai_schema
distil_logger = logging.getLogger("instructor.distil")
def logging(level=logging.INFO, handler=None, log_to_file=True, filename_prefix=None):
"""
Configure the instructor module's logging.
:param level: Log level.
:param handler: Optional logging handler. If not provided, defaults to FileHandler or NullHandler based on log_to_file.
:param log_to_file: If True and no handler is provided, logs to a file.
:param filename: Optional filename for logging if log_to_file is True. Defaults to 'instructor.log'.
"""
distil_logger.setLevel(level)
# Clear existing handlers
for h in distil_logger.handlers[:]:
distil_logger.removeHandler(h)
if handler:
distil_logger.addHandler(handler)
elif log_to_file:
filename = filename_prefix or "instructor.log"
file_handler = logging.FileHandler(filename)
file_handler.setFormatter(logging.Formatter("%(message)s"))
distil_logger.addHandler(file_handler)
else:
distil_logger.addHandler(logging.NullHandler())
class FinetuneFormat(enum.Enum):
MESSAGES: str = "messages"
RAW: str = "raw"
def get_signature_from_fn(fn: Callable) -> str:
"""
Get the function signature as a string.
:Example:
>>> def my_function(a: int, b: int) -> int:
>>> return a + b
>>>
>>> get_signature_from_fn(my_function)
"def my_function(a: int, b: int) -> int"
:param fn: Function to get the signature for.
:return: Function signature as a string.
"""
sig = inspect.signature(fn)
lines = f"def {fn.__name__}{sig}"
docstring = inspect.getdoc(fn)
if docstring:
formatted_docstring = f'"""\n{docstring}\n"""'
else:
formatted_docstring = ""
return f"{lines}\n{formatted_docstring}"
@functools.lru_cache()
def format_function(func: Callable) -> str:
"""
Format a function as a string with docstring and body.
"""
source_lines = inspect.getsourcelines(func)
definition = " ".join(source_lines[0]).strip()
docstring = inspect.getdoc(func)
if docstring:
formatted_docstring = f'"""\n{docstring}\n"""'
else:
formatted_docstring = ""
body = inspect.getsource(func)
body = body.replace(f"def {func.__name__}", "")
return f"{definition}\n{formatted_docstring}\n{body}"
def is_return_type_base_model_or_instance(func: Callable[..., Any]) -> bool:
"""
Check if the return type of a function is a pydantic BaseModel or an instance of it.
:param func: Function to check.
:return: True if the return type is a pydantic BaseModel or an instance of it.
"""
return_type = inspect.signature(func).return_annotation
return inspect.isclass(return_type) and issubclass(return_type, BaseModel)
@validate_call
def track(
fn: Callable[..., Any],
args: tuple,
kwargs: dict,
resp: BaseModel,
name: Optional[str] = None,
finetune_format: FinetuneFormat = FinetuneFormat.RAW,
):
"""
Track the function call and response in a log file, later used for finetuning.
:param fn: Function to track.
:param args: Arguments passed to the function.
:param kwargs: Keyword arguments passed to the function.
:param resp: Response returned by the function.
:param name: Name of the function to track. Defaults to the function name.
:param finetune_format: Format to use for finetuning. Defaults to "raw".
"""
name = name if name else fn.__name__
base_model: BaseModel = type(resp)
if finetune_format == FinetuneFormat.RAW:
function_body = dict(
fn_name=name,
fn_repr=format_function(fn),
args=args,
kwargs=kwargs,
resp=resp.model_dump(),
schema=base_model.model_json_schema(),
)
distil_logger.info(json.dumps(function_body))
return
if finetune_format == FinetuneFormat.MESSAGES:
# This is the format that OpenAI's API expects for a finetune call
openai_function_call = openai_schema(base_model).openai_schema
function_definition = get_signature_from_fn(fn)
function_body = {
"messages": [
{
"role": "system",
"content": f"Return the response from the function call.\n\n {function_definition}",
},
{
"role": "user",
"content": f"Return the results of the function with the following arguments:\n\n {name}(*{args}, **{kwargs})",
},
{
"role": "function",
"function_call": {
"name": openai_function_call["name"],
"augments": resp.model_dump(),
},
},
],
"functions": [openai_function_call],
"function_call": {"name": name},
}
distil_logger.info(json.dumps(function_body))
return
raise ValueError(f"Invalid finetune format: {finetune_format}")
def distil(
*args,
name: str = None,
mode: str = "distil",
fine_tune_format: FinetuneFormat = FinetuneFormat.RAW,
):
"""
Decorator to track the function call and response, supports distillation and dispatch modes.
If used without arguments, it must be used as a decorator.
:Example:
>>> @distil
>>> def my_function() -> MyModel:
>>> return MyModel()
>>>
>>> @distil(name="my_function")
>>> def my_function() -> MyModel:
>>> return MyModel()
:param fn: Function to track.
:param name: Name of the function to track. Defaults to the function name.
:param mode: Mode to use for distillation. Defaults to "distil".
"""
allowed_modes = {"distil", "dispatch"}
assert mode in allowed_modes, f"Must be in {allowed_modes}"
assert mode == "distil", "Only distil mode is supported at the moment."
def _wrap_distil(fn):
msg = f"Return type hint for {fn} must subclass `pydantic.BaseModel'"
assert is_return_type_base_model_or_instance(fn), msg
@functools.wraps(fn)
def _distil(*args, **kwargs):
resp = fn(*args, **kwargs)
track(fn, args, kwargs, resp, name=name, finetune_format=fine_tune_format)
return resp
return _distil
if len(args) == 1 and callable(args[0]):
return _wrap_distil(args[0])
return _wrap_distil
+69
View File
@@ -0,0 +1,69 @@
from pydantic import BaseModel
from instructor.distil import (
distil,
format_function,
get_signature_from_fn,
is_return_type_base_model_or_instance,
)
# Replace `your_module_name` with your actual module name
class SimpleModel(BaseModel):
data: int
def test_is_return_type_base_model_or_instance():
def valid_function() -> SimpleModel:
return SimpleModel(data=1)
def invalid_function() -> int:
return 1
assert is_return_type_base_model_or_instance(valid_function)
assert not is_return_type_base_model_or_instance(invalid_function)
def test_get_signature_from_fn():
def test_function(a: int, b: str) -> float:
"""Sample docstring"""
pass
result = get_signature_from_fn(test_function)
expected = "def test_function(a: int, b: str) -> float"
assert expected in result
assert "Sample docstring" in result
def test_format_function():
def sample_function(x: int) -> SimpleModel:
"""This is a docstring."""
return SimpleModel(data=x)
formatted = format_function(sample_function)
assert "def sample_function(x: int) -> SimpleModel:" in formatted
assert '"""This is a docstring."""' in formatted
assert "return SimpleModel(data=x)" in formatted
def test_distil_decorator_without_arguments():
@distil
def test_func(x: int) -> SimpleModel:
return SimpleModel(data=x)
result = test_func(42)
assert result.data == 42
def test_distil_decorator_with_name_argument():
@distil(name="custom_name")
def another_test_func(x: int) -> SimpleModel:
return SimpleModel(data=x)
result = another_test_func(55)
assert result.data == 55
# Mock track function for decorator tests
def mock_track(*args, **kwargs):
pass