Query planner prototype (#3)

This commit is contained in:
Jason Liu
2023-06-20 21:17:04 +09:00
committed by GitHub
parent 92be0d53f2
commit 7bfe767ff4
+116 -31
View File
@@ -1,11 +1,14 @@
from functools import lru_cache
import openai
import enum
import json
import asyncio
from pydantic import Field
from typing import List, Tuple
from openai_function_call import OpenAISchema
from tenacity import retry, stop_after_attempt
from pprint import pprint
class QueryType(str, enum.Enum):
@@ -18,6 +21,24 @@ class QueryType(str, enum.Enum):
MERGE_MULTIPLE_RESPONSES = "MERGE_MULTIPLE_RESPONSES"
class ComputeQuery(OpenAISchema):
"""
Models a computation of a query, assume this can be some RAG system like llamaindex
"""
query: str
response: str = "..."
class MergedResponses(OpenAISchema):
"""
Models a merged response of multiple queries.
Currently we just concatinate them but we can do much more complex things.
"""
responses: List[ComputeQuery]
class Query(OpenAISchema):
"""
Class representing a single question in a question answer subquery.
@@ -38,6 +59,31 @@ class Query(OpenAISchema):
description="Type of question we are asking, either a single question or a multi question merge when there are multiple questions",
)
async def execute(self, dependency_func):
print("Executing", f"`self.question`")
print("Executing with", len(self.dependancies), "dependancies")
if self.node_type == QueryType.SINGLE_QUESTION:
resp = ComputeQuery(
query=self.question,
)
await asyncio.sleep(1)
pprint(resp.dict())
return resp
sub_queries = dependency_func(self.dependancies)
computed_queries = await asyncio.gather(
*[q.execute(dependency_func=dependency_func) for q in sub_queries]
)
sub_answers = MergedResponses(responses=computed_queries)
merged_query = f"{self.question}\nContext: {sub_answers.json()}"
resp = ComputeQuery(
query=merged_query,
)
await asyncio.sleep(2)
pprint(resp.dict())
return resp
class QueryPlan(OpenAISchema):
"""
@@ -49,12 +95,24 @@ class QueryPlan(OpenAISchema):
..., description="The original question we are asking"
)
async def execute(self):
# this should be done with a topological sort, but this is easier to understand
original_question = self.query_graph[-1]
print(f"Executing query plan from `{original_question.question}`")
return await original_question.execute(dependency_func=self.dependencies)
def dependencies(self, idz: List[int]) -> List[Query]:
"""
Returns the dependencies of the query with the given id.
"""
return [q for q in self.query_graph if q.id in idz]
Query.update_forward_refs()
QueryPlan.update_forward_refs()
def query_planner(question: str) -> QueryPlan:
def query_planner(question: str, plan=False) -> QueryPlan:
PLANNING_MODEL = "gpt-4"
ANSWERING_MODEL = "gpt-3.5-turbo-0613"
@@ -65,31 +123,32 @@ def query_planner(question: str) -> QueryPlan:
},
{
"role": "user",
"content": f"Consider: {question}\n Before you call the function, think step by step to get a correct query plan.",
},
{
"role": "assistant",
"content": "Lets think step by step to find the correct query plan that does not make any assuptions of what is known.",
"content": f"Consider: {question}\nGenerate the correct query plan.",
},
]
completion = openai.ChatCompletion.create(
model=PLANNING_MODEL,
temperature=0,
messages=messages,
max_tokens=1000,
)
if plan:
messages.append(
{
"role": "assistant",
"content": "Lets think step by step to find the correct query plan that does not make any assuptions of what is known.",
},
)
completion = openai.ChatCompletion.create(
model=PLANNING_MODEL,
temperature=0,
messages=messages,
max_tokens=1000,
)
messages.append(completion.choices[0].message)
messages.append(completion.choices[0].message)
print(messages[-1])
messages.append(
{
"role": "user",
"content": "Using that information produce the complete and correct query plan.",
}
)
messages.append(
{
"role": "user",
"content": "Using that information produce the complete and correct query plan.",
}
)
completion = openai.ChatCompletion.create(
model=ANSWERING_MODEL,
@@ -107,17 +166,43 @@ if __name__ == "__main__":
from pprint import pprint
plan = query_planner(
"What is the difference in populations of Canada and the Jason's home country?"
"What is the difference in populations of Canada and the Jason's home country?",
plan=False,
)
pprint(plan.dict())
"""
{'question': {'dependancies': [{'dependancies': [],
'node_type': <QueryType.SINGLE_QUESTION: 'SINGLE'>,
'question': 'What is the capital of Canada?'},
{'dependancies': [],
'node_type': <QueryType.SINGLE_QUESTION: 'SINGLE'>,
'question': "What is Jason's home country?"}],
'node_type': <QueryType.MERGE_MULTIPLE_RESPONSES: 'MERGE_MULTIPLE_RESPONSES'>,
'question': "What is of Canada and the Jason's "
'home country?'}}
{'query_graph': [{'dependancies': [],
'id': 1,
'node_type': <QueryType.SINGLE_QUESTION: 'SINGLE'>,
'question': "Identify Jason's home country"},
{'dependancies': [],
'id': 2,
'node_type': <QueryType.SINGLE_QUESTION: 'SINGLE'>,
'question': 'Find the population of Canada'},
{'dependancies': [1],
'id': 3,
'node_type': <QueryType.SINGLE_QUESTION: 'SINGLE'>,
'question': "Find the population of Jason's home country"},
{'dependancies': [2, 3],
'id': 4,
'node_type': <QueryType.SINGLE_QUESTION: 'SINGLE'>,
'question': 'Calculate the difference in populations between '
"Canada and Jason's home country"}]}
"""
asyncio.run(plan.execute())
"""
Executing query plan from `What is the difference in populations of Canada and Jason's home country?`
Executing `What is the difference in populations of Canada and Jason's home country?`
Executing with 2 dependancies
Executing `What is the population of Canada?`
Executing `What is the population of Jason's home country?`
{'query': 'What is the population of Canada?', 'response': '...'}
{'query': "What is the population of Jason's home country?", 'response': '...'}
{'query': "What is the difference in populations of Canada and Jason's home "
'country?'
'Context: {"responses": [{"query": "What is the population of '
'Canada?", "response": "..."}, {"query": "What is the population of '
'Jason's home country?", "response": "..."}]}',
'response': '...'}
"""