Edit on GitHub

sqlglot.dataframe.sql

 1from sqlglot.dataframe.sql.column import Column
 2from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions
 3from sqlglot.dataframe.sql.group import GroupedData
 4from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter
 5from sqlglot.dataframe.sql.session import SparkSession
 6from sqlglot.dataframe.sql.window import Window, WindowSpec
 7
 8__all__ = [
 9    "SparkSession",
10    "DataFrame",
11    "GroupedData",
12    "Column",
13    "DataFrameNaFunctions",
14    "Window",
15    "WindowSpec",
16    "DataFrameReader",
17    "DataFrameWriter",
18]
class SparkSession:
 21class SparkSession:
 22    DEFAULT_DIALECT = "spark"
 23    _instance = None
 24
 25    def __init__(self):
 26        if not hasattr(self, "known_ids"):
 27            self.known_ids = set()
 28            self.known_branch_ids = set()
 29            self.known_sequence_ids = set()
 30            self.name_to_sequence_id_mapping = defaultdict(list)
 31            self.incrementing_id = 1
 32            self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)()
 33
 34    def __new__(cls, *args, **kwargs) -> SparkSession:
 35        if cls._instance is None:
 36            cls._instance = super().__new__(cls)
 37        return cls._instance
 38
 39    @property
 40    def read(self) -> DataFrameReader:
 41        return DataFrameReader(self)
 42
 43    def table(self, tableName: str) -> DataFrame:
 44        return self.read.table(tableName)
 45
 46    def createDataFrame(
 47        self,
 48        data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
 49        schema: t.Optional[SchemaInput] = None,
 50        samplingRatio: t.Optional[float] = None,
 51        verifySchema: bool = False,
 52    ) -> DataFrame:
 53        from sqlglot.dataframe.sql.dataframe import DataFrame
 54
 55        if samplingRatio is not None or verifySchema:
 56            raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
 57        if schema is not None and (
 58            not isinstance(schema, (StructType, str, list))
 59            or (isinstance(schema, list) and not isinstance(schema[0], str))
 60        ):
 61            raise NotImplementedError("Only schema of either list or string of list supported")
 62        if not data:
 63            raise ValueError("Must provide data to create into a DataFrame")
 64
 65        column_mapping: t.Dict[str, t.Optional[str]]
 66        if schema is not None:
 67            column_mapping = get_column_mapping_from_schema_input(schema)
 68        elif isinstance(data[0], dict):
 69            column_mapping = {col_name.strip(): None for col_name in data[0]}
 70        else:
 71            column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
 72
 73        data_expressions = [
 74            exp.Tuple(
 75                expressions=list(
 76                    map(
 77                        lambda x: F.lit(x).expression,
 78                        row if not isinstance(row, dict) else row.values(),
 79                    )
 80                )
 81            )
 82            for row in data
 83        ]
 84
 85        sel_columns = [
 86            F.col(name).cast(data_type).alias(name).expression
 87            if data_type is not None
 88            else F.col(name).expression
 89            for name, data_type in column_mapping.items()
 90        ]
 91
 92        select_kwargs = {
 93            "expressions": sel_columns,
 94            "from": exp.From(
 95                this=exp.Values(
 96                    expressions=data_expressions,
 97                    alias=exp.TableAlias(
 98                        this=exp.to_identifier(self._auto_incrementing_name),
 99                        columns=[exp.to_identifier(col_name) for col_name in column_mapping],
100                    ),
101                ),
102            ),
103        }
104
105        sel_expression = exp.Select(**select_kwargs)
106        return DataFrame(self, sel_expression)
107
108    def sql(self, sqlQuery: str) -> DataFrame:
109        expression = sqlglot.parse_one(sqlQuery, read=self.dialect)
110        if isinstance(expression, exp.Select):
111            df = DataFrame(self, expression)
112            df = df._convert_leaf_to_cte()
113        elif isinstance(expression, (exp.Create, exp.Insert)):
114            select_expression = expression.expression.copy()
115            if isinstance(expression, exp.Insert):
116                select_expression.set("with", expression.args.get("with"))
117                expression.set("with", None)
118            del expression.args["expression"]
119            df = DataFrame(self, select_expression, output_expression_container=expression)  # type: ignore
120            df = df._convert_leaf_to_cte()
121        else:
122            raise ValueError(
123                "Unknown expression type provided in the SQL. Please create an issue with the SQL."
124            )
125        return df
126
127    @property
128    def _auto_incrementing_name(self) -> str:
129        name = f"a{self.incrementing_id}"
130        self.incrementing_id += 1
131        return name
132
133    @property
134    def _random_branch_id(self) -> str:
135        id = self._random_id
136        self.known_branch_ids.add(id)
137        return id
138
139    @property
140    def _random_sequence_id(self):
141        id = self._random_id
142        self.known_sequence_ids.add(id)
143        return id
144
145    @property
146    def _random_id(self) -> str:
147        id = "r" + uuid.uuid4().hex
148        self.known_ids.add(id)
149        return id
150
151    @property
152    def _join_hint_names(self) -> t.Set[str]:
153        return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"}
154
155    def _add_alias_to_mapping(self, name: str, sequence_id: str):
156        self.name_to_sequence_id_mapping[name].append(sequence_id)
157
158    class Builder:
159        SQLFRAME_DIALECT_KEY = "sqlframe.dialect"
160
161        def __init__(self):
162            self.dialect = "spark"
163
164        def __getattr__(self, item) -> SparkSession.Builder:
165            return self
166
167        def __call__(self, *args, **kwargs):
168            return self
169
170        def config(
171            self,
172            key: t.Optional[str] = None,
173            value: t.Optional[t.Any] = None,
174            *,
175            map: t.Optional[t.Dict[str, t.Any]] = None,
176            **kwargs: t.Any,
177        ) -> SparkSession.Builder:
178            if key == self.SQLFRAME_DIALECT_KEY:
179                self.dialect = value
180            elif map and self.SQLFRAME_DIALECT_KEY in map:
181                self.dialect = map[self.SQLFRAME_DIALECT_KEY]
182            return self
183
184        def getOrCreate(self) -> SparkSession:
185            spark = SparkSession()
186            spark.dialect = Dialect.get_or_raise(self.dialect)()
187            return spark
188
189    @classproperty
190    def builder(cls) -> Builder:
191        return cls.Builder()
DEFAULT_DIALECT = 'spark'
def table(self, tableName: str) -> DataFrame:
43    def table(self, tableName: str) -> DataFrame:
44        return self.read.table(tableName)
def createDataFrame( self, data: Sequence[Union[Dict[str, <MagicMock id='139879721478480'>], List[<MagicMock id='139879721478480'>], Tuple]], schema: Optional[<MagicMock id='139879721383968'>] = None, samplingRatio: Optional[float] = None, verifySchema: bool = False) -> DataFrame:
 46    def createDataFrame(
 47        self,
 48        data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
 49        schema: t.Optional[SchemaInput] = None,
 50        samplingRatio: t.Optional[float] = None,
 51        verifySchema: bool = False,
 52    ) -> DataFrame:
 53        from sqlglot.dataframe.sql.dataframe import DataFrame
 54
 55        if samplingRatio is not None or verifySchema:
 56            raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
 57        if schema is not None and (
 58            not isinstance(schema, (StructType, str, list))
 59            or (isinstance(schema, list) and not isinstance(schema[0], str))
 60        ):
 61            raise NotImplementedError("Only schema of either list or string of list supported")
 62        if not data:
 63            raise ValueError("Must provide data to create into a DataFrame")
 64
 65        column_mapping: t.Dict[str, t.Optional[str]]
 66        if schema is not None:
 67            column_mapping = get_column_mapping_from_schema_input(schema)
 68        elif isinstance(data[0], dict):
 69            column_mapping = {col_name.strip(): None for col_name in data[0]}
 70        else:
 71            column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
 72
 73        data_expressions = [
 74            exp.Tuple(
 75                expressions=list(
 76                    map(
 77                        lambda x: F.lit(x).expression,
 78                        row if not isinstance(row, dict) else row.values(),
 79                    )
 80                )
 81            )
 82            for row in data
 83        ]
 84
 85        sel_columns = [
 86            F.col(name).cast(data_type).alias(name).expression
 87            if data_type is not None
 88            else F.col(name).expression
 89            for name, data_type in column_mapping.items()
 90        ]
 91
 92        select_kwargs = {
 93            "expressions": sel_columns,
 94            "from": exp.From(
 95                this=exp.Values(
 96                    expressions=data_expressions,
 97                    alias=exp.TableAlias(
 98                        this=exp.to_identifier(self._auto_incrementing_name),
 99                        columns=[exp.to_identifier(col_name) for col_name in column_mapping],
100                    ),
101                ),
102            ),
103        }
104
105        sel_expression = exp.Select(**select_kwargs)
106        return DataFrame(self, sel_expression)
def sql(self, sqlQuery: str) -> DataFrame:
108    def sql(self, sqlQuery: str) -> DataFrame:
109        expression = sqlglot.parse_one(sqlQuery, read=self.dialect)
110        if isinstance(expression, exp.Select):
111            df = DataFrame(self, expression)
112            df = df._convert_leaf_to_cte()
113        elif isinstance(expression, (exp.Create, exp.Insert)):
114            select_expression = expression.expression.copy()
115            if isinstance(expression, exp.Insert):
116                select_expression.set("with", expression.args.get("with"))
117                expression.set("with", None)
118            del expression.args["expression"]
119            df = DataFrame(self, select_expression, output_expression_container=expression)  # type: ignore
120            df = df._convert_leaf_to_cte()
121        else:
122            raise ValueError(
123                "Unknown expression type provided in the SQL. Please create an issue with the SQL."
124            )
125        return df
class SparkSession.Builder:
158    class Builder:
159        SQLFRAME_DIALECT_KEY = "sqlframe.dialect"
160
161        def __init__(self):
162            self.dialect = "spark"
163
164        def __getattr__(self, item) -> SparkSession.Builder:
165            return self
166
167        def __call__(self, *args, **kwargs):
168            return self
169
170        def config(
171            self,
172            key: t.Optional[str] = None,
173            value: t.Optional[t.Any] = None,
174            *,
175            map: t.Optional[t.Dict[str, t.Any]] = None,
176            **kwargs: t.Any,
177        ) -> SparkSession.Builder:
178            if key == self.SQLFRAME_DIALECT_KEY:
179                self.dialect = value
180            elif map and self.SQLFRAME_DIALECT_KEY in map:
181                self.dialect = map[self.SQLFRAME_DIALECT_KEY]
182            return self
183
184        def getOrCreate(self) -> SparkSession:
185            spark = SparkSession()
186            spark.dialect = Dialect.get_or_raise(self.dialect)()
187            return spark
SQLFRAME_DIALECT_KEY = 'sqlframe.dialect'
dialect
def config( self, key: Optional[str] = None, value: Optional[Any] = None, *, map: Optional[Dict[str, Any]] = None, **kwargs: Any) -> SparkSession.Builder:
170        def config(
171            self,
172            key: t.Optional[str] = None,
173            value: t.Optional[t.Any] = None,
174            *,
175            map: t.Optional[t.Dict[str, t.Any]] = None,
176            **kwargs: t.Any,
177        ) -> SparkSession.Builder:
178            if key == self.SQLFRAME_DIALECT_KEY:
179                self.dialect = value
180            elif map and self.SQLFRAME_DIALECT_KEY in map:
181                self.dialect = map[self.SQLFRAME_DIALECT_KEY]
182            return self
def getOrCreate(self) -> SparkSession:
184        def getOrCreate(self) -> SparkSession:
185            spark = SparkSession()
186            spark.dialect = Dialect.get_or_raise(self.dialect)()
187            return spark
class DataFrame:
 49class DataFrame:
 50    def __init__(
 51        self,
 52        spark: SparkSession,
 53        expression: exp.Select,
 54        branch_id: t.Optional[str] = None,
 55        sequence_id: t.Optional[str] = None,
 56        last_op: Operation = Operation.INIT,
 57        pending_hints: t.Optional[t.List[exp.Expression]] = None,
 58        output_expression_container: t.Optional[OutputExpressionContainer] = None,
 59        **kwargs,
 60    ):
 61        self.spark = spark
 62        self.expression = expression
 63        self.branch_id = branch_id or self.spark._random_branch_id
 64        self.sequence_id = sequence_id or self.spark._random_sequence_id
 65        self.last_op = last_op
 66        self.pending_hints = pending_hints or []
 67        self.output_expression_container = output_expression_container or exp.Select()
 68
 69    def __getattr__(self, column_name: str) -> Column:
 70        return self[column_name]
 71
 72    def __getitem__(self, column_name: str) -> Column:
 73        column_name = f"{self.branch_id}.{column_name}"
 74        return Column(column_name)
 75
 76    def __copy__(self):
 77        return self.copy()
 78
 79    @property
 80    def sparkSession(self):
 81        return self.spark
 82
 83    @property
 84    def write(self):
 85        return DataFrameWriter(self)
 86
 87    @property
 88    def latest_cte_name(self) -> str:
 89        if not self.expression.ctes:
 90            from_exp = self.expression.args["from"]
 91            if from_exp.alias_or_name:
 92                return from_exp.alias_or_name
 93            table_alias = from_exp.find(exp.TableAlias)
 94            if not table_alias:
 95                raise RuntimeError(
 96                    f"Could not find an alias name for this expression: {self.expression}"
 97                )
 98            return table_alias.alias_or_name
 99        return self.expression.ctes[-1].alias
