Refactor find_provider function to remove Optional type and adjust return type to BaseProvider; update import statements for consistency

This commit is contained in:
2024-10-30 18:00:26 -04:00
parent d972f1cd85
commit 2309c30b8f
2 changed files with 14 additions and 13 deletions
+6 -6
View File
@@ -1,10 +1,10 @@
from typing import List, Type
from simplemind.providers._base import BaseProvider
from simplemind.providers.anthropic import Anthropic
from simplemind.providers.groq import Groq
from simplemind.providers.openai import OpenAI
from simplemind.providers.ollama import Ollama
from simplemind.providers.xai import XAI
from ._base import BaseProvider
from .anthropic import Anthropic
from .groq import Groq
from .openai import OpenAI
from .ollama import Ollama
from .xai import XAI
providers: List[Type[BaseProvider]] = [Anthropic, Groq, OpenAI, Ollama, XAI]
+8 -7
View File
@@ -6,7 +6,7 @@ from .providers import providers, BaseProvider
_PROVIDER_NAMES = [provider.NAME.lower() for provider in providers]
def find_provider(provider_name: Optional[str]) -> Type[BaseProvider]:
def find_provider(provider_name: str) -> BaseProvider:
"""
Find and instantiate a provider by name.
@@ -19,15 +19,16 @@ def find_provider(provider_name: Optional[str]) -> Type[BaseProvider]:
Raises:
ValueError: If the provider is not found, with a suggestion for the closest match.
"""
if provider_name:
for provider_class in providers:
if provider_class.NAME.lower() == provider_name.lower():
# Instantiate the provider
return provider_class()
# Find the provider by name.
for provider_class in providers:
if provider_class.NAME.lower() == provider_name.lower():
# Instantiate the provider
return provider_class()
# Find the closest match
provider_found = difflib.get_close_matches(
provider_name.lower(), _PROVIDER_NAMES, n=1
) # Show only one suggestion
)
if provider_found:
raise ValueError(