diff options
Diffstat (limited to 'sqlglot/dataframe/sql')
-rw-r--r-- | sqlglot/dataframe/sql/__init__.py | 18 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/_typing.pyi | 20 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/column.py | 295 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/dataframe.py | 730 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/functions.py | 1258 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/group.py | 57 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/normalize.py | 72 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/operations.py | 53 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/readwriter.py | 79 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/session.py | 148 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/transforms.py | 9 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/types.py | 208 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/util.py | 32 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/window.py | 117 |
14 files changed, 3096 insertions, 0 deletions
diff --git a/sqlglot/dataframe/sql/__init__.py b/sqlglot/dataframe/sql/__init__.py new file mode 100644 index 0000000..3f90802 --- /dev/null +++ b/sqlglot/dataframe/sql/__init__.py @@ -0,0 +1,18 @@ +from sqlglot.dataframe.sql.column import Column +from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions +from sqlglot.dataframe.sql.group import GroupedData +from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter +from sqlglot.dataframe.sql.session import SparkSession +from sqlglot.dataframe.sql.window import Window, WindowSpec + +__all__ = [ + "SparkSession", + "DataFrame", + "GroupedData", + "Column", + "DataFrameNaFunctions", + "Window", + "WindowSpec", + "DataFrameReader", + "DataFrameWriter", +] diff --git a/sqlglot/dataframe/sql/_typing.pyi b/sqlglot/dataframe/sql/_typing.pyi new file mode 100644 index 0000000..f1a03ea --- /dev/null +++ b/sqlglot/dataframe/sql/_typing.pyi @@ -0,0 +1,20 @@ +from __future__ import annotations + +import datetime +import typing as t + +from sqlglot import expressions as exp + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql.column import Column + from sqlglot.dataframe.sql.types import StructType + +ColumnLiterals = t.TypeVar( + "ColumnLiterals", bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] +) +ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str]) +ColumnOrLiteral = t.TypeVar( + "ColumnOrLiteral", bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] +) +SchemaInput = t.TypeVar("SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]]) +OutputExpressionContainer = t.TypeVar("OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert]) diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py new file mode 100644 index 0000000..2391080 --- /dev/null +++ b/sqlglot/dataframe/sql/column.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +import typing as t + +import sqlglot +from sqlglot import expressions as exp +from sqlglot.dataframe.sql.types import DataType +from sqlglot.helper import flatten + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql._typing import ColumnOrLiteral + from sqlglot.dataframe.sql.window import WindowSpec + + +class Column: + def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): + if isinstance(expression, Column): + expression = expression.expression # type: ignore + elif expression is None or not isinstance(expression, (str, exp.Expression)): + expression = self._lit(expression).expression # type: ignore + self.expression: exp.Expression = sqlglot.maybe_parse(expression, dialect="spark") + + def __repr__(self): + return repr(self.expression) + + def __hash__(self): + return hash(self.expression) + + def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore + return self.binary_op(exp.EQ, other) + + def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore + return self.binary_op(exp.NEQ, other) + + def __gt__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.GT, other) + + def __ge__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.GTE, other) + + def __lt__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.LT, other) + + def __le__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.LTE, other) + + def __and__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.And, other) + + def __or__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.Or, other) + + def __mod__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.Mod, other) + + def __add__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.Add, other) + + def __sub__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.Sub, other) + + def __mul__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.Mul, other) + + def __truediv__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.Div, other) + + def __div__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.Div, other) + + def __neg__(self) -> Column: + return self.unary_op(exp.Neg) + + def __radd__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.Add, other) + + def __rsub__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.Sub, other) + + def __rmul__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.Mul, other) + + def __rdiv__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.Div, other) + + def __rtruediv__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.Div, other) + + def __rmod__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.Mod, other) + + def __pow__(self, power: ColumnOrLiteral, modulo=None): + return Column(exp.Pow(this=self.expression, power=Column(power).expression)) + + def __rpow__(self, power: ColumnOrLiteral): + return Column(exp.Pow(this=Column(power).expression, power=self.expression)) + + def __invert__(self): + return self.unary_op(exp.Not) + + def __rand__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.And, other) + + def __ror__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.Or, other) + + @classmethod + def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): + return cls(value) + + @classmethod + def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]: + return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args] + + @classmethod + def _lit(cls, value: ColumnOrLiteral) -> Column: + if isinstance(value, dict): + columns = [cls._lit(v).alias(k).expression for k, v in value.items()] + return cls(exp.Struct(expressions=columns)) + return cls(exp.convert(value)) + + @classmethod + def invoke_anonymous_function( + cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] + ) -> Column: + columns = [] if column is None else [cls.ensure_col(column)] + column_args = [cls.ensure_col(arg) for arg in args] + expressions = [x.expression for x in columns + column_args] + new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) + return Column(new_expression) + + @classmethod + def invoke_expression_over_column( + cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs + ) -> Column: + ensured_column = None if column is None else cls.ensure_col(column) + new_expression = ( + callable_expression(**kwargs) + if ensured_column is None + else callable_expression(this=ensured_column.column_expression, **kwargs) + ) + return Column(new_expression) + + def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: + return Column(klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)) + + def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: + return Column(klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)) + + def unary_op(self, klass: t.Callable, **kwargs) -> Column: + return Column(klass(this=self.column_expression, **kwargs)) + + @property + def is_alias(self): + return isinstance(self.expression, exp.Alias) + + @property + def is_column(self): + return isinstance(self.expression, exp.Column) + + @property + def column_expression(self) -> exp.Column: + return self.expression.unalias() + + @property + def alias_or_name(self) -> str: + return self.expression.alias_or_name + + @classmethod + def ensure_literal(cls, value) -> Column: + from sqlglot.dataframe.sql.functions import lit + + if isinstance(value, cls): + value = value.expression + if not isinstance(value, exp.Literal): + return lit(value) + return Column(value) + + def copy(self) -> Column: + return Column(self.expression.copy()) + + def set_table_name(self, table_name: str, copy=False) -> Column: + expression = self.expression.copy() if copy else self.expression + expression.set("table", exp.to_identifier(table_name)) + return Column(expression) + + def sql(self, **kwargs) -> Column: + return self.expression.sql(**{"dialect": "spark", **kwargs}) + + def alias(self, name: str) -> Column: + new_expression = exp.alias_(self.column_expression, name) + return Column(new_expression) + + def asc(self) -> Column: + new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True) + return Column(new_expression) + + def desc(self) -> Column: + new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False) + return Column(new_expression) + + asc_nulls_first = asc + + def asc_nulls_last(self) -> Column: + new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False) + return Column(new_expression) + + def desc_nulls_first(self) -> Column: + new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True) + return Column(new_expression) + + desc_nulls_last = desc + + def when(self, condition: Column, value: t.Any) -> Column: + from sqlglot.dataframe.sql.functions import when + + column_with_if = when(condition, value) + if not isinstance(self.expression, exp.Case): + return column_with_if + new_column = self.copy() + new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) + return new_column + + def otherwise(self, value: t.Any) -> Column: + from sqlglot.dataframe.sql.functions import lit + + true_value = value if isinstance(value, Column) else lit(value) + new_column = self.copy() + new_column.expression.set("default", true_value.column_expression) + return new_column + + def isNull(self) -> Column: + new_expression = exp.Is(this=self.column_expression, expression=exp.Null()) + return Column(new_expression) + + def isNotNull(self) -> Column: + new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null())) + return Column(new_expression) + + def cast(self, dataType: t.Union[str, DataType]): + """ + Functionality Difference: PySpark cast accepts a datatype instance of the datatype class + Sqlglot doesn't currently replicate this class so it only accepts a string + """ + if isinstance(dataType, DataType): + dataType = dataType.simpleString() + new_expression = exp.Cast(this=self.column_expression, to=dataType) + return Column(new_expression) + + def startswith(self, value: t.Union[str, Column]) -> Column: + value = self._lit(value) if not isinstance(value, Column) else value + return self.invoke_anonymous_function(self, "STARTSWITH", value) + + def endswith(self, value: t.Union[str, Column]) -> Column: + value = self._lit(value) if not isinstance(value, Column) else value + return self.invoke_anonymous_function(self, "ENDSWITH", value) + + def rlike(self, regexp: str) -> Column: + return self.invoke_expression_over_column( + column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression + ) + + def like(self, other: str): + return self.invoke_expression_over_column(self, exp.Like, expression=self._lit(other).expression) + + def ilike(self, other: str): + return self.invoke_expression_over_column(self, exp.ILike, expression=self._lit(other).expression) + + def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: + startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos + length = self._lit(length) if not isinstance(length, Column) else length + return Column.invoke_expression_over_column( + self, exp.Substring, start=startPos.expression, length=length.expression + ) + + def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): + columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore + expressions = [self._lit(x).expression for x in columns] + return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore + + def between( + self, + lowerBound: t.Union[ColumnOrLiteral], + upperBound: t.Union[ColumnOrLiteral], + ) -> Column: + lower_bound_exp = self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound + upper_bound_exp = self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound + return Column( + exp.Between(this=self.column_expression, low=lower_bound_exp.expression, high=upper_bound_exp.expression) + ) + + def over(self, window: WindowSpec) -> Column: + window_expression = window.expression.copy() + window_expression.set("this", self.column_expression) + return Column(window_expression) diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py new file mode 100644 index 0000000..322dcf2 --- /dev/null +++ b/sqlglot/dataframe/sql/dataframe.py @@ -0,0 +1,730 @@ +from __future__ import annotations + +import functools +import typing as t +import zlib +from copy import copy + +import sqlglot +from sqlglot import expressions as exp +from sqlglot.dataframe.sql import functions as F +from sqlglot.dataframe.sql.column import Column +from sqlglot.dataframe.sql.group import GroupedData +from sqlglot.dataframe.sql.normalize import normalize +from sqlglot.dataframe.sql.operations import Operation, operation +from sqlglot.dataframe.sql.readwriter import DataFrameWriter +from sqlglot.dataframe.sql.transforms import replace_id_value +from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join +from sqlglot.dataframe.sql.window import Window +from sqlglot.helper import ensure_list, object_to_dict +from sqlglot.optimizer import optimize as optimize_func +from sqlglot.optimizer.qualify_columns import qualify_columns + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql._typing import ColumnLiterals, OutputExpressionContainer + from sqlglot.dataframe.sql.session import SparkSession + + +JOIN_HINTS = { + "BROADCAST", + "BROADCASTJOIN", + "MAPJOIN", + "MERGE", + "SHUFFLEMERGE", + "MERGEJOIN", + "SHUFFLE_HASH", + "SHUFFLE_REPLICATE_NL", +} + + +class DataFrame: + def __init__( + self, + spark: SparkSession, + expression: exp.Select, + branch_id: t.Optional[str] = None, + sequence_id: t.Optional[str] = None, + last_op: Operation = Operation.INIT, + pending_hints: t.Optional[t.List[exp.Expression]] = None, + output_expression_container: t.Optional[OutputExpressionContainer] = None, + **kwargs, + ): + self.spark = spark + self.expression = expression + self.branch_id = branch_id or self.spark._random_branch_id + self.sequence_id = sequence_id or self.spark._random_sequence_id + self.last_op = last_op + self.pending_hints = pending_hints or [] + self.output_expression_container = output_expression_container or exp.Select() + + def __getattr__(self, column_name: str) -> Column: + return self[column_name] + + def __getitem__(self, column_name: str) -> Column: + column_name = f"{self.branch_id}.{column_name}" + return Column(column_name) + + def __copy__(self): + return self.copy() + + @property + def sparkSession(self): + return self.spark + + @property + def write(self): + return DataFrameWriter(self) + + @property + def latest_cte_name(self) -> str: + if not self.expression.ctes: + from_exp = self.expression.args["from"] + if from_exp.alias_or_name: + return from_exp.alias_or_name + table_alias = from_exp.find(exp.TableAlias) + if not table_alias: + raise RuntimeError(f"Could not find an alias name for this expression: {self.expression}") + return table_alias.alias_or_name + return self.expression.ctes[-1].alias + + @property + def pending_join_hints(self): + return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)] + + @property + def pending_partition_hints(self): + return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)] + + @property + def columns(self) -> t.List[str]: + return self.expression.named_selects + + @property + def na(self) -> DataFrameNaFunctions: + return DataFrameNaFunctions(self) + + def _replace_cte_names_with_hashes(self, expression: exp.Select): + expression = expression.copy() + ctes = expression.ctes + replacement_mapping = {} + for cte in ctes: + old_name_id = cte.args["alias"].this + new_hashed_id = exp.to_identifier( + self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"] + ) + replacement_mapping[old_name_id] = new_hashed_id + cte.set("alias", exp.TableAlias(this=new_hashed_id)) + expression = expression.transform(replace_id_value, replacement_mapping) + return expression + + def _create_cte_from_expression( + self, + expression: exp.Expression, + branch_id: t.Optional[str] = None, + sequence_id: t.Optional[str] = None, + **kwargs, + ) -> t.Tuple[exp.CTE, str]: + name = self.spark._random_name + expression_to_cte = expression.copy() + expression_to_cte.set("with", None) + cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0] + cte.set("branch_id", branch_id or self.branch_id) + cte.set("sequence_id", sequence_id or self.sequence_id) + return cte, name + + def _ensure_list_of_columns( + self, cols: t.Union[str, t.Iterable[str], Column, t.Iterable[Column]] + ) -> t.List[Column]: + columns = ensure_list(cols) + columns = Column.ensure_cols(columns) + return columns + + def _ensure_and_normalize_cols(self, cols): + cols = self._ensure_list_of_columns(cols) + normalize(self.spark, self.expression, cols) + return cols + + def _ensure_and_normalize_col(self, col): + col = Column.ensure_col(col) + normalize(self.spark, self.expression, col) + return col + + def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame: + df = self._resolve_pending_hints() + sequence_id = sequence_id or df.sequence_id + expression = df.expression.copy() + cte_expression, cte_name = df._create_cte_from_expression(expression=expression, sequence_id=sequence_id) + new_expression = df._add_ctes_to_expression(exp.Select(), expression.ctes + [cte_expression]) + sel_columns = df._get_outer_select_columns(cte_expression) + new_expression = new_expression.from_(cte_name).select(*[x.alias_or_name for x in sel_columns]) + return df.copy(expression=new_expression, sequence_id=sequence_id) + + def _resolve_pending_hints(self) -> DataFrame: + df = self.copy() + if not self.pending_hints: + return df + expression = df.expression + hint_expression = expression.args.get("hint") or exp.Hint(expressions=[]) + for hint in df.pending_partition_hints: + hint_expression.args.get("expressions").append(hint) + df.pending_hints.remove(hint) + + join_aliases = {join_table.alias_or_name for join_table in get_tables_from_expression_with_join(expression)} + if join_aliases: + for hint in df.pending_join_hints: + for sequence_id_expression in hint.expressions: + sequence_id_or_name = sequence_id_expression.alias_or_name + sequence_ids_to_match = [sequence_id_or_name] + if sequence_id_or_name in df.spark.name_to_sequence_id_mapping: + sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[sequence_id_or_name] + matching_ctes = [ + cte for cte in reversed(expression.ctes) if cte.args["sequence_id"] in sequence_ids_to_match + ] + for matching_cte in matching_ctes: + if matching_cte.alias_or_name in join_aliases: + sequence_id_expression.set("this", matching_cte.args["alias"].this) + df.pending_hints.remove(hint) + break + hint_expression.args.get("expressions").append(hint) + if hint_expression.expressions: + expression.set("hint", hint_expression) + return df + + def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame: + hint_name = hint_name.upper() + hint_expression = ( + exp.JoinHint(this=hint_name, expressions=[exp.to_table(parameter.alias_or_name) for parameter in args]) + if hint_name in JOIN_HINTS + else exp.Anonymous(this=hint_name, expressions=[parameter.expression for parameter in args]) + ) + new_df = self.copy() + new_df.pending_hints.append(hint_expression) + return new_df + + def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool): + other_df = other._convert_leaf_to_cte() + base_expression = self.expression.copy() + base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes) + all_ctes = base_expression.ctes + other_df.expression.set("with", None) + base_expression.set("with", None) + operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression) + operation.set("with", exp.With(expressions=all_ctes)) + return self.copy(expression=operation)._convert_leaf_to_cte() + + def _cache(self, storage_level: str): + df = self._convert_leaf_to_cte() + df.expression.ctes[-1].set("cache_storage_level", storage_level) + return df + + @classmethod + def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select: + expression = expression.copy() + with_expression = expression.args.get("with") + if with_expression: + existing_ctes = with_expression.expressions + existsing_cte_names = {x.alias_or_name for x in existing_ctes} + for cte in ctes: + if cte.alias_or_name not in existsing_cte_names: + existing_ctes.append(cte) + else: + existing_ctes = ctes + expression.set("with", exp.With(expressions=existing_ctes)) + return expression + + @classmethod + def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]: + expression = item.expression if isinstance(item, DataFrame) else item + return [Column(x) for x in expression.find(exp.Select).expressions] + + @classmethod + def _create_hash_from_expression(cls, expression: exp.Select): + value = expression.sql(dialect="spark").encode("utf-8") + return f"t{zlib.crc32(value)}"[:6] + + def _get_select_expressions( + self, + ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]: + select_expressions: t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]] = [] + main_select_ctes: t.List[exp.CTE] = [] + for cte in self.expression.ctes: + cache_storage_level = cte.args.get("cache_storage_level") + if cache_storage_level: + select_expression = cte.this.copy() + select_expression.set("with", exp.With(expressions=copy(main_select_ctes))) + select_expression.set("cte_alias_name", cte.alias_or_name) + select_expression.set("cache_storage_level", cache_storage_level) + select_expressions.append((exp.Cache, select_expression)) + else: + main_select_ctes.append(cte) + main_select = self.expression.copy() + if main_select_ctes: + main_select.set("with", exp.With(expressions=main_select_ctes)) + expression_select_pair = (type(self.output_expression_container), main_select) + select_expressions.append(expression_select_pair) # type: ignore + return select_expressions + + def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]: + df = self._resolve_pending_hints() + select_expressions = df._get_select_expressions() + output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] + replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} + for expression_type, select_expression in select_expressions: + select_expression = select_expression.transform(replace_id_value, replacement_mapping) + if optimize: + select_expression = optimize_func(select_expression) + select_expression = df._replace_cte_names_with_hashes(select_expression) + expression: t.Union[exp.Select, exp.Cache, exp.Drop] + if expression_type == exp.Cache: + cache_table_name = df._create_hash_from_expression(select_expression) + cache_table = exp.to_table(cache_table_name) + original_alias_name = select_expression.args["cte_alias_name"] + replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(cache_table_name) + sqlglot.schema.add_table(cache_table_name, select_expression.named_selects) + cache_storage_level = select_expression.args["cache_storage_level"] + options = [ + exp.Literal.string("storageLevel"), + exp.Literal.string(cache_storage_level), + ] + expression = exp.Cache(this=cache_table, expression=select_expression, lazy=True, options=options) + # We will drop the "view" if it exists before running the cache table + output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) + elif expression_type == exp.Create: + expression = df.output_expression_container.copy() + expression.set("expression", select_expression) + elif expression_type == exp.Insert: + expression = df.output_expression_container.copy() + select_without_ctes = select_expression.copy() + select_without_ctes.set("with", None) + expression.set("expression", select_without_ctes) + if select_expression.ctes: + expression.set("with", exp.With(expressions=select_expression.ctes)) + elif expression_type == exp.Select: + expression = select_expression + else: + raise ValueError(f"Invalid expression type: {expression_type}") + output_expressions.append(expression) + + return [expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions] + + def copy(self, **kwargs) -> DataFrame: + return DataFrame(**object_to_dict(self, **kwargs)) + + @operation(Operation.SELECT) + def select(self, *cols, **kwargs) -> DataFrame: + cols = self._ensure_and_normalize_cols(cols) + kwargs["append"] = kwargs.get("append", False) + if self.expression.args.get("joins"): + ambiguous_cols = [col for col in cols if not col.column_expression.table] + if ambiguous_cols: + join_table_identifiers = [x.this for x in get_tables_from_expression_with_join(self.expression)] + cte_names_in_join = [x.this for x in join_table_identifiers] + for ambiguous_col in ambiguous_cols: + ctes_with_column = [ + cte + for cte in self.expression.ctes + if cte.alias_or_name in cte_names_in_join + and ambiguous_col.alias_or_name in cte.this.named_selects + ] + # If the select column does not specify a table and there is a join + # then we assume they are referring to the left table + if len(ctes_with_column) > 1: + table_identifier = self.expression.args["from"].args["expressions"][0].this + else: + table_identifier = ctes_with_column[0].args["alias"].this + ambiguous_col.expression.set("table", table_identifier) + expression = self.expression.select(*[x.expression for x in cols], **kwargs) + qualify_columns(expression, sqlglot.schema) + return self.copy(expression=expression, **kwargs) + + @operation(Operation.NO_OP) + def alias(self, name: str, **kwargs) -> DataFrame: + new_sequence_id = self.spark._random_sequence_id + df = self.copy() + for join_hint in df.pending_join_hints: + for expression in join_hint.expressions: + if expression.alias_or_name == self.sequence_id: + expression.set("this", Column.ensure_col(new_sequence_id).expression) + df.spark._add_alias_to_mapping(name, new_sequence_id) + return df._convert_leaf_to_cte(sequence_id=new_sequence_id) + + @operation(Operation.WHERE) + def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame: + col = self._ensure_and_normalize_col(column) + return self.copy(expression=self.expression.where(col.expression)) + + filter = where + + @operation(Operation.GROUP_BY) + def groupBy(self, *cols, **kwargs) -> GroupedData: + columns = self._ensure_and_normalize_cols(cols) + return GroupedData(self, columns, self.last_op) + + @operation(Operation.SELECT) + def agg(self, *exprs, **kwargs) -> DataFrame: + cols = self._ensure_and_normalize_cols(exprs) + return self.groupBy().agg(*cols) + + @operation(Operation.FROM) + def join( + self, other_df: DataFrame, on: t.Union[str, t.List[str], Column, t.List[Column]], how: str = "inner", **kwargs + ) -> DataFrame: + other_df = other_df._convert_leaf_to_cte() + pre_join_self_latest_cte_name = self.latest_cte_name + columns = self._ensure_and_normalize_cols(on) + join_type = how.replace("_", " ") + if isinstance(columns[0].expression, exp.Column): + join_columns = [Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns] + join_clause = functools.reduce( + lambda x, y: x & y, + [ + col.copy().set_table_name(pre_join_self_latest_cte_name) + == col.copy().set_table_name(other_df.latest_cte_name) + for col in columns + ], + ) + else: + if len(columns) > 1: + columns = [functools.reduce(lambda x, y: x & y, columns)] + join_clause = columns[0] + join_columns = [ + Column(x).set_table_name(pre_join_self_latest_cte_name) + if i % 2 == 0 + else Column(x).set_table_name(other_df.latest_cte_name) + for i, x in enumerate(join_clause.expression.find_all(exp.Column)) + ] + self_columns = [ + column.set_table_name(pre_join_self_latest_cte_name, copy=True) + for column in self._get_outer_select_columns(self) + ] + other_columns = [ + column.set_table_name(other_df.latest_cte_name, copy=True) + for column in self._get_outer_select_columns(other_df) + ] + column_value_mapping = { + column.alias_or_name if not isinstance(column.expression.this, exp.Star) else column.sql(): column + for column in other_columns + self_columns + join_columns + } + all_columns = [ + column_value_mapping[name] + for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns} + ] + new_df = self.copy( + expression=self.expression.join(other_df.latest_cte_name, on=join_clause.expression, join_type=join_type) + ) + new_df.expression = new_df._add_ctes_to_expression(new_df.expression, other_df.expression.ctes) + new_df.pending_hints.extend(other_df.pending_hints) + new_df = new_df.select.__wrapped__(new_df, *all_columns) + return new_df + + @operation(Operation.ORDER_BY) + def orderBy( + self, *cols: t.Union[str, Column], ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None + ) -> DataFrame: + """ + This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark + has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this + is unlikely to come up. + """ + columns = self._ensure_and_normalize_cols(cols) + pre_ordered_col_indexes = [ + x + for x in [i if isinstance(col.expression, exp.Ordered) else None for i, col in enumerate(columns)] + if x is not None + ] + if ascending is None: + ascending = [True] * len(columns) + elif not isinstance(ascending, list): + ascending = [ascending] * len(columns) + ascending = [bool(x) for i, x in enumerate(ascending)] + assert len(columns) == len( + ascending + ), "The length of items in ascending must equal the number of columns provided" + col_and_ascending = list(zip(columns, ascending)) + order_by_columns = [ + exp.Ordered(this=col.expression, desc=not asc) + if i not in pre_ordered_col_indexes + else columns[i].column_expression + for i, (col, asc) in enumerate(col_and_ascending) + ] + return self.copy(expression=self.expression.order_by(*order_by_columns)) + + sort = orderBy + + @operation(Operation.FROM) + def union(self, other: DataFrame) -> DataFrame: + return self._set_operation(exp.Union, other, False) + + unionAll = union + + @operation(Operation.FROM) + def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): + l_columns = self.columns + r_columns = other.columns + if not allowMissingColumns: + l_expressions = l_columns + r_expressions = l_columns + else: + l_expressions = [] + r_expressions = [] + r_columns_unused = copy(r_columns) + for l_column in l_columns: + l_expressions.append(l_column) + if l_column in r_columns: + r_expressions.append(l_column) + r_columns_unused.remove(l_column) + else: + r_expressions.append(exp.alias_(exp.Null(), l_column)) + for r_column in r_columns_unused: + l_expressions.append(exp.alias_(exp.Null(), r_column)) + r_expressions.append(r_column) + r_df = other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) + l_df = self.copy() + if allowMissingColumns: + l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) + return l_df._set_operation(exp.Union, r_df, False) + + @operation(Operation.FROM) + def intersect(self, other: DataFrame) -> DataFrame: + return self._set_operation(exp.Intersect, other, True) + + @operation(Operation.FROM) + def intersectAll(self, other: DataFrame) -> DataFrame: + return self._set_operation(exp.Intersect, other, False) + + @operation(Operation.FROM) + def exceptAll(self, other: DataFrame) -> DataFrame: + return self._set_operation(exp.Except, other, False) + + @operation(Operation.SELECT) + def distinct(self) -> DataFrame: + return self.copy(expression=self.expression.distinct()) + + @operation(Operation.SELECT) + def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): + if not subset: + return self.distinct() + column_names = ensure_list(subset) + window = Window.partitionBy(*column_names).orderBy(*column_names) + return ( + self.copy() + .withColumn("row_num", F.row_number().over(window)) + .where(F.col("row_num") == F.lit(1)) + .drop("row_num") + ) + + @operation(Operation.FROM) + def dropna( + self, + how: str = "any", + thresh: t.Optional[int] = None, + subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, + ) -> DataFrame: + minimum_non_null = thresh or 0 # will be determined later if thresh is null + new_df = self.copy() + all_columns = self._get_outer_select_columns(new_df.expression) + if subset: + null_check_columns = self._ensure_and_normalize_cols(subset) + else: + null_check_columns = all_columns + if thresh is None: + minimum_num_nulls = 1 if how == "any" else len(null_check_columns) + else: + minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 + if minimum_num_nulls > len(null_check_columns): + raise RuntimeError( + f"The minimum num nulls for dropna must be less than or equal to the number of columns. " + f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" + ) + if_null_checks = [F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns] + nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) + num_nulls = nulls_added_together.alias("num_nulls") + new_df = new_df.select(num_nulls, append=True) + filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) + final_df = filtered_df.select(*all_columns) + return final_df + + @operation(Operation.FROM) + def fillna( + self, + value: t.Union[ColumnLiterals], + subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, + ) -> DataFrame: + """ + Functionality Difference: If you provide a value to replace a null and that type conflicts + with the type of the column then PySpark will just ignore your replacement. + This will try to cast them to be the same in some cases. So they won't always match. + Best to not mix types so make sure replacement is the same type as the column + + Possibility for improvement: Use `typeof` function to get the type of the column + and check if it matches the type of the value provided. If not then make it null. + """ + from sqlglot.dataframe.sql.functions import lit + + values = None + columns = None + new_df = self.copy() + all_columns = self._get_outer_select_columns(new_df.expression) + all_column_mapping = {column.alias_or_name: column for column in all_columns} + if isinstance(value, dict): + values = value.values() + columns = self._ensure_and_normalize_cols(list(value)) + if not columns: + columns = self._ensure_and_normalize_cols(subset) if subset else all_columns + if not values: + values = [value] * len(columns) + value_columns = [lit(value) for value in values] + + null_replacement_mapping = { + column.alias_or_name: (F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)) + for column, value in zip(columns, value_columns) + } + null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} + null_replacement_columns = [null_replacement_mapping[column.alias_or_name] for column in all_columns] + new_df = new_df.select(*null_replacement_columns) + return new_df + + @operation(Operation.FROM) + def replace( + self, + to_replace: t.Union[bool, int, float, str, t.List, t.Dict], + value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, + subset: t.Optional[t.Union[str, t.List[str]]] = None, + ) -> DataFrame: + from sqlglot.dataframe.sql.functions import lit + + old_values = None + subset = ensure_list(subset) + new_df = self.copy() + all_columns = self._get_outer_select_columns(new_df.expression) + all_column_mapping = {column.alias_or_name: column for column in all_columns} + + columns = self._ensure_and_normalize_cols(subset) if subset else all_columns + if isinstance(to_replace, dict): + old_values = list(to_replace) + new_values = list(to_replace.values()) + elif not old_values and isinstance(to_replace, list): + assert isinstance(value, list), "value must be a list since the replacements are a list" + assert len(to_replace) == len(value), "the replacements and values must be the same length" + old_values = to_replace + new_values = value + else: + old_values = [to_replace] * len(columns) + new_values = [value] * len(columns) + old_values = [lit(value) for value in old_values] + new_values = [lit(value) for value in new_values] + + replacement_mapping = {} + for column in columns: + expression = Column(None) + for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): + if i == 0: + expression = F.when(column == old_value, new_value) + else: + expression = expression.when(column == old_value, new_value) # type: ignore + replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( + column.expression.alias_or_name + ) + + replacement_mapping = {**all_column_mapping, **replacement_mapping} + replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] + new_df = new_df.select(*replacement_columns) + return new_df + + @operation(Operation.SELECT) + def withColumn(self, colName: str, col: Column) -> DataFrame: + col = self._ensure_and_normalize_col(col) + existing_col_names = self.expression.named_selects + existing_col_index = existing_col_names.index(colName) if colName in existing_col_names else None + if existing_col_index: + expression = self.expression.copy() + expression.expressions[existing_col_index] = col.expression + return self.copy(expression=expression) + return self.copy().select(col.alias(colName), append=True) + + @operation(Operation.SELECT) + def withColumnRenamed(self, existing: str, new: str): + expression = self.expression.copy() + existing_columns = [expression for expression in expression.expressions if expression.alias_or_name == existing] + if not existing_columns: + raise ValueError("Tried to rename a column that doesn't exist") + for existing_column in existing_columns: + if isinstance(existing_column, exp.Column): + existing_column.replace(exp.alias_(existing_column.copy(), new)) + else: + existing_column.set("alias", exp.to_identifier(new)) + return self.copy(expression=expression) + + @operation(Operation.SELECT) + def drop(self, *cols: t.Union[str, Column]) -> DataFrame: + all_columns = self._get_outer_select_columns(self.expression) + drop_cols = self._ensure_and_normalize_cols(cols) + new_columns = [ + col + for col in all_columns + if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] + ] + return self.copy().select(*new_columns, append=False) + + @operation(Operation.LIMIT) + def limit(self, num: int) -> DataFrame: + return self.copy(expression=self.expression.limit(num)) + + @operation(Operation.NO_OP) + def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: + parameter_list = ensure_list(parameters) + parameter_columns = ( + self._ensure_list_of_columns(parameter_list) if parameters else Column.ensure_cols([self.sequence_id]) + ) + return self._hint(name, parameter_columns) + + @operation(Operation.NO_OP) + def repartition(self, numPartitions: t.Union[int, str], *cols: t.Union[int, str]) -> DataFrame: + num_partitions = Column.ensure_cols(ensure_list(numPartitions)) + columns = self._ensure_and_normalize_cols(cols) + args = num_partitions + columns + return self._hint("repartition", args) + + @operation(Operation.NO_OP) + def coalesce(self, numPartitions: int) -> DataFrame: + num_partitions = Column.ensure_cols([numPartitions]) + return self._hint("coalesce", num_partitions) + + @operation(Operation.NO_OP) + def cache(self) -> DataFrame: + return self._cache(storage_level="MEMORY_AND_DISK") + + @operation(Operation.NO_OP) + def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: + """ + Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html + """ + return self._cache(storageLevel) + + +class DataFrameNaFunctions: + def __init__(self, df: DataFrame): + self.df = df + + def drop( + self, + how: str = "any", + thresh: t.Optional[int] = None, + subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, + ) -> DataFrame: + return self.df.dropna(how=how, thresh=thresh, subset=subset) + + def fill( + self, + value: t.Union[int, bool, float, str, t.Dict[str, t.Any]], + subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, + ) -> DataFrame: + return self.df.fillna(value=value, subset=subset) + + def replace( + self, + to_replace: t.Union[bool, int, float, str, t.List, t.Dict], + value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, + subset: t.Optional[t.Union[str, t.List[str]]] = None, + ) -> DataFrame: + return self.df.replace(to_replace=to_replace, value=value, subset=subset) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py new file mode 100644 index 0000000..4c6de30 --- /dev/null +++ b/sqlglot/dataframe/sql/functions.py @@ -0,0 +1,1258 @@ +from __future__ import annotations + +import typing as t +from inspect import signature + +from sqlglot import expressions as glotexp +from sqlglot.dataframe.sql.column import Column +from sqlglot.helper import ensure_list +from sqlglot.helper import flatten as _flatten + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql._typing import ColumnOrLiteral, ColumnOrName + from sqlglot.dataframe.sql.dataframe import DataFrame + + +def col(column_name: t.Union[ColumnOrName, t.Any]) -> Column: + return Column(column_name) + + +def lit(value: t.Optional[t.Any] = None) -> Column: + if isinstance(value, str): + return Column(glotexp.Literal.string(str(value))) + return Column(value) + + +def greatest(*cols: ColumnOrName) -> Column: + columns = [Column.ensure_col(col) for col in cols] + return Column.invoke_expression_over_column( + columns[0], glotexp.Greatest, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None + ) + + +def least(*cols: ColumnOrName) -> Column: + columns = [Column.ensure_col(col) for col in cols] + return Column.invoke_expression_over_column( + columns[0], glotexp.Least, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None + ) + + +def count_distinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: + columns = [Column.ensure_col(x) for x in [col] + list(cols)] + return Column(glotexp.Count(this=glotexp.Distinct(expressions=[x.expression for x in columns]))) + + +def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: + return count_distinct(col, *cols) + + +def when(condition: Column, value: t.Any) -> Column: + true_value = value if isinstance(value, Column) else lit(value) + return Column(glotexp.Case(ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)])) + + +def asc(col: ColumnOrName) -> Column: + return Column.ensure_col(col).asc() + + +def desc(col: ColumnOrName): + return Column.ensure_col(col).desc() + + +def broadcast(df: DataFrame) -> DataFrame: + return df.hint("broadcast") + + +def sqrt(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Sqrt) + + +def abs(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Abs) + + +def max(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Max) + + +def min(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Min) + + +def max_by(col: ColumnOrName, ord: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "MAX_BY", ord) + + +def min_by(col: ColumnOrName, ord: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "MIN_BY", ord) + + +def count(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Count) + + +def sum(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Sum) + + +def avg(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Avg) + + +def mean(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "MEAN") + + +def sumDistinct(col: ColumnOrName) -> Column: + return sum_distinct(col) + + +def sum_distinct(col: ColumnOrName) -> Column: + raise NotImplementedError("Sum distinct is not currently implemented") + + +def product(col: ColumnOrName) -> Column: + raise NotImplementedError("Product is not currently implemented") + + +def acos(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ACOS") + + +def acosh(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ACOSH") + + +def asin(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ASIN") + + +def asinh(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ASINH") + + +def atan(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ATAN") + + +def atan2(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column: + return Column.invoke_anonymous_function(col1, "ATAN2", col2) + + +def atanh(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ATANH") + + +def cbrt(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "CBRT") + + +def ceil(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Ceil) + + +def cos(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "COS") + + +def cosh(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "COSH") + + +def cot(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "COT") + + +def csc(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "CSC") + + +def exp(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Exp) + + +def expm1(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "EXPM1") + + +def floor(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Floor) + + +def log10(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Log10) + + +def log1p(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "LOG1P") + + +def log2(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Log2) + + +def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column: + if arg2 is None: + return Column.invoke_expression_over_column(arg1, glotexp.Ln) + return Column.invoke_expression_over_column(arg1, glotexp.Log, expression=Column.ensure_col(arg2).expression) + + +def rint(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "RINT") + + +def sec(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "SEC") + + +def signum(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "SIGNUM") + + +def sin(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "SIN") + + +def sinh(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "SINH") + + +def tan(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "TAN") + + +def tanh(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "TANH") + + +def toDegrees(col: ColumnOrName) -> Column: + return degrees(col) + + +def degrees(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "DEGREES") + + +def toRadians(col: ColumnOrName) -> Column: + return radians(col) + + +def radians(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "RADIANS") + + +def bitwiseNOT(col: ColumnOrName) -> Column: + return bitwise_not(col) + + +def bitwise_not(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.BitwiseNot) + + +def asc_nulls_first(col: ColumnOrName) -> Column: + return Column.ensure_col(col).asc_nulls_first() + + +def asc_nulls_last(col: ColumnOrName) -> Column: + return Column.ensure_col(col).asc_nulls_last() + + +def desc_nulls_first(col: ColumnOrName) -> Column: + return Column.ensure_col(col).desc_nulls_first() + + +def desc_nulls_last(col: ColumnOrName) -> Column: + return Column.ensure_col(col).desc_nulls_last() + + +def stddev(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Stddev) + + +def stddev_samp(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.StddevSamp) + + +def stddev_pop(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.StddevPop) + + +def variance(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Variance) + + +def var_samp(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Variance) + + +def var_pop(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.VariancePop) + + +def skewness(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "SKEWNESS") + + +def kurtosis(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "KURTOSIS") + + +def collect_list(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.ArrayAgg) + + +def collect_set(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.SetAgg) + + +def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column: + return Column.invoke_anonymous_function(col1, "HYPOT", col2) + + +def pow(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column: + return Column.invoke_anonymous_function(col1, "POW", col2) + + +def row_number() -> Column: + return Column(glotexp.Anonymous(this="ROW_NUMBER")) + + +def dense_rank() -> Column: + return Column(glotexp.Anonymous(this="DENSE_RANK")) + + +def rank() -> Column: + return Column(glotexp.Anonymous(this="RANK")) + + +def cume_dist() -> Column: + return Column(glotexp.Anonymous(this="CUME_DIST")) + + +def percent_rank() -> Column: + return Column(glotexp.Anonymous(this="PERCENT_RANK")) + + +def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column: + return approx_count_distinct(col, rsd) + + +def approx_count_distinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column: + if rsd is None: + return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct) + return Column.invoke_expression_over_column(col, glotexp.ApproxDistinct, accuracy=Column.ensure_col(rsd).expression) + + +def coalesce(*cols: ColumnOrName) -> Column: + columns = [Column.ensure_col(col) for col in cols] + return Column.invoke_expression_over_column( + columns[0], glotexp.Coalesce, expressions=[col.expression for col in columns[1:]] if len(columns) > 1 else None + ) + + +def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col1, "CORR", col2) + + +def covar_pop(col1: ColumnOrName, col2: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col1, "COVAR_POP", col2) + + +def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col1, "COVAR_SAMP", col2) + + +def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column: + if ignorenulls is not None: + return Column.invoke_anonymous_function(col, "FIRST", ignorenulls) + return Column.invoke_anonymous_function(col, "FIRST") + + +def grouping_id(*cols: ColumnOrName) -> Column: + if not cols: + return Column.invoke_anonymous_function(None, "GROUPING_ID") + if len(cols) == 1: + return Column.invoke_anonymous_function(cols[0], "GROUPING_ID") + return Column.invoke_anonymous_function(cols[0], "GROUPING_ID", *cols[1:]) + + +def input_file_name() -> Column: + return Column.invoke_anonymous_function(None, "INPUT_FILE_NAME") + + +def isnan(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ISNAN") + + +def isnull(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ISNULL") + + +def last(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column: + if ignorenulls is not None: + return Column.invoke_anonymous_function(col, "LAST", ignorenulls) + return Column.invoke_anonymous_function(col, "LAST") + + +def monotonically_increasing_id() -> Column: + return Column.invoke_anonymous_function(None, "MONOTONICALLY_INCREASING_ID") + + +def nanvl(col1: ColumnOrName, col2: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col1, "NANVL", col2) + + +def percentile_approx( + col: ColumnOrName, + percentage: t.Union[ColumnOrLiteral, t.List[float], t.Tuple[float]], + accuracy: t.Optional[t.Union[ColumnOrLiteral]] = None, +) -> Column: + if accuracy: + return Column.invoke_anonymous_function(col, "PERCENTILE_APPROX", percentage, accuracy) + return Column.invoke_anonymous_function(col, "PERCENTILE_APPROX", percentage) + + +def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column: + return Column.invoke_anonymous_function(seed, "RAND") + + +def randn(seed: t.Optional[ColumnOrLiteral] = None) -> Column: + return Column.invoke_anonymous_function(seed, "RANDN") + + +def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column: + if scale is not None: + return Column.invoke_expression_over_column(col, glotexp.Round, decimals=glotexp.convert(scale)) + return Column.invoke_expression_over_column(col, glotexp.Round) + + +def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column: + if scale is not None: + return Column.invoke_anonymous_function(col, "BROUND", scale) + return Column.invoke_anonymous_function(col, "BROUND") + + +def shiftleft(col: ColumnOrName, numBits: int) -> Column: + return Column.invoke_expression_over_column( + col, glotexp.BitwiseLeftShift, expression=Column.ensure_col(numBits).expression + ) + + +def shiftLeft(col: ColumnOrName, numBits: int) -> Column: + return shiftleft(col, numBits) + + +def shiftright(col: ColumnOrName, numBits: int) -> Column: + return Column.invoke_expression_over_column( + col, glotexp.BitwiseRightShift, expression=Column.ensure_col(numBits).expression + ) + + +def shiftRight(col: ColumnOrName, numBits: int) -> Column: + return shiftright(col, numBits) + + +def shiftrightunsigned(col: ColumnOrName, numBits: int) -> Column: + return Column.invoke_anonymous_function(col, "SHIFTRIGHTUNSIGNED", numBits) + + +def shiftRightUnsigned(col: ColumnOrName, numBits: int) -> Column: + return shiftrightunsigned(col, numBits) + + +def expr(str: str) -> Column: + return Column(str) + + +def struct(col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOrName) -> Column: + columns = ensure_list(col) + list(cols) + expressions = [Column.ensure_col(column).expression for column in columns] + return Column(glotexp.Struct(expressions=expressions)) + + +def conv(col: ColumnOrName, fromBase: int, toBase: int) -> Column: + return Column.invoke_anonymous_function(col, "CONV", fromBase, toBase) + + +def factorial(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "FACTORIAL") + + +def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None) -> Column: + if default is not None: + return Column.invoke_anonymous_function(col, "LAG", offset, default) + if offset != 1: + return Column.invoke_anonymous_function(col, "LAG", offset) + return Column.invoke_anonymous_function(col, "LAG") + + +def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None) -> Column: + if default is not None: + return Column.invoke_anonymous_function(col, "LEAD", offset, default) + if offset != 1: + return Column.invoke_anonymous_function(col, "LEAD", offset) + return Column.invoke_anonymous_function(col, "LEAD") + + +def nth_value(col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None) -> Column: + if ignoreNulls is not None: + raise NotImplementedError("There is currently not support for `ignoreNulls` parameter") + if offset != 1: + return Column.invoke_anonymous_function(col, "NTH_VALUE", offset) + return Column.invoke_anonymous_function(col, "NTH_VALUE") + + +def ntile(n: int) -> Column: + return Column.invoke_anonymous_function(None, "NTILE", n) + + +def current_date() -> Column: + return Column.invoke_expression_over_column(None, glotexp.CurrentDate) + + +def current_timestamp() -> Column: + return Column.invoke_expression_over_column(None, glotexp.CurrentTimestamp) + + +def date_format(col: ColumnOrName, format: str) -> Column: + return Column.invoke_anonymous_function(col, "DATE_FORMAT", lit(format)) + + +def year(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Year) + + +def quarter(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "QUARTER") + + +def month(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Month) + + +def dayofweek(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "DAYOFWEEK") + + +def dayofmonth(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "DAYOFMONTH") + + +def dayofyear(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "DAYOFYEAR") + + +def hour(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "HOUR") + + +def minute(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "MINUTE") + + +def second(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "SECOND") + + +def weekofyear(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "WEEKOFYEAR") + + +def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(year, "MAKE_DATE", month, day) + + +def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: + return Column.invoke_expression_over_column(col, glotexp.DateAdd, expression=Column.ensure_col(days).expression) + + +def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: + return Column.invoke_expression_over_column(col, glotexp.DateSub, expression=Column.ensure_col(days).expression) + + +def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(end, glotexp.DateDiff, expression=Column.ensure_col(start).expression) + + +def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column: + return Column.invoke_anonymous_function(start, "ADD_MONTHS", months) + + +def months_between(date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None) -> Column: + if roundOff is None: + return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2) + return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff) + + +def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column: + if format is not None: + return Column.invoke_anonymous_function(col, "TO_DATE", lit(format)) + return Column.invoke_anonymous_function(col, "TO_DATE") + + +def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column: + if format is not None: + return Column.invoke_anonymous_function(col, "TO_TIMESTAMP", lit(format)) + return Column.invoke_anonymous_function(col, "TO_TIMESTAMP") + + +def trunc(col: ColumnOrName, format: str) -> Column: + return Column.invoke_expression_over_column(col, glotexp.DateTrunc, unit=lit(format).expression) + + +def date_trunc(format: str, timestamp: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(timestamp, glotexp.TimestampTrunc, unit=lit(format).expression) + + +def next_day(col: ColumnOrName, dayOfWeek: str) -> Column: + return Column.invoke_anonymous_function(col, "NEXT_DAY", lit(dayOfWeek)) + + +def last_day(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "LAST_DAY") + + +def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column: + if format is not None: + return Column.invoke_anonymous_function(col, "FROM_UNIXTIME", lit(format)) + return Column.invoke_anonymous_function(col, "FROM_UNIXTIME") + + +def unix_timestamp(timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None) -> Column: + if format is not None: + return Column.invoke_anonymous_function(timestamp, "UNIX_TIMESTAMP", lit(format)) + return Column.invoke_anonymous_function(timestamp, "UNIX_TIMESTAMP") + + +def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column: + tz_column = tz if isinstance(tz, Column) else lit(tz) + return Column.invoke_anonymous_function(timestamp, "FROM_UTC_TIMESTAMP", tz_column) + + +def to_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column: + tz_column = tz if isinstance(tz, Column) else lit(tz) + return Column.invoke_anonymous_function(timestamp, "TO_UTC_TIMESTAMP", tz_column) + + +def timestamp_seconds(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "TIMESTAMP_SECONDS") + + +def window( + timeColumn: ColumnOrName, + windowDuration: str, + slideDuration: t.Optional[str] = None, + startTime: t.Optional[str] = None, +) -> Column: + if slideDuration is not None and startTime is not None: + return Column.invoke_anonymous_function( + timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration), lit(startTime) + ) + if slideDuration is not None: + return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration)) + if startTime is not None: + return Column.invoke_anonymous_function( + timeColumn, "WINDOW", lit(windowDuration), lit(windowDuration), lit(startTime) + ) + return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration)) + + +def session_window(timeColumn: ColumnOrName, gapDuration: ColumnOrName) -> Column: + gap_duration_column = gapDuration if isinstance(gapDuration, Column) else lit(gapDuration) + return Column.invoke_anonymous_function(timeColumn, "SESSION_WINDOW", gap_duration_column) + + +def crc32(col: ColumnOrName) -> Column: + column = col if isinstance(col, Column) else lit(col) + return Column.invoke_anonymous_function(column, "CRC32") + + +def md5(col: ColumnOrName) -> Column: + column = col if isinstance(col, Column) else lit(col) + return Column.invoke_anonymous_function(column, "MD5") + + +def sha1(col: ColumnOrName) -> Column: + column = col if isinstance(col, Column) else lit(col) + return Column.invoke_anonymous_function(column, "SHA1") + + +def sha2(col: ColumnOrName, numBits: int) -> Column: + column = col if isinstance(col, Column) else lit(col) + num_bits = lit(numBits) + return Column.invoke_anonymous_function(column, "SHA2", num_bits) + + +def hash(*cols: ColumnOrName) -> Column: + args = cols[1:] if len(cols) > 1 else [] + return Column.invoke_anonymous_function(cols[0], "HASH", *args) + + +def xxhash64(*cols: ColumnOrName) -> Column: + args = cols[1:] if len(cols) > 1 else [] + return Column.invoke_anonymous_function(cols[0], "XXHASH64", *args) + + +def assert_true(col: ColumnOrName, errorMsg: t.Optional[ColumnOrName] = None) -> Column: + if errorMsg is not None: + error_msg_col = errorMsg if isinstance(errorMsg, Column) else lit(errorMsg) + return Column.invoke_anonymous_function(col, "ASSERT_TRUE", error_msg_col) + return Column.invoke_anonymous_function(col, "ASSERT_TRUE") + + +def raise_error(errorMsg: ColumnOrName) -> Column: + error_msg_col = errorMsg if isinstance(errorMsg, Column) else lit(errorMsg) + return Column.invoke_anonymous_function(error_msg_col, "RAISE_ERROR") + + +def upper(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Upper) + + +def lower(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Lower) + + +def ascii(col: ColumnOrLiteral) -> Column: + return Column.invoke_anonymous_function(col, "ASCII") + + +def base64(col: ColumnOrLiteral) -> Column: + return Column.invoke_anonymous_function(col, "BASE64") + + +def unbase64(col: ColumnOrLiteral) -> Column: + return Column.invoke_anonymous_function(col, "UNBASE64") + + +def ltrim(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "LTRIM") + + +def rtrim(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "RTRIM") + + +def trim(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Trim) + + +def concat_ws(sep: str, *cols: ColumnOrName) -> Column: + columns = [Column(col) for col in cols] + return Column.invoke_expression_over_column( + None, glotexp.ConcatWs, expressions=[x.expression for x in [lit(sep)] + list(columns)] + ) + + +def decode(col: ColumnOrName, charset: str) -> Column: + return Column.invoke_anonymous_function(col, "DECODE", lit(charset)) + + +def encode(col: ColumnOrName, charset: str) -> Column: + return Column.invoke_anonymous_function(col, "ENCODE", lit(charset)) + + +def format_number(col: ColumnOrName, d: int) -> Column: + return Column.invoke_anonymous_function(col, "FORMAT_NUMBER", lit(d)) + + +def format_string(format: str, *cols: ColumnOrName) -> Column: + format_col = lit(format) + columns = [Column.ensure_col(x) for x in cols] + return Column.invoke_anonymous_function(format_col, "FORMAT_STRING", *columns) + + +def instr(col: ColumnOrName, substr: str) -> Column: + return Column.invoke_anonymous_function(col, "INSTR", lit(substr)) + + +def overlay( + src: ColumnOrName, + replace: ColumnOrName, + pos: t.Union[ColumnOrName, int], + len: t.Optional[t.Union[ColumnOrName, int]] = None, +) -> Column: + if len is not None: + return Column.invoke_anonymous_function(src, "OVERLAY", replace, pos, len) + return Column.invoke_anonymous_function(src, "OVERLAY", replace, pos) + + +def sentences( + string: ColumnOrName, language: t.Optional[ColumnOrName] = None, country: t.Optional[ColumnOrName] = None +) -> Column: + if language is not None and country is not None: + return Column.invoke_anonymous_function(string, "SENTENCES", language, country) + if language is not None: + return Column.invoke_anonymous_function(string, "SENTENCES", language) + if country is not None: + return Column.invoke_anonymous_function(string, "SENTENCES", lit("en"), country) + return Column.invoke_anonymous_function(string, "SENTENCES") + + +def substring(str: ColumnOrName, pos: int, len: int) -> Column: + return Column.ensure_col(str).substr(pos, len) + + +def substring_index(str: ColumnOrName, delim: str, count: int) -> Column: + return Column.invoke_anonymous_function(str, "SUBSTRING_INDEX", lit(delim), lit(count)) + + +def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column: + return Column.invoke_expression_over_column( + left, glotexp.Levenshtein, expression=Column.ensure_col(right).expression + ) + + +def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column: + substr_col = lit(substr) + pos_column = lit(pos) + str_column = Column.ensure_col(str) + if pos is not None: + return Column.invoke_anonymous_function(substr_col, "LOCATE", str_column, pos_column) + return Column.invoke_anonymous_function(substr_col, "LOCATE", str_column) + + +def lpad(col: ColumnOrName, len: int, pad: str) -> Column: + return Column.invoke_anonymous_function(col, "LPAD", lit(len), lit(pad)) + + +def rpad(col: ColumnOrName, len: int, pad: str) -> Column: + return Column.invoke_anonymous_function(col, "RPAD", lit(len), lit(pad)) + + +def repeat(col: ColumnOrName, n: int) -> Column: + return Column.invoke_anonymous_function(col, "REPEAT", n) + + +def split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column: + if limit is not None: + return Column.invoke_expression_over_column( + str, glotexp.RegexpSplit, expression=lit(pattern).expression, limit=lit(limit).expression + ) + return Column.invoke_expression_over_column(str, glotexp.RegexpSplit, expression=lit(pattern).expression) + + +def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) -> Column: + if idx is not None: + return Column.invoke_anonymous_function(str, "REGEXP_EXTRACT", lit(pattern), idx) + return Column.invoke_anonymous_function(str, "REGEXP_EXTRACT", lit(pattern)) + + +def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column: + return Column.invoke_anonymous_function(str, "REGEXP_REPLACE", lit(pattern), lit(replacement)) + + +def initcap(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Initcap) + + +def soundex(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "SOUNDEX") + + +def bin(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "BIN") + + +def hex(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "HEX") + + +def unhex(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "UNHEX") + + +def length(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Length) + + +def octet_length(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "OCTET_LENGTH") + + +def bit_length(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "BIT_LENGTH") + + +def translate(srcCol: ColumnOrName, matching: str, replace: str) -> Column: + return Column.invoke_anonymous_function(srcCol, "TRANSLATE", lit(matching), lit(replace)) + + +def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: + cols = _flatten(cols) if not isinstance(cols[0], (str, Column)) else cols # type: ignore + cols = [Column.ensure_col(col).expression for col in cols] # type: ignore + return Column.invoke_expression_over_column(None, glotexp.Array, expressions=cols) + + +def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: + cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore + return Column.invoke_expression_over_column( + None, glotexp.VarMap, keys=array(*cols[::2]).expression, values=array(*cols[1::2]).expression + ) + + +def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col1, "MAP_FROM_ARRAYS", col2) + + +def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column: + value_col = value if isinstance(value, Column) else lit(value) + return Column.invoke_expression_over_column(col, glotexp.ArrayContains, expression=value_col.expression) + + +def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col1, "ARRAYS_OVERLAP", Column.ensure_col(col2)) + + +def slice(x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]) -> Column: + start_col = start if isinstance(start, Column) else lit(start) + length_col = length if isinstance(length, Column) else lit(length) + return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col) + + +def array_join(col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None) -> Column: + if null_replacement is not None: + return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement)) + return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter)) + + +def concat(*cols: ColumnOrName) -> Column: + if len(cols) == 1: + return Column.invoke_anonymous_function(cols[0], "CONCAT") + return Column.invoke_anonymous_function(cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]]) + + +def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column: + value_col = value if isinstance(value, Column) else lit(value) + return Column.invoke_anonymous_function(col, "ARRAY_POSITION", value_col) + + +def element_at(col: ColumnOrName, value: ColumnOrLiteral) -> Column: + value_col = value if isinstance(value, Column) else lit(value) + return Column.invoke_anonymous_function(col, "ELEMENT_AT", value_col) + + +def array_remove(col: ColumnOrName, value: ColumnOrLiteral) -> Column: + value_col = value if isinstance(value, Column) else lit(value) + return Column.invoke_anonymous_function(col, "ARRAY_REMOVE", value_col) + + +def array_distinct(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ARRAY_DISTINCT") + + +def array_intersect(col1: ColumnOrName, col2: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col1, "ARRAY_INTERSECT", Column.ensure_col(col2)) + + +def array_union(col1: ColumnOrName, col2: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col1, "ARRAY_UNION", Column.ensure_col(col2)) + + +def array_except(col1: ColumnOrName, col2: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col1, "ARRAY_EXCEPT", Column.ensure_col(col2)) + + +def explode(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Explode) + + +def posexplode(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.Posexplode) + + +def explode_outer(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "EXPLODE_OUTER") + + +def posexplode_outer(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "POSEXPLODE_OUTER") + + +def get_json_object(col: ColumnOrName, path: str) -> Column: + return Column.invoke_expression_over_column(col, glotexp.JSONExtract, path=lit(path).expression) + + +def json_tuple(col: ColumnOrName, *fields: str) -> Column: + return Column.invoke_anonymous_function(col, "JSON_TUPLE", *[lit(field) for field in fields]) + + +def from_json( + col: ColumnOrName, + schema: t.Union[Column, str], + options: t.Optional[t.Dict[str, str]] = None, +) -> Column: + schema = schema if isinstance(schema, Column) else lit(schema) + if options is not None: + options_col = create_map([lit(x) for x in _flatten(options.items())]) + return Column.invoke_anonymous_function(col, "FROM_JSON", schema, options_col) + return Column.invoke_anonymous_function(col, "FROM_JSON", schema) + + +def to_json(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column: + if options is not None: + options_col = create_map([lit(x) for x in _flatten(options.items())]) + return Column.invoke_anonymous_function(col, "TO_JSON", options_col) + return Column.invoke_anonymous_function(col, "TO_JSON") + + +def schema_of_json(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column: + if options is not None: + options_col = create_map([lit(x) for x in _flatten(options.items())]) + return Column.invoke_anonymous_function(col, "SCHEMA_OF_JSON", options_col) + return Column.invoke_anonymous_function(col, "SCHEMA_OF_JSON") + + +def schema_of_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column: + if options is not None: + options_col = create_map([lit(x) for x in _flatten(options.items())]) + return Column.invoke_anonymous_function(col, "SCHEMA_OF_CSV", options_col) + return Column.invoke_anonymous_function(col, "SCHEMA_OF_CSV") + + +def to_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column: + if options is not None: + options_col = create_map([lit(x) for x in _flatten(options.items())]) + return Column.invoke_anonymous_function(col, "TO_CSV", options_col) + return Column.invoke_anonymous_function(col, "TO_CSV") + + +def size(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.ArraySize) + + +def array_min(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ARRAY_MIN") + + +def array_max(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "ARRAY_MAX") + + +def sort_array(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column: + if asc is not None: + return Column.invoke_anonymous_function(col, "SORT_ARRAY", lit(asc)) + return Column.invoke_anonymous_function(col, "SORT_ARRAY") + + +def array_sort(col: ColumnOrName) -> Column: + return Column.invoke_expression_over_column(col, glotexp.ArraySort) + + +def shuffle(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "SHUFFLE") + + +def reverse(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "REVERSE") + + +def flatten(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "FLATTEN") + + +def map_keys(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "MAP_KEYS") + + +def map_values(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "MAP_VALUES") + + +def map_entries(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "MAP_ENTRIES") + + +def map_from_entries(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "MAP_FROM_ENTRIES") + + +def array_repeat(col: ColumnOrName, count: t.Union[ColumnOrName, int]) -> Column: + count_col = count if isinstance(count, Column) else lit(count) + return Column.invoke_anonymous_function(col, "ARRAY_REPEAT", count_col) + + +def array_zip(*cols: ColumnOrName) -> Column: + if len(cols) == 1: + return Column.invoke_anonymous_function(cols[0], "ARRAY_ZIP") + return Column.invoke_anonymous_function(cols[0], "ARRAY_ZIP", *cols[1:]) + + +def map_concat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: + columns = list(flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore + if len(columns) == 1: + return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT") + return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT", *columns[1:]) + + +def sequence(start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None) -> Column: + if step is not None: + return Column.invoke_anonymous_function(start, "SEQUENCE", stop, step) + return Column.invoke_anonymous_function(start, "SEQUENCE", stop) + + +def from_csv( + col: ColumnOrName, + schema: t.Union[Column, str], + options: t.Optional[t.Dict[str, str]] = None, +) -> Column: + schema = schema if isinstance(schema, Column) else lit(schema) + if options is not None: + option_cols = create_map([lit(x) for x in _flatten(options.items())]) + return Column.invoke_anonymous_function(col, "FROM_CSV", schema, option_cols) + return Column.invoke_anonymous_function(col, "FROM_CSV", schema) + + +def aggregate( + col: ColumnOrName, + initialValue: ColumnOrName, + merge: t.Callable[[Column, Column], Column], + finish: t.Optional[t.Callable[[Column], Column]] = None, + accumulator_name: str = "acc", + target_row_name: str = "x", +) -> Column: + merge_exp = glotexp.Lambda( + this=merge(Column(accumulator_name), Column(target_row_name)).expression, + expressions=[ + glotexp.to_identifier(accumulator_name, quoted=_lambda_quoted(accumulator_name)), + glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name)), + ], + ) + if finish is not None: + finish_exp = glotexp.Lambda( + this=finish(Column(accumulator_name)).expression, + expressions=[glotexp.to_identifier(accumulator_name, quoted=_lambda_quoted(accumulator_name))], + ) + return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp)) + return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp)) + + +def transform( + col: ColumnOrName, + f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], + target_row_name: str = "x", + row_count_name: str = "i", +) -> Column: + num_arguments = len(signature(f).parameters) + expressions = [glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))] + columns = [Column(target_row_name)] + if num_arguments > 1: + columns.append(Column(row_count_name)) + expressions.append(glotexp.to_identifier(row_count_name, quoted=_lambda_quoted(row_count_name))) + + f_expression = glotexp.Lambda(this=f(*columns).expression, expressions=expressions) + return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression)) + + +def exists(col: ColumnOrName, f: t.Callable[[Column], Column], target_row_name: str = "x") -> Column: + f_expression = glotexp.Lambda( + this=f(Column(target_row_name)).expression, + expressions=[glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))], + ) + return Column.invoke_anonymous_function(col, "EXISTS", Column(f_expression)) + + +def forall(col: ColumnOrName, f: t.Callable[[Column], Column], target_row_name: str = "x") -> Column: + f_expression = glotexp.Lambda( + this=f(Column(target_row_name)).expression, + expressions=[glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))], + ) + + return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression)) + + +def filter( + col: ColumnOrName, + f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], + target_row_name: str = "x", + row_count_name: str = "i", +) -> Column: + num_arguments = len(signature(f).parameters) + expressions = [glotexp.to_identifier(target_row_name, quoted=_lambda_quoted(target_row_name))] + columns = [Column(target_row_name)] + if num_arguments > 1: + columns.append(Column(row_count_name)) + expressions.append(glotexp.to_identifier(row_count_name, quoted=_lambda_quoted(row_count_name))) + + f_expression = glotexp.Lambda(this=f(*columns).expression, expressions=expressions) + return Column.invoke_anonymous_function(col, "FILTER", Column(f_expression)) + + +def zip_with( + left: ColumnOrName, + right: ColumnOrName, + f: t.Callable[[Column, Column], Column], + left_name: str = "x", + right_name: str = "y", +) -> Column: + f_expression = glotexp.Lambda( + this=f(Column(left_name), Column(right_name)).expression, + expressions=[ + glotexp.to_identifier(left_name, quoted=_lambda_quoted(left_name)), + glotexp.to_identifier(right_name, quoted=_lambda_quoted(right_name)), + ], + ) + + return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression)) + + +def transform_keys( + col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v" +) -> Column: + f_expression = glotexp.Lambda( + this=f(Column(key_name), Column(value_name)).expression, + expressions=[ + glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)), + glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)), + ], + ) + return Column.invoke_anonymous_function(col, "TRANSFORM_KEYS", Column(f_expression)) + + +def transform_values( + col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v" +) -> Column: + f_expression = glotexp.Lambda( + this=f(Column(key_name), Column(value_name)).expression, + expressions=[ + glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)), + glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)), + ], + ) + return Column.invoke_anonymous_function(col, "TRANSFORM_VALUES", Column(f_expression)) + + +def map_filter( + col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]], key_name: str = "k", value_name: str = "v" +) -> Column: + f_expression = glotexp.Lambda( + this=f(Column(key_name), Column(value_name)).expression, + expressions=[ + glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)), + glotexp.to_identifier(value_name, quoted=_lambda_quoted(value_name)), + ], + ) + return Column.invoke_anonymous_function(col, "MAP_FILTER", Column(f_expression)) + + +def map_zip_with( + col1: ColumnOrName, + col2: ColumnOrName, + f: t.Union[t.Callable[[Column, Column, Column], Column]], + key_name: str = "k", + value1: str = "v1", + value2: str = "v2", +) -> Column: + f_expression = glotexp.Lambda( + this=f(Column(key_name), Column(value1), Column(value2)).expression, + expressions=[ + glotexp.to_identifier(key_name, quoted=_lambda_quoted(key_name)), + glotexp.to_identifier(value1, quoted=_lambda_quoted(value1)), + glotexp.to_identifier(value2, quoted=_lambda_quoted(value2)), + ], + ) + return Column.invoke_anonymous_function(col1, "MAP_ZIP_WITH", col2, Column(f_expression)) + + +def _lambda_quoted(value: str) -> t.Optional[bool]: + return False if value == "_" else None diff --git a/sqlglot/dataframe/sql/group.py b/sqlglot/dataframe/sql/group.py new file mode 100644 index 0000000..947aace --- /dev/null +++ b/sqlglot/dataframe/sql/group.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import typing as t + +from sqlglot.dataframe.sql import functions as F +from sqlglot.dataframe.sql.column import Column +from sqlglot.dataframe.sql.operations import Operation, operation + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql.dataframe import DataFrame + + +class GroupedData: + def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation): + self._df = df.copy() + self.spark = df.spark + self.last_op = last_op + self.group_by_cols = group_by_cols + + def _get_function_applied_columns(self, func_name: str, cols: t.Tuple[str, ...]) -> t.List[Column]: + func_name = func_name.lower() + return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols] + + @operation(Operation.SELECT) + def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: + columns = ( + [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] + if isinstance(exprs[0], dict) + else exprs + ) + cols = self._df._ensure_and_normalize_cols(columns) + + expression = self._df.expression.group_by(*[x.expression for x in self.group_by_cols]).select( + *[x.expression for x in self.group_by_cols + cols], append=False + ) + return self._df.copy(expression=expression) + + def count(self) -> DataFrame: + return self.agg(F.count("*").alias("count")) + + def mean(self, *cols: str) -> DataFrame: + return self.avg(*cols) + + def avg(self, *cols: str) -> DataFrame: + return self.agg(*self._get_function_applied_columns("avg", cols)) + + def max(self, *cols: str) -> DataFrame: + return self.agg(*self._get_function_applied_columns("max", cols)) + + def min(self, *cols: str) -> DataFrame: + return self.agg(*self._get_function_applied_columns("min", cols)) + + def sum(self, *cols: str) -> DataFrame: + return self.agg(*self._get_function_applied_columns("sum", cols)) + + def pivot(self, *cols: str) -> DataFrame: + raise NotImplementedError("Sum distinct is not currently implemented") diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py new file mode 100644 index 0000000..1513946 --- /dev/null +++ b/sqlglot/dataframe/sql/normalize.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import expressions as exp +from sqlglot.dataframe.sql.column import Column +from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join +from sqlglot.helper import ensure_list + +NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column]) + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql.session import SparkSession + + +def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]): + expr = ensure_list(expr) + expressions = _ensure_expressions(expr) + for expression in expressions: + identifiers = expression.find_all(exp.Identifier) + for identifier in identifiers: + replace_alias_name_with_cte_name(spark, expression_context, identifier) + replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier) + + +def replace_alias_name_with_cte_name(spark: SparkSession, expression_context: exp.Select, id: exp.Identifier): + if id.alias_or_name in spark.name_to_sequence_id_mapping: + for cte in reversed(expression_context.ctes): + if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]: + _set_alias_name(id, cte.alias_or_name) + break + + +def replace_branch_and_sequence_ids_with_cte_name( + spark: SparkSession, expression_context: exp.Select, id: exp.Identifier +): + if id.alias_or_name in spark.known_ids: + # Check if we have a join and if both the tables in that join share a common branch id + # If so we need to have this reference the left table by default unless the id is a sequence + # id then it keeps that reference. This handles the weird edge case in spark that shouldn't + # be common in practice + if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids: + join_table_aliases = [x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)] + ctes_in_join = [cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases] + if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]: + assert len(ctes_in_join) == 2 + _set_alias_name(id, ctes_in_join[0].alias_or_name) + return + + for cte in reversed(expression_context.ctes): + if id.alias_or_name in (cte.args["branch_id"], cte.args["sequence_id"]): + _set_alias_name(id, cte.alias_or_name) + return + + +def _set_alias_name(id: exp.Identifier, name: str): + id.set("this", name) + + +def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]: + values = ensure_list(values) + results = [] + for value in values: + if isinstance(value, str): + results.append(Column.ensure_col(value).expression) + elif isinstance(value, Column): + results.append(value.expression) + elif isinstance(value, exp.Expression): + results.append(value) + else: + raise ValueError(f"Got an invalid type to normalize: {type(value)}") + return results diff --git a/sqlglot/dataframe/sql/operations.py b/sqlglot/dataframe/sql/operations.py new file mode 100644 index 0000000..d51335c --- /dev/null +++ b/sqlglot/dataframe/sql/operations.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import functools +import typing as t +from enum import IntEnum + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql.dataframe import DataFrame + from sqlglot.dataframe.sql.group import GroupedData + + +class Operation(IntEnum): + INIT = -1 + NO_OP = 0 + FROM = 1 + WHERE = 2 + GROUP_BY = 3 + HAVING = 4 + SELECT = 5 + ORDER_BY = 6 + LIMIT = 7 + + +def operation(op: Operation): + """ + Decorator used around DataFrame methods to indicate what type of operation is being performed from the + ordered Operation enums. This is used to determine which operations should be performed on a CTE vs. + included with the previous operation. + + Ex: After a user does a join we want to allow them to select which columns for the different + tables that they want to carry through to the following operation. If we put that join in + a CTE preemptively then the user would not have a chance to select which column they want + in cases where there is overlap in names. + """ + + def decorator(func: t.Callable): + @functools.wraps(func) + def wrapper(self: DataFrame, *args, **kwargs): + if self.last_op == Operation.INIT: + self = self._convert_leaf_to_cte() + self.last_op = Operation.NO_OP + last_op = self.last_op + new_op = op if op != Operation.NO_OP else last_op + if new_op < last_op or (last_op == new_op and new_op == Operation.SELECT): + self = self._convert_leaf_to_cte() + df: t.Union[DataFrame, GroupedData] = func(self, *args, **kwargs) + df.last_op = new_op # type: ignore + return df + + wrapper.__wrapped__ = func # type: ignore + return wrapper + + return decorator diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py new file mode 100644 index 0000000..4830035 --- /dev/null +++ b/sqlglot/dataframe/sql/readwriter.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import typing as t + +import sqlglot +from sqlglot import expressions as exp +from sqlglot.helper import object_to_dict + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql.dataframe import DataFrame + from sqlglot.dataframe.sql.session import SparkSession + + +class DataFrameReader: + def __init__(self, spark: SparkSession): + self.spark = spark + + def table(self, tableName: str) -> DataFrame: + from sqlglot.dataframe.sql.dataframe import DataFrame + + sqlglot.schema.add_table(tableName) + return DataFrame(self.spark, exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName))) + + +class DataFrameWriter: + def __init__( + self, df: DataFrame, spark: t.Optional[SparkSession] = None, mode: t.Optional[str] = None, by_name: bool = False + ): + self._df = df + self._spark = spark or df.spark + self._mode = mode + self._by_name = by_name + + def copy(self, **kwargs) -> DataFrameWriter: + return DataFrameWriter( + **{k[1:] if k.startswith("_") else k: v for k, v in object_to_dict(self, **kwargs).items()} + ) + + def sql(self, **kwargs) -> t.List[str]: + return self._df.sql(**kwargs) + + def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter: + return self.copy(_mode=saveMode) + + @property + def byName(self): + return self.copy(by_name=True) + + def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: + output_expression_container = exp.Insert( + **{ + "this": exp.to_table(tableName), + "overwrite": overwrite, + } + ) + df = self._df.copy(output_expression_container=output_expression_container) + if self._by_name: + columns = sqlglot.schema.column_names(tableName, only_visible=True) + df = df._convert_leaf_to_cte().select(*columns) + + return self.copy(_df=df) + + def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): + if format is not None: + raise NotImplementedError("Providing Format in the save as table is not supported") + exists, replace, mode = None, None, mode or str(self._mode) + if mode == "append": + return self.insertInto(name) + if mode == "ignore": + exists = True + if mode == "overwrite": + replace = True + output_expression_container = exp.Create( + this=exp.to_table(name), + kind="TABLE", + exists=exists, + replace=replace, + ) + return self.copy(_df=self._df.copy(output_expression_container=output_expression_container)) diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py new file mode 100644 index 0000000..1ea86d1 --- /dev/null +++ b/sqlglot/dataframe/sql/session.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import typing as t +import uuid +from collections import defaultdict + +import sqlglot +from sqlglot import expressions as exp +from sqlglot.dataframe.sql import functions as F +from sqlglot.dataframe.sql.dataframe import DataFrame +from sqlglot.dataframe.sql.readwriter import DataFrameReader +from sqlglot.dataframe.sql.types import StructType +from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput + + +class SparkSession: + known_ids: t.ClassVar[t.Set[str]] = set() + known_branch_ids: t.ClassVar[t.Set[str]] = set() + known_sequence_ids: t.ClassVar[t.Set[str]] = set() + name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list) + + def __init__(self): + self.incrementing_id = 1 + + def __getattr__(self, name: str) -> SparkSession: + return self + + def __call__(self, *args, **kwargs) -> SparkSession: + return self + + @property + def read(self) -> DataFrameReader: + return DataFrameReader(self) + + def table(self, tableName: str) -> DataFrame: + return self.read.table(tableName) + + def createDataFrame( + self, + data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], + schema: t.Optional[SchemaInput] = None, + samplingRatio: t.Optional[float] = None, + verifySchema: bool = False, + ) -> DataFrame: + from sqlglot.dataframe.sql.dataframe import DataFrame + + if samplingRatio is not None or verifySchema: + raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") + if schema is not None and ( + not isinstance(schema, (StructType, str, list)) + or (isinstance(schema, list) and not isinstance(schema[0], str)) + ): + raise NotImplementedError("Only schema of either list or string of list supported") + if not data: + raise ValueError("Must provide data to create into a DataFrame") + + column_mapping: t.Dict[str, t.Optional[str]] + if schema is not None: + column_mapping = get_column_mapping_from_schema_input(schema) + elif isinstance(data[0], dict): + column_mapping = {col_name.strip(): None for col_name in data[0]} + else: + column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} + + data_expressions = [ + exp.Tuple( + expressions=list(map(lambda x: F.lit(x).expression, row if not isinstance(row, dict) else row.values())) + ) + for row in data + ] + + sel_columns = [ + F.col(name).cast(data_type).alias(name).expression if data_type is not None else F.col(name).expression + for name, data_type in column_mapping.items() + ] + + select_kwargs = { + "expressions": sel_columns, + "from": exp.From( + expressions=[ + exp.Subquery( + this=exp.Values(expressions=data_expressions), + alias=exp.TableAlias( + this=exp.to_identifier(self._auto_incrementing_name), + columns=[exp.to_identifier(col_name) for col_name in column_mapping], + ), + ) + ] + ), + } + + sel_expression = exp.Select(**select_kwargs) + return DataFrame(self, sel_expression) + + def sql(self, sqlQuery: str) -> DataFrame: + expression = sqlglot.parse_one(sqlQuery, read="spark") + if isinstance(expression, exp.Select): + df = DataFrame(self, expression) + df = df._convert_leaf_to_cte() + elif isinstance(expression, (exp.Create, exp.Insert)): + select_expression = expression.expression.copy() + if isinstance(expression, exp.Insert): + select_expression.set("with", expression.args.get("with")) + expression.set("with", None) + del expression.args["expression"] + df = DataFrame(self, select_expression, output_expression_container=expression) + df = df._convert_leaf_to_cte() + else: + raise ValueError("Unknown expression type provided in the SQL. Please create an issue with the SQL.") + return df + + @property + def _auto_incrementing_name(self) -> str: + name = f"a{self.incrementing_id}" + self.incrementing_id += 1 + return name + + @property + def _random_name(self) -> str: + return f"a{str(uuid.uuid4())[:8]}" + + @property + def _random_branch_id(self) -> str: + id = self._random_id + self.known_branch_ids.add(id) + return id + + @property + def _random_sequence_id(self): + id = self._random_id + self.known_sequence_ids.add(id) + return id + + @property + def _random_id(self) -> str: + id = f"a{str(uuid.uuid4())[:8]}" + self.known_ids.add(id) + return id + + @property + def _join_hint_names(self) -> t.Set[str]: + return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"} + + def _add_alias_to_mapping(self, name: str, sequence_id: str): + self.name_to_sequence_id_mapping[name].append(sequence_id) diff --git a/sqlglot/dataframe/sql/transforms.py b/sqlglot/dataframe/sql/transforms.py new file mode 100644 index 0000000..b3dcc12 --- /dev/null +++ b/sqlglot/dataframe/sql/transforms.py @@ -0,0 +1,9 @@ +import typing as t + +from sqlglot import expressions as exp + + +def replace_id_value(node, replacement_mapping: t.Dict[exp.Identifier, exp.Identifier]): + if isinstance(node, exp.Identifier) and node in replacement_mapping: + node = node.replace(replacement_mapping[node].copy()) + return node diff --git a/sqlglot/dataframe/sql/types.py b/sqlglot/dataframe/sql/types.py new file mode 100644 index 0000000..dc5c05a --- /dev/null +++ b/sqlglot/dataframe/sql/types.py @@ -0,0 +1,208 @@ +import typing as t + + +class DataType: + def __repr__(self) -> str: + return self.__class__.__name__ + "()" + + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: t.Any) -> bool: + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other: t.Any) -> bool: + return not self.__eq__(other) + + def __str__(self) -> str: + return self.typeName() + + @classmethod + def typeName(cls) -> str: + return cls.__name__[:-4].lower() + + def simpleString(self) -> str: + return str(self) + + def jsonValue(self) -> t.Union[str, t.Dict[str, t.Any]]: + return str(self) + + +class DataTypeWithLength(DataType): + def __init__(self, length: int): + self.length = length + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.length})" + + def __str__(self) -> str: + return f"{self.typeName()}({self.length})" + + +class StringType(DataType): + pass + + +class CharType(DataTypeWithLength): + pass + + +class VarcharType(DataTypeWithLength): + pass + + +class BinaryType(DataType): + pass + + +class BooleanType(DataType): + pass + + +class DateType(DataType): + pass + + +class TimestampType(DataType): + pass + + +class TimestampNTZType(DataType): + @classmethod + def typeName(cls) -> str: + return "timestamp_ntz" + + +class DecimalType(DataType): + def __init__(self, precision: int = 10, scale: int = 0): + self.precision = precision + self.scale = scale + + def simpleString(self) -> str: + return f"decimal({self.precision}, {self.scale})" + + def jsonValue(self) -> str: + return f"decimal({self.precision}, {self.scale})" + + def __repr__(self) -> str: + return f"DecimalType({self.precision}, {self.scale})" + + +class DoubleType(DataType): + pass + + +class FloatType(DataType): + pass + + +class ByteType(DataType): + def __str__(self) -> str: + return "tinyint" + + +class IntegerType(DataType): + def __str__(self) -> str: + return "int" + + +class LongType(DataType): + def __str__(self) -> str: + return "bigint" + + +class ShortType(DataType): + def __str__(self) -> str: + return "smallint" + + +class ArrayType(DataType): + def __init__(self, elementType: DataType, containsNull: bool = True): + self.elementType = elementType + self.containsNull = containsNull + + def __repr__(self) -> str: + return f"ArrayType({self.elementType, str(self.containsNull)}" + + def simpleString(self) -> str: + return f"array<{self.elementType.simpleString()}>" + + def jsonValue(self) -> t.Dict[str, t.Any]: + return { + "type": self.typeName(), + "elementType": self.elementType.jsonValue(), + "containsNull": self.containsNull, + } + + +class MapType(DataType): + def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True): + self.keyType = keyType + self.valueType = valueType + self.valueContainsNull = valueContainsNull + + def __repr__(self) -> str: + return f"MapType({self.keyType}, {self.valueType}, {str(self.valueContainsNull)})" + + def simpleString(self) -> str: + return f"map<{self.keyType.simpleString()}, {self.valueType.simpleString()}>" + + def jsonValue(self) -> t.Dict[str, t.Any]: + return { + "type": self.typeName(), + "keyType": self.keyType.jsonValue(), + "valueType": self.valueType.jsonValue(), + "valueContainsNull": self.valueContainsNull, + } + + +class StructField(DataType): + def __init__( + self, name: str, dataType: DataType, nullable: bool = True, metadata: t.Optional[t.Dict[str, t.Any]] = None + ): + self.name = name + self.dataType = dataType + self.nullable = nullable + self.metadata = metadata or {} + + def __repr__(self) -> str: + return f"StructField('{self.name}', {self.dataType}, {str(self.nullable)})" + + def simpleString(self) -> str: + return f"{self.name}:{self.dataType.simpleString()}" + + def jsonValue(self) -> t.Dict[str, t.Any]: + return { + "name": self.name, + "type": self.dataType.jsonValue(), + "nullable": self.nullable, + "metadata": self.metadata, + } + + +class StructType(DataType): + def __init__(self, fields: t.Optional[t.List[StructField]] = None): + if not fields: + self.fields = [] + self.names = [] + else: + self.fields = fields + self.names = [f.name for f in fields] + + def __iter__(self) -> t.Iterator[StructField]: + return iter(self.fields) + + def __len__(self) -> int: + return len(self.fields) + + def __repr__(self) -> str: + return f"StructType({', '.join(str(field) for field in self)})" + + def simpleString(self) -> str: + return f"struct<{', '.join(x.simpleString() for x in self)}>" + + def jsonValue(self) -> t.Dict[str, t.Any]: + return {"type": self.typeName(), "fields": [x.jsonValue() for x in self]} + + def fieldNames(self) -> t.List[str]: + return list(self.names) diff --git a/sqlglot/dataframe/sql/util.py b/sqlglot/dataframe/sql/util.py new file mode 100644 index 0000000..575d18a --- /dev/null +++ b/sqlglot/dataframe/sql/util.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import expressions as exp +from sqlglot.dataframe.sql import types + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql._typing import SchemaInput + + +def get_column_mapping_from_schema_input(schema: SchemaInput) -> t.Dict[str, t.Optional[str]]: + if isinstance(schema, dict): + return schema + elif isinstance(schema, str): + col_name_type_strs = [x.strip() for x in schema.split(",")] + return { + name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() + for name_type_str in col_name_type_strs + } + elif isinstance(schema, types.StructType): + return {struct_field.name: struct_field.dataType.simpleString() for struct_field in schema} + return {x.strip(): None for x in schema} # type: ignore + + +def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.Table]: + if not expression.args.get("joins"): + return [] + + left_table = expression.args["from"].args["expressions"][0] + other_tables = [join.this for join in expression.args["joins"]] + return [left_table] + other_tables diff --git a/sqlglot/dataframe/sql/window.py b/sqlglot/dataframe/sql/window.py new file mode 100644 index 0000000..842f366 --- /dev/null +++ b/sqlglot/dataframe/sql/window.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import sys +import typing as t + +from sqlglot import expressions as exp +from sqlglot.dataframe.sql import functions as F +from sqlglot.helper import flatten + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql._typing import ColumnOrName + + +class Window: + _JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808 + _JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807 + _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG) + _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG) + + unboundedPreceding: int = _JAVA_MIN_LONG + + unboundedFollowing: int = _JAVA_MAX_LONG + + currentRow: int = 0 + + @classmethod + def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: + return WindowSpec().partitionBy(*cols) + + @classmethod + def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: + return WindowSpec().orderBy(*cols) + + @classmethod + def rowsBetween(cls, start: int, end: int) -> WindowSpec: + return WindowSpec().rowsBetween(start, end) + + @classmethod + def rangeBetween(cls, start: int, end: int) -> WindowSpec: + return WindowSpec().rangeBetween(start, end) + + +class WindowSpec: + def __init__(self, expression: exp.Expression = exp.Window()): + self.expression = expression + + def copy(self): + return WindowSpec(self.expression.copy()) + + def sql(self, **kwargs) -> str: + return self.expression.sql(dialect="spark", **kwargs) + + def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: + from sqlglot.dataframe.sql.column import Column + + cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore + expressions = [Column.ensure_col(x).expression for x in cols] + window_spec = self.copy() + partition_by_expressions = window_spec.expression.args.get("partition_by", []) + partition_by_expressions.extend(expressions) + window_spec.expression.set("partition_by", partition_by_expressions) + return window_spec + + def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: + from sqlglot.dataframe.sql.column import Column + + cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore + expressions = [Column.ensure_col(x).expression for x in cols] + window_spec = self.copy() + if window_spec.expression.args.get("order") is None: + window_spec.expression.set("order", exp.Order(expressions=[])) + order_by = window_spec.expression.args["order"].expressions + order_by.extend(expressions) + window_spec.expression.args["order"].set("expressions", order_by) + return window_spec + + def _calc_start_end(self, start: int, end: int) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]: + kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {"start_side": None, "end_side": None} + if start == Window.currentRow: + kwargs["start"] = "CURRENT ROW" + else: + kwargs = { + **kwargs, + **{ + "start_side": "PRECEDING", + "start": "UNBOUNDED" if start <= Window.unboundedPreceding else F.lit(start).expression, + }, + } + if end == Window.currentRow: + kwargs["end"] = "CURRENT ROW" + else: + kwargs = { + **kwargs, + **{ + "end_side": "FOLLOWING", + "end": "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression, + }, + } + return kwargs + + def rowsBetween(self, start: int, end: int) -> WindowSpec: + window_spec = self.copy() + spec = self._calc_start_end(start, end) + spec["kind"] = "ROWS" + window_spec.expression.set( + "spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}) + ) + return window_spec + + def rangeBetween(self, start: int, end: int) -> WindowSpec: + window_spec = self.copy() + spec = self._calc_start_end(start, end) + spec["kind"] = "RANGE" + window_spec.expression.set( + "spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}) + ) + return window_spec |