fix: Adding tests to make sure non-stream iterables work (#413)

This commit is contained in:
Jason Liu
2024-02-07 09:46:56 -05:00
committed by GitHub
parent 482143f4c3
commit f9389c18d3
3 changed files with 70 additions and 5 deletions
+2 -3
View File
@@ -213,8 +213,7 @@ def process_response(
# ? This really hints at the fact that we need a better way of
# ? attaching usage data and the raw response to the model we return.
if isinstance(response_model, IterableBase):
#! If the response model is a multitask, return the tasks
if isinstance(model, IterableBase):
return [task for task in model.tasks]
if isinstance(response_model, ParallelBase):
@@ -267,7 +266,7 @@ async def process_response_async(
# ? This really hints at the fact that we need a better way of
# ? attaching usage data and the raw response to the model we return.
if isinstance(response_model, IterableBase):
if isinstance(model, IterableBase):
#! If the response model is a multitask, return the tasks
return [task for task in model.tasks]
+1 -1
View File
@@ -1,6 +1,6 @@
[tool.poetry]
name = "instructor"
version = "0.5.0"
version = "0.5.1"
description = "structured outputs for llm"
authors = ["Jason Liu <jason@jxnl.co>"]
license = "MIT"
+67 -1
View File
@@ -22,7 +22,6 @@ def test_multi_user(model, mode, client):
def stream_extract(input: str) -> Iterable[User]:
return client.chat.completions.create(
model=model,
stream=True,
response_model=Users,
messages=[
{
@@ -54,6 +53,73 @@ def test_multi_user(model, mode, client):
async def test_multi_user_tools_mode_async(model, mode, aclient):
client = instructor.patch(aclient, mode=mode)
async def stream_extract(input: str) -> Iterable[User]:
return await client.chat.completions.create(
model=model,
response_model=Users,
messages=[
{
"role": "user",
"content": (
f"Consider the data below:\n{input}"
"Correctly segment it into entitites"
"Make sure the JSON is correct"
),
},
],
max_tokens=1000,
)
resp = []
async for user in await stream_extract(input="Jason is 20, Sarah is 30"):
resp.append(user)
print(resp)
assert len(resp) == 2
assert resp[0].name == "Jason"
assert resp[0].age == 20
assert resp[1].name == "Sarah"
assert resp[1].age == 30
@pytest.mark.parametrize("model, mode", product(models, modes))
def test_multi_user_stream(model, mode, client):
client = instructor.patch(client, mode=mode)
def stream_extract(input: str) -> Iterable[User]:
return client.chat.completions.create(
model=model,
stream=True,
response_model=Users,
messages=[
{
"role": "system",
"content": "You are a perfect entity extraction system",
},
{
"role": "user",
"content": (
f"Consider the data below:\n{input}"
"Correctly segment it into entitites"
"Make sure the JSON is correct"
),
},
],
max_tokens=1000,
)
resp = [user for user in stream_extract(input="Jason is 20, Sarah is 30")]
assert len(resp) == 2
assert resp[0].name == "Jason"
assert resp[0].age == 20
assert resp[1].name == "Sarah"
assert resp[1].age == 30
@pytest.mark.asyncio
@pytest.mark.parametrize("model, mode", product(models, modes))
async def test_multi_user_tools_mode_async_stream(model, mode, aclient):
client = instructor.patch(aclient, mode=mode)
async def stream_extract(input: str) -> Iterable[User]:
return await client.chat.completions.create(
model=model,