Edit on GitHub

sqlglot.dataframe.sql

 1from sqlglot.dataframe.sql.column import Column
 2from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions
 3from sqlglot.dataframe.sql.group import GroupedData
 4from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter
 5from sqlglot.dataframe.sql.session import SparkSession
 6from sqlglot.dataframe.sql.window import Window, WindowSpec
 7
 8__all__ = [
 9    "SparkSession",
10    "DataFrame",
11    "GroupedData",
12    "Column",
13    "DataFrameNaFunctions",
14    "Window",
15    "WindowSpec",
16    "DataFrameReader",
17    "DataFrameWriter",
18]
class SparkSession:
 20class SparkSession:
 21    known_ids: t.ClassVar[t.Set[str]] = set()
 22    known_branch_ids: t.ClassVar[t.Set[str]] = set()
 23    known_sequence_ids: t.ClassVar[t.Set[str]] = set()
 24    name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list)
 25
 26    def __init__(self):
 27        self.incrementing_id = 1
 28
 29    def __getattr__(self, name: str) -> SparkSession:
 30        return self
 31
 32    def __call__(self, *args, **kwargs) -> SparkSession:
 33        return self
 34
 35    @property
 36    def read(self) -> DataFrameReader:
 37        return DataFrameReader(self)
 38
 39    def table(self, tableName: str) -> DataFrame:
 40        return self.read.table(tableName)
 41
 42    def createDataFrame(
 43        self,
 44        data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
 45        schema: t.Optional[SchemaInput] = None,
 46        samplingRatio: t.Optional[float] = None,
 47        verifySchema: bool = False,
 48    ) -> DataFrame:
 49        from sqlglot.dataframe.sql.dataframe import DataFrame
 50
 51        if samplingRatio is not None or verifySchema:
 52            raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
 53        if schema is not None and (
 54            not isinstance(schema, (StructType, str, list))
 55            or (isinstance(schema, list) and not isinstance(schema[0], str))
 56        ):
 57            raise NotImplementedError("Only schema of either list or string of list supported")
 58        if not data:
 59            raise ValueError("Must provide data to create into a DataFrame")
 60
 61        column_mapping: t.Dict[str, t.Optional[str]]
 62        if schema is not None:
 63            column_mapping = get_column_mapping_from_schema_input(schema)
 64        elif isinstance(data[0], dict):
 65            column_mapping = {col_name.strip(): None for col_name in data[0]}
 66        else:
 67            column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
 68
 69        data_expressions = [
 70            exp.Tuple(
 71                expressions=list(
 72                    map(
 73                        lambda x: F.lit(x).expression,
 74                        row if not isinstance(row, dict) else row.values(),
 75                    )
 76                )
 77            )
 78            for row in data
 79        ]
 80
 81        sel_columns = [
 82            F.col(name).cast(data_type).alias(name).expression
 83            if data_type is not None
 84            else F.col(name).expression
 85            for name, data_type in column_mapping.items()
 86        ]
 87
 88        select_kwargs = {
 89            "expressions": sel_columns,
 90            "from": exp.From(
 91                this=exp.Values(
 92                    expressions=data_expressions,
 93                    alias=exp.TableAlias(
 94                        this=exp.to_identifier(self._auto_incrementing_name),
 95                        columns=[exp.to_identifier(col_name) for col_name in column_mapping],
 96                    ),
 97                ),
 98            ),
 99        }
100
101        sel_expression = exp.Select(**select_kwargs)
102        return DataFrame(self, sel_expression)
103
104    def sql(self, sqlQuery: str) -> DataFrame:
105        expression = sqlglot.parse_one(sqlQuery, read="spark")
106        if isinstance(expression, exp.Select):
107            df = DataFrame(self, expression)
108            df = df._convert_leaf_to_cte()
109        elif isinstance(expression, (exp.Create, exp.Insert)):
110            select_expression = expression.expression.copy()
111            if isinstance(expression, exp.Insert):
112                select_expression.set("with", expression.args.get("with"))
113                expression.set("with", None)
114            del expression.args["expression"]
115            df = DataFrame(self, select_expression, output_expression_container=expression)  # type: ignore
116            df = df._convert_leaf_to_cte()
117        else:
118            raise ValueError(
119                "Unknown expression type provided in the SQL. Please create an issue with the SQL."
120            )
121        return df
122
123    @property
124    def _auto_incrementing_name(self) -> str:
125        name = f"a{self.incrementing_id}"
126        self.incrementing_id += 1
127        return name
128
129    @property
130    def _random_branch_id(self) -> str:
131        id = self._random_id
132        self.known_branch_ids.add(id)
133        return id
134
135    @property
136    def _random_sequence_id(self):
137        id = self._random_id
138        self.known_sequence_ids.add(id)
139        return id
140
141    @property
142    def _random_id(self) -> str:
143        id = "r" + uuid.uuid4().hex
144        self.known_ids.add(id)
145        return id
146
147    @property
148    def _join_hint_names(self) -> t.Set[str]:
149        return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"}
150
151    def _add_alias_to_mapping(self, name: str, sequence_id: str):
152        self.name_to_sequence_id_mapping[name].append(sequence_id)
known_ids: ClassVar[Set[str]] = set()
known_branch_ids: ClassVar[Set[str]] = set()
known_sequence_ids: ClassVar[Set[str]] = set()
name_to_sequence_id_mapping: ClassVar[Dict[str, List[str]]] = defaultdict(<class 'list'>, {})
incrementing_id
def table(self, tableName: str) -> sqlglot.dataframe.sql.DataFrame:
39    def table(self, tableName: str) -> DataFrame:
40        return self.read.table(tableName)
def createDataFrame( self, data: Sequence[Union[Dict[str, <MagicMock id='140430583414944'>], List[<MagicMock id='140430583414944'>], Tuple]], schema: Optional[<MagicMock id='140430585007696'>] = None, samplingRatio: Optional[float] = None, verifySchema: bool = False) -> sqlglot.dataframe.sql.DataFrame:
 42    def createDataFrame(
 43        self,
 44        data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
 45        schema: t.Optional[SchemaInput] = None,
 46        samplingRatio: t.Optional[float] = None,
 47        verifySchema: bool = False,
 48    ) -> DataFrame:
 49        from sqlglot.dataframe.sql.dataframe import DataFrame
 50
 51        if samplingRatio is not None or verifySchema:
 52            raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
 53        if schema is not None and (
 54            not isinstance(schema, (StructType, str, list))
 55            or (isinstance(schema, list) and not isinstance(schema[0], str))
 56        ):
 57            raise NotImplementedError("Only schema of either list or string of list supported")
 58        if not data:
 59            raise ValueError("Must provide data to create into a DataFrame")
 60
 61        column_mapping: t.Dict[str, t.Optional[str]]
 62        if schema is not None:
 63            column_mapping = get_column_mapping_from_schema_input(schema)
 64        elif isinstance(data[0], dict):
 65            column_mapping = {col_name.strip(): None for col_name in data[0]}
 66        else:
 67            column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
 68
 69        data_expressions = [
 70            exp.Tuple(
 71                expressions=list(
 72                    map(
 73                        lambda x: F.lit(x).expression,
 74                        row if not isinstance(row, dict) else row.values(),
 75                    )
 76                )
 77            )
 78            for row in data
 79        ]
 80
 81        sel_columns = [
 82            F.col(name).cast(data_type).alias(name).expression
 83            if data_type is not None
 84            else F.col(name).expression
 85            for name, data_type in column_mapping.items()
 86        ]
 87
 88        select_kwargs = {
 89            "expressions": sel_columns,
 90            "from": exp.From(
 91                this=exp.Values(
 92                    expressions=data_expressions,
 93                    alias=exp.TableAlias(
 94                        this=exp.to_identifier(self._auto_incrementing_name),
 95                        columns=[exp.to_identifier(col_name) for col_name in column_mapping],
 96                    ),
 97                ),
 98            ),
 99        }
100
101        sel_expression = exp.Select(**select_kwargs)
102        return DataFrame(self, sel_expression)
def sql(self, sqlQuery: str) -> sqlglot.dataframe.sql.DataFrame:
104    def sql(self, sqlQuery: str) -> DataFrame:
105        expression = sqlglot.parse_one(sqlQuery, read="spark")
106        if isinstance(expression, exp.Select):
107            df = DataFrame(self, expression)
108            df = df._convert_leaf_to_cte()
109        elif isinstance(expression, (exp.Create, exp.Insert)):
110            select_expression = expression.expression.copy()
111            if isinstance(expression, exp.Insert):
112                select_expression.set("with", expression.args.get("with"))
113                expression.set("with", None)
114            del expression.args["expression"]
115            df = DataFrame(self, select_expression, output_expression_container=expression)  # type: ignore
116            df = df._convert_leaf_to_cte()
117        else:
118            raise ValueError(
119                "Unknown expression type provided in the SQL. Please create an issue with the SQL."
120            )
121        return df
class DataFrame:
 45class DataFrame:
 46    def __init__(
 47        self,
 48        spark: SparkSession,
 49        expression: exp.Select,
 50        branch_id: t.Optional[str] = None,
 51        sequence_id: t.Optional[str] = None,
 52        last_op: Operation = Operation.INIT,
 53        pending_hints: t.Optional[t.List[exp.Expression]] = None,
 54        output_expression_container: t.Optional[OutputExpressionContainer] = None,
 55        **kwargs,
 56    ):
 57        self.spark = spark
 58        self.expression = expression
 59        self.branch_id = branch_id or self.spark._random_branch_id
 60        self.sequence_id = sequence_id or self.spark._random_sequence_id
 61        self.last_op = last_op
 62        self.pending_hints = pending_hints or []
 63        self.output_expression_container = output_expression_container or exp.Select()
 64
 65    def __getattr__(self, column_name: str) -> Column:
 66        return self[column_name]
 67
 68    def __getitem__(self, column_name: str) -> Column:
 69        column_name = f"{self.branch_id}.{column_name}"
 70        return Column(column_name)
 71
 72    def __copy__(self):
 73        return self.copy()
 74
 75    @property
 76    def sparkSession(self):
 77        return self.spark
 78
 79    @property
 80    def write(self):
 81        return DataFrameWriter(self)
 82
 83    @property
 84    def latest_cte_name(self) -> str:
 85        if not self.expression.ctes:
 86            from_exp = self.expression.args["from"]
 87            if from_exp.alias_or_name:
 88                return from_exp.alias_or_name
 89            table_alias = from_exp.find(exp.TableAlias)
 90            if not table_alias:
 91                raise RuntimeError(
 92                    f"Could not find an alias name for this expression: {self.expression}"
 93                )
 94            return table_alias.alias_or_name
 95        return self.expression.ctes[-1].alias
 96
 97    @property
 98    def pending_join_hints(self):
 99        return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)]
