diff --git a/instructor/anthropic_utils.py b/instructor/anthropic_utils.py index ed8ee94..d9fd9f7 100644 --- a/instructor/anthropic_utils.py +++ b/instructor/anthropic_utils.py @@ -50,6 +50,7 @@ def _add_params( # TODO: handling of nested params with the same name properties = model_dict.get("properties", {}) list_found = False + nested_list_found = False for field_name, details in properties.items(): parameter = ET.SubElement(root, "parameter") @@ -74,11 +75,19 @@ def _add_params( field_type = details.get( "type", "unknown" ) # Might be better to fail here if there is no type since pydantic models require types + + if "array" in field_type and "items" not in details: + raise ValueError("Invalid array item.") - # Adjust type if array - if "array" in field_type or "List" in field_type: + # Check for nested List + if "array" in field_type and "$ref" in details["items"]: type_element.text = f"List[{details['title']}]" list_found = True + nested_list_found = True + # Check for non-nested List + elif "array" in field_type and "type" in details["items"]: + type_element.text = f"List[{details['items']['type']}]" + list_found = True else: type_element.text = field_type @@ -105,22 +114,13 @@ def _add_params( reference, references, ) - elif field_type == "array": # Handling for List[] type + elif field_type == "array" and nested_list_found: # Handling for List[] type nested_params = ET.SubElement(parameter, "parameters") list_found |= _add_params( nested_params, _resolve_reference(references, details["items"]["$ref"]), references, ) - elif "array" in field_type: # Handling for optional List[] type - nested_params = ET.SubElement(parameter, "parameters") - list_found |= _add_params( - nested_params, - _resolve_reference( - references, details["anyOf"][0]["items"]["$ref"] - ), # CHANGE - references, - ) return list_found diff --git a/instructor/function_calls.py b/instructor/function_calls.py index 3e0f9c9..f63040e 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -69,6 +69,7 @@ class OpenAISchema(BaseModel): # type: ignore[misc] for line in parseString(json_to_xml(cls)).toprettyxml().splitlines()[1:] ) + @classmethod def from_response( cls, diff --git a/poetry.lock b/poetry.lock index 259f349..59e0874 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -2677,6 +2677,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, diff --git a/tests/anthropic/test_simple.py b/tests/anthropic/test_simple.py index c8cce0a..205306a 100644 --- a/tests/anthropic/test_simple.py +++ b/tests/anthropic/test_simple.py @@ -63,6 +63,31 @@ def test_nested_type(): assert resp.address.street_name == "First Avenue" +def test_list(): + class User(BaseModel): + name: str + age: int + family: List[str] + + resp = create( + model="claude-3-opus-20240229", # Fails with claude-3-haiku-20240307 + max_tokens=1024, + max_retries=0, + messages=[ + { + "role": "user", + "content": "Create a user for a model with a name, age, and family members.", + } + ], + response_model=User, + ) + + assert isinstance(resp, User) + assert isinstance(resp.family, List) + for member in resp.family: + assert isinstance(member, str) + + def test_nested_list(): class Properties(BaseModel): key: str