Files
instructor/tests/openai/test_parallel.py
T
Jason Liu c7f1ceeb5c Types!!! (#372)
Co-authored-by: Luke Van Seters <lukevanseters@gmail.com>
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
2024-01-31 22:08:07 -05:00

81 lines
2.4 KiB
Python

from typing import Iterable, Literal, Union
from pydantic import BaseModel
import pytest
import instructor
class Weather(BaseModel):
location: str
units: Literal["imperial", "metric"]
class GoogleSearch(BaseModel):
query: str
def test_sync_parallel_tools_or(client):
client = instructor.patch(client, mode=instructor.Mode.PARALLEL_TOOLS)
resp = client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=[
{"role": "system", "content": "You must always use tools"},
{
"role": "user",
"content": "What is the weather in toronto and dallas and who won the super bowl?",
},
],
response_model=Iterable[Weather | GoogleSearch],
)
assert len(list(resp)) == 3
@pytest.mark.asyncio
async def test_async_parallel_tools_or(aclient):
client = instructor.patch(aclient, mode=instructor.Mode.PARALLEL_TOOLS)
resp = await client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=[
{"role": "system", "content": "You must always use tools"},
{
"role": "user",
"content": "What is the weather in toronto and dallas and who won the super bowl?",
},
],
response_model=Iterable[Weather | GoogleSearch],
)
assert len(list(resp)) == 3
def test_sync_parallel_tools_one(client):
client = instructor.patch(client, mode=instructor.Mode.PARALLEL_TOOLS)
resp = client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=[
{"role": "system", "content": "You must always use tools"},
{
"role": "user",
"content": "What is the weather in toronto and dallas?",
},
],
response_model=Iterable[Weather],
)
assert len(list(resp)) == 2
@pytest.mark.asyncio
async def test_async_parallel_tools_one(aclient):
client = instructor.patch(aclient, mode=instructor.Mode.PARALLEL_TOOLS)
resp = await client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=[
{"role": "system", "content": "You must always use tools"},
{
"role": "user",
"content": "What is the weather in toronto and dallas?",
},
],
response_model=Iterable[Weather],
)
assert len(list(resp)) == 2