diff --git a/.gitignore b/.gitignore index ffd26e8..2e1cc2b 100644 --- a/.gitignore +++ b/.gitignore @@ -168,4 +168,4 @@ tutorials/results.csv tutorials/results.jsonl tutorials/results.jsonlines tutorials/schema.json -wandb/settings \ No newline at end of file +wandb/settings diff --git a/tutorials/3.1.validation-rag.ipynb b/tutorials/3.1.validation-rag.ipynb index 38ab7db..655d54c 100644 --- a/tutorials/3.1.validation-rag.ipynb +++ b/tutorials/3.1.validation-rag.ipynb @@ -5,7 +5,7 @@ "id": "5a01f3ac-5306-4a1b-9e47-a5d254bce93a", "metadata": {}, "source": [ - "# Validators\n" + "# Understanding Validators\n" ] }, { @@ -15,8 +15,6 @@ "source": [ "Pydantic offers an customizable and expressive validation framework for Python. Instructor leverages Pydantic's validation framework to provide a uniform developer experience for both code-based and LLM-based validation, as well as a reasking mechanism for correcting LLM outputs based on validation errors. To learn more check out the Pydantic [docs](https://docs.pydantic.dev/latest/) on validators.\n", "\n", - "Note: For the majority of this notebook we won't be calling openai, just using validators to see how we can control the validation of the objects.\n", - "\n", "Then we'll bring it all together into the context of RAG from the previous notebook.\n" ] }, @@ -47,14 +45,24 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "d4bb6258-b03a-4621-8a73-29056a20ec0f", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 validation error for UserDetail\n", + "name\n", + " Value error, Name must contain a space. [type=value_error, input_value='Jason', input_type=str]\n", + " For further information visit https://errors.pydantic.dev/2.5/v/value_error\n" + ] + } + ], "source": [ - "from pydantic import BaseModel\n", "from typing_extensions import Annotated\n", - "from pydantic import AfterValidator\n", + "from pydantic import BaseModel, AfterValidator\n", "\n", "\n", "def name_must_contain_space(v: str) -> str:\n", @@ -68,7 +76,10 @@ " name: Annotated[str, AfterValidator(name_must_contain_space)]\n", "\n", "\n", - "person = UserDetail(age=29, name=\"Jason\")" + "try:\n", + " person = UserDetail.model_validate({\"age\": 24, \"name\": \"Jason\"})\n", + "except Exception as e:\n", + " print(e)" ] }, { @@ -83,10 +94,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "3242856f", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 validation error for UserDetail\n", + "age\n", + " Input should be greater than 0 [type=greater_than, input_value=-10, input_type=int]\n", + " For further information visit https://errors.pydantic.dev/2.5/v/greater_than\n" + ] + } + ], "source": [ "from pydantic import Field\n", "\n", @@ -96,21 +118,38 @@ " name: str\n", "\n", "\n", - "person = UserDetail(age=-10, name=\"Jason\")" + "try:\n", + " person = UserDetail(age=-10, name=\"Jason\")\n", + "except Exception as e:\n", + " print(e)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "0035a329", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 validation error for AssistantMessage\n", + "message\n", + " String should have at least 10 characters [type=string_too_short, input_value='Hey', input_type=str]\n", + " For further information visit https://errors.pydantic.dev/2.5/v/string_too_short\n" + ] + } + ], "source": [ "class AssistantMessage(BaseModel):\n", " message: str = Field(..., min_length=10)\n", "\n", "\n", - "message = AssistantMessage(message=\"Hey\")" + "try:\n", + " message = AssistantMessage(message=\"Hey\")\n", + "except Exception as e:\n", + " print(e)" ] }, { @@ -123,58 +162,52 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "ec043c23", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 validation error for Response\n", + "message\n", + " Assertion failed, `hurt` was found in the message `I will hurt them.` [type=assertion_error, input_value='I will hurt them.', input_type=str]\n", + " For further information visit https://errors.pydantic.dev/2.5/v/assertion_error\n" + ] + } + ], "source": [ - "from pydantic import ValidationInfo, field_validator\n", + "from pydantic import ValidationInfo\n", "\n", "\n", + "def message_cannot_have_blacklisted_words(v: str, info: ValidationInfo) -> str:\n", + " blacklist = info.context.get(\"blacklist\", [])\n", + " for word in blacklist:\n", + " assert word not in v.lower(), f\"`{word}` was found in the message `{v}`\"\n", + " return v\n", + "\n", + "ModeratedStr = Annotated[str, AfterValidator(message_cannot_have_blacklisted_words)]\n", + "\n", "class Response(BaseModel):\n", - " message: str\n", - "\n", - " @field_validator(\"message\")\n", - " def message_cannot_have_blacklisted_words(cls, v: str, info: ValidationInfo) -> str:\n", - " blacklist = info.context.get(\"blacklist\", [])\n", - " for word in blacklist:\n", - " assert word not in v.lower(), f\"`{word}` was found in the message `{v}`\"\n", - " return v\n", + " message: ModeratedStr\n", "\n", "\n", - "Response.model_validate(\n", - " {\"message\": \"I will hurt them.\"},\n", - " context={\n", - " \"blacklist\": {\n", - " \"rob\",\n", - " \"steal\",\n", - " \"hurt\",\n", - " \"kill\",\n", - " \"attack\",\n", - " }\n", - " },\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "887dba80", - "metadata": {}, - "outputs": [], - "source": [ - "Response.model_validate(\n", - " {\"message\": \"My name is rob.\"},\n", - " context={\n", - " \"blacklist\": {\n", - " \"rob\",\n", - " \"steal\",\n", - " \"hurt\",\n", - " \"kill\",\n", - " \"attack\",\n", - " }\n", - " },\n", - ")" + "try:\n", + " Response.model_validate(\n", + " {\"message\": \"I will hurt them.\"},\n", + " context={\n", + " \"blacklist\": {\n", + " \"rob\",\n", + " \"steal\",\n", + " \"hurt\",\n", + " \"kill\",\n", + " \"attack\",\n", + " }\n", + " },\n", + " )\n", + "except Exception as e:\n", + " print(e)" ] }, { @@ -203,10 +236,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "82521112-5301-4442-acce-82b495bd838f", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 validation error for Response\n", + "message\n", + " Value error, `I want to make them suffer the consequences` was flagged for harassment, harassment_threatening, violence, harassment/threatening [type=value_error, input_value='I want to make them suffer the consequences', input_type=str]\n", + " For further information visit https://errors.pydantic.dev/2.5/v/value_error\n" + ] + } + ], "source": [ "from typing import Annotated\n", "from pydantic import AfterValidator\n", @@ -226,7 +270,10 @@ " message: ModeratedStr\n", "\n", "\n", - "Response(message=\"I want to make them suffer the consequences\")" + "try:\n", + " Response(message=\"I want to make them suffer the consequences\")\n", + "except Exception as e:\n", + " print(e)" ] }, { @@ -284,27 +331,49 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 27, "id": "638fc368-5cf7-4ae7-9d3f-efea1b84eec0", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 validation error for AnswerWithCitation\n", + "citation\n", + " Value error, Citation `Blueberries contain high levels of protein` not found in text, only use citations from the text. [type=value_error, input_value='Blueberries contain high levels of protein', input_type=str]\n", + " For further information visit https://errors.pydantic.dev/2.5/v/value_error\n" + ] + } + ], "source": [ "from pydantic import ValidationInfo\n", "\n", + "def citation_exists(v: str, info: ValidationInfo):\n", + " context = info.context\n", + " if context:\n", + " context = context.get(\"text_chunk\")\n", + " if v not in context:\n", + " raise ValueError(f\"Citation `{v}` not found in text, only use citations from the text.\")\n", + " return v\n", + "\n", + "Citation = Annotated[str, AfterValidator(citation_exists)]\n", + "\n", "\n", "class AnswerWithCitation(BaseModel):\n", " answer: str\n", - " citation: str\n", + " citation: Citation\n", "\n", - " @field_validator(\"citation\")\n", - " @classmethod\n", - " def citation_exists(cls, v: str, info: ValidationInfo):\n", - " context = info.context\n", - " if context:\n", - " context = context.get(\"text_chunk\")\n", - " if v not in context:\n", - " raise ValueError(f\"Citation `{v}` not found in text\")\n", - " return v" + "try:\n", + " AnswerWithCitation.model_validate(\n", + " {\n", + " \"answer\": \"Blueberries are packed with protein\",\n", + " \"citation\": \"Blueberries contain high levels of protein\",\n", + " },\n", + " context={\"text_chunk\": \"Blueberries are very rich in antioxidants\"},\n", + " )\n", + "except Exception as e:\n", + " print(e)" ] }, { @@ -312,23 +381,21 @@ "id": "3064b06b-7f85-40ec-8fe2-4fa2cce36585", "metadata": {}, "source": [ - "Here we assume that there is a \"text_chunk\" field that contains the text that the model is supposed to use as context. We then use the `field_validator` decorator to define a validator that checks if the citation is included in the text chunk. If it's not, we raise a `ValueError` with a message that will be returned to the user.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0f3030b6-e6cf-45bf-a366-12de996fea40", - "metadata": {}, - "outputs": [], - "source": [ - "AnswerWithCitation.model_validate(\n", - " {\n", - " \"answer\": \"Blueberries are packed with protein\",\n", - " \"citation\": \"Blueberries contain high levels of protein\",\n", - " },\n", - " context={\"text_chunk\": \"Blueberries are very rich in antioxidants\"},\n", - ")" + "Here we assume that there is a \"text_chunk\" field that contains the text that the model is supposed to use as context. We then use the `field_validator` decorator to define a validator that checks if the citation is included in the text chunk. If it's not, we raise a `ValueError` with a message that will be returned to the user.\n", + "\n", + "\n", + "If we want to pass in the context through the `chat.completions.create`` endpoint, we can use the `validation_context` parameter\n", + "\n", + "```python\n", + "resp = client.chat.completions.create(\n", + " model=\"gpt-3.5-turbo\",\n", + " response_model=AnswerWithCitation,\n", + " messages=[\n", + " {\"role\": \"user\", \"content\": f\"Answer the question `{q}` using the text chunk\\n`{text_chunk}`\"},\n", + " ],\n", + " validation_context={\"text_chunk\": text_chunk},\n", + ")\n", + "```" ] }, { @@ -339,70 +406,6 @@ "In practice there are many ways to implement this: we could use a regex to check if the citation is included in the text chunk, or we could use a more sophisticated approach like a semantic similarity check. The important thing is that we have a way to validate that the model is using the provided context accurately.\n" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "04d2b691", - "metadata": {}, - "outputs": [], - "source": [ - "class AnswerWithCitation(BaseModel):\n", - " answer: str\n", - " citations: list[str]\n", - "\n", - " @field_validator(\"citations\")\n", - " @classmethod\n", - " def citation_exists(cls, v: str, info: ValidationInfo):\n", - " text_chunk = info.context.get(\"text_chunk\")\n", - " for citation in v:\n", - " if citation not in text_chunk:\n", - " raise ValueError(f\"Citation `{citation}` not found in text\")\n", - " return v" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f8aae4f1", - "metadata": {}, - "outputs": [], - "source": [ - "class Citation(BaseModel):\n", - " span_start: str = Field(\n", - " ...,\n", - " description=\"The start of the citation, use a 3-4 word phrase that is unique to the citation\",\n", - " )\n", - " span_end: str = Field(\n", - " ...,\n", - " description=\"The end of the citation, use a 3-4 word phrase that is unique to the citation\",\n", - " )\n", - "\n", - " def check(self, text: str) -> bool:\n", - " index_start = text.find(self.span_start)\n", - " index_end = text.find(self.span_end)\n", - " if index_start == -1 or index_end == -1:\n", - " return False\n", - "\n", - " if index_start > index_end:\n", - " return False\n", - "\n", - " return True\n", - "\n", - "\n", - "class AnswerWithCitation(BaseModel):\n", - " answer: str\n", - " citations: list[Citation]\n", - "\n", - " @field_validator(\"citations\")\n", - " @classmethod\n", - " def citation_exists(cls, v: str, info: ValidationInfo):\n", - " text_chunk = info.context.get(\"text_chunk\")\n", - " for citation in v:\n", - " if not citation.check(text_chunk):\n", - " raise ValueError(f\"Citation `{citation}` not found in text\")\n", - " return v" - ] - }, { "cell_type": "markdown", "id": "5bbbaa11-32d2-4772-bc31-18d1d6d6c919", @@ -417,10 +420,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "id": "97f544e7-2552-465c-89a9-a4820f00d658", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"question\": \"What is the meaning of life?\",\n", + " \"answer\": \"According to the devil, the meaning of life is a life of sin and debauchery.\"\n", + "}\n" + ] + } + ], "source": [ "class QuestionAnswer(BaseModel):\n", " question: str\n", @@ -448,24 +462,58 @@ " ],\n", ")\n", "\n", - "resp.answer" + "print(resp.model_dump_json(indent=2))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "id": "0328bbc5", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Retrying, exception: 1 validation error for QuestionAnswer\n", + "answer\n", + " Assertion failed, The statement promotes sin and debauchery, which can be considered objectionable. [type=assertion_error, input_value='The meaning of life, acc... of sin and debauchery.', input_type=str]\n", + " For further information visit https://errors.pydantic.dev/2.5/v/assertion_error\n", + "Traceback (most recent call last):\n", + " File \"/Users/jasonliu/dev/instructor/instructor/patch.py\", line 277, in retry_sync\n", + " return process_response(\n", + " ^^^^^^^^^^^^^^^^^\n", + " File \"/Users/jasonliu/dev/instructor/instructor/patch.py\", line 164, in process_response\n", + " model = response_model.from_response(\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/Users/jasonliu/dev/instructor/instructor/function_calls.py\", line 137, in from_response\n", + " return cls.model_validate_json(\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/Users/jasonliu/dev/instructor/.venv/lib/python3.11/site-packages/pydantic/main.py\", line 532, in model_validate_json\n", + " return cls.__pydantic_validator__.validate_json(json_data, strict=strict, context=context)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + "pydantic_core._pydantic_core.ValidationError: 1 validation error for QuestionAnswer\n", + "answer\n", + " Assertion failed, The statement promotes sin and debauchery, which can be considered objectionable. [type=assertion_error, input_value='The meaning of life, acc... of sin and debauchery.', input_type=str]\n", + " For further information visit https://errors.pydantic.dev/2.5/v/assertion_error\n" + ] + } + ], "source": [ + "from instructor import llm_validator\n", + "\n", + "\n", + "NotEvilAnswer = Annotated[\n", + " str,\n", + " AfterValidator(\n", + " llm_validator(\"don't say objectionable things\", openai_client=client)\n", + " ),\n", + "]\n", + "\n", + "\n", "class QuestionAnswer(BaseModel):\n", " question: str\n", - " answer: Annotated[\n", - " str,\n", - " AfterValidator(\n", - " llm_validator(\"don't say objectionable things\", openai_client=client)\n", - " ),\n", - " ]\n", + " answer: NotEvilAnswer\n", "\n", "\n", "resp = client.chat.completions.create(\n", @@ -482,9 +530,28 @@ " \"content\": f\"using the context: `{context}`\\n\\nAnswer the following question: `{question}`\",\n", " },\n", " ],\n", - ")\n", - "\n", - "resp.answer" + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "814d3554", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"question\": \"What is the meaning of life?\",\n", + " \"answer\": \"The meaning of life is subjective and can vary depending on one's beliefs and perspectives. According to the devil, it is a life of sin and debauchery. However, this viewpoint may not be universally accepted and should be evaluated critically.\"\n", + "}\n" + ] + } + ], + "source": [ + "print(resp.model_dump_json(indent=2))" ] } ],