chore: Include types to instructor.patch (#422)

This commit is contained in:
Ezzeri Esa
2024-02-14 14:37:31 -08:00
committed by GitHub
parent fdc1fc2bb9
commit f0d7889021
3 changed files with 89 additions and 89 deletions
+2
View File
@@ -25,8 +25,10 @@ env:
instructor/dsl/partialjson.py
instructor/dsl/validators.py
instructor/function_calls.py
instructor/patch.py
tests/test_function_calls.py
tests/test_distil.py
tests/test_patch.py
jobs:
MyPy:
+77 -79
View File
@@ -3,19 +3,22 @@ import json
import logging
from collections.abc import Iterable
from functools import wraps
from tenacity import Retrying, AsyncRetrying, stop_after_attempt, RetryError
from tenacity import Retrying, AsyncRetrying, stop_after_attempt, RetryError # type: ignore[import-not-found]
from json import JSONDecodeError
from typing import (
Any,
Callable,
Optional,
ParamSpec,
Protocol,
Type,
TypeVar,
Union,
Dict,
Generator,
get_args,
get_origin,
List,
Optional,
overload,
Protocol,
Tuple,
Type,
Union,
)
from openai import AsyncOpenAI, OpenAI
@@ -34,13 +37,6 @@ from instructor.dsl.partial import PartialBase
from .function_calls import Mode, OpenAISchema, openai_schema
logger = logging.getLogger("instructor")
T = TypeVar("T")
T_Model = TypeVar("T_Model", bound=BaseModel)
T_Retval = TypeVar("T_Retval")
T_ParamSpec = ParamSpec("T_ParamSpec")
T = TypeVar("T")
def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
@@ -60,8 +56,8 @@ def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
def handle_response_model(
response_model: T, mode: Mode = Mode.TOOLS, **kwargs
) -> Union[Type[OpenAISchema], dict]:
response_model: Type[BaseModel], mode: Mode = Mode.TOOLS, **kwargs: Any
) -> Tuple[Union[Type[OpenAISchema], ParallelBase], Dict[str, Any]]:
"""Prepare the response model type hint, and returns the response_model
along with the new modified kwargs needed to be able to use the response_model
parameter with the patch function.
@@ -89,15 +85,14 @@ def handle_response_model(
new_kwargs["tool_choice"] = "auto"
# This is a special case for parallel models
response_model = ParallelModel(typehint=response_model)
return response_model, new_kwargs
return ParallelModel(typehint=response_model), new_kwargs
# This is for all other single model cases
if get_origin(response_model) is Iterable:
iterable_element_class = get_args(response_model)[0]
response_model = IterableModel(iterable_element_class)
if not issubclass(response_model, OpenAISchema):
response_model = openai_schema(response_model) # type: ignore
response_model = openai_schema(response_model)
if new_kwargs.get("stream", False) and not issubclass(
response_model, (IterableBase, PartialBase)
@@ -107,8 +102,8 @@ def handle_response_model(
)
if mode == Mode.FUNCTIONS:
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore
new_kwargs["functions"] = [response_model.openai_schema]
new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]}
elif mode == Mode.TOOLS:
new_kwargs["tools"] = [
{
@@ -168,14 +163,14 @@ def handle_response_model(
def process_response(
response: T,
response: ChatCompletion,
*,
response_model: Type[T_Model],
response_model: Union[Type[OpenAISchema], ParallelBase, None],
stream: bool,
validation_context: dict = None,
strict=None,
validation_context: Optional[Dict[str, Any]] = None,
strict: Optional[bool] = None,
mode: Mode = Mode.FUNCTIONS,
) -> Union[T_Model, T]:
) -> Union[OpenAISchema, List[OpenAISchema]]:
"""Processes a OpenAI response with the response model, if available.
Args:
@@ -213,11 +208,12 @@ def process_response(
# ? This really hints at the fact that we need a better way of
# ? attaching usage data and the raw response to the model we return.
if isinstance(model, IterableBase):
return [task for task in model.tasks]
return [task for task in model.tasks] # type: ignore[attr-defined]
if isinstance(response_model, ParallelBase):
return model
assert hasattr(model, "_raw_response")
model._raw_response = response
return model
@@ -225,12 +221,12 @@ def process_response(
async def process_response_async(
response: ChatCompletion,
*,
response_model: Type[T_Model],
stream: bool = False,
validation_context: dict = None,
response_model: Union[Type[OpenAISchema], ParallelBase, None],
stream: bool,
validation_context: Optional[Dict[str, Any]] = None,
strict: Optional[bool] = None,
mode: Mode = Mode.FUNCTIONS,
) -> T:
) -> Union[OpenAISchema, List[OpenAISchema], Dict[str, Any], Generator[Any, None, None]]:
"""Processes a OpenAI response with the response model, if available.
It can use `validation_context` and `strict` to validate the response
via the pydantic model
@@ -250,42 +246,44 @@ async def process_response_async(
and issubclass(response_model, (IterableBase, PartialBase))
and stream
):
model = await response_model.from_streaming_response_async(
await_model = await response_model.from_streaming_response_async(
response,
mode=mode,
)
return model
return await_model
model = response_model.from_response(
response,
validation_context=validation_context,
strict=strict,
mode=mode,
)
) # type: ignore[var-annotated]
# ? This really hints at the fact that we need a better way of
# ? attaching usage data and the raw response to the model we return.
if isinstance(model, IterableBase):
#! If the response model is a multitask, return the tasks
assert hasattr(model, "tasks")
return [task for task in model.tasks]
if isinstance(response_model, ParallelBase):
return model
assert hasattr(model, "_raw_response")
model._raw_response = response
return model
async def retry_async(
func: Callable[T_ParamSpec, T_Retval],
response_model: Type[T],
validation_context,
args,
kwargs,
max_retries: int | AsyncRetrying = 1,
async def retry_async( # type: ignore[return]
func: Callable[..., ChatCompletion],
response_model: Union[Type[OpenAISchema], ParallelBase],
validation_context: Optional[Dict[str, Any]],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
max_retries: int,
strict: Optional[bool] = None,
mode: Mode = Mode.FUNCTIONS,
) -> T:
) -> Union[OpenAISchema, List[OpenAISchema], Dict[str, Any], Generator[Any, None, None]]:
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
# If max_retries is int, then create a AsyncRetrying object
@@ -327,7 +325,7 @@ async def retry_async(
)
except (ValidationError, JSONDecodeError) as e:
logger.debug(f"Error response: {response}")
kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore
kwargs["messages"].append(dump_message(response.choices[0].message))
if mode == Mode.TOOLS:
kwargs["messages"].append(
{
@@ -361,22 +359,22 @@ async def retry_async(
raise e.last_attempt.exception from e
def retry_sync(
func: Callable[T_ParamSpec, T_Retval],
response_model: Type[T],
validation_context: dict,
args,
kwargs,
max_retries: int | Retrying = 1,
def retry_sync( # type: ignore[return]
func: Callable[..., ChatCompletion],
response_model: Union[Type[OpenAISchema], ParallelBase],
validation_context: Optional[Dict[str, Any]],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
max_retries: int,
strict: Optional[bool] = None,
mode: Mode = Mode.FUNCTIONS,
):
) -> Union[OpenAISchema, List[OpenAISchema]]:
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
# If max_retries is int, then create a Retrying object
if isinstance(max_retries, int):
logger.debug(f"max_retries: {max_retries}")
max_retries: Retrying = Retrying(
max_retries = Retrying(
stop=stop_after_attempt(max_retries),
reraise=True,
)
@@ -444,7 +442,7 @@ def retry_sync(
raise e.last_attempt.exception from e
def is_async(func: Callable) -> bool:
def is_async(func: Callable[..., Any]) -> bool:
"""Returns true if the callable is async, accounting for wrapped callables"""
return inspect.iscoroutinefunction(func) or (
hasattr(func, "__wrapped__") and inspect.iscoroutinefunction(func.__wrapped__)
@@ -474,12 +472,12 @@ Parameters:
class InstructorChatCompletionCreate(Protocol):
def __call__(
self,
response_model: Type[T_Model] = None,
validation_context: dict = None,
response_model: Union[Type[BaseModel], ParallelBase, None] = None,
validation_context: Optional[Dict[str, Any]] = None,
max_retries: int = 1,
*args: T_ParamSpec.args,
**kwargs: T_ParamSpec.kwargs,
) -> T_Model:
*args: Any,
**kwargs: Any,
) -> Type[BaseModel]:
...
@@ -492,7 +490,7 @@ def patch(
@overload
def patch(
def patch( # type: ignore[misc]
client: AsyncOpenAI,
mode: Mode = Mode.FUNCTIONS,
) -> AsyncOpenAI:
@@ -501,15 +499,15 @@ def patch(
@overload
def patch(
create: Callable[T_ParamSpec, T_Retval],
create: Callable[..., Any],
mode: Mode = Mode.FUNCTIONS,
) -> InstructorChatCompletionCreate:
...
def patch(
client: Union[OpenAI, AsyncOpenAI] = None,
create: Callable[T_ParamSpec, T_Retval] = None,
def patch( # type: ignore[misc]
client: Union[OpenAI, AsyncOpenAI, None] = None,
create: Optional[Callable[..., Any]] = None,
mode: Mode = Mode.FUNCTIONS,
) -> Union[OpenAI, AsyncOpenAI]:
"""
@@ -536,40 +534,40 @@ def patch(
@wraps(func)
async def new_create_async(
response_model: Type[T_Model] = None,
validation_context: dict = None,
response_model: Type[BaseModel],
validation_context: Dict[str, Any],
max_retries: int = 1,
*args: T_ParamSpec.args,
**kwargs: T_ParamSpec.kwargs,
) -> T_Model:
response_model, new_kwargs = handle_response_model(
*args: Any,
**kwargs: Any,
) -> Union[OpenAISchema, List[OpenAISchema], Dict[str, Any], Generator[Any, None, None]]:
new_response_model, new_kwargs = handle_response_model(
response_model=response_model, mode=mode, **kwargs
)
response = await retry_async(
func=func,
response_model=response_model,
response_model=new_response_model,
validation_context=validation_context,
max_retries=max_retries,
args=args,
kwargs=new_kwargs,
mode=mode,
) # type: ignore
)
return response
@wraps(func)
def new_create_sync(
response_model: Type[T_Model] = None,
validation_context: dict = None,
response_model: Type[BaseModel],
validation_context: Dict[str, Any],
max_retries: int = 1,
*args: T_ParamSpec.args,
**kwargs: T_ParamSpec.kwargs,
) -> T_Model:
response_model, new_kwargs = handle_response_model(
*args: Any,
**kwargs: Any,
) -> Union[OpenAISchema, List[OpenAISchema]]:
new_response_model, new_kwargs = handle_response_model(
response_model=response_model, mode=mode, **kwargs
)
response = retry_sync(
func=func,
response_model=response_model,
response_model=new_response_model,
validation_context=validation_context,
max_retries=max_retries,
args=args,
@@ -588,7 +586,7 @@ def patch(
return new_create
def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS):
def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS) -> AsyncOpenAI:
"""
No longer necessary, use `patch` instead.
+10 -10
View File
@@ -6,40 +6,40 @@ import instructor
from instructor.patch import OVERRIDE_DOCS, is_async
def test_patch_completes_successfully():
def test_patch_completes_successfully() -> None:
instructor.patch(OpenAI())
def test_apatch_completes_successfully():
def test_apatch_completes_successfully() -> None:
instructor.apatch(AsyncOpenAI())
def test_is_async_returns_true_if_function_is_async():
async def async_function():
def test_is_async_returns_true_if_function_is_async() -> None:
async def async_function() -> None:
pass
assert is_async(async_function) is True
def test_is_async_returns_false_if_function_is_not_async():
def sync_function():
def test_is_async_returns_false_if_function_is_not_async() -> None:
def sync_function() -> None:
pass
assert is_async(sync_function) is False
def test_is_async_returns_true_if_wrapped_function_is_async():
async def async_function():
def test_is_async_returns_true_if_wrapped_function_is_async() -> None:
async def async_function() -> None:
pass
@functools.wraps(async_function)
def wrapped_function():
def wrapped_function() -> None:
pass
assert is_async(wrapped_function) is True
def test_override_docs():
def test_override_docs() -> None:
assert (
"response_model" in OVERRIDE_DOCS
), "response_model should be in OVERRIDE_DOCS"