diff --git a/examples/classification/multi_prediction.py b/examples/classification/multi_prediction.py index 1f4f261..107460e 100644 --- a/examples/classification/multi_prediction.py +++ b/examples/classification/multi_prediction.py @@ -9,18 +9,14 @@ patch() # Define new Enum class for multiple labels class MultiLabels(str, enum.Enum): - TECH_ISSUE = "tech_issue" BILLING = "billing" GENERAL_QUERY = "general_query" + HARDWARE = "hardware" # Adjust the prediction model to accommodate a list of labels class MultiClassPrediction(BaseModel): - """ - List of correct class labels for the given text (Multi Class) - """ - - class_labels: List[MultiLabels] + predicted_labels: List[MultiLabels] # Modify the classify function @@ -38,7 +34,8 @@ def multi_classify(data: str) -> MultiClassPrediction: # Example using a support ticket -ticket = "My account is locked and I can't access my billing info." +ticket = ( + "My account is locked and I can't access my billing info. Phone is also broken." +) prediction = multi_classify(ticket) -assert MultiLabels.TECH_ISSUE in prediction.class_labels -assert MultiLabels.BILLING in prediction.class_labels +print(prediction)