diff options
Diffstat (limited to 'sqlglot/dataframe/sql/functions.py')
-rw-r--r-- | sqlglot/dataframe/sql/functions.py | 100 |
1 files changed, 76 insertions, 24 deletions
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index bc002e5..dbfb06f 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -45,7 +45,11 @@ def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: def when(condition: Column, value: t.Any) -> Column: true_value = value if isinstance(value, Column) else lit(value) - return Column(glotexp.Case(ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)])) + return Column( + glotexp.Case( + ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)] + ) + ) def asc(col: ColumnOrName) -> Column: @@ -407,7 +411,9 @@ def percentile_approx( return Column.invoke_expression_over_column( col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy ) - return Column.invoke_expression_over_column(col, glotexp.ApproxQuantile, quantile=lit(percentage)) + return Column.invoke_expression_over_column( + col, glotexp.ApproxQuantile, quantile=lit(percentage) + ) def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column: @@ -471,7 +477,9 @@ def factorial(col: ColumnOrName) -> Column: return Column.invoke_anonymous_function(col, "FACTORIAL") -def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None) -> Column: +def lag( + col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None +) -> Column: if default is not None: return Column.invoke_anonymous_function(col, "LAG", offset, default) if offset != 1: @@ -479,7 +487,9 @@ def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[Colu return Column.invoke_anonymous_function(col, "LAG") -def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None) -> Column: +def lead( + col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None +) -> Column: if default is not None: return Column.invoke_anonymous_function(col, "LEAD", offset, default) if offset != 1: @@ -487,7 +497,9 @@ def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.A return Column.invoke_anonymous_function(col, "LEAD") -def nth_value(col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None) -> Column: +def nth_value( + col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None +) -> Column: if ignoreNulls is not None: raise NotImplementedError("There is currently not support for `ignoreNulls` parameter") if offset != 1: @@ -571,7 +583,9 @@ def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Colum return Column.invoke_anonymous_function(start, "ADD_MONTHS", months) -def months_between(date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None) -> Column: +def months_between( + date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None +) -> Column: if roundOff is None: return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2) return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff) @@ -611,9 +625,13 @@ def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column: return Column.invoke_expression_over_column(col, glotexp.UnixToStr) -def unix_timestamp(timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None) -> Column: +def unix_timestamp( + timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None +) -> Column: if format is not None: - return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix, format=lit(format)) + return Column.invoke_expression_over_column( + timestamp, glotexp.StrToUnix, format=lit(format) + ) return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix) @@ -642,7 +660,9 @@ def window( timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration), lit(startTime) ) if slideDuration is not None: - return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration)) + return Column.invoke_anonymous_function( + timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration) + ) if startTime is not None: return Column.invoke_anonymous_function( timeColumn, "WINDOW", lit(windowDuration), lit(windowDuration), lit(startTime) @@ -731,7 +751,9 @@ def trim(col: ColumnOrName) -> Column: def concat_ws(sep: str, *cols: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols)) + return Column.invoke_expression_over_column( + None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols) + ) def decode(col: ColumnOrName, charset: str) -> Column: @@ -768,7 +790,9 @@ def overlay( def sentences( - string: ColumnOrName, language: t.Optional[ColumnOrName] = None, country: t.Optional[ColumnOrName] = None + string: ColumnOrName, + language: t.Optional[ColumnOrName] = None, + country: t.Optional[ColumnOrName] = None, ) -> Column: if language is not None and country is not None: return Column.invoke_anonymous_function(string, "SENTENCES", language, country) @@ -794,7 +818,9 @@ def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column: def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column: substr_col = lit(substr) if pos is not None: - return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col, position=pos) + return Column.invoke_expression_over_column( + str, glotexp.StrPosition, substr=substr_col, position=pos + ) return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col) @@ -872,7 +898,10 @@ def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore return Column.invoke_expression_over_column( - None, glotexp.VarMap, keys=array(*cols[::2]).expression, values=array(*cols[1::2]).expression + None, + glotexp.VarMap, + keys=array(*cols[::2]).expression, + values=array(*cols[1::2]).expression, ) @@ -882,29 +911,39 @@ def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column: def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column: value_col = value if isinstance(value, Column) else lit(value) - return Column.invoke_expression_over_column(col, glotexp.ArrayContains, expression=value_col.expression) + return Column.invoke_expression_over_column( + col, glotexp.ArrayContains, expression=value_col.expression + ) def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column: return Column.invoke_anonymous_function(col1, "ARRAYS_OVERLAP", Column.ensure_col(col2)) -def slice(x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]) -> Column: +def slice( + x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int] +) -> Column: start_col = start if isinstance(start, Column) else lit(start) length_col = length if isinstance(length, Column) else lit(length) return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col) -def array_join(col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None) -> Column: +def array_join( + col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None +) -> Column: if null_replacement is not None: - return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement)) + return Column.invoke_anonymous_function( + col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement) + ) return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter)) def concat(*cols: ColumnOrName) -> Column: if len(cols) == 1: return Column.invoke_anonymous_function(cols[0], "CONCAT") - return Column.invoke_anonymous_function(cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]]) + return Column.invoke_anonymous_function( + cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]] + ) def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column: @@ -1076,7 +1115,9 @@ def map_concat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT", *columns[1:]) -def sequence(start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None) -> Column: +def sequence( + start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None +) -> Column: if step is not None: return Column.invoke_anonymous_function(start, "SEQUENCE", stop, step) return Column.invoke_anonymous_function(start, "SEQUENCE", stop) @@ -1103,12 +1144,15 @@ def aggregate( merge_exp = _get_lambda_from_func(merge) if finish is not None: finish_exp = _get_lambda_from_func(finish) - return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp)) + return Column.invoke_anonymous_function( + col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp) + ) return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp)) def transform( - col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]] + col: ColumnOrName, + f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], ) -> Column: f_expression = _get_lambda_from_func(f) return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression)) @@ -1124,12 +1168,17 @@ def forall(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column: return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression)) -def filter(col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]) -> Column: +def filter( + col: ColumnOrName, + f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], +) -> Column: f_expression = _get_lambda_from_func(f) return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression) -def zip_with(left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]) -> Column: +def zip_with( + left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column] +) -> Column: f_expression = _get_lambda_from_func(f) return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression)) @@ -1163,7 +1212,10 @@ def _lambda_quoted(value: str) -> t.Optional[bool]: def _get_lambda_from_func(lambda_expression: t.Callable): - variables = [glotexp.to_identifier(x, quoted=_lambda_quoted(x)) for x in lambda_expression.__code__.co_varnames] + variables = [ + glotexp.to_identifier(x, quoted=_lambda_quoted(x)) + for x in lambda_expression.__code__.co_varnames + ] return glotexp.Lambda( this=lambda_expression(*[Column(x) for x in variables]).expression, expressions=variables, |