mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 14:50:16 +00:00
chore: Include types to instructor.patch (#422)
This commit is contained in:
@@ -25,8 +25,10 @@ env:
|
||||
instructor/dsl/partialjson.py
|
||||
instructor/dsl/validators.py
|
||||
instructor/function_calls.py
|
||||
instructor/patch.py
|
||||
tests/test_function_calls.py
|
||||
tests/test_distil.py
|
||||
tests/test_patch.py
|
||||
|
||||
jobs:
|
||||
MyPy:
|
||||
|
||||
+77
-79
@@ -3,19 +3,22 @@ import json
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from functools import wraps
|
||||
from tenacity import Retrying, AsyncRetrying, stop_after_attempt, RetryError
|
||||
from tenacity import Retrying, AsyncRetrying, stop_after_attempt, RetryError # type: ignore[import-not-found]
|
||||
from json import JSONDecodeError
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Optional,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
Dict,
|
||||
Generator,
|
||||
get_args,
|
||||
get_origin,
|
||||
List,
|
||||
Optional,
|
||||
overload,
|
||||
Protocol,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
@@ -34,13 +37,6 @@ from instructor.dsl.partial import PartialBase
|
||||
from .function_calls import Mode, OpenAISchema, openai_schema
|
||||
|
||||
logger = logging.getLogger("instructor")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
T_Model = TypeVar("T_Model", bound=BaseModel)
|
||||
T_Retval = TypeVar("T_Retval")
|
||||
T_ParamSpec = ParamSpec("T_ParamSpec")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
|
||||
@@ -60,8 +56,8 @@ def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
|
||||
|
||||
|
||||
def handle_response_model(
|
||||
response_model: T, mode: Mode = Mode.TOOLS, **kwargs
|
||||
) -> Union[Type[OpenAISchema], dict]:
|
||||
response_model: Type[BaseModel], mode: Mode = Mode.TOOLS, **kwargs: Any
|
||||
) -> Tuple[Union[Type[OpenAISchema], ParallelBase], Dict[str, Any]]:
|
||||
"""Prepare the response model type hint, and returns the response_model
|
||||
along with the new modified kwargs needed to be able to use the response_model
|
||||
parameter with the patch function.
|
||||
@@ -89,15 +85,14 @@ def handle_response_model(
|
||||
new_kwargs["tool_choice"] = "auto"
|
||||
|
||||
# This is a special case for parallel models
|
||||
response_model = ParallelModel(typehint=response_model)
|
||||
return response_model, new_kwargs
|
||||
return ParallelModel(typehint=response_model), new_kwargs
|
||||
|
||||
# This is for all other single model cases
|
||||
if get_origin(response_model) is Iterable:
|
||||
iterable_element_class = get_args(response_model)[0]
|
||||
response_model = IterableModel(iterable_element_class)
|
||||
if not issubclass(response_model, OpenAISchema):
|
||||
response_model = openai_schema(response_model) # type: ignore
|
||||
response_model = openai_schema(response_model)
|
||||
|
||||
if new_kwargs.get("stream", False) and not issubclass(
|
||||
response_model, (IterableBase, PartialBase)
|
||||
@@ -107,8 +102,8 @@ def handle_response_model(
|
||||
)
|
||||
|
||||
if mode == Mode.FUNCTIONS:
|
||||
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
|
||||
new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore
|
||||
new_kwargs["functions"] = [response_model.openai_schema]
|
||||
new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]}
|
||||
elif mode == Mode.TOOLS:
|
||||
new_kwargs["tools"] = [
|
||||
{
|
||||
@@ -168,14 +163,14 @@ def handle_response_model(
|
||||
|
||||
|
||||
def process_response(
|
||||
response: T,
|
||||
response: ChatCompletion,
|
||||
*,
|
||||
response_model: Type[T_Model],
|
||||
response_model: Union[Type[OpenAISchema], ParallelBase, None],
|
||||
stream: bool,
|
||||
validation_context: dict = None,
|
||||
strict=None,
|
||||
validation_context: Optional[Dict[str, Any]] = None,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
) -> Union[T_Model, T]:
|
||||
) -> Union[OpenAISchema, List[OpenAISchema]]:
|
||||
"""Processes a OpenAI response with the response model, if available.
|
||||
|
||||
Args:
|
||||
@@ -213,11 +208,12 @@ def process_response(
|
||||
# ? This really hints at the fact that we need a better way of
|
||||
# ? attaching usage data and the raw response to the model we return.
|
||||
if isinstance(model, IterableBase):
|
||||
return [task for task in model.tasks]
|
||||
return [task for task in model.tasks] # type: ignore[attr-defined]
|
||||
|
||||
if isinstance(response_model, ParallelBase):
|
||||
return model
|
||||
|
||||
assert hasattr(model, "_raw_response")
|
||||
model._raw_response = response
|
||||
return model
|
||||
|
||||
@@ -225,12 +221,12 @@ def process_response(
|
||||
async def process_response_async(
|
||||
response: ChatCompletion,
|
||||
*,
|
||||
response_model: Type[T_Model],
|
||||
stream: bool = False,
|
||||
validation_context: dict = None,
|
||||
response_model: Union[Type[OpenAISchema], ParallelBase, None],
|
||||
stream: bool,
|
||||
validation_context: Optional[Dict[str, Any]] = None,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
) -> T:
|
||||
) -> Union[OpenAISchema, List[OpenAISchema], Dict[str, Any], Generator[Any, None, None]]:
|
||||
"""Processes a OpenAI response with the response model, if available.
|
||||
It can use `validation_context` and `strict` to validate the response
|
||||
via the pydantic model
|
||||
@@ -250,42 +246,44 @@ async def process_response_async(
|
||||
and issubclass(response_model, (IterableBase, PartialBase))
|
||||
and stream
|
||||
):
|
||||
model = await response_model.from_streaming_response_async(
|
||||
await_model = await response_model.from_streaming_response_async(
|
||||
response,
|
||||
mode=mode,
|
||||
)
|
||||
return model
|
||||
return await_model
|
||||
|
||||
model = response_model.from_response(
|
||||
response,
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
)
|
||||
) # type: ignore[var-annotated]
|
||||
|
||||
# ? This really hints at the fact that we need a better way of
|
||||
# ? attaching usage data and the raw response to the model we return.
|
||||
if isinstance(model, IterableBase):
|
||||
#! If the response model is a multitask, return the tasks
|
||||
assert hasattr(model, "tasks")
|
||||
return [task for task in model.tasks]
|
||||
|
||||
if isinstance(response_model, ParallelBase):
|
||||
return model
|
||||
|
||||
assert hasattr(model, "_raw_response")
|
||||
model._raw_response = response
|
||||
return model
|
||||
|
||||
|
||||
async def retry_async(
|
||||
func: Callable[T_ParamSpec, T_Retval],
|
||||
response_model: Type[T],
|
||||
validation_context,
|
||||
args,
|
||||
kwargs,
|
||||
max_retries: int | AsyncRetrying = 1,
|
||||
async def retry_async( # type: ignore[return]
|
||||
func: Callable[..., ChatCompletion],
|
||||
response_model: Union[Type[OpenAISchema], ParallelBase],
|
||||
validation_context: Optional[Dict[str, Any]],
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
max_retries: int,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
) -> T:
|
||||
) -> Union[OpenAISchema, List[OpenAISchema], Dict[str, Any], Generator[Any, None, None]]:
|
||||
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
|
||||
|
||||
# If max_retries is int, then create a AsyncRetrying object
|
||||
@@ -327,7 +325,7 @@ async def retry_async(
|
||||
)
|
||||
except (ValidationError, JSONDecodeError) as e:
|
||||
logger.debug(f"Error response: {response}")
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message))
|
||||
if mode == Mode.TOOLS:
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
@@ -361,22 +359,22 @@ async def retry_async(
|
||||
raise e.last_attempt.exception from e
|
||||
|
||||
|
||||
def retry_sync(
|
||||
func: Callable[T_ParamSpec, T_Retval],
|
||||
response_model: Type[T],
|
||||
validation_context: dict,
|
||||
args,
|
||||
kwargs,
|
||||
max_retries: int | Retrying = 1,
|
||||
def retry_sync( # type: ignore[return]
|
||||
func: Callable[..., ChatCompletion],
|
||||
response_model: Union[Type[OpenAISchema], ParallelBase],
|
||||
validation_context: Optional[Dict[str, Any]],
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
max_retries: int,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
):
|
||||
) -> Union[OpenAISchema, List[OpenAISchema]]:
|
||||
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
|
||||
|
||||
# If max_retries is int, then create a Retrying object
|
||||
if isinstance(max_retries, int):
|
||||
logger.debug(f"max_retries: {max_retries}")
|
||||
max_retries: Retrying = Retrying(
|
||||
max_retries = Retrying(
|
||||
stop=stop_after_attempt(max_retries),
|
||||
reraise=True,
|
||||
)
|
||||
@@ -444,7 +442,7 @@ def retry_sync(
|
||||
raise e.last_attempt.exception from e
|
||||
|
||||
|
||||
def is_async(func: Callable) -> bool:
|
||||
def is_async(func: Callable[..., Any]) -> bool:
|
||||
"""Returns true if the callable is async, accounting for wrapped callables"""
|
||||
return inspect.iscoroutinefunction(func) or (
|
||||
hasattr(func, "__wrapped__") and inspect.iscoroutinefunction(func.__wrapped__)
|
||||
@@ -474,12 +472,12 @@ Parameters:
|
||||
class InstructorChatCompletionCreate(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
response_model: Type[T_Model] = None,
|
||||
validation_context: dict = None,
|
||||
response_model: Union[Type[BaseModel], ParallelBase, None] = None,
|
||||
validation_context: Optional[Dict[str, Any]] = None,
|
||||
max_retries: int = 1,
|
||||
*args: T_ParamSpec.args,
|
||||
**kwargs: T_ParamSpec.kwargs,
|
||||
) -> T_Model:
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Type[BaseModel]:
|
||||
...
|
||||
|
||||
|
||||
@@ -492,7 +490,7 @@ def patch(
|
||||
|
||||
|
||||
@overload
|
||||
def patch(
|
||||
def patch( # type: ignore[misc]
|
||||
client: AsyncOpenAI,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
) -> AsyncOpenAI:
|
||||
@@ -501,15 +499,15 @@ def patch(
|
||||
|
||||
@overload
|
||||
def patch(
|
||||
create: Callable[T_ParamSpec, T_Retval],
|
||||
create: Callable[..., Any],
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
) -> InstructorChatCompletionCreate:
|
||||
...
|
||||
|
||||
|
||||
def patch(
|
||||
client: Union[OpenAI, AsyncOpenAI] = None,
|
||||
create: Callable[T_ParamSpec, T_Retval] = None,
|
||||
def patch( # type: ignore[misc]
|
||||
client: Union[OpenAI, AsyncOpenAI, None] = None,
|
||||
create: Optional[Callable[..., Any]] = None,
|
||||
mode: Mode = Mode.FUNCTIONS,
|
||||
) -> Union[OpenAI, AsyncOpenAI]:
|
||||
"""
|
||||
@@ -536,40 +534,40 @@ def patch(
|
||||
|
||||
@wraps(func)
|
||||
async def new_create_async(
|
||||
response_model: Type[T_Model] = None,
|
||||
validation_context: dict = None,
|
||||
response_model: Type[BaseModel],
|
||||
validation_context: Dict[str, Any],
|
||||
max_retries: int = 1,
|
||||
*args: T_ParamSpec.args,
|
||||
**kwargs: T_ParamSpec.kwargs,
|
||||
) -> T_Model:
|
||||
response_model, new_kwargs = handle_response_model(
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Union[OpenAISchema, List[OpenAISchema], Dict[str, Any], Generator[Any, None, None]]:
|
||||
new_response_model, new_kwargs = handle_response_model(
|
||||
response_model=response_model, mode=mode, **kwargs
|
||||
)
|
||||
response = await retry_async(
|
||||
func=func,
|
||||
response_model=response_model,
|
||||
response_model=new_response_model,
|
||||
validation_context=validation_context,
|
||||
max_retries=max_retries,
|
||||
args=args,
|
||||
kwargs=new_kwargs,
|
||||
mode=mode,
|
||||
) # type: ignore
|
||||
)
|
||||
return response
|
||||
|
||||
@wraps(func)
|
||||
def new_create_sync(
|
||||
response_model: Type[T_Model] = None,
|
||||
validation_context: dict = None,
|
||||
response_model: Type[BaseModel],
|
||||
validation_context: Dict[str, Any],
|
||||
max_retries: int = 1,
|
||||
*args: T_ParamSpec.args,
|
||||
**kwargs: T_ParamSpec.kwargs,
|
||||
) -> T_Model:
|
||||
response_model, new_kwargs = handle_response_model(
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Union[OpenAISchema, List[OpenAISchema]]:
|
||||
new_response_model, new_kwargs = handle_response_model(
|
||||
response_model=response_model, mode=mode, **kwargs
|
||||
)
|
||||
response = retry_sync(
|
||||
func=func,
|
||||
response_model=response_model,
|
||||
response_model=new_response_model,
|
||||
validation_context=validation_context,
|
||||
max_retries=max_retries,
|
||||
args=args,
|
||||
@@ -588,7 +586,7 @@ def patch(
|
||||
return new_create
|
||||
|
||||
|
||||
def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS):
|
||||
def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS) -> AsyncOpenAI:
|
||||
"""
|
||||
No longer necessary, use `patch` instead.
|
||||
|
||||
|
||||
+10
-10
@@ -6,40 +6,40 @@ import instructor
|
||||
from instructor.patch import OVERRIDE_DOCS, is_async
|
||||
|
||||
|
||||
def test_patch_completes_successfully():
|
||||
def test_patch_completes_successfully() -> None:
|
||||
instructor.patch(OpenAI())
|
||||
|
||||
|
||||
def test_apatch_completes_successfully():
|
||||
def test_apatch_completes_successfully() -> None:
|
||||
instructor.apatch(AsyncOpenAI())
|
||||
|
||||
|
||||
def test_is_async_returns_true_if_function_is_async():
|
||||
async def async_function():
|
||||
def test_is_async_returns_true_if_function_is_async() -> None:
|
||||
async def async_function() -> None:
|
||||
pass
|
||||
|
||||
assert is_async(async_function) is True
|
||||
|
||||
|
||||
def test_is_async_returns_false_if_function_is_not_async():
|
||||
def sync_function():
|
||||
def test_is_async_returns_false_if_function_is_not_async() -> None:
|
||||
def sync_function() -> None:
|
||||
pass
|
||||
|
||||
assert is_async(sync_function) is False
|
||||
|
||||
|
||||
def test_is_async_returns_true_if_wrapped_function_is_async():
|
||||
async def async_function():
|
||||
def test_is_async_returns_true_if_wrapped_function_is_async() -> None:
|
||||
async def async_function() -> None:
|
||||
pass
|
||||
|
||||
@functools.wraps(async_function)
|
||||
def wrapped_function():
|
||||
def wrapped_function() -> None:
|
||||
pass
|
||||
|
||||
assert is_async(wrapped_function) is True
|
||||
|
||||
|
||||
def test_override_docs():
|
||||
def test_override_docs() -> None:
|
||||
assert (
|
||||
"response_model" in OVERRIDE_DOCS
|
||||
), "response_model should be in OVERRIDE_DOCS"
|
||||
|
||||
Reference in New Issue
Block a user