ran isort on all files

This commit is contained in:
Siddhesh Agarwal
2024-10-31 16:58:47 +05:30
parent 7fe8e91111
commit 33e4046ac3
15 changed files with 83 additions and 45 deletions
+6 -6
View File
@@ -1,12 +1,12 @@
from _context import sm
from pydantic import BaseModel
import openai
import faiss
import numpy as np
import os
import pickle
import faiss
import numpy as np
import openai
from _context import sm
from pydantic import BaseModel
class ContextualMemoryPlugin(sm.BasePlugin):
def __init__(
+2 -3
View File
@@ -1,8 +1,7 @@
from typing import List, Iterator
from pydantic import BaseModel
from typing import Iterator, List
from _context import sm
from pydantic import BaseModel
class Movie(BaseModel):
+3 -3
View File
@@ -1,8 +1,8 @@
from _context import sm
from pydantic import BaseModel
from typing import Literal
from _context import sm
from pydantic import BaseModel
class SentimentAnalysis(BaseModel):
sentiment: Literal["positive", "negative", "neutral"]
+2 -1
View File
@@ -1,6 +1,7 @@
import simplemind as sm
import time
import simplemind as sm
class ConversationPlugin(sm.BasePlugin):
def post_send_hook(self, conversation, response):
+3 -3
View File
@@ -1,8 +1,8 @@
from typing import List, Type
from .models import Conversation, BasePlugin, BaseModel
from .utils import find_provider
from .models import BaseModel, BasePlugin, Conversation
from .settings import settings
from .utils import find_provider
class Session:
@@ -81,7 +81,7 @@ def generate_data(
*,
llm_model: str | None = None,
llm_provider: str | None = None,
response_model: Type[BaseModel] = None,
response_model: Type[BaseModel],
**kwargs,
) -> BaseModel:
"""Generate structured data from a given prompt."""
-2
View File
@@ -2,12 +2,10 @@ import uuid
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
from .utils import find_provider
MESSAGE_ROLE = Literal["system", "user", "assistant"]
+1 -1
View File
@@ -4,8 +4,8 @@ from ._base import BaseProvider
from .anthropic import Anthropic
from .gemini import Gemini
from .groq import Groq
from .openai import OpenAI
from .ollama import Ollama
from .openai import OpenAI
from .xai import XAI
providers: List[Type[BaseProvider]] = [Anthropic, Gemini, Groq, OpenAI, Ollama, XAI]
+5 -1
View File
@@ -1,6 +1,10 @@
from abc import ABC, abstractmethod
from typing import Type, TypeVar
from instructor import Instructor
from pydantic import BaseModel
T = TypeVar("T", bound=BaseModel)
class BaseProvider(ABC):
@@ -27,7 +31,7 @@ class BaseProvider(ABC):
raise NotImplementedError
@abstractmethod
def structured_response(self, prompt: str, response_model, **kwargs):
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
"""Get a structured response."""
raise NotImplementedError
+8 -2
View File
@@ -1,8 +1,14 @@
from typing import Type, TypeVar
import anthropic
import instructor
from pydantic import BaseModel
from ._base import BaseProvider
from ..settings import settings
from ._base import BaseProvider
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "anthropic"
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
@@ -55,7 +61,7 @@ class Anthropic(BaseProvider):
llm_provider=PROVIDER_NAME,
)
def structured_response(self, model: str, response_model, **kwargs):
def structured_response(self, model: str, response_model: Type[T], **kwargs) -> T:
response = self.structured_client.messages.create(
model=model or self.DEFAULT_MODEL, response_model=response_model, **kwargs
)
+14 -6
View File
@@ -1,13 +1,18 @@
import instructor
import google.generativeai as genai
from typing import Type, TypeVar
from ._base import BaseProvider
import google.generativeai as genai
import instructor
from pydantic import BaseModel
from ..models import Conversation, Message
from ..settings import settings
from ..models import Message, Conversation
from ._base import BaseProvider
PROVIDER_NAME = "gemini"
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
T = TypeVar("T", bound=BaseModel)
class Gemini(BaseProvider):
NAME = PROVIDER_NAME
@@ -59,12 +64,13 @@ class Gemini(BaseProvider):
llm_provider=PROVIDER_NAME,
)
def structured_response(self, prompt: str, response_model, **kwargs):
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
"""Send a structured response to the Gemini API."""
try:
response = self.structured_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
response_model=response_model,
**kwargs,
)
except Exception as e:
# Handle the exception appropriately, e.g., log the error or raise a custom exception
@@ -77,7 +83,9 @@ class Gemini(BaseProvider):
"""Generate text using the Gemini API."""
try:
response = self.structured_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}], response_model=None
messages=[{"role": "user", "content": prompt}],
response_model=None,
**kwargs,
)
except Exception as e:
# Handle the exception appropriately, e.g., log the error or raise a custom exception
+7 -2
View File
@@ -1,12 +1,17 @@
from typing import Type, TypeVar
import groq
import instructor
from pydantic import BaseModel
from ._base import BaseProvider
from ..settings import settings
from ._base import BaseProvider
PROVIDER_NAME = "groq"
DEFAULT_MODEL = "llama3-8b-8192"
T = TypeVar("T", bound=BaseModel)
class Groq(BaseProvider):
NAME = PROVIDER_NAME
@@ -57,7 +62,7 @@ class Groq(BaseProvider):
llm_provider=PROVIDER_NAME,
)
def structured_response(self, prompt: str, response_model, **kwargs):
def structured_response(self, prompt: str, response_model: Type[T], **kwargs) -> T:
# Ensure messages are provided in kwargs
messages = [
{"role": "user", "content": prompt},
+15 -7
View File
@@ -1,9 +1,15 @@
import ollama as ol
import instructor
from openai import OpenAI
from typing import Type, TypeVar
import instructor
import ollama as ol
from openai import OpenAI
from pydantic import BaseModel
from ._base import BaseProvider
from ..settings import settings
from ._base import BaseProvider
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "ollama"
DEFAULT_MODEL = "llama3.2"
@@ -58,8 +64,9 @@ class Ollama(BaseProvider):
)
def structured_response(
self, prompt: str, response_model, *, llm_model: str, **kwargs
):
self, prompt: str, response_model: Type[T], *, llm_model: str, **kwargs
) -> T:
"""Get a structured response from the Ollama API."""
messages = [
{"role": "user", "content": prompt},
]
@@ -72,7 +79,8 @@ class Ollama(BaseProvider):
)
return response
def generate_text(self, prompt: str, *, llm_model: str):
def generate_text(self, prompt: str, *, llm_model: str) -> str:
"""Generate text using the Ollama API."""
messages = [
{"role": "user", "content": prompt},
]
+15 -6
View File
@@ -1,8 +1,13 @@
from typing import Type, TypeVar
import instructor
import openai as oa
from pydantic import BaseModel
from ._base import BaseProvider
from ..settings import settings
from ._base import BaseProvider
T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "openai"
DEFAULT_MODEL = "gpt-4o-mini"
@@ -52,13 +57,18 @@ class OpenAI(BaseProvider):
)
def structured_response(
self, prompt: str, response_model, *, llm_model: str | None = None, **kwargs
):
self,
prompt: str,
response_model: Type[T],
*,
llm_model: str | None = None,
**kwargs,
) -> T:
"""Get a structured response from the OpenAI API."""
# Ensure messages are provided in kwargs
messages = [
{"role": "user", "content": prompt},
]
response = self.structured_client.chat.completions.create(
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
@@ -68,12 +78,11 @@ class OpenAI(BaseProvider):
return response
def generate_text(self, prompt: str, *, llm_model: str | None = None, **kwargs):
"""Generate text using the OpenAI API."""
messages = [
{"role": "user", "content": prompt},
]
response = self.client.chat.completions.create(
messages=messages, model=llm_model or self.DEFAULT_MODEL, **kwargs
)
return response.choices[0].message.content
+1 -1
View File
@@ -1,8 +1,8 @@
import instructor
import openai as oa
from ._base import BaseProvider
from ..settings import settings
from ._base import BaseProvider
PROVIDER_NAME = "xai"
DEFAULT_MODEL = "grok-beta"
+1 -1
View File
@@ -1,6 +1,6 @@
import difflib
from .providers import providers, BaseProvider
from .providers import BaseProvider, providers
_PROVIDER_NAMES = [provider.NAME.lower() for provider in providers]