added proper type hinting

This commit is contained in:
Siddhesh Agarwal
2024-11-01 12:25:44 +05:30
parent 3a7383425f
commit f5b922ade8
7 changed files with 44 additions and 18 deletions
+4 -1
View File
@@ -1,10 +1,13 @@
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Any, Type, TypeVar
from typing import TYPE_CHECKING, Any, Type, TypeVar
from instructor import Instructor
from pydantic import BaseModel
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
+5 -2
View File
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import Type, TypeVar
from typing import TYPE_CHECKING, Type, TypeVar
import anthropic
import instructor
@@ -9,6 +9,9 @@ from ..logging import logger
from ..settings import settings
from ._base import BaseProvider
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
@@ -39,7 +42,7 @@ class Anthropic(BaseProvider):
return instructor.from_anthropic(self.client)
@logger
def send_conversation(self, conversation: "Conversation", **kwargs):
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
"""Send a conversation to the Anthropic API."""
from ..models import Message
+7 -3
View File
@@ -2,7 +2,7 @@
# IT is not currently working as desired.
from functools import cached_property
from typing import Type, TypeVar
from typing import TYPE_CHECKING, Type, TypeVar
import google.generativeai as genai
import instructor
@@ -12,12 +12,16 @@ from ..logging import logger
from ..settings import settings
from ._base import BaseProvider
PROVIDER_NAME = "gemini"
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "gemini"
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
class Gemini(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
+7 -3
View File
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import Type, TypeVar
from typing import TYPE_CHECKING, Type, TypeVar
import groq
import instructor
@@ -9,13 +9,17 @@ from ..logging import logger
from ..settings import settings
from ._base import BaseProvider
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "groq"
DEFAULT_MODEL = "llama3-8b-8192"
DEFAULT_MAX_TOKENS = 1_000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
T = TypeVar("T", bound=BaseModel)
class Groq(BaseProvider):
NAME = PROVIDER_NAME
+4 -1
View File
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import Type, TypeVar
from typing import TYPE_CHECKING, Type, TypeVar
import instructor
import ollama as ol
@@ -10,6 +10,9 @@ from ..logging import logger
from ..settings import settings
from ._base import BaseProvider
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
+5 -2
View File
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import Type, TypeVar
from typing import TYPE_CHECKING, Type, TypeVar
import instructor
import openai as oa
@@ -9,6 +9,9 @@ from ..logging import logger
from ..settings import settings
from ._base import BaseProvider
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "openai"
@@ -38,7 +41,7 @@ class OpenAI(BaseProvider):
return instructor.from_openai(self.client)
@logger
def send_conversation(self, conversation: "Conversation", **kwargs):
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
"""Send a conversation to the OpenAI API."""
from ..models import Message
+12 -6
View File
@@ -1,22 +1,28 @@
from functools import cached_property
from typing import Type, TypeVar
from typing import TYPE_CHECKING, Type, TypeVar
import instructor
import openai as oa
from pydantic import BaseModel
from simplemind.models import Message
from ..logging import logger
from ..settings import settings
from ._base import BaseProvider
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "xai"
DEFAULT_MODEL = "grok-beta"
BASE_URL = "https://api.x.ai/v1"
DEFAULT_MAX_TOKENS = 1000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
T = TypeVar("T", bound=BaseModel)
class XAI(BaseProvider):
NAME = PROVIDER_NAME
@@ -42,7 +48,7 @@ class XAI(BaseProvider):
return instructor.from_openai(self.client)
@logger
def send_conversation(self, conversation: "Conversation", **kwargs):
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
"""Send a conversation to the OpenAI API."""
from ..models import Message
@@ -75,7 +81,7 @@ class XAI(BaseProvider):
raise NotImplementedError("XAI does not support structured responses")
@logger
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
def generate_text(self, prompt: str, *, llm_model: str, **kwargs) -> str:
messages = [
{"role": "user", "content": prompt},
]
@@ -86,4 +92,4 @@ class XAI(BaseProvider):
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response.choices[0].message.content
return str(response.choices[0].message.content)