From 6ecced4624ae95bdd4de4e74103c10be48111a46 Mon Sep 17 00:00:00 2001 From: Jason Liu Date: Wed, 20 Mar 2024 10:13:31 -0400 Subject: [PATCH] doc: make diskcache async compatible --- examples/caching/example_diskcache.py | 56 +++++++++++++++++++++++++-- 1 file changed, 53 insertions(+), 3 deletions(-) diff --git a/examples/caching/example_diskcache.py b/examples/caching/example_diskcache.py index 1217341..dd67ebf 100644 --- a/examples/caching/example_diskcache.py +++ b/examples/caching/example_diskcache.py @@ -3,10 +3,11 @@ import inspect import instructor import diskcache -from openai import OpenAI +from openai import OpenAI, AsyncOpenAI from pydantic import BaseModel client = instructor.patch(OpenAI()) +aclient = instructor.patch(AsyncOpenAI()) class UserDetail(BaseModel): @@ -23,6 +24,8 @@ def instructor_cache(func): if not issubclass(return_type, BaseModel): raise ValueError("The return type must be a Pydantic model") + is_async = inspect.iscoroutinefunction(func) + @functools.wraps(func) def wrapper(*args, **kwargs): key = f"{func.__name__}-{functools._make_key(args, kwargs, typed=False)}" @@ -39,7 +42,23 @@ def instructor_cache(func): return result - return wrapper + @functools.wraps(func) + async def awrapper(*args, **kwargs): + key = f"{func.__name__}-{functools._make_key(args, kwargs, typed=False)}" + # Check if the result is already cached + if (cached := cache.get(key)) is not None: + # Deserialize from JSON based on the return type + if issubclass(return_type, BaseModel): + return return_type.model_validate_json(cached) + + # Call the function and cache its result + result = await func(*args, **kwargs) + serialized_result = result.model_dump_json() + cache.set(key, serialized_result) + + return result + + return wrapper if not is_async else awrapper @instructor_cache @@ -50,7 +69,18 @@ def extract(data) -> UserDetail: messages=[ {"role": "user", "content": data}, ], - ) + ) # type: ignore + + +@instructor_cache +async def aextract(data) -> UserDetail: + return await aclient.chat.completions.create( + model="gpt-3.5-turbo", + response_model=UserDetail, + messages=[ + {"role": "user", "content": data}, + ], + ) # type: ignore def test_extract(): @@ -69,7 +99,27 @@ def test_extract(): print(f"Time taken: {time.perf_counter() - start}") +async def atest_extract(): + import time + + start = time.perf_counter() + model = await aextract("Extract jason is 25 years old") + assert model.name.lower() == "jason" + assert model.age == 25 + print(f"Time taken: {time.perf_counter() - start}") + + start = time.perf_counter() + model = await aextract("Extract jason is 25 years old") + assert model.name.lower() == "jason" + assert model.age == 25 + print(f"Time taken: {time.perf_counter() - start}") + + if __name__ == "__main__": test_extract() # Time taken: 0.7285366660216823 # Time taken: 9.841693099588156e-05 + + import asyncio + + asyncio.run(atest_extract())