mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
fix multi commit
This commit is contained in:
@@ -9,18 +9,14 @@ patch()
|
||||
|
||||
# Define new Enum class for multiple labels
|
||||
class MultiLabels(str, enum.Enum):
|
||||
TECH_ISSUE = "tech_issue"
|
||||
BILLING = "billing"
|
||||
GENERAL_QUERY = "general_query"
|
||||
HARDWARE = "hardware"
|
||||
|
||||
|
||||
# Adjust the prediction model to accommodate a list of labels
|
||||
class MultiClassPrediction(BaseModel):
|
||||
"""
|
||||
List of correct class labels for the given text (Multi Class)
|
||||
"""
|
||||
|
||||
class_labels: List[MultiLabels]
|
||||
predicted_labels: List[MultiLabels]
|
||||
|
||||
|
||||
# Modify the classify function
|
||||
@@ -38,7 +34,8 @@ def multi_classify(data: str) -> MultiClassPrediction:
|
||||
|
||||
|
||||
# Example using a support ticket
|
||||
ticket = "My account is locked and I can't access my billing info."
|
||||
ticket = (
|
||||
"My account is locked and I can't access my billing info. Phone is also broken."
|
||||
)
|
||||
prediction = multi_classify(ticket)
|
||||
assert MultiLabels.TECH_ISSUE in prediction.class_labels
|
||||
assert MultiLabels.BILLING in prediction.class_labels
|
||||
print(prediction)
|
||||
|
||||
Reference in New Issue
Block a user