mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 22:50:18 +00:00
added proper type hinting
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import Any, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Type, TypeVar
|
||||
|
||||
from instructor import Instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import anthropic
|
||||
import instructor
|
||||
@@ -9,6 +9,9 @@ from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
@@ -39,7 +42,7 @@ class Anthropic(BaseProvider):
|
||||
return instructor.from_anthropic(self.client)
|
||||
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||
"""Send a conversation to the Anthropic API."""
|
||||
from ..models import Message
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# IT is not currently working as desired.
|
||||
|
||||
from functools import cached_property
|
||||
from typing import Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import google.generativeai as genai
|
||||
import instructor
|
||||
@@ -12,12 +12,16 @@ from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
PROVIDER_NAME = "gemini"
|
||||
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "gemini"
|
||||
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
|
||||
|
||||
|
||||
class Gemini(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import groq
|
||||
import instructor
|
||||
@@ -9,13 +9,17 @@ from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "groq"
|
||||
DEFAULT_MODEL = "llama3-8b-8192"
|
||||
DEFAULT_MAX_TOKENS = 1_000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class Groq(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
import ollama as ol
|
||||
@@ -10,6 +10,9 @@ from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
import openai as oa
|
||||
@@ -9,6 +9,9 @@ from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
PROVIDER_NAME = "openai"
|
||||
@@ -38,7 +41,7 @@ class OpenAI(BaseProvider):
|
||||
return instructor.from_openai(self.client)
|
||||
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
from ..models import Message
|
||||
|
||||
|
||||
@@ -1,22 +1,28 @@
|
||||
from functools import cached_property
|
||||
from typing import Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import instructor
|
||||
import openai as oa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from simplemind.models import Message
|
||||
|
||||
from ..logging import logger
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import Conversation, Message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "xai"
|
||||
DEFAULT_MODEL = "grok-beta"
|
||||
BASE_URL = "https://api.x.ai/v1"
|
||||
DEFAULT_MAX_TOKENS = 1000
|
||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class XAI(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
@@ -42,7 +48,7 @@ class XAI(BaseProvider):
|
||||
return instructor.from_openai(self.client)
|
||||
|
||||
@logger
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
||||
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||
"""Send a conversation to the OpenAI API."""
|
||||
from ..models import Message
|
||||
|
||||
@@ -75,7 +81,7 @@ class XAI(BaseProvider):
|
||||
raise NotImplementedError("XAI does not support structured responses")
|
||||
|
||||
@logger
|
||||
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
|
||||
def generate_text(self, prompt: str, *, llm_model: str, **kwargs) -> str:
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
@@ -86,4 +92,4 @@ class XAI(BaseProvider):
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
return str(response.choices[0].message.content)
|
||||
|
||||
Reference in New Issue
Block a user