mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
feat: Improve MD_JSON mode (#490)
This commit is contained in:
@@ -33,12 +33,12 @@ def extract(data) -> UserDetail:
|
||||
start = time.perf_counter() # (1)
|
||||
model = extract("Extract jason is 25 years old")
|
||||
print(f"Time taken: {time.perf_counter() - start}")
|
||||
#> Time taken: 0.8392175831831992
|
||||
#> Time taken: 0.7583793329977198
|
||||
|
||||
start = time.perf_counter()
|
||||
model = extract("Extract jason is 25 years old") # (2)
|
||||
print(f"Time taken: {time.perf_counter() - start}")
|
||||
#> Time taken: 8.33999365568161e-07
|
||||
#> Time taken: 4.3330073822289705e-06
|
||||
```
|
||||
|
||||
1. Using `time.perf_counter()` to measure the time taken to run the function is better than using `time.time()` because it's more accurate and less susceptible to system clock changes.
|
||||
|
||||
@@ -157,8 +157,8 @@ async def print_iterable_results():
|
||||
)
|
||||
async for m in model:
|
||||
print(m)
|
||||
#> name='John Smith' age=30
|
||||
#> name='Mary Jane' age=28
|
||||
#> name='John Doe' age=30
|
||||
#> name='Jane Doe' age=28
|
||||
|
||||
|
||||
import asyncio
|
||||
|
||||
@@ -89,7 +89,7 @@ print(user2.model_dump_json(indent=2))
|
||||
{
|
||||
"result": null,
|
||||
"error": false,
|
||||
"message": null
|
||||
"message": "Unknown user"
|
||||
}
|
||||
"""
|
||||
```
|
||||
|
||||
@@ -150,7 +150,7 @@ class SearchQuery(BaseModel):
|
||||
|
||||
def execute(self):
|
||||
print(f"Searching for {self.query} of type {self.query_type}")
|
||||
#> Searching for cat pictures of type image
|
||||
#> Searching for cat of type image
|
||||
return "Results for cat"
|
||||
|
||||
|
||||
|
||||
@@ -44,9 +44,9 @@ function_calls = client.chat.completions.create(
|
||||
|
||||
for fc in function_calls:
|
||||
print(fc)
|
||||
#> location='Toronto' units='imperial'
|
||||
#> location='Toronto' units='metric'
|
||||
#> location='Dallas' units='imperial'
|
||||
#> query='who won the super bowl'
|
||||
#> query='super bowl winner'
|
||||
```
|
||||
|
||||
1. Set the mode to `PARALLEL_TOOLS` to enable parallel function calling.
|
||||
|
||||
@@ -119,10 +119,10 @@ print(extraction.model_dump_json(indent=2))
|
||||
"twitter": "@CodeMaster2023"
|
||||
}
|
||||
],
|
||||
"date": "2024-03-15",
|
||||
"date": "March 15th, 2024",
|
||||
"location": "Grand Tech Arena located at 4521 Innovation Drive",
|
||||
"budget": 50000,
|
||||
"deadline": "2024-02-20"
|
||||
"deadline": "February 20th"
|
||||
}
|
||||
"""
|
||||
```
|
||||
|
||||
@@ -25,7 +25,7 @@ user: UserExtract = client.chat.completions.create(
|
||||
print(user._raw_response)
|
||||
"""
|
||||
ChatCompletion(
|
||||
id='chatcmpl-8u9bsrmmf5YjZyfCtQymoZV8LK1qg',
|
||||
id='chatcmpl-8zpltT9vXJdO5OE3AfDsOhAUr911A',
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason='stop',
|
||||
@@ -37,7 +37,7 @@ ChatCompletion(
|
||||
function_call=None,
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id='call_O5rpXf47YgXiYrYWv45yZUeM',
|
||||
id='call_vXI3foz7jqlzFILU9pwuYJZB',
|
||||
function=Function(
|
||||
arguments='{"name":"Jason","age":25}', name='UserExtract'
|
||||
),
|
||||
@@ -47,10 +47,10 @@ ChatCompletion(
|
||||
),
|
||||
)
|
||||
],
|
||||
created=1708394000,
|
||||
created=1709747709,
|
||||
model='gpt-3.5-turbo-0125',
|
||||
object='chat.completion',
|
||||
system_fingerprint='fp_69829325d0',
|
||||
system_fingerprint='fp_2b778c6b35',
|
||||
usage=CompletionUsage(completion_tokens=9, prompt_tokens=82, total_tokens=91),
|
||||
)
|
||||
"""
|
||||
|
||||
@@ -91,7 +91,7 @@ except ValidationError as e:
|
||||
"""
|
||||
1 validation error for QuestionAnswer
|
||||
answer
|
||||
Assertion failed, The statement promotes objectionable behavior by encouraging evil and theft. [type=assertion_error, input_value='The meaning of life is to be evil and steal', input_type=str]
|
||||
Assertion failed, The statement promotes objectionable behavior by encouraging evil and stealing, which goes against the rule of not saying objectionable things. [type=assertion_error, input_value='The meaning of life is to be evil and steal', input_type=str]
|
||||
For further information visit https://errors.pydantic.dev/2.6/v/assertion_error
|
||||
"""
|
||||
```
|
||||
|
||||
@@ -47,6 +47,7 @@ client = instructor.patch(client, mode=instructor.Mode.TOOLS)
|
||||
# Rate limit the number of requests
|
||||
sem = asyncio.Semaphore(5)
|
||||
|
||||
|
||||
# Use an Enum to define the types of questions
|
||||
class QuestionType(Enum):
|
||||
CONTACT = "CONTACT"
|
||||
|
||||
+2
-3
@@ -51,9 +51,8 @@ client = Groq(
|
||||
)
|
||||
|
||||
# By default, the patch function will patch the ChatCompletion.create and ChatCompletion.create methods to support the response_model parameter
|
||||
client = instructor.patch(
|
||||
client, mode=instructor.Mode.MD_JSON
|
||||
)
|
||||
client = instructor.patch(client, mode=instructor.Mode.MD_JSON)
|
||||
|
||||
|
||||
# Now, we can use the response_model parameter using only a base model
|
||||
# rather than having to use the OpenAISchema class
|
||||
|
||||
+2
-4
@@ -50,12 +50,10 @@ from mistralai.client import MistralClient
|
||||
# enables `response_model` in chat call
|
||||
client = MistralClient()
|
||||
|
||||
patched_chat = instructor.patch(
|
||||
create=client.chat,
|
||||
mode=instructor.Mode.MISTRAL_TOOLS
|
||||
)
|
||||
patched_chat = instructor.patch(create=client.chat, mode=instructor.Mode.MISTRAL_TOOLS)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
class UserDetails(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
@@ -108,13 +108,13 @@ if __name__ == "__main__":
|
||||
assert isinstance(df, pd.DataFrame)
|
||||
print(df)
|
||||
"""
|
||||
Party Years Served
|
||||
Party Years Served
|
||||
President
|
||||
Joe Biden Democratic 2021-
|
||||
Donald Trump Republican 2017-2021
|
||||
Barack Obama Democratic 2009-2017
|
||||
George W. Bush Republican 2001-2009
|
||||
Bill Clinton Democratic 1993-2001
|
||||
Joe Biden Democratic 2021-Current
|
||||
Donald Trump Republican 2017-2021
|
||||
Barack Obama Democratic 2009-2017
|
||||
George W. Bush Republican 2001-2009
|
||||
Bill Clinton Democratic 1993-2001
|
||||
"""
|
||||
|
||||
table = extract_table(
|
||||
|
||||
@@ -15,7 +15,7 @@ instructor hub pull --slug youtube-clips --py > youtube_clips.py
|
||||
```python
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Generator, Iterable
|
||||
from typing import List, Generator, Iterable
|
||||
import instructor
|
||||
import openai
|
||||
|
||||
@@ -24,12 +24,12 @@ client = instructor.patch(openai.OpenAI())
|
||||
|
||||
def extract_video_id(url: str) -> str | None:
|
||||
import re
|
||||
|
||||
match = re.search(r"v=([a-zA-Z0-9_-]+)", url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
|
||||
|
||||
class TranscriptSegment(BaseModel):
|
||||
source_id: int
|
||||
start: float
|
||||
@@ -51,9 +51,7 @@ def get_transcript_with_timing(
|
||||
|
||||
|
||||
class YoutubeClip(BaseModel):
|
||||
title: str = Field(
|
||||
description="Specific and informative title for the clip."
|
||||
)
|
||||
title: str = Field(description="Specific and informative title for the clip.")
|
||||
description: str = Field(
|
||||
description="A detailed description of the clip, including notable quotes or phrases."
|
||||
)
|
||||
@@ -98,7 +96,6 @@ if __name__ == "__main__":
|
||||
console = Console()
|
||||
url = Prompt.ask("Enter a YouTube URL")
|
||||
|
||||
|
||||
with console.status("[bold green]Processing YouTube URL...") as status:
|
||||
video_id = extract_video_id(url)
|
||||
|
||||
|
||||
+4
-4
@@ -115,7 +115,7 @@ print(response.model_dump_json(indent=2))
|
||||
print(user._raw_response.model_dump_json(indent=2))
|
||||
"""
|
||||
{
|
||||
"id": "chatcmpl-8u9e2TV3ehCgLsRxNLLeAbzpEmBuZ",
|
||||
"id": "chatcmpl-8zplvRbNM8iKSVa3Ld9NmVICeXZZ9",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
@@ -127,7 +127,7 @@ print(response.model_dump_json(indent=2))
|
||||
"function_call": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_3ZuQhfteTLEy7CUokjwnLBHr",
|
||||
"id": "call_V5FRMSXrHFFTTqTjpwA76h7t",
|
||||
"function": {
|
||||
"arguments": "{\"name\":\"Jason\",\"age\":25}",
|
||||
"name": "UserDetail"
|
||||
@@ -138,10 +138,10 @@ print(response.model_dump_json(indent=2))
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1708394134,
|
||||
"created": 1709747711,
|
||||
"model": "gpt-3.5-turbo-0125",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "fp_69829325d0",
|
||||
"system_fingerprint": "fp_2b778c6b35",
|
||||
"usage": {
|
||||
"completion_tokens": 9,
|
||||
"prompt_tokens": 81,
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, AsyncGenerator, Generator, Iterable, List, Optional, Tup
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from instructor.function_calls import OpenAISchema, Mode
|
||||
from instructor.utils import extract_json_from_stream, extract_json_from_stream_async
|
||||
|
||||
|
||||
class IterableBase:
|
||||
@@ -13,6 +14,10 @@ class IterableBase:
|
||||
cls, completion: Iterable[Any], mode: Mode, **kwargs: Any
|
||||
) -> Generator[BaseModel, None, None]: # noqa: ARG003
|
||||
json_chunks = cls.extract_json(completion, mode)
|
||||
|
||||
if mode == Mode.MD_JSON:
|
||||
json_chunks = extract_json_from_stream(json_chunks)
|
||||
|
||||
yield from cls.tasks_from_chunks(json_chunks, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@@ -20,6 +25,10 @@ class IterableBase:
|
||||
cls, completion: AsyncGenerator[Any, None], mode: Mode, **kwargs: Any
|
||||
) -> AsyncGenerator[BaseModel, None]:
|
||||
json_chunks = cls.extract_json_async(completion, mode)
|
||||
|
||||
if mode == Mode.MD_JSON:
|
||||
json_chunks = extract_json_from_stream_async(json_chunks)
|
||||
|
||||
return cls.tasks_from_chunks_async(json_chunks, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@@ -110,13 +119,14 @@ class IterableBase:
|
||||
|
||||
@staticmethod
|
||||
def get_object(s: str, stack: int) -> Tuple[Optional[str], str]:
|
||||
start_index = s.find("{")
|
||||
for i, c in enumerate(s):
|
||||
if c == "{":
|
||||
stack += 1
|
||||
if c == "}":
|
||||
stack -= 1
|
||||
if stack == 0:
|
||||
return s[: i + 1], s[i + 2 :]
|
||||
return s[start_index : i + 1], s[i + 2 :]
|
||||
return None, s
|
||||
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from copy import deepcopy
|
||||
|
||||
from instructor.function_calls import Mode
|
||||
from instructor.dsl.partialjson import JSONParser
|
||||
from instructor.utils import extract_json_from_stream, extract_json_from_stream_async
|
||||
|
||||
parser = JSONParser()
|
||||
T_Model = TypeVar("T_Model", bound=BaseModel)
|
||||
@@ -35,6 +36,10 @@ class PartialBase(Generic[T_Model]):
|
||||
cls, completion: Iterable[Any], mode: Mode, **kwargs: Any
|
||||
) -> Generator[T_Model, None, None]:
|
||||
json_chunks = cls.extract_json(completion, mode)
|
||||
|
||||
if mode == Mode.MD_JSON:
|
||||
json_chunks = extract_json_from_stream(json_chunks)
|
||||
|
||||
yield from cls.model_from_chunks(json_chunks, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@@ -42,6 +47,10 @@ class PartialBase(Generic[T_Model]):
|
||||
cls, completion: AsyncGenerator[Any, None], mode: Mode, **kwargs: Any
|
||||
) -> AsyncGenerator[T_Model, None]:
|
||||
json_chunks = cls.extract_json_async(completion, mode)
|
||||
|
||||
if mode == Mode.MD_JSON:
|
||||
json_chunks = extract_json_from_stream_async(json_chunks)
|
||||
|
||||
return cls.model_from_chunks_async(json_chunks, **kwargs)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -7,6 +7,7 @@ import enum
|
||||
import warnings
|
||||
import logging
|
||||
from openai.types.chat import ChatCompletion
|
||||
from instructor.utils import extract_json_from_codeblock
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -135,6 +136,9 @@ class OpenAISchema(BaseModel): # type: ignore[misc]
|
||||
strict=strict,
|
||||
)
|
||||
elif mode in {Mode.JSON, Mode.JSON_SCHEMA, Mode.MD_JSON}:
|
||||
if mode == Mode.MD_JSON:
|
||||
message.content = extract_json_from_codeblock(message.content or "")
|
||||
|
||||
model_response = cls.model_validate_json(
|
||||
message.content, # type: ignore
|
||||
context=validation_context,
|
||||
|
||||
+9
-48
@@ -1,6 +1,5 @@
|
||||
# type: ignore[all]
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from textwrap import dedent
|
||||
from collections.abc import Iterable
|
||||
@@ -24,8 +23,6 @@ from typing import (
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageParam,
|
||||
)
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from pydantic import BaseModel, ValidationError
|
||||
@@ -34,6 +31,7 @@ from instructor.dsl.iterable import IterableModel, IterableBase
|
||||
from instructor.dsl.parallel import ParallelBase, ParallelModel, handle_parallel_model
|
||||
from instructor.dsl.partial import PartialBase
|
||||
from instructor.dsl.simple_type import ModelAdapter, AdapterBase, is_simple_type
|
||||
from instructor.utils import dump_message, update_total_usage
|
||||
|
||||
from .function_calls import Mode, OpenAISchema, openai_schema
|
||||
|
||||
@@ -47,35 +45,6 @@ T_ParamSpec = ParamSpec("T_ParamSpec")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def update_total_usage(response, total_usage):
|
||||
if isinstance(response, ChatCompletion) and response.usage is not None:
|
||||
total_usage.completion_tokens += response.usage.completion_tokens or 0
|
||||
total_usage.prompt_tokens += response.usage.prompt_tokens or 0
|
||||
total_usage.total_tokens += response.usage.total_tokens or 0
|
||||
response.usage = total_usage # Replace each response usage with the total usage
|
||||
return response
|
||||
|
||||
|
||||
def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
|
||||
"""Dumps a message to a dict, to be returned to the OpenAI API.
|
||||
Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
|
||||
if it isn't used.
|
||||
"""
|
||||
ret: ChatCompletionMessageParam = {
|
||||
"role": message.role,
|
||||
"content": message.content or "",
|
||||
}
|
||||
if hasattr(message, "tool_calls") and message.tool_calls is not None:
|
||||
ret["tool_calls"] = message.model_dump()["tool_calls"]
|
||||
if (
|
||||
hasattr(message, "function_call")
|
||||
and message.function_call is not None
|
||||
and ret["content"]
|
||||
):
|
||||
ret["content"] += json.dumps(message.model_dump()["function_call"])
|
||||
return ret
|
||||
|
||||
|
||||
def handle_response_model(
|
||||
response_model: T, mode: Mode = Mode.TOOLS, **kwargs
|
||||
) -> Union[Type[OpenAISchema], dict]:
|
||||
@@ -153,12 +122,12 @@ def handle_response_model(
|
||||
f"""
|
||||
As a genius expert, your task is to understand the content and provide
|
||||
the parsed objects in json that match the following json_schema:\n
|
||||
{response_model.model_json_schema()['properties']}
|
||||
|
||||
{response_model.model_json_schema()}
|
||||
|
||||
Make sure to return an instance of the JSON, not the schema itself
|
||||
"""
|
||||
)
|
||||
# Check for nested models
|
||||
if "$defs" in response_model.model_json_schema():
|
||||
message += f"\nHere are some more definitions to adhere too:\n{response_model.model_json_schema()['$defs']}"
|
||||
|
||||
if mode == Mode.JSON:
|
||||
new_kwargs["response_format"] = {"type": "json_object"}
|
||||
@@ -172,11 +141,10 @@ def handle_response_model(
|
||||
elif mode == Mode.MD_JSON:
|
||||
new_kwargs["messages"].append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Here is the perfectly correctly formatted JSON\n```json",
|
||||
"role": "user",
|
||||
"content": "Return the correct JSON response within a ```json codeblock. not the JSON_SCHEMA",
|
||||
},
|
||||
)
|
||||
new_kwargs["stop"] = "```"
|
||||
# check that the first message is a system message
|
||||
# if it is not, add a system message to the beginning
|
||||
if new_kwargs["messages"][0]["role"] != "system":
|
||||
@@ -402,8 +370,8 @@ async def retry_async(
|
||||
if mode == Mode.MD_JSON:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "```json",
|
||||
"role": "user",
|
||||
"content": "Return the correct JSON response within a ```json codeblock. not the JSON_SCHEMA",
|
||||
},
|
||||
)
|
||||
raise e
|
||||
@@ -473,13 +441,6 @@ def retry_sync(
|
||||
"content": f"Recall the function correctly, fix the errors and exceptions found\n{e}",
|
||||
}
|
||||
)
|
||||
if mode == Mode.MD_JSON:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "```json",
|
||||
},
|
||||
)
|
||||
raise e
|
||||
except RetryError as e:
|
||||
logger.exception(f"Failed after retries: {e.last_attempt.exception}")
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
import json
|
||||
from typing import Generator, Iterable, AsyncGenerator
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageParam,
|
||||
)
|
||||
|
||||
|
||||
def extract_json_from_codeblock(content: str) -> str:
|
||||
first_paren = content.find("{")
|
||||
last_paren = content.rfind("}")
|
||||
return content[first_paren : last_paren + 1]
|
||||
|
||||
|
||||
def extract_json_from_stream(chunks: Iterable[str]) -> Generator[str, None, None]:
|
||||
capturing = False
|
||||
brace_count = 0
|
||||
for chunk in chunks:
|
||||
for char in chunk:
|
||||
if char == "{":
|
||||
capturing = True
|
||||
brace_count += 1
|
||||
yield char
|
||||
elif char == "}" and capturing:
|
||||
brace_count -= 1
|
||||
yield char
|
||||
if brace_count == 0:
|
||||
capturing = False
|
||||
break # Cease yielding upon closing the current JSON object
|
||||
elif capturing:
|
||||
yield char
|
||||
|
||||
|
||||
async def extract_json_from_stream_async(
|
||||
chunks: AsyncGenerator[str, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
capturing = False
|
||||
brace_count = 0
|
||||
async for chunk in chunks:
|
||||
for char in chunk:
|
||||
if char == "{":
|
||||
capturing = True
|
||||
brace_count += 1
|
||||
yield char
|
||||
elif char == "}" and capturing:
|
||||
brace_count -= 1
|
||||
yield char
|
||||
if brace_count == 0:
|
||||
capturing = False
|
||||
break # Cease yielding upon closing the current JSON object
|
||||
elif capturing:
|
||||
yield char
|
||||
|
||||
|
||||
def update_total_usage(response, total_usage):
|
||||
if isinstance(response, ChatCompletion) and response.usage is not None:
|
||||
total_usage.completion_tokens += response.usage.completion_tokens or 0
|
||||
total_usage.prompt_tokens += response.usage.prompt_tokens or 0
|
||||
total_usage.total_tokens += response.usage.total_tokens or 0
|
||||
response.usage = total_usage # Replace each response usage with the total usage
|
||||
return response
|
||||
|
||||
|
||||
def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
|
||||
"""Dumps a message to a dict, to be returned to the OpenAI API.
|
||||
Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
|
||||
if it isn't used.
|
||||
"""
|
||||
ret: ChatCompletionMessageParam = {
|
||||
"role": message.role,
|
||||
"content": message.content or "",
|
||||
}
|
||||
if hasattr(message, "tool_calls") and message.tool_calls is not None:
|
||||
ret["tool_calls"] = message.model_dump()["tool_calls"]
|
||||
if (
|
||||
hasattr(message, "function_call")
|
||||
and message.function_call is not None
|
||||
and ret["content"]
|
||||
):
|
||||
ret["content"] += json.dumps(message.model_dump()["function_call"])
|
||||
return ret
|
||||
@@ -2,6 +2,5 @@ import instructor
|
||||
|
||||
models = ["gpt-4-turbo-preview"]
|
||||
modes = [
|
||||
instructor.Mode.JSON,
|
||||
instructor.Mode.TOOLS,
|
||||
instructor.Mode.MD_JSON,
|
||||
]
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
import json
|
||||
import pytest
|
||||
from instructor.utils import (
|
||||
extract_json_from_codeblock,
|
||||
extract_json_from_stream,
|
||||
extract_json_from_stream_async,
|
||||
)
|
||||
|
||||
|
||||
def test_extract_json_from_codeblock():
|
||||
example = """
|
||||
Here is a response
|
||||
|
||||
```json
|
||||
{
|
||||
"key": "value"
|
||||
}
|
||||
```
|
||||
"""
|
||||
result = extract_json_from_codeblock(example)
|
||||
assert json.loads(result) == {"key": "value"}
|
||||
|
||||
|
||||
def test_extract_json_from_codeblock_no_end():
|
||||
example = """
|
||||
Here is a response
|
||||
|
||||
```json
|
||||
{
|
||||
"key": "value",
|
||||
"another_key": [{"key": {"key": "value"}}]
|
||||
}
|
||||
"""
|
||||
result = extract_json_from_codeblock(example)
|
||||
assert json.loads(result) == {
|
||||
"key": "value",
|
||||
"another_key": [{"key": {"key": "value"}}],
|
||||
}
|
||||
|
||||
|
||||
def test_extract_json_from_codeblock_no_start():
|
||||
example = """
|
||||
Here is a response
|
||||
|
||||
{
|
||||
"key": "value",
|
||||
"another_key": [{"key": {"key": "value"}}, {"key": "value"}]
|
||||
}
|
||||
"""
|
||||
result = extract_json_from_codeblock(example)
|
||||
assert json.loads(result) == {
|
||||
"key": "value",
|
||||
"another_key": [{"key": {"key": "value"}}, {"key": "value"}],
|
||||
}
|
||||
|
||||
|
||||
def test_stream_json():
|
||||
text = """here is the json for you!
|
||||
|
||||
```json
|
||||
, here
|
||||
{
|
||||
"key": "value",
|
||||
"another_key": [{"key": {"key": "value"}}]
|
||||
}
|
||||
```
|
||||
|
||||
What do you think?
|
||||
"""
|
||||
|
||||
def batch_strings(chunks, n=2):
|
||||
batch = ""
|
||||
for chunk in chunks:
|
||||
for char in chunk:
|
||||
batch += char
|
||||
if len(batch) == n:
|
||||
yield batch
|
||||
batch = ""
|
||||
if batch: # Yield any remaining characters in the last batch
|
||||
yield batch
|
||||
|
||||
result = json.loads(
|
||||
"".join(list(extract_json_from_stream(batch_strings(text, n=3))))
|
||||
)
|
||||
assert result == {"key": "value", "another_key": [{"key": {"key": "value"}}]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_json_async():
|
||||
text = """here is the json for you!
|
||||
|
||||
```json
|
||||
, here
|
||||
{
|
||||
"key": "value",
|
||||
"another_key": [{"key": {"key": "value"}}, {"key": "value"}]
|
||||
}
|
||||
```
|
||||
|
||||
What do you think?
|
||||
"""
|
||||
|
||||
async def batch_strings_async(chunks, n=2):
|
||||
batch = ""
|
||||
for chunk in chunks:
|
||||
for char in chunk:
|
||||
batch += char
|
||||
if len(batch) == n:
|
||||
yield batch
|
||||
batch = ""
|
||||
if batch: # Yield any remaining characters in the last batch
|
||||
yield batch
|
||||
|
||||
result = json.loads(
|
||||
"".join(
|
||||
[
|
||||
chunk
|
||||
async for chunk in extract_json_from_stream_async(
|
||||
batch_strings_async(text, n=3)
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
assert result == {
|
||||
"key": "value",
|
||||
"another_key": [{"key": {"key": "value"}}, {"key": "value"}],
|
||||
}
|
||||
Reference in New Issue
Block a user