doc: make diskcache async compatible

This commit is contained in:
Jason Liu
2024-03-20 10:13:31 -04:00
parent 120038f8e7
commit 6ecced4624
+53 -3
View File
@@ -3,10 +3,11 @@ import inspect
import instructor
import diskcache
from openai import OpenAI
from openai import OpenAI, AsyncOpenAI
from pydantic import BaseModel
client = instructor.patch(OpenAI())
aclient = instructor.patch(AsyncOpenAI())
class UserDetail(BaseModel):
@@ -23,6 +24,8 @@ def instructor_cache(func):
if not issubclass(return_type, BaseModel):
raise ValueError("The return type must be a Pydantic model")
is_async = inspect.iscoroutinefunction(func)
@functools.wraps(func)
def wrapper(*args, **kwargs):
key = f"{func.__name__}-{functools._make_key(args, kwargs, typed=False)}"
@@ -39,7 +42,23 @@ def instructor_cache(func):
return result
return wrapper
@functools.wraps(func)
async def awrapper(*args, **kwargs):
key = f"{func.__name__}-{functools._make_key(args, kwargs, typed=False)}"
# Check if the result is already cached
if (cached := cache.get(key)) is not None:
# Deserialize from JSON based on the return type
if issubclass(return_type, BaseModel):
return return_type.model_validate_json(cached)
# Call the function and cache its result
result = await func(*args, **kwargs)
serialized_result = result.model_dump_json()
cache.set(key, serialized_result)
return result
return wrapper if not is_async else awrapper
@instructor_cache
@@ -50,7 +69,18 @@ def extract(data) -> UserDetail:
messages=[
{"role": "user", "content": data},
],
)
) # type: ignore
@instructor_cache
async def aextract(data) -> UserDetail:
return await aclient.chat.completions.create(
model="gpt-3.5-turbo",
response_model=UserDetail,
messages=[
{"role": "user", "content": data},
],
) # type: ignore
def test_extract():
@@ -69,7 +99,27 @@ def test_extract():
print(f"Time taken: {time.perf_counter() - start}")
async def atest_extract():
import time
start = time.perf_counter()
model = await aextract("Extract jason is 25 years old")
assert model.name.lower() == "jason"
assert model.age == 25
print(f"Time taken: {time.perf_counter() - start}")
start = time.perf_counter()
model = await aextract("Extract jason is 25 years old")
assert model.name.lower() == "jason"
assert model.age == 25
print(f"Time taken: {time.perf_counter() - start}")
if __name__ == "__main__":
test_extract()
# Time taken: 0.7285366660216823
# Time taken: 9.841693099588156e-05
import asyncio
asyncio.run(atest_extract())