This commit is contained in:
Jason Liu
2023-11-25 20:27:57 -05:00
parent 3fd2bbd6d6
commit 2a5cfcecbd
3 changed files with 16 additions and 15 deletions
@@ -8,6 +8,7 @@ from openai import OpenAI
from pydantic import BaseModel
class Labels(str, enum.Enum):
SPAM = "spam"
NOT_SPAM = "not_spam"
@@ -20,6 +21,7 @@ class SinglePrediction(BaseModel):
class_label: Labels
models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview"]
modes = [instructor.Mode.FUNCTIONS, instructor.Mode.JSON, instructor.Mode.TOOLS]
data = [
@@ -33,6 +35,7 @@ data = [
),
]
@pytest.mark.parametrize("model, data, mode", product(models, data, modes))
def test_classification(model, data, mode):
client = instructor.patch(OpenAI(), mode=mode)
@@ -67,6 +70,7 @@ class MultiLabels(str, enum.Enum):
class MultiClassPrediction(BaseModel):
predicted_labels: List[MultiLabels]
data = [
(
"I am having trouble with my billing",
@@ -82,9 +86,9 @@ data = [
),
]
@pytest.mark.parametrize("model, data, mode", product(models, data, modes))
def test_multi_classify(model, data, mode):
client = instructor.patch(OpenAI(), mode=mode)
if mode == instructor.Mode.JSON and model in {"gpt-3.5-turbo", "gpt-4"}:
@@ -103,5 +107,5 @@ def test_multi_classify(model, data, mode):
"content": f"Classify the following support ticket: {input}",
},
],
)
assert set(resp.predicted_labels) == set(expected)
)
assert set(resp.predicted_labels) == set(expected)
@@ -8,6 +8,7 @@ from openai import AsyncOpenAI, OpenAI
from pydantic import BaseModel
class SinglePrediction(BaseModel):
"""
Correct class label for the given text
@@ -15,19 +16,15 @@ class SinglePrediction(BaseModel):
class_label: Literal["spam", "not_spam"]
models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview"]
modes = [instructor.Mode.FUNCTIONS, instructor.Mode.JSON, instructor.Mode.TOOLS]
data = [
(
"I am a spammer",
"spam"
),
(
"I am not a spammer",
"not_spam"
),
("I am a spammer", "spam"),
("I am not a spammer", "not_spam"),
]
@pytest.mark.parametrize("model, data, mode", product(models, data, modes))
@pytest.mark.asyncio
async def test_classification(model, data, mode):
@@ -56,6 +53,7 @@ async def test_classification(model, data, mode):
class MultiClassPrediction(BaseModel):
predicted_labels: List[Literal["billing", "general_query", "hardware"]]
data = [
(
"I am having trouble with my billing",
@@ -71,10 +69,10 @@ data = [
),
]
@pytest.mark.parametrize("model, data, mode", product(models, data, modes))
@pytest.mark.asyncio
async def test_multi_classify(model, data, mode):
client = instructor.patch(AsyncOpenAI(), mode=mode)
if mode == instructor.Mode.JSON and model in {"gpt-3.5-turbo", "gpt-4"}:
@@ -93,5 +91,5 @@ async def test_multi_classify(model, data, mode):
"content": f"Classify the following support ticket: {input}",
},
],
)
assert set(resp.predicted_labels) == set(expected)
)
assert set(resp.predicted_labels) == set(expected)
-1
View File
@@ -83,7 +83,6 @@ The contract can be terminated with a 30-day notice, unless there are outstandin
"""
models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview"]
modes = [instructor.Mode.FUNCTIONS, instructor.Mode.JSON, instructor.Mode.TOOLS]