Fix async usage (#167)

Co-authored-by: Jason Liu <jxnl@users.noreply.github.com>
Co-authored-by: Jason Liu <jason@jxnl.co>
This commit is contained in:
Isaac Poulton
2023-11-14 23:57:06 +00:00
committed by GitHub
parent 78bf56921f
commit 4de58fd157
7 changed files with 150 additions and 56 deletions
+45 -13
View File
@@ -1,9 +1,13 @@
import inspect
from functools import wraps
from json import JSONDecodeError
from pydantic import ValidationError, BaseModel
from typing import Callable, Type, Optional
from logging import warn
from typing import Callable, Optional, Type, Union
from openai import AsyncOpenAI, OpenAI
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from pydantic import BaseModel, ValidationError
from .function_calls import OpenAISchema, openai_schema
OVERRIDE_DOCS = """
@@ -66,6 +70,18 @@ def process_response(
return response
def dump_message(message: ChatCompletionMessage) -> dict:
"""Dumps a message to a dict, to be returned to the OpenAI API.
Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
if it isn't used.
"""
dumped_message = message.model_dump()
if not dumped_message.get("tool_calls"):
del dumped_message["tool_calls"]
return dumped_message
async def retry_async(
func,
response_model,
@@ -78,7 +94,7 @@ async def retry_async(
retries = 0
while retries <= max_retries:
try:
response = await func(*args, **kwargs)
response: ChatCompletion = await func(*args, **kwargs)
return (
process_response(
response,
@@ -122,7 +138,7 @@ def retry_sync(
None,
)
except (ValidationError, JSONDecodeError) as e:
kwargs["messages"].append(response.choices[0].message) # type: ignore
kwargs["messages"].append(dump_message(response.choices[0].message))
kwargs["messages"].append(
{
"role": "user",
@@ -134,7 +150,16 @@ def retry_sync(
raise e
def wrap_chatcompletion(func: Callable, is_async: bool = None) -> Callable:
def is_async(func: Callable) -> bool:
"""Returns true if the callable is async, accounting for wrapped callables"""
return inspect.iscoroutinefunction(func) or (
hasattr(func, "__wrapped__") and inspect.iscoroutinefunction(func.__wrapped__)
)
def wrap_chatcompletion(func: Callable) -> Callable:
func_is_async = is_async(func)
@wraps(func)
async def new_chatcompletion_async(
response_model=None,
@@ -177,12 +202,14 @@ def wrap_chatcompletion(func: Callable, is_async: bool = None) -> Callable:
raise ValueError(error)
return response
wrapper_function = new_chatcompletion_async if is_async else new_chatcompletion_sync
wrapper_function = (
new_chatcompletion_async if func_is_async else new_chatcompletion_sync
)
wrapper_function.__doc__ = OVERRIDE_DOCS
return wrapper_function
def patch(client):
def patch(client: Union[OpenAI, AsyncOpenAI]):
"""
Patch the `client.chat.completions.create` method
@@ -198,9 +225,11 @@ def patch(client):
return client
def apatch(client):
def apatch(client: AsyncOpenAI):
"""
Patch the `client.chat.completions.acreate` and `client.chat.completions.acreate` methods
No longer necessary, use `patch` instead.
Patch the `client.chat.completions.create` method
Enables the following features:
@@ -209,7 +238,10 @@ def apatch(client):
- `validation_context` parameter to validate the response using the pydantic model
- `strict` parameter to use strict json parsing
"""
client.chat.completions.create = wrap_chatcompletion(
client.chat.completions.create, is_async=True
# Emit a deprecation warning
warn(
"instructor.apatch is deprecated, use instructor.patch instead",
DeprecationWarning,
)
return client
return patch(client)