1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
|
from __future__ import annotations
import typing as t
import sqlglot
from sqlglot import expressions as exp
from sqlglot.helper import object_to_dict
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.session import SparkSession
class DataFrameReader:
def __init__(self, spark: SparkSession):
self.spark = spark
def table(self, tableName: str) -> DataFrame:
from sqlglot.dataframe.sql.dataframe import DataFrame
sqlglot.schema.add_table(tableName)
return DataFrame(
self.spark,
exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)),
)
class DataFrameWriter:
def __init__(
self,
df: DataFrame,
spark: t.Optional[SparkSession] = None,
mode: t.Optional[str] = None,
by_name: bool = False,
):
self._df = df
self._spark = spark or df.spark
self._mode = mode
self._by_name = by_name
def copy(self, **kwargs) -> DataFrameWriter:
return DataFrameWriter(
**{
k[1:] if k.startswith("_") else k: v
for k, v in object_to_dict(self, **kwargs).items()
}
)
def sql(self, **kwargs) -> t.List[str]:
return self._df.sql(**kwargs)
def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
return self.copy(_mode=saveMode)
@property
def byName(self):
return self.copy(by_name=True)
def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
output_expression_container = exp.Insert(
**{
"this": exp.to_table(tableName),
"overwrite": overwrite,
}
)
df = self._df.copy(output_expression_container=output_expression_container)
if self._by_name:
columns = sqlglot.schema.column_names(tableName, only_visible=True)
df = df._convert_leaf_to_cte().select(*columns)
return self.copy(_df=df)
def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
if format is not None:
raise NotImplementedError("Providing Format in the save as table is not supported")
exists, replace, mode = None, None, mode or str(self._mode)
if mode == "append":
return self.insertInto(name)
if mode == "ignore":
exists = True
if mode == "overwrite":
replace = True
output_expression_container = exp.Create(
this=exp.to_table(name),
kind="TABLE",
exists=exists,
replace=replace,
)
return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
|