From f5b922ade80192995039eb8d747dcdf9626912e8 Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Fri, 1 Nov 2024 12:25:44 +0530 Subject: [PATCH] added proper type hinting --- simplemind/providers/_base.py | 5 ++++- simplemind/providers/anthropic.py | 7 +++++-- simplemind/providers/gemini.py | 10 +++++++--- simplemind/providers/groq.py | 10 +++++++--- simplemind/providers/ollama.py | 5 ++++- simplemind/providers/openai.py | 7 +++++-- simplemind/providers/xai.py | 18 ++++++++++++------ 7 files changed, 44 insertions(+), 18 deletions(-) diff --git a/simplemind/providers/_base.py b/simplemind/providers/_base.py index 42ffb9b..4485246 100644 --- a/simplemind/providers/_base.py +++ b/simplemind/providers/_base.py @@ -1,10 +1,13 @@ from abc import ABC, abstractmethod from functools import cached_property -from typing import Any, Type, TypeVar +from typing import TYPE_CHECKING, Any, Type, TypeVar from instructor import Instructor from pydantic import BaseModel +if TYPE_CHECKING: + from ..models import Conversation, Message + T = TypeVar("T", bound=BaseModel) diff --git a/simplemind/providers/anthropic.py b/simplemind/providers/anthropic.py index 9978a17..b51e40f 100644 --- a/simplemind/providers/anthropic.py +++ b/simplemind/providers/anthropic.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Type, TypeVar +from typing import TYPE_CHECKING, Type, TypeVar import anthropic import instructor @@ -9,6 +9,9 @@ from ..logging import logger from ..settings import settings from ._base import BaseProvider +if TYPE_CHECKING: + from ..models import Conversation, Message + T = TypeVar("T", bound=BaseModel) @@ -39,7 +42,7 @@ class Anthropic(BaseProvider): return instructor.from_anthropic(self.client) @logger - def send_conversation(self, conversation: "Conversation", **kwargs): + def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message": """Send a conversation to the Anthropic API.""" from ..models import Message diff --git a/simplemind/providers/gemini.py b/simplemind/providers/gemini.py index 4a20262..cd91471 100644 --- a/simplemind/providers/gemini.py +++ b/simplemind/providers/gemini.py @@ -2,7 +2,7 @@ # IT is not currently working as desired. from functools import cached_property -from typing import Type, TypeVar +from typing import TYPE_CHECKING, Type, TypeVar import google.generativeai as genai import instructor @@ -12,12 +12,16 @@ from ..logging import logger from ..settings import settings from ._base import BaseProvider -PROVIDER_NAME = "gemini" -DEFAULT_MODEL = "models/gemini-1.5-flash-latest" +if TYPE_CHECKING: + from ..models import Conversation, Message T = TypeVar("T", bound=BaseModel) +PROVIDER_NAME = "gemini" +DEFAULT_MODEL = "models/gemini-1.5-flash-latest" + + class Gemini(BaseProvider): NAME = PROVIDER_NAME DEFAULT_MODEL = DEFAULT_MODEL diff --git a/simplemind/providers/groq.py b/simplemind/providers/groq.py index c3a82f8..5114e90 100644 --- a/simplemind/providers/groq.py +++ b/simplemind/providers/groq.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Type, TypeVar +from typing import TYPE_CHECKING, Type, TypeVar import groq import instructor @@ -9,13 +9,17 @@ from ..logging import logger from ..settings import settings from ._base import BaseProvider +if TYPE_CHECKING: + from ..models import Conversation, Message + +T = TypeVar("T", bound=BaseModel) + + PROVIDER_NAME = "groq" DEFAULT_MODEL = "llama3-8b-8192" DEFAULT_MAX_TOKENS = 1_000 DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} -T = TypeVar("T", bound=BaseModel) - class Groq(BaseProvider): NAME = PROVIDER_NAME diff --git a/simplemind/providers/ollama.py b/simplemind/providers/ollama.py index e9c3220..c422eb4 100644 --- a/simplemind/providers/ollama.py +++ b/simplemind/providers/ollama.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Type, TypeVar +from typing import TYPE_CHECKING, Type, TypeVar import instructor import ollama as ol @@ -10,6 +10,9 @@ from ..logging import logger from ..settings import settings from ._base import BaseProvider +if TYPE_CHECKING: + from ..models import Conversation, Message + T = TypeVar("T", bound=BaseModel) diff --git a/simplemind/providers/openai.py b/simplemind/providers/openai.py index f99e9b7..2a05080 100644 --- a/simplemind/providers/openai.py +++ b/simplemind/providers/openai.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Type, TypeVar +from typing import TYPE_CHECKING, Type, TypeVar import instructor import openai as oa @@ -9,6 +9,9 @@ from ..logging import logger from ..settings import settings from ._base import BaseProvider +if TYPE_CHECKING: + from ..models import Conversation, Message + T = TypeVar("T", bound=BaseModel) PROVIDER_NAME = "openai" @@ -38,7 +41,7 @@ class OpenAI(BaseProvider): return instructor.from_openai(self.client) @logger - def send_conversation(self, conversation: "Conversation", **kwargs): + def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message": """Send a conversation to the OpenAI API.""" from ..models import Message diff --git a/simplemind/providers/xai.py b/simplemind/providers/xai.py index 936bca2..afc1faf 100644 --- a/simplemind/providers/xai.py +++ b/simplemind/providers/xai.py @@ -1,22 +1,28 @@ from functools import cached_property -from typing import Type, TypeVar +from typing import TYPE_CHECKING, Type, TypeVar import instructor import openai as oa from pydantic import BaseModel +from simplemind.models import Message + from ..logging import logger from ..settings import settings from ._base import BaseProvider +if TYPE_CHECKING: + from ..models import Conversation, Message + +T = TypeVar("T", bound=BaseModel) + + PROVIDER_NAME = "xai" DEFAULT_MODEL = "grok-beta" BASE_URL = "https://api.x.ai/v1" DEFAULT_MAX_TOKENS = 1000 DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS} -T = TypeVar("T", bound=BaseModel) - class XAI(BaseProvider): NAME = PROVIDER_NAME @@ -42,7 +48,7 @@ class XAI(BaseProvider): return instructor.from_openai(self.client) @logger - def send_conversation(self, conversation: "Conversation", **kwargs): + def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message": """Send a conversation to the OpenAI API.""" from ..models import Message @@ -75,7 +81,7 @@ class XAI(BaseProvider): raise NotImplementedError("XAI does not support structured responses") @logger - def generate_text(self, prompt: str, *, llm_model: str, **kwargs): + def generate_text(self, prompt: str, *, llm_model: str, **kwargs) -> str: messages = [ {"role": "user", "content": prompt}, ] @@ -86,4 +92,4 @@ class XAI(BaseProvider): **{**self.DEFAULT_KWARGS, **kwargs}, ) - return response.choices[0].message.content + return str(response.choices[0].message.content)