Implementing the Distillation Decorator prototype (#118)

This commit is contained in:
Jason Liu
2023-10-15 14:54:55 -04:00
committed by GitHub
parent 41c4636dd2
commit 9fe3c83eca
9 changed files with 610 additions and 1 deletions
+110
View File
@@ -0,0 +1,110 @@
---
draft: False
date: 2023-10-17
tags:
- python
- distilation
- function calling
- tinetuning
---
# Introduction to `Instructions` from `Instructor`, finetuning from Python functions.
The core philosophy with the `instructor` library is to make language models backwards compatible with existing code. By adding Pydantic in the mix we're able to easily work with LLMs without much worry.
However, many times, a single function isn't just one LLM call. After the results are returned theres [validation](/docs/validation.md), some additional processing and formatting before you `return` the result.
But the promise of LLMs is that they can do all of this in one go. So how do we get there? Finetuning end to end is a great tool for enhancing language models. Instructor uses type hints via Pydantic to maintain backward compatibility. Distillation focuses on fine-tuning language models to imitate specific functions.
## Challenges in Fine-tuning
Fine-tuning a model isn't as straightforward as just writing `def f(a, b): return a * b` to teach a model three-digit multiplication. Substantial data preparation is required, making logging for data collection cumbersome. Luckily OpenAI not only provides a fine-tuning script but also one for function calling which simplies the process backed by structured outputs! More over, the finetune allows us to avoid passing the schema to the model, resulting in less tokens being used!
## Role of Instructor in Easing the Process
The feature `from instructor import Instructions` simplifies this. It decorates Python functions that return Pydantic objects, automatically creating a fine-tuning dataset when provided a handler for logging. This allows you to finetune a model to imitate a function's behavior.
## How to Use Instructor's Distillation Feature
Here's an example to illustrate its use:
```python
import logging
import random
from pydantic import BaseModel
from instructor import Instructions
logging.basicConfig(level=logging.INFO)
instructions = Instructions(
name="three_digit_multiply",
finetune_format="messages",
log_handlers=[logging.FileHandler("math_finetunes.jsonl")]
)
class Multiply(BaseModel):
a: int
b: int
result: int
@instructions.distil
def fn(a: int, b: int) -> Multiply:
resp = a * b
return Multiply(a=a, b=b, result=resp)
for _ in range(10):
a = random.randint(100, 999)
b = random.randint(100, 999)
print(fn(a, b))
```
## Logging output
```python
{
"messages": [
{"role": "system", "content": 'Predict the results of this function: ...'},
{"role": "user", "content": 'Return fn(133, b=539)'},
{"role": "assistant",
"function_call":
{
"name": "Multiply",
"arguments": '{"a":133,"b":539,"result":89509}'
}
}
],
"functions": [
{"name": "Multiply", "description": "Correctly extracted `Multiply`..."}
]
}
```
## Why Instructor and Distillation are Useful
Many systems are not as simple as a single `openai.ChatCompletion.create` call, instead we often create objects, do additional processing, validation, error correction, and then return the result. This is a lot of work, and it's easy to make mistakes. Instructor's `distil` feature makes this process easier by:
1. Streamlines complex functions with validations, making them more efficient.
2. Facilitates the integration of classical machine learning with language models.
By understanding and leveraging these capabilities, you can create powerful, fine-tuned language models with ease. To learn more about how to use the file to finetune a model, check out the [cli](/docs/cli/finetune.md)
## Next Steps
This post is mostly a peek of what I've been working on this week. Once we have a model trained I'd like to be able to dynamically swap the implemetnation of a function with a model. This would allow us to do things like:
```python
from instructor import Instructions
instructions = Instructions(
name="three_digit_multiply",
)
@instructions.distil(model='gpt-3.5-turbo:finetuned', swap=True)
def fn(a: int, b: int) -> Multiply:
resp = a + b
return Multiply(a=a, b=b, result=resp)
```
Now we can swap out the implementation of `fn` with calling the finetuned model, since we know the response type is still `Multiply` we can use instructor behind the scenes and have it be backwards compatible with the existing code.
Now if you're thinking wow, I'd love a backend service to do this for continously, you're in luck! Please check out the survey at [useinstructor.com](https://useinstructor.com) and let us know who you are.
+93
View File
@@ -0,0 +1,93 @@
# Distilling python functions into LLM
`Instructions` from the `Instructor` library offers a seamless way to make language models backward compatible with existing Python functions. By employing Pydantic type hints, it not only ensures compatibility but also facilitates fine-tuning language models to emulate these functions end-to-end.
## The Challenges in Function-Level Fine-Tuning
Unlike simple script-level fine-tuning, replicating the behavior of a Python function in a language model involves intricate data preparation. For instance, teaching a model to execute three-digit multiplication is not as trivial as implementing `def f(a, b): return a * b`. OpenAI's fine-tuning script coupled with their function calling utility provides a structured output, thereby simplifying the data collection process. Additionally, this eliminates the need for passing the schema to the model, thus conserving tokens.
## The Role of `Instructions` in Simplifying the Fine-Tuning Process
By using `Instructions`, you can annotate a Python function that returns a Pydantic object, thereby automating the dataset creation for fine-tuning. A handler for logging is all that's needed to build this dataset.
## How to Implement `Instructions` in Your Code
Here's a step-by-step example:
```python
import logging
from pydantic import BaseModel
from instructor import Instructions
logging.basicConfig(level=logging.INFO)
instructions = Instructions(
name="three_digit_multiply",
finetune_format="messages",
log_handlers=[logging.FileHandler("math_finetunes.jsonl")]
)
class Multiply(BaseModel):
a: int
b: int
result: int
@instructions.distil
def fn(a: int, b: int) -> Multiply:
resp = a + b
return Multiply(a=a, b=b, result=resp)
```
## Custom Log Handlers for Data Collection
While the example above uses a file-based log handler, you can easily extend this to custom log handlers for different storage solutions. The following skeleton code illustrates how to create a log handler for an S3 bucket:
```python
import logging
import boto3
class S3LogHandler(logging.Handler):
def __init__(self, bucket, key):
logging.Handler.__init__(self)
self.bucket = bucket
self.key = key
def emit(self, record):
s3 = boto3.client('s3')
log_entry = self.format(record)
s3.put_object(Body=log_entry, Bucket=self.bucket, Key=self.key)
```
You can add this custom log handler to `Instructions` as shown:
```python
instructions = Instructions(
name="three_digit_multiply",
finetune_format="messages",
log_handlers=[S3LogHandler(bucket='your-bucket', key='your-key')]
)
```
## Why `Instructions` is a Game-Changer
1. It condenses complex, multi-step functions with validations into a single fine-tuned model.
2. It integrates language models with classical machine learning seamlessly.
## Next Steps and Future Scope
Going forward, the aim is to dynamically switch between the Python function and its fine-tuned model representation. This could look like:
```python
from instructor import Instructions
instructions = Instructions(
name="three_digit_multiply",
)
@instructions.distil(model='gpt-3.5-turbo:finetuned', swap=True)
def fn(a: int, b: int) -> Multiply:
resp = a + b
return Multiply(a=a, b=b, result=resp)
```
This dynamic switching retains backward compatibility while improving efficiency, opening up exciting avenues for future developments.
@@ -0,0 +1,10 @@
{"messages": [{"role": "system", "content": "Predict the results of this function:\n\ndef fn(a: int, b: int, c: str) -> __main__.Multiply\n\"\"\"\n_summary_\n\nArgs:\n a (int): _description_\n b (int): _description_\n c (str): _description_\n\nReturns:\n Response: _description_\n\"\"\""}, {"role": "user", "content": "Return `fn(540, b=677, c=\"hello\")`"}, {"role": "assistant", "function_call": {"name": "Multiply", "arguments": "{\n \"a\": 540,\n \"b\": 677,\n \"result\": 1217\n}"}}], "functions": [{"name": "Multiply", "description": "Correctly extracted `Multiply` with all the required parameters with correct types", "parameters": {"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}, "result": {"description": "The result of the multiplication", "type": "integer"}}, "required": ["a", "b", "result"], "type": "object"}}]}
{"messages": [{"role": "system", "content": "Predict the results of this function:\n\ndef fn(a: int, b: int, c: str) -> __main__.Multiply\n\"\"\"\n_summary_\n\nArgs:\n a (int): _description_\n b (int): _description_\n c (str): _description_\n\nReturns:\n Response: _description_\n\"\"\""}, {"role": "user", "content": "Return `fn(798, b=534, c=\"hello\")`"}, {"role": "assistant", "function_call": {"name": "Multiply", "arguments": "{\n \"a\": 798,\n \"b\": 534,\n \"result\": 1332\n}"}}], "functions": [{"name": "Multiply", "description": "Correctly extracted `Multiply` with all the required parameters with correct types", "parameters": {"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}, "result": {"description": "The result of the multiplication", "type": "integer"}}, "required": ["a", "b", "result"], "type": "object"}}]}
{"messages": [{"role": "system", "content": "Predict the results of this function:\n\ndef fn(a: int, b: int, c: str) -> __main__.Multiply\n\"\"\"\n_summary_\n\nArgs:\n a (int): _description_\n b (int): _description_\n c (str): _description_\n\nReturns:\n Response: _description_\n\"\"\""}, {"role": "user", "content": "Return `fn(608, b=669, c=\"hello\")`"}, {"role": "assistant", "function_call": {"name": "Multiply", "arguments": "{\n \"a\": 608,\n \"b\": 669,\n \"result\": 1277\n}"}}], "functions": [{"name": "Multiply", "description": "Correctly extracted `Multiply` with all the required parameters with correct types", "parameters": {"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}, "result": {"description": "The result of the multiplication", "type": "integer"}}, "required": ["a", "b", "result"], "type": "object"}}]}
{"messages": [{"role": "system", "content": "Predict the results of this function:\n\ndef fn(a: int, b: int, c: str) -> __main__.Multiply\n\"\"\"\n_summary_\n\nArgs:\n a (int): _description_\n b (int): _description_\n c (str): _description_\n\nReturns:\n Response: _description_\n\"\"\""}, {"role": "user", "content": "Return `fn(982, b=768, c=\"hello\")`"}, {"role": "assistant", "function_call": {"name": "Multiply", "arguments": "{\n \"a\": 982,\n \"b\": 768,\n \"result\": 1750\n}"}}], "functions": [{"name": "Multiply", "description": "Correctly extracted `Multiply` with all the required parameters with correct types", "parameters": {"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}, "result": {"description": "The result of the multiplication", "type": "integer"}}, "required": ["a", "b", "result"], "type": "object"}}]}
{"messages": [{"role": "system", "content": "Predict the results of this function:\n\ndef fn(a: int, b: int, c: str) -> __main__.Multiply\n\"\"\"\n_summary_\n\nArgs:\n a (int): _description_\n b (int): _description_\n c (str): _description_\n\nReturns:\n Response: _description_\n\"\"\""}, {"role": "user", "content": "Return `fn(994, b=682, c=\"hello\")`"}, {"role": "assistant", "function_call": {"name": "Multiply", "arguments": "{\n \"a\": 994,\n \"b\": 682,\n \"result\": 1676\n}"}}], "functions": [{"name": "Multiply", "description": "Correctly extracted `Multiply` with all the required parameters with correct types", "parameters": {"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}, "result": {"description": "The result of the multiplication", "type": "integer"}}, "required": ["a", "b", "result"], "type": "object"}}]}
{"messages": [{"role": "system", "content": "Predict the results of this function:\n\ndef fn(a: int, b: int, c: str) -> __main__.Multiply\n\"\"\"\n_summary_\n\nArgs:\n a (int): _description_\n b (int): _description_\n c (str): _description_\n\nReturns:\n Response: _description_\n\"\"\""}, {"role": "user", "content": "Return `fn(467, b=754, c=\"hello\")`"}, {"role": "assistant", "function_call": {"name": "Multiply", "arguments": "{\n \"a\": 467,\n \"b\": 754,\n \"result\": 1221\n}"}}], "functions": [{"name": "Multiply", "description": "Correctly extracted `Multiply` with all the required parameters with correct types", "parameters": {"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}, "result": {"description": "The result of the multiplication", "type": "integer"}}, "required": ["a", "b", "result"], "type": "object"}}]}
{"messages": [{"role": "system", "content": "Predict the results of this function:\n\ndef fn(a: int, b: int, c: str) -> __main__.Multiply\n\"\"\"\n_summary_\n\nArgs:\n a (int): _description_\n b (int): _description_\n c (str): _description_\n\nReturns:\n Response: _description_\n\"\"\""}, {"role": "user", "content": "Return `fn(497, b=364, c=\"hello\")`"}, {"role": "assistant", "function_call": {"name": "Multiply", "arguments": "{\n \"a\": 497,\n \"b\": 364,\n \"result\": 861\n}"}}], "functions": [{"name": "Multiply", "description": "Correctly extracted `Multiply` with all the required parameters with correct types", "parameters": {"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}, "result": {"description": "The result of the multiplication", "type": "integer"}}, "required": ["a", "b", "result"], "type": "object"}}]}
{"messages": [{"role": "system", "content": "Predict the results of this function:\n\ndef fn(a: int, b: int, c: str) -> __main__.Multiply\n\"\"\"\n_summary_\n\nArgs:\n a (int): _description_\n b (int): _description_\n c (str): _description_\n\nReturns:\n Response: _description_\n\"\"\""}, {"role": "user", "content": "Return `fn(840, b=821, c=\"hello\")`"}, {"role": "assistant", "function_call": {"name": "Multiply", "arguments": "{\n \"a\": 840,\n \"b\": 821,\n \"result\": 1661\n}"}}], "functions": [{"name": "Multiply", "description": "Correctly extracted `Multiply` with all the required parameters with correct types", "parameters": {"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}, "result": {"description": "The result of the multiplication", "type": "integer"}}, "required": ["a", "b", "result"], "type": "object"}}]}
{"messages": [{"role": "system", "content": "Predict the results of this function:\n\ndef fn(a: int, b: int, c: str) -> __main__.Multiply\n\"\"\"\n_summary_\n\nArgs:\n a (int): _description_\n b (int): _description_\n c (str): _description_\n\nReturns:\n Response: _description_\n\"\"\""}, {"role": "user", "content": "Return `fn(646, b=835, c=\"hello\")`"}, {"role": "assistant", "function_call": {"name": "Multiply", "arguments": "{\n \"a\": 646,\n \"b\": 835,\n \"result\": 1481\n}"}}], "functions": [{"name": "Multiply", "description": "Correctly extracted `Multiply` with all the required parameters with correct types", "parameters": {"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}, "result": {"description": "The result of the multiplication", "type": "integer"}}, "required": ["a", "b", "result"], "type": "object"}}]}
{"messages": [{"role": "system", "content": "Predict the results of this function:\n\ndef fn(a: int, b: int, c: str) -> __main__.Multiply\n\"\"\"\n_summary_\n\nArgs:\n a (int): _description_\n b (int): _description_\n c (str): _description_\n\nReturns:\n Response: _description_\n\"\"\""}, {"role": "user", "content": "Return `fn(926, b=196, c=\"hello\")`"}, {"role": "assistant", "function_call": {"name": "Multiply", "arguments": "{\n \"a\": 926,\n \"b\": 196,\n \"result\": 1122\n}"}}], "functions": [{"name": "Multiply", "description": "Correctly extracted `Multiply` with all the required parameters with correct types", "parameters": {"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}, "result": {"description": "The result of the multiplication", "type": "integer"}}, "required": ["a", "b", "result"], "type": "object"}}]}
+79
View File
@@ -0,0 +1,79 @@
import logging
from pydantic import BaseModel, Field
from instructor import Instructions
logging.basicConfig(level=logging.INFO)
# Usage
instructions = Instructions(
name="three_digit_multiply",
finetune_format="messages",
log_handlers=[
logging.FileHandler("math_finetunes.jsonl"),
],
)
class Multiply(BaseModel):
a: int
b: int
result: int = Field(..., description="The result of the multiplication")
@instructions.distil
def fn(a: int, b: int, c: str) -> Multiply:
"""_summary_
Args:
a (int): _description_
b (int): _description_
c (str): _description_
Returns:
Response: _description_
"""
resp = a + b
return Multiply(a=a, b=b, result=resp)
if __name__ == "__main__":
import random
# A log will look like this:
log_line = {
"messages": [
{
"role": "system",
"content": 'Predict the results of this function:\n\ndef fn(a: int, b: int, c: str) -> __main__.Response\n"""\n_summary_\n\nArgs:\n a (int): _description_\n b (int): _description_\n c (str): _description_\n\nReturns:\n Response: _description_\n"""',
},
{"role": "user", "content": 'Return fn(133, b=539, c="hello")'},
{
"role": "assistant",
"function_call": {
"name": "Response",
"arguments": '{"a":133,"b":539,"result":672}',
},
},
],
"functions": [
{
"name": "Response",
"description": "Correctly extracted `Response` with all the required parameters with correct types",
"parameters": {
"properties": {
"a": {"type": "integer"},
"b": {"type": "integer"},
"result": {"type": "integer"},
},
"required": ["a", "b", "result"],
"type": "object",
},
}
],
}
for _ in range(10):
a = random.randint(100, 999)
b = random.randint(100, 999)
print("returning", fn(a, b=b, c="hello"))
+3
View File
@@ -1,4 +1,5 @@
from .function_calls import OpenAISchema, openai_function, openai_schema
from .distil import FinetuneFormat, Instructions
from .dsl import MultiTask, Maybe, llm_validator, CitationMixin
from .patch import patch, unpatch
@@ -11,5 +12,7 @@ __all__ = [
"openai_schema",
"patch",
"llm_validator",
"FinetuneFormat",
"Instructions",
"unpatch",
]
+214
View File
@@ -0,0 +1,214 @@
import enum
import functools
import inspect
import json
import logging
from typing import Any, Callable, List, Optional
import uuid
from pydantic import BaseModel, validate_call
from instructor import openai_schema
class FinetuneFormat(enum.Enum):
MESSAGES: str = "messages"
RAW: str = "raw"
def get_signature_from_fn(fn: Callable) -> str:
"""
Get the function signature as a string.
:Example:
>>> def my_function(a: int, b: int) -> int:
>>> return a + b
>>>
>>> get_signature_from_fn(my_function)
"def my_function(a: int, b: int) -> int"
:param fn: Function to get the signature for.
:return: Function signature as a string.
"""
sig = inspect.signature(fn)
lines = f"def {fn.__name__}{sig}"
docstring = inspect.getdoc(fn)
if docstring:
formatted_docstring = f'"""\n{docstring}\n"""'
else:
formatted_docstring = ""
return f"{lines}\n{formatted_docstring}"
@functools.lru_cache()
def format_function(func: Callable) -> str:
"""
Format a function as a string with docstring and body.
"""
source_lines = inspect.getsourcelines(func)
definition = " ".join(source_lines[0]).strip()
docstring = inspect.getdoc(func)
if docstring:
formatted_docstring = f'"""\n{docstring}\n"""'
else:
formatted_docstring = ""
body = inspect.getsource(func)
body = body.replace(f"def {func.__name__}", "")
return f"{definition}\n{formatted_docstring}\n{body}"
def is_return_type_base_model_or_instance(func: Callable[..., Any]) -> bool:
"""
Check if the return type of a function is a pydantic BaseModel or an instance of it.
:param func: Function to check.
:return: True if the return type is a pydantic BaseModel or an instance of it.
"""
return_type = inspect.signature(func).return_annotation
assert (
return_type != inspect.Signature.empty
), "Must have a return type hint that is a pydantic BaseModel"
return inspect.isclass(return_type) and issubclass(return_type, BaseModel)
class Instructions:
def __init__(
self,
name: str = None,
id: str = None,
log_handlers: List[logging.Handler] = None,
finetune_format: FinetuneFormat = FinetuneFormat.MESSAGES,
indent: int = 2,
):
self.name = name
self.id = id or str(uuid.uuid4())
self.unique_id = str(uuid.uuid4())
self.finetune_format = finetune_format
self.indent = indent
self.logger = logging.getLogger(self.name)
for handler in log_handlers or []:
self.logger.addHandler(handler)
def distil(
self,
*args,
name: str = None,
mode: str = "distil",
fine_tune_format: FinetuneFormat = None,
):
"""
Decorator to track the function call and response, supports distillation and dispatch modes.
If used without arguments, it must be used as a decorator.
:Example:
>>> @distil
>>> def my_function() -> MyModel:
>>> return MyModel()
>>>
>>> @distil(name="my_function")
>>> def my_function() -> MyModel:
>>> return MyModel()
:param fn: Function to track.
:param name: Name of the function to track. Defaults to the function name.
:param mode: Mode to use for distillation. Defaults to "distil".
"""
allowed_modes = {"distil", "dispatch"}
assert mode in allowed_modes, f"Must be in {allowed_modes}"
assert mode == "distil", "Only distil mode is supported at the moment."
if fine_tune_format is None:
fine_tune_format = self.finetune_format
def _wrap_distil(fn):
msg = f"Return type hint for {fn} must subclass `pydantic.BaseModel'"
assert is_return_type_base_model_or_instance(fn), msg
@functools.wraps(fn)
def _distil(*args, **kwargs):
resp = fn(*args, **kwargs)
self.track(
fn, args, kwargs, resp, name=name, finetune_format=fine_tune_format
)
return resp
return _distil
if len(args) == 1 and callable(args[0]):
return _wrap_distil(args[0])
return _wrap_distil
@validate_call
def track(
self,
fn: Callable[..., Any],
args: tuple,
kwargs: dict,
resp: BaseModel,
name: Optional[str] = None,
finetune_format: FinetuneFormat = FinetuneFormat.MESSAGES,
):
"""
Track the function call and response in a log file, later used for finetuning.
:param fn: Function to track.
:param args: Arguments passed to the function.
:param kwargs: Keyword arguments passed to the function.
:param resp: Response returned by the function.
:param name: Name of the function to track. Defaults to the function name.
:param finetune_format: Format to use for finetuning. Defaults to "raw".
"""
name = name if name else fn.__name__
base_model: BaseModel = type(resp)
if finetune_format == FinetuneFormat.MESSAGES:
openai_function_call = openai_schema(base_model).openai_schema
func_def = get_signature_from_fn(fn).replace(fn.__name__, name)
str_args = ", ".join(map(str, args))
str_kwargs = (
", ".join(f"{k}={json.dumps(v)}" for k, v in kwargs.items()) or None
)
call_args = ", ".join(filter(None, [str_args, str_kwargs]))
function_body = {
"messages": [
{
"role": "system",
"content": f"Predict the results of this function:\n\n{func_def}",
},
{
"role": "user",
"content": f"Return `{name}({call_args})`",
},
{
"role": "assistant",
"function_call": {
"name": openai_function_call["name"],
"arguments": resp.model_dump_json(indent=self.indent),
},
},
],
"functions": [openai_function_call],
}
self.logger.info(json.dumps(function_body))
if finetune_format == FinetuneFormat.RAW:
function_body = dict(
fn_name=name,
fn_repr=format_function(fn),
args=args,
kwargs=kwargs,
resp=resp.model_dump(),
schema=base_model.model_json_schema(),
)
self.logger.info(json.dumps(function_body))
+1
View File
@@ -56,6 +56,7 @@ nav:
- Introduction:
- Getting Started: 'index.md'
- Prompt Engineering Tips: 'tips/index.md'
- Distillation: 'distillation.md'
- Helpers:
- Reasking and Validation Overview: "reask_validation.md"
- Multiple Extractions: "multitask.md"
+1 -1
View File
@@ -1,6 +1,6 @@
[tool.poetry]
name = "instructor"
version = "0.2.8"
version = "0.2.9"
description = "Helper functions that allow us to improve openai's function_call ergonomics"
authors = ["Jason <jason@jxnl.co>"]
license = "MIT"
+99
View File
@@ -0,0 +1,99 @@
import pytest
import openai
from pydantic import BaseModel
from instructor.distil import (
Instructions,
format_function,
get_signature_from_fn,
is_return_type_base_model_or_instance,
)
# Replace `your_module_name` with your actual module name
instructions = Instructions(
name="test_distil",
)
class SimpleModel(BaseModel):
data: int
def test_must_have_hint():
with pytest.raises(AssertionError):
@instructions.distil
def test_func(x: int):
return SimpleModel(data=x)
def test_must_be_base_model():
with pytest.raises(AssertionError):
@instructions.distil
def test_func(x) -> int:
return SimpleModel(data=x)
def test_is_return_type_base_model_or_instance():
def valid_function() -> SimpleModel:
return SimpleModel(data=1)
def invalid_function() -> int:
return 1
assert is_return_type_base_model_or_instance(valid_function)
assert not is_return_type_base_model_or_instance(invalid_function)
def test_get_signature_from_fn():
def test_function(a: int, b: str) -> float:
"""Sample docstring"""
pass
result = get_signature_from_fn(test_function)
expected = "def test_function(a: int, b: str) -> float"
assert expected in result
assert "Sample docstring" in result
def test_format_function():
def sample_function(x: int) -> SimpleModel:
"""This is a docstring."""
return SimpleModel(data=x)
formatted = format_function(sample_function)
assert "def sample_function(x: int) -> SimpleModel:" in formatted
assert '"""This is a docstring."""' in formatted
assert "return SimpleModel(data=x)" in formatted
def test_distil_decorator_without_arguments():
@instructions.distil
def test_func(x: int) -> SimpleModel:
return SimpleModel(data=x)
result = test_func(42)
assert result.data == 42
def test_distil_decorator_with_name_argument():
@instructions.distil(name="custom_name")
def another_test_func(x: int) -> SimpleModel:
return SimpleModel(data=x)
result = another_test_func(55)
assert result.data == 55
# Mock track function for decorator tests
def mock_track(*args, **kwargs):
pass
def fn(a: int, b: int) -> int:
return openai.ChatCompletion.create(
messages=[],
model="davinci",
response_model=SimpleModel,
)