From 9aedf730d633e35c77832816123ae242d7edd048 Mon Sep 17 00:00:00 2001 From: Joschka Braun <47435119+joschkabraun@users.noreply.github.com> Date: Fri, 1 Mar 2024 18:04:45 -0500 Subject: [PATCH] fix: account for multiple times wrapped functions (#476) --- instructor/patch.py | 8 +++++--- tests/test_patch.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/instructor/patch.py b/instructor/patch.py index fe4c1f6..63e85df 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -467,9 +467,11 @@ def retry_sync( def is_async(func: Callable) -> bool: """Returns true if the callable is async, accounting for wrapped callables""" - return inspect.iscoroutinefunction(func) or ( - hasattr(func, "__wrapped__") and inspect.iscoroutinefunction(func.__wrapped__) - ) + is_coroutine = inspect.iscoroutinefunction(func) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + is_coroutine = is_coroutine or inspect.iscoroutinefunction(func) + return is_coroutine OVERRIDE_DOCS = """ diff --git a/tests/test_patch.py b/tests/test_patch.py index 0418a1e..b0e72e5 100644 --- a/tests/test_patch.py +++ b/tests/test_patch.py @@ -39,6 +39,40 @@ def test_is_async_returns_true_if_wrapped_function_is_async(): assert is_async(wrapped_function) is True +def test_is_async_returns_true_if_double_wrapped_function_is_async(): + async def async_function(): + pass + + @functools.wraps(async_function) + def wrapped_function(): + pass + + @functools.wraps(wrapped_function) + def double_wrapped_function(): + pass + + assert is_async(double_wrapped_function) is True + + +def test_is_async_returns_true_if_triple_wrapped_function_is_async(): + async def async_function(): + pass + + @functools.wraps(async_function) + def wrapped_function(): + pass + + @functools.wraps(wrapped_function) + def double_wrapped_function(): + pass + + @functools.wraps(double_wrapped_function) + def triple_wrapped_function(): + pass + + assert is_async(triple_wrapped_function) is True + + def test_override_docs(): assert ( "response_model" in OVERRIDE_DOCS