refactor(OpenAISchema): improve from_response readability and update tests (#543)

This commit is contained in:
Jason Liu
2024-03-29 08:14:26 -04:00
committed by GitHub
parent 5c2496c946
commit b06eb4faa3
3 changed files with 97 additions and 54 deletions
+94 -52
View File
@@ -4,9 +4,11 @@ from docstring_parser import parse
from functools import wraps
from pydantic import BaseModel, create_model
from openai.types.chat import ChatCompletion
from instructor.exceptions import IncompleteOutputException
from typing import Any, Dict, Optional, Type
from instructor.mode import Mode
from instructor.utils import extract_json_from_codeblock
from instructor.exceptions import IncompleteOutputException
from instructor.mode import Mode
import logging
@@ -90,66 +92,106 @@ class OpenAISchema(BaseModel): # type: ignore[misc]
cls (OpenAISchema): An instance of the class
"""
if mode == Mode.ANTHROPIC_TOOLS:
try:
from instructor.anthropic_utils import extract_xml, xml_to_model
except ImportError as err:
raise ImportError(
"Please 'pip install anthropic' package to proceed."
) from err
assert hasattr(completion, "content")
return xml_to_model(cls, extract_xml(completion.content[0].text)) # type:ignore
return cls.parse_anthropic_tools(completion)
if mode == Mode.ANTHROPIC_JSON:
assert hasattr(completion, "content")
text = completion.content[0].text # type: ignore
extra_text = extract_json_from_codeblock(text)
return cls.model_validate_json(extra_text)
assert hasattr(completion, "choices"), "No choices in completion"
return cls.parse_anthropic_json(completion, validation_context, strict)
if completion.choices[0].finish_reason == "length":
logger.error("Incomplete output detected, should increase max_tokens")
raise IncompleteOutputException()
# If Anthropic, this should be different
message = completion.choices[0].message
if mode == Mode.FUNCTIONS:
assert (
message.function_call.name == cls.openai_schema["name"] # type: ignore[index]
), "Function name does not match"
model_response = cls.model_validate_json(
message.function_call.arguments, # type: ignore[attr-defined]
context=validation_context,
strict=strict,
)
elif mode in {Mode.TOOLS, Mode.MISTRAL_TOOLS}:
assert (
len(message.tool_calls or []) == 1
), "Instructor does not support multiple tool calls, use List[Model] instead."
tool_call = message.tool_calls[0] # type: ignore
assert (
tool_call.function.name == cls.openai_schema["name"] # type: ignore[index]
), "Tool name does not match"
model_response = cls.model_validate_json(
tool_call.function.arguments,
context=validation_context,
strict=strict,
)
elif mode in {Mode.JSON, Mode.JSON_SCHEMA, Mode.MD_JSON}:
if mode == Mode.MD_JSON:
message.content = extract_json_from_codeblock(message.content or "")
return cls.parse_functions(completion, validation_context, strict)
model_response = cls.model_validate_json(
message.content, # type: ignore
context=validation_context,
strict=strict,
)
else:
raise ValueError(f"Invalid patch mode: {mode}")
if mode in {Mode.TOOLS, Mode.MISTRAL_TOOLS}:
return cls.parse_tools(completion, validation_context, strict)
# TODO: add logging or response handler
return model_response
if mode in {Mode.JSON, Mode.JSON_SCHEMA, Mode.MD_JSON}:
return cls.parse_json(completion, validation_context, strict)
raise ValueError(f"Invalid patch mode: {mode}")
@classmethod
def parse_anthropic_tools(
cls: Type[BaseModel],
completion: ChatCompletion,
) -> BaseModel:
try:
from instructor.anthropic_utils import extract_xml, xml_to_model
except ImportError as err:
raise ImportError(
"Please 'pip install anthropic xmltodict' package to proceed."
) from err
assert hasattr(completion, "content")
return xml_to_model(cls, extract_xml(completion.content[0].text)) # type:ignore
@classmethod
def parse_anthropic_json(
cls: Type[BaseModel],
completion: ChatCompletion,
validation_context: Optional[Dict[str, Any]] = None,
strict: Optional[bool] = None,
) -> BaseModel:
assert hasattr(completion, "content")
text = completion.content[0].text # type: ignore
extra_text = extract_json_from_codeblock(text)
return cls.model_validate_json(
extra_text, context=validation_context, strict=strict
)
@classmethod
def parse_functions(
cls: Type[BaseModel],
completion: ChatCompletion,
validation_context: Optional[Dict[str, Any]] = None,
strict: Optional[bool] = None,
) -> BaseModel:
message = completion.choices[0].message
assert (
message.function_call.name == cls.openai_schema["name"] # type: ignore[index]
), "Function name does not match"
return cls.model_validate_json(
message.function_call.arguments, # type: ignore[attr-defined]
context=validation_context,
strict=strict,
)
@classmethod
def parse_tools(
cls: Type[BaseModel],
completion: ChatCompletion,
validation_context: Optional[Dict[str, Any]] = None,
strict: Optional[bool] = None,
) -> BaseModel:
message = completion.choices[0].message
assert (
len(message.tool_calls or []) == 1
), "Instructor does not support multiple tool calls, use List[Model] instead."
tool_call = message.tool_calls[0] # type: ignore
assert (
tool_call.function.name == cls.openai_schema["name"] # type: ignore[index]
), "Tool name does not match"
return cls.model_validate_json(
tool_call.function.arguments,
context=validation_context,
strict=strict,
)
@classmethod
def parse_json(
cls: Type[BaseModel],
completion: ChatCompletion,
validation_context: Optional[Dict[str, Any]] = None,
strict: Optional[bool] = None,
) -> BaseModel:
message = completion.choices[0].message.content or ""
message = extract_json_from_codeblock(message)
return cls.model_validate_json(
message,
context=validation_context,
strict=strict,
)
def openai_schema(cls: Type[BaseModel]) -> OpenAISchema:
+1
View File
@@ -2,5 +2,6 @@ import instructor
models = ["gpt-4-turbo-preview"]
modes = [
instructor.Mode.TOOLS,
instructor.Mode.MD_JSON,
]
+2 -2
View File
@@ -117,8 +117,8 @@ def test_complete_output_no_exception(
[{"finish_reason": "length", "data_content": '{\n"data": "incomplete dat"\n}'}],
indirect=True,
) # type: ignore[misc]
async def test_incomplete_output_exception_raise(
def test_incomplete_output_exception_raise(
test_model: Type[OpenAISchema], mock_completion: ChatCompletion
) -> None:
with pytest.raises(IncompleteOutputException):
await test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS)
test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS)