mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
clean up
This commit is contained in:
@@ -8,11 +8,14 @@ from pydantic import BaseModel
|
||||
|
||||
client = instructor.patch(OpenAI())
|
||||
|
||||
|
||||
class UserDetail(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
cache = diskcache.Cache('./my_cache_directory')
|
||||
|
||||
cache = diskcache.Cache("./my_cache_directory")
|
||||
|
||||
|
||||
def instructor_cache(func):
|
||||
"""Cache a function that returns a Pydantic model"""
|
||||
@@ -38,6 +41,7 @@ def instructor_cache(func):
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@instructor_cache
|
||||
def extract(data) -> UserDetail:
|
||||
return client.chat.completions.create(
|
||||
@@ -45,12 +49,12 @@ def extract(data) -> UserDetail:
|
||||
response_model=UserDetail,
|
||||
messages=[
|
||||
{"role": "user", "content": data},
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_extract():
|
||||
import time
|
||||
import time
|
||||
|
||||
start = time.perf_counter()
|
||||
model = extract("Extract jason is 25 years old")
|
||||
@@ -68,4 +72,4 @@ def test_extract():
|
||||
if __name__ == "__main__":
|
||||
test_extract()
|
||||
# Time taken: 0.7285366660216823
|
||||
# Time taken: 9.841693099588156e-05
|
||||
# Time taken: 9.841693099588156e-05
|
||||
|
||||
@@ -9,6 +9,7 @@ from openai import OpenAI
|
||||
client = instructor.patch(OpenAI())
|
||||
cache = redis.Redis("localhost")
|
||||
|
||||
|
||||
def instructor_cache(func):
|
||||
"""Cache a function that returns a Pydantic model"""
|
||||
return_type = inspect.signature(func).return_annotation
|
||||
@@ -38,6 +39,7 @@ class UserDetail(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
|
||||
@instructor_cache
|
||||
def extract(data) -> UserDetail:
|
||||
# Assuming client.chat.completions.create returns a UserDetail instance
|
||||
@@ -46,11 +48,12 @@ def extract(data) -> UserDetail:
|
||||
response_model=UserDetail,
|
||||
messages=[
|
||||
{"role": "user", "content": data},
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_extract():
|
||||
import time
|
||||
import time
|
||||
|
||||
start = time.perf_counter()
|
||||
model = extract("Extract jason is 25 years old")
|
||||
@@ -68,4 +71,4 @@ def test_extract():
|
||||
if __name__ == "__main__":
|
||||
test_extract()
|
||||
# Time taken: 0.798335583996959
|
||||
# Time taken: 0.00017016706988215446
|
||||
# Time taken: 0.00017016706988215446
|
||||
|
||||
@@ -5,10 +5,12 @@ import functools
|
||||
|
||||
client = instructor.patch(OpenAI())
|
||||
|
||||
|
||||
class UserDetail(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def extract(data):
|
||||
return client.chat.completions.create(
|
||||
@@ -16,12 +18,12 @@ def extract(data):
|
||||
response_model=UserDetail,
|
||||
messages=[
|
||||
{"role": "user", "content": data},
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_extract():
|
||||
import time
|
||||
import time
|
||||
|
||||
start = time.perf_counter()
|
||||
model = extract("Extract jason is 25 years old")
|
||||
@@ -35,7 +37,8 @@ def test_extract():
|
||||
assert model.age == 25
|
||||
print(f"Time taken: {time.perf_counter() - start}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_extract()
|
||||
# Time taken: 0.9267581660533324
|
||||
# Time taken: 1.2080417945981026e-06
|
||||
# Time taken: 1.2080417945981026e-06
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import instructor
|
||||
import instructor
|
||||
|
||||
from instructor import openai_moderation
|
||||
|
||||
@@ -8,8 +8,9 @@ from openai import OpenAI
|
||||
|
||||
client = instructor.patch(OpenAI())
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
message: Annotated[str, AfterValidator(openai_moderation(client=client))]
|
||||
|
||||
|
||||
response = Response(message="I want to make them suffer the consequences")
|
||||
response = Response(message="I want to make them suffer the consequences")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from .distil import FinetuneFormat, Instructions
|
||||
from .dsl import CitationMixin, Maybe, MultiTask, llm_validator, openai_moderation
|
||||
from .function_calls import OpenAISchema, openai_function, openai_schema
|
||||
from .function_calls import OpenAISchema, openai_function, openai_schema, Mode
|
||||
from .patch import apatch, patch
|
||||
|
||||
__all__ = [
|
||||
@@ -10,6 +10,7 @@ __all__ = [
|
||||
"MultiTask",
|
||||
"Maybe",
|
||||
"openai_schema",
|
||||
"Mode",
|
||||
"patch",
|
||||
"apatch",
|
||||
"llm_validator",
|
||||
|
||||
@@ -99,6 +99,7 @@ def llm_validator(
|
||||
|
||||
return llm
|
||||
|
||||
|
||||
def openai_moderation(client: OpenAI = None):
|
||||
"""
|
||||
Validates a message using OpenAI moderation model.
|
||||
@@ -133,8 +134,10 @@ def openai_moderation(client: OpenAI = None):
|
||||
out = response.results[0]
|
||||
cats = out.categories.model_dump()
|
||||
if out.flagged:
|
||||
raise ValueError(f"`{v}` was flagged for {', '.join(cat for cat in cats if cats[cat])}")
|
||||
|
||||
raise ValueError(
|
||||
f"`{v}` was flagged for {', '.join(cat for cat in cats if cats[cat])}"
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
return validate_message_with_openai_mod
|
||||
|
||||
+6
-10
@@ -55,9 +55,7 @@ def handle_response_model(
|
||||
|
||||
if mode == Mode.FUNCTIONS:
|
||||
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
|
||||
new_kwargs["function_call"] = {
|
||||
"name": response_model.openai_schema["name"]
|
||||
} # type: ignore
|
||||
new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore
|
||||
elif mode == Mode.TOOLS:
|
||||
new_kwargs["tools"] = [
|
||||
{
|
||||
@@ -92,7 +90,9 @@ def handle_response_model(
|
||||
raise ValueError(f"Invalid patch mode: {mode}")
|
||||
|
||||
if new_kwargs.get("stream", False) and response_model is not None:
|
||||
raise NotImplementedError("stream=True is not supported when using response_model parameter")
|
||||
raise NotImplementedError(
|
||||
"stream=True is not supported when using response_model parameter"
|
||||
)
|
||||
|
||||
warnings.warn(
|
||||
"stream=True is not supported when using response_model parameter"
|
||||
@@ -204,9 +204,7 @@ def is_async(func: Callable) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def wrap_chatcompletion(
|
||||
func: Callable, mode: Mode = Mode.FUNCTIONS
|
||||
) -> Callable:
|
||||
def wrap_chatcompletion(func: Callable, mode: Mode = Mode.FUNCTIONS) -> Callable:
|
||||
func_is_async = is_async(func)
|
||||
|
||||
@wraps(func)
|
||||
@@ -268,9 +266,7 @@ def wrap_chatcompletion(
|
||||
return wrapper_function
|
||||
|
||||
|
||||
def patch(
|
||||
client: Union[OpenAI, AsyncOpenAI], mode: Mode = Mode.FUNCTIONS
|
||||
):
|
||||
def patch(client: Union[OpenAI, AsyncOpenAI], mode: Mode = Mode.FUNCTIONS):
|
||||
"""
|
||||
Patch the `client.chat.completions.create` method
|
||||
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
import pytest
|
||||
from itertools import product
|
||||
from pydantic import BaseModel
|
||||
from openai import OpenAI
|
||||
import instructor
|
||||
from instructor.function_calls import Mode
|
||||
|
||||
|
||||
class UserDetails(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
|
||||
# Lists for models, test data, and modes
|
||||
models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview"]
|
||||
test_data = [
|
||||
("Jason is 10", "Jason", 10),
|
||||
("Alice is 25", "Alice", 25),
|
||||
("Bob is 35", "Bob", 35),
|
||||
]
|
||||
modes = [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model, data, mode", product(models, test_data, modes))
|
||||
def test_extract(model, data, mode):
|
||||
sample_data, expected_name, expected_age = data
|
||||
|
||||
if mode == Mode.JSON and model in {"gpt-3.5-turbo", "gpt-4"}:
|
||||
pytest.skip(
|
||||
"JSON mode is not supported for gpt-3.5-turbo and gpt-4, skipping test"
|
||||
)
|
||||
|
||||
# Setting up the client with the instructor patch
|
||||
client = instructor.patch(OpenAI(), mode=mode)
|
||||
|
||||
# Calling the extract function with the provided model, sample data, and mode
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
response_model=UserDetails,
|
||||
messages=[
|
||||
{"role": "user", "content": sample_data},
|
||||
],
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert (
|
||||
response.name == expected_name
|
||||
), f"Expected name {expected_name}, got {response.name}"
|
||||
assert (
|
||||
response.age == expected_age
|
||||
), f"Expected age {expected_age}, got {response.age}"
|
||||
@@ -76,4 +76,4 @@ def test_mode(mode):
|
||||
],
|
||||
)
|
||||
assert user.name.lower() == "jason"
|
||||
assert user.age == 25
|
||||
assert user.age == 25
|
||||
|
||||
@@ -15,9 +15,7 @@ class UserExtract(BaseModel):
|
||||
age: int
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS]
|
||||
)
|
||||
@pytest.mark.parametrize("mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS])
|
||||
def test_runmodel(mode):
|
||||
client = instructor.patch(OpenAI(), mode=mode)
|
||||
model = client.chat.completions.create(
|
||||
@@ -36,9 +34,7 @@ def test_runmodel(mode):
|
||||
), "The raw response should be available from OpenAI"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS]
|
||||
)
|
||||
@pytest.mark.parametrize("mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS])
|
||||
@pytest.mark.asyncio
|
||||
async def test_runmodel_async(mode):
|
||||
aclient = instructor.patch(AsyncOpenAI(), mode=mode)
|
||||
|
||||
Reference in New Issue
Block a user