1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
|
from __future__ import annotations
import typing as t
import uuid
from collections import defaultdict
import sqlglot
from sqlglot import expressions as exp
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.readwriter import DataFrameReader
from sqlglot.dataframe.sql.types import StructType
from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput
class SparkSession:
known_ids: t.ClassVar[t.Set[str]] = set()
known_branch_ids: t.ClassVar[t.Set[str]] = set()
known_sequence_ids: t.ClassVar[t.Set[str]] = set()
name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list)
def __init__(self):
self.incrementing_id = 1
def __getattr__(self, name: str) -> SparkSession:
return self
def __call__(self, *args, **kwargs) -> SparkSession:
return self
@property
def read(self) -> DataFrameReader:
return DataFrameReader(self)
def table(self, tableName: str) -> DataFrame:
return self.read.table(tableName)
def createDataFrame(
self,
data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
schema: t.Optional[SchemaInput] = None,
samplingRatio: t.Optional[float] = None,
verifySchema: bool = False,
) -> DataFrame:
from sqlglot.dataframe.sql.dataframe import DataFrame
if samplingRatio is not None or verifySchema:
raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
if schema is not None and (
not isinstance(schema, (StructType, str, list))
or (isinstance(schema, list) and not isinstance(schema[0], str))
):
raise NotImplementedError("Only schema of either list or string of list supported")
if not data:
raise ValueError("Must provide data to create into a DataFrame")
column_mapping: t.Dict[str, t.Optional[str]]
if schema is not None:
column_mapping = get_column_mapping_from_schema_input(schema)
elif isinstance(data[0], dict):
column_mapping = {col_name.strip(): None for col_name in data[0]}
else:
column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
data_expressions = [
exp.Tuple(
expressions=list(
map(
lambda x: F.lit(x).expression,
row if not isinstance(row, dict) else row.values(),
)
)
)
for row in data
]
sel_columns = [
F.col(name).cast(data_type).alias(name).expression
if data_type is not None
else F.col(name).expression
for name, data_type in column_mapping.items()
]
select_kwargs = {
"expressions": sel_columns,
"from": exp.From(
expressions=[
exp.Values(
expressions=data_expressions,
alias=exp.TableAlias(
this=exp.to_identifier(self._auto_incrementing_name),
columns=[exp.to_identifier(col_name) for col_name in column_mapping],
),
),
],
),
}
sel_expression = exp.Select(**select_kwargs)
return DataFrame(self, sel_expression)
def sql(self, sqlQuery: str) -> DataFrame:
expression = sqlglot.parse_one(sqlQuery, read="spark")
if isinstance(expression, exp.Select):
df = DataFrame(self, expression)
df = df._convert_leaf_to_cte()
elif isinstance(expression, (exp.Create, exp.Insert)):
select_expression = expression.expression.copy()
if isinstance(expression, exp.Insert):
select_expression.set("with", expression.args.get("with"))
expression.set("with", None)
del expression.args["expression"]
df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore
df = df._convert_leaf_to_cte()
else:
raise ValueError(
"Unknown expression type provided in the SQL. Please create an issue with the SQL."
)
return df
@property
def _auto_incrementing_name(self) -> str:
name = f"a{self.incrementing_id}"
self.incrementing_id += 1
return name
@property
def _random_name(self) -> str:
return f"a{str(uuid.uuid4())[:8]}"
@property
def _random_branch_id(self) -> str:
id = self._random_id
self.known_branch_ids.add(id)
return id
@property
def _random_sequence_id(self):
id = self._random_id
self.known_sequence_ids.add(id)
return id
@property
def _random_id(self) -> str:
id = f"a{str(uuid.uuid4())[:8]}"
self.known_ids.add(id)
return id
@property
def _join_hint_names(self) -> t.Set[str]:
return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"}
def _add_alias_to_mapping(self, name: str, sequence_id: str):
self.name_to_sequence_id_mapping[name].append(sequence_id)
|