mirror of
https://github.com/kennethreitz/langchain.git
synced 2026-06-05 23:00:18 +00:00
88a3a56c1a
# Add Spark SQL support * Add Spark SQL support. It can connect to Spark via building a local/remote SparkSession. * Include a notebook example I tried some complicated queries (window function, table joins), and the tool works well. Compared to the [Spark Dataframe agent](https://python.langchain.com/en/latest/modules/agents/toolkits/examples/spark.html), this tool is able to generate queries across multiple tables. --------- # Your PR Title (What it does) <!-- Thank you for contributing to LangChain! Your PR will appear in our next release under the title you set. Please make sure it highlights your valuable contribution. Replace this with a description of the change, the issue it fixes (if applicable), and relevant context. List any dependencies required for this change. After you're done, someone will review your PR. They may suggest improvements. If no one reviews your PR within a few days, feel free to @-mention the same people again, as notifications can get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting <!-- If you're adding a new integration, include an integration test and an example notebook showing its use! --> ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: <!-- For a quicker response, figure out the right person to tag with @ @hwchase17 - project lead Tracing / Callbacks - @agola11 Async - @agola11 DataLoaders - @eyurtsev Models - @hwchase17 - @agola11 Agents / Tools / Toolkits - @vowelparrot VectorStores / Retrievers / Memory - @dev2049 --> --------- Co-authored-by: Gengliang Wang <gengliang@apache.org> Co-authored-by: Mike W <62768671+skcoirz@users.noreply.github.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com> Co-authored-by: UmerHA <40663591+UmerHA@users.noreply.github.com> Co-authored-by: 张城铭 <z@hyperf.io> Co-authored-by: assert <zhangchengming@kkguan.com> Co-authored-by: blob42 <spike@w530> Co-authored-by: Yuekai Zhang <zhangyuekai@foxmail.com> Co-authored-by: Richard He <he.yucheng@outlook.com> Co-authored-by: Dev 2049 <dev.dev2049@gmail.com> Co-authored-by: Leonid Ganeline <leo.gan.57@gmail.com> Co-authored-by: Alexey Nominas <60900649+Chae4ek@users.noreply.github.com> Co-authored-by: elBarkey <elbarkey@gmail.com> Co-authored-by: Davis Chase <130488702+dev2049@users.noreply.github.com> Co-authored-by: Jeffrey D <1289344+verygoodsoftwarenotvirus@users.noreply.github.com> Co-authored-by: so2liu <yangliu35@outlook.com> Co-authored-by: Viswanadh Rayavarapu <44315599+vishwa-rn@users.noreply.github.com> Co-authored-by: Chakib Ben Ziane <contact@blob42.xyz> Co-authored-by: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Co-authored-by: Daniel Chalef <daniel.chalef@private.org> Co-authored-by: Jari Bakken <jari.bakken@gmail.com> Co-authored-by: escafati <scafatieugenio@gmail.com>
175 lines
6.8 KiB
Python
175 lines
6.8 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Any, Iterable, List, Optional
|
|
|
|
if TYPE_CHECKING:
|
|
from pyspark.sql import DataFrame, Row, SparkSession
|
|
|
|
|
|
class SparkSQL:
|
|
def __init__(
|
|
self,
|
|
spark_session: Optional[SparkSession] = None,
|
|
catalog: Optional[str] = None,
|
|
schema: Optional[str] = None,
|
|
ignore_tables: Optional[List[str]] = None,
|
|
include_tables: Optional[List[str]] = None,
|
|
sample_rows_in_table_info: int = 3,
|
|
):
|
|
try:
|
|
from pyspark.sql import SparkSession
|
|
except ImportError:
|
|
raise ValueError(
|
|
"pyspark is not installed. Please install it with `pip install pyspark`"
|
|
)
|
|
|
|
self._spark = (
|
|
spark_session if spark_session else SparkSession.builder.getOrCreate()
|
|
)
|
|
if catalog is not None:
|
|
self._spark.catalog.setCurrentCatalog(catalog)
|
|
if schema is not None:
|
|
self._spark.catalog.setCurrentDatabase(schema)
|
|
|
|
self._all_tables = set(self._get_all_table_names())
|
|
self._include_tables = set(include_tables) if include_tables else set()
|
|
if self._include_tables:
|
|
missing_tables = self._include_tables - self._all_tables
|
|
if missing_tables:
|
|
raise ValueError(
|
|
f"include_tables {missing_tables} not found in database"
|
|
)
|
|
self._ignore_tables = set(ignore_tables) if ignore_tables else set()
|
|
if self._ignore_tables:
|
|
missing_tables = self._ignore_tables - self._all_tables
|
|
if missing_tables:
|
|
raise ValueError(
|
|
f"ignore_tables {missing_tables} not found in database"
|
|
)
|
|
usable_tables = self.get_usable_table_names()
|
|
self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
|
|
|
|
if not isinstance(sample_rows_in_table_info, int):
|
|
raise TypeError("sample_rows_in_table_info must be an integer")
|
|
|
|
self._sample_rows_in_table_info = sample_rows_in_table_info
|
|
|
|
@classmethod
|
|
def from_uri(
|
|
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
|
) -> SparkSQL:
|
|
"""Creating a remote Spark Session via Spark connect.
|
|
For example: SparkSQL.from_uri("sc://localhost:15002")
|
|
"""
|
|
try:
|
|
from pyspark.sql import SparkSession
|
|
except ImportError:
|
|
raise ValueError(
|
|
"pyspark is not installed. Please install it with `pip install pyspark`"
|
|
)
|
|
|
|
spark = SparkSession.builder.remote(database_uri).getOrCreate()
|
|
return cls(spark, **kwargs)
|
|
|
|
def get_usable_table_names(self) -> Iterable[str]:
|
|
"""Get names of tables available."""
|
|
if self._include_tables:
|
|
return self._include_tables
|
|
# sorting the result can help LLM understanding it.
|
|
return sorted(self._all_tables - self._ignore_tables)
|
|
|
|
def _get_all_table_names(self) -> Iterable[str]:
|
|
rows = self._spark.sql("SHOW TABLES").select("tableName").collect()
|
|
return list(map(lambda row: row.tableName, rows))
|
|
|
|
def _get_create_table_stmt(self, table: str) -> str:
|
|
statement = (
|
|
self._spark.sql(f"SHOW CREATE TABLE {table}").collect()[0].createtab_stmt
|
|
)
|
|
# Ignore the data source provider and options to reduce the number of tokens.
|
|
using_clause_index = statement.find("USING")
|
|
return statement[:using_clause_index] + ";"
|
|
|
|
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
|
all_table_names = self.get_usable_table_names()
|
|
if table_names is not None:
|
|
missing_tables = set(table_names).difference(all_table_names)
|
|
if missing_tables:
|
|
raise ValueError(f"table_names {missing_tables} not found in database")
|
|
all_table_names = table_names
|
|
tables = []
|
|
for table_name in all_table_names:
|
|
table_info = self._get_create_table_stmt(table_name)
|
|
if self._sample_rows_in_table_info:
|
|
table_info += "\n\n/*"
|
|
table_info += f"\n{self._get_sample_spark_rows(table_name)}\n"
|
|
table_info += "*/"
|
|
tables.append(table_info)
|
|
final_str = "\n\n".join(tables)
|
|
return final_str
|
|
|
|
def _get_sample_spark_rows(self, table: str) -> str:
|
|
query = f"SELECT * FROM {table} LIMIT {self._sample_rows_in_table_info}"
|
|
df = self._spark.sql(query)
|
|
columns_str = "\t".join(list(map(lambda f: f.name, df.schema.fields)))
|
|
try:
|
|
sample_rows = self._get_dataframe_results(df)
|
|
# save the sample rows in string format
|
|
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
|
|
except Exception:
|
|
sample_rows_str = ""
|
|
|
|
return (
|
|
f"{self._sample_rows_in_table_info} rows from {table} table:\n"
|
|
f"{columns_str}\n"
|
|
f"{sample_rows_str}"
|
|
)
|
|
|
|
def _convert_row_as_tuple(self, row: Row) -> tuple:
|
|
return tuple(map(str, row.asDict().values()))
|
|
|
|
def _get_dataframe_results(self, df: DataFrame) -> list:
|
|
return list(map(self._convert_row_as_tuple, df.collect()))
|
|
|
|
def run(self, command: str, fetch: str = "all") -> str:
|
|
df = self._spark.sql(command)
|
|
if fetch == "one":
|
|
df = df.limit(1)
|
|
return str(self._get_dataframe_results(df))
|
|
|
|
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
|
|
"""Get information about specified tables.
|
|
|
|
Follows best practices as specified in: Rajkumar et al, 2022
|
|
(https://arxiv.org/abs/2204.00498)
|
|
|
|
If `sample_rows_in_table_info`, the specified number of sample rows will be
|
|
appended to each table description. This can increase performance as
|
|
demonstrated in the paper.
|
|
"""
|
|
try:
|
|
return self.get_table_info(table_names)
|
|
except ValueError as e:
|
|
"""Format the error message"""
|
|
return f"Error: {e}"
|
|
|
|
def run_no_throw(self, command: str, fetch: str = "all") -> str:
|
|
"""Execute a SQL command and return a string representing the results.
|
|
|
|
If the statement returns rows, a string of the results is returned.
|
|
If the statement returns no rows, an empty string is returned.
|
|
|
|
If the statement throws an error, the error message is returned.
|
|
"""
|
|
try:
|
|
from pyspark.errors import PySparkException
|
|
except ImportError:
|
|
raise ValueError(
|
|
"pyspark is not installed. Please install it with `pip install pyspark`"
|
|
)
|
|
try:
|
|
return self.run(command, fetch)
|
|
except PySparkException as e:
|
|
"""Format the error message"""
|
|
return f"Error: {e}"
|