fix: Improve type hinting, update response models handling, add logging, and fix bugs (#484)

This commit is contained in:
Jason Liu
2024-03-04 18:27:49 -05:00
committed by GitHub
parent 63fe8a365a
commit a6803a28f4
10 changed files with 110 additions and 96 deletions
+5 -20
View File
@@ -28,27 +28,12 @@ user = client.chat.completions.create(
) # type: ignore
"""
DEBUG:httpx:load_ssl_context verify=True cert=None trust_env=True http2=False
DEBUG:httpx:load_verify_locations cafile='/Users/jasonliu/dev/instructor/.venv/lib/python3.11/site-packages/certifi/cacert.pem'
DEBUG:instructor:Patching `client.chat.completions.create` with mode=<Mode.FUNCTIONS: 'function_call'>
...
DEBUG:instructor:Patching `client.chat.completions.create` with mode=<Mode.TOOLS: 'tool_call'>
DEBUG:instructor:Instructor Request: mode.value='tool_call', response_model=<class '__main__.UserDetail'>, new_kwargs={'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'Extract Jason is 25 years old'}], 'tools': [{'type': 'function', 'function': {'name': 'UserDetail', 'description': 'Correctly extracted `UserDetail` with all the required parameters with correct types', 'parameters': {'properties': {'name': {'title': 'Name', 'type': 'string'}, 'age': {'title': 'Age', 'type': 'integer'}}, 'required': ['age', 'name'], 'type': 'object'}}}], 'tool_choice': {'type': 'function', 'function': {'name': 'UserDetail'}}}
DEBUG:instructor:max_retries: 1
DEBUG:openai._base_client:Request options: {'method': 'post', 'url': '/chat/completions', 'files': None, 'json_data': {'messages': [{'role': 'user', 'content': 'Extract Jason is 25 years old'}], 'model': 'gpt-3.5-turbo', 'function_call': {'name': 'UserDetail'}, 'functions': [{'name': 'UserDetail', 'description': 'Correctly extracted `UserDetail` with all the required parameters with correct types', 'parameters': {'properties': {'name': {'title': 'Name', 'type': 'string'}, 'age': {'title': 'Age', 'type': 'integer'}}, 'required': ['age', 'name'], 'type': 'object'}}]}}
DEBUG:httpcore.connection:connect_tcp.started host='api.openai.com' port=443 local_address=None timeout=5.0 socket_options=None
DEBUG:httpcore.connection:connect_tcp.complete return_value=<httpcore._backends.sync.SyncStream object at 0x105062c90>
DEBUG:httpcore.connection:start_tls.started ssl_context=<ssl.SSLContext object at 0x100748680> server_hostname='api.openai.com' timeout=5.0
DEBUG:httpcore.connection:start_tls.complete return_value=<httpcore._backends.sync.SyncStream object at 0x101caa150>
DEBUG:httpcore.http11:send_request_headers.started request=<Request [b'POST']>
DEBUG:httpcore.http11:send_request_headers.complete
DEBUG:httpcore.http11:send_request_body.started request=<Request [b'POST']>
DEBUG:httpcore.http11:send_request_body.complete
DEBUG:httpcore.http11:receive_response_headers.started request=<Request [b'POST']>
DEBUG:httpcore.http11:receive_response_headers.complete return_value=(b'HTTP/1.1', 200, b'OK', [(b'Date', b'Mon, 12 Feb 2024 14:55:45 GMT'), (b'Content-Type', b'application/json'), (b'Transfer-Encoding', b'chunked'), (b'Connection', b'keep-alive'), (b'access-control-allow-origin', b'*'), (b'Cache-Control', b'no-cache, must-revalidate'), (b'openai-model', b'gpt-3.5-turbo-0613'), (b'openai-organization', b'scribe-ai'), (b'openai-processing-ms', b'483'), (b'openai-version', b'2020-10-01'), (b'strict-transport-security', b'max-age=15724800; includeSubDomains'), (b'x-ratelimit-limit-requests', b'10000'), (b'x-ratelimit-limit-tokens', b'2000000'), (b'x-ratelimit-remaining-requests', b'9999'), (b'x-ratelimit-remaining-tokens', b'1999975'), (b'x-ratelimit-reset-requests', b'6ms'), (b'x-ratelimit-reset-tokens', b'0s'), (b'x-request-id', b'req_f0fa476897ae165fc50fa90b7968595b'), (b'CF-Cache-Status', b'DYNAMIC'), (b'Set-Cookie', b'__cf_bm=e2_yCrwo4frh6Oq4ZufCEhNJ4lSGJ2.MMtk45X8lrMM-1707749745-1-AfWk8CyACc7aZo6GpCI82FBfI/wmPEFZLNO/Cr3eavTW3xKVFCS7G9jvwYTFLXjJr0cttYsXeLAnjwipw18R0Vo=; path=/; expires=Mon, 12-Feb-24 15:25:45 GMT; domain=.api.openai.com; HttpOnly; Secure; SameSite=None'), (b'Set-Cookie', b'_cfuvid=PyVVCGSMxTg1p.woYvHVVC9E3n69faOs5FOxaDdjXOM-1707749745711-0-604800000; path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None'), (b'Server', b'cloudflare'), (b'CF-RAY', b'8545aca30c1fa22f-YYZ'), (b'Content-Encoding', b'gzip'), (b'alt-svc', b'h3=":443"; ma=86400')])
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
DEBUG:httpcore.http11:receive_response_body.started request=<Request [b'POST']>
DEBUG:httpcore.http11:receive_response_body.complete
DEBUG:httpcore.http11:response_closed.started
DEBUG:httpcore.http11:response_closed.complete
DEBUG:openai._base_client:HTTP Request: POST https://api.openai.com/v1/chat/completions "200 OK"
...
DEBUG:instructor:Instructor Pre-Response: ChatCompletion(id='chatcmpl-8zBxMxsOqm5Sj6yeEI38PnU2r6ncC', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_E1cftF5U0zEjzIbWt3q0ZLbN', function=Function(arguments='{"name":"Jason","age":25}', name='UserDetail'), type='function')]))], created=1709594660, model='gpt-3.5-turbo-0125', object='chat.completion', system_fingerprint='fp_2b778c6b35', usage=CompletionUsage(completion_tokens=9, prompt_tokens=81, total_tokens=90))
DEBUG:httpcore.connection:close.started
DEBUG:httpcore.connection:close.complete
"""
+2 -2
View File
@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field, FieldValidationInfo, model_validator
from pydantic import BaseModel, Field, model_validator, ValidationInfo
from typing import Generator, List, Tuple
@@ -58,7 +58,7 @@ class CitationMixin(BaseModel): # type: ignore[misc]
)
@model_validator(mode="after") # type: ignore[misc]
def validate_sources(self, info: FieldValidationInfo) -> "CitationMixin":
def validate_sources(self, info: ValidationInfo) -> "CitationMixin":
"""
For each substring_phrase, find the span of the substring_phrase in the context.
If the span is not found, remove the substring_phrase from the list.
+1 -1
View File
@@ -191,7 +191,7 @@ def IterableModel(
new_cls = create_model(
name,
tasks=list_tasks,
__base__=(OpenAISchema, IterableBase),
__base__=(OpenAISchema, IterableBase), # type: ignore
)
# set the class constructor BaseModel
new_cls.task_type = subtask_class
+5 -5
View File
@@ -12,7 +12,7 @@ from typing import (
get_origin,
)
from types import UnionType # type: ignore[attr-defined]
from pydantic import BaseModel
from instructor.function_calls import OpenAISchema, Mode, openai_schema
from collections.abc import Iterable
@@ -32,7 +32,7 @@ class ParallelBase:
mode: Mode,
validation_context: Optional[Any] = None,
strict: Optional[bool] = None,
) -> Generator[T, None, None]:
) -> Generator[BaseModel, None, None]:
#! We expect this from the OpenAISchema class, We should address
#! this with a protocol or an abstract class... @jxnlco
assert mode == Mode.PARALLEL_TOOLS, "Mode must be PARALLEL_TOOLS"
@@ -44,7 +44,7 @@ class ParallelBase:
)
def get_types_array(typehint: Type[Iterable[Union[T]]]) -> Tuple[Type[T], ...]:
def get_types_array(typehint: Type[Iterable[T]]) -> Tuple[Type[T], ...]:
should_be_iterable = get_origin(typehint)
if should_be_iterable is not Iterable:
raise TypeError(f"Model should be with Iterable instead if {typehint}")
@@ -63,7 +63,7 @@ def get_types_array(typehint: Type[Iterable[Union[T]]]) -> Tuple[Type[T], ...]:
return get_args(typehint)
def handle_parallel_model(typehint: Type[Iterable[Union[T]]]) -> List[Dict[str, Any]]:
def handle_parallel_model(typehint: Type[Iterable[T]]) -> List[Dict[str, Any]]:
the_types = get_types_array(typehint)
return [
{"type": "function", "function": openai_schema(model).openai_schema}
@@ -71,6 +71,6 @@ def handle_parallel_model(typehint: Type[Iterable[Union[T]]]) -> List[Dict[str,
]
def ParallelModel(typehint: Type[Iterable[Union[T]]]) -> ParallelBase:
def ParallelModel(typehint: Type[Iterable[T]]) -> ParallelBase:
the_types = get_types_array(typehint)
return ParallelBase(*[model for model in the_types])
+18 -16
View File
@@ -19,6 +19,7 @@ from typing import (
NoReturn,
Optional,
TypeVar,
Type,
)
from copy import deepcopy
@@ -26,29 +27,28 @@ from instructor.function_calls import Mode
from instructor.dsl.partialjson import JSONParser
parser = JSONParser()
Model = TypeVar("Model", bound=BaseModel)
T_Model = TypeVar("T_Model", bound=BaseModel)
class PartialBase:
class PartialBase(Generic[T_Model]):
@classmethod
def from_streaming_response(
cls, completion: Iterable[Any], mode: Mode, **kwargs: Any
) -> Generator[Model, None, None]:
) -> Generator[T_Model, None, None]:
json_chunks = cls.extract_json(completion, mode)
yield from cls.model_from_chunks(json_chunks, **kwargs)
@classmethod
async def from_streaming_response_async(
cls, completion: AsyncGenerator[Any, None], mode: Mode, **kwargs: Any
) -> AsyncGenerator[Model, None]:
) -> AsyncGenerator[T_Model, None]:
json_chunks = cls.extract_json_async(completion, mode)
return cls.model_from_chunks_async(json_chunks, **kwargs)
@classmethod
def model_from_chunks(
cls, json_chunks: Iterable[Any], **kwargs: Any
) -> Generator[Model, None, None]:
) -> Generator[T_Model, None, None]:
prev_obj = None
potential_object = ""
for chunk in json_chunks:
@@ -70,7 +70,7 @@ class PartialBase:
@classmethod
async def model_from_chunks_async(
cls, json_chunks: AsyncGenerator[str, None], **kwargs: Any
) -> AsyncGenerator[Model, None]:
) -> AsyncGenerator[T_Model, None]:
potential_object = ""
prev_obj = None
async for chunk in json_chunks:
@@ -136,7 +136,7 @@ class PartialBase:
pass
class Partial(Generic[Model]):
class Partial(Generic[T_Model]):
"""Generate a new class with all attributes optionals.
Notes:
@@ -151,7 +151,7 @@ class Partial(Generic[Model]):
cls,
*args: object, # noqa :ARG003
**kwargs: object, # noqa :ARG003
) -> "Partial[Model]":
) -> "Partial[T_Model]":
"""Cannot instantiate.
Raises:
@@ -173,8 +173,8 @@ class Partial(Generic[Model]):
def __class_getitem__( # type: ignore[override]
cls,
wrapped_class: type[Model],
) -> type[Model]:
wrapped_class: type[T_Model],
) -> type[T_Model]:
"""Convert model to a partial model with all fields being optionals."""
def _make_field_optional(
@@ -199,7 +199,9 @@ class Partial(Generic[Model]):
)
# Reconstruct the generic type with modified arguments
tmp_field.annotation = Optional[generic_base[modified_args]]
tmp_field.annotation = (
Optional[generic_base[modified_args]] if generic_base else None
)
tmp_field.default = None
# If the field is a BaseModel, then recursively convert it's
# attributes to optionals.
@@ -211,12 +213,12 @@ class Partial(Generic[Model]):
tmp_field.default = None
return tmp_field.annotation, tmp_field
return create_model( # type: ignore[no-any-return, call-overload]
f"Partial{wrapped_class.__name__}",
return create_model(
__model_name=f"Partial{wrapped_class.__name__}",
__base__=(wrapped_class, PartialBase),
__module__=wrapped_class.__module__,
**{
field_name: _make_field_optional(field_info)
for field_name, field_info in wrapped_class.model_fields.items()
for field_name, field_info in wrapped_class.__fields__.items()
},
)
) # type: ignore[all]
+1 -1
View File
@@ -139,7 +139,7 @@ class JSONParser:
if "." in num_str or "e" in num_str or "E" in num_str
else int(num_str)
)
except ValueError as e:
except json.JSONDecodeError as e:
raise e
return num, s
+2 -2
View File
@@ -32,7 +32,7 @@ def llm_validator(
allow_override: bool = False,
model: str = "gpt-3.5-turbo",
temperature: float = 0,
openai_client: OpenAI = None,
openai_client: Optional[OpenAI] = None,
) -> Callable[[str], str]:
"""
Create a validator that uses the LLM to validate an attribute
@@ -85,7 +85,7 @@ def llm_validator(
],
model=model,
temperature=temperature,
)
) # type: ignore[all]
# If the response is not valid, return the reason, this could be used in
# the future to generate a better response, via reasking mechanism.
+26 -17
View File
@@ -5,20 +5,24 @@ from pydantic import BaseModel, create_model
from instructor.exceptions import IncompleteOutputException
import enum
import warnings
import logging
from openai.types.chat import ChatCompletion
T = TypeVar("T")
logger = logging.getLogger("instructor")
class Mode(enum.Enum):
"""The mode to use for patching the client"""
FUNCTIONS: str = "function_call"
PARALLEL_TOOLS: str = "parallel_tool_call"
TOOLS: str = "tool_call"
MISTRAL_TOOLS: str = "mistral_tools"
JSON: str = "json_mode"
MD_JSON: str = "markdown_json_mode"
JSON_SCHEMA: str = "json_schema_mode"
FUNCTIONS = "function_call"
PARALLEL_TOOLS = "parallel_tool_call"
TOOLS = "tool_call"
MISTRAL_TOOLS = "mistral_tools"
JSON = "json_mode"
MD_JSON = "markdown_json_mode"
JSON_SCHEMA = "json_schema_mode"
def __new__(cls, value: str) -> "Mode":
member = object.__new__(cls)
@@ -82,11 +86,11 @@ class OpenAISchema(BaseModel): # type: ignore[misc]
@classmethod
def from_response(
cls,
completion: T,
completion: ChatCompletion,
validation_context: Optional[Dict[str, Any]] = None,
strict: Optional[bool] = None,
mode: Mode = Mode.TOOLS,
) -> Dict[str, Any]:
) -> BaseModel:
"""Execute the function from the response of an openai chat completion
Parameters:
@@ -102,41 +106,46 @@ class OpenAISchema(BaseModel): # type: ignore[misc]
assert hasattr(completion, "choices")
if completion.choices[0].finish_reason == "length":
logger.error("Incomplete output detected, should increase max_tokens")
raise IncompleteOutputException()
# If Anthropic, this should be different
message = completion.choices[0].message
if mode == Mode.FUNCTIONS:
assert (
message.function_call.name == cls.openai_schema["name"] # type: ignore[index]
), "Function name does not match"
return cls.model_validate_json(
message.function_call.arguments,
model_response = cls.model_validate_json(
message.function_call.arguments, # type: ignore[attr-defined]
context=validation_context,
strict=strict,
)
elif mode in {Mode.TOOLS, Mode.MISTRAL_TOOLS}:
assert (
len(message.tool_calls) == 1
len(message.tool_calls or []) == 1
), "Instructor does not support multiple tool calls, use List[Model] instead."
tool_call = message.tool_calls[0]
tool_call = message.tool_calls[0] # type: ignore
assert (
tool_call.function.name == cls.openai_schema["name"] # type: ignore[index]
), "Tool name does not match"
return cls.model_validate_json(
model_response = cls.model_validate_json(
tool_call.function.arguments,
context=validation_context,
strict=strict,
)
elif mode in {Mode.JSON, Mode.JSON_SCHEMA, Mode.MD_JSON}:
return cls.model_validate_json(
message.content,
model_response = cls.model_validate_json(
message.content, # type: ignore
context=validation_context,
strict=strict,
)
else:
raise ValueError(f"Invalid patch mode: {mode}")
# TODO: add logging or response handler
return model_response
def openai_schema(cls: Type[BaseModel]) -> OpenAISchema:
if not issubclass(cls, BaseModel):
@@ -147,4 +156,4 @@ def openai_schema(cls: Type[BaseModel]) -> OpenAISchema:
cls.__name__,
__base__=(cls, OpenAISchema),
)
)
) # type: ignore[all]
+49 -31
View File
@@ -1,3 +1,4 @@
# type: ignore[all]
import inspect
import json
import logging
@@ -8,6 +9,7 @@ from tenacity import Retrying, AsyncRetrying, stop_after_attempt, RetryError
from json import JSONDecodeError
from typing import (
Callable,
Generator,
Optional,
ParamSpec,
Protocol,
@@ -45,6 +47,15 @@ T_ParamSpec = ParamSpec("T_ParamSpec")
T = TypeVar("T")
def update_total_usage(response, total_usage):
if isinstance(response, ChatCompletion) and response.usage is not None:
total_usage.completion_tokens += response.usage.completion_tokens or 0
total_usage.prompt_tokens += response.usage.prompt_tokens or 0
total_usage.total_tokens += response.usage.total_tokens or 0
response.usage = total_usage # Replace each response usage with the total usage
return response
def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
"""Dumps a message to a dict, to be returned to the OpenAI API.
Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
@@ -56,7 +67,11 @@ def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
}
if hasattr(message, "tool_calls") and message.tool_calls is not None:
ret["tool_calls"] = message.model_dump()["tool_calls"]
if hasattr(message, "function_call") and message.function_call is not None:
if (
hasattr(message, "function_call")
and message.function_call is not None
and ret["content"]
):
ret["content"] += json.dumps(message.model_dump()["function_call"])
return ret
@@ -177,18 +192,29 @@ def handle_response_model(
new_kwargs["messages"][0]["content"] += f"\n\n{message}"
else:
raise ValueError(f"Invalid patch mode: {mode}")
logger.debug(
f"Instructor Request: {mode.value=}, {response_model=}, {new_kwargs=}",
extra={
"mode": mode.value,
"response_model": response_model.__name__
if response_model is not None
else None,
"new_kwargs": new_kwargs,
},
)
return response_model, new_kwargs
def process_response(
response: T,
response: T_Model,
*,
response_model: Type[T_Model],
response_model: Type[OpenAISchema | BaseModel],
stream: bool,
validation_context: dict = None,
validation_context: Optional[dict] = None,
strict=None,
mode: Mode = Mode.TOOLS,
) -> Union[T_Model, T]:
) -> T_Model | Generator[T_Model, None, None]:
"""Processes a OpenAI response with the response model, if available.
Args:
@@ -202,7 +228,13 @@ def process_response(
Returns:
Union[T_Model, T]: The parsed response, if a response model is available, otherwise the response as is from the SDK
"""
logger.debug(
f"Instructor Raw Response: {response}",
)
if response_model is None:
logger.debug("No response model, returning response as is")
return response
if (
@@ -244,12 +276,12 @@ def process_response(
async def process_response_async(
response: ChatCompletion,
*,
response_model: Type[T_Model],
response_model: Type[T_Model | OpenAISchema | BaseModel],
stream: bool = False,
validation_context: dict = None,
validation_context: Optional[dict] = None,
strict: Optional[bool] = None,
mode: Mode = Mode.TOOLS,
) -> T:
) -> T_Model | ChatCompletion:
"""Processes a OpenAI response with the response model, if available.
It can use `validation_context` and `strict` to validate the response
via the pydantic model
@@ -261,6 +293,10 @@ async def process_response_async(
validation_context (dict, optional): The validation context to use for validating the response. Defaults to None.
strict (bool, optional): Whether to use strict json parsing. Defaults to None.
"""
logger.debug(
f"Instructor Raw Response: {response}",
)
if response_model is None:
return response
@@ -329,18 +365,9 @@ async def retry_async(
logger.debug(f"Retrying, attempt: {attempt}")
with attempt:
try:
response: ChatCompletion = await func(*args, **kwargs)
response: ChatCompletion = await func(*args, **kwargs) # type: ignore
stream = kwargs.get("stream", False)
if (
isinstance(response, ChatCompletion)
and response.usage is not None
):
total_usage.completion_tokens += (
response.usage.completion_tokens or 0
)
total_usage.prompt_tokens += response.usage.prompt_tokens or 0
total_usage.total_tokens += response.usage.total_tokens or 0
response.usage = total_usage # Replace each response usage with the total usage
response = update_total_usage(response, total_usage)
return await process_response_async(
response,
response_model=response_model,
@@ -348,9 +375,9 @@ async def retry_async(
validation_context=validation_context,
strict=strict,
mode=mode,
)
) # type: ignore[all]
except (ValidationError, JSONDecodeError) as e:
logger.debug(f"Error response: {response}")
logger.debug(f"Error response: {response}", e)
kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore
if mode == Mode.TOOLS:
kwargs["messages"].append(
@@ -413,16 +440,7 @@ def retry_sync(
try:
response = func(*args, **kwargs)
stream = kwargs.get("stream", False)
if (
isinstance(response, ChatCompletion)
and response.usage is not None
):
total_usage.completion_tokens += (
response.usage.completion_tokens or 0
)
total_usage.prompt_tokens += response.usage.prompt_tokens or 0
total_usage.total_tokens += response.usage.total_tokens or 0
response.usage = total_usage # Replace each response usage with the total usage
response = update_total_usage(response, total_usage)
return process_response(
response,
response_model=response_model,
+1 -1
View File
@@ -1,6 +1,6 @@
[tool.poetry]
name = "instructor"
version = "0.6.2"
version = "0.6.3"
description = "structured outputs for llm"
authors = ["Jason Liu <jason@jxnl.co>"]
license = "MIT"