From 21d3a968d2206a20ee91dfeb5ed441fd9bdab5b8 Mon Sep 17 00:00:00 2001 From: Jason Liu Date: Tue, 30 Jan 2024 23:02:07 -0500 Subject: [PATCH] Implement Validated Multi-Classes --- examples/validated-multiclass/output.json | 26 ++++++ examples/validated-multiclass/run.py | 101 ++++++++++++++++++++++ 2 files changed, 127 insertions(+) create mode 100644 examples/validated-multiclass/output.json create mode 100644 examples/validated-multiclass/run.py diff --git a/examples/validated-multiclass/output.json b/examples/validated-multiclass/output.json new file mode 100644 index 0000000..c4c284b --- /dev/null +++ b/examples/validated-multiclass/output.json @@ -0,0 +1,26 @@ +{ + "texts": [ + "What is your phone number?", + "What is your email address?", + "What is your address?", + "What is your privacy policy?" + ], + "predictions": [ + { + "id": 1, + "name": "phone" + }, + { + "id": 2, + "name": "email" + }, + { + "id": 3, + "name": "address" + }, + { + "id": 4, + "name": "Other" + } + ] +} \ No newline at end of file diff --git a/examples/validated-multiclass/run.py b/examples/validated-multiclass/run.py new file mode 100644 index 0000000..3d4bea8 --- /dev/null +++ b/examples/validated-multiclass/run.py @@ -0,0 +1,101 @@ +from typing import List +from pydantic import BaseModel, ValidationInfo, model_validator +import openai +import instructor +import asyncio + +client = instructor.patch( + openai.AsyncOpenAI(), +) + + +class Tag(BaseModel): + id: int + name: str + + @model_validator(mode="after") + def validate_ids(self, info: ValidationInfo): + context = info.context + if context: + tags: List[Tag] = context.get("tags") + assert self.id in { + tag.id for tag in tags + }, f"Tag ID {self.id} not found in context" + assert self.name in { + tag.name for tag in tags + }, f"Tag name {self.name} not found in context" + return self + + +class TagWithInstructions(Tag): + instructions: str + + +class TagRequest(BaseModel): + texts: List[str] + tags: List[TagWithInstructions] + + +class TagResponse(BaseModel): + texts: List[str] + predictions: List[Tag] + + +async def tag_single_request(text: str, tags: List[Tag]) -> Tag: + allowed_tags = [(tag.id, tag.name) for tag in tags] + allowed_tags_str = ", ".join([f"`{tag}`" for tag in allowed_tags]) + return await client.chat.completions.create( + model="gpt-4-turbo-preview", + messages=[ + { + "role": "system", + "content": "You are a world-class text tagging system.", + }, + {"role": "user", "content": f"Describe the following text: `{text}`"}, + { + "role": "user", + "content": f"Here are the allowed tags: {allowed_tags_str}", + }, + ], + response_model=Tag, + # Minizises the hallucination of tags that are not in the allowed tags. + validation_context={"tags": tags}, + ) + + +async def tag_request(request: TagRequest) -> TagResponse: + predictions = await asyncio.gather( + *[tag_single_request(text, request.tags) for text in request.texts] + ) + return TagResponse( + texts=request.texts, + predictions=predictions, + ) + + +if __name__ == "__main__": + # Tags will be a range of different topics. + # Such as personal, phone, email, etc. + tags = [ + TagWithInstructions(id=0, name="personal", instructions="Personal information"), + TagWithInstructions(id=1, name="phone", instructions="Phone number"), + TagWithInstructions(id=2, name="email", instructions="Email address"), + TagWithInstructions(id=3, name="address", instructions="Address"), + TagWithInstructions(id=4, name="Other", instructions="Other information"), + ] + + # Texts will be a range of different questions. + # Such as "How much does it cost?", "What is your privacy policy?", etc. + texts = [ + "What is your phone number?", + "What is your email address?", + "What is your address?", + "What is your privacy policy?", + ] + + # The request will contain the texts and the tags. + request = TagRequest(texts=texts, tags=tags) + + # The response will contain the texts, the predicted tags, and the confidence. + response = asyncio.run(tag_request(request)) + print(response.model_dump_json(indent=2))