100
101    @property
102    def pending_join_hints(self):
103        return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)]
104
105    @property
106    def pending_partition_hints(self):
107        return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)]
108
109    @property
110    def columns(self) -> t.List[str]:
111        return self.expression.named_selects
112
113    @property
114    def na(self) -> DataFrameNaFunctions:
115        return DataFrameNaFunctions(self)
116
117    def _replace_cte_names_with_hashes(self, expression: exp.Select):
118        replacement_mapping = {}
119        for cte in expression.ctes:
120            old_name_id = cte.args["alias"].this
121            new_hashed_id = exp.to_identifier(
122                self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
123            )
124            replacement_mapping[old_name_id] = new_hashed_id
125            expression = expression.transform(replace_id_value, replacement_mapping)
126        return expression
127
128    def _create_cte_from_expression(
129        self,
130        expression: exp.Expression,
131        branch_id: t.Optional[str] = None,
132        sequence_id: t.Optional[str] = None,
133        **kwargs,
134    ) -> t.Tuple[exp.CTE, str]:
135        name = self._create_hash_from_expression(expression)
136        expression_to_cte = expression.copy()
137        expression_to_cte.set("with", None)
138        cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
139        cte.set("branch_id", branch_id or self.branch_id)
140        cte.set("sequence_id", sequence_id or self.sequence_id)
141        return cte, name
142
143    @t.overload
144    def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
145        ...
146
147    @t.overload
148    def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
149        ...
150
151    def _ensure_list_of_columns(self, cols):
152        return Column.ensure_cols(ensure_list(cols))
153
154    def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None):
155        cols = self._ensure_list_of_columns(cols)
156        normalize(self.spark, expression or self.expression, cols)
157        return cols
158
159    def _ensure_and_normalize_col(self, col):
160        col = Column.ensure_col(col)
161        normalize(self.spark, self.expression, col)
162        return col
163
164    def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame:
165        df = self._resolve_pending_hints()
166        sequence_id = sequence_id or df.sequence_id
167        expression = df.expression.copy()
168        cte_expression, cte_name = df._create_cte_from_expression(
169            expression=expression, sequence_id=sequence_id
170        )
171        new_expression = df._add_ctes_to_expression(
172            exp.Select(), expression.ctes + [cte_expression]
173        )
174        sel_columns = df._get_outer_select_columns(cte_expression)
175        new_expression = new_expression.from_(cte_name).select(
176            *[x.alias_or_name for x in sel_columns]
177        )
178        return df.copy(expression=new_expression, sequence_id=sequence_id)
179
180    def _resolve_pending_hints(self) -> DataFrame:
181        df = self.copy()
182        if not self.pending_hints:
183            return df
184        expression = df.expression
185        hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
186        for hint in df.pending_partition_hints:
187            hint_expression.append("expressions", hint)
188            df.pending_hints.remove(hint)
189
190        join_aliases = {
191            join_table.alias_or_name
192            for join_table in get_tables_from_expression_with_join(expression)
193        }
194        if join_aliases:
195            for hint in df.pending_join_hints:
196                for sequence_id_expression in hint.expressions:
197                    sequence_id_or_name = sequence_id_expression.alias_or_name
198                    sequence_ids_to_match = [sequence_id_or_name]
199                    if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
200                        sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[
201                            sequence_id_or_name
202                        ]
203                    matching_ctes = [
204                        cte
205                        for cte in reversed(expression.ctes)
206                        if cte.args["sequence_id"] in sequence_ids_to_match
207                    ]
208                    for matching_cte in matching_ctes:
209                        if matching_cte.alias_or_name in join_aliases:
210                            sequence_id_expression.set("this", matching_cte.args["alias"].this)
211                            df.pending_hints.remove(hint)
212                            break
213                hint_expression.append("expressions", hint)
214        if hint_expression.expressions:
215            expression.set("hint", hint_expression)
216        return df
217
218    def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
219        hint_name = hint_name.upper()
220        hint_expression = (
221            exp.JoinHint(
222                this=hint_name,
223                expressions=[exp.to_table(parameter.alias_or_name) for parameter in args],
224            )
225            if hint_name in JOIN_HINTS
226            else exp.Anonymous(
227                this=hint_name, expressions=[parameter.expression for parameter in args]
228            )
229        )
230        new_df = self.copy()
231        new_df.pending_hints.append(hint_expression)
232        return new_df
233
234    def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool):
235        other_df = other._convert_leaf_to_cte()
236        base_expression = self.expression.copy()
237        base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes)
238        all_ctes = base_expression.ctes
239        other_df.expression.set("with", None)
240        base_expression.set("with", None)
241        operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression)
242        operation.set("with", exp.With(expressions=all_ctes))
243        return self.copy(expression=operation)._convert_leaf_to_cte()
244
245    def _cache(self, storage_level: str):
246        df = self._convert_leaf_to_cte()
247        df.expression.ctes[-1].set("cache_storage_level", storage_level)
248        return df
249
250    @classmethod
251    def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select:
252        expression = expression.copy()
253        with_expression = expression.args.get("with")
254        if with_expression:
255            existing_ctes = with_expression.expressions
256            existsing_cte_names = {x.alias_or_name for x in existing_ctes}
257            for cte in ctes:
258                if cte.alias_or_name not in existsing_cte_names:
259                    existing_ctes.append(cte)
260        else:
261            existing_ctes = ctes
262        expression.set("with", exp.With(expressions=existing_ctes))
263        return expression
264
265    @classmethod
266    def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]:
267        expression = item.expression if isinstance(item, DataFrame) else item
268        return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions]
269
270    @classmethod
271    def _create_hash_from_expression(cls, expression: exp.Expression) -> str:
272        from sqlglot.dataframe.sql.session import SparkSession
273
274        value = expression.sql(dialect=SparkSession().dialect).encode("utf-8")
275        return f"t{zlib.crc32(value)}"[:6]
276
277    def _get_select_expressions(
278        self,
279    ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
280        select_expressions: t.List[
281            t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]
282        ] = []
283        main_select_ctes: t.List[exp.CTE] = []
284        for cte in self.expression.ctes:
285            cache_storage_level = cte.args.get("cache_storage_level")
286            if cache_storage_level:
287                select_expression = cte.this.copy()
288                select_expression.set("with", exp.With(expressions=copy(main_select_ctes)))
289                select_expression.set("cte_alias_name", cte.alias_or_name)
290                select_expression.set("cache_storage_level", cache_storage_level)
291                select_expressions.append((exp.Cache, select_expression))
292            else:
293                main_select_ctes.append(cte)
294        main_select = self.expression.copy()
295        if main_select_ctes:
296            main_select.set("with", exp.With(expressions=main_select_ctes))
297        expression_select_pair = (type(self.output_expression_container), main_select)
298        select_expressions.append(expression_select_pair)  # type: ignore
299        return select_expressions
300
301    def sql(
302        self, dialect: t.Optional[DialectType] = None, optimize: bool = True, **kwargs
303    ) -> t.List[str]:
304        from sqlglot.dataframe.sql.session import SparkSession
305
306        if dialect and Dialect.get_or_raise(dialect)() != SparkSession().dialect:
307            logger.warning(
308                f"The recommended way of defining a dialect is by doing `SparkSession.builder.config('sqlframe.dialect', '{dialect}').getOrCreate()`. It is no longer needed then when calling `sql`. If you run into issues try updating your query to use this pattern."
309            )
310        df = self._resolve_pending_hints()
311        select_expressions = df._get_select_expressions()
312        output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
313        replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
314        for expression_type, select_expression in select_expressions:
315            select_expression = select_expression.transform(replace_id_value, replacement_mapping)
316            if optimize:
317                quote_identifiers(select_expression)
318                select_expression = t.cast(
319                    exp.Select, optimize_func(select_expression, dialect=SparkSession().dialect)
320                )
321            select_expression = df._replace_cte_names_with_hashes(select_expression)
322            expression: t.Union[exp.Select, exp.Cache, exp.Drop]
323            if expression_type == exp.Cache:
324                cache_table_name = df._create_hash_from_expression(select_expression)
325                cache_table = exp.to_table(cache_table_name)
326                original_alias_name = select_expression.args["cte_alias_name"]
327
328                replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(  # type: ignore
329                    cache_table_name
330                )
331                sqlglot.schema.add_table(
332                    cache_table_name,
333                    {
334                        expression.alias_or_name: expression.type.sql(
335                            dialect=SparkSession().dialect
336                        )
337                        for expression in select_expression.expressions
338                    },
339                    dialect=SparkSession().dialect,
340                )
341                cache_storage_level = select_expression.args["cache_storage_level"]
342                options = [
343                    exp.Literal.string("storageLevel"),
344                    exp.Literal.string(cache_storage_level),
345                ]
346                expression = exp.Cache(
347                    this=cache_table, expression=select_expression, lazy=True, options=options
348                )
349                # We will drop the "view" if it exists before running the cache table
350                output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
351            elif expression_type == exp.Create:
352                expression = df.output_expression_container.copy()
353                expression.set("expression", select_expression)
354            elif expression_type == exp.Insert:
355                expression = df.output_expression_container.copy()
356                select_without_ctes = select_expression.copy()
357                select_without_ctes.set("with", None)
358                expression.set("expression", select_without_ctes)
359                if select_expression.ctes:
360                    expression.set("with", exp.With(expressions=select_expression.ctes))
361            elif expression_type == exp.Select:
362                expression = select_expression
363            else:
364                raise ValueError(f"Invalid expression type: {expression_type}")
365            output_expressions.append(expression)
366
367        return [
368            expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
369            for expression in output_expressions
370        ]
371
372    def copy(self, **kwargs) -> DataFrame:
373        return DataFrame(**object_to_dict(self, **kwargs))
374
375    @operation(Operation.SELECT)
376    def select(self, *cols, **kwargs) -> DataFrame:
377        cols = self._ensure_and_normalize_cols(cols)
378        kwargs["append"] = kwargs.get("append", False)
379        if self.expression.args.get("joins"):
380            ambiguous_cols = [
381                col
382                for col in cols
383                if isinstance(col.column_expression, exp.Column) and not col.column_expression.table
384            ]
385            if ambiguous_cols:
386                join_table_identifiers = [
387                    x.this for x in get_tables_from_expression_with_join(self.expression)
388                ]
389                cte_names_in_join = [x.this for x in join_table_identifiers]
390                # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
391                # and therefore we allow multiple columns with the same name in the result. This matches the behavior
392                # of Spark.
393                resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols}
394                for ambiguous_col in ambiguous_cols:
395                    ctes_with_column = [
396                        cte
397                        for cte in self.expression.ctes
398                        if cte.alias_or_name in cte_names_in_join
399                        and ambiguous_col.alias_or_name in cte.this.named_selects
400                    ]
401                    # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise,
402                    # use the same CTE we used before
403                    cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1)
404                    if cte:
405                        resolved_column_position[ambiguous_col] += 1
406                    else:
407                        cte = ctes_with_column[resolved_column_position[ambiguous_col]]
408                    ambiguous_col.expression.set("table", cte.alias_or_name)
409        return self.copy(
410            expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs
411        )
412
413    @operation(Operation.NO_OP)
414    def alias(self, name: str, **kwargs) -> DataFrame:
415        new_sequence_id = self.spark._random_sequence_id
416        df = self.copy()
417        for join_hint in df.pending_join_hints:
418            for expression in join_hint.expressions:
419                if expression.alias_or_name == self.sequence_id:
420                    expression.set("this", Column.ensure_col(new_sequence_id).expression)
421        df.spark._add_alias_to_mapping(name, new_sequence_id)
422        return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
423
424    @operation(Operation.WHERE)
425    def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
426        col = self._ensure_and_normalize_col(column)
427        return self.copy(expression=self.expression.where(col.expression))
428
429    filter = where
430
431    @operation(Operation.GROUP_BY)
432    def groupBy(self, *cols, **kwargs) -> GroupedData:
433        columns = self._ensure_and_normalize_cols(cols)
434        return GroupedData(self, columns, self.last_op)
435
436    @operation(Operation.SELECT)
437    def agg(self, *exprs, **kwargs) -> DataFrame:
438        cols = self._ensure_and_normalize_cols(exprs)
439        return self.groupBy().agg(*cols)
440
441    @operation(Operation.FROM)
442    def join(
443        self,
444        other_df: DataFrame,
445        on: t.Union[str, t.List[str], Column, t.List[Column]],
446        how: str = "inner",
447        **kwargs,
448    ) -> DataFrame:
449        other_df = other_df._convert_leaf_to_cte()
450        join_columns = self._ensure_list_of_columns(on)
451        # We will determine actual "join on" expression later so we don't provide it at first
452        join_expression = self.expression.join(
453            other_df.latest_cte_name, join_type=how.replace("_", " ")
454        )
455        join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
456        self_columns = self._get_outer_select_columns(join_expression)
457        other_columns = self._get_outer_select_columns(other_df)
458        # Determines the join clause and select columns to be used passed on what type of columns were provided for
459        # the join. The columns returned changes based on how the on expression is provided.
460        if isinstance(join_columns[0].expression, exp.Column):
461            """
462            Unique characteristics of join on column names only:
463            * The column names are put at the front of the select list
464            * The column names are deduplicated across the entire select list and only the column names (other dups are allowed)
465            """
466            table_names = [
467                table.alias_or_name
468                for table in get_tables_from_expression_with_join(join_expression)
469            ]
470            potential_ctes = [
471                cte
472                for cte in join_expression.ctes
473                if cte.alias_or_name in table_names
474                and cte.alias_or_name != other_df.latest_cte_name
475            ]
476            # Determine the table to reference for the left side of the join by checking each of the left side
477            # tables and see if they have the column being referenced.
478            join_column_pairs = []
479            for join_column in join_columns:
480                num_matching_ctes = 0
481                for cte in potential_ctes:
482                    if join_column.alias_or_name in cte.this.named_selects:
483                        left_column = join_column.copy().set_table_name(cte.alias_or_name)
484                        right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
485                        join_column_pairs.append((left_column, right_column))
486                        num_matching_ctes += 1
487                if num_matching_ctes > 1:
488                    raise ValueError(
489                        f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name."
490                    )
491                elif num_matching_ctes == 0:
492                    raise ValueError(
493                        f"Column {join_column.alias_or_name} does not exist in any of the tables."
494                    )
495            join_clause = functools.reduce(
496                lambda x, y: x & y,
497                [left_column == right_column for left_column, right_column in join_column_pairs],
498            )
499            join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
500            # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
501            select_column_names = [
502                column.alias_or_name
503                if not isinstance(column.expression.this, exp.Star)
504                else column.sql()
505                for column in self_columns + other_columns
506            ]
507            select_column_names = [
508                column_name
509                for column_name in select_column_names
510                if column_name not in join_column_names
511            ]
512            select_column_names = join_column_names + select_column_names
513        else:
514            """
515            Unique characteristics of join on expressions:
516            * There is no deduplication of the results.
517            * The left join dataframe columns go first and right come after. No sort preference is given to join columns
518            """
519            join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
520            if len(join_columns) > 1:
521                join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
522            join_clause = join_columns[0]
523            select_column_names = [column.alias_or_name for column in self_columns + other_columns]
524
525        # Update the on expression with the actual join clause to replace the dummy one from before
526        join_expression.args["joins"][-1].set("on", join_clause.expression)
527        new_df = self.copy(expression=join_expression)
528        new_df.pending_join_hints.extend(self.pending_join_hints)
529        new_df.pending_hints.extend(other_df.pending_hints)
530        new_df = new_df.select.__wrapped__(new_df, *select_column_names)
531        return new_df
532
533    @operation(Operation.ORDER_BY)
534    def orderBy(
535        self,
536        *cols: t.Union[str, Column],
537        ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
538    ) -> DataFrame:
539        """
540        This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
541        has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
542        is unlikely to come up.
543        """
544        columns = self._ensure_and_normalize_cols(cols)
545        pre_ordered_col_indexes = [
546            x
547            for x in [
548                i if isinstance(col.expression, exp.Ordered) else None
549                for i, col in enumerate(columns)
550            ]
551            if x is not None
552        ]
553        if ascending is None:
554            ascending = [True] * len(columns)
555        elif not isinstance(ascending, list):
556            ascending = [ascending] * len(columns)
557        ascending = [bool(x) for i, x in enumerate(ascending)]
558        assert len(columns) == len(
559            ascending
560        ), "The length of items in ascending must equal the number of columns provided"
561        col_and_ascending = list(zip(columns, ascending))
562        order_by_columns = [
563            exp.Ordered(this=col.expression, desc=not asc)
564            if i not in pre_ordered_col_indexes
565            else columns[i].column_expression
566            for i, (col, asc) in enumerate(col_and_ascending)
567        ]
568        return self.copy(expression=self.expression.order_by(*order_by_columns))
569
570    sort = orderBy
571
572    @operation(Operation.FROM)
573    def union(self, other: DataFrame) -> DataFrame:
574        return self._set_operation(exp.Union, other, False)
575
576    unionAll = union
577
578    @operation(Operation.FROM)
579    def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
580        l_columns = self.columns
581        r_columns = other.columns
582        if not allowMissingColumns:
583            l_expressions = l_columns
584            r_expressions = l_columns
585        else:
586            l_expressions = []
587            r_expressions = []
588            r_columns_unused = copy(r_columns)
589            for l_column in l_columns:
590                l_expressions.append(l_column)
591                if l_column in r_columns:
592                    r_expressions.append(l_column)
593                    r_columns_unused.remove(l_column)
594                else:
595                    r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False))
596            for r_column in r_columns_unused:
597                l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False))
598                r_expressions.append(r_column)
599        r_df = (
600            other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
601        )
602        l_df = self.copy()
603        if allowMissingColumns:
604            l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
605        return l_df._set_operation(exp.Union, r_df, False)
606
607    @operation(Operation.FROM)
608    def intersect(self, other: DataFrame) -> DataFrame:
609        return self._set_operation(exp.Intersect, other, True)
610
611    @operation(Operation.FROM)
612    def intersectAll(self, other: DataFrame) -> DataFrame:
613        return self._set_operation(exp.Intersect, other, False)
614
615    @operation(Operation.FROM)
616    def exceptAll(self, other: DataFrame) -> DataFrame:
617        return self._set_operation(exp.Except, other, False)
618
619    @operation(Operation.SELECT)
620    def distinct(self) -> DataFrame:
621        return self.copy(expression=self.expression.distinct())
622
623    @operation(Operation.SELECT)
624    def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
625        if not subset:
626            return self.distinct()
627        column_names = ensure_list(subset)
628        window = Window.partitionBy(*column_names).orderBy(*column_names)
629        return (
630            self.copy()
631            .withColumn("row_num", F.row_number().over(window))
632            .where(F.col("row_num") == F.lit(1))
633            .drop("row_num")
634        )
635
636    @operation(Operation.FROM)
637    def dropna(
638        self,
639        how: str = "any",
640        thresh: t.Optional[int] = None,
641        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
642    ) -> DataFrame:
643        minimum_non_null = thresh or 0  # will be determined later if thresh is null
644        new_df = self.copy()
645        all_columns = self._get_outer_select_columns(new_df.expression)
646        if subset:
647            null_check_columns = self._ensure_and_normalize_cols(subset)
648        else:
649            null_check_columns = all_columns
650        if thresh is None:
651            minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
652        else:
653            minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
654        if minimum_num_nulls > len(null_check_columns):
655            raise RuntimeError(
656                f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
657                f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
658            )
659        if_null_checks = [
660            F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
661        ]
662        nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
663        num_nulls = nulls_added_together.alias("num_nulls")
664        new_df = new_df.select(num_nulls, append=True)
665        filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
666        final_df = filtered_df.select(*all_columns)
667        return final_df
668
669    @operation(Operation.FROM)
670    def fillna(
671        self,
672        value: t.Union[ColumnLiterals],
673        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
674    ) -> DataFrame:
675        """
676        Functionality Difference: If you provide a value to replace a null and that type conflicts
677        with the type of the column then PySpark will just ignore your replacement.
678        This will try to cast them to be the same in some cases. So they won't always match.
679        Best to not mix types so make sure replacement is the same type as the column
680
681        Possibility for improvement: Use `typeof` function to get the type of the column
682        and check if it matches the type of the value provided. If not then make it null.
683        """
684        from sqlglot.dataframe.sql.functions import lit
685
686        values = None
687        columns = None
688        new_df = self.copy()
689        all_columns = self._get_outer_select_columns(new_df.expression)
690        all_column_mapping = {column.alias_or_name: column for column in all_columns}
691        if isinstance(value, dict):
692            values = list(value.values())
693            columns = self._ensure_and_normalize_cols(list(value))
694        if not columns:
695            columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
696        if not values:
697            values = [value] * len(columns)
698        value_columns = [lit(value) for value in values]
699
700        null_replacement_mapping = {
701            column.alias_or_name: (
702                F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
703            )
704            for column, value in zip(columns, value_columns)
705        }
706        null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
707        null_replacement_columns = [
708            null_replacement_mapping[column.alias_or_name] for column in all_columns
709        ]
710        new_df = new_df.select(*null_replacement_columns)
711        return new_df
712
713    @operation(Operation.FROM)
714    def replace(
715        self,
716        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
717        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
718        subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
719    ) -> DataFrame:
720        from sqlglot.dataframe.sql.functions import lit
721
722        old_values = None
723        new_df = self.copy()
724        all_columns = self._get_outer_select_columns(new_df.expression)
725        all_column_mapping = {column.alias_or_name: column for column in all_columns}
726
727        columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
728        if isinstance(to_replace, dict):
729            old_values = list(to_replace)
730            new_values = list(to_replace.values())
731        elif not old_values and isinstance(to_replace, list):
732            assert isinstance(value, list), "value must be a list since the replacements are a list"
733            assert len(to_replace) == len(
734                value
735            ), "the replacements and values must be the same length"
736            old_values = to_replace
737            new_values = value
738        else:
739            old_values = [to_replace] * len(columns)
740            new_values = [value] * len(columns)
741        old_values = [lit(value) for value in old_values]
742        new_values = [lit(value) for value in new_values]
743
744        replacement_mapping = {}
745        for column in columns:
746            expression = Column(None)
747            for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
748                if i == 0:
749                    expression = F.when(column == old_value, new_value)
750                else:
751                    expression = expression.when(column == old_value, new_value)  # type: ignore
752            replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
753                column.expression.alias_or_name
754            )
755
756        replacement_mapping = {**all_column_mapping, **replacement_mapping}
757        replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
758        new_df = new_df.select(*replacement_columns)
759        return new_df
760
761    @operation(Operation.SELECT)
762    def withColumn(self, colName: str, col: Column) -> DataFrame:
763        col = self._ensure_and_normalize_col(col)
764        existing_col_names = self.expression.named_selects
765        existing_col_index = (
766            existing_col_names.index(colName) if colName in existing_col_names else None
767        )
768        if existing_col_index:
769            expression = self.expression.copy()
770            expression.expressions[existing_col_index] = col.expression
771            return self.copy(expression=expression)
772        return self.copy().select(col.alias(colName), append=True)
773
774    @operation(Operation.SELECT)
775    def withColumnRenamed(self, existing: str, new: str):
776        expression = self.expression.copy()
777        existing_columns = [
778            expression
779            for expression in expression.expressions
780            if expression.alias_or_name == existing
781        ]
782        if not existing_columns:
783            raise ValueError("Tried to rename a column that doesn't exist")
784        for existing_column in existing_columns:
785            if isinstance(existing_column, exp.Column):
786                existing_column.replace(exp.alias_(existing_column, new))
787            else:
788                existing_column.set("alias", exp.to_identifier(new))
789        return self.copy(expression=expression)
790
791    @operation(Operation.SELECT)
792    def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
793        all_columns = self._get_outer_select_columns(self.expression)
794        drop_cols = self._ensure_and_normalize_cols(cols)
795        new_columns = [
796            col
797            for col in all_columns
798            if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
799        ]
800        return self.copy().select(*new_columns, append=False)
801
802    @operation(Operation.LIMIT)
803    def limit(self, num: int) -> DataFrame:
804        return self.copy(expression=self.expression.limit(num))
805
806    @operation(Operation.NO_OP)
807    def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
808        parameter_list = ensure_list(parameters)
809        parameter_columns = (
810            self._ensure_list_of_columns(parameter_list)
811            if parameters
812            else Column.ensure_cols([self.sequence_id])
813        )
814        return self._hint(name, parameter_columns)
815
816    @operation(Operation.NO_OP)
817    def repartition(
818        self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
819    ) -> DataFrame:
820        num_partition_cols = self._ensure_list_of_columns(numPartitions)
821        columns = self._ensure_and_normalize_cols(cols)
822        args = num_partition_cols + columns
823        return self._hint("repartition", args)
824
825    @operation(Operation.NO_OP)
826    def coalesce(self, numPartitions: int) -> DataFrame:
827        num_partitions = Column.ensure_cols([numPartitions])
828        return self._hint("coalesce", num_partitions)
829
830    @operation(Operation.NO_OP)
831    def cache(self) -> DataFrame:
832        return self._cache(storage_level="MEMORY_AND_DISK")
833
834    @operation(Operation.NO_OP)
835    def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
836        """
837        Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
838        """
839        return self._cache(storageLevel)
DataFrame( spark: <MagicMock id='139879724199760'>, expression: sqlglot.expressions.Select, branch_id: Optional[str] = None, sequence_id: Optional[str] = None, last_op: sqlglot.dataframe.sql.operations.Operation = <Operation.INIT: -1>, pending_hints: Optional[List[sqlglot.expressions.Expression]] = None, output_expression_container: Optional[<MagicMock id='139879724102464'>] = None, **kwargs)
50    def __init__(
51        self,
52        spark: SparkSession,
53        expression: exp.Select,
54        branch_id: t.Optional[str] = None,
55        sequence_id: t.Optional[str] = None,
56        last_op: Operation = Operation.INIT,
57        pending_hints: t.Optional[t.List[exp.Expression]] = None,
58        output_expression_container: t.Optional[OutputExpressionContainer] = None,
59        **kwargs,
60    ):
61        self.spark = spark
62        self.expression = expression
63        self.branch_id = branch_id or self.spark._random_branch_id
64        self.sequence_id = sequence_id or self.spark._random_sequence_id
65        self.last_op = last_op
66        self.pending_hints = pending_hints or []
67        self.output_expression_container = output_expression_container or exp.Select()
spark
expression
branch_id
sequence_id
last_op
pending_hints
output_expression_container
sparkSession
write
latest_cte_name: str
pending_join_hints
pending_partition_hints
columns: List[str]
def sql( self, dialect: Optional[<MagicMock id='139879720022896'>] = None, optimize: bool = True, **kwargs) -> List[str]:
301    def sql(
302        self, dialect: t.Optional[DialectType] = None, optimize: bool = True, **kwargs
303    ) -> t.List[str]:
304        from sqlglot.dataframe.sql.session import SparkSession
305
306        if dialect and Dialect.get_or_raise(dialect)() != SparkSession().dialect:
307            logger.warning(
308                f"The recommended way of defining a dialect is by doing `SparkSession.builder.config('sqlframe.dialect', '{dialect}').getOrCreate()`. It is no longer needed then when calling `sql`. If you run into issues try updating your query to use this pattern."
309            )
310        df = self._resolve_pending_hints()
311        select_expressions = df._get_select_expressions()
312        output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
313        replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
314        for expression_type, select_expression in select_expressions:
315            select_expression = select_expression.transform(replace_id_value, replacement_mapping)
316            if optimize:
317                quote_identifiers(select_expression)
318                select_expression = t.cast(
319                    exp.Select, optimize_func(select_expression, dialect=SparkSession().dialect)
320                )
321            select_expression = df._replace_cte_names_with_hashes(select_expression)
322            expression: t.Union[exp.Select, exp.Cache, exp.Drop]
323            if expression_type == exp.Cache:
324                cache_table_name = df._create_hash_from_expression(select_expression)
325                cache_table = exp.to_table(cache_table_name)
326                original_alias_name = select_expression.args["cte_alias_name"]
327
328                replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(  # type: ignore
329                    cache_table_name
330                )
331                sqlglot.schema.add_table(
332                    cache_table_name,
333                    {
334                        expression.alias_or_name: expression.type.sql(
335                            dialect=SparkSession().dialect
336                        )
337                        for expression in select_expression.expressions
338                    },
339                    dialect=SparkSession().dialect,
340                )
341                cache_storage_level = select_expression.args["cache_storage_level"]
342                options = [
343                    exp.Literal.string("storageLevel"),
344                    exp.Literal.string(cache_storage_level),
345                ]
346                expression = exp.Cache(
347                    this=cache_table, expression=select_expression, lazy=True, options=options
348                )
349                # We will drop the "view" if it exists before running the cache table
350                output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
351            elif expression_type == exp.Create:
352                expression = df.output_expression_container.copy()
353                expression.set("expression", select_expression)
354            elif expression_type == exp.Insert:
355                expression = df.output_expression_container.copy()
356                select_without_ctes = select_expression.copy()
357                select_without_ctes.set("with", None)
358                expression.set("expression", select_without_ctes)
359                if select_expression.ctes:
360                    expression.set("with", exp.With(expressions=select_expression.ctes))
361            elif expression_type == exp.Select:
362                expression = select_expression
363            else:
364                raise ValueError(f"Invalid expression type: {expression_type}")
365            output_expressions.append(expression)
366
367        return [
368            expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
369            for expression in output_expressions
370        ]
def copy(self, **kwargs) -> DataFrame:
372    def copy(self, **kwargs) -> DataFrame:
373        return DataFrame(**object_to_dict(self, **kwargs))
@operation(Operation.SELECT)
def select(self, *cols, **kwargs) -> DataFrame:
375    @operation(Operation.SELECT)
376    def select(self, *cols, **kwargs) -> DataFrame:
377        cols = self._ensure_and_normalize_cols(cols)
378        kwargs["append"] = kwargs.get("append", False)
379        if self.expression.args.get("joins"):
380            ambiguous_cols = [
381                col
382                for col in cols
383                if isinstance(col.column_expression, exp.Column) and not col.column_expression.table
384            ]
385            if ambiguous_cols:
386                join_table_identifiers = [
387                    x.this for x in get_tables_from_expression_with_join(self.expression)
388                ]
389                cte_names_in_join = [x.this for x in join_table_identifiers]
390                # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
391                # and therefore we allow multiple columns with the same name in the result. This matches the behavior
392                # of Spark.
393                resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols}
394                for ambiguous_col in ambiguous_cols:
395                    ctes_with_column = [
396                        cte
397                        for cte in self.expression.ctes
398                        if cte.alias_or_name in cte_names_in_join
399                        and ambiguous_col.alias_or_name in cte.this.named_selects
400                    ]
401                    # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise,
402                    # use the same CTE we used before
403                    cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1)
404                    if cte:
405                        resolved_column_position[ambiguous_col] += 1
406                    else:
407                        cte = ctes_with_column[resolved_column_position[ambiguous_col]]
408                    ambiguous_col.expression.set("table", cte.alias_or_name)
409        return self.copy(
410            expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs
411        )
@operation(Operation.NO_OP)
def alias(self, name: str, **kwargs) -> DataFrame:
413    @operation(Operation.NO_OP)
414    def alias(self, name: str, **kwargs) -> DataFrame:
415        new_sequence_id = self.spark._random_sequence_id
416        df = self.copy()
417        for join_hint in df.pending_join_hints:
418            for expression in join_hint.expressions:
419                if expression.alias_or_name == self.sequence_id:
420                    expression.set("this", Column.ensure_col(new_sequence_id).expression)
421        df.spark._add_alias_to_mapping(name, new_sequence_id)
422        return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
@operation(Operation.WHERE)
def where( self, column: Union[Column, bool], **kwargs) -> DataFrame:
424    @operation(Operation.WHERE)
425    def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
426        col = self._ensure_and_normalize_col(column)
427        return self.copy(expression=self.expression.where(col.expression))
@operation(Operation.WHERE)
def filter( self, column: Union[Column, bool], **kwargs) -> DataFrame:
424    @operation(Operation.WHERE)
425    def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
426        col = self._ensure_and_normalize_col(column)
427        return self.copy(expression=self.expression.where(col.expression))
@operation(Operation.GROUP_BY)
def groupBy(self, *cols, **kwargs) -> GroupedData:
431    @operation(Operation.GROUP_BY)
432    def groupBy(self, *cols, **kwargs) -> GroupedData:
433        columns = self._ensure_and_normalize_cols(cols)
434        return GroupedData(self, columns, self.last_op)
@operation(Operation.SELECT)
def agg(self, *exprs, **kwargs) -> DataFrame:
436    @operation(Operation.SELECT)
437    def agg(self, *exprs, **kwargs) -> DataFrame:
438        cols = self._ensure_and_normalize_cols(exprs)
439        return self.groupBy().agg(*cols)
@operation(Operation.FROM)
def join( self, other_df: DataFrame, on: Union[str, List[str], Column, List[Column]], how: str = 'inner', **kwargs) -> DataFrame:
441    @operation(Operation.FROM)
442    def join(
443        self,
444        other_df: DataFrame,
445        on: t.Union[str, t.List[str], Column, t.List[Column]],
446        how: str = "inner",
447        **kwargs,
448    ) -> DataFrame:
449        other_df = other_df._convert_leaf_to_cte()
450        join_columns = self._ensure_list_of_columns(on)
451        # We will determine actual "join on" expression later so we don't provide it at first
452        join_expression = self.expression.join(
453            other_df.latest_cte_name, join_type=how.replace("_", " ")
454        )
455        join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
456        self_columns = self._get_outer_select_columns(join_expression)
457        other_columns = self._get_outer_select_columns(other_df)
458        # Determines the join clause and select columns to be used passed on what type of columns were provided for
459        # the join. The columns returned changes based on how the on expression is provided.
460        if isinstance(join_columns[0].expression, exp.Column):
461            """
462            Unique characteristics of join on column names only:
463            * The column names are put at the front of the select list
464            * The column names are deduplicated across the entire select list and only the column names (other dups are allowed)
465            """
466            table_names = [
467                table.alias_or_name
468                for table in get_tables_from_expression_with_join(join_expression)
469            ]
470            potential_ctes = [
471                cte
472                for cte in join_expression.ctes
473                if cte.alias_or_name in table_names
474                and cte.alias_or_name != other_df.latest_cte_name
475            ]
476            # Determine the table to reference for the left side of the join by checking each of the left side
477            # tables and see if they have the column being referenced.
478            join_column_pairs = []
479            for join_column in join_columns:
480                num_matching_ctes = 0
481                for cte in potential_ctes:
482                    if join_column.alias_or_name in cte.this.named_selects:
483                        left_column = join_column.copy().set_table_name(cte.alias_or_name)
484                        right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
485                        join_column_pairs.append((left_column, right_column))
486                        num_matching_ctes += 1
487                if num_matching_ctes > 1:
488                    raise ValueError(
489                        f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name."
490                    )
491                elif num_matching_ctes == 0:
492                    raise ValueError(
493                        f"Column {join_column.alias_or_name} does not exist in any of the tables."
494                    )
495            join_clause = functools.reduce(
496                lambda x, y: x & y,
497                [left_column == right_column for left_column, right_column in join_column_pairs],
498            )
499            join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
500            # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
501            select_column_names = [
502                column.alias_or_name
503                if not isinstance(column.expression.this, exp.Star)
504                else column.sql()
505                for column in self_columns + other_columns
506            ]
507            select_column_names = [
508                column_name
509                for column_name in select_column_names
510                if column_name not in join_column_names
511            ]
512            select_column_names = join_column_names + select_column_names
513        else:
514            """
515            Unique characteristics of join on expressions:
516            * There is no deduplication of the results.
517            * The left join dataframe columns go first and right come after. No sort preference is given to join columns
518            """
519            join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
520            if len(join_columns) > 1:
521                join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
522            join_clause = join_columns[0]
523            select_column_names = [column.alias_or_name for column in self_columns + other_columns]
524
525        # Update the on expression with the actual join clause to replace the dummy one from before
526        join_expression.args["joins"][-1].set("on", join_clause.expression)
527        new_df = self.copy(expression=join_expression)
528        new_df.pending_join_hints.extend(self.pending_join_hints)
529        new_df.pending_hints.extend(other_df.pending_hints)
530        new_df = new_df.select.__wrapped__(new_df, *select_column_names)
531        return new_df
@operation(Operation.ORDER_BY)
def orderBy( self, *cols: Union[str, Column], ascending: Union[Any, List[Any], NoneType] = None) -> DataFrame:
533    @operation(Operation.ORDER_BY)
534    def orderBy(
535        self,
536        *cols: t.Union[str, Column],
537        ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
538    ) -> DataFrame:
539        """
540        This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
541        has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
542        is unlikely to come up.
543        """
544        columns = self._ensure_and_normalize_cols(cols)
545        pre_ordered_col_indexes = [
546            x
547            for x in [
548                i if isinstance(col.expression, exp.Ordered) else None
549                for i, col in enumerate(columns)
550            ]
551            if x is not None
552        ]
553        if ascending is None:
554            ascending = [True] * len(columns)
555        elif not isinstance(ascending, list):
556            ascending = [ascending] * len(columns)
557        ascending = [bool(x) for i, x in enumerate(ascending)]
558        assert len(columns) == len(
559            ascending
560        ), "The length of items in ascending must equal the number of columns provided"
561        col_and_ascending = list(zip(columns, ascending))
562        order_by_columns = [
563            exp.Ordered(this=col.expression, desc=not asc)
564            if i not in pre_ordered_col_indexes
565            else columns[i].column_expression
566            for i, (col, asc) in enumerate(col_and_ascending)
567        ]
568        return self.copy(expression=self.expression.order_by(*order_by_columns))

