Files
instructor/tests/openai/test_stream.py
T
2024-01-31 18:31:33 -05:00

84 lines
2.4 KiB
Python

from itertools import product
from typing import Iterable
from pydantic import BaseModel
import pytest
import instructor
from instructor.dsl.partial import Partial
from tests.openai.util import models, modes
class UserExtract(BaseModel):
name: str
age: int
@pytest.mark.parametrize("model, mode, stream", product(models, modes, [True, False]))
def test_iterable_model(model, mode, stream, client):
client = instructor.patch(client, mode=mode)
model = client.chat.completions.create(
model=model,
response_model=Iterable[UserExtract],
max_retries=2,
stream=stream,
messages=[
{"role": "user", "content": "Make two up people"},
],
)
for m in model:
assert isinstance(m, UserExtract)
@pytest.mark.parametrize("model, mode, stream", product(models, modes, [True, False]))
@pytest.mark.asyncio
async def test_iterable_model_async(model, mode, stream, aclient):
aclient = instructor.patch(aclient, mode=mode)
model = await aclient.chat.completions.create(
model=model,
response_model=Iterable[UserExtract],
max_retries=2,
stream=stream,
messages=[
{"role": "user", "content": "Make two up people"},
],
)
if stream:
async for m in model:
assert isinstance(m, UserExtract)
else:
for m in model:
assert isinstance(m, UserExtract)
@pytest.mark.parametrize("model,mode", product(models, modes))
def test_partial_model(model, mode, client):
client = instructor.patch(client, mode=mode)
model = client.chat.completions.create(
model=model,
response_model=Partial[UserExtract],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)
for m in model:
assert isinstance(m, UserExtract)
@pytest.mark.parametrize("model,mode", product(models, modes))
@pytest.mark.asyncio
async def test_partial_model_async(model, mode, aclient):
aclient = instructor.patch(aclient, mode=mode)
model = await aclient.chat.completions.create(
model=model,
response_model=Partial[UserExtract],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)
async for m in model:
assert isinstance(m, UserExtract)