This commit is contained in:
2024-10-28 06:42:01 -04:00
parent 11d6209973
commit 2bbaab20d6
4 changed files with 53 additions and 13 deletions
+6
View File
@@ -0,0 +1,6 @@
{
"workbench.colorCustomizations": {
"sash.hoverBorder": "#d5f4a4"
},
"peacock.color": "#c0ef76"
}
+35 -13
View File
@@ -1,32 +1,54 @@
import os
import logging
class BaseClientProvider:
def __init__(self, *, api_key=None, api_key_environ_name=None):
def __init__(self, *, api_key_environ_key=None, api_key=None):
self.logger = logging.getLogger(self.__class__.__name__)
self._api_key_environ_name = api_key_environ_name
self.client = None
# Load API key from environment if not provided
self._api_key = api_key or self._load_from_environ()
self._api_key = api_key or self._load_from_environ(self._api_key_environ_name)
def _load_from_environ(self):
if self._api_key_environ_name:
return os.environ.get(self._api_key_environ_name)
return None
@classmethod
def from_environ(cls, environ_key):
"""Loads the API key from the environment (recommended)."""
return cls(api_key=os.environ.get(environ_key))
def initialize(self):
"""Initializes the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
def test_connection(self):
raise NotImplementedError("This method must be implemented by the AI provider client.")
"""Tests the connection to the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
def generate_response(self, request):
raise NotImplementedError("This method must be implemented by the AI provider client.")
"""Generates a response from the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
def health_check(self):
raise NotImplementedError("This method must be implemented by the AI provider client.")
"""Checks the health of the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
def available_models(self):
raise NotImplementedError("This method must be implemented by the AI provider client.")
"""Returns the available models from the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
def features(self):
raise NotImplementedError("This method must be implemented by the AI provider client.")
"""Returns the features of the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
+12
View File
@@ -4,5 +4,17 @@ from .base import BaseClientProvider
class OpenAI(BaseClientProvider):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.initialize()
def initialize(self):
assert self._api_key, "API key is required for OpenAI client"
assert self._api_key.startswith("sk-"), "OpenAI API key must start with 'sk-'"
self.logger.info("Initializing OpenAI client")
self.logger.debug(f"API key: {self._api_key}")
self.test_connection()
def test_connection(self):
View File