mirror of
https://github.com/kennethreitz/simplemind.git
synced 2026-06-05 14:50:16 +00:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4d2c81850e | |||
| 64246658b0 | |||
| f0aff7814b | |||
| 72121c121d | |||
| 028e89b080 | |||
| e13d03f40b | |||
| 0fc49c7e13 | |||
| d6afbd1fd0 | |||
| 27d30ccfe8 |
@@ -3,7 +3,10 @@ Release History
|
||||
|
||||
## 0.2.2 (2024-11-02)
|
||||
|
||||
- Add openai streaming support (set `stream=True` to `generate_text`).
|
||||
- `conv.prepend_system_message` now uses system role by default.
|
||||
- Add `provider.supports_streaming` property.
|
||||
- Add `provider.supports_structured_response` property.
|
||||
- General improvements.
|
||||
|
||||
## 0.2.1 (2024-11-01)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from _context import sm
|
||||
|
||||
# Defaults to the default provider (openai)
|
||||
r = sm.generate_text(
|
||||
"Write a poem about the moon", llm_model="gpt-4o-mini", stream=True
|
||||
)
|
||||
r = sm.generate_text("Write a poem about the moon", llm_provider="gemini", stream=True)
|
||||
|
||||
for chunk in r:
|
||||
print(chunk, end="", flush=True)
|
||||
|
||||
@@ -17,6 +17,7 @@ class BaseProvider(ABC):
|
||||
NAME: str
|
||||
DEFAULT_MODEL: str
|
||||
supports_streaming: bool = False
|
||||
supports_structured_responses: bool = True
|
||||
|
||||
@cached_property
|
||||
@abstractmethod
|
||||
|
||||
@@ -17,6 +17,7 @@ DEFAULT_MAX_TOKENS = 5_000
|
||||
class Amazon(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
supports_streaming = True
|
||||
|
||||
def __init__(self, profile_name: str | None = None):
|
||||
self.profile_name = profile_name or settings.AMAZON_PROFILE_NAME
|
||||
@@ -92,3 +93,24 @@ class Amazon(BaseProvider):
|
||||
)
|
||||
|
||||
return response.content[0].text
|
||||
|
||||
def generate_stream_text(self, prompt, *, llm_model, **kwargs):
|
||||
"""Generate streaming text using the Amazon API."""
|
||||
|
||||
# Prepare the messages.
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
# Send the request to the API.
|
||||
response = self.client.messages.create(
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Yield the text chunks.
|
||||
for chunk in response:
|
||||
if chunk.text:
|
||||
yield chunk.text
|
||||
|
||||
@@ -24,6 +24,7 @@ class Anthropic(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
supports_streaming = True
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
@@ -107,3 +108,20 @@ class Anthropic(BaseProvider):
|
||||
)
|
||||
|
||||
return response.content[0].text
|
||||
|
||||
@logger
|
||||
def generate_stream_text(self, prompt: str, *, llm_model: str, **kwargs):
|
||||
# Prepare the messages.
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
# Make the request.
|
||||
with self.client.messages.stream(
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
messages=messages,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
) as stream:
|
||||
# Yield each chunk of text from the stream.
|
||||
for chunk in stream.text_stream:
|
||||
yield chunk
|
||||
|
||||
@@ -24,6 +24,7 @@ DEFAULT_MODEL = "models/gemini-1.5-flash-latest"
|
||||
class Gemini(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
supports_streaming = True
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
@@ -107,3 +108,17 @@ class Gemini(BaseProvider):
|
||||
# Handle the exception appropriately, e.g., log the error or raise a custom exception
|
||||
raise RuntimeError(f"Failed to generate text with Gemini API: {e}") from e
|
||||
return response.text
|
||||
|
||||
@logger
|
||||
def generate_stream_text(self, prompt: str, **kwargs) -> str:
|
||||
"""Generate streaming text using the Gemini API."""
|
||||
kwargs.pop("llm_model", None)
|
||||
try:
|
||||
response = self.client.generate_content(prompt, stream=True, **kwargs)
|
||||
for chunk in response:
|
||||
if chunk.text:
|
||||
yield chunk.text
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to generate streaming text with Gemini API: {e}"
|
||||
) from e
|
||||
|
||||
@@ -24,6 +24,7 @@ class Groq(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
supports_streaming = True
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
@@ -111,3 +112,32 @@ class Groq(BaseProvider):
|
||||
)
|
||||
|
||||
return str(response.choices[0].message.content)
|
||||
|
||||
@logger
|
||||
def generate_stream_text(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
llm_model: str | None = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Generate streaming text using the Groq API."""
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
stream=True,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
try:
|
||||
for chunk in response:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to generate streaming text with Groq API: {e}"
|
||||
) from e
|
||||
|
||||
@@ -26,6 +26,7 @@ class Ollama(BaseProvider):
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
TIMEOUT = DEFAULT_TIMEOUT
|
||||
supports_streaming = True
|
||||
|
||||
def __init__(self, host_url: str | None = None):
|
||||
self.host_url = host_url or settings.OLLAMA_HOST_URL
|
||||
@@ -116,3 +117,21 @@ class Ollama(BaseProvider):
|
||||
)
|
||||
|
||||
return response.get("message", {}).get("content", "")
|
||||
|
||||
@logger
|
||||
def generate_stream_text(self, prompt: str, *, llm_model: str, **kwargs) -> str:
|
||||
# Prepare the messages.
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
response = self.client.chat(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
stream=True,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
# Iterate over the response and yield the content.
|
||||
for chunk in response:
|
||||
yield chunk["message"]["content"]
|
||||
|
||||
@@ -24,6 +24,7 @@ class OpenAI(BaseProvider):
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
supports_streaming = True
|
||||
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
|
||||
@@ -25,6 +25,8 @@ class XAI(BaseProvider):
|
||||
NAME = PROVIDER_NAME
|
||||
DEFAULT_MODEL = DEFAULT_MODEL
|
||||
DEFAULT_KWARGS = DEFAULT_KWARGS
|
||||
supports_streaming = True
|
||||
supports_structured_responses = False
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
|
||||
@@ -85,14 +87,36 @@ class XAI(BaseProvider):
|
||||
|
||||
@logger
|
||||
def generate_text(self, prompt: str, *, llm_model: str, **kwargs) -> str:
|
||||
# Prepare the messages.
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
# Make the request.
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
# Return the response content.
|
||||
return str(response.choices[0].message.content)
|
||||
|
||||
@logger
|
||||
def generate_stream_text(self, prompt: str, *, llm_model: str, **kwargs) -> str:
|
||||
# Prepare the messages.
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
# Make the request.
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model or self.DEFAULT_MODEL,
|
||||
stream=True,
|
||||
**{**self.DEFAULT_KWARGS, **kwargs},
|
||||
)
|
||||
|
||||
# Iterate over the response and yield the content.
|
||||
for chunk in response:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
Reference in New Issue
Block a user