diff --git a/instructor/patch.py b/instructor/patch.py index 878c35d..e000d38 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -213,8 +213,7 @@ def process_response( # ? This really hints at the fact that we need a better way of # ? attaching usage data and the raw response to the model we return. - if isinstance(response_model, IterableBase): - #! If the response model is a multitask, return the tasks + if isinstance(model, IterableBase): return [task for task in model.tasks] if isinstance(response_model, ParallelBase): @@ -267,7 +266,7 @@ async def process_response_async( # ? This really hints at the fact that we need a better way of # ? attaching usage data and the raw response to the model we return. - if isinstance(response_model, IterableBase): + if isinstance(model, IterableBase): #! If the response model is a multitask, return the tasks return [task for task in model.tasks] diff --git a/pyproject.toml b/pyproject.toml index 4bddb80..cc824fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "instructor" -version = "0.5.0" +version = "0.5.1" description = "structured outputs for llm" authors = ["Jason Liu "] license = "MIT" diff --git a/tests/openai/test_multitask.py b/tests/openai/test_multitask.py index 23b869f..1700a47 100644 --- a/tests/openai/test_multitask.py +++ b/tests/openai/test_multitask.py @@ -22,7 +22,6 @@ def test_multi_user(model, mode, client): def stream_extract(input: str) -> Iterable[User]: return client.chat.completions.create( model=model, - stream=True, response_model=Users, messages=[ { @@ -54,6 +53,73 @@ def test_multi_user(model, mode, client): async def test_multi_user_tools_mode_async(model, mode, aclient): client = instructor.patch(aclient, mode=mode) + async def stream_extract(input: str) -> Iterable[User]: + return await client.chat.completions.create( + model=model, + response_model=Users, + messages=[ + { + "role": "user", + "content": ( + f"Consider the data below:\n{input}" + "Correctly segment it into entitites" + "Make sure the JSON is correct" + ), + }, + ], + max_tokens=1000, + ) + + resp = [] + async for user in await stream_extract(input="Jason is 20, Sarah is 30"): + resp.append(user) + print(resp) + assert len(resp) == 2 + assert resp[0].name == "Jason" + assert resp[0].age == 20 + assert resp[1].name == "Sarah" + assert resp[1].age == 30 + + +@pytest.mark.parametrize("model, mode", product(models, modes)) +def test_multi_user_stream(model, mode, client): + client = instructor.patch(client, mode=mode) + + def stream_extract(input: str) -> Iterable[User]: + return client.chat.completions.create( + model=model, + stream=True, + response_model=Users, + messages=[ + { + "role": "system", + "content": "You are a perfect entity extraction system", + }, + { + "role": "user", + "content": ( + f"Consider the data below:\n{input}" + "Correctly segment it into entitites" + "Make sure the JSON is correct" + ), + }, + ], + max_tokens=1000, + ) + + resp = [user for user in stream_extract(input="Jason is 20, Sarah is 30")] + assert len(resp) == 2 + assert resp[0].name == "Jason" + assert resp[0].age == 20 + assert resp[1].name == "Sarah" + assert resp[1].age == 30 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model, mode", product(models, modes)) +async def test_multi_user_tools_mode_async_stream(model, mode, aclient): + client = instructor.patch(aclient, mode=mode) + async def stream_extract(input: str) -> Iterable[User]: return await client.chat.completions.create( model=model,