mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 14:50:16 +00:00
fix: Improve type hinting, update response models handling, add logging, and fix bugs (#484)
This commit is contained in:
@@ -28,27 +28,12 @@ user = client.chat.completions.create(
|
||||
) # type: ignore
|
||||
|
||||
"""
|
||||
DEBUG:httpx:load_ssl_context verify=True cert=None trust_env=True http2=False
|
||||
DEBUG:httpx:load_verify_locations cafile='/Users/jasonliu/dev/instructor/.venv/lib/python3.11/site-packages/certifi/cacert.pem'
|
||||
DEBUG:instructor:Patching `client.chat.completions.create` with mode=<Mode.FUNCTIONS: 'function_call'>
|
||||
...
|
||||
DEBUG:instructor:Patching `client.chat.completions.create` with mode=<Mode.TOOLS: 'tool_call'>
|
||||
DEBUG:instructor:Instructor Request: mode.value='tool_call', response_model=<class '__main__.UserDetail'>, new_kwargs={'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'Extract Jason is 25 years old'}], 'tools': [{'type': 'function', 'function': {'name': 'UserDetail', 'description': 'Correctly extracted `UserDetail` with all the required parameters with correct types', 'parameters': {'properties': {'name': {'title': 'Name', 'type': 'string'}, 'age': {'title': 'Age', 'type': 'integer'}}, 'required': ['age', 'name'], 'type': 'object'}}}], 'tool_choice': {'type': 'function', 'function': {'name': 'UserDetail'}}}
|
||||
DEBUG:instructor:max_retries: 1
|
||||
DEBUG:openai._base_client:Request options: {'method': 'post', 'url': '/chat/completions', 'files': None, 'json_data': {'messages': [{'role': 'user', 'content': 'Extract Jason is 25 years old'}], 'model': 'gpt-3.5-turbo', 'function_call': {'name': 'UserDetail'}, 'functions': [{'name': 'UserDetail', 'description': 'Correctly extracted `UserDetail` with all the required parameters with correct types', 'parameters': {'properties': {'name': {'title': 'Name', 'type': 'string'}, 'age': {'title': 'Age', 'type': 'integer'}}, 'required': ['age', 'name'], 'type': 'object'}}]}}
|
||||
DEBUG:httpcore.connection:connect_tcp.started host='api.openai.com' port=443 local_address=None timeout=5.0 socket_options=None
|
||||
DEBUG:httpcore.connection:connect_tcp.complete return_value=<httpcore._backends.sync.SyncStream object at 0x105062c90>
|
||||
DEBUG:httpcore.connection:start_tls.started ssl_context=<ssl.SSLContext object at 0x100748680> server_hostname='api.openai.com' timeout=5.0
|
||||
DEBUG:httpcore.connection:start_tls.complete return_value=<httpcore._backends.sync.SyncStream object at 0x101caa150>
|
||||
DEBUG:httpcore.http11:send_request_headers.started request=<Request [b'POST']>
|
||||
DEBUG:httpcore.http11:send_request_headers.complete
|
||||
DEBUG:httpcore.http11:send_request_body.started request=<Request [b'POST']>
|
||||
DEBUG:httpcore.http11:send_request_body.complete
|
||||
DEBUG:httpcore.http11:receive_response_headers.started request=<Request [b'POST']>
|
||||
DEBUG:httpcore.http11:receive_response_headers.complete return_value=(b'HTTP/1.1', 200, b'OK', [(b'Date', b'Mon, 12 Feb 2024 14:55:45 GMT'), (b'Content-Type', b'application/json'), (b'Transfer-Encoding', b'chunked'), (b'Connection', b'keep-alive'), (b'access-control-allow-origin', b'*'), (b'Cache-Control', b'no-cache, must-revalidate'), (b'openai-model', b'gpt-3.5-turbo-0613'), (b'openai-organization', b'scribe-ai'), (b'openai-processing-ms', b'483'), (b'openai-version', b'2020-10-01'), (b'strict-transport-security', b'max-age=15724800; includeSubDomains'), (b'x-ratelimit-limit-requests', b'10000'), (b'x-ratelimit-limit-tokens', b'2000000'), (b'x-ratelimit-remaining-requests', b'9999'), (b'x-ratelimit-remaining-tokens', b'1999975'), (b'x-ratelimit-reset-requests', b'6ms'), (b'x-ratelimit-reset-tokens', b'0s'), (b'x-request-id', b'req_f0fa476897ae165fc50fa90b7968595b'), (b'CF-Cache-Status', b'DYNAMIC'), (b'Set-Cookie', b'__cf_bm=e2_yCrwo4frh6Oq4ZufCEhNJ4lSGJ2.MMtk45X8lrMM-1707749745-1-AfWk8CyACc7aZo6GpCI82FBfI/wmPEFZLNO/Cr3eavTW3xKVFCS7G9jvwYTFLXjJr0cttYsXeLAnjwipw18R0Vo=; path=/; expires=Mon, 12-Feb-24 15:25:45 GMT; domain=.api.openai.com; HttpOnly; Secure; SameSite=None'), (b'Set-Cookie', b'_cfuvid=PyVVCGSMxTg1p.woYvHVVC9E3n69faOs5FOxaDdjXOM-1707749745711-0-604800000; path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None'), (b'Server', b'cloudflare'), (b'CF-RAY', b'8545aca30c1fa22f-YYZ'), (b'Content-Encoding', b'gzip'), (b'alt-svc', b'h3=":443"; ma=86400')])
|
||||
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
|
||||
DEBUG:httpcore.http11:receive_response_body.started request=<Request [b'POST']>
|
||||
DEBUG:httpcore.http11:receive_response_body.complete
|
||||
DEBUG:httpcore.http11:response_closed.started
|
||||
DEBUG:httpcore.http11:response_closed.complete
|
||||
DEBUG:openai._base_client:HTTP Request: POST https://api.openai.com/v1/chat/completions "200 OK"
|
||||
...
|
||||
DEBUG:instructor:Instructor Pre-Response: ChatCompletion(id='chatcmpl-8zBxMxsOqm5Sj6yeEI38PnU2r6ncC', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_E1cftF5U0zEjzIbWt3q0ZLbN', function=Function(arguments='{"name":"Jason","age":25}', name='UserDetail'), type='function')]))], created=1709594660, model='gpt-3.5-turbo-0125', object='chat.completion', system_fingerprint='fp_2b778c6b35', usage=CompletionUsage(completion_tokens=9, prompt_tokens=81, total_tokens=90))
|
||||
DEBUG:httpcore.connection:close.started
|
||||
DEBUG:httpcore.connection:close.complete
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, Field, FieldValidationInfo, model_validator
|
||||
from pydantic import BaseModel, Field, model_validator, ValidationInfo
|
||||
from typing import Generator, List, Tuple
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ class CitationMixin(BaseModel): # type: ignore[misc]
|
||||
)
|
||||
|
||||
@model_validator(mode="after") # type: ignore[misc]
|
||||
def validate_sources(self, info: FieldValidationInfo) -> "CitationMixin":
|
||||
def validate_sources(self, info: ValidationInfo) -> "CitationMixin":
|
||||
"""
|
||||
For each substring_phrase, find the span of the substring_phrase in the context.
|
||||
If the span is not found, remove the substring_phrase from the list.
|
||||
|
||||
@@ -191,7 +191,7 @@ def IterableModel(
|
||||
new_cls = create_model(
|
||||
name,
|
||||
tasks=list_tasks,
|
||||
__base__=(OpenAISchema, IterableBase),
|
||||
__base__=(OpenAISchema, IterableBase), # type: ignore
|
||||
)
|
||||
# set the class constructor BaseModel
|
||||
new_cls.task_type = subtask_class
|
||||
|
||||
@@ -12,7 +12,7 @@ from typing import (
|
||||
get_origin,
|
||||
)
|
||||
from types import UnionType # type: ignore[attr-defined]
|
||||
|
||||
from pydantic import BaseModel
|
||||
from instructor.function_calls import OpenAISchema, Mode, openai_schema
|
||||
from collections.abc import Iterable
|
||||
|
||||
@@ -32,7 +32,7 @@ class ParallelBase:
|
||||
mode: Mode,
|
||||
validation_context: Optional[Any] = None,
|
||||
strict: Optional[bool] = None,
|
||||
) -> Generator[T, None, None]:
|
||||
) -> Generator[BaseModel, None, None]:
|
||||
#! We expect this from the OpenAISchema class, We should address
|
||||
#! this with a protocol or an abstract class... @jxnlco
|
||||
assert mode == Mode.PARALLEL_TOOLS, "Mode must be PARALLEL_TOOLS"
|
||||
@@ -44,7 +44,7 @@ class ParallelBase:
|
||||
)
|
||||
|
||||
|
||||
def get_types_array(typehint: Type[Iterable[Union[T]]]) -> Tuple[Type[T], ...]:
|
||||
def get_types_array(typehint: Type[Iterable[T]]) -> Tuple[Type[T], ...]:
|
||||
should_be_iterable = get_origin(typehint)
|
||||
if should_be_iterable is not Iterable:
|
||||
raise TypeError(f"Model should be with Iterable instead if {typehint}")
|
||||
@@ -63,7 +63,7 @@ def get_types_array(typehint: Type[Iterable[Union[T]]]) -> Tuple[Type[T], ...]:
|
||||
return get_args(typehint)
|
||||
|
||||
|
||||
def handle_parallel_model(typehint: Type[Iterable[Union[T]]]) -> List[Dict[str, Any]]:
|
||||
def handle_parallel_model(typehint: Type[Iterable[T]]) -> List[Dict[str, Any]]:
|
||||
the_types = get_types_array(typehint)
|
||||
return [
|
||||
{"type": "function", "function": openai_schema(model).openai_schema}
|
||||
@@ -71,6 +71,6 @@ def handle_parallel_model(typehint: Type[Iterable[Union[T]]]) -> List[Dict[str,
|
||||
]
|
||||
|
||||
|
||||
def ParallelModel(typehint: Type[Iterable[Union[T]]]) -> ParallelBase:
|
||||
def ParallelModel(typehint: Type[Iterable[T]]) -> ParallelBase:
|
||||
the_types = get_types_array(typehint)
|
||||
return ParallelBase(*[model for model in the_types])
|
||||
|
||||
+18
-16
@@ -19,6 +19,7 @@ from typing import (
|
||||
NoReturn,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Type,
|
||||
)
|
||||
from copy import deepcopy
|
||||
|
||||
@@ -26,29 +27,28 @@ from instructor.function_calls import Mode
|
||||
from instructor.dsl.partialjson import JSONParser
|
||||
|
||||
parser = JSONParser()
|
||||
|
||||
Model = TypeVar("Model", bound=BaseModel)
|
||||
T_Model = TypeVar("T_Model", bound=BaseModel)
|
||||
|
||||
|
||||
class PartialBase:
|
||||
class PartialBase(Generic[T_Model]):
|
||||
@classmethod
|
||||
def from_streaming_response(
|
||||
cls, completion: Iterable[Any], mode: Mode, **kwargs: Any
|
||||
) -> Generator[Model, None, None]:
|
||||
) -> Generator[T_Model, None, None]:
|
||||
json_chunks = cls.extract_json(completion, mode)
|
||||
yield from cls.model_from_chunks(json_chunks, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def from_streaming_response_async(
|
||||
cls, completion: AsyncGenerator[Any, None], mode: Mode, **kwargs: Any
|
||||
) -> AsyncGenerator[Model, None]:
|
||||
) -> AsyncGenerator[T_Model, None]:
|
||||
json_chunks = cls.extract_json_async(completion, mode)
|
||||
return cls.model_from_chunks_async(json_chunks, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def model_from_chunks(
|
||||
cls, json_chunks: Iterable[Any], **kwargs: Any
|
||||
) -> Generator[Model, None, None]:
|
||||
) -> Generator[T_Model, None, None]:
|
||||
prev_obj = None
|
||||
potential_object = ""
|
||||
for chunk in json_chunks:
|
||||
@@ -70,7 +70,7 @@ class PartialBase:
|
||||
@classmethod
|
||||
async def model_from_chunks_async(
|
||||
cls, json_chunks: AsyncGenerator[str, None], **kwargs: Any
|
||||
) -> AsyncGenerator[Model, None]:
|
||||
) -> AsyncGenerator[T_Model, None]:
|
||||
potential_object = ""
|
||||
prev_obj = None
|
||||
async for chunk in json_chunks:
|
||||
@@ -136,7 +136,7 @@ class PartialBase:
|
||||
pass
|
||||
|
||||
|
||||
class Partial(Generic[Model]):
|
||||
class Partial(Generic[T_Model]):
|
||||
"""Generate a new class with all attributes optionals.
|
||||
|
||||
Notes:
|
||||
@@ -151,7 +151,7 @@ class Partial(Generic[Model]):
|
||||
cls,
|
||||
*args: object, # noqa :ARG003
|
||||
**kwargs: object, # noqa :ARG003
|
||||
) -> "Partial[Model]":
|
||||
) -> "Partial[T_Model]":
|
||||
"""Cannot instantiate.
|
||||
|
||||
Raises:
|
||||
@@ -173,8 +173,8 @@ class Partial(Generic[Model]):
|
||||
|
||||
def __class_getitem__( # type: ignore[override]
|
||||
cls,
|
||||
wrapped_class: type[Model],
|
||||
) -> type[Model]:
|
||||
wrapped_class: type[T_Model],
|
||||
) -> type[T_Model]:
|
||||
"""Convert model to a partial model with all fields being optionals."""
|
||||
|
||||
def _make_field_optional(
|
||||
@@ -199,7 +199,9 @@ class Partial(Generic[Model]):
|
||||
)
|
||||
|
||||
# Reconstruct the generic type with modified arguments
|
||||
tmp_field.annotation = Optional[generic_base[modified_args]]
|
||||
tmp_field.annotation = (
|
||||
Optional[generic_base[modified_args]] if generic_base else None
|
||||
)
|
||||
tmp_field.default = None
|
||||
# If the field is a BaseModel, then recursively convert it's
|
||||
# attributes to optionals.
|
||||
@@ -211,12 +213,12 @@ class Partial(Generic[Model]):
|
||||
tmp_field.default = None
|
||||
return tmp_field.annotation, tmp_field
|
||||
|
||||
return create_model( # type: ignore[no-any-return, call-overload]
|
||||
f"Partial{wrapped_class.__name__}",
|
||||
return create_model(
|
||||
__model_name=f"Partial{wrapped_class.__name__}",
|
||||
__base__=(wrapped_class, PartialBase),
|
||||
__module__=wrapped_class.__module__,
|
||||
**{
|
||||
field_name: _make_field_optional(field_info)
|
||||
for field_name, field_info in wrapped_class.model_fields.items()
|
||||
for field_name, field_info in wrapped_class.__fields__.items()
|
||||
},
|
||||
)
|
||||
) # type: ignore[all]
|
||||
|
||||
@@ -139,7 +139,7 @@ class JSONParser:
|
||||
if "." in num_str or "e" in num_str or "E" in num_str
|
||||
else int(num_str)
|
||||
)
|
||||
except ValueError as e:
|
||||
except json.JSONDecodeError as e:
|
||||
raise e
|
||||
return num, s
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ def llm_validator(
|
||||
allow_override: bool = False,
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0,
|
||||
openai_client: OpenAI = None,
|
||||
openai_client: Optional[OpenAI] = None,
|
||||
) -> Callable[[str], str]:
|
||||
"""
|
||||
Create a validator that uses the LLM to validate an attribute
|
||||
@@ -85,7 +85,7 @@ def llm_validator(
|
||||
],
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
)
|
||||
) # type: ignore[all]
|
||||
|
||||
# If the response is not valid, return the reason, this could be used in
|
||||
# the future to generate a better response, via reasking mechanism.
|
||||
|
||||
@@ -5,20 +5,24 @@ from pydantic import BaseModel, create_model
|
||||
from instructor.exceptions import IncompleteOutputException
|
||||
import enum
|
||||
import warnings
|
||||
import logging
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger("instructor")
|
||||
|
||||
|
||||
class Mode(enum.Enum):
|
||||
"""The mode to use for patching the client"""
|
||||
|
||||
FUNCTIONS: str = "function_call"
|
||||
PARALLEL_TOOLS: str = "parallel_tool_call"
|
||||
TOOLS: str = "tool_call"
|
||||
MISTRAL_TOOLS: str = "mistral_tools"
|
||||
JSON: str = "json_mode"
|
||||
MD_JSON: str = "markdown_json_mode"
|
||||
JSON_SCHEMA: str = "json_schema_mode"
|
||||
FUNCTIONS = "function_call"
|
||||
PARALLEL_TOOLS = "parallel_tool_call"
|
||||
TOOLS = "tool_call"
|
||||
MISTRAL_TOOLS = "mistral_tools"
|
||||
JSON = "json_mode"
|
||||
MD_JSON = "markdown_json_mode"
|
||||
JSON_SCHEMA = "json_schema_mode"
|
||||
|
||||
def __new__(cls, value: str) -> "Mode":
|
||||
member = object.__new__(cls)
|
||||
@@ -82,11 +86,11 @@ class OpenAISchema(BaseModel): # type: ignore[misc]
|
||||
@classmethod
|
||||
def from_response(
|
||||
cls,
|
||||
completion: T,
|
||||
completion: ChatCompletion,
|
||||
validation_context: Optional[Dict[str, Any]] = None,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.TOOLS,
|
||||
) -> Dict[str, Any]:
|
||||
) -> BaseModel:
|
||||
"""Execute the function from the response of an openai chat completion
|
||||
|
||||
Parameters:
|
||||
@@ -102,41 +106,46 @@ class OpenAISchema(BaseModel): # type: ignore[misc]
|
||||
assert hasattr(completion, "choices")
|
||||
|
||||
if completion.choices[0].finish_reason == "length":
|
||||
logger.error("Incomplete output detected, should increase max_tokens")
|
||||
raise IncompleteOutputException()
|
||||
|
||||
# If Anthropic, this should be different
|
||||
message = completion.choices[0].message
|
||||
|
||||
if mode == Mode.FUNCTIONS:
|
||||
assert (
|
||||
message.function_call.name == cls.openai_schema["name"] # type: ignore[index]
|
||||
), "Function name does not match"
|
||||
return cls.model_validate_json(
|
||||
message.function_call.arguments,
|
||||
model_response = cls.model_validate_json(
|
||||
message.function_call.arguments, # type: ignore[attr-defined]
|
||||
context=validation_context,
|
||||
strict=strict,
|
||||
)
|
||||
elif mode in {Mode.TOOLS, Mode.MISTRAL_TOOLS}:
|
||||
assert (
|
||||
len(message.tool_calls) == 1
|
||||
len(message.tool_calls or []) == 1
|
||||
), "Instructor does not support multiple tool calls, use List[Model] instead."
|
||||
tool_call = message.tool_calls[0]
|
||||
tool_call = message.tool_calls[0] # type: ignore
|
||||
assert (
|
||||
tool_call.function.name == cls.openai_schema["name"] # type: ignore[index]
|
||||
), "Tool name does not match"
|
||||
return cls.model_validate_json(
|
||||
model_response = cls.model_validate_json(
|
||||
tool_call.function.arguments,
|
||||
context=validation_context,
|
||||
strict=strict,
|
||||
)
|
||||
elif mode in {Mode.JSON, Mode.JSON_SCHEMA, Mode.MD_JSON}:
|
||||
return cls.model_validate_json(
|
||||
message.content,
|
||||
model_response = cls.model_validate_json(
|
||||
message.content, # type: ignore
|
||||
context=validation_context,
|
||||
strict=strict,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid patch mode: {mode}")
|
||||
|
||||
# TODO: add logging or response handler
|
||||
return model_response
|
||||
|
||||
|
||||
def openai_schema(cls: Type[BaseModel]) -> OpenAISchema:
|
||||
if not issubclass(cls, BaseModel):
|
||||
@@ -147,4 +156,4 @@ def openai_schema(cls: Type[BaseModel]) -> OpenAISchema:
|
||||
cls.__name__,
|
||||
__base__=(cls, OpenAISchema),
|
||||
)
|
||||
)
|
||||
) # type: ignore[all]
|
||||
|
||||
+49
-31
@@ -1,3 +1,4 @@
|
||||
# type: ignore[all]
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
@@ -8,6 +9,7 @@ from tenacity import Retrying, AsyncRetrying, stop_after_attempt, RetryError
|
||||
from json import JSONDecodeError
|
||||
from typing import (
|
||||
Callable,
|
||||
Generator,
|
||||
Optional,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
@@ -45,6 +47,15 @@ T_ParamSpec = ParamSpec("T_ParamSpec")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def update_total_usage(response, total_usage):
|
||||
if isinstance(response, ChatCompletion) and response.usage is not None:
|
||||
total_usage.completion_tokens += response.usage.completion_tokens or 0
|
||||
total_usage.prompt_tokens += response.usage.prompt_tokens or 0
|
||||
total_usage.total_tokens += response.usage.total_tokens or 0
|
||||
response.usage = total_usage # Replace each response usage with the total usage
|
||||
return response
|
||||
|
||||
|
||||
def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
|
||||
"""Dumps a message to a dict, to be returned to the OpenAI API.
|
||||
Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
|
||||
@@ -56,7 +67,11 @@ def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
|
||||
}
|
||||
if hasattr(message, "tool_calls") and message.tool_calls is not None:
|
||||
ret["tool_calls"] = message.model_dump()["tool_calls"]
|
||||
if hasattr(message, "function_call") and message.function_call is not None:
|
||||
if (
|
||||
hasattr(message, "function_call")
|
||||
and message.function_call is not None
|
||||
and ret["content"]
|
||||
):
|
||||
ret["content"] += json.dumps(message.model_dump()["function_call"])
|
||||
return ret
|
||||
|
||||
@@ -177,18 +192,29 @@ def handle_response_model(
|
||||
new_kwargs["messages"][0]["content"] += f"\n\n{message}"
|
||||
else:
|
||||
raise ValueError(f"Invalid patch mode: {mode}")
|
||||
|
||||
logger.debug(
|
||||
f"Instructor Request: {mode.value=}, {response_model=}, {new_kwargs=}",
|
||||
extra={
|
||||
"mode": mode.value,
|
||||
"response_model": response_model.__name__
|
||||
if response_model is not None
|
||||
else None,
|
||||
"new_kwargs": new_kwargs,
|
||||
},
|
||||
)
|
||||
return response_model, new_kwargs
|
||||
|
||||
|
||||
def process_response(
|
||||
response: T,
|
||||
response: T_Model,
|
||||
*,
|
||||
response_model: Type[T_Model],
|
||||
response_model: Type[OpenAISchema | BaseModel],
|
||||
stream: bool,
|
||||
validation_context: dict = None,
|
||||
validation_context: Optional[dict] = None,
|
||||
strict=None,
|
||||
mode: Mode = Mode.TOOLS,
|
||||
) -> Union[T_Model, T]:
|
||||
) -> T_Model | Generator[T_Model, None, None]:
|
||||
"""Processes a OpenAI response with the response model, if available.
|
||||
|
||||
Args:
|
||||
@@ -202,7 +228,13 @@ def process_response(
|
||||
Returns:
|
||||
Union[T_Model, T]: The parsed response, if a response model is available, otherwise the response as is from the SDK
|
||||
"""
|
||||
|
||||
logger.debug(
|
||||
f"Instructor Raw Response: {response}",
|
||||
)
|
||||
|
||||
if response_model is None:
|
||||
logger.debug("No response model, returning response as is")
|
||||
return response
|
||||
|
||||
if (
|
||||
@@ -244,12 +276,12 @@ def process_response(
|
||||
async def process_response_async(
|
||||
response: ChatCompletion,
|
||||
*,
|
||||
response_model: Type[T_Model],
|
||||
response_model: Type[T_Model | OpenAISchema | BaseModel],
|
||||
stream: bool = False,
|
||||
validation_context: dict = None,
|
||||
validation_context: Optional[dict] = None,
|
||||
strict: Optional[bool] = None,
|
||||
mode: Mode = Mode.TOOLS,
|
||||
) -> T:
|
||||
) -> T_Model | ChatCompletion:
|
||||
"""Processes a OpenAI response with the response model, if available.
|
||||
It can use `validation_context` and `strict` to validate the response
|
||||
via the pydantic model
|
||||
@@ -261,6 +293,10 @@ async def process_response_async(
|
||||
validation_context (dict, optional): The validation context to use for validating the response. Defaults to None.
|
||||
strict (bool, optional): Whether to use strict json parsing. Defaults to None.
|
||||
"""
|
||||
|
||||
logger.debug(
|
||||
f"Instructor Raw Response: {response}",
|
||||
)
|
||||
if response_model is None:
|
||||
return response
|
||||
|
||||
@@ -329,18 +365,9 @@ async def retry_async(
|
||||
logger.debug(f"Retrying, attempt: {attempt}")
|
||||
with attempt:
|
||||
try:
|
||||
response: ChatCompletion = await func(*args, **kwargs)
|
||||
response: ChatCompletion = await func(*args, **kwargs) # type: ignore
|
||||
stream = kwargs.get("stream", False)
|
||||
if (
|
||||
isinstance(response, ChatCompletion)
|
||||
and response.usage is not None
|
||||
):
|
||||
total_usage.completion_tokens += (
|
||||
response.usage.completion_tokens or 0
|
||||
)
|
||||
total_usage.prompt_tokens += response.usage.prompt_tokens or 0
|
||||
total_usage.total_tokens += response.usage.total_tokens or 0
|
||||
response.usage = total_usage # Replace each response usage with the total usage
|
||||
response = update_total_usage(response, total_usage)
|
||||
return await process_response_async(
|
||||
response,
|
||||
response_model=response_model,
|
||||
@@ -348,9 +375,9 @@ async def retry_async(
|
||||
validation_context=validation_context,
|
||||
strict=strict,
|
||||
mode=mode,
|
||||
)
|
||||
) # type: ignore[all]
|
||||
except (ValidationError, JSONDecodeError) as e:
|
||||
logger.debug(f"Error response: {response}")
|
||||
logger.debug(f"Error response: {response}", e)
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore
|
||||
if mode == Mode.TOOLS:
|
||||
kwargs["messages"].append(
|
||||
@@ -413,16 +440,7 @@ def retry_sync(
|
||||
try:
|
||||
response = func(*args, **kwargs)
|
||||
stream = kwargs.get("stream", False)
|
||||
if (
|
||||
isinstance(response, ChatCompletion)
|
||||
and response.usage is not None
|
||||
):
|
||||
total_usage.completion_tokens += (
|
||||
response.usage.completion_tokens or 0
|
||||
)
|
||||
total_usage.prompt_tokens += response.usage.prompt_tokens or 0
|
||||
total_usage.total_tokens += response.usage.total_tokens or 0
|
||||
response.usage = total_usage # Replace each response usage with the total usage
|
||||
response = update_total_usage(response, total_usage)
|
||||
return process_response(
|
||||
response,
|
||||
response_model=response_model,
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "instructor"
|
||||
version = "0.6.2"
|
||||
version = "0.6.3"
|
||||
description = "structured outputs for llm"
|
||||
authors = ["Jason Liu <jason@jxnl.co>"]
|
||||
license = "MIT"
|
||||
|
||||
Reference in New Issue
Block a user