mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
chore: include types to instructor.dsl (#419)
This commit is contained in:
@@ -17,6 +17,13 @@ env:
|
||||
instructor/cli/usage.py
|
||||
instructor/exceptions.py
|
||||
instructor/distil.py
|
||||
instructor/dsl/citation.py
|
||||
instructor/dsl/iterable.py
|
||||
instructor/dsl/maybe.py
|
||||
instructor/dsl/parallel.py
|
||||
instructor/dsl/partial.py
|
||||
instructor/dsl/partialjson.py
|
||||
instructor/dsl/validators.py
|
||||
instructor/function_calls.py
|
||||
tests/test_function_calls.py
|
||||
tests/test_distil.py
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from pydantic import BaseModel, Field, FieldValidationInfo, model_validator
|
||||
from typing import List
|
||||
from typing import Generator, List, Tuple
|
||||
|
||||
|
||||
class CitationMixin(BaseModel):
|
||||
class CitationMixin(BaseModel): # type: ignore[misc]
|
||||
"""
|
||||
Helpful mixing that can use `validation_context={"context": context}` in `from_response` to find the span of the substring_phrase in the context.
|
||||
|
||||
@@ -57,7 +57,7 @@ class CitationMixin(BaseModel):
|
||||
description="List of unique and specific substrings of the quote that was used to answer the question.",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
@model_validator(mode="after") # type: ignore[misc]
|
||||
def validate_sources(self, info: FieldValidationInfo) -> "CitationMixin":
|
||||
"""
|
||||
For each substring_phrase, find the span of the substring_phrase in the context.
|
||||
@@ -75,8 +75,10 @@ class CitationMixin(BaseModel):
|
||||
self.substring_quotes = [text_chunks[span[0] : span[1]] for span in spans]
|
||||
return self
|
||||
|
||||
def _get_span(self, quote, context, errs=5):
|
||||
import regex
|
||||
def _get_span(
|
||||
self, quote: str, context: str, errs: int = 5
|
||||
) -> Generator[Tuple[int, int], None, None]:
|
||||
import regex # type: ignore[import-untyped]
|
||||
|
||||
minor = quote
|
||||
major = context
|
||||
@@ -90,6 +92,6 @@ class CitationMixin(BaseModel):
|
||||
if s is not None:
|
||||
yield from s.spans()
|
||||
|
||||
def get_spans(self, context):
|
||||
def get_spans(self, context: str) -> Generator[Tuple[int, int], None, None]:
|
||||
for quote in self.substring_quotes:
|
||||
yield from self._get_span(quote, context)
|
||||
|
||||
+31
-17
@@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Type, Any
|
||||
from typing import Any, AsyncGenerator, Generator, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
@@ -6,20 +6,26 @@ from instructor.function_calls import OpenAISchema, Mode
|
||||
|
||||
|
||||
class IterableBase:
|
||||
task_type = None # type: ignore
|
||||
task_type = None # type: ignore[var-annotated]
|
||||
|
||||
@classmethod
|
||||
def from_streaming_response(cls, completion, mode: Mode, **kwargs: Any): # noqa: ARG003
|
||||
def from_streaming_response(
|
||||
cls, completion: Iterable[Any], mode: Mode, **kwargs: Any
|
||||
) -> Generator[BaseModel, None, None]: # noqa: ARG003
|
||||
json_chunks = cls.extract_json(completion, mode)
|
||||
yield from cls.tasks_from_chunks(json_chunks)
|
||||
yield from cls.tasks_from_chunks(json_chunks, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def from_streaming_response_async(cls, completion, mode: Mode, **kwargs):
|
||||
async def from_streaming_response_async(
|
||||
cls, completion: AsyncGenerator[Any, None], mode: Mode, **kwargs: Any
|
||||
) -> AsyncGenerator[BaseModel, None]:
|
||||
json_chunks = cls.extract_json_async(completion, mode)
|
||||
return cls.tasks_from_chunks_async(json_chunks, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def tasks_from_chunks(cls, json_chunks, **kwargs):
|
||||
def tasks_from_chunks(
|
||||
cls, json_chunks: Iterable[str], **kwargs: Any
|
||||
) -> Generator[BaseModel, None, None]:
|
||||
started = False
|
||||
potential_object = ""
|
||||
for chunk in json_chunks:
|
||||
@@ -32,11 +38,14 @@ class IterableBase:
|
||||
|
||||
task_json, potential_object = cls.get_object(potential_object, 0)
|
||||
if task_json:
|
||||
obj = cls.task_type.model_validate_json(task_json, **kwargs) # type: ignore
|
||||
assert cls.task_type is not None
|
||||
obj = cls.task_type.model_validate_json(task_json, **kwargs)
|
||||
yield obj
|
||||
|
||||
@classmethod
|
||||
async def tasks_from_chunks_async(cls, json_chunks, **kwargs):
|
||||
async def tasks_from_chunks_async(
|
||||
cls, json_chunks: AsyncGenerator[str, None], **kwargs: Any
|
||||
) -> AsyncGenerator[BaseModel, None]:
|
||||
started = False
|
||||
potential_object = ""
|
||||
async for chunk in json_chunks:
|
||||
@@ -49,11 +58,14 @@ class IterableBase:
|
||||
|
||||
task_json, potential_object = cls.get_object(potential_object, 0)
|
||||
if task_json:
|
||||
obj = cls.task_type.model_validate_json(task_json, **kwargs) # type: ignore
|
||||
assert cls.task_type is not None
|
||||
obj = cls.task_type.model_validate_json(task_json, **kwargs)
|
||||
yield obj
|
||||
|
||||
@staticmethod
|
||||
def extract_json(completion, mode: Mode):
|
||||
def extract_json(
|
||||
completion: Iterable[Any], mode: Mode
|
||||
) -> Generator[str, None, None]:
|
||||
for chunk in completion:
|
||||
try:
|
||||
if chunk.choices:
|
||||
@@ -74,7 +86,9 @@ class IterableBase:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def extract_json_async(completion, mode: Mode):
|
||||
async def extract_json_async(
|
||||
completion: AsyncGenerator[Any, None], mode: Mode
|
||||
) -> AsyncGenerator[str, None]:
|
||||
async for chunk in completion:
|
||||
try:
|
||||
if chunk.choices:
|
||||
@@ -95,15 +109,15 @@ class IterableBase:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_object(str, stack):
|
||||
for i, c in enumerate(str):
|
||||
def get_object(s: str, stack: int) -> Tuple[Optional[str], str]:
|
||||
for i, c in enumerate(s):
|
||||
if c == "{":
|
||||
stack += 1
|
||||
if c == "}":
|
||||
stack -= 1
|
||||
if stack == 0:
|
||||
return str[: i + 1], str[i + 2 :]
|
||||
return None, str
|
||||
return s[: i + 1], s[i + 2 :]
|
||||
return None, s
|
||||
|
||||
|
||||
def IterableModel(
|
||||
@@ -166,7 +180,7 @@ def IterableModel(
|
||||
name = f"Iterable{task_name}"
|
||||
|
||||
list_tasks = (
|
||||
List[subtask_class],
|
||||
List[subtask_class], # type: ignore[valid-type]
|
||||
Field(
|
||||
default_factory=list,
|
||||
repr=False,
|
||||
@@ -177,7 +191,7 @@ def IterableModel(
|
||||
new_cls = create_model(
|
||||
name,
|
||||
tasks=list_tasks,
|
||||
__base__=(OpenAISchema, IterableBase), # type: ignore
|
||||
__base__=(OpenAISchema, IterableBase),
|
||||
)
|
||||
# set the class constructor BaseModel
|
||||
new_cls.task_type = subtask_class
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from typing import Type, Optional, TypeVar, Generic
|
||||
from typing import Generic, Optional, Type, TypeVar
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class MaybeBase(BaseModel, Generic[T]):
|
||||
class MaybeBase(BaseModel, Generic[T]): # type: ignore[misc]
|
||||
"""
|
||||
Extract a result from a model, if any, otherwise set the error and message fields.
|
||||
"""
|
||||
@@ -13,8 +13,8 @@ class MaybeBase(BaseModel, Generic[T]):
|
||||
error: bool = Field(default=False)
|
||||
message: Optional[str]
|
||||
|
||||
def __bool__(self):
|
||||
return self.result is not None # type: ignore
|
||||
def __bool__(self) -> bool:
|
||||
return self.result is not None
|
||||
|
||||
|
||||
def Maybe(model: Type[T]) -> Type[MaybeBase[T]]:
|
||||
|
||||
+22
-10
@@ -1,10 +1,22 @@
|
||||
from typing import Type, TypeVar, Union, get_origin, get_args
|
||||
from types import UnionType
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
from types import UnionType # type: ignore[attr-defined]
|
||||
|
||||
from instructor.function_calls import OpenAISchema, Mode, openai_schema
|
||||
from collections.abc import Iterable
|
||||
|
||||
T = TypeVar("T")
|
||||
T = TypeVar("T", bound=OpenAISchema)
|
||||
|
||||
|
||||
class ParallelBase:
|
||||
@@ -16,11 +28,11 @@ class ParallelBase:
|
||||
|
||||
def from_response(
|
||||
self,
|
||||
response,
|
||||
response: Any,
|
||||
mode: Mode,
|
||||
validation_context=None,
|
||||
strict: bool = None,
|
||||
) -> Iterable[Union[T]]:
|
||||
validation_context: Optional[Any] = None,
|
||||
strict: Optional[bool] = None,
|
||||
) -> Generator[T, None, None]:
|
||||
#! We expect this from the OpenAISchema class, We should address
|
||||
#! this with a protocol or an abstract class... @jxnlco
|
||||
assert mode == Mode.PARALLEL_TOOLS, "Mode must be PARALLEL_TOOLS"
|
||||
@@ -32,7 +44,7 @@ class ParallelBase:
|
||||
)
|
||||
|
||||
|
||||
def get_types_array(typehint: Type[Iterable[Union[T]]]):
|
||||
def get_types_array(typehint: Type[Iterable[Union[T]]]) -> Tuple[Type[T], ...]:
|
||||
should_be_iterable = get_origin(typehint)
|
||||
assert should_be_iterable is Iterable
|
||||
|
||||
@@ -50,7 +62,7 @@ def get_types_array(typehint: Type[Iterable[Union[T]]]):
|
||||
return get_args(typehint)
|
||||
|
||||
|
||||
def handle_parallel_model(typehint: Type[Iterable[Union[T]]]):
|
||||
def handle_parallel_model(typehint: Type[Iterable[Union[T]]]) -> List[Dict[str, Any]]:
|
||||
the_types = get_types_array(typehint)
|
||||
return [
|
||||
{"type": "function", "function": openai_schema(model).openai_schema}
|
||||
@@ -58,6 +70,6 @@ def handle_parallel_model(typehint: Type[Iterable[Union[T]]]):
|
||||
]
|
||||
|
||||
|
||||
def ParallelModel(typehint):
|
||||
def ParallelModel(typehint: Type[Iterable[Union[T]]]) -> ParallelBase:
|
||||
the_types = get_types_array(typehint)
|
||||
return ParallelBase(*[model for model in the_types])
|
||||
|
||||
+33
-10
@@ -8,7 +8,18 @@
|
||||
|
||||
from pydantic import BaseModel, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing import TypeVar, NoReturn, get_args, get_origin, Optional, Generic
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Generator,
|
||||
Generic,
|
||||
get_args,
|
||||
get_origin,
|
||||
Iterable,
|
||||
NoReturn,
|
||||
Optional,
|
||||
TypeVar,
|
||||
)
|
||||
from copy import deepcopy
|
||||
|
||||
from instructor.function_calls import Mode
|
||||
@@ -21,17 +32,23 @@ Model = TypeVar("Model", bound=BaseModel)
|
||||
|
||||
class PartialBase:
|
||||
@classmethod
|
||||
def from_streaming_response(cls, completion, mode: Mode, **kwargs):
|
||||
def from_streaming_response(
|
||||
cls, completion: Iterable[Any], mode: Mode, **kwargs: Any
|
||||
) -> Generator[Model, None, None]:
|
||||
json_chunks = cls.extract_json(completion, mode)
|
||||
yield from cls.model_from_chunks(json_chunks, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def from_streaming_response_async(cls, completion, mode: Mode, **kwargs):
|
||||
async def from_streaming_response_async(
|
||||
cls, completion: AsyncGenerator[Any, None], mode: Mode, **kwargs: Any
|
||||
) -> AsyncGenerator[Model, None]:
|
||||
json_chunks = cls.extract_json_async(completion, mode)
|
||||
return cls.model_from_chunks_async(json_chunks, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def model_from_chunks(cls, json_chunks, **kwargs):
|
||||
def model_from_chunks(
|
||||
cls, json_chunks: Iterable[Any], **kwargs: Any
|
||||
) -> Generator[Model, None, None]:
|
||||
prev_obj = None
|
||||
potential_object = ""
|
||||
for chunk in json_chunks:
|
||||
@@ -42,7 +59,7 @@ class PartialBase:
|
||||
parser.parse(potential_object) if potential_object.strip() else None
|
||||
)
|
||||
if task_json:
|
||||
obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore
|
||||
obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore[attr-defined]
|
||||
if obj != prev_obj:
|
||||
obj.__dict__[
|
||||
"chunk"
|
||||
@@ -51,7 +68,9 @@ class PartialBase:
|
||||
yield obj
|
||||
|
||||
@classmethod
|
||||
async def model_from_chunks_async(cls, json_chunks, **kwargs):
|
||||
async def model_from_chunks_async(
|
||||
cls, json_chunks: AsyncGenerator[str, None], **kwargs: Any
|
||||
) -> AsyncGenerator[Model, None]:
|
||||
potential_object = ""
|
||||
prev_obj = None
|
||||
async for chunk in json_chunks:
|
||||
@@ -62,7 +81,7 @@ class PartialBase:
|
||||
parser.parse(potential_object) if potential_object.strip() else None
|
||||
)
|
||||
if task_json:
|
||||
obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore
|
||||
obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore[attr-defined]
|
||||
if obj != prev_obj:
|
||||
obj.__dict__[
|
||||
"chunk"
|
||||
@@ -71,7 +90,9 @@ class PartialBase:
|
||||
yield obj
|
||||
|
||||
@staticmethod
|
||||
def extract_json(completion, mode: Mode):
|
||||
def extract_json(
|
||||
completion: Iterable[Any], mode: Mode
|
||||
) -> Generator[str, None, None]:
|
||||
for chunk in completion:
|
||||
try:
|
||||
if chunk.choices:
|
||||
@@ -92,7 +113,9 @@ class PartialBase:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def extract_json_async(completion, mode: Mode):
|
||||
async def extract_json_async(
|
||||
completion: AsyncGenerator[Any, None], mode: Mode
|
||||
) -> AsyncGenerator[str, None]:
|
||||
async for chunk in completion:
|
||||
try:
|
||||
if chunk.choices:
|
||||
@@ -169,7 +192,7 @@ class Partial(Generic[Model]):
|
||||
|
||||
# Recursively apply Partial to each of the generic arguments
|
||||
modified_args = tuple(
|
||||
Partial[arg]
|
||||
Partial[arg] # type: ignore[valid-type]
|
||||
if isinstance(arg, type) and issubclass(arg, BaseModel)
|
||||
else arg
|
||||
for arg in generic_args
|
||||
|
||||
@@ -8,12 +8,12 @@ The above copyright notice and this permission notice shall be included in all c
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
import json
|
||||
|
||||
|
||||
class JSONParser:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.parsers = {
|
||||
" ": self.parse_space,
|
||||
"\r": self.parse_space,
|
||||
@@ -30,26 +30,26 @@ class JSONParser:
|
||||
for c in "0123456789.-":
|
||||
self.parsers[c] = self.parse_number
|
||||
|
||||
self.last_parse_reminding = None
|
||||
self.last_parse_reminding: Optional[str] = None
|
||||
self.on_extra_token = self.default_on_extra_token
|
||||
|
||||
def default_on_extra_token(self, text, data, reminding):
|
||||
def default_on_extra_token(self, text: str, data: Any, reminding: str) -> None:
|
||||
pass
|
||||
|
||||
def parse(self, s):
|
||||
def parse(self, s: str) -> Dict[str, Any]:
|
||||
if len(s) >= 1:
|
||||
try:
|
||||
return json.loads(s)
|
||||
except json.JSONDecodeError as e:
|
||||
data, reminding = self.parse_any(s, e)
|
||||
self.last_parse_reminding = reminding
|
||||
if self.on_extra_token and reminding:
|
||||
self.on_extra_token(s, data, reminding)
|
||||
if self.on_extra_token is not None and reminding:
|
||||
self.on_extra_token(s, data, reminding) # type: ignore[no-untyped-call]
|
||||
return json.loads(json.dumps(data))
|
||||
else:
|
||||
return json.loads("{}")
|
||||
|
||||
def parse_any(self, s, e):
|
||||
def parse_any(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]:
|
||||
if not s:
|
||||
raise e
|
||||
parser = self.parsers.get(s[0])
|
||||
@@ -57,10 +57,10 @@ class JSONParser:
|
||||
raise e
|
||||
return parser(s, e)
|
||||
|
||||
def parse_space(self, s, e):
|
||||
def parse_space(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]:
|
||||
return self.parse_any(s.strip(), e)
|
||||
|
||||
def parse_array(self, s, e):
|
||||
def parse_array(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]:
|
||||
s = s[1:] # skip starting '['
|
||||
acc = []
|
||||
s = s.strip()
|
||||
@@ -76,9 +76,9 @@ class JSONParser:
|
||||
s = s.strip()
|
||||
return acc, s
|
||||
|
||||
def parse_object(self, s, e):
|
||||
def parse_object(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]:
|
||||
s = s[1:] # skip starting '{'
|
||||
acc = {}
|
||||
acc: Dict[str, Any] = {}
|
||||
s = s.strip()
|
||||
while s:
|
||||
if s[0] == "}":
|
||||
@@ -114,7 +114,7 @@ class JSONParser:
|
||||
s = s.strip()
|
||||
return acc, s
|
||||
|
||||
def parse_string(self, s, e): # noqa: ARG002
|
||||
def parse_string(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]: # noqa: ARG002
|
||||
end = s.find('"', 1)
|
||||
while end != -1 and s[end - 1] == "\\": # Handle escaped quotes
|
||||
end = s.find('"', end + 1)
|
||||
@@ -125,7 +125,7 @@ class JSONParser:
|
||||
s = s[end + 1 :]
|
||||
return json.loads(str_val), s
|
||||
|
||||
def parse_number(self, s, e):
|
||||
def parse_number(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]:
|
||||
i = 0
|
||||
while i < len(s) and s[i] in "0123456789.-":
|
||||
i += 1
|
||||
@@ -143,17 +143,17 @@ class JSONParser:
|
||||
raise e
|
||||
return num, s
|
||||
|
||||
def parse_true(self, s, e):
|
||||
def parse_true(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]:
|
||||
if s.startswith("true"):
|
||||
return True, s[4:]
|
||||
raise e
|
||||
|
||||
def parse_false(self, s, e):
|
||||
def parse_false(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]:
|
||||
if s.startswith("false"):
|
||||
return False, s[5:]
|
||||
raise e
|
||||
|
||||
def parse_null(self, s, e):
|
||||
def parse_null(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]:
|
||||
if s.startswith("null"):
|
||||
return None, s[4:]
|
||||
raise e
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
from pydantic import Field
|
||||
@@ -33,7 +33,7 @@ def llm_validator(
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0,
|
||||
openai_client: OpenAI = None,
|
||||
):
|
||||
) -> Callable[[str], str]:
|
||||
"""
|
||||
Create a validator that uses the LLM to validate an attribute
|
||||
|
||||
@@ -70,7 +70,7 @@ def llm_validator(
|
||||
|
||||
openai_client = openai_client if openai_client else patch(OpenAI())
|
||||
|
||||
def llm(v):
|
||||
def llm(v: str) -> str:
|
||||
resp = openai_client.chat.completions.create(
|
||||
response_model=Validator,
|
||||
messages=[
|
||||
@@ -85,7 +85,7 @@ def llm_validator(
|
||||
],
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
) # type: ignore
|
||||
)
|
||||
|
||||
# If the response is not valid, return the reason, this could be used in
|
||||
# the future to generate a better response, via reasking mechanism.
|
||||
@@ -99,7 +99,7 @@ def llm_validator(
|
||||
return llm
|
||||
|
||||
|
||||
def openai_moderation(client: OpenAI = None):
|
||||
def openai_moderation(client: Optional[OpenAI] = None) -> Callable[[str], str]:
|
||||
"""
|
||||
Validates a message using OpenAI moderation model.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user