summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dataframe/sql')
-rw-r--r--sqlglot/dataframe/sql/__init__.py18
-rw-r--r--sqlglot/dataframe/sql/_typing.pyi20
-rw-r--r--sqlglot/dataframe/sql/column.py295
-rw-r--r--sqlglot/dataframe/sql/dataframe.py730
-rw-r--r--sqlglot/dataframe/sql/functions.py1258
-rw-r--r--sqlglot/dataframe/sql/group.py57
-rw-r--r--sqlglot/dataframe/sql/normalize.py72
-rw-r--r--sqlglot/dataframe/sql/operations.py53
-rw-r--r--sqlglot/dataframe/sql/readwriter.py79
-rw-r--r--sqlglot/dataframe/sql/session.py148
-rw-r--r--sqlglot/dataframe/sql/transforms.py9
-rw-r--r--sqlglot/dataframe/sql/types.py208
-rw-r--r--sqlglot/dataframe/sql/util.py32
-rw-r--r--sqlglot/dataframe/sql/window.py117
14 files changed, 3096 insertions, 0 deletions
diff --git a/sqlglot/dataframe/sql/__init__.py b/sqlglot/dataframe/sql/__init__.py
new file mode 100644
index 0000000..3f90802
--- /dev/null
+++ b/sqlglot/dataframe/sql/__init__.py
@@ -0,0 +1,18 @@
+from sqlglot.dataframe.sql.column import Column
+from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions
+from sqlglot.dataframe.sql.group import GroupedData
+from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter
+from sqlglot.dataframe.sql.session import SparkSession
+from sqlglot.dataframe.sql.window import Window, WindowSpec
+
+__all__ = [
+ "SparkSession",
+ "DataFrame",
+ "GroupedData",
+ "Column",
+ "DataFrameNaFunctions",
+ "Window",
+ "WindowSpec",
+ "DataFrameReader",
+ "DataFrameWriter",
+]
diff --git a/sqlglot/dataframe/sql/_typing.pyi b/sqlglot/dataframe/sql/_typing.pyi
new file mode 100644
index 0000000..f1a03ea
--- /dev/null
+++ b/sqlglot/dataframe/sql/_typing.pyi
@@ -0,0 +1,20 @@
+from __future__ import annotations
+
+import datetime
+import typing as t
+
+from sqlglot import expressions as exp
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql.column import Column
+ from sqlglot.dataframe.sql.types import StructType
+
+ColumnLiterals = t.TypeVar(
+ "ColumnLiterals", bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
+)
+ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str])
+ColumnOrLiteral = t.TypeVar(
+ "ColumnOrLiteral", bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
+)
+SchemaInput = t.TypeVar("SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]])
+OutputExpressionContainer = t.TypeVar("OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert])
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
new file mode 100644
index 0000000..2391080
--- /dev/null
+++ b/sqlglot/dataframe/sql/column.py
@@ -0,0 +1,295 @@
+from __future__ import annotations
+
+import typing as t
+
+import sqlglot
+from sqlglot import expressions as exp
+from sqlglot.dataframe.sql.types import DataType
+from sqlglot.helper import flatten
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql._typing import ColumnOrLiteral
+ from sqlglot.dataframe.sql.window import WindowSpec
+
+
+class Column:
+ def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
+ if isinstance(expression, Column):
+ expression = expression.expression # type: ignore
+ elif expression is None or not isinstance(expression, (str, exp.Expression)):
+ expression = self._lit(expression).expression # type: ignore
+ self.expression: exp.Expression = sqlglot.maybe_parse(expression, dialect="spark")
+
+ def __repr__(self):
+ return repr(self.expression)
+
+ def __hash__(self):
+ return hash(self.expression)
+
+ def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore
+ return self.binary_op(exp.EQ, other)
+
+ def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore
+ return self.binary_op(exp.NEQ, other)
+
+ def __gt__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.GT, other)
+
+ def __ge__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.GTE, other)
+
+ def __lt__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.LT, other)
+
+ def __le__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.LTE, other)
+
+ def __and__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.And, other)
+
+ def __or__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.Or, other)
+
+ def __mod__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.Mod, other)
+
+ def __add__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.Add, other)
+
+ def __sub__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.Sub, other)
+
+ def __mul__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.Mul, other)
+
+ def __truediv__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.Div, other)
+
+ def __div__(self, other: ColumnOrLiteral) -> Column:
+ return self.binary_op(exp.Div, other)
+
+ def __neg__(self) -> Column:
+ return self.unary_op(exp.Neg)
+
+ def __radd__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.Add, other)
+
+ def __rsub__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.Sub, other)
+
+ def __rmul__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.Mul, other)
+
+ def __rdiv__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.Div, other)
+
+ def __rtruediv__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.Div, other)
+
+ def __rmod__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.Mod, other)
+
+ def __pow__(self, power: ColumnOrLiteral, modulo=None):
+ return Column(exp.Pow(this=self.expression, power=Column(power).expression))
+
+ def __rpow__(self, power: ColumnOrLiteral):
+ return Column(exp.Pow(this=Column(power).expression, power=self.expression))
+
+ def __invert__(self):
+ return self.unary_op(exp.Not)
+
+ def __rand__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.And, other)
+
+ def __ror__(self, other: ColumnOrLiteral) -> Column:
+ return self.inverse_binary_op(exp.Or, other)
+
+ @classmethod
+ def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
+ return cls(value)
+
+ @classmethod
+ def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]:
+ return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args]
+
+ @classmethod
+ def _lit(cls, value: ColumnOrLiteral) -> Column:
+ if isinstance(value, dict):
+ columns = [cls._lit(v).alias(k).expression for k, v in value.items()]
+ return cls(exp.Struct(expressions=columns))
+ return cls(exp.convert(value))
+
+ @classmethod
+ def invoke_anonymous_function(
+ cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral]
+ ) -> Column:
+ columns = [] if column is None else [cls.ensure_col(column)]
+ column_args = [cls.ensure_col(arg) for arg in args]
+ expressions = [x.expression for x in columns + column_args]
+ new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions)
+ return Column(new_expression)
+
+ @classmethod
+ def invoke_expression_over_column(
+ cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
+ ) -> Column:
+ ensured_column = None if column is None else cls.ensure_col(column)
+ new_expression = (
+ callable_expression(**kwargs)
+ if ensured_column is None
+ else callable_expression(this=ensured_column.column_expression, **kwargs)
+ )
+ return Column(new_expression)
+
+ def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
+ return Column(klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs))
+
+ def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
+ return Column(klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs))
+
+ def unary_op(self, klass: t.Callable, **kwargs) -> Column:
+ return Column(klass(this=self.column_expression, **kwargs))
+
+ @property
+ def is_alias(self):
+ return isinstance(self.expression, exp.Alias)
+
+ @property
+ def is_column(self):
+ return isinstance(self.expression, exp.Column)
+
+ @property
+ def column_expression(self) -> exp.Column:
+ return self.expression.unalias()
+
+ @property
+ def alias_or_name(self) -> str:
+ return self.expression.alias_or_name
+
+ @classmethod
+ def ensure_literal(cls, value) -> Column:
+ from sqlglot.dataframe.sql.functions import lit
+
+ if isinstance(value, cls):
+ value = value.expression
+ if not isinstance(value, exp.Literal):
+ return lit(value)
+ return Column(value)
+
+ def copy(self) -> Column:
+ return Column(self.expression.copy())
+
+ def set_table_name(self, table_name: str, copy=False) -> Column:
+ expression = self.expression.copy() if copy else self.expression
+ expression.set("table", exp.to_identifier(table_name))
+ return Column(expression)
+
+ def sql(self, **kwargs) -> Column:
+ return self.expression.sql(**{"dialect": "spark", **kwargs})
+
+ def alias(self, name: str) -> Column:
+ new_expression = exp.alias_(self.column_expression, name)
+ return Column(new_expression)
+
+ def asc(self) -> Column:
+ new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
+ return Column(new_expression)
+
+ def desc(self) -> Column:
+ new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
+ return Column(new_expression)
+
+ asc_nulls_first = asc
+
+ def asc_nulls_last(self) -> Column:
+ new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False)
+ return Column(new_expression)
+
+ def desc_nulls_first(self) -> Column:
+ new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True)
+ return Column(new_expression)
+
+ desc_nulls_last = desc
+
+ def when(self, condition: Column, value: t.Any) -> Column:
+ from sqlglot.dataframe.sql.functions import when
+
+ column_with_if = when(condition, value)
+ if not isinstance(self.expression, exp.Case):
+ return column_with_if
+ new_column = self.copy()
+ new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"])
+ return new_column
+
+ def otherwise(self, value: t.Any) -> Column:
+ from sqlglot.dataframe.sql.functions import lit
+
+ true_value = value if isinstance(value, Column) else lit(value)
+ new_column = self.copy()
+ new_column.expression.set("default", true_value.column_expression)
+ return new_column
+
+ def isNull(self) -> Column:
+ new_expression = exp.Is(this=self.column_expression, expression=exp.Null())
+ return Column(new_expression)
+
+ def isNotNull(self) -> Column:
+ new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
+ return Column(new_expression)
+
+ def cast(self, dataType: t.Union[str, DataType]):
+ """
+ Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
+ Sqlglot doesn't currently replicate this class so it only accepts a string
+ """
+ if isinstance(dataType, DataType):
+ dataType = dataType.simpleString()
+ new_expression = exp.Cast(this=self.column_expression, to=dataType)
+ return Column(new_expression)
+
+ def startswith(self, value: t.Union[str, Column]) -> Column:
+ value = self._lit(value) if not isinstance(value, Column) else value
+ return self.invoke_anonymous_function(self, "STARTSWITH", value)
+
+ def endswith(self, value: t.Union[str, Column]) -> Column:
+ value = self._lit(value) if not isinstance(value, Column) else value
+ return self.invoke_anonymous_function(self, "ENDSWITH", value)
+
+ def rlike(self, regexp: str) -> Column:
+ return self.invoke_expression_over_column(
+ column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression
+ )
+
+ def like(self, other: str):
+ return self.invoke_expression_over_column(self, exp.Like, expression=self._lit(other).expression)
+
+ def ilike(self, other: str):
+ return self.invoke_expression_over_column(self, exp.ILike, expression=self._lit(other).expression)
+
+ def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
+ startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
+ length = self._lit(length) if not isinstance(length, Column) else length
+ return Column.invoke_expression_over_column(
+ self, exp.Substring, start=startPos.expression, length=length.expression
+ )
+
+ def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]):
+ columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
+ expressions = [self._lit(x).expression for x in columns]
+ return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore
+
+ def between(
+ self,
+ lowerBound: t.Union[ColumnOrLiteral],
+ upperBound: t.Union[ColumnOrLiteral],
+ ) -> Column:
+ lower_bound_exp = self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
+ upper_bound_exp = self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
+ return Column(
+ exp.Between(this=self.column_expression, low=lower_bound_exp.expression, high=upper_bound_exp.expression)
+ )
+
+ def over(self, window: WindowSpec) -> Column:
+ window_expression = window.expression.copy()
+ window_expression.set("this", self.column_expression)
+ return Column(window_expression)
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
new file mode 100644
index 0000000..322dcf2
--- /dev/null
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -0,0 +1,730 @@
+from __future__ import annotations
+
+import functools
+import typing as t
+import zlib
+from copy import copy
+
+import sqlglot
+from sqlglot import expressions as exp
+from sqlglot.dataframe.sql import functions as F
+from sqlglot.dataframe.sql.column import Column
+from sqlglot.dataframe.sql.group import GroupedData
+from sqlglot.dataframe.sql.normalize import normalize
+from sqlglot.dataframe.sql.operations import Operation, operation
+from sqlglot.dataframe.sql.readwriter import DataFrameWriter
+from sqlglot.dataframe.sql.transforms import replace_id_value
+from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
+from sqlglot.dataframe.sql.window import Window
+from sqlglot.helper import ensure_list, object_to_dict
+from sqlglot.optimizer import optimize as optimize_func
+from sqlglot.optimizer.qualify_columns import qualify_columns
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql._typing import ColumnLiterals, OutputExpressionContainer
+ from sqlglot.dataframe.sql.session import SparkSession
+
+
+JOIN_HINTS = {
+ "BROADCAST",
+ "BROADCASTJOIN",
+ "MAPJOIN",
+ "MERGE",
+ "SHUFFLEMERGE",
+ "MERGEJOIN",
+ "SHUFFLE_HASH",
+ "SHUFFLE_REPLICATE_NL",
+}
+
+
+class DataFrame:
+ def __init__(
+ self,
+ spark: SparkSession,
+ expression: exp.Select,
+ branch_id: t.Optional[str] = None,
+ sequence_id: t.Optional[str] = None,
+ last_op: Operation = Operation.INIT,
+ pending_hints: t.Optional[t.List[exp.Expression]] = None,
+ output_expression_container: t.Optional[OutputExpressionContainer] = None,
+ **kwargs,
+ ):
+ self.spark = spark
+ self.expression = expression
+ self.branch_id = branch_id or self.spark._random_branch_id
+ self.sequence_id = sequence_id or self.spark._random_sequence_id
+ self.last_op = last_op
+ self.pending_hints = pending_hints or []
+ self.output_expression_container = output_expression_container or exp.Select()
+
+ def __getattr__(self, column_name: str) -> Column:
+ return self[column_name]
+
+ def __getitem__(self, column_name: str) -> Column:
+ column_name = f"{self.branch_id}.{column_name}"
+ return Column(column_name)
+
+ def __copy__(self):
+ return self.copy()
+
+ @property
+ def sparkSession(self):
+ return self.spark
+
+ @property
+ def write(self):
+ return DataFrameWriter(self)
+
+ @property
+ def latest_cte_name(self) -> str:
+ if not self.expression.ctes:
+ from_exp = self.expression.args["from"]
+ if from_exp.alias_or_name:
+ return from_exp.alias_or_name
+ table_alias = from_exp.find(exp.TableAlias)
+ if not table_alias:
+ raise RuntimeError(f"Could not find an alias name for this expression: {self.expression}")
+ return table_alias.alias_or_name
+ return self.expression.ctes[-1].alias
+
+ @property
+ def pending_join_hints(self):
+ return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)]
+
+ @property
+ def pending_partition_hints(self):
+ return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)]
+
+ @property
+ def columns(self) -> t.List[str]:
+ return self.expression.named_selects
+
+ @property
+ def na(self) -> DataFrameNaFunctions:
+ return DataFrameNaFunctions(self)
+
+ def _replace_cte_names_with_hashes(self, expression: exp.Select):
+ expression = expression.copy()
+ ctes = expression.ctes
+ replacement_mapping = {}
+ for cte in ctes:
+ old_name_id = cte.args["alias"].this
+ new_hashed_id = exp.to_identifier(
+ self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
+ )
+ replacement_mapping[old_name_id] = new_hashed_id
+ cte.set("alias", exp.TableAlias(this=new_hashed_id))
+ expression = expression.transform(replace_id_value, replacement_mapping)
+ return expression
+
+ def _create_cte_from_expression(
+ self,
+ expression: exp.Expression,
+ branch_id: t.Optional[str] = None,
+ sequence_id: t.Optional[str] = None,
+ **kwargs,
+ ) -> t.Tuple[exp.CTE, str]:
+ name = self.spark._random_name
+ expression_to_cte = expression.copy()
+ expression_to_cte.set("with", None)
+ cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
+ cte.set("branch_id", branch_id or self.branch_id)
+ cte.set("sequence_id", sequence_id or self.sequence_id)
+ return cte, name
+
+ def _ensure_list_of_columns(
+ self, cols: t.Union[str, t.Iterable[str], Column, t.Iterable[Column]]
+ ) -> t.List[Column]:
+ columns = ensure_list(cols)
+ columns = Column.ensure_cols(columns)
+ return columns
+
+ def _ensure_and_normalize_cols(self, cols):
+ cols = self._ensure_list_of_columns(cols)
+ normalize(self.spark, self.expression, cols)
+ return cols
+
+ def _ensure_and_normalize_col(self, col):
+ col = Column.ensure_col(col)
+ normalize(self.spark, self.expression, col)
+ return col
+
+ def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame:
+ df = self._resolve_pending_hints()
+ sequence_id = sequence_id or df.sequence_id
+ expression = df.expression.copy()
+ cte_expression, cte_name = df._create_cte_from_expression(expression=expression, sequence_id=sequence_id)
+ new_expression = df._add_ctes_to_expression(exp.Select(), expression.ctes + [cte_expression])
+ sel_columns = df._get_outer_select_columns(cte_expression)
+ new_expression = new_expression.from_(cte_name).select(*[x.alias_or_name for x in sel_columns])
+ return df.copy(expression=new_expression, sequence_id=sequence_id)
+
+ def _resolve_pending_hints(self) -> DataFrame:
+ df = self.copy()
+ if not self.pending_hints:
+ return df
+ expression = df.expression
+ hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
+ for hint in df.pending_partition_hints:
+ hint_expression.args.get("expressions").append(hint)
+ df.pending_hints.remove(hint)
+
+ join_aliases = {join_table.alias_or_name for join_table in get_tables_from_expression_with_join(expression)}
+ if join_aliases:
+ for hint in df.pending_join_hints:
+ for sequence_id_expression in hint.expressions:
+ sequence_id_or_name = sequence_id_expression.alias_or_name
+ sequence_ids_to_match = [sequence_id_or_name]
+ if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
+ sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[sequence_id_or_name]
+ matching_ctes = [
+ cte for cte in reversed(expression.ctes) if cte.args["sequence_id"] in sequence_ids_to_match
+ ]
+ for matching_cte in matching_ctes:
+ if matching_cte.alias_or_name in join_aliases:
+ sequence_id_expression.set("this", matching_cte.args["alias"].this)
+ df.pending_hints.remove(hint)
+ break
+ hint_expression.args.get("expressions").append(hint)
+ if hint_expression.expressions:
+ expression.set("hint", hint_expression)
+ return df
+
+ def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
+ hint_name = hint_name.upper()
+ hint_expression = (
+ exp.JoinHint(this=hint_name, expressions=[exp.to_table(parameter.alias_or_name) for parameter in args])
+ if hint_name in JOIN_HINTS
+ else exp.Anonymous(this=hint_name, expressions=[parameter.expression for parameter in args])
+ )
+ new_df = self.copy()
+ new_df.pending_hints.append(hint_expression)
+ return new_df
+
+ def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool):
+ other_df = other._convert_leaf_to_cte()
+ base_expression = self.expression.copy()
+ base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes)
+ all_ctes = base_expression.ctes
+ other_df.expression.set("with", None)
+ base_expression.set("with", None)
+ operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression)
+ operation.set("with", exp.With(expressions=all_ctes))
+ return self.copy(expression=operation)._convert_leaf_to_cte()
+
+ def _cache(self, storage_level: str):
+ df = self._convert_leaf_to_cte()
+ df.expression.ctes[-1].set("cache_storage_level", storage_level)
+ return df
+
+ @classmethod
+ def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select:
+ expression = expression.copy()
+ with_expression = expression.args.get("with")
+ if with_expression:
+ existing_ctes = with_expression.expressions
+ existsing_cte_names = {x.alias_or_name for x in existing_ctes}
+ for cte in ctes:
+ if cte.alias_or_name not in existsing_cte_names:
+ existing_ctes.append(cte)
+ else:
+ existing_ctes = ctes
+ expression.set("with", exp.With(expressions=existing_ctes))
+ return expression
+
+ @classmethod
+ def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]:
+ expression = item.expression if isinstance(item, DataFrame) else item
+ return [Column(x) for x in expression.find(exp.Select).expressions]
+
+ @classmethod
+ def _create_hash_from_expression(cls, expression: exp.Select):
+ value = expression.sql(dialect="spark").encode("utf-8")
+ return f"t{zlib.crc32(value)}"[:6]
+
+ def _get_select_expressions(
+ self,
+ ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
+ select_expressions: t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]] = []
+ main_select_ctes: t.List[exp.CTE] = []
+ for cte in self.expression.ctes:
+ cache_storage_level = cte.args.get("cache_storage_level")
+ if cache_storage_level:
+ select_expression = cte.this.copy()
+ select_expression.set("with", exp.With(expressions=copy(main_select_ctes)))
+ select_expression.set("cte_alias_name", cte.alias_or_name)
+ select_expression.set("cache_storage_level", cache_storage_level)
+ select_expressions.append((exp.Cache, select_expression))
+ else:
+ main_select_ctes.append(cte)
+ main_select = self.expression.copy()
+ if main_select_ctes:
+ main_select.set("with", exp.With(expressions=main_select_ctes))
+ expression_select_pair = (type(self.output_expression_container), main_select)
+ select_expressions.append(expression_select_pair) # type: ignore
+ return select_expressions
+
+ def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
+ df = self._resolve_pending_hints()
+ select_expressions = df._get_select_expressions()
+ output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
+ replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
+ for expression_type, select_expression in select_expressions:
+ select_expression = select_expression.transform(replace_id_value, replacement_mapping)
+ if optimize:
+ select_expression = optimize_func(select_expression)
+ select_expression = df._replace_cte_names_with_hashes(select_expression)
+ expression: t.Union[exp.Select, exp.Cache, exp.Drop]
+ if expression_type == exp.Cache:
+ cache_table_name = df._create_hash_from_expression(select_expression)
+ cache_table = exp.to_table(cache_table_name)
+ original_alias_name = select_expression.args["cte_alias_name"]
+ replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(cache_table_name)
+ sqlglot.schema.add_table(cache_table_name, select_expression.named_selects)
+ cache_storage_level = select_expression.args["cache_storage_level"]
+ options = [
+ exp.Literal.string("storageLevel"),
+ exp.Literal.string(cache_storage_level),
+ ]
+ expression = exp.Cache(this=cache_table, expression=select_expression, lazy=True, options=options)
+ # We will drop the "view" if it exists before running the cache table
+ output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
+ elif expression_type == exp.Create:
+ expression = df.output_expression_container.copy()
+ expression.set("expression", select_expression)
+ elif expression_type == exp.Insert:
+ expression = df.output_expression_container.copy()
+ select_without_ctes = select_expression.copy()
+ select_without_ctes.set("with", None)
+ expression.set("expression", select_without_ctes)
+ if select_expression.ctes:
+ expression.set("with", exp.With(expressions=select_expression.ctes))
+ elif expression_type == exp.Select:
+ expression = select_expression
+ else:
+ raise ValueError(f"Invalid expression type: {expression_type}")
+ output_expressions.append(expression)
+
+ return [expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions]
+
+ def copy(self, **kwargs) -> DataFrame:
+ return DataFrame(**object_to_dict(self, **kwargs))
+
+ @operation(Operation.SELECT)
+ def select(self, *cols, **kwargs) -> DataFrame:
+ cols = self._ensure_and_normalize_cols(cols)
+ kwargs["append"] = kwargs.get("append", False)
+ if self.expression.args.get("joins"):
+ ambiguous_cols = [col for col in cols if not col.column_expression.table]
+ if ambiguous_cols:
+ join_table_identifiers = [x.this for x in get_tables_from_expression_with_join(self.expression)]
+ cte_names_in_join = [x.this for x in join_table_identifiers]
+ for ambiguous_col in ambiguous_cols:
+ ctes_with_column = [
+ cte
+ for cte in self.expression.ctes
+ if cte.alias_or_name in cte_names_in_join
+ and ambiguous_col.alias_or_name in cte.this.named_selects
+ ]
+ # If the select column does not specify a table and there is a join
+ # then we assume they are referring to the left table
+ if len(ctes_with_column) > 1:
+ table_identifier = self.expression.args["from"].args["expressions"][0].this
+ else:
+ table_identifier = ctes_with_column[0].args["alias"].this
+ ambiguous_col.expression.set("table", table_identifier)
+ expression = self.expression.select(*[x.expression for x in cols], **kwargs)
+ qualify_columns(expression, sqlglot.schema)
+ return self.copy(expression=expression, **kwargs)
+
+ @operation(Operation.NO_OP)
+ def alias(self, name: str, **kwargs) -> DataFrame:
+ new_sequence_id = self.spark._random_sequence_id
+ df = self.copy()
+ for join_hint in df.pending_join_hints:
+ for expression in join_hint.expressions:
+ if expression.alias_or_name == self.sequence_id:
+ expression.set("this", Column.ensure_col(new_sequence_id).expression)
+ df.spark._add_alias_to_mapping(name, new_sequence_id)
+ return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
+
+ @operation(Operation.WHERE)
+ def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
+ col = self._ensure_and_normalize_col(column)
+ return self.copy(expression=self.expression.where(col.expression))
+
+ filter = where
+
+ @operation(Operation.GROUP_BY)
+ def groupBy(self, *cols, **kwargs) -> GroupedData:
+ columns = self._ensure_and_normalize_cols(cols)
+ return GroupedData(self, columns, self.last_op)
+
+ @operation(Operation.SELECT)
+ def agg(self, *exprs, **kwargs) -> DataFrame:
+ cols = self._ensure_and_normalize_cols(exprs)
+ return self.groupBy().agg(*cols)
+
+ @operation(Operation.FROM)
+ def join(
+ self, other_df: DataFrame, on: t.Union[str, t.List[str], Column, t.List[Column]], how: str = "inner", **kwargs
+ ) -> DataFrame:
+ other_df = other_df._convert_leaf_to_cte()
+ pre_join_self_latest_cte_name = self.latest_cte_name
+ columns = self._ensure_and_normalize_cols(on)
+ join_type = how.replace("_", " ")
+ if isinstance(columns[0].expression, exp.Column):
+ join_columns = [Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns]
+ join_clause = functools.reduce(
+ lambda x, y: x & y,
+ [
+ col.copy().set_table_name(pre_join_self_latest_cte_name)
+ == col.copy().set_table_name(other_df.latest_cte_name)
+ for col in columns
+ ],
+ )
+ else:
+ if len(columns) > 1:
+ columns = [functools.reduce(lambda x, y: x & y, columns)]
+ join_clause = columns[0]
+ join_columns = [
+ Column(x).set_table_name(pre_join_self_latest_cte_name)
+ if i % 2 == 0
+ else Column(x).set_table_name(other_df.latest_cte_name)
+ for i, x in enumerate(join_clause.expression.find_all(exp.Column))
+ ]
+ self_columns = [
+ column.set_table_name(pre_join_self_latest_cte_name, copy=True)
+ for column in self._get_outer_select_columns(self)
+ ]
+ other_columns = [
+ column.set_table_name(other_df.latest_cte_name, copy=True)
+ for column in self._get_outer_select_columns(other_df)
+ ]
+ column_value_mapping = {
+ column.alias_or_name if not isinstance(column.expression.this, exp.Star) else column.sql(): column
+ for column in other_columns + self_columns + join_columns
+ }
+ all_columns = [
+ column_value_mapping[name]
+ for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns}
+ ]
+ new_df = self.copy(
+ expression=self.expression.join(other_df.latest_cte_name, on=join_clause.expression, join_type=join_type)
+ )
+ new_df.expression = new_df._add_ctes_to_expression(new_df.expression, other_df.expression.ctes)
+ new_df.pending_hints.extend(other_df.pending_hints)
+ new_df = new_df.select.__wrapped__(new_df, *all_columns)
+ return new_df
+
+ @operation(Operation.ORDER_BY)
+ def orderBy(
+ self, *cols: t.Union[str, Column], ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None
+ ) -> DataFrame:
+ """
+ This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
+ has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
+ is unlikely to come up.
+ """
+ columns = self._ensure_and_normalize_cols(cols)
+ pre_ordered_col_indexes = [
+ x
+ for x in [i if isinstance(col.expression, exp.Ordered) else None for i, col in enumerate(columns)]
+ if x is not None
+ ]
+ if ascending is None:
+ ascending = [True] * len(columns)
+ elif not isinstance(ascending, list):
+ ascending = [ascending] * len(columns)
+ ascending = [bool(x) for i, x in enumerate(ascending)]
+ assert len(columns) == len(
+ ascending
+ ), "The length of items in ascending must equal the number of columns provided"
+ col_and_ascending = list(zip(columns, ascending))
+ order_by_columns = [
+ exp.Ordered(this=col.expression, desc=not asc)
+ if i not in pre_ordered_col_indexes
+ else columns[i].column_expression
+ for i, (col, asc) in enumerate(col_and_ascending)
+ ]
+ return self.copy(expression=self.expression.order_by(*order_by_columns))
+
+ sort = orderBy
+
+ @operation(Operation.FROM)
+ def union(self, other: DataFrame) -> DataFrame:
+ return self._set_operation(exp.Union, other, False)
+
+ unionAll = union
+
+ @operation(Operation.FROM)
+ def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
+ l_columns = self.columns
+ r_columns = other.columns
+ if not allowMissingColumns:
+ l_expressions = l_columns
+ r_expressions = l_columns
+ else:
+ l_expressions = []
+ r_expressions = []
+ r_columns_unused = copy(r_columns)
+ for l_column in l_columns:
+ l_expressions.append(l_column)
+ if l_column in r_columns:
+ r_expressions.append(l_column)
+ r_columns_unused.remove(l_column)
+ else:
+ r_expressions.append(exp.alias_(exp.Null(), l_column))
+ for r_column in r_columns_unused:
+ l_expressions.append(exp.alias_(exp.Null(), r_column))
+ r_expressions.append(r_column)
+ r_df = other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
+ l_df = self.copy()
+ if allowMissingColumns:
+ l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
+ return l_df._set_operation(exp.Union, r_df, False)
+
+ @operation(Operation.FROM)
+ def intersect(self, other: DataFrame) -> DataFrame:
+ return self._set_operation(exp.Intersect, other, True)
+
+ @operation(Operation.FROM)
+ def intersectAll(self, other: DataFrame) -> DataFrame:
+ return self._set_operation(exp.Intersect, other, False)
+
+ @operation(Operation.FROM)
+ def exceptAll(self, other: DataFrame) -> DataFrame:
+ return self._set_operation(exp.Except, other, False)
+
+ @operation(Operation.SELECT)
+ def distinct(self) -> DataFrame:
+ return self.copy(expression=self.expression.distinct())
+
+ @operation(Operation.SELECT)
+ def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
+ if not subset:
+ return self.distinct()
+ column_names = ensure_list(subset)
+ window = Window.partitionBy(*column_names).orderBy(*column_names)
+ return (
+ self.copy()
+ .withColumn("row_num", F.row_number().over(window))
+ .where(F.col("row_num") == F.lit(1))
+ .drop("row_num")
+ )
+
+ @operation(Operation.FROM)
+ def dropna(
+ self,
+ how: str = "any",
+ thresh: t.Optional[int] = None,
+ subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
+ ) -> DataFrame:
+ minimum_non_null = thresh or 0 # will be determined later if thresh is null
+ new_df = self.copy()
+ all_columns = self._get_outer_select_columns(new_df.expression)
+ if subset:
+ null_check_columns = self._ensure_and_normalize_cols(subset)
+ else:
+ null_check_columns = all_columns
+ if thresh is None:
+ minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
+ else:
+ minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
+ if minimum_num_nulls > len(null_check_columns):
+ raise RuntimeError(
+ f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
+ f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
+ )
+ if_null_checks = [F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns]
+ nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
+ num_nulls = nulls_added_together.alias("num_nulls")
+ new_df = new_df.select(num_nulls, append=True)
+ filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
+ final_df = filtered_df.select(*all_columns)
+ return final_df
+
+ @operation(Operation.FROM)
+ def fillna(
+ self,
+ value: t.Union[ColumnLiterals],
+ subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
+ ) -> DataFrame:
+ """
+ Functionality Difference: If you provide a value to replace a null and that type conflicts
+ with the type of the column then PySpark will just ignore your replacement.
+ This will try to cast them to be the same in some cases. So they won't always match.
+ Best to not mix types so make sure replacement is the same type as the column
+
+ Possibility for improvement: Use `typeof` function to get the type of the column
+ and check if it matches the type of the value provided. If not then make it null.
+ """
+ from sqlglot.dataframe.sql.functions import lit
+
+ values = None
+ columns = None
+ new_df = self.copy()
+ all_columns = self._get_outer_select_columns(new_df.expression)
+ all_column_mapping = {column.alias_or_name: column for column in all_columns}
+ if isinstance(value, dict):
+ values = value.values()
+ columns = self._ensure_and_normalize_cols(list(value))
+ if not columns:
+ columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
+ if not values:
+ values = [value] * len(columns)
+ value_columns = [lit(value) for value in values]
+
+ null_replacement_mapping = {
+ column.alias_or_name: (F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name))
+ for column, value in zip(columns, value_columns)
+ }
+ null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
+ null_replacement_columns = [null_replacement_mapping[column.alias_or_name] for column in all_columns]
+ new_df = new_df.select(*null_replacement_columns)
+ return new_df
+
+ @operation(Operation.FROM)
+ def replace(
+ self,
+ to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
+ value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
+ subset: t.Optional[t.Union[str, t.List[str]]] = None,
+ ) -> DataFrame:
+ from sqlglot.dataframe.sql.functions import lit
+
+ old_values = None
+ subset = ensure_list(subset)
+ new_df = self.copy()
+ all_columns = self._get_outer_select_columns(new_df.expression)
+ all_column_mapping = {column.alias_or_name: column for column in all_columns}
+
+ columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
+ if isinstance(to_replace, dict):
+ old_values = list(to_replace)
+ new_values = list(to_replace.values())
+ elif not old_values and isinstance(to_replace, list):
+ assert isinstance(value, list), "value must be a list since the replacements are a list"
+ assert len(to_replace) == len(value), "the replacements and values must be the same length"
+ old_values = to_replace
+ new_values = value
+ else:
+ old_values = [to_replace] * len(columns)
+ new_values = [value] * len(columns)
+ old_values = [lit(value) for value in old_values]
+ new_values = [lit(value) for value in new_values]
+
+ replacement_mapping = {}
+ for column in columns:
+ expression = Column(None)
+ for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
+ if i == 0:
+ expression = F.when(column == old_value, new_value)
+ else:
+ expression = expression.when(column == old_value, new_value) # type: ignore
+ replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
+ column.expression.alias_or_name
+ )
+
+ replacement_mapping = {**all_column_mapping, **replacement_mapping}
+ replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
+ new_df = new_df.select(*replacement_columns)
+ return new_df
+
+ @operation(Operation.SELECT)
+ def withColumn(self, colName: str, col: Column) -> DataFrame:
+ col = self._ensure_and_normalize_col(col)
+ existing_col_names = self.expression.named_selects
+ existing_col_index = existing_col_names.index(colName) if colName in existing_col_names else None
+ if existing_col_index:
+ expression = self.expression.copy()
+ expression.expressions[existing_col_index] = col.expression
+ return self.copy(expression=expression)
+ return self.copy().select(col.alias(colName), append=True)
+
+ @operation(Operation.SELECT)
+ def withColumnRenamed(self, existing: str, new: str):
+ expression = self.expression.copy()
+ existing_columns = [expression for expression in expression.expressions if expression.alias_or_name == existing]
+ if not existing_columns:
+ raise ValueError("Tried to rename a column that doesn't exist")
+ for existing_column in existing_columns:
+ if isinstance(existing_column, exp.Column):
+ existing_column.replace(exp.alias_(existing_column.copy(), new))
+ else:
+ existing_column.set("alias", exp.to_identifier(new))
+ return self.copy(expression=expression)
+
+ @operation(Operation.SELECT)
+ def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
+ all_columns = self._get_outer_select_columns(self.expression)
+ drop_cols = self._ensure_and_normalize_cols(cols)
+ new_columns = [
+ col
+ for col in all_columns
+ if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
+ ]
+ return self.copy().select(*new_columns, append=False)
+
+ @operation(Operation.LIMIT)
+ def limit(self, num: int) -> DataFrame:
+ return self.copy(expression=self.expression.limit(num))
+
+ @operation(Operation.NO_OP)
+ def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
+ parameter_list = ensure_list(parameters)
+ parameter_columns = (
+ self._ensure_list_of_columns(parameter_list) if parameters else Column.ensure_cols([self.sequence_id])
+ )
+ return self._hint(name, parameter_columns)
+
+ @operation(Operation.NO_OP)
+ def repartition(self, numPartitions: t.Union[int, str], *cols: t.Union[int, str]) -> DataFrame:
+ num_partitions = Column.ensure_cols(ensure_list(numPartitions))
+ columns = self._ensure_and_normalize_cols(cols)
+ args = num_partitions + columns
+ return self._hint("repartition", args)
+
+ @operation(Operation.NO_OP)
+ def coalesce(self, numPartitions: int) -> DataFrame:
+ num_partitions = Column.ensure_cols([numPartitions])
+ return self._hint("coalesce", num_partitions)
+
+ @operation(Operation.NO_OP)
+ def cache(self) -> DataFrame:
+ return self._cache(storage_level="MEMORY_AND_DISK")
+
+ @operation(Operation.NO_OP)
+ def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
+ """
+ Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
+ """
+ return self._cache(storageLevel)
+
+
+class DataFrameNaFunctions:
+ def __init__(self, df: DataFrame):
+ self.df = df
+
+ def drop(
+ self,
+ how: str = "any",
+ thresh: t.Optional[int] = None,
+ subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
+ ) -> DataFrame:
+ return self.df.dropna(how=how, thresh=thresh, subset=subset)
+
+ def fill(
+ self,
+ value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
+ subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
+ ) -> DataFrame:
+ return self.df.fillna(value=value, subset=subset)
+
+ def replace(
+ self,
+ to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
+ value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
+ subset: t.Optional[t.Union[str, t.List[str]]] = None,
+ ) -> DataFrame:
+ return self.df.replace(to_replace=to_replace, value=value, subset=subset)
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
new file mode 100644
index 0000000..4c6de30
--- /dev/null
+++ b/sqlglot/dataframe/sql/functions.py
@@ -0,0 +1,1258 @@
+from __future__ import annotations
+
+import typing as t
+from inspect import signature
+
+from sqlglot import expressions as glotexp
+from sqlglot.dataframe.sql.column import Column
+from sqlglot.helper import ensure_list
+from sqlglot.helper import flatten as _flatten
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql._typing import ColumnOrLiteral, ColumnOrName
+ from sqlglot.dataframe.sql.dataframe import DataFrame
+
+
+def col(column_name: t.Union[ColumnOrName, t.Any]) -> Column:
+ return Column(column_name)
+
+
+def lit(value: t.Optional[t.Any] = None) -> Column:
+ if isinstance(value, str):
+ return Column(glotexp.Literal.string(str(value)))
+ return Column(value)
+
+
+def greatest(*cols: ColumnOrName) -> Column:
+ columns = [Column.ensure_col(col) for col in cols]
+ return Column.invoke_expression_over_column(
+ columns[0], glotexp.Greatest, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None
+ )
+
+
+def least(*cols: ColumnOrName) -> Column:
+ columns = [Column.ensure_col(col) for col in cols]
+ return Column.invoke_expression_over_column(
+ columns[0], glotexp.Least, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None
+ )
+
+
+def count_distinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
+ columns = [Column.ensure_col(x) for x in [col] + list(cols)]
+ return Column(glotexp.Count(this=glotexp.Distinct(expressions=[x.expression for x in columns])))
+
+
+def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
+ return count_distinct(col, *cols)
+
+
+def when(condition: Column, value: t.Any) -> Column:
+ true_value = value if isinstance(value, Column) else lit(value)
+ return Column(glotexp.Case(ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)]))
+
+
+def asc(col: ColumnOrName) -> Column:
+ return Column.ensure_col(col).asc()
+
+
+def desc(col: ColumnOrName):
+ return Column.ensure_col(col).desc()
+
+
+def broadcast(df: DataFrame) -> DataFrame:
+ return df.hint("broadcast")
+
+
+def sqrt(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Sqrt)
+
+
+def abs(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Abs)
+
+
+def max(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Max)
+
+
+def min(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Min)
+
+
+def max_by(col: ColumnOrName, ord: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MAX_BY", ord)
+
+
+def min_by(col: ColumnOrName, ord: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MIN_BY", ord)
+
+
+def count(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Count)
+
+
+def sum(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Sum)
+
+
+def avg(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Avg)
+
+
+def mean(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MEAN")
+
+
+def sumDistinct(col: ColumnOrName) -> Column:
+ return sum_distinct(col)
+
+
+def sum_distinct(col: ColumnOrName) -> Column:
+ raise NotImplementedError("Sum distinct is not currently implemented")
+
+
+def product(col: ColumnOrName) -> Column:
+ raise NotImplementedError("Product is not currently implemented")
+
+
+def acos(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ACOS")
+
+
+def acosh(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ACOSH")
+
+
+def asin(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ASIN")
+
+
+def asinh(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ASINH")
+
+
+def atan(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ATAN")
+
+
+def atan2(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
+ return Column.invoke_anonymous_function(col1, "ATAN2", col2)
+
+
+def atanh(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ATANH")
+
+
+def cbrt(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "CBRT")
+
+
+def ceil(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Ceil)
+
+
+def cos(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "COS")
+
+
+def cosh(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "COSH")
+
+
+def cot(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "COT")
+
+
+def csc(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "CSC")
+
+
+def exp(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Exp)
+
+
+def expm1(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "EXPM1")
+
+
+def floor(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Floor)
+
+
+def log10(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Log10)
+
+
+def log1p(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "LOG1P")
+
+
+def log2(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Log2)
+
+
+def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column:
+ if arg2 is None:
+ return Column.invoke_expression_over_column(arg1, glotexp.Ln)
+ return Column.invoke_expression_over_column(arg1, glotexp.Log, expression=Column.ensure_col(arg2).expression)
+
+
+def rint(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "RINT")
+
+
+def sec(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SEC")
+
+
+def signum(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SIGNUM")
+
+
+def sin(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SIN")
+
+
+def sinh(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SINH")
+
+
+def tan(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "TAN")
+
+
+def tanh(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "TANH")
+
+
+def toDegrees(col: ColumnOrName) -> Column:
+ return degrees(col)
+
+
+def degrees(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "DEGREES")
+
+
+def toRadians(col: ColumnOrName) -> Column:
+ return radians(col)
+
+
+def radians(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "RADIANS")
+
+
+def bitwiseNOT(col: ColumnOrName) -> Column:
+ return bitwise_not(col)
+
+
+def bitwise_not(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.BitwiseNot)
+
+
+def asc_nulls_first(col: ColumnOrName) -> Column:
+ return Column.ensure_col(col).asc_nulls_first()
+
+
+def asc_nulls_last(col: ColumnOrName) -> Column:
+ return Column.ensure_col(col).asc_nulls_last()
+
+
+def desc_nulls_first(col: ColumnOrName) -> Column:
+ return Column.ensure_col(col).desc_nulls_first()
+
+
+def desc_nulls_last(col: ColumnOrName) -> Column:
+ return Column.ensure_col(col).desc_nulls_last()
+
+
+def stddev(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Stddev)
+
+
+def stddev_samp(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.StddevSamp)
+
+
+def stddev_pop(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.StddevPop)
+
+
+def variance(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Variance)
+
+
+def var_samp(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Variance)
+
+
+def var_pop(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.VariancePop)
+
+
+def skewness(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SKEWNESS")
+
+
+def kurtosis(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "KURTOSIS")
+
+
+def collect_list(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.ArrayAgg)
+
+
+def collect_set(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.SetAgg)
+
+
+def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
+ return Column.invoke_anonymous_function(col1, "HYPOT", col2)
+
+
+def pow(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
+ return Column.invoke_anonymous_function(col1, "POW", col2)
+
+
+def row_number() -> Column:
+ return Column(glotexp.Anonymous(this="ROW_NUMBER"))
+
+
+def dense_rank() -> Column:
+ return Column(glotexp.Anonymous(this="DENSE_RANK"))
+
+
+def rank() -> Column:
+ return Column(glotexp.Anonymous(this="RANK"))
+
+
+def cume_dist() -> Column:
+ return Column(glotexp.Anonymous(this="CUME_DIST"))
+
+
+def percent_rank() -> Column:
+ return Column(glotexp.Anonymous(this="PERCENT_RANK"))
+
+
+def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column:
+ return approx_count_distinct(col, rsd)
+
+
+def approx_count_distinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column:
+ if rsd is None:
+ return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct)
+ return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct, accuracy=Column.ensure_col(rsd).expression)
+
+
+def coalesce(*cols: ColumnOrName) -> Column:
+ columns = [Column.ensure_col(col) for col in cols]
+ return Column.invoke_expression_over_column(
+ columns[0], glotexp.Coalesce, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None
+ )
+
+
+def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "CORR", col2)
+
+
+def covar_pop(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "COVAR_POP", col2)
+
+
+def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "COVAR_SAMP", col2)
+
+
+def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
+ if ignorenulls is not None:
+ return Column.invoke_anonymous_function(col, "FIRST", ignorenulls)
+ return Column.invoke_anonymous_function(col, "FIRST")
+
+
+def grouping_id(*cols: ColumnOrName) -> Column:
+ if not cols:
+ return Column.invoke_anonymous_function(None, "GROUPING_ID")
+ if len(cols) == 1:
+ return Column.invoke_anonymous_function(cols[0], "GROUPING_ID")
+ return Column.invoke_anonymous_function(cols[0], "GROUPING_ID", *cols[1:])
+
+
+def input_file_name() -> Column:
+ return Column.invoke_anonymous_function(None, "INPUT_FILE_NAME")
+
+
+def isnan(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ISNAN")
+
+
+def isnull(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ISNULL")
+
+
+def last(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
+ if ignorenulls is not None:
+ return Column.invoke_anonymous_function(col, "LAST", ignorenulls)
+ return Column.invoke_anonymous_function(col, "LAST")
+
+
+def monotonically_increasing_id() -> Column:
+ return Column.invoke_anonymous_function(None, "MONOTONICALLY_INCREASING_ID")
+
+
+def nanvl(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "NANVL", col2)
+
+
+def percentile_approx(
+ col: ColumnOrName,
+ percentage: t.Union[ColumnOrLiteral, t.List[float], t.Tuple[float]],
+ accuracy: t.Optional[t.Union[ColumnOrLiteral]] = None,
+) -> Column:
+ if accuracy:
+ return Column.invoke_anonymous_function(col, "PERCENTILE_APPROX", percentage, accuracy)
+ return Column.invoke_anonymous_function(col, "PERCENTILE_APPROX", percentage)
+
+
+def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
+ return Column.invoke_anonymous_function(seed, "RAND")
+
+
+def randn(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
+ return Column.invoke_anonymous_function(seed, "RANDN")
+
+
+def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
+ if scale is not None:
+ return Column.invoke_expression_over_column(col, glotexp.Round, decimals=glotexp.convert(scale))
+ return Column.invoke_expression_over_column(col, glotexp.Round)
+
+
+def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
+ if scale is not None:
+ return Column.invoke_anonymous_function(col, "BROUND", scale)
+ return Column.invoke_anonymous_function(col, "BROUND")
+
+
+def shiftleft(col: ColumnOrName, numBits: int) -> Column:
+ return Column.invoke_expression_over_column(
+ col, glotexp.BitwiseLeftShift, expression=Column.ensure_col(numBits).expression
+ )
+
+
+def shiftLeft(col: ColumnOrName, numBits: int) -> Column:
+ return shiftleft(col, numBits)
+
+
+def shiftright(col: ColumnOrName, numBits: int) -> Column:
+ return Column.invoke_expression_over_column(
+ col, glotexp.BitwiseRightShift, expression=Column.ensure_col(numBits).expression
+ )
+
+
+def shiftRight(col: ColumnOrName, numBits: int) -> Column:
+ return shiftright(col, numBits)
+
+
+def shiftrightunsigned(col: ColumnOrName, numBits: int) -> Column:
+ return Column.invoke_anonymous_function(col, "SHIFTRIGHTUNSIGNED", numBits)
+
+
+def shiftRightUnsigned(col: ColumnOrName, numBits: int) -> Column:
+ return shiftrightunsigned(col, numBits)
+
+
+def expr(str: str) -> Column:
+ return Column(str)
+
+
+def struct(col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOrName) -> Column:
+ columns = ensure_list(col) + list(cols)
+ expressions = [Column.ensure_col(column).expression for column in columns]
+ return Column(glotexp.Struct(expressions=expressions))
+
+
+def conv(col: ColumnOrName, fromBase: int, toBase: int) -> Column:
+ return Column.invoke_anonymous_function(col, "CONV", fromBase, toBase)
+
+
+def factorial(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "FACTORIAL")
+
+
+def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None) -> Column:
+ if default is not None:
+ return Column.invoke_anonymous_function(col, "LAG", offset, default)
+ if offset != 1:
+ return Column.invoke_anonymous_function(col, "LAG", offset)
+ return Column.invoke_anonymous_function(col, "LAG")
+
+
+def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None) -> Column:
+ if default is not None:
+ return Column.invoke_anonymous_function(col, "LEAD", offset, default)
+ if offset != 1:
+ return Column.invoke_anonymous_function(col, "LEAD", offset)
+ return Column.invoke_anonymous_function(col, "LEAD")
+
+
+def nth_value(col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None) -> Column:
+ if ignoreNulls is not None:
+ raise NotImplementedError("There is currently not support for `ignoreNulls` parameter")
+ if offset != 1:
+ return Column.invoke_anonymous_function(col, "NTH_VALUE", offset)
+ return Column.invoke_anonymous_function(col, "NTH_VALUE")
+
+
+def ntile(n: int) -> Column:
+ return Column.invoke_anonymous_function(None, "NTILE", n)
+
+
+def current_date() -> Column:
+ return Column.invoke_expression_over_column(None, glotexp.CurrentDate)
+
+
+def current_timestamp() -> Column:
+ return Column.invoke_expression_over_column(None, glotexp.CurrentTimestamp)
+
+
+def date_format(col: ColumnOrName, format: str) -> Column:
+ return Column.invoke_anonymous_function(col, "DATE_FORMAT", lit(format))
+
+
+def year(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Year)
+
+
+def quarter(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "QUARTER")
+
+
+def month(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Month)
+
+
+def dayofweek(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "DAYOFWEEK")
+
+
+def dayofmonth(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "DAYOFMONTH")
+
+
+def dayofyear(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "DAYOFYEAR")
+
+
+def hour(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "HOUR")
+
+
+def minute(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MINUTE")
+
+
+def second(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SECOND")
+
+
+def weekofyear(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "WEEKOFYEAR")
+
+
+def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(year, "MAKE_DATE", month, day)
+
+
+def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.DateAdd, expression=Column.ensure_col(days).expression)
+
+
+def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.DateSub, expression=Column.ensure_col(days).expression)
+
+
+def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(end, glotexp.DateDiff, expression=Column.ensure_col(start).expression)
+
+
+def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column:
+ return Column.invoke_anonymous_function(start, "ADD_MONTHS", months)
+
+
+def months_between(date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None) -> Column:
+ if roundOff is None:
+ return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2)
+ return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff)
+
+
+def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
+ if format is not None:
+ return Column.invoke_anonymous_function(col, "TO_DATE", lit(format))
+ return Column.invoke_anonymous_function(col, "TO_DATE")
+
+
+def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
+ if format is not None:
+ return Column.invoke_anonymous_function(col, "TO_TIMESTAMP", lit(format))
+ return Column.invoke_anonymous_function(col, "TO_TIMESTAMP")
+
+
+def trunc(col: ColumnOrName, format: str) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.DateTrunc, unit=lit(format).expression)
+
+
+def date_trunc(format: str, timestamp: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(timestamp, glotexp.TimestampTrunc, unit=lit(format).expression)
+
+
+def next_day(col: ColumnOrName, dayOfWeek: str) -> Column:
+ return Column.invoke_anonymous_function(col, "NEXT_DAY", lit(dayOfWeek))
+
+
+def last_day(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "LAST_DAY")
+
+
+def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
+ if format is not None:
+ return Column.invoke_anonymous_function(col, "FROM_UNIXTIME", lit(format))
+ return Column.invoke_anonymous_function(col, "FROM_UNIXTIME")
+
+
+def unix_timestamp(timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None) -> Column:
+ if format is not None:
+ return Column.invoke_anonymous_function(timestamp, "UNIX_TIMESTAMP", lit(format))
+ return Column.invoke_anonymous_function(timestamp, "UNIX_TIMESTAMP")
+
+
+def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
+ tz_column = tz if isinstance(tz, Column) else lit(tz)
+ return Column.invoke_anonymous_function(timestamp, "FROM_UTC_TIMESTAMP", tz_column)
+
+
+def to_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
+ tz_column = tz if isinstance(tz, Column) else lit(tz)
+ return Column.invoke_anonymous_function(timestamp, "TO_UTC_TIMESTAMP", tz_column)
+
+
+def timestamp_seconds(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "TIMESTAMP_SECONDS")
+
+
+def window(
+ timeColumn: ColumnOrName,
+ windowDuration: str,
+ slideDuration: t.Optional[str] = None,
+ startTime: t.Optional[str] = None,
+) -> Column:
+ if slideDuration is not None and startTime is not None:
+ return Column.invoke_anonymous_function(
+ timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration), lit(startTime)
+ )
+ if slideDuration is not None:
+ return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration))
+ if startTime is not None:
+ return Column.invoke_anonymous_function(
+ timeColumn, "WINDOW", lit(windowDuration), lit(windowDuration), lit(startTime)
+ )
+ return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration))
+
+
+def session_window(timeColumn: ColumnOrName, gapDuration: ColumnOrName) -> Column:
+ gap_duration_column = gapDuration if isinstance(gapDuration, Column) else lit(gapDuration)
+ return Column.invoke_anonymous_function(timeColumn, "SESSION_WINDOW", gap_duration_column)
+
+
+def crc32(col: ColumnOrName) -> Column:
+ column = col if isinstance(col, Column) else lit(col)
+ return Column.invoke_anonymous_function(column, "CRC32")
+
+
+def md5(col: ColumnOrName) -> Column:
+ column = col if isinstance(col, Column) else lit(col)
+ return Column.invoke_anonymous_function(column, "MD5")
+
+
+def sha1(col: ColumnOrName) -> Column:
+ column = col if isinstance(col, Column) else lit(col)
+ return Column.invoke_anonymous_function(column, "SHA1")
+
+
+def sha2(col: ColumnOrName, numBits: int) -> Column:
+ column = col if isinstance(col, Column) else lit(col)
+ num_bits = lit(numBits)
+ return Column.invoke_anonymous_function(column, "SHA2", num_bits)
+
+
+def hash(*cols: ColumnOrName) -> Column:
+ args = cols[1:] if len(cols) > 1 else []
+ return Column.invoke_anonymous_function(cols[0], "HASH", *args)
+
+
+def xxhash64(*cols: ColumnOrName) -> Column:
+ args = cols[1:] if len(cols) > 1 else []
+ return Column.invoke_anonymous_function(cols[0], "XXHASH64", *args)
+
+
+def assert_true(col: ColumnOrName, errorMsg: t.Optional[ColumnOrName] = None) -> Column:
+ if errorMsg is not None:
+ error_msg_col = errorMsg if isinstance(errorMsg, Column) else lit(errorMsg)
+ return Column.invoke_anonymous_function(col, "ASSERT_TRUE", error_msg_col)
+ return Column.invoke_anonymous_function(col, "ASSERT_TRUE")
+
+
+def raise_error(errorMsg: ColumnOrName) -> Column:
+ error_msg_col = errorMsg if isinstance(errorMsg, Column) else lit(errorMsg)
+ return Column.invoke_anonymous_function(error_msg_col, "RAISE_ERROR")
+
+
+def upper(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Upper)
+
+
+def lower(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Lower)
+
+
+def ascii(col: ColumnOrLiteral) -> Column:
+ return Column.invoke_anonymous_function(col, "ASCII")
+
+
+def base64(col: ColumnOrLiteral) -> Column:
+ return Column.invoke_anonymous_function(col, "BASE64")
+
+
+def unbase64(col: ColumnOrLiteral) -> Column:
+ return Column.invoke_anonymous_function(col, "UNBASE64")
+
+
+def ltrim(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "LTRIM")
+
+
+def rtrim(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "RTRIM")
+
+
+def trim(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Trim)
+
+
+def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
+ columns = [Column(col) for col in cols]
+ return Column.invoke_expression_over_column(
+ None, glotexp.ConcatWs, expressions=[x.expression for x in [lit(sep)] + list(columns)]
+ )
+
+
+def decode(col: ColumnOrName, charset: str) -> Column:
+ return Column.invoke_anonymous_function(col, "DECODE", lit(charset))
+
+
+def encode(col: ColumnOrName, charset: str) -> Column:
+ return Column.invoke_anonymous_function(col, "ENCODE", lit(charset))
+
+
+def format_number(col: ColumnOrName, d: int) -> Column:
+ return Column.invoke_anonymous_function(col, "FORMAT_NUMBER", lit(d))
+
+
+def format_string(format: str, *cols: ColumnOrName) -> Column:
+ format_col = lit(format)
+ columns = [Column.ensure_col(x) for x in cols]
+ return Column.invoke_anonymous_function(format_col, "FORMAT_STRING", *columns)
+
+
+def instr(col: ColumnOrName, substr: str) -> Column:
+ return Column.invoke_anonymous_function(col, "INSTR", lit(substr))
+
+
+def overlay(
+ src: ColumnOrName,
+ replace: ColumnOrName,
+ pos: t.Union[ColumnOrName, int],
+ len: t.Optional[t.Union[ColumnOrName, int]] = None,
+) -> Column:
+ if len is not None:
+ return Column.invoke_anonymous_function(src, "OVERLAY", replace, pos, len)
+ return Column.invoke_anonymous_function(src, "OVERLAY", replace, pos)
+
+
+def sentences(
+ string: ColumnOrName, language: t.Optional[ColumnOrName] = None, country: t.Optional[ColumnOrName] = None
+) -> Column:
+ if language is not None and country is not None:
+ return Column.invoke_anonymous_function(string, "SENTENCES", language, country)
+ if language is not None:
+ return Column.invoke_anonymous_function(string, "SENTENCES", language)
+ if country is not None:
+ return Column.invoke_anonymous_function(string, "SENTENCES", lit("en"), country)
+ return Column.invoke_anonymous_function(string, "SENTENCES")
+
+
+def substring(str: ColumnOrName, pos: int, len: int) -> Column:
+ return Column.ensure_col(str).substr(pos, len)
+
+
+def substring_index(str: ColumnOrName, delim: str, count: int) -> Column:
+ return Column.invoke_anonymous_function(str, "SUBSTRING_INDEX", lit(delim), lit(count))
+
+
+def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(
+ left, glotexp.Levenshtein, expression=Column.ensure_col(right).expression
+ )
+
+
+def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column:
+ substr_col = lit(substr)
+ pos_column = lit(pos)
+ str_column = Column.ensure_col(str)
+ if pos is not None:
+ return Column.invoke_anonymous_function(substr_col, "LOCATE", str_column, pos_column)
+ return Column.invoke_anonymous_function(substr_col, "LOCATE", str_column)
+
+
+def lpad(col: ColumnOrName, len: int, pad: str) -> Column:
+ return Column.invoke_anonymous_function(col, "LPAD", lit(len), lit(pad))
+
+
+def rpad(col: ColumnOrName, len: int, pad: str) -> Column:
+ return Column.invoke_anonymous_function(col, "RPAD", lit(len), lit(pad))
+
+
+def repeat(col: ColumnOrName, n: int) -> Column:
+ return Column.invoke_anonymous_function(col, "REPEAT", n)
+
+
+def split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column:
+ if limit is not None:
+ return Column.invoke_expression_over_column(
+ str, glotexp.RegexpSplit, expression=lit(pattern).expression, limit=lit(limit).expression
+ )
+ return Column.invoke_expression_over_column(str, glotexp.RegexpSplit, expression=lit(pattern).expression)
+
+
+def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) -> Column:
+ if idx is not None:
+ return Column.invoke_anonymous_function(str, "REGEXP_EXTRACT", lit(pattern), idx)
+ return Column.invoke_anonymous_function(str, "REGEXP_EXTRACT", lit(pattern))
+
+
+def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column:
+ return Column.invoke_anonymous_function(str, "REGEXP_REPLACE", lit(pattern), lit(replacement))
+
+
+def initcap(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Initcap)
+
+
+def soundex(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SOUNDEX")
+
+
+def bin(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "BIN")
+
+
+def hex(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "HEX")
+
+
+def unhex(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "UNHEX")
+
+
+def length(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Length)
+
+
+def octet_length(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "OCTET_LENGTH")
+
+
+def bit_length(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "BIT_LENGTH")
+
+
+def translate(srcCol: ColumnOrName, matching: str, replace: str) -> Column:
+ return Column.invoke_anonymous_function(srcCol, "TRANSLATE", lit(matching), lit(replace))
+
+
+def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
+ cols = _flatten(cols) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
+ cols = [Column.ensure_col(col).expression for col in cols] # type: ignore
+ return Column.invoke_expression_over_column(None, glotexp.Array, expressions=cols)
+
+
+def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
+ cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
+ return Column.invoke_expression_over_column(
+ None, glotexp.VarMap, keys=array(*cols[::2]).expression, values=array(*cols[1::2]).expression
+ )
+
+
+def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "MAP_FROM_ARRAYS", col2)
+
+
+def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
+ value_col = value if isinstance(value, Column) else lit(value)
+ return Column.invoke_expression_over_column(col, glotexp.ArrayContains, expression=value_col.expression)
+
+
+def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "ARRAYS_OVERLAP", Column.ensure_col(col2))
+
+
+def slice(x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]) -> Column:
+ start_col = start if isinstance(start, Column) else lit(start)
+ length_col = length if isinstance(length, Column) else lit(length)
+ return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col)
+
+
+def array_join(col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None) -> Column:
+ if null_replacement is not None:
+ return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement))
+ return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter))
+
+
+def concat(*cols: ColumnOrName) -> Column:
+ if len(cols) == 1:
+ return Column.invoke_anonymous_function(cols[0], "CONCAT")
+ return Column.invoke_anonymous_function(cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]])
+
+
+def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
+ value_col = value if isinstance(value, Column) else lit(value)
+ return Column.invoke_anonymous_function(col, "ARRAY_POSITION", value_col)
+
+
+def element_at(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
+ value_col = value if isinstance(value, Column) else lit(value)
+ return Column.invoke_anonymous_function(col, "ELEMENT_AT", value_col)
+
+
+def array_remove(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
+ value_col = value if isinstance(value, Column) else lit(value)
+ return Column.invoke_anonymous_function(col, "ARRAY_REMOVE", value_col)
+
+
+def array_distinct(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ARRAY_DISTINCT")
+
+
+def array_intersect(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "ARRAY_INTERSECT", Column.ensure_col(col2))
+
+
+def array_union(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "ARRAY_UNION", Column.ensure_col(col2))
+
+
+def array_except(col1: ColumnOrName, col2: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col1, "ARRAY_EXCEPT", Column.ensure_col(col2))
+
+
+def explode(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Explode)
+
+
+def posexplode(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.Posexplode)
+
+
+def explode_outer(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "EXPLODE_OUTER")
+
+
+def posexplode_outer(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "POSEXPLODE_OUTER")
+
+
+def get_json_object(col: ColumnOrName, path: str) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.JSONExtract, path=lit(path).expression)
+
+
+def json_tuple(col: ColumnOrName, *fields: str) -> Column:
+ return Column.invoke_anonymous_function(col, "JSON_TUPLE", *[lit(field) for field in fields])
+
+
+def from_json(
+ col: ColumnOrName,
+ schema: t.Union[Column, str],
+ options: t.Optional[t.Dict[str, str]] = None,
+) -> Column:
+ schema = schema if isinstance(schema, Column) else lit(schema)
+ if options is not None:
+ options_col = create_map([lit(x) for x in _flatten(options.items())])
+ return Column.invoke_anonymous_function(col, "FROM_JSON", schema, options_col)
+ return Column.invoke_anonymous_function(col, "FROM_JSON", schema)
+
+
+def to_json(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column:
+ if options is not None:
+ options_col = create_map([lit(x) for x in _flatten(options.items())])
+ return Column.invoke_anonymous_function(col, "TO_JSON", options_col)
+ return Column.invoke_anonymous_function(col, "TO_JSON")
+
+
+def schema_of_json(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column:
+ if options is not None:
+ options_col = create_map([lit(x) for x in _flatten(options.items())])
+ return Column.invoke_anonymous_function(col, "SCHEMA_OF_JSON", options_col)
+ return Column.invoke_anonymous_function(col, "SCHEMA_OF_JSON")
+
+
+def schema_of_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column:
+ if options is not None:
+ options_col = create_map([lit(x) for x in _flatten(options.items())])
+ return Column.invoke_anonymous_function(col, "SCHEMA_OF_CSV", options_col)
+ return Column.invoke_anonymous_function(col, "SCHEMA_OF_CSV")
+
+
+def to_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column:
+ if options is not None:
+ options_col = create_map([lit(x) for x in _flatten(options.items())])
+ return Column.invoke_anonymous_function(col, "TO_CSV", options_col)
+ return Column.invoke_anonymous_function(col, "TO_CSV")
+
+
+def size(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.ArraySize)
+
+
+def array_min(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ARRAY_MIN")
+
+
+def array_max(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "ARRAY_MAX")
+
+
+def sort_array(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column:
+ if asc is not None:
+ return Column.invoke_anonymous_function(col, "SORT_ARRAY", lit(asc))
+ return Column.invoke_anonymous_function(col, "SORT_ARRAY")
+
+
+def array_sort(col: ColumnOrName) -> Column:
+ return Column.invoke_expression_over_column(col, glotexp.ArraySort)
+
+
+def shuffle(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "SHUFFLE")
+
+
+def reverse(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "REVERSE")
+
+
+def flatten(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "FLATTEN")
+
+
+def map_keys(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MAP_KEYS")
+
+
+def map_values(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MAP_VALUES")
+
+
+def map_entries(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MAP_ENTRIES")
+
+
+def map_from_entries(col: ColumnOrName) -> Column:
+ return Column.invoke_anonymous_function(col, "MAP_FROM_ENTRIES")
+
+
+def array_repeat(col: ColumnOrName, count: t.Union[ColumnOrName, int]) -> Column:
+ count_col = count if isinstance(count, Column) else lit(count)
+ return Column.invoke_anonymous_function(col, "ARRAY_REPEAT", count_col)
+
+
+def array_zip(*cols: ColumnOrName) -> Column:
+ if len(cols) == 1:
+ return Column.invoke_anonymous_function(cols[0], "ARRAY_ZIP")
+ return Column.invoke_anonymous_function(cols[0], "ARRAY_ZIP", *cols[1:])
+
+
+def map_concat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
+ columns = list(flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
+ if len(columns) == 1:
+ return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT")
+ return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT", *columns[1:])
+
+
+def sequence(start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None) -> Column:
+ if step is not None:
+ return Column.invoke_anonymous_function(start, "SEQUENCE", stop, step)
+ return Column.invoke_anonymous_function(start, "SEQUENCE", stop)
+
+
+def from_csv(
+ col: ColumnOrName,
+ schema: t.Union[Column, str],
+ options: t.Optional[t.Dict[str, str]] = None,
+) -> Column:
+ schema = schema if isinstance(schema, Column) else lit(schema)
+ if options is not None:
+ option_cols = create_map([lit(x) for x in _flatten(options.items())])
+ return Column.invoke_anonymous_function(col, "FROM_CSV", schema, option_cols)
+ return Column.invoke_anonymous_function(col, "FROM_CSV", schema)
+
+
+def aggregate(
+ col: ColumnOrName,
+ initialValue: ColumnOrName,
+ merge: t.Callable[[Column, Column], Column],
+ finish: t.Optional[t.Callable[[Column], Column]] = None,
+ accumulator_name: str = "acc",
+ target_row_name: str = "x",
+) -> Column:
+ merge_exp = glotexp.Lambda(
+ this=merge(Column(accumulator_name), Column(target_row_name)).expression,
+ expressions=[
+ glotexp.to_identifier(accumulator_name, quoted=_lambda_quoted(accumulator_name)),
+ glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name)),
+ ],
+ )
+ if finish is not None:
+ finish_exp = glotexp.Lambda(
+ this=finish(Column(accumulator_name)).expression,
+ expressions=[glotexp.to_identifier(accumulator_name, quoted=_lambda_quoted(accumulator_name))],
+ )
+ return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp))
+ return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp))
+
+
+def transform(
+ col: ColumnOrName,
+ f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
+ target_row_name: str = "x",
+ row_count_name: str = "i",
+) -> Column:
+ num_arguments = len(signature(f).parameters)
+ expressions = [glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))]
+ columns = [Column(target_row_name)]
+ if num_arguments > 1:
+ columns.append(Column(row_count_name))
+ expressions.append(glotexp.to_identifier(row_count_name, quoted=_lambda_quoted(row_count_name)))
+
+ f_expression = glotexp.Lambda(this=f(*columns).expression, expressions=expressions)
+ return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression))
+
+
+def exists(col: ColumnOrName, f: t.Callable[[Column], Column], target_row_name: str = "x") -> Column:
+ f_expression = glotexp.Lambda(
+ this=f(Column(target_row_name)).expression,
+ expressions=[glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))],
+ )
+ return Column.invoke_anonymous_function(col, "EXISTS", Column(f_expression))
+
+
+def forall(col: ColumnOrName, f: t.Callable[[Column], Column], target_row_name: str = "x") -> Column:
+ f_expression = glotexp.Lambda(
+ this=f(Column(target_row_name)).expression,
+ expressions=[glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))],
+ )
+
+ return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression))
+
+
+def filter(
+ col: ColumnOrName,
+ f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
+ target_row_name: str = "x",
+ row_count_name: str = "i",
+) -> Column:
+ num_arguments = len(signature(f).parameters)
+ expressions = [glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))]
+ columns = [Column(target_row_name)]
+ if num_arguments > 1:
+ columns.append(Column(row_count_name))
+ expressions.append(glotexp.to_identifier(row_count_name, quoted=_lambda_quoted(row_count_name)))
+
+ f_expression = glotexp.Lambda(this=f(*columns).expression, expressions=expressions)
+ return Column.invoke_anonymous_function(col, "FILTER", Column(f_expression))
+
+
+def zip_with(
+ left: ColumnOrName,
+ right: ColumnOrName,
+ f: t.Callable[[Column, Column], Column],
+ left_name: str = "x",
+ right_name: str = "y",
+) -> Column:
+ f_expression = glotexp.Lambda(
+ this=f(Column(left_name), Column(right_name)).expression,
+ expressions=[
+ glotexp.to_identifier(left_name, quoted=_lambda_quoted(left_name)),
+ glotexp.to_identifier(right_name, quoted=_lambda_quoted(right_name)),
+ ],
+ )
+
+ return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression))
+
+
+def transform_keys(
+ col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v"
+) -> Column:
+ f_expression = glotexp.Lambda(
+ this=f(Column(key_name), Column(value_name)).expression,
+ expressions=[
+ glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
+ glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)),
+ ],
+ )
+ return Column.invoke_anonymous_function(col, "TRANSFORM_KEYS", Column(f_expression))
+
+
+def transform_values(
+ col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v"
+) -> Column:
+ f_expression = glotexp.Lambda(
+ this=f(Column(key_name), Column(value_name)).expression,
+ expressions=[
+ glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
+ glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)),
+ ],
+ )
+ return Column.invoke_anonymous_function(col, "TRANSFORM_VALUES", Column(f_expression))
+
+
+def map_filter(
+ col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v"
+) -> Column:
+ f_expression = glotexp.Lambda(
+ this=f(Column(key_name), Column(value_name)).expression,
+ expressions=[
+ glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
+ glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)),
+ ],
+ )
+ return Column.invoke_anonymous_function(col, "MAP_FILTER", Column(f_expression))
+
+
+def map_zip_with(
+ col1: ColumnOrName,
+ col2: ColumnOrName,
+ f: t.Union[t.Callable[[Column, Column, Column], Column]],
+ key_name: str = "k",
+ value1: str = "v1",
+ value2: str = "v2",
+) -> Column:
+ f_expression = glotexp.Lambda(
+ this=f(Column(key_name), Column(value1), Column(value2)).expression,
+ expressions=[
+ glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)),
+ glotexp.to_identifier(value1, quoted=_lambda_quoted(value1)),
+ glotexp.to_identifier(value2, quoted=_lambda_quoted(value2)),
+ ],
+ )
+ return Column.invoke_anonymous_function(col1, "MAP_ZIP_WITH", col2, Column(f_expression))
+
+
+def _lambda_quoted(value: str) -> t.Optional[bool]:
+ return False if value == "_" else None
diff --git a/sqlglot/dataframe/sql/group.py b/sqlglot/dataframe/sql/group.py
new file mode 100644
index 0000000..947aace
--- /dev/null
+++ b/sqlglot/dataframe/sql/group.py
@@ -0,0 +1,57 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot.dataframe.sql import functions as F
+from sqlglot.dataframe.sql.column import Column
+from sqlglot.dataframe.sql.operations import Operation, operation
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql.dataframe import DataFrame
+
+
+class GroupedData:
+ def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation):
+ self._df = df.copy()
+ self.spark = df.spark
+ self.last_op = last_op
+ self.group_by_cols = group_by_cols
+
+ def _get_function_applied_columns(self, func_name: str, cols: t.Tuple[str, ...]) -> t.List[Column]:
+ func_name = func_name.lower()
+ return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols]
+
+ @operation(Operation.SELECT)
+ def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame:
+ columns = (
+ [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
+ if isinstance(exprs[0], dict)
+ else exprs
+ )
+ cols = self._df._ensure_and_normalize_cols(columns)
+
+ expression = self._df.expression.group_by(*[x.expression for x in self.group_by_cols]).select(
+ *[x.expression for x in self.group_by_cols + cols], append=False
+ )
+ return self._df.copy(expression=expression)
+
+ def count(self) -> DataFrame:
+ return self.agg(F.count("*").alias("count"))
+
+ def mean(self, *cols: str) -> DataFrame:
+ return self.avg(*cols)
+
+ def avg(self, *cols: str) -> DataFrame:
+ return self.agg(*self._get_function_applied_columns("avg", cols))
+
+ def max(self, *cols: str) -> DataFrame:
+ return self.agg(*self._get_function_applied_columns("max", cols))
+
+ def min(self, *cols: str) -> DataFrame:
+ return self.agg(*self._get_function_applied_columns("min", cols))
+
+ def sum(self, *cols: str) -> DataFrame:
+ return self.agg(*self._get_function_applied_columns("sum", cols))
+
+ def pivot(self, *cols: str) -> DataFrame:
+ raise NotImplementedError("Sum distinct is not currently implemented")
diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py
new file mode 100644
index 0000000..1513946
--- /dev/null
+++ b/sqlglot/dataframe/sql/normalize.py
@@ -0,0 +1,72 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot import expressions as exp
+from sqlglot.dataframe.sql.column import Column
+from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
+from sqlglot.helper import ensure_list
+
+NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql.session import SparkSession
+
+
+def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]):
+ expr = ensure_list(expr)
+ expressions = _ensure_expressions(expr)
+ for expression in expressions:
+ identifiers = expression.find_all(exp.Identifier)
+ for identifier in identifiers:
+ replace_alias_name_with_cte_name(spark, expression_context, identifier)
+ replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
+
+
+def replace_alias_name_with_cte_name(spark: SparkSession, expression_context: exp.Select, id: exp.Identifier):
+ if id.alias_or_name in spark.name_to_sequence_id_mapping:
+ for cte in reversed(expression_context.ctes):
+ if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]:
+ _set_alias_name(id, cte.alias_or_name)
+ break
+
+
+def replace_branch_and_sequence_ids_with_cte_name(
+ spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
+):
+ if id.alias_or_name in spark.known_ids:
+ # Check if we have a join and if both the tables in that join share a common branch id
+ # If so we need to have this reference the left table by default unless the id is a sequence
+ # id then it keeps that reference. This handles the weird edge case in spark that shouldn't
+ # be common in practice
+ if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids:
+ join_table_aliases = [x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)]
+ ctes_in_join = [cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases]
+ if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]:
+ assert len(ctes_in_join) == 2
+ _set_alias_name(id, ctes_in_join[0].alias_or_name)
+ return
+
+ for cte in reversed(expression_context.ctes):
+ if id.alias_or_name in (cte.args["branch_id"], cte.args["sequence_id"]):
+ _set_alias_name(id, cte.alias_or_name)
+ return
+
+
+def _set_alias_name(id: exp.Identifier, name: str):
+ id.set("this", name)
+
+
+def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]:
+ values = ensure_list(values)
+ results = []
+ for value in values:
+ if isinstance(value, str):
+ results.append(Column.ensure_col(value).expression)
+ elif isinstance(value, Column):
+ results.append(value.expression)
+ elif isinstance(value, exp.Expression):
+ results.append(value)
+ else:
+ raise ValueError(f"Got an invalid type to normalize: {type(value)}")
+ return results
diff --git a/sqlglot/dataframe/sql/operations.py b/sqlglot/dataframe/sql/operations.py
new file mode 100644
index 0000000..d51335c
--- /dev/null
+++ b/sqlglot/dataframe/sql/operations.py
@@ -0,0 +1,53 @@
+from __future__ import annotations
+
+import functools
+import typing as t
+from enum import IntEnum
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql.dataframe import DataFrame
+ from sqlglot.dataframe.sql.group import GroupedData
+
+
+class Operation(IntEnum):
+ INIT = -1
+ NO_OP = 0
+ FROM = 1
+ WHERE = 2
+ GROUP_BY = 3
+ HAVING = 4
+ SELECT = 5
+ ORDER_BY = 6
+ LIMIT = 7
+
+
+def operation(op: Operation):
+ """
+ Decorator used around DataFrame methods to indicate what type of operation is being performed from the
+ ordered Operation enums. This is used to determine which operations should be performed on a CTE vs.
+ included with the previous operation.
+
+ Ex: After a user does a join we want to allow them to select which columns for the different
+ tables that they want to carry through to the following operation. If we put that join in
+ a CTE preemptively then the user would not have a chance to select which column they want
+ in cases where there is overlap in names.
+ """
+
+ def decorator(func: t.Callable):
+ @functools.wraps(func)
+ def wrapper(self: DataFrame, *args, **kwargs):
+ if self.last_op == Operation.INIT:
+ self = self._convert_leaf_to_cte()
+ self.last_op = Operation.NO_OP
+ last_op = self.last_op
+ new_op = op if op != Operation.NO_OP else last_op
+ if new_op < last_op or (last_op == new_op and new_op == Operation.SELECT):
+ self = self._convert_leaf_to_cte()
+ df: t.Union[DataFrame, GroupedData] = func(self, *args, **kwargs)
+ df.last_op = new_op # type: ignore
+ return df
+
+ wrapper.__wrapped__ = func # type: ignore
+ return wrapper
+
+ return decorator
diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py
new file mode 100644
index 0000000..4830035
--- /dev/null
+++ b/sqlglot/dataframe/sql/readwriter.py
@@ -0,0 +1,79 @@
+from __future__ import annotations
+
+import typing as t
+
+import sqlglot
+from sqlglot import expressions as exp
+from sqlglot.helper import object_to_dict
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql.dataframe import DataFrame
+ from sqlglot.dataframe.sql.session import SparkSession
+
+
+class DataFrameReader:
+ def __init__(self, spark: SparkSession):
+ self.spark = spark
+
+ def table(self, tableName: str) -> DataFrame:
+ from sqlglot.dataframe.sql.dataframe import DataFrame
+
+ sqlglot.schema.add_table(tableName)
+ return DataFrame(self.spark, exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)))
+
+
+class DataFrameWriter:
+ def __init__(
+ self, df: DataFrame, spark: t.Optional[SparkSession] = None, mode: t.Optional[str] = None, by_name: bool = False
+ ):
+ self._df = df
+ self._spark = spark or df.spark
+ self._mode = mode
+ self._by_name = by_name
+
+ def copy(self, **kwargs) -> DataFrameWriter:
+ return DataFrameWriter(
+ **{k[1:] if k.startswith("_") else k: v for k, v in object_to_dict(self, **kwargs).items()}
+ )
+
+ def sql(self, **kwargs) -> t.List[str]:
+ return self._df.sql(**kwargs)
+
+ def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
+ return self.copy(_mode=saveMode)
+
+ @property
+ def byName(self):
+ return self.copy(by_name=True)
+
+ def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
+ output_expression_container = exp.Insert(
+ **{
+ "this": exp.to_table(tableName),
+ "overwrite": overwrite,
+ }
+ )
+ df = self._df.copy(output_expression_container=output_expression_container)
+ if self._by_name:
+ columns = sqlglot.schema.column_names(tableName, only_visible=True)
+ df = df._convert_leaf_to_cte().select(*columns)
+
+ return self.copy(_df=df)
+
+ def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
+ if format is not None:
+ raise NotImplementedError("Providing Format in the save as table is not supported")
+ exists, replace, mode = None, None, mode or str(self._mode)
+ if mode == "append":
+ return self.insertInto(name)
+ if mode == "ignore":
+ exists = True
+ if mode == "overwrite":
+ replace = True
+ output_expression_container = exp.Create(
+ this=exp.to_table(name),
+ kind="TABLE",
+ exists=exists,
+ replace=replace,
+ )
+ return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
new file mode 100644
index 0000000..1ea86d1
--- /dev/null
+++ b/sqlglot/dataframe/sql/session.py
@@ -0,0 +1,148 @@
+from __future__ import annotations
+
+import typing as t
+import uuid
+from collections import defaultdict
+
+import sqlglot
+from sqlglot import expressions as exp
+from sqlglot.dataframe.sql import functions as F
+from sqlglot.dataframe.sql.dataframe import DataFrame
+from sqlglot.dataframe.sql.readwriter import DataFrameReader
+from sqlglot.dataframe.sql.types import StructType
+from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput
+
+
+class SparkSession:
+ known_ids: t.ClassVar[t.Set[str]] = set()
+ known_branch_ids: t.ClassVar[t.Set[str]] = set()
+ known_sequence_ids: t.ClassVar[t.Set[str]] = set()
+ name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list)
+
+ def __init__(self):
+ self.incrementing_id = 1
+
+ def __getattr__(self, name: str) -> SparkSession:
+ return self
+
+ def __call__(self, *args, **kwargs) -> SparkSession:
+ return self
+
+ @property
+ def read(self) -> DataFrameReader:
+ return DataFrameReader(self)
+
+ def table(self, tableName: str) -> DataFrame:
+ return self.read.table(tableName)
+
+ def createDataFrame(
+ self,
+ data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
+ schema: t.Optional[SchemaInput] = None,
+ samplingRatio: t.Optional[float] = None,
+ verifySchema: bool = False,
+ ) -> DataFrame:
+ from sqlglot.dataframe.sql.dataframe import DataFrame
+
+ if samplingRatio is not None or verifySchema:
+ raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
+ if schema is not None and (
+ not isinstance(schema, (StructType, str, list))
+ or (isinstance(schema, list) and not isinstance(schema[0], str))
+ ):
+ raise NotImplementedError("Only schema of either list or string of list supported")
+ if not data:
+ raise ValueError("Must provide data to create into a DataFrame")
+
+ column_mapping: t.Dict[str, t.Optional[str]]
+ if schema is not None:
+ column_mapping = get_column_mapping_from_schema_input(schema)
+ elif isinstance(data[0], dict):
+ column_mapping = {col_name.strip(): None for col_name in data[0]}
+ else:
+ column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
+
+ data_expressions = [
+ exp.Tuple(
+ expressions=list(map(lambda x: F.lit(x).expression, row if not isinstance(row, dict) else row.values()))
+ )
+ for row in data
+ ]
+
+ sel_columns = [
+ F.col(name).cast(data_type).alias(name).expression if data_type is not None else F.col(name).expression
+ for name, data_type in column_mapping.items()
+ ]
+
+ select_kwargs = {
+ "expressions": sel_columns,
+ "from": exp.From(
+ expressions=[
+ exp.Subquery(
+ this=exp.Values(expressions=data_expressions),
+ alias=exp.TableAlias(
+ this=exp.to_identifier(self._auto_incrementing_name),
+ columns=[exp.to_identifier(col_name) for col_name in column_mapping],
+ ),
+ )
+ ]
+ ),
+ }
+
+ sel_expression = exp.Select(**select_kwargs)
+ return DataFrame(self, sel_expression)
+
+ def sql(self, sqlQuery: str) -> DataFrame:
+ expression = sqlglot.parse_one(sqlQuery, read="spark")
+ if isinstance(expression, exp.Select):
+ df = DataFrame(self, expression)
+ df = df._convert_leaf_to_cte()
+ elif isinstance(expression, (exp.Create, exp.Insert)):
+ select_expression = expression.expression.copy()
+ if isinstance(expression, exp.Insert):
+ select_expression.set("with", expression.args.get("with"))
+ expression.set("with", None)
+ del expression.args["expression"]
+ df = DataFrame(self, select_expression, output_expression_container=expression)
+ df = df._convert_leaf_to_cte()
+ else:
+ raise ValueError("Unknown expression type provided in the SQL. Please create an issue with the SQL.")
+ return df
+
+ @property
+ def _auto_incrementing_name(self) -> str:
+ name = f"a{self.incrementing_id}"
+ self.incrementing_id += 1
+ return name
+
+ @property
+ def _random_name(self) -> str:
+ return f"a{str(uuid.uuid4())[:8]}"
+
+ @property
+ def _random_branch_id(self) -> str:
+ id = self._random_id
+ self.known_branch_ids.add(id)
+ return id
+
+ @property
+ def _random_sequence_id(self):
+ id = self._random_id
+ self.known_sequence_ids.add(id)
+ return id
+
+ @property
+ def _random_id(self) -> str:
+ id = f"a{str(uuid.uuid4())[:8]}"
+ self.known_ids.add(id)
+ return id
+
+ @property
+ def _join_hint_names(self) -> t.Set[str]:
+ return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"}
+
+ def _add_alias_to_mapping(self, name: str, sequence_id: str):
+ self.name_to_sequence_id_mapping[name].append(sequence_id)
diff --git a/sqlglot/dataframe/sql/transforms.py b/sqlglot/dataframe/sql/transforms.py
new file mode 100644
index 0000000..b3dcc12
--- /dev/null
+++ b/sqlglot/dataframe/sql/transforms.py
@@ -0,0 +1,9 @@
+import typing as t
+
+from sqlglot import expressions as exp
+
+
+def replace_id_value(node, replacement_mapping: t.Dict[exp.Identifier, exp.Identifier]):
+ if isinstance(node, exp.Identifier) and node in replacement_mapping:
+ node = node.replace(replacement_mapping[node].copy())
+ return node
diff --git a/sqlglot/dataframe/sql/types.py b/sqlglot/dataframe/sql/types.py
new file mode 100644
index 0000000..dc5c05a
--- /dev/null
+++ b/sqlglot/dataframe/sql/types.py
@@ -0,0 +1,208 @@
+import typing as t
+
+
+class DataType:
+ def __repr__(self) -> str:
+ return self.__class__.__name__ + "()"
+
+ def __hash__(self) -> int:
+ return hash(str(self))
+
+ def __eq__(self, other: t.Any) -> bool:
+ return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
+
+ def __ne__(self, other: t.Any) -> bool:
+ return not self.__eq__(other)
+
+ def __str__(self) -> str:
+ return self.typeName()
+
+ @classmethod
+ def typeName(cls) -> str:
+ return cls.__name__[:-4].lower()
+
+ def simpleString(self) -> str:
+ return str(self)
+
+ def jsonValue(self) -> t.Union[str, t.Dict[str, t.Any]]:
+ return str(self)
+
+
+class DataTypeWithLength(DataType):
+ def __init__(self, length: int):
+ self.length = length
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({self.length})"
+
+ def __str__(self) -> str:
+ return f"{self.typeName()}({self.length})"
+
+
+class StringType(DataType):
+ pass
+
+
+class CharType(DataTypeWithLength):
+ pass
+
+
+class VarcharType(DataTypeWithLength):
+ pass
+
+
+class BinaryType(DataType):
+ pass
+
+
+class BooleanType(DataType):
+ pass
+
+
+class DateType(DataType):
+ pass
+
+
+class TimestampType(DataType):
+ pass
+
+
+class TimestampNTZType(DataType):
+ @classmethod
+ def typeName(cls) -> str:
+ return "timestamp_ntz"
+
+
+class DecimalType(DataType):
+ def __init__(self, precision: int = 10, scale: int = 0):
+ self.precision = precision
+ self.scale = scale
+
+ def simpleString(self) -> str:
+ return f"decimal({self.precision}, {self.scale})"
+
+ def jsonValue(self) -> str:
+ return f"decimal({self.precision}, {self.scale})"
+
+ def __repr__(self) -> str:
+ return f"DecimalType({self.precision}, {self.scale})"
+
+
+class DoubleType(DataType):
+ pass
+
+
+class FloatType(DataType):
+ pass
+
+
+class ByteType(DataType):
+ def __str__(self) -> str:
+ return "tinyint"
+
+
+class IntegerType(DataType):
+ def __str__(self) -> str:
+ return "int"
+
+
+class LongType(DataType):
+ def __str__(self) -> str:
+ return "bigint"
+
+
+class ShortType(DataType):
+ def __str__(self) -> str:
+ return "smallint"
+
+
+class ArrayType(DataType):
+ def __init__(self, elementType: DataType, containsNull: bool = True):
+ self.elementType = elementType
+ self.containsNull = containsNull
+
+ def __repr__(self) -> str:
+ return f"ArrayType({self.elementType, str(self.containsNull)}"
+
+ def simpleString(self) -> str:
+ return f"array<{self.elementType.simpleString()}>"
+
+ def jsonValue(self) -> t.Dict[str, t.Any]:
+ return {
+ "type": self.typeName(),
+ "elementType": self.elementType.jsonValue(),
+ "containsNull": self.containsNull,
+ }
+
+
+class MapType(DataType):
+ def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True):
+ self.keyType = keyType
+ self.valueType = valueType
+ self.valueContainsNull = valueContainsNull
+
+ def __repr__(self) -> str:
+ return f"MapType({self.keyType}, {self.valueType}, {str(self.valueContainsNull)})"
+
+ def simpleString(self) -> str:
+ return f"map<{self.keyType.simpleString()}, {self.valueType.simpleString()}>"
+
+ def jsonValue(self) -> t.Dict[str, t.Any]:
+ return {
+ "type": self.typeName(),
+ "keyType": self.keyType.jsonValue(),
+ "valueType": self.valueType.jsonValue(),
+ "valueContainsNull": self.valueContainsNull,
+ }
+
+
+class StructField(DataType):
+ def __init__(
+ self, name: str, dataType: DataType, nullable: bool = True, metadata: t.Optional[t.Dict[str, t.Any]] = None
+ ):
+ self.name = name
+ self.dataType = dataType
+ self.nullable = nullable
+ self.metadata = metadata or {}
+
+ def __repr__(self) -> str:
+ return f"StructField('{self.name}', {self.dataType}, {str(self.nullable)})"
+
+ def simpleString(self) -> str:
+ return f"{self.name}:{self.dataType.simpleString()}"
+
+ def jsonValue(self) -> t.Dict[str, t.Any]:
+ return {
+ "name": self.name,
+ "type": self.dataType.jsonValue(),
+ "nullable": self.nullable,
+ "metadata": self.metadata,
+ }
+
+
+class StructType(DataType):
+ def __init__(self, fields: t.Optional[t.List[StructField]] = None):
+ if not fields:
+ self.fields = []
+ self.names = []
+ else:
+ self.fields = fields
+ self.names = [f.name for f in fields]
+
+ def __iter__(self) -> t.Iterator[StructField]:
+ return iter(self.fields)
+
+ def __len__(self) -> int:
+ return len(self.fields)
+
+ def __repr__(self) -> str:
+ return f"StructType({', '.join(str(field) for field in self)})"
+
+ def simpleString(self) -> str:
+ return f"struct<{', '.join(x.simpleString() for x in self)}>"
+
+ def jsonValue(self) -> t.Dict[str, t.Any]:
+ return {"type": self.typeName(), "fields": [x.jsonValue() for x in self]}
+
+ def fieldNames(self) -> t.List[str]:
+ return list(self.names)
diff --git a/sqlglot/dataframe/sql/util.py b/sqlglot/dataframe/sql/util.py
new file mode 100644
index 0000000..575d18a
--- /dev/null
+++ b/sqlglot/dataframe/sql/util.py
@@ -0,0 +1,32 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot import expressions as exp
+from sqlglot.dataframe.sql import types
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql._typing import SchemaInput
+
+
+def get_column_mapping_from_schema_input(schema: SchemaInput) -> t.Dict[str, t.Optional[str]]:
+ if isinstance(schema, dict):
+ return schema
+ elif isinstance(schema, str):
+ col_name_type_strs = [x.strip() for x in schema.split(",")]
+ return {
+ name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
+ for name_type_str in col_name_type_strs
+ }
+ elif isinstance(schema, types.StructType):
+ return {struct_field.name: struct_field.dataType.simpleString() for struct_field in schema}
+ return {x.strip(): None for x in schema} # type: ignore
+
+
+def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.Table]:
+ if not expression.args.get("joins"):
+ return []
+
+ left_table = expression.args["from"].args["expressions"][0]
+ other_tables = [join.this for join in expression.args["joins"]]
+ return [left_table] + other_tables
diff --git a/sqlglot/dataframe/sql/window.py b/sqlglot/dataframe/sql/window.py
new file mode 100644
index 0000000..842f366
--- /dev/null
+++ b/sqlglot/dataframe/sql/window.py
@@ -0,0 +1,117 @@
+from __future__ import annotations
+
+import sys
+import typing as t
+
+from sqlglot import expressions as exp
+from sqlglot.dataframe.sql import functions as F
+from sqlglot.helper import flatten
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql._typing import ColumnOrName
+
+
+class Window:
+ _JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808
+ _JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807
+ _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG)
+ _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG)
+
+ unboundedPreceding: int = _JAVA_MIN_LONG
+
+ unboundedFollowing: int = _JAVA_MAX_LONG
+
+ currentRow: int = 0
+
+ @classmethod
+ def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
+ return WindowSpec().partitionBy(*cols)
+
+ @classmethod
+ def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
+ return WindowSpec().orderBy(*cols)
+
+ @classmethod
+ def rowsBetween(cls, start: int, end: int) -> WindowSpec:
+ return WindowSpec().rowsBetween(start, end)
+
+ @classmethod
+ def rangeBetween(cls, start: int, end: int) -> WindowSpec:
+ return WindowSpec().rangeBetween(start, end)
+
+
+class WindowSpec:
+ def __init__(self, expression: exp.Expression = exp.Window()):
+ self.expression = expression
+
+ def copy(self):
+ return WindowSpec(self.expression.copy())
+
+ def sql(self, **kwargs) -> str:
+ return self.expression.sql(dialect="spark", **kwargs)
+
+ def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
+ from sqlglot.dataframe.sql.column import Column
+
+ cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
+ expressions = [Column.ensure_col(x).expression for x in cols]
+ window_spec = self.copy()
+ partition_by_expressions = window_spec.expression.args.get("partition_by", [])
+ partition_by_expressions.extend(expressions)
+ window_spec.expression.set("partition_by", partition_by_expressions)
+ return window_spec
+
+ def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
+ from sqlglot.dataframe.sql.column import Column
+
+ cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
+ expressions = [Column.ensure_col(x).expression for x in cols]
+ window_spec = self.copy()
+ if window_spec.expression.args.get("order") is None:
+ window_spec.expression.set("order", exp.Order(expressions=[]))
+ order_by = window_spec.expression.args["order"].expressions
+ order_by.extend(expressions)
+ window_spec.expression.args["order"].set("expressions", order_by)
+ return window_spec
+
+ def _calc_start_end(self, start: int, end: int) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
+ kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {"start_side": None, "end_side": None}
+ if start == Window.currentRow:
+ kwargs["start"] = "CURRENT ROW"
+ else:
+ kwargs = {
+ **kwargs,
+ **{
+ "start_side": "PRECEDING",
+ "start": "UNBOUNDED" if start <= Window.unboundedPreceding else F.lit(start).expression,
+ },
+ }
+ if end == Window.currentRow:
+ kwargs["end"] = "CURRENT ROW"
+ else:
+ kwargs = {
+ **kwargs,
+ **{
+ "end_side": "FOLLOWING",
+ "end": "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression,
+ },
+ }
+ return kwargs
+
+ def rowsBetween(self, start: int, end: int) -> WindowSpec:
+ window_spec = self.copy()
+ spec = self._calc_start_end(start, end)
+ spec["kind"] = "ROWS"
+ window_spec.expression.set(
+ "spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
+ )
+ return window_spec
+
+ def rangeBetween(self, start: int, end: int) -> WindowSpec:
+ window_spec = self.copy()
+ spec = self._calc_start_end(start, end)
+ spec["kind"] = "RANGE"
+ window_spec.expression.set(
+ "spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
+ )
+ return window_spec