mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 22:50:18 +00:00
ran isort on all files
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
from _context import sm
|
||||
|
||||
from pydantic import BaseModel
|
||||
import openai
|
||||
import faiss
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import openai
|
||||
from _context import sm
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ContextualMemoryPlugin(sm.BasePlugin):
|
||||
def __init__(
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from typing import List, Iterator
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Iterator, List
|
||||
|
||||
from _context import sm
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Movie(BaseModel):
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from _context import sm
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal
|
||||
|
||||
from _context import sm
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SentimentAnalysis(BaseModel):
|
||||
sentiment: Literal["positive", "negative", "neutral"]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import simplemind as sm
|
||||
import time
|
||||
|
||||
import simplemind as sm
|
||||
|
||||
|
||||
class ConversationPlugin(sm.BasePlugin):
|
||||
def post_send_hook(self, conversation, response):
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import List, Type
|
||||
|
||||
from .models import Conversation, BasePlugin, BaseModel
|
||||
from .utils import find_provider
|
||||
from .models import BaseModel, BasePlugin, Conversation
|
||||
from .settings import settings
|
||||
from .utils import find_provider
|
||||
|
||||
|
||||
class Session:
|
||||
@@ -81,7 +81,7 @@ def generate_data(
|
||||
*,
|
||||
llm_model: str | None = None,
|
||||
llm_provider: str | None = None,
|
||||
response_model: Type[BaseModel] = None,
|
||||
response_model: Type[BaseModel],
|
||||
**kwargs,
|
||||
) -> BaseModel:
|
||||
"""Generate structured data from a given prompt."""
|
||||
|
||||
@@ -2,12 +2,10 @@ import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .utils import find_provider
|
||||
|
||||
|
||||
MESSAGE_ROLE = Literal["system", "user", "assistant"]
|
||||
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ from ._base import BaseProvider
|
||||
from .anthropic import Anthropic
|
||||
from .gemini import Gemini
|
||||
from .groq import Groq
|
||||
from .openai import OpenAI
|
||||
from .ollama import Ollama
|
||||
from .openai import OpenAI
|
||||
from .xai import XAI
|
||||
|
||||
providers: List[Type[BaseProvider]] = [Anthropic, Gemini, Groq, OpenAI, Ollama, XAI]
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Type, TypeVar
|
||||
|
||||
from instructor import Instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
@@ -27,7 +31,7 @@ class BaseProvider(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def structured_response(self, prompt: str, response_model, **kwargs):
|
||||
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
|
||||
"""Get a structured response."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
from typing import Type, TypeVar
|
||||
|
||||
import anthropic
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._base import BaseProvider
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "anthropic"
|
||||
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
||||
@@ -55,7 +61,7 @@ class Anthropic(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(self, model: str, response_model, **kwargs):
|
||||
def structured_response(self, model: str, response_model: Type[T], **kwargs) -> T:
|
||||
response = self.structured_client.messages.create(
|
||||
model=model or self.DEFAULT_MODEL, response_model=response_model, **kwargs
|
||||
)
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
import instructor
|
||||
import google.generativeai as genai
|
||||
from typing import Type, TypeVar
|
||||
|
||||
from ._base import BaseProvider
|
||||
import google.generativeai as genai
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..models import Conversation, Message
|
||||
from ..settings import settings
|
||||
from ..models import Message, Conversation
|
||||
from ._base import BaseProvider
|
||||
|
||||
PROVIDER_NAME = "gemini"
|
||||
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class Gemini(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
@@ -59,12 +64,13 @@ class Gemini(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(self, prompt: str, response_model, **kwargs):
|
||||
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
|
||||
"""Send a structured response to the Gemini API."""
|
||||
try:
|
||||
response = self.structured_client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
response_model=response_model,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle the exception appropriately, e.g., log the error or raise a custom exception
|
||||
@@ -77,7 +83,9 @@ class Gemini(BaseProvider):
|
||||
"""Generate text using the Gemini API."""
|
||||
try:
|
||||
response = self.structured_client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": prompt}], response_model=None
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
response_model=None,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle the exception appropriately, e.g., log the error or raise a custom exception
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
from typing import Type, TypeVar
|
||||
|
||||
import groq
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._base import BaseProvider
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
PROVIDER_NAME = "groq"
|
||||
DEFAULT_MODEL = "llama3-8b-8192"
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class Groq(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
@@ -57,7 +62,7 @@ class Groq(BaseProvider):
|
||||
llm_provider=PROVIDER_NAME,
|
||||
)
|
||||
|
||||
def structured_response(self, prompt: str, response_model, **kwargs):
|
||||
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
|
||||
# Ensure messages are provided in kwargs
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
import ollama as ol
|
||||
import instructor
|
||||
from openai import OpenAI
|
||||
from typing import Type, TypeVar
|
||||
|
||||
import instructor
|
||||
import ollama as ol
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._base import BaseProvider
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
PROVIDER_NAME = "ollama"
|
||||
DEFAULT_MODEL = "llama3.2"
|
||||
@@ -58,8 +64,9 @@ class Ollama(BaseProvider):
|
||||
)
|
||||
|
||||
def structured_response(
|
||||
self, prompt: str, response_model, *, llm_model: str, **kwargs
|
||||
):
|
||||
self, prompt: str, response_model: Type[T], *, llm_model: str, **kwargs
|
||||
) -> T:
|
||||
"""Get a structured response from the Ollama API."""
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
@@ -72,7 +79,8 @@ class Ollama(BaseProvider):
|
||||
)
|
||||
return response
|
||||
|
||||
def generate_text(self, prompt: str, *, llm_model: str):
|
||||
def generate_text(self, prompt: str, *, llm_model: str) -> str:
|
||||
"""Generate text using the Ollama API."""
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
from typing import Type, TypeVar
|
||||
|
||||
import instructor
|
||||
import openai as oa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._base import BaseProvider
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
PROVIDER_NAME = "openai"
|
||||
DEFAULT_MODEL = "gpt-4o-mini"
|
||||
@@ -52,13 +57,18 @@ class OpenAI(BaseProvider):
|
||||
)
|
||||
|
||||
def structured_response(
|
||||
self, prompt: str, response_model, *, llm_model: str | None = None, **kwargs
|
||||
):
|
||||
self,
|
||||
prompt: str,
|
||||
response_model: Type[T],
|
||||
*,
|
||||
llm_model: str | None = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""Get a structured response from the OpenAI API."""
|
||||
# Ensure messages are provided in kwargs
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.structured_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
@@ -68,12 +78,11 @@ class OpenAI(BaseProvider):
|
||||
return response
|
||||
|
||||
def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs):
|
||||
"""Generate text using the OpenAI API."""
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages, model=llm_model or self.DEFAULT_MODEL, **kwargs
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import instructor
|
||||
import openai as oa
|
||||
|
||||
from ._base import BaseProvider
|
||||
from ..settings import settings
|
||||
from ._base import BaseProvider
|
||||
|
||||
PROVIDER_NAME = "xai"
|
||||
DEFAULT_MODEL = "grok-beta"
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
import difflib
|
||||
|
||||
from .providers import providers, BaseProvider
|
||||
from .providers import BaseProvider, providers
|
||||
|
||||
_PROVIDER_NAMES = [provider.NAME.lower() for provider in providers]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user