Files
instructor/instructor/distil.py
T

269 lines
8.8 KiB
Python

import enum
import json
import uuid
import logging
import inspect
import functools
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar
from pydantic import BaseModel, validate_call
from openai import OpenAI
from instructor.function_calls import openai_schema
T_Retval = TypeVar("T_Retval")
class FinetuneFormat(enum.Enum):
MESSAGES: str = "messages"
RAW: str = "raw"
def get_signature_from_fn(fn: Callable[..., Any]) -> str:
"""
Get the function signature as a string.
:Example:
>>> def my_function(a: int, b: int) -> int:
>>> return a + b
>>>
>>> get_signature_from_fn(my_function)
"def my_function(a: int, b: int) -> int"
:param fn: Function to get the signature for.
:return: Function signature as a string.
"""
sig = inspect.signature(fn)
lines = f"def {fn.__name__}{sig}"
docstring = inspect.getdoc(fn)
if docstring:
formatted_docstring = f'"""\n{docstring}\n"""'
else:
formatted_docstring = ""
return f"{lines}\n{formatted_docstring}"
@functools.lru_cache()
def format_function(func: Callable[..., Any]) -> str:
"""
Format a function as a string with docstring and body.
"""
source_lines = inspect.getsourcelines(func)
definition = " ".join(source_lines[0]).strip()
docstring = inspect.getdoc(func)
if docstring:
formatted_docstring = f'"""\n{docstring}\n"""'
else:
formatted_docstring = ""
body = inspect.getsource(func)
body = body.replace(f"def {func.__name__}", "")
return f"{definition}\n{formatted_docstring}\n{body}"
def is_return_type_base_model_or_instance(func: Callable[..., Any]) -> bool:
"""
Check if the return type of a function is a pydantic BaseModel or an instance of it.
:param func: Function to check.
:return: True if the return type is a pydantic BaseModel or an instance of it.
"""
return_type = inspect.signature(func).return_annotation
assert (
return_type != inspect.Signature.empty
), "Must have a return type hint that is a pydantic BaseModel"
return inspect.isclass(return_type) and issubclass(return_type, BaseModel)
class Instructions:
def __init__(
self,
name: Optional[str] = None,
id: Optional[str] = None,
log_handlers: Optional[List[logging.Handler]] = None,
finetune_format: FinetuneFormat = FinetuneFormat.MESSAGES,
indent: int = 2,
include_code_body: bool = False,
openai_client: Optional[OpenAI] = None,
) -> None:
"""
Instructions for distillation and dispatch.
:param name: Name of the instructions.
:param id: ID of the instructions.
:param log_handlers: List of log handlers to use.
:param finetune_format: Format to use for finetuning.
:param indent: Indentation to use for finetuning.
:param include_code_body: Whether to include the code body in the finetuning.
"""
self.name = name
self.id = id or str(uuid.uuid4())
self.unique_id = str(uuid.uuid4())
self.finetune_format = finetune_format
self.indent = indent
self.include_code_body = include_code_body
self.client = openai_client or OpenAI()
self.logger = logging.getLogger(self.name)
for handler in log_handlers or []:
self.logger.addHandler(handler)
def distil(
self,
*args: Any,
name: Optional[str] = None,
mode: str = "distil",
model: str = "gpt-3.5-turbo",
fine_tune_format: Optional[FinetuneFormat] = None,
) -> Callable[
[Callable[..., Any]],
Callable[[Callable[..., T_Retval]], Callable[..., T_Retval]],
]:
"""
Decorator to track the function call and response, supports distillation and dispatch modes.
If used without arguments, it must be used as a decorator.
:Example:
>>> @distil
>>> def my_function() -> MyModel:
>>> return MyModel()
>>>
>>> @distil(name="my_function")
>>> def my_function() -> MyModel:
>>> return MyModel()
:param fn: Function to track.
:param name: Name of the function to track. Defaults to the function name.
:param mode: Mode to use for distillation. Defaults to "distil".
"""
allowed_modes = {"distil", "dispatch"}
assert mode in allowed_modes, f"Must be in {allowed_modes}"
if fine_tune_format is None:
fine_tune_format = self.finetune_format
def _wrap_distil(
fn: Callable[..., Any],
) -> Callable[[Callable[..., T_Retval]], Callable[..., T_Retval]]:
msg = f"Return type hint for {fn} must subclass `pydantic.BaseModel'"
assert is_return_type_base_model_or_instance(fn), msg
return_base_model = inspect.signature(fn).return_annotation
@functools.wraps(fn)
def _dispatch(*args: Any, **kwargs: Any) -> Callable[..., T_Retval]:
name = name if name else fn.__name__
openai_kwargs = self.openai_kwargs(
name=name,
fn=fn,
args=args,
kwargs=kwargs,
base_model=return_base_model,
)
return self.client.chat.completions.create(
**openai_kwargs, model=model, response_model=return_base_model
)
@functools.wraps(fn)
def _distil(*args: Any, **kwargs: Any) -> Callable[..., T_Retval]:
resp = fn(*args, **kwargs)
self.track(
fn, args, kwargs, resp, name=name, finetune_format=fine_tune_format
)
return resp
return _dispatch if mode == "dispatch" else _distil
if len(args) == 1 and callable(args[0]):
return _wrap_distil(args[0])
return _wrap_distil
@validate_call # type: ignore[misc]
def track(
self,
fn: Callable[..., Any],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
resp: BaseModel,
name: Optional[str] = None,
finetune_format: FinetuneFormat = FinetuneFormat.MESSAGES,
) -> None:
"""
Track the function call and response in a log file, later used for finetuning.
:param fn: Function to track.
:param args: Arguments passed to the function.
:param kwargs: Keyword arguments passed to the function.
:param resp: Response returned by the function.
:param name: Name of the function to track. Defaults to the function name.
:param finetune_format: Format to use for finetuning. Defaults to "raw".
"""
name = name if name else fn.__name__
base_model: BaseModel = type(resp)
if finetune_format == FinetuneFormat.MESSAGES:
openai_function_call = openai_schema(base_model).openai_schema
openai_kwargs = self.openai_kwargs(name, fn, args, kwargs, base_model)
openai_kwargs["messages"].append(
{
"role": "assistant",
"function_call": {
"name": base_model.__name__,
"arguments": resp.model_dump_json(indent=self.indent),
},
}
)
openai_kwargs["functions"] = [openai_function_call]
self.logger.info(json.dumps(openai_kwargs))
if finetune_format == FinetuneFormat.RAW:
function_body = dict(
fn_name=name,
fn_repr=format_function(fn),
args=args,
kwargs=kwargs,
resp=resp.model_dump(),
schema=base_model.model_json_schema(),
)
self.logger.info(json.dumps(function_body))
def openai_kwargs(
self,
name: str,
fn: Callable[..., Any],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
base_model: Type[BaseModel],
) -> Dict[str, Any]:
if self.include_code_body:
func_def = format_function(fn)
else:
func_def = get_signature_from_fn(fn)
str_args = ", ".join(map(str, args))
str_kwargs = (
", ".join(f"{k}={json.dumps(v)}" for k, v in kwargs.items()) or None
)
call_args = ", ".join(filter(None, [str_args, str_kwargs]))
function_body = {
"messages": [
{
"role": "system",
"content": f"Predict the results of this function:\n\n{func_def}",
},
{
"role": "user",
"content": f"Return `{name}({call_args})`",
},
],
}
return function_body