mirror of
https://github.com/kennethreitz/langchain.git
synced 2026-06-05 23:00:18 +00:00
6bc8ae63ef
I'm using a hash function for the key just to make sure its length doesn't get out of hand, otherwise the implementation is quite similar.
137 lines
4.8 KiB
Python
137 lines
4.8 KiB
Python
"""Beta Feature: base interface for cache."""
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from sqlalchemy import Column, Integer, String, create_engine, select
|
|
from sqlalchemy.engine.base import Engine
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import Session
|
|
|
|
from langchain.schema import Generation
|
|
|
|
RETURN_VAL_TYPE = List[Generation]
|
|
|
|
|
|
class BaseCache(ABC):
|
|
"""Base interface for cache."""
|
|
|
|
@abstractmethod
|
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
|
"""Look up based on prompt and llm_string."""
|
|
|
|
@abstractmethod
|
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
|
"""Update cache based on prompt and llm_string."""
|
|
|
|
|
|
class InMemoryCache(BaseCache):
|
|
"""Cache that stores things in memory."""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize with empty cache."""
|
|
self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {}
|
|
|
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
|
"""Look up based on prompt and llm_string."""
|
|
return self._cache.get((prompt, llm_string), None)
|
|
|
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
|
"""Update cache based on prompt and llm_string."""
|
|
self._cache[(prompt, llm_string)] = return_val
|
|
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
class FullLLMCache(Base): # type: ignore
|
|
"""SQLite table for full LLM Cache (all generations)."""
|
|
|
|
__tablename__ = "full_llm_cache"
|
|
prompt = Column(String, primary_key=True)
|
|
llm = Column(String, primary_key=True)
|
|
idx = Column(Integer, primary_key=True)
|
|
response = Column(String)
|
|
|
|
|
|
class SQLAlchemyCache(BaseCache):
|
|
"""Cache that uses SQAlchemy as a backend."""
|
|
|
|
def __init__(self, engine: Engine):
|
|
"""Initialize by creating all tables."""
|
|
self.engine = engine
|
|
Base.metadata.create_all(self.engine)
|
|
|
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
|
"""Look up based on prompt and llm_string."""
|
|
stmt = (
|
|
select(FullLLMCache.response)
|
|
.where(FullLLMCache.prompt == prompt)
|
|
.where(FullLLMCache.llm == llm_string)
|
|
.order_by(FullLLMCache.idx)
|
|
)
|
|
with Session(self.engine) as session:
|
|
generations = []
|
|
for row in session.execute(stmt):
|
|
generations.append(Generation(text=row[0]))
|
|
if len(generations) > 0:
|
|
return generations
|
|
return None
|
|
|
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
|
"""Look up based on prompt and llm_string."""
|
|
for i, generation in enumerate(return_val):
|
|
item = FullLLMCache(
|
|
prompt=prompt, llm=llm_string, response=generation.text, idx=i
|
|
)
|
|
with Session(self.engine) as session, session.begin():
|
|
session.add(item)
|
|
|
|
|
|
class SQLiteCache(SQLAlchemyCache):
|
|
"""Cache that uses SQLite as a backend."""
|
|
|
|
def __init__(self, database_path: str = ".langchain.db"):
|
|
"""Initialize by creating the engine and all tables."""
|
|
engine = create_engine(f"sqlite:///{database_path}")
|
|
super().__init__(engine)
|
|
|
|
|
|
class RedisCache(BaseCache):
|
|
"""Cache that uses Redis as a backend."""
|
|
|
|
def __init__(self, redis_: Any):
|
|
"""Initialize by passing in Redis instance."""
|
|
try:
|
|
from redis import Redis
|
|
except ImportError:
|
|
raise ValueError(
|
|
"Could not import redis python package. "
|
|
"Please install it with `pip install redis`."
|
|
)
|
|
if not isinstance(redis_, Redis):
|
|
raise ValueError("Please pass in Redis object.")
|
|
self.redis = redis_
|
|
|
|
def _key(self, prompt: str, llm_string: str, idx: int) -> str:
|
|
"""Compute key from prompt, llm_string, and idx."""
|
|
return str(hash(prompt + llm_string)) + "_" + str(idx)
|
|
|
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
|
"""Look up based on prompt and llm_string."""
|
|
idx = 0
|
|
generations = []
|
|
while self.redis.get(self._key(prompt, llm_string, idx)):
|
|
result = self.redis.get(self._key(prompt, llm_string, idx))
|
|
if not result:
|
|
break
|
|
elif isinstance(result, bytes):
|
|
result = result.decode()
|
|
generations.append(Generation(text=result))
|
|
idx += 1
|
|
return generations if generations else None
|
|
|
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
|
"""Update cache based on prompt and llm_string."""
|
|
for i, generation in enumerate(return_val):
|
|
self.redis.set(self._key(prompt, llm_string, i), generation.text)
|