Merge pull request #8 from Siddhesh-Agarwal/main

This commit is contained in:
2024-10-28 23:49:03 -04:00
committed by GitHub
8 changed files with 119 additions and 33 deletions
+1
View File
@@ -1,3 +1,4 @@
export OPENAI_API_KEY=""
export ANTHROPIC_API_KEY=""
export XAI_API_KEY=""
export GROQ_API_KEY=""
+35 -2
View File
@@ -1,6 +1,6 @@
import uuid
from typing import List, Dict, Any, Optional
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
@@ -17,8 +17,41 @@ class SMBaseModel(BaseModel):
return str(self)
class BaseProvider(SMBaseModel):
"""The base provider class."""
__name__ = "BaseProvider"
DEFAULT_MODEL = "DEFAULT_MODEL"
@property
def client(self):
"""The instructor client for the provider."""
raise NotImplementedError
@property
def structured_client(self):
"""The structured client for the provider."""
raise NotImplementedError
def send_conversation(self, conversation: "Conversation"):
"""Send a conversation to the provider."""
raise NotImplementedError
def structured_response(self, prompt: str, response_model, **kwargs):
"""Get a structured response."""
raise NotImplementedError
def generate_text(self, prompt: str, **kwargs):
"""Generate text from a prompt."""
raise NotImplementedError
class BasePlugin(SMBaseModel):
"""The base plugin class."""
class Message(SMBaseModel):
role: str
role: Literal["system", "user", "assistant"]
text: str
meta: Dict[str, Any] = {}
raw: Optional[Any] = None
+3 -2
View File
@@ -1,5 +1,6 @@
from .openai import OpenAI
from .anthropic import Anthropic
from .groq import Groq
from .openai import OpenAI
from .xai import XAI
providers = [OpenAI, Anthropic, XAI]
providers = [Anthropic, Groq, OpenAI, XAI]
+7 -7
View File
@@ -1,19 +1,21 @@
from typing import Union
import anthropic
import instructor
# from ..models import Conversation, Message
from ..settings import settings
from simplemind.models import BaseProvider, Conversation, Message
from simplemind.settings import settings
PROVIDER_NAME = "anthropic"
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
DEFAULT_MAX_TOKENS = 1000
class Anthropic:
class Anthropic(BaseProvider):
__name__ = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
def __init__(self, api_key: str = None):
def __init__(self, api_key: Union[str, None] = None):
self.api_key = api_key or settings.ANTHROPIC_API_KEY
@property
@@ -24,12 +26,10 @@ class Anthropic:
@property
def structured_client(self):
"""A client patched with Instructor."""
return instructor.from_anthropic(anthropic.Anthropic(api_key=self.api_key))
return instructor.from_anthropic(self.client)
def send_conversation(self, conversation: "Conversation"):
"""Send a conversation to the Anthropic API."""
from ..models import Message
messages = [
{"role": msg.role, "content": msg.text} for msg in conversation.messages
]
+50
View File
@@ -0,0 +1,50 @@
from typing import Union
import groq
import instructor
from simplemind.models import BaseProvider, Conversation, Message
from simplemind.settings import settings
PROVIDER_NAME = "groq"
DEFAULT_MODEL = "llama3-8b-8192"
class Groq(BaseProvider):
__name__ = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
def __init__(self, api_key: Union[str, None] = None):
self.api_key = api_key or settings.GROQ_API_KEY
@property
def client(self):
"""The raw Groq client."""
return groq.Groq(api_key=self.api_key)
@property
def structured_client(self):
"""A client patched with Instructor."""
return instructor.from_groq(self.client)
def send_conversation(self, conversation: "Conversation"):
"""Send a conversation to the Groq API."""
messages = [
{"role": msg.role, "content": msg.text} for msg in conversation.messages
]
response = self.client.chat.completions.create(
model=conversation.llm_model or DEFAULT_MODEL, messages=messages
)
# Get the response content from the Groq response
assistant_message = response.choices[0].message
# Create and return a properly formatted Message instance
return Message(
role="assistant",
text=assistant_message.content,
raw=response,
llm_model=conversation.llm_model or DEFAULT_MODEL,
llm_provider=PROVIDER_NAME,
)
+11 -10
View File
@@ -1,17 +1,20 @@
import openai as oa
import instructor
from typing import Union
from ..settings import settings
import instructor
import openai as oa
from simplemind.models import BaseProvider, Conversation, Message
from simplemind.settings import settings
PROVIDER_NAME = "openai"
DEFAULT_MODEL = "gpt-4o-mini"
class OpenAI:
class OpenAI(BaseProvider):
__name__ = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
def __init__(self, api_key: str = None):
def __init__(self, api_key: Union[str, None] = None):
self.api_key = api_key or settings.OPENAI_API_KEY
@property
@@ -21,13 +24,11 @@ class OpenAI:
@property
def structured_client(self):
"""A client patched with Instructor."""
return instructor.patch(oa.OpenAI(api_key=self.api_key))
"""A OpenAI client with Instructor."""
return instructor.from_openai(self.client)
def send_conversation(self, conversation: "Conversation"):
"""Send a conversation to the OpenAI API."""
from ..models import Message
messages = [
{"role": msg.role, "content": msg.text} for msg in conversation.messages
]
@@ -42,7 +43,7 @@ class OpenAI:
# Create and return a properly formatted Message instance
return Message(
role="assistant",
text=assistant_message.content,
text=assistant_message.content or "",
raw=response,
llm_model=conversation.llm_model or DEFAULT_MODEL,
llm_provider=PROVIDER_NAME,
+10 -11
View File
@@ -1,8 +1,10 @@
import openai as oa
import instructor
from typing import Union
# from ..models import Conversation, Message
from ..settings import settings
import instructor
import openai as oa
from simplemind.models import BaseProvider, Conversation, Message
from simplemind.settings import settings
PROVIDER_NAME = "xai"
DEFAULT_MODEL = "grok-beta"
@@ -10,11 +12,11 @@ BASE_URL = "https://api.x.ai/v1"
DEFAULT_MAX_TOKENS = 1000
class XAI:
class XAI(BaseProvider):
__name__ = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
def __init__(self, api_key: str = None):
def __init__(self, api_key: Union[str, None] = None):
self.api_key = api_key or settings.XAI_API_KEY
@property
@@ -29,13 +31,10 @@ class XAI:
@property
def structured_client(self):
"""A client patched with Instructor."""
return instructor.patch(
oa.OpenAI(api_key=self.api_key, base_url="https://api.x.ai/v1")
)
return instructor.from_openai(self.client)
def send_conversation(self, conversation: "Conversation"):
"""Send a conversation to the OpenAI API."""
from ..models import Message
messages = [
{"role": msg.role, "content": msg.text} for msg in conversation.messages
@@ -57,7 +56,7 @@ class XAI:
llm_provider=PROVIDER_NAME,
)
def structured_response(self, prompt, response_model, *, llm_model):
def structured_response(self, prompt: str, response_model, *, llm_model):
raise NotImplementedError("XAI does not support structured responses")
def generate_text(self, prompt, *, llm_model):
+2 -1
View File
@@ -3,8 +3,9 @@ from pydantic_settings import BaseSettings
class Settings(BaseSettings):
OPENAI_API_KEY: str = Field(..., env="OPENAI_API_KEY")
ANTHROPIC_API_KEY: str = Field(..., env="ANTHROPIC_API_KEY")
GROQ_API_KEY: str = Field(..., env="GROQ_API_KEY")
OPENAI_API_KEY: str = Field(..., env="OPENAI_API_KEY")
XAI_API_KEY: str = Field(..., env="XAI_API_KEY")