mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 22:50:18 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5fa67c3b2f | |||
| b7e950a8f0 | |||
| 735c6ba665 | |||
| 9132030cbd | |||
| aeea8936ce | |||
| e79b474215 | |||
| fe2ca9d5f5 | |||
| 670240b943 |
@@ -1,6 +1,15 @@
|
|||||||
Release History
|
Release History
|
||||||
===============
|
===============
|
||||||
|
|
||||||
|
|
||||||
|
## 0.2.4 (2024-11-11)
|
||||||
|
|
||||||
|
- General improvements.
|
||||||
|
|
||||||
|
## 0.2.3 (2024-11-04)
|
||||||
|
|
||||||
|
- Remove default max-tokens for OpenAI provider.
|
||||||
|
|
||||||
## 0.2.3 (2024-11-03)
|
## 0.2.3 (2024-11-03)
|
||||||
|
|
||||||
- Update default model for Amazon provider.
|
- Update default model for Amazon provider.
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from pydantic import BaseModel
|
|
||||||
from _context import simplemind as sm
|
from _context import simplemind as sm
|
||||||
|
from pydantic import BaseModel
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
import time
|
import time
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from _context import sm
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
|
|
||||||
from _context import sm
|
|
||||||
|
|
||||||
|
|
||||||
class MultiAIConversation:
|
class MultiAIConversation:
|
||||||
"""Orchestrates conversations between multiple AI models."""
|
"""Orchestrates conversations between multiple AI models."""
|
||||||
|
|||||||
@@ -1,35 +1,28 @@
|
|||||||
from datetime import datetime
|
|
||||||
import logging
|
|
||||||
import sqlite3
|
|
||||||
from typing import List
|
|
||||||
import re
|
|
||||||
import os
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import logging
|
||||||
import spacy
|
import os
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import sqlite3
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from datetime import datetime
|
||||||
from _context import simplemind as sm
|
from typing import List
|
||||||
|
|
||||||
import nltk
|
import nltk
|
||||||
from nltk.tokenize import word_tokenize
|
import spacy
|
||||||
from nltk.tag import pos_tag
|
|
||||||
|
|
||||||
from rich.console import Console
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.markdown import Markdown
|
|
||||||
from rich.status import Status
|
|
||||||
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
import random
|
|
||||||
|
|
||||||
from docopt import docopt
|
|
||||||
|
|
||||||
from prompt_toolkit import PromptSession
|
|
||||||
from prompt_toolkit.completion import Completer, Completion
|
|
||||||
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
|
|
||||||
|
|
||||||
import xerox
|
import xerox
|
||||||
|
from _context import simplemind as sm
|
||||||
|
from docopt import docopt
|
||||||
|
from nltk.tag import pos_tag
|
||||||
|
from nltk.tokenize import word_tokenize
|
||||||
|
from prompt_toolkit import PromptSession
|
||||||
|
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
|
||||||
|
from prompt_toolkit.completion import Completer, Completion
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.status import Status
|
||||||
|
|
||||||
DB_PATH = "enhanced_context.db"
|
DB_PATH = "enhanced_context.db"
|
||||||
AVAILABLE_PROVIDERS = ["xai", "openai", "anthropic", "ollama"]
|
AVAILABLE_PROVIDERS = ["xai", "openai", "anthropic", "ollama"]
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from _context import simplemind as sm
|
from _context import simplemind as sm
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
from _context import simplemind as sm
|
from _context import simplemind as sm
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from pydantic import BaseModel
|
|
||||||
from _context import simplemind as sm
|
from _context import simplemind as sm
|
||||||
|
from pydantic import BaseModel
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import nltk
|
import nltk
|
||||||
|
from _context import simplemind as sm
|
||||||
from nltk.sentiment import SentimentIntensityAnalyzer
|
from nltk.sentiment import SentimentIntensityAnalyzer
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from _context import simplemind as sm
|
|
||||||
|
|
||||||
nltk.download("vader_lexicon")
|
nltk.download("vader_lexicon")
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
# Note: you should probably be using textblob for this.
|
# Note: you should probably be using textblob for this.
|
||||||
|
|
||||||
|
|
||||||
class SentimentAnalysis(BaseModel):
|
class SentimentAnalysis(BaseModel):
|
||||||
sentiment: Literal["positive", "negative", "neutral"]
|
sentiment: Literal["positive", "negative", "neutral"]
|
||||||
confidence: float
|
confidence: float
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
from fastapi import FastAPI, Request, HTTPException
|
|
||||||
from fastapi.templating import Jinja2Templates
|
|
||||||
from fastapi.staticfiles import StaticFiles
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from fastapi.templating import Jinja2Templates
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import simplemind as sm
|
import simplemind as sm
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|||||||
+7
-7
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "simplemind"
|
name = "simplemind"
|
||||||
version = "0.2.2"
|
version = "0.2.4"
|
||||||
description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases."
|
description = "An experimental client for AI providers that intends to replace LangChain and LangGraph for most common use cases."
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
@@ -10,18 +10,18 @@ dependencies = ["pydantic", "pydantic-settings", "instructor", "logfire"]
|
|||||||
full = [
|
full = [
|
||||||
"openai",
|
"openai",
|
||||||
"anthropic",
|
"anthropic",
|
||||||
"ollama",
|
|
||||||
"groq",
|
"groq",
|
||||||
"google-generativeai",
|
"google-generativeai",
|
||||||
"botocore",
|
"botocore",
|
||||||
"boto3"
|
"boto3"
|
||||||
]
|
]
|
||||||
openai = ["openai"]
|
|
||||||
anthropic = ["anthropic"]
|
|
||||||
ollama = ["ollama", "openai"]
|
|
||||||
groq = ["groq"]
|
|
||||||
gemini = ["google-generativeai"]
|
|
||||||
amazon = ["boto3", "botocore", "anthropic"]
|
amazon = ["boto3", "botocore", "anthropic"]
|
||||||
|
anthropic = ["anthropic"]
|
||||||
|
gemini = ["google-generativeai"]
|
||||||
|
groq = ["groq"]
|
||||||
|
ollama = ["openai"]
|
||||||
|
openai = ["openai"]
|
||||||
|
xai = ["openai"]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["hatchling"]
|
requires = ["hatchling"]
|
||||||
|
|||||||
@@ -1,12 +1,32 @@
|
|||||||
from typing import List, Type
|
from typing import List, Type
|
||||||
|
|
||||||
from ._base import BaseProvider
|
from ._base import BaseProvider
|
||||||
|
from .amazon import Amazon
|
||||||
from .anthropic import Anthropic
|
from .anthropic import Anthropic
|
||||||
from .gemini import Gemini
|
from .gemini import Gemini
|
||||||
from .groq import Groq
|
from .groq import Groq
|
||||||
from .ollama import Ollama
|
from .ollama import Ollama
|
||||||
from .openai import OpenAI
|
from .openai import OpenAI
|
||||||
from .xai import XAI
|
from .xai import XAI
|
||||||
from .amazon import Amazon
|
|
||||||
|
|
||||||
providers: List[Type[BaseProvider]] = [Anthropic, Gemini, Groq, OpenAI, Ollama, XAI, Amazon]
|
providers: List[Type[BaseProvider]] = [
|
||||||
|
Anthropic,
|
||||||
|
Gemini,
|
||||||
|
Groq,
|
||||||
|
OpenAI,
|
||||||
|
Ollama,
|
||||||
|
XAI,
|
||||||
|
Amazon,
|
||||||
|
]
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Anthropic",
|
||||||
|
"Gemini",
|
||||||
|
"Groq",
|
||||||
|
"OpenAI",
|
||||||
|
"Ollama",
|
||||||
|
"XAI",
|
||||||
|
"Amazon",
|
||||||
|
"providers",
|
||||||
|
"BaseProvider",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,22 +1,22 @@
|
|||||||
from typing import Type, TypeVar, Iterator
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||||
|
|
||||||
import instructor
|
import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from ._base import BaseProvider
|
|
||||||
from ..settings import settings
|
from ..settings import settings
|
||||||
|
from ._base import BaseProvider
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..models import Conversation, Message
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
PROVIDER_NAME = "amazon"
|
|
||||||
DEFAULT_MODEL = "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
|
|
||||||
DEFAULT_MAX_TOKENS = 5_000
|
|
||||||
|
|
||||||
|
|
||||||
class Amazon(BaseProvider):
|
class Amazon(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = "amazon"
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||||
|
DEFAULT_MAX_TOKENS = 5_000
|
||||||
supports_streaming = True
|
supports_streaming = True
|
||||||
|
|
||||||
def __init__(self, profile_name: str | None = None):
|
def __init__(self, profile_name: str | None = None):
|
||||||
@@ -25,7 +25,12 @@ class Amazon(BaseProvider):
|
|||||||
@cached_property
|
@cached_property
|
||||||
def client(self):
|
def client(self):
|
||||||
"""The AnthropicBedrock client."""
|
"""The AnthropicBedrock client."""
|
||||||
import anthropic
|
try:
|
||||||
|
import anthropic
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install the `anthropic` package: `pip install anthropic`"
|
||||||
|
) from exc
|
||||||
|
|
||||||
if not self.profile_name:
|
if not self.profile_name:
|
||||||
raise ValueError("Profile name is not provided")
|
raise ValueError("Profile name is not provided")
|
||||||
@@ -33,12 +38,12 @@ class Amazon(BaseProvider):
|
|||||||
return anthropic.AnthropicBedrock(aws_profile=self.profile_name)
|
return anthropic.AnthropicBedrock(aws_profile=self.profile_name)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def structured_client(self):
|
def structured_client(self) -> instructor.Instructor:
|
||||||
"""A client patched with Instructor."""
|
"""A client patched with Instructor."""
|
||||||
|
|
||||||
return instructor.from_anthropic(self.client)
|
return instructor.from_anthropic(self.client)
|
||||||
|
|
||||||
def send_conversation(self, conversation: "Conversation", **kwargs):
|
def send_conversation(self, conversation: "Conversation", **kwargs) -> "Message":
|
||||||
"""Send a conversation to the OpenAI API."""
|
"""Send a conversation to the OpenAI API."""
|
||||||
|
|
||||||
from ..models import Message
|
from ..models import Message
|
||||||
@@ -59,7 +64,7 @@ class Amazon(BaseProvider):
|
|||||||
role="assistant",
|
role="assistant",
|
||||||
text=assistant_message.content or "",
|
text=assistant_message.content or "",
|
||||||
raw=response,
|
raw=response,
|
||||||
llm_model=conversation.llm_model or DEFAULT_MODEL,
|
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=PROVIDER_NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -75,12 +80,12 @@ class Amazon(BaseProvider):
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
model=llm_model or self.DEFAULT_MODEL,
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
max_tokens=DEFAULT_MAX_TOKENS,
|
max_tokens=self.DEFAULT_MAX_TOKENS,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def generate_text(self, prompt, *, llm_model, **kwargs):
|
def generate_text(self, prompt: str, *, llm_model: str, **kwargs):
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
@@ -88,13 +93,15 @@ class Amazon(BaseProvider):
|
|||||||
response = self.client.messages.create(
|
response = self.client.messages.create(
|
||||||
model=llm_model or self.DEFAULT_MODEL,
|
model=llm_model or self.DEFAULT_MODEL,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=DEFAULT_MAX_TOKENS,
|
max_tokens=self.DEFAULT_MAX_TOKENS,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response.content[0].text
|
return response.content[0].text
|
||||||
|
|
||||||
def generate_stream_text(self, prompt, *, llm_model, **kwargs) -> Iterator[str]:
|
def generate_stream_text(
|
||||||
|
self, prompt: str, *, llm_model: str, **kwargs
|
||||||
|
) -> Iterator[str]:
|
||||||
"""Generate streaming text using the Amazon API."""
|
"""Generate streaming text using the Amazon API."""
|
||||||
|
|
||||||
# Prepare the messages.
|
# Prepare the messages.
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
|
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||||
|
|
||||||
import instructor
|
import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -14,20 +14,15 @@ if TYPE_CHECKING:
|
|||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_NAME = "anthropic"
|
|
||||||
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
|
||||||
DEFAULT_MAX_TOKENS = 1_000
|
|
||||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
|
||||||
|
|
||||||
|
|
||||||
class Anthropic(BaseProvider):
|
class Anthropic(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = "anthropic"
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
||||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
DEFAULT_MAX_TOKENS = 1_000
|
||||||
|
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||||
supports_streaming = True
|
supports_streaming = True
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None):
|
def __init__(self, api_key: str | None = None):
|
||||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
self.api_key = api_key or settings.get_api_key(self.NAME)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def client(self):
|
def client(self):
|
||||||
@@ -72,7 +67,7 @@ class Anthropic(BaseProvider):
|
|||||||
text=assistant_message,
|
text=assistant_message,
|
||||||
raw=response,
|
raw=response,
|
||||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=self.NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
@logger
|
@logger
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
# IT is not currently working as desired.
|
# IT is not currently working as desired.
|
||||||
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
|
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||||
|
|
||||||
import instructor
|
import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -17,18 +17,14 @@ if TYPE_CHECKING:
|
|||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_NAME = "gemini"
|
|
||||||
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
|
|
||||||
|
|
||||||
|
|
||||||
class Gemini(BaseProvider):
|
class Gemini(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = "gemini"
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
|
||||||
supports_streaming = True
|
supports_streaming = True
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None):
|
def __init__(self, api_key: str | None = None):
|
||||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
self.api_key = api_key or settings.get_api_key(self.NAME)
|
||||||
self.model_name = DEFAULT_MODEL
|
self.model_name = self.DEFAULT_MODEL
|
||||||
|
|
||||||
def set_model(self, model_name: str):
|
def set_model(self, model_name: str):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
@@ -76,7 +72,7 @@ class Gemini(BaseProvider):
|
|||||||
text=response.text,
|
text=response.text,
|
||||||
raw=response,
|
raw=response,
|
||||||
llm_model=self.model_name,
|
llm_model=self.model_name,
|
||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=self.NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
@logger
|
@logger
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
|
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||||
|
|
||||||
import instructor
|
import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -14,20 +14,15 @@ if TYPE_CHECKING:
|
|||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_NAME = "groq"
|
|
||||||
DEFAULT_MODEL = "llama3-8b-8192"
|
|
||||||
DEFAULT_MAX_TOKENS = 1_000
|
|
||||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
|
||||||
|
|
||||||
|
|
||||||
class Groq(BaseProvider):
|
class Groq(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = "groq"
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = "llama3-8b-8192"
|
||||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
DEFAULT_MAX_TOKENS = 1_000
|
||||||
|
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||||
supports_streaming = True
|
supports_streaming = True
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None):
|
def __init__(self, api_key: str | None = None):
|
||||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
self.api_key = api_key or settings.get_api_key(self.NAME)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def client(self):
|
def client(self):
|
||||||
@@ -75,7 +70,7 @@ class Groq(BaseProvider):
|
|||||||
text=assistant_message.content or "",
|
text=assistant_message.content or "",
|
||||||
raw=response,
|
raw=response,
|
||||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=self.NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
@logger
|
@logger
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
|
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||||
|
|
||||||
import instructor
|
import instructor
|
||||||
from openai import OpenAI
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from ..logging import logger
|
from ..logging import logger
|
||||||
@@ -15,17 +14,11 @@ if TYPE_CHECKING:
|
|||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_NAME = "ollama"
|
|
||||||
DEFAULT_MODEL = "llama3.2"
|
|
||||||
DEFAULT_TIMEOUT = 60
|
|
||||||
DEFAULT_KWARGS = {}
|
|
||||||
|
|
||||||
|
|
||||||
class Ollama(BaseProvider):
|
class Ollama(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = "ollama"
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = "llama3.2"
|
||||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
DEFAULT_TIMEOUT = 60
|
||||||
TIMEOUT = DEFAULT_TIMEOUT
|
DEFAULT_KWARGS = {}
|
||||||
supports_streaming = True
|
supports_streaming = True
|
||||||
|
|
||||||
def __init__(self, host_url: str | None = None):
|
def __init__(self, host_url: str | None = None):
|
||||||
@@ -37,21 +30,18 @@ class Ollama(BaseProvider):
|
|||||||
if not self.host_url:
|
if not self.host_url:
|
||||||
raise ValueError("No ollama host url provided")
|
raise ValueError("No ollama host url provided")
|
||||||
try:
|
try:
|
||||||
import ollama as ol
|
import openai
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Please install the `ollama` package: `pip install ollama`"
|
"Please install the `openai` package: `pip install openai`"
|
||||||
) from exc
|
) from exc
|
||||||
return ol.Client(timeout=self.TIMEOUT, host=self.host_url)
|
return openai.OpenAI(base_url=f"{self.host_url}/v1", api_key="ollama")
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def structured_client(self) -> instructor.Instructor:
|
def structured_client(self) -> instructor.Instructor:
|
||||||
"""A client patched with Instructor."""
|
"""A client patched with Instructor."""
|
||||||
return instructor.from_openai(
|
return instructor.from_openai(
|
||||||
OpenAI(
|
self.client,
|
||||||
base_url=f"{self.host_url}/v1",
|
|
||||||
api_key="ollama",
|
|
||||||
),
|
|
||||||
mode=instructor.Mode.JSON,
|
mode=instructor.Mode.JSON,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -64,7 +54,7 @@ class Ollama(BaseProvider):
|
|||||||
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
{"role": msg.role, "content": msg.text} for msg in conversation.messages
|
||||||
]
|
]
|
||||||
response = self.client.chat(
|
response = self.client.chat(
|
||||||
model=conversation.llm_model or DEFAULT_MODEL,
|
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
@@ -76,7 +66,7 @@ class Ollama(BaseProvider):
|
|||||||
text=assistant_message.get("content"),
|
text=assistant_message.get("content"),
|
||||||
raw=response,
|
raw=response,
|
||||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=self.NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
@logger
|
@logger
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
|
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||||
|
|
||||||
import instructor
|
import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -13,20 +13,16 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
PROVIDER_NAME = "openai"
|
|
||||||
DEFAULT_MODEL = "gpt-4o-mini"
|
|
||||||
DEFAULT_MAX_TOKENS = None
|
|
||||||
DEFAULT_KWARGS = {}
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAI(BaseProvider):
|
class OpenAI(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = "openai"
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = "gpt-4o-mini"
|
||||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
DEFAULT_MAX_TOKENS = None
|
||||||
|
DEFAULT_KWARGS = {}
|
||||||
supports_streaming = True
|
supports_streaming = True
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None):
|
def __init__(self, api_key: str | None = None):
|
||||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
self.api_key = api_key or settings.get_api_key(self.NAME)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def client(self):
|
def client(self):
|
||||||
@@ -56,7 +52,7 @@ class OpenAI(BaseProvider):
|
|||||||
]
|
]
|
||||||
|
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=conversation.llm_model or DEFAULT_MODEL,
|
model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||||
)
|
)
|
||||||
@@ -69,8 +65,8 @@ class OpenAI(BaseProvider):
|
|||||||
role="assistant",
|
role="assistant",
|
||||||
text=assistant_message.content or "",
|
text=assistant_message.content or "",
|
||||||
raw=response,
|
raw=response,
|
||||||
llm_model=conversation.llm_model or DEFAULT_MODEL,
|
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=self.NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
@logger
|
@logger
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import TYPE_CHECKING, Type, TypeVar, Iterator
|
from typing import TYPE_CHECKING, Iterator, Type, TypeVar
|
||||||
|
|
||||||
import instructor
|
import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -14,22 +14,17 @@ if TYPE_CHECKING:
|
|||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_NAME = "xai"
|
|
||||||
DEFAULT_MODEL = "grok-beta"
|
|
||||||
BASE_URL = "https://api.x.ai/v1"
|
|
||||||
DEFAULT_MAX_TOKENS = 1000
|
|
||||||
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
|
||||||
|
|
||||||
|
|
||||||
class XAI(BaseProvider):
|
class XAI(BaseProvider):
|
||||||
NAME = PROVIDER_NAME
|
NAME = "xai"
|
||||||
DEFAULT_MODEL = DEFAULT_MODEL
|
DEFAULT_MODEL = "grok-beta"
|
||||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
DEFAULT_MAX_TOKENS = 1000
|
||||||
|
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
|
||||||
|
BASE_URL = "https://api.x.ai/v1"
|
||||||
supports_streaming = True
|
supports_streaming = True
|
||||||
supports_structured_responses = False
|
supports_structured_responses = False
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None):
|
def __init__(self, api_key: str | None = None):
|
||||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
self.api_key = api_key or settings.get_api_key(self.NAME)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def client(self):
|
def client(self):
|
||||||
@@ -44,7 +39,7 @@ class XAI(BaseProvider):
|
|||||||
) from exc
|
) from exc
|
||||||
return oa.OpenAI(
|
return oa.OpenAI(
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
base_url=BASE_URL,
|
base_url=self.BASE_URL,
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
@@ -76,7 +71,7 @@ class XAI(BaseProvider):
|
|||||||
text=assistant_message.content,
|
text=assistant_message.content,
|
||||||
raw=response,
|
raw=response,
|
||||||
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
llm_model=conversation.llm_model or self.DEFAULT_MODEL,
|
||||||
llm_provider=PROVIDER_NAME,
|
llm_provider=self.NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
@logger
|
@logger
|
||||||
|
|||||||
@@ -14,8 +14,9 @@ class LoggingConfig(BaseSettings):
|
|||||||
"""Enable logging for the application."""
|
"""Enable logging for the application."""
|
||||||
# adding imports here to avoid forced dependencies
|
# adding imports here to avoid forced dependencies
|
||||||
try:
|
try:
|
||||||
import logfire
|
|
||||||
from logging import basicConfig
|
from logging import basicConfig
|
||||||
|
|
||||||
|
import logfire
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"To enable logging, please install logfire: `pip install logfire`"
|
"To enable logging, please install logfire: `pip install logfire`"
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
|
|
||||||
|
|
||||||
import simplemind as sm
|
import simplemind as sm
|
||||||
|
from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI
|
||||||
|
|
||||||
|
|
||||||
class ResponseModel(BaseModel):
|
class ResponseModel(BaseModel):
|
||||||
result: int
|
result: int
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from simplemind.providers import Anthropic, Gemini, OpenAI, Groq, Ollama, Amazon
|
from simplemind.providers import Amazon, Anthropic, Gemini, Groq, Ollama, OpenAI
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|||||||
Reference in New Issue
Block a user