diff --git a/instructor/anthropic_utils.py b/instructor/anthropic_utils.py index 7f6b12f..52510f3 100644 --- a/instructor/anthropic_utils.py +++ b/instructor/anthropic_utils.py @@ -83,10 +83,21 @@ def _add_params( if ( isinstance(details, dict) and "$ref" in details ): # Checking if there are nested params + reference = _resolve_reference(references, details["$ref"]) + + if 'enum' in reference: + type_element.text = reference['type'] + enum_values = reference['enum'] + values = ET.SubElement(parameter, "values") + for value in enum_values: + value_element = ET.SubElement(values, "value") + value_element.text = value + continue + nested_params = ET.SubElement(parameter, "parameters") list_found |= _add_params( nested_params, - _resolve_reference(references, details["$ref"]), + reference, references, ) elif field_type == "array": # Handling for List[] type diff --git a/tests/anthropic/test_anthropic.py b/tests/anthropic/test_anthropic.py index 8bf5d94..5cbbf33 100644 --- a/tests/anthropic/test_anthropic.py +++ b/tests/anthropic/test_anthropic.py @@ -1,12 +1,14 @@ -import pytest -import anthropic -import instructor -from pydantic import BaseModel +from enum import Enum from typing import List -create = instructor.patch( - create=anthropic.Anthropic().messages.create, mode=instructor.Mode.ANTHROPIC_TOOLS -) +import anthropic +import pytest +from pydantic import BaseModel + +import instructor + +create = instructor.patch(create=anthropic.Anthropic().messages.create, mode=instructor.Mode.ANTHROPIC_TOOLS) + @pytest.mark.skip def test_anthropic(): @@ -33,3 +35,31 @@ def test_anthropic(): ) # type: ignore assert isinstance(resp, User) + + +@pytest.mark.skip +def test_anthropic_enum(): + class ProgrammingLanguage(Enum): + PYTHON = "python" + JAVASCRIPT = "javascript" + TYPESCRIPT = "typescript" + UNKNOWN = "unknown" + OTHER = "other" + + class SimpleEnum(BaseModel): + language: ProgrammingLanguage + + resp = create( + model="claude-3-haiku-20240307", + max_tokens=1024, + max_retries=0, + messages=[ + { + "role": "user", + "content": "What is your favorite programming language?", + } + ], + response_model=SimpleEnum, + ) # type: ignore + + assert isinstance(resp, SimpleEnum)