Edit on GitHub

sqlglot.dataframe.sql

 1from sqlglot.dataframe.sql.column import Column
 2from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions
 3from sqlglot.dataframe.sql.group import GroupedData
 4from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter
 5from sqlglot.dataframe.sql.session import SparkSession
 6from sqlglot.dataframe.sql.window import Window, WindowSpec
 7
 8__all__ = [
 9    "SparkSession",
10    "DataFrame",
11    "GroupedData",
12    "Column",
13    "DataFrameNaFunctions",
14    "Window",
15    "WindowSpec",
16    "DataFrameReader",
17    "DataFrameWriter",
18]
class SparkSession:
 20class SparkSession:
 21    known_ids: t.ClassVar[t.Set[str]] = set()
 22    known_branch_ids: t.ClassVar[t.Set[str]] = set()
 23    known_sequence_ids: t.ClassVar[t.Set[str]] = set()
 24    name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list)
 25
 26    def __init__(self):
 27        self.incrementing_id = 1
 28
 29    def __getattr__(self, name: str) -> SparkSession:
 30        return self
 31
 32    def __call__(self, *args, **kwargs) -> SparkSession:
 33        return self
 34
 35    @property
 36    def read(self) -> DataFrameReader:
 37        return DataFrameReader(self)
 38
 39    def table(self, tableName: str) -> DataFrame:
 40        return self.read.table(tableName)
 41
 42    def createDataFrame(
 43        self,
 44        data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
 45        schema: t.Optional[SchemaInput] = None,
 46        samplingRatio: t.Optional[float] = None,
 47        verifySchema: bool = False,
 48    ) -> DataFrame:
 49        from sqlglot.dataframe.sql.dataframe import DataFrame
 50
 51        if samplingRatio is not None or verifySchema:
 52            raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
 53        if schema is not None and (
 54            not isinstance(schema, (StructType, str, list))
 55            or (isinstance(schema, list) and not isinstance(schema[0], str))
 56        ):
 57            raise NotImplementedError("Only schema of either list or string of list supported")
 58        if not data:
 59            raise ValueError("Must provide data to create into a DataFrame")
 60
 61        column_mapping: t.Dict[str, t.Optional[str]]
 62        if schema is not None:
 63            column_mapping = get_column_mapping_from_schema_input(schema)
 64        elif isinstance(data[0], dict):
 65            column_mapping = {col_name.strip(): None for col_name in data[0]}
 66        else:
 67            column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
 68
 69        data_expressions = [
 70            exp.Tuple(
 71                expressions=list(
 72                    map(
 73                        lambda x: F.lit(x).expression,
 74                        row if not isinstance(row, dict) else row.values(),
 75                    )
 76                )
 77            )
 78            for row in data
 79        ]
 80
 81        sel_columns = [
 82            F.col(name).cast(data_type).alias(name).expression
 83            if data_type is not None
 84            else F.col(name).expression
 85            for name, data_type in column_mapping.items()
 86        ]
 87
 88        select_kwargs = {
 89            "expressions": sel_columns,
 90            "from": exp.From(
 91                expressions=[
 92                    exp.Values(
 93                        expressions=data_expressions,
 94                        alias=exp.TableAlias(
 95                            this=exp.to_identifier(self._auto_incrementing_name),
 96                            columns=[exp.to_identifier(col_name) for col_name in column_mapping],
 97                        ),
 98                    ),
 99                ],
100            ),
101        }
102
103        sel_expression = exp.Select(**select_kwargs)
104        return DataFrame(self, sel_expression)
105
106    def sql(self, sqlQuery: str) -> DataFrame:
107        expression = sqlglot.parse_one(sqlQuery, read="spark")
108        if isinstance(expression, exp.Select):
109            df = DataFrame(self, expression)
110            df = df._convert_leaf_to_cte()
111        elif isinstance(expression, (exp.Create, exp.Insert)):
112            select_expression = expression.expression.copy()
113            if isinstance(expression, exp.Insert):
114                select_expression.set("with", expression.args.get("with"))
115                expression.set("with", None)
116            del expression.args["expression"]
117            df = DataFrame(self, select_expression, output_expression_container=expression)  # type: ignore
118            df = df._convert_leaf_to_cte()
119        else:
120            raise ValueError(
121                "Unknown expression type provided in the SQL. Please create an issue with the SQL."
122            )
123        return df
124
125    @property
126    def _auto_incrementing_name(self) -> str:
127        name = f"a{self.incrementing_id}"
128        self.incrementing_id += 1
129        return name
130
131    @property
132    def _random_name(self) -> str:
133        return "r" + uuid.uuid4().hex
134
135    @property
136    def _random_branch_id(self) -> str:
137        id = self._random_id
138        self.known_branch_ids.add(id)
139        return id
140
141    @property
142    def _random_sequence_id(self):
143        id = self._random_id
144        self.known_sequence_ids.add(id)
145        return id
146
147    @property
148    def _random_id(self) -> str:
149        id = self._random_name
150        self.known_ids.add(id)
151        return id
152
153    @property
154    def _join_hint_names(self) -> t.Set[str]:
155        return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"}
156
157    def _add_alias_to_mapping(self, name: str, sequence_id: str):
158        self.name_to_sequence_id_mapping[name].append(sequence_id)
SparkSession()
26    def __init__(self):
27        self.incrementing_id = 1
def table(self, tableName: str) -> sqlglot.dataframe.sql.DataFrame:
39    def table(self, tableName: str) -> DataFrame:
40        return self.read.table(tableName)
def createDataFrame( self, data: Sequence[Union[Dict[str, <MagicMock id='140483441683968'>], List[<MagicMock id='140483441683968'>], Tuple]], schema: Optional[<MagicMock id='140483443386656'>] = None, samplingRatio: Optional[float] = None, verifySchema: bool = False) -> sqlglot.dataframe.sql.DataFrame:
 42    def createDataFrame(
 43        self,
 44        data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
 45        schema: t.Optional[SchemaInput] = None,
 46        samplingRatio: t.Optional[float] = None,
 47        verifySchema: bool = False,
 48    ) -> DataFrame:
 49        from sqlglot.dataframe.sql.dataframe import DataFrame
 50
 51        if samplingRatio is not None or verifySchema:
 52            raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
 53        if schema is not None and (
 54            not isinstance(schema, (StructType, str, list))
 55            or (isinstance(schema, list) and not isinstance(schema[0], str))
 56        ):
 57            raise NotImplementedError("Only schema of either list or string of list supported")
 58        if not data:
 59            raise ValueError("Must provide data to create into a DataFrame")
 60
 61        column_mapping: t.Dict[str, t.Optional[str]]
 62        if schema is not None:
 63            column_mapping = get_column_mapping_from_schema_input(schema)
 64        elif isinstance(data[0], dict):
 65            column_mapping = {col_name.strip(): None for col_name in data[0]}
 66        else:
 67            column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
 68
 69        data_expressions = [
 70            exp.Tuple(
 71                expressions=list(
 72                    map(
 73                        lambda x: F.lit(x).expression,
 74                        row if not isinstance(row, dict) else row.values(),
 75                    )
 76                )
 77            )
 78            for row in data
 79        ]
 80
 81        sel_columns = [
 82            F.col(name).cast(data_type).alias(name).expression
 83            if data_type is not None
 84            else F.col(name).expression
 85            for name, data_type in column_mapping.items()
 86        ]
 87
 88        select_kwargs = {
 89            "expressions": sel_columns,
 90            "from": exp.From(
 91                expressions=[
 92                    exp.Values(
 93                        expressions=data_expressions,
 94                        alias=exp.TableAlias(
 95                            this=exp.to_identifier(self._auto_incrementing_name),
 96                            columns=[exp.to_identifier(col_name) for col_name in column_mapping],
 97                        ),
 98                    ),
 99                ],
100            ),
101        }
102
103        sel_expression = exp.Select(**select_kwargs)
104        return DataFrame(self, sel_expression)
def sql(self, sqlQuery: str) -> sqlglot.dataframe.sql.DataFrame:
106    def sql(self, sqlQuery: str) -> DataFrame:
107        expression = sqlglot.parse_one(sqlQuery, read="spark")
108        if isinstance(expression, exp.Select):
109            df = DataFrame(self, expression)
110            df = df._convert_leaf_to_cte()
111        elif isinstance(expression, (exp.Create, exp.Insert)):
112            select_expression = expression.expression.copy()
113            if isinstance(expression, exp.Insert):
114                select_expression.set("with", expression.args.get("with"))
115                expression.set("with", None)
116            del expression.args["expression"]
117            df = DataFrame(self, select_expression, output_expression_container=expression)  # type: ignore
118            df = df._convert_leaf_to_cte()
119        else:
120            raise ValueError(
121                "Unknown expression type provided in the SQL. Please create an issue with the SQL."
122            )
123        return df
class DataFrame:
 46class DataFrame:
 47    def __init__(
 48        self,
 49        spark: SparkSession,
 50        expression: exp.Select,
 51        branch_id: t.Optional[str] = None,
 52        sequence_id: t.Optional[str] = None,
 53        last_op: Operation = Operation.INIT,
 54        pending_hints: t.Optional[t.List[exp.Expression]] = None,
 55        output_expression_container: t.Optional[OutputExpressionContainer] = None,
 56        **kwargs,
 57    ):
 58        self.spark = spark
 59        self.expression = expression
 60        self.branch_id = branch_id or self.spark._random_branch_id
 61        self.sequence_id = sequence_id or self.spark._random_sequence_id
 62        self.last_op = last_op
 63        self.pending_hints = pending_hints or []
 64        self.output_expression_container = output_expression_container or exp.Select()
 65
 66    def __getattr__(self, column_name: str) -> Column:
 67        return self[column_name]
 68
 69    def __getitem__(self, column_name: str) -> Column:
 70        column_name = f"{self.branch_id}.{column_name}"
 71        return Column(column_name)
 72
 73    def __copy__(self):
 74        return self.copy()
 75
 76    @property
 77    def sparkSession(self):
 78        return self.spark
 79
 80    @property
 81    def write(self):
 82        return DataFrameWriter(self)
 83
 84    @property
 85    def latest_cte_name(self) -> str:
 86        if not self.expression.ctes:
 87            from_exp = self.expression.args["from"]
 88            if from_exp.alias_or_name:
 89                return from_exp.alias_or_name
 90            table_alias = from_exp.find(exp.TableAlias)
 91            if not table_alias:
 92                raise RuntimeError(
 93                    f"Could not find an alias name for this expression: {self.expression}"
 94                )
 95            return table_alias.alias_or_name
 96        return self.expression.ctes[-1].alias
 97
 98    @property
 99    def pending_join_hints(self):
