diff options
Diffstat (limited to 'sqlglot/dataframe/sql/dataframe.py')
-rw-r--r-- | sqlglot/dataframe/sql/dataframe.py | 730 |
1 files changed, 730 insertions, 0 deletions
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py new file mode 100644 index 0000000..322dcf2 --- /dev/null +++ b/sqlglot/dataframe/sql/dataframe.py @@ -0,0 +1,730 @@ +from __future__ import annotations + +import functools +import typing as t +import zlib +from copy import copy + +import sqlglot +from sqlglot import expressions as exp +from sqlglot.dataframe.sql import functions as F +from sqlglot.dataframe.sql.column import Column +from sqlglot.dataframe.sql.group import GroupedData +from sqlglot.dataframe.sql.normalize import normalize +from sqlglot.dataframe.sql.operations import Operation, operation +from sqlglot.dataframe.sql.readwriter import DataFrameWriter +from sqlglot.dataframe.sql.transforms import replace_id_value +from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join +from sqlglot.dataframe.sql.window import Window +from sqlglot.helper import ensure_list, object_to_dict +from sqlglot.optimizer import optimize as optimize_func +from sqlglot.optimizer.qualify_columns import qualify_columns + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql._typing import ColumnLiterals, OutputExpressionContainer + from sqlglot.dataframe.sql.session import SparkSession + + +JOIN_HINTS = { + "BROADCAST", + "BROADCASTJOIN", + "MAPJOIN", + "MERGE", + "SHUFFLEMERGE", + "MERGEJOIN", + "SHUFFLE_HASH", + "SHUFFLE_REPLICATE_NL", +} + + +class DataFrame: + def __init__( + self, + spark: SparkSession, + expression: exp.Select, + branch_id: t.Optional[str] = None, + sequence_id: t.Optional[str] = None, + last_op: Operation = Operation.INIT, + pending_hints: t.Optional[t.List[exp.Expression]] = None, + output_expression_container: t.Optional[OutputExpressionContainer] = None, + **kwargs, + ): + self.spark = spark + self.expression = expression + self.branch_id = branch_id or self.spark._random_branch_id + self.sequence_id = sequence_id or self.spark._random_sequence_id + self.last_op = last_op + self.pending_hints = pending_hints or [] + self.output_expression_container = output_expression_container or exp.Select() + + def __getattr__(self, column_name: str) -> Column: + return self[column_name] + + def __getitem__(self, column_name: str) -> Column: + column_name = f"{self.branch_id}.{column_name}" + return Column(column_name) + + def __copy__(self): + return self.copy() + + @property + def sparkSession(self): + return self.spark + + @property + def write(self): + return DataFrameWriter(self) + + @property + def latest_cte_name(self) -> str: + if not self.expression.ctes: + from_exp = self.expression.args["from"] + if from_exp.alias_or_name: + return from_exp.alias_or_name + table_alias = from_exp.find(exp.TableAlias) + if not table_alias: + raise RuntimeError(f"Could not find an alias name for this expression: {self.expression}") + return table_alias.alias_or_name + return self.expression.ctes[-1].alias + + @property + def pending_join_hints(self): + return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)] + + @property + def pending_partition_hints(self): + return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)] + + @property + def columns(self) -> t.List[str]: + return self.expression.named_selects + + @property + def na(self) -> DataFrameNaFunctions: + return DataFrameNaFunctions(self) + + def _replace_cte_names_with_hashes(self, expression: exp.Select): + expression = expression.copy() + ctes = expression.ctes + replacement_mapping = {} + for cte in ctes: + old_name_id = cte.args["alias"].this + new_hashed_id = exp.to_identifier( + self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"] + ) + replacement_mapping[old_name_id] = new_hashed_id + cte.set("alias", exp.TableAlias(this=new_hashed_id)) + expression = expression.transform(replace_id_value, replacement_mapping) + return expression + + def _create_cte_from_expression( + self, + expression: exp.Expression, + branch_id: t.Optional[str] = None, + sequence_id: t.Optional[str] = None, + **kwargs, + ) -> t.Tuple[exp.CTE, str]: + name = self.spark._random_name + expression_to_cte = expression.copy() + expression_to_cte.set("with", None) + cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0] + cte.set("branch_id", branch_id or self.branch_id) + cte.set("sequence_id", sequence_id or self.sequence_id) + return cte, name + + def _ensure_list_of_columns( + self, cols: t.Union[str, t.Iterable[str], Column, t.Iterable[Column]] + ) -> t.List[Column]: + columns = ensure_list(cols) + columns = Column.ensure_cols(columns) + return columns + + def _ensure_and_normalize_cols(self, cols): + cols = self._ensure_list_of_columns(cols) + normalize(self.spark, self.expression, cols) + return cols + + def _ensure_and_normalize_col(self, col): + col = Column.ensure_col(col) + normalize(self.spark, self.expression, col) + return col + + def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame: + df = self._resolve_pending_hints() + sequence_id = sequence_id or df.sequence_id + expression = df.expression.copy() + cte_expression, cte_name = df._create_cte_from_expression(expression=expression, sequence_id=sequence_id) + new_expression = df._add_ctes_to_expression(exp.Select(), expression.ctes + [cte_expression]) + sel_columns = df._get_outer_select_columns(cte_expression) + new_expression = new_expression.from_(cte_name).select(*[x.alias_or_name for x in sel_columns]) + return df.copy(expression=new_expression, sequence_id=sequence_id) + + def _resolve_pending_hints(self) -> DataFrame: + df = self.copy() + if not self.pending_hints: + return df + expression = df.expression + hint_expression = expression.args.get("hint") or exp.Hint(expressions=[]) + for hint in df.pending_partition_hints: + hint_expression.args.get("expressions").append(hint) + df.pending_hints.remove(hint) + + join_aliases = {join_table.alias_or_name for join_table in get_tables_from_expression_with_join(expression)} + if join_aliases: + for hint in df.pending_join_hints: + for sequence_id_expression in hint.expressions: + sequence_id_or_name = sequence_id_expression.alias_or_name + sequence_ids_to_match = [sequence_id_or_name] + if sequence_id_or_name in df.spark.name_to_sequence_id_mapping: + sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[sequence_id_or_name] + matching_ctes = [ + cte for cte in reversed(expression.ctes) if cte.args["sequence_id"] in sequence_ids_to_match + ] + for matching_cte in matching_ctes: + if matching_cte.alias_or_name in join_aliases: + sequence_id_expression.set("this", matching_cte.args["alias"].this) + df.pending_hints.remove(hint) + break + hint_expression.args.get("expressions").append(hint) + if hint_expression.expressions: + expression.set("hint", hint_expression) + return df + + def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame: + hint_name = hint_name.upper() + hint_expression = ( + exp.JoinHint(this=hint_name, expressions=[exp.to_table(parameter.alias_or_name) for parameter in args]) + if hint_name in JOIN_HINTS + else exp.Anonymous(this=hint_name, expressions=[parameter.expression for parameter in args]) + ) + new_df = self.copy() + new_df.pending_hints.append(hint_expression) + return new_df + + def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool): + other_df = other._convert_leaf_to_cte() + base_expression = self.expression.copy() + base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes) + all_ctes = base_expression.ctes + other_df.expression.set("with", None) + base_expression.set("with", None) + operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression) + operation.set("with", exp.With(expressions=all_ctes)) + return self.copy(expression=operation)._convert_leaf_to_cte() + + def _cache(self, storage_level: str): + df = self._convert_leaf_to_cte() + df.expression.ctes[-1].set("cache_storage_level", storage_level) + return df + + @classmethod + def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select: + expression = expression.copy() + with_expression = expression.args.get("with") + if with_expression: + existing_ctes = with_expression.expressions + existsing_cte_names = {x.alias_or_name for x in existing_ctes} + for cte in ctes: + if cte.alias_or_name not in existsing_cte_names: + existing_ctes.append(cte) + else: + existing_ctes = ctes + expression.set("with", exp.With(expressions=existing_ctes)) + return expression + + @classmethod + def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]: + expression = item.expression if isinstance(item, DataFrame) else item + return [Column(x) for x in expression.find(exp.Select).expressions] + + @classmethod + def _create_hash_from_expression(cls, expression: exp.Select): + value = expression.sql(dialect="spark").encode("utf-8") + return f"t{zlib.crc32(value)}"[:6] + + def _get_select_expressions( + self, + ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]: + select_expressions: t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]] = [] + main_select_ctes: t.List[exp.CTE] = [] + for cte in self.expression.ctes: + cache_storage_level = cte.args.get("cache_storage_level") + if cache_storage_level: + select_expression = cte.this.copy() + select_expression.set("with", exp.With(expressions=copy(main_select_ctes))) + select_expression.set("cte_alias_name", cte.alias_or_name) + select_expression.set("cache_storage_level", cache_storage_level) + select_expressions.append((exp.Cache, select_expression)) + else: + main_select_ctes.append(cte) + main_select = self.expression.copy() + if main_select_ctes: + main_select.set("with", exp.With(expressions=main_select_ctes)) + expression_select_pair = (type(self.output_expression_container), main_select) + select_expressions.append(expression_select_pair) # type: ignore + return select_expressions + + def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]: + df = self._resolve_pending_hints() + select_expressions = df._get_select_expressions() + output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] + replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} + for expression_type, select_expression in select_expressions: + select_expression = select_expression.transform(replace_id_value, replacement_mapping) + if optimize: + select_expression = optimize_func(select_expression) + select_expression = df._replace_cte_names_with_hashes(select_expression) + expression: t.Union[exp.Select, exp.Cache, exp.Drop] + if expression_type == exp.Cache: + cache_table_name = df._create_hash_from_expression(select_expression) + cache_table = exp.to_table(cache_table_name) + original_alias_name = select_expression.args["cte_alias_name"] + replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(cache_table_name) + sqlglot.schema.add_table(cache_table_name, select_expression.named_selects) + cache_storage_level = select_expression.args["cache_storage_level"] + options = [ + exp.Literal.string("storageLevel"), + exp.Literal.string(cache_storage_level), + ] + expression = exp.Cache(this=cache_table, expression=select_expression, lazy=True, options=options) + # We will drop the "view" if it exists before running the cache table + output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) + elif expression_type == exp.Create: + expression = df.output_expression_container.copy() + expression.set("expression", select_expression) + elif expression_type == exp.Insert: + expression = df.output_expression_container.copy() + select_without_ctes = select_expression.copy() + select_without_ctes.set("with", None) + expression.set("expression", select_without_ctes) + if select_expression.ctes: + expression.set("with", exp.With(expressions=select_expression.ctes)) + elif expression_type == exp.Select: + expression = select_expression + else: + raise ValueError(f"Invalid expression type: {expression_type}") + output_expressions.append(expression) + + return [expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions] + + def copy(self, **kwargs) -> DataFrame: + return DataFrame(**object_to_dict(self, **kwargs)) + + @operation(Operation.SELECT) + def select(self, *cols, **kwargs) -> DataFrame: + cols = self._ensure_and_normalize_cols(cols) + kwargs["append"] = kwargs.get("append", False) + if self.expression.args.get("joins"): + ambiguous_cols = [col for col in cols if not col.column_expression.table] + if ambiguous_cols: + join_table_identifiers = [x.this for x in get_tables_from_expression_with_join(self.expression)] + cte_names_in_join = [x.this for x in join_table_identifiers] + for ambiguous_col in ambiguous_cols: + ctes_with_column = [ + cte + for cte in self.expression.ctes + if cte.alias_or_name in cte_names_in_join + and ambiguous_col.alias_or_name in cte.this.named_selects + ] + # If the select column does not specify a table and there is a join + # then we assume they are referring to the left table + if len(ctes_with_column) > 1: + table_identifier = self.expression.args["from"].args["expressions"][0].this + else: + table_identifier = ctes_with_column[0].args["alias"].this + ambiguous_col.expression.set("table", table_identifier) + expression = self.expression.select(*[x.expression for x in cols], **kwargs) + qualify_columns(expression, sqlglot.schema) + return self.copy(expression=expression, **kwargs) + + @operation(Operation.NO_OP) + def alias(self, name: str, **kwargs) -> DataFrame: + new_sequence_id = self.spark._random_sequence_id + df = self.copy() + for join_hint in df.pending_join_hints: + for expression in join_hint.expressions: + if expression.alias_or_name == self.sequence_id: + expression.set("this", Column.ensure_col(new_sequence_id).expression) + df.spark._add_alias_to_mapping(name, new_sequence_id) + return df._convert_leaf_to_cte(sequence_id=new_sequence_id) + + @operation(Operation.WHERE) + def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame: + col = self._ensure_and_normalize_col(column) + return self.copy(expression=self.expression.where(col.expression)) + + filter = where + + @operation(Operation.GROUP_BY) + def groupBy(self, *cols, **kwargs) -> GroupedData: + columns = self._ensure_and_normalize_cols(cols) + return GroupedData(self, columns, self.last_op) + + @operation(Operation.SELECT) + def agg(self, *exprs, **kwargs) -> DataFrame: + cols = self._ensure_and_normalize_cols(exprs) + return self.groupBy().agg(*cols) + + @operation(Operation.FROM) + def join( + self, other_df: DataFrame, on: t.Union[str, t.List[str], Column, t.List[Column]], how: str = "inner", **kwargs + ) -> DataFrame: + other_df = other_df._convert_leaf_to_cte() + pre_join_self_latest_cte_name = self.latest_cte_name + columns = self._ensure_and_normalize_cols(on) + join_type = how.replace("_", " ") + if isinstance(columns[0].expression, exp.Column): + join_columns = [Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns] + join_clause = functools.reduce( + lambda x, y: x & y, + [ + col.copy().set_table_name(pre_join_self_latest_cte_name) + == col.copy().set_table_name(other_df.latest_cte_name) + for col in columns + ], + ) + else: + if len(columns) > 1: + columns = [functools.reduce(lambda x, y: x & y, columns)] + join_clause = columns[0] + join_columns = [ + Column(x).set_table_name(pre_join_self_latest_cte_name) + if i % 2 == 0 + else Column(x).set_table_name(other_df.latest_cte_name) + for i, x in enumerate(join_clause.expression.find_all(exp.Column)) + ] + self_columns = [ + column.set_table_name(pre_join_self_latest_cte_name, copy=True) + for column in self._get_outer_select_columns(self) + ] + other_columns = [ + column.set_table_name(other_df.latest_cte_name, copy=True) + for column in self._get_outer_select_columns(other_df) + ] + column_value_mapping = { + column.alias_or_name if not isinstance(column.expression.this, exp.Star) else column.sql(): column + for column in other_columns + self_columns + join_columns + } + all_columns = [ + column_value_mapping[name] + for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns} + ] + new_df = self.copy( + expression=self.expression.join(other_df.latest_cte_name, on=join_clause.expression, join_type=join_type) + ) + new_df.expression = new_df._add_ctes_to_expression(new_df.expression, other_df.expression.ctes) + new_df.pending_hints.extend(other_df.pending_hints) + new_df = new_df.select.__wrapped__(new_df, *all_columns) + return new_df + + @operation(Operation.ORDER_BY) + def orderBy( + self, *cols: t.Union[str, Column], ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None + ) -> DataFrame: + """ + This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark + has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this + is unlikely to come up. + """ + columns = self._ensure_and_normalize_cols(cols) + pre_ordered_col_indexes = [ + x + for x in [i if isinstance(col.expression, exp.Ordered) else None for i, col in enumerate(columns)] + if x is not None + ] + if ascending is None: + ascending = [True] * len(columns) + elif not isinstance(ascending, list): + ascending = [ascending] * len(columns) + ascending = [bool(x) for i, x in enumerate(ascending)] + assert len(columns) == len( + ascending + ), "The length of items in ascending must equal the number of columns provided" + col_and_ascending = list(zip(columns, ascending)) + order_by_columns = [ + exp.Ordered(this=col.expression, desc=not asc) + if i not in pre_ordered_col_indexes + else columns[i].column_expression + for i, (col, asc) in enumerate(col_and_ascending) + ] + return self.copy(expression=self.expression.order_by(*order_by_columns)) + + sort = orderBy + + @operation(Operation.FROM) + def union(self, other: DataFrame) -> DataFrame: + return self._set_operation(exp.Union, other, False) + + unionAll = union + + @operation(Operation.FROM) + def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): + l_columns = self.columns + r_columns = other.columns + if not allowMissingColumns: + l_expressions = l_columns + r_expressions = l_columns + else: + l_expressions = [] + r_expressions = [] + r_columns_unused = copy(r_columns) + for l_column in l_columns: + l_expressions.append(l_column) + if l_column in r_columns: + r_expressions.append(l_column) + r_columns_unused.remove(l_column) + else: + r_expressions.append(exp.alias_(exp.Null(), l_column)) + for r_column in r_columns_unused: + l_expressions.append(exp.alias_(exp.Null(), r_column)) + r_expressions.append(r_column) + r_df = other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) + l_df = self.copy() + if allowMissingColumns: + l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) + return l_df._set_operation(exp.Union, r_df, False) + + @operation(Operation.FROM) + def intersect(self, other: DataFrame) -> DataFrame: + return self._set_operation(exp.Intersect, other, True) + + @operation(Operation.FROM) + def intersectAll(self, other: DataFrame) -> DataFrame: + return self._set_operation(exp.Intersect, other, False) + + @operation(Operation.FROM) + def exceptAll(self, other: DataFrame) -> DataFrame: + return self._set_operation(exp.Except, other, False) + + @operation(Operation.SELECT) + def distinct(self) -> DataFrame: + return self.copy(expression=self.expression.distinct()) + + @operation(Operation.SELECT) + def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): + if not subset: + return self.distinct() + column_names = ensure_list(subset) + window = Window.partitionBy(*column_names).orderBy(*column_names) + return ( + self.copy() + .withColumn("row_num", F.row_number().over(window)) + .where(F.col("row_num") == F.lit(1)) + .drop("row_num") + ) + + @operation(Operation.FROM) + def dropna( + self, + how: str = "any", + thresh: t.Optional[int] = None, + subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, + ) -> DataFrame: + minimum_non_null = thresh or 0 # will be determined later if thresh is null + new_df = self.copy() + all_columns = self._get_outer_select_columns(new_df.expression) + if subset: + null_check_columns = self._ensure_and_normalize_cols(subset) + else: + null_check_columns = all_columns + if thresh is None: + minimum_num_nulls = 1 if how == "any" else len(null_check_columns) + else: + minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 + if minimum_num_nulls > len(null_check_columns): + raise RuntimeError( + f"The minimum num nulls for dropna must be less than or equal to the number of columns. " + f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" + ) + if_null_checks = [F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns] + nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) + num_nulls = nulls_added_together.alias("num_nulls") + new_df = new_df.select(num_nulls, append=True) + filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) + final_df = filtered_df.select(*all_columns) + return final_df + + @operation(Operation.FROM) + def fillna( + self, + value: t.Union[ColumnLiterals], + subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, + ) -> DataFrame: + """ + Functionality Difference: If you provide a value to replace a null and that type conflicts + with the type of the column then PySpark will just ignore your replacement. + This will try to cast them to be the same in some cases. So they won't always match. + Best to not mix types so make sure replacement is the same type as the column + + Possibility for improvement: Use `typeof` function to get the type of the column + and check if it matches the type of the value provided. If not then make it null. + """ + from sqlglot.dataframe.sql.functions import lit + + values = None + columns = None + new_df = self.copy() + all_columns = self._get_outer_select_columns(new_df.expression) + all_column_mapping = {column.alias_or_name: column for column in all_columns} + if isinstance(value, dict): + values = value.values() + columns = self._ensure_and_normalize_cols(list(value)) + if not columns: + columns = self._ensure_and_normalize_cols(subset) if subset else all_columns + if not values: + values = [value] * len(columns) + value_columns = [lit(value) for value in values] + + null_replacement_mapping = { + column.alias_or_name: (F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)) + for column, value in zip(columns, value_columns) + } + null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} + null_replacement_columns = [null_replacement_mapping[column.alias_or_name] for column in all_columns] + new_df = new_df.select(*null_replacement_columns) + return new_df + + @operation(Operation.FROM) + def replace( + self, + to_replace: t.Union[bool, int, float, str, t.List, t.Dict], + value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, + subset: t.Optional[t.Union[str, t.List[str]]] = None, + ) -> DataFrame: + from sqlglot.dataframe.sql.functions import lit + + old_values = None + subset = ensure_list(subset) + new_df = self.copy() + all_columns = self._get_outer_select_columns(new_df.expression) + all_column_mapping = {column.alias_or_name: column for column in all_columns} + + columns = self._ensure_and_normalize_cols(subset) if subset else all_columns + if isinstance(to_replace, dict): + old_values = list(to_replace) + new_values = list(to_replace.values()) + elif not old_values and isinstance(to_replace, list): + assert isinstance(value, list), "value must be a list since the replacements are a list" + assert len(to_replace) == len(value), "the replacements and values must be the same length" + old_values = to_replace + new_values = value + else: + old_values = [to_replace] * len(columns) + new_values = [value] * len(columns) + old_values = [lit(value) for value in old_values] + new_values = [lit(value) for value in new_values] + + replacement_mapping = {} + for column in columns: + expression = Column(None) + for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): + if i == 0: + expression = F.when(column == old_value, new_value) + else: + expression = expression.when(column == old_value, new_value) # type: ignore + replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( + column.expression.alias_or_name + ) + + replacement_mapping = {**all_column_mapping, **replacement_mapping} + replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] + new_df = new_df.select(*replacement_columns) + return new_df + + @operation(Operation.SELECT) + def withColumn(self, colName: str, col: Column) -> DataFrame: + col = self._ensure_and_normalize_col(col) + existing_col_names = self.expression.named_selects + existing_col_index = existing_col_names.index(colName) if colName in existing_col_names else None + if existing_col_index: + expression = self.expression.copy() + expression.expressions[existing_col_index] = col.expression + return self.copy(expression=expression) + return self.copy().select(col.alias(colName), append=True) + + @operation(Operation.SELECT) + def withColumnRenamed(self, existing: str, new: str): + expression = self.expression.copy() + existing_columns = [expression for expression in expression.expressions if expression.alias_or_name == existing] + if not existing_columns: + raise ValueError("Tried to rename a column that doesn't exist") + for existing_column in existing_columns: + if isinstance(existing_column, exp.Column): + existing_column.replace(exp.alias_(existing_column.copy(), new)) + else: + existing_column.set("alias", exp.to_identifier(new)) + return self.copy(expression=expression) + + @operation(Operation.SELECT) + def drop(self, *cols: t.Union[str, Column]) -> DataFrame: + all_columns = self._get_outer_select_columns(self.expression) + drop_cols = self._ensure_and_normalize_cols(cols) + new_columns = [ + col + for col in all_columns + if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] + ] + return self.copy().select(*new_columns, append=False) + + @operation(Operation.LIMIT) + def limit(self, num: int) -> DataFrame: + return self.copy(expression=self.expression.limit(num)) + + @operation(Operation.NO_OP) + def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: + parameter_list = ensure_list(parameters) + parameter_columns = ( + self._ensure_list_of_columns(parameter_list) if parameters else Column.ensure_cols([self.sequence_id]) + ) + return self._hint(name, parameter_columns) + + @operation(Operation.NO_OP) + def repartition(self, numPartitions: t.Union[int, str], *cols: t.Union[int, str]) -> DataFrame: + num_partitions = Column.ensure_cols(ensure_list(numPartitions)) + columns = self._ensure_and_normalize_cols(cols) + args = num_partitions + columns + return self._hint("repartition", args) + + @operation(Operation.NO_OP) + def coalesce(self, numPartitions: int) -> DataFrame: + num_partitions = Column.ensure_cols([numPartitions]) + return self._hint("coalesce", num_partitions) + + @operation(Operation.NO_OP) + def cache(self) -> DataFrame: + return self._cache(storage_level="MEMORY_AND_DISK") + + @operation(Operation.NO_OP) + def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: + """ + Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html + """ + return self._cache(storageLevel) + + +class DataFrameNaFunctions: + def __init__(self, df: DataFrame): + self.df = df + + def drop( + self, + how: str = "any", + thresh: t.Optional[int] = None, + subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, + ) -> DataFrame: + return self.df.dropna(how=how, thresh=thresh, subset=subset) + + def fill( + self, + value: t.Union[int, bool, float, str, t.Dict[str, t.Any]], + subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, + ) -> DataFrame: + return self.df.fillna(value=value, subset=subset) + + def replace( + self, + to_replace: t.Union[bool, int, float, str, t.List, t.Dict], + value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, + subset: t.Optional[t.Union[str, t.List[str]]] = None, + ) -> DataFrame: + return self.df.replace(to_replace=to_replace, value=value, subset=subset) |