Files
instructor/tests/openai/evals/test_classification_literals.py
T
Jason Liu 2a5cfcecbd ruff
2023-11-25 20:27:57 -05:00

96 lines
2.5 KiB
Python

import enum
from itertools import product
from typing import List, Literal
import pytest
import instructor
from openai import AsyncOpenAI, OpenAI
from pydantic import BaseModel
class SinglePrediction(BaseModel):
"""
Correct class label for the given text
"""
class_label: Literal["spam", "not_spam"]
models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview"]
modes = [instructor.Mode.FUNCTIONS, instructor.Mode.JSON, instructor.Mode.TOOLS]
data = [
("I am a spammer", "spam"),
("I am not a spammer", "not_spam"),
]
@pytest.mark.parametrize("model, data, mode", product(models, data, modes))
@pytest.mark.asyncio
async def test_classification(model, data, mode):
client = instructor.patch(AsyncOpenAI(), mode=mode)
if mode == instructor.Mode.JSON and model in {"gpt-3.5-turbo", "gpt-4"}:
pytest.skip(
"JSON mode is not supported for gpt-3.5-turbo and gpt-4, skipping test"
)
input, expected = data
resp = await client.chat.completions.create(
model=model,
response_model=SinglePrediction,
messages=[
{
"role": "user",
"content": f"Classify the following text: {input}",
},
],
)
assert resp.class_label == expected
# Adjust the prediction model to accommodate a list of labels
class MultiClassPrediction(BaseModel):
predicted_labels: List[Literal["billing", "general_query", "hardware"]]
data = [
(
"I am having trouble with my billing",
["billing"],
),
(
"I am having trouble with my hardware",
["hardware"],
),
(
"I have a general query and a billing issue",
["general_query", "billing"],
),
]
@pytest.mark.parametrize("model, data, mode", product(models, data, modes))
@pytest.mark.asyncio
async def test_multi_classify(model, data, mode):
client = instructor.patch(AsyncOpenAI(), mode=mode)
if mode == instructor.Mode.JSON and model in {"gpt-3.5-turbo", "gpt-4"}:
pytest.skip(
"JSON mode is not supported for gpt-3.5-turbo and gpt-4, skipping test"
)
input, expected = data
resp = await client.chat.completions.create(
model=model,
response_model=MultiClassPrediction,
messages=[
{
"role": "user",
"content": f"Classify the following support ticket: {input}",
},
],
)
assert set(resp.predicted_labels) == set(expected)