summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dataframe')
-rw-r--r--sqlglot/dataframe/sql/column.py8
-rw-r--r--sqlglot/dataframe/sql/dataframe.py22
-rw-r--r--sqlglot/dataframe/sql/functions.py2
-rw-r--r--sqlglot/dataframe/sql/normalize.py4
-rw-r--r--sqlglot/dataframe/sql/session.py8
-rw-r--r--sqlglot/dataframe/sql/window.py14
6 files changed, 33 insertions, 25 deletions
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
index ca85376..724c5bf 100644
--- a/sqlglot/dataframe/sql/column.py
+++ b/sqlglot/dataframe/sql/column.py
@@ -144,9 +144,11 @@ class Column:
) -> Column:
ensured_column = None if column is None else cls.ensure_col(column)
ensure_expression_values = {
- k: [Column.ensure_col(x).expression for x in v]
- if is_iterable(v)
- else Column.ensure_col(v).expression
+ k: (
+ [Column.ensure_col(x).expression for x in v]
+ if is_iterable(v)
+ else Column.ensure_col(v).expression
+ )
for k, v in kwargs.items()
if v is not None
}
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 68d36fe..0bacbf9 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -140,12 +140,10 @@ class DataFrame:
return cte, name
@t.overload
- def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
- ...
+ def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ...
@t.overload
- def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
- ...
+ def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ...
def _ensure_list_of_columns(self, cols):
return Column.ensure_cols(ensure_list(cols))
@@ -496,9 +494,11 @@ class DataFrame:
join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
# To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
select_column_names = [
- column.alias_or_name
- if not isinstance(column.expression.this, exp.Star)
- else column.sql()
+ (
+ column.alias_or_name
+ if not isinstance(column.expression.this, exp.Star)
+ else column.sql()
+ )
for column in self_columns + other_columns
]
select_column_names = [
@@ -552,9 +552,11 @@ class DataFrame:
), "The length of items in ascending must equal the number of columns provided"
col_and_ascending = list(zip(columns, ascending))
order_by_columns = [
- exp.Ordered(this=col.expression, desc=not asc)
- if i not in pre_ordered_col_indexes
- else columns[i].column_expression
+ (
+ exp.Ordered(this=col.expression, desc=not asc)
+ if i not in pre_ordered_col_indexes
+ else columns[i].column_expression
+ )
for i, (col, asc) in enumerate(col_and_ascending)
]
return self.copy(expression=self.expression.order_by(*order_by_columns))
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index 141a302..a388cb4 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -661,7 +661,7 @@ def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
def to_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
tz_column = tz if isinstance(tz, Column) else lit(tz)
- return Column.invoke_anonymous_function(timestamp, "TO_UTC_TIMESTAMP", tz_column)
+ return Column.invoke_expression_over_column(timestamp, expression.FromTimeZone, zone=tz_column)
def timestamp_seconds(col: ColumnOrName) -> Column:
diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py
index f68bacb..b246641 100644
--- a/sqlglot/dataframe/sql/normalize.py
+++ b/sqlglot/dataframe/sql/normalize.py
@@ -7,11 +7,11 @@ from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
from sqlglot.helper import ensure_list
-NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
-
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.session import SparkSession
+ NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
+
def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]):
expr = ensure_list(expr)
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
index 4a33ef9..f518ac2 100644
--- a/sqlglot/dataframe/sql/session.py
+++ b/sqlglot/dataframe/sql/session.py
@@ -82,9 +82,11 @@ class SparkSession:
]
sel_columns = [
- F.col(name).cast(data_type).alias(name).expression
- if data_type is not None
- else F.col(name).expression
+ (
+ F.col(name).cast(data_type).alias(name).expression
+ if data_type is not None
+ else F.col(name).expression
+ )
for name, data_type in column_mapping.items()
]
diff --git a/sqlglot/dataframe/sql/window.py b/sqlglot/dataframe/sql/window.py
index c1d913f..9e2fabd 100644
--- a/sqlglot/dataframe/sql/window.py
+++ b/sqlglot/dataframe/sql/window.py
@@ -90,9 +90,11 @@ class WindowSpec:
**kwargs,
**{
"start_side": "PRECEDING",
- "start": "UNBOUNDED"
- if start <= Window.unboundedPreceding
- else F.lit(start).expression,
+ "start": (
+ "UNBOUNDED"
+ if start <= Window.unboundedPreceding
+ else F.lit(start).expression
+ ),
},
}
if end == Window.currentRow:
@@ -102,9 +104,9 @@ class WindowSpec:
**kwargs,
**{
"end_side": "FOLLOWING",
- "end": "UNBOUNDED"
- if end >= Window.unboundedFollowing
- else F.lit(end).expression,
+ "end": (
+ "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression
+ ),
},
}
return kwargs