mirror of
https://github.com/kennethreitz/langchain.git
synced 2026-06-05 23:00:18 +00:00
68 lines
2.0 KiB
Python
68 lines
2.0 KiB
Python
from typing import Optional, Tuple
|
|
|
|
import sqlalchemy
|
|
from pgvector.sqlalchemy import Vector
|
|
from sqlalchemy.dialects.postgresql import JSON, UUID
|
|
from sqlalchemy.orm import Session, relationship
|
|
|
|
from langchain.vectorstores.pgvector import BaseModel
|
|
|
|
|
|
class CollectionStore(BaseModel):
|
|
__tablename__ = "langchain_pg_collection"
|
|
|
|
name = sqlalchemy.Column(sqlalchemy.String)
|
|
cmetadata = sqlalchemy.Column(JSON)
|
|
|
|
embeddings = relationship(
|
|
"EmbeddingStore",
|
|
back_populates="collection",
|
|
passive_deletes=True,
|
|
)
|
|
|
|
@classmethod
|
|
def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]:
|
|
return session.query(cls).filter(cls.name == name).first() # type: ignore
|
|
|
|
@classmethod
|
|
def get_or_create(
|
|
cls,
|
|
session: Session,
|
|
name: str,
|
|
cmetadata: Optional[dict] = None,
|
|
) -> Tuple["CollectionStore", bool]:
|
|
"""
|
|
Get or create a collection.
|
|
Returns [Collection, bool] where the bool is True if the collection was created.
|
|
"""
|
|
created = False
|
|
collection = cls.get_by_name(session, name)
|
|
if collection:
|
|
return collection, created
|
|
|
|
collection = cls(name=name, cmetadata=cmetadata)
|
|
session.add(collection)
|
|
session.commit()
|
|
created = True
|
|
return collection, created
|
|
|
|
|
|
class EmbeddingStore(BaseModel):
|
|
__tablename__ = "langchain_pg_embedding"
|
|
|
|
collection_id = sqlalchemy.Column(
|
|
UUID(as_uuid=True),
|
|
sqlalchemy.ForeignKey(
|
|
f"{CollectionStore.__tablename__}.uuid",
|
|
ondelete="CASCADE",
|
|
),
|
|
)
|
|
collection = relationship(CollectionStore, back_populates="embeddings")
|
|
|
|
embedding: Vector = sqlalchemy.Column(Vector(None))
|
|
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
|
cmetadata = sqlalchemy.Column(JSON, nullable=True)
|
|
|
|
# custom_id : any user defined id
|
|
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|