summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dataframe')
-rw-r--r--sqlglot/dataframe/sql/column.py4
-rw-r--r--sqlglot/dataframe/sql/functions.py30
2 files changed, 25 insertions, 9 deletions
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
index f4cfeba..fcfd71e 100644
--- a/sqlglot/dataframe/sql/column.py
+++ b/sqlglot/dataframe/sql/column.py
@@ -114,7 +114,7 @@ class Column:
return self.inverse_binary_op(exp.Or, other)
@classmethod
- def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
+ def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]) -> Column:
return cls(value)
@classmethod
@@ -259,7 +259,7 @@ class Column:
new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
return Column(new_expression)
- def cast(self, dataType: t.Union[str, DataType]):
+ def cast(self, dataType: t.Union[str, DataType]) -> Column:
"""
Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
Sqlglot doesn't currently replicate this class so it only accepts a string
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index bdc1fb4..1549a07 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -600,8 +600,13 @@ def months_between(
date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None
) -> Column:
if roundOff is None:
- return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2)
- return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff)
+ return Column.invoke_expression_over_column(
+ date1, expression.MonthsBetween, expression=date2
+ )
+
+ return Column.invoke_expression_over_column(
+ date1, expression.MonthsBetween, expression=date2, roundoff=roundOff
+ )
def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
@@ -614,8 +619,9 @@ def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
if format is not None:
- return Column.invoke_anonymous_function(col, "TO_TIMESTAMP", lit(format))
- return Column.invoke_anonymous_function(col, "TO_TIMESTAMP")
+ return Column.invoke_expression_over_column(col, expression.StrToTime, format=lit(format))
+
+ return Column.ensure_col(col).cast("timestamp")
def trunc(col: ColumnOrName, format: str) -> Column:
@@ -875,8 +881,16 @@ def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None)
)
-def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column:
- return Column.invoke_anonymous_function(str, "REGEXP_REPLACE", lit(pattern), lit(replacement))
+def regexp_replace(
+ str: ColumnOrName, pattern: str, replacement: str, position: t.Optional[int] = None
+) -> Column:
+ return Column.invoke_expression_over_column(
+ str,
+ expression.RegexpReplace,
+ expression=lit(pattern),
+ replacement=lit(replacement),
+ position=position,
+ )
def initcap(col: ColumnOrName) -> Column:
@@ -1186,7 +1200,9 @@ def transform(
f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
) -> Column:
f_expression = _get_lambda_from_func(f)
- return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression))
+ return Column.invoke_expression_over_column(
+ col, expression.Transform, expression=Column(f_expression)
+ )
def exists(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column: