summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql/readwriter.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dataframe/sql/readwriter.py')
-rw-r--r--sqlglot/dataframe/sql/readwriter.py79
1 files changed, 79 insertions, 0 deletions
diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py
new file mode 100644
index 0000000..4830035
--- /dev/null
+++ b/sqlglot/dataframe/sql/readwriter.py
@@ -0,0 +1,79 @@
+from __future__ import annotations
+
+import typing as t
+
+import sqlglot
+from sqlglot import expressions as exp
+from sqlglot.helper import object_to_dict
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql.dataframe import DataFrame
+ from sqlglot.dataframe.sql.session import SparkSession
+
+
+class DataFrameReader:
+ def __init__(self, spark: SparkSession):
+ self.spark = spark
+
+ def table(self, tableName: str) -> DataFrame:
+ from sqlglot.dataframe.sql.dataframe import DataFrame
+
+ sqlglot.schema.add_table(tableName)
+ return DataFrame(self.spark, exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)))
+
+
+class DataFrameWriter:
+ def __init__(
+ self, df: DataFrame, spark: t.Optional[SparkSession] = None, mode: t.Optional[str] = None, by_name: bool = False
+ ):
+ self._df = df
+ self._spark = spark or df.spark
+ self._mode = mode
+ self._by_name = by_name
+
+ def copy(self, **kwargs) -> DataFrameWriter:
+ return DataFrameWriter(
+ **{k[1:] if k.startswith("_") else k: v for k, v in object_to_dict(self, **kwargs).items()}
+ )
+
+ def sql(self, **kwargs) -> t.List[str]:
+ return self._df.sql(**kwargs)
+
+ def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
+ return self.copy(_mode=saveMode)
+
+ @property
+ def byName(self):
+ return self.copy(by_name=True)
+
+ def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
+ output_expression_container = exp.Insert(
+ **{
+ "this": exp.to_table(tableName),
+ "overwrite": overwrite,
+ }
+ )
+ df = self._df.copy(output_expression_container=output_expression_container)
+ if self._by_name:
+ columns = sqlglot.schema.column_names(tableName, only_visible=True)
+ df = df._convert_leaf_to_cte().select(*columns)
+
+ return self.copy(_df=df)
+
+ def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
+ if format is not None:
+ raise NotImplementedError("Providing Format in the save as table is not supported")
+ exists, replace, mode = None, None, mode or str(self._mode)
+ if mode == "append":
+ return self.insertInto(name)
+ if mode == "ignore":
+ exists = True
+ if mode == "overwrite":
+ replace = True
+ output_expression_container = exp.Create(
+ this=exp.to_table(name),
+ kind="TABLE",
+ exists=exists,
+ replace=replace,
+ )
+ return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))