mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 22:50:18 +00:00
test
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
from .claude import Claude
|
||||
from .anthropic import Anthropic
|
||||
from .openai import OpenAI
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
import os
|
||||
|
||||
import instructor
|
||||
from anthropic import Anthropic as BaseAnthropic
|
||||
|
||||
from .base import BaseClientProvider
|
||||
|
||||
|
||||
class Anthropic(BaseClientProvider):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.login()
|
||||
|
||||
def login(self):
|
||||
"""Initialize Anthropic client, with Instructor enabled."""
|
||||
|
||||
# Default to environment variable if not provided.
|
||||
if self._api_key is None:
|
||||
self._api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
|
||||
base_client = BaseAnthropic(api_key=self._api_key)
|
||||
self.client = instructor.from_anthropic(base_client)
|
||||
# assert self.test_connection()
|
||||
|
||||
@property
|
||||
def available_models(self):
|
||||
"""Returns the available models from the Anthropic client."""
|
||||
|
||||
# TODO: scrape from website or embed
|
||||
return [
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-haiku-20240307",
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-5-sonnet-20240620",
|
||||
]
|
||||
|
||||
# def test_connection(self):
|
||||
# """Test the connection to Anthropic. Returns True if successful."""
|
||||
|
||||
# raise NotImplementedError("Anthropic test_connection not implemented.")
|
||||
@@ -2,11 +2,16 @@
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
DEFAULT_MODEL = "gpt-4o"
|
||||
|
||||
|
||||
class BaseClientProvider:
|
||||
|
||||
def __init__(self, *, api_key=None):
|
||||
def __init__(self, *, model=DEFAULT_MODEL, api_key=None):
|
||||
# self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.client = None
|
||||
self.model = model
|
||||
|
||||
# Load API key from environment if not provided
|
||||
self._api_key = api_key
|
||||
@@ -23,11 +28,11 @@ class BaseClientProvider:
|
||||
msg = "This method must be implemented by the AI provider client."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def generate_response(self, request):
|
||||
"""Generates a response from the AI provider client."""
|
||||
# def generate_response(self, request):
|
||||
# """Generates a response from the AI provider client."""
|
||||
|
||||
msg = "This method must be implemented by the AI provider client."
|
||||
raise NotImplementedError(msg)
|
||||
# msg = "This method must be implemented by the AI provider client."
|
||||
# raise NotImplementedError(msg)
|
||||
|
||||
def health_check(self):
|
||||
"""Checks the health of the AI provider client."""
|
||||
@@ -43,8 +48,20 @@ class BaseClientProvider:
|
||||
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def features(self):
|
||||
"""Returns the features of the AI provider client."""
|
||||
# def features(self):
|
||||
# """Returns the features of the AI provider client."""
|
||||
|
||||
msg = "This method must be implemented by the AI provider client."
|
||||
raise NotImplementedError(msg)
|
||||
# msg = "This method must be implemented by the AI provider client."
|
||||
# raise NotImplementedError(msg)
|
||||
|
||||
# def structured_response(self, model, message, **kwargs):
|
||||
# pass
|
||||
|
||||
# def structured_conversation(self, model, message, **kwargs):
|
||||
# pass
|
||||
|
||||
# def single_message(self, model, message, **kwargs):
|
||||
# return self.generate_response(message)
|
||||
|
||||
# def start_conversation(self, model, message, **kwargs):
|
||||
# pass
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
from .base import BaseClientProvider
|
||||
|
||||
class Claude(BaseClientProvider):
|
||||
pass
|
||||
@@ -4,6 +4,7 @@ import instructor
|
||||
from openai import OpenAI as BaseOpenAI
|
||||
|
||||
from .base import BaseClientProvider
|
||||
from ..models import AIResponse
|
||||
|
||||
|
||||
class OpenAI(BaseClientProvider):
|
||||
@@ -19,10 +20,9 @@ class OpenAI(BaseClientProvider):
|
||||
if self._api_key is None:
|
||||
self._api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
base_client = BaseOpenAI(api_key=self._api_key)
|
||||
self.client = instructor.from_openai(base_client)
|
||||
self.test_connection()
|
||||
|
||||
self.client = BaseOpenAI(api_key=self._api_key)
|
||||
self.instructor_client = instructor.from_openai(self.client)
|
||||
assert self.test_connection()
|
||||
|
||||
@property
|
||||
def available_models(self):
|
||||
@@ -35,4 +35,22 @@ class OpenAI(BaseClientProvider):
|
||||
return [g for g in gen()]
|
||||
|
||||
def test_connection(self):
|
||||
pass
|
||||
"""Test the connection to OpenAI. Returns True if successful."""
|
||||
|
||||
return bool(len(self.available_models))
|
||||
|
||||
def single_message(self, message, *, response_model=False, **kwargs):
|
||||
"""Generates a response from the OpenAI client."""
|
||||
|
||||
client = self.client if not response_model else self.instructor_client
|
||||
|
||||
completion = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": message}],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return AIResponse(
|
||||
response=completion,
|
||||
text=completion.choices[0].message.content,
|
||||
)
|
||||
|
||||
+11
-1
@@ -1,9 +1,19 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, ClassVar
|
||||
|
||||
|
||||
class AIRequest(BaseModel):
|
||||
prompt: str
|
||||
text: str
|
||||
parameters: dict = {}
|
||||
|
||||
def __str__(self):
|
||||
return self.text
|
||||
|
||||
|
||||
class AIResponse(BaseModel):
|
||||
text: str
|
||||
response: Any
|
||||
metadata: dict = {}
|
||||
|
||||
def __str__(self):
|
||||
return self.text
|
||||
|
||||
@@ -3,6 +3,20 @@ import simplemind
|
||||
context = None
|
||||
|
||||
openai = simplemind.integrations.OpenAI()
|
||||
openai.login()
|
||||
|
||||
print(openai.test_connection())
|
||||
print(openai.available_models)
|
||||
|
||||
print()
|
||||
print()
|
||||
message = "who is kennethreitz?"
|
||||
|
||||
print(f"> {message}")
|
||||
print(openai.single_message(message))
|
||||
|
||||
# claude = simplemind.integrations.Anthropic()
|
||||
|
||||
# # print(claude.test_connection())
|
||||
# # print(claude.available_models)
|
||||
|
||||
# claude.login()
|
||||
|
||||
Reference in New Issue
Block a user