mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
636 lines
23 KiB
Python
636 lines
23 KiB
Python
import inspect
|
|
import json
|
|
import logging
|
|
from textwrap import dedent
|
|
from collections.abc import Iterable
|
|
from functools import wraps
|
|
from tenacity import Retrying, AsyncRetrying, stop_after_attempt, RetryError
|
|
from json import JSONDecodeError
|
|
from typing import (
|
|
Callable,
|
|
Optional,
|
|
ParamSpec,
|
|
Protocol,
|
|
Type,
|
|
TypeVar,
|
|
Union,
|
|
get_args,
|
|
get_origin,
|
|
overload,
|
|
)
|
|
|
|
from openai import AsyncOpenAI, OpenAI
|
|
from openai.types.chat import (
|
|
ChatCompletion,
|
|
ChatCompletionMessage,
|
|
ChatCompletionMessageParam,
|
|
)
|
|
from openai.types.completion_usage import CompletionUsage
|
|
from pydantic import BaseModel, ValidationError
|
|
|
|
from instructor.dsl.iterable import IterableModel, IterableBase
|
|
from instructor.dsl.parallel import ParallelBase, ParallelModel, handle_parallel_model
|
|
from instructor.dsl.partial import PartialBase
|
|
from instructor.dsl.simple_type import ModelAdapter, AdapterBase, is_simple_type
|
|
|
|
from .function_calls import Mode, OpenAISchema, openai_schema
|
|
|
|
logger = logging.getLogger("instructor")
|
|
T = TypeVar("T")
|
|
|
|
|
|
T_Model = TypeVar("T_Model", bound=BaseModel)
|
|
T_Retval = TypeVar("T_Retval")
|
|
T_ParamSpec = ParamSpec("T_ParamSpec")
|
|
T = TypeVar("T")
|
|
|
|
|
|
def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
|
|
"""Dumps a message to a dict, to be returned to the OpenAI API.
|
|
Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
|
|
if it isn't used.
|
|
"""
|
|
ret: ChatCompletionMessageParam = {
|
|
"role": message.role,
|
|
"content": message.content or "",
|
|
}
|
|
if hasattr(message, "tool_calls") and message.tool_calls is not None:
|
|
ret["tool_calls"] = message.model_dump()["tool_calls"]
|
|
if hasattr(message, "function_call") and message.function_call is not None:
|
|
ret["content"] += json.dumps(message.model_dump()["function_call"])
|
|
return ret
|
|
|
|
|
|
def handle_response_model(
|
|
response_model: T, mode: Mode = Mode.TOOLS, **kwargs
|
|
) -> Union[Type[OpenAISchema], dict]:
|
|
"""Prepare the response model type hint, and returns the response_model
|
|
along with the new modified kwargs needed to be able to use the response_model
|
|
parameter with the patch function.
|
|
|
|
|
|
Args:
|
|
response_model (T): The response model to use for parsing the response
|
|
mode (Mode, optional): The openai completion mode. Defaults to Mode.TOOLS.
|
|
|
|
Raises:
|
|
NotImplementedError: When using stream=True with a non-iterable response_model
|
|
ValueError: When using an invalid patch mode
|
|
|
|
Returns:
|
|
Union[Type[OpenAISchema], dict]: The response model to use for parsing the response
|
|
"""
|
|
new_kwargs = kwargs.copy()
|
|
if response_model is not None:
|
|
# Handles the case where the response_model is a simple type
|
|
# Literal, Annotated, Union, str, int, float, bool, Enum
|
|
# We wrap the response_model in a ModelAdapter that sets 'content' as the response
|
|
if is_simple_type(response_model):
|
|
response_model = ModelAdapter[response_model]
|
|
|
|
# This a special case for parallel tools
|
|
if mode == Mode.PARALLEL_TOOLS:
|
|
assert (
|
|
new_kwargs.get("stream", False) is False
|
|
), "stream=True is not supported when using PARALLEL_TOOLS mode"
|
|
new_kwargs["tools"] = handle_parallel_model(response_model)
|
|
new_kwargs["tool_choice"] = "auto"
|
|
|
|
# This is a special case for parallel models
|
|
response_model = ParallelModel(typehint=response_model)
|
|
return response_model, new_kwargs
|
|
|
|
# This is for all other single model cases
|
|
if get_origin(response_model) is Iterable:
|
|
iterable_element_class = get_args(response_model)[0]
|
|
response_model = IterableModel(iterable_element_class)
|
|
if not issubclass(response_model, OpenAISchema):
|
|
response_model = openai_schema(response_model) # type: ignore
|
|
|
|
if new_kwargs.get("stream", False) and not issubclass(
|
|
response_model, (IterableBase, PartialBase)
|
|
):
|
|
raise NotImplementedError(
|
|
"stream=True is not supported when using response_model parameter for non-iterables"
|
|
)
|
|
|
|
if mode == Mode.FUNCTIONS:
|
|
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
|
|
new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore
|
|
elif mode in {Mode.TOOLS, Mode.MISTRAL_TOOLS}:
|
|
new_kwargs["tools"] = [
|
|
{
|
|
"type": "function",
|
|
"function": response_model.openai_schema,
|
|
}
|
|
]
|
|
if mode == Mode.MISTRAL_TOOLS:
|
|
new_kwargs["tool_choice"] = "any"
|
|
else:
|
|
new_kwargs["tool_choice"] = {
|
|
"type": "function",
|
|
"function": {"name": response_model.openai_schema["name"]},
|
|
}
|
|
elif mode in {Mode.JSON, Mode.MD_JSON, Mode.JSON_SCHEMA}:
|
|
# If its a JSON Mode we need to massage the prompt a bit
|
|
# in order to get the response we want in a json format
|
|
message = dedent(
|
|
f"""
|
|
As a genius expert, your task is to understand the content and provide
|
|
the parsed objects in json that match the following json_schema:\n
|
|
{response_model.model_json_schema()['properties']}
|
|
"""
|
|
)
|
|
# Check for nested models
|
|
if "$defs" in response_model.model_json_schema():
|
|
message += f"\nHere are some more definitions to adhere too:\n{response_model.model_json_schema()['$defs']}"
|
|
|
|
if mode == Mode.JSON:
|
|
new_kwargs["response_format"] = {"type": "json_object"}
|
|
|
|
elif mode == Mode.JSON_SCHEMA:
|
|
new_kwargs["response_format"] = {
|
|
"type": "json_object",
|
|
"schema": response_model.model_json_schema(),
|
|
}
|
|
|
|
elif mode == Mode.MD_JSON:
|
|
new_kwargs["messages"].append(
|
|
{
|
|
"role": "assistant",
|
|
"content": "Here is the perfectly correctly formatted JSON\n```json",
|
|
},
|
|
)
|
|
new_kwargs["stop"] = "```"
|
|
# check that the first message is a system message
|
|
# if it is not, add a system message to the beginning
|
|
if new_kwargs["messages"][0]["role"] != "system":
|
|
new_kwargs["messages"].insert(
|
|
0,
|
|
{
|
|
"role": "system",
|
|
"content": message,
|
|
},
|
|
)
|
|
# if it is, system append the schema to the end
|
|
else:
|
|
new_kwargs["messages"][0]["content"] += f"\n\n{message}"
|
|
else:
|
|
raise ValueError(f"Invalid patch mode: {mode}")
|
|
return response_model, new_kwargs
|
|
|
|
|
|
def process_response(
|
|
response: T,
|
|
*,
|
|
response_model: Type[T_Model],
|
|
stream: bool,
|
|
validation_context: dict = None,
|
|
strict=None,
|
|
mode: Mode = Mode.TOOLS,
|
|
) -> Union[T_Model, T]:
|
|
"""Processes a OpenAI response with the response model, if available.
|
|
|
|
Args:
|
|
response (T): The response from OpenAI's API
|
|
response_model (Type[T_Model]): The response model to use for parsing the response
|
|
stream (bool): Whether the response is a stream
|
|
validation_context (dict, optional): The validation context to use for validating the response. Defaults to None.
|
|
strict (_type_, optional): Whether to use strict json parsing. Defaults to None.
|
|
mode (Mode, optional): The openai completion mode. Defaults to Mode.FUNCTIONS.
|
|
|
|
Returns:
|
|
Union[T_Model, T]: The parsed response, if a response model is available, otherwise the response as is from the SDK
|
|
"""
|
|
if response_model is None:
|
|
return response
|
|
|
|
if (
|
|
inspect.isclass(response_model)
|
|
and issubclass(response_model, (IterableBase, PartialBase))
|
|
and stream
|
|
):
|
|
model = response_model.from_streaming_response(
|
|
response,
|
|
mode=mode,
|
|
)
|
|
return model
|
|
|
|
model = response_model.from_response(
|
|
response,
|
|
validation_context=validation_context,
|
|
strict=strict,
|
|
mode=mode,
|
|
)
|
|
|
|
# ? This really hints at the fact that we need a better way of
|
|
# ? attaching usage data and the raw response to the model we return.
|
|
if isinstance(model, IterableBase):
|
|
logger.debug(f"Returning takes from IterableBase")
|
|
return [task for task in model.tasks]
|
|
|
|
if isinstance(response_model, ParallelBase):
|
|
logger.debug(f"Returning model from ParallelBase")
|
|
return model
|
|
|
|
if isinstance(model, AdapterBase):
|
|
logger.debug(f"Returning model from AdapterBase")
|
|
return model.content
|
|
|
|
model._raw_response = response
|
|
return model
|
|
|
|
|
|
async def process_response_async(
|
|
response: ChatCompletion,
|
|
*,
|
|
response_model: Type[T_Model],
|
|
stream: bool = False,
|
|
validation_context: dict = None,
|
|
strict: Optional[bool] = None,
|
|
mode: Mode = Mode.TOOLS,
|
|
) -> T:
|
|
"""Processes a OpenAI response with the response model, if available.
|
|
It can use `validation_context` and `strict` to validate the response
|
|
via the pydantic model
|
|
|
|
Args:
|
|
response (ChatCompletion): The response from OpenAI's API
|
|
response_model (BaseModel): The response model to use for parsing the response
|
|
stream (bool): Whether the response is a stream
|
|
validation_context (dict, optional): The validation context to use for validating the response. Defaults to None.
|
|
strict (bool, optional): Whether to use strict json parsing. Defaults to None.
|
|
"""
|
|
if response_model is None:
|
|
return response
|
|
|
|
if (
|
|
inspect.isclass(response_model)
|
|
and issubclass(response_model, (IterableBase, PartialBase))
|
|
and stream
|
|
):
|
|
model = await response_model.from_streaming_response_async(
|
|
response,
|
|
mode=mode,
|
|
)
|
|
return model
|
|
|
|
model = response_model.from_response(
|
|
response,
|
|
validation_context=validation_context,
|
|
strict=strict,
|
|
mode=mode,
|
|
)
|
|
|
|
# ? This really hints at the fact that we need a better way of
|
|
# ? attaching usage data and the raw response to the model we return.
|
|
if isinstance(model, IterableBase):
|
|
logger.debug(f"Returning takes from IterableBase")
|
|
return [task for task in model.tasks]
|
|
|
|
if isinstance(response_model, ParallelBase):
|
|
logger.debug(f"Returning model from ParallelBase")
|
|
return model
|
|
|
|
if isinstance(model, AdapterBase):
|
|
logger.debug(f"Returning model from AdapterBase")
|
|
return model.content
|
|
|
|
model._raw_response = response
|
|
return model
|
|
|
|
|
|
async def retry_async(
|
|
func: Callable[T_ParamSpec, T_Retval],
|
|
response_model: Type[T],
|
|
validation_context,
|
|
args,
|
|
kwargs,
|
|
max_retries: int | AsyncRetrying = 1,
|
|
strict: Optional[bool] = None,
|
|
mode: Mode = Mode.TOOLS,
|
|
) -> T:
|
|
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
|
|
|
|
# If max_retries is int, then create a AsyncRetrying object
|
|
if isinstance(max_retries, int):
|
|
logger.debug(f"max_retries: {max_retries}")
|
|
max_retries = AsyncRetrying(
|
|
stop=stop_after_attempt(max_retries),
|
|
reraise=True,
|
|
)
|
|
if not isinstance(max_retries, (AsyncRetrying, Retrying)):
|
|
raise ValueError(
|
|
"max_retries must be an `int` or a `tenacity.AsyncRetrying` object"
|
|
)
|
|
|
|
try:
|
|
async for attempt in max_retries:
|
|
logger.debug(f"Retrying, attempt: {attempt}")
|
|
with attempt:
|
|
try:
|
|
response: ChatCompletion = await func(*args, **kwargs)
|
|
stream = kwargs.get("stream", False)
|
|
if (
|
|
isinstance(response, ChatCompletion)
|
|
and response.usage is not None
|
|
):
|
|
total_usage.completion_tokens += (
|
|
response.usage.completion_tokens or 0
|
|
)
|
|
total_usage.prompt_tokens += response.usage.prompt_tokens or 0
|
|
total_usage.total_tokens += response.usage.total_tokens or 0
|
|
response.usage = total_usage # Replace each response usage with the total usage
|
|
return await process_response_async(
|
|
response,
|
|
response_model=response_model,
|
|
stream=stream,
|
|
validation_context=validation_context,
|
|
strict=strict,
|
|
mode=mode,
|
|
)
|
|
except (ValidationError, JSONDecodeError) as e:
|
|
logger.debug(f"Error response: {response}")
|
|
kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore
|
|
if mode == Mode.TOOLS:
|
|
kwargs["messages"].append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": response.choices[0]
|
|
.message.tool_calls[0]
|
|
.id,
|
|
"name": response.choices[0]
|
|
.message.tool_calls[0]
|
|
.function.name,
|
|
"content": "Exceptions found\n{e}\nRecall the function correctly.",
|
|
}
|
|
)
|
|
|
|
kwargs["messages"].append(
|
|
{
|
|
"role": "user",
|
|
"content": f"Recall the function correctly, fix the errors, exceptions found\n{e}",
|
|
}
|
|
)
|
|
if mode == Mode.MD_JSON:
|
|
kwargs["messages"].append(
|
|
{
|
|
"role": "assistant",
|
|
"content": "```json",
|
|
},
|
|
)
|
|
raise e
|
|
except RetryError as e:
|
|
logger.exception(f"Failed after retries: {e.last_attempt.exception}")
|
|
raise e.last_attempt.exception from e
|
|
|
|
|
|
def retry_sync(
|
|
func: Callable[T_ParamSpec, T_Retval],
|
|
response_model: Type[T],
|
|
validation_context: dict,
|
|
args,
|
|
kwargs,
|
|
max_retries: int | Retrying = 1,
|
|
strict: Optional[bool] = None,
|
|
mode: Mode = Mode.TOOLS,
|
|
):
|
|
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
|
|
|
|
# If max_retries is int, then create a Retrying object
|
|
if isinstance(max_retries, int):
|
|
logger.debug(f"max_retries: {max_retries}")
|
|
max_retries: Retrying = Retrying(
|
|
stop=stop_after_attempt(max_retries),
|
|
reraise=True,
|
|
)
|
|
if not isinstance(max_retries, (Retrying, AsyncRetrying)):
|
|
raise ValueError("max_retries must be an int or a `tenacity.Retrying` object")
|
|
|
|
try:
|
|
for attempt in max_retries:
|
|
with attempt:
|
|
try:
|
|
response = func(*args, **kwargs)
|
|
stream = kwargs.get("stream", False)
|
|
if (
|
|
isinstance(response, ChatCompletion)
|
|
and response.usage is not None
|
|
):
|
|
total_usage.completion_tokens += (
|
|
response.usage.completion_tokens or 0
|
|
)
|
|
total_usage.prompt_tokens += response.usage.prompt_tokens or 0
|
|
total_usage.total_tokens += response.usage.total_tokens or 0
|
|
response.usage = total_usage # Replace each response usage with the total usage
|
|
return process_response(
|
|
response,
|
|
response_model=response_model,
|
|
stream=stream,
|
|
validation_context=validation_context,
|
|
strict=strict,
|
|
mode=mode,
|
|
)
|
|
except (ValidationError, JSONDecodeError) as e:
|
|
logger.debug(f"Error response: {response}")
|
|
kwargs["messages"].append(dump_message(response.choices[0].message))
|
|
# ! How do we handle this for parallel tools in the future?
|
|
if mode == Mode.TOOLS:
|
|
kwargs["messages"].append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": response.choices[0]
|
|
.message.tool_calls[0]
|
|
.id,
|
|
"name": response.choices[0]
|
|
.message.tool_calls[0]
|
|
.function.name,
|
|
"content": f"Recall the function correctly, fix the errors and exceptions found\n{e}",
|
|
}
|
|
)
|
|
else:
|
|
kwargs["messages"].append(
|
|
{
|
|
"role": "user",
|
|
"content": f"Recall the function correctly, fix the errors and exceptions found\n{e}",
|
|
}
|
|
)
|
|
if mode == Mode.MD_JSON:
|
|
kwargs["messages"].append(
|
|
{
|
|
"role": "assistant",
|
|
"content": "```json",
|
|
},
|
|
)
|
|
raise e
|
|
except RetryError as e:
|
|
logger.exception(f"Failed after retries: {e.last_attempt.exception}")
|
|
raise e.last_attempt.exception from e
|
|
|
|
|
|
def is_async(func: Callable) -> bool:
|
|
"""Returns true if the callable is async, accounting for wrapped callables"""
|
|
is_coroutine = inspect.iscoroutinefunction(func)
|
|
while hasattr(func, "__wrapped__"):
|
|
func = func.__wrapped__
|
|
is_coroutine = is_coroutine or inspect.iscoroutinefunction(func)
|
|
return is_coroutine
|
|
|
|
|
|
OVERRIDE_DOCS = """
|
|
Creates a new chat completion for the provided messages and parameters.
|
|
|
|
See: https://platform.openai.com/docs/api-reference/chat-completions/create
|
|
|
|
Additional Notes:
|
|
|
|
Using the `response_model` parameter, you can specify a response model to use for parsing the response from OpenAI's API. If its present, the response will be parsed using the response model, otherwise it will be returned as is.
|
|
|
|
If `stream=True` is specified, the response will be parsed using the `from_stream_response` method of the response model, if available, otherwise it will be parsed using the `from_response` method.
|
|
|
|
If need to obtain the raw response from OpenAI's API, you can access it using the `_raw_response` attribute of the response model. The `_raw_response.usage` attribute is modified to reflect the token usage from the last successful response as well as from any previous unsuccessful attempts.
|
|
|
|
Parameters:
|
|
response_model (Union[Type[BaseModel], Type[OpenAISchema]]): The response model to use for parsing the response from OpenAI's API, if available (default: None)
|
|
max_retries (int): The maximum number of retries to attempt if the response is not valid (default: 0)
|
|
validation_context (dict): The validation context to use for validating the response (default: None)
|
|
"""
|
|
|
|
|
|
class InstructorChatCompletionCreate(Protocol):
|
|
def __call__(
|
|
self,
|
|
response_model: Type[T_Model] = None,
|
|
validation_context: dict = None,
|
|
max_retries: int = 1,
|
|
*args: T_ParamSpec.args,
|
|
**kwargs: T_ParamSpec.kwargs,
|
|
) -> T_Model:
|
|
...
|
|
|
|
|
|
@overload
|
|
def patch(
|
|
client: OpenAI,
|
|
mode: Mode = Mode.TOOLS,
|
|
) -> OpenAI:
|
|
...
|
|
|
|
|
|
@overload
|
|
def patch(
|
|
client: AsyncOpenAI,
|
|
mode: Mode = Mode.TOOLS,
|
|
) -> AsyncOpenAI:
|
|
...
|
|
|
|
|
|
@overload
|
|
def patch(
|
|
create: Callable[T_ParamSpec, T_Retval],
|
|
mode: Mode = Mode.TOOLS,
|
|
) -> InstructorChatCompletionCreate:
|
|
...
|
|
|
|
|
|
def patch(
|
|
client: Union[OpenAI, AsyncOpenAI] = None,
|
|
create: Callable[T_ParamSpec, T_Retval] = None,
|
|
mode: Mode = Mode.TOOLS,
|
|
) -> Union[OpenAI, AsyncOpenAI]:
|
|
"""
|
|
Patch the `client.chat.completions.create` method
|
|
|
|
Enables the following features:
|
|
|
|
- `response_model` parameter to parse the response from OpenAI's API
|
|
- `max_retries` parameter to retry the function if the response is not valid
|
|
- `validation_context` parameter to validate the response using the pydantic model
|
|
- `strict` parameter to use strict json parsing
|
|
"""
|
|
|
|
logger.debug(f"Patching `client.chat.completions.create` with {mode=}")
|
|
|
|
if create is not None:
|
|
func = create
|
|
elif client is not None:
|
|
func = client.chat.completions.create
|
|
else:
|
|
raise ValueError("Either client or create must be provided")
|
|
|
|
func_is_async = is_async(func)
|
|
|
|
@wraps(func)
|
|
async def new_create_async(
|
|
response_model: Type[T_Model] = None,
|
|
validation_context: dict = None,
|
|
max_retries: int = 1,
|
|
*args: T_ParamSpec.args,
|
|
**kwargs: T_ParamSpec.kwargs,
|
|
) -> T_Model:
|
|
response_model, new_kwargs = handle_response_model(
|
|
response_model=response_model, mode=mode, **kwargs
|
|
)
|
|
response = await retry_async(
|
|
func=func,
|
|
response_model=response_model,
|
|
validation_context=validation_context,
|
|
max_retries=max_retries,
|
|
args=args,
|
|
kwargs=new_kwargs,
|
|
mode=mode,
|
|
) # type: ignore
|
|
return response
|
|
|
|
@wraps(func)
|
|
def new_create_sync(
|
|
response_model: Type[T_Model] = None,
|
|
validation_context: dict = None,
|
|
max_retries: int = 1,
|
|
*args: T_ParamSpec.args,
|
|
**kwargs: T_ParamSpec.kwargs,
|
|
) -> T_Model:
|
|
response_model, new_kwargs = handle_response_model(
|
|
response_model=response_model, mode=mode, **kwargs
|
|
)
|
|
response = retry_sync(
|
|
func=func,
|
|
response_model=response_model,
|
|
validation_context=validation_context,
|
|
max_retries=max_retries,
|
|
args=args,
|
|
kwargs=new_kwargs,
|
|
mode=mode,
|
|
)
|
|
return response
|
|
|
|
new_create = new_create_async if func_is_async else new_create_sync
|
|
new_create.__doc__ = OVERRIDE_DOCS
|
|
|
|
if client is not None:
|
|
client.chat.completions.create = new_create
|
|
return client
|
|
else:
|
|
return new_create
|
|
|
|
|
|
def apatch(client: AsyncOpenAI, mode: Mode = Mode.TOOLS):
|
|
"""
|
|
No longer necessary, use `patch` instead.
|
|
|
|
Patch the `client.chat.completions.create` method
|
|
|
|
Enables the following features:
|
|
|
|
- `response_model` parameter to parse the response from OpenAI's API
|
|
- `max_retries` parameter to retry the function if the response is not valid
|
|
- `validation_context` parameter to validate the response using the pydantic model
|
|
- `strict` parameter to use strict json parsing
|
|
"""
|
|
import warnings
|
|
|
|
warnings.warn(
|
|
"apatch is deprecated, use patch instead", DeprecationWarning, stacklevel=2
|
|
)
|
|
return patch(client, mode=mode)
|