diff --git a/instructor/function_calls.py b/instructor/function_calls.py index bb3e5dd..1826707 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -11,7 +11,6 @@ class Mode(enum.Enum): FUNCTIONS: str = "function_call" TOOLS: str = "tool_call" JSON: str = "json_mode" - JSON_SCHEMA: str = "json_schema_mode" MD_JSON: str = "markdown_json_mode" @@ -148,7 +147,13 @@ class OpenAISchema(BaseModel): context=validation_context, strict=strict, ) - elif mode in {Mode.JSON, Mode.JSON_SCHEMA, Mode.MD_JSON}: + elif mode == Mode.JSON: + return cls.model_validate_json( + message.content, + context=validation_context, + strict=strict, + ) + elif mode == Mode.MD_JSON: return cls.model_validate_json( message.content, context=validation_context, diff --git a/instructor/patch.py b/instructor/patch.py index 2c4db14..418c6c1 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -91,9 +91,7 @@ def handle_response_model( "type": "function", "function": {"name": response_model.openai_schema["name"]}, } - elif mode in {Mode.JSON, Mode.MD_JSON, Mode.JSON_SCHEMA}: - # If its a JSON Mode we need to massage the prompt a bit - # in order to get the response we want in a json format + elif mode == Mode.JSON or mode == Mode.MD_JSON: if mode == Mode.JSON: new_kwargs["response_format"] = {"type": "json_object"} message = f"""Make sure that your response to any message matches the json_schema below, @@ -103,16 +101,10 @@ def handle_response_model( if "$defs" in response_model.model_json_schema(): message += f"\nHere are some more definitions to adhere too:\n{response_model.model_json_schema()['$defs']}" - elif mode == Mode.JSON_SCHEMA: - new_kwargs["response_format"] = { - "type": "json_schema", - "schema": response_model.model_json_schema(), - } - - elif mode == Mode.MD_JSON: + else: message = f""" As a genius expert, your task is to understand the content and provide - the parsed objects in json that match the following json_schema:\n + the parsed objects in json that match the following json_schema (do not deviate at all and its okay if you cant be exact):\n {response_model.model_json_schema()['properties']} """ # Check for nested models @@ -126,11 +118,9 @@ def handle_response_model( }, ) new_kwargs["stop"] = "```" - else: - raise ValueError(f"Invalid patch mode: {mode}") # check that the first message is a system message # if it is not, add a system message to the beginning - if new_kwargs["messages"][0]["role"] != "system" and len(message) > 0: + if new_kwargs["messages"][0]["role"] != "system": new_kwargs["messages"].insert( 0, {