diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/dataframe/sql/column.py | 4 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/functions.py | 30 |
2 files changed, 25 insertions, 9 deletions
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index f4cfeba..fcfd71e 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -114,7 +114,7 @@ class Column: return self.inverse_binary_op(exp.Or, other) @classmethod - def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): + def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]) -> Column: return cls(value) @classmethod @@ -259,7 +259,7 @@ class Column: new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null())) return Column(new_expression) - def cast(self, dataType: t.Union[str, DataType]): + def cast(self, dataType: t.Union[str, DataType]) -> Column: """ Functionality Difference: PySpark cast accepts a datatype instance of the datatype class Sqlglot doesn't currently replicate this class so it only accepts a string diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index bdc1fb4..1549a07 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -600,8 +600,13 @@ def months_between( date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None ) -> Column: if roundOff is None: - return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2) - return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff) + return Column.invoke_expression_over_column( + date1, expression.MonthsBetween, expression=date2 + ) + + return Column.invoke_expression_over_column( + date1, expression.MonthsBetween, expression=date2, roundoff=roundOff + ) def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column: @@ -614,8 +619,9 @@ def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column: def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column: if format is not None: - return Column.invoke_anonymous_function(col, "TO_TIMESTAMP", lit(format)) - return Column.invoke_anonymous_function(col, "TO_TIMESTAMP") + return Column.invoke_expression_over_column(col, expression.StrToTime, format=lit(format)) + + return Column.ensure_col(col).cast("timestamp") def trunc(col: ColumnOrName, format: str) -> Column: @@ -875,8 +881,16 @@ def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) ) -def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column: - return Column.invoke_anonymous_function(str, "REGEXP_REPLACE", lit(pattern), lit(replacement)) +def regexp_replace( + str: ColumnOrName, pattern: str, replacement: str, position: t.Optional[int] = None +) -> Column: + return Column.invoke_expression_over_column( + str, + expression.RegexpReplace, + expression=lit(pattern), + replacement=lit(replacement), + position=position, + ) def initcap(col: ColumnOrName) -> Column: @@ -1186,7 +1200,9 @@ def transform( f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], ) -> Column: f_expression = _get_lambda_from_func(f) - return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression)) + return Column.invoke_expression_over_column( + col, expression.Transform, expression=Column(f_expression) + ) def exists(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column: |