mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 22:50:18 +00:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cd0be3ad89 | |||
| 3dd2e1b248 | |||
| ad1800840d | |||
| d62f297b68 | |||
| a2597709d2 | |||
| 1455b5ba13 | |||
| 0fb54d1987 | |||
| fe06331662 | |||
| 56b1e65d70 | |||
| 4b3e1bc6dd | |||
| f5b922ade8 | |||
| 3a7383425f | |||
| 92c10fc41e | |||
| caceba381d | |||
| 0795464fd7 | |||
| d82effdfb1 | |||
| e648292cb3 | |||
| 37a9333be3 | |||
| cbc3739411 |
@@ -0,0 +1,27 @@
|
|||||||
|
import time
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
import logfire
|
||||||
|
|
||||||
|
from .settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
def logger(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
"""A @logger decorator that logs the function parameters, function returns, and exceptions raised if logging is enabled."""
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs) -> Any:
|
||||||
|
if not settings.logging.enabled:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
logfire.info(f"Calling {func.__name__} with args: {args}, kwargs: {kwargs}")
|
||||||
|
t1 = time.perf_counter()
|
||||||
|
try:
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
t2 = time.perf_counter()
|
||||||
|
logfire.info(f"{func.__name__} returned: {result} in {t2-t1} seconds")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
t2 = time.perf_counter()
|
||||||
|
logfire.error(f"Error in {func.__name__}: {e} in {t2-t1} seconds")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return wrapper
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
from types import TracebackType
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from types import TracebackType
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, Type, TypeVar
|
from typing import TYPE_CHECKING, Any, Type, TypeVar
|
||||||
|
|
||||||
from instructor import Instructor
|
from instructor import Instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..models import Conversation, Message
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Type, TypeVar
|
from typing import TYPE_CHECKING, Type, TypeVar
|
||||||
|
|
||||||
import anthropic
|
|
||||||
import instructor
|
import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..logging import logger
|
||||||
from ..settings import settings
|
from ..settings import settings
|
||||||
from ._base import BaseProvider
|
from ._base import BaseProvider
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..models import Conversation, Message
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
@@ -30,6 +33,13 @@ class Anthropic(BaseProvider):
|
|||||||
"""The raw Anthropic client."""
|
"""The raw Anthropic client."""
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("Anthropic API key is required")
|
raise ValueError("Anthropic API key is required")
|
||||||
|
try:
|
||||||
|
import anthropic
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install the `anthropic` package: `pip install anthropic`"
|
||||||
|
) from exc
|
||||||
|
|
||||||
return anthropic.Anthropic(api_key=self.api_key)
|
return anthropic.Anthropic(api_key=self.api_key)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
@@ -37,7 +47,8 @@ class Anthropic(BaseProvider):
|
|||||||
"""A client patched with Instructor."""
|
"""A client patched with Instructor."""
|
||||||
return instructor.from_anthropic(self.client)
|
return instructor.from_anthropic(self.client)
|
||||||
|
|
||||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
@logger
|
||||||
|
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||||
"""Send a conversation to the Anthropic API."""
|
"""Send a conversation to the Anthropic API."""
|
||||||
from ..models import Message
|
from ..models import Message
|
||||||
|
|
||||||
@@ -63,6 +74,7 @@ class Anthropic(BaseProvider):
|
|||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=PROVIDER_NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@logger
|
||||||
def structured_response(
|
def structured_response(
|
||||||
self, response_model: Type[T], *, llm_model: str | None = None, **kwargs
|
self, response_model: Type[T], *, llm_model: str | None = None, **kwargs
|
||||||
) -> T:
|
) -> T:
|
||||||
@@ -80,8 +92,9 @@ class Anthropic(BaseProvider):
|
|||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
return response
|
return response_model.model_validate(response)
|
||||||
|
|
||||||
|
@logger
|
||||||
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
|
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
|
|||||||
@@ -2,21 +2,25 @@
|
|||||||
# IT is not currently working as desired.
|
# IT is not currently working as desired.
|
||||||
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Type, TypeVar
|
from typing import TYPE_CHECKING, Type, TypeVar
|
||||||
|
|
||||||
import google.generativeai as genai
|
|
||||||
import instructor
|
import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..logging import logger
|
||||||
from ..settings import settings
|
from ..settings import settings
|
||||||
from ._base import BaseProvider
|
from ._base import BaseProvider
|
||||||
|
|
||||||
PROVIDER_NAME = "gemini"
|
if TYPE_CHECKING:
|
||||||
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
|
from ..models import Conversation, Message
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
PROVIDER_NAME = "gemini"
|
||||||
|
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
|
||||||
|
|
||||||
|
|
||||||
class Gemini(BaseProvider):
|
class Gemini(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = PROVIDER_NAME
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = DEFAULT_MODEL
|
||||||
@@ -25,12 +29,21 @@ class Gemini(BaseProvider):
|
|||||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||||
self.model_name = DEFAULT_MODEL
|
self.model_name = DEFAULT_MODEL
|
||||||
|
|
||||||
|
def set_model(self, model_name: str):
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def client(self, model_name: str = DEFAULT_MODEL):
|
def client(self):
|
||||||
"""The raw Gemini client."""
|
"""The raw Gemini client."""
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("Gemini API key is required")
|
raise ValueError("Gemini API key is required")
|
||||||
self.model_name = model_name
|
try:
|
||||||
|
import google.generativeai as genai
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install the `google-generativeai` package: `pip install google-generativeai`"
|
||||||
|
) from exc
|
||||||
|
genai.configure(api_key=self.api_key)
|
||||||
return genai.GenerativeModel(model_name=self.model_name)
|
return genai.GenerativeModel(model_name=self.model_name)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
@@ -38,6 +51,7 @@ class Gemini(BaseProvider):
|
|||||||
"""A Gemini client patched with Instructor."""
|
"""A Gemini client patched with Instructor."""
|
||||||
return instructor.from_gemini(self.client)
|
return instructor.from_gemini(self.client)
|
||||||
|
|
||||||
|
@logger
|
||||||
def send_conversation(self, conversation: "Conversation") -> "Message":
|
def send_conversation(self, conversation: "Conversation") -> "Message":
|
||||||
"""Send a conversation to the Gemini API."""
|
"""Send a conversation to the Gemini API."""
|
||||||
from ..models import Message
|
from ..models import Message
|
||||||
@@ -64,9 +78,11 @@ class Gemini(BaseProvider):
|
|||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=PROVIDER_NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@logger
|
||||||
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
|
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
|
||||||
"""Send a structured response to the Gemini API."""
|
"""Send a structured response to the Gemini API."""
|
||||||
llm_model = kwargs.pop("llm_model", self.model_name)
|
# Only try to pop if the key exists
|
||||||
|
kwargs.pop("llm_model", None) # Add default value of None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.structured_client.chat.completions.create(
|
response = self.structured_client.chat.completions.create(
|
||||||
@@ -79,12 +95,12 @@ class Gemini(BaseProvider):
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Failed to send structured response to Gemini API: {e}"
|
f"Failed to send structured response to Gemini API: {e}"
|
||||||
) from e
|
) from e
|
||||||
return response
|
return response_model.model_validate(response)
|
||||||
|
|
||||||
|
@logger
|
||||||
def generate_text(self, prompt: str, **kwargs) -> str:
|
def generate_text(self, prompt: str, **kwargs) -> str:
|
||||||
"""Generate text using the Gemini API."""
|
"""Generate text using the Gemini API."""
|
||||||
llm_model = kwargs.pop("llm_model", self.model_name)
|
kwargs.pop("llm_model")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.client.generate_content(prompt, **kwargs)
|
response = self.client.generate_content(prompt, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,22 +1,29 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Type, TypeVar
|
from typing import TYPE_CHECKING, Type, TypeVar
|
||||||
|
|
||||||
import groq
|
|
||||||
import instructor
|
import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..logging import logger
|
||||||
from ..settings import settings
|
from ..settings import settings
|
||||||
from ._base import BaseProvider
|
from ._base import BaseProvider
|
||||||
|
|
||||||
PROVIDER_NAME = "groq"
|
if TYPE_CHECKING:
|
||||||
DEFAULT_MODEL = "llama3-8b-8192"
|
from ..models import Conversation, Message
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
PROVIDER_NAME = "groq"
|
||||||
|
DEFAULT_MODEL = "llama3-8b-8192"
|
||||||
|
DEFAULT_MAX_TOKENS = 1_000
|
||||||
|
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||||
|
|
||||||
|
|
||||||
class Groq(BaseProvider):
|
class Groq(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = PROVIDER_NAME
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = DEFAULT_MODEL
|
||||||
|
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None):
|
def __init__(self, api_key: str | None = None):
|
||||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||||
@@ -26,6 +33,12 @@ class Groq(BaseProvider):
|
|||||||
"""The raw Groq client."""
|
"""The raw Groq client."""
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("Groq API key is required")
|
raise ValueError("Groq API key is required")
|
||||||
|
try:
|
||||||
|
import groq
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install the `groq` package: `pip install groq`"
|
||||||
|
) from exc
|
||||||
return groq.Groq(api_key=self.api_key)
|
return groq.Groq(api_key=self.api_key)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
@@ -33,6 +46,7 @@ class Groq(BaseProvider):
|
|||||||
"""A client patched with Instructor."""
|
"""A client patched with Instructor."""
|
||||||
return instructor.from_groq(self.client)
|
return instructor.from_groq(self.client)
|
||||||
|
|
||||||
|
@logger
|
||||||
def send_conversation(
|
def send_conversation(
|
||||||
self,
|
self,
|
||||||
conversation: "Conversation",
|
conversation: "Conversation",
|
||||||
@@ -48,7 +62,7 @@ class Groq(BaseProvider):
|
|||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
**kwargs,
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the response content from the Groq response
|
# Get the response content from the Groq response
|
||||||
@@ -63,6 +77,7 @@ class Groq(BaseProvider):
|
|||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=PROVIDER_NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@logger
|
||||||
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
|
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
|
||||||
# Ensure messages are provided in kwargs
|
# Ensure messages are provided in kwargs
|
||||||
messages = [
|
messages = [
|
||||||
@@ -73,17 +88,18 @@ class Groq(BaseProvider):
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
model=kwargs.pop("llm_model", self.DEFAULT_MODEL),
|
model=kwargs.pop("llm_model", self.DEFAULT_MODEL),
|
||||||
**kwargs,
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
return response
|
return response_model.model_validate(response)
|
||||||
|
|
||||||
|
@logger
|
||||||
def generate_text(
|
def generate_text(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
*,
|
*,
|
||||||
llm_model: str,
|
llm_model: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
) -> str:
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
@@ -91,7 +107,7 @@ class Groq(BaseProvider):
|
|||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=llm_model or self.DEFAULT_MODEL,
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
**kwargs,
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
return response.choices[0].message.content
|
return str(response.choices[0].message.content)
|
||||||
|
|||||||
@@ -1,25 +1,30 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Type, TypeVar
|
from typing import TYPE_CHECKING, Type, TypeVar
|
||||||
|
|
||||||
import instructor
|
import instructor
|
||||||
import ollama as ol
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..logging import logger
|
||||||
from ..settings import settings
|
from ..settings import settings
|
||||||
from ._base import BaseProvider
|
from ._base import BaseProvider
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..models import Conversation, Message
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_NAME = "ollama"
|
PROVIDER_NAME = "ollama"
|
||||||
DEFAULT_MODEL = "llama3.2"
|
DEFAULT_MODEL = "llama3.2"
|
||||||
DEFAULT_TIMEOUT = 60
|
DEFAULT_TIMEOUT = 60
|
||||||
|
DEFAULT_KWARGS = {}
|
||||||
|
|
||||||
|
|
||||||
class Ollama(BaseProvider):
|
class Ollama(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = PROVIDER_NAME
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = DEFAULT_MODEL
|
||||||
|
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||||
TIMEOUT = DEFAULT_TIMEOUT
|
TIMEOUT = DEFAULT_TIMEOUT
|
||||||
|
|
||||||
def __init__(self, host_url: str | None = None):
|
def __init__(self, host_url: str | None = None):
|
||||||
@@ -30,6 +35,12 @@ class Ollama(BaseProvider):
|
|||||||
"""The raw Ollama client."""
|
"""The raw Ollama client."""
|
||||||
if not self.host_url:
|
if not self.host_url:
|
||||||
raise ValueError("No ollama host url provided")
|
raise ValueError("No ollama host url provided")
|
||||||
|
try:
|
||||||
|
import ollama as ol
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install the `ollama` package: `pip install ollama`"
|
||||||
|
) from exc
|
||||||
return ol.Client(timeout=self.TIMEOUT, host=self.host_url)
|
return ol.Client(timeout=self.TIMEOUT, host=self.host_url)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
@@ -43,7 +54,8 @@ class Ollama(BaseProvider):
|
|||||||
mode=instructor.Mode.JSON,
|
mode=instructor.Mode.JSON,
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_conversation(self, conversation: "Conversation") -> "Message":
|
@logger
|
||||||
|
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||||
"""Send a conversation to the Ollama API."""
|
"""Send a conversation to the Ollama API."""
|
||||||
from ..models import Message
|
from ..models import Message
|
||||||
|
|
||||||
@@ -51,7 +63,9 @@ class Ollama(BaseProvider):
|
|||||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||||
]
|
]
|
||||||
response = self.client.chat(
|
response = self.client.chat(
|
||||||
model=conversation.llm_model or DEFAULT_MODEL, messages=messages
|
model=conversation.llm_model or DEFAULT_MODEL,
|
||||||
|
messages=messages,
|
||||||
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
assistant_message = response.get("message")
|
assistant_message = response.get("message")
|
||||||
|
|
||||||
@@ -64,6 +78,7 @@ class Ollama(BaseProvider):
|
|||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=PROVIDER_NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@logger
|
||||||
def structured_response(
|
def structured_response(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -81,18 +96,23 @@ class Ollama(BaseProvider):
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
model=llm_model or self.DEFAULT_MODEL,
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
**kwargs,
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
return response
|
return response_model.model_validate(response)
|
||||||
|
|
||||||
def generate_text(self, prompt: str, *, llm_model: str | None = None) -> str:
|
@logger
|
||||||
|
def generate_text(
|
||||||
|
self, prompt: str, *, llm_model: str | None = None, **kwargs
|
||||||
|
) -> str:
|
||||||
"""Generate text using the Ollama API."""
|
"""Generate text using the Ollama API."""
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
|
|
||||||
response = self.client.chat(
|
response = self.client.chat(
|
||||||
messages=messages, model=llm_model or self.DEFAULT_MODEL
|
messages=messages,
|
||||||
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
return response.get("message", {}).get("content", "")
|
return response.get("message", {}).get("content", "")
|
||||||
|
|||||||
@@ -1,22 +1,28 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Type, TypeVar
|
from typing import TYPE_CHECKING, Type, TypeVar
|
||||||
|
|
||||||
import instructor
|
import instructor
|
||||||
import openai as oa
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..logging import logger
|
||||||
from ..settings import settings
|
from ..settings import settings
|
||||||
from ._base import BaseProvider
|
from ._base import BaseProvider
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..models import Conversation, Message
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
PROVIDER_NAME = "openai"
|
PROVIDER_NAME = "openai"
|
||||||
DEFAULT_MODEL = "gpt-4o-mini"
|
DEFAULT_MODEL = "gpt-4o-mini"
|
||||||
|
DEFAULT_MAX_TOKENS = 1_000
|
||||||
|
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||||
|
|
||||||
|
|
||||||
class OpenAI(BaseProvider):
|
class OpenAI(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = PROVIDER_NAME
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = DEFAULT_MODEL
|
||||||
|
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None):
|
def __init__(self, api_key: str | None = None):
|
||||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||||
@@ -26,6 +32,12 @@ class OpenAI(BaseProvider):
|
|||||||
"""The raw OpenAI client."""
|
"""The raw OpenAI client."""
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("OpenAI API key is required")
|
raise ValueError("OpenAI API key is required")
|
||||||
|
try:
|
||||||
|
import openai as oa
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install the `openai` package: `pip install openai`"
|
||||||
|
) from exc
|
||||||
return oa.OpenAI(api_key=self.api_key)
|
return oa.OpenAI(api_key=self.api_key)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
@@ -33,7 +45,8 @@ class OpenAI(BaseProvider):
|
|||||||
"""A OpenAI client with Instructor."""
|
"""A OpenAI client with Instructor."""
|
||||||
return instructor.from_openai(self.client)
|
return instructor.from_openai(self.client)
|
||||||
|
|
||||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
@logger
|
||||||
|
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||||
"""Send a conversation to the OpenAI API."""
|
"""Send a conversation to the OpenAI API."""
|
||||||
from ..models import Message
|
from ..models import Message
|
||||||
|
|
||||||
@@ -42,7 +55,9 @@ class OpenAI(BaseProvider):
|
|||||||
]
|
]
|
||||||
|
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=conversation.llm_model or DEFAULT_MODEL, messages=messages, **kwargs
|
model=conversation.llm_model or DEFAULT_MODEL,
|
||||||
|
messages=messages,
|
||||||
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the response content from the OpenAI response
|
# Get the response content from the OpenAI response
|
||||||
@@ -57,6 +72,7 @@ class OpenAI(BaseProvider):
|
|||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=PROVIDER_NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@logger
|
||||||
def structured_response(
|
def structured_response(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -74,16 +90,19 @@ class OpenAI(BaseProvider):
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
model=llm_model or self.DEFAULT_MODEL,
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
**kwargs,
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
return response
|
return response_model.model_validate(response)
|
||||||
|
|
||||||
|
@logger
|
||||||
def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs):
|
def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs):
|
||||||
"""Generate text using the OpenAI API."""
|
"""Generate text using the OpenAI API."""
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
messages=messages, model=llm_model or self.DEFAULT_MODEL, **kwargs
|
messages=messages,
|
||||||
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
return response.choices[0].message.content
|
return response.choices[0].message.content
|
||||||
|
|||||||
@@ -1,20 +1,30 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
from typing import TYPE_CHECKING, Type, TypeVar
|
||||||
|
|
||||||
import instructor
|
import instructor
|
||||||
import openai as oa
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..logging import logger
|
||||||
from ..settings import settings
|
from ..settings import settings
|
||||||
from ._base import BaseProvider
|
from ._base import BaseProvider
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..models import Conversation, Message
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_NAME = "xai"
|
PROVIDER_NAME = "xai"
|
||||||
DEFAULT_MODEL = "grok-beta"
|
DEFAULT_MODEL = "grok-beta"
|
||||||
BASE_URL = "https://api.x.ai/v1"
|
BASE_URL = "https://api.x.ai/v1"
|
||||||
DEFAULT_MAX_TOKENS = 1000
|
DEFAULT_MAX_TOKENS = 1000
|
||||||
|
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||||
|
|
||||||
|
|
||||||
class XAI(BaseProvider):
|
class XAI(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = PROVIDER_NAME
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = DEFAULT_MODEL
|
||||||
|
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None):
|
def __init__(self, api_key: str | None = None):
|
||||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||||
@@ -24,6 +34,12 @@ class XAI(BaseProvider):
|
|||||||
"""The raw OpenAI client."""
|
"""The raw OpenAI client."""
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("XAI API key is required")
|
raise ValueError("XAI API key is required")
|
||||||
|
try:
|
||||||
|
import openai as oa
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install the `openai` package: `pip install openai`"
|
||||||
|
) from exc
|
||||||
return oa.OpenAI(
|
return oa.OpenAI(
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
base_url=BASE_URL,
|
base_url=BASE_URL,
|
||||||
@@ -34,7 +50,8 @@ class XAI(BaseProvider):
|
|||||||
"""A client patched with Instructor."""
|
"""A client patched with Instructor."""
|
||||||
return instructor.from_openai(self.client)
|
return instructor.from_openai(self.client)
|
||||||
|
|
||||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
@logger
|
||||||
|
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||||
"""Send a conversation to the OpenAI API."""
|
"""Send a conversation to the OpenAI API."""
|
||||||
from ..models import Message
|
from ..models import Message
|
||||||
|
|
||||||
@@ -45,7 +62,7 @@ class XAI(BaseProvider):
|
|||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=conversation.llm_model or self.DEFAULT_MODEL,
|
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
**kwargs,
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the response content from the OpenAI response
|
# Get the response content from the OpenAI response
|
||||||
@@ -60,10 +77,14 @@ class XAI(BaseProvider):
|
|||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=PROVIDER_NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
def structured_response(self, prompt: str, response_model, *, llm_model: str):
|
@logger
|
||||||
|
def structured_response(
|
||||||
|
self, prompt: str, response_model: Type[T], *, llm_model: str
|
||||||
|
) -> T:
|
||||||
raise NotImplementedError("XAI does not support structured responses")
|
raise NotImplementedError("XAI does not support structured responses")
|
||||||
|
|
||||||
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
|
@logger
|
||||||
|
def generate_text(self, prompt: str, *, llm_model: str, **kwargs) -> str:
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
@@ -71,7 +92,7 @@ class XAI(BaseProvider):
|
|||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=llm_model or self.DEFAULT_MODEL,
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
**kwargs,
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
return response.choices[0].message.content
|
return str(response.choices[0].message.content)
|
||||||
|
|||||||
+27
-4
@@ -1,19 +1,42 @@
|
|||||||
from typing import Literal, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import Field, SecretStr, field_validator
|
from pydantic import Field, SecretStr, field_validator
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
logging_level = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
|
||||||
|
|
||||||
|
|
||||||
class LoggingConfig(BaseSettings):
|
class LoggingConfig(BaseSettings):
|
||||||
"""The class that holds all the logging settings for the application."""
|
"""The class that holds all the logging settings for the application."""
|
||||||
|
|
||||||
enabled: bool = Field(False, description="Enable logging")
|
enabled: bool = Field(False, description="Enable logging")
|
||||||
level: logging_level = Field("INFO", description="The logging level")
|
|
||||||
|
|
||||||
model_config = SettingsConfigDict(extra="forbid")
|
model_config = SettingsConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
def enable_logfire(self, **kwargs) -> None:
|
||||||
|
"""Enable logging for the application."""
|
||||||
|
# adding imports here to avoid forced dependencies
|
||||||
|
try:
|
||||||
|
import logfire
|
||||||
|
from logging import basicConfig
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"To enable logging, please install logfire: `pip install logfire`"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
self.enabled = True
|
||||||
|
logfire.configure(**kwargs)
|
||||||
|
basicConfig(handlers=[logfire.LogfireLoggingHandler()])
|
||||||
|
|
||||||
|
try:
|
||||||
|
logfire.configure(**kwargs)
|
||||||
|
basicConfig(handlers=[logfire.LogfireLoggingHandler()])
|
||||||
|
except Exception as e:
|
||||||
|
self.enabled = False # Reset flag on failure
|
||||||
|
raise RuntimeError("Failed to configure logging") from e
|
||||||
|
|
||||||
|
def disable_logfire(self) -> None:
|
||||||
|
"""Disable logging for the application."""
|
||||||
|
self.enabled = False
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
"""The class that holds all the API keys for the application."""
|
"""The class that holds all the API keys for the application."""
|
||||||
|
|||||||
+2
-2
@@ -1,8 +1,8 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
# Add the project root to the Python path.
|
# Add the project root to the Python path.
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from simplemind.providers import Anthropic, Gemini, Groq, Ollama, OpenAI
|
||||||
|
|
||||||
|
|
||||||
class ResponseModel(BaseModel):
|
class ResponseModel(BaseModel):
|
||||||
result: int
|
result: int
|
||||||
@@ -25,4 +25,4 @@ def test_generate_data(provider_cls):
|
|||||||
data = provider.structured_response(prompt=prompt, response_model=ResponseModel)
|
data = provider.structured_response(prompt=prompt, response_model=ResponseModel)
|
||||||
|
|
||||||
assert isinstance(data, ResponseModel)
|
assert isinstance(data, ResponseModel)
|
||||||
assert type(data.result) == int
|
assert isinstance(data.result, int)
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
from simplemind.providers import Anthropic, Gemini, Groq, Ollama, OpenAI
|
||||||
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|||||||
Reference in New Issue
Block a user