From d3a567f2ff4ee53e32ff18bb4382c5c34b22c6ba Mon Sep 17 00:00:00 2001 From: Jason Liu Date: Sat, 25 Nov 2023 19:56:40 -0500 Subject: [PATCH] clean up --- examples/caching/example_diskcache.py | 12 ++++-- examples/caching/example_redis.py | 9 +++-- examples/caching/lru.py | 9 +++-- examples/validators/moderation.py | 5 ++- instructor/__init__.py | 3 +- instructor/dsl/validators.py | 9 +++-- instructor/patch.py | 16 +++----- tests/openai/evals/test_extract_users.py | 51 ++++++++++++++++++++++++ tests/openai/test_modes.py | 2 +- tests/openai/test_patch.py | 8 +--- 10 files changed, 91 insertions(+), 33 deletions(-) create mode 100644 tests/openai/evals/test_extract_users.py diff --git a/examples/caching/example_diskcache.py b/examples/caching/example_diskcache.py index 8d94bb8..1217341 100644 --- a/examples/caching/example_diskcache.py +++ b/examples/caching/example_diskcache.py @@ -8,11 +8,14 @@ from pydantic import BaseModel client = instructor.patch(OpenAI()) + class UserDetail(BaseModel): name: str age: int -cache = diskcache.Cache('./my_cache_directory') + +cache = diskcache.Cache("./my_cache_directory") + def instructor_cache(func): """Cache a function that returns a Pydantic model""" @@ -38,6 +41,7 @@ def instructor_cache(func): return wrapper + @instructor_cache def extract(data) -> UserDetail: return client.chat.completions.create( @@ -45,12 +49,12 @@ def extract(data) -> UserDetail: response_model=UserDetail, messages=[ {"role": "user", "content": data}, - ] + ], ) def test_extract(): - import time + import time start = time.perf_counter() model = extract("Extract jason is 25 years old") @@ -68,4 +72,4 @@ def test_extract(): if __name__ == "__main__": test_extract() # Time taken: 0.7285366660216823 - # Time taken: 9.841693099588156e-05 \ No newline at end of file + # Time taken: 9.841693099588156e-05 diff --git a/examples/caching/example_redis.py b/examples/caching/example_redis.py index 2abee69..cdc1242 100644 --- a/examples/caching/example_redis.py +++ b/examples/caching/example_redis.py @@ -9,6 +9,7 @@ from openai import OpenAI client = instructor.patch(OpenAI()) cache = redis.Redis("localhost") + def instructor_cache(func): """Cache a function that returns a Pydantic model""" return_type = inspect.signature(func).return_annotation @@ -38,6 +39,7 @@ class UserDetail(BaseModel): name: str age: int + @instructor_cache def extract(data) -> UserDetail: # Assuming client.chat.completions.create returns a UserDetail instance @@ -46,11 +48,12 @@ def extract(data) -> UserDetail: response_model=UserDetail, messages=[ {"role": "user", "content": data}, - ] + ], ) + def test_extract(): - import time + import time start = time.perf_counter() model = extract("Extract jason is 25 years old") @@ -68,4 +71,4 @@ def test_extract(): if __name__ == "__main__": test_extract() # Time taken: 0.798335583996959 - # Time taken: 0.00017016706988215446 \ No newline at end of file + # Time taken: 0.00017016706988215446 diff --git a/examples/caching/lru.py b/examples/caching/lru.py index 3c26b10..c4e7352 100644 --- a/examples/caching/lru.py +++ b/examples/caching/lru.py @@ -5,10 +5,12 @@ import functools client = instructor.patch(OpenAI()) + class UserDetail(BaseModel): name: str age: int + @functools.lru_cache def extract(data): return client.chat.completions.create( @@ -16,12 +18,12 @@ def extract(data): response_model=UserDetail, messages=[ {"role": "user", "content": data}, - ] + ], ) def test_extract(): - import time + import time start = time.perf_counter() model = extract("Extract jason is 25 years old") @@ -35,7 +37,8 @@ def test_extract(): assert model.age == 25 print(f"Time taken: {time.perf_counter() - start}") + if __name__ == "__main__": test_extract() # Time taken: 0.9267581660533324 - # Time taken: 1.2080417945981026e-06 \ No newline at end of file + # Time taken: 1.2080417945981026e-06 diff --git a/examples/validators/moderation.py b/examples/validators/moderation.py index 9ced914..6cf228d 100644 --- a/examples/validators/moderation.py +++ b/examples/validators/moderation.py @@ -1,4 +1,4 @@ -import instructor +import instructor from instructor import openai_moderation @@ -8,8 +8,9 @@ from openai import OpenAI client = instructor.patch(OpenAI()) + class Response(BaseModel): message: Annotated[str, AfterValidator(openai_moderation(client=client))] -response = Response(message="I want to make them suffer the consequences") \ No newline at end of file +response = Response(message="I want to make them suffer the consequences") diff --git a/instructor/__init__.py b/instructor/__init__.py index 2aa2cb3..5d428ba 100644 --- a/instructor/__init__.py +++ b/instructor/__init__.py @@ -1,6 +1,6 @@ from .distil import FinetuneFormat, Instructions from .dsl import CitationMixin, Maybe, MultiTask, llm_validator, openai_moderation -from .function_calls import OpenAISchema, openai_function, openai_schema +from .function_calls import OpenAISchema, openai_function, openai_schema, Mode from .patch import apatch, patch __all__ = [ @@ -10,6 +10,7 @@ __all__ = [ "MultiTask", "Maybe", "openai_schema", + "Mode", "patch", "apatch", "llm_validator", diff --git a/instructor/dsl/validators.py b/instructor/dsl/validators.py index e67fcec..c5f7739 100644 --- a/instructor/dsl/validators.py +++ b/instructor/dsl/validators.py @@ -99,6 +99,7 @@ def llm_validator( return llm + def openai_moderation(client: OpenAI = None): """ Validates a message using OpenAI moderation model. @@ -133,8 +134,10 @@ def openai_moderation(client: OpenAI = None): out = response.results[0] cats = out.categories.model_dump() if out.flagged: - raise ValueError(f"`{v}` was flagged for {', '.join(cat for cat in cats if cats[cat])}") - + raise ValueError( + f"`{v}` was flagged for {', '.join(cat for cat in cats if cats[cat])}" + ) + return v - + return validate_message_with_openai_mod diff --git a/instructor/patch.py b/instructor/patch.py index 1830a8a..e6a96f0 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -55,9 +55,7 @@ def handle_response_model( if mode == Mode.FUNCTIONS: new_kwargs["functions"] = [response_model.openai_schema] # type: ignore - new_kwargs["function_call"] = { - "name": response_model.openai_schema["name"] - } # type: ignore + new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore elif mode == Mode.TOOLS: new_kwargs["tools"] = [ { @@ -92,7 +90,9 @@ def handle_response_model( raise ValueError(f"Invalid patch mode: {mode}") if new_kwargs.get("stream", False) and response_model is not None: - raise NotImplementedError("stream=True is not supported when using response_model parameter") + raise NotImplementedError( + "stream=True is not supported when using response_model parameter" + ) warnings.warn( "stream=True is not supported when using response_model parameter" @@ -204,9 +204,7 @@ def is_async(func: Callable) -> bool: ) -def wrap_chatcompletion( - func: Callable, mode: Mode = Mode.FUNCTIONS -) -> Callable: +def wrap_chatcompletion(func: Callable, mode: Mode = Mode.FUNCTIONS) -> Callable: func_is_async = is_async(func) @wraps(func) @@ -268,9 +266,7 @@ def wrap_chatcompletion( return wrapper_function -def patch( - client: Union[OpenAI, AsyncOpenAI], mode: Mode = Mode.FUNCTIONS -): +def patch(client: Union[OpenAI, AsyncOpenAI], mode: Mode = Mode.FUNCTIONS): """ Patch the `client.chat.completions.create` method diff --git a/tests/openai/evals/test_extract_users.py b/tests/openai/evals/test_extract_users.py new file mode 100644 index 0000000..9e0356d --- /dev/null +++ b/tests/openai/evals/test_extract_users.py @@ -0,0 +1,51 @@ +import pytest +from itertools import product +from pydantic import BaseModel +from openai import OpenAI +import instructor +from instructor.function_calls import Mode + + +class UserDetails(BaseModel): + name: str + age: int + + +# Lists for models, test data, and modes +models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview"] +test_data = [ + ("Jason is 10", "Jason", 10), + ("Alice is 25", "Alice", 25), + ("Bob is 35", "Bob", 35), +] +modes = [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS] + + +@pytest.mark.parametrize("model, data, mode", product(models, test_data, modes)) +def test_extract(model, data, mode): + sample_data, expected_name, expected_age = data + + if mode == Mode.JSON and model in {"gpt-3.5-turbo", "gpt-4"}: + pytest.skip( + "JSON mode is not supported for gpt-3.5-turbo and gpt-4, skipping test" + ) + + # Setting up the client with the instructor patch + client = instructor.patch(OpenAI(), mode=mode) + + # Calling the extract function with the provided model, sample data, and mode + response = client.chat.completions.create( + model=model, + response_model=UserDetails, + messages=[ + {"role": "user", "content": sample_data}, + ], + ) + + # Assertions + assert ( + response.name == expected_name + ), f"Expected name {expected_name}, got {response.name}" + assert ( + response.age == expected_age + ), f"Expected age {expected_age}, got {response.age}" diff --git a/tests/openai/test_modes.py b/tests/openai/test_modes.py index a275f4d..3b6d358 100644 --- a/tests/openai/test_modes.py +++ b/tests/openai/test_modes.py @@ -76,4 +76,4 @@ def test_mode(mode): ], ) assert user.name.lower() == "jason" - assert user.age == 25 \ No newline at end of file + assert user.age == 25 diff --git a/tests/openai/test_patch.py b/tests/openai/test_patch.py index f6a8e31..5619c30 100644 --- a/tests/openai/test_patch.py +++ b/tests/openai/test_patch.py @@ -15,9 +15,7 @@ class UserExtract(BaseModel): age: int -@pytest.mark.parametrize( - "mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS] -) +@pytest.mark.parametrize("mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS]) def test_runmodel(mode): client = instructor.patch(OpenAI(), mode=mode) model = client.chat.completions.create( @@ -36,9 +34,7 @@ def test_runmodel(mode): ), "The raw response should be available from OpenAI" -@pytest.mark.parametrize( - "mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS] -) +@pytest.mark.parametrize("mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS]) @pytest.mark.asyncio async def test_runmodel_async(mode): aclient = instructor.patch(AsyncOpenAI(), mode=mode)