summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dataframe')
-rw-r--r--sqlglot/dataframe/sql/column.py10
-rw-r--r--sqlglot/dataframe/sql/functions.py252
2 files changed, 89 insertions, 173 deletions
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
index 2391080..e66aaa8 100644
--- a/sqlglot/dataframe/sql/column.py
+++ b/sqlglot/dataframe/sql/column.py
@@ -5,7 +5,7 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.types import DataType
-from sqlglot.helper import flatten
+from sqlglot.helper import flatten, is_iterable
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnOrLiteral
@@ -134,10 +134,14 @@ class Column:
cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
) -> Column:
ensured_column = None if column is None else cls.ensure_col(column)
+ ensure_expression_values = {
+ k: [Column.ensure_col(x).expression for x in v] if is_iterable(v) else Column.ensure_col(v).expression
+ for k, v in kwargs.items()
+ }
new_expression = (
- callable_expression(**kwargs)
+ callable_expression(**ensure_expression_values)
if ensured_column is None
- else callable_expression(this=ensured_column.column_expression, **kwargs)
+ else callable_expression(this=ensured_column.column_expression, **ensure_expression_values)
)
return Column(new_expression)
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index 4c6de30..bc002e5 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -1,7 +1,6 @@
from __future__ import annotations
import typing as t
-from inspect import signature
from sqlglot import expressions as glotexp
from sqlglot.dataframe.sql.column import Column
@@ -24,17 +23,15 @@ def lit(value: t.Optional[t.Any] = None) -> Column:
def greatest(*cols: ColumnOrName) -> Column:
- columns = [Column.ensure_col(col) for col in cols]
- return Column.invoke_expression_over_column(
- columns[0], glotexp.Greatest, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None
- )
+ if len(cols) > 1:
+ return Column.invoke_expression_over_column(cols[0], glotexp.Greatest, expressions=cols[1:])
+ return Column.invoke_expression_over_column(cols[0], glotexp.Greatest)
def least(*cols: ColumnOrName) -> Column:
- columns = [Column.ensure_col(col) for col in cols]
- return Column.invoke_expression_over_column(
- columns[0], glotexp.Least, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None
- )
+ if len(cols) > 1:
+ return Column.invoke_expression_over_column(cols[0], glotexp.Least, expressions=cols[1:])
+ return Column.invoke_expression_over_column(cols[0], glotexp.Least)
def count_distinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
@@ -194,7 +191,7 @@ def log2(col: ColumnOrName) -> Column:
def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column:
if arg2 is None:
return Column.invoke_expression_over_column(arg1, glotexp.Ln)
- return Column.invoke_expression_over_column(arg1, glotexp.Log, expression=Column.ensure_col(arg2).expression)
+ return Column.invoke_expression_over_column(arg1, glotexp.Log, expression=arg2)
def rint(col: ColumnOrName) -> Column:
@@ -310,7 +307,7 @@ def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]
def pow(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
- return Column.invoke_anonymous_function(col1, "POW", col2)
+ return Column.invoke_expression_over_column(col1, glotexp.Pow, power=col2)
def row_number() -> Column:
@@ -340,14 +337,13 @@ def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Col
def approx_count_distinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column:
if rsd is None:
return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct)
- return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct, accuracy=Column.ensure_col(rsd).expression)
+ return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct, accuracy=rsd)
def coalesce(*cols: ColumnOrName) -> Column:
- columns = [Column.ensure_col(col) for col in cols]
- return Column.invoke_expression_over_column(
- columns[0], glotexp.Coalesce, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None
- )
+ if len(cols) > 1:
+ return Column.invoke_expression_over_column(cols[0], glotexp.Coalesce, expressions=cols[1:])
+ return Column.invoke_expression_over_column(cols[0], glotexp.Coalesce)
def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column:
@@ -405,11 +401,13 @@ def nanvl(col1: ColumnOrName, col2: ColumnOrName) -> Column:
def percentile_approx(
col: ColumnOrName,
percentage: t.Union[ColumnOrLiteral, t.List[float], t.Tuple[float]],
- accuracy: t.Optional[t.Union[ColumnOrLiteral]] = None,
+ accuracy: t.Optional[t.Union[ColumnOrLiteral, int]] = None,
) -> Column:
if accuracy:
- return Column.invoke_anonymous_function(col, "PERCENTILE_APPROX", percentage, accuracy)
- return Column.invoke_anonymous_function(col, "PERCENTILE_APPROX", percentage)
+ return Column.invoke_expression_over_column(
+ col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy
+ )
+ return Column.invoke_expression_over_column(col, glotexp.ApproxQuantile, quantile=lit(percentage))
def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
@@ -422,7 +420,7 @@ def randn(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
if scale is not None:
- return Column.invoke_expression_over_column(col, glotexp.Round, decimals=glotexp.convert(scale))
+ return Column.invoke_expression_over_column(col, glotexp.Round, decimals=scale)
return Column.invoke_expression_over_column(col, glotexp.Round)
@@ -433,9 +431,7 @@ def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
def shiftleft(col: ColumnOrName, numBits: int) -> Column:
- return Column.invoke_expression_over_column(
- col, glotexp.BitwiseLeftShift, expression=Column.ensure_col(numBits).expression
- )
+ return Column.invoke_expression_over_column(col, glotexp.BitwiseLeftShift, expression=numBits)
def shiftLeft(col: ColumnOrName, numBits: int) -> Column:
@@ -443,9 +439,7 @@ def shiftLeft(col: ColumnOrName, numBits: int) -> Column:
def shiftright(col: ColumnOrName, numBits: int) -> Column:
- return Column.invoke_expression_over_column(
- col, glotexp.BitwiseRightShift, expression=Column.ensure_col(numBits).expression
- )
+ return Column.invoke_expression_over_column(col, glotexp.BitwiseRightShift, expression=numBits)
def shiftRight(col: ColumnOrName, numBits: int) -> Column:
@@ -466,8 +460,7 @@ def expr(str: str) -> Column:
def struct(col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOrName) -> Column:
columns = ensure_list(col) + list(cols)
- expressions = [Column.ensure_col(column).expression for column in columns]
- return Column(glotexp.Struct(expressions=expressions))
+ return Column.invoke_expression_over_column(None, glotexp.Struct, expressions=columns)
def conv(col: ColumnOrName, fromBase: int, toBase: int) -> Column:
@@ -515,7 +508,7 @@ def current_timestamp() -> Column:
def date_format(col: ColumnOrName, format: str) -> Column:
- return Column.invoke_anonymous_function(col, "DATE_FORMAT", lit(format))
+ return Column.invoke_expression_over_column(col, glotexp.TimeToStr, format=lit(format))
def year(col: ColumnOrName) -> Column:
@@ -563,15 +556,15 @@ def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Col
def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.DateAdd, expression=Column.ensure_col(days).expression)
+ return Column.invoke_expression_over_column(col, glotexp.DateAdd, expression=days)
def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.DateSub, expression=Column.ensure_col(days).expression)
+ return Column.invoke_expression_over_column(col, glotexp.DateSub, expression=days)
def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(end, glotexp.DateDiff, expression=Column.ensure_col(start).expression)
+ return Column.invoke_expression_over_column(end, glotexp.DateDiff, expression=start)
def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column:
@@ -586,8 +579,8 @@ def months_between(date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optiona
def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
if format is not None:
- return Column.invoke_anonymous_function(col, "TO_DATE", lit(format))
- return Column.invoke_anonymous_function(col, "TO_DATE")
+ return Column.invoke_expression_over_column(col, glotexp.TsOrDsToDate, format=lit(format))
+ return Column.invoke_expression_over_column(col, glotexp.TsOrDsToDate)
def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
@@ -597,11 +590,11 @@ def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
def trunc(col: ColumnOrName, format: str) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.DateTrunc, unit=lit(format).expression)
+ return Column.invoke_expression_over_column(col, glotexp.DateTrunc, unit=lit(format))
def date_trunc(format: str, timestamp: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(timestamp, glotexp.TimestampTrunc, unit=lit(format).expression)
+ return Column.invoke_expression_over_column(timestamp, glotexp.TimestampTrunc, unit=lit(format))
def next_day(col: ColumnOrName, dayOfWeek: str) -> Column:
@@ -614,14 +607,14 @@ def last_day(col: ColumnOrName) -> Column:
def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
if format is not None:
- return Column.invoke_anonymous_function(col, "FROM_UNIXTIME", lit(format))
- return Column.invoke_anonymous_function(col, "FROM_UNIXTIME")
+ return Column.invoke_expression_over_column(col, glotexp.UnixToStr, format=lit(format))
+ return Column.invoke_expression_over_column(col, glotexp.UnixToStr)
def unix_timestamp(timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None) -> Column:
if format is not None:
- return Column.invoke_anonymous_function(timestamp, "UNIX_TIMESTAMP", lit(format))
- return Column.invoke_anonymous_function(timestamp, "UNIX_TIMESTAMP")
+ return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix, format=lit(format))
+ return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix)
def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
@@ -738,10 +731,7 @@ def trim(col: ColumnOrName) -> Column:
def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
- columns = [Column(col) for col in cols]
- return Column.invoke_expression_over_column(
- None, glotexp.ConcatWs, expressions=[x.expression for x in [lit(sep)] + list(columns)]
- )
+ return Column.invoke_expression_over_column(None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols))
def decode(col: ColumnOrName, charset: str) -> Column:
@@ -798,18 +788,14 @@ def substring_index(str: ColumnOrName, delim: str, count: int) -> Column:
def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(
- left, glotexp.Levenshtein, expression=Column.ensure_col(right).expression
- )
+ return Column.invoke_expression_over_column(left, glotexp.Levenshtein, expression=right)
def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column:
substr_col = lit(substr)
- pos_column = lit(pos)
- str_column = Column.ensure_col(str)
if pos is not None:
- return Column.invoke_anonymous_function(substr_col, "LOCATE", str_column, pos_column)
- return Column.invoke_anonymous_function(substr_col, "LOCATE", str_column)
+ return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col, position=pos)
+ return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col)
def lpad(col: ColumnOrName, len: int, pad: str) -> Column:
@@ -821,15 +807,15 @@ def rpad(col: ColumnOrName, len: int, pad: str) -> Column:
def repeat(col: ColumnOrName, n: int) -> Column:
- return Column.invoke_anonymous_function(col, "REPEAT", n)
+ return Column.invoke_expression_over_column(col, glotexp.Repeat, times=lit(n))
def split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column:
if limit is not None:
return Column.invoke_expression_over_column(
- str, glotexp.RegexpSplit, expression=lit(pattern).expression, limit=lit(limit).expression
+ str, glotexp.RegexpSplit, expression=lit(pattern).expression, limit=limit
)
- return Column.invoke_expression_over_column(str, glotexp.RegexpSplit, expression=lit(pattern).expression)
+ return Column.invoke_expression_over_column(str, glotexp.RegexpSplit, expression=lit(pattern))
def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) -> Column:
@@ -879,9 +865,8 @@ def translate(srcCol: ColumnOrName, matching: str, replace: str) -> Column:
def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
- cols = _flatten(cols) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
- cols = [Column.ensure_col(col).expression for col in cols] # type: ignore
- return Column.invoke_expression_over_column(None, glotexp.Array, expressions=cols)
+ columns = _flatten(cols) if not isinstance(cols[0], (str, Column)) else cols
+ return Column.invoke_expression_over_column(None, glotexp.Array, expressions=columns)
def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
@@ -892,7 +877,7 @@ def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column
def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col1, "MAP_FROM_ARRAYS", col2)
+ return Column.invoke_expression_over_column(None, glotexp.Map, keys=col1, values=col2)
def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
@@ -970,7 +955,7 @@ def posexplode_outer(col: ColumnOrName) -> Column:
def get_json_object(col: ColumnOrName, path: str) -> Column:
- return Column.invoke_expression_over_column(col, glotexp.JSONExtract, path=lit(path).expression)
+ return Column.invoke_expression_over_column(col, glotexp.JSONExtract, path=lit(path))
def json_tuple(col: ColumnOrName, *fields: str) -> Column:
@@ -1031,11 +1016,17 @@ def array_max(col: ColumnOrName) -> Column:
def sort_array(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column:
if asc is not None:
- return Column.invoke_anonymous_function(col, "SORT_ARRAY", lit(asc))
- return Column.invoke_anonymous_function(col, "SORT_ARRAY")
+ return Column.invoke_expression_over_column(col, glotexp.SortArray, asc=asc)
+ return Column.invoke_expression_over_column(col, glotexp.SortArray)
-def array_sort(col: ColumnOrName) -> Column:
+def array_sort(
+ col: ColumnOrName,
+ comparator: t.Optional[t.Union[t.Callable[[Column, Column], Column]]] = None,
+) -> Column:
+ if comparator is not None:
+ f_expression = _get_lambda_from_func(comparator)
+ return Column.invoke_expression_over_column(col, glotexp.ArraySort, expression=f_expression)
return Column.invoke_expression_over_column(col, glotexp.ArraySort)
@@ -1108,130 +1099,53 @@ def aggregate(
initialValue: ColumnOrName,
merge: t.Callable[[Column, Column], Column],
finish: t.Optional[t.Callable[[Column], Column]] = None,
- accumulator_name: str = "acc",
- target_row_name: str = "x",
) -> Column:
- merge_exp = glotexp.Lambda(
- this=merge(Column(accumulator_name), Column(target_row_name)).expression,
- expressions=[
- glotexp.to_identifier(accumulator_name, quoted=_lambda_quoted(accumulator_name)),
- glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name)),
- ],
- )
+ merge_exp = _get_lambda_from_func(merge)
if finish is not None:
- finish_exp = glotexp.Lambda(
- this=finish(Column(accumulator_name)).expression,
- expressions=[glotexp.to_identifier(accumulator_name, quoted=_lambda_quoted(accumulator_name))],
- )
+ finish_exp = _get_lambda_from_func(finish)
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp))
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp))
def transform(
- col: ColumnOrName,
- f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
- target_row_name: str = "x",
- row_count_name: str = "i",
+ col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]
) -> Column:
- num_arguments = len(signature(f).parameters)
- expressions = [glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))]
- columns = [Column(target_row_name)]
- if num_arguments > 1:
- columns.append(Column(row_count_name))
- expressions.append(glotexp.to_identifier(row_count_name, quoted=_lambda_quoted(row_count_name)))
-
- f_expression = glotexp.Lambda(this=f(*columns).expression, expressions=expressions)
+ f_expression = _get_lambda_from_func(f)
return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression))
-def exists(col: ColumnOrName, f: t.Callable[[Column], Column], target_row_name: str = "x") -> Column:
- f_expression = glotexp.Lambda(
- this=f(Column(target_row_name)).expression,
- expressions=[glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))],
- )
+def exists(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column:
+ f_expression = _get_lambda_from_func(f)
return Column.invoke_anonymous_function(col, "EXISTS", Column(f_expression))
-def forall(col: ColumnOrName, f: t.Callable[[Column], Column], target_row_name: str = "x") -> Column:
- f_expression = glotexp.Lambda(
- this=f(Column(target_row_name)).expression,
- expressions=[glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))],
- )
-
+def forall(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column:
+ f_expression = _get_lambda_from_func(f)
return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression))
-def filter(
- col: ColumnOrName,
- f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
- target_row_name: str = "x",
- row_count_name: str = "i",
-) -> Column:
- num_arguments = len(signature(f).parameters)
- expressions = [glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))]
- columns = [Column(target_row_name)]
- if num_arguments > 1:
- columns.append(Column(row_count_name))
- expressions.append(glotexp.to_identifier(row_count_name, quoted=_lambda_quoted(row_count_name)))
-
- f_expression = glotexp.Lambda(this=f(*columns).expression, expressions=expressions)
- return Column.invoke_anonymous_function(col, "FILTER", Column(f_expression))
-
-
-def zip_with(
- left: ColumnOrName,
- right: ColumnOrName,
- f: t.Callable[[Column, Column], Column],
- left_name: str = "x",
- right_name: str = "y",
-) -> Column:
- f_expression = glotexp.Lambda(
- this=f(Column(left_name), Column(right_name)).expression,
- expressions=[
- glotexp.to_identifier(left_name, quoted=_lambda_quoted(left_name)),
- glotexp.to_identifier(right_name, quoted=_lambda_quoted(right_name)),
- ],
- )
+def filter(col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]) -> Column:
+ f_expression = _get_lambda_from_func(f)
+ return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression)
+
+def zip_with(left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]) -> Column:
+ f_expression = _get_lambda_from_func(f)
return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression))
-def transform_keys(
- col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v"
-) -> Column:
- f_expression = glotexp.Lambda(
- this=f(Column(key_name), Column(value_name)).expression,
- expressions=[
- glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
- glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)),
- ],
- )
+def transform_keys(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column:
+ f_expression = _get_lambda_from_func(f)
return Column.invoke_anonymous_function(col, "TRANSFORM_KEYS", Column(f_expression))
-def transform_values(
- col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v"
-) -> Column:
- f_expression = glotexp.Lambda(
- this=f(Column(key_name), Column(value_name)).expression,
- expressions=[
- glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
- glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)),
- ],
- )
+def transform_values(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column:
+ f_expression = _get_lambda_from_func(f)
return Column.invoke_anonymous_function(col, "TRANSFORM_VALUES", Column(f_expression))
-def map_filter(
- col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v"
-) -> Column:
- f_expression = glotexp.Lambda(
- this=f(Column(key_name), Column(value_name)).expression,
- expressions=[
- glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
- glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)),
- ],
- )
+def map_filter(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column:
+ f_expression = _get_lambda_from_func(f)
return Column.invoke_anonymous_function(col, "MAP_FILTER", Column(f_expression))
@@ -1239,20 +1153,18 @@ def map_zip_with(
col1: ColumnOrName,
col2: ColumnOrName,
f: t.Union[t.Callable[[Column, Column, Column], Column]],
- key_name: str = "k",
- value1: str = "v1",
- value2: str = "v2",
) -> Column:
- f_expression = glotexp.Lambda(
- this=f(Column(key_name), Column(value1), Column(value2)).expression,
- expressions=[
- glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
- glotexp.to_identifier(value1, quoted=_lambda_quoted(value1)),
- glotexp.to_identifier(value2, quoted=_lambda_quoted(value2)),
- ],
- )
+ f_expression = _get_lambda_from_func(f)
return Column.invoke_anonymous_function(col1, "MAP_ZIP_WITH", col2, Column(f_expression))
def _lambda_quoted(value: str) -> t.Optional[bool]:
return False if value == "_" else None
+
+
+def _get_lambda_from_func(lambda_expression: t.Callable):
+ variables = [glotexp.to_identifier(x, quoted=_lambda_quoted(x)) for x in lambda_expression.__code__.co_varnames]
+ return glotexp.Lambda(
+ this=lambda_expression(*[Column(x) for x in variables]).expression,
+ expressions=variables,
+ )