mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
Query planner prototype (#3)
This commit is contained in:
+116
-31
@@ -1,11 +1,14 @@
|
||||
from functools import lru_cache
|
||||
import openai
|
||||
import enum
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
from pydantic import Field
|
||||
from typing import List, Tuple
|
||||
from openai_function_call import OpenAISchema
|
||||
from tenacity import retry, stop_after_attempt
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
class QueryType(str, enum.Enum):
|
||||
@@ -18,6 +21,24 @@ class QueryType(str, enum.Enum):
|
||||
MERGE_MULTIPLE_RESPONSES = "MERGE_MULTIPLE_RESPONSES"
|
||||
|
||||
|
||||
class ComputeQuery(OpenAISchema):
|
||||
"""
|
||||
Models a computation of a query, assume this can be some RAG system like llamaindex
|
||||
"""
|
||||
|
||||
query: str
|
||||
response: str = "..."
|
||||
|
||||
|
||||
class MergedResponses(OpenAISchema):
|
||||
"""
|
||||
Models a merged response of multiple queries.
|
||||
Currently we just concatinate them but we can do much more complex things.
|
||||
"""
|
||||
|
||||
responses: List[ComputeQuery]
|
||||
|
||||
|
||||
class Query(OpenAISchema):
|
||||
"""
|
||||
Class representing a single question in a question answer subquery.
|
||||
@@ -38,6 +59,31 @@ class Query(OpenAISchema):
|
||||
description="Type of question we are asking, either a single question or a multi question merge when there are multiple questions",
|
||||
)
|
||||
|
||||
async def execute(self, dependency_func):
|
||||
print("Executing", f"`self.question`")
|
||||
print("Executing with", len(self.dependancies), "dependancies")
|
||||
|
||||
if self.node_type == QueryType.SINGLE_QUESTION:
|
||||
resp = ComputeQuery(
|
||||
query=self.question,
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
pprint(resp.dict())
|
||||
return resp
|
||||
|
||||
sub_queries = dependency_func(self.dependancies)
|
||||
computed_queries = await asyncio.gather(
|
||||
*[q.execute(dependency_func=dependency_func) for q in sub_queries]
|
||||
)
|
||||
sub_answers = MergedResponses(responses=computed_queries)
|
||||
merged_query = f"{self.question}\nContext: {sub_answers.json()}"
|
||||
resp = ComputeQuery(
|
||||
query=merged_query,
|
||||
)
|
||||
await asyncio.sleep(2)
|
||||
pprint(resp.dict())
|
||||
return resp
|
||||
|
||||
|
||||
class QueryPlan(OpenAISchema):
|
||||
"""
|
||||
@@ -49,12 +95,24 @@ class QueryPlan(OpenAISchema):
|
||||
..., description="The original question we are asking"
|
||||
)
|
||||
|
||||
async def execute(self):
|
||||
# this should be done with a topological sort, but this is easier to understand
|
||||
original_question = self.query_graph[-1]
|
||||
print(f"Executing query plan from `{original_question.question}`")
|
||||
return await original_question.execute(dependency_func=self.dependencies)
|
||||
|
||||
def dependencies(self, idz: List[int]) -> List[Query]:
|
||||
"""
|
||||
Returns the dependencies of the query with the given id.
|
||||
"""
|
||||
return [q for q in self.query_graph if q.id in idz]
|
||||
|
||||
|
||||
Query.update_forward_refs()
|
||||
QueryPlan.update_forward_refs()
|
||||
|
||||
|
||||
def query_planner(question: str) -> QueryPlan:
|
||||
def query_planner(question: str, plan=False) -> QueryPlan:
|
||||
PLANNING_MODEL = "gpt-4"
|
||||
ANSWERING_MODEL = "gpt-3.5-turbo-0613"
|
||||
|
||||
@@ -65,31 +123,32 @@ def query_planner(question: str) -> QueryPlan:
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Consider: {question}\n Before you call the function, think step by step to get a correct query plan.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Lets think step by step to find the correct query plan that does not make any assuptions of what is known.",
|
||||
"content": f"Consider: {question}\nGenerate the correct query plan.",
|
||||
},
|
||||
]
|
||||
|
||||
completion = openai.ChatCompletion.create(
|
||||
model=PLANNING_MODEL,
|
||||
temperature=0,
|
||||
messages=messages,
|
||||
max_tokens=1000,
|
||||
)
|
||||
if plan:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Lets think step by step to find the correct query plan that does not make any assuptions of what is known.",
|
||||
},
|
||||
)
|
||||
completion = openai.ChatCompletion.create(
|
||||
model=PLANNING_MODEL,
|
||||
temperature=0,
|
||||
messages=messages,
|
||||
max_tokens=1000,
|
||||
)
|
||||
|
||||
messages.append(completion.choices[0].message)
|
||||
messages.append(completion.choices[0].message)
|
||||
|
||||
print(messages[-1])
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Using that information produce the complete and correct query plan.",
|
||||
}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Using that information produce the complete and correct query plan.",
|
||||
}
|
||||
)
|
||||
|
||||
completion = openai.ChatCompletion.create(
|
||||
model=ANSWERING_MODEL,
|
||||
@@ -107,17 +166,43 @@ if __name__ == "__main__":
|
||||
from pprint import pprint
|
||||
|
||||
plan = query_planner(
|
||||
"What is the difference in populations of Canada and the Jason's home country?"
|
||||
"What is the difference in populations of Canada and the Jason's home country?",
|
||||
plan=False,
|
||||
)
|
||||
pprint(plan.dict())
|
||||
"""
|
||||
{'question': {'dependancies': [{'dependancies': [],
|
||||
'node_type': <QueryType.SINGLE_QUESTION: 'SINGLE'>,
|
||||
'question': 'What is the capital of Canada?'},
|
||||
{'dependancies': [],
|
||||
'node_type': <QueryType.SINGLE_QUESTION: 'SINGLE'>,
|
||||
'question': "What is Jason's home country?"}],
|
||||
'node_type': <QueryType.MERGE_MULTIPLE_RESPONSES: 'MERGE_MULTIPLE_RESPONSES'>,
|
||||
'question': "What is of Canada and the Jason's "
|
||||
'home country?'}}
|
||||
{'query_graph': [{'dependancies': [],
|
||||
'id': 1,
|
||||
'node_type': <QueryType.SINGLE_QUESTION: 'SINGLE'>,
|
||||
'question': "Identify Jason's home country"},
|
||||
{'dependancies': [],
|
||||
'id': 2,
|
||||
'node_type': <QueryType.SINGLE_QUESTION: 'SINGLE'>,
|
||||
'question': 'Find the population of Canada'},
|
||||
{'dependancies': [1],
|
||||
'id': 3,
|
||||
'node_type': <QueryType.SINGLE_QUESTION: 'SINGLE'>,
|
||||
'question': "Find the population of Jason's home country"},
|
||||
{'dependancies': [2, 3],
|
||||
'id': 4,
|
||||
'node_type': <QueryType.SINGLE_QUESTION: 'SINGLE'>,
|
||||
'question': 'Calculate the difference in populations between '
|
||||
"Canada and Jason's home country"}]}
|
||||
"""
|
||||
|
||||
asyncio.run(plan.execute())
|
||||
"""
|
||||
Executing query plan from `What is the difference in populations of Canada and Jason's home country?`
|
||||
Executing `What is the difference in populations of Canada and Jason's home country?`
|
||||
Executing with 2 dependancies
|
||||
Executing `What is the population of Canada?`
|
||||
Executing `What is the population of Jason's home country?`
|
||||
{'query': 'What is the population of Canada?', 'response': '...'}
|
||||
{'query': "What is the population of Jason's home country?", 'response': '...'}
|
||||
{'query': "What is the difference in populations of Canada and Jason's home "
|
||||
'country?'
|
||||
'Context: {"responses": [{"query": "What is the population of '
|
||||
'Canada?", "response": "..."}, {"query": "What is the population of '
|
||||
'Jason's home country?", "response": "..."}]}',
|
||||
'response': '...'}
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user