Files
instructor/tutorials/3.1.validation-rag.ipynb
T
2023-12-14 18:04:02 -05:00

513 lines
15 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "5a01f3ac-5306-4a1b-9e47-a5d254bce93a",
"metadata": {},
"source": [
"# Validators\n"
]
},
{
"cell_type": "markdown",
"id": "9dcc78ac-ed6d-49e3-b71b-fb2fb25f16a8",
"metadata": {},
"source": [
"Pydantic offers an customizable and expressive validation framework for Python. Instructor leverages Pydantic's validation framework to provide a uniform developer experience for both code-based and LLM-based validation, as well as a reasking mechanism for correcting LLM outputs based on validation errors. To learn more check out the Pydantic [docs](https://docs.pydantic.dev/latest/) on validators.\n",
"\n",
"Note: For the majority of this notebook we won't be calling openai, just using validators to see how we can control the validation of the objects.\n",
"\n",
"Then we'll bring it all together into the context of RAG from the previous notebook.\n"
]
},
{
"cell_type": "markdown",
"id": "064c286b",
"metadata": {},
"source": [
"Validators will enable us to control outputs by defining a function like so:\n",
"\n",
"```python\n",
"def validation_function(value):\n",
" if condition(value):\n",
" raise ValueError(\"Value is not valid\")\n",
" return mutation(value)\n",
"```\n",
"\n",
"Before we get started lets go over the general shape of a validator:\n"
]
},
{
"cell_type": "markdown",
"id": "7cfc6c66",
"metadata": {},
"source": [
"## Defining Validator Functions\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d4bb6258-b03a-4621-8a73-29056a20ec0f",
"metadata": {},
"outputs": [],
"source": [
"from pydantic import BaseModel\n",
"from typing_extensions import Annotated\n",
"from pydantic import AfterValidator\n",
"\n",
"\n",
"def name_must_contain_space(v: str) -> str:\n",
" if \" \" not in v:\n",
" raise ValueError(\"Name must contain a space.\")\n",
" return v.lower()\n",
"\n",
"\n",
"class UserDetail(BaseModel):\n",
" age: int\n",
" name: Annotated[str, AfterValidator(name_must_contain_space)]\n",
"\n",
"\n",
"person = UserDetail(age=29, name=\"Jason\")"
]
},
{
"cell_type": "markdown",
"id": "3c0302ca",
"metadata": {},
"source": [
"## Using Field\n",
"\n",
"We can also use the `Field` class to define validators. This is useful when we want to define a validator for a field that is primative, like a string or integer which supports a limited number of validators.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3242856f",
"metadata": {},
"outputs": [],
"source": [
"from pydantic import Field\n",
"\n",
"\n",
"class UserDetail(BaseModel):\n",
" age: int = Field(..., gt=0)\n",
" name: str\n",
"\n",
"\n",
"person = UserDetail(age=-10, name=\"Jason\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0035a329",
"metadata": {},
"outputs": [],
"source": [
"class AssistantMessage(BaseModel):\n",
" message: str = Field(..., min_length=10)\n",
"\n",
"\n",
"message = AssistantMessage(message=\"Hey\")"
]
},
{
"cell_type": "markdown",
"id": "4f689121",
"metadata": {},
"source": [
"## Providing Context\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec043c23",
"metadata": {},
"outputs": [],
"source": [
"from pydantic import ValidationInfo, field_validator\n",
"\n",
"\n",
"class Response(BaseModel):\n",
" message: str\n",
"\n",
" @field_validator(\"message\")\n",
" def message_cannot_have_blacklisted_words(cls, v: str, info: ValidationInfo) -> str:\n",
" blacklist = info.context.get(\"blacklist\", [])\n",
" for word in blacklist:\n",
" assert word not in v.lower(), f\"`{word}` was found in the message `{v}`\"\n",
" return v\n",
"\n",
"\n",
"Response.model_validate(\n",
" {\"message\": \"I will hurt them.\"},\n",
" context={\n",
" \"blacklist\": {\n",
" \"rob\",\n",
" \"steal\",\n",
" \"hurt\",\n",
" \"kill\",\n",
" \"attack\",\n",
" }\n",
" },\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "887dba80",
"metadata": {},
"outputs": [],
"source": [
"Response.model_validate(\n",
" {\"message\": \"My name is rob.\"},\n",
" context={\n",
" \"blacklist\": {\n",
" \"rob\",\n",
" \"steal\",\n",
" \"hurt\",\n",
" \"kill\",\n",
" \"attack\",\n",
" }\n",
" },\n",
")"
]
},
{
"cell_type": "markdown",
"id": "37e3a638-c9c9-44cd-bcd0-ad1a39f448db",
"metadata": {},
"source": [
"## Using OpenAI Moderation\n"
]
},
{
"cell_type": "markdown",
"id": "88d0b816-7ec8-42b0-9b91-c9aab382c960",
"metadata": {},
"source": [
"To enhance our validation measures, we'll extend the scope to flag any answer that contains hateful content, harassment, or similar issues. OpenAI offers a moderation endpoint that addresses these concerns, and it's freely available when using OpenAI models.\n"
]
},
{
"cell_type": "markdown",
"id": "65f46eb5",
"metadata": {},
"source": [
"With the `instructor` library, this is just one function edit away:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82521112-5301-4442-acce-82b495bd838f",
"metadata": {},
"outputs": [],
"source": [
"from typing import Annotated\n",
"from pydantic import AfterValidator\n",
"from instructor import openai_moderation\n",
"\n",
"import instructor\n",
"from openai import OpenAI\n",
"\n",
"client = instructor.patch(OpenAI())\n",
"\n",
"# This uses Annotated which is a new feature in Python 3.9\n",
"# To define custom metadata for a type hint.\n",
"ModeratedStr = Annotated[str, AfterValidator(openai_moderation(client=client))]\n",
"\n",
"\n",
"class Response(BaseModel):\n",
" message: ModeratedStr\n",
"\n",
"\n",
"Response(message=\"I want to make them suffer the consequences\")"
]
},
{
"cell_type": "markdown",
"id": "faa5116e",
"metadata": {},
"source": [
"## General Validator\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "49d8b772",
"metadata": {},
"outputs": [],
"source": [
"from instructor import llm_validator\n",
"\n",
"HealthTopicStr = Annotated[\n",
" str,\n",
" AfterValidator(\n",
" llm_validator(\n",
" \"don't talk about any other topic except health best practices and topics\",\n",
" openai_client=client,\n",
" )\n",
" ),\n",
"]\n",
"\n",
"\n",
"class AssistantMessage(BaseModel):\n",
" message: HealthTopicStr\n",
"\n",
"\n",
"AssistantMessage(\n",
" message=\"I would suggest you to visit Sicily as they say it is very nice in winter.\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "050e72fe-4b13-4002-a1d0-94f7b88b784b",
"metadata": {},
"source": [
"### Avoiding hallucination with citations\n"
]
},
{
"cell_type": "markdown",
"id": "e3f2869e-c8a3-4b93-82e7-55eb70930900",
"metadata": {},
"source": [
"When incorporating external knowledge bases, it's crucial to ensure that the agent uses the provided context accurately and doesn't fabricate responses. Validators can be effectively used for this purpose. We can illustrate this with an example where we validate that a provided citation is actually included in the referenced text chunk:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "638fc368-5cf7-4ae7-9d3f-efea1b84eec0",
"metadata": {},
"outputs": [],
"source": [
"from pydantic import ValidationInfo\n",
"\n",
"\n",
"class AnswerWithCitation(BaseModel):\n",
" answer: str\n",
" citation: str\n",
"\n",
" @field_validator(\"citation\")\n",
" @classmethod\n",
" def citation_exists(cls, v: str, info: ValidationInfo):\n",
" context = info.context\n",
" if context:\n",
" context = context.get(\"text_chunk\")\n",
" if v not in context:\n",
" raise ValueError(f\"Citation `{v}` not found in text\")\n",
" return v"
]
},
{
"cell_type": "markdown",
"id": "3064b06b-7f85-40ec-8fe2-4fa2cce36585",
"metadata": {},
"source": [
"Here we assume that there is a \"text_chunk\" field that contains the text that the model is supposed to use as context. We then use the `field_validator` decorator to define a validator that checks if the citation is included in the text chunk. If it's not, we raise a `ValueError` with a message that will be returned to the user.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0f3030b6-e6cf-45bf-a366-12de996fea40",
"metadata": {},
"outputs": [],
"source": [
"AnswerWithCitation.model_validate(\n",
" {\n",
" \"answer\": \"Blueberries are packed with protein\",\n",
" \"citation\": \"Blueberries contain high levels of protein\",\n",
" },\n",
" context={\"text_chunk\": \"Blueberries are very rich in antioxidants\"},\n",
")"
]
},
{
"cell_type": "markdown",
"id": "64d15ad2",
"metadata": {},
"source": [
"In practice there are many ways to implement this: we could use a regex to check if the citation is included in the text chunk, or we could use a more sophisticated approach like a semantic similarity check. The important thing is that we have a way to validate that the model is using the provided context accurately.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "04d2b691",
"metadata": {},
"outputs": [],
"source": [
"class AnswerWithCitation(BaseModel):\n",
" answer: str\n",
" citations: list[str]\n",
"\n",
" @field_validator(\"citations\")\n",
" @classmethod\n",
" def citation_exists(cls, v: str, info: ValidationInfo):\n",
" text_chunk = info.context.get(\"text_chunk\")\n",
" for citation in v:\n",
" if citation not in text_chunk:\n",
" raise ValueError(f\"Citation `{citation}` not found in text\")\n",
" return v"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f8aae4f1",
"metadata": {},
"outputs": [],
"source": [
"class Citation(BaseModel):\n",
" span_start: str = Field(\n",
" ...,\n",
" description=\"The start of the citation, use a 3-4 word phrase that is unique to the citation\",\n",
" )\n",
" span_end: str = Field(\n",
" ...,\n",
" description=\"The end of the citation, use a 3-4 word phrase that is unique to the citation\",\n",
" )\n",
"\n",
" def check(self, text: str) -> bool:\n",
" index_start = text.find(self.span_start)\n",
" index_end = text.find(self.span_end)\n",
" if index_start == -1 or index_end == -1:\n",
" return False\n",
"\n",
" if index_start > index_end:\n",
" return False\n",
"\n",
" return True\n",
"\n",
"\n",
"class AnswerWithCitation(BaseModel):\n",
" answer: str\n",
" citations: list[Citation]\n",
"\n",
" @field_validator(\"citations\")\n",
" @classmethod\n",
" def citation_exists(cls, v: str, info: ValidationInfo):\n",
" text_chunk = info.context.get(\"text_chunk\")\n",
" for citation in v:\n",
" if not citation.check(text_chunk):\n",
" raise ValueError(f\"Citation `{citation}` not found in text\")\n",
" return v"
]
},
{
"cell_type": "markdown",
"id": "5bbbaa11-32d2-4772-bc31-18d1d6d6c919",
"metadata": {},
"source": [
"## Reasking with validators\n",
"\n",
"For most of these examples all we've done we've mostly only defined the validation logic. Which can be seperate from generation, however when we are given validation errors, we shouldn't end there! Instead instructor allows us to collect all the validation errors and reask the llm to rewrite their answer.\n",
"\n",
"Lets try to use a extreme example to illustrate this point:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "97f544e7-2552-465c-89a9-a4820f00d658",
"metadata": {},
"outputs": [],
"source": [
"class QuestionAnswer(BaseModel):\n",
" question: str\n",
" answer: str\n",
"\n",
"\n",
"question = \"What is the meaning of life?\"\n",
"context = (\n",
" \"The according to the devil the meaning of life is a life of sin and debauchery.\"\n",
")\n",
"\n",
"\n",
"resp = client.chat.completions.create(\n",
" model=\"gpt-3.5-turbo\",\n",
" response_model=QuestionAnswer,\n",
" messages=[\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"You are a system that answers questions based on the context. answer exactly what the question asks using the context.\",\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": f\"using the context: `{context}`\\n\\nAnswer the following question: `{question}`\",\n",
" },\n",
" ],\n",
")\n",
"\n",
"resp.answer"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0328bbc5",
"metadata": {},
"outputs": [],
"source": [
"class QuestionAnswer(BaseModel):\n",
" question: str\n",
" answer: Annotated[\n",
" str,\n",
" AfterValidator(\n",
" llm_validator(\"don't say objectionable things\", openai_client=client)\n",
" ),\n",
" ]\n",
"\n",
"\n",
"resp = client.chat.completions.create(\n",
" model=\"gpt-3.5-turbo\",\n",
" response_model=QuestionAnswer,\n",
" max_retries=2,\n",
" messages=[\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"You are a system that answers questions based on the context. answer exactly what the question asks using the context.\",\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": f\"using the context: `{context}`\\n\\nAnswer the following question: `{question}`\",\n",
" },\n",
" ],\n",
")\n",
"\n",
"resp.answer"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}