From d1f00706bff58b863b0a1c5bf4adf39d36049d4c Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 11 Nov 2022 09:54:35 +0100 Subject: Merging upstream version 10.0.1. Signed-off-by: Daniel Baumann --- sqlglot/dataframe/sql/column.py | 46 +++++++++++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 11 deletions(-) (limited to 'sqlglot/dataframe/sql/column.py') diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index e66aaa8..f9e1c5b 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -18,7 +18,11 @@ class Column: expression = expression.expression # type: ignore elif expression is None or not isinstance(expression, (str, exp.Expression)): expression = self._lit(expression).expression # type: ignore - self.expression: exp.Expression = sqlglot.maybe_parse(expression, dialect="spark") + + expression = sqlglot.maybe_parse(expression, dialect="spark") + if expression is None: + raise ValueError(f"Could not parse {expression}") + self.expression: exp.Expression = expression def __repr__(self): return repr(self.expression) @@ -135,21 +139,29 @@ class Column: ) -> Column: ensured_column = None if column is None else cls.ensure_col(column) ensure_expression_values = { - k: [Column.ensure_col(x).expression for x in v] if is_iterable(v) else Column.ensure_col(v).expression + k: [Column.ensure_col(x).expression for x in v] + if is_iterable(v) + else Column.ensure_col(v).expression for k, v in kwargs.items() } new_expression = ( callable_expression(**ensure_expression_values) if ensured_column is None - else callable_expression(this=ensured_column.column_expression, **ensure_expression_values) + else callable_expression( + this=ensured_column.column_expression, **ensure_expression_values + ) ) return Column(new_expression) def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: - return Column(klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)) + return Column( + klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs) + ) def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: - return Column(klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)) + return Column( + klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs) + ) def unary_op(self, klass: t.Callable, **kwargs) -> Column: return Column(klass(this=self.column_expression, **kwargs)) @@ -188,7 +200,7 @@ class Column: expression.set("table", exp.to_identifier(table_name)) return Column(expression) - def sql(self, **kwargs) -> Column: + def sql(self, **kwargs) -> str: return self.expression.sql(**{"dialect": "spark", **kwargs}) def alias(self, name: str) -> Column: @@ -265,10 +277,14 @@ class Column: ) def like(self, other: str): - return self.invoke_expression_over_column(self, exp.Like, expression=self._lit(other).expression) + return self.invoke_expression_over_column( + self, exp.Like, expression=self._lit(other).expression + ) def ilike(self, other: str): - return self.invoke_expression_over_column(self, exp.ILike, expression=self._lit(other).expression) + return self.invoke_expression_over_column( + self, exp.ILike, expression=self._lit(other).expression + ) def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos @@ -287,10 +303,18 @@ class Column: lowerBound: t.Union[ColumnOrLiteral], upperBound: t.Union[ColumnOrLiteral], ) -> Column: - lower_bound_exp = self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound - upper_bound_exp = self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound + lower_bound_exp = ( + self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound + ) + upper_bound_exp = ( + self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound + ) return Column( - exp.Between(this=self.column_expression, low=lower_bound_exp.expression, high=upper_bound_exp.expression) + exp.Between( + this=self.column_expression, + low=lower_bound_exp.expression, + high=upper_bound_exp.expression, + ) ) def over(self, window: WindowSpec) -> Column: -- cgit v1.2.3