feat: Improve MD_JSON mode (#490)

This commit is contained in:
Jason Liu
2024-03-06 12:57:05 -05:00
committed by GitHub
parent 00bedcf165
commit 3e44a6bc30
21 changed files with 277 additions and 89 deletions
+2 -2
View File
@@ -33,12 +33,12 @@ def extract(data) -> UserDetail:
start = time.perf_counter() # (1)
model = extract("Extract jason is 25 years old")
print(f"Time taken: {time.perf_counter() - start}")
#> Time taken: 0.8392175831831992
#> Time taken: 0.7583793329977198
start = time.perf_counter()
model = extract("Extract jason is 25 years old") # (2)
print(f"Time taken: {time.perf_counter() - start}")
#> Time taken: 8.33999365568161e-07
#> Time taken: 4.3330073822289705e-06
```
1. Using `time.perf_counter()` to measure the time taken to run the function is better than using `time.time()` because it's more accurate and less susceptible to system clock changes.
+2 -2
View File
@@ -157,8 +157,8 @@ async def print_iterable_results():
)
async for m in model:
print(m)
#> name='John Smith' age=30
#> name='Mary Jane' age=28
#> name='John Doe' age=30
#> name='Jane Doe' age=28
import asyncio
+1 -1
View File
@@ -89,7 +89,7 @@ print(user2.model_dump_json(indent=2))
{
"result": null,
"error": false,
"message": null
"message": "Unknown user"
}
"""
```
+1 -1
View File
@@ -150,7 +150,7 @@ class SearchQuery(BaseModel):
def execute(self):
print(f"Searching for {self.query} of type {self.query_type}")
#> Searching for cat pictures of type image
#> Searching for cat of type image
return "Results for cat"
+2 -2
View File
@@ -44,9 +44,9 @@ function_calls = client.chat.completions.create(
for fc in function_calls:
print(fc)
#> location='Toronto' units='imperial'
#> location='Toronto' units='metric'
#> location='Dallas' units='imperial'
#> query='who won the super bowl'
#> query='super bowl winner'
```
1. Set the mode to `PARALLEL_TOOLS` to enable parallel function calling.
+2 -2
View File
@@ -119,10 +119,10 @@ print(extraction.model_dump_json(indent=2))
"twitter": "@CodeMaster2023"
}
],
"date": "2024-03-15",
"date": "March 15th, 2024",
"location": "Grand Tech Arena located at 4521 Innovation Drive",
"budget": 50000,
"deadline": "2024-02-20"
"deadline": "February 20th"
}
"""
```
+4 -4
View File
@@ -25,7 +25,7 @@ user: UserExtract = client.chat.completions.create(
print(user._raw_response)
"""
ChatCompletion(
id='chatcmpl-8u9bsrmmf5YjZyfCtQymoZV8LK1qg',
id='chatcmpl-8zpltT9vXJdO5OE3AfDsOhAUr911A',
choices=[
Choice(
finish_reason='stop',
@@ -37,7 +37,7 @@ ChatCompletion(
function_call=None,
tool_calls=[
ChatCompletionMessageToolCall(
id='call_O5rpXf47YgXiYrYWv45yZUeM',
id='call_vXI3foz7jqlzFILU9pwuYJZB',
function=Function(
arguments='{"name":"Jason","age":25}', name='UserExtract'
),
@@ -47,10 +47,10 @@ ChatCompletion(
),
)
],
created=1708394000,
created=1709747709,
model='gpt-3.5-turbo-0125',
object='chat.completion',
system_fingerprint='fp_69829325d0',
system_fingerprint='fp_2b778c6b35',
usage=CompletionUsage(completion_tokens=9, prompt_tokens=82, total_tokens=91),
)
"""
+1 -1
View File
@@ -91,7 +91,7 @@ except ValidationError as e:
"""
1 validation error for QuestionAnswer
answer
Assertion failed, The statement promotes objectionable behavior by encouraging evil and theft. [type=assertion_error, input_value='The meaning of life is to be evil and steal', input_type=str]
Assertion failed, The statement promotes objectionable behavior by encouraging evil and stealing, which goes against the rule of not saying objectionable things. [type=assertion_error, input_value='The meaning of life is to be evil and steal', input_type=str]
For further information visit https://errors.pydantic.dev/2.6/v/assertion_error
"""
```
@@ -47,6 +47,7 @@ client = instructor.patch(client, mode=instructor.Mode.TOOLS)
# Rate limit the number of requests
sem = asyncio.Semaphore(5)
# Use an Enum to define the types of questions
class QuestionType(Enum):
CONTACT = "CONTACT"
+2 -3
View File
@@ -51,9 +51,8 @@ client = Groq(
)
# By default, the patch function will patch the ChatCompletion.create and ChatCompletion.create methods to support the response_model parameter
client = instructor.patch(
client, mode=instructor.Mode.MD_JSON
)
client = instructor.patch(client, mode=instructor.Mode.MD_JSON)
# Now, we can use the response_model parameter using only a base model
# rather than having to use the OpenAISchema class
+2 -4
View File
@@ -50,12 +50,10 @@ from mistralai.client import MistralClient
# enables `response_model` in chat call
client = MistralClient()
patched_chat = instructor.patch(
create=client.chat,
mode=instructor.Mode.MISTRAL_TOOLS
)
patched_chat = instructor.patch(create=client.chat, mode=instructor.Mode.MISTRAL_TOOLS)
if __name__ == "__main__":
class UserDetails(BaseModel):
name: str
age: int
+6 -6
View File
@@ -108,13 +108,13 @@ if __name__ == "__main__":
assert isinstance(df, pd.DataFrame)
print(df)
"""
Party Years Served
Party Years Served
President
Joe Biden Democratic 2021-
Donald Trump Republican 2017-2021
Barack Obama Democratic 2009-2017
George W. Bush Republican 2001-2009
Bill Clinton Democratic 1993-2001
Joe Biden Democratic 2021-Current
Donald Trump Republican 2017-2021
Barack Obama Democratic 2009-2017
George W. Bush Republican 2001-2009
Bill Clinton Democratic 1993-2001
"""
table = extract_table(
+3 -6
View File
@@ -15,7 +15,7 @@ instructor hub pull --slug youtube-clips --py > youtube_clips.py
```python
from youtube_transcript_api import YouTubeTranscriptApi
from pydantic import BaseModel, Field
from typing import List, Dict, Generator, Iterable
from typing import List, Generator, Iterable
import instructor
import openai
@@ -24,12 +24,12 @@ client = instructor.patch(openai.OpenAI())
def extract_video_id(url: str) -> str | None:
import re
match = re.search(r"v=([a-zA-Z0-9_-]+)", url)
if match:
return match.group(1)
class TranscriptSegment(BaseModel):
source_id: int
start: float
@@ -51,9 +51,7 @@ def get_transcript_with_timing(
class YoutubeClip(BaseModel):
title: str = Field(
description="Specific and informative title for the clip."
)
title: str = Field(description="Specific and informative title for the clip.")
description: str = Field(
description="A detailed description of the clip, including notable quotes or phrases."
)
@@ -98,7 +96,6 @@ if __name__ == "__main__":
console = Console()
url = Prompt.ask("Enter a YouTube URL")
with console.status("[bold green]Processing YouTube URL...") as status:
video_id = extract_video_id(url)
+4 -4
View File
@@ -115,7 +115,7 @@ print(response.model_dump_json(indent=2))
print(user._raw_response.model_dump_json(indent=2))
"""
{
"id": "chatcmpl-8u9e2TV3ehCgLsRxNLLeAbzpEmBuZ",
"id": "chatcmpl-8zplvRbNM8iKSVa3Ld9NmVICeXZZ9",
"choices": [
{
"finish_reason": "stop",
@@ -127,7 +127,7 @@ print(response.model_dump_json(indent=2))
"function_call": null,
"tool_calls": [
{
"id": "call_3ZuQhfteTLEy7CUokjwnLBHr",
"id": "call_V5FRMSXrHFFTTqTjpwA76h7t",
"function": {
"arguments": "{\"name\":\"Jason\",\"age\":25}",
"name": "UserDetail"
@@ -138,10 +138,10 @@ print(response.model_dump_json(indent=2))
}
}
],
"created": 1708394134,
"created": 1709747711,
"model": "gpt-3.5-turbo-0125",
"object": "chat.completion",
"system_fingerprint": "fp_69829325d0",
"system_fingerprint": "fp_2b778c6b35",
"usage": {
"completion_tokens": 9,
"prompt_tokens": 81,
+11 -1
View File
@@ -3,6 +3,7 @@ from typing import Any, AsyncGenerator, Generator, Iterable, List, Optional, Tup
from pydantic import BaseModel, Field, create_model
from instructor.function_calls import OpenAISchema, Mode
from instructor.utils import extract_json_from_stream, extract_json_from_stream_async
class IterableBase:
@@ -13,6 +14,10 @@ class IterableBase:
cls, completion: Iterable[Any], mode: Mode, **kwargs: Any
) -> Generator[BaseModel, None, None]: # noqa: ARG003
json_chunks = cls.extract_json(completion, mode)
if mode == Mode.MD_JSON:
json_chunks = extract_json_from_stream(json_chunks)
yield from cls.tasks_from_chunks(json_chunks, **kwargs)
@classmethod
@@ -20,6 +25,10 @@ class IterableBase:
cls, completion: AsyncGenerator[Any, None], mode: Mode, **kwargs: Any
) -> AsyncGenerator[BaseModel, None]:
json_chunks = cls.extract_json_async(completion, mode)
if mode == Mode.MD_JSON:
json_chunks = extract_json_from_stream_async(json_chunks)
return cls.tasks_from_chunks_async(json_chunks, **kwargs)
@classmethod
@@ -110,13 +119,14 @@ class IterableBase:
@staticmethod
def get_object(s: str, stack: int) -> Tuple[Optional[str], str]:
start_index = s.find("{")
for i, c in enumerate(s):
if c == "{":
stack += 1
if c == "}":
stack -= 1
if stack == 0:
return s[: i + 1], s[i + 2 :]
return s[start_index : i + 1], s[i + 2 :]
return None, s
+9
View File
@@ -24,6 +24,7 @@ from copy import deepcopy
from instructor.function_calls import Mode
from instructor.dsl.partialjson import JSONParser
from instructor.utils import extract_json_from_stream, extract_json_from_stream_async
parser = JSONParser()
T_Model = TypeVar("T_Model", bound=BaseModel)
@@ -35,6 +36,10 @@ class PartialBase(Generic[T_Model]):
cls, completion: Iterable[Any], mode: Mode, **kwargs: Any
) -> Generator[T_Model, None, None]:
json_chunks = cls.extract_json(completion, mode)
if mode == Mode.MD_JSON:
json_chunks = extract_json_from_stream(json_chunks)
yield from cls.model_from_chunks(json_chunks, **kwargs)
@classmethod
@@ -42,6 +47,10 @@ class PartialBase(Generic[T_Model]):
cls, completion: AsyncGenerator[Any, None], mode: Mode, **kwargs: Any
) -> AsyncGenerator[T_Model, None]:
json_chunks = cls.extract_json_async(completion, mode)
if mode == Mode.MD_JSON:
json_chunks = extract_json_from_stream_async(json_chunks)
return cls.model_from_chunks_async(json_chunks, **kwargs)
@classmethod
+4
View File
@@ -7,6 +7,7 @@ import enum
import warnings
import logging
from openai.types.chat import ChatCompletion
from instructor.utils import extract_json_from_codeblock
T = TypeVar("T")
@@ -135,6 +136,9 @@ class OpenAISchema(BaseModel): # type: ignore[misc]
strict=strict,
)
elif mode in {Mode.JSON, Mode.JSON_SCHEMA, Mode.MD_JSON}:
if mode == Mode.MD_JSON:
message.content = extract_json_from_codeblock(message.content or "")
model_response = cls.model_validate_json(
message.content, # type: ignore
context=validation_context,
+9 -48
View File
@@ -1,6 +1,5 @@
# type: ignore[all]
import inspect
import json
import logging
from textwrap import dedent
from collections.abc import Iterable
@@ -24,8 +23,6 @@ from typing import (
from openai import AsyncOpenAI, OpenAI
from openai.types.chat import (
ChatCompletion,
ChatCompletionMessage,
ChatCompletionMessageParam,
)
from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel, ValidationError
@@ -34,6 +31,7 @@ from instructor.dsl.iterable import IterableModel, IterableBase
from instructor.dsl.parallel import ParallelBase, ParallelModel, handle_parallel_model
from instructor.dsl.partial import PartialBase
from instructor.dsl.simple_type import ModelAdapter, AdapterBase, is_simple_type
from instructor.utils import dump_message, update_total_usage
from .function_calls import Mode, OpenAISchema, openai_schema
@@ -47,35 +45,6 @@ T_ParamSpec = ParamSpec("T_ParamSpec")
T = TypeVar("T")
def update_total_usage(response, total_usage):
if isinstance(response, ChatCompletion) and response.usage is not None:
total_usage.completion_tokens += response.usage.completion_tokens or 0
total_usage.prompt_tokens += response.usage.prompt_tokens or 0
total_usage.total_tokens += response.usage.total_tokens or 0
response.usage = total_usage # Replace each response usage with the total usage
return response
def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
"""Dumps a message to a dict, to be returned to the OpenAI API.
Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
if it isn't used.
"""
ret: ChatCompletionMessageParam = {
"role": message.role,
"content": message.content or "",
}
if hasattr(message, "tool_calls") and message.tool_calls is not None:
ret["tool_calls"] = message.model_dump()["tool_calls"]
if (
hasattr(message, "function_call")
and message.function_call is not None
and ret["content"]
):
ret["content"] += json.dumps(message.model_dump()["function_call"])
return ret
def handle_response_model(
response_model: T, mode: Mode = Mode.TOOLS, **kwargs
) -> Union[Type[OpenAISchema], dict]:
@@ -153,12 +122,12 @@ def handle_response_model(
f"""
As a genius expert, your task is to understand the content and provide
the parsed objects in json that match the following json_schema:\n
{response_model.model_json_schema()['properties']}
{response_model.model_json_schema()}
Make sure to return an instance of the JSON, not the schema itself
"""
)
# Check for nested models
if "$defs" in response_model.model_json_schema():
message += f"\nHere are some more definitions to adhere too:\n{response_model.model_json_schema()['$defs']}"
if mode == Mode.JSON:
new_kwargs["response_format"] = {"type": "json_object"}
@@ -172,11 +141,10 @@ def handle_response_model(
elif mode == Mode.MD_JSON:
new_kwargs["messages"].append(
{
"role": "assistant",
"content": "Here is the perfectly correctly formatted JSON\n```json",
"role": "user",
"content": "Return the correct JSON response within a ```json codeblock. not the JSON_SCHEMA",
},
)
new_kwargs["stop"] = "```"
# check that the first message is a system message
# if it is not, add a system message to the beginning
if new_kwargs["messages"][0]["role"] != "system":
@@ -402,8 +370,8 @@ async def retry_async(
if mode == Mode.MD_JSON:
kwargs["messages"].append(
{
"role": "assistant",
"content": "```json",
"role": "user",
"content": "Return the correct JSON response within a ```json codeblock. not the JSON_SCHEMA",
},
)
raise e
@@ -473,13 +441,6 @@ def retry_sync(
"content": f"Recall the function correctly, fix the errors and exceptions found\n{e}",
}
)
if mode == Mode.MD_JSON:
kwargs["messages"].append(
{
"role": "assistant",
"content": "```json",
},
)
raise e
except RetryError as e:
logger.exception(f"Failed after retries: {e.last_attempt.exception}")
+83
View File
@@ -0,0 +1,83 @@
import json
from typing import Generator, Iterable, AsyncGenerator
from openai.types.chat import (
ChatCompletion,
ChatCompletionMessage,
ChatCompletionMessageParam,
)
def extract_json_from_codeblock(content: str) -> str:
first_paren = content.find("{")
last_paren = content.rfind("}")
return content[first_paren : last_paren + 1]
def extract_json_from_stream(chunks: Iterable[str]) -> Generator[str, None, None]:
capturing = False
brace_count = 0
for chunk in chunks:
for char in chunk:
if char == "{":
capturing = True
brace_count += 1
yield char
elif char == "}" and capturing:
brace_count -= 1
yield char
if brace_count == 0:
capturing = False
break # Cease yielding upon closing the current JSON object
elif capturing:
yield char
async def extract_json_from_stream_async(
chunks: AsyncGenerator[str, None],
) -> AsyncGenerator[str, None]:
capturing = False
brace_count = 0
async for chunk in chunks:
for char in chunk:
if char == "{":
capturing = True
brace_count += 1
yield char
elif char == "}" and capturing:
brace_count -= 1
yield char
if brace_count == 0:
capturing = False
break # Cease yielding upon closing the current JSON object
elif capturing:
yield char
def update_total_usage(response, total_usage):
if isinstance(response, ChatCompletion) and response.usage is not None:
total_usage.completion_tokens += response.usage.completion_tokens or 0
total_usage.prompt_tokens += response.usage.prompt_tokens or 0
total_usage.total_tokens += response.usage.total_tokens or 0
response.usage = total_usage # Replace each response usage with the total usage
return response
def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
"""Dumps a message to a dict, to be returned to the OpenAI API.
Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
if it isn't used.
"""
ret: ChatCompletionMessageParam = {
"role": message.role,
"content": message.content or "",
}
if hasattr(message, "tool_calls") and message.tool_calls is not None:
ret["tool_calls"] = message.model_dump()["tool_calls"]
if (
hasattr(message, "function_call")
and message.function_call is not None
and ret["content"]
):
ret["content"] += json.dumps(message.model_dump()["function_call"])
return ret
+1 -2
View File
@@ -2,6 +2,5 @@ import instructor
models = ["gpt-4-turbo-preview"]
modes = [
instructor.Mode.JSON,
instructor.Mode.TOOLS,
instructor.Mode.MD_JSON,
]
+127
View File
@@ -0,0 +1,127 @@
import json
import pytest
from instructor.utils import (
extract_json_from_codeblock,
extract_json_from_stream,
extract_json_from_stream_async,
)
def test_extract_json_from_codeblock():
example = """
Here is a response
```json
{
"key": "value"
}
```
"""
result = extract_json_from_codeblock(example)
assert json.loads(result) == {"key": "value"}
def test_extract_json_from_codeblock_no_end():
example = """
Here is a response
```json
{
"key": "value",
"another_key": [{"key": {"key": "value"}}]
}
"""
result = extract_json_from_codeblock(example)
assert json.loads(result) == {
"key": "value",
"another_key": [{"key": {"key": "value"}}],
}
def test_extract_json_from_codeblock_no_start():
example = """
Here is a response
{
"key": "value",
"another_key": [{"key": {"key": "value"}}, {"key": "value"}]
}
"""
result = extract_json_from_codeblock(example)
assert json.loads(result) == {
"key": "value",
"another_key": [{"key": {"key": "value"}}, {"key": "value"}],
}
def test_stream_json():
text = """here is the json for you!
```json
, here
{
"key": "value",
"another_key": [{"key": {"key": "value"}}]
}
```
What do you think?
"""
def batch_strings(chunks, n=2):
batch = ""
for chunk in chunks:
for char in chunk:
batch += char
if len(batch) == n:
yield batch
batch = ""
if batch: # Yield any remaining characters in the last batch
yield batch
result = json.loads(
"".join(list(extract_json_from_stream(batch_strings(text, n=3))))
)
assert result == {"key": "value", "another_key": [{"key": {"key": "value"}}]}
@pytest.mark.asyncio
async def test_stream_json_async():
text = """here is the json for you!
```json
, here
{
"key": "value",
"another_key": [{"key": {"key": "value"}}, {"key": "value"}]
}
```
What do you think?
"""
async def batch_strings_async(chunks, n=2):
batch = ""
for chunk in chunks:
for char in chunk:
batch += char
if len(batch) == n:
yield batch
batch = ""
if batch: # Yield any remaining characters in the last batch
yield batch
result = json.loads(
"".join(
[
chunk
async for chunk in extract_json_from_stream_async(
batch_strings_async(text, n=3)
)
]
)
)
assert result == {
"key": "value",
"another_key": [{"key": {"key": "value"}}, {"key": "value"}],
}