This implementation lets any ordered columns take priority over whatever is provided in ascending. Spark has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this is unlikely to come up.

@operation(Operation.ORDER_BY)
def sort( self, *cols: Union[str, Column], ascending: Union[Any, List[Any], NoneType] = None) -> DataFrame:
533    @operation(Operation.ORDER_BY)
534    def orderBy(
535        self,
536        *cols: t.Union[str, Column],
537        ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
538    ) -> DataFrame:
539        """
540        This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
541        has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
542        is unlikely to come up.
543        """
544        columns = self._ensure_and_normalize_cols(cols)
545        pre_ordered_col_indexes = [
546            x
547            for x in [
548                i if isinstance(col.expression, exp.Ordered) else None
549                for i, col in enumerate(columns)
550            ]
551            if x is not None
552        ]
553        if ascending is None:
554            ascending = [True] * len(columns)
555        elif not isinstance(ascending, list):
556            ascending = [ascending] * len(columns)
557        ascending = [bool(x) for i, x in enumerate(ascending)]
558        assert len(columns) == len(
559            ascending
560        ), "The length of items in ascending must equal the number of columns provided"
561        col_and_ascending = list(zip(columns, ascending))
562        order_by_columns = [
563            exp.Ordered(this=col.expression, desc=not asc)
564            if i not in pre_ordered_col_indexes
565            else columns[i].column_expression
566            for i, (col, asc) in enumerate(col_and_ascending)
567        ]
568        return self.copy(expression=self.expression.order_by(*order_by_columns))

