From f0d7889021d4eff2554cc946dd76e838fe4baedf Mon Sep 17 00:00:00 2001 From: Ezzeri Esa Date: Wed, 14 Feb 2024 14:37:31 -0800 Subject: [PATCH] chore: Include types to instructor.patch (#422) --- .github/workflows/mypy.yml | 2 + instructor/patch.py | 156 ++++++++++++++++++------------------- tests/test_patch.py | 20 ++--- 3 files changed, 89 insertions(+), 89 deletions(-) diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index d76a711..ef34ea3 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -25,8 +25,10 @@ env: instructor/dsl/partialjson.py instructor/dsl/validators.py instructor/function_calls.py + instructor/patch.py tests/test_function_calls.py tests/test_distil.py + tests/test_patch.py jobs: MyPy: diff --git a/instructor/patch.py b/instructor/patch.py index e5adde1..0d3f068 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -3,19 +3,22 @@ import json import logging from collections.abc import Iterable from functools import wraps -from tenacity import Retrying, AsyncRetrying, stop_after_attempt, RetryError +from tenacity import Retrying, AsyncRetrying, stop_after_attempt, RetryError # type: ignore[import-not-found] from json import JSONDecodeError from typing import ( + Any, Callable, - Optional, - ParamSpec, - Protocol, - Type, - TypeVar, - Union, + Dict, + Generator, get_args, get_origin, + List, + Optional, overload, + Protocol, + Tuple, + Type, + Union, ) from openai import AsyncOpenAI, OpenAI @@ -34,13 +37,6 @@ from instructor.dsl.partial import PartialBase from .function_calls import Mode, OpenAISchema, openai_schema logger = logging.getLogger("instructor") -T = TypeVar("T") - - -T_Model = TypeVar("T_Model", bound=BaseModel) -T_Retval = TypeVar("T_Retval") -T_ParamSpec = ParamSpec("T_ParamSpec") -T = TypeVar("T") def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam: @@ -60,8 +56,8 @@ def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam: def handle_response_model( - response_model: T, mode: Mode = Mode.TOOLS, **kwargs -) -> Union[Type[OpenAISchema], dict]: + response_model: Type[BaseModel], mode: Mode = Mode.TOOLS, **kwargs: Any +) -> Tuple[Union[Type[OpenAISchema], ParallelBase], Dict[str, Any]]: """Prepare the response model type hint, and returns the response_model along with the new modified kwargs needed to be able to use the response_model parameter with the patch function. @@ -89,15 +85,14 @@ def handle_response_model( new_kwargs["tool_choice"] = "auto" # This is a special case for parallel models - response_model = ParallelModel(typehint=response_model) - return response_model, new_kwargs + return ParallelModel(typehint=response_model), new_kwargs # This is for all other single model cases if get_origin(response_model) is Iterable: iterable_element_class = get_args(response_model)[0] response_model = IterableModel(iterable_element_class) if not issubclass(response_model, OpenAISchema): - response_model = openai_schema(response_model) # type: ignore + response_model = openai_schema(response_model) if new_kwargs.get("stream", False) and not issubclass( response_model, (IterableBase, PartialBase) @@ -107,8 +102,8 @@ def handle_response_model( ) if mode == Mode.FUNCTIONS: - new_kwargs["functions"] = [response_model.openai_schema] # type: ignore - new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore + new_kwargs["functions"] = [response_model.openai_schema] + new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} elif mode == Mode.TOOLS: new_kwargs["tools"] = [ { @@ -168,14 +163,14 @@ def handle_response_model( def process_response( - response: T, + response: ChatCompletion, *, - response_model: Type[T_Model], + response_model: Union[Type[OpenAISchema], ParallelBase, None], stream: bool, - validation_context: dict = None, - strict=None, + validation_context: Optional[Dict[str, Any]] = None, + strict: Optional[bool] = None, mode: Mode = Mode.FUNCTIONS, -) -> Union[T_Model, T]: +) -> Union[OpenAISchema, List[OpenAISchema]]: """Processes a OpenAI response with the response model, if available. Args: @@ -213,11 +208,12 @@ def process_response( # ? This really hints at the fact that we need a better way of # ? attaching usage data and the raw response to the model we return. if isinstance(model, IterableBase): - return [task for task in model.tasks] + return [task for task in model.tasks] # type: ignore[attr-defined] if isinstance(response_model, ParallelBase): return model + assert hasattr(model, "_raw_response") model._raw_response = response return model @@ -225,12 +221,12 @@ def process_response( async def process_response_async( response: ChatCompletion, *, - response_model: Type[T_Model], - stream: bool = False, - validation_context: dict = None, + response_model: Union[Type[OpenAISchema], ParallelBase, None], + stream: bool, + validation_context: Optional[Dict[str, Any]] = None, strict: Optional[bool] = None, mode: Mode = Mode.FUNCTIONS, -) -> T: +) -> Union[OpenAISchema, List[OpenAISchema], Dict[str, Any], Generator[Any, None, None]]: """Processes a OpenAI response with the response model, if available. It can use `validation_context` and `strict` to validate the response via the pydantic model @@ -250,42 +246,44 @@ async def process_response_async( and issubclass(response_model, (IterableBase, PartialBase)) and stream ): - model = await response_model.from_streaming_response_async( + await_model = await response_model.from_streaming_response_async( response, mode=mode, ) - return model + return await_model model = response_model.from_response( response, validation_context=validation_context, strict=strict, mode=mode, - ) + ) # type: ignore[var-annotated] # ? This really hints at the fact that we need a better way of # ? attaching usage data and the raw response to the model we return. if isinstance(model, IterableBase): #! If the response model is a multitask, return the tasks + assert hasattr(model, "tasks") return [task for task in model.tasks] if isinstance(response_model, ParallelBase): return model + assert hasattr(model, "_raw_response") model._raw_response = response return model -async def retry_async( - func: Callable[T_ParamSpec, T_Retval], - response_model: Type[T], - validation_context, - args, - kwargs, - max_retries: int | AsyncRetrying = 1, +async def retry_async( # type: ignore[return] + func: Callable[..., ChatCompletion], + response_model: Union[Type[OpenAISchema], ParallelBase], + validation_context: Optional[Dict[str, Any]], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + max_retries: int, strict: Optional[bool] = None, mode: Mode = Mode.FUNCTIONS, -) -> T: +) -> Union[OpenAISchema, List[OpenAISchema], Dict[str, Any], Generator[Any, None, None]]: total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0) # If max_retries is int, then create a AsyncRetrying object @@ -327,7 +325,7 @@ async def retry_async( ) except (ValidationError, JSONDecodeError) as e: logger.debug(f"Error response: {response}") - kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore + kwargs["messages"].append(dump_message(response.choices[0].message)) if mode == Mode.TOOLS: kwargs["messages"].append( { @@ -361,22 +359,22 @@ async def retry_async( raise e.last_attempt.exception from e -def retry_sync( - func: Callable[T_ParamSpec, T_Retval], - response_model: Type[T], - validation_context: dict, - args, - kwargs, - max_retries: int | Retrying = 1, +def retry_sync( # type: ignore[return] + func: Callable[..., ChatCompletion], + response_model: Union[Type[OpenAISchema], ParallelBase], + validation_context: Optional[Dict[str, Any]], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + max_retries: int, strict: Optional[bool] = None, mode: Mode = Mode.FUNCTIONS, -): +) -> Union[OpenAISchema, List[OpenAISchema]]: total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0) # If max_retries is int, then create a Retrying object if isinstance(max_retries, int): logger.debug(f"max_retries: {max_retries}") - max_retries: Retrying = Retrying( + max_retries = Retrying( stop=stop_after_attempt(max_retries), reraise=True, ) @@ -444,7 +442,7 @@ def retry_sync( raise e.last_attempt.exception from e -def is_async(func: Callable) -> bool: +def is_async(func: Callable[..., Any]) -> bool: """Returns true if the callable is async, accounting for wrapped callables""" return inspect.iscoroutinefunction(func) or ( hasattr(func, "__wrapped__") and inspect.iscoroutinefunction(func.__wrapped__) @@ -474,12 +472,12 @@ Parameters: class InstructorChatCompletionCreate(Protocol): def __call__( self, - response_model: Type[T_Model] = None, - validation_context: dict = None, + response_model: Union[Type[BaseModel], ParallelBase, None] = None, + validation_context: Optional[Dict[str, Any]] = None, max_retries: int = 1, - *args: T_ParamSpec.args, - **kwargs: T_ParamSpec.kwargs, - ) -> T_Model: + *args: Any, + **kwargs: Any, + ) -> Type[BaseModel]: ... @@ -492,7 +490,7 @@ def patch( @overload -def patch( +def patch( # type: ignore[misc] client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS, ) -> AsyncOpenAI: @@ -501,15 +499,15 @@ def patch( @overload def patch( - create: Callable[T_ParamSpec, T_Retval], + create: Callable[..., Any], mode: Mode = Mode.FUNCTIONS, ) -> InstructorChatCompletionCreate: ... -def patch( - client: Union[OpenAI, AsyncOpenAI] = None, - create: Callable[T_ParamSpec, T_Retval] = None, +def patch( # type: ignore[misc] + client: Union[OpenAI, AsyncOpenAI, None] = None, + create: Optional[Callable[..., Any]] = None, mode: Mode = Mode.FUNCTIONS, ) -> Union[OpenAI, AsyncOpenAI]: """ @@ -536,40 +534,40 @@ def patch( @wraps(func) async def new_create_async( - response_model: Type[T_Model] = None, - validation_context: dict = None, + response_model: Type[BaseModel], + validation_context: Dict[str, Any], max_retries: int = 1, - *args: T_ParamSpec.args, - **kwargs: T_ParamSpec.kwargs, - ) -> T_Model: - response_model, new_kwargs = handle_response_model( + *args: Any, + **kwargs: Any, + ) -> Union[OpenAISchema, List[OpenAISchema], Dict[str, Any], Generator[Any, None, None]]: + new_response_model, new_kwargs = handle_response_model( response_model=response_model, mode=mode, **kwargs ) response = await retry_async( func=func, - response_model=response_model, + response_model=new_response_model, validation_context=validation_context, max_retries=max_retries, args=args, kwargs=new_kwargs, mode=mode, - ) # type: ignore + ) return response @wraps(func) def new_create_sync( - response_model: Type[T_Model] = None, - validation_context: dict = None, + response_model: Type[BaseModel], + validation_context: Dict[str, Any], max_retries: int = 1, - *args: T_ParamSpec.args, - **kwargs: T_ParamSpec.kwargs, - ) -> T_Model: - response_model, new_kwargs = handle_response_model( + *args: Any, + **kwargs: Any, + ) -> Union[OpenAISchema, List[OpenAISchema]]: + new_response_model, new_kwargs = handle_response_model( response_model=response_model, mode=mode, **kwargs ) response = retry_sync( func=func, - response_model=response_model, + response_model=new_response_model, validation_context=validation_context, max_retries=max_retries, args=args, @@ -588,7 +586,7 @@ def patch( return new_create -def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS): +def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS) -> AsyncOpenAI: """ No longer necessary, use `patch` instead. diff --git a/tests/test_patch.py b/tests/test_patch.py index 0418a1e..ac28e0f 100644 --- a/tests/test_patch.py +++ b/tests/test_patch.py @@ -6,40 +6,40 @@ import instructor from instructor.patch import OVERRIDE_DOCS, is_async -def test_patch_completes_successfully(): +def test_patch_completes_successfully() -> None: instructor.patch(OpenAI()) -def test_apatch_completes_successfully(): +def test_apatch_completes_successfully() -> None: instructor.apatch(AsyncOpenAI()) -def test_is_async_returns_true_if_function_is_async(): - async def async_function(): +def test_is_async_returns_true_if_function_is_async() -> None: + async def async_function() -> None: pass assert is_async(async_function) is True -def test_is_async_returns_false_if_function_is_not_async(): - def sync_function(): +def test_is_async_returns_false_if_function_is_not_async() -> None: + def sync_function() -> None: pass assert is_async(sync_function) is False -def test_is_async_returns_true_if_wrapped_function_is_async(): - async def async_function(): +def test_is_async_returns_true_if_wrapped_function_is_async() -> None: + async def async_function() -> None: pass @functools.wraps(async_function) - def wrapped_function(): + def wrapped_function() -> None: pass assert is_async(wrapped_function) is True -def test_override_docs(): +def test_override_docs() -> None: assert ( "response_model" in OVERRIDE_DOCS ), "response_model should be in OVERRIDE_DOCS"