From b06eb4faa32e7b5e0f3628a611490efd042868fa Mon Sep 17 00:00:00 2001 From: Jason Liu Date: Fri, 29 Mar 2024 08:14:26 -0400 Subject: [PATCH] refactor(OpenAISchema): improve `from_response` readability and update tests (#543) --- instructor/function_calls.py | 146 ++++++++++++++++++++++------------ tests/llm/test_openai/util.py | 1 + tests/test_function_calls.py | 4 +- 3 files changed, 97 insertions(+), 54 deletions(-) diff --git a/instructor/function_calls.py b/instructor/function_calls.py index ca77bb6..5b4c893 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -4,9 +4,11 @@ from docstring_parser import parse from functools import wraps from pydantic import BaseModel, create_model from openai.types.chat import ChatCompletion -from instructor.exceptions import IncompleteOutputException +from typing import Any, Dict, Optional, Type from instructor.mode import Mode from instructor.utils import extract_json_from_codeblock +from instructor.exceptions import IncompleteOutputException +from instructor.mode import Mode import logging @@ -90,66 +92,106 @@ class OpenAISchema(BaseModel): # type: ignore[misc] cls (OpenAISchema): An instance of the class """ if mode == Mode.ANTHROPIC_TOOLS: - try: - from instructor.anthropic_utils import extract_xml, xml_to_model - except ImportError as err: - raise ImportError( - "Please 'pip install anthropic' package to proceed." - ) from err - assert hasattr(completion, "content") - return xml_to_model(cls, extract_xml(completion.content[0].text)) # type:ignore + return cls.parse_anthropic_tools(completion) if mode == Mode.ANTHROPIC_JSON: - assert hasattr(completion, "content") - text = completion.content[0].text # type: ignore - extra_text = extract_json_from_codeblock(text) - return cls.model_validate_json(extra_text) - - assert hasattr(completion, "choices"), "No choices in completion" + return cls.parse_anthropic_json(completion, validation_context, strict) if completion.choices[0].finish_reason == "length": - logger.error("Incomplete output detected, should increase max_tokens") raise IncompleteOutputException() - # If Anthropic, this should be different - message = completion.choices[0].message - if mode == Mode.FUNCTIONS: - assert ( - message.function_call.name == cls.openai_schema["name"] # type: ignore[index] - ), "Function name does not match" - model_response = cls.model_validate_json( - message.function_call.arguments, # type: ignore[attr-defined] - context=validation_context, - strict=strict, - ) - elif mode in {Mode.TOOLS, Mode.MISTRAL_TOOLS}: - assert ( - len(message.tool_calls or []) == 1 - ), "Instructor does not support multiple tool calls, use List[Model] instead." - tool_call = message.tool_calls[0] # type: ignore - assert ( - tool_call.function.name == cls.openai_schema["name"] # type: ignore[index] - ), "Tool name does not match" - model_response = cls.model_validate_json( - tool_call.function.arguments, - context=validation_context, - strict=strict, - ) - elif mode in {Mode.JSON, Mode.JSON_SCHEMA, Mode.MD_JSON}: - if mode == Mode.MD_JSON: - message.content = extract_json_from_codeblock(message.content or "") + return cls.parse_functions(completion, validation_context, strict) - model_response = cls.model_validate_json( - message.content, # type: ignore - context=validation_context, - strict=strict, - ) - else: - raise ValueError(f"Invalid patch mode: {mode}") + if mode in {Mode.TOOLS, Mode.MISTRAL_TOOLS}: + return cls.parse_tools(completion, validation_context, strict) - # TODO: add logging or response handler - return model_response + if mode in {Mode.JSON, Mode.JSON_SCHEMA, Mode.MD_JSON}: + return cls.parse_json(completion, validation_context, strict) + + raise ValueError(f"Invalid patch mode: {mode}") + + @classmethod + def parse_anthropic_tools( + cls: Type[BaseModel], + completion: ChatCompletion, + ) -> BaseModel: + try: + from instructor.anthropic_utils import extract_xml, xml_to_model + except ImportError as err: + raise ImportError( + "Please 'pip install anthropic xmltodict' package to proceed." + ) from err + assert hasattr(completion, "content") + return xml_to_model(cls, extract_xml(completion.content[0].text)) # type:ignore + + @classmethod + def parse_anthropic_json( + cls: Type[BaseModel], + completion: ChatCompletion, + validation_context: Optional[Dict[str, Any]] = None, + strict: Optional[bool] = None, + ) -> BaseModel: + assert hasattr(completion, "content") + text = completion.content[0].text # type: ignore + extra_text = extract_json_from_codeblock(text) + return cls.model_validate_json( + extra_text, context=validation_context, strict=strict + ) + + @classmethod + def parse_functions( + cls: Type[BaseModel], + completion: ChatCompletion, + validation_context: Optional[Dict[str, Any]] = None, + strict: Optional[bool] = None, + ) -> BaseModel: + message = completion.choices[0].message + assert ( + message.function_call.name == cls.openai_schema["name"] # type: ignore[index] + ), "Function name does not match" + return cls.model_validate_json( + message.function_call.arguments, # type: ignore[attr-defined] + context=validation_context, + strict=strict, + ) + + @classmethod + def parse_tools( + cls: Type[BaseModel], + completion: ChatCompletion, + validation_context: Optional[Dict[str, Any]] = None, + strict: Optional[bool] = None, + ) -> BaseModel: + message = completion.choices[0].message + assert ( + len(message.tool_calls or []) == 1 + ), "Instructor does not support multiple tool calls, use List[Model] instead." + tool_call = message.tool_calls[0] # type: ignore + assert ( + tool_call.function.name == cls.openai_schema["name"] # type: ignore[index] + ), "Tool name does not match" + return cls.model_validate_json( + tool_call.function.arguments, + context=validation_context, + strict=strict, + ) + + @classmethod + def parse_json( + cls: Type[BaseModel], + completion: ChatCompletion, + validation_context: Optional[Dict[str, Any]] = None, + strict: Optional[bool] = None, + ) -> BaseModel: + message = completion.choices[0].message.content or "" + message = extract_json_from_codeblock(message) + + return cls.model_validate_json( + message, + context=validation_context, + strict=strict, + ) def openai_schema(cls: Type[BaseModel]) -> OpenAISchema: diff --git a/tests/llm/test_openai/util.py b/tests/llm/test_openai/util.py index 8bc658d..244382b 100644 --- a/tests/llm/test_openai/util.py +++ b/tests/llm/test_openai/util.py @@ -2,5 +2,6 @@ import instructor models = ["gpt-4-turbo-preview"] modes = [ + instructor.Mode.TOOLS, instructor.Mode.MD_JSON, ] diff --git a/tests/test_function_calls.py b/tests/test_function_calls.py index c4fb9ea..e35c093 100644 --- a/tests/test_function_calls.py +++ b/tests/test_function_calls.py @@ -117,8 +117,8 @@ def test_complete_output_no_exception( [{"finish_reason": "length", "data_content": '{\n"data": "incomplete dat"\n}'}], indirect=True, ) # type: ignore[misc] -async def test_incomplete_output_exception_raise( +def test_incomplete_output_exception_raise( test_model: Type[OpenAISchema], mock_completion: ChatCompletion ) -> None: with pytest.raises(IncompleteOutputException): - await test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS) + test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS)