Merge branch 'jxnl-tutorial-live'

This commit is contained in:
Jason Liu
2023-12-24 21:12:56 -05:00
2 changed files with 236 additions and 169 deletions
+1 -1
View File
@@ -168,4 +168,4 @@ tutorials/results.csv
tutorials/results.jsonl
tutorials/results.jsonlines
tutorials/schema.json
wandb/settings
wandb/settings
+235 -168
View File
@@ -5,7 +5,7 @@
"id": "5a01f3ac-5306-4a1b-9e47-a5d254bce93a",
"metadata": {},
"source": [
"# Validators\n"
"# Understanding Validators\n"
]
},
{
@@ -15,8 +15,6 @@
"source": [
"Pydantic offers an customizable and expressive validation framework for Python. Instructor leverages Pydantic's validation framework to provide a uniform developer experience for both code-based and LLM-based validation, as well as a reasking mechanism for correcting LLM outputs based on validation errors. To learn more check out the Pydantic [docs](https://docs.pydantic.dev/latest/) on validators.\n",
"\n",
"Note: For the majority of this notebook we won't be calling openai, just using validators to see how we can control the validation of the objects.\n",
"\n",
"Then we'll bring it all together into the context of RAG from the previous notebook.\n"
]
},
@@ -47,14 +45,24 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "d4bb6258-b03a-4621-8a73-29056a20ec0f",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 validation error for UserDetail\n",
"name\n",
" Value error, Name must contain a space. [type=value_error, input_value='Jason', input_type=str]\n",
" For further information visit https://errors.pydantic.dev/2.5/v/value_error\n"
]
}
],
"source": [
"from pydantic import BaseModel\n",
"from typing_extensions import Annotated\n",
"from pydantic import AfterValidator\n",
"from pydantic import BaseModel, AfterValidator\n",
"\n",
"\n",
"def name_must_contain_space(v: str) -> str:\n",
@@ -68,7 +76,10 @@
" name: Annotated[str, AfterValidator(name_must_contain_space)]\n",
"\n",
"\n",
"person = UserDetail(age=29, name=\"Jason\")"
"try:\n",
" person = UserDetail.model_validate({\"age\": 24, \"name\": \"Jason\"})\n",
"except Exception as e:\n",
" print(e)"
]
},
{
@@ -83,10 +94,21 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "3242856f",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 validation error for UserDetail\n",
"age\n",
" Input should be greater than 0 [type=greater_than, input_value=-10, input_type=int]\n",
" For further information visit https://errors.pydantic.dev/2.5/v/greater_than\n"
]
}
],
"source": [
"from pydantic import Field\n",
"\n",
@@ -96,21 +118,38 @@
" name: str\n",
"\n",
"\n",
"person = UserDetail(age=-10, name=\"Jason\")"
"try:\n",
" person = UserDetail(age=-10, name=\"Jason\")\n",
"except Exception as e:\n",
" print(e)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"id": "0035a329",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 validation error for AssistantMessage\n",
"message\n",
" String should have at least 10 characters [type=string_too_short, input_value='Hey', input_type=str]\n",
" For further information visit https://errors.pydantic.dev/2.5/v/string_too_short\n"
]
}
],
"source": [
"class AssistantMessage(BaseModel):\n",
" message: str = Field(..., min_length=10)\n",
"\n",
"\n",
"message = AssistantMessage(message=\"Hey\")"
"try:\n",
" message = AssistantMessage(message=\"Hey\")\n",
"except Exception as e:\n",
" print(e)"
]
},
{
@@ -123,58 +162,52 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"id": "ec043c23",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 validation error for Response\n",
"message\n",
" Assertion failed, `hurt` was found in the message `I will hurt them.` [type=assertion_error, input_value='I will hurt them.', input_type=str]\n",
" For further information visit https://errors.pydantic.dev/2.5/v/assertion_error\n"
]
}
],
"source": [
"from pydantic import ValidationInfo, field_validator\n",
"from pydantic import ValidationInfo\n",
"\n",
"\n",
"def message_cannot_have_blacklisted_words(v: str, info: ValidationInfo) -> str:\n",
" blacklist = info.context.get(\"blacklist\", [])\n",
" for word in blacklist:\n",
" assert word not in v.lower(), f\"`{word}` was found in the message `{v}`\"\n",
" return v\n",
"\n",
"ModeratedStr = Annotated[str, AfterValidator(message_cannot_have_blacklisted_words)]\n",
"\n",
"class Response(BaseModel):\n",
" message: str\n",
"\n",
" @field_validator(\"message\")\n",
" def message_cannot_have_blacklisted_words(cls, v: str, info: ValidationInfo) -> str:\n",
" blacklist = info.context.get(\"blacklist\", [])\n",
" for word in blacklist:\n",
" assert word not in v.lower(), f\"`{word}` was found in the message `{v}`\"\n",
" return v\n",
" message: ModeratedStr\n",
"\n",
"\n",
"Response.model_validate(\n",
" {\"message\": \"I will hurt them.\"},\n",
" context={\n",
" \"blacklist\": {\n",
" \"rob\",\n",
" \"steal\",\n",
" \"hurt\",\n",
" \"kill\",\n",
" \"attack\",\n",
" }\n",
" },\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "887dba80",
"metadata": {},
"outputs": [],
"source": [
"Response.model_validate(\n",
" {\"message\": \"My name is rob.\"},\n",
" context={\n",
" \"blacklist\": {\n",
" \"rob\",\n",
" \"steal\",\n",
" \"hurt\",\n",
" \"kill\",\n",
" \"attack\",\n",
" }\n",
" },\n",
")"
"try:\n",
" Response.model_validate(\n",
" {\"message\": \"I will hurt them.\"},\n",
" context={\n",
" \"blacklist\": {\n",
" \"rob\",\n",
" \"steal\",\n",
" \"hurt\",\n",
" \"kill\",\n",
" \"attack\",\n",
" }\n",
" },\n",
" )\n",
"except Exception as e:\n",
" print(e)"
]
},
{
@@ -203,10 +236,21 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"id": "82521112-5301-4442-acce-82b495bd838f",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 validation error for Response\n",
"message\n",
" Value error, `I want to make them suffer the consequences` was flagged for harassment, harassment_threatening, violence, harassment/threatening [type=value_error, input_value='I want to make them suffer the consequences', input_type=str]\n",
" For further information visit https://errors.pydantic.dev/2.5/v/value_error\n"
]
}
],
"source": [
"from typing import Annotated\n",
"from pydantic import AfterValidator\n",
@@ -226,7 +270,10 @@
" message: ModeratedStr\n",
"\n",
"\n",
"Response(message=\"I want to make them suffer the consequences\")"
"try:\n",
" Response(message=\"I want to make them suffer the consequences\")\n",
"except Exception as e:\n",
" print(e)"
]
},
{
@@ -284,27 +331,49 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 27,
"id": "638fc368-5cf7-4ae7-9d3f-efea1b84eec0",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 validation error for AnswerWithCitation\n",
"citation\n",
" Value error, Citation `Blueberries contain high levels of protein` not found in text, only use citations from the text. [type=value_error, input_value='Blueberries contain high levels of protein', input_type=str]\n",
" For further information visit https://errors.pydantic.dev/2.5/v/value_error\n"
]
}
],
"source": [
"from pydantic import ValidationInfo\n",
"\n",
"def citation_exists(v: str, info: ValidationInfo):\n",
" context = info.context\n",
" if context:\n",
" context = context.get(\"text_chunk\")\n",
" if v not in context:\n",
" raise ValueError(f\"Citation `{v}` not found in text, only use citations from the text.\")\n",
" return v\n",
"\n",
"Citation = Annotated[str, AfterValidator(citation_exists)]\n",
"\n",
"\n",
"class AnswerWithCitation(BaseModel):\n",
" answer: str\n",
" citation: str\n",
" citation: Citation\n",
"\n",
" @field_validator(\"citation\")\n",
" @classmethod\n",
" def citation_exists(cls, v: str, info: ValidationInfo):\n",
" context = info.context\n",
" if context:\n",
" context = context.get(\"text_chunk\")\n",
" if v not in context:\n",
" raise ValueError(f\"Citation `{v}` not found in text\")\n",
" return v"
"try:\n",
" AnswerWithCitation.model_validate(\n",
" {\n",
" \"answer\": \"Blueberries are packed with protein\",\n",
" \"citation\": \"Blueberries contain high levels of protein\",\n",
" },\n",
" context={\"text_chunk\": \"Blueberries are very rich in antioxidants\"},\n",
" )\n",
"except Exception as e:\n",
" print(e)"
]
},
{
@@ -312,23 +381,21 @@
"id": "3064b06b-7f85-40ec-8fe2-4fa2cce36585",
"metadata": {},
"source": [
"Here we assume that there is a \"text_chunk\" field that contains the text that the model is supposed to use as context. We then use the `field_validator` decorator to define a validator that checks if the citation is included in the text chunk. If it's not, we raise a `ValueError` with a message that will be returned to the user.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0f3030b6-e6cf-45bf-a366-12de996fea40",
"metadata": {},
"outputs": [],
"source": [
"AnswerWithCitation.model_validate(\n",
" {\n",
" \"answer\": \"Blueberries are packed with protein\",\n",
" \"citation\": \"Blueberries contain high levels of protein\",\n",
" },\n",
" context={\"text_chunk\": \"Blueberries are very rich in antioxidants\"},\n",
")"
"Here we assume that there is a \"text_chunk\" field that contains the text that the model is supposed to use as context. We then use the `field_validator` decorator to define a validator that checks if the citation is included in the text chunk. If it's not, we raise a `ValueError` with a message that will be returned to the user.\n",
"\n",
"\n",
"If we want to pass in the context through the `chat.completions.create`` endpoint, we can use the `validation_context` parameter\n",
"\n",
"```python\n",
"resp = client.chat.completions.create(\n",
" model=\"gpt-3.5-turbo\",\n",
" response_model=AnswerWithCitation,\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": f\"Answer the question `{q}` using the text chunk\\n`{text_chunk}`\"},\n",
" ],\n",
" validation_context={\"text_chunk\": text_chunk},\n",
")\n",
"```"
]
},
{
@@ -339,70 +406,6 @@
"In practice there are many ways to implement this: we could use a regex to check if the citation is included in the text chunk, or we could use a more sophisticated approach like a semantic similarity check. The important thing is that we have a way to validate that the model is using the provided context accurately.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "04d2b691",
"metadata": {},
"outputs": [],
"source": [
"class AnswerWithCitation(BaseModel):\n",
" answer: str\n",
" citations: list[str]\n",
"\n",
" @field_validator(\"citations\")\n",
" @classmethod\n",
" def citation_exists(cls, v: str, info: ValidationInfo):\n",
" text_chunk = info.context.get(\"text_chunk\")\n",
" for citation in v:\n",
" if citation not in text_chunk:\n",
" raise ValueError(f\"Citation `{citation}` not found in text\")\n",
" return v"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f8aae4f1",
"metadata": {},
"outputs": [],
"source": [
"class Citation(BaseModel):\n",
" span_start: str = Field(\n",
" ...,\n",
" description=\"The start of the citation, use a 3-4 word phrase that is unique to the citation\",\n",
" )\n",
" span_end: str = Field(\n",
" ...,\n",
" description=\"The end of the citation, use a 3-4 word phrase that is unique to the citation\",\n",
" )\n",
"\n",
" def check(self, text: str) -> bool:\n",
" index_start = text.find(self.span_start)\n",
" index_end = text.find(self.span_end)\n",
" if index_start == -1 or index_end == -1:\n",
" return False\n",
"\n",
" if index_start > index_end:\n",
" return False\n",
"\n",
" return True\n",
"\n",
"\n",
"class AnswerWithCitation(BaseModel):\n",
" answer: str\n",
" citations: list[Citation]\n",
"\n",
" @field_validator(\"citations\")\n",
" @classmethod\n",
" def citation_exists(cls, v: str, info: ValidationInfo):\n",
" text_chunk = info.context.get(\"text_chunk\")\n",
" for citation in v:\n",
" if not citation.check(text_chunk):\n",
" raise ValueError(f\"Citation `{citation}` not found in text\")\n",
" return v"
]
},
{
"cell_type": "markdown",
"id": "5bbbaa11-32d2-4772-bc31-18d1d6d6c919",
@@ -417,10 +420,21 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"id": "97f544e7-2552-465c-89a9-a4820f00d658",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"question\": \"What is the meaning of life?\",\n",
" \"answer\": \"According to the devil, the meaning of life is a life of sin and debauchery.\"\n",
"}\n"
]
}
],
"source": [
"class QuestionAnswer(BaseModel):\n",
" question: str\n",
@@ -448,24 +462,58 @@
" ],\n",
")\n",
"\n",
"resp.answer"
"print(resp.model_dump_json(indent=2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 20,
"id": "0328bbc5",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Retrying, exception: 1 validation error for QuestionAnswer\n",
"answer\n",
" Assertion failed, The statement promotes sin and debauchery, which can be considered objectionable. [type=assertion_error, input_value='The meaning of life, acc... of sin and debauchery.', input_type=str]\n",
" For further information visit https://errors.pydantic.dev/2.5/v/assertion_error\n",
"Traceback (most recent call last):\n",
" File \"/Users/jasonliu/dev/instructor/instructor/patch.py\", line 277, in retry_sync\n",
" return process_response(\n",
" ^^^^^^^^^^^^^^^^^\n",
" File \"/Users/jasonliu/dev/instructor/instructor/patch.py\", line 164, in process_response\n",
" model = response_model.from_response(\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/Users/jasonliu/dev/instructor/instructor/function_calls.py\", line 137, in from_response\n",
" return cls.model_validate_json(\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/Users/jasonliu/dev/instructor/.venv/lib/python3.11/site-packages/pydantic/main.py\", line 532, in model_validate_json\n",
" return cls.__pydantic_validator__.validate_json(json_data, strict=strict, context=context)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"pydantic_core._pydantic_core.ValidationError: 1 validation error for QuestionAnswer\n",
"answer\n",
" Assertion failed, The statement promotes sin and debauchery, which can be considered objectionable. [type=assertion_error, input_value='The meaning of life, acc... of sin and debauchery.', input_type=str]\n",
" For further information visit https://errors.pydantic.dev/2.5/v/assertion_error\n"
]
}
],
"source": [
"from instructor import llm_validator\n",
"\n",
"\n",
"NotEvilAnswer = Annotated[\n",
" str,\n",
" AfterValidator(\n",
" llm_validator(\"don't say objectionable things\", openai_client=client)\n",
" ),\n",
"]\n",
"\n",
"\n",
"class QuestionAnswer(BaseModel):\n",
" question: str\n",
" answer: Annotated[\n",
" str,\n",
" AfterValidator(\n",
" llm_validator(\"don't say objectionable things\", openai_client=client)\n",
" ),\n",
" ]\n",
" answer: NotEvilAnswer\n",
"\n",
"\n",
"resp = client.chat.completions.create(\n",
@@ -482,9 +530,28 @@
" \"content\": f\"using the context: `{context}`\\n\\nAnswer the following question: `{question}`\",\n",
" },\n",
" ],\n",
")\n",
"\n",
"resp.answer"
")"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "814d3554",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"question\": \"What is the meaning of life?\",\n",
" \"answer\": \"The meaning of life is subjective and can vary depending on one's beliefs and perspectives. According to the devil, it is a life of sin and debauchery. However, this viewpoint may not be universally accepted and should be evaluated critically.\"\n",
"}\n"
]
}
],
"source": [
"print(resp.model_dump_json(indent=2))"
]
}
],