100
101    @property
102    def pending_partition_hints(self):
103        return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)]
104
105    @property
106    def columns(self) -> t.List[str]:
107        return self.expression.named_selects
108
109    @property
110    def na(self) -> DataFrameNaFunctions:
111        return DataFrameNaFunctions(self)
112
113    def _replace_cte_names_with_hashes(self, expression: exp.Select):
114        replacement_mapping = {}
115        for cte in expression.ctes:
116            old_name_id = cte.args["alias"].this
117            new_hashed_id = exp.to_identifier(
118                self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
119            )
120            replacement_mapping[old_name_id] = new_hashed_id
121            expression = expression.transform(replace_id_value, replacement_mapping)
122        return expression
123
124    def _create_cte_from_expression(
125        self,
126        expression: exp.Expression,
127        branch_id: t.Optional[str] = None,
128        sequence_id: t.Optional[str] = None,
129        **kwargs,
130    ) -> t.Tuple[exp.CTE, str]:
131        name = self._create_hash_from_expression(expression)
132        expression_to_cte = expression.copy()
133        expression_to_cte.set("with", None)
134        cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
135        cte.set("branch_id", branch_id or self.branch_id)
136        cte.set("sequence_id", sequence_id or self.sequence_id)
137        return cte, name
138
139    @t.overload
140    def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
141        ...
142
143    @t.overload
144    def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
145        ...
146
147    def _ensure_list_of_columns(self, cols):
148        return Column.ensure_cols(ensure_list(cols))
149
150    def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None):
151        cols = self._ensure_list_of_columns(cols)
152        normalize(self.spark, expression or self.expression, cols)
153        return cols
154
155    def _ensure_and_normalize_col(self, col):
156        col = Column.ensure_col(col)
157        normalize(self.spark, self.expression, col)
158        return col
159
160    def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame:
161        df = self._resolve_pending_hints()
162        sequence_id = sequence_id or df.sequence_id
163        expression = df.expression.copy()
164        cte_expression, cte_name = df._create_cte_from_expression(
165            expression=expression, sequence_id=sequence_id
166        )
167        new_expression = df._add_ctes_to_expression(
168            exp.Select(), expression.ctes + [cte_expression]
169        )
170        sel_columns = df._get_outer_select_columns(cte_expression)
171        new_expression = new_expression.from_(cte_name).select(
172            *[x.alias_or_name for x in sel_columns]
173        )
174        return df.copy(expression=new_expression, sequence_id=sequence_id)
175
176    def _resolve_pending_hints(self) -> DataFrame:
177        df = self.copy()
178        if not self.pending_hints:
179            return df
180        expression = df.expression
181        hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
182        for hint in df.pending_partition_hints:
183            hint_expression.append("expressions", hint)
184            df.pending_hints.remove(hint)
185
186        join_aliases = {
187            join_table.alias_or_name
188            for join_table in get_tables_from_expression_with_join(expression)
189        }
190        if join_aliases:
191            for hint in df.pending_join_hints:
192                for sequence_id_expression in hint.expressions:
193                    sequence_id_or_name = sequence_id_expression.alias_or_name
194                    sequence_ids_to_match = [sequence_id_or_name]
195                    if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
196                        sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[
197                            sequence_id_or_name
198                        ]
199                    matching_ctes = [
200                        cte
201                        for cte in reversed(expression.ctes)
202                        if cte.args["sequence_id"] in sequence_ids_to_match
203                    ]
204                    for matching_cte in matching_ctes:
205                        if matching_cte.alias_or_name in join_aliases:
206                            sequence_id_expression.set("this", matching_cte.args["alias"].this)
207                            df.pending_hints.remove(hint)
208                            break
209                hint_expression.append("expressions", hint)
210        if hint_expression.expressions:
211            expression.set("hint", hint_expression)
212        return df
213
214    def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
215        hint_name = hint_name.upper()
216        hint_expression = (
217            exp.JoinHint(
218                this=hint_name,
219                expressions=[exp.to_table(parameter.alias_or_name) for parameter in args],
220            )
221            if hint_name in JOIN_HINTS
222            else exp.Anonymous(
223                this=hint_name, expressions=[parameter.expression for parameter in args]
224            )
225        )
226        new_df = self.copy()
227        new_df.pending_hints.append(hint_expression)
228        return new_df
229
230    def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool):
231        other_df = other._convert_leaf_to_cte()
232        base_expression = self.expression.copy()
233        base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes)
234        all_ctes = base_expression.ctes
235        other_df.expression.set("with", None)
236        base_expression.set("with", None)
237        operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression)
238        operation.set("with", exp.With(expressions=all_ctes))
239        return self.copy(expression=operation)._convert_leaf_to_cte()
240
241    def _cache(self, storage_level: str):
242        df = self._convert_leaf_to_cte()
243        df.expression.ctes[-1].set("cache_storage_level", storage_level)
244        return df
245
246    @classmethod
247    def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select:
248        expression = expression.copy()
249        with_expression = expression.args.get("with")
250        if with_expression:
251            existing_ctes = with_expression.expressions
252            existsing_cte_names = {x.alias_or_name for x in existing_ctes}
253            for cte in ctes:
254                if cte.alias_or_name not in existsing_cte_names:
255                    existing_ctes.append(cte)
256        else:
257            existing_ctes = ctes
258        expression.set("with", exp.With(expressions=existing_ctes))
259        return expression
260
261    @classmethod
262    def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]:
263        expression = item.expression if isinstance(item, DataFrame) else item
264        return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions]
265
266    @classmethod
267    def _create_hash_from_expression(cls, expression: exp.Expression) -> str:
268        value = expression.sql(dialect="spark").encode("utf-8")
269        return f"t{zlib.crc32(value)}"[:6]
270
271    def _get_select_expressions(
272        self,
273    ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
274        select_expressions: t.List[
275            t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]
276        ] = []
277        main_select_ctes: t.List[exp.CTE] = []
278        for cte in self.expression.ctes:
279            cache_storage_level = cte.args.get("cache_storage_level")
280            if cache_storage_level:
281                select_expression = cte.this.copy()
282                select_expression.set("with", exp.With(expressions=copy(main_select_ctes)))
283                select_expression.set("cte_alias_name", cte.alias_or_name)
284                select_expression.set("cache_storage_level", cache_storage_level)
285                select_expressions.append((exp.Cache, select_expression))
286            else:
287                main_select_ctes.append(cte)
288        main_select = self.expression.copy()
289        if main_select_ctes:
290            main_select.set("with", exp.With(expressions=main_select_ctes))
291        expression_select_pair = (type(self.output_expression_container), main_select)
292        select_expressions.append(expression_select_pair)  # type: ignore
293        return select_expressions
294
295    def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
296        df = self._resolve_pending_hints()
297        select_expressions = df._get_select_expressions()
298        output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
299        replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
300        for expression_type, select_expression in select_expressions:
301            select_expression = select_expression.transform(replace_id_value, replacement_mapping)
302            if optimize:
303                select_expression = t.cast(exp.Select, optimize_func(select_expression))
304            select_expression = df._replace_cte_names_with_hashes(select_expression)
305            expression: t.Union[exp.Select, exp.Cache, exp.Drop]
306            if expression_type == exp.Cache:
307                cache_table_name = df._create_hash_from_expression(select_expression)
308                cache_table = exp.to_table(cache_table_name)
309                original_alias_name = select_expression.args["cte_alias_name"]
310
311                replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(  # type: ignore
312                    cache_table_name
313                )
314                sqlglot.schema.add_table(
315                    cache_table_name,
316                    {
317                        expression.alias_or_name: expression.type.sql("spark")
318                        for expression in select_expression.expressions
319                    },
320                    dialect="spark",
321                )
322                cache_storage_level = select_expression.args["cache_storage_level"]
323                options = [
324                    exp.Literal.string("storageLevel"),
325                    exp.Literal.string(cache_storage_level),
326                ]
327                expression = exp.Cache(
328                    this=cache_table, expression=select_expression, lazy=True, options=options
329                )
330                # We will drop the "view" if it exists before running the cache table
331                output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
332            elif expression_type == exp.Create:
333                expression = df.output_expression_container.copy()
334                expression.set("expression", select_expression)
335            elif expression_type == exp.Insert:
336                expression = df.output_expression_container.copy()
337                select_without_ctes = select_expression.copy()
338                select_without_ctes.set("with", None)
339                expression.set("expression", select_without_ctes)
340                if select_expression.ctes:
341                    expression.set("with", exp.With(expressions=select_expression.ctes))
342            elif expression_type == exp.Select:
343                expression = select_expression
344            else:
345                raise ValueError(f"Invalid expression type: {expression_type}")
346            output_expressions.append(expression)
347
348        return [
349            expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions
350        ]
351
352    def copy(self, **kwargs) -> DataFrame:
353        return DataFrame(**object_to_dict(self, **kwargs))
354
355    @operation(Operation.SELECT)
356    def select(self, *cols, **kwargs) -> DataFrame:
357        cols = self._ensure_and_normalize_cols(cols)
358        kwargs["append"] = kwargs.get("append", False)
359        if self.expression.args.get("joins"):
360            ambiguous_cols = [
361                col
362                for col in cols
363                if isinstance(col.column_expression, exp.Column) and not col.column_expression.table
364            ]
365            if ambiguous_cols:
366                join_table_identifiers = [
367                    x.this for x in get_tables_from_expression_with_join(self.expression)
368                ]
369                cte_names_in_join = [x.this for x in join_table_identifiers]
370                # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
371                # and therefore we allow multiple columns with the same name in the result. This matches the behavior
372                # of Spark.
373                resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols}
374                for ambiguous_col in ambiguous_cols:
375                    ctes_with_column = [
376                        cte
377                        for cte in self.expression.ctes
378                        if cte.alias_or_name in cte_names_in_join
379                        and ambiguous_col.alias_or_name in cte.this.named_selects
380                    ]
381                    # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise,
382                    # use the same CTE we used before
383                    cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1)
384                    if cte:
385                        resolved_column_position[ambiguous_col] += 1
386                    else:
387                        cte = ctes_with_column[resolved_column_position[ambiguous_col]]
388                    ambiguous_col.expression.set("table", cte.alias_or_name)
389        return self.copy(
390            expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs
391        )
392
393    @operation(Operation.NO_OP)
394    def alias(self, name: str, **kwargs) -> DataFrame:
395        new_sequence_id = self.spark._random_sequence_id
396        df = self.copy()
397        for join_hint in df.pending_join_hints:
398            for expression in join_hint.expressions:
399                if expression.alias_or_name == self.sequence_id:
400                    expression.set("this", Column.ensure_col(new_sequence_id).expression)
401        df.spark._add_alias_to_mapping(name, new_sequence_id)
402        return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
403
404    @operation(Operation.WHERE)
405    def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
406        col = self._ensure_and_normalize_col(column)
407        return self.copy(expression=self.expression.where(col.expression))
408
409    filter = where
410
411    @operation(Operation.GROUP_BY)
412    def groupBy(self, *cols, **kwargs) -> GroupedData:
413        columns = self._ensure_and_normalize_cols(cols)
414        return GroupedData(self, columns, self.last_op)
415
416    @operation(Operation.SELECT)
417    def agg(self, *exprs, **kwargs) -> DataFrame:
418        cols = self._ensure_and_normalize_cols(exprs)
419        return self.groupBy().agg(*cols)
420
421    @operation(Operation.FROM)
422    def join(
423        self,
424        other_df: DataFrame,
425        on: t.Union[str, t.List[str], Column, t.List[Column]],
426        how: str = "inner",
427        **kwargs,
428    ) -> DataFrame:
429        other_df = other_df._convert_leaf_to_cte()
430        join_columns = self._ensure_list_of_columns(on)
431        # We will determine actual "join on" expression later so we don't provide it at first
432        join_expression = self.expression.join(
433            other_df.latest_cte_name, join_type=how.replace("_", " ")
434        )
435        join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
436        self_columns = self._get_outer_select_columns(join_expression)
437        other_columns = self._get_outer_select_columns(other_df)
438        # Determines the join clause and select columns to be used passed on what type of columns were provided for
439        # the join. The columns returned changes based on how the on expression is provided.
440        if isinstance(join_columns[0].expression, exp.Column):
441            """
442            Unique characteristics of join on column names only:
443            * The column names are put at the front of the select list
444            * The column names are deduplicated across the entire select list and only the column names (other dups are allowed)
445            """
446            table_names = [
447                table.alias_or_name
448                for table in get_tables_from_expression_with_join(join_expression)
449            ]
450            potential_ctes = [
451                cte
452                for cte in join_expression.ctes
453                if cte.alias_or_name in table_names
454                and cte.alias_or_name != other_df.latest_cte_name
455            ]
456            # Determine the table to reference for the left side of the join by checking each of the left side
457            # tables and see if they have the column being referenced.
458            join_column_pairs = []
459            for join_column in join_columns:
460                num_matching_ctes = 0
461                for cte in potential_ctes:
462                    if join_column.alias_or_name in cte.this.named_selects:
463                        left_column = join_column.copy().set_table_name(cte.alias_or_name)
464                        right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
465                        join_column_pairs.append((left_column, right_column))
466                        num_matching_ctes += 1
467                if num_matching_ctes > 1:
468                    raise ValueError(
469                        f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name."
470                    )
471                elif num_matching_ctes == 0:
472                    raise ValueError(
473                        f"Column {join_column.alias_or_name} does not exist in any of the tables."
474                    )
475            join_clause = functools.reduce(
476                lambda x, y: x & y,
477                [left_column == right_column for left_column, right_column in join_column_pairs],
478            )
479            join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
480            # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
481            select_column_names = [
482                column.alias_or_name
483                if not isinstance(column.expression.this, exp.Star)
484                else column.sql()
485                for column in self_columns + other_columns
486            ]
487            select_column_names = [
488                column_name
489                for column_name in select_column_names
490                if column_name not in join_column_names
491            ]
492            select_column_names = join_column_names + select_column_names
493        else:
494            """
495            Unique characteristics of join on expressions:
496            * There is no deduplication of the results.
497            * The left join dataframe columns go first and right come after. No sort preference is given to join columns
498            """
499            join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
500            if len(join_columns) > 1:
501                join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
502            join_clause = join_columns[0]
503            select_column_names = [column.alias_or_name for column in self_columns + other_columns]
504
505        # Update the on expression with the actual join clause to replace the dummy one from before
506        join_expression.args["joins"][-1].set("on", join_clause.expression)
507        new_df = self.copy(expression=join_expression)
508        new_df.pending_join_hints.extend(self.pending_join_hints)
509        new_df.pending_hints.extend(other_df.pending_hints)
510        new_df = new_df.select.__wrapped__(new_df, *select_column_names)
511        return new_df
512
513    @operation(Operation.ORDER_BY)
514    def orderBy(
515        self,
516        *cols: t.Union[str, Column],
517        ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
518    ) -> DataFrame:
519        """
520        This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
521        has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
522        is unlikely to come up.
523        """
524        columns = self._ensure_and_normalize_cols(cols)
525        pre_ordered_col_indexes = [
526            x
527            for x in [
528                i if isinstance(col.expression, exp.Ordered) else None
529                for i, col in enumerate(columns)
530            ]
531            if x is not None
532        ]
533        if ascending is None:
534            ascending = [True] * len(columns)
535        elif not isinstance(ascending, list):
536            ascending = [ascending] * len(columns)
537        ascending = [bool(x) for i, x in enumerate(ascending)]
538        assert len(columns) == len(
539            ascending
540        ), "The length of items in ascending must equal the number of columns provided"
541        col_and_ascending = list(zip(columns, ascending))
542        order_by_columns = [
543            exp.Ordered(this=col.expression, desc=not asc)
544            if i not in pre_ordered_col_indexes
545            else columns[i].column_expression
546            for i, (col, asc) in enumerate(col_and_ascending)
547        ]
548        return self.copy(expression=self.expression.order_by(*order_by_columns))
549
550    sort = orderBy
551
552    @operation(Operation.FROM)
553    def union(self, other: DataFrame) -> DataFrame:
554        return self._set_operation(exp.Union, other, False)
555
556    unionAll = union
557
558    @operation(Operation.FROM)
559    def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
560        l_columns = self.columns
561        r_columns = other.columns
562        if not allowMissingColumns:
563            l_expressions = l_columns
564            r_expressions = l_columns
565        else:
566            l_expressions = []
567            r_expressions = []
568            r_columns_unused = copy(r_columns)
569            for l_column in l_columns:
570                l_expressions.append(l_column)
571                if l_column in r_columns:
572                    r_expressions.append(l_column)
573                    r_columns_unused.remove(l_column)
574                else:
575                    r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False))
576            for r_column in r_columns_unused:
577                l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False))
578                r_expressions.append(r_column)
579        r_df = (
580            other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
581        )
582        l_df = self.copy()
583        if allowMissingColumns:
584            l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
585        return l_df._set_operation(exp.Union, r_df, False)
586
587    @operation(Operation.FROM)
588    def intersect(self, other: DataFrame) -> DataFrame:
589        return self._set_operation(exp.Intersect, other, True)
590
591    @operation(Operation.FROM)
592    def intersectAll(self, other: DataFrame) -> DataFrame:
593        return self._set_operation(exp.Intersect, other, False)
594
595    @operation(Operation.FROM)
596    def exceptAll(self, other: DataFrame) -> DataFrame:
597        return self._set_operation(exp.Except, other, False)
598
599    @operation(Operation.SELECT)
600    def distinct(self) -> DataFrame:
601        return self.copy(expression=self.expression.distinct())
602
603    @operation(Operation.SELECT)
604    def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
605        if not subset:
606            return self.distinct()
607        column_names = ensure_list(subset)
608        window = Window.partitionBy(*column_names).orderBy(*column_names)
609        return (
610            self.copy()
611            .withColumn("row_num", F.row_number().over(window))
612            .where(F.col("row_num") == F.lit(1))
613            .drop("row_num")
614        )
615
616    @operation(Operation.FROM)
617    def dropna(
618        self,
619        how: str = "any",
620        thresh: t.Optional[int] = None,
621        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
622    ) -> DataFrame:
623        minimum_non_null = thresh or 0  # will be determined later if thresh is null
624        new_df = self.copy()
625        all_columns = self._get_outer_select_columns(new_df.expression)
626        if subset:
627            null_check_columns = self._ensure_and_normalize_cols(subset)
628        else:
629            null_check_columns = all_columns
630        if thresh is None:
631            minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
632        else:
633            minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
634        if minimum_num_nulls > len(null_check_columns):
635            raise RuntimeError(
636                f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
637                f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
638            )
639        if_null_checks = [
640            F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
641        ]
642        nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
643        num_nulls = nulls_added_together.alias("num_nulls")
644        new_df = new_df.select(num_nulls, append=True)
645        filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
646        final_df = filtered_df.select(*all_columns)
647        return final_df
648
649    @operation(Operation.FROM)
650    def fillna(
651        self,
652        value: t.Union[ColumnLiterals],
653        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
654    ) -> DataFrame:
655        """
656        Functionality Difference: If you provide a value to replace a null and that type conflicts
657        with the type of the column then PySpark will just ignore your replacement.
658        This will try to cast them to be the same in some cases. So they won't always match.
659        Best to not mix types so make sure replacement is the same type as the column
660
661        Possibility for improvement: Use `typeof` function to get the type of the column
662        and check if it matches the type of the value provided. If not then make it null.
663        """
664        from sqlglot.dataframe.sql.functions import lit
665
666        values = None
667        columns = None
668        new_df = self.copy()
669        all_columns = self._get_outer_select_columns(new_df.expression)
670        all_column_mapping = {column.alias_or_name: column for column in all_columns}
671        if isinstance(value, dict):
672            values = list(value.values())
673            columns = self._ensure_and_normalize_cols(list(value))
674        if not columns:
675            columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
676        if not values:
677            values = [value] * len(columns)
678        value_columns = [lit(value) for value in values]
679
680        null_replacement_mapping = {
681            column.alias_or_name: (
682                F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
683            )
684            for column, value in zip(columns, value_columns)
685        }
686        null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
687        null_replacement_columns = [
688            null_replacement_mapping[column.alias_or_name] for column in all_columns
689        ]
690        new_df = new_df.select(*null_replacement_columns)
691        return new_df
692
693    @operation(Operation.FROM)
694    def replace(
695        self,
696        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
697        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
698        subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
699    ) -> DataFrame:
700        from sqlglot.dataframe.sql.functions import lit
701
702        old_values = None
703        new_df = self.copy()
704        all_columns = self._get_outer_select_columns(new_df.expression)
705        all_column_mapping = {column.alias_or_name: column for column in all_columns}
706
707        columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
708        if isinstance(to_replace, dict):
709            old_values = list(to_replace)
710            new_values = list(to_replace.values())
711        elif not old_values and isinstance(to_replace, list):
712            assert isinstance(value, list), "value must be a list since the replacements are a list"
713            assert len(to_replace) == len(
714                value
715            ), "the replacements and values must be the same length"
716            old_values = to_replace
717            new_values = value
718        else:
719            old_values = [to_replace] * len(columns)
720            new_values = [value] * len(columns)
721        old_values = [lit(value) for value in old_values]
722        new_values = [lit(value) for value in new_values]
723
724        replacement_mapping = {}
725        for column in columns:
726            expression = Column(None)
727            for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
728                if i == 0:
729                    expression = F.when(column == old_value, new_value)
730                else:
731                    expression = expression.when(column == old_value, new_value)  # type: ignore
732            replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
733                column.expression.alias_or_name
734            )
735
736        replacement_mapping = {**all_column_mapping, **replacement_mapping}
737        replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
738        new_df = new_df.select(*replacement_columns)
739        return new_df
740
741    @operation(Operation.SELECT)
742    def withColumn(self, colName: str, col: Column) -> DataFrame:
743        col = self._ensure_and_normalize_col(col)
744        existing_col_names = self.expression.named_selects
745        existing_col_index = (
746            existing_col_names.index(colName) if colName in existing_col_names else None
747        )
748        if existing_col_index:
749            expression = self.expression.copy()
750            expression.expressions[existing_col_index] = col.expression
751            return self.copy(expression=expression)
752        return self.copy().select(col.alias(colName), append=True)
753
754    @operation(Operation.SELECT)
755    def withColumnRenamed(self, existing: str, new: str):
756        expression = self.expression.copy()
757        existing_columns = [
758            expression
759            for expression in expression.expressions
760            if expression.alias_or_name == existing
761        ]
762        if not existing_columns:
763            raise ValueError("Tried to rename a column that doesn't exist")
764        for existing_column in existing_columns:
765            if isinstance(existing_column, exp.Column):
766                existing_column.replace(exp.alias_(existing_column, new))
767            else:
768                existing_column.set("alias", exp.to_identifier(new))
769        return self.copy(expression=expression)
770
771    @operation(Operation.SELECT)
772    def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
773        all_columns = self._get_outer_select_columns(self.expression)
774        drop_cols = self._ensure_and_normalize_cols(cols)
775        new_columns = [
776            col
777            for col in all_columns
778            if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
779        ]
780        return self.copy().select(*new_columns, append=False)
781
782    @operation(Operation.LIMIT)
783    def limit(self, num: int) -> DataFrame:
784        return self.copy(expression=self.expression.limit(num))
785
786    @operation(Operation.NO_OP)
787    def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
788        parameter_list = ensure_list(parameters)
789        parameter_columns = (
790            self._ensure_list_of_columns(parameter_list)
791            if parameters
792            else Column.ensure_cols([self.sequence_id])
793        )
794        return self._hint(name, parameter_columns)
795
796    @operation(Operation.NO_OP)
797    def repartition(
798        self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
799    ) -> DataFrame:
800        num_partition_cols = self._ensure_list_of_columns(numPartitions)
801        columns = self._ensure_and_normalize_cols(cols)
802        args = num_partition_cols + columns
803        return self._hint("repartition", args)
804
805    @operation(Operation.NO_OP)
806    def coalesce(self, numPartitions: int) -> DataFrame:
807        num_partitions = Column.ensure_cols([numPartitions])
808        return self._hint("coalesce", num_partitions)
809
810    @operation(Operation.NO_OP)
811    def cache(self) -> DataFrame:
812        return self._cache(storage_level="MEMORY_AND_DISK")
813
814    @operation(Operation.NO_OP)
815    def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
816        """
817        Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
818        """
819        return self._cache(storageLevel)
DataFrame( spark: <MagicMock id='140430588014224'>, expression: sqlglot.expressions.Select, branch_id: Optional[str] = None, sequence_id: Optional[str] = None, last_op: sqlglot.dataframe.sql.operations.Operation = <Operation.INIT: -1>, pending_hints: Optional[List[sqlglot.expressions.Expression]] = None, output_expression_container: Optional[<MagicMock id='140430588117776'>] = None, **kwargs)
46    def __init__(
47        self,
48        spark: SparkSession,
49        expression: exp.Select,
50        branch_id: t.Optional[str] = None,
51        sequence_id: t.Optional[str] = None,
52        last_op: Operation = Operation.INIT,
53        pending_hints: t.Optional[t.List[exp.Expression]] = None,
54        output_expression_container: t.Optional[OutputExpressionContainer] = None,
55        **kwargs,
56    ):
57        self.spark = spark
58        self.expression = expression
59        self.branch_id = branch_id or self.spark._random_branch_id
60        self.sequence_id = sequence_id or self.spark._random_sequence_id
61        self.last_op = last_op
62        self.pending_hints = pending_hints or []
63        self.output_expression_container = output_expression_container or exp.Select()
spark
expression
branch_id
sequence_id
last_op
pending_hints
output_expression_container
sparkSession
write
latest_cte_name: str
pending_join_hints
pending_partition_hints
columns: List[str]
def sql(self, dialect='spark', optimize=True, **kwargs) -> List[str]:
295    def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
296        df = self._resolve_pending_hints()
297        select_expressions = df._get_select_expressions()
298        output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
299        replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
300        for expression_type, select_expression in select_expressions:
301            select_expression = select_expression.transform(replace_id_value, replacement_mapping)
302            if optimize:
303                select_expression = t.cast(exp.Select, optimize_func(select_expression))
304            select_expression = df._replace_cte_names_with_hashes(select_expression)
305            expression: t.Union[exp.Select, exp.Cache, exp.Drop]
306            if expression_type == exp.Cache:
307                cache_table_name = df._create_hash_from_expression(select_expression)
308                cache_table = exp.to_table(cache_table_name)
309                original_alias_name = select_expression.args["cte_alias_name"]
310
311                replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(  # type: ignore
312                    cache_table_name
313                )
314                sqlglot.schema.add_table(
315                    cache_table_name,
316                    {
317                        expression.alias_or_name: expression.type.sql("spark")
318                        for expression in select_expression.expressions
319                    },
320                    dialect="spark",
321                )
322                cache_storage_level = select_expression.args["cache_storage_level"]
323                options = [
324                    exp.Literal.string("storageLevel"),
325                    exp.Literal.string(cache_storage_level),
326                ]
327                expression = exp.Cache(
328                    this=cache_table, expression=select_expression, lazy=True, options=options
329                )
330                # We will drop the "view" if it exists before running the cache table
331                output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
332            elif expression_type == exp.Create:
333                expression = df.output_expression_container.copy()
334                expression.set("expression", select_expression)
335            elif expression_type == exp.Insert:
336                expression = df.output_expression_container.copy()
337                select_without_ctes = select_expression.copy()
338                select_without_ctes.set("with", None)
339                expression.set("expression", select_without_ctes)
340                if select_expression.ctes:
341                    expression.set("with", exp.With(expressions=select_expression.ctes))
342            elif expression_type == exp.Select:
343                expression = select_expression
344            else:
345                raise ValueError(f"Invalid expression type: {expression_type}")
346            output_expressions.append(expression)
347
348        return [
349            expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions
350        ]
def copy(self, **kwargs) -> sqlglot.dataframe.sql.DataFrame:
352    def copy(self, **kwargs) -> DataFrame:
353        return DataFrame(**object_to_dict(self, **kwargs))
@operation(Operation.SELECT)
def select(self, *cols, **kwargs) -> sqlglot.dataframe.sql.DataFrame:
355    @operation(Operation.SELECT)
356    def select(self, *cols, **kwargs) -> DataFrame:
357        cols = self._ensure_and_normalize_cols(cols)
358        kwargs["append"] = kwargs.get("append", False)
359        if self.expression.args.get("joins"):
360            ambiguous_cols = [
361                col
362                for col in cols
363                if isinstance(col.column_expression, exp.Column) and not col.column_expression.table
364            ]
365            if ambiguous_cols:
366                join_table_identifiers = [
367                    x.this for x in get_tables_from_expression_with_join(self.expression)
368                ]
369                cte_names_in_join = [x.this for x in join_table_identifiers]
370                # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
371                # and therefore we allow multiple columns with the same name in the result. This matches the behavior
372                # of Spark.
373                resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols}
374                for ambiguous_col in ambiguous_cols:
375                    ctes_with_column = [
376                        cte
377                        for cte in self.expression.ctes
378                        if cte.alias_or_name in cte_names_in_join
379                        and ambiguous_col.alias_or_name in cte.this.named_selects
380                    ]
381                    # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise,
382                    # use the same CTE we used before
383                    cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1)
384                    if cte:
385                        resolved_column_position[ambiguous_col] += 1
386                    else:
387                        cte = ctes_with_column[resolved_column_position[ambiguous_col]]
388                    ambiguous_col.expression.set("table", cte.alias_or_name)
389        return self.copy(
390            expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs
391        )
@operation(Operation.NO_OP)
def alias(self, name: str, **kwargs) -> sqlglot.dataframe.sql.DataFrame:
393    @operation(Operation.NO_OP)
394    def alias(self, name: str, **kwargs) -> DataFrame:
395        new_sequence_id = self.spark._random_sequence_id
396        df = self.copy()
397        for join_hint in df.pending_join_hints:
398            for expression in join_hint.expressions:
399                if expression.alias_or_name == self.sequence_id:
400                    expression.set("this", Column.ensure_col(new_sequence_id).expression)
401        df.spark._add_alias_to_mapping(name, new_sequence_id)
402        return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
@operation(Operation.WHERE)
def where( self, column: Union[sqlglot.dataframe.sql.Column, bool], **kwargs) -> sqlglot.dataframe.sql.DataFrame:
404    @operation(Operation.WHERE)
405    def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
406        col = self._ensure_and_normalize_col(column)
407        return self.copy(expression=self.expression.where(col.expression))
@operation(Operation.WHERE)
def filter( self, column: Union[sqlglot.dataframe.sql.Column, bool], **kwargs) -> sqlglot.dataframe.sql.DataFrame:
404    @operation(Operation.WHERE)
405    def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
406        col = self._ensure_and_normalize_col(column)
407        return self.copy(expression=self.expression.where(col.expression))
@operation(Operation.GROUP_BY)
def groupBy(self, *cols, **kwargs) -> sqlglot.dataframe.sql.GroupedData:
411    @operation(Operation.GROUP_BY)
412    def groupBy(self, *cols, **kwargs) -> GroupedData:
413        columns = self._ensure_and_normalize_cols(cols)
414        return GroupedData(self, columns, self.last_op)
@operation(Operation.SELECT)
def agg(self, *exprs, **kwargs) -> sqlglot.dataframe.sql.DataFrame:
416    @operation(Operation.SELECT)
417    def agg(self, *exprs, **kwargs) -> DataFrame:
418        cols = self._ensure_and_normalize_cols(exprs)
419        return self.groupBy().agg(*cols)
@operation(Operation.FROM)
def join( self, other_df: sqlglot.dataframe.sql.DataFrame, on: Union[str, List[str], sqlglot.dataframe.sql.Column, List[sqlglot.dataframe.sql.Column]], how: str = 'inner', **kwargs) -> sqlglot.dataframe.sql.DataFrame:
421    @operation(Operation.FROM)
422    def join(
423        self,
424        other_df: DataFrame,
425        on: t.Union[str, t.List[str], Column, t.List[Column]],
426        how: str = "inner",
427        **kwargs,
428    ) -> DataFrame:
429        other_df = other_df._convert_leaf_to_cte()
430        join_columns = self._ensure_list_of_columns(on)
431        # We will determine actual "join on" expression later so we don't provide it at first
432        join_expression = self.expression.join(
433            other_df.latest_cte_name, join_type=how.replace("_", " ")
434        )
435        join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
436        self_columns = self._get_outer_select_columns(join_expression)
437        other_columns = self._get_outer_select_columns(other_df)
438        # Determines the join clause and select columns to be used passed on what type of columns were provided for
439        # the join. The columns returned changes based on how the on expression is provided.
440        if isinstance(join_columns[0].expression, exp.Column):
441            """
442            Unique characteristics of join on column names only:
443            * The column names are put at the front of the select list
444            * The column names are deduplicated across the entire select list and only the column names (other dups are allowed)
445            """
446            table_names = [
447                table.alias_or_name
448                for table in get_tables_from_expression_with_join(join_expression)
449            ]
450            potential_ctes = [
451                cte
452                for cte in join_expression.ctes
453                if cte.alias_or_name in table_names
454                and cte.alias_or_name != other_df.latest_cte_name
455            ]
456            # Determine the table to reference for the left side of the join by checking each of the left side
457            # tables and see if they have the column being referenced.
458            join_column_pairs = []
459            for join_column in join_columns:
460                num_matching_ctes = 0
461                for cte in potential_ctes:
462                    if join_column.alias_or_name in cte.this.named_selects:
463                        left_column = join_column.copy().set_table_name(cte.alias_or_name)
464                        right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
465                        join_column_pairs.append((left_column, right_column))
466                        num_matching_ctes += 1
467                if num_matching_ctes > 1:
468                    raise ValueError(
469                        f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name."
470                    )
471                elif num_matching_ctes == 0:
472                    raise ValueError(
473                        f"Column {join_column.alias_or_name} does not exist in any of the tables."
474                    )
475            join_clause = functools.reduce(
476                lambda x, y: x & y,
477                [left_column == right_column for left_column, right_column in join_column_pairs],
478            )
479            join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
480            # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
481            select_column_names = [
482                column.alias_or_name
483                if not isinstance(column.expression.this, exp.Star)
484                else column.sql()
485                for column in self_columns + other_columns
486            ]
487            select_column_names = [
488                column_name
489                for column_name in select_column_names
490                if column_name not in join_column_names
491            ]
492            select_column_names = join_column_names + select_column_names
493        else:
494            """
495            Unique characteristics of join on expressions:
496            * There is no deduplication of the results.
497            * The left join dataframe columns go first and right come after. No sort preference is given to join columns
498            """
499            join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
500            if len(join_columns) > 1:
501                join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
502            join_clause = join_columns[0]
503            select_column_names = [column.alias_or_name for column in self_columns + other_columns]
504
505        # Update the on expression with the actual join clause to replace the dummy one from before
506        join_expression.args["joins"][-1].set("on", join_clause.expression)
507        new_df = self.copy(expression=join_expression)
508        new_df.pending_join_hints.extend(self.pending_join_hints)
509        new_df.pending_hints.extend(other_df.pending_hints)
510        new_df = new_df.select.__wrapped__(new_df, *select_column_names)
511        return new_df
@operation(Operation.ORDER_BY)
def orderBy( self, *cols: Union[str, sqlglot.dataframe.sql.Column], ascending: Union[Any, List[Any], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
513    @operation(Operation.ORDER_BY)
514    def orderBy(
515        self,
516        *cols: t.Union[str, Column],
517        ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
518    ) -> DataFrame:
519        """
520        This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
521        has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
522        is unlikely to come up.
523        """
524        columns = self._ensure_and_normalize_cols(cols)
525        pre_ordered_col_indexes = [
526            x
527            for x in [
528                i if isinstance(col.expression, exp.Ordered) else None
529                for i, col in enumerate(columns)
530            ]
531            if x is not None
532        ]
533        if ascending is None:
534            ascending = [True] * len(columns)
535        elif not isinstance(ascending, list):
536            ascending = [ascending] * len(columns)
537        ascending = [bool(x) for i, x in enumerate(ascending)]
538        assert len(columns) == len(
539            ascending
540        ), "The length of items in ascending must equal the number of columns provided"
541        col_and_ascending = list(zip(columns, ascending))
542        order_by_columns = [
543            exp.Ordered(this=col.expression, desc=not asc)
544            if i not in pre_ordered_col_indexes
545            else columns[i].column_expression
546            for i, (col, asc) in enumerate(col_and_ascending)
547        ]
548        return self.copy(expression=self.expression.order_by(*order_by_columns))

This implementation lets any ordered columns take priority over whatever is provided in ascending. Spark has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this is unlikely to come up.

@operation(Operation.ORDER_BY)
def sort( self, *cols: Union[str, sqlglot.dataframe.sql.Column], ascending: Union[Any, List[Any], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
513    @operation(Operation.ORDER_BY)
514    def orderBy(
515        self,
516        *cols: t.Union[str, Column],
517        ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
518    ) -> DataFrame:
519        """
520        This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
521        has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
522        is unlikely to come up.
523        """
524        columns = self._ensure_and_normalize_cols(cols)
525        pre_ordered_col_indexes = [
526            x
527            for x in [
528                i if isinstance(col.expression, exp.Ordered) else None
529                for i, col in enumerate(columns)
530            ]
531            if x is not None
532        ]
533        if ascending is None:
534            ascending = [True] * len(columns)
535        elif not isinstance(ascending, list):
536            ascending = [ascending] * len(columns)
537        ascending = [bool(x) for i, x in enumerate(ascending)]
538        assert len(columns) == len(
539            ascending
540        ), "The length of items in ascending must equal the number of columns provided"
541        col_and_ascending = list(zip(columns, ascending))
542        order_by_columns = [
543            exp.Ordered(this=col.expression, desc=not asc)
544            if i not in pre_ordered_col_indexes
545            else columns[i].column_expression
546            for i, (col, asc) in enumerate(col_and_ascending)
547        ]
548        return self.copy(expression=self.expression.order_by(*order_by_columns))

This implementation lets any ordered columns take priority over whatever is provided in ascending. Spark has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this is unlikely to come up.

@operation(Operation.FROM)
def union( self, other: sqlglot.dataframe.sql.DataFrame) -> sqlglot.dataframe.sql.DataFrame:
552    @operation(Operation.FROM)
553    def union(self, other: DataFrame) -> DataFrame:
554        return self._set_operation(exp.Union, other, False)
@operation(Operation.FROM)
def unionAll( self, other: sqlglot.dataframe.sql.DataFrame) -> sqlglot.dataframe.sql.DataFrame:
552    @operation(Operation.FROM)
553    def union(self, other: DataFrame) -> DataFrame:
554        return self._set_operation(exp.Union, other, False)
@operation(Operation.FROM)
def unionByName( self, other: sqlglot.dataframe.sql.DataFrame, allowMissingColumns: bool = False):
558    @operation(Operation.FROM)
559    def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
560        l_columns = self.columns
561        r_columns = other.columns
562        if not allowMissingColumns:
563            l_expressions = l_columns
564            r_expressions = l_columns
565        else:
566            l_expressions = []
567            r_expressions = []
568            r_columns_unused = copy(r_columns)
569            for l_column in l_columns:
570                l_expressions.append(l_column)
571                if l_column in r_columns:
572                    r_expressions.append(l_column)
573                    r_columns_unused.remove(l_column)
574                else:
575                    r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False))
576            for r_column in r_columns_unused:
577                l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False))
578                r_expressions.append(r_column)
579        r_df = (
580            other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
581        )
582        l_df = self.copy()
583        if allowMissingColumns:
584            l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
585        return l_df._set_operation(exp.Union, r_df, False)
@operation(Operation.FROM)
def intersect( self, other: sqlglot.dataframe.sql.DataFrame) -> sqlglot.dataframe.sql.DataFrame:
587    @operation(Operation.FROM)
588    def intersect(self, other: DataFrame) -> DataFrame:
589        return self._set_operation(exp.Intersect, other, True)
@operation(Operation.FROM)
def intersectAll( self, other: sqlglot.dataframe.sql.DataFrame) -> sqlglot.dataframe.sql.DataFrame:
591    @operation(Operation.FROM)
592    def intersectAll(self, other: DataFrame) -> DataFrame:
593        return self._set_operation(exp.Intersect, other, False)
@operation(Operation.FROM)
def exceptAll( self, other: sqlglot.dataframe.sql.DataFrame) -> sqlglot.dataframe.sql.DataFrame:
595    @operation(Operation.FROM)
596    def exceptAll(self, other: DataFrame) -> DataFrame:
597        return self._set_operation(exp.Except, other, False)
@operation(Operation.SELECT)
def distinct(self) -> sqlglot.dataframe.sql.DataFrame:
599    @operation(Operation.SELECT)
600    def distinct(self) -> DataFrame:
601        return self.copy(expression=self.expression.distinct())
@operation(Operation.SELECT)
def dropDuplicates(self, subset: Optional[List[str]] = None):
603    @operation(Operation.SELECT)
604    def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
605        if not subset:
606            return self.distinct()
607        column_names = ensure_list(subset)
608        window = Window.partitionBy(*column_names).orderBy(*column_names)
609        return (
610            self.copy()
611            .withColumn("row_num", F.row_number().over(window))
612            .where(F.col("row_num") == F.lit(1))
613            .drop("row_num")
614        )
@operation(Operation.FROM)
def dropna( self, how: str = 'any', thresh: Optional[int] = None, subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
616    @operation(Operation.FROM)
617    def dropna(
618        self,
619        how: str = "any",
620        thresh: t.Optional[int] = None,
621        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
622    ) -> DataFrame:
623        minimum_non_null = thresh or 0  # will be determined later if thresh is null
624        new_df = self.copy()
625        all_columns = self._get_outer_select_columns(new_df.expression)
626        if subset:
627            null_check_columns = self._ensure_and_normalize_cols(subset)
628        else:
629            null_check_columns = all_columns
630        if thresh is None:
631            minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
632        else:
633            minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
634        if minimum_num_nulls > len(null_check_columns):
635            raise RuntimeError(
636                f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
637                f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
638            )
639        if_null_checks = [
640            F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
641        ]
642        nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
643        num_nulls = nulls_added_together.alias("num_nulls")
644        new_df = new_df.select(num_nulls, append=True)
645        filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
646        final_df = filtered_df.select(*all_columns)
647        return final_df
@operation(Operation.FROM)
def fillna( self, value: <MagicMock id='140430583371888'>, subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
649    @operation(Operation.FROM)
650    def fillna(
651        self,
652        value: t.Union[ColumnLiterals],
653        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
654    ) -> DataFrame:
655        """
656        Functionality Difference: If you provide a value to replace a null and that type conflicts
657        with the type of the column then PySpark will just ignore your replacement.
658        This will try to cast them to be the same in some cases. So they won't always match.
659        Best to not mix types so make sure replacement is the same type as the column
660
661        Possibility for improvement: Use `typeof` function to get the type of the column
662        and check if it matches the type of the value provided. If not then make it null.
663        """
664        from sqlglot.dataframe.sql.functions import lit
665
666        values = None
667        columns = None
668        new_df = self.copy()
669        all_columns = self._get_outer_select_columns(new_df.expression)
670        all_column_mapping = {column.alias_or_name: column for column in all_columns}
671        if isinstance(value, dict):
672            values = list(value.values())
673            columns = self._ensure_and_normalize_cols(list(value))
674        if not columns:
675            columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
676        if not values:
677            values = [value] * len(columns)
678        value_columns = [lit(value) for value in values]
679
680        null_replacement_mapping = {
681            column.alias_or_name: (
682                F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
683            )
684            for column, value in zip(columns, value_columns)
685        }
686        null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
687        null_replacement_columns = [
688            null_replacement_mapping[column.alias_or_name] for column in all_columns
689        ]
690        new_df = new_df.select(*null_replacement_columns)
691        return new_df

Functionality Difference: If you provide a value to replace a null and that type conflicts with the type of the column then PySpark will just ignore your replacement. This will try to cast them to be the same in some cases. So they won't always match. Best to not mix types so make sure replacement is the same type as the column

Possibility for improvement: Use typeof function to get the type of the column and check if it matches the type of the value provided. If not then make it null.

@operation(Operation.FROM)
def replace( self, to_replace: Union[bool, int, float, str, List, Dict], value: Union[bool, int, float, str, List, NoneType] = None, subset: Union[Collection[<MagicMock id='140430583980064'>], <MagicMock id='140430583980064'>, NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
693    @operation(Operation.FROM)
694    def replace(
695        self,
696        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
697        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
698        subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
699    ) -> DataFrame:
700        from sqlglot.dataframe.sql.functions import lit
701
702        old_values = None
703        new_df = self.copy()
704        all_columns = self._get_outer_select_columns(new_df.expression)
705        all_column_mapping = {column.alias_or_name: column for column in all_columns}
706
707        columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
708        if isinstance(to_replace, dict):
709            old_values = list(to_replace)
710            new_values = list(to_replace.values())
711        elif not old_values and isinstance(to_replace, list):
712            assert isinstance(value, list), "value must be a list since the replacements are a list"
713            assert len(to_replace) == len(
714                value
715            ), "the replacements and values must be the same length"
716            old_values = to_replace
717            new_values = value
718        else:
719            old_values = [to_replace] * len(columns)
720            new_values = [value] * len(columns)
721        old_values = [lit(value) for value in old_values]
722        new_values = [lit(value) for value in new_values]
723
724        replacement_mapping = {}
725        for column in columns:
726            expression = Column(None)
727            for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
728                if i == 0:
729                    expression = F.when(column == old_value, new_value)
730                else:
731                    expression = expression.when(column == old_value, new_value)  # type: ignore
732            replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
733                column.expression.alias_or_name
734            )
735
736        replacement_mapping = {**all_column_mapping, **replacement_mapping}
737        replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
738        new_df = new_df.select(*replacement_columns)
739        return new_df
@operation(Operation.SELECT)
def withColumn( self, colName: str, col: sqlglot.dataframe.sql.Column) -> sqlglot.dataframe.sql.DataFrame:
741    @operation(Operation.SELECT)
742    def withColumn(self, colName: str, col: Column) -> DataFrame:
743        col = self._ensure_and_normalize_col(col)
744        existing_col_names = self.expression.named_selects
745        existing_col_index = (
746            existing_col_names.index(colName) if colName in existing_col_names else None
747        )
748        if existing_col_index:
749            expression = self.expression.copy()
750            expression.expressions[existing_col_index] = col.expression
751            return self.copy(expression=expression)
752        return self.copy().select(col.alias(colName), append=True)
@operation(Operation.SELECT)
def withColumnRenamed(self, existing: str, new: str):
754    @operation(Operation.SELECT)
755    def withColumnRenamed(self, existing: str, new: str):
756        expression = self.expression.copy()
757        existing_columns = [
758            expression
759            for expression in expression.expressions
760            if expression.alias_or_name == existing
761        ]
762        if not existing_columns:
763            raise ValueError("Tried to rename a column that doesn't exist")
764        for existing_column in existing_columns:
765            if isinstance(existing_column, exp.Column):
766                existing_column.replace(exp.alias_(existing_column, new))
767            else:
768                existing_column.set("alias", exp.to_identifier(new))
769        return self.copy(expression=expression)
@operation(Operation.SELECT)
def drop( self, *cols: Union[str, sqlglot.dataframe.sql.Column]) -> sqlglot.dataframe.sql.DataFrame:
771    @operation(Operation.SELECT)
772    def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
773        all_columns = self._get_outer_select_columns(self.expression)
774        drop_cols = self._ensure_and_normalize_cols(cols)
775        new_columns = [
776            col
777            for col in all_columns
778            if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
779        ]
780        return self.copy().select(*new_columns, append=False)
@operation(Operation.LIMIT)
def limit(self, num: int) -> sqlglot.dataframe.sql.DataFrame:
782    @operation(Operation.LIMIT)
783    def limit(self, num: int) -> DataFrame:
784        return self.copy(expression=self.expression.limit(num))
@operation(Operation.NO_OP)
def hint( self, name: str, *parameters: Union[str, int, NoneType]) -> sqlglot.dataframe.sql.DataFrame:
786    @operation(Operation.NO_OP)
787    def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
788        parameter_list = ensure_list(parameters)
789        parameter_columns = (
790            self._ensure_list_of_columns(parameter_list)
791            if parameters
792            else Column.ensure_cols([self.sequence_id])
793        )
794        return self._hint(name, parameter_columns)
@operation(Operation.NO_OP)
def repartition( self, numPartitions: Union[int, <MagicMock id='140430581926688'>], *cols: <MagicMock id='140430581982224'>) -> sqlglot.dataframe.sql.DataFrame:
796    @operation(Operation.NO_OP)
797    def repartition(
798        self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
799    ) -> DataFrame:
800        num_partition_cols = self._ensure_list_of_columns(numPartitions)
801        columns = self._ensure_and_normalize_cols(cols)
802        args = num_partition_cols + columns
803        return self._hint("repartition", args)
@operation(Operation.NO_OP)
def coalesce(self, numPartitions: int) -> sqlglot.dataframe.sql.DataFrame:
805    @operation(Operation.NO_OP)
806    def coalesce(self, numPartitions: int) -> DataFrame:
807        num_partitions = Column.ensure_cols([numPartitions])
808        return self._hint("coalesce", num_partitions)
@operation(Operation.NO_OP)
def cache(self) -> sqlglot.dataframe.sql.DataFrame:
810    @operation(Operation.NO_OP)
811    def cache(self) -> DataFrame:
812        return self._cache(storage_level="MEMORY_AND_DISK")
@operation(Operation.NO_OP)
def persist( self, storageLevel: str = 'MEMORY_AND_DISK_SER') -> sqlglot.dataframe.sql.DataFrame:
814    @operation(Operation.NO_OP)
815    def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
816        """
817        Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
818        """
819        return self._cache(storageLevel)
class GroupedData:
14class GroupedData:
15    def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation):
16        self._df = df.copy()
17        self.spark = df.spark
18        self.last_op = last_op
19        self.group_by_cols = group_by_cols
20
21    def _get_function_applied_columns(
22        self, func_name: str, cols: t.Tuple[str, ...]
23    ) -> t.List[Column]:
24        func_name = func_name.lower()
25        return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols]
26
27    @operation(Operation.SELECT)
28    def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame:
29        columns = (
30            [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
31            if isinstance(exprs[0], dict)
32            else exprs
33        )
34        cols = self._df._ensure_and_normalize_cols(columns)
35
36        expression = self._df.expression.group_by(
37            *[x.expression for x in self.group_by_cols]
38        ).select(*[x.expression for x in self.group_by_cols + cols], append=False)
39        return self._df.copy(expression=expression)
40
41    def count(self) -> DataFrame:
42        return self.agg(F.count("*").alias("count"))
43
44    def mean(self, *cols: str) -> DataFrame:
45        return self.avg(*cols)
46
47    def avg(self, *cols: str) -> DataFrame:
48        return self.agg(*self._get_function_applied_columns("avg", cols))
49
50    def max(self, *cols: str) -> DataFrame:
51        return self.agg(*self._get_function_applied_columns("max", cols))
52
53    def min(self, *cols: str) -> DataFrame:
54        return self.agg(*self._get_function_applied_columns("min", cols))
55
56    def sum(self, *cols: str) -> DataFrame:
57        return self.agg(*self._get_function_applied_columns("sum", cols))
58
59    def pivot(self, *cols: str) -> DataFrame:
60        raise NotImplementedError("Sum distinct is not currently implemented")
GroupedData( df: sqlglot.dataframe.sql.DataFrame, group_by_cols: List[sqlglot.dataframe.sql.Column], last_op: sqlglot.dataframe.sql.operations.Operation)
15    def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation):
16        self._df = df.copy()
17        self.spark = df.spark
18        self.last_op = last_op
19        self.group_by_cols = group_by_cols
spark
last_op
group_by_cols
@operation(Operation.SELECT)
def agg( self, *exprs: Union[sqlglot.dataframe.sql.Column, Dict[str, str]]) -> sqlglot.dataframe.sql.DataFrame:
27    @operation(Operation.SELECT)
28    def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame:
29        columns = (
30            [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
31            if isinstance(exprs[0], dict)
32            else exprs
33        )
34        cols = self._df._ensure_and_normalize_cols(columns)
35
36        expression = self._df.expression.group_by(
37            *[x.expression for x in self.group_by_cols]
38        ).select(*[x.expression for x in self.group_by_cols + cols], append=False)
39        return self._df.copy(expression=expression)
def count(self) -> sqlglot.dataframe.sql.DataFrame:
41    def count(self) -> DataFrame:
42        return self.agg(F.count("*").alias("count"))
def mean(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
44    def mean(self, *cols: str) -> DataFrame:
45        return self.avg(*cols)
def avg(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
47    def avg(self, *cols: str) -> DataFrame:
48        return self.agg(*self._get_function_applied_columns("avg", cols))
def max(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
50    def max(self, *cols: str) -> DataFrame:
51        return self.agg(*self._get_function_applied_columns("max", cols))
def min(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
53    def min(self, *cols: str) -> DataFrame:
54        return self.agg(*self._get_function_applied_columns("min", cols))
def sum(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
56    def sum(self, *cols: str) -> DataFrame:
57        return self.agg(*self._get_function_applied_columns("sum", cols))
def pivot(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
59    def pivot(self, *cols: str) -> DataFrame:
60        raise NotImplementedError("Sum distinct is not currently implemented")
class Column:
 17class Column:
 18    def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
 19        if isinstance(expression, Column):
 20            expression = expression.expression  # type: ignore
 21        elif expression is None or not isinstance(expression, (str, exp.Expression)):
 22            expression = self._lit(expression).expression  # type: ignore
 23
 24        expression = sqlglot.maybe_parse(expression, dialect="spark")
 25        if expression is None:
 26            raise ValueError(f"Could not parse {expression}")
 27
 28        if isinstance(expression, exp.Column):
 29            expression.transform(Spark.normalize_identifier, copy=False)
 30
 31        self.expression: exp.Expression = expression
 32
 33    def __repr__(self):
 34        return repr(self.expression)
 35
 36    def __hash__(self):
 37        return hash(self.expression)
 38
 39    def __eq__(self, other: ColumnOrLiteral) -> Column:  # type: ignore
 40        return self.binary_op(exp.EQ, other)
 41
 42    def __ne__(self, other: ColumnOrLiteral) -> Column:  # type: ignore
 43        return self.binary_op(exp.NEQ, other)
 44
 45    def __gt__(self, other: ColumnOrLiteral) -> Column:
 46        return self.binary_op(exp.GT, other)
 47
 48    def __ge__(self, other: ColumnOrLiteral) -> Column:
 49        return self.binary_op(exp.GTE, other)
 50
 51    def __lt__(self, other: ColumnOrLiteral) -> Column:
 52        return self.binary_op(exp.LT, other)
 53
 54    def __le__(self, other: ColumnOrLiteral) -> Column:
 55        return self.binary_op(exp.LTE, other)
 56
 57    def __and__(self, other: ColumnOrLiteral) -> Column:
 58        return self.binary_op(exp.And, other)
 59
 60    def __or__(self, other: ColumnOrLiteral) -> Column:
 61        return self.binary_op(exp.Or, other)
 62
 63    def __mod__(self, other: ColumnOrLiteral) -> Column:
 64        return self.binary_op(exp.Mod, other)
 65
 66    def __add__(self, other: ColumnOrLiteral) -> Column:
 67        return self.binary_op(exp.Add, other)
 68
 69    def __sub__(self, other: ColumnOrLiteral) -> Column:
 70        return self.binary_op(exp.Sub, other)
 71
 72    def __mul__(self, other: ColumnOrLiteral) -> Column:
 73        return self.binary_op(exp.Mul, other)
 74
 75    def __truediv__(self, other: ColumnOrLiteral) -> Column:
 76        return self.binary_op(exp.Div, other)
 77
 78    def __div__(self, other: ColumnOrLiteral) -> Column:
 79        return self.binary_op(exp.Div, other)
 80
 81    def __neg__(self) -> Column:
 82        return self.unary_op(exp.Neg)
 83
 84    def __radd__(self, other: ColumnOrLiteral) -> Column:
 85        return self.inverse_binary_op(exp.Add, other)
 86
 87    def __rsub__(self, other: ColumnOrLiteral) -> Column:
 88        return self.inverse_binary_op(exp.Sub, other)
 89
 90    def __rmul__(self, other: ColumnOrLiteral) -> Column:
 91        return self.inverse_binary_op(exp.Mul, other)
 92
 93    def __rdiv__(self, other: ColumnOrLiteral) -> Column:
 94        return self.inverse_binary_op(exp.Div, other)
 95
 96    def __rtruediv__(self, other: ColumnOrLiteral) -> Column:
 97        return self.inverse_binary_op(exp.Div, other)
 98
 99    def __rmod__(self, other: ColumnOrLiteral) -> Column:
100        return self.inverse_binary_op(exp.Mod, other)
101
102    def __pow__(self, power: ColumnOrLiteral, modulo=None):
103        return Column(exp.Pow(this=self.expression, expression=Column(power).expression))
104
105    def __rpow__(self, power: ColumnOrLiteral):
106        return Column(exp.Pow(this=Column(power).expression, expression=self.expression))
107
108    def __invert__(self):
109        return self.unary_op(exp.Not)
110
111    def __rand__(self, other: ColumnOrLiteral) -> Column:
112        return self.inverse_binary_op(exp.And, other)
113
114    def __ror__(self, other: ColumnOrLiteral) -> Column:
115        return self.inverse_binary_op(exp.Or, other)
116
117    @classmethod
118    def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
119        return cls(value)
120
121    @classmethod
122    def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]:
123        return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args]
124
125    @classmethod
126    def _lit(cls, value: ColumnOrLiteral) -> Column:
127        if isinstance(value, dict):
128            columns = [cls._lit(v).alias(k).expression for k, v in value.items()]
129            return cls(exp.Struct(expressions=columns))
130        return cls(exp.convert(value))
131
132    @classmethod
133    def invoke_anonymous_function(
134        cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral]
135    ) -> Column:
136        columns = [] if column is None else [cls.ensure_col(column)]
137        column_args = [cls.ensure_col(arg) for arg in args]
138        expressions = [x.expression for x in columns + column_args]
139        new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions)
140        return Column(new_expression)
141
142    @classmethod
143    def invoke_expression_over_column(
144        cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
145    ) -> Column:
146        ensured_column = None if column is None else cls.ensure_col(column)
147        ensure_expression_values = {
148            k: [Column.ensure_col(x).expression for x in v]
149            if is_iterable(v)
150            else Column.ensure_col(v).expression
151            for k, v in kwargs.items()
152            if v is not None
153        }
154        new_expression = (
155            callable_expression(**ensure_expression_values)
156            if ensured_column is None
157            else callable_expression(
158                this=ensured_column.column_expression, **ensure_expression_values
159            )
160        )
161        return Column(new_expression)
162
163    def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
164        return Column(
165            klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
166        )
167
168    def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
169        return Column(
170            klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
171        )
172
173    def unary_op(self, klass: t.Callable, **kwargs) -> Column:
174        return Column(klass(this=self.column_expression, **kwargs))
175
176    @property
177    def is_alias(self):
178        return isinstance(self.expression, exp.Alias)
179
180    @property
181    def is_column(self):
182        return isinstance(self.expression, exp.Column)
183
184    @property
185    def column_expression(self) -> t.Union[exp.Column, exp.Literal]:
186        return self.expression.unalias()
187
188    @property
189    def alias_or_name(self) -> str:
190        return self.expression.alias_or_name
191
192    @classmethod
193    def ensure_literal(cls, value) -> Column:
194        from sqlglot.dataframe.sql.functions import lit
195
196        if isinstance(value, cls):
197            value = value.expression
198        if not isinstance(value, exp.Literal):
199            return lit(value)
200        return Column(value)
201
202    def copy(self) -> Column:
203        return Column(self.expression.copy())
204
205    def set_table_name(self, table_name: str, copy=False) -> Column:
206        expression = self.expression.copy() if copy else self.expression
207        expression.set("table", exp.to_identifier(table_name))
208        return Column(expression)
209
210    def sql(self, **kwargs) -> str:
211        return self.expression.sql(**{"dialect": "spark", **kwargs})
212
213    def alias(self, name: str) -> Column:
214        new_expression = exp.alias_(self.column_expression, name)
215        return Column(new_expression)
216
217    def asc(self) -> Column:
218        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
219        return Column(new_expression)
220
221    def desc(self) -> Column:
222        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
223        return Column(new_expression)
224
225    asc_nulls_first = asc
226
227    def asc_nulls_last(self) -> Column:
228        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False)
229        return Column(new_expression)
230
231    def desc_nulls_first(self) -> Column:
232        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True)
233        return Column(new_expression)
234
235    desc_nulls_last = desc
236
237    def when(self, condition: Column, value: t.Any) -> Column:
238        from sqlglot.dataframe.sql.functions import when
239
240        column_with_if = when(condition, value)
241        if not isinstance(self.expression, exp.Case):
242            return column_with_if
243        new_column = self.copy()
244        new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"])
245        return new_column
246
247    def otherwise(self, value: t.Any) -> Column:
248        from sqlglot.dataframe.sql.functions import lit
249
250        true_value = value if isinstance(value, Column) else lit(value)
251        new_column = self.copy()
252        new_column.expression.set("default", true_value.column_expression)
253        return new_column
254
255    def isNull(self) -> Column:
256        new_expression = exp.Is(this=self.column_expression, expression=exp.Null())
257        return Column(new_expression)
258
259    def isNotNull(self) -> Column:
260        new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
261        return Column(new_expression)
262
263    def cast(self, dataType: t.Union[str, DataType]):
264        """
265        Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
266        Sqlglot doesn't currently replicate this class so it only accepts a string
267        """
268        if isinstance(dataType, DataType):
269            dataType = dataType.simpleString()
270        return Column(exp.cast(self.column_expression, dataType, dialect="spark"))
271
272    def startswith(self, value: t.Union[str, Column]) -> Column:
273        value = self._lit(value) if not isinstance(value, Column) else value
274        return self.invoke_anonymous_function(self, "STARTSWITH", value)
275
276    def endswith(self, value: t.Union[str, Column]) -> Column:
277        value = self._lit(value) if not isinstance(value, Column) else value
278        return self.invoke_anonymous_function(self, "ENDSWITH", value)
279
280    def rlike(self, regexp: str) -> Column:
281        return self.invoke_expression_over_column(
282            column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression
283        )
284
285    def like(self, other: str):
286        return self.invoke_expression_over_column(
287            self, exp.Like, expression=self._lit(other).expression
288        )
289
290    def ilike(self, other: str):
291        return self.invoke_expression_over_column(
292            self, exp.ILike, expression=self._lit(other).expression
293        )
294
295    def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
296        startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
297        length = self._lit(length) if not isinstance(length, Column) else length
298        return Column.invoke_expression_over_column(
299            self, exp.Substring, start=startPos.expression, length=length.expression
300        )
301
302    def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]):
303        columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
304        expressions = [self._lit(x).expression for x in columns]
305        return Column.invoke_expression_over_column(self, exp.In, expressions=expressions)  # type: ignore
306
307    def between(
308        self,
309        lowerBound: t.Union[ColumnOrLiteral],
310        upperBound: t.Union[ColumnOrLiteral],
311    ) -> Column:
312        lower_bound_exp = (
313            self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
314        )
315        upper_bound_exp = (
316            self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
317        )
318        return Column(
319            exp.Between(
320                this=self.column_expression,
321                low=lower_bound_exp.expression,
322                high=upper_bound_exp.expression,
323            )
324        )
325
326    def over(self, window: WindowSpec) -> Column:
327        window_expression = window.expression.copy()
328        window_expression.set("this", self.column_expression)
329        return Column(window_expression)
Column( expression: Union[<MagicMock id='140430586079424'>, sqlglot.expressions.Expression, NoneType])
18    def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
19        if isinstance(expression, Column):
20            expression = expression.expression  # type: ignore
21        elif expression is None or not isinstance(expression, (str, exp.Expression)):
22            expression = self._lit(expression).expression  # type: ignore
23
24        expression = sqlglot.maybe_parse(expression, dialect="spark")
25        if expression is None:
26            raise ValueError(f"Could not parse {expression}")
27
28        if isinstance(expression, exp.Column):
29            expression.transform(Spark.normalize_identifier, copy=False)
30
31        self.expression: exp.Expression = expression
@classmethod
def ensure_col( cls, value: Union[<MagicMock id='140430582454624'>, sqlglot.expressions.Expression, NoneType]):
117    @classmethod
118    def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
119        return cls(value)
@classmethod
def ensure_cols( cls, args: List[Union[<MagicMock id='140430582612608'>, sqlglot.expressions.Expression]]) -> List[sqlglot.dataframe.sql.Column]:
121    @classmethod
122    def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]:
123        return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args]
@classmethod
def invoke_anonymous_function( cls, column: Optional[<MagicMock id='140430582366016'>], func_name: str, *args: Optional[<MagicMock id='140430582236640'>]) -> sqlglot.dataframe.sql.Column:
132    @classmethod
133    def invoke_anonymous_function(
134        cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral]
135    ) -> Column:
136        columns = [] if column is None else [cls.ensure_col(column)]
137        column_args = [cls.ensure_col(arg) for arg in args]
138        expressions = [x.expression for x in columns + column_args]
139        new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions)
140        return Column(new_expression)
@classmethod
def invoke_expression_over_column( cls, column: Optional[<MagicMock id='140430582194784'>], callable_expression: Callable, **kwargs) -> sqlglot.dataframe.sql.Column:
142    @classmethod
143    def invoke_expression_over_column(
144        cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
145    ) -> Column:
146        ensured_column = None if column is None else cls.ensure_col(column)
147        ensure_expression_values = {
148            k: [Column.ensure_col(x).expression for x in v]
149            if is_iterable(v)
150            else Column.ensure_col(v).expression
151            for k, v in kwargs.items()
152            if v is not None
153        }
154        new_expression = (
155            callable_expression(**ensure_expression_values)
156            if ensured_column is None
157            else callable_expression(
158                this=ensured_column.column_expression, **ensure_expression_values
159            )
160        )
161        return Column(new_expression)
def binary_op( self, klass: Callable, other: <MagicMock id='140430582382656'>, **kwargs) -> sqlglot.dataframe.sql.Column:
163    def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
164        return Column(
165            klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
166        )
def inverse_binary_op( self, klass: Callable, other: <MagicMock id='140430582392304'>, **kwargs) -> sqlglot.dataframe.sql.Column:
168    def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
169        return Column(
170            klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
171        )
def unary_op(self, klass: Callable, **kwargs) -> sqlglot.dataframe.sql.Column:
173    def unary_op(self, klass: t.Callable, **kwargs) -> Column:
174        return Column(klass(this=self.column_expression, **kwargs))
is_alias
is_column
alias_or_name: str
@classmethod
def ensure_literal(cls, value) -> sqlglot.dataframe.sql.Column:
192    @classmethod
193    def ensure_literal(cls, value) -> Column:
194        from sqlglot.dataframe.sql.functions import lit
195
196        if isinstance(value, cls):
197            value = value.expression
198        if not isinstance(value, exp.Literal):
199            return lit(value)
200        return Column(value)
def copy(self) -> sqlglot.dataframe.sql.Column:
202    def copy(self) -> Column:
203        return Column(self.expression.copy())
def set_table_name(self, table_name: str, copy=False) -> sqlglot.dataframe.sql.Column:
205    def set_table_name(self, table_name: str, copy=False) -> Column:
206        expression = self.expression.copy() if copy else self.expression
207        expression.set("table", exp.to_identifier(table_name))
208        return Column(expression)
def sql(self, **kwargs) -> str:
210    def sql(self, **kwargs) -> str:
211        return self.expression.sql(**{"dialect": "spark", **kwargs})
def alias(self, name: str) -> sqlglot.dataframe.sql.Column:
213    def alias(self, name: str) -> Column:
214        new_expression = exp.alias_(self.column_expression, name)
215        return Column(new_expression)
def asc(self) -> sqlglot.dataframe.sql.Column:
217    def asc(self) -> Column:
218        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
219        return Column(new_expression)
def desc(self) -> sqlglot.dataframe.sql.Column:
221    def desc(self) -> Column:
222        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
223        return Column(new_expression)
def asc_nulls_first(self) -> sqlglot.dataframe.sql.Column:
217    def asc(self) -> Column:
218        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
219        return Column(new_expression)
def asc_nulls_last(self) -> sqlglot.dataframe.sql.Column:
227    def asc_nulls_last(self) -> Column:
228        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False)
229        return Column(new_expression)
def desc_nulls_first(self) -> sqlglot.dataframe.sql.Column:
231    def desc_nulls_first(self) -> Column:
232        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True)
233        return Column(new_expression)
def desc_nulls_last(self) -> sqlglot.dataframe.sql.Column:
221    def desc(self) -> Column:
222        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
223        return Column(new_expression)
def when( self, condition: sqlglot.dataframe.sql.Column, value: Any) -> sqlglot.dataframe.sql.Column:
237    def when(self, condition: Column, value: t.Any) -> Column:
238        from sqlglot.dataframe.sql.functions import when
239
240        column_with_if = when(condition, value)
241        if not isinstance(self.expression, exp.Case):
242            return column_with_if
243        new_column = self.copy()
244        new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"])
245        return new_column
def otherwise(self, value: Any) -> sqlglot.dataframe.sql.Column:
247    def otherwise(self, value: t.Any) -> Column:
248        from sqlglot.dataframe.sql.functions import lit
249
250        true_value = value if isinstance(value, Column) else lit(value)
251        new_column = self.copy()
252        new_column.expression.set("default", true_value.column_expression)
253        return new_column
def isNull(self) -> sqlglot.dataframe.sql.Column:
255    def isNull(self) -> Column:
256        new_expression = exp.Is(this=self.column_expression, expression=exp.Null())
257        return Column(new_expression)
def isNotNull(self) -> sqlglot.dataframe.sql.Column:
259    def isNotNull(self) -> Column:
260        new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
261        return Column(new_expression)
def cast(self, dataType: Union[str, sqlglot.dataframe.sql.types.DataType]):
263    def cast(self, dataType: t.Union[str, DataType]):
264        """
265        Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
266        Sqlglot doesn't currently replicate this class so it only accepts a string
267        """
268        if isinstance(dataType, DataType):
269            dataType = dataType.simpleString()
270        return Column(exp.cast(self.column_expression, dataType, dialect="spark"))

