mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 22:50:18 +00:00
improvements
This commit is contained in:
@@ -6,13 +6,15 @@ from .settings import settings
|
||||
def create_conversation(llm_model=None, llm_provider=None):
|
||||
"""Create a new conversation."""
|
||||
|
||||
return Conversation(llm_model=llm_model, llm_provider=llm_provider)
|
||||
return Conversation(
|
||||
llm_model=llm_model, llm_provider=llm_provider or settings.DEFAULT_LLM_PROVIDER
|
||||
)
|
||||
|
||||
|
||||
def generate_data(prompt, *, llm_model=None, llm_provider=None, response_model=None):
|
||||
"""Generate structured data from a given prompt."""
|
||||
|
||||
provider = find_provider(llm_provider)
|
||||
provider = find_provider(llm_provider or settings.DEFAULT_LLM_PROVIDER)
|
||||
|
||||
return provider.structured_response(
|
||||
prompt=prompt,
|
||||
@@ -23,8 +25,7 @@ def generate_data(prompt, *, llm_model=None, llm_provider=None, response_model=N
|
||||
|
||||
def generate_text(prompt, *, llm_model=None, llm_provider=None):
|
||||
"""Generate text from a given prompt."""
|
||||
|
||||
provider = find_provider(llm_provider)
|
||||
provider = find_provider(llm_provider or settings.DEFAULT_LLM_PROVIDER)
|
||||
|
||||
return provider.generate_text(prompt=prompt, llm_model=llm_model)
|
||||
|
||||
@@ -35,5 +36,5 @@ __all__ = [
|
||||
"find_provider",
|
||||
"generate_data",
|
||||
"generate_text",
|
||||
"settings"
|
||||
"settings",
|
||||
]
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .utils import find_provider
|
||||
@@ -20,8 +22,13 @@ class SMBaseModel(BaseModel):
|
||||
return str(self)
|
||||
|
||||
|
||||
class BasePlugin(SMBaseModel):
|
||||
"""The base plugin class."""
|
||||
class BasePlugin(ABC):
|
||||
"""The base conversation plugin class."""
|
||||
|
||||
@abstractmethod
|
||||
def send_hook(self, conversation: "Conversation"):
|
||||
"""Send a hook to the plugin."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Message(SMBaseModel):
|
||||
|
||||
@@ -2,8 +2,6 @@ from abc import ABC, abstractmethod
|
||||
|
||||
from instructor import Instructor
|
||||
|
||||
# from ..models import Conversation, Message
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
"""The base provider class."""
|
||||
|
||||
@@ -13,6 +13,7 @@ class Settings(BaseSettings):
|
||||
GROQ_API_KEY: Optional[SecretStr] = Field(None, description="API key for Groq")
|
||||
OPENAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for OpenAI")
|
||||
XAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for xAI")
|
||||
DEFAULT_LLM_PROVIDER: str = Field("openai", description="The default LLM provider")
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env", env_file_encoding="utf-8", case_sensitive=True, extra="ignore"
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
import simplemind as sm
|
||||
|
||||
|
||||
class SimpleMemoryPlugin:
|
||||
def __init__(self):
|
||||
self.memories = [
|
||||
"the earth has fictionally beeen destroyed.",
|
||||
"the moon is made of cheese.",
|
||||
]
|
||||
|
||||
def yield_memories(self):
|
||||
return (m for m in self.memories)
|
||||
|
||||
def send_hook(self, conversation: sm.Conversation):
|
||||
for m in self.yield_memories():
|
||||
conversation.add_message(role="system", text=m)
|
||||
|
||||
|
||||
conversation = sm.create_conversation(llm_model="grok-beta", llm_provider="xai")
|
||||
conversation.add_plugin(SimpleMemoryPlugin())
|
||||
|
||||
|
||||
conversation.add_message(
|
||||
role="user",
|
||||
text="Write a poem about the moon",
|
||||
)
|
||||
r = conversation.send()
|
||||
|
||||
print(r.text)
|
||||
Reference in New Issue
Block a user