mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
ruff
This commit is contained in:
@@ -8,6 +8,7 @@ from openai import OpenAI
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Labels(str, enum.Enum):
|
||||
SPAM = "spam"
|
||||
NOT_SPAM = "not_spam"
|
||||
@@ -20,6 +21,7 @@ class SinglePrediction(BaseModel):
|
||||
|
||||
class_label: Labels
|
||||
|
||||
|
||||
models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview"]
|
||||
modes = [instructor.Mode.FUNCTIONS, instructor.Mode.JSON, instructor.Mode.TOOLS]
|
||||
data = [
|
||||
@@ -33,6 +35,7 @@ data = [
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model, data, mode", product(models, data, modes))
|
||||
def test_classification(model, data, mode):
|
||||
client = instructor.patch(OpenAI(), mode=mode)
|
||||
@@ -67,6 +70,7 @@ class MultiLabels(str, enum.Enum):
|
||||
class MultiClassPrediction(BaseModel):
|
||||
predicted_labels: List[MultiLabels]
|
||||
|
||||
|
||||
data = [
|
||||
(
|
||||
"I am having trouble with my billing",
|
||||
@@ -82,9 +86,9 @@ data = [
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model, data, mode", product(models, data, modes))
|
||||
def test_multi_classify(model, data, mode):
|
||||
|
||||
client = instructor.patch(OpenAI(), mode=mode)
|
||||
|
||||
if mode == instructor.Mode.JSON and model in {"gpt-3.5-turbo", "gpt-4"}:
|
||||
@@ -103,5 +107,5 @@ def test_multi_classify(model, data, mode):
|
||||
"content": f"Classify the following support ticket: {input}",
|
||||
},
|
||||
],
|
||||
)
|
||||
assert set(resp.predicted_labels) == set(expected)
|
||||
)
|
||||
assert set(resp.predicted_labels) == set(expected)
|
||||
|
||||
@@ -8,6 +8,7 @@ from openai import AsyncOpenAI, OpenAI
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SinglePrediction(BaseModel):
|
||||
"""
|
||||
Correct class label for the given text
|
||||
@@ -15,19 +16,15 @@ class SinglePrediction(BaseModel):
|
||||
|
||||
class_label: Literal["spam", "not_spam"]
|
||||
|
||||
|
||||
models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview"]
|
||||
modes = [instructor.Mode.FUNCTIONS, instructor.Mode.JSON, instructor.Mode.TOOLS]
|
||||
data = [
|
||||
(
|
||||
"I am a spammer",
|
||||
"spam"
|
||||
),
|
||||
(
|
||||
"I am not a spammer",
|
||||
"not_spam"
|
||||
),
|
||||
("I am a spammer", "spam"),
|
||||
("I am not a spammer", "not_spam"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model, data, mode", product(models, data, modes))
|
||||
@pytest.mark.asyncio
|
||||
async def test_classification(model, data, mode):
|
||||
@@ -56,6 +53,7 @@ async def test_classification(model, data, mode):
|
||||
class MultiClassPrediction(BaseModel):
|
||||
predicted_labels: List[Literal["billing", "general_query", "hardware"]]
|
||||
|
||||
|
||||
data = [
|
||||
(
|
||||
"I am having trouble with my billing",
|
||||
@@ -71,10 +69,10 @@ data = [
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model, data, mode", product(models, data, modes))
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_classify(model, data, mode):
|
||||
|
||||
client = instructor.patch(AsyncOpenAI(), mode=mode)
|
||||
|
||||
if mode == instructor.Mode.JSON and model in {"gpt-3.5-turbo", "gpt-4"}:
|
||||
@@ -93,5 +91,5 @@ async def test_multi_classify(model, data, mode):
|
||||
"content": f"Classify the following support ticket: {input}",
|
||||
},
|
||||
],
|
||||
)
|
||||
assert set(resp.predicted_labels) == set(expected)
|
||||
)
|
||||
assert set(resp.predicted_labels) == set(expected)
|
||||
|
||||
@@ -83,7 +83,6 @@ The contract can be terminated with a 30-day notice, unless there are outstandin
|
||||
"""
|
||||
|
||||
|
||||
|
||||
models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview"]
|
||||
modes = [instructor.Mode.FUNCTIONS, instructor.Mode.JSON, instructor.Mode.TOOLS]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user