From 639a208fa57ea674d165c4837e96f3ae4d7e3e61 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 19 Feb 2023 14:45:09 +0100 Subject: Merging upstream version 11.1.3. Signed-off-by: Daniel Baumann --- sqlglot/dataframe/sql/column.py | 1 + sqlglot/dataframe/sql/functions.py | 235 ++++++++++++++++++++----------------- 2 files changed, 130 insertions(+), 106 deletions(-) (limited to 'sqlglot/dataframe') diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index f5b0974..609b2a4 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -143,6 +143,7 @@ class Column: if is_iterable(v) else Column.ensure_col(v).expression for k, v in kwargs.items() + if v is not None } new_expression = ( callable_expression(**ensure_expression_values) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index 47d5e7b..0262d54 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -2,7 +2,7 @@ from __future__ import annotations import typing as t -from sqlglot import expressions as glotexp +from sqlglot import exp as expression from sqlglot.dataframe.sql.column import Column from sqlglot.helper import ensure_list from sqlglot.helper import flatten as _flatten @@ -18,25 +18,29 @@ def col(column_name: t.Union[ColumnOrName, t.Any]) -> Column: def lit(value: t.Optional[t.Any] = None) -> Column: if isinstance(value, str): - return Column(glotexp.Literal.string(str(value))) + return Column(expression.Literal.string(str(value))) return Column(value) def greatest(*cols: ColumnOrName) -> Column: if len(cols) > 1: - return Column.invoke_expression_over_column(cols[0], glotexp.Greatest, expressions=cols[1:]) - return Column.invoke_expression_over_column(cols[0], glotexp.Greatest) + return Column.invoke_expression_over_column( + cols[0], expression.Greatest, expressions=cols[1:] + ) + return Column.invoke_expression_over_column(cols[0], expression.Greatest) def least(*cols: ColumnOrName) -> Column: if len(cols) > 1: - return Column.invoke_expression_over_column(cols[0], glotexp.Least, expressions=cols[1:]) - return Column.invoke_expression_over_column(cols[0], glotexp.Least) + return Column.invoke_expression_over_column(cols[0], expression.Least, expressions=cols[1:]) + return Column.invoke_expression_over_column(cols[0], expression.Least) def count_distinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: columns = [Column.ensure_col(x) for x in [col] + list(cols)] - return Column(glotexp.Count(this=glotexp.Distinct(expressions=[x.expression for x in columns]))) + return Column( + expression.Count(this=expression.Distinct(expressions=[x.expression for x in columns])) + ) def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: @@ -46,8 +50,8 @@ def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: def when(condition: Column, value: t.Any) -> Column: true_value = value if isinstance(value, Column) else lit(value) return Column( - glotexp.Case( - ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)] + expression.Case( + ifs=[expression.If(this=condition.column_expression, true=true_value.column_expression)] ) ) @@ -65,19 +69,19 @@ def broadcast(df: DataFrame) -> DataFrame: def sqrt(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Sqrt) + return Column.invoke_expression_over_column(col, expression.Sqrt) def abs(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Abs) + return Column.invoke_expression_over_column(col, expression.Abs) def max(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Max) + return Column.invoke_expression_over_column(col, expression.Max) def min(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Min) + return Column.invoke_expression_over_column(col, expression.Min) def max_by(col: ColumnOrName, ord: ColumnOrName) -> Column: @@ -89,15 +93,15 @@ def min_by(col: ColumnOrName, ord: ColumnOrName) -> Column: def count(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Count) + return Column.invoke_expression_over_column(col, expression.Count) def sum(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Sum) + return Column.invoke_expression_over_column(col, expression.Sum) def avg(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Avg) + return Column.invoke_expression_over_column(col, expression.Avg) def mean(col: ColumnOrName) -> Column: @@ -149,7 +153,7 @@ def cbrt(col: ColumnOrName) -> Column: def ceil(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Ceil) + return Column.invoke_expression_over_column(col, expression.Ceil) def cos(col: ColumnOrName) -> Column: @@ -169,7 +173,7 @@ def csc(col: ColumnOrName) -> Column: def exp(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Exp) + return Column.invoke_expression_over_column(col, expression.Exp) def expm1(col: ColumnOrName) -> Column: @@ -177,11 +181,11 @@ def expm1(col: ColumnOrName) -> Column: def floor(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Floor) + return Column.invoke_expression_over_column(col, expression.Floor) def log10(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Log10) + return Column.invoke_expression_over_column(col, expression.Log10) def log1p(col: ColumnOrName) -> Column: @@ -189,13 +193,13 @@ def log1p(col: ColumnOrName) -> Column: def log2(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Log2) + return Column.invoke_expression_over_column(col, expression.Log2) def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column: if arg2 is None: - return Column.invoke_expression_over_column(arg1, glotexp.Ln) - return Column.invoke_expression_over_column(arg1, glotexp.Log, expression=arg2) + return Column.invoke_expression_over_column(arg1, expression.Ln) + return Column.invoke_expression_over_column(arg1, expression.Log, expression=arg2) def rint(col: ColumnOrName) -> Column: @@ -247,7 +251,7 @@ def bitwiseNOT(col: ColumnOrName) -> Column: def bitwise_not(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.BitwiseNot) + return Column.invoke_expression_over_column(col, expression.BitwiseNot) def asc_nulls_first(col: ColumnOrName) -> Column: @@ -267,27 +271,27 @@ def desc_nulls_last(col: ColumnOrName) -> Column: def stddev(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Stddev) + return Column.invoke_expression_over_column(col, expression.Stddev) def stddev_samp(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.StddevSamp) + return Column.invoke_expression_over_column(col, expression.StddevSamp) def stddev_pop(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.StddevPop) + return Column.invoke_expression_over_column(col, expression.StddevPop) def variance(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Variance) + return Column.invoke_expression_over_column(col, expression.Variance) def var_samp(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Variance) + return Column.invoke_expression_over_column(col, expression.Variance) def var_pop(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.VariancePop) + return Column.invoke_expression_over_column(col, expression.VariancePop) def skewness(col: ColumnOrName) -> Column: @@ -299,11 +303,11 @@ def kurtosis(col: ColumnOrName) -> Column: def collect_list(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.ArrayAgg) + return Column.invoke_expression_over_column(col, expression.ArrayAgg) def collect_set(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.SetAgg) + return Column.invoke_expression_over_column(col, expression.SetAgg) def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column: @@ -311,27 +315,27 @@ def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float] def pow(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column: - return Column.invoke_expression_over_column(col1, glotexp.Pow, expression=col2) + return Column.invoke_expression_over_column(col1, expression.Pow, expression=col2) def row_number() -> Column: - return Column(glotexp.Anonymous(this="ROW_NUMBER")) + return Column(expression.Anonymous(this="ROW_NUMBER")) def dense_rank() -> Column: - return Column(glotexp.Anonymous(this="DENSE_RANK")) + return Column(expression.Anonymous(this="DENSE_RANK")) def rank() -> Column: - return Column(glotexp.Anonymous(this="RANK")) + return Column(expression.Anonymous(this="RANK")) def cume_dist() -> Column: - return Column(glotexp.Anonymous(this="CUME_DIST")) + return Column(expression.Anonymous(this="CUME_DIST")) def percent_rank() -> Column: - return Column(glotexp.Anonymous(this="PERCENT_RANK")) + return Column(expression.Anonymous(this="PERCENT_RANK")) def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column: @@ -340,14 +344,16 @@ def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Col def approx_count_distinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column: if rsd is None: - return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct) - return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct, accuracy=rsd) + return Column.invoke_expression_over_column(col, expression.ApproxDistinct) + return Column.invoke_expression_over_column(col, expression.ApproxDistinct, accuracy=rsd) def coalesce(*cols: ColumnOrName) -> Column: if len(cols) > 1: - return Column.invoke_expression_over_column(cols[0], glotexp.Coalesce, expressions=cols[1:]) - return Column.invoke_expression_over_column(cols[0], glotexp.Coalesce) + return Column.invoke_expression_over_column( + cols[0], expression.Coalesce, expressions=cols[1:] + ) + return Column.invoke_expression_over_column(cols[0], expression.Coalesce) def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column: @@ -409,10 +415,10 @@ def percentile_approx( ) -> Column: if accuracy: return Column.invoke_expression_over_column( - col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy + col, expression.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy ) return Column.invoke_expression_over_column( - col, glotexp.ApproxQuantile, quantile=lit(percentage) + col, expression.ApproxQuantile, quantile=lit(percentage) ) @@ -426,8 +432,8 @@ def randn(seed: t.Optional[ColumnOrLiteral] = None) -> Column: def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column: if scale is not None: - return Column.invoke_expression_over_column(col, glotexp.Round, decimals=scale) - return Column.invoke_expression_over_column(col, glotexp.Round) + return Column.invoke_expression_over_column(col, expression.Round, decimals=scale) + return Column.invoke_expression_over_column(col, expression.Round) def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column: @@ -437,7 +443,9 @@ def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column: def shiftleft(col: ColumnOrName, numBits: int) -> Column: - return Column.invoke_expression_over_column(col, glotexp.BitwiseLeftShift, expression=numBits) + return Column.invoke_expression_over_column( + col, expression.BitwiseLeftShift, expression=numBits + ) def shiftLeft(col: ColumnOrName, numBits: int) -> Column: @@ -445,7 +453,9 @@ def shiftLeft(col: ColumnOrName, numBits: int) -> Column: def shiftright(col: ColumnOrName, numBits: int) -> Column: - return Column.invoke_expression_over_column(col, glotexp.BitwiseRightShift, expression=numBits) + return Column.invoke_expression_over_column( + col, expression.BitwiseRightShift, expression=numBits + ) def shiftRight(col: ColumnOrName, numBits: int) -> Column: @@ -466,7 +476,7 @@ def expr(str: str) -> Column: def struct(col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOrName) -> Column: columns = ensure_list(col) + list(cols) - return Column.invoke_expression_over_column(None, glotexp.Struct, expressions=columns) + return Column.invoke_expression_over_column(None, expression.Struct, expressions=columns) def conv(col: ColumnOrName, fromBase: int, toBase: int) -> Column: @@ -512,19 +522,19 @@ def ntile(n: int) -> Column: def current_date() -> Column: - return Column.invoke_expression_over_column(None, glotexp.CurrentDate) + return Column.invoke_expression_over_column(None, expression.CurrentDate) def current_timestamp() -> Column: - return Column.invoke_expression_over_column(None, glotexp.CurrentTimestamp) + return Column.invoke_expression_over_column(None, expression.CurrentTimestamp) def date_format(col: ColumnOrName, format: str) -> Column: - return Column.invoke_expression_over_column(col, glotexp.TimeToStr, format=lit(format)) + return Column.invoke_expression_over_column(col, expression.TimeToStr, format=lit(format)) def year(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Year) + return Column.invoke_expression_over_column(col, expression.Year) def quarter(col: ColumnOrName) -> Column: @@ -532,19 +542,19 @@ def quarter(col: ColumnOrName) -> Column: def month(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Month) + return Column.invoke_expression_over_column(col, expression.Month) def dayofweek(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.DayOfWeek) + return Column.invoke_expression_over_column(col, expression.DayOfWeek) def dayofmonth(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.DayOfMonth) + return Column.invoke_expression_over_column(col, expression.DayOfMonth) def dayofyear(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.DayOfYear) + return Column.invoke_expression_over_column(col, expression.DayOfYear) def hour(col: ColumnOrName) -> Column: @@ -560,7 +570,7 @@ def second(col: ColumnOrName) -> Column: def weekofyear(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.WeekOfYear) + return Column.invoke_expression_over_column(col, expression.WeekOfYear) def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column: @@ -568,15 +578,15 @@ def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Col def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: - return Column.invoke_expression_over_column(col, glotexp.DateAdd, expression=days) + return Column.invoke_expression_over_column(col, expression.DateAdd, expression=days) def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: - return Column.invoke_expression_over_column(col, glotexp.DateSub, expression=days) + return Column.invoke_expression_over_column(col, expression.DateSub, expression=days) def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(end, glotexp.DateDiff, expression=start) + return Column.invoke_expression_over_column(end, expression.DateDiff, expression=start) def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column: @@ -593,8 +603,10 @@ def months_between( def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column: if format is not None: - return Column.invoke_expression_over_column(col, glotexp.TsOrDsToDate, format=lit(format)) - return Column.invoke_expression_over_column(col, glotexp.TsOrDsToDate) + return Column.invoke_expression_over_column( + col, expression.TsOrDsToDate, format=lit(format) + ) + return Column.invoke_expression_over_column(col, expression.TsOrDsToDate) def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column: @@ -604,11 +616,13 @@ def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column: def trunc(col: ColumnOrName, format: str) -> Column: - return Column.invoke_expression_over_column(col, glotexp.DateTrunc, unit=lit(format)) + return Column.invoke_expression_over_column(col, expression.DateTrunc, unit=lit(format)) def date_trunc(format: str, timestamp: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(timestamp, glotexp.TimestampTrunc, unit=lit(format)) + return Column.invoke_expression_over_column( + timestamp, expression.TimestampTrunc, unit=lit(format) + ) def next_day(col: ColumnOrName, dayOfWeek: str) -> Column: @@ -621,8 +635,8 @@ def last_day(col: ColumnOrName) -> Column: def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column: if format is not None: - return Column.invoke_expression_over_column(col, glotexp.UnixToStr, format=lit(format)) - return Column.invoke_expression_over_column(col, glotexp.UnixToStr) + return Column.invoke_expression_over_column(col, expression.UnixToStr, format=lit(format)) + return Column.invoke_expression_over_column(col, expression.UnixToStr) def unix_timestamp( @@ -630,9 +644,9 @@ def unix_timestamp( ) -> Column: if format is not None: return Column.invoke_expression_over_column( - timestamp, glotexp.StrToUnix, format=lit(format) + timestamp, expression.StrToUnix, format=lit(format) ) - return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix) + return Column.invoke_expression_over_column(timestamp, expression.StrToUnix) def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column: @@ -719,11 +733,11 @@ def raise_error(errorMsg: ColumnOrName) -> Column: def upper(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Upper) + return Column.invoke_expression_over_column(col, expression.Upper) def lower(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Lower) + return Column.invoke_expression_over_column(col, expression.Lower) def ascii(col: ColumnOrLiteral) -> Column: @@ -747,24 +761,24 @@ def rtrim(col: ColumnOrName) -> Column: def trim(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Trim) + return Column.invoke_expression_over_column(col, expression.Trim) def concat_ws(sep: str, *cols: ColumnOrName) -> Column: return Column.invoke_expression_over_column( - None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols) + None, expression.ConcatWs, expressions=[lit(sep)] + list(cols) ) def decode(col: ColumnOrName, charset: str) -> Column: return Column.invoke_expression_over_column( - col, glotexp.Decode, charset=glotexp.Literal.string(charset) + col, expression.Decode, charset=expression.Literal.string(charset) ) def encode(col: ColumnOrName, charset: str) -> Column: return Column.invoke_expression_over_column( - col, glotexp.Encode, charset=glotexp.Literal.string(charset) + col, expression.Encode, charset=expression.Literal.string(charset) ) @@ -816,16 +830,16 @@ def substring_index(str: ColumnOrName, delim: str, count: int) -> Column: def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(left, glotexp.Levenshtein, expression=right) + return Column.invoke_expression_over_column(left, expression.Levenshtein, expression=right) def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column: substr_col = lit(substr) if pos is not None: return Column.invoke_expression_over_column( - str, glotexp.StrPosition, substr=substr_col, position=pos + str, expression.StrPosition, substr=substr_col, position=pos ) - return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col) + return Column.invoke_expression_over_column(str, expression.StrPosition, substr=substr_col) def lpad(col: ColumnOrName, len: int, pad: str) -> Column: @@ -837,21 +851,26 @@ def rpad(col: ColumnOrName, len: int, pad: str) -> Column: def repeat(col: ColumnOrName, n: int) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Repeat, times=lit(n)) + return Column.invoke_expression_over_column(col, expression.Repeat, times=lit(n)) def split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column: if limit is not None: return Column.invoke_expression_over_column( - str, glotexp.RegexpSplit, expression=lit(pattern).expression, limit=limit + str, expression.RegexpSplit, expression=lit(pattern).expression, limit=limit ) - return Column.invoke_expression_over_column(str, glotexp.RegexpSplit, expression=lit(pattern)) + return Column.invoke_expression_over_column( + str, expression.RegexpSplit, expression=lit(pattern) + ) def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) -> Column: - if idx is not None: - return Column.invoke_anonymous_function(str, "REGEXP_EXTRACT", lit(pattern), idx) - return Column.invoke_anonymous_function(str, "REGEXP_EXTRACT", lit(pattern)) + return Column.invoke_expression_over_column( + str, + expression.RegexpExtract, + expression=lit(pattern), + group=idx, + ) def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column: @@ -859,7 +878,7 @@ def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column: def initcap(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Initcap) + return Column.invoke_expression_over_column(col, expression.Initcap) def soundex(col: ColumnOrName) -> Column: @@ -871,15 +890,15 @@ def bin(col: ColumnOrName) -> Column: def hex(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Hex) + return Column.invoke_expression_over_column(col, expression.Hex) def unhex(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Unhex) + return Column.invoke_expression_over_column(col, expression.Unhex) def length(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Length) + return Column.invoke_expression_over_column(col, expression.Length) def octet_length(col: ColumnOrName) -> Column: @@ -896,27 +915,27 @@ def translate(srcCol: ColumnOrName, matching: str, replace: str) -> Column: def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: columns = _flatten(cols) if not isinstance(cols[0], (str, Column)) else cols - return Column.invoke_expression_over_column(None, glotexp.Array, expressions=columns) + return Column.invoke_expression_over_column(None, expression.Array, expressions=columns) def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore return Column.invoke_expression_over_column( None, - glotexp.VarMap, + expression.VarMap, keys=array(*cols[::2]).expression, values=array(*cols[1::2]).expression, ) def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(None, glotexp.Map, keys=col1, values=col2) + return Column.invoke_expression_over_column(None, expression.Map, keys=col1, values=col2) def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column: value_col = value if isinstance(value, Column) else lit(value) return Column.invoke_expression_over_column( - col, glotexp.ArrayContains, expression=value_col.expression + col, expression.ArrayContains, expression=value_col.expression ) @@ -943,7 +962,7 @@ def array_join( def concat(*cols: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(None, glotexp.Concat, expressions=cols) + return Column.invoke_expression_over_column(None, expression.Concat, expressions=cols) def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column: @@ -978,11 +997,11 @@ def array_except(col1: ColumnOrName, col2: ColumnOrName) -> Column: def explode(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Explode) + return Column.invoke_expression_over_column(col, expression.Explode) def posexplode(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.Posexplode) + return Column.invoke_expression_over_column(col, expression.Posexplode) def explode_outer(col: ColumnOrName) -> Column: @@ -994,7 +1013,7 @@ def posexplode_outer(col: ColumnOrName) -> Column: def get_json_object(col: ColumnOrName, path: str) -> Column: - return Column.invoke_expression_over_column(col, glotexp.JSONExtract, path=lit(path)) + return Column.invoke_expression_over_column(col, expression.JSONExtract, path=lit(path)) def json_tuple(col: ColumnOrName, *fields: str) -> Column: @@ -1042,7 +1061,7 @@ def to_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> C def size(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, glotexp.ArraySize) + return Column.invoke_expression_over_column(col, expression.ArraySize) def array_min(col: ColumnOrName) -> Column: @@ -1055,8 +1074,8 @@ def array_max(col: ColumnOrName) -> Column: def sort_array(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column: if asc is not None: - return Column.invoke_expression_over_column(col, glotexp.SortArray, asc=asc) - return Column.invoke_expression_over_column(col, glotexp.SortArray) + return Column.invoke_expression_over_column(col, expression.SortArray, asc=asc) + return Column.invoke_expression_over_column(col, expression.SortArray) def array_sort( @@ -1065,8 +1084,10 @@ def array_sort( ) -> Column: if comparator is not None: f_expression = _get_lambda_from_func(comparator) - return Column.invoke_expression_over_column(col, glotexp.ArraySort, expression=f_expression) - return Column.invoke_expression_over_column(col, glotexp.ArraySort) + return Column.invoke_expression_over_column( + col, expression.ArraySort, expression=f_expression + ) + return Column.invoke_expression_over_column(col, expression.ArraySort) def shuffle(col: ColumnOrName) -> Column: @@ -1146,13 +1167,13 @@ def aggregate( finish_exp = _get_lambda_from_func(finish) return Column.invoke_expression_over_column( col, - glotexp.Reduce, + expression.Reduce, initial=initialValue, merge=Column(merge_exp), finish=Column(finish_exp), ) return Column.invoke_expression_over_column( - col, glotexp.Reduce, initial=initialValue, merge=Column(merge_exp) + col, expression.Reduce, initial=initialValue, merge=Column(merge_exp) ) @@ -1179,7 +1200,9 @@ def filter( f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], ) -> Column: f_expression = _get_lambda_from_func(f) - return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression) + return Column.invoke_expression_over_column( + col, expression.ArrayFilter, expression=f_expression + ) def zip_with( @@ -1219,10 +1242,10 @@ def _lambda_quoted(value: str) -> t.Optional[bool]: def _get_lambda_from_func(lambda_expression: t.Callable): variables = [ - glotexp.to_identifier(x, quoted=_lambda_quoted(x)) + expression.to_identifier(x, quoted=_lambda_quoted(x)) for x in lambda_expression.__code__.co_varnames ] - return glotexp.Lambda( + return expression.Lambda( this=lambda_expression(*[Column(x) for x in variables]).expression, expressions=variables, ) -- cgit v1.2.3