summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql/column.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dataframe/sql/column.py')
-rw-r--r--sqlglot/dataframe/sql/column.py46
1 files changed, 35 insertions, 11 deletions
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
index e66aaa8..f9e1c5b 100644
--- a/sqlglot/dataframe/sql/column.py
+++ b/sqlglot/dataframe/sql/column.py
@@ -18,7 +18,11 @@ class Column:
expression = expression.expression # type: ignore
elif expression is None or not isinstance(expression, (str, exp.Expression)):
expression = self._lit(expression).expression # type: ignore
- self.expression: exp.Expression = sqlglot.maybe_parse(expression, dialect="spark")
+
+ expression = sqlglot.maybe_parse(expression, dialect="spark")
+ if expression is None:
+ raise ValueError(f"Could not parse {expression}")
+ self.expression: exp.Expression = expression
def __repr__(self):
return repr(self.expression)
@@ -135,21 +139,29 @@ class Column:
) -> Column:
ensured_column = None if column is None else cls.ensure_col(column)
ensure_expression_values = {
- k: [Column.ensure_col(x).expression for x in v] if is_iterable(v) else Column.ensure_col(v).expression
+ k: [Column.ensure_col(x).expression for x in v]
+ if is_iterable(v)
+ else Column.ensure_col(v).expression
for k, v in kwargs.items()
}
new_expression = (
callable_expression(**ensure_expression_values)
if ensured_column is None
- else callable_expression(this=ensured_column.column_expression, **ensure_expression_values)
+ else callable_expression(
+ this=ensured_column.column_expression, **ensure_expression_values
+ )
)
return Column(new_expression)
def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
- return Column(klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs))
+ return Column(
+ klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
+ )
def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
- return Column(klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs))
+ return Column(
+ klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
+ )
def unary_op(self, klass: t.Callable, **kwargs) -> Column:
return Column(klass(this=self.column_expression, **kwargs))
@@ -188,7 +200,7 @@ class Column:
expression.set("table", exp.to_identifier(table_name))
return Column(expression)
- def sql(self, **kwargs) -> Column:
+ def sql(self, **kwargs) -> str:
return self.expression.sql(**{"dialect": "spark", **kwargs})
def alias(self, name: str) -> Column:
@@ -265,10 +277,14 @@ class Column:
)
def like(self, other: str):
- return self.invoke_expression_over_column(self, exp.Like, expression=self._lit(other).expression)
+ return self.invoke_expression_over_column(
+ self, exp.Like, expression=self._lit(other).expression
+ )
def ilike(self, other: str):
- return self.invoke_expression_over_column(self, exp.ILike, expression=self._lit(other).expression)
+ return self.invoke_expression_over_column(
+ self, exp.ILike, expression=self._lit(other).expression
+ )
def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
@@ -287,10 +303,18 @@ class Column:
lowerBound: t.Union[ColumnOrLiteral],
upperBound: t.Union[ColumnOrLiteral],
) -> Column:
- lower_bound_exp = self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
- upper_bound_exp = self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
+ lower_bound_exp = (
+ self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
+ )
+ upper_bound_exp = (
+ self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
+ )
return Column(
- exp.Between(this=self.column_expression, low=lower_bound_exp.expression, high=upper_bound_exp.expression)
+ exp.Between(
+ this=self.column_expression,
+ low=lower_bound_exp.expression,
+ high=upper_bound_exp.expression,
+ )
)
def over(self, window: WindowSpec) -> Column: