summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql/dataframe.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dataframe/sql/dataframe.py')
-rw-r--r--sqlglot/dataframe/sql/dataframe.py158
1 files changed, 117 insertions, 41 deletions
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 322dcf2..40cd6c9 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -21,7 +21,12 @@ from sqlglot.optimizer import optimize as optimize_func
from sqlglot.optimizer.qualify_columns import qualify_columns
if t.TYPE_CHECKING:
- from sqlglot.dataframe.sql._typing import ColumnLiterals, OutputExpressionContainer
+ from sqlglot.dataframe.sql._typing import (
+ ColumnLiterals,
+ ColumnOrLiteral,
+ ColumnOrName,
+ OutputExpressionContainer,
+ )
from sqlglot.dataframe.sql.session import SparkSession
@@ -83,7 +88,9 @@ class DataFrame:
return from_exp.alias_or_name
table_alias = from_exp.find(exp.TableAlias)
if not table_alias:
- raise RuntimeError(f"Could not find an alias name for this expression: {self.expression}")
+ raise RuntimeError(
+ f"Could not find an alias name for this expression: {self.expression}"
+ )
return table_alias.alias_or_name
return self.expression.ctes[-1].alias
@@ -132,12 +139,16 @@ class DataFrame:
cte.set("sequence_id", sequence_id or self.sequence_id)
return cte, name
- def _ensure_list_of_columns(
- self, cols: t.Union[str, t.Iterable[str], Column, t.Iterable[Column]]
- ) -> t.List[Column]:
- columns = ensure_list(cols)
- columns = Column.ensure_cols(columns)
- return columns
+ @t.overload
+ def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
+ ...
+
+ @t.overload
+ def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
+ ...
+
+ def _ensure_list_of_columns(self, cols):
+ return Column.ensure_cols(ensure_list(cols))
def _ensure_and_normalize_cols(self, cols):
cols = self._ensure_list_of_columns(cols)
@@ -153,10 +164,16 @@ class DataFrame:
df = self._resolve_pending_hints()
sequence_id = sequence_id or df.sequence_id
expression = df.expression.copy()
- cte_expression, cte_name = df._create_cte_from_expression(expression=expression, sequence_id=sequence_id)
- new_expression = df._add_ctes_to_expression(exp.Select(), expression.ctes + [cte_expression])
+ cte_expression, cte_name = df._create_cte_from_expression(
+ expression=expression, sequence_id=sequence_id
+ )
+ new_expression = df._add_ctes_to_expression(
+ exp.Select(), expression.ctes + [cte_expression]
+ )
sel_columns = df._get_outer_select_columns(cte_expression)
- new_expression = new_expression.from_(cte_name).select(*[x.alias_or_name for x in sel_columns])
+ new_expression = new_expression.from_(cte_name).select(
+ *[x.alias_or_name for x in sel_columns]
+ )
return df.copy(expression=new_expression, sequence_id=sequence_id)
def _resolve_pending_hints(self) -> DataFrame:
@@ -169,16 +186,23 @@ class DataFrame:
hint_expression.args.get("expressions").append(hint)
df.pending_hints.remove(hint)
- join_aliases = {join_table.alias_or_name for join_table in get_tables_from_expression_with_join(expression)}
+ join_aliases = {
+ join_table.alias_or_name
+ for join_table in get_tables_from_expression_with_join(expression)
+ }
if join_aliases:
for hint in df.pending_join_hints:
for sequence_id_expression in hint.expressions:
sequence_id_or_name = sequence_id_expression.alias_or_name
sequence_ids_to_match = [sequence_id_or_name]
if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
- sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[sequence_id_or_name]
+ sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[
+ sequence_id_or_name
+ ]
matching_ctes = [
- cte for cte in reversed(expression.ctes) if cte.args["sequence_id"] in sequence_ids_to_match
+ cte
+ for cte in reversed(expression.ctes)
+ if cte.args["sequence_id"] in sequence_ids_to_match
]
for matching_cte in matching_ctes:
if matching_cte.alias_or_name in join_aliases:
@@ -193,9 +217,14 @@ class DataFrame:
def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
hint_name = hint_name.upper()
hint_expression = (
- exp.JoinHint(this=hint_name, expressions=[exp.to_table(parameter.alias_or_name) for parameter in args])
+ exp.JoinHint(
+ this=hint_name,
+ expressions=[exp.to_table(parameter.alias_or_name) for parameter in args],
+ )
if hint_name in JOIN_HINTS
- else exp.Anonymous(this=hint_name, expressions=[parameter.expression for parameter in args])
+ else exp.Anonymous(
+ this=hint_name, expressions=[parameter.expression for parameter in args]
+ )
)
new_df = self.copy()
new_df.pending_hints.append(hint_expression)
@@ -245,7 +274,9 @@ class DataFrame:
def _get_select_expressions(
self,
) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
- select_expressions: t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]] = []
+ select_expressions: t.List[
+ t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]
+ ] = []
main_select_ctes: t.List[exp.CTE] = []
for cte in self.expression.ctes:
cache_storage_level = cte.args.get("cache_storage_level")
@@ -279,14 +310,19 @@ class DataFrame:
cache_table_name = df._create_hash_from_expression(select_expression)
cache_table = exp.to_table(cache_table_name)
original_alias_name = select_expression.args["cte_alias_name"]
- replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(cache_table_name)
+
+ replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore
+ cache_table_name
+ )
sqlglot.schema.add_table(cache_table_name, select_expression.named_selects)
cache_storage_level = select_expression.args["cache_storage_level"]
options = [
exp.Literal.string("storageLevel"),
exp.Literal.string(cache_storage_level),
]
- expression = exp.Cache(this=cache_table, expression=select_expression, lazy=True, options=options)
+ expression = exp.Cache(
+ this=cache_table, expression=select_expression, lazy=True, options=options
+ )
# We will drop the "view" if it exists before running the cache table
output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
elif expression_type == exp.Create:
@@ -305,7 +341,9 @@ class DataFrame:
raise ValueError(f"Invalid expression type: {expression_type}")
output_expressions.append(expression)
- return [expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions]
+ return [
+ expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions
+ ]
def copy(self, **kwargs) -> DataFrame:
return DataFrame(**object_to_dict(self, **kwargs))
@@ -317,7 +355,9 @@ class DataFrame:
if self.expression.args.get("joins"):
ambiguous_cols = [col for col in cols if not col.column_expression.table]
if ambiguous_cols:
- join_table_identifiers = [x.this for x in get_tables_from_expression_with_join(self.expression)]
+ join_table_identifiers = [
+ x.this for x in get_tables_from_expression_with_join(self.expression)
+ ]
cte_names_in_join = [x.this for x in join_table_identifiers]
for ambiguous_col in ambiguous_cols:
ctes_with_column = [
@@ -367,14 +407,20 @@ class DataFrame:
@operation(Operation.FROM)
def join(
- self, other_df: DataFrame, on: t.Union[str, t.List[str], Column, t.List[Column]], how: str = "inner", **kwargs
+ self,
+ other_df: DataFrame,
+ on: t.Union[str, t.List[str], Column, t.List[Column]],
+ how: str = "inner",
+ **kwargs,
) -> DataFrame:
other_df = other_df._convert_leaf_to_cte()
pre_join_self_latest_cte_name = self.latest_cte_name
columns = self._ensure_and_normalize_cols(on)
join_type = how.replace("_", " ")
if isinstance(columns[0].expression, exp.Column):
- join_columns = [Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns]
+ join_columns = [
+ Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns
+ ]
join_clause = functools.reduce(
lambda x, y: x & y,
[
@@ -402,7 +448,9 @@ class DataFrame:
for column in self._get_outer_select_columns(other_df)
]
column_value_mapping = {
- column.alias_or_name if not isinstance(column.expression.this, exp.Star) else column.sql(): column
+ column.alias_or_name
+ if not isinstance(column.expression.this, exp.Star)
+ else column.sql(): column
for column in other_columns + self_columns + join_columns
}
all_columns = [
@@ -410,16 +458,22 @@ class DataFrame:
for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns}
]
new_df = self.copy(
- expression=self.expression.join(other_df.latest_cte_name, on=join_clause.expression, join_type=join_type)
+ expression=self.expression.join(
+ other_df.latest_cte_name, on=join_clause.expression, join_type=join_type
+ )
+ )
+ new_df.expression = new_df._add_ctes_to_expression(
+ new_df.expression, other_df.expression.ctes
)
- new_df.expression = new_df._add_ctes_to_expression(new_df.expression, other_df.expression.ctes)
new_df.pending_hints.extend(other_df.pending_hints)
new_df = new_df.select.__wrapped__(new_df, *all_columns)
return new_df
@operation(Operation.ORDER_BY)
def orderBy(
- self, *cols: t.Union[str, Column], ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None
+ self,
+ *cols: t.Union[str, Column],
+ ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
) -> DataFrame:
"""
This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
@@ -429,7 +483,10 @@ class DataFrame:
columns = self._ensure_and_normalize_cols(cols)
pre_ordered_col_indexes = [
x
- for x in [i if isinstance(col.expression, exp.Ordered) else None for i, col in enumerate(columns)]
+ for x in [
+ i if isinstance(col.expression, exp.Ordered) else None
+ for i, col in enumerate(columns)
+ ]
if x is not None
]
if ascending is None:
@@ -478,7 +535,9 @@ class DataFrame:
for r_column in r_columns_unused:
l_expressions.append(exp.alias_(exp.Null(), r_column))
r_expressions.append(r_column)
- r_df = other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
+ r_df = (
+ other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
+ )
l_df = self.copy()
if allowMissingColumns:
l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
@@ -536,7 +595,9 @@ class DataFrame:
f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
)
- if_null_checks = [F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns]
+ if_null_checks = [
+ F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
+ ]
nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
num_nulls = nulls_added_together.alias("num_nulls")
new_df = new_df.select(num_nulls, append=True)
@@ -576,11 +637,15 @@ class DataFrame:
value_columns = [lit(value) for value in values]
null_replacement_mapping = {
- column.alias_or_name: (F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name))
+ column.alias_or_name: (
+ F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
+ )
for column, value in zip(columns, value_columns)
}
null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
- null_replacement_columns = [null_replacement_mapping[column.alias_or_name] for column in all_columns]
+ null_replacement_columns = [
+ null_replacement_mapping[column.alias_or_name] for column in all_columns
+ ]
new_df = new_df.select(*null_replacement_columns)
return new_df
@@ -589,12 +654,11 @@ class DataFrame:
self,
to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
- subset: t.Optional[t.Union[str, t.List[str]]] = None,
+ subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
) -> DataFrame:
from sqlglot.dataframe.sql.functions import lit
old_values = None
- subset = ensure_list(subset)
new_df = self.copy()
all_columns = self._get_outer_select_columns(new_df.expression)
all_column_mapping = {column.alias_or_name: column for column in all_columns}
@@ -605,7 +669,9 @@ class DataFrame:
new_values = list(to_replace.values())
elif not old_values and isinstance(to_replace, list):
assert isinstance(value, list), "value must be a list since the replacements are a list"
- assert len(to_replace) == len(value), "the replacements and values must be the same length"
+ assert len(to_replace) == len(
+ value
+ ), "the replacements and values must be the same length"
old_values = to_replace
new_values = value
else:
@@ -635,7 +701,9 @@ class DataFrame:
def withColumn(self, colName: str, col: Column) -> DataFrame:
col = self._ensure_and_normalize_col(col)
existing_col_names = self.expression.named_selects
- existing_col_index = existing_col_names.index(colName) if colName in existing_col_names else None
+ existing_col_index = (
+ existing_col_names.index(colName) if colName in existing_col_names else None
+ )
if existing_col_index:
expression = self.expression.copy()
expression.expressions[existing_col_index] = col.expression
@@ -645,7 +713,11 @@ class DataFrame:
@operation(Operation.SELECT)
def withColumnRenamed(self, existing: str, new: str):
expression = self.expression.copy()
- existing_columns = [expression for expression in expression.expressions if expression.alias_or_name == existing]
+ existing_columns = [
+ expression
+ for expression in expression.expressions
+ if expression.alias_or_name == existing
+ ]
if not existing_columns:
raise ValueError("Tried to rename a column that doesn't exist")
for existing_column in existing_columns:
@@ -674,15 +746,19 @@ class DataFrame:
def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
parameter_list = ensure_list(parameters)
parameter_columns = (
- self._ensure_list_of_columns(parameter_list) if parameters else Column.ensure_cols([self.sequence_id])
+ self._ensure_list_of_columns(parameter_list)
+ if parameters
+ else Column.ensure_cols([self.sequence_id])
)
return self._hint(name, parameter_columns)
@operation(Operation.NO_OP)
- def repartition(self, numPartitions: t.Union[int, str], *cols: t.Union[int, str]) -> DataFrame:
- num_partitions = Column.ensure_cols(ensure_list(numPartitions))
+ def repartition(
+ self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
+ ) -> DataFrame:
+ num_partition_cols = self._ensure_list_of_columns(numPartitions)
columns = self._ensure_and_normalize_cols(cols)
- args = num_partitions + columns
+ args = num_partition_cols + columns
return self._hint("repartition", args)
@operation(Operation.NO_OP)