diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index 26897e4..d76a711 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -17,6 +17,13 @@ env: instructor/cli/usage.py instructor/exceptions.py instructor/distil.py + instructor/dsl/citation.py + instructor/dsl/iterable.py + instructor/dsl/maybe.py + instructor/dsl/parallel.py + instructor/dsl/partial.py + instructor/dsl/partialjson.py + instructor/dsl/validators.py instructor/function_calls.py tests/test_function_calls.py tests/test_distil.py diff --git a/instructor/dsl/citation.py b/instructor/dsl/citation.py index b324880..df72b24 100644 --- a/instructor/dsl/citation.py +++ b/instructor/dsl/citation.py @@ -1,8 +1,8 @@ from pydantic import BaseModel, Field, FieldValidationInfo, model_validator -from typing import List +from typing import Generator, List, Tuple -class CitationMixin(BaseModel): +class CitationMixin(BaseModel): # type: ignore[misc] """ Helpful mixing that can use `validation_context={"context": context}` in `from_response` to find the span of the substring_phrase in the context. @@ -57,7 +57,7 @@ class CitationMixin(BaseModel): description="List of unique and specific substrings of the quote that was used to answer the question.", ) - @model_validator(mode="after") + @model_validator(mode="after") # type: ignore[misc] def validate_sources(self, info: FieldValidationInfo) -> "CitationMixin": """ For each substring_phrase, find the span of the substring_phrase in the context. @@ -75,8 +75,10 @@ class CitationMixin(BaseModel): self.substring_quotes = [text_chunks[span[0] : span[1]] for span in spans] return self - def _get_span(self, quote, context, errs=5): - import regex + def _get_span( + self, quote: str, context: str, errs: int = 5 + ) -> Generator[Tuple[int, int], None, None]: + import regex # type: ignore[import-untyped] minor = quote major = context @@ -90,6 +92,6 @@ class CitationMixin(BaseModel): if s is not None: yield from s.spans() - def get_spans(self, context): + def get_spans(self, context: str) -> Generator[Tuple[int, int], None, None]: for quote in self.substring_quotes: yield from self._get_span(quote, context) diff --git a/instructor/dsl/iterable.py b/instructor/dsl/iterable.py index 21bea2b..fd2c9e2 100644 --- a/instructor/dsl/iterable.py +++ b/instructor/dsl/iterable.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Type, Any +from typing import Any, AsyncGenerator, Generator, Iterable, List, Optional, Tuple, Type from pydantic import BaseModel, Field, create_model @@ -6,20 +6,26 @@ from instructor.function_calls import OpenAISchema, Mode class IterableBase: - task_type = None # type: ignore + task_type = None # type: ignore[var-annotated] @classmethod - def from_streaming_response(cls, completion, mode: Mode, **kwargs: Any): # noqa: ARG003 + def from_streaming_response( + cls, completion: Iterable[Any], mode: Mode, **kwargs: Any + ) -> Generator[BaseModel, None, None]: # noqa: ARG003 json_chunks = cls.extract_json(completion, mode) - yield from cls.tasks_from_chunks(json_chunks) + yield from cls.tasks_from_chunks(json_chunks, **kwargs) @classmethod - async def from_streaming_response_async(cls, completion, mode: Mode, **kwargs): + async def from_streaming_response_async( + cls, completion: AsyncGenerator[Any, None], mode: Mode, **kwargs: Any + ) -> AsyncGenerator[BaseModel, None]: json_chunks = cls.extract_json_async(completion, mode) return cls.tasks_from_chunks_async(json_chunks, **kwargs) @classmethod - def tasks_from_chunks(cls, json_chunks, **kwargs): + def tasks_from_chunks( + cls, json_chunks: Iterable[str], **kwargs: Any + ) -> Generator[BaseModel, None, None]: started = False potential_object = "" for chunk in json_chunks: @@ -32,11 +38,14 @@ class IterableBase: task_json, potential_object = cls.get_object(potential_object, 0) if task_json: - obj = cls.task_type.model_validate_json(task_json, **kwargs) # type: ignore + assert cls.task_type is not None + obj = cls.task_type.model_validate_json(task_json, **kwargs) yield obj @classmethod - async def tasks_from_chunks_async(cls, json_chunks, **kwargs): + async def tasks_from_chunks_async( + cls, json_chunks: AsyncGenerator[str, None], **kwargs: Any + ) -> AsyncGenerator[BaseModel, None]: started = False potential_object = "" async for chunk in json_chunks: @@ -49,11 +58,14 @@ class IterableBase: task_json, potential_object = cls.get_object(potential_object, 0) if task_json: - obj = cls.task_type.model_validate_json(task_json, **kwargs) # type: ignore + assert cls.task_type is not None + obj = cls.task_type.model_validate_json(task_json, **kwargs) yield obj @staticmethod - def extract_json(completion, mode: Mode): + def extract_json( + completion: Iterable[Any], mode: Mode + ) -> Generator[str, None, None]: for chunk in completion: try: if chunk.choices: @@ -74,7 +86,9 @@ class IterableBase: pass @staticmethod - async def extract_json_async(completion, mode: Mode): + async def extract_json_async( + completion: AsyncGenerator[Any, None], mode: Mode + ) -> AsyncGenerator[str, None]: async for chunk in completion: try: if chunk.choices: @@ -95,15 +109,15 @@ class IterableBase: pass @staticmethod - def get_object(str, stack): - for i, c in enumerate(str): + def get_object(s: str, stack: int) -> Tuple[Optional[str], str]: + for i, c in enumerate(s): if c == "{": stack += 1 if c == "}": stack -= 1 if stack == 0: - return str[: i + 1], str[i + 2 :] - return None, str + return s[: i + 1], s[i + 2 :] + return None, s def IterableModel( @@ -166,7 +180,7 @@ def IterableModel( name = f"Iterable{task_name}" list_tasks = ( - List[subtask_class], + List[subtask_class], # type: ignore[valid-type] Field( default_factory=list, repr=False, @@ -177,7 +191,7 @@ def IterableModel( new_cls = create_model( name, tasks=list_tasks, - __base__=(OpenAISchema, IterableBase), # type: ignore + __base__=(OpenAISchema, IterableBase), ) # set the class constructor BaseModel new_cls.task_type = subtask_class diff --git a/instructor/dsl/maybe.py b/instructor/dsl/maybe.py index 621cf4d..98e3ea9 100644 --- a/instructor/dsl/maybe.py +++ b/instructor/dsl/maybe.py @@ -1,10 +1,10 @@ from pydantic import BaseModel, Field, create_model -from typing import Type, Optional, TypeVar, Generic +from typing import Generic, Optional, Type, TypeVar T = TypeVar("T", bound=BaseModel) -class MaybeBase(BaseModel, Generic[T]): +class MaybeBase(BaseModel, Generic[T]): # type: ignore[misc] """ Extract a result from a model, if any, otherwise set the error and message fields. """ @@ -13,8 +13,8 @@ class MaybeBase(BaseModel, Generic[T]): error: bool = Field(default=False) message: Optional[str] - def __bool__(self): - return self.result is not None # type: ignore + def __bool__(self) -> bool: + return self.result is not None def Maybe(model: Type[T]) -> Type[MaybeBase[T]]: diff --git a/instructor/dsl/parallel.py b/instructor/dsl/parallel.py index 7369134..c1756c7 100644 --- a/instructor/dsl/parallel.py +++ b/instructor/dsl/parallel.py @@ -1,10 +1,22 @@ -from typing import Type, TypeVar, Union, get_origin, get_args -from types import UnionType +from typing import ( + Any, + Dict, + Generator, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + get_args, + get_origin, +) +from types import UnionType # type: ignore[attr-defined] from instructor.function_calls import OpenAISchema, Mode, openai_schema from collections.abc import Iterable -T = TypeVar("T") +T = TypeVar("T", bound=OpenAISchema) class ParallelBase: @@ -16,11 +28,11 @@ class ParallelBase: def from_response( self, - response, + response: Any, mode: Mode, - validation_context=None, - strict: bool = None, - ) -> Iterable[Union[T]]: + validation_context: Optional[Any] = None, + strict: Optional[bool] = None, + ) -> Generator[T, None, None]: #! We expect this from the OpenAISchema class, We should address #! this with a protocol or an abstract class... @jxnlco assert mode == Mode.PARALLEL_TOOLS, "Mode must be PARALLEL_TOOLS" @@ -32,7 +44,7 @@ class ParallelBase: ) -def get_types_array(typehint: Type[Iterable[Union[T]]]): +def get_types_array(typehint: Type[Iterable[Union[T]]]) -> Tuple[Type[T], ...]: should_be_iterable = get_origin(typehint) assert should_be_iterable is Iterable @@ -50,7 +62,7 @@ def get_types_array(typehint: Type[Iterable[Union[T]]]): return get_args(typehint) -def handle_parallel_model(typehint: Type[Iterable[Union[T]]]): +def handle_parallel_model(typehint: Type[Iterable[Union[T]]]) -> List[Dict[str, Any]]: the_types = get_types_array(typehint) return [ {"type": "function", "function": openai_schema(model).openai_schema} @@ -58,6 +70,6 @@ def handle_parallel_model(typehint: Type[Iterable[Union[T]]]): ] -def ParallelModel(typehint): +def ParallelModel(typehint: Type[Iterable[Union[T]]]) -> ParallelBase: the_types = get_types_array(typehint) return ParallelBase(*[model for model in the_types]) diff --git a/instructor/dsl/partial.py b/instructor/dsl/partial.py index f439312..cb4cc02 100644 --- a/instructor/dsl/partial.py +++ b/instructor/dsl/partial.py @@ -8,7 +8,18 @@ from pydantic import BaseModel, create_model from pydantic.fields import FieldInfo -from typing import TypeVar, NoReturn, get_args, get_origin, Optional, Generic +from typing import ( + Any, + AsyncGenerator, + Generator, + Generic, + get_args, + get_origin, + Iterable, + NoReturn, + Optional, + TypeVar, +) from copy import deepcopy from instructor.function_calls import Mode @@ -21,17 +32,23 @@ Model = TypeVar("Model", bound=BaseModel) class PartialBase: @classmethod - def from_streaming_response(cls, completion, mode: Mode, **kwargs): + def from_streaming_response( + cls, completion: Iterable[Any], mode: Mode, **kwargs: Any + ) -> Generator[Model, None, None]: json_chunks = cls.extract_json(completion, mode) yield from cls.model_from_chunks(json_chunks, **kwargs) @classmethod - async def from_streaming_response_async(cls, completion, mode: Mode, **kwargs): + async def from_streaming_response_async( + cls, completion: AsyncGenerator[Any, None], mode: Mode, **kwargs: Any + ) -> AsyncGenerator[Model, None]: json_chunks = cls.extract_json_async(completion, mode) return cls.model_from_chunks_async(json_chunks, **kwargs) @classmethod - def model_from_chunks(cls, json_chunks, **kwargs): + def model_from_chunks( + cls, json_chunks: Iterable[Any], **kwargs: Any + ) -> Generator[Model, None, None]: prev_obj = None potential_object = "" for chunk in json_chunks: @@ -42,7 +59,7 @@ class PartialBase: parser.parse(potential_object) if potential_object.strip() else None ) if task_json: - obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore + obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore[attr-defined] if obj != prev_obj: obj.__dict__[ "chunk" @@ -51,7 +68,9 @@ class PartialBase: yield obj @classmethod - async def model_from_chunks_async(cls, json_chunks, **kwargs): + async def model_from_chunks_async( + cls, json_chunks: AsyncGenerator[str, None], **kwargs: Any + ) -> AsyncGenerator[Model, None]: potential_object = "" prev_obj = None async for chunk in json_chunks: @@ -62,7 +81,7 @@ class PartialBase: parser.parse(potential_object) if potential_object.strip() else None ) if task_json: - obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore + obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore[attr-defined] if obj != prev_obj: obj.__dict__[ "chunk" @@ -71,7 +90,9 @@ class PartialBase: yield obj @staticmethod - def extract_json(completion, mode: Mode): + def extract_json( + completion: Iterable[Any], mode: Mode + ) -> Generator[str, None, None]: for chunk in completion: try: if chunk.choices: @@ -92,7 +113,9 @@ class PartialBase: pass @staticmethod - async def extract_json_async(completion, mode: Mode): + async def extract_json_async( + completion: AsyncGenerator[Any, None], mode: Mode + ) -> AsyncGenerator[str, None]: async for chunk in completion: try: if chunk.choices: @@ -169,7 +192,7 @@ class Partial(Generic[Model]): # Recursively apply Partial to each of the generic arguments modified_args = tuple( - Partial[arg] + Partial[arg] # type: ignore[valid-type] if isinstance(arg, type) and issubclass(arg, BaseModel) else arg for arg in generic_args diff --git a/instructor/dsl/partialjson.py b/instructor/dsl/partialjson.py index dfceef3..3e215e3 100644 --- a/instructor/dsl/partialjson.py +++ b/instructor/dsl/partialjson.py @@ -8,12 +8,12 @@ The above copyright notice and this permission notice shall be included in all c THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - +from typing import Any, Dict, Optional, Tuple import json class JSONParser: - def __init__(self): + def __init__(self) -> None: self.parsers = { " ": self.parse_space, "\r": self.parse_space, @@ -30,26 +30,26 @@ class JSONParser: for c in "0123456789.-": self.parsers[c] = self.parse_number - self.last_parse_reminding = None + self.last_parse_reminding: Optional[str] = None self.on_extra_token = self.default_on_extra_token - def default_on_extra_token(self, text, data, reminding): + def default_on_extra_token(self, text: str, data: Any, reminding: str) -> None: pass - def parse(self, s): + def parse(self, s: str) -> Dict[str, Any]: if len(s) >= 1: try: return json.loads(s) except json.JSONDecodeError as e: data, reminding = self.parse_any(s, e) self.last_parse_reminding = reminding - if self.on_extra_token and reminding: - self.on_extra_token(s, data, reminding) + if self.on_extra_token is not None and reminding: + self.on_extra_token(s, data, reminding) # type: ignore[no-untyped-call] return json.loads(json.dumps(data)) else: return json.loads("{}") - def parse_any(self, s, e): + def parse_any(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]: if not s: raise e parser = self.parsers.get(s[0]) @@ -57,10 +57,10 @@ class JSONParser: raise e return parser(s, e) - def parse_space(self, s, e): + def parse_space(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]: return self.parse_any(s.strip(), e) - def parse_array(self, s, e): + def parse_array(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]: s = s[1:] # skip starting '[' acc = [] s = s.strip() @@ -76,9 +76,9 @@ class JSONParser: s = s.strip() return acc, s - def parse_object(self, s, e): + def parse_object(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]: s = s[1:] # skip starting '{' - acc = {} + acc: Dict[str, Any] = {} s = s.strip() while s: if s[0] == "}": @@ -114,7 +114,7 @@ class JSONParser: s = s.strip() return acc, s - def parse_string(self, s, e): # noqa: ARG002 + def parse_string(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]: # noqa: ARG002 end = s.find('"', 1) while end != -1 and s[end - 1] == "\\": # Handle escaped quotes end = s.find('"', end + 1) @@ -125,7 +125,7 @@ class JSONParser: s = s[end + 1 :] return json.loads(str_val), s - def parse_number(self, s, e): + def parse_number(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]: i = 0 while i < len(s) and s[i] in "0123456789.-": i += 1 @@ -143,17 +143,17 @@ class JSONParser: raise e return num, s - def parse_true(self, s, e): + def parse_true(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]: if s.startswith("true"): return True, s[4:] raise e - def parse_false(self, s, e): + def parse_false(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]: if s.startswith("false"): return False, s[5:] raise e - def parse_null(self, s, e): + def parse_null(self, s: str, e: json.JSONDecodeError) -> Tuple[Any, str]: if s.startswith("null"): return None, s[4:] raise e diff --git a/instructor/dsl/validators.py b/instructor/dsl/validators.py index 0c6f137..90365d8 100644 --- a/instructor/dsl/validators.py +++ b/instructor/dsl/validators.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Callable, Optional from openai import OpenAI from pydantic import Field @@ -33,7 +33,7 @@ def llm_validator( model: str = "gpt-3.5-turbo", temperature: float = 0, openai_client: OpenAI = None, -): +) -> Callable[[str], str]: """ Create a validator that uses the LLM to validate an attribute @@ -70,7 +70,7 @@ def llm_validator( openai_client = openai_client if openai_client else patch(OpenAI()) - def llm(v): + def llm(v: str) -> str: resp = openai_client.chat.completions.create( response_model=Validator, messages=[ @@ -85,7 +85,7 @@ def llm_validator( ], model=model, temperature=temperature, - ) # type: ignore + ) # If the response is not valid, return the reason, this could be used in # the future to generate a better response, via reasking mechanism. @@ -99,7 +99,7 @@ def llm_validator( return llm -def openai_moderation(client: OpenAI = None): +def openai_moderation(client: Optional[OpenAI] = None) -> Callable[[str], str]: """ Validates a message using OpenAI moderation model.