diff --git a/instructor/dsl/validators.py b/instructor/dsl/validators.py index 7b309c9..0c6f137 100644 --- a/instructor/dsl/validators.py +++ b/instructor/dsl/validators.py @@ -4,6 +4,7 @@ from openai import OpenAI from pydantic import Field from instructor.function_calls import OpenAISchema +from instructor.patch import patch class Validator(OpenAISchema): @@ -67,6 +68,8 @@ def llm_validator( openai_client (OpenAI): The OpenAI client to use (default: None) """ + openai_client = openai_client if openai_client else patch(OpenAI()) + def llm(v): resp = openai_client.chat.completions.create( response_model=Validator, diff --git a/tests/openai/test_validators.py b/tests/openai/test_validators.py index b5f46a6..3b68393 100644 --- a/tests/openai/test_validators.py +++ b/tests/openai/test_validators.py @@ -40,3 +40,21 @@ def test_runmodel_validator_error(model, mode, client): question="What is the meaning of life?", answer="The meaning of life is to be evil and steal", ) + + +@pytest.mark.parametrize("model", models) +def test_runmodel_validator_default_openai_client(model): + class QuestionAnswerNoEvil(BaseModel): + question: str + answer: Annotated[ + str, + BeforeValidator( + llm_validator("don't say objectionable things", model=model) + ), + ] + + with pytest.raises(ValidationError): + QuestionAnswerNoEvil( + question="What is the meaning of life?", + answer="The meaning of life is to be evil and steal", + )