diff options
Diffstat (limited to 'sqlglot/dataframe/sql/column.py')
-rw-r--r-- | sqlglot/dataframe/sql/column.py | 295 |
1 files changed, 295 insertions, 0 deletions
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py new file mode 100644 index 0000000..2391080 --- /dev/null +++ b/sqlglot/dataframe/sql/column.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +import typing as t + +import sqlglot +from sqlglot import expressions as exp +from sqlglot.dataframe.sql.types import DataType +from sqlglot.helper import flatten + +if t.TYPE_CHECKING: + from sqlglot.dataframe.sql._typing import ColumnOrLiteral + from sqlglot.dataframe.sql.window import WindowSpec + + +class Column: + def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): + if isinstance(expression, Column): + expression = expression.expression # type: ignore + elif expression is None or not isinstance(expression, (str, exp.Expression)): + expression = self._lit(expression).expression # type: ignore + self.expression: exp.Expression = sqlglot.maybe_parse(expression, dialect="spark") + + def __repr__(self): + return repr(self.expression) + + def __hash__(self): + return hash(self.expression) + + def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore + return self.binary_op(exp.EQ, other) + + def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore + return self.binary_op(exp.NEQ, other) + + def __gt__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.GT, other) + + def __ge__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.GTE, other) + + def __lt__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.LT, other) + + def __le__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.LTE, other) + + def __and__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.And, other) + + def __or__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.Or, other) + + def __mod__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.Mod, other) + + def __add__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.Add, other) + + def __sub__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.Sub, other) + + def __mul__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.Mul, other) + + def __truediv__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.Div, other) + + def __div__(self, other: ColumnOrLiteral) -> Column: + return self.binary_op(exp.Div, other) + + def __neg__(self) -> Column: + return self.unary_op(exp.Neg) + + def __radd__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.Add, other) + + def __rsub__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.Sub, other) + + def __rmul__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.Mul, other) + + def __rdiv__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.Div, other) + + def __rtruediv__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.Div, other) + + def __rmod__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.Mod, other) + + def __pow__(self, power: ColumnOrLiteral, modulo=None): + return Column(exp.Pow(this=self.expression, power=Column(power).expression)) + + def __rpow__(self, power: ColumnOrLiteral): + return Column(exp.Pow(this=Column(power).expression, power=self.expression)) + + def __invert__(self): + return self.unary_op(exp.Not) + + def __rand__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.And, other) + + def __ror__(self, other: ColumnOrLiteral) -> Column: + return self.inverse_binary_op(exp.Or, other) + + @classmethod + def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): + return cls(value) + + @classmethod + def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]: + return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args] + + @classmethod + def _lit(cls, value: ColumnOrLiteral) -> Column: + if isinstance(value, dict): + columns = [cls._lit(v).alias(k).expression for k, v in value.items()] + return cls(exp.Struct(expressions=columns)) + return cls(exp.convert(value)) + + @classmethod + def invoke_anonymous_function( + cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] + ) -> Column: + columns = [] if column is None else [cls.ensure_col(column)] + column_args = [cls.ensure_col(arg) for arg in args] + expressions = [x.expression for x in columns + column_args] + new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) + return Column(new_expression) + + @classmethod + def invoke_expression_over_column( + cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs + ) -> Column: + ensured_column = None if column is None else cls.ensure_col(column) + new_expression = ( + callable_expression(**kwargs) + if ensured_column is None + else callable_expression(this=ensured_column.column_expression, **kwargs) + ) + return Column(new_expression) + + def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: + return Column(klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)) + + def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: + return Column(klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)) + + def unary_op(self, klass: t.Callable, **kwargs) -> Column: + return Column(klass(this=self.column_expression, **kwargs)) + + @property + def is_alias(self): + return isinstance(self.expression, exp.Alias) + + @property + def is_column(self): + return isinstance(self.expression, exp.Column) + + @property + def column_expression(self) -> exp.Column: + return self.expression.unalias() + + @property + def alias_or_name(self) -> str: + return self.expression.alias_or_name + + @classmethod + def ensure_literal(cls, value) -> Column: + from sqlglot.dataframe.sql.functions import lit + + if isinstance(value, cls): + value = value.expression + if not isinstance(value, exp.Literal): + return lit(value) + return Column(value) + + def copy(self) -> Column: + return Column(self.expression.copy()) + + def set_table_name(self, table_name: str, copy=False) -> Column: + expression = self.expression.copy() if copy else self.expression + expression.set("table", exp.to_identifier(table_name)) + return Column(expression) + + def sql(self, **kwargs) -> Column: + return self.expression.sql(**{"dialect": "spark", **kwargs}) + + def alias(self, name: str) -> Column: + new_expression = exp.alias_(self.column_expression, name) + return Column(new_expression) + + def asc(self) -> Column: + new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True) + return Column(new_expression) + + def desc(self) -> Column: + new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False) + return Column(new_expression) + + asc_nulls_first = asc + + def asc_nulls_last(self) -> Column: + new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False) + return Column(new_expression) + + def desc_nulls_first(self) -> Column: + new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True) + return Column(new_expression) + + desc_nulls_last = desc + + def when(self, condition: Column, value: t.Any) -> Column: + from sqlglot.dataframe.sql.functions import when + + column_with_if = when(condition, value) + if not isinstance(self.expression, exp.Case): + return column_with_if + new_column = self.copy() + new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) + return new_column + + def otherwise(self, value: t.Any) -> Column: + from sqlglot.dataframe.sql.functions import lit + + true_value = value if isinstance(value, Column) else lit(value) + new_column = self.copy() + new_column.expression.set("default", true_value.column_expression) + return new_column + + def isNull(self) -> Column: + new_expression = exp.Is(this=self.column_expression, expression=exp.Null()) + return Column(new_expression) + + def isNotNull(self) -> Column: + new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null())) + return Column(new_expression) + + def cast(self, dataType: t.Union[str, DataType]): + """ + Functionality Difference: PySpark cast accepts a datatype instance of the datatype class + Sqlglot doesn't currently replicate this class so it only accepts a string + """ + if isinstance(dataType, DataType): + dataType = dataType.simpleString() + new_expression = exp.Cast(this=self.column_expression, to=dataType) + return Column(new_expression) + + def startswith(self, value: t.Union[str, Column]) -> Column: + value = self._lit(value) if not isinstance(value, Column) else value + return self.invoke_anonymous_function(self, "STARTSWITH", value) + + def endswith(self, value: t.Union[str, Column]) -> Column: + value = self._lit(value) if not isinstance(value, Column) else value + return self.invoke_anonymous_function(self, "ENDSWITH", value) + + def rlike(self, regexp: str) -> Column: + return self.invoke_expression_over_column( + column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression + ) + + def like(self, other: str): + return self.invoke_expression_over_column(self, exp.Like, expression=self._lit(other).expression) + + def ilike(self, other: str): + return self.invoke_expression_over_column(self, exp.ILike, expression=self._lit(other).expression) + + def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: + startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos + length = self._lit(length) if not isinstance(length, Column) else length + return Column.invoke_expression_over_column( + self, exp.Substring, start=startPos.expression, length=length.expression + ) + + def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): + columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore + expressions = [self._lit(x).expression for x in columns] + return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore + + def between( + self, + lowerBound: t.Union[ColumnOrLiteral], + upperBound: t.Union[ColumnOrLiteral], + ) -> Column: + lower_bound_exp = self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound + upper_bound_exp = self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound + return Column( + exp.Between(this=self.column_expression, low=lower_bound_exp.expression, high=upper_bound_exp.expression) + ) + + def over(self, window: WindowSpec) -> Column: + window_expression = window.expression.copy() + window_expression.set("this", self.column_expression) + return Column(window_expression) |