diff options
Diffstat (limited to 'sqlglot/dataframe/sql/column.py')
-rw-r--r-- | sqlglot/dataframe/sql/column.py | 342 |
1 files changed, 0 insertions, 342 deletions
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py deleted file mode 100644 index 724c5bf..0000000 --- a/sqlglot/dataframe/sql/column.py +++ /dev/null @@ -1,342 +0,0 @@ -from __future__ import annotations - -import typing as t - -import sqlglot -from sqlglot import expressions as exp -from sqlglot.dataframe.sql.types import DataType -from sqlglot.helper import flatten, is_iterable - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql._typing import ColumnOrLiteral - from sqlglot.dataframe.sql.window import WindowSpec - - -class Column: - def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): - from sqlglot.dataframe.sql.session import SparkSession - - if isinstance(expression, Column): - expression = expression.expression # type: ignore - elif expression is None or not isinstance(expression, (str, exp.Expression)): - expression = self._lit(expression).expression # type: ignore - elif not isinstance(expression, exp.Column): - expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform( - SparkSession().dialect.normalize_identifier, copy=False - ) - if expression is None: - raise ValueError(f"Could not parse {expression}") - - self.expression: exp.Expression = expression # type: ignore - - def __repr__(self): - return repr(self.expression) - - def __hash__(self): - return hash(self.expression) - - def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore - return self.binary_op(exp.EQ, other) - - def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore - return self.binary_op(exp.NEQ, other) - - def __gt__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.GT, other) - - def __ge__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.GTE, other) - - def __lt__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.LT, other) - - def __le__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.LTE, other) - - def __and__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.And, other) - - def __or__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Or, other) - - def __mod__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Mod, other) - - def __add__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Add, other) - - def __sub__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Sub, other) - - def __mul__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Mul, other) - - def __truediv__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Div, other) - - def __div__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Div, other) - - def __neg__(self) -> Column: - return self.unary_op(exp.Neg) - - def __radd__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Add, other) - - def __rsub__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Sub, other) - - def __rmul__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Mul, other) - - def __rdiv__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Div, other) - - def __rtruediv__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Div, other) - - def __rmod__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Mod, other) - - def __pow__(self, power: ColumnOrLiteral, modulo=None): - return Column(exp.Pow(this=self.expression, expression=Column(power).expression)) - - def __rpow__(self, power: ColumnOrLiteral): - return Column(exp.Pow(this=Column(power).expression, expression=self.expression)) - - def __invert__(self): - return self.unary_op(exp.Not) - - def __rand__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.And, other) - - def __ror__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Or, other) - - @classmethod - def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]) -> Column: - return cls(value) - - @classmethod - def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]: - return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args] - - @classmethod - def _lit(cls, value: ColumnOrLiteral) -> Column: - if isinstance(value, dict): - columns = [cls._lit(v).alias(k).expression for k, v in value.items()] - return cls(exp.Struct(expressions=columns)) - return cls(exp.convert(value)) - - @classmethod - def invoke_anonymous_function( - cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] - ) -> Column: - columns = [] if column is None else [cls.ensure_col(column)] - column_args = [cls.ensure_col(arg) for arg in args] - expressions = [x.expression for x in columns + column_args] - new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) - return Column(new_expression) - - @classmethod - def invoke_expression_over_column( - cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs - ) -> Column: - ensured_column = None if column is None else cls.ensure_col(column) - ensure_expression_values = { - k: ( - [Column.ensure_col(x).expression for x in v] - if is_iterable(v) - else Column.ensure_col(v).expression - ) - for k, v in kwargs.items() - if v is not None - } - new_expression = ( - callable_expression(**ensure_expression_values) - if ensured_column is None - else callable_expression( - this=ensured_column.column_expression, **ensure_expression_values - ) - ) - return Column(new_expression) - - def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: - return Column( - klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs) - ) - - def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: - return Column( - klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs) - ) - - def unary_op(self, klass: t.Callable, **kwargs) -> Column: - return Column(klass(this=self.column_expression, **kwargs)) - - @property - def is_alias(self): - return isinstance(self.expression, exp.Alias) - - @property - def is_column(self): - return isinstance(self.expression, exp.Column) - - @property - def column_expression(self) -> t.Union[exp.Column, exp.Literal]: - return self.expression.unalias() - - @property - def alias_or_name(self) -> str: - return self.expression.alias_or_name - - @classmethod - def ensure_literal(cls, value) -> Column: - from sqlglot.dataframe.sql.functions import lit - - if isinstance(value, cls): - value = value.expression - if not isinstance(value, exp.Literal): - return lit(value) - return Column(value) - - def copy(self) -> Column: - return Column(self.expression.copy()) - - def set_table_name(self, table_name: str, copy=False) -> Column: - expression = self.expression.copy() if copy else self.expression - expression.set("table", exp.to_identifier(table_name)) - return Column(expression) - - def sql(self, **kwargs) -> str: - from sqlglot.dataframe.sql.session import SparkSession - - return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs}) - - def alias(self, name: str) -> Column: - from sqlglot.dataframe.sql.session import SparkSession - - dialect = SparkSession().dialect - alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect) - new_expression = exp.alias_( - self.column_expression, - alias.this if isinstance(alias, exp.Column) else name, - dialect=dialect, - ) - return Column(new_expression) - - def asc(self) -> Column: - new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True) - return Column(new_expression) - - def desc(self) -> Column: - new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False) - return Column(new_expression) - - asc_nulls_first = asc - - def asc_nulls_last(self) -> Column: - new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False) - return Column(new_expression) - - def desc_nulls_first(self) -> Column: - new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True) - return Column(new_expression) - - desc_nulls_last = desc - - def when(self, condition: Column, value: t.Any) -> Column: - from sqlglot.dataframe.sql.functions import when - - column_with_if = when(condition, value) - if not isinstance(self.expression, exp.Case): - return column_with_if - new_column = self.copy() - new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) - return new_column - - def otherwise(self, value: t.Any) -> Column: - from sqlglot.dataframe.sql.functions import lit - - true_value = value if isinstance(value, Column) else lit(value) - new_column = self.copy() - new_column.expression.set("default", true_value.column_expression) - return new_column - - def isNull(self) -> Column: - new_expression = exp.Is(this=self.column_expression, expression=exp.Null()) - return Column(new_expression) - - def isNotNull(self) -> Column: - new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null())) - return Column(new_expression) - - def cast(self, dataType: t.Union[str, DataType]) -> Column: - """ - Functionality Difference: PySpark cast accepts a datatype instance of the datatype class - Sqlglot doesn't currently replicate this class so it only accepts a string - """ - from sqlglot.dataframe.sql.session import SparkSession - - if isinstance(dataType, DataType): - dataType = dataType.simpleString() - return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect)) - - def startswith(self, value: t.Union[str, Column]) -> Column: - value = self._lit(value) if not isinstance(value, Column) else value - return self.invoke_anonymous_function(self, "STARTSWITH", value) - - def endswith(self, value: t.Union[str, Column]) -> Column: - value = self._lit(value) if not isinstance(value, Column) else value - return self.invoke_anonymous_function(self, "ENDSWITH", value) - - def rlike(self, regexp: str) -> Column: - return self.invoke_expression_over_column( - column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression - ) - - def like(self, other: str): - return self.invoke_expression_over_column( - self, exp.Like, expression=self._lit(other).expression - ) - - def ilike(self, other: str): - return self.invoke_expression_over_column( - self, exp.ILike, expression=self._lit(other).expression - ) - - def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: - startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos - length = self._lit(length) if not isinstance(length, Column) else length - return Column.invoke_expression_over_column( - self, exp.Substring, start=startPos.expression, length=length.expression - ) - - def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): - columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore - expressions = [self._lit(x).expression for x in columns] - return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore - - def between( - self, - lowerBound: t.Union[ColumnOrLiteral], - upperBound: t.Union[ColumnOrLiteral], - ) -> Column: - lower_bound_exp = ( - self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound - ) - upper_bound_exp = ( - self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound - ) - return Column( - exp.Between( - this=self.column_expression, - low=lower_bound_exp.expression, - high=upper_bound_exp.expression, - ) - ) - - def over(self, window: WindowSpec) -> Column: - window_expression = window.expression.copy() - window_expression.set("this", self.column_expression) - return Column(window_expression) |