mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
Clean up streaming code (#377)
This commit is contained in:
+17
-1
@@ -99,4 +99,20 @@ for user in users:
|
||||
>>> name="John" "age"=10
|
||||
```
|
||||
|
||||
This streaming is still a prototype, but should work quite well for simple schemas.
|
||||
## Asynchronous Streaming
|
||||
|
||||
I also just want to call out in this example that `instructor` also supports asynchronous streaming. This is useful when you want to stream a response model and process the results as they come in, but you'll need to use the `async for` syntax to iterate over the results.
|
||||
|
||||
```python
|
||||
model = await client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
response_model=Iterable[UserExtract],
|
||||
max_retries=2,
|
||||
stream=stream,
|
||||
messages=[
|
||||
{"role": "user", "content": "Make two up people"},
|
||||
],
|
||||
)
|
||||
async for m in model:
|
||||
assert isinstance(m, UserExtract)
|
||||
```
|
||||
|
||||
@@ -103,3 +103,21 @@ for extraction in extraction_stream:
|
||||
This will output the following:
|
||||
|
||||

|
||||
|
||||
## Asynchronous Streaming
|
||||
|
||||
I also just want to call out in this example that `instructor` also supports asynchronous streaming. This is useful when you want to stream a response model and process the results as they come in, but you'll need to use the `async for` syntax to iterate over the results.
|
||||
|
||||
```python
|
||||
model = await client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
response_model=Partial[UserExtract],
|
||||
max_retries=2,
|
||||
stream=True,
|
||||
messages=[
|
||||
{"role": "user", "content": "Jason Liu is 12 years old"},
|
||||
],
|
||||
)
|
||||
async for m in model:
|
||||
assert isinstance(m, UserExtract)
|
||||
```
|
||||
@@ -3,7 +3,7 @@ from .dsl import (
|
||||
CitationMixin,
|
||||
Maybe,
|
||||
Partial,
|
||||
MultiTask,
|
||||
IterableModel,
|
||||
llm_validator,
|
||||
openai_moderation,
|
||||
)
|
||||
@@ -13,7 +13,7 @@ from .patch import apatch, patch
|
||||
__all__ = [
|
||||
"OpenAISchema",
|
||||
"CitationMixin",
|
||||
"MultiTask",
|
||||
"IterableModel",
|
||||
"Maybe",
|
||||
"Partial",
|
||||
"openai_schema",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .multitask import MultiTask
|
||||
from .iterable import IterableModel
|
||||
from .maybe import Maybe
|
||||
from .partial import Partial
|
||||
from .validators import llm_validator, openai_moderation
|
||||
@@ -6,7 +6,7 @@ from .citation import CitationMixin
|
||||
|
||||
__all__ = [ # noqa: F405
|
||||
"CitationMixin",
|
||||
"MultiTask",
|
||||
"IterableModel",
|
||||
"Maybe",
|
||||
"Partial",
|
||||
"llm_validator",
|
||||
|
||||
@@ -5,21 +5,21 @@ from pydantic import BaseModel, Field, create_model
|
||||
from instructor.function_calls import OpenAISchema, Mode
|
||||
|
||||
|
||||
class MultiTaskBase:
|
||||
class IterableBase:
|
||||
task_type = None # type: ignore
|
||||
|
||||
@classmethod
|
||||
def from_streaming_response(cls, completion, mode: Mode):
|
||||
def from_streaming_response(cls, completion, mode: Mode, **kwargs):
|
||||
json_chunks = cls.extract_json(completion, mode)
|
||||
yield from cls.tasks_from_chunks(json_chunks)
|
||||
|
||||
@classmethod
|
||||
async def from_streaming_response_async(cls, completion, mode: Mode):
|
||||
async def from_streaming_response_async(cls, completion, mode: Mode, **kwargs):
|
||||
json_chunks = cls.extract_json_async(completion, mode)
|
||||
return cls.tasks_from_chunks_async(json_chunks)
|
||||
return cls.tasks_from_chunks_async(json_chunks, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def tasks_from_chunks(cls, json_chunks):
|
||||
def tasks_from_chunks(cls, json_chunks, **kwargs):
|
||||
started = False
|
||||
potential_object = ""
|
||||
for chunk in json_chunks:
|
||||
@@ -32,11 +32,11 @@ class MultiTaskBase:
|
||||
|
||||
task_json, potential_object = cls.get_object(potential_object, 0)
|
||||
if task_json:
|
||||
obj = cls.task_type.model_validate_json(task_json) # type: ignore
|
||||
obj = cls.task_type.model_validate_json(task_json, **kwargs) # type: ignore
|
||||
yield obj
|
||||
|
||||
@classmethod
|
||||
async def tasks_from_chunks_async(cls, json_chunks):
|
||||
async def tasks_from_chunks_async(cls, json_chunks, **kwargs):
|
||||
started = False
|
||||
potential_object = ""
|
||||
async for chunk in json_chunks:
|
||||
@@ -49,7 +49,7 @@ class MultiTaskBase:
|
||||
|
||||
task_json, potential_object = cls.get_object(potential_object, 0)
|
||||
if task_json:
|
||||
obj = cls.task_type.model_validate_json(task_json) # type: ignore
|
||||
obj = cls.task_type.model_validate_json(task_json, **kwargs) # type: ignore
|
||||
yield obj
|
||||
|
||||
@staticmethod
|
||||
@@ -106,13 +106,13 @@ class MultiTaskBase:
|
||||
return None, str
|
||||
|
||||
|
||||
def MultiTask(
|
||||
def IterableModel(
|
||||
subtask_class: Type[BaseModel],
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Dynamically create a MultiTask OpenAISchema that can be used to segment multiple
|
||||
Dynamically create a IterableModel OpenAISchema that can be used to segment multiple
|
||||
tasks given a base class. This creates class that can be used to create a toolkit
|
||||
for a specific task, names and descriptions are automatically generated. However
|
||||
they can be overridden.
|
||||
@@ -121,14 +121,14 @@ def MultiTask(
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
from instructor import MultiTask
|
||||
from instructor import IterableModel
|
||||
|
||||
class User(BaseModel):
|
||||
name: str = Field(description="The name of the person")
|
||||
age: int = Field(description="The age of the person")
|
||||
role: str = Field(description="The role of the person")
|
||||
|
||||
MultiUser = MultiTask(User)
|
||||
MultiUser = IterableModel(User)
|
||||
```
|
||||
|
||||
## Result
|
||||
@@ -163,7 +163,7 @@ def MultiTask(
|
||||
"""
|
||||
task_name = subtask_class.__name__ if name is None else name
|
||||
|
||||
name = f"Multi{task_name}"
|
||||
name = f"Iterable{task_name}"
|
||||
|
||||
list_tasks = (
|
||||
List[subtask_class],
|
||||
@@ -177,7 +177,7 @@ def MultiTask(
|
||||
new_cls = create_model(
|
||||
name,
|
||||
tasks=list_tasks,
|
||||
__base__=(OpenAISchema, MultiTaskBase), # type: ignore
|
||||
__base__=(OpenAISchema, IterableBase), # type: ignore
|
||||
)
|
||||
# set the class constructor BaseModel
|
||||
new_cls.task_type = subtask_class
|
||||
@@ -21,17 +21,17 @@ Model = TypeVar("Model", bound=BaseModel)
|
||||
|
||||
class PartialBase:
|
||||
@classmethod
|
||||
def from_streaming_response(cls, completion, mode: Mode):
|
||||
def from_streaming_response(cls, completion, mode: Mode, **kwargs):
|
||||
json_chunks = cls.extract_json(completion, mode)
|
||||
yield from cls.model_from_chunks(json_chunks)
|
||||
yield from cls.model_from_chunks(json_chunks, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def from_streaming_response_async(cls, completion, mode: Mode):
|
||||
async def from_streaming_response_async(cls, completion, mode: Mode, **kwargs):
|
||||
json_chunks = cls.extract_json_async(completion, mode)
|
||||
return cls.model_from_chunks_async(json_chunks)
|
||||
return cls.model_from_chunks_async(json_chunks, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def model_from_chunks(cls, json_chunks):
|
||||
def model_from_chunks(cls, json_chunks, **kwargs):
|
||||
prev_obj = None
|
||||
potential_object = ""
|
||||
for chunk in json_chunks:
|
||||
@@ -42,7 +42,7 @@ class PartialBase:
|
||||
parser.parse(potential_object) if potential_object.strip() else None
|
||||
)
|
||||
if task_json:
|
||||
obj = cls.model_validate(task_json, strict=None) # type: ignore
|
||||
obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore
|
||||
if obj != prev_obj:
|
||||
obj.__dict__[
|
||||
"chunk"
|
||||
@@ -51,8 +51,9 @@ class PartialBase:
|
||||
yield obj
|
||||
|
||||
@classmethod
|
||||
async def model_from_chunks_async(cls, json_chunks):
|
||||
async def model_from_chunks_async(cls, json_chunks, **kwargs):
|
||||
potential_object = ""
|
||||
prev_obj = None
|
||||
async for chunk in json_chunks:
|
||||
potential_object += chunk
|
||||
|
||||
@@ -61,7 +62,7 @@ class PartialBase:
|
||||
parser.parse(potential_object) if potential_object.strip() else None
|
||||
)
|
||||
if task_json:
|
||||
obj = cls.model_validate(task_json, strict=None) # type: ignore
|
||||
obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore
|
||||
if obj != prev_obj:
|
||||
obj.__dict__[
|
||||
"chunk"
|
||||
|
||||
@@ -122,8 +122,6 @@ class OpenAISchema(BaseModel):
|
||||
validation_context=None,
|
||||
strict: bool = None,
|
||||
mode: Mode = Mode.TOOLS,
|
||||
stream_multitask: bool = False,
|
||||
stream_partial: bool = False,
|
||||
):
|
||||
"""Execute the function from the response of an openai chat completion
|
||||
|
||||
@@ -138,12 +136,6 @@ class OpenAISchema(BaseModel):
|
||||
Returns:
|
||||
cls (OpenAISchema): An instance of the class
|
||||
"""
|
||||
if stream_multitask:
|
||||
return cls.from_streaming_response(completion, mode)
|
||||
|
||||
if stream_partial:
|
||||
return cls.from_streaming_response(completion, mode)
|
||||
|
||||
if completion.choices[0].finish_reason == "length":
|
||||
raise IncompleteOutputException()
|
||||
|
||||
@@ -187,8 +179,6 @@ class OpenAISchema(BaseModel):
|
||||
validation_context=None,
|
||||
strict: bool = None,
|
||||
mode: Mode = Mode.TOOLS,
|
||||
stream_multitask: bool = False,
|
||||
stream_partial: bool = False,
|
||||
):
|
||||
"""Execute the function from the response of an openai chat completion
|
||||
|
||||
@@ -203,12 +193,6 @@ class OpenAISchema(BaseModel):
|
||||
cls (OpenAISchema): An instance of the class
|
||||
"""
|
||||
|
||||
if stream_multitask:
|
||||
return await cls.from_streaming_response_async(completion, mode)
|
||||
|
||||
if stream_partial:
|
||||
return cls.from_streaming_response_async(completion, mode)
|
||||
|
||||
if completion.choices[0].finish_reason == "length":
|
||||
raise IncompleteOutputException()
|
||||
|
||||
|
||||
+38
-29
@@ -15,7 +15,7 @@ from openai.types.chat import (
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from instructor.dsl.multitask import MultiTask, MultiTaskBase
|
||||
from instructor.dsl.iterable import IterableModel, IterableBase
|
||||
from instructor.dsl.partial import PartialBase
|
||||
|
||||
from .function_calls import Mode, OpenAISchema, openai_schema
|
||||
@@ -69,12 +69,12 @@ def handle_response_model(
|
||||
if response_model is not None:
|
||||
if get_origin(response_model) is Iterable:
|
||||
iterable_element_class = get_args(response_model)[0]
|
||||
response_model = MultiTask(iterable_element_class)
|
||||
response_model = IterableModel(iterable_element_class)
|
||||
if not issubclass(response_model, OpenAISchema):
|
||||
response_model = openai_schema(response_model) # type: ignore
|
||||
|
||||
if new_kwargs.get("stream", False) and not issubclass(
|
||||
response_model, (MultiTaskBase, PartialBase)
|
||||
response_model, (IterableBase, PartialBase)
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"stream=True is not supported when using response_model parameter for non-iterables"
|
||||
@@ -162,23 +162,29 @@ def process_response(
|
||||
validation_context (dict, optional): The validation context to use for validating the response. Defaults to None.
|
||||
strict (bool, optional): Whether to use strict json parsing. Defaults to None.
|
||||
"""
|
||||
if response_model is not None:
|
||||
is_model_multitask = issubclass(response_model, MultiTaskBase)
|
||||
is_model_partial = issubclass(response_model, PartialBase)
|
||||
model = response_model.from_response(
|
||||
if response_model is None:
|
||||
return response
|
||||
|
||||
if issubclass(response_model, (IterableBase, PartialBase)) and stream:
|
||||
model = response_model.from_streaming_response(
|
||||
response,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
stream_multitask=stream and is_model_multitask,
|
||||
stream_partial=stream and is_model_partial,
|
||||
)
|
||||
if not stream:
|
||||
model._raw_response = response
|
||||
if is_model_multitask:
|
||||
return model.tasks
|
||||
return model
|
||||
return response
|
||||
|
||||
model = response_model.from_response(
|
||||
response,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
)
|
||||
model._raw_response = response
|
||||
|
||||
if issubclass(response_model, IterableBase):
|
||||
# If the response model is a multitask, return the tasks
|
||||
return [task for task in model.tasks]
|
||||
|
||||
return model
|
||||
|
||||
|
||||
async def process_response_async(
|
||||
@@ -201,23 +207,26 @@ async def process_response_async(
|
||||
validation_context (dict, optional): The validation context to use for validating the response. Defaults to None.
|
||||
strict (bool, optional): Whether to use strict json parsing. Defaults to None.
|
||||
"""
|
||||
if response_model is not None:
|
||||
is_model_multitask = issubclass(response_model, MultiTaskBase)
|
||||
is_model_partial = issubclass(response_model, PartialBase)
|
||||
model = await response_model.from_response_async(
|
||||
if response_model is None:
|
||||
return response
|
||||
|
||||
if issubclass(response_model, (IterableBase, PartialBase)) and stream:
|
||||
model = await response_model.from_streaming_response_async(
|
||||
response,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
stream_multitask=stream and is_model_multitask,
|
||||
stream_partial=stream and is_model_partial,
|
||||
)
|
||||
if not stream:
|
||||
model._raw_response = response
|
||||
if is_model_multitask:
|
||||
return model.tasks
|
||||
return model
|
||||
return response
|
||||
|
||||
model = await response_model.from_response_async(
|
||||
response,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
)
|
||||
model._raw_response = response
|
||||
if issubclass(response_model, IterableBase):
|
||||
return model.tasks
|
||||
return model
|
||||
|
||||
|
||||
async def retry_async(
|
||||
|
||||
+2
-2
@@ -134,8 +134,8 @@ nav:
|
||||
- Fields: 'concepts/fields.md'
|
||||
- Missing: "concepts/maybe.md"
|
||||
- Patching: 'concepts/patching.md'
|
||||
- List (Streaming): "concepts/lists.md"
|
||||
- Partial (Streaming): "concepts/field_streaming.md"
|
||||
- Stream Iterable: "concepts/lists.md"
|
||||
- Stream Partial: "concepts/partial.md"
|
||||
- Raw Response: 'concepts/raw_response.md'
|
||||
- FastAPI: 'concepts/fastapi.md'
|
||||
- Caching: 'concepts/caching.md'
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
from itertools import product
|
||||
from typing import Iterable
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
import instructor
|
||||
from instructor.dsl.partial import Partial
|
||||
|
||||
from tests.openai.util import models, modes
|
||||
|
||||
|
||||
class UserExtract(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model, mode, stream", product(models, modes, [True, False]))
|
||||
def test_iterable_model(model, mode, stream, client):
|
||||
client = instructor.patch(client, mode=mode)
|
||||
model = client.chat.completions.create(
|
||||
model=model,
|
||||
response_model=Iterable[UserExtract],
|
||||
max_retries=2,
|
||||
stream=stream,
|
||||
messages=[
|
||||
{"role": "user", "content": "Make two up people"},
|
||||
],
|
||||
)
|
||||
for m in model:
|
||||
assert isinstance(m, UserExtract)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model, mode, stream", product(models, modes, [True, False]))
|
||||
@pytest.mark.asyncio
|
||||
async def test_iterable_model_async(model, mode, stream, aclient):
|
||||
aclient = instructor.patch(aclient, mode=mode)
|
||||
model = await aclient.chat.completions.create(
|
||||
model=model,
|
||||
response_model=Iterable[UserExtract],
|
||||
max_retries=2,
|
||||
stream=stream,
|
||||
messages=[
|
||||
{"role": "user", "content": "Make two up people"},
|
||||
],
|
||||
)
|
||||
if stream:
|
||||
async for m in model:
|
||||
assert isinstance(m, UserExtract)
|
||||
else:
|
||||
for m in model:
|
||||
assert isinstance(m, UserExtract)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model,mode", product(models, modes))
|
||||
def test_partial_model(model, mode, client):
|
||||
client = instructor.patch(client, mode=mode)
|
||||
model = client.chat.completions.create(
|
||||
model=model,
|
||||
response_model=Partial[UserExtract],
|
||||
max_retries=2,
|
||||
stream=True,
|
||||
messages=[
|
||||
{"role": "user", "content": "Jason Liu is 12 years old"},
|
||||
],
|
||||
)
|
||||
for m in model:
|
||||
assert isinstance(m, UserExtract)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model,mode", product(models, modes))
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_model_async(model, mode, aclient):
|
||||
aclient = instructor.patch(aclient, mode=mode)
|
||||
model = await aclient.chat.completions.create(
|
||||
model=model,
|
||||
response_model=Partial[UserExtract],
|
||||
max_retries=2,
|
||||
stream=True,
|
||||
messages=[
|
||||
{"role": "user", "content": "Jason Liu is 12 years old"},
|
||||
],
|
||||
)
|
||||
async for m in model:
|
||||
assert isinstance(m, UserExtract)
|
||||
@@ -1,8 +1,7 @@
|
||||
import instructor
|
||||
|
||||
models = ["gpt-3.5-turbo-1106", "gpt-4-1106-preview"]
|
||||
models = ["gpt-4-turbo-preview"]
|
||||
modes = [
|
||||
instructor.Mode.FUNCTIONS,
|
||||
instructor.Mode.JSON,
|
||||
instructor.Mode.TOOLS,
|
||||
]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from instructor import OpenAISchema
|
||||
from instructor.dsl import MultiTask
|
||||
from instructor.dsl import IterableModel
|
||||
|
||||
|
||||
def test_multi_task():
|
||||
@@ -9,9 +9,9 @@ def test_multi_task():
|
||||
id: int
|
||||
query: str
|
||||
|
||||
multitask = MultiTask(Search)
|
||||
assert multitask.openai_schema["name"] == "MultiSearch"
|
||||
IterableSearch = IterableModel(Search)
|
||||
assert IterableSearch.openai_schema["name"] == "IterableSearch"
|
||||
assert (
|
||||
multitask.openai_schema["description"]
|
||||
IterableSearch.openai_schema["description"]
|
||||
== "Correct segmentation of `Search` tasks"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user