diff --git a/instructor/__init__.py b/instructor/__init__.py index b5cef02..80d2c86 100644 --- a/instructor/__init__.py +++ b/instructor/__init__.py @@ -1,3 +1,5 @@ +from .mode import Mode +from .process_response import handle_response_model from .distil import FinetuneFormat, Instructions from .dsl import ( CitationMixin, @@ -7,8 +9,9 @@ from .dsl import ( llm_validator, openai_moderation, ) -from .function_calls import OpenAISchema, openai_schema, Mode -from .patch import apatch, patch, handle_parallel_model, handle_response_model +from .function_calls import OpenAISchema, openai_schema +from .patch import apatch, patch +from .process_response import handle_parallel_model __all__ = [ "OpenAISchema", diff --git a/instructor/dsl/iterable.py b/instructor/dsl/iterable.py index 19c14da..5f9d556 100644 --- a/instructor/dsl/iterable.py +++ b/instructor/dsl/iterable.py @@ -2,7 +2,8 @@ from typing import Any, AsyncGenerator, Generator, Iterable, List, Optional, Tup from pydantic import BaseModel, Field, create_model -from instructor.function_calls import OpenAISchema, Mode +from instructor.function_calls import OpenAISchema +from instructor.mode import Mode from instructor.utils import extract_json_from_stream, extract_json_from_stream_async diff --git a/instructor/dsl/parallel.py b/instructor/dsl/parallel.py index 1ab0e20..1840951 100644 --- a/instructor/dsl/parallel.py +++ b/instructor/dsl/parallel.py @@ -13,9 +13,11 @@ from typing import ( ) from types import UnionType # type: ignore[attr-defined] from pydantic import BaseModel -from instructor.function_calls import OpenAISchema, Mode, openai_schema +from instructor.function_calls import OpenAISchema, openai_schema from collections.abc import Iterable +from instructor.mode import Mode + T = TypeVar("T", bound=OpenAISchema) diff --git a/instructor/dsl/partial.py b/instructor/dsl/partial.py index 81ae5d6..0778fae 100644 --- a/instructor/dsl/partial.py +++ b/instructor/dsl/partial.py @@ -22,7 +22,7 @@ from typing import ( ) from copy import deepcopy -from instructor.function_calls import Mode +from instructor.mode import Mode from instructor.dsl.partialjson import JSONParser from instructor.utils import extract_json_from_stream, extract_json_from_stream_async diff --git a/instructor/dsl/validators.py b/instructor/dsl/validators.py index fba60c9..9f7b77a 100644 --- a/instructor/dsl/validators.py +++ b/instructor/dsl/validators.py @@ -3,8 +3,8 @@ from typing import Callable, Optional from openai import OpenAI from pydantic import Field +import instructor from instructor.function_calls import OpenAISchema -from instructor.patch import patch class Validator(OpenAISchema): @@ -68,7 +68,7 @@ def llm_validator( openai_client (OpenAI): The OpenAI client to use (default: None) """ - openai_client = openai_client if openai_client else patch(OpenAI()) + openai_client = openai_client if openai_client else instructor.patch(OpenAI()) def llm(v: str) -> str: resp = openai_client.chat.completions.create( diff --git a/instructor/function_calls.py b/instructor/function_calls.py index bc9cbd7..36af25d 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -3,43 +3,16 @@ from docstring_parser import parse from functools import wraps from pydantic import BaseModel, create_model from instructor.exceptions import IncompleteOutputException -import enum -import warnings -import logging from openai.types.chat import ChatCompletion +from instructor.mode import Mode from instructor.utils import extract_json_from_codeblock +import logging T = TypeVar("T") logger = logging.getLogger("instructor") -class Mode(enum.Enum): - """The mode to use for patching the client""" - - FUNCTIONS = "function_call" - PARALLEL_TOOLS = "parallel_tool_call" - TOOLS = "tool_call" - MISTRAL_TOOLS = "mistral_tools" - JSON = "json_mode" - MD_JSON = "markdown_json_mode" - JSON_SCHEMA = "json_schema_mode" - - def __new__(cls, value: str) -> "Mode": - member = object.__new__(cls) - member._value_ = value - - # Deprecation warning for FUNCTIONS - if value == "function_call": - warnings.warn( - "FUNCTIONS is deprecated and will be removed in future versions", - DeprecationWarning, - stacklevel=2, - ) - - return member - - class OpenAISchema(BaseModel): # type: ignore[misc] @classmethod # type: ignore[misc] @property diff --git a/instructor/mode.py b/instructor/mode.py new file mode 100644 index 0000000..f04c4a2 --- /dev/null +++ b/instructor/mode.py @@ -0,0 +1,28 @@ +import enum +import warnings + + +class Mode(enum.Enum): + """The mode to use for patching the client""" + + FUNCTIONS = "function_call" + PARALLEL_TOOLS = "parallel_tool_call" + TOOLS = "tool_call" + MISTRAL_TOOLS = "mistral_tools" + JSON = "json_mode" + MD_JSON = "markdown_json_mode" + JSON_SCHEMA = "json_schema_mode" + + def __new__(cls, value: str) -> "Mode": + member = object.__new__(cls) + member._value_ = value + + # Deprecation warning for FUNCTIONS + if value == "function_call": + warnings.warn( + "FUNCTIONS is deprecated and will be removed in future versions", + DeprecationWarning, + stacklevel=2, + ) + + return member diff --git a/instructor/patch.py b/instructor/patch.py index 0c1572f..ad3d7fc 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -1,479 +1,30 @@ # type: ignore[all] -import inspect -import logging -from textwrap import dedent -from collections.abc import Iterable from functools import wraps -from tenacity import Retrying, AsyncRetrying, stop_after_attempt, RetryError -from json import JSONDecodeError from typing import ( Callable, - Generator, - Optional, ParamSpec, Protocol, Type, TypeVar, Union, - get_args, - get_origin, overload, ) from openai import AsyncOpenAI, OpenAI -from openai.types.chat import ( - ChatCompletion, -) -from openai.types.completion_usage import CompletionUsage -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel -from instructor.dsl.iterable import IterableModel, IterableBase -from instructor.dsl.parallel import ParallelBase, ParallelModel, handle_parallel_model -from instructor.dsl.partial import PartialBase -from instructor.dsl.simple_type import ModelAdapter, AdapterBase, is_simple_type -from instructor.utils import dump_message, update_total_usage +from instructor.process_response import handle_response_model +from instructor.retry import retry_async, retry_sync +from instructor.utils import is_async -from .function_calls import Mode, OpenAISchema, openai_schema +from instructor.mode import Mode +import logging logger = logging.getLogger("instructor") -T = TypeVar("T") - T_Model = TypeVar("T_Model", bound=BaseModel) T_Retval = TypeVar("T_Retval") T_ParamSpec = ParamSpec("T_ParamSpec") -T = TypeVar("T") - - -def handle_response_model( - response_model: T, mode: Mode = Mode.TOOLS, **kwargs -) -> Union[Type[OpenAISchema], dict]: - """Prepare the response model type hint, and returns the response_model - along with the new modified kwargs needed to be able to use the response_model - parameter with the patch function. - - - Args: - response_model (T): The response model to use for parsing the response - mode (Mode, optional): The openai completion mode. Defaults to Mode.TOOLS. - - Raises: - NotImplementedError: When using stream=True with a non-iterable response_model - ValueError: When using an invalid patch mode - - Returns: - Union[Type[OpenAISchema], dict]: The response model to use for parsing the response - """ - new_kwargs = kwargs.copy() - if response_model is not None: - # Handles the case where the response_model is a simple type - # Literal, Annotated, Union, str, int, float, bool, Enum - # We wrap the response_model in a ModelAdapter that sets 'content' as the response - if is_simple_type(response_model): - response_model = ModelAdapter[response_model] - - # This a special case for parallel tools - if mode == Mode.PARALLEL_TOOLS: - assert ( - new_kwargs.get("stream", False) is False - ), "stream=True is not supported when using PARALLEL_TOOLS mode" - new_kwargs["tools"] = handle_parallel_model(response_model) - new_kwargs["tool_choice"] = "auto" - - # This is a special case for parallel models - response_model = ParallelModel(typehint=response_model) - return response_model, new_kwargs - - # This is for all other single model cases - if get_origin(response_model) is Iterable: - iterable_element_class = get_args(response_model)[0] - response_model = IterableModel(iterable_element_class) - if not issubclass(response_model, OpenAISchema): - response_model = openai_schema(response_model) # type: ignore - - if new_kwargs.get("stream", False) and not issubclass( - response_model, (IterableBase, PartialBase) - ): - raise NotImplementedError( - "stream=True is not supported when using response_model parameter for non-iterables" - ) - - if mode == Mode.FUNCTIONS: - new_kwargs["functions"] = [response_model.openai_schema] # type: ignore - new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore - elif mode in {Mode.TOOLS, Mode.MISTRAL_TOOLS}: - new_kwargs["tools"] = [ - { - "type": "function", - "function": response_model.openai_schema, - } - ] - if mode == Mode.MISTRAL_TOOLS: - new_kwargs["tool_choice"] = "any" - else: - new_kwargs["tool_choice"] = { - "type": "function", - "function": {"name": response_model.openai_schema["name"]}, - } - elif mode in {Mode.JSON, Mode.MD_JSON, Mode.JSON_SCHEMA}: - # If its a JSON Mode we need to massage the prompt a bit - # in order to get the response we want in a json format - message = dedent( - f""" - As a genius expert, your task is to understand the content and provide - the parsed objects in json that match the following json_schema:\n - - {response_model.model_json_schema()} - - Make sure to return an instance of the JSON, not the schema itself - """ - ) - - if mode == Mode.JSON: - new_kwargs["response_format"] = {"type": "json_object"} - - elif mode == Mode.JSON_SCHEMA: - new_kwargs["response_format"] = { - "type": "json_object", - "schema": response_model.model_json_schema(), - } - - elif mode == Mode.MD_JSON: - new_kwargs["messages"].append( - { - "role": "user", - "content": "Return the correct JSON response within a ```json codeblock. not the JSON_SCHEMA", - }, - ) - # check that the first message is a system message - # if it is not, add a system message to the beginning - if new_kwargs["messages"][0]["role"] != "system": - new_kwargs["messages"].insert( - 0, - { - "role": "system", - "content": message, - }, - ) - # if it is, system append the schema to the end - else: - new_kwargs["messages"][0]["content"] += f"\n\n{message}" - else: - raise ValueError(f"Invalid patch mode: {mode}") - - logger.debug( - f"Instructor Request: {mode.value=}, {response_model=}, {new_kwargs=}", - extra={ - "mode": mode.value, - "response_model": response_model.__name__ - if response_model is not None - else None, - "new_kwargs": new_kwargs, - }, - ) - return response_model, new_kwargs - - -def process_response( - response: T_Model, - *, - response_model: Type[OpenAISchema | BaseModel], - stream: bool, - validation_context: Optional[dict] = None, - strict=None, - mode: Mode = Mode.TOOLS, -) -> T_Model | Generator[T_Model, None, None]: - """Processes a OpenAI response with the response model, if available. - - Args: - response (T): The response from OpenAI's API - response_model (Type[T_Model]): The response model to use for parsing the response - stream (bool): Whether the response is a stream - validation_context (dict, optional): The validation context to use for validating the response. Defaults to None. - strict (_type_, optional): Whether to use strict json parsing. Defaults to None. - mode (Mode, optional): The openai completion mode. Defaults to Mode.FUNCTIONS. - - Returns: - Union[T_Model, T]: The parsed response, if a response model is available, otherwise the response as is from the SDK - """ - - logger.debug( - f"Instructor Raw Response: {response}", - ) - - if response_model is None: - logger.debug("No response model, returning response as is") - return response - - if ( - inspect.isclass(response_model) - and issubclass(response_model, (IterableBase, PartialBase)) - and stream - ): - model = response_model.from_streaming_response( - response, - mode=mode, - ) - return model - - model = response_model.from_response( - response, - validation_context=validation_context, - strict=strict, - mode=mode, - ) - - # ? This really hints at the fact that we need a better way of - # ? attaching usage data and the raw response to the model we return. - if isinstance(model, IterableBase): - logger.debug(f"Returning takes from IterableBase") - return [task for task in model.tasks] - - if isinstance(response_model, ParallelBase): - logger.debug(f"Returning model from ParallelBase") - return model - - if isinstance(model, AdapterBase): - logger.debug(f"Returning model from AdapterBase") - return model.content - - model._raw_response = response - return model - - -async def process_response_async( - response: ChatCompletion, - *, - response_model: Type[T_Model | OpenAISchema | BaseModel], - stream: bool = False, - validation_context: Optional[dict] = None, - strict: Optional[bool] = None, - mode: Mode = Mode.TOOLS, -) -> T_Model | ChatCompletion: - """Processes a OpenAI response with the response model, if available. - It can use `validation_context` and `strict` to validate the response - via the pydantic model - - Args: - response (ChatCompletion): The response from OpenAI's API - response_model (BaseModel): The response model to use for parsing the response - stream (bool): Whether the response is a stream - validation_context (dict, optional): The validation context to use for validating the response. Defaults to None. - strict (bool, optional): Whether to use strict json parsing. Defaults to None. - """ - - logger.debug( - f"Instructor Raw Response: {response}", - ) - if response_model is None: - return response - - if ( - inspect.isclass(response_model) - and issubclass(response_model, (IterableBase, PartialBase)) - and stream - ): - model = await response_model.from_streaming_response_async( - response, - mode=mode, - ) - return model - - model = response_model.from_response( - response, - validation_context=validation_context, - strict=strict, - mode=mode, - ) - - # ? This really hints at the fact that we need a better way of - # ? attaching usage data and the raw response to the model we return. - if isinstance(model, IterableBase): - logger.debug(f"Returning takes from IterableBase") - return [task for task in model.tasks] - - if isinstance(response_model, ParallelBase): - logger.debug(f"Returning model from ParallelBase") - return model - - if isinstance(model, AdapterBase): - logger.debug(f"Returning model from AdapterBase") - return model.content - - model._raw_response = response - return model - - -async def retry_async( - func: Callable[T_ParamSpec, T_Retval], - response_model: Type[T], - validation_context, - args, - kwargs, - max_retries: int | AsyncRetrying = 1, - strict: Optional[bool] = None, - mode: Mode = Mode.TOOLS, -) -> T: - total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0) - - # If max_retries is int, then create a AsyncRetrying object - if isinstance(max_retries, int): - logger.debug(f"max_retries: {max_retries}") - max_retries = AsyncRetrying( - stop=stop_after_attempt(max_retries), - reraise=True, - ) - if not isinstance(max_retries, (AsyncRetrying, Retrying)): - raise ValueError( - "max_retries must be an `int` or a `tenacity.AsyncRetrying` object" - ) - - try: - async for attempt in max_retries: - logger.debug(f"Retrying, attempt: {attempt}") - with attempt: - try: - response: ChatCompletion = await func(*args, **kwargs) # type: ignore - stream = kwargs.get("stream", False) - response = update_total_usage(response, total_usage) - return await process_response_async( - response, - response_model=response_model, - stream=stream, - validation_context=validation_context, - strict=strict, - mode=mode, - ) # type: ignore[all] - except (ValidationError, JSONDecodeError) as e: - logger.debug(f"Error response: {response}", e) - kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore - if mode == Mode.TOOLS: - kwargs["messages"].append( - { - "role": "tool", - "tool_call_id": response.choices[0] - .message.tool_calls[0] - .id, - "name": response.choices[0] - .message.tool_calls[0] - .function.name, - "content": "Exceptions found\n{e}\nRecall the function correctly.", - } - ) - - kwargs["messages"].append( - { - "role": "user", - "content": f"Recall the function correctly, fix the errors, exceptions found\n{e}", - } - ) - if mode == Mode.MD_JSON: - kwargs["messages"].append( - { - "role": "user", - "content": "Return the correct JSON response within a ```json codeblock. not the JSON_SCHEMA", - }, - ) - raise e - except RetryError as e: - logger.exception(f"Failed after retries: {e.last_attempt.exception}") - raise e.last_attempt.exception from e - - -def retry_sync( - func: Callable[T_ParamSpec, T_Retval], - response_model: Type[T], - validation_context: dict, - args, - kwargs, - max_retries: int | Retrying = 1, - strict: Optional[bool] = None, - mode: Mode = Mode.TOOLS, -): - total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0) - - # If max_retries is int, then create a Retrying object - if isinstance(max_retries, int): - logger.debug(f"max_retries: {max_retries}") - max_retries: Retrying = Retrying( - stop=stop_after_attempt(max_retries), - reraise=True, - ) - if not isinstance(max_retries, (Retrying, AsyncRetrying)): - raise ValueError("max_retries must be an int or a `tenacity.Retrying` object") - - try: - for attempt in max_retries: - with attempt: - try: - response = func(*args, **kwargs) - stream = kwargs.get("stream", False) - response = update_total_usage(response, total_usage) - return process_response( - response, - response_model=response_model, - stream=stream, - validation_context=validation_context, - strict=strict, - mode=mode, - ) - except (ValidationError, JSONDecodeError) as e: - logger.debug(f"Error response: {response}") - kwargs["messages"].append(dump_message(response.choices[0].message)) - # ! How do we handle this for parallel tools in the future? - if mode == Mode.TOOLS: - kwargs["messages"].append( - { - "role": "tool", - "tool_call_id": response.choices[0] - .message.tool_calls[0] - .id, - "name": response.choices[0] - .message.tool_calls[0] - .function.name, - "content": f"Recall the function correctly, fix the errors and exceptions found\n{e}", - } - ) - else: - kwargs["messages"].append( - { - "role": "user", - "content": f"Recall the function correctly, fix the errors and exceptions found\n{e}", - } - ) - raise e - except RetryError as e: - logger.exception(f"Failed after retries: {e.last_attempt.exception}") - raise e.last_attempt.exception from e - - -def is_async(func: Callable) -> bool: - """Returns true if the callable is async, accounting for wrapped callables""" - is_coroutine = inspect.iscoroutinefunction(func) - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - is_coroutine = is_coroutine or inspect.iscoroutinefunction(func) - return is_coroutine - - -OVERRIDE_DOCS = """ -Creates a new chat completion for the provided messages and parameters. - -See: https://platform.openai.com/docs/api-reference/chat-completions/create - -Additional Notes: - -Using the `response_model` parameter, you can specify a response model to use for parsing the response from OpenAI's API. If its present, the response will be parsed using the response model, otherwise it will be returned as is. - -If `stream=True` is specified, the response will be parsed using the `from_stream_response` method of the response model, if available, otherwise it will be parsed using the `from_response` method. - -If need to obtain the raw response from OpenAI's API, you can access it using the `_raw_response` attribute of the response model. The `_raw_response.usage` attribute is modified to reflect the token usage from the last successful response as well as from any previous unsuccessful attempts. - -Parameters: - response_model (Union[Type[BaseModel], Type[OpenAISchema]]): The response model to use for parsing the response from OpenAI's API, if available (default: None) - max_retries (int): The maximum number of retries to attempt if the response is not valid (default: 0) - validation_context (dict): The validation context to use for validating the response (default: None) -""" class InstructorChatCompletionCreate(Protocol): @@ -584,7 +135,6 @@ def patch( return response new_create = new_create_async if func_is_async else new_create_sync - new_create.__doc__ = OVERRIDE_DOCS if client is not None: client.chat.completions.create = new_create diff --git a/instructor/process_response.py b/instructor/process_response.py new file mode 100644 index 0000000..a59e7bd --- /dev/null +++ b/instructor/process_response.py @@ -0,0 +1,297 @@ +# type: ignore[all] + +from collections.abc import Iterable +from textwrap import dedent +from instructor.dsl.iterable import IterableBase, IterableModel +from instructor.dsl.parallel import ParallelBase, ParallelModel, handle_parallel_model +from instructor.dsl.partial import PartialBase +from instructor.dsl.simple_type import AdapterBase, ModelAdapter, is_simple_type +from instructor.function_calls import OpenAISchema, openai_schema + +from openai.types.chat import ChatCompletion +from pydantic import BaseModel + + +import inspect +import logging +from typing import ( + Generator, + Optional, + Type, + Tuple, + get_args, + get_origin, + TypeVar, + ParamSpec, + Any, + Dict, +) + +from instructor.mode import Mode + + +logger = logging.getLogger("instructor") + +T_Model = TypeVar("T_Model", bound=BaseModel) +T_Retval = TypeVar("T_Retval") +T_ParamSpec = ParamSpec("T_ParamSpec") +T = TypeVar("T") + + +async def process_response_async( + response: ChatCompletion, + *, + response_model: Type[T_Model | OpenAISchema | BaseModel], + stream: bool = False, + validation_context: Optional[dict] = None, + strict: Optional[bool] = None, + mode: Mode = Mode.TOOLS, +) -> T_Model | ChatCompletion: + """Processes a OpenAI response with the response model, if available. + It can use `validation_context` and `strict` to validate the response + via the pydantic model + + Args: + response (ChatCompletion): The response from OpenAI's API + response_model (BaseModel): The response model to use for parsing the response + stream (bool): Whether the response is a stream + validation_context (dict, optional): The validation context to use for validating the response. Defaults to None. + strict (bool, optional): Whether to use strict json parsing. Defaults to None. + """ + + logger.debug( + f"Instructor Raw Response: {response}", + ) + if response_model is None: + return response + + if ( + inspect.isclass(response_model) + and issubclass(response_model, (IterableBase, PartialBase)) + and stream + ): + model = await response_model.from_streaming_response_async( + response, + mode=mode, + ) + return model + + model = response_model.from_response( + response, + validation_context=validation_context, + strict=strict, + mode=mode, + ) + + # ? This really hints at the fact that we need a better way of + # ? attaching usage data and the raw response to the model we return. + if isinstance(model, IterableBase): + logger.debug(f"Returning takes from IterableBase") + return [task for task in model.tasks] + + if isinstance(response_model, ParallelBase): + logger.debug(f"Returning model from ParallelBase") + return model + + if isinstance(model, AdapterBase): + logger.debug(f"Returning model from AdapterBase") + return model.content + + model._raw_response = response + return model + + +def process_response( + response: T_Model, + *, + response_model: Type[OpenAISchema | BaseModel], + stream: bool, + validation_context: Optional[dict] = None, + strict=None, + mode: Mode = Mode.TOOLS, +) -> T_Model | Generator[T_Model, None, None] | ChatCompletion: + """Processes a OpenAI response with the response model, if available. + + Args: + response (T): The response from OpenAI's API + response_model (Type[T_Model]): The response model to use for parsing the response + stream (bool): Whether the response is a stream + validation_context (dict, optional): The validation context to use for validating the response. Defaults to None. + strict (_type_, optional): Whether to use strict json parsing. Defaults to None. + mode (Mode, optional): The openai completion mode. Defaults to Mode.FUNCTIONS. + + Returns: + Union[T_Model, T]: The parsed response, if a response model is available, otherwise the response as is from the SDK + """ + + logger.debug( + f"Instructor Raw Response: {response}", + ) + + if response_model is None: + logger.debug("No response model, returning response as is") + return response + + if ( + inspect.isclass(response_model) + and issubclass(response_model, (IterableBase, PartialBase)) + and stream + ): + model = response_model.from_streaming_response( + response, + mode=mode, + ) + return model + + model = response_model.from_response( + response, + validation_context=validation_context, + strict=strict, + mode=mode, + ) + + # ? This really hints at the fact that we need a better way of + # ? attaching usage data and the raw response to the model we return. + if isinstance(model, IterableBase): + logger.debug(f"Returning takes from IterableBase") + return [task for task in model.tasks] + + if isinstance(response_model, ParallelBase): + logger.debug(f"Returning model from ParallelBase") + return model + + if isinstance(model, AdapterBase): + logger.debug(f"Returning model from AdapterBase") + return model.content + + model._raw_response = response + return model + + +def handle_response_model( + response_model: T, mode: Mode = Mode.TOOLS, **kwargs +) -> Tuple[Type[OpenAISchema], Dict[str, Any]]: + """Prepare the response model type hint, and returns the response_model + along with the new modified kwargs needed to be able to use the response_model + parameter with the patch function. + + + Args: + response_model (T): The response model to use for parsing the response + mode (Mode, optional): The openai completion mode. Defaults to Mode.TOOLS. + + Raises: + NotImplementedError: When using stream=True with a non-iterable response_model + ValueError: When using an invalid patch mode + + Returns: + Union[Type[OpenAISchema], dict]: The response model to use for parsing the response + """ + new_kwargs = kwargs.copy() + if response_model is not None: + # Handles the case where the response_model is a simple type + # Literal, Annotated, Union, str, int, float, bool, Enum + # We wrap the response_model in a ModelAdapter that sets 'content' as the response + if is_simple_type(response_model): + response_model = ModelAdapter[response_model] + + # This a special case for parallel tools + if mode == Mode.PARALLEL_TOOLS: + assert ( + new_kwargs.get("stream", False) is False + ), "stream=True is not supported when using PARALLEL_TOOLS mode" + new_kwargs["tools"] = handle_parallel_model(response_model) + new_kwargs["tool_choice"] = "auto" + + # This is a special case for parallel models + response_model = ParallelModel(typehint=response_model) + return response_model, new_kwargs + + # This is for all other single model cases + if get_origin(response_model) is Iterable: + iterable_element_class = get_args(response_model)[0] + response_model = IterableModel(iterable_element_class) + if not issubclass(response_model, OpenAISchema): + response_model = openai_schema(response_model) # type: ignore + + if new_kwargs.get("stream", False) and not issubclass( + response_model, (IterableBase, PartialBase) + ): + raise NotImplementedError( + "stream=True is not supported when using response_model parameter for non-iterables" + ) + + if mode == Mode.FUNCTIONS: + new_kwargs["functions"] = [response_model.openai_schema] # type: ignore + new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore + elif mode in {Mode.TOOLS, Mode.MISTRAL_TOOLS}: + new_kwargs["tools"] = [ + { + "type": "function", + "function": response_model.openai_schema, + } + ] + if mode == Mode.MISTRAL_TOOLS: + new_kwargs["tool_choice"] = "any" + else: + new_kwargs["tool_choice"] = { + "type": "function", + "function": {"name": response_model.openai_schema["name"]}, + } + elif mode in {Mode.JSON, Mode.MD_JSON, Mode.JSON_SCHEMA}: + # If its a JSON Mode we need to massage the prompt a bit + # in order to get the response we want in a json format + message = dedent( + f""" + As a genius expert, your task is to understand the content and provide + the parsed objects in json that match the following json_schema:\n + + {response_model.model_json_schema()} + + Make sure to return an instance of the JSON, not the schema itself + """ + ) + + if mode == Mode.JSON: + new_kwargs["response_format"] = {"type": "json_object"} + + elif mode == Mode.JSON_SCHEMA: + new_kwargs["response_format"] = { + "type": "json_object", + "schema": response_model.model_json_schema(), + } + + elif mode == Mode.MD_JSON: + new_kwargs["messages"].append( + { + "role": "user", + "content": "Return the correct JSON response within a ```json codeblock. not the JSON_SCHEMA", + }, + ) + # check that the first message is a system message + # if it is not, add a system message to the beginning + if new_kwargs["messages"][0]["role"] != "system": + new_kwargs["messages"].insert( + 0, + { + "role": "system", + "content": message, + }, + ) + # if it is, system append the schema to the end + else: + new_kwargs["messages"][0]["content"] += f"\n\n{message}" + else: + raise ValueError(f"Invalid patch mode: {mode}") + + logger.debug( + f"Instructor Request: {mode.value=}, {response_model=}, {new_kwargs=}", + extra={ + "mode": mode.value, + "response_model": response_model.__name__ + if response_model is not None + else None, + "new_kwargs": new_kwargs, + }, + ) + return response_model, new_kwargs diff --git a/instructor/retry.py b/instructor/retry.py new file mode 100644 index 0000000..2ed50a8 --- /dev/null +++ b/instructor/retry.py @@ -0,0 +1,143 @@ +# type: ignore[all] +import logging + +from openai.types.chat import ChatCompletion +from instructor.mode import Mode +from instructor.process_response import process_response, process_response_async +from instructor.utils import dump_message, update_total_usage + +from openai.types.completion_usage import CompletionUsage +from pydantic import ValidationError +from tenacity import AsyncRetrying, RetryError, Retrying, stop_after_attempt + + +from json import JSONDecodeError +from pydantic import BaseModel +from typing import Callable, Optional, Type, TypeVar, ParamSpec + +logger = logging.getLogger("instructor") + +T_Model = TypeVar("T_Model", bound=BaseModel) +T_Retval = TypeVar("T_Retval") +T_ParamSpec = ParamSpec("T_ParamSpec") +T = TypeVar("T") + + +def reask_messages(response: ChatCompletion, mode: Mode, exception: Exception): + yield dump_message(response.choices[0].message) + + if mode == Mode.TOOLS: + for tool_call in response.choices[0].message.tool_calls: # type: ignore + yield { + "role": "tool", + "tool_call_id": tool_call.id, + "name": tool_call.function.name, + "content": f"Validation Error found:\n{exception}\nRecall the function correctly, fix the errors", + } + + # TODO: Give users more control on configuration + if mode == Mode.MD_JSON: + yield { + "role": "user", + "content": f"Correct your JSON ONLY RESPONSE, based on the following errors:\n{exception}", + } + else: + yield { + "role": "user", + "content": f"Recall the function correctly, fix the errors, exceptions found\n{exception}", + } + + +def retry_sync( + func: Callable[T_ParamSpec, T_Retval], + response_model: Type[T_Model], + validation_context: dict, + args, + kwargs, + max_retries: int | Retrying = 1, + strict: Optional[bool] = None, + mode: Mode = Mode.TOOLS, +) -> T_Model: + total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0) + + # If max_retries is int, then create a Retrying object + if isinstance(max_retries, int): + logger.debug(f"max_retries: {max_retries}") + max_retries = Retrying( + stop=stop_after_attempt(max_retries), + reraise=True, + ) + if not isinstance(max_retries, (Retrying, AsyncRetrying)): + raise ValueError("max_retries must be an int or a `tenacity.Retrying` object") + + try: + for attempt in max_retries: + with attempt: + try: + response = func(*args, **kwargs) + stream = kwargs.get("stream", False) + response = update_total_usage(response, total_usage) + return process_response( + response, + response_model=response_model, + stream=stream, + validation_context=validation_context, + strict=strict, + mode=mode, + ) + except (ValidationError, JSONDecodeError) as e: + logger.debug(f"Error response: {response}") + kwargs["messages"].extend(reask_messages(response, mode, e)) + raise e + except RetryError as e: + logger.exception(f"Failed after retries: {e.last_attempt.exception}") + raise e.last_attempt.exception from e + + +async def retry_async( + func: Callable[T_ParamSpec, T_Retval], + response_model: Type[T], + validation_context, + args, + kwargs, + max_retries: int | AsyncRetrying = 1, + strict: Optional[bool] = None, + mode: Mode = Mode.TOOLS, +) -> T: + total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0) + + # If max_retries is int, then create a AsyncRetrying object + if isinstance(max_retries, int): + logger.debug(f"max_retries: {max_retries}") + max_retries = AsyncRetrying( + stop=stop_after_attempt(max_retries), + reraise=True, + ) + if not isinstance(max_retries, (AsyncRetrying, Retrying)): + raise ValueError( + "max_retries must be an `int` or a `tenacity.AsyncRetrying` object" + ) + + try: + async for attempt in max_retries: + logger.debug(f"Retrying, attempt: {attempt}") + with attempt: + try: + response: ChatCompletion = await func(*args, **kwargs) # type: ignore + stream = kwargs.get("stream", False) + response = update_total_usage(response, total_usage) + return await process_response_async( + response, + response_model=response_model, + stream=stream, + validation_context=validation_context, + strict=strict, + mode=mode, + ) # type: ignore[all] + except (ValidationError, JSONDecodeError) as e: + logger.debug(f"Error response: {response}", e) + kwargs["messages"].extend(reask_messages(response, mode, e)) + raise e + except RetryError as e: + logger.exception(f"Failed after retries: {e.last_attempt.exception}") + raise e.last_attempt.exception from e diff --git a/instructor/utils.py b/instructor/utils.py index 2ed3b33..6cbab00 100644 --- a/instructor/utils.py +++ b/instructor/utils.py @@ -1,5 +1,8 @@ +import inspect import json -from typing import Generator, Iterable, AsyncGenerator +from typing import Callable, Generator, Iterable, AsyncGenerator, TypeVar + +from pydantic import BaseModel from openai.types.chat import ( ChatCompletion, @@ -7,6 +10,8 @@ from openai.types.chat import ( ChatCompletionMessageParam, ) +T_Model = TypeVar("T_Model", bound=BaseModel) + def extract_json_from_codeblock(content: str) -> str: first_paren = content.find("{") @@ -54,7 +59,7 @@ async def extract_json_from_stream_async( yield char -def update_total_usage(response, total_usage): +def update_total_usage(response: T_Model, total_usage) -> T_Model | ChatCompletion: if isinstance(response, ChatCompletion) and response.usage is not None: total_usage.completion_tokens += response.usage.completion_tokens or 0 total_usage.prompt_tokens += response.usage.prompt_tokens or 0 @@ -81,3 +86,12 @@ def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam: ): ret["content"] += json.dumps(message.model_dump()["function_call"]) return ret + + +def is_async(func: Callable) -> bool: + """Returns true if the callable is async, accounting for wrapped callables""" + is_coroutine = inspect.iscoroutinefunction(func) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + is_coroutine = is_coroutine or inspect.iscoroutinefunction(func) + return is_coroutine diff --git a/pyproject.toml b/pyproject.toml index 4311b11..2d44247 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "instructor" -version = "0.6.3" +version = "0.6.4" description = "structured outputs for llm" authors = ["Jason Liu "] license = "MIT" diff --git a/tests/test_patch.py b/tests/test_patch.py index b0e72e5..2a25c6e 100644 --- a/tests/test_patch.py +++ b/tests/test_patch.py @@ -3,7 +3,7 @@ import functools from openai import AsyncOpenAI, OpenAI import instructor -from instructor.patch import OVERRIDE_DOCS, is_async +from instructor.utils import is_async def test_patch_completes_successfully(): @@ -71,9 +71,3 @@ def test_is_async_returns_true_if_triple_wrapped_function_is_async(): pass assert is_async(triple_wrapped_function) is True - - -def test_override_docs(): - assert ( - "response_model" in OVERRIDE_DOCS - ), "response_model should be in OVERRIDE_DOCS"