mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
implement dispatch
This commit is contained in:
+36
-6
@@ -4,10 +4,12 @@ import inspect
|
||||
import json
|
||||
import logging
|
||||
|
||||
from typing import Any, Callable, List, Optional
|
||||
import uuid
|
||||
from typing import Any, Callable, List, Optional, Type
|
||||
from pydantic import BaseModel, validate_call
|
||||
|
||||
import uuid
|
||||
import openai
|
||||
|
||||
from instructor import openai_schema
|
||||
|
||||
|
||||
@@ -85,6 +87,16 @@ class Instructions:
|
||||
indent: int = 2,
|
||||
include_code_body: bool = False,
|
||||
):
|
||||
"""
|
||||
Instructions for distillation and dispatch.
|
||||
|
||||
:param name: Name of the instructions.
|
||||
:param id: ID of the instructions.
|
||||
:param log_handlers: List of log handlers to use.
|
||||
:param finetune_format: Format to use for finetuning.
|
||||
:param indent: Indentation to use for finetuning.
|
||||
:param include_code_body: Whether to include the code body in the finetuning.
|
||||
"""
|
||||
self.name = name
|
||||
self.id = id or str(uuid.uuid4())
|
||||
self.unique_id = str(uuid.uuid4())
|
||||
@@ -132,6 +144,20 @@ class Instructions:
|
||||
def _wrap_distil(fn):
|
||||
msg = f"Return type hint for {fn} must subclass `pydantic.BaseModel'"
|
||||
assert is_return_type_base_model_or_instance(fn), msg
|
||||
return_base_model = inspect.signature(fn).return_annotation
|
||||
|
||||
@functools.wraps(fn)
|
||||
def _dispatch(*args, **kwargs):
|
||||
openai_kwargs = self.openai_kwargs(
|
||||
name=name,
|
||||
fn=fn,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
base_model=return_base_model,
|
||||
)
|
||||
return openai.ChatCompletion.create(
|
||||
**openai_kwargs, response_model=return_base_model
|
||||
)
|
||||
|
||||
@functools.wraps(fn)
|
||||
def _distil(*args, **kwargs):
|
||||
@@ -142,7 +168,11 @@ class Instructions:
|
||||
|
||||
return resp
|
||||
|
||||
return _distil
|
||||
if mode == "dispatch":
|
||||
return _dispatch
|
||||
|
||||
if mode == "distil":
|
||||
return _distil
|
||||
|
||||
if len(args) == 1 and callable(args[0]):
|
||||
return _wrap_distil(args[0])
|
||||
@@ -173,6 +203,7 @@ class Instructions:
|
||||
base_model: BaseModel = type(resp)
|
||||
|
||||
if finetune_format == FinetuneFormat.MESSAGES:
|
||||
openai_function_call = openai_schema(base_model).openai_schema
|
||||
openai_kwargs = self.openai_kwargs(name, fn, args, kwargs, base_model)
|
||||
openai_kwargs["messages"].append(
|
||||
{
|
||||
@@ -183,6 +214,8 @@ class Instructions:
|
||||
},
|
||||
}
|
||||
)
|
||||
openai_kwargs["functions"] = [openai_function_call]
|
||||
openai_kwargs["function_call"] = {"name": openai_function_call["name"]}
|
||||
self.logger.info(json.dumps(openai_kwargs))
|
||||
|
||||
if finetune_format == FinetuneFormat.RAW:
|
||||
@@ -197,8 +230,6 @@ class Instructions:
|
||||
self.logger.info(json.dumps(function_body))
|
||||
|
||||
def openai_kwargs(self, name, fn, args, kwargs, base_model):
|
||||
openai_function_call = openai_schema(base_model).openai_schema
|
||||
|
||||
if self.include_code_body:
|
||||
func_def = format_function(fn)
|
||||
else:
|
||||
@@ -221,6 +252,5 @@ class Instructions:
|
||||
"content": f"Return `{name}({call_args})`",
|
||||
},
|
||||
],
|
||||
"functions": [openai_function_call],
|
||||
}
|
||||
return function_body
|
||||
|
||||
Reference in New Issue
Block a user