mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
create code
This commit is contained in:
@@ -0,0 +1,59 @@
|
||||
import json
|
||||
from functools import wraps
|
||||
from typing import Any, Callable
|
||||
from pydantic import validate_arguments, BaseModel
|
||||
|
||||
|
||||
class openai_function:
|
||||
def __init__(self, func: Callable) -> None:
|
||||
self.func = func
|
||||
self.validate_func = validate_arguments(func)
|
||||
self.openai_schema = {
|
||||
"name": self.func.__name__,
|
||||
"description": self.func.__doc__,
|
||||
"parameters": self.validate_func.model.schema(),
|
||||
}
|
||||
self.model = self.validate_func.model
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
@wraps(self.func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return self.validate_func(*args, **kwargs)
|
||||
return wrapper(*args, **kwargs)
|
||||
|
||||
def from_response(self, completion, throw_error=True):
|
||||
"""Execute the function from the response of an openai chat completion"""
|
||||
message = completion.choices[0].message
|
||||
|
||||
if throw_error:
|
||||
assert "function_call" in message, "No function call detected"
|
||||
assert message["function_call"]["name"] == self.schema["name"], "Function name does not match"
|
||||
|
||||
function_call = message["function_call"]
|
||||
arguments = json.loads(function_call["arguments"])
|
||||
return self.validate_func(**arguments)
|
||||
|
||||
|
||||
class OpenAISchema(BaseModel):
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def openai_schema(cls):
|
||||
schema = cls.schema()
|
||||
return {
|
||||
"name": schema["title"],
|
||||
"description": schema["description"],
|
||||
"parameters": schema
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_response(cls, completion, throw_error=True):
|
||||
message = completion.choices[0].message
|
||||
|
||||
if throw_error:
|
||||
assert "function_call" in message, "No function call detected"
|
||||
assert message["function_call"]["name"] == cls.openai_schema["name"], "Function name does not match"
|
||||
|
||||
function_call = message["function_call"]
|
||||
arguments = json.loads(function_call["arguments"])
|
||||
return cls(**arguments)
|
||||
Reference in New Issue
Block a user