Functionality Difference: PySpark cast accepts a datatype instance of the datatype class Sqlglot doesn't currently replicate this class so it only accepts a string

def startswith( self, value: Union[str, sqlglot.dataframe.sql.Column]) -> sqlglot.dataframe.sql.Column:
272    def startswith(self, value: t.Union[str, Column]) -> Column:
273        value = self._lit(value) if not isinstance(value, Column) else value
274        return self.invoke_anonymous_function(self, "STARTSWITH", value)
def endswith( self, value: Union[str, sqlglot.dataframe.sql.Column]) -> sqlglot.dataframe.sql.Column:
276    def endswith(self, value: t.Union[str, Column]) -> Column:
277        value = self._lit(value) if not isinstance(value, Column) else value
278        return self.invoke_anonymous_function(self, "ENDSWITH", value)
def rlike(self, regexp: str) -> sqlglot.dataframe.sql.Column:
280    def rlike(self, regexp: str) -> Column:
281        return self.invoke_expression_over_column(
282            column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression
283        )
def like(self, other: str):
285    def like(self, other: str):
286        return self.invoke_expression_over_column(
287            self, exp.Like, expression=self._lit(other).expression
288        )
def ilike(self, other: str):
290    def ilike(self, other: str):
291        return self.invoke_expression_over_column(
292            self, exp.ILike, expression=self._lit(other).expression
293        )
def substr( self, startPos: Union[int, sqlglot.dataframe.sql.Column], length: Union[int, sqlglot.dataframe.sql.Column]) -> sqlglot.dataframe.sql.Column:
295    def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
296        startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
297        length = self._lit(length) if not isinstance(length, Column) else length
298        return Column.invoke_expression_over_column(
299            self, exp.Substring, start=startPos.expression, length=length.expression
300        )
def isin( self, *cols: Union[<MagicMock id='140430582714512'>, Iterable[<MagicMock id='140430582714512'>]]):
302    def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]):
303        columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
304        expressions = [self._lit(x).expression for x in columns]
305        return Column.invoke_expression_over_column(self, exp.In, expressions=expressions)  # type: ignore
def between( self, lowerBound: <MagicMock id='140430582771536'>, upperBound: <MagicMock id='140430582832176'>) -> sqlglot.dataframe.sql.Column:
307    def between(
308        self,
309        lowerBound: t.Union[ColumnOrLiteral],
310        upperBound: t.Union[ColumnOrLiteral],
311    ) -> Column:
312        lower_bound_exp = (
313            self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
314        )
315        upper_bound_exp = (
316            self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
317        )
318        return Column(
319            exp.Between(
320                this=self.column_expression,
321                low=lower_bound_exp.expression,
322                high=upper_bound_exp.expression,
323            )
324        )
def over( self, window: <MagicMock id='140430582891472'>) -> sqlglot.dataframe.sql.Column:
326    def over(self, window: WindowSpec) -> Column:
327        window_expression = window.expression.copy()
328        window_expression.set("this", self.column_expression)
329        return Column(window_expression)
class DataFrameNaFunctions:
822class DataFrameNaFunctions:
823    def __init__(self, df: DataFrame):
824        self.df = df
825
826    def drop(
827        self,
828        how: str = "any",
829        thresh: t.Optional[int] = None,
830        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
831    ) -> DataFrame:
832        return self.df.dropna(how=how, thresh=thresh, subset=subset)
833
834    def fill(
835        self,
836        value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
837        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
838    ) -> DataFrame:
839        return self.df.fillna(value=value, subset=subset)
840
841    def replace(
842        self,
843        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
844        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
845        subset: t.Optional[t.Union[str, t.List[str]]] = None,
846    ) -> DataFrame:
847        return self.df.replace(to_replace=to_replace, value=value, subset=subset)
DataFrameNaFunctions(df: sqlglot.dataframe.sql.DataFrame)
823    def __init__(self, df: DataFrame):
824        self.df = df
df
def drop( self, how: str = 'any', thresh: Optional[int] = None, subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
826    def drop(
827        self,
828        how: str = "any",
829        thresh: t.Optional[int] = None,
830        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
831    ) -> DataFrame:
832        return self.df.dropna(how=how, thresh=thresh, subset=subset)
def fill( self, value: Union[int, bool, float, str, Dict[str, Any]], subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
834    def fill(
835        self,
836        value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
837        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
838    ) -> DataFrame:
839        return self.df.fillna(value=value, subset=subset)
def replace( self, to_replace: Union[bool, int, float, str, List, Dict], value: Union[bool, int, float, str, List, NoneType] = None, subset: Union[str, List[str], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
841    def replace(
842        self,
843        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
844        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
845        subset: t.Optional[t.Union[str, t.List[str]]] = None,
846    ) -> DataFrame:
847        return self.df.replace(to_replace=to_replace, value=value, subset=subset)
class Window:
15class Window:
16    _JAVA_MIN_LONG = -(1 << 63)  # -9223372036854775808
17    _JAVA_MAX_LONG = (1 << 63) - 1  # 9223372036854775807
18    _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG)
19    _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG)
20
21    unboundedPreceding: int = _JAVA_MIN_LONG
22
23    unboundedFollowing: int = _JAVA_MAX_LONG
24
25    currentRow: int = 0
26
27    @classmethod
28    def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
29        return WindowSpec().partitionBy(*cols)
30
31    @classmethod
32    def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
33        return WindowSpec().orderBy(*cols)
34
35    @classmethod
36    def rowsBetween(cls, start: int, end: int) -> WindowSpec:
37        return WindowSpec().rowsBetween(start, end)
38
39    @classmethod
40    def rangeBetween(cls, start: int, end: int) -> WindowSpec:
41        return WindowSpec().rangeBetween(start, end)
unboundedPreceding: int = -9223372036854775808
unboundedFollowing: int = 9223372036854775807
currentRow: int = 0
@classmethod
def partitionBy( cls, *cols: Union[<MagicMock id='140430579900976'>, List[<MagicMock id='140430579900976'>]]) -> sqlglot.dataframe.sql.WindowSpec:
27    @classmethod
28    def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
29        return WindowSpec().partitionBy(*cols)
@classmethod
def orderBy( cls, *cols: Union[<MagicMock id='140430582923168'>, List[<MagicMock id='140430582923168'>]]) -> sqlglot.dataframe.sql.WindowSpec:
31    @classmethod
32    def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
33        return WindowSpec().orderBy(*cols)
@classmethod
def rowsBetween(cls, start: int, end: int) -> sqlglot.dataframe.sql.WindowSpec:
35    @classmethod
36    def rowsBetween(cls, start: int, end: int) -> WindowSpec:
37        return WindowSpec().rowsBetween(start, end)
@classmethod
def rangeBetween(cls, start: int, end: int) -> sqlglot.dataframe.sql.WindowSpec:
39    @classmethod
40    def rangeBetween(cls, start: int, end: int) -> WindowSpec:
41        return WindowSpec().rangeBetween(start, end)
class WindowSpec:
 44class WindowSpec:
 45    def __init__(self, expression: exp.Expression = exp.Window()):
 46        self.expression = expression
 47
 48    def copy(self):
 49        return WindowSpec(self.expression.copy())
 50
 51    def sql(self, **kwargs) -> str:
 52        return self.expression.sql(dialect="spark", **kwargs)
 53
 54    def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
 55        from sqlglot.dataframe.sql.column import Column
 56
 57        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
 58        expressions = [Column.ensure_col(x).expression for x in cols]
 59        window_spec = self.copy()
 60        partition_by_expressions = window_spec.expression.args.get("partition_by", [])
 61        partition_by_expressions.extend(expressions)
 62        window_spec.expression.set("partition_by", partition_by_expressions)
 63        return window_spec
 64
 65    def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
 66        from sqlglot.dataframe.sql.column import Column
 67
 68        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
 69        expressions = [Column.ensure_col(x).expression for x in cols]
 70        window_spec = self.copy()
 71        if window_spec.expression.args.get("order") is None:
 72            window_spec.expression.set("order", exp.Order(expressions=[]))
 73        order_by = window_spec.expression.args["order"].expressions
 74        order_by.extend(expressions)
 75        window_spec.expression.args["order"].set("expressions", order_by)
 76        return window_spec
 77
 78    def _calc_start_end(
 79        self, start: int, end: int
 80    ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
 81        kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {
 82            "start_side": None,
 83            "end_side": None,
 84        }
 85        if start == Window.currentRow:
 86            kwargs["start"] = "CURRENT ROW"
 87        else:
 88            kwargs = {
 89                **kwargs,
 90                **{
 91                    "start_side": "PRECEDING",
 92                    "start": "UNBOUNDED"
 93                    if start <= Window.unboundedPreceding
 94                    else F.lit(start).expression,
 95                },
 96            }
 97        if end == Window.currentRow:
 98            kwargs["end"] = "CURRENT ROW"
 99        else:
100            kwargs = {
101                **kwargs,
102                **{
103                    "end_side": "FOLLOWING",
104                    "end": "UNBOUNDED"
105                    if end >= Window.unboundedFollowing
106                    else F.lit(end).expression,
107                },
108            }
109        return kwargs
110
111    def rowsBetween(self, start: int, end: int) -> WindowSpec:
112        window_spec = self.copy()
113        spec = self._calc_start_end(start, end)
114        spec["kind"] = "ROWS"
115        window_spec.expression.set(
116            "spec",
117            exp.WindowSpec(
118                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
119            ),
120        )
121        return window_spec
122
123    def rangeBetween(self, start: int, end: int) -> WindowSpec:
124        window_spec = self.copy()
125        spec = self._calc_start_end(start, end)
126        spec["kind"] = "RANGE"
127        window_spec.expression.set(
128            "spec",
129            exp.WindowSpec(
130                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
131            ),
132        )
133        return window_spec
WindowSpec(expression: sqlglot.expressions.Expression = (WINDOW ))
45    def __init__(self, expression: exp.Expression = exp.Window()):
46        self.expression = expression
expression
def copy(self):
48    def copy(self):
49        return WindowSpec(self.expression.copy())
def sql(self, **kwargs) -> str:
51    def sql(self, **kwargs) -> str:
52        return self.expression.sql(dialect="spark", **kwargs)
def partitionBy( self, *cols: Union[<MagicMock id='140430580080384'>, List[<MagicMock id='140430580080384'>]]) -> sqlglot.dataframe.sql.WindowSpec:
54    def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
55        from sqlglot.dataframe.sql.column import Column
56
57        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
58        expressions = [Column.ensure_col(x).expression for x in cols]
59        window_spec = self.copy()
60        partition_by_expressions = window_spec.expression.args.get("partition_by", [])
61        partition_by_expressions.extend(expressions)
62        window_spec.expression.set("partition_by", partition_by_expressions)
63        return window_spec
def orderBy( self, *cols: Union[<MagicMock id='140430580145632'>, List[<MagicMock id='140430580145632'>]]) -> sqlglot.dataframe.sql.WindowSpec:
65    def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
66        from sqlglot.dataframe.sql.column import Column
67
68        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
69        expressions = [Column.ensure_col(x).expression for x in cols]
70        window_spec = self.copy()
71        if window_spec.expression.args.get("order") is None:
72            window_spec.expression.set("order", exp.Order(expressions=[]))
73        order_by = window_spec.expression.args["order"].expressions
74        order_by.extend(expressions)
75        window_spec.expression.args["order"].set("expressions", order_by)
76        return window_spec
def rowsBetween(self, start: int, end: int) -> sqlglot.dataframe.sql.WindowSpec:
111    def rowsBetween(self, start: int, end: int) -> WindowSpec:
112        window_spec = self.copy()
113        spec = self._calc_start_end(start, end)
114        spec["kind"] = "ROWS"
115        window_spec.expression.set(
116            "spec",
117            exp.WindowSpec(
118                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
119            ),
120        )
121        return window_spec
def rangeBetween(self, start: int, end: int) -> sqlglot.dataframe.sql.WindowSpec:
123    def rangeBetween(self, start: int, end: int) -> WindowSpec:
124        window_spec = self.copy()
125        spec = self._calc_start_end(start, end)
126        spec["kind"] = "RANGE"
127        window_spec.expression.set(
128            "spec",
129            exp.WindowSpec(
130                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
131            ),
132        )
133        return window_spec
class DataFrameReader:
16class DataFrameReader:
17    def __init__(self, spark: SparkSession):
18        self.spark = spark
19
20    def table(self, tableName: str) -> DataFrame:
21        from sqlglot.dataframe.sql.dataframe import DataFrame
22
23        sqlglot.schema.add_table(tableName, dialect="spark")
24
25        return DataFrame(
26            self.spark,
27            exp.Select()
28            .from_(exp.to_table(tableName, dialect="spark").transform(Spark.normalize_identifier))
29            .select(
30                *(column for column in sqlglot.schema.column_names(tableName, dialect="spark"))
31            ),
32        )
DataFrameReader(spark: sqlglot.dataframe.sql.SparkSession)
17    def __init__(self, spark: SparkSession):
18        self.spark = spark
spark
def table(self, tableName: str) -> sqlglot.dataframe.sql.DataFrame:
20    def table(self, tableName: str) -> DataFrame:
21        from sqlglot.dataframe.sql.dataframe import DataFrame
22
23        sqlglot.schema.add_table(tableName, dialect="spark")
24
25        return DataFrame(
26            self.spark,
27            exp.Select()
28            .from_(exp.to_table(tableName, dialect="spark").transform(Spark.normalize_identifier))
29            .select(
30                *(column for column in sqlglot.schema.column_names(tableName, dialect="spark"))
31            ),
32        )
class DataFrameWriter:
35class DataFrameWriter:
36    def __init__(
37        self,
38        df: DataFrame,
39        spark: t.Optional[SparkSession] = None,
40        mode: t.Optional[str] = None,
41        by_name: bool = False,
42    ):
43        self._df = df
44        self._spark = spark or df.spark
45        self._mode = mode
46        self._by_name = by_name
47
48    def copy(self, **kwargs) -> DataFrameWriter:
49        return DataFrameWriter(
50            **{
51                k[1:] if k.startswith("_") else k: v
52                for k, v in object_to_dict(self, **kwargs).items()
53            }
54        )
55
56    def sql(self, **kwargs) -> t.List[str]:
57        return self._df.sql(**kwargs)
58
59    def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
60        return self.copy(_mode=saveMode)
61
62    @property
63    def byName(self):
64        return self.copy(by_name=True)
65
66    def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
67        output_expression_container = exp.Insert(
68            **{
69                "this": exp.to_table(tableName),
70                "overwrite": overwrite,
71            }
72        )
73        df = self._df.copy(output_expression_container=output_expression_container)
74        if self._by_name:
75            columns = sqlglot.schema.column_names(tableName, only_visible=True, dialect="spark")
76            df = df._convert_leaf_to_cte().select(*columns)
77
78        return self.copy(_df=df)
79
80    def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
81        if format is not None:
82            raise NotImplementedError("Providing Format in the save as table is not supported")
83        exists, replace, mode = None, None, mode or str(self._mode)
84        if mode == "append":
85            return self.insertInto(name)
86        if mode == "ignore":
87            exists = True
88        if mode == "overwrite":
89            replace = True
90        output_expression_container = exp.Create(
91            this=exp.to_table(name),
92            kind="TABLE",
93            exists=exists,
94            replace=replace,
95        )
96        return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
DataFrameWriter( df: sqlglot.dataframe.sql.DataFrame, spark: Optional[sqlglot.dataframe.sql.SparkSession] = None, mode: Optional[str] = None, by_name: bool = False)
36    def __init__(
37        self,
38        df: DataFrame,
39        spark: t.Optional[SparkSession] = None,
40        mode: t.Optional[str] = None,
41        by_name: bool = False,
42    ):
43        self._df = df
44        self._spark = spark or df.spark
45        self._mode = mode
46        self._by_name = by_name
def copy(self, **kwargs) -> sqlglot.dataframe.sql.DataFrameWriter:
48    def copy(self, **kwargs) -> DataFrameWriter:
49        return DataFrameWriter(
50            **{
51                k[1:] if k.startswith("_") else k: v
52                for k, v in object_to_dict(self, **kwargs).items()
53            }
54        )
def sql(self, **kwargs) -> List[str]:
56    def sql(self, **kwargs) -> t.List[str]:
57        return self._df.sql(**kwargs)
def mode( self, saveMode: Optional[str]) -> sqlglot.dataframe.sql.DataFrameWriter:
59    def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
60        return self.copy(_mode=saveMode)
byName
def insertInto( self, tableName: str, overwrite: Optional[bool] = None) -> sqlglot.dataframe.sql.DataFrameWriter:
66    def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
67        output_expression_container = exp.Insert(
68            **{
69                "this": exp.to_table(tableName),
70                "overwrite": overwrite,
71            }
72        )
73        df = self._df.copy(output_expression_container=output_expression_container)
74        if self._by_name:
75            columns = sqlglot.schema.column_names(tableName, only_visible=True, dialect="spark")
76            df = df._convert_leaf_to_cte().select(*columns)
77
78        return self.copy(_df=df)
def saveAsTable( self, name: str, format: Optional[str] = None, mode: Optional[str] = None):
80    def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
81        if format is not None:
82            raise NotImplementedError("Providing Format in the save as table is not supported")
83        exists, replace, mode = None, None, mode or str(self._mode)
84        if mode == "append":
85            return self.insertInto(name)
86        if mode == "ignore":
87            exists = True
88        if mode == "overwrite":
89            replace = True
90        output_expression_container = exp.Create(
91            this=exp.to_table(name),
92            kind="TABLE",
93            exists=exists,
94            replace=replace,
95        )
96        return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))