diff --git a/instructor/patch.py b/instructor/patch.py index 446ac67..29293fe 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -92,11 +92,13 @@ def handle_response_model( elif mode == Mode.JSON or mode == Mode.MD_JSON: if mode == Mode.JSON: new_kwargs["response_format"] = {"type": "json_object"} - # check that the first message is a system message - # if it is not, add a system message to the beginning message = f"""Make sure that your response to any message matches the json_schema below, - do not deviate at all: \n{response_model.model_json_schema()['properties']} - """ + do not deviate at all: \n{response_model.model_json_schema()['properties']} + """ + # Check for nested models + if '$defs' in response_model.model_json_schema(): + message += f"\nHere are some more definitions to adhere too:\n{response_model.model_json_schema()['$defs']}" + else: message = f""" As a genius expert, your task is to understand the content and provide @@ -110,6 +112,8 @@ def handle_response_model( }, ) new_kwargs["stop"] = "```" + # check that the first message is a system message + # if it is not, add a system message to the beginning if new_kwargs["messages"][0]["role"] != "system": new_kwargs["messages"].insert( 0, diff --git a/tests/openai/evals/test_nested_structures.py b/tests/openai/evals/test_nested_structures.py new file mode 100644 index 0000000..7818558 --- /dev/null +++ b/tests/openai/evals/test_nested_structures.py @@ -0,0 +1,50 @@ +from typing import Iterable +from openai import OpenAI +from pydantic import BaseModel, Field +from typing import List + +import pytest + +import instructor +from instructor.function_calls import Mode + + +class Item(BaseModel): + name: str + price: float + + +class Order(BaseModel): + items: List[Item] = Field(..., default_factory=list) + customer: str + + +@pytest.mark.parametrize("mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS, Mode.MD_JSON]) +def test_nested(mode): + client = instructor.patch(OpenAI(), mode=mode) + + content = """ + Order Details: + Customer: Jason + Items: + + Name: Apple, Price: 0.50 + Name: Bread, Price: 2.00 + Name: Milk, Price: 1.50 + """ + + resp = client.chat.completions.create( + model="gpt-3.5-turbo-1106", + response_model=Order, + messages=[ + { + "role": "user", + "content": content, + }, + ], + ) + + assert len(resp.items) == 3 + assert {x.name.lower() for x in resp.items} == {"apple", "bread", "milk"} + assert {x.price for x in resp.items} == {0.5, 2.0, 1.5} + assert resp.customer.lower() == "jason"