summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-02-08 05:38:42 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-02-08 05:38:42 +0000
commitc66e4a33e1a07c439f03fe47f146a6c6482bf6df (patch)
treecfdf01111c063b3e50841695e6c2768833aea4dc /sqlglot/dataframe/sql
parentReleasing debian version 20.11.0-1. (diff)
downloadsqlglot-c66e4a33e1a07c439f03fe47f146a6c6482bf6df.tar.xz
sqlglot-c66e4a33e1a07c439f03fe47f146a6c6482bf6df.zip
Merging upstream version 21.0.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dataframe/sql')
-rw-r--r--sqlglot/dataframe/sql/dataframe.py6
-rw-r--r--sqlglot/dataframe/sql/functions.py35
2 files changed, 23 insertions, 18 deletions
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 0bacbf9..7e3f07b 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -140,10 +140,12 @@ class DataFrame:
return cte, name
@t.overload
- def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ...
+ def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
+ ...
@t.overload
- def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ...
+ def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
+ ...
def _ensure_list_of_columns(self, cols):
return Column.ensure_cols(ensure_list(cols))
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index a388cb4..29e7c55 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -368,7 +368,10 @@ def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column:
def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
- return Column.invoke_expression_over_column(col, expression.First, ignore_nulls=ignorenulls)
+ this = Column.invoke_expression_over_column(col, expression.First)
+ if ignorenulls:
+ return Column.invoke_expression_over_column(this, expression.IgnoreNulls)
+ return this
def grouping_id(*cols: ColumnOrName) -> Column:
@@ -392,7 +395,10 @@ def isnull(col: ColumnOrName) -> Column:
def last(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
- return Column.invoke_expression_over_column(col, expression.Last, ignore_nulls=ignorenulls)
+ this = Column.invoke_expression_over_column(col, expression.Last)
+ if ignorenulls:
+ return Column.invoke_expression_over_column(this, expression.IgnoreNulls)
+ return this
def monotonically_increasing_id() -> Column:
@@ -485,31 +491,28 @@ def factorial(col: ColumnOrName) -> Column:
def lag(
col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None
) -> Column:
- if default is not None:
- return Column.invoke_anonymous_function(col, "LAG", offset, default)
- if offset != 1:
- return Column.invoke_anonymous_function(col, "LAG", offset)
- return Column.invoke_anonymous_function(col, "LAG")
+ return Column.invoke_expression_over_column(
+ col, expression.Lag, offset=None if offset == 1 else offset, default=default
+ )
def lead(
col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None
) -> Column:
- if default is not None:
- return Column.invoke_anonymous_function(col, "LEAD", offset, default)
- if offset != 1:
- return Column.invoke_anonymous_function(col, "LEAD", offset)
- return Column.invoke_anonymous_function(col, "LEAD")
+ return Column.invoke_expression_over_column(
+ col, expression.Lead, offset=None if offset == 1 else offset, default=default
+ )
def nth_value(
col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None
) -> Column:
+ this = Column.invoke_expression_over_column(
+ col, expression.NthValue, offset=None if offset == 1 else offset
+ )
if ignoreNulls is not None:
- raise NotImplementedError("There is currently not support for `ignoreNulls` parameter")
- if offset != 1:
- return Column.invoke_anonymous_function(col, "NTH_VALUE", offset)
- return Column.invoke_anonymous_function(col, "NTH_VALUE")
+ return Column.invoke_expression_over_column(this, expression.IgnoreNulls)
+ return this
def ntile(n: int) -> Column: