mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
Fix async usage (#167)
Co-authored-by: Jason Liu <jxnl@users.noreply.github.com> Co-authored-by: Jason Liu <jason@jxnl.co>
This commit is contained in:
+45
-13
@@ -1,9 +1,13 @@
|
||||
import inspect
|
||||
|
||||
from functools import wraps
|
||||
from json import JSONDecodeError
|
||||
from pydantic import ValidationError, BaseModel
|
||||
from typing import Callable, Type, Optional
|
||||
from logging import warn
|
||||
from typing import Callable, Optional, Type, Union
|
||||
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from .function_calls import OpenAISchema, openai_schema
|
||||
|
||||
OVERRIDE_DOCS = """
|
||||
@@ -66,6 +70,18 @@ def process_response(
|
||||
return response
|
||||
|
||||
|
||||
def dump_message(message: ChatCompletionMessage) -> dict:
|
||||
"""Dumps a message to a dict, to be returned to the OpenAI API.
|
||||
|
||||
Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
|
||||
if it isn't used.
|
||||
"""
|
||||
dumped_message = message.model_dump()
|
||||
if not dumped_message.get("tool_calls"):
|
||||
del dumped_message["tool_calls"]
|
||||
return dumped_message
|
||||
|
||||
|
||||
async def retry_async(
|
||||
func,
|
||||
response_model,
|
||||
@@ -78,7 +94,7 @@ async def retry_async(
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
response = await func(*args, **kwargs)
|
||||
response: ChatCompletion = await func(*args, **kwargs)
|
||||
return (
|
||||
process_response(
|
||||
response,
|
||||
@@ -122,7 +138,7 @@ def retry_sync(
|
||||
None,
|
||||
)
|
||||
except (ValidationError, JSONDecodeError) as e:
|
||||
kwargs["messages"].append(response.choices[0].message) # type: ignore
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message))
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "user",
|
||||
@@ -134,7 +150,16 @@ def retry_sync(
|
||||
raise e
|
||||
|
||||
|
||||
def wrap_chatcompletion(func: Callable, is_async: bool = None) -> Callable:
|
||||
def is_async(func: Callable) -> bool:
|
||||
"""Returns true if the callable is async, accounting for wrapped callables"""
|
||||
return inspect.iscoroutinefunction(func) or (
|
||||
hasattr(func, "__wrapped__") and inspect.iscoroutinefunction(func.__wrapped__)
|
||||
)
|
||||
|
||||
|
||||
def wrap_chatcompletion(func: Callable) -> Callable:
|
||||
func_is_async = is_async(func)
|
||||
|
||||
@wraps(func)
|
||||
async def new_chatcompletion_async(
|
||||
response_model=None,
|
||||
@@ -177,12 +202,14 @@ def wrap_chatcompletion(func: Callable, is_async: bool = None) -> Callable:
|
||||
raise ValueError(error)
|
||||
return response
|
||||
|
||||
wrapper_function = new_chatcompletion_async if is_async else new_chatcompletion_sync
|
||||
wrapper_function = (
|
||||
new_chatcompletion_async if func_is_async else new_chatcompletion_sync
|
||||
)
|
||||
wrapper_function.__doc__ = OVERRIDE_DOCS
|
||||
return wrapper_function
|
||||
|
||||
|
||||
def patch(client):
|
||||
def patch(client: Union[OpenAI, AsyncOpenAI]):
|
||||
"""
|
||||
Patch the `client.chat.completions.create` method
|
||||
|
||||
@@ -198,9 +225,11 @@ def patch(client):
|
||||
return client
|
||||
|
||||
|
||||
def apatch(client):
|
||||
def apatch(client: AsyncOpenAI):
|
||||
"""
|
||||
Patch the `client.chat.completions.acreate` and `client.chat.completions.acreate` methods
|
||||
No longer necessary, use `patch` instead.
|
||||
|
||||
Patch the `client.chat.completions.create` method
|
||||
|
||||
Enables the following features:
|
||||
|
||||
@@ -209,7 +238,10 @@ def apatch(client):
|
||||
- `validation_context` parameter to validate the response using the pydantic model
|
||||
- `strict` parameter to use strict json parsing
|
||||
"""
|
||||
client.chat.completions.create = wrap_chatcompletion(
|
||||
client.chat.completions.create, is_async=True
|
||||
|
||||
# Emit a deprecation warning
|
||||
warn(
|
||||
"instructor.apatch is deprecated, use instructor.patch instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return client
|
||||
return patch(client)
|
||||
|
||||
Reference in New Issue
Block a user