Add new Mode for Anyscale's json schema schema type (#273)

This commit is contained in:
Jason Liu
2023-12-12 23:11:57 -05:00
committed by GitHub
parent 329aad023b
commit aeec164482
2 changed files with 16 additions and 11 deletions
+2 -7
View File
@@ -11,6 +11,7 @@ class Mode(enum.Enum):
FUNCTIONS: str = "function_call"
TOOLS: str = "tool_call"
JSON: str = "json_mode"
JSON_SCHEMA: str = "json_schema_mode"
MD_JSON: str = "markdown_json_mode"
@@ -147,13 +148,7 @@ class OpenAISchema(BaseModel):
context=validation_context,
strict=strict,
)
elif mode == Mode.JSON:
return cls.model_validate_json(
message.content,
context=validation_context,
strict=strict,
)
elif mode == Mode.MD_JSON:
elif mode in {Mode.JSON, Mode.JSON_SCHEMA, Mode.MD_JSON}:
return cls.model_validate_json(
message.content,
context=validation_context,
+14 -4
View File
@@ -91,7 +91,9 @@ def handle_response_model(
"type": "function",
"function": {"name": response_model.openai_schema["name"]},
}
elif mode == Mode.JSON or mode == Mode.MD_JSON:
elif mode in {Mode.JSON, Mode.MD_JSON, Mode.JSON_SCHEMA}:
# If its a JSON Mode we need to massage the prompt a bit
# in order to get the response we want in a json format
if mode == Mode.JSON:
new_kwargs["response_format"] = {"type": "json_object"}
message = f"""Make sure that your response to any message matches the json_schema below,
@@ -101,10 +103,16 @@ def handle_response_model(
if "$defs" in response_model.model_json_schema():
message += f"\nHere are some more definitions to adhere too:\n{response_model.model_json_schema()['$defs']}"
else:
elif mode == Mode.JSON_SCHEMA:
new_kwargs["response_format"] = {
"type": "json_schema",
"schema": response_model.model_json_schema(),
}
elif mode == Mode.MD_JSON:
message = f"""
As a genius expert, your task is to understand the content and provide
the parsed objects in json that match the following json_schema (do not deviate at all and its okay if you cant be exact):\n
the parsed objects in json that match the following json_schema:\n
{response_model.model_json_schema()['properties']}
"""
# Check for nested models
@@ -118,9 +126,11 @@ def handle_response_model(
},
)
new_kwargs["stop"] = "```"
else:
raise ValueError(f"Invalid patch mode: {mode}")
# check that the first message is a system message
# if it is not, add a system message to the beginning
if new_kwargs["messages"][0]["role"] != "system":
if new_kwargs["messages"][0]["role"] != "system" and len(message) > 0:
new_kwargs["messages"].insert(
0,
{