From 11d62099730be0dbbb47fb5f61068b66d57b1961 Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Mon, 28 Oct 2024 06:32:53 -0400 Subject: [PATCH] test --- pyproject.toml | 4 ++++ simplemind/integrations/base.py | 21 ++++++++++++--------- simplemind/integrations/openai.py | 5 ++++- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d497720..4429a0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,10 @@ readme = "README.md" requires-python = ">=3.11" dependencies = ["pydantic", "instructor"] +[dependency-groups] +openai = ["openai"] +claude = ["anthropic"] + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" diff --git a/simplemind/integrations/base.py b/simplemind/integrations/base.py index 0f533f7..2b6cf83 100644 --- a/simplemind/integrations/base.py +++ b/simplemind/integrations/base.py @@ -1,17 +1,20 @@ import os +import logging class BaseClientProvider: - def __init__(self, *, api_key=None, environ_name=None): - self._environ_name = environ_name - # TODO: reverse order, potentially? + def __init__(self, *, api_key=None, api_key_environ_name=None): + self.logger = logging.getLogger(self.__class__.__name__) + + self._api_key_environ_name = api_key_environ_name + + # Load API key from environment if not provided self._api_key = api_key or self._load_from_environ() def _load_from_environ(self): - if self._environ_name: - self.api_key = os.environ.get(self._environ_name) - else: - self.api_key = None + if self._api_key_environ_name: + return os.environ.get(self._api_key_environ_name) + return None def test_connection(self): raise NotImplementedError("This method must be implemented by the AI provider client.") @@ -25,5 +28,5 @@ class BaseClientProvider: def available_models(self): raise NotImplementedError("This method must be implemented by the AI provider client.") - # TODO: logging provider. - # + def features(self): + raise NotImplementedError("This method must be implemented by the AI provider client.") diff --git a/simplemind/integrations/openai.py b/simplemind/integrations/openai.py index 5f6892d..85e0203 100644 --- a/simplemind/integrations/openai.py +++ b/simplemind/integrations/openai.py @@ -2,4 +2,7 @@ from .base import BaseClientProvider class OpenAI(BaseClientProvider): - pass + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def test_connection(self):