summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-23 07:22:23 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-23 07:22:23 +0000
commit0d1477fdf20125df42fe49355b757625417c8f8c (patch)
tree0ace7a95d185b2b1ae36e25e341bf92cd9021cb0 /sqlglot/dataframe/sql
parentReleasing debian version 23.16.0-1. (diff)
downloadsqlglot-0d1477fdf20125df42fe49355b757625417c8f8c.tar.xz
sqlglot-0d1477fdf20125df42fe49355b757625417c8f8c.zip
Merging upstream version 24.0.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dataframe/sql')
-rw-r--r--sqlglot/dataframe/sql/__init__.py18
-rw-r--r--sqlglot/dataframe/sql/_typing.py18
-rw-r--r--sqlglot/dataframe/sql/column.py342
-rw-r--r--sqlglot/dataframe/sql/dataframe.py862
-rw-r--r--sqlglot/dataframe/sql/functions.py1270
-rw-r--r--sqlglot/dataframe/sql/group.py59
-rw-r--r--sqlglot/dataframe/sql/normalize.py78
-rw-r--r--sqlglot/dataframe/sql/operations.py53
-rw-r--r--sqlglot/dataframe/sql/readwriter.py108
-rw-r--r--sqlglot/dataframe/sql/session.py199
-rw-r--r--sqlglot/dataframe/sql/transforms.py9
-rw-r--r--sqlglot/dataframe/sql/types.py212
-rw-r--r--sqlglot/dataframe/sql/util.py32
-rw-r--r--sqlglot/dataframe/sql/window.py136
14 files changed, 0 insertions, 3396 deletions
diff --git a/sqlglot/dataframe/sql/__init__.py b/sqlglot/dataframe/sql/__init__.py
deleted file mode 100644
index 3f90802..0000000
--- a/sqlglot/dataframe/sql/__init__.py
+++ /dev/null
@@ -1,18 +0,0 @@
-from sqlglot.dataframe.sql.column import Column
-from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions
-from sqlglot.dataframe.sql.group import GroupedData
-from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter
-from sqlglot.dataframe.sql.session import SparkSession
-from sqlglot.dataframe.sql.window import Window, WindowSpec
-
-__all__ = [
- "SparkSession",
- "DataFrame",
- "GroupedData",
- "Column",
- "DataFrameNaFunctions",
- "Window",
- "WindowSpec",
- "DataFrameReader",
- "DataFrameWriter",
-]
diff --git a/sqlglot/dataframe/sql/_typing.py b/sqlglot/dataframe/sql/_typing.py
deleted file mode 100644
index fb46026..0000000
--- a/sqlglot/dataframe/sql/_typing.py
+++ /dev/null
@@ -1,18 +0,0 @@
-from __future__ import annotations
-
-import datetime
-import typing as t
-
-from sqlglot import expressions as exp
-
-if t.TYPE_CHECKING:
- from sqlglot.dataframe.sql.column import Column
- from sqlglot.dataframe.sql.types import StructType
-
-ColumnLiterals = t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
-ColumnOrName = t.Union[Column, str]
-ColumnOrLiteral = t.Union[
- Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime
-]
-SchemaInput = t.Union[str, t.List[str], StructType, t.Dict[str, t.Optional[str]]]
-OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert]
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
deleted file mode 100644
index 724c5bf..0000000
--- a/sqlglot/dataframe/sql/column.py
+++ /dev/null
@@ -1,342 +0,0 @@
-from __future__ import annotations
-
-import typing as t
-
-import sqlglot
-from sqlglot import expressions as exp
-from sqlglot.dataframe.sql.types import DataType
-from sqlglot.helper import flatten, is_iterable
-
-if t.TYPE_CHECKING:
- from sqlglot.dataframe.sql._typing import ColumnOrLiteral
- from sqlglot.dataframe.sql.window import WindowSpec
-
-
-class Column:
- def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
- from sqlglot.dataframe.sql.session import SparkSession
-
- if isinstance(expression, Column):
- expression = expression.expression # type: ignore
- elif expression is None or not isinstance(expression, (str, exp.Expression)):
- expression = self._lit(expression).expression # type: ignore
- elif not isinstance(expression, exp.Column):
- expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform(
- SparkSession().dialect.normalize_identifier, copy=False
- )
- if expression is None:
- raise ValueError(f"Could not parse {expression}")
-
- self.expression: exp.Expression = expression # type: ignore
-
- def __repr__(self):
- return repr(self.expression)
-
- def __hash__(self):
- return hash(self.expression)
-
- def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore
- return self.binary_op(exp.EQ, other)
-
- def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore
- return self.binary_op(exp.NEQ, other)
-
- def __gt__(self, other: ColumnOrLiteral) -> Column:
- return self.binary_op(exp.GT, other)
-
- def __ge__(self, other: ColumnOrLiteral) -> Column:
- return self.binary_op(exp.GTE, other)
-
- def __lt__(self, other: ColumnOrLiteral) -> Column:
- return self.binary_op(exp.LT, other)
-
- def __le__(self, other: ColumnOrLiteral) -> Column:
- return self.binary_op(exp.LTE, other)
-
- def __and__(self, other: ColumnOrLiteral) -> Column:
- return self.binary_op(exp.And, other)
-
- def __or__(self, other: ColumnOrLiteral) -> Column:
- return self.binary_op(exp.Or, other)
-
- def __mod__(self, other: ColumnOrLiteral) -> Column:
- return self.binary_op(exp.Mod, other)
-
- def __add__(self, other: ColumnOrLiteral) -> Column:
- return self.binary_op(exp.Add, other)
-
- def __sub__(self, other: ColumnOrLiteral) -> Column:
- return self.binary_op(exp.Sub, other)
-
- def __mul__(self, other: ColumnOrLiteral) -> Column:
- return self.binary_op(exp.Mul, other)
-
- def __truediv__(self, other: ColumnOrLiteral) -> Column:
- return self.binary_op(exp.Div, other)
-
- def __div__(self, other: ColumnOrLiteral) -> Column:
- return self.binary_op(exp.Div, other)
-
- def __neg__(self) -> Column:
- return self.unary_op(exp.Neg)
-
- def __radd__(self, other: ColumnOrLiteral) -> Column:
- return self.inverse_binary_op(exp.Add, other)
-
- def __rsub__(self, other: ColumnOrLiteral) -> Column:
- return self.inverse_binary_op(exp.Sub, other)
-
- def __rmul__(self, other: ColumnOrLiteral) -> Column:
- return self.inverse_binary_op(exp.Mul, other)
-
- def __rdiv__(self, other: ColumnOrLiteral) -> Column:
- return self.inverse_binary_op(exp.Div, other)
-
- def __rtruediv__(self, other: ColumnOrLiteral) -> Column:
- return self.inverse_binary_op(exp.Div, other)
-
- def __rmod__(self, other: ColumnOrLiteral) -> Column:
- return self.inverse_binary_op(exp.Mod, other)
-
- def __pow__(self, power: ColumnOrLiteral, modulo=None):
- return Column(exp.Pow(this=self.expression, expression=Column(power).expression))
-
- def __rpow__(self, power: ColumnOrLiteral):
- return Column(exp.Pow(this=Column(power).expression, expression=self.expression))
-
- def __invert__(self):
- return self.unary_op(exp.Not)
-
- def __rand__(self, other: ColumnOrLiteral) -> Column:
- return self.inverse_binary_op(exp.And, other)
-
- def __ror__(self, other: ColumnOrLiteral) -> Column:
- return self.inverse_binary_op(exp.Or, other)
-
- @classmethod
- def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]) -> Column:
- return cls(value)
-
- @classmethod
- def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]:
- return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args]
-
- @classmethod
- def _lit(cls, value: ColumnOrLiteral) -> Column:
- if isinstance(value, dict):
- columns = [cls._lit(v).alias(k).expression for k, v in value.items()]
- return cls(exp.Struct(expressions=columns))
- return cls(exp.convert(value))
-
- @classmethod
- def invoke_anonymous_function(
- cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral]
- ) -> Column:
- columns = [] if column is None else [cls.ensure_col(column)]
- column_args = [cls.ensure_col(arg) for arg in args]
- expressions = [x.expression for x in columns + column_args]
- new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions)
- return Column(new_expression)
-
- @classmethod
- def invoke_expression_over_column(
- cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
- ) -> Column:
- ensured_column = None if column is None else cls.ensure_col(column)
- ensure_expression_values = {
- k: (
- [Column.ensure_col(x).expression for x in v]
- if is_iterable(v)
- else Column.ensure_col(v).expression
- )
- for k, v in kwargs.items()
- if v is not None
- }
- new_expression = (
- callable_expression(**ensure_expression_values)
- if ensured_column is None
- else callable_expression(
- this=ensured_column.column_expression, **ensure_expression_values
- )
- )
- return Column(new_expression)
-
- def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
- return Column(
- klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
- )
-
- def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
- return Column(
- klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
- )
-
- def unary_op(self, klass: t.Callable, **kwargs) -> Column:
- return Column(klass(this=self.column_expression, **kwargs))
-
- @property
- def is_alias(self):
- return isinstance(self.expression, exp.Alias)
-
- @property
- def is_column(self):
- return isinstance(self.expression, exp.Column)
-
- @property
- def column_expression(self) -> t.Union[exp.Column, exp.Literal]:
- return self.expression.unalias()
-
- @property
- def alias_or_name(self) -> str:
- return self.expression.alias_or_name
-
- @classmethod
- def ensure_literal(cls, value) -> Column:
- from sqlglot.dataframe.sql.functions import lit
-
- if isinstance(value, cls):
- value = value.expression
- if not isinstance(value, exp.Literal):
- return lit(value)
- return Column(value)
-
- def copy(self) -> Column:
- return Column(self.expression.copy())
-
- def set_table_name(self, table_name: str, copy=False) -> Column:
- expression = self.expression.copy() if copy else self.expression
- expression.set("table", exp.to_identifier(table_name))
- return Column(expression)
-
- def sql(self, **kwargs) -> str:
- from sqlglot.dataframe.sql.session import SparkSession
-
- return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
-
- def alias(self, name: str) -> Column:
- from sqlglot.dataframe.sql.session import SparkSession
-
- dialect = SparkSession().dialect
- alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect)
- new_expression = exp.alias_(
- self.column_expression,
- alias.this if isinstance(alias, exp.Column) else name,
- dialect=dialect,
- )
- return Column(new_expression)
-
- def asc(self) -> Column:
- new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
- return Column(new_expression)
-
- def desc(self) -> Column:
- new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
- return Column(new_expression)
-
- asc_nulls_first = asc
-
- def asc_nulls_last(self) -> Column:
- new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False)
- return Column(new_expression)
-
- def desc_nulls_first(self) -> Column:
- new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True)
- return Column(new_expression)
-
- desc_nulls_last = desc
-
- def when(self, condition: Column, value: t.Any) -> Column:
- from sqlglot.dataframe.sql.functions import when
-
- column_with_if = when(condition, value)
- if not isinstance(self.expression, exp.Case):
- return column_with_if
- new_column = self.copy()
- new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"])
- return new_column
-
- def otherwise(self, value: t.Any) -> Column:
- from sqlglot.dataframe.sql.functions import lit
-
- true_value = value if isinstance(value, Column) else lit(value)
- new_column = self.copy()
- new_column.expression.set("default", true_value.column_expression)
- return new_column
-
- def isNull(self) -> Column:
- new_expression = exp.Is(this=self.column_expression, expression=exp.Null())
- return Column(new_expression)
-
- def isNotNull(self) -> Column:
- new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
- return Column(new_expression)
-
- def cast(self, dataType: t.Union[str, DataType]) -> Column:
- """
- Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
- Sqlglot doesn't currently replicate this class so it only accepts a string
- """
- from sqlglot.dataframe.sql.session import SparkSession
-
- if isinstance(dataType, DataType):
- dataType = dataType.simpleString()
- return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect))
-
- def startswith(self, value: t.Union[str, Column]) -> Column:
- value = self._lit(value) if not isinstance(value, Column) else value
- return self.invoke_anonymous_function(self, "STARTSWITH", value)
-
- def endswith(self, value: t.Union[str, Column]) -> Column:
- value = self._lit(value) if not isinstance(value, Column) else value
- return self.invoke_anonymous_function(self, "ENDSWITH", value)
-
- def rlike(self, regexp: str) -> Column:
- return self.invoke_expression_over_column(
- column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression
- )
-
- def like(self, other: str):
- return self.invoke_expression_over_column(
- self, exp.Like, expression=self._lit(other).expression
- )
-
- def ilike(self, other: str):
- return self.invoke_expression_over_column(
- self, exp.ILike, expression=self._lit(other).expression
- )
-
- def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
- startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
- length = self._lit(length) if not isinstance(length, Column) else length
- return Column.invoke_expression_over_column(
- self, exp.Substring, start=startPos.expression, length=length.expression
- )
-
- def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]):
- columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
- expressions = [self._lit(x).expression for x in columns]
- return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore
-
- def between(
- self,
- lowerBound: t.Union[ColumnOrLiteral],
- upperBound: t.Union[ColumnOrLiteral],
- ) -> Column:
- lower_bound_exp = (
- self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
- )
- upper_bound_exp = (
- self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
- )
- return Column(
- exp.Between(
- this=self.column_expression,
- low=lower_bound_exp.expression,
- high=upper_bound_exp.expression,
- )
- )
-
- def over(self, window: WindowSpec) -> Column:
- window_expression = window.expression.copy()
- window_expression.set("this", self.column_expression)
- return Column(window_expression)
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
deleted file mode 100644
index 8316c36..0000000
--- a/sqlglot/dataframe/sql/dataframe.py
+++ /dev/null
@@ -1,862 +0,0 @@
-from __future__ import annotations
-
-import functools
-import logging
-import typing as t
-import zlib
-from copy import copy
-
-import sqlglot
-from sqlglot import Dialect, expressions as exp
-from sqlglot.dataframe.sql import functions as F
-from sqlglot.dataframe.sql.column import Column
-from sqlglot.dataframe.sql.group import GroupedData
-from sqlglot.dataframe.sql.normalize import normalize
-from sqlglot.dataframe.sql.operations import Operation, operation
-from sqlglot.dataframe.sql.readwriter import DataFrameWriter
-from sqlglot.dataframe.sql.transforms import replace_id_value
-from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
-from sqlglot.dataframe.sql.window import Window
-from sqlglot.helper import ensure_list, object_to_dict, seq_get
-
-if t.TYPE_CHECKING:
- from sqlglot.dataframe.sql._typing import (
- ColumnLiterals,
- ColumnOrLiteral,
- ColumnOrName,
- OutputExpressionContainer,
- )
- from sqlglot.dataframe.sql.session import SparkSession
- from sqlglot.dialects.dialect import DialectType
-
-logger = logging.getLogger("sqlglot")
-
-JOIN_HINTS = {
- "BROADCAST",
- "BROADCASTJOIN",
- "MAPJOIN",
- "MERGE",
- "SHUFFLEMERGE",
- "MERGEJOIN",
- "SHUFFLE_HASH",
- "SHUFFLE_REPLICATE_NL",
-}
-
-
-class DataFrame:
- def __init__(
- self,
- spark: SparkSession,
- expression: exp.Select,
- branch_id: t.Optional[str] = None,
- sequence_id: t.Optional[str] = None,
- last_op: Operation = Operation.INIT,
- pending_hints: t.Optional[t.List[exp.Expression]] = None,
- output_expression_container: t.Optional[OutputExpressionContainer] = None,
- **kwargs,
- ):
- self.spark = spark
- self.expression = expression
- self.branch_id = branch_id or self.spark._random_branch_id
- self.sequence_id = sequence_id or self.spark._random_sequence_id
- self.last_op = last_op
- self.pending_hints = pending_hints or []
- self.output_expression_container = output_expression_container or exp.Select()
-
- def __getattr__(self, column_name: str) -> Column:
- return self[column_name]
-
- def __getitem__(self, column_name: str) -> Column:
- column_name = f"{self.branch_id}.{column_name}"
- return Column(column_name)
-
- def __copy__(self):
- return self.copy()
-
- @property
- def sparkSession(self):
- return self.spark
-
- @property
- def write(self):
- return DataFrameWriter(self)
-
- @property
- def latest_cte_name(self) -> str:
- if not self.expression.ctes:
- from_exp = self.expression.args["from"]
- if from_exp.alias_or_name:
- return from_exp.alias_or_name
- table_alias = from_exp.find(exp.TableAlias)
- if not table_alias:
- raise RuntimeError(
- f"Could not find an alias name for this expression: {self.expression}"
- )
- return table_alias.alias_or_name
- return self.expression.ctes[-1].alias
-
- @property
- def pending_join_hints(self):
- return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)]
-
- @property
- def pending_partition_hints(self):
- return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)]
-
- @property
- def columns(self) -> t.List[str]:
- return self.expression.named_selects
-
- @property
- def na(self) -> DataFrameNaFunctions:
- return DataFrameNaFunctions(self)
-
- def _replace_cte_names_with_hashes(self, expression: exp.Select):
- replacement_mapping = {}
- for cte in expression.ctes:
- old_name_id = cte.args["alias"].this
- new_hashed_id = exp.to_identifier(
- self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
- )
- replacement_mapping[old_name_id] = new_hashed_id
- expression = expression.transform(replace_id_value, replacement_mapping).assert_is(
- exp.Select
- )
- return expression
-
- def _create_cte_from_expression(
- self,
- expression: exp.Expression,
- branch_id: t.Optional[str] = None,
- sequence_id: t.Optional[str] = None,
- **kwargs,
- ) -> t.Tuple[exp.CTE, str]:
- name = self._create_hash_from_expression(expression)
- expression_to_cte = expression.copy()
- expression_to_cte.set("with", None)
- cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
- cte.set("branch_id", branch_id or self.branch_id)
- cte.set("sequence_id", sequence_id or self.sequence_id)
- return cte, name
-
- @t.overload
- def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ...
-
- @t.overload
- def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ...
-
- def _ensure_list_of_columns(self, cols):
- return Column.ensure_cols(ensure_list(cols))
-
- def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None):
- cols = self._ensure_list_of_columns(cols)
- normalize(self.spark, expression or self.expression, cols)
- return cols
-
- def _ensure_and_normalize_col(self, col):
- col = Column.ensure_col(col)
- normalize(self.spark, self.expression, col)
- return col
-
- def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame:
- df = self._resolve_pending_hints()
- sequence_id = sequence_id or df.sequence_id
- expression = df.expression.copy()
- cte_expression, cte_name = df._create_cte_from_expression(
- expression=expression, sequence_id=sequence_id
- )
- new_expression = df._add_ctes_to_expression(
- exp.Select(), expression.ctes + [cte_expression]
- )
- sel_columns = df._get_outer_select_columns(cte_expression)
- new_expression = new_expression.from_(cte_name).select(
- *[x.alias_or_name for x in sel_columns]
- )
- return df.copy(expression=new_expression, sequence_id=sequence_id)
-
- def _resolve_pending_hints(self) -> DataFrame:
- df = self.copy()
- if not self.pending_hints:
- return df
- expression = df.expression
- hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
- for hint in df.pending_partition_hints:
- hint_expression.append("expressions", hint)
- df.pending_hints.remove(hint)
-
- join_aliases = {
- join_table.alias_or_name
- for join_table in get_tables_from_expression_with_join(expression)
- }
- if join_aliases:
- for hint in df.pending_join_hints:
- for sequence_id_expression in hint.expressions:
- sequence_id_or_name = sequence_id_expression.alias_or_name
- sequence_ids_to_match = [sequence_id_or_name]
- if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
- sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[
- sequence_id_or_name
- ]
- matching_ctes = [
- cte
- for cte in reversed(expression.ctes)
- if cte.args["sequence_id"] in sequence_ids_to_match
- ]
- for matching_cte in matching_ctes:
- if matching_cte.alias_or_name in join_aliases:
- sequence_id_expression.set("this", matching_cte.args["alias"].this)
- df.pending_hints.remove(hint)
- break
- hint_expression.append("expressions", hint)
- if hint_expression.expressions:
- expression.set("hint", hint_expression)
- return df
-
- def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
- hint_name = hint_name.upper()
- hint_expression = (
- exp.JoinHint(
- this=hint_name,
- expressions=[exp.to_table(parameter.alias_or_name) for parameter in args],
- )
- if hint_name in JOIN_HINTS
- else exp.Anonymous(
- this=hint_name, expressions=[parameter.expression for parameter in args]
- )
- )
- new_df = self.copy()
- new_df.pending_hints.append(hint_expression)
- return new_df
-
- def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool):
- other_df = other._convert_leaf_to_cte()
- base_expression = self.expression.copy()
- base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes)
- all_ctes = base_expression.ctes
- other_df.expression.set("with", None)
- base_expression.set("with", None)
- operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression)
- operation.set("with", exp.With(expressions=all_ctes))
- return self.copy(expression=operation)._convert_leaf_to_cte()
-
- def _cache(self, storage_level: str):
- df = self._convert_leaf_to_cte()
- df.expression.ctes[-1].set("cache_storage_level", storage_level)
- return df
-
- @classmethod
- def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select:
- expression = expression.copy()
- with_expression = expression.args.get("with")
- if with_expression:
- existing_ctes = with_expression.expressions
- existsing_cte_names = {x.alias_or_name for x in existing_ctes}
- for cte in ctes:
- if cte.alias_or_name not in existsing_cte_names:
- existing_ctes.append(cte)
- else:
- existing_ctes = ctes
- expression.set("with", exp.With(expressions=existing_ctes))
- return expression
-
- @classmethod
- def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]:
- expression = item.expression if isinstance(item, DataFrame) else item
- return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions]
-
- @classmethod
- def _create_hash_from_expression(cls, expression: exp.Expression) -> str:
- from sqlglot.dataframe.sql.session import SparkSession
-
- value = expression.sql(dialect=SparkSession().dialect).encode("utf-8")
- return f"t{zlib.crc32(value)}"[:6]
-
- def _get_select_expressions(
- self,
- ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
- select_expressions: t.List[
- t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]
- ] = []
- main_select_ctes: t.List[exp.CTE] = []
- for cte in self.expression.ctes:
- cache_storage_level = cte.args.get("cache_storage_level")
- if cache_storage_level:
- select_expression = cte.this.copy()
- select_expression.set("with", exp.With(expressions=copy(main_select_ctes)))
- select_expression.set("cte_alias_name", cte.alias_or_name)
- select_expression.set("cache_storage_level", cache_storage_level)
- select_expressions.append((exp.Cache, select_expression))
- else:
- main_select_ctes.append(cte)
- main_select = self.expression.copy()
- if main_select_ctes:
- main_select.set("with", exp.With(expressions=main_select_ctes))
- expression_select_pair = (type(self.output_expression_container), main_select)
- select_expressions.append(expression_select_pair) # type: ignore
- return select_expressions
-
- def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]:
- from sqlglot.dataframe.sql.session import SparkSession
-
- dialect = Dialect.get_or_raise(dialect or SparkSession().dialect)
-
- df = self._resolve_pending_hints()
- select_expressions = df._get_select_expressions()
- output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
- replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
-
- for expression_type, select_expression in select_expressions:
- select_expression = select_expression.transform(
- replace_id_value, replacement_mapping
- ).assert_is(exp.Select)
- if optimize:
- select_expression = t.cast(
- exp.Select, self.spark._optimize(select_expression, dialect=dialect)
- )
-
- select_expression = df._replace_cte_names_with_hashes(select_expression)
-
- expression: t.Union[exp.Select, exp.Cache, exp.Drop]
- if expression_type == exp.Cache:
- cache_table_name = df._create_hash_from_expression(select_expression)
- cache_table = exp.to_table(cache_table_name)
- original_alias_name = select_expression.args["cte_alias_name"]
-
- replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore
- cache_table_name
- )
- sqlglot.schema.add_table(
- cache_table_name,
- {
- expression.alias_or_name: expression.type.sql(dialect=dialect)
- for expression in select_expression.expressions
- },
- dialect=dialect,
- )
-
- cache_storage_level = select_expression.args["cache_storage_level"]
- options = [
- exp.Literal.string("storageLevel"),
- exp.Literal.string(cache_storage_level),
- ]
- expression = exp.Cache(
- this=cache_table, expression=select_expression, lazy=True, options=options
- )
-
- # We will drop the "view" if it exists before running the cache table
- output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
- elif expression_type == exp.Create:
- expression = df.output_expression_container.copy()
- expression.set("expression", select_expression)
- elif expression_type == exp.Insert:
- expression = df.output_expression_container.copy()
- select_without_ctes = select_expression.copy()
- select_without_ctes.set("with", None)
- expression.set("expression", select_without_ctes)
-
- if select_expression.ctes:
- expression.set("with", exp.With(expressions=select_expression.ctes))
- elif expression_type == exp.Select:
- expression = select_expression
- else:
- raise ValueError(f"Invalid expression type: {expression_type}")
-
- output_expressions.append(expression)
-
- return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions]
-
- def copy(self, **kwargs) -> DataFrame:
- return DataFrame(**object_to_dict(self, **kwargs))
-
- @operation(Operation.SELECT)
- def select(self, *cols, **kwargs) -> DataFrame:
- cols = self._ensure_and_normalize_cols(cols)
- kwargs["append"] = kwargs.get("append", False)
- if self.expression.args.get("joins"):
- ambiguous_cols = [
- col
- for col in cols
- if isinstance(col.column_expression, exp.Column) and not col.column_expression.table
- ]
- if ambiguous_cols:
- join_table_identifiers = [
- x.this for x in get_tables_from_expression_with_join(self.expression)
- ]
- cte_names_in_join = [x.this for x in join_table_identifiers]
- # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
- # and therefore we allow multiple columns with the same name in the result. This matches the behavior
- # of Spark.
- resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols}
- for ambiguous_col in ambiguous_cols:
- ctes_with_column = [
- cte
- for cte in self.expression.ctes
- if cte.alias_or_name in cte_names_in_join
- and ambiguous_col.alias_or_name in cte.this.named_selects
- ]
- # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise,
- # use the same CTE we used before
- cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1)
- if cte:
- resolved_column_position[ambiguous_col] += 1
- else:
- cte = ctes_with_column[resolved_column_position[ambiguous_col]]
- ambiguous_col.expression.set("table", cte.alias_or_name)
- return self.copy(
- expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs
- )
-
- @operation(Operation.NO_OP)
- def alias(self, name: str, **kwargs) -> DataFrame:
- new_sequence_id = self.spark._random_sequence_id
- df = self.copy()
- for join_hint in df.pending_join_hints:
- for expression in join_hint.expressions:
- if expression.alias_or_name == self.sequence_id:
- expression.set("this", Column.ensure_col(new_sequence_id).expression)
- df.spark._add_alias_to_mapping(name, new_sequence_id)
- return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
-
- @operation(Operation.WHERE)
- def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
- col = self._ensure_and_normalize_col(column)
- return self.copy(expression=self.expression.where(col.expression))
-
- filter = where
-
- @operation(Operation.GROUP_BY)
- def groupBy(self, *cols, **kwargs) -> GroupedData:
- columns = self._ensure_and_normalize_cols(cols)
- return GroupedData(self, columns, self.last_op)
-
- @operation(Operation.SELECT)
- def agg(self, *exprs, **kwargs) -> DataFrame:
- cols = self._ensure_and_normalize_cols(exprs)
- return self.groupBy().agg(*cols)
-
- @operation(Operation.FROM)
- def join(
- self,
- other_df: DataFrame,
- on: t.Union[str, t.List[str], Column, t.List[Column]],
- how: str = "inner",
- **kwargs,
- ) -> DataFrame:
- other_df = other_df._convert_leaf_to_cte()
- join_columns = self._ensure_list_of_columns(on)
- # We will determine actual "join on" expression later so we don't provide it at first
- join_expression = self.expression.join(
- other_df.latest_cte_name, join_type=how.replace("_", " ")
- )
- join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
- self_columns = self._get_outer_select_columns(join_expression)
- other_columns = self._get_outer_select_columns(other_df)
- # Determines the join clause and select columns to be used passed on what type of columns were provided for
- # the join. The columns returned changes based on how the on expression is provided.
- if isinstance(join_columns[0].expression, exp.Column):
- """
- Unique characteristics of join on column names only:
- * The column names are put at the front of the select list
- * The column names are deduplicated across the entire select list and only the column names (other dups are allowed)
- """
- table_names = [
- table.alias_or_name
- for table in get_tables_from_expression_with_join(join_expression)
- ]
- potential_ctes = [
- cte
- for cte in join_expression.ctes
- if cte.alias_or_name in table_names
- and cte.alias_or_name != other_df.latest_cte_name
- ]
- # Determine the table to reference for the left side of the join by checking each of the left side
- # tables and see if they have the column being referenced.
- join_column_pairs = []
- for join_column in join_columns:
- num_matching_ctes = 0
- for cte in potential_ctes:
- if join_column.alias_or_name in cte.this.named_selects:
- left_column = join_column.copy().set_table_name(cte.alias_or_name)
- right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
- join_column_pairs.append((left_column, right_column))
- num_matching_ctes += 1
- if num_matching_ctes > 1:
- raise ValueError(
- f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name."
- )
- elif num_matching_ctes == 0:
- raise ValueError(
- f"Column {join_column.alias_or_name} does not exist in any of the tables."
- )
- join_clause = functools.reduce(
- lambda x, y: x & y,
- [left_column == right_column for left_column, right_column in join_column_pairs],
- )
- join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
- # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
- select_column_names = [
- (
- column.alias_or_name
- if not isinstance(column.expression.this, exp.Star)
- else column.sql()
- )
- for column in self_columns + other_columns
- ]
- select_column_names = [
- column_name
- for column_name in select_column_names
- if column_name not in join_column_names
- ]
- select_column_names = join_column_names + select_column_names
- else:
- """
- Unique characteristics of join on expressions:
- * There is no deduplication of the results.
- * The left join dataframe columns go first and right come after. No sort preference is given to join columns
- """
- join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
- if len(join_columns) > 1:
- join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
- join_clause = join_columns[0]
- select_column_names = [column.alias_or_name for column in self_columns + other_columns]
-
- # Update the on expression with the actual join clause to replace the dummy one from before
- join_expression.args["joins"][-1].set("on", join_clause.expression)
- new_df = self.copy(expression=join_expression)
- new_df.pending_join_hints.extend(self.pending_join_hints)
- new_df.pending_hints.extend(other_df.pending_hints)
- new_df = new_df.select.__wrapped__(new_df, *select_column_names)
- return new_df
-
- @operation(Operation.ORDER_BY)
- def orderBy(
- self,
- *cols: t.Union[str, Column],
- ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
- ) -> DataFrame:
- """
- This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
- has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
- is unlikely to come up.
- """
- columns = self._ensure_and_normalize_cols(cols)
- pre_ordered_col_indexes = [
- i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered)
- ]
- if ascending is None:
- ascending = [True] * len(columns)
- elif not isinstance(ascending, list):
- ascending = [ascending] * len(columns)
- ascending = [bool(x) for i, x in enumerate(ascending)]
- assert len(columns) == len(
- ascending
- ), "The length of items in ascending must equal the number of columns provided"
- col_and_ascending = list(zip(columns, ascending))
- order_by_columns = [
- (
- exp.Ordered(this=col.expression, desc=not asc)
- if i not in pre_ordered_col_indexes
- else columns[i].column_expression
- )
- for i, (col, asc) in enumerate(col_and_ascending)
- ]
- return self.copy(expression=self.expression.order_by(*order_by_columns))
-
- sort = orderBy
-
- @operation(Operation.FROM)
- def union(self, other: DataFrame) -> DataFrame:
- return self._set_operation(exp.Union, other, False)
-
- unionAll = union
-
- @operation(Operation.FROM)
- def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
- l_columns = self.columns
- r_columns = other.columns
- if not allowMissingColumns:
- l_expressions = l_columns
- r_expressions = l_columns
- else:
- l_expressions = []
- r_expressions = []
- r_columns_unused = copy(r_columns)
- for l_column in l_columns:
- l_expressions.append(l_column)
- if l_column in r_columns:
- r_expressions.append(l_column)
- r_columns_unused.remove(l_column)
- else:
- r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False))
- for r_column in r_columns_unused:
- l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False))
- r_expressions.append(r_column)
- r_df = (
- other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
- )
- l_df = self.copy()
- if allowMissingColumns:
- l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
- return l_df._set_operation(exp.Union, r_df, False)
-
- @operation(Operation.FROM)
- def intersect(self, other: DataFrame) -> DataFrame:
- return self._set_operation(exp.Intersect, other, True)
-
- @operation(Operation.FROM)
- def intersectAll(self, other: DataFrame) -> DataFrame:
- return self._set_operation(exp.Intersect, other, False)
-
- @operation(Operation.FROM)
- def exceptAll(self, other: DataFrame) -> DataFrame:
- return self._set_operation(exp.Except, other, False)
-
- @operation(Operation.SELECT)
- def distinct(self) -> DataFrame:
- return self.copy(expression=self.expression.distinct())
-
- @operation(Operation.SELECT)
- def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
- if not subset:
- return self.distinct()
- column_names = ensure_list(subset)
- window = Window.partitionBy(*column_names).orderBy(*column_names)
- return (
- self.copy()
- .withColumn("row_num", F.row_number().over(window))
- .where(F.col("row_num") == F.lit(1))
- .drop("row_num")
- )
-
- @operation(Operation.FROM)
- def dropna(
- self,
- how: str = "any",
- thresh: t.Optional[int] = None,
- subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
- ) -> DataFrame:
- minimum_non_null = thresh or 0 # will be determined later if thresh is null
- new_df = self.copy()
- all_columns = self._get_outer_select_columns(new_df.expression)
- if subset:
- null_check_columns = self._ensure_and_normalize_cols(subset)
- else:
- null_check_columns = all_columns
- if thresh is None:
- minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
- else:
- minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
- if minimum_num_nulls > len(null_check_columns):
- raise RuntimeError(
- f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
- f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
- )
- if_null_checks = [
- F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
- ]
- nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
- num_nulls = nulls_added_together.alias("num_nulls")
- new_df = new_df.select(num_nulls, append=True)
- filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
- final_df = filtered_df.select(*all_columns)
- return final_df
-
- @operation(Operation.FROM)
- def fillna(
- self,
- value: t.Union[ColumnLiterals],
- subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
- ) -> DataFrame:
- """
- Functionality Difference: If you provide a value to replace a null and that type conflicts
- with the type of the column then PySpark will just ignore your replacement.
- This will try to cast them to be the same in some cases. So they won't always match.
- Best to not mix types so make sure replacement is the same type as the column
-
- Possibility for improvement: Use `typeof` function to get the type of the column
- and check if it matches the type of the value provided. If not then make it null.
- """
- from sqlglot.dataframe.sql.functions import lit
-
- values = None
- columns = None
- new_df = self.copy()
- all_columns = self._get_outer_select_columns(new_df.expression)
- all_column_mapping = {column.alias_or_name: column for column in all_columns}
- if isinstance(value, dict):
- values = list(value.values())
- columns = self._ensure_and_normalize_cols(list(value))
- if not columns:
- columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
- if not values:
- values = [value] * len(columns)
- value_columns = [lit(value) for value in values]
-
- null_replacement_mapping = {
- column.alias_or_name: (
- F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
- )
- for column, value in zip(columns, value_columns)
- }
- null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
- null_replacement_columns = [
- null_replacement_mapping[column.alias_or_name] for column in all_columns
- ]
- new_df = new_df.select(*null_replacement_columns)
- return new_df
-
- @operation(Operation.FROM)
- def replace(
- self,
- to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
- value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
- subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
- ) -> DataFrame:
- from sqlglot.dataframe.sql.functions import lit
-
- old_values = None
- new_df = self.copy()
- all_columns = self._get_outer_select_columns(new_df.expression)
- all_column_mapping = {column.alias_or_name: column for column in all_columns}
-
- columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
- if isinstance(to_replace, dict):
- old_values = list(to_replace)
- new_values = list(to_replace.values())
- elif not old_values and isinstance(to_replace, list):
- assert isinstance(value, list), "value must be a list since the replacements are a list"
- assert len(to_replace) == len(
- value
- ), "the replacements and values must be the same length"
- old_values = to_replace
- new_values = value
- else:
- old_values = [to_replace] * len(columns)
- new_values = [value] * len(columns)
- old_values = [lit(value) for value in old_values]
- new_values = [lit(value) for value in new_values]
-
- replacement_mapping = {}
- for column in columns:
- expression = Column(None)
- for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
- if i == 0:
- expression = F.when(column == old_value, new_value)
- else:
- expression = expression.when(column == old_value, new_value) # type: ignore
- replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
- column.expression.alias_or_name
- )
-
- replacement_mapping = {**all_column_mapping, **replacement_mapping}
- replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
- new_df = new_df.select(*replacement_columns)
- return new_df
-
- @operation(Operation.SELECT)
- def withColumn(self, colName: str, col: Column) -> DataFrame:
- col = self._ensure_and_normalize_col(col)
- existing_col_names = self.expression.named_selects
- existing_col_index = (
- existing_col_names.index(colName) if colName in existing_col_names else None
- )
- if existing_col_index:
- expression = self.expression.copy()
- expression.expressions[existing_col_index] = col.expression
- return self.copy(expression=expression)
- return self.copy().select(col.alias(colName), append=True)
-
- @operation(Operation.SELECT)
- def withColumnRenamed(self, existing: str, new: str):
- expression = self.expression.copy()
- existing_columns = [
- expression
- for expression in expression.expressions
- if expression.alias_or_name == existing
- ]
- if not existing_columns:
- raise ValueError("Tried to rename a column that doesn't exist")
- for existing_column in existing_columns:
- if isinstance(existing_column, exp.Column):
- existing_column.replace(exp.alias_(existing_column, new))
- else:
- existing_column.set("alias", exp.to_identifier(new))
- return self.copy(expression=expression)
-
- @operation(Operation.SELECT)
- def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
- all_columns = self._get_outer_select_columns(self.expression)
- drop_cols = self._ensure_and_normalize_cols(cols)
- new_columns = [
- col
- for col in all_columns
- if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
- ]
- return self.copy().select(*new_columns, append=False)
-
- @operation(Operation.LIMIT)
- def limit(self, num: int) -> DataFrame:
- return self.copy(expression=self.expression.limit(num))
-
- @operation(Operation.NO_OP)
- def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
- parameter_list = ensure_list(parameters)
- parameter_columns = (
- self._ensure_list_of_columns(parameter_list)
- if parameters
- else Column.ensure_cols([self.sequence_id])
- )
- return self._hint(name, parameter_columns)
-
- @operation(Operation.NO_OP)
- def repartition(
- self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
- ) -> DataFrame:
- num_partition_cols = self._ensure_list_of_columns(numPartitions)
- columns = self._ensure_and_normalize_cols(cols)
- args = num_partition_cols + columns
- return self._hint("repartition", args)
-
- @operation(Operation.NO_OP)
- def coalesce(self, numPartitions: int) -> DataFrame:
- num_partitions = Column.ensure_cols([numPartitions])
- return self._hint("coalesce", num_partitions)
-
- @operation(Operation.NO_OP)
- def cache(self) -> DataFrame:
- return self._cache(storage_level="MEMORY_AND_DISK")
-
- @operation(Operation.NO_OP)
- def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
- """
- Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
- """
- return self._cache(storageLevel)
-
-
-class DataFrameNaFunctions:
- def __init__(self, df: DataFrame):
- self.df = df
-
- def drop(
- self,
- how: str = "any",
- thresh: t.Optional[int] = None,
- subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
- ) -> DataFrame:
- return self.df.dropna(how=how, thresh=thresh, subset=subset)
-
- def fill(
- self,
- value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
- subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
- ) -> DataFrame:
- return self.df.fillna(value=value, subset=subset)
-
- def replace(
- self,
- to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
- value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
- subset: t.Optional[t.Union[str, t.List[str]]] = None,
- ) -> DataFrame:
- return self.df.replace(to_replace=to_replace, value=value, subset=subset)
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
deleted file mode 100644
index 81b7d61..0000000
--- a/sqlglot/dataframe/sql/functions.py
+++ /dev/null
@@ -1,1270 +0,0 @@
-from __future__ import annotations
-
-import typing as t
-
-from sqlglot import exp as expression
-from sqlglot.dataframe.sql.column import Column
-from sqlglot.helper import ensure_list, flatten as _flatten
-
-if t.TYPE_CHECKING:
- from sqlglot.dataframe.sql._typing import ColumnOrLiteral, ColumnOrName
- from sqlglot.dataframe.sql.dataframe import DataFrame
-
-
-def col(column_name: t.Union[ColumnOrName, t.Any]) -> Column:
- return Column(column_name)
-
-
-def lit(value: t.Optional[t.Any] = None) -> Column:
- if isinstance(value, str):
- return Column(expression.Literal.string(str(value)))
- return Column(value)
-
-
-def greatest(*cols: ColumnOrName) -> Column:
- if len(cols) > 1:
- return Column.invoke_expression_over_column(
- cols[0], expression.Greatest, expressions=cols[1:]
- )
- return Column.invoke_expression_over_column(cols[0], expression.Greatest)
-
-
-def least(*cols: ColumnOrName) -> Column:
- if len(cols) > 1:
- return Column.invoke_expression_over_column(cols[0], expression.Least, expressions=cols[1:])
- return Column.invoke_expression_over_column(cols[0], expression.Least)
-
-
-def count_distinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
- columns = [Column.ensure_col(x) for x in [col] + list(cols)]
- return Column(
- expression.Count(this=expression.Distinct(expressions=[x.expression for x in columns]))
- )
-
-
-def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
- return count_distinct(col, *cols)
-
-
-def when(condition: Column, value: t.Any) -> Column:
- true_value = value if isinstance(value, Column) else lit(value)
- return Column(
- expression.Case(
- ifs=[expression.If(this=condition.column_expression, true=true_value.column_expression)]
- )
- )
-
-
-def asc(col: ColumnOrName) -> Column:
- return Column.ensure_col(col).asc()
-
-
-def desc(col: ColumnOrName):
- return Column.ensure_col(col).desc()
-
-
-def broadcast(df: DataFrame) -> DataFrame:
- return df.hint("broadcast")
-
-
-def sqrt(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Sqrt)
-
-
-def abs(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Abs)
-
-
-def max(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Max)
-
-
-def min(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Min)
-
-
-def max_by(col: ColumnOrName, ord: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.ArgMax, expression=ord)
-
-
-def min_by(col: ColumnOrName, ord: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.ArgMin, expression=ord)
-
-
-def count(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Count)
-
-
-def sum(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Sum)
-
-
-def avg(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Avg)
-
-
-def mean(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "MEAN")
-
-
-def sumDistinct(col: ColumnOrName) -> Column:
- return sum_distinct(col)
-
-
-def sum_distinct(col: ColumnOrName) -> Column:
- raise NotImplementedError("Sum distinct is not currently implemented")
-
-
-def product(col: ColumnOrName) -> Column:
- raise NotImplementedError("Product is not currently implemented")
-
-
-def acos(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "ACOS")
-
-
-def acosh(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "ACOSH")
-
-
-def asin(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "ASIN")
-
-
-def asinh(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "ASINH")
-
-
-def atan(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "ATAN")
-
-
-def atan2(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
- return Column.invoke_anonymous_function(col1, "ATAN2", col2)
-
-
-def atanh(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "ATANH")
-
-
-def cbrt(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Cbrt)
-
-
-def ceil(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Ceil)
-
-
-def cos(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "COS")
-
-
-def cosh(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "COSH")
-
-
-def cot(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "COT")
-
-
-def csc(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "CSC")
-
-
-def exp(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Exp)
-
-
-def expm1(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "EXPM1")
-
-
-def floor(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Floor)
-
-
-def log10(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(lit(10), expression.Log, expression=col)
-
-
-def log1p(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "LOG1P")
-
-
-def log2(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(lit(2), expression.Log, expression=col)
-
-
-def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column:
- if arg2 is None:
- return Column.invoke_expression_over_column(arg1, expression.Ln)
- return Column.invoke_expression_over_column(arg1, expression.Log, expression=arg2)
-
-
-def rint(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "RINT")
-
-
-def sec(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "SEC")
-
-
-def signum(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Sign)
-
-
-def sin(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "SIN")
-
-
-def sinh(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "SINH")
-
-
-def tan(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "TAN")
-
-
-def tanh(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "TANH")
-
-
-def toDegrees(col: ColumnOrName) -> Column:
- return degrees(col)
-
-
-def degrees(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "DEGREES")
-
-
-def toRadians(col: ColumnOrName) -> Column:
- return radians(col)
-
-
-def radians(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "RADIANS")
-
-
-def bitwiseNOT(col: ColumnOrName) -> Column:
- return bitwise_not(col)
-
-
-def bitwise_not(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.BitwiseNot)
-
-
-def asc_nulls_first(col: ColumnOrName) -> Column:
- return Column.ensure_col(col).asc_nulls_first()
-
-
-def asc_nulls_last(col: ColumnOrName) -> Column:
- return Column.ensure_col(col).asc_nulls_last()
-
-
-def desc_nulls_first(col: ColumnOrName) -> Column:
- return Column.ensure_col(col).desc_nulls_first()
-
-
-def desc_nulls_last(col: ColumnOrName) -> Column:
- return Column.ensure_col(col).desc_nulls_last()
-
-
-def stddev(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Stddev)
-
-
-def stddev_samp(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.StddevSamp)
-
-
-def stddev_pop(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.StddevPop)
-
-
-def variance(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Variance)
-
-
-def var_samp(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Variance)
-
-
-def var_pop(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.VariancePop)
-
-
-def skewness(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "SKEWNESS")
-
-
-def kurtosis(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "KURTOSIS")
-
-
-def collect_list(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.ArrayAgg)
-
-
-def collect_set(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.ArrayUniqueAgg)
-
-
-def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
- return Column.invoke_anonymous_function(col1, "HYPOT", col2)
-
-
-def pow(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
- return Column.invoke_expression_over_column(col1, expression.Pow, expression=col2)
-
-
-def row_number() -> Column:
- return Column(expression.Anonymous(this="ROW_NUMBER"))
-
-
-def dense_rank() -> Column:
- return Column(expression.Anonymous(this="DENSE_RANK"))
-
-
-def rank() -> Column:
- return Column(expression.Anonymous(this="RANK"))
-
-
-def cume_dist() -> Column:
- return Column(expression.Anonymous(this="CUME_DIST"))
-
-
-def percent_rank() -> Column:
- return Column(expression.Anonymous(this="PERCENT_RANK"))
-
-
-def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column:
- return approx_count_distinct(col, rsd)
-
-
-def approx_count_distinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column:
- if rsd is None:
- return Column.invoke_expression_over_column(col, expression.ApproxDistinct)
- return Column.invoke_expression_over_column(col, expression.ApproxDistinct, accuracy=rsd)
-
-
-def coalesce(*cols: ColumnOrName) -> Column:
- if len(cols) > 1:
- return Column.invoke_expression_over_column(
- cols[0], expression.Coalesce, expressions=cols[1:]
- )
- return Column.invoke_expression_over_column(cols[0], expression.Coalesce)
-
-
-def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col1, expression.Corr, expression=col2)
-
-
-def covar_pop(col1: ColumnOrName, col2: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col1, expression.CovarPop, expression=col2)
-
-
-def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col1, expression.CovarSamp, expression=col2)
-
-
-def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
- this = Column.invoke_expression_over_column(col, expression.First)
- if ignorenulls:
- return Column.invoke_expression_over_column(this, expression.IgnoreNulls)
- return this
-
-
-def grouping_id(*cols: ColumnOrName) -> Column:
- if not cols:
- return Column.invoke_anonymous_function(None, "GROUPING_ID")
- if len(cols) == 1:
- return Column.invoke_anonymous_function(cols[0], "GROUPING_ID")
- return Column.invoke_anonymous_function(cols[0], "GROUPING_ID", *cols[1:])
-
-
-def input_file_name() -> Column:
- return Column.invoke_anonymous_function(None, "INPUT_FILE_NAME")
-
-
-def isnan(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.IsNan)
-
-
-def isnull(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "ISNULL")
-
-
-def last(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
- this = Column.invoke_expression_over_column(col, expression.Last)
- if ignorenulls:
- return Column.invoke_expression_over_column(this, expression.IgnoreNulls)
- return this
-
-
-def monotonically_increasing_id() -> Column:
- return Column.invoke_anonymous_function(None, "MONOTONICALLY_INCREASING_ID")
-
-
-def nanvl(col1: ColumnOrName, col2: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col1, "NANVL", col2)
-
-
-def percentile_approx(
- col: ColumnOrName,
- percentage: t.Union[ColumnOrLiteral, t.List[float], t.Tuple[float]],
- accuracy: t.Optional[t.Union[ColumnOrLiteral, int]] = None,
-) -> Column:
- if accuracy:
- return Column.invoke_expression_over_column(
- col, expression.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy
- )
- return Column.invoke_expression_over_column(
- col, expression.ApproxQuantile, quantile=lit(percentage)
- )
-
-
-def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
- return Column.invoke_expression_over_column(seed, expression.Rand)
-
-
-def randn(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
- return Column.invoke_expression_over_column(seed, expression.Randn)
-
-
-def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
- if scale is not None:
- return Column.invoke_expression_over_column(col, expression.Round, decimals=scale)
- return Column.invoke_expression_over_column(col, expression.Round)
-
-
-def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
- if scale is not None:
- return Column.invoke_anonymous_function(col, "BROUND", scale)
- return Column.invoke_anonymous_function(col, "BROUND")
-
-
-def shiftleft(col: ColumnOrName, numBits: int) -> Column:
- return Column.invoke_expression_over_column(
- col, expression.BitwiseLeftShift, expression=numBits
- )
-
-
-def shiftLeft(col: ColumnOrName, numBits: int) -> Column:
- return shiftleft(col, numBits)
-
-
-def shiftright(col: ColumnOrName, numBits: int) -> Column:
- return Column.invoke_expression_over_column(
- col, expression.BitwiseRightShift, expression=numBits
- )
-
-
-def shiftRight(col: ColumnOrName, numBits: int) -> Column:
- return shiftright(col, numBits)
-
-
-def shiftrightunsigned(col: ColumnOrName, numBits: int) -> Column:
- return Column.invoke_anonymous_function(col, "SHIFTRIGHTUNSIGNED", numBits)
-
-
-def shiftRightUnsigned(col: ColumnOrName, numBits: int) -> Column:
- return shiftrightunsigned(col, numBits)
-
-
-def expr(str: str) -> Column:
- return Column(str)
-
-
-def struct(col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOrName) -> Column:
- columns = ensure_list(col) + list(cols)
- return Column.invoke_expression_over_column(None, expression.Struct, expressions=columns)
-
-
-def conv(col: ColumnOrName, fromBase: int, toBase: int) -> Column:
- return Column.invoke_anonymous_function(col, "CONV", fromBase, toBase)
-
-
-def factorial(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "FACTORIAL")
-
-
-def lag(
- col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None
-) -> Column:
- return Column.invoke_expression_over_column(
- col, expression.Lag, offset=None if offset == 1 else offset, default=default
- )
-
-
-def lead(
- col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None
-) -> Column:
- return Column.invoke_expression_over_column(
- col, expression.Lead, offset=None if offset == 1 else offset, default=default
- )
-
-
-def nth_value(
- col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None
-) -> Column:
- this = Column.invoke_expression_over_column(
- col, expression.NthValue, offset=None if offset == 1 else offset
- )
- if ignoreNulls is not None:
- return Column.invoke_expression_over_column(this, expression.IgnoreNulls)
- return this
-
-
-def ntile(n: int) -> Column:
- return Column.invoke_anonymous_function(None, "NTILE", n)
-
-
-def current_date() -> Column:
- return Column.invoke_expression_over_column(None, expression.CurrentDate)
-
-
-def current_timestamp() -> Column:
- return Column.invoke_expression_over_column(None, expression.CurrentTimestamp)
-
-
-def date_format(col: ColumnOrName, format: str) -> Column:
- return Column.invoke_expression_over_column(col, expression.TimeToStr, format=lit(format))
-
-
-def year(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Year)
-
-
-def quarter(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Quarter)
-
-
-def month(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Month)
-
-
-def dayofweek(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.DayOfWeek)
-
-
-def dayofmonth(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.DayOfMonth)
-
-
-def dayofyear(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.DayOfYear)
-
-
-def hour(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "HOUR")
-
-
-def minute(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "MINUTE")
-
-
-def second(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "SECOND")
-
-
-def weekofyear(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.WeekOfYear)
-
-
-def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(year, "MAKE_DATE", month, day)
-
-
-def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
- return Column.invoke_expression_over_column(
- col, expression.DateAdd, expression=days, unit=expression.Var(this="DAY")
- )
-
-
-def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
- return Column.invoke_expression_over_column(
- col, expression.DateSub, expression=days, unit=expression.Var(this="DAY")
- )
-
-
-def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(end, expression.DateDiff, expression=start)
-
-
-def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column:
- return Column.invoke_expression_over_column(start, expression.AddMonths, expression=months)
-
-
-def months_between(
- date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None
-) -> Column:
- if roundOff is None:
- return Column.invoke_expression_over_column(
- date1, expression.MonthsBetween, expression=date2
- )
-
- return Column.invoke_expression_over_column(
- date1, expression.MonthsBetween, expression=date2, roundoff=roundOff
- )
-
-
-def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
- if format is not None:
- return Column.invoke_expression_over_column(
- col, expression.TsOrDsToDate, format=lit(format)
- )
- return Column.invoke_expression_over_column(col, expression.TsOrDsToDate)
-
-
-def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
- if format is not None:
- return Column.invoke_expression_over_column(col, expression.StrToTime, format=lit(format))
-
- return Column.ensure_col(col).cast("timestamp")
-
-
-def trunc(col: ColumnOrName, format: str) -> Column:
- return Column.invoke_expression_over_column(col, expression.DateTrunc, unit=lit(format))
-
-
-def date_trunc(format: str, timestamp: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(
- timestamp, expression.TimestampTrunc, unit=lit(format)
- )
-
-
-def next_day(col: ColumnOrName, dayOfWeek: str) -> Column:
- return Column.invoke_anonymous_function(col, "NEXT_DAY", lit(dayOfWeek))
-
-
-def last_day(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.LastDay)
-
-
-def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
- if format is not None:
- return Column.invoke_expression_over_column(col, expression.UnixToStr, format=lit(format))
- return Column.invoke_expression_over_column(col, expression.UnixToStr)
-
-
-def unix_timestamp(
- timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None
-) -> Column:
- if format is not None:
- return Column.invoke_expression_over_column(
- timestamp, expression.StrToUnix, format=lit(format)
- )
- return Column.invoke_expression_over_column(timestamp, expression.StrToUnix)
-
-
-def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
- tz_column = tz if isinstance(tz, Column) else lit(tz)
- return Column.invoke_expression_over_column(timestamp, expression.AtTimeZone, zone=tz_column)
-
-
-def to_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column:
- tz_column = tz if isinstance(tz, Column) else lit(tz)
- return Column.invoke_expression_over_column(timestamp, expression.FromTimeZone, zone=tz_column)
-
-
-def timestamp_seconds(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "TIMESTAMP_SECONDS")
-
-
-def window(
- timeColumn: ColumnOrName,
- windowDuration: str,
- slideDuration: t.Optional[str] = None,
- startTime: t.Optional[str] = None,
-) -> Column:
- if slideDuration is not None and startTime is not None:
- return Column.invoke_anonymous_function(
- timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration), lit(startTime)
- )
- if slideDuration is not None:
- return Column.invoke_anonymous_function(
- timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration)
- )
- if startTime is not None:
- return Column.invoke_anonymous_function(
- timeColumn, "WINDOW", lit(windowDuration), lit(windowDuration), lit(startTime)
- )
- return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration))
-
-
-def session_window(timeColumn: ColumnOrName, gapDuration: ColumnOrName) -> Column:
- gap_duration_column = gapDuration if isinstance(gapDuration, Column) else lit(gapDuration)
- return Column.invoke_anonymous_function(timeColumn, "SESSION_WINDOW", gap_duration_column)
-
-
-def crc32(col: ColumnOrName) -> Column:
- column = col if isinstance(col, Column) else lit(col)
- return Column.invoke_anonymous_function(column, "CRC32")
-
-
-def md5(col: ColumnOrName) -> Column:
- column = col if isinstance(col, Column) else lit(col)
- return Column.invoke_expression_over_column(column, expression.MD5)
-
-
-def sha1(col: ColumnOrName) -> Column:
- column = col if isinstance(col, Column) else lit(col)
- return Column.invoke_expression_over_column(column, expression.SHA)
-
-
-def sha2(col: ColumnOrName, numBits: int) -> Column:
- column = col if isinstance(col, Column) else lit(col)
- return Column.invoke_expression_over_column(column, expression.SHA2, length=lit(numBits))
-
-
-def hash(*cols: ColumnOrName) -> Column:
- args = cols[1:] if len(cols) > 1 else []
- return Column.invoke_anonymous_function(cols[0], "HASH", *args)
-
-
-def xxhash64(*cols: ColumnOrName) -> Column:
- args = cols[1:] if len(cols) > 1 else []
- return Column.invoke_anonymous_function(cols[0], "XXHASH64", *args)
-
-
-def assert_true(col: ColumnOrName, errorMsg: t.Optional[ColumnOrName] = None) -> Column:
- if errorMsg is not None:
- error_msg_col = errorMsg if isinstance(errorMsg, Column) else lit(errorMsg)
- return Column.invoke_anonymous_function(col, "ASSERT_TRUE", error_msg_col)
- return Column.invoke_anonymous_function(col, "ASSERT_TRUE")
-
-
-def raise_error(errorMsg: ColumnOrName) -> Column:
- error_msg_col = errorMsg if isinstance(errorMsg, Column) else lit(errorMsg)
- return Column.invoke_anonymous_function(error_msg_col, "RAISE_ERROR")
-
-
-def upper(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Upper)
-
-
-def lower(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Lower)
-
-
-def ascii(col: ColumnOrLiteral) -> Column:
- return Column.invoke_anonymous_function(col, "ASCII")
-
-
-def base64(col: ColumnOrLiteral) -> Column:
- return Column.invoke_expression_over_column(col, expression.ToBase64)
-
-
-def unbase64(col: ColumnOrLiteral) -> Column:
- return Column.invoke_expression_over_column(col, expression.FromBase64)
-
-
-def ltrim(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "LTRIM")
-
-
-def rtrim(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "RTRIM")
-
-
-def trim(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Trim)
-
-
-def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(
- None, expression.ConcatWs, expressions=[lit(sep)] + list(cols)
- )
-
-
-def decode(col: ColumnOrName, charset: str) -> Column:
- return Column.invoke_expression_over_column(
- col, expression.Decode, charset=expression.Literal.string(charset)
- )
-
-
-def encode(col: ColumnOrName, charset: str) -> Column:
- return Column.invoke_expression_over_column(
- col, expression.Encode, charset=expression.Literal.string(charset)
- )
-
-
-def format_number(col: ColumnOrName, d: int) -> Column:
- return Column.invoke_anonymous_function(col, "FORMAT_NUMBER", lit(d))
-
-
-def format_string(format: str, *cols: ColumnOrName) -> Column:
- format_col = lit(format)
- columns = [Column.ensure_col(x) for x in cols]
- return Column.invoke_anonymous_function(format_col, "FORMAT_STRING", *columns)
-
-
-def instr(col: ColumnOrName, substr: str) -> Column:
- return Column.invoke_anonymous_function(col, "INSTR", lit(substr))
-
-
-def overlay(
- src: ColumnOrName,
- replace: ColumnOrName,
- pos: t.Union[ColumnOrName, int],
- len: t.Optional[t.Union[ColumnOrName, int]] = None,
-) -> Column:
- if len is not None:
- return Column.invoke_anonymous_function(src, "OVERLAY", replace, pos, len)
- return Column.invoke_anonymous_function(src, "OVERLAY", replace, pos)
-
-
-def sentences(
- string: ColumnOrName,
- language: t.Optional[ColumnOrName] = None,
- country: t.Optional[ColumnOrName] = None,
-) -> Column:
- if language is not None and country is not None:
- return Column.invoke_anonymous_function(string, "SENTENCES", language, country)
- if language is not None:
- return Column.invoke_anonymous_function(string, "SENTENCES", language)
- if country is not None:
- return Column.invoke_anonymous_function(string, "SENTENCES", lit("en"), country)
- return Column.invoke_anonymous_function(string, "SENTENCES")
-
-
-def substring(str: ColumnOrName, pos: int, len: int) -> Column:
- return Column.ensure_col(str).substr(pos, len)
-
-
-def substring_index(str: ColumnOrName, delim: str, count: int) -> Column:
- return Column.invoke_anonymous_function(str, "SUBSTRING_INDEX", lit(delim), lit(count))
-
-
-def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(left, expression.Levenshtein, expression=right)
-
-
-def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column:
- substr_col = lit(substr)
- if pos is not None:
- return Column.invoke_expression_over_column(
- str, expression.StrPosition, substr=substr_col, position=pos
- )
- return Column.invoke_expression_over_column(str, expression.StrPosition, substr=substr_col)
-
-
-def lpad(col: ColumnOrName, len: int, pad: str) -> Column:
- return Column.invoke_anonymous_function(col, "LPAD", lit(len), lit(pad))
-
-
-def rpad(col: ColumnOrName, len: int, pad: str) -> Column:
- return Column.invoke_anonymous_function(col, "RPAD", lit(len), lit(pad))
-
-
-def repeat(col: ColumnOrName, n: int) -> Column:
- return Column.invoke_expression_over_column(col, expression.Repeat, times=lit(n))
-
-
-def split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column:
- if limit is not None:
- return Column.invoke_expression_over_column(
- str, expression.RegexpSplit, expression=lit(pattern).expression, limit=limit
- )
- return Column.invoke_expression_over_column(
- str, expression.RegexpSplit, expression=lit(pattern)
- )
-
-
-def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) -> Column:
- return Column.invoke_expression_over_column(
- str,
- expression.RegexpExtract,
- expression=lit(pattern),
- group=idx,
- )
-
-
-def regexp_replace(
- str: ColumnOrName, pattern: str, replacement: str, position: t.Optional[int] = None
-) -> Column:
- return Column.invoke_expression_over_column(
- str,
- expression.RegexpReplace,
- expression=lit(pattern),
- replacement=lit(replacement),
- position=position,
- )
-
-
-def initcap(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Initcap)
-
-
-def soundex(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "SOUNDEX")
-
-
-def bin(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "BIN")
-
-
-def hex(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Hex)
-
-
-def unhex(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Unhex)
-
-
-def length(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Length)
-
-
-def octet_length(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "OCTET_LENGTH")
-
-
-def bit_length(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "BIT_LENGTH")
-
-
-def translate(srcCol: ColumnOrName, matching: str, replace: str) -> Column:
- return Column.invoke_anonymous_function(srcCol, "TRANSLATE", lit(matching), lit(replace))
-
-
-def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
- columns = _flatten(cols) if not isinstance(cols[0], (str, Column)) else cols
- return Column.invoke_expression_over_column(None, expression.Array, expressions=columns)
-
-
-def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
- cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
- return Column.invoke_expression_over_column(
- None,
- expression.VarMap,
- keys=array(*cols[::2]).expression,
- values=array(*cols[1::2]).expression,
- )
-
-
-def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(None, expression.Map, keys=col1, values=col2)
-
-
-def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
- value_col = value if isinstance(value, Column) else lit(value)
- return Column.invoke_expression_over_column(
- col, expression.ArrayContains, expression=value_col.expression
- )
-
-
-def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col1, "ARRAYS_OVERLAP", Column.ensure_col(col2))
-
-
-def slice(
- x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]
-) -> Column:
- start_col = start if isinstance(start, Column) else lit(start)
- length_col = length if isinstance(length, Column) else lit(length)
- return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col)
-
-
-def array_join(
- col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None
-) -> Column:
- if null_replacement is not None:
- return Column.invoke_expression_over_column(
- col, expression.ArrayToString, expression=lit(delimiter), null=lit(null_replacement)
- )
- return Column.invoke_expression_over_column(
- col, expression.ArrayToString, expression=lit(delimiter)
- )
-
-
-def concat(*cols: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(None, expression.Concat, expressions=cols)
-
-
-def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
- value_col = value if isinstance(value, Column) else lit(value)
- return Column.invoke_anonymous_function(col, "ARRAY_POSITION", value_col)
-
-
-def element_at(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
- value_col = value if isinstance(value, Column) else lit(value)
- return Column.invoke_anonymous_function(col, "ELEMENT_AT", value_col)
-
-
-def array_remove(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
- value_col = value if isinstance(value, Column) else lit(value)
- return Column.invoke_anonymous_function(col, "ARRAY_REMOVE", value_col)
-
-
-def array_distinct(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "ARRAY_DISTINCT")
-
-
-def array_intersect(col1: ColumnOrName, col2: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col1, "ARRAY_INTERSECT", Column.ensure_col(col2))
-
-
-def array_union(col1: ColumnOrName, col2: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col1, "ARRAY_UNION", Column.ensure_col(col2))
-
-
-def array_except(col1: ColumnOrName, col2: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col1, "ARRAY_EXCEPT", Column.ensure_col(col2))
-
-
-def explode(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Explode)
-
-
-def posexplode(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Posexplode)
-
-
-def explode_outer(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.ExplodeOuter)
-
-
-def posexplode_outer(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.PosexplodeOuter)
-
-
-def get_json_object(col: ColumnOrName, path: str) -> Column:
- return Column.invoke_expression_over_column(col, expression.JSONExtract, expression=lit(path))
-
-
-def json_tuple(col: ColumnOrName, *fields: str) -> Column:
- return Column.invoke_anonymous_function(col, "JSON_TUPLE", *[lit(field) for field in fields])
-
-
-def from_json(
- col: ColumnOrName,
- schema: t.Union[Column, str],
- options: t.Optional[t.Dict[str, str]] = None,
-) -> Column:
- schema = schema if isinstance(schema, Column) else lit(schema)
- if options is not None:
- options_col = create_map([lit(x) for x in _flatten(options.items())])
- return Column.invoke_anonymous_function(col, "FROM_JSON", schema, options_col)
- return Column.invoke_anonymous_function(col, "FROM_JSON", schema)
-
-
-def to_json(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column:
- if options is not None:
- options_col = create_map([lit(x) for x in _flatten(options.items())])
- return Column.invoke_expression_over_column(col, expression.JSONFormat, options=options_col)
- return Column.invoke_expression_over_column(col, expression.JSONFormat)
-
-
-def schema_of_json(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column:
- if options is not None:
- options_col = create_map([lit(x) for x in _flatten(options.items())])
- return Column.invoke_anonymous_function(col, "SCHEMA_OF_JSON", options_col)
- return Column.invoke_anonymous_function(col, "SCHEMA_OF_JSON")
-
-
-def schema_of_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column:
- if options is not None:
- options_col = create_map([lit(x) for x in _flatten(options.items())])
- return Column.invoke_anonymous_function(col, "SCHEMA_OF_CSV", options_col)
- return Column.invoke_anonymous_function(col, "SCHEMA_OF_CSV")
-
-
-def to_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column:
- if options is not None:
- options_col = create_map([lit(x) for x in _flatten(options.items())])
- return Column.invoke_anonymous_function(col, "TO_CSV", options_col)
- return Column.invoke_anonymous_function(col, "TO_CSV")
-
-
-def size(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.ArraySize)
-
-
-def array_min(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "ARRAY_MIN")
-
-
-def array_max(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "ARRAY_MAX")
-
-
-def sort_array(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column:
- if asc is not None:
- return Column.invoke_expression_over_column(col, expression.SortArray, asc=asc)
- return Column.invoke_expression_over_column(col, expression.SortArray)
-
-
-def array_sort(
- col: ColumnOrName,
- comparator: t.Optional[t.Union[t.Callable[[Column, Column], Column]]] = None,
-) -> Column:
- if comparator is not None:
- f_expression = _get_lambda_from_func(comparator)
- return Column.invoke_expression_over_column(
- col, expression.ArraySort, expression=f_expression
- )
- return Column.invoke_expression_over_column(col, expression.ArraySort)
-
-
-def shuffle(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "SHUFFLE")
-
-
-def reverse(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "REVERSE")
-
-
-def flatten(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Flatten)
-
-
-def map_keys(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "MAP_KEYS")
-
-
-def map_values(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "MAP_VALUES")
-
-
-def map_entries(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "MAP_ENTRIES")
-
-
-def map_from_entries(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.MapFromEntries)
-
-
-def array_repeat(col: ColumnOrName, count: t.Union[ColumnOrName, int]) -> Column:
- count_col = count if isinstance(count, Column) else lit(count)
- return Column.invoke_anonymous_function(col, "ARRAY_REPEAT", count_col)
-
-
-def array_zip(*cols: ColumnOrName) -> Column:
- if len(cols) == 1:
- return Column.invoke_anonymous_function(cols[0], "ARRAY_ZIP")
- return Column.invoke_anonymous_function(cols[0], "ARRAY_ZIP", *cols[1:])
-
-
-def map_concat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
- columns = list(flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
- if len(columns) == 1:
- return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT")
- return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT", *columns[1:])
-
-
-def sequence(
- start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None
-) -> Column:
- if step is not None:
- return Column.invoke_anonymous_function(start, "SEQUENCE", stop, step)
- return Column.invoke_anonymous_function(start, "SEQUENCE", stop)
-
-
-def from_csv(
- col: ColumnOrName,
- schema: t.Union[Column, str],
- options: t.Optional[t.Dict[str, str]] = None,
-) -> Column:
- schema = schema if isinstance(schema, Column) else lit(schema)
- if options is not None:
- option_cols = create_map([lit(x) for x in _flatten(options.items())])
- return Column.invoke_anonymous_function(col, "FROM_CSV", schema, option_cols)
- return Column.invoke_anonymous_function(col, "FROM_CSV", schema)
-
-
-def aggregate(
- col: ColumnOrName,
- initialValue: ColumnOrName,
- merge: t.Callable[[Column, Column], Column],
- finish: t.Optional[t.Callable[[Column], Column]] = None,
-) -> Column:
- merge_exp = _get_lambda_from_func(merge)
- if finish is not None:
- finish_exp = _get_lambda_from_func(finish)
- return Column.invoke_expression_over_column(
- col,
- expression.Reduce,
- initial=initialValue,
- merge=Column(merge_exp),
- finish=Column(finish_exp),
- )
- return Column.invoke_expression_over_column(
- col, expression.Reduce, initial=initialValue, merge=Column(merge_exp)
- )
-
-
-def transform(
- col: ColumnOrName,
- f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
-) -> Column:
- f_expression = _get_lambda_from_func(f)
- return Column.invoke_expression_over_column(
- col, expression.Transform, expression=Column(f_expression)
- )
-
-
-def exists(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column:
- f_expression = _get_lambda_from_func(f)
- return Column.invoke_anonymous_function(col, "EXISTS", Column(f_expression))
-
-
-def forall(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column:
- f_expression = _get_lambda_from_func(f)
- return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression))
-
-
-def filter(
- col: ColumnOrName,
- f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
-) -> Column:
- f_expression = _get_lambda_from_func(f)
- return Column.invoke_expression_over_column(
- col, expression.ArrayFilter, expression=f_expression
- )
-
-
-def zip_with(
- left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]
-) -> Column:
- f_expression = _get_lambda_from_func(f)
- return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression))
-
-
-def transform_keys(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column:
- f_expression = _get_lambda_from_func(f)
- return Column.invoke_anonymous_function(col, "TRANSFORM_KEYS", Column(f_expression))
-
-
-def transform_values(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column:
- f_expression = _get_lambda_from_func(f)
- return Column.invoke_anonymous_function(col, "TRANSFORM_VALUES", Column(f_expression))
-
-
-def map_filter(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column:
- f_expression = _get_lambda_from_func(f)
- return Column.invoke_anonymous_function(col, "MAP_FILTER", Column(f_expression))
-
-
-def map_zip_with(
- col1: ColumnOrName,
- col2: ColumnOrName,
- f: t.Union[t.Callable[[Column, Column, Column], Column]],
-) -> Column:
- f_expression = _get_lambda_from_func(f)
- return Column.invoke_anonymous_function(col1, "MAP_ZIP_WITH", col2, Column(f_expression))
-
-
-def _lambda_quoted(value: str) -> t.Optional[bool]:
- return False if value == "_" else None
-
-
-def _get_lambda_from_func(lambda_expression: t.Callable):
- variables = [
- expression.to_identifier(x, quoted=_lambda_quoted(x))
- for x in lambda_expression.__code__.co_varnames
- ]
- return expression.Lambda(
- this=lambda_expression(*[Column(x) for x in variables]).expression,
- expressions=variables,
- )
diff --git a/sqlglot/dataframe/sql/group.py b/sqlglot/dataframe/sql/group.py
deleted file mode 100644
index ba27c17..0000000
--- a/sqlglot/dataframe/sql/group.py
+++ /dev/null
@@ -1,59 +0,0 @@
-from __future__ import annotations
-
-import typing as t
-
-from sqlglot.dataframe.sql import functions as F
-from sqlglot.dataframe.sql.column import Column
-from sqlglot.dataframe.sql.operations import Operation, operation
-
-if t.TYPE_CHECKING:
- from sqlglot.dataframe.sql.dataframe import DataFrame
-
-
-class GroupedData:
- def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation):
- self._df = df.copy()
- self.spark = df.spark
- self.last_op = last_op
- self.group_by_cols = group_by_cols
-
- def _get_function_applied_columns(
- self, func_name: str, cols: t.Tuple[str, ...]
- ) -> t.List[Column]:
- func_name = func_name.lower()
- return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols]
-
- @operation(Operation.SELECT)
- def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame:
- columns = (
- [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
- if isinstance(exprs[0], dict)
- else exprs
- )
- cols = self._df._ensure_and_normalize_cols(columns)
-
- expression = self._df.expression.group_by(
- *[x.expression for x in self.group_by_cols]
- ).select(*[x.expression for x in self.group_by_cols + cols], append=False)
- return self._df.copy(expression=expression)
-
- def count(self) -> DataFrame:
- return self.agg(F.count("*").alias("count"))
-
- def mean(self, *cols: str) -> DataFrame:
- return self.avg(*cols)
-
- def avg(self, *cols: str) -> DataFrame:
- return self.agg(*self._get_function_applied_columns("avg", cols))
-
- def max(self, *cols: str) -> DataFrame:
- return self.agg(*self._get_function_applied_columns("max", cols))
-
- def min(self, *cols: str) -> DataFrame:
- return self.agg(*self._get_function_applied_columns("min", cols))
-
- def sum(self, *cols: str) -> DataFrame:
- return self.agg(*self._get_function_applied_columns("sum", cols))
-
- def pivot(self, *cols: str) -> DataFrame:
- raise NotImplementedError("Sum distinct is not currently implemented")
diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py
deleted file mode 100644
index b246641..0000000
--- a/sqlglot/dataframe/sql/normalize.py
+++ /dev/null
@@ -1,78 +0,0 @@
-from __future__ import annotations
-
-import typing as t
-
-from sqlglot import expressions as exp
-from sqlglot.dataframe.sql.column import Column
-from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
-from sqlglot.helper import ensure_list
-
-if t.TYPE_CHECKING:
- from sqlglot.dataframe.sql.session import SparkSession
-
- NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
-
-
-def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]):
- expr = ensure_list(expr)
- expressions = _ensure_expressions(expr)
- for expression in expressions:
- identifiers = expression.find_all(exp.Identifier)
- for identifier in identifiers:
- identifier.transform(spark.dialect.normalize_identifier)
- replace_alias_name_with_cte_name(spark, expression_context, identifier)
- replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
-
-
-def replace_alias_name_with_cte_name(
- spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
-):
- if id.alias_or_name in spark.name_to_sequence_id_mapping:
- for cte in reversed(expression_context.ctes):
- if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]:
- _set_alias_name(id, cte.alias_or_name)
- break
-
-
-def replace_branch_and_sequence_ids_with_cte_name(
- spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
-):
- if id.alias_or_name in spark.known_ids:
- # Check if we have a join and if both the tables in that join share a common branch id
- # If so we need to have this reference the left table by default unless the id is a sequence
- # id then it keeps that reference. This handles the weird edge case in spark that shouldn't
- # be common in practice
- if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids:
- join_table_aliases = [
- x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)
- ]
- ctes_in_join = [
- cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases
- ]
- if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]:
- assert len(ctes_in_join) == 2
- _set_alias_name(id, ctes_in_join[0].alias_or_name)
- return
-
- for cte in reversed(expression_context.ctes):
- if id.alias_or_name in (cte.args["branch_id"], cte.args["sequence_id"]):
- _set_alias_name(id, cte.alias_or_name)
- return
-
-
-def _set_alias_name(id: exp.Identifier, name: str):
- id.set("this", name)
-
-
-def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]:
- results = []
- for value in values:
- if isinstance(value, str):
- results.append(Column.ensure_col(value).expression)
- elif isinstance(value, Column):
- results.append(value.expression)
- elif isinstance(value, exp.Expression):
- results.append(value)
- else:
- raise ValueError(f"Got an invalid type to normalize: {type(value)}")
- return results
diff --git a/sqlglot/dataframe/sql/operations.py b/sqlglot/dataframe/sql/operations.py
deleted file mode 100644
index e4c106b..0000000
--- a/sqlglot/dataframe/sql/operations.py
+++ /dev/null
@@ -1,53 +0,0 @@
-from __future__ import annotations
-
-import functools
-import typing as t
-from enum import IntEnum
-
-if t.TYPE_CHECKING:
- from sqlglot.dataframe.sql.dataframe import DataFrame
- from sqlglot.dataframe.sql.group import GroupedData
-
-
-class Operation(IntEnum):
- INIT = -1
- NO_OP = 0
- FROM = 1
- WHERE = 2
- GROUP_BY = 3
- HAVING = 4
- SELECT = 5
- ORDER_BY = 6
- LIMIT = 7
-
-
-def operation(op: Operation):
- """
- Decorator used around DataFrame methods to indicate what type of operation is being performed from the
- ordered Operation enums. This is used to determine which operations should be performed on a CTE vs.
- included with the previous operation.
-
- Ex: After a user does a join we want to allow them to select which columns for the different
- tables that they want to carry through to the following operation. If we put that join in
- a CTE preemptively then the user would not have a chance to select which column they want
- in cases where there is overlap in names.
- """
-
- def decorator(func: t.Callable):
- @functools.wraps(func)
- def wrapper(self: DataFrame, *args, **kwargs):
- if self.last_op == Operation.INIT:
- self = self._convert_leaf_to_cte()
- self.last_op = Operation.NO_OP
- last_op = self.last_op
- new_op = op if op != Operation.NO_OP else last_op
- if new_op < last_op or (last_op == new_op == Operation.SELECT):
- self = self._convert_leaf_to_cte()
- df: t.Union[DataFrame, GroupedData] = func(self, *args, **kwargs)
- df.last_op = new_op # type: ignore
- return df
-
- wrapper.__wrapped__ = func # type: ignore
- return wrapper
-
- return decorator
diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py
deleted file mode 100644
index 0804486..0000000
--- a/sqlglot/dataframe/sql/readwriter.py
+++ /dev/null
@@ -1,108 +0,0 @@
-from __future__ import annotations
-
-import typing as t
-
-import sqlglot
-from sqlglot import expressions as exp
-from sqlglot.helper import object_to_dict
-
-if t.TYPE_CHECKING:
- from sqlglot.dataframe.sql.dataframe import DataFrame
- from sqlglot.dataframe.sql.session import SparkSession
-
-
-class DataFrameReader:
- def __init__(self, spark: SparkSession):
- self.spark = spark
-
- def table(self, tableName: str) -> DataFrame:
- from sqlglot.dataframe.sql.dataframe import DataFrame
- from sqlglot.dataframe.sql.session import SparkSession
-
- sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect)
-
- return DataFrame(
- self.spark,
- exp.Select()
- .from_(
- exp.to_table(tableName, dialect=SparkSession().dialect).transform(
- SparkSession().dialect.normalize_identifier
- )
- )
- .select(
- *(
- column
- for column in sqlglot.schema.column_names(
- tableName, dialect=SparkSession().dialect
- )
- )
- ),
- )
-
-
-class DataFrameWriter:
- def __init__(
- self,
- df: DataFrame,
- spark: t.Optional[SparkSession] = None,
- mode: t.Optional[str] = None,
- by_name: bool = False,
- ):
- self._df = df
- self._spark = spark or df.spark
- self._mode = mode
- self._by_name = by_name
-
- def copy(self, **kwargs) -> DataFrameWriter:
- return DataFrameWriter(
- **{
- k[1:] if k.startswith("_") else k: v
- for k, v in object_to_dict(self, **kwargs).items()
- }
- )
-
- def sql(self, **kwargs) -> t.List[str]:
- return self._df.sql(**kwargs)
-
- def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
- return self.copy(_mode=saveMode)
-
- @property
- def byName(self):
- return self.copy(by_name=True)
-
- def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
- from sqlglot.dataframe.sql.session import SparkSession
-
- output_expression_container = exp.Insert(
- **{
- "this": exp.to_table(tableName),
- "overwrite": overwrite,
- }
- )
- df = self._df.copy(output_expression_container=output_expression_container)
- if self._by_name:
- columns = sqlglot.schema.column_names(
- tableName, only_visible=True, dialect=SparkSession().dialect
- )
- df = df._convert_leaf_to_cte().select(*columns)
-
- return self.copy(_df=df)
-
- def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
- if format is not None:
- raise NotImplementedError("Providing Format in the save as table is not supported")
- exists, replace, mode = None, None, mode or str(self._mode)
- if mode == "append":
- return self.insertInto(name)
- if mode == "ignore":
- exists = True
- if mode == "overwrite":
- replace = True
- output_expression_container = exp.Create(
- this=exp.to_table(name),
- kind="TABLE",
- exists=exists,
- replace=replace,
- )
- return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
deleted file mode 100644
index 4e47aaa..0000000
--- a/sqlglot/dataframe/sql/session.py
+++ /dev/null
@@ -1,199 +0,0 @@
-from __future__ import annotations
-
-import typing as t
-import uuid
-from collections import defaultdict
-
-import sqlglot
-from sqlglot import Dialect, expressions as exp
-from sqlglot.dataframe.sql import functions as F
-from sqlglot.dataframe.sql.dataframe import DataFrame
-from sqlglot.dataframe.sql.readwriter import DataFrameReader
-from sqlglot.dataframe.sql.types import StructType
-from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input
-from sqlglot.helper import classproperty
-from sqlglot.optimizer import optimize
-from sqlglot.optimizer.qualify_columns import quote_identifiers
-
-if t.TYPE_CHECKING:
- from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput
-
-
-class SparkSession:
- DEFAULT_DIALECT = "spark"
- _instance = None
-
- def __init__(self):
- if not hasattr(self, "known_ids"):
- self.known_ids = set()
- self.known_branch_ids = set()
- self.known_sequence_ids = set()
- self.name_to_sequence_id_mapping = defaultdict(list)
- self.incrementing_id = 1
- self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)
-
- def __new__(cls, *args, **kwargs) -> SparkSession:
- if cls._instance is None:
- cls._instance = super().__new__(cls)
- return cls._instance
-
- @property
- def read(self) -> DataFrameReader:
- return DataFrameReader(self)
-
- def table(self, tableName: str) -> DataFrame:
- return self.read.table(tableName)
-
- def createDataFrame(
- self,
- data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
- schema: t.Optional[SchemaInput] = None,
- samplingRatio: t.Optional[float] = None,
- verifySchema: bool = False,
- ) -> DataFrame:
- from sqlglot.dataframe.sql.dataframe import DataFrame
-
- if samplingRatio is not None or verifySchema:
- raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
- if schema is not None and (
- not isinstance(schema, (StructType, str, list))
- or (isinstance(schema, list) and not isinstance(schema[0], str))
- ):
- raise NotImplementedError("Only schema of either list or string of list supported")
- if not data:
- raise ValueError("Must provide data to create into a DataFrame")
-
- column_mapping: t.Dict[str, t.Optional[str]]
- if schema is not None:
- column_mapping = get_column_mapping_from_schema_input(schema)
- elif isinstance(data[0], dict):
- column_mapping = {col_name.strip(): None for col_name in data[0]}
- else:
- column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
-
- data_expressions = [
- exp.tuple_(
- *map(
- lambda x: F.lit(x).expression,
- row if not isinstance(row, dict) else row.values(),
- )
- )
- for row in data
- ]
-
- sel_columns = [
- (
- F.col(name).cast(data_type).alias(name).expression
- if data_type is not None
- else F.col(name).expression
- )
- for name, data_type in column_mapping.items()
- ]
-
- select_kwargs = {
- "expressions": sel_columns,
- "from": exp.From(
- this=exp.Values(
- expressions=data_expressions,
- alias=exp.TableAlias(
- this=exp.to_identifier(self._auto_incrementing_name),
- columns=[exp.to_identifier(col_name) for col_name in column_mapping],
- ),
- ),
- ),
- }
-
- sel_expression = exp.Select(**select_kwargs)
- return DataFrame(self, sel_expression)
-
- def _optimize(
- self, expression: exp.Expression, dialect: t.Optional[Dialect] = None
- ) -> exp.Expression:
- dialect = dialect or self.dialect
- quote_identifiers(expression, dialect=dialect)
- return optimize(expression, dialect=dialect)
-
- def sql(self, sqlQuery: str) -> DataFrame:
- expression = self._optimize(sqlglot.parse_one(sqlQuery, read=self.dialect))
- if isinstance(expression, exp.Select):
- df = DataFrame(self, expression)
- df = df._convert_leaf_to_cte()
- elif isinstance(expression, (exp.Create, exp.Insert)):
- select_expression = expression.expression.copy()
- if isinstance(expression, exp.Insert):
- select_expression.set("with", expression.args.get("with"))
- expression.set("with", None)
- del expression.args["expression"]
- df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore
- df = df._convert_leaf_to_cte()
- else:
- raise ValueError(
- "Unknown expression type provided in the SQL. Please create an issue with the SQL."
- )
- return df
-
- @property
- def _auto_incrementing_name(self) -> str:
- name = f"a{self.incrementing_id}"
- self.incrementing_id += 1
- return name
-
- @property
- def _random_branch_id(self) -> str:
- id = self._random_id
- self.known_branch_ids.add(id)
- return id
-
- @property
- def _random_sequence_id(self):
- id = self._random_id
- self.known_sequence_ids.add(id)
- return id
-
- @property
- def _random_id(self) -> str:
- id = "r" + uuid.uuid4().hex
- self.known_ids.add(id)
- return id
-
- @property
- def _join_hint_names(self) -> t.Set[str]:
- return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"}
-
- def _add_alias_to_mapping(self, name: str, sequence_id: str):
- self.name_to_sequence_id_mapping[name].append(sequence_id)
-
- class Builder:
- SQLFRAME_DIALECT_KEY = "sqlframe.dialect"
-
- def __init__(self):
- self.dialect = "spark"
-
- def __getattr__(self, item) -> SparkSession.Builder:
- return self
-
- def __call__(self, *args, **kwargs):
- return self
-
- def config(
- self,
- key: t.Optional[str] = None,
- value: t.Optional[t.Any] = None,
- *,
- map: t.Optional[t.Dict[str, t.Any]] = None,
- **kwargs: t.Any,
- ) -> SparkSession.Builder:
- if key == self.SQLFRAME_DIALECT_KEY:
- self.dialect = value
- elif map and self.SQLFRAME_DIALECT_KEY in map:
- self.dialect = map[self.SQLFRAME_DIALECT_KEY]
- return self
-
- def getOrCreate(self) -> SparkSession:
- spark = SparkSession()
- spark.dialect = Dialect.get_or_raise(self.dialect)
- return spark
-
- @classproperty
- def builder(cls) -> Builder:
- return cls.Builder()
diff --git a/sqlglot/dataframe/sql/transforms.py b/sqlglot/dataframe/sql/transforms.py
deleted file mode 100644
index b3dcc12..0000000
--- a/sqlglot/dataframe/sql/transforms.py
+++ /dev/null
@@ -1,9 +0,0 @@
-import typing as t
-
-from sqlglot import expressions as exp
-
-
-def replace_id_value(node, replacement_mapping: t.Dict[exp.Identifier, exp.Identifier]):
- if isinstance(node, exp.Identifier) and node in replacement_mapping:
- node = node.replace(replacement_mapping[node].copy())
- return node
diff --git a/sqlglot/dataframe/sql/types.py b/sqlglot/dataframe/sql/types.py
deleted file mode 100644
index a63e505..0000000
--- a/sqlglot/dataframe/sql/types.py
+++ /dev/null
@@ -1,212 +0,0 @@
-import typing as t
-
-
-class DataType:
- def __repr__(self) -> str:
- return self.__class__.__name__ + "()"
-
- def __hash__(self) -> int:
- return hash(str(self))
-
- def __eq__(self, other: t.Any) -> bool:
- return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
-
- def __ne__(self, other: t.Any) -> bool:
- return not self.__eq__(other)
-
- def __str__(self) -> str:
- return self.typeName()
-
- @classmethod
- def typeName(cls) -> str:
- return cls.__name__[:-4].lower()
-
- def simpleString(self) -> str:
- return str(self)
-
- def jsonValue(self) -> t.Union[str, t.Dict[str, t.Any]]:
- return str(self)
-
-
-class DataTypeWithLength(DataType):
- def __init__(self, length: int):
- self.length = length
-
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}({self.length})"
-
- def __str__(self) -> str:
- return f"{self.typeName()}({self.length})"
-
-
-class StringType(DataType):
- pass
-
-
-class CharType(DataTypeWithLength):
- pass
-
-
-class VarcharType(DataTypeWithLength):
- pass
-
-
-class BinaryType(DataType):
- pass
-
-
-class BooleanType(DataType):
- pass
-
-
-class DateType(DataType):
- pass
-
-
-class TimestampType(DataType):
- pass
-
-
-class TimestampNTZType(DataType):
- @classmethod
- def typeName(cls) -> str:
- return "timestamp_ntz"
-
-
-class DecimalType(DataType):
- def __init__(self, precision: int = 10, scale: int = 0):
- self.precision = precision
- self.scale = scale
-
- def simpleString(self) -> str:
- return f"decimal({self.precision}, {self.scale})"
-
- def jsonValue(self) -> str:
- return f"decimal({self.precision}, {self.scale})"
-
- def __repr__(self) -> str:
- return f"DecimalType({self.precision}, {self.scale})"
-
-
-class DoubleType(DataType):
- pass
-
-
-class FloatType(DataType):
- pass
-
-
-class ByteType(DataType):
- def __str__(self) -> str:
- return "tinyint"
-
-
-class IntegerType(DataType):
- def __str__(self) -> str:
- return "int"
-
-
-class LongType(DataType):
- def __str__(self) -> str:
- return "bigint"
-
-
-class ShortType(DataType):
- def __str__(self) -> str:
- return "smallint"
-
-
-class ArrayType(DataType):
- def __init__(self, elementType: DataType, containsNull: bool = True):
- self.elementType = elementType
- self.containsNull = containsNull
-
- def __repr__(self) -> str:
- return f"ArrayType({self.elementType, str(self.containsNull)}"
-
- def simpleString(self) -> str:
- return f"array<{self.elementType.simpleString()}>"
-
- def jsonValue(self) -> t.Dict[str, t.Any]:
- return {
- "type": self.typeName(),
- "elementType": self.elementType.jsonValue(),
- "containsNull": self.containsNull,
- }
-
-
-class MapType(DataType):
- def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True):
- self.keyType = keyType
- self.valueType = valueType
- self.valueContainsNull = valueContainsNull
-
- def __repr__(self) -> str:
- return f"MapType({self.keyType}, {self.valueType}, {str(self.valueContainsNull)})"
-
- def simpleString(self) -> str:
- return f"map<{self.keyType.simpleString()}, {self.valueType.simpleString()}>"
-
- def jsonValue(self) -> t.Dict[str, t.Any]:
- return {
- "type": self.typeName(),
- "keyType": self.keyType.jsonValue(),
- "valueType": self.valueType.jsonValue(),
- "valueContainsNull": self.valueContainsNull,
- }
-
-
-class StructField(DataType):
- def __init__(
- self,
- name: str,
- dataType: DataType,
- nullable: bool = True,
- metadata: t.Optional[t.Dict[str, t.Any]] = None,
- ):
- self.name = name
- self.dataType = dataType
- self.nullable = nullable
- self.metadata = metadata or {}
-
- def __repr__(self) -> str:
- return f"StructField('{self.name}', {self.dataType}, {str(self.nullable)})"
-
- def simpleString(self) -> str:
- return f"{self.name}:{self.dataType.simpleString()}"
-
- def jsonValue(self) -> t.Dict[str, t.Any]:
- return {
- "name": self.name,
- "type": self.dataType.jsonValue(),
- "nullable": self.nullable,
- "metadata": self.metadata,
- }
-
-
-class StructType(DataType):
- def __init__(self, fields: t.Optional[t.List[StructField]] = None):
- if not fields:
- self.fields = []
- self.names = []
- else:
- self.fields = fields
- self.names = [f.name for f in fields]
-
- def __iter__(self) -> t.Iterator[StructField]:
- return iter(self.fields)
-
- def __len__(self) -> int:
- return len(self.fields)
-
- def __repr__(self) -> str:
- return f"StructType({', '.join(str(field) for field in self)})"
-
- def simpleString(self) -> str:
- return f"struct<{', '.join(x.simpleString() for x in self)}>"
-
- def jsonValue(self) -> t.Dict[str, t.Any]:
- return {"type": self.typeName(), "fields": [x.jsonValue() for x in self]}
-
- def fieldNames(self) -> t.List[str]:
- return list(self.names)
diff --git a/sqlglot/dataframe/sql/util.py b/sqlglot/dataframe/sql/util.py
deleted file mode 100644
index 4b9fbb1..0000000
--- a/sqlglot/dataframe/sql/util.py
+++ /dev/null
@@ -1,32 +0,0 @@
-from __future__ import annotations
-
-import typing as t
-
-from sqlglot import expressions as exp
-from sqlglot.dataframe.sql import types
-
-if t.TYPE_CHECKING:
- from sqlglot.dataframe.sql._typing import SchemaInput
-
-
-def get_column_mapping_from_schema_input(schema: SchemaInput) -> t.Dict[str, t.Optional[str]]:
- if isinstance(schema, dict):
- return schema
- elif isinstance(schema, str):
- col_name_type_strs = [x.strip() for x in schema.split(",")]
- return {
- name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
- for name_type_str in col_name_type_strs
- }
- elif isinstance(schema, types.StructType):
- return {struct_field.name: struct_field.dataType.simpleString() for struct_field in schema}
- return {x.strip(): None for x in schema} # type: ignore
-
-
-def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.Table]:
- if not expression.args.get("joins"):
- return []
-
- left_table = expression.args["from"].this
- other_tables = [join.this for join in expression.args["joins"]]
- return [left_table] + other_tables
diff --git a/sqlglot/dataframe/sql/window.py b/sqlglot/dataframe/sql/window.py
deleted file mode 100644
index 9e2fabd..0000000
--- a/sqlglot/dataframe/sql/window.py
+++ /dev/null
@@ -1,136 +0,0 @@
-from __future__ import annotations
-
-import sys
-import typing as t
-
-from sqlglot import expressions as exp
-from sqlglot.dataframe.sql import functions as F
-from sqlglot.helper import flatten
-
-if t.TYPE_CHECKING:
- from sqlglot.dataframe.sql._typing import ColumnOrName
-
-
-class Window:
- _JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808
- _JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807
- _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG)
- _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG)
-
- unboundedPreceding: int = _JAVA_MIN_LONG
-
- unboundedFollowing: int = _JAVA_MAX_LONG
-
- currentRow: int = 0
-
- @classmethod
- def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
- return WindowSpec().partitionBy(*cols)
-
- @classmethod
- def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
- return WindowSpec().orderBy(*cols)
-
- @classmethod
- def rowsBetween(cls, start: int, end: int) -> WindowSpec:
- return WindowSpec().rowsBetween(start, end)
-
- @classmethod
- def rangeBetween(cls, start: int, end: int) -> WindowSpec:
- return WindowSpec().rangeBetween(start, end)
-
-
-class WindowSpec:
- def __init__(self, expression: exp.Expression = exp.Window()):
- self.expression = expression
-
- def copy(self):
- return WindowSpec(self.expression.copy())
-
- def sql(self, **kwargs) -> str:
- from sqlglot.dataframe.sql.session import SparkSession
-
- return self.expression.sql(dialect=SparkSession().dialect, **kwargs)
-
- def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
- from sqlglot.dataframe.sql.column import Column
-
- cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
- expressions = [Column.ensure_col(x).expression for x in cols]
- window_spec = self.copy()
- partition_by_expressions = window_spec.expression.args.get("partition_by", [])
- partition_by_expressions.extend(expressions)
- window_spec.expression.set("partition_by", partition_by_expressions)
- return window_spec
-
- def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
- from sqlglot.dataframe.sql.column import Column
-
- cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
- expressions = [Column.ensure_col(x).expression for x in cols]
- window_spec = self.copy()
- if window_spec.expression.args.get("order") is None:
- window_spec.expression.set("order", exp.Order(expressions=[]))
- order_by = window_spec.expression.args["order"].expressions
- order_by.extend(expressions)
- window_spec.expression.args["order"].set("expressions", order_by)
- return window_spec
-
- def _calc_start_end(
- self, start: int, end: int
- ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
- kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {
- "start_side": None,
- "end_side": None,
- }
- if start == Window.currentRow:
- kwargs["start"] = "CURRENT ROW"
- else:
- kwargs = {
- **kwargs,
- **{
- "start_side": "PRECEDING",
- "start": (
- "UNBOUNDED"
- if start <= Window.unboundedPreceding
- else F.lit(start).expression
- ),
- },
- }
- if end == Window.currentRow:
- kwargs["end"] = "CURRENT ROW"
- else:
- kwargs = {
- **kwargs,
- **{
- "end_side": "FOLLOWING",
- "end": (
- "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression
- ),
- },
- }
- return kwargs
-
- def rowsBetween(self, start: int, end: int) -> WindowSpec:
- window_spec = self.copy()
- spec = self._calc_start_end(start, end)
- spec["kind"] = "ROWS"
- window_spec.expression.set(
- "spec",
- exp.WindowSpec(
- **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
- ),
- )
- return window_spec
-
- def rangeBetween(self, start: int, end: int) -> WindowSpec:
- window_spec = self.copy()
- spec = self._calc_start_end(start, end)
- spec["kind"] = "RANGE"
- window_spec.expression.set(
- "spec",
- exp.WindowSpec(
- **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
- ),
- )
- return window_spec