diff options
Diffstat (limited to 'sqlglot/dataframe')
-rw-r--r-- | sqlglot/dataframe/sql/column.py | 8 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/dataframe.py | 22 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/functions.py | 2 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/normalize.py | 4 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/session.py | 8 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/window.py | 14 |
6 files changed, 33 insertions, 25 deletions
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index ca85376..724c5bf 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -144,9 +144,11 @@ class Column: ) -> Column: ensured_column = None if column is None else cls.ensure_col(column) ensure_expression_values = { - k: [Column.ensure_col(x).expression for x in v] - if is_iterable(v) - else Column.ensure_col(v).expression + k: ( + [Column.ensure_col(x).expression for x in v] + if is_iterable(v) + else Column.ensure_col(v).expression + ) for k, v in kwargs.items() if v is not None } diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 68d36fe..0bacbf9 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -140,12 +140,10 @@ class DataFrame: return cte, name @t.overload - def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: - ... + def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ... @t.overload - def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: - ... + def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ... def _ensure_list_of_columns(self, cols): return Column.ensure_cols(ensure_list(cols)) @@ -496,9 +494,11 @@ class DataFrame: join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list select_column_names = [ - column.alias_or_name - if not isinstance(column.expression.this, exp.Star) - else column.sql() + ( + column.alias_or_name + if not isinstance(column.expression.this, exp.Star) + else column.sql() + ) for column in self_columns + other_columns ] select_column_names = [ @@ -552,9 +552,11 @@ class DataFrame: ), "The length of items in ascending must equal the number of columns provided" col_and_ascending = list(zip(columns, ascending)) order_by_columns = [ - exp.Ordered(this=col.expression, desc=not asc) - if i not in pre_ordered_col_indexes - else columns[i].column_expression + ( + exp.Ordered(this=col.expression, desc=not asc) + if i not in pre_ordered_col_indexes + else columns[i].column_expression + ) for i, (col, asc) in enumerate(col_and_ascending) ] return self.copy(expression=self.expression.order_by(*order_by_columns)) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index 141a302..a388cb4 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -661,7 +661,7 @@ def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column: def to_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column: tz_column = tz if isinstance(tz, Column) else lit(tz) - return Column.invoke_anonymous_function(timestamp, "TO_UTC_TIMESTAMP", tz_column) + return Column.invoke_expression_over_column(timestamp, expression.FromTimeZone, zone=tz_column) def timestamp_seconds(col: ColumnOrName) -> Column: diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py index f68bacb..b246641 100644 --- a/sqlglot/dataframe/sql/normalize.py +++ b/sqlglot/dataframe/sql/normalize.py @@ -7,11 +7,11 @@ from sqlglot.dataframe.sql.column import Column from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join from sqlglot.helper import ensure_list -NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column]) - if t.TYPE_CHECKING: from sqlglot.dataframe.sql.session import SparkSession + NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column]) + def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]): expr = ensure_list(expr) diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py index 4a33ef9..f518ac2 100644 --- a/sqlglot/dataframe/sql/session.py +++ b/sqlglot/dataframe/sql/session.py @@ -82,9 +82,11 @@ class SparkSession: ] sel_columns = [ - F.col(name).cast(data_type).alias(name).expression - if data_type is not None - else F.col(name).expression + ( + F.col(name).cast(data_type).alias(name).expression + if data_type is not None + else F.col(name).expression + ) for name, data_type in column_mapping.items() ] diff --git a/sqlglot/dataframe/sql/window.py b/sqlglot/dataframe/sql/window.py index c1d913f..9e2fabd 100644 --- a/sqlglot/dataframe/sql/window.py +++ b/sqlglot/dataframe/sql/window.py @@ -90,9 +90,11 @@ class WindowSpec: **kwargs, **{ "start_side": "PRECEDING", - "start": "UNBOUNDED" - if start <= Window.unboundedPreceding - else F.lit(start).expression, + "start": ( + "UNBOUNDED" + if start <= Window.unboundedPreceding + else F.lit(start).expression + ), }, } if end == Window.currentRow: @@ -102,9 +104,9 @@ class WindowSpec: **kwargs, **{ "end_side": "FOLLOWING", - "end": "UNBOUNDED" - if end >= Window.unboundedFollowing - else F.lit(end).expression, + "end": ( + "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression + ), }, } return kwargs |