This commit is contained in:
Jason Liu
2023-06-15 09:35:57 +09:00
committed by GitHub
parent 81d50b1ec2
commit eb74065e16
+11 -7
View File
@@ -19,23 +19,25 @@ class openai_function:
@wraps(self.func)
def wrapper(*args, **kwargs):
return self.validate_func(*args, **kwargs)
return wrapper(*args, **kwargs)
def from_response(self, completion, throw_error=True):
"""Execute the function from the response of an openai chat completion"""
message = completion.choices[0].message
if throw_error:
assert "function_call" in message, "No function call detected"
assert message["function_call"]["name"] == self.schema["name"], "Function name does not match"
assert (
message["function_call"]["name"] == self.openai_schema["name"]
), "Function name does not match"
function_call = message["function_call"]
arguments = json.loads(function_call["arguments"])
return self.validate_func(**arguments)
class OpenAISchema(BaseModel):
class OpenAISchema(BaseModel):
@classmethod
@property
def openai_schema(cls):
@@ -43,7 +45,7 @@ class OpenAISchema(BaseModel):
return {
"name": schema["title"],
"description": schema["description"],
"parameters": schema
"parameters": schema,
}
@classmethod
@@ -52,8 +54,10 @@ class OpenAISchema(BaseModel):
if throw_error:
assert "function_call" in message, "No function call detected"
assert message["function_call"]["name"] == cls.openai_schema["name"], "Function name does not match"
assert (
message["function_call"]["name"] == cls.openai_schema["name"]
), "Function name does not match"
function_call = message["function_call"]
arguments = json.loads(function_call["arguments"])
return cls(**arguments)