fix multi commit

This commit is contained in:
Jason
2023-09-13 22:52:09 -04:00
parent d88f23321f
commit 3f0e0477a9
+6 -9
View File
@@ -9,18 +9,14 @@ patch()
# Define new Enum class for multiple labels
class MultiLabels(str, enum.Enum):
TECH_ISSUE = "tech_issue"
BILLING = "billing"
GENERAL_QUERY = "general_query"
HARDWARE = "hardware"
# Adjust the prediction model to accommodate a list of labels
class MultiClassPrediction(BaseModel):
"""
List of correct class labels for the given text (Multi Class)
"""
class_labels: List[MultiLabels]
predicted_labels: List[MultiLabels]
# Modify the classify function
@@ -38,7 +34,8 @@ def multi_classify(data: str) -> MultiClassPrediction:
# Example using a support ticket
ticket = "My account is locked and I can't access my billing info."
ticket = (
"My account is locked and I can't access my billing info. Phone is also broken."
)
prediction = multi_classify(ticket)
assert MultiLabels.TECH_ISSUE in prediction.class_labels
assert MultiLabels.BILLING in prediction.class_labels
print(prediction)