mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
Fix async usage (#167)
Co-authored-by: Jason Liu <jxnl@users.noreply.github.com> Co-authored-by: Jason Liu <jason@jxnl.co>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from .function_calls import OpenAISchema, openai_function, openai_schema
|
||||
from .distil import FinetuneFormat, Instructions
|
||||
from .dsl import CitationMixin, Maybe, MultiTask, llm_validator
|
||||
from .function_calls import OpenAISchema, openai_function, openai_schema
|
||||
from .dsl import MultiTask, Maybe, llm_validator, CitationMixin
|
||||
from .patch import patch, apatch
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from pydantic import BaseModel, create_model, Field
|
||||
from typing import Optional, List, Type
|
||||
from instructor import OpenAISchema
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from instructor.function_calls import OpenAISchema
|
||||
|
||||
|
||||
class MultiTaskBase:
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from pydantic import Field
|
||||
from typing import Optional
|
||||
|
||||
from openai import OpenAI
|
||||
import instructor
|
||||
from pydantic import Field
|
||||
|
||||
from instructor.function_calls import OpenAISchema
|
||||
|
||||
|
||||
class Validator(instructor.OpenAISchema):
|
||||
class Validator(OpenAISchema):
|
||||
"""
|
||||
Validate if an attribute is correct and if not,
|
||||
return a new value with an error message
|
||||
|
||||
+45
-13
@@ -1,9 +1,13 @@
|
||||
import inspect
|
||||
|
||||
from functools import wraps
|
||||
from json import JSONDecodeError
|
||||
from pydantic import ValidationError, BaseModel
|
||||
from typing import Callable, Type, Optional
|
||||
from logging import warn
|
||||
from typing import Callable, Optional, Type, Union
|
||||
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from .function_calls import OpenAISchema, openai_schema
|
||||
|
||||
OVERRIDE_DOCS = """
|
||||
@@ -66,6 +70,18 @@ def process_response(
|
||||
return response
|
||||
|
||||
|
||||
def dump_message(message: ChatCompletionMessage) -> dict:
|
||||
"""Dumps a message to a dict, to be returned to the OpenAI API.
|
||||
|
||||
Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
|
||||
if it isn't used.
|
||||
"""
|
||||
dumped_message = message.model_dump()
|
||||
if not dumped_message.get("tool_calls"):
|
||||
del dumped_message["tool_calls"]
|
||||
return dumped_message
|
||||
|
||||
|
||||
async def retry_async(
|
||||
func,
|
||||
response_model,
|
||||
@@ -78,7 +94,7 @@ async def retry_async(
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
response = await func(*args, **kwargs)
|
||||
response: ChatCompletion = await func(*args, **kwargs)
|
||||
return (
|
||||
process_response(
|
||||
response,
|
||||
@@ -122,7 +138,7 @@ def retry_sync(
|
||||
None,
|
||||
)
|
||||
except (ValidationError, JSONDecodeError) as e:
|
||||
kwargs["messages"].append(response.choices[0].message) # type: ignore
|
||||
kwargs["messages"].append(dump_message(response.choices[0].message))
|
||||
kwargs["messages"].append(
|
||||
{
|
||||
"role": "user",
|
||||
@@ -134,7 +150,16 @@ def retry_sync(
|
||||
raise e
|
||||
|
||||
|
||||
def wrap_chatcompletion(func: Callable, is_async: bool = None) -> Callable:
|
||||
def is_async(func: Callable) -> bool:
|
||||
"""Returns true if the callable is async, accounting for wrapped callables"""
|
||||
return inspect.iscoroutinefunction(func) or (
|
||||
hasattr(func, "__wrapped__") and inspect.iscoroutinefunction(func.__wrapped__)
|
||||
)
|
||||
|
||||
|
||||
def wrap_chatcompletion(func: Callable) -> Callable:
|
||||
func_is_async = is_async(func)
|
||||
|
||||
@wraps(func)
|
||||
async def new_chatcompletion_async(
|
||||
response_model=None,
|
||||
@@ -177,12 +202,14 @@ def wrap_chatcompletion(func: Callable, is_async: bool = None) -> Callable:
|
||||
raise ValueError(error)
|
||||
return response
|
||||
|
||||
wrapper_function = new_chatcompletion_async if is_async else new_chatcompletion_sync
|
||||
wrapper_function = (
|
||||
new_chatcompletion_async if func_is_async else new_chatcompletion_sync
|
||||
)
|
||||
wrapper_function.__doc__ = OVERRIDE_DOCS
|
||||
return wrapper_function
|
||||
|
||||
|
||||
def patch(client):
|
||||
def patch(client: Union[OpenAI, AsyncOpenAI]):
|
||||
"""
|
||||
Patch the `client.chat.completions.create` method
|
||||
|
||||
@@ -198,9 +225,11 @@ def patch(client):
|
||||
return client
|
||||
|
||||
|
||||
def apatch(client):
|
||||
def apatch(client: AsyncOpenAI):
|
||||
"""
|
||||
Patch the `client.chat.completions.acreate` and `client.chat.completions.acreate` methods
|
||||
No longer necessary, use `patch` instead.
|
||||
|
||||
Patch the `client.chat.completions.create` method
|
||||
|
||||
Enables the following features:
|
||||
|
||||
@@ -209,7 +238,10 @@ def apatch(client):
|
||||
- `validation_context` parameter to validate the response using the pydantic model
|
||||
- `strict` parameter to use strict json parsing
|
||||
"""
|
||||
client.chat.completions.create = wrap_chatcompletion(
|
||||
client.chat.completions.create, is_async=True
|
||||
|
||||
# Emit a deprecation warning
|
||||
warn(
|
||||
"instructor.apatch is deprecated, use instructor.patch instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return client
|
||||
return patch(client)
|
||||
|
||||
Generated
+32
-34
@@ -264,13 +264,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "httpcore"
|
||||
version = "1.0.1"
|
||||
version = "1.0.2"
|
||||
description = "A minimal low-level HTTP client."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "httpcore-1.0.1-py3-none-any.whl", hash = "sha256:c5e97ef177dca2023d0b9aad98e49507ef5423e9f1d94ffe2cfe250aa28e63b0"},
|
||||
{file = "httpcore-1.0.1.tar.gz", hash = "sha256:fce1ddf9b606cfb98132ab58865c3728c52c8e4c3c46e2aabb3674464a186e92"},
|
||||
{file = "httpcore-1.0.2-py3-none-any.whl", hash = "sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7"},
|
||||
{file = "httpcore-1.0.2.tar.gz", hash = "sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -410,16 +410,6 @@ files = [
|
||||
{file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"},
|
||||
{file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"},
|
||||
{file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"},
|
||||
{file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"},
|
||||
{file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"},
|
||||
{file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"},
|
||||
@@ -592,13 +582,13 @@ mkdocstrings = ">=0.20"
|
||||
|
||||
[[package]]
|
||||
name = "openai"
|
||||
version = "1.1.1"
|
||||
description = "Client library for the openai API"
|
||||
version = "1.2.3"
|
||||
description = "The official Python library for the openai API"
|
||||
optional = false
|
||||
python-versions = ">=3.7.1"
|
||||
files = [
|
||||
{file = "openai-1.1.1-py3-none-any.whl", hash = "sha256:1496418b132c88352bcfffa8c24e83a69f0e01b1484cbb7bb48f722aad8fd6e1"},
|
||||
{file = "openai-1.1.1.tar.gz", hash = "sha256:80e49cb21d8445f6d51339b8af7376fc83302c78ab78578b78133ef89634869d"},
|
||||
{file = "openai-1.2.3-py3-none-any.whl", hash = "sha256:d8d1221d777c3b2d12468f17410bf929ca0cb06e9556586e61f5a4255f0cf2b4"},
|
||||
{file = "openai-1.2.3.tar.gz", hash = "sha256:800d206ec02c8310400f07b3bb52e158751f3a419e75d080117d913f358bf0d5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -646,13 +636,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "platformdirs"
|
||||
version = "3.11.0"
|
||||
version = "4.0.0"
|
||||
description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "platformdirs-3.11.0-py3-none-any.whl", hash = "sha256:e9d171d00af68be50e9202731309c4e658fd8bc76f55c11c7dd760d023bda68e"},
|
||||
{file = "platformdirs-3.11.0.tar.gz", hash = "sha256:cf8ee52a3afdb965072dcc652433e0c7e3e40cf5ea1477cd4b3b1d2eb75495b3"},
|
||||
{file = "platformdirs-4.0.0-py3-none-any.whl", hash = "sha256:118c954d7e949b35437270383a3f2531e99dd93cf7ce4dc8340d3356d30f173b"},
|
||||
{file = "platformdirs-4.0.0.tar.gz", hash = "sha256:cb633b2bcf10c51af60beb0ab06d2f1d69064b43abf4c185ca6b28865f3f9731"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
@@ -827,13 +817,13 @@ plugins = ["importlib-metadata"]
|
||||
|
||||
[[package]]
|
||||
name = "pymdown-extensions"
|
||||
version = "10.3.1"
|
||||
version = "10.4"
|
||||
description = "Extension pack for Python Markdown."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pymdown_extensions-10.3.1-py3-none-any.whl", hash = "sha256:8cba67beb2a1318cdaf742d09dff7c0fc4cafcc290147ade0f8fb7b71522711a"},
|
||||
{file = "pymdown_extensions-10.3.1.tar.gz", hash = "sha256:f6c79941498a458852853872e379e7bab63888361ba20992fc8b4f8a9b61735e"},
|
||||
{file = "pymdown_extensions-10.4-py3-none-any.whl", hash = "sha256:cfc28d6a09d19448bcbf8eee3ce098c7d17ff99f7bd3069db4819af181212037"},
|
||||
{file = "pymdown_extensions-10.4.tar.gz", hash = "sha256:bc46f11749ecd4d6b71cf62396104b4a200bad3498cb0f5dad1b8502fe461a35"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -865,6 +855,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
|
||||
[package.extras]
|
||||
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-asyncio"
|
||||
version = "0.21.1"
|
||||
description = "Pytest support for asyncio"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "pytest-asyncio-0.21.1.tar.gz", hash = "sha256:40a7eae6dded22c7b604986855ea48400ab15b069ae38116e8c01238e9eeb64d"},
|
||||
{file = "pytest_asyncio-0.21.1-py3-none-any.whl", hash = "sha256:8666c1c8ac02631d7c51ba282e0c69a8a452b211ffedf2599099845da5c5c37b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pytest = ">=7.0.0"
|
||||
|
||||
[package.extras]
|
||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
|
||||
testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.8.2"
|
||||
@@ -891,7 +899,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
|
||||
@@ -899,15 +906,8 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
|
||||
@@ -924,7 +924,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
|
||||
@@ -932,7 +931,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
|
||||
{file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
|
||||
@@ -1245,4 +1243,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "5395b7333475a6f4a6544bbeaf991d8d781606cef00e291c396ef1ebdc85c270"
|
||||
content-hash = "48097711e7152fde9f43e23c8dcd2253cf1f872da82da2189cd25acee7ee3a0a"
|
||||
|
||||
@@ -24,6 +24,7 @@ mkdocs = "^1.4.3"
|
||||
mkdocs-material = "^9.1.18"
|
||||
mkdocstrings = "^0.22.0"
|
||||
mkdocstrings-python = "^1.1.2"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
+60
-2
@@ -1,14 +1,17 @@
|
||||
import functools
|
||||
import pytest
|
||||
import instructor
|
||||
|
||||
from pydantic import BaseModel, ValidationError, BeforeValidator
|
||||
from pydantic import BaseModel, Field, ValidationError, BeforeValidator
|
||||
from openai import OpenAI, AsyncOpenAI
|
||||
from instructor import llm_validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
from instructor.patch import is_async, wrap_chatcompletion
|
||||
|
||||
client = instructor.patch(OpenAI())
|
||||
aclient = instructor.apatch(AsyncOpenAI())
|
||||
aclient = instructor.patch(AsyncOpenAI())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -78,6 +81,61 @@ def test_runmodel_validator():
|
||||
model, "_raw_response"
|
||||
), "The raw response should be available from OpenAI"
|
||||
|
||||
|
||||
def test_patch_completes_successfully():
|
||||
instructor.patch(OpenAI())
|
||||
|
||||
|
||||
def test_apatch_completes_successfully():
|
||||
instructor.apatch(AsyncOpenAI())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wrap_chatcompletion_wraps_async_input_function():
|
||||
async def input_function(*args, **kwargs):
|
||||
return "Hello, World!"
|
||||
|
||||
wrapped_function = wrap_chatcompletion(input_function)
|
||||
result = await wrapped_function()
|
||||
|
||||
assert result == "Hello, World!"
|
||||
|
||||
|
||||
def test_wrap_chatcompletion_wraps_input_function():
|
||||
def input_function(*args, **kwargs):
|
||||
return "Hello, World!"
|
||||
|
||||
wrapped_function = wrap_chatcompletion(input_function)
|
||||
result = wrapped_function()
|
||||
|
||||
assert result == "Hello, World!"
|
||||
|
||||
|
||||
def test_is_async_returns_true_if_function_is_async():
|
||||
async def async_function():
|
||||
pass
|
||||
|
||||
assert is_async(async_function) is True
|
||||
|
||||
|
||||
def test_is_async_returns_false_if_function_is_not_async():
|
||||
def sync_function():
|
||||
pass
|
||||
|
||||
assert is_async(sync_function) is False
|
||||
|
||||
|
||||
def test_is_async_returns_true_if_wrapped_function_is_async():
|
||||
async def async_function():
|
||||
pass
|
||||
|
||||
@functools.wraps(async_function)
|
||||
def wrapped_function():
|
||||
pass
|
||||
|
||||
assert is_async(wrapped_function) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_runmodel_validator():
|
||||
aclient = instructor.apatch(AsyncOpenAI())
|
||||
|
||||
Reference in New Issue
Block a user