This commit is contained in:
Jason Liu
2023-11-25 19:56:40 -05:00
parent 767b88c786
commit d3a567f2ff
10 changed files with 91 additions and 33 deletions
+8 -4
View File
@@ -8,11 +8,14 @@ from pydantic import BaseModel
client = instructor.patch(OpenAI())
class UserDetail(BaseModel):
name: str
age: int
cache = diskcache.Cache('./my_cache_directory')
cache = diskcache.Cache("./my_cache_directory")
def instructor_cache(func):
"""Cache a function that returns a Pydantic model"""
@@ -38,6 +41,7 @@ def instructor_cache(func):
return wrapper
@instructor_cache
def extract(data) -> UserDetail:
return client.chat.completions.create(
@@ -45,12 +49,12 @@ def extract(data) -> UserDetail:
response_model=UserDetail,
messages=[
{"role": "user", "content": data},
]
],
)
def test_extract():
import time
import time
start = time.perf_counter()
model = extract("Extract jason is 25 years old")
@@ -68,4 +72,4 @@ def test_extract():
if __name__ == "__main__":
test_extract()
# Time taken: 0.7285366660216823
# Time taken: 9.841693099588156e-05
# Time taken: 9.841693099588156e-05
+6 -3
View File
@@ -9,6 +9,7 @@ from openai import OpenAI
client = instructor.patch(OpenAI())
cache = redis.Redis("localhost")
def instructor_cache(func):
"""Cache a function that returns a Pydantic model"""
return_type = inspect.signature(func).return_annotation
@@ -38,6 +39,7 @@ class UserDetail(BaseModel):
name: str
age: int
@instructor_cache
def extract(data) -> UserDetail:
# Assuming client.chat.completions.create returns a UserDetail instance
@@ -46,11 +48,12 @@ def extract(data) -> UserDetail:
response_model=UserDetail,
messages=[
{"role": "user", "content": data},
]
],
)
def test_extract():
import time
import time
start = time.perf_counter()
model = extract("Extract jason is 25 years old")
@@ -68,4 +71,4 @@ def test_extract():
if __name__ == "__main__":
test_extract()
# Time taken: 0.798335583996959
# Time taken: 0.00017016706988215446
# Time taken: 0.00017016706988215446
+6 -3
View File
@@ -5,10 +5,12 @@ import functools
client = instructor.patch(OpenAI())
class UserDetail(BaseModel):
name: str
age: int
@functools.lru_cache
def extract(data):
return client.chat.completions.create(
@@ -16,12 +18,12 @@ def extract(data):
response_model=UserDetail,
messages=[
{"role": "user", "content": data},
]
],
)
def test_extract():
import time
import time
start = time.perf_counter()
model = extract("Extract jason is 25 years old")
@@ -35,7 +37,8 @@ def test_extract():
assert model.age == 25
print(f"Time taken: {time.perf_counter() - start}")
if __name__ == "__main__":
test_extract()
# Time taken: 0.9267581660533324
# Time taken: 1.2080417945981026e-06
# Time taken: 1.2080417945981026e-06
+3 -2
View File
@@ -1,4 +1,4 @@
import instructor
import instructor
from instructor import openai_moderation
@@ -8,8 +8,9 @@ from openai import OpenAI
client = instructor.patch(OpenAI())
class Response(BaseModel):
message: Annotated[str, AfterValidator(openai_moderation(client=client))]
response = Response(message="I want to make them suffer the consequences")
response = Response(message="I want to make them suffer the consequences")
+2 -1
View File
@@ -1,6 +1,6 @@
from .distil import FinetuneFormat, Instructions
from .dsl import CitationMixin, Maybe, MultiTask, llm_validator, openai_moderation
from .function_calls import OpenAISchema, openai_function, openai_schema
from .function_calls import OpenAISchema, openai_function, openai_schema, Mode
from .patch import apatch, patch
__all__ = [
@@ -10,6 +10,7 @@ __all__ = [
"MultiTask",
"Maybe",
"openai_schema",
"Mode",
"patch",
"apatch",
"llm_validator",
+6 -3
View File
@@ -99,6 +99,7 @@ def llm_validator(
return llm
def openai_moderation(client: OpenAI = None):
"""
Validates a message using OpenAI moderation model.
@@ -133,8 +134,10 @@ def openai_moderation(client: OpenAI = None):
out = response.results[0]
cats = out.categories.model_dump()
if out.flagged:
raise ValueError(f"`{v}` was flagged for {', '.join(cat for cat in cats if cats[cat])}")
raise ValueError(
f"`{v}` was flagged for {', '.join(cat for cat in cats if cats[cat])}"
)
return v
return validate_message_with_openai_mod
+6 -10
View File
@@ -55,9 +55,7 @@ def handle_response_model(
if mode == Mode.FUNCTIONS:
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
new_kwargs["function_call"] = {
"name": response_model.openai_schema["name"]
} # type: ignore
new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore
elif mode == Mode.TOOLS:
new_kwargs["tools"] = [
{
@@ -92,7 +90,9 @@ def handle_response_model(
raise ValueError(f"Invalid patch mode: {mode}")
if new_kwargs.get("stream", False) and response_model is not None:
raise NotImplementedError("stream=True is not supported when using response_model parameter")
raise NotImplementedError(
"stream=True is not supported when using response_model parameter"
)
warnings.warn(
"stream=True is not supported when using response_model parameter"
@@ -204,9 +204,7 @@ def is_async(func: Callable) -> bool:
)
def wrap_chatcompletion(
func: Callable, mode: Mode = Mode.FUNCTIONS
) -> Callable:
def wrap_chatcompletion(func: Callable, mode: Mode = Mode.FUNCTIONS) -> Callable:
func_is_async = is_async(func)
@wraps(func)
@@ -268,9 +266,7 @@ def wrap_chatcompletion(
return wrapper_function
def patch(
client: Union[OpenAI, AsyncOpenAI], mode: Mode = Mode.FUNCTIONS
):
def patch(client: Union[OpenAI, AsyncOpenAI], mode: Mode = Mode.FUNCTIONS):
"""
Patch the `client.chat.completions.create` method
+51
View File
@@ -0,0 +1,51 @@
import pytest
from itertools import product
from pydantic import BaseModel
from openai import OpenAI
import instructor
from instructor.function_calls import Mode
class UserDetails(BaseModel):
name: str
age: int
# Lists for models, test data, and modes
models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview"]
test_data = [
("Jason is 10", "Jason", 10),
("Alice is 25", "Alice", 25),
("Bob is 35", "Bob", 35),
]
modes = [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS]
@pytest.mark.parametrize("model, data, mode", product(models, test_data, modes))
def test_extract(model, data, mode):
sample_data, expected_name, expected_age = data
if mode == Mode.JSON and model in {"gpt-3.5-turbo", "gpt-4"}:
pytest.skip(
"JSON mode is not supported for gpt-3.5-turbo and gpt-4, skipping test"
)
# Setting up the client with the instructor patch
client = instructor.patch(OpenAI(), mode=mode)
# Calling the extract function with the provided model, sample data, and mode
response = client.chat.completions.create(
model=model,
response_model=UserDetails,
messages=[
{"role": "user", "content": sample_data},
],
)
# Assertions
assert (
response.name == expected_name
), f"Expected name {expected_name}, got {response.name}"
assert (
response.age == expected_age
), f"Expected age {expected_age}, got {response.age}"
+1 -1
View File
@@ -76,4 +76,4 @@ def test_mode(mode):
],
)
assert user.name.lower() == "jason"
assert user.age == 25
assert user.age == 25
+2 -6
View File
@@ -15,9 +15,7 @@ class UserExtract(BaseModel):
age: int
@pytest.mark.parametrize(
"mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS]
)
@pytest.mark.parametrize("mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS])
def test_runmodel(mode):
client = instructor.patch(OpenAI(), mode=mode)
model = client.chat.completions.create(
@@ -36,9 +34,7 @@ def test_runmodel(mode):
), "The raw response should be available from OpenAI"
@pytest.mark.parametrize(
"mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS]
)
@pytest.mark.parametrize("mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS])
@pytest.mark.asyncio
async def test_runmodel_async(mode):
aclient = instructor.patch(AsyncOpenAI(), mode=mode)