diff options
Diffstat (limited to 'sqlglot/dataframe')
-rw-r--r-- | sqlglot/dataframe/sql/_typing.pyi | 14 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/column.py | 46 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/dataframe.py | 158 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/functions.py | 100 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/group.py | 10 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/normalize.py | 13 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/readwriter.py | 16 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/session.py | 17 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/types.py | 6 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/window.py | 27 |
10 files changed, 305 insertions, 102 deletions
diff --git a/sqlglot/dataframe/sql/_typing.pyi b/sqlglot/dataframe/sql/_typing.pyi index f1a03ea..67c8c09 100644 --- a/sqlglot/dataframe/sql/_typing.pyi +++ b/sqlglot/dataframe/sql/_typing.pyi @@ -10,11 +10,17 @@ if t.TYPE_CHECKING: from sqlglot.dataframe.sql.types import StructType ColumnLiterals = t.TypeVar( - "ColumnLiterals", bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] + "ColumnLiterals", + bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime], ) ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str]) ColumnOrLiteral = t.TypeVar( - "ColumnOrLiteral", bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] + "ColumnOrLiteral", + bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime], +) +SchemaInput = t.TypeVar( + "SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]] +) +OutputExpressionContainer = t.TypeVar( + "OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert] ) -SchemaInput = t.TypeVar("SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]]) -OutputExpressionContainer = t.TypeVar("OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert]) diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index e66aaa8..f9e1c5b 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -18,7 +18,11 @@ class Column: expression = expression.expression # type: ignore elif expression is None or not isinstance(expression, (str, exp.Expression)): expression = self._lit(expression).expression # type: ignore - self.expression: exp.Expression = sqlglot.maybe_parse(expression, dialect="spark") + + expression = sqlglot.maybe_parse(expression, dialect="spark") + if expression is None: + raise ValueError(f"Could not parse {expression}") + self.expression: exp.Expression = expression def __repr__(self): return repr(self.expression) @@ -135,21 +139,29 @@ class Column: ) -> Column: ensured_column = None if column is None else cls.ensure_col(column) ensure_expression_values = { - k: [Column.ensure_col(x).expression for x in v] if is_iterable(v) else Column.ensure_col(v).expression + k: [Column.ensure_col(x).expression for x in v] + if is_iterable(v) + else Column.ensure_col(v).expression for k, v in kwargs.items() } new_expression = ( callable_expression(**ensure_expression_values) if ensured_column is None - else callable_expression(this=ensured_column.column_expression, **ensure_expression_values) + else callable_expression( + this=ensured_column.column_expression, **ensure_expression_values + ) ) return Column(new_expression) def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: - return Column(klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)) + return Column( + klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs) + ) def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: - return Column(klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)) + return Column( + klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs) + ) def unary_op(self, klass: t.Callable, **kwargs) -> Column: return Column(klass(this=self.column_expression, **kwargs)) @@ -188,7 +200,7 @@ class Column: expression.set("table", exp.to_identifier(table_name)) return Column(expression) - def sql(self, **kwargs) -> Column: + def sql(self, **kwargs) -> str: return self.expression.sql(**{"dialect": "spark", **kwargs}) def alias(self, name: str) -> Column: @@ -265,10 +277,14 @@ class Column: ) def like(self, other: str): - return self.invoke_expression_over_column(self, exp.Like, expression=self._lit(other).expression) + return self.invoke_expression_over_column( + self, exp.Like, expression=self._lit(other).expression + ) def ilike(self, other: str): - return self.invoke_expression_over_column(self, exp.ILike, expression=self._lit(other).expression) + return self.invoke_expression_over_column( + self, exp.ILike, expression=self._lit(other).expression + ) def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos @@ -287,10 +303,18 @@ class Column: lowerBound: t.Union[ColumnOrLiteral], upperBound: t.Union[ColumnOrLiteral], ) -> Column: - lower_bound_exp = self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound - upper_bound_exp = self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound + lower_bound_exp = ( + self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound + ) + upper_bound_exp = ( + self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound + ) return Column( - exp.Between(this=self.column_expression, low=lower_bound_exp.expression, high=upper_bound_exp.expression) + exp.Between( + this=self.column_expression, + low=lower_bound_exp.expression, + high=upper_bound_exp.expression, + ) ) def over(self, window: WindowSpec) -> Column: diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 322dcf2..40cd6c9 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -21,7 +21,12 @@ from sqlglot.optimizer import optimize as optimize_func from sqlglot.optimizer.qualify_columns import qualify_columns if t.TYPE_CHECKING: - from sqlglot.dataframe.sql._typing import ColumnLiterals, OutputExpressionContainer + from sqlglot.dataframe.sql._typing import ( + ColumnLiterals, + ColumnOrLiteral, + ColumnOrName, + OutputExpressionContainer, + ) from sqlglot.dataframe.sql.session import SparkSession @@ -83,7 +88,9 @@ class DataFrame: return from_exp.alias_or_name table_alias = from_exp.find(exp.TableAlias) if not table_alias: - raise RuntimeError(f"Could not find an alias name for this expression: {self.expression}") + raise RuntimeError( + f"Could not find an alias name for this expression: {self.expression}" + ) return table_alias.alias_or_name return self.expression.ctes[-1].alias @@ -132,12 +139,16 @@ class DataFrame: cte.set("sequence_id", sequence_id or self.sequence_id) return cte, name - def _ensure_list_of_columns( - self, cols: t.Union[str, t.Iterable[str], Column, t.Iterable[Column]] - ) -> t.List[Column]: - columns = ensure_list(cols) - columns = Column.ensure_cols(columns) - return columns + @t.overload + def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: + ... + + @t.overload + def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: + ... + + def _ensure_list_of_columns(self, cols): + return Column.ensure_cols(ensure_list(cols)) def _ensure_and_normalize_cols(self, cols): cols = self._ensure_list_of_columns(cols) @@ -153,10 +164,16 @@ class DataFrame: df = self._resolve_pending_hints() sequence_id = sequence_id or df.sequence_id expression = df.expression.copy() - cte_expression, cte_name = df._create_cte_from_expression(expression=expression, sequence_id=sequence_id) - new_expression = df._add_ctes_to_expression(exp.Select(), expression.ctes + [cte_expression]) + cte_expression, cte_name = df._create_cte_from_expression( + expression=expression, sequence_id=sequence_id + ) + new_expression = df._add_ctes_to_expression( + exp.Select(), expression.ctes + [cte_expression] + ) sel_columns = df._get_outer_select_columns(cte_expression) - new_expression = new_expression.from_(cte_name).select(*[x.alias_or_name for x in sel_columns]) + new_expression = new_expression.from_(cte_name).select( + *[x.alias_or_name for x in sel_columns] + ) return df.copy(expression=new_expression, sequence_id=sequence_id) def _resolve_pending_hints(self) -> DataFrame: @@ -169,16 +186,23 @@ class DataFrame: hint_expression.args.get("expressions").append(hint) df.pending_hints.remove(hint) - join_aliases = {join_table.alias_or_name for join_table in get_tables_from_expression_with_join(expression)} + join_aliases = { + join_table.alias_or_name + for join_table in get_tables_from_expression_with_join(expression) + } if join_aliases: for hint in df.pending_join_hints: for sequence_id_expression in hint.expressions: sequence_id_or_name = sequence_id_expression.alias_or_name sequence_ids_to_match = [sequence_id_or_name] if sequence_id_or_name in df.spark.name_to_sequence_id_mapping: - sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[sequence_id_or_name] + sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[ + sequence_id_or_name + ] matching_ctes = [ - cte for cte in reversed(expression.ctes) if cte.args["sequence_id"] in sequence_ids_to_match + cte + for cte in reversed(expression.ctes) + if cte.args["sequence_id"] in sequence_ids_to_match ] for matching_cte in matching_ctes: if matching_cte.alias_or_name in join_aliases: @@ -193,9 +217,14 @@ class DataFrame: def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame: hint_name = hint_name.upper() hint_expression = ( - exp.JoinHint(this=hint_name, expressions=[exp.to_table(parameter.alias_or_name) for parameter in args]) + exp.JoinHint( + this=hint_name, + expressions=[exp.to_table(parameter.alias_or_name) for parameter in args], + ) if hint_name in JOIN_HINTS - else exp.Anonymous(this=hint_name, expressions=[parameter.expression for parameter in args]) + else exp.Anonymous( + this=hint_name, expressions=[parameter.expression for parameter in args] + ) ) new_df = self.copy() new_df.pending_hints.append(hint_expression) @@ -245,7 +274,9 @@ class DataFrame: def _get_select_expressions( self, ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]: - select_expressions: t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]] = [] + select_expressions: t.List[ + t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select] + ] = [] main_select_ctes: t.List[exp.CTE] = [] for cte in self.expression.ctes: cache_storage_level = cte.args.get("cache_storage_level") @@ -279,14 +310,19 @@ class DataFrame: cache_table_name = df._create_hash_from_expression(select_expression) cache_table = exp.to_table(cache_table_name) original_alias_name = select_expression.args["cte_alias_name"] - replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(cache_table_name) + + replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore + cache_table_name + ) sqlglot.schema.add_table(cache_table_name, select_expression.named_selects) cache_storage_level = select_expression.args["cache_storage_level"] options = [ exp.Literal.string("storageLevel"), exp.Literal.string(cache_storage_level), ] - expression = exp.Cache(this=cache_table, expression=select_expression, lazy=True, options=options) + expression = exp.Cache( + this=cache_table, expression=select_expression, lazy=True, options=options + ) # We will drop the "view" if it exists before running the cache table output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) elif expression_type == exp.Create: @@ -305,7 +341,9 @@ class DataFrame: raise ValueError(f"Invalid expression type: {expression_type}") output_expressions.append(expression) - return [expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions] + return [ + expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions + ] def copy(self, **kwargs) -> DataFrame: return DataFrame(**object_to_dict(self, **kwargs)) @@ -317,7 +355,9 @@ class DataFrame: if self.expression.args.get("joins"): ambiguous_cols = [col for col in cols if not col.column_expression.table] if ambiguous_cols: - join_table_identifiers = [x.this for x in get_tables_from_expression_with_join(self.expression)] + join_table_identifiers = [ + x.this for x in get_tables_from_expression_with_join(self.expression) + ] cte_names_in_join = [x.this for x in join_table_identifiers] for ambiguous_col in ambiguous_cols: ctes_with_column = [ @@ -367,14 +407,20 @@ class DataFrame: @operation(Operation.FROM) def join( - self, other_df: DataFrame, on: t.Union[str, t.List[str], Column, t.List[Column]], how: str = "inner", **kwargs + self, + other_df: DataFrame, + on: t.Union[str, t.List[str], Column, t.List[Column]], + how: str = "inner", + **kwargs, ) -> DataFrame: other_df = other_df._convert_leaf_to_cte() pre_join_self_latest_cte_name = self.latest_cte_name columns = self._ensure_and_normalize_cols(on) join_type = how.replace("_", " ") if isinstance(columns[0].expression, exp.Column): - join_columns = [Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns] + join_columns = [ + Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns + ] join_clause = functools.reduce( lambda x, y: x & y, [ @@ -402,7 +448,9 @@ class DataFrame: for column in self._get_outer_select_columns(other_df) ] column_value_mapping = { - column.alias_or_name if not isinstance(column.expression.this, exp.Star) else column.sql(): column + column.alias_or_name + if not isinstance(column.expression.this, exp.Star) + else column.sql(): column for column in other_columns + self_columns + join_columns } all_columns = [ @@ -410,16 +458,22 @@ class DataFrame: for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns} ] new_df = self.copy( - expression=self.expression.join(other_df.latest_cte_name, on=join_clause.expression, join_type=join_type) + expression=self.expression.join( + other_df.latest_cte_name, on=join_clause.expression, join_type=join_type + ) + ) + new_df.expression = new_df._add_ctes_to_expression( + new_df.expression, other_df.expression.ctes ) - new_df.expression = new_df._add_ctes_to_expression(new_df.expression, other_df.expression.ctes) new_df.pending_hints.extend(other_df.pending_hints) new_df = new_df.select.__wrapped__(new_df, *all_columns) return new_df @operation(Operation.ORDER_BY) def orderBy( - self, *cols: t.Union[str, Column], ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None + self, + *cols: t.Union[str, Column], + ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, ) -> DataFrame: """ This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark @@ -429,7 +483,10 @@ class DataFrame: columns = self._ensure_and_normalize_cols(cols) pre_ordered_col_indexes = [ x - for x in [i if isinstance(col.expression, exp.Ordered) else None for i, col in enumerate(columns)] + for x in [ + i if isinstance(col.expression, exp.Ordered) else None + for i, col in enumerate(columns) + ] if x is not None ] if ascending is None: @@ -478,7 +535,9 @@ class DataFrame: for r_column in r_columns_unused: l_expressions.append(exp.alias_(exp.Null(), r_column)) r_expressions.append(r_column) - r_df = other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) + r_df = ( + other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) + ) l_df = self.copy() if allowMissingColumns: l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) @@ -536,7 +595,9 @@ class DataFrame: f"The minimum num nulls for dropna must be less than or equal to the number of columns. " f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" ) - if_null_checks = [F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns] + if_null_checks = [ + F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns + ] nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) num_nulls = nulls_added_together.alias("num_nulls") new_df = new_df.select(num_nulls, append=True) @@ -576,11 +637,15 @@ class DataFrame: value_columns = [lit(value) for value in values] null_replacement_mapping = { - column.alias_or_name: (F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)) + column.alias_or_name: ( + F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) + ) for column, value in zip(columns, value_columns) } null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} - null_replacement_columns = [null_replacement_mapping[column.alias_or_name] for column in all_columns] + null_replacement_columns = [ + null_replacement_mapping[column.alias_or_name] for column in all_columns + ] new_df = new_df.select(*null_replacement_columns) return new_df @@ -589,12 +654,11 @@ class DataFrame: self, to_replace: t.Union[bool, int, float, str, t.List, t.Dict], value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, - subset: t.Optional[t.Union[str, t.List[str]]] = None, + subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, ) -> DataFrame: from sqlglot.dataframe.sql.functions import lit old_values = None - subset = ensure_list(subset) new_df = self.copy() all_columns = self._get_outer_select_columns(new_df.expression) all_column_mapping = {column.alias_or_name: column for column in all_columns} @@ -605,7 +669,9 @@ class DataFrame: new_values = list(to_replace.values()) elif not old_values and isinstance(to_replace, list): assert isinstance(value, list), "value must be a list since the replacements are a list" - assert len(to_replace) == len(value), "the replacements and values must be the same length" + assert len(to_replace) == len( + value + ), "the replacements and values must be the same length" old_values = to_replace new_values = value else: @@ -635,7 +701,9 @@ class DataFrame: def withColumn(self, colName: str, col: Column) -> DataFrame: col = self._ensure_and_normalize_col(col) existing_col_names = self.expression.named_selects - existing_col_index = existing_col_names.index(colName) if colName in existing_col_names else None + existing_col_index = ( + existing_col_names.index(colName) if colName in existing_col_names else None + ) if existing_col_index: expression = self.expression.copy() expression.expressions[existing_col_index] = col.expression @@ -645,7 +713,11 @@ class DataFrame: @operation(Operation.SELECT) def withColumnRenamed(self, existing: str, new: str): expression = self.expression.copy() - existing_columns = [expression for expression in expression.expressions if expression.alias_or_name == existing] + existing_columns = [ + expression + for expression in expression.expressions + if expression.alias_or_name == existing + ] if not existing_columns: raise ValueError("Tried to rename a column that doesn't exist") for existing_column in existing_columns: @@ -674,15 +746,19 @@ class DataFrame: def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: parameter_list = ensure_list(parameters) parameter_columns = ( - self._ensure_list_of_columns(parameter_list) if parameters else Column.ensure_cols([self.sequence_id]) + self._ensure_list_of_columns(parameter_list) + if parameters + else Column.ensure_cols([self.sequence_id]) ) return self._hint(name, parameter_columns) @operation(Operation.NO_OP) - def repartition(self, numPartitions: t.Union[int, str], *cols: t.Union[int, str]) -> DataFrame: - num_partitions = Column.ensure_cols(ensure_list(numPartitions)) + def repartition( + self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName + ) -> DataFrame: + num_partition_cols = self._ensure_list_of_columns(numPartitions) columns = self._ensure_and_normalize_cols(cols) - args = num_partitions + columns + args = num_partition_cols + columns return self._hint("repartition", args) @operation(Operation.NO_OP) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index bc002e5..dbfb06f 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -45,7 +45,11 @@ def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: def when(condition: Column, value: t.Any) -> Column: true_value = value if isinstance(value, Column) else lit(value) - return Column(glotexp.Case(ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)])) + return Column( + glotexp.Case( + ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)] + ) + ) def asc(col: ColumnOrName) -> Column: @@ -407,7 +411,9 @@ def percentile_approx( return Column.invoke_expression_over_column( col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy ) - return Column.invoke_expression_over_column(col, glotexp.ApproxQuantile, quantile=lit(percentage)) + return Column.invoke_expression_over_column( + col, glotexp.ApproxQuantile, quantile=lit(percentage) + ) def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column: @@ -471,7 +477,9 @@ def factorial(col: ColumnOrName) -> Column: return Column.invoke_anonymous_function(col, "FACTORIAL") -def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None) -> Column: +def lag( + col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None +) -> Column: if default is not None: return Column.invoke_anonymous_function(col, "LAG", offset, default) if offset != 1: @@ -479,7 +487,9 @@ def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[Colu return Column.invoke_anonymous_function(col, "LAG") -def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None) -> Column: +def lead( + col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None +) -> Column: if default is not None: return Column.invoke_anonymous_function(col, "LEAD", offset, default) if offset != 1: @@ -487,7 +497,9 @@ def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.A return Column.invoke_anonymous_function(col, "LEAD") -def nth_value(col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None) -> Column: +def nth_value( + col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None +) -> Column: if ignoreNulls is not None: raise NotImplementedError("There is currently not support for `ignoreNulls` parameter") if offset != 1: @@ -571,7 +583,9 @@ def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Colum return Column.invoke_anonymous_function(start, "ADD_MONTHS", months) -def months_between(date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None) -> Column: +def months_between( + date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None +) -> Column: if roundOff is None: return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2) return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff) @@ -611,9 +625,13 @@ def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column: return Column.invoke_expression_over_column(col, glotexp.UnixToStr) -def unix_timestamp(timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None) -> Column: +def unix_timestamp( + timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None +) -> Column: if format is not None: - return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix, format=lit(format)) + return Column.invoke_expression_over_column( + timestamp, glotexp.StrToUnix, format=lit(format) + ) return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix) @@ -642,7 +660,9 @@ def window( timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration), lit(startTime) ) if slideDuration is not None: - return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration)) + return Column.invoke_anonymous_function( + timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration) + ) if startTime is not None: return Column.invoke_anonymous_function( timeColumn, "WINDOW", lit(windowDuration), lit(windowDuration), lit(startTime) @@ -731,7 +751,9 @@ def trim(col: ColumnOrName) -> Column: def concat_ws(sep: str, *cols: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols)) + return Column.invoke_expression_over_column( + None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols) + ) def decode(col: ColumnOrName, charset: str) -> Column: @@ -768,7 +790,9 @@ def overlay( def sentences( - string: ColumnOrName, language: t.Optional[ColumnOrName] = None, country: t.Optional[ColumnOrName] = None + string: ColumnOrName, + language: t.Optional[ColumnOrName] = None, + country: t.Optional[ColumnOrName] = None, ) -> Column: if language is not None and country is not None: return Column.invoke_anonymous_function(string, "SENTENCES", language, country) @@ -794,7 +818,9 @@ def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column: def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column: substr_col = lit(substr) if pos is not None: - return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col, position=pos) + return Column.invoke_expression_over_column( + str, glotexp.StrPosition, substr=substr_col, position=pos + ) return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col) @@ -872,7 +898,10 @@ def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore return Column.invoke_expression_over_column( - None, glotexp.VarMap, keys=array(*cols[::2]).expression, values=array(*cols[1::2]).expression + None, + glotexp.VarMap, + keys=array(*cols[::2]).expression, + values=array(*cols[1::2]).expression, ) @@ -882,29 +911,39 @@ def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column: def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column: value_col = value if isinstance(value, Column) else lit(value) - return Column.invoke_expression_over_column(col, glotexp.ArrayContains, expression=value_col.expression) + return Column.invoke_expression_over_column( + col, glotexp.ArrayContains, expression=value_col.expression + ) def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column: return Column.invoke_anonymous_function(col1, "ARRAYS_OVERLAP", Column.ensure_col(col2)) -def slice(x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]) -> Column: +def slice( + x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int] +) -> Column: start_col = start if isinstance(start, Column) else lit(start) length_col = length if isinstance(length, Column) else lit(length) return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col) -def array_join(col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None) -> Column: +def array_join( + col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None +) -> Column: if null_replacement is not None: - return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement)) + return Column.invoke_anonymous_function( + col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement) + ) return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter)) def concat(*cols: ColumnOrName) -> Column: if len(cols) == 1: return Column.invoke_anonymous_function(cols[0], "CONCAT") - return Column.invoke_anonymous_function(cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]]) + return Column.invoke_anonymous_function( + cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]] + ) def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column: @@ -1076,7 +1115,9 @@ def map_concat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT", *columns[1:]) -def sequence(start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None) -> Column: +def sequence( + start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None +) -> Column: if step is not None: return Column.invoke_anonymous_function(start, "SEQUENCE", stop, step) return Column.invoke_anonymous_function(start, "SEQUENCE", stop) @@ -1103,12 +1144,15 @@ def aggregate( merge_exp = _get_lambda_from_func(merge) if finish is not None: finish_exp = _get_lambda_from_func(finish) - return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp)) + return Column.invoke_anonymous_function( + col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp) + ) return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp)) def transform( - col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]] + col: ColumnOrName, + f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], ) -> Column: f_expression = _get_lambda_from_func(f) return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression)) @@ -1124,12 +1168,17 @@ def forall(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column: return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression)) -def filter(col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]) -> Column: +def filter( + col: ColumnOrName, + f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], +) -> Column: f_expression = _get_lambda_from_func(f) return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression) -def zip_with(left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]) -> Column: +def zip_with( + left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column] +) -> Column: f_expression = _get_lambda_from_func(f) return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression)) @@ -1163,7 +1212,10 @@ def _lambda_quoted(value: str) -> t.Optional[bool]: def _get_lambda_from_func(lambda_expression: t.Callable): - variables = [glotexp.to_identifier(x, quoted=_lambda_quoted(x)) for x in lambda_expression.__code__.co_varnames] + variables = [ + glotexp.to_identifier(x, quoted=_lambda_quoted(x)) + for x in lambda_expression.__code__.co_varnames + ] return glotexp.Lambda( this=lambda_expression(*[Column(x) for x in variables]).expression, expressions=variables, diff --git a/sqlglot/dataframe/sql/group.py b/sqlglot/dataframe/sql/group.py index 947aace..ba27c17 100644 --- a/sqlglot/dataframe/sql/group.py +++ b/sqlglot/dataframe/sql/group.py @@ -17,7 +17,9 @@ class GroupedData: self.last_op = last_op self.group_by_cols = group_by_cols - def _get_function_applied_columns(self, func_name: str, cols: t.Tuple[str, ...]) -> t.List[Column]: + def _get_function_applied_columns( + self, func_name: str, cols: t.Tuple[str, ...] + ) -> t.List[Column]: func_name = func_name.lower() return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols] @@ -30,9 +32,9 @@ class GroupedData: ) cols = self._df._ensure_and_normalize_cols(columns) - expression = self._df.expression.group_by(*[x.expression for x in self.group_by_cols]).select( - *[x.expression for x in self.group_by_cols + cols], append=False - ) + expression = self._df.expression.group_by( + *[x.expression for x in self.group_by_cols] + ).select(*[x.expression for x in self.group_by_cols + cols], append=False) return self._df.copy(expression=expression) def count(self) -> DataFrame: diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py index 1513946..75feba7 100644 --- a/sqlglot/dataframe/sql/normalize.py +++ b/sqlglot/dataframe/sql/normalize.py @@ -23,7 +23,9 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[ replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier) -def replace_alias_name_with_cte_name(spark: SparkSession, expression_context: exp.Select, id: exp.Identifier): +def replace_alias_name_with_cte_name( + spark: SparkSession, expression_context: exp.Select, id: exp.Identifier +): if id.alias_or_name in spark.name_to_sequence_id_mapping: for cte in reversed(expression_context.ctes): if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]: @@ -40,8 +42,12 @@ def replace_branch_and_sequence_ids_with_cte_name( # id then it keeps that reference. This handles the weird edge case in spark that shouldn't # be common in practice if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids: - join_table_aliases = [x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)] - ctes_in_join = [cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases] + join_table_aliases = [ + x.alias_or_name for x in get_tables_from_expression_with_join(expression_context) + ] + ctes_in_join = [ + cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases + ] if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]: assert len(ctes_in_join) == 2 _set_alias_name(id, ctes_in_join[0].alias_or_name) @@ -58,7 +64,6 @@ def _set_alias_name(id: exp.Identifier, name: str): def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]: - values = ensure_list(values) results = [] for value in values: if isinstance(value, str): diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py index 4830035..febc664 100644 --- a/sqlglot/dataframe/sql/readwriter.py +++ b/sqlglot/dataframe/sql/readwriter.py @@ -19,12 +19,19 @@ class DataFrameReader: from sqlglot.dataframe.sql.dataframe import DataFrame sqlglot.schema.add_table(tableName) - return DataFrame(self.spark, exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName))) + return DataFrame( + self.spark, + exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)), + ) class DataFrameWriter: def __init__( - self, df: DataFrame, spark: t.Optional[SparkSession] = None, mode: t.Optional[str] = None, by_name: bool = False + self, + df: DataFrame, + spark: t.Optional[SparkSession] = None, + mode: t.Optional[str] = None, + by_name: bool = False, ): self._df = df self._spark = spark or df.spark @@ -33,7 +40,10 @@ class DataFrameWriter: def copy(self, **kwargs) -> DataFrameWriter: return DataFrameWriter( - **{k[1:] if k.startswith("_") else k: v for k, v in object_to_dict(self, **kwargs).items()} + **{ + k[1:] if k.startswith("_") else k: v + for k, v in object_to_dict(self, **kwargs).items() + } ) def sql(self, **kwargs) -> t.List[str]: diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py index 1ea86d1..8cb16ef 100644 --- a/sqlglot/dataframe/sql/session.py +++ b/sqlglot/dataframe/sql/session.py @@ -67,13 +67,20 @@ class SparkSession: data_expressions = [ exp.Tuple( - expressions=list(map(lambda x: F.lit(x).expression, row if not isinstance(row, dict) else row.values())) + expressions=list( + map( + lambda x: F.lit(x).expression, + row if not isinstance(row, dict) else row.values(), + ) + ) ) for row in data ] sel_columns = [ - F.col(name).cast(data_type).alias(name).expression if data_type is not None else F.col(name).expression + F.col(name).cast(data_type).alias(name).expression + if data_type is not None + else F.col(name).expression for name, data_type in column_mapping.items() ] @@ -106,10 +113,12 @@ class SparkSession: select_expression.set("with", expression.args.get("with")) expression.set("with", None) del expression.args["expression"] - df = DataFrame(self, select_expression, output_expression_container=expression) + df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore df = df._convert_leaf_to_cte() else: - raise ValueError("Unknown expression type provided in the SQL. Please create an issue with the SQL.") + raise ValueError( + "Unknown expression type provided in the SQL. Please create an issue with the SQL." + ) return df @property diff --git a/sqlglot/dataframe/sql/types.py b/sqlglot/dataframe/sql/types.py index dc5c05a..a63e505 100644 --- a/sqlglot/dataframe/sql/types.py +++ b/sqlglot/dataframe/sql/types.py @@ -158,7 +158,11 @@ class MapType(DataType): class StructField(DataType): def __init__( - self, name: str, dataType: DataType, nullable: bool = True, metadata: t.Optional[t.Dict[str, t.Any]] = None + self, + name: str, + dataType: DataType, + nullable: bool = True, + metadata: t.Optional[t.Dict[str, t.Any]] = None, ): self.name = name self.dataType = dataType diff --git a/sqlglot/dataframe/sql/window.py b/sqlglot/dataframe/sql/window.py index 842f366..c54c07e 100644 --- a/sqlglot/dataframe/sql/window.py +++ b/sqlglot/dataframe/sql/window.py @@ -74,8 +74,13 @@ class WindowSpec: window_spec.expression.args["order"].set("expressions", order_by) return window_spec - def _calc_start_end(self, start: int, end: int) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]: - kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {"start_side": None, "end_side": None} + def _calc_start_end( + self, start: int, end: int + ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]: + kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = { + "start_side": None, + "end_side": None, + } if start == Window.currentRow: kwargs["start"] = "CURRENT ROW" else: @@ -83,7 +88,9 @@ class WindowSpec: **kwargs, **{ "start_side": "PRECEDING", - "start": "UNBOUNDED" if start <= Window.unboundedPreceding else F.lit(start).expression, + "start": "UNBOUNDED" + if start <= Window.unboundedPreceding + else F.lit(start).expression, }, } if end == Window.currentRow: @@ -93,7 +100,9 @@ class WindowSpec: **kwargs, **{ "end_side": "FOLLOWING", - "end": "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression, + "end": "UNBOUNDED" + if end >= Window.unboundedFollowing + else F.lit(end).expression, }, } return kwargs @@ -103,7 +112,10 @@ class WindowSpec: spec = self._calc_start_end(start, end) spec["kind"] = "ROWS" window_spec.expression.set( - "spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}) + "spec", + exp.WindowSpec( + **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} + ), ) return window_spec @@ -112,6 +124,9 @@ class WindowSpec: spec = self._calc_start_end(start, end) spec["kind"] = "RANGE" window_spec.expression.set( - "spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}) + "spec", + exp.WindowSpec( + **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} + ), ) return window_spec |