update function annotations

This commit is contained in:
2024-11-02 17:14:53 -04:00
parent 22aff505c4
commit 25ba1a9289
7 changed files with 20 additions and 15 deletions
+2 -2
View File
@@ -1,4 +1,4 @@
from typing import Type, TypeVar
from typing import Type, TypeVar, Iterator
from functools import cached_property
import instructor
@@ -94,7 +94,7 @@ class Amazon(BaseProvider):
return response.content[0].text
def generate_stream_text(self, prompt, *, llm_model, **kwargs):
def generate_stream_text(self, prompt, *, llm_model, **kwargs) -> Iterator[str]:
"""Generate streaming text using the Amazon API."""
# Prepare the messages.
+4 -2
View File
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
import instructor
from pydantic import BaseModel
@@ -110,7 +110,9 @@ class Anthropic(BaseProvider):
return response.content[0].text
@logger
def generate_stream_text(self, prompt: str, *, llm_model: str, **kwargs):
def generate_stream_text(
self, prompt: str, *, llm_model: str, **kwargs
) -> Iterator[str]:
# Prepare the messages.
messages = [
{"role": "user", "content": prompt},
+2 -2
View File
@@ -2,7 +2,7 @@
# IT is not currently working as desired.
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
import instructor
from pydantic import BaseModel
@@ -110,7 +110,7 @@ class Gemini(BaseProvider):
return response.text
@logger
def generate_stream_text(self, prompt: str, **kwargs) -> str:
def generate_stream_text(self, prompt: str, **kwargs) -> Iterator[str]:
"""Generate streaming text using the Gemini API."""
kwargs.pop("llm_model", None)
try:
+2 -2
View File
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
import instructor
from pydantic import BaseModel
@@ -120,7 +120,7 @@ class Groq(BaseProvider):
*,
llm_model: str | None = None,
**kwargs,
) -> str:
) -> Iterator[str]:
"""Generate streaming text using the Groq API."""
messages = [
{"role": "user", "content": prompt},
+4 -2
View File
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
import instructor
from openai import OpenAI
@@ -119,7 +119,9 @@ class Ollama(BaseProvider):
return response.get("message", {}).get("content", "")
@logger
def generate_stream_text(self, prompt: str, *, llm_model: str, **kwargs) -> str:
def generate_stream_text(
self, prompt: str, *, llm_model: str, **kwargs
) -> Iterator[str]:
# Prepare the messages.
messages = [
{"role": "user", "content": prompt},
+2 -3
View File
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
import instructor
from pydantic import BaseModel
@@ -24,7 +24,6 @@ class OpenAI(BaseProvider):
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
supports_streaming = True
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@@ -112,7 +111,7 @@ class OpenAI(BaseProvider):
@logger
def generate_stream_text(
self, prompt: str, *, llm_model: str | None = None, **kwargs
):
) -> Iterator[str]:
"""Generate streaming text using the OpenAI API.
Yields chunks of text as they are generated by the model.
+4 -2
View File
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
import instructor
from pydantic import BaseModel
@@ -103,7 +103,9 @@ class XAI(BaseProvider):
return str(response.choices[0].message.content)
@logger
def generate_stream_text(self, prompt: str, *, llm_model: str, **kwargs) -> str:
def generate_stream_text(
self, prompt: str, *, llm_model: str, **kwargs
) -> Iterator[str]:
# Prepare the messages.
messages = [
{"role": "user", "content": prompt},