mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 14:50:16 +00:00
fix: Adding tests to make sure non-stream iterables work (#413)
This commit is contained in:
+2
-3
@@ -213,8 +213,7 @@ def process_response(
|
||||
|
||||
# ? This really hints at the fact that we need a better way of
|
||||
# ? attaching usage data and the raw response to the model we return.
|
||||
if isinstance(response_model, IterableBase):
|
||||
#! If the response model is a multitask, return the tasks
|
||||
if isinstance(model, IterableBase):
|
||||
return [task for task in model.tasks]
|
||||
|
||||
if isinstance(response_model, ParallelBase):
|
||||
@@ -267,7 +266,7 @@ async def process_response_async(
|
||||
|
||||
# ? This really hints at the fact that we need a better way of
|
||||
# ? attaching usage data and the raw response to the model we return.
|
||||
if isinstance(response_model, IterableBase):
|
||||
if isinstance(model, IterableBase):
|
||||
#! If the response model is a multitask, return the tasks
|
||||
return [task for task in model.tasks]
|
||||
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "instructor"
|
||||
version = "0.5.0"
|
||||
version = "0.5.1"
|
||||
description = "structured outputs for llm"
|
||||
authors = ["Jason Liu <jason@jxnl.co>"]
|
||||
license = "MIT"
|
||||
|
||||
@@ -22,7 +22,6 @@ def test_multi_user(model, mode, client):
|
||||
def stream_extract(input: str) -> Iterable[User]:
|
||||
return client.chat.completions.create(
|
||||
model=model,
|
||||
stream=True,
|
||||
response_model=Users,
|
||||
messages=[
|
||||
{
|
||||
@@ -54,6 +53,73 @@ def test_multi_user(model, mode, client):
|
||||
async def test_multi_user_tools_mode_async(model, mode, aclient):
|
||||
client = instructor.patch(aclient, mode=mode)
|
||||
|
||||
async def stream_extract(input: str) -> Iterable[User]:
|
||||
return await client.chat.completions.create(
|
||||
model=model,
|
||||
response_model=Users,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Consider the data below:\n{input}"
|
||||
"Correctly segment it into entitites"
|
||||
"Make sure the JSON is correct"
|
||||
),
|
||||
},
|
||||
],
|
||||
max_tokens=1000,
|
||||
)
|
||||
|
||||
resp = []
|
||||
async for user in await stream_extract(input="Jason is 20, Sarah is 30"):
|
||||
resp.append(user)
|
||||
print(resp)
|
||||
assert len(resp) == 2
|
||||
assert resp[0].name == "Jason"
|
||||
assert resp[0].age == 20
|
||||
assert resp[1].name == "Sarah"
|
||||
assert resp[1].age == 30
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model, mode", product(models, modes))
|
||||
def test_multi_user_stream(model, mode, client):
|
||||
client = instructor.patch(client, mode=mode)
|
||||
|
||||
def stream_extract(input: str) -> Iterable[User]:
|
||||
return client.chat.completions.create(
|
||||
model=model,
|
||||
stream=True,
|
||||
response_model=Users,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a perfect entity extraction system",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Consider the data below:\n{input}"
|
||||
"Correctly segment it into entitites"
|
||||
"Make sure the JSON is correct"
|
||||
),
|
||||
},
|
||||
],
|
||||
max_tokens=1000,
|
||||
)
|
||||
|
||||
resp = [user for user in stream_extract(input="Jason is 20, Sarah is 30")]
|
||||
assert len(resp) == 2
|
||||
assert resp[0].name == "Jason"
|
||||
assert resp[0].age == 20
|
||||
assert resp[1].name == "Sarah"
|
||||
assert resp[1].age == 30
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model, mode", product(models, modes))
|
||||
async def test_multi_user_tools_mode_async_stream(model, mode, aclient):
|
||||
client = instructor.patch(aclient, mode=mode)
|
||||
|
||||
async def stream_extract(input: str) -> Iterable[User]:
|
||||
return await client.chat.completions.create(
|
||||
model=model,
|
||||
|
||||
Reference in New Issue
Block a user