diff --git a/instructor/distil.py b/instructor/distil.py index 5d52972..3ea844f 100644 --- a/instructor/distil.py +++ b/instructor/distil.py @@ -4,10 +4,12 @@ import inspect import json import logging -from typing import Any, Callable, List, Optional -import uuid +from typing import Any, Callable, List, Optional, Type from pydantic import BaseModel, validate_call +import uuid +import openai + from instructor import openai_schema @@ -85,6 +87,16 @@ class Instructions: indent: int = 2, include_code_body: bool = False, ): + """ + Instructions for distillation and dispatch. + + :param name: Name of the instructions. + :param id: ID of the instructions. + :param log_handlers: List of log handlers to use. + :param finetune_format: Format to use for finetuning. + :param indent: Indentation to use for finetuning. + :param include_code_body: Whether to include the code body in the finetuning. + """ self.name = name self.id = id or str(uuid.uuid4()) self.unique_id = str(uuid.uuid4()) @@ -132,6 +144,20 @@ class Instructions: def _wrap_distil(fn): msg = f"Return type hint for {fn} must subclass `pydantic.BaseModel'" assert is_return_type_base_model_or_instance(fn), msg + return_base_model = inspect.signature(fn).return_annotation + + @functools.wraps(fn) + def _dispatch(*args, **kwargs): + openai_kwargs = self.openai_kwargs( + name=name, + fn=fn, + args=args, + kwargs=kwargs, + base_model=return_base_model, + ) + return openai.ChatCompletion.create( + **openai_kwargs, response_model=return_base_model + ) @functools.wraps(fn) def _distil(*args, **kwargs): @@ -142,7 +168,11 @@ class Instructions: return resp - return _distil + if mode == "dispatch": + return _dispatch + + if mode == "distil": + return _distil if len(args) == 1 and callable(args[0]): return _wrap_distil(args[0]) @@ -173,6 +203,7 @@ class Instructions: base_model: BaseModel = type(resp) if finetune_format == FinetuneFormat.MESSAGES: + openai_function_call = openai_schema(base_model).openai_schema openai_kwargs = self.openai_kwargs(name, fn, args, kwargs, base_model) openai_kwargs["messages"].append( { @@ -183,6 +214,8 @@ class Instructions: }, } ) + openai_kwargs["functions"] = [openai_function_call] + openai_kwargs["function_call"] = {"name": openai_function_call["name"]} self.logger.info(json.dumps(openai_kwargs)) if finetune_format == FinetuneFormat.RAW: @@ -197,8 +230,6 @@ class Instructions: self.logger.info(json.dumps(function_body)) def openai_kwargs(self, name, fn, args, kwargs, base_model): - openai_function_call = openai_schema(base_model).openai_schema - if self.include_code_body: func_def = format_function(fn) else: @@ -221,6 +252,5 @@ class Instructions: "content": f"Return `{name}({call_args})`", }, ], - "functions": [openai_function_call], } return function_body