diff options
Diffstat (limited to 'sqlglot/dataframe/sql/functions.py')
-rw-r--r-- | sqlglot/dataframe/sql/functions.py | 30 |
1 files changed, 23 insertions, 7 deletions
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index bdc1fb4..1549a07 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -600,8 +600,13 @@ def months_between( date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None ) -> Column: if roundOff is None: - return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2) - return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff) + return Column.invoke_expression_over_column( + date1, expression.MonthsBetween, expression=date2 + ) + + return Column.invoke_expression_over_column( + date1, expression.MonthsBetween, expression=date2, roundoff=roundOff + ) def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column: @@ -614,8 +619,9 @@ def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column: def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column: if format is not None: - return Column.invoke_anonymous_function(col, "TO_TIMESTAMP", lit(format)) - return Column.invoke_anonymous_function(col, "TO_TIMESTAMP") + return Column.invoke_expression_over_column(col, expression.StrToTime, format=lit(format)) + + return Column.ensure_col(col).cast("timestamp") def trunc(col: ColumnOrName, format: str) -> Column: @@ -875,8 +881,16 @@ def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) ) -def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column: - return Column.invoke_anonymous_function(str, "REGEXP_REPLACE", lit(pattern), lit(replacement)) +def regexp_replace( + str: ColumnOrName, pattern: str, replacement: str, position: t.Optional[int] = None +) -> Column: + return Column.invoke_expression_over_column( + str, + expression.RegexpReplace, + expression=lit(pattern), + replacement=lit(replacement), + position=position, + ) def initcap(col: ColumnOrName) -> Column: @@ -1186,7 +1200,9 @@ def transform( f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], ) -> Column: f_expression = _get_lambda_from_func(f) - return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression)) + return Column.invoke_expression_over_column( + col, expression.Transform, expression=Column(f_expression) + ) def exists(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column: |