mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
refactor(OpenAISchema): improve from_response readability and update tests (#543)
This commit is contained in:
@@ -4,9 +4,11 @@ from docstring_parser import parse
|
||||
from functools import wraps
|
||||
from pydantic import BaseModel, create_model
|
||||
from openai.types.chat import ChatCompletion
|
||||
from instructor.exceptions import IncompleteOutputException
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from instructor.mode import Mode
|
||||
from instructor.utils import extract_json_from_codeblock
|
||||
from instructor.exceptions import IncompleteOutputException
|
||||
from instructor.mode import Mode
|
||||
import logging
|
||||
|
||||
|
||||
@@ -90,66 +92,106 @@ class OpenAISchema(BaseModel): # type: ignore[misc]
|
||||
cls (OpenAISchema): An instance of the class
|
||||
"""
|
||||
if mode == Mode.ANTHROPIC_TOOLS:
|
||||
try:
|
||||
from instructor.anthropic_utils import extract_xml, xml_to_model
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please 'pip install anthropic' package to proceed."
|
||||
) from err
|
||||
assert hasattr(completion, "content")
|
||||
return xml_to_model(cls, extract_xml(completion.content[0].text)) # type:ignore
|
||||
return cls.parse_anthropic_tools(completion)
|
||||
|
||||
if mode == Mode.ANTHROPIC_JSON:
|
||||
assert hasattr(completion, "content")
|
||||
text = completion.content[0].text # type: ignore
|
||||
extra_text = extract_json_from_codeblock(text)
|
||||
return cls.model_validate_json(extra_text)
|
||||
|
||||
assert hasattr(completion, "choices"), "No choices in completion"
|
||||
return cls.parse_anthropic_json(completion, validation_context, strict)
|
||||
|
||||
if completion.choices[0].finish_reason == "length":
|
||||
logger.error("Incomplete output detected, should increase max_tokens")
|
||||
raise IncompleteOutputException()
|
||||
|
||||
# If Anthropic, this should be different
|
||||
message = completion.choices[0].message
|
||||
|
||||
if mode == Mode.FUNCTIONS:
|
||||
assert (
|
||||
message.function_call.name == cls.openai_schema["name"] # type: ignore[index]
|
||||
), "Function name does not match"
|
||||
model_response = cls.model_validate_json(
|
||||
message.function_call.arguments, # type: ignore[attr-defined]
|
||||
context=validation_context,
|
||||
strict=strict,
|
||||
)
|
||||
elif mode in {Mode.TOOLS, Mode.MISTRAL_TOOLS}:
|
||||
assert (
|
||||
len(message.tool_calls or []) == 1
|
||||
), "Instructor does not support multiple tool calls, use List[Model] instead."
|
||||
tool_call = message.tool_calls[0] # type: ignore
|
||||
assert (
|
||||
tool_call.function.name == cls.openai_schema["name"] # type: ignore[index]
|
||||
), "Tool name does not match"
|
||||
model_response = cls.model_validate_json(
|
||||
tool_call.function.arguments,
|
||||
context=validation_context,
|
||||
strict=strict,
|
||||
)
|
||||
elif mode in {Mode.JSON, Mode.JSON_SCHEMA, Mode.MD_JSON}:
|
||||
if mode == Mode.MD_JSON:
|
||||
message.content = extract_json_from_codeblock(message.content or "")
|
||||
return cls.parse_functions(completion, validation_context, strict)
|
||||
|
||||
model_response = cls.model_validate_json(
|
||||
message.content, # type: ignore
|
||||
context=validation_context,
|
||||
strict=strict,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid patch mode: {mode}")
|
||||
if mode in {Mode.TOOLS, Mode.MISTRAL_TOOLS}:
|
||||
return cls.parse_tools(completion, validation_context, strict)
|
||||
|
||||
# TODO: add logging or response handler
|
||||
return model_response
|
||||
if mode in {Mode.JSON, Mode.JSON_SCHEMA, Mode.MD_JSON}:
|
||||
return cls.parse_json(completion, validation_context, strict)
|
||||
|
||||
raise ValueError(f"Invalid patch mode: {mode}")
|
||||
|
||||
@classmethod
|
||||
def parse_anthropic_tools(
|
||||
cls: Type[BaseModel],
|
||||
completion: ChatCompletion,
|
||||
) -> BaseModel:
|
||||
try:
|
||||
from instructor.anthropic_utils import extract_xml, xml_to_model
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please 'pip install anthropic xmltodict' package to proceed."
|
||||
) from err
|
||||
assert hasattr(completion, "content")
|
||||
return xml_to_model(cls, extract_xml(completion.content[0].text)) # type:ignore
|
||||
|
||||
@classmethod
|
||||
def parse_anthropic_json(
|
||||
cls: Type[BaseModel],
|
||||
completion: ChatCompletion,
|
||||
validation_context: Optional[Dict[str, Any]] = None,
|
||||
strict: Optional[bool] = None,
|
||||
) -> BaseModel:
|
||||
assert hasattr(completion, "content")
|
||||
text = completion.content[0].text # type: ignore
|
||||
extra_text = extract_json_from_codeblock(text)
|
||||
return cls.model_validate_json(
|
||||
extra_text, context=validation_context, strict=strict
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def parse_functions(
|
||||
cls: Type[BaseModel],
|
||||
completion: ChatCompletion,
|
||||
validation_context: Optional[Dict[str, Any]] = None,
|
||||
strict: Optional[bool] = None,
|
||||
) -> BaseModel:
|
||||
message = completion.choices[0].message
|
||||
assert (
|
||||
message.function_call.name == cls.openai_schema["name"] # type: ignore[index]
|
||||
), "Function name does not match"
|
||||
return cls.model_validate_json(
|
||||
message.function_call.arguments, # type: ignore[attr-defined]
|
||||
context=validation_context,
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def parse_tools(
|
||||
cls: Type[BaseModel],
|
||||
completion: ChatCompletion,
|
||||
validation_context: Optional[Dict[str, Any]] = None,
|
||||
strict: Optional[bool] = None,
|
||||
) -> BaseModel:
|
||||
message = completion.choices[0].message
|
||||
assert (
|
||||
len(message.tool_calls or []) == 1
|
||||
), "Instructor does not support multiple tool calls, use List[Model] instead."
|
||||
tool_call = message.tool_calls[0] # type: ignore
|
||||
assert (
|
||||
tool_call.function.name == cls.openai_schema["name"] # type: ignore[index]
|
||||
), "Tool name does not match"
|
||||
return cls.model_validate_json(
|
||||
tool_call.function.arguments,
|
||||
context=validation_context,
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def parse_json(
|
||||
cls: Type[BaseModel],
|
||||
completion: ChatCompletion,
|
||||
validation_context: Optional[Dict[str, Any]] = None,
|
||||
strict: Optional[bool] = None,
|
||||
) -> BaseModel:
|
||||
message = completion.choices[0].message.content or ""
|
||||
message = extract_json_from_codeblock(message)
|
||||
|
||||
return cls.model_validate_json(
|
||||
message,
|
||||
context=validation_context,
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
|
||||
def openai_schema(cls: Type[BaseModel]) -> OpenAISchema:
|
||||
|
||||
@@ -2,5 +2,6 @@ import instructor
|
||||
|
||||
models = ["gpt-4-turbo-preview"]
|
||||
modes = [
|
||||
instructor.Mode.TOOLS,
|
||||
instructor.Mode.MD_JSON,
|
||||
]
|
||||
|
||||
@@ -117,8 +117,8 @@ def test_complete_output_no_exception(
|
||||
[{"finish_reason": "length", "data_content": '{\n"data": "incomplete dat"\n}'}],
|
||||
indirect=True,
|
||||
) # type: ignore[misc]
|
||||
async def test_incomplete_output_exception_raise(
|
||||
def test_incomplete_output_exception_raise(
|
||||
test_model: Type[OpenAISchema], mock_completion: ChatCompletion
|
||||
) -> None:
|
||||
with pytest.raises(IncompleteOutputException):
|
||||
await test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS)
|
||||
test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS)
|
||||
|
||||
Reference in New Issue
Block a user