Merge pull request #22 from Siddhesh-Agarwal/main

Added Gemini Provider
This commit is contained in:
2024-10-31 11:55:47 -04:00
committed by GitHub
21 changed files with 294 additions and 116 deletions
+4 -3
View File
@@ -1,5 +1,6 @@
export OPENAI_API_KEY=""
export ANTHROPIC_API_KEY=""
export XAI_API_KEY=""
export OLLAMA_HOST_URL=""
export GEMINI_API_KEY=""
export GROQ_API_KEY=""
export OLLAMA_HOST_URL=""
export OPENAI_API_KEY=""
export XAI_API_KEY=""
+8 -7
View File
@@ -20,13 +20,14 @@ With Simplemind, tapping into AI is as easy as a friendly conversation.
To specify a specific provider or model, you can use the `llm_provider` and `llm_model` parameters when calling: `generate_text`, `generate_data`, or `create_conversation`.
- **[OpenAI's GPT](https://openai.com/gpt)**
- **[Anthropic's Claude](https://www.anthropic.com/claude)**
- **[xAI's Grok](https://x.ai/)**
- **[Groq's Groq](https://groq.com/)**
- **[Ollama](https://ollama.com)**
- [**Anthropic's Claude**](https://www.anthropic.com/claude)
- [**Google's Gemini**](https://ai.google.dev/gemini-api)
- [**Groq's Groq**](https://groq.com/)
- [**Ollama**](https://ollama.com)
- [**OpenAI's GPT**](https://openai.com/gpt)
- [**xAI's Grok**](https://x.ai/)
If you'd like to see Simplemind support additional providers or models, please send a pull request!
If you want to see Simplemind support, additional providers or models, please request a pull!
## Why SimpleMind?
- **Intuitive**: Built with Pythonic simplicity and readability in mind.
@@ -140,7 +141,7 @@ response = gpt_4o_mini.generate_text(
Harnessing the power of Python, you can easily create your own plugins to add additional functionality to your conversations:
```python
class SimpleMemoryPlugin:
class SimpleMemoryPlugin(sm.BasePlugin):
def __init__(self):
self.memories = [
"the earth has fictionally beeen destroyed.",
+7 -7
View File
@@ -1,14 +1,14 @@
from _context import sm
from pydantic import BaseModel
import openai
import faiss
import numpy as np
import os
import pickle
import faiss
import numpy as np
import openai
from _context import sm
from pydantic import BaseModel
class ContextualMemoryPlugin:
class ContextualMemoryPlugin(sm.BasePlugin):
def __init__(
self,
api_key: str,
+3 -4
View File
@@ -1,8 +1,7 @@
from typing import List, Iterator
from pydantic import BaseModel
from typing import Iterator, List
from _context import sm
from pydantic import BaseModel
class Movie(BaseModel):
@@ -25,7 +24,7 @@ class QuotesList(BaseModel):
quotes: List[MovieQuote]
def gen_quotes(n=10) -> Iterator[MovieQuote]:
def gen_quotes(n: int = 10) -> Iterator[MovieQuote]:
"""Generate a list of quotes from famous movies."""
for q in sm.generate_data(
+4 -2
View File
@@ -1,9 +1,11 @@
from _context import sm
class MathPlugin:
class MathPlugin(sm.BasePlugin):
def send_hook(self, conversation: sm.Conversation):
last_user_message = conversation.get_last_message(role="user")
if last_user_message is None:
return
if "calculate" in last_user_message.text.lower():
expression = last_user_message.text.lower().replace("calculate", "").strip()
try:
@@ -14,7 +16,7 @@ class MathPlugin:
except Exception:
conversation.add_message(
role="assistant",
text="I'm sorry, I couldn't compute that expression.",
text="I'm sorry, I couldn't compute that expression. Please try again.",
)
+3 -3
View File
@@ -1,8 +1,8 @@
from _context import sm
from pydantic import BaseModel
from typing import Literal
from _context import sm
from pydantic import BaseModel
class SentimentAnalysis(BaseModel):
sentiment: Literal["positive", "negative", "neutral"]
+1 -1
View File
@@ -1,7 +1,7 @@
from _context import sm
class SimpleMemoryPlugin(sm.BasePlugin):
class SimpleMemoryPlugin:
def __init__(self):
self.memories = [
"the earth has fictionally beeen destroyed.",
+3 -2
View File
@@ -1,6 +1,7 @@
import simplemind as sm
import time
import simplemind as sm
class ConversationPlugin(sm.BasePlugin):
def post_send_hook(self, conversation, response):
@@ -8,7 +9,7 @@ class ConversationPlugin(sm.BasePlugin):
print(f"{conversation.llm_model}:\n{response.text.strip()}\n\n------------\n")
def have_conversation(rounds=3):
def have_conversation(rounds: int = 3):
# Create two conversations - one for each AI
with (
sm.create_conversation(
+2 -2
View File
@@ -3,8 +3,8 @@ name = "simplemind"
version = "0.1.4"
description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases."
readme = "README.md"
requires-python = ">=3.11"
dependencies = ["pydantic", "pydantic-settings", "instructor", "openai", "anthropic", "ollama", "groq"]
requires-python = ">=3.10"
dependencies = ["pydantic", "pydantic-settings", "instructor", "openai", "anthropic", "ollama", "groq", "google-generativeai"]
[build-system]
requires = ["hatchling"]
+19 -8
View File
@@ -1,8 +1,8 @@
from typing import List, Optional, Type
from typing import List, Type
from .models import Conversation, BasePlugin, BaseModel
from .utils import find_provider
from .models import BaseModel, BasePlugin, Conversation
from .settings import settings
from .utils import find_provider
class Session:
@@ -56,9 +56,9 @@ class Session:
def create_conversation(
*,
llm_model=None,
llm_provider=None,
plugins: Optional[List[BasePlugin]] = None,
llm_model: str | None = None,
llm_provider: str | None = None,
plugins: List[BasePlugin] | None = None,
**kwargs,
) -> Conversation:
"""Create a new conversation."""
@@ -77,7 +77,12 @@ def create_conversation(
def generate_data(
prompt, *, llm_model=None, llm_provider=None, response_model=None, **kwargs
prompt: str,
*,
llm_model: str | None = None,
llm_provider: str | None = None,
response_model: Type[BaseModel],
**kwargs,
) -> BaseModel:
"""Generate structured data from a given prompt."""
@@ -92,7 +97,13 @@ def generate_data(
)
def generate_text(prompt, *, llm_model=None, llm_provider=None, **kwargs) -> str:
def generate_text(
prompt: str,
*,
llm_model: str | None = None,
llm_provider: str | None = None,
**kwargs,
) -> str:
"""Generate text from a given prompt."""
# Find the provider.
+38 -18
View File
@@ -1,18 +1,18 @@
from types import TracebackType
import uuid
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
from .utils import find_provider
from .providers import find_provider
MESSAGE_ROLE = Literal["system", "user", "assistant"]
class SMBaseModel(BaseModel):
"""The base SimpleMind model class."""
date_created: datetime = Field(default_factory=datetime.now)
def __str__(self):
@@ -22,34 +22,36 @@ class SMBaseModel(BaseModel):
return str(self)
class BasePlugin:
class BasePlugin(SMBaseModel):
"""The base conversation plugin class."""
# Plugin metadata.
meta: Dict[str, Any] = {}
def initialize_hook(self, conversation: "Conversation"):
def initialize_hook(self, conversation: "Conversation") -> Any:
"""Initialize a hook for the plugin."""
raise NotImplementedError
def cleanup_hook(self, conversation: "Conversation"):
def cleanup_hook(self, conversation: "Conversation") -> Any:
"""Cleanup a hook for the plugin."""
raise NotImplementedError
def add_message_hook(self, conversation: "Conversation", message: "Message"):
def add_message_hook(self, conversation: "Conversation", message: "Message") -> Any:
"""Add a message hook for the plugin."""
raise NotImplementedError
def pre_send_hook(self, conversation: "Conversation"):
def pre_send_hook(self, conversation: "Conversation") -> Any:
"""Pre-send hook for the plugin."""
raise NotImplementedError
def post_send_hook(self, conversation: "Conversation", response: "Message"):
def post_send_hook(self, conversation: "Conversation", response: "Message") -> Any:
"""Post-send hook for the plugin."""
raise NotImplementedError
class Message(SMBaseModel):
"""A message in a conversation."""
role: MESSAGE_ROLE
text: str
meta: Dict[str, Any] = {}
@@ -61,7 +63,16 @@ class Message(SMBaseModel):
return f"<Message role={self.role} text={self.text!r}>"
@classmethod
def from_raw_response(cls, *, text: str, raw):
def from_raw_response(cls, *, text: str, raw: Any) -> "Message":
"""Create a Message instance from a raw response.
Args:
text (str): The message text.
raw (Any): The raw response data.
Returns:
Message: A new Message instance.
"""
self = cls()
self.text = text
self.raw = raw
@@ -69,11 +80,13 @@ class Message(SMBaseModel):
class Conversation(SMBaseModel):
"""A conversation between a user and an assistant."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
messages: List[Message] = []
llm_model: Optional[str] = None
llm_provider: Optional[str] = None
plugins: List[Any] = []
plugins: List[BasePlugin] = []
def __str__(self):
return f"<Conversation id={self.id!r}>"
@@ -89,8 +102,13 @@ class Conversation(SMBaseModel):
return self
def __exit__(self, exc_type, exc_value, traceback):
# Execute all cleanup hooks.
def __exit__(
self,
exc_type: type[BaseException],
exc_value: BaseException,
traceback: TracebackType,
) -> None:
"""Execute all cleanup hooks."""
for plugin in self.plugins:
if hasattr(plugin, "cleanup_hook"):
try:
@@ -99,7 +117,7 @@ class Conversation(SMBaseModel):
pass
def prepend_system_message(
self, role: str, text: str, meta: Optional[Dict[str, Any]] = None
self, role: MESSAGE_ROLE, text: str, meta: Dict[str, Any] | None = None
):
"""Prepend a system message to the conversation."""
self.messages = [Message(role=role, text=text, meta=meta or {})] + self.messages
@@ -127,7 +145,9 @@ class Conversation(SMBaseModel):
self.messages.append(Message(role=role, text=text, meta=meta))
def send(
self, llm_model: Optional[str] = None, llm_provider: Optional[str] = None
self,
llm_model: str | None = None,
llm_provider: str | None = None,
) -> Message:
"""Send the conversation to the LLM."""
@@ -156,10 +176,10 @@ class Conversation(SMBaseModel):
return response
def get_last_message(self, role: MESSAGE_ROLE) -> Optional[Message]:
def get_last_message(self, role: MESSAGE_ROLE) -> Message | None:
"""Get the last message with the given role."""
return next((m for m in reversed(self.messages) if m.role == role), None)
def add_plugin(self, plugin: Any):
def add_plugin(self, plugin: BasePlugin) -> None:
"""Add a plugin to the conversation."""
self.plugins.append(plugin)
+3 -2
View File
@@ -2,9 +2,10 @@ from typing import List, Type
from ._base import BaseProvider
from .anthropic import Anthropic
from .gemini import Gemini
from .groq import Groq
from .openai import OpenAI
from .ollama import Ollama
from .openai import OpenAI
from .xai import XAI
providers: List[Type[BaseProvider]] = [Anthropic, Groq, OpenAI, Ollama, XAI]
providers: List[Type[BaseProvider]] = [Anthropic, Gemini, Groq, OpenAI, Ollama, XAI]
+9 -4
View File
@@ -1,6 +1,11 @@
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Any, Type, TypeVar
from instructor import Instructor
from pydantic import BaseModel
T = TypeVar("T", bound=BaseModel)
class BaseProvider(ABC):
@@ -9,13 +14,13 @@ class BaseProvider(ABC):
NAME: str
DEFAULT_MODEL: str
@property
@cached_property
@abstractmethod
def client(self):
def client(self) -> Any:
"""The instructor client for the provider."""
raise NotImplementedError
@property
@cached_property
@abstractmethod
def structured_client(self) -> Instructor:
"""The structured client for the provider."""
@@ -27,7 +32,7 @@ class BaseProvider(ABC):
raise NotImplementedError
@abstractmethod
def structured_response(self, prompt: str, response_model, **kwargs):
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
"""Get a structured response."""
raise NotImplementedError
+13 -8
View File
@@ -1,10 +1,15 @@
from typing import Union
from functools import cached_property
from typing import Type, TypeVar
import anthropic
import instructor
from pydantic import BaseModel
from ._base import BaseProvider
from ..settings import settings
from ._base import BaseProvider
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "anthropic"
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
@@ -15,17 +20,17 @@ class Anthropic(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
def __init__(self, api_key: Union[str, None] = None):
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@property
@cached_property
def client(self):
"""The raw Anthropic client."""
if not self.api_key:
raise ValueError("Anthropic API key is required")
return anthropic.Anthropic(api_key=self.api_key)
@property
@cached_property
def structured_client(self):
"""A client patched with Instructor."""
return instructor.from_anthropic(self.client)
@@ -57,13 +62,13 @@ class Anthropic(BaseProvider):
llm_provider=PROVIDER_NAME,
)
def structured_response(self, model, response_model, **kwargs):
def structured_response(self, model: str, response_model: Type[T], **kwargs) -> T:
response = self.structured_client.messages.create(
model=model, response_model=response_model or self.DEFAULT_MODEL, **kwargs
model=model or self.DEFAULT_MODEL, response_model=response_model, **kwargs
)
return response
def generate_text(self, prompt, *, llm_model, **kwargs):
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
messages = [
{"role": "user", "content": prompt},
]
+94
View File
@@ -0,0 +1,94 @@
from functools import cached_property
from typing import Type, TypeVar
import google.generativeai as genai
import instructor
from pydantic import BaseModel
from ..models import Conversation, Message
from ..settings import settings
from ._base import BaseProvider
PROVIDER_NAME = "gemini"
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
T = TypeVar("T", bound=BaseModel)
class Gemini(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
self.model_name = DEFAULT_MODEL
@cached_property
def client(self, model_name: str = DEFAULT_MODEL):
"""The raw Gemini client."""
if not self.api_key:
raise ValueError("Gemini API key is required")
self.model_name = model_name
return genai.GenerativeModel(model_name=self.model_name)
@cached_property
def structured_client(self):
"""A Gemini client patched with Instructor."""
return instructor.from_gemini(self.client)
def send_conversation(self, conversation: "Conversation") -> "Message":
"""Send a conversation to the Gemini API."""
messages = [
{
"role": msg.role,
"content": msg.text,
"metadata": msg.meta or {},
}
for msg in conversation.messages
]
try:
response = self.structured_client.chat.completions.create(
messages=messages, response_model=None
)
except Exception as e:
# Handle the exception appropriately, e.g., log the error or raise a custom exception
raise RuntimeError(f"Failed to send conversation to Gemini API: {e}") from e
# Create and return a properly formatted Message instance
return Message(
role="assistant",
text=str(response),
raw=response,
llm_model=self.model_name,
llm_provider=PROVIDER_NAME,
)
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
"""Send a structured response to the Gemini API."""
try:
response = self.structured_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
response_model=response_model,
**kwargs,
)
except Exception as e:
# Handle the exception appropriately, e.g., log the error or raise a custom exception
raise RuntimeError(
f"Failed to send structured response to Gemini API: {e}"
) from e
return response
def generate_text(self, prompt: str, **kwargs) -> str:
"""Generate text using the Gemini API."""
try:
response = self.structured_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
response_model=None,
**kwargs,
)
except Exception as e:
# Handle the exception appropriately, e.g., log the error or raise a custom exception
raise RuntimeError(f"Failed to generate text with Gemini API: {e}") from e
return str(response)
+10 -6
View File
@@ -1,30 +1,34 @@
from typing import Union
from functools import cached_property
from typing import Type, TypeVar
import groq
import instructor
from pydantic import BaseModel
from ._base import BaseProvider
from ..settings import settings
from ._base import BaseProvider
PROVIDER_NAME = "groq"
DEFAULT_MODEL = "llama3-8b-8192"
T = TypeVar("T", bound=BaseModel)
class Groq(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
def __init__(self, api_key: Union[str, None] = None):
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@property
@cached_property
def client(self):
"""The raw Groq client."""
if not self.api_key:
raise ValueError("Groq API key is required")
return groq.Groq(api_key=self.api_key)
@property
@cached_property
def structured_client(self):
"""A client patched with Instructor."""
return instructor.from_groq(self.client)
@@ -59,7 +63,7 @@ class Groq(BaseProvider):
llm_provider=PROVIDER_NAME,
)
def structured_response(self, prompt: str, response_model, **kwargs):
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
# Ensure messages are provided in kwargs
messages = [
{"role": "user", "content": prompt},
+23 -12
View File
@@ -1,9 +1,16 @@
import ollama as ol
import instructor
from openai import OpenAI
from functools import cached_property
from typing import Type, TypeVar
import instructor
import ollama as ol
from openai import OpenAI
from pydantic import BaseModel
from ._base import BaseProvider
from ..settings import settings
from ._base import BaseProvider
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "ollama"
DEFAULT_MODEL = "llama3.2"
@@ -15,18 +22,18 @@ class Ollama(BaseProvider):
DEFAULT_MODEL = DEFAULT_MODEL
TIMEOUT = DEFAULT_TIMEOUT
def __init__(self, host_url: str = None):
def __init__(self, host_url: str | None = None):
self.host_url = host_url or settings.OLLAMA_HOST_URL
@property
@cached_property
def client(self):
"""The raw Ollama client."""
if not self.host_url:
raise ValueError("No ollama host url provided")
return ol.Client(timeout=self.TIMEOUT, host=self.host_url)
@property
def structured_client(self):
@cached_property
def structured_client(self) -> instructor.Instructor:
"""A client patched with Instructor."""
return instructor.from_openai(
OpenAI(
@@ -36,7 +43,7 @@ class Ollama(BaseProvider):
mode=instructor.Mode.JSON,
)
def send_conversation(self, conversation: "Conversation"):
def send_conversation(self, conversation: "Conversation") -> "Message":
"""Send a conversation to the Ollama API."""
from ..models import Message
@@ -57,7 +64,10 @@ class Ollama(BaseProvider):
llm_provider=PROVIDER_NAME,
)
def structured_response(self, prompt, response_model, *, llm_model: str, **kwargs):
def structured_response(
self, prompt: str, response_model: Type[T], *, llm_model: str, **kwargs
) -> T:
"""Get a structured response from the Ollama API."""
messages = [
{"role": "user", "content": prompt},
]
@@ -70,7 +80,8 @@ class Ollama(BaseProvider):
)
return response
def generate_text(self, prompt, *, llm_model):
def generate_text(self, prompt: str, *, llm_model: str) -> str:
"""Generate text using the Ollama API."""
messages = [
{"role": "user", "content": prompt},
]
@@ -79,4 +90,4 @@ class Ollama(BaseProvider):
messages=messages, model=llm_model or self.DEFAULT_MODEL
)
return response.get("message").get("content")
return response.get("message", {}).get("content", "")
+20 -10
View File
@@ -1,10 +1,14 @@
from typing import Union
from functools import cached_property
from typing import Type, TypeVar
import instructor
import openai as oa
from pydantic import BaseModel
from ._base import BaseProvider
from ..settings import settings
from ._base import BaseProvider
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "openai"
DEFAULT_MODEL = "gpt-4o-mini"
@@ -14,17 +18,17 @@ class OpenAI(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
def __init__(self, api_key: Union[str, None] = None):
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@property
@cached_property
def client(self):
"""The raw OpenAI client."""
if not self.api_key:
raise ValueError("OpenAI API key is required")
return oa.OpenAI(api_key=self.api_key)
@property
@cached_property
def structured_client(self):
"""A OpenAI client with Instructor."""
return instructor.from_openai(self.client)
@@ -53,12 +57,19 @@ class OpenAI(BaseProvider):
llm_provider=PROVIDER_NAME,
)
def structured_response(self, prompt, response_model, *, llm_model: str, **kwargs):
def structured_response(
self,
prompt: str,
response_model: Type[T],
*,
llm_model: str | None = None,
**kwargs,
) -> T:
"""Get a structured response from the OpenAI API."""
# Ensure messages are provided in kwargs
messages = [
{"role": "user", "content": prompt},
]
response = self.structured_client.chat.completions.create(
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
@@ -67,13 +78,12 @@ class OpenAI(BaseProvider):
)
return response
def generate_text(self, prompt, *, llm_model, **kwargs):
def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs):
"""Generate text using the OpenAI API."""
messages = [
{"role": "user", "content": prompt},
]
response = self.client.chat.completions.create(
messages=messages, model=llm_model or self.DEFAULT_MODEL, **kwargs
)
return response.choices[0].message.content
+7 -7
View File
@@ -1,10 +1,10 @@
from typing import Union
from functools import cached_property
import instructor
import openai as oa
from ._base import BaseProvider
from ..settings import settings
from ._base import BaseProvider
PROVIDER_NAME = "xai"
DEFAULT_MODEL = "grok-beta"
@@ -16,10 +16,10 @@ class XAI(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
def __init__(self, api_key: Union[str, None] = None):
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@property
@cached_property
def client(self):
"""The raw OpenAI client."""
if not self.api_key:
@@ -29,7 +29,7 @@ class XAI(BaseProvider):
base_url=BASE_URL,
)
@property
@cached_property
def structured_client(self):
"""A client patched with Instructor."""
return instructor.from_openai(self.client)
@@ -60,10 +60,10 @@ class XAI(BaseProvider):
llm_provider=PROVIDER_NAME,
)
def structured_response(self, prompt: str, response_model, *, llm_model):
def structured_response(self, prompt: str, response_model, *, llm_model: str):
raise NotImplementedError("XAI does not support structured responses")
def generate_text(self, prompt, *, llm_model, **kwargs):
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
messages = [
{"role": "user", "content": prompt},
]
+14 -1
View File
@@ -1,8 +1,19 @@
from typing import Optional, Union
from typing import Literal, Optional, Union
from pydantic import Field, SecretStr, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
logging_level = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
class LoggingConfig(BaseSettings):
"""The class that holds all the logging settings for the application."""
enabled: bool = Field(False, description="Enable logging")
level: logging_level = Field("INFO", description="The logging level")
model_config = SettingsConfigDict(extra="forbid")
class Settings(BaseSettings):
"""The class that holds all the API keys for the application."""
@@ -11,6 +22,7 @@ class Settings(BaseSettings):
None, description="API key for Anthropic"
)
GROQ_API_KEY: Optional[SecretStr] = Field(None, description="API key for Groq")
GEMINI_API_KEY: Optional[SecretStr] = Field(None, description="API key for Gemini")
OPENAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for OpenAI")
OLLAMA_HOST_URL: Optional[str] = Field(
"http://127.0.0.1:11434", description="Fully qualified host URL for Ollama"
@@ -22,6 +34,7 @@ class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", case_sensitive=True, extra="ignore"
)
logging: LoggingConfig = LoggingConfig()
@field_validator("*", mode="before")
@classmethod
+9 -9
View File
@@ -1,24 +1,26 @@
import difflib
from typing import Optional, Type
from .providers import providers, BaseProvider
from .providers import BaseProvider, providers
_PROVIDER_NAMES = [provider.NAME.lower() for provider in providers]
def find_provider(provider_name: str) -> BaseProvider:
def find_provider(provider_name: str | None) -> BaseProvider:
"""
Find and instantiate a provider by name.
Parameters:
provider_name (Union[str, None]): The name of the provider to find.
provider_name (Union[str, None]): The name of the provider to find.
Returns:
An instance of the provider class if found.
An instance of the provider class if found.
Raises:
ValueError: If the provider is not found, with a suggestion for the closest match.
ValueError: If the provider is not specified or is not found, with a suggestion for the closest match.
"""
if provider_name is None:
raise ValueError("No provider specified.")
# Find the provider by name.
for provider_class in providers:
if provider_class.NAME.lower() == provider_name.lower():
@@ -29,10 +31,8 @@ def find_provider(provider_name: str) -> BaseProvider:
provider_found = difflib.get_close_matches(
provider_name.lower(), _PROVIDER_NAMES, n=1
)
if provider_found:
raise ValueError(
f"Provider {provider_name!r} not found. Did you mean {provider_found[0]!r}?"
)
else:
raise ValueError(f"Provider {provider_name} not found.")
raise ValueError(f"Provider {provider_name} not found.")