This implementation lets any ordered columns take priority over whatever is provided in ascending. Spark has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this is unlikely to come up.

@operation(Operation.FROM)
def union( self, other: DataFrame) -> DataFrame:
572    @operation(Operation.FROM)
573    def union(self, other: DataFrame) -> DataFrame:
574        return self._set_operation(exp.Union, other, False)
@operation(Operation.FROM)
def unionAll( self, other: DataFrame) -> DataFrame:
572    @operation(Operation.FROM)
573    def union(self, other: DataFrame) -> DataFrame:
574        return self._set_operation(exp.Union, other, False)
@operation(Operation.FROM)
def unionByName( self, other: DataFrame, allowMissingColumns: bool = False):
578    @operation(Operation.FROM)
579    def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
580        l_columns = self.columns
581        r_columns = other.columns
582        if not allowMissingColumns:
583            l_expressions = l_columns
584            r_expressions = l_columns
585        else:
586            l_expressions = []
587            r_expressions = []
588            r_columns_unused = copy(r_columns)
589            for l_column in l_columns:
590                l_expressions.append(l_column)
591                if l_column in r_columns:
592                    r_expressions.append(l_column)
593                    r_columns_unused.remove(l_column)
594                else:
595                    r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False))
596            for r_column in r_columns_unused:
597                l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False))
598                r_expressions.append(r_column)
599        r_df = (
600            other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
601        )
602        l_df = self.copy()
603        if allowMissingColumns:
604            l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
605        return l_df._set_operation(exp.Union, r_df, False)
@operation(Operation.FROM)
def intersect( self, other: DataFrame) -> DataFrame:
607    @operation(Operation.FROM)
608    def intersect(self, other: DataFrame) -> DataFrame:
609        return self._set_operation(exp.Intersect, other, True)
@operation(Operation.FROM)
def intersectAll( self, other: DataFrame) -> DataFrame:
611    @operation(Operation.FROM)
612    def intersectAll(self, other: DataFrame) -> DataFrame:
613        return self._set_operation(exp.Intersect, other, False)
@operation(Operation.FROM)
def exceptAll( self, other: DataFrame) -> DataFrame:
615    @operation(Operation.FROM)
616    def exceptAll(self, other: DataFrame) -> DataFrame:
617        return self._set_operation(exp.Except, other, False)
@operation(Operation.SELECT)
def distinct(self) -> DataFrame:
619    @operation(Operation.SELECT)
620    def distinct(self) -> DataFrame:
621        return self.copy(expression=self.expression.distinct())
@operation(Operation.SELECT)
def dropDuplicates(self, subset: Optional[List[str]] = None):
623    @operation(Operation.SELECT)
624    def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
625        if not subset:
626            return self.distinct()
627        column_names = ensure_list(subset)
628        window = Window.partitionBy(*column_names).orderBy(*column_names)
629        return (
630            self.copy()
631            .withColumn("row_num", F.row_number().over(window))
632            .where(F.col("row_num") == F.lit(1))
633            .drop("row_num")
634        )
@operation(Operation.FROM)
def dropna( self, how: str = 'any', thresh: Optional[int] = None, subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> DataFrame:
636    @operation(Operation.FROM)
637    def dropna(
638        self,
639        how: str = "any",
640        thresh: t.Optional[int] = None,
641        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
642    ) -> DataFrame:
643        minimum_non_null = thresh or 0  # will be determined later if thresh is null
644        new_df = self.copy()
645        all_columns = self._get_outer_select_columns(new_df.expression)
646        if subset:
647            null_check_columns = self._ensure_and_normalize_cols(subset)
648        else:
649            null_check_columns = all_columns
650        if thresh is None:
651            minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
652        else:
653            minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
654        if minimum_num_nulls > len(null_check_columns):
655            raise RuntimeError(
656                f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
657                f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
658            )
659        if_null_checks = [
660            F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
661        ]
662        nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
663        num_nulls = nulls_added_together.alias("num_nulls")
664        new_df = new_df.select(num_nulls, append=True)
665        filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
666        final_df = filtered_df.select(*all_columns)
667        return final_df
@operation(Operation.FROM)
def fillna( self, value: <MagicMock id='139879719639920'>, subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> DataFrame:
669    @operation(Operation.FROM)
670    def fillna(
671        self,
672        value: t.Union[ColumnLiterals],
673        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
674    ) -> DataFrame:
675        """
676        Functionality Difference: If you provide a value to replace a null and that type conflicts
677        with the type of the column then PySpark will just ignore your replacement.
678        This will try to cast them to be the same in some cases. So they won't always match.
679        Best to not mix types so make sure replacement is the same type as the column
680
681        Possibility for improvement: Use `typeof` function to get the type of the column
682        and check if it matches the type of the value provided. If not then make it null.
683        """
684        from sqlglot.dataframe.sql.functions import lit
685
686        values = None
687        columns = None
688        new_df = self.copy()
689        all_columns = self._get_outer_select_columns(new_df.expression)
690        all_column_mapping = {column.alias_or_name: column for column in all_columns}
691        if isinstance(value, dict):
692            values = list(value.values())
693            columns = self._ensure_and_normalize_cols(list(value))
694        if not columns:
695            columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
696        if not values:
697            values = [value] * len(columns)
698        value_columns = [lit(value) for value in values]
699
700        null_replacement_mapping = {
701            column.alias_or_name: (
702                F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
703            )
704            for column, value in zip(columns, value_columns)
705        }
706        null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
707        null_replacement_columns = [
708            null_replacement_mapping[column.alias_or_name] for column in all_columns
709        ]
710        new_df = new_df.select(*null_replacement_columns)
711        return new_df

Functionality Difference: If you provide a value to replace a null and that type conflicts with the type of the column then PySpark will just ignore your replacement. This will try to cast them to be the same in some cases. So they won't always match. Best to not mix types so make sure replacement is the same type as the column

Possibility for improvement: Use typeof function to get the type of the column and check if it matches the type of the value provided. If not then make it null.

@operation(Operation.FROM)
def replace( self, to_replace: Union[bool, int, float, str, List, Dict], value: Union[bool, int, float, str, List, NoneType] = None, subset: Union[Collection[<MagicMock id='139879720333088'>], <MagicMock id='139879720333088'>, NoneType] = None) -> DataFrame:
713    @operation(Operation.FROM)
714    def replace(
715        self,
716        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
717        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
718        subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
719    ) -> DataFrame:
720        from sqlglot.dataframe.sql.functions import lit
721
722        old_values = None
723        new_df = self.copy()
724        all_columns = self._get_outer_select_columns(new_df.expression)
725        all_column_mapping = {column.alias_or_name: column for column in all_columns}
726
727        columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
728        if isinstance(to_replace, dict):
729            old_values = list(to_replace)
730            new_values = list(to_replace.values())
731        elif not old_values and isinstance(to_replace, list):
732            assert isinstance(value, list), "value must be a list since the replacements are a list"
733            assert len(to_replace) == len(
734                value
735            ), "the replacements and values must be the same length"
736            old_values = to_replace
737            new_values = value
738        else:
739            old_values = [to_replace] * len(columns)
740            new_values = [value] * len(columns)
741        old_values = [lit(value) for value in old_values]
742        new_values = [lit(value) for value in new_values]
743
744        replacement_mapping = {}
745        for column in columns:
746            expression = Column(None)
747            for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
748                if i == 0:
749                    expression = F.when(column == old_value, new_value)
750                else:
751                    expression = expression.when(column == old_value, new_value)  # type: ignore
752            replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
753                column.expression.alias_or_name
754            )
755
756        replacement_mapping = {**all_column_mapping, **replacement_mapping}
757        replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
758        new_df = new_df.select(*replacement_columns)
759        return new_df
@operation(Operation.SELECT)
def withColumn( self, colName: str, col: Column) -> DataFrame:
761    @operation(Operation.SELECT)
762    def withColumn(self, colName: str, col: Column) -> DataFrame:
763        col = self._ensure_and_normalize_col(col)
764        existing_col_names = self.expression.named_selects
765        existing_col_index = (
766            existing_col_names.index(colName) if colName in existing_col_names else None
767        )
768        if existing_col_index:
769            expression = self.expression.copy()
770            expression.expressions[existing_col_index] = col.expression
771            return self.copy(expression=expression)
772        return self.copy().select(col.alias(colName), append=True)
@operation(Operation.SELECT)
def withColumnRenamed(self, existing: str, new: str):
774    @operation(Operation.SELECT)
775    def withColumnRenamed(self, existing: str, new: str):
776        expression = self.expression.copy()
777        existing_columns = [
778            expression
779            for expression in expression.expressions
780            if expression.alias_or_name == existing
781        ]
782        if not existing_columns:
783            raise ValueError("Tried to rename a column that doesn't exist")
784        for existing_column in existing_columns:
785            if isinstance(existing_column, exp.Column):
786                existing_column.replace(exp.alias_(existing_column, new))
787            else:
788                existing_column.set("alias", exp.to_identifier(new))
789        return self.copy(expression=expression)
@operation(Operation.SELECT)
def drop( self, *cols: Union[str, Column]) -> DataFrame:
791    @operation(Operation.SELECT)
792    def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
793        all_columns = self._get_outer_select_columns(self.expression)
794        drop_cols = self._ensure_and_normalize_cols(cols)
795        new_columns = [
796            col
797            for col in all_columns
798            if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
799        ]
800        return self.copy().select(*new_columns, append=False)
@operation(Operation.LIMIT)
def limit(self, num: int) -> DataFrame:
802    @operation(Operation.LIMIT)
803    def limit(self, num: int) -> DataFrame:
804        return self.copy(expression=self.expression.limit(num))
@operation(Operation.NO_OP)
def hint( self, name: str, *parameters: Union[str, int, NoneType]) -> DataFrame:
806    @operation(Operation.NO_OP)
807    def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
808        parameter_list = ensure_list(parameters)
809        parameter_columns = (
810            self._ensure_list_of_columns(parameter_list)
811            if parameters
812            else Column.ensure_cols([self.sequence_id])
813        )
814        return self._hint(name, parameter_columns)
@operation(Operation.NO_OP)
def repartition( self, numPartitions: Union[int, <MagicMock id='139879720501456'>], *cols: <MagicMock id='139879720564960'>) -> DataFrame:
816    @operation(Operation.NO_OP)
817    def repartition(
818        self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
819    ) -> DataFrame:
820        num_partition_cols = self._ensure_list_of_columns(numPartitions)
821        columns = self._ensure_and_normalize_cols(cols)
822        args = num_partition_cols + columns
823        return self._hint("repartition", args)
@operation(Operation.NO_OP)
def coalesce(self, numPartitions: int) -> DataFrame:
825    @operation(Operation.NO_OP)
826    def coalesce(self, numPartitions: int) -> DataFrame:
827        num_partitions = Column.ensure_cols([numPartitions])
828        return self._hint("coalesce", num_partitions)
@operation(Operation.NO_OP)
def cache(self) -> DataFrame:
830    @operation(Operation.NO_OP)
831    def cache(self) -> DataFrame:
832        return self._cache(storage_level="MEMORY_AND_DISK")
@operation(Operation.NO_OP)
def persist( self, storageLevel: str = 'MEMORY_AND_DISK_SER') -> DataFrame:
834    @operation(Operation.NO_OP)
835    def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
836        """
837        Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
838        """
839        return self._cache(storageLevel)
class GroupedData:
14class GroupedData:
15    def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation):
16        self._df = df.copy()
17        self.spark = df.spark
18        self.last_op = last_op
19        self.group_by_cols = group_by_cols
20
21    def _get_function_applied_columns(
22        self, func_name: str, cols: t.Tuple[str, ...]
23    ) -> t.List[Column]:
24        func_name = func_name.lower()
25        return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols]
26
27    @operation(Operation.SELECT)
28    def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame:
29        columns = (
30            [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
31            if isinstance(exprs[0], dict)
32            else exprs
33        )
34        cols = self._df._ensure_and_normalize_cols(columns)
35
36        expression = self._df.expression.group_by(
37            *[x.expression for x in self.group_by_cols]
38        ).select(*[x.expression for x in self.group_by_cols + cols], append=False)
39        return self._df.copy(expression=expression)
40
41    def count(self) -> DataFrame:
42        return self.agg(F.count("*").alias("count"))
43
44    def mean(self, *cols: str) -> DataFrame:
45        return self.avg(*cols)
46
47    def avg(self, *cols: str) -> DataFrame:
48        return self.agg(*self._get_function_applied_columns("avg", cols))
49
50    def max(self, *cols: str) -> DataFrame:
51        return self.agg(*self._get_function_applied_columns("max", cols))
52
53    def min(self, *cols: str) -> DataFrame:
54        return self.agg(*self._get_function_applied_columns("min", cols))
55
56    def sum(self, *cols: str) -> DataFrame:
57        return self.agg(*self._get_function_applied_columns("sum", cols))
58
59    def pivot(self, *cols: str) -> DataFrame:
60        raise NotImplementedError("Sum distinct is not currently implemented")
GroupedData( df: DataFrame, group_by_cols: List[Column], last_op: sqlglot.dataframe.sql.operations.Operation)
15    def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation):
16        self._df = df.copy()
17        self.spark = df.spark
18        self.last_op = last_op
19        self.group_by_cols = group_by_cols
spark
last_op
group_by_cols
@operation(Operation.SELECT)
def agg( self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame:
27    @operation(Operation.SELECT)
28    def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame:
29        columns = (
30            [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
31            if isinstance(exprs[0], dict)
32            else exprs
33        )
34        cols = self._df._ensure_and_normalize_cols(columns)
35
36        expression = self._df.expression.group_by(
37            *[x.expression for x in self.group_by_cols]
38        ).select(*[x.expression for x in self.group_by_cols + cols], append=False)
39        return self._df.copy(expression=expression)
def count(self) -> DataFrame:
41    def count(self) -> DataFrame:
42        return self.agg(F.count("*").alias("count"))
def mean(self, *cols: str) -> DataFrame:
44    def mean(self, *cols: str) -> DataFrame:
45        return self.avg(*cols)
def avg(self, *cols: str) -> DataFrame:
47    def avg(self, *cols: str) -> DataFrame:
48        return self.agg(*self._get_function_applied_columns("avg", cols))
def max(self, *cols: str) -> DataFrame:
50    def max(self, *cols: str) -> DataFrame:
51        return self.agg(*self._get_function_applied_columns("max", cols))
def min(self, *cols: str) -> DataFrame:
53    def min(self, *cols: str) -> DataFrame:
54        return self.agg(*self._get_function_applied_columns("min", cols))
def sum(self, *cols: str) -> DataFrame:
56    def sum(self, *cols: str) -> DataFrame:
57        return self.agg(*self._get_function_applied_columns("sum", cols))
def pivot(self, *cols: str) -> DataFrame:
59    def pivot(self, *cols: str) -> DataFrame:
60        raise NotImplementedError("Sum distinct is not currently implemented")
class Column:
 16class Column:
 17    def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
 18        from sqlglot.dataframe.sql.session import SparkSession
 19
 20        if isinstance(expression, Column):
 21            expression = expression.expression  # type: ignore
 22        elif expression is None or not isinstance(expression, (str, exp.Expression)):
 23            expression = self._lit(expression).expression  # type: ignore
 24        elif not isinstance(expression, exp.Column):
 25            expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform(
 26                SparkSession().dialect.normalize_identifier, copy=False
 27            )
 28        if expression is None:
 29            raise ValueError(f"Could not parse {expression}")
 30
 31        self.expression: exp.Expression = expression  # type: ignore
 32
 33    def __repr__(self):
 34        return repr(self.expression)
 35
 36    def __hash__(self):
 37        return hash(self.expression)
 38
 39    def __eq__(self, other: ColumnOrLiteral) -> Column:  # type: ignore
 40        return self.binary_op(exp.EQ, other)
 41
 42    def __ne__(self, other: ColumnOrLiteral) -> Column:  # type: ignore
 43        return self.binary_op(exp.NEQ, other)
 44
 45    def __gt__(self, other: ColumnOrLiteral) -> Column:
 46        return self.binary_op(exp.GT, other)
 47
 48    def __ge__(self, other: ColumnOrLiteral) -> Column:
 49        return self.binary_op(exp.GTE, other)
 50
 51    def __lt__(self, other: ColumnOrLiteral) -> Column:
 52        return self.binary_op(exp.LT, other)
 53
 54    def __le__(self, other: ColumnOrLiteral) -> Column:
 55        return self.binary_op(exp.LTE, other)
 56
 57    def __and__(self, other: ColumnOrLiteral) -> Column:
 58        return self.binary_op(exp.And, other)
 59
 60    def __or__(self, other: ColumnOrLiteral) -> Column:
 61        return self.binary_op(exp.Or, other)
 62
 63    def __mod__(self, other: ColumnOrLiteral) -> Column:
 64        return self.binary_op(exp.Mod, other)
 65
 66    def __add__(self, other: ColumnOrLiteral) -> Column:
 67        return self.binary_op(exp.Add, other)
 68
 69    def __sub__(self, other: ColumnOrLiteral) -> Column:
 70        return self.binary_op(exp.Sub, other)
 71
 72    def __mul__(self, other: ColumnOrLiteral) -> Column:
 73        return self.binary_op(exp.Mul, other)
 74
 75    def __truediv__(self, other: ColumnOrLiteral) -> Column:
 76        return self.binary_op(exp.Div, other)
 77
 78    def __div__(self, other: ColumnOrLiteral) -> Column:
 79        return self.binary_op(exp.Div, other)
 80
 81    def __neg__(self) -> Column:
 82        return self.unary_op(exp.Neg)
 83
 84    def __radd__(self, other: ColumnOrLiteral) -> Column:
 85        return self.inverse_binary_op(exp.Add, other)
 86
 87    def __rsub__(self, other: ColumnOrLiteral) -> Column:
 88        return self.inverse_binary_op(exp.Sub, other)
 89
 90    def __rmul__(self, other: ColumnOrLiteral) -> Column:
 91        return self.inverse_binary_op(exp.Mul, other)
 92
 93    def __rdiv__(self, other: ColumnOrLiteral) -> Column:
 94        return self.inverse_binary_op(exp.Div, other)
 95
 96    def __rtruediv__(self, other: ColumnOrLiteral) -> Column:
 97        return self.inverse_binary_op(exp.Div, other)
 98
 99    def __rmod__(self, other: ColumnOrLiteral) -> Column:
100        return self.inverse_binary_op(exp.Mod, other)
101
102    def __pow__(self, power: ColumnOrLiteral, modulo=None):
103        return Column(exp.Pow(this=self.expression, expression=Column(power).expression))
104
105    def __rpow__(self, power: ColumnOrLiteral):
106        return Column(exp.Pow(this=Column(power).expression, expression=self.expression))
107
108    def __invert__(self):
109        return self.unary_op(exp.Not)
110
111    def __rand__(self, other: ColumnOrLiteral) -> Column:
112        return self.inverse_binary_op(exp.And, other)
113
114    def __ror__(self, other: ColumnOrLiteral) -> Column:
115        return self.inverse_binary_op(exp.Or, other)
116
117    @classmethod
118    def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]) -> Column:
119        return cls(value)
120
121    @classmethod
122    def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]:
123        return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args]
124
125    @classmethod
126    def _lit(cls, value: ColumnOrLiteral) -> Column:
127        if isinstance(value, dict):
128            columns = [cls._lit(v).alias(k).expression for k, v in value.items()]
129            return cls(exp.Struct(expressions=columns))
130        return cls(exp.convert(value))
131
132    @classmethod
133    def invoke_anonymous_function(
134        cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral]
135    ) -> Column:
136        columns = [] if column is None else [cls.ensure_col(column)]
137        column_args = [cls.ensure_col(arg) for arg in args]
138        expressions = [x.expression for x in columns + column_args]
139        new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions)
140        return Column(new_expression)
141
142    @classmethod
143    def invoke_expression_over_column(
144        cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
145    ) -> Column:
146        ensured_column = None if column is None else cls.ensure_col(column)
147        ensure_expression_values = {
148            k: [Column.ensure_col(x).expression for x in v]
149            if is_iterable(v)
150            else Column.ensure_col(v).expression
151            for k, v in kwargs.items()
152            if v is not None
153        }
154        new_expression = (
155            callable_expression(**ensure_expression_values)
156            if ensured_column is None
157            else callable_expression(
158                this=ensured_column.column_expression, **ensure_expression_values
159            )
160        )
161        return Column(new_expression)
162
163    def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
164        return Column(
165            klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
166        )
167
168    def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
169        return Column(
170            klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
171        )
172
173    def unary_op(self, klass: t.Callable, **kwargs) -> Column:
174        return Column(klass(this=self.column_expression, **kwargs))
175
176    @property
177    def is_alias(self):
178        return isinstance(self.expression, exp.Alias)
179
180    @property
181    def is_column(self):
182        return isinstance(self.expression, exp.Column)
183
184    @property
185    def column_expression(self) -> t.Union[exp.Column, exp.Literal]:
186        return self.expression.unalias()
187
188    @property
189    def alias_or_name(self) -> str:
190        return self.expression.alias_or_name
191
192    @classmethod
193    def ensure_literal(cls, value) -> Column:
194        from sqlglot.dataframe.sql.functions import lit
195
196        if isinstance(value, cls):
197            value = value.expression
198        if not isinstance(value, exp.Literal):
199            return lit(value)
200        return Column(value)
201
202    def copy(self) -> Column:
203        return Column(self.expression.copy())
204
205    def set_table_name(self, table_name: str, copy=False) -> Column:
206        expression = self.expression.copy() if copy else self.expression
207        expression.set("table", exp.to_identifier(table_name))
208        return Column(expression)
209
210    def sql(self, **kwargs) -> str:
211        from sqlglot.dataframe.sql.session import SparkSession
212
213        return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
214
215    def alias(self, name: str) -> Column:
216        from sqlglot.dataframe.sql.session import SparkSession
217
218        dialect = SparkSession().dialect
219        alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect)
220        new_expression = exp.alias_(
221            self.column_expression,
222            alias.this if isinstance(alias, exp.Column) else name,
223            dialect=dialect,
224        )
225        return Column(new_expression)
226
227    def asc(self) -> Column:
228        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
229        return Column(new_expression)
230
231    def desc(self) -> Column:
232        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
233        return Column(new_expression)
234
235    asc_nulls_first = asc
236
237    def asc_nulls_last(self) -> Column:
238        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False)
239        return Column(new_expression)
240
241    def desc_nulls_first(self) -> Column:
242        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True)
243        return Column(new_expression)
244
245    desc_nulls_last = desc
246
247    def when(self, condition: Column, value: t.Any) -> Column:
248        from sqlglot.dataframe.sql.functions import when
249
250        column_with_if = when(condition, value)
251        if not isinstance(self.expression, exp.Case):
252            return column_with_if
253        new_column = self.copy()
254        new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"])
255        return new_column
256
257    def otherwise(self, value: t.Any) -> Column:
258        from sqlglot.dataframe.sql.functions import lit
259
260        true_value = value if isinstance(value, Column) else lit(value)
261        new_column = self.copy()
262        new_column.expression.set("default", true_value.column_expression)
263        return new_column
264
265    def isNull(self) -> Column:
266        new_expression = exp.Is(this=self.column_expression, expression=exp.Null())
267        return Column(new_expression)
268
269    def isNotNull(self) -> Column:
270        new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
271        return Column(new_expression)
272
273    def cast(self, dataType: t.Union[str, DataType]) -> Column:
274        """
275        Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
276        Sqlglot doesn't currently replicate this class so it only accepts a string
277        """
278        from sqlglot.dataframe.sql.session import SparkSession
279
280        if isinstance(dataType, DataType):
281            dataType = dataType.simpleString()
282        return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect))
283
284    def startswith(self, value: t.Union[str, Column]) -> Column:
285        value = self._lit(value) if not isinstance(value, Column) else value
286        return self.invoke_anonymous_function(self, "STARTSWITH", value)
287
288    def endswith(self, value: t.Union[str, Column]) -> Column:
289        value = self._lit(value) if not isinstance(value, Column) else value
290        return self.invoke_anonymous_function(self, "ENDSWITH", value)
291
292    def rlike(self, regexp: str) -> Column:
293        return self.invoke_expression_over_column(
294            column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression
295        )
296
297    def like(self, other: str):
298        return self.invoke_expression_over_column(
299            self, exp.Like, expression=self._lit(other).expression
300        )
301
302    def ilike(self, other: str):
303        return self.invoke_expression_over_column(
304            self, exp.ILike, expression=self._lit(other).expression
305        )
306
307    def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
308        startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
309        length = self._lit(length) if not isinstance(length, Column) else length
310        return Column.invoke_expression_over_column(
311            self, exp.Substring, start=startPos.expression, length=length.expression
312        )
313
314    def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]):
315        columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
316        expressions = [self._lit(x).expression for x in columns]
317        return Column.invoke_expression_over_column(self, exp.In, expressions=expressions)  # type: ignore
318
319    def between(
320        self,
321        lowerBound: t.Union[ColumnOrLiteral],
322        upperBound: t.Union[ColumnOrLiteral],
323    ) -> Column:
324        lower_bound_exp = (
325            self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
326        )
327        upper_bound_exp = (
328            self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
329        )
330        return Column(
331            exp.Between(
332                this=self.column_expression,
333                low=lower_bound_exp.expression,
334                high=upper_bound_exp.expression,
335            )
336        )
337
338    def over(self, window: WindowSpec) -> Column:
339        window_expression = window.expression.copy()
340        window_expression.set("this", self.column_expression)
341        return Column(window_expression)
Column( expression: Union[<MagicMock id='139879722174736'>, sqlglot.expressions.Expression, NoneType])
17    def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
18        from sqlglot.dataframe.sql.session import SparkSession
19
20        if isinstance(expression, Column):
21            expression = expression.expression  # type: ignore
22        elif expression is None or not isinstance(expression, (str, exp.Expression)):
23            expression = self._lit(expression).expression  # type: ignore
24        elif not isinstance(expression, exp.Column):
25            expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform(
26                SparkSession().dialect.normalize_identifier, copy=False
27            )
28        if expression is None:
29            raise ValueError(f"Could not parse {expression}")
30
31        self.expression: exp.Expression = expression  # type: ignore
@classmethod
def ensure_col( cls, value: Union[<MagicMock id='139879718606960'>, sqlglot.expressions.Expression, NoneType]) -> Column:
117    @classmethod
118    def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]) -> Column:
119        return cls(value)
@classmethod
def ensure_cols( cls, args: List[Union[<MagicMock id='139879718833936'>, sqlglot.expressions.Expression]]) -> List[Column]:
121    @classmethod
122    def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]:
123        return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args]
@classmethod
def invoke_anonymous_function( cls, column: Optional[<MagicMock id='139879718802608'>], func_name: str, *args: Optional[<MagicMock id='139879718760704'>]) -> Column:
132    @classmethod
133    def invoke_anonymous_function(
134        cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral]
135    ) -> Column:
136        columns = [] if column is None else [cls.ensure_col(column)]
137        column_args = [cls.ensure_col(arg) for arg in args]
138        expressions = [x.expression for x in columns + column_args]
139        new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions)
140        return Column(new_expression)
@classmethod
def invoke_expression_over_column( cls, column: Optional[<MagicMock id='139879718629728'>], callable_expression: Callable, **kwargs) -> Column:
142    @classmethod
143    def invoke_expression_over_column(
144        cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
145    ) -> Column:
146        ensured_column = None if column is None else cls.ensure_col(column)
147        ensure_expression_values = {
148            k: [Column.ensure_col(x).expression for x in v]
149            if is_iterable(v)
150            else Column.ensure_col(v).expression
151            for k, v in kwargs.items()
152            if v is not None
153        }
154        new_expression = (
155            callable_expression(**ensure_expression_values)
156            if ensured_column is None
157            else callable_expression(
158                this=ensured_column.column_expression, **ensure_expression_values
159            )
160        )
161        return Column(new_expression)
def binary_op( self, klass: Callable, other: <MagicMock id='139879718850464'>, **kwargs) -> Column:
163    def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
164        return Column(
165            klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
166        )
def inverse_binary_op( self, klass: Callable, other: <MagicMock id='139879718860592'>, **kwargs) -> Column:
168    def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
169        return Column(
170            klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
171        )
def unary_op(self, klass: Callable, **kwargs) -> Column:
173    def unary_op(self, klass: t.Callable, **kwargs) -> Column:
174        return Column(klass(this=self.column_expression, **kwargs))
is_alias
is_column
alias_or_name: str
@classmethod
def ensure_literal(cls, value) -> Column:
192    @classmethod
193    def ensure_literal(cls, value) -> Column:
194        from sqlglot.dataframe.sql.functions import lit
195
196        if isinstance(value, cls):
197            value = value.expression
198        if not isinstance(value, exp.Literal):
199            return lit(value)
200        return Column(value)
def copy(self) -> Column:
202    def copy(self) -> Column:
203        return Column(self.expression.copy())
def set_table_name(self, table_name: str, copy=False) -> Column:
205    def set_table_name(self, table_name: str, copy=False) -> Column:
206        expression = self.expression.copy() if copy else self.expression
207        expression.set("table", exp.to_identifier(table_name))
208        return Column(expression)
def sql(self, **kwargs) -> str:
210    def sql(self, **kwargs) -> str:
211        from sqlglot.dataframe.sql.session import SparkSession
212
213        return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
def alias(self, name: str) -> Column:
215    def alias(self, name: str) -> Column:
216        from sqlglot.dataframe.sql.session import SparkSession
217
218        dialect = SparkSession().dialect
219        alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect)
220        new_expression = exp.alias_(
221            self.column_expression,
222            alias.this if isinstance(alias, exp.Column) else name,
223            dialect=dialect,
224        )
225        return Column(new_expression)
def asc(self) -> Column:
227    def asc(self) -> Column:
228        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
229        return Column(new_expression)
def desc(self) -> Column:
231    def desc(self) -> Column:
232        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
233        return Column(new_expression)
def asc_nulls_first(self) -> Column:
227    def asc(self) -> Column:
228        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
229        return Column(new_expression)
def asc_nulls_last(self) -> Column:
237    def asc_nulls_last(self) -> Column:
238        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False)
239        return Column(new_expression)
def desc_nulls_first(self) -> Column:
241    def desc_nulls_first(self) -> Column:
242        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True)
243        return Column(new_expression)
def desc_nulls_last(self) -> Column:
231    def desc(self) -> Column:
232        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
233        return Column(new_expression)
def when( self, condition: Column, value: Any) -> Column:
247    def when(self, condition: Column, value: t.Any) -> Column:
248        from sqlglot.dataframe.sql.functions import when
249
250        column_with_if = when(condition, value)
251        if not isinstance(self.expression, exp.Case):
252            return column_with_if
253        new_column = self.copy()
254        new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"])
255        return new_column
def otherwise(self, value: Any) -> Column:
257    def otherwise(self, value: t.Any) -> Column:
258        from sqlglot.dataframe.sql.functions import lit
259
260        true_value = value if isinstance(value, Column) else lit(value)
261        new_column = self.copy()
262        new_column.expression.set("default", true_value.column_expression)
263        return new_column
def isNull(self) -> Column:
265    def isNull(self) -> Column:
266        new_expression = exp.Is(this=self.column_expression, expression=exp.Null())
267        return Column(new_expression)
def isNotNull(self) -> Column:
269    def isNotNull(self) -> Column:
270        new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
271        return Column(new_expression)
def cast( self, dataType: Union[str, sqlglot.dataframe.sql.types.DataType]) -> Column:
273    def cast(self, dataType: t.Union[str, DataType]) -> Column:
274        """
275        Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
276        Sqlglot doesn't currently replicate this class so it only accepts a string
277        """
278        from sqlglot.dataframe.sql.session import SparkSession
279
280        if isinstance(dataType, DataType):
281            dataType = dataType.simpleString()
282        return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect))

