Add streaming support to XAI provider and update example usage

This commit is contained in:
2024-11-02 16:42:28 -04:00
parent 27d30ccfe8
commit d6afbd1fd0
2 changed files with 24 additions and 3 deletions
+1 -3
View File
@@ -1,9 +1,7 @@
from _context import sm
# Defaults to the default provider (openai)
r = sm.generate_text(
"Write a poem about the moon", llm_model="gpt-4o-mini", stream=True
)
r = sm.generate_text("Write a poem about the moon", llm_provider="xai", stream=True)
for chunk in r:
print(chunk, end="", flush=True)
+23
View File
@@ -25,6 +25,7 @@ class XAI(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
supports_streaming = True
supports_structured_responses = False
def __init__(self, api_key: str | None = None):
@@ -86,14 +87,36 @@ class XAI(BaseProvider):
@logger
def generate_text(self, prompt: str, *, llm_model: str, **kwargs) -> str:
# Prepare the messages.
messages = [
{"role": "user", "content": prompt},
]
# Make the request.
response = self.client.chat.completions.create(
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
**{**self.DEFAULT_KWARGS, **kwargs},
)
# Return the response content.
return str(response.choices[0].message.content)
@logger
def generate_stream_text(self, prompt: str, *, llm_model: str, **kwargs) -> str:
# Prepare the messages.
messages = [
{"role": "user", "content": prompt},
]
# Make the request.
response = self.client.chat.completions.create(
messages=messages,
model=llm_model or self.DEFAULT_MODEL,
stream=True,
**{**self.DEFAULT_KWARGS, **kwargs},
)
# Iterate over the response and yield the content.
for chunk in response:
yield chunk.choices[0].delta.content