diff --git a/simplemind/providers/__init__.py b/simplemind/providers/__init__.py index 403d1a7..f9b1983 100644 --- a/simplemind/providers/__init__.py +++ b/simplemind/providers/__init__.py @@ -1,10 +1,10 @@ from typing import List, Type -from simplemind.providers._base import BaseProvider -from simplemind.providers.anthropic import Anthropic -from simplemind.providers.groq import Groq -from simplemind.providers.openai import OpenAI -from simplemind.providers.ollama import Ollama -from simplemind.providers.xai import XAI +from ._base import BaseProvider +from .anthropic import Anthropic +from .groq import Groq +from .openai import OpenAI +from .ollama import Ollama +from .xai import XAI providers: List[Type[BaseProvider]] = [Anthropic, Groq, OpenAI, Ollama, XAI] diff --git a/simplemind/utils.py b/simplemind/utils.py index 1ed57bc..67477ca 100644 --- a/simplemind/utils.py +++ b/simplemind/utils.py @@ -6,7 +6,7 @@ from .providers import providers, BaseProvider _PROVIDER_NAMES = [provider.NAME.lower() for provider in providers] -def find_provider(provider_name: Optional[str]) -> Type[BaseProvider]: +def find_provider(provider_name: str) -> BaseProvider: """ Find and instantiate a provider by name. @@ -19,15 +19,16 @@ def find_provider(provider_name: Optional[str]) -> Type[BaseProvider]: Raises: ValueError: If the provider is not found, with a suggestion for the closest match. """ - if provider_name: - for provider_class in providers: - if provider_class.NAME.lower() == provider_name.lower(): - # Instantiate the provider - return provider_class() + # Find the provider by name. + for provider_class in providers: + if provider_class.NAME.lower() == provider_name.lower(): + # Instantiate the provider + return provider_class() + # Find the closest match provider_found = difflib.get_close_matches( provider_name.lower(), _PROVIDER_NAMES, n=1 - ) # Show only one suggestion + ) if provider_found: raise ValueError(