mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 14:50:16 +00:00
Fixed the system prompt for JSON mode, enabling the use of Pydantic nested models (#249)
Co-authored-by: Jason Liu <jason@jxnl.co> Co-authored-by: Jason Liu <jxnl@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
cb96010bba
commit
202f9cb227
+8
-4
@@ -92,11 +92,13 @@ def handle_response_model(
|
||||
elif mode == Mode.JSON or mode == Mode.MD_JSON:
|
||||
if mode == Mode.JSON:
|
||||
new_kwargs["response_format"] = {"type": "json_object"}
|
||||
# check that the first message is a system message
|
||||
# if it is not, add a system message to the beginning
|
||||
message = f"""Make sure that your response to any message matches the json_schema below,
|
||||
do not deviate at all: \n{response_model.model_json_schema()['properties']}
|
||||
"""
|
||||
do not deviate at all: \n{response_model.model_json_schema()['properties']}
|
||||
"""
|
||||
# Check for nested models
|
||||
if '$defs' in response_model.model_json_schema():
|
||||
message += f"\nHere are some more definitions to adhere too:\n{response_model.model_json_schema()['$defs']}"
|
||||
|
||||
else:
|
||||
message = f"""
|
||||
As a genius expert, your task is to understand the content and provide
|
||||
@@ -110,6 +112,8 @@ def handle_response_model(
|
||||
},
|
||||
)
|
||||
new_kwargs["stop"] = "```"
|
||||
# check that the first message is a system message
|
||||
# if it is not, add a system message to the beginning
|
||||
if new_kwargs["messages"][0]["role"] != "system":
|
||||
new_kwargs["messages"].insert(
|
||||
0,
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
from typing import Iterable
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
import instructor
|
||||
from instructor.function_calls import Mode
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
name: str
|
||||
price: float
|
||||
|
||||
|
||||
class Order(BaseModel):
|
||||
items: List[Item] = Field(..., default_factory=list)
|
||||
customer: str
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS, Mode.MD_JSON])
|
||||
def test_nested(mode):
|
||||
client = instructor.patch(OpenAI(), mode=mode)
|
||||
|
||||
content = """
|
||||
Order Details:
|
||||
Customer: Jason
|
||||
Items:
|
||||
|
||||
Name: Apple, Price: 0.50
|
||||
Name: Bread, Price: 2.00
|
||||
Name: Milk, Price: 1.50
|
||||
"""
|
||||
|
||||
resp = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-1106",
|
||||
response_model=Order,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
assert len(resp.items) == 3
|
||||
assert {x.name.lower() for x in resp.items} == {"apple", "bread", "milk"}
|
||||
assert {x.price for x in resp.items} == {0.5, 2.0, 1.5}
|
||||
assert resp.customer.lower() == "jason"
|
||||
Reference in New Issue
Block a user