mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
Add multiple modalities: tools, functions, json_mode (#218)
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,70 @@
|
||||
# Patching
|
||||
|
||||
Instructor enhances client functionality with three new keywords for backwards compatibility. This allows use of the enhanced client as usual, with structured output benefits.
|
||||
|
||||
- `response_model`: Defines the response type for `chat.completions.create`.
|
||||
- `max_retries`: Determines retry attempts for failed `chat.completions.create` validations.
|
||||
- `validation_context`: Provides extra context to the validation process.
|
||||
|
||||
There are three methods for structured output:
|
||||
|
||||
1. **Function Calling**: The primary method. Use this for stability and testing.
|
||||
2. **Tool Calling**: Useful in specific scenarios; lacks the reasking feature of OpenAI's tool calling API.
|
||||
3. **JSON Mode**: Offers closer adherence to JSON but with more potential validation errors. Suitable for specific non-function calling clients.
|
||||
|
||||
## Function Calling
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
import instructor
|
||||
|
||||
client = instructor.patch(OpenAI())
|
||||
```
|
||||
|
||||
## Tool Calling
|
||||
|
||||
```python
|
||||
import instructor
|
||||
from instructor import Mode
|
||||
|
||||
client = instructor.patch(OpenAI(), mode=Mode.TOOL_CALL)
|
||||
```
|
||||
|
||||
## JSON Mode
|
||||
|
||||
```python
|
||||
import instructor
|
||||
from instructor import Mode
|
||||
from openai import OpenAI
|
||||
|
||||
client = instructor.patch(OpenAI(), mode=Mode.JSON)
|
||||
```
|
||||
|
||||
### Schema Integration
|
||||
|
||||
In JSON Mode, the schema is part of the system message:
|
||||
|
||||
```python
|
||||
import instructor
|
||||
from openai import OpenAI
|
||||
|
||||
client = instructor.patch(OpenAI())
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-1106",
|
||||
response_format={"type": "json_object"},
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"Match your response to this json_schema: \n{UserExtract.model_json_schema()['properties']}",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Extract jason is 25 years old",
|
||||
},
|
||||
],
|
||||
)
|
||||
user = UserExtract.from_response(response, mode=Mode.JSON)
|
||||
assert user.name.lower() == "jason"
|
||||
assert user.age == 25
|
||||
```
|
||||
@@ -1,9 +1,20 @@
|
||||
from calendar import c
|
||||
import json
|
||||
from docstring_parser import parse
|
||||
from functools import wraps
|
||||
from typing import Any, Callable
|
||||
from pydantic import BaseModel, create_model, validate_arguments
|
||||
|
||||
import enum
|
||||
|
||||
|
||||
class Mode(enum.Enum):
|
||||
"""The mode to use for patching the client"""
|
||||
|
||||
FUNCTIONS: str = "function_call"
|
||||
TOOLS: str = "tool_call"
|
||||
JSON: str = "json_mode"
|
||||
|
||||
|
||||
class openai_function:
|
||||
"""
|
||||
@@ -176,9 +187,9 @@ class OpenAISchema(BaseModel):
|
||||
def from_response(
|
||||
cls,
|
||||
completion,
|
||||
throw_error: bool = True,
|
||||
validation_context=None,
|
||||
strict: bool = None,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
):
|
||||
"""Execute the function from the response of an openai chat completion
|
||||
|
||||
@@ -193,11 +204,36 @@ class OpenAISchema(BaseModel):
|
||||
"""
|
||||
message = completion.choices[0].message
|
||||
|
||||
return cls.model_validate_json(
|
||||
message.function_call.arguments,
|
||||
context=validation_context,
|
||||
strict=strict,
|
||||
)
|
||||
if mode == Mode.FUNCTIONS:
|
||||
assert (
|
||||
message.function_call.name == cls.openai_schema["name"]
|
||||
), "Function name does not match"
|
||||
return cls.model_validate_json(
|
||||
message.function_call.arguments,
|
||||
context=validation_context,
|
||||
strict=strict,
|
||||
)
|
||||
elif mode == Mode.TOOLS:
|
||||
assert (
|
||||
len(message.tool_calls) == 1
|
||||
), "Instructor does not support multiple tool calls, use List[Model] instead."
|
||||
tool_call = message.tool_calls[0]
|
||||
assert (
|
||||
tool_call.function.name == cls.openai_schema["name"]
|
||||
), "Tool name does not match"
|
||||
return cls.model_validate_json(
|
||||
tool_call.function.arguments,
|
||||
context=validation_context,
|
||||
strict=strict,
|
||||
)
|
||||
elif mode == Mode.JSON:
|
||||
return cls.model_validate_json(
|
||||
message.content,
|
||||
context=validation_context,
|
||||
strict=strict,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid patch mode: {mode}")
|
||||
|
||||
|
||||
def openai_schema(cls) -> OpenAISchema:
|
||||
|
||||
+112
-48
@@ -4,10 +4,12 @@ from json import JSONDecodeError
|
||||
from typing import Callable, Optional, Type, Union
|
||||
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
||||
from openai.types.chat import ChatCompletion
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from .function_calls import OpenAISchema, openai_schema
|
||||
from .function_calls import OpenAISchema, openai_schema, Mode
|
||||
|
||||
import warnings
|
||||
|
||||
OVERRIDE_DOCS = """
|
||||
Creates a new chat completion for the provided messages and parameters.
|
||||
@@ -29,16 +31,68 @@ Parameters:
|
||||
"""
|
||||
|
||||
|
||||
def handle_response_model(response_model: Type[BaseModel], kwargs):
|
||||
def dump_message(message) -> dict:
|
||||
"""Dumps a message to a dict, to be returned to the OpenAI API.
|
||||
Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
|
||||
if it isn't used.
|
||||
"""
|
||||
dumped_message = message.model_dump()
|
||||
if not dumped_message.get("tool_calls"):
|
||||
del dumped_message["tool_calls"]
|
||||
return {k: v for k, v in dumped_message.items() if v}
|
||||
|
||||
|
||||
def handle_response_model(
|
||||
*,
|
||||
response_model: Type[BaseModel],
|
||||
kwargs,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
):
|
||||
new_kwargs = kwargs.copy()
|
||||
if response_model is not None:
|
||||
if not issubclass(response_model, OpenAISchema):
|
||||
response_model = openai_schema(response_model) # type: ignore
|
||||
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
|
||||
new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore
|
||||
|
||||
if mode == Mode.FUNCTIONS:
|
||||
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
|
||||
new_kwargs["function_call"] = {
|
||||
"name": response_model.openai_schema["name"]
|
||||
} # type: ignore
|
||||
elif mode == Mode.TOOLS:
|
||||
new_kwargs["tools"] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": response_model.openai_schema,
|
||||
}
|
||||
]
|
||||
new_kwargs["tool_choice"] = {
|
||||
"type": "function",
|
||||
"function": {"name": response_model.openai_schema["name"]},
|
||||
}
|
||||
elif mode == Mode.JSON:
|
||||
new_kwargs["response_format"] = {"type": "json_object"}
|
||||
|
||||
# check that the first message is a system message
|
||||
# if it is not, add a system message to the beginning
|
||||
message = f"Make sure that your response to any message matchs the json_schema below, do not deviate at all: \n{response_model.model_json_schema()['properties']}"
|
||||
|
||||
if new_kwargs["messages"][0]["role"] != "system":
|
||||
new_kwargs["messages"].insert(
|
||||
0,
|
||||
{
|
||||
"role": "system",
|
||||
"content": message,
|
||||
},
|
||||
)
|
||||
|
||||
# if the first message is a system append the schema to the end
|
||||
if new_kwargs["messages"][0]["role"] == "system":
|
||||
new_kwargs["messages"][0]["content"] += f"\n\n{message}"
|
||||
else:
|
||||
raise ValueError(f"Invalid patch mode: {mode}")
|
||||
|
||||
if new_kwargs.get("stream", False) and response_model is not None:
|
||||
import warnings
|
||||
raise NotImplementedError("stream=True is not supported when using response_model parameter")
|
||||
|
||||
warnings.warn(
|
||||
"stream=True is not supported when using response_model parameter"
|
||||
@@ -48,7 +102,12 @@ def handle_response_model(response_model: Type[BaseModel], kwargs):
|
||||
|
||||
|
||||
def process_response(
|
||||
response, response_model, validation_context: dict = None, strict=None
|
||||
response,
|
||||
*,
|
||||
response_model: Type[BaseModel],
|
||||
validation_context: dict = None,
|
||||
strict=None,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
): # type: ignore
|
||||
"""Processes a OpenAI response with the response model, if available
|
||||
It can use `validation_context` and `strict` to validate the response
|
||||
@@ -62,25 +121,13 @@ def process_response(
|
||||
"""
|
||||
if response_model is not None:
|
||||
model = response_model.from_response(
|
||||
response, validation_context=validation_context, strict=strict
|
||||
response, validation_context=validation_context, strict=strict, mode=mode
|
||||
)
|
||||
model._raw_response = response
|
||||
return model
|
||||
return response
|
||||
|
||||
|
||||
def dump_message(message: ChatCompletionMessage) -> dict:
|
||||
"""Dumps a message to a dict, to be returned to the OpenAI API.
|
||||
|
||||
Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
|
||||
if it isn't used.
|
||||
"""
|
||||
dumped_message = message.model_dump()
|
||||
if not dumped_message.get("tool_calls"):
|
||||
del dumped_message["tool_calls"]
|
||||
return dumped_message
|
||||
|
||||
|
||||
async def retry_async(
|
||||
func,
|
||||
response_model,
|
||||
@@ -89,22 +136,21 @@ async def retry_async(
|
||||
kwargs,
|
||||
max_retries,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
):
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
response: ChatCompletion = await func(*args, **kwargs)
|
||||
return (
|
||||
process_response(
|
||||
response,
|
||||
response_model,
|
||||
validation_context,
|
||||
strict=strict,
|
||||
),
|
||||
None,
|
||||
return process_response(
|
||||
response,
|
||||
response_model=response_model,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
)
|
||||
except (ValidationError, JSONDecodeError) as e:
|
||||
kwargs["messages"].append(response.choices[0].message) # type: ignore
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "user",
|
||||
@@ -124,20 +170,22 @@ def retry_sync(
|
||||
kwargs,
|
||||
max_retries,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
):
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
# Excepts ValidationError, and JSONDecodeError
|
||||
try:
|
||||
response = func(*args, **kwargs)
|
||||
return (
|
||||
process_response(
|
||||
response, response_model, validation_context, strict=strict
|
||||
),
|
||||
None,
|
||||
return process_response(
|
||||
response,
|
||||
response_model=response_model,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
)
|
||||
except (ValidationError, JSONDecodeError) as e:
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message))
|
||||
kwargs["messages"].append(response.choices[0].message)
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "user",
|
||||
@@ -156,7 +204,9 @@ def is_async(func: Callable) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def wrap_chatcompletion(func: Callable) -> Callable:
|
||||
def wrap_chatcompletion(
|
||||
func: Callable, mode: Mode = Mode.FUNCTIONS
|
||||
) -> Callable:
|
||||
func_is_async = is_async(func)
|
||||
|
||||
@wraps(func)
|
||||
@@ -167,17 +217,22 @@ def wrap_chatcompletion(func: Callable) -> Callable:
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
response_model, new_kwargs = handle_response_model(response_model, kwargs) # type: ignore
|
||||
response, error = await retry_async(
|
||||
if mode == Mode.TOOLS:
|
||||
max_retries = 0
|
||||
warnings.warn("max_retries is not supported when using tool calls")
|
||||
|
||||
response_model, new_kwargs = handle_response_model(
|
||||
response_model=response_model, kwargs=kwargs, mode=mode
|
||||
) # type: ignore
|
||||
response = await retry_async(
|
||||
func=func,
|
||||
response_model=response_model,
|
||||
validation_context=validation_context,
|
||||
max_retries=max_retries,
|
||||
args=args,
|
||||
kwargs=new_kwargs,
|
||||
mode=mode,
|
||||
) # type: ignore
|
||||
if error:
|
||||
raise ValueError(error)
|
||||
return response
|
||||
|
||||
@wraps(func)
|
||||
@@ -188,17 +243,22 @@ def wrap_chatcompletion(func: Callable) -> Callable:
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
response_model, new_kwargs = handle_response_model(response_model, kwargs) # type: ignore
|
||||
response, error = retry_sync(
|
||||
if mode == Mode.TOOLS:
|
||||
max_retries = 0
|
||||
warnings.warn("max_retries is not supported when using tool calls")
|
||||
|
||||
response_model, new_kwargs = handle_response_model(
|
||||
response_model=response_model, kwargs=kwargs, mode=mode
|
||||
) # type: ignore
|
||||
response = retry_sync(
|
||||
func=func,
|
||||
response_model=response_model,
|
||||
validation_context=validation_context,
|
||||
max_retries=max_retries,
|
||||
args=args,
|
||||
kwargs=new_kwargs,
|
||||
mode=mode,
|
||||
) # type: ignore
|
||||
if error:
|
||||
raise ValueError(error)
|
||||
return response
|
||||
|
||||
wrapper_function = (
|
||||
@@ -208,7 +268,9 @@ def wrap_chatcompletion(func: Callable) -> Callable:
|
||||
return wrapper_function
|
||||
|
||||
|
||||
def patch(client: Union[OpenAI, AsyncOpenAI]):
|
||||
def patch(
|
||||
client: Union[OpenAI, AsyncOpenAI], mode: Mode = Mode.FUNCTIONS
|
||||
):
|
||||
"""
|
||||
Patch the `client.chat.completions.create` method
|
||||
|
||||
@@ -220,11 +282,13 @@ def patch(client: Union[OpenAI, AsyncOpenAI]):
|
||||
- `strict` parameter to use strict json parsing
|
||||
"""
|
||||
|
||||
client.chat.completions.create = wrap_chatcompletion(client.chat.completions.create)
|
||||
client.chat.completions.create = wrap_chatcompletion(
|
||||
client.chat.completions.create, mode=mode
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def apatch(client: AsyncOpenAI):
|
||||
def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS):
|
||||
"""
|
||||
No longer necessary, use `patch` instead.
|
||||
|
||||
@@ -237,4 +301,4 @@ def apatch(client: AsyncOpenAI):
|
||||
- `validation_context` parameter to validate the response using the pydantic model
|
||||
- `strict` parameter to use strict json parsing
|
||||
"""
|
||||
return patch(client)
|
||||
return patch(client, mode=mode)
|
||||
|
||||
@@ -131,6 +131,7 @@ nav:
|
||||
- Models: 'concepts/models.md'
|
||||
- Fields: 'concepts/fields.md'
|
||||
- Types: 'concepts/types.md'
|
||||
- Patching: 'concepts/patching.md'
|
||||
- Streaming: "concepts/lists.md"
|
||||
- Union: 'concepts/union.md'
|
||||
- Alias: 'concepts/alias.md'
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
from instructor.function_calls import OpenAISchema, Mode
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
client = OpenAI()
|
||||
|
||||
|
||||
class UserExtract(OpenAISchema):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
|
||||
def test_tool_call():
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-1106",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Extract jason is 25 years old, mary is 30 years old",
|
||||
},
|
||||
],
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": UserExtract.openai_schema,
|
||||
}
|
||||
],
|
||||
tool_choice={
|
||||
"type": "function",
|
||||
"function": {"name": UserExtract.openai_schema["name"]},
|
||||
},
|
||||
)
|
||||
response_message = response.choices[0].message
|
||||
tool_calls = response_message.tool_calls
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "UserExtract"
|
||||
assert tool_calls[0].function
|
||||
user = UserExtract.from_response(response, mode=Mode.TOOLS)
|
||||
assert user.name.lower() == "jason"
|
||||
assert user.age == 25
|
||||
|
||||
|
||||
def test_json_mode():
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-1106",
|
||||
response_format={"type": "json_object"},
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"Make sure that your response to any message matchs the json_schema below, do not deviate at all: \n{UserExtract.model_json_schema()['properties']}",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Extract jason is 25 years old",
|
||||
},
|
||||
],
|
||||
)
|
||||
user = UserExtract.from_response(response, mode=Mode.JSON)
|
||||
assert user.name.lower() == "jason"
|
||||
assert user.age == 25
|
||||
+64
-69
@@ -1,19 +1,67 @@
|
||||
from pydantic import BaseModel, field_validator
|
||||
import pytest
|
||||
import instructor
|
||||
|
||||
from instructor import llm_validator
|
||||
from typing_extensions import Annotated
|
||||
from pydantic import field_validator, BaseModel, BeforeValidator, ValidationError
|
||||
from openai import OpenAI, AsyncOpenAI
|
||||
|
||||
client = instructor.patch(OpenAI())
|
||||
from instructor.function_calls import Mode
|
||||
|
||||
aclient = instructor.patch(AsyncOpenAI())
|
||||
client = instructor.patch(OpenAI())
|
||||
|
||||
|
||||
class UserExtract(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS]
|
||||
)
|
||||
def test_runmodel(mode):
|
||||
client = instructor.patch(OpenAI(), mode=mode)
|
||||
model = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-1106",
|
||||
response_model=UserExtract,
|
||||
max_retries=2,
|
||||
messages=[
|
||||
{"role": "user", "content": "Extract jason is 25 years old"},
|
||||
],
|
||||
)
|
||||
assert isinstance(model, UserExtract), "Should be instance of UserExtract"
|
||||
assert model.name.lower() == "jason"
|
||||
assert model.age == 25
|
||||
assert hasattr(
|
||||
model, "_raw_response"
|
||||
), "The raw response should be available from OpenAI"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS]
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_runmodel_async(mode):
|
||||
aclient = instructor.patch(AsyncOpenAI(), mode=mode)
|
||||
model = await aclient.chat.completions.create(
|
||||
model="gpt-3.5-turbo-1106",
|
||||
response_model=UserExtract,
|
||||
max_retries=2,
|
||||
messages=[
|
||||
{"role": "user", "content": "Extract jason is 25 years old"},
|
||||
],
|
||||
)
|
||||
assert isinstance(model, UserExtract), "Should be instance of UserExtract"
|
||||
assert model.name.lower() == "jason"
|
||||
assert model.age == 25
|
||||
assert hasattr(
|
||||
model, "_raw_response"
|
||||
), "The raw response should be available from OpenAI"
|
||||
|
||||
|
||||
class UserExtractValidated(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v):
|
||||
@@ -22,91 +70,38 @@ class UserExtract(BaseModel):
|
||||
return v
|
||||
|
||||
|
||||
def test_runmodel_validator():
|
||||
@pytest.mark.parametrize("mode", [Mode.FUNCTIONS, Mode.JSON])
|
||||
def test_runmodel_validator(mode):
|
||||
client = instructor.patch(OpenAI(), mode=mode)
|
||||
model = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
response_model=UserExtract,
|
||||
model="gpt-3.5-turbo-1106",
|
||||
response_model=UserExtractValidated,
|
||||
max_retries=2,
|
||||
messages=[
|
||||
{"role": "user", "content": "Extract jason is 25 years old"},
|
||||
],
|
||||
)
|
||||
assert isinstance(model, UserExtract), "Should be instance of UserExtract"
|
||||
assert isinstance(model, UserExtractValidated), "Should be instance of UserExtract"
|
||||
assert model.name == "JASON"
|
||||
assert hasattr(
|
||||
model, "_raw_response"
|
||||
), "The raw response should be available from OpenAI"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mode", [Mode.FUNCTIONS, Mode.JSON])
|
||||
@pytest.mark.asyncio
|
||||
async def test_runmodel_async_validator():
|
||||
async def test_runmodel_async_validator(mode):
|
||||
aclient = instructor.patch(AsyncOpenAI(), mode=mode)
|
||||
model = await aclient.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
response_model=UserExtract,
|
||||
model="gpt-3.5-turbo-1106",
|
||||
response_model=UserExtractValidated,
|
||||
max_retries=2,
|
||||
messages=[
|
||||
{"role": "user", "content": "Extract jason is 25 years old"},
|
||||
],
|
||||
)
|
||||
assert isinstance(model, UserExtract), "Should be instance of UserExtract"
|
||||
assert isinstance(model, UserExtractValidated), "Should be instance of UserExtract"
|
||||
assert model.name == "JASON"
|
||||
assert hasattr(
|
||||
model, "_raw_response"
|
||||
), "The raw response should be available from OpenAI"
|
||||
|
||||
|
||||
class UserExtractSimple(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_runmodel():
|
||||
model = await aclient.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
response_model=UserExtractSimple,
|
||||
messages=[
|
||||
{"role": "user", "content": "Extract jason is 25 years old"},
|
||||
],
|
||||
)
|
||||
assert isinstance(
|
||||
model, UserExtractSimple
|
||||
), "Should be instance of UserExtractSimple"
|
||||
assert model.name.lower() == "jason"
|
||||
assert hasattr(
|
||||
model, "_raw_response"
|
||||
), "The raw response should be available from OpenAI"
|
||||
|
||||
|
||||
def test_runmodel():
|
||||
model = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
response_model=UserExtractSimple,
|
||||
messages=[
|
||||
{"role": "user", "content": "Extract jason is 25 years old"},
|
||||
],
|
||||
)
|
||||
assert isinstance(
|
||||
model, UserExtractSimple
|
||||
), "Should be instance of UserExtractSimple"
|
||||
assert model.name.lower() == "jason"
|
||||
assert hasattr(
|
||||
model, "_raw_response"
|
||||
), "The raw response should be available from OpenAI"
|
||||
|
||||
|
||||
def test_runmodel_validator_error():
|
||||
class QuestionAnswerNoEvil(BaseModel):
|
||||
question: str
|
||||
answer: Annotated[
|
||||
str,
|
||||
BeforeValidator(
|
||||
llm_validator("don't say objectionable things", openai_client=client)
|
||||
),
|
||||
]
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
QuestionAnswerNoEvil(
|
||||
question="What is the meaning of life?",
|
||||
answer="The meaning of life is to be evil and steal",
|
||||
)
|
||||
|
||||
@@ -1,17 +1,38 @@
|
||||
import pytest
|
||||
|
||||
import instructor
|
||||
import instructor
|
||||
|
||||
from typing_extensions import Annotated
|
||||
from pydantic import BaseModel, AfterValidator, ValidationError
|
||||
from pydantic import BaseModel, AfterValidator, BeforeValidator, ValidationError
|
||||
from openai import OpenAI
|
||||
|
||||
from instructor.dsl.validators import llm_validator
|
||||
|
||||
client = instructor.patch(OpenAI())
|
||||
|
||||
|
||||
def test_patch_completes_successfully():
|
||||
class Response(BaseModel):
|
||||
message: Annotated[str, AfterValidator(instructor.openai_moderation(client=client))]
|
||||
|
||||
message: Annotated[
|
||||
str, AfterValidator(instructor.openai_moderation(client=client))
|
||||
]
|
||||
|
||||
with pytest.raises(ValidationError) as e:
|
||||
Response(message="I want to make them suffer the consequences")
|
||||
Response(message="I want to make them suffer the consequences")
|
||||
|
||||
|
||||
def test_runmodel_validator_error():
|
||||
class QuestionAnswerNoEvil(BaseModel):
|
||||
question: str
|
||||
answer: Annotated[
|
||||
str,
|
||||
BeforeValidator(
|
||||
llm_validator("don't say objectionable things", openai_client=client)
|
||||
),
|
||||
]
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
QuestionAnswerNoEvil(
|
||||
question="What is the meaning of life?",
|
||||
answer="The meaning of life is to be evil and steal",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user