summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql/functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dataframe/sql/functions.py')
-rw-r--r--sqlglot/dataframe/sql/functions.py100
1 files changed, 76 insertions, 24 deletions
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index bc002e5..dbfb06f 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -45,7 +45,11 @@ def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
def when(condition: Column, value: t.Any) -> Column:
true_value = value if isinstance(value, Column) else lit(value)
- return Column(glotexp.Case(ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)]))
+ return Column(
+ glotexp.Case(
+ ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)]
+ )
+ )
def asc(col: ColumnOrName) -> Column:
@@ -407,7 +411,9 @@ def percentile_approx(
return Column.invoke_expression_over_column(
col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy
)
- return Column.invoke_expression_over_column(col, glotexp.ApproxQuantile, quantile=lit(percentage))
+ return Column.invoke_expression_over_column(
+ col, glotexp.ApproxQuantile, quantile=lit(percentage)
+ )
def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
@@ -471,7 +477,9 @@ def factorial(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "FACTORIAL")
-def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None) -> Column:
+def lag(
+ col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None
+) -> Column:
if default is not None:
return Column.invoke_anonymous_function(col, "LAG", offset, default)
if offset != 1:
@@ -479,7 +487,9 @@ def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[Colu
return Column.invoke_anonymous_function(col, "LAG")
-def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None) -> Column:
+def lead(
+ col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None
+) -> Column:
if default is not None:
return Column.invoke_anonymous_function(col, "LEAD", offset, default)
if offset != 1:
@@ -487,7 +497,9 @@ def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.A
return Column.invoke_anonymous_function(col, "LEAD")
-def nth_value(col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None) -> Column:
+def nth_value(
+ col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None
+) -> Column:
if ignoreNulls is not None:
raise NotImplementedError("There is currently not support for `ignoreNulls` parameter")
if offset != 1:
@@ -571,7 +583,9 @@ def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Colum
return Column.invoke_anonymous_function(start, "ADD_MONTHS", months)
-def months_between(date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None) -> Column:
+def months_between(
+ date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None
+) -> Column:
if roundOff is None:
return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2)
return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff)
@@ -611,9 +625,13 @@ def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
return Column.invoke_expression_over_column(col, glotexp.UnixToStr)
-def unix_timestamp(timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None) -> Column:
+def unix_timestamp(
+ timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None
+) -> Column:
if format is not None:
- return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix, format=lit(format))
+ return Column.invoke_expression_over_column(
+ timestamp, glotexp.StrToUnix, format=lit(format)
+ )
return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix)
@@ -642,7 +660,9 @@ def window(
timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration), lit(startTime)
)
if slideDuration is not None:
- return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration))
+ return Column.invoke_anonymous_function(
+ timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration)
+ )
if startTime is not None:
return Column.invoke_anonymous_function(
timeColumn, "WINDOW", lit(windowDuration), lit(windowDuration), lit(startTime)
@@ -731,7 +751,9 @@ def trim(col: ColumnOrName) -> Column:
def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols))
+ return Column.invoke_expression_over_column(
+ None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols)
+ )
def decode(col: ColumnOrName, charset: str) -> Column:
@@ -768,7 +790,9 @@ def overlay(
def sentences(
- string: ColumnOrName, language: t.Optional[ColumnOrName] = None, country: t.Optional[ColumnOrName] = None
+ string: ColumnOrName,
+ language: t.Optional[ColumnOrName] = None,
+ country: t.Optional[ColumnOrName] = None,
) -> Column:
if language is not None and country is not None:
return Column.invoke_anonymous_function(string, "SENTENCES", language, country)
@@ -794,7 +818,9 @@ def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column:
def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column:
substr_col = lit(substr)
if pos is not None:
- return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col, position=pos)
+ return Column.invoke_expression_over_column(
+ str, glotexp.StrPosition, substr=substr_col, position=pos
+ )
return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col)
@@ -872,7 +898,10 @@ def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
return Column.invoke_expression_over_column(
- None, glotexp.VarMap, keys=array(*cols[::2]).expression, values=array(*cols[1::2]).expression
+ None,
+ glotexp.VarMap,
+ keys=array(*cols[::2]).expression,
+ values=array(*cols[1::2]).expression,
)
@@ -882,29 +911,39 @@ def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column:
def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
value_col = value if isinstance(value, Column) else lit(value)
- return Column.invoke_expression_over_column(col, glotexp.ArrayContains, expression=value_col.expression)
+ return Column.invoke_expression_over_column(
+ col, glotexp.ArrayContains, expression=value_col.expression
+ )
def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col1, "ARRAYS_OVERLAP", Column.ensure_col(col2))
-def slice(x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]) -> Column:
+def slice(
+ x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]
+) -> Column:
start_col = start if isinstance(start, Column) else lit(start)
length_col = length if isinstance(length, Column) else lit(length)
return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col)
-def array_join(col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None) -> Column:
+def array_join(
+ col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None
+) -> Column:
if null_replacement is not None:
- return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement))
+ return Column.invoke_anonymous_function(
+ col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement)
+ )
return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter))
def concat(*cols: ColumnOrName) -> Column:
if len(cols) == 1:
return Column.invoke_anonymous_function(cols[0], "CONCAT")
- return Column.invoke_anonymous_function(cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]])
+ return Column.invoke_anonymous_function(
+ cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]]
+ )
def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
@@ -1076,7 +1115,9 @@ def map_concat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column
return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT", *columns[1:])
-def sequence(start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None) -> Column:
+def sequence(
+ start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None
+) -> Column:
if step is not None:
return Column.invoke_anonymous_function(start, "SEQUENCE", stop, step)
return Column.invoke_anonymous_function(start, "SEQUENCE", stop)
@@ -1103,12 +1144,15 @@ def aggregate(
merge_exp = _get_lambda_from_func(merge)
if finish is not None:
finish_exp = _get_lambda_from_func(finish)
- return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp))
+ return Column.invoke_anonymous_function(
+ col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp)
+ )
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp))
def transform(
- col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]
+ col: ColumnOrName,
+ f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
) -> Column:
f_expression = _get_lambda_from_func(f)
return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression))
@@ -1124,12 +1168,17 @@ def forall(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column:
return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression))
-def filter(col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]) -> Column:
+def filter(
+ col: ColumnOrName,
+ f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
+) -> Column:
f_expression = _get_lambda_from_func(f)
return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression)
-def zip_with(left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]) -> Column:
+def zip_with(
+ left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]
+) -> Column:
f_expression = _get_lambda_from_func(f)
return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression))
@@ -1163,7 +1212,10 @@ def _lambda_quoted(value: str) -> t.Optional[bool]:
def _get_lambda_from_func(lambda_expression: t.Callable):
- variables = [glotexp.to_identifier(x, quoted=_lambda_quoted(x)) for x in lambda_expression.__code__.co_varnames]
+ variables = [
+ glotexp.to_identifier(x, quoted=_lambda_quoted(x))
+ for x in lambda_expression.__code__.co_varnames
+ ]
return glotexp.Lambda(
this=lambda_expression(*[Column(x) for x in variables]).expression,
expressions=variables,