add Amazon Bedrock provider

This commit is contained in:
Stan Zubarev
2024-10-31 19:34:50 -04:00
parent 7c8f22bef1
commit 8d83050a64
3 changed files with 90 additions and 1 deletions
+2 -1
View File
@@ -7,5 +7,6 @@ from .groq import Groq
from .ollama import Ollama
from .openai import OpenAI
from .xai import XAI
from .amazon import Amazon
providers: List[Type[BaseProvider]] = [Anthropic, Gemini, Groq, OpenAI, Ollama, XAI]
providers: List[Type[BaseProvider]] = [Anthropic, Gemini, Groq, OpenAI, Ollama, XAI, Amazon]
+85
View File
@@ -0,0 +1,85 @@
from typing import Union
import json
import instructor
import anthropic
from ._base import BaseProvider
from ..settings import settings
PROVIDER_NAME = "amazon"
DEFAULT_MODEL = "anthropic.claude-3-sonnet-20240229-v1:0"
DEFAULT_MAX_TOKENS = 5000
class Amazon(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
def __init__(self, api_key: Union[str, None] = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@property
def client(self):
"""The AnthropicBedrock client."""
if not self.api_key:
raise ValueError("Profile name is not provided")
return anthropic.AnthropicBedrock(aws_profile=self.api_key)
@property
def structured_client(self):
"""A client patched with Instructor."""
return instructor.from_anthropic(self.client)
def send_conversation(self, conversation: "Conversation", **kwargs):
"""Send a conversation to the OpenAI API."""
from ..models import Message
messages = [
{"role": msg.role, "content": msg.text} for msg in conversation.messages
]
response = self.client.chat.completions.create(
model=conversation.llm_model or DEFAULT_MODEL, messages=messages, **kwargs
)
# Get the response content from the OpenAI response
assistant_message = response.choices[0].message
# Create and return a properly formatted Message instance
return Message(
role="assistant",
text=assistant_message.content or "",
raw=response,
llm_model=conversation.llm_model or DEFAULT_MODEL,
llm_provider=PROVIDER_NAME,
)
def structured_response(self, prompt, response_model, *, llm_model: str, **kwargs):
# Ensure messages are provided in kwargs
messages = [
{"role": "user", "content": prompt},
]
response = self.structured_client.chat.completions.create(
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
response_model=response_model,
max_tokens=DEFAULT_MAX_TOKENS,
**kwargs,
)
return response
def generate_text(self, prompt, *, llm_model, **kwargs):
messages = [
{"role": "user", "content": prompt},
]
response = self.client.messages.create(
model=llm_model or self.DEFAULT_MODEL,
messages=messages,
max_tokens=DEFAULT_MAX_TOKENS,
**kwargs,
)
return response.content[0].text
+3
View File
@@ -18,6 +18,9 @@ class LoggingConfig(BaseSettings):
class Settings(BaseSettings):
"""The class that holds all the API keys for the application."""
AMAZON_API_KEY: Optional[SecretStr] = Field(
"default", description="AWS Named Profile"
)
ANTHROPIC_API_KEY: Optional[SecretStr] = Field(
None, description="API key for Anthropic"
)