mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
64 lines
2.1 KiB
Python
64 lines
2.1 KiB
Python
import json
|
|
from functools import wraps
|
|
from typing import Any, Callable
|
|
from pydantic import validate_arguments, BaseModel
|
|
|
|
|
|
class openai_function:
|
|
def __init__(self, func: Callable) -> None:
|
|
self.func = func
|
|
self.validate_func = validate_arguments(func)
|
|
self.openai_schema = {
|
|
"name": self.func.__name__,
|
|
"description": self.func.__doc__,
|
|
"parameters": self.validate_func.model.schema(),
|
|
}
|
|
self.model = self.validate_func.model
|
|
|
|
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
@wraps(self.func)
|
|
def wrapper(*args, **kwargs):
|
|
return self.validate_func(*args, **kwargs)
|
|
|
|
return wrapper(*args, **kwargs)
|
|
|
|
def from_response(self, completion, throw_error=True):
|
|
"""Execute the function from the response of an openai chat completion"""
|
|
message = completion.choices[0].message
|
|
|
|
if throw_error:
|
|
assert "function_call" in message, "No function call detected"
|
|
assert (
|
|
message["function_call"]["name"] == self.openai_schema["name"]
|
|
), "Function name does not match"
|
|
|
|
function_call = message["function_call"]
|
|
arguments = json.loads(function_call["arguments"])
|
|
return self.validate_func(**arguments)
|
|
|
|
|
|
class OpenAISchema(BaseModel):
|
|
@classmethod
|
|
@property
|
|
def openai_schema(cls):
|
|
schema = cls.schema()
|
|
return {
|
|
"name": schema["title"],
|
|
"description": schema["description"],
|
|
"parameters": schema,
|
|
}
|
|
|
|
@classmethod
|
|
def from_response(cls, completion, throw_error=True):
|
|
message = completion.choices[0].message
|
|
|
|
if throw_error:
|
|
assert "function_call" in message, "No function call detected"
|
|
assert (
|
|
message["function_call"]["name"] == cls.openai_schema["name"]
|
|
), "Function name does not match"
|
|
|
|
function_call = message["function_call"]
|
|
arguments = json.loads(function_call["arguments"])
|
|
return cls(**arguments)
|