From f41b8e4948d034a39959e08dc191227a03757980 Mon Sep 17 00:00:00 2001 From: Jason Liu Date: Fri, 9 Feb 2024 17:59:42 -0500 Subject: [PATCH] fix(parallel): enhance error handling in get_types_array and add test case (#423) --- ellipsis.yaml | 3 ++- instructor/dsl/parallel.py | 3 ++- instructor/patch.py | 3 ++- tests/openai/test_parallel.py | 17 +++++++++++++++++ 4 files changed, 23 insertions(+), 3 deletions(-) diff --git a/ellipsis.yaml b/ellipsis.yaml index fdc51af..fbe251e 100644 --- a/ellipsis.yaml +++ b/ellipsis.yaml @@ -3,7 +3,7 @@ version: 1.1 pr_review: auto_review_enabled: true auto_summarize_pr: true - confidence_threshold: 0.9 + confidence_threshold: 0.85 rules: # Control what gets flagged during PR review with custom rules. Here are some to get you started: - "Code should be DRY (Dont Repeat Yourself)" @@ -13,3 +13,4 @@ pr_review: - "If library code changes, expect documentation to be updated" - "If library code changes, check if tests are updated" - "If a new `md` file is created in `docs` make sure its added to mkdocs.yml" + - "Assertions should always have an error message that is formatted well. " diff --git a/instructor/dsl/parallel.py b/instructor/dsl/parallel.py index c1756c7..831c1c0 100644 --- a/instructor/dsl/parallel.py +++ b/instructor/dsl/parallel.py @@ -46,7 +46,8 @@ class ParallelBase: def get_types_array(typehint: Type[Iterable[Union[T]]]) -> Tuple[Type[T], ...]: should_be_iterable = get_origin(typehint) - assert should_be_iterable is Iterable + if should_be_iterable is not Iterable: + raise TypeError(f"Model should be with Iterable instead if {typehint}") if get_origin(get_args(typehint)[0]) is Union: # works for Iterable[Union[int, str]] diff --git a/instructor/patch.py b/instructor/patch.py index e000d38..29c8a1e 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -339,9 +339,10 @@ async def retry_async( "name": response.choices[0] .message.tool_calls[0] .function.name, - "content": "failure", + "content": "Exceptions found\n{e}\nRecall the function correctly.", } ) + kwargs["messages"].append( { "role": "user", diff --git a/tests/openai/test_parallel.py b/tests/openai/test_parallel.py index 6f4b17b..16e9015 100644 --- a/tests/openai/test_parallel.py +++ b/tests/openai/test_parallel.py @@ -14,6 +14,23 @@ class GoogleSearch(BaseModel): query: str +def test_sync_parallel_tools__error(client): + client = instructor.patch(client, mode=instructor.Mode.PARALLEL_TOOLS) + + with pytest.raises(TypeError): + resp = client.chat.completions.create( + model="gpt-4-turbo-preview", + messages=[ + {"role": "system", "content": "You must always use tools"}, + { + "role": "user", + "content": "What is the weather in toronto and dallas and who won the super bowl?", + }, + ], + response_model=Weather, + ) + + def test_sync_parallel_tools_or(client): client = instructor.patch(client, mode=instructor.Mode.PARALLEL_TOOLS) resp = client.chat.completions.create(