fix: big refactor of patch.py (#493)

This commit is contained in:
Jason Liu
2024-03-07 02:07:11 -05:00
committed by GitHub
parent 3e44a6bc30
commit 8ced667cc2
13 changed files with 507 additions and 502 deletions
+5 -2
View File
@@ -1,3 +1,5 @@
from .mode import Mode
from .process_response import handle_response_model
from .distil import FinetuneFormat, Instructions
from .dsl import (
CitationMixin,
@@ -7,8 +9,9 @@ from .dsl import (
llm_validator,
openai_moderation,
)
from .function_calls import OpenAISchema, openai_schema, Mode
from .patch import apatch, patch, handle_parallel_model, handle_response_model
from .function_calls import OpenAISchema, openai_schema
from .patch import apatch, patch
from .process_response import handle_parallel_model
__all__ = [
"OpenAISchema",
+2 -1
View File
@@ -2,7 +2,8 @@ from typing import Any, AsyncGenerator, Generator, Iterable, List, Optional, Tup
from pydantic import BaseModel, Field, create_model
from instructor.function_calls import OpenAISchema, Mode
from instructor.function_calls import OpenAISchema
from instructor.mode import Mode
from instructor.utils import extract_json_from_stream, extract_json_from_stream_async
+3 -1
View File
@@ -13,9 +13,11 @@ from typing import (
)
from types import UnionType # type: ignore[attr-defined]
from pydantic import BaseModel
from instructor.function_calls import OpenAISchema, Mode, openai_schema
from instructor.function_calls import OpenAISchema, openai_schema
from collections.abc import Iterable
from instructor.mode import Mode
T = TypeVar("T", bound=OpenAISchema)
+1 -1
View File
@@ -22,7 +22,7 @@ from typing import (
)
from copy import deepcopy
from instructor.function_calls import Mode
from instructor.mode import Mode
from instructor.dsl.partialjson import JSONParser
from instructor.utils import extract_json_from_stream, extract_json_from_stream_async
+2 -2
View File
@@ -3,8 +3,8 @@ from typing import Callable, Optional
from openai import OpenAI
from pydantic import Field
import instructor
from instructor.function_calls import OpenAISchema
from instructor.patch import patch
class Validator(OpenAISchema):
@@ -68,7 +68,7 @@ def llm_validator(
openai_client (OpenAI): The OpenAI client to use (default: None)
"""
openai_client = openai_client if openai_client else patch(OpenAI())
openai_client = openai_client if openai_client else instructor.patch(OpenAI())
def llm(v: str) -> str:
resp = openai_client.chat.completions.create(
+2 -29
View File
@@ -3,43 +3,16 @@ from docstring_parser import parse
from functools import wraps
from pydantic import BaseModel, create_model
from instructor.exceptions import IncompleteOutputException
import enum
import warnings
import logging
from openai.types.chat import ChatCompletion
from instructor.mode import Mode
from instructor.utils import extract_json_from_codeblock
import logging
T = TypeVar("T")
logger = logging.getLogger("instructor")
class Mode(enum.Enum):
"""The mode to use for patching the client"""
FUNCTIONS = "function_call"
PARALLEL_TOOLS = "parallel_tool_call"
TOOLS = "tool_call"
MISTRAL_TOOLS = "mistral_tools"
JSON = "json_mode"
MD_JSON = "markdown_json_mode"
JSON_SCHEMA = "json_schema_mode"
def __new__(cls, value: str) -> "Mode":
member = object.__new__(cls)
member._value_ = value
# Deprecation warning for FUNCTIONS
if value == "function_call":
warnings.warn(
"FUNCTIONS is deprecated and will be removed in future versions",
DeprecationWarning,
stacklevel=2,
)
return member
class OpenAISchema(BaseModel): # type: ignore[misc]
@classmethod # type: ignore[misc]
@property
+28
View File
@@ -0,0 +1,28 @@
import enum
import warnings
class Mode(enum.Enum):
"""The mode to use for patching the client"""
FUNCTIONS = "function_call"
PARALLEL_TOOLS = "parallel_tool_call"
TOOLS = "tool_call"
MISTRAL_TOOLS = "mistral_tools"
JSON = "json_mode"
MD_JSON = "markdown_json_mode"
JSON_SCHEMA = "json_schema_mode"
def __new__(cls, value: str) -> "Mode":
member = object.__new__(cls)
member._value_ = value
# Deprecation warning for FUNCTIONS
if value == "function_call":
warnings.warn(
"FUNCTIONS is deprecated and will be removed in future versions",
DeprecationWarning,
stacklevel=2,
)
return member
+6 -456
View File
@@ -1,479 +1,30 @@
# type: ignore[all]
import inspect
import logging
from textwrap import dedent
from collections.abc import Iterable
from functools import wraps
from tenacity import Retrying, AsyncRetrying, stop_after_attempt, RetryError
from json import JSONDecodeError
from typing import (
Callable,
Generator,
Optional,
ParamSpec,
Protocol,
Type,
TypeVar,
Union,
get_args,
get_origin,
overload,
)
from openai import AsyncOpenAI, OpenAI
from openai.types.chat import (
ChatCompletion,
)
from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel
from instructor.dsl.iterable import IterableModel, IterableBase
from instructor.dsl.parallel import ParallelBase, ParallelModel, handle_parallel_model
from instructor.dsl.partial import PartialBase
from instructor.dsl.simple_type import ModelAdapter, AdapterBase, is_simple_type
from instructor.utils import dump_message, update_total_usage
from instructor.process_response import handle_response_model
from instructor.retry import retry_async, retry_sync
from instructor.utils import is_async
from .function_calls import Mode, OpenAISchema, openai_schema
from instructor.mode import Mode
import logging
logger = logging.getLogger("instructor")
T = TypeVar("T")
T_Model = TypeVar("T_Model", bound=BaseModel)
T_Retval = TypeVar("T_Retval")
T_ParamSpec = ParamSpec("T_ParamSpec")
T = TypeVar("T")
def handle_response_model(
response_model: T, mode: Mode = Mode.TOOLS, **kwargs
) -> Union[Type[OpenAISchema], dict]:
"""Prepare the response model type hint, and returns the response_model
along with the new modified kwargs needed to be able to use the response_model
parameter with the patch function.
Args:
response_model (T): The response model to use for parsing the response
mode (Mode, optional): The openai completion mode. Defaults to Mode.TOOLS.
Raises:
NotImplementedError: When using stream=True with a non-iterable response_model
ValueError: When using an invalid patch mode
Returns:
Union[Type[OpenAISchema], dict]: The response model to use for parsing the response
"""
new_kwargs = kwargs.copy()
if response_model is not None:
# Handles the case where the response_model is a simple type
# Literal, Annotated, Union, str, int, float, bool, Enum
# We wrap the response_model in a ModelAdapter that sets 'content' as the response
if is_simple_type(response_model):
response_model = ModelAdapter[response_model]
# This a special case for parallel tools
if mode == Mode.PARALLEL_TOOLS:
assert (
new_kwargs.get("stream", False) is False
), "stream=True is not supported when using PARALLEL_TOOLS mode"
new_kwargs["tools"] = handle_parallel_model(response_model)
new_kwargs["tool_choice"] = "auto"
# This is a special case for parallel models
response_model = ParallelModel(typehint=response_model)
return response_model, new_kwargs
# This is for all other single model cases
if get_origin(response_model) is Iterable:
iterable_element_class = get_args(response_model)[0]
response_model = IterableModel(iterable_element_class)
if not issubclass(response_model, OpenAISchema):
response_model = openai_schema(response_model) # type: ignore
if new_kwargs.get("stream", False) and not issubclass(
response_model, (IterableBase, PartialBase)
):
raise NotImplementedError(
"stream=True is not supported when using response_model parameter for non-iterables"
)
if mode == Mode.FUNCTIONS:
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore
elif mode in {Mode.TOOLS, Mode.MISTRAL_TOOLS}:
new_kwargs["tools"] = [
{
"type": "function",
"function": response_model.openai_schema,
}
]
if mode == Mode.MISTRAL_TOOLS:
new_kwargs["tool_choice"] = "any"
else:
new_kwargs["tool_choice"] = {
"type": "function",
"function": {"name": response_model.openai_schema["name"]},
}
elif mode in {Mode.JSON, Mode.MD_JSON, Mode.JSON_SCHEMA}:
# If its a JSON Mode we need to massage the prompt a bit
# in order to get the response we want in a json format
message = dedent(
f"""
As a genius expert, your task is to understand the content and provide
the parsed objects in json that match the following json_schema:\n
{response_model.model_json_schema()}
Make sure to return an instance of the JSON, not the schema itself
"""
)
if mode == Mode.JSON:
new_kwargs["response_format"] = {"type": "json_object"}
elif mode == Mode.JSON_SCHEMA:
new_kwargs["response_format"] = {
"type": "json_object",
"schema": response_model.model_json_schema(),
}
elif mode == Mode.MD_JSON:
new_kwargs["messages"].append(
{
"role": "user",
"content": "Return the correct JSON response within a ```json codeblock. not the JSON_SCHEMA",
},
)
# check that the first message is a system message
# if it is not, add a system message to the beginning
if new_kwargs["messages"][0]["role"] != "system":
new_kwargs["messages"].insert(
0,
{
"role": "system",
"content": message,
},
)
# if it is, system append the schema to the end
else:
new_kwargs["messages"][0]["content"] += f"\n\n{message}"
else:
raise ValueError(f"Invalid patch mode: {mode}")
logger.debug(
f"Instructor Request: {mode.value=}, {response_model=}, {new_kwargs=}",
extra={
"mode": mode.value,
"response_model": response_model.__name__
if response_model is not None
else None,
"new_kwargs": new_kwargs,
},
)
return response_model, new_kwargs
def process_response(
response: T_Model,
*,
response_model: Type[OpenAISchema | BaseModel],
stream: bool,
validation_context: Optional[dict] = None,
strict=None,
mode: Mode = Mode.TOOLS,
) -> T_Model | Generator[T_Model, None, None]:
"""Processes a OpenAI response with the response model, if available.
Args:
response (T): The response from OpenAI's API
response_model (Type[T_Model]): The response model to use for parsing the response
stream (bool): Whether the response is a stream
validation_context (dict, optional): The validation context to use for validating the response. Defaults to None.
strict (_type_, optional): Whether to use strict json parsing. Defaults to None.
mode (Mode, optional): The openai completion mode. Defaults to Mode.FUNCTIONS.
Returns:
Union[T_Model, T]: The parsed response, if a response model is available, otherwise the response as is from the SDK
"""
logger.debug(
f"Instructor Raw Response: {response}",
)
if response_model is None:
logger.debug("No response model, returning response as is")
return response
if (
inspect.isclass(response_model)
and issubclass(response_model, (IterableBase, PartialBase))
and stream
):
model = response_model.from_streaming_response(
response,
mode=mode,
)
return model
model = response_model.from_response(
response,
validation_context=validation_context,
strict=strict,
mode=mode,
)
# ? This really hints at the fact that we need a better way of
# ? attaching usage data and the raw response to the model we return.
if isinstance(model, IterableBase):
logger.debug(f"Returning takes from IterableBase")
return [task for task in model.tasks]
if isinstance(response_model, ParallelBase):
logger.debug(f"Returning model from ParallelBase")
return model
if isinstance(model, AdapterBase):
logger.debug(f"Returning model from AdapterBase")
return model.content
model._raw_response = response
return model
async def process_response_async(
response: ChatCompletion,
*,
response_model: Type[T_Model | OpenAISchema | BaseModel],
stream: bool = False,
validation_context: Optional[dict] = None,
strict: Optional[bool] = None,
mode: Mode = Mode.TOOLS,
) -> T_Model | ChatCompletion:
"""Processes a OpenAI response with the response model, if available.
It can use `validation_context` and `strict` to validate the response
via the pydantic model
Args:
response (ChatCompletion): The response from OpenAI's API
response_model (BaseModel): The response model to use for parsing the response
stream (bool): Whether the response is a stream
validation_context (dict, optional): The validation context to use for validating the response. Defaults to None.
strict (bool, optional): Whether to use strict json parsing. Defaults to None.
"""
logger.debug(
f"Instructor Raw Response: {response}",
)
if response_model is None:
return response
if (
inspect.isclass(response_model)
and issubclass(response_model, (IterableBase, PartialBase))
and stream
):
model = await response_model.from_streaming_response_async(
response,
mode=mode,
)
return model
model = response_model.from_response(
response,
validation_context=validation_context,
strict=strict,
mode=mode,
)
# ? This really hints at the fact that we need a better way of
# ? attaching usage data and the raw response to the model we return.
if isinstance(model, IterableBase):
logger.debug(f"Returning takes from IterableBase")
return [task for task in model.tasks]
if isinstance(response_model, ParallelBase):
logger.debug(f"Returning model from ParallelBase")
return model
if isinstance(model, AdapterBase):
logger.debug(f"Returning model from AdapterBase")
return model.content
model._raw_response = response
return model
async def retry_async(
func: Callable[T_ParamSpec, T_Retval],
response_model: Type[T],
validation_context,
args,
kwargs,
max_retries: int | AsyncRetrying = 1,
strict: Optional[bool] = None,
mode: Mode = Mode.TOOLS,
) -> T:
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
# If max_retries is int, then create a AsyncRetrying object
if isinstance(max_retries, int):
logger.debug(f"max_retries: {max_retries}")
max_retries = AsyncRetrying(
stop=stop_after_attempt(max_retries),
reraise=True,
)
if not isinstance(max_retries, (AsyncRetrying, Retrying)):
raise ValueError(
"max_retries must be an `int` or a `tenacity.AsyncRetrying` object"
)
try:
async for attempt in max_retries:
logger.debug(f"Retrying, attempt: {attempt}")
with attempt:
try:
response: ChatCompletion = await func(*args, **kwargs) # type: ignore
stream = kwargs.get("stream", False)
response = update_total_usage(response, total_usage)
return await process_response_async(
response,
response_model=response_model,
stream=stream,
validation_context=validation_context,
strict=strict,
mode=mode,
) # type: ignore[all]
except (ValidationError, JSONDecodeError) as e:
logger.debug(f"Error response: {response}", e)
kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore
if mode == Mode.TOOLS:
kwargs["messages"].append(
{
"role": "tool",
"tool_call_id": response.choices[0]
.message.tool_calls[0]
.id,
"name": response.choices[0]
.message.tool_calls[0]
.function.name,
"content": "Exceptions found\n{e}\nRecall the function correctly.",
}
)
kwargs["messages"].append(
{
"role": "user",
"content": f"Recall the function correctly, fix the errors, exceptions found\n{e}",
}
)
if mode == Mode.MD_JSON:
kwargs["messages"].append(
{
"role": "user",
"content": "Return the correct JSON response within a ```json codeblock. not the JSON_SCHEMA",
},
)
raise e
except RetryError as e:
logger.exception(f"Failed after retries: {e.last_attempt.exception}")
raise e.last_attempt.exception from e
def retry_sync(
func: Callable[T_ParamSpec, T_Retval],
response_model: Type[T],
validation_context: dict,
args,
kwargs,
max_retries: int | Retrying = 1,
strict: Optional[bool] = None,
mode: Mode = Mode.TOOLS,
):
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
# If max_retries is int, then create a Retrying object
if isinstance(max_retries, int):
logger.debug(f"max_retries: {max_retries}")
max_retries: Retrying = Retrying(
stop=stop_after_attempt(max_retries),
reraise=True,
)
if not isinstance(max_retries, (Retrying, AsyncRetrying)):
raise ValueError("max_retries must be an int or a `tenacity.Retrying` object")
try:
for attempt in max_retries:
with attempt:
try:
response = func(*args, **kwargs)
stream = kwargs.get("stream", False)
response = update_total_usage(response, total_usage)
return process_response(
response,
response_model=response_model,
stream=stream,
validation_context=validation_context,
strict=strict,
mode=mode,
)
except (ValidationError, JSONDecodeError) as e:
logger.debug(f"Error response: {response}")
kwargs["messages"].append(dump_message(response.choices[0].message))
# ! How do we handle this for parallel tools in the future?
if mode == Mode.TOOLS:
kwargs["messages"].append(
{
"role": "tool",
"tool_call_id": response.choices[0]
.message.tool_calls[0]
.id,
"name": response.choices[0]
.message.tool_calls[0]
.function.name,
"content": f"Recall the function correctly, fix the errors and exceptions found\n{e}",
}
)
else:
kwargs["messages"].append(
{
"role": "user",
"content": f"Recall the function correctly, fix the errors and exceptions found\n{e}",
}
)
raise e
except RetryError as e:
logger.exception(f"Failed after retries: {e.last_attempt.exception}")
raise e.last_attempt.exception from e
def is_async(func: Callable) -> bool:
"""Returns true if the callable is async, accounting for wrapped callables"""
is_coroutine = inspect.iscoroutinefunction(func)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
is_coroutine = is_coroutine or inspect.iscoroutinefunction(func)
return is_coroutine
OVERRIDE_DOCS = """
Creates a new chat completion for the provided messages and parameters.
See: https://platform.openai.com/docs/api-reference/chat-completions/create
Additional Notes:
Using the `response_model` parameter, you can specify a response model to use for parsing the response from OpenAI's API. If its present, the response will be parsed using the response model, otherwise it will be returned as is.
If `stream=True` is specified, the response will be parsed using the `from_stream_response` method of the response model, if available, otherwise it will be parsed using the `from_response` method.
If need to obtain the raw response from OpenAI's API, you can access it using the `_raw_response` attribute of the response model. The `_raw_response.usage` attribute is modified to reflect the token usage from the last successful response as well as from any previous unsuccessful attempts.
Parameters:
response_model (Union[Type[BaseModel], Type[OpenAISchema]]): The response model to use for parsing the response from OpenAI's API, if available (default: None)
max_retries (int): The maximum number of retries to attempt if the response is not valid (default: 0)
validation_context (dict): The validation context to use for validating the response (default: None)
"""
class InstructorChatCompletionCreate(Protocol):
@@ -584,7 +135,6 @@ def patch(
return response
new_create = new_create_async if func_is_async else new_create_sync
new_create.__doc__ = OVERRIDE_DOCS
if client is not None:
client.chat.completions.create = new_create
+297
View File
@@ -0,0 +1,297 @@
# type: ignore[all]
from collections.abc import Iterable
from textwrap import dedent
from instructor.dsl.iterable import IterableBase, IterableModel
from instructor.dsl.parallel import ParallelBase, ParallelModel, handle_parallel_model
from instructor.dsl.partial import PartialBase
from instructor.dsl.simple_type import AdapterBase, ModelAdapter, is_simple_type
from instructor.function_calls import OpenAISchema, openai_schema
from openai.types.chat import ChatCompletion
from pydantic import BaseModel
import inspect
import logging
from typing import (
Generator,
Optional,
Type,
Tuple,
get_args,
get_origin,
TypeVar,
ParamSpec,
Any,
Dict,
)
from instructor.mode import Mode
logger = logging.getLogger("instructor")
T_Model = TypeVar("T_Model", bound=BaseModel)
T_Retval = TypeVar("T_Retval")
T_ParamSpec = ParamSpec("T_ParamSpec")
T = TypeVar("T")
async def process_response_async(
response: ChatCompletion,
*,
response_model: Type[T_Model | OpenAISchema | BaseModel],
stream: bool = False,
validation_context: Optional[dict] = None,
strict: Optional[bool] = None,
mode: Mode = Mode.TOOLS,
) -> T_Model | ChatCompletion:
"""Processes a OpenAI response with the response model, if available.
It can use `validation_context` and `strict` to validate the response
via the pydantic model
Args:
response (ChatCompletion): The response from OpenAI's API
response_model (BaseModel): The response model to use for parsing the response
stream (bool): Whether the response is a stream
validation_context (dict, optional): The validation context to use for validating the response. Defaults to None.
strict (bool, optional): Whether to use strict json parsing. Defaults to None.
"""
logger.debug(
f"Instructor Raw Response: {response}",
)
if response_model is None:
return response
if (
inspect.isclass(response_model)
and issubclass(response_model, (IterableBase, PartialBase))
and stream
):
model = await response_model.from_streaming_response_async(
response,
mode=mode,
)
return model
model = response_model.from_response(
response,
validation_context=validation_context,
strict=strict,
mode=mode,
)
# ? This really hints at the fact that we need a better way of
# ? attaching usage data and the raw response to the model we return.
if isinstance(model, IterableBase):
logger.debug(f"Returning takes from IterableBase")
return [task for task in model.tasks]
if isinstance(response_model, ParallelBase):
logger.debug(f"Returning model from ParallelBase")
return model
if isinstance(model, AdapterBase):
logger.debug(f"Returning model from AdapterBase")
return model.content
model._raw_response = response
return model
def process_response(
response: T_Model,
*,
response_model: Type[OpenAISchema | BaseModel],
stream: bool,
validation_context: Optional[dict] = None,
strict=None,
mode: Mode = Mode.TOOLS,
) -> T_Model | Generator[T_Model, None, None] | ChatCompletion:
"""Processes a OpenAI response with the response model, if available.
Args:
response (T): The response from OpenAI's API
response_model (Type[T_Model]): The response model to use for parsing the response
stream (bool): Whether the response is a stream
validation_context (dict, optional): The validation context to use for validating the response. Defaults to None.
strict (_type_, optional): Whether to use strict json parsing. Defaults to None.
mode (Mode, optional): The openai completion mode. Defaults to Mode.FUNCTIONS.
Returns:
Union[T_Model, T]: The parsed response, if a response model is available, otherwise the response as is from the SDK
"""
logger.debug(
f"Instructor Raw Response: {response}",
)
if response_model is None:
logger.debug("No response model, returning response as is")
return response
if (
inspect.isclass(response_model)
and issubclass(response_model, (IterableBase, PartialBase))
and stream
):
model = response_model.from_streaming_response(
response,
mode=mode,
)
return model
model = response_model.from_response(
response,
validation_context=validation_context,
strict=strict,
mode=mode,
)
# ? This really hints at the fact that we need a better way of
# ? attaching usage data and the raw response to the model we return.
if isinstance(model, IterableBase):
logger.debug(f"Returning takes from IterableBase")
return [task for task in model.tasks]
if isinstance(response_model, ParallelBase):
logger.debug(f"Returning model from ParallelBase")
return model
if isinstance(model, AdapterBase):
logger.debug(f"Returning model from AdapterBase")
return model.content
model._raw_response = response
return model
def handle_response_model(
response_model: T, mode: Mode = Mode.TOOLS, **kwargs
) -> Tuple[Type[OpenAISchema], Dict[str, Any]]:
"""Prepare the response model type hint, and returns the response_model
along with the new modified kwargs needed to be able to use the response_model
parameter with the patch function.
Args:
response_model (T): The response model to use for parsing the response
mode (Mode, optional): The openai completion mode. Defaults to Mode.TOOLS.
Raises:
NotImplementedError: When using stream=True with a non-iterable response_model
ValueError: When using an invalid patch mode
Returns:
Union[Type[OpenAISchema], dict]: The response model to use for parsing the response
"""
new_kwargs = kwargs.copy()
if response_model is not None:
# Handles the case where the response_model is a simple type
# Literal, Annotated, Union, str, int, float, bool, Enum
# We wrap the response_model in a ModelAdapter that sets 'content' as the response
if is_simple_type(response_model):
response_model = ModelAdapter[response_model]
# This a special case for parallel tools
if mode == Mode.PARALLEL_TOOLS:
assert (
new_kwargs.get("stream", False) is False
), "stream=True is not supported when using PARALLEL_TOOLS mode"
new_kwargs["tools"] = handle_parallel_model(response_model)
new_kwargs["tool_choice"] = "auto"
# This is a special case for parallel models
response_model = ParallelModel(typehint=response_model)
return response_model, new_kwargs
# This is for all other single model cases
if get_origin(response_model) is Iterable:
iterable_element_class = get_args(response_model)[0]
response_model = IterableModel(iterable_element_class)
if not issubclass(response_model, OpenAISchema):
response_model = openai_schema(response_model) # type: ignore
if new_kwargs.get("stream", False) and not issubclass(
response_model, (IterableBase, PartialBase)
):
raise NotImplementedError(
"stream=True is not supported when using response_model parameter for non-iterables"
)
if mode == Mode.FUNCTIONS:
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore
elif mode in {Mode.TOOLS, Mode.MISTRAL_TOOLS}:
new_kwargs["tools"] = [
{
"type": "function",
"function": response_model.openai_schema,
}
]
if mode == Mode.MISTRAL_TOOLS:
new_kwargs["tool_choice"] = "any"
else:
new_kwargs["tool_choice"] = {
"type": "function",
"function": {"name": response_model.openai_schema["name"]},
}
elif mode in {Mode.JSON, Mode.MD_JSON, Mode.JSON_SCHEMA}:
# If its a JSON Mode we need to massage the prompt a bit
# in order to get the response we want in a json format
message = dedent(
f"""
As a genius expert, your task is to understand the content and provide
the parsed objects in json that match the following json_schema:\n
{response_model.model_json_schema()}
Make sure to return an instance of the JSON, not the schema itself
"""
)
if mode == Mode.JSON:
new_kwargs["response_format"] = {"type": "json_object"}
elif mode == Mode.JSON_SCHEMA:
new_kwargs["response_format"] = {
"type": "json_object",
"schema": response_model.model_json_schema(),
}
elif mode == Mode.MD_JSON:
new_kwargs["messages"].append(
{
"role": "user",
"content": "Return the correct JSON response within a ```json codeblock. not the JSON_SCHEMA",
},
)
# check that the first message is a system message
# if it is not, add a system message to the beginning
if new_kwargs["messages"][0]["role"] != "system":
new_kwargs["messages"].insert(
0,
{
"role": "system",
"content": message,
},
)
# if it is, system append the schema to the end
else:
new_kwargs["messages"][0]["content"] += f"\n\n{message}"
else:
raise ValueError(f"Invalid patch mode: {mode}")
logger.debug(
f"Instructor Request: {mode.value=}, {response_model=}, {new_kwargs=}",
extra={
"mode": mode.value,
"response_model": response_model.__name__
if response_model is not None
else None,
"new_kwargs": new_kwargs,
},
)
return response_model, new_kwargs
+143
View File
@@ -0,0 +1,143 @@
# type: ignore[all]
import logging
from openai.types.chat import ChatCompletion
from instructor.mode import Mode
from instructor.process_response import process_response, process_response_async
from instructor.utils import dump_message, update_total_usage
from openai.types.completion_usage import CompletionUsage
from pydantic import ValidationError
from tenacity import AsyncRetrying, RetryError, Retrying, stop_after_attempt
from json import JSONDecodeError
from pydantic import BaseModel
from typing import Callable, Optional, Type, TypeVar, ParamSpec
logger = logging.getLogger("instructor")
T_Model = TypeVar("T_Model", bound=BaseModel)
T_Retval = TypeVar("T_Retval")
T_ParamSpec = ParamSpec("T_ParamSpec")
T = TypeVar("T")
def reask_messages(response: ChatCompletion, mode: Mode, exception: Exception):
yield dump_message(response.choices[0].message)
if mode == Mode.TOOLS:
for tool_call in response.choices[0].message.tool_calls: # type: ignore
yield {
"role": "tool",
"tool_call_id": tool_call.id,
"name": tool_call.function.name,
"content": f"Validation Error found:\n{exception}\nRecall the function correctly, fix the errors",
}
# TODO: Give users more control on configuration
if mode == Mode.MD_JSON:
yield {
"role": "user",
"content": f"Correct your JSON ONLY RESPONSE, based on the following errors:\n{exception}",
}
else:
yield {
"role": "user",
"content": f"Recall the function correctly, fix the errors, exceptions found\n{exception}",
}
def retry_sync(
func: Callable[T_ParamSpec, T_Retval],
response_model: Type[T_Model],
validation_context: dict,
args,
kwargs,
max_retries: int | Retrying = 1,
strict: Optional[bool] = None,
mode: Mode = Mode.TOOLS,
) -> T_Model:
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
# If max_retries is int, then create a Retrying object
if isinstance(max_retries, int):
logger.debug(f"max_retries: {max_retries}")
max_retries = Retrying(
stop=stop_after_attempt(max_retries),
reraise=True,
)
if not isinstance(max_retries, (Retrying, AsyncRetrying)):
raise ValueError("max_retries must be an int or a `tenacity.Retrying` object")
try:
for attempt in max_retries:
with attempt:
try:
response = func(*args, **kwargs)
stream = kwargs.get("stream", False)
response = update_total_usage(response, total_usage)
return process_response(
response,
response_model=response_model,
stream=stream,
validation_context=validation_context,
strict=strict,
mode=mode,
)
except (ValidationError, JSONDecodeError) as e:
logger.debug(f"Error response: {response}")
kwargs["messages"].extend(reask_messages(response, mode, e))
raise e
except RetryError as e:
logger.exception(f"Failed after retries: {e.last_attempt.exception}")
raise e.last_attempt.exception from e
async def retry_async(
func: Callable[T_ParamSpec, T_Retval],
response_model: Type[T],
validation_context,
args,
kwargs,
max_retries: int | AsyncRetrying = 1,
strict: Optional[bool] = None,
mode: Mode = Mode.TOOLS,
) -> T:
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
# If max_retries is int, then create a AsyncRetrying object
if isinstance(max_retries, int):
logger.debug(f"max_retries: {max_retries}")
max_retries = AsyncRetrying(
stop=stop_after_attempt(max_retries),
reraise=True,
)
if not isinstance(max_retries, (AsyncRetrying, Retrying)):
raise ValueError(
"max_retries must be an `int` or a `tenacity.AsyncRetrying` object"
)
try:
async for attempt in max_retries:
logger.debug(f"Retrying, attempt: {attempt}")
with attempt:
try:
response: ChatCompletion = await func(*args, **kwargs) # type: ignore
stream = kwargs.get("stream", False)
response = update_total_usage(response, total_usage)
return await process_response_async(
response,
response_model=response_model,
stream=stream,
validation_context=validation_context,
strict=strict,
mode=mode,
) # type: ignore[all]
except (ValidationError, JSONDecodeError) as e:
logger.debug(f"Error response: {response}", e)
kwargs["messages"].extend(reask_messages(response, mode, e))
raise e
except RetryError as e:
logger.exception(f"Failed after retries: {e.last_attempt.exception}")
raise e.last_attempt.exception from e
+16 -2
View File
@@ -1,5 +1,8 @@
import inspect
import json
from typing import Generator, Iterable, AsyncGenerator
from typing import Callable, Generator, Iterable, AsyncGenerator, TypeVar
from pydantic import BaseModel
from openai.types.chat import (
ChatCompletion,
@@ -7,6 +10,8 @@ from openai.types.chat import (
ChatCompletionMessageParam,
)
T_Model = TypeVar("T_Model", bound=BaseModel)
def extract_json_from_codeblock(content: str) -> str:
first_paren = content.find("{")
@@ -54,7 +59,7 @@ async def extract_json_from_stream_async(
yield char
def update_total_usage(response, total_usage):
def update_total_usage(response: T_Model, total_usage) -> T_Model | ChatCompletion:
if isinstance(response, ChatCompletion) and response.usage is not None:
total_usage.completion_tokens += response.usage.completion_tokens or 0
total_usage.prompt_tokens += response.usage.prompt_tokens or 0
@@ -81,3 +86,12 @@ def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
):
ret["content"] += json.dumps(message.model_dump()["function_call"])
return ret
def is_async(func: Callable) -> bool:
"""Returns true if the callable is async, accounting for wrapped callables"""
is_coroutine = inspect.iscoroutinefunction(func)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
is_coroutine = is_coroutine or inspect.iscoroutinefunction(func)
return is_coroutine
+1 -1
View File
@@ -1,6 +1,6 @@
[tool.poetry]
name = "instructor"
version = "0.6.3"
version = "0.6.4"
description = "structured outputs for llm"
authors = ["Jason Liu <jason@jxnl.co>"]
license = "MIT"
+1 -7
View File
@@ -3,7 +3,7 @@ import functools
from openai import AsyncOpenAI, OpenAI
import instructor
from instructor.patch import OVERRIDE_DOCS, is_async
from instructor.utils import is_async
def test_patch_completes_successfully():
@@ -71,9 +71,3 @@ def test_is_async_returns_true_if_triple_wrapped_function_is_async():
pass
assert is_async(triple_wrapped_function) is True
def test_override_docs():
assert (
"response_model" in OVERRIDE_DOCS
), "response_model should be in OVERRIDE_DOCS"