100        return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)]
101
102    @property
103    def pending_partition_hints(self):
104        return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)]
105
106    @property
107    def columns(self) -> t.List[str]:
108        return self.expression.named_selects
109
110    @property
111    def na(self) -> DataFrameNaFunctions:
112        return DataFrameNaFunctions(self)
113
114    def _replace_cte_names_with_hashes(self, expression: exp.Select):
115        replacement_mapping = {}
116        for cte in expression.ctes:
117            old_name_id = cte.args["alias"].this
118            new_hashed_id = exp.to_identifier(
119                self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
120            )
121            replacement_mapping[old_name_id] = new_hashed_id
122            expression = expression.transform(replace_id_value, replacement_mapping)
123        return expression
124
125    def _create_cte_from_expression(
126        self,
127        expression: exp.Expression,
128        branch_id: t.Optional[str] = None,
129        sequence_id: t.Optional[str] = None,
130        **kwargs,
131    ) -> t.Tuple[exp.CTE, str]:
132        name = self.spark._random_name
133        expression_to_cte = expression.copy()
134        expression_to_cte.set("with", None)
135        cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
136        cte.set("branch_id", branch_id or self.branch_id)
137        cte.set("sequence_id", sequence_id or self.sequence_id)
138        return cte, name
139
140    @t.overload
141    def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
142        ...
143
144    @t.overload
145    def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
146        ...
147
148    def _ensure_list_of_columns(self, cols):
149        return Column.ensure_cols(ensure_list(cols))
150
151    def _ensure_and_normalize_cols(self, cols):
152        cols = self._ensure_list_of_columns(cols)
153        normalize(self.spark, self.expression, cols)
154        return cols
155
156    def _ensure_and_normalize_col(self, col):
157        col = Column.ensure_col(col)
158        normalize(self.spark, self.expression, col)
159        return col
160
161    def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame:
162        df = self._resolve_pending_hints()
163        sequence_id = sequence_id or df.sequence_id
164        expression = df.expression.copy()
165        cte_expression, cte_name = df._create_cte_from_expression(
166            expression=expression, sequence_id=sequence_id
167        )
168        new_expression = df._add_ctes_to_expression(
169            exp.Select(), expression.ctes + [cte_expression]
170        )
171        sel_columns = df._get_outer_select_columns(cte_expression)
172        new_expression = new_expression.from_(cte_name).select(
173            *[x.alias_or_name for x in sel_columns]
174        )
175        return df.copy(expression=new_expression, sequence_id=sequence_id)
176
177    def _resolve_pending_hints(self) -> DataFrame:
178        df = self.copy()
179        if not self.pending_hints:
180            return df
181        expression = df.expression
182        hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
183        for hint in df.pending_partition_hints:
184            hint_expression.append("expressions", hint)
185            df.pending_hints.remove(hint)
186
187        join_aliases = {
188            join_table.alias_or_name
189            for join_table in get_tables_from_expression_with_join(expression)
190        }
191        if join_aliases:
192            for hint in df.pending_join_hints:
193                for sequence_id_expression in hint.expressions:
194                    sequence_id_or_name = sequence_id_expression.alias_or_name
195                    sequence_ids_to_match = [sequence_id_or_name]
196                    if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
197                        sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[
198                            sequence_id_or_name
199                        ]
200                    matching_ctes = [
201                        cte
202                        for cte in reversed(expression.ctes)
203                        if cte.args["sequence_id"] in sequence_ids_to_match
204                    ]
205                    for matching_cte in matching_ctes:
206                        if matching_cte.alias_or_name in join_aliases:
207                            sequence_id_expression.set("this", matching_cte.args["alias"].this)
208                            df.pending_hints.remove(hint)
209                            break
210                hint_expression.append("expressions", hint)
211        if hint_expression.expressions:
212            expression.set("hint", hint_expression)
213        return df
214
215    def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
216        hint_name = hint_name.upper()
217        hint_expression = (
218            exp.JoinHint(
219                this=hint_name,
220                expressions=[exp.to_table(parameter.alias_or_name) for parameter in args],
221            )
222            if hint_name in JOIN_HINTS
223            else exp.Anonymous(
224                this=hint_name, expressions=[parameter.expression for parameter in args]
225            )
226        )
227        new_df = self.copy()
228        new_df.pending_hints.append(hint_expression)
229        return new_df
230
231    def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool):
232        other_df = other._convert_leaf_to_cte()
233        base_expression = self.expression.copy()
234        base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes)
235        all_ctes = base_expression.ctes
236        other_df.expression.set("with", None)
237        base_expression.set("with", None)
238        operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression)
239        operation.set("with", exp.With(expressions=all_ctes))
240        return self.copy(expression=operation)._convert_leaf_to_cte()
241
242    def _cache(self, storage_level: str):
243        df = self._convert_leaf_to_cte()
244        df.expression.ctes[-1].set("cache_storage_level", storage_level)
245        return df
246
247    @classmethod
248    def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select:
249        expression = expression.copy()
250        with_expression = expression.args.get("with")
251        if with_expression:
252            existing_ctes = with_expression.expressions
253            existsing_cte_names = {x.alias_or_name for x in existing_ctes}
254            for cte in ctes:
255                if cte.alias_or_name not in existsing_cte_names:
256                    existing_ctes.append(cte)
257        else:
258            existing_ctes = ctes
259        expression.set("with", exp.With(expressions=existing_ctes))
260        return expression
261
262    @classmethod
263    def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]:
264        expression = item.expression if isinstance(item, DataFrame) else item
265        return [Column(x) for x in expression.find(exp.Select).expressions]
266
267    @classmethod
268    def _create_hash_from_expression(cls, expression: exp.Select):
269        value = expression.sql(dialect="spark").encode("utf-8")
270        return f"t{zlib.crc32(value)}"[:6]
271
272    def _get_select_expressions(
273        self,
274    ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
275        select_expressions: t.List[
276            t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]
277        ] = []
278        main_select_ctes: t.List[exp.CTE] = []
279        for cte in self.expression.ctes:
280            cache_storage_level = cte.args.get("cache_storage_level")
281            if cache_storage_level:
282                select_expression = cte.this.copy()
283                select_expression.set("with", exp.With(expressions=copy(main_select_ctes)))
284                select_expression.set("cte_alias_name", cte.alias_or_name)
285                select_expression.set("cache_storage_level", cache_storage_level)
286                select_expressions.append((exp.Cache, select_expression))
287            else:
288                main_select_ctes.append(cte)
289        main_select = self.expression.copy()
290        if main_select_ctes:
291            main_select.set("with", exp.With(expressions=main_select_ctes))
292        expression_select_pair = (type(self.output_expression_container), main_select)
293        select_expressions.append(expression_select_pair)  # type: ignore
294        return select_expressions
295
296    def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
297        df = self._resolve_pending_hints()
298        select_expressions = df._get_select_expressions()
299        output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
300        replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
301        for expression_type, select_expression in select_expressions:
302            select_expression = select_expression.transform(replace_id_value, replacement_mapping)
303            if optimize:
304                select_expression = optimize_func(select_expression)
305            select_expression = df._replace_cte_names_with_hashes(select_expression)
306            expression: t.Union[exp.Select, exp.Cache, exp.Drop]
307            if expression_type == exp.Cache:
308                cache_table_name = df._create_hash_from_expression(select_expression)
309                cache_table = exp.to_table(cache_table_name)
310                original_alias_name = select_expression.args["cte_alias_name"]
311
312                replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(  # type: ignore
313                    cache_table_name
314                )
315                sqlglot.schema.add_table(
316                    cache_table_name,
317                    {
318                        expression.alias_or_name: expression.type.sql("spark")
319                        for expression in select_expression.expressions
320                    },
321                )
322                cache_storage_level = select_expression.args["cache_storage_level"]
323                options = [
324                    exp.Literal.string("storageLevel"),
325                    exp.Literal.string(cache_storage_level),
326                ]
327                expression = exp.Cache(
328                    this=cache_table, expression=select_expression, lazy=True, options=options
329                )
330                # We will drop the "view" if it exists before running the cache table
331                output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
332            elif expression_type == exp.Create:
333                expression = df.output_expression_container.copy()
334                expression.set("expression", select_expression)
335            elif expression_type == exp.Insert:
336                expression = df.output_expression_container.copy()
337                select_without_ctes = select_expression.copy()
338                select_without_ctes.set("with", None)
339                expression.set("expression", select_without_ctes)
340                if select_expression.ctes:
341                    expression.set("with", exp.With(expressions=select_expression.ctes))
342            elif expression_type == exp.Select:
343                expression = select_expression
344            else:
345                raise ValueError(f"Invalid expression type: {expression_type}")
346            output_expressions.append(expression)
347
348        return [
349            expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions
350        ]
351
352    def copy(self, **kwargs) -> DataFrame:
353        return DataFrame(**object_to_dict(self, **kwargs))
354
355    @operation(Operation.SELECT)
356    def select(self, *cols, **kwargs) -> DataFrame:
357        cols = self._ensure_and_normalize_cols(cols)
358        kwargs["append"] = kwargs.get("append", False)
359        if self.expression.args.get("joins"):
360            ambiguous_cols = [col for col in cols if not col.column_expression.table]
361            if ambiguous_cols:
362                join_table_identifiers = [
363                    x.this for x in get_tables_from_expression_with_join(self.expression)
364                ]
365                cte_names_in_join = [x.this for x in join_table_identifiers]
366                for ambiguous_col in ambiguous_cols:
367                    ctes_with_column = [
368                        cte
369                        for cte in self.expression.ctes
370                        if cte.alias_or_name in cte_names_in_join
371                        and ambiguous_col.alias_or_name in cte.this.named_selects
372                    ]
373                    # If the select column does not specify a table and there is a join
374                    # then we assume they are referring to the left table
375                    if len(ctes_with_column) > 1:
376                        table_identifier = self.expression.args["from"].args["expressions"][0].this
377                    else:
378                        table_identifier = ctes_with_column[0].args["alias"].this
379                    ambiguous_col.expression.set("table", table_identifier)
380        expression = self.expression.select(*[x.expression for x in cols], **kwargs)
381        qualify_columns(expression, sqlglot.schema)
382        return self.copy(expression=expression, **kwargs)
383
384    @operation(Operation.NO_OP)
385    def alias(self, name: str, **kwargs) -> DataFrame:
386        new_sequence_id = self.spark._random_sequence_id
387        df = self.copy()
388        for join_hint in df.pending_join_hints:
389            for expression in join_hint.expressions:
390                if expression.alias_or_name == self.sequence_id:
391                    expression.set("this", Column.ensure_col(new_sequence_id).expression)
392        df.spark._add_alias_to_mapping(name, new_sequence_id)
393        return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
394
395    @operation(Operation.WHERE)
396    def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
397        col = self._ensure_and_normalize_col(column)
398        return self.copy(expression=self.expression.where(col.expression))
399
400    filter = where
401
402    @operation(Operation.GROUP_BY)
403    def groupBy(self, *cols, **kwargs) -> GroupedData:
404        columns = self._ensure_and_normalize_cols(cols)
405        return GroupedData(self, columns, self.last_op)
406
407    @operation(Operation.SELECT)
408    def agg(self, *exprs, **kwargs) -> DataFrame:
409        cols = self._ensure_and_normalize_cols(exprs)
410        return self.groupBy().agg(*cols)
411
412    @operation(Operation.FROM)
413    def join(
414        self,
415        other_df: DataFrame,
416        on: t.Union[str, t.List[str], Column, t.List[Column]],
417        how: str = "inner",
418        **kwargs,
419    ) -> DataFrame:
420        other_df = other_df._convert_leaf_to_cte()
421        pre_join_self_latest_cte_name = self.latest_cte_name
422        columns = self._ensure_and_normalize_cols(on)
423        join_type = how.replace("_", " ")
424        if isinstance(columns[0].expression, exp.Column):
425            join_columns = [
426                Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns
427            ]
428            join_clause = functools.reduce(
429                lambda x, y: x & y,
430                [
431                    col.copy().set_table_name(pre_join_self_latest_cte_name)
432                    == col.copy().set_table_name(other_df.latest_cte_name)
433                    for col in columns
434                ],
435            )
436        else:
437            if len(columns) > 1:
438                columns = [functools.reduce(lambda x, y: x & y, columns)]
439            join_clause = columns[0]
440            join_columns = [
441                Column(x).set_table_name(pre_join_self_latest_cte_name)
442                if i % 2 == 0
443                else Column(x).set_table_name(other_df.latest_cte_name)
444                for i, x in enumerate(join_clause.expression.find_all(exp.Column))
445            ]
446        self_columns = [
447            column.set_table_name(pre_join_self_latest_cte_name, copy=True)
448            for column in self._get_outer_select_columns(self)
449        ]
450        other_columns = [
451            column.set_table_name(other_df.latest_cte_name, copy=True)
452            for column in self._get_outer_select_columns(other_df)
453        ]
454        column_value_mapping = {
455            column.alias_or_name
456            if not isinstance(column.expression.this, exp.Star)
457            else column.sql(): column
458            for column in other_columns + self_columns + join_columns
459        }
460        all_columns = [
461            column_value_mapping[name]
462            for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns}
463        ]
464        new_df = self.copy(
465            expression=self.expression.join(
466                other_df.latest_cte_name, on=join_clause.expression, join_type=join_type
467            )
468        )
469        new_df.expression = new_df._add_ctes_to_expression(
470            new_df.expression, other_df.expression.ctes
471        )
472        new_df.pending_hints.extend(other_df.pending_hints)
473        new_df = new_df.select.__wrapped__(new_df, *all_columns)
474        return new_df
475
476    @operation(Operation.ORDER_BY)
477    def orderBy(
478        self,
479        *cols: t.Union[str, Column],
480        ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
481    ) -> DataFrame:
482        """
483        This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
484        has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
485        is unlikely to come up.
486        """
487        columns = self._ensure_and_normalize_cols(cols)
488        pre_ordered_col_indexes = [
489            x
490            for x in [
491                i if isinstance(col.expression, exp.Ordered) else None
492                for i, col in enumerate(columns)
493            ]
494            if x is not None
495        ]
496        if ascending is None:
497            ascending = [True] * len(columns)
498        elif not isinstance(ascending, list):
499            ascending = [ascending] * len(columns)
500        ascending = [bool(x) for i, x in enumerate(ascending)]
501        assert len(columns) == len(
502            ascending
503        ), "The length of items in ascending must equal the number of columns provided"
504        col_and_ascending = list(zip(columns, ascending))
505        order_by_columns = [
506            exp.Ordered(this=col.expression, desc=not asc)
507            if i not in pre_ordered_col_indexes
508            else columns[i].column_expression
509            for i, (col, asc) in enumerate(col_and_ascending)
510        ]
511        return self.copy(expression=self.expression.order_by(*order_by_columns))
512
513    sort = orderBy
514
515    @operation(Operation.FROM)
516    def union(self, other: DataFrame) -> DataFrame:
517        return self._set_operation(exp.Union, other, False)
518
519    unionAll = union
520
521    @operation(Operation.FROM)
522    def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
523        l_columns = self.columns
524        r_columns = other.columns
525        if not allowMissingColumns:
526            l_expressions = l_columns
527            r_expressions = l_columns
528        else:
529            l_expressions = []
530            r_expressions = []
531            r_columns_unused = copy(r_columns)
532            for l_column in l_columns:
533                l_expressions.append(l_column)
534                if l_column in r_columns:
535                    r_expressions.append(l_column)
536                    r_columns_unused.remove(l_column)
537                else:
538                    r_expressions.append(exp.alias_(exp.Null(), l_column))
539            for r_column in r_columns_unused:
540                l_expressions.append(exp.alias_(exp.Null(), r_column))
541                r_expressions.append(r_column)
542        r_df = (
543            other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
544        )
545        l_df = self.copy()
546        if allowMissingColumns:
547            l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
548        return l_df._set_operation(exp.Union, r_df, False)
549
550    @operation(Operation.FROM)
551    def intersect(self, other: DataFrame) -> DataFrame:
552        return self._set_operation(exp.Intersect, other, True)
553
554    @operation(Operation.FROM)
555    def intersectAll(self, other: DataFrame) -> DataFrame:
556        return self._set_operation(exp.Intersect, other, False)
557
558    @operation(Operation.FROM)
559    def exceptAll(self, other: DataFrame) -> DataFrame:
560        return self._set_operation(exp.Except, other, False)
561
562    @operation(Operation.SELECT)
563    def distinct(self) -> DataFrame:
564        return self.copy(expression=self.expression.distinct())
565
566    @operation(Operation.SELECT)
567    def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
568        if not subset:
569            return self.distinct()
570        column_names = ensure_list(subset)
571        window = Window.partitionBy(*column_names).orderBy(*column_names)
572        return (
573            self.copy()
574            .withColumn("row_num", F.row_number().over(window))
575            .where(F.col("row_num") == F.lit(1))
576            .drop("row_num")
577        )
578
579    @operation(Operation.FROM)
580    def dropna(
581        self,
582        how: str = "any",
583        thresh: t.Optional[int] = None,
584        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
585    ) -> DataFrame:
586        minimum_non_null = thresh or 0  # will be determined later if thresh is null
587        new_df = self.copy()
588        all_columns = self._get_outer_select_columns(new_df.expression)
589        if subset:
590            null_check_columns = self._ensure_and_normalize_cols(subset)
591        else:
592            null_check_columns = all_columns
593        if thresh is None:
594            minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
595        else:
596            minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
597        if minimum_num_nulls > len(null_check_columns):
598            raise RuntimeError(
599                f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
600                f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
601            )
602        if_null_checks = [
603            F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
604        ]
605        nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
606        num_nulls = nulls_added_together.alias("num_nulls")
607        new_df = new_df.select(num_nulls, append=True)
608        filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
609        final_df = filtered_df.select(*all_columns)
610        return final_df
611
612    @operation(Operation.FROM)
613    def fillna(
614        self,
615        value: t.Union[ColumnLiterals],
616        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
617    ) -> DataFrame:
618        """
619        Functionality Difference: If you provide a value to replace a null and that type conflicts
620        with the type of the column then PySpark will just ignore your replacement.
621        This will try to cast them to be the same in some cases. So they won't always match.
622        Best to not mix types so make sure replacement is the same type as the column
623
624        Possibility for improvement: Use `typeof` function to get the type of the column
625        and check if it matches the type of the value provided. If not then make it null.
626        """
627        from sqlglot.dataframe.sql.functions import lit
628
629        values = None
630        columns = None
631        new_df = self.copy()
632        all_columns = self._get_outer_select_columns(new_df.expression)
633        all_column_mapping = {column.alias_or_name: column for column in all_columns}
634        if isinstance(value, dict):
635            values = list(value.values())
636            columns = self._ensure_and_normalize_cols(list(value))
637        if not columns:
638            columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
639        if not values:
640            values = [value] * len(columns)
641        value_columns = [lit(value) for value in values]
642
643        null_replacement_mapping = {
644            column.alias_or_name: (
645                F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
646            )
647            for column, value in zip(columns, value_columns)
648        }
649        null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
650        null_replacement_columns = [
651            null_replacement_mapping[column.alias_or_name] for column in all_columns
652        ]
653        new_df = new_df.select(*null_replacement_columns)
654        return new_df
655
656    @operation(Operation.FROM)
657    def replace(
658        self,
659        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
660        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
661        subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
662    ) -> DataFrame:
663        from sqlglot.dataframe.sql.functions import lit
664
665        old_values = None
666        new_df = self.copy()
667        all_columns = self._get_outer_select_columns(new_df.expression)
668        all_column_mapping = {column.alias_or_name: column for column in all_columns}
669
670        columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
671        if isinstance(to_replace, dict):
672            old_values = list(to_replace)
673            new_values = list(to_replace.values())
674        elif not old_values and isinstance(to_replace, list):
675            assert isinstance(value, list), "value must be a list since the replacements are a list"
676            assert len(to_replace) == len(
677                value
678            ), "the replacements and values must be the same length"
679            old_values = to_replace
680            new_values = value
681        else:
682            old_values = [to_replace] * len(columns)
683            new_values = [value] * len(columns)
684        old_values = [lit(value) for value in old_values]
685        new_values = [lit(value) for value in new_values]
686
687        replacement_mapping = {}
688        for column in columns:
689            expression = Column(None)
690            for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
691                if i == 0:
692                    expression = F.when(column == old_value, new_value)
693                else:
694                    expression = expression.when(column == old_value, new_value)  # type: ignore
695            replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
696                column.expression.alias_or_name
697            )
698
699        replacement_mapping = {**all_column_mapping, **replacement_mapping}
700        replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
701        new_df = new_df.select(*replacement_columns)
702        return new_df
703
704    @operation(Operation.SELECT)
705    def withColumn(self, colName: str, col: Column) -> DataFrame:
706        col = self._ensure_and_normalize_col(col)
707        existing_col_names = self.expression.named_selects
708        existing_col_index = (
709            existing_col_names.index(colName) if colName in existing_col_names else None
710        )
711        if existing_col_index:
712            expression = self.expression.copy()
713            expression.expressions[existing_col_index] = col.expression
714            return self.copy(expression=expression)
715        return self.copy().select(col.alias(colName), append=True)
716
717    @operation(Operation.SELECT)
718    def withColumnRenamed(self, existing: str, new: str):
719        expression = self.expression.copy()
720        existing_columns = [
721            expression
722            for expression in expression.expressions
723            if expression.alias_or_name == existing
724        ]
725        if not existing_columns:
726            raise ValueError("Tried to rename a column that doesn't exist")
727        for existing_column in existing_columns:
728            if isinstance(existing_column, exp.Column):
729                existing_column.replace(exp.alias_(existing_column.copy(), new))
730            else:
731                existing_column.set("alias", exp.to_identifier(new))
732        return self.copy(expression=expression)
733
734    @operation(Operation.SELECT)
735    def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
736        all_columns = self._get_outer_select_columns(self.expression)
737        drop_cols = self._ensure_and_normalize_cols(cols)
738        new_columns = [
739            col
740            for col in all_columns
741            if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
742        ]
743        return self.copy().select(*new_columns, append=False)
744
745    @operation(Operation.LIMIT)
746    def limit(self, num: int) -> DataFrame:
747        return self.copy(expression=self.expression.limit(num))
748
749    @operation(Operation.NO_OP)
750    def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
751        parameter_list = ensure_list(parameters)
752        parameter_columns = (
753            self._ensure_list_of_columns(parameter_list)
754            if parameters
755            else Column.ensure_cols([self.sequence_id])
756        )
757        return self._hint(name, parameter_columns)
758
759    @operation(Operation.NO_OP)
760    def repartition(
761        self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
762    ) -> DataFrame:
763        num_partition_cols = self._ensure_list_of_columns(numPartitions)
764        columns = self._ensure_and_normalize_cols(cols)
765        args = num_partition_cols + columns
766        return self._hint("repartition", args)
767
768    @operation(Operation.NO_OP)
769    def coalesce(self, numPartitions: int) -> DataFrame:
770        num_partitions = Column.ensure_cols([numPartitions])
771        return self._hint("coalesce", num_partitions)
772
773    @operation(Operation.NO_OP)
774    def cache(self) -> DataFrame:
775        return self._cache(storage_level="MEMORY_AND_DISK")
776
777    @operation(Operation.NO_OP)
778    def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
779        """
780        Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
781        """
782        return self._cache(storageLevel)
DataFrame( spark: <MagicMock id='140483442251840'>, expression: sqlglot.expressions.Select, branch_id: Optional[str] = None, sequence_id: Optional[str] = None, last_op: sqlglot.dataframe.sql.operations.Operation = <Operation.INIT: -1>, pending_hints: Optional[List[sqlglot.expressions.Expression]] = None, output_expression_container: Optional[<MagicMock id='140483442487456'>] = None, **kwargs)
47    def __init__(
48        self,
49        spark: SparkSession,
50        expression: exp.Select,
51        branch_id: t.Optional[str] = None,
52        sequence_id: t.Optional[str] = None,
53        last_op: Operation = Operation.INIT,
54        pending_hints: t.Optional[t.List[exp.Expression]] = None,
55        output_expression_container: t.Optional[OutputExpressionContainer] = None,
56        **kwargs,
57    ):
58        self.spark = spark
59        self.expression = expression
60        self.branch_id = branch_id or self.spark._random_branch_id
61        self.sequence_id = sequence_id or self.spark._random_sequence_id
62        self.last_op = last_op
63        self.pending_hints = pending_hints or []
64        self.output_expression_container = output_expression_container or exp.Select()
def sql(self, dialect='spark', optimize=True, **kwargs) -> List[str]:
296    def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
297        df = self._resolve_pending_hints()
298        select_expressions = df._get_select_expressions()
299        output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
300        replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
301        for expression_type, select_expression in select_expressions:
302            select_expression = select_expression.transform(replace_id_value, replacement_mapping)
303            if optimize:
304                select_expression = optimize_func(select_expression)
305            select_expression = df._replace_cte_names_with_hashes(select_expression)
306            expression: t.Union[exp.Select, exp.Cache, exp.Drop]
307            if expression_type == exp.Cache:
308                cache_table_name = df._create_hash_from_expression(select_expression)
309                cache_table = exp.to_table(cache_table_name)
310                original_alias_name = select_expression.args["cte_alias_name"]
311
312                replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(  # type: ignore
313                    cache_table_name
314                )
315                sqlglot.schema.add_table(
316                    cache_table_name,
317                    {
318                        expression.alias_or_name: expression.type.sql("spark")
319                        for expression in select_expression.expressions
320                    },
321                )
322                cache_storage_level = select_expression.args["cache_storage_level"]
323                options = [
324                    exp.Literal.string("storageLevel"),
325                    exp.Literal.string(cache_storage_level),
326                ]
327                expression = exp.Cache(
328                    this=cache_table, expression=select_expression, lazy=True, options=options
329                )
330                # We will drop the "view" if it exists before running the cache table
331                output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
332            elif expression_type == exp.Create:
333                expression = df.output_expression_container.copy()
334                expression.set("expression", select_expression)
335            elif expression_type == exp.Insert:
336                expression = df.output_expression_container.copy()
337                select_without_ctes = select_expression.copy()
338                select_without_ctes.set("with", None)
339                expression.set("expression", select_without_ctes)
340                if select_expression.ctes:
341                    expression.set("with", exp.With(expressions=select_expression.ctes))
342            elif expression_type == exp.Select:
343                expression = select_expression
344            else:
345                raise ValueError(f"Invalid expression type: {expression_type}")
346            output_expressions.append(expression)
347
348        return [
349            expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions
350        ]
def copy(self, **kwargs) -> sqlglot.dataframe.sql.DataFrame:
352    def copy(self, **kwargs) -> DataFrame:
353        return DataFrame(**object_to_dict(self, **kwargs))
@operation(Operation.SELECT)
def select(self, *cols, **kwargs) -> sqlglot.dataframe.sql.DataFrame:
355    @operation(Operation.SELECT)
356    def select(self, *cols, **kwargs) -> DataFrame:
357        cols = self._ensure_and_normalize_cols(cols)
358        kwargs["append"] = kwargs.get("append", False)
359        if self.expression.args.get("joins"):
360            ambiguous_cols = [col for col in cols if not col.column_expression.table]
361            if ambiguous_cols:
362                join_table_identifiers = [
363                    x.this for x in get_tables_from_expression_with_join(self.expression)
364                ]
365                cte_names_in_join = [x.this for x in join_table_identifiers]
366                for ambiguous_col in ambiguous_cols:
367                    ctes_with_column = [
368                        cte
369                        for cte in self.expression.ctes
370                        if cte.alias_or_name in cte_names_in_join
371                        and ambiguous_col.alias_or_name in cte.this.named_selects
372                    ]
373                    # If the select column does not specify a table and there is a join
374                    # then we assume they are referring to the left table
375                    if len(ctes_with_column) > 1:
376                        table_identifier = self.expression.args["from"].args["expressions"][0].this
377                    else:
378                        table_identifier = ctes_with_column[0].args["alias"].this
379                    ambiguous_col.expression.set("table", table_identifier)
380        expression = self.expression.select(*[x.expression for x in cols], **kwargs)
381        qualify_columns(expression, sqlglot.schema)
382        return self.copy(expression=expression, **kwargs)
@operation(Operation.NO_OP)
def alias(self, name: str, **kwargs) -> sqlglot.dataframe.sql.DataFrame:
384    @operation(Operation.NO_OP)
385    def alias(self, name: str, **kwargs) -> DataFrame:
386        new_sequence_id = self.spark._random_sequence_id
387        df = self.copy()
388        for join_hint in df.pending_join_hints:
389            for expression in join_hint.expressions:
390                if expression.alias_or_name == self.sequence_id:
391                    expression.set("this", Column.ensure_col(new_sequence_id).expression)
392        df.spark._add_alias_to_mapping(name, new_sequence_id)
393        return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
@operation(Operation.WHERE)
def where( self, column: Union[sqlglot.dataframe.sql.Column, bool], **kwargs) -> sqlglot.dataframe.sql.DataFrame:
395    @operation(Operation.WHERE)
396    def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
397        col = self._ensure_and_normalize_col(column)
398        return self.copy(expression=self.expression.where(col.expression))
@operation(Operation.WHERE)
def filter( self, column: Union[sqlglot.dataframe.sql.Column, bool], **kwargs) -> sqlglot.dataframe.sql.DataFrame:
395    @operation(Operation.WHERE)
396    def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
397        col = self._ensure_and_normalize_col(column)
398        return self.copy(expression=self.expression.where(col.expression))
@operation(Operation.GROUP_BY)
def groupBy(self, *cols, **kwargs) -> sqlglot.dataframe.sql.GroupedData:
402    @operation(Operation.GROUP_BY)
403    def groupBy(self, *cols, **kwargs) -> GroupedData:
404        columns = self._ensure_and_normalize_cols(cols)
405        return GroupedData(self, columns, self.last_op)
@operation(Operation.SELECT)
def agg(self, *exprs, **kwargs) -> sqlglot.dataframe.sql.DataFrame:
407    @operation(Operation.SELECT)
408    def agg(self, *exprs, **kwargs) -> DataFrame:
409        cols = self._ensure_and_normalize_cols(exprs)
410        return self.groupBy().agg(*cols)
@operation(Operation.FROM)
def join( self, other_df: sqlglot.dataframe.sql.DataFrame, on: Union[str, List[str], sqlglot.dataframe.sql.Column, List[sqlglot.dataframe.sql.Column]], how: str = 'inner', **kwargs) -> sqlglot.dataframe.sql.DataFrame:
412    @operation(Operation.FROM)
413    def join(
414        self,
415        other_df: DataFrame,
416        on: t.Union[str, t.List[str], Column, t.List[Column]],
417        how: str = "inner",
418        **kwargs,
419    ) -> DataFrame:
420        other_df = other_df._convert_leaf_to_cte()
421        pre_join_self_latest_cte_name = self.latest_cte_name
422        columns = self._ensure_and_normalize_cols(on)
423        join_type = how.replace("_", " ")
424        if isinstance(columns[0].expression, exp.Column):
425            join_columns = [
426                Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns
427            ]
428            join_clause = functools.reduce(
429                lambda x, y: x & y,
430                [
431                    col.copy().set_table_name(pre_join_self_latest_cte_name)
432                    == col.copy().set_table_name(other_df.latest_cte_name)
433                    for col in columns
434                ],
435            )
436        else:
437            if len(columns) > 1:
438                columns = [functools.reduce(lambda x, y: x & y, columns)]
439            join_clause = columns[0]
440            join_columns = [
441                Column(x).set_table_name(pre_join_self_latest_cte_name)
442                if i % 2 == 0
443                else Column(x).set_table_name(other_df.latest_cte_name)
444                for i, x in enumerate(join_clause.expression.find_all(exp.Column))
445            ]
446        self_columns = [
447            column.set_table_name(pre_join_self_latest_cte_name, copy=True)
448            for column in self._get_outer_select_columns(self)
449        ]
450        other_columns = [
451            column.set_table_name(other_df.latest_cte_name, copy=True)
452            for column in self._get_outer_select_columns(other_df)
453        ]
454        column_value_mapping = {
455            column.alias_or_name
456            if not isinstance(column.expression.this, exp.Star)
457            else column.sql(): column
458            for column in other_columns + self_columns + join_columns
459        }
460        all_columns = [
461            column_value_mapping[name]
462            for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns}
463        ]
464        new_df = self.copy(
465            expression=self.expression.join(
466                other_df.latest_cte_name, on=join_clause.expression, join_type=join_type
467            )
468        )
469        new_df.expression = new_df._add_ctes_to_expression(
470            new_df.expression, other_df.expression.ctes
471        )
472        new_df.pending_hints.extend(other_df.pending_hints)
473        new_df = new_df.select.__wrapped__(new_df, *all_columns)
474        return new_df
@operation(Operation.ORDER_BY)
def orderBy( self, *cols: Union[str, sqlglot.dataframe.sql.Column], ascending: Union[Any, List[Any], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
476    @operation(Operation.ORDER_BY)
477    def orderBy(
478        self,
479        *cols: t.Union[str, Column],
480        ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
481    ) -> DataFrame:
482        """
483        This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
484        has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
485        is unlikely to come up.
486        """
487        columns = self._ensure_and_normalize_cols(cols)
488        pre_ordered_col_indexes = [
489            x
490            for x in [
491                i if isinstance(col.expression, exp.Ordered) else None
492                for i, col in enumerate(columns)
493            ]
494            if x is not None
495        ]
496        if ascending is None:
497            ascending = [True] * len(columns)
498        elif not isinstance(ascending, list):
499            ascending = [ascending] * len(columns)
500        ascending = [bool(x) for i, x in enumerate(ascending)]
501        assert len(columns) == len(
502            ascending
503        ), "The length of items in ascending must equal the number of columns provided"
504        col_and_ascending = list(zip(columns, ascending))
505        order_by_columns = [
506            exp.Ordered(this=col.expression, desc=not asc)
507            if i not in pre_ordered_col_indexes
508            else columns[i].column_expression
509            for i, (col, asc) in enumerate(col_and_ascending)
510        ]
511        return self.copy(expression=self.expression.order_by(*order_by_columns))

This implementation lets any ordered columns take priority over whatever is provided in ascending. Spark has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this is unlikely to come up.

@operation(Operation.ORDER_BY)
def sort( self, *cols: Union[str, sqlglot.dataframe.sql.Column], ascending: Union[Any, List[Any], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
476    @operation(Operation.ORDER_BY)
477    def orderBy(
478        self,
479        *cols: t.Union[str, Column],
480        ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
481    ) -> DataFrame:
482        """
483        This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
484        has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
485        is unlikely to come up.
486        """
487        columns = self._ensure_and_normalize_cols(cols)
488        pre_ordered_col_indexes = [
489            x
490            for x in [
491                i if isinstance(col.expression, exp.Ordered) else None
492                for i, col in enumerate(columns)
493            ]
494            if x is not None
495        ]
496        if ascending is None:
497            ascending = [True] * len(columns)
498        elif not isinstance(ascending, list):
499            ascending = [ascending] * len(columns)
500        ascending = [bool(x) for i, x in enumerate(ascending)]
501        assert len(columns) == len(
502            ascending
503        ), "The length of items in ascending must equal the number of columns provided"
504        col_and_ascending = list(zip(columns, ascending))
505        order_by_columns = [
506            exp.Ordered(this=col.expression, desc=not asc)
507            if i not in pre_ordered_col_indexes
508            else columns[i].column_expression
509            for i, (col, asc) in enumerate(col_and_ascending)
510        ]
511        return self.copy(expression=self.expression.order_by(*order_by_columns))

This implementation lets any ordered columns take priority over whatever is provided in ascending. Spark has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this is unlikely to come up.

@operation(Operation.FROM)
def union( self, other: sqlglot.dataframe.sql.DataFrame) -> sqlglot.dataframe.sql.DataFrame:
515    @operation(Operation.FROM)
516    def union(self, other: DataFrame) -> DataFrame:
517        return self._set_operation(exp.Union, other, False)
@operation(Operation.FROM)
def unionAll( self, other: sqlglot.dataframe.sql.DataFrame) -> sqlglot.dataframe.sql.DataFrame:
515    @operation(Operation.FROM)
516    def union(self, other: DataFrame) -> DataFrame:
517        return self._set_operation(exp.Union, other, False)
@operation(Operation.FROM)
def unionByName( self, other: sqlglot.dataframe.sql.DataFrame, allowMissingColumns: bool = False):
521    @operation(Operation.FROM)
522    def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
523        l_columns = self.columns
524        r_columns = other.columns
525        if not allowMissingColumns:
526            l_expressions = l_columns
527            r_expressions = l_columns
528        else:
529            l_expressions = []
530            r_expressions = []
531            r_columns_unused = copy(r_columns)
532            for l_column in l_columns:
533                l_expressions.append(l_column)
534                if l_column in r_columns:
535                    r_expressions.append(l_column)
536                    r_columns_unused.remove(l_column)
537                else:
538                    r_expressions.append(exp.alias_(exp.Null(), l_column))
539            for r_column in r_columns_unused:
540                l_expressions.append(exp.alias_(exp.Null(), r_column))
541                r_expressions.append(r_column)
542        r_df = (
543            other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
544        )
545        l_df = self.copy()
546        if allowMissingColumns:
547            l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
548        return l_df._set_operation(exp.Union, r_df, False)
@operation(Operation.FROM)
def intersect( self, other: sqlglot.dataframe.sql.DataFrame) -> sqlglot.dataframe.sql.DataFrame:
550    @operation(Operation.FROM)
551    def intersect(self, other: DataFrame) -> DataFrame:
552        return self._set_operation(exp.Intersect, other, True)
@operation(Operation.FROM)
def intersectAll( self, other: sqlglot.dataframe.sql.DataFrame) -> sqlglot.dataframe.sql.DataFrame:
554    @operation(Operation.FROM)
555    def intersectAll(self, other: DataFrame) -> DataFrame:
556        return self._set_operation(exp.Intersect, other, False)
@operation(Operation.FROM)
def exceptAll( self, other: sqlglot.dataframe.sql.DataFrame) -> sqlglot.dataframe.sql.DataFrame:
558    @operation(Operation.FROM)
559    def exceptAll(self, other: DataFrame) -> DataFrame:
560        return self._set_operation(exp.Except, other, False)
@operation(Operation.SELECT)
def distinct(self) -> sqlglot.dataframe.sql.DataFrame:
562    @operation(Operation.SELECT)
563    def distinct(self) -> DataFrame:
564        return self.copy(expression=self.expression.distinct())
@operation(Operation.SELECT)
def dropDuplicates(self, subset: Optional[List[str]] = None):
566    @operation(Operation.SELECT)
567    def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
568        if not subset:
569            return self.distinct()
570        column_names = ensure_list(subset)
571        window = Window.partitionBy(*column_names).orderBy(*column_names)
572        return (
573            self.copy()
574            .withColumn("row_num", F.row_number().over(window))
575            .where(F.col("row_num") == F.lit(1))
576            .drop("row_num")
577        )
@operation(Operation.FROM)
def dropna( self, how: str = 'any', thresh: Optional[int] = None, subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
579    @operation(Operation.FROM)
580    def dropna(
581        self,
582        how: str = "any",
583        thresh: t.Optional[int] = None,
584        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
585    ) -> DataFrame:
586        minimum_non_null = thresh or 0  # will be determined later if thresh is null
587        new_df = self.copy()
588        all_columns = self._get_outer_select_columns(new_df.expression)
589        if subset:
590            null_check_columns = self._ensure_and_normalize_cols(subset)
591        else:
592            null_check_columns = all_columns
593        if thresh is None:
594            minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
595        else:
596            minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
597        if minimum_num_nulls > len(null_check_columns):
598            raise RuntimeError(
599                f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
600                f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
601            )
602        if_null_checks = [
603            F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
604        ]
605        nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
606        num_nulls = nulls_added_together.alias("num_nulls")
607        new_df = new_df.select(num_nulls, append=True)
608        filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
609        final_df = filtered_df.select(*all_columns)
610        return final_df
@operation(Operation.FROM)
def fillna( self, value: <MagicMock id='140483439601424'>, subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
612    @operation(Operation.FROM)
613    def fillna(
614        self,
615        value: t.Union[ColumnLiterals],
616        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
617    ) -> DataFrame:
618        """
619        Functionality Difference: If you provide a value to replace a null and that type conflicts
620        with the type of the column then PySpark will just ignore your replacement.
621        This will try to cast them to be the same in some cases. So they won't always match.
622        Best to not mix types so make sure replacement is the same type as the column
623
624        Possibility for improvement: Use `typeof` function to get the type of the column
625        and check if it matches the type of the value provided. If not then make it null.
626        """
627        from sqlglot.dataframe.sql.functions import lit
628
629        values = None
630        columns = None
631        new_df = self.copy()
632        all_columns = self._get_outer_select_columns(new_df.expression)
633        all_column_mapping = {column.alias_or_name: column for column in all_columns}
634        if isinstance(value, dict):
635            values = list(value.values())
636            columns = self._ensure_and_normalize_cols(list(value))
637        if not columns:
638            columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
639        if not values:
640            values = [value] * len(columns)
641        value_columns = [lit(value) for value in values]
642
643        null_replacement_mapping = {
644            column.alias_or_name: (
645                F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
646            )
647            for column, value in zip(columns, value_columns)
648        }
649        null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
650        null_replacement_columns = [
651            null_replacement_mapping[column.alias_or_name] for column in all_columns
652        ]
653        new_df = new_df.select(*null_replacement_columns)
654        return new_df

Functionality Difference: If you provide a value to replace a null and that type conflicts with the type of the column then PySpark will just ignore your replacement. This will try to cast them to be the same in some cases. So they won't always match. Best to not mix types so make sure replacement is the same type as the column

Possibility for improvement: Use typeof function to get the type of the column and check if it matches the type of the value provided. If not then make it null.

@operation(Operation.FROM)
def replace( self, to_replace: Union[bool, int, float, str, List, Dict], value: Union[bool, int, float, str, List, NoneType] = None, subset: Union[Collection[<MagicMock id='140483439717408'>], <MagicMock id='140483439717408'>, NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
656    @operation(Operation.FROM)
657    def replace(
658        self,
659        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
660        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
661        subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
662    ) -> DataFrame:
663        from sqlglot.dataframe.sql.functions import lit
664
665        old_values = None
666        new_df = self.copy()
667        all_columns = self._get_outer_select_columns(new_df.expression)
668        all_column_mapping = {column.alias_or_name: column for column in all_columns}
669
670        columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
671        if isinstance(to_replace, dict):
672            old_values = list(to_replace)
673            new_values = list(to_replace.values())
674        elif not old_values and isinstance(to_replace, list):
675            assert isinstance(value, list), "value must be a list since the replacements are a list"
676            assert len(to_replace) == len(
677                value
678            ), "the replacements and values must be the same length"
679            old_values = to_replace
680            new_values = value
681        else:
682            old_values = [to_replace] * len(columns)
683            new_values = [value] * len(columns)
684        old_values = [lit(value) for value in old_values]
685        new_values = [lit(value) for value in new_values]
686
687        replacement_mapping = {}
688        for column in columns:
689            expression = Column(None)
690            for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
691                if i == 0:
692                    expression = F.when(column == old_value, new_value)
693                else:
694                    expression = expression.when(column == old_value, new_value)  # type: ignore
695            replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
696                column.expression.alias_or_name
697            )
698
699        replacement_mapping = {**all_column_mapping, **replacement_mapping}
700        replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
701        new_df = new_df.select(*replacement_columns)
702        return new_df
@operation(Operation.SELECT)
def withColumn( self, colName: str, col: sqlglot.dataframe.sql.Column) -> sqlglot.dataframe.sql.DataFrame:
704    @operation(Operation.SELECT)
705    def withColumn(self, colName: str, col: Column) -> DataFrame:
706        col = self._ensure_and_normalize_col(col)
707        existing_col_names = self.expression.named_selects
708        existing_col_index = (
709            existing_col_names.index(colName) if colName in existing_col_names else None
710        )
711        if existing_col_index:
712            expression = self.expression.copy()
713            expression.expressions[existing_col_index] = col.expression
714            return self.copy(expression=expression)
715        return self.copy().select(col.alias(colName), append=True)
@operation(Operation.SELECT)
def withColumnRenamed(self, existing: str, new: str):
717    @operation(Operation.SELECT)
718    def withColumnRenamed(self, existing: str, new: str):
719        expression = self.expression.copy()
720        existing_columns = [
721            expression
722            for expression in expression.expressions
723            if expression.alias_or_name == existing
724        ]
725        if not existing_columns:
726            raise ValueError("Tried to rename a column that doesn't exist")
727        for existing_column in existing_columns:
728            if isinstance(existing_column, exp.Column):
729                existing_column.replace(exp.alias_(existing_column.copy(), new))
730            else:
731                existing_column.set("alias", exp.to_identifier(new))
732        return self.copy(expression=expression)
@operation(Operation.SELECT)
def drop( self, *cols: Union[str, sqlglot.dataframe.sql.Column]) -> sqlglot.dataframe.sql.DataFrame:
734    @operation(Operation.SELECT)
735    def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
736        all_columns = self._get_outer_select_columns(self.expression)
737        drop_cols = self._ensure_and_normalize_cols(cols)
738        new_columns = [
739            col
740            for col in all_columns
741            if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
742        ]
743        return self.copy().select(*new_columns, append=False)
@operation(Operation.LIMIT)
def limit(self, num: int) -> sqlglot.dataframe.sql.DataFrame:
745    @operation(Operation.LIMIT)
746    def limit(self, num: int) -> DataFrame:
747        return self.copy(expression=self.expression.limit(num))
@operation(Operation.NO_OP)
def hint( self, name: str, *parameters: Union[str, int, NoneType]) -> sqlglot.dataframe.sql.DataFrame:
749    @operation(Operation.NO_OP)
750    def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
751        parameter_list = ensure_list(parameters)
752        parameter_columns = (
753            self._ensure_list_of_columns(parameter_list)
754            if parameters
755            else Column.ensure_cols([self.sequence_id])
756        )
757        return self._hint(name, parameter_columns)
@operation(Operation.NO_OP)
def repartition( self, numPartitions: Union[int, <MagicMock id='140483439886592'>], *cols: <MagicMock id='140483439948896'>) -> sqlglot.dataframe.sql.DataFrame:
759    @operation(Operation.NO_OP)
760    def repartition(
761        self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
762    ) -> DataFrame:
763        num_partition_cols = self._ensure_list_of_columns(numPartitions)
764        columns = self._ensure_and_normalize_cols(cols)
765        args = num_partition_cols + columns
766        return self._hint("repartition", args)
@operation(Operation.NO_OP)
def coalesce(self, numPartitions: int) -> sqlglot.dataframe.sql.DataFrame:
768    @operation(Operation.NO_OP)
769    def coalesce(self, numPartitions: int) -> DataFrame:
770        num_partitions = Column.ensure_cols([numPartitions])
771        return self._hint("coalesce", num_partitions)
@operation(Operation.NO_OP)
def cache(self) -> sqlglot.dataframe.sql.DataFrame:
773    @operation(Operation.NO_OP)
774    def cache(self) -> DataFrame:
775        return self._cache(storage_level="MEMORY_AND_DISK")
@operation(Operation.NO_OP)
def persist( self, storageLevel: str = 'MEMORY_AND_DISK_SER') -> sqlglot.dataframe.sql.DataFrame:
777    @operation(Operation.NO_OP)
778    def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
779        """
780        Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
781        """
782        return self._cache(storageLevel)
class GroupedData:
14class GroupedData:
15    def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation):
16        self._df = df.copy()
17        self.spark = df.spark
18        self.last_op = last_op
19        self.group_by_cols = group_by_cols
20
21    def _get_function_applied_columns(
22        self, func_name: str, cols: t.Tuple[str, ...]
23    ) -> t.List[Column]:
24        func_name = func_name.lower()
25        return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols]
26
27    @operation(Operation.SELECT)
28    def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame:
29        columns = (
30            [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
31            if isinstance(exprs[0], dict)
32            else exprs
33        )
34        cols = self._df._ensure_and_normalize_cols(columns)
35
36        expression = self._df.expression.group_by(
37            *[x.expression for x in self.group_by_cols]
38        ).select(*[x.expression for x in self.group_by_cols + cols], append=False)
39        return self._df.copy(expression=expression)
40
41    def count(self) -> DataFrame:
42        return self.agg(F.count("*").alias("count"))
43
44    def mean(self, *cols: str) -> DataFrame:
45        return self.avg(*cols)
46
47    def avg(self, *cols: str) -> DataFrame:
48        return self.agg(*self._get_function_applied_columns("avg", cols))
49
50    def max(self, *cols: str) -> DataFrame:
51        return self.agg(*self._get_function_applied_columns("max", cols))
52
53    def min(self, *cols: str) -> DataFrame:
54        return self.agg(*self._get_function_applied_columns("min", cols))
55
56    def sum(self, *cols: str) -> DataFrame:
57        return self.agg(*self._get_function_applied_columns("sum", cols))
58
59    def pivot(self, *cols: str) -> DataFrame:
60        raise NotImplementedError("Sum distinct is not currently implemented")
GroupedData( df: sqlglot.dataframe.sql.DataFrame, group_by_cols: List[sqlglot.dataframe.sql.Column], last_op: sqlglot.dataframe.sql.operations.Operation)
15    def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation):
16        self._df = df.copy()
17        self.spark = df.spark
18        self.last_op = last_op
19        self.group_by_cols = group_by_cols
@operation(Operation.SELECT)
def agg( self, *exprs: Union[sqlglot.dataframe.sql.Column, Dict[str, str]]) -> sqlglot.dataframe.sql.DataFrame:
27    @operation(Operation.SELECT)
28    def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame:
29        columns = (
30            [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
31            if isinstance(exprs[0], dict)
32            else exprs
33        )
34        cols = self._df._ensure_and_normalize_cols(columns)
35
36        expression = self._df.expression.group_by(
37            *[x.expression for x in self.group_by_cols]
38        ).select(*[x.expression for x in self.group_by_cols + cols], append=False)
39        return self._df.copy(expression=expression)
def count(self) -> sqlglot.dataframe.sql.DataFrame:
41    def count(self) -> DataFrame:
42        return self.agg(F.count("*").alias("count"))
def mean(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
44    def mean(self, *cols: str) -> DataFrame:
45        return self.avg(*cols)
def avg(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
47    def avg(self, *cols: str) -> DataFrame:
48        return self.agg(*self._get_function_applied_columns("avg", cols))
def max(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
50    def max(self, *cols: str) -> DataFrame:
51        return self.agg(*self._get_function_applied_columns("max", cols))
def min(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
53    def min(self, *cols: str) -> DataFrame:
54        return self.agg(*self._get_function_applied_columns("min", cols))
def sum(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
56    def sum(self, *cols: str) -> DataFrame:
57        return self.agg(*self._get_function_applied_columns("sum", cols))
def pivot(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
59    def pivot(self, *cols: str) -> DataFrame:
60        raise NotImplementedError("Sum distinct is not currently implemented")
class Column:
 16class Column:
 17    def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
 18        if isinstance(expression, Column):
 19            expression = expression.expression  # type: ignore
 20        elif expression is None or not isinstance(expression, (str, exp.Expression)):
 21            expression = self._lit(expression).expression  # type: ignore
 22
 23        expression = sqlglot.maybe_parse(expression, dialect="spark")
 24        if expression is None:
 25            raise ValueError(f"Could not parse {expression}")
 26        self.expression: exp.Expression = expression
 27
 28    def __repr__(self):
 29        return repr(self.expression)
 30
 31    def __hash__(self):
 32        return hash(self.expression)
 33
 34    def __eq__(self, other: ColumnOrLiteral) -> Column:  # type: ignore
 35        return self.binary_op(exp.EQ, other)
 36
 37    def __ne__(self, other: ColumnOrLiteral) -> Column:  # type: ignore
 38        return self.binary_op(exp.NEQ, other)
 39
 40    def __gt__(self, other: ColumnOrLiteral) -> Column:
 41        return self.binary_op(exp.GT, other)
 42
 43    def __ge__(self, other: ColumnOrLiteral) -> Column:
 44        return self.binary_op(exp.GTE, other)
 45
 46    def __lt__(self, other: ColumnOrLiteral) -> Column:
 47        return self.binary_op(exp.LT, other)
 48
 49    def __le__(self, other: ColumnOrLiteral) -> Column:
 50        return self.binary_op(exp.LTE, other)
 51
 52    def __and__(self, other: ColumnOrLiteral) -> Column:
 53        return self.binary_op(exp.And, other)
 54
 55    def __or__(self, other: ColumnOrLiteral) -> Column:
 56        return self.binary_op(exp.Or, other)
 57
 58    def __mod__(self, other: ColumnOrLiteral) -> Column:
 59        return self.binary_op(exp.Mod, other)
 60
 61    def __add__(self, other: ColumnOrLiteral) -> Column:
 62        return self.binary_op(exp.Add, other)
 63
 64    def __sub__(self, other: ColumnOrLiteral) -> Column:
 65        return self.binary_op(exp.Sub, other)
 66
 67    def __mul__(self, other: ColumnOrLiteral) -> Column:
 68        return self.binary_op(exp.Mul, other)
 69
 70    def __truediv__(self, other: ColumnOrLiteral) -> Column:
 71        return self.binary_op(exp.Div, other)
 72
 73    def __div__(self, other: ColumnOrLiteral) -> Column:
 74        return self.binary_op(exp.Div, other)
 75
 76    def __neg__(self) -> Column:
 77        return self.unary_op(exp.Neg)
 78
 79    def __radd__(self, other: ColumnOrLiteral) -> Column:
 80        return self.inverse_binary_op(exp.Add, other)
 81
 82    def __rsub__(self, other: ColumnOrLiteral) -> Column:
 83        return self.inverse_binary_op(exp.Sub, other)
 84
 85    def __rmul__(self, other: ColumnOrLiteral) -> Column:
 86        return self.inverse_binary_op(exp.Mul, other)
 87
 88    def __rdiv__(self, other: ColumnOrLiteral) -> Column:
 89        return self.inverse_binary_op(exp.Div, other)
 90
 91    def __rtruediv__(self, other: ColumnOrLiteral) -> Column:
 92        return self.inverse_binary_op(exp.Div, other)
 93
 94    def __rmod__(self, other: ColumnOrLiteral) -> Column:
 95        return self.inverse_binary_op(exp.Mod, other)
 96
 97    def __pow__(self, power: ColumnOrLiteral, modulo=None):
 98        return Column(exp.Pow(this=self.expression, expression=Column(power).expression))
 99
100    def __rpow__(self, power: ColumnOrLiteral):
101        return Column(exp.Pow(this=Column(power).expression, expression=self.expression))
102
103    def __invert__(self):
104        return self.unary_op(exp.Not)
105
106    def __rand__(self, other: ColumnOrLiteral) -> Column:
107        return self.inverse_binary_op(exp.And, other)
108
109    def __ror__(self, other: ColumnOrLiteral) -> Column:
110        return self.inverse_binary_op(exp.Or, other)
111
112    @classmethod
113    def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
114        return cls(value)
115
116    @classmethod
117    def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]:
118        return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args]
119
120    @classmethod
121    def _lit(cls, value: ColumnOrLiteral) -> Column:
122        if isinstance(value, dict):
123            columns = [cls._lit(v).alias(k).expression for k, v in value.items()]
124            return cls(exp.Struct(expressions=columns))
125        return cls(exp.convert(value))
126
127    @classmethod
128    def invoke_anonymous_function(
129        cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral]
130    ) -> Column:
131        columns = [] if column is None else [cls.ensure_col(column)]
132        column_args = [cls.ensure_col(arg) for arg in args]
133        expressions = [x.expression for x in columns + column_args]
134        new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions)
135        return Column(new_expression)
136
137    @classmethod
138    def invoke_expression_over_column(
139        cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
140    ) -> Column:
141        ensured_column = None if column is None else cls.ensure_col(column)
142        ensure_expression_values = {
143            k: [Column.ensure_col(x).expression for x in v]
144            if is_iterable(v)
145            else Column.ensure_col(v).expression
146            for k, v in kwargs.items()
147            if v is not None
148        }
149        new_expression = (
150            callable_expression(**ensure_expression_values)
151            if ensured_column is None
152            else callable_expression(
153                this=ensured_column.column_expression, **ensure_expression_values
154            )
155        )
156        return Column(new_expression)
157
158    def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
159        return Column(
160            klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
161        )
162
163    def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
164        return Column(
165            klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
166        )
167
168    def unary_op(self, klass: t.Callable, **kwargs) -> Column:
169        return Column(klass(this=self.column_expression, **kwargs))
170
171    @property
172    def is_alias(self):
173        return isinstance(self.expression, exp.Alias)
174
175    @property
176    def is_column(self):
177        return isinstance(self.expression, exp.Column)
178
179    @property
180    def column_expression(self) -> exp.Column:
181        return self.expression.unalias()
182
183    @property
184    def alias_or_name(self) -> str:
185        return self.expression.alias_or_name
186
187    @classmethod
188    def ensure_literal(cls, value) -> Column:
189        from sqlglot.dataframe.sql.functions import lit
190
191        if isinstance(value, cls):
192            value = value.expression
193        if not isinstance(value, exp.Literal):
194            return lit(value)
195        return Column(value)
196
197    def copy(self) -> Column:
198        return Column(self.expression.copy())
199
200    def set_table_name(self, table_name: str, copy=False) -> Column:
201        expression = self.expression.copy() if copy else self.expression
202        expression.set("table", exp.to_identifier(table_name))
203        return Column(expression)
204
205    def sql(self, **kwargs) -> str:
206        return self.expression.sql(**{"dialect": "spark", **kwargs})
207
208    def alias(self, name: str) -> Column:
209        new_expression = exp.alias_(self.column_expression, name)
210        return Column(new_expression)
211
212    def asc(self) -> Column:
213        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
214        return Column(new_expression)
215
216    def desc(self) -> Column:
217        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
218        return Column(new_expression)
219
220    asc_nulls_first = asc
221
222    def asc_nulls_last(self) -> Column:
223        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False)
224        return Column(new_expression)
225
226    def desc_nulls_first(self) -> Column:
227        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True)
228        return Column(new_expression)
229
230    desc_nulls_last = desc
231
232    def when(self, condition: Column, value: t.Any) -> Column:
233        from sqlglot.dataframe.sql.functions import when
234
235        column_with_if = when(condition, value)
236        if not isinstance(self.expression, exp.Case):
237            return column_with_if
238        new_column = self.copy()
239        new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"])
240        return new_column
241
242    def otherwise(self, value: t.Any) -> Column:
243        from sqlglot.dataframe.sql.functions import lit
244
245        true_value = value if isinstance(value, Column) else lit(value)
246        new_column = self.copy()
247        new_column.expression.set("default", true_value.column_expression)
248        return new_column
249
250    def isNull(self) -> Column:
251        new_expression = exp.Is(this=self.column_expression, expression=exp.Null())
252        return Column(new_expression)
253
254    def isNotNull(self) -> Column:
255        new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
256        return Column(new_expression)
257
258    def cast(self, dataType: t.Union[str, DataType]):
259        """
260        Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
261        Sqlglot doesn't currently replicate this class so it only accepts a string
262        """
263        if isinstance(dataType, DataType):
264            dataType = dataType.simpleString()
265        return Column(exp.cast(self.column_expression, dataType, dialect="spark"))
266
267    def startswith(self, value: t.Union[str, Column]) -> Column:
268        value = self._lit(value) if not isinstance(value, Column) else value
269        return self.invoke_anonymous_function(self, "STARTSWITH", value)
270
271    def endswith(self, value: t.Union[str, Column]) -> Column:
272        value = self._lit(value) if not isinstance(value, Column) else value
273        return self.invoke_anonymous_function(self, "ENDSWITH", value)
274
275    def rlike(self, regexp: str) -> Column:
276        return self.invoke_expression_over_column(
277            column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression
278        )
279
280    def like(self, other: str):
281        return self.invoke_expression_over_column(
282            self, exp.Like, expression=self._lit(other).expression
283        )
284
285    def ilike(self, other: str):
286        return self.invoke_expression_over_column(
287            self, exp.ILike, expression=self._lit(other).expression
288        )
289
290    def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
291        startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
292        length = self._lit(length) if not isinstance(length, Column) else length
293        return Column.invoke_expression_over_column(
294            self, exp.Substring, start=startPos.expression, length=length.expression
295        )
296
297    def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]):
298        columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
299        expressions = [self._lit(x).expression for x in columns]
300        return Column.invoke_expression_over_column(self, exp.In, expressions=expressions)  # type: ignore
301
302    def between(
303        self,
304        lowerBound: t.Union[ColumnOrLiteral],
305        upperBound: t.Union[ColumnOrLiteral],
306    ) -> Column:
307        lower_bound_exp = (
308            self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
309        )
310        upper_bound_exp = (
311            self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
312        )
313        return Column(
314            exp.Between(
315                this=self.column_expression,
316                low=lower_bound_exp.expression,
317                high=upper_bound_exp.expression,
318            )
319        )
320
321    def over(self, window: WindowSpec) -> Column:
322        window_expression = window.expression.copy()
323        window_expression.set("this", self.column_expression)
324        return Column(window_expression)
Column( expression: Union[<MagicMock id='140483439967056'>, sqlglot.expressions.Expression, NoneType])
17    def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
18        if isinstance(expression, Column):
19            expression = expression.expression  # type: ignore
20        elif expression is None or not isinstance(expression, (str, exp.Expression)):
21            expression = self._lit(expression).expression  # type: ignore
22
23        expression = sqlglot.maybe_parse(expression, dialect="spark")
24        if expression is None:
25            raise ValueError(f"Could not parse {expression}")
26        self.expression: exp.Expression = expression
@classmethod
def ensure_col( cls, value: Union[<MagicMock id='140483440277968'>, sqlglot.expressions.Expression, NoneType]):
112    @classmethod
113    def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
114        return cls(value)
@classmethod
def ensure_cols( cls, args: List[Union[<MagicMock id='140483438527104'>, sqlglot.expressions.Expression]]) -> List[sqlglot.dataframe.sql.Column]:
116    @classmethod
117    def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]:
118        return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args]
@classmethod
def invoke_anonymous_function( cls, column: Optional[<MagicMock id='140483438591760'>], func_name: str, *args: Optional[<MagicMock id='140483438648176'>]) -> sqlglot.dataframe.sql.Column:
127    @classmethod
128    def invoke_anonymous_function(
129        cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral]
130    ) -> Column:
131        columns = [] if column is None else [cls.ensure_col(column)]
132        column_args = [cls.ensure_col(arg) for arg in args]
133        expressions = [x.expression for x in columns + column_args]
134        new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions)
135        return Column(new_expression)
@classmethod
def invoke_expression_over_column( cls, column: Optional[<MagicMock id='140483438730096'>], callable_expression: Callable, **kwargs) -> sqlglot.dataframe.sql.Column:
137    @classmethod
138    def invoke_expression_over_column(
139        cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
140    ) -> Column:
141        ensured_column = None if column is None else cls.ensure_col(column)
142        ensure_expression_values = {
143            k: [Column.ensure_col(x).expression for x in v]
144            if is_iterable(v)
145            else Column.ensure_col(v).expression
146            for k, v in kwargs.items()
147            if v is not None
148        }
149        new_expression = (
150            callable_expression(**ensure_expression_values)
151            if ensured_column is None
152            else callable_expression(
153                this=ensured_column.column_expression, **ensure_expression_values
154            )
155        )
156        return Column(new_expression)
def binary_op( self, klass: Callable, other: <MagicMock id='140483438706640'>, **kwargs) -> sqlglot.dataframe.sql.Column:
158    def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
159        return Column(
160            klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
161        )
def inverse_binary_op( self, klass: Callable, other: <MagicMock id='140483438814656'>, **kwargs) -> sqlglot.dataframe.sql.Column:
163    def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
164        return Column(
165            klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
166        )
def unary_op(self, klass: Callable, **kwargs) -> sqlglot.dataframe.sql.Column:
168    def unary_op(self, klass: t.Callable, **kwargs) -> Column:
169        return Column(klass(this=self.column_expression, **kwargs))
@classmethod
def ensure_literal(cls, value) -> sqlglot.dataframe.sql.Column:
187    @classmethod
188    def ensure_literal(cls, value) -> Column:
189        from sqlglot.dataframe.sql.functions import lit
190
191        if isinstance(value, cls):
192            value = value.expression
193        if not isinstance(value, exp.Literal):
194            return lit(value)
195        return Column(value)
def copy(self) -> sqlglot.dataframe.sql.Column:
197    def copy(self) -> Column:
198        return Column(self.expression.copy())
def set_table_name(self, table_name: str, copy=False) -> sqlglot.dataframe.sql.Column:
200    def set_table_name(self, table_name: str, copy=False) -> Column:
201        expression = self.expression.copy() if copy else self.expression
202        expression.set("table", exp.to_identifier(table_name))
203        return Column(expression)
def sql(self, **kwargs) -> str:
205    def sql(self, **kwargs) -> str:
206        return self.expression.sql(**{"dialect": "spark", **kwargs})
def alias(self, name: str) -> sqlglot.dataframe.sql.Column:
208    def alias(self, name: str) -> Column:
209        new_expression = exp.alias_(self.column_expression, name)
210        return Column(new_expression)
def asc(self) -> sqlglot.dataframe.sql.Column:
212    def asc(self) -> Column:
213        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
214        return Column(new_expression)
def desc(self) -> sqlglot.dataframe.sql.Column:
216    def desc(self) -> Column:
217        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
218        return Column(new_expression)
def asc_nulls_first(self) -> sqlglot.dataframe.sql.Column:
212    def asc(self) -> Column:
213        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
214        return Column(new_expression)
def asc_nulls_last(self) -> sqlglot.dataframe.sql.Column:
222    def asc_nulls_last(self) -> Column:
223        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False)
224        return Column(new_expression)
def desc_nulls_first(self) -> sqlglot.dataframe.sql.Column:
226    def desc_nulls_first(self) -> Column:
227        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True)
228        return Column(new_expression)
def desc_nulls_last(self) -> sqlglot.dataframe.sql.Column:
216    def desc(self) -> Column:
217        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
218        return Column(new_expression)
def when( self, condition: sqlglot.dataframe.sql.Column, value: Any) -> sqlglot.dataframe.sql.Column:
232    def when(self, condition: Column, value: t.Any) -> Column:
233        from sqlglot.dataframe.sql.functions import when
234
235        column_with_if = when(condition, value)
236        if not isinstance(self.expression, exp.Case):
237            return column_with_if
238        new_column = self.copy()
239        new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"])
240        return new_column
def otherwise(self, value: Any) -> sqlglot.dataframe.sql.Column:
242    def otherwise(self, value: t.Any) -> Column:
243        from sqlglot.dataframe.sql.functions import lit
244
245        true_value = value if isinstance(value, Column) else lit(value)
246        new_column = self.copy()
247        new_column.expression.set("default", true_value.column_expression)
248        return new_column
def isNull(self) -> sqlglot.dataframe.sql.Column:
250    def isNull(self) -> Column:
251        new_expression = exp.Is(this=self.column_expression, expression=exp.Null())
252        return Column(new_expression)
def isNotNull(self) -> sqlglot.dataframe.sql.Column:
254    def isNotNull(self) -> Column:
255        new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
256        return Column(new_expression)
def cast(self, dataType: Union[str, sqlglot.dataframe.sql.types.DataType]):
258    def cast(self, dataType: t.Union[str, DataType]):
259        """
260        Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
261        Sqlglot doesn't currently replicate this class so it only accepts a string
262        """
263        if isinstance(dataType, DataType):
264            dataType = dataType.simpleString()
265        return Column(exp.cast(self.column_expression, dataType, dialect="spark"))

