chore: include types to instructor.dsl (#419)

This commit is contained in:
Ezzeri Esa
2024-02-08 17:13:37 -08:00
committed by GitHub
parent 805161b70f
commit 25f8214286
8 changed files with 127 additions and 69 deletions
+7
View File
@@ -17,6 +17,13 @@ env:
instructor/cli/usage.py
instructor/exceptions.py
instructor/distil.py
instructor/dsl/citation.py
instructor/dsl/iterable.py
instructor/dsl/maybe.py
instructor/dsl/parallel.py
instructor/dsl/partial.py
instructor/dsl/partialjson.py
instructor/dsl/validators.py
instructor/function_calls.py
tests/test_function_calls.py
tests/test_distil.py
+8 -6
View File
@@ -1,8 +1,8 @@
from pydantic import BaseModel, Field, FieldValidationInfo, model_validator
from typing import List
from typing import Generator, List, Tuple
class CitationMixin(BaseModel):
class CitationMixin(BaseModel): # type: ignore[misc]
"""
Helpful mixing that can use `validation_context={"context": context}` in `from_response` to find the span of the substring_phrase in the context.
@@ -57,7 +57,7 @@ class CitationMixin(BaseModel):
description="List of unique and specific substrings of the quote that was used to answer the question.",
)
@model_validator(mode="after")
@model_validator(mode="after") # type: ignore[misc]
def validate_sources(self, info: FieldValidationInfo) -> "CitationMixin":
"""
For each substring_phrase, find the span of the substring_phrase in the context.
@@ -75,8 +75,10 @@ class CitationMixin(BaseModel):
self.substring_quotes = [text_chunks[span[0] : span[1]] for span in spans]
return self
def _get_span(self, quote, context, errs=5):
import regex
def _get_span(
self, quote: str, context: str, errs: int = 5
) -> Generator[Tuple[int, int], None, None]:
import regex # type: ignore[import-untyped]
minor = quote
major = context
@@ -90,6 +92,6 @@ class CitationMixin(BaseModel):
if s is not None:
yield from s.spans()
def get_spans(self, context):
def get_spans(self, context: str) -> Generator[Tuple[int, int], None, None]:
for quote in self.substring_quotes:
yield from self._get_span(quote, context)
+31 -17
View File
@@ -1,4 +1,4 @@
from typing import List, Optional, Type, Any
from typing import Any, AsyncGenerator, Generator, Iterable, List, Optional, Tuple, Type
from pydantic import BaseModel, Field, create_model
@@ -6,20 +6,26 @@ from instructor.function_calls import OpenAISchema, Mode
class IterableBase:
task_type = None # type: ignore
task_type = None # type: ignore[var-annotated]
@classmethod
def from_streaming_response(cls, completion, mode: Mode, **kwargs: Any): # noqa: ARG003
def from_streaming_response(
cls, completion: Iterable[Any], mode: Mode, **kwargs: Any
) -> Generator[BaseModel, None, None]: # noqa: ARG003
json_chunks = cls.extract_json(completion, mode)
yield from cls.tasks_from_chunks(json_chunks)
yield from cls.tasks_from_chunks(json_chunks, **kwargs)
@classmethod
async def from_streaming_response_async(cls, completion, mode: Mode, **kwargs):
async def from_streaming_response_async(
cls, completion: AsyncGenerator[Any, None], mode: Mode, **kwargs: Any
) -> AsyncGenerator[BaseModel, None]:
json_chunks = cls.extract_json_async(completion, mode)
return cls.tasks_from_chunks_async(json_chunks, **kwargs)
@classmethod
def tasks_from_chunks(cls, json_chunks, **kwargs):
def tasks_from_chunks(
cls, json_chunks: Iterable[str], **kwargs: Any
) -> Generator[BaseModel, None, None]:
started = False
potential_object = ""
for chunk in json_chunks:
@@ -32,11 +38,14 @@ class IterableBase:
task_json, potential_object = cls.get_object(potential_object, 0)
if task_json:
obj = cls.task_type.model_validate_json(task_json, **kwargs) # type: ignore
assert cls.task_type is not None
obj = cls.task_type.model_validate_json(task_json, **kwargs)
yield obj
@classmethod
async def tasks_from_chunks_async(cls, json_chunks, **kwargs):
async def tasks_from_chunks_async(
cls, json_chunks: AsyncGenerator[str, None], **kwargs: Any
) -> AsyncGenerator[BaseModel, None]:
started = False
potential_object = ""
async for chunk in json_chunks:
@@ -49,11 +58,14 @@ class IterableBase:
task_json, potential_object = cls.get_object(potential_object, 0)
if task_json:
obj = cls.task_type.model_validate_json(task_json, **kwargs) # type: ignore
assert cls.task_type is not None
obj = cls.task_type.model_validate_json(task_json, **kwargs)
yield obj
@staticmethod
def extract_json(completion, mode: Mode):
def extract_json(
completion: Iterable[Any], mode: Mode
) -> Generator[str, None, None]:
for chunk in completion:
try:
if chunk.choices:
@@ -74,7 +86,9 @@ class IterableBase:
pass
@staticmethod
async def extract_json_async(completion, mode: Mode):
async def extract_json_async(
completion: AsyncGenerator[Any, None], mode: Mode
) -> AsyncGenerator[str, None]:
async for chunk in completion:
try:
if chunk.choices:
@@ -95,15 +109,15 @@ class IterableBase:
pass
@staticmethod
def get_object(str, stack):
for i, c in enumerate(str):
def get_object(s: str, stack: int) -> Tuple[Optional[str], str]:
for i, c in enumerate(s):
if c == "{":
stack += 1
if c == "}":
stack -= 1
if stack == 0:
return str[: i + 1], str[i + 2 :]
return None, str
return s[: i + 1], s[i + 2 :]
return None, s
def IterableModel(
@@ -166,7 +180,7 @@ def IterableModel(
name = f"Iterable{task_name}"
list_tasks = (
List[subtask_class],
List[subtask_class], # type: ignore[valid-type]
Field(
default_factory=list,
repr=False,
@@ -177,7 +191,7 @@ def IterableModel(
new_cls = create_model(
name,
tasks=list_tasks,
__base__=(OpenAISchema, IterableBase), # type: ignore
__base__=(OpenAISchema, IterableBase),
)
# set the class constructor BaseModel
new_cls.task_type = subtask_class
+4 -4
View File
@@ -1,10 +1,10 @@
from pydantic import BaseModel, Field, create_model
from typing import Type, Optional, TypeVar, Generic
from typing import Generic, Optional, Type, TypeVar
T = TypeVar("T", bound=BaseModel)
class MaybeBase(BaseModel, Generic[T]):
class MaybeBase(BaseModel, Generic[T]): # type: ignore[misc]
"""
Extract a result from a model, if any, otherwise set the error and message fields.
"""
@@ -13,8 +13,8 @@ class MaybeBase(BaseModel, Generic[T]):
error: bool = Field(default=False)
message: Optional[str]
def __bool__(self):
return self.result is not None # type: ignore
def __bool__(self) -> bool:
return self.result is not None
def Maybe(model: Type[T]) -> Type[MaybeBase[T]]:
+22 -10
View File
@@ -1,10 +1,22 @@
from typing import Type, TypeVar, Union, get_origin, get_args
from types import UnionType
from typing import (
Any,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
get_args,
get_origin,
)
from types import UnionType # type: ignore[attr-defined]
from instructor.function_calls import OpenAISchema, Mode, openai_schema
from collections.abc import Iterable
T = TypeVar("T")
T = TypeVar("T", bound=OpenAISchema)
class ParallelBase:
@@ -16,11 +28,11 @@ class ParallelBase:
def from_response(
self,
response,
response: Any,
mode: Mode,
validation_context=None,
strict: bool = None,
) -> Iterable[Union[T]]:
validation_context: Optional[Any] = None,
strict: Optional[bool] = None,
) -> Generator[T, None, None]:
#! We expect this from the OpenAISchema class, We should address
#! this with a protocol or an abstract class... @jxnlco
assert mode == Mode.PARALLEL_TOOLS, "Mode must be PARALLEL_TOOLS"
@@ -32,7 +44,7 @@ class ParallelBase:
)
def get_types_array(typehint: Type[Iterable[Union[T]]]):
def get_types_array(typehint: Type[Iterable[Union[T]]]) -> Tuple[Type[T], ...]:
should_be_iterable = get_origin(typehint)
assert should_be_iterable is Iterable
@@ -50,7 +62,7 @@ def get_types_array(typehint: Type[Iterable[Union[T]]]):
return get_args(typehint)
def handle_parallel_model(typehint: Type[Iterable[Union[T]]]):
def handle_parallel_model(typehint: Type[Iterable[Union[T]]]) -> List[Dict[str, Any]]:
the_types = get_types_array(typehint)
return [
{"type": "function", "function": openai_schema(model).openai_schema}
@@ -58,6 +70,6 @@ def handle_parallel_model(typehint: Type[Iterable[Union[T]]]):
]
def ParallelModel(typehint):
def ParallelModel(typehint: Type[Iterable[Union[T]]]) -> ParallelBase:
the_types = get_types_array(typehint)
return ParallelBase(*[model for model in the_types])
+33 -10
View File
@@ -8,7 +8,18 @@
from pydantic import BaseModel, create_model
from pydantic.fields import FieldInfo
from typing import TypeVar, NoReturn, get_args, get_origin, Optional, Generic
from typing import (
Any,
AsyncGenerator,
Generator,
Generic,
get_args,
get_origin,
Iterable,
NoReturn,
Optional,
TypeVar,
)
from copy import deepcopy
from instructor.function_calls import Mode
@@ -21,17 +32,23 @@ Model = TypeVar("Model", bound=BaseModel)
class PartialBase:
@classmethod
def from_streaming_response(cls, completion, mode: Mode, **kwargs):
def from_streaming_response(
cls, completion: Iterable[Any], mode: Mode, **kwargs: Any
) -> Generator[Model, None, None]:
json_chunks = cls.extract_json(completion, mode)
yield from cls.model_from_chunks(json_chunks, **kwargs)
@classmethod
async def from_streaming_response_async(cls, completion, mode: Mode, **kwargs):
async def from_streaming_response_async(
cls, completion: AsyncGenerator[Any, None], mode: Mode, **kwargs: Any
) -> AsyncGenerator[Model, None]:
json_chunks = cls.extract_json_async(completion, mode)
return cls.model_from_chunks_async(json_chunks, **kwargs)
@classmethod
def model_from_chunks(cls, json_chunks, **kwargs):
def model_from_chunks(
cls, json_chunks: Iterable[Any], **kwargs: Any
) -> Generator[Model, None, None]:
prev_obj = None
potential_object = ""
for chunk in json_chunks:
@@ -42,7 +59,7 @@ class PartialBase:
parser.parse(potential_object) if potential_object.strip() else None
)
if task_json:
obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore
obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore[attr-defined]
if obj != prev_obj:
obj.__dict__[
"chunk"
@@ -51,7 +68,9 @@ class PartialBase:
yield obj
@classmethod
async def model_from_chunks_async(cls, json_chunks, **kwargs):
async def model_from_chunks_async(
cls, json_chunks: AsyncGenerator[str, None], **kwargs: Any
) -> AsyncGenerator[Model, None]:
potential_object = ""
prev_obj = None
async for chunk in json_chunks:
@@ -62,7 +81,7 @@ class PartialBase:
parser.parse(potential_object) if potential_object.strip() else None
)
if task_json:
obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore
obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore[attr-defined]
if obj != prev_obj:
obj.__dict__[
"chunk"
@@ -71,7 +90,9 @@ class PartialBase:
yield obj
@staticmethod
def extract_json(completion, mode: Mode):
def extract_json(
completion: Iterable[Any], mode: Mode
) -> Generator[str, None, None]:
for chunk in completion:
try:
if chunk.choices:
@@ -92,7 +113,9 @@ class PartialBase:
pass
@staticmethod
async def extract_json_async(completion, mode: Mode):
async def extract_json_async(
completion: AsyncGenerator[Any, None], mode: Mode
) -> AsyncGenerator[str, None]:
async for chunk in completion:
try:
if chunk.choices:
@@ -169,7 +192,7 @@ class Partial(Generic[Model]):
# Recursively apply Partial to each of the generic arguments
modified_args = tuple(
Partial[arg]
Partial[arg] # type: ignore[valid-type]
if isinstance(arg, type) and issubclass(arg, BaseModel)
else arg
for arg in generic_args
+17 -17
View File
@@ -8,12 +8,12 @@ The above copyright notice and this permission notice shall be included in all c
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
from typing import Any, Dict, Optional, Tuple
import json
class JSONParser:
def __init__(self):
def __init__(self) -> None:
self.parsers = {
" ": self.parse_space,
"\r": self.parse_space,
@@ -30,26 +30,26 @@ class JSONParser:
for c in "0123456789.-":
self.parsers[c] = self.parse_number
self.last_parse_reminding = None
self.last_parse_reminding: Optional[str] = None
self.on_extra_token = self.default_on_extra_token
def default_on_extra_token(self, text, data, reminding):
def default_on_extra_token(self, text: str, data: Any, reminding: str) -> None:
pass
def parse(self, s):
def parse(self, s: str) -> Dict[str, Any]:
if len(s) >= 1:
try:
return json.loads(s)
except json.JSONDecodeError as e:
data, reminding = self.parse_any(s, e)
self.last_parse_reminding = reminding
if self.on_extra_token and reminding:
self.on_extra_token(s, data, reminding)
if self.on_extra_token is not None and reminding:
self.on_extra_token(s, data, reminding) # type: ignore[no-untyped-call]
return json.loads(json.dumps(data))
else:
return json.loads("{}")
def parse_any(self, s, e):
def parse_any(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]:
if not s:
raise e
parser = self.parsers.get(s[0])
@@ -57,10 +57,10 @@ class JSONParser:
raise e
return parser(s, e)
def parse_space(self, s, e):
def parse_space(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]:
return self.parse_any(s.strip(), e)
def parse_array(self, s, e):
def parse_array(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]:
s = s[1:] # skip starting '['
acc = []
s = s.strip()
@@ -76,9 +76,9 @@ class JSONParser:
s = s.strip()
return acc, s
def parse_object(self, s, e):
def parse_object(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]:
s = s[1:] # skip starting '{'
acc = {}
acc: Dict[str, Any] = {}
s = s.strip()
while s:
if s[0] == "}":
@@ -114,7 +114,7 @@ class JSONParser:
s = s.strip()
return acc, s
def parse_string(self, s, e): # noqa: ARG002
def parse_string(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]: # noqa: ARG002
end = s.find('"', 1)
while end != -1 and s[end - 1] == "\\": # Handle escaped quotes
end = s.find('"', end + 1)
@@ -125,7 +125,7 @@ class JSONParser:
s = s[end + 1 :]
return json.loads(str_val), s
def parse_number(self, s, e):
def parse_number(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]:
i = 0
while i < len(s) and s[i] in "0123456789.-":
i += 1
@@ -143,17 +143,17 @@ class JSONParser:
raise e
return num, s
def parse_true(self, s, e):
def parse_true(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]:
if s.startswith("true"):
return True, s[4:]
raise e
def parse_false(self, s, e):
def parse_false(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]:
if s.startswith("false"):
return False, s[5:]
raise e
def parse_null(self, s, e):
def parse_null(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]:
if s.startswith("null"):
return None, s[4:]
raise e
+5 -5
View File
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Callable, Optional
from openai import OpenAI
from pydantic import Field
@@ -33,7 +33,7 @@ def llm_validator(
model: str = "gpt-3.5-turbo",
temperature: float = 0,
openai_client: OpenAI = None,
):
) -> Callable[[str], str]:
"""
Create a validator that uses the LLM to validate an attribute
@@ -70,7 +70,7 @@ def llm_validator(
openai_client = openai_client if openai_client else patch(OpenAI())
def llm(v):
def llm(v: str) -> str:
resp = openai_client.chat.completions.create(
response_model=Validator,
messages=[
@@ -85,7 +85,7 @@ def llm_validator(
],
model=model,
temperature=temperature,
) # type: ignore
)
# If the response is not valid, return the reason, this could be used in
# the future to generate a better response, via reasking mechanism.
@@ -99,7 +99,7 @@ def llm_validator(
return llm
def openai_moderation(client: OpenAI = None):
def openai_moderation(client: Optional[OpenAI] = None) -> Callable[[str], str]:
"""
Validates a message using OpenAI moderation model.