Support patching openai (#78)

* update docs

* add patch

* bump version
This commit is contained in:
Jason Liu
2023-08-17 01:25:29 +08:00
committed by GitHub
parent 0cb80886fa
commit 944366847f
7 changed files with 215 additions and 5 deletions
+56 -2
View File
@@ -22,7 +22,7 @@ Welcome to the Quick Start Guide for OpenAI Function Call. This guide will walk
### Requirements
This library depends on **Pydantic** an **OpenAI** that's all.
This library depends on **Pydantic** and **OpenAI** that's all.
### Installation
@@ -35,7 +35,58 @@ To get started with OpenAI Function Call, you need to install it using `pip`. Ru
$ pip install openai_function_call
```
## Quick Start
## Quick Start with Patching ChatCompletion
To simplify your work with OpenAI models and streamline the extraction of Pydantic objects from prompts, we offer a patching mechanism for the `ChatCompletion`` class. Here's a step-by-step guide:
### Step 1: Import and Patch the Module
First, import the required libraries and apply the patch function to the OpenAI module. This exposes new functionality with the response_model parameter.
```python
import openai
from pydantic import BaseModel
from openai_function_call import patch
patch()
```
### Step 2: Define the Pydantic Model
Create a Pydantic model to define the structure of the data you want to extract. This model will map directly to the information in the prompt.
```python
class UserDetail(BaseModel):
name: str
age: int
```
### Step 3: Extract Data with ChatCompletion
Use the openai.ChatCompletion.create method to send a prompt and extract the data into the Pydantic object. The response_model parameter specifies the Pydantic model to use for extraction.
```python
user: UserDetail = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
response_model=UserDetail,
messages=[
{"role": "user", "content": "Extract Jason is 25 years old"},
]
)
```
### Step 4: Validate the Extracted Data
You can then validate the extracted data by asserting the expected values. By adding the type things you also get a bunch of nice benefits with your IDE like spell check and auto complete!
```python
assert user.name == "Jason"
assert user.age == 25
```
## Introduction to `OpenAISchema`
If you want more control than just passing a single class we can use the `OpenAISchema` which extends `BaseModel`.
This quick start guide contains the follow sections:
@@ -64,6 +115,9 @@ In this schema, we define a `UserDetails` class that extends `OpenAISchema`. We
To enhance the performance of the OpenAI language model, you can add additional prompting in the form of docstrings and field descriptions. They can provide context and guide the model on how to process the data.
!!! note Using `patch`
these docstrings and fields descriptions are powered by `pydantic.BaseModel` so they'll work via the patching approach as well.
```python hl_lines="5 6"
from openai_function_call import OpenAISchema
from pydantic import Field
+23
View File
@@ -0,0 +1,23 @@
import openai
from pydantic import BaseModel
from openai_function_call import patch
# By default, the patch function will patch the ChatCompletion.create and ChatCompletion.acreate methods. to support response_model parameter
patch()
# Now, we can use the response_model parameter using only a base model
# rather than having to use the OpenAISchema class
class UserExtract(BaseModel):
name: str
age: int
user: UserExtract = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
response_model=UserExtract,
messages=[
{"role": "user", "content": "Extract jason is 25 years old"},
],
) # type: ignore
print(user)
+1 -1
View File
@@ -44,7 +44,7 @@ markdown_extensions:
- admonition
nav:
- Introduction:
- OpenAISchema: 'index.md'
- Getting Started: 'index.md'
- MultiTask: "multitask.md"
- Philosophy: 'philosophy.md'
- Use Cases:
+8 -1
View File
@@ -1,4 +1,11 @@
from .function_calls import OpenAISchema, openai_function, openai_schema
from .dsl.multitask import MultiTask
from .patch import patch
__all__ = ["OpenAISchema", "openai_function", "MultiTask", "openai_schema"]
__all__ = [
"OpenAISchema",
"openai_function",
"MultiTask",
"openai_schema",
"patch",
]
+100
View File
@@ -0,0 +1,100 @@
from functools import wraps
import openai
import inspect
from typing import Callable, Optional, Type, Union
from pydantic import BaseModel
from openai_function_call import OpenAISchema, openai_schema
def wrap_chatcompletion(func: Callable) -> Callable:
is_async = inspect.iscoroutinefunction(func)
if is_async:
@wraps(func)
async def new_chatcompletion(
*args,
response_model: Optional[Union[Type[BaseModel], Type[OpenAISchema]]] = None,
**kwargs
): # type: ignore
if response_model is not None:
if not issubclass(response_model, OpenAISchema):
response_model = openai_schema(response_model)
kwargs["functions"] = [response_model.openai_schema]
kwargs["function_call"] = {"name": response_model.openai_schema["name"]}
if kwargs.get("stream", False) and response_model is not None:
import warnings
warnings.warn(
"stream=True is not supported when using response_model parameter"
)
response = await func(*args, **kwargs)
if response_model is not None:
model = response_model.from_response(response)
model._raw_response = response
return model
return response
else:
@wraps(func)
def new_chatcompletion(
*args,
response_model: Optional[Union[Type[BaseModel], Type[OpenAISchema]]] = None,
**kwargs
):
if response_model is not None:
if not issubclass(response_model, OpenAISchema):
response_model = openai_schema(response_model)
kwargs["functions"] = [response_model.openai_schema]
kwargs["function_call"] = {"name": response_model.openai_schema["name"]}
if kwargs.get("stream", False) and response_model is not None:
import warnings
warnings.warn(
"stream=True is not supported when using response_model parameter"
)
response = func(*args, **kwargs)
if response_model is not None:
model = response_model.from_response(response)
model._raw_response = response
return model
return response
new_chatcompletion.__doc__ = """
Creates a new chat completion for the provided messages and parameters.
See: https://platform.openai.com/docs/api-reference/chat-completions/create
Additional Notes:
Using the `response_model` parameter, you can specify a response model to use for parsing the response from OpenAI's API. If its present, the response will be parsed using the response model, otherwise it will be returned as is.
If `stream=True` is specified, the response will be parsed using the `from_stream_response` method of the response model, if available, otherwise it will be parsed using the `from_response` method.
If need to obtain the raw response from OpenAI's API, you can access it using the `_raw_response` attribute of the response model.
Parameters:
response_model (Union[Type[BaseModel], Type[OpenAISchema]]): The response model to use for parsing the response from OpenAI's API, if available (default: None)
"""
return new_chatcompletion
original_chatcompletion = openai.ChatCompletion.create
original_chatcompletion_async = openai.ChatCompletion.acreate
def patch():
openai.ChatCompletion.create = wrap_chatcompletion(original_chatcompletion)
openai.ChatCompletion.acreate = wrap_chatcompletion(original_chatcompletion_async)
def unpatch():
openai.ChatCompletion.create = original_chatcompletion
openai.ChatCompletion.acreate = original_chatcompletion_async
+1 -1
View File
@@ -1,6 +1,6 @@
[tool.poetry]
name = "openai-function-call"
version = "0.2.2"
version = "0.2.3"
description = "Helper functions that allow us to improve openai's function_call ergonomics"
authors = ["Jason <jason@jxnl.co>"]
license = "MIT"
+26
View File
@@ -0,0 +1,26 @@
from pydantic import BaseModel
import pytest
import openai
from openai_function_call import patch
@pytest.mark.skip(reason="Needs openai call")
def test_runmodel():
patch()
class UserExtract(BaseModel):
name: str
age: int
model = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
response_model=UserExtract,
messages=[
{"role": "user", "content": "Extract jason is 25 years old"},
],
)
assert isinstance(model, UserExtract), "Should be instance of UserExtract"
assert model.name.lower() == "jason"
assert hasattr(
model, "_raw_response"
), "The raw response should be available from OpenAI"