diff options
Diffstat (limited to 'sqlglot/dataframe/sql/functions.py')
-rw-r--r-- | sqlglot/dataframe/sql/functions.py | 35 |
1 files changed, 19 insertions, 16 deletions
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index a388cb4..29e7c55 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -368,7 +368,10 @@ def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column: def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column: - return Column.invoke_expression_over_column(col, expression.First, ignore_nulls=ignorenulls) + this = Column.invoke_expression_over_column(col, expression.First) + if ignorenulls: + return Column.invoke_expression_over_column(this, expression.IgnoreNulls) + return this def grouping_id(*cols: ColumnOrName) -> Column: @@ -392,7 +395,10 @@ def isnull(col: ColumnOrName) -> Column: def last(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column: - return Column.invoke_expression_over_column(col, expression.Last, ignore_nulls=ignorenulls) + this = Column.invoke_expression_over_column(col, expression.Last) + if ignorenulls: + return Column.invoke_expression_over_column(this, expression.IgnoreNulls) + return this def monotonically_increasing_id() -> Column: @@ -485,31 +491,28 @@ def factorial(col: ColumnOrName) -> Column: def lag( col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None ) -> Column: - if default is not None: - return Column.invoke_anonymous_function(col, "LAG", offset, default) - if offset != 1: - return Column.invoke_anonymous_function(col, "LAG", offset) - return Column.invoke_anonymous_function(col, "LAG") + return Column.invoke_expression_over_column( + col, expression.Lag, offset=None if offset == 1 else offset, default=default + ) def lead( col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None ) -> Column: - if default is not None: - return Column.invoke_anonymous_function(col, "LEAD", offset, default) - if offset != 1: - return Column.invoke_anonymous_function(col, "LEAD", offset) - return Column.invoke_anonymous_function(col, "LEAD") + return Column.invoke_expression_over_column( + col, expression.Lead, offset=None if offset == 1 else offset, default=default + ) def nth_value( col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None ) -> Column: + this = Column.invoke_expression_over_column( + col, expression.NthValue, offset=None if offset == 1 else offset + ) if ignoreNulls is not None: - raise NotImplementedError("There is currently not support for `ignoreNulls` parameter") - if offset != 1: - return Column.invoke_anonymous_function(col, "NTH_VALUE", offset) - return Column.invoke_anonymous_function(col, "NTH_VALUE") + return Column.invoke_expression_over_column(this, expression.IgnoreNulls) + return this def ntile(n: int) -> Column: |