diff options
Diffstat (limited to 'sqlglot/dataframe/sql')
-rw-r--r-- | sqlglot/dataframe/sql/__init__.py | 18 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/_typing.py | 18 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/column.py | 342 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/dataframe.py | 862 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/functions.py | 1270 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/group.py | 59 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/normalize.py | 78 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/operations.py | 53 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/readwriter.py | 108 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/session.py | 199 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/transforms.py | 9 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/types.py | 212 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/util.py | 32 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/window.py | 136 |
14 files changed, 0 insertions, 3396 deletions
diff --git a/sqlglot/dataframe/sql/__init__.py b/sqlglot/dataframe/sql/__init__.py deleted file mode 100644 index 3f90802..0000000 --- a/sqlglot/dataframe/sql/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -from sqlglot.dataframe.sql.column import Column -from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions -from sqlglot.dataframe.sql.group import GroupedData -from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter -from sqlglot.dataframe.sql.session import SparkSession -from sqlglot.dataframe.sql.window import Window, WindowSpec - -__all__ = [ - "SparkSession", - "DataFrame", - "GroupedData", - "Column", - "DataFrameNaFunctions", - "Window", - "WindowSpec", - "DataFrameReader", - "DataFrameWriter", -] diff --git a/sqlglot/dataframe/sql/_typing.py b/sqlglot/dataframe/sql/_typing.py deleted file mode 100644 index fb46026..0000000 --- a/sqlglot/dataframe/sql/_typing.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -import datetime -import typing as t - -from sqlglot import expressions as exp - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql.column import Column - from sqlglot.dataframe.sql.types import StructType - -ColumnLiterals = t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] -ColumnOrName = t.Union[Column, str] -ColumnOrLiteral = t.Union[ - Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime -] -SchemaInput = t.Union[str, t.List[str], StructType, t.Dict[str, t.Optional[str]]] -OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert] diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py deleted file mode 100644 index 724c5bf..0000000 --- a/sqlglot/dataframe/sql/column.py +++ /dev/null @@ -1,342 +0,0 @@ -from __future__ import annotations - -import typing as t - -import sqlglot -from sqlglot import expressions as exp -from sqlglot.dataframe.sql.types import DataType -from sqlglot.helper import flatten, is_iterable - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql._typing import ColumnOrLiteral - from sqlglot.dataframe.sql.window import WindowSpec - - -class Column: - def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): - from sqlglot.dataframe.sql.session import SparkSession - - if isinstance(expression, Column): - expression = expression.expression # type: ignore - elif expression is None or not isinstance(expression, (str, exp.Expression)): - expression = self._lit(expression).expression # type: ignore - elif not isinstance(expression, exp.Column): - expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform( - SparkSession().dialect.normalize_identifier, copy=False - ) - if expression is None: - raise ValueError(f"Could not parse {expression}") - - self.expression: exp.Expression = expression # type: ignore - - def __repr__(self): - return repr(self.expression) - - def __hash__(self): - return hash(self.expression) - - def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore - return self.binary_op(exp.EQ, other) - - def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore - return self.binary_op(exp.NEQ, other) - - def __gt__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.GT, other) - - def __ge__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.GTE, other) - - def __lt__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.LT, other) - - def __le__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.LTE, other) - - def __and__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.And, other) - - def __or__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Or, other) - - def __mod__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Mod, other) - - def __add__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Add, other) - - def __sub__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Sub, other) - - def __mul__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Mul, other) - - def __truediv__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Div, other) - - def __div__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Div, other) - - def __neg__(self) -> Column: - return self.unary_op(exp.Neg) - - def __radd__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Add, other) - - def __rsub__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Sub, other) - - def __rmul__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Mul, other) - - def __rdiv__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Div, other) - - def __rtruediv__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Div, other) - - def __rmod__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Mod, other) - - def __pow__(self, power: ColumnOrLiteral, modulo=None): - return Column(exp.Pow(this=self.expression, expression=Column(power).expression)) - - def __rpow__(self, power: ColumnOrLiteral): - return Column(exp.Pow(this=Column(power).expression, expression=self.expression)) - - def __invert__(self): - return self.unary_op(exp.Not) - - def __rand__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.And, other) - - def __ror__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Or, other) - - @classmethod - def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]) -> Column: - return cls(value) - - @classmethod - def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]: - return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args] - - @classmethod - def _lit(cls, value: ColumnOrLiteral) -> Column: - if isinstance(value, dict): - columns = [cls._lit(v).alias(k).expression for k, v in value.items()] - return cls(exp.Struct(expressions=columns)) - return cls(exp.convert(value)) - - @classmethod - def invoke_anonymous_function( - cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] - ) -> Column: - columns = [] if column is None else [cls.ensure_col(column)] - column_args = [cls.ensure_col(arg) for arg in args] - expressions = [x.expression for x in columns + column_args] - new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) - return Column(new_expression) - - @classmethod - def invoke_expression_over_column( - cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs - ) -> Column: - ensured_column = None if column is None else cls.ensure_col(column) - ensure_expression_values = { - k: ( - [Column.ensure_col(x).expression for x in v] - if is_iterable(v) - else Column.ensure_col(v).expression - ) - for k, v in kwargs.items() - if v is not None - } - new_expression = ( - callable_expression(**ensure_expression_values) - if ensured_column is None - else callable_expression( - this=ensured_column.column_expression, **ensure_expression_values - ) - ) - return Column(new_expression) - - def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: - return Column( - klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs) - ) - - def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: - return Column( - klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs) - ) - - def unary_op(self, klass: t.Callable, **kwargs) -> Column: - return Column(klass(this=self.column_expression, **kwargs)) - - @property - def is_alias(self): - return isinstance(self.expression, exp.Alias) - - @property - def is_column(self): - return isinstance(self.expression, exp.Column) - - @property - def column_expression(self) -> t.Union[exp.Column, exp.Literal]: - return self.expression.unalias() - - @property - def alias_or_name(self) -> str: - return self.expression.alias_or_name - - @classmethod - def ensure_literal(cls, value) -> Column: - from sqlglot.dataframe.sql.functions import lit - - if isinstance(value, cls): - value = value.expression - if not isinstance(value, exp.Literal): - return lit(value) - return Column(value) - - def copy(self) -> Column: - return Column(self.expression.copy()) - - def set_table_name(self, table_name: str, copy=False) -> Column: - expression = self.expression.copy() if copy else self.expression - expression.set("table", exp.to_identifier(table_name)) - return Column(expression) - - def sql(self, **kwargs) -> str: - from sqlglot.dataframe.sql.session import SparkSession - - return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs}) - - def alias(self, name: str) -> Column: - from sqlglot.dataframe.sql.session import SparkSession - - dialect = SparkSession().dialect - alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect) - new_expression = exp.alias_( - self.column_expression, - alias.this if isinstance(alias, exp.Column) else name, - dialect=dialect, - ) - return Column(new_expression) - - def asc(self) -> Column: - new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True) - return Column(new_expression) - - def desc(self) -> Column: - new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False) - return Column(new_expression) - - asc_nulls_first = asc - - def asc_nulls_last(self) -> Column: - new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False) - return Column(new_expression) - - def desc_nulls_first(self) -> Column: - new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True) - return Column(new_expression) - - desc_nulls_last = desc - - def when(self, condition: Column, value: t.Any) -> Column: - from sqlglot.dataframe.sql.functions import when - - column_with_if = when(condition, value) - if not isinstance(self.expression, exp.Case): - return column_with_if - new_column = self.copy() - new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) - return new_column - - def otherwise(self, value: t.Any) -> Column: - from sqlglot.dataframe.sql.functions import lit - - true_value = value if isinstance(value, Column) else lit(value) - new_column = self.copy() - new_column.expression.set("default", true_value.column_expression) - return new_column - - def isNull(self) -> Column: - new_expression = exp.Is(this=self.column_expression, expression=exp.Null()) - return Column(new_expression) - - def isNotNull(self) -> Column: - new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null())) - return Column(new_expression) - - def cast(self, dataType: t.Union[str, DataType]) -> Column: - """ - Functionality Difference: PySpark cast accepts a datatype instance of the datatype class - Sqlglot doesn't currently replicate this class so it only accepts a string - """ - from sqlglot.dataframe.sql.session import SparkSession - - if isinstance(dataType, DataType): - dataType = dataType.simpleString() - return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect)) - - def startswith(self, value: t.Union[str, Column]) -> Column: - value = self._lit(value) if not isinstance(value, Column) else value - return self.invoke_anonymous_function(self, "STARTSWITH", value) - - def endswith(self, value: t.Union[str, Column]) -> Column: - value = self._lit(value) if not isinstance(value, Column) else value - return self.invoke_anonymous_function(self, "ENDSWITH", value) - - def rlike(self, regexp: str) -> Column: - return self.invoke_expression_over_column( - column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression - ) - - def like(self, other: str): - return self.invoke_expression_over_column( - self, exp.Like, expression=self._lit(other).expression - ) - - def ilike(self, other: str): - return self.invoke_expression_over_column( - self, exp.ILike, expression=self._lit(other).expression - ) - - def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: - startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos - length = self._lit(length) if not isinstance(length, Column) else length - return Column.invoke_expression_over_column( - self, exp.Substring, start=startPos.expression, length=length.expression - ) - - def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): - columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore - expressions = [self._lit(x).expression for x in columns] - return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore - - def between( - self, - lowerBound: t.Union[ColumnOrLiteral], - upperBound: t.Union[ColumnOrLiteral], - ) -> Column: - lower_bound_exp = ( - self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound - ) - upper_bound_exp = ( - self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound - ) - return Column( - exp.Between( - this=self.column_expression, - low=lower_bound_exp.expression, - high=upper_bound_exp.expression, - ) - ) - - def over(self, window: WindowSpec) -> Column: - window_expression = window.expression.copy() - window_expression.set("this", self.column_expression) - return Column(window_expression) diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py deleted file mode 100644 index 8316c36..0000000 --- a/sqlglot/dataframe/sql/dataframe.py +++ /dev/null @@ -1,862 +0,0 @@ -from __future__ import annotations - -import functools -import logging -import typing as t -import zlib -from copy import copy - -import sqlglot -from sqlglot import Dialect, expressions as exp -from sqlglot.dataframe.sql import functions as F -from sqlglot.dataframe.sql.column import Column -from sqlglot.dataframe.sql.group import GroupedData -from sqlglot.dataframe.sql.normalize import normalize -from sqlglot.dataframe.sql.operations import Operation, operation -from sqlglot.dataframe.sql.readwriter import DataFrameWriter -from sqlglot.dataframe.sql.transforms import replace_id_value -from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join -from sqlglot.dataframe.sql.window import Window -from sqlglot.helper import ensure_list, object_to_dict, seq_get - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql._typing import ( - ColumnLiterals, - ColumnOrLiteral, - ColumnOrName, - OutputExpressionContainer, - ) - from sqlglot.dataframe.sql.session import SparkSession - from sqlglot.dialects.dialect import DialectType - -logger = logging.getLogger("sqlglot") - -JOIN_HINTS = { - "BROADCAST", - "BROADCASTJOIN", - "MAPJOIN", - "MERGE", - "SHUFFLEMERGE", - "MERGEJOIN", - "SHUFFLE_HASH", - "SHUFFLE_REPLICATE_NL", -} - - -class DataFrame: - def __init__( - self, - spark: SparkSession, - expression: exp.Select, - branch_id: t.Optional[str] = None, - sequence_id: t.Optional[str] = None, - last_op: Operation = Operation.INIT, - pending_hints: t.Optional[t.List[exp.Expression]] = None, - output_expression_container: t.Optional[OutputExpressionContainer] = None, - **kwargs, - ): - self.spark = spark - self.expression = expression - self.branch_id = branch_id or self.spark._random_branch_id - self.sequence_id = sequence_id or self.spark._random_sequence_id - self.last_op = last_op - self.pending_hints = pending_hints or [] - self.output_expression_container = output_expression_container or exp.Select() - - def __getattr__(self, column_name: str) -> Column: - return self[column_name] - - def __getitem__(self, column_name: str) -> Column: - column_name = f"{self.branch_id}.{column_name}" - return Column(column_name) - - def __copy__(self): - return self.copy() - - @property - def sparkSession(self): - return self.spark - - @property - def write(self): - return DataFrameWriter(self) - - @property - def latest_cte_name(self) -> str: - if not self.expression.ctes: - from_exp = self.expression.args["from"] - if from_exp.alias_or_name: - return from_exp.alias_or_name - table_alias = from_exp.find(exp.TableAlias) - if not table_alias: - raise RuntimeError( - f"Could not find an alias name for this expression: {self.expression}" - ) - return table_alias.alias_or_name - return self.expression.ctes[-1].alias - - @property - def pending_join_hints(self): - return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)] - - @property - def pending_partition_hints(self): - return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)] - - @property - def columns(self) -> t.List[str]: - return self.expression.named_selects - - @property - def na(self) -> DataFrameNaFunctions: - return DataFrameNaFunctions(self) - - def _replace_cte_names_with_hashes(self, expression: exp.Select): - replacement_mapping = {} - for cte in expression.ctes: - old_name_id = cte.args["alias"].this - new_hashed_id = exp.to_identifier( - self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"] - ) - replacement_mapping[old_name_id] = new_hashed_id - expression = expression.transform(replace_id_value, replacement_mapping).assert_is( - exp.Select - ) - return expression - - def _create_cte_from_expression( - self, - expression: exp.Expression, - branch_id: t.Optional[str] = None, - sequence_id: t.Optional[str] = None, - **kwargs, - ) -> t.Tuple[exp.CTE, str]: - name = self._create_hash_from_expression(expression) - expression_to_cte = expression.copy() - expression_to_cte.set("with", None) - cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0] - cte.set("branch_id", branch_id or self.branch_id) - cte.set("sequence_id", sequence_id or self.sequence_id) - return cte, name - - @t.overload - def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ... - - @t.overload - def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ... - - def _ensure_list_of_columns(self, cols): - return Column.ensure_cols(ensure_list(cols)) - - def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None): - cols = self._ensure_list_of_columns(cols) - normalize(self.spark, expression or self.expression, cols) - return cols - - def _ensure_and_normalize_col(self, col): - col = Column.ensure_col(col) - normalize(self.spark, self.expression, col) - return col - - def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame: - df = self._resolve_pending_hints() - sequence_id = sequence_id or df.sequence_id - expression = df.expression.copy() - cte_expression, cte_name = df._create_cte_from_expression( - expression=expression, sequence_id=sequence_id - ) - new_expression = df._add_ctes_to_expression( - exp.Select(), expression.ctes + [cte_expression] - ) - sel_columns = df._get_outer_select_columns(cte_expression) - new_expression = new_expression.from_(cte_name).select( - *[x.alias_or_name for x in sel_columns] - ) - return df.copy(expression=new_expression, sequence_id=sequence_id) - - def _resolve_pending_hints(self) -> DataFrame: - df = self.copy() - if not self.pending_hints: - return df - expression = df.expression - hint_expression = expression.args.get("hint") or exp.Hint(expressions=[]) - for hint in df.pending_partition_hints: - hint_expression.append("expressions", hint) - df.pending_hints.remove(hint) - - join_aliases = { - join_table.alias_or_name - for join_table in get_tables_from_expression_with_join(expression) - } - if join_aliases: - for hint in df.pending_join_hints: - for sequence_id_expression in hint.expressions: - sequence_id_or_name = sequence_id_expression.alias_or_name - sequence_ids_to_match = [sequence_id_or_name] - if sequence_id_or_name in df.spark.name_to_sequence_id_mapping: - sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[ - sequence_id_or_name - ] - matching_ctes = [ - cte - for cte in reversed(expression.ctes) - if cte.args["sequence_id"] in sequence_ids_to_match - ] - for matching_cte in matching_ctes: - if matching_cte.alias_or_name in join_aliases: - sequence_id_expression.set("this", matching_cte.args["alias"].this) - df.pending_hints.remove(hint) - break - hint_expression.append("expressions", hint) - if hint_expression.expressions: - expression.set("hint", hint_expression) - return df - - def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame: - hint_name = hint_name.upper() - hint_expression = ( - exp.JoinHint( - this=hint_name, - expressions=[exp.to_table(parameter.alias_or_name) for parameter in args], - ) - if hint_name in JOIN_HINTS - else exp.Anonymous( - this=hint_name, expressions=[parameter.expression for parameter in args] - ) - ) - new_df = self.copy() - new_df.pending_hints.append(hint_expression) - return new_df - - def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool): - other_df = other._convert_leaf_to_cte() - base_expression = self.expression.copy() - base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes) - all_ctes = base_expression.ctes - other_df.expression.set("with", None) - base_expression.set("with", None) - operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression) - operation.set("with", exp.With(expressions=all_ctes)) - return self.copy(expression=operation)._convert_leaf_to_cte() - - def _cache(self, storage_level: str): - df = self._convert_leaf_to_cte() - df.expression.ctes[-1].set("cache_storage_level", storage_level) - return df - - @classmethod - def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select: - expression = expression.copy() - with_expression = expression.args.get("with") - if with_expression: - existing_ctes = with_expression.expressions - existsing_cte_names = {x.alias_or_name for x in existing_ctes} - for cte in ctes: - if cte.alias_or_name not in existsing_cte_names: - existing_ctes.append(cte) - else: - existing_ctes = ctes - expression.set("with", exp.With(expressions=existing_ctes)) - return expression - - @classmethod - def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]: - expression = item.expression if isinstance(item, DataFrame) else item - return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions] - - @classmethod - def _create_hash_from_expression(cls, expression: exp.Expression) -> str: - from sqlglot.dataframe.sql.session import SparkSession - - value = expression.sql(dialect=SparkSession().dialect).encode("utf-8") - return f"t{zlib.crc32(value)}"[:6] - - def _get_select_expressions( - self, - ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]: - select_expressions: t.List[ - t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select] - ] = [] - main_select_ctes: t.List[exp.CTE] = [] - for cte in self.expression.ctes: - cache_storage_level = cte.args.get("cache_storage_level") - if cache_storage_level: - select_expression = cte.this.copy() - select_expression.set("with", exp.With(expressions=copy(main_select_ctes))) - select_expression.set("cte_alias_name", cte.alias_or_name) - select_expression.set("cache_storage_level", cache_storage_level) - select_expressions.append((exp.Cache, select_expression)) - else: - main_select_ctes.append(cte) - main_select = self.expression.copy() - if main_select_ctes: - main_select.set("with", exp.With(expressions=main_select_ctes)) - expression_select_pair = (type(self.output_expression_container), main_select) - select_expressions.append(expression_select_pair) # type: ignore - return select_expressions - - def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]: - from sqlglot.dataframe.sql.session import SparkSession - - dialect = Dialect.get_or_raise(dialect or SparkSession().dialect) - - df = self._resolve_pending_hints() - select_expressions = df._get_select_expressions() - output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] - replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} - - for expression_type, select_expression in select_expressions: - select_expression = select_expression.transform( - replace_id_value, replacement_mapping - ).assert_is(exp.Select) - if optimize: - select_expression = t.cast( - exp.Select, self.spark._optimize(select_expression, dialect=dialect) - ) - - select_expression = df._replace_cte_names_with_hashes(select_expression) - - expression: t.Union[exp.Select, exp.Cache, exp.Drop] - if expression_type == exp.Cache: - cache_table_name = df._create_hash_from_expression(select_expression) - cache_table = exp.to_table(cache_table_name) - original_alias_name = select_expression.args["cte_alias_name"] - - replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore - cache_table_name - ) - sqlglot.schema.add_table( - cache_table_name, - { - expression.alias_or_name: expression.type.sql(dialect=dialect) - for expression in select_expression.expressions - }, - dialect=dialect, - ) - - cache_storage_level = select_expression.args["cache_storage_level"] - options = [ - exp.Literal.string("storageLevel"), - exp.Literal.string(cache_storage_level), - ] - expression = exp.Cache( - this=cache_table, expression=select_expression, lazy=True, options=options - ) - - # We will drop the "view" if it exists before running the cache table - output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) - elif expression_type == exp.Create: - expression = df.output_expression_container.copy() - expression.set("expression", select_expression) - elif expression_type == exp.Insert: - expression = df.output_expression_container.copy() - select_without_ctes = select_expression.copy() - select_without_ctes.set("with", None) - expression.set("expression", select_without_ctes) - - if select_expression.ctes: - expression.set("with", exp.With(expressions=select_expression.ctes)) - elif expression_type == exp.Select: - expression = select_expression - else: - raise ValueError(f"Invalid expression type: {expression_type}") - - output_expressions.append(expression) - - return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions] - - def copy(self, **kwargs) -> DataFrame: - return DataFrame(**object_to_dict(self, **kwargs)) - - @operation(Operation.SELECT) - def select(self, *cols, **kwargs) -> DataFrame: - cols = self._ensure_and_normalize_cols(cols) - kwargs["append"] = kwargs.get("append", False) - if self.expression.args.get("joins"): - ambiguous_cols = [ - col - for col in cols - if isinstance(col.column_expression, exp.Column) and not col.column_expression.table - ] - if ambiguous_cols: - join_table_identifiers = [ - x.this for x in get_tables_from_expression_with_join(self.expression) - ] - cte_names_in_join = [x.this for x in join_table_identifiers] - # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right - # and therefore we allow multiple columns with the same name in the result. This matches the behavior - # of Spark. - resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols} - for ambiguous_col in ambiguous_cols: - ctes_with_column = [ - cte - for cte in self.expression.ctes - if cte.alias_or_name in cte_names_in_join - and ambiguous_col.alias_or_name in cte.this.named_selects - ] - # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise, - # use the same CTE we used before - cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1) - if cte: - resolved_column_position[ambiguous_col] += 1 - else: - cte = ctes_with_column[resolved_column_position[ambiguous_col]] - ambiguous_col.expression.set("table", cte.alias_or_name) - return self.copy( - expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs - ) - - @operation(Operation.NO_OP) - def alias(self, name: str, **kwargs) -> DataFrame: - new_sequence_id = self.spark._random_sequence_id - df = self.copy() - for join_hint in df.pending_join_hints: - for expression in join_hint.expressions: - if expression.alias_or_name == self.sequence_id: - expression.set("this", Column.ensure_col(new_sequence_id).expression) - df.spark._add_alias_to_mapping(name, new_sequence_id) - return df._convert_leaf_to_cte(sequence_id=new_sequence_id) - - @operation(Operation.WHERE) - def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame: - col = self._ensure_and_normalize_col(column) - return self.copy(expression=self.expression.where(col.expression)) - - filter = where - - @operation(Operation.GROUP_BY) - def groupBy(self, *cols, **kwargs) -> GroupedData: - columns = self._ensure_and_normalize_cols(cols) - return GroupedData(self, columns, self.last_op) - - @operation(Operation.SELECT) - def agg(self, *exprs, **kwargs) -> DataFrame: - cols = self._ensure_and_normalize_cols(exprs) - return self.groupBy().agg(*cols) - - @operation(Operation.FROM) - def join( - self, - other_df: DataFrame, - on: t.Union[str, t.List[str], Column, t.List[Column]], - how: str = "inner", - **kwargs, - ) -> DataFrame: - other_df = other_df._convert_leaf_to_cte() - join_columns = self._ensure_list_of_columns(on) - # We will determine actual "join on" expression later so we don't provide it at first - join_expression = self.expression.join( - other_df.latest_cte_name, join_type=how.replace("_", " ") - ) - join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes) - self_columns = self._get_outer_select_columns(join_expression) - other_columns = self._get_outer_select_columns(other_df) - # Determines the join clause and select columns to be used passed on what type of columns were provided for - # the join. The columns returned changes based on how the on expression is provided. - if isinstance(join_columns[0].expression, exp.Column): - """ - Unique characteristics of join on column names only: - * The column names are put at the front of the select list - * The column names are deduplicated across the entire select list and only the column names (other dups are allowed) - """ - table_names = [ - table.alias_or_name - for table in get_tables_from_expression_with_join(join_expression) - ] - potential_ctes = [ - cte - for cte in join_expression.ctes - if cte.alias_or_name in table_names - and cte.alias_or_name != other_df.latest_cte_name - ] - # Determine the table to reference for the left side of the join by checking each of the left side - # tables and see if they have the column being referenced. - join_column_pairs = [] - for join_column in join_columns: - num_matching_ctes = 0 - for cte in potential_ctes: - if join_column.alias_or_name in cte.this.named_selects: - left_column = join_column.copy().set_table_name(cte.alias_or_name) - right_column = join_column.copy().set_table_name(other_df.latest_cte_name) - join_column_pairs.append((left_column, right_column)) - num_matching_ctes += 1 - if num_matching_ctes > 1: - raise ValueError( - f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name." - ) - elif num_matching_ctes == 0: - raise ValueError( - f"Column {join_column.alias_or_name} does not exist in any of the tables." - ) - join_clause = functools.reduce( - lambda x, y: x & y, - [left_column == right_column for left_column, right_column in join_column_pairs], - ) - join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] - # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list - select_column_names = [ - ( - column.alias_or_name - if not isinstance(column.expression.this, exp.Star) - else column.sql() - ) - for column in self_columns + other_columns - ] - select_column_names = [ - column_name - for column_name in select_column_names - if column_name not in join_column_names - ] - select_column_names = join_column_names + select_column_names - else: - """ - Unique characteristics of join on expressions: - * There is no deduplication of the results. - * The left join dataframe columns go first and right come after. No sort preference is given to join columns - """ - join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) - if len(join_columns) > 1: - join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] - join_clause = join_columns[0] - select_column_names = [column.alias_or_name for column in self_columns + other_columns] - - # Update the on expression with the actual join clause to replace the dummy one from before - join_expression.args["joins"][-1].set("on", join_clause.expression) - new_df = self.copy(expression=join_expression) - new_df.pending_join_hints.extend(self.pending_join_hints) - new_df.pending_hints.extend(other_df.pending_hints) - new_df = new_df.select.__wrapped__(new_df, *select_column_names) - return new_df - - @operation(Operation.ORDER_BY) - def orderBy( - self, - *cols: t.Union[str, Column], - ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, - ) -> DataFrame: - """ - This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark - has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this - is unlikely to come up. - """ - columns = self._ensure_and_normalize_cols(cols) - pre_ordered_col_indexes = [ - i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered) - ] - if ascending is None: - ascending = [True] * len(columns) - elif not isinstance(ascending, list): - ascending = [ascending] * len(columns) - ascending = [bool(x) for i, x in enumerate(ascending)] - assert len(columns) == len( - ascending - ), "The length of items in ascending must equal the number of columns provided" - col_and_ascending = list(zip(columns, ascending)) - order_by_columns = [ - ( - exp.Ordered(this=col.expression, desc=not asc) - if i not in pre_ordered_col_indexes - else columns[i].column_expression - ) - for i, (col, asc) in enumerate(col_and_ascending) - ] - return self.copy(expression=self.expression.order_by(*order_by_columns)) - - sort = orderBy - - @operation(Operation.FROM) - def union(self, other: DataFrame) -> DataFrame: - return self._set_operation(exp.Union, other, False) - - unionAll = union - - @operation(Operation.FROM) - def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): - l_columns = self.columns - r_columns = other.columns - if not allowMissingColumns: - l_expressions = l_columns - r_expressions = l_columns - else: - l_expressions = [] - r_expressions = [] - r_columns_unused = copy(r_columns) - for l_column in l_columns: - l_expressions.append(l_column) - if l_column in r_columns: - r_expressions.append(l_column) - r_columns_unused.remove(l_column) - else: - r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False)) - for r_column in r_columns_unused: - l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False)) - r_expressions.append(r_column) - r_df = ( - other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) - ) - l_df = self.copy() - if allowMissingColumns: - l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) - return l_df._set_operation(exp.Union, r_df, False) - - @operation(Operation.FROM) - def intersect(self, other: DataFrame) -> DataFrame: - return self._set_operation(exp.Intersect, other, True) - - @operation(Operation.FROM) - def intersectAll(self, other: DataFrame) -> DataFrame: - return self._set_operation(exp.Intersect, other, False) - - @operation(Operation.FROM) - def exceptAll(self, other: DataFrame) -> DataFrame: - return self._set_operation(exp.Except, other, False) - - @operation(Operation.SELECT) - def distinct(self) -> DataFrame: - return self.copy(expression=self.expression.distinct()) - - @operation(Operation.SELECT) - def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): - if not subset: - return self.distinct() - column_names = ensure_list(subset) - window = Window.partitionBy(*column_names).orderBy(*column_names) - return ( - self.copy() - .withColumn("row_num", F.row_number().over(window)) - .where(F.col("row_num") == F.lit(1)) - .drop("row_num") - ) - - @operation(Operation.FROM) - def dropna( - self, - how: str = "any", - thresh: t.Optional[int] = None, - subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, - ) -> DataFrame: - minimum_non_null = thresh or 0 # will be determined later if thresh is null - new_df = self.copy() - all_columns = self._get_outer_select_columns(new_df.expression) - if subset: - null_check_columns = self._ensure_and_normalize_cols(subset) - else: - null_check_columns = all_columns - if thresh is None: - minimum_num_nulls = 1 if how == "any" else len(null_check_columns) - else: - minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 - if minimum_num_nulls > len(null_check_columns): - raise RuntimeError( - f"The minimum num nulls for dropna must be less than or equal to the number of columns. " - f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" - ) - if_null_checks = [ - F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns - ] - nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) - num_nulls = nulls_added_together.alias("num_nulls") - new_df = new_df.select(num_nulls, append=True) - filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) - final_df = filtered_df.select(*all_columns) - return final_df - - @operation(Operation.FROM) - def fillna( - self, - value: t.Union[ColumnLiterals], - subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, - ) -> DataFrame: - """ - Functionality Difference: If you provide a value to replace a null and that type conflicts - with the type of the column then PySpark will just ignore your replacement. - This will try to cast them to be the same in some cases. So they won't always match. - Best to not mix types so make sure replacement is the same type as the column - - Possibility for improvement: Use `typeof` function to get the type of the column - and check if it matches the type of the value provided. If not then make it null. - """ - from sqlglot.dataframe.sql.functions import lit - - values = None - columns = None - new_df = self.copy() - all_columns = self._get_outer_select_columns(new_df.expression) - all_column_mapping = {column.alias_or_name: column for column in all_columns} - if isinstance(value, dict): - values = list(value.values()) - columns = self._ensure_and_normalize_cols(list(value)) - if not columns: - columns = self._ensure_and_normalize_cols(subset) if subset else all_columns - if not values: - values = [value] * len(columns) - value_columns = [lit(value) for value in values] - - null_replacement_mapping = { - column.alias_or_name: ( - F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) - ) - for column, value in zip(columns, value_columns) - } - null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} - null_replacement_columns = [ - null_replacement_mapping[column.alias_or_name] for column in all_columns - ] - new_df = new_df.select(*null_replacement_columns) - return new_df - - @operation(Operation.FROM) - def replace( - self, - to_replace: t.Union[bool, int, float, str, t.List, t.Dict], - value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, - subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, - ) -> DataFrame: - from sqlglot.dataframe.sql.functions import lit - - old_values = None - new_df = self.copy() - all_columns = self._get_outer_select_columns(new_df.expression) - all_column_mapping = {column.alias_or_name: column for column in all_columns} - - columns = self._ensure_and_normalize_cols(subset) if subset else all_columns - if isinstance(to_replace, dict): - old_values = list(to_replace) - new_values = list(to_replace.values()) - elif not old_values and isinstance(to_replace, list): - assert isinstance(value, list), "value must be a list since the replacements are a list" - assert len(to_replace) == len( - value - ), "the replacements and values must be the same length" - old_values = to_replace - new_values = value - else: - old_values = [to_replace] * len(columns) - new_values = [value] * len(columns) - old_values = [lit(value) for value in old_values] - new_values = [lit(value) for value in new_values] - - replacement_mapping = {} - for column in columns: - expression = Column(None) - for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): - if i == 0: - expression = F.when(column == old_value, new_value) - else: - expression = expression.when(column == old_value, new_value) # type: ignore - replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( - column.expression.alias_or_name - ) - - replacement_mapping = {**all_column_mapping, **replacement_mapping} - replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] - new_df = new_df.select(*replacement_columns) - return new_df - - @operation(Operation.SELECT) - def withColumn(self, colName: str, col: Column) -> DataFrame: - col = self._ensure_and_normalize_col(col) - existing_col_names = self.expression.named_selects - existing_col_index = ( - existing_col_names.index(colName) if colName in existing_col_names else None - ) - if existing_col_index: - expression = self.expression.copy() - expression.expressions[existing_col_index] = col.expression - return self.copy(expression=expression) - return self.copy().select(col.alias(colName), append=True) - - @operation(Operation.SELECT) - def withColumnRenamed(self, existing: str, new: str): - expression = self.expression.copy() - existing_columns = [ - expression - for expression in expression.expressions - if expression.alias_or_name == existing - ] - if not existing_columns: - raise ValueError("Tried to rename a column that doesn't exist") - for existing_column in existing_columns: - if isinstance(existing_column, exp.Column): - existing_column.replace(exp.alias_(existing_column, new)) - else: - existing_column.set("alias", exp.to_identifier(new)) - return self.copy(expression=expression) - - @operation(Operation.SELECT) - def drop(self, *cols: t.Union[str, Column]) -> DataFrame: - all_columns = self._get_outer_select_columns(self.expression) - drop_cols = self._ensure_and_normalize_cols(cols) - new_columns = [ - col - for col in all_columns - if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] - ] - return self.copy().select(*new_columns, append=False) - - @operation(Operation.LIMIT) - def limit(self, num: int) -> DataFrame: - return self.copy(expression=self.expression.limit(num)) - - @operation(Operation.NO_OP) - def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: - parameter_list = ensure_list(parameters) - parameter_columns = ( - self._ensure_list_of_columns(parameter_list) - if parameters - else Column.ensure_cols([self.sequence_id]) - ) - return self._hint(name, parameter_columns) - - @operation(Operation.NO_OP) - def repartition( - self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName - ) -> DataFrame: - num_partition_cols = self._ensure_list_of_columns(numPartitions) - columns = self._ensure_and_normalize_cols(cols) - args = num_partition_cols + columns - return self._hint("repartition", args) - - @operation(Operation.NO_OP) - def coalesce(self, numPartitions: int) -> DataFrame: - num_partitions = Column.ensure_cols([numPartitions]) - return self._hint("coalesce", num_partitions) - - @operation(Operation.NO_OP) - def cache(self) -> DataFrame: - return self._cache(storage_level="MEMORY_AND_DISK") - - @operation(Operation.NO_OP) - def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: - """ - Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html - """ - return self._cache(storageLevel) - - -class DataFrameNaFunctions: - def __init__(self, df: DataFrame): - self.df = df - - def drop( - self, - how: str = "any", - thresh: t.Optional[int] = None, - subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, - ) -> DataFrame: - return self.df.dropna(how=how, thresh=thresh, subset=subset) - - def fill( - self, - value: t.Union[int, bool, float, str, t.Dict[str, t.Any]], - subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, - ) -> DataFrame: - return self.df.fillna(value=value, subset=subset) - - def replace( - self, - to_replace: t.Union[bool, int, float, str, t.List, t.Dict], - value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, - subset: t.Optional[t.Union[str, t.List[str]]] = None, - ) -> DataFrame: - return self.df.replace(to_replace=to_replace, value=value, subset=subset) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py deleted file mode 100644 index 81b7d61..0000000 --- a/sqlglot/dataframe/sql/functions.py +++ /dev/null @@ -1,1270 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import exp as expression -from sqlglot.dataframe.sql.column import Column -from sqlglot.helper import ensure_list, flatten as _flatten - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql._typing import ColumnOrLiteral, ColumnOrName - from sqlglot.dataframe.sql.dataframe import DataFrame - - -def col(column_name: t.Union[ColumnOrName, t.Any]) -> Column: - return Column(column_name) - - -def lit(value: t.Optional[t.Any] = None) -> Column: - if isinstance(value, str): - return Column(expression.Literal.string(str(value))) - return Column(value) - - -def greatest(*cols: ColumnOrName) -> Column: - if len(cols) > 1: - return Column.invoke_expression_over_column( - cols[0], expression.Greatest, expressions=cols[1:] - ) - return Column.invoke_expression_over_column(cols[0], expression.Greatest) - - -def least(*cols: ColumnOrName) -> Column: - if len(cols) > 1: - return Column.invoke_expression_over_column(cols[0], expression.Least, expressions=cols[1:]) - return Column.invoke_expression_over_column(cols[0], expression.Least) - - -def count_distinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: - columns = [Column.ensure_col(x) for x in [col] + list(cols)] - return Column( - expression.Count(this=expression.Distinct(expressions=[x.expression for x in columns])) - ) - - -def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: - return count_distinct(col, *cols) - - -def when(condition: Column, value: t.Any) -> Column: - true_value = value if isinstance(value, Column) else lit(value) - return Column( - expression.Case( - ifs=[expression.If(this=condition.column_expression, true=true_value.column_expression)] - ) - ) - - -def asc(col: ColumnOrName) -> Column: - return Column.ensure_col(col).asc() - - -def desc(col: ColumnOrName): - return Column.ensure_col(col).desc() - - -def broadcast(df: DataFrame) -> DataFrame: - return df.hint("broadcast") - - -def sqrt(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Sqrt) - - -def abs(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Abs) - - -def max(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Max) - - -def min(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Min) - - -def max_by(col: ColumnOrName, ord: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.ArgMax, expression=ord) - - -def min_by(col: ColumnOrName, ord: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.ArgMin, expression=ord) - - -def count(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Count) - - -def sum(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Sum) - - -def avg(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Avg) - - -def mean(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "MEAN") - - -def sumDistinct(col: ColumnOrName) -> Column: - return sum_distinct(col) - - -def sum_distinct(col: ColumnOrName) -> Column: - raise NotImplementedError("Sum distinct is not currently implemented") - - -def product(col: ColumnOrName) -> Column: - raise NotImplementedError("Product is not currently implemented") - - -def acos(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "ACOS") - - -def acosh(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "ACOSH") - - -def asin(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "ASIN") - - -def asinh(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "ASINH") - - -def atan(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "ATAN") - - -def atan2(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column: - return Column.invoke_anonymous_function(col1, "ATAN2", col2) - - -def atanh(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "ATANH") - - -def cbrt(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Cbrt) - - -def ceil(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Ceil) - - -def cos(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "COS") - - -def cosh(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "COSH") - - -def cot(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "COT") - - -def csc(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "CSC") - - -def exp(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Exp) - - -def expm1(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "EXPM1") - - -def floor(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Floor) - - -def log10(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(lit(10), expression.Log, expression=col) - - -def log1p(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "LOG1P") - - -def log2(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(lit(2), expression.Log, expression=col) - - -def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column: - if arg2 is None: - return Column.invoke_expression_over_column(arg1, expression.Ln) - return Column.invoke_expression_over_column(arg1, expression.Log, expression=arg2) - - -def rint(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "RINT") - - -def sec(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "SEC") - - -def signum(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Sign) - - -def sin(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "SIN") - - -def sinh(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "SINH") - - -def tan(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "TAN") - - -def tanh(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "TANH") - - -def toDegrees(col: ColumnOrName) -> Column: - return degrees(col) - - -def degrees(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "DEGREES") - - -def toRadians(col: ColumnOrName) -> Column: - return radians(col) - - -def radians(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "RADIANS") - - -def bitwiseNOT(col: ColumnOrName) -> Column: - return bitwise_not(col) - - -def bitwise_not(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.BitwiseNot) - - -def asc_nulls_first(col: ColumnOrName) -> Column: - return Column.ensure_col(col).asc_nulls_first() - - -def asc_nulls_last(col: ColumnOrName) -> Column: - return Column.ensure_col(col).asc_nulls_last() - - -def desc_nulls_first(col: ColumnOrName) -> Column: - return Column.ensure_col(col).desc_nulls_first() - - -def desc_nulls_last(col: ColumnOrName) -> Column: - return Column.ensure_col(col).desc_nulls_last() - - -def stddev(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Stddev) - - -def stddev_samp(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.StddevSamp) - - -def stddev_pop(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.StddevPop) - - -def variance(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Variance) - - -def var_samp(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Variance) - - -def var_pop(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.VariancePop) - - -def skewness(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "SKEWNESS") - - -def kurtosis(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "KURTOSIS") - - -def collect_list(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.ArrayAgg) - - -def collect_set(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.ArrayUniqueAgg) - - -def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column: - return Column.invoke_anonymous_function(col1, "HYPOT", col2) - - -def pow(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column: - return Column.invoke_expression_over_column(col1, expression.Pow, expression=col2) - - -def row_number() -> Column: - return Column(expression.Anonymous(this="ROW_NUMBER")) - - -def dense_rank() -> Column: - return Column(expression.Anonymous(this="DENSE_RANK")) - - -def rank() -> Column: - return Column(expression.Anonymous(this="RANK")) - - -def cume_dist() -> Column: - return Column(expression.Anonymous(this="CUME_DIST")) - - -def percent_rank() -> Column: - return Column(expression.Anonymous(this="PERCENT_RANK")) - - -def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column: - return approx_count_distinct(col, rsd) - - -def approx_count_distinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column: - if rsd is None: - return Column.invoke_expression_over_column(col, expression.ApproxDistinct) - return Column.invoke_expression_over_column(col, expression.ApproxDistinct, accuracy=rsd) - - -def coalesce(*cols: ColumnOrName) -> Column: - if len(cols) > 1: - return Column.invoke_expression_over_column( - cols[0], expression.Coalesce, expressions=cols[1:] - ) - return Column.invoke_expression_over_column(cols[0], expression.Coalesce) - - -def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col1, expression.Corr, expression=col2) - - -def covar_pop(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col1, expression.CovarPop, expression=col2) - - -def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col1, expression.CovarSamp, expression=col2) - - -def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column: - this = Column.invoke_expression_over_column(col, expression.First) - if ignorenulls: - return Column.invoke_expression_over_column(this, expression.IgnoreNulls) - return this - - -def grouping_id(*cols: ColumnOrName) -> Column: - if not cols: - return Column.invoke_anonymous_function(None, "GROUPING_ID") - if len(cols) == 1: - return Column.invoke_anonymous_function(cols[0], "GROUPING_ID") - return Column.invoke_anonymous_function(cols[0], "GROUPING_ID", *cols[1:]) - - -def input_file_name() -> Column: - return Column.invoke_anonymous_function(None, "INPUT_FILE_NAME") - - -def isnan(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.IsNan) - - -def isnull(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "ISNULL") - - -def last(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column: - this = Column.invoke_expression_over_column(col, expression.Last) - if ignorenulls: - return Column.invoke_expression_over_column(this, expression.IgnoreNulls) - return this - - -def monotonically_increasing_id() -> Column: - return Column.invoke_anonymous_function(None, "MONOTONICALLY_INCREASING_ID") - - -def nanvl(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "NANVL", col2) - - -def percentile_approx( - col: ColumnOrName, - percentage: t.Union[ColumnOrLiteral, t.List[float], t.Tuple[float]], - accuracy: t.Optional[t.Union[ColumnOrLiteral, int]] = None, -) -> Column: - if accuracy: - return Column.invoke_expression_over_column( - col, expression.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy - ) - return Column.invoke_expression_over_column( - col, expression.ApproxQuantile, quantile=lit(percentage) - ) - - -def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column: - return Column.invoke_expression_over_column(seed, expression.Rand) - - -def randn(seed: t.Optional[ColumnOrLiteral] = None) -> Column: - return Column.invoke_expression_over_column(seed, expression.Randn) - - -def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column: - if scale is not None: - return Column.invoke_expression_over_column(col, expression.Round, decimals=scale) - return Column.invoke_expression_over_column(col, expression.Round) - - -def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column: - if scale is not None: - return Column.invoke_anonymous_function(col, "BROUND", scale) - return Column.invoke_anonymous_function(col, "BROUND") - - -def shiftleft(col: ColumnOrName, numBits: int) -> Column: - return Column.invoke_expression_over_column( - col, expression.BitwiseLeftShift, expression=numBits - ) - - -def shiftLeft(col: ColumnOrName, numBits: int) -> Column: - return shiftleft(col, numBits) - - -def shiftright(col: ColumnOrName, numBits: int) -> Column: - return Column.invoke_expression_over_column( - col, expression.BitwiseRightShift, expression=numBits - ) - - -def shiftRight(col: ColumnOrName, numBits: int) -> Column: - return shiftright(col, numBits) - - -def shiftrightunsigned(col: ColumnOrName, numBits: int) -> Column: - return Column.invoke_anonymous_function(col, "SHIFTRIGHTUNSIGNED", numBits) - - -def shiftRightUnsigned(col: ColumnOrName, numBits: int) -> Column: - return shiftrightunsigned(col, numBits) - - -def expr(str: str) -> Column: - return Column(str) - - -def struct(col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOrName) -> Column: - columns = ensure_list(col) + list(cols) - return Column.invoke_expression_over_column(None, expression.Struct, expressions=columns) - - -def conv(col: ColumnOrName, fromBase: int, toBase: int) -> Column: - return Column.invoke_anonymous_function(col, "CONV", fromBase, toBase) - - -def factorial(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "FACTORIAL") - - -def lag( - col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None -) -> Column: - return Column.invoke_expression_over_column( - col, expression.Lag, offset=None if offset == 1 else offset, default=default - ) - - -def lead( - col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None -) -> Column: - return Column.invoke_expression_over_column( - col, expression.Lead, offset=None if offset == 1 else offset, default=default - ) - - -def nth_value( - col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None -) -> Column: - this = Column.invoke_expression_over_column( - col, expression.NthValue, offset=None if offset == 1 else offset - ) - if ignoreNulls is not None: - return Column.invoke_expression_over_column(this, expression.IgnoreNulls) - return this - - -def ntile(n: int) -> Column: - return Column.invoke_anonymous_function(None, "NTILE", n) - - -def current_date() -> Column: - return Column.invoke_expression_over_column(None, expression.CurrentDate) - - -def current_timestamp() -> Column: - return Column.invoke_expression_over_column(None, expression.CurrentTimestamp) - - -def date_format(col: ColumnOrName, format: str) -> Column: - return Column.invoke_expression_over_column(col, expression.TimeToStr, format=lit(format)) - - -def year(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Year) - - -def quarter(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Quarter) - - -def month(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Month) - - -def dayofweek(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.DayOfWeek) - - -def dayofmonth(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.DayOfMonth) - - -def dayofyear(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.DayOfYear) - - -def hour(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "HOUR") - - -def minute(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "MINUTE") - - -def second(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "SECOND") - - -def weekofyear(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.WeekOfYear) - - -def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(year, "MAKE_DATE", month, day) - - -def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: - return Column.invoke_expression_over_column( - col, expression.DateAdd, expression=days, unit=expression.Var(this="DAY") - ) - - -def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: - return Column.invoke_expression_over_column( - col, expression.DateSub, expression=days, unit=expression.Var(this="DAY") - ) - - -def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(end, expression.DateDiff, expression=start) - - -def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column: - return Column.invoke_expression_over_column(start, expression.AddMonths, expression=months) - - -def months_between( - date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None -) -> Column: - if roundOff is None: - return Column.invoke_expression_over_column( - date1, expression.MonthsBetween, expression=date2 - ) - - return Column.invoke_expression_over_column( - date1, expression.MonthsBetween, expression=date2, roundoff=roundOff - ) - - -def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column: - if format is not None: - return Column.invoke_expression_over_column( - col, expression.TsOrDsToDate, format=lit(format) - ) - return Column.invoke_expression_over_column(col, expression.TsOrDsToDate) - - -def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column: - if format is not None: - return Column.invoke_expression_over_column(col, expression.StrToTime, format=lit(format)) - - return Column.ensure_col(col).cast("timestamp") - - -def trunc(col: ColumnOrName, format: str) -> Column: - return Column.invoke_expression_over_column(col, expression.DateTrunc, unit=lit(format)) - - -def date_trunc(format: str, timestamp: ColumnOrName) -> Column: - return Column.invoke_expression_over_column( - timestamp, expression.TimestampTrunc, unit=lit(format) - ) - - -def next_day(col: ColumnOrName, dayOfWeek: str) -> Column: - return Column.invoke_anonymous_function(col, "NEXT_DAY", lit(dayOfWeek)) - - -def last_day(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.LastDay) - - -def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column: - if format is not None: - return Column.invoke_expression_over_column(col, expression.UnixToStr, format=lit(format)) - return Column.invoke_expression_over_column(col, expression.UnixToStr) - - -def unix_timestamp( - timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None -) -> Column: - if format is not None: - return Column.invoke_expression_over_column( - timestamp, expression.StrToUnix, format=lit(format) - ) - return Column.invoke_expression_over_column(timestamp, expression.StrToUnix) - - -def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column: - tz_column = tz if isinstance(tz, Column) else lit(tz) - return Column.invoke_expression_over_column(timestamp, expression.AtTimeZone, zone=tz_column) - - -def to_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column: - tz_column = tz if isinstance(tz, Column) else lit(tz) - return Column.invoke_expression_over_column(timestamp, expression.FromTimeZone, zone=tz_column) - - -def timestamp_seconds(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "TIMESTAMP_SECONDS") - - -def window( - timeColumn: ColumnOrName, - windowDuration: str, - slideDuration: t.Optional[str] = None, - startTime: t.Optional[str] = None, -) -> Column: - if slideDuration is not None and startTime is not None: - return Column.invoke_anonymous_function( - timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration), lit(startTime) - ) - if slideDuration is not None: - return Column.invoke_anonymous_function( - timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration) - ) - if startTime is not None: - return Column.invoke_anonymous_function( - timeColumn, "WINDOW", lit(windowDuration), lit(windowDuration), lit(startTime) - ) - return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration)) - - -def session_window(timeColumn: ColumnOrName, gapDuration: ColumnOrName) -> Column: - gap_duration_column = gapDuration if isinstance(gapDuration, Column) else lit(gapDuration) - return Column.invoke_anonymous_function(timeColumn, "SESSION_WINDOW", gap_duration_column) - - -def crc32(col: ColumnOrName) -> Column: - column = col if isinstance(col, Column) else lit(col) - return Column.invoke_anonymous_function(column, "CRC32") - - -def md5(col: ColumnOrName) -> Column: - column = col if isinstance(col, Column) else lit(col) - return Column.invoke_expression_over_column(column, expression.MD5) - - -def sha1(col: ColumnOrName) -> Column: - column = col if isinstance(col, Column) else lit(col) - return Column.invoke_expression_over_column(column, expression.SHA) - - -def sha2(col: ColumnOrName, numBits: int) -> Column: - column = col if isinstance(col, Column) else lit(col) - return Column.invoke_expression_over_column(column, expression.SHA2, length=lit(numBits)) - - -def hash(*cols: ColumnOrName) -> Column: - args = cols[1:] if len(cols) > 1 else [] - return Column.invoke_anonymous_function(cols[0], "HASH", *args) - - -def xxhash64(*cols: ColumnOrName) -> Column: - args = cols[1:] if len(cols) > 1 else [] - return Column.invoke_anonymous_function(cols[0], "XXHASH64", *args) - - -def assert_true(col: ColumnOrName, errorMsg: t.Optional[ColumnOrName] = None) -> Column: - if errorMsg is not None: - error_msg_col = errorMsg if isinstance(errorMsg, Column) else lit(errorMsg) - return Column.invoke_anonymous_function(col, "ASSERT_TRUE", error_msg_col) - return Column.invoke_anonymous_function(col, "ASSERT_TRUE") - - -def raise_error(errorMsg: ColumnOrName) -> Column: - error_msg_col = errorMsg if isinstance(errorMsg, Column) else lit(errorMsg) - return Column.invoke_anonymous_function(error_msg_col, "RAISE_ERROR") - - -def upper(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Upper) - - -def lower(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Lower) - - -def ascii(col: ColumnOrLiteral) -> Column: - return Column.invoke_anonymous_function(col, "ASCII") - - -def base64(col: ColumnOrLiteral) -> Column: - return Column.invoke_expression_over_column(col, expression.ToBase64) - - -def unbase64(col: ColumnOrLiteral) -> Column: - return Column.invoke_expression_over_column(col, expression.FromBase64) - - -def ltrim(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "LTRIM") - - -def rtrim(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "RTRIM") - - -def trim(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Trim) - - -def concat_ws(sep: str, *cols: ColumnOrName) -> Column: - return Column.invoke_expression_over_column( - None, expression.ConcatWs, expressions=[lit(sep)] + list(cols) - ) - - -def decode(col: ColumnOrName, charset: str) -> Column: - return Column.invoke_expression_over_column( - col, expression.Decode, charset=expression.Literal.string(charset) - ) - - -def encode(col: ColumnOrName, charset: str) -> Column: - return Column.invoke_expression_over_column( - col, expression.Encode, charset=expression.Literal.string(charset) - ) - - -def format_number(col: ColumnOrName, d: int) -> Column: - return Column.invoke_anonymous_function(col, "FORMAT_NUMBER", lit(d)) - - -def format_string(format: str, *cols: ColumnOrName) -> Column: - format_col = lit(format) - columns = [Column.ensure_col(x) for x in cols] - return Column.invoke_anonymous_function(format_col, "FORMAT_STRING", *columns) - - -def instr(col: ColumnOrName, substr: str) -> Column: - return Column.invoke_anonymous_function(col, "INSTR", lit(substr)) - - -def overlay( - src: ColumnOrName, - replace: ColumnOrName, - pos: t.Union[ColumnOrName, int], - len: t.Optional[t.Union[ColumnOrName, int]] = None, -) -> Column: - if len is not None: - return Column.invoke_anonymous_function(src, "OVERLAY", replace, pos, len) - return Column.invoke_anonymous_function(src, "OVERLAY", replace, pos) - - -def sentences( - string: ColumnOrName, - language: t.Optional[ColumnOrName] = None, - country: t.Optional[ColumnOrName] = None, -) -> Column: - if language is not None and country is not None: - return Column.invoke_anonymous_function(string, "SENTENCES", language, country) - if language is not None: - return Column.invoke_anonymous_function(string, "SENTENCES", language) - if country is not None: - return Column.invoke_anonymous_function(string, "SENTENCES", lit("en"), country) - return Column.invoke_anonymous_function(string, "SENTENCES") - - -def substring(str: ColumnOrName, pos: int, len: int) -> Column: - return Column.ensure_col(str).substr(pos, len) - - -def substring_index(str: ColumnOrName, delim: str, count: int) -> Column: - return Column.invoke_anonymous_function(str, "SUBSTRING_INDEX", lit(delim), lit(count)) - - -def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(left, expression.Levenshtein, expression=right) - - -def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column: - substr_col = lit(substr) - if pos is not None: - return Column.invoke_expression_over_column( - str, expression.StrPosition, substr=substr_col, position=pos - ) - return Column.invoke_expression_over_column(str, expression.StrPosition, substr=substr_col) - - -def lpad(col: ColumnOrName, len: int, pad: str) -> Column: - return Column.invoke_anonymous_function(col, "LPAD", lit(len), lit(pad)) - - -def rpad(col: ColumnOrName, len: int, pad: str) -> Column: - return Column.invoke_anonymous_function(col, "RPAD", lit(len), lit(pad)) - - -def repeat(col: ColumnOrName, n: int) -> Column: - return Column.invoke_expression_over_column(col, expression.Repeat, times=lit(n)) - - -def split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column: - if limit is not None: - return Column.invoke_expression_over_column( - str, expression.RegexpSplit, expression=lit(pattern).expression, limit=limit - ) - return Column.invoke_expression_over_column( - str, expression.RegexpSplit, expression=lit(pattern) - ) - - -def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) -> Column: - return Column.invoke_expression_over_column( - str, - expression.RegexpExtract, - expression=lit(pattern), - group=idx, - ) - - -def regexp_replace( - str: ColumnOrName, pattern: str, replacement: str, position: t.Optional[int] = None -) -> Column: - return Column.invoke_expression_over_column( - str, - expression.RegexpReplace, - expression=lit(pattern), - replacement=lit(replacement), - position=position, - ) - - -def initcap(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Initcap) - - -def soundex(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "SOUNDEX") - - -def bin(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "BIN") - - -def hex(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Hex) - - -def unhex(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Unhex) - - -def length(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Length) - - -def octet_length(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "OCTET_LENGTH") - - -def bit_length(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "BIT_LENGTH") - - -def translate(srcCol: ColumnOrName, matching: str, replace: str) -> Column: - return Column.invoke_anonymous_function(srcCol, "TRANSLATE", lit(matching), lit(replace)) - - -def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: - columns = _flatten(cols) if not isinstance(cols[0], (str, Column)) else cols - return Column.invoke_expression_over_column(None, expression.Array, expressions=columns) - - -def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: - cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore - return Column.invoke_expression_over_column( - None, - expression.VarMap, - keys=array(*cols[::2]).expression, - values=array(*cols[1::2]).expression, - ) - - -def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(None, expression.Map, keys=col1, values=col2) - - -def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column: - value_col = value if isinstance(value, Column) else lit(value) - return Column.invoke_expression_over_column( - col, expression.ArrayContains, expression=value_col.expression - ) - - -def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "ARRAYS_OVERLAP", Column.ensure_col(col2)) - - -def slice( - x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int] -) -> Column: - start_col = start if isinstance(start, Column) else lit(start) - length_col = length if isinstance(length, Column) else lit(length) - return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col) - - -def array_join( - col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None -) -> Column: - if null_replacement is not None: - return Column.invoke_expression_over_column( - col, expression.ArrayToString, expression=lit(delimiter), null=lit(null_replacement) - ) - return Column.invoke_expression_over_column( - col, expression.ArrayToString, expression=lit(delimiter) - ) - - -def concat(*cols: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(None, expression.Concat, expressions=cols) - - -def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column: - value_col = value if isinstance(value, Column) else lit(value) - return Column.invoke_anonymous_function(col, "ARRAY_POSITION", value_col) - - -def element_at(col: ColumnOrName, value: ColumnOrLiteral) -> Column: - value_col = value if isinstance(value, Column) else lit(value) - return Column.invoke_anonymous_function(col, "ELEMENT_AT", value_col) - - -def array_remove(col: ColumnOrName, value: ColumnOrLiteral) -> Column: - value_col = value if isinstance(value, Column) else lit(value) - return Column.invoke_anonymous_function(col, "ARRAY_REMOVE", value_col) - - -def array_distinct(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "ARRAY_DISTINCT") - - -def array_intersect(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "ARRAY_INTERSECT", Column.ensure_col(col2)) - - -def array_union(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "ARRAY_UNION", Column.ensure_col(col2)) - - -def array_except(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "ARRAY_EXCEPT", Column.ensure_col(col2)) - - -def explode(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Explode) - - -def posexplode(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Posexplode) - - -def explode_outer(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.ExplodeOuter) - - -def posexplode_outer(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.PosexplodeOuter) - - -def get_json_object(col: ColumnOrName, path: str) -> Column: - return Column.invoke_expression_over_column(col, expression.JSONExtract, expression=lit(path)) - - -def json_tuple(col: ColumnOrName, *fields: str) -> Column: - return Column.invoke_anonymous_function(col, "JSON_TUPLE", *[lit(field) for field in fields]) - - -def from_json( - col: ColumnOrName, - schema: t.Union[Column, str], - options: t.Optional[t.Dict[str, str]] = None, -) -> Column: - schema = schema if isinstance(schema, Column) else lit(schema) - if options is not None: - options_col = create_map([lit(x) for x in _flatten(options.items())]) - return Column.invoke_anonymous_function(col, "FROM_JSON", schema, options_col) - return Column.invoke_anonymous_function(col, "FROM_JSON", schema) - - -def to_json(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column: - if options is not None: - options_col = create_map([lit(x) for x in _flatten(options.items())]) - return Column.invoke_expression_over_column(col, expression.JSONFormat, options=options_col) - return Column.invoke_expression_over_column(col, expression.JSONFormat) - - -def schema_of_json(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column: - if options is not None: - options_col = create_map([lit(x) for x in _flatten(options.items())]) - return Column.invoke_anonymous_function(col, "SCHEMA_OF_JSON", options_col) - return Column.invoke_anonymous_function(col, "SCHEMA_OF_JSON") - - -def schema_of_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column: - if options is not None: - options_col = create_map([lit(x) for x in _flatten(options.items())]) - return Column.invoke_anonymous_function(col, "SCHEMA_OF_CSV", options_col) - return Column.invoke_anonymous_function(col, "SCHEMA_OF_CSV") - - -def to_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column: - if options is not None: - options_col = create_map([lit(x) for x in _flatten(options.items())]) - return Column.invoke_anonymous_function(col, "TO_CSV", options_col) - return Column.invoke_anonymous_function(col, "TO_CSV") - - -def size(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.ArraySize) - - -def array_min(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "ARRAY_MIN") - - -def array_max(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "ARRAY_MAX") - - -def sort_array(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column: - if asc is not None: - return Column.invoke_expression_over_column(col, expression.SortArray, asc=asc) - return Column.invoke_expression_over_column(col, expression.SortArray) - - -def array_sort( - col: ColumnOrName, - comparator: t.Optional[t.Union[t.Callable[[Column, Column], Column]]] = None, -) -> Column: - if comparator is not None: - f_expression = _get_lambda_from_func(comparator) - return Column.invoke_expression_over_column( - col, expression.ArraySort, expression=f_expression - ) - return Column.invoke_expression_over_column(col, expression.ArraySort) - - -def shuffle(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "SHUFFLE") - - -def reverse(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "REVERSE") - - -def flatten(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Flatten) - - -def map_keys(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "MAP_KEYS") - - -def map_values(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "MAP_VALUES") - - -def map_entries(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "MAP_ENTRIES") - - -def map_from_entries(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.MapFromEntries) - - -def array_repeat(col: ColumnOrName, count: t.Union[ColumnOrName, int]) -> Column: - count_col = count if isinstance(count, Column) else lit(count) - return Column.invoke_anonymous_function(col, "ARRAY_REPEAT", count_col) - - -def array_zip(*cols: ColumnOrName) -> Column: - if len(cols) == 1: - return Column.invoke_anonymous_function(cols[0], "ARRAY_ZIP") - return Column.invoke_anonymous_function(cols[0], "ARRAY_ZIP", *cols[1:]) - - -def map_concat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: - columns = list(flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore - if len(columns) == 1: - return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT") - return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT", *columns[1:]) - - -def sequence( - start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None -) -> Column: - if step is not None: - return Column.invoke_anonymous_function(start, "SEQUENCE", stop, step) - return Column.invoke_anonymous_function(start, "SEQUENCE", stop) - - -def from_csv( - col: ColumnOrName, - schema: t.Union[Column, str], - options: t.Optional[t.Dict[str, str]] = None, -) -> Column: - schema = schema if isinstance(schema, Column) else lit(schema) - if options is not None: - option_cols = create_map([lit(x) for x in _flatten(options.items())]) - return Column.invoke_anonymous_function(col, "FROM_CSV", schema, option_cols) - return Column.invoke_anonymous_function(col, "FROM_CSV", schema) - - -def aggregate( - col: ColumnOrName, - initialValue: ColumnOrName, - merge: t.Callable[[Column, Column], Column], - finish: t.Optional[t.Callable[[Column], Column]] = None, -) -> Column: - merge_exp = _get_lambda_from_func(merge) - if finish is not None: - finish_exp = _get_lambda_from_func(finish) - return Column.invoke_expression_over_column( - col, - expression.Reduce, - initial=initialValue, - merge=Column(merge_exp), - finish=Column(finish_exp), - ) - return Column.invoke_expression_over_column( - col, expression.Reduce, initial=initialValue, merge=Column(merge_exp) - ) - - -def transform( - col: ColumnOrName, - f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], -) -> Column: - f_expression = _get_lambda_from_func(f) - return Column.invoke_expression_over_column( - col, expression.Transform, expression=Column(f_expression) - ) - - -def exists(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column: - f_expression = _get_lambda_from_func(f) - return Column.invoke_anonymous_function(col, "EXISTS", Column(f_expression)) - - -def forall(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column: - f_expression = _get_lambda_from_func(f) - return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression)) - - -def filter( - col: ColumnOrName, - f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], -) -> Column: - f_expression = _get_lambda_from_func(f) - return Column.invoke_expression_over_column( - col, expression.ArrayFilter, expression=f_expression - ) - - -def zip_with( - left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column] -) -> Column: - f_expression = _get_lambda_from_func(f) - return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression)) - - -def transform_keys(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column: - f_expression = _get_lambda_from_func(f) - return Column.invoke_anonymous_function(col, "TRANSFORM_KEYS", Column(f_expression)) - - -def transform_values(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column: - f_expression = _get_lambda_from_func(f) - return Column.invoke_anonymous_function(col, "TRANSFORM_VALUES", Column(f_expression)) - - -def map_filter(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column: - f_expression = _get_lambda_from_func(f) - return Column.invoke_anonymous_function(col, "MAP_FILTER", Column(f_expression)) - - -def map_zip_with( - col1: ColumnOrName, - col2: ColumnOrName, - f: t.Union[t.Callable[[Column, Column, Column], Column]], -) -> Column: - f_expression = _get_lambda_from_func(f) - return Column.invoke_anonymous_function(col1, "MAP_ZIP_WITH", col2, Column(f_expression)) - - -def _lambda_quoted(value: str) -> t.Optional[bool]: - return False if value == "_" else None - - -def _get_lambda_from_func(lambda_expression: t.Callable): - variables = [ - expression.to_identifier(x, quoted=_lambda_quoted(x)) - for x in lambda_expression.__code__.co_varnames - ] - return expression.Lambda( - this=lambda_expression(*[Column(x) for x in variables]).expression, - expressions=variables, - ) diff --git a/sqlglot/dataframe/sql/group.py b/sqlglot/dataframe/sql/group.py deleted file mode 100644 index ba27c17..0000000 --- a/sqlglot/dataframe/sql/group.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot.dataframe.sql import functions as F -from sqlglot.dataframe.sql.column import Column -from sqlglot.dataframe.sql.operations import Operation, operation - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql.dataframe import DataFrame - - -class GroupedData: - def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation): - self._df = df.copy() - self.spark = df.spark - self.last_op = last_op - self.group_by_cols = group_by_cols - - def _get_function_applied_columns( - self, func_name: str, cols: t.Tuple[str, ...] - ) -> t.List[Column]: - func_name = func_name.lower() - return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols] - - @operation(Operation.SELECT) - def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: - columns = ( - [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] - if isinstance(exprs[0], dict) - else exprs - ) - cols = self._df._ensure_and_normalize_cols(columns) - - expression = self._df.expression.group_by( - *[x.expression for x in self.group_by_cols] - ).select(*[x.expression for x in self.group_by_cols + cols], append=False) - return self._df.copy(expression=expression) - - def count(self) -> DataFrame: - return self.agg(F.count("*").alias("count")) - - def mean(self, *cols: str) -> DataFrame: - return self.avg(*cols) - - def avg(self, *cols: str) -> DataFrame: - return self.agg(*self._get_function_applied_columns("avg", cols)) - - def max(self, *cols: str) -> DataFrame: - return self.agg(*self._get_function_applied_columns("max", cols)) - - def min(self, *cols: str) -> DataFrame: - return self.agg(*self._get_function_applied_columns("min", cols)) - - def sum(self, *cols: str) -> DataFrame: - return self.agg(*self._get_function_applied_columns("sum", cols)) - - def pivot(self, *cols: str) -> DataFrame: - raise NotImplementedError("Sum distinct is not currently implemented") diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py deleted file mode 100644 index b246641..0000000 --- a/sqlglot/dataframe/sql/normalize.py +++ /dev/null @@ -1,78 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import expressions as exp -from sqlglot.dataframe.sql.column import Column -from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join -from sqlglot.helper import ensure_list - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql.session import SparkSession - - NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column]) - - -def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]): - expr = ensure_list(expr) - expressions = _ensure_expressions(expr) - for expression in expressions: - identifiers = expression.find_all(exp.Identifier) - for identifier in identifiers: - identifier.transform(spark.dialect.normalize_identifier) - replace_alias_name_with_cte_name(spark, expression_context, identifier) - replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier) - - -def replace_alias_name_with_cte_name( - spark: SparkSession, expression_context: exp.Select, id: exp.Identifier -): - if id.alias_or_name in spark.name_to_sequence_id_mapping: - for cte in reversed(expression_context.ctes): - if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]: - _set_alias_name(id, cte.alias_or_name) - break - - -def replace_branch_and_sequence_ids_with_cte_name( - spark: SparkSession, expression_context: exp.Select, id: exp.Identifier -): - if id.alias_or_name in spark.known_ids: - # Check if we have a join and if both the tables in that join share a common branch id - # If so we need to have this reference the left table by default unless the id is a sequence - # id then it keeps that reference. This handles the weird edge case in spark that shouldn't - # be common in practice - if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids: - join_table_aliases = [ - x.alias_or_name for x in get_tables_from_expression_with_join(expression_context) - ] - ctes_in_join = [ - cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases - ] - if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]: - assert len(ctes_in_join) == 2 - _set_alias_name(id, ctes_in_join[0].alias_or_name) - return - - for cte in reversed(expression_context.ctes): - if id.alias_or_name in (cte.args["branch_id"], cte.args["sequence_id"]): - _set_alias_name(id, cte.alias_or_name) - return - - -def _set_alias_name(id: exp.Identifier, name: str): - id.set("this", name) - - -def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]: - results = [] - for value in values: - if isinstance(value, str): - results.append(Column.ensure_col(value).expression) - elif isinstance(value, Column): - results.append(value.expression) - elif isinstance(value, exp.Expression): - results.append(value) - else: - raise ValueError(f"Got an invalid type to normalize: {type(value)}") - return results diff --git a/sqlglot/dataframe/sql/operations.py b/sqlglot/dataframe/sql/operations.py deleted file mode 100644 index e4c106b..0000000 --- a/sqlglot/dataframe/sql/operations.py +++ /dev/null @@ -1,53 +0,0 @@ -from __future__ import annotations - -import functools -import typing as t -from enum import IntEnum - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql.dataframe import DataFrame - from sqlglot.dataframe.sql.group import GroupedData - - -class Operation(IntEnum): - INIT = -1 - NO_OP = 0 - FROM = 1 - WHERE = 2 - GROUP_BY = 3 - HAVING = 4 - SELECT = 5 - ORDER_BY = 6 - LIMIT = 7 - - -def operation(op: Operation): - """ - Decorator used around DataFrame methods to indicate what type of operation is being performed from the - ordered Operation enums. This is used to determine which operations should be performed on a CTE vs. - included with the previous operation. - - Ex: After a user does a join we want to allow them to select which columns for the different - tables that they want to carry through to the following operation. If we put that join in - a CTE preemptively then the user would not have a chance to select which column they want - in cases where there is overlap in names. - """ - - def decorator(func: t.Callable): - @functools.wraps(func) - def wrapper(self: DataFrame, *args, **kwargs): - if self.last_op == Operation.INIT: - self = self._convert_leaf_to_cte() - self.last_op = Operation.NO_OP - last_op = self.last_op - new_op = op if op != Operation.NO_OP else last_op - if new_op < last_op or (last_op == new_op == Operation.SELECT): - self = self._convert_leaf_to_cte() - df: t.Union[DataFrame, GroupedData] = func(self, *args, **kwargs) - df.last_op = new_op # type: ignore - return df - - wrapper.__wrapped__ = func # type: ignore - return wrapper - - return decorator diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py deleted file mode 100644 index 0804486..0000000 --- a/sqlglot/dataframe/sql/readwriter.py +++ /dev/null @@ -1,108 +0,0 @@ -from __future__ import annotations - -import typing as t - -import sqlglot -from sqlglot import expressions as exp -from sqlglot.helper import object_to_dict - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql.dataframe import DataFrame - from sqlglot.dataframe.sql.session import SparkSession - - -class DataFrameReader: - def __init__(self, spark: SparkSession): - self.spark = spark - - def table(self, tableName: str) -> DataFrame: - from sqlglot.dataframe.sql.dataframe import DataFrame - from sqlglot.dataframe.sql.session import SparkSession - - sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect) - - return DataFrame( - self.spark, - exp.Select() - .from_( - exp.to_table(tableName, dialect=SparkSession().dialect).transform( - SparkSession().dialect.normalize_identifier - ) - ) - .select( - *( - column - for column in sqlglot.schema.column_names( - tableName, dialect=SparkSession().dialect - ) - ) - ), - ) - - -class DataFrameWriter: - def __init__( - self, - df: DataFrame, - spark: t.Optional[SparkSession] = None, - mode: t.Optional[str] = None, - by_name: bool = False, - ): - self._df = df - self._spark = spark or df.spark - self._mode = mode - self._by_name = by_name - - def copy(self, **kwargs) -> DataFrameWriter: - return DataFrameWriter( - **{ - k[1:] if k.startswith("_") else k: v - for k, v in object_to_dict(self, **kwargs).items() - } - ) - - def sql(self, **kwargs) -> t.List[str]: - return self._df.sql(**kwargs) - - def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter: - return self.copy(_mode=saveMode) - - @property - def byName(self): - return self.copy(by_name=True) - - def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: - from sqlglot.dataframe.sql.session import SparkSession - - output_expression_container = exp.Insert( - **{ - "this": exp.to_table(tableName), - "overwrite": overwrite, - } - ) - df = self._df.copy(output_expression_container=output_expression_container) - if self._by_name: - columns = sqlglot.schema.column_names( - tableName, only_visible=True, dialect=SparkSession().dialect - ) - df = df._convert_leaf_to_cte().select(*columns) - - return self.copy(_df=df) - - def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): - if format is not None: - raise NotImplementedError("Providing Format in the save as table is not supported") - exists, replace, mode = None, None, mode or str(self._mode) - if mode == "append": - return self.insertInto(name) - if mode == "ignore": - exists = True - if mode == "overwrite": - replace = True - output_expression_container = exp.Create( - this=exp.to_table(name), - kind="TABLE", - exists=exists, - replace=replace, - ) - return self.copy(_df=self._df.copy(output_expression_container=output_expression_container)) diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py deleted file mode 100644 index 4e47aaa..0000000 --- a/sqlglot/dataframe/sql/session.py +++ /dev/null @@ -1,199 +0,0 @@ -from __future__ import annotations - -import typing as t -import uuid -from collections import defaultdict - -import sqlglot -from sqlglot import Dialect, expressions as exp -from sqlglot.dataframe.sql import functions as F -from sqlglot.dataframe.sql.dataframe import DataFrame -from sqlglot.dataframe.sql.readwriter import DataFrameReader -from sqlglot.dataframe.sql.types import StructType -from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input -from sqlglot.helper import classproperty -from sqlglot.optimizer import optimize -from sqlglot.optimizer.qualify_columns import quote_identifiers - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput - - -class SparkSession: - DEFAULT_DIALECT = "spark" - _instance = None - - def __init__(self): - if not hasattr(self, "known_ids"): - self.known_ids = set() - self.known_branch_ids = set() - self.known_sequence_ids = set() - self.name_to_sequence_id_mapping = defaultdict(list) - self.incrementing_id = 1 - self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT) - - def __new__(cls, *args, **kwargs) -> SparkSession: - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - @property - def read(self) -> DataFrameReader: - return DataFrameReader(self) - - def table(self, tableName: str) -> DataFrame: - return self.read.table(tableName) - - def createDataFrame( - self, - data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], - schema: t.Optional[SchemaInput] = None, - samplingRatio: t.Optional[float] = None, - verifySchema: bool = False, - ) -> DataFrame: - from sqlglot.dataframe.sql.dataframe import DataFrame - - if samplingRatio is not None or verifySchema: - raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") - if schema is not None and ( - not isinstance(schema, (StructType, str, list)) - or (isinstance(schema, list) and not isinstance(schema[0], str)) - ): - raise NotImplementedError("Only schema of either list or string of list supported") - if not data: - raise ValueError("Must provide data to create into a DataFrame") - - column_mapping: t.Dict[str, t.Optional[str]] - if schema is not None: - column_mapping = get_column_mapping_from_schema_input(schema) - elif isinstance(data[0], dict): - column_mapping = {col_name.strip(): None for col_name in data[0]} - else: - column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} - - data_expressions = [ - exp.tuple_( - *map( - lambda x: F.lit(x).expression, - row if not isinstance(row, dict) else row.values(), - ) - ) - for row in data - ] - - sel_columns = [ - ( - F.col(name).cast(data_type).alias(name).expression - if data_type is not None - else F.col(name).expression - ) - for name, data_type in column_mapping.items() - ] - - select_kwargs = { - "expressions": sel_columns, - "from": exp.From( - this=exp.Values( - expressions=data_expressions, - alias=exp.TableAlias( - this=exp.to_identifier(self._auto_incrementing_name), - columns=[exp.to_identifier(col_name) for col_name in column_mapping], - ), - ), - ), - } - - sel_expression = exp.Select(**select_kwargs) - return DataFrame(self, sel_expression) - - def _optimize( - self, expression: exp.Expression, dialect: t.Optional[Dialect] = None - ) -> exp.Expression: - dialect = dialect or self.dialect - quote_identifiers(expression, dialect=dialect) - return optimize(expression, dialect=dialect) - - def sql(self, sqlQuery: str) -> DataFrame: - expression = self._optimize(sqlglot.parse_one(sqlQuery, read=self.dialect)) - if isinstance(expression, exp.Select): - df = DataFrame(self, expression) - df = df._convert_leaf_to_cte() - elif isinstance(expression, (exp.Create, exp.Insert)): - select_expression = expression.expression.copy() - if isinstance(expression, exp.Insert): - select_expression.set("with", expression.args.get("with")) - expression.set("with", None) - del expression.args["expression"] - df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore - df = df._convert_leaf_to_cte() - else: - raise ValueError( - "Unknown expression type provided in the SQL. Please create an issue with the SQL." - ) - return df - - @property - def _auto_incrementing_name(self) -> str: - name = f"a{self.incrementing_id}" - self.incrementing_id += 1 - return name - - @property - def _random_branch_id(self) -> str: - id = self._random_id - self.known_branch_ids.add(id) - return id - - @property - def _random_sequence_id(self): - id = self._random_id - self.known_sequence_ids.add(id) - return id - - @property - def _random_id(self) -> str: - id = "r" + uuid.uuid4().hex - self.known_ids.add(id) - return id - - @property - def _join_hint_names(self) -> t.Set[str]: - return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"} - - def _add_alias_to_mapping(self, name: str, sequence_id: str): - self.name_to_sequence_id_mapping[name].append(sequence_id) - - class Builder: - SQLFRAME_DIALECT_KEY = "sqlframe.dialect" - - def __init__(self): - self.dialect = "spark" - - def __getattr__(self, item) -> SparkSession.Builder: - return self - - def __call__(self, *args, **kwargs): - return self - - def config( - self, - key: t.Optional[str] = None, - value: t.Optional[t.Any] = None, - *, - map: t.Optional[t.Dict[str, t.Any]] = None, - **kwargs: t.Any, - ) -> SparkSession.Builder: - if key == self.SQLFRAME_DIALECT_KEY: - self.dialect = value - elif map and self.SQLFRAME_DIALECT_KEY in map: - self.dialect = map[self.SQLFRAME_DIALECT_KEY] - return self - - def getOrCreate(self) -> SparkSession: - spark = SparkSession() - spark.dialect = Dialect.get_or_raise(self.dialect) - return spark - - @classproperty - def builder(cls) -> Builder: - return cls.Builder() diff --git a/sqlglot/dataframe/sql/transforms.py b/sqlglot/dataframe/sql/transforms.py deleted file mode 100644 index b3dcc12..0000000 --- a/sqlglot/dataframe/sql/transforms.py +++ /dev/null @@ -1,9 +0,0 @@ -import typing as t - -from sqlglot import expressions as exp - - -def replace_id_value(node, replacement_mapping: t.Dict[exp.Identifier, exp.Identifier]): - if isinstance(node, exp.Identifier) and node in replacement_mapping: - node = node.replace(replacement_mapping[node].copy()) - return node diff --git a/sqlglot/dataframe/sql/types.py b/sqlglot/dataframe/sql/types.py deleted file mode 100644 index a63e505..0000000 --- a/sqlglot/dataframe/sql/types.py +++ /dev/null @@ -1,212 +0,0 @@ -import typing as t - - -class DataType: - def __repr__(self) -> str: - return self.__class__.__name__ + "()" - - def __hash__(self) -> int: - return hash(str(self)) - - def __eq__(self, other: t.Any) -> bool: - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other: t.Any) -> bool: - return not self.__eq__(other) - - def __str__(self) -> str: - return self.typeName() - - @classmethod - def typeName(cls) -> str: - return cls.__name__[:-4].lower() - - def simpleString(self) -> str: - return str(self) - - def jsonValue(self) -> t.Union[str, t.Dict[str, t.Any]]: - return str(self) - - -class DataTypeWithLength(DataType): - def __init__(self, length: int): - self.length = length - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.length})" - - def __str__(self) -> str: - return f"{self.typeName()}({self.length})" - - -class StringType(DataType): - pass - - -class CharType(DataTypeWithLength): - pass - - -class VarcharType(DataTypeWithLength): - pass - - -class BinaryType(DataType): - pass - - -class BooleanType(DataType): - pass - - -class DateType(DataType): - pass - - -class TimestampType(DataType): - pass - - -class TimestampNTZType(DataType): - @classmethod - def typeName(cls) -> str: - return "timestamp_ntz" - - -class DecimalType(DataType): - def __init__(self, precision: int = 10, scale: int = 0): - self.precision = precision - self.scale = scale - - def simpleString(self) -> str: - return f"decimal({self.precision}, {self.scale})" - - def jsonValue(self) -> str: - return f"decimal({self.precision}, {self.scale})" - - def __repr__(self) -> str: - return f"DecimalType({self.precision}, {self.scale})" - - -class DoubleType(DataType): - pass - - -class FloatType(DataType): - pass - - -class ByteType(DataType): - def __str__(self) -> str: - return "tinyint" - - -class IntegerType(DataType): - def __str__(self) -> str: - return "int" - - -class LongType(DataType): - def __str__(self) -> str: - return "bigint" - - -class ShortType(DataType): - def __str__(self) -> str: - return "smallint" - - -class ArrayType(DataType): - def __init__(self, elementType: DataType, containsNull: bool = True): - self.elementType = elementType - self.containsNull = containsNull - - def __repr__(self) -> str: - return f"ArrayType({self.elementType, str(self.containsNull)}" - - def simpleString(self) -> str: - return f"array<{self.elementType.simpleString()}>" - - def jsonValue(self) -> t.Dict[str, t.Any]: - return { - "type": self.typeName(), - "elementType": self.elementType.jsonValue(), - "containsNull": self.containsNull, - } - - -class MapType(DataType): - def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True): - self.keyType = keyType - self.valueType = valueType - self.valueContainsNull = valueContainsNull - - def __repr__(self) -> str: - return f"MapType({self.keyType}, {self.valueType}, {str(self.valueContainsNull)})" - - def simpleString(self) -> str: - return f"map<{self.keyType.simpleString()}, {self.valueType.simpleString()}>" - - def jsonValue(self) -> t.Dict[str, t.Any]: - return { - "type": self.typeName(), - "keyType": self.keyType.jsonValue(), - "valueType": self.valueType.jsonValue(), - "valueContainsNull": self.valueContainsNull, - } - - -class StructField(DataType): - def __init__( - self, - name: str, - dataType: DataType, - nullable: bool = True, - metadata: t.Optional[t.Dict[str, t.Any]] = None, - ): - self.name = name - self.dataType = dataType - self.nullable = nullable - self.metadata = metadata or {} - - def __repr__(self) -> str: - return f"StructField('{self.name}', {self.dataType}, {str(self.nullable)})" - - def simpleString(self) -> str: - return f"{self.name}:{self.dataType.simpleString()}" - - def jsonValue(self) -> t.Dict[str, t.Any]: - return { - "name": self.name, - "type": self.dataType.jsonValue(), - "nullable": self.nullable, - "metadata": self.metadata, - } - - -class StructType(DataType): - def __init__(self, fields: t.Optional[t.List[StructField]] = None): - if not fields: - self.fields = [] - self.names = [] - else: - self.fields = fields - self.names = [f.name for f in fields] - - def __iter__(self) -> t.Iterator[StructField]: - return iter(self.fields) - - def __len__(self) -> int: - return len(self.fields) - - def __repr__(self) -> str: - return f"StructType({', '.join(str(field) for field in self)})" - - def simpleString(self) -> str: - return f"struct<{', '.join(x.simpleString() for x in self)}>" - - def jsonValue(self) -> t.Dict[str, t.Any]: - return {"type": self.typeName(), "fields": [x.jsonValue() for x in self]} - - def fieldNames(self) -> t.List[str]: - return list(self.names) diff --git a/sqlglot/dataframe/sql/util.py b/sqlglot/dataframe/sql/util.py deleted file mode 100644 index 4b9fbb1..0000000 --- a/sqlglot/dataframe/sql/util.py +++ /dev/null @@ -1,32 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import expressions as exp -from sqlglot.dataframe.sql import types - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql._typing import SchemaInput - - -def get_column_mapping_from_schema_input(schema: SchemaInput) -> t.Dict[str, t.Optional[str]]: - if isinstance(schema, dict): - return schema - elif isinstance(schema, str): - col_name_type_strs = [x.strip() for x in schema.split(",")] - return { - name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() - for name_type_str in col_name_type_strs - } - elif isinstance(schema, types.StructType): - return {struct_field.name: struct_field.dataType.simpleString() for struct_field in schema} - return {x.strip(): None for x in schema} # type: ignore - - -def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.Table]: - if not expression.args.get("joins"): - return [] - - left_table = expression.args["from"].this - other_tables = [join.this for join in expression.args["joins"]] - return [left_table] + other_tables diff --git a/sqlglot/dataframe/sql/window.py b/sqlglot/dataframe/sql/window.py deleted file mode 100644 index 9e2fabd..0000000 --- a/sqlglot/dataframe/sql/window.py +++ /dev/null @@ -1,136 +0,0 @@ -from __future__ import annotations - -import sys -import typing as t - -from sqlglot import expressions as exp -from sqlglot.dataframe.sql import functions as F -from sqlglot.helper import flatten - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql._typing import ColumnOrName - - -class Window: - _JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808 - _JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807 - _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG) - _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG) - - unboundedPreceding: int = _JAVA_MIN_LONG - - unboundedFollowing: int = _JAVA_MAX_LONG - - currentRow: int = 0 - - @classmethod - def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: - return WindowSpec().partitionBy(*cols) - - @classmethod - def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: - return WindowSpec().orderBy(*cols) - - @classmethod - def rowsBetween(cls, start: int, end: int) -> WindowSpec: - return WindowSpec().rowsBetween(start, end) - - @classmethod - def rangeBetween(cls, start: int, end: int) -> WindowSpec: - return WindowSpec().rangeBetween(start, end) - - -class WindowSpec: - def __init__(self, expression: exp.Expression = exp.Window()): - self.expression = expression - - def copy(self): - return WindowSpec(self.expression.copy()) - - def sql(self, **kwargs) -> str: - from sqlglot.dataframe.sql.session import SparkSession - - return self.expression.sql(dialect=SparkSession().dialect, **kwargs) - - def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: - from sqlglot.dataframe.sql.column import Column - - cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore - expressions = [Column.ensure_col(x).expression for x in cols] - window_spec = self.copy() - partition_by_expressions = window_spec.expression.args.get("partition_by", []) - partition_by_expressions.extend(expressions) - window_spec.expression.set("partition_by", partition_by_expressions) - return window_spec - - def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: - from sqlglot.dataframe.sql.column import Column - - cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore - expressions = [Column.ensure_col(x).expression for x in cols] - window_spec = self.copy() - if window_spec.expression.args.get("order") is None: - window_spec.expression.set("order", exp.Order(expressions=[])) - order_by = window_spec.expression.args["order"].expressions - order_by.extend(expressions) - window_spec.expression.args["order"].set("expressions", order_by) - return window_spec - - def _calc_start_end( - self, start: int, end: int - ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]: - kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = { - "start_side": None, - "end_side": None, - } - if start == Window.currentRow: - kwargs["start"] = "CURRENT ROW" - else: - kwargs = { - **kwargs, - **{ - "start_side": "PRECEDING", - "start": ( - "UNBOUNDED" - if start <= Window.unboundedPreceding - else F.lit(start).expression - ), - }, - } - if end == Window.currentRow: - kwargs["end"] = "CURRENT ROW" - else: - kwargs = { - **kwargs, - **{ - "end_side": "FOLLOWING", - "end": ( - "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression - ), - }, - } - return kwargs - - def rowsBetween(self, start: int, end: int) -> WindowSpec: - window_spec = self.copy() - spec = self._calc_start_end(start, end) - spec["kind"] = "ROWS" - window_spec.expression.set( - "spec", - exp.WindowSpec( - **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} - ), - ) - return window_spec - - def rangeBetween(self, start: int, end: int) -> WindowSpec: - window_spec = self.copy() - spec = self._calc_start_end(start, end) - spec["kind"] = "RANGE" - window_spec.expression.set( - "spec", - exp.WindowSpec( - **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} - ), - ) - return window_spec |