fix: add handling for List[non-object] types (#521)

Co-authored-by: Jason Liu <jxnl@users.noreply.github.com>
This commit is contained in:
shreya w
2024-03-22 20:37:58 -04:00
committed by GitHub
parent a9d6cd8f3b
commit cea534fd22
4 changed files with 40 additions and 13 deletions
+12 -12
View File
@@ -50,6 +50,7 @@ def _add_params(
# TODO: handling of nested params with the same name
properties = model_dict.get("properties", {})
list_found = False
nested_list_found = False
for field_name, details in properties.items():
parameter = ET.SubElement(root, "parameter")
@@ -74,11 +75,19 @@ def _add_params(
field_type = details.get(
"type", "unknown"
) # Might be better to fail here if there is no type since pydantic models require types
if "array" in field_type and "items" not in details:
raise ValueError("Invalid array item.")
# Adjust type if array
if "array" in field_type or "List" in field_type:
# Check for nested List
if "array" in field_type and "$ref" in details["items"]:
type_element.text = f"List[{details['title']}]"
list_found = True
nested_list_found = True
# Check for non-nested List
elif "array" in field_type and "type" in details["items"]:
type_element.text = f"List[{details['items']['type']}]"
list_found = True
else:
type_element.text = field_type
@@ -105,22 +114,13 @@ def _add_params(
reference,
references,
)
elif field_type == "array": # Handling for List[] type
elif field_type == "array" and nested_list_found: # Handling for List[] type
nested_params = ET.SubElement(parameter, "parameters")
list_found |= _add_params(
nested_params,
_resolve_reference(references, details["items"]["$ref"]),
references,
)
elif "array" in field_type: # Handling for optional List[] type
nested_params = ET.SubElement(parameter, "parameters")
list_found |= _add_params(
nested_params,
_resolve_reference(
references, details["anyOf"][0]["items"]["$ref"]
), # CHANGE
references,
)
return list_found
+1
View File
@@ -69,6 +69,7 @@ class OpenAISchema(BaseModel): # type: ignore[misc]
for line in parseString(json_to_xml(cls)).toprettyxml().splitlines()[1:]
)
@classmethod
def from_response(
cls,
Generated
+2 -1
View File
@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "aiohttp"
@@ -2677,6 +2677,7 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
+25
View File
@@ -63,6 +63,31 @@ def test_nested_type():
assert resp.address.street_name == "First Avenue"
def test_list():
class User(BaseModel):
name: str
age: int
family: List[str]
resp = create(
model="claude-3-opus-20240229", # Fails with claude-3-haiku-20240307
max_tokens=1024,
max_retries=0,
messages=[
{
"role": "user",
"content": "Create a user for a model with a name, age, and family members.",
}
],
response_model=User,
)
assert isinstance(resp, User)
assert isinstance(resp.family, List)
for member in resp.family:
assert isinstance(member, str)
def test_nested_list():
class Properties(BaseModel):
key: str