diff --git a/README.md b/README.md
index f308648..c427f69 100644
--- a/README.md
+++ b/README.md
@@ -106,6 +106,101 @@ user_details = UserDetails.from_response(completion)
print(user_details) # UserDetails(name="John Doe", age=30)
```
+### Example 3: Using the DSL
+
+```python
+from openai_function_call import OpenAISchema
+from openai_function_call.dsl import ChatCompletion, MultiTask, messages as m
+
+# Define a subtask you'd like to extract from then,
+# We'll use MultTask to easily map it to a List[Search]
+# so we can extract more than one
+class Search(OpenAISchema):
+ id: int
+ query: str
+
+task = (
+ ChatCompletion(name="Acme Inc Email Segmentation", model="gpt3.5-turbo-0613")
+ | m.ExpertSystem(task="Segment emails into search queries")
+ | MultiTask(subtask_class=Search)
+ | m.TaggedMessage(
+ tag="email",
+ content="Can you find the video I sent last week and also the post about dogs",
+ )
+ | m.TipsMessage(
+ tips=[
+ "When unsure about the correct segmentation, try to think about the task as a whole",
+ "If acronyms are used expand them to their full form",
+ "Use multiple phrases to describe the same thing",
+ ]
+ )
+ | m.ChainOfThought()
+)
+# Its important that this just builds you request,
+# all these | operators are overloaded and all we do is compile
+# it to the openai kwargs
+assert isinstance(task, ChatCompletion)
+pprint(task.kwargs, indent=3)
+"""
+{
+ "messages": [
+ {
+ "role": "system",
+ "content": "You are a world class, state of the art agent capable
+ of correctly completing the task: `Segment emails into search queries`"
+ },
+ {
+ "role": "user",
+ "content": "Can you find the video I sent last week and also the post about dogs"
+ },
+ ...
+ {
+ "role": "assistant",
+ "content": "Lets think step by step to get the correct answer:"
+ }
+ ],
+ "functions": [
+ {
+ "name": "MultiSearch",
+ "description": "Correct segmentation of `Search` tasks",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "tasks": {
+ "description": "Correctly segmented list of `Search` tasks",
+ "type": "array",
+ "items": {"$ref": "#/definitions/Search"}
+ }
+ },
+ "definitions": {
+ "Search": {
+ "type": "object",
+ "properties": {
+ "id": {"type": "integer"},
+ "query": {"type": "string"}
+ },
+ "required": ["id", "query"]
+ }
+ },
+ "required": ["tasks"]
+ }
+ }
+ ],
+ "function_call": {"name": "MultiSearch"},
+ "max_tokens": 1000,
+ "temperature": 0.1,
+ "model": "gpt3.5-turbo-0613"
+}
+"""
+
+# Once we call .create we'll be returned with a multitask object that contains our list of task
+result = tasks.create()
+
+for task in result.tasks:
+ # We can now extract the list of tasks as we could normally
+ assert isinstance(task, Search)
+```
+
## Advanced Usage
If you want to see more examples checkout the examples folder!
diff --git a/openai_function_call/__init__.py b/openai_function_call/__init__.py
index 323c534..bc0fb3c 100644
--- a/openai_function_call/__init__.py
+++ b/openai_function_call/__init__.py
@@ -1,111 +1,3 @@
-# MIT License
-#
-# Copyright (c) 2023 Jason Liu
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in all
-# copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-# SOFTWARE.
+from .function_calls import OpenAISchema, openai_function
-import json
-from functools import wraps
-from typing import Any, Callable
-from pydantic import validate_arguments, BaseModel
-
-
-def _remove_a_key(d, remove_key) -> None:
- """Remove a key from a dictionary recursively"""
- if isinstance(d, dict):
- for key in list(d.keys()):
- if key == remove_key:
- del d[key]
- else:
- _remove_a_key(d[key], remove_key)
-
-
-class openai_function:
- def __init__(self, func: Callable) -> None:
- self.func = func
- self.validate_func = validate_arguments(func)
- parameters = self.validate_func.model.schema()
- parameters["properties"] = {
- k: v
- for k, v in parameters["properties"].items()
- if k not in ("v__duplicate_kwargs", "args", "kwargs")
- }
- parameters["required"] = sorted(
- parameters["properties"]
- ) # bug workaround see lc
- _remove_a_key(parameters, "title")
- _remove_a_key(parameters, "additionalProperties")
- self.openai_schema = {
- "name": self.func.__name__,
- "description": self.func.__doc__,
- "parameters": parameters,
- }
- self.model = self.validate_func.model
-
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
- @wraps(self.func)
- def wrapper(*args, **kwargs):
- return self.validate_func(*args, **kwargs)
-
- return wrapper(*args, **kwargs)
-
- def from_response(self, completion, throw_error=True):
- """Execute the function from the response of an openai chat completion"""
- message = completion.choices[0].message
-
- if throw_error:
- assert "function_call" in message, "No function call detected"
- assert (
- message["function_call"]["name"] == self.openai_schema["name"]
- ), "Function name does not match"
-
- function_call = message["function_call"]
- arguments = json.loads(function_call["arguments"])
- return self.validate_func(**arguments)
-
-
-class OpenAISchema(BaseModel):
- @classmethod
- @property
- def openai_schema(cls):
- schema = cls.schema()
- parameters = {
- k: v for k, v in schema.items() if k not in ("title", "description")
- }
- parameters["required"] = sorted(parameters["properties"])
- _remove_a_key(parameters, "title")
- return {
- "name": schema["title"],
- "description": schema["description"],
- "parameters": parameters,
- }
-
- @classmethod
- def from_response(cls, completion, throw_error=True):
- message = completion.choices[0].message
-
- if throw_error:
- assert "function_call" in message, "No function call detected"
- assert (
- message["function_call"]["name"] == cls.openai_schema["name"]
- ), "Function name does not match"
-
- function_call = message["function_call"]
- arguments = json.loads(function_call["arguments"])
- return cls(**arguments)
+__all__ = ["OpenAISchema", "openai_function"]
diff --git a/openai_function_call/dsl/__init__.py b/openai_function_call/dsl/__init__.py
new file mode 100644
index 0000000..3a49644
--- /dev/null
+++ b/openai_function_call/dsl/__init__.py
@@ -0,0 +1,5 @@
+from .completion import ChatCompletion
+from .multitask import MultiTask
+from .messages import *
+
+__all__ = ["ChatCompletion", "MultiTask", "messages"]
diff --git a/openai_function_call/dsl/completion.py b/openai_function_call/dsl/completion.py
index a0b2b80..90501ee 100644
--- a/openai_function_call/dsl/completion.py
+++ b/openai_function_call/dsl/completion.py
@@ -102,71 +102,4 @@ class ChatCompletion(BaseModel):
if self.function:
return self.function.from_response(await completion)
return await completion
-
-
-def MultiTask(
- subtask_class: Type[OpenAISchema],
- name: Optional[str] = None,
- description: Optional[str] = None,
-):
- """
- Dynamically create a MultiTask OpenAISchema that can be used to segment multiple
- tasks given a base class. This creates class that can be used to create a toolkit
- for a specific task, names and descriptions are automatically generated. However
- they can be overridden.
-
- :param subtask_class: The base class to use for the MultiTask
- :param name: The name of the MultiTask
- :param description: The description of the MultiTask
-
- :return: new schema class called `Multi{subtask_class.name}`
- """
- task_name = subtask_class.__name__ if name is None else name
-
- name = f"Multi{task_name}"
-
- list_tasks = (
- List[subtask_class],
- Field(
- default_factory=list,
- repr=False,
- description=f"Correctly segmented list of `{task_name}` tasks",
- ),
- )
-
- new_cls = create_model(name, tasks=list_tasks, __base__=(OpenAISchema,))
-
- new_cls.__doc__ = (
- f"Correct segmentation of `{task_name}` tasks"
- if description is None
- else description
- )
-
- return new_cls
-
-
-if __name__ == "__main__":
- from pprint import pprint
-
- class Search(OpenAISchema):
- id: int
- query: str
-
- task = (
- ChatCompletion(name="Acme Inc Email Segmentation", model="gpt3.5-turbo-0613")
- | ExpertSystem(task="Segment emails into search queries")
- | MultiTask(subtask_class=Search)
- | TaggedMessage(
- tag="email",
- content="Can you find the video I sent last week and also the post about dogs",
- )
- | TipsMessage(
- tips=[
- "When unsure about the correct segmentation, try to think about the task as a whole",
- "If acronyms are used expand them to their full form",
- "Use multiple phrases to describe the same thing",
- ]
- )
- | ChainOfThought()
- )
- assert isinstance(task, ChatCompletion)
+
\ No newline at end of file
diff --git a/openai_function_call/dsl/messages.py b/openai_function_call/dsl/messages.py
new file mode 100644
index 0000000..d93eee1
--- /dev/null
+++ b/openai_function_call/dsl/messages.py
@@ -0,0 +1,81 @@
+from enum import Enum, auto
+from pydantic.dataclasses import dataclass
+from pydantic import Field
+from typing import Optional, List
+
+
+class MessageRole(Enum):
+ USER = auto()
+ SYSTEM = auto()
+ ASSISTANT = auto()
+
+
+@dataclass
+class Message:
+ content: str = Field(default=None, repr=True)
+ role: MessageRole = Field(default=MessageRole.USER, repr=False)
+ name: Optional[str] = Field(default=None)
+
+ def dict(self):
+ assert self.content is not None, "Content must be set!"
+ obj = {
+ "role": self.role.name.lower(),
+ "content": self.content,
+ }
+ if self.name and self.role == MessageRole.USER:
+ obj["name"] = self.name
+ return obj
+
+
+@dataclass
+class SystemMessage(Message):
+ def __post_init__(self):
+ self.role = MessageRole.SYSTEM
+
+
+@dataclass
+class UserMessage(Message):
+ def __post_init__(self):
+ self.role = MessageRole.USER
+
+
+@dataclass
+class TaggedMessage(Message):
+ tag: str = Field(default="data", repr=True)
+
+ def __post_init__(self):
+ self.role = MessageRole.USER
+ self.content = f"<{self.tag}>{self.content}{self.tag}>"
+
+
+@dataclass
+class AssistantMessage(Message):
+ def __post_init__(self):
+ self.role = MessageRole.ASSISTANT
+
+
+@dataclass
+class ExpertSystem(Message):
+ task: str = Field(default=None, repr=True)
+
+ def __post_init__(self):
+ self.role = MessageRole.SYSTEM
+ self.content = f"You are a world class, state of the art agent capable of correctly completing the task: `{self.task}`"
+
+
+@dataclass
+class TipsMessage(Message):
+ tips: List[str] = Field(default_factory=list)
+ header: str = "Here are some tips to help you complete the task"
+
+ def __post_init__(self):
+ self.role = MessageRole.USER
+ tips = "\n* ".join(self.tips)
+ self.content = f"{self.header}:\n\n* {tips}"
+
+
+@dataclass
+class ChainOfThought(Message):
+ def __post_init__(self):
+ self.role = MessageRole.ASSISTANT
+ self.content = "Lets think step by step to get the correct answer:"
diff --git a/openai_function_call/dsl/multitask.py b/openai_function_call/dsl/multitask.py
new file mode 100644
index 0000000..7f3f85d
--- /dev/null
+++ b/openai_function_call/dsl/multitask.py
@@ -0,0 +1,44 @@
+from pydantic import create_model, Field
+from typing import Optional, List, Type
+from ..function_calls import OpenAISchema
+
+
+def MultiTask(
+ subtask_class: Type[OpenAISchema],
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+):
+ """
+ Dynamically create a MultiTask OpenAISchema that can be used to segment multiple
+ tasks given a base class. This creates class that can be used to create a toolkit
+ for a specific task, names and descriptions are automatically generated. However
+ they can be overridden.
+
+ :param subtask_class: The base class to use for the MultiTask
+ :param name: The name of the MultiTask
+ :param description: The description of the MultiTask
+
+ :return: new schema class called `Multi{subtask_class.name}`
+ """
+ task_name = subtask_class.__name__ if name is None else name
+
+ name = f"Multi{task_name}"
+
+ list_tasks = (
+ List[subtask_class],
+ Field(
+ default_factory=list,
+ repr=False,
+ description=f"Correctly segmented list of `{task_name}` tasks",
+ ),
+ )
+
+ new_cls = create_model(name, tasks=list_tasks, __base__=(OpenAISchema,))
+
+ new_cls.__doc__ = (
+ f"Correct segmentation of `{task_name}` tasks"
+ if description is None
+ else description
+ )
+
+ return new_cls
diff --git a/openai_function_call/function_calls.py b/openai_function_call/function_calls.py
new file mode 100644
index 0000000..323c534
--- /dev/null
+++ b/openai_function_call/function_calls.py
@@ -0,0 +1,111 @@
+# MIT License
+#
+# Copyright (c) 2023 Jason Liu
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import json
+from functools import wraps
+from typing import Any, Callable
+from pydantic import validate_arguments, BaseModel
+
+
+def _remove_a_key(d, remove_key) -> None:
+ """Remove a key from a dictionary recursively"""
+ if isinstance(d, dict):
+ for key in list(d.keys()):
+ if key == remove_key:
+ del d[key]
+ else:
+ _remove_a_key(d[key], remove_key)
+
+
+class openai_function:
+ def __init__(self, func: Callable) -> None:
+ self.func = func
+ self.validate_func = validate_arguments(func)
+ parameters = self.validate_func.model.schema()
+ parameters["properties"] = {
+ k: v
+ for k, v in parameters["properties"].items()
+ if k not in ("v__duplicate_kwargs", "args", "kwargs")
+ }
+ parameters["required"] = sorted(
+ parameters["properties"]
+ ) # bug workaround see lc
+ _remove_a_key(parameters, "title")
+ _remove_a_key(parameters, "additionalProperties")
+ self.openai_schema = {
+ "name": self.func.__name__,
+ "description": self.func.__doc__,
+ "parameters": parameters,
+ }
+ self.model = self.validate_func.model
+
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
+ @wraps(self.func)
+ def wrapper(*args, **kwargs):
+ return self.validate_func(*args, **kwargs)
+
+ return wrapper(*args, **kwargs)
+
+ def from_response(self, completion, throw_error=True):
+ """Execute the function from the response of an openai chat completion"""
+ message = completion.choices[0].message
+
+ if throw_error:
+ assert "function_call" in message, "No function call detected"
+ assert (
+ message["function_call"]["name"] == self.openai_schema["name"]
+ ), "Function name does not match"
+
+ function_call = message["function_call"]
+ arguments = json.loads(function_call["arguments"])
+ return self.validate_func(**arguments)
+
+
+class OpenAISchema(BaseModel):
+ @classmethod
+ @property
+ def openai_schema(cls):
+ schema = cls.schema()
+ parameters = {
+ k: v for k, v in schema.items() if k not in ("title", "description")
+ }
+ parameters["required"] = sorted(parameters["properties"])
+ _remove_a_key(parameters, "title")
+ return {
+ "name": schema["title"],
+ "description": schema["description"],
+ "parameters": parameters,
+ }
+
+ @classmethod
+ def from_response(cls, completion, throw_error=True):
+ message = completion.choices[0].message
+
+ if throw_error:
+ assert "function_call" in message, "No function call detected"
+ assert (
+ message["function_call"]["name"] == cls.openai_schema["name"]
+ ), "Function name does not match"
+
+ function_call = message["function_call"]
+ arguments = json.loads(function_call["arguments"])
+ return cls(**arguments)
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/test_completion.py b/tests/test_completion.py
new file mode 100644
index 0000000..d9801bc
--- /dev/null
+++ b/tests/test_completion.py
@@ -0,0 +1,28 @@
+from openai_function_call import OpenAISchema
+from openai_function_call.dsl import ChatCompletion, MultiTask, messages as m
+
+
+def test_chatcompletion_has_kwargs():
+ class Search(OpenAISchema):
+ id: int
+ query: str
+
+ task = (
+ ChatCompletion(name="Acme Inc Email Segmentation", model="gpt3.5-turbo-0613")
+ | m.ExpertSystem(task="Segment emails into search queries")
+ | MultiTask(subtask_class=Search)
+ | m.TaggedMessage(
+ tag="email",
+ content="Can you find the video I sent last week and also the post about dogs",
+ )
+ | m.TipsMessage(
+ tips=[
+ "When unsure about the correct segmentation, try to think about the task as a whole",
+ "If acronyms are used expand them to their full form",
+ "Use multiple phrases to describe the same thing",
+ ]
+ )
+ | m.ChainOfThought()
+ )
+ assert isinstance(task, ChatCompletion)
+ assert isinstance(task.kwargs, dict)
diff --git a/tests/test_messages.py b/tests/test_messages.py
new file mode 100644
index 0000000..c422944
--- /dev/null
+++ b/tests/test_messages.py
@@ -0,0 +1,55 @@
+from openai_function_call.dsl import messages as m
+
+
+def test_create_message():
+ assert m.Message(
+ role=m.MessageRole.SYSTEM,
+ content="Hello, world!",
+ ).dict() == {
+ "role": "system",
+ "content": "Hello, world!",
+ }
+
+
+def test_create_user_message():
+ assert m.UserMessage(
+ content="Hello, world!",
+ ).dict() == {
+ "role": "user",
+ "content": "Hello, world!",
+ }
+
+
+def test_create_system_message():
+ assert m.SystemMessage(content="I am nice").dict() == {
+ "role": "system",
+ "content": "I am nice",
+ }
+
+
+def test_assistance_message():
+ assert m.AssistantMessage(content="I am nice").dict() == {
+ "role": "assistant",
+ "content": "I am nice",
+ }
+
+
+def test_create_tagged_message():
+ assert m.TaggedMessage(content="I am nice", tag="data").dict() == {
+ "role": "user",
+ "content": "I am nice",
+ }
+
+
+def test_expert_system_message():
+ assert m.ExpertSystem(task="task").dict() == {
+ "role": "system",
+ "content": "You are a world class, state of the art agent capable of correctly completing the task: `task`",
+ }
+
+
+def test_chain_of_thought_message():
+ assert m.ChainOfThought().dict() == {
+ "role": "assistant",
+ "content": "Lets think step by step to get the correct answer:",
+ }
diff --git a/tests/test_multitask.py b/tests/test_multitask.py
new file mode 100644
index 0000000..d6564cf
--- /dev/null
+++ b/tests/test_multitask.py
@@ -0,0 +1,76 @@
+from openai_function_call.dsl import MultiTask
+from openai_function_call import OpenAISchema
+
+
+def test_multi_task():
+ class Search(OpenAISchema):
+ """This is the search docstring"""
+
+ id: int
+ query: str
+
+ multitask = MultiTask(subtask_class=Search)
+ assert multitask.openai_schema == {
+ "description": "Correct segmentation of `Search` tasks",
+ "name": "MultiSearch",
+ "parameters": {
+ "definitions": {
+ "Search": {
+ "properties": {
+ "id": {"type": "integer"},
+ "query": {"type": "string"},
+ },
+ "required": ["id", "query"],
+ "description": "This is the search docstring",
+ "type": "object",
+ }
+ },
+ "properties": {
+ "tasks": {
+ "description": "Correctly segmented list of `Search` tasks",
+ "items": {"$ref": "#/definitions/Search"},
+ "type": "array",
+ }
+ },
+ "required": ["tasks"],
+ "type": "object",
+ },
+ }
+
+
+def test_multi_task_with_name_and_desc():
+ class Search(OpenAISchema):
+ """This is the search docstring"""
+
+ id: int
+ query: str
+
+ multitask = MultiTask(
+ subtask_class=Search, name="MyCustomName", description="MyCustomDesc"
+ )
+ assert multitask.openai_schema == {
+ "description": "MyCustomDesc",
+ "name": "MultiMyCustomName",
+ "parameters": {
+ "definitions": {
+ "Search": {
+ "properties": {
+ "id": {"type": "integer"},
+ "query": {"type": "string"},
+ },
+ "required": ["id", "query"],
+ "description": "This is the search docstring",
+ "type": "object",
+ }
+ },
+ "properties": {
+ "tasks": {
+ "description": "Correctly segmented list of `MyCustomName` tasks",
+ "items": {"$ref": "#/definitions/Search"},
+ "type": "array",
+ }
+ },
+ "required": ["tasks"],
+ "type": "object",
+ },
+ }