From aeec16448294422cfe9e0d424adb755c1665a242 Mon Sep 17 00:00:00 2001 From: Jason Liu Date: Tue, 12 Dec 2023 23:11:57 -0500 Subject: [PATCH] Add new Mode for Anyscale's json schema schema type (#273) --- instructor/function_calls.py | 9 ++------- instructor/patch.py | 18 ++++++++++++++---- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/instructor/function_calls.py b/instructor/function_calls.py index 1826707..bb3e5dd 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -11,6 +11,7 @@ class Mode(enum.Enum): FUNCTIONS: str = "function_call" TOOLS: str = "tool_call" JSON: str = "json_mode" + JSON_SCHEMA: str = "json_schema_mode" MD_JSON: str = "markdown_json_mode" @@ -147,13 +148,7 @@ class OpenAISchema(BaseModel): context=validation_context, strict=strict, ) - elif mode == Mode.JSON: - return cls.model_validate_json( - message.content, - context=validation_context, - strict=strict, - ) - elif mode == Mode.MD_JSON: + elif mode in {Mode.JSON, Mode.JSON_SCHEMA, Mode.MD_JSON}: return cls.model_validate_json( message.content, context=validation_context, diff --git a/instructor/patch.py b/instructor/patch.py index 418c6c1..2c4db14 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -91,7 +91,9 @@ def handle_response_model( "type": "function", "function": {"name": response_model.openai_schema["name"]}, } - elif mode == Mode.JSON or mode == Mode.MD_JSON: + elif mode in {Mode.JSON, Mode.MD_JSON, Mode.JSON_SCHEMA}: + # If its a JSON Mode we need to massage the prompt a bit + # in order to get the response we want in a json format if mode == Mode.JSON: new_kwargs["response_format"] = {"type": "json_object"} message = f"""Make sure that your response to any message matches the json_schema below, @@ -101,10 +103,16 @@ def handle_response_model( if "$defs" in response_model.model_json_schema(): message += f"\nHere are some more definitions to adhere too:\n{response_model.model_json_schema()['$defs']}" - else: + elif mode == Mode.JSON_SCHEMA: + new_kwargs["response_format"] = { + "type": "json_schema", + "schema": response_model.model_json_schema(), + } + + elif mode == Mode.MD_JSON: message = f""" As a genius expert, your task is to understand the content and provide - the parsed objects in json that match the following json_schema (do not deviate at all and its okay if you cant be exact):\n + the parsed objects in json that match the following json_schema:\n {response_model.model_json_schema()['properties']} """ # Check for nested models @@ -118,9 +126,11 @@ def handle_response_model( }, ) new_kwargs["stop"] = "```" + else: + raise ValueError(f"Invalid patch mode: {mode}") # check that the first message is a system message # if it is not, add a system message to the beginning - if new_kwargs["messages"][0]["role"] != "system": + if new_kwargs["messages"][0]["role"] != "system" and len(message) > 0: new_kwargs["messages"].insert( 0, {