diff options
Diffstat (limited to 'sqlglot/dataframe/sql/dataframe.py')
-rw-r--r-- | sqlglot/dataframe/sql/dataframe.py | 158 |
1 files changed, 117 insertions, 41 deletions
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 322dcf2..40cd6c9 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -21,7 +21,12 @@ from sqlglot.optimizer import optimize as optimize_func from sqlglot.optimizer.qualify_columns import qualify_columns if t.TYPE_CHECKING: - from sqlglot.dataframe.sql._typing import ColumnLiterals, OutputExpressionContainer + from sqlglot.dataframe.sql._typing import ( + ColumnLiterals, + ColumnOrLiteral, + ColumnOrName, + OutputExpressionContainer, + ) from sqlglot.dataframe.sql.session import SparkSession @@ -83,7 +88,9 @@ class DataFrame: return from_exp.alias_or_name table_alias = from_exp.find(exp.TableAlias) if not table_alias: - raise RuntimeError(f"Could not find an alias name for this expression: {self.expression}") + raise RuntimeError( + f"Could not find an alias name for this expression: {self.expression}" + ) return table_alias.alias_or_name return self.expression.ctes[-1].alias @@ -132,12 +139,16 @@ class DataFrame: cte.set("sequence_id", sequence_id or self.sequence_id) return cte, name - def _ensure_list_of_columns( - self, cols: t.Union[str, t.Iterable[str], Column, t.Iterable[Column]] - ) -> t.List[Column]: - columns = ensure_list(cols) - columns = Column.ensure_cols(columns) - return columns + @t.overload + def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: + ... + + @t.overload + def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: + ... + + def _ensure_list_of_columns(self, cols): + return Column.ensure_cols(ensure_list(cols)) def _ensure_and_normalize_cols(self, cols): cols = self._ensure_list_of_columns(cols) @@ -153,10 +164,16 @@ class DataFrame: df = self._resolve_pending_hints() sequence_id = sequence_id or df.sequence_id expression = df.expression.copy() - cte_expression, cte_name = df._create_cte_from_expression(expression=expression, sequence_id=sequence_id) - new_expression = df._add_ctes_to_expression(exp.Select(), expression.ctes + [cte_expression]) + cte_expression, cte_name = df._create_cte_from_expression( + expression=expression, sequence_id=sequence_id + ) + new_expression = df._add_ctes_to_expression( + exp.Select(), expression.ctes + [cte_expression] + ) sel_columns = df._get_outer_select_columns(cte_expression) - new_expression = new_expression.from_(cte_name).select(*[x.alias_or_name for x in sel_columns]) + new_expression = new_expression.from_(cte_name).select( + *[x.alias_or_name for x in sel_columns] + ) return df.copy(expression=new_expression, sequence_id=sequence_id) def _resolve_pending_hints(self) -> DataFrame: @@ -169,16 +186,23 @@ class DataFrame: hint_expression.args.get("expressions").append(hint) df.pending_hints.remove(hint) - join_aliases = {join_table.alias_or_name for join_table in get_tables_from_expression_with_join(expression)} + join_aliases = { + join_table.alias_or_name + for join_table in get_tables_from_expression_with_join(expression) + } if join_aliases: for hint in df.pending_join_hints: for sequence_id_expression in hint.expressions: sequence_id_or_name = sequence_id_expression.alias_or_name sequence_ids_to_match = [sequence_id_or_name] if sequence_id_or_name in df.spark.name_to_sequence_id_mapping: - sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[sequence_id_or_name] + sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[ + sequence_id_or_name + ] matching_ctes = [ - cte for cte in reversed(expression.ctes) if cte.args["sequence_id"] in sequence_ids_to_match + cte + for cte in reversed(expression.ctes) + if cte.args["sequence_id"] in sequence_ids_to_match ] for matching_cte in matching_ctes: if matching_cte.alias_or_name in join_aliases: @@ -193,9 +217,14 @@ class DataFrame: def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame: hint_name = hint_name.upper() hint_expression = ( - exp.JoinHint(this=hint_name, expressions=[exp.to_table(parameter.alias_or_name) for parameter in args]) + exp.JoinHint( + this=hint_name, + expressions=[exp.to_table(parameter.alias_or_name) for parameter in args], + ) if hint_name in JOIN_HINTS - else exp.Anonymous(this=hint_name, expressions=[parameter.expression for parameter in args]) + else exp.Anonymous( + this=hint_name, expressions=[parameter.expression for parameter in args] + ) ) new_df = self.copy() new_df.pending_hints.append(hint_expression) @@ -245,7 +274,9 @@ class DataFrame: def _get_select_expressions( self, ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]: - select_expressions: t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]] = [] + select_expressions: t.List[ + t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select] + ] = [] main_select_ctes: t.List[exp.CTE] = [] for cte in self.expression.ctes: cache_storage_level = cte.args.get("cache_storage_level") @@ -279,14 +310,19 @@ class DataFrame: cache_table_name = df._create_hash_from_expression(select_expression) cache_table = exp.to_table(cache_table_name) original_alias_name = select_expression.args["cte_alias_name"] - replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(cache_table_name) + + replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore + cache_table_name + ) sqlglot.schema.add_table(cache_table_name, select_expression.named_selects) cache_storage_level = select_expression.args["cache_storage_level"] options = [ exp.Literal.string("storageLevel"), exp.Literal.string(cache_storage_level), ] - expression = exp.Cache(this=cache_table, expression=select_expression, lazy=True, options=options) + expression = exp.Cache( + this=cache_table, expression=select_expression, lazy=True, options=options + ) # We will drop the "view" if it exists before running the cache table output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) elif expression_type == exp.Create: @@ -305,7 +341,9 @@ class DataFrame: raise ValueError(f"Invalid expression type: {expression_type}") output_expressions.append(expression) - return [expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions] + return [ + expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions + ] def copy(self, **kwargs) -> DataFrame: return DataFrame(**object_to_dict(self, **kwargs)) @@ -317,7 +355,9 @@ class DataFrame: if self.expression.args.get("joins"): ambiguous_cols = [col for col in cols if not col.column_expression.table] if ambiguous_cols: - join_table_identifiers = [x.this for x in get_tables_from_expression_with_join(self.expression)] + join_table_identifiers = [ + x.this for x in get_tables_from_expression_with_join(self.expression) + ] cte_names_in_join = [x.this for x in join_table_identifiers] for ambiguous_col in ambiguous_cols: ctes_with_column = [ @@ -367,14 +407,20 @@ class DataFrame: @operation(Operation.FROM) def join( - self, other_df: DataFrame, on: t.Union[str, t.List[str], Column, t.List[Column]], how: str = "inner", **kwargs + self, + other_df: DataFrame, + on: t.Union[str, t.List[str], Column, t.List[Column]], + how: str = "inner", + **kwargs, ) -> DataFrame: other_df = other_df._convert_leaf_to_cte() pre_join_self_latest_cte_name = self.latest_cte_name columns = self._ensure_and_normalize_cols(on) join_type = how.replace("_", " ") if isinstance(columns[0].expression, exp.Column): - join_columns = [Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns] + join_columns = [ + Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns + ] join_clause = functools.reduce( lambda x, y: x & y, [ @@ -402,7 +448,9 @@ class DataFrame: for column in self._get_outer_select_columns(other_df) ] column_value_mapping = { - column.alias_or_name if not isinstance(column.expression.this, exp.Star) else column.sql(): column + column.alias_or_name + if not isinstance(column.expression.this, exp.Star) + else column.sql(): column for column in other_columns + self_columns + join_columns } all_columns = [ @@ -410,16 +458,22 @@ class DataFrame: for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns} ] new_df = self.copy( - expression=self.expression.join(other_df.latest_cte_name, on=join_clause.expression, join_type=join_type) + expression=self.expression.join( + other_df.latest_cte_name, on=join_clause.expression, join_type=join_type + ) + ) + new_df.expression = new_df._add_ctes_to_expression( + new_df.expression, other_df.expression.ctes ) - new_df.expression = new_df._add_ctes_to_expression(new_df.expression, other_df.expression.ctes) new_df.pending_hints.extend(other_df.pending_hints) new_df = new_df.select.__wrapped__(new_df, *all_columns) return new_df @operation(Operation.ORDER_BY) def orderBy( - self, *cols: t.Union[str, Column], ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None + self, + *cols: t.Union[str, Column], + ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, ) -> DataFrame: """ This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark @@ -429,7 +483,10 @@ class DataFrame: columns = self._ensure_and_normalize_cols(cols) pre_ordered_col_indexes = [ x - for x in [i if isinstance(col.expression, exp.Ordered) else None for i, col in enumerate(columns)] + for x in [ + i if isinstance(col.expression, exp.Ordered) else None + for i, col in enumerate(columns) + ] if x is not None ] if ascending is None: @@ -478,7 +535,9 @@ class DataFrame: for r_column in r_columns_unused: l_expressions.append(exp.alias_(exp.Null(), r_column)) r_expressions.append(r_column) - r_df = other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) + r_df = ( + other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) + ) l_df = self.copy() if allowMissingColumns: l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) @@ -536,7 +595,9 @@ class DataFrame: f"The minimum num nulls for dropna must be less than or equal to the number of columns. " f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" ) - if_null_checks = [F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns] + if_null_checks = [ + F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns + ] nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) num_nulls = nulls_added_together.alias("num_nulls") new_df = new_df.select(num_nulls, append=True) @@ -576,11 +637,15 @@ class DataFrame: value_columns = [lit(value) for value in values] null_replacement_mapping = { - column.alias_or_name: (F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)) + column.alias_or_name: ( + F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) + ) for column, value in zip(columns, value_columns) } null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} - null_replacement_columns = [null_replacement_mapping[column.alias_or_name] for column in all_columns] + null_replacement_columns = [ + null_replacement_mapping[column.alias_or_name] for column in all_columns + ] new_df = new_df.select(*null_replacement_columns) return new_df @@ -589,12 +654,11 @@ class DataFrame: self, to_replace: t.Union[bool, int, float, str, t.List, t.Dict], value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, - subset: t.Optional[t.Union[str, t.List[str]]] = None, + subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, ) -> DataFrame: from sqlglot.dataframe.sql.functions import lit old_values = None - subset = ensure_list(subset) new_df = self.copy() all_columns = self._get_outer_select_columns(new_df.expression) all_column_mapping = {column.alias_or_name: column for column in all_columns} @@ -605,7 +669,9 @@ class DataFrame: new_values = list(to_replace.values()) elif not old_values and isinstance(to_replace, list): assert isinstance(value, list), "value must be a list since the replacements are a list" - assert len(to_replace) == len(value), "the replacements and values must be the same length" + assert len(to_replace) == len( + value + ), "the replacements and values must be the same length" old_values = to_replace new_values = value else: @@ -635,7 +701,9 @@ class DataFrame: def withColumn(self, colName: str, col: Column) -> DataFrame: col = self._ensure_and_normalize_col(col) existing_col_names = self.expression.named_selects - existing_col_index = existing_col_names.index(colName) if colName in existing_col_names else None + existing_col_index = ( + existing_col_names.index(colName) if colName in existing_col_names else None + ) if existing_col_index: expression = self.expression.copy() expression.expressions[existing_col_index] = col.expression @@ -645,7 +713,11 @@ class DataFrame: @operation(Operation.SELECT) def withColumnRenamed(self, existing: str, new: str): expression = self.expression.copy() - existing_columns = [expression for expression in expression.expressions if expression.alias_or_name == existing] + existing_columns = [ + expression + for expression in expression.expressions + if expression.alias_or_name == existing + ] if not existing_columns: raise ValueError("Tried to rename a column that doesn't exist") for existing_column in existing_columns: @@ -674,15 +746,19 @@ class DataFrame: def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: parameter_list = ensure_list(parameters) parameter_columns = ( - self._ensure_list_of_columns(parameter_list) if parameters else Column.ensure_cols([self.sequence_id]) + self._ensure_list_of_columns(parameter_list) + if parameters + else Column.ensure_cols([self.sequence_id]) ) return self._hint(name, parameter_columns) @operation(Operation.NO_OP) - def repartition(self, numPartitions: t.Union[int, str], *cols: t.Union[int, str]) -> DataFrame: - num_partitions = Column.ensure_cols(ensure_list(numPartitions)) + def repartition( + self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName + ) -> DataFrame: + num_partition_cols = self._ensure_list_of_columns(numPartitions) columns = self._ensure_and_normalize_cols(cols) - args = num_partitions + columns + args = num_partition_cols + columns return self._hint("repartition", args) @operation(Operation.NO_OP) |