From 7bfe767ff4cdc52ab4eb517f6505494fb4c0179d Mon Sep 17 00:00:00 2001 From: Jason Liu Date: Tue, 20 Jun 2023 21:17:04 +0900 Subject: [PATCH] Query planner prototype (#3) --- question_answer_subquery.py | 147 ++++++++++++++++++++++++++++-------- 1 file changed, 116 insertions(+), 31 deletions(-) diff --git a/question_answer_subquery.py b/question_answer_subquery.py index a87da70..76c0dbc 100644 --- a/question_answer_subquery.py +++ b/question_answer_subquery.py @@ -1,11 +1,14 @@ +from functools import lru_cache import openai import enum import json +import asyncio from pydantic import Field from typing import List, Tuple from openai_function_call import OpenAISchema from tenacity import retry, stop_after_attempt +from pprint import pprint class QueryType(str, enum.Enum): @@ -18,6 +21,24 @@ class QueryType(str, enum.Enum): MERGE_MULTIPLE_RESPONSES = "MERGE_MULTIPLE_RESPONSES" +class ComputeQuery(OpenAISchema): + """ + Models a computation of a query, assume this can be some RAG system like llamaindex + """ + + query: str + response: str = "..." + + +class MergedResponses(OpenAISchema): + """ + Models a merged response of multiple queries. + Currently we just concatinate them but we can do much more complex things. + """ + + responses: List[ComputeQuery] + + class Query(OpenAISchema): """ Class representing a single question in a question answer subquery. @@ -38,6 +59,31 @@ class Query(OpenAISchema): description="Type of question we are asking, either a single question or a multi question merge when there are multiple questions", ) + async def execute(self, dependency_func): + print("Executing", f"`self.question`") + print("Executing with", len(self.dependancies), "dependancies") + + if self.node_type == QueryType.SINGLE_QUESTION: + resp = ComputeQuery( + query=self.question, + ) + await asyncio.sleep(1) + pprint(resp.dict()) + return resp + + sub_queries = dependency_func(self.dependancies) + computed_queries = await asyncio.gather( + *[q.execute(dependency_func=dependency_func) for q in sub_queries] + ) + sub_answers = MergedResponses(responses=computed_queries) + merged_query = f"{self.question}\nContext: {sub_answers.json()}" + resp = ComputeQuery( + query=merged_query, + ) + await asyncio.sleep(2) + pprint(resp.dict()) + return resp + class QueryPlan(OpenAISchema): """ @@ -49,12 +95,24 @@ class QueryPlan(OpenAISchema): ..., description="The original question we are asking" ) + async def execute(self): + # this should be done with a topological sort, but this is easier to understand + original_question = self.query_graph[-1] + print(f"Executing query plan from `{original_question.question}`") + return await original_question.execute(dependency_func=self.dependencies) + + def dependencies(self, idz: List[int]) -> List[Query]: + """ + Returns the dependencies of the query with the given id. + """ + return [q for q in self.query_graph if q.id in idz] + Query.update_forward_refs() QueryPlan.update_forward_refs() -def query_planner(question: str) -> QueryPlan: +def query_planner(question: str, plan=False) -> QueryPlan: PLANNING_MODEL = "gpt-4" ANSWERING_MODEL = "gpt-3.5-turbo-0613" @@ -65,31 +123,32 @@ def query_planner(question: str) -> QueryPlan: }, { "role": "user", - "content": f"Consider: {question}\n Before you call the function, think step by step to get a correct query plan.", - }, - { - "role": "assistant", - "content": "Lets think step by step to find the correct query plan that does not make any assuptions of what is known.", + "content": f"Consider: {question}\nGenerate the correct query plan.", }, ] - completion = openai.ChatCompletion.create( - model=PLANNING_MODEL, - temperature=0, - messages=messages, - max_tokens=1000, - ) + if plan: + messages.append( + { + "role": "assistant", + "content": "Lets think step by step to find the correct query plan that does not make any assuptions of what is known.", + }, + ) + completion = openai.ChatCompletion.create( + model=PLANNING_MODEL, + temperature=0, + messages=messages, + max_tokens=1000, + ) - messages.append(completion.choices[0].message) + messages.append(completion.choices[0].message) - print(messages[-1]) - - messages.append( - { - "role": "user", - "content": "Using that information produce the complete and correct query plan.", - } - ) + messages.append( + { + "role": "user", + "content": "Using that information produce the complete and correct query plan.", + } + ) completion = openai.ChatCompletion.create( model=ANSWERING_MODEL, @@ -107,17 +166,43 @@ if __name__ == "__main__": from pprint import pprint plan = query_planner( - "What is the difference in populations of Canada and the Jason's home country?" + "What is the difference in populations of Canada and the Jason's home country?", + plan=False, ) pprint(plan.dict()) """ - {'question': {'dependancies': [{'dependancies': [], - 'node_type': , - 'question': 'What is the capital of Canada?'}, - {'dependancies': [], - 'node_type': , - 'question': "What is Jason's home country?"}], - 'node_type': , - 'question': "What is of Canada and the Jason's " - 'home country?'}} + {'query_graph': [{'dependancies': [], + 'id': 1, + 'node_type': , + 'question': "Identify Jason's home country"}, + {'dependancies': [], + 'id': 2, + 'node_type': , + 'question': 'Find the population of Canada'}, + {'dependancies': [1], + 'id': 3, + 'node_type': , + 'question': "Find the population of Jason's home country"}, + {'dependancies': [2, 3], + 'id': 4, + 'node_type': , + 'question': 'Calculate the difference in populations between ' + "Canada and Jason's home country"}]} + """ + + asyncio.run(plan.execute()) + """ + Executing query plan from `What is the difference in populations of Canada and Jason's home country?` + Executing `What is the difference in populations of Canada and Jason's home country?` + Executing with 2 dependancies + Executing `What is the population of Canada?` + Executing `What is the population of Jason's home country?` + {'query': 'What is the population of Canada?', 'response': '...'} + {'query': "What is the population of Jason's home country?", 'response': '...'} + {'query': "What is the difference in populations of Canada and Jason's home " + 'country?' + 'Context: {"responses": [{"query": "What is the population of ' + 'Canada?", "response": "..."}, {"query": "What is the population of ' + 'Jason's home country?", "response": "..."}]}', + 'response': '...'} """