improvements

This commit is contained in:
2024-10-29 06:51:25 -04:00
parent 03204aa9a2
commit fb8b109545
6 changed files with 16 additions and 44 deletions
+6 -5
View File
@@ -6,13 +6,15 @@ from .settings import settings
def create_conversation(llm_model=None, llm_provider=None):
"""Create a new conversation."""
return Conversation(llm_model=llm_model, llm_provider=llm_provider)
return Conversation(
llm_model=llm_model, llm_provider=llm_provider or settings.DEFAULT_LLM_PROVIDER
)
def generate_data(prompt, *, llm_model=None, llm_provider=None, response_model=None):
"""Generate structured data from a given prompt."""
provider = find_provider(llm_provider)
provider = find_provider(llm_provider or settings.DEFAULT_LLM_PROVIDER)
return provider.structured_response(
prompt=prompt,
@@ -23,8 +25,7 @@ def generate_data(prompt, *, llm_model=None, llm_provider=None, response_model=N
def generate_text(prompt, *, llm_model=None, llm_provider=None):
"""Generate text from a given prompt."""
provider = find_provider(llm_provider)
provider = find_provider(llm_provider or settings.DEFAULT_LLM_PROVIDER)
return provider.generate_text(prompt=prompt, llm_model=llm_model)
@@ -35,5 +36,5 @@ __all__ = [
"find_provider",
"generate_data",
"generate_text",
"settings"
"settings",
]
+9 -2
View File
@@ -1,7 +1,9 @@
import uuid
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
from .utils import find_provider
@@ -20,8 +22,13 @@ class SMBaseModel(BaseModel):
return str(self)
class BasePlugin(SMBaseModel):
"""The base plugin class."""
class BasePlugin(ABC):
"""The base conversation plugin class."""
@abstractmethod
def send_hook(self, conversation: "Conversation"):
"""Send a hook to the plugin."""
raise NotImplementedError
class Message(SMBaseModel):
-2
View File
@@ -2,8 +2,6 @@ from abc import ABC, abstractmethod
from instructor import Instructor
# from ..models import Conversation, Message
class BaseProvider(ABC):
"""The base provider class."""
+1
View File
@@ -13,6 +13,7 @@ class Settings(BaseSettings):
GROQ_API_KEY: Optional[SecretStr] = Field(None, description="API key for Groq")
OPENAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for OpenAI")
XAI_API_KEY: Optional[SecretStr] = Field(None, description="API key for xAI")
DEFAULT_LLM_PROVIDER: str = Field("openai", description="The default LLM provider")
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", case_sensitive=True, extra="ignore"
-29
View File
@@ -1,29 +0,0 @@
import simplemind as sm
class SimpleMemoryPlugin:
def __init__(self):
self.memories = [
"the earth has fictionally beeen destroyed.",
"the moon is made of cheese.",
]
def yield_memories(self):
return (m for m in self.memories)
def send_hook(self, conversation: sm.Conversation):
for m in self.yield_memories():
conversation.add_message(role="system", text=m)
conversation = sm.create_conversation(llm_model="grok-beta", llm_provider="xai")
conversation.add_plugin(SimpleMemoryPlugin())
conversation.add_message(
role="user",
text="Write a poem about the moon",
)
r = conversation.send()
print(r.text)
-6
View File
@@ -1,6 +0,0 @@
import simplemind as sm
conversation = sm.create_conversation(llm_model="gpt-4o-mini", llm_provider="openai")
conversation.add_message("user", "Hi there, how are you?")
reply = conversation.send()
print(reply)