Files
langchain/langchain/chat_models/base.py
T
Jeremiah Lowin 925dd3e59e Add async versions of predict() and predict_messages() (#4867)
# Add async versions of predict() and predict_messages()

#4615 introduced a unifying interface for "base" and "chat" LLM models
via the new `predict()` and `predict_messages()` methods that allow both
types of models to operate on string and message-based inputs,
respectively.

This PR adds async versions of the same (`apredict()` and
`apredict_messages()`) that are identical except for their use of
`agenerate()` in place of `generate()`, which means they repurpose all
existing work on the async backend.


## Who can review?

Community members can review the PR once tests pass. Tag
maintainers/contributors who might be interested:
        @hwchase17 (follows his work on #4615)
        @agola11 (async)

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
2023-05-23 17:22:49 -07:00

282 lines
9.1 KiB
Python

import asyncio
import inspect
import warnings
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, List, Mapping, Optional, Sequence
from pydantic import Extra, Field, root_validator
import langchain
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForLLMRun,
CallbackManager,
CallbackManagerForLLMRun,
Callbacks,
)
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
ChatResult,
HumanMessage,
LLMResult,
PromptValue,
)
def _get_verbosity() -> bool:
return langchain.verbose
class BaseChatModel(BaseLanguageModel, ABC):
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True)
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
return {}
def generate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
"""Top Level call"""
params = self.dict()
params["stop"] = stop
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages, invocation_params=params
)
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
)
try:
results = [
self._generate(m, stop=stop, run_manager=run_manager)
if new_arg_supported
else self._generate(m, stop=stop)
for m in messages
]
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
raise e
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
run_manager.on_llm_end(output)
return output
async def agenerate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
"""Top Level call"""
params = self.dict()
params["stop"] = stop
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = await callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages, invocation_params=params
)
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
)
try:
results = await asyncio.gather(
*[
self._agenerate(m, stop=stop, run_manager=run_manager)
if new_arg_supported
else self._agenerate(m, stop=stop)
for m in messages
]
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e)
raise e
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
await run_manager.on_llm_end(output)
return output
def generate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]
return self.generate(prompt_messages, stop=stop, callbacks=callbacks)
async def agenerate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]
return await self.agenerate(prompt_messages, stop=stop, callbacks=callbacks)
@abstractmethod
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> ChatResult:
"""Top Level call"""
@abstractmethod
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> ChatResult:
"""Top Level call"""
def __call__(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> BaseMessage:
generation = self.generate(
[messages], stop=stop, callbacks=callbacks
).generations[0][0]
if isinstance(generation, ChatGeneration):
return generation.message
else:
raise ValueError("Unexpected generation type")
async def _call_async(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> BaseMessage:
result = await self.agenerate([messages], stop=stop, callbacks=callbacks)
generation = result.generations[0][0]
if isinstance(generation, ChatGeneration):
return generation.message
else:
raise ValueError("Unexpected generation type")
def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str:
return self.predict(message, stop=stop)
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
result = self([HumanMessage(content=text)], stop=_stop)
return result.content
def predict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
) -> BaseMessage:
if stop is None:
_stop = None
else:
_stop = list(stop)
return self(messages, stop=_stop)
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
result = await self._call_async([HumanMessage(content=text)], stop=_stop)
return result.content
async def apredict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
) -> BaseMessage:
if stop is None:
_stop = None
else:
_stop = list(stop)
return await self._call_async(messages, stop=_stop)
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {}
@property
@abstractmethod
def _llm_type(self) -> str:
"""Return type of chat model."""
def dict(self, **kwargs: Any) -> Dict:
"""Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params)
starter_dict["_type"] = self._llm_type
return starter_dict
class SimpleChatModel(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@abstractmethod
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Simpler interface."""
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> ChatResult:
func = partial(self._generate, messages, stop=stop, run_manager=run_manager)
return await asyncio.get_event_loop().run_in_executor(None, func)