Functionality Difference: PySpark cast accepts a datatype instance of the datatype class Sqlglot doesn't currently replicate this class so it only accepts a string

def startswith( self, value: Union[str, Column]) -> Column:
284    def startswith(self, value: t.Union[str, Column]) -> Column:
285        value = self._lit(value) if not isinstance(value, Column) else value
286        return self.invoke_anonymous_function(self, "STARTSWITH", value)
def endswith( self, value: Union[str, Column]) -> Column:
288    def endswith(self, value: t.Union[str, Column]) -> Column:
289        value = self._lit(value) if not isinstance(value, Column) else value
290        return self.invoke_anonymous_function(self, "ENDSWITH", value)
def rlike(self, regexp: str) -> Column:
292    def rlike(self, regexp: str) -> Column:
293        return self.invoke_expression_over_column(
294            column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression
295        )
def like(self, other: str):
297    def like(self, other: str):
298        return self.invoke_expression_over_column(
299            self, exp.Like, expression=self._lit(other).expression
300        )
def ilike(self, other: str):
302    def ilike(self, other: str):
303        return self.invoke_expression_over_column(
304            self, exp.ILike, expression=self._lit(other).expression
305        )
def substr( self, startPos: Union[int, Column], length: Union[int, Column]) -> Column:
307    def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
308        startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
309        length = self._lit(length) if not isinstance(length, Column) else length
310        return Column.invoke_expression_over_column(
311            self, exp.Substring, start=startPos.expression, length=length.expression
312        )
def isin( self, *cols: Union[<MagicMock id='139879719076048'>, Iterable[<MagicMock id='139879719076048'>]]):
314    def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]):
315        columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
316        expressions = [self._lit(x).expression for x in columns]
317        return Column.invoke_expression_over_column(self, exp.In, expressions=expressions)  # type: ignore
def between( self, lowerBound: <MagicMock id='139879719199664'>, upperBound: <MagicMock id='139879719254144'>) -> Column:
319    def between(
320        self,
321        lowerBound: t.Union[ColumnOrLiteral],
322        upperBound: t.Union[ColumnOrLiteral],
323    ) -> Column:
324        lower_bound_exp = (
325            self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
326        )
327        upper_bound_exp = (
328            self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
329        )
330        return Column(
331            exp.Between(
332                this=self.column_expression,
333                low=lower_bound_exp.expression,
334                high=upper_bound_exp.expression,
335            )
336        )
def over( self, window: <MagicMock id='139879719279088'>) -> Column:
338    def over(self, window: WindowSpec) -> Column:
339        window_expression = window.expression.copy()
340        window_expression.set("this", self.column_expression)
341        return Column(window_expression)
class DataFrameNaFunctions:
842class DataFrameNaFunctions:
843    def __init__(self, df: DataFrame):
844        self.df = df
845
846    def drop(
847        self,
848        how: str = "any",
849        thresh: t.Optional[int] = None,
850        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
851    ) -> DataFrame:
852        return self.df.dropna(how=how, thresh=thresh, subset=subset)
853
854    def fill(
855        self,
856        value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
857        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
858    ) -> DataFrame:
859        return self.df.fillna(value=value, subset=subset)
860
861    def replace(
862        self,
863        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
864        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
865        subset: t.Optional[t.Union[str, t.List[str]]] = None,
866    ) -> DataFrame:
867        return self.df.replace(to_replace=to_replace, value=value, subset=subset)
DataFrameNaFunctions(df: DataFrame)
843    def __init__(self, df: DataFrame):
844        self.df = df
df
def drop( self, how: str = 'any', thresh: Optional[int] = None, subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> DataFrame:
846    def drop(
847        self,
848        how: str = "any",
849        thresh: t.Optional[int] = None,
850        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
851    ) -> DataFrame:
852        return self.df.dropna(how=how, thresh=thresh, subset=subset)
def fill( self, value: Union[int, bool, float, str, Dict[str, Any]], subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> DataFrame:
854    def fill(
855        self,
856        value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
857        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
858    ) -> DataFrame:
859        return self.df.fillna(value=value, subset=subset)
def replace( self, to_replace: Union[bool, int, float, str, List, Dict], value: Union[bool, int, float, str, List, NoneType] = None, subset: Union[str, List[str], NoneType] = None) -> DataFrame:
861    def replace(
862        self,
863        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
864        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
865        subset: t.Optional[t.Union[str, t.List[str]]] = None,
866    ) -> DataFrame:
867        return self.df.replace(to_replace=to_replace, value=value, subset=subset)
class Window:
15class Window:
16    _JAVA_MIN_LONG = -(1 << 63)  # -9223372036854775808
17    _JAVA_MAX_LONG = (1 << 63) - 1  # 9223372036854775807
18    _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG)
19    _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG)
20
21    unboundedPreceding: int = _JAVA_MIN_LONG
22
23    unboundedFollowing: int = _JAVA_MAX_LONG
24
25    currentRow: int = 0
26
27    @classmethod
28    def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
29        return WindowSpec().partitionBy(*cols)
30
31    @classmethod
32    def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
33        return WindowSpec().orderBy(*cols)
34
35    @classmethod
36    def rowsBetween(cls, start: int, end: int) -> WindowSpec:
37        return WindowSpec().rowsBetween(start, end)
38
39    @classmethod
40    def rangeBetween(cls, start: int, end: int) -> WindowSpec:
41        return WindowSpec().rangeBetween(start, end)
unboundedPreceding: int = -9223372036854775808
unboundedFollowing: int = 9223372036854775807
currentRow: int = 0
@classmethod
def partitionBy( cls, *cols: Union[<MagicMock id='139879719449648'>, List[<MagicMock id='139879719449648'>]]) -> WindowSpec:
27    @classmethod
28    def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
29        return WindowSpec().partitionBy(*cols)
@classmethod
def orderBy( cls, *cols: Union[<MagicMock id='139879719450800'>, List[<MagicMock id='139879719450800'>]]) -> WindowSpec:
31    @classmethod
32    def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
33        return WindowSpec().orderBy(*cols)
@classmethod
def rowsBetween(cls, start: int, end: int) -> WindowSpec:
35    @classmethod
36    def rowsBetween(cls, start: int, end: int) -> WindowSpec:
37        return WindowSpec().rowsBetween(start, end)
@classmethod
def rangeBetween(cls, start: int, end: int) -> WindowSpec:
39    @classmethod
40    def rangeBetween(cls, start: int, end: int) -> WindowSpec:
41        return WindowSpec().rangeBetween(start, end)
class WindowSpec:
 44class WindowSpec:
 45    def __init__(self, expression: exp.Expression = exp.Window()):
 46        self.expression = expression
 47
 48    def copy(self):
 49        return WindowSpec(self.expression.copy())
 50
 51    def sql(self, **kwargs) -> str:
 52        from sqlglot.dataframe.sql.session import SparkSession
 53
 54        return self.expression.sql(dialect=SparkSession().dialect, **kwargs)
 55
 56    def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
 57        from sqlglot.dataframe.sql.column import Column
 58
 59        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
 60        expressions = [Column.ensure_col(x).expression for x in cols]
 61        window_spec = self.copy()
 62        partition_by_expressions = window_spec.expression.args.get("partition_by", [])
 63        partition_by_expressions.extend(expressions)
 64        window_spec.expression.set("partition_by", partition_by_expressions)
 65        return window_spec
 66
 67    def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
 68        from sqlglot.dataframe.sql.column import Column
 69
 70        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
 71        expressions = [Column.ensure_col(x).expression for x in cols]
 72        window_spec = self.copy()
 73        if window_spec.expression.args.get("order") is None:
 74            window_spec.expression.set("order", exp.Order(expressions=[]))
 75        order_by = window_spec.expression.args["order"].expressions
 76        order_by.extend(expressions)
 77        window_spec.expression.args["order"].set("expressions", order_by)
 78        return window_spec
 79
 80    def _calc_start_end(
 81        self, start: int, end: int
 82    ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
 83        kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {
 84            "start_side": None,
 85            "end_side": None,
 86        }
 87        if start == Window.currentRow:
 88            kwargs["start"] = "CURRENT ROW"
 89        else:
 90            kwargs = {
 91                **kwargs,
 92                **{
 93                    "start_side": "PRECEDING",
 94                    "start": "UNBOUNDED"
 95                    if start <= Window.unboundedPreceding
 96                    else F.lit(start).expression,
 97                },
 98            }
 99        if end == Window.currentRow:
100            kwargs["end"] = "CURRENT ROW"
101        else:
102            kwargs = {
103                **kwargs,
104                **{
105                    "end_side": "FOLLOWING",
106                    "end": "UNBOUNDED"
107                    if end >= Window.unboundedFollowing
108                    else F.lit(end).expression,
109                },
110            }
111        return kwargs
112
113    def rowsBetween(self, start: int, end: int) -> WindowSpec:
114        window_spec = self.copy()
115        spec = self._calc_start_end(start, end)
116        spec["kind"] = "ROWS"
117        window_spec.expression.set(
118            "spec",
119            exp.WindowSpec(
120                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
121            ),
122        )
123        return window_spec
124
125    def rangeBetween(self, start: int, end: int) -> WindowSpec:
126        window_spec = self.copy()
127        spec = self._calc_start_end(start, end)
128        spec["kind"] = "RANGE"
129        window_spec.expression.set(
130            "spec",
131            exp.WindowSpec(
132                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
133            ),
134        )
135        return window_spec
WindowSpec(expression: sqlglot.expressions.Expression = (WINDOW ))
45    def __init__(self, expression: exp.Expression = exp.Window()):
46        self.expression = expression
expression
def copy(self):
48    def copy(self):
49        return WindowSpec(self.expression.copy())
def sql(self, **kwargs) -> str:
51    def sql(self, **kwargs) -> str:
52        from sqlglot.dataframe.sql.session import SparkSession
53
54        return self.expression.sql(dialect=SparkSession().dialect, **kwargs)
def partitionBy( self, *cols: Union[<MagicMock id='139879718310320'>, List[<MagicMock id='139879718310320'>]]) -> WindowSpec:
56    def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
57        from sqlglot.dataframe.sql.column import Column
58
59        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
60        expressions = [Column.ensure_col(x).expression for x in cols]
61        window_spec = self.copy()
62        partition_by_expressions = window_spec.expression.args.get("partition_by", [])
63        partition_by_expressions.extend(expressions)
64        window_spec.expression.set("partition_by", partition_by_expressions)
65        return window_spec
def orderBy( self, *cols: Union[<MagicMock id='139879717620032'>, List[<MagicMock id='139879717620032'>]]) -> WindowSpec:
67    def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
68        from sqlglot.dataframe.sql.column import Column
69
70        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
71        expressions = [Column.ensure_col(x).expression for x in cols]
72        window_spec = self.copy()
73        if window_spec.expression.args.get("order") is None:
74            window_spec.expression.set("order", exp.Order(expressions=[]))
75        order_by = window_spec.expression.args["order"].expressions
76        order_by.extend(expressions)
77        window_spec.expression.args["order"].set("expressions", order_by)
78        return window_spec
def rowsBetween(self, start: int, end: int) -> WindowSpec:
113    def rowsBetween(self, start: int, end: int) -> WindowSpec:
114        window_spec = self.copy()
115        spec = self._calc_start_end(start, end)
116        spec["kind"] = "ROWS"
117        window_spec.expression.set(
118            "spec",
119            exp.WindowSpec(
120                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
121            ),
122        )
123        return window_spec
def rangeBetween(self, start: int, end: int) -> WindowSpec:
125    def rangeBetween(self, start: int, end: int) -> WindowSpec:
126        window_spec = self.copy()
127        spec = self._calc_start_end(start, end)
128        spec["kind"] = "RANGE"
129        window_spec.expression.set(
130            "spec",
131            exp.WindowSpec(
132                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
133            ),
134        )
135        return window_spec
class DataFrameReader:
15class DataFrameReader:
16    def __init__(self, spark: SparkSession):
17        self.spark = spark
18
19    def table(self, tableName: str) -> DataFrame:
20        from sqlglot.dataframe.sql.dataframe import DataFrame
21        from sqlglot.dataframe.sql.session import SparkSession
22
23        sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect)
24
25        return DataFrame(
26            self.spark,
27            exp.Select()
28            .from_(
29                exp.to_table(tableName, dialect=SparkSession().dialect).transform(
30                    SparkSession().dialect.normalize_identifier
31                )
32            )
33            .select(
34                *(
35                    column
36                    for column in sqlglot.schema.column_names(
37                        tableName, dialect=SparkSession().dialect
38                    )
39                )
40            ),
41        )
DataFrameReader(spark: SparkSession)
16    def __init__(self, spark: SparkSession):
17        self.spark = spark
spark
def table(self, tableName: str) -> DataFrame:
19    def table(self, tableName: str) -> DataFrame:
20        from sqlglot.dataframe.sql.dataframe import DataFrame
21        from sqlglot.dataframe.sql.session import SparkSession
22
23        sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect)
24
25        return DataFrame(
26            self.spark,
27            exp.Select()
28            .from_(
29                exp.to_table(tableName, dialect=SparkSession().dialect).transform(
30                    SparkSession().dialect.normalize_identifier
31                )
32            )
33            .select(
34                *(
35                    column
36                    for column in sqlglot.schema.column_names(
37                        tableName, dialect=SparkSession().dialect
38                    )
39                )
40            ),
41        )
class DataFrameWriter:
 44class DataFrameWriter:
 45    def __init__(
 46        self,
 47        df: DataFrame,
 48        spark: t.Optional[SparkSession] = None,
 49        mode: t.Optional[str] = None,
 50        by_name: bool = False,
 51    ):
 52        self._df = df
 53        self._spark = spark or df.spark
 54        self._mode = mode
 55        self._by_name = by_name
 56
 57    def copy(self, **kwargs) -> DataFrameWriter:
 58        return DataFrameWriter(
 59            **{
 60                k[1:] if k.startswith("_") else k: v
 61                for k, v in object_to_dict(self, **kwargs).items()
 62            }
 63        )
 64
 65    def sql(self, **kwargs) -> t.List[str]:
 66        return self._df.sql(**kwargs)
 67
 68    def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
 69        return self.copy(_mode=saveMode)
 70
 71    @property
 72    def byName(self):
 73        return self.copy(by_name=True)
 74
 75    def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
 76        from sqlglot.dataframe.sql.session import SparkSession
 77
 78        output_expression_container = exp.Insert(
 79            **{
 80                "this": exp.to_table(tableName),
 81                "overwrite": overwrite,
 82            }
 83        )
 84        df = self._df.copy(output_expression_container=output_expression_container)
 85        if self._by_name:
 86            columns = sqlglot.schema.column_names(
 87                tableName, only_visible=True, dialect=SparkSession().dialect
 88            )
 89            df = df._convert_leaf_to_cte().select(*columns)
 90
 91        return self.copy(_df=df)
 92
 93    def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
 94        if format is not None:
 95            raise NotImplementedError("Providing Format in the save as table is not supported")
 96        exists, replace, mode = None, None, mode or str(self._mode)
 97        if mode == "append":
 98            return self.insertInto(name)
 99        if mode == "ignore":
