mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
Markdown JSON Mode (#246)
Co-authored-by: Jason Liu <jxnl@users.noreply.github.com>
This commit is contained in:
@@ -40,6 +40,20 @@ from openai import OpenAI
|
||||
client = instructor.patch(OpenAI(), mode=Mode.JSON)
|
||||
```
|
||||
|
||||
## Markdown JSON Mode
|
||||
|
||||
!!! warning "Experimental"
|
||||
|
||||
This is not recommended, and may not be supported in the future, this is just left to support vision models.
|
||||
|
||||
```python
|
||||
import instructor
|
||||
from instructor import Mode
|
||||
from openai import OpenAI
|
||||
|
||||
client = instructor.patch(OpenAI(), mode=Mode.MD_JSON)
|
||||
|
||||
```
|
||||
### Schema Integration
|
||||
|
||||
In JSON Mode, the schema is part of the system message:
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
import instructor
|
||||
from openai import OpenAI
|
||||
from typing import Iterable
|
||||
from pydantic import BaseModel
|
||||
import base64
|
||||
|
||||
client = instructor.patch(OpenAI(), mode=instructor.function_calls.Mode.MD_JSON)
|
||||
|
||||
class Circle(BaseModel):
|
||||
x: int
|
||||
y: int
|
||||
color: str
|
||||
|
||||
def encode_image(image_path):
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||
|
||||
def draw_circle(image_size, num_circles, path):
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
import random
|
||||
|
||||
image = Image.new("RGB", image_size, "white")
|
||||
|
||||
draw = ImageDraw.Draw(image)
|
||||
for _ in range(num_circles):
|
||||
# Randomize the circle properties
|
||||
radius = 100#random.randint(10, min(image_size)//5) # Radius between 10 and 1/5th of the smallest dimension
|
||||
x = random.randint(radius, image_size[0] - radius)
|
||||
y = random.randint(radius, image_size[1] - radius)
|
||||
color = ['red', 'black', 'blue', 'green'][random.randint(0, 3)]
|
||||
|
||||
circle_position = (x - radius, y - radius, x + radius, y + radius)
|
||||
print(f"Generating circle at {x, y} with color {color}")
|
||||
draw.ellipse(circle_position, fill=color, outline="black")
|
||||
|
||||
image.save(path)
|
||||
|
||||
img_path = 'circle.jpg'
|
||||
draw_circle((1024,1024), 1, img_path)
|
||||
base64_image = encode_image(img_path)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4-vision-preview",
|
||||
max_tokens=1800,
|
||||
response_model=Circle,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": 'find the circle'},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
print(f"Found circle with center at x: {response.x}, y: {response.y} and color: {response.color}")
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import re
|
||||
from docstring_parser import parse
|
||||
from functools import wraps
|
||||
from typing import Any, Callable
|
||||
@@ -13,6 +14,7 @@ class Mode(enum.Enum):
|
||||
FUNCTIONS: str = "function_call"
|
||||
TOOLS: str = "tool_call"
|
||||
JSON: str = "json_mode"
|
||||
MD_JSON: str = "markdown_json_mode"
|
||||
|
||||
|
||||
class openai_function:
|
||||
@@ -237,6 +239,12 @@ class OpenAISchema(BaseModel):
|
||||
context=validation_context,
|
||||
strict=strict,
|
||||
)
|
||||
elif mode == Mode.MD_JSON:
|
||||
return cls.model_validate_json(
|
||||
message.content,
|
||||
context=validation_context,
|
||||
strict=strict,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid patch mode: {mode}")
|
||||
|
||||
|
||||
+35
-8
@@ -89,13 +89,27 @@ def handle_response_model(
|
||||
"type": "function",
|
||||
"function": {"name": response_model.openai_schema["name"]},
|
||||
}
|
||||
elif mode == Mode.JSON:
|
||||
new_kwargs["response_format"] = {"type": "json_object"}
|
||||
|
||||
# check that the first message is a system message
|
||||
# if it is not, add a system message to the beginning
|
||||
message = f"Make sure that your response to any message matches the json_schema below, do not deviate at all: \n{response_model.model_json_schema()['properties']}"
|
||||
|
||||
elif mode == Mode.JSON or mode == Mode.MD_JSON:
|
||||
if mode == Mode.JSON:
|
||||
new_kwargs["response_format"] = {"type": "json_object"}
|
||||
# check that the first message is a system message
|
||||
# if it is not, add a system message to the beginning
|
||||
message = f"""Make sure that your response to any message matches the json_schema below,
|
||||
do not deviate at all: \n{response_model.model_json_schema()['properties']}
|
||||
"""
|
||||
else:
|
||||
message = f"""
|
||||
As a genius expert, your task is to understand the content and provide
|
||||
the parsed objects in json that match the following json_schema (do not deviate at all and its okay if you cant be exact):\n
|
||||
{response_model.model_json_schema()['properties']}
|
||||
"""
|
||||
new_kwargs["messages"].append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "```json",
|
||||
},
|
||||
)
|
||||
new_kwargs["stop"] = "```"
|
||||
if new_kwargs["messages"][0]["role"] != "system":
|
||||
new_kwargs["messages"].insert(
|
||||
0,
|
||||
@@ -110,7 +124,6 @@ def handle_response_model(
|
||||
new_kwargs["messages"][0]["content"] += f"\n\n{message}"
|
||||
else:
|
||||
raise ValueError(f"Invalid patch mode: {mode}")
|
||||
|
||||
return response_model, new_kwargs
|
||||
|
||||
|
||||
@@ -182,6 +195,13 @@ async def retry_async(
|
||||
"content": f"Recall the function correctly, exceptions found\n{e}",
|
||||
}
|
||||
)
|
||||
if mode == Mode.MD_JSON:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "```json",
|
||||
},
|
||||
)
|
||||
retries += 1
|
||||
if retries > max_retries:
|
||||
raise e
|
||||
@@ -219,6 +239,13 @@ def retry_sync(
|
||||
"content": f"Recall the function correctly, exceptions found\n{e}",
|
||||
}
|
||||
)
|
||||
if mode == Mode.MD_JSON:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "```json",
|
||||
},
|
||||
)
|
||||
retries += 1
|
||||
if retries > max_retries:
|
||||
raise e
|
||||
|
||||
@@ -62,6 +62,27 @@ def test_json_mode():
|
||||
assert user.age == 25
|
||||
|
||||
|
||||
|
||||
def test_markdown_json_mode():
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-1106",
|
||||
response_format={"type": "json_object"},
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"Make sure that your response to any message matchs the json_schema below, do not deviate at all: \n{UserExtract.model_json_schema()['properties']}",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Extract jason is 25 years old",
|
||||
},
|
||||
],
|
||||
)
|
||||
user = UserExtract.from_response(response, mode=Mode.MD_JSON)
|
||||
assert user.name.lower() == "jason"
|
||||
assert user.age == 25
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS])
|
||||
def test_mode(mode):
|
||||
client = instructor.patch(OpenAI(), mode=mode)
|
||||
|
||||
Reference in New Issue
Block a user