mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
doc: make diskcache async compatible
This commit is contained in:
@@ -3,10 +3,11 @@ import inspect
|
||||
import instructor
|
||||
import diskcache
|
||||
|
||||
from openai import OpenAI
|
||||
from openai import OpenAI, AsyncOpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
client = instructor.patch(OpenAI())
|
||||
aclient = instructor.patch(AsyncOpenAI())
|
||||
|
||||
|
||||
class UserDetail(BaseModel):
|
||||
@@ -23,6 +24,8 @@ def instructor_cache(func):
|
||||
if not issubclass(return_type, BaseModel):
|
||||
raise ValueError("The return type must be a Pydantic model")
|
||||
|
||||
is_async = inspect.iscoroutinefunction(func)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
key = f"{func.__name__}-{functools._make_key(args, kwargs, typed=False)}"
|
||||
@@ -39,7 +42,23 @@ def instructor_cache(func):
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
@functools.wraps(func)
|
||||
async def awrapper(*args, **kwargs):
|
||||
key = f"{func.__name__}-{functools._make_key(args, kwargs, typed=False)}"
|
||||
# Check if the result is already cached
|
||||
if (cached := cache.get(key)) is not None:
|
||||
# Deserialize from JSON based on the return type
|
||||
if issubclass(return_type, BaseModel):
|
||||
return return_type.model_validate_json(cached)
|
||||
|
||||
# Call the function and cache its result
|
||||
result = await func(*args, **kwargs)
|
||||
serialized_result = result.model_dump_json()
|
||||
cache.set(key, serialized_result)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper if not is_async else awrapper
|
||||
|
||||
|
||||
@instructor_cache
|
||||
@@ -50,7 +69,18 @@ def extract(data) -> UserDetail:
|
||||
messages=[
|
||||
{"role": "user", "content": data},
|
||||
],
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
@instructor_cache
|
||||
async def aextract(data) -> UserDetail:
|
||||
return await aclient.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
response_model=UserDetail,
|
||||
messages=[
|
||||
{"role": "user", "content": data},
|
||||
],
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def test_extract():
|
||||
@@ -69,7 +99,27 @@ def test_extract():
|
||||
print(f"Time taken: {time.perf_counter() - start}")
|
||||
|
||||
|
||||
async def atest_extract():
|
||||
import time
|
||||
|
||||
start = time.perf_counter()
|
||||
model = await aextract("Extract jason is 25 years old")
|
||||
assert model.name.lower() == "jason"
|
||||
assert model.age == 25
|
||||
print(f"Time taken: {time.perf_counter() - start}")
|
||||
|
||||
start = time.perf_counter()
|
||||
model = await aextract("Extract jason is 25 years old")
|
||||
assert model.name.lower() == "jason"
|
||||
assert model.age == 25
|
||||
print(f"Time taken: {time.perf_counter() - start}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_extract()
|
||||
# Time taken: 0.7285366660216823
|
||||
# Time taken: 9.841693099588156e-05
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(atest_extract())
|
||||
|
||||
Reference in New Issue
Block a user