mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
fix: big refactor of patch.py (#493)
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from .mode import Mode
|
||||
from .process_response import handle_response_model
|
||||
from .distil import FinetuneFormat, Instructions
|
||||
from .dsl import (
|
||||
CitationMixin,
|
||||
@@ -7,8 +9,9 @@ from .dsl import (
|
||||
llm_validator,
|
||||
openai_moderation,
|
||||
)
|
||||
from .function_calls import OpenAISchema, openai_schema, Mode
|
||||
from .patch import apatch, patch, handle_parallel_model, handle_response_model
|
||||
from .function_calls import OpenAISchema, openai_schema
|
||||
from .patch import apatch, patch
|
||||
from .process_response import handle_parallel_model
|
||||
|
||||
__all__ = [
|
||||
"OpenAISchema",
|
||||
|
||||
@@ -2,7 +2,8 @@ from typing import Any, AsyncGenerator, Generator, Iterable, List, Optional, Tup
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from instructor.function_calls import OpenAISchema, Mode
|
||||
from instructor.function_calls import OpenAISchema
|
||||
from instructor.mode import Mode
|
||||
from instructor.utils import extract_json_from_stream, extract_json_from_stream_async
|
||||
|
||||
|
||||
|
||||
@@ -13,9 +13,11 @@ from typing import (
|
||||
)
|
||||
from types import UnionType # type: ignore[attr-defined]
|
||||
from pydantic import BaseModel
|
||||
from instructor.function_calls import OpenAISchema, Mode, openai_schema
|
||||
from instructor.function_calls import OpenAISchema, openai_schema
|
||||
from collections.abc import Iterable
|
||||
|
||||
from instructor.mode import Mode
|
||||
|
||||
T = TypeVar("T", bound=OpenAISchema)
|
||||
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from typing import (
|
||||
)
|
||||
from copy import deepcopy
|
||||
|
||||
from instructor.function_calls import Mode
|
||||
from instructor.mode import Mode
|
||||
from instructor.dsl.partialjson import JSONParser
|
||||
from instructor.utils import extract_json_from_stream, extract_json_from_stream_async
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import Callable, Optional
|
||||
from openai import OpenAI
|
||||
from pydantic import Field
|
||||
|
||||
import instructor
|
||||
from instructor.function_calls import OpenAISchema
|
||||
from instructor.patch import patch
|
||||
|
||||
|
||||
class Validator(OpenAISchema):
|
||||
@@ -68,7 +68,7 @@ def llm_validator(
|
||||
openai_client (OpenAI): The OpenAI client to use (default: None)
|
||||
"""
|
||||
|
||||
openai_client = openai_client if openai_client else patch(OpenAI())
|
||||
openai_client = openai_client if openai_client else instructor.patch(OpenAI())
|
||||
|
||||
def llm(v: str) -> str:
|
||||
resp = openai_client.chat.completions.create(
|
||||
|
||||
@@ -3,43 +3,16 @@ from docstring_parser import parse
|
||||
from functools import wraps
|
||||
from pydantic import BaseModel, create_model
|
||||
from instructor.exceptions import IncompleteOutputException
|
||||
import enum
|
||||
import warnings
|
||||
import logging
|
||||
from openai.types.chat import ChatCompletion
|
||||
from instructor.mode import Mode
|
||||
from instructor.utils import extract_json_from_codeblock
|
||||
import logging
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger("instructor")
|
||||
|
||||
|
||||
class Mode(enum.Enum):
|
||||
"""The mode to use for patching the client"""
|
||||
|
||||
FUNCTIONS = "function_call"
|
||||
PARALLEL_TOOLS = "parallel_tool_call"
|
||||
TOOLS = "tool_call"
|
||||
MISTRAL_TOOLS = "mistral_tools"
|
||||
JSON = "json_mode"
|
||||
MD_JSON = "markdown_json_mode"
|
||||
JSON_SCHEMA = "json_schema_mode"
|
||||
|
||||
def __new__(cls, value: str) -> "Mode":
|
||||
member = object.__new__(cls)
|
||||
member._value_ = value
|
||||
|
||||
# Deprecation warning for FUNCTIONS
|
||||
if value == "function_call":
|
||||
warnings.warn(
|
||||
"FUNCTIONS is deprecated and will be removed in future versions",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return member
|
||||
|
||||
|
||||
class OpenAISchema(BaseModel): # type: ignore[misc]
|
||||
@classmethod # type: ignore[misc]
|
||||
@property
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
import enum
|
||||
import warnings
|
||||
|
||||
|
||||
class Mode(enum.Enum):
|
||||
"""The mode to use for patching the client"""
|
||||
|
||||
FUNCTIONS = "function_call"
|
||||
PARALLEL_TOOLS = "parallel_tool_call"
|
||||
TOOLS = "tool_call"
|
||||
MISTRAL_TOOLS = "mistral_tools"
|
||||
JSON = "json_mode"
|
||||
MD_JSON = "markdown_json_mode"
|
||||
JSON_SCHEMA = "json_schema_mode"
|
||||
|
||||
def __new__(cls, value: str) -> "Mode":
|
||||
member = object.__new__(cls)
|
||||
member._value_ = value
|
||||
|
||||
# Deprecation warning for FUNCTIONS
|
||||
if value == "function_call":
|
||||
warnings.warn(
|
||||
"FUNCTIONS is deprecated and will be removed in future versions",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return member
|
||||
+6
-456
@@ -1,479 +1,30 @@
|
||||
# type: ignore[all]
|
||||
import inspect
|
||||
import logging
|
||||
from textwrap import dedent
|
||||
from collections.abc import Iterable
|
||||
from functools import wraps
|
||||
from tenacity import Retrying, AsyncRetrying, stop_after_attempt, RetryError
|
||||
from json import JSONDecodeError
|
||||
from typing import (
|
||||
Callable,
|
||||
Generator,
|
||||
Optional,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
overload,
|
||||
)
|
||||
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
)
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from instructor.dsl.iterable import IterableModel, IterableBase
|
||||
from instructor.dsl.parallel import ParallelBase, ParallelModel, handle_parallel_model
|
||||
from instructor.dsl.partial import PartialBase
|
||||
from instructor.dsl.simple_type import ModelAdapter, AdapterBase, is_simple_type
|
||||
from instructor.utils import dump_message, update_total_usage
|
||||
from instructor.process_response import handle_response_model
|
||||
from instructor.retry import retry_async, retry_sync
|
||||
from instructor.utils import is_async
|
||||
|
||||
from .function_calls import Mode, OpenAISchema, openai_schema
|
||||
from instructor.mode import Mode
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("instructor")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
T_Model = TypeVar("T_Model", bound=BaseModel)
|
||||
T_Retval = TypeVar("T_Retval")
|
||||
T_ParamSpec = ParamSpec("T_ParamSpec")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def handle_response_model(
|
||||
response_model: T, mode: Mode = Mode.TOOLS, **kwargs
|
||||
) -> Union[Type[OpenAISchema], dict]:
|
||||
"""Prepare the response model type hint, and returns the response_model
|
||||
along with the new modified kwargs needed to be able to use the response_model
|
||||
parameter with the patch function.
|
||||
|
||||
|
||||
Args:
|
||||
response_model (T): The response model to use for parsing the response
|
||||
mode (Mode, optional): The openai completion mode. Defaults to Mode.TOOLS.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: When using stream=True with a non-iterable response_model
|
||||
ValueError: When using an invalid patch mode
|
||||
|
||||
Returns:
|
||||
Union[Type[OpenAISchema], dict]: The response model to use for parsing the response
|
||||
"""
|
||||
new_kwargs = kwargs.copy()
|
||||
if response_model is not None:
|
||||
# Handles the case where the response_model is a simple type
|
||||
# Literal, Annotated, Union, str, int, float, bool, Enum
|
||||
# We wrap the response_model in a ModelAdapter that sets 'content' as the response
|
||||
if is_simple_type(response_model):
|
||||
response_model = ModelAdapter[response_model]
|
||||
|
||||
# This a special case for parallel tools
|
||||
if mode == Mode.PARALLEL_TOOLS:
|
||||
assert (
|
||||
new_kwargs.get("stream", False) is False
|
||||
), "stream=True is not supported when using PARALLEL_TOOLS mode"
|
||||
new_kwargs["tools"] = handle_parallel_model(response_model)
|
||||
new_kwargs["tool_choice"] = "auto"
|
||||
|
||||
# This is a special case for parallel models
|
||||
response_model = ParallelModel(typehint=response_model)
|
||||
return response_model, new_kwargs
|
||||
|
||||
# This is for all other single model cases
|
||||
if get_origin(response_model) is Iterable:
|
||||
iterable_element_class = get_args(response_model)[0]
|
||||
response_model = IterableModel(iterable_element_class)
|
||||
if not issubclass(response_model, OpenAISchema):
|
||||
response_model = openai_schema(response_model) # type: ignore
|
||||
|
||||
if new_kwargs.get("stream", False) and not issubclass(
|
||||
response_model, (IterableBase, PartialBase)
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"stream=True is not supported when using response_model parameter for non-iterables"
|
||||
)
|
||||
|
||||
if mode == Mode.FUNCTIONS:
|
||||
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
|
||||
new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore
|
||||
elif mode in {Mode.TOOLS, Mode.MISTRAL_TOOLS}:
|
||||
new_kwargs["tools"] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": response_model.openai_schema,
|
||||
}
|
||||
]
|
||||
if mode == Mode.MISTRAL_TOOLS:
|
||||
new_kwargs["tool_choice"] = "any"
|
||||
else:
|
||||
new_kwargs["tool_choice"] = {
|
||||
"type": "function",
|
||||
"function": {"name": response_model.openai_schema["name"]},
|
||||
}
|
||||
elif mode in {Mode.JSON, Mode.MD_JSON, Mode.JSON_SCHEMA}:
|
||||
# If its a JSON Mode we need to massage the prompt a bit
|
||||
# in order to get the response we want in a json format
|
||||
message = dedent(
|
||||
f"""
|
||||
As a genius expert, your task is to understand the content and provide
|
||||
the parsed objects in json that match the following json_schema:\n
|
||||
|
||||
{response_model.model_json_schema()}
|
||||
|
||||
Make sure to return an instance of the JSON, not the schema itself
|
||||
"""
|
||||
)
|
||||
|
||||
if mode == Mode.JSON:
|
||||
new_kwargs["response_format"] = {"type": "json_object"}
|
||||
|
||||
elif mode == Mode.JSON_SCHEMA:
|
||||
new_kwargs["response_format"] = {
|
||||
"type": "json_object",
|
||||
"schema": response_model.model_json_schema(),
|
||||
}
|
||||
|
||||
elif mode == Mode.MD_JSON:
|
||||
new_kwargs["messages"].append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Return the correct JSON response within a ```json codeblock. not the JSON_SCHEMA",
|
||||
},
|
||||
)
|
||||
# check that the first message is a system message
|
||||
# if it is not, add a system message to the beginning
|
||||
if new_kwargs["messages"][0]["role"] != "system":
|
||||
new_kwargs["messages"].insert(
|
||||
0,
|
||||
{
|
||||
"role": "system",
|
||||
"content": message,
|
||||
},
|
||||
)
|
||||
# if it is, system append the schema to the end
|
||||
else:
|
||||
new_kwargs["messages"][0]["content"] += f"\n\n{message}"
|
||||
else:
|
||||
raise ValueError(f"Invalid patch mode: {mode}")
|
||||
|
||||
logger.debug(
|
||||
f"Instructor Request: {mode.value=}, {response_model=}, {new_kwargs=}",
|
||||
extra={
|
||||
"mode": mode.value,
|
||||
"response_model": response_model.__name__
|
||||
if response_model is not None
|
||||
else None,
|
||||
"new_kwargs": new_kwargs,
|
||||
},
|
||||
)
|
||||
return response_model, new_kwargs
|
||||
|
||||
|
||||
def process_response(
|
||||
response: T_Model,
|
||||
*,
|
||||
response_model: Type[OpenAISchema | BaseModel],
|
||||
stream: bool,
|
||||
validation_context: Optional[dict] = None,
|
||||
strict=None,
|
||||
mode: Mode = Mode.TOOLS,
|
||||
) -> T_Model | Generator[T_Model, None, None]:
|
||||
"""Processes a OpenAI response with the response model, if available.
|
||||
|
||||
Args:
|
||||
response (T): The response from OpenAI's API
|
||||
response_model (Type[T_Model]): The response model to use for parsing the response
|
||||
stream (bool): Whether the response is a stream
|
||||
validation_context (dict, optional): The validation context to use for validating the response. Defaults to None.
|
||||
strict (_type_, optional): Whether to use strict json parsing. Defaults to None.
|
||||
mode (Mode, optional): The openai completion mode. Defaults to Mode.FUNCTIONS.
|
||||
|
||||
Returns:
|
||||
Union[T_Model, T]: The parsed response, if a response model is available, otherwise the response as is from the SDK
|
||||
"""
|
||||
|
||||
logger.debug(
|
||||
f"Instructor Raw Response: {response}",
|
||||
)
|
||||
|
||||
if response_model is None:
|
||||
logger.debug("No response model, returning response as is")
|
||||
return response
|
||||
|
||||
if (
|
||||
inspect.isclass(response_model)
|
||||
and issubclass(response_model, (IterableBase, PartialBase))
|
||||
and stream
|
||||
):
|
||||
model = response_model.from_streaming_response(
|
||||
response,
|
||||
mode=mode,
|
||||
)
|
||||
return model
|
||||
|
||||
model = response_model.from_response(
|
||||
response,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
# ? This really hints at the fact that we need a better way of
|
||||
# ? attaching usage data and the raw response to the model we return.
|
||||
if isinstance(model, IterableBase):
|
||||
logger.debug(f"Returning takes from IterableBase")
|
||||
return [task for task in model.tasks]
|
||||
|
||||
if isinstance(response_model, ParallelBase):
|
||||
logger.debug(f"Returning model from ParallelBase")
|
||||
return model
|
||||
|
||||
if isinstance(model, AdapterBase):
|
||||
logger.debug(f"Returning model from AdapterBase")
|
||||
return model.content
|
||||
|
||||
model._raw_response = response
|
||||
return model
|
||||
|
||||
|
||||
async def process_response_async(
|
||||
response: ChatCompletion,
|
||||
*,
|
||||
response_model: Type[T_Model | OpenAISchema | BaseModel],
|
||||
stream: bool = False,
|
||||
validation_context: Optional[dict] = None,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.TOOLS,
|
||||
) -> T_Model | ChatCompletion:
|
||||
"""Processes a OpenAI response with the response model, if available.
|
||||
It can use `validation_context` and `strict` to validate the response
|
||||
via the pydantic model
|
||||
|
||||
Args:
|
||||
response (ChatCompletion): The response from OpenAI's API
|
||||
response_model (BaseModel): The response model to use for parsing the response
|
||||
stream (bool): Whether the response is a stream
|
||||
validation_context (dict, optional): The validation context to use for validating the response. Defaults to None.
|
||||
strict (bool, optional): Whether to use strict json parsing. Defaults to None.
|
||||
"""
|
||||
|
||||
logger.debug(
|
||||
f"Instructor Raw Response: {response}",
|
||||
)
|
||||
if response_model is None:
|
||||
return response
|
||||
|
||||
if (
|
||||
inspect.isclass(response_model)
|
||||
and issubclass(response_model, (IterableBase, PartialBase))
|
||||
and stream
|
||||
):
|
||||
model = await response_model.from_streaming_response_async(
|
||||
response,
|
||||
mode=mode,
|
||||
)
|
||||
return model
|
||||
|
||||
model = response_model.from_response(
|
||||
response,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
# ? This really hints at the fact that we need a better way of
|
||||
# ? attaching usage data and the raw response to the model we return.
|
||||
if isinstance(model, IterableBase):
|
||||
logger.debug(f"Returning takes from IterableBase")
|
||||
return [task for task in model.tasks]
|
||||
|
||||
if isinstance(response_model, ParallelBase):
|
||||
logger.debug(f"Returning model from ParallelBase")
|
||||
return model
|
||||
|
||||
if isinstance(model, AdapterBase):
|
||||
logger.debug(f"Returning model from AdapterBase")
|
||||
return model.content
|
||||
|
||||
model._raw_response = response
|
||||
return model
|
||||
|
||||
|
||||
async def retry_async(
|
||||
func: Callable[T_ParamSpec, T_Retval],
|
||||
response_model: Type[T],
|
||||
validation_context,
|
||||
args,
|
||||
kwargs,
|
||||
max_retries: int | AsyncRetrying = 1,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.TOOLS,
|
||||
) -> T:
|
||||
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
|
||||
|
||||
# If max_retries is int, then create a AsyncRetrying object
|
||||
if isinstance(max_retries, int):
|
||||
logger.debug(f"max_retries: {max_retries}")
|
||||
max_retries = AsyncRetrying(
|
||||
stop=stop_after_attempt(max_retries),
|
||||
reraise=True,
|
||||
)
|
||||
if not isinstance(max_retries, (AsyncRetrying, Retrying)):
|
||||
raise ValueError(
|
||||
"max_retries must be an `int` or a `tenacity.AsyncRetrying` object"
|
||||
)
|
||||
|
||||
try:
|
||||
async for attempt in max_retries:
|
||||
logger.debug(f"Retrying, attempt: {attempt}")
|
||||
with attempt:
|
||||
try:
|
||||
response: ChatCompletion = await func(*args, **kwargs) # type: ignore
|
||||
stream = kwargs.get("stream", False)
|
||||
response = update_total_usage(response, total_usage)
|
||||
return await process_response_async(
|
||||
response,
|
||||
response_model=response_model,
|
||||
stream=stream,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
) # type: ignore[all]
|
||||
except (ValidationError, JSONDecodeError) as e:
|
||||
logger.debug(f"Error response: {response}", e)
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore
|
||||
if mode == Mode.TOOLS:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": response.choices[0]
|
||||
.message.tool_calls[0]
|
||||
.id,
|
||||
"name": response.choices[0]
|
||||
.message.tool_calls[0]
|
||||
.function.name,
|
||||
"content": "Exceptions found\n{e}\nRecall the function correctly.",
|
||||
}
|
||||
)
|
||||
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Recall the function correctly, fix the errors, exceptions found\n{e}",
|
||||
}
|
||||
)
|
||||
if mode == Mode.MD_JSON:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Return the correct JSON response within a ```json codeblock. not the JSON_SCHEMA",
|
||||
},
|
||||
)
|
||||
raise e
|
||||
except RetryError as e:
|
||||
logger.exception(f"Failed after retries: {e.last_attempt.exception}")
|
||||
raise e.last_attempt.exception from e
|
||||
|
||||
|
||||
def retry_sync(
|
||||
func: Callable[T_ParamSpec, T_Retval],
|
||||
response_model: Type[T],
|
||||
validation_context: dict,
|
||||
args,
|
||||
kwargs,
|
||||
max_retries: int | Retrying = 1,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.TOOLS,
|
||||
):
|
||||
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
|
||||
|
||||
# If max_retries is int, then create a Retrying object
|
||||
if isinstance(max_retries, int):
|
||||
logger.debug(f"max_retries: {max_retries}")
|
||||
max_retries: Retrying = Retrying(
|
||||
stop=stop_after_attempt(max_retries),
|
||||
reraise=True,
|
||||
)
|
||||
if not isinstance(max_retries, (Retrying, AsyncRetrying)):
|
||||
raise ValueError("max_retries must be an int or a `tenacity.Retrying` object")
|
||||
|
||||
try:
|
||||
for attempt in max_retries:
|
||||
with attempt:
|
||||
try:
|
||||
response = func(*args, **kwargs)
|
||||
stream = kwargs.get("stream", False)
|
||||
response = update_total_usage(response, total_usage)
|
||||
return process_response(
|
||||
response,
|
||||
response_model=response_model,
|
||||
stream=stream,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
)
|
||||
except (ValidationError, JSONDecodeError) as e:
|
||||
logger.debug(f"Error response: {response}")
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message))
|
||||
# ! How do we handle this for parallel tools in the future?
|
||||
if mode == Mode.TOOLS:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": response.choices[0]
|
||||
.message.tool_calls[0]
|
||||
.id,
|
||||
"name": response.choices[0]
|
||||
.message.tool_calls[0]
|
||||
.function.name,
|
||||
"content": f"Recall the function correctly, fix the errors and exceptions found\n{e}",
|
||||
}
|
||||
)
|
||||
else:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Recall the function correctly, fix the errors and exceptions found\n{e}",
|
||||
}
|
||||
)
|
||||
raise e
|
||||
except RetryError as e:
|
||||
logger.exception(f"Failed after retries: {e.last_attempt.exception}")
|
||||
raise e.last_attempt.exception from e
|
||||
|
||||
|
||||
def is_async(func: Callable) -> bool:
|
||||
"""Returns true if the callable is async, accounting for wrapped callables"""
|
||||
is_coroutine = inspect.iscoroutinefunction(func)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
is_coroutine = is_coroutine or inspect.iscoroutinefunction(func)
|
||||
return is_coroutine
|
||||
|
||||
|
||||
OVERRIDE_DOCS = """
|
||||
Creates a new chat completion for the provided messages and parameters.
|
||||
|
||||
See: https://platform.openai.com/docs/api-reference/chat-completions/create
|
||||
|
||||
Additional Notes:
|
||||
|
||||
Using the `response_model` parameter, you can specify a response model to use for parsing the response from OpenAI's API. If its present, the response will be parsed using the response model, otherwise it will be returned as is.
|
||||
|
||||
If `stream=True` is specified, the response will be parsed using the `from_stream_response` method of the response model, if available, otherwise it will be parsed using the `from_response` method.
|
||||
|
||||
If need to obtain the raw response from OpenAI's API, you can access it using the `_raw_response` attribute of the response model. The `_raw_response.usage` attribute is modified to reflect the token usage from the last successful response as well as from any previous unsuccessful attempts.
|
||||
|
||||
Parameters:
|
||||
response_model (Union[Type[BaseModel], Type[OpenAISchema]]): The response model to use for parsing the response from OpenAI's API, if available (default: None)
|
||||
max_retries (int): The maximum number of retries to attempt if the response is not valid (default: 0)
|
||||
validation_context (dict): The validation context to use for validating the response (default: None)
|
||||
"""
|
||||
|
||||
|
||||
class InstructorChatCompletionCreate(Protocol):
|
||||
@@ -584,7 +135,6 @@ def patch(
|
||||
return response
|
||||
|
||||
new_create = new_create_async if func_is_async else new_create_sync
|
||||
new_create.__doc__ = OVERRIDE_DOCS
|
||||
|
||||
if client is not None:
|
||||
client.chat.completions.create = new_create
|
||||
|
||||
@@ -0,0 +1,297 @@
|
||||
# type: ignore[all]
|
||||
|
||||
from collections.abc import Iterable
|
||||
from textwrap import dedent
|
||||
from instructor.dsl.iterable import IterableBase, IterableModel
|
||||
from instructor.dsl.parallel import ParallelBase, ParallelModel, handle_parallel_model
|
||||
from instructor.dsl.partial import PartialBase
|
||||
from instructor.dsl.simple_type import AdapterBase, ModelAdapter, is_simple_type
|
||||
from instructor.function_calls import OpenAISchema, openai_schema
|
||||
|
||||
from openai.types.chat import ChatCompletion
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from typing import (
|
||||
Generator,
|
||||
Optional,
|
||||
Type,
|
||||
Tuple,
|
||||
get_args,
|
||||
get_origin,
|
||||
TypeVar,
|
||||
ParamSpec,
|
||||
Any,
|
||||
Dict,
|
||||
)
|
||||
|
||||
from instructor.mode import Mode
|
||||
|
||||
|
||||
logger = logging.getLogger("instructor")
|
||||
|
||||
T_Model = TypeVar("T_Model", bound=BaseModel)
|
||||
T_Retval = TypeVar("T_Retval")
|
||||
T_ParamSpec = ParamSpec("T_ParamSpec")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def process_response_async(
|
||||
response: ChatCompletion,
|
||||
*,
|
||||
response_model: Type[T_Model | OpenAISchema | BaseModel],
|
||||
stream: bool = False,
|
||||
validation_context: Optional[dict] = None,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.TOOLS,
|
||||
) -> T_Model | ChatCompletion:
|
||||
"""Processes a OpenAI response with the response model, if available.
|
||||
It can use `validation_context` and `strict` to validate the response
|
||||
via the pydantic model
|
||||
|
||||
Args:
|
||||
response (ChatCompletion): The response from OpenAI's API
|
||||
response_model (BaseModel): The response model to use for parsing the response
|
||||
stream (bool): Whether the response is a stream
|
||||
validation_context (dict, optional): The validation context to use for validating the response. Defaults to None.
|
||||
strict (bool, optional): Whether to use strict json parsing. Defaults to None.
|
||||
"""
|
||||
|
||||
logger.debug(
|
||||
f"Instructor Raw Response: {response}",
|
||||
)
|
||||
if response_model is None:
|
||||
return response
|
||||
|
||||
if (
|
||||
inspect.isclass(response_model)
|
||||
and issubclass(response_model, (IterableBase, PartialBase))
|
||||
and stream
|
||||
):
|
||||
model = await response_model.from_streaming_response_async(
|
||||
response,
|
||||
mode=mode,
|
||||
)
|
||||
return model
|
||||
|
||||
model = response_model.from_response(
|
||||
response,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
# ? This really hints at the fact that we need a better way of
|
||||
# ? attaching usage data and the raw response to the model we return.
|
||||
if isinstance(model, IterableBase):
|
||||
logger.debug(f"Returning takes from IterableBase")
|
||||
return [task for task in model.tasks]
|
||||
|
||||
if isinstance(response_model, ParallelBase):
|
||||
logger.debug(f"Returning model from ParallelBase")
|
||||
return model
|
||||
|
||||
if isinstance(model, AdapterBase):
|
||||
logger.debug(f"Returning model from AdapterBase")
|
||||
return model.content
|
||||
|
||||
model._raw_response = response
|
||||
return model
|
||||
|
||||
|
||||
def process_response(
|
||||
response: T_Model,
|
||||
*,
|
||||
response_model: Type[OpenAISchema | BaseModel],
|
||||
stream: bool,
|
||||
validation_context: Optional[dict] = None,
|
||||
strict=None,
|
||||
mode: Mode = Mode.TOOLS,
|
||||
) -> T_Model | Generator[T_Model, None, None] | ChatCompletion:
|
||||
"""Processes a OpenAI response with the response model, if available.
|
||||
|
||||
Args:
|
||||
response (T): The response from OpenAI's API
|
||||
response_model (Type[T_Model]): The response model to use for parsing the response
|
||||
stream (bool): Whether the response is a stream
|
||||
validation_context (dict, optional): The validation context to use for validating the response. Defaults to None.
|
||||
strict (_type_, optional): Whether to use strict json parsing. Defaults to None.
|
||||
mode (Mode, optional): The openai completion mode. Defaults to Mode.FUNCTIONS.
|
||||
|
||||
Returns:
|
||||
Union[T_Model, T]: The parsed response, if a response model is available, otherwise the response as is from the SDK
|
||||
"""
|
||||
|
||||
logger.debug(
|
||||
f"Instructor Raw Response: {response}",
|
||||
)
|
||||
|
||||
if response_model is None:
|
||||
logger.debug("No response model, returning response as is")
|
||||
return response
|
||||
|
||||
if (
|
||||
inspect.isclass(response_model)
|
||||
and issubclass(response_model, (IterableBase, PartialBase))
|
||||
and stream
|
||||
):
|
||||
model = response_model.from_streaming_response(
|
||||
response,
|
||||
mode=mode,
|
||||
)
|
||||
return model
|
||||
|
||||
model = response_model.from_response(
|
||||
response,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
# ? This really hints at the fact that we need a better way of
|
||||
# ? attaching usage data and the raw response to the model we return.
|
||||
if isinstance(model, IterableBase):
|
||||
logger.debug(f"Returning takes from IterableBase")
|
||||
return [task for task in model.tasks]
|
||||
|
||||
if isinstance(response_model, ParallelBase):
|
||||
logger.debug(f"Returning model from ParallelBase")
|
||||
return model
|
||||
|
||||
if isinstance(model, AdapterBase):
|
||||
logger.debug(f"Returning model from AdapterBase")
|
||||
return model.content
|
||||
|
||||
model._raw_response = response
|
||||
return model
|
||||
|
||||
|
||||
def handle_response_model(
|
||||
response_model: T, mode: Mode = Mode.TOOLS, **kwargs
|
||||
) -> Tuple[Type[OpenAISchema], Dict[str, Any]]:
|
||||
"""Prepare the response model type hint, and returns the response_model
|
||||
along with the new modified kwargs needed to be able to use the response_model
|
||||
parameter with the patch function.
|
||||
|
||||
|
||||
Args:
|
||||
response_model (T): The response model to use for parsing the response
|
||||
mode (Mode, optional): The openai completion mode. Defaults to Mode.TOOLS.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: When using stream=True with a non-iterable response_model
|
||||
ValueError: When using an invalid patch mode
|
||||
|
||||
Returns:
|
||||
Union[Type[OpenAISchema], dict]: The response model to use for parsing the response
|
||||
"""
|
||||
new_kwargs = kwargs.copy()
|
||||
if response_model is not None:
|
||||
# Handles the case where the response_model is a simple type
|
||||
# Literal, Annotated, Union, str, int, float, bool, Enum
|
||||
# We wrap the response_model in a ModelAdapter that sets 'content' as the response
|
||||
if is_simple_type(response_model):
|
||||
response_model = ModelAdapter[response_model]
|
||||
|
||||
# This a special case for parallel tools
|
||||
if mode == Mode.PARALLEL_TOOLS:
|
||||
assert (
|
||||
new_kwargs.get("stream", False) is False
|
||||
), "stream=True is not supported when using PARALLEL_TOOLS mode"
|
||||
new_kwargs["tools"] = handle_parallel_model(response_model)
|
||||
new_kwargs["tool_choice"] = "auto"
|
||||
|
||||
# This is a special case for parallel models
|
||||
response_model = ParallelModel(typehint=response_model)
|
||||
return response_model, new_kwargs
|
||||
|
||||
# This is for all other single model cases
|
||||
if get_origin(response_model) is Iterable:
|
||||
iterable_element_class = get_args(response_model)[0]
|
||||
response_model = IterableModel(iterable_element_class)
|
||||
if not issubclass(response_model, OpenAISchema):
|
||||
response_model = openai_schema(response_model) # type: ignore
|
||||
|
||||
if new_kwargs.get("stream", False) and not issubclass(
|
||||
response_model, (IterableBase, PartialBase)
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"stream=True is not supported when using response_model parameter for non-iterables"
|
||||
)
|
||||
|
||||
if mode == Mode.FUNCTIONS:
|
||||
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
|
||||
new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore
|
||||
elif mode in {Mode.TOOLS, Mode.MISTRAL_TOOLS}:
|
||||
new_kwargs["tools"] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": response_model.openai_schema,
|
||||
}
|
||||
]
|
||||
if mode == Mode.MISTRAL_TOOLS:
|
||||
new_kwargs["tool_choice"] = "any"
|
||||
else:
|
||||
new_kwargs["tool_choice"] = {
|
||||
"type": "function",
|
||||
"function": {"name": response_model.openai_schema["name"]},
|
||||
}
|
||||
elif mode in {Mode.JSON, Mode.MD_JSON, Mode.JSON_SCHEMA}:
|
||||
# If its a JSON Mode we need to massage the prompt a bit
|
||||
# in order to get the response we want in a json format
|
||||
message = dedent(
|
||||
f"""
|
||||
As a genius expert, your task is to understand the content and provide
|
||||
the parsed objects in json that match the following json_schema:\n
|
||||
|
||||
{response_model.model_json_schema()}
|
||||
|
||||
Make sure to return an instance of the JSON, not the schema itself
|
||||
"""
|
||||
)
|
||||
|
||||
if mode == Mode.JSON:
|
||||
new_kwargs["response_format"] = {"type": "json_object"}
|
||||
|
||||
elif mode == Mode.JSON_SCHEMA:
|
||||
new_kwargs["response_format"] = {
|
||||
"type": "json_object",
|
||||
"schema": response_model.model_json_schema(),
|
||||
}
|
||||
|
||||
elif mode == Mode.MD_JSON:
|
||||
new_kwargs["messages"].append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Return the correct JSON response within a ```json codeblock. not the JSON_SCHEMA",
|
||||
},
|
||||
)
|
||||
# check that the first message is a system message
|
||||
# if it is not, add a system message to the beginning
|
||||
if new_kwargs["messages"][0]["role"] != "system":
|
||||
new_kwargs["messages"].insert(
|
||||
0,
|
||||
{
|
||||
"role": "system",
|
||||
"content": message,
|
||||
},
|
||||
)
|
||||
# if it is, system append the schema to the end
|
||||
else:
|
||||
new_kwargs["messages"][0]["content"] += f"\n\n{message}"
|
||||
else:
|
||||
raise ValueError(f"Invalid patch mode: {mode}")
|
||||
|
||||
logger.debug(
|
||||
f"Instructor Request: {mode.value=}, {response_model=}, {new_kwargs=}",
|
||||
extra={
|
||||
"mode": mode.value,
|
||||
"response_model": response_model.__name__
|
||||
if response_model is not None
|
||||
else None,
|
||||
"new_kwargs": new_kwargs,
|
||||
},
|
||||
)
|
||||
return response_model, new_kwargs
|
||||
@@ -0,0 +1,143 @@
|
||||
# type: ignore[all]
|
||||
import logging
|
||||
|
||||
from openai.types.chat import ChatCompletion
|
||||
from instructor.mode import Mode
|
||||
from instructor.process_response import process_response, process_response_async
|
||||
from instructor.utils import dump_message, update_total_usage
|
||||
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from pydantic import ValidationError
|
||||
from tenacity import AsyncRetrying, RetryError, Retrying, stop_after_attempt
|
||||
|
||||
|
||||
from json import JSONDecodeError
|
||||
from pydantic import BaseModel
|
||||
from typing import Callable, Optional, Type, TypeVar, ParamSpec
|
||||
|
||||
logger = logging.getLogger("instructor")
|
||||
|
||||
T_Model = TypeVar("T_Model", bound=BaseModel)
|
||||
T_Retval = TypeVar("T_Retval")
|
||||
T_ParamSpec = ParamSpec("T_ParamSpec")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def reask_messages(response: ChatCompletion, mode: Mode, exception: Exception):
|
||||
yield dump_message(response.choices[0].message)
|
||||
|
||||
if mode == Mode.TOOLS:
|
||||
for tool_call in response.choices[0].message.tool_calls: # type: ignore
|
||||
yield {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_call.function.name,
|
||||
"content": f"Validation Error found:\n{exception}\nRecall the function correctly, fix the errors",
|
||||
}
|
||||
|
||||
# TODO: Give users more control on configuration
|
||||
if mode == Mode.MD_JSON:
|
||||
yield {
|
||||
"role": "user",
|
||||
"content": f"Correct your JSON ONLY RESPONSE, based on the following errors:\n{exception}",
|
||||
}
|
||||
else:
|
||||
yield {
|
||||
"role": "user",
|
||||
"content": f"Recall the function correctly, fix the errors, exceptions found\n{exception}",
|
||||
}
|
||||
|
||||
|
||||
def retry_sync(
|
||||
func: Callable[T_ParamSpec, T_Retval],
|
||||
response_model: Type[T_Model],
|
||||
validation_context: dict,
|
||||
args,
|
||||
kwargs,
|
||||
max_retries: int | Retrying = 1,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.TOOLS,
|
||||
) -> T_Model:
|
||||
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
|
||||
|
||||
# If max_retries is int, then create a Retrying object
|
||||
if isinstance(max_retries, int):
|
||||
logger.debug(f"max_retries: {max_retries}")
|
||||
max_retries = Retrying(
|
||||
stop=stop_after_attempt(max_retries),
|
||||
reraise=True,
|
||||
)
|
||||
if not isinstance(max_retries, (Retrying, AsyncRetrying)):
|
||||
raise ValueError("max_retries must be an int or a `tenacity.Retrying` object")
|
||||
|
||||
try:
|
||||
for attempt in max_retries:
|
||||
with attempt:
|
||||
try:
|
||||
response = func(*args, **kwargs)
|
||||
stream = kwargs.get("stream", False)
|
||||
response = update_total_usage(response, total_usage)
|
||||
return process_response(
|
||||
response,
|
||||
response_model=response_model,
|
||||
stream=stream,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
)
|
||||
except (ValidationError, JSONDecodeError) as e:
|
||||
logger.debug(f"Error response: {response}")
|
||||
kwargs["messages"].extend(reask_messages(response, mode, e))
|
||||
raise e
|
||||
except RetryError as e:
|
||||
logger.exception(f"Failed after retries: {e.last_attempt.exception}")
|
||||
raise e.last_attempt.exception from e
|
||||
|
||||
|
||||
async def retry_async(
|
||||
func: Callable[T_ParamSpec, T_Retval],
|
||||
response_model: Type[T],
|
||||
validation_context,
|
||||
args,
|
||||
kwargs,
|
||||
max_retries: int | AsyncRetrying = 1,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.TOOLS,
|
||||
) -> T:
|
||||
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
|
||||
|
||||
# If max_retries is int, then create a AsyncRetrying object
|
||||
if isinstance(max_retries, int):
|
||||
logger.debug(f"max_retries: {max_retries}")
|
||||
max_retries = AsyncRetrying(
|
||||
stop=stop_after_attempt(max_retries),
|
||||
reraise=True,
|
||||
)
|
||||
if not isinstance(max_retries, (AsyncRetrying, Retrying)):
|
||||
raise ValueError(
|
||||
"max_retries must be an `int` or a `tenacity.AsyncRetrying` object"
|
||||
)
|
||||
|
||||
try:
|
||||
async for attempt in max_retries:
|
||||
logger.debug(f"Retrying, attempt: {attempt}")
|
||||
with attempt:
|
||||
try:
|
||||
response: ChatCompletion = await func(*args, **kwargs) # type: ignore
|
||||
stream = kwargs.get("stream", False)
|
||||
response = update_total_usage(response, total_usage)
|
||||
return await process_response_async(
|
||||
response,
|
||||
response_model=response_model,
|
||||
stream=stream,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
) # type: ignore[all]
|
||||
except (ValidationError, JSONDecodeError) as e:
|
||||
logger.debug(f"Error response: {response}", e)
|
||||
kwargs["messages"].extend(reask_messages(response, mode, e))
|
||||
raise e
|
||||
except RetryError as e:
|
||||
logger.exception(f"Failed after retries: {e.last_attempt.exception}")
|
||||
raise e.last_attempt.exception from e
|
||||
+16
-2
@@ -1,5 +1,8 @@
|
||||
import inspect
|
||||
import json
|
||||
from typing import Generator, Iterable, AsyncGenerator
|
||||
from typing import Callable, Generator, Iterable, AsyncGenerator, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
@@ -7,6 +10,8 @@ from openai.types.chat import (
|
||||
ChatCompletionMessageParam,
|
||||
)
|
||||
|
||||
T_Model = TypeVar("T_Model", bound=BaseModel)
|
||||
|
||||
|
||||
def extract_json_from_codeblock(content: str) -> str:
|
||||
first_paren = content.find("{")
|
||||
@@ -54,7 +59,7 @@ async def extract_json_from_stream_async(
|
||||
yield char
|
||||
|
||||
|
||||
def update_total_usage(response, total_usage):
|
||||
def update_total_usage(response: T_Model, total_usage) -> T_Model | ChatCompletion:
|
||||
if isinstance(response, ChatCompletion) and response.usage is not None:
|
||||
total_usage.completion_tokens += response.usage.completion_tokens or 0
|
||||
total_usage.prompt_tokens += response.usage.prompt_tokens or 0
|
||||
@@ -81,3 +86,12 @@ def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
|
||||
):
|
||||
ret["content"] += json.dumps(message.model_dump()["function_call"])
|
||||
return ret
|
||||
|
||||
|
||||
def is_async(func: Callable) -> bool:
|
||||
"""Returns true if the callable is async, accounting for wrapped callables"""
|
||||
is_coroutine = inspect.iscoroutinefunction(func)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
is_coroutine = is_coroutine or inspect.iscoroutinefunction(func)
|
||||
return is_coroutine
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "instructor"
|
||||
version = "0.6.3"
|
||||
version = "0.6.4"
|
||||
description = "structured outputs for llm"
|
||||
authors = ["Jason Liu <jason@jxnl.co>"]
|
||||
license = "MIT"
|
||||
|
||||
+1
-7
@@ -3,7 +3,7 @@ import functools
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
|
||||
import instructor
|
||||
from instructor.patch import OVERRIDE_DOCS, is_async
|
||||
from instructor.utils import is_async
|
||||
|
||||
|
||||
def test_patch_completes_successfully():
|
||||
@@ -71,9 +71,3 @@ def test_is_async_returns_true_if_triple_wrapped_function_is_async():
|
||||
pass
|
||||
|
||||
assert is_async(triple_wrapped_function) is True
|
||||
|
||||
|
||||
def test_override_docs():
|
||||
assert (
|
||||
"response_model" in OVERRIDE_DOCS
|
||||
), "response_model should be in OVERRIDE_DOCS"
|
||||
|
||||
Reference in New Issue
Block a user