summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-02-19 13:45:09 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-02-19 13:45:09 +0000
commit639a208fa57ea674d165c4837e96f3ae4d7e3e61 (patch)
treef4d66da146c396d407cecefb5b405e609af1109e /sqlglot/dataframe
parentReleasing debian version 11.0.1-1. (diff)
downloadsqlglot-639a208fa57ea674d165c4837e96f3ae4d7e3e61.tar.xz
sqlglot-639a208fa57ea674d165c4837e96f3ae4d7e3e61.zip
Merging upstream version 11.1.3.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dataframe')
-rw-r--r--sqlglot/dataframe/sql/column.py1
-rw-r--r--sqlglot/dataframe/sql/functions.py235
2 files changed, 130 insertions, 106 deletions
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
index f5b0974..609b2a4 100644
--- a/sqlglot/dataframe/sql/column.py
+++ b/sqlglot/dataframe/sql/column.py
@@ -143,6 +143,7 @@ class Column:
if is_iterable(v)
else Column.ensure_col(v).expression
for k, v in kwargs.items()
+ if v is not None
}
new_expression = (
callable_expression(**ensure_expression_values)
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index 47d5e7b..0262d54 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
-from sqlglot import expressions as glotexp
+from sqlglot import exp as expression
from sqlglot.dataframe.sql.column import Column
from sqlglot.helper import ensure_list
from sqlglot.helper import flatten as _flatten
@@ -18,25 +18,29 @@ def col(column_name: t.Union[ColumnOrName, t.Any]) -> Column:
def lit(value: t.Optional[t.Any] = None) -> Column:
if isinstance(value, str):
- return Column(glotexp.Literal.string(str(value)))
+ return Column(expression.Literal.string(str(value)))
return Column(value)
def greatest(*cols: ColumnOrName) -> Column:
if len(cols) > 1:
- return Column.invoke_expression_over_column(cols[0], glotexp.Greatest, expressions=cols[1:])
- return Column.invoke_expression_over_column(cols[0], glotexp.Greatest)
+ return Column.invoke_expression_over_column(
+ cols[0], expression.Greatest, expressions=cols[1:]
+ )
+ return Column.invoke_expression_over_column(cols[0], expression.Greatest)
def least(*cols: ColumnOrName) -> Column:
if len(cols) > 1:
- return Column.invoke_expression_over_column(cols[0], glotexp.Least, expressions=cols[1:])
- return Column.invoke_expression_over_column(cols[0], glotexp.Least)
+ return Column.invoke_expression_over_column(cols[0], expression.Least, expressions=cols[1:])
+ return Column.invoke_expression_over_column(cols[0], expression.Least)
def count_distinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
columns = [Column.ensure_col(x) for x in [col] + list(cols)]
- return Column(glotexp.Count(this=glotexp.Distinct(expressions=[x.expression for x in columns])))
+ return Column(
+ expression.Count(this=expression.Distinct(expressions=[x.expression for x in columns]))
+ )
def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
@@ -46,8 +50,8 @@ def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
def when(condition: Column, value: t.Any) -> Column:
true_value = value if isinstance(value, Column) else lit(value)
return Column(
- glotexp.Case(
- ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)]
+ expression.Case(
+ ifs=[expression.If(this=condition.column_expression, true=true_value.column_expression)]
)
)
@@ -65,19 +69,19 @@ def broadcast(df: DataFrame) -> DataFrame:
def sqrt(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Sqrt)
+ return Column.invoke_expression_over_column(col, expression.Sqrt)
def abs(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Abs)
+ return Column.invoke_expression_over_column(col, expression.Abs)
def max(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Max)
+ return Column.invoke_expression_over_column(col, expression.Max)
def min(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Min)
+ return Column.invoke_expression_over_column(col, expression.Min)
def max_by(col: ColumnOrName, ord: ColumnOrName) -> Column:
@@ -89,15 +93,15 @@ def min_by(col: ColumnOrName, ord: ColumnOrName) -> Column:
def count(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Count)
+ return Column.invoke_expression_over_column(col, expression.Count)
def sum(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Sum)
+ return Column.invoke_expression_over_column(col, expression.Sum)
def avg(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Avg)
+ return Column.invoke_expression_over_column(col, expression.Avg)
def mean(col: ColumnOrName) -> Column:
@@ -149,7 +153,7 @@ def cbrt(col: ColumnOrName) -> Column:
def ceil(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Ceil)
+ return Column.invoke_expression_over_column(col, expression.Ceil)
def cos(col: ColumnOrName) -> Column:
@@ -169,7 +173,7 @@ def csc(col: ColumnOrName) -> Column:
def exp(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Exp)
+ return Column.invoke_expression_over_column(col, expression.Exp)
def expm1(col: ColumnOrName) -> Column:
@@ -177,11 +181,11 @@ def expm1(col: ColumnOrName) -> Column:
def floor(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Floor)
+ return Column.invoke_expression_over_column(col, expression.Floor)
def log10(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Log10)
+ return Column.invoke_expression_over_column(col, expression.Log10)
def log1p(col: ColumnOrName) -> Column:
@@ -189,13 +193,13 @@ def log1p(col: ColumnOrName) -> Column:
def log2(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Log2)
+ return Column.invoke_expression_over_column(col, expression.Log2)
def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column:
if arg2 is None:
- return Column.invoke_expression_over_column(arg1, glotexp.Ln)
- return Column.invoke_expression_over_column(arg1, glotexp.Log, expression=arg2)
+ return Column.invoke_expression_over_column(arg1, expression.Ln)
+ return Column.invoke_expression_over_column(arg1, expression.Log, expression=arg2)
def rint(col: ColumnOrName) -> Column:
@@ -247,7 +251,7 @@ def bitwiseNOT(col: ColumnOrName) -> Column:
def bitwise_not(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.BitwiseNot)
+ return Column.invoke_expression_over_column(col, expression.BitwiseNot)
def asc_nulls_first(col: ColumnOrName) -> Column:
@@ -267,27 +271,27 @@ def desc_nulls_last(col: ColumnOrName) -> Column:
def stddev(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Stddev)
+ return Column.invoke_expression_over_column(col, expression.Stddev)
def stddev_samp(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.StddevSamp)
+ return Column.invoke_expression_over_column(col, expression.StddevSamp)
def stddev_pop(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.StddevPop)
+ return Column.invoke_expression_over_column(col, expression.StddevPop)
def variance(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Variance)
+ return Column.invoke_expression_over_column(col, expression.Variance)
def var_samp(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Variance)
+ return Column.invoke_expression_over_column(col, expression.Variance)
def var_pop(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.VariancePop)
+ return Column.invoke_expression_over_column(col, expression.VariancePop)
def skewness(col: ColumnOrName) -> Column:
@@ -299,11 +303,11 @@ def kurtosis(col: ColumnOrName) -> Column:
def collect_list(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.ArrayAgg)
+ return Column.invoke_expression_over_column(col, expression.ArrayAgg)
def collect_set(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.SetAgg)
+ return Column.invoke_expression_over_column(col, expression.SetAgg)
def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
@@ -311,27 +315,27 @@ def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]
def pow(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
- return Column.invoke_expression_over_column(col1, glotexp.Pow, expression=col2)
+ return Column.invoke_expression_over_column(col1, expression.Pow, expression=col2)
def row_number() -> Column:
- return Column(glotexp.Anonymous(this="ROW_NUMBER"))
+ return Column(expression.Anonymous(this="ROW_NUMBER"))
def dense_rank() -> Column:
- return Column(glotexp.Anonymous(this="DENSE_RANK"))
+ return Column(expression.Anonymous(this="DENSE_RANK"))
def rank() -> Column:
- return Column(glotexp.Anonymous(this="RANK"))
+ return Column(expression.Anonymous(this="RANK"))
def cume_dist() -> Column:
- return Column(glotexp.Anonymous(this="CUME_DIST"))
+ return Column(expression.Anonymous(this="CUME_DIST"))
def percent_rank() -> Column:
- return Column(glotexp.Anonymous(this="PERCENT_RANK"))
+ return Column(expression.Anonymous(this="PERCENT_RANK"))
def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column:
@@ -340,14 +344,16 @@ def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Col
def approx_count_distinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column:
if rsd is None:
- return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct)
- return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct, accuracy=rsd)
+ return Column.invoke_expression_over_column(col, expression.ApproxDistinct)
+ return Column.invoke_expression_over_column(col, expression.ApproxDistinct, accuracy=rsd)
def coalesce(*cols: ColumnOrName) -> Column:
if len(cols) > 1:
- return Column.invoke_expression_over_column(cols[0], glotexp.Coalesce, expressions=cols[1:])
- return Column.invoke_expression_over_column(cols[0], glotexp.Coalesce)
+ return Column.invoke_expression_over_column(
+ cols[0], expression.Coalesce, expressions=cols[1:]
+ )
+ return Column.invoke_expression_over_column(cols[0], expression.Coalesce)
def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column:
@@ -409,10 +415,10 @@ def percentile_approx(
) -> Column:
if accuracy:
return Column.invoke_expression_over_column(
- col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy
+ col, expression.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy
)
return Column.invoke_expression_over_column(
- col, glotexp.ApproxQuantile, quantile=lit(percentage)
+ col, expression.ApproxQuantile, quantile=lit(percentage)
)
@@ -426,8 +432,8 @@ def randn(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
if scale is not None:
- return Column.invoke_expression_over_column(col, glotexp.Round, decimals=scale)
- return Column.invoke_expression_over_column(col, glotexp.Round)
+ return Column.invoke_expression_over_column(col, expression.Round, decimals=scale)
+ return Column.invoke_expression_over_column(col, expression.Round)
def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
@@ -437,7 +443,9 @@ def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
def shiftleft(col: ColumnOrName, numBits: int) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.BitwiseLeftShift, expression=numBits)
+ return Column.invoke_expression_over_column(
+ col, expression.BitwiseLeftShift, expression=numBits
+ )
def shiftLeft(col: ColumnOrName, numBits: int) -> Column:
@@ -445,7 +453,9 @@ def shiftLeft(col: ColumnOrName, numBits: int) -> Column:
def shiftright(col: ColumnOrName, numBits: int) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.BitwiseRightShift, expression=numBits)
+ return Column.invoke_expression_over_column(
+ col, expression.BitwiseRightShift, expression=numBits
+ )
def shiftRight(col: ColumnOrName, numBits: int) -> Column:
@@ -466,7 +476,7 @@ def expr(str: str) -> Column:
def struct(col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOrName) -> Column:
columns = ensure_list(col) + list(cols)
- return Column.invoke_expression_over_column(None, glotexp.Struct, expressions=columns)
+ return Column.invoke_expression_over_column(None, expression.Struct, expressions=columns)
def conv(col: ColumnOrName, fromBase: int, toBase: int) -> Column:
@@ -512,19 +522,19 @@ def ntile(n: int) -> Column:
def current_date() -> Column:
- return Column.invoke_expression_over_column(None, glotexp.CurrentDate)
+ return Column.invoke_expression_over_column(None, expression.CurrentDate)
def current_timestamp() -> Column:
- return Column.invoke_expression_over_column(None, glotexp.CurrentTimestamp)
+ return Column.invoke_expression_over_column(None, expression.CurrentTimestamp)
def date_format(col: ColumnOrName, format: str) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.TimeToStr, format=lit(format))
+ return Column.invoke_expression_over_column(col, expression.TimeToStr, format=lit(format))
def year(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Year)
+ return Column.invoke_expression_over_column(col, expression.Year)
def quarter(col: ColumnOrName) -> Column:
@@ -532,19 +542,19 @@ def quarter(col: ColumnOrName) -> Column:
def month(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Month)
+ return Column.invoke_expression_over_column(col, expression.Month)
def dayofweek(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.DayOfWeek)
+ return Column.invoke_expression_over_column(col, expression.DayOfWeek)
def dayofmonth(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.DayOfMonth)
+ return Column.invoke_expression_over_column(col, expression.DayOfMonth)
def dayofyear(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.DayOfYear)
+ return Column.invoke_expression_over_column(col, expression.DayOfYear)
def hour(col: ColumnOrName) -> Column:
@@ -560,7 +570,7 @@ def second(col: ColumnOrName) -> Column:
def weekofyear(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.WeekOfYear)
+ return Column.invoke_expression_over_column(col, expression.WeekOfYear)
def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column:
@@ -568,15 +578,15 @@ def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Col
def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.DateAdd, expression=days)
+ return Column.invoke_expression_over_column(col, expression.DateAdd, expression=days)
def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.DateSub, expression=days)
+ return Column.invoke_expression_over_column(col, expression.DateSub, expression=days)
def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(end, glotexp.DateDiff, expression=start)
+ return Column.invoke_expression_over_column(end, expression.DateDiff, expression=start)
def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column:
@@ -593,8 +603,10 @@ def months_between(
def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
if format is not None:
- return Column.invoke_expression_over_column(col, glotexp.TsOrDsToDate, format=lit(format))
- return Column.invoke_expression_over_column(col, glotexp.TsOrDsToDate)
+ return Column.invoke_expression_over_column(
+ col, expression.TsOrDsToDate, format=lit(format)
+ )
+ return Column.invoke_expression_over_column(col, expression.TsOrDsToDate)
def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
@@ -604,11 +616,13 @@ def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
def trunc(col: ColumnOrName, format: str) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.DateTrunc, unit=lit(format))
+ return Column.invoke_expression_over_column(col, expression.DateTrunc, unit=lit(format))
def date_trunc(format: str, timestamp: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(timestamp, glotexp.TimestampTrunc, unit=lit(format))
+ return Column.invoke_expression_over_column(
+ timestamp, expression.TimestampTrunc, unit=lit(format)
+ )
def next_day(col: ColumnOrName, dayOfWeek: str) -> Column:
@@ -621,8 +635,8 @@ def last_day(col: ColumnOrName) -> Column:
def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
if format is not None:
- return Column.invoke_expression_over_column(col, glotexp.UnixToStr, format=lit(format))
- return Column.invoke_expression_over_column(col, glotexp.UnixToStr)
+ return Column.invoke_expression_over_column(col, expression.UnixToStr, format=lit(format))
+ return Column.invoke_expression_over_column(col, expression.UnixToStr)
def unix_timestamp(
@@ -630,9 +644,9 @@ def unix_timestamp(
) -> Column:
if format is not None:
return Column.invoke_expression_over_column(
- timestamp, glotexp.StrToUnix, format=lit(format)
+ timestamp, expression.StrToUnix, format=lit(format)
)
- return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix)
+ return Column.invoke_expression_over_column(timestamp, expression.StrToUnix)
def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
@@ -719,11 +733,11 @@ def raise_error(errorMsg: ColumnOrName) -> Column:
def upper(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Upper)
+ return Column.invoke_expression_over_column(col, expression.Upper)
def lower(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Lower)
+ return Column.invoke_expression_over_column(col, expression.Lower)
def ascii(col: ColumnOrLiteral) -> Column:
@@ -747,24 +761,24 @@ def rtrim(col: ColumnOrName) -> Column:
def trim(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Trim)
+ return Column.invoke_expression_over_column(col, expression.Trim)
def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
return Column.invoke_expression_over_column(
- None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols)
+ None, expression.ConcatWs, expressions=[lit(sep)] + list(cols)
)
def decode(col: ColumnOrName, charset: str) -> Column:
return Column.invoke_expression_over_column(
- col, glotexp.Decode, charset=glotexp.Literal.string(charset)
+ col, expression.Decode, charset=expression.Literal.string(charset)
)
def encode(col: ColumnOrName, charset: str) -> Column:
return Column.invoke_expression_over_column(
- col, glotexp.Encode, charset=glotexp.Literal.string(charset)
+ col, expression.Encode, charset=expression.Literal.string(charset)
)
@@ -816,16 +830,16 @@ def substring_index(str: ColumnOrName, delim: str, count: int) -> Column:
def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(left, glotexp.Levenshtein, expression=right)
+ return Column.invoke_expression_over_column(left, expression.Levenshtein, expression=right)
def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column:
substr_col = lit(substr)
if pos is not None:
return Column.invoke_expression_over_column(
- str, glotexp.StrPosition, substr=substr_col, position=pos
+ str, expression.StrPosition, substr=substr_col, position=pos
)
- return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col)
+ return Column.invoke_expression_over_column(str, expression.StrPosition, substr=substr_col)
def lpad(col: ColumnOrName, len: int, pad: str) -> Column:
@@ -837,21 +851,26 @@ def rpad(col: ColumnOrName, len: int, pad: str) -> Column:
def repeat(col: ColumnOrName, n: int) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Repeat, times=lit(n))
+ return Column.invoke_expression_over_column(col, expression.Repeat, times=lit(n))
def split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column:
if limit is not None:
return Column.invoke_expression_over_column(
- str, glotexp.RegexpSplit, expression=lit(pattern).expression, limit=limit
+ str, expression.RegexpSplit, expression=lit(pattern).expression, limit=limit
)
- return Column.invoke_expression_over_column(str, glotexp.RegexpSplit, expression=lit(pattern))
+ return Column.invoke_expression_over_column(
+ str, expression.RegexpSplit, expression=lit(pattern)
+ )
def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) -> Column:
- if idx is not None:
- return Column.invoke_anonymous_function(str, "REGEXP_EXTRACT", lit(pattern), idx)
- return Column.invoke_anonymous_function(str, "REGEXP_EXTRACT", lit(pattern))
+ return Column.invoke_expression_over_column(
+ str,
+ expression.RegexpExtract,
+ expression=lit(pattern),
+ group=idx,
+ )
def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column:
@@ -859,7 +878,7 @@ def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column:
def initcap(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Initcap)
+ return Column.invoke_expression_over_column(col, expression.Initcap)
def soundex(col: ColumnOrName) -> Column:
@@ -871,15 +890,15 @@ def bin(col: ColumnOrName) -> Column:
def hex(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Hex)
+ return Column.invoke_expression_over_column(col, expression.Hex)
def unhex(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Unhex)
+ return Column.invoke_expression_over_column(col, expression.Unhex)
def length(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Length)
+ return Column.invoke_expression_over_column(col, expression.Length)
def octet_length(col: ColumnOrName) -> Column:
@@ -896,27 +915,27 @@ def translate(srcCol: ColumnOrName, matching: str, replace: str) -> Column:
def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
columns = _flatten(cols) if not isinstance(cols[0], (str, Column)) else cols
- return Column.invoke_expression_over_column(None, glotexp.Array, expressions=columns)
+ return Column.invoke_expression_over_column(None, expression.Array, expressions=columns)
def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
return Column.invoke_expression_over_column(
None,
- glotexp.VarMap,
+ expression.VarMap,
keys=array(*cols[::2]).expression,
values=array(*cols[1::2]).expression,
)
def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(None, glotexp.Map, keys=col1, values=col2)
+ return Column.invoke_expression_over_column(None, expression.Map, keys=col1, values=col2)
def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
value_col = value if isinstance(value, Column) else lit(value)
return Column.invoke_expression_over_column(
- col, glotexp.ArrayContains, expression=value_col.expression
+ col, expression.ArrayContains, expression=value_col.expression
)
@@ -943,7 +962,7 @@ def array_join(
def concat(*cols: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(None, glotexp.Concat, expressions=cols)
+ return Column.invoke_expression_over_column(None, expression.Concat, expressions=cols)
def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
@@ -978,11 +997,11 @@ def array_except(col1: ColumnOrName, col2: ColumnOrName) -> Column:
def explode(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Explode)
+ return Column.invoke_expression_over_column(col, expression.Explode)
def posexplode(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.Posexplode)
+ return Column.invoke_expression_over_column(col, expression.Posexplode)
def explode_outer(col: ColumnOrName) -> Column:
@@ -994,7 +1013,7 @@ def posexplode_outer(col: ColumnOrName) -> Column:
def get_json_object(col: ColumnOrName, path: str) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.JSONExtract, path=lit(path))
+ return Column.invoke_expression_over_column(col, expression.JSONExtract, path=lit(path))
def json_tuple(col: ColumnOrName, *fields: str) -> Column:
@@ -1042,7 +1061,7 @@ def to_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> C
def size(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.ArraySize)
+ return Column.invoke_expression_over_column(col, expression.ArraySize)
def array_min(col: ColumnOrName) -> Column:
@@ -1055,8 +1074,8 @@ def array_max(col: ColumnOrName) -> Column:
def sort_array(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column:
if asc is not None:
- return Column.invoke_expression_over_column(col, glotexp.SortArray, asc=asc)
- return Column.invoke_expression_over_column(col, glotexp.SortArray)
+ return Column.invoke_expression_over_column(col, expression.SortArray, asc=asc)
+ return Column.invoke_expression_over_column(col, expression.SortArray)
def array_sort(
@@ -1065,8 +1084,10 @@ def array_sort(
) -> Column:
if comparator is not None:
f_expression = _get_lambda_from_func(comparator)
- return Column.invoke_expression_over_column(col, glotexp.ArraySort, expression=f_expression)
- return Column.invoke_expression_over_column(col, glotexp.ArraySort)
+ return Column.invoke_expression_over_column(
+ col, expression.ArraySort, expression=f_expression
+ )
+ return Column.invoke_expression_over_column(col, expression.ArraySort)
def shuffle(col: ColumnOrName) -> Column:
@@ -1146,13 +1167,13 @@ def aggregate(
finish_exp = _get_lambda_from_func(finish)
return Column.invoke_expression_over_column(
col,
- glotexp.Reduce,
+ expression.Reduce,
initial=initialValue,
merge=Column(merge_exp),
finish=Column(finish_exp),
)
return Column.invoke_expression_over_column(
- col, glotexp.Reduce, initial=initialValue, merge=Column(merge_exp)
+ col, expression.Reduce, initial=initialValue, merge=Column(merge_exp)
)
@@ -1179,7 +1200,9 @@ def filter(
f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
) -> Column:
f_expression = _get_lambda_from_func(f)
- return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression)
+ return Column.invoke_expression_over_column(
+ col, expression.ArrayFilter, expression=f_expression
+ )
def zip_with(
@@ -1219,10 +1242,10 @@ def _lambda_quoted(value: str) -> t.Optional[bool]:
def _get_lambda_from_func(lambda_expression: t.Callable):
variables = [
- glotexp.to_identifier(x, quoted=_lambda_quoted(x))
+ expression.to_identifier(x, quoted=_lambda_quoted(x))
for x in lambda_expression.__code__.co_varnames
]
- return glotexp.Lambda(
+ return expression.Lambda(
this=lambda_expression(*[Column(x) for x in variables]).expression,
expressions=variables,
)