Clean up streaming code (#377)

This commit is contained in:
Jason Liu
2024-01-31 18:31:33 -05:00
committed by GitHub
parent eeb2fb2f3d
commit ef76cfe97b
12 changed files with 190 additions and 80 deletions
+17 -1
View File
@@ -99,4 +99,20 @@ for user in users:
>>> name="John" "age"=10
```
This streaming is still a prototype, but should work quite well for simple schemas.
## Asynchronous Streaming
I also just want to call out in this example that `instructor` also supports asynchronous streaming. This is useful when you want to stream a response model and process the results as they come in, but you'll need to use the `async for` syntax to iterate over the results.
```python
model = await client.chat.completions.create(
model="gpt-4",
response_model=Iterable[UserExtract],
max_retries=2,
stream=stream,
messages=[
{"role": "user", "content": "Make two up people"},
],
)
async for m in model:
assert isinstance(m, UserExtract)
```
@@ -103,3 +103,21 @@ for extraction in extraction_stream:
This will output the following:
![Partial Streaming Gif](../img/partial.gif)
## Asynchronous Streaming
I also just want to call out in this example that `instructor` also supports asynchronous streaming. This is useful when you want to stream a response model and process the results as they come in, but you'll need to use the `async for` syntax to iterate over the results.
```python
model = await client.chat.completions.create(
model="gpt-4",
response_model=Partial[UserExtract],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)
async for m in model:
assert isinstance(m, UserExtract)
```
+2 -2
View File
@@ -3,7 +3,7 @@ from .dsl import (
CitationMixin,
Maybe,
Partial,
MultiTask,
IterableModel,
llm_validator,
openai_moderation,
)
@@ -13,7 +13,7 @@ from .patch import apatch, patch
__all__ = [
"OpenAISchema",
"CitationMixin",
"MultiTask",
"IterableModel",
"Maybe",
"Partial",
"openai_schema",
+2 -2
View File
@@ -1,4 +1,4 @@
from .multitask import MultiTask
from .iterable import IterableModel
from .maybe import Maybe
from .partial import Partial
from .validators import llm_validator, openai_moderation
@@ -6,7 +6,7 @@ from .citation import CitationMixin
__all__ = [ # noqa: F405
"CitationMixin",
"MultiTask",
"IterableModel",
"Maybe",
"Partial",
"llm_validator",
@@ -5,21 +5,21 @@ from pydantic import BaseModel, Field, create_model
from instructor.function_calls import OpenAISchema, Mode
class MultiTaskBase:
class IterableBase:
task_type = None # type: ignore
@classmethod
def from_streaming_response(cls, completion, mode: Mode):
def from_streaming_response(cls, completion, mode: Mode, **kwargs):
json_chunks = cls.extract_json(completion, mode)
yield from cls.tasks_from_chunks(json_chunks)
@classmethod
async def from_streaming_response_async(cls, completion, mode: Mode):
async def from_streaming_response_async(cls, completion, mode: Mode, **kwargs):
json_chunks = cls.extract_json_async(completion, mode)
return cls.tasks_from_chunks_async(json_chunks)
return cls.tasks_from_chunks_async(json_chunks, **kwargs)
@classmethod
def tasks_from_chunks(cls, json_chunks):
def tasks_from_chunks(cls, json_chunks, **kwargs):
started = False
potential_object = ""
for chunk in json_chunks:
@@ -32,11 +32,11 @@ class MultiTaskBase:
task_json, potential_object = cls.get_object(potential_object, 0)
if task_json:
obj = cls.task_type.model_validate_json(task_json) # type: ignore
obj = cls.task_type.model_validate_json(task_json, **kwargs) # type: ignore
yield obj
@classmethod
async def tasks_from_chunks_async(cls, json_chunks):
async def tasks_from_chunks_async(cls, json_chunks, **kwargs):
started = False
potential_object = ""
async for chunk in json_chunks:
@@ -49,7 +49,7 @@ class MultiTaskBase:
task_json, potential_object = cls.get_object(potential_object, 0)
if task_json:
obj = cls.task_type.model_validate_json(task_json) # type: ignore
obj = cls.task_type.model_validate_json(task_json, **kwargs) # type: ignore
yield obj
@staticmethod
@@ -106,13 +106,13 @@ class MultiTaskBase:
return None, str
def MultiTask(
def IterableModel(
subtask_class: Type[BaseModel],
name: Optional[str] = None,
description: Optional[str] = None,
):
"""
Dynamically create a MultiTask OpenAISchema that can be used to segment multiple
Dynamically create a IterableModel OpenAISchema that can be used to segment multiple
tasks given a base class. This creates class that can be used to create a toolkit
for a specific task, names and descriptions are automatically generated. However
they can be overridden.
@@ -121,14 +121,14 @@ def MultiTask(
```python
from pydantic import BaseModel, Field
from instructor import MultiTask
from instructor import IterableModel
class User(BaseModel):
name: str = Field(description="The name of the person")
age: int = Field(description="The age of the person")
role: str = Field(description="The role of the person")
MultiUser = MultiTask(User)
MultiUser = IterableModel(User)
```
## Result
@@ -163,7 +163,7 @@ def MultiTask(
"""
task_name = subtask_class.__name__ if name is None else name
name = f"Multi{task_name}"
name = f"Iterable{task_name}"
list_tasks = (
List[subtask_class],
@@ -177,7 +177,7 @@ def MultiTask(
new_cls = create_model(
name,
tasks=list_tasks,
__base__=(OpenAISchema, MultiTaskBase), # type: ignore
__base__=(OpenAISchema, IterableBase), # type: ignore
)
# set the class constructor BaseModel
new_cls.task_type = subtask_class
+9 -8
View File
@@ -21,17 +21,17 @@ Model = TypeVar("Model", bound=BaseModel)
class PartialBase:
@classmethod
def from_streaming_response(cls, completion, mode: Mode):
def from_streaming_response(cls, completion, mode: Mode, **kwargs):
json_chunks = cls.extract_json(completion, mode)
yield from cls.model_from_chunks(json_chunks)
yield from cls.model_from_chunks(json_chunks, **kwargs)
@classmethod
async def from_streaming_response_async(cls, completion, mode: Mode):
async def from_streaming_response_async(cls, completion, mode: Mode, **kwargs):
json_chunks = cls.extract_json_async(completion, mode)
return cls.model_from_chunks_async(json_chunks)
return cls.model_from_chunks_async(json_chunks, **kwargs)
@classmethod
def model_from_chunks(cls, json_chunks):
def model_from_chunks(cls, json_chunks, **kwargs):
prev_obj = None
potential_object = ""
for chunk in json_chunks:
@@ -42,7 +42,7 @@ class PartialBase:
parser.parse(potential_object) if potential_object.strip() else None
)
if task_json:
obj = cls.model_validate(task_json, strict=None) # type: ignore
obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore
if obj != prev_obj:
obj.__dict__[
"chunk"
@@ -51,8 +51,9 @@ class PartialBase:
yield obj
@classmethod
async def model_from_chunks_async(cls, json_chunks):
async def model_from_chunks_async(cls, json_chunks, **kwargs):
potential_object = ""
prev_obj = None
async for chunk in json_chunks:
potential_object += chunk
@@ -61,7 +62,7 @@ class PartialBase:
parser.parse(potential_object) if potential_object.strip() else None
)
if task_json:
obj = cls.model_validate(task_json, strict=None) # type: ignore
obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore
if obj != prev_obj:
obj.__dict__[
"chunk"
-16
View File
@@ -122,8 +122,6 @@ class OpenAISchema(BaseModel):
validation_context=None,
strict: bool = None,
mode: Mode = Mode.TOOLS,
stream_multitask: bool = False,
stream_partial: bool = False,
):
"""Execute the function from the response of an openai chat completion
@@ -138,12 +136,6 @@ class OpenAISchema(BaseModel):
Returns:
cls (OpenAISchema): An instance of the class
"""
if stream_multitask:
return cls.from_streaming_response(completion, mode)
if stream_partial:
return cls.from_streaming_response(completion, mode)
if completion.choices[0].finish_reason == "length":
raise IncompleteOutputException()
@@ -187,8 +179,6 @@ class OpenAISchema(BaseModel):
validation_context=None,
strict: bool = None,
mode: Mode = Mode.TOOLS,
stream_multitask: bool = False,
stream_partial: bool = False,
):
"""Execute the function from the response of an openai chat completion
@@ -203,12 +193,6 @@ class OpenAISchema(BaseModel):
cls (OpenAISchema): An instance of the class
"""
if stream_multitask:
return await cls.from_streaming_response_async(completion, mode)
if stream_partial:
return cls.from_streaming_response_async(completion, mode)
if completion.choices[0].finish_reason == "length":
raise IncompleteOutputException()
+38 -29
View File
@@ -15,7 +15,7 @@ from openai.types.chat import (
from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel, ValidationError
from instructor.dsl.multitask import MultiTask, MultiTaskBase
from instructor.dsl.iterable import IterableModel, IterableBase
from instructor.dsl.partial import PartialBase
from .function_calls import Mode, OpenAISchema, openai_schema
@@ -69,12 +69,12 @@ def handle_response_model(
if response_model is not None:
if get_origin(response_model) is Iterable:
iterable_element_class = get_args(response_model)[0]
response_model = MultiTask(iterable_element_class)
response_model = IterableModel(iterable_element_class)
if not issubclass(response_model, OpenAISchema):
response_model = openai_schema(response_model) # type: ignore
if new_kwargs.get("stream", False) and not issubclass(
response_model, (MultiTaskBase, PartialBase)
response_model, (IterableBase, PartialBase)
):
raise NotImplementedError(
"stream=True is not supported when using response_model parameter for non-iterables"
@@ -162,23 +162,29 @@ def process_response(
validation_context (dict, optional): The validation context to use for validating the response. Defaults to None.
strict (bool, optional): Whether to use strict json parsing. Defaults to None.
"""
if response_model is not None:
is_model_multitask = issubclass(response_model, MultiTaskBase)
is_model_partial = issubclass(response_model, PartialBase)
model = response_model.from_response(
if response_model is None:
return response
if issubclass(response_model, (IterableBase, PartialBase)) and stream:
model = response_model.from_streaming_response(
response,
validation_context=validation_context,
strict=strict,
mode=mode,
stream_multitask=stream and is_model_multitask,
stream_partial=stream and is_model_partial,
)
if not stream:
model._raw_response = response
if is_model_multitask:
return model.tasks
return model
return response
model = response_model.from_response(
response,
validation_context=validation_context,
strict=strict,
mode=mode,
)
model._raw_response = response
if issubclass(response_model, IterableBase):
# If the response model is a multitask, return the tasks
return [task for task in model.tasks]
return model
async def process_response_async(
@@ -201,23 +207,26 @@ async def process_response_async(
validation_context (dict, optional): The validation context to use for validating the response. Defaults to None.
strict (bool, optional): Whether to use strict json parsing. Defaults to None.
"""
if response_model is not None:
is_model_multitask = issubclass(response_model, MultiTaskBase)
is_model_partial = issubclass(response_model, PartialBase)
model = await response_model.from_response_async(
if response_model is None:
return response
if issubclass(response_model, (IterableBase, PartialBase)) and stream:
model = await response_model.from_streaming_response_async(
response,
validation_context=validation_context,
strict=strict,
mode=mode,
stream_multitask=stream and is_model_multitask,
stream_partial=stream and is_model_partial,
)
if not stream:
model._raw_response = response
if is_model_multitask:
return model.tasks
return model
return response
model = await response_model.from_response_async(
response,
validation_context=validation_context,
strict=strict,
mode=mode,
)
model._raw_response = response
if issubclass(response_model, IterableBase):
return model.tasks
return model
async def retry_async(
+2 -2
View File
@@ -134,8 +134,8 @@ nav:
- Fields: 'concepts/fields.md'
- Missing: "concepts/maybe.md"
- Patching: 'concepts/patching.md'
- List (Streaming): "concepts/lists.md"
- Partial (Streaming): "concepts/field_streaming.md"
- Stream Iterable: "concepts/lists.md"
- Stream Partial: "concepts/partial.md"
- Raw Response: 'concepts/raw_response.md'
- FastAPI: 'concepts/fastapi.md'
- Caching: 'concepts/caching.md'
+83
View File
@@ -0,0 +1,83 @@
from itertools import product
from typing import Iterable
from pydantic import BaseModel
import pytest
import instructor
from instructor.dsl.partial import Partial
from tests.openai.util import models, modes
class UserExtract(BaseModel):
name: str
age: int
@pytest.mark.parametrize("model, mode, stream", product(models, modes, [True, False]))
def test_iterable_model(model, mode, stream, client):
client = instructor.patch(client, mode=mode)
model = client.chat.completions.create(
model=model,
response_model=Iterable[UserExtract],
max_retries=2,
stream=stream,
messages=[
{"role": "user", "content": "Make two up people"},
],
)
for m in model:
assert isinstance(m, UserExtract)
@pytest.mark.parametrize("model, mode, stream", product(models, modes, [True, False]))
@pytest.mark.asyncio
async def test_iterable_model_async(model, mode, stream, aclient):
aclient = instructor.patch(aclient, mode=mode)
model = await aclient.chat.completions.create(
model=model,
response_model=Iterable[UserExtract],
max_retries=2,
stream=stream,
messages=[
{"role": "user", "content": "Make two up people"},
],
)
if stream:
async for m in model:
assert isinstance(m, UserExtract)
else:
for m in model:
assert isinstance(m, UserExtract)
@pytest.mark.parametrize("model,mode", product(models, modes))
def test_partial_model(model, mode, client):
client = instructor.patch(client, mode=mode)
model = client.chat.completions.create(
model=model,
response_model=Partial[UserExtract],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)
for m in model:
assert isinstance(m, UserExtract)
@pytest.mark.parametrize("model,mode", product(models, modes))
@pytest.mark.asyncio
async def test_partial_model_async(model, mode, aclient):
aclient = instructor.patch(aclient, mode=mode)
model = await aclient.chat.completions.create(
model=model,
response_model=Partial[UserExtract],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)
async for m in model:
assert isinstance(m, UserExtract)
+1 -2
View File
@@ -1,8 +1,7 @@
import instructor
models = ["gpt-3.5-turbo-1106", "gpt-4-1106-preview"]
models = ["gpt-4-turbo-preview"]
modes = [
instructor.Mode.FUNCTIONS,
instructor.Mode.JSON,
instructor.Mode.TOOLS,
]
+4 -4
View File
@@ -1,5 +1,5 @@
from instructor import OpenAISchema
from instructor.dsl import MultiTask
from instructor.dsl import IterableModel
def test_multi_task():
@@ -9,9 +9,9 @@ def test_multi_task():
id: int
query: str
multitask = MultiTask(Search)
assert multitask.openai_schema["name"] == "MultiSearch"
IterableSearch = IterableModel(Search)
assert IterableSearch.openai_schema["name"] == "IterableSearch"
assert (
multitask.openai_schema["description"]
IterableSearch.openai_schema["description"]
== "Correct segmentation of `Search` tasks"
)