100            exists = True
101        if mode == "overwrite":
102            replace = True
103        output_expression_container = exp.Create(
104            this=exp.to_table(name),
105            kind="TABLE",
106            exists=exists,
107            replace=replace,
108        )
109        return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
DataFrameWriter( df: DataFrame, spark: Optional[SparkSession] = None, mode: Optional[str] = None, by_name: bool = False)
45    def __init__(
46        self,
47        df: DataFrame,
48        spark: t.Optional[SparkSession] = None,
49        mode: t.Optional[str] = None,
50        by_name: bool = False,
51    ):
52        self._df = df
53        self._spark = spark or df.spark
54        self._mode = mode
55        self._by_name = by_name
def copy(self, **kwargs) -> DataFrameWriter:
57    def copy(self, **kwargs) -> DataFrameWriter:
58        return DataFrameWriter(
59            **{
60                k[1:] if k.startswith("_") else k: v
61                for k, v in object_to_dict(self, **kwargs).items()
62            }
63        )
def sql(self, **kwargs) -> List[str]:
65    def sql(self, **kwargs) -> t.List[str]:
66        return self._df.sql(**kwargs)
def mode( self, saveMode: Optional[str]) -> DataFrameWriter:
68    def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
69        return self.copy(_mode=saveMode)
byName
def insertInto( self, tableName: str, overwrite: Optional[bool] = None) -> DataFrameWriter:
75    def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
76        from sqlglot.dataframe.sql.session import SparkSession
77
78        output_expression_container = exp.Insert(
79            **{
80                "this": exp.to_table(tableName),
81                "overwrite": overwrite,
82            }
83        )
84        df = self._df.copy(output_expression_container=output_expression_container)
85        if self._by_name:
86            columns = sqlglot.schema.column_names(
87                tableName, only_visible=True, dialect=SparkSession().dialect
88            )
89            df = df._convert_leaf_to_cte().select(*columns)
90
91        return self.copy(_df=df)
def saveAsTable( self, name: str, format: Optional[str] = None, mode: Optional[str] = None):
 93    def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
 94        if format is not None:
 95            raise NotImplementedError("Providing Format in the save as table is not supported")
 96        exists, replace, mode = None, None, mode or str(self._mode)
 97        if mode == "append":
 98            return self.insertInto(name)
 99        if mode == "ignore":
100            exists = True
101        if mode == "overwrite":
102            replace = True
103        output_expression_container = exp.Create(
104            this=exp.to_table(name),
105            kind="TABLE",
106            exists=exists,
107            replace=replace,
108        )
109        return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))