From 7555ef01ae06d1c810472fe2f95f65fb5f4b4d42 Mon Sep 17 00:00:00 2001 From: Jason Liu Date: Wed, 14 Jun 2023 19:45:23 +0900 Subject: [PATCH] create code --- openai_function_call.py | 59 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 openai_function_call.py diff --git a/openai_function_call.py b/openai_function_call.py new file mode 100644 index 0000000..4ed50e4 --- /dev/null +++ b/openai_function_call.py @@ -0,0 +1,59 @@ +import json +from functools import wraps +from typing import Any, Callable +from pydantic import validate_arguments, BaseModel + + +class openai_function: + def __init__(self, func: Callable) -> None: + self.func = func + self.validate_func = validate_arguments(func) + self.openai_schema = { + "name": self.func.__name__, + "description": self.func.__doc__, + "parameters": self.validate_func.model.schema(), + } + self.model = self.validate_func.model + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + @wraps(self.func) + def wrapper(*args, **kwargs): + return self.validate_func(*args, **kwargs) + return wrapper(*args, **kwargs) + + def from_response(self, completion, throw_error=True): + """Execute the function from the response of an openai chat completion""" + message = completion.choices[0].message + + if throw_error: + assert "function_call" in message, "No function call detected" + assert message["function_call"]["name"] == self.schema["name"], "Function name does not match" + + function_call = message["function_call"] + arguments = json.loads(function_call["arguments"]) + return self.validate_func(**arguments) + + +class OpenAISchema(BaseModel): + + @classmethod + @property + def openai_schema(cls): + schema = cls.schema() + return { + "name": schema["title"], + "description": schema["description"], + "parameters": schema + } + + @classmethod + def from_response(cls, completion, throw_error=True): + message = completion.choices[0].message + + if throw_error: + assert "function_call" in message, "No function call detected" + assert message["function_call"]["name"] == cls.openai_schema["name"], "Function name does not match" + + function_call = message["function_call"] + arguments = json.loads(function_call["arguments"]) + return cls(**arguments)