implement dispatch

This commit is contained in:
Jason Liu
2023-10-22 17:54:19 -04:00
parent edf29482e4
commit 7119d18257
+36 -6
View File
@@ -4,10 +4,12 @@ import inspect
import json
import logging
from typing import Any, Callable, List, Optional
import uuid
from typing import Any, Callable, List, Optional, Type
from pydantic import BaseModel, validate_call
import uuid
import openai
from instructor import openai_schema
@@ -85,6 +87,16 @@ class Instructions:
indent: int = 2,
include_code_body: bool = False,
):
"""
Instructions for distillation and dispatch.
:param name: Name of the instructions.
:param id: ID of the instructions.
:param log_handlers: List of log handlers to use.
:param finetune_format: Format to use for finetuning.
:param indent: Indentation to use for finetuning.
:param include_code_body: Whether to include the code body in the finetuning.
"""
self.name = name
self.id = id or str(uuid.uuid4())
self.unique_id = str(uuid.uuid4())
@@ -132,6 +144,20 @@ class Instructions:
def _wrap_distil(fn):
msg = f"Return type hint for {fn} must subclass `pydantic.BaseModel'"
assert is_return_type_base_model_or_instance(fn), msg
return_base_model = inspect.signature(fn).return_annotation
@functools.wraps(fn)
def _dispatch(*args, **kwargs):
openai_kwargs = self.openai_kwargs(
name=name,
fn=fn,
args=args,
kwargs=kwargs,
base_model=return_base_model,
)
return openai.ChatCompletion.create(
**openai_kwargs, response_model=return_base_model
)
@functools.wraps(fn)
def _distil(*args, **kwargs):
@@ -142,7 +168,11 @@ class Instructions:
return resp
return _distil
if mode == "dispatch":
return _dispatch
if mode == "distil":
return _distil
if len(args) == 1 and callable(args[0]):
return _wrap_distil(args[0])
@@ -173,6 +203,7 @@ class Instructions:
base_model: BaseModel = type(resp)
if finetune_format == FinetuneFormat.MESSAGES:
openai_function_call = openai_schema(base_model).openai_schema
openai_kwargs = self.openai_kwargs(name, fn, args, kwargs, base_model)
openai_kwargs["messages"].append(
{
@@ -183,6 +214,8 @@ class Instructions:
},
}
)
openai_kwargs["functions"] = [openai_function_call]
openai_kwargs["function_call"] = {"name": openai_function_call["name"]}
self.logger.info(json.dumps(openai_kwargs))
if finetune_format == FinetuneFormat.RAW:
@@ -197,8 +230,6 @@ class Instructions:
self.logger.info(json.dumps(function_body))
def openai_kwargs(self, name, fn, args, kwargs, base_model):
openai_function_call = openai_schema(base_model).openai_schema
if self.include_code_body:
func_def = format_function(fn)
else:
@@ -221,6 +252,5 @@ class Instructions:
"content": f"Return `{name}({call_args})`",
},
],
"functions": [openai_function_call],
}
return function_body