diff --git a/docs/concepts/lists.md b/docs/concepts/lists.md index 9ab2aeb..c8d076f 100644 --- a/docs/concepts/lists.md +++ b/docs/concepts/lists.md @@ -99,4 +99,20 @@ for user in users: >>> name="John" "age"=10 ``` -This streaming is still a prototype, but should work quite well for simple schemas. +## Asynchronous Streaming + +I also just want to call out in this example that `instructor` also supports asynchronous streaming. This is useful when you want to stream a response model and process the results as they come in, but you'll need to use the `async for` syntax to iterate over the results. + +```python +model = await client.chat.completions.create( + model="gpt-4", + response_model=Iterable[UserExtract], + max_retries=2, + stream=stream, + messages=[ + {"role": "user", "content": "Make two up people"}, + ], +) +async for m in model: + assert isinstance(m, UserExtract) +``` diff --git a/docs/concepts/field_streaming.md b/docs/concepts/partial.md similarity index 85% rename from docs/concepts/field_streaming.md rename to docs/concepts/partial.md index 0d6f1bf..f0ad348 100644 --- a/docs/concepts/field_streaming.md +++ b/docs/concepts/partial.md @@ -103,3 +103,21 @@ for extraction in extraction_stream: This will output the following: ![Partial Streaming Gif](../img/partial.gif) + +## Asynchronous Streaming + +I also just want to call out in this example that `instructor` also supports asynchronous streaming. This is useful when you want to stream a response model and process the results as they come in, but you'll need to use the `async for` syntax to iterate over the results. + +```python +model = await client.chat.completions.create( + model="gpt-4", + response_model=Partial[UserExtract], + max_retries=2, + stream=True, + messages=[ + {"role": "user", "content": "Jason Liu is 12 years old"}, + ], +) +async for m in model: + assert isinstance(m, UserExtract) +``` diff --git a/instructor/__init__.py b/instructor/__init__.py index 08b4111..2f2b3fc 100644 --- a/instructor/__init__.py +++ b/instructor/__init__.py @@ -3,7 +3,7 @@ from .dsl import ( CitationMixin, Maybe, Partial, - MultiTask, + IterableModel, llm_validator, openai_moderation, ) @@ -13,7 +13,7 @@ from .patch import apatch, patch __all__ = [ "OpenAISchema", "CitationMixin", - "MultiTask", + "IterableModel", "Maybe", "Partial", "openai_schema", diff --git a/instructor/dsl/__init__.py b/instructor/dsl/__init__.py index da5b3d1..83ae1de 100644 --- a/instructor/dsl/__init__.py +++ b/instructor/dsl/__init__.py @@ -1,4 +1,4 @@ -from .multitask import MultiTask +from .iterable import IterableModel from .maybe import Maybe from .partial import Partial from .validators import llm_validator, openai_moderation @@ -6,7 +6,7 @@ from .citation import CitationMixin __all__ = [ # noqa: F405 "CitationMixin", - "MultiTask", + "IterableModel", "Maybe", "Partial", "llm_validator", diff --git a/instructor/dsl/multitask.py b/instructor/dsl/iterable.py similarity index 90% rename from instructor/dsl/multitask.py rename to instructor/dsl/iterable.py index e03da8c..295fb67 100644 --- a/instructor/dsl/multitask.py +++ b/instructor/dsl/iterable.py @@ -5,21 +5,21 @@ from pydantic import BaseModel, Field, create_model from instructor.function_calls import OpenAISchema, Mode -class MultiTaskBase: +class IterableBase: task_type = None # type: ignore @classmethod - def from_streaming_response(cls, completion, mode: Mode): + def from_streaming_response(cls, completion, mode: Mode, **kwargs): json_chunks = cls.extract_json(completion, mode) yield from cls.tasks_from_chunks(json_chunks) @classmethod - async def from_streaming_response_async(cls, completion, mode: Mode): + async def from_streaming_response_async(cls, completion, mode: Mode, **kwargs): json_chunks = cls.extract_json_async(completion, mode) - return cls.tasks_from_chunks_async(json_chunks) + return cls.tasks_from_chunks_async(json_chunks, **kwargs) @classmethod - def tasks_from_chunks(cls, json_chunks): + def tasks_from_chunks(cls, json_chunks, **kwargs): started = False potential_object = "" for chunk in json_chunks: @@ -32,11 +32,11 @@ class MultiTaskBase: task_json, potential_object = cls.get_object(potential_object, 0) if task_json: - obj = cls.task_type.model_validate_json(task_json) # type: ignore + obj = cls.task_type.model_validate_json(task_json, **kwargs) # type: ignore yield obj @classmethod - async def tasks_from_chunks_async(cls, json_chunks): + async def tasks_from_chunks_async(cls, json_chunks, **kwargs): started = False potential_object = "" async for chunk in json_chunks: @@ -49,7 +49,7 @@ class MultiTaskBase: task_json, potential_object = cls.get_object(potential_object, 0) if task_json: - obj = cls.task_type.model_validate_json(task_json) # type: ignore + obj = cls.task_type.model_validate_json(task_json, **kwargs) # type: ignore yield obj @staticmethod @@ -106,13 +106,13 @@ class MultiTaskBase: return None, str -def MultiTask( +def IterableModel( subtask_class: Type[BaseModel], name: Optional[str] = None, description: Optional[str] = None, ): """ - Dynamically create a MultiTask OpenAISchema that can be used to segment multiple + Dynamically create a IterableModel OpenAISchema that can be used to segment multiple tasks given a base class. This creates class that can be used to create a toolkit for a specific task, names and descriptions are automatically generated. However they can be overridden. @@ -121,14 +121,14 @@ def MultiTask( ```python from pydantic import BaseModel, Field - from instructor import MultiTask + from instructor import IterableModel class User(BaseModel): name: str = Field(description="The name of the person") age: int = Field(description="The age of the person") role: str = Field(description="The role of the person") - MultiUser = MultiTask(User) + MultiUser = IterableModel(User) ``` ## Result @@ -163,7 +163,7 @@ def MultiTask( """ task_name = subtask_class.__name__ if name is None else name - name = f"Multi{task_name}" + name = f"Iterable{task_name}" list_tasks = ( List[subtask_class], @@ -177,7 +177,7 @@ def MultiTask( new_cls = create_model( name, tasks=list_tasks, - __base__=(OpenAISchema, MultiTaskBase), # type: ignore + __base__=(OpenAISchema, IterableBase), # type: ignore ) # set the class constructor BaseModel new_cls.task_type = subtask_class diff --git a/instructor/dsl/partial.py b/instructor/dsl/partial.py index 6f880cf..f439312 100644 --- a/instructor/dsl/partial.py +++ b/instructor/dsl/partial.py @@ -21,17 +21,17 @@ Model = TypeVar("Model", bound=BaseModel) class PartialBase: @classmethod - def from_streaming_response(cls, completion, mode: Mode): + def from_streaming_response(cls, completion, mode: Mode, **kwargs): json_chunks = cls.extract_json(completion, mode) - yield from cls.model_from_chunks(json_chunks) + yield from cls.model_from_chunks(json_chunks, **kwargs) @classmethod - async def from_streaming_response_async(cls, completion, mode: Mode): + async def from_streaming_response_async(cls, completion, mode: Mode, **kwargs): json_chunks = cls.extract_json_async(completion, mode) - return cls.model_from_chunks_async(json_chunks) + return cls.model_from_chunks_async(json_chunks, **kwargs) @classmethod - def model_from_chunks(cls, json_chunks): + def model_from_chunks(cls, json_chunks, **kwargs): prev_obj = None potential_object = "" for chunk in json_chunks: @@ -42,7 +42,7 @@ class PartialBase: parser.parse(potential_object) if potential_object.strip() else None ) if task_json: - obj = cls.model_validate(task_json, strict=None) # type: ignore + obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore if obj != prev_obj: obj.__dict__[ "chunk" @@ -51,8 +51,9 @@ class PartialBase: yield obj @classmethod - async def model_from_chunks_async(cls, json_chunks): + async def model_from_chunks_async(cls, json_chunks, **kwargs): potential_object = "" + prev_obj = None async for chunk in json_chunks: potential_object += chunk @@ -61,7 +62,7 @@ class PartialBase: parser.parse(potential_object) if potential_object.strip() else None ) if task_json: - obj = cls.model_validate(task_json, strict=None) # type: ignore + obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore if obj != prev_obj: obj.__dict__[ "chunk" diff --git a/instructor/function_calls.py b/instructor/function_calls.py index 18c9f70..42221a7 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -122,8 +122,6 @@ class OpenAISchema(BaseModel): validation_context=None, strict: bool = None, mode: Mode = Mode.TOOLS, - stream_multitask: bool = False, - stream_partial: bool = False, ): """Execute the function from the response of an openai chat completion @@ -138,12 +136,6 @@ class OpenAISchema(BaseModel): Returns: cls (OpenAISchema): An instance of the class """ - if stream_multitask: - return cls.from_streaming_response(completion, mode) - - if stream_partial: - return cls.from_streaming_response(completion, mode) - if completion.choices[0].finish_reason == "length": raise IncompleteOutputException() @@ -187,8 +179,6 @@ class OpenAISchema(BaseModel): validation_context=None, strict: bool = None, mode: Mode = Mode.TOOLS, - stream_multitask: bool = False, - stream_partial: bool = False, ): """Execute the function from the response of an openai chat completion @@ -203,12 +193,6 @@ class OpenAISchema(BaseModel): cls (OpenAISchema): An instance of the class """ - if stream_multitask: - return await cls.from_streaming_response_async(completion, mode) - - if stream_partial: - return cls.from_streaming_response_async(completion, mode) - if completion.choices[0].finish_reason == "length": raise IncompleteOutputException() diff --git a/instructor/patch.py b/instructor/patch.py index d936091..dfa3335 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -15,7 +15,7 @@ from openai.types.chat import ( from openai.types.completion_usage import CompletionUsage from pydantic import BaseModel, ValidationError -from instructor.dsl.multitask import MultiTask, MultiTaskBase +from instructor.dsl.iterable import IterableModel, IterableBase from instructor.dsl.partial import PartialBase from .function_calls import Mode, OpenAISchema, openai_schema @@ -69,12 +69,12 @@ def handle_response_model( if response_model is not None: if get_origin(response_model) is Iterable: iterable_element_class = get_args(response_model)[0] - response_model = MultiTask(iterable_element_class) + response_model = IterableModel(iterable_element_class) if not issubclass(response_model, OpenAISchema): response_model = openai_schema(response_model) # type: ignore if new_kwargs.get("stream", False) and not issubclass( - response_model, (MultiTaskBase, PartialBase) + response_model, (IterableBase, PartialBase) ): raise NotImplementedError( "stream=True is not supported when using response_model parameter for non-iterables" @@ -162,23 +162,29 @@ def process_response( validation_context (dict, optional): The validation context to use for validating the response. Defaults to None. strict (bool, optional): Whether to use strict json parsing. Defaults to None. """ - if response_model is not None: - is_model_multitask = issubclass(response_model, MultiTaskBase) - is_model_partial = issubclass(response_model, PartialBase) - model = response_model.from_response( + if response_model is None: + return response + + if issubclass(response_model, (IterableBase, PartialBase)) and stream: + model = response_model.from_streaming_response( response, - validation_context=validation_context, - strict=strict, mode=mode, - stream_multitask=stream and is_model_multitask, - stream_partial=stream and is_model_partial, ) - if not stream: - model._raw_response = response - if is_model_multitask: - return model.tasks return model - return response + + model = response_model.from_response( + response, + validation_context=validation_context, + strict=strict, + mode=mode, + ) + model._raw_response = response + + if issubclass(response_model, IterableBase): + # If the response model is a multitask, return the tasks + return [task for task in model.tasks] + + return model async def process_response_async( @@ -201,23 +207,26 @@ async def process_response_async( validation_context (dict, optional): The validation context to use for validating the response. Defaults to None. strict (bool, optional): Whether to use strict json parsing. Defaults to None. """ - if response_model is not None: - is_model_multitask = issubclass(response_model, MultiTaskBase) - is_model_partial = issubclass(response_model, PartialBase) - model = await response_model.from_response_async( + if response_model is None: + return response + + if issubclass(response_model, (IterableBase, PartialBase)) and stream: + model = await response_model.from_streaming_response_async( response, - validation_context=validation_context, - strict=strict, mode=mode, - stream_multitask=stream and is_model_multitask, - stream_partial=stream and is_model_partial, ) - if not stream: - model._raw_response = response - if is_model_multitask: - return model.tasks return model - return response + + model = await response_model.from_response_async( + response, + validation_context=validation_context, + strict=strict, + mode=mode, + ) + model._raw_response = response + if issubclass(response_model, IterableBase): + return model.tasks + return model async def retry_async( diff --git a/mkdocs.yml b/mkdocs.yml index 6404bbe..b8f3c7a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -134,8 +134,8 @@ nav: - Fields: 'concepts/fields.md' - Missing: "concepts/maybe.md" - Patching: 'concepts/patching.md' - - List (Streaming): "concepts/lists.md" - - Partial (Streaming): "concepts/field_streaming.md" + - Stream Iterable: "concepts/lists.md" + - Stream Partial: "concepts/partial.md" - Raw Response: 'concepts/raw_response.md' - FastAPI: 'concepts/fastapi.md' - Caching: 'concepts/caching.md' diff --git a/tests/openai/test_stream.py b/tests/openai/test_stream.py new file mode 100644 index 0000000..4b9b3b0 --- /dev/null +++ b/tests/openai/test_stream.py @@ -0,0 +1,83 @@ +from itertools import product +from typing import Iterable +from pydantic import BaseModel +import pytest +import instructor +from instructor.dsl.partial import Partial + +from tests.openai.util import models, modes + + +class UserExtract(BaseModel): + name: str + age: int + + +@pytest.mark.parametrize("model, mode, stream", product(models, modes, [True, False])) +def test_iterable_model(model, mode, stream, client): + client = instructor.patch(client, mode=mode) + model = client.chat.completions.create( + model=model, + response_model=Iterable[UserExtract], + max_retries=2, + stream=stream, + messages=[ + {"role": "user", "content": "Make two up people"}, + ], + ) + for m in model: + assert isinstance(m, UserExtract) + + +@pytest.mark.parametrize("model, mode, stream", product(models, modes, [True, False])) +@pytest.mark.asyncio +async def test_iterable_model_async(model, mode, stream, aclient): + aclient = instructor.patch(aclient, mode=mode) + model = await aclient.chat.completions.create( + model=model, + response_model=Iterable[UserExtract], + max_retries=2, + stream=stream, + messages=[ + {"role": "user", "content": "Make two up people"}, + ], + ) + if stream: + async for m in model: + assert isinstance(m, UserExtract) + else: + for m in model: + assert isinstance(m, UserExtract) + + +@pytest.mark.parametrize("model,mode", product(models, modes)) +def test_partial_model(model, mode, client): + client = instructor.patch(client, mode=mode) + model = client.chat.completions.create( + model=model, + response_model=Partial[UserExtract], + max_retries=2, + stream=True, + messages=[ + {"role": "user", "content": "Jason Liu is 12 years old"}, + ], + ) + for m in model: + assert isinstance(m, UserExtract) + + +@pytest.mark.parametrize("model,mode", product(models, modes)) +@pytest.mark.asyncio +async def test_partial_model_async(model, mode, aclient): + aclient = instructor.patch(aclient, mode=mode) + model = await aclient.chat.completions.create( + model=model, + response_model=Partial[UserExtract], + max_retries=2, + stream=True, + messages=[ + {"role": "user", "content": "Jason Liu is 12 years old"}, + ], + ) + async for m in model: + assert isinstance(m, UserExtract) diff --git a/tests/openai/util.py b/tests/openai/util.py index 1234286..b118e3e 100644 --- a/tests/openai/util.py +++ b/tests/openai/util.py @@ -1,8 +1,7 @@ import instructor -models = ["gpt-3.5-turbo-1106", "gpt-4-1106-preview"] +models = ["gpt-4-turbo-preview"] modes = [ - instructor.Mode.FUNCTIONS, instructor.Mode.JSON, instructor.Mode.TOOLS, ] diff --git a/tests/test_multitask.py b/tests/test_multitask.py index 04bd014..4efcecd 100644 --- a/tests/test_multitask.py +++ b/tests/test_multitask.py @@ -1,5 +1,5 @@ from instructor import OpenAISchema -from instructor.dsl import MultiTask +from instructor.dsl import IterableModel def test_multi_task(): @@ -9,9 +9,9 @@ def test_multi_task(): id: int query: str - multitask = MultiTask(Search) - assert multitask.openai_schema["name"] == "MultiSearch" + IterableSearch = IterableModel(Search) + assert IterableSearch.openai_schema["name"] == "IterableSearch" assert ( - multitask.openai_schema["description"] + IterableSearch.openai_schema["description"] == "Correct segmentation of `Search` tasks" )