This commit is contained in:
2024-10-28 07:57:03 -04:00
parent 47efa51ac6
commit e9354d915c
7 changed files with 117 additions and 21 deletions
+1 -1
View File
@@ -1,2 +1,2 @@
from .claude import Claude
from .anthropic import Anthropic
from .openai import OpenAI
+41
View File
@@ -0,0 +1,41 @@
import os
import instructor
from anthropic import Anthropic as BaseAnthropic
from .base import BaseClientProvider
class Anthropic(BaseClientProvider):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.login()
def login(self):
"""Initialize Anthropic client, with Instructor enabled."""
# Default to environment variable if not provided.
if self._api_key is None:
self._api_key = os.getenv("ANTHROPIC_API_KEY")
base_client = BaseAnthropic(api_key=self._api_key)
self.client = instructor.from_anthropic(base_client)
# assert self.test_connection()
@property
def available_models(self):
"""Returns the available models from the Anthropic client."""
# TODO: scrape from website or embed
return [
"claude-3-opus-20240229",
"claude-3-5-sonnet-20240620",
"claude-3-haiku-20240307",
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20240620",
]
# def test_connection(self):
# """Test the connection to Anthropic. Returns True if successful."""
# raise NotImplementedError("Anthropic test_connection not implemented.")
+26 -9
View File
@@ -2,11 +2,16 @@
from pydantic import BaseModel
DEFAULT_MODEL = "gpt-4o"
class BaseClientProvider:
def __init__(self, *, api_key=None):
def __init__(self, *, model=DEFAULT_MODEL, api_key=None):
# self.logger = logging.getLogger(self.__class__.__name__)
self.client = None
self.model = model
# Load API key from environment if not provided
self._api_key = api_key
@@ -23,11 +28,11 @@ class BaseClientProvider:
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
def generate_response(self, request):
"""Generates a response from the AI provider client."""
# def generate_response(self, request):
# """Generates a response from the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
# msg = "This method must be implemented by the AI provider client."
# raise NotImplementedError(msg)
def health_check(self):
"""Checks the health of the AI provider client."""
@@ -43,8 +48,20 @@ class BaseClientProvider:
raise NotImplementedError(msg)
def features(self):
"""Returns the features of the AI provider client."""
# def features(self):
# """Returns the features of the AI provider client."""
msg = "This method must be implemented by the AI provider client."
raise NotImplementedError(msg)
# msg = "This method must be implemented by the AI provider client."
# raise NotImplementedError(msg)
# def structured_response(self, model, message, **kwargs):
# pass
# def structured_conversation(self, model, message, **kwargs):
# pass
# def single_message(self, model, message, **kwargs):
# return self.generate_response(message)
# def start_conversation(self, model, message, **kwargs):
# pass
-4
View File
@@ -1,4 +0,0 @@
from .base import BaseClientProvider
class Claude(BaseClientProvider):
pass
+23 -5
View File
@@ -4,6 +4,7 @@ import instructor
from openai import OpenAI as BaseOpenAI
from .base import BaseClientProvider
from ..models import AIResponse
class OpenAI(BaseClientProvider):
@@ -19,10 +20,9 @@ class OpenAI(BaseClientProvider):
if self._api_key is None:
self._api_key = os.getenv("OPENAI_API_KEY")
base_client = BaseOpenAI(api_key=self._api_key)
self.client = instructor.from_openai(base_client)
self.test_connection()
self.client = BaseOpenAI(api_key=self._api_key)
self.instructor_client = instructor.from_openai(self.client)
assert self.test_connection()
@property
def available_models(self):
@@ -35,4 +35,22 @@ class OpenAI(BaseClientProvider):
return [g for g in gen()]
def test_connection(self):
pass
"""Test the connection to OpenAI. Returns True if successful."""
return bool(len(self.available_models))
def single_message(self, message, *, response_model=False, **kwargs):
"""Generates a response from the OpenAI client."""
client = self.client if not response_model else self.instructor_client
completion = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": message}],
**kwargs,
)
return AIResponse(
response=completion,
text=completion.choices[0].message.content,
)
+11 -1
View File
@@ -1,9 +1,19 @@
from pydantic import BaseModel
from typing import Any, ClassVar
class AIRequest(BaseModel):
prompt: str
text: str
parameters: dict = {}
def __str__(self):
return self.text
class AIResponse(BaseModel):
text: str
response: Any
metadata: dict = {}
def __str__(self):
return self.text
+15 -1
View File
@@ -3,6 +3,20 @@ import simplemind
context = None
openai = simplemind.integrations.OpenAI()
openai.login()
print(openai.test_connection())
print(openai.available_models)
print()
print()
message = "who is kennethreitz?"
print(f"> {message}")
print(openai.single_message(message))
# claude = simplemind.integrations.Anthropic()
# # print(claude.test_connection())
# # print(claude.available_models)
# claude.login()