diff --git a/examples/classification/test_run.py b/examples/classification/test_run.py new file mode 100644 index 0000000..83c0fa6 --- /dev/null +++ b/examples/classification/test_run.py @@ -0,0 +1,32 @@ +import pytest + +from examples.planning.run import extract_person, extract_people, Person + + +@pytest.mark.asyncio +async def test_extract_person(): + # Test the extract_person function with a known input + text = "John is 45 years old" + expected_person = Person(name="John", age=45) + person = await extract_person(text) + assert ( + person == expected_person + ), "The extracted person does not match the expected person" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "names_and_ages, expected_people", + [ + ( + ["Alice is 30 years old", "Bob is 24 years old"], + [Person(name="Alice", age=30), Person(name="Bob", age=24)], + ) + ], +) +async def test_extract_people(names_and_ages, expected_people): + # Test the extract_people function with a list of known inputs + people = await extract_people(names_and_ages) + assert ( + people == expected_people + ), "The extracted people do not match the expected people" diff --git a/examples/match_language/run_v1.py b/examples/match_language/run_v1.py index 67326fa..28f2d7e 100644 --- a/examples/match_language/run_v1.py +++ b/examples/match_language/run_v1.py @@ -50,6 +50,7 @@ client = patch(AsyncOpenAI()) class GeneratedSummary(BaseModel): summary: str + async def summarize_text(text: str): response = await client.chat.completions.create( model="gpt-3.5-turbo", diff --git a/examples/match_language/run_v2.py b/examples/match_language/run_v2.py index e8ce294..84ca894 100644 --- a/examples/match_language/run_v2.py +++ b/examples/match_language/run_v2.py @@ -53,6 +53,7 @@ class GeneratedSummary(BaseModel): ) summary: str + async def summarize_text(text: str): response = await client.chat.completions.create( model="gpt-3.5-turbo", @@ -94,4 +95,4 @@ if __name__ == "__main__": Source: de, Summary: de, Match: True, Detected: de Source: hi, Summary: hi, Match: True, Detected: hi Source: ja, Summary: ja, Match: True, Detected: ja - """ \ No newline at end of file + """ diff --git a/instructor/anthropic_utils.py b/instructor/anthropic_utils.py index d9fd9f7..ab38a48 100644 --- a/instructor/anthropic_utils.py +++ b/instructor/anthropic_utils.py @@ -75,7 +75,7 @@ def _add_params( field_type = details.get( "type", "unknown" ) # Might be better to fail here if there is no type since pydantic models require types - + if "array" in field_type and "items" not in details: raise ValueError("Invalid array item.") @@ -83,7 +83,7 @@ def _add_params( if "array" in field_type and "$ref" in details["items"]: type_element.text = f"List[{details['title']}]" list_found = True - nested_list_found = True + nested_list_found = True # Check for non-nested List elif "array" in field_type and "type" in details["items"]: type_element.text = f"List[{details['items']['type']}]" diff --git a/instructor/function_calls.py b/instructor/function_calls.py index f63040e..3e0f9c9 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -69,7 +69,6 @@ class OpenAISchema(BaseModel): # type: ignore[misc] for line in parseString(json_to_xml(cls)).toprettyxml().splitlines()[1:] ) - @classmethod def from_response( cls, diff --git a/tests/anthropic/test_simple.py b/tests/anthropic/test_simple.py index 205306a..98c8666 100644 --- a/tests/anthropic/test_simple.py +++ b/tests/anthropic/test_simple.py @@ -68,9 +68,9 @@ def test_list(): name: str age: int family: List[str] - + resp = create( - model="claude-3-opus-20240229", # Fails with claude-3-haiku-20240307 + model="claude-3-opus-20240229", # Fails with claude-3-haiku-20240307 max_tokens=1024, max_retries=0, messages=[ @@ -81,13 +81,13 @@ def test_list(): ], response_model=User, ) - + assert isinstance(resp, User) assert isinstance(resp.family, List) for member in resp.family: assert isinstance(member, str) - + def test_nested_list(): class Properties(BaseModel): key: str