19 Commits

Author SHA1 Message Date
kennethreitz cd0be3ad89 Refactor LoggingConfig methods for enabling and disabling logging 2024-11-01 08:36:05 -04:00
kennethreitz 3dd2e1b248 Refactor Gemini provider to handle missing llm_model key 2024-11-01 08:28:53 -04:00
Siddhesh Agarwal ad1800840d small changes 2024-11-01 15:27:15 +05:30
Siddhesh Agarwal d62f297b68 removed unused variable 2024-11-01 15:16:20 +05:30
Siddhesh Agarwal a2597709d2 gemini works as expected 2024-11-01 14:55:22 +05:30
Siddhesh Agarwal 1455b5ba13 remove unused import 2024-11-01 14:31:19 +05:30
Siddhesh Agarwal 0fb54d1987 circular import problem solve 2024-11-01 14:31:01 +05:30
Siddhesh Agarwal fe06331662 fixed forced imports + ensured return type in structure_response 2024-11-01 14:24:34 +05:30
Siddhesh Agarwal 56b1e65d70 moved logging functions to LoggingConfig from Settings 2024-11-01 13:06:06 +05:30
Siddhesh Agarwal 4b3e1bc6dd added methods to toggle logging 2024-11-01 12:55:24 +05:30
Siddhesh Agarwal f5b922ade8 added proper type hinting 2024-11-01 12:25:44 +05:30
Siddhesh Agarwal 3a7383425f sorted imports 2024-11-01 11:09:54 +05:30
Siddhesh Agarwal 92c10fc41e added logging 2024-11-01 11:07:04 +05:30
kennethreitz caceba381d Refactor default_kwargs logic in Ollama provider 2024-10-31 19:49:33 -04:00
kennethreitz 0795464fd7 Merge pull request #24 from barisozmen/default_kwargs
Add default kwargs logic to Groq, OpenAI, XAI, and Ollama providers
2024-10-31 19:48:02 -04:00
Barış Özmen d82effdfb1 added default_kwargs logic to xAI provider 2024-11-01 00:18:57 +03:00
Barış Özmen e648292cb3 added default_kwargs logic to Ollama provider 2024-11-01 00:17:22 +03:00
Barış Özmen 37a9333be3 added default_kwargs logic to OpenAI provider 2024-11-01 00:15:49 +03:00
Barış Özmen cbc3739411 added default_kwargs logic to Groq provider 2024-11-01 00:14:41 +03:00
13 changed files with 217 additions and 60 deletions
+27
View File
@@ -0,0 +1,27 @@
import time
from typing import Any, Callable
import logfire
from .settings import settings
def logger(func: Callable[..., Any]) -> Callable[..., Any]:
"""A @logger decorator that logs the function parameters, function returns, and exceptions raised if logging is enabled."""
def wrapper(*args, **kwargs) -> Any:
if not settings.logging.enabled:
return func(*args, **kwargs)
logfire.info(f"Calling {func.__name__} with args: {args}, kwargs: {kwargs}")
t1 = time.perf_counter()
try:
result = func(*args, **kwargs)
t2 = time.perf_counter()
logfire.info(f"{func.__name__} returned: {result} in {t2-t1} seconds")
return result
except Exception as e:
t2 = time.perf_counter()
logfire.error(f"Error in {func.__name__}: {e} in {t2-t1} seconds")
raise e
return wrapper
+1 -1
View File
@@ -1,6 +1,6 @@
from types import TracebackType
import uuid
from datetime import datetime
from types import TracebackType
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
+4 -1
View File
@@ -1,10 +1,13 @@
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Any, Type, TypeVar
from typing import TYPE_CHECKING, Any, Type, TypeVar
from instructor import Instructor
from pydantic import BaseModel
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
+17 -4
View File
@@ -1,13 +1,16 @@
from functools import cached_property
from typing import Type, TypeVar
from typing import TYPE_CHECKING, Type, TypeVar
import anthropic
import instructor
from pydantic import BaseModel
from ..logging import logger
from ..settings import settings
from ._base import BaseProvider
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
@@ -30,6 +33,13 @@ class Anthropic(BaseProvider):
"""The raw Anthropic client."""
if not self.api_key:
raise ValueError("Anthropic API key is required")
try:
import anthropic
except ImportError as exc:
raise ImportError(
"Please install the `anthropic` package: `pip install anthropic`"
) from exc
return anthropic.Anthropic(api_key=self.api_key)
@cached_property
@@ -37,7 +47,8 @@ class Anthropic(BaseProvider):
"""A client patched with Instructor."""
return instructor.from_anthropic(self.client)
def send_conversation(self, conversation: "Conversation", **kwargs):
@logger
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
"""Send a conversation to the Anthropic API."""
from ..models import Message
@@ -63,6 +74,7 @@ class Anthropic(BaseProvider):
llm_provider=PROVIDER_NAME,
)
@logger
def structured_response(
self, response_model: Type[T], *, llm_model: str | None = None, **kwargs
) -> T:
@@ -80,8 +92,9 @@ class Anthropic(BaseProvider):
response_model=response_model,
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response
return response_model.model_validate(response)
@logger
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
messages = [
{"role": "user", "content": prompt},
+26 -10
View File
@@ -2,21 +2,25 @@
# IT is not currently working as desired.
from functools import cached_property
from typing import Type, TypeVar
from typing import TYPE_CHECKING, Type, TypeVar
import google.generativeai as genai
import instructor
from pydantic import BaseModel
from ..logging import logger
from ..settings import settings
from ._base import BaseProvider
PROVIDER_NAME = "gemini"
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "gemini"
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
class Gemini(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
@@ -25,12 +29,21 @@ class Gemini(BaseProvider):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
self.model_name = DEFAULT_MODEL
def set_model(self, model_name: str):
self.model_name = model_name
@cached_property
def client(self, model_name: str = DEFAULT_MODEL):
def client(self):
"""The raw Gemini client."""
if not self.api_key:
raise ValueError("Gemini API key is required")
self.model_name = model_name
try:
import google.generativeai as genai
except ImportError as exc:
raise ImportError(
"Please install the `google-generativeai` package: `pip install google-generativeai`"
) from exc
genai.configure(api_key=self.api_key)
return genai.GenerativeModel(model_name=self.model_name)
@cached_property
@@ -38,6 +51,7 @@ class Gemini(BaseProvider):
"""A Gemini client patched with Instructor."""
return instructor.from_gemini(self.client)
@logger
def send_conversation(self, conversation: "Conversation") -> "Message":
"""Send a conversation to the Gemini API."""
from ..models import Message
@@ -64,9 +78,11 @@ class Gemini(BaseProvider):
llm_provider=PROVIDER_NAME,
)
@logger
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
"""Send a structured response to the Gemini API."""
llm_model = kwargs.pop("llm_model", self.model_name)
# Only try to pop if the key exists
kwargs.pop("llm_model", None) # Add default value of None
try:
response = self.structured_client.chat.completions.create(
@@ -79,12 +95,12 @@ class Gemini(BaseProvider):
raise RuntimeError(
f"Failed to send structured response to Gemini API: {e}"
) from e
return response
return response_model.model_validate(response)
@logger
def generate_text(self, prompt: str, **kwargs) -> str:
"""Generate text using the Gemini API."""
llm_model = kwargs.pop("llm_model", self.model_name)
kwargs.pop("llm_model")
try:
response = self.client.generate_content(prompt, **kwargs)
except Exception as e:
+26 -10
View File
@@ -1,22 +1,29 @@
from functools import cached_property
from typing import Type, TypeVar
from typing import TYPE_CHECKING, Type, TypeVar
import groq
import instructor
from pydantic import BaseModel
from ..logging import logger
from ..settings import settings
from ._base import BaseProvider
PROVIDER_NAME = "groq"
DEFAULT_MODEL = "llama3-8b-8192"
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "groq"
DEFAULT_MODEL = "llama3-8b-8192"
DEFAULT_MAX_TOKENS = 1_000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
class Groq(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@@ -26,6 +33,12 @@ class Groq(BaseProvider):
"""The raw Groq client."""
if not self.api_key:
raise ValueError("Groq API key is required")
try:
import groq
except ImportError as exc:
raise ImportError(
"Please install the `groq` package: `pip install groq`"
) from exc
return groq.Groq(api_key=self.api_key)
@cached_property
@@ -33,6 +46,7 @@ class Groq(BaseProvider):
"""A client patched with Instructor."""
return instructor.from_groq(self.client)
@logger
def send_conversation(
self,
conversation: "Conversation",
@@ -48,7 +62,7 @@ class Groq(BaseProvider):
response = self.client.chat.completions.create(
model=conversation.llm_model or self.DEFAULT_MODEL,
messages=messages,
**kwargs,
**{**self.DEFAULT_KWARGS, **kwargs},
)
# Get the response content from the Groq response
@@ -63,6 +77,7 @@ class Groq(BaseProvider):
llm_provider=PROVIDER_NAME,
)
@logger
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
# Ensure messages are provided in kwargs
messages = [
@@ -73,17 +88,18 @@ class Groq(BaseProvider):
messages=messages,
response_model=response_model,
model=kwargs.pop("llm_model", self.DEFAULT_MODEL),
**kwargs,
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response
return response_model.model_validate(response)
@logger
def generate_text(
self,
prompt: str,
*,
llm_model: str,
**kwargs,
):
) -> str:
messages = [
{"role": "user", "content": prompt},
]
@@ -91,7 +107,7 @@ class Groq(BaseProvider):
response = self.client.chat.completions.create(
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
**kwargs,
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response.choices[0].message.content
return str(response.choices[0].message.content)
+28 -8
View File
@@ -1,25 +1,30 @@
from functools import cached_property
from typing import Type, TypeVar
from typing import TYPE_CHECKING, Type, TypeVar
import instructor
import ollama as ol
from openai import OpenAI
from pydantic import BaseModel
from ..logging import logger
from ..settings import settings
from ._base import BaseProvider
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "ollama"
DEFAULT_MODEL = "llama3.2"
DEFAULT_TIMEOUT = 60
DEFAULT_KWARGS = {}
class Ollama(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
TIMEOUT = DEFAULT_TIMEOUT
def __init__(self, host_url: str | None = None):
@@ -30,6 +35,12 @@ class Ollama(BaseProvider):
"""The raw Ollama client."""
if not self.host_url:
raise ValueError("No ollama host url provided")
try:
import ollama as ol
except ImportError as exc:
raise ImportError(
"Please install the `ollama` package: `pip install ollama`"
) from exc
return ol.Client(timeout=self.TIMEOUT, host=self.host_url)
@cached_property
@@ -43,7 +54,8 @@ class Ollama(BaseProvider):
mode=instructor.Mode.JSON,
)
def send_conversation(self, conversation: "Conversation") -> "Message":
@logger
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
"""Send a conversation to the Ollama API."""
from ..models import Message
@@ -51,7 +63,9 @@ class Ollama(BaseProvider):
{"role": msg.role, "content": msg.text} for msg in conversation.messages
]
response = self.client.chat(
model=conversation.llm_model or DEFAULT_MODEL, messages=messages
model=conversation.llm_model or DEFAULT_MODEL,
messages=messages,
**{**self.DEFAULT_KWARGS, **kwargs},
)
assistant_message = response.get("message")
@@ -64,6 +78,7 @@ class Ollama(BaseProvider):
llm_provider=PROVIDER_NAME,
)
@logger
def structured_response(
self,
prompt: str,
@@ -81,18 +96,23 @@ class Ollama(BaseProvider):
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
response_model=response_model,
**kwargs,
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response
return response_model.model_validate(response)
def generate_text(self, prompt: str, *, llm_model: str | None = None) -> str:
@logger
def generate_text(
self, prompt: str, *, llm_model: str | None = None, **kwargs
) -> str:
"""Generate text using the Ollama API."""
messages = [
{"role": "user", "content": prompt},
]
response = self.client.chat(
messages=messages, model=llm_model or self.DEFAULT_MODEL
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response.get("message", {}).get("content", "")
+26 -7
View File
@@ -1,22 +1,28 @@
from functools import cached_property
from typing import Type, TypeVar
from typing import TYPE_CHECKING, Type, TypeVar
import instructor
import openai as oa
from pydantic import BaseModel
from ..logging import logger
from ..settings import settings
from ._base import BaseProvider
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "openai"
DEFAULT_MODEL = "gpt-4o-mini"
DEFAULT_MAX_TOKENS = 1_000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
class OpenAI(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@@ -26,6 +32,12 @@ class OpenAI(BaseProvider):
"""The raw OpenAI client."""
if not self.api_key:
raise ValueError("OpenAI API key is required")
try:
import openai as oa
except ImportError as exc:
raise ImportError(
"Please install the `openai` package: `pip install openai`"
) from exc
return oa.OpenAI(api_key=self.api_key)
@cached_property
@@ -33,7 +45,8 @@ class OpenAI(BaseProvider):
"""A OpenAI client with Instructor."""
return instructor.from_openai(self.client)
def send_conversation(self, conversation: "Conversation", **kwargs):
@logger
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
"""Send a conversation to the OpenAI API."""
from ..models import Message
@@ -42,7 +55,9 @@ class OpenAI(BaseProvider):
]
response = self.client.chat.completions.create(
model=conversation.llm_model or DEFAULT_MODEL, messages=messages, **kwargs
model=conversation.llm_model or DEFAULT_MODEL,
messages=messages,
**{**self.DEFAULT_KWARGS, **kwargs},
)
# Get the response content from the OpenAI response
@@ -57,6 +72,7 @@ class OpenAI(BaseProvider):
llm_provider=PROVIDER_NAME,
)
@logger
def structured_response(
self,
prompt: str,
@@ -74,16 +90,19 @@ class OpenAI(BaseProvider):
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
response_model=response_model,
**kwargs,
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response
return response_model.model_validate(response)
@logger
def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs):
"""Generate text using the OpenAI API."""
messages = [
{"role": "user", "content": prompt},
]
response = self.client.chat.completions.create(
messages=messages, model=llm_model or self.DEFAULT_MODEL, **kwargs
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response.choices[0].message.content
+28 -7
View File
@@ -1,20 +1,30 @@
from functools import cached_property
from typing import TYPE_CHECKING, Type, TypeVar
import instructor
import openai as oa
from pydantic import BaseModel
from ..logging import logger
from ..settings import settings
from ._base import BaseProvider
if TYPE_CHECKING:
from ..models import Conversation, Message
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "xai"
DEFAULT_MODEL = "grok-beta"
BASE_URL = "https://api.x.ai/v1"
DEFAULT_MAX_TOKENS = 1000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
class XAI(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@@ -24,6 +34,12 @@ class XAI(BaseProvider):
"""The raw OpenAI client."""
if not self.api_key:
raise ValueError("XAI API key is required")
try:
import openai as oa
except ImportError as exc:
raise ImportError(
"Please install the `openai` package: `pip install openai`"
) from exc
return oa.OpenAI(
api_key=self.api_key,
base_url=BASE_URL,
@@ -34,7 +50,8 @@ class XAI(BaseProvider):
"""A client patched with Instructor."""
return instructor.from_openai(self.client)
def send_conversation(self, conversation: "Conversation", **kwargs):
@logger
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
"""Send a conversation to the OpenAI API."""
from ..models import Message
@@ -45,7 +62,7 @@ class XAI(BaseProvider):
response = self.client.chat.completions.create(
model=conversation.llm_model or self.DEFAULT_MODEL,
messages=messages,
**kwargs,
**{**self.DEFAULT_KWARGS, **kwargs},
)
# Get the response content from the OpenAI response
@@ -60,10 +77,14 @@ class XAI(BaseProvider):
llm_provider=PROVIDER_NAME,
)
def structured_response(self, prompt: str, response_model, *, llm_model: str):
@logger
def structured_response(
self, prompt: str, response_model: Type[T], *, llm_model: str
) -> T:
raise NotImplementedError("XAI does not support structured responses")
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
@logger
def generate_text(self, prompt: str, *, llm_model: str, **kwargs) -> str:
messages = [
{"role": "user", "content": prompt},
]
@@ -71,7 +92,7 @@ class XAI(BaseProvider):
response = self.client.chat.completions.create(
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
**kwargs,
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response.choices[0].message.content
return str(response.choices[0].message.content)
+27 -4
View File
@@ -1,19 +1,42 @@
from typing import Literal, Optional, Union
from typing import Optional, Union
from pydantic import Field, SecretStr, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
logging_level = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
class LoggingConfig(BaseSettings):
"""The class that holds all the logging settings for the application."""
enabled: bool = Field(False, description="Enable logging")
level: logging_level = Field("INFO", description="The logging level")
model_config = SettingsConfigDict(extra="forbid")
def enable_logfire(self, **kwargs) -> None:
"""Enable logging for the application."""
# adding imports here to avoid forced dependencies
try:
import logfire
from logging import basicConfig
except ImportError as e:
raise ImportError(
"To enable logging, please install logfire: `pip install logfire`"
) from e
self.enabled = True
logfire.configure(**kwargs)
basicConfig(handlers=[logfire.LogfireLoggingHandler()])
try:
logfire.configure(**kwargs)
basicConfig(handlers=[logfire.LogfireLoggingHandler()])
except Exception as e:
self.enabled = False # Reset flag on failure
raise RuntimeError("Failed to configure logging") from e
def disable_logfire(self) -> None:
"""Disable logging for the application."""
self.enabled = False
class Settings(BaseSettings):
"""The class that holds all the API keys for the application."""
+2 -2
View File
@@ -1,8 +1,8 @@
import pytest
import os
import sys
import pytest
# Add the project root to the Python path.
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+4 -4
View File
@@ -1,8 +1,8 @@
import pytest
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama
from pydantic import BaseModel
import pytest
from simplemind.providers import Anthropic, Gemini, Groq, Ollama, OpenAI
class ResponseModel(BaseModel):
result: int
@@ -25,4 +25,4 @@ def test_generate_data(provider_cls):
data = provider.structured_response(prompt=prompt, response_model=ResponseModel)
assert isinstance(data, ResponseModel)
assert type(data.result) == int
assert isinstance(data.result, int)
+1 -2
View File
@@ -1,6 +1,5 @@
import pytest
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama
from simplemind.providers import Anthropic, Gemini, Groq, Ollama, OpenAI
@pytest.mark.parametrize(