Functionality Difference: PySpark cast accepts a datatype instance of the datatype class Sqlglot doesn't currently replicate this class so it only accepts a string

def startswith( self, value: Union[str, sqlglot.dataframe.sql.Column]) -> sqlglot.dataframe.sql.Column:
267    def startswith(self, value: t.Union[str, Column]) -> Column:
268        value = self._lit(value) if not isinstance(value, Column) else value
269        return self.invoke_anonymous_function(self, "STARTSWITH", value)
def endswith( self, value: Union[str, sqlglot.dataframe.sql.Column]) -> sqlglot.dataframe.sql.Column:
271    def endswith(self, value: t.Union[str, Column]) -> Column:
272        value = self._lit(value) if not isinstance(value, Column) else value
273        return self.invoke_anonymous_function(self, "ENDSWITH", value)
def rlike(self, regexp: str) -> sqlglot.dataframe.sql.Column:
275    def rlike(self, regexp: str) -> Column:
276        return self.invoke_expression_over_column(
277            column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression
278        )
def like(self, other: str):
280    def like(self, other: str):
281        return self.invoke_expression_over_column(
282            self, exp.Like, expression=self._lit(other).expression
283        )
def ilike(self, other: str):
285    def ilike(self, other: str):
286        return self.invoke_expression_over_column(
287            self, exp.ILike, expression=self._lit(other).expression
288        )
def substr( self, startPos: Union[int, sqlglot.dataframe.sql.Column], length: Union[int, sqlglot.dataframe.sql.Column]) -> sqlglot.dataframe.sql.Column:
290    def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
291        startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
292        length = self._lit(length) if not isinstance(length, Column) else length
293        return Column.invoke_expression_over_column(
294            self, exp.Substring, start=startPos.expression, length=length.expression
295        )
def isin( self, *cols: Union[<MagicMock id='140483438973072'>, Iterable[<MagicMock id='140483438973072'>]]):
297    def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]):
298        columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
299        expressions = [self._lit(x).expression for x in columns]
300        return Column.invoke_expression_over_column(self, exp.In, expressions=expressions)  # type: ignore
def between( self, lowerBound: <MagicMock id='140483439030096'>, upperBound: <MagicMock id='140483439074304'>) -> sqlglot.dataframe.sql.Column:
302    def between(
303        self,
304        lowerBound: t.Union[ColumnOrLiteral],
305        upperBound: t.Union[ColumnOrLiteral],
306    ) -> Column:
307        lower_bound_exp = (
308            self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
309        )
310        upper_bound_exp = (
311            self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
312        )
313        return Column(
314            exp.Between(
315                this=self.column_expression,
316                low=lower_bound_exp.expression,
317                high=upper_bound_exp.expression,
318            )
319        )
def over( self, window: <MagicMock id='140483439149984'>) -> sqlglot.dataframe.sql.Column:
321    def over(self, window: WindowSpec) -> Column:
322        window_expression = window.expression.copy()
323        window_expression.set("this", self.column_expression)
324        return Column(window_expression)
class DataFrameNaFunctions:
785class DataFrameNaFunctions:
786    def __init__(self, df: DataFrame):
787        self.df = df
788
789    def drop(
790        self,
791        how: str = "any",
792        thresh: t.Optional[int] = None,
793        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
794    ) -> DataFrame:
795        return self.df.dropna(how=how, thresh=thresh, subset=subset)
796
797    def fill(
798        self,
799        value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
800        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
801    ) -> DataFrame:
802        return self.df.fillna(value=value, subset=subset)
803
804    def replace(
805        self,
806        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
807        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
808        subset: t.Optional[t.Union[str, t.List[str]]] = None,
809    ) -> DataFrame:
810        return self.df.replace(to_replace=to_replace, value=value, subset=subset)
DataFrameNaFunctions(df: sqlglot.dataframe.sql.DataFrame)
786    def __init__(self, df: DataFrame):
787        self.df = df
def drop( self, how: str = 'any', thresh: Optional[int] = None, subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
789    def drop(
790        self,
791        how: str = "any",
792        thresh: t.Optional[int] = None,
793        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
794    ) -> DataFrame:
795        return self.df.dropna(how=how, thresh=thresh, subset=subset)
def fill( self, value: Union[int, bool, float, str, Dict[str, Any]], subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
797    def fill(
798        self,
799        value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
800        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
801    ) -> DataFrame:
802        return self.df.fillna(value=value, subset=subset)
def replace( self, to_replace: Union[bool, int, float, str, List, Dict], value: Union[bool, int, float, str, List, NoneType] = None, subset: Union[str, List[str], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
804    def replace(
805        self,
806        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
807        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
808        subset: t.Optional[t.Union[str, t.List[str]]] = None,
809    ) -> DataFrame:
810        return self.df.replace(to_replace=to_replace, value=value, subset=subset)
class Window:
15class Window:
16    _JAVA_MIN_LONG = -(1 << 63)  # -9223372036854775808
17    _JAVA_MAX_LONG = (1 << 63) - 1  # 9223372036854775807
18    _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG)
19    _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG)
20
21    unboundedPreceding: int = _JAVA_MIN_LONG
22
23    unboundedFollowing: int = _JAVA_MAX_LONG
24
25    currentRow: int = 0
26
27    @classmethod
28    def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
29        return WindowSpec().partitionBy(*cols)
30
31    @classmethod
32    def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
33        return WindowSpec().orderBy(*cols)
34
35    @classmethod
36    def rowsBetween(cls, start: int, end: int) -> WindowSpec:
37        return WindowSpec().rowsBetween(start, end)
38
39    @classmethod
40    def rangeBetween(cls, start: int, end: int) -> WindowSpec:
41        return WindowSpec().rangeBetween(start, end)
Window()
@classmethod
def partitionBy( cls, *cols: Union[<MagicMock id='140483439288784'>, List[<MagicMock id='140483439288784'>]]) -> sqlglot.dataframe.sql.WindowSpec:
27    @classmethod
28    def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
29        return WindowSpec().partitionBy(*cols)
@classmethod
def orderBy( cls, *cols: Union[<MagicMock id='140483439363408'>, List[<MagicMock id='140483439363408'>]]) -> sqlglot.dataframe.sql.WindowSpec:
31    @classmethod
32    def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
33        return WindowSpec().orderBy(*cols)
@classmethod
def rowsBetween(cls, start: int, end: int) -> sqlglot.dataframe.sql.WindowSpec:
35    @classmethod
36    def rowsBetween(cls, start: int, end: int) -> WindowSpec:
37        return WindowSpec().rowsBetween(start, end)
@classmethod
def rangeBetween(cls, start: int, end: int) -> sqlglot.dataframe.sql.WindowSpec:
39    @classmethod
40    def rangeBetween(cls, start: int, end: int) -> WindowSpec:
41        return WindowSpec().rangeBetween(start, end)
class WindowSpec:
 44class WindowSpec:
 45    def __init__(self, expression: exp.Expression = exp.Window()):
 46        self.expression = expression
 47
 48    def copy(self):
 49        return WindowSpec(self.expression.copy())
 50
 51    def sql(self, **kwargs) -> str:
 52        return self.expression.sql(dialect="spark", **kwargs)
 53
 54    def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
 55        from sqlglot.dataframe.sql.column import Column
 56
 57        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
 58        expressions = [Column.ensure_col(x).expression for x in cols]
 59        window_spec = self.copy()
 60        partition_by_expressions = window_spec.expression.args.get("partition_by", [])
 61        partition_by_expressions.extend(expressions)
 62        window_spec.expression.set("partition_by", partition_by_expressions)
 63        return window_spec
 64
 65    def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
 66        from sqlglot.dataframe.sql.column import Column
 67
 68        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
 69        expressions = [Column.ensure_col(x).expression for x in cols]
 70        window_spec = self.copy()
 71        if window_spec.expression.args.get("order") is None:
 72            window_spec.expression.set("order", exp.Order(expressions=[]))
 73        order_by = window_spec.expression.args["order"].expressions
 74        order_by.extend(expressions)
 75        window_spec.expression.args["order"].set("expressions", order_by)
 76        return window_spec
 77
 78    def _calc_start_end(
 79        self, start: int, end: int
 80    ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
 81        kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {
 82            "start_side": None,
 83            "end_side": None,
 84        }
 85        if start == Window.currentRow:
 86            kwargs["start"] = "CURRENT ROW"
 87        else:
 88            kwargs = {
 89                **kwargs,
 90                **{
 91                    "start_side": "PRECEDING",
 92                    "start": "UNBOUNDED"
 93                    if start <= Window.unboundedPreceding
 94                    else F.lit(start).expression,
 95                },
 96            }
 97        if end == Window.currentRow:
 98            kwargs["end"] = "CURRENT ROW"
 99        else:
100            kwargs = {
101                **kwargs,
102                **{
103                    "end_side": "FOLLOWING",
104                    "end": "UNBOUNDED"
105                    if end >= Window.unboundedFollowing
106                    else F.lit(end).expression,
107                },
108            }
109        return kwargs
110
111    def rowsBetween(self, start: int, end: int) -> WindowSpec:
112        window_spec = self.copy()
113        spec = self._calc_start_end(start, end)
114        spec["kind"] = "ROWS"
115        window_spec.expression.set(
116            "spec",
117            exp.WindowSpec(
118                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
119            ),
120        )
121        return window_spec
122
123    def rangeBetween(self, start: int, end: int) -> WindowSpec:
124        window_spec = self.copy()
125        spec = self._calc_start_end(start, end)
126        spec["kind"] = "RANGE"
127        window_spec.expression.set(
128            "spec",
129            exp.WindowSpec(
130                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
131            ),
132        )
133        return window_spec
WindowSpec(expression: sqlglot.expressions.Expression = (WINDOW ))
45    def __init__(self, expression: exp.Expression = exp.Window()):
46        self.expression = expression
def copy(self):
48    def copy(self):
49        return WindowSpec(self.expression.copy())
def sql(self, **kwargs) -> str:
51    def sql(self, **kwargs) -> str:
52        return self.expression.sql(dialect="spark", **kwargs)
def partitionBy( self, *cols: Union[<MagicMock id='140483437616032'>, List[<MagicMock id='140483437616032'>]]) -> sqlglot.dataframe.sql.WindowSpec:
54    def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
55        from sqlglot.dataframe.sql.column import Column
56
57        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
58        expressions = [Column.ensure_col(x).expression for x in cols]
59        window_spec = self.copy()
60        partition_by_expressions = window_spec.expression.args.get("partition_by", [])
61        partition_by_expressions.extend(expressions)
62        window_spec.expression.set("partition_by", partition_by_expressions)
63        return window_spec
def orderBy( self, *cols: Union[<MagicMock id='140483438173232'>, List[<MagicMock id='140483438173232'>]]) -> sqlglot.dataframe.sql.WindowSpec:
65    def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
66        from sqlglot.dataframe.sql.column import Column
67
68        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
69        expressions = [Column.ensure_col(x).expression for x in cols]
70        window_spec = self.copy()
71        if window_spec.expression.args.get("order") is None:
72            window_spec.expression.set("order", exp.Order(expressions=[]))
73        order_by = window_spec.expression.args["order"].expressions
74        order_by.extend(expressions)
75        window_spec.expression.args["order"].set("expressions", order_by)
76        return window_spec
def rowsBetween(self, start: int, end: int) -> sqlglot.dataframe.sql.WindowSpec:
111    def rowsBetween(self, start: int, end: int) -> WindowSpec:
112        window_spec = self.copy()
113        spec = self._calc_start_end(start, end)
114        spec["kind"] = "ROWS"
115        window_spec.expression.set(
116            "spec",
117            exp.WindowSpec(
118                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
119            ),
120        )
121        return window_spec
def rangeBetween(self, start: int, end: int) -> sqlglot.dataframe.sql.WindowSpec:
123    def rangeBetween(self, start: int, end: int) -> WindowSpec:
124        window_spec = self.copy()
125        spec = self._calc_start_end(start, end)
126        spec["kind"] = "RANGE"
127        window_spec.expression.set(
128            "spec",
129            exp.WindowSpec(
130                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
131            ),
132        )
133        return window_spec
class DataFrameReader:
15class DataFrameReader:
16    def __init__(self, spark: SparkSession):
17        self.spark = spark
18
19    def table(self, tableName: str) -> DataFrame:
20        from sqlglot.dataframe.sql.dataframe import DataFrame
21
22        sqlglot.schema.add_table(tableName)
23        return DataFrame(
24            self.spark,
25            exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)),
26        )
DataFrameReader(spark: sqlglot.dataframe.sql.SparkSession)
16    def __init__(self, spark: SparkSession):
17        self.spark = spark
def table(self, tableName: str) -> sqlglot.dataframe.sql.DataFrame:
19    def table(self, tableName: str) -> DataFrame:
20        from sqlglot.dataframe.sql.dataframe import DataFrame
21
22        sqlglot.schema.add_table(tableName)
23        return DataFrame(
24            self.spark,
25            exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)),
26        )
class DataFrameWriter:
29class DataFrameWriter:
30    def __init__(
31        self,
32        df: DataFrame,
33        spark: t.Optional[SparkSession] = None,
34        mode: t.Optional[str] = None,
35        by_name: bool = False,
36    ):
37        self._df = df
38        self._spark = spark or df.spark
39        self._mode = mode
40        self._by_name = by_name
41
42    def copy(self, **kwargs) -> DataFrameWriter:
43        return DataFrameWriter(
44            **{
45                k[1:] if k.startswith("_") else k: v
46                for k, v in object_to_dict(self, **kwargs).items()
47            }
48        )
49
50    def sql(self, **kwargs) -> t.List[str]:
51        return self._df.sql(**kwargs)
52
53    def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
54        return self.copy(_mode=saveMode)
55
56    @property
57    def byName(self):
58        return self.copy(by_name=True)
59
60    def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
61        output_expression_container = exp.Insert(
62            **{
63                "this": exp.to_table(tableName),
64                "overwrite": overwrite,
65            }
66        )
67        df = self._df.copy(output_expression_container=output_expression_container)
68        if self._by_name:
69            columns = sqlglot.schema.column_names(tableName, only_visible=True)
70            df = df._convert_leaf_to_cte().select(*columns)
71
72        return self.copy(_df=df)
73
74    def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
75        if format is not None:
76            raise NotImplementedError("Providing Format in the save as table is not supported")
77        exists, replace, mode = None, None, mode or str(self._mode)
78        if mode == "append":
79            return self.insertInto(name)
80        if mode == "ignore":
81            exists = True
82        if mode == "overwrite":
83            replace = True
84        output_expression_container = exp.Create(
85            this=exp.to_table(name),
86            kind="TABLE",
87            exists=exists,
88            replace=replace,
89        )
90        return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
DataFrameWriter( df: sqlglot.dataframe.sql.DataFrame, spark: Optional[sqlglot.dataframe.sql.SparkSession] = None, mode: Optional[str] = None, by_name: bool = False)
30    def __init__(
31        self,
32        df: DataFrame,
33        spark: t.Optional[SparkSession] = None,
34        mode: t.Optional[str] = None,
35        by_name: bool = False,
36    ):
37        self._df = df
38        self._spark = spark or df.spark
39        self._mode = mode
40        self._by_name = by_name
def copy(self, **kwargs) -> sqlglot.dataframe.sql.DataFrameWriter:
42    def copy(self, **kwargs) -> DataFrameWriter:
43        return DataFrameWriter(
44            **{
45                k[1:] if k.startswith("_") else k: v
46                for k, v in object_to_dict(self, **kwargs).items()
47            }
48        )
def sql(self, **kwargs) -> List[str]:
50    def sql(self, **kwargs) -> t.List[str]:
51        return self._df.sql(**kwargs)
def mode( self, saveMode: Optional[str]) -> sqlglot.dataframe.sql.DataFrameWriter:
53    def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
54        return self.copy(_mode=saveMode)
def insertInto( self, tableName: str, overwrite: Optional[bool] = None) -> sqlglot.dataframe.sql.DataFrameWriter:
60    def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
61        output_expression_container = exp.Insert(
62            **{
63                "this": exp.to_table(tableName),
64                "overwrite": overwrite,
65            }
66        )
67        df = self._df.copy(output_expression_container=output_expression_container)
68        if self._by_name:
69            columns = sqlglot.schema.column_names(tableName, only_visible=True)
70            df = df._convert_leaf_to_cte().select(*columns)
71
72        return self.copy(_df=df)
def saveAsTable( self, name: str, format: Optional[str] = None, mode: Optional[str] = None):
74    def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
75        if format is not None:
76            raise NotImplementedError("Providing Format in the save as table is not supported")
77        exists, replace, mode = None, None, mode or str(self._mode)
78        if mode == "append":
79            return self.insertInto(name)
80        if mode == "ignore":
81            exists = True
82        if mode == "overwrite":
83            replace = True
84        output_expression_container = exp.Create(
85            this=exp.to_table(name),
86            kind="TABLE",
87            exists=exists,
88            replace=replace,
